Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit feea47a

Browse files
committed
Use hashValues to improve the vector search.
1 parent e68f6a5 commit feea47a

File tree

1 file changed

+96
-8
lines changed

1 file changed

+96
-8
lines changed

src/main/java/org/tinystruct/data/component/SQLiteVector.java

Lines changed: 96 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,7 @@
1111
import java.sql.PreparedStatement;
1212
import java.sql.ResultSet;
1313
import java.sql.SQLException;
14-
import java.util.ArrayList;
15-
import java.util.Comparator;
16-
import java.util.List;
14+
import java.util.*;
1715
import java.util.logging.Level;
1816
import java.util.logging.Logger;
1917

@@ -27,6 +25,9 @@
2725
public class SQLiteVector implements Vector {
2826
private static final Logger LOGGER = Logger.getLogger(SQLiteVector.class.getName());
2927
private final DatabaseOperator dbOperator;
28+
private static final int NUM_HASH_TABLES = 4; // Number of hash tables for LSH
29+
private static final int NUM_PROJECTIONS = 8; // Number of random projections per hash table
30+
private final double[][] randomProjections; // Random projection vectors for LSH
3031

3132
public SQLiteVector() {
3233
try {
@@ -39,14 +40,66 @@ public SQLiteVector() {
3940
"vector BLOB NOT NULL," +
4041
"label TEXT NOT NULL," +
4142
"entity_id INTEGER NOT NULL," +
42-
"entity_type TEXT NOT NULL)";
43+
"entity_type TEXT NOT NULL," +
44+
"hash_values TEXT NOT NULL)"; // Store hash values as comma-separated string
4345
dbOperator.execute(createTableSQL);
46+
47+
// Initialize random projections for LSH
48+
this.randomProjections = new double[NUM_HASH_TABLES][NUM_PROJECTIONS];
49+
Random random = new Random(42); // Fixed seed for reproducibility
50+
for (int i = 0; i < NUM_HASH_TABLES; i++) {
51+
for (int j = 0; j < NUM_PROJECTIONS; j++) {
52+
randomProjections[i][j] = random.nextGaussian();
53+
}
54+
}
4455
} catch (ApplicationException e) {
4556
LOGGER.log(Level.SEVERE, "Error initializing SQLite database connection", e);
4657
throw new ApplicationRuntimeException("Failed to initialize SQLite database connection", e);
4758
}
4859
}
4960

61+
/**
62+
* Computes LSH hash values for a vector.
63+
*
64+
* @param vector The vector to hash
65+
* @return Array of hash values, one per hash table
66+
*/
67+
private int[] computeHashValues(double[] vector) {
68+
int[] hashValues = new int[NUM_HASH_TABLES];
69+
for (int i = 0; i < NUM_HASH_TABLES; i++) {
70+
double projection = 0;
71+
for (int j = 0; j < NUM_PROJECTIONS; j++) {
72+
projection += vector[j] * randomProjections[i][j];
73+
}
74+
hashValues[i] = projection > 0 ? 1 : 0;
75+
}
76+
return hashValues;
77+
}
78+
79+
/**
80+
* Converts hash values to a string representation for storage.
81+
*/
82+
private String hashValuesToString(int[] hashValues) {
83+
StringBuilder sb = new StringBuilder();
84+
for (int i = 0; i < hashValues.length; i++) {
85+
if (i > 0) sb.append(",");
86+
sb.append(hashValues[i]);
87+
}
88+
return sb.toString();
89+
}
90+
91+
/**
92+
* Converts string representation back to hash values.
93+
*/
94+
private int[] stringToHashValues(String hashString) {
95+
String[] parts = hashString.split(",");
96+
int[] hashValues = new int[parts.length];
97+
for (int i = 0; i < parts.length; i++) {
98+
hashValues[i] = Integer.parseInt(parts[i]);
99+
}
100+
return hashValues;
101+
}
102+
50103
/**
51104
* Adds a vector to the database with the associated label, entity ID, and entity type.
52105
*
@@ -58,10 +111,14 @@ public SQLiteVector() {
58111
*/
59112
@Override
60113
public void add(double[] vector, String label, int entityId, String entityType) throws ApplicationException {
114+
// Compute hash values for the vector
115+
int[] hashValues = computeHashValues(vector);
116+
String hashString = hashValuesToString(hashValues);
117+
61118
// SQL query to insert a new vector into the database
62-
String sql = "INSERT INTO vectors (vector, label, entity_id, entity_type) VALUES (?, ?, ?, ?)";
119+
String sql = "INSERT INTO vectors (vector, label, entity_id, entity_type, hash_values) VALUES (?, ?, ?, ?, ?)";
63120
try (PreparedStatement statement = dbOperator.preparedStatement(sql, new Object[]{
64-
toByteArray(vector), label, entityId, entityType})) {
121+
toByteArray(vector), label, entityId, entityType, hashString})) {
65122
// Execute the SQL statement to add the vector
66123
dbOperator.execute(statement);
67124
} catch (SQLException e) {
@@ -81,11 +138,23 @@ public void add(double[] vector, String label, int entityId, String entityType)
81138
*/
82139
@Override
83140
public List<SearchResult> search(double[] queryVector, int topK) throws ApplicationException {
84-
String sql = "SELECT * FROM vectors";
141+
// Compute hash values for the query vector
142+
int[] queryHashValues = computeHashValues(queryVector);
143+
String queryHashString = hashValuesToString(queryHashValues);
144+
145+
// Find candidate vectors using LSH
146+
String sql = "SELECT * FROM vectors WHERE hash_values = ?";
85147
List<SearchResult> results = new ArrayList<>();
148+
Set<Integer> processedIds = new HashSet<>(); // Track processed vectors to avoid duplicates
86149

87-
try (ResultSet resultSet = dbOperator.query(sql)) {
150+
try (PreparedStatement statement = dbOperator.preparedStatement(sql, new Object[]{queryHashString});
151+
ResultSet resultSet = dbOperator.executeQuery(statement)) {
152+
88153
while (resultSet.next()) {
154+
int id = resultSet.getInt("id");
155+
if (processedIds.contains(id)) continue;
156+
processedIds.add(id);
157+
89158
byte[] vectorBytes = resultSet.getBytes("vector");
90159
double[] vector = toDoubleArray(vectorBytes);
91160
String label = resultSet.getString("label");
@@ -99,6 +168,25 @@ public List<SearchResult> search(double[] queryVector, int topK) throws Applicat
99168
throw new ApplicationException("Failed to search vectors in SQLite", e);
100169
}
101170

171+
// If we don't have enough results, fall back to exact search
172+
if (results.size() < topK) {
173+
String fallbackSql = "SELECT * FROM vectors WHERE id NOT IN (" +
174+
String.join(",", processedIds.stream().map(String::valueOf).toArray(String[]::new)) + ")";
175+
try (ResultSet resultSet = dbOperator.query(fallbackSql)) {
176+
while (resultSet.next() && results.size() < topK) {
177+
byte[] vectorBytes = resultSet.getBytes("vector");
178+
double[] vector = toDoubleArray(vectorBytes);
179+
String label = resultSet.getString("label");
180+
int entityId = resultSet.getInt("entity_id");
181+
String entityType = resultSet.getString("entity_type");
182+
double similarity = cosineSimilarity(queryVector, vector);
183+
results.add(new SearchResult(vector, label, entityId, entityType, similarity));
184+
}
185+
} catch (SQLException e) {
186+
LOGGER.log(Level.SEVERE, "Error in fallback search", e);
187+
}
188+
}
189+
102190
results.sort(Comparator.comparingDouble(SearchResult::getSimilarity).reversed());
103191
return results.subList(0, Math.min(topK, results.size()));
104192
}

0 commit comments

Comments
 (0)