Source code for FlagEmbedding.abc.inference.AbsEmbedder

import logging
from tqdm import tqdm, trange
from abc import ABC, abstractmethod
from typing import Any, Union, List, Dict, Literal, Optional

import queue
import multiprocessing as mp
from multiprocessing import Queue

import math
import gc
import torch
import numpy as np
from transformers import is_torch_npu_available

logger = logging.getLogger(__name__)


[docs] class AbsEmbedder(ABC): """ Base class for embedder. Extend this class and implement :meth:`encode_queries`, :meth:`encode_corpus`, :meth:`encode` for custom embedders. Args: model_name_or_path (str): If it's a path to a local model, it loads the model from the path. Otherwise tries to download and load a model from HuggingFace Hub with the name. normalize_embeddings (bool, optional): If True, normalize the embedding vector. Defaults to :data:`True`. use_fp16 (bool, optional): If true, use half-precision floating-point to speed up computation with a slight performance degradation. Defaults to :data:`True`. query_instruction_for_retrieval: (Optional[str], optional): Query instruction for retrieval tasks, which will be used with with :attr:`query_instruction_format`. Defaults to :data:`None`. query_instruction_format: (str, optional): The template for :attr:`query_instruction_for_retrieval`. Defaults to :data:`"{}{}"`. devices (Optional[Union[str, int, List[str], List[int]]], optional): Devices to use for model inference. Defaults to :data:`None`. batch_size (int, optional): Batch size for inference. Defaults to :data:`256`. query_max_length (int, optional): Maximum length for query. Defaults to :data:`512`. passage_max_length (int, optional): Maximum length for passage. Defaults to :data:`512`. convert_to_numpy (bool, optional): If True, the output embedding will be a Numpy array. Otherwise, it will be a Torch Tensor. Defaults to :data:`True`. kwargs (Dict[Any], optional): Additional parameters for HuggingFace Transformers config or children classes. """ def __init__( self, model_name_or_path: str, normalize_embeddings: bool = True, use_fp16: bool = True, query_instruction_for_retrieval: Optional[str] = None, query_instruction_format: str = "{}{}", # specify the format of query_instruction_for_retrieval devices: Optional[Union[str, int, List[str], List[int]]] = None, # inference batch_size: int = 256, query_max_length: int = 512, passage_max_length: int = 512, convert_to_numpy: bool = True, **kwargs: Any, ): query_instruction_format = query_instruction_format.replace('\\n', '\n') self.model_name_or_path = model_name_or_path self.normalize_embeddings = normalize_embeddings self.use_fp16 = use_fp16 self.query_instruction_for_retrieval = query_instruction_for_retrieval self.query_instruction_format = query_instruction_format self.target_devices = self.get_target_devices(devices) self.batch_size = batch_size self.query_max_length = query_max_length self.passage_max_length = passage_max_length self.convert_to_numpy = convert_to_numpy for k in kwargs: setattr(self, k, kwargs[k]) self.kwargs = kwargs # tokenizer and model are initialized in the child class self.tokenizer = None self.model = None self.pool = None def stop_self_pool(self): if self.pool is not None: self.stop_multi_process_pool(self.pool) self.pool = None try: self.model.to('cpu') torch.cuda.empty_cache() except: pass gc.collect()
[docs] @staticmethod def get_target_devices(devices: Union[str, int, List[str], List[int]]) -> List[str]: """ Args: devices (Union[str, int, List[str], List[int]]): specified devices, can be `str`, `int`, list of `str`, or list of `int`. Raises: ValueError: Devices should be a string or an integer or a list of strings or a list of integers. Returns: List[str]: A list of target devices in format. """ if devices is None: if torch.cuda.is_available(): return [f"cuda:{i}" for i in range(torch.cuda.device_count())] elif is_torch_npu_available(): return [f"npu:{i}" for i in range(torch.npu.device_count())] elif torch.backends.mps.is_available(): return [f"mps:{i}" for i in range(torch.mps.device_count())] else: return ["cpu"] elif isinstance(devices, str): return [devices] elif isinstance(devices, int): return [f"cuda:{devices}"] elif isinstance(devices, list): if isinstance(devices[0], str): return devices elif isinstance(devices[0], int): return [f"cuda:{device}" for device in devices] else: raise ValueError("devices should be a string or an integer or a list of strings or a list of integers.") else: raise ValueError("devices should be a string or an integer or a list of strings or a list of integers.")
[docs] @staticmethod def get_detailed_instruct(instruction_format: str, instruction: str, sentence: str): """Combine the instruction and sentence along with the instruction format. Args: instruction_format (str): Format for instruction. instruction (str): The text of instruction. sentence (str): The sentence to concatenate with. Returns: str: The complete sentence with instruction """ return instruction_format.format(instruction, sentence)
[docs] def encode_queries( self, queries: Union[List[str], str], batch_size: Optional[int] = None, max_length: Optional[int] = None, convert_to_numpy: Optional[bool] = None, **kwargs: Any ): """encode the queries using the instruction if provided. Args: queries (Union[List[str], str]): Input queries to encode. batch_size (Optional[int], optional): Number of sentences for each iter. Defaults to :data:`None`. max_length (Optional[int], optional): Maximum length of tokens. Defaults to :data:`None`. convert_to_numpy (Optional[bool], optional): If True, the output embedding will be a Numpy array. Otherwise, it will be a Torch Tensor. Defaults to :data:`None`. Returns: Union[torch.Tensor, np.ndarray]: Return the embedding vectors in a numpy array or tensor. """ if batch_size is None: batch_size = self.batch_size if max_length is None: max_length = self.query_max_length if convert_to_numpy is None: convert_to_numpy = self.convert_to_numpy return self.encode( queries, batch_size=batch_size, max_length=max_length, convert_to_numpy=convert_to_numpy, instruction=self.query_instruction_for_retrieval, instruction_format=self.query_instruction_format, **kwargs )
[docs] def encode_corpus( self, corpus: Union[List[str], str], batch_size: Optional[int] = None, max_length: Optional[int] = None, convert_to_numpy: Optional[bool] = None, **kwargs: Any ): """encode the corpus using the instruction if provided. Args: corpus (Union[List[str], str]): Input corpus to encode. batch_size (Optional[int], optional): Number of sentences for each iter. Defaults to :data:`None`. max_length (Optional[int], optional): Maximum length of tokens. Defaults to :data:`None`. convert_to_numpy (Optional[bool], optional): If True, the output embedding will be a Numpy array. Otherwise, it will be a Torch Tensor. Defaults to :data:`None`. Returns: Union[torch.Tensor, np.ndarray]: Return the embedding vectors in a numpy array or tensor. """ passage_instruction_for_retrieval = self.kwargs.get("passage_instruction_for_retrieval", None) passage_instruction_format = self.kwargs.get("passage_instruction_format", "{}{}") if batch_size is None: batch_size = self.batch_size if max_length is None: max_length = self.passage_max_length if convert_to_numpy is None: convert_to_numpy = self.convert_to_numpy return self.encode( corpus, batch_size=batch_size, max_length=max_length, convert_to_numpy=convert_to_numpy, instruction=passage_instruction_for_retrieval, instruction_format=passage_instruction_format, **kwargs )
[docs] def encode( self, sentences: Union[List[str], str], batch_size: Optional[int] = None, max_length: Optional[int] = None, convert_to_numpy: Optional[bool] = None, instruction: Optional[str] = None, instruction_format: Optional[str] = None, **kwargs: Any ): """encode the input sentences with the embedding model. Args: sentences (Union[List[str], str]): Input sentences to encode. batch_size (Optional[int], optional): Number of sentences for each iter. Defaults to :data:`None`. max_length (Optional[int], optional): Maximum length of tokens. Defaults to :data:`None`. convert_to_numpy (Optional[bool], optional): If True, the output embedding will be a Numpy array. Otherwise, it will be a Torch Tensor. Defaults to :data:`None`. instruction (Optional[str], optional): The text of instruction. Defaults to :data:`None`. instruction_format (Optional[str], optional): Format for instruction. Defaults to :data:`None`. Returns: Union[torch.Tensor, np.ndarray]: return the embedding vectors in a numpy array or tensor. """ if batch_size is None: batch_size = self.batch_size if max_length is None: max_length = self.passage_max_length if convert_to_numpy is None: convert_to_numpy = self.convert_to_numpy if instruction is not None: if isinstance(sentences, str): sentences = self.get_detailed_instruct(instruction_format, instruction, sentences) else: sentences = [self.get_detailed_instruct(instruction_format, instruction, sentence) for sentence in sentences] if isinstance(sentences, str) or len(self.target_devices) == 1: return self.encode_single_device( sentences, batch_size=batch_size, max_length=max_length, convert_to_numpy=convert_to_numpy, device=self.target_devices[0], **kwargs ) if self.pool is None: self.pool = self.start_multi_process_pool(AbsEmbedder._encode_multi_process_worker) embeddings = self.encode_multi_process( sentences, self.pool, batch_size=batch_size, max_length=max_length, convert_to_numpy=convert_to_numpy, **kwargs ) return embeddings
def __del__(self): self.stop_self_pool()
[docs] @abstractmethod def encode_single_device( self, sentences: Union[List[str], str], batch_size: int = 256, max_length: int = 512, convert_to_numpy: bool = True, device: Optional[str] = None, **kwargs: Any, ): """ This method should encode sentences and return embeddings on a single device. """ pass
# adapted from https://github.com/UKPLab/sentence-transformers/blob/1802076d4eae42ff0a5629e1b04e75785d4e193b/sentence_transformers/SentenceTransformer.py#L807
[docs] def start_multi_process_pool( self, process_target_func: Any, ) -> Dict[Literal["input", "output", "processes"], Any]: """ Starts a multi-process pool to process the encoding with several independent processes via :meth:`SentenceTransformer.encode_multi_process <sentence_transformers.SentenceTransformer.encode_multi_process>`. This method is recommended if you want to encode on multiple GPUs or CPUs. It is advised to start only one process per GPU. This method works together with encode_multi_process and stop_multi_process_pool. Returns: Dict[str, Any]: A dictionary with the target processes, an input queue, and an output queue. """ if self.model is None: raise ValueError("Model is not initialized.") logger.info("Start multi-process pool on devices: {}".format(", ".join(map(str, self.target_devices)))) self.model.to("cpu") self.model.share_memory() ctx = mp.get_context("spawn") input_queue = ctx.Queue() output_queue = ctx.Queue() processes = [] for device_id in tqdm(self.target_devices, desc='initial target device'): p = ctx.Process( target=process_target_func, args=(device_id, self, input_queue, output_queue), daemon=True, ) p.start() processes.append(p) return {"input": input_queue, "output": output_queue, "processes": processes}
# adapted from https://github.com/UKPLab/sentence-transformers/blob/1802076d4eae42ff0a5629e1b04e75785d4e193b/sentence_transformers/SentenceTransformer.py#L976
[docs] @staticmethod def _encode_multi_process_worker( target_device: str, model: 'AbsEmbedder', input_queue: Queue, results_queue: Queue ) -> None: """ Internal working process to encode sentences in multi-process setup """ while True: try: chunk_id, sentences, kwargs = ( input_queue.get() ) embeddings = model.encode_single_device( sentences, device=target_device, **kwargs ) results_queue.put([chunk_id, embeddings]) except queue.Empty: break
# copied from https://github.com/UKPLab/sentence-transformers/blob/1802076d4eae42ff0a5629e1b04e75785d4e193b/sentence_transformers/SentenceTransformer.py#L857
[docs] @staticmethod def stop_multi_process_pool(pool: Dict[Literal["input", "output", "processes"], Any]) -> None: """ Stops all processes started with start_multi_process_pool. Args: pool (Dict[str, object]): A dictionary containing the input queue, output queue, and process list. Returns: None """ for p in pool["processes"]: p.terminate() for p in pool["processes"]: p.join() p.close() pool["input"].close() pool["output"].close() pool = None
# adapted from https://github.com/UKPLab/sentence-transformers/blob/1802076d4eae42ff0a5629e1b04e75785d4e193b/sentence_transformers/SentenceTransformer.py#L877
[docs] def encode_multi_process( self, sentences: List[str], pool: Dict[Literal["input", "output", "processes"], Any], **kwargs ): chunk_size = math.ceil(len(sentences) / len(pool["processes"])) input_queue = pool["input"] last_chunk_id = 0 chunk = [] for sentence in sentences: chunk.append(sentence) if len(chunk) >= chunk_size: input_queue.put( [last_chunk_id, chunk, kwargs] ) last_chunk_id += 1 chunk = [] if len(chunk) > 0: input_queue.put([last_chunk_id, chunk, kwargs]) last_chunk_id += 1 output_queue = pool["output"] results_list = sorted( [output_queue.get() for _ in trange(last_chunk_id, desc="Chunks")], key=lambda x: x[0], ) embeddings = self._concatenate_results_from_multi_process([result[1] for result in results_list]) return embeddings
[docs] def _concatenate_results_from_multi_process(self, results_list: List[Union[torch.Tensor, np.ndarray, Any]]): """concatenate and return the results from all the processes Args: results_list (List[Union[torch.Tensor, np.ndarray, Any]]): A list of results from all the processes. Raises: NotImplementedError: Unsupported type for results_list Returns: Union[torch.Tensor, np.ndarray]: return the embedding vectors in a numpy array or tensor. """ if isinstance(results_list[0], torch.Tensor): return torch.cat(results_list, dim=0) elif isinstance(results_list[0], np.ndarray): return np.concatenate(results_list, axis=0) else: raise NotImplementedError("Unsupported type for results_list")