Source code for FlagEmbedding.abc.finetune.embedder.AbsDataset

import os
import math
import random
import logging
import datasets
import numpy as np
import torch.distributed as dist
from dataclasses import dataclass
from torch.utils.data import Dataset
from transformers import (
    PreTrainedTokenizer, 
    DataCollatorWithPadding,
    TrainerCallback,
    TrainerState,
    TrainerControl
)

from .AbsArguments import AbsEmbedderDataArguments, AbsEmbedderTrainingArguments

logger = logging.getLogger(__name__)


[docs] class AbsEmbedderTrainDataset(Dataset): """Abstract class for training dataset. Args: args (AbsEmbedderDataArguments): Data arguments. tokenizer (PreTrainedTokenizer): Tokenizer to use. """ def __init__( self, args: AbsEmbedderDataArguments, tokenizer: PreTrainedTokenizer ): self.args = args self.tokenizer = tokenizer self.shuffle_ratio = args.shuffle_ratio train_datasets = [] for data_dir in args.train_data: if not os.path.isdir(data_dir): if not (data_dir.endswith('.json') or data_dir.endswith('.jsonl')): continue temp_dataset = self._load_dataset(data_dir) if len(temp_dataset) == 0: continue train_datasets.append(temp_dataset) else: for file in os.listdir(data_dir): if not (file.endswith('.json') or file.endswith('.jsonl')): continue temp_dataset = self._load_dataset(os.path.join(data_dir, file)) if len(temp_dataset) == 0: continue train_datasets.append(temp_dataset) self.dataset = datasets.concatenate_datasets(train_datasets)
[docs] def _load_dataset(self, file_path: str): """Load dataset from path. Args: file_path (str): Path to load the datasets from. Raises: ValueError: `pos_scores` and `neg_scores` not found in the features of training data Returns: datasets.Dataset: Loaded HF dataset. """ if dist.get_rank() == 0: logger.info(f'loading data from {file_path} ...') temp_dataset = datasets.load_dataset('json', data_files=file_path, split='train', cache_dir=self.args.cache_path) if len(temp_dataset) > self.args.max_example_num_per_dataset: temp_dataset = temp_dataset.select(random.sample(list(range(len(temp_dataset))), self.args.max_example_num_per_dataset)) if not self.args.knowledge_distillation: if 'pos_scores' in temp_dataset.column_names: temp_dataset = temp_dataset.remove_columns(['pos_scores']) if 'neg_scores' in temp_dataset.column_names: temp_dataset = temp_dataset.remove_columns(['neg_scores']) else: if 'pos_scores' not in temp_dataset.column_names or 'neg_scores' not in temp_dataset.column_names: raise ValueError(f"`pos_scores` and `neg_scores` not found in the features of training data in {file_path}, which is necessary when using knowledge distillation.") return temp_dataset
[docs] def _shuffle_text(self, text): """shuffle the input text. Args: text (str): Input text. Returns: str: Shuffled text. """ if self.shuffle_ratio > 0 and len(text) > 100 and random.random() < self.shuffle_ratio: split_text = [] chunk_size = len(text)//3 + 1 for i in range(0, len(text), chunk_size): split_text.append(text[i:i+chunk_size]) random.shuffle(split_text) return " ".join(split_text) else: return text
def __len__(self): return len(self.dataset) def __getitem__(self, item): data = self.dataset[item] train_group_size = self.args.train_group_size query = data['query'] if self.args.query_instruction_for_retrieval is not None: query = self.args.query_instruction_format.format( data['prompt'] if 'prompt' in data else self.args.query_instruction_for_retrieval, query ) passages = [] teacher_scores = [] assert isinstance(data['pos'], list) and isinstance(data['neg'], list) pos_idx = random.choice(list(range(len(data['pos'])))) passages.append(self._shuffle_text(data['pos'][pos_idx])) neg_all_idx = list(range(len(data['neg']))) if len(data['neg']) < train_group_size - 1: num = math.ceil((train_group_size - 1) / len(data['neg'])) neg_idxs = random.sample(neg_all_idx * num, train_group_size - 1) else: neg_idxs = random.sample(neg_all_idx, self.args.train_group_size - 1) for neg_idx in neg_idxs: passages.append(data['neg'][neg_idx]) if self.args.knowledge_distillation: assert isinstance(data['pos_scores'], list) and isinstance(data['neg_scores'], list) teacher_scores.append(data['pos_scores'][pos_idx]) for neg_idx in neg_idxs: teacher_scores.append(data['neg_scores'][neg_idx]) if not all(isinstance(score, (int, float)) for score in teacher_scores): raise ValueError(f"pos_score or neg_score must be digit") else: teacher_scores = None if self.args.passage_instruction_for_retrieval is not None: passages = [ self.args.passage_instruction_format.format( self.args.passage_instruction_for_retrieval, p ) for p in passages ] return query, passages, teacher_scores
[docs] @dataclass class AbsEmbedderCollator(DataCollatorWithPadding): """ The abstract embedder collator. """ query_max_len: int = 32 passage_max_len: int = 128 sub_batch_size: int = -1 def __call__(self, features): queries = [f[0] for f in features] passages = [f[1] for f in features] teacher_scores = [f[2] for f in features] if teacher_scores[0] is None: teacher_scores = None elif isinstance(teacher_scores[0], list): teacher_scores = sum(teacher_scores, []) if isinstance(queries[0], list): queries = sum(queries, []) if isinstance(passages[0], list): passages = sum(passages, []) queries_inputs = self.tokenizer( queries, truncation=True, max_length=self.query_max_len, return_tensors=None ) passages_inputs = self.tokenizer( passages, truncation=True, max_length=self.passage_max_len, return_tensors=None ) if self.sub_batch_size is None or self.sub_batch_size <= 0: q_collated = self.tokenizer.pad( queries_inputs, padding=self.padding, max_length=self.query_max_len, pad_to_multiple_of=self.pad_to_multiple_of, return_tensors=self.return_tensors ) d_collated = self.tokenizer.pad( passages_inputs, padding=self.padding, max_length=self.passage_max_len, pad_to_multiple_of=self.pad_to_multiple_of, return_tensors=self.return_tensors ) else: batch_size = self.sub_batch_size q_collated = [] for i in range(0, len(queries_inputs['attention_mask']), batch_size): start = i end = min(len(queries_inputs['attention_mask']), i + batch_size) sub_features = {} for k, v in queries_inputs.items(): sub_features[k] = v[start:end] q_collated.append(self.tokenizer.pad( sub_features, padding=self.padding, max_length=self.passage_max_len, pad_to_multiple_of=self.pad_to_multiple_of, return_tensors=self.return_tensors )) d_collated = [] for i in range(0, len(passages_inputs['attention_mask']), batch_size): start = i end = min(len(passages_inputs['attention_mask']), i + batch_size) sub_features = {} for k, v in passages_inputs.items(): sub_features[k] = v[start:end] d_collated.append(self.tokenizer.pad( sub_features, padding=self.padding, max_length=self.passage_max_len, pad_to_multiple_of=self.pad_to_multiple_of, return_tensors=self.return_tensors )) return { "queries": q_collated, "passages": d_collated, "teacher_scores": teacher_scores, "no_in_batch_neg_flag": False }
[docs] class AbsEmbedderSameDatasetTrainDataset(AbsEmbedderTrainDataset): """Abstract class for training dataset that samples batches from same dataset. Args: args (AbsEmbedderDataArguments): Data arguments. default_batch_size (int): The default batch size for training. seed (int): Random seed. tokenizer (PreTrainedTokenizer): Tokenizer to use. process_index (int, optional): Current process index. Defaults to 0. num_processes (int, optional): Total number of processes. Defaults to 1. """ def __init__( self, args: AbsEmbedderDataArguments, default_batch_size: int, seed: int, tokenizer: PreTrainedTokenizer, process_index: int=0, num_processes: int=1 ): self.args = args self.shuffle_ratio = args.shuffle_ratio self.defaut_batch_size = default_batch_size self.deterministic_generator = np.random.default_rng(seed) self.tokenizer = tokenizer self.process_index = process_index self.num_processes = num_processes self.step = 0 train_datasets = [] each_data_idxs = [] batch_size_idxs = [] no_in_batch_neg_flags = [] cur_all_num = 0 small_threshold = args.small_threshold drop_threshold = args.drop_threshold for data_dir in args.train_data: if not os.path.isdir(data_dir): # Add `no_in_batch_neg` **suffix** to `data_dir` to indicate that this dataset does not use in-batch negatives no_in_batch_neg_flag = data_dir.split('.')[-2].endswith('no_in_batch_neg') if not (data_dir.endswith('.json') or data_dir.endswith('.jsonl')): continue temp_dataset = self._load_dataset(data_dir) if len(temp_dataset) == 0 or len(temp_dataset) < small_threshold: continue else: train_datasets.append(temp_dataset) each_data_idxs.append(np.arange(len(temp_dataset)) + cur_all_num) cur_all_num += len(temp_dataset) batch_size_idxs.append(self._get_file_batch_size(temp_dataset, default_batch_size)) no_in_batch_neg_flags.append(no_in_batch_neg_flag) else: small_datasets = [] small_batch_size = math.inf # Add `no_in_batch_neg` **suffix** to `data_dir` to indicate that this dataset does not use in-batch negatives no_in_batch_neg_flag = data_dir.endswith('no_in_batch_neg') for file in os.listdir(data_dir): if not (file.endswith('.json') or file.endswith('.jsonl')): continue temp_dataset = self._load_dataset(os.path.join(data_dir, file)) if len(temp_dataset) == 0: continue elif len(temp_dataset) < small_threshold: small_datasets.append(temp_dataset) small_batch_size = min(small_batch_size, self._get_file_batch_size(temp_dataset, default_batch_size)) else: train_datasets.append(temp_dataset) each_data_idxs.append(np.arange(len(temp_dataset)) + cur_all_num) cur_all_num += len(temp_dataset) batch_size_idxs.append(self._get_file_batch_size(temp_dataset, default_batch_size)) no_in_batch_neg_flags.append(no_in_batch_neg_flag) if len(small_datasets) > 0: small_dataset = datasets.concatenate_datasets(small_datasets) if len(small_dataset) >= drop_threshold: train_datasets.append(small_dataset) each_data_idxs.append(np.arange(len(small_dataset)) + cur_all_num) cur_all_num += len(small_dataset) batch_size_idxs.append(small_batch_size) no_in_batch_neg_flags.append(no_in_batch_neg_flag) self.dataset = datasets.concatenate_datasets(train_datasets) self.each_data_idxs = each_data_idxs self.datasets_inxs = np.arange(len(each_data_idxs)) self.batch_size_idxs = batch_size_idxs self.no_in_batch_neg_flags = no_in_batch_neg_flags self.refresh_epoch()
[docs] def _load_dataset(self, file_path: str): """Load datset from given path. Args: file_path (str): The path to load or download from HF hub. Returns: datasets.Dataset: The loaded dataset. """ if dist.get_rank() == 0: logger.info(f'loading data from {file_path} ...') temp_dataset = datasets.load_dataset('json', data_files=file_path, split='train', cache_dir=self.args.cache_path) if len(temp_dataset) > self.args.max_example_num_per_dataset: temp_dataset = temp_dataset.select(random.sample(list(range(len(temp_dataset))), self.args.max_example_num_per_dataset)) if not self.args.knowledge_distillation: if 'pos_scores' in temp_dataset.column_names: temp_dataset = temp_dataset.remove_columns(['pos_scores']) if 'neg_scores' in temp_dataset.column_names: temp_dataset = temp_dataset.remove_columns(['neg_scores']) return temp_dataset
[docs] @staticmethod def _get_file_batch_size(temp_dataset: datasets.Dataset, default_batch_size: int): """Get the appropriate batch size for the dataset. Args: temp_dataset (datasets.Dataset): Loaded :data:`datasets.Dataset` object. default_batch_size (int): The default batch size to use if not specified in the dataset. Returns: int: The final batch size to use. """ if 'batch_size' in temp_dataset.column_names: return temp_dataset['batch_size'][0] if 'type' in temp_dataset.column_names: data_type = temp_dataset['type'][0] if 'symmetric' in data_type: return default_batch_size // 2 # make the symmetric data have smaller batch size return default_batch_size
[docs] def refresh_epoch(self): """ Refresh data for epoch. """ logger.info(f'-- Rank {self.process_index}: refresh data --') self.deterministic_generator.shuffle(self.datasets_inxs) batch_datas = [] for dataset_inx in self.datasets_inxs: self.deterministic_generator.shuffle(self.each_data_idxs[dataset_inx]) cur_batch_size = self.batch_size_idxs[dataset_inx]*self.num_processes no_in_batch_neg_flag = self.no_in_batch_neg_flags[dataset_inx] for start_index in range(0, len(self.each_data_idxs[dataset_inx]), cur_batch_size): # judge the last batch's length if len(self.each_data_idxs[dataset_inx]) - start_index < cur_batch_size: break batch_datas.append(( self.each_data_idxs[dataset_inx][start_index:start_index+cur_batch_size], no_in_batch_neg_flag )) self.deterministic_generator.shuffle(batch_datas) self.batch_datas = batch_datas self.step = 0
def __len__(self): return len(self.batch_datas) * self.num_processes def __getitem__(self, _): batch_indices, no_in_batch_neg_flag = self.batch_datas[self.step] # extend here cur_batch_size = int(len(batch_indices) / self.num_processes) batch_indices = batch_indices[self.process_index * cur_batch_size: (self.process_index + 1) * cur_batch_size] batch_data = self.dataset[batch_indices] self.step += 1 queries, passages, teacher_scores = self._create_batch_data(batch_raw_data=batch_data) return queries, passages, teacher_scores, no_in_batch_neg_flag
[docs] def _get_train_group_size(self, batch_raw_data): """Get the training group size and data type. Args: batch_raw_data (datasets.Dataset): One batch of raw data. Returns: int: The training group size. str: The type of data for the task. """ if 'type' in batch_raw_data: data_type = batch_raw_data['type'][0] if data_type in ['only_1neg']: return 2, data_type elif data_type in ['symmetric_class']: return min(len(batch_raw_data['neg'][0]) + 1, self.args.train_group_size), data_type else: return self.args.train_group_size, data_type return self.args.train_group_size, None
[docs] def _create_batch_data(self, batch_raw_data): """Create a comple batch of data with queries, documents and teacher scores. Args: batch_raw_data (datasets.Dataset): One batch of raw data. Returns: List[str]: Queries with instruction format. List[str]: Documents with instruction format. List[float]: Teacher scores for model distillation. """ queries, passages, teacher_scores = [], [], [] train_group_size, data_type = self._get_train_group_size(batch_raw_data) for i in range(len(batch_raw_data['query'])): if data_type is not None: assert batch_raw_data['type'][i] == data_type, f"Data type is not consistent in the same batch" queries.append( self.args.query_instruction_format.format( batch_raw_data['prompt'][i] if 'prompt' in batch_raw_data else self.args.query_instruction_for_retrieval, batch_raw_data['query'][i] ) ) tmp_passages = [] pos_idx = random.choice(list(range(len(batch_raw_data['pos'][i])))) pos = self._shuffle_text(batch_raw_data['pos'][i][pos_idx]) tmp_passages.append(pos) neg_all_idx = list(range(len(batch_raw_data['neg'][i]))) if len(batch_raw_data['neg'][i]) < train_group_size - 1: num = math.ceil((train_group_size - 1) / len(batch_raw_data['neg'][i])) neg_idxs = random.sample(neg_all_idx * num, train_group_size - 1) else: neg_idxs = random.sample(neg_all_idx, train_group_size - 1) for neg_idx in neg_idxs: tmp_passages.append(batch_raw_data['neg'][i][neg_idx]) if self.args.knowledge_distillation: if 'pos_scores' in batch_raw_data and batch_raw_data['pos_scores'][i] is not None: teacher_scores.append(batch_raw_data['pos_scores'][i][pos_idx]) for neg_idx in neg_idxs: if 'neg_scores' in batch_raw_data and batch_raw_data['neg_scores'][i] is not None: teacher_scores.append(batch_raw_data['neg_scores'][i][neg_idx]) else: teacher_scores = None if data_type is not None and data_type in ['symmetric_sts', 'symmetric_clustering']: tmp_passages = [ self.args.query_instruction_format.format( batch_raw_data['prompt'][i] if 'prompt' in batch_raw_data else self.args.query_instruction_for_retrieval, p ) for p in tmp_passages ] else: if self.args.passage_instruction_for_retrieval is not None: tmp_passages = [ self.args.passage_instruction_format.format( self.args.passage_instruction_for_retrieval, p ) for p in tmp_passages ] passages.extend(tmp_passages) if teacher_scores is not None: if len(teacher_scores) > 0 and len(passages) > 0: assert len(teacher_scores) == len(passages) return queries, passages, teacher_scores
[docs] @dataclass class AbsEmbedderSameDatasetCollator(DataCollatorWithPadding): """ EmbedCollator for SameDataset. Note that after using this collator, the training_args should be set as: ``training_args.per_device_train_batch_size = 1`` ``training_args.dataloader_num_workers = 0 # avoid multi-processing`` """ query_max_len: int = 32 passage_max_len: int = 128 sub_batch_size: int = -1 def __call__(self, features): queries = features[0][0] passages = features[0][1] teacher_scores = features[0][2] no_in_batch_neg_flag = features[0][3] queries_inputs = self.tokenizer( queries, truncation=True, max_length=self.query_max_len, return_tensors=None ) passages_inputs = self.tokenizer( passages, truncation=True, max_length=self.passage_max_len, return_tensors=None ) if self.sub_batch_size is None or self.sub_batch_size <= 0: q_collated = self.tokenizer.pad( queries_inputs, padding=self.padding, max_length=self.query_max_len, pad_to_multiple_of=self.pad_to_multiple_of, return_tensors=self.return_tensors, ) d_collated = self.tokenizer.pad( passages_inputs, padding=self.padding, max_length=self.passage_max_len, pad_to_multiple_of=self.pad_to_multiple_of, return_tensors=self.return_tensors, ) else: batch_size = self.sub_batch_size q_collated = [] for i in range(0, len(queries_inputs['attention_mask']), batch_size): start = i end = min(len(queries_inputs['attention_mask']), i + batch_size) sub_features = {} for k, v in queries_inputs.items(): sub_features[k] = v[start:end] q_collated.append(self.tokenizer.pad( sub_features, padding=self.padding, max_length=self.query_max_len, pad_to_multiple_of=self.pad_to_multiple_of, return_tensors=self.return_tensors, )) d_collated = [] for i in range(0, len(passages_inputs['attention_mask']), batch_size): start = i end = min(len(passages_inputs['attention_mask']), i + batch_size) sub_features = {} for k, v in passages_inputs.items(): sub_features[k] = v[start:end] d_collated.append(self.tokenizer.pad( sub_features, padding=self.padding, max_length=self.passage_max_len, pad_to_multiple_of=self.pad_to_multiple_of, return_tensors=self.return_tensors, )) if isinstance(teacher_scores, list) and len(teacher_scores) == 0: teacher_scores = None return { "queries": q_collated, "passages": d_collated, "teacher_scores": teacher_scores, "no_in_batch_neg_flag": no_in_batch_neg_flag }
[docs] class EmbedderTrainerCallbackForDataRefresh(TrainerCallback): """ Callback class to inspect the state of the training loop and take decision. """ def __init__(self, train_dataset: AbsEmbedderSameDatasetTrainDataset): self.train_dataset = train_dataset
[docs] def on_epoch_end( self, args: AbsEmbedderTrainingArguments, state: TrainerState, control: TrainerControl, **kwargs ): """ Event called at the end of an epoch. """ self.train_dataset.refresh_epoch()