import os
import math
import random
import logging
import datasets
import numpy as np
import torch.distributed as dist
from dataclasses import dataclass
from torch.utils.data import Dataset
from transformers import (
PreTrainedTokenizer,
DataCollatorWithPadding,
BatchEncoding,
DataCollatorForSeq2Seq
)
from typing import List
from .AbsArguments import AbsRerankerDataArguments
logger = logging.getLogger(__name__)
[docs]
class AbsRerankerTrainDataset(Dataset):
"""Abstract class for reranker training dataset.
Args:
args (AbsRerankerDataArguments): Data arguments.
tokenizer (PreTrainedTokenizer): Tokenizer to use.
"""
def __init__(
self,
args: AbsRerankerDataArguments,
tokenizer: PreTrainedTokenizer
):
self.args = args
self.tokenizer = tokenizer
train_datasets = []
for data_dir in args.train_data:
if not os.path.isdir(data_dir):
if not (data_dir.endswith('.json') or data_dir.endswith('.jsonl')): continue
temp_dataset = self._load_dataset(data_dir)
if len(temp_dataset) == 0: continue
train_datasets.append(temp_dataset)
else:
for file in os.listdir(data_dir):
if not (file.endswith('.json') or file.endswith('.jsonl')): continue
temp_dataset = self._load_dataset(os.path.join(data_dir, file))
if len(temp_dataset) == 0: continue
train_datasets.append(temp_dataset)
self.dataset = datasets.concatenate_datasets(train_datasets)
self.max_length = self.args.query_max_len + self.args.passage_max_len
[docs]
def _load_dataset(self, file_path: str):
"""Load dataset from path.
Args:
file_path (str): Path to load the datasets from.
Raises:
ValueError: `pos_scores` and `neg_scores` not found in the features of training data
Returns:
datasets.Dataset: Loaded HF dataset.
"""
if dist.get_rank() == 0:
logger.info(f'loading data from {file_path} ...')
temp_dataset = datasets.load_dataset('json', data_files=file_path, split='train', cache_dir=self.args.cache_path)
if len(temp_dataset) > self.args.max_example_num_per_dataset:
temp_dataset = temp_dataset.select(random.sample(list(range(len(temp_dataset))), self.args.max_example_num_per_dataset))
if not self.args.knowledge_distillation:
if 'pos_scores' in temp_dataset.column_names:
temp_dataset = temp_dataset.remove_columns(['pos_scores'])
if 'neg_scores' in temp_dataset.column_names:
temp_dataset = temp_dataset.remove_columns(['neg_scores'])
else:
if 'pos_scores' not in temp_dataset.column_names or 'neg_scores' not in temp_dataset.column_names:
raise ValueError(f"`pos_scores` and `neg_scores` not found in the features of training data in {file_path}, which is necessary when using knowledge distillation.")
return temp_dataset
[docs]
def _shuffle_text(self, text):
"""shuffle the input text.
Args:
text (str): Input text.
Returns:
str: Shuffled text.
"""
if self.args.shuffle_ratio > 0 and len(text) > 100 and random.random() < self.args.shuffle_ratio:
split_text = []
chunk_size = len(text)//3 + 1
for i in range(0, len(text), chunk_size):
split_text.append(text[i:i+chunk_size])
random.shuffle(split_text)
return " ".join(split_text)
else:
return text
def __len__(self):
return len(self.dataset)
[docs]
def create_one_example(self, qry_encoding: str, doc_encoding: str):
"""Creates a single input example by encoding and preparing a query and document pair for the model.
Args:
qry_encoding (str): Query to be encoded.
doc_encoding (str): Document to be encoded.
Returns:
dict: A dictionary containing tokenized and prepared inputs, ready for model consumption.
"""
qry_inputs = self.tokenizer.encode(qry_encoding, truncation=True, max_length=self.args.query_max_len + self.args.passage_max_len // 4, add_special_tokens=False)
doc_inputs = self.tokenizer.encode(doc_encoding, truncation=True, max_length=self.args.passage_max_len + self.args.query_max_len // 2, add_special_tokens=False)
item = self.tokenizer.prepare_for_model(
qry_inputs,
doc_inputs,
truncation='only_second',
max_length=self.args.query_max_len + self.args.passage_max_len,
padding=False,
)
return item
def __getitem__(self, item):
data = self.dataset[item]
train_group_size = self.args.train_group_size
query = data['query']
if self.args.query_instruction_for_rerank is not None:
query = self.args.query_instruction_format.format(
data['query_prompt'] if 'query_prompt' in data else self.args.query_instruction_for_rerank,
query
)
passages = []
teacher_scores = []
assert isinstance(data['pos'], list) and isinstance(data['neg'], list)
pos_idx = random.choice(list(range(len(data['pos']))))
passages.append(self._shuffle_text(data['pos'][pos_idx]))
neg_all_idx = list(range(len(data['neg'])))
if len(data['neg']) < train_group_size - 1:
num = math.ceil((train_group_size - 1) / len(data['neg']))
neg_idxs = random.sample(neg_all_idx * num, train_group_size - 1)
else:
neg_idxs = random.sample(neg_all_idx, self.args.train_group_size - 1)
for neg_idx in neg_idxs:
passages.append(data['neg'][neg_idx])
if self.args.knowledge_distillation:
assert isinstance(data['pos_scores'], list) and isinstance(data['neg_scores'], list)
teacher_scores.append(data['pos_scores'][pos_idx])
for neg_idx in neg_idxs:
teacher_scores.append(data['neg_scores'][neg_idx])
if not all(isinstance(score, (int, float)) for score in teacher_scores):
raise ValueError(f"pos_score or neg_score must be digit")
else:
teacher_scores = None
if self.args.passage_instruction_for_rerank is not None:
passages = [
self.args.passage_instruction_format.format(
data['passage_prompt'] if 'passage_prompt' in data else self.args.passage_instruction_for_rerank, p
)
for p in passages
]
batch_data = []
for passage in passages:
batch_data.append(self.create_one_example(query, passage))
return batch_data, teacher_scores
[docs]
@dataclass
class AbsRerankerCollator(DataCollatorWithPadding):
"""
The abstract reranker collator.
"""
query_max_len: int = 32
passage_max_len: int = 128
def __call__(self, features) -> list[BatchEncoding]:
teacher_scores = [f[1] for f in features]
if teacher_scores[0] is None:
teacher_scores = None
elif isinstance(teacher_scores[0], list):
teacher_scores = sum(teacher_scores, [])
features = [f[0] for f in features]
if isinstance(features[0], list):
features = sum(features, [])
collated = self.tokenizer.pad(
features,
padding=self.padding,
max_length=self.query_max_len + self.passage_max_len,
pad_to_multiple_of=self.pad_to_multiple_of,
return_tensors=self.return_tensors,
)
return {
"pair": collated,
"teacher_scores": teacher_scores,
}
[docs]
class AbsLLMRerankerTrainDataset(AbsRerankerTrainDataset):
"""Abstract class for LLM reranker training dataset.
Args:
args (AbsRerankerDataArguments): Data arguments.
tokenizer (PreTrainedTokenizer): Tokenizer to use.
"""
def __init__(
self,
args: AbsRerankerDataArguments,
tokenizer: PreTrainedTokenizer
):
super().__init__(args, tokenizer)
sep = self.args.sep_token
self.sep_inputs = self.tokenizer(
sep,
return_tensors=None,
add_special_tokens=False
)['input_ids']
def __getitem__(self, item) -> List[BatchEncoding]:
data = self.dataset[item]
train_group_size = self.args.train_group_size
query = data['query']
if self.args.query_instruction_for_rerank is not None:
query = self.args.query_instruction_format.format(
data['query_prompt'] if 'query_prompt' in data else self.args.query_instruction_for_rerank,
query
)
passages = []
teacher_scores = []
assert isinstance(data['pos'], list) and isinstance(data['neg'], list)
pos_idx = random.choice(list(range(len(data['pos']))))
passages.append(self._shuffle_text(data['pos'][pos_idx]))
neg_all_idx = list(range(len(data['neg'])))
if len(data['neg']) < train_group_size - 1:
num = math.ceil((train_group_size - 1) / len(data['neg']))
neg_idxs = random.sample(neg_all_idx * num, train_group_size - 1)
else:
neg_idxs = random.sample(neg_all_idx, self.args.train_group_size - 1)
for neg_idx in neg_idxs:
passages.append(data['neg'][neg_idx])
if self.args.knowledge_distillation:
assert isinstance(data['pos_scores'], list) and isinstance(data['neg_scores'], list)
teacher_scores.append(data['pos_scores'][pos_idx])
for neg_idx in neg_idxs:
teacher_scores.append(data['neg_scores'][neg_idx])
if not all(isinstance(score, (int, float)) for score in teacher_scores):
raise ValueError(f"pos_score or neg_score must be digit")
else:
teacher_scores = None
if self.args.passage_instruction_for_rerank is not None:
passages = [
self.args.passage_instruction_format.format(
data['passage_prompt'] if 'passage_prompt' in data else self.args.passage_instruction_for_rerank, p
)
for p in passages
]
prompt = self.dataset[item].get('prompt', "Given a query A and a passage B, determine whether the passage contains an answer to the query by providing a prediction of either 'Yes' or 'No'.")
query_inputs = self.tokenizer(
query,
return_tensors=None,
max_length=self.args.query_max_len + self.args.passage_max_len // 4,
truncation=True,
add_special_tokens=False
)
prompt_inputs = self.tokenizer(
prompt,
return_tensors=None,
add_special_tokens=False
)['input_ids']
max_length = self.max_length - len(prompt_inputs) - len(self.sep_inputs)
passages_inputs = []
for i, passage in enumerate(passages):
passage_inputs = self.tokenizer(
passage,
return_tensors=None,
max_length=self.args.passage_max_len + self.args.query_max_len // 2,
truncation=True,
add_special_tokens=False
)
if self.tokenizer.bos_token_id is not None and self.tokenizer.bos_token_id != self.tokenizer.pad_token_id:
item = self.tokenizer.prepare_for_model(
[self.tokenizer.bos_token_id] + query_inputs['input_ids'],
self.sep_inputs + passage_inputs['input_ids'],
truncation='only_second',
max_length=max_length,
padding=False,
return_attention_mask=False,
return_token_type_ids=False,
add_special_tokens=False
)
else:
item = self.tokenizer.prepare_for_model(
query_inputs['input_ids'],
self.sep_inputs + passage_inputs['input_ids'],
truncation='only_second',
max_length=max_length,
padding=False,
return_attention_mask=False,
return_token_type_ids=False,
add_special_tokens=False
)
passage_inputs['input_ids'] = item['input_ids'] + self.sep_inputs + prompt_inputs
passage_inputs['attention_mask'] = [1] * len(passage_inputs['input_ids'])
# passage_inputs['labels'] = passage_inputs['input_ids'].copy()
# passage_inputs['labels'] = [-100] * (len(passage_inputs['input_ids']) - 1) + passage_inputs['labels'][(len(passage_inputs['input_ids']) - 1):]
passage_inputs.pop('token_type_ids') if 'token_type_ids' in passage_inputs.keys() else None
if 'position_ids' in passage_inputs.keys():
passage_inputs['position_ids'] = list(range(len(passage_inputs['input_ids'])))
passages_inputs.append(passage_inputs)
return passages_inputs, teacher_scores
[docs]
@dataclass
class AbsLLMRerankerCollator(DataCollatorForSeq2Seq):
"""
Wrapper that does conversion from List[Tuple[encode_qry, encode_psg]] to List[qry], List[psg]
and pass batch separately to the actual collator.
Abstract out data detail for the model.
"""
query_max_len: int = 32
passage_max_len: int = 128
def __call__(self, features, return_tensors='pt'):
if return_tensors is None:
return_tensors = self.return_tensors
teacher_scores = [f[1] for f in features]
if teacher_scores[0] is None:
teacher_scores = None
elif isinstance(teacher_scores[0], list):
teacher_scores = sum(teacher_scores, [])
features = [f[0] for f in features]
if isinstance(features[0], list):
features = sum(features, [])
labels = [feature["labels"] for feature in features] if "labels" in features[0].keys() else None
# We have to pad the labels before calling `tokenizer.pad` as this method won't pad them and needs them of the
# same length to return tensors.
if labels is not None:
max_label_length = max(len(l) for l in labels)
# print(max_label_length)
if self.pad_to_multiple_of is not None:
max_label_length = (
(max_label_length + self.pad_to_multiple_of - 1)
// self.pad_to_multiple_of
* self.pad_to_multiple_of
)
padding_side = self.tokenizer.padding_side
for feature in features:
remainder = [self.label_pad_token_id] * (max_label_length - len(feature["labels"]))
if isinstance(feature["labels"], list):
feature["labels"] = (
feature["labels"] + remainder
if padding_side == "right" else remainder + feature["labels"]
)
elif padding_side == "right":
feature["labels"] = np.concatenate([feature["labels"], remainder]).astype(np.int64)
else:
feature["labels"] = np.concatenate([remainder, feature["labels"]]).astype(np.int64)
collated = self.tokenizer.pad(
features,
padding=self.padding,
max_length=self.query_max_len + self.passage_max_len,
return_tensors=return_tensors,
pad_to_multiple_of=self.pad_to_multiple_of,
)
return {
"pair": collated,
"teacher_scores": teacher_scores,
}