-
Notifications
You must be signed in to change notification settings - Fork 4.5k
Expand file tree
/
Copy pathbase.py
More file actions
116 lines (98 loc) · 4.42 KB
/
base.py
File metadata and controls
116 lines (98 loc) · 4.42 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC
from abc import abstractmethod
from typing import Any
import apache_beam as beam
from apache_beam.ml.rag.types import EmbeddableItem
class VectorDatabaseWriteConfig(ABC):
"""Abstract base class for vector database configurations in RAG pipelines.
VectorDatabaseWriteConfig defines the interface for configuring vector
database writes in RAG pipelines. Implementations should provide
database-specific configuration and create appropriate write transforms.
The configuration flow:
1. Subclass provides database-specific configuration (table names, etc)
2. create_write_transform() creates appropriate PTransform for writing
3. Transform handles converting EmbeddableItems to database-specific format
Example implementation:
>>> class BigQueryVectorWriterConfig(VectorDatabaseWriteConfig):
... def __init__(self, table: str):
... self.embedding_column = embedding_column
...
... def create_write_transform(self):
... return beam.io.WriteToBigQuery(
... table=self.table
... )
"""
@abstractmethod
def create_write_transform(self) -> beam.PTransform[EmbeddableItem, Any]:
"""Creates a PTransform that writes embeddings to the vector database.
Returns:
A PTransform that accepts PCollection[EmbeddableItem]
and writes the embeddings
and metadata to the configured vector database.
The transform should handle:
- Converting EmbeddableItem format to database schema
- Setting up database connection/client
- Writing with appropriate batching/error handling
"""
raise NotImplementedError(type(self))
class VectorDatabaseWriteTransform(beam.PTransform):
"""A PTransform for writing embedded chunks to vector databases.
This transform uses a VectorDatabaseWriteConfig to write chunks with
embeddings to vector database. It handles validating the config and applying
the database-specific write transform.
Example usage:
>>> config = BigQueryVectorConfig(
... table='project.dataset.embeddings',
... embedding_column='embedding'
... )
>>>
>>> with beam.Pipeline() as p:
... items = p | beam.Create([...]) # PCollection[EmbeddableItem]
... items | VectorDatabaseWriteTransform(config)
Args:
database_config: Configuration for the target vector database.
Must be a subclass of VectorDatabaseWriteConfig that implements
create_write_transform().
Raises:
TypeError: If database_config is not a VectorDatabaseWriteConfig instance.
"""
def __init__(self, database_config: VectorDatabaseWriteConfig):
"""Initialize transform with database config.
Args:
database_config: Configuration for target vector database.
"""
if not isinstance(database_config, VectorDatabaseWriteConfig):
raise TypeError(
f"database_config must be VectorDatabaseWriteConfig, "
f"got {type(database_config)}")
self.database_config = database_config
def expand(
self, pcoll: beam.PCollection[EmbeddableItem]
) -> beam.PTransform[EmbeddableItem, Any]:
"""Creates and applies the database-specific write transform.
Args:
pcoll: PCollection of EmbeddableItems with embeddings to write to the
vector database. Each EmbeddableItem must have:
- An embedding
- An ID
- Metadata used to filter results as specified by database config
Returns:
Result of writing to database (implementation specific).
"""
write_transform = self.database_config.create_write_transform()
return pcoll | write_transform