Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit f7825f8

Browse files
authored
example: add gemini gen ai embedding model demo (#54)
Added gen ai embedding model demo built with tidb vector
1 parent 8c777d1 commit f7825f8

File tree

4 files changed

+376
-0
lines changed

4 files changed

+376
-0
lines changed
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# GeminiAI Embedding Example
2+
3+
This example demonstrates how to utilize GeminiAI embedding for semantic search. According to GeminiAI's [documentation](https://ai.google.dev/gemini-api/docs/embeddings), we will use cosine similarity to calculate vector distance.
4+
5+
You can run this example in two ways:
6+
7+
- [Run in Jupyter Notebook](#jupyter-notebook)
8+
- [Run in Local](#run-in-local)
9+
10+
## Jupyter Notebook
11+
12+
Notebook: [example.ipynb](./example.ipynb)
13+
14+
Try it in the [Google colab](https://colab.research.google.com/github/pingcap/tidb-vector-python/blob/main/examples/gemini-ai-embeddings-demo/example.ipynb).
15+
16+
## Run in Local
17+
18+
### Create a virtual environment
19+
20+
```bash
21+
python3 -m venv .venv
22+
source .venv/bin/activate
23+
```
24+
25+
### Install the requirements
26+
27+
```bash
28+
pip install -r requirements.txt
29+
```
30+
31+
### Set the environment variables
32+
33+
Get the `GEMINI_API_KEY` from [GeminiAI](https://ai.google.dev/gemini-api/docs/quickstart)
34+
35+
Get the `TIDB_HOST`, `TIDB_USERNAME`, and `TIDB_PASSWORD` from the TiDB Cloud console, as described in the [Prerequisites](../README.md#prerequisites) section.
36+
37+
```bash
38+
export GEMINI_API_KEY="*******"
39+
export TIDB_HOST="gateway01.*******.shared.aws.tidbcloud.com"
40+
export TIDB_USERNAME="****.root"
41+
export TIDB_PASSWORD="****"
42+
```
43+
44+
### Run the example
45+
46+
```bash
47+
python3 example.py
48+
```
Lines changed: 249 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,249 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {
6+
"id": "ewKGZW06kmIv"
7+
},
8+
"source": [
9+
"# Example of Embedding\n",
10+
"\n",
11+
"It is an embedding example that uses `tidb_vector_python` as its library."
12+
]
13+
},
14+
{
15+
"cell_type": "markdown",
16+
"metadata": {
17+
"id": "F1fsS576izUl"
18+
},
19+
"source": [
20+
"## Install Dependencies"
21+
]
22+
},
23+
{
24+
"cell_type": "code",
25+
"execution_count": null,
26+
"metadata": {
27+
"id": "pTpKX_lDizUp"
28+
},
29+
"outputs": [],
30+
"source": [
31+
"%%capture\n",
32+
"%pip install google.generativeai peewee pymysql tidb_vector"
33+
]
34+
},
35+
{
36+
"cell_type": "markdown",
37+
"metadata": {
38+
"id": "psEHGWiHizUq"
39+
},
40+
"source": [
41+
"## Preapre the environment\n",
42+
"\n",
43+
"> **Note:**\n",
44+
">\n",
45+
"> - You can get the `GEMINI_API_KEY` from [GeminiAI](https://ai.google.dev/gemini-api/docs/quickstart).\n",
46+
"> - You can get the `TIDB_HOST`, `TIDB_USERNAME`, and `TIDB_PASSWORD` from the TiDB Cloud console, as described in the [Prerequisites](../README.md#prerequisites) section.\n",
47+
"\n",
48+
"Set the embedding model as `models/embedding-001`, and\n",
49+
"the amount of embedding dimensions is `768`."
50+
]
51+
},
52+
{
53+
"cell_type": "code",
54+
"execution_count": null,
55+
"metadata": {
56+
"id": "MgKOjwmYizUq"
57+
},
58+
"outputs": [],
59+
"source": [
60+
"import getpass\n",
61+
"\n",
62+
"GEMINI_API_KEY = getpass.getpass(\"Enter your GeminiAI API key: \")\n",
63+
"TIDB_HOST = input(\"Enter your TiDB host: \")\n",
64+
"TIDB_USERNAME = input(\"Enter your TiDB username: \")\n",
65+
"TIDB_PASSWORD = getpass.getpass(\"Enter your TiDB password: \")\n",
66+
"\n",
67+
"embedding_model = \"models/embedding-001\"\n",
68+
"embedding_dimensions = 768"
69+
]
70+
},
71+
{
72+
"cell_type": "markdown",
73+
"metadata": {
74+
"id": "3WbH_BITizUr"
75+
},
76+
"source": [
77+
"## Initial the Clients of OpenAI and Database"
78+
]
79+
},
80+
{
81+
"cell_type": "code",
82+
"execution_count": null,
83+
"metadata": {
84+
"id": "UWtcs58-izUr"
85+
},
86+
"outputs": [],
87+
"source": [
88+
"import google.generativeai as genai\n",
89+
"from peewee import Model, MySQLDatabase, TextField, SQL\n",
90+
"from tidb_vector.peewee import VectorField\n",
91+
"\n",
92+
"genai.configure(api_key=GEMINI_API_KEY)\n",
93+
"db = MySQLDatabase(\n",
94+
" 'test',\n",
95+
" user=TIDB_USERNAME,\n",
96+
" password=TIDB_PASSWORD,\n",
97+
" host=TIDB_HOST,\n",
98+
" port=4000,\n",
99+
" ssl_verify_cert=True,\n",
100+
" ssl_verify_identity=True\n",
101+
")\n",
102+
"db.connect()"
103+
]
104+
},
105+
{
106+
"cell_type": "markdown",
107+
"metadata": {
108+
"id": "uOyjrmWJizUr"
109+
},
110+
"source": [
111+
"## Prepare the Context\n",
112+
"\n",
113+
"In this case, contexts are the documents, use the openai embeddings model to get the embeddings of the documents, and store them in the TiDB."
114+
]
115+
},
116+
{
117+
"cell_type": "code",
118+
"execution_count": null,
119+
"metadata": {
120+
"id": "_e5P_m0MizUs"
121+
},
122+
"outputs": [],
123+
"source": [
124+
"documents = [\n",
125+
" \"TiDB is an open-source distributed SQL database that supports Hybrid Transactional and Analytical Processing (HTAP) workloads.\",\n",
126+
" \"TiFlash is the key component that makes TiDB essentially an Hybrid Transactional/Analytical Processing (HTAP) database. As a columnar storage extension of TiKV, TiFlash provides both good isolation level and strong consistency guarantee.\",\n",
127+
" \"TiKV is a distributed and transactional key-value database, which provides transactional APIs with ACID compliance. With the implementation of the Raft consensus algorithm and consensus state stored in RocksDB, TiKV guarantees data consistency between multiple replicas and high availability. \",\n",
128+
"]\n",
129+
"\n",
130+
"class DocModel(Model):\n",
131+
" text = TextField()\n",
132+
" embedding = VectorField(dimensions=embedding_dimensions)\n",
133+
"\n",
134+
" class Meta:\n",
135+
" database = db\n",
136+
" table_name = \"gemini_embedding_test\"\n",
137+
"\n",
138+
" def __str__(self):\n",
139+
" return self.text\n",
140+
"\n",
141+
"db.drop_tables([DocModel])\n",
142+
"db.create_tables([DocModel])\n",
143+
"\n",
144+
"embeddings = genai.embed_content(model=embedding_model, content=documents, task_type=\"retrieval_document\")\n",
145+
"data_source = [\n",
146+
" {\"text\": doc, \"embedding\": emb}\n",
147+
" for doc, emb in zip(documents, embeddings['embedding'])\n",
148+
"]\n",
149+
"DocModel.insert_many(data_source).execute()"
150+
]
151+
},
152+
{
153+
"cell_type": "markdown",
154+
"metadata": {
155+
"id": "zMP-P1g8izUs"
156+
},
157+
"source": [
158+
"## Initial the Vector of Question\n",
159+
"\n",
160+
"Ask a question, use the openai embeddings model to get the embeddings of the question"
161+
]
162+
},
163+
{
164+
"cell_type": "code",
165+
"execution_count": null,
166+
"metadata": {
167+
"id": "-zrTOxs4izUt"
168+
},
169+
"outputs": [],
170+
"source": [
171+
"question = \"what is TiKV?\"\n",
172+
"question_embedding = genai.embed_content(model=embedding_model, content=[question], task_type=\"retrieval_query\")['embedding'][0]"
173+
]
174+
},
175+
{
176+
"cell_type": "markdown",
177+
"metadata": {
178+
"id": "atc0gXVZizUt"
179+
},
180+
"source": [
181+
"## Retrieve by Cosine Distance of Vectors\n",
182+
"Get the relevant documents from the TiDB by comparing the embeddings of the question and the documents"
183+
]
184+
},
185+
{
186+
"cell_type": "code",
187+
"execution_count": null,
188+
"metadata": {
189+
"id": "DTtJRX64izUt"
190+
},
191+
"outputs": [],
192+
"source": [
193+
"related_docs = DocModel.select(\n",
194+
" DocModel.text, DocModel.embedding.cosine_distance(question_embedding).alias(\"distance\")\n",
195+
").order_by(SQL(\"distance\")).limit(3)\n",
196+
"\n",
197+
"print(\"Question:\", question)\n",
198+
"print(\"Related documents:\")\n",
199+
"for doc in related_docs:\n",
200+
" print(doc.distance, doc.text)"
201+
]
202+
},
203+
{
204+
"cell_type": "markdown",
205+
"metadata": {
206+
"id": "bYBetPchmNUp"
207+
},
208+
"source": [
209+
"## Cleanup"
210+
]
211+
},
212+
{
213+
"cell_type": "code",
214+
"execution_count": null,
215+
"metadata": {
216+
"id": "Lh27gC7gizUt"
217+
},
218+
"outputs": [],
219+
"source": [
220+
"db.close()"
221+
]
222+
}
223+
],
224+
"metadata": {
225+
"colab": {
226+
"provenance": [],
227+
"toc_visible": true
228+
},
229+
"kernelspec": {
230+
"display_name": ".venv",
231+
"language": "python",
232+
"name": "python3"
233+
},
234+
"language_info": {
235+
"codemirror_mode": {
236+
"name": "ipython",
237+
"version": 3
238+
},
239+
"file_extension": ".py",
240+
"mimetype": "text/x-python",
241+
"name": "python",
242+
"nbconvert_exporter": "python",
243+
"pygments_lexer": "ipython3",
244+
"version": "3.10.13"
245+
}
246+
},
247+
"nbformat": 4,
248+
"nbformat_minor": 0
249+
}
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
import os
2+
from peewee import Model, MySQLDatabase, TextField, SQL
3+
from tidb_vector.peewee import VectorField
4+
import google.generativeai as genai # Hypothetical import for Gemini API client
5+
6+
# Init Gemini client
7+
# Adjust the initialization according to the Gemini API documentation
8+
genai.configure(api_key=os.environ.get('GEMINI_API_KEY'))
9+
embedding_model = 'models/embedding-001' # Replace with the actual model name
10+
embedding_dimensions = 768 # Adjust if different for the Gemini model
11+
12+
# Init TiDB connection
13+
db = MySQLDatabase(
14+
'test',
15+
user=os.environ.get('TIDB_USERNAME'),
16+
password=os.environ.get('TIDB_PASSWORD'),
17+
host=os.environ.get('TIDB_HOST'),
18+
port=4000,
19+
ssl_verify_cert=True,
20+
ssl_verify_identity=True
21+
)
22+
23+
documents = [
24+
"TiDB is an open-source distributed SQL database that supports Hybrid Transactional and Analytical Processing (HTAP) workloads.",
25+
"TiFlash is the key component that makes TiDB essentially an Hybrid Transactional/Analytical Processing (HTAP) database. As a columnar storage extension of TiKV, TiFlash provides both good isolation level and strong consistency guarantee.",
26+
"TiKV is a distributed and transactional key-value database, which provides transactional APIs with ACID compliance. With the implementation of the Raft consensus algorithm and consensus state stored in RocksDB, TiKV guarantees data consistency between multiple replicas and high availability.",
27+
]
28+
29+
# Define a model with a VectorField to store the embeddings
30+
class DocModel(Model):
31+
text = TextField()
32+
embedding = VectorField(dimensions=embedding_dimensions)
33+
34+
class Meta:
35+
database = db
36+
table_name = "gemini_embedding_test"
37+
38+
def __str__(self):
39+
return self.text
40+
41+
db.connect()
42+
db.drop_tables([DocModel])
43+
db.create_tables([DocModel])
44+
45+
# Insert the documents and their embeddings into TiDB
46+
embeddings = genai.embed_content(model=embedding_model, content=documents, task_type="retrieval_document")
47+
data_source = [
48+
{"text": doc, "embedding": emb}
49+
for doc, emb in zip(documents, embeddings['embedding'])
50+
]
51+
DocModel.insert_many(data_source).execute()
52+
53+
# Query the most similar documents to a question
54+
# 1. Generate the embedding of the question
55+
# 2. Query the most similar documents based on the cosine distance in TiDB
56+
# 3. Print the results
57+
question = "what is TiKV?"
58+
question_embedding = genai.embed_content(model=embedding_model, content=[question], task_type="retrieval_query")['embedding'][0]
59+
related_docs = DocModel.select(
60+
DocModel.text, DocModel.embedding.cosine_distance(question_embedding).alias("distance")
61+
).order_by(SQL("distance")).limit(3)
62+
63+
print("Question:", question)
64+
print("Related documents:")
65+
for doc in related_docs:
66+
print(doc.distance, doc.text)
67+
68+
db.close()
69+
70+
# Expected Output:
71+
#
72+
# Question: what is TiKV?
73+
# Related documents:
74+
# 0.22371791507562544 TiKV is a distributed and transactional key-value database, which provides transactional APIs with ACID compliance. With the implementation of the Raft consensus algorithm and consensus state stored in RocksDB, TiKV guarantees data consistency between multiple replicas and high availability.
75+
# 0.3317073143109729 TiFlash is the key component that makes TiDB essentially an Hybrid Transactional/Analytical Processing (HTAP) database. As a columnar storage extension of TiKV, TiFlash provides both good isolation level and strong consistency guarantee.
76+
# 0.3690570695898543 TiDB is an open-source distributed SQL database that supports Hybrid Transactional and Analytical Processing (HTAP) workloads.
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
google.generativeai
2+
peewee
3+
tidb-vector

0 commit comments

Comments
 (0)