Source code for FlagEmbedding.inference.embedder.decoder_only.base

from tqdm import tqdm, trange
from typing import cast, Any, List, Union, Optional

import torch
import numpy as np
from transformers import AutoModel, AutoTokenizer

from FlagEmbedding.abc.inference import AbsEmbedder


# Pooling function for LLM-based embedding models
def last_token_pool(last_hidden_states: torch.Tensor,
                    attention_mask: torch.Tensor) -> torch.Tensor:
    """Last token pooling method.

    Args:
        last_hidden_state (torch.Tensor): The last hidden state of the model.
        attention_mask (torch.Tensor): Attention mask. Defaults to :data:`None`.

    Returns:
        torch.Tensor: The embedding vectors after pooling.
    """
    left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
    if left_padding:
        return last_hidden_states[:, -1]
    else:
        sequence_lengths = attention_mask.sum(dim=1) - 1
        batch_size = last_hidden_states.shape[0]
        return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]


[docs] class BaseLLMEmbedder(AbsEmbedder): """Base embedder class for LLM like decoder only models. Args: model_name_or_path (str): If it's a path to a local model, it loads the model from the path. Otherwise tries to download and load a model from HuggingFace Hub with the name. normalize_embeddings (bool, optional): If True, normalize the embedding vector. Defaults to :data:`True`. use_fp16 (bool, optional): If true, use half-precision floating-point to speed up computation with a slight performance degradation. Defaults to :data:`True`. query_instruction_for_retrieval (Optional[str], optional): Query instruction for retrieval tasks, which will be used with with :attr:`query_instruction_format`. Defaults to :data:`None`. query_instruction_format (str, optional): The template for :attr:`query_instruction_for_retrieval`. Defaults to :data:`"{}{}"`. devices (Optional[Union[str, int, List[str], List[int]]], optional): Devices to use for model inference. Defaults to :data:`None`. trust_remote_code (bool, optional): trust_remote_code for HF datasets or models. Defaults to :data:`False`. cache_dir (Optional[str], optional): Cache directory for the model. Defaults to :data:`None`. batch_size (int, optional): Batch size for inference. Defaults to :data:`256`. query_max_length (int, optional): Maximum length for query. Defaults to :data:`512`. passage_max_length (int, optional): Maximum length for passage. Defaults to :data:`512`. convert_to_numpy (bool, optional): If True, the output embedding will be a Numpy array. Otherwise, it will be a Torch Tensor. Defaults to :data:`True`. Attributes: DEFAULT_POOLING_METHOD: The default pooling method when running the model. """ DEFAULT_POOLING_METHOD = "last_token" def __init__( self, model_name_or_path: str, normalize_embeddings: bool = True, use_fp16: bool = True, query_instruction_for_retrieval: Optional[str] = None, query_instruction_format: str = "Instruct: {}\nQuery: {}", # specify the format of query_instruction_for_retrieval devices: Optional[Union[str, List[str]]] = None, # specify devices, such as "cuda:0" or ["cuda:0", "cuda:1"] # Additional parameters for BaseLLMEmbedder trust_remote_code: bool = False, cache_dir: Optional[str] = None, # inference batch_size: int = 256, query_max_length: int = 512, passage_max_length: int = 512, convert_to_numpy: bool = True, **kwargs: Any, ): super().__init__( model_name_or_path, normalize_embeddings=normalize_embeddings, use_fp16=use_fp16, query_instruction_for_retrieval=query_instruction_for_retrieval, query_instruction_format=query_instruction_format, devices=devices, batch_size=batch_size, query_max_length=query_max_length, passage_max_length=passage_max_length, convert_to_numpy=convert_to_numpy, **kwargs ) self.tokenizer = AutoTokenizer.from_pretrained( model_name_or_path, trust_remote_code=trust_remote_code, cache_dir=cache_dir ) self.model = AutoModel.from_pretrained( model_name_or_path, trust_remote_code=trust_remote_code, cache_dir=cache_dir ) if self.kwargs.get("pooling_method", "last_token") != "last_token": raise ValueError("Pooling method must be 'last_token' for LLM-based models.")
[docs] def encode_queries( self, queries: Union[List[str], str], batch_size: Optional[int] = None, max_length: Optional[int] = None, convert_to_numpy: Optional[bool] = None, **kwargs: Any ) -> Union[np.ndarray, torch.Tensor]: """Encode the queries. Args: queries (Union[List[str], str]): Input queries to encode. batch_size (Optional[int], optional): Number of sentences for each iter. Defaults to :data:`None`. max_length (Optional[int], optional): Maximum length of tokens. Defaults to :data:`None`. convert_to_numpy (Optional[bool], optional): If True, the output embedding will be a Numpy array. Otherwise, it will be a Torch Tensor. Defaults to :data:`None`. Returns: Union[torch.Tensor, np.ndarray]: Return the embedding vectors in a numpy array or tensor. """ return super().encode_queries( queries, batch_size=batch_size, max_length=max_length, convert_to_numpy=convert_to_numpy, **kwargs )
[docs] def encode_corpus( self, corpus: Union[List[str], str], batch_size: Optional[int] = None, max_length: Optional[int] = None, convert_to_numpy: Optional[bool] = None, **kwargs: Any ) -> Union[np.ndarray, torch.Tensor]: """Encode the corpus. Args: corpus (Union[List[str], str]): Input corpus to encode. batch_size (Optional[int], optional): Number of sentences for each iter. Defaults to :data:`None`. max_length (Optional[int], optional): Maximum length of tokens. Defaults to :data:`None`. convert_to_numpy (Optional[bool], optional): If True, the output embedding will be a Numpy array. Otherwise, it will be a Torch Tensor. Defaults to :data:`None`. Returns: Union[torch.Tensor, np.ndarray]: Return the embedding vectors in a numpy array or tensor. """ return super().encode_corpus( corpus, batch_size=batch_size, max_length=max_length, convert_to_numpy=convert_to_numpy, **kwargs )
[docs] def encode( self, sentences: Union[List[str], str], batch_size: Optional[int] = None, max_length: Optional[int] = None, convert_to_numpy: Optional[bool] = None, **kwargs: Any ) -> Union[np.ndarray, torch.Tensor]: """Encode the input sentences with the embedding model. Args: sentences (Union[List[str], str]): Input sentences to encode. batch_size (Optional[int], optional): Number of sentences for each iter. Defaults to :data:`None`. max_length (Optional[int], optional): Maximum length of tokens. Defaults to :data:`None`. convert_to_numpy (Optional[bool], optional): If True, the output embedding will be a Numpy array. Otherwise, it will be a Torch Tensor. Defaults to :data:`None`. Returns: Union[torch.Tensor, np.ndarray]: return the embedding vectors in a numpy array or tensor. """ return super().encode( sentences, batch_size=batch_size, max_length=max_length, convert_to_numpy=convert_to_numpy, **kwargs )
[docs] @torch.no_grad() def encode_single_device( self, sentences: Union[List[str], str], batch_size: int = 256, max_length: int = 512, convert_to_numpy: bool = True, device: Optional[str] = None, **kwargs: Any # add `pad_to_multiple_of=8` for bge-multilingual-gemmma2 ): """Encode input sentences by a single device. Args: sentences (Union[List[str], str]): Input sentences to encode. batch_size (int, optional): Number of sentences for each iter. Defaults to :data:`256`. max_length (int, optional): Maximum length of tokens. Defaults to :data:`512`. convert_to_numpy (bool, optional): If True, the output embedding will be a Numpy array. Otherwise, it will be a Torch Tensor. Defaults to :data:`True`. device (Optional[str], optional): Device to use for encoding. Defaults to None. Returns: Union[torch.Tensor, np.ndarray]: return the embedding vectors in a numpy array or tensor. """ if device is None: device = self.target_devices[0] if device == "cpu": self.use_fp16 = False if self.use_fp16: self.model.half() self.model.to(device) self.model.eval() input_was_string = False if isinstance(sentences, str): sentences = [sentences] input_was_string = True # tokenize without padding to get the correct length all_inputs = [] for start_index in trange(0, len(sentences), batch_size, desc='pre tokenize', disable=len(sentences) < 256): sentences_batch = sentences[start_index:start_index + batch_size] inputs_batch = self.tokenizer( sentences_batch, truncation=True, max_length=max_length, **kwargs ) inputs_batch = [{ k: inputs_batch[k][i] for k in inputs_batch.keys() } for i in range(len(sentences_batch))] all_inputs.extend(inputs_batch) # sort by length for less padding length_sorted_idx = np.argsort([-len(x['input_ids']) for x in all_inputs]) all_inputs_sorted = [all_inputs[i] for i in length_sorted_idx] # adjust batch size flag = False while flag is False: try: inputs_batch = self.tokenizer.pad( all_inputs_sorted[: batch_size], padding=True, return_tensors='pt', **kwargs ).to(device) last_hidden_state = self.model(**inputs_batch, return_dict=True).last_hidden_state embeddings = last_token_pool(last_hidden_state, inputs_batch['attention_mask']) flag = True except RuntimeError as e: batch_size = batch_size * 3 // 4 except torch.OutofMemoryError as e: batch_size = batch_size * 3 // 4 # encode all_embeddings = [] for start_index in tqdm(range(0, len(sentences), batch_size), desc="Inference Embeddings", disable=len(sentences) < 256): inputs_batch = all_inputs_sorted[start_index:start_index + batch_size] inputs_batch = self.tokenizer.pad( inputs_batch, padding=True, return_tensors='pt', **kwargs ).to(device) last_hidden_state = self.model(**inputs_batch, return_dict=True).last_hidden_state embeddings = last_token_pool(last_hidden_state, inputs_batch['attention_mask']) if self.normalize_embeddings: embeddings = torch.nn.functional.normalize(embeddings, dim=-1) embeddings = cast(torch.Tensor, embeddings) if convert_to_numpy: embeddings = embeddings.cpu().numpy() all_embeddings.append(embeddings) if convert_to_numpy: all_embeddings = np.concatenate(all_embeddings, axis=0) else: all_embeddings = torch.cat(all_embeddings, dim=0) # adjust the order of embeddings all_embeddings = all_embeddings[np.argsort(length_sorted_idx)] # return the embeddings if input_was_string: return all_embeddings[0] return all_embeddings