diff --git a/alloydb/vector_search/requirements-test.txt b/alloydb/vector_search/requirements-test.txt new file mode 100644 index 00000000000..4a3fe4434f6 --- /dev/null +++ b/alloydb/vector_search/requirements-test.txt @@ -0,0 +1,2 @@ +pytest==8.4.0 +pytest-asyncio==0.24.0 \ No newline at end of file diff --git a/alloydb/vector_search/requirements.txt b/alloydb/vector_search/requirements.txt new file mode 100644 index 00000000000..2c4ee19e607 --- /dev/null +++ b/alloydb/vector_search/requirements.txt @@ -0,0 +1,5 @@ +Flask==3.1.1 +google-cloud-alloydb==0.4.7 +google-cloud-alloydb-connector[pg8000]==1.9.0 +gunicorn==23.0.0 +psycopg2-binary==2.9.10 \ No newline at end of file diff --git a/alloydb/vector_search/resources/example_data.sql b/alloydb/vector_search/resources/example_data.sql new file mode 100644 index 00000000000..3d194634cc1 --- /dev/null +++ b/alloydb/vector_search/resources/example_data.sql @@ -0,0 +1,225 @@ +DROP TABLE IF EXISTS product_inventory; + +DROP TABLE IF EXISTS product; + +CREATE TABLE + product ( + id INT PRIMARY KEY, + name VARCHAR(255) NOT NULL, + description TEXT, + category VARCHAR(255), + color VARCHAR(255), + embedding vector (768) GENERATED ALWAYS AS (embedding ('text-embedding-005', description)) STORED + ); + +CREATE TABLE + product_inventory ( + id INT PRIMARY KEY, + product_id INT REFERENCES product (id), + inventory INT, + price DECIMAL(10, 2) + ); + +INSERT INTO + product (id, name, description, category, color) +VALUES + ( + 1, + 'Stuffed Elephant', + 'Soft plush elephant with floppy ears.', + 'Plush Toys', + 'Gray' + ), + ( + 2, + 'Remote Control Airplane', + 'Easy-to-fly remote control airplane.', + 'Vehicles', + 'Red' + ), + ( + 3, + 'Wooden Train Set', + 'Classic wooden train set with tracks and trains.', + 'Vehicles', + 'Multicolor' + ), + ( + 4, + 'Kids Tool Set', + 'Toy tool set with realistic tools.', + 'Pretend Play', + 'Multicolor' + ), + ( + 5, + 'Play Food Set', + 'Set of realistic play food items.', + 'Pretend Play', + 'Multicolor' + ), + ( + 6, + 'Magnetic Tiles', + 'Set of colorful magnetic tiles for building.', + 'Construction Toys', + 'Multicolor' + ), + ( + 7, + 'Kids Microscope', + 'Microscope for kids with different magnification levels.', + 'Educational Toys', + 'White' + ), + ( + 8, + 'Telescope for Kids', + 'Telescope designed for kids to explore the night sky.', + 'Educational Toys', + 'Blue' + ), + ( + 9, + 'Coding Robot', + 'Robot that teaches kids basic coding concepts.', + 'Educational Toys', + 'White' + ), + ( + 10, + 'Kids Camera', + 'Durable camera for kids to take pictures and videos.', + 'Electronics', + 'Pink' + ), + ( + 11, + 'Walkie Talkies', + 'Set of walkie talkies for kids to communicate.', + 'Electronics', + 'Blue' + ), + ( + 12, + 'Karaoke Machine', + 'Karaoke machine with built-in microphone and speaker.', + 'Electronics', + 'Black' + ), + ( + 13, + 'Kids Drum Set', + 'Drum set designed for kids with adjustable height.', + 'Musical Instruments', + 'Blue' + ), + ( + 14, + 'Kids Guitar', + 'Acoustic guitar for kids with nylon strings.', + 'Musical Instruments', + 'Brown' + ), + ( + 15, + 'Kids Keyboard', + 'Electronic keyboard with different instrument sounds.', + 'Musical Instruments', + 'Black' + ), + ( + 16, + 'Art Easel', + 'Double-sided art easel with chalkboard and whiteboard.', + 'Arts & Crafts', + 'White' + ), + ( + 17, + 'Finger Paints', + 'Set of non-toxic finger paints for kids.', + 'Arts & Crafts', + 'Multicolor' + ), + ( + 18, + 'Modeling Clay', + 'Set of colorful modeling clay.', + 'Arts & Crafts', + 'Multicolor' + ), + ( + 19, + 'Watercolor Paint Set', + 'Watercolor paint set with brushes and palette.', + 'Arts & Crafts', + 'Multicolor' + ), + ( + 20, + 'Beading Kit', + 'Kit for making bracelets and necklaces with beads.', + 'Arts & Crafts', + 'Multicolor' + ), + ( + 21, + '3D Puzzle', + '3D puzzle of a famous landmark.', + 'Puzzles', + 'Multicolor' + ), + ( + 22, + 'Race Car Track Set', + 'Race car track set with cars and accessories.', + 'Vehicles', + 'Multicolor' + ), + ( + 23, + 'RC Monster Truck', + 'Remote control monster truck with oversized tires.', + 'Vehicles', + 'Green' + ), + ( + 24, + 'Train Track Expansion Set', + 'Expansion set for wooden train tracks.', + 'Vehicles', + 'Multicolor' + ); + +INSERT INTO + product_inventory (id, product_id, inventory, price) +VALUES + (1, 1, 9, 13.09), + (2, 2, 40, 79.82), + (3, 3, 34, 52.49), + (4, 4, 9, 12.03), + (5, 5, 36, 71.29), + (6, 6, 10, 51.49), + (7, 7, 7, 37.35), + (8, 8, 6, 10.87), + (9, 9, 7, 42.47), + (10, 10, 3, 24.35), + (11, 11, 4, 10.20), + (12, 12, 47, 74.57), + (13, 13, 5, 28.54), + (14, 14, 11, 25.58), + (15, 15, 21, 69.84), + (16, 16, 6, 47.73), + (17, 17, 26, 81.00), + (18, 18, 11, 91.60), + (19, 19, 8, 78.53), + (20, 20, 43, 84.33), + (21, 21, 46, 90.01), + (22, 22, 6, 49.82), + (23, 23, 37, 50.20), + (24, 24, 27, 99.27); + +CREATE INDEX product_index ON product USING scann (embedding cosine) +WITH + (num_leaves = 5); \ No newline at end of file diff --git a/alloydb/vector_search/vector_search.py b/alloydb/vector_search/vector_search.py new file mode 100644 index 00000000000..32d4feec8eb --- /dev/null +++ b/alloydb/vector_search/vector_search.py @@ -0,0 +1,102 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import Optional, Union + +from google.api_core import exceptions as api_exceptions +from google.cloud.alloydb.connector import Connector, IPTypes + +from pg8000 import dbapi + + +def get_db_connection( + project_id: str, + region: str, + cluster_id: str, + instance_id: str, + db_user: str, + db_pass: str, + db_name: str, + ip_type: IPTypes, +) -> dbapi.Connection: + connector = None + instance_uri = ( + f"projects/{project_id}/locations/{region}/" + f"clusters/{cluster_id}/instances/{instance_id}" + ) + try: + connector = Connector() + connection = connector.connect( + instance_uri=instance_uri, + driver="pg8000", + ip_type=ip_type, + user=db_user, + password=db_pass, + db=db_name, + enable_iam_auth=False, + ) + return connection + except api_exceptions.Forbidden as e: + raise ConnectionError("Missing IAM permissions to connect to AlloyDB.") from e + except api_exceptions.NotFound as e: + raise ConnectionError("The specified AlloyDB instance was not found.") from e + except api_exceptions.ServiceUnavailable as e: + raise ConnectionError("AlloyDB service is temporarily unavailable.") from e + except api_exceptions.GoogleAPICallError as e: + raise ConnectionError( + "An error occurred during the AlloyDB connector's API interaction." + ) from e + except Exception as e: + logging.exception(f"An unexpected error occurred during connection setup: {e}") + raise + + +def execute_sql_request( + db_connection: dbapi.Connection, + sql_statement: str, + params: tuple = (), + fetch_one: bool = False, + fetch_all: bool = False, +) -> Union[Optional[tuple], list[tuple], bool]: + cursor = db_connection.cursor() + cursor.execute(sql_statement, params) + + if fetch_one: + result = cursor.fetchone() + elif fetch_all: + result = cursor.fetchall() + else: + db_connection.commit() + result = True + + if cursor: + cursor.close() + + return result + + +def perform_vector_search( + db_connection: dbapi.Connection, word_to_find: str, limit: int = 5 +) -> tuple[list]: + sql_statement = """ + SELECT id, name, description, category, color + FROM product + ORDER BY embedding <=> embedding('text-embedding-005', %s)::vector + LIMIT %s; + """ + params = (word_to_find, limit) + response = execute_sql_request(db_connection, sql_statement, params, fetch_all=True) + + return response diff --git a/alloydb/vector_search/vector_search_test.py b/alloydb/vector_search/vector_search_test.py new file mode 100644 index 00000000000..71dffb2e931 --- /dev/null +++ b/alloydb/vector_search/vector_search_test.py @@ -0,0 +1,68 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from pathlib import Path + +from google.cloud.alloydb.connector import IPTypes + +from pg8000 import dbapi + +import pytest + +from vector_search import execute_sql_request, get_db_connection, perform_vector_search + + +GOOGLE_CLOUD_PROJECT = os.environ["GOOGLE_CLOUD_PROJECT"] +ALLOYDB_REGION = os.environ["ALLOYDB_REGION"] +ALLOYDB_CLUSTER = os.environ["ALLOYDB_CLUSTER"] +ALLOYDB_INSTANCE = os.environ["ALLOYDB_INSTANCE"] + +ALLOYDB_DATABASE_NAME = os.environ["ALLOYDB_DATABASE_NAME"] +ALLOYDB_PASSWORD = os.environ["ALLOYDB_PASSWORD"] +ALLOYDB_USERNAME = "postgres" + + +@pytest.fixture(scope="module") +def db_connection() -> dbapi.Connection: + return get_db_connection( + project_id=GOOGLE_CLOUD_PROJECT, + region=ALLOYDB_REGION, + cluster_id=ALLOYDB_CLUSTER, + instance_id=ALLOYDB_INSTANCE, + db_user=ALLOYDB_USERNAME, + db_pass=ALLOYDB_PASSWORD, + db_name=ALLOYDB_DATABASE_NAME, + ip_type=IPTypes.PUBLIC, + ) + + +def test_basic_vector_search(db_connection: dbapi.Connection) -> None: + # Install required extensions + sql_statement = """ + CREATE EXTENSION IF NOT EXISTS vector; + CREATE EXTENSION IF NOT EXISTS alloydb_scann; + """ + execute_sql_request(db_connection, sql_statement) + + # Insert product and product inventory data + with open( + Path(__file__).parent / "resources/example_data.sql", encoding="utf-8" + ) as f: + sql_statement = f.read() + execute_sql_request(db_connection, sql_statement) + + # Perform a Vector search in the DB + result = perform_vector_search(db_connection, word_to_find="music", limit=3) + assert len(result) == 3