BGE Explanation#

In this section, we will go through BGE and BGE-v1.5’s structure and how they generate embeddings.

0. Installation#

Install the required packages in your environment.

%%capture
%pip install -U transformers FlagEmbedding

1. Encode sentences#

To know how exactly a sentence is encoded, let’s first load the tokenizer and model from HF transformers instead of FlagEmbedding

from transformers import AutoTokenizer, AutoModel
import torch

tokenizer = AutoTokenizer.from_pretrained("BAAI/bge-base-en-v1.5")
model = AutoModel.from_pretrained("BAAI/bge-base-en-v1.5")

sentences = ["embedding", "I love machine learning and nlp"]

Run the following cell to check the model of bge-base-en-v1.5. It has the exactly same structure of BERT-base, 12 encoder layers and hidden dimension of 768.

Note that the corresponding models of BGE and BGE-v1.5 have same structures. For example, bge-base-en and bge-base-en-v1.5 have the same structure.

model.eval()
BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30522, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0-11): 12 x BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (intermediate): BertIntermediate(
          (dense): Linear(in_features=768, out_features=3072, bias=True)
          (intermediate_act_fn): GELUActivation()
        )
        (output): BertOutput(
          (dense): Linear(in_features=3072, out_features=768, bias=True)
          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
  )
  (pooler): BertPooler(
    (dense): Linear(in_features=768, out_features=768, bias=True)
    (activation): Tanh()
  )
)

First, let’s tokenize the sentences.

inputs = tokenizer(
    sentences, 
    padding=True, 
    truncation=True, 
    return_tensors='pt', 
    max_length=512
)
inputs
{'input_ids': tensor([[  101,  7861,  8270,  4667,   102,     0,     0,     0,     0],
        [  101,  1045,  2293,  3698,  4083,  1998, 17953,  2361,   102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1]])}

From the results, we can see that each sentence begins with token 101 and ends with 102, they are the [CLS] and [SEP] special token used in BERT.

last_hidden_state = model(**inputs, return_dict=True).last_hidden_state
last_hidden_state.shape
torch.Size([2, 9, 768])

Here we implement the pooling function, with two choices of using [CLS]’s last hidden state, or the mean pooling of the whole last hidden state.

def pooling(last_hidden_state: torch.Tensor, pooling_method='cls', attention_mask: torch.Tensor = None):
    if pooling_method == 'cls':
        return last_hidden_state[:, 0]
    elif pooling_method == 'mean':
        s = torch.sum(last_hidden_state * attention_mask.unsqueeze(-1).float(), dim=1)
        d = attention_mask.sum(dim=1, keepdim=True).float()
        return s / d

Different from more commonly used mean pooling, BGE is trained to use the last hidden state of [CLS] as the sentence embedding:

sentence_embeddings = model_output[0][:, 0]

If you use mean pooling, there will be a significant decrease in performance. Therefore, make sure to use the correct method to obtain sentence vectors.

embeddings = pooling(
    last_hidden_state, 
    pooling_method='cls', 
    attention_mask=inputs['attention_mask']
)
embeddings.shape
torch.Size([2, 768])

Assembling them together, we get the whole encoding function:

def _encode(sentences, max_length=512, convert_to_numpy=True):

    # handle the case of single sentence and a list of sentences
    input_was_string = False
    if isinstance(sentences, str):
        sentences = [sentences]
        input_was_string = True

    inputs = tokenizer(
        sentences, 
        padding=True, 
        truncation=True, 
        return_tensors='pt', 
        max_length=max_length
    )

    last_hidden_state = model(**inputs, return_dict=True).last_hidden_state
    
    embeddings = pooling(
        last_hidden_state, 
        pooling_method='cls', 
        attention_mask=inputs['attention_mask']
    )

    # normalize the embedding vectors
    embeddings = torch.nn.functional.normalize(embeddings, dim=-1)

    # convert to numpy if needed
    if convert_to_numpy:
        embeddings = embeddings.detach().numpy()

    return embeddings[0] if input_was_string else embeddings

2. Comparison#

Now let’s run the function we wrote to get the embeddings of the two sentences:

embeddings = _encode(sentences)
print(f"Embeddings:\n{embeddings}")

scores = embeddings @ embeddings.T
print(f"Similarity scores:\n{scores}")
Embeddings:
[[ 1.4549762e-02 -9.6840411e-03  3.7761475e-03 ... -8.5092714e-04
   2.8417887e-02  6.3214332e-02]
 [ 3.3924331e-05 -3.2998275e-03  1.7206438e-02 ...  3.5703944e-03
   1.8721525e-02 -2.0371782e-02]]
Similarity scores:
[[0.9999997 0.6077381]
 [0.6077381 0.9999999]]

Then, run the API provided in FlagEmbedding:

from FlagEmbedding import FlagModel

model = FlagModel('BAAI/bge-base-en-v1.5')

embeddings = model.encode(sentences)
print(f"Embeddings:\n{embeddings}")

scores = embeddings @ embeddings.T
print(f"Similarity scores:\n{scores}")
Embeddings:
[[ 1.4549762e-02 -9.6840411e-03  3.7761475e-03 ... -8.5092714e-04
   2.8417887e-02  6.3214332e-02]
 [ 3.3924331e-05 -3.2998275e-03  1.7206438e-02 ...  3.5703944e-03
   1.8721525e-02 -2.0371782e-02]]
Similarity scores:
[[0.9999997 0.6077381]
 [0.6077381 0.9999999]]

As we expect, the two encoding functions return exactly the same results. The full implementation in FlagEmbedding handles large datasets by batching and contains GPU support and parallelization. Feel free to check the source code for more details.