11
11
import java .sql .PreparedStatement ;
12
12
import java .sql .ResultSet ;
13
13
import java .sql .SQLException ;
14
- import java .util .ArrayList ;
15
- import java .util .Comparator ;
16
- import java .util .List ;
14
+ import java .util .*;
17
15
import java .util .logging .Level ;
18
16
import java .util .logging .Logger ;
19
17
27
25
public class SQLiteVector implements Vector {
28
26
private static final Logger LOGGER = Logger .getLogger (SQLiteVector .class .getName ());
29
27
private final DatabaseOperator dbOperator ;
28
+ private static final int NUM_HASH_TABLES = 4 ; // Number of hash tables for LSH
29
+ private static final int NUM_PROJECTIONS = 8 ; // Number of random projections per hash table
30
+ private final double [][] randomProjections ; // Random projection vectors for LSH
30
31
31
32
public SQLiteVector () {
32
33
try {
@@ -39,14 +40,66 @@ public SQLiteVector() {
39
40
"vector BLOB NOT NULL," +
40
41
"label TEXT NOT NULL," +
41
42
"entity_id INTEGER NOT NULL," +
42
- "entity_type TEXT NOT NULL)" ;
43
+ "entity_type TEXT NOT NULL," +
44
+ "hash_values TEXT NOT NULL)" ; // Store hash values as comma-separated string
43
45
dbOperator .execute (createTableSQL );
46
+
47
+ // Initialize random projections for LSH
48
+ this .randomProjections = new double [NUM_HASH_TABLES ][NUM_PROJECTIONS ];
49
+ Random random = new Random (42 ); // Fixed seed for reproducibility
50
+ for (int i = 0 ; i < NUM_HASH_TABLES ; i ++) {
51
+ for (int j = 0 ; j < NUM_PROJECTIONS ; j ++) {
52
+ randomProjections [i ][j ] = random .nextGaussian ();
53
+ }
54
+ }
44
55
} catch (ApplicationException e ) {
45
56
LOGGER .log (Level .SEVERE , "Error initializing SQLite database connection" , e );
46
57
throw new ApplicationRuntimeException ("Failed to initialize SQLite database connection" , e );
47
58
}
48
59
}
49
60
61
+ /**
62
+ * Computes LSH hash values for a vector.
63
+ *
64
+ * @param vector The vector to hash
65
+ * @return Array of hash values, one per hash table
66
+ */
67
+ private int [] computeHashValues (double [] vector ) {
68
+ int [] hashValues = new int [NUM_HASH_TABLES ];
69
+ for (int i = 0 ; i < NUM_HASH_TABLES ; i ++) {
70
+ double projection = 0 ;
71
+ for (int j = 0 ; j < NUM_PROJECTIONS ; j ++) {
72
+ projection += vector [j ] * randomProjections [i ][j ];
73
+ }
74
+ hashValues [i ] = projection > 0 ? 1 : 0 ;
75
+ }
76
+ return hashValues ;
77
+ }
78
+
79
+ /**
80
+ * Converts hash values to a string representation for storage.
81
+ */
82
+ private String hashValuesToString (int [] hashValues ) {
83
+ StringBuilder sb = new StringBuilder ();
84
+ for (int i = 0 ; i < hashValues .length ; i ++) {
85
+ if (i > 0 ) sb .append ("," );
86
+ sb .append (hashValues [i ]);
87
+ }
88
+ return sb .toString ();
89
+ }
90
+
91
+ /**
92
+ * Converts string representation back to hash values.
93
+ */
94
+ private int [] stringToHashValues (String hashString ) {
95
+ String [] parts = hashString .split ("," );
96
+ int [] hashValues = new int [parts .length ];
97
+ for (int i = 0 ; i < parts .length ; i ++) {
98
+ hashValues [i ] = Integer .parseInt (parts [i ]);
99
+ }
100
+ return hashValues ;
101
+ }
102
+
50
103
/**
51
104
* Adds a vector to the database with the associated label, entity ID, and entity type.
52
105
*
@@ -58,10 +111,14 @@ public SQLiteVector() {
58
111
*/
59
112
@ Override
60
113
public void add (double [] vector , String label , int entityId , String entityType ) throws ApplicationException {
114
+ // Compute hash values for the vector
115
+ int [] hashValues = computeHashValues (vector );
116
+ String hashString = hashValuesToString (hashValues );
117
+
61
118
// SQL query to insert a new vector into the database
62
- String sql = "INSERT INTO vectors (vector, label, entity_id, entity_type) VALUES (?, ?, ?, ?)" ;
119
+ String sql = "INSERT INTO vectors (vector, label, entity_id, entity_type, hash_values ) VALUES (?, ?, ?, ?, ?)" ;
63
120
try (PreparedStatement statement = dbOperator .preparedStatement (sql , new Object []{
64
- toByteArray (vector ), label , entityId , entityType })) {
121
+ toByteArray (vector ), label , entityId , entityType , hashString })) {
65
122
// Execute the SQL statement to add the vector
66
123
dbOperator .execute (statement );
67
124
} catch (SQLException e ) {
@@ -81,11 +138,23 @@ public void add(double[] vector, String label, int entityId, String entityType)
81
138
*/
82
139
@ Override
83
140
public List <SearchResult > search (double [] queryVector , int topK ) throws ApplicationException {
84
- String sql = "SELECT * FROM vectors" ;
141
+ // Compute hash values for the query vector
142
+ int [] queryHashValues = computeHashValues (queryVector );
143
+ String queryHashString = hashValuesToString (queryHashValues );
144
+
145
+ // Find candidate vectors using LSH
146
+ String sql = "SELECT * FROM vectors WHERE hash_values = ?" ;
85
147
List <SearchResult > results = new ArrayList <>();
148
+ Set <Integer > processedIds = new HashSet <>(); // Track processed vectors to avoid duplicates
86
149
87
- try (ResultSet resultSet = dbOperator .query (sql )) {
150
+ try (PreparedStatement statement = dbOperator .preparedStatement (sql , new Object []{queryHashString });
151
+ ResultSet resultSet = dbOperator .executeQuery (statement )) {
152
+
88
153
while (resultSet .next ()) {
154
+ int id = resultSet .getInt ("id" );
155
+ if (processedIds .contains (id )) continue ;
156
+ processedIds .add (id );
157
+
89
158
byte [] vectorBytes = resultSet .getBytes ("vector" );
90
159
double [] vector = toDoubleArray (vectorBytes );
91
160
String label = resultSet .getString ("label" );
@@ -99,6 +168,25 @@ public List<SearchResult> search(double[] queryVector, int topK) throws Applicat
99
168
throw new ApplicationException ("Failed to search vectors in SQLite" , e );
100
169
}
101
170
171
+ // If we don't have enough results, fall back to exact search
172
+ if (results .size () < topK ) {
173
+ String fallbackSql = "SELECT * FROM vectors WHERE id NOT IN (" +
174
+ String .join ("," , processedIds .stream ().map (String ::valueOf ).toArray (String []::new )) + ")" ;
175
+ try (ResultSet resultSet = dbOperator .query (fallbackSql )) {
176
+ while (resultSet .next () && results .size () < topK ) {
177
+ byte [] vectorBytes = resultSet .getBytes ("vector" );
178
+ double [] vector = toDoubleArray (vectorBytes );
179
+ String label = resultSet .getString ("label" );
180
+ int entityId = resultSet .getInt ("entity_id" );
181
+ String entityType = resultSet .getString ("entity_type" );
182
+ double similarity = cosineSimilarity (queryVector , vector );
183
+ results .add (new SearchResult (vector , label , entityId , entityType , similarity ));
184
+ }
185
+ } catch (SQLException e ) {
186
+ LOGGER .log (Level .SEVERE , "Error in fallback search" , e );
187
+ }
188
+ }
189
+
102
190
results .sort (Comparator .comparingDouble (SearchResult ::getSimilarity ).reversed ());
103
191
return results .subList (0 , Math .min (topK , results .size ()));
104
192
}
0 commit comments