Reranker#

Reranker is disigned in cross-encoder architecture that takes the query and text at the same time and directly output their score of similarity. It is more capable on scoring the query-text relevance, but with the tradeoff of slower speed. Thus, complete retrieval system usually contains retrievers in the first stage to do a large scope retrieval, and then follows by rerankers to rerank the results more precisely.

In this tutorial, we will go through text retrieval pipeline with reranker and evaluate the results before and after reranking.

Note: Step 1-4 are identical to the tutorial of evaluation. We suggest to first go through that if you are not familiar with retrieval.

0. Setup#

Install the dependencies in the environment.

%pip install -U FlagEmbedding faiss-cpu

1. Dataset#

Download and preprocess the MS Marco dataset

from datasets import load_dataset
import numpy as np

data = load_dataset("namespace-Pt/msmarco", split="dev")
queries = np.array(data[:100]["query"])
corpus = sum(data[:5000]["positive"], [])

2. Embedding#

from FlagEmbedding import FlagModel

# get the BGE embedding model
model = FlagModel('BAAI/bge-base-en-v1.5',
                  query_instruction_for_retrieval="Represent this sentence for searching relevant passages:",
                  use_fp16=True)

# get the embedding of the corpus
corpus_embeddings = model.encode(corpus)

print("shape of the corpus embeddings:", corpus_embeddings.shape)
print("data type of the embeddings: ", corpus_embeddings.dtype)
Inference Embeddings: 100%|██████████| 21/21 [01:59<00:00,  5.68s/it]
shape of the corpus embeddings: (5331, 768)
data type of the embeddings:  float32

3. Indexing#

import faiss

# get the length of our embedding vectors, vectors by bge-base-en-v1.5 have length 768
dim = corpus_embeddings.shape[-1]

# create the faiss index and store the corpus embeddings into the vector space
index = faiss.index_factory(dim, 'Flat', faiss.METRIC_INNER_PRODUCT)
corpus_embeddings = corpus_embeddings.astype(np.float32)
index.train(corpus_embeddings)
index.add(corpus_embeddings)

print(f"total number of vectors: {index.ntotal}")
total number of vectors: 5331

4. Retrieval#

query_embeddings = model.encode_queries(queries)
ground_truths = [d["positive"] for d in data]
corpus = np.asarray(corpus)
from tqdm import tqdm

res_scores, res_ids, res_text = [], [], []
query_size = len(query_embeddings)
batch_size = 256
# The cutoffs we will use during evaluation, and set k to be the maximum of the cutoffs.
cut_offs = [1, 10]
k = max(cut_offs)

for i in tqdm(range(0, query_size, batch_size), desc="Searching"):
    q_embedding = query_embeddings[i: min(i+batch_size, query_size)].astype(np.float32)
    # search the top k answers for each of the queries
    score, idx = index.search(q_embedding, k=k)
    res_scores += list(score)
    res_ids += list(idx)
    res_text += list(corpus[idx])
Searching: 100%|██████████| 1/1 [00:00<00:00, 22.35it/s]

5. Reranking#

Now we will use a reranker to rerank the list of answers we retrieved using our index. Hopefully, this will lead to better results.

The following table lists the available BGE rerankers. Feel free to try out to see their differences!

Model

Language

Parameters

Description

Base Model

BAAI/bge-reranker-v2-m3

Multilingual

568M

a lightweight cross-encoder model, possesses strong multilingual capabilities, easy to deploy, with fast inference.

XLM-RoBERTa-Large

BAAI/bge-reranker-v2-gemma

Multilingual

2.51B

a cross-encoder model which is suitable for multilingual contexts, performs well in both English proficiency and multilingual capabilities.

Gemma2-2B

BAAI/bge-reranker-v2-minicpm-layerwise

Multilingual

2.72B

a cross-encoder model which is suitable for multilingual contexts, performs well in both English and Chinese proficiency, allows freedom to select layers for output, facilitating accelerated inference.

MiniCPM

BAAI/bge-reranker-v2.5-gemma2-lightweight

Multilingual

9.24B

a cross-encoder model which is suitable for multilingual contexts, performs well in both English and Chinese proficiency, allows freedom to select layers, compress ratio and compress layers for output, facilitating accelerated inference.

Gemma2-9B

BAAI/bge-reranker-large

Chinese and English

560M

a cross-encoder model which is more accurate but less efficient

XLM-RoBERTa-Large

BAAI/bge-reranker-base

Chinese and English

278M

a cross-encoder model which is more accurate but less efficient

XLM-RoBERTa-Base

First, let’s use a small example to see how reranker works:

from FlagEmbedding import FlagReranker

reranker = FlagReranker('BAAI/bge-reranker-large', use_fp16=True) 
# Setting use_fp16 to True speeds up computation with a slight performance degradation

# use the compute_score() function to calculate scores for each input sentence pair
scores = reranker.compute_score([
    ['what is panda?', 'Today is a sunny day'], 
    ['what is panda?', 'The tiger (Panthera tigris) is a member of the genus Panthera and the largest living cat species native to Asia.'],
    ['what is panda?', 'The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China.']
    ])
print(scores)
[-9.474676132202148, -2.823843240737915, 5.76226806640625]

Now, let’s use the reranker to rerank our previously retrieved results:

new_ids, new_scores, new_text = [], [], []
for i in range(len(queries)):
    # get the new scores of the previously retrieved results
    new_score = reranker.compute_score([[queries[i], text] for text in res_text[i]])
    # sort the lists of ids and scores by the new scores
    new_id = [tup[1] for tup in sorted(list(zip(new_score, res_ids[i])), reverse=True)]
    new_scores.append(sorted(new_score, reverse=True))
    new_ids.append(new_id)
    new_text.append(corpus[new_id])

6. Evaluate#

For details of these metrics, please checkout the tutorial of evaluation.

6.1 Recall#

def calc_recall(preds, truths, cutoffs):
    recalls = np.zeros(len(cutoffs))
    for text, truth in zip(preds, truths):
        for i, c in enumerate(cutoffs):
            recall = np.intersect1d(truth, text[:c])
            recalls[i] += len(recall) / max(min(len(recall), len(truth)), 1)
    recalls /= len(preds)
    return recalls

Before reranking:

recalls_init = calc_recall(res_text, ground_truths, cut_offs)
for i, c in enumerate(cut_offs):
    print(f"recall@{c}:\t{recalls_init[i]}")
recall@1:	0.97
recall@10:	1.0

After reranking:

recalls_rerank = calc_recall(new_text, ground_truths, cut_offs)
for i, c in enumerate(cut_offs):
    print(f"recall@{c}:\t{recalls_rerank[i]}")
recall@1:	0.99
recall@10:	1.0

6.2 MRR#

def MRR(preds, truth, cutoffs):
    mrr = [0 for _ in range(len(cutoffs))]
    for pred, t in zip(preds, truth):
        for i, c in enumerate(cutoffs):
            for j, p in enumerate(pred):
                if j < c and p in t:
                    mrr[i] += 1/(j+1)
                    break
    mrr = [k/len(preds) for k in mrr]
    return mrr

Before reranking:

mrr_init = MRR(res_text, ground_truths, cut_offs)
for i, c in enumerate(cut_offs):
    print(f"MRR@{c}:\t{mrr_init[i]}")
MRR@1:	0.97
MRR@10:	0.9825

After reranking:

mrr_rerank = MRR(new_text, ground_truths, cut_offs)
for i, c in enumerate(cut_offs):
    print(f"MRR@{c}:\t{mrr_rerank[i]}")
MRR@1:	0.99
MRR@10:	0.995

6.3 nDCG#

Before reranking:

from sklearn.metrics import ndcg_score

pred_hard_encodings = []
for pred, label in zip(res_text, ground_truths):
    pred_hard_encoding = list(np.isin(pred, label).astype(int))
    pred_hard_encodings.append(pred_hard_encoding)

for i, c in enumerate(cut_offs):
    nDCG = ndcg_score(pred_hard_encodings, res_scores, k=c)
    print(f"nDCG@{c}: {nDCG}")
nDCG@1: 0.97
nDCG@10: 0.9869253606521631

After reranking:

pred_hard_encodings_rerank = []
for pred, label in zip(new_text, ground_truths):
    pred_hard_encoding = list(np.isin(pred, label).astype(int))
    pred_hard_encodings_rerank.append(pred_hard_encoding)

for i, c in enumerate(cut_offs):
    nDCG = ndcg_score(pred_hard_encodings_rerank, new_scores, k=c)
    print(f"nDCG@{c}: {nDCG}")
nDCG@1: 0.99
nDCG@10: 0.9963092975357145