BGE Explanation#
In this section, we will go through BGE and BGE-v1.5’s structure and how they generate embeddings.
0. Installation#
Install the required packages in your environment.
%%capture
%pip install -U transformers FlagEmbedding
1. Encode sentences#
To know how exactly a sentence is encoded, let’s first load the tokenizer and model from HF transformers instead of FlagEmbedding
from transformers import AutoTokenizer, AutoModel
import torch
tokenizer = AutoTokenizer.from_pretrained("BAAI/bge-base-en-v1.5")
model = AutoModel.from_pretrained("BAAI/bge-base-en-v1.5")
sentences = ["embedding", "I love machine learning and nlp"]
Run the following cell to check the model of bge-base-en-v1.5. It has the exactly same structure of BERT-base, 12 encoder layers and hidden dimension of 768.
Note that the corresponding models of BGE and BGE-v1.5 have same structures. For example, bge-base-en and bge-base-en-v1.5 have the same structure.
model.eval()
BertModel(
(embeddings): BertEmbeddings(
(word_embeddings): Embedding(30522, 768, padding_idx=0)
(position_embeddings): Embedding(512, 768)
(token_type_embeddings): Embedding(2, 768)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(encoder): BertEncoder(
(layer): ModuleList(
(0-11): 12 x BertLayer(
(attention): BertAttention(
(self): BertSelfAttention(
(query): Linear(in_features=768, out_features=768, bias=True)
(key): Linear(in_features=768, out_features=768, bias=True)
(value): Linear(in_features=768, out_features=768, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(output): BertSelfOutput(
(dense): Linear(in_features=768, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(intermediate): BertIntermediate(
(dense): Linear(in_features=768, out_features=3072, bias=True)
(intermediate_act_fn): GELUActivation()
)
(output): BertOutput(
(dense): Linear(in_features=3072, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
)
)
(pooler): BertPooler(
(dense): Linear(in_features=768, out_features=768, bias=True)
(activation): Tanh()
)
)
First, let’s tokenize the sentences.
inputs = tokenizer(
sentences,
padding=True,
truncation=True,
return_tensors='pt',
max_length=512
)
inputs
{'input_ids': tensor([[ 101, 7861, 8270, 4667, 102, 0, 0, 0, 0],
[ 101, 1045, 2293, 3698, 4083, 1998, 17953, 2361, 102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 1, 1, 1]])}
From the results, we can see that each sentence begins with token 101 and ends with 102, they are the [CLS]
and [SEP]
special token used in BERT.
last_hidden_state = model(**inputs, return_dict=True).last_hidden_state
last_hidden_state.shape
torch.Size([2, 9, 768])
Here we implement the pooling function, with two choices of using [CLS]
’s last hidden state, or the mean pooling of the whole last hidden state.
def pooling(last_hidden_state: torch.Tensor, pooling_method='cls', attention_mask: torch.Tensor = None):
if pooling_method == 'cls':
return last_hidden_state[:, 0]
elif pooling_method == 'mean':
s = torch.sum(last_hidden_state * attention_mask.unsqueeze(-1).float(), dim=1)
d = attention_mask.sum(dim=1, keepdim=True).float()
return s / d
Different from more commonly used mean pooling, BGE is trained to use the last hidden state of [CLS]
as the sentence embedding:
sentence_embeddings = model_output[0][:, 0]
If you use mean pooling, there will be a significant decrease in performance. Therefore, make sure to use the correct method to obtain sentence vectors.
embeddings = pooling(
last_hidden_state,
pooling_method='cls',
attention_mask=inputs['attention_mask']
)
embeddings.shape
torch.Size([2, 768])
Assembling them together, we get the whole encoding function:
def _encode(sentences, max_length=512, convert_to_numpy=True):
# handle the case of single sentence and a list of sentences
input_was_string = False
if isinstance(sentences, str):
sentences = [sentences]
input_was_string = True
inputs = tokenizer(
sentences,
padding=True,
truncation=True,
return_tensors='pt',
max_length=max_length
)
last_hidden_state = model(**inputs, return_dict=True).last_hidden_state
embeddings = pooling(
last_hidden_state,
pooling_method='cls',
attention_mask=inputs['attention_mask']
)
# normalize the embedding vectors
embeddings = torch.nn.functional.normalize(embeddings, dim=-1)
# convert to numpy if needed
if convert_to_numpy:
embeddings = embeddings.detach().numpy()
return embeddings[0] if input_was_string else embeddings
2. Comparison#
Now let’s run the function we wrote to get the embeddings of the two sentences:
embeddings = _encode(sentences)
print(f"Embeddings:\n{embeddings}")
scores = embeddings @ embeddings.T
print(f"Similarity scores:\n{scores}")
Embeddings:
[[ 1.4549762e-02 -9.6840411e-03 3.7761475e-03 ... -8.5092714e-04
2.8417887e-02 6.3214332e-02]
[ 3.3924331e-05 -3.2998275e-03 1.7206438e-02 ... 3.5703944e-03
1.8721525e-02 -2.0371782e-02]]
Similarity scores:
[[0.9999997 0.6077381]
[0.6077381 0.9999999]]
Then, run the API provided in FlagEmbedding:
from FlagEmbedding import FlagModel
model = FlagModel('BAAI/bge-base-en-v1.5')
embeddings = model.encode(sentences)
print(f"Embeddings:\n{embeddings}")
scores = embeddings @ embeddings.T
print(f"Similarity scores:\n{scores}")
Embeddings:
[[ 1.4549762e-02 -9.6840411e-03 3.7761475e-03 ... -8.5092714e-04
2.8417887e-02 6.3214332e-02]
[ 3.3924331e-05 -3.2998275e-03 1.7206438e-02 ... 3.5703944e-03
1.8721525e-02 -2.0371782e-02]]
Similarity scores:
[[0.9999997 0.6077381]
[0.6077381 0.9999999]]
As we expect, the two encoding functions return exactly the same results. The full implementation in FlagEmbedding handles large datasets by batching and contains GPU support and parallelization. Feel free to check the source code for more details.