|
| 1 | +{ |
| 2 | + "cells": [ |
| 3 | + { |
| 4 | + "cell_type": "markdown", |
| 5 | + "metadata": { |
| 6 | + "id": "ewKGZW06kmIv" |
| 7 | + }, |
| 8 | + "source": [ |
| 9 | + "# Example of Embedding\n", |
| 10 | + "\n", |
| 11 | + "It is an embedding example that uses `tidb_vector_python` as its library." |
| 12 | + ] |
| 13 | + }, |
| 14 | + { |
| 15 | + "cell_type": "markdown", |
| 16 | + "metadata": { |
| 17 | + "id": "F1fsS576izUl" |
| 18 | + }, |
| 19 | + "source": [ |
| 20 | + "## Install Dependencies" |
| 21 | + ] |
| 22 | + }, |
| 23 | + { |
| 24 | + "cell_type": "code", |
| 25 | + "execution_count": null, |
| 26 | + "metadata": { |
| 27 | + "id": "pTpKX_lDizUp" |
| 28 | + }, |
| 29 | + "outputs": [], |
| 30 | + "source": [ |
| 31 | + "%%capture\n", |
| 32 | + "%pip install google.generativeai peewee pymysql tidb_vector" |
| 33 | + ] |
| 34 | + }, |
| 35 | + { |
| 36 | + "cell_type": "markdown", |
| 37 | + "metadata": { |
| 38 | + "id": "psEHGWiHizUq" |
| 39 | + }, |
| 40 | + "source": [ |
| 41 | + "## Preapre the environment\n", |
| 42 | + "\n", |
| 43 | + "> **Note:**\n", |
| 44 | + ">\n", |
| 45 | + "> - You can get the `GEMINI_API_KEY` from [GeminiAI](https://ai.google.dev/gemini-api/docs/quickstart).\n", |
| 46 | + "> - You can get the `TIDB_HOST`, `TIDB_USERNAME`, and `TIDB_PASSWORD` from the TiDB Cloud console, as described in the [Prerequisites](../README.md#prerequisites) section.\n", |
| 47 | + "\n", |
| 48 | + "Set the embedding model as `models/embedding-001`, and\n", |
| 49 | + "the amount of embedding dimensions is `768`." |
| 50 | + ] |
| 51 | + }, |
| 52 | + { |
| 53 | + "cell_type": "code", |
| 54 | + "execution_count": null, |
| 55 | + "metadata": { |
| 56 | + "id": "MgKOjwmYizUq" |
| 57 | + }, |
| 58 | + "outputs": [], |
| 59 | + "source": [ |
| 60 | + "import getpass\n", |
| 61 | + "\n", |
| 62 | + "GEMINI_API_KEY = getpass.getpass(\"Enter your GeminiAI API key: \")\n", |
| 63 | + "TIDB_HOST = input(\"Enter your TiDB host: \")\n", |
| 64 | + "TIDB_USERNAME = input(\"Enter your TiDB username: \")\n", |
| 65 | + "TIDB_PASSWORD = getpass.getpass(\"Enter your TiDB password: \")\n", |
| 66 | + "\n", |
| 67 | + "embedding_model = \"models/embedding-001\"\n", |
| 68 | + "embedding_dimensions = 768" |
| 69 | + ] |
| 70 | + }, |
| 71 | + { |
| 72 | + "cell_type": "markdown", |
| 73 | + "metadata": { |
| 74 | + "id": "3WbH_BITizUr" |
| 75 | + }, |
| 76 | + "source": [ |
| 77 | + "## Initial the Clients of OpenAI and Database" |
| 78 | + ] |
| 79 | + }, |
| 80 | + { |
| 81 | + "cell_type": "code", |
| 82 | + "execution_count": null, |
| 83 | + "metadata": { |
| 84 | + "id": "UWtcs58-izUr" |
| 85 | + }, |
| 86 | + "outputs": [], |
| 87 | + "source": [ |
| 88 | + "import google.generativeai as genai\n", |
| 89 | + "from peewee import Model, MySQLDatabase, TextField, SQL\n", |
| 90 | + "from tidb_vector.peewee import VectorField\n", |
| 91 | + "\n", |
| 92 | + "genai.configure(api_key=GEMINI_API_KEY)\n", |
| 93 | + "db = MySQLDatabase(\n", |
| 94 | + " 'test',\n", |
| 95 | + " user=TIDB_USERNAME,\n", |
| 96 | + " password=TIDB_PASSWORD,\n", |
| 97 | + " host=TIDB_HOST,\n", |
| 98 | + " port=4000,\n", |
| 99 | + " ssl_verify_cert=True,\n", |
| 100 | + " ssl_verify_identity=True\n", |
| 101 | + ")\n", |
| 102 | + "db.connect()" |
| 103 | + ] |
| 104 | + }, |
| 105 | + { |
| 106 | + "cell_type": "markdown", |
| 107 | + "metadata": { |
| 108 | + "id": "uOyjrmWJizUr" |
| 109 | + }, |
| 110 | + "source": [ |
| 111 | + "## Prepare the Context\n", |
| 112 | + "\n", |
| 113 | + "In this case, contexts are the documents, use the openai embeddings model to get the embeddings of the documents, and store them in the TiDB." |
| 114 | + ] |
| 115 | + }, |
| 116 | + { |
| 117 | + "cell_type": "code", |
| 118 | + "execution_count": null, |
| 119 | + "metadata": { |
| 120 | + "id": "_e5P_m0MizUs" |
| 121 | + }, |
| 122 | + "outputs": [], |
| 123 | + "source": [ |
| 124 | + "documents = [\n", |
| 125 | + " \"TiDB is an open-source distributed SQL database that supports Hybrid Transactional and Analytical Processing (HTAP) workloads.\",\n", |
| 126 | + " \"TiFlash is the key component that makes TiDB essentially an Hybrid Transactional/Analytical Processing (HTAP) database. As a columnar storage extension of TiKV, TiFlash provides both good isolation level and strong consistency guarantee.\",\n", |
| 127 | + " \"TiKV is a distributed and transactional key-value database, which provides transactional APIs with ACID compliance. With the implementation of the Raft consensus algorithm and consensus state stored in RocksDB, TiKV guarantees data consistency between multiple replicas and high availability. \",\n", |
| 128 | + "]\n", |
| 129 | + "\n", |
| 130 | + "class DocModel(Model):\n", |
| 131 | + " text = TextField()\n", |
| 132 | + " embedding = VectorField(dimensions=embedding_dimensions)\n", |
| 133 | + "\n", |
| 134 | + " class Meta:\n", |
| 135 | + " database = db\n", |
| 136 | + " table_name = \"gemini_embedding_test\"\n", |
| 137 | + "\n", |
| 138 | + " def __str__(self):\n", |
| 139 | + " return self.text\n", |
| 140 | + "\n", |
| 141 | + "db.drop_tables([DocModel])\n", |
| 142 | + "db.create_tables([DocModel])\n", |
| 143 | + "\n", |
| 144 | + "embeddings = genai.embed_content(model=embedding_model, content=documents, task_type=\"retrieval_document\")\n", |
| 145 | + "data_source = [\n", |
| 146 | + " {\"text\": doc, \"embedding\": emb}\n", |
| 147 | + " for doc, emb in zip(documents, embeddings['embedding'])\n", |
| 148 | + "]\n", |
| 149 | + "DocModel.insert_many(data_source).execute()" |
| 150 | + ] |
| 151 | + }, |
| 152 | + { |
| 153 | + "cell_type": "markdown", |
| 154 | + "metadata": { |
| 155 | + "id": "zMP-P1g8izUs" |
| 156 | + }, |
| 157 | + "source": [ |
| 158 | + "## Initial the Vector of Question\n", |
| 159 | + "\n", |
| 160 | + "Ask a question, use the openai embeddings model to get the embeddings of the question" |
| 161 | + ] |
| 162 | + }, |
| 163 | + { |
| 164 | + "cell_type": "code", |
| 165 | + "execution_count": null, |
| 166 | + "metadata": { |
| 167 | + "id": "-zrTOxs4izUt" |
| 168 | + }, |
| 169 | + "outputs": [], |
| 170 | + "source": [ |
| 171 | + "question = \"what is TiKV?\"\n", |
| 172 | + "question_embedding = genai.embed_content(model=embedding_model, content=[question], task_type=\"retrieval_query\")['embedding'][0]" |
| 173 | + ] |
| 174 | + }, |
| 175 | + { |
| 176 | + "cell_type": "markdown", |
| 177 | + "metadata": { |
| 178 | + "id": "atc0gXVZizUt" |
| 179 | + }, |
| 180 | + "source": [ |
| 181 | + "## Retrieve by Cosine Distance of Vectors\n", |
| 182 | + "Get the relevant documents from the TiDB by comparing the embeddings of the question and the documents" |
| 183 | + ] |
| 184 | + }, |
| 185 | + { |
| 186 | + "cell_type": "code", |
| 187 | + "execution_count": null, |
| 188 | + "metadata": { |
| 189 | + "id": "DTtJRX64izUt" |
| 190 | + }, |
| 191 | + "outputs": [], |
| 192 | + "source": [ |
| 193 | + "related_docs = DocModel.select(\n", |
| 194 | + " DocModel.text, DocModel.embedding.cosine_distance(question_embedding).alias(\"distance\")\n", |
| 195 | + ").order_by(SQL(\"distance\")).limit(3)\n", |
| 196 | + "\n", |
| 197 | + "print(\"Question:\", question)\n", |
| 198 | + "print(\"Related documents:\")\n", |
| 199 | + "for doc in related_docs:\n", |
| 200 | + " print(doc.distance, doc.text)" |
| 201 | + ] |
| 202 | + }, |
| 203 | + { |
| 204 | + "cell_type": "markdown", |
| 205 | + "metadata": { |
| 206 | + "id": "bYBetPchmNUp" |
| 207 | + }, |
| 208 | + "source": [ |
| 209 | + "## Cleanup" |
| 210 | + ] |
| 211 | + }, |
| 212 | + { |
| 213 | + "cell_type": "code", |
| 214 | + "execution_count": null, |
| 215 | + "metadata": { |
| 216 | + "id": "Lh27gC7gizUt" |
| 217 | + }, |
| 218 | + "outputs": [], |
| 219 | + "source": [ |
| 220 | + "db.close()" |
| 221 | + ] |
| 222 | + } |
| 223 | + ], |
| 224 | + "metadata": { |
| 225 | + "colab": { |
| 226 | + "provenance": [], |
| 227 | + "toc_visible": true |
| 228 | + }, |
| 229 | + "kernelspec": { |
| 230 | + "display_name": ".venv", |
| 231 | + "language": "python", |
| 232 | + "name": "python3" |
| 233 | + }, |
| 234 | + "language_info": { |
| 235 | + "codemirror_mode": { |
| 236 | + "name": "ipython", |
| 237 | + "version": 3 |
| 238 | + }, |
| 239 | + "file_extension": ".py", |
| 240 | + "mimetype": "text/x-python", |
| 241 | + "name": "python", |
| 242 | + "nbconvert_exporter": "python", |
| 243 | + "pygments_lexer": "ipython3", |
| 244 | + "version": "3.10.13" |
| 245 | + } |
| 246 | + }, |
| 247 | + "nbformat": 4, |
| 248 | + "nbformat_minor": 0 |
| 249 | +} |
0 commit comments