Source code for FlagEmbedding.finetune.embedder.encoder_only.m3.trainer

import os
import torch
import logging
from typing import Optional

from FlagEmbedding.abc.finetune.embedder import AbsEmbedderTrainer

logger = logging.getLogger(__name__)


[docs] class EncoderOnlyEmbedderM3Trainer(AbsEmbedderTrainer): """ Trainer class for M3. """ def _save(self, output_dir: Optional[str] = None, state_dict=None): """Save the model to directory. Args: output_dir (Optional[str], optional): Output directory to save the model. Defaults to ``None``. Raises: NotImplementedError """ output_dir = output_dir if output_dir is not None else self.args.output_dir os.makedirs(output_dir, exist_ok=True) logger.info("Saving model checkpoint to %s", output_dir) # Save a trained model and configuration using `save_pretrained()`. # They can then be reloaded using `from_pretrained()` if not hasattr(self.model, 'save'): raise NotImplementedError( f'MODEL {self.model.__class__.__name__} ' f'does not support save interface') else: self.model.save(output_dir) if self.tokenizer is not None and self.is_world_process_zero(): self.tokenizer.save_pretrained(output_dir) torch.save(self.args, os.path.join(output_dir, "training_args.bin"))
# save the checkpoint for sentence-transformers library # if self.is_world_process_zero(): # save_ckpt_for_sentence_transformers(output_dir, # pooling_mode=self.args.sentence_pooling_method, # normlized=self.args.normlized)