"""
Contains the code related for face recognition/verification.
"""
from typing import Union, Tuple, Optional
import torch
import torchvision.transforms as T
import numpy as np
from torch import nn
from facenet_pytorch import MTCNN, InceptionResnetV1
from PIL import Image
import matplotlib.pyplot as plt
from scipy.spatial import distance
from .vgg import VGG_16
from . import utils
# Image Types
import PIL.PngImagePlugin
import PIL.JpegImagePlugin
[docs]class DefaultMTCNN(MTCNN):
"""Default MTCNN with predefined parameters, not meant to be called directly.
"""
def __init__(self):
super().__init__(
image_size=224,
margin=0,
min_face_size=20,
thresholds=[0.6, 0.7, 0.7],
factor=0.709,
post_process=True,
device="cpu",
)
[docs]def get_model(model_name: str = "vgg") -> torch.nn.Module:
"""Return the most up to date model
Returns:
nn.Module: Pytorch model set to eval
"""
types = (str,)
available_model_names = ("vgg", "inception")
utils.type_check(model_name, types)
model_name = model_name.lower()
if model_name not in available_model_names:
raise ValueError(f"Expected {model_name} to be one of {available_model_names}")
if model_name == "vgg":
model = VGG_16()
model.load_weights()
model.eval()
elif model_name == "inception":
model = InceptionResnetV1(pretrained="casia-webface").eval()
return model
[docs]def load_image(image: Union[str, np.ndarray, Image.Image]) -> np.ndarray:
types = (
str,
np.ndarray,
Image.Image,
PIL.PngImagePlugin.PngImageFile,
PIL.JpegImagePlugin.JpegImageFile,
)
utils.type_check(image, types)
img = image
if type(image) == str:
img = Image.open(image)
if not isinstance(img, np.ndarray):
img = img.convert("RGB")
img = np.array(img)
return img
[docs]def verify(
img1: Union[Image.Image, str],
img2: Union[Image.Image, str],
model: nn.Module,
mtcnn: MTCNN = DefaultMTCNN(),
dist_fn: str = "euclidean",
plot: bool = True,
decision: bool = True,
) -> Union[Tuple[bool, float], float]:
"""Compares two images by the distance between their embeddings
Args:
img1 (Union[Image.Image, str]): Either PIL or path
img2 (Union[Image.Image, str]): Either PIL or path
model (nn.Module): A pytorch model.
mtcnn (MTCNN, optional): User defined MTCNN. Defaults to None.
dist_fn (str, optional): Current opetions are euclidean or cosine. Defaults to "euclidean".
plot (bool, optional): Matplotlib plot. Defaults to True.
decision (bool, optional): [description]. Defaults to True.
Returns:
Union[Tuple[bool, float]]: Distance between the embeddings
"""
img_types = (str, Image.Image, np.ndarray)
utils.type_check(img1, img_types)
utils.type_check(img2, img_types)
img1 = load_image(img1)
img2 = load_image(img2)
img1_embedding = get_embedding(img1, model, mtcnn)
img2_embedding = get_embedding(img2, model, mtcnn)
img1_embedding_np = torch_to_np(img1_embedding)
img2_embedding_np = torch_to_np(img2_embedding)
if plot:
fig = plt.figure()
ax1 = fig.add_subplot(121)
display_image(img1, ax1)
ax2 = fig.add_subplot(122)
display_image(img2, ax2)
plt.show()
if dist_fn == "euclidean":
dist = distance.euclidean(img1_embedding_np, img2_embedding_np)
elif dist_fn == "cosine":
dist = distance.cosine(img1_embedding_np, img2_embedding_np)
elif dist_fn == "both":
dist1 = distance.euclidean(img1_embedding_np, img2_embedding_np)
dist2 = distance.cosine(img1_embedding_np, img2_embedding_np)
dist = np.array([dist1, dist2]).mean()
if decision:
return dist < 0.56, dist
return dist
[docs]def get_embedding(
image: Union[np.ndarray, Image.Image], model, mtcnn=DefaultMTCNN()
) -> torch.Tensor:
"""Runs the PIL image though the model after preprocessing
Args:
image (Image.Image): PIL image
model (nn.Module): Pytorch model
mtcnn (optional): Defaults to None.
Returns:
[type]: [description]
"""
types = (np.ndarray, Image.Image)
utils.type_check(image, types)
resize = T.Resize([mtcnn.image_size] * 2)
to_tensor = T.ToTensor()
with torch.no_grad():
Img1 = to_tensor(resize(extract_face(image, mtcnn)))
embedding = model(Img1.unsqueeze(0))
return embedding
[docs]def display_image(img: Union[Image.Image, np.ndarray], ax: Optional = None) -> None:
"""Displays PIL image, optionally can be used to stack images together
Args:
img (PIL Image): Image to be displayed.
ax ([Axes], optional): Axes from the previous plt call to stack together. Defaults to None.
"""
types = (Image.Image, np.ndarray)
utils.type_check(img, types)
utils.check_image_channels(img)
if isinstance(img, Image.Image):
img = np.array(img)
if ax:
ax.imshow(img)
else:
plt.imshow(img)
[docs]def torch_to_np(array: torch.Tensor) -> np.ndarray:
"""Coverts torch tensor to numpy array, handles the case when torch tensor is
stuck in cuda.
Args:
array (torch.Tensor): Pytorch Tensor
Returns:
np.ndarray
"""
types = (torch.Tensor,)
utils.type_check(array, types)
if array.is_cuda:
array = array.cpu()
return array.detach().numpy()