Source code for FlagEmbedding.finetune.reranker.encoder_only.base.modeling

from transformers import PreTrainedModel, AutoTokenizer
import logging

from FlagEmbedding.abc.finetune.reranker import AbsRerankerModel

logger = logging.getLogger(__name__)


[docs] class CrossEncoderModel(AbsRerankerModel): """Model class for reranker. Args: base_model (PreTrainedModel): The underlying pre-trained model used for encoding and scoring input pairs. tokenizer (AutoTokenizer, optional): The tokenizer for encoding input text. Defaults to ``None``. train_batch_size (int, optional): The batch size to use. Defaults to ``4``. """ def __init__( self, base_model: PreTrainedModel, tokenizer: AutoTokenizer = None, train_batch_size: int = 4, ): super().__init__( base_model, tokenizer=tokenizer, train_batch_size=train_batch_size, )
[docs] def encode(self, features): """Encodes input features to logits. Args: features (dict): Dictionary with input features. Returns: torch.Tensor: The logits output from the model. """ return self.model(**features, return_dict=True).logits