{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# BGE Auto Embedder" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "FlagEmbedding provides a high level class `FlagAutoModel` that unify the inference of embedding models. Besides BGE series, it also supports other popular open-source embedding models such as E5, GTE, SFR, etc. In this tutorial, we will have an idea how to use it." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "% pip install FlagEmbedding" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 1. Usage" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "First, import `FlagAutoModel` from FlagEmbedding, and use the `from_finetuned()` function to initialize the model:" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "from FlagEmbedding import FlagAutoModel\n", "\n", "model = FlagAutoModel.from_finetuned(\n", " 'BAAI/bge-base-en-v1.5',\n", " query_instruction_for_retrieval=\"Represent this sentence for searching relevant passages: \",\n", " devices=\"cuda:0\", # if not specified, will use all available gpus or cpu when no gpu available\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Then use the model exactly same to `FlagModel` (`FlagM3Model` if using BGE M3, `FlagLLMModel` if using BGE Multilingual Gemma2, `FlagICLModel` if using BGE ICL)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[[0.76 0.6714]\n", " [0.6177 0.7603]]\n" ] } ], "source": [ "queries = [\"query 1\", \"query 2\"]\n", "corpus = [\"passage 1\", \"passage 2\"]\n", "\n", "# encode the queries and corpus\n", "q_embeddings = model.encode_queries(queries)\n", "p_embeddings = model.encode_corpus(corpus)\n", "\n", "# compute the similarity scores\n", "scores = q_embeddings @ p_embeddings.T\n", "print(scores)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 2. Explanation" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`FlagAutoModel` use an OrderedDict `MODEL_MAPPING` to store all the supported models configuration:" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['bge-en-icl',\n", " 'bge-multilingual-gemma2',\n", " 'bge-m3',\n", " 'bge-large-en-v1.5',\n", " 'bge-base-en-v1.5',\n", " 'bge-small-en-v1.5',\n", " 'bge-large-zh-v1.5',\n", " 'bge-base-zh-v1.5',\n", " 'bge-small-zh-v1.5',\n", " 'bge-large-en',\n", " 'bge-base-en',\n", " 'bge-small-en',\n", " 'bge-large-zh',\n", " 'bge-base-zh',\n", " 'bge-small-zh',\n", " 'e5-mistral-7b-instruct',\n", " 'e5-large-v2',\n", " 'e5-base-v2',\n", " 'e5-small-v2',\n", " 'multilingual-e5-large-instruct',\n", " 'multilingual-e5-large',\n", " 'multilingual-e5-base',\n", " 'multilingual-e5-small',\n", " 'e5-large',\n", " 'e5-base',\n", " 'e5-small',\n", " 'gte-Qwen2-7B-instruct',\n", " 'gte-Qwen2-1.5B-instruct',\n", " 'gte-Qwen1.5-7B-instruct',\n", " 'gte-multilingual-base',\n", " 'gte-large-en-v1.5',\n", " 'gte-base-en-v1.5',\n", " 'gte-large',\n", " 'gte-base',\n", " 'gte-small',\n", " 'gte-large-zh',\n", " 'gte-base-zh',\n", " 'gte-small-zh',\n", " 'SFR-Embedding-2_R',\n", " 'SFR-Embedding-Mistral',\n", " 'Linq-Embed-Mistral']" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from FlagEmbedding.inference.embedder.model_mapping import AUTO_EMBEDDER_MAPPING\n", "\n", "list(AUTO_EMBEDDER_MAPPING.keys())" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "EmbedderConfig(model_class=, pooling_method=, trust_remote_code=False, query_instruction_format='{}\\n{}')\n" ] } ], "source": [ "print(AUTO_EMBEDDER_MAPPING['bge-en-icl'])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Taking a look at the value of each key, which is an object of `EmbedderConfig`. It consists four attributes:" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "```python\n", "@dataclass\n", "class EmbedderConfig:\n", " model_class: Type[AbsEmbedder]\n", " pooling_method: PoolingMethod\n", " trust_remote_code: bool = False\n", " query_instruction_format: str = \"{}{}\"\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Not only the BGE series, it supports other models such as E5 similarly:" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "EmbedderConfig(model_class=, pooling_method=, trust_remote_code=False, query_instruction_format='{}\\n{}')\n" ] } ], "source": [ "print(AUTO_EMBEDDER_MAPPING['bge-en-icl'])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 3. Customization" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If you want to use your own models through `FlagAutoModel`, consider the following steps:\n", "\n", "1. Check the type of your embedding model and choose the appropriate model class, is it an encoder or a decoder?\n", "2. What kind of pooling method it uses? CLS token, mean pooling, or last token?\n", "3. Does your model needs `trust_remote_code=Ture` to ran?\n", "4. Is there a query instruction format for retrieval?\n", "\n", "After these four attributes are assured, add your model name as the key and corresponding EmbedderConfig as the value to `MODEL_MAPPING`. Now have a try!" ] } ], "metadata": { "kernelspec": { "display_name": "dev", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.7" } }, "nbformat": 4, "nbformat_minor": 2 }