Source code for FlagEmbedding.evaluation.air_bench.arguments

from dataclasses import dataclass, field
from typing import List, Optional
from air_benchmark import EvalArgs as AIRBenchEvalArgs


[docs] @dataclass class AIRBenchEvalModelArgs: """ Evaluation Model arguments for AIR Bench. """ embedder_name_or_path: str = field( metadata={"help": "The embedder name or path.", "required": True} ) embedder_model_class: Optional[str] = field( default=None, metadata={"help": "The embedder model class. Available classes: ['encoder-only-base', 'encoder-only-m3', 'decoder-only-base', 'decoder-only-icl']. Default: None. For the custom model, you need to specifiy the model class.", "choices": ["encoder-only-base", "encoder-only-m3", "decoder-only-base", "decoder-only-icl"]} ) normalize_embeddings: bool = field( default=True, metadata={"help": "whether to normalize the embeddings"} ) pooling_method: str = field( default="cls", metadata={"help": "The pooling method fot the embedder."} ) use_fp16: bool = field( default=True, metadata={"help": "whether to use fp16 for inference"} ) devices: Optional[str] = field( default=None, metadata={"help": "Devices to use for inference.", "nargs": "+"} ) query_instruction_for_retrieval: Optional[str] = field( default=None, metadata={"help": "Instruction for query"} ) query_instruction_format_for_retrieval: str = field( default="{}{}", metadata={"help": "Format for query instruction"} ) examples_for_task: Optional[str] = field( default=None, metadata={"help": "Examples for task"} ) examples_instruction_format: str = field( default="{}{}", metadata={"help": "Format for examples instruction"} ) trust_remote_code: bool = field( default=False, metadata={"help": "Trust remote code"} ) reranker_name_or_path: Optional[str] = field( default=None, metadata={"help": "The reranker name or path."} ) reranker_model_class: Optional[str] = field( default=None, metadata={"help": "The reranker model class. Available classes: ['encoder-only-base', 'decoder-only-base', 'decoder-only-layerwise', 'decoder-only-lightweight']. Default: None. For the custom model, you need to specify the model class.", "choices": ["encoder-only-base", "decoder-only-base", "decoder-only-layerwise", "decoder-only-lightweight"]} ) reranker_peft_path: Optional[str] = field( default=None, metadata={"help": "The reranker peft path."} ) use_bf16: bool = field( default=False, metadata={"help": "whether to use bf16 for inference"} ) query_instruction_for_rerank: Optional[str] = field( default=None, metadata={"help": "Instruction for query"} ) query_instruction_format_for_rerank: str = field( default="{}{}", metadata={"help": "Format for query instruction"} ) passage_instruction_for_rerank: Optional[str] = field( default=None, metadata={"help": "Instruction for passage"} ) passage_instruction_format_for_rerank: str = field( default="{}{}", metadata={"help": "Format for passage instruction"} ) model_cache_dir: str = field( default=None, metadata={"help": "Cache directory for models."} ) # ================ for inference =============== embedder_batch_size: int = field( default=3000, metadata={"help": "Batch size for inference."} ) reranker_batch_size: int = field( default=3000, metadata={"help": "Batch size for inference."} ) embedder_query_max_length: int = field( default=512, metadata={"help": "Max length for query."} ) embedder_passage_max_length: int = field( default=512, metadata={"help": "Max length for passage."} ) reranker_query_max_length: Optional[int] = field( default=None, metadata={"help": "Max length for reranking."} ) reranker_max_length: int = field( default=512, metadata={"help": "Max length for reranking."} ) normalize: bool = field( default=False, metadata={"help": "whether to normalize the reranking scores"} ) prompt: Optional[str] = field( default=None, metadata={"help": "The prompt for the reranker."} ) cutoff_layers: List[int] = field( default=None, metadata={"help": "The output layers of layerwise/lightweight reranker."} ) compress_ratio: int = field( default=1, metadata={"help": "The compress ratio of lightweight reranker."} ) compress_layers: Optional[int] = field( default=None, metadata={"help": "The compress layers of lightweight reranker.", "nargs": "+"} ) def __post_init__(self): # replace "\\n" with "\n" if "\\n" in self.query_instruction_format_for_retrieval: self.query_instruction_format_for_retrieval = self.query_instruction_format_for_retrieval.replace("\\n", "\n") if "\\n" in self.examples_instruction_format: self.examples_instruction_format = self.examples_instruction_format.replace("\\n", "\n") if "\\n" in self.query_instruction_format_for_rerank: self.query_instruction_format_for_rerank = self.query_instruction_format_for_rerank.replace("\\n", "\n") if "\\n" in self.passage_instruction_format_for_rerank: self.passage_instruction_format_for_rerank = self.passage_instruction_format_for_rerank.replace("\\n", "\n")