import os
import logging
from typing import List, Union, Optional
from FlagEmbedding.inference.embedder.model_mapping import (
EmbedderModelClass,
AUTO_EMBEDDER_MAPPING, EMBEDDER_CLASS_MAPPING
)
logger = logging.getLogger(__name__)
[docs]
class FlagAutoModel:
"""
Automatically choose the appropriate class to load the embedding model.
"""
def __init__(self):
raise EnvironmentError(
"FlagAutoModel is designed to be instantiated using the `FlagAutoModel.from_finetuned(model_name_or_path)` method."
)
[docs]
@classmethod
def from_finetuned(
cls,
model_name_or_path: str,
model_class: Optional[Union[str, EmbedderModelClass]] = None,
normalize_embeddings: bool = True,
use_fp16: bool = True,
query_instruction_for_retrieval: Optional[str] = None,
devices: Optional[Union[str, List[str]]] = None,
pooling_method: Optional[str] = None,
trust_remote_code: Optional[bool] = None,
query_instruction_format: Optional[str] = None,
**kwargs,
):
"""
Load a finetuned model according to the provided vars.
Args:
model_name_or_path (str): If it's a path to a local model, it loads the model from the path. Otherwise tries to download and
load a model from HuggingFace Hub with the name.
model_class (Optional[Union[str, EmbedderModelClass]], optional): The embedder class to use. Defaults to :data:`None`.
normalize_embeddings (bool, optional): If True, the output embedding will be a Numpy array. Otherwise, it will be a Torch Tensor.
Defaults to :data:`True`.
use_fp16 (bool, optional): If true, use half-precision floating-point to speed up computation with a slight performance
degradation. Defaults to :data:`True`.
query_instruction_for_retrieval (Optional[str], optional): Query instruction for retrieval tasks, which will be used with
:attr:`query_instruction_format`. Defaults to :data:`None`.
devices (Optional[Union[str, List[str]]], optional): Devices to use for model inference. Defaults to :data:`None`.
pooling_method (Optional[str], optional): Pooling method to get embedding vector from the last hidden state. Defaults to :data:`None`.
trust_remote_code (Optional[bool], optional): trust_remote_code for HF datasets or models. Defaults to :data:`None`.
query_instruction_format (Optional[str], optional): The template for :attr:`query_instruction_for_retrieval`. Defaults to :data:`None`.
Raises:
ValueError
Returns:
AbsEmbedder: The model class to load model, which is child class of :class:`AbsEmbedder`.
"""
model_name = os.path.basename(model_name_or_path)
if model_name.startswith("checkpoint-"):
model_name = os.path.basename(os.path.dirname(model_name_or_path))
if model_class is not None:
_model_class = EMBEDDER_CLASS_MAPPING[EmbedderModelClass(model_class)]
if pooling_method is None:
pooling_method = _model_class.DEFAULT_POOLING_METHOD
logger.warning(
f"`pooling_method` is not specified, use default pooling method '{pooling_method}'."
)
if trust_remote_code is None:
trust_remote_code = False
logger.warning(
f"`trust_remote_code` is not specified, set to default value '{trust_remote_code}'."
)
if query_instruction_format is None:
query_instruction_format = "{}{}"
logger.warning(
f"`query_instruction_format` is not specified, set to default value '{query_instruction_format}'."
)
else:
if model_name not in AUTO_EMBEDDER_MAPPING:
raise ValueError(
f"Model name '{model_name}' not found in the model mapping. You can pull request to add the model to "
"`https://github.com/FlagOpen/FlagEmbedding/tree/master/FlagEmbedding/inference/embedder/model_mapping.py`. "
"If need, you can create a new `<model>.py` file in `https://github.com/FlagOpen/FlagEmbedding/tree/master/FlagEmbedding/inference/embedder/encoder_only` "
"or `https://github.com/FlagOpen/FlagEmbedding/tree/master/FlagEmbedding/inference/embedder/decoder_only`. "
"Welcome to contribute! You can also directly specify the corresponding `model_class` to instantiate the model."
)
model_config = AUTO_EMBEDDER_MAPPING[model_name]
_model_class = model_config.model_class
if pooling_method is None:
pooling_method = model_config.pooling_method.value
if trust_remote_code is None:
trust_remote_code = model_config.trust_remote_code
if query_instruction_format is None:
query_instruction_format = model_config.query_instruction_format
return _model_class(
model_name_or_path,
normalize_embeddings=normalize_embeddings,
use_fp16=use_fp16,
query_instruction_for_retrieval=query_instruction_for_retrieval,
query_instruction_format=query_instruction_format,
devices=devices,
pooling_method=pooling_method,
trust_remote_code=trust_remote_code,
**kwargs,
)