Source code for FlagEmbedding.finetune.reranker.decoder_only.layerwise.modeling
import torch
from transformers import PreTrainedModel, AutoTokenizer
import logging
from typing import List, Union, Dict, Optional
from torch import Tensor
from FlagEmbedding.abc.finetune.reranker import AbsRerankerModel, RerankerOutput
logger = logging.getLogger(__name__)
[docs]
class CrossDecoderModel(AbsRerankerModel):
"""
Model class for decoder only reranker.
Args:
base_model (PreTrainedModel): The underlying pre-trained model used for encoding and scoring input pairs.
tokenizer (AutoTokenizer, optional): The tokenizer for encoding input text. Defaults to ``None``.
train_batch_size (int, optional): The batch size to use. Defaults to ``4``.
start_layer (int, optional): Starting layer for layerwise. Defaults to ``8``.
"""
def __init__(
self,
base_model: PreTrainedModel,
tokenizer: AutoTokenizer = None,
train_batch_size: int = 4,
start_layer: int = 8
):
super().__init__(
base_model,
tokenizer=tokenizer,
train_batch_size=train_batch_size,
)
self.start_layer = start_layer
[docs]
def encode(self, features):
if features is None:
return None
outputs = self.model(input_ids=features['input_ids'],
attention_mask=features['attention_mask'],
position_ids=features['position_ids'] if 'position_ids' in features.keys() else None,
output_hidden_states=True)
all_logits = outputs.logits
all_scores = []
for logits in all_logits:
all_scores.append(logits[:, -1].contiguous())
return all_scores
[docs]
def forward(self, pair: Union[Dict[str, Tensor], List[Dict[str, Tensor]]] = None, teacher_scores: Optional[Tensor] = None):
ranker_logits = self.encode(pair) # (batch_size * num, dim)
if self.training:
loss = 0
for logits in ranker_logits:
grouped_logits = logits.view(self.train_batch_size, -1)
target = torch.zeros(self.train_batch_size, device=grouped_logits.device, dtype=torch.long)
loss += self.compute_loss(grouped_logits, target)
if teacher_scores is None:
teacher_scores = ranker_logits[-1].view(
self.train_batch_size,
-1
)
teacher_targets = torch.softmax(teacher_scores.detach(), dim=-1)
for logits in ranker_logits[:-1]:
student_scores = logits.view(
self.train_batch_size,
-1
)
loss += - torch.mean(torch.sum(torch.log_softmax(student_scores, dim=-1) * teacher_targets, dim=-1))
else:
teacher_scores = torch.Tensor(teacher_scores)
teacher_scores = teacher_scores.view(self.train_batch_size, -1)
teacher_targets = torch.softmax(teacher_scores.detach(), dim=-1).to(ranker_logits[-1].device)
for logits in ranker_logits:
student_scores = logits.view(
self.train_batch_size,
-1
)
loss += - torch.mean(torch.sum(torch.log_softmax(student_scores, dim=-1) * teacher_targets, dim=-1))
else:
loss = None
# print(loss)
return RerankerOutput(
loss=loss,
scores=ranker_logits,
)