{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# BGE Explanation" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In this section, we will go through BGE and BGE-v1.5's structure and how they generate embeddings." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 0. Installation" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Install the required packages in your environment." ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "%%capture\n", "%pip install -U transformers FlagEmbedding" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 1. Encode sentences" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To know how exactly a sentence is encoded, let's first load the tokenizer and model from HF transformers instead of FlagEmbedding" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [], "source": [ "from transformers import AutoTokenizer, AutoModel\n", "import torch\n", "\n", "tokenizer = AutoTokenizer.from_pretrained(\"BAAI/bge-base-en-v1.5\")\n", "model = AutoModel.from_pretrained(\"BAAI/bge-base-en-v1.5\")\n", "\n", "sentences = [\"embedding\", \"I love machine learning and nlp\"]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Run the following cell to check the model of bge-base-en-v1.5. It has the exactly same structure of BERT-base, 12 encoder layers and hidden dimension of 768.\n", "\n", "Note that the corresponding models of BGE and BGE-v1.5 have same structures. For example, bge-base-en and bge-base-en-v1.5 have the same structure." ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "BertModel(\n", " (embeddings): BertEmbeddings(\n", " (word_embeddings): Embedding(30522, 768, padding_idx=0)\n", " (position_embeddings): Embedding(512, 768)\n", " (token_type_embeddings): Embedding(2, 768)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (encoder): BertEncoder(\n", " (layer): ModuleList(\n", " (0-11): 12 x BertLayer(\n", " (attention): BertAttention(\n", " (self): BertSelfAttention(\n", " (query): Linear(in_features=768, out_features=768, bias=True)\n", " (key): Linear(in_features=768, out_features=768, bias=True)\n", " (value): Linear(in_features=768, out_features=768, bias=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (output): BertSelfOutput(\n", " (dense): Linear(in_features=768, out_features=768, bias=True)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " (intermediate): BertIntermediate(\n", " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", " (intermediate_act_fn): GELUActivation()\n", " )\n", " (output): BertOutput(\n", " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " )\n", " )\n", " (pooler): BertPooler(\n", " (dense): Linear(in_features=768, out_features=768, bias=True)\n", " (activation): Tanh()\n", " )\n", ")" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model.eval()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "First, let's tokenize the sentences." ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'input_ids': tensor([[ 101, 7861, 8270, 4667, 102, 0, 0, 0, 0],\n", " [ 101, 1045, 2293, 3698, 4083, 1998, 17953, 2361, 102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0],\n", " [0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 0, 0, 0, 0],\n", " [1, 1, 1, 1, 1, 1, 1, 1, 1]])}" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "inputs = tokenizer(\n", " sentences, \n", " padding=True, \n", " truncation=True, \n", " return_tensors='pt', \n", " max_length=512\n", ")\n", "inputs" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "From the results, we can see that each sentence begins with token 101 and ends with 102, they are the `[CLS]` and `[SEP]` special token used in BERT." ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([2, 9, 768])" ] }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" } ], "source": [ "last_hidden_state = model(**inputs, return_dict=True).last_hidden_state\n", "last_hidden_state.shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Here we implement the pooling function, with two choices of using `[CLS]`'s last hidden state, or the mean pooling of the whole last hidden state." ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [], "source": [ "def pooling(last_hidden_state: torch.Tensor, pooling_method='cls', attention_mask: torch.Tensor = None):\n", " if pooling_method == 'cls':\n", " return last_hidden_state[:, 0]\n", " elif pooling_method == 'mean':\n", " s = torch.sum(last_hidden_state * attention_mask.unsqueeze(-1).float(), dim=1)\n", " d = attention_mask.sum(dim=1, keepdim=True).float()\n", " return s / d" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Different from more commonly used mean pooling, BGE is trained to use the last hidden state of `[CLS]` as the sentence embedding: \n", "\n", "`sentence_embeddings = model_output[0][:, 0]`\n", "\n", "If you use mean pooling, there will be a significant decrease in performance. Therefore, make sure to use the correct method to obtain sentence vectors." ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([2, 768])" ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ "embeddings = pooling(\n", " last_hidden_state, \n", " pooling_method='cls', \n", " attention_mask=inputs['attention_mask']\n", ")\n", "embeddings.shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Assembling them together, we get the whole encoding function:" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [], "source": [ "def _encode(sentences, max_length=512, convert_to_numpy=True):\n", "\n", " # handle the case of single sentence and a list of sentences\n", " input_was_string = False\n", " if isinstance(sentences, str):\n", " sentences = [sentences]\n", " input_was_string = True\n", "\n", " inputs = tokenizer(\n", " sentences, \n", " padding=True, \n", " truncation=True, \n", " return_tensors='pt', \n", " max_length=max_length\n", " )\n", "\n", " last_hidden_state = model(**inputs, return_dict=True).last_hidden_state\n", " \n", " embeddings = pooling(\n", " last_hidden_state, \n", " pooling_method='cls', \n", " attention_mask=inputs['attention_mask']\n", " )\n", "\n", " # normalize the embedding vectors\n", " embeddings = torch.nn.functional.normalize(embeddings, dim=-1)\n", "\n", " # convert to numpy if needed\n", " if convert_to_numpy:\n", " embeddings = embeddings.detach().numpy()\n", "\n", " return embeddings[0] if input_was_string else embeddings" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 2. Comparison" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now let's run the function we wrote to get the embeddings of the two sentences:" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Embeddings:\n", "[[ 1.4549762e-02 -9.6840411e-03 3.7761475e-03 ... -8.5092714e-04\n", " 2.8417887e-02 6.3214332e-02]\n", " [ 3.3924331e-05 -3.2998275e-03 1.7206438e-02 ... 3.5703944e-03\n", " 1.8721525e-02 -2.0371782e-02]]\n", "Similarity scores:\n", "[[0.9999997 0.6077381]\n", " [0.6077381 0.9999999]]\n" ] } ], "source": [ "embeddings = _encode(sentences)\n", "print(f\"Embeddings:\\n{embeddings}\")\n", "\n", "scores = embeddings @ embeddings.T\n", "print(f\"Similarity scores:\\n{scores}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Then, run the API provided in FlagEmbedding:" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Embeddings:\n", "[[ 1.4549762e-02 -9.6840411e-03 3.7761475e-03 ... -8.5092714e-04\n", " 2.8417887e-02 6.3214332e-02]\n", " [ 3.3924331e-05 -3.2998275e-03 1.7206438e-02 ... 3.5703944e-03\n", " 1.8721525e-02 -2.0371782e-02]]\n", "Similarity scores:\n", "[[0.9999997 0.6077381]\n", " [0.6077381 0.9999999]]\n" ] } ], "source": [ "from FlagEmbedding import FlagModel\n", "\n", "model = FlagModel('BAAI/bge-base-en-v1.5')\n", "\n", "embeddings = model.encode(sentences)\n", "print(f\"Embeddings:\\n{embeddings}\")\n", "\n", "scores = embeddings @ embeddings.T\n", "print(f\"Similarity scores:\\n{scores}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As we expect, the two encoding functions return exactly the same results. The full implementation in FlagEmbedding handles large datasets by batching and contains GPU support and parallelization. Feel free to check the [source code](https://github.com/FlagOpen/FlagEmbedding/blob/master/FlagEmbedding/flag_models.py#L370) for more details." ] } ], "metadata": { "kernelspec": { "display_name": "base", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.13" } }, "nbformat": 4, "nbformat_minor": 2 }