import os
import json
import logging
import datasets
from tqdm import tqdm
from typing import List, Optional
from FlagEmbedding.abc.evaluation import AbsEvalDataLoader
from .utils.normalize_text import normalize_text
logger = logging.getLogger(__name__)
[docs]
class MKQAEvalDataLoader(AbsEvalDataLoader):
"""
Data loader class for MKQA.
"""
[docs]
def available_dataset_names(self) -> List[str]:
"""
Get the available dataset names.
Returns:
List[str]: All the available dataset names.
"""
return ['en', 'ar', 'fi', 'ja', 'ko', 'ru', 'es', 'sv', 'he', 'th', 'da', 'de', 'fr', 'it', 'nl', 'pl', 'pt', 'hu', 'vi', 'ms', 'km', 'no', 'tr', 'zh_cn', 'zh_hk', 'zh_tw']
[docs]
def available_splits(self, dataset_name: Optional[str] = None) -> List[str]:
"""
Get the avaialble splits.
Args:
dataset_name (str): Dataset name.
Returns:
List[str]: All the available splits for the dataset.
"""
return ["test"]
[docs]
def load_corpus(self, dataset_name: Optional[str] = None) -> datasets.DatasetDict:
"""Load the corpus.
Args:
dataset_name (Optional[str], optional): Name of the dataset. Defaults to None.
Returns:
datasets.DatasetDict: Loaded datasets instance of corpus.
"""
if self.dataset_dir is not None:
# same corpus for all languages
save_dir = self.dataset_dir
return self._load_local_corpus(save_dir, dataset_name=dataset_name)
else:
return self._load_remote_corpus(dataset_name=dataset_name)
[docs]
def _load_local_qrels(self, save_dir: str, dataset_name: Optional[str] = None, split: str = 'test') -> datasets.DatasetDict:
"""Try to load qrels from local datasets.
Args:
save_dir (str): Directory that save the data files.
dataset_name (Optional[str], optional): Name of the dataset. Defaults to ``None``.
split (str, optional): Split of the dataset. Defaults to ``'test'``.
Raises:
ValueError: No local qrels found, will try to download from remote.
Returns:
datasets.DatasetDict: Loaded datasets instance of qrels.
"""
checked_split = self.check_splits(split)
if len(checked_split) == 0:
raise ValueError(f"Split {split} not found in the dataset.")
split = checked_split[0]
qrels_path = os.path.join(save_dir, f"{split}_qrels.jsonl")
if self.force_redownload or not os.path.exists(qrels_path):
logger.warning(f"Qrels not found in {qrels_path}. Trying to download the qrels from the remote and save it to {save_dir}.")
return self._load_remote_qrels(dataset_name=dataset_name, split=split, save_dir=save_dir)
else:
qrels_data = datasets.load_dataset('json', data_files=qrels_path, cache_dir=self.cache_dir)['train']
qrels = {}
for data in qrels_data:
qid = data['qid']
qrels[qid] = data['answers']
return datasets.DatasetDict(qrels)
[docs]
def _load_remote_corpus(
self,
dataset_name: Optional[str] = None,
save_dir: Optional[str] = None
) -> datasets.DatasetDict:
"""
Refer to: https://arxiv.org/pdf/2402.03216. We use the corpus from the BeIR dataset.
"""
corpus = datasets.load_dataset(
"BeIR/nq", "corpus",
cache_dir=self.cache_dir,
trust_remote_code=True,
download_mode=self.hf_download_mode
)["corpus"]
if save_dir is not None:
os.makedirs(save_dir, exist_ok=True)
save_path = os.path.join(save_dir, "corpus.jsonl")
corpus_dict = {}
with open(save_path, "w", encoding="utf-8") as f:
for data in tqdm(corpus, desc="Loading and Saving corpus"):
docid, title, text = str(data["_id"]), normalize_text(data["title"]).lower(), normalize_text(data["text"]).lower()
_data = {
"id": docid,
"title": title,
"text": text
}
corpus_dict[docid] = {
"title": title,
"text": text
}
f.write(json.dumps(_data, ensure_ascii=False) + "\n")
logging.info(f"{self.eval_name} corpus saved to {save_path}")
else:
corpus_dict = {}
for data in tqdm(corpus, desc="Loading corpus"):
docid, title, text = str(data["_id"]), normalize_text(data["title"]), normalize_text(data["text"])
corpus_dict[docid] = {
"title": title,
"text": text
}
return datasets.DatasetDict(corpus_dict)
[docs]
def _load_remote_qrels(
self,
dataset_name: str,
split: str = 'test',
save_dir: Optional[str] = None
) -> datasets.DatasetDict:
"""Load remote qrels from HF.
Args:
dataset_name (str): Name of the dataset.
split (str, optional): Split of the dataset. Defaults to ``'test'``.
save_dir (Optional[str], optional): Directory to save the dataset. Defaults to ``None``.
Returns:
datasets.DatasetDict: Loaded datasets instance of qrel.
"""
endpoint = f"{os.getenv('HF_ENDPOINT', 'https://huggingface.co')}/datasets/Shitao/bge-m3-data"
queries_download_url = f"{endpoint}/resolve/main/MKQA_test-data.zip"
qrels_save_dir = self._download_zip_file(queries_download_url, self.cache_dir)
qrels_save_path = os.path.join(qrels_save_dir, f"{dataset_name}.jsonl")
if save_dir is not None:
os.makedirs(save_dir, exist_ok=True)
save_path = os.path.join(save_dir, f"{split}_qrels.jsonl")
qrels_dict = {}
with open(save_path, "w", encoding="utf-8") as f1:
with open(qrels_save_path, "r", encoding="utf-8") as f2:
for line in tqdm(f2.readlines(), desc="Loading and Saving qrels"):
data = json.loads(line)
qid, answers = str(data["id"]), data["answers"]
_data = {
"qid": qid,
"answers": answers
}
if qid not in qrels_dict:
qrels_dict[qid] = {}
qrels_dict[qid] = answers
f1.write(json.dumps(_data, ensure_ascii=False) + "\n")
logging.info(f"{self.eval_name} {dataset_name} qrels saved to {save_path}")
else:
qrels_dict = {}
with open(qrels_save_path, "r", encoding="utf-8") as f:
for line in tqdm(f.readlines(), desc="Loading qrels"):
data = json.loads(line)
qid, answers = str(data["id"]), data["answers"]
if qid not in qrels_dict:
qrels_dict[qid] = {}
qrels_dict[qid] = answers
return datasets.DatasetDict(qrels_dict)
[docs]
def _load_remote_queries(
self,
dataset_name: str,
split: str = 'test',
save_dir: Optional[str] = None
) -> datasets.DatasetDict:
"""Load the queries from HF.
Args:
dataset_name (str): Name of the dataset.
split (str, optional): Split of the dataset. Defaults to ``'test'``.
save_dir (Optional[str], optional): Directory to save the dataset. Defaults to ``None``.
Returns:
datasets.DatasetDict: Loaded datasets instance of queries.
"""
endpoint = f"{os.getenv('HF_ENDPOINT', 'https://huggingface.co')}/datasets/Shitao/bge-m3-data"
queries_download_url = f"{endpoint}/resolve/main/MKQA_test-data.zip"
queries_save_dir = self._download_zip_file(queries_download_url, self.cache_dir)
queries_save_path = os.path.join(queries_save_dir, f"{dataset_name}.jsonl")
if save_dir is not None:
os.makedirs(save_dir, exist_ok=True)
save_path = os.path.join(save_dir, f"{split}_queries.jsonl")
queries_dict = {}
with open(save_path, "w", encoding="utf-8") as f1:
with open(queries_save_path, "r", encoding="utf-8") as f2:
for line in tqdm(f2.readlines(), desc="Loading and Saving queries"):
data = json.loads(line)
qid, query = str(data["id"]), data["question"]
_data = {
"id": qid,
"text": query
}
queries_dict[qid] = query
f1.write(json.dumps(_data, ensure_ascii=False) + "\n")
logging.info(f"{self.eval_name} {dataset_name} queries saved to {save_path}")
else:
queries_dict = {}
with open(queries_save_path, "r", encoding="utf-8") as f:
for line in tqdm(f.readlines(), desc="Loading queries"):
data = json.loads(line)
qid, query = str(data["id"]), data["question"]
queries_dict[qid] = query
return datasets.DatasetDict(queries_dict)