Source code for FlagEmbedding.abc.finetune.reranker.AbsRunner

import os
import logging
from pathlib import Path
from typing import Tuple
from abc import ABC, abstractmethod
from transformers import set_seed, PreTrainedTokenizer


from .AbsArguments import (
    AbsRerankerModelArguments,
    AbsRerankerDataArguments,
    AbsRerankerTrainingArguments
)
from .AbsTrainer import AbsRerankerTrainer
from .AbsModeling import AbsRerankerModel
from .AbsDataset import (
    AbsRerankerTrainDataset, AbsRerankerCollator,
    AbsLLMRerankerTrainDataset, AbsLLMRerankerCollator
)

logger = logging.getLogger(__name__)


[docs] class AbsRerankerRunner(ABC): """Abstract class to run reranker model fine-tuning. Args: model_args (AbsRerankerModelArguments): Model arguments data_args (AbsRerankerDataArguments): Data arguments. training_args (AbsRerankerTrainingArguments): Training arguments. """ def __init__( self, model_args: AbsRerankerModelArguments, data_args: AbsRerankerDataArguments, training_args: AbsRerankerTrainingArguments ): self.model_args = model_args self.data_args = data_args self.training_args = training_args if ( os.path.exists(training_args.output_dir) and os.listdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir ): raise ValueError( f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome." ) # Setup logging logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN, ) logger.warning( "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s", training_args.local_rank, training_args.device, training_args.n_gpu, bool(training_args.local_rank != -1), training_args.fp16, ) logger.info("Training/evaluation parameters %s", training_args) logger.info("Model parameters %s", model_args) logger.info("Data parameters %s", data_args) # Set seed set_seed(training_args.seed) self.tokenizer, self.model = self.load_tokenizer_and_model() self.train_dataset = self.load_train_dataset() self.data_collator = self.load_data_collator() self.trainer = self.load_trainer()
[docs] @abstractmethod def load_tokenizer_and_model(self) -> Tuple[PreTrainedTokenizer, AbsRerankerModel]: """Abstract method to load the tokenizer and model. Returns: Tuple[PreTrainedTokenizer, AbsRerankerModel]: Loaded tokenizer and model instances. """ pass
[docs] @abstractmethod def load_trainer(self) -> AbsRerankerTrainer: """Abstract method to load the trainer. Returns: AbsRerankerTrainer: The loaded trainer instance. """ pass
[docs] def load_train_dataset(self) -> AbsRerankerTrainDataset: """Loads the training dataset based on data arguments. Returns: AbsRerankerTrainDataset: The loaded dataset instance. """ if self.model_args.model_type == 'encoder': train_dataset = AbsRerankerTrainDataset( args=self.data_args, tokenizer=self.tokenizer ) else: train_dataset = AbsLLMRerankerTrainDataset( args=self.data_args, tokenizer=self.tokenizer ) return train_dataset
[docs] def load_data_collator(self) -> AbsRerankerCollator: """Loads the appropriate data collator. Returns: AbsRerankerCollator: Loaded data collator. """ if self.model_args.model_type == 'encoder': RerankerCollator = AbsRerankerCollator else: RerankerCollator = AbsLLMRerankerCollator data_collator = RerankerCollator( tokenizer=self.tokenizer, query_max_len=self.data_args.query_max_len, passage_max_len=self.data_args.passage_max_len, pad_to_multiple_of=self.data_args.pad_to_multiple_of, padding=True, return_tensors="pt" ) return data_collator
[docs] def run(self): """ Executes the training process. """ Path(self.training_args.output_dir).mkdir(parents=True, exist_ok=True) # Training self.trainer.train(resume_from_checkpoint=self.training_args.resume_from_checkpoint) self.trainer.save_model()