Source code for FlagEmbedding.finetune.reranker.decoder_only.base.modeling

import torch
from transformers import PreTrainedModel, AutoTokenizer
import logging

from FlagEmbedding.abc.finetune.reranker import AbsRerankerModel

logger = logging.getLogger(__name__)


[docs] class CrossDecoderModel(AbsRerankerModel): """ Model class for decoder only reranker. Args: base_model (PreTrainedModel): The underlying pre-trained model used for encoding and scoring input pairs. tokenizer (AutoTokenizer, optional): The tokenizer for encoding input text. Defaults to ``None``. train_batch_size (int, optional): The batch size to use. Defaults to ``4``. """ def __init__( self, base_model: PreTrainedModel, tokenizer: AutoTokenizer = None, train_batch_size: int = 4, ): super().__init__( base_model, tokenizer=tokenizer, train_batch_size=train_batch_size, )
[docs] def encode(self, features): """Encodes input features to logits. Args: features (dict): Dictionary with input features. Returns: torch.Tensor: The logits output from the model. """ if features is None: return None outputs = self.model(input_ids=features['input_ids'], attention_mask=features['attention_mask'], position_ids=features['position_ids'] if 'position_ids' in features.keys() else None, output_hidden_states=True) # _, max_indices = torch.max(features['labels'], dim=1) # predict_indices = max_indices # logits = [outputs.logits[i, predict_indices[i], :] for i in range(outputs.logits.shape[0])] # logits = torch.stack(logits, dim=0) scores = outputs.logits[:, -1, self.yes_loc] return scores.contiguous()