{ "cells": [ { "cell_type": "markdown", "id": "06cff9e4", "metadata": {}, "source": [ "# BGE Series" ] }, { "cell_type": "markdown", "id": "880e229d", "metadata": {}, "source": [ "In this Part, we will walk through the BGE series and introduce how to use the BGE embedding models." ] }, { "cell_type": "markdown", "id": "2516fd49", "metadata": {}, "source": [ "## 1. BAAI General Embedding" ] }, { "cell_type": "markdown", "id": "2113ee71", "metadata": {}, "source": [ "BGE stands for BAAI General Embedding, it's a series of embeddings models developed and published by Beijing Academy of Artificial Intelligence (BAAI)." ] }, { "cell_type": "markdown", "id": "16515b99", "metadata": {}, "source": [ "A full support of APIs and related usages of BGE is maintained in [FlagEmbedding](https://github.com/FlagOpen/FlagEmbedding) on GitHub.\n", "\n", "Run the following cell to install FlagEmbedding in your environment." ] }, { "cell_type": "code", "execution_count": null, "id": "88095fd0", "metadata": {}, "outputs": [], "source": [ "%%capture\n", "%pip install -U FlagEmbedding" ] }, { "cell_type": "code", "execution_count": 3, "id": "a2376217", "metadata": {}, "outputs": [], "source": [ "import os \n", "os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'\n", "# single GPU is better for small tasks\n", "os.environ['CUDA_VISIBLE_DEVICES'] = '0'" ] }, { "cell_type": "markdown", "id": "bc6e30a0", "metadata": {}, "source": [ "The collection of BGE models can be found in [Huggingface collection](https://huggingface.co/collections/BAAI/bge-66797a74476eb1f085c7446d)." ] }, { "cell_type": "markdown", "id": "67a16ccf", "metadata": {}, "source": [ "## 2. BGE Series Models" ] }, { "cell_type": "markdown", "id": "2e10034a", "metadata": {}, "source": [ "### 2.1 BGE" ] }, { "cell_type": "markdown", "id": "0cdc6702", "metadata": {}, "source": [ "The very first version of BGE has 6 models, with 'large', 'base', and 'small' for English and Chinese. " ] }, { "cell_type": "markdown", "id": "04b75f72", "metadata": {}, "source": [ "| Model | Language | Parameters | Model Size | Description | Base Model |\n", "|:-------|:--------:|:--------------:|:--------------:|:-----------------:|:----------------:|\n", "| [BAAI/bge-large-en](https://huggingface.co/BAAI/bge-large-en) | English | 500M | 1.34 GB | Embedding Model which map text into vector | BERT |\n", "| [BAAI/bge-base-en](https://huggingface.co/BAAI/bge-base-en) | English | 109M | 438 MB | a base-scale model but with similar ability to `bge-large-en` | BERT |\n", "| [BAAI/bge-small-en](https://huggingface.co/BAAI/bge-small-en) | English | 33.4M | 133 MB | a small-scale model but with competitive performance | BERT |\n", "| [BAAI/bge-large-zh](https://huggingface.co/BAAI/bge-large-zh) | Chinese | 326M | 1.3 GB | Embedding Model which map text into vector | BERT |\n", "| [BAAI/bge-base-zh](https://huggingface.co/BAAI/bge-base-zh) | Chinese | 102M | 409 MB | a base-scale model but with similar ability to `bge-large-zh` | BERT |\n", "| [BAAI/bge-small-zh](https://huggingface.co/BAAI/bge-small-zh) | Chinese | 24M | 95.8 MB | a small-scale model but with competitive performance | BERT |" ] }, { "cell_type": "markdown", "id": "c9c45d17", "metadata": {}, "source": [ "For inference, simply import FlagModel from FlagEmbedding and initialize the model." ] }, { "cell_type": "code", "execution_count": null, "id": "89e07751", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[[0.84864 0.7946737 ]\n", " [0.760097 0.85449743]]\n" ] } ], "source": [ "from FlagEmbedding import FlagModel\n", "\n", "# Load BGE model\n", "model = FlagModel(\n", " 'BAAI/bge-base-en',\n", " query_instruction_for_retrieval=\"Represent this sentence for searching relevant passages:\",\n", " query_instruction_format='{}{}',\n", ")\n", "\n", "queries = [\"query 1\", \"query 2\"]\n", "corpus = [\"passage 1\", \"passage 2\"]\n", "\n", "# encode the queries and corpus\n", "q_embeddings = model.encode_queries(queries)\n", "p_embeddings = model.encode_corpus(corpus)\n", "\n", "# compute the similarity scores\n", "scores = q_embeddings @ p_embeddings.T\n", "print(scores)" ] }, { "cell_type": "markdown", "id": "6c8e69ed", "metadata": {}, "source": [ "For general encoding, use either `encode()`:\n", "```python\n", "FlagModel.encode(sentences, batch_size=256, max_length=512, convert_to_numpy=True)\n", "```\n", "or `encode_corpus()` that directly calls `encode()`:\n", "```python\n", "FlagModel.encode_corpus(corpus, batch_size=256, max_length=512, convert_to_numpy=True)\n", "```\n", "The *encode_queries()* function concatenate the `query_instruction_for_retrieval` with each of the input query to form the new sentences and then feed them to `encode()`.\n", "```python\n", "FlagModel.encode_queries(queries, batch_size=256, max_length=512, convert_to_numpy=True)\n", "```" ] }, { "cell_type": "markdown", "id": "2c86a5a3", "metadata": {}, "source": [ "### 2.2 BGE v1.5" ] }, { "cell_type": "markdown", "id": "454ff7aa", "metadata": {}, "source": [ "BGE 1.5 alleviate the issue of the similarity distribution, and enhance retrieval ability without instruction." ] }, { "cell_type": "markdown", "id": "30b1f897", "metadata": {}, "source": [ "| Model | Language | Parameters | Model Size | Description | Base Model |\n", "|:-------|:--------:|:--------------:|:--------------:|:-----------------:|:----------------:|\n", "| [BAAI/bge-large-en-v1.5](https://huggingface.co/BAAI/bge-large-en-v1.5) | English | 335M | 1.34 GB | version 1.5 with more reasonable similarity distribution | BERT |\n", "| [BAAI/bge-base-en-v1.5](https://huggingface.co/BAAI/bge-base-en-v1.5) | English | 109M | 438 MB | version 1.5 with more reasonable similarity distribution | BERT |\n", "| [BAAI/bge-small-en-v1.5](https://huggingface.co/BAAI/bge-small-en-v1.5) | English | 33.4M | 133 MB | version 1.5 with more reasonable similarity distribution | BERT |\n", "| [BAAI/bge-large-zh-v1.5](https://huggingface.co/BAAI/bge-large-zh-v1.5) | Chinese | 326M | 1.3 GB | version 1.5 with more reasonable similarity distribution | BERT |\n", "| [BAAI/bge-base-zh-v1.5](https://huggingface.co/BAAI/bge-base-zh-v1.5) | Chinese | 102M | 409 MB | version 1.5 with more reasonable similarity distribution | BERT |\n", "| [BAAI/bge-small-zh-v1.5](https://huggingface.co/BAAI/bge-small-zh-v1.5) | Chinese | 24M | 95.8 MB | version 1.5 with more reasonable similarity distribution | BERT |" ] }, { "cell_type": "markdown", "id": "ed00c504", "metadata": {}, "source": [ "You can use BGE 1.5 models exactly same to BGE v1 models." ] }, { "cell_type": "code", "execution_count": 3, "id": "9b17afcc", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "pre tokenize: 100%|██████████| 1/1 [00:00<00:00, 2252.58it/s]\n", "pre tokenize: 100%|██████████| 1/1 [00:00<00:00, 3575.71it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[[0.76 0.6714]\n", " [0.6177 0.7603]]\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "model = FlagModel(\n", " 'BAAI/bge-base-en-v1.5',\n", " query_instruction_for_retrieval=\"Represent this sentence for searching relevant passages:\",\n", " query_instruction_format='{}{}'\n", ")\n", "\n", "queries = [\"query 1\", \"query 2\"]\n", "corpus = [\"passage 1\", \"passage 2\"]\n", "\n", "# encode the queries and corpus\n", "q_embeddings = model.encode_queries(queries)\n", "p_embeddings = model.encode_corpus(corpus)\n", "\n", "# compute the similarity scores\n", "scores = q_embeddings @ p_embeddings.T\n", "print(scores)" ] }, { "cell_type": "markdown", "id": "dcf2a82b", "metadata": {}, "source": [ "### 2.3 BGE M3" ] }, { "cell_type": "markdown", "id": "cc5b5a5e", "metadata": {}, "source": [ "BGE-M3 is the new version of BGE models that is distinguished for its versatility in:\n", "- Multi-Functionality: Simultaneously perform the three common retrieval functionalities of embedding model: dense retrieval, multi-vector retrieval, and sparse retrieval.\n", "- Multi-Linguality: Supports more than 100 working languages.\n", "- Multi-Granularity: Can proces inputs with different granularityies, spanning from short sentences to long documents of up to 8192 tokens.\n", "\n", "For more details, feel free to check out the [paper](https://arxiv.org/pdf/2402.03216)." ] }, { "cell_type": "markdown", "id": "41348e03", "metadata": {}, "source": [ "| Model | Language | Parameters | Model Size | Description | Base Model |\n", "|:-------|:--------:|:--------------:|:--------------:|:-----------------:|:----------------:|\n", "| [BAAI/bge-m3](https://huggingface.co/BAAI/bge-m3) | Multilingual | 568M | 2.27 GB | Multi-Functionality(dense retrieval, sparse retrieval, multi-vector(colbert)), Multi-Linguality, and Multi-Granularity(8192 tokens) | XLM-RoBERTa |" ] }, { "cell_type": "code", "execution_count": 4, "id": "d4647625", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Fetching 30 files: 100%|██████████| 30/30 [00:00<00:00, 194180.74it/s]\n" ] } ], "source": [ "from FlagEmbedding import BGEM3FlagModel\n", "\n", "model = BGEM3FlagModel('BAAI/bge-m3', use_fp16=True)\n", "\n", "sentences = [\"What is BGE M3?\", \"Defination of BM25\"]" ] }, { "cell_type": "markdown", "id": "1f89f1a9", "metadata": {}, "source": [ "```python\n", "BGEM3FlagModel.encode(\n", " sentences, \n", " batch_size=12, \n", " max_length=8192, \n", " return_dense=True, \n", " return_sparse=False, \n", " return_colbert_vecs=False\n", ")\n", "```\n", "It returns a dictionary like:\n", "```python\n", "{\n", " 'dense_vecs': # array of dense embeddings of inputs if return_dense=True, otherwise None,\n", " 'lexical_weights': # array of dictionaries with keys and values are ids of tokens and their corresponding weights if return_sparse=True, otherwise None,\n", " 'colbert_vecs': # array of multi-vector embeddings of inputs if return_cobert_vecs=True, otherwise None,'\n", "}\n", "```" ] }, { "cell_type": "code", "execution_count": 5, "id": "f0b11cf0", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "pre tokenize: 100%|██████████| 1/1 [00:00<00:00, 1148.18it/s]\n" ] } ], "source": [ "# If you don't need such a long length of 8192 input tokens, you can set max_length to a smaller value to speed up encoding.\n", "embeddings = model.encode(\n", " sentences, \n", " max_length=10,\n", " return_dense=True, \n", " return_sparse=True, \n", " return_colbert_vecs=True\n", ")" ] }, { "cell_type": "code", "execution_count": 6, "id": "72cba126", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "dense embedding:\n", "[[-0.03412 -0.04706 -0.00087 ... 0.04822 0.007614 -0.02957 ]\n", " [-0.01035 -0.04483 -0.02434 ... -0.008224 0.01497 0.011055]]\n", "sparse embedding:\n", "[defaultdict(, {'4865': np.float16(0.0836), '83': np.float16(0.0814), '335': np.float16(0.1296), '11679': np.float16(0.2517), '276': np.float16(0.1699), '363': np.float16(0.2695), '32': np.float16(0.04077)}), defaultdict(, {'262': np.float16(0.05014), '5983': np.float16(0.1367), '2320': np.float16(0.04517), '111': np.float16(0.0634), '90017': np.float16(0.2517), '2588': np.float16(0.3333)})]\n", "multi-vector:\n", "[array([[-8.68966337e-03, -4.89266850e-02, -3.03634931e-03, ...,\n", " -2.21243706e-02, 5.72856329e-02, 1.28355855e-02],\n", " [-8.92937183e-03, -4.67235669e-02, -9.52814799e-03, ...,\n", " -3.14785317e-02, 5.39088845e-02, 6.96671568e-03],\n", " [ 1.84195358e-02, -4.22310382e-02, 8.55499704e-04, ...,\n", " -1.97946690e-02, 3.84313315e-02, 7.71250250e-03],\n", " ...,\n", " [-2.55824160e-02, -1.65533274e-02, -4.21357416e-02, ...,\n", " -4.50234264e-02, 4.41286489e-02, -1.00052059e-02],\n", " [ 5.90990965e-07, -5.53734899e-02, 8.51499755e-03, ...,\n", " -2.29209941e-02, 6.04418293e-02, 9.39912070e-03],\n", " [ 2.57394509e-03, -2.92690992e-02, -1.89342294e-02, ...,\n", " -8.04431178e-03, 3.28964666e-02, 4.38723788e-02]], dtype=float32), array([[ 0.01724418, 0.03835401, -0.02309308, ..., 0.00141706,\n", " 0.02995041, -0.05990082],\n", " [ 0.00996325, 0.03922409, -0.03849588, ..., 0.00591671,\n", " 0.02722516, -0.06510868],\n", " [ 0.01781915, 0.03925728, -0.01710397, ..., 0.00801776,\n", " 0.03987768, -0.05070014],\n", " ...,\n", " [ 0.05478653, 0.00755799, 0.00328444, ..., -0.01648209,\n", " 0.02405782, 0.00363262],\n", " [ 0.00936953, 0.05028074, -0.02388872, ..., 0.02567679,\n", " 0.00791224, -0.03257877],\n", " [ 0.01803976, 0.0133922 , 0.00019365, ..., 0.0184015 ,\n", " 0.01373822, 0.00315539]], dtype=float32)]\n" ] } ], "source": [ "print(f\"dense embedding:\\n{embeddings['dense_vecs']}\")\n", "print(f\"sparse embedding:\\n{embeddings['lexical_weights']}\")\n", "print(f\"multi-vector:\\n{embeddings['colbert_vecs']}\")" ] }, { "cell_type": "markdown", "id": "14d83caa", "metadata": {}, "source": [ "### 2.4 BGE Multilingual Gemma2" ] }, { "cell_type": "markdown", "id": "fd4c67df", "metadata": {}, "source": [ "BGE Multilingual Gemma2 is a LLM-based Multi-Lingual embedding model." ] }, { "cell_type": "markdown", "id": "abdca22e", "metadata": {}, "source": [ "| Model | Language | Parameters | Model Size | Description | Base Model |\n", "|:-------|:--------:|:--------------:|:--------------:|:-----------------:|:----------------:|\n", "| [BAAI/bge-multilingual-gemma2](https://huggingface.co/BAAI/bge-multilingual-gemma2) | Multilingual | 9.24B | 37 GB | LLM-based multilingual embedding model with SOTA results on multilingual benchmarks | Gemma2-9B |" ] }, { "cell_type": "code", "execution_count": 7, "id": "8ec545bc", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Loading checkpoint shards: 100%|██████████| 4/4 [00:00<00:00, 6.34it/s]\n", "pre tokenize: 100%|██████████| 1/1 [00:00<00:00, 816.49it/s]\n", "pre tokenize: 100%|██████████| 1/1 [00:00<00:00, 718.33it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[[0.559 0.01685 ]\n", " [0.0008683 0.5015 ]]\n" ] } ], "source": [ "from FlagEmbedding import FlagLLMModel\n", "\n", "queries = [\"how much protein should a female eat\", \"summit define\"]\n", "documents = [\n", " \"As a general guideline, the CDC's average requirement of protein for women ages 19 to 70 is 46 grams per day. But, as you can see from this chart, you'll need to increase that if you're expecting or training for a marathon. Check out the chart below to see how much protein you should be eating each day.\",\n", " \"Definition of summit for English Language Learners. : 1 the highest point of a mountain : the top of a mountain. : 2 the highest level. : 3 a meeting or series of meetings between the leaders of two or more governments.\"\n", "]\n", "\n", "model = FlagLLMModel('BAAI/bge-multilingual-gemma2', \n", " query_instruction_for_retrieval=\"Given a web search query, retrieve relevant passages that answer the query.\",\n", " use_fp16=True) # Setting use_fp16 to True speeds up computation with a slight performance degradation\n", "\n", "embeddings_1 = model.encode_queries(queries)\n", "embeddings_2 = model.encode_corpus(documents)\n", "similarity = embeddings_1 @ embeddings_2.T\n", "print(similarity)\n" ] }, { "cell_type": "markdown", "id": "8b7b2aa4", "metadata": {}, "source": [ "### 2.4 BGE ICL" ] }, { "cell_type": "markdown", "id": "7c9acb92", "metadata": {}, "source": [ "BGE ICL stands for in-context learning. By providing few-shot examples in the query, it can significantly enhance the model's ability to handle new tasks." ] }, { "cell_type": "markdown", "id": "cf6c9345", "metadata": {}, "source": [ "| Model | Language | Parameters | Model Size | Description | Base Model |\n", "|:-------|:--------:|:--------------:|:--------------:|:-----------------:|:----------------:|\n", "| [BAAI/bge-en-icl](https://huggingface.co/BAAI/bge-en-icl) | English | 7.11B | 28.5 GB | LLM-based English embedding model with excellent in-context learning ability. | Mistral-7B |" ] }, { "cell_type": "code", "execution_count": 8, "id": "4595bae7", "metadata": {}, "outputs": [], "source": [ "documents = [\n", " \"As a general guideline, the CDC's average requirement of protein for women ages 19 to 70 is 46 grams per day. But, as you can see from this chart, you'll need to increase that if you're expecting or training for a marathon. Check out the chart below to see how much protein you should be eating each day.\",\n", " \"Definition of summit for English Language Learners. : 1 the highest point of a mountain : the top of a mountain. : 2 the highest level. : 3 a meeting or series of meetings between the leaders of two or more governments.\"\n", "]\n", "\n", "examples = [\n", " {\n", " 'instruct': 'Given a web search query, retrieve relevant passages that answer the query.',\n", " 'query': 'what is a virtual interface',\n", " 'response': \"A virtual interface is a software-defined abstraction that mimics the behavior and characteristics of a physical network interface. It allows multiple logical network connections to share the same physical network interface, enabling efficient utilization of network resources. Virtual interfaces are commonly used in virtualization technologies such as virtual machines and containers to provide network connectivity without requiring dedicated hardware. They facilitate flexible network configurations and help in isolating network traffic for security and management purposes.\"\n", " },\n", " {\n", " 'instruct': 'Given a web search query, retrieve relevant passages that answer the query.',\n", " 'query': 'causes of back pain in female for a week',\n", " 'response': \"Back pain in females lasting a week can stem from various factors. Common causes include muscle strain due to lifting heavy objects or improper posture, spinal issues like herniated discs or osteoporosis, menstrual cramps causing referred pain, urinary tract infections, or pelvic inflammatory disease. Pregnancy-related changes can also contribute. Stress and lack of physical activity may exacerbate symptoms. Proper diagnosis by a healthcare professional is crucial for effective treatment and management.\"\n", " }\n", "]\n", "\n", "queries = [\"how much protein should a female eat\", \"summit define\"]" ] }, { "cell_type": "code", "execution_count": null, "id": "ffb586c6", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 6.55it/s]\n", "pre tokenize: 100%|██████████| 1/1 [00:00<00:00, 366.09it/s]\n", "pre tokenize: 100%|██████████| 1/1 [00:00<00:00, 623.69it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[[0.6064 0.3018]\n", " [0.257 0.537 ]]\n" ] } ], "source": [ "from FlagEmbedding import FlagICLModel\n", "import os\n", "\n", "model = FlagICLModel('BAAI/bge-en-icl', \n", " examples_for_task=examples, # set `examples_for_task=None` to use model without examples\n", " # examples_instruction_format=\"{}\\n{}\\n{}\" # specify the format to use examples_for_task\n", " )\n", "\n", "embeddings_1 = model.encode_queries(queries)\n", "embeddings_2 = model.encode_corpus(documents)\n", "similarity = embeddings_1 @ embeddings_2.T\n", "\n", "print(similarity)" ] } ], "metadata": { "kernelspec": { "display_name": "dev", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.7" } }, "nbformat": 4, "nbformat_minor": 5 }