Source code for FlagEmbedding.finetune.reranker.decoder_only.base.arguments
from typing import List
from dataclasses import dataclass, field
from FlagEmbedding.abc.finetune.reranker import AbsRerankerModelArguments
def default_target_modules() -> List[int]:
return ['v_proj', 'q_proj', 'k_proj', 'gate_proj', 'down_proj', 'o_proj', 'up_proj']
[docs]
@dataclass
class RerankerModelArguments(AbsRerankerModelArguments):
"""
Model argument class for decoder only reranker.
"""
use_lora: bool = field(
default=True,
metadata={"help": "If passed, will use LORA (low-rank parameter-efficient training) to train the model."}
)
lora_rank: int = field(
default=64,
metadata={"help": "The rank of lora."}
)
lora_alpha: float = field(
default=16,
metadata={"help": "The alpha parameter of lora."}
)
lora_dropout: float = field(
default=0.1,
metadata={"help": "The dropout rate of lora modules."}
)
target_modules: List[str] = field(
default_factory=default_target_modules,
metadata={"help": "The target modules to apply LORA."}
)
modules_to_save: List[str] = field(
default=None,
metadata={"help": "List of modules that should be saved in the final checkpoint."}
)
use_flash_attn: bool = field(
default=False,
metadata={"help": "If passed, will use flash attention to train the model."}
)
# use_slow_tokenizer: bool = field(
# default=False,
# metadata={"help": "If passed, will use a slow tokenizer (not backed by the 🤗 Tokenizers library)."}
# )
from_peft: str = field(
default=None
)
raw_peft: List[str] = field(
default=None
)
save_merged_lora_model: bool = field(
default=False,
metadata={"help": "If passed, will merge the lora modules and save the entire model."}
)