-
Notifications
You must be signed in to change notification settings - Fork 4.5k
Expand file tree
/
Copy pathpostgres.py
More file actions
209 lines (177 loc) · 7.97 KB
/
postgres.py
File metadata and controls
209 lines (177 loc) · 7.97 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Callable
from typing import Dict
from typing import List
from typing import NamedTuple
from typing import Optional
from typing import Union
import apache_beam as beam
from apache_beam.coders import registry
from apache_beam.coders.row_coder import RowCoder
from apache_beam.io.jdbc import WriteToJdbc
from apache_beam.ml.rag.ingestion.base import VectorDatabaseWriteConfig
from apache_beam.ml.rag.ingestion.jdbc_common import ConnectionConfig
from apache_beam.ml.rag.ingestion.jdbc_common import WriteConfig
from apache_beam.ml.rag.ingestion.postgres_common import ColumnSpec
from apache_beam.ml.rag.ingestion.postgres_common import ColumnSpecsBuilder
from apache_beam.ml.rag.ingestion.postgres_common import ConflictResolution
from apache_beam.ml.rag.types import EmbeddableItem
_LOGGER = logging.getLogger(__name__)
MetadataSpec = Union[ColumnSpec, Dict[str, ColumnSpec]]
class _PostgresQueryBuilder:
def __init__(
self,
table_name: str,
*,
column_specs: List[ColumnSpec],
conflict_resolution: Optional[ConflictResolution] = None):
"""Builds SQL queries for writing EmbeddableItems to Postgres.
"""
self.table_name = table_name
self.column_specs = column_specs
self.conflict_resolution = conflict_resolution
# Validate no duplicate column names
names = [col.column_name for col in self.column_specs]
duplicates = set(name for name in names if names.count(name) > 1)
if duplicates:
raise ValueError(f"Duplicate column names found: {duplicates}")
# Create NamedTuple type
fields = [(col.column_name, col.python_type) for col in self.column_specs]
type_name = f"VectorRecord_{table_name}"
self.record_type = NamedTuple(type_name, fields) # type: ignore
# Register coder
registry.register_coder(self.record_type, RowCoder)
# Set default update fields to all non-conflict fields if update fields are
# not specified
if self.conflict_resolution:
self.conflict_resolution.maybe_set_default_update_fields(
[col.column_name for col in self.column_specs if col.column_name])
def build_insert(self) -> str:
"""Build INSERT query with proper type casting."""
# Get column names and placeholders
fields = [col.column_name for col in self.column_specs]
placeholders = [col.placeholder for col in self.column_specs]
# Build base query
query = f"""
INSERT INTO {self.table_name}
({', '.join(fields)})
VALUES ({', '.join(placeholders)})
"""
# Add conflict handling if configured
if self.conflict_resolution:
query += f" {self.conflict_resolution.get_conflict_clause()}"
_LOGGER.info("Query with placeholders %s", query)
return query
def create_converter(self) -> Callable[[EmbeddableItem], NamedTuple]:
"""Creates a function to convert EmbeddableItems to records."""
def convert(chunk: EmbeddableItem) -> self.record_type: # type: ignore
return self.record_type(
**{col.column_name: col.value_fn(chunk)
for col in self.column_specs}) # type: ignore
return convert
class PostgresVectorWriterConfig(VectorDatabaseWriteConfig):
def __init__(
self,
connection_config: ConnectionConfig,
table_name: str,
*,
# pylint: disable=dangerous-default-value
write_config: WriteConfig = WriteConfig(),
column_specs: List[ColumnSpec] = ColumnSpecsBuilder.with_defaults().build(
),
conflict_resolution: Optional[ConflictResolution] = ConflictResolution(
on_conflict_fields=[], action='IGNORE')):
"""Configuration for writing vectors to Postgres using jdbc.
Supports flexible schema configuration through column specifications and
conflict resolution strategies.
Args:
connection_config:
:class:`~apache_beam.ml.rag.ingestion.jdbc_common.ConnectionConfig`.
table_name: Target table name.
write_config: JdbcIO :class:`~.jdbc_common.WriteConfig` to control
batch sizes, authosharding, etc.
column_specs:
Use :class:`~.postgres_common.ColumnSpecsBuilder` to configure how
embeddings and metadata are written a database
schema. If None, uses default EmbeddableItem schema.
conflict_resolution: Optional
:class:`~.postgres_common.ConflictResolution`
strategy for handling insert conflicts. ON CONFLICT DO NOTHING by
default.
Examples:
Simple case with default schema:
>>> config = PostgresVectorWriterConfig(
... connection_config=ConnectionConfig(...),
... table_name='embeddings'
... )
Custom schema with metadata fields:
>>> specs = (ColumnSpecsBuilder()
... .with_id_spec(column_name="my_id_column")
... .with_embedding_spec(column_name="embedding_vec")
... .add_metadata_field(field="source", column_name="src")
... .add_metadata_field(
... "timestamp",
... column_name="created_at",
... sql_typecast="::timestamp"
... )
... .build())
Minimal schema (only ID + embedding written)
>>> column_specs = (ColumnSpecsBuilder()
... .with_id_spec()
... .with_embedding_spec()
... .build())
>>> config = PostgresVectorWriterConfig(
... connection_config=ConnectionConfig(...),
... table_name='embeddings',
... column_specs=specs
... )
"""
self.connection_config = connection_config
self.write_config = write_config
# NamedTuple is created and registered here during pipeline construction
self.query_builder = _PostgresQueryBuilder(
table_name,
column_specs=column_specs,
conflict_resolution=conflict_resolution)
def create_write_transform(self) -> beam.PTransform:
return _WriteToPostgresVectorDatabase(self)
class _WriteToPostgresVectorDatabase(beam.PTransform):
"""Implementation of Postgres vector database write. """
def __init__(self, config: PostgresVectorWriterConfig):
self.config = config
self.query_builder = config.query_builder
self.connection_config = config.connection_config
self.write_config = config.write_config
def expand(self, pcoll: beam.PCollection[EmbeddableItem]):
return (
pcoll
|
"Convert to Records" >> beam.Map(self.query_builder.create_converter())
| "Write to Postgres" >> WriteToJdbc(
table_name=self.query_builder.table_name,
driver_class_name="org.postgresql.Driver",
jdbc_url=self.connection_config.jdbc_url,
username=self.connection_config.username,
password=self.connection_config.password,
statement=self.query_builder.build_insert(),
connection_properties=self.connection_config.connection_properties,
connection_init_sqls=self.connection_config.connection_init_sqls,
autosharding=self.write_config.autosharding,
max_connections=self.write_config.max_connections,
write_batch_size=self.write_config.write_batch_size,
**self.connection_config.additional_jdbc_args))