Source code for FlagEmbedding.abc.finetune.embedder.AbsRunner

import os
import logging
from pathlib import Path
from typing import Tuple
from abc import ABC, abstractmethod
from transformers import set_seed, PreTrainedTokenizer


from .AbsArguments import (
    AbsEmbedderModelArguments,
    AbsEmbedderDataArguments,
    AbsEmbedderTrainingArguments
)
from .AbsTrainer import AbsEmbedderTrainer
from .AbsModeling import AbsEmbedderModel
from .AbsDataset import (
    AbsEmbedderTrainDataset, AbsEmbedderCollator,
    AbsEmbedderSameDatasetTrainDataset, AbsEmbedderSameDatasetCollator
)

logger = logging.getLogger(__name__)


[docs] class AbsEmbedderRunner(ABC): """Abstract class to run embedding model fine-tuning. Args: model_args (AbsEmbedderModelArguments): Model arguments data_args (AbsEmbedderDataArguments): Data arguments. training_args (AbsEmbedderTrainingArguments): Training arguments. """ def __init__( self, model_args: AbsEmbedderModelArguments, data_args: AbsEmbedderDataArguments, training_args: AbsEmbedderTrainingArguments ): self.model_args = model_args self.data_args = data_args self.training_args = training_args if ( os.path.exists(training_args.output_dir) and os.listdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir ): raise ValueError( f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome." ) # Setup logging logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN, ) logger.warning( "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s", training_args.local_rank, training_args.device, training_args.n_gpu, bool(training_args.local_rank != -1), training_args.fp16, ) logger.info("Training/evaluation parameters %s", training_args) logger.info("Model parameters %s", model_args) logger.info("Data parameters %s", data_args) # Set seed set_seed(training_args.seed) self.tokenizer, self.model = self.load_tokenizer_and_model() self.train_dataset = self.load_train_dataset() self.data_collator = self.load_data_collator() self.trainer = self.load_trainer()
[docs] @abstractmethod def load_tokenizer_and_model(self) -> Tuple[PreTrainedTokenizer, AbsEmbedderModel]: """Abstract method to load the tokenizer and model. Returns: Tuple[PreTrainedTokenizer, AbsEmbedderModel]: Loaded tokenizer and model instances. """ pass
[docs] @abstractmethod def load_trainer(self) -> AbsEmbedderTrainer: """Abstract method to load the trainer. Returns: AbsEmbedderTrainer: The loaded trainer instance. """ pass
[docs] def load_train_dataset(self) -> AbsEmbedderTrainDataset: """Loads the training dataset based on data arguments. Returns: AbsEmbedderTrainDataset: The loaded dataset instance. """ if self.data_args.same_dataset_within_batch: train_dataset = AbsEmbedderSameDatasetTrainDataset( args=self.data_args, default_batch_size=self.training_args.per_device_train_batch_size, seed=self.training_args.seed, tokenizer=self.tokenizer, process_index=self.training_args.process_index, num_processes=self.training_args.world_size ) self.training_args.per_device_train_batch_size = 1 self.training_args.dataloader_num_workers = 0 # avoid multi-processing else: train_dataset = AbsEmbedderTrainDataset( args=self.data_args, tokenizer=self.tokenizer ) return train_dataset
[docs] def load_data_collator(self) -> AbsEmbedderCollator: """Loads the appropriate data collator. Returns: AbsEmbedderCollator: Loaded data collator. """ if self.data_args.same_dataset_within_batch: EmbedCollator = AbsEmbedderSameDatasetCollator else: EmbedCollator = AbsEmbedderCollator data_collator = EmbedCollator( tokenizer=self.tokenizer, query_max_len=self.data_args.query_max_len, passage_max_len=self.data_args.passage_max_len, sub_batch_size=self.training_args.sub_batch_size, pad_to_multiple_of=self.data_args.pad_to_multiple_of, padding=True, return_tensors="pt" ) return data_collator
[docs] def run(self): """ Executes the training process. """ Path(self.training_args.output_dir).mkdir(parents=True, exist_ok=True) # Training self.trainer.train(resume_from_checkpoint=self.training_args.resume_from_checkpoint) self.trainer.save_model()