Source code for FlagEmbedding.finetune.reranker.encoder_only.base.modeling
from transformers import PreTrainedModel, AutoTokenizer
import logging
from FlagEmbedding.abc.finetune.reranker import AbsRerankerModel
logger = logging.getLogger(__name__)
[docs]
class CrossEncoderModel(AbsRerankerModel):
"""Model class for reranker.
Args:
base_model (PreTrainedModel): The underlying pre-trained model used for encoding and scoring input pairs.
tokenizer (AutoTokenizer, optional): The tokenizer for encoding input text. Defaults to ``None``.
train_batch_size (int, optional): The batch size to use. Defaults to ``4``.
"""
def __init__(
self,
base_model: PreTrainedModel,
tokenizer: AutoTokenizer = None,
train_batch_size: int = 4,
):
super().__init__(
base_model,
tokenizer=tokenizer,
train_batch_size=train_batch_size,
)
[docs]
def encode(self, features):
"""Encodes input features to logits.
Args:
features (dict): Dictionary with input features.
Returns:
torch.Tensor: The logits output from the model.
"""
return self.model(**features, return_dict=True).logits