"""
Contains the code related for face recogniction/verification using on a
database.
"""
from typing import Tuple, Optional, Union
import sqlite3
import numpy as np
import torch
[docs]class Database:
"""
A simple wrapper for sqlite database
"""
def __init__(self, path):
self.path = path
[docs] def create(self, table_name="tt") -> None:
"""Create a new table. The table name itself is not very important since we are not doing any relational searches.
Args:
table_name (str, optional): Defaults to "tt".
Returns:
None
"""
query = "CREATE TABLE ? (name text, value blob)"
q_new = query.replace("?", table_name)
self._execute(q_new, ())
return None
[docs] def insert(self, name: str, value: bytes, log: bool = False):
"""Inserts values into the database.
Args:
name (str): Name of the person.
value (bytes): 512 dimensional embeddings as bytes.
log (bool, optional): Logs to console. Defaults to False.
"""
self._execute("INSERT INTO tt VALUES (?, ?)",
(name, value), fetch=False)
def __iter__(self):
conn = sqlite3.connect(self.path)
c = conn.cursor()
c.execute("Select * From tt")
conn.commit()
value = c.fetchall()
conn.close()
for x in value:
yield (x)
[docs] def get_value(self, name: str) -> bytes:
"""Returns the embedding associated with the name.
Args:
name (str): Name of the person in the database.
Returns:
bytes: Bytes type embedding.
"""
value = self._execute(
"Select name, value From tt where name=(?)", (name,), fetch=True
)
return value
def _execute(
self, query, query_params, fetch=False
) -> Optional[Union[Tuple[str, float], None]]:
"""Executes the given query with query params, optionally return the output.
Returns:
(name, value) or None
"""
value = None
conn = sqlite3.connect(self.path)
c = conn.cursor()
c.execute(query, query_params)
if fetch:
value = c.fetchone()
conn.commit()
conn.close()
return value
[docs] def delete(self, name: str) -> None:
"""Deletes the matching name from the database
Args:
name (str): name to match
"""
self._execute("DELETE FROM tt WHERE name = ?", (name,), fetch=False)
[docs] def update(self, name: str, value: bytes) -> None:
"""Updates the embedding associated with the name.
Args:
name (str): Name of the person in the database.
value (bytes): Embedding as bytes
"""
self._execute(
"UPDATE tt SET value = ? WHERE name = ?", (value, name), fetch=False
)
[docs] def drop_table(self, table: str) -> None:
"""Drops the table
Args:
table (str): Name of the table.
"""
query = "DROP TABLE ?".replace("?", table)
self._execute(query, (), fetch=False)
[docs]def torch_to_np(array: torch.Tensor) -> np.ndarray:
"""Coverts torch tensor to numpy array, handles the case when torch tensor is
stuck in cuda.
Args:
array (torch.Tensor): Pytorch Tensor
Returns:
np.ndarray
"""
if array.is_cuda:
array = array.cpu()
return array.detach().numpy()
[docs]def load_array(byte: bytes) -> np.ndarray:
"""Loads the given bytes array in a ndarray
Args:
byte (bytes): Bytes array
Returns:
np.ndarray: Numpy array.
"""
return np.ndarray(shape=(1, 512), dtype=np.float32, buffer=byte)
[docs]def compare_embedding(
embedding: Union[torch.Tensor, np.ndarray], db: Database
) -> Union[Tuple[str, float], None]:
"""Compares the given embedding with every embedding in db. Returns the first
matched name and distance if any is found.
Args:
embedding (Union[torch.Tensor, np.ndarray]): Embedding vector, either numpy or torch
db (Database): Database to search the images
Returns:
Union[Tuple[str, float], None]: Name, distance pair if matched.
"""
if isinstance(embedding, torch.Tensor):
embedding = torch_to_np(embedding)
for name, em in db:
arr = load_array(em)
print(arr.shape)
diff = arr - embedding
dist = np.sqrt(np.einsum("ij,ij->j", diff, diff))[0]
print(dist)
if dist < 0.56:
return name, dist
print("Not found")