import torch
from torch import nn, Tensor
import torch.nn.functional as F
import torch.distributed as dist
from transformers import AutoTokenizer
from transformers.file_utils import ModelOutput
import logging
from dataclasses import dataclass
from abc import ABC, abstractmethod
from typing import Dict, Optional, List, Union
logger = logging.getLogger(__name__)
[docs]
@dataclass
class EmbedderOutput(ModelOutput):
"""
Output information returned by the model.
"""
q_reps: Optional[Tensor] = None
p_reps: Optional[Tensor] = None
loss: Optional[Tensor] = None
scores: Optional[Tensor] = None
[docs]
class AbsEmbedderModel(ABC, nn.Module):
"""Abstract class of embedding model for training.
Args:
base_model: The base model to train on.
tokenizer (AutoTokenizer, optional): The tokenizer to use. Defaults to ``None``.
negatives_cross_device (bool, optional): If True, will compute cross devices negative loss. Defaults to ``False``.
temperature (float, optional): Temperature to control the scale of scores. Defaults to ``1.0``.
sub_batch_size (int, optional): Sub-batch size during encoding. If negative, will not split to sub-batch.
Defaults to ``-1``.
kd_loss_type (str, optional): Type of knowledge distillation loss. Defaults to ``"kl_div"``.
"""
def __init__(
self,
base_model,
tokenizer: AutoTokenizer = None,
negatives_cross_device: bool = False,
temperature: float = 1.0,
sub_batch_size: int = -1,
kd_loss_type: str = 'kl_div',
):
super().__init__()
self.model = base_model
self.tokenizer = tokenizer
self.temperature = temperature
self.negatives_cross_device = negatives_cross_device
if self.negatives_cross_device:
if not dist.is_initialized():
raise ValueError('Distributed training has not been initialized for representation all gather.')
self.process_rank = dist.get_rank()
self.world_size = dist.get_world_size()
self.sub_batch_size = sub_batch_size
self.kd_loss_type = kd_loss_type
[docs]
@abstractmethod
def encode(self, features):
"""Abstract method encode and get the embedding.
Args:
features (Union[list, dict]): Features feed to the model.
"""
pass
[docs]
@abstractmethod
def compute_loss(self, scores, target):
"""Abstract method compute the loss.
Args:
scores (torch.Tensor): Computed score.
target (torch.Tensor): The target value.
"""
pass
[docs]
@abstractmethod
def compute_score(self, q_reps, p_reps):
"""Abstract method to compute the score.
Args:
q_reps (torch.Tensor): Queries representations.
p_reps (torch.Tensor): Passages rerpresentations.
"""
pass
[docs]
@abstractmethod
def save(self, output_dir: str):
"""Abstract method to save the model.
Args:
output_dir (str): Directory for saving the model.
"""
pass
[docs]
def get_local_score(self, q_reps, p_reps, all_scores):
"""Get the local score of queries and passages.
Args:
q_reps (torch.Tensor): Queries representations.
p_reps (torch.Tensor): Passages rerpresentations.
all_scores (torch.Tensor): All the query-passage scores computed.
Returns:
torch.Tensor: Local scores to compute loss.
"""
group_size = p_reps.size(0) // q_reps.size(0)
indices = torch.arange(0, q_reps.size(0), device=q_reps.device) * group_size
specific_scores = []
for i in range(group_size):
specific_scores.append(
all_scores[torch.arange(q_reps.size(0), device=q_reps.device), indices + i]
)
return torch.stack(specific_scores, dim=1).view(q_reps.size(0), -1)
[docs]
def compute_local_score(self, q_reps, p_reps, compute_score_func=None, **kwargs):
"""Compute the local score of queries and passages.
Args:
q_reps (torch.Tensor): Queries representations.
p_reps (torch.Tensor): Passages rerpresentations.
compute_score_func (function, optional): Function to compute score. Defaults to ``None``, which will use the
:meth:`self.compute_score`.
Returns:
torch.Tensor: Local scores to compute loss.
"""
if compute_score_func is None:
all_scores = self.compute_score(q_reps, p_reps)
else:
all_scores = compute_score_func(q_reps, p_reps, **kwargs)
loacl_scores = self.get_local_score(q_reps, p_reps, all_scores)
return loacl_scores
[docs]
def _compute_no_in_batch_neg_loss(self, q_reps, p_reps, teacher_targets=None, compute_score_func=None, **kwargs):
"""
Compute loss when using no in-batch negatives and no cross-device negatives
"""
group_size = p_reps.size(0) // q_reps.size(0)
local_scores = self.compute_local_score(q_reps, p_reps, compute_score_func, **kwargs) # (batch_size, group_size)
if teacher_targets is not None:
# compute kd loss
loss = self.distill_loss(self.kd_loss_type, teacher_targets, local_scores, group_size=group_size)
# add normal loss if needed
if self.kd_loss_type == "kl_div":
local_targets = torch.zeros(local_scores.size(0), device=local_scores.device, dtype=torch.long) # (batch_size)
loss += self.compute_loss(local_scores, local_targets)
else:
local_targets = torch.zeros(local_scores.size(0), device=local_scores.device, dtype=torch.long) # (batch_size)
loss = self.compute_loss(local_scores, local_targets)
return local_scores, loss
[docs]
def _compute_in_batch_neg_loss(self, q_reps, p_reps, teacher_targets=None, compute_score_func=None, **kwargs):
"""
Compute loss when only using in-batch negatives
"""
group_size = p_reps.size(0) // q_reps.size(0)
if compute_score_func is None:
scores = self.compute_score(q_reps, p_reps) # (batch_size, batch_size * group_size)
else:
scores = compute_score_func(q_reps, p_reps, **kwargs) # (batch_size, batch_size * group_size)
if teacher_targets is not None:
# compute kd loss
if self.kd_loss_type == "kl_div":
student_scores = self.get_local_score(q_reps, p_reps, scores) # (batch_size, group_size)
loss = self.distill_loss(self.kd_loss_type, teacher_targets, student_scores, group_size)
idxs = torch.arange(q_reps.size(0), device=q_reps.device, dtype=torch.long)
targets = idxs * (p_reps.size(0) // q_reps.size(0)) # (batch_size)
loss += self.compute_loss(scores, targets)
elif self.kd_loss_type == "m3_kd_loss":
loss = self.distill_loss(self.kd_loss_type, teacher_targets, scores, group_size)
else:
raise ValueError(f"Invalid kd_loss_type: {self.kd_loss_type}")
else:
idxs = torch.arange(q_reps.size(0), device=q_reps.device, dtype=torch.long)
targets = idxs * group_size # (batch_size)
loss = self.compute_loss(scores, targets)
return scores, loss
[docs]
def _compute_cross_device_neg_loss(self, q_reps, p_reps, teacher_targets=None, compute_score_func=None, **kwargs):
"""
Compute loss when using both in-batch negatives and cross-device negatives
"""
group_size = p_reps.size(0) // q_reps.size(0)
cross_q_reps = self._dist_gather_tensor(q_reps) # (world_size * batch_size, dim)
cross_p_reps = self._dist_gather_tensor(p_reps) # (world_size * batch_size * group_size, dim)
if compute_score_func is None:
cross_scores = self.compute_score(cross_q_reps, cross_p_reps) # (world_size * batch_size, world_size * batch_size * group_size)
else:
cross_scores = compute_score_func(cross_q_reps, cross_p_reps, **kwargs) # (world_size * batch_size, world_size * batch_size * group_size)
if teacher_targets is not None:
# compute kd loss
if self.kd_loss_type == "kl_div":
student_scores = self.get_local_score(cross_q_reps, cross_p_reps, cross_scores) # (world_size * batch_size, group_size)
student_scores = student_scores[
q_reps.size(0)*self.process_rank : q_reps.size(0)*(self.process_rank+1)
] # (batch_size, group_size)
loss = self.distill_loss(self.kd_loss_type, teacher_targets, student_scores, group_size)
cross_idxs = torch.arange(cross_q_reps.size(0), device=cross_q_reps.device, dtype=torch.long)
cross_targets = cross_idxs * group_size # (world_size * batch_size)
loss += self.compute_loss(cross_scores, cross_targets)
elif self.kd_loss_type == "m3_kd_loss":
cross_teacher_targets = self._dist_gather_tensor(teacher_targets) # (world_size * batch_size, group_size)
loss = self.distill_loss(self.kd_loss_type, cross_teacher_targets, cross_scores, group_size)
else:
raise ValueError(f"Invalid kd_loss_type: {self.kd_loss_type}")
else:
cross_idxs = torch.arange(cross_q_reps.size(0), device=cross_q_reps.device, dtype=torch.long)
cross_targets = cross_idxs * group_size # (world_size * batch_size)
loss = self.compute_loss(cross_scores, cross_targets)
return cross_scores, loss
[docs]
def forward(
self,
queries: Union[Dict[str, Tensor], List[Dict[str, Tensor]]] = None,
passages: Union[Dict[str, Tensor], List[Dict[str, Tensor]]] = None,
teacher_scores: Union[None, List[float]] = None,
no_in_batch_neg_flag: bool = False,
):
"""The computation performed at every call.
Args:
queries (Union[Dict[str, Tensor], List[Dict[str, Tensor]]], optional): Input queries. Defaults to ``None``.
passages (Union[Dict[str, Tensor], List[Dict[str, Tensor]]], optional): Input passages. Defaults to ``None``.
teacher_scores (Union[None, List[float]], optional): Teacher scores for distillation. Defaults to ``None``.
no_in_batch_neg_flag (bool, optional): If True, use no in-batch negatives and no cross-device negatives. Defaults to ``False``.
Returns:
EmbedderOutput: Output of the forward call of model.
"""
q_reps = self.encode(queries) # (batch_size, dim)
p_reps = self.encode(passages) # (batch_size * group_size, dim)
if self.training:
if teacher_scores is not None:
teacher_scores = torch.tensor(teacher_scores, device=q_reps.device)
teacher_scores = teacher_scores.view(q_reps.size(0), -1).detach() # (batch_size, group_size)
teacher_targets = F.softmax(teacher_scores, dim=-1) # (batch_size, group_size)
else:
teacher_targets = None
if no_in_batch_neg_flag:
compute_loss_func = self._compute_no_in_batch_neg_loss
else:
if self.negatives_cross_device:
compute_loss_func = self._compute_cross_device_neg_loss
else:
compute_loss_func = self._compute_in_batch_neg_loss
scores, loss = compute_loss_func(q_reps, p_reps, teacher_targets=teacher_targets)
else:
loss = None
return EmbedderOutput(
loss=loss,
)
[docs]
@staticmethod
def distill_loss(kd_loss_type, teacher_targets, student_scores, group_size=None):
"""Compute the distillation loss.
Args:
kd_loss_type (str): Type of knowledge distillation loss, supports "kl_div" and "m3_kd_loss".
teacher_targets (torch.Tensor): Targets from the teacher model.
student_scores (torch.Tensor): Score of student model.
group_size (int, optional): Number of groups for . Defaults to ``None``.
Raises:
ValueError: Invalid kd_loss_type
Returns:
torch.Tensor: A scalar of computed distillation loss.
"""
if kd_loss_type == 'kl_div':
# teacher_targets: (batch_size, group_size) / (world_size * batch_size, group_size)
# student_scores: (batch_size, group_size) / (world_size * batch_size, group_size)
return - torch.mean(
torch.sum(torch.log_softmax(student_scores, dim=-1) * teacher_targets, dim=-1)
)
elif kd_loss_type == 'm3_kd_loss':
# teacher_targets: (batch_size, group_size) / (world_size * batch_size, group_size)
# student_scores: (batch_size, batch_size * group_size) / (world_size * batch_size, world_size * batch_size * group_size)
labels = torch.arange(student_scores.size(0), device=student_scores.device, dtype=torch.long)
labels = labels * group_size
loss = 0
mask = torch.zeros_like(student_scores)
for i in range(group_size):
temp_target = labels + i
temp_scores = student_scores + mask
temp_loss = F.cross_entropy(temp_scores, temp_target, reduction="none") # B
loss += torch.mean(teacher_targets[:, i] * temp_loss)
mask = torch.scatter(mask, dim=-1, index=temp_target.unsqueeze(-1),
value=torch.finfo(student_scores.dtype).min)
return loss
else:
raise ValueError(f"Invalid kd_loss_type: {kd_loss_type}")
[docs]
def _dist_gather_tensor(self, t: Optional[torch.Tensor]):
"""Gather a tensor from all processes in a distributed setting.
Args:
t (Optional[torch.Tensor]): The input tensor to be gathered. If `None`, no gathering is performed.
Returns:
Union[torch.Tensor, None]: A concatenated tensor from all processes if ``t`` is not ``None``,
otherwise returns ``None``.
"""
if t is None:
return None
t = t.contiguous()
all_tensors = [torch.empty_like(t) for _ in range(self.world_size)]
dist.all_gather(all_tensors, t)
all_tensors[self.process_rank] = t
all_tensors = torch.cat(all_tensors, dim=0)
return all_tensors