Source code for FlagEmbedding.finetune.reranker.encoder_only.base.runner
import logging
from typing import Tuple
from transformers import (
AutoModelForSequenceClassification, AutoConfig,
AutoTokenizer, PreTrainedTokenizer
)
from FlagEmbedding.abc.finetune.reranker import AbsRerankerRunner, AbsRerankerModel
from FlagEmbedding.finetune.reranker.encoder_only.base.modeling import CrossEncoderModel
from FlagEmbedding.finetune.reranker.encoder_only.base.trainer import EncoderOnlyRerankerTrainer
logger = logging.getLogger(__name__)
[docs]
class EncoderOnlyRerankerRunner(AbsRerankerRunner):
"""
Encoder only reranker runner for finetuning.
"""
[docs]
def load_tokenizer_and_model(self) -> Tuple[PreTrainedTokenizer, AbsRerankerModel]:
"""Load the tokenizer and model.
Returns:
Tuple[PreTrainedTokenizer, AbsEmbedderModel]: Tokenizer and model instances.
"""
tokenizer = AutoTokenizer.from_pretrained(
self.model_args.model_name_or_path,
cache_dir=self.model_args.cache_dir,
token=self.model_args.token,
trust_remote_code=self.model_args.trust_remote_code
)
num_labels = 1
config = AutoConfig.from_pretrained(
self.model_args.config_name if self.model_args.config_name else self.model_args.model_name_or_path,
num_labels=num_labels,
cache_dir=self.model_args.cache_dir,
token=self.model_args.token,
trust_remote_code=self.model_args.trust_remote_code,
)
logger.info('Config: %s', config)
base_model = AutoModelForSequenceClassification.from_pretrained(
self.model_args.model_name_or_path,
config=config,
cache_dir=self.model_args.cache_dir,
token=self.model_args.token,
from_tf=bool(".ckpt" in self.model_args.model_name_or_path),
trust_remote_code=self.model_args.trust_remote_code
)
model = CrossEncoderModel(
base_model,
tokenizer=tokenizer,
train_batch_size=self.training_args.per_device_train_batch_size,
)
if self.training_args.gradient_checkpointing:
model.enable_input_require_grads()
return tokenizer, model
[docs]
def load_trainer(self) -> EncoderOnlyRerankerTrainer:
"""Load the trainer.
Returns:
EncoderOnlyRerankerTrainer: Loaded trainer instance.
"""
trainer = EncoderOnlyRerankerTrainer(
model=self.model,
args=self.training_args,
train_dataset=self.train_dataset,
data_collator=self.data_collator,
tokenizer=self.tokenizer
)
return trainer