Source code for FlagEmbedding.abc.evaluation.runner

import os
import json
import logging
from typing import List, Union, Tuple

from FlagEmbedding import FlagAutoModel, FlagAutoReranker

from .arguments import AbsEvalArgs, AbsEvalModelArgs
from .evaluator import AbsEvaluator
from .searcher import EvalDenseRetriever, EvalReranker
from .data_loader import AbsEvalDataLoader

logger = logging.getLogger(__name__)


[docs] class AbsEvalRunner: """ Abstract class of evaluation runner. Args: eval_args (AbsEvalArgs): :class:AbsEvalArgs object with the evaluation arguments. model_args (AbsEvalModelArgs): :class:AbsEvalModelArgs object with the model arguments. """ def __init__( self, eval_args: AbsEvalArgs, model_args: AbsEvalModelArgs, ): self.eval_args = eval_args self.model_args = model_args self.retriever, self.reranker = self.load_retriever_and_reranker() self.data_loader = self.load_data_loader() self.evaluator = self.load_evaluator() @staticmethod def get_models(model_args: AbsEvalModelArgs) -> Tuple[FlagAutoModel, Union[FlagAutoReranker, None]]: """Get the embedding and reranker model Args: model_args (AbsEvalModelArgs): :class:AbsEvalModelArgs object with the model arguments. Returns: Tuple[FlagAutoModel, Union[FlagAutoReranker, None]]: A :class:FlagAutoModel object of embedding model, and :class:FlagAutoReranker object of reranker model if path provided. """ embedder = FlagAutoModel.from_finetuned( model_name_or_path=model_args.embedder_name_or_path, model_class=model_args.embedder_model_class, normalize_embeddings=model_args.normalize_embeddings, pooling_method=model_args.pooling_method, use_fp16=model_args.use_fp16, query_instruction_for_retrieval=model_args.query_instruction_for_retrieval, query_instruction_format=model_args.query_instruction_format_for_retrieval, devices=model_args.devices, examples_for_task=model_args.examples_for_task, examples_instruction_format=model_args.examples_instruction_format, trust_remote_code=model_args.trust_remote_code, cache_dir=model_args.cache_dir, batch_size=model_args.embedder_batch_size, query_max_length=model_args.embedder_query_max_length, passage_max_length=model_args.embedder_passage_max_length, ) embedder.model.config._name_or_path = model_args.embedder_name_or_path reranker = None if model_args.reranker_name_or_path is not None: reranker = FlagAutoReranker.from_finetuned( model_name_or_path=model_args.reranker_name_or_path, model_class=model_args.reranker_model_class, peft_path=model_args.reranker_peft_path, use_fp16=model_args.use_fp16, use_bf16=model_args.use_bf16, query_instruction_for_rerank=model_args.query_instruction_for_rerank, query_instruction_format=model_args.query_instruction_format_for_rerank, passage_instruction_for_rerank=model_args.passage_instruction_for_rerank, passage_instruction_format=model_args.passage_instruction_format_for_rerank, cache_dir=model_args.cache_dir, trust_remote_code=model_args.trust_remote_code, devices=model_args.devices, normalize=model_args.normalize, prompt=model_args.prompt, cutoff_layers=model_args.cutoff_layers, compress_layers=model_args.compress_layers, compress_ratio=model_args.compress_ratio, batch_size=model_args.reranker_batch_size, query_max_length=model_args.reranker_query_max_length, max_length=model_args.reranker_max_length, ) reranker.model.config._name_or_path = model_args.reranker_name_or_path return embedder, reranker def load_retriever_and_reranker(self) -> Tuple[EvalDenseRetriever, Union[EvalReranker, None]]: """Load retriever and reranker for evaluation Returns: Tuple[EvalDenseRetriever, Union[EvalReranker, None]]: A :class:EvalDenseRetriever object for retrieval, and a :class:EvalReranker object if reranker provided. """ embedder, reranker = self.get_models(self.model_args) retriever = EvalDenseRetriever( embedder, search_top_k=self.eval_args.search_top_k, overwrite=self.eval_args.overwrite ) if reranker is not None: reranker = EvalReranker(reranker, rerank_top_k=self.eval_args.rerank_top_k) return retriever, reranker def load_data_loader(self) -> AbsEvalDataLoader: """Load the data loader Returns: AbsEvalDataLoader: Data loader object for that specific task. """ data_loader = AbsEvalDataLoader( eval_name=self.eval_args.eval_name, dataset_dir=self.eval_args.dataset_dir, cache_dir=self.eval_args.cache_path, token=self.eval_args.token, force_redownload=self.eval_args.force_redownload, ) return data_loader def load_evaluator(self) -> AbsEvaluator: """Load the evaluator for evaluation Returns: AbsEvaluator: the evaluator to run the evaluation. """ evaluator = AbsEvaluator( eval_name=self.eval_args.eval_name, data_loader=self.data_loader, overwrite=self.eval_args.overwrite, ) return evaluator @staticmethod def evaluate_metrics( search_results_save_dir: str, output_method: str = "markdown", output_path: str = "./eval_dev_results.md", metrics: Union[str, List[str]] = ["ndcg_at_10", "recall_at_10"] ): """Evaluate the provided metrics and write the results. Args: search_results_save_dir (str): Path to save the search results. output_method (str, optional): Output results to `json` or `markdown`. Defaults to :data:`"markdown"`. output_path (str, optional): Path to write the output. Defaults to :data:`"./eval_dev_results.md"`. metrics (Union[str, List[str]], optional): metrics to use. Defaults to :data:`["ndcg_at_10", "recall_at_10"]`. Raises: FileNotFoundError: Eval results not found ValueError: Invalid output method """ eval_results_dict = {} for model_name in sorted(os.listdir(search_results_save_dir)): model_search_results_save_dir = os.path.join(search_results_save_dir, model_name) if not os.path.isdir(model_search_results_save_dir): continue for reranker_name in sorted(os.listdir(model_search_results_save_dir)): reranker_search_results_save_dir = os.path.join(model_search_results_save_dir, reranker_name) if not os.path.isdir(reranker_search_results_save_dir): continue eval_results_path = os.path.join(reranker_search_results_save_dir, 'EVAL', "eval_results.json") if os.path.exists(eval_results_path): eval_results = json.load(open(eval_results_path, encoding='utf-8')) else: raise FileNotFoundError(f"Eval results not found: {eval_results_path}") if model_name not in eval_results_dict: eval_results_dict[model_name] = {} eval_results_dict[model_name][reranker_name] = eval_results if output_method == "json": AbsEvaluator.output_eval_results_to_json(eval_results_dict, output_path) elif output_method == "markdown": AbsEvaluator.output_eval_results_to_markdown(eval_results_dict, output_path, metrics) else: raise ValueError(f"Invalid output method: {output_method}. Available methods: ['json', 'markdown']") def run(self): """ Run the whole evaluation. """ if self.eval_args.dataset_names is None: dataset_names = self.data_loader.available_dataset_names() else: dataset_names = self.data_loader.check_dataset_names(self.eval_args.dataset_names) if len(dataset_names) == 0: logger.info(f"Running {self.eval_args.eval_name} evaluation on the default dataset.") self.evaluator( splits=self.eval_args.splits, search_results_save_dir=self.eval_args.output_dir, retriever=self.retriever, reranker=self.reranker, corpus_embd_save_dir=self.eval_args.corpus_embd_save_dir, ignore_identical_ids=self.eval_args.ignore_identical_ids, k_values=self.eval_args.k_values ) logger.info(f"{self.eval_args.eval_name} evaluation completed.") else: logger.info(f"Running {self.eval_args.eval_name} evaluation on the following dataset names: {dataset_names}") for dataset_name in dataset_names: logger.info(f"Running {self.eval_args.eval_name} evaluation on: {dataset_name}") self.evaluator( splits=self.eval_args.splits, search_results_save_dir=self.eval_args.output_dir, retriever=self.retriever, reranker=self.reranker, corpus_embd_save_dir=self.eval_args.corpus_embd_save_dir, ignore_identical_ids=self.eval_args.ignore_identical_ids, k_values=self.eval_args.k_values, dataset_name=dataset_name, ) logger.info(f"{self.eval_args.eval_name} evaluation on {dataset_names} completed.") logger.info("Start computing metrics.") self.evaluate_metrics( search_results_save_dir=self.eval_args.output_dir, output_method=self.eval_args.eval_output_method, output_path=self.eval_args.eval_output_path, metrics=self.eval_args.eval_metrics )