Source code for FlagEmbedding.finetune.embedder.decoder_only.base.runner

import logging
from typing import Tuple
from pathlib import Path
from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer

from FlagEmbedding.abc.finetune.embedder.AbsArguments import AbsEmbedderDataArguments, AbsEmbedderTrainingArguments
from FlagEmbedding.abc.finetune.embedder import AbsEmbedderRunner, AbsEmbedderModel, EmbedderTrainerCallbackForDataRefresh

from .arguments import DecoderOnlyEmbedderModelArguments
from .trainer import DecoderOnlyEmbedderTrainer
from .modeling import BiDecoderOnlyEmbedderModel
from .load_model import get_model, save_merged_model

logger = logging.getLogger(__name__)


[docs] class DecoderOnlyEmbedderRunner(AbsEmbedderRunner): """Runner class for decoder only embedding model. Args: model_args (DecoderOnlyEmbedderModelArguments): Model arguments instance. data_args (AbsEmbedderDataArguments): Data arguments instance. training_args (AbsEmbedderTrainingArguments): Trainer arguments. """ def __init__( self, model_args: DecoderOnlyEmbedderModelArguments, data_args: AbsEmbedderDataArguments, training_args: AbsEmbedderTrainingArguments ): super().__init__(model_args, data_args, training_args)
[docs] def load_tokenizer_and_model(self) -> Tuple[PreTrainedTokenizer, AbsEmbedderModel]: """Load tokenizer and model. Returns: Tuple[PreTrainedTokenizer, AbsEmbedderModel]: Tokenizer and model instances. """ tokenizer = AutoTokenizer.from_pretrained( self.model_args.tokenizer_name if self.model_args.tokenizer_name else self.model_args.model_name_or_path, token=self.model_args.token, cache_dir=self.model_args.cache_dir, use_fast=False, add_eos_token=True ) if tokenizer.pad_token is None: if tokenizer.unk_token is not None: tokenizer.pad_token = tokenizer.unk_token tokenizer.pad_token_id = tokenizer.unk_token_id else: tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token_id = tokenizer.eos_token_id tokenizer.padding_side = 'left' resize = False if self.model_args.additional_special_tokens is not None: special_tokens_dict = {'additional_special_tokens': self.model_args.additional_special_tokens} add_num = tokenizer.add_special_tokens(special_tokens_dict) if add_num > 0: resize = True logger.info(f"Add {add_num} special tokens to the tokenizer. Special tokens: {self.model_args.additional_special_tokens}") else: logger.warning(f"Special tokens {self.model_args.additional_special_tokens} already exists in the tokenizer.") base_model = get_model(self.model_args, self.training_args.output_dir, resize, len(tokenizer)) num_labels = 1 config = AutoConfig.from_pretrained( self.model_args.config_name if self.model_args.config_name else self.model_args.model_name_or_path, num_labels=num_labels, cache_dir=self.model_args.cache_dir, token=self.model_args.token, trust_remote_code=self.model_args.trust_remote_code, ) logger.info('Config: %s', config) model = BiDecoderOnlyEmbedderModel( base_model, tokenizer=tokenizer, negatives_cross_device=self.training_args.negatives_cross_device, temperature=self.training_args.temperature, sub_batch_size=self.training_args.sub_batch_size, kd_loss_type=self.training_args.kd_loss_type, sentence_pooling_method=self.training_args.sentence_pooling_method, normalize_embeddings=self.training_args.normalize_embeddings ) if self.training_args.gradient_checkpointing: model.enable_input_require_grads() if self.training_args.fix_position_embedding: for k, v in model.named_parameters(): if "position_embeddings" in k: logging.info(f"Freeze the parameters for {k}") v.requires_grad = False return tokenizer, model
[docs] def load_trainer(self) -> DecoderOnlyEmbedderTrainer: """Load the trainer. Returns: DecoderOnlyEmbedderTrainer: Loaded trainer instance. """ trainer = DecoderOnlyEmbedderTrainer( model=self.model, args=self.training_args, train_dataset=self.train_dataset, data_collator=self.data_collator, tokenizer=self.tokenizer ) if self.data_args.same_dataset_within_batch: trainer.add_callback(EmbedderTrainerCallbackForDataRefresh(self.train_dataset)) return trainer
[docs] def run(self): """ Run the finetune. """ Path(self.training_args.output_dir).mkdir(parents=True, exist_ok=True) # Training self.trainer.train(resume_from_checkpoint=self.training_args.resume_from_checkpoint) self.trainer.save_model() # save merged model if self.model_args.save_merged_lora_model and self.training_args.process_index == 0: save_merged_model(self.model_args, self.training_args.output_dir)