import os
import json
import logging
import datasets
from tqdm import tqdm
from typing import List, Optional
from FlagEmbedding.abc.evaluation import AbsEvalDataLoader
logger = logging.getLogger(__name__)
[docs]
class MSMARCOEvalDataLoader(AbsEvalDataLoader):
"""
Data loader class for MSMARCO.
"""
[docs]
def available_dataset_names(self) -> List[str]:
"""
Get the available dataset names.
Returns:
List[str]: All the available dataset names.
"""
return ["passage", "document"]
[docs]
def available_splits(self, dataset_name: Optional[str] = None) -> List[str]:
"""
Get the avaialble splits.
Args:
dataset_name (Optional[str], optional): Dataset name. Defaults to ``None``.
Returns:
List[str]: All the available splits for the dataset.
"""
return ["dev", "dl19", "dl20"]
[docs]
def _load_remote_corpus(
self,
dataset_name: str,
save_dir: Optional[str] = None
) -> datasets.DatasetDict:
"""Load the corpus dataset from HF.
Args:
dataset_name (str): Name of the dataset.
save_dir (Optional[str], optional): Directory to save the dataset. Defaults to ``None``.
Returns:
datasets.DatasetDict: Loaded datasets instance of corpus.
"""
if dataset_name == 'passage':
corpus = datasets.load_dataset(
'Tevatron/msmarco-passage-corpus',
'default',
trust_remote_code=True,
cache_dir=self.cache_dir,
download_mode=self.hf_download_mode
)['train']
else:
corpus = datasets.load_dataset(
'irds/msmarco-document',
'docs',
trust_remote_code=True,
cache_dir=self.cache_dir,
download_mode=self.hf_download_mode
)
if save_dir is not None:
os.makedirs(save_dir, exist_ok=True)
save_path = os.path.join(save_dir, "corpus.jsonl")
corpus_dict = {}
with open(save_path, "w", encoding="utf-8") as f:
for data in tqdm(corpus, desc="Loading and Saving corpus"):
if dataset_name == 'passage':
_data = {
"id": data["docid"],
"title": data["title"],
"text": data["text"]
}
corpus_dict[data["docid"]] = {
"title": data["title"],
"text": data["text"]
}
else:
_data = {
"id": data["doc_id"],
"title": data["title"],
"text": data["body"]
}
corpus_dict[data["doc_id"]] = {
"title": data["title"],
"text": data["body"]
}
f.write(json.dumps(_data, ensure_ascii=False) + "\n")
logging.info(f"{self.eval_name} {dataset_name} corpus saved to {save_path}")
else:
if dataset_name == 'passage':
corpus_dict = {data["docid"]: {"title": data["title"], "text": data["text"]} for data in tqdm(corpus, desc="Loading corpus")}
else:
corpus_dict = {data["doc_id"]: {"title": data["title"], "text": data["body"]} for data in tqdm(corpus, desc="Loading corpus")}
return datasets.DatasetDict(corpus_dict)
[docs]
def _load_remote_qrels(
self,
dataset_name: Optional[str] = None,
split: str = 'dev',
save_dir: Optional[str] = None
) -> datasets.DatasetDict:
"""Load the qrels from HF.
Args:
dataset_name (str): Name of the dataset.
split (str, optional): Split of the dataset. Defaults to ``'dev'``.
save_dir (Optional[str], optional): Directory to save the dataset. Defaults to ``None``.
Returns:
datasets.DatasetDict: Loaded datasets instance of qrel.
"""
if dataset_name == 'passage':
if split == 'dev':
qrels = datasets.load_dataset(
'BeIR/msmarco-qrels',
split='validation',
trust_remote_code=True,
cache_dir=self.cache_dir,
download_mode=self.hf_download_mode
)
qrels_download_url = None
elif split == 'dl19':
qrels_download_url = "https://trec.nist.gov/data/deep/2019qrels-pass.txt"
else:
qrels_download_url = "https://trec.nist.gov/data/deep/2020qrels-pass.txt"
else:
if split == 'dev':
qrels_download_url = "https://msmarco.z22.web.core.windows.net/msmarcoranking/msmarco-docdev-qrels.tsv.gz"
elif split == 'dl19':
qrels_download_url = "https://trec.nist.gov/data/deep/2019qrels-docs.txt"
else:
qrels_download_url = "https://trec.nist.gov/data/deep/2020qrels-docs.txt"
if qrels_download_url is not None:
qrels_save_path = self._download_file(qrels_download_url, self.cache_dir)
else:
qrels_save_path = None
if save_dir is not None:
os.makedirs(save_dir, exist_ok=True)
save_path = os.path.join(save_dir, f"{split}_qrels.jsonl")
qrels_dict = {}
if qrels_save_path is not None:
with open(save_path, "w", encoding="utf-8") as f1:
with open(qrels_save_path, "r", encoding="utf-8") as f2:
for line in tqdm(f2.readlines(), desc="Loading and Saving qrels"):
qid, _, docid, rel = line.strip().split()
qid, docid, rel = str(qid), str(docid), int(rel)
_data = {
"qid": qid,
"docid": docid,
"relevance": rel
}
if qid not in qrels_dict:
qrels_dict[qid] = {}
qrels_dict[qid][docid] = rel
f1.write(json.dumps(_data, ensure_ascii=False) + "\n")
else:
with open(save_path, "w", encoding="utf-8") as f:
for data in tqdm(qrels, desc="Loading and Saving qrels"):
qid, docid, rel = str(data['query-id']), str(data['corpus-id']), int(data['score'])
_data = {
"qid": qid,
"docid": docid,
"relevance": rel
}
if qid not in qrels_dict:
qrels_dict[qid] = {}
qrels_dict[qid][docid] = rel
f.write(json.dumps(_data, ensure_ascii=False) + "\n")
logging.info(f"{self.eval_name} {dataset_name} qrels saved to {save_path}")
else:
qrels_dict = {}
if qrels_save_path is None:
with open(qrels_save_path, "r", encoding="utf-8") as f:
for line in tqdm(f.readlines(), desc="Loading qrels"):
qid, _, docid, rel = line.strip().split()
qid, docid, rel = str(qid), str(docid), int(rel)
if qid not in qrels_dict:
qrels_dict[qid] = {}
qrels_dict[qid][docid] = rel
else:
for data in tqdm(qrels, desc="Loading queries"):
qid, docid, rel = str(data['query-id']), str(data['corpus-id']), int(data['score'])
if qid not in qrels_dict:
qrels_dict[qid] = {}
qrels_dict[qid][docid] = rel
return datasets.DatasetDict(qrels_dict)
[docs]
def _load_remote_queries(
self,
dataset_name: Optional[str] = None,
split: str = 'test',
save_dir: Optional[str] = None
) -> datasets.DatasetDict:
"""Load the queries from HF.
Args:
dataset_name (str): Name of the dataset.
split (str, optional): Split of the dataset. Defaults to ``'test'``.
save_dir (Optional[str], optional): Directory to save the dataset. Defaults to ``None``.
Returns:
datasets.DatasetDict: Loaded datasets instance of queries.
"""
if split == 'dev':
if dataset_name == 'passage':
queries = datasets.load_dataset(
'BeIR/msmarco',
'queries',
trust_remote_code=True,
cache_dir=self.cache_dir,
download_mode=self.hf_download_mode
)['queries']
queries_save_path = None
else:
queries_download_url = "https://msmarco.z22.web.core.windows.net/msmarcoranking/msmarco-docdev-qrels.tsv.gz"
queries_save_path = self._download_gz_file(queries_download_url, self.cache_dir)
else:
year = split.replace("dl", "")
queries_download_url = f"https://msmarco.z22.web.core.windows.net/msmarcoranking/msmarco-test20{year}-queries.tsv.gz"
queries_save_path = self._download_gz_file(queries_download_url, self.cache_dir)
qrels = self.load_qrels(dataset_name=dataset_name, split=split)
if save_dir is not None:
os.makedirs(save_dir, exist_ok=True)
save_path = os.path.join(save_dir, f"{split}_queries.jsonl")
queries_dict = {}
if queries_save_path is not None:
with open(save_path, "w", encoding="utf-8") as f1:
with open(queries_save_path, "r", encoding="utf-8") as f2:
for line in tqdm(f2.readlines(), desc="Loading and Saving queries"):
qid, query = line.strip().split("\t")
if qid not in qrels.keys(): continue
qid = str(qid)
_data = {
"id": qid,
"text": query
}
queries_dict[qid] = query
f1.write(json.dumps(_data, ensure_ascii=False) + "\n")
else:
with open(save_path, "w", encoding="utf-8") as f:
for data in tqdm(queries, desc="Loading and Saving queries"):
qid, query = data['_id'], data['text']
if qid not in qrels.keys(): continue
_data = {
"id": qid,
"text": query
}
queries_dict[qid] = query
f.write(json.dumps(_data, ensure_ascii=False) + "\n")
logging.info(f"{self.eval_name} {dataset_name} queries saved to {save_path}")
else:
queries_dict = {}
if queries_save_path is not None:
with open(queries_save_path, "r", encoding="utf-8") as f:
for line in tqdm(f.readlines(), desc="Loading queries"):
qid, query = line.strip().split("\t")
qid = str(qid)
if qid not in qrels.keys(): continue
queries_dict[qid] = query
else:
for data in tqdm(queries, desc="Loading queries"):
qid, query = data['_id'], data['text']
if qid not in qrels.keys(): continue
queries_dict[qid] = query
return datasets.DatasetDict(queries_dict)