Source code for FlagEmbedding.abc.evaluation.data_loader

"""
Adapted from https://github.com/AIR-Bench/AIR-Bench/blob/0.1.0/air_benchmark/evaluation_utils/data_loader.py
"""
import os
import logging
import datasets
import subprocess
from abc import ABC, abstractmethod
from typing import List, Optional, Union

logger = logging.getLogger(__name__)


[docs] class AbsEvalDataLoader(ABC): """ Base class of data loader for evaluation. Args: eval_name (str): The experiment name of current evaluation. dataset_dir (str, optional): path to the datasets. Defaults to ``None``. cache_dir (str, optional): Path to HuggingFace cache directory. Defaults to ``None``. token (str, optional): HF_TOKEN to access the private datasets/models in HF. Defaults to ``None``. force_redownload: If True, will force redownload the dataset to cover the local dataset. Defaults to ``False``. """ def __init__( self, eval_name: str, dataset_dir: Optional[str] = None, cache_dir: Optional[str] = None, token: Optional[str] = None, force_redownload: bool = False ): self.eval_name = eval_name self.dataset_dir = dataset_dir if cache_dir is None: cache_dir = os.getenv('HF_HUB_CACHE', '~/.cache/huggingface/hub') self.cache_dir = os.path.join(cache_dir, eval_name) self.token = token self.force_redownload = force_redownload self.hf_download_mode = None if not force_redownload else "force_redownload"
[docs] def available_dataset_names(self) -> List[str]: """ Returns: List[str]: Available dataset names. """ return []
[docs] @abstractmethod def available_splits(self, dataset_name: Optional[str] = None) -> List[str]: """ Returns: List[str]: Available splits in the dataset. """ pass
[docs] def check_dataset_names(self, dataset_names: Union[str, List[str]]) -> List[str]: """Check the validity of dataset names Args: dataset_names (Union[str, List[str]]): a dataset name (str) or a list of dataset names (List[str]) Raises: ValueError Returns: List[str]: List of valid dataset names. """ available_dataset_names = self.available_dataset_names() if isinstance(dataset_names, str): dataset_names = [dataset_names] for dataset_name in dataset_names: if dataset_name not in available_dataset_names: raise ValueError(f"Dataset name '{dataset_name}' not found in the dataset. Available dataset names: {available_dataset_names}") return dataset_names
[docs] def check_splits(self, splits: Union[str, List[str]], dataset_name: Optional[str] = None) -> List[str]: """Check whether the splits are available in the dataset. Args: splits (Union[str, List[str]]): Splits to check. dataset_name (Optional[str], optional): Name of dataset to check. Defaults to ``None``. Returns: List[str]: The available splits. """ available_splits = self.available_splits(dataset_name=dataset_name) if isinstance(splits, str): splits = [splits] checked_splits = [] for split in splits: if split not in available_splits: logger.warning(f"Split '{split}' not found in the dataset. Removing it from the list.") else: checked_splits.append(split) return checked_splits
[docs] def load_corpus(self, dataset_name: Optional[str] = None) -> datasets.DatasetDict: """Load the corpus from the dataset. Args: dataset_name (Optional[str], optional): Name of the dataset. Defaults to ``None``. Returns: datasets.DatasetDict: A dict of corpus with id as key, title and text as value. """ if self.dataset_dir is not None: if dataset_name is None: save_dir = self.dataset_dir else: save_dir = os.path.join(self.dataset_dir, dataset_name) return self._load_local_corpus(save_dir, dataset_name=dataset_name) else: return self._load_remote_corpus(dataset_name=dataset_name)
[docs] def load_qrels(self, dataset_name: Optional[str] = None, split: str = 'test') -> datasets.DatasetDict: """Load the qrels from the dataset. Args: dataset_name (Optional[str], optional): Name of the dataset. Defaults to ``None``. split (str, optional): The split to load relevance from. Defaults to ``'test'``. Raises: ValueError Returns: datasets.DatasetDict: A dict of relevance of query and document. """ if self.dataset_dir is not None: if dataset_name is None: save_dir = self.dataset_dir else: checked_dataset_names = self.check_dataset_names(dataset_name) if len(checked_dataset_names) == 0: raise ValueError(f"Dataset name {dataset_name} not found in the dataset.") dataset_name = checked_dataset_names[0] save_dir = os.path.join(self.dataset_dir, dataset_name) return self._load_local_qrels(save_dir, dataset_name=dataset_name, split=split) else: return self._load_remote_qrels(dataset_name=dataset_name, split=split)
[docs] def load_queries(self, dataset_name: Optional[str] = None, split: str = 'test') -> datasets.DatasetDict: """Load the queries from the dataset. Args: dataset_name (Optional[str], optional): Name of the dataset. Defaults to ``None``. split (str, optional): The split to load queries from. Defaults to ``'test'``. Raises: ValueError Returns: datasets.DatasetDict: A dict of queries with id as key, query text as value. """ if self.dataset_dir is not None: if dataset_name is None: save_dir = self.dataset_dir else: checked_dataset_names = self.check_dataset_names(dataset_name) if len(checked_dataset_names) == 0: raise ValueError(f"Dataset name {dataset_name} not found in the dataset.") dataset_name = checked_dataset_names[0] save_dir = os.path.join(self.dataset_dir, dataset_name) return self._load_local_queries(save_dir, dataset_name=dataset_name, split=split) else: return self._load_remote_queries(dataset_name=dataset_name, split=split)
[docs] def _load_remote_corpus( self, dataset_name: Optional[str] = None, save_dir: Optional[str] = None ) -> datasets.DatasetDict: """Abstract method to load corpus from remote dataset, to be overrode in child class. Args: dataset_name (Optional[str], optional): Name of the dataset. Defaults to ``None``. save_dir (Optional[str], optional): Path to save the new downloaded corpus. Defaults to ``None``. Raises: NotImplementedError: Loading remote corpus is not implemented. Returns: datasets.DatasetDict: A dict of corpus with id as key, title and text as value. """ raise NotImplementedError("Loading remote corpus is not implemented.")
[docs] def _load_remote_qrels( self, dataset_name: Optional[str] = None, split: str = 'test', save_dir: Optional[str] = None ) -> datasets.DatasetDict: """Abstract method to load relevance from remote dataset, to be overrode in child class. Args: dataset_name (Optional[str], optional): Name of the dataset. Defaults to ``None``. split (str, optional): Split to load from the remote dataset. Defaults to ``'test'``. save_dir (Optional[str], optional): Path to save the new downloaded relevance. Defaults to ``None``. Raises: NotImplementedError: Loading remote qrels is not implemented. Returns: datasets.DatasetDict: A dict of relevance of query and document. """ raise NotImplementedError("Loading remote qrels is not implemented.")
[docs] def _load_remote_queries( self, dataset_name: Optional[str] = None, split: str = 'test', save_dir: Optional[str] = None ) -> datasets.DatasetDict: """Abstract method to load queries from remote dataset, to be overrode in child class. Args: dataset_name (Optional[str], optional): Name of the dataset. Defaults to ``None``. split (str, optional): Split to load from the remote dataset. Defaults to ``'test'``. save_dir (Optional[str], optional): Path to save the new downloaded queries. Defaults to ``None``. Raises: NotImplementedError Returns: datasets.DatasetDict: A dict of queries with id as key, query text as value. """ raise NotImplementedError("Loading remote queries is not implemented.")
[docs] def _load_local_corpus(self, save_dir: str, dataset_name: Optional[str] = None) -> datasets.DatasetDict: """Load corpus from local dataset. Args: save_dir (str): Path to save the loaded corpus. dataset_name (Optional[str], optional): Name of the dataset. Defaults to ``None``. Returns: datasets.DatasetDict: A dict of corpus with id as key, title and text as value. """ corpus_path = os.path.join(save_dir, 'corpus.jsonl') if self.force_redownload or not os.path.exists(corpus_path): logger.warning(f"Corpus not found in {corpus_path}. Trying to download the corpus from the remote and save it to {save_dir}.") return self._load_remote_corpus(dataset_name=dataset_name, save_dir=save_dir) else: corpus_data = datasets.load_dataset('json', data_files=corpus_path, cache_dir=self.cache_dir)['train'] corpus = {} for e in corpus_data: corpus[e['id']] = {'title': e.get('title', ""), 'text': e['text']} return datasets.DatasetDict(corpus)
[docs] def _load_local_qrels(self, save_dir: str, dataset_name: Optional[str] = None, split: str = 'test') -> datasets.DatasetDict: """Load relevance from local dataset. Args: save_dir (str): Path to save the loaded relevance. dataset_name (Optional[str], optional): Name of the dataset. Defaults to ``None``. split (str, optional): Split to load from the local dataset. Defaults to ``'test'``. Raises: ValueError Returns: datasets.DatasetDict: A dict of relevance of query and document. """ checked_split = self.check_splits(split, dataset_name=dataset_name) if len(checked_split) == 0: raise ValueError(f"Split {split} not found in the dataset.") split = checked_split[0] qrels_path = os.path.join(save_dir, f"{split}_qrels.jsonl") if self.force_redownload or not os.path.exists(qrels_path): logger.warning(f"Qrels not found in {qrels_path}. Trying to download the qrels from the remote and save it to {save_dir}.") return self._load_remote_qrels(dataset_name=dataset_name, split=split, save_dir=save_dir) else: qrels_data = datasets.load_dataset('json', data_files=qrels_path, cache_dir=self.cache_dir)['train'] qrels = {} for data in qrels_data: qid = data['qid'] if qid not in qrels: qrels[qid] = {} qrels[qid][data['docid']] = data['relevance'] return datasets.DatasetDict(qrels)
[docs] def _load_local_queries(self, save_dir: str, dataset_name: Optional[str] = None, split: str = 'test') -> datasets.DatasetDict: """Load queries from local dataset. Args: save_dir (str): Path to save the loaded queries. dataset_name (Optional[str], optional): Name of the dataset. Defaults to ``None``. split (str, optional): Split to load from the local dataset. Defaults to ``'test'``. Raises: ValueError Returns: datasets.DatasetDict: A dict of queries with id as key, query text as value. """ checked_split = self.check_splits(split, dataset_name=dataset_name) if len(checked_split) == 0: raise ValueError(f"Split {split} not found in the dataset.") split = checked_split[0] queries_path = os.path.join(save_dir, f"{split}_queries.jsonl") if self.force_redownload or not os.path.exists(queries_path): logger.warning(f"Queries not found in {queries_path}. Trying to download the queries from the remote and save it to {save_dir}.") return self._load_remote_queries(dataset_name=dataset_name, split=split, save_dir=save_dir) else: queries_data = datasets.load_dataset('json', data_files=queries_path, cache_dir=self.cache_dir)['train'] queries = {e['id']: e['text'] for e in queries_data} return datasets.DatasetDict(queries)
[docs] def _download_file(self, download_url: str, save_dir: str): """Download file from provided URL. Args: download_url (str): Source URL of the file. save_dir (str): Path to the directory to save the zip file. Raises: FileNotFoundError Returns: str: The path of the downloaded file. """ save_path = os.path.join(save_dir, download_url.split('/')[-1]) if self.force_redownload or (not os.path.exists(save_path) or os.path.getsize(save_path) == 0): cmd = ["wget", "-O", save_path, download_url] else: cmd = ["wget", "-nc", "-O", save_path, download_url] try: subprocess.run(cmd, check=True) except subprocess.CalledProcessError as e: logger.warning(e.output) if not os.path.exists(save_path) or os.path.getsize(save_path) == 0: raise FileNotFoundError(f"Failed to download file from {download_url} to {save_path}") else: logger.info(f"Downloaded file from {download_url} to {save_path}") return save_path
[docs] def _get_fpath_size(self, fpath: str) -> int: """Get the total size of the files in provided path. Args: fpath (str): path of files to compute the size. Returns: int: The total size in bytes. """ if not os.path.isdir(fpath): return os.path.getsize(fpath) else: total_size = 0 for dirpath, _, filenames in os.walk(fpath): for f in filenames: fp = os.path.join(dirpath, f) total_size += os.path.getsize(fp) return total_size
[docs] def _download_gz_file(self, download_url: str, save_dir: str): """Download and unzip the gzip file from provided URL. Args: download_url (str): Source URL of the gzip file. save_dir (str): Path to the directory to save the gzip file. Raises: FileNotFoundError Returns: str: The path to the file after unzip. """ gz_file_path = self._download_file(download_url, save_dir) cmd = ["gzip", "-d", gz_file_path] try: subprocess.run(cmd, check=True) except subprocess.CalledProcessError as e: logger.warning(e.output) file_path = gz_file_path.replace(".gz", "") if not os.path.exists(file_path) or self._get_fpath_size(file_path) == 0: raise FileNotFoundError(f"Failed to unzip file {gz_file_path}") return file_path
[docs] def _download_zip_file(self, download_url: str, save_dir: str): """Download and unzip the zip file from provided URL. Args: download_url (str): Source URL of the zip file. save_dir (str): Path to the directory to save the zip file. Raises: FileNotFoundError Returns: str: The path to the file after unzip. """ zip_file_path = self._download_file(download_url, save_dir) file_path = zip_file_path.replace(".zip", "") if self.force_redownload or not os.path.exists(file_path): cmd = ["unzip", "-o", zip_file_path, "-d", file_path] else: cmd = ["unzip", "-n", zip_file_path, "-d", file_path] try: subprocess.run(cmd, check=True) except subprocess.CalledProcessError as e: logger.warning(e.output) if not os.path.exists(file_path) or self._get_fpath_size(file_path) == 0: raise FileNotFoundError(f"Failed to unzip file {zip_file_path}") return file_path