AbsDataset#

AbsEmbedderTrainDataset#

class FlagEmbedding.abc.finetune.embedder.AbsEmbedderTrainDataset(args: AbsEmbedderDataArguments, tokenizer: PreTrainedTokenizer)[source]#

Abstract class for training dataset.

Parameters:

Methods#

AbsEmbedderTrainDataset._load_dataset(file_path: str)[source]#

Load dataset from path.

Parameters:

file_path (str) – Path to load the datasets from.

Raises:

ValueErrorpos_scores and neg_scores not found in the features of training data

Returns:

Loaded HF dataset.

Return type:

datasets.Dataset

AbsEmbedderTrainDataset._shuffle_text(text)[source]#

shuffle the input text.

Parameters:

text (str) – Input text.

Returns:

Shuffled text.

Return type:

str

AbsEmbedderCollator#

class FlagEmbedding.abc.finetune.embedder.AbsEmbedderCollator(tokenizer: PreTrainedTokenizerBase, padding: bool | str | PaddingStrategy = True, max_length: int | None = None, pad_to_multiple_of: int | None = None, return_tensors: str = 'pt', query_max_len: int = 32, passage_max_len: int = 128, sub_batch_size: int = -1)[source]#

The abstract embedder collator.

AbsEmbedderSameDatasetTrainDataset#

class FlagEmbedding.abc.finetune.embedder.AbsEmbedderSameDatasetTrainDataset(args: AbsEmbedderDataArguments, default_batch_size: int, seed: int, tokenizer: PreTrainedTokenizer, process_index: int = 0, num_processes: int = 1)[source]#

Abstract class for training dataset that samples batches from same dataset.

Parameters:
  • args (AbsEmbedderDataArguments) – Data arguments.

  • default_batch_size (int) – The default batch size for training.

  • seed (int) – Random seed.

  • tokenizer (PreTrainedTokenizer) – Tokenizer to use.

  • process_index (int, optional) – Current process index. Defaults to 0.

  • num_processes (int, optional) – Total number of processes. Defaults to 1.

Methods#

AbsEmbedderSameDatasetTrainDataset.refresh_epoch()[source]#

Refresh data for epoch.

AbsEmbedderSameDatasetTrainDataset._load_dataset(file_path: str)[source]#

Load datset from given path.

Parameters:

file_path (str) – The path to load or download from HF hub.

Returns:

The loaded dataset.

Return type:

datasets.Dataset

static AbsEmbedderSameDatasetTrainDataset._get_file_batch_size(temp_dataset: Dataset, default_batch_size: int)[source]#

Get the appropriate batch size for the dataset.

Parameters:
  • temp_dataset (datasets.Dataset) – Loaded datasets.Dataset object.

  • default_batch_size (int) – The default batch size to use if not specified in the dataset.

Returns:

The final batch size to use.

Return type:

int

AbsEmbedderSameDatasetTrainDataset._get_train_group_size(batch_raw_data)[source]#

Get the training group size and data type.

Parameters:

batch_raw_data (datasets.Dataset) – One batch of raw data.

Returns:

The training group size. str: The type of data for the task.

Return type:

int

AbsEmbedderSameDatasetTrainDataset._create_batch_data(batch_raw_data)[source]#

Create a comple batch of data with queries, documents and teacher scores.

Parameters:

batch_raw_data (datasets.Dataset) – One batch of raw data.

Returns:

Queries with instruction format. List[str]: Documents with instruction format. List[float]: Teacher scores for model distillation.

Return type:

List[str]

AbsEmbedderSameDatasetCollator#

class FlagEmbedding.abc.finetune.embedder.AbsEmbedderSameDatasetCollator(tokenizer: PreTrainedTokenizerBase, padding: bool | str | PaddingStrategy = True, max_length: int | None = None, pad_to_multiple_of: int | None = None, return_tensors: str = 'pt', query_max_len: int = 32, passage_max_len: int = 128, sub_batch_size: int = -1)[source]#

EmbedCollator for SameDataset. Note that after using this collator, the training_args should be set as:

training_args.per_device_train_batch_size = 1

training_args.dataloader_num_workers = 0    # avoid multi-processing

EmbedderTrainerCallbackForDataRefresh#

class FlagEmbedding.abc.finetune.embedder.EmbedderTrainerCallbackForDataRefresh(train_dataset: AbsEmbedderSameDatasetTrainDataset)[source]#

Callback class to inspect the state of the training loop and take decision.

Methods#

EmbedderTrainerCallbackForDataRefresh.on_epoch_end(args: AbsEmbedderTrainingArguments, state: TrainerState, control: TrainerControl, **kwargs)[source]#

Event called at the end of an epoch.