Source code for FlagEmbedding.evaluation.air_bench.runner

from typing import Union, Tuple
from air_benchmark import AIRBench

from FlagEmbedding.abc.evaluation import (
    AbsEvalRunner,
    EvalDenseRetriever, EvalReranker
)

from .arguments import AIRBenchEvalArgs, AIRBenchEvalModelArgs


[docs] class AIRBenchEvalRunner: """ Evaluation runner for AIR Bench. Args: eval_args (AIRBenchEvalArgs): :class:AIRBenchEvalArgs object with the evaluation arguments. model_args (AIRBenchEvalModelArgs): :class:AIRBenchEvalModelArgs object with the model arguments. """ def __init__( self, eval_args: AIRBenchEvalArgs, model_args: AIRBenchEvalModelArgs, ): self.eval_args = eval_args self.model_args = model_args self.model_args.cache_dir = model_args.model_cache_dir self.retriever, self.reranker = self.load_retriever_and_reranker() def load_retriever_and_reranker(self) -> Tuple[EvalDenseRetriever, Union[EvalReranker, None]]: """Load retriever and reranker for evaluation Returns: Tuple[EvalDenseRetriever, Union[EvalReranker, None]]: A :class:EvalDenseRetriever object for retrieval, and a :class:EvalReranker object if reranker provided. """ embedder, reranker = AbsEvalRunner.get_models(self.model_args) retriever = EvalDenseRetriever( embedder, search_top_k=self.eval_args.search_top_k, overwrite=self.eval_args.overwrite ) if reranker is not None: reranker = EvalReranker(reranker, rerank_top_k=self.eval_args.rerank_top_k) return retriever, reranker def run(self): """ Run the whole evaluation. """ evaluation = AIRBench( benchmark_version=self.eval_args.benchmark_version, task_types=self.eval_args.task_types, domains=self.eval_args.domains, languages=self.eval_args.languages, splits=self.eval_args.splits, cache_dir=self.eval_args.cache_dir, ) evaluation.run( self.retriever, reranker=self.reranker, output_dir=self.eval_args.output_dir, overwrite=self.eval_args.overwrite, )