Source code for FlagEmbedding.inference.auto_reranker

import os
import logging
from typing import Union, Optional

from FlagEmbedding.inference.reranker.model_mapping import (
    RerankerModelClass,
    RERANKER_CLASS_MAPPING,
    AUTO_RERANKER_MAPPING
)

logger = logging.getLogger(__name__)


[docs] class FlagAutoReranker: """ Automatically choose the appropriate class to load the reranker model. """ def __init__(self): raise EnvironmentError( "FlagAutoReranker is designed to be instantiated using the `FlagAutoReranker.from_finetuned(model_name_or_path)` method." )
[docs] @classmethod def from_finetuned( cls, model_name_or_path: str, model_class: Optional[Union[str, RerankerModelClass]] = None, use_fp16: bool = False, trust_remote_code: Optional[bool] = None, **kwargs, ): """ Load a finetuned model according to the provided vars. Args: model_name_or_path (str): If it's a path to a local model, it loads the model from the path. Otherwise tries to download and load a model from HuggingFace Hub with the name. model_class (Optional[Union[str, RerankerModelClass]], optional): The reranker class to use.. Defaults to :data:`None`. use_fp16 (bool, optional): If true, use half-precision floating-point to speed up computation with a slight performance degradation. Defaults to :data:`False`. trust_remote_code (Optional[bool], optional): trust_remote_code for HF datasets or models. Defaults to :data:`None`. Raises: ValueError Returns: AbsReranker: The reranker class to load model, which is child class of :class:`AbsReranker`. """ model_name = os.path.basename(model_name_or_path) if model_name.startswith("checkpoint-"): model_name = os.path.basename(os.path.dirname(model_name_or_path)) if model_class is not None: _model_class = RERANKER_CLASS_MAPPING[RerankerModelClass(model_class)] if trust_remote_code is None: trust_remote_code = False logging.warning( f"`trust_remote_code` is not specified, set to default value '{trust_remote_code}'." ) else: if model_name not in AUTO_RERANKER_MAPPING: raise ValueError( f"Model name '{model_name}' not found in the model mapping. You can pull request to add the model to " "`https://github.com/FlagOpen/FlagEmbedding/tree/master/FlagEmbedding/inference/reranker/model_mapping.py`. " "If need, you can create a new `<model>.py` file in `https://github.com/FlagOpen/FlagEmbedding/tree/master/FlagEmbedding/inference/reranker/encoder_only` " "or `https://github.com/FlagOpen/FlagEmbedding/tree/master/FlagEmbedding/inference/reranker/decoder_only`. " "Welcome to contribute! You can also directly specify the corresponding `model_class` to instantiate the model." ) model_config = AUTO_RERANKER_MAPPING[model_name] _model_class = model_config.model_class if trust_remote_code is None: trust_remote_code = model_config.trust_remote_code return _model_class( model_name_or_path, use_fp16=use_fp16, trust_remote_code=trust_remote_code, **kwargs, )