from tqdm import tqdm, trange
from typing import cast, Any, List, Union, Optional
import torch
import numpy as np
from transformers import AutoModel, AutoTokenizer
from FlagEmbedding.abc.inference import AbsEmbedder
[docs]
class BaseEmbedder(AbsEmbedder):
"""
Base embedder for encoder only models.
Args:
model_name_or_path (str): If it's a path to a local model, it loads the model from the path. Otherwise tries to download and
load a model from HuggingFace Hub with the name.
normalize_embeddings (bool, optional): If True, normalize the embedding vector. Defaults to :data:`True`.
use_fp16 (bool, optional): If true, use half-precision floating-point to speed up computation with a slight performance
degradation. Defaults to :data:`True`.
query_instruction_for_retrieval (Optional[str], optional): Query instruction for retrieval tasks, which will be used with
with :attr:`query_instruction_format`. Defaults to :data:`None`.
query_instruction_format (str, optional): The template for :attr:`query_instruction_for_retrieval`. Defaults to :data:`"{}{}"`.
devices (Optional[Union[str, int, List[str], List[int]]], optional): Devices to use for model inference. Defaults to :data:`None`.
pooling_method (str, optional): Pooling method to get embedding vector from the last hidden state. Defaults to :data:`"cls"`.
trust_remote_code (bool, optional): trust_remote_code for HF datasets or models. Defaults to :data:`False`.
cache_dir (Optional[str], optional): Cache directory for the model. Defaults to :data:`None`.
batch_size (int, optional): Batch size for inference. Defaults to :data:`256`.
query_max_length (int, optional): Maximum length for query. Defaults to :data:`512`.
passage_max_length (int, optional): Maximum length for passage. Defaults to :data:`512`.
convert_to_numpy (bool, optional): If True, the output embedding will be a Numpy array. Otherwise, it will be a Torch Tensor.
Defaults to :data:`True`.
Attributes:
DEFAULT_POOLING_METHOD: The default pooling method when running the model.
"""
DEFAULT_POOLING_METHOD = "cls"
def __init__(
self,
model_name_or_path: str,
normalize_embeddings: bool = True,
use_fp16: bool = True,
query_instruction_for_retrieval: Optional[str] = None,
query_instruction_format: str = "{}{}", # specify the format of query_instruction_for_retrieval
devices: Optional[Union[str, List[str]]] = None, # specify devices, such as "cuda:0" or ["cuda:0", "cuda:1"]
# Additional parameters for BaseEmbedder
pooling_method: str = "cls",
trust_remote_code: bool = False,
cache_dir: Optional[str] = None,
# inference
batch_size: int = 256,
query_max_length: int = 512,
passage_max_length: int = 512,
convert_to_numpy: bool = True,
**kwargs: Any,
):
super().__init__(
model_name_or_path,
normalize_embeddings=normalize_embeddings,
use_fp16=use_fp16,
query_instruction_for_retrieval=query_instruction_for_retrieval,
query_instruction_format=query_instruction_format,
devices=devices,
batch_size=batch_size,
query_max_length=query_max_length,
passage_max_length=passage_max_length,
convert_to_numpy=convert_to_numpy,
**kwargs
)
self.pooling_method = pooling_method
self.tokenizer = AutoTokenizer.from_pretrained(
model_name_or_path,
trust_remote_code=trust_remote_code,
cache_dir=cache_dir
)
self.model = AutoModel.from_pretrained(
model_name_or_path,
trust_remote_code=trust_remote_code,
cache_dir=cache_dir
)
[docs]
def encode_queries(
self,
queries: Union[List[str], str],
batch_size: Optional[int] = None,
max_length: Optional[int] = None,
convert_to_numpy: Optional[bool] = None,
**kwargs: Any
) -> Union[np.ndarray, torch.Tensor]:
"""Encode the queries.
Args:
queries (Union[List[str], str]): Input queries to encode.
batch_size (Optional[int], optional): Number of sentences for each iter. Defaults to :data:`None`.
max_length (Optional[int], optional): Maximum length of tokens. Defaults to :data:`None`.
convert_to_numpy (Optional[bool], optional): If True, the output embedding will be a Numpy array. Otherwise, it will
be a Torch Tensor. Defaults to :data:`None`.
Returns:
Union[torch.Tensor, np.ndarray]: Return the embedding vectors in a numpy array or tensor.
"""
return super().encode_queries(
queries,
batch_size=batch_size,
max_length=max_length,
convert_to_numpy=convert_to_numpy,
**kwargs
)
[docs]
def encode_corpus(
self,
corpus: Union[List[str], str],
batch_size: Optional[int] = None,
max_length: Optional[int] = None,
convert_to_numpy: Optional[bool] = None,
**kwargs: Any
) -> Union[np.ndarray, torch.Tensor]:
"""Encode the corpus using the instruction if provided.
Args:
corpus (Union[List[str], str]): Input corpus to encode.
batch_size (Optional[int], optional): Number of sentences for each iter. Defaults to :data:`None`.
max_length (Optional[int], optional): Maximum length of tokens. Defaults to :data:`None`.
convert_to_numpy (Optional[bool], optional): If True, the output embedding will be a Numpy array. Otherwise, it will
be a Torch Tensor. Defaults to :data:`None`.
Returns:
Union[torch.Tensor, np.ndarray]: Return the embedding vectors in a numpy array or tensor.
"""
return super().encode_corpus(
corpus,
batch_size=batch_size,
max_length=max_length,
convert_to_numpy=convert_to_numpy,
**kwargs
)
[docs]
def encode(
self,
sentences: Union[List[str], str],
batch_size: Optional[int] = None,
max_length: Optional[int] = None,
convert_to_numpy: Optional[bool] = None,
**kwargs: Any
) -> Union[np.ndarray, torch.Tensor]:
"""Encode the input sentences with the embedding model.
Args:
sentences (Union[List[str], str]): Input sentences to encode.
batch_size (Optional[int], optional): Number of sentences for each iter. Defaults to :data:`None`.
max_length (Optional[int], optional): Maximum length of tokens. Defaults to :data:`None`.
convert_to_numpy (Optional[bool], optional): If True, the output embedding will be a Numpy array. Otherwise, it will
be a Torch Tensor. Defaults to :data:`None`.
Returns:
Union[torch.Tensor, np.ndarray]: return the embedding vectors in a numpy array or tensor.
"""
return super().encode(
sentences,
batch_size=batch_size,
max_length=max_length,
convert_to_numpy=convert_to_numpy,
**kwargs
)
[docs]
@torch.no_grad()
def encode_single_device(
self,
sentences: Union[List[str], str],
batch_size: int = 256,
max_length: int = 512,
convert_to_numpy: bool = True,
device: Optional[str] = None,
**kwargs: Any
):
"""Encode input sentences by a single device.
Args:
sentences (Union[List[str], str]): Input sentences to encode.
batch_size (int, optional): Number of sentences for each iter. Defaults to :data:`256`.
max_length (int, optional): Maximum length of tokens. Defaults to :data:`512`.
convert_to_numpy (bool, optional): If True, the output embedding will be a Numpy array. Otherwise, it will
be a Torch Tensor. Defaults to :data:`True`.
device (Optional[str], optional): Device to use for encoding. Defaults to None.
Returns:
Union[torch.Tensor, np.ndarray]: return the embedding vectors in a numpy array or tensor.
"""
if device is None:
device = self.target_devices[0]
if device == "cpu": self.use_fp16 = False
if self.use_fp16: self.model.half()
self.model.to(device)
self.model.eval()
input_was_string = False
if isinstance(sentences, str):
sentences = [sentences]
input_was_string = True
# tokenize without padding to get the correct length
all_inputs = []
for start_index in trange(0, len(sentences), batch_size, desc='pre tokenize',
disable=len(sentences) < 256):
sentences_batch = sentences[start_index:start_index + batch_size]
inputs_batch = self.tokenizer(
sentences_batch,
truncation=True,
max_length=max_length,
**kwargs
)
inputs_batch = [{
k: inputs_batch[k][i] for k in inputs_batch.keys()
} for i in range(len(sentences_batch))]
all_inputs.extend(inputs_batch)
# sort by length for less padding
length_sorted_idx = np.argsort([-len(x['input_ids']) for x in all_inputs])
all_inputs_sorted = [all_inputs[i] for i in length_sorted_idx]
# adjust batch size
flag = False
while flag is False:
try:
inputs_batch = self.tokenizer.pad(
all_inputs_sorted[: batch_size],
padding=True,
return_tensors='pt',
**kwargs
).to(device)
last_hidden_state = self.model(**inputs_batch, return_dict=True).last_hidden_state
embeddings = self.pooling(last_hidden_state, inputs_batch['attention_mask'])
flag = True
except RuntimeError as e:
batch_size = batch_size * 3 // 4
except torch.OutofMemoryError as e:
batch_size = batch_size * 3 // 4
# encode
all_embeddings = []
for start_index in tqdm(range(0, len(sentences), batch_size), desc="Inference Embeddings",
disable=len(sentences) < 256):
inputs_batch = all_inputs_sorted[start_index:start_index + batch_size]
inputs_batch = self.tokenizer.pad(
inputs_batch,
padding=True,
return_tensors='pt',
**kwargs
).to(device)
last_hidden_state = self.model(**inputs_batch, return_dict=True).last_hidden_state
embeddings = self.pooling(last_hidden_state, inputs_batch['attention_mask'])
if self.normalize_embeddings:
embeddings = torch.nn.functional.normalize(embeddings, dim=-1)
embeddings = cast(torch.Tensor, embeddings)
if convert_to_numpy:
embeddings = embeddings.cpu().numpy()
all_embeddings.append(embeddings)
if convert_to_numpy:
all_embeddings = np.concatenate(all_embeddings, axis=0)
else:
all_embeddings = torch.cat(all_embeddings, dim=0)
# adjust the order of embeddings
all_embeddings = all_embeddings[np.argsort(length_sorted_idx)]
# return the embeddings
if input_was_string:
return all_embeddings[0]
return all_embeddings
[docs]
def pooling(
self,
last_hidden_state: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None
):
"""The pooling function.
Args:
last_hidden_state (torch.Tensor): The last hidden state of the model.
attention_mask (Optional[torch.Tensor], optional): Attention mask. Defaults to :data:`None`.
Raises:
NotImplementedError: pooling method not implemented.
Returns:
torch.Tensor: The embedding vectors after pooling.
"""
if self.pooling_method == 'cls':
return last_hidden_state[:, 0]
elif self.pooling_method == 'mean':
s = torch.sum(last_hidden_state * attention_mask.unsqueeze(-1).float(), dim=1)
d = attention_mask.sum(dim=1, keepdim=True).float()
return s / d
else:
raise NotImplementedError(f"pooling method {self.pooling_method} not implemented")