From 95ed224827938377399712adc3791e84ad01cc9d Mon Sep 17 00:00:00 2001 From: Alessandro Benedetti Date: Mon, 31 Mar 2025 12:33:19 +0100 Subject: [PATCH 01/43] first draft --- .../apache/solr/schema/DenseVectorField.java | 24 +++++++++ .../search/join/BlockJoinChildQParser.java | 4 +- .../search/join/BlockJoinParentQParser.java | 51 +++++++++++++++++-- .../solr/search/join/FiltersQParser.java | 2 +- .../apache/solr/search/neural/KnnQParser.java | 10 +++- .../apache/solr/update/AddUpdateCommand.java | 13 ++++- .../solr/search/join/BJQParserTest.java | 5 +- 7 files changed, 96 insertions(+), 13 deletions(-) diff --git a/solr/core/src/java/org/apache/solr/schema/DenseVectorField.java b/solr/core/src/java/org/apache/solr/schema/DenseVectorField.java index 773c1e6337d1..a3dfb8f9ae95 100644 --- a/solr/core/src/java/org/apache/solr/schema/DenseVectorField.java +++ b/solr/core/src/java/org/apache/solr/schema/DenseVectorField.java @@ -40,10 +40,15 @@ import org.apache.lucene.search.KnnFloatVectorQuery; import org.apache.lucene.search.Query; import org.apache.lucene.search.SortField; +import org.apache.lucene.search.join.BitSetProducer; +import org.apache.lucene.search.join.DiversifyingChildrenByteKnnVectorQuery; +import org.apache.lucene.search.join.DiversifyingChildrenFloatKnnVectorQuery; import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.hnsw.HnswGraph; import org.apache.solr.common.SolrException; +import org.apache.solr.request.SolrQueryRequest; import org.apache.solr.search.QParser; +import org.apache.solr.search.join.BlockJoinParentQParser; import org.apache.solr.uninverting.UninvertingReader; import org.apache.solr.util.vector.ByteDenseVectorParser; import org.apache.solr.util.vector.DenseVectorParser; @@ -384,6 +389,25 @@ public Query getKnnVectorQuery( } } + public Query getMultiValuedKnnVectorQuery(final SolrQueryRequest request, + String fieldName, String vectorToSearch, int topK, Query filterQuery) { + + DenseVectorParser vectorBuilder = + getVectorBuilder(vectorToSearch, DenseVectorParser.BuilderPhase.QUERY); + + BitSetProducer acceptedDocuments = BlockJoinParentQParser.getCachedBitSetProducer(request, filterQuery); + switch (vectorEncoding) { + case FLOAT32: + return new DiversifyingChildrenFloatKnnVectorQuery(fieldName, vectorBuilder.getFloatVector(), null, topK, acceptedDocuments); + case BYTE: + return new DiversifyingChildrenByteKnnVectorQuery(fieldName, vectorBuilder.getByteVector(), null, topK, acceptedDocuments); + default: + throw new SolrException( + SolrException.ErrorCode.SERVER_ERROR, + "Unexpected state. Vector Encoding: " + vectorEncoding); + } + } + /** * Not Supported. Please use the {!knn} query parser to run K nearest neighbors search queries. */ diff --git a/solr/core/src/java/org/apache/solr/search/join/BlockJoinChildQParser.java b/solr/core/src/java/org/apache/solr/search/join/BlockJoinChildQParser.java index bb6c80db07a8..2ac81ca7d95b 100644 --- a/solr/core/src/java/org/apache/solr/search/join/BlockJoinChildQParser.java +++ b/solr/core/src/java/org/apache/solr/search/join/BlockJoinChildQParser.java @@ -33,8 +33,8 @@ public BlockJoinChildQParser( } @Override - protected Query createQuery(Query parentListQuery, Query query, String scoreMode) { - return new ToChildBlockJoinQuery(query, getBitSetProducer(parentListQuery)); + protected Query createQuery(Query allParents, BooleanQuery query, String scoreMode) { + return new ToChildBlockJoinQuery(query, getBitSetProducer(allParents)); } @Override diff --git a/solr/core/src/java/org/apache/solr/search/join/BlockJoinParentQParser.java b/solr/core/src/java/org/apache/solr/search/join/BlockJoinParentQParser.java index 67a38d0fc57a..348ff632327f 100644 --- a/solr/core/src/java/org/apache/solr/search/join/BlockJoinParentQParser.java +++ b/solr/core/src/java/org/apache/solr/search/join/BlockJoinParentQParser.java @@ -18,17 +18,25 @@ import java.io.IOException; import java.io.UncheckedIOException; +import java.util.List; +import java.util.Map; import java.util.Objects; import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.search.BooleanClause; +import org.apache.lucene.search.BooleanQuery; import org.apache.lucene.search.ConstantScoreScorer; import org.apache.lucene.search.ConstantScoreWeight; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.KnnByteVectorQuery; +import org.apache.lucene.search.KnnFloatVectorQuery; import org.apache.lucene.search.Query; import org.apache.lucene.search.QueryVisitor; import org.apache.lucene.search.Scorer; import org.apache.lucene.search.Weight; import org.apache.lucene.search.join.BitSetProducer; +import org.apache.lucene.search.join.DiversifyingChildrenByteKnnVectorQuery; +import org.apache.lucene.search.join.DiversifyingChildrenFloatKnnVectorQuery; import org.apache.lucene.search.join.QueryBitSetProducer; import org.apache.lucene.search.join.ScoreMode; import org.apache.lucene.search.join.ToParentBlockJoinQuery; @@ -38,6 +46,7 @@ import org.apache.solr.request.SolrQueryRequest; import org.apache.solr.search.ExtendedQueryBase; import org.apache.solr.search.QParser; +import org.apache.solr.search.QueryUtils; import org.apache.solr.search.SolrCache; import org.apache.solr.search.SyntaxError; @@ -67,7 +76,7 @@ protected Query parseParentFilter() throws SyntaxError { } @Override - protected Query wrapSubordinateClause(Query subordinate) throws SyntaxError { + protected Query wrapSubordinateClause(BooleanQuery subordinate) throws SyntaxError { String scoreMode = localParams.get("score", ScoreMode.None.name()); Query parentQ = parseParentFilter(); return createQuery(parentQ, subordinate, scoreMode); @@ -78,10 +87,42 @@ protected Query noClausesQuery() throws SyntaxError { return new BitSetProducerQuery(getBitSetProducer(parseParentFilter())); } - protected Query createQuery(final Query parentList, Query query, String scoreMode) - throws SyntaxError { - return new AllParentsAware( - query, getBitSetProducer(parentList), ScoreModeParser.parse(scoreMode), parentList); + protected Query createQuery(final Query allParents, BooleanQuery query, String scoreMode) + throws SyntaxError { + List clauses = query.clauses(); + if (clauses.size() == 1 && clauses.get(0).getQuery().getClass().equals(KnnByteVectorQuery.class)) { + Query acceptedParents = getAcceptedParents(allParents); + KnnByteVectorQuery childQuery = (KnnByteVectorQuery) clauses.get(0).getQuery(); + String vectorField = childQuery.getField(); + byte[] queryVector = childQuery.getTargetCopy(); + int topK = childQuery.getK(); + BitSetProducer parentFilter = getBitSetProducer(acceptedParents); + Query childrenFilter = childQuery.getFilter(); + return new DiversifyingChildrenByteKnnVectorQuery(vectorField, queryVector, childrenFilter, topK, parentFilter); + } else if (clauses.size() == 1 && clauses.get(0).getQuery().getClass().equals(KnnFloatVectorQuery.class)) { + Query acceptedParents = getAcceptedParents(allParents); + KnnFloatVectorQuery childQuery = (KnnFloatVectorQuery) clauses.get(0).getQuery(); + String vectorField = childQuery.getField(); + float[] queryVector = childQuery.getTargetCopy(); + int topK = childQuery.getK(); + BitSetProducer parentFilter = getBitSetProducer(acceptedParents); + Query childrenFilter = childQuery.getFilter(); + return new DiversifyingChildrenFloatKnnVectorQuery(vectorField, queryVector, childrenFilter, topK, parentFilter); + } else { + return new AllParentsAware( + query, getBitSetProducer(allParents), ScoreModeParser.parse(scoreMode), allParents); + } + } + + private Query getAcceptedParents(Query allParents) throws SyntaxError { + List parentFilterQueries = QueryUtils.parseFilterQueries(req); + BooleanQuery.Builder acceptedParentsBuilder = createBuilder(); + for (Query filter:parentFilterQueries) { + acceptedParentsBuilder.add(filter, BooleanClause.Occur.MUST); + } + acceptedParentsBuilder.add(allParents, BooleanClause.Occur.MUST); + Query acceptedParents = acceptedParentsBuilder.build(); + return acceptedParents; } BitSetProducer getBitSetProducer(Query query) { diff --git a/solr/core/src/java/org/apache/solr/search/join/FiltersQParser.java b/solr/core/src/java/org/apache/solr/search/join/FiltersQParser.java index 05c705aa1ce1..45036ebffece 100644 --- a/solr/core/src/java/org/apache/solr/search/join/FiltersQParser.java +++ b/solr/core/src/java/org/apache/solr/search/join/FiltersQParser.java @@ -73,7 +73,7 @@ protected Query unwrapQuery(Query query, BooleanClause.Occur occur) { return query; } - protected Query wrapSubordinateClause(Query subordinate) throws SyntaxError { + protected Query wrapSubordinateClause(BooleanQuery subordinate) throws SyntaxError { return subordinate; } diff --git a/solr/core/src/java/org/apache/solr/search/neural/KnnQParser.java b/solr/core/src/java/org/apache/solr/search/neural/KnnQParser.java index b6d9f2541cd0..f917c7731368 100644 --- a/solr/core/src/java/org/apache/solr/search/neural/KnnQParser.java +++ b/solr/core/src/java/org/apache/solr/search/neural/KnnQParser.java @@ -40,7 +40,13 @@ public Query parse() throws SyntaxError { final String vectorToSearch = getVectorToSearch(); final int topK = localParams.getInt(TOP_K, DEFAULT_TOP_K); - return denseVectorType.getKnnVectorQuery( - schemaField.getName(), vectorToSearch, topK, getFilterQuery()); + if(schemaField.multiValued()){ + return denseVectorType.getMultiValuedKnnVectorQuery(req, + schemaField.getName(), vectorToSearch, topK, getFilterQuery()); + } else { + return denseVectorType.getKnnVectorQuery( + schemaField.getName(), vectorToSearch, topK, getFilterQuery()); + } + } } diff --git a/solr/core/src/java/org/apache/solr/update/AddUpdateCommand.java b/solr/core/src/java/org/apache/solr/update/AddUpdateCommand.java index 5ca176ea66ae..7b3c86ac4c54 100644 --- a/solr/core/src/java/org/apache/solr/update/AddUpdateCommand.java +++ b/solr/core/src/java/org/apache/solr/update/AddUpdateCommand.java @@ -29,6 +29,7 @@ import org.apache.solr.common.SolrInputField; import org.apache.solr.common.params.CommonParams; import org.apache.solr.request.SolrQueryRequest; +import org.apache.solr.schema.DenseVectorField; import org.apache.solr.schema.IndexSchema; import org.apache.solr.schema.SchemaField; @@ -256,7 +257,9 @@ private List flatten(SolrInputDocument root) { /** Extract all child documents from parent that are saved in fields */ private void flattenLabelled( List unwrappedDocs, SolrInputDocument currentDoc, boolean isRoot) { + IndexSchema schema = req.getSchema(); for (SolrInputField field : currentDoc.values()) { + SchemaField sfield = schema.getFieldOrNull(field.getName()); Object value = field.getFirstValue(); // check if value is a childDocument if (value instanceof SolrInputDocument) { @@ -270,7 +273,15 @@ private void flattenLabelled( for (SolrInputDocument child : childrenList) { flattenLabelled(unwrappedDocs, child); } - } + } else if (sfield.getType() instanceof DenseVectorField && sfield.multiValued() && isRoot){ + Collection vectorValues = field.getValues(); + for(Object vectorValue:vectorValues){ + SolrInputDocument singleVectorNestedDoc = new SolrInputDocument(); + singleVectorNestedDoc.setField(field.getName(), vectorValue); + flattenLabelled(unwrappedDocs, singleVectorNestedDoc); + } + + } } if (!isRoot) unwrappedDocs.add(currentDoc); diff --git a/solr/core/src/test/org/apache/solr/search/join/BJQParserTest.java b/solr/core/src/test/org/apache/solr/search/join/BJQParserTest.java index 5fcb7455187e..346ea01a8aa3 100644 --- a/solr/core/src/test/org/apache/solr/search/join/BJQParserTest.java +++ b/solr/core/src/test/org/apache/solr/search/join/BJQParserTest.java @@ -151,8 +151,9 @@ private static void addGrandChildren(List block) { @Test public void testFull() { - String childb = "{!parent which=\"parent_s:[* TO *]\"}child_s:l"; - assertQ(req("q", childb), sixParents); + //{!knn f=vector topK=10}[1.0, 2.0, 3.0, 4.0]&fq:acl:foo + String childb = "{!parent which=\"parent_s:[* TO *]\"}{!knn f=vector topK=10}[1.0, 2.0, 3.0, 4.0]"; + assertQ(req("q", childb,"fq","parent_s:a"), sixParents); } private static final String sixParents[] = From 69dcae345a94474cf139389d8b05ade4c0e220f8 Mon Sep 17 00:00:00 2001 From: Alessandro Benedetti Date: Wed, 2 Apr 2025 09:53:49 +0100 Subject: [PATCH 02/43] Only Nested Vectors changes --- .../apache/solr/schema/DenseVectorField.java | 24 ------------------- .../search/join/BlockJoinParentQParser.java | 1 - .../apache/solr/search/neural/KnnQParser.java | 10 ++------ .../apache/solr/update/AddUpdateCommand.java | 13 +--------- 4 files changed, 3 insertions(+), 45 deletions(-) diff --git a/solr/core/src/java/org/apache/solr/schema/DenseVectorField.java b/solr/core/src/java/org/apache/solr/schema/DenseVectorField.java index a3dfb8f9ae95..773c1e6337d1 100644 --- a/solr/core/src/java/org/apache/solr/schema/DenseVectorField.java +++ b/solr/core/src/java/org/apache/solr/schema/DenseVectorField.java @@ -40,15 +40,10 @@ import org.apache.lucene.search.KnnFloatVectorQuery; import org.apache.lucene.search.Query; import org.apache.lucene.search.SortField; -import org.apache.lucene.search.join.BitSetProducer; -import org.apache.lucene.search.join.DiversifyingChildrenByteKnnVectorQuery; -import org.apache.lucene.search.join.DiversifyingChildrenFloatKnnVectorQuery; import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.hnsw.HnswGraph; import org.apache.solr.common.SolrException; -import org.apache.solr.request.SolrQueryRequest; import org.apache.solr.search.QParser; -import org.apache.solr.search.join.BlockJoinParentQParser; import org.apache.solr.uninverting.UninvertingReader; import org.apache.solr.util.vector.ByteDenseVectorParser; import org.apache.solr.util.vector.DenseVectorParser; @@ -389,25 +384,6 @@ public Query getKnnVectorQuery( } } - public Query getMultiValuedKnnVectorQuery(final SolrQueryRequest request, - String fieldName, String vectorToSearch, int topK, Query filterQuery) { - - DenseVectorParser vectorBuilder = - getVectorBuilder(vectorToSearch, DenseVectorParser.BuilderPhase.QUERY); - - BitSetProducer acceptedDocuments = BlockJoinParentQParser.getCachedBitSetProducer(request, filterQuery); - switch (vectorEncoding) { - case FLOAT32: - return new DiversifyingChildrenFloatKnnVectorQuery(fieldName, vectorBuilder.getFloatVector(), null, topK, acceptedDocuments); - case BYTE: - return new DiversifyingChildrenByteKnnVectorQuery(fieldName, vectorBuilder.getByteVector(), null, topK, acceptedDocuments); - default: - throw new SolrException( - SolrException.ErrorCode.SERVER_ERROR, - "Unexpected state. Vector Encoding: " + vectorEncoding); - } - } - /** * Not Supported. Please use the {!knn} query parser to run K nearest neighbors search queries. */ diff --git a/solr/core/src/java/org/apache/solr/search/join/BlockJoinParentQParser.java b/solr/core/src/java/org/apache/solr/search/join/BlockJoinParentQParser.java index 348ff632327f..bf4eee1902c2 100644 --- a/solr/core/src/java/org/apache/solr/search/join/BlockJoinParentQParser.java +++ b/solr/core/src/java/org/apache/solr/search/join/BlockJoinParentQParser.java @@ -19,7 +19,6 @@ import java.io.IOException; import java.io.UncheckedIOException; import java.util.List; -import java.util.Map; import java.util.Objects; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.search.BooleanClause; diff --git a/solr/core/src/java/org/apache/solr/search/neural/KnnQParser.java b/solr/core/src/java/org/apache/solr/search/neural/KnnQParser.java index f917c7731368..b6d9f2541cd0 100644 --- a/solr/core/src/java/org/apache/solr/search/neural/KnnQParser.java +++ b/solr/core/src/java/org/apache/solr/search/neural/KnnQParser.java @@ -40,13 +40,7 @@ public Query parse() throws SyntaxError { final String vectorToSearch = getVectorToSearch(); final int topK = localParams.getInt(TOP_K, DEFAULT_TOP_K); - if(schemaField.multiValued()){ - return denseVectorType.getMultiValuedKnnVectorQuery(req, - schemaField.getName(), vectorToSearch, topK, getFilterQuery()); - } else { - return denseVectorType.getKnnVectorQuery( - schemaField.getName(), vectorToSearch, topK, getFilterQuery()); - } - + return denseVectorType.getKnnVectorQuery( + schemaField.getName(), vectorToSearch, topK, getFilterQuery()); } } diff --git a/solr/core/src/java/org/apache/solr/update/AddUpdateCommand.java b/solr/core/src/java/org/apache/solr/update/AddUpdateCommand.java index 7b3c86ac4c54..5ca176ea66ae 100644 --- a/solr/core/src/java/org/apache/solr/update/AddUpdateCommand.java +++ b/solr/core/src/java/org/apache/solr/update/AddUpdateCommand.java @@ -29,7 +29,6 @@ import org.apache.solr.common.SolrInputField; import org.apache.solr.common.params.CommonParams; import org.apache.solr.request.SolrQueryRequest; -import org.apache.solr.schema.DenseVectorField; import org.apache.solr.schema.IndexSchema; import org.apache.solr.schema.SchemaField; @@ -257,9 +256,7 @@ private List flatten(SolrInputDocument root) { /** Extract all child documents from parent that are saved in fields */ private void flattenLabelled( List unwrappedDocs, SolrInputDocument currentDoc, boolean isRoot) { - IndexSchema schema = req.getSchema(); for (SolrInputField field : currentDoc.values()) { - SchemaField sfield = schema.getFieldOrNull(field.getName()); Object value = field.getFirstValue(); // check if value is a childDocument if (value instanceof SolrInputDocument) { @@ -273,15 +270,7 @@ private void flattenLabelled( for (SolrInputDocument child : childrenList) { flattenLabelled(unwrappedDocs, child); } - } else if (sfield.getType() instanceof DenseVectorField && sfield.multiValued() && isRoot){ - Collection vectorValues = field.getValues(); - for(Object vectorValue:vectorValues){ - SolrInputDocument singleVectorNestedDoc = new SolrInputDocument(); - singleVectorNestedDoc.setField(field.getName(), vectorValue); - flattenLabelled(unwrappedDocs, singleVectorNestedDoc); - } - - } + } } if (!isRoot) unwrappedDocs.add(currentDoc); From 19ac8945ea850c1271adb1e7fba3e539658eaa8f Mon Sep 17 00:00:00 2001 From: Alessandro Benedetti Date: Fri, 4 Apr 2025 18:45:02 +0100 Subject: [PATCH 03/43] first tests draft, parent filter and children filter missing as a test --- .../search/join/BlockJoinParentQParser.java | 63 +++-- .../solr/search/join/BJQParserTest.java | 5 +- .../BlockJoinNestedVectorsQParserTest.java | 219 ++++++++++++++++++ 3 files changed, 265 insertions(+), 22 deletions(-) create mode 100644 solr/core/src/test/org/apache/solr/search/join/BlockJoinNestedVectorsQParserTest.java diff --git a/solr/core/src/java/org/apache/solr/search/join/BlockJoinParentQParser.java b/solr/core/src/java/org/apache/solr/search/join/BlockJoinParentQParser.java index bf4eee1902c2..732fbbd434e6 100644 --- a/solr/core/src/java/org/apache/solr/search/join/BlockJoinParentQParser.java +++ b/solr/core/src/java/org/apache/solr/search/join/BlockJoinParentQParser.java @@ -38,6 +38,7 @@ import org.apache.lucene.search.join.DiversifyingChildrenFloatKnnVectorQuery; import org.apache.lucene.search.join.QueryBitSetProducer; import org.apache.lucene.search.join.ScoreMode; +import org.apache.lucene.search.join.ToChildBlockJoinQuery; import org.apache.lucene.search.join.ToParentBlockJoinQuery; import org.apache.lucene.util.BitSet; import org.apache.lucene.util.BitSetIterator; @@ -90,37 +91,61 @@ protected Query createQuery(final Query allParents, BooleanQuery query, String s throws SyntaxError { List clauses = query.clauses(); if (clauses.size() == 1 && clauses.get(0).getQuery().getClass().equals(KnnByteVectorQuery.class)) { - Query acceptedParents = getAcceptedParents(allParents); - KnnByteVectorQuery childQuery = (KnnByteVectorQuery) clauses.get(0).getQuery(); - String vectorField = childQuery.getField(); - byte[] queryVector = childQuery.getTargetCopy(); - int topK = childQuery.getK(); - BitSetProducer parentFilter = getBitSetProducer(acceptedParents); - Query childrenFilter = childQuery.getFilter(); - return new DiversifyingChildrenByteKnnVectorQuery(vectorField, queryVector, childrenFilter, topK, parentFilter); + BitSetProducer allParentsBitSet = getBitSetProducer(allParents); + BooleanQuery parentsFilter = getAdditionalParentFilters(); + + KnnByteVectorQuery knnChildrenQuery = (KnnByteVectorQuery) clauses.get(0).getQuery(); + String vectorField = knnChildrenQuery.getField(); + byte[] queryVector = knnChildrenQuery.getTargetCopy(); + int topK = knnChildrenQuery.getK(); + + Query acceptedChildren = getAcceptedChildren(knnChildrenQuery.getFilter(), parentsFilter, allParentsBitSet); + + Query knnChildren = new DiversifyingChildrenByteKnnVectorQuery(vectorField, queryVector, acceptedChildren, topK, allParentsBitSet); + return new ToParentBlockJoinQuery(knnChildren, allParentsBitSet, ScoreModeParser.parse(scoreMode)); } else if (clauses.size() == 1 && clauses.get(0).getQuery().getClass().equals(KnnFloatVectorQuery.class)) { - Query acceptedParents = getAcceptedParents(allParents); - KnnFloatVectorQuery childQuery = (KnnFloatVectorQuery) clauses.get(0).getQuery(); - String vectorField = childQuery.getField(); - float[] queryVector = childQuery.getTargetCopy(); - int topK = childQuery.getK(); - BitSetProducer parentFilter = getBitSetProducer(acceptedParents); - Query childrenFilter = childQuery.getFilter(); - return new DiversifyingChildrenFloatKnnVectorQuery(vectorField, queryVector, childrenFilter, topK, parentFilter); + BitSetProducer allParentsBitSet = getBitSetProducer(allParents); + BooleanQuery parentsFilter = getAdditionalParentFilters(); + + KnnFloatVectorQuery knnChildrenQuery = (KnnFloatVectorQuery) clauses.get(0).getQuery(); + String vectorField = knnChildrenQuery.getField(); + float[] queryVector = knnChildrenQuery.getTargetCopy(); + int topK = knnChildrenQuery.getK(); + + Query acceptedChildren = getAcceptedChildren(knnChildrenQuery.getFilter(), parentsFilter, allParentsBitSet); + + Query knnChildren = new DiversifyingChildrenFloatKnnVectorQuery(vectorField, queryVector, acceptedChildren, topK, allParentsBitSet); + return new ToParentBlockJoinQuery(knnChildren, allParentsBitSet, ScoreModeParser.parse(scoreMode)); } else { return new AllParentsAware( query, getBitSetProducer(allParents), ScoreModeParser.parse(scoreMode), allParents); } } - private Query getAcceptedParents(Query allParents) throws SyntaxError { + private Query getAcceptedChildren(Query knnChildrenQuery, BooleanQuery parentsFilter, BitSetProducer allParentsBitSet) { + Query childrenFilter = knnChildrenQuery; + Query acceptedChildren = childrenFilter; + + if (parentsFilter.clauses().size() >0) { + Query acceptedChildrenBasedOnParentsFilter = new ToChildBlockJoinQuery(parentsFilter, allParentsBitSet); + BooleanQuery.Builder acceptedChildrenBuilder = createBuilder(); + if (childrenFilter != null) { + acceptedChildrenBuilder.add(childrenFilter, BooleanClause.Occur.MUST); + } + acceptedChildrenBuilder.add(acceptedChildrenBasedOnParentsFilter, BooleanClause.Occur.MUST); + + acceptedChildren = acceptedChildrenBuilder.build(); + } + return acceptedChildren; + } + + private BooleanQuery getAdditionalParentFilters() throws SyntaxError { List parentFilterQueries = QueryUtils.parseFilterQueries(req); BooleanQuery.Builder acceptedParentsBuilder = createBuilder(); for (Query filter:parentFilterQueries) { acceptedParentsBuilder.add(filter, BooleanClause.Occur.MUST); } - acceptedParentsBuilder.add(allParents, BooleanClause.Occur.MUST); - Query acceptedParents = acceptedParentsBuilder.build(); + BooleanQuery acceptedParents = acceptedParentsBuilder.build(); return acceptedParents; } diff --git a/solr/core/src/test/org/apache/solr/search/join/BJQParserTest.java b/solr/core/src/test/org/apache/solr/search/join/BJQParserTest.java index 346ea01a8aa3..5fcb7455187e 100644 --- a/solr/core/src/test/org/apache/solr/search/join/BJQParserTest.java +++ b/solr/core/src/test/org/apache/solr/search/join/BJQParserTest.java @@ -151,9 +151,8 @@ private static void addGrandChildren(List block) { @Test public void testFull() { - //{!knn f=vector topK=10}[1.0, 2.0, 3.0, 4.0]&fq:acl:foo - String childb = "{!parent which=\"parent_s:[* TO *]\"}{!knn f=vector topK=10}[1.0, 2.0, 3.0, 4.0]"; - assertQ(req("q", childb,"fq","parent_s:a"), sixParents); + String childb = "{!parent which=\"parent_s:[* TO *]\"}child_s:l"; + assertQ(req("q", childb), sixParents); } private static final String sixParents[] = diff --git a/solr/core/src/test/org/apache/solr/search/join/BlockJoinNestedVectorsQParserTest.java b/solr/core/src/test/org/apache/solr/search/join/BlockJoinNestedVectorsQParserTest.java new file mode 100644 index 000000000000..d42cc461799e --- /dev/null +++ b/solr/core/src/test/org/apache/solr/search/join/BlockJoinNestedVectorsQParserTest.java @@ -0,0 +1,219 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.search.join; + +import org.apache.lucene.search.join.ScoreMode; +import org.apache.solr.SolrTestCaseJ4; +import org.apache.solr.common.SolrException; +import org.apache.solr.common.SolrInputDocument; +import org.apache.solr.core.SolrCore; +import org.apache.solr.metrics.MetricsMap; +import org.apache.solr.metrics.SolrMetricManager; +import org.apache.solr.util.BaseTestHarness; +import org.apache.solr.util.RandomNoReverseMergePolicyFactory; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.ClassRule; +import org.junit.Test; +import org.junit.rules.TestRule; + +import javax.xml.xpath.XPathConstants; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Locale; +import java.util.Map; + +public class BlockJoinNestedVectorsQParserTest extends SolrTestCaseJ4 { + + private static final String[] klm = new String[] {"k", "l", "m"}; + private static final String[] abcdef = new String[] {"a", "b", "c", "d", "e", "f"}; + private static int vectorsIndex = 30; + private static List floatQueryVector = Arrays.asList(1.0f, 1.0f, 1.0f, 1.0f); + private static List byteQueryVector = Arrays.asList(1, 1, 1, 1); + + + + + @ClassRule + public static final TestRule noReverseMerge = RandomNoReverseMergePolicyFactory.createRule(); + + @BeforeClass + public static void beforeClass() throws Exception { + initCore("solrconfig.xml", "schema15.xml"); + prepareIndex(); + } + + + public static void prepareIndex() throws Exception { + List docsToIndex = prepareDocs(); + for (SolrInputDocument doc : docsToIndex) { + assertU(adoc(doc)); + } + + assertU(commit()); + } + + private static List prepareDocs() { + int parentCount = 10; + int perParentChildren = 3; + List docs = new ArrayList<>(parentCount); + for (int i = 1; i < parentCount + 1; i++) { + SolrInputDocument doc = new SolrInputDocument(); + doc.setField("id", i); + doc.setField("parent_b", true); + + doc.setField("parent_s", abcdef[i % abcdef.length]); + List children = new ArrayList<>(perParentChildren); + + for(int j = 0; j < perParentChildren; j++) { + SolrInputDocument child = new SolrInputDocument(); + child.setField("id", i+""+j); + child.setField("child_s", klm[i % klm.length]); + child.setField("vector", perElementAddFloat(floatQueryVector, vectorsIndex)); + child.setField("vector_byte", perElementAddInteger(byteQueryVector, vectorsIndex)); + vectorsIndex--; + children.add(child); + } + doc.setField("vectors",children); + docs.add(doc); + } + + return docs; + } + + private static List perElementAddFloat(List vector, int value) { + List result = new ArrayList<>(vector.size()); + for (int i = 0; i < vector.size(); i++) { + if (i == 0) { + result.add(vector.get(i) + value); + } else { + result.add(vector.get(i)); + } + } + return result; + } + + + private static List perElementAddInteger(List vector, int value){ + List result = new ArrayList<>(vector.size()); + for (int i = 0; i < vector.size(); i++) { + if (i == 0) { + result.add(vector.get(i) + value); + } else { + result.add(vector.get(i)); + } + } + return result; + } + + @Test + public void childrenRetrievalFloat_filteringByParentMetadata_shouldReturnKnnChildren() { + assertQ( + req( + "fq", "{!child of=$allParents filters=$parent.fq}", + "q", "{!knn f=vector topK=5}[1.0, 1.0, 1.0, 1.0]", + "fl", "id", + "parent.fq", "parent_s:(a c)", + "allParents", "parent_s:[* TO *]"), + "//*[@numFound='5']", + "//result/doc[1]/str[@name='id'][.='82']", + "//result/doc[2]/str[@name='id'][.='81']", + "//result/doc[3]/str[@name='id'][.='80']", + "//result/doc[4]/str[@name='id'][.='62']", + "//result/doc[5]/str[@name='id'][.='61']"); + } + + @Test + public void childrenRetrievalByte_filteringByParentMetadata_shouldReturnKnnChildren() { + assertQ( + req( + "fq", "{!child of=$allParents filters=$parent.fq}", + "q", "{!knn f=vector_byte topK=5}[1, 1, 1, 1]", + "fl", "id", + "parent.fq", "parent_s:(a c)", + "allParents", "parent_s:[* TO *]"), + "//*[@numFound='5']", + "//result/doc[1]/str[@name='id'][.='82']", + "//result/doc[2]/str[@name='id'][.='81']", + "//result/doc[3]/str[@name='id'][.='80']", + "//result/doc[4]/str[@name='id'][.='62']", + "//result/doc[5]/str[@name='id'][.='61']"); + } + + @Test + public void parentRetrievalFloat_knnChildren_shouldReturnKnnChildren() { + assertQ( + req( + // "fq", "parent_s:(a c)", + "q", "{!parent which=$allParents score=max v=$children.q}", + "fl", "id,score", + "children.q", "{!knn f=vector topK=3}[1.0, 1.0, 1.0, 1.0]", + "allParents", "parent_s:[* TO *]"), + "//*[@numFound='3']", + "//result/doc[1]/str[@name='id'][.='10']", + "//result/doc[2]/str[@name='id'][.='9']", + "//result/doc[3]/str[@name='id'][.='8']"); + } + + @Test + public void parentRetrievalFloat_knnChildrenWithParentFilter_shouldReturnKnnChildren() { + assertQ( + req( + "fq", "parent_s:(a c)", + "q", "{!parent which=$allParents score=max v=$children.q}", + "fl", "id,score", + "children.q", "{!knn f=vector topK=3}[1.0, 1.0, 1.0, 1.0]", + "allParents", "parent_s:[* TO *]"), + "//*[@numFound='3']", + "//result/doc[1]/str[@name='id'][.='8']", + "//result/doc[2]/str[@name='id'][.='6']", + "//result/doc[3]/str[@name='id'][.='2']"); + } + + @Test + public void parentRetrievalByte_knnChildren_shouldReturnKnnChildren() { + assertQ( + req( + // "fq", "parent_s:(a c)", + "q", "{!parent which=$allParents score=max v=$children.q}", + "fl", "id,score", + "children.q", "{!knn f=vector_byte topK=3}[1, 1, 1, 1]", + "allParents", "parent_s:[* TO *]"), + "//*[@numFound='3']", + "//result/doc[1]/str[@name='id'][.='10']", + "//result/doc[2]/str[@name='id'][.='9']", + "//result/doc[3]/str[@name='id'][.='8']"); + } + + @Test + public void parentRetrievalByte_knnChildrenWithParentFilter_shouldReturnKnnChildren() { + assertQ( + req( + "fq", "parent_s:(a c)", + "q", "{!parent which=$allParents score=max v=$children.q}", + "fl", "id,score", + "children.q", "{!knn f=vector_byte topK=3}[1, 1, 1, 1]", + "allParents", "parent_s:[* TO *]"), + "//*[@numFound='3']", + "//result/doc[1]/str[@name='id'][.='8']", + "//result/doc[2]/str[@name='id'][.='6']", + "//result/doc[3]/str[@name='id'][.='2']"); + } + + +} From 20792a5113215aac3620118542145af6be5d1497 Mon Sep 17 00:00:00 2001 From: Alessandro Benedetti Date: Tue, 8 Apr 2025 11:14:17 +0100 Subject: [PATCH 04/43] tests cleaned --- .../BlockJoinNestedVectorsQParserTest.java | 123 ++++++++++++------ 1 file changed, 82 insertions(+), 41 deletions(-) diff --git a/solr/core/src/test/org/apache/solr/search/join/BlockJoinNestedVectorsQParserTest.java b/solr/core/src/test/org/apache/solr/search/join/BlockJoinNestedVectorsQParserTest.java index d42cc461799e..5e1e25d222aa 100644 --- a/solr/core/src/test/org/apache/solr/search/join/BlockJoinNestedVectorsQParserTest.java +++ b/solr/core/src/test/org/apache/solr/search/join/BlockJoinNestedVectorsQParserTest.java @@ -16,39 +16,23 @@ */ package org.apache.solr.search.join; -import org.apache.lucene.search.join.ScoreMode; import org.apache.solr.SolrTestCaseJ4; -import org.apache.solr.common.SolrException; import org.apache.solr.common.SolrInputDocument; -import org.apache.solr.core.SolrCore; -import org.apache.solr.metrics.MetricsMap; -import org.apache.solr.metrics.SolrMetricManager; -import org.apache.solr.util.BaseTestHarness; import org.apache.solr.util.RandomNoReverseMergePolicyFactory; -import org.junit.AfterClass; import org.junit.BeforeClass; import org.junit.ClassRule; import org.junit.Test; import org.junit.rules.TestRule; -import javax.xml.xpath.XPathConstants; import java.util.ArrayList; import java.util.Arrays; import java.util.List; -import java.util.Locale; -import java.util.Map; - -public class BlockJoinNestedVectorsQParserTest extends SolrTestCaseJ4 { - - private static final String[] klm = new String[] {"k", "l", "m"}; - private static final String[] abcdef = new String[] {"a", "b", "c", "d", "e", "f"}; - private static int vectorsIndex = 30; - private static List floatQueryVector = Arrays.asList(1.0f, 1.0f, 1.0f, 1.0f); - private static List byteQueryVector = Arrays.asList(1, 1, 1, 1); - - +public class BlockJoinNestedVectorsQParserTest extends SolrTestCaseJ4 { + private static final List FLOAT_QUERY_VECTOR = Arrays.asList(1.0f, 1.0f, 1.0f, 1.0f); + private static final List BYTE_QUERY_VECTOR = Arrays.asList(1, 1, 1, 1); + @ClassRule public static final TestRule noReverseMerge = RandomNoReverseMergePolicyFactory.createRule(); @@ -64,15 +48,28 @@ public static void prepareIndex() throws Exception { for (SolrInputDocument doc : docsToIndex) { assertU(adoc(doc)); } - assertU(commit()); } + /** + * The documents in the index are 10 parents, with some parent level metadata and + * 30 nested documents (with vectors and children level metadata) + * Each parent document has 3 nested documents with vectors. + * + * This allows to run knn queries both at parent/children level and using various pre-filters both + * for parent metadata and children. + * @return + */ private static List prepareDocs() { - int parentCount = 10; - int perParentChildren = 3; - List docs = new ArrayList<>(parentCount); - for (int i = 1; i < parentCount + 1; i++) { + int totalParentDocuments = 10; + int totalNestedVectors = 30; + int perParentChildren = totalNestedVectors / totalParentDocuments; + + final String[] klm = new String[] {"k", "l", "m"}; + final String[] abcdef = new String[] {"a", "b", "c", "d", "e", "f"}; + + List docs = new ArrayList<>(totalParentDocuments); + for (int i = 1; i < totalParentDocuments + 1; i++) { SolrInputDocument doc = new SolrInputDocument(); doc.setField("id", i); doc.setField("parent_b", true); @@ -80,13 +77,14 @@ private static List prepareDocs() { doc.setField("parent_s", abcdef[i % abcdef.length]); List children = new ArrayList<>(perParentChildren); + //nested vector documents have a distance from the query vector inversely proportional to their id for(int j = 0; j < perParentChildren; j++) { SolrInputDocument child = new SolrInputDocument(); child.setField("id", i+""+j); child.setField("child_s", klm[i % klm.length]); - child.setField("vector", perElementAddFloat(floatQueryVector, vectorsIndex)); - child.setField("vector_byte", perElementAddInteger(byteQueryVector, vectorsIndex)); - vectorsIndex--; + child.setField("vector", outDistanceFloat(FLOAT_QUERY_VECTOR, totalNestedVectors)); + child.setField("vector_byte", outDistanceByte(BYTE_QUERY_VECTOR, totalNestedVectors)); + totalNestedVectors--; //the higher the id of the nested document, lower the distance with the query vector children.add(child); } doc.setField("vectors",children); @@ -96,7 +94,15 @@ private static List prepareDocs() { return docs; } - private static List perElementAddFloat(List vector, int value) { + /** + * Generate a resulting float vector with a distance from the original vector that is proportional to the value in input + * (higher the value, higher the distance from the original vector) + * + * @param vector + * @param value + * @return + */ + private static List outDistanceFloat(List vector, int value) { List result = new ArrayList<>(vector.size()); for (int i = 0; i < vector.size(); i++) { if (i == 0) { @@ -107,9 +113,16 @@ private static List perElementAddFloat(List vector, int value) { } return result; } - - private static List perElementAddInteger(List vector, int value){ + /** + * Generate a resulting byte vector with a distance from the original vector that is proportional to the value in input + * (higher the value, higher the distance from the original vector) + * + * @param vector + * @param value + * @return + */ + private static List outDistanceByte(List vector, int value){ List result = new ArrayList<>(vector.size()); for (int i = 0; i < vector.size(); i++) { if (i == 0) { @@ -126,7 +139,7 @@ public void childrenRetrievalFloat_filteringByParentMetadata_shouldReturnKnnChil assertQ( req( "fq", "{!child of=$allParents filters=$parent.fq}", - "q", "{!knn f=vector topK=5}[1.0, 1.0, 1.0, 1.0]", + "q", "{!knn f=vector topK=5}"+FLOAT_QUERY_VECTOR, "fl", "id", "parent.fq", "parent_s:(a c)", "allParents", "parent_s:[* TO *]"), @@ -143,7 +156,7 @@ public void childrenRetrievalByte_filteringByParentMetadata_shouldReturnKnnChild assertQ( req( "fq", "{!child of=$allParents filters=$parent.fq}", - "q", "{!knn f=vector_byte topK=5}[1, 1, 1, 1]", + "q", "{!knn f=vector_byte topK=5}"+BYTE_QUERY_VECTOR, "fl", "id", "parent.fq", "parent_s:(a c)", "allParents", "parent_s:[* TO *]"), @@ -156,13 +169,13 @@ public void childrenRetrievalByte_filteringByParentMetadata_shouldReturnKnnChild } @Test - public void parentRetrievalFloat_knnChildren_shouldReturnKnnChildren() { + public void parentRetrievalFloat_knnChildren_shouldReturnKnnParents() { assertQ( req( // "fq", "parent_s:(a c)", "q", "{!parent which=$allParents score=max v=$children.q}", "fl", "id,score", - "children.q", "{!knn f=vector topK=3}[1.0, 1.0, 1.0, 1.0]", + "children.q", "{!knn f=vector topK=3}"+FLOAT_QUERY_VECTOR, "allParents", "parent_s:[* TO *]"), "//*[@numFound='3']", "//result/doc[1]/str[@name='id'][.='10']", @@ -171,13 +184,13 @@ public void parentRetrievalFloat_knnChildren_shouldReturnKnnChildren() { } @Test - public void parentRetrievalFloat_knnChildrenWithParentFilter_shouldReturnKnnChildren() { + public void parentRetrievalFloat_knnChildrenWithParentFilter_shouldReturnKnnParents() { assertQ( req( "fq", "parent_s:(a c)", "q", "{!parent which=$allParents score=max v=$children.q}", "fl", "id,score", - "children.q", "{!knn f=vector topK=3}[1.0, 1.0, 1.0, 1.0]", + "children.q", "{!knn f=vector topK=3}"+FLOAT_QUERY_VECTOR, "allParents", "parent_s:[* TO *]"), "//*[@numFound='3']", "//result/doc[1]/str[@name='id'][.='8']", @@ -186,13 +199,27 @@ public void parentRetrievalFloat_knnChildrenWithParentFilter_shouldReturnKnnChil } @Test - public void parentRetrievalByte_knnChildren_shouldReturnKnnChildren() { + public void parentRetrievalFloat_knnChildrenWithParentFilterAndChildrenFilter_shouldReturnKnnParents() { + assertQ( + req( + "fq", "parent_s:(a c)", + "q", "{!parent which=$allParents score=max v=$children.q}", + "fl", "id,score", + "children.q", "{!knn f=vector topK=3 preFilter=child_s:m}"+FLOAT_QUERY_VECTOR, + "allParents", "parent_s:[* TO *]"), + "//*[@numFound='2']", + "//result/doc[1]/str[@name='id'][.='8']", + "//result/doc[2]/str[@name='id'][.='2']"); + } + + @Test + public void parentRetrievalByte_knnChildren_shouldReturnKnnParents() { assertQ( req( // "fq", "parent_s:(a c)", "q", "{!parent which=$allParents score=max v=$children.q}", "fl", "id,score", - "children.q", "{!knn f=vector_byte topK=3}[1, 1, 1, 1]", + "children.q", "{!knn f=vector_byte topK=3}"+BYTE_QUERY_VECTOR, "allParents", "parent_s:[* TO *]"), "//*[@numFound='3']", "//result/doc[1]/str[@name='id'][.='10']", @@ -201,19 +228,33 @@ public void parentRetrievalByte_knnChildren_shouldReturnKnnChildren() { } @Test - public void parentRetrievalByte_knnChildrenWithParentFilter_shouldReturnKnnChildren() { + public void parentRetrievalByte_knnChildrenWithParentFilter_shouldReturnKnnParents() { assertQ( req( "fq", "parent_s:(a c)", "q", "{!parent which=$allParents score=max v=$children.q}", "fl", "id,score", - "children.q", "{!knn f=vector_byte topK=3}[1, 1, 1, 1]", + "children.q", "{!knn f=vector_byte topK=3}"+BYTE_QUERY_VECTOR, "allParents", "parent_s:[* TO *]"), "//*[@numFound='3']", "//result/doc[1]/str[@name='id'][.='8']", "//result/doc[2]/str[@name='id'][.='6']", "//result/doc[3]/str[@name='id'][.='2']"); } + + @Test + public void parentRetrievalByte_knnChildrenWithParentFilterAndChildrenFilter_shouldReturnKnnParents() { + assertQ( + req( + "fq", "parent_s:(a c)", + "q", "{!parent which=$allParents score=max v=$children.q}", + "fl", "id,score", + "children.q", "{!knn f=vector_byte topK=3 preFilter=child_s:m}"+BYTE_QUERY_VECTOR, + "allParents", "parent_s:[* TO *]"), + "//*[@numFound='2']", + "//result/doc[1]/str[@name='id'][.='8']", + "//result/doc[2]/str[@name='id'][.='2']"); + } } From f209d38261205da15acf6bb2eb2c0a37ea59561f Mon Sep 17 00:00:00 2001 From: Alessandro Benedetti Date: Tue, 8 Apr 2025 11:45:12 +0100 Subject: [PATCH 05/43] code cleanup --- .../search/join/BlockJoinChildQParser.java | 4 +- .../search/join/BlockJoinParentQParser.java | 51 +++++++++++-------- 2 files changed, 31 insertions(+), 24 deletions(-) diff --git a/solr/core/src/java/org/apache/solr/search/join/BlockJoinChildQParser.java b/solr/core/src/java/org/apache/solr/search/join/BlockJoinChildQParser.java index 2ac81ca7d95b..c49d296b8e76 100644 --- a/solr/core/src/java/org/apache/solr/search/join/BlockJoinChildQParser.java +++ b/solr/core/src/java/org/apache/solr/search/join/BlockJoinChildQParser.java @@ -33,8 +33,8 @@ public BlockJoinChildQParser( } @Override - protected Query createQuery(Query allParents, BooleanQuery query, String scoreMode) { - return new ToChildBlockJoinQuery(query, getBitSetProducer(allParents)); + protected Query createQuery(Query allParents, BooleanQuery parentQuery, String scoreMode) { + return new ToChildBlockJoinQuery(parentQuery, getBitSetProducer(allParents)); } @Override diff --git a/solr/core/src/java/org/apache/solr/search/join/BlockJoinParentQParser.java b/solr/core/src/java/org/apache/solr/search/join/BlockJoinParentQParser.java index 732fbbd434e6..792b6e887cca 100644 --- a/solr/core/src/java/org/apache/solr/search/join/BlockJoinParentQParser.java +++ b/solr/core/src/java/org/apache/solr/search/join/BlockJoinParentQParser.java @@ -87,63 +87,70 @@ protected Query noClausesQuery() throws SyntaxError { return new BitSetProducerQuery(getBitSetProducer(parseParentFilter())); } - protected Query createQuery(final Query allParents, BooleanQuery query, String scoreMode) + protected Query createQuery(final Query allParents, BooleanQuery childrenQuery, String scoreMode) throws SyntaxError { - List clauses = query.clauses(); - if (clauses.size() == 1 && clauses.get(0).getQuery().getClass().equals(KnnByteVectorQuery.class)) { + List childrenClauses = childrenQuery.clauses(); + if (isByteKnnQuery(childrenClauses)) { BitSetProducer allParentsBitSet = getBitSetProducer(allParents); - BooleanQuery parentsFilter = getAdditionalParentFilters(); + BooleanQuery parentsFilter = getParentsFilter(); - KnnByteVectorQuery knnChildrenQuery = (KnnByteVectorQuery) clauses.get(0).getQuery(); + KnnByteVectorQuery knnChildrenQuery = (KnnByteVectorQuery) childrenClauses.get(0).getQuery(); String vectorField = knnChildrenQuery.getField(); byte[] queryVector = knnChildrenQuery.getTargetCopy(); int topK = knnChildrenQuery.getK(); - Query acceptedChildren = getAcceptedChildren(knnChildrenQuery.getFilter(), parentsFilter, allParentsBitSet); + Query acceptedChildren = getChildrenFilter(knnChildrenQuery.getFilter(), parentsFilter, allParentsBitSet); Query knnChildren = new DiversifyingChildrenByteKnnVectorQuery(vectorField, queryVector, acceptedChildren, topK, allParentsBitSet); return new ToParentBlockJoinQuery(knnChildren, allParentsBitSet, ScoreModeParser.parse(scoreMode)); - } else if (clauses.size() == 1 && clauses.get(0).getQuery().getClass().equals(KnnFloatVectorQuery.class)) { + } else if (isFloatKnnQuery(childrenClauses)) { BitSetProducer allParentsBitSet = getBitSetProducer(allParents); - BooleanQuery parentsFilter = getAdditionalParentFilters(); + BooleanQuery parentsFilter = getParentsFilter(); - KnnFloatVectorQuery knnChildrenQuery = (KnnFloatVectorQuery) clauses.get(0).getQuery(); + KnnFloatVectorQuery knnChildrenQuery = (KnnFloatVectorQuery) childrenClauses.get(0).getQuery(); String vectorField = knnChildrenQuery.getField(); float[] queryVector = knnChildrenQuery.getTargetCopy(); int topK = knnChildrenQuery.getK(); - Query acceptedChildren = getAcceptedChildren(knnChildrenQuery.getFilter(), parentsFilter, allParentsBitSet); + Query childrenFilter = getChildrenFilter(knnChildrenQuery.getFilter(), parentsFilter, allParentsBitSet); - Query knnChildren = new DiversifyingChildrenFloatKnnVectorQuery(vectorField, queryVector, acceptedChildren, topK, allParentsBitSet); + Query knnChildren = new DiversifyingChildrenFloatKnnVectorQuery(vectorField, queryVector, childrenFilter, topK, allParentsBitSet); return new ToParentBlockJoinQuery(knnChildren, allParentsBitSet, ScoreModeParser.parse(scoreMode)); } else { return new AllParentsAware( - query, getBitSetProducer(allParents), ScoreModeParser.parse(scoreMode), allParents); + childrenQuery, getBitSetProducer(allParents), ScoreModeParser.parse(scoreMode), allParents); } } - private Query getAcceptedChildren(Query knnChildrenQuery, BooleanQuery parentsFilter, BitSetProducer allParentsBitSet) { - Query childrenFilter = knnChildrenQuery; - Query acceptedChildren = childrenFilter; + private boolean isFloatKnnQuery(List childrenClauses) { + return childrenClauses.size() == 1 && childrenClauses.get(0).getQuery().getClass().equals(KnnFloatVectorQuery.class); + } + + private boolean isByteKnnQuery(List childrenClauses) { + return childrenClauses.size() == 1 && childrenClauses.get(0).getQuery().getClass().equals(KnnByteVectorQuery.class); + } + + private Query getChildrenFilter(Query childrenKnnPreFilter, BooleanQuery parentsFilter, BitSetProducer allParentsBitSet) { + Query childrenFilter = childrenKnnPreFilter; if (parentsFilter.clauses().size() >0) { - Query acceptedChildrenBasedOnParentsFilter = new ToChildBlockJoinQuery(parentsFilter, allParentsBitSet); + Query acceptedChildrenBasedOnParentsFilter = new ToChildBlockJoinQuery(parentsFilter, allParentsBitSet); //no scoring happens here BooleanQuery.Builder acceptedChildrenBuilder = createBuilder(); if (childrenFilter != null) { - acceptedChildrenBuilder.add(childrenFilter, BooleanClause.Occur.MUST); + acceptedChildrenBuilder.add(childrenFilter, BooleanClause.Occur.FILTER); } - acceptedChildrenBuilder.add(acceptedChildrenBasedOnParentsFilter, BooleanClause.Occur.MUST); + acceptedChildrenBuilder.add(acceptedChildrenBasedOnParentsFilter, BooleanClause.Occur.FILTER); - acceptedChildren = acceptedChildrenBuilder.build(); + childrenFilter = acceptedChildrenBuilder.build(); } - return acceptedChildren; + return childrenFilter; } - private BooleanQuery getAdditionalParentFilters() throws SyntaxError { + private BooleanQuery getParentsFilter() throws SyntaxError { List parentFilterQueries = QueryUtils.parseFilterQueries(req); BooleanQuery.Builder acceptedParentsBuilder = createBuilder(); for (Query filter:parentFilterQueries) { - acceptedParentsBuilder.add(filter, BooleanClause.Occur.MUST); + acceptedParentsBuilder.add(filter, BooleanClause.Occur.FILTER); } BooleanQuery acceptedParents = acceptedParentsBuilder.build(); return acceptedParents; From 79c863b044537459250eb1844b055e882b544edf Mon Sep 17 00:00:00 2001 From: Alessandro Benedetti Date: Tue, 8 Apr 2025 12:11:48 +0100 Subject: [PATCH 06/43] draft documentation --- .../pages/searching-nested-documents.adoc | 34 +++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/solr/solr-ref-guide/modules/query-guide/pages/searching-nested-documents.adoc b/solr/solr-ref-guide/modules/query-guide/pages/searching-nested-documents.adoc index 124dce3b6de3..11bf301257b4 100644 --- a/solr/solr-ref-guide/modules/query-guide/pages/searching-nested-documents.adoc +++ b/solr/solr-ref-guide/modules/query-guide/pages/searching-nested-documents.adoc @@ -184,6 +184,22 @@ $ curl 'http://localhost:8983/solr/gettingstarted/select' -d 'omitHeader=true' - ---- ==== +[#vector-search-child] +[CAUTION] +.Vector search - children are nested documents with a vector field +==== +It is quite common to encode the original text of a document into multiple nested vectors. + +This may happen, among other use cases, because you chunked the original text into paragraphs, each of them modeled as a nested document with the paragraph text and the vector representation. + +Solr doesn't need to have denormalised nested documents, you can still retrieve the children paragraphs by knn vector search and prefilter them using parent level metadata. + +[source,text] +---- +$ curl 'http://localhost:8983/solr/gettingstarted/select' -d 'omitHeader=true' --data-urlencode 'fq={!child of=$block_mask filters=$parentsFilter}&q={!knn f=childVectorField topK=5}[1.0,2.5,3.0...]' --data-urlencode 'block_mask=(*:* -{!prefix f="_nest_path_" v="/skus/" parentsFilter="name_s:pen"})' +---- +==== + === Parent Query Parser The inverse of the `{!child}` query parser is the `{!parent}` query parser, which lets you search for the _ancestor_ documents of some child documents matching a wrapped query. @@ -265,6 +281,24 @@ $ curl 'http://localhost:8983/solr/gettingstarted/select' -d 'omitHeader=true' - Note that in the above example, the `/` characters in the `\_nest_path_` were "double escaped" in the `which` parameter, for the <> regarding the `{!child} pasers `of` parameter. ==== +[#vector-search-parent +[CAUTION] +.Vector search - children are nested documents with a vector field +==== +It is quite common to encode the original text of a document into multiple nested vectors. + +This may happen, among other use cases, because you chunked the original text into paragraphs, each of them modeled as a nested document with the paragraph text and the vector representation. + +You can run knn vector search on children documents (with potential prefiltering on children and/or parents metadata) and retrieve top-K parents. + +N.B. Solr ensures that the knn search for children keeps track of parent metadata filtering, guaranteeing top-k parents retrieval + +[source,text] +---- +$ curl 'http://localhost:8983/solr/gettingstarted/select' -d 'omitHeader=true' --data-urlencode 'fq=parentField:term&q={!parent which=$block_mask score=max v=$children.q}' --data-urlencode 'block_mask=(*:* -{!prefix f="_nest_path_" v="/skus/" children.q="{!knn f=vector topK=3 preFilter=childField:term}"[1.0,2.5,3.0...]})' +---- +==== + === Combining Block Join Query Parsers with Child Doc Transformer The combination of these two parsers with the `[child] transformer enables seamless creation of very powerful queries. From cba54734a725b3a5b5efe316753882c25b3f99fd Mon Sep 17 00:00:00 2001 From: Alessandro Benedetti Date: Tue, 8 Apr 2025 12:23:40 +0100 Subject: [PATCH 07/43] tidy --- .../search/join/BlockJoinParentQParser.java | 50 ++-- .../BlockJoinNestedVectorsQParserTest.java | 224 +++++++++--------- 2 files changed, 145 insertions(+), 129 deletions(-) diff --git a/solr/core/src/java/org/apache/solr/search/join/BlockJoinParentQParser.java b/solr/core/src/java/org/apache/solr/search/join/BlockJoinParentQParser.java index 792b6e887cca..0b03f720791d 100644 --- a/solr/core/src/java/org/apache/solr/search/join/BlockJoinParentQParser.java +++ b/solr/core/src/java/org/apache/solr/search/join/BlockJoinParentQParser.java @@ -88,7 +88,7 @@ protected Query noClausesQuery() throws SyntaxError { } protected Query createQuery(final Query allParents, BooleanQuery childrenQuery, String scoreMode) - throws SyntaxError { + throws SyntaxError { List childrenClauses = childrenQuery.clauses(); if (isByteKnnQuery(childrenClauses)) { BitSetProducer allParentsBitSet = getBitSetProducer(allParents); @@ -99,42 +99,58 @@ protected Query createQuery(final Query allParents, BooleanQuery childrenQuery, byte[] queryVector = knnChildrenQuery.getTargetCopy(); int topK = knnChildrenQuery.getK(); - Query acceptedChildren = getChildrenFilter(knnChildrenQuery.getFilter(), parentsFilter, allParentsBitSet); + Query acceptedChildren = + getChildrenFilter(knnChildrenQuery.getFilter(), parentsFilter, allParentsBitSet); - Query knnChildren = new DiversifyingChildrenByteKnnVectorQuery(vectorField, queryVector, acceptedChildren, topK, allParentsBitSet); - return new ToParentBlockJoinQuery(knnChildren, allParentsBitSet, ScoreModeParser.parse(scoreMode)); + Query knnChildren = + new DiversifyingChildrenByteKnnVectorQuery( + vectorField, queryVector, acceptedChildren, topK, allParentsBitSet); + return new ToParentBlockJoinQuery( + knnChildren, allParentsBitSet, ScoreModeParser.parse(scoreMode)); } else if (isFloatKnnQuery(childrenClauses)) { BitSetProducer allParentsBitSet = getBitSetProducer(allParents); BooleanQuery parentsFilter = getParentsFilter(); - - KnnFloatVectorQuery knnChildrenQuery = (KnnFloatVectorQuery) childrenClauses.get(0).getQuery(); + + KnnFloatVectorQuery knnChildrenQuery = + (KnnFloatVectorQuery) childrenClauses.get(0).getQuery(); String vectorField = knnChildrenQuery.getField(); float[] queryVector = knnChildrenQuery.getTargetCopy(); int topK = knnChildrenQuery.getK(); - Query childrenFilter = getChildrenFilter(knnChildrenQuery.getFilter(), parentsFilter, allParentsBitSet); + Query childrenFilter = + getChildrenFilter(knnChildrenQuery.getFilter(), parentsFilter, allParentsBitSet); - Query knnChildren = new DiversifyingChildrenFloatKnnVectorQuery(vectorField, queryVector, childrenFilter, topK, allParentsBitSet); - return new ToParentBlockJoinQuery(knnChildren, allParentsBitSet, ScoreModeParser.parse(scoreMode)); + Query knnChildren = + new DiversifyingChildrenFloatKnnVectorQuery( + vectorField, queryVector, childrenFilter, topK, allParentsBitSet); + return new ToParentBlockJoinQuery( + knnChildren, allParentsBitSet, ScoreModeParser.parse(scoreMode)); } else { return new AllParentsAware( - childrenQuery, getBitSetProducer(allParents), ScoreModeParser.parse(scoreMode), allParents); + childrenQuery, + getBitSetProducer(allParents), + ScoreModeParser.parse(scoreMode), + allParents); } } private boolean isFloatKnnQuery(List childrenClauses) { - return childrenClauses.size() == 1 && childrenClauses.get(0).getQuery().getClass().equals(KnnFloatVectorQuery.class); + return childrenClauses.size() == 1 + && childrenClauses.get(0).getQuery().getClass().equals(KnnFloatVectorQuery.class); } private boolean isByteKnnQuery(List childrenClauses) { - return childrenClauses.size() == 1 && childrenClauses.get(0).getQuery().getClass().equals(KnnByteVectorQuery.class); + return childrenClauses.size() == 1 + && childrenClauses.get(0).getQuery().getClass().equals(KnnByteVectorQuery.class); } - private Query getChildrenFilter(Query childrenKnnPreFilter, BooleanQuery parentsFilter, BitSetProducer allParentsBitSet) { + private Query getChildrenFilter( + Query childrenKnnPreFilter, BooleanQuery parentsFilter, BitSetProducer allParentsBitSet) { Query childrenFilter = childrenKnnPreFilter; - - if (parentsFilter.clauses().size() >0) { - Query acceptedChildrenBasedOnParentsFilter = new ToChildBlockJoinQuery(parentsFilter, allParentsBitSet); //no scoring happens here + + if (parentsFilter.clauses().size() > 0) { + Query acceptedChildrenBasedOnParentsFilter = + new ToChildBlockJoinQuery(parentsFilter, allParentsBitSet); // no scoring happens here BooleanQuery.Builder acceptedChildrenBuilder = createBuilder(); if (childrenFilter != null) { acceptedChildrenBuilder.add(childrenFilter, BooleanClause.Occur.FILTER); @@ -149,7 +165,7 @@ private Query getChildrenFilter(Query childrenKnnPreFilter, BooleanQuery parents private BooleanQuery getParentsFilter() throws SyntaxError { List parentFilterQueries = QueryUtils.parseFilterQueries(req); BooleanQuery.Builder acceptedParentsBuilder = createBuilder(); - for (Query filter:parentFilterQueries) { + for (Query filter : parentFilterQueries) { acceptedParentsBuilder.add(filter, BooleanClause.Occur.FILTER); } BooleanQuery acceptedParents = acceptedParentsBuilder.build(); diff --git a/solr/core/src/test/org/apache/solr/search/join/BlockJoinNestedVectorsQParserTest.java b/solr/core/src/test/org/apache/solr/search/join/BlockJoinNestedVectorsQParserTest.java index 5e1e25d222aa..a85904b66ac8 100644 --- a/solr/core/src/test/org/apache/solr/search/join/BlockJoinNestedVectorsQParserTest.java +++ b/solr/core/src/test/org/apache/solr/search/join/BlockJoinNestedVectorsQParserTest.java @@ -16,6 +16,9 @@ */ package org.apache.solr.search.join; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; import org.apache.solr.SolrTestCaseJ4; import org.apache.solr.common.SolrInputDocument; import org.apache.solr.util.RandomNoReverseMergePolicyFactory; @@ -24,15 +27,10 @@ import org.junit.Test; import org.junit.rules.TestRule; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; - - public class BlockJoinNestedVectorsQParserTest extends SolrTestCaseJ4 { private static final List FLOAT_QUERY_VECTOR = Arrays.asList(1.0f, 1.0f, 1.0f, 1.0f); private static final List BYTE_QUERY_VECTOR = Arrays.asList(1, 1, 1, 1); - + @ClassRule public static final TestRule noReverseMerge = RandomNoReverseMergePolicyFactory.createRule(); @@ -42,7 +40,6 @@ public static void beforeClass() throws Exception { prepareIndex(); } - public static void prepareIndex() throws Exception { List docsToIndex = prepareDocs(); for (SolrInputDocument doc : docsToIndex) { @@ -52,12 +49,13 @@ public static void prepareIndex() throws Exception { } /** - * The documents in the index are 10 parents, with some parent level metadata and - * 30 nested documents (with vectors and children level metadata) - * Each parent document has 3 nested documents with vectors. - * - * This allows to run knn queries both at parent/children level and using various pre-filters both - * for parent metadata and children. + * The documents in the index are 10 parents, with some parent level metadata and 30 nested + * documents (with vectors and children level metadata) Each parent document has 3 nested + * documents with vectors. + * + *

This allows to run knn queries both at parent/children level and using various pre-filters + * both for parent metadata and children. + * * @return */ private static List prepareDocs() { @@ -67,7 +65,7 @@ private static List prepareDocs() { final String[] klm = new String[] {"k", "l", "m"}; final String[] abcdef = new String[] {"a", "b", "c", "d", "e", "f"}; - + List docs = new ArrayList<>(totalParentDocuments); for (int i = 1; i < totalParentDocuments + 1; i++) { SolrInputDocument doc = new SolrInputDocument(); @@ -77,27 +75,29 @@ private static List prepareDocs() { doc.setField("parent_s", abcdef[i % abcdef.length]); List children = new ArrayList<>(perParentChildren); - //nested vector documents have a distance from the query vector inversely proportional to their id - for(int j = 0; j < perParentChildren; j++) { + // nested vector documents have a distance from the query vector inversely proportional to + // their id + for (int j = 0; j < perParentChildren; j++) { SolrInputDocument child = new SolrInputDocument(); - child.setField("id", i+""+j); + child.setField("id", i + "" + j); child.setField("child_s", klm[i % klm.length]); child.setField("vector", outDistanceFloat(FLOAT_QUERY_VECTOR, totalNestedVectors)); child.setField("vector_byte", outDistanceByte(BYTE_QUERY_VECTOR, totalNestedVectors)); - totalNestedVectors--; //the higher the id of the nested document, lower the distance with the query vector + totalNestedVectors--; // the higher the id of the nested document, lower the distance with + // the query vector children.add(child); } - doc.setField("vectors",children); + doc.setField("vectors", children); docs.add(doc); } - + return docs; } /** - * Generate a resulting float vector with a distance from the original vector that is proportional to the value in input - * (higher the value, higher the distance from the original vector) - * + * Generate a resulting float vector with a distance from the original vector that is proportional + * to the value in input (higher the value, higher the distance from the original vector) + * * @param vector * @param value * @return @@ -115,14 +115,14 @@ private static List outDistanceFloat(List vector, int value) { } /** - * Generate a resulting byte vector with a distance from the original vector that is proportional to the value in input - * (higher the value, higher the distance from the original vector) + * Generate a resulting byte vector with a distance from the original vector that is proportional + * to the value in input (higher the value, higher the distance from the original vector) * * @param vector * @param value * @return */ - private static List outDistanceByte(List vector, int value){ + private static List outDistanceByte(List vector, int value) { List result = new ArrayList<>(vector.size()); for (int i = 0; i < vector.size(); i++) { if (i == 0) { @@ -137,124 +137,124 @@ private static List outDistanceByte(List vector, int value){ @Test public void childrenRetrievalFloat_filteringByParentMetadata_shouldReturnKnnChildren() { assertQ( - req( - "fq", "{!child of=$allParents filters=$parent.fq}", - "q", "{!knn f=vector topK=5}"+FLOAT_QUERY_VECTOR, - "fl", "id", - "parent.fq", "parent_s:(a c)", - "allParents", "parent_s:[* TO *]"), - "//*[@numFound='5']", - "//result/doc[1]/str[@name='id'][.='82']", - "//result/doc[2]/str[@name='id'][.='81']", - "//result/doc[3]/str[@name='id'][.='80']", - "//result/doc[4]/str[@name='id'][.='62']", - "//result/doc[5]/str[@name='id'][.='61']"); + req( + "fq", "{!child of=$allParents filters=$parent.fq}", + "q", "{!knn f=vector topK=5}" + FLOAT_QUERY_VECTOR, + "fl", "id", + "parent.fq", "parent_s:(a c)", + "allParents", "parent_s:[* TO *]"), + "//*[@numFound='5']", + "//result/doc[1]/str[@name='id'][.='82']", + "//result/doc[2]/str[@name='id'][.='81']", + "//result/doc[3]/str[@name='id'][.='80']", + "//result/doc[4]/str[@name='id'][.='62']", + "//result/doc[5]/str[@name='id'][.='61']"); } @Test public void childrenRetrievalByte_filteringByParentMetadata_shouldReturnKnnChildren() { assertQ( - req( - "fq", "{!child of=$allParents filters=$parent.fq}", - "q", "{!knn f=vector_byte topK=5}"+BYTE_QUERY_VECTOR, - "fl", "id", - "parent.fq", "parent_s:(a c)", - "allParents", "parent_s:[* TO *]"), - "//*[@numFound='5']", - "//result/doc[1]/str[@name='id'][.='82']", - "//result/doc[2]/str[@name='id'][.='81']", - "//result/doc[3]/str[@name='id'][.='80']", - "//result/doc[4]/str[@name='id'][.='62']", - "//result/doc[5]/str[@name='id'][.='61']"); + req( + "fq", "{!child of=$allParents filters=$parent.fq}", + "q", "{!knn f=vector_byte topK=5}" + BYTE_QUERY_VECTOR, + "fl", "id", + "parent.fq", "parent_s:(a c)", + "allParents", "parent_s:[* TO *]"), + "//*[@numFound='5']", + "//result/doc[1]/str[@name='id'][.='82']", + "//result/doc[2]/str[@name='id'][.='81']", + "//result/doc[3]/str[@name='id'][.='80']", + "//result/doc[4]/str[@name='id'][.='62']", + "//result/doc[5]/str[@name='id'][.='61']"); } @Test public void parentRetrievalFloat_knnChildren_shouldReturnKnnParents() { assertQ( - req( - // "fq", "parent_s:(a c)", - "q", "{!parent which=$allParents score=max v=$children.q}", - "fl", "id,score", - "children.q", "{!knn f=vector topK=3}"+FLOAT_QUERY_VECTOR, - "allParents", "parent_s:[* TO *]"), - "//*[@numFound='3']", - "//result/doc[1]/str[@name='id'][.='10']", - "//result/doc[2]/str[@name='id'][.='9']", - "//result/doc[3]/str[@name='id'][.='8']"); + req( + // "fq", "parent_s:(a c)", + "q", "{!parent which=$allParents score=max v=$children.q}", + "fl", "id,score", + "children.q", "{!knn f=vector topK=3}" + FLOAT_QUERY_VECTOR, + "allParents", "parent_s:[* TO *]"), + "//*[@numFound='3']", + "//result/doc[1]/str[@name='id'][.='10']", + "//result/doc[2]/str[@name='id'][.='9']", + "//result/doc[3]/str[@name='id'][.='8']"); } @Test public void parentRetrievalFloat_knnChildrenWithParentFilter_shouldReturnKnnParents() { assertQ( - req( - "fq", "parent_s:(a c)", - "q", "{!parent which=$allParents score=max v=$children.q}", - "fl", "id,score", - "children.q", "{!knn f=vector topK=3}"+FLOAT_QUERY_VECTOR, - "allParents", "parent_s:[* TO *]"), - "//*[@numFound='3']", - "//result/doc[1]/str[@name='id'][.='8']", - "//result/doc[2]/str[@name='id'][.='6']", - "//result/doc[3]/str[@name='id'][.='2']"); + req( + "fq", "parent_s:(a c)", + "q", "{!parent which=$allParents score=max v=$children.q}", + "fl", "id,score", + "children.q", "{!knn f=vector topK=3}" + FLOAT_QUERY_VECTOR, + "allParents", "parent_s:[* TO *]"), + "//*[@numFound='3']", + "//result/doc[1]/str[@name='id'][.='8']", + "//result/doc[2]/str[@name='id'][.='6']", + "//result/doc[3]/str[@name='id'][.='2']"); } @Test - public void parentRetrievalFloat_knnChildrenWithParentFilterAndChildrenFilter_shouldReturnKnnParents() { + public void + parentRetrievalFloat_knnChildrenWithParentFilterAndChildrenFilter_shouldReturnKnnParents() { assertQ( - req( - "fq", "parent_s:(a c)", - "q", "{!parent which=$allParents score=max v=$children.q}", - "fl", "id,score", - "children.q", "{!knn f=vector topK=3 preFilter=child_s:m}"+FLOAT_QUERY_VECTOR, - "allParents", "parent_s:[* TO *]"), - "//*[@numFound='2']", - "//result/doc[1]/str[@name='id'][.='8']", - "//result/doc[2]/str[@name='id'][.='2']"); + req( + "fq", "parent_s:(a c)", + "q", "{!parent which=$allParents score=max v=$children.q}", + "fl", "id,score", + "children.q", "{!knn f=vector topK=3 preFilter=child_s:m}" + FLOAT_QUERY_VECTOR, + "allParents", "parent_s:[* TO *]"), + "//*[@numFound='2']", + "//result/doc[1]/str[@name='id'][.='8']", + "//result/doc[2]/str[@name='id'][.='2']"); } @Test public void parentRetrievalByte_knnChildren_shouldReturnKnnParents() { assertQ( - req( - // "fq", "parent_s:(a c)", - "q", "{!parent which=$allParents score=max v=$children.q}", - "fl", "id,score", - "children.q", "{!knn f=vector_byte topK=3}"+BYTE_QUERY_VECTOR, - "allParents", "parent_s:[* TO *]"), - "//*[@numFound='3']", - "//result/doc[1]/str[@name='id'][.='10']", - "//result/doc[2]/str[@name='id'][.='9']", - "//result/doc[3]/str[@name='id'][.='8']"); + req( + // "fq", "parent_s:(a c)", + "q", "{!parent which=$allParents score=max v=$children.q}", + "fl", "id,score", + "children.q", "{!knn f=vector_byte topK=3}" + BYTE_QUERY_VECTOR, + "allParents", "parent_s:[* TO *]"), + "//*[@numFound='3']", + "//result/doc[1]/str[@name='id'][.='10']", + "//result/doc[2]/str[@name='id'][.='9']", + "//result/doc[3]/str[@name='id'][.='8']"); } @Test public void parentRetrievalByte_knnChildrenWithParentFilter_shouldReturnKnnParents() { assertQ( - req( - "fq", "parent_s:(a c)", - "q", "{!parent which=$allParents score=max v=$children.q}", - "fl", "id,score", - "children.q", "{!knn f=vector_byte topK=3}"+BYTE_QUERY_VECTOR, - "allParents", "parent_s:[* TO *]"), - "//*[@numFound='3']", - "//result/doc[1]/str[@name='id'][.='8']", - "//result/doc[2]/str[@name='id'][.='6']", - "//result/doc[3]/str[@name='id'][.='2']"); + req( + "fq", "parent_s:(a c)", + "q", "{!parent which=$allParents score=max v=$children.q}", + "fl", "id,score", + "children.q", "{!knn f=vector_byte topK=3}" + BYTE_QUERY_VECTOR, + "allParents", "parent_s:[* TO *]"), + "//*[@numFound='3']", + "//result/doc[1]/str[@name='id'][.='8']", + "//result/doc[2]/str[@name='id'][.='6']", + "//result/doc[3]/str[@name='id'][.='2']"); } @Test - public void parentRetrievalByte_knnChildrenWithParentFilterAndChildrenFilter_shouldReturnKnnParents() { + public void + parentRetrievalByte_knnChildrenWithParentFilterAndChildrenFilter_shouldReturnKnnParents() { assertQ( - req( - "fq", "parent_s:(a c)", - "q", "{!parent which=$allParents score=max v=$children.q}", - "fl", "id,score", - "children.q", "{!knn f=vector_byte topK=3 preFilter=child_s:m}"+BYTE_QUERY_VECTOR, - "allParents", "parent_s:[* TO *]"), - "//*[@numFound='2']", - "//result/doc[1]/str[@name='id'][.='8']", - "//result/doc[2]/str[@name='id'][.='2']"); + req( + "fq", "parent_s:(a c)", + "q", "{!parent which=$allParents score=max v=$children.q}", + "fl", "id,score", + "children.q", "{!knn f=vector_byte topK=3 preFilter=child_s:m}" + BYTE_QUERY_VECTOR, + "allParents", "parent_s:[* TO *]"), + "//*[@numFound='2']", + "//result/doc[1]/str[@name='id'][.='8']", + "//result/doc[2]/str[@name='id'][.='2']"); } - - } From 5f094961a26bddcb7ed89f5bde313efd74011335 Mon Sep 17 00:00:00 2001 From: Alessandro Benedetti Date: Wed, 9 Apr 2025 10:48:52 +0100 Subject: [PATCH 08/43] tidy --- .../join/BlockJoinNestedVectorsQParserTest.java | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/solr/core/src/test/org/apache/solr/search/join/BlockJoinNestedVectorsQParserTest.java b/solr/core/src/test/org/apache/solr/search/join/BlockJoinNestedVectorsQParserTest.java index a85904b66ac8..764febdd9af9 100644 --- a/solr/core/src/test/org/apache/solr/search/join/BlockJoinNestedVectorsQParserTest.java +++ b/solr/core/src/test/org/apache/solr/search/join/BlockJoinNestedVectorsQParserTest.java @@ -56,7 +56,7 @@ public static void prepareIndex() throws Exception { *

This allows to run knn queries both at parent/children level and using various pre-filters * both for parent metadata and children. * - * @return + * @return a list of documents to index */ private static List prepareDocs() { int totalParentDocuments = 10; @@ -98,9 +98,9 @@ private static List prepareDocs() { * Generate a resulting float vector with a distance from the original vector that is proportional * to the value in input (higher the value, higher the distance from the original vector) * - * @param vector - * @param value - * @return + * @param vector a numerical vector + * @param value a numerical value to be added to the first element of the vector + * @return a numerical vector that has a distance from the input vector, proportional to the value */ private static List outDistanceFloat(List vector, int value) { List result = new ArrayList<>(vector.size()); @@ -118,9 +118,9 @@ private static List outDistanceFloat(List vector, int value) { * Generate a resulting byte vector with a distance from the original vector that is proportional * to the value in input (higher the value, higher the distance from the original vector) * - * @param vector - * @param value - * @return + * @param vector a numerical vector + * @param value a numerical value to be added to the first element of the vector + * @return a numerical vector that has a distance from the input vector, proportional to the value */ private static List outDistanceByte(List vector, int value) { List result = new ArrayList<>(vector.size()); From d40def2af820b6002dee7ba4a398c880cc6fd047 Mon Sep 17 00:00:00 2001 From: Alessandro Benedetti Date: Mon, 31 Mar 2025 12:33:19 +0100 Subject: [PATCH 09/43] first draft --- .../apache/solr/schema/DenseVectorField.java | 24 +++++++++ .../search/join/BlockJoinChildQParser.java | 4 +- .../search/join/BlockJoinParentQParser.java | 51 +++++++++++++++++-- .../solr/search/join/FiltersQParser.java | 2 +- .../apache/solr/search/neural/KnnQParser.java | 10 +++- .../apache/solr/update/AddUpdateCommand.java | 13 ++++- .../solr/search/join/BJQParserTest.java | 5 +- 7 files changed, 96 insertions(+), 13 deletions(-) diff --git a/solr/core/src/java/org/apache/solr/schema/DenseVectorField.java b/solr/core/src/java/org/apache/solr/schema/DenseVectorField.java index 773c1e6337d1..a3dfb8f9ae95 100644 --- a/solr/core/src/java/org/apache/solr/schema/DenseVectorField.java +++ b/solr/core/src/java/org/apache/solr/schema/DenseVectorField.java @@ -40,10 +40,15 @@ import org.apache.lucene.search.KnnFloatVectorQuery; import org.apache.lucene.search.Query; import org.apache.lucene.search.SortField; +import org.apache.lucene.search.join.BitSetProducer; +import org.apache.lucene.search.join.DiversifyingChildrenByteKnnVectorQuery; +import org.apache.lucene.search.join.DiversifyingChildrenFloatKnnVectorQuery; import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.hnsw.HnswGraph; import org.apache.solr.common.SolrException; +import org.apache.solr.request.SolrQueryRequest; import org.apache.solr.search.QParser; +import org.apache.solr.search.join.BlockJoinParentQParser; import org.apache.solr.uninverting.UninvertingReader; import org.apache.solr.util.vector.ByteDenseVectorParser; import org.apache.solr.util.vector.DenseVectorParser; @@ -384,6 +389,25 @@ public Query getKnnVectorQuery( } } + public Query getMultiValuedKnnVectorQuery(final SolrQueryRequest request, + String fieldName, String vectorToSearch, int topK, Query filterQuery) { + + DenseVectorParser vectorBuilder = + getVectorBuilder(vectorToSearch, DenseVectorParser.BuilderPhase.QUERY); + + BitSetProducer acceptedDocuments = BlockJoinParentQParser.getCachedBitSetProducer(request, filterQuery); + switch (vectorEncoding) { + case FLOAT32: + return new DiversifyingChildrenFloatKnnVectorQuery(fieldName, vectorBuilder.getFloatVector(), null, topK, acceptedDocuments); + case BYTE: + return new DiversifyingChildrenByteKnnVectorQuery(fieldName, vectorBuilder.getByteVector(), null, topK, acceptedDocuments); + default: + throw new SolrException( + SolrException.ErrorCode.SERVER_ERROR, + "Unexpected state. Vector Encoding: " + vectorEncoding); + } + } + /** * Not Supported. Please use the {!knn} query parser to run K nearest neighbors search queries. */ diff --git a/solr/core/src/java/org/apache/solr/search/join/BlockJoinChildQParser.java b/solr/core/src/java/org/apache/solr/search/join/BlockJoinChildQParser.java index bb6c80db07a8..2ac81ca7d95b 100644 --- a/solr/core/src/java/org/apache/solr/search/join/BlockJoinChildQParser.java +++ b/solr/core/src/java/org/apache/solr/search/join/BlockJoinChildQParser.java @@ -33,8 +33,8 @@ public BlockJoinChildQParser( } @Override - protected Query createQuery(Query parentListQuery, Query query, String scoreMode) { - return new ToChildBlockJoinQuery(query, getBitSetProducer(parentListQuery)); + protected Query createQuery(Query allParents, BooleanQuery query, String scoreMode) { + return new ToChildBlockJoinQuery(query, getBitSetProducer(allParents)); } @Override diff --git a/solr/core/src/java/org/apache/solr/search/join/BlockJoinParentQParser.java b/solr/core/src/java/org/apache/solr/search/join/BlockJoinParentQParser.java index 67a38d0fc57a..348ff632327f 100644 --- a/solr/core/src/java/org/apache/solr/search/join/BlockJoinParentQParser.java +++ b/solr/core/src/java/org/apache/solr/search/join/BlockJoinParentQParser.java @@ -18,17 +18,25 @@ import java.io.IOException; import java.io.UncheckedIOException; +import java.util.List; +import java.util.Map; import java.util.Objects; import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.search.BooleanClause; +import org.apache.lucene.search.BooleanQuery; import org.apache.lucene.search.ConstantScoreScorer; import org.apache.lucene.search.ConstantScoreWeight; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.KnnByteVectorQuery; +import org.apache.lucene.search.KnnFloatVectorQuery; import org.apache.lucene.search.Query; import org.apache.lucene.search.QueryVisitor; import org.apache.lucene.search.Scorer; import org.apache.lucene.search.Weight; import org.apache.lucene.search.join.BitSetProducer; +import org.apache.lucene.search.join.DiversifyingChildrenByteKnnVectorQuery; +import org.apache.lucene.search.join.DiversifyingChildrenFloatKnnVectorQuery; import org.apache.lucene.search.join.QueryBitSetProducer; import org.apache.lucene.search.join.ScoreMode; import org.apache.lucene.search.join.ToParentBlockJoinQuery; @@ -38,6 +46,7 @@ import org.apache.solr.request.SolrQueryRequest; import org.apache.solr.search.ExtendedQueryBase; import org.apache.solr.search.QParser; +import org.apache.solr.search.QueryUtils; import org.apache.solr.search.SolrCache; import org.apache.solr.search.SyntaxError; @@ -67,7 +76,7 @@ protected Query parseParentFilter() throws SyntaxError { } @Override - protected Query wrapSubordinateClause(Query subordinate) throws SyntaxError { + protected Query wrapSubordinateClause(BooleanQuery subordinate) throws SyntaxError { String scoreMode = localParams.get("score", ScoreMode.None.name()); Query parentQ = parseParentFilter(); return createQuery(parentQ, subordinate, scoreMode); @@ -78,10 +87,42 @@ protected Query noClausesQuery() throws SyntaxError { return new BitSetProducerQuery(getBitSetProducer(parseParentFilter())); } - protected Query createQuery(final Query parentList, Query query, String scoreMode) - throws SyntaxError { - return new AllParentsAware( - query, getBitSetProducer(parentList), ScoreModeParser.parse(scoreMode), parentList); + protected Query createQuery(final Query allParents, BooleanQuery query, String scoreMode) + throws SyntaxError { + List clauses = query.clauses(); + if (clauses.size() == 1 && clauses.get(0).getQuery().getClass().equals(KnnByteVectorQuery.class)) { + Query acceptedParents = getAcceptedParents(allParents); + KnnByteVectorQuery childQuery = (KnnByteVectorQuery) clauses.get(0).getQuery(); + String vectorField = childQuery.getField(); + byte[] queryVector = childQuery.getTargetCopy(); + int topK = childQuery.getK(); + BitSetProducer parentFilter = getBitSetProducer(acceptedParents); + Query childrenFilter = childQuery.getFilter(); + return new DiversifyingChildrenByteKnnVectorQuery(vectorField, queryVector, childrenFilter, topK, parentFilter); + } else if (clauses.size() == 1 && clauses.get(0).getQuery().getClass().equals(KnnFloatVectorQuery.class)) { + Query acceptedParents = getAcceptedParents(allParents); + KnnFloatVectorQuery childQuery = (KnnFloatVectorQuery) clauses.get(0).getQuery(); + String vectorField = childQuery.getField(); + float[] queryVector = childQuery.getTargetCopy(); + int topK = childQuery.getK(); + BitSetProducer parentFilter = getBitSetProducer(acceptedParents); + Query childrenFilter = childQuery.getFilter(); + return new DiversifyingChildrenFloatKnnVectorQuery(vectorField, queryVector, childrenFilter, topK, parentFilter); + } else { + return new AllParentsAware( + query, getBitSetProducer(allParents), ScoreModeParser.parse(scoreMode), allParents); + } + } + + private Query getAcceptedParents(Query allParents) throws SyntaxError { + List parentFilterQueries = QueryUtils.parseFilterQueries(req); + BooleanQuery.Builder acceptedParentsBuilder = createBuilder(); + for (Query filter:parentFilterQueries) { + acceptedParentsBuilder.add(filter, BooleanClause.Occur.MUST); + } + acceptedParentsBuilder.add(allParents, BooleanClause.Occur.MUST); + Query acceptedParents = acceptedParentsBuilder.build(); + return acceptedParents; } BitSetProducer getBitSetProducer(Query query) { diff --git a/solr/core/src/java/org/apache/solr/search/join/FiltersQParser.java b/solr/core/src/java/org/apache/solr/search/join/FiltersQParser.java index 05c705aa1ce1..45036ebffece 100644 --- a/solr/core/src/java/org/apache/solr/search/join/FiltersQParser.java +++ b/solr/core/src/java/org/apache/solr/search/join/FiltersQParser.java @@ -73,7 +73,7 @@ protected Query unwrapQuery(Query query, BooleanClause.Occur occur) { return query; } - protected Query wrapSubordinateClause(Query subordinate) throws SyntaxError { + protected Query wrapSubordinateClause(BooleanQuery subordinate) throws SyntaxError { return subordinate; } diff --git a/solr/core/src/java/org/apache/solr/search/neural/KnnQParser.java b/solr/core/src/java/org/apache/solr/search/neural/KnnQParser.java index b6d9f2541cd0..f917c7731368 100644 --- a/solr/core/src/java/org/apache/solr/search/neural/KnnQParser.java +++ b/solr/core/src/java/org/apache/solr/search/neural/KnnQParser.java @@ -40,7 +40,13 @@ public Query parse() throws SyntaxError { final String vectorToSearch = getVectorToSearch(); final int topK = localParams.getInt(TOP_K, DEFAULT_TOP_K); - return denseVectorType.getKnnVectorQuery( - schemaField.getName(), vectorToSearch, topK, getFilterQuery()); + if(schemaField.multiValued()){ + return denseVectorType.getMultiValuedKnnVectorQuery(req, + schemaField.getName(), vectorToSearch, topK, getFilterQuery()); + } else { + return denseVectorType.getKnnVectorQuery( + schemaField.getName(), vectorToSearch, topK, getFilterQuery()); + } + } } diff --git a/solr/core/src/java/org/apache/solr/update/AddUpdateCommand.java b/solr/core/src/java/org/apache/solr/update/AddUpdateCommand.java index 5ca176ea66ae..7b3c86ac4c54 100644 --- a/solr/core/src/java/org/apache/solr/update/AddUpdateCommand.java +++ b/solr/core/src/java/org/apache/solr/update/AddUpdateCommand.java @@ -29,6 +29,7 @@ import org.apache.solr.common.SolrInputField; import org.apache.solr.common.params.CommonParams; import org.apache.solr.request.SolrQueryRequest; +import org.apache.solr.schema.DenseVectorField; import org.apache.solr.schema.IndexSchema; import org.apache.solr.schema.SchemaField; @@ -256,7 +257,9 @@ private List flatten(SolrInputDocument root) { /** Extract all child documents from parent that are saved in fields */ private void flattenLabelled( List unwrappedDocs, SolrInputDocument currentDoc, boolean isRoot) { + IndexSchema schema = req.getSchema(); for (SolrInputField field : currentDoc.values()) { + SchemaField sfield = schema.getFieldOrNull(field.getName()); Object value = field.getFirstValue(); // check if value is a childDocument if (value instanceof SolrInputDocument) { @@ -270,7 +273,15 @@ private void flattenLabelled( for (SolrInputDocument child : childrenList) { flattenLabelled(unwrappedDocs, child); } - } + } else if (sfield.getType() instanceof DenseVectorField && sfield.multiValued() && isRoot){ + Collection vectorValues = field.getValues(); + for(Object vectorValue:vectorValues){ + SolrInputDocument singleVectorNestedDoc = new SolrInputDocument(); + singleVectorNestedDoc.setField(field.getName(), vectorValue); + flattenLabelled(unwrappedDocs, singleVectorNestedDoc); + } + + } } if (!isRoot) unwrappedDocs.add(currentDoc); diff --git a/solr/core/src/test/org/apache/solr/search/join/BJQParserTest.java b/solr/core/src/test/org/apache/solr/search/join/BJQParserTest.java index 5fcb7455187e..346ea01a8aa3 100644 --- a/solr/core/src/test/org/apache/solr/search/join/BJQParserTest.java +++ b/solr/core/src/test/org/apache/solr/search/join/BJQParserTest.java @@ -151,8 +151,9 @@ private static void addGrandChildren(List block) { @Test public void testFull() { - String childb = "{!parent which=\"parent_s:[* TO *]\"}child_s:l"; - assertQ(req("q", childb), sixParents); + //{!knn f=vector topK=10}[1.0, 2.0, 3.0, 4.0]&fq:acl:foo + String childb = "{!parent which=\"parent_s:[* TO *]\"}{!knn f=vector topK=10}[1.0, 2.0, 3.0, 4.0]"; + assertQ(req("q", childb,"fq","parent_s:a"), sixParents); } private static final String sixParents[] = From bb8cfdefc8d8a3a4aba9cd5689dbf1a52110a97c Mon Sep 17 00:00:00 2001 From: Alessandro Benedetti Date: Wed, 2 Apr 2025 09:55:21 +0100 Subject: [PATCH 10/43] Only Multi valued Vectors changes --- .../search/join/BlockJoinChildQParser.java | 4 +- .../search/join/BlockJoinParentQParser.java | 51 ++----------------- .../solr/search/join/FiltersQParser.java | 2 +- .../solr/search/join/BJQParserTest.java | 5 +- 4 files changed, 10 insertions(+), 52 deletions(-) diff --git a/solr/core/src/java/org/apache/solr/search/join/BlockJoinChildQParser.java b/solr/core/src/java/org/apache/solr/search/join/BlockJoinChildQParser.java index 2ac81ca7d95b..bb6c80db07a8 100644 --- a/solr/core/src/java/org/apache/solr/search/join/BlockJoinChildQParser.java +++ b/solr/core/src/java/org/apache/solr/search/join/BlockJoinChildQParser.java @@ -33,8 +33,8 @@ public BlockJoinChildQParser( } @Override - protected Query createQuery(Query allParents, BooleanQuery query, String scoreMode) { - return new ToChildBlockJoinQuery(query, getBitSetProducer(allParents)); + protected Query createQuery(Query parentListQuery, Query query, String scoreMode) { + return new ToChildBlockJoinQuery(query, getBitSetProducer(parentListQuery)); } @Override diff --git a/solr/core/src/java/org/apache/solr/search/join/BlockJoinParentQParser.java b/solr/core/src/java/org/apache/solr/search/join/BlockJoinParentQParser.java index 348ff632327f..67a38d0fc57a 100644 --- a/solr/core/src/java/org/apache/solr/search/join/BlockJoinParentQParser.java +++ b/solr/core/src/java/org/apache/solr/search/join/BlockJoinParentQParser.java @@ -18,25 +18,17 @@ import java.io.IOException; import java.io.UncheckedIOException; -import java.util.List; -import java.util.Map; import java.util.Objects; import org.apache.lucene.index.LeafReaderContext; -import org.apache.lucene.search.BooleanClause; -import org.apache.lucene.search.BooleanQuery; import org.apache.lucene.search.ConstantScoreScorer; import org.apache.lucene.search.ConstantScoreWeight; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.IndexSearcher; -import org.apache.lucene.search.KnnByteVectorQuery; -import org.apache.lucene.search.KnnFloatVectorQuery; import org.apache.lucene.search.Query; import org.apache.lucene.search.QueryVisitor; import org.apache.lucene.search.Scorer; import org.apache.lucene.search.Weight; import org.apache.lucene.search.join.BitSetProducer; -import org.apache.lucene.search.join.DiversifyingChildrenByteKnnVectorQuery; -import org.apache.lucene.search.join.DiversifyingChildrenFloatKnnVectorQuery; import org.apache.lucene.search.join.QueryBitSetProducer; import org.apache.lucene.search.join.ScoreMode; import org.apache.lucene.search.join.ToParentBlockJoinQuery; @@ -46,7 +38,6 @@ import org.apache.solr.request.SolrQueryRequest; import org.apache.solr.search.ExtendedQueryBase; import org.apache.solr.search.QParser; -import org.apache.solr.search.QueryUtils; import org.apache.solr.search.SolrCache; import org.apache.solr.search.SyntaxError; @@ -76,7 +67,7 @@ protected Query parseParentFilter() throws SyntaxError { } @Override - protected Query wrapSubordinateClause(BooleanQuery subordinate) throws SyntaxError { + protected Query wrapSubordinateClause(Query subordinate) throws SyntaxError { String scoreMode = localParams.get("score", ScoreMode.None.name()); Query parentQ = parseParentFilter(); return createQuery(parentQ, subordinate, scoreMode); @@ -87,42 +78,10 @@ protected Query noClausesQuery() throws SyntaxError { return new BitSetProducerQuery(getBitSetProducer(parseParentFilter())); } - protected Query createQuery(final Query allParents, BooleanQuery query, String scoreMode) - throws SyntaxError { - List clauses = query.clauses(); - if (clauses.size() == 1 && clauses.get(0).getQuery().getClass().equals(KnnByteVectorQuery.class)) { - Query acceptedParents = getAcceptedParents(allParents); - KnnByteVectorQuery childQuery = (KnnByteVectorQuery) clauses.get(0).getQuery(); - String vectorField = childQuery.getField(); - byte[] queryVector = childQuery.getTargetCopy(); - int topK = childQuery.getK(); - BitSetProducer parentFilter = getBitSetProducer(acceptedParents); - Query childrenFilter = childQuery.getFilter(); - return new DiversifyingChildrenByteKnnVectorQuery(vectorField, queryVector, childrenFilter, topK, parentFilter); - } else if (clauses.size() == 1 && clauses.get(0).getQuery().getClass().equals(KnnFloatVectorQuery.class)) { - Query acceptedParents = getAcceptedParents(allParents); - KnnFloatVectorQuery childQuery = (KnnFloatVectorQuery) clauses.get(0).getQuery(); - String vectorField = childQuery.getField(); - float[] queryVector = childQuery.getTargetCopy(); - int topK = childQuery.getK(); - BitSetProducer parentFilter = getBitSetProducer(acceptedParents); - Query childrenFilter = childQuery.getFilter(); - return new DiversifyingChildrenFloatKnnVectorQuery(vectorField, queryVector, childrenFilter, topK, parentFilter); - } else { - return new AllParentsAware( - query, getBitSetProducer(allParents), ScoreModeParser.parse(scoreMode), allParents); - } - } - - private Query getAcceptedParents(Query allParents) throws SyntaxError { - List parentFilterQueries = QueryUtils.parseFilterQueries(req); - BooleanQuery.Builder acceptedParentsBuilder = createBuilder(); - for (Query filter:parentFilterQueries) { - acceptedParentsBuilder.add(filter, BooleanClause.Occur.MUST); - } - acceptedParentsBuilder.add(allParents, BooleanClause.Occur.MUST); - Query acceptedParents = acceptedParentsBuilder.build(); - return acceptedParents; + protected Query createQuery(final Query parentList, Query query, String scoreMode) + throws SyntaxError { + return new AllParentsAware( + query, getBitSetProducer(parentList), ScoreModeParser.parse(scoreMode), parentList); } BitSetProducer getBitSetProducer(Query query) { diff --git a/solr/core/src/java/org/apache/solr/search/join/FiltersQParser.java b/solr/core/src/java/org/apache/solr/search/join/FiltersQParser.java index 45036ebffece..05c705aa1ce1 100644 --- a/solr/core/src/java/org/apache/solr/search/join/FiltersQParser.java +++ b/solr/core/src/java/org/apache/solr/search/join/FiltersQParser.java @@ -73,7 +73,7 @@ protected Query unwrapQuery(Query query, BooleanClause.Occur occur) { return query; } - protected Query wrapSubordinateClause(BooleanQuery subordinate) throws SyntaxError { + protected Query wrapSubordinateClause(Query subordinate) throws SyntaxError { return subordinate; } diff --git a/solr/core/src/test/org/apache/solr/search/join/BJQParserTest.java b/solr/core/src/test/org/apache/solr/search/join/BJQParserTest.java index 346ea01a8aa3..5fcb7455187e 100644 --- a/solr/core/src/test/org/apache/solr/search/join/BJQParserTest.java +++ b/solr/core/src/test/org/apache/solr/search/join/BJQParserTest.java @@ -151,9 +151,8 @@ private static void addGrandChildren(List block) { @Test public void testFull() { - //{!knn f=vector topK=10}[1.0, 2.0, 3.0, 4.0]&fq:acl:foo - String childb = "{!parent which=\"parent_s:[* TO *]\"}{!knn f=vector topK=10}[1.0, 2.0, 3.0, 4.0]"; - assertQ(req("q", childb,"fq","parent_s:a"), sixParents); + String childb = "{!parent which=\"parent_s:[* TO *]\"}child_s:l"; + assertQ(req("q", childb), sixParents); } private static final String sixParents[] = From 882eb9c9b97f2a68f746c41f30dc109a5fc898bf Mon Sep 17 00:00:00 2001 From: Alessandro Benedetti Date: Wed, 9 Apr 2025 11:26:55 +0100 Subject: [PATCH 11/43] first draft --- .../apache/solr/schema/DenseVectorField.java | 32 +- .../collection1/conf/schema-densevector.xml | 2 + .../KnnQParserMultiValuedVectorsTest.java | 971 ++++++++++++++++++ 3 files changed, 1002 insertions(+), 3 deletions(-) create mode 100644 solr/core/src/test/org/apache/solr/search/neural/KnnQParserMultiValuedVectorsTest.java diff --git a/solr/core/src/java/org/apache/solr/schema/DenseVectorField.java b/solr/core/src/java/org/apache/solr/schema/DenseVectorField.java index a3dfb8f9ae95..297d8841be56 100644 --- a/solr/core/src/java/org/apache/solr/schema/DenseVectorField.java +++ b/solr/core/src/java/org/apache/solr/schema/DenseVectorField.java @@ -19,6 +19,7 @@ import static java.util.Optional.ofNullable; import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH; import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN; +import static org.apache.solr.schema.IndexSchema.NEST_PATH_FIELD_NAME; import java.lang.invoke.MethodHandles; import java.util.ArrayList; @@ -36,13 +37,19 @@ import org.apache.lucene.queries.function.ValueSource; import org.apache.lucene.queries.function.valuesource.ByteKnnVectorFieldSource; import org.apache.lucene.queries.function.valuesource.FloatKnnVectorFieldSource; +import org.apache.lucene.search.BooleanClause; +import org.apache.lucene.search.BooleanQuery; +import org.apache.lucene.search.DocValuesFieldExistsQuery; import org.apache.lucene.search.KnnByteVectorQuery; import org.apache.lucene.search.KnnFloatVectorQuery; +import org.apache.lucene.search.MatchAllDocsQuery; import org.apache.lucene.search.Query; import org.apache.lucene.search.SortField; import org.apache.lucene.search.join.BitSetProducer; import org.apache.lucene.search.join.DiversifyingChildrenByteKnnVectorQuery; import org.apache.lucene.search.join.DiversifyingChildrenFloatKnnVectorQuery; +import org.apache.lucene.search.join.ScoreMode; +import org.apache.lucene.search.join.ToParentBlockJoinQuery; import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.hnsw.HnswGraph; import org.apache.solr.common.SolrException; @@ -390,22 +397,41 @@ public Query getKnnVectorQuery( } public Query getMultiValuedKnnVectorQuery(final SolrQueryRequest request, - String fieldName, String vectorToSearch, int topK, Query filterQuery) { + String fieldName, String vectorToSearch, int topK, Query filterQuery) { DenseVectorParser vectorBuilder = getVectorBuilder(vectorToSearch, DenseVectorParser.BuilderPhase.QUERY); + BooleanQuery allDocuments = + new BooleanQuery.Builder() + .add(new BooleanClause(new MatchAllDocsQuery(), BooleanClause.Occur.MUST)) + .add( + new BooleanClause( + new DocValuesFieldExistsQuery(NEST_PATH_FIELD_NAME), + BooleanClause.Occur.MUST_NOT)) + .build(); + BitSetProducer acceptedDocuments = BlockJoinParentQParser.getCachedBitSetProducer(request, filterQuery); + BitSetProducer allParentsBitSet = BlockJoinParentQParser.getCachedBitSetProducer(request, allDocuments); + + Query knnOnVectorField; switch (vectorEncoding) { case FLOAT32: - return new DiversifyingChildrenFloatKnnVectorQuery(fieldName, vectorBuilder.getFloatVector(), null, topK, acceptedDocuments); + knnOnVectorField = + new DiversifyingChildrenFloatKnnVectorQuery(fieldName, vectorBuilder.getFloatVector(), null, topK, acceptedDocuments); + break; case BYTE: - return new DiversifyingChildrenByteKnnVectorQuery(fieldName, vectorBuilder.getByteVector(), null, topK, acceptedDocuments); + knnOnVectorField = + new DiversifyingChildrenByteKnnVectorQuery(fieldName, vectorBuilder.getByteVector(), null, topK, acceptedDocuments); + break; default: throw new SolrException( SolrException.ErrorCode.SERVER_ERROR, "Unexpected state. Vector Encoding: " + vectorEncoding); } + + return new ToParentBlockJoinQuery( + knnOnVectorField, allParentsBitSet, ScoreMode.Max); } /** diff --git a/solr/core/src/test-files/solr/collection1/conf/schema-densevector.xml b/solr/core/src/test-files/solr/collection1/conf/schema-densevector.xml index 42db078a6e20..405ecbe2fe5d 100644 --- a/solr/core/src/test-files/solr/collection1/conf/schema-densevector.xml +++ b/solr/core/src/test-files/solr/collection1/conf/schema-densevector.xml @@ -28,8 +28,10 @@ + + diff --git a/solr/core/src/test/org/apache/solr/search/neural/KnnQParserMultiValuedVectorsTest.java b/solr/core/src/test/org/apache/solr/search/neural/KnnQParserMultiValuedVectorsTest.java new file mode 100644 index 000000000000..bf6e3dfa15e4 --- /dev/null +++ b/solr/core/src/test/org/apache/solr/search/neural/KnnQParserMultiValuedVectorsTest.java @@ -0,0 +1,971 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.search.neural; + +import org.apache.solr.SolrTestCaseJ4; +import org.apache.solr.common.SolrException; +import org.apache.solr.common.SolrInputDocument; +import org.apache.solr.common.params.CommonParams; +import org.apache.solr.common.params.SolrParams; +import org.apache.solr.request.SolrQueryRequest; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +import static org.apache.solr.search.neural.KnnQParser.DEFAULT_TOP_K; + +public class KnnQParserMultiValuedVectorsTest extends SolrTestCaseJ4 { + String IDField = "id"; + String vectorField = "vector"; + String vectorField2 = "vector2"; + String vectorFieldByteEncoding = "vector_byte_encoding"; + + @Before + public void prepareIndex() throws Exception { + /* vectorDimension="4" similarityFunction="cosine" */ + initCore("solrconfig_codec.xml", "schema-densevector.xml"); + + List docsToIndex = this.prepareDocs(); + for (SolrInputDocument doc : docsToIndex) { + assertU(adoc(doc)); + } + + assertU(commit()); + } + + private List prepareDocs() { + int docsCount = 13; + List docs = new ArrayList<>(docsCount); + for (int i = 1; i < docsCount + 1; i++) { + SolrInputDocument doc = new SolrInputDocument(); + doc.addField(IDField, i); + docs.add(doc); + } + + docs.get(0) + .addField(vectorField, Arrays.asList(1f, 2f, 3f, 4f)); // cosine distance vector1= 1.0 + docs.get(1) + .addField( + vectorField, Arrays.asList(1.5f, 2.5f, 3.5f, 4.5f)); // cosine distance vector1= 0.998 + docs.get(2) + .addField( + vectorField, + Arrays.asList(7.5f, 15.5f, 17.5f, 22.5f)); // cosine distance vector1= 0.992 + docs.get(3) + .addField( + vectorField, Arrays.asList(1.4f, 2.4f, 3.4f, 4.4f)); // cosine distance vector1= 0.999 + docs.get(4) + .addField(vectorField, Arrays.asList(30f, 22f, 35f, 20f)); // cosine distance vector1= 0.862 + docs.get(5) + .addField(vectorField, Arrays.asList(40f, 1f, 1f, 200f)); // cosine distance vector1= 0.756 + docs.get(6) + .addField(vectorField, Arrays.asList(5f, 10f, 20f, 40f)); // cosine distance vector1= 0.970 + docs.get(7) + .addField( + vectorField, Arrays.asList(120f, 60f, 30f, 15f)); // cosine distance vector1= 0.515 + docs.get(8) + .addField( + vectorField, Arrays.asList(200f, 50f, 100f, 25f)); // cosine distance vector1= 0.554 + docs.get(9) + .addField( + vectorField, Arrays.asList(1.8f, 2.5f, 3.7f, 4.9f)); // cosine distance vector1= 0.997 + docs.get(10) + .addField(vectorField2, Arrays.asList(1f, 2f, 3f, 4f)); // cosine distance vector2= 1 + docs.get(11) + .addField( + vectorField2, + Arrays.asList(7.5f, 15.5f, 17.5f, 22.5f)); // cosine distance vector2= 0.992 + docs.get(12) + .addField( + vectorField2, Arrays.asList(1.5f, 2.5f, 3.5f, 4.5f)); // cosine distance vector2= 0.998 + + docs.get(0).addField(vectorFieldByteEncoding, Arrays.asList(1, 2, 3, 4)); + docs.get(1).addField(vectorFieldByteEncoding, Arrays.asList(2, 2, 1, 4)); + docs.get(2).addField(vectorFieldByteEncoding, Arrays.asList(1, 2, 1, 2)); + docs.get(3).addField(vectorFieldByteEncoding, Arrays.asList(7, 2, 1, 3)); + docs.get(4).addField(vectorFieldByteEncoding, Arrays.asList(19, 2, 4, 4)); + docs.get(5).addField(vectorFieldByteEncoding, Arrays.asList(19, 2, 4, 4)); + docs.get(6).addField(vectorFieldByteEncoding, Arrays.asList(18, 2, 4, 4)); + docs.get(7).addField(vectorFieldByteEncoding, Arrays.asList(8, 3, 2, 4)); + + return docs; + } + + @After + public void cleanUp() { + clearIndex(); + deleteCore(); + } + + @Test + public void incorrectTopK_shouldThrowException() { + String vectorToSearch = "[1.0, 2.0, 3.0, 4.0]"; + + assertQEx( + "String topK should throw Exception", + "For input string: \"string\"", + req(CommonParams.Q, "{!knn f=vector topK=string}" + vectorToSearch, "fl", "id"), + SolrException.ErrorCode.BAD_REQUEST); + + assertQEx( + "Double topK should throw Exception", + "For input string: \"4.5\"", + req(CommonParams.Q, "{!knn f=vector topK=4.5}" + vectorToSearch, "fl", "id"), + SolrException.ErrorCode.BAD_REQUEST); + } + + @Test + public void topKMissing_shouldReturnDefaultTopK() { + String vectorToSearch = "[1.0, 2.0, 3.0, 4.0]"; + + assertQ( + req(CommonParams.Q, "{!knn f=vector}" + vectorToSearch, "fl", "id"), + "//result[@numFound='" + DEFAULT_TOP_K + "']", + "//result/doc[1]/str[@name='id'][.='1']", + "//result/doc[2]/str[@name='id'][.='4']", + "//result/doc[3]/str[@name='id'][.='2']", + "//result/doc[4]/str[@name='id'][.='10']", + "//result/doc[5]/str[@name='id'][.='3']", + "//result/doc[6]/str[@name='id'][.='7']", + "//result/doc[7]/str[@name='id'][.='5']", + "//result/doc[8]/str[@name='id'][.='6']", + "//result/doc[9]/str[@name='id'][.='9']", + "//result/doc[10]/str[@name='id'][.='8']"); + } + + @Test + public void topK_shouldReturnOnlyTopKResults() { + String vectorToSearch = "[1.0, 2.0, 3.0, 4.0]"; + + assertQ( + req(CommonParams.Q, "{!knn f=vector topK=5}" + vectorToSearch, "fl", "id"), + "//result[@numFound='5']", + "//result/doc[1]/str[@name='id'][.='1']", + "//result/doc[2]/str[@name='id'][.='4']", + "//result/doc[3]/str[@name='id'][.='2']", + "//result/doc[4]/str[@name='id'][.='10']", + "//result/doc[5]/str[@name='id'][.='3']"); + + assertQ( + req(CommonParams.Q, "{!knn f=vector topK=3}" + vectorToSearch, "fl", "id"), + "//result[@numFound='3']", + "//result/doc[1]/str[@name='id'][.='1']", + "//result/doc[2]/str[@name='id'][.='4']", + "//result/doc[3]/str[@name='id'][.='2']"); + } + + @Test + public void incorrectVectorFieldType_shouldThrowException() { + String vectorToSearch = "[1.0, 2.0, 3.0, 4.0]"; + + assertQEx( + "Incorrect vector field type should throw Exception", + "only DenseVectorField is compatible with Vector Query Parsers", + req(CommonParams.Q, "{!knn f=id topK=10}" + vectorToSearch, "fl", "id"), + SolrException.ErrorCode.BAD_REQUEST); + } + + @Test + public void undefinedVectorField_shouldThrowException() { + String vectorToSearch = "[1.0, 2.0, 3.0, 4.0]"; + + assertQEx( + "Undefined vector field should throw Exception", + "undefined field: \"notExistent\"", + req(CommonParams.Q, "{!knn f=notExistent topK=10}" + vectorToSearch, "fl", "id"), + SolrException.ErrorCode.BAD_REQUEST); + } + + @Test + public void missingVectorField_shouldThrowException() { + String vectorToSearch = "[1.0, 2.0, 3.0, 4.0]"; + + assertQEx( + "missing vector field should throw Exception", + "the Dense Vector field 'f' is missing", + req(CommonParams.Q, "{!knn topK=10}" + vectorToSearch, "fl", "id"), + SolrException.ErrorCode.BAD_REQUEST); + } + + @Test + public void correctVectorField_shouldSearchOnThatField() { + String vectorToSearch = "[1.0, 2.0, 3.0, 4.0]"; + + assertQ( + req(CommonParams.Q, "{!knn f=vector2 topK=5}" + vectorToSearch, "fl", "id"), + "//result[@numFound='3']", + "//result/doc[1]/str[@name='id'][.='11']", + "//result/doc[2]/str[@name='id'][.='13']", + "//result/doc[3]/str[@name='id'][.='12']"); + } + + @Test + public void highDimensionFloatVectorField_shouldSearchOnThatField() { + int highDimension = 2048; + List docsToIndex = this.prepareHighDimensionFloatVectorsDocs(highDimension); + for (SolrInputDocument doc : docsToIndex) { + assertU(adoc(doc)); + } + assertU(commit()); + + float[] highDimensionalityQueryVector = new float[highDimension]; + for (int i = 0; i < highDimension; i++) { + highDimensionalityQueryVector[i] = i; + } + String vectorToSearch = Arrays.toString(highDimensionalityQueryVector); + + assertQ( + req(CommonParams.Q, "{!knn f=2048_float_vector topK=1}" + vectorToSearch, "fl", "id"), + "//result[@numFound='1']", + "//result/doc[1]/str[@name='id'][.='1']"); + } + + @Test + public void highDimensionByteVectorField_shouldSearchOnThatField() { + int highDimension = 2048; + List docsToIndex = this.prepareHighDimensionByteVectorsDocs(highDimension); + for (SolrInputDocument doc : docsToIndex) { + assertU(adoc(doc)); + } + assertU(commit()); + + byte[] highDimensionalityQueryVector = new byte[highDimension]; + for (int i = 0; i < highDimension; i++) { + highDimensionalityQueryVector[i] = (byte) (i % 127); + } + String vectorToSearch = Arrays.toString(highDimensionalityQueryVector); + + assertQ( + req(CommonParams.Q, "{!knn f=2048_byte_vector topK=1}" + vectorToSearch, "fl", "id"), + "//result[@numFound='1']", + "//result/doc[1]/str[@name='id'][.='1']"); + } + + private List prepareHighDimensionFloatVectorsDocs(int highDimension) { + int docsCount = 13; + String field = "2048_float_vector"; + List docs = new ArrayList<>(docsCount); + + for (int i = 1; i < docsCount + 1; i++) { + SolrInputDocument doc = new SolrInputDocument(); + doc.addField(IDField, i); + docs.add(doc); + } + + for (int i = 0; i < docsCount; i++) { + List highDimensionalityVector = new ArrayList<>(); + for (int j = i * highDimension; j < highDimension; j++) { + highDimensionalityVector.add(j); + } + docs.get(i).addField(field, highDimensionalityVector); + } + Collections.reverse(docs); + return docs; + } + + private List prepareHighDimensionByteVectorsDocs(int highDimension) { + int docsCount = 13; + String field = "2048_byte_vector"; + List docs = new ArrayList<>(docsCount); + + for (int i = 1; i < docsCount + 1; i++) { + SolrInputDocument doc = new SolrInputDocument(); + doc.addField(IDField, i); + docs.add(doc); + } + + for (int i = 0; i < docsCount; i++) { + List highDimensionalityVector = new ArrayList<>(); + for (int j = i * highDimension; j < highDimension; j++) { + highDimensionalityVector.add(j % 127); + } + docs.get(i).addField(field, highDimensionalityVector); + } + Collections.reverse(docs); + return docs; + } + + @Test + public void vectorByteEncodingField_shouldSearchOnThatField() { + String vectorToSearch = "[2, 2, 1, 3]"; + + assertQ( + req(CommonParams.Q, "{!knn f=vector_byte_encoding topK=2}" + vectorToSearch, "fl", "id"), + "//result[@numFound='2']", + "//result/doc[1]/str[@name='id'][.='2']", + "//result/doc[2]/str[@name='id'][.='3']"); + + vectorToSearch = "[8, 3, 2, 4]"; + + assertQ( + req(CommonParams.Q, "{!knn f=vector_byte_encoding topK=2}" + vectorToSearch, "fl", "id"), + "//result[@numFound='2']", + "//result/doc[1]/str[@name='id'][.='8']", + "//result/doc[2]/str[@name='id'][.='4']"); + } + + @Test + public void vectorByteEncodingField_shouldRaiseExceptionIfQueryUsesFloatVectors() { + String vectorToSearch = "[8.3, 4.3, 2.1, 4.1]"; + + assertQEx( + "incorrect vector element: '8.3'. The expected format is:'[b1,b2..b3]' where each element b is a byte (-128 to 127)", + "incorrect vector element: '8.3'. The expected format is:'[b1,b2..b3]' where each element b is a byte (-128 to 127)", + req(CommonParams.Q, "{!knn f=vector_byte_encoding topK=10}" + vectorToSearch, "fl", "id"), + SolrException.ErrorCode.BAD_REQUEST); + } + + @Test + public void + vectorByteEncodingField_shouldRaiseExceptionWhenQueryContainsValuesOutsideByteValueRange() { + String vectorToSearch = "[1, -129, 3, 5]"; + + assertQEx( + "incorrect vector element: ' -129'. The expected format is:'[b1,b2..b3]' where each element b is a byte (-128 to 127)", + "incorrect vector element: ' -129'. The expected format is:'[b1,b2..b3]' where each element b is a byte (-128 to 127)", + req(CommonParams.Q, "{!knn f=vector_byte_encoding topK=10}" + vectorToSearch, "fl", "id"), + SolrException.ErrorCode.BAD_REQUEST); + + vectorToSearch = "[1, 3, 156, 5]"; + + assertQEx( + "incorrect vector element: ' 156'. The expected format is:'[b1,b2..b3]' where each element b is a byte (-128 to 127)", + "incorrect vector element: ' 156'. The expected format is:'[b1,b2..b3]' where each element b is a byte (-128 to 127)", + req(CommonParams.Q, "{!knn f=vector_byte_encoding topK=10}" + vectorToSearch, "fl", "id"), + SolrException.ErrorCode.BAD_REQUEST); + } + + @Test + public void missingVectorToSearch_shouldThrowException() { + assertQEx( + "missing vector to search should throw Exception", + "the Dense Vector value 'v' to search is missing", + req(CommonParams.Q, "{!knn f=vector topK=10}", "fl", "id"), + SolrException.ErrorCode.BAD_REQUEST); + } + + @Test + public void incorrectVectorToSearchDimension_shouldThrowException() { + String vectorToSearch = "[2.0, 4.4, 3.]"; + assertQEx( + "missing vector to search should throw Exception", + "incorrect vector dimension. The vector value has size 3 while it is expected a vector with size 4", + req(CommonParams.Q, "{!knn f=vector topK=10}" + vectorToSearch, "fl", "id"), + SolrException.ErrorCode.BAD_REQUEST); + + vectorToSearch = "[2.0, 4.4,,]"; + assertQEx( + "incorrect vector to search should throw Exception", + "incorrect vector dimension. The vector value has size 2 while it is expected a vector with size 4", + req(CommonParams.Q, "{!knn f=vector topK=10}" + vectorToSearch, "fl", "id"), + SolrException.ErrorCode.BAD_REQUEST); + } + + @Test + public void incorrectVectorToSearch_shouldThrowException() { + String vectorToSearch = "2.0, 4.4, 3.5, 6.4"; + assertQEx( + "incorrect vector to search should throw Exception", + "incorrect vector format. The expected format is:'[f1,f2..f3]' where each element f is a float", + req(CommonParams.Q, "{!knn f=vector topK=10}" + vectorToSearch, "fl", "id"), + SolrException.ErrorCode.BAD_REQUEST); + + vectorToSearch = "[2.0, 4.4, 3.5, 6.4"; + assertQEx( + "incorrect vector to search should throw Exception", + "incorrect vector format. The expected format is:'[f1,f2..f3]' where each element f is a float", + req(CommonParams.Q, "{!knn f=vector topK=10}" + vectorToSearch, "fl", "id"), + SolrException.ErrorCode.BAD_REQUEST); + + vectorToSearch = "2.0, 4.4, 3.5, 6.4]"; + assertQEx( + "incorrect vector to search should throw Exception", + "incorrect vector format. The expected format is:'[f1,f2..f3]' where each element f is a float", + req(CommonParams.Q, "{!knn f=vector topK=10}" + vectorToSearch, "fl", "id"), + SolrException.ErrorCode.BAD_REQUEST); + + vectorToSearch = "[2.0, 4.4, 3.5, stringElement]"; + assertQEx( + "incorrect vector to search should throw Exception", + "incorrect vector element: ' stringElement'. The expected format is:'[f1,f2..f3]' where each element f is a float", + req(CommonParams.Q, "{!knn f=vector topK=10}" + vectorToSearch, "fl", "id"), + SolrException.ErrorCode.BAD_REQUEST); + + vectorToSearch = "[2.0, 4.4, , ]"; + assertQEx( + "incorrect vector to search should throw Exception", + "incorrect vector element: ' '. The expected format is:'[f1,f2..f3]' where each element f is a float", + req(CommonParams.Q, "{!knn f=vector topK=10}" + vectorToSearch, "fl", "id"), + SolrException.ErrorCode.BAD_REQUEST); + } + + @Test + public void correctQuery_shouldRankBySimilarityFunction() { + String vectorToSearch = "[1.0, 2.0, 3.0, 4.0]"; + + assertQ( + req(CommonParams.Q, "{!knn f=vector topK=10}" + vectorToSearch, "fl", "id"), + "//result[@numFound='10']", + "//result/doc[1]/str[@name='id'][.='1']", + "//result/doc[2]/str[@name='id'][.='4']", + "//result/doc[3]/str[@name='id'][.='2']", + "//result/doc[4]/str[@name='id'][.='10']", + "//result/doc[5]/str[@name='id'][.='3']", + "//result/doc[6]/str[@name='id'][.='7']", + "//result/doc[7]/str[@name='id'][.='5']", + "//result/doc[8]/str[@name='id'][.='6']", + "//result/doc[9]/str[@name='id'][.='9']", + "//result/doc[10]/str[@name='id'][.='8']"); + } + + @Test + public void knnQueryUsedInFilter_shouldFilterResultsBeforeTheQueryExecution() { + String vectorToSearch = "[1.0, 2.0, 3.0, 4.0]"; + assertQ( + req( + CommonParams.Q, + "id:(3 4 9 2)", + "fq", + "{!knn f=vector topK=4}" + vectorToSearch, + "fl", + "id"), + "//result[@numFound='2']", + "//result/doc[1]/str[@name='id'][.='2']", + "//result/doc[2]/str[@name='id'][.='4']"); + } + + @Test + public void knnQueryUsedInFilters_shouldFilterResultsBeforeTheQueryExecution() { + String vectorToSearch = "[1.0, 2.0, 3.0, 4.0]"; + + // topK=4 -> 1,4,2,10 + assertQ( + req( + CommonParams.Q, + "id:(3 4 9 2)", + "fq", + "{!knn f=vector topK=4}" + vectorToSearch, + "fq", + "id:(4 20 9)", + "fl", + "id"), + "//result[@numFound='1']", + "//result/doc[1]/str[@name='id'][.='4']"); + } + + @Test + public void knnQueryUsedInFiltersWithPreFilter_shouldFilterResultsBeforeTheQueryExecution() { + String vectorToSearch = "[1.0, 2.0, 3.0, 4.0]"; + + // topK=4 w/localparam preFilter -> 1,4,7,9 + assertQ( + req( + CommonParams.Q, + "id:(3 4 9 2)", + "fq", + "{!knn f=vector topK=4 preFilter='id:(1 4 7 8 9)'}" + vectorToSearch, + "fq", + "id:(4 20 9)", + "fl", + "id"), + "//result[@numFound='2']", + "//result/doc[1]/str[@name='id'][.='4']", + "//result/doc[2]/str[@name='id'][.='9']"); + } + + @Test + public void knnQueryUsedInFilters_rejectIncludeExclude() { + String vectorToSearch = "[1.0, 2.0, 3.0, 4.0]"; + + for (String fq : + Arrays.asList( + "{!knn f=vector topK=5 includeTags=xxx}" + vectorToSearch, + "{!knn f=vector topK=5 excludeTags=xxx}" + vectorToSearch)) { + assertQEx( + "fq={!knn...} incompatible with include/exclude localparams", + "used as a filter does not support", + req("q", "*:*", "fq", fq), + SolrException.ErrorCode.BAD_REQUEST); + } + } + + @Test + public void knnQueryAsSubQuery() { + final SolrParams common = params("fl", "id", "vec", "[1.0, 2.0, 3.0, 4.0]"); + final String filt = "id:(2 4 7 9 8 20 3)"; + + // When knn parser is a subquery, it should not pre-filter on any global fq params + // topK -> 1,4,2,10,3 -> fq -> 4,2,3 + assertQ( + req(common, "fq", filt, "q", "*:* AND {!knn f=vector topK=5 v=$vec}"), + "//result[@numFound='3']", + "//result/doc[1]/str[@name='id'][.='4']", + "//result/doc[2]/str[@name='id'][.='2']", + "//result/doc[3]/str[@name='id'][.='3']"); + // topK -> 1,4,2,10,3 + '8' -> fq -> 4,2,3,8 + assertQ( + req(common, "fq", filt, "q", "id:8^=0.01 OR {!knn f=vector topK=5 v=$vec}"), + "//result[@numFound='4']", + "//result/doc[1]/str[@name='id'][.='4']", + "//result/doc[2]/str[@name='id'][.='2']", + "//result/doc[3]/str[@name='id'][.='3']", + "//result/doc[4]/str[@name='id'][.='8']"); + } + + @Test + public void knnQueryAsSubQuery_withPreFilter() { + final SolrParams common = params("fl", "id", "vec", "[1.0, 2.0, 3.0, 4.0]"); + final String filt = "id:(2 4 7 9 8 20 3)"; + + // knn subquery should still accept `preFilter` local param + // filt -> topK -> 4,2,3,7,9 + assertQ( + req(common, "q", "*:* AND {!knn f=vector topK=5 preFilter='" + filt + "' v=$vec}"), + "//result[@numFound='5']", + "//result/doc[1]/str[@name='id'][.='4']", + "//result/doc[2]/str[@name='id'][.='2']", + "//result/doc[3]/str[@name='id'][.='3']", + "//result/doc[4]/str[@name='id'][.='7']", + "//result/doc[5]/str[@name='id'][.='9']"); + + // it should not pre-filter on any global fq params + // filt -> topK -> 4,2,3,7,9 -> fq -> 3,9 + assertQ( + req( + common, + "fq", + "id:(1 9 20 3 5 6 8)", + "q", + "*:* AND {!knn f=vector topK=5 preFilter='" + filt + "' v=$vec}"), + "//result[@numFound='2']", + "//result/doc[1]/str[@name='id'][.='3']", + "//result/doc[2]/str[@name='id'][.='9']"); + // filt -> topK -> 4,2,3,7,9 + '8' -> fq -> 8,3,9 + assertQ( + req( + common, + "fq", + "id:(1 9 20 3 5 6 8)", + "q", + "id:8^=100 OR {!knn f=vector topK=5 preFilter='" + filt + "' v=$vec}"), + "//result[@numFound='3']", + "//result/doc[1]/str[@name='id'][.='8']", + "//result/doc[2]/str[@name='id'][.='3']", + "//result/doc[3]/str[@name='id'][.='9']"); + } + + @Test + public void knnQueryAsSubQuery_rejectIncludeExclude() { + final SolrParams common = params("fl", "id", "vec", "[1.0, 2.0, 3.0, 4.0]"); + + for (String knn : + Arrays.asList( + "{!knn f=vector topK=5 includeTags=xxx v=$vec}", + "{!knn f=vector topK=5 excludeTags=xxx v=$vec}")) { + assertQEx( + "knn as subquery incompatible with include/exclude localparams", + "used as a sub-query does not support", + req(common, "q", "*:* OR " + knn), + SolrException.ErrorCode.BAD_REQUEST); + } + } + + @Test + public void knnQueryWithFilterQuery_singlePreFilterEquivilence() { + final String vectorToSearch = "[1.0, 2.0, 3.0, 4.0]"; + final SolrParams common = params("fl", "id"); + + // these requests should be equivalent + final String filt = "id:(1 2 7 20)"; + for (SolrQueryRequest req : + Arrays.asList( + req(common, "q", "{!knn f=vector topK=10}" + vectorToSearch, "fq", filt), + req(common, "q", "{!knn f=vector preFilter=\"" + filt + "\" topK=10}" + vectorToSearch), + req( + common, + "q", + "{!knn f=vector preFilter=$my_filt topK=10}" + vectorToSearch, + "my_filt", + filt))) { + assertQ( + req, + "//result[@numFound='3']", + "//result/doc[1]/str[@name='id'][.='1']", + "//result/doc[2]/str[@name='id'][.='2']", + "//result/doc[3]/str[@name='id'][.='7']"); + } + } + + @Test + public void knnQueryWithFilterQuery_multiPreFilterEquivilence() { + final String vectorToSearch = "[1.0, 2.0, 3.0, 4.0]"; + final SolrParams common = params("fl", "id"); + + // these requests should be equivalent + final String fx = "id:(3 4 9 2 1 )"; // 1 & 10 dropped from intersection + final String fy = "id:(3 4 9 2 10)"; + for (SolrQueryRequest req : + Arrays.asList( + req(common, "q", "{!knn f=vector topK=4}" + vectorToSearch, "fq", fx, "fq", fy), + req( + common, + "q", + "{!knn f=vector preFilter=\"" + + fx + + "\" preFilter=\"" + + fy + + "\" topK=4}" + + vectorToSearch), + req( + common, + "q", + "{!knn f=vector preFilter=$fx preFilter=$fy topK=4}" + vectorToSearch, + "fx", + fx, + "fy", + fy), + req( + common, + "q", + "{!knn f=vector preFilter=$multi_filt topK=4}" + vectorToSearch, + "multi_filt", + fx, + "multi_filt", + fy))) { + assertQ( + req, + "//result[@numFound='4']", + "//result/doc[1]/str[@name='id'][.='4']", + "//result/doc[2]/str[@name='id'][.='2']", + "//result/doc[3]/str[@name='id'][.='3']", + "//result/doc[4]/str[@name='id'][.='9']"); + } + } + + @Test + public void knnQueryWithPreFilter_rejectIncludeExclude() { + final String vectorToSearch = "[1.0, 2.0, 3.0, 4.0]"; + + assertQEx( + "knn preFilter localparm incompatible with include/exclude localparams", + "does not support combining preFilter localparam with either", + // shouldn't matter if global fq w/tag even exists, usage is an error + req("q", "{!knn f=vector preFilter='id:1' includeTags=xxx}" + vectorToSearch), + SolrException.ErrorCode.BAD_REQUEST); + assertQEx( + "knn preFilter localparm incompatible with include/exclude localparams", + "does not support combining preFilter localparam with either", + // shouldn't matter if global fq w/tag even exists, usage is an error + req("q", "{!knn f=vector preFilter='id:1' excludeTags=xxx}" + vectorToSearch), + SolrException.ErrorCode.BAD_REQUEST); + } + + @Test + public void knnQueryWithFilterQuery_preFilterLocalParamOverridesGlobalFilters() { + final String vectorToSearch = "[1.0, 2.0, 3.0, 4.0]"; + + // trivial case: empty preFilter localparam means no pre-filtering + assertQ( + req( + "q", "{!knn f=vector preFilter='' topK=5}" + vectorToSearch, + "fq", "-id:4", + "fl", "id"), + "//result[@numFound='4']", + "//result/doc[1]/str[@name='id'][.='1']", + "//result/doc[2]/str[@name='id'][.='2']", + "//result/doc[3]/str[@name='id'][.='10']", + "//result/doc[4]/str[@name='id'][.='3']"); + + // localparam prefiltering, global fqs applied independently + assertQ( + req( + "q", "{!knn f=vector preFilter='id:(3 4 9 2 7 8)' topK=5}" + vectorToSearch, + "fq", "-id:4", + "fl", "id"), + "//result[@numFound='4']", + "//result/doc[1]/str[@name='id'][.='2']", + "//result/doc[2]/str[@name='id'][.='3']", + "//result/doc[3]/str[@name='id'][.='7']", + "//result/doc[4]/str[@name='id'][.='9']"); + } + + @Test + public void knnQueryWithFilterQuery_localParamIncludeExcludeTags() { + final String vectorToSearch = "[1.0, 2.0, 3.0, 4.0]"; + final SolrParams common = + params( + "fl", "id", + "fq", "{!tag=xx,aa}id:(5 6 7 8 9 10)", + "fq", "{!tag=yy,aa}id:(1 2 3 4 5 6 7)"); + + // These req's are equivalent: pre-filter everything + // So only 7,6,5 are viable for topK=5 + for (SolrQueryRequest req : + Arrays.asList( + // default behavior is all fq's pre-filter, + req(common, "q", "{!knn f=vector topK=5}" + vectorToSearch), + // diff ways of explicitly requesting both fq params + req(common, "q", "{!knn f=vector includeTags=aa topK=5}" + vectorToSearch), + req( + common, + "q", + "{!knn f=vector includeTags=aa excludeTags='' topK=5}" + vectorToSearch), + req( + common, + "q", + "{!knn f=vector includeTags=aa excludeTags=bogus topK=5}" + vectorToSearch), + req( + common, + "q", + "{!knn f=vector includeTags=xx includeTags=yy topK=5}" + vectorToSearch), + req(common, "q", "{!knn f=vector includeTags=xx,yy,bogus topK=5}" + vectorToSearch))) { + assertQ( + req, + "//result[@numFound='3']", + "//result/doc[1]/str[@name='id'][.='7']", + "//result/doc[2]/str[@name='id'][.='5']", + "//result/doc[3]/str[@name='id'][.='6']"); + } + } + + @Test + public void knnQueryWithFilterQuery_localParamsDisablesAllPreFiltering() { + final String vectorToSearch = "[1.0, 2.0, 3.0, 4.0]"; + final SolrParams common = + params( + "fl", "id", + "fq", "{!tag=xx,aa}id:(5 6 7 8 9 10)", + "fq", "{!tag=yy,aa}id:(1 2 3 4 5 6 7)"); + + // These req's are equivalent: pre-filter nothing + // So 1,4,2,10,3,7 are the topK=6 + // Only 7 matches both of the the regular fq params + for (SolrQueryRequest req : + Arrays.asList( + // explicit local empty preFilter + req(common, "q", "{!knn f=vector preFilter='' topK=6}" + vectorToSearch), + // diff ways of explicitly including none of the global fq params + req(common, "q", "{!knn f=vector includeTags='' topK=6}" + vectorToSearch), + req(common, "q", "{!knn f=vector includeTags=bogus topK=6}" + vectorToSearch), + // diff ways of explicitly excluding all of the global fq params + req(common, "q", "{!knn f=vector excludeTags=aa topK=6}" + vectorToSearch), + req( + common, + "q", + "{!knn f=vector includeTags=aa excludeTags=aa topK=6}" + vectorToSearch), + req( + common, + "q", + "{!knn f=vector includeTags=aa excludeTags=xx,yy topK=6}" + vectorToSearch), + req( + common, + "q", + "{!knn f=vector includeTags=xx,yy excludeTags=aa topK=6}" + vectorToSearch), + req(common, "q", "{!knn f=vector excludeTags=xx,yy topK=6}" + vectorToSearch), + req(common, "q", "{!knn f=vector excludeTags=aa topK=6}" + vectorToSearch), + req( + common, + "q", + "{!knn f=vector excludeTags=xx excludeTags=yy topK=6}" + vectorToSearch), + req( + common, + "q", + "{!knn f=vector excludeTags=xx excludeTags=yy,bogus topK=6}" + vectorToSearch), + req(common, "q", "{!knn f=vector excludeTags=xx,yy,bogus topK=6}" + vectorToSearch))) { + assertQ(req, "//result[@numFound='1']", "//result/doc[1]/str[@name='id'][.='7']"); + } + } + + @Test + public void knnQueryWithFilterQuery_localParamCombinedIncludeExcludeTags() { + final String vectorToSearch = "[1.0, 2.0, 3.0, 4.0]"; + final SolrParams common = + params( + "fl", "id", + "fq", "{!tag=xx,aa}id:(5 6 7 8 9 10)", + "fq", "{!tag=yy,aa}id:(1 2 3 4 5 6 7)"); + + // These req's are equivalent: prefilter only the 'yy' fq + // So 1,4,2,3,7 are in the topK=5. + // Only 7 matches the regular 'xx' fq param + for (SolrQueryRequest req : + Arrays.asList( + // diff ways of only using the 'yy' filter + req(common, "q", "{!knn f=vector includeTags=yy,bogus topK=5}" + vectorToSearch), + req( + common, + "q", + "{!knn f=vector includeTags=yy excludeTags='' topK=5}" + vectorToSearch), + req(common, "q", "{!knn f=vector excludeTags=xx,bogus topK=5}" + vectorToSearch), + req( + common, + "q", + "{!knn f=vector includeTags=yy excludeTags=xx topK=5}" + vectorToSearch), + req( + common, + "q", + "{!knn f=vector includeTags=aa excludeTags=xx topK=5}" + vectorToSearch))) { + assertQ(req, "//result[@numFound='1']", "//result/doc[1]/str[@name='id'][.='7']"); + } + } + + @Test + public void knnQueryWithMultiSelectFaceting_excludeTags() { + // NOTE: faceting on id is not very realistic, + // but it confirms what we care about re:filters w/o needing extra fields. + final String facet_xpath = "//lst[@name='facet_fields']/lst[@name='id']/int"; + final String vectorToSearch = "[1.0, 2.0, 3.0, 4.0]"; + + final SolrParams common = + params( + "fl", "id", + "indent", "true", + "q", "{!knn f=vector topK=5 excludeTags=facet_click v=$vec}", + "vec", vectorToSearch, + // mimicing "inStock:true" + "fq", "-id:(2 3)", + "facet", "true", + "facet.mincount", "1", + "facet.field", "{!ex=facet_click}id"); + + // initial query, with basic pre-filter and facet counts + assertQ( + req(common), + "//result[@numFound='5']", + "//result/doc[1]/str[@name='id'][.='1']", + "//result/doc[2]/str[@name='id'][.='4']", + "//result/doc[3]/str[@name='id'][.='10']", + "//result/doc[4]/str[@name='id'][.='7']", + "//result/doc[5]/str[@name='id'][.='5']", + "*[count(" + facet_xpath + ")=5]", + facet_xpath + "[@name='1'][.='1']", + facet_xpath + "[@name='4'][.='1']", + facet_xpath + "[@name='10'][.='1']", + facet_xpath + "[@name='7'][.='1']", + facet_xpath + "[@name='5'][.='1']"); + + // drill down on a single facet constraint + // multi-select means facet counts shouldn't change + // (this proves the knn isn't pre-filtering on the 'facet_click' fq) + assertQ( + req(common, "fq", "{!tag=facet_click}id:(4)"), + "//result[@numFound='1']", + "//result/doc[1]/str[@name='id'][.='4']", + "*[count(" + facet_xpath + ")=5]", + facet_xpath + "[@name='1'][.='1']", + facet_xpath + "[@name='4'][.='1']", + facet_xpath + "[@name='10'][.='1']", + facet_xpath + "[@name='7'][.='1']", + facet_xpath + "[@name='5'][.='1']"); + + // drill down on an additional facet constraint + // multi-select means facet counts shouldn't change + // (this proves the knn isn't pre-filtering on the 'facet_click' fq) + assertQ( + req(common, "fq", "{!tag=facet_click}id:(4 5)"), + "//result[@numFound='2']", + "//result/doc[1]/str[@name='id'][.='4']", + "//result/doc[2]/str[@name='id'][.='5']", + "*[count(" + facet_xpath + ")=5]", + facet_xpath + "[@name='1'][.='1']", + facet_xpath + "[@name='4'][.='1']", + facet_xpath + "[@name='10'][.='1']", + facet_xpath + "[@name='7'][.='1']", + facet_xpath + "[@name='5'][.='1']"); + } + + @Test + public void knnQueryWithCostlyFq_shouldPerformKnnSearchWithPostFilter() { + String vectorToSearch = "[1.0, 2.0, 3.0, 4.0]"; + + assertQ( + req( + CommonParams.Q, + "{!knn f=vector topK=10}" + vectorToSearch, + "fq", + "{!frange cache=false l=0.99}$q", + "fl", + "*,score"), + "//result[@numFound='5']", + "//result/doc[1]/str[@name='id'][.='1']", + "//result/doc[2]/str[@name='id'][.='4']", + "//result/doc[3]/str[@name='id'][.='2']", + "//result/doc[4]/str[@name='id'][.='10']", + "//result/doc[5]/str[@name='id'][.='3']"); + } + + @Test + public void knnQueryWithFilterQueries_shouldPerformKnnSearchWithPreFiltersAndPostFilters() { + String vectorToSearch = "[1.0, 2.0, 3.0, 4.0]"; + + assertQ( + req( + CommonParams.Q, + "{!knn f=vector topK=4}" + vectorToSearch, + "fq", + "id:(3 4 9 2)", + "fq", + "{!frange cache=false l=0.99}$q", + "fl", + "id"), + "//result[@numFound='2']", + "//result/doc[1]/str[@name='id'][.='4']", + "//result/doc[2]/str[@name='id'][.='2']"); + } + + @Test + public void knnQueryWithNegativeFilterQuery_shouldPerformKnnSearchInPreFilteredResults() { + String vectorToSearch = "[1.0, 2.0, 3.0, 4.0]"; + assertQ( + req(CommonParams.Q, "{!knn f=vector topK=4}" + vectorToSearch, "fq", "-id:4", "fl", "id"), + "//result[@numFound='4']", + "//result/doc[1]/str[@name='id'][.='1']", + "//result/doc[2]/str[@name='id'][.='2']", + "//result/doc[3]/str[@name='id'][.='10']", + "//result/doc[4]/str[@name='id'][.='3']"); + } + + /** + * See {@link org.apache.solr.search.ReRankQParserPlugin.ReRankQueryRescorer#combine(float, + * boolean, float)}} for more details. + */ + @Test + public void knnQueryAsRerank_shouldAddSimilarityFunctionScore() { + String vectorToSearch = "[1.0, 2.0, 3.0, 4.0]"; + + assertQ( + req( + CommonParams.Q, + "id:(3 4 9 2)", + "rq", + "{!rerank reRankQuery=$rqq reRankDocs=4 reRankWeight=1}", + "rqq", + "{!knn f=vector topK=4}" + vectorToSearch, + "fl", + "id"), + "//result[@numFound='4']", + "//result/doc[1]/str[@name='id'][.='4']", + "//result/doc[2]/str[@name='id'][.='2']", + "//result/doc[3]/str[@name='id'][.='3']", + "//result/doc[4]/str[@name='id'][.='9']"); + } +} From 993f0852209f5b92b3cd478d4ee8d13e3ec6c96d Mon Sep 17 00:00:00 2001 From: Alessandro Benedetti Date: Wed, 9 Apr 2025 11:30:19 +0100 Subject: [PATCH 12/43] first draft --- .../KnnQParserMultiValuedVectorsTest.java | 151 +++++++++++------- 1 file changed, 92 insertions(+), 59 deletions(-) diff --git a/solr/core/src/test/org/apache/solr/search/neural/KnnQParserMultiValuedVectorsTest.java b/solr/core/src/test/org/apache/solr/search/neural/KnnQParserMultiValuedVectorsTest.java index bf6e3dfa15e4..d1c0704118a8 100644 --- a/solr/core/src/test/org/apache/solr/search/neural/KnnQParserMultiValuedVectorsTest.java +++ b/solr/core/src/test/org/apache/solr/search/neural/KnnQParserMultiValuedVectorsTest.java @@ -22,9 +22,13 @@ import org.apache.solr.common.params.CommonParams; import org.apache.solr.common.params.SolrParams; import org.apache.solr.request.SolrQueryRequest; +import org.apache.solr.util.RandomNoReverseMergePolicyFactory; import org.junit.After; import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.ClassRule; import org.junit.Test; +import org.junit.rules.TestRule; import java.util.ArrayList; import java.util.Arrays; @@ -34,82 +38,111 @@ import static org.apache.solr.search.neural.KnnQParser.DEFAULT_TOP_K; public class KnnQParserMultiValuedVectorsTest extends SolrTestCaseJ4 { - String IDField = "id"; - String vectorField = "vector"; - String vectorField2 = "vector2"; - String vectorFieldByteEncoding = "vector_byte_encoding"; + + @ClassRule + public static final TestRule noReverseMerge = RandomNoReverseMergePolicyFactory.createRule(); - @Before - public void prepareIndex() throws Exception { + @BeforeClass + public static void beforeClass() throws Exception { /* vectorDimension="4" similarityFunction="cosine" */ initCore("solrconfig_codec.xml", "schema-densevector.xml"); + prepareIndex(); + } - List docsToIndex = this.prepareDocs(); + public static void prepareIndex() throws Exception { + List docsToIndex = prepareDocs(); for (SolrInputDocument doc : docsToIndex) { assertU(adoc(doc)); } - assertU(commit()); } - private List prepareDocs() { - int docsCount = 13; - List docs = new ArrayList<>(docsCount); - for (int i = 1; i < docsCount + 1; i++) { + /** + * The documents in the index are 10 parents, with some parent level metadata and 30 nested + * documents (with vectors and children level metadata) Each parent document has 3 nested + * documents with vectors. + * + *

This allows to run knn queries both at parent/children level and using various pre-filters + * both for parent metadata and children. + * + * @return a list of documents to index + */ + private static List prepareDocs() { + int totalParentDocuments = 10; + int totalNestedVectors = 30; + int perParentChildren = totalNestedVectors / totalParentDocuments; + + final String[] klm = new String[] {"k", "l", "m"}; + final String[] abcdef = new String[] {"a", "b", "c", "d", "e", "f"}; + + List docs = new ArrayList<>(totalParentDocuments); + for (int i = 1; i < totalParentDocuments + 1; i++) { SolrInputDocument doc = new SolrInputDocument(); - doc.addField(IDField, i); + doc.setField("id", i); + doc.setField("parent_b", true); + + doc.setField("parent_s", abcdef[i % abcdef.length]); + List children = new ArrayList<>(perParentChildren); + + // nested vector documents have a distance from the query vector inversely proportional to + // their id + for (int j = 0; j < perParentChildren; j++) { + SolrInputDocument child = new SolrInputDocument(); + child.setField("id", i + "" + j); + child.setField("child_s", klm[i % klm.length]); + child.setField("vector", outDistanceFloat(FLOAT_QUERY_VECTOR, totalNestedVectors)); + child.setField("vector_byte", outDistanceByte(BYTE_QUERY_VECTOR, totalNestedVectors)); + totalNestedVectors--; // the higher the id of the nested document, lower the distance with + // the query vector + children.add(child); + } + doc.setField("vectors", children); docs.add(doc); } - docs.get(0) - .addField(vectorField, Arrays.asList(1f, 2f, 3f, 4f)); // cosine distance vector1= 1.0 - docs.get(1) - .addField( - vectorField, Arrays.asList(1.5f, 2.5f, 3.5f, 4.5f)); // cosine distance vector1= 0.998 - docs.get(2) - .addField( - vectorField, - Arrays.asList(7.5f, 15.5f, 17.5f, 22.5f)); // cosine distance vector1= 0.992 - docs.get(3) - .addField( - vectorField, Arrays.asList(1.4f, 2.4f, 3.4f, 4.4f)); // cosine distance vector1= 0.999 - docs.get(4) - .addField(vectorField, Arrays.asList(30f, 22f, 35f, 20f)); // cosine distance vector1= 0.862 - docs.get(5) - .addField(vectorField, Arrays.asList(40f, 1f, 1f, 200f)); // cosine distance vector1= 0.756 - docs.get(6) - .addField(vectorField, Arrays.asList(5f, 10f, 20f, 40f)); // cosine distance vector1= 0.970 - docs.get(7) - .addField( - vectorField, Arrays.asList(120f, 60f, 30f, 15f)); // cosine distance vector1= 0.515 - docs.get(8) - .addField( - vectorField, Arrays.asList(200f, 50f, 100f, 25f)); // cosine distance vector1= 0.554 - docs.get(9) - .addField( - vectorField, Arrays.asList(1.8f, 2.5f, 3.7f, 4.9f)); // cosine distance vector1= 0.997 - docs.get(10) - .addField(vectorField2, Arrays.asList(1f, 2f, 3f, 4f)); // cosine distance vector2= 1 - docs.get(11) - .addField( - vectorField2, - Arrays.asList(7.5f, 15.5f, 17.5f, 22.5f)); // cosine distance vector2= 0.992 - docs.get(12) - .addField( - vectorField2, Arrays.asList(1.5f, 2.5f, 3.5f, 4.5f)); // cosine distance vector2= 0.998 - - docs.get(0).addField(vectorFieldByteEncoding, Arrays.asList(1, 2, 3, 4)); - docs.get(1).addField(vectorFieldByteEncoding, Arrays.asList(2, 2, 1, 4)); - docs.get(2).addField(vectorFieldByteEncoding, Arrays.asList(1, 2, 1, 2)); - docs.get(3).addField(vectorFieldByteEncoding, Arrays.asList(7, 2, 1, 3)); - docs.get(4).addField(vectorFieldByteEncoding, Arrays.asList(19, 2, 4, 4)); - docs.get(5).addField(vectorFieldByteEncoding, Arrays.asList(19, 2, 4, 4)); - docs.get(6).addField(vectorFieldByteEncoding, Arrays.asList(18, 2, 4, 4)); - docs.get(7).addField(vectorFieldByteEncoding, Arrays.asList(8, 3, 2, 4)); - return docs; } + /** + * Generate a resulting float vector with a distance from the original vector that is proportional + * to the value in input (higher the value, higher the distance from the original vector) + * + * @param vector a numerical vector + * @param value a numerical value to be added to the first element of the vector + * @return a numerical vector that has a distance from the input vector, proportional to the value + */ + private static List outDistanceFloat(List vector, int value) { + List result = new ArrayList<>(vector.size()); + for (int i = 0; i < vector.size(); i++) { + if (i == 0) { + result.add(vector.get(i) + value); + } else { + result.add(vector.get(i)); + } + } + return result; + } + + /** + * Generate a resulting byte vector with a distance from the original vector that is proportional + * to the value in input (higher the value, higher the distance from the original vector) + * + * @param vector a numerical vector + * @param value a numerical value to be added to the first element of the vector + * @return a numerical vector that has a distance from the input vector, proportional to the value + */ + private static List outDistanceByte(List vector, int value) { + List result = new ArrayList<>(vector.size()); + for (int i = 0; i < vector.size(); i++) { + if (i == 0) { + result.add(vector.get(i) + value); + } else { + result.add(vector.get(i)); + } + } + return result; + } + @After public void cleanUp() { clearIndex(); From 9713cc026b342275e4500af4ea21f645898d9810 Mon Sep 17 00:00:00 2001 From: Alessandro Benedetti Date: Wed, 9 Apr 2025 14:02:37 +0100 Subject: [PATCH 13/43] first draft --- .../apache/solr/schema/DenseVectorField.java | 5 - .../apache/solr/update/AddUpdateCommand.java | 13 +- .../NestedUpdateProcessorFactory.java | 228 +++-- .../collection1/conf/schema-densevector.xml | 7 + .../solr/schema/DenseVectorFieldTest.java | 8 - .../KnnQParserMultiValuedVectorsTest.java | 886 +----------------- 6 files changed, 159 insertions(+), 988 deletions(-) diff --git a/solr/core/src/java/org/apache/solr/schema/DenseVectorField.java b/solr/core/src/java/org/apache/solr/schema/DenseVectorField.java index 297d8841be56..d8e9c000e215 100644 --- a/solr/core/src/java/org/apache/solr/schema/DenseVectorField.java +++ b/solr/core/src/java/org/apache/solr/schema/DenseVectorField.java @@ -200,11 +200,6 @@ protected boolean enableDocValuesByDefault() { @Override public void checkSchemaField(final SchemaField field) throws SolrException { super.checkSchemaField(field); - if (field.multiValued()) { - throw new SolrException( - SolrException.ErrorCode.SERVER_ERROR, - getClass().getSimpleName() + " fields can not be multiValued: " + field.getName()); - } if (field.hasDocValues()) { throw new SolrException( diff --git a/solr/core/src/java/org/apache/solr/update/AddUpdateCommand.java b/solr/core/src/java/org/apache/solr/update/AddUpdateCommand.java index 7b3c86ac4c54..5ca176ea66ae 100644 --- a/solr/core/src/java/org/apache/solr/update/AddUpdateCommand.java +++ b/solr/core/src/java/org/apache/solr/update/AddUpdateCommand.java @@ -29,7 +29,6 @@ import org.apache.solr.common.SolrInputField; import org.apache.solr.common.params.CommonParams; import org.apache.solr.request.SolrQueryRequest; -import org.apache.solr.schema.DenseVectorField; import org.apache.solr.schema.IndexSchema; import org.apache.solr.schema.SchemaField; @@ -257,9 +256,7 @@ private List flatten(SolrInputDocument root) { /** Extract all child documents from parent that are saved in fields */ private void flattenLabelled( List unwrappedDocs, SolrInputDocument currentDoc, boolean isRoot) { - IndexSchema schema = req.getSchema(); for (SolrInputField field : currentDoc.values()) { - SchemaField sfield = schema.getFieldOrNull(field.getName()); Object value = field.getFirstValue(); // check if value is a childDocument if (value instanceof SolrInputDocument) { @@ -273,15 +270,7 @@ private void flattenLabelled( for (SolrInputDocument child : childrenList) { flattenLabelled(unwrappedDocs, child); } - } else if (sfield.getType() instanceof DenseVectorField && sfield.multiValued() && isRoot){ - Collection vectorValues = field.getValues(); - for(Object vectorValue:vectorValues){ - SolrInputDocument singleVectorNestedDoc = new SolrInputDocument(); - singleVectorNestedDoc.setField(field.getName(), vectorValue); - flattenLabelled(unwrappedDocs, singleVectorNestedDoc); - } - - } + } } if (!isRoot) unwrappedDocs.add(currentDoc); diff --git a/solr/core/src/java/org/apache/solr/update/processor/NestedUpdateProcessorFactory.java b/solr/core/src/java/org/apache/solr/update/processor/NestedUpdateProcessorFactory.java index 96a93b7b45e4..273e8b928388 100644 --- a/solr/core/src/java/org/apache/solr/update/processor/NestedUpdateProcessorFactory.java +++ b/solr/core/src/java/org/apache/solr/update/processor/NestedUpdateProcessorFactory.java @@ -18,13 +18,17 @@ package org.apache.solr.update.processor; import java.io.IOException; +import java.util.ArrayList; import java.util.Collection; + import org.apache.solr.common.SolrException; import org.apache.solr.common.SolrInputDocument; import org.apache.solr.common.SolrInputField; import org.apache.solr.request.SolrQueryRequest; import org.apache.solr.response.SolrQueryResponse; +import org.apache.solr.schema.DenseVectorField; import org.apache.solr.schema.IndexSchema; +import org.apache.solr.schema.SchemaField; import org.apache.solr.update.AddUpdateCommand; /** @@ -37,110 +41,144 @@ */ public class NestedUpdateProcessorFactory extends UpdateRequestProcessorFactory { - @Override - public UpdateRequestProcessor getInstance( - SolrQueryRequest req, SolrQueryResponse rsp, UpdateRequestProcessor next) { - boolean storeParent = shouldStoreDocParent(req.getSchema()); - boolean storePath = shouldStoreDocPath(req.getSchema()); - if (!(storeParent || storePath)) { - return next; + @Override + public UpdateRequestProcessor getInstance( + SolrQueryRequest req, SolrQueryResponse rsp, UpdateRequestProcessor next) { + boolean storeParent = shouldStoreDocParent(req.getSchema()); + boolean storePath = shouldStoreDocPath(req.getSchema()); + if (!(storeParent || storePath)) { + return next; + } + return new NestedUpdateProcessor(req, storeParent, storePath, next); } - return new NestedUpdateProcessor(req, storeParent, storePath, next); - } - - private static boolean shouldStoreDocParent(IndexSchema schema) { - return schema.getFields().containsKey(IndexSchema.NEST_PARENT_FIELD_NAME); - } - - private static boolean shouldStoreDocPath(IndexSchema schema) { - return schema.getFields().containsKey(IndexSchema.NEST_PATH_FIELD_NAME); - } - - private static class NestedUpdateProcessor extends UpdateRequestProcessor { - private static final String PATH_SEP_CHAR = "/"; - private static final String NUM_SEP_CHAR = "#"; - private static final String SINGULAR_VALUE_CHAR = ""; - private boolean storePath; - private boolean storeParent; - private String uniqueKeyFieldName; - - NestedUpdateProcessor( - SolrQueryRequest req, boolean storeParent, boolean storePath, UpdateRequestProcessor next) { - super(next); - this.storeParent = storeParent; - this.storePath = storePath; - this.uniqueKeyFieldName = req.getSchema().getUniqueKeyField().getName(); + + private static boolean shouldStoreDocParent(IndexSchema schema) { + return schema.getFields().containsKey(IndexSchema.NEST_PARENT_FIELD_NAME); } - @Override - public void processAdd(AddUpdateCommand cmd) throws IOException { - SolrInputDocument doc = cmd.getSolrInputDocument(); - processDocChildren(doc, null); - super.processAdd(cmd); + private static boolean shouldStoreDocPath(IndexSchema schema) { + return schema.getFields().containsKey(IndexSchema.NEST_PATH_FIELD_NAME); } - private boolean processDocChildren(SolrInputDocument doc, String fullPath) { - boolean isNested = false; - for (SolrInputField field : doc.values()) { - int childNum = 0; - boolean isSingleVal = !(field.getValue() instanceof Collection); - for (Object val : field) { - if (!(val instanceof SolrInputDocument cDoc)) { - // either all collection items are child docs or none are. - break; - } - final String fieldName = field.getName(); - - if (fieldName.contains(PATH_SEP_CHAR)) { - throw new SolrException( - SolrException.ErrorCode.BAD_REQUEST, - "Field name: '" - + fieldName - + "' contains: '" - + PATH_SEP_CHAR - + "' , which is reserved for the nested URP"); - } - final String sChildNum = isSingleVal ? SINGULAR_VALUE_CHAR : String.valueOf(childNum); - if (!cDoc.containsKey(uniqueKeyFieldName)) { - String parentDocId = doc.getField(uniqueKeyFieldName).getFirstValue().toString(); - cDoc.setField( - uniqueKeyFieldName, generateChildUniqueId(parentDocId, fieldName, sChildNum)); - } - if (!isNested) { - isNested = true; - } - final String lastKeyPath = PATH_SEP_CHAR + fieldName + NUM_SEP_CHAR + sChildNum; - // concat of all paths children.grandChild => /children#1/grandChild# - final String childDocPath = fullPath == null ? lastKeyPath : fullPath + lastKeyPath; - processChildDoc(cDoc, doc, childDocPath); - ++childNum; + private static class NestedUpdateProcessor extends UpdateRequestProcessor { + private static final String PATH_SEP_CHAR = "/"; + private static final String NUM_SEP_CHAR = "#"; + private static final String SINGULAR_VALUE_CHAR = ""; + private boolean storePath; + private boolean storeParent; + private String uniqueKeyFieldName; + private IndexSchema schema; + + NestedUpdateProcessor( + SolrQueryRequest req, boolean storeParent, boolean storePath, UpdateRequestProcessor next) { + super(next); + this.storeParent = storeParent; + this.storePath = storePath; + this.uniqueKeyFieldName = req.getSchema().getUniqueKeyField().getName(); + this.schema = req.getSchema(); } - } - return isNested; - } - private void processChildDoc( - SolrInputDocument sdoc, SolrInputDocument parent, String fullPath) { - if (storePath) { - setPathField(sdoc, fullPath); - } - if (storeParent) { - setParentKey(sdoc, parent); - } - processDocChildren(sdoc, fullPath); - } + @Override + public void processAdd(AddUpdateCommand cmd) throws IOException { + SolrInputDocument doc = cmd.getSolrInputDocument(); + processDocChildren(doc, null); + super.processAdd(cmd); + } - private String generateChildUniqueId(String parentId, String childKey, String childNum) { - // combines parentId with the child's key and childNum. e.g. "10/footnote#1" - return parentId + PATH_SEP_CHAR + childKey + NUM_SEP_CHAR + childNum; - } + private boolean processDocChildren(SolrInputDocument doc, String fullPath) { + boolean isNested = false; + for (SolrInputField field : doc.values()) { + SchemaField sfield = schema.getField(field.getName()); + int childNum = 0; + boolean isSingleVal = !(field.getValue() instanceof Collection); + if (fullPath == null && isMultiValuedVectorField(sfield)) { + ArrayList vectors = new ArrayList<>(field.getValueCount()); + for (Object vectorValue : field.getValues()) { + SolrInputDocument singleVectorNestedDoc = new SolrInputDocument(); + singleVectorNestedDoc.setField(field.getName(), vectorValue); + final String sChildNum = isSingleVal ? SINGULAR_VALUE_CHAR : String.valueOf(childNum); + String parentDocId = doc.getField(uniqueKeyFieldName).getFirstValue().toString(); + singleVectorNestedDoc.setField( + uniqueKeyFieldName, generateChildUniqueId(parentDocId, field.getName(), sChildNum)); - private void setParentKey(SolrInputDocument sdoc, SolrInputDocument parent) { - sdoc.setField(IndexSchema.NEST_PARENT_FIELD_NAME, parent.getFieldValue(uniqueKeyFieldName)); - } + if (!isNested) { + isNested = true; + } + final String lastKeyPath = PATH_SEP_CHAR + field.getName() + NUM_SEP_CHAR + sChildNum; + final String childDocPath = fullPath == null ? lastKeyPath : fullPath + lastKeyPath; + if (storePath) { + setPathField(singleVectorNestedDoc, childDocPath); + } + if (storeParent) { + setParentKey(singleVectorNestedDoc, doc); + } + ++childNum; + vectors.add(singleVectorNestedDoc); + } + doc.setField(field.getName(), vectors); + } else { + for (Object val : field) { + if (!(val instanceof SolrInputDocument cDoc)) { + // either all collection items are child docs or none are. + break; + } + final String fieldName = field.getName(); - private void setPathField(SolrInputDocument sdoc, String fullPath) { - sdoc.setField(IndexSchema.NEST_PATH_FIELD_NAME, fullPath); + if (fieldName.contains(PATH_SEP_CHAR)) { + throw new SolrException( + SolrException.ErrorCode.BAD_REQUEST, + "Field name: '" + + fieldName + + "' contains: '" + + PATH_SEP_CHAR + + "' , which is reserved for the nested URP"); + } + final String sChildNum = isSingleVal ? SINGULAR_VALUE_CHAR : String.valueOf(childNum); + if (!cDoc.containsKey(uniqueKeyFieldName)) { + String parentDocId = doc.getField(uniqueKeyFieldName).getFirstValue().toString(); + cDoc.setField( + uniqueKeyFieldName, generateChildUniqueId(parentDocId, fieldName, sChildNum)); + } + if (!isNested) { + isNested = true; + } + final String lastKeyPath = PATH_SEP_CHAR + fieldName + NUM_SEP_CHAR + sChildNum; + // concat of all paths children.grandChild => /children#1/grandChild# + final String childDocPath = fullPath == null ? lastKeyPath : fullPath + lastKeyPath; + processChildDoc(cDoc, doc, childDocPath); + ++childNum; + } + } + } + return isNested; + } + + private static boolean isMultiValuedVectorField(SchemaField sfield) { + return sfield.getType() instanceof DenseVectorField && sfield.multiValued(); + } + + private void processChildDoc( + SolrInputDocument sdoc, SolrInputDocument parent, String fullPath) { + if (storePath) { + setPathField(sdoc, fullPath); + } + if (storeParent) { + setParentKey(sdoc, parent); + } + processDocChildren(sdoc, fullPath); + } + + private String generateChildUniqueId(String parentId, String childKey, String childNum) { + // combines parentId with the child's key and childNum. e.g. "10/footnote#1" + return parentId + PATH_SEP_CHAR + childKey + NUM_SEP_CHAR + childNum; + } + + private void setParentKey(SolrInputDocument sdoc, SolrInputDocument parent) { + sdoc.setField(IndexSchema.NEST_PARENT_FIELD_NAME, parent.getFieldValue(uniqueKeyFieldName)); + } + + private void setPathField(SolrInputDocument sdoc, String fullPath) { + sdoc.setField(IndexSchema.NEST_PATH_FIELD_NAME, fullPath); + } } - } } diff --git a/solr/core/src/test-files/solr/collection1/conf/schema-densevector.xml b/solr/core/src/test-files/solr/collection1/conf/schema-densevector.xml index 405ecbe2fe5d..d1982569d0b5 100644 --- a/solr/core/src/test-files/solr/collection1/conf/schema-densevector.xml +++ b/solr/core/src/test-files/solr/collection1/conf/schema-densevector.xml @@ -25,8 +25,15 @@ + + + + + + + diff --git a/solr/core/src/test/org/apache/solr/schema/DenseVectorFieldTest.java b/solr/core/src/test/org/apache/solr/schema/DenseVectorFieldTest.java index 4b7533985213..6fe1b7c1ffe6 100644 --- a/solr/core/src/test/org/apache/solr/schema/DenseVectorFieldTest.java +++ b/solr/core/src/test/org/apache/solr/schema/DenseVectorFieldTest.java @@ -74,14 +74,6 @@ public void fieldDefinition_docValues_shouldThrowException() throws Exception { "DenseVectorField fields can not have docValues: vector"); } - @Test - public void fieldDefinition_multiValued_shouldThrowException() throws Exception { - assertConfigs( - "solrconfig-basic.xml", - "bad-schema-densevector-multivalued.xml", - "DenseVectorField fields can not be multiValued: vector"); - } - @Test public void fieldTypeDefinition_nullSimilarityDistance_shouldUseDefaultSimilarityEuclidean() throws Exception { diff --git a/solr/core/src/test/org/apache/solr/search/neural/KnnQParserMultiValuedVectorsTest.java b/solr/core/src/test/org/apache/solr/search/neural/KnnQParserMultiValuedVectorsTest.java index d1c0704118a8..1d20cf959f46 100644 --- a/solr/core/src/test/org/apache/solr/search/neural/KnnQParserMultiValuedVectorsTest.java +++ b/solr/core/src/test/org/apache/solr/search/neural/KnnQParserMultiValuedVectorsTest.java @@ -38,6 +38,8 @@ import static org.apache.solr.search.neural.KnnQParser.DEFAULT_TOP_K; public class KnnQParserMultiValuedVectorsTest extends SolrTestCaseJ4 { + private static final List FLOAT_QUERY_VECTOR = Arrays.asList(1.0f, 1.0f, 1.0f, 1.0f); + private static final List BYTE_QUERY_VECTOR = Arrays.asList(1, 1, 1, 1); @ClassRule public static final TestRule noReverseMerge = RandomNoReverseMergePolicyFactory.createRule(); @@ -52,7 +54,7 @@ public static void beforeClass() throws Exception { public static void prepareIndex() throws Exception { List docsToIndex = prepareDocs(); for (SolrInputDocument doc : docsToIndex) { - assertU(adoc(doc)); + updateJ(jsonAdd(doc), null); } assertU(commit()); } @@ -71,32 +73,26 @@ private static List prepareDocs() { int totalParentDocuments = 10; int totalNestedVectors = 30; int perParentChildren = totalNestedVectors / totalParentDocuments; - - final String[] klm = new String[] {"k", "l", "m"}; + final String[] abcdef = new String[] {"a", "b", "c", "d", "e", "f"}; List docs = new ArrayList<>(totalParentDocuments); for (int i = 1; i < totalParentDocuments + 1; i++) { SolrInputDocument doc = new SolrInputDocument(); doc.setField("id", i); - doc.setField("parent_b", true); - - doc.setField("parent_s", abcdef[i % abcdef.length]); - List children = new ArrayList<>(perParentChildren); - + doc.setField("_text_", abcdef[i % abcdef.length]); + List> floatVectors = new ArrayList<>(perParentChildren); + List> byteVectors = new ArrayList<>(perParentChildren); // nested vector documents have a distance from the query vector inversely proportional to // their id for (int j = 0; j < perParentChildren; j++) { - SolrInputDocument child = new SolrInputDocument(); - child.setField("id", i + "" + j); - child.setField("child_s", klm[i % klm.length]); - child.setField("vector", outDistanceFloat(FLOAT_QUERY_VECTOR, totalNestedVectors)); - child.setField("vector_byte", outDistanceByte(BYTE_QUERY_VECTOR, totalNestedVectors)); + floatVectors.add(outDistanceFloat(FLOAT_QUERY_VECTOR, totalNestedVectors)); + byteVectors.add(outDistanceByte(BYTE_QUERY_VECTOR, totalNestedVectors)); totalNestedVectors--; // the higher the id of the nested document, lower the distance with - // the query vector - children.add(child); } - doc.setField("vectors", children); + doc.setField("vector_multivalued", floatVectors); + doc.setField("vector_byte_multivalued", byteVectors); + docs.add(doc); } @@ -142,863 +138,17 @@ private static List outDistanceByte(List vector, int value) { } return result; } - - @After - public void cleanUp() { - clearIndex(); - deleteCore(); - } - - @Test - public void incorrectTopK_shouldThrowException() { - String vectorToSearch = "[1.0, 2.0, 3.0, 4.0]"; - - assertQEx( - "String topK should throw Exception", - "For input string: \"string\"", - req(CommonParams.Q, "{!knn f=vector topK=string}" + vectorToSearch, "fl", "id"), - SolrException.ErrorCode.BAD_REQUEST); - - assertQEx( - "Double topK should throw Exception", - "For input string: \"4.5\"", - req(CommonParams.Q, "{!knn f=vector topK=4.5}" + vectorToSearch, "fl", "id"), - SolrException.ErrorCode.BAD_REQUEST); - } - - @Test - public void topKMissing_shouldReturnDefaultTopK() { - String vectorToSearch = "[1.0, 2.0, 3.0, 4.0]"; - - assertQ( - req(CommonParams.Q, "{!knn f=vector}" + vectorToSearch, "fl", "id"), - "//result[@numFound='" + DEFAULT_TOP_K + "']", - "//result/doc[1]/str[@name='id'][.='1']", - "//result/doc[2]/str[@name='id'][.='4']", - "//result/doc[3]/str[@name='id'][.='2']", - "//result/doc[4]/str[@name='id'][.='10']", - "//result/doc[5]/str[@name='id'][.='3']", - "//result/doc[6]/str[@name='id'][.='7']", - "//result/doc[7]/str[@name='id'][.='5']", - "//result/doc[8]/str[@name='id'][.='6']", - "//result/doc[9]/str[@name='id'][.='9']", - "//result/doc[10]/str[@name='id'][.='8']"); - } + @Test public void topK_shouldReturnOnlyTopKResults() { - String vectorToSearch = "[1.0, 2.0, 3.0, 4.0]"; - assertQ( - req(CommonParams.Q, "{!knn f=vector topK=5}" + vectorToSearch, "fl", "id"), + req(CommonParams.Q, "{!knn f=vector_multivalued topK=5}" + FLOAT_QUERY_VECTOR, "fl", "id"), "//result[@numFound='5']", - "//result/doc[1]/str[@name='id'][.='1']", - "//result/doc[2]/str[@name='id'][.='4']", - "//result/doc[3]/str[@name='id'][.='2']", - "//result/doc[4]/str[@name='id'][.='10']", - "//result/doc[5]/str[@name='id'][.='3']"); - - assertQ( - req(CommonParams.Q, "{!knn f=vector topK=3}" + vectorToSearch, "fl", "id"), - "//result[@numFound='3']", - "//result/doc[1]/str[@name='id'][.='1']", - "//result/doc[2]/str[@name='id'][.='4']", - "//result/doc[3]/str[@name='id'][.='2']"); - } - - @Test - public void incorrectVectorFieldType_shouldThrowException() { - String vectorToSearch = "[1.0, 2.0, 3.0, 4.0]"; - - assertQEx( - "Incorrect vector field type should throw Exception", - "only DenseVectorField is compatible with Vector Query Parsers", - req(CommonParams.Q, "{!knn f=id topK=10}" + vectorToSearch, "fl", "id"), - SolrException.ErrorCode.BAD_REQUEST); - } - - @Test - public void undefinedVectorField_shouldThrowException() { - String vectorToSearch = "[1.0, 2.0, 3.0, 4.0]"; - - assertQEx( - "Undefined vector field should throw Exception", - "undefined field: \"notExistent\"", - req(CommonParams.Q, "{!knn f=notExistent topK=10}" + vectorToSearch, "fl", "id"), - SolrException.ErrorCode.BAD_REQUEST); - } - - @Test - public void missingVectorField_shouldThrowException() { - String vectorToSearch = "[1.0, 2.0, 3.0, 4.0]"; - - assertQEx( - "missing vector field should throw Exception", - "the Dense Vector field 'f' is missing", - req(CommonParams.Q, "{!knn topK=10}" + vectorToSearch, "fl", "id"), - SolrException.ErrorCode.BAD_REQUEST); - } - - @Test - public void correctVectorField_shouldSearchOnThatField() { - String vectorToSearch = "[1.0, 2.0, 3.0, 4.0]"; - - assertQ( - req(CommonParams.Q, "{!knn f=vector2 topK=5}" + vectorToSearch, "fl", "id"), - "//result[@numFound='3']", - "//result/doc[1]/str[@name='id'][.='11']", - "//result/doc[2]/str[@name='id'][.='13']", - "//result/doc[3]/str[@name='id'][.='12']"); - } - - @Test - public void highDimensionFloatVectorField_shouldSearchOnThatField() { - int highDimension = 2048; - List docsToIndex = this.prepareHighDimensionFloatVectorsDocs(highDimension); - for (SolrInputDocument doc : docsToIndex) { - assertU(adoc(doc)); - } - assertU(commit()); - - float[] highDimensionalityQueryVector = new float[highDimension]; - for (int i = 0; i < highDimension; i++) { - highDimensionalityQueryVector[i] = i; - } - String vectorToSearch = Arrays.toString(highDimensionalityQueryVector); - - assertQ( - req(CommonParams.Q, "{!knn f=2048_float_vector topK=1}" + vectorToSearch, "fl", "id"), - "//result[@numFound='1']", - "//result/doc[1]/str[@name='id'][.='1']"); - } - - @Test - public void highDimensionByteVectorField_shouldSearchOnThatField() { - int highDimension = 2048; - List docsToIndex = this.prepareHighDimensionByteVectorsDocs(highDimension); - for (SolrInputDocument doc : docsToIndex) { - assertU(adoc(doc)); - } - assertU(commit()); - - byte[] highDimensionalityQueryVector = new byte[highDimension]; - for (int i = 0; i < highDimension; i++) { - highDimensionalityQueryVector[i] = (byte) (i % 127); - } - String vectorToSearch = Arrays.toString(highDimensionalityQueryVector); - - assertQ( - req(CommonParams.Q, "{!knn f=2048_byte_vector topK=1}" + vectorToSearch, "fl", "id"), - "//result[@numFound='1']", - "//result/doc[1]/str[@name='id'][.='1']"); - } - - private List prepareHighDimensionFloatVectorsDocs(int highDimension) { - int docsCount = 13; - String field = "2048_float_vector"; - List docs = new ArrayList<>(docsCount); - - for (int i = 1; i < docsCount + 1; i++) { - SolrInputDocument doc = new SolrInputDocument(); - doc.addField(IDField, i); - docs.add(doc); - } - - for (int i = 0; i < docsCount; i++) { - List highDimensionalityVector = new ArrayList<>(); - for (int j = i * highDimension; j < highDimension; j++) { - highDimensionalityVector.add(j); - } - docs.get(i).addField(field, highDimensionalityVector); - } - Collections.reverse(docs); - return docs; - } - - private List prepareHighDimensionByteVectorsDocs(int highDimension) { - int docsCount = 13; - String field = "2048_byte_vector"; - List docs = new ArrayList<>(docsCount); - - for (int i = 1; i < docsCount + 1; i++) { - SolrInputDocument doc = new SolrInputDocument(); - doc.addField(IDField, i); - docs.add(doc); - } - - for (int i = 0; i < docsCount; i++) { - List highDimensionalityVector = new ArrayList<>(); - for (int j = i * highDimension; j < highDimension; j++) { - highDimensionalityVector.add(j % 127); - } - docs.get(i).addField(field, highDimensionalityVector); - } - Collections.reverse(docs); - return docs; - } - - @Test - public void vectorByteEncodingField_shouldSearchOnThatField() { - String vectorToSearch = "[2, 2, 1, 3]"; - - assertQ( - req(CommonParams.Q, "{!knn f=vector_byte_encoding topK=2}" + vectorToSearch, "fl", "id"), - "//result[@numFound='2']", - "//result/doc[1]/str[@name='id'][.='2']", - "//result/doc[2]/str[@name='id'][.='3']"); - - vectorToSearch = "[8, 3, 2, 4]"; - - assertQ( - req(CommonParams.Q, "{!knn f=vector_byte_encoding topK=2}" + vectorToSearch, "fl", "id"), - "//result[@numFound='2']", - "//result/doc[1]/str[@name='id'][.='8']", - "//result/doc[2]/str[@name='id'][.='4']"); - } - - @Test - public void vectorByteEncodingField_shouldRaiseExceptionIfQueryUsesFloatVectors() { - String vectorToSearch = "[8.3, 4.3, 2.1, 4.1]"; - - assertQEx( - "incorrect vector element: '8.3'. The expected format is:'[b1,b2..b3]' where each element b is a byte (-128 to 127)", - "incorrect vector element: '8.3'. The expected format is:'[b1,b2..b3]' where each element b is a byte (-128 to 127)", - req(CommonParams.Q, "{!knn f=vector_byte_encoding topK=10}" + vectorToSearch, "fl", "id"), - SolrException.ErrorCode.BAD_REQUEST); - } - - @Test - public void - vectorByteEncodingField_shouldRaiseExceptionWhenQueryContainsValuesOutsideByteValueRange() { - String vectorToSearch = "[1, -129, 3, 5]"; - - assertQEx( - "incorrect vector element: ' -129'. The expected format is:'[b1,b2..b3]' where each element b is a byte (-128 to 127)", - "incorrect vector element: ' -129'. The expected format is:'[b1,b2..b3]' where each element b is a byte (-128 to 127)", - req(CommonParams.Q, "{!knn f=vector_byte_encoding topK=10}" + vectorToSearch, "fl", "id"), - SolrException.ErrorCode.BAD_REQUEST); - - vectorToSearch = "[1, 3, 156, 5]"; - - assertQEx( - "incorrect vector element: ' 156'. The expected format is:'[b1,b2..b3]' where each element b is a byte (-128 to 127)", - "incorrect vector element: ' 156'. The expected format is:'[b1,b2..b3]' where each element b is a byte (-128 to 127)", - req(CommonParams.Q, "{!knn f=vector_byte_encoding topK=10}" + vectorToSearch, "fl", "id"), - SolrException.ErrorCode.BAD_REQUEST); - } - - @Test - public void missingVectorToSearch_shouldThrowException() { - assertQEx( - "missing vector to search should throw Exception", - "the Dense Vector value 'v' to search is missing", - req(CommonParams.Q, "{!knn f=vector topK=10}", "fl", "id"), - SolrException.ErrorCode.BAD_REQUEST); - } - - @Test - public void incorrectVectorToSearchDimension_shouldThrowException() { - String vectorToSearch = "[2.0, 4.4, 3.]"; - assertQEx( - "missing vector to search should throw Exception", - "incorrect vector dimension. The vector value has size 3 while it is expected a vector with size 4", - req(CommonParams.Q, "{!knn f=vector topK=10}" + vectorToSearch, "fl", "id"), - SolrException.ErrorCode.BAD_REQUEST); - - vectorToSearch = "[2.0, 4.4,,]"; - assertQEx( - "incorrect vector to search should throw Exception", - "incorrect vector dimension. The vector value has size 2 while it is expected a vector with size 4", - req(CommonParams.Q, "{!knn f=vector topK=10}" + vectorToSearch, "fl", "id"), - SolrException.ErrorCode.BAD_REQUEST); - } - - @Test - public void incorrectVectorToSearch_shouldThrowException() { - String vectorToSearch = "2.0, 4.4, 3.5, 6.4"; - assertQEx( - "incorrect vector to search should throw Exception", - "incorrect vector format. The expected format is:'[f1,f2..f3]' where each element f is a float", - req(CommonParams.Q, "{!knn f=vector topK=10}" + vectorToSearch, "fl", "id"), - SolrException.ErrorCode.BAD_REQUEST); - - vectorToSearch = "[2.0, 4.4, 3.5, 6.4"; - assertQEx( - "incorrect vector to search should throw Exception", - "incorrect vector format. The expected format is:'[f1,f2..f3]' where each element f is a float", - req(CommonParams.Q, "{!knn f=vector topK=10}" + vectorToSearch, "fl", "id"), - SolrException.ErrorCode.BAD_REQUEST); - - vectorToSearch = "2.0, 4.4, 3.5, 6.4]"; - assertQEx( - "incorrect vector to search should throw Exception", - "incorrect vector format. The expected format is:'[f1,f2..f3]' where each element f is a float", - req(CommonParams.Q, "{!knn f=vector topK=10}" + vectorToSearch, "fl", "id"), - SolrException.ErrorCode.BAD_REQUEST); - - vectorToSearch = "[2.0, 4.4, 3.5, stringElement]"; - assertQEx( - "incorrect vector to search should throw Exception", - "incorrect vector element: ' stringElement'. The expected format is:'[f1,f2..f3]' where each element f is a float", - req(CommonParams.Q, "{!knn f=vector topK=10}" + vectorToSearch, "fl", "id"), - SolrException.ErrorCode.BAD_REQUEST); - - vectorToSearch = "[2.0, 4.4, , ]"; - assertQEx( - "incorrect vector to search should throw Exception", - "incorrect vector element: ' '. The expected format is:'[f1,f2..f3]' where each element f is a float", - req(CommonParams.Q, "{!knn f=vector topK=10}" + vectorToSearch, "fl", "id"), - SolrException.ErrorCode.BAD_REQUEST); - } - - @Test - public void correctQuery_shouldRankBySimilarityFunction() { - String vectorToSearch = "[1.0, 2.0, 3.0, 4.0]"; - - assertQ( - req(CommonParams.Q, "{!knn f=vector topK=10}" + vectorToSearch, "fl", "id"), - "//result[@numFound='10']", - "//result/doc[1]/str[@name='id'][.='1']", - "//result/doc[2]/str[@name='id'][.='4']", - "//result/doc[3]/str[@name='id'][.='2']", - "//result/doc[4]/str[@name='id'][.='10']", - "//result/doc[5]/str[@name='id'][.='3']", - "//result/doc[6]/str[@name='id'][.='7']", - "//result/doc[7]/str[@name='id'][.='5']", - "//result/doc[8]/str[@name='id'][.='6']", - "//result/doc[9]/str[@name='id'][.='9']", - "//result/doc[10]/str[@name='id'][.='8']"); - } - - @Test - public void knnQueryUsedInFilter_shouldFilterResultsBeforeTheQueryExecution() { - String vectorToSearch = "[1.0, 2.0, 3.0, 4.0]"; - assertQ( - req( - CommonParams.Q, - "id:(3 4 9 2)", - "fq", - "{!knn f=vector topK=4}" + vectorToSearch, - "fl", - "id"), - "//result[@numFound='2']", - "//result/doc[1]/str[@name='id'][.='2']", - "//result/doc[2]/str[@name='id'][.='4']"); - } - - @Test - public void knnQueryUsedInFilters_shouldFilterResultsBeforeTheQueryExecution() { - String vectorToSearch = "[1.0, 2.0, 3.0, 4.0]"; - - // topK=4 -> 1,4,2,10 - assertQ( - req( - CommonParams.Q, - "id:(3 4 9 2)", - "fq", - "{!knn f=vector topK=4}" + vectorToSearch, - "fq", - "id:(4 20 9)", - "fl", - "id"), - "//result[@numFound='1']", - "//result/doc[1]/str[@name='id'][.='4']"); - } - - @Test - public void knnQueryUsedInFiltersWithPreFilter_shouldFilterResultsBeforeTheQueryExecution() { - String vectorToSearch = "[1.0, 2.0, 3.0, 4.0]"; - - // topK=4 w/localparam preFilter -> 1,4,7,9 - assertQ( - req( - CommonParams.Q, - "id:(3 4 9 2)", - "fq", - "{!knn f=vector topK=4 preFilter='id:(1 4 7 8 9)'}" + vectorToSearch, - "fq", - "id:(4 20 9)", - "fl", - "id"), - "//result[@numFound='2']", - "//result/doc[1]/str[@name='id'][.='4']", - "//result/doc[2]/str[@name='id'][.='9']"); - } - - @Test - public void knnQueryUsedInFilters_rejectIncludeExclude() { - String vectorToSearch = "[1.0, 2.0, 3.0, 4.0]"; - - for (String fq : - Arrays.asList( - "{!knn f=vector topK=5 includeTags=xxx}" + vectorToSearch, - "{!knn f=vector topK=5 excludeTags=xxx}" + vectorToSearch)) { - assertQEx( - "fq={!knn...} incompatible with include/exclude localparams", - "used as a filter does not support", - req("q", "*:*", "fq", fq), - SolrException.ErrorCode.BAD_REQUEST); - } - } - - @Test - public void knnQueryAsSubQuery() { - final SolrParams common = params("fl", "id", "vec", "[1.0, 2.0, 3.0, 4.0]"); - final String filt = "id:(2 4 7 9 8 20 3)"; - - // When knn parser is a subquery, it should not pre-filter on any global fq params - // topK -> 1,4,2,10,3 -> fq -> 4,2,3 - assertQ( - req(common, "fq", filt, "q", "*:* AND {!knn f=vector topK=5 v=$vec}"), - "//result[@numFound='3']", - "//result/doc[1]/str[@name='id'][.='4']", - "//result/doc[2]/str[@name='id'][.='2']", - "//result/doc[3]/str[@name='id'][.='3']"); - // topK -> 1,4,2,10,3 + '8' -> fq -> 4,2,3,8 - assertQ( - req(common, "fq", filt, "q", "id:8^=0.01 OR {!knn f=vector topK=5 v=$vec}"), - "//result[@numFound='4']", - "//result/doc[1]/str[@name='id'][.='4']", - "//result/doc[2]/str[@name='id'][.='2']", - "//result/doc[3]/str[@name='id'][.='3']", - "//result/doc[4]/str[@name='id'][.='8']"); - } - - @Test - public void knnQueryAsSubQuery_withPreFilter() { - final SolrParams common = params("fl", "id", "vec", "[1.0, 2.0, 3.0, 4.0]"); - final String filt = "id:(2 4 7 9 8 20 3)"; - - // knn subquery should still accept `preFilter` local param - // filt -> topK -> 4,2,3,7,9 - assertQ( - req(common, "q", "*:* AND {!knn f=vector topK=5 preFilter='" + filt + "' v=$vec}"), - "//result[@numFound='5']", - "//result/doc[1]/str[@name='id'][.='4']", - "//result/doc[2]/str[@name='id'][.='2']", - "//result/doc[3]/str[@name='id'][.='3']", + "//result/doc[1]/str[@name='id'][.='10']", + "//result/doc[2]/str[@name='id'][.='9']", + "//result/doc[3]/str[@name='id'][.='8']", "//result/doc[4]/str[@name='id'][.='7']", - "//result/doc[5]/str[@name='id'][.='9']"); - - // it should not pre-filter on any global fq params - // filt -> topK -> 4,2,3,7,9 -> fq -> 3,9 - assertQ( - req( - common, - "fq", - "id:(1 9 20 3 5 6 8)", - "q", - "*:* AND {!knn f=vector topK=5 preFilter='" + filt + "' v=$vec}"), - "//result[@numFound='2']", - "//result/doc[1]/str[@name='id'][.='3']", - "//result/doc[2]/str[@name='id'][.='9']"); - // filt -> topK -> 4,2,3,7,9 + '8' -> fq -> 8,3,9 - assertQ( - req( - common, - "fq", - "id:(1 9 20 3 5 6 8)", - "q", - "id:8^=100 OR {!knn f=vector topK=5 preFilter='" + filt + "' v=$vec}"), - "//result[@numFound='3']", - "//result/doc[1]/str[@name='id'][.='8']", - "//result/doc[2]/str[@name='id'][.='3']", - "//result/doc[3]/str[@name='id'][.='9']"); - } - - @Test - public void knnQueryAsSubQuery_rejectIncludeExclude() { - final SolrParams common = params("fl", "id", "vec", "[1.0, 2.0, 3.0, 4.0]"); - - for (String knn : - Arrays.asList( - "{!knn f=vector topK=5 includeTags=xxx v=$vec}", - "{!knn f=vector topK=5 excludeTags=xxx v=$vec}")) { - assertQEx( - "knn as subquery incompatible with include/exclude localparams", - "used as a sub-query does not support", - req(common, "q", "*:* OR " + knn), - SolrException.ErrorCode.BAD_REQUEST); - } - } - - @Test - public void knnQueryWithFilterQuery_singlePreFilterEquivilence() { - final String vectorToSearch = "[1.0, 2.0, 3.0, 4.0]"; - final SolrParams common = params("fl", "id"); - - // these requests should be equivalent - final String filt = "id:(1 2 7 20)"; - for (SolrQueryRequest req : - Arrays.asList( - req(common, "q", "{!knn f=vector topK=10}" + vectorToSearch, "fq", filt), - req(common, "q", "{!knn f=vector preFilter=\"" + filt + "\" topK=10}" + vectorToSearch), - req( - common, - "q", - "{!knn f=vector preFilter=$my_filt topK=10}" + vectorToSearch, - "my_filt", - filt))) { - assertQ( - req, - "//result[@numFound='3']", - "//result/doc[1]/str[@name='id'][.='1']", - "//result/doc[2]/str[@name='id'][.='2']", - "//result/doc[3]/str[@name='id'][.='7']"); - } - } - - @Test - public void knnQueryWithFilterQuery_multiPreFilterEquivilence() { - final String vectorToSearch = "[1.0, 2.0, 3.0, 4.0]"; - final SolrParams common = params("fl", "id"); - - // these requests should be equivalent - final String fx = "id:(3 4 9 2 1 )"; // 1 & 10 dropped from intersection - final String fy = "id:(3 4 9 2 10)"; - for (SolrQueryRequest req : - Arrays.asList( - req(common, "q", "{!knn f=vector topK=4}" + vectorToSearch, "fq", fx, "fq", fy), - req( - common, - "q", - "{!knn f=vector preFilter=\"" - + fx - + "\" preFilter=\"" - + fy - + "\" topK=4}" - + vectorToSearch), - req( - common, - "q", - "{!knn f=vector preFilter=$fx preFilter=$fy topK=4}" + vectorToSearch, - "fx", - fx, - "fy", - fy), - req( - common, - "q", - "{!knn f=vector preFilter=$multi_filt topK=4}" + vectorToSearch, - "multi_filt", - fx, - "multi_filt", - fy))) { - assertQ( - req, - "//result[@numFound='4']", - "//result/doc[1]/str[@name='id'][.='4']", - "//result/doc[2]/str[@name='id'][.='2']", - "//result/doc[3]/str[@name='id'][.='3']", - "//result/doc[4]/str[@name='id'][.='9']"); - } - } - - @Test - public void knnQueryWithPreFilter_rejectIncludeExclude() { - final String vectorToSearch = "[1.0, 2.0, 3.0, 4.0]"; - - assertQEx( - "knn preFilter localparm incompatible with include/exclude localparams", - "does not support combining preFilter localparam with either", - // shouldn't matter if global fq w/tag even exists, usage is an error - req("q", "{!knn f=vector preFilter='id:1' includeTags=xxx}" + vectorToSearch), - SolrException.ErrorCode.BAD_REQUEST); - assertQEx( - "knn preFilter localparm incompatible with include/exclude localparams", - "does not support combining preFilter localparam with either", - // shouldn't matter if global fq w/tag even exists, usage is an error - req("q", "{!knn f=vector preFilter='id:1' excludeTags=xxx}" + vectorToSearch), - SolrException.ErrorCode.BAD_REQUEST); - } - - @Test - public void knnQueryWithFilterQuery_preFilterLocalParamOverridesGlobalFilters() { - final String vectorToSearch = "[1.0, 2.0, 3.0, 4.0]"; - - // trivial case: empty preFilter localparam means no pre-filtering - assertQ( - req( - "q", "{!knn f=vector preFilter='' topK=5}" + vectorToSearch, - "fq", "-id:4", - "fl", "id"), - "//result[@numFound='4']", - "//result/doc[1]/str[@name='id'][.='1']", - "//result/doc[2]/str[@name='id'][.='2']", - "//result/doc[3]/str[@name='id'][.='10']", - "//result/doc[4]/str[@name='id'][.='3']"); - - // localparam prefiltering, global fqs applied independently - assertQ( - req( - "q", "{!knn f=vector preFilter='id:(3 4 9 2 7 8)' topK=5}" + vectorToSearch, - "fq", "-id:4", - "fl", "id"), - "//result[@numFound='4']", - "//result/doc[1]/str[@name='id'][.='2']", - "//result/doc[2]/str[@name='id'][.='3']", - "//result/doc[3]/str[@name='id'][.='7']", - "//result/doc[4]/str[@name='id'][.='9']"); - } - - @Test - public void knnQueryWithFilterQuery_localParamIncludeExcludeTags() { - final String vectorToSearch = "[1.0, 2.0, 3.0, 4.0]"; - final SolrParams common = - params( - "fl", "id", - "fq", "{!tag=xx,aa}id:(5 6 7 8 9 10)", - "fq", "{!tag=yy,aa}id:(1 2 3 4 5 6 7)"); - - // These req's are equivalent: pre-filter everything - // So only 7,6,5 are viable for topK=5 - for (SolrQueryRequest req : - Arrays.asList( - // default behavior is all fq's pre-filter, - req(common, "q", "{!knn f=vector topK=5}" + vectorToSearch), - // diff ways of explicitly requesting both fq params - req(common, "q", "{!knn f=vector includeTags=aa topK=5}" + vectorToSearch), - req( - common, - "q", - "{!knn f=vector includeTags=aa excludeTags='' topK=5}" + vectorToSearch), - req( - common, - "q", - "{!knn f=vector includeTags=aa excludeTags=bogus topK=5}" + vectorToSearch), - req( - common, - "q", - "{!knn f=vector includeTags=xx includeTags=yy topK=5}" + vectorToSearch), - req(common, "q", "{!knn f=vector includeTags=xx,yy,bogus topK=5}" + vectorToSearch))) { - assertQ( - req, - "//result[@numFound='3']", - "//result/doc[1]/str[@name='id'][.='7']", - "//result/doc[2]/str[@name='id'][.='5']", - "//result/doc[3]/str[@name='id'][.='6']"); - } - } - - @Test - public void knnQueryWithFilterQuery_localParamsDisablesAllPreFiltering() { - final String vectorToSearch = "[1.0, 2.0, 3.0, 4.0]"; - final SolrParams common = - params( - "fl", "id", - "fq", "{!tag=xx,aa}id:(5 6 7 8 9 10)", - "fq", "{!tag=yy,aa}id:(1 2 3 4 5 6 7)"); - - // These req's are equivalent: pre-filter nothing - // So 1,4,2,10,3,7 are the topK=6 - // Only 7 matches both of the the regular fq params - for (SolrQueryRequest req : - Arrays.asList( - // explicit local empty preFilter - req(common, "q", "{!knn f=vector preFilter='' topK=6}" + vectorToSearch), - // diff ways of explicitly including none of the global fq params - req(common, "q", "{!knn f=vector includeTags='' topK=6}" + vectorToSearch), - req(common, "q", "{!knn f=vector includeTags=bogus topK=6}" + vectorToSearch), - // diff ways of explicitly excluding all of the global fq params - req(common, "q", "{!knn f=vector excludeTags=aa topK=6}" + vectorToSearch), - req( - common, - "q", - "{!knn f=vector includeTags=aa excludeTags=aa topK=6}" + vectorToSearch), - req( - common, - "q", - "{!knn f=vector includeTags=aa excludeTags=xx,yy topK=6}" + vectorToSearch), - req( - common, - "q", - "{!knn f=vector includeTags=xx,yy excludeTags=aa topK=6}" + vectorToSearch), - req(common, "q", "{!knn f=vector excludeTags=xx,yy topK=6}" + vectorToSearch), - req(common, "q", "{!knn f=vector excludeTags=aa topK=6}" + vectorToSearch), - req( - common, - "q", - "{!knn f=vector excludeTags=xx excludeTags=yy topK=6}" + vectorToSearch), - req( - common, - "q", - "{!knn f=vector excludeTags=xx excludeTags=yy,bogus topK=6}" + vectorToSearch), - req(common, "q", "{!knn f=vector excludeTags=xx,yy,bogus topK=6}" + vectorToSearch))) { - assertQ(req, "//result[@numFound='1']", "//result/doc[1]/str[@name='id'][.='7']"); - } - } - - @Test - public void knnQueryWithFilterQuery_localParamCombinedIncludeExcludeTags() { - final String vectorToSearch = "[1.0, 2.0, 3.0, 4.0]"; - final SolrParams common = - params( - "fl", "id", - "fq", "{!tag=xx,aa}id:(5 6 7 8 9 10)", - "fq", "{!tag=yy,aa}id:(1 2 3 4 5 6 7)"); - - // These req's are equivalent: prefilter only the 'yy' fq - // So 1,4,2,3,7 are in the topK=5. - // Only 7 matches the regular 'xx' fq param - for (SolrQueryRequest req : - Arrays.asList( - // diff ways of only using the 'yy' filter - req(common, "q", "{!knn f=vector includeTags=yy,bogus topK=5}" + vectorToSearch), - req( - common, - "q", - "{!knn f=vector includeTags=yy excludeTags='' topK=5}" + vectorToSearch), - req(common, "q", "{!knn f=vector excludeTags=xx,bogus topK=5}" + vectorToSearch), - req( - common, - "q", - "{!knn f=vector includeTags=yy excludeTags=xx topK=5}" + vectorToSearch), - req( - common, - "q", - "{!knn f=vector includeTags=aa excludeTags=xx topK=5}" + vectorToSearch))) { - assertQ(req, "//result[@numFound='1']", "//result/doc[1]/str[@name='id'][.='7']"); - } - } - - @Test - public void knnQueryWithMultiSelectFaceting_excludeTags() { - // NOTE: faceting on id is not very realistic, - // but it confirms what we care about re:filters w/o needing extra fields. - final String facet_xpath = "//lst[@name='facet_fields']/lst[@name='id']/int"; - final String vectorToSearch = "[1.0, 2.0, 3.0, 4.0]"; - - final SolrParams common = - params( - "fl", "id", - "indent", "true", - "q", "{!knn f=vector topK=5 excludeTags=facet_click v=$vec}", - "vec", vectorToSearch, - // mimicing "inStock:true" - "fq", "-id:(2 3)", - "facet", "true", - "facet.mincount", "1", - "facet.field", "{!ex=facet_click}id"); - - // initial query, with basic pre-filter and facet counts - assertQ( - req(common), - "//result[@numFound='5']", - "//result/doc[1]/str[@name='id'][.='1']", - "//result/doc[2]/str[@name='id'][.='4']", - "//result/doc[3]/str[@name='id'][.='10']", - "//result/doc[4]/str[@name='id'][.='7']", - "//result/doc[5]/str[@name='id'][.='5']", - "*[count(" + facet_xpath + ")=5]", - facet_xpath + "[@name='1'][.='1']", - facet_xpath + "[@name='4'][.='1']", - facet_xpath + "[@name='10'][.='1']", - facet_xpath + "[@name='7'][.='1']", - facet_xpath + "[@name='5'][.='1']"); - - // drill down on a single facet constraint - // multi-select means facet counts shouldn't change - // (this proves the knn isn't pre-filtering on the 'facet_click' fq) - assertQ( - req(common, "fq", "{!tag=facet_click}id:(4)"), - "//result[@numFound='1']", - "//result/doc[1]/str[@name='id'][.='4']", - "*[count(" + facet_xpath + ")=5]", - facet_xpath + "[@name='1'][.='1']", - facet_xpath + "[@name='4'][.='1']", - facet_xpath + "[@name='10'][.='1']", - facet_xpath + "[@name='7'][.='1']", - facet_xpath + "[@name='5'][.='1']"); - - // drill down on an additional facet constraint - // multi-select means facet counts shouldn't change - // (this proves the knn isn't pre-filtering on the 'facet_click' fq) - assertQ( - req(common, "fq", "{!tag=facet_click}id:(4 5)"), - "//result[@numFound='2']", - "//result/doc[1]/str[@name='id'][.='4']", - "//result/doc[2]/str[@name='id'][.='5']", - "*[count(" + facet_xpath + ")=5]", - facet_xpath + "[@name='1'][.='1']", - facet_xpath + "[@name='4'][.='1']", - facet_xpath + "[@name='10'][.='1']", - facet_xpath + "[@name='7'][.='1']", - facet_xpath + "[@name='5'][.='1']"); - } - - @Test - public void knnQueryWithCostlyFq_shouldPerformKnnSearchWithPostFilter() { - String vectorToSearch = "[1.0, 2.0, 3.0, 4.0]"; - - assertQ( - req( - CommonParams.Q, - "{!knn f=vector topK=10}" + vectorToSearch, - "fq", - "{!frange cache=false l=0.99}$q", - "fl", - "*,score"), - "//result[@numFound='5']", - "//result/doc[1]/str[@name='id'][.='1']", - "//result/doc[2]/str[@name='id'][.='4']", - "//result/doc[3]/str[@name='id'][.='2']", - "//result/doc[4]/str[@name='id'][.='10']", - "//result/doc[5]/str[@name='id'][.='3']"); - } - - @Test - public void knnQueryWithFilterQueries_shouldPerformKnnSearchWithPreFiltersAndPostFilters() { - String vectorToSearch = "[1.0, 2.0, 3.0, 4.0]"; - - assertQ( - req( - CommonParams.Q, - "{!knn f=vector topK=4}" + vectorToSearch, - "fq", - "id:(3 4 9 2)", - "fq", - "{!frange cache=false l=0.99}$q", - "fl", - "id"), - "//result[@numFound='2']", - "//result/doc[1]/str[@name='id'][.='4']", - "//result/doc[2]/str[@name='id'][.='2']"); - } - - @Test - public void knnQueryWithNegativeFilterQuery_shouldPerformKnnSearchInPreFilteredResults() { - String vectorToSearch = "[1.0, 2.0, 3.0, 4.0]"; - assertQ( - req(CommonParams.Q, "{!knn f=vector topK=4}" + vectorToSearch, "fq", "-id:4", "fl", "id"), - "//result[@numFound='4']", - "//result/doc[1]/str[@name='id'][.='1']", - "//result/doc[2]/str[@name='id'][.='2']", - "//result/doc[3]/str[@name='id'][.='10']", - "//result/doc[4]/str[@name='id'][.='3']"); - } - - /** - * See {@link org.apache.solr.search.ReRankQParserPlugin.ReRankQueryRescorer#combine(float, - * boolean, float)}} for more details. - */ - @Test - public void knnQueryAsRerank_shouldAddSimilarityFunctionScore() { - String vectorToSearch = "[1.0, 2.0, 3.0, 4.0]"; - - assertQ( - req( - CommonParams.Q, - "id:(3 4 9 2)", - "rq", - "{!rerank reRankQuery=$rqq reRankDocs=4 reRankWeight=1}", - "rqq", - "{!knn f=vector topK=4}" + vectorToSearch, - "fl", - "id"), - "//result[@numFound='4']", - "//result/doc[1]/str[@name='id'][.='4']", - "//result/doc[2]/str[@name='id'][.='2']", - "//result/doc[3]/str[@name='id'][.='3']", - "//result/doc[4]/str[@name='id'][.='9']"); + "//result/doc[5]/str[@name='id'][.='6']"); } } From a203a2d50d05d4e25e606249aa4bda67757fc348 Mon Sep 17 00:00:00 2001 From: Alessandro Benedetti Date: Wed, 9 Apr 2025 14:21:39 +0100 Subject: [PATCH 14/43] first draft with working tests --- .../org/apache/solr/schema/DenseVectorField.java | 16 +++++++++++----- .../neural/KnnQParserMultiValuedVectorsTest.java | 11 +++++++++++ 2 files changed, 22 insertions(+), 5 deletions(-) diff --git a/solr/core/src/java/org/apache/solr/schema/DenseVectorField.java b/solr/core/src/java/org/apache/solr/schema/DenseVectorField.java index d8e9c000e215..3f75f3846abb 100644 --- a/solr/core/src/java/org/apache/solr/schema/DenseVectorField.java +++ b/solr/core/src/java/org/apache/solr/schema/DenseVectorField.java @@ -49,6 +49,7 @@ import org.apache.lucene.search.join.DiversifyingChildrenByteKnnVectorQuery; import org.apache.lucene.search.join.DiversifyingChildrenFloatKnnVectorQuery; import org.apache.lucene.search.join.ScoreMode; +import org.apache.lucene.search.join.ToChildBlockJoinQuery; import org.apache.lucene.search.join.ToParentBlockJoinQuery; import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.hnsw.HnswGraph; @@ -396,7 +397,7 @@ public Query getMultiValuedKnnVectorQuery(final SolrQueryRequest request, DenseVectorParser vectorBuilder = getVectorBuilder(vectorToSearch, DenseVectorParser.BuilderPhase.QUERY); - + BooleanQuery allDocuments = new BooleanQuery.Builder() .add(new BooleanClause(new MatchAllDocsQuery(), BooleanClause.Occur.MUST)) @@ -405,19 +406,24 @@ public Query getMultiValuedKnnVectorQuery(final SolrQueryRequest request, new DocValuesFieldExistsQuery(NEST_PATH_FIELD_NAME), BooleanClause.Occur.MUST_NOT)) .build(); - - BitSetProducer acceptedDocuments = BlockJoinParentQParser.getCachedBitSetProducer(request, filterQuery); + BitSetProducer allParentsBitSet = BlockJoinParentQParser.getCachedBitSetProducer(request, allDocuments); + + Query acceptedVectorsBasedOnDocumentFilters = null; + if(filterQuery != null){ + acceptedVectorsBasedOnDocumentFilters = new ToChildBlockJoinQuery(filterQuery, allParentsBitSet); + } + Query knnOnVectorField; switch (vectorEncoding) { case FLOAT32: knnOnVectorField = - new DiversifyingChildrenFloatKnnVectorQuery(fieldName, vectorBuilder.getFloatVector(), null, topK, acceptedDocuments); + new DiversifyingChildrenFloatKnnVectorQuery(fieldName, vectorBuilder.getFloatVector(), acceptedVectorsBasedOnDocumentFilters, topK, allParentsBitSet); break; case BYTE: knnOnVectorField = - new DiversifyingChildrenByteKnnVectorQuery(fieldName, vectorBuilder.getByteVector(), null, topK, acceptedDocuments); + new DiversifyingChildrenByteKnnVectorQuery(fieldName, vectorBuilder.getByteVector(), acceptedVectorsBasedOnDocumentFilters, topK, allParentsBitSet); break; default: throw new SolrException( diff --git a/solr/core/src/test/org/apache/solr/search/neural/KnnQParserMultiValuedVectorsTest.java b/solr/core/src/test/org/apache/solr/search/neural/KnnQParserMultiValuedVectorsTest.java index 1d20cf959f46..29f778bcce1d 100644 --- a/solr/core/src/test/org/apache/solr/search/neural/KnnQParserMultiValuedVectorsTest.java +++ b/solr/core/src/test/org/apache/solr/search/neural/KnnQParserMultiValuedVectorsTest.java @@ -151,4 +151,15 @@ public void topK_shouldReturnOnlyTopKResults() { "//result/doc[4]/str[@name='id'][.='7']", "//result/doc[5]/str[@name='id'][.='6']"); } + + @Test + public void topKWithFilter_shouldReturnOnlyTopKResults() { + assertQ( + req(CommonParams.Q, "{!knn f=vector_multivalued topK=5}" + FLOAT_QUERY_VECTOR, "fl", "id","fq","_text_:(b OR c)"), + "//result[@numFound='4']", + "//result/doc[1]/str[@name='id'][.='8']", + "//result/doc[2]/str[@name='id'][.='7']", + "//result/doc[3]/str[@name='id'][.='2']", + "//result/doc[4]/str[@name='id'][.='1']"); + } } From 59e59ad60b258158e7f2a64993c21a1c00e62e43 Mon Sep 17 00:00:00 2001 From: Alessandro Benedetti Date: Wed, 9 Apr 2025 17:55:03 +0100 Subject: [PATCH 15/43] draft for automatic child transformer --- .../transform/ChildDocTransformer.java | 8 +++++-- .../apache/solr/schema/DenseVectorField.java | 22 ++++++++++++++++--- .../org/apache/solr/search/QueryLimits.java | 5 +++++ .../apache/solr/search/SolrReturnFields.java | 8 ++++++- .../KnnQParserMultiValuedVectorsTest.java | 13 ++++++++++- .../java/org/apache/solr/SolrTestCaseJ4.java | 2 +- 6 files changed, 50 insertions(+), 8 deletions(-) diff --git a/solr/core/src/java/org/apache/solr/response/transform/ChildDocTransformer.java b/solr/core/src/java/org/apache/solr/response/transform/ChildDocTransformer.java index d8a4f0842264..ae1f87d165f3 100644 --- a/solr/core/src/java/org/apache/solr/response/transform/ChildDocTransformer.java +++ b/solr/core/src/java/org/apache/solr/response/transform/ChildDocTransformer.java @@ -51,19 +51,23 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -class ChildDocTransformer extends DocTransformer { +public class ChildDocTransformer extends DocTransformer { private static final Logger log = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass()); private static final String ANON_CHILD_KEY = "_childDocuments_"; private final String name; private final BitSetProducer parentsFilter; // if null; resolve parent via uniqueKey instead - private final DocSet childDocSet; + private DocSet childDocSet; private final int limit; private final boolean isNestedSchema; private final SolrReturnFields childReturnFields; private final String[] extraRequestedFields; + public void setChildDocSet(DocSet childDocSet) { + this.childDocSet = childDocSet; + } + ChildDocTransformer( String name, BitSetProducer parentsFilter, diff --git a/solr/core/src/java/org/apache/solr/schema/DenseVectorField.java b/solr/core/src/java/org/apache/solr/schema/DenseVectorField.java index 3f75f3846abb..1f5c5db0b1c3 100644 --- a/solr/core/src/java/org/apache/solr/schema/DenseVectorField.java +++ b/solr/core/src/java/org/apache/solr/schema/DenseVectorField.java @@ -21,6 +21,7 @@ import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN; import static org.apache.solr.schema.IndexSchema.NEST_PATH_FIELD_NAME; +import java.io.IOException; import java.lang.invoke.MethodHandles; import java.util.ArrayList; import java.util.List; @@ -55,7 +56,11 @@ import org.apache.lucene.util.hnsw.HnswGraph; import org.apache.solr.common.SolrException; import org.apache.solr.request.SolrQueryRequest; +import org.apache.solr.response.transform.ChildDocTransformer; +import org.apache.solr.response.transform.DocTransformer; +import org.apache.solr.response.transform.DocTransformers; import org.apache.solr.search.QParser; +import org.apache.solr.search.QueryLimits; import org.apache.solr.search.join.BlockJoinParentQParser; import org.apache.solr.uninverting.UninvertingReader; import org.apache.solr.util.vector.ByteDenseVectorParser; @@ -430,9 +435,20 @@ public Query getMultiValuedKnnVectorQuery(final SolrQueryRequest request, SolrException.ErrorCode.SERVER_ERROR, "Unexpected state. Vector Encoding: " + vectorEncoding); } - - return new ToParentBlockJoinQuery( - knnOnVectorField, allParentsBitSet, ScoreMode.Max); + try { + //TO DO: if no ChildDocTransformer, then we need to add this one + QueryLimits currentLimits = QueryLimits.getCurrentLimits(); + DocTransformers transformers = (DocTransformers)currentLimits.getRsp().getReturnFields().getTransformer(); + ChildDocTransformer child = (ChildDocTransformer) transformers.getTransformer(1); + child.setChildDocSet(request.getSearcher().getDocSet(knnOnVectorField)); + + return new ToParentBlockJoinQuery( + knnOnVectorField, allParentsBitSet, ScoreMode.Max); + } catch (IOException e) { + throw new SolrException(SolrException.ErrorCode.SERVER_ERROR, e); + } + + } /** diff --git a/solr/core/src/java/org/apache/solr/search/QueryLimits.java b/solr/core/src/java/org/apache/solr/search/QueryLimits.java index e6e0db5eed94..85587f7baa02 100644 --- a/solr/core/src/java/org/apache/solr/search/QueryLimits.java +++ b/solr/core/src/java/org/apache/solr/search/QueryLimits.java @@ -45,6 +45,11 @@ public class QueryLimits implements QueryTimeout { public static QueryLimits NONE = new QueryLimits(); private final SolrQueryResponse rsp; + + public SolrQueryResponse getRsp() { + return rsp; + } + private final boolean allowPartialResults; // short-circuit the checks if any limit has been tripped diff --git a/solr/core/src/java/org/apache/solr/search/SolrReturnFields.java b/solr/core/src/java/org/apache/solr/search/SolrReturnFields.java index baaaeb270ef3..71e29d9ccd62 100644 --- a/solr/core/src/java/org/apache/solr/search/SolrReturnFields.java +++ b/solr/core/src/java/org/apache/solr/search/SolrReturnFields.java @@ -38,6 +38,7 @@ import org.apache.solr.common.params.SolrParams; import org.apache.solr.common.util.GlobPatternUtil; import org.apache.solr.request.SolrQueryRequest; +import org.apache.solr.response.transform.ChildDocTransformerFactory; import org.apache.solr.response.transform.DocTransformer; import org.apache.solr.response.transform.DocTransformers; import org.apache.solr.response.transform.RenameFieldTransformer; @@ -295,7 +296,6 @@ private void add( sp.pos = start; field = null; } - if (field == null) { // We didn't find a simple name, so let's see if it's a globbed field name. // Globbing only works with field names of the recommended form (roughly like java @@ -536,6 +536,12 @@ private void addField( String disp = (key == null) ? field : key; augmenters.addTransformer(new ScoreAugmenter(disp)); } + /* + if("vector_multivalued".equals(field)){ + ChildDocTransformerFactory childFactory = new ChildDocTransformerFactory(); + DocTransformer multiValuedTrans = childFactory.create("vector_multivalued", null, null); + augmenters.addTransformer(multiValuedTrans); + }*/ } @Override diff --git a/solr/core/src/test/org/apache/solr/search/neural/KnnQParserMultiValuedVectorsTest.java b/solr/core/src/test/org/apache/solr/search/neural/KnnQParserMultiValuedVectorsTest.java index 29f778bcce1d..4c20c3bbfde1 100644 --- a/solr/core/src/test/org/apache/solr/search/neural/KnnQParserMultiValuedVectorsTest.java +++ b/solr/core/src/test/org/apache/solr/search/neural/KnnQParserMultiValuedVectorsTest.java @@ -155,7 +155,18 @@ public void topK_shouldReturnOnlyTopKResults() { @Test public void topKWithFilter_shouldReturnOnlyTopKResults() { assertQ( - req(CommonParams.Q, "{!knn f=vector_multivalued topK=5}" + FLOAT_QUERY_VECTOR, "fl", "id","fq","_text_:(b OR c)"), + req(CommonParams.Q, "{!knn f=vector_multivalued topK=5}" + FLOAT_QUERY_VECTOR, "fl", "id,[child childFilter=$allChildren limit=2 fl=id,vector_multivalued]","fq","_text_:(b OR c)","allChildren","_nest_path_:[* TO *]"), + "//result[@numFound='4']", + "//result/doc[1]/str[@name='id'][.='8']", + "//result/doc[2]/str[@name='id'][.='7']", + "//result/doc[3]/str[@name='id'][.='2']", + "//result/doc[4]/str[@name='id'][.='1']"); + } + + @Test + public void topKWithFilterAndChildTransformer_shouldReturnOnlyTopKResults() { + assertQ( + req(CommonParams.Q, "{!knn f=vector_multivalued topK=5}" + FLOAT_QUERY_VECTOR, "fl", "id,score,vector_multivalued,[child fl=vector_multivalued]","fq","_text_:(b OR c)","allChildren","_nest_path_:[* TO *]"), "//result[@numFound='4']", "//result/doc[1]/str[@name='id'][.='8']", "//result/doc[2]/str[@name='id'][.='7']", diff --git a/solr/test-framework/src/java/org/apache/solr/SolrTestCaseJ4.java b/solr/test-framework/src/java/org/apache/solr/SolrTestCaseJ4.java index 126698945eed..deb70d5109aa 100644 --- a/solr/test-framework/src/java/org/apache/solr/SolrTestCaseJ4.java +++ b/solr/test-framework/src/java/org/apache/solr/SolrTestCaseJ4.java @@ -891,7 +891,7 @@ public static void assertQ(String message, SolrQueryRequest req, String... tests // since the default (standard) response format is now JSON // need to explicitly request XML since this class uses XPath ModifiableSolrParams xmlWriterTypeParams = new ModifiableSolrParams(req.getParams()); - xmlWriterTypeParams.set(CommonParams.WT, "xml"); + xmlWriterTypeParams.set(CommonParams.WT, "json"); // for tests, let's turn indention off so we don't have to handle extraneous spaces xmlWriterTypeParams.set("indent", xmlWriterTypeParams.get("indent", "off")); req.setParams(xmlWriterTypeParams); From 723b8b16681acbb9d2bd383f9a25353f23cc0e68 Mon Sep 17 00:00:00 2001 From: Alessandro Benedetti Date: Thu, 10 Apr 2025 12:24:13 +0100 Subject: [PATCH 16/43] draft for automatic child transformer --- .../apache/solr/schema/DenseVectorField.java | 63 ++++++++++--- .../org/apache/solr/search/ReturnFields.java | 2 + .../apache/solr/search/SolrReturnFields.java | 6 +- .../KnnQParserMultiValuedVectorsTest.java | 88 ++++++++++++++++++- .../java/org/apache/solr/SolrTestCaseJ4.java | 2 +- 5 files changed, 142 insertions(+), 19 deletions(-) diff --git a/solr/core/src/java/org/apache/solr/schema/DenseVectorField.java b/solr/core/src/java/org/apache/solr/schema/DenseVectorField.java index 1f5c5db0b1c3..354cbb835e77 100644 --- a/solr/core/src/java/org/apache/solr/schema/DenseVectorField.java +++ b/solr/core/src/java/org/apache/solr/schema/DenseVectorField.java @@ -55,12 +55,15 @@ import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.hnsw.HnswGraph; import org.apache.solr.common.SolrException; +import org.apache.solr.common.params.ModifiableSolrParams; import org.apache.solr.request.SolrQueryRequest; import org.apache.solr.response.transform.ChildDocTransformer; +import org.apache.solr.response.transform.ChildDocTransformerFactory; import org.apache.solr.response.transform.DocTransformer; import org.apache.solr.response.transform.DocTransformers; import org.apache.solr.search.QParser; import org.apache.solr.search.QueryLimits; +import org.apache.solr.search.ReturnFields; import org.apache.solr.search.join.BlockJoinParentQParser; import org.apache.solr.uninverting.UninvertingReader; import org.apache.solr.util.vector.ByteDenseVectorParser; @@ -402,7 +405,7 @@ public Query getMultiValuedKnnVectorQuery(final SolrQueryRequest request, DenseVectorParser vectorBuilder = getVectorBuilder(vectorToSearch, DenseVectorParser.BuilderPhase.QUERY); - + BooleanQuery allDocuments = new BooleanQuery.Builder() .add(new BooleanClause(new MatchAllDocsQuery(), BooleanClause.Occur.MUST)) @@ -411,14 +414,14 @@ public Query getMultiValuedKnnVectorQuery(final SolrQueryRequest request, new DocValuesFieldExistsQuery(NEST_PATH_FIELD_NAME), BooleanClause.Occur.MUST_NOT)) .build(); - + BitSetProducer allParentsBitSet = BlockJoinParentQParser.getCachedBitSetProducer(request, allDocuments); - + Query acceptedVectorsBasedOnDocumentFilters = null; - if(filterQuery != null){ - acceptedVectorsBasedOnDocumentFilters = new ToChildBlockJoinQuery(filterQuery, allParentsBitSet); + if (filterQuery != null) { + acceptedVectorsBasedOnDocumentFilters = new ToChildBlockJoinQuery(filterQuery, allParentsBitSet); } - + Query knnOnVectorField; switch (vectorEncoding) { @@ -436,19 +439,53 @@ public Query getMultiValuedKnnVectorQuery(final SolrQueryRequest request, "Unexpected state. Vector Encoding: " + vectorEncoding); } try { - //TO DO: if no ChildDocTransformer, then we need to add this one - QueryLimits currentLimits = QueryLimits.getCurrentLimits(); - DocTransformers transformers = (DocTransformers)currentLimits.getRsp().getReturnFields().getTransformer(); - ChildDocTransformer child = (ChildDocTransformer) transformers.getTransformer(1); - child.setChildDocSet(request.getSearcher().getDocSet(knnOnVectorField)); - + knnOnVectorField = knnOnVectorField.rewrite(request.getSearcher()); + setAppropriateChildrenListingTransformer(request, fieldName, knnOnVectorField); return new ToParentBlockJoinQuery( knnOnVectorField, allParentsBitSet, ScoreMode.Max); } catch (IOException e) { throw new SolrException(SolrException.ErrorCode.SERVER_ERROR, e); } + } + + private void setAppropriateChildrenListingTransformer(SolrQueryRequest request, String fieldName, Query knnOnVectorField) throws IOException { + QueryLimits currentLimits = QueryLimits.getCurrentLimits(); + ReturnFields returnFields = currentLimits.getRsp().getReturnFields(); + DocTransformer originalTransformer = returnFields.getTransformer(); - + if (originalTransformer == null) { + ChildDocTransformer addBestVectorPerDocument = getDefaultVectorChildrenTransformer(request, fieldName, knnOnVectorField); + returnFields.setTransformer(addBestVectorPerDocument); + } else if (originalTransformer instanceof DocTransformers) { + DocTransformers transformers = (DocTransformers) originalTransformer; + boolean noChildTransformer = true; + for (int i = 0; i < transformers.size() && noChildTransformer; i++) { + DocTransformer t = transformers.getTransformer(i); + if (t instanceof ChildDocTransformer) { + noChildTransformer = false; + } + } + if (noChildTransformer) { + transformers.addTransformer(getDefaultVectorChildrenTransformer(request, fieldName, knnOnVectorField)); + } + } else { + if (!(originalTransformer instanceof ChildDocTransformer)) { + DocTransformers transformers = new DocTransformers(); + transformers.addTransformer(originalTransformer); + transformers.addTransformer(getDefaultVectorChildrenTransformer(request, fieldName, knnOnVectorField)); + returnFields.setTransformer(transformers); + } + } + } + + private static ChildDocTransformer getDefaultVectorChildrenTransformer(SolrQueryRequest request, String fieldName, Query knnOnVectorField) throws IOException { + ChildDocTransformerFactory childFactory = new ChildDocTransformerFactory(); + ModifiableSolrParams params = new ModifiableSolrParams(); + params.add("limit","1"); + params.add("fl",fieldName); + ChildDocTransformer addVectors = (ChildDocTransformer)childFactory.create(fieldName,params, request); + addVectors.setChildDocSet(request.getSearcher().getDocSet(knnOnVectorField)); + return addVectors; } /** diff --git a/solr/core/src/java/org/apache/solr/search/ReturnFields.java b/solr/core/src/java/org/apache/solr/search/ReturnFields.java index 44dcb12491ed..49dd5c1140a2 100644 --- a/solr/core/src/java/org/apache/solr/search/ReturnFields.java +++ b/solr/core/src/java/org/apache/solr/search/ReturnFields.java @@ -89,4 +89,6 @@ public abstract class ReturnFields { /** Returns the DocTransformer used to modify documents, or null */ public abstract DocTransformer getTransformer(); + + public abstract void setTransformer(DocTransformer transformer); } diff --git a/solr/core/src/java/org/apache/solr/search/SolrReturnFields.java b/solr/core/src/java/org/apache/solr/search/SolrReturnFields.java index 71e29d9ccd62..83dd09300e53 100644 --- a/solr/core/src/java/org/apache/solr/search/SolrReturnFields.java +++ b/solr/core/src/java/org/apache/solr/search/SolrReturnFields.java @@ -38,7 +38,6 @@ import org.apache.solr.common.params.SolrParams; import org.apache.solr.common.util.GlobPatternUtil; import org.apache.solr.request.SolrQueryRequest; -import org.apache.solr.response.transform.ChildDocTransformerFactory; import org.apache.solr.response.transform.DocTransformer; import org.apache.solr.response.transform.DocTransformers; import org.apache.solr.response.transform.RenameFieldTransformer; @@ -605,6 +604,11 @@ public DocTransformer getTransformer() { return transformer; } + @Override + public void setTransformer(DocTransformer transformer) { +this.transformer = transformer; + } + @Override public String toString() { final StringBuilder sb = new StringBuilder("SolrReturnFields=("); diff --git a/solr/core/src/test/org/apache/solr/search/neural/KnnQParserMultiValuedVectorsTest.java b/solr/core/src/test/org/apache/solr/search/neural/KnnQParserMultiValuedVectorsTest.java index 4c20c3bbfde1..9a4afc80c66d 100644 --- a/solr/core/src/test/org/apache/solr/search/neural/KnnQParserMultiValuedVectorsTest.java +++ b/solr/core/src/test/org/apache/solr/search/neural/KnnQParserMultiValuedVectorsTest.java @@ -17,12 +17,14 @@ package org.apache.solr.search.neural; import org.apache.solr.SolrTestCaseJ4; +import org.apache.solr.client.solrj.SolrQuery; import org.apache.solr.common.SolrException; import org.apache.solr.common.SolrInputDocument; import org.apache.solr.common.params.CommonParams; import org.apache.solr.common.params.SolrParams; import org.apache.solr.request.SolrQueryRequest; import org.apache.solr.util.RandomNoReverseMergePolicyFactory; +import org.apache.solr.util.RestTestBase; import org.junit.After; import org.junit.Before; import org.junit.BeforeClass; @@ -155,22 +157,100 @@ public void topK_shouldReturnOnlyTopKResults() { @Test public void topKWithFilter_shouldReturnOnlyTopKResults() { assertQ( - req(CommonParams.Q, "{!knn f=vector_multivalued topK=5}" + FLOAT_QUERY_VECTOR, "fl", "id,[child childFilter=$allChildren limit=2 fl=id,vector_multivalued]","fq","_text_:(b OR c)","allChildren","_nest_path_:[* TO *]"), + req(CommonParams.Q, "{!knn f=vector_multivalued topK=5}" + FLOAT_QUERY_VECTOR, "fl", "id","fq","_text_:(b OR c)"), "//result[@numFound='4']", "//result/doc[1]/str[@name='id'][.='8']", "//result/doc[2]/str[@name='id'][.='7']", "//result/doc[3]/str[@name='id'][.='2']", "//result/doc[4]/str[@name='id'][.='1']"); } + @Test - public void topKWithFilterAndChildTransformer_shouldReturnOnlyTopKResults() { + public void topKWithoutTransformer_shouldDefaultToBestChildren() { assertQ( - req(CommonParams.Q, "{!knn f=vector_multivalued topK=5}" + FLOAT_QUERY_VECTOR, "fl", "id,score,vector_multivalued,[child fl=vector_multivalued]","fq","_text_:(b OR c)","allChildren","_nest_path_:[* TO *]"), + req(CommonParams.Q, "{!knn f=vector_multivalued topK=5}" + FLOAT_QUERY_VECTOR, "fl", "id,vector_multivalued","fq","_text_:(b OR c)"), "//result[@numFound='4']", "//result/doc[1]/str[@name='id'][.='8']", + "//result/doc[1]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[1][.='8.0']", + "//result/doc[1]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[2][.='1.0']", + "//result/doc[1]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[3][.='1.0']", + "//result/doc[1]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[4][.='1.0']", "//result/doc[2]/str[@name='id'][.='7']", + "//result/doc[2]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[1][.='11.0']", + "//result/doc[2]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[2][.='1.0']", + "//result/doc[2]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[3][.='1.0']", + "//result/doc[2]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[4][.='1.0']", "//result/doc[3]/str[@name='id'][.='2']", - "//result/doc[4]/str[@name='id'][.='1']"); + "//result/doc[3]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[1][.='26.0']", + "//result/doc[3]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[2][.='1.0']", + "//result/doc[3]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[3][.='1.0']", + "//result/doc[3]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[4][.='1.0']", + "//result/doc[4]/str[@name='id'][.='1']", + "//result/doc[4]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[1][.='29.0']", + "//result/doc[4]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[2][.='1.0']", + "//result/doc[4]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[3][.='1.0']", + "//result/doc[4]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[4][.='1.0']"); + } + + @Test + public void topKWithTransformer_shouldAddDefaultToBestChildren() { + assertQ( + req(CommonParams.Q, "{!knn f=vector_multivalued topK=5}" + FLOAT_QUERY_VECTOR, "fl", "id,vector_multivalued,score","fq","_text_:(b OR c)"), + "//result[@numFound='4']", + "//result/doc[1]/str[@name='id'][.='8']", + "//result/doc[1]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[1][.='8.0']", + "//result/doc[1]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[2][.='1.0']", + "//result/doc[1]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[3][.='1.0']", + "//result/doc[1]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[4][.='1.0']", + "//result/doc[2]/str[@name='id'][.='7']", + "//result/doc[2]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[1][.='11.0']", + "//result/doc[2]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[2][.='1.0']", + "//result/doc[2]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[3][.='1.0']", + "//result/doc[2]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[4][.='1.0']", + "//result/doc[3]/str[@name='id'][.='2']", + "//result/doc[3]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[1][.='26.0']", + "//result/doc[3]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[2][.='1.0']", + "//result/doc[3]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[3][.='1.0']", + "//result/doc[3]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[4][.='1.0']", + "//result/doc[4]/str[@name='id'][.='1']", + "//result/doc[4]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[1][.='29.0']", + "//result/doc[4]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[2][.='1.0']", + "//result/doc[4]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[3][.='1.0']", + "//result/doc[4]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[4][.='1.0']"); + } + + @Test + public void topKWithChildTransformer_shouldUseOriginalChildTransformer() { + assertQ( + req(CommonParams.Q, "{!knn f=vector_multivalued topK=3}" + FLOAT_QUERY_VECTOR, "fl", "id,vector_multivalued,score,[child limit=2 fl=vector_multivalued]","fq","_text_:(b OR c)"), + "//result[@numFound='3']", + "//result/doc[1]/str[@name='id'][.='8']", + "//result/doc[1]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[1][.='10.0']", + "//result/doc[1]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[2][.='1.0']", + "//result/doc[1]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[3][.='1.0']", + "//result/doc[1]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[4][.='1.0']", + "//result/doc[1]/arr[@name='vector_multivalued'][1]/doc[2]/arr[@name='vector_multivalued']/float[1][.='9.0']", + "//result/doc[1]/arr[@name='vector_multivalued'][1]/doc[2]/arr[@name='vector_multivalued']/float[2][.='1.0']", + "//result/doc[1]/arr[@name='vector_multivalued'][1]/doc[2]/arr[@name='vector_multivalued']/float[3][.='1.0']", + "//result/doc[1]/arr[@name='vector_multivalued'][1]/doc[2]/arr[@name='vector_multivalued']/float[4][.='1.0']", + "//result/doc[2]/str[@name='id'][.='7']", + "//result/doc[2]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[1][.='13.0']", + "//result/doc[2]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[2][.='1.0']", + "//result/doc[2]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[3][.='1.0']", + "//result/doc[2]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[4][.='1.0']", + "//result/doc[2]/arr[@name='vector_multivalued'][1]/doc[2]/arr[@name='vector_multivalued']/float[1][.='12.0']", + "//result/doc[2]/arr[@name='vector_multivalued'][1]/doc[2]/arr[@name='vector_multivalued']/float[2][.='1.0']", + "//result/doc[2]/arr[@name='vector_multivalued'][1]/doc[2]/arr[@name='vector_multivalued']/float[3][.='1.0']", + "//result/doc[2]/arr[@name='vector_multivalued'][1]/doc[2]/arr[@name='vector_multivalued']/float[4][.='1.0']", + "//result/doc[3]/str[@name='id'][.='2']", + "//result/doc[3]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[1][.='28.0']", + "//result/doc[3]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[2][.='1.0']", + "//result/doc[3]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[3][.='1.0']", + "//result/doc[3]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[4][.='1.0']", + "//result/doc[3]/arr[@name='vector_multivalued'][1]/doc[2]/arr[@name='vector_multivalued']/float[1][.='27.0']", + "//result/doc[3]/arr[@name='vector_multivalued'][1]/doc[2]/arr[@name='vector_multivalued']/float[2][.='1.0']", + "//result/doc[3]/arr[@name='vector_multivalued'][1]/doc[2]/arr[@name='vector_multivalued']/float[3][.='1.0']", + "//result/doc[3]/arr[@name='vector_multivalued'][1]/doc[2]/arr[@name='vector_multivalued']/float[4][.='1.0']"); } } diff --git a/solr/test-framework/src/java/org/apache/solr/SolrTestCaseJ4.java b/solr/test-framework/src/java/org/apache/solr/SolrTestCaseJ4.java index deb70d5109aa..126698945eed 100644 --- a/solr/test-framework/src/java/org/apache/solr/SolrTestCaseJ4.java +++ b/solr/test-framework/src/java/org/apache/solr/SolrTestCaseJ4.java @@ -891,7 +891,7 @@ public static void assertQ(String message, SolrQueryRequest req, String... tests // since the default (standard) response format is now JSON // need to explicitly request XML since this class uses XPath ModifiableSolrParams xmlWriterTypeParams = new ModifiableSolrParams(req.getParams()); - xmlWriterTypeParams.set(CommonParams.WT, "json"); + xmlWriterTypeParams.set(CommonParams.WT, "xml"); // for tests, let's turn indention off so we don't have to handle extraneous spaces xmlWriterTypeParams.set("indent", xmlWriterTypeParams.get("indent", "off")); req.setParams(xmlWriterTypeParams); From 369314128e601147cde7eec1543d7b51125aa128 Mon Sep 17 00:00:00 2001 From: Alessandro Benedetti Date: Thu, 10 Apr 2025 12:55:35 +0100 Subject: [PATCH 17/43] add best child per document transformer --- .../transform/ChildDocTransformer.java | 12 +++- .../org/apache/solr/search/QueryLimits.java | 4 ++ .../search/join/BlockJoinParentQParser.java | 33 +++++++++ .../BlockJoinNestedVectorsQParserTest.java | 69 +++++++++++++++++++ 4 files changed, 116 insertions(+), 2 deletions(-) diff --git a/solr/core/src/java/org/apache/solr/response/transform/ChildDocTransformer.java b/solr/core/src/java/org/apache/solr/response/transform/ChildDocTransformer.java index d8a4f0842264..9abb0fc34738 100644 --- a/solr/core/src/java/org/apache/solr/response/transform/ChildDocTransformer.java +++ b/solr/core/src/java/org/apache/solr/response/transform/ChildDocTransformer.java @@ -51,14 +51,14 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -class ChildDocTransformer extends DocTransformer { +public class ChildDocTransformer extends DocTransformer { private static final Logger log = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass()); private static final String ANON_CHILD_KEY = "_childDocuments_"; private final String name; private final BitSetProducer parentsFilter; // if null; resolve parent via uniqueKey instead - private final DocSet childDocSet; + private DocSet childDocSet; private final int limit; private final boolean isNestedSchema; private final SolrReturnFields childReturnFields; @@ -96,6 +96,14 @@ public String[] getExtraRequestFields() { return extraRequestedFields; } + public DocSet getChildDocSet() { + return childDocSet; + } + + public void setChildDocSet(DocSet childDocSet) { + this.childDocSet = childDocSet; + } + private int getPrevRootGivenFilter(LeafReaderContext leafReaderContext, int segRootId) throws IOException { final BitSet segParentsBitSet = parentsFilter.getBitSet(leafReaderContext); diff --git a/solr/core/src/java/org/apache/solr/search/QueryLimits.java b/solr/core/src/java/org/apache/solr/search/QueryLimits.java index e6e0db5eed94..2cde716e1c95 100644 --- a/solr/core/src/java/org/apache/solr/search/QueryLimits.java +++ b/solr/core/src/java/org/apache/solr/search/QueryLimits.java @@ -193,4 +193,8 @@ public static QueryLimits getCurrentLimits() { final SolrRequestInfo info = SolrRequestInfo.getRequestInfo(); return info != null ? info.getLimits() : NONE; } + + public SolrQueryResponse getRsp() { + return rsp; + } } diff --git a/solr/core/src/java/org/apache/solr/search/join/BlockJoinParentQParser.java b/solr/core/src/java/org/apache/solr/search/join/BlockJoinParentQParser.java index 0b03f720791d..d3540d702b4d 100644 --- a/solr/core/src/java/org/apache/solr/search/join/BlockJoinParentQParser.java +++ b/solr/core/src/java/org/apache/solr/search/join/BlockJoinParentQParser.java @@ -44,9 +44,14 @@ import org.apache.lucene.util.BitSetIterator; import org.apache.solr.common.params.SolrParams; import org.apache.solr.request.SolrQueryRequest; +import org.apache.solr.response.transform.ChildDocTransformer; +import org.apache.solr.response.transform.DocTransformer; +import org.apache.solr.response.transform.DocTransformers; import org.apache.solr.search.ExtendedQueryBase; import org.apache.solr.search.QParser; +import org.apache.solr.search.QueryLimits; import org.apache.solr.search.QueryUtils; +import org.apache.solr.search.ReturnFields; import org.apache.solr.search.SolrCache; import org.apache.solr.search.SyntaxError; @@ -134,6 +139,34 @@ protected Query createQuery(final Query allParents, BooleanQuery childrenQuery, } } + private void setAppropriateChildrenListingTransformer(SolrQueryRequest request, String fieldName, Query knnOnVectorField) throws IOException { + QueryLimits currentLimits = QueryLimits.getCurrentLimits(); + ReturnFields returnFields = currentLimits.getRsp().getReturnFields(); + DocTransformer originalTransformer = returnFields.getTransformer(); + + if (originalTransformer instanceof DocTransformers) { + DocTransformers transformers = (DocTransformers) originalTransformer; + boolean noChildTransformer = true; + for (int i = 0; i < transformers.size() && noChildTransformer; i++) { + DocTransformer t = transformers.getTransformer(i); + if (t instanceof ChildDocTransformer) { + ChildDocTransformer childTransformer = (ChildDocTransformer) t; + if(childTransformer.getChildDocSet() == null) { + childTransformer.setChildDocSet(); + } + noChildTransformer = false; + } + } + } else { + if ((originalTransformer instanceof ChildDocTransformer)) { + ChildDocTransformer childTransformer = (ChildDocTransformer) originalTransformer; + if(childTransformer.getChildDocSet() == null) { + childTransformer.setChildDocSet(); + } + } + } + } + private boolean isFloatKnnQuery(List childrenClauses) { return childrenClauses.size() == 1 && childrenClauses.get(0).getQuery().getClass().equals(KnnFloatVectorQuery.class); diff --git a/solr/core/src/test/org/apache/solr/search/join/BlockJoinNestedVectorsQParserTest.java b/solr/core/src/test/org/apache/solr/search/join/BlockJoinNestedVectorsQParserTest.java index 764febdd9af9..8015ddb96140 100644 --- a/solr/core/src/test/org/apache/solr/search/join/BlockJoinNestedVectorsQParserTest.java +++ b/solr/core/src/test/org/apache/solr/search/join/BlockJoinNestedVectorsQParserTest.java @@ -21,6 +21,7 @@ import java.util.List; import org.apache.solr.SolrTestCaseJ4; import org.apache.solr.common.SolrInputDocument; +import org.apache.solr.common.params.CommonParams; import org.apache.solr.util.RandomNoReverseMergePolicyFactory; import org.junit.BeforeClass; import org.junit.ClassRule; @@ -257,4 +258,72 @@ public void parentRetrievalByte_knnChildrenWithParentFilter_shouldReturnKnnParen "//result/doc[1]/str[@name='id'][.='8']", "//result/doc[2]/str[@name='id'][.='2']"); } + + @Test + public void parentRetrievalFloat_topKWithChildTransformerWithFilter_shouldUseOriginalChildTransformerFilter() { + assertQ( + req( + "fq", "parent_s:(b c)", + "q", "{!parent which=$allParents score=max v=$children.q}", + "fl", "id,score,vectors,vector,[child limit=2 fl=vector childFilter=$all_children]", + "children.q", "{!knn f=vector topK=3}" + FLOAT_QUERY_VECTOR, + "allParents", "parent_s:[* TO *]", + "all_children", "child_s:[* TO *]"), + "//result[@numFound='3']", + "//result/doc[1]/str[@name='id'][.='8']", + "//result/doc[1]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[1][.='10.0']", + "//result/doc[1]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[2][.='1.0']", + "//result/doc[1]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[3][.='1.0']", + "//result/doc[1]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[4][.='1.0']", + "//result/doc[1]/arr[@name='vectors'][1]/doc[2]/arr[@name='vector']/float[1][.='9.0']", + "//result/doc[1]/arr[@name='vectors'][1]/doc[2]/arr[@name='vector']/float[2][.='1.0']", + "//result/doc[1]/arr[@name='vectors'][1]/doc[2]/arr[@name='vector']/float[3][.='1.0']", + "//result/doc[1]/arr[@name='vectors'][1]/doc[2]/arr[@name='vector']/float[4][.='1.0']", + "//result/doc[2]/str[@name='id'][.='7']", + "//result/doc[2]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[1][.='13.0']", + "//result/doc[2]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[2][.='1.0']", + "//result/doc[2]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[3][.='1.0']", + "//result/doc[2]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[4][.='1.0']", + "//result/doc[2]/arr[@name='vectors'][1]/doc[2]/arr[@name='vector']/float[1][.='12.0']", + "//result/doc[2]/arr[@name='vectors'][1]/doc[2]/arr[@name='vector']/float[2][.='1.0']", + "//result/doc[2]/arr[@name='vectors'][1]/doc[2]/arr[@name='vector']/float[3][.='1.0']", + "//result/doc[2]/arr[@name='vectors'][1]/doc[2]/arr[@name='vector']/float[4][.='1.0']", + "//result/doc[3]/str[@name='id'][.='2']", + "//result/doc[3]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[1][.='28.0']", + "//result/doc[3]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[2][.='1.0']", + "//result/doc[3]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[3][.='1.0']", + "//result/doc[3]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[4][.='1.0']", + "//result/doc[3]/arr[@name='vectors'][1]/doc[2]/arr[@name='vector']/float[1][.='27.0']", + "//result/doc[3]/arr[@name='vectors'][1]/doc[2]/arr[@name='vector']/float[2][.='1.0']", + "//result/doc[3]/arr[@name='vectors'][1]/doc[2]/arr[@name='vector']/float[3][.='1.0']", + "//result/doc[3]/arr[@name='vectors'][1]/doc[2]/arr[@name='vector']/float[4][.='1.0']"); + } + + @Test + public void parentRetrievalFloat_topKWithChildTransformerWithNoFilter_shouldUseBestChildrenVectorTransformerFilter() { + assertQ( + req( + "fq", "parent_s:(b c)", + "q", "{!parent which=$allParents score=max v=$children.q}", + "fl", "id,score,vectors,vector,[child fl=vector]", + "children.q", "{!knn f=vector topK=3}" + FLOAT_QUERY_VECTOR, + "allParents", "parent_s:[* TO *]"), + "//result[@numFound='3']", + "//result/doc[1]/str[@name='id'][.='8']", + "//result/doc[1]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[1][.='8.0']", + "//result/doc[1]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[2][.='1.0']", + "//result/doc[1]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[3][.='1.0']", + "//result/doc[1]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[4][.='1.0']", + "//result/doc[2]/str[@name='id'][.='7']", + "//result/doc[2]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[1][.='11.0']", + "//result/doc[2]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[2][.='1.0']", + "//result/doc[2]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[3][.='1.0']", + "//result/doc[2]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[4][.='1.0']", + "//result/doc[3]/str[@name='id'][.='2']", + "//result/doc[3]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[1][.='26.0']", + "//result/doc[3]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[2][.='1.0']", + "//result/doc[3]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[3][.='1.0']", + "//result/doc[3]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[4][.='1.0']"); + + } } From cdff4ad9ac0e9464b324c542ea40f7dabbf835d1 Mon Sep 17 00:00:00 2001 From: Alessandro Benedetti Date: Thu, 10 Apr 2025 12:59:39 +0100 Subject: [PATCH 18/43] add best child per document transformer --- .../search/join/BlockJoinParentQParser.java | 106 ++++++++++-------- 1 file changed, 58 insertions(+), 48 deletions(-) diff --git a/solr/core/src/java/org/apache/solr/search/join/BlockJoinParentQParser.java b/solr/core/src/java/org/apache/solr/search/join/BlockJoinParentQParser.java index d3540d702b4d..2484a321c1fa 100644 --- a/solr/core/src/java/org/apache/solr/search/join/BlockJoinParentQParser.java +++ b/solr/core/src/java/org/apache/solr/search/join/BlockJoinParentQParser.java @@ -94,52 +94,62 @@ protected Query noClausesQuery() throws SyntaxError { protected Query createQuery(final Query allParents, BooleanQuery childrenQuery, String scoreMode) throws SyntaxError { - List childrenClauses = childrenQuery.clauses(); - if (isByteKnnQuery(childrenClauses)) { - BitSetProducer allParentsBitSet = getBitSetProducer(allParents); - BooleanQuery parentsFilter = getParentsFilter(); - - KnnByteVectorQuery knnChildrenQuery = (KnnByteVectorQuery) childrenClauses.get(0).getQuery(); - String vectorField = knnChildrenQuery.getField(); - byte[] queryVector = knnChildrenQuery.getTargetCopy(); - int topK = knnChildrenQuery.getK(); - - Query acceptedChildren = - getChildrenFilter(knnChildrenQuery.getFilter(), parentsFilter, allParentsBitSet); - - Query knnChildren = - new DiversifyingChildrenByteKnnVectorQuery( - vectorField, queryVector, acceptedChildren, topK, allParentsBitSet); - return new ToParentBlockJoinQuery( - knnChildren, allParentsBitSet, ScoreModeParser.parse(scoreMode)); - } else if (isFloatKnnQuery(childrenClauses)) { - BitSetProducer allParentsBitSet = getBitSetProducer(allParents); - BooleanQuery parentsFilter = getParentsFilter(); - - KnnFloatVectorQuery knnChildrenQuery = - (KnnFloatVectorQuery) childrenClauses.get(0).getQuery(); - String vectorField = knnChildrenQuery.getField(); - float[] queryVector = knnChildrenQuery.getTargetCopy(); - int topK = knnChildrenQuery.getK(); - - Query childrenFilter = - getChildrenFilter(knnChildrenQuery.getFilter(), parentsFilter, allParentsBitSet); - - Query knnChildren = - new DiversifyingChildrenFloatKnnVectorQuery( - vectorField, queryVector, childrenFilter, topK, allParentsBitSet); - return new ToParentBlockJoinQuery( - knnChildren, allParentsBitSet, ScoreModeParser.parse(scoreMode)); - } else { - return new AllParentsAware( - childrenQuery, - getBitSetProducer(allParents), - ScoreModeParser.parse(scoreMode), - allParents); - } + try { + List childrenClauses = childrenQuery.clauses(); + if (isByteKnnQuery(childrenClauses)) { + BitSetProducer allParentsBitSet = getBitSetProducer(allParents); + BooleanQuery parentsFilter = getParentsFilter(); + + KnnByteVectorQuery knnChildrenQuery = (KnnByteVectorQuery) childrenClauses.get(0).getQuery(); + String vectorField = knnChildrenQuery.getField(); + byte[] queryVector = knnChildrenQuery.getTargetCopy(); + int topK = knnChildrenQuery.getK(); + + Query acceptedChildren = + getChildrenFilter(knnChildrenQuery.getFilter(), parentsFilter, allParentsBitSet); + + Query knnChildren = + new DiversifyingChildrenByteKnnVectorQuery( + vectorField, queryVector, acceptedChildren, topK, allParentsBitSet); + knnChildren = knnChildren.rewrite(req.getSearcher()); + this.setAppropriateChildrenListingTransformer(req,knnChildren); + + return new ToParentBlockJoinQuery( + knnChildren, allParentsBitSet, ScoreModeParser.parse(scoreMode)); + } else if (isFloatKnnQuery(childrenClauses)) { + BitSetProducer allParentsBitSet = getBitSetProducer(allParents); + BooleanQuery parentsFilter = getParentsFilter(); + + KnnFloatVectorQuery knnChildrenQuery = + (KnnFloatVectorQuery) childrenClauses.get(0).getQuery(); + String vectorField = knnChildrenQuery.getField(); + float[] queryVector = knnChildrenQuery.getTargetCopy(); + int topK = knnChildrenQuery.getK(); + + Query childrenFilter = + getChildrenFilter(knnChildrenQuery.getFilter(), parentsFilter, allParentsBitSet); + + Query knnChildren = + new DiversifyingChildrenFloatKnnVectorQuery( + vectorField, queryVector, childrenFilter, topK, allParentsBitSet); + knnChildren = knnChildren.rewrite(req.getSearcher()); + this.setAppropriateChildrenListingTransformer(req,knnChildren); + + return new ToParentBlockJoinQuery( + knnChildren, allParentsBitSet, ScoreModeParser.parse(scoreMode)); + } else { + return new AllParentsAware( + childrenQuery, + getBitSetProducer(allParents), + ScoreModeParser.parse(scoreMode), + allParents); + } + } catch (IOException e) { + throw new RuntimeException(e); + } } - private void setAppropriateChildrenListingTransformer(SolrQueryRequest request, String fieldName, Query knnOnVectorField) throws IOException { + private void setAppropriateChildrenListingTransformer(SolrQueryRequest request, Query knnOnVectorField) throws IOException { QueryLimits currentLimits = QueryLimits.getCurrentLimits(); ReturnFields returnFields = currentLimits.getRsp().getReturnFields(); DocTransformer originalTransformer = returnFields.getTransformer(); @@ -151,8 +161,8 @@ private void setAppropriateChildrenListingTransformer(SolrQueryRequest request, DocTransformer t = transformers.getTransformer(i); if (t instanceof ChildDocTransformer) { ChildDocTransformer childTransformer = (ChildDocTransformer) t; - if(childTransformer.getChildDocSet() == null) { - childTransformer.setChildDocSet(); + if (childTransformer.getChildDocSet() == null) { + childTransformer.setChildDocSet(request.getSearcher().getDocSet(knnOnVectorField)); } noChildTransformer = false; } @@ -160,8 +170,8 @@ private void setAppropriateChildrenListingTransformer(SolrQueryRequest request, } else { if ((originalTransformer instanceof ChildDocTransformer)) { ChildDocTransformer childTransformer = (ChildDocTransformer) originalTransformer; - if(childTransformer.getChildDocSet() == null) { - childTransformer.setChildDocSet(); + if (childTransformer.getChildDocSet() == null) { + childTransformer.setChildDocSet(request.getSearcher().getDocSet(knnOnVectorField)); } } } From 88a05366c8096e548727e5f3d89f44195e7a842c Mon Sep 17 00:00:00 2001 From: Alessandro Benedetti Date: Thu, 10 Apr 2025 13:18:54 +0100 Subject: [PATCH 19/43] add best child per document transformer --- .../search/join/BlockJoinParentQParser.java | 3 +- .../BlockJoinNestedVectorsQParserTest.java | 121 ++++++++++++++++++ 2 files changed, 123 insertions(+), 1 deletion(-) diff --git a/solr/core/src/java/org/apache/solr/search/join/BlockJoinParentQParser.java b/solr/core/src/java/org/apache/solr/search/join/BlockJoinParentQParser.java index 2484a321c1fa..ce38a7394b03 100644 --- a/solr/core/src/java/org/apache/solr/search/join/BlockJoinParentQParser.java +++ b/solr/core/src/java/org/apache/solr/search/join/BlockJoinParentQParser.java @@ -42,6 +42,7 @@ import org.apache.lucene.search.join.ToParentBlockJoinQuery; import org.apache.lucene.util.BitSet; import org.apache.lucene.util.BitSetIterator; +import org.apache.solr.common.SolrException; import org.apache.solr.common.params.SolrParams; import org.apache.solr.request.SolrQueryRequest; import org.apache.solr.response.transform.ChildDocTransformer; @@ -145,7 +146,7 @@ protected Query createQuery(final Query allParents, BooleanQuery childrenQuery, allParents); } } catch (IOException e) { - throw new RuntimeException(e); + throw new SolrException(SolrException.ErrorCode.SERVER_ERROR, e); } } diff --git a/solr/core/src/test/org/apache/solr/search/join/BlockJoinNestedVectorsQParserTest.java b/solr/core/src/test/org/apache/solr/search/join/BlockJoinNestedVectorsQParserTest.java index 8015ddb96140..bf0d9d6e9e51 100644 --- a/solr/core/src/test/org/apache/solr/search/join/BlockJoinNestedVectorsQParserTest.java +++ b/solr/core/src/test/org/apache/solr/search/join/BlockJoinNestedVectorsQParserTest.java @@ -324,6 +324,127 @@ public void parentRetrievalFloat_topKWithChildTransformerWithNoFilter_shouldUseB "//result/doc[3]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[2][.='1.0']", "//result/doc[3]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[3][.='1.0']", "//result/doc[3]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[4][.='1.0']"); + } + + @Test + public void parentRetrievalFloat_topKWithOnlyChildTransformerWithNoFilter_shouldUseBestChildrenVectorTransformerFilter() { + assertQ( + req( + "fq", "parent_s:(b c)", + "q", "{!parent which=$allParents score=max v=$children.q}", + "fl", "id,vectors,vector,[child fl=vector]", + "children.q", "{!knn f=vector topK=3}" + FLOAT_QUERY_VECTOR, + "allParents", "parent_s:[* TO *]"), + "//result[@numFound='3']", + "//result/doc[1]/str[@name='id'][.='8']", + "//result/doc[1]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[1][.='8.0']", + "//result/doc[1]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[2][.='1.0']", + "//result/doc[1]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[3][.='1.0']", + "//result/doc[1]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[4][.='1.0']", + "//result/doc[2]/str[@name='id'][.='7']", + "//result/doc[2]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[1][.='11.0']", + "//result/doc[2]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[2][.='1.0']", + "//result/doc[2]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[3][.='1.0']", + "//result/doc[2]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[4][.='1.0']", + "//result/doc[3]/str[@name='id'][.='2']", + "//result/doc[3]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[1][.='26.0']", + "//result/doc[3]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[2][.='1.0']", + "//result/doc[3]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[3][.='1.0']", + "//result/doc[3]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[4][.='1.0']"); + } + + @Test + public void parentRetrievalByte_topKWithChildTransformerWithFilter_shouldUseOriginalChildTransformerFilter() { + assertQ( + req( + "fq", "parent_s:(b c)", + "q", "{!parent which=$allParents score=max v=$children.q}", + "fl", "id,score,vectors,vector_byte,[child limit=2 fl=vector_byte childFilter=$all_children]", + "children.q", "{!knn f=vector_byte topK=3}" + BYTE_QUERY_VECTOR, + "allParents", "parent_s:[* TO *]", + "all_children", "child_s:[* TO *]"), + "//result[@numFound='3']", + "//result/doc[1]/str[@name='id'][.='8']", + "//result/doc[1]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[1][.='10']", + "//result/doc[1]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[2][.='1']", + "//result/doc[1]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[3][.='1']", + "//result/doc[1]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[4][.='1']", + "//result/doc[1]/arr[@name='vectors'][1]/doc[2]/arr[@name='vector_byte']/int[1][.='9']", + "//result/doc[1]/arr[@name='vectors'][1]/doc[2]/arr[@name='vector_byte']/int[2][.='1']", + "//result/doc[1]/arr[@name='vectors'][1]/doc[2]/arr[@name='vector_byte']/int[3][.='1']", + "//result/doc[1]/arr[@name='vectors'][1]/doc[2]/arr[@name='vector_byte']/int[4][.='1']", + "//result/doc[2]/str[@name='id'][.='7']", + "//result/doc[2]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[1][.='13']", + "//result/doc[2]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[2][.='1']", + "//result/doc[2]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[3][.='1']", + "//result/doc[2]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[4][.='1']", + "//result/doc[2]/arr[@name='vectors'][1]/doc[2]/arr[@name='vector_byte']/int[1][.='12']", + "//result/doc[2]/arr[@name='vectors'][1]/doc[2]/arr[@name='vector_byte']/int[2][.='1']", + "//result/doc[2]/arr[@name='vectors'][1]/doc[2]/arr[@name='vector_byte']/int[3][.='1']", + "//result/doc[2]/arr[@name='vectors'][1]/doc[2]/arr[@name='vector_byte']/int[4][.='1']", + "//result/doc[3]/str[@name='id'][.='2']", + "//result/doc[3]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[1][.='28']", + "//result/doc[3]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[2][.='1']", + "//result/doc[3]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[3][.='1']", + "//result/doc[3]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[4][.='1']", + "//result/doc[3]/arr[@name='vectors'][1]/doc[2]/arr[@name='vector_byte']/int[1][.='27']", + "//result/doc[3]/arr[@name='vectors'][1]/doc[2]/arr[@name='vector_byte']/int[2][.='1']", + "//result/doc[3]/arr[@name='vectors'][1]/doc[2]/arr[@name='vector_byte']/int[3][.='1']", + "//result/doc[3]/arr[@name='vectors'][1]/doc[2]/arr[@name='vector_byte']/int[4][.='1']"); + } + @Test + public void parentRetrievalByte_topKWithChildTransformerWithNoFilter_shouldUseBestChildrenVectorTransformerFilter() { + assertQ( + req( + "fq", "parent_s:(b c)", + "q", "{!parent which=$allParents score=max v=$children.q}", + "fl", "id,score,vectors,vector_byte,[child fl=vector_byte]", + "children.q", "{!knn f=vector_byte topK=3}" + BYTE_QUERY_VECTOR, + "allParents", "parent_s:[* TO *]"), + "//result[@numFound='3']", + "//result/doc[1]/str[@name='id'][.='8']", + "//result/doc[1]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[1][.='8']", + "//result/doc[1]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[2][.='1']", + "//result/doc[1]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[3][.='1']", + "//result/doc[1]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[4][.='1']", + "//result/doc[2]/str[@name='id'][.='7']", + "//result/doc[2]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[1][.='11']", + "//result/doc[2]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[2][.='1']", + "//result/doc[2]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[3][.='1']", + "//result/doc[2]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[4][.='1']", + "//result/doc[3]/str[@name='id'][.='2']", + "//result/doc[3]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[1][.='26']", + "//result/doc[3]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[2][.='1']", + "//result/doc[3]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[3][.='1']", + "//result/doc[3]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[4][.='1']"); } + + @Test + public void parentRetrievalByte_topKWithOnlyChildTransformerWithNoFilter_shouldUseBestChildrenVectorTransformerFilter() { + assertQ( + req( + "fq", "parent_s:(b c)", + "q", "{!parent which=$allParents score=max v=$children.q}", + "fl", "id,vectors,vector_byte,[child fl=vector_byte]", + "children.q", "{!knn f=vector_byte topK=3}" + BYTE_QUERY_VECTOR, + "allParents", "parent_s:[* TO *]"), + "//result[@numFound='3']", + "//result/doc[1]/str[@name='id'][.='8']", + "//result/doc[1]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[1][.='8']", + "//result/doc[1]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[2][.='1']", + "//result/doc[1]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[3][.='1']", + "//result/doc[1]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[4][.='1']", + "//result/doc[2]/str[@name='id'][.='7']", + "//result/doc[2]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[1][.='11']", + "//result/doc[2]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[2][.='1']", + "//result/doc[2]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[3][.='1']", + "//result/doc[2]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[4][.='1']", + "//result/doc[3]/str[@name='id'][.='2']", + "//result/doc[3]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[1][.='26']", + "//result/doc[3]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[2][.='1']", + "//result/doc[3]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[3][.='1']", + "//result/doc[3]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[4][.='1']"); + } + } From fcbf770d2019731d7e366c9cb94d548a18a9522c Mon Sep 17 00:00:00 2001 From: Alessandro Benedetti Date: Wed, 16 Apr 2025 17:44:52 +0100 Subject: [PATCH 20/43] minor refinement to avoid some instructions --- .../search/join/BlockJoinParentQParser.java | 130 ++++--- .../BlockJoinNestedVectorsQParserTest.java | 339 +++++++++--------- 2 files changed, 243 insertions(+), 226 deletions(-) diff --git a/solr/core/src/java/org/apache/solr/search/join/BlockJoinParentQParser.java b/solr/core/src/java/org/apache/solr/search/join/BlockJoinParentQParser.java index ce38a7394b03..7e64f385ac84 100644 --- a/solr/core/src/java/org/apache/solr/search/join/BlockJoinParentQParser.java +++ b/solr/core/src/java/org/apache/solr/search/join/BlockJoinParentQParser.java @@ -95,62 +95,64 @@ protected Query noClausesQuery() throws SyntaxError { protected Query createQuery(final Query allParents, BooleanQuery childrenQuery, String scoreMode) throws SyntaxError { - try { - List childrenClauses = childrenQuery.clauses(); - if (isByteKnnQuery(childrenClauses)) { - BitSetProducer allParentsBitSet = getBitSetProducer(allParents); - BooleanQuery parentsFilter = getParentsFilter(); - - KnnByteVectorQuery knnChildrenQuery = (KnnByteVectorQuery) childrenClauses.get(0).getQuery(); - String vectorField = knnChildrenQuery.getField(); - byte[] queryVector = knnChildrenQuery.getTargetCopy(); - int topK = knnChildrenQuery.getK(); - - Query acceptedChildren = - getChildrenFilter(knnChildrenQuery.getFilter(), parentsFilter, allParentsBitSet); - - Query knnChildren = - new DiversifyingChildrenByteKnnVectorQuery( - vectorField, queryVector, acceptedChildren, topK, allParentsBitSet); - knnChildren = knnChildren.rewrite(req.getSearcher()); - this.setAppropriateChildrenListingTransformer(req,knnChildren); - - return new ToParentBlockJoinQuery( - knnChildren, allParentsBitSet, ScoreModeParser.parse(scoreMode)); - } else if (isFloatKnnQuery(childrenClauses)) { - BitSetProducer allParentsBitSet = getBitSetProducer(allParents); - BooleanQuery parentsFilter = getParentsFilter(); - - KnnFloatVectorQuery knnChildrenQuery = - (KnnFloatVectorQuery) childrenClauses.get(0).getQuery(); - String vectorField = knnChildrenQuery.getField(); - float[] queryVector = knnChildrenQuery.getTargetCopy(); - int topK = knnChildrenQuery.getK(); - - Query childrenFilter = - getChildrenFilter(knnChildrenQuery.getFilter(), parentsFilter, allParentsBitSet); - - Query knnChildren = - new DiversifyingChildrenFloatKnnVectorQuery( - vectorField, queryVector, childrenFilter, topK, allParentsBitSet); - knnChildren = knnChildren.rewrite(req.getSearcher()); - this.setAppropriateChildrenListingTransformer(req,knnChildren); - - return new ToParentBlockJoinQuery( - knnChildren, allParentsBitSet, ScoreModeParser.parse(scoreMode)); - } else { - return new AllParentsAware( - childrenQuery, - getBitSetProducer(allParents), - ScoreModeParser.parse(scoreMode), - allParents); - } - } catch (IOException e) { - throw new SolrException(SolrException.ErrorCode.SERVER_ERROR, e); + try { + List childrenClauses = childrenQuery.clauses(); + KnnByteVectorQuery knnByteChildrenQuery = getBytetKnnQuery(childrenClauses); + if (knnByteChildrenQuery != null) { + BitSetProducer allParentsBitSet = getBitSetProducer(allParents); + BooleanQuery parentsFilter = getParentsFilter(); + + String vectorField = knnByteChildrenQuery.getField(); + byte[] queryVector = knnByteChildrenQuery.getTargetCopy(); + int topK = knnByteChildrenQuery.getK(); + + Query acceptedChildren = + getChildrenFilter(knnByteChildrenQuery.getFilter(), parentsFilter, allParentsBitSet); + + Query knnChildren = + new DiversifyingChildrenByteKnnVectorQuery( + vectorField, queryVector, acceptedChildren, topK, allParentsBitSet); + knnChildren = knnChildren.rewrite(req.getSearcher()); + this.setAppropriateChildrenListingTransformer(req, knnChildren); + + return new ToParentBlockJoinQuery( + knnChildren, allParentsBitSet, ScoreModeParser.parse(scoreMode)); + } else { + KnnFloatVectorQuery knnFLoatChildrenQuery = getFloatKnnQuery(childrenClauses); + if (knnFLoatChildrenQuery != null) { + BitSetProducer allParentsBitSet = getBitSetProducer(allParents); + BooleanQuery parentsFilter = getParentsFilter(); + + String vectorField = knnFLoatChildrenQuery.getField(); + float[] queryVector = knnFLoatChildrenQuery.getTargetCopy(); + int topK = knnFLoatChildrenQuery.getK(); + + Query childrenFilter = + getChildrenFilter(knnFLoatChildrenQuery.getFilter(), parentsFilter, allParentsBitSet); + + Query knnChildren = + new DiversifyingChildrenFloatKnnVectorQuery( + vectorField, queryVector, childrenFilter, topK, allParentsBitSet); + knnChildren = knnChildren.rewrite(req.getSearcher()); + this.setAppropriateChildrenListingTransformer(req, knnChildren); + + return new ToParentBlockJoinQuery( + knnChildren, allParentsBitSet, ScoreModeParser.parse(scoreMode)); + } else { + return new AllParentsAware( + childrenQuery, + getBitSetProducer(allParents), + ScoreModeParser.parse(scoreMode), + allParents); + } } + } catch (IOException e) { + throw new SolrException(SolrException.ErrorCode.SERVER_ERROR, e); + } } - private void setAppropriateChildrenListingTransformer(SolrQueryRequest request, Query knnOnVectorField) throws IOException { + private void setAppropriateChildrenListingTransformer( + SolrQueryRequest request, Query knnOnVectorField) throws IOException { QueryLimits currentLimits = QueryLimits.getCurrentLimits(); ReturnFields returnFields = currentLimits.getRsp().getReturnFields(); DocTransformer originalTransformer = returnFields.getTransformer(); @@ -178,14 +180,24 @@ private void setAppropriateChildrenListingTransformer(SolrQueryRequest request, } } - private boolean isFloatKnnQuery(List childrenClauses) { - return childrenClauses.size() == 1 - && childrenClauses.get(0).getQuery().getClass().equals(KnnFloatVectorQuery.class); + private KnnFloatVectorQuery getFloatKnnQuery(List childrenClauses) { + if (childrenClauses.size() == 1) { + Query query = childrenClauses.get(0).getQuery(); + if (query instanceof KnnFloatVectorQuery) { + return (KnnFloatVectorQuery) query; + } + } + return null; } - private boolean isByteKnnQuery(List childrenClauses) { - return childrenClauses.size() == 1 - && childrenClauses.get(0).getQuery().getClass().equals(KnnByteVectorQuery.class); + private KnnByteVectorQuery getBytetKnnQuery(List childrenClauses) { + if (childrenClauses.size() == 1) { + Query query = childrenClauses.get(0).getQuery(); + if (query instanceof KnnByteVectorQuery) { + return (KnnByteVectorQuery) query; + } + } + return null; } private Query getChildrenFilter( diff --git a/solr/core/src/test/org/apache/solr/search/join/BlockJoinNestedVectorsQParserTest.java b/solr/core/src/test/org/apache/solr/search/join/BlockJoinNestedVectorsQParserTest.java index bf0d9d6e9e51..3d198a1cded2 100644 --- a/solr/core/src/test/org/apache/solr/search/join/BlockJoinNestedVectorsQParserTest.java +++ b/solr/core/src/test/org/apache/solr/search/join/BlockJoinNestedVectorsQParserTest.java @@ -21,7 +21,6 @@ import java.util.List; import org.apache.solr.SolrTestCaseJ4; import org.apache.solr.common.SolrInputDocument; -import org.apache.solr.common.params.CommonParams; import org.apache.solr.util.RandomNoReverseMergePolicyFactory; import org.junit.BeforeClass; import org.junit.ClassRule; @@ -258,193 +257,199 @@ public void parentRetrievalByte_knnChildrenWithParentFilter_shouldReturnKnnParen "//result/doc[1]/str[@name='id'][.='8']", "//result/doc[2]/str[@name='id'][.='2']"); } - + @Test - public void parentRetrievalFloat_topKWithChildTransformerWithFilter_shouldUseOriginalChildTransformerFilter() { + public void + parentRetrievalFloat_topKWithChildTransformerWithFilter_shouldUseOriginalChildTransformerFilter() { assertQ( - req( - "fq", "parent_s:(b c)", - "q", "{!parent which=$allParents score=max v=$children.q}", - "fl", "id,score,vectors,vector,[child limit=2 fl=vector childFilter=$all_children]", - "children.q", "{!knn f=vector topK=3}" + FLOAT_QUERY_VECTOR, - "allParents", "parent_s:[* TO *]", - "all_children", "child_s:[* TO *]"), - "//result[@numFound='3']", - "//result/doc[1]/str[@name='id'][.='8']", - "//result/doc[1]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[1][.='10.0']", - "//result/doc[1]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[2][.='1.0']", - "//result/doc[1]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[3][.='1.0']", - "//result/doc[1]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[4][.='1.0']", - "//result/doc[1]/arr[@name='vectors'][1]/doc[2]/arr[@name='vector']/float[1][.='9.0']", - "//result/doc[1]/arr[@name='vectors'][1]/doc[2]/arr[@name='vector']/float[2][.='1.0']", - "//result/doc[1]/arr[@name='vectors'][1]/doc[2]/arr[@name='vector']/float[3][.='1.0']", - "//result/doc[1]/arr[@name='vectors'][1]/doc[2]/arr[@name='vector']/float[4][.='1.0']", - "//result/doc[2]/str[@name='id'][.='7']", - "//result/doc[2]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[1][.='13.0']", - "//result/doc[2]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[2][.='1.0']", - "//result/doc[2]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[3][.='1.0']", - "//result/doc[2]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[4][.='1.0']", - "//result/doc[2]/arr[@name='vectors'][1]/doc[2]/arr[@name='vector']/float[1][.='12.0']", - "//result/doc[2]/arr[@name='vectors'][1]/doc[2]/arr[@name='vector']/float[2][.='1.0']", - "//result/doc[2]/arr[@name='vectors'][1]/doc[2]/arr[@name='vector']/float[3][.='1.0']", - "//result/doc[2]/arr[@name='vectors'][1]/doc[2]/arr[@name='vector']/float[4][.='1.0']", - "//result/doc[3]/str[@name='id'][.='2']", - "//result/doc[3]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[1][.='28.0']", - "//result/doc[3]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[2][.='1.0']", - "//result/doc[3]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[3][.='1.0']", - "//result/doc[3]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[4][.='1.0']", - "//result/doc[3]/arr[@name='vectors'][1]/doc[2]/arr[@name='vector']/float[1][.='27.0']", - "//result/doc[3]/arr[@name='vectors'][1]/doc[2]/arr[@name='vector']/float[2][.='1.0']", - "//result/doc[3]/arr[@name='vectors'][1]/doc[2]/arr[@name='vector']/float[3][.='1.0']", - "//result/doc[3]/arr[@name='vectors'][1]/doc[2]/arr[@name='vector']/float[4][.='1.0']"); + req( + "fq", "parent_s:(b c)", + "q", "{!parent which=$allParents score=max v=$children.q}", + "fl", "id,score,vectors,vector,[child limit=2 fl=vector childFilter=$all_children]", + "children.q", "{!knn f=vector topK=3}" + FLOAT_QUERY_VECTOR, + "allParents", "parent_s:[* TO *]", + "all_children", "child_s:[* TO *]"), + "//result[@numFound='3']", + "//result/doc[1]/str[@name='id'][.='8']", + "//result/doc[1]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[1][.='10.0']", + "//result/doc[1]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[2][.='1.0']", + "//result/doc[1]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[3][.='1.0']", + "//result/doc[1]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[4][.='1.0']", + "//result/doc[1]/arr[@name='vectors'][1]/doc[2]/arr[@name='vector']/float[1][.='9.0']", + "//result/doc[1]/arr[@name='vectors'][1]/doc[2]/arr[@name='vector']/float[2][.='1.0']", + "//result/doc[1]/arr[@name='vectors'][1]/doc[2]/arr[@name='vector']/float[3][.='1.0']", + "//result/doc[1]/arr[@name='vectors'][1]/doc[2]/arr[@name='vector']/float[4][.='1.0']", + "//result/doc[2]/str[@name='id'][.='7']", + "//result/doc[2]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[1][.='13.0']", + "//result/doc[2]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[2][.='1.0']", + "//result/doc[2]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[3][.='1.0']", + "//result/doc[2]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[4][.='1.0']", + "//result/doc[2]/arr[@name='vectors'][1]/doc[2]/arr[@name='vector']/float[1][.='12.0']", + "//result/doc[2]/arr[@name='vectors'][1]/doc[2]/arr[@name='vector']/float[2][.='1.0']", + "//result/doc[2]/arr[@name='vectors'][1]/doc[2]/arr[@name='vector']/float[3][.='1.0']", + "//result/doc[2]/arr[@name='vectors'][1]/doc[2]/arr[@name='vector']/float[4][.='1.0']", + "//result/doc[3]/str[@name='id'][.='2']", + "//result/doc[3]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[1][.='28.0']", + "//result/doc[3]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[2][.='1.0']", + "//result/doc[3]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[3][.='1.0']", + "//result/doc[3]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[4][.='1.0']", + "//result/doc[3]/arr[@name='vectors'][1]/doc[2]/arr[@name='vector']/float[1][.='27.0']", + "//result/doc[3]/arr[@name='vectors'][1]/doc[2]/arr[@name='vector']/float[2][.='1.0']", + "//result/doc[3]/arr[@name='vectors'][1]/doc[2]/arr[@name='vector']/float[3][.='1.0']", + "//result/doc[3]/arr[@name='vectors'][1]/doc[2]/arr[@name='vector']/float[4][.='1.0']"); } @Test - public void parentRetrievalFloat_topKWithChildTransformerWithNoFilter_shouldUseBestChildrenVectorTransformerFilter() { + public void + parentRetrievalFloat_topKWithChildTransformerWithNoFilter_shouldUseBestChildrenVectorTransformerFilter() { assertQ( - req( - "fq", "parent_s:(b c)", - "q", "{!parent which=$allParents score=max v=$children.q}", - "fl", "id,score,vectors,vector,[child fl=vector]", - "children.q", "{!knn f=vector topK=3}" + FLOAT_QUERY_VECTOR, - "allParents", "parent_s:[* TO *]"), - "//result[@numFound='3']", - "//result/doc[1]/str[@name='id'][.='8']", - "//result/doc[1]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[1][.='8.0']", - "//result/doc[1]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[2][.='1.0']", - "//result/doc[1]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[3][.='1.0']", - "//result/doc[1]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[4][.='1.0']", - "//result/doc[2]/str[@name='id'][.='7']", - "//result/doc[2]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[1][.='11.0']", - "//result/doc[2]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[2][.='1.0']", - "//result/doc[2]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[3][.='1.0']", - "//result/doc[2]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[4][.='1.0']", - "//result/doc[3]/str[@name='id'][.='2']", - "//result/doc[3]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[1][.='26.0']", - "//result/doc[3]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[2][.='1.0']", - "//result/doc[3]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[3][.='1.0']", - "//result/doc[3]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[4][.='1.0']"); + req( + "fq", "parent_s:(b c)", + "q", "{!parent which=$allParents score=max v=$children.q}", + "fl", "id,score,vectors,vector,[child fl=vector]", + "children.q", "{!knn f=vector topK=3}" + FLOAT_QUERY_VECTOR, + "allParents", "parent_s:[* TO *]"), + "//result[@numFound='3']", + "//result/doc[1]/str[@name='id'][.='8']", + "//result/doc[1]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[1][.='8.0']", + "//result/doc[1]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[2][.='1.0']", + "//result/doc[1]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[3][.='1.0']", + "//result/doc[1]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[4][.='1.0']", + "//result/doc[2]/str[@name='id'][.='7']", + "//result/doc[2]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[1][.='11.0']", + "//result/doc[2]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[2][.='1.0']", + "//result/doc[2]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[3][.='1.0']", + "//result/doc[2]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[4][.='1.0']", + "//result/doc[3]/str[@name='id'][.='2']", + "//result/doc[3]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[1][.='26.0']", + "//result/doc[3]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[2][.='1.0']", + "//result/doc[3]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[3][.='1.0']", + "//result/doc[3]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[4][.='1.0']"); } @Test - public void parentRetrievalFloat_topKWithOnlyChildTransformerWithNoFilter_shouldUseBestChildrenVectorTransformerFilter() { + public void + parentRetrievalFloat_topKWithOnlyChildTransformerWithNoFilter_shouldUseBestChildrenVectorTransformerFilter() { assertQ( - req( - "fq", "parent_s:(b c)", - "q", "{!parent which=$allParents score=max v=$children.q}", - "fl", "id,vectors,vector,[child fl=vector]", - "children.q", "{!knn f=vector topK=3}" + FLOAT_QUERY_VECTOR, - "allParents", "parent_s:[* TO *]"), - "//result[@numFound='3']", - "//result/doc[1]/str[@name='id'][.='8']", - "//result/doc[1]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[1][.='8.0']", - "//result/doc[1]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[2][.='1.0']", - "//result/doc[1]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[3][.='1.0']", - "//result/doc[1]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[4][.='1.0']", - "//result/doc[2]/str[@name='id'][.='7']", - "//result/doc[2]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[1][.='11.0']", - "//result/doc[2]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[2][.='1.0']", - "//result/doc[2]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[3][.='1.0']", - "//result/doc[2]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[4][.='1.0']", - "//result/doc[3]/str[@name='id'][.='2']", - "//result/doc[3]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[1][.='26.0']", - "//result/doc[3]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[2][.='1.0']", - "//result/doc[3]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[3][.='1.0']", - "//result/doc[3]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[4][.='1.0']"); + req( + "fq", "parent_s:(b c)", + "q", "{!parent which=$allParents score=max v=$children.q}", + "fl", "id,vectors,vector,[child fl=vector]", + "children.q", "{!knn f=vector topK=3}" + FLOAT_QUERY_VECTOR, + "allParents", "parent_s:[* TO *]"), + "//result[@numFound='3']", + "//result/doc[1]/str[@name='id'][.='8']", + "//result/doc[1]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[1][.='8.0']", + "//result/doc[1]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[2][.='1.0']", + "//result/doc[1]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[3][.='1.0']", + "//result/doc[1]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[4][.='1.0']", + "//result/doc[2]/str[@name='id'][.='7']", + "//result/doc[2]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[1][.='11.0']", + "//result/doc[2]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[2][.='1.0']", + "//result/doc[2]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[3][.='1.0']", + "//result/doc[2]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[4][.='1.0']", + "//result/doc[3]/str[@name='id'][.='2']", + "//result/doc[3]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[1][.='26.0']", + "//result/doc[3]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[2][.='1.0']", + "//result/doc[3]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[3][.='1.0']", + "//result/doc[3]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[4][.='1.0']"); } @Test - public void parentRetrievalByte_topKWithChildTransformerWithFilter_shouldUseOriginalChildTransformerFilter() { + public void + parentRetrievalByte_topKWithChildTransformerWithFilter_shouldUseOriginalChildTransformerFilter() { assertQ( - req( - "fq", "parent_s:(b c)", - "q", "{!parent which=$allParents score=max v=$children.q}", - "fl", "id,score,vectors,vector_byte,[child limit=2 fl=vector_byte childFilter=$all_children]", - "children.q", "{!knn f=vector_byte topK=3}" + BYTE_QUERY_VECTOR, - "allParents", "parent_s:[* TO *]", - "all_children", "child_s:[* TO *]"), - "//result[@numFound='3']", - "//result/doc[1]/str[@name='id'][.='8']", - "//result/doc[1]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[1][.='10']", - "//result/doc[1]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[2][.='1']", - "//result/doc[1]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[3][.='1']", - "//result/doc[1]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[4][.='1']", - "//result/doc[1]/arr[@name='vectors'][1]/doc[2]/arr[@name='vector_byte']/int[1][.='9']", - "//result/doc[1]/arr[@name='vectors'][1]/doc[2]/arr[@name='vector_byte']/int[2][.='1']", - "//result/doc[1]/arr[@name='vectors'][1]/doc[2]/arr[@name='vector_byte']/int[3][.='1']", - "//result/doc[1]/arr[@name='vectors'][1]/doc[2]/arr[@name='vector_byte']/int[4][.='1']", - "//result/doc[2]/str[@name='id'][.='7']", - "//result/doc[2]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[1][.='13']", - "//result/doc[2]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[2][.='1']", - "//result/doc[2]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[3][.='1']", - "//result/doc[2]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[4][.='1']", - "//result/doc[2]/arr[@name='vectors'][1]/doc[2]/arr[@name='vector_byte']/int[1][.='12']", - "//result/doc[2]/arr[@name='vectors'][1]/doc[2]/arr[@name='vector_byte']/int[2][.='1']", - "//result/doc[2]/arr[@name='vectors'][1]/doc[2]/arr[@name='vector_byte']/int[3][.='1']", - "//result/doc[2]/arr[@name='vectors'][1]/doc[2]/arr[@name='vector_byte']/int[4][.='1']", - "//result/doc[3]/str[@name='id'][.='2']", - "//result/doc[3]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[1][.='28']", - "//result/doc[3]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[2][.='1']", - "//result/doc[3]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[3][.='1']", - "//result/doc[3]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[4][.='1']", - "//result/doc[3]/arr[@name='vectors'][1]/doc[2]/arr[@name='vector_byte']/int[1][.='27']", - "//result/doc[3]/arr[@name='vectors'][1]/doc[2]/arr[@name='vector_byte']/int[2][.='1']", - "//result/doc[3]/arr[@name='vectors'][1]/doc[2]/arr[@name='vector_byte']/int[3][.='1']", - "//result/doc[3]/arr[@name='vectors'][1]/doc[2]/arr[@name='vector_byte']/int[4][.='1']"); + req( + "fq", "parent_s:(b c)", + "q", "{!parent which=$allParents score=max v=$children.q}", + "fl", + "id,score,vectors,vector_byte,[child limit=2 fl=vector_byte childFilter=$all_children]", + "children.q", "{!knn f=vector_byte topK=3}" + BYTE_QUERY_VECTOR, + "allParents", "parent_s:[* TO *]", + "all_children", "child_s:[* TO *]"), + "//result[@numFound='3']", + "//result/doc[1]/str[@name='id'][.='8']", + "//result/doc[1]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[1][.='10']", + "//result/doc[1]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[2][.='1']", + "//result/doc[1]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[3][.='1']", + "//result/doc[1]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[4][.='1']", + "//result/doc[1]/arr[@name='vectors'][1]/doc[2]/arr[@name='vector_byte']/int[1][.='9']", + "//result/doc[1]/arr[@name='vectors'][1]/doc[2]/arr[@name='vector_byte']/int[2][.='1']", + "//result/doc[1]/arr[@name='vectors'][1]/doc[2]/arr[@name='vector_byte']/int[3][.='1']", + "//result/doc[1]/arr[@name='vectors'][1]/doc[2]/arr[@name='vector_byte']/int[4][.='1']", + "//result/doc[2]/str[@name='id'][.='7']", + "//result/doc[2]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[1][.='13']", + "//result/doc[2]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[2][.='1']", + "//result/doc[2]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[3][.='1']", + "//result/doc[2]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[4][.='1']", + "//result/doc[2]/arr[@name='vectors'][1]/doc[2]/arr[@name='vector_byte']/int[1][.='12']", + "//result/doc[2]/arr[@name='vectors'][1]/doc[2]/arr[@name='vector_byte']/int[2][.='1']", + "//result/doc[2]/arr[@name='vectors'][1]/doc[2]/arr[@name='vector_byte']/int[3][.='1']", + "//result/doc[2]/arr[@name='vectors'][1]/doc[2]/arr[@name='vector_byte']/int[4][.='1']", + "//result/doc[3]/str[@name='id'][.='2']", + "//result/doc[3]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[1][.='28']", + "//result/doc[3]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[2][.='1']", + "//result/doc[3]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[3][.='1']", + "//result/doc[3]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[4][.='1']", + "//result/doc[3]/arr[@name='vectors'][1]/doc[2]/arr[@name='vector_byte']/int[1][.='27']", + "//result/doc[3]/arr[@name='vectors'][1]/doc[2]/arr[@name='vector_byte']/int[2][.='1']", + "//result/doc[3]/arr[@name='vectors'][1]/doc[2]/arr[@name='vector_byte']/int[3][.='1']", + "//result/doc[3]/arr[@name='vectors'][1]/doc[2]/arr[@name='vector_byte']/int[4][.='1']"); } @Test - public void parentRetrievalByte_topKWithChildTransformerWithNoFilter_shouldUseBestChildrenVectorTransformerFilter() { + public void + parentRetrievalByte_topKWithChildTransformerWithNoFilter_shouldUseBestChildrenVectorTransformerFilter() { assertQ( - req( - "fq", "parent_s:(b c)", - "q", "{!parent which=$allParents score=max v=$children.q}", - "fl", "id,score,vectors,vector_byte,[child fl=vector_byte]", - "children.q", "{!knn f=vector_byte topK=3}" + BYTE_QUERY_VECTOR, - "allParents", "parent_s:[* TO *]"), - "//result[@numFound='3']", - "//result/doc[1]/str[@name='id'][.='8']", - "//result/doc[1]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[1][.='8']", - "//result/doc[1]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[2][.='1']", - "//result/doc[1]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[3][.='1']", - "//result/doc[1]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[4][.='1']", - "//result/doc[2]/str[@name='id'][.='7']", - "//result/doc[2]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[1][.='11']", - "//result/doc[2]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[2][.='1']", - "//result/doc[2]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[3][.='1']", - "//result/doc[2]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[4][.='1']", - "//result/doc[3]/str[@name='id'][.='2']", - "//result/doc[3]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[1][.='26']", - "//result/doc[3]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[2][.='1']", - "//result/doc[3]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[3][.='1']", - "//result/doc[3]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[4][.='1']"); + req( + "fq", "parent_s:(b c)", + "q", "{!parent which=$allParents score=max v=$children.q}", + "fl", "id,score,vectors,vector_byte,[child fl=vector_byte]", + "children.q", "{!knn f=vector_byte topK=3}" + BYTE_QUERY_VECTOR, + "allParents", "parent_s:[* TO *]"), + "//result[@numFound='3']", + "//result/doc[1]/str[@name='id'][.='8']", + "//result/doc[1]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[1][.='8']", + "//result/doc[1]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[2][.='1']", + "//result/doc[1]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[3][.='1']", + "//result/doc[1]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[4][.='1']", + "//result/doc[2]/str[@name='id'][.='7']", + "//result/doc[2]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[1][.='11']", + "//result/doc[2]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[2][.='1']", + "//result/doc[2]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[3][.='1']", + "//result/doc[2]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[4][.='1']", + "//result/doc[3]/str[@name='id'][.='2']", + "//result/doc[3]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[1][.='26']", + "//result/doc[3]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[2][.='1']", + "//result/doc[3]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[3][.='1']", + "//result/doc[3]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[4][.='1']"); } @Test - public void parentRetrievalByte_topKWithOnlyChildTransformerWithNoFilter_shouldUseBestChildrenVectorTransformerFilter() { + public void + parentRetrievalByte_topKWithOnlyChildTransformerWithNoFilter_shouldUseBestChildrenVectorTransformerFilter() { assertQ( - req( - "fq", "parent_s:(b c)", - "q", "{!parent which=$allParents score=max v=$children.q}", - "fl", "id,vectors,vector_byte,[child fl=vector_byte]", - "children.q", "{!knn f=vector_byte topK=3}" + BYTE_QUERY_VECTOR, - "allParents", "parent_s:[* TO *]"), - "//result[@numFound='3']", - "//result/doc[1]/str[@name='id'][.='8']", - "//result/doc[1]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[1][.='8']", - "//result/doc[1]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[2][.='1']", - "//result/doc[1]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[3][.='1']", - "//result/doc[1]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[4][.='1']", - "//result/doc[2]/str[@name='id'][.='7']", - "//result/doc[2]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[1][.='11']", - "//result/doc[2]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[2][.='1']", - "//result/doc[2]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[3][.='1']", - "//result/doc[2]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[4][.='1']", - "//result/doc[3]/str[@name='id'][.='2']", - "//result/doc[3]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[1][.='26']", - "//result/doc[3]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[2][.='1']", - "//result/doc[3]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[3][.='1']", - "//result/doc[3]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[4][.='1']"); + req( + "fq", "parent_s:(b c)", + "q", "{!parent which=$allParents score=max v=$children.q}", + "fl", "id,vectors,vector_byte,[child fl=vector_byte]", + "children.q", "{!knn f=vector_byte topK=3}" + BYTE_QUERY_VECTOR, + "allParents", "parent_s:[* TO *]"), + "//result[@numFound='3']", + "//result/doc[1]/str[@name='id'][.='8']", + "//result/doc[1]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[1][.='8']", + "//result/doc[1]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[2][.='1']", + "//result/doc[1]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[3][.='1']", + "//result/doc[1]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[4][.='1']", + "//result/doc[2]/str[@name='id'][.='7']", + "//result/doc[2]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[1][.='11']", + "//result/doc[2]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[2][.='1']", + "//result/doc[2]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[3][.='1']", + "//result/doc[2]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[4][.='1']", + "//result/doc[3]/str[@name='id'][.='2']", + "//result/doc[3]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[1][.='26']", + "//result/doc[3]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[2][.='1']", + "//result/doc[3]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[3][.='1']", + "//result/doc[3]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[4][.='1']"); } - } From cae962ed307f386debea75078a7489f94da4828a Mon Sep 17 00:00:00 2001 From: Alessandro Benedetti Date: Wed, 16 Apr 2025 17:49:02 +0100 Subject: [PATCH 21/43] minor refactor --- .../solr/search/join/BlockJoinParentQParser.java | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/solr/core/src/java/org/apache/solr/search/join/BlockJoinParentQParser.java b/solr/core/src/java/org/apache/solr/search/join/BlockJoinParentQParser.java index 7e64f385ac84..30dd200f40ad 100644 --- a/solr/core/src/java/org/apache/solr/search/join/BlockJoinParentQParser.java +++ b/solr/core/src/java/org/apache/solr/search/join/BlockJoinParentQParser.java @@ -98,10 +98,10 @@ protected Query createQuery(final Query allParents, BooleanQuery childrenQuery, try { List childrenClauses = childrenQuery.clauses(); KnnByteVectorQuery knnByteChildrenQuery = getBytetKnnQuery(childrenClauses); - if (knnByteChildrenQuery != null) { - BitSetProducer allParentsBitSet = getBitSetProducer(allParents); - BooleanQuery parentsFilter = getParentsFilter(); + BitSetProducer allParentsBitSet = getBitSetProducer(allParents); + BooleanQuery parentsFilter = getParentsFilter(); + if (knnByteChildrenQuery != null) { String vectorField = knnByteChildrenQuery.getField(); byte[] queryVector = knnByteChildrenQuery.getTargetCopy(); int topK = knnByteChildrenQuery.getK(); @@ -120,9 +120,6 @@ protected Query createQuery(final Query allParents, BooleanQuery childrenQuery, } else { KnnFloatVectorQuery knnFLoatChildrenQuery = getFloatKnnQuery(childrenClauses); if (knnFLoatChildrenQuery != null) { - BitSetProducer allParentsBitSet = getBitSetProducer(allParents); - BooleanQuery parentsFilter = getParentsFilter(); - String vectorField = knnFLoatChildrenQuery.getField(); float[] queryVector = knnFLoatChildrenQuery.getTargetCopy(); int topK = knnFLoatChildrenQuery.getK(); @@ -140,10 +137,7 @@ protected Query createQuery(final Query allParents, BooleanQuery childrenQuery, knnChildren, allParentsBitSet, ScoreModeParser.parse(scoreMode)); } else { return new AllParentsAware( - childrenQuery, - getBitSetProducer(allParents), - ScoreModeParser.parse(scoreMode), - allParents); + childrenQuery, allParentsBitSet, ScoreModeParser.parse(scoreMode), allParents); } } } catch (IOException e) { From 4a7fdacacc05f6fabd1275fbda077b52fd562c06 Mon Sep 17 00:00:00 2001 From: Alessandro Benedetti Date: Wed, 16 Apr 2025 17:52:06 +0100 Subject: [PATCH 22/43] Update solr/core/src/java/org/apache/solr/search/join/BlockJoinParentQParser.java Co-authored-by: Christine Poerschke --- .../org/apache/solr/search/join/BlockJoinParentQParser.java | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/solr/core/src/java/org/apache/solr/search/join/BlockJoinParentQParser.java b/solr/core/src/java/org/apache/solr/search/join/BlockJoinParentQParser.java index 30dd200f40ad..f882408da5d5 100644 --- a/solr/core/src/java/org/apache/solr/search/join/BlockJoinParentQParser.java +++ b/solr/core/src/java/org/apache/solr/search/join/BlockJoinParentQParser.java @@ -164,13 +164,11 @@ private void setAppropriateChildrenListingTransformer( noChildTransformer = false; } } - } else { - if ((originalTransformer instanceof ChildDocTransformer)) { + } else if ((originalTransformer instanceof ChildDocTransformer)) { ChildDocTransformer childTransformer = (ChildDocTransformer) originalTransformer; if (childTransformer.getChildDocSet() == null) { childTransformer.setChildDocSet(request.getSearcher().getDocSet(knnOnVectorField)); } - } } } From 6a6aed8ca0699b7390831b9da3d9ceb62af96bac Mon Sep 17 00:00:00 2001 From: Alessandro Benedetti Date: Wed, 16 Apr 2025 17:52:40 +0100 Subject: [PATCH 23/43] Update solr/core/src/java/org/apache/solr/search/join/BlockJoinParentQParser.java Co-authored-by: Christine Poerschke --- .../org/apache/solr/search/join/BlockJoinParentQParser.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/solr/core/src/java/org/apache/solr/search/join/BlockJoinParentQParser.java b/solr/core/src/java/org/apache/solr/search/join/BlockJoinParentQParser.java index f882408da5d5..1b43fcf79399 100644 --- a/solr/core/src/java/org/apache/solr/search/join/BlockJoinParentQParser.java +++ b/solr/core/src/java/org/apache/solr/search/join/BlockJoinParentQParser.java @@ -196,7 +196,7 @@ private Query getChildrenFilter( Query childrenKnnPreFilter, BooleanQuery parentsFilter, BitSetProducer allParentsBitSet) { Query childrenFilter = childrenKnnPreFilter; - if (parentsFilter.clauses().size() > 0) { + if (!parentsFilter.clauses().isEmpty()) { Query acceptedChildrenBasedOnParentsFilter = new ToChildBlockJoinQuery(parentsFilter, allParentsBitSet); // no scoring happens here BooleanQuery.Builder acceptedChildrenBuilder = createBuilder(); From db232ec61f418334e9b51d409e6150e08e88baa9 Mon Sep 17 00:00:00 2001 From: Alessandro Benedetti Date: Sat, 6 Dec 2025 02:10:47 +0100 Subject: [PATCH 24/43] new approach following feedback --- .../transform/ChildDocTransformer.java | 12 +- .../org/apache/solr/search/QueryLimits.java | 4 - .../search/join/BlockJoinChildQParser.java | 4 +- .../search/join/BlockJoinParentQParser.java | 144 +----------- .../solr/search/join/FiltersQParser.java | 2 +- .../apache/solr/search/neural/KnnQParser.java | 63 +++++ .../collection1/conf/schema-densevector.xml | 7 +- .../BlockJoinNestedVectorsQParserTest.java | 2 +- .../search/neural/KnnQParserChildTest.java | 217 ++++++++++++++++++ 9 files changed, 295 insertions(+), 160 deletions(-) create mode 100644 solr/core/src/test/org/apache/solr/search/neural/KnnQParserChildTest.java diff --git a/solr/core/src/java/org/apache/solr/response/transform/ChildDocTransformer.java b/solr/core/src/java/org/apache/solr/response/transform/ChildDocTransformer.java index 64dc9c461236..204aa7a6190c 100644 --- a/solr/core/src/java/org/apache/solr/response/transform/ChildDocTransformer.java +++ b/solr/core/src/java/org/apache/solr/response/transform/ChildDocTransformer.java @@ -52,14 +52,14 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -public class ChildDocTransformer extends DocTransformer { +class ChildDocTransformer extends DocTransformer { private static final Logger log = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass()); private static final String ANON_CHILD_KEY = "_childDocuments_"; private final String name; private final BitSetProducer parentsFilter; // if null; resolve parent via uniqueKey instead - private DocSet childDocSet; + private final DocSet childDocSet; private final int limit; private final boolean isNestedSchema; private final SolrReturnFields childReturnFields; @@ -97,14 +97,6 @@ public String[] getExtraRequestFields() { return extraRequestedFields; } - public DocSet getChildDocSet() { - return childDocSet; - } - - public void setChildDocSet(DocSet childDocSet) { - this.childDocSet = childDocSet; - } - private int getPrevRootGivenFilter(LeafReaderContext leafReaderContext, int segRootId) throws IOException { final BitSet segParentsBitSet = parentsFilter.getBitSet(leafReaderContext); diff --git a/solr/core/src/java/org/apache/solr/search/QueryLimits.java b/solr/core/src/java/org/apache/solr/search/QueryLimits.java index 9704c1ec4e9e..35ea4533e93c 100644 --- a/solr/core/src/java/org/apache/solr/search/QueryLimits.java +++ b/solr/core/src/java/org/apache/solr/search/QueryLimits.java @@ -247,8 +247,4 @@ public static QueryLimits getCurrentLimits() { final SolrRequestInfo info = SolrRequestInfo.getRequestInfo(); return info != null ? info.getLimits() : NONE; } - - public SolrQueryResponse getRsp() { - return rsp; - } } diff --git a/solr/core/src/java/org/apache/solr/search/join/BlockJoinChildQParser.java b/solr/core/src/java/org/apache/solr/search/join/BlockJoinChildQParser.java index c49d296b8e76..bb6c80db07a8 100644 --- a/solr/core/src/java/org/apache/solr/search/join/BlockJoinChildQParser.java +++ b/solr/core/src/java/org/apache/solr/search/join/BlockJoinChildQParser.java @@ -33,8 +33,8 @@ public BlockJoinChildQParser( } @Override - protected Query createQuery(Query allParents, BooleanQuery parentQuery, String scoreMode) { - return new ToChildBlockJoinQuery(parentQuery, getBitSetProducer(allParents)); + protected Query createQuery(Query parentListQuery, Query query, String scoreMode) { + return new ToChildBlockJoinQuery(query, getBitSetProducer(parentListQuery)); } @Override diff --git a/solr/core/src/java/org/apache/solr/search/join/BlockJoinParentQParser.java b/solr/core/src/java/org/apache/solr/search/join/BlockJoinParentQParser.java index 557979963321..1d73bbd78aa7 100644 --- a/solr/core/src/java/org/apache/solr/search/join/BlockJoinParentQParser.java +++ b/solr/core/src/java/org/apache/solr/search/join/BlockJoinParentQParser.java @@ -18,41 +18,26 @@ import java.io.IOException; import java.io.UncheckedIOException; -import java.util.List; import java.util.Objects; import org.apache.lucene.index.LeafReaderContext; -import org.apache.lucene.search.BooleanClause; -import org.apache.lucene.search.BooleanQuery; import org.apache.lucene.search.ConstantScoreScorer; import org.apache.lucene.search.ConstantScoreWeight; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.IndexSearcher; -import org.apache.lucene.search.KnnByteVectorQuery; -import org.apache.lucene.search.KnnFloatVectorQuery; import org.apache.lucene.search.Query; import org.apache.lucene.search.QueryVisitor; import org.apache.lucene.search.ScorerSupplier; import org.apache.lucene.search.Weight; import org.apache.lucene.search.join.BitSetProducer; -import org.apache.lucene.search.join.DiversifyingChildrenByteKnnVectorQuery; -import org.apache.lucene.search.join.DiversifyingChildrenFloatKnnVectorQuery; import org.apache.lucene.search.join.QueryBitSetProducer; import org.apache.lucene.search.join.ScoreMode; -import org.apache.lucene.search.join.ToChildBlockJoinQuery; import org.apache.lucene.search.join.ToParentBlockJoinQuery; import org.apache.lucene.util.BitSet; import org.apache.lucene.util.BitSetIterator; -import org.apache.solr.common.SolrException; import org.apache.solr.common.params.SolrParams; import org.apache.solr.request.SolrQueryRequest; -import org.apache.solr.response.transform.ChildDocTransformer; -import org.apache.solr.response.transform.DocTransformer; -import org.apache.solr.response.transform.DocTransformers; import org.apache.solr.search.ExtendedQueryBase; import org.apache.solr.search.QParser; -import org.apache.solr.search.QueryLimits; -import org.apache.solr.search.QueryUtils; -import org.apache.solr.search.ReturnFields; import org.apache.solr.search.SolrCache; import org.apache.solr.search.SyntaxError; import org.apache.solr.util.SolrDefaultScorerSupplier; @@ -83,7 +68,7 @@ protected Query parseParentFilter() throws SyntaxError { } @Override - protected Query wrapSubordinateClause(BooleanQuery subordinate) throws SyntaxError { + protected Query wrapSubordinateClause(Query subordinate) throws SyntaxError { String scoreMode = localParams.get("score", ScoreMode.None.name()); Query parentQ = parseParentFilter(); return createQuery(parentQ, subordinate, scoreMode); @@ -94,131 +79,10 @@ protected Query noClausesQuery() throws SyntaxError { return new BitSetProducerQuery(getBitSetProducer(parseParentFilter())); } - protected Query createQuery(final Query allParents, BooleanQuery childrenQuery, String scoreMode) + protected Query createQuery(final Query parentList, Query query, String scoreMode) throws SyntaxError { - try { - List childrenClauses = childrenQuery.clauses(); - KnnByteVectorQuery knnByteChildrenQuery = getBytetKnnQuery(childrenClauses); - BitSetProducer allParentsBitSet = getBitSetProducer(allParents); - BooleanQuery parentsFilter = getParentsFilter(); - - if (knnByteChildrenQuery != null) { - String vectorField = knnByteChildrenQuery.getField(); - byte[] queryVector = knnByteChildrenQuery.getTargetCopy(); - int topK = knnByteChildrenQuery.getK(); - - Query acceptedChildren = - getChildrenFilter(knnByteChildrenQuery.getFilter(), parentsFilter, allParentsBitSet); - - Query knnChildren = - new DiversifyingChildrenByteKnnVectorQuery( - vectorField, queryVector, acceptedChildren, topK, allParentsBitSet); - knnChildren = knnChildren.rewrite(req.getSearcher()); - this.setAppropriateChildrenListingTransformer(req, knnChildren); - - return new ToParentBlockJoinQuery( - knnChildren, allParentsBitSet, ScoreModeParser.parse(scoreMode)); - } else { - KnnFloatVectorQuery knnFLoatChildrenQuery = getFloatKnnQuery(childrenClauses); - if (knnFLoatChildrenQuery != null) { - String vectorField = knnFLoatChildrenQuery.getField(); - float[] queryVector = knnFLoatChildrenQuery.getTargetCopy(); - int topK = knnFLoatChildrenQuery.getK(); - - Query childrenFilter = - getChildrenFilter(knnFLoatChildrenQuery.getFilter(), parentsFilter, allParentsBitSet); - - Query knnChildren = - new DiversifyingChildrenFloatKnnVectorQuery( - vectorField, queryVector, childrenFilter, topK, allParentsBitSet); - knnChildren = knnChildren.rewrite(req.getSearcher()); - this.setAppropriateChildrenListingTransformer(req, knnChildren); - - return new ToParentBlockJoinQuery( - knnChildren, allParentsBitSet, ScoreModeParser.parse(scoreMode)); - } else { - return new AllParentsAware( - childrenQuery, allParentsBitSet, ScoreModeParser.parse(scoreMode), allParents); - } - } - } catch (IOException e) { - throw new SolrException(SolrException.ErrorCode.SERVER_ERROR, e); - } - } - - private void setAppropriateChildrenListingTransformer( - SolrQueryRequest request, Query knnOnVectorField) throws IOException { - QueryLimits currentLimits = QueryLimits.getCurrentLimits(); - ReturnFields returnFields = currentLimits.getRsp().getReturnFields(); - DocTransformer originalTransformer = returnFields.getTransformer(); - - if (originalTransformer instanceof DocTransformers) { - DocTransformers transformers = (DocTransformers) originalTransformer; - boolean noChildTransformer = true; - for (int i = 0; i < transformers.size() && noChildTransformer; i++) { - DocTransformer t = transformers.getTransformer(i); - if (t instanceof ChildDocTransformer) { - ChildDocTransformer childTransformer = (ChildDocTransformer) t; - if (childTransformer.getChildDocSet() == null) { - childTransformer.setChildDocSet(request.getSearcher().getDocSet(knnOnVectorField)); - } - noChildTransformer = false; - } - } - } else if ((originalTransformer instanceof ChildDocTransformer)) { - ChildDocTransformer childTransformer = (ChildDocTransformer) originalTransformer; - if (childTransformer.getChildDocSet() == null) { - childTransformer.setChildDocSet(request.getSearcher().getDocSet(knnOnVectorField)); - } - } - } - - private KnnFloatVectorQuery getFloatKnnQuery(List childrenClauses) { - if (childrenClauses.size() == 1) { - Query query = childrenClauses.get(0).getQuery(); - if (query instanceof KnnFloatVectorQuery) { - return (KnnFloatVectorQuery) query; - } - } - return null; - } - - private KnnByteVectorQuery getBytetKnnQuery(List childrenClauses) { - if (childrenClauses.size() == 1) { - Query query = childrenClauses.get(0).getQuery(); - if (query instanceof KnnByteVectorQuery) { - return (KnnByteVectorQuery) query; - } - } - return null; - } - - private Query getChildrenFilter( - Query childrenKnnPreFilter, BooleanQuery parentsFilter, BitSetProducer allParentsBitSet) { - Query childrenFilter = childrenKnnPreFilter; - - if (!parentsFilter.clauses().isEmpty()) { - Query acceptedChildrenBasedOnParentsFilter = - new ToChildBlockJoinQuery(parentsFilter, allParentsBitSet); // no scoring happens here - BooleanQuery.Builder acceptedChildrenBuilder = createBuilder(); - if (childrenFilter != null) { - acceptedChildrenBuilder.add(childrenFilter, BooleanClause.Occur.FILTER); - } - acceptedChildrenBuilder.add(acceptedChildrenBasedOnParentsFilter, BooleanClause.Occur.FILTER); - - childrenFilter = acceptedChildrenBuilder.build(); - } - return childrenFilter; - } - - private BooleanQuery getParentsFilter() throws SyntaxError { - List parentFilterQueries = QueryUtils.parseFilterQueries(req); - BooleanQuery.Builder acceptedParentsBuilder = createBuilder(); - for (Query filter : parentFilterQueries) { - acceptedParentsBuilder.add(filter, BooleanClause.Occur.FILTER); - } - BooleanQuery acceptedParents = acceptedParentsBuilder.build(); - return acceptedParents; + return new AllParentsAware( + query, getBitSetProducer(parentList), ScoreModeParser.parse(scoreMode), parentList); } BitSetProducer getBitSetProducer(Query query) { diff --git a/solr/core/src/java/org/apache/solr/search/join/FiltersQParser.java b/solr/core/src/java/org/apache/solr/search/join/FiltersQParser.java index 45036ebffece..05c705aa1ce1 100644 --- a/solr/core/src/java/org/apache/solr/search/join/FiltersQParser.java +++ b/solr/core/src/java/org/apache/solr/search/join/FiltersQParser.java @@ -73,7 +73,7 @@ protected Query unwrapQuery(Query query, BooleanClause.Occur occur) { return query; } - protected Query wrapSubordinateClause(BooleanQuery subordinate) throws SyntaxError { + protected Query wrapSubordinateClause(Query subordinate) throws SyntaxError { return subordinate; } diff --git a/solr/core/src/java/org/apache/solr/search/neural/KnnQParser.java b/solr/core/src/java/org/apache/solr/search/neural/KnnQParser.java index db355e0b84e6..5f22dcd9842b 100644 --- a/solr/core/src/java/org/apache/solr/search/neural/KnnQParser.java +++ b/solr/core/src/java/org/apache/solr/search/neural/KnnQParser.java @@ -17,14 +17,24 @@ package org.apache.solr.search.neural; import java.util.Optional; +import org.apache.lucene.index.VectorEncoding; +import org.apache.lucene.search.BooleanClause; +import org.apache.lucene.search.BooleanQuery; import org.apache.lucene.search.Query; +import org.apache.lucene.search.join.BitSetProducer; +import org.apache.lucene.search.join.DiversifyingChildrenByteKnnVectorQuery; +import org.apache.lucene.search.join.DiversifyingChildrenFloatKnnVectorQuery; +import org.apache.lucene.search.join.ToChildBlockJoinQuery; import org.apache.solr.common.SolrException; import org.apache.solr.common.params.SolrParams; import org.apache.solr.request.SolrQueryRequest; import org.apache.solr.schema.DenseVectorField; import org.apache.solr.schema.SchemaField; import org.apache.solr.search.QParser; +import org.apache.solr.search.QueryUtils; import org.apache.solr.search.SyntaxError; +import org.apache.solr.search.join.BlockJoinParentQParser; +import org.apache.solr.util.vector.DenseVectorParser; public class KnnQParser extends AbstractVectorQParserBase { @@ -41,6 +51,9 @@ public class KnnQParser extends AbstractVectorQParserBase { protected static final String SATURATION_THRESHOLD = "saturationThreshold"; protected static final String PATIENCE = "patience"; + public static final String CHILDREN_OF = "childrenOf"; + public static final String ALL_PARENTS = "allParents"; + public KnnQParser(String qstr, SolrParams localParams, SolrParams params, SolrQueryRequest req) { super(qstr, localParams, params, req); } @@ -104,12 +117,36 @@ protected Query getSeedQuery() throws SolrException, SyntaxError { @Override public Query parse() throws SyntaxError { + final String vectorField = getFieldName(); final SchemaField schemaField = req.getCore().getLatestSchema().getField(getFieldName()); final DenseVectorField denseVectorType = getCheckedFieldType(schemaField); final String vectorToSearch = getVectorToSearch(); final int topK = localParams.getInt(TOP_K, DEFAULT_TOP_K); final Integer filteredSearchThreshold = localParams.getInt(FILTERED_SEARCH_THRESHOLD); + // check for parent diversification logic... + final String parentsFilterQuery = localParams.get(CHILDREN_OF); + if (null != parentsFilterQuery) { + final String allParentsQuery = localParams.get(ALL_PARENTS); + final BitSetProducer allParentsBitSet = BlockJoinParentQParser.getCachedBitSetProducer(req, subQuery(allParentsQuery, null).getQuery()); + final BooleanQuery acceptedParents = getParentsFilter(parentsFilterQuery); + final DenseVectorParser vectorBuilder = + denseVectorType.getVectorBuilder(vectorToSearch, DenseVectorParser.BuilderPhase.QUERY); + final VectorEncoding vectorEncoding = denseVectorType.getVectorEncoding(); + Query acceptedChildren = getChildrenFilter(getFilterQuery(), acceptedParents, allParentsBitSet); + switch (vectorEncoding) { + case FLOAT32: + return new DiversifyingChildrenFloatKnnVectorQuery( + vectorField, vectorBuilder.getFloatVector(), acceptedChildren, topK, allParentsBitSet); + case BYTE: + return new DiversifyingChildrenByteKnnVectorQuery( + vectorField, vectorBuilder.getByteVector(), acceptedChildren, topK, allParentsBitSet); + default: + throw new SolrException( + SolrException.ErrorCode.SERVER_ERROR, + "Unexpected encoding. Vector Encoding: " + vectorEncoding); + }} + return denseVectorType.getKnnVectorQuery( schemaField.getName(), vectorToSearch, @@ -119,4 +156,30 @@ public Query parse() throws SyntaxError { getEarlyTerminationParams(), filteredSearchThreshold); } + + private BooleanQuery getParentsFilter(String parentsFilterQuery) throws SyntaxError { + final Query parentsFilter = subQuery(parentsFilterQuery, null).getQuery(); + BooleanQuery.Builder acceptedParentsBuilder = new BooleanQuery.Builder(); + acceptedParentsBuilder.add(parentsFilter, BooleanClause.Occur.FILTER); + BooleanQuery acceptedParents = acceptedParentsBuilder.build(); + return acceptedParents; + } + + private Query getChildrenFilter( + Query childrenKnnPreFilter, BooleanQuery parentsFilter, BitSetProducer allParentsBitSet) { + Query childrenFilter = childrenKnnPreFilter; + + if (!parentsFilter.clauses().isEmpty()) { + Query acceptedChildrenBasedOnParentsFilter = + new ToChildBlockJoinQuery(parentsFilter, allParentsBitSet); // no scoring happens here + BooleanQuery.Builder acceptedChildrenBuilder = new BooleanQuery.Builder(); + if (childrenFilter != null) { + acceptedChildrenBuilder.add(childrenFilter, BooleanClause.Occur.FILTER); + } + acceptedChildrenBuilder.add(acceptedChildrenBasedOnParentsFilter, BooleanClause.Occur.FILTER); + + childrenFilter = acceptedChildrenBuilder.build(); + } + return childrenFilter; + } } diff --git a/solr/core/src/test-files/solr/collection1/conf/schema-densevector.xml b/solr/core/src/test-files/solr/collection1/conf/schema-densevector.xml index 42db078a6e20..f3d663a40663 100644 --- a/solr/core/src/test-files/solr/collection1/conf/schema-densevector.xml +++ b/solr/core/src/test-files/solr/collection1/conf/schema-densevector.xml @@ -19,13 +19,14 @@ - + - + + @@ -34,6 +35,8 @@ + + diff --git a/solr/core/src/test/org/apache/solr/search/join/BlockJoinNestedVectorsQParserTest.java b/solr/core/src/test/org/apache/solr/search/join/BlockJoinNestedVectorsQParserTest.java index 3d198a1cded2..3bf9bb30fd6a 100644 --- a/solr/core/src/test/org/apache/solr/search/join/BlockJoinNestedVectorsQParserTest.java +++ b/solr/core/src/test/org/apache/solr/search/join/BlockJoinNestedVectorsQParserTest.java @@ -220,7 +220,7 @@ public void parentRetrievalByte_knnChildren_shouldReturnKnnParents() { // "fq", "parent_s:(a c)", "q", "{!parent which=$allParents score=max v=$children.q}", "fl", "id,score", - "children.q", "{!knn f=vector_byte topK=3}" + BYTE_QUERY_VECTOR, + "children.q", "{!knn f=vector_byte topK=3 childrenOf=$allParents' allParents=$allParents}" + BYTE_QUERY_VECTOR, "allParents", "parent_s:[* TO *]"), "//*[@numFound='3']", "//result/doc[1]/str[@name='id'][.='10']", diff --git a/solr/core/src/test/org/apache/solr/search/neural/KnnQParserChildTest.java b/solr/core/src/test/org/apache/solr/search/neural/KnnQParserChildTest.java new file mode 100644 index 000000000000..5c6bcf7d2918 --- /dev/null +++ b/solr/core/src/test/org/apache/solr/search/neural/KnnQParserChildTest.java @@ -0,0 +1,217 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.search.neural; + +import java.util.List; +import java.util.Random; +import java.util.stream.Collectors; +import org.apache.lucene.tests.util.TestUtil; +import org.apache.solr.SolrTestCaseJ4; +import org.apache.solr.common.SolrInputDocument; +import org.apache.solr.common.SolrInputField; +import org.junit.BeforeClass; + +public class KnnQParserChildTest extends SolrTestCaseJ4 { + + private static final int MAX_TOP_K = 100; + private static final int MIN_NUM_PARENTS = MAX_TOP_K * 10; + private static final int MIN_NUM_KIDS_PER_PARENT = 5; + + @BeforeClass + public static void prepareIndex() throws Exception { + /* vectorDimension="4" similarityFunction="cosine" */ + initCore("solrconfig_codec.xml", "schema-densevector.xml"); + + final int numParents = atLeast(MIN_NUM_PARENTS); + for (int p = 0; p < numParents; p++) { + final String parentId = "parent-" + p; + final SolrInputDocument parent = doc(f("id", parentId), f("type_s", "PARENT")); + final int numKids = atLeast(MIN_NUM_KIDS_PER_PARENT); + for (int k = 0; k < numKids; k++) { + final String kidId = parentId + "-kid-" + k; + final SolrInputDocument kid = + doc(f("id", kidId), f("parent_s", parentId), f("type_s", "KID")); + + kid.addField("vector", randomFloatVector(random())); + kid.addField("vector_byte_encoding", randomByteVector(random())); + + parent.addChildDocument(kid); + } + assertU(adoc(parent)); + if (rarely(random())) { + assertU(commit()); + } + } + assertU(commit()); + } + + /** Direct usage knn w/childOf to confim that a diverse set of child docs are returned */ + public void testDiverseKids() { + final int numIters = atLeast(100); + for (int iter = 0; iter < numIters; iter++) { + final String topK = "" + TestUtil.nextInt(random(), 2, MAX_TOP_K); + + // check floats... + assertQ( + req( + "q", + "{!knn f=vector topK=$k childrenOf='type_s:PARENT' allParents='type_s:PARENT'}" + + vecStr(randomFloatVector(random())), + "indent", + "true", + "fl", + "id,parent_s", + "_iter", + "" + iter, + "k", + topK, + "rows", + topK), + "*[count(//doc/str[@name='parent_s' and not(following::str[@name='parent_s']/text() = text())])=" + + topK + + "]", + "*[count(//doc/str[@name='id' and not(following::str[@name='id']/text() = text())])=" + + topK + + "]"); + + // check bytes... + assertQ( + req( + "q", + "{!knn f=vector_byte_encoding topK=$k childrenOf='type_s:PARENT' allParents='type_s:PARENT'}" + + vecStr(randomByteVector(random())), + "indent", + "true", + "fl", + "id,parent_s", + "_iter", + "" + iter, + "k", + topK, + "rows", + topK), + "*[count(//doc/str[@name='parent_s' and not(following::str[@name='parent_s']/text() = text())])=" + + topK + + "]", + "*[count(//doc/str[@name='id' and not(following::str[@name='id']/text() = text())])=" + + topK + + "]"); + } + } + + /** Sanity check that knn w/diversification works as expected when wrapped in parent query */ + public void testParentsOfDiverseKids() { + + final int numIters = atLeast(100); + for (int iter = 0; iter < numIters; iter++) { + final String topK = "" + TestUtil.nextInt(random(), 2, MAX_TOP_K); + + // check floats... + assertQ( + req( + "q", + "{!parent which='type_s:PARENT' score=max v=$knn}", + "knn", + "{!knn f=vector topK=$k childrenOf='type_s:PARENT' allParents='type_s:PARENT'}" + + vecStr(randomFloatVector(random())), + "indent", + "true", + "fl", + "id", + "_iter", + "" + iter, + "k", + topK, + "rows", + topK), + "*[count(//doc/str[@name='id' and not(following::str[@name='id']/text() = text())])=" + + topK + + "]"); + + // check bytes... + assertQ( + req( + "q", + "{!parent which='type_s:PARENT' score=max v=$knn}", + "knn", + "{!knn f=vector_byte_encoding topK=$k childrenOf='type_s:PARENT' allParents='type_s:PARENT'}" + + vecStr(randomByteVector(random())), + "indent", + "true", + "fl", + "id", + "_iter", + "" + iter, + "k", + topK, + "rows", + topK), + "*[count(//doc/str[@name='id' and not(following::str[@name='id']/text() = text())])=" + + topK + + "]"); + } + } + + /** Format a vector as a string for use in queries */ + protected static String vecStr(final List vector) { + return "[" + vector.stream().map(Object::toString).collect(Collectors.joining(",")) + "]"; + } + + /** Random vector of size 4 */ + protected static List randomFloatVector(Random r) { + // we don't want nextFloat() because it's bound by -1:1 + // but we also don't want NaN, or +/- Infinity (so we don't mess with intBitsToFloat) + // we could be fancier to get *all* the possible "real" floats, but this is good enough... + + // Note: bias first vec entry to ensure we never have an all zero vector (invalid w/cosine sim + // used in configs) + return List.of( + 1F + (r.nextFloat() * 10000F), + r.nextFloat() * 10000F, + r.nextFloat() * 10000F, + r.nextFloat() * 10000F); + } + + /** Random vector of size 4 */ + protected static List randomByteVector(Random r) { + final byte[] byteBuff = new byte[4]; + r.nextBytes(byteBuff); + // Note: bias first vec entry to ensure we never have an all zero vector (invalid w/cosine sim + // used in configs) + return List.of( + (Byte) (byte) (byteBuff[0] + 1), + (Byte) byteBuff[1], + (Byte) byteBuff[1], + (Byte) byteBuff[1]); + } + + /** Convenience method for building a SolrInputDocument */ + protected static SolrInputDocument doc(SolrInputField... fields) { + SolrInputDocument d = new SolrInputDocument(); + for (SolrInputField f : fields) { + d.put(f.getName(), f); + } + return d; + } + + /** Convenience method for building a SolrInputField */ + protected static SolrInputField f(String name, Object value) { + final SolrInputField f = new SolrInputField(name); + f.setValue(value); + return f; + } +} From 8c78b23dc35fcd0ce93ac9b8a1e48049dd1e834f Mon Sep 17 00:00:00 2001 From: Alessandro Benedetti Date: Tue, 9 Dec 2025 12:59:56 +0100 Subject: [PATCH 25/43] tests fixed for the new approach --- .../BlockJoinNestedVectorsQParserTest.java | 138 ++++++------------ 1 file changed, 45 insertions(+), 93 deletions(-) diff --git a/solr/core/src/test/org/apache/solr/search/join/BlockJoinNestedVectorsQParserTest.java b/solr/core/src/test/org/apache/solr/search/join/BlockJoinNestedVectorsQParserTest.java index 3bf9bb30fd6a..56789d05844f 100644 --- a/solr/core/src/test/org/apache/solr/search/join/BlockJoinNestedVectorsQParserTest.java +++ b/solr/core/src/test/org/apache/solr/search/join/BlockJoinNestedVectorsQParserTest.java @@ -172,10 +172,9 @@ public void childrenRetrievalByte_filteringByParentMetadata_shouldReturnKnnChild public void parentRetrievalFloat_knnChildren_shouldReturnKnnParents() { assertQ( req( - // "fq", "parent_s:(a c)", "q", "{!parent which=$allParents score=max v=$children.q}", "fl", "id,score", - "children.q", "{!knn f=vector topK=3}" + FLOAT_QUERY_VECTOR, + "children.q", "{!knn f=vector topK=3 childrenOf=$allParents allParents=$allParents}" + FLOAT_QUERY_VECTOR, "allParents", "parent_s:[* TO *]"), "//*[@numFound='3']", "//result/doc[1]/str[@name='id'][.='10']", @@ -184,14 +183,26 @@ public void parentRetrievalFloat_knnChildren_shouldReturnKnnParents() { } @Test - public void parentRetrievalFloat_knnChildrenWithParentFilter_shouldReturnKnnParents() { + public void parentRetrievalFloat_knnChildrenWithNoDiversifying_shouldReturnOneParent() { assertQ( req( - "fq", "parent_s:(a c)", "q", "{!parent which=$allParents score=max v=$children.q}", "fl", "id,score", "children.q", "{!knn f=vector topK=3}" + FLOAT_QUERY_VECTOR, "allParents", "parent_s:[* TO *]"), + "//*[@numFound='1']", + "//result/doc[1]/str[@name='id'][.='10']"); + } + + @Test + public void parentRetrievalFloat_knnChildrenWithParentFilter_shouldReturnKnnParents() { + assertQ( + req( + "q", "{!parent which=$allParents score=max v=$children.q}", + "fl", "id,score", + "children.q", "{!knn f=vector topK=3 childrenOf=$someParents allParents=$allParents}" + FLOAT_QUERY_VECTOR, + "allParents", "parent_s:[* TO *]", + "someParents", "parent_s:(a c)"), "//*[@numFound='3']", "//result/doc[1]/str[@name='id'][.='8']", "//result/doc[2]/str[@name='id'][.='6']", @@ -203,11 +214,11 @@ public void parentRetrievalFloat_knnChildrenWithParentFilter_shouldReturnKnnPare parentRetrievalFloat_knnChildrenWithParentFilterAndChildrenFilter_shouldReturnKnnParents() { assertQ( req( - "fq", "parent_s:(a c)", "q", "{!parent which=$allParents score=max v=$children.q}", "fl", "id,score", - "children.q", "{!knn f=vector topK=3 preFilter=child_s:m}" + FLOAT_QUERY_VECTOR, - "allParents", "parent_s:[* TO *]"), + "children.q", "{!knn f=vector topK=3 preFilter=child_s:m childrenOf=$someParents allParents=$allParents}" + FLOAT_QUERY_VECTOR, + "allParents", "parent_s:[* TO *]", + "someParents", "parent_s:(a c)"), "//*[@numFound='2']", "//result/doc[1]/str[@name='id'][.='8']", "//result/doc[2]/str[@name='id'][.='2']"); @@ -217,10 +228,9 @@ public void parentRetrievalFloat_knnChildrenWithParentFilter_shouldReturnKnnPare public void parentRetrievalByte_knnChildren_shouldReturnKnnParents() { assertQ( req( - // "fq", "parent_s:(a c)", "q", "{!parent which=$allParents score=max v=$children.q}", "fl", "id,score", - "children.q", "{!knn f=vector_byte topK=3 childrenOf=$allParents' allParents=$allParents}" + BYTE_QUERY_VECTOR, + "children.q", "{!knn f=vector_byte topK=3 childrenOf=$allParents allParents=$allParents}" + BYTE_QUERY_VECTOR, "allParents", "parent_s:[* TO *]"), "//*[@numFound='3']", "//result/doc[1]/str[@name='id'][.='10']", @@ -232,11 +242,11 @@ public void parentRetrievalByte_knnChildren_shouldReturnKnnParents() { public void parentRetrievalByte_knnChildrenWithParentFilter_shouldReturnKnnParents() { assertQ( req( - "fq", "parent_s:(a c)", "q", "{!parent which=$allParents score=max v=$children.q}", "fl", "id,score", - "children.q", "{!knn f=vector_byte topK=3}" + BYTE_QUERY_VECTOR, - "allParents", "parent_s:[* TO *]"), + "children.q", "{!knn f=vector_byte topK=3 childrenOf=$someParents allParents=$allParents}" + BYTE_QUERY_VECTOR, + "allParents", "parent_s:[* TO *]", + "someParents", "parent_s:(a c)"), "//*[@numFound='3']", "//result/doc[1]/str[@name='id'][.='8']", "//result/doc[2]/str[@name='id'][.='6']", @@ -248,11 +258,11 @@ public void parentRetrievalByte_knnChildrenWithParentFilter_shouldReturnKnnParen parentRetrievalByte_knnChildrenWithParentFilterAndChildrenFilter_shouldReturnKnnParents() { assertQ( req( - "fq", "parent_s:(a c)", "q", "{!parent which=$allParents score=max v=$children.q}", "fl", "id,score", - "children.q", "{!knn f=vector_byte topK=3 preFilter=child_s:m}" + BYTE_QUERY_VECTOR, - "allParents", "parent_s:[* TO *]"), + "children.q", "{!knn f=vector_byte topK=3 preFilter=child_s:m childrenOf=$someParents allParents=$allParents}" + BYTE_QUERY_VECTOR, + "allParents", "parent_s:[* TO *]", + "someParents", "parent_s:(a c)"), "//*[@numFound='2']", "//result/doc[1]/str[@name='id'][.='8']", "//result/doc[2]/str[@name='id'][.='2']"); @@ -263,12 +273,11 @@ public void parentRetrievalByte_knnChildrenWithParentFilter_shouldReturnKnnParen parentRetrievalFloat_topKWithChildTransformerWithFilter_shouldUseOriginalChildTransformerFilter() { assertQ( req( - "fq", "parent_s:(b c)", "q", "{!parent which=$allParents score=max v=$children.q}", - "fl", "id,score,vectors,vector,[child limit=2 fl=vector childFilter=$all_children]", - "children.q", "{!knn f=vector topK=3}" + FLOAT_QUERY_VECTOR, + "fl", "id,score,vectors,vector,[child limit=2 fl=vector]", + "children.q", "{!knn f=vector topK=3 childrenOf=$someParents allParents=$allParents}" + FLOAT_QUERY_VECTOR, "allParents", "parent_s:[* TO *]", - "all_children", "child_s:[* TO *]"), + "someParents", "parent_s:(a c)"), "//result[@numFound='3']", "//result/doc[1]/str[@name='id'][.='8']", "//result/doc[1]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[1][.='10.0']", @@ -279,12 +288,12 @@ public void parentRetrievalByte_knnChildrenWithParentFilter_shouldReturnKnnParen "//result/doc[1]/arr[@name='vectors'][1]/doc[2]/arr[@name='vector']/float[2][.='1.0']", "//result/doc[1]/arr[@name='vectors'][1]/doc[2]/arr[@name='vector']/float[3][.='1.0']", "//result/doc[1]/arr[@name='vectors'][1]/doc[2]/arr[@name='vector']/float[4][.='1.0']", - "//result/doc[2]/str[@name='id'][.='7']", - "//result/doc[2]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[1][.='13.0']", + "//result/doc[2]/str[@name='id'][.='6']", + "//result/doc[2]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[1][.='16.0']", "//result/doc[2]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[2][.='1.0']", "//result/doc[2]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[3][.='1.0']", "//result/doc[2]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[4][.='1.0']", - "//result/doc[2]/arr[@name='vectors'][1]/doc[2]/arr[@name='vector']/float[1][.='12.0']", + "//result/doc[2]/arr[@name='vectors'][1]/doc[2]/arr[@name='vector']/float[1][.='15.0']", "//result/doc[2]/arr[@name='vectors'][1]/doc[2]/arr[@name='vector']/float[2][.='1.0']", "//result/doc[2]/arr[@name='vectors'][1]/doc[2]/arr[@name='vector']/float[3][.='1.0']", "//result/doc[2]/arr[@name='vectors'][1]/doc[2]/arr[@name='vector']/float[4][.='1.0']", @@ -301,14 +310,14 @@ public void parentRetrievalByte_knnChildrenWithParentFilter_shouldReturnKnnParen @Test public void - parentRetrievalFloat_topKWithChildTransformerWithNoFilter_shouldUseBestChildrenVectorTransformerFilter() { + parentRetrievalFloat_topKWithChildTransformerWithFilter_shouldReturnBestChild() { assertQ( req( - "fq", "parent_s:(b c)", "q", "{!parent which=$allParents score=max v=$children.q}", - "fl", "id,score,vectors,vector,[child fl=vector]", - "children.q", "{!knn f=vector topK=3}" + FLOAT_QUERY_VECTOR, - "allParents", "parent_s:[* TO *]"), + "fl", "id,score,vectors,vector,[child fl=vector childFilter=$children.q]", + "children.q", "{!knn f=vector topK=3 childrenOf=$someParents allParents=$allParents}" + FLOAT_QUERY_VECTOR, + "allParents", "parent_s:[* TO *]", + "someParents", "parent_s:(b c)"), "//result[@numFound='3']", "//result/doc[1]/str[@name='id'][.='8']", "//result/doc[1]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[1][.='8.0']", @@ -327,46 +336,18 @@ public void parentRetrievalByte_knnChildrenWithParentFilter_shouldReturnKnnParen "//result/doc[3]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[4][.='1.0']"); } - @Test - public void - parentRetrievalFloat_topKWithOnlyChildTransformerWithNoFilter_shouldUseBestChildrenVectorTransformerFilter() { - assertQ( - req( - "fq", "parent_s:(b c)", - "q", "{!parent which=$allParents score=max v=$children.q}", - "fl", "id,vectors,vector,[child fl=vector]", - "children.q", "{!knn f=vector topK=3}" + FLOAT_QUERY_VECTOR, - "allParents", "parent_s:[* TO *]"), - "//result[@numFound='3']", - "//result/doc[1]/str[@name='id'][.='8']", - "//result/doc[1]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[1][.='8.0']", - "//result/doc[1]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[2][.='1.0']", - "//result/doc[1]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[3][.='1.0']", - "//result/doc[1]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[4][.='1.0']", - "//result/doc[2]/str[@name='id'][.='7']", - "//result/doc[2]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[1][.='11.0']", - "//result/doc[2]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[2][.='1.0']", - "//result/doc[2]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[3][.='1.0']", - "//result/doc[2]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[4][.='1.0']", - "//result/doc[3]/str[@name='id'][.='2']", - "//result/doc[3]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[1][.='26.0']", - "//result/doc[3]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[2][.='1.0']", - "//result/doc[3]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[3][.='1.0']", - "//result/doc[3]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[4][.='1.0']"); - } @Test public void - parentRetrievalByte_topKWithChildTransformerWithFilter_shouldUseOriginalChildTransformerFilter() { + parentRetrievalByte_topKWithChildTransformer_shouldReturnAllChildren() { assertQ( req( - "fq", "parent_s:(b c)", "q", "{!parent which=$allParents score=max v=$children.q}", "fl", - "id,score,vectors,vector_byte,[child limit=2 fl=vector_byte childFilter=$all_children]", - "children.q", "{!knn f=vector_byte topK=3}" + BYTE_QUERY_VECTOR, + "id,score,vectors,vector_byte,[child limit=2 fl=vector_byte]", + "children.q", "{!knn f=vector_byte topK=3 childrenOf=$someParents allParents=$allParents}" + BYTE_QUERY_VECTOR, "allParents", "parent_s:[* TO *]", - "all_children", "child_s:[* TO *]"), + "someParents", "parent_s:(b c)"), "//result[@numFound='3']", "//result/doc[1]/str[@name='id'][.='8']", "//result/doc[1]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[1][.='10']", @@ -398,43 +379,14 @@ public void parentRetrievalByte_knnChildrenWithParentFilter_shouldReturnKnnParen } @Test - public void - parentRetrievalByte_topKWithChildTransformerWithNoFilter_shouldUseBestChildrenVectorTransformerFilter() { + public void parentRetrievalByte_topKWithChildTransformerWithFilter_shouldReturnBestChild() { assertQ( req( - "fq", "parent_s:(b c)", "q", "{!parent which=$allParents score=max v=$children.q}", - "fl", "id,score,vectors,vector_byte,[child fl=vector_byte]", - "children.q", "{!knn f=vector_byte topK=3}" + BYTE_QUERY_VECTOR, - "allParents", "parent_s:[* TO *]"), - "//result[@numFound='3']", - "//result/doc[1]/str[@name='id'][.='8']", - "//result/doc[1]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[1][.='8']", - "//result/doc[1]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[2][.='1']", - "//result/doc[1]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[3][.='1']", - "//result/doc[1]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[4][.='1']", - "//result/doc[2]/str[@name='id'][.='7']", - "//result/doc[2]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[1][.='11']", - "//result/doc[2]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[2][.='1']", - "//result/doc[2]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[3][.='1']", - "//result/doc[2]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[4][.='1']", - "//result/doc[3]/str[@name='id'][.='2']", - "//result/doc[3]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[1][.='26']", - "//result/doc[3]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[2][.='1']", - "//result/doc[3]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[3][.='1']", - "//result/doc[3]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[4][.='1']"); - } - - @Test - public void - parentRetrievalByte_topKWithOnlyChildTransformerWithNoFilter_shouldUseBestChildrenVectorTransformerFilter() { - assertQ( - req( - "fq", "parent_s:(b c)", - "q", "{!parent which=$allParents score=max v=$children.q}", - "fl", "id,vectors,vector_byte,[child fl=vector_byte]", - "children.q", "{!knn f=vector_byte topK=3}" + BYTE_QUERY_VECTOR, - "allParents", "parent_s:[* TO *]"), + "fl", "id,score,vectors,vector_byte,[child fl=vector_byte childFilter=$children.q]", + "children.q", "{!knn f=vector_byte topK=3 childrenOf=$someParents allParents=$allParents}" + BYTE_QUERY_VECTOR, + "allParents", "parent_s:[* TO *]", + "someParents", "parent_s:(b c)"), "//result[@numFound='3']", "//result/doc[1]/str[@name='id'][.='8']", "//result/doc[1]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[1][.='8']", From 6e13aca64e42155dc8aafb370f52e5a8675c5460 Mon Sep 17 00:00:00 2001 From: Alessandro Benedetti Date: Tue, 9 Dec 2025 17:55:16 +0100 Subject: [PATCH 26/43] tidy + documentation --- .../apache/solr/search/neural/KnnQParser.java | 17 +++-- .../BlockJoinNestedVectorsQParserTest.java | 50 ++++++++++----- .../pages/dense-vector-search.adoc | 46 ++++++++++++- .../pages/searching-nested-documents.adoc | 64 +++++++++++++------ 4 files changed, 133 insertions(+), 44 deletions(-) diff --git a/solr/core/src/java/org/apache/solr/search/neural/KnnQParser.java b/solr/core/src/java/org/apache/solr/search/neural/KnnQParser.java index 5f22dcd9842b..b50be79d861a 100644 --- a/solr/core/src/java/org/apache/solr/search/neural/KnnQParser.java +++ b/solr/core/src/java/org/apache/solr/search/neural/KnnQParser.java @@ -31,7 +31,6 @@ import org.apache.solr.schema.DenseVectorField; import org.apache.solr.schema.SchemaField; import org.apache.solr.search.QParser; -import org.apache.solr.search.QueryUtils; import org.apache.solr.search.SyntaxError; import org.apache.solr.search.join.BlockJoinParentQParser; import org.apache.solr.util.vector.DenseVectorParser; @@ -128,16 +127,23 @@ public Query parse() throws SyntaxError { final String parentsFilterQuery = localParams.get(CHILDREN_OF); if (null != parentsFilterQuery) { final String allParentsQuery = localParams.get(ALL_PARENTS); - final BitSetProducer allParentsBitSet = BlockJoinParentQParser.getCachedBitSetProducer(req, subQuery(allParentsQuery, null).getQuery()); + final BitSetProducer allParentsBitSet = + BlockJoinParentQParser.getCachedBitSetProducer( + req, subQuery(allParentsQuery, null).getQuery()); final BooleanQuery acceptedParents = getParentsFilter(parentsFilterQuery); final DenseVectorParser vectorBuilder = denseVectorType.getVectorBuilder(vectorToSearch, DenseVectorParser.BuilderPhase.QUERY); final VectorEncoding vectorEncoding = denseVectorType.getVectorEncoding(); - Query acceptedChildren = getChildrenFilter(getFilterQuery(), acceptedParents, allParentsBitSet); + Query acceptedChildren = + getChildrenFilter(getFilterQuery(), acceptedParents, allParentsBitSet); switch (vectorEncoding) { case FLOAT32: return new DiversifyingChildrenFloatKnnVectorQuery( - vectorField, vectorBuilder.getFloatVector(), acceptedChildren, topK, allParentsBitSet); + vectorField, + vectorBuilder.getFloatVector(), + acceptedChildren, + topK, + allParentsBitSet); case BYTE: return new DiversifyingChildrenByteKnnVectorQuery( vectorField, vectorBuilder.getByteVector(), acceptedChildren, topK, allParentsBitSet); @@ -145,7 +151,8 @@ public Query parse() throws SyntaxError { throw new SolrException( SolrException.ErrorCode.SERVER_ERROR, "Unexpected encoding. Vector Encoding: " + vectorEncoding); - }} + } + } return denseVectorType.getKnnVectorQuery( schemaField.getName(), diff --git a/solr/core/src/test/org/apache/solr/search/join/BlockJoinNestedVectorsQParserTest.java b/solr/core/src/test/org/apache/solr/search/join/BlockJoinNestedVectorsQParserTest.java index 56789d05844f..8ab5ad92113c 100644 --- a/solr/core/src/test/org/apache/solr/search/join/BlockJoinNestedVectorsQParserTest.java +++ b/solr/core/src/test/org/apache/solr/search/join/BlockJoinNestedVectorsQParserTest.java @@ -174,7 +174,9 @@ public void parentRetrievalFloat_knnChildren_shouldReturnKnnParents() { req( "q", "{!parent which=$allParents score=max v=$children.q}", "fl", "id,score", - "children.q", "{!knn f=vector topK=3 childrenOf=$allParents allParents=$allParents}" + FLOAT_QUERY_VECTOR, + "children.q", + "{!knn f=vector topK=3 childrenOf=$allParents allParents=$allParents}" + + FLOAT_QUERY_VECTOR, "allParents", "parent_s:[* TO *]"), "//*[@numFound='3']", "//result/doc[1]/str[@name='id'][.='10']", @@ -200,7 +202,9 @@ public void parentRetrievalFloat_knnChildrenWithParentFilter_shouldReturnKnnPare req( "q", "{!parent which=$allParents score=max v=$children.q}", "fl", "id,score", - "children.q", "{!knn f=vector topK=3 childrenOf=$someParents allParents=$allParents}" + FLOAT_QUERY_VECTOR, + "children.q", + "{!knn f=vector topK=3 childrenOf=$someParents allParents=$allParents}" + + FLOAT_QUERY_VECTOR, "allParents", "parent_s:[* TO *]", "someParents", "parent_s:(a c)"), "//*[@numFound='3']", @@ -216,7 +220,9 @@ public void parentRetrievalFloat_knnChildrenWithParentFilter_shouldReturnKnnPare req( "q", "{!parent which=$allParents score=max v=$children.q}", "fl", "id,score", - "children.q", "{!knn f=vector topK=3 preFilter=child_s:m childrenOf=$someParents allParents=$allParents}" + FLOAT_QUERY_VECTOR, + "children.q", + "{!knn f=vector topK=3 preFilter=child_s:m childrenOf=$someParents allParents=$allParents}" + + FLOAT_QUERY_VECTOR, "allParents", "parent_s:[* TO *]", "someParents", "parent_s:(a c)"), "//*[@numFound='2']", @@ -230,7 +236,9 @@ public void parentRetrievalByte_knnChildren_shouldReturnKnnParents() { req( "q", "{!parent which=$allParents score=max v=$children.q}", "fl", "id,score", - "children.q", "{!knn f=vector_byte topK=3 childrenOf=$allParents allParents=$allParents}" + BYTE_QUERY_VECTOR, + "children.q", + "{!knn f=vector_byte topK=3 childrenOf=$allParents allParents=$allParents}" + + BYTE_QUERY_VECTOR, "allParents", "parent_s:[* TO *]"), "//*[@numFound='3']", "//result/doc[1]/str[@name='id'][.='10']", @@ -244,7 +252,9 @@ public void parentRetrievalByte_knnChildrenWithParentFilter_shouldReturnKnnParen req( "q", "{!parent which=$allParents score=max v=$children.q}", "fl", "id,score", - "children.q", "{!knn f=vector_byte topK=3 childrenOf=$someParents allParents=$allParents}" + BYTE_QUERY_VECTOR, + "children.q", + "{!knn f=vector_byte topK=3 childrenOf=$someParents allParents=$allParents}" + + BYTE_QUERY_VECTOR, "allParents", "parent_s:[* TO *]", "someParents", "parent_s:(a c)"), "//*[@numFound='3']", @@ -260,7 +270,9 @@ public void parentRetrievalByte_knnChildrenWithParentFilter_shouldReturnKnnParen req( "q", "{!parent which=$allParents score=max v=$children.q}", "fl", "id,score", - "children.q", "{!knn f=vector_byte topK=3 preFilter=child_s:m childrenOf=$someParents allParents=$allParents}" + BYTE_QUERY_VECTOR, + "children.q", + "{!knn f=vector_byte topK=3 preFilter=child_s:m childrenOf=$someParents allParents=$allParents}" + + BYTE_QUERY_VECTOR, "allParents", "parent_s:[* TO *]", "someParents", "parent_s:(a c)"), "//*[@numFound='2']", @@ -275,7 +287,9 @@ public void parentRetrievalByte_knnChildrenWithParentFilter_shouldReturnKnnParen req( "q", "{!parent which=$allParents score=max v=$children.q}", "fl", "id,score,vectors,vector,[child limit=2 fl=vector]", - "children.q", "{!knn f=vector topK=3 childrenOf=$someParents allParents=$allParents}" + FLOAT_QUERY_VECTOR, + "children.q", + "{!knn f=vector topK=3 childrenOf=$someParents allParents=$allParents}" + + FLOAT_QUERY_VECTOR, "allParents", "parent_s:[* TO *]", "someParents", "parent_s:(a c)"), "//result[@numFound='3']", @@ -309,13 +323,14 @@ public void parentRetrievalByte_knnChildrenWithParentFilter_shouldReturnKnnParen } @Test - public void - parentRetrievalFloat_topKWithChildTransformerWithFilter_shouldReturnBestChild() { + public void parentRetrievalFloat_topKWithChildTransformerWithFilter_shouldReturnBestChild() { assertQ( req( "q", "{!parent which=$allParents score=max v=$children.q}", "fl", "id,score,vectors,vector,[child fl=vector childFilter=$children.q]", - "children.q", "{!knn f=vector topK=3 childrenOf=$someParents allParents=$allParents}" + FLOAT_QUERY_VECTOR, + "children.q", + "{!knn f=vector topK=3 childrenOf=$someParents allParents=$allParents}" + + FLOAT_QUERY_VECTOR, "allParents", "parent_s:[* TO *]", "someParents", "parent_s:(b c)"), "//result[@numFound='3']", @@ -336,16 +351,15 @@ public void parentRetrievalByte_knnChildrenWithParentFilter_shouldReturnKnnParen "//result/doc[3]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[4][.='1.0']"); } - @Test - public void - parentRetrievalByte_topKWithChildTransformer_shouldReturnAllChildren() { + public void parentRetrievalByte_topKWithChildTransformer_shouldReturnAllChildren() { assertQ( req( "q", "{!parent which=$allParents score=max v=$children.q}", - "fl", - "id,score,vectors,vector_byte,[child limit=2 fl=vector_byte]", - "children.q", "{!knn f=vector_byte topK=3 childrenOf=$someParents allParents=$allParents}" + BYTE_QUERY_VECTOR, + "fl", "id,score,vectors,vector_byte,[child limit=2 fl=vector_byte]", + "children.q", + "{!knn f=vector_byte topK=3 childrenOf=$someParents allParents=$allParents}" + + BYTE_QUERY_VECTOR, "allParents", "parent_s:[* TO *]", "someParents", "parent_s:(b c)"), "//result[@numFound='3']", @@ -384,7 +398,9 @@ public void parentRetrievalByte_topKWithChildTransformerWithFilter_shouldReturnB req( "q", "{!parent which=$allParents score=max v=$children.q}", "fl", "id,score,vectors,vector_byte,[child fl=vector_byte childFilter=$children.q]", - "children.q", "{!knn f=vector_byte topK=3 childrenOf=$someParents allParents=$allParents}" + BYTE_QUERY_VECTOR, + "children.q", + "{!knn f=vector_byte topK=3 childrenOf=$someParents allParents=$allParents}" + + BYTE_QUERY_VECTOR, "allParents", "parent_s:[* TO *]", "someParents", "parent_s:(b c)"), "//result[@numFound='3']", diff --git a/solr/solr-ref-guide/modules/query-guide/pages/dense-vector-search.adoc b/solr/solr-ref-guide/modules/query-guide/pages/dense-vector-search.adoc index 26bbcff16709..af65a273be60 100644 --- a/solr/solr-ref-guide/modules/query-guide/pages/dense-vector-search.adoc +++ b/solr/solr-ref-guide/modules/query-guide/pages/dense-vector-search.adoc @@ -347,7 +347,7 @@ Apache Solr provides three query parsers that work with dense vector fields, tha All parsers return scores for retrieved documents that are the approximate distance to the target vector (defined by the similarityFunction configured at indexing time) and both support "Pre-Filtering" the document graph to reduce the number of candidate vectors evaluated (without needing to compute their vector similarity distances). -Common parameters for both query parsers are: +Common parameters for all query parsers are: `f`:: + @@ -499,6 +499,50 @@ Here is an example of a `knn` search using a `filteredSearchThreshold`: [source,text] ?q={!knn f=vector topK=10 filteredSearchThreshold=60}[1.0, 2.0, 3.0, 4.0] +`childrenOf`:: ++ +[%autowidth,frame=none] +|=== +|Optional |Default: none +|=== ++ +A query that enables the Lucene’s implementation of {lucene-javadocs}/join/org/apache/lucene/search/join/DiversifyingChildrenFloatKnnVectorQuery.html[Diversifying Knn Query] . ++ +This parameter is meant to be a filter query on parent document metadata. +The knn search returns the top-k nearest children documents that satify the filter on the parent. ++ +Only one child per distinct parent is returned. + +Here is an example of a `knn` search using a `childrenOf`: + +[source,text] +?q={!knn f=vector topK=3 childrenOf=$someParents allParents=$allParents}[1.0, 2.0, 3.0, 4.0] +&allParents=*:* -_nest_path_:* +&someParents=color_s:RED + +The search results retrieved are the k=3 nearest documents to the vector in input `[1.0, 2.0, 3.0, 4.0]`, each of them with a different parent. Only the documents with a parent that satisfy the 'color_s:RED' condition are considered candidates for the ANN search. + +`allParents`:: ++ +[%autowidth,frame=none] +|=== +|Optional |Default: none +|Mandatory if using 'childrenOf' parameter|Default: none +|=== ++ +A query that matches ALL parents. +It's required to work with the 'childrenOf' parameter. + + +Here is an example of a `knn` search using a `childrenOf`: + +[source,text] +?q={!knn f=vector topK=3 childrenOf=$someParents allParents=$allParents}[1.0, 2.0, 3.0, 4.0] +&allParents=*:* -_nest_path_:* +&someParents=color_s:RED + +The search results retrieved are the k=3 nearest documents to the vector in input `[1.0, 2.0, 3.0, 4.0]`, each of them with a different parent. The 'allParents' parameter must return all parents to guarantee the correct functioning of the query. + === knn_text_to_vector Query Parser The `knn_text_to_vector` query parser encode a textual query to a vector using a dedicated Large Language Model(fine tuned for the task of encoding text to vector for sentence similarity) and matches k-nearest neighbours documents to such query vector. diff --git a/solr/solr-ref-guide/modules/query-guide/pages/searching-nested-documents.adoc b/solr/solr-ref-guide/modules/query-guide/pages/searching-nested-documents.adoc index 531fa3c147c0..2fd8af1b084c 100644 --- a/solr/solr-ref-guide/modules/query-guide/pages/searching-nested-documents.adoc +++ b/solr/solr-ref-guide/modules/query-guide/pages/searching-nested-documents.adoc @@ -190,13 +190,13 @@ $ curl 'http://localhost:8983/solr/gettingstarted/select' -d 'omitHeader=true' - ==== It is quite common to encode the original text of a document into multiple nested vectors. -This may happen, among other use cases, because you chunked the original text into paragraphs, each of them modeled as a nested document with the paragraph text and the vector representation. +This may happen, among other use cases, because you chunked the original text into paragraphs, each of them modeled as a nested document with the paragraph text and its vector representation. -Solr doesn't need to have denormalised nested documents, you can still retrieve the children paragraphs by knn vector search and prefilter them using parent level metadata. +You don't need the redundant parent metadata in each child document (the traditional flat approach), you can still retrieve the children paragraphs by knn vector search, prefilter them using parent level metadata and finally retrieve only K parents if needed. [source,text] ---- -$ curl 'http://localhost:8983/solr/gettingstarted/select' -d 'omitHeader=true' --data-urlencode 'fq={!child of=$block_mask filters=$parentsFilter}&q={!knn f=childVectorField topK=5}[1.0,2.5,3.0...]' --data-urlencode 'block_mask=(*:* -{!prefix f="_nest_path_" v="/skus/" parentsFilter="name_s:pen"})' +$ curl 'http://localhost:8983/solr/gettingstarted/select' -d 'omitHeader=true' --data-urlencode 'fq={!child of=$block_mask filters=$parentsFilter}&q={!knn f=childVectorField topK=5 childrenOf=$someParents allParents=$allParents}[1.0,2.5,3.0...]' --data-urlencode 'block_mask=(*:* -{!prefix f="_nest_path_" v="/skus/" parentsFilter="name_s:pen"})' ---- ==== @@ -281,24 +281,6 @@ $ curl 'http://localhost:8983/solr/gettingstarted/select' -d 'omitHeader=true' - Note that in the above example, the `/` characters in the `\_nest_path_` were "double escaped" in the `which` parameter, for the <> regarding the `{!child} pasers `of` parameter. ==== -[#vector-search-parent -[CAUTION] -.Vector search - children are nested documents with a vector field -==== -It is quite common to encode the original text of a document into multiple nested vectors. - -This may happen, among other use cases, because you chunked the original text into paragraphs, each of them modeled as a nested document with the paragraph text and the vector representation. - -You can run knn vector search on children documents (with potential prefiltering on children and/or parents metadata) and retrieve top-K parents. - -N.B. Solr ensures that the knn search for children keeps track of parent metadata filtering, guaranteeing top-k parents retrieval - -[source,text] ----- -$ curl 'http://localhost:8983/solr/gettingstarted/select' -d 'omitHeader=true' --data-urlencode 'fq=parentField:term&q={!parent which=$block_mask score=max v=$children.q}' --data-urlencode 'block_mask=(*:* -{!prefix f="_nest_path_" v="/skus/" children.q="{!knn f=vector topK=3 preFilter=childField:term}"[1.0,2.5,3.0...]})' ----- -==== - === Combining Block Join Query Parsers with Child Doc Transformer The combination of these two parsers with the `[child]` transformer enables seamless creation of very powerful queries. @@ -330,3 +312,43 @@ $ curl 'http://localhost:8983/solr/gettingstarted/select' -d 'omitHeader=true' - "_version_":1676585794196733952}]}] }} ---- + +=== Nested Vectors search through Block Join Query Parsers and Child Doc Transformer + + + +When dealing with vector search a possible use case involves having multiple vectors for a single document. + +This in Solr can be implemented with the block join and nested documents. +Each nested document has a vector field (among other metadata) + +This may happen, among other use cases, because you chunked the original text into paragraphs, each of them modeled as a nested document with the paragraph text and the vector representation. + +You can run knn vector search on children documents (with potential prefiltering on children and/or parents metadata) and retrieve top-K parents. + +N.B. Solr ensures that the knn search for children keeps track of parent metadata filtering, guaranteeing top-k parents retrieval + +An example: +[source,text] +?q={!parent which=$allParents score=max v=$children.q}& +children.q={!knn f=vector topK=3 childrenOf=$someParents allParents=$allParents}[1.0, 2.0, 3.0, 4.0]& +allParents=*:* -_nest_path_:*& +someParents=color_s:RED& +fl=id,score,vectors,vector,[child fl=vector childFilter=$children.q] + +The search results retrieved are the top k=3 parents of the nearest children to the vector in input `[1.0, 2.0, 3.0, 4.0]`, ranked by the `similarityFunction` configured at indexing time. + +Let's decompose the query to better explain it: +[source,text] +?q={!parent which=$allParents score=max v=$children.q} + +This query returns the parent solr documents using the block join parent query parser on a query that filters on the children documents ('children.q'). For each child retrieved, the parent is returned. + +[source,text] +children.q={!knn f=vector topK=3 childrenOf=$someParents allParents=$allParents}[1.0, 2.0, 3.0, 4.0] + +This query is a knn vector search on the children documents. +Specifically it retrieves the top k=3 children documents, filtered by 'someParent' metadata. +This query ensures only one child per parent is retrieved. + +---- From a2aff99e35cb1152d11271af4576c60cf90e8bbc Mon Sep 17 00:00:00 2001 From: Alessandro Benedetti Date: Tue, 9 Dec 2025 18:10:32 +0100 Subject: [PATCH 27/43] tidy + documentation --- .../modules/query-guide/pages/dense-vector-search.adoc | 3 --- 1 file changed, 3 deletions(-) diff --git a/solr/solr-ref-guide/modules/query-guide/pages/dense-vector-search.adoc b/solr/solr-ref-guide/modules/query-guide/pages/dense-vector-search.adoc index af65a273be60..def2e9cbeb2b 100644 --- a/solr/solr-ref-guide/modules/query-guide/pages/dense-vector-search.adoc +++ b/solr/solr-ref-guide/modules/query-guide/pages/dense-vector-search.adoc @@ -505,9 +505,6 @@ Here is an example of a `knn` search using a `filteredSearchThreshold`: |=== |Optional |Default: none |=== -+ -A query that enables the Lucene’s implementation of {lucene-javadocs}/join/org/apache/lucene/search/join/DiversifyingChildrenFloatKnnVectorQuery.html[Diversifying Knn Query] . -+ This parameter is meant to be a filter query on parent document metadata. The knn search returns the top-k nearest children documents that satify the filter on the parent. + From bac533ef0d5a34b887935aa2906421bb6c3a0630 Mon Sep 17 00:00:00 2001 From: Alessandro Benedetti Date: Tue, 9 Dec 2025 18:44:49 +0100 Subject: [PATCH 28/43] tidy + documentation --- changelog/unreleased/SOLR-17736.yml | 8 ++++++ .../apache/solr/search/neural/KnnQParser.java | 25 +++++++++++++----- .../BlockJoinNestedVectorsQParserTest.java | 26 ++++++++++++++----- .../search/neural/KnnQParserChildTest.java | 8 +++--- 4 files changed, 50 insertions(+), 17 deletions(-) create mode 100644 changelog/unreleased/SOLR-17736.yml diff --git a/changelog/unreleased/SOLR-17736.yml b/changelog/unreleased/SOLR-17736.yml new file mode 100644 index 000000000000..675d0a365c39 --- /dev/null +++ b/changelog/unreleased/SOLR-17736.yml @@ -0,0 +1,8 @@ +# See https://github.com/apache/solr/blob/main/dev-docs/changelog.adoc +title: Introducing support for nested vector search, enabling the retrieval of nested documents diversified by parent. This enables multi valued vectors scenarios and best child retrieval per parent. +type: added # added, changed, fixed, deprecated, removed, dependency_update, security, other +authors: + - name: Alessandro Benedetti +links: + - name: SOLR-17736 + url: https://issues.apache.org/jira/browse/SOLR-17736 diff --git a/solr/core/src/java/org/apache/solr/search/neural/KnnQParser.java b/solr/core/src/java/org/apache/solr/search/neural/KnnQParser.java index b50be79d861a..61926efb36de 100644 --- a/solr/core/src/java/org/apache/solr/search/neural/KnnQParser.java +++ b/solr/core/src/java/org/apache/solr/search/neural/KnnQParser.java @@ -125,15 +125,24 @@ public Query parse() throws SyntaxError { // check for parent diversification logic... final String parentsFilterQuery = localParams.get(CHILDREN_OF); - if (null != parentsFilterQuery) { - final String allParentsQuery = localParams.get(ALL_PARENTS); + final String allParentsQuery = localParams.get(ALL_PARENTS); + + boolean isDiversifyingChildrenKnnQuery = null != parentsFilterQuery || null != allParentsQuery; + if (isDiversifyingChildrenKnnQuery) { + if (null == allParentsQuery) { + throw new SolrException( + SolrException.ErrorCode.BAD_REQUEST, + "When running a diversifying children KNN query, 'allParents' parameter is required"); + } + final DenseVectorParser vectorBuilder = + denseVectorType.getVectorBuilder(vectorToSearch, DenseVectorParser.BuilderPhase.QUERY); + final VectorEncoding vectorEncoding = denseVectorType.getVectorEncoding(); + final BitSetProducer allParentsBitSet = BlockJoinParentQParser.getCachedBitSetProducer( req, subQuery(allParentsQuery, null).getQuery()); final BooleanQuery acceptedParents = getParentsFilter(parentsFilterQuery); - final DenseVectorParser vectorBuilder = - denseVectorType.getVectorBuilder(vectorToSearch, DenseVectorParser.BuilderPhase.QUERY); - final VectorEncoding vectorEncoding = denseVectorType.getVectorEncoding(); + Query acceptedChildren = getChildrenFilter(getFilterQuery(), acceptedParents, allParentsBitSet); switch (vectorEncoding) { @@ -165,9 +174,11 @@ public Query parse() throws SyntaxError { } private BooleanQuery getParentsFilter(String parentsFilterQuery) throws SyntaxError { - final Query parentsFilter = subQuery(parentsFilterQuery, null).getQuery(); BooleanQuery.Builder acceptedParentsBuilder = new BooleanQuery.Builder(); - acceptedParentsBuilder.add(parentsFilter, BooleanClause.Occur.FILTER); + if (parentsFilterQuery != null) { + final Query parentsFilter = subQuery(parentsFilterQuery, null).getQuery(); + acceptedParentsBuilder.add(parentsFilter, BooleanClause.Occur.FILTER); + } BooleanQuery acceptedParents = acceptedParentsBuilder.build(); return acceptedParents; } diff --git a/solr/core/src/test/org/apache/solr/search/join/BlockJoinNestedVectorsQParserTest.java b/solr/core/src/test/org/apache/solr/search/join/BlockJoinNestedVectorsQParserTest.java index 8ab5ad92113c..65bc72e3bb61 100644 --- a/solr/core/src/test/org/apache/solr/search/join/BlockJoinNestedVectorsQParserTest.java +++ b/solr/core/src/test/org/apache/solr/search/join/BlockJoinNestedVectorsQParserTest.java @@ -134,6 +134,24 @@ private static List outDistanceByte(List vector, int value) { return result; } + @Test + public void parentRetrieval_knnChildrenDiversifyingWithNoAllParents_shouldThrowException() { + assertQEx( + "When running a diversifying children KNN query, 'allParents' parameter is required", + req( + "q", + "{!parent which=$allParents score=max v=$children.q}", + "fl", + "id,score", + "children.q", + "{!knn f=vector topK=3 childrenOf=$someParents}" + FLOAT_QUERY_VECTOR, + "allParents", + "parent_s:[* TO *]", + "someParents", + "parent_s:(a c)"), + 400); + } + @Test public void childrenRetrievalFloat_filteringByParentMetadata_shouldReturnKnnChildren() { assertQ( @@ -174,9 +192,7 @@ public void parentRetrievalFloat_knnChildren_shouldReturnKnnParents() { req( "q", "{!parent which=$allParents score=max v=$children.q}", "fl", "id,score", - "children.q", - "{!knn f=vector topK=3 childrenOf=$allParents allParents=$allParents}" - + FLOAT_QUERY_VECTOR, + "children.q", "{!knn f=vector topK=3 allParents=$allParents}" + FLOAT_QUERY_VECTOR, "allParents", "parent_s:[* TO *]"), "//*[@numFound='3']", "//result/doc[1]/str[@name='id'][.='10']", @@ -236,9 +252,7 @@ public void parentRetrievalByte_knnChildren_shouldReturnKnnParents() { req( "q", "{!parent which=$allParents score=max v=$children.q}", "fl", "id,score", - "children.q", - "{!knn f=vector_byte topK=3 childrenOf=$allParents allParents=$allParents}" - + BYTE_QUERY_VECTOR, + "children.q", "{!knn f=vector_byte topK=3 allParents=$allParents}" + BYTE_QUERY_VECTOR, "allParents", "parent_s:[* TO *]"), "//*[@numFound='3']", "//result/doc[1]/str[@name='id'][.='10']", diff --git a/solr/core/src/test/org/apache/solr/search/neural/KnnQParserChildTest.java b/solr/core/src/test/org/apache/solr/search/neural/KnnQParserChildTest.java index 5c6bcf7d2918..ecf9e0ee31fa 100644 --- a/solr/core/src/test/org/apache/solr/search/neural/KnnQParserChildTest.java +++ b/solr/core/src/test/org/apache/solr/search/neural/KnnQParserChildTest.java @@ -69,7 +69,7 @@ public void testDiverseKids() { assertQ( req( "q", - "{!knn f=vector topK=$k childrenOf='type_s:PARENT' allParents='type_s:PARENT'}" + "{!knn f=vector topK=$k allParents='type_s:PARENT'}" + vecStr(randomFloatVector(random())), "indent", "true", @@ -92,7 +92,7 @@ public void testDiverseKids() { assertQ( req( "q", - "{!knn f=vector_byte_encoding topK=$k childrenOf='type_s:PARENT' allParents='type_s:PARENT'}" + "{!knn f=vector_byte_encoding topK=$k allParents='type_s:PARENT'}" + vecStr(randomByteVector(random())), "indent", "true", @@ -126,7 +126,7 @@ public void testParentsOfDiverseKids() { "q", "{!parent which='type_s:PARENT' score=max v=$knn}", "knn", - "{!knn f=vector topK=$k childrenOf='type_s:PARENT' allParents='type_s:PARENT'}" + "{!knn f=vector topK=$k allParents='type_s:PARENT'}" + vecStr(randomFloatVector(random())), "indent", "true", @@ -148,7 +148,7 @@ public void testParentsOfDiverseKids() { "q", "{!parent which='type_s:PARENT' score=max v=$knn}", "knn", - "{!knn f=vector_byte_encoding topK=$k childrenOf='type_s:PARENT' allParents='type_s:PARENT'}" + "{!knn f=vector_byte_encoding topK=$k allParents='type_s:PARENT'}" + vecStr(randomByteVector(random())), "indent", "true", From de531d2bdc65e223094d970ed772b7187642538c Mon Sep 17 00:00:00 2001 From: Alessandro Benedetti Date: Wed, 10 Dec 2025 17:30:21 +0100 Subject: [PATCH 29/43] tidy + documentation --- .../query-guide/pages/dense-vector-search.adoc | 5 ++--- .../pages/searching-nested-documents.adoc | 16 ---------------- 2 files changed, 2 insertions(+), 19 deletions(-) diff --git a/solr/solr-ref-guide/modules/query-guide/pages/dense-vector-search.adoc b/solr/solr-ref-guide/modules/query-guide/pages/dense-vector-search.adoc index 8480169490c1..eeadca91c7bd 100644 --- a/solr/solr-ref-guide/modules/query-guide/pages/dense-vector-search.adoc +++ b/solr/solr-ref-guide/modules/query-guide/pages/dense-vector-search.adoc @@ -534,12 +534,11 @@ A query that matches ALL parents. It's required to work with the 'childrenOf' parameter. -Here is an example of a `knn` search using a `childrenOf`: +Here is an example of a `knn` search using a `allParents`: [source,text] -?q={!knn f=vector topK=3 childrenOf=$someParents allParents=$allParents}[1.0, 2.0, 3.0, 4.0] +?q={!knn f=vector topK=3 allParents=$allParents}[1.0, 2.0, 3.0, 4.0] &allParents=*:* -_nest_path_:* -&someParents=color_s:RED The search results retrieved are the k=3 nearest documents to the vector in input `[1.0, 2.0, 3.0, 4.0]`, each of them with a different parent. The 'allParents' parameter must return all parents to guarantee the correct functioning of the query. diff --git a/solr/solr-ref-guide/modules/query-guide/pages/searching-nested-documents.adoc b/solr/solr-ref-guide/modules/query-guide/pages/searching-nested-documents.adoc index 2fd8af1b084c..e5a75762013f 100644 --- a/solr/solr-ref-guide/modules/query-guide/pages/searching-nested-documents.adoc +++ b/solr/solr-ref-guide/modules/query-guide/pages/searching-nested-documents.adoc @@ -184,22 +184,6 @@ $ curl 'http://localhost:8983/solr/gettingstarted/select' -d 'omitHeader=true' - ---- ==== -[#vector-search-child] -[CAUTION] -.Vector search - children are nested documents with a vector field -==== -It is quite common to encode the original text of a document into multiple nested vectors. - -This may happen, among other use cases, because you chunked the original text into paragraphs, each of them modeled as a nested document with the paragraph text and its vector representation. - -You don't need the redundant parent metadata in each child document (the traditional flat approach), you can still retrieve the children paragraphs by knn vector search, prefilter them using parent level metadata and finally retrieve only K parents if needed. - -[source,text] ----- -$ curl 'http://localhost:8983/solr/gettingstarted/select' -d 'omitHeader=true' --data-urlencode 'fq={!child of=$block_mask filters=$parentsFilter}&q={!knn f=childVectorField topK=5 childrenOf=$someParents allParents=$allParents}[1.0,2.5,3.0...]' --data-urlencode 'block_mask=(*:* -{!prefix f="_nest_path_" v="/skus/" parentsFilter="name_s:pen"})' ----- -==== - === Parent Query Parser The inverse of the `{!child}` query parser is the `{!parent}` query parser, which lets you search for the _ancestor_ documents of some child documents matching a wrapped query. From 227870b339b27822e5d230636e3d943542d182a5 Mon Sep 17 00:00:00 2001 From: Alessandro Benedetti Date: Thu, 11 Dec 2025 12:15:29 +0100 Subject: [PATCH 30/43] catching up and merging --- .../transform/ChildDocTransformer.java | 6 +- .../apache/solr/schema/DenseVectorField.java | 2 - .../org/apache/solr/search/QueryLimits.java | 5 - .../org/apache/solr/search/ReturnFields.java | 2 - .../apache/solr/search/SolrReturnFields.java | 12 +- .../apache/solr/search/neural/KnnQParser.java | 0 .../NestedUpdateProcessorFactory.java | 251 +++++++++--------- .../search/vector/KnnQParserChildTest.java | 2 +- .../KnnQParserMultiValuedVectorsTest.java | 209 ++++++++------- 9 files changed, 238 insertions(+), 251 deletions(-) delete mode 100644 solr/core/src/java/org/apache/solr/search/neural/KnnQParser.java diff --git a/solr/core/src/java/org/apache/solr/response/transform/ChildDocTransformer.java b/solr/core/src/java/org/apache/solr/response/transform/ChildDocTransformer.java index 47eb7ab786a4..ad2f480e91fe 100644 --- a/solr/core/src/java/org/apache/solr/response/transform/ChildDocTransformer.java +++ b/solr/core/src/java/org/apache/solr/response/transform/ChildDocTransformer.java @@ -59,16 +59,12 @@ public class ChildDocTransformer extends DocTransformer { private final String name; private final BitSetProducer parentsFilter; // if null; resolve parent via uniqueKey instead - private DocSet childDocSet; + private final DocSet childDocSet; private final int limit; private final boolean isNestedSchema; private final SolrReturnFields childReturnFields; private final String[] extraRequestedFields; - public void setChildDocSet(DocSet childDocSet) { - this.childDocSet = childDocSet; - } - ChildDocTransformer( String name, BitSetProducer parentsFilter, diff --git a/solr/core/src/java/org/apache/solr/schema/DenseVectorField.java b/solr/core/src/java/org/apache/solr/schema/DenseVectorField.java index 3613fd33ee83..28317e1da745 100644 --- a/solr/core/src/java/org/apache/solr/schema/DenseVectorField.java +++ b/solr/core/src/java/org/apache/solr/schema/DenseVectorField.java @@ -19,9 +19,7 @@ import static java.util.Optional.ofNullable; import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH; import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN; -import static org.apache.solr.schema.IndexSchema.NEST_PATH_FIELD_NAME; -import java.io.IOException; import java.lang.invoke.MethodHandles; import java.util.ArrayList; import java.util.List; diff --git a/solr/core/src/java/org/apache/solr/search/QueryLimits.java b/solr/core/src/java/org/apache/solr/search/QueryLimits.java index 862c8896f2c7..35ea4533e93c 100644 --- a/solr/core/src/java/org/apache/solr/search/QueryLimits.java +++ b/solr/core/src/java/org/apache/solr/search/QueryLimits.java @@ -49,11 +49,6 @@ public final class QueryLimits implements QueryTimeout { public static final QueryLimits NONE = new QueryLimits(); private final SolrQueryResponse rsp; - - public SolrQueryResponse getRsp() { - return rsp; - } - private final boolean allowPartialResults; // short-circuit the checks if any limit has been tripped diff --git a/solr/core/src/java/org/apache/solr/search/ReturnFields.java b/solr/core/src/java/org/apache/solr/search/ReturnFields.java index 9035b22fedf4..e01db9f1f5c5 100644 --- a/solr/core/src/java/org/apache/solr/search/ReturnFields.java +++ b/solr/core/src/java/org/apache/solr/search/ReturnFields.java @@ -115,6 +115,4 @@ public Set getNonScoreDependentReturnFieldNames() { /** Returns the DocTransformer used to modify documents, or null */ public abstract DocTransformer getTransformer(); - - public abstract void setTransformer(DocTransformer transformer); } diff --git a/solr/core/src/java/org/apache/solr/search/SolrReturnFields.java b/solr/core/src/java/org/apache/solr/search/SolrReturnFields.java index 1b0fa41bd3dc..939cf9b778c2 100644 --- a/solr/core/src/java/org/apache/solr/search/SolrReturnFields.java +++ b/solr/core/src/java/org/apache/solr/search/SolrReturnFields.java @@ -298,6 +298,7 @@ private void add( sp.pos = start; field = null; } + if (field == null) { // We didn't find a simple name, so let's see if it's a globbed field name. // Globbing only works with field names of the recommended form (roughly like java @@ -515,12 +516,6 @@ private void addField( augmenters.addTransformer(new ScoreAugmenter(disp)); scoreDependentFields.put(disp, disp.equals(SCORE) ? "" : SCORE); } - /* - if("vector_multivalued".equals(field)){ - ChildDocTransformerFactory childFactory = new ChildDocTransformerFactory(); - DocTransformer multiValuedTrans = childFactory.create("vector_multivalued", null, null); - augmenters.addTransformer(multiValuedTrans); - }*/ } @Override @@ -589,11 +584,6 @@ public DocTransformer getTransformer() { return transformer; } - @Override - public void setTransformer(DocTransformer transformer) { -this.transformer = transformer; - } - @Override public String toString() { final StringBuilder sb = new StringBuilder("SolrReturnFields=("); diff --git a/solr/core/src/java/org/apache/solr/search/neural/KnnQParser.java b/solr/core/src/java/org/apache/solr/search/neural/KnnQParser.java deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/solr/core/src/java/org/apache/solr/update/processor/NestedUpdateProcessorFactory.java b/solr/core/src/java/org/apache/solr/update/processor/NestedUpdateProcessorFactory.java index 273e8b928388..ba3b3c8419fb 100644 --- a/solr/core/src/java/org/apache/solr/update/processor/NestedUpdateProcessorFactory.java +++ b/solr/core/src/java/org/apache/solr/update/processor/NestedUpdateProcessorFactory.java @@ -20,7 +20,6 @@ import java.io.IOException; import java.util.ArrayList; import java.util.Collection; - import org.apache.solr.common.SolrException; import org.apache.solr.common.SolrInputDocument; import org.apache.solr.common.SolrInputField; @@ -41,144 +40,144 @@ */ public class NestedUpdateProcessorFactory extends UpdateRequestProcessorFactory { - @Override - public UpdateRequestProcessor getInstance( - SolrQueryRequest req, SolrQueryResponse rsp, UpdateRequestProcessor next) { - boolean storeParent = shouldStoreDocParent(req.getSchema()); - boolean storePath = shouldStoreDocPath(req.getSchema()); - if (!(storeParent || storePath)) { - return next; - } - return new NestedUpdateProcessor(req, storeParent, storePath, next); + @Override + public UpdateRequestProcessor getInstance( + SolrQueryRequest req, SolrQueryResponse rsp, UpdateRequestProcessor next) { + boolean storeParent = shouldStoreDocParent(req.getSchema()); + boolean storePath = shouldStoreDocPath(req.getSchema()); + if (!(storeParent || storePath)) { + return next; } - - private static boolean shouldStoreDocParent(IndexSchema schema) { - return schema.getFields().containsKey(IndexSchema.NEST_PARENT_FIELD_NAME); + return new NestedUpdateProcessor(req, storeParent, storePath, next); + } + + private static boolean shouldStoreDocParent(IndexSchema schema) { + return schema.getFields().containsKey(IndexSchema.NEST_PARENT_FIELD_NAME); + } + + private static boolean shouldStoreDocPath(IndexSchema schema) { + return schema.getFields().containsKey(IndexSchema.NEST_PATH_FIELD_NAME); + } + + private static class NestedUpdateProcessor extends UpdateRequestProcessor { + private static final String PATH_SEP_CHAR = "/"; + private static final String NUM_SEP_CHAR = "#"; + private static final String SINGULAR_VALUE_CHAR = ""; + private boolean storePath; + private boolean storeParent; + private String uniqueKeyFieldName; + private IndexSchema schema; + + NestedUpdateProcessor( + SolrQueryRequest req, boolean storeParent, boolean storePath, UpdateRequestProcessor next) { + super(next); + this.storeParent = storeParent; + this.storePath = storePath; + this.uniqueKeyFieldName = req.getSchema().getUniqueKeyField().getName(); + this.schema = req.getSchema(); } - private static boolean shouldStoreDocPath(IndexSchema schema) { - return schema.getFields().containsKey(IndexSchema.NEST_PATH_FIELD_NAME); + @Override + public void processAdd(AddUpdateCommand cmd) throws IOException { + SolrInputDocument doc = cmd.getSolrInputDocument(); + processDocChildren(doc, null); + super.processAdd(cmd); } - private static class NestedUpdateProcessor extends UpdateRequestProcessor { - private static final String PATH_SEP_CHAR = "/"; - private static final String NUM_SEP_CHAR = "#"; - private static final String SINGULAR_VALUE_CHAR = ""; - private boolean storePath; - private boolean storeParent; - private String uniqueKeyFieldName; - private IndexSchema schema; - - NestedUpdateProcessor( - SolrQueryRequest req, boolean storeParent, boolean storePath, UpdateRequestProcessor next) { - super(next); - this.storeParent = storeParent; - this.storePath = storePath; - this.uniqueKeyFieldName = req.getSchema().getUniqueKeyField().getName(); - this.schema = req.getSchema(); - } - - @Override - public void processAdd(AddUpdateCommand cmd) throws IOException { - SolrInputDocument doc = cmd.getSolrInputDocument(); - processDocChildren(doc, null); - super.processAdd(cmd); - } - - private boolean processDocChildren(SolrInputDocument doc, String fullPath) { - boolean isNested = false; - for (SolrInputField field : doc.values()) { - SchemaField sfield = schema.getField(field.getName()); - int childNum = 0; - boolean isSingleVal = !(field.getValue() instanceof Collection); - if (fullPath == null && isMultiValuedVectorField(sfield)) { - ArrayList vectors = new ArrayList<>(field.getValueCount()); - for (Object vectorValue : field.getValues()) { - SolrInputDocument singleVectorNestedDoc = new SolrInputDocument(); - singleVectorNestedDoc.setField(field.getName(), vectorValue); - final String sChildNum = isSingleVal ? SINGULAR_VALUE_CHAR : String.valueOf(childNum); - String parentDocId = doc.getField(uniqueKeyFieldName).getFirstValue().toString(); - singleVectorNestedDoc.setField( - uniqueKeyFieldName, generateChildUniqueId(parentDocId, field.getName(), sChildNum)); - - if (!isNested) { - isNested = true; - } - final String lastKeyPath = PATH_SEP_CHAR + field.getName() + NUM_SEP_CHAR + sChildNum; - final String childDocPath = fullPath == null ? lastKeyPath : fullPath + lastKeyPath; - if (storePath) { - setPathField(singleVectorNestedDoc, childDocPath); - } - if (storeParent) { - setParentKey(singleVectorNestedDoc, doc); - } - ++childNum; - vectors.add(singleVectorNestedDoc); - } - doc.setField(field.getName(), vectors); - } else { - for (Object val : field) { - if (!(val instanceof SolrInputDocument cDoc)) { - // either all collection items are child docs or none are. - break; - } - final String fieldName = field.getName(); - - if (fieldName.contains(PATH_SEP_CHAR)) { - throw new SolrException( - SolrException.ErrorCode.BAD_REQUEST, - "Field name: '" - + fieldName - + "' contains: '" - + PATH_SEP_CHAR - + "' , which is reserved for the nested URP"); - } - final String sChildNum = isSingleVal ? SINGULAR_VALUE_CHAR : String.valueOf(childNum); - if (!cDoc.containsKey(uniqueKeyFieldName)) { - String parentDocId = doc.getField(uniqueKeyFieldName).getFirstValue().toString(); - cDoc.setField( - uniqueKeyFieldName, generateChildUniqueId(parentDocId, fieldName, sChildNum)); - } - if (!isNested) { - isNested = true; - } - final String lastKeyPath = PATH_SEP_CHAR + fieldName + NUM_SEP_CHAR + sChildNum; - // concat of all paths children.grandChild => /children#1/grandChild# - final String childDocPath = fullPath == null ? lastKeyPath : fullPath + lastKeyPath; - processChildDoc(cDoc, doc, childDocPath); - ++childNum; - } - } + private boolean processDocChildren(SolrInputDocument doc, String fullPath) { + boolean isNested = false; + for (SolrInputField field : doc.values()) { + SchemaField sfield = schema.getField(field.getName()); + int childNum = 0; + boolean isSingleVal = !(field.getValue() instanceof Collection); + if (fullPath == null && isMultiValuedVectorField(sfield)) { + ArrayList vectors = new ArrayList<>(field.getValueCount()); + for (Object vectorValue : field.getValues()) { + SolrInputDocument singleVectorNestedDoc = new SolrInputDocument(); + singleVectorNestedDoc.setField(field.getName(), vectorValue); + final String sChildNum = isSingleVal ? SINGULAR_VALUE_CHAR : String.valueOf(childNum); + String parentDocId = doc.getField(uniqueKeyFieldName).getFirstValue().toString(); + singleVectorNestedDoc.setField( + uniqueKeyFieldName, generateChildUniqueId(parentDocId, field.getName(), sChildNum)); + + if (!isNested) { + isNested = true; } - return isNested; - } - - private static boolean isMultiValuedVectorField(SchemaField sfield) { - return sfield.getType() instanceof DenseVectorField && sfield.multiValued(); - } - - private void processChildDoc( - SolrInputDocument sdoc, SolrInputDocument parent, String fullPath) { + final String lastKeyPath = PATH_SEP_CHAR + field.getName() + NUM_SEP_CHAR + sChildNum; + final String childDocPath = fullPath == null ? lastKeyPath : fullPath + lastKeyPath; if (storePath) { - setPathField(sdoc, fullPath); + setPathField(singleVectorNestedDoc, childDocPath); } if (storeParent) { - setParentKey(sdoc, parent); + setParentKey(singleVectorNestedDoc, doc); + } + ++childNum; + vectors.add(singleVectorNestedDoc); + } + doc.setField(field.getName(), vectors); + } else { + for (Object val : field) { + if (!(val instanceof SolrInputDocument cDoc)) { + // either all collection items are child docs or none are. + break; } - processDocChildren(sdoc, fullPath); + final String fieldName = field.getName(); + + if (fieldName.contains(PATH_SEP_CHAR)) { + throw new SolrException( + SolrException.ErrorCode.BAD_REQUEST, + "Field name: '" + + fieldName + + "' contains: '" + + PATH_SEP_CHAR + + "' , which is reserved for the nested URP"); + } + final String sChildNum = isSingleVal ? SINGULAR_VALUE_CHAR : String.valueOf(childNum); + if (!cDoc.containsKey(uniqueKeyFieldName)) { + String parentDocId = doc.getField(uniqueKeyFieldName).getFirstValue().toString(); + cDoc.setField( + uniqueKeyFieldName, generateChildUniqueId(parentDocId, fieldName, sChildNum)); + } + if (!isNested) { + isNested = true; + } + final String lastKeyPath = PATH_SEP_CHAR + fieldName + NUM_SEP_CHAR + sChildNum; + // concat of all paths children.grandChild => /children#1/grandChild# + final String childDocPath = fullPath == null ? lastKeyPath : fullPath + lastKeyPath; + processChildDoc(cDoc, doc, childDocPath); + ++childNum; + } } + } + return isNested; + } - private String generateChildUniqueId(String parentId, String childKey, String childNum) { - // combines parentId with the child's key and childNum. e.g. "10/footnote#1" - return parentId + PATH_SEP_CHAR + childKey + NUM_SEP_CHAR + childNum; - } + private static boolean isMultiValuedVectorField(SchemaField sfield) { + return sfield.getType() instanceof DenseVectorField && sfield.multiValued(); + } - private void setParentKey(SolrInputDocument sdoc, SolrInputDocument parent) { - sdoc.setField(IndexSchema.NEST_PARENT_FIELD_NAME, parent.getFieldValue(uniqueKeyFieldName)); - } + private void processChildDoc( + SolrInputDocument sdoc, SolrInputDocument parent, String fullPath) { + if (storePath) { + setPathField(sdoc, fullPath); + } + if (storeParent) { + setParentKey(sdoc, parent); + } + processDocChildren(sdoc, fullPath); + } - private void setPathField(SolrInputDocument sdoc, String fullPath) { - sdoc.setField(IndexSchema.NEST_PATH_FIELD_NAME, fullPath); - } + private String generateChildUniqueId(String parentId, String childKey, String childNum) { + // combines parentId with the child's key and childNum. e.g. "10/footnote#1" + return parentId + PATH_SEP_CHAR + childKey + NUM_SEP_CHAR + childNum; + } + + private void setParentKey(SolrInputDocument sdoc, SolrInputDocument parent) { + sdoc.setField(IndexSchema.NEST_PARENT_FIELD_NAME, parent.getFieldValue(uniqueKeyFieldName)); + } + + private void setPathField(SolrInputDocument sdoc, String fullPath) { + sdoc.setField(IndexSchema.NEST_PATH_FIELD_NAME, fullPath); } + } } diff --git a/solr/core/src/test/org/apache/solr/search/vector/KnnQParserChildTest.java b/solr/core/src/test/org/apache/solr/search/vector/KnnQParserChildTest.java index ecf9e0ee31fa..2f732fe048be 100644 --- a/solr/core/src/test/org/apache/solr/search/vector/KnnQParserChildTest.java +++ b/solr/core/src/test/org/apache/solr/search/vector/KnnQParserChildTest.java @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.solr.search.neural; +package org.apache.solr.search.vector; import java.util.List; import java.util.Random; diff --git a/solr/core/src/test/org/apache/solr/search/vector/KnnQParserMultiValuedVectorsTest.java b/solr/core/src/test/org/apache/solr/search/vector/KnnQParserMultiValuedVectorsTest.java index 9a4afc80c66d..fb98ead3c0d7 100644 --- a/solr/core/src/test/org/apache/solr/search/vector/KnnQParserMultiValuedVectorsTest.java +++ b/solr/core/src/test/org/apache/solr/search/vector/KnnQParserMultiValuedVectorsTest.java @@ -14,35 +14,24 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.solr.search.neural; +package org.apache.solr.search.vector; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; import org.apache.solr.SolrTestCaseJ4; -import org.apache.solr.client.solrj.SolrQuery; -import org.apache.solr.common.SolrException; import org.apache.solr.common.SolrInputDocument; import org.apache.solr.common.params.CommonParams; -import org.apache.solr.common.params.SolrParams; -import org.apache.solr.request.SolrQueryRequest; import org.apache.solr.util.RandomNoReverseMergePolicyFactory; -import org.apache.solr.util.RestTestBase; -import org.junit.After; -import org.junit.Before; import org.junit.BeforeClass; import org.junit.ClassRule; import org.junit.Test; import org.junit.rules.TestRule; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collections; -import java.util.List; - -import static org.apache.solr.search.neural.KnnQParser.DEFAULT_TOP_K; - public class KnnQParserMultiValuedVectorsTest extends SolrTestCaseJ4 { private static final List FLOAT_QUERY_VECTOR = Arrays.asList(1.0f, 1.0f, 1.0f, 1.0f); private static final List BYTE_QUERY_VECTOR = Arrays.asList(1, 1, 1, 1); - + @ClassRule public static final TestRule noReverseMerge = RandomNoReverseMergePolicyFactory.createRule(); @@ -56,7 +45,7 @@ public static void beforeClass() throws Exception { public static void prepareIndex() throws Exception { List docsToIndex = prepareDocs(); for (SolrInputDocument doc : docsToIndex) { - updateJ(jsonAdd(doc), null); + updateJ(jsonAdd(doc), null); } assertU(commit()); } @@ -75,7 +64,7 @@ private static List prepareDocs() { int totalParentDocuments = 10; int totalNestedVectors = 30; int perParentChildren = totalNestedVectors / totalParentDocuments; - + final String[] abcdef = new String[] {"a", "b", "c", "d", "e", "f"}; List docs = new ArrayList<>(totalParentDocuments); @@ -140,7 +129,6 @@ private static List outDistanceByte(List vector, int value) { } return result; } - @Test public void topK_shouldReturnOnlyTopKResults() { @@ -157,100 +145,123 @@ public void topK_shouldReturnOnlyTopKResults() { @Test public void topKWithFilter_shouldReturnOnlyTopKResults() { assertQ( - req(CommonParams.Q, "{!knn f=vector_multivalued topK=5}" + FLOAT_QUERY_VECTOR, "fl", "id","fq","_text_:(b OR c)"), - "//result[@numFound='4']", - "//result/doc[1]/str[@name='id'][.='8']", - "//result/doc[2]/str[@name='id'][.='7']", - "//result/doc[3]/str[@name='id'][.='2']", - "//result/doc[4]/str[@name='id'][.='1']"); + req( + CommonParams.Q, + "{!knn f=vector_multivalued topK=5}" + FLOAT_QUERY_VECTOR, + "fl", + "id", + "fq", + "_text_:(b OR c)"), + "//result[@numFound='4']", + "//result/doc[1]/str[@name='id'][.='8']", + "//result/doc[2]/str[@name='id'][.='7']", + "//result/doc[3]/str[@name='id'][.='2']", + "//result/doc[4]/str[@name='id'][.='1']"); } - @Test public void topKWithoutTransformer_shouldDefaultToBestChildren() { assertQ( - req(CommonParams.Q, "{!knn f=vector_multivalued topK=5}" + FLOAT_QUERY_VECTOR, "fl", "id,vector_multivalued","fq","_text_:(b OR c)"), - "//result[@numFound='4']", - "//result/doc[1]/str[@name='id'][.='8']", - "//result/doc[1]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[1][.='8.0']", - "//result/doc[1]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[2][.='1.0']", - "//result/doc[1]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[3][.='1.0']", - "//result/doc[1]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[4][.='1.0']", - "//result/doc[2]/str[@name='id'][.='7']", - "//result/doc[2]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[1][.='11.0']", - "//result/doc[2]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[2][.='1.0']", - "//result/doc[2]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[3][.='1.0']", - "//result/doc[2]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[4][.='1.0']", - "//result/doc[3]/str[@name='id'][.='2']", - "//result/doc[3]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[1][.='26.0']", - "//result/doc[3]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[2][.='1.0']", - "//result/doc[3]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[3][.='1.0']", - "//result/doc[3]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[4][.='1.0']", - "//result/doc[4]/str[@name='id'][.='1']", - "//result/doc[4]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[1][.='29.0']", - "//result/doc[4]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[2][.='1.0']", - "//result/doc[4]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[3][.='1.0']", - "//result/doc[4]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[4][.='1.0']"); + req( + CommonParams.Q, + "{!knn f=vector_multivalued topK=5}" + FLOAT_QUERY_VECTOR, + "fl", + "id,vector_multivalued", + "fq", + "_text_:(b OR c)"), + "//result[@numFound='4']", + "//result/doc[1]/str[@name='id'][.='8']", + "//result/doc[1]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[1][.='8.0']", + "//result/doc[1]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[2][.='1.0']", + "//result/doc[1]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[3][.='1.0']", + "//result/doc[1]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[4][.='1.0']", + "//result/doc[2]/str[@name='id'][.='7']", + "//result/doc[2]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[1][.='11.0']", + "//result/doc[2]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[2][.='1.0']", + "//result/doc[2]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[3][.='1.0']", + "//result/doc[2]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[4][.='1.0']", + "//result/doc[3]/str[@name='id'][.='2']", + "//result/doc[3]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[1][.='26.0']", + "//result/doc[3]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[2][.='1.0']", + "//result/doc[3]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[3][.='1.0']", + "//result/doc[3]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[4][.='1.0']", + "//result/doc[4]/str[@name='id'][.='1']", + "//result/doc[4]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[1][.='29.0']", + "//result/doc[4]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[2][.='1.0']", + "//result/doc[4]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[3][.='1.0']", + "//result/doc[4]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[4][.='1.0']"); } @Test public void topKWithTransformer_shouldAddDefaultToBestChildren() { assertQ( - req(CommonParams.Q, "{!knn f=vector_multivalued topK=5}" + FLOAT_QUERY_VECTOR, "fl", "id,vector_multivalued,score","fq","_text_:(b OR c)"), - "//result[@numFound='4']", - "//result/doc[1]/str[@name='id'][.='8']", - "//result/doc[1]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[1][.='8.0']", - "//result/doc[1]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[2][.='1.0']", - "//result/doc[1]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[3][.='1.0']", - "//result/doc[1]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[4][.='1.0']", - "//result/doc[2]/str[@name='id'][.='7']", - "//result/doc[2]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[1][.='11.0']", - "//result/doc[2]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[2][.='1.0']", - "//result/doc[2]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[3][.='1.0']", - "//result/doc[2]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[4][.='1.0']", - "//result/doc[3]/str[@name='id'][.='2']", - "//result/doc[3]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[1][.='26.0']", - "//result/doc[3]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[2][.='1.0']", - "//result/doc[3]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[3][.='1.0']", - "//result/doc[3]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[4][.='1.0']", - "//result/doc[4]/str[@name='id'][.='1']", - "//result/doc[4]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[1][.='29.0']", - "//result/doc[4]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[2][.='1.0']", - "//result/doc[4]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[3][.='1.0']", - "//result/doc[4]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[4][.='1.0']"); + req( + CommonParams.Q, + "{!knn f=vector_multivalued topK=5}" + FLOAT_QUERY_VECTOR, + "fl", + "id,vector_multivalued,score", + "fq", + "_text_:(b OR c)"), + "//result[@numFound='4']", + "//result/doc[1]/str[@name='id'][.='8']", + "//result/doc[1]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[1][.='8.0']", + "//result/doc[1]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[2][.='1.0']", + "//result/doc[1]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[3][.='1.0']", + "//result/doc[1]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[4][.='1.0']", + "//result/doc[2]/str[@name='id'][.='7']", + "//result/doc[2]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[1][.='11.0']", + "//result/doc[2]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[2][.='1.0']", + "//result/doc[2]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[3][.='1.0']", + "//result/doc[2]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[4][.='1.0']", + "//result/doc[3]/str[@name='id'][.='2']", + "//result/doc[3]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[1][.='26.0']", + "//result/doc[3]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[2][.='1.0']", + "//result/doc[3]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[3][.='1.0']", + "//result/doc[3]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[4][.='1.0']", + "//result/doc[4]/str[@name='id'][.='1']", + "//result/doc[4]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[1][.='29.0']", + "//result/doc[4]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[2][.='1.0']", + "//result/doc[4]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[3][.='1.0']", + "//result/doc[4]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[4][.='1.0']"); } @Test public void topKWithChildTransformer_shouldUseOriginalChildTransformer() { assertQ( - req(CommonParams.Q, "{!knn f=vector_multivalued topK=3}" + FLOAT_QUERY_VECTOR, "fl", "id,vector_multivalued,score,[child limit=2 fl=vector_multivalued]","fq","_text_:(b OR c)"), - "//result[@numFound='3']", - "//result/doc[1]/str[@name='id'][.='8']", - "//result/doc[1]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[1][.='10.0']", - "//result/doc[1]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[2][.='1.0']", - "//result/doc[1]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[3][.='1.0']", - "//result/doc[1]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[4][.='1.0']", - "//result/doc[1]/arr[@name='vector_multivalued'][1]/doc[2]/arr[@name='vector_multivalued']/float[1][.='9.0']", - "//result/doc[1]/arr[@name='vector_multivalued'][1]/doc[2]/arr[@name='vector_multivalued']/float[2][.='1.0']", - "//result/doc[1]/arr[@name='vector_multivalued'][1]/doc[2]/arr[@name='vector_multivalued']/float[3][.='1.0']", - "//result/doc[1]/arr[@name='vector_multivalued'][1]/doc[2]/arr[@name='vector_multivalued']/float[4][.='1.0']", - "//result/doc[2]/str[@name='id'][.='7']", - "//result/doc[2]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[1][.='13.0']", - "//result/doc[2]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[2][.='1.0']", - "//result/doc[2]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[3][.='1.0']", - "//result/doc[2]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[4][.='1.0']", - "//result/doc[2]/arr[@name='vector_multivalued'][1]/doc[2]/arr[@name='vector_multivalued']/float[1][.='12.0']", - "//result/doc[2]/arr[@name='vector_multivalued'][1]/doc[2]/arr[@name='vector_multivalued']/float[2][.='1.0']", - "//result/doc[2]/arr[@name='vector_multivalued'][1]/doc[2]/arr[@name='vector_multivalued']/float[3][.='1.0']", - "//result/doc[2]/arr[@name='vector_multivalued'][1]/doc[2]/arr[@name='vector_multivalued']/float[4][.='1.0']", - "//result/doc[3]/str[@name='id'][.='2']", - "//result/doc[3]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[1][.='28.0']", - "//result/doc[3]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[2][.='1.0']", - "//result/doc[3]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[3][.='1.0']", - "//result/doc[3]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[4][.='1.0']", - "//result/doc[3]/arr[@name='vector_multivalued'][1]/doc[2]/arr[@name='vector_multivalued']/float[1][.='27.0']", - "//result/doc[3]/arr[@name='vector_multivalued'][1]/doc[2]/arr[@name='vector_multivalued']/float[2][.='1.0']", - "//result/doc[3]/arr[@name='vector_multivalued'][1]/doc[2]/arr[@name='vector_multivalued']/float[3][.='1.0']", - "//result/doc[3]/arr[@name='vector_multivalued'][1]/doc[2]/arr[@name='vector_multivalued']/float[4][.='1.0']"); + req( + CommonParams.Q, + "{!knn f=vector_multivalued topK=3}" + FLOAT_QUERY_VECTOR, + "fl", + "id,vector_multivalued,score,[child limit=2 fl=vector_multivalued]", + "fq", + "_text_:(b OR c)"), + "//result[@numFound='3']", + "//result/doc[1]/str[@name='id'][.='8']", + "//result/doc[1]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[1][.='10.0']", + "//result/doc[1]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[2][.='1.0']", + "//result/doc[1]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[3][.='1.0']", + "//result/doc[1]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[4][.='1.0']", + "//result/doc[1]/arr[@name='vector_multivalued'][1]/doc[2]/arr[@name='vector_multivalued']/float[1][.='9.0']", + "//result/doc[1]/arr[@name='vector_multivalued'][1]/doc[2]/arr[@name='vector_multivalued']/float[2][.='1.0']", + "//result/doc[1]/arr[@name='vector_multivalued'][1]/doc[2]/arr[@name='vector_multivalued']/float[3][.='1.0']", + "//result/doc[1]/arr[@name='vector_multivalued'][1]/doc[2]/arr[@name='vector_multivalued']/float[4][.='1.0']", + "//result/doc[2]/str[@name='id'][.='7']", + "//result/doc[2]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[1][.='13.0']", + "//result/doc[2]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[2][.='1.0']", + "//result/doc[2]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[3][.='1.0']", + "//result/doc[2]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[4][.='1.0']", + "//result/doc[2]/arr[@name='vector_multivalued'][1]/doc[2]/arr[@name='vector_multivalued']/float[1][.='12.0']", + "//result/doc[2]/arr[@name='vector_multivalued'][1]/doc[2]/arr[@name='vector_multivalued']/float[2][.='1.0']", + "//result/doc[2]/arr[@name='vector_multivalued'][1]/doc[2]/arr[@name='vector_multivalued']/float[3][.='1.0']", + "//result/doc[2]/arr[@name='vector_multivalued'][1]/doc[2]/arr[@name='vector_multivalued']/float[4][.='1.0']", + "//result/doc[3]/str[@name='id'][.='2']", + "//result/doc[3]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[1][.='28.0']", + "//result/doc[3]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[2][.='1.0']", + "//result/doc[3]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[3][.='1.0']", + "//result/doc[3]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[4][.='1.0']", + "//result/doc[3]/arr[@name='vector_multivalued'][1]/doc[2]/arr[@name='vector_multivalued']/float[1][.='27.0']", + "//result/doc[3]/arr[@name='vector_multivalued'][1]/doc[2]/arr[@name='vector_multivalued']/float[2][.='1.0']", + "//result/doc[3]/arr[@name='vector_multivalued'][1]/doc[2]/arr[@name='vector_multivalued']/float[3][.='1.0']", + "//result/doc[3]/arr[@name='vector_multivalued'][1]/doc[2]/arr[@name='vector_multivalued']/float[4][.='1.0']"); } } From b9865c151d76092b135f10af20db1405d20072cf Mon Sep 17 00:00:00 2001 From: Alessandro Benedetti Date: Wed, 17 Dec 2025 10:54:55 +0100 Subject: [PATCH 31/43] work in progress --- .../transform/ChildDocTransformer.java | 2 +- .../NestedUpdateProcessorFactory.java | 41 +- .../collection1/conf/schema-densevector.xml | 11 +- .../join/BlockJoinMultiValuedVectorsTest.java | 165 +++++++ ...ockJoinNestedVectorsParentQParserTest.java | 448 ++++++++++++++++++ .../BlockJoinNestedVectorsQParserTest.java | 437 ----------------- .../join/BlockJoinNestedVectorsTest.java | 169 +++++++ .../KnnQParserMultiValuedVectorsTest.java | 267 ----------- 8 files changed, 817 insertions(+), 723 deletions(-) create mode 100644 solr/core/src/test/org/apache/solr/search/join/BlockJoinMultiValuedVectorsTest.java create mode 100644 solr/core/src/test/org/apache/solr/search/join/BlockJoinNestedVectorsParentQParserTest.java delete mode 100644 solr/core/src/test/org/apache/solr/search/join/BlockJoinNestedVectorsQParserTest.java create mode 100644 solr/core/src/test/org/apache/solr/search/join/BlockJoinNestedVectorsTest.java delete mode 100644 solr/core/src/test/org/apache/solr/search/vector/KnnQParserMultiValuedVectorsTest.java diff --git a/solr/core/src/java/org/apache/solr/response/transform/ChildDocTransformer.java b/solr/core/src/java/org/apache/solr/response/transform/ChildDocTransformer.java index ad2f480e91fe..204aa7a6190c 100644 --- a/solr/core/src/java/org/apache/solr/response/transform/ChildDocTransformer.java +++ b/solr/core/src/java/org/apache/solr/response/transform/ChildDocTransformer.java @@ -52,7 +52,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -public class ChildDocTransformer extends DocTransformer { +class ChildDocTransformer extends DocTransformer { private static final Logger log = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass()); private static final String ANON_CHILD_KEY = "_childDocuments_"; diff --git a/solr/core/src/java/org/apache/solr/update/processor/NestedUpdateProcessorFactory.java b/solr/core/src/java/org/apache/solr/update/processor/NestedUpdateProcessorFactory.java index ba3b3c8419fb..4e5637bda4cd 100644 --- a/solr/core/src/java/org/apache/solr/update/processor/NestedUpdateProcessorFactory.java +++ b/solr/core/src/java/org/apache/solr/update/processor/NestedUpdateProcessorFactory.java @@ -20,6 +20,7 @@ import java.io.IOException; import java.util.ArrayList; import java.util.Collection; +import java.util.List; import org.apache.solr.common.SolrException; import org.apache.solr.common.SolrInputDocument; import org.apache.solr.common.SolrInputField; @@ -86,12 +87,14 @@ public void processAdd(AddUpdateCommand cmd) throws IOException { private boolean processDocChildren(SolrInputDocument doc, String fullPath) { boolean isNested = false; + List originalVectorFieldsToRemove = new ArrayList<>(); + ArrayList vectors = new ArrayList<>(); for (SolrInputField field : doc.values()) { - SchemaField sfield = schema.getField(field.getName()); + SchemaField sfield = schema.getFieldOrNull(field.getName()); int childNum = 0; boolean isSingleVal = !(field.getValue() instanceof Collection); - if (fullPath == null && isMultiValuedVectorField(sfield)) { - ArrayList vectors = new ArrayList<>(field.getValueCount()); + boolean firstLevelChildren = fullPath == null; + if (firstLevelChildren && sfield!= null && isMultiValuedVectorField(sfield)) { for (Object vectorValue : field.getValues()) { SolrInputDocument singleVectorNestedDoc = new SolrInputDocument(); singleVectorNestedDoc.setField(field.getName(), vectorValue); @@ -104,7 +107,7 @@ private boolean processDocChildren(SolrInputDocument doc, String fullPath) { isNested = true; } final String lastKeyPath = PATH_SEP_CHAR + field.getName() + NUM_SEP_CHAR + sChildNum; - final String childDocPath = fullPath == null ? lastKeyPath : fullPath + lastKeyPath; + final String childDocPath = firstLevelChildren ? lastKeyPath : fullPath + lastKeyPath; if (storePath) { setPathField(singleVectorNestedDoc, childDocPath); } @@ -114,7 +117,7 @@ private boolean processDocChildren(SolrInputDocument doc, String fullPath) { ++childNum; vectors.add(singleVectorNestedDoc); } - doc.setField(field.getName(), vectors); + originalVectorFieldsToRemove.add(field.getName()); } else { for (Object val : field) { if (!(val instanceof SolrInputDocument cDoc)) { @@ -143,28 +146,38 @@ private boolean processDocChildren(SolrInputDocument doc, String fullPath) { } final String lastKeyPath = PATH_SEP_CHAR + fieldName + NUM_SEP_CHAR + sChildNum; // concat of all paths children.grandChild => /children#1/grandChild# - final String childDocPath = fullPath == null ? lastKeyPath : fullPath + lastKeyPath; + final String childDocPath = firstLevelChildren ? lastKeyPath : fullPath + lastKeyPath; processChildDoc(cDoc, doc, childDocPath); ++childNum; } } } + this.cleanOriginalVectorFields(doc, originalVectorFieldsToRemove); + if(vectors.size() > 0) { + doc.setField("vectors", vectors); + } return isNested; } + private void cleanOriginalVectorFields(SolrInputDocument doc, List originalVectorFieldsToRemove) { + for(String fieldName : originalVectorFieldsToRemove) { + doc.removeField(fieldName); + } + } + private static boolean isMultiValuedVectorField(SchemaField sfield) { return sfield.getType() instanceof DenseVectorField && sfield.multiValued(); } private void processChildDoc( - SolrInputDocument sdoc, SolrInputDocument parent, String fullPath) { + SolrInputDocument child, SolrInputDocument parent, String fullPath) { if (storePath) { - setPathField(sdoc, fullPath); + setPathField(child, fullPath); } if (storeParent) { - setParentKey(sdoc, parent); + setParentKey(child, parent); } - processDocChildren(sdoc, fullPath); + processDocChildren(child, fullPath); } private String generateChildUniqueId(String parentId, String childKey, String childNum) { @@ -172,12 +185,12 @@ private String generateChildUniqueId(String parentId, String childKey, String ch return parentId + PATH_SEP_CHAR + childKey + NUM_SEP_CHAR + childNum; } - private void setParentKey(SolrInputDocument sdoc, SolrInputDocument parent) { - sdoc.setField(IndexSchema.NEST_PARENT_FIELD_NAME, parent.getFieldValue(uniqueKeyFieldName)); + private void setParentKey(SolrInputDocument child, SolrInputDocument parent) { + child.setField(IndexSchema.NEST_PARENT_FIELD_NAME, parent.getFieldValue(uniqueKeyFieldName)); } - private void setPathField(SolrInputDocument sdoc, String fullPath) { - sdoc.setField(IndexSchema.NEST_PATH_FIELD_NAME, fullPath); + private void setPathField(SolrInputDocument child, String fullPath) { + child.setField(IndexSchema.NEST_PATH_FIELD_NAME, fullPath); } } } diff --git a/solr/core/src/test-files/solr/collection1/conf/schema-densevector.xml b/solr/core/src/test-files/solr/collection1/conf/schema-densevector.xml index 75fecdb18258..8557c560578a 100644 --- a/solr/core/src/test-files/solr/collection1/conf/schema-densevector.xml +++ b/solr/core/src/test-files/solr/collection1/conf/schema-densevector.xml @@ -18,22 +18,23 @@ - + + - + - + - + @@ -44,6 +45,7 @@ + @@ -62,5 +64,6 @@ + id diff --git a/solr/core/src/test/org/apache/solr/search/join/BlockJoinMultiValuedVectorsTest.java b/solr/core/src/test/org/apache/solr/search/join/BlockJoinMultiValuedVectorsTest.java new file mode 100644 index 000000000000..65c1d09ca8b0 --- /dev/null +++ b/solr/core/src/test/org/apache/solr/search/join/BlockJoinMultiValuedVectorsTest.java @@ -0,0 +1,165 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.search.join; + +import java.util.ArrayList; +import java.util.List; +import org.apache.solr.common.SolrInputDocument; +import org.apache.solr.util.RandomNoReverseMergePolicyFactory; +import org.junit.BeforeClass; +import org.junit.ClassRule; +import org.junit.Test; +import org.junit.rules.TestRule; + +public class BlockJoinMultiValuedVectorsTest extends BlockJoinNestedVectorsParentQParserTest { + + protected static String VECTOR_FIELD = "vector_multivalued"; + protected static String VECTOR_BYTE_FIELD = "vector_byte_multivalued"; + + @ClassRule + public static final TestRule noReverseMerge = RandomNoReverseMergePolicyFactory.createRule(); + + @BeforeClass + public static void beforeClass() throws Exception { + /* vectorDimension="4" similarityFunction="cosine" */ + initCore("solrconfig_codec.xml", "schema-densevector.xml"); + prepareIndex(); + } + + protected static void prepareIndex() throws Exception { + List docsToIndex = prepareDocs(); + for (SolrInputDocument doc : docsToIndex) { + updateJ(jsonAdd(doc), null); + } + assertU(commit()); + } + + /** + * The documents in the index are 10 parents, with some parent level metadata and 30 nested + * documents (with vectors and children level metadata) Each parent document has 3 nested + * documents with vectors. + * + *

This allows to run knn queries both at parent/children level and using various pre-filters + * both for parent metadata and children. + * + * @return a list of documents to index + */ + protected static List prepareDocs() { + int totalParentDocuments = 10; + int totalNestedVectors = 30; + int perParentChildren = totalNestedVectors / totalParentDocuments; + + final String[] abcdef = new String[] {"a", "b", "c", "d", "e", "f"}; + + List docs = new ArrayList<>(totalParentDocuments); + for (int i = 1; i < totalParentDocuments + 1; i++) { + SolrInputDocument doc = new SolrInputDocument(); + doc.setField("id", i); + doc.setField("parent_b", true); + doc.setField("parent_s", abcdef[i % abcdef.length]); + List> floatVectors = new ArrayList<>(perParentChildren); + List> byteVectors = new ArrayList<>(perParentChildren); + // nested vector documents have a distance from the query vector inversely proportional to + // their id + for (int j = 0; j < perParentChildren; j++) { + floatVectors.add(outDistanceFloat(FLOAT_QUERY_VECTOR, totalNestedVectors)); + byteVectors.add(outDistanceByte(BYTE_QUERY_VECTOR, totalNestedVectors)); + totalNestedVectors--; // the higher the id of the nested document, lower the distance with + } + doc.setField(VECTOR_FIELD, floatVectors); + doc.setField(VECTOR_BYTE_FIELD, byteVectors); + + docs.add(doc); + } + + return docs; + } + + @Test + public void parentRetrieval_knnChildrenDiversifyingWithNoAllParents_shouldThrowException() { + super.parentRetrieval_knnChildrenDiversifyingWithNoAllParents_shouldThrowException(VECTOR_FIELD); + } + + @Test + public void childrenRetrievalFloat_filteringByParentMetadata_shouldReturnKnnChildren() { + super.childrenRetrievalFloat_filteringByParentMetadata_shouldReturnKnnChildren(VECTOR_FIELD); + } + + @Test + public void childrenRetrievalByte_filteringByParentMetadata_shouldReturnKnnChildren() { + super.childrenRetrievalByte_filteringByParentMetadata_shouldReturnKnnChildren(VECTOR_BYTE_FIELD); + } + + @Test + public void parentRetrievalFloat_knnChildren_shouldReturnKnnParents() { + super.parentRetrievalFloat_knnChildren_shouldReturnKnnParents(VECTOR_FIELD); + } + + @Test + public void parentRetrievalFloat_knnChildrenWithNoDiversifying_shouldReturnOneParent() { + super.parentRetrievalFloat_knnChildrenWithNoDiversifying_shouldReturnOneParent(VECTOR_FIELD); + } + + @Test + public void parentRetrievalFloat_knnChildrenWithParentFilter_shouldReturnKnnParents() { + super.parentRetrievalFloat_knnChildrenWithParentFilter_shouldReturnKnnParents(VECTOR_FIELD); + } + + @Test + public void + parentRetrievalFloat_knnChildrenWithParentFilterAndChildrenFilter_shouldReturnKnnParents( + ) { + super.parentRetrievalFloat_knnChildrenWithParentFilterAndChildrenFilter_shouldReturnKnnParents(VECTOR_FIELD); + } + + @Test + public void parentRetrievalByte_knnChildren_shouldReturnKnnParents() { + super.parentRetrievalByte_knnChildren_shouldReturnKnnParents(VECTOR_BYTE_FIELD); + } + + @Test + public void parentRetrievalByte_knnChildrenWithParentFilter_shouldReturnKnnParents() { + super.parentRetrievalByte_knnChildrenWithParentFilter_shouldReturnKnnParents(VECTOR_BYTE_FIELD); + } + + @Test + public void + parentRetrievalByte_knnChildrenWithParentFilterAndChildrenFilter_shouldReturnKnnParents() { + super.parentRetrievalByte_knnChildrenWithParentFilterAndChildrenFilter_shouldReturnKnnParents(VECTOR_BYTE_FIELD); + } + + @Test + public void + parentRetrievalFloat_topKWithChildTransformerWithFilter_shouldUseOriginalChildTransformerFilter() { + super.parentRetrievalFloat_topKWithChildTransformerWithFilter_shouldUseOriginalChildTransformerFilter(VECTOR_FIELD); + } + + @Test + public void parentRetrievalFloat_topKWithChildTransformerWithFilter_shouldReturnBestChild() { + super.parentRetrievalFloat_topKWithChildTransformerWithFilter_shouldReturnBestChild(VECTOR_FIELD); + } + + @Test + public void parentRetrievalByte_topKWithChildTransformer_shouldReturnAllChildren() { + super.parentRetrievalByte_topKWithChildTransformer_shouldReturnAllChildren(VECTOR_BYTE_FIELD); + } + + @Test + public void parentRetrievalByte_topKWithChildTransformerWithFilter_shouldReturnBestChild() { + super.parentRetrievalByte_topKWithChildTransformerWithFilter_shouldReturnBestChild(VECTOR_BYTE_FIELD); + } +} diff --git a/solr/core/src/test/org/apache/solr/search/join/BlockJoinNestedVectorsParentQParserTest.java b/solr/core/src/test/org/apache/solr/search/join/BlockJoinNestedVectorsParentQParserTest.java new file mode 100644 index 000000000000..1b2bf02ddc52 --- /dev/null +++ b/solr/core/src/test/org/apache/solr/search/join/BlockJoinNestedVectorsParentQParserTest.java @@ -0,0 +1,448 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.search.join; + +import org.apache.solr.SolrTestCaseJ4; +import org.junit.Test; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +public class BlockJoinNestedVectorsParentQParserTest extends SolrTestCaseJ4 { + protected static final List FLOAT_QUERY_VECTOR = Arrays.asList(1.0f, 1.0f, 1.0f, 1.0f); + protected static final List BYTE_QUERY_VECTOR = Arrays.asList(1, 1, 1, 1); + + + protected static String VECTORS_PSEUDOFIELD = "vectors"; + + /** + * Generate a resulting float vector with a distance from the original vector that is proportional + * to the value in input (higher the value, higher the distance from the original vector) + * + * @param vector a numerical vector + * @param value a numerical value to be added to the first element of the vector + * @return a numerical vector that has a distance from the input vector, proportional to the value + */ + protected static List outDistanceFloat(List vector, int value) { + List result = new ArrayList<>(vector.size()); + for (int i = 0; i < vector.size(); i++) { + if (i == 0) { + result.add(vector.get(i) + value); + } else { + result.add(vector.get(i)); + } + } + return result; + } + + /** + * Generate a resulting byte vector with a distance from the original vector that is proportional + * to the value in input (higher the value, higher the distance from the original vector) + * + * @param vector a numerical vector + * @param value a numerical value to be added to the first element of the vector + * @return a numerical vector that has a distance from the input vector, proportional to the value + */ + protected static List outDistanceByte(List vector, int value) { + List result = new ArrayList<>(vector.size()); + for (int i = 0; i < vector.size(); i++) { + if (i == 0) { + result.add(vector.get(i) + value); + } else { + result.add(vector.get(i)); + } + } + return result; + } + + protected void parentRetrieval_knnChildrenDiversifyingWithNoAllParents_shouldThrowException(String vectorField) { + assertQEx( + "When running a diversifying children KNN query, 'allParents' parameter is required", + req( + "q", + "{!parent which=$allParents score=max v=$children.q}", + "fl", + "id,score", + "children.q", + "{!knn f=" + vectorField + " topK=3 childrenOf=$someParents}" + FLOAT_QUERY_VECTOR, + "allParents", + "parent_s:[* TO *]", + "someParents", + "parent_s:(a c)"), + 400); + } + + protected void childrenRetrievalFloat_filteringByParentMetadata_shouldReturnKnnChildren(String vectorField) { + assertQ( + req( + "fq", "{!child of=$allParents filters=$parent.fq}", + "q", "{!knn f=" + vectorField + " topK=5}" + FLOAT_QUERY_VECTOR, + "fl", "id", + "parent.fq", "parent_s:(a c)", + "allParents", "parent_s:[* TO *]"), + "//*[@numFound='5']", + "//result/doc[1]/str[@name='id'][.='82']", + "//result/doc[2]/str[@name='id'][.='81']", + "//result/doc[3]/str[@name='id'][.='80']", + "//result/doc[4]/str[@name='id'][.='62']", + "//result/doc[5]/str[@name='id'][.='61']"); + } + + protected void childrenRetrievalByte_filteringByParentMetadata_shouldReturnKnnChildren(String vectorByteField) { + assertQ( + req( + "fq", "{!child of=$allParents filters=$parent.fq}", + "q", "{!knn f=" + vectorByteField + " topK=5}" + BYTE_QUERY_VECTOR, + "fl", "id", + "parent.fq", "parent_s:(a c)", + "allParents", "parent_s:[* TO *]"), + "//*[@numFound='5']", + "//result/doc[1]/str[@name='id'][.='82']", + "//result/doc[2]/str[@name='id'][.='81']", + "//result/doc[3]/str[@name='id'][.='80']", + "//result/doc[4]/str[@name='id'][.='62']", + "//result/doc[5]/str[@name='id'][.='61']"); + } + + protected void parentRetrievalFloat_knnChildren_shouldReturnKnnParents(String vectorField) { + assertQ( + req( + "q", "{!parent which=$allParents score=max v=$children.q}", + "fl", "id,score", + "children.q", "{!knn f=" + vectorField + " topK=3 allParents=$allParents}" + FLOAT_QUERY_VECTOR, + "allParents", "parent_s:[* TO *]"), + "//*[@numFound='3']", + "//result/doc[1]/str[@name='id'][.='10']", + "//result/doc[2]/str[@name='id'][.='9']", + "//result/doc[3]/str[@name='id'][.='8']"); + } + + protected void parentRetrievalFloat_knnChildrenWithNoDiversifying_shouldReturnOneParent(String vectorField) { + assertQ( + req( + "q", "{!parent which=$allParents score=max v=$children.q}", + "fl", "id,score", + "children.q", "{!knn f=" + vectorField + " topK=3}" + FLOAT_QUERY_VECTOR, + "allParents", "parent_s:[* TO *]"), + "//*[@numFound='1']", + "//result/doc[1]/str[@name='id'][.='10']"); + } + + protected void parentRetrievalFloat_knnChildrenWithParentFilter_shouldReturnKnnParents(String vectorField) { + assertQ( + req( + "q", "{!parent which=$allParents score=max v=$children.q}", + "fl", "id,score", + "children.q", + "{!knn f=" + vectorField + + " topK=3 childrenOf=$someParents allParents=$allParents}" + + FLOAT_QUERY_VECTOR, + "allParents", "parent_s:[* TO *]", + "someParents", "parent_s:(a c)"), + "//*[@numFound='3']", + "//result/doc[1]/str[@name='id'][.='8']", + "//result/doc[2]/str[@name='id'][.='6']", + "//result/doc[3]/str[@name='id'][.='2']"); + } + + protected void + parentRetrievalFloat_knnChildrenWithParentFilterAndChildrenFilter_shouldReturnKnnParents( + String vectorField) { + assertQ( + req( + "q", "{!parent which=$allParents score=max v=$children.q}", + "fl", "id,score", + "children.q", + "{!knn f=" + vectorField + + " topK=3 preFilter=child_s:m childrenOf=$someParents allParents=$allParents}" + + FLOAT_QUERY_VECTOR, + "allParents", "parent_s:[* TO *]", + "someParents", "parent_s:(a c)"), + "//*[@numFound='2']", + "//result/doc[1]/str[@name='id'][.='8']", + "//result/doc[2]/str[@name='id'][.='2']"); + } + + protected void parentRetrievalByte_knnChildren_shouldReturnKnnParents(String vectorByteField) { + assertQ( + req( + "q", "{!parent which=$allParents score=max v=$children.q}", + "fl", "id,score", + "children.q", "{!knn f=" + vectorByteField + " topK=3 allParents=$allParents}" + BYTE_QUERY_VECTOR, + "allParents", "parent_s:[* TO *]"), + "//*[@numFound='3']", + "//result/doc[1]/str[@name='id'][.='10']", + "//result/doc[2]/str[@name='id'][.='9']", + "//result/doc[3]/str[@name='id'][.='8']"); + } + + protected void parentRetrievalByte_knnChildrenWithParentFilter_shouldReturnKnnParents(String vectorByteField) { + assertQ( + req( + "q", "{!parent which=$allParents score=max v=$children.q}", + "fl", "id,score", + "children.q", + "{!knn f=" + vectorByteField + + " topK=3 childrenOf=$someParents allParents=$allParents}" + + BYTE_QUERY_VECTOR, + "allParents", "parent_s:[* TO *]", + "someParents", "parent_s:(a c)"), + "//*[@numFound='3']", + "//result/doc[1]/str[@name='id'][.='8']", + "//result/doc[2]/str[@name='id'][.='6']", + "//result/doc[3]/str[@name='id'][.='2']"); + } + + protected void + parentRetrievalByte_knnChildrenWithParentFilterAndChildrenFilter_shouldReturnKnnParents(String vectorByteField) { + assertQ( + req( + "q", "{!parent which=$allParents score=max v=$children.q}", + "fl", "id,score", + "children.q", + "{!knn f=" + vectorByteField + + " topK=3 preFilter=child_s:m childrenOf=$someParents allParents=$allParents}" + + BYTE_QUERY_VECTOR, + "allParents", "parent_s:[* TO *]", + "someParents", "parent_s:(a c)"), + "//*[@numFound='2']", + "//result/doc[1]/str[@name='id'][.='8']", + "//result/doc[2]/str[@name='id'][.='2']"); + } + + protected void + parentRetrievalFloat_topKWithChildTransformerWithFilter_shouldUseOriginalChildTransformerFilter( + String vectorField) { + assertQ( + req( + "q", "{!parent which=$allParents score=max v=$children.q}", + "fl", "id,score," + VECTORS_PSEUDOFIELD + "," + vectorField + + ",[child limit=2 fl=vector]", + "children.q", + "{!knn f=" + vectorField + + " topK=3 childrenOf=$someParents allParents=$allParents}" + + FLOAT_QUERY_VECTOR, + "allParents", "parent_s:[* TO *]", + "someParents", "parent_s:(a c)"), + "//result[@numFound='3']", + "//result/doc[1]/str[@name='id'][.='8']", + "//result/doc[1]/arr[@name='" + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + vectorField + "']/float[1][.='10.0']", + "//result/doc[1]/arr[@name='" + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + vectorField + "']/float[2][.='1.0']", + "//result/doc[1]/arr[@name='" + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + vectorField + "']/float[3][.='1.0']", + "//result/doc[1]/arr[@name='" + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + vectorField + "']/float[4][.='1.0']", + "//result/doc[1]/arr[@name='" + VECTORS_PSEUDOFIELD + + "'][1]/doc[2]/arr[@name='" + vectorField + "']/float[1][.='9.0']", + "//result/doc[1]/arr[@name='" + VECTORS_PSEUDOFIELD + + "'][1]/doc[2]/arr[@name='" + vectorField + "']/float[2][.='1.0']", + "//result/doc[1]/arr[@name='" + VECTORS_PSEUDOFIELD + + "'][1]/doc[2]/arr[@name='" + vectorField + "']/float[3][.='1.0']", + "//result/doc[1]/arr[@name='" + VECTORS_PSEUDOFIELD + + "'][1]/doc[2]/arr[@name='" + vectorField + "']/float[4][.='1.0']", + "//result/doc[2]/str[@name='id'][.='6']", + "//result/doc[2]/arr[@name='" + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + vectorField + "']/float[1][.='16.0']", + "//result/doc[2]/arr[@name='" + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + vectorField + "']/float[2][.='1.0']", + "//result/doc[2]/arr[@name='" + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + vectorField + "']/float[3][.='1.0']", + "//result/doc[2]/arr[@name='" + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + vectorField + "']/float[4][.='1.0']", + "//result/doc[2]/arr[@name='" + VECTORS_PSEUDOFIELD + + "'][1]/doc[2]/arr[@name='" + vectorField + "']/float[1][.='15.0']", + "//result/doc[2]/arr[@name='" + VECTORS_PSEUDOFIELD + + "'][1]/doc[2]/arr[@name='" + vectorField + "']/float[2][.='1.0']", + "//result/doc[2]/arr[@name='" + VECTORS_PSEUDOFIELD + + "'][1]/doc[2]/arr[@name='" + vectorField + "']/float[3][.='1.0']", + "//result/doc[2]/arr[@name='" + VECTORS_PSEUDOFIELD + + "'][1]/doc[2]/arr[@name='" + vectorField + "']/float[4][.='1.0']", + "//result/doc[3]/str[@name='id'][.='2']", + "//result/doc[3]/arr[@name='" + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + vectorField + "']/float[1][.='28.0']", + "//result/doc[3]/arr[@name='" + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + vectorField + "']/float[2][.='1.0']", + "//result/doc[3]/arr[@name='" + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + vectorField + "']/float[3][.='1.0']", + "//result/doc[3]/arr[@name='" + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + vectorField + "']/float[4][.='1.0']", + "//result/doc[3]/arr[@name='" + VECTORS_PSEUDOFIELD + + "'][1]/doc[2]/arr[@name='" + vectorField + "']/float[1][.='27.0']", + "//result/doc[3]/arr[@name='" + VECTORS_PSEUDOFIELD + + "'][1]/doc[2]/arr[@name='" + vectorField + "']/float[2][.='1.0']", + "//result/doc[3]/arr[@name='" + VECTORS_PSEUDOFIELD + + "'][1]/doc[2]/arr[@name='" + vectorField + "']/float[3][.='1.0']", + "//result/doc[3]/arr[@name='" + VECTORS_PSEUDOFIELD + + "'][1]/doc[2]/arr[@name='" + vectorField + "']/float[4][.='1.0']"); + } + + protected void parentRetrievalFloat_topKWithChildTransformerWithFilter_shouldReturnBestChild(String vectorField) { + assertQ( + req( + "q", "{!parent which=$allParents score=max v=$children.q}", + "fl", + "id,score," + VECTORS_PSEUDOFIELD + "," + vectorField + + ",[child fl=vector childFilter=$children.q]", + "children.q", + "{!knn f=" + vectorField + + " topK=3 childrenOf=$someParents allParents=$allParents}" + + FLOAT_QUERY_VECTOR, + "allParents", "parent_s:[* TO *]", + "someParents", "parent_s:(b c)"), + "//result[@numFound='3']", + "//result/doc[1]/str[@name='id'][.='8']", + "//result/doc[1]/arr[@name='" + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + vectorField + "']/float[1][.='8.0']", + "//result/doc[1]/arr[@name='" + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + vectorField + "']/float[2][.='1.0']", + "//result/doc[1]/arr[@name='" + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + vectorField + "']/float[3][.='1.0']", + "//result/doc[1]/arr[@name='" + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + vectorField + "']/float[4][.='1.0']", + "//result/doc[2]/str[@name='id'][.='7']", + "//result/doc[2]/arr[@name='" + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + vectorField + "']/float[1][.='11.0']", + "//result/doc[2]/arr[@name='" + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + vectorField + "']/float[2][.='1.0']", + "//result/doc[2]/arr[@name='" + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + vectorField + "']/float[3][.='1.0']", + "//result/doc[2]/arr[@name='" + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + vectorField + "']/float[4][.='1.0']", + "//result/doc[3]/str[@name='id'][.='2']", + "//result/doc[3]/arr[@name='" + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + vectorField + "']/float[1][.='26.0']", + "//result/doc[3]/arr[@name='" + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + vectorField + "']/float[2][.='1.0']", + "//result/doc[3]/arr[@name='" + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + vectorField + "']/float[3][.='1.0']", + "//result/doc[3]/arr[@name='" + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + vectorField + "']/float[4][.='1.0']"); + } + + protected void parentRetrievalByte_topKWithChildTransformer_shouldReturnAllChildren(String vectorByteField) { + assertQ( + req( + "q", "{!parent which=$allParents score=max v=$children.q}", + "fl", + "id,score," + VECTORS_PSEUDOFIELD + "," + vectorByteField + ",[child limit=2 fl=" + + vectorByteField +"]", + "children.q", + "{!knn f=" + vectorByteField + + " topK=3 childrenOf=$someParents allParents=$allParents}" + + BYTE_QUERY_VECTOR, + "allParents", "parent_s:[* TO *]", + "someParents", "parent_s:(b c)"), + "//result[@numFound='3']", + "//result/doc[1]/str[@name='id'][.='8']", + "//result/doc[1]/arr[@name='" + VECTORS_PSEUDOFIELD + "'][1]/doc[1]/arr[@name='" + vectorByteField + + "']/int[1][.='10']", + "//result/doc[1]/arr[@name='" + VECTORS_PSEUDOFIELD + "'][1]/doc[1]/arr[@name='" + vectorByteField + + "']/int[2][.='1']", + "//result/doc[1]/arr[@name='" + VECTORS_PSEUDOFIELD + "'][1]/doc[1]/arr[@name='" + vectorByteField + + "']/int[3][.='1']", + "//result/doc[1]/arr[@name='" + VECTORS_PSEUDOFIELD + "'][1]/doc[1]/arr[@name='" + vectorByteField + + "']/int[4][.='1']", + "//result/doc[1]/arr[@name='" + VECTORS_PSEUDOFIELD + "'][1]/doc[2]/arr[@name='" + vectorByteField + + "']/int[1][.='9']", + "//result/doc[1]/arr[@name='" + VECTORS_PSEUDOFIELD + "'][1]/doc[2]/arr[@name='" + vectorByteField + + "']/int[2][.='1']", + "//result/doc[1]/arr[@name='" + VECTORS_PSEUDOFIELD + "'][1]/doc[2]/arr[@name='" + vectorByteField + + "']/int[3][.='1']", + "//result/doc[1]/arr[@name='" + VECTORS_PSEUDOFIELD + "'][1]/doc[2]/arr[@name='" + vectorByteField + + "']/int[4][.='1']", + "//result/doc[2]/str[@name='id'][.='7']", + "//result/doc[2]/arr[@name='" + VECTORS_PSEUDOFIELD + "'][1]/doc[1]/arr[@name='" + vectorByteField + + "']/int[1][.='13']", + "//result/doc[2]/arr[@name='" + VECTORS_PSEUDOFIELD + "'][1]/doc[1]/arr[@name='" + vectorByteField + + "']/int[2][.='1']", + "//result/doc[2]/arr[@name='" + VECTORS_PSEUDOFIELD + "'][1]/doc[1]/arr[@name='" + vectorByteField + + "']/int[3][.='1']", + "//result/doc[2]/arr[@name='" + VECTORS_PSEUDOFIELD + "'][1]/doc[1]/arr[@name='" + vectorByteField + + "']/int[4][.='1']", + "//result/doc[2]/arr[@name='" + VECTORS_PSEUDOFIELD + "'][1]/doc[2]/arr[@name='" + vectorByteField + + "']/int[1][.='12']", + "//result/doc[2]/arr[@name='" + VECTORS_PSEUDOFIELD + "'][1]/doc[2]/arr[@name='" + vectorByteField + + "']/int[2][.='1']", + "//result/doc[2]/arr[@name='" + VECTORS_PSEUDOFIELD + "'][1]/doc[2]/arr[@name='" + vectorByteField + + "']/int[3][.='1']", + "//result/doc[2]/arr[@name='" + VECTORS_PSEUDOFIELD + "'][1]/doc[2]/arr[@name='" + vectorByteField + + "']/int[4][.='1']", + "//result/doc[3]/str[@name='id'][.='2']", + "//result/doc[3]/arr[@name='" + VECTORS_PSEUDOFIELD + "'][1]/doc[1]/arr[@name='" + vectorByteField + + "']/int[1][.='28']", + "//result/doc[3]/arr[@name='" + VECTORS_PSEUDOFIELD + "'][1]/doc[1]/arr[@name='" + vectorByteField + + "']/int[2][.='1']", + "//result/doc[3]/arr[@name='" + VECTORS_PSEUDOFIELD + "'][1]/doc[1]/arr[@name='" + vectorByteField + + "']/int[3][.='1']", + "//result/doc[3]/arr[@name='" + VECTORS_PSEUDOFIELD + "'][1]/doc[1]/arr[@name='" + vectorByteField + + "']/int[4][.='1']", + "//result/doc[3]/arr[@name='" + VECTORS_PSEUDOFIELD + "'][1]/doc[2]/arr[@name='" + vectorByteField + + "']/int[1][.='27']", + "//result/doc[3]/arr[@name='" + VECTORS_PSEUDOFIELD + "'][1]/doc[2]/arr[@name='" + vectorByteField + + "']/int[2][.='1']", + "//result/doc[3]/arr[@name='" + VECTORS_PSEUDOFIELD + "'][1]/doc[2]/arr[@name='" + vectorByteField + + "']/int[3][.='1']", + "//result/doc[3]/arr[@name='" + VECTORS_PSEUDOFIELD + "'][1]/doc[2]/arr[@name='" + vectorByteField + + "']/int[4][.='1']"); + } + + protected void parentRetrievalByte_topKWithChildTransformerWithFilter_shouldReturnBestChild(String vectorByteField) { + assertQ( + req( + "q", "{!parent which=$allParents score=max v=$children.q}", + "fl", + "id,score," + VECTORS_PSEUDOFIELD + "," + vectorByteField + + ",[child fl=" + vectorByteField + " childFilter=$children.q]", + "children.q", + "{!knn f=" + vectorByteField + + " topK=3 childrenOf=$someParents allParents=$allParents}" + + BYTE_QUERY_VECTOR, + "allParents", "parent_s:[* TO *]", + "someParents", "parent_s:(b c)"), + "//result[@numFound='3']", + "//result/doc[1]/str[@name='id'][.='8']", + "//result/doc[1]/arr[@name='" + VECTORS_PSEUDOFIELD + "'][1]/doc[1]/arr[@name='" + vectorByteField + + "']/int[1][.='8']", + "//result/doc[1]/arr[@name='" + VECTORS_PSEUDOFIELD + "'][1]/doc[1]/arr[@name='" + vectorByteField + + "']/int[2][.='1']", + "//result/doc[1]/arr[@name='" + VECTORS_PSEUDOFIELD + "'][1]/doc[1]/arr[@name='" + vectorByteField + + "']/int[3][.='1']", + "//result/doc[1]/arr[@name='" + VECTORS_PSEUDOFIELD + "'][1]/doc[1]/arr[@name='" + vectorByteField + + "']/int[4][.='1']", + "//result/doc[2]/str[@name='id'][.='7']", + "//result/doc[2]/arr[@name='" + VECTORS_PSEUDOFIELD + "'][1]/doc[1]/arr[@name='" + vectorByteField + + "']/int[1][.='11']", + "//result/doc[2]/arr[@name='" + VECTORS_PSEUDOFIELD + "'][1]/doc[1]/arr[@name='" + vectorByteField + + "']/int[2][.='1']", + "//result/doc[2]/arr[@name='" + VECTORS_PSEUDOFIELD + "'][1]/doc[1]/arr[@name='" + vectorByteField + + "']/int[3][.='1']", + "//result/doc[2]/arr[@name='" + VECTORS_PSEUDOFIELD + "'][1]/doc[1]/arr[@name='" + vectorByteField + + "']/int[4][.='1']", + "//result/doc[3]/str[@name='id'][.='2']", + "//result/doc[3]/arr[@name='" + VECTORS_PSEUDOFIELD + "'][1]/doc[1]/arr[@name='" + vectorByteField + + "']/int[1][.='26']", + "//result/doc[3]/arr[@name='" + VECTORS_PSEUDOFIELD + "'][1]/doc[1]/arr[@name='" + vectorByteField + + "']/int[2][.='1']", + "//result/doc[3]/arr[@name='" + VECTORS_PSEUDOFIELD + "'][1]/doc[1]/arr[@name='" + vectorByteField + + "']/int[3][.='1']", + "//result/doc[3]/arr[@name='" + VECTORS_PSEUDOFIELD + "'][1]/doc[1]/arr[@name='" + vectorByteField + + "']/int[4][.='1']"); + } +} diff --git a/solr/core/src/test/org/apache/solr/search/join/BlockJoinNestedVectorsQParserTest.java b/solr/core/src/test/org/apache/solr/search/join/BlockJoinNestedVectorsQParserTest.java deleted file mode 100644 index 65bc72e3bb61..000000000000 --- a/solr/core/src/test/org/apache/solr/search/join/BlockJoinNestedVectorsQParserTest.java +++ /dev/null @@ -1,437 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.solr.search.join; - -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; -import org.apache.solr.SolrTestCaseJ4; -import org.apache.solr.common.SolrInputDocument; -import org.apache.solr.util.RandomNoReverseMergePolicyFactory; -import org.junit.BeforeClass; -import org.junit.ClassRule; -import org.junit.Test; -import org.junit.rules.TestRule; - -public class BlockJoinNestedVectorsQParserTest extends SolrTestCaseJ4 { - private static final List FLOAT_QUERY_VECTOR = Arrays.asList(1.0f, 1.0f, 1.0f, 1.0f); - private static final List BYTE_QUERY_VECTOR = Arrays.asList(1, 1, 1, 1); - - @ClassRule - public static final TestRule noReverseMerge = RandomNoReverseMergePolicyFactory.createRule(); - - @BeforeClass - public static void beforeClass() throws Exception { - initCore("solrconfig.xml", "schema15.xml"); - prepareIndex(); - } - - public static void prepareIndex() throws Exception { - List docsToIndex = prepareDocs(); - for (SolrInputDocument doc : docsToIndex) { - assertU(adoc(doc)); - } - assertU(commit()); - } - - /** - * The documents in the index are 10 parents, with some parent level metadata and 30 nested - * documents (with vectors and children level metadata) Each parent document has 3 nested - * documents with vectors. - * - *

This allows to run knn queries both at parent/children level and using various pre-filters - * both for parent metadata and children. - * - * @return a list of documents to index - */ - private static List prepareDocs() { - int totalParentDocuments = 10; - int totalNestedVectors = 30; - int perParentChildren = totalNestedVectors / totalParentDocuments; - - final String[] klm = new String[] {"k", "l", "m"}; - final String[] abcdef = new String[] {"a", "b", "c", "d", "e", "f"}; - - List docs = new ArrayList<>(totalParentDocuments); - for (int i = 1; i < totalParentDocuments + 1; i++) { - SolrInputDocument doc = new SolrInputDocument(); - doc.setField("id", i); - doc.setField("parent_b", true); - - doc.setField("parent_s", abcdef[i % abcdef.length]); - List children = new ArrayList<>(perParentChildren); - - // nested vector documents have a distance from the query vector inversely proportional to - // their id - for (int j = 0; j < perParentChildren; j++) { - SolrInputDocument child = new SolrInputDocument(); - child.setField("id", i + "" + j); - child.setField("child_s", klm[i % klm.length]); - child.setField("vector", outDistanceFloat(FLOAT_QUERY_VECTOR, totalNestedVectors)); - child.setField("vector_byte", outDistanceByte(BYTE_QUERY_VECTOR, totalNestedVectors)); - totalNestedVectors--; // the higher the id of the nested document, lower the distance with - // the query vector - children.add(child); - } - doc.setField("vectors", children); - docs.add(doc); - } - - return docs; - } - - /** - * Generate a resulting float vector with a distance from the original vector that is proportional - * to the value in input (higher the value, higher the distance from the original vector) - * - * @param vector a numerical vector - * @param value a numerical value to be added to the first element of the vector - * @return a numerical vector that has a distance from the input vector, proportional to the value - */ - private static List outDistanceFloat(List vector, int value) { - List result = new ArrayList<>(vector.size()); - for (int i = 0; i < vector.size(); i++) { - if (i == 0) { - result.add(vector.get(i) + value); - } else { - result.add(vector.get(i)); - } - } - return result; - } - - /** - * Generate a resulting byte vector with a distance from the original vector that is proportional - * to the value in input (higher the value, higher the distance from the original vector) - * - * @param vector a numerical vector - * @param value a numerical value to be added to the first element of the vector - * @return a numerical vector that has a distance from the input vector, proportional to the value - */ - private static List outDistanceByte(List vector, int value) { - List result = new ArrayList<>(vector.size()); - for (int i = 0; i < vector.size(); i++) { - if (i == 0) { - result.add(vector.get(i) + value); - } else { - result.add(vector.get(i)); - } - } - return result; - } - - @Test - public void parentRetrieval_knnChildrenDiversifyingWithNoAllParents_shouldThrowException() { - assertQEx( - "When running a diversifying children KNN query, 'allParents' parameter is required", - req( - "q", - "{!parent which=$allParents score=max v=$children.q}", - "fl", - "id,score", - "children.q", - "{!knn f=vector topK=3 childrenOf=$someParents}" + FLOAT_QUERY_VECTOR, - "allParents", - "parent_s:[* TO *]", - "someParents", - "parent_s:(a c)"), - 400); - } - - @Test - public void childrenRetrievalFloat_filteringByParentMetadata_shouldReturnKnnChildren() { - assertQ( - req( - "fq", "{!child of=$allParents filters=$parent.fq}", - "q", "{!knn f=vector topK=5}" + FLOAT_QUERY_VECTOR, - "fl", "id", - "parent.fq", "parent_s:(a c)", - "allParents", "parent_s:[* TO *]"), - "//*[@numFound='5']", - "//result/doc[1]/str[@name='id'][.='82']", - "//result/doc[2]/str[@name='id'][.='81']", - "//result/doc[3]/str[@name='id'][.='80']", - "//result/doc[4]/str[@name='id'][.='62']", - "//result/doc[5]/str[@name='id'][.='61']"); - } - - @Test - public void childrenRetrievalByte_filteringByParentMetadata_shouldReturnKnnChildren() { - assertQ( - req( - "fq", "{!child of=$allParents filters=$parent.fq}", - "q", "{!knn f=vector_byte topK=5}" + BYTE_QUERY_VECTOR, - "fl", "id", - "parent.fq", "parent_s:(a c)", - "allParents", "parent_s:[* TO *]"), - "//*[@numFound='5']", - "//result/doc[1]/str[@name='id'][.='82']", - "//result/doc[2]/str[@name='id'][.='81']", - "//result/doc[3]/str[@name='id'][.='80']", - "//result/doc[4]/str[@name='id'][.='62']", - "//result/doc[5]/str[@name='id'][.='61']"); - } - - @Test - public void parentRetrievalFloat_knnChildren_shouldReturnKnnParents() { - assertQ( - req( - "q", "{!parent which=$allParents score=max v=$children.q}", - "fl", "id,score", - "children.q", "{!knn f=vector topK=3 allParents=$allParents}" + FLOAT_QUERY_VECTOR, - "allParents", "parent_s:[* TO *]"), - "//*[@numFound='3']", - "//result/doc[1]/str[@name='id'][.='10']", - "//result/doc[2]/str[@name='id'][.='9']", - "//result/doc[3]/str[@name='id'][.='8']"); - } - - @Test - public void parentRetrievalFloat_knnChildrenWithNoDiversifying_shouldReturnOneParent() { - assertQ( - req( - "q", "{!parent which=$allParents score=max v=$children.q}", - "fl", "id,score", - "children.q", "{!knn f=vector topK=3}" + FLOAT_QUERY_VECTOR, - "allParents", "parent_s:[* TO *]"), - "//*[@numFound='1']", - "//result/doc[1]/str[@name='id'][.='10']"); - } - - @Test - public void parentRetrievalFloat_knnChildrenWithParentFilter_shouldReturnKnnParents() { - assertQ( - req( - "q", "{!parent which=$allParents score=max v=$children.q}", - "fl", "id,score", - "children.q", - "{!knn f=vector topK=3 childrenOf=$someParents allParents=$allParents}" - + FLOAT_QUERY_VECTOR, - "allParents", "parent_s:[* TO *]", - "someParents", "parent_s:(a c)"), - "//*[@numFound='3']", - "//result/doc[1]/str[@name='id'][.='8']", - "//result/doc[2]/str[@name='id'][.='6']", - "//result/doc[3]/str[@name='id'][.='2']"); - } - - @Test - public void - parentRetrievalFloat_knnChildrenWithParentFilterAndChildrenFilter_shouldReturnKnnParents() { - assertQ( - req( - "q", "{!parent which=$allParents score=max v=$children.q}", - "fl", "id,score", - "children.q", - "{!knn f=vector topK=3 preFilter=child_s:m childrenOf=$someParents allParents=$allParents}" - + FLOAT_QUERY_VECTOR, - "allParents", "parent_s:[* TO *]", - "someParents", "parent_s:(a c)"), - "//*[@numFound='2']", - "//result/doc[1]/str[@name='id'][.='8']", - "//result/doc[2]/str[@name='id'][.='2']"); - } - - @Test - public void parentRetrievalByte_knnChildren_shouldReturnKnnParents() { - assertQ( - req( - "q", "{!parent which=$allParents score=max v=$children.q}", - "fl", "id,score", - "children.q", "{!knn f=vector_byte topK=3 allParents=$allParents}" + BYTE_QUERY_VECTOR, - "allParents", "parent_s:[* TO *]"), - "//*[@numFound='3']", - "//result/doc[1]/str[@name='id'][.='10']", - "//result/doc[2]/str[@name='id'][.='9']", - "//result/doc[3]/str[@name='id'][.='8']"); - } - - @Test - public void parentRetrievalByte_knnChildrenWithParentFilter_shouldReturnKnnParents() { - assertQ( - req( - "q", "{!parent which=$allParents score=max v=$children.q}", - "fl", "id,score", - "children.q", - "{!knn f=vector_byte topK=3 childrenOf=$someParents allParents=$allParents}" - + BYTE_QUERY_VECTOR, - "allParents", "parent_s:[* TO *]", - "someParents", "parent_s:(a c)"), - "//*[@numFound='3']", - "//result/doc[1]/str[@name='id'][.='8']", - "//result/doc[2]/str[@name='id'][.='6']", - "//result/doc[3]/str[@name='id'][.='2']"); - } - - @Test - public void - parentRetrievalByte_knnChildrenWithParentFilterAndChildrenFilter_shouldReturnKnnParents() { - assertQ( - req( - "q", "{!parent which=$allParents score=max v=$children.q}", - "fl", "id,score", - "children.q", - "{!knn f=vector_byte topK=3 preFilter=child_s:m childrenOf=$someParents allParents=$allParents}" - + BYTE_QUERY_VECTOR, - "allParents", "parent_s:[* TO *]", - "someParents", "parent_s:(a c)"), - "//*[@numFound='2']", - "//result/doc[1]/str[@name='id'][.='8']", - "//result/doc[2]/str[@name='id'][.='2']"); - } - - @Test - public void - parentRetrievalFloat_topKWithChildTransformerWithFilter_shouldUseOriginalChildTransformerFilter() { - assertQ( - req( - "q", "{!parent which=$allParents score=max v=$children.q}", - "fl", "id,score,vectors,vector,[child limit=2 fl=vector]", - "children.q", - "{!knn f=vector topK=3 childrenOf=$someParents allParents=$allParents}" - + FLOAT_QUERY_VECTOR, - "allParents", "parent_s:[* TO *]", - "someParents", "parent_s:(a c)"), - "//result[@numFound='3']", - "//result/doc[1]/str[@name='id'][.='8']", - "//result/doc[1]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[1][.='10.0']", - "//result/doc[1]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[2][.='1.0']", - "//result/doc[1]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[3][.='1.0']", - "//result/doc[1]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[4][.='1.0']", - "//result/doc[1]/arr[@name='vectors'][1]/doc[2]/arr[@name='vector']/float[1][.='9.0']", - "//result/doc[1]/arr[@name='vectors'][1]/doc[2]/arr[@name='vector']/float[2][.='1.0']", - "//result/doc[1]/arr[@name='vectors'][1]/doc[2]/arr[@name='vector']/float[3][.='1.0']", - "//result/doc[1]/arr[@name='vectors'][1]/doc[2]/arr[@name='vector']/float[4][.='1.0']", - "//result/doc[2]/str[@name='id'][.='6']", - "//result/doc[2]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[1][.='16.0']", - "//result/doc[2]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[2][.='1.0']", - "//result/doc[2]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[3][.='1.0']", - "//result/doc[2]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[4][.='1.0']", - "//result/doc[2]/arr[@name='vectors'][1]/doc[2]/arr[@name='vector']/float[1][.='15.0']", - "//result/doc[2]/arr[@name='vectors'][1]/doc[2]/arr[@name='vector']/float[2][.='1.0']", - "//result/doc[2]/arr[@name='vectors'][1]/doc[2]/arr[@name='vector']/float[3][.='1.0']", - "//result/doc[2]/arr[@name='vectors'][1]/doc[2]/arr[@name='vector']/float[4][.='1.0']", - "//result/doc[3]/str[@name='id'][.='2']", - "//result/doc[3]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[1][.='28.0']", - "//result/doc[3]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[2][.='1.0']", - "//result/doc[3]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[3][.='1.0']", - "//result/doc[3]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[4][.='1.0']", - "//result/doc[3]/arr[@name='vectors'][1]/doc[2]/arr[@name='vector']/float[1][.='27.0']", - "//result/doc[3]/arr[@name='vectors'][1]/doc[2]/arr[@name='vector']/float[2][.='1.0']", - "//result/doc[3]/arr[@name='vectors'][1]/doc[2]/arr[@name='vector']/float[3][.='1.0']", - "//result/doc[3]/arr[@name='vectors'][1]/doc[2]/arr[@name='vector']/float[4][.='1.0']"); - } - - @Test - public void parentRetrievalFloat_topKWithChildTransformerWithFilter_shouldReturnBestChild() { - assertQ( - req( - "q", "{!parent which=$allParents score=max v=$children.q}", - "fl", "id,score,vectors,vector,[child fl=vector childFilter=$children.q]", - "children.q", - "{!knn f=vector topK=3 childrenOf=$someParents allParents=$allParents}" - + FLOAT_QUERY_VECTOR, - "allParents", "parent_s:[* TO *]", - "someParents", "parent_s:(b c)"), - "//result[@numFound='3']", - "//result/doc[1]/str[@name='id'][.='8']", - "//result/doc[1]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[1][.='8.0']", - "//result/doc[1]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[2][.='1.0']", - "//result/doc[1]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[3][.='1.0']", - "//result/doc[1]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[4][.='1.0']", - "//result/doc[2]/str[@name='id'][.='7']", - "//result/doc[2]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[1][.='11.0']", - "//result/doc[2]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[2][.='1.0']", - "//result/doc[2]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[3][.='1.0']", - "//result/doc[2]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[4][.='1.0']", - "//result/doc[3]/str[@name='id'][.='2']", - "//result/doc[3]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[1][.='26.0']", - "//result/doc[3]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[2][.='1.0']", - "//result/doc[3]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[3][.='1.0']", - "//result/doc[3]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector']/float[4][.='1.0']"); - } - - @Test - public void parentRetrievalByte_topKWithChildTransformer_shouldReturnAllChildren() { - assertQ( - req( - "q", "{!parent which=$allParents score=max v=$children.q}", - "fl", "id,score,vectors,vector_byte,[child limit=2 fl=vector_byte]", - "children.q", - "{!knn f=vector_byte topK=3 childrenOf=$someParents allParents=$allParents}" - + BYTE_QUERY_VECTOR, - "allParents", "parent_s:[* TO *]", - "someParents", "parent_s:(b c)"), - "//result[@numFound='3']", - "//result/doc[1]/str[@name='id'][.='8']", - "//result/doc[1]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[1][.='10']", - "//result/doc[1]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[2][.='1']", - "//result/doc[1]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[3][.='1']", - "//result/doc[1]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[4][.='1']", - "//result/doc[1]/arr[@name='vectors'][1]/doc[2]/arr[@name='vector_byte']/int[1][.='9']", - "//result/doc[1]/arr[@name='vectors'][1]/doc[2]/arr[@name='vector_byte']/int[2][.='1']", - "//result/doc[1]/arr[@name='vectors'][1]/doc[2]/arr[@name='vector_byte']/int[3][.='1']", - "//result/doc[1]/arr[@name='vectors'][1]/doc[2]/arr[@name='vector_byte']/int[4][.='1']", - "//result/doc[2]/str[@name='id'][.='7']", - "//result/doc[2]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[1][.='13']", - "//result/doc[2]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[2][.='1']", - "//result/doc[2]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[3][.='1']", - "//result/doc[2]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[4][.='1']", - "//result/doc[2]/arr[@name='vectors'][1]/doc[2]/arr[@name='vector_byte']/int[1][.='12']", - "//result/doc[2]/arr[@name='vectors'][1]/doc[2]/arr[@name='vector_byte']/int[2][.='1']", - "//result/doc[2]/arr[@name='vectors'][1]/doc[2]/arr[@name='vector_byte']/int[3][.='1']", - "//result/doc[2]/arr[@name='vectors'][1]/doc[2]/arr[@name='vector_byte']/int[4][.='1']", - "//result/doc[3]/str[@name='id'][.='2']", - "//result/doc[3]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[1][.='28']", - "//result/doc[3]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[2][.='1']", - "//result/doc[3]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[3][.='1']", - "//result/doc[3]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[4][.='1']", - "//result/doc[3]/arr[@name='vectors'][1]/doc[2]/arr[@name='vector_byte']/int[1][.='27']", - "//result/doc[3]/arr[@name='vectors'][1]/doc[2]/arr[@name='vector_byte']/int[2][.='1']", - "//result/doc[3]/arr[@name='vectors'][1]/doc[2]/arr[@name='vector_byte']/int[3][.='1']", - "//result/doc[3]/arr[@name='vectors'][1]/doc[2]/arr[@name='vector_byte']/int[4][.='1']"); - } - - @Test - public void parentRetrievalByte_topKWithChildTransformerWithFilter_shouldReturnBestChild() { - assertQ( - req( - "q", "{!parent which=$allParents score=max v=$children.q}", - "fl", "id,score,vectors,vector_byte,[child fl=vector_byte childFilter=$children.q]", - "children.q", - "{!knn f=vector_byte topK=3 childrenOf=$someParents allParents=$allParents}" - + BYTE_QUERY_VECTOR, - "allParents", "parent_s:[* TO *]", - "someParents", "parent_s:(b c)"), - "//result[@numFound='3']", - "//result/doc[1]/str[@name='id'][.='8']", - "//result/doc[1]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[1][.='8']", - "//result/doc[1]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[2][.='1']", - "//result/doc[1]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[3][.='1']", - "//result/doc[1]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[4][.='1']", - "//result/doc[2]/str[@name='id'][.='7']", - "//result/doc[2]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[1][.='11']", - "//result/doc[2]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[2][.='1']", - "//result/doc[2]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[3][.='1']", - "//result/doc[2]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[4][.='1']", - "//result/doc[3]/str[@name='id'][.='2']", - "//result/doc[3]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[1][.='26']", - "//result/doc[3]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[2][.='1']", - "//result/doc[3]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[3][.='1']", - "//result/doc[3]/arr[@name='vectors'][1]/doc[1]/arr[@name='vector_byte']/int[4][.='1']"); - } -} diff --git a/solr/core/src/test/org/apache/solr/search/join/BlockJoinNestedVectorsTest.java b/solr/core/src/test/org/apache/solr/search/join/BlockJoinNestedVectorsTest.java new file mode 100644 index 000000000000..d5168fb9c764 --- /dev/null +++ b/solr/core/src/test/org/apache/solr/search/join/BlockJoinNestedVectorsTest.java @@ -0,0 +1,169 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.search.join; + +import java.util.ArrayList; +import java.util.List; +import org.apache.solr.common.SolrInputDocument; +import org.apache.solr.util.RandomNoReverseMergePolicyFactory; +import org.junit.BeforeClass; +import org.junit.ClassRule; +import org.junit.Test; +import org.junit.rules.TestRule; + +public class BlockJoinNestedVectorsTest extends BlockJoinNestedVectorsParentQParserTest { + protected static String VECTOR_FIELD = "vector"; + protected static String VECTOR_BYTE_FIELD = "vector_byte_encoding"; + + @ClassRule + public static final TestRule noReverseMerge = RandomNoReverseMergePolicyFactory.createRule(); + + @BeforeClass + public static void beforeClass() throws Exception { + /* vectorDimension="4" similarityFunction="cosine" */ + initCore("solrconfig_codec.xml", "schema-densevector.xml"); + prepareIndex(); + } + + protected static void prepareIndex() throws Exception { + List docsToIndex = prepareDocs(); + for (SolrInputDocument doc : docsToIndex) { + assertU(adoc(doc)); + } + assertU(commit()); + } + + /** + * The documents in the index are 10 parents, with some parent level metadata and 30 nested + * documents (with vectors and children level metadata) Each parent document has 3 nested + * documents with vectors. + * + *

This allows to run knn queries both at parent/children level and using various pre-filters + * both for parent metadata and children. + * + * @return a list of documents to index + */ + private static List prepareDocs() { + int totalParentDocuments = 10; + int totalNestedVectors = 30; + int perParentChildren = totalNestedVectors / totalParentDocuments; + + final String[] klm = new String[] {"k", "l", "m"}; + final String[] abcdef = new String[] {"a", "b", "c", "d", "e", "f"}; + + List docs = new ArrayList<>(totalParentDocuments); + for (int i = 1; i < totalParentDocuments + 1; i++) { + SolrInputDocument doc = new SolrInputDocument(); + doc.setField("id", i); + doc.setField("parent_b", true); + + doc.setField("parent_s", abcdef[i % abcdef.length]); + List children = new ArrayList<>(perParentChildren); + + // nested vector documents have a distance from the query vector inversely proportional to + // their id + for (int j = 0; j < perParentChildren; j++) { + SolrInputDocument child = new SolrInputDocument(); + child.setField("id", i + "" + j); + child.setField("child_s", klm[i % klm.length]); + child.setField("vector", outDistanceFloat(FLOAT_QUERY_VECTOR, totalNestedVectors)); + child.setField("vector_byte_encoding", outDistanceByte(BYTE_QUERY_VECTOR, totalNestedVectors)); + totalNestedVectors--; // the higher the id of the nested document, lower the distance with + // the query vector + children.add(child); + } + doc.setField("vectors", children); + docs.add(doc); + } + + return docs; + } + + @Test + public void parentRetrieval_knnChildrenDiversifyingWithNoAllParents_shouldThrowException() { + super.parentRetrieval_knnChildrenDiversifyingWithNoAllParents_shouldThrowException(VECTOR_FIELD); + } + + @Test + public void childrenRetrievalFloat_filteringByParentMetadata_shouldReturnKnnChildren() { + super.childrenRetrievalFloat_filteringByParentMetadata_shouldReturnKnnChildren(VECTOR_FIELD); + } + + @Test + public void childrenRetrievalByte_filteringByParentMetadata_shouldReturnKnnChildren() { + super.childrenRetrievalByte_filteringByParentMetadata_shouldReturnKnnChildren(VECTOR_BYTE_FIELD); + } + + @Test + public void parentRetrievalFloat_knnChildren_shouldReturnKnnParents() { + super.parentRetrievalFloat_knnChildren_shouldReturnKnnParents(VECTOR_FIELD); + } + + @Test + public void parentRetrievalFloat_knnChildrenWithNoDiversifying_shouldReturnOneParent() { + super.parentRetrievalFloat_knnChildrenWithNoDiversifying_shouldReturnOneParent(VECTOR_FIELD); + } + + @Test + public void parentRetrievalFloat_knnChildrenWithParentFilter_shouldReturnKnnParents() { + super.parentRetrievalFloat_knnChildrenWithParentFilter_shouldReturnKnnParents(VECTOR_FIELD); + } + + @Test + public void + parentRetrievalFloat_knnChildrenWithParentFilterAndChildrenFilter_shouldReturnKnnParents( + ) { + super.parentRetrievalFloat_knnChildrenWithParentFilterAndChildrenFilter_shouldReturnKnnParents(VECTOR_FIELD); + } + + @Test + public void parentRetrievalByte_knnChildren_shouldReturnKnnParents() { + super.parentRetrievalByte_knnChildren_shouldReturnKnnParents(VECTOR_BYTE_FIELD); + } + + @Test + public void parentRetrievalByte_knnChildrenWithParentFilter_shouldReturnKnnParents() { + super.parentRetrievalByte_knnChildrenWithParentFilter_shouldReturnKnnParents(VECTOR_BYTE_FIELD); + } + + @Test + public void + parentRetrievalByte_knnChildrenWithParentFilterAndChildrenFilter_shouldReturnKnnParents() { + super.parentRetrievalByte_knnChildrenWithParentFilterAndChildrenFilter_shouldReturnKnnParents(VECTOR_BYTE_FIELD); + } + + @Test + public void + parentRetrievalFloat_topKWithChildTransformerWithFilter_shouldUseOriginalChildTransformerFilter() { + super.parentRetrievalFloat_topKWithChildTransformerWithFilter_shouldUseOriginalChildTransformerFilter(VECTOR_FIELD); + } + + @Test + public void parentRetrievalFloat_topKWithChildTransformerWithFilter_shouldReturnBestChild() { + super.parentRetrievalFloat_topKWithChildTransformerWithFilter_shouldReturnBestChild(VECTOR_FIELD); + } + + @Test + public void parentRetrievalByte_topKWithChildTransformer_shouldReturnAllChildren() { + super.parentRetrievalByte_topKWithChildTransformer_shouldReturnAllChildren(VECTOR_BYTE_FIELD); + } + + @Test + public void parentRetrievalByte_topKWithChildTransformerWithFilter_shouldReturnBestChild() { + super.parentRetrievalByte_topKWithChildTransformerWithFilter_shouldReturnBestChild(VECTOR_BYTE_FIELD); + } +} diff --git a/solr/core/src/test/org/apache/solr/search/vector/KnnQParserMultiValuedVectorsTest.java b/solr/core/src/test/org/apache/solr/search/vector/KnnQParserMultiValuedVectorsTest.java deleted file mode 100644 index fb98ead3c0d7..000000000000 --- a/solr/core/src/test/org/apache/solr/search/vector/KnnQParserMultiValuedVectorsTest.java +++ /dev/null @@ -1,267 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.solr.search.vector; - -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; -import org.apache.solr.SolrTestCaseJ4; -import org.apache.solr.common.SolrInputDocument; -import org.apache.solr.common.params.CommonParams; -import org.apache.solr.util.RandomNoReverseMergePolicyFactory; -import org.junit.BeforeClass; -import org.junit.ClassRule; -import org.junit.Test; -import org.junit.rules.TestRule; - -public class KnnQParserMultiValuedVectorsTest extends SolrTestCaseJ4 { - private static final List FLOAT_QUERY_VECTOR = Arrays.asList(1.0f, 1.0f, 1.0f, 1.0f); - private static final List BYTE_QUERY_VECTOR = Arrays.asList(1, 1, 1, 1); - - @ClassRule - public static final TestRule noReverseMerge = RandomNoReverseMergePolicyFactory.createRule(); - - @BeforeClass - public static void beforeClass() throws Exception { - /* vectorDimension="4" similarityFunction="cosine" */ - initCore("solrconfig_codec.xml", "schema-densevector.xml"); - prepareIndex(); - } - - public static void prepareIndex() throws Exception { - List docsToIndex = prepareDocs(); - for (SolrInputDocument doc : docsToIndex) { - updateJ(jsonAdd(doc), null); - } - assertU(commit()); - } - - /** - * The documents in the index are 10 parents, with some parent level metadata and 30 nested - * documents (with vectors and children level metadata) Each parent document has 3 nested - * documents with vectors. - * - *

This allows to run knn queries both at parent/children level and using various pre-filters - * both for parent metadata and children. - * - * @return a list of documents to index - */ - private static List prepareDocs() { - int totalParentDocuments = 10; - int totalNestedVectors = 30; - int perParentChildren = totalNestedVectors / totalParentDocuments; - - final String[] abcdef = new String[] {"a", "b", "c", "d", "e", "f"}; - - List docs = new ArrayList<>(totalParentDocuments); - for (int i = 1; i < totalParentDocuments + 1; i++) { - SolrInputDocument doc = new SolrInputDocument(); - doc.setField("id", i); - doc.setField("_text_", abcdef[i % abcdef.length]); - List> floatVectors = new ArrayList<>(perParentChildren); - List> byteVectors = new ArrayList<>(perParentChildren); - // nested vector documents have a distance from the query vector inversely proportional to - // their id - for (int j = 0; j < perParentChildren; j++) { - floatVectors.add(outDistanceFloat(FLOAT_QUERY_VECTOR, totalNestedVectors)); - byteVectors.add(outDistanceByte(BYTE_QUERY_VECTOR, totalNestedVectors)); - totalNestedVectors--; // the higher the id of the nested document, lower the distance with - } - doc.setField("vector_multivalued", floatVectors); - doc.setField("vector_byte_multivalued", byteVectors); - - docs.add(doc); - } - - return docs; - } - - /** - * Generate a resulting float vector with a distance from the original vector that is proportional - * to the value in input (higher the value, higher the distance from the original vector) - * - * @param vector a numerical vector - * @param value a numerical value to be added to the first element of the vector - * @return a numerical vector that has a distance from the input vector, proportional to the value - */ - private static List outDistanceFloat(List vector, int value) { - List result = new ArrayList<>(vector.size()); - for (int i = 0; i < vector.size(); i++) { - if (i == 0) { - result.add(vector.get(i) + value); - } else { - result.add(vector.get(i)); - } - } - return result; - } - - /** - * Generate a resulting byte vector with a distance from the original vector that is proportional - * to the value in input (higher the value, higher the distance from the original vector) - * - * @param vector a numerical vector - * @param value a numerical value to be added to the first element of the vector - * @return a numerical vector that has a distance from the input vector, proportional to the value - */ - private static List outDistanceByte(List vector, int value) { - List result = new ArrayList<>(vector.size()); - for (int i = 0; i < vector.size(); i++) { - if (i == 0) { - result.add(vector.get(i) + value); - } else { - result.add(vector.get(i)); - } - } - return result; - } - - @Test - public void topK_shouldReturnOnlyTopKResults() { - assertQ( - req(CommonParams.Q, "{!knn f=vector_multivalued topK=5}" + FLOAT_QUERY_VECTOR, "fl", "id"), - "//result[@numFound='5']", - "//result/doc[1]/str[@name='id'][.='10']", - "//result/doc[2]/str[@name='id'][.='9']", - "//result/doc[3]/str[@name='id'][.='8']", - "//result/doc[4]/str[@name='id'][.='7']", - "//result/doc[5]/str[@name='id'][.='6']"); - } - - @Test - public void topKWithFilter_shouldReturnOnlyTopKResults() { - assertQ( - req( - CommonParams.Q, - "{!knn f=vector_multivalued topK=5}" + FLOAT_QUERY_VECTOR, - "fl", - "id", - "fq", - "_text_:(b OR c)"), - "//result[@numFound='4']", - "//result/doc[1]/str[@name='id'][.='8']", - "//result/doc[2]/str[@name='id'][.='7']", - "//result/doc[3]/str[@name='id'][.='2']", - "//result/doc[4]/str[@name='id'][.='1']"); - } - - @Test - public void topKWithoutTransformer_shouldDefaultToBestChildren() { - assertQ( - req( - CommonParams.Q, - "{!knn f=vector_multivalued topK=5}" + FLOAT_QUERY_VECTOR, - "fl", - "id,vector_multivalued", - "fq", - "_text_:(b OR c)"), - "//result[@numFound='4']", - "//result/doc[1]/str[@name='id'][.='8']", - "//result/doc[1]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[1][.='8.0']", - "//result/doc[1]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[2][.='1.0']", - "//result/doc[1]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[3][.='1.0']", - "//result/doc[1]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[4][.='1.0']", - "//result/doc[2]/str[@name='id'][.='7']", - "//result/doc[2]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[1][.='11.0']", - "//result/doc[2]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[2][.='1.0']", - "//result/doc[2]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[3][.='1.0']", - "//result/doc[2]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[4][.='1.0']", - "//result/doc[3]/str[@name='id'][.='2']", - "//result/doc[3]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[1][.='26.0']", - "//result/doc[3]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[2][.='1.0']", - "//result/doc[3]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[3][.='1.0']", - "//result/doc[3]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[4][.='1.0']", - "//result/doc[4]/str[@name='id'][.='1']", - "//result/doc[4]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[1][.='29.0']", - "//result/doc[4]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[2][.='1.0']", - "//result/doc[4]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[3][.='1.0']", - "//result/doc[4]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[4][.='1.0']"); - } - - @Test - public void topKWithTransformer_shouldAddDefaultToBestChildren() { - assertQ( - req( - CommonParams.Q, - "{!knn f=vector_multivalued topK=5}" + FLOAT_QUERY_VECTOR, - "fl", - "id,vector_multivalued,score", - "fq", - "_text_:(b OR c)"), - "//result[@numFound='4']", - "//result/doc[1]/str[@name='id'][.='8']", - "//result/doc[1]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[1][.='8.0']", - "//result/doc[1]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[2][.='1.0']", - "//result/doc[1]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[3][.='1.0']", - "//result/doc[1]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[4][.='1.0']", - "//result/doc[2]/str[@name='id'][.='7']", - "//result/doc[2]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[1][.='11.0']", - "//result/doc[2]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[2][.='1.0']", - "//result/doc[2]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[3][.='1.0']", - "//result/doc[2]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[4][.='1.0']", - "//result/doc[3]/str[@name='id'][.='2']", - "//result/doc[3]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[1][.='26.0']", - "//result/doc[3]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[2][.='1.0']", - "//result/doc[3]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[3][.='1.0']", - "//result/doc[3]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[4][.='1.0']", - "//result/doc[4]/str[@name='id'][.='1']", - "//result/doc[4]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[1][.='29.0']", - "//result/doc[4]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[2][.='1.0']", - "//result/doc[4]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[3][.='1.0']", - "//result/doc[4]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[4][.='1.0']"); - } - - @Test - public void topKWithChildTransformer_shouldUseOriginalChildTransformer() { - assertQ( - req( - CommonParams.Q, - "{!knn f=vector_multivalued topK=3}" + FLOAT_QUERY_VECTOR, - "fl", - "id,vector_multivalued,score,[child limit=2 fl=vector_multivalued]", - "fq", - "_text_:(b OR c)"), - "//result[@numFound='3']", - "//result/doc[1]/str[@name='id'][.='8']", - "//result/doc[1]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[1][.='10.0']", - "//result/doc[1]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[2][.='1.0']", - "//result/doc[1]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[3][.='1.0']", - "//result/doc[1]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[4][.='1.0']", - "//result/doc[1]/arr[@name='vector_multivalued'][1]/doc[2]/arr[@name='vector_multivalued']/float[1][.='9.0']", - "//result/doc[1]/arr[@name='vector_multivalued'][1]/doc[2]/arr[@name='vector_multivalued']/float[2][.='1.0']", - "//result/doc[1]/arr[@name='vector_multivalued'][1]/doc[2]/arr[@name='vector_multivalued']/float[3][.='1.0']", - "//result/doc[1]/arr[@name='vector_multivalued'][1]/doc[2]/arr[@name='vector_multivalued']/float[4][.='1.0']", - "//result/doc[2]/str[@name='id'][.='7']", - "//result/doc[2]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[1][.='13.0']", - "//result/doc[2]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[2][.='1.0']", - "//result/doc[2]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[3][.='1.0']", - "//result/doc[2]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[4][.='1.0']", - "//result/doc[2]/arr[@name='vector_multivalued'][1]/doc[2]/arr[@name='vector_multivalued']/float[1][.='12.0']", - "//result/doc[2]/arr[@name='vector_multivalued'][1]/doc[2]/arr[@name='vector_multivalued']/float[2][.='1.0']", - "//result/doc[2]/arr[@name='vector_multivalued'][1]/doc[2]/arr[@name='vector_multivalued']/float[3][.='1.0']", - "//result/doc[2]/arr[@name='vector_multivalued'][1]/doc[2]/arr[@name='vector_multivalued']/float[4][.='1.0']", - "//result/doc[3]/str[@name='id'][.='2']", - "//result/doc[3]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[1][.='28.0']", - "//result/doc[3]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[2][.='1.0']", - "//result/doc[3]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[3][.='1.0']", - "//result/doc[3]/arr[@name='vector_multivalued'][1]/doc[1]/arr[@name='vector_multivalued']/float[4][.='1.0']", - "//result/doc[3]/arr[@name='vector_multivalued'][1]/doc[2]/arr[@name='vector_multivalued']/float[1][.='27.0']", - "//result/doc[3]/arr[@name='vector_multivalued'][1]/doc[2]/arr[@name='vector_multivalued']/float[2][.='1.0']", - "//result/doc[3]/arr[@name='vector_multivalued'][1]/doc[2]/arr[@name='vector_multivalued']/float[3][.='1.0']", - "//result/doc[3]/arr[@name='vector_multivalued'][1]/doc[2]/arr[@name='vector_multivalued']/float[4][.='1.0']"); - } -} From 6475c068ec5ffd57ffab5f7a53827351f83737ab Mon Sep 17 00:00:00 2001 From: Alessandro Benedetti Date: Wed, 17 Dec 2025 16:20:47 +0100 Subject: [PATCH 32/43] work in progress --- .../NestedUpdateProcessorFactory.java | 9 +- .../join/BlockJoinMultiValuedVectorsTest.java | 64 +- ...ockJoinNestedVectorsParentQParserTest.java | 574 +++++++++++++----- .../join/BlockJoinNestedVectorsTest.java | 34 +- 4 files changed, 479 insertions(+), 202 deletions(-) diff --git a/solr/core/src/java/org/apache/solr/update/processor/NestedUpdateProcessorFactory.java b/solr/core/src/java/org/apache/solr/update/processor/NestedUpdateProcessorFactory.java index 4e5637bda4cd..6367cebf474d 100644 --- a/solr/core/src/java/org/apache/solr/update/processor/NestedUpdateProcessorFactory.java +++ b/solr/core/src/java/org/apache/solr/update/processor/NestedUpdateProcessorFactory.java @@ -94,7 +94,7 @@ private boolean processDocChildren(SolrInputDocument doc, String fullPath) { int childNum = 0; boolean isSingleVal = !(field.getValue() instanceof Collection); boolean firstLevelChildren = fullPath == null; - if (firstLevelChildren && sfield!= null && isMultiValuedVectorField(sfield)) { + if (firstLevelChildren && sfield != null && isMultiValuedVectorField(sfield)) { for (Object vectorValue : field.getValues()) { SolrInputDocument singleVectorNestedDoc = new SolrInputDocument(); singleVectorNestedDoc.setField(field.getName(), vectorValue); @@ -153,14 +153,15 @@ private boolean processDocChildren(SolrInputDocument doc, String fullPath) { } } this.cleanOriginalVectorFields(doc, originalVectorFieldsToRemove); - if(vectors.size() > 0) { + if (vectors.size() > 0) { doc.setField("vectors", vectors); } return isNested; } - private void cleanOriginalVectorFields(SolrInputDocument doc, List originalVectorFieldsToRemove) { - for(String fieldName : originalVectorFieldsToRemove) { + private void cleanOriginalVectorFields( + SolrInputDocument doc, List originalVectorFieldsToRemove) { + for (String fieldName : originalVectorFieldsToRemove) { doc.removeField(fieldName); } } diff --git a/solr/core/src/test/org/apache/solr/search/join/BlockJoinMultiValuedVectorsTest.java b/solr/core/src/test/org/apache/solr/search/join/BlockJoinMultiValuedVectorsTest.java index 65c1d09ca8b0..7a9438ed1e8b 100644 --- a/solr/core/src/test/org/apache/solr/search/join/BlockJoinMultiValuedVectorsTest.java +++ b/solr/core/src/test/org/apache/solr/search/join/BlockJoinMultiValuedVectorsTest.java @@ -91,17 +91,42 @@ protected static List prepareDocs() { @Test public void parentRetrieval_knnChildrenDiversifyingWithNoAllParents_shouldThrowException() { - super.parentRetrieval_knnChildrenDiversifyingWithNoAllParents_shouldThrowException(VECTOR_FIELD); + super.parentRetrieval_knnChildrenDiversifyingWithNoAllParents_shouldThrowException( + VECTOR_FIELD); } @Test public void childrenRetrievalFloat_filteringByParentMetadata_shouldReturnKnnChildren() { - super.childrenRetrievalFloat_filteringByParentMetadata_shouldReturnKnnChildren(VECTOR_FIELD); + assertQ( + req( + "fq", "{!child of=$allParents filters=$parent.fq}", + "q", "{!knn f=" + VECTOR_FIELD + " topK=5}" + FLOAT_QUERY_VECTOR, + "fl", "id", + "parent.fq", "parent_s:(a c)", + "allParents", "parent_s:[* TO *]"), + "//*[@numFound='5']", + "//result/doc[1]/str[@name='id'][.='8/vector_byte_multivalued#2']", + "//result/doc[2]/str[@name='id'][.='8/vector_byte_multivalued#1']", + "//result/doc[3]/str[@name='id'][.='8/vector_byte_multivalued#0']", + "//result/doc[4]/str[@name='id'][.='6/vector_byte_multivalued#2']", + "//result/doc[5]/str[@name='id'][.='6/vector_byte_multivalued#1']"); } @Test public void childrenRetrievalByte_filteringByParentMetadata_shouldReturnKnnChildren() { - super.childrenRetrievalByte_filteringByParentMetadata_shouldReturnKnnChildren(VECTOR_BYTE_FIELD); + assertQ( + req( + "fq", "{!child of=$allParents filters=$parent.fq}", + "q", "{!knn f=" + VECTOR_BYTE_FIELD + " topK=5}" + BYTE_QUERY_VECTOR, + "fl", "id", + "parent.fq", "parent_s:(a c)", + "allParents", "parent_s:[* TO *]"), + "//*[@numFound='5']", + "//result/doc[1]/str[@name='id'][.='8/vector_byte_multivalued#2']", + "//result/doc[2]/str[@name='id'][.='8/vector_byte_multivalued#1']", + "//result/doc[3]/str[@name='id'][.='8/vector_byte_multivalued#0']", + "//result/doc[4]/str[@name='id'][.='6/vector_byte_multivalued#2']", + "//result/doc[5]/str[@name='id'][.='6/vector_byte_multivalued#1']"); } @Test @@ -119,13 +144,6 @@ public void parentRetrievalFloat_knnChildrenWithParentFilter_shouldReturnKnnPare super.parentRetrievalFloat_knnChildrenWithParentFilter_shouldReturnKnnParents(VECTOR_FIELD); } - @Test - public void - parentRetrievalFloat_knnChildrenWithParentFilterAndChildrenFilter_shouldReturnKnnParents( - ) { - super.parentRetrievalFloat_knnChildrenWithParentFilterAndChildrenFilter_shouldReturnKnnParents(VECTOR_FIELD); - } - @Test public void parentRetrievalByte_knnChildren_shouldReturnKnnParents() { super.parentRetrievalByte_knnChildren_shouldReturnKnnParents(VECTOR_BYTE_FIELD); @@ -138,28 +156,16 @@ public void parentRetrievalByte_knnChildrenWithParentFilter_shouldReturnKnnParen @Test public void - parentRetrievalByte_knnChildrenWithParentFilterAndChildrenFilter_shouldReturnKnnParents() { - super.parentRetrievalByte_knnChildrenWithParentFilterAndChildrenFilter_shouldReturnKnnParents(VECTOR_BYTE_FIELD); - } - - @Test - public void - parentRetrievalFloat_topKWithChildTransformerWithFilter_shouldUseOriginalChildTransformerFilter() { - super.parentRetrievalFloat_topKWithChildTransformerWithFilter_shouldUseOriginalChildTransformerFilter(VECTOR_FIELD); - } - - @Test - public void parentRetrievalFloat_topKWithChildTransformerWithFilter_shouldReturnBestChild() { - super.parentRetrievalFloat_topKWithChildTransformerWithFilter_shouldReturnBestChild(VECTOR_FIELD); - } - - @Test - public void parentRetrievalByte_topKWithChildTransformer_shouldReturnAllChildren() { + parentRetrievalByte_topKWithChildTransformer_shouldReturnAllChildren() { // new transformer + // all vectors super.parentRetrievalByte_topKWithChildTransformer_shouldReturnAllChildren(VECTOR_BYTE_FIELD); } @Test - public void parentRetrievalByte_topKWithChildTransformerWithFilter_shouldReturnBestChild() { - super.parentRetrievalByte_topKWithChildTransformerWithFilter_shouldReturnBestChild(VECTOR_BYTE_FIELD); + public void + parentRetrievalByte_topKWithChildTransformerWithFilter_shouldReturnBestChild() { // new + // trasnformer best vector + super.parentRetrievalByte_topKWithChildTransformerWithFilter_shouldReturnBestChild( + VECTOR_BYTE_FIELD); } } diff --git a/solr/core/src/test/org/apache/solr/search/join/BlockJoinNestedVectorsParentQParserTest.java b/solr/core/src/test/org/apache/solr/search/join/BlockJoinNestedVectorsParentQParserTest.java index 1b2bf02ddc52..dd904f594860 100644 --- a/solr/core/src/test/org/apache/solr/search/join/BlockJoinNestedVectorsParentQParserTest.java +++ b/solr/core/src/test/org/apache/solr/search/join/BlockJoinNestedVectorsParentQParserTest.java @@ -16,18 +16,15 @@ */ package org.apache.solr.search.join; -import org.apache.solr.SolrTestCaseJ4; -import org.junit.Test; - import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import org.apache.solr.SolrTestCaseJ4; public class BlockJoinNestedVectorsParentQParserTest extends SolrTestCaseJ4 { protected static final List FLOAT_QUERY_VECTOR = Arrays.asList(1.0f, 1.0f, 1.0f, 1.0f); protected static final List BYTE_QUERY_VECTOR = Arrays.asList(1, 1, 1, 1); - protected static String VECTORS_PSEUDOFIELD = "vectors"; /** @@ -70,7 +67,8 @@ protected static List outDistanceByte(List vector, int value) return result; } - protected void parentRetrieval_knnChildrenDiversifyingWithNoAllParents_shouldThrowException(String vectorField) { + protected void parentRetrieval_knnChildrenDiversifyingWithNoAllParents_shouldThrowException( + String vectorField) { assertQEx( "When running a diversifying children KNN query, 'allParents' parameter is required", req( @@ -87,7 +85,8 @@ protected void parentRetrieval_knnChildrenDiversifyingWithNoAllParents_shouldThr 400); } - protected void childrenRetrievalFloat_filteringByParentMetadata_shouldReturnKnnChildren(String vectorField) { + protected void childrenRetrievalFloat_filteringByParentMetadata_shouldReturnKnnChildren( + String vectorField) { assertQ( req( "fq", "{!child of=$allParents filters=$parent.fq}", @@ -103,7 +102,8 @@ protected void childrenRetrievalFloat_filteringByParentMetadata_shouldReturnKnnC "//result/doc[5]/str[@name='id'][.='61']"); } - protected void childrenRetrievalByte_filteringByParentMetadata_shouldReturnKnnChildren(String vectorByteField) { + protected void childrenRetrievalByte_filteringByParentMetadata_shouldReturnKnnChildren( + String vectorByteField) { assertQ( req( "fq", "{!child of=$allParents filters=$parent.fq}", @@ -124,7 +124,8 @@ protected void parentRetrievalFloat_knnChildren_shouldReturnKnnParents(String ve req( "q", "{!parent which=$allParents score=max v=$children.q}", "fl", "id,score", - "children.q", "{!knn f=" + vectorField + " topK=3 allParents=$allParents}" + FLOAT_QUERY_VECTOR, + "children.q", + "{!knn f=" + vectorField + " topK=3 allParents=$allParents}" + FLOAT_QUERY_VECTOR, "allParents", "parent_s:[* TO *]"), "//*[@numFound='3']", "//result/doc[1]/str[@name='id'][.='10']", @@ -132,7 +133,8 @@ protected void parentRetrievalFloat_knnChildren_shouldReturnKnnParents(String ve "//result/doc[3]/str[@name='id'][.='8']"); } - protected void parentRetrievalFloat_knnChildrenWithNoDiversifying_shouldReturnOneParent(String vectorField) { + protected void parentRetrievalFloat_knnChildrenWithNoDiversifying_shouldReturnOneParent( + String vectorField) { assertQ( req( "q", "{!parent which=$allParents score=max v=$children.q}", @@ -143,13 +145,15 @@ protected void parentRetrievalFloat_knnChildrenWithNoDiversifying_shouldReturnOn "//result/doc[1]/str[@name='id'][.='10']"); } - protected void parentRetrievalFloat_knnChildrenWithParentFilter_shouldReturnKnnParents(String vectorField) { + protected void parentRetrievalFloat_knnChildrenWithParentFilter_shouldReturnKnnParents( + String vectorField) { assertQ( req( "q", "{!parent which=$allParents score=max v=$children.q}", "fl", "id,score", "children.q", - "{!knn f=" + vectorField + "{!knn f=" + + vectorField + " topK=3 childrenOf=$someParents allParents=$allParents}" + FLOAT_QUERY_VECTOR, "allParents", "parent_s:[* TO *]", @@ -162,13 +166,14 @@ protected void parentRetrievalFloat_knnChildrenWithParentFilter_shouldReturnKnnP protected void parentRetrievalFloat_knnChildrenWithParentFilterAndChildrenFilter_shouldReturnKnnParents( - String vectorField) { + String vectorField) { assertQ( req( "q", "{!parent which=$allParents score=max v=$children.q}", "fl", "id,score", "children.q", - "{!knn f=" + vectorField + "{!knn f=" + + vectorField + " topK=3 preFilter=child_s:m childrenOf=$someParents allParents=$allParents}" + FLOAT_QUERY_VECTOR, "allParents", "parent_s:[* TO *]", @@ -183,7 +188,11 @@ protected void parentRetrievalByte_knnChildren_shouldReturnKnnParents(String vec req( "q", "{!parent which=$allParents score=max v=$children.q}", "fl", "id,score", - "children.q", "{!knn f=" + vectorByteField + " topK=3 allParents=$allParents}" + BYTE_QUERY_VECTOR, + "children.q", + "{!knn f=" + + vectorByteField + + " topK=3 allParents=$allParents}" + + BYTE_QUERY_VECTOR, "allParents", "parent_s:[* TO *]"), "//*[@numFound='3']", "//result/doc[1]/str[@name='id'][.='10']", @@ -191,13 +200,15 @@ protected void parentRetrievalByte_knnChildren_shouldReturnKnnParents(String vec "//result/doc[3]/str[@name='id'][.='8']"); } - protected void parentRetrievalByte_knnChildrenWithParentFilter_shouldReturnKnnParents(String vectorByteField) { + protected void parentRetrievalByte_knnChildrenWithParentFilter_shouldReturnKnnParents( + String vectorByteField) { assertQ( req( "q", "{!parent which=$allParents score=max v=$children.q}", "fl", "id,score", "children.q", - "{!knn f=" + vectorByteField + "{!knn f=" + + vectorByteField + " topK=3 childrenOf=$someParents allParents=$allParents}" + BYTE_QUERY_VECTOR, "allParents", "parent_s:[* TO *]", @@ -209,13 +220,15 @@ protected void parentRetrievalByte_knnChildrenWithParentFilter_shouldReturnKnnPa } protected void - parentRetrievalByte_knnChildrenWithParentFilterAndChildrenFilter_shouldReturnKnnParents(String vectorByteField) { + parentRetrievalByte_knnChildrenWithParentFilterAndChildrenFilter_shouldReturnKnnParents( + String vectorByteField) { assertQ( req( "q", "{!parent which=$allParents score=max v=$children.q}", "fl", "id,score", "children.q", - "{!knn f=" + vectorByteField + "{!knn f=" + + vectorByteField + " topK=3 preFilter=child_s:m childrenOf=$someParents allParents=$allParents}" + BYTE_QUERY_VECTOR, "allParents", "parent_s:[* TO *]", @@ -227,222 +240,471 @@ protected void parentRetrievalByte_knnChildrenWithParentFilter_shouldReturnKnnPa protected void parentRetrievalFloat_topKWithChildTransformerWithFilter_shouldUseOriginalChildTransformerFilter( - String vectorField) { + String vectorField) { assertQ( req( "q", "{!parent which=$allParents score=max v=$children.q}", - "fl", "id,score," + VECTORS_PSEUDOFIELD + "," + vectorField - + ",[child limit=2 fl=vector]", + "fl", + "id,score," + + VECTORS_PSEUDOFIELD + + "," + + vectorField + + ",[child limit=2 fl=vector]", "children.q", - "{!knn f=" + vectorField + "{!knn f=" + + vectorField + " topK=3 childrenOf=$someParents allParents=$allParents}" + FLOAT_QUERY_VECTOR, "allParents", "parent_s:[* TO *]", "someParents", "parent_s:(a c)"), "//result[@numFound='3']", "//result/doc[1]/str[@name='id'][.='8']", - "//result/doc[1]/arr[@name='" + VECTORS_PSEUDOFIELD - + "'][1]/doc[1]/arr[@name='" + vectorField + "']/float[1][.='10.0']", - "//result/doc[1]/arr[@name='" + VECTORS_PSEUDOFIELD - + "'][1]/doc[1]/arr[@name='" + vectorField + "']/float[2][.='1.0']", - "//result/doc[1]/arr[@name='" + VECTORS_PSEUDOFIELD - + "'][1]/doc[1]/arr[@name='" + vectorField + "']/float[3][.='1.0']", - "//result/doc[1]/arr[@name='" + VECTORS_PSEUDOFIELD - + "'][1]/doc[1]/arr[@name='" + vectorField + "']/float[4][.='1.0']", - "//result/doc[1]/arr[@name='" + VECTORS_PSEUDOFIELD - + "'][1]/doc[2]/arr[@name='" + vectorField + "']/float[1][.='9.0']", - "//result/doc[1]/arr[@name='" + VECTORS_PSEUDOFIELD - + "'][1]/doc[2]/arr[@name='" + vectorField + "']/float[2][.='1.0']", - "//result/doc[1]/arr[@name='" + VECTORS_PSEUDOFIELD - + "'][1]/doc[2]/arr[@name='" + vectorField + "']/float[3][.='1.0']", - "//result/doc[1]/arr[@name='" + VECTORS_PSEUDOFIELD - + "'][1]/doc[2]/arr[@name='" + vectorField + "']/float[4][.='1.0']", + "//result/doc[1]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + vectorField + + "']/float[1][.='10.0']", + "//result/doc[1]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + vectorField + + "']/float[2][.='1.0']", + "//result/doc[1]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + vectorField + + "']/float[3][.='1.0']", + "//result/doc[1]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + vectorField + + "']/float[4][.='1.0']", + "//result/doc[1]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[2]/arr[@name='" + + vectorField + + "']/float[1][.='9.0']", + "//result/doc[1]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[2]/arr[@name='" + + vectorField + + "']/float[2][.='1.0']", + "//result/doc[1]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[2]/arr[@name='" + + vectorField + + "']/float[3][.='1.0']", + "//result/doc[1]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[2]/arr[@name='" + + vectorField + + "']/float[4][.='1.0']", "//result/doc[2]/str[@name='id'][.='6']", - "//result/doc[2]/arr[@name='" + VECTORS_PSEUDOFIELD - + "'][1]/doc[1]/arr[@name='" + vectorField + "']/float[1][.='16.0']", - "//result/doc[2]/arr[@name='" + VECTORS_PSEUDOFIELD - + "'][1]/doc[1]/arr[@name='" + vectorField + "']/float[2][.='1.0']", - "//result/doc[2]/arr[@name='" + VECTORS_PSEUDOFIELD - + "'][1]/doc[1]/arr[@name='" + vectorField + "']/float[3][.='1.0']", - "//result/doc[2]/arr[@name='" + VECTORS_PSEUDOFIELD - + "'][1]/doc[1]/arr[@name='" + vectorField + "']/float[4][.='1.0']", - "//result/doc[2]/arr[@name='" + VECTORS_PSEUDOFIELD - + "'][1]/doc[2]/arr[@name='" + vectorField + "']/float[1][.='15.0']", - "//result/doc[2]/arr[@name='" + VECTORS_PSEUDOFIELD - + "'][1]/doc[2]/arr[@name='" + vectorField + "']/float[2][.='1.0']", - "//result/doc[2]/arr[@name='" + VECTORS_PSEUDOFIELD - + "'][1]/doc[2]/arr[@name='" + vectorField + "']/float[3][.='1.0']", - "//result/doc[2]/arr[@name='" + VECTORS_PSEUDOFIELD - + "'][1]/doc[2]/arr[@name='" + vectorField + "']/float[4][.='1.0']", + "//result/doc[2]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + vectorField + + "']/float[1][.='16.0']", + "//result/doc[2]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + vectorField + + "']/float[2][.='1.0']", + "//result/doc[2]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + vectorField + + "']/float[3][.='1.0']", + "//result/doc[2]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + vectorField + + "']/float[4][.='1.0']", + "//result/doc[2]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[2]/arr[@name='" + + vectorField + + "']/float[1][.='15.0']", + "//result/doc[2]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[2]/arr[@name='" + + vectorField + + "']/float[2][.='1.0']", + "//result/doc[2]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[2]/arr[@name='" + + vectorField + + "']/float[3][.='1.0']", + "//result/doc[2]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[2]/arr[@name='" + + vectorField + + "']/float[4][.='1.0']", "//result/doc[3]/str[@name='id'][.='2']", - "//result/doc[3]/arr[@name='" + VECTORS_PSEUDOFIELD - + "'][1]/doc[1]/arr[@name='" + vectorField + "']/float[1][.='28.0']", - "//result/doc[3]/arr[@name='" + VECTORS_PSEUDOFIELD - + "'][1]/doc[1]/arr[@name='" + vectorField + "']/float[2][.='1.0']", - "//result/doc[3]/arr[@name='" + VECTORS_PSEUDOFIELD - + "'][1]/doc[1]/arr[@name='" + vectorField + "']/float[3][.='1.0']", - "//result/doc[3]/arr[@name='" + VECTORS_PSEUDOFIELD - + "'][1]/doc[1]/arr[@name='" + vectorField + "']/float[4][.='1.0']", - "//result/doc[3]/arr[@name='" + VECTORS_PSEUDOFIELD - + "'][1]/doc[2]/arr[@name='" + vectorField + "']/float[1][.='27.0']", - "//result/doc[3]/arr[@name='" + VECTORS_PSEUDOFIELD - + "'][1]/doc[2]/arr[@name='" + vectorField + "']/float[2][.='1.0']", - "//result/doc[3]/arr[@name='" + VECTORS_PSEUDOFIELD - + "'][1]/doc[2]/arr[@name='" + vectorField + "']/float[3][.='1.0']", - "//result/doc[3]/arr[@name='" + VECTORS_PSEUDOFIELD - + "'][1]/doc[2]/arr[@name='" + vectorField + "']/float[4][.='1.0']"); + "//result/doc[3]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + vectorField + + "']/float[1][.='28.0']", + "//result/doc[3]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + vectorField + + "']/float[2][.='1.0']", + "//result/doc[3]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + vectorField + + "']/float[3][.='1.0']", + "//result/doc[3]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + vectorField + + "']/float[4][.='1.0']", + "//result/doc[3]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[2]/arr[@name='" + + vectorField + + "']/float[1][.='27.0']", + "//result/doc[3]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[2]/arr[@name='" + + vectorField + + "']/float[2][.='1.0']", + "//result/doc[3]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[2]/arr[@name='" + + vectorField + + "']/float[3][.='1.0']", + "//result/doc[3]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[2]/arr[@name='" + + vectorField + + "']/float[4][.='1.0']"); } - protected void parentRetrievalFloat_topKWithChildTransformerWithFilter_shouldReturnBestChild(String vectorField) { + protected void parentRetrievalFloat_topKWithChildTransformerWithFilter_shouldReturnBestChild( + String vectorField) { assertQ( req( - "q", "{!parent which=$allParents score=max v=$children.q}", + "q", + "{!parent which=$allParents score=max v=$children.q}", "fl", - "id,score," + VECTORS_PSEUDOFIELD + "," + vectorField + "id,score," + + VECTORS_PSEUDOFIELD + + "," + + vectorField + ",[child fl=vector childFilter=$children.q]", "children.q", - "{!knn f=" + vectorField - + " topK=3 childrenOf=$someParents allParents=$allParents}" - + FLOAT_QUERY_VECTOR, - "allParents", "parent_s:[* TO *]", - "someParents", "parent_s:(b c)"), + "{!knn f=" + + vectorField + + " topK=3 childrenOf=$someParents allParents=$allParents}" + + FLOAT_QUERY_VECTOR, + "allParents", + "parent_s:[* TO *]", + "someParents", + "parent_s:(b c)"), "//result[@numFound='3']", "//result/doc[1]/str[@name='id'][.='8']", - "//result/doc[1]/arr[@name='" + VECTORS_PSEUDOFIELD - + "'][1]/doc[1]/arr[@name='" + vectorField + "']/float[1][.='8.0']", - "//result/doc[1]/arr[@name='" + VECTORS_PSEUDOFIELD - + "'][1]/doc[1]/arr[@name='" + vectorField + "']/float[2][.='1.0']", - "//result/doc[1]/arr[@name='" + VECTORS_PSEUDOFIELD - + "'][1]/doc[1]/arr[@name='" + vectorField + "']/float[3][.='1.0']", - "//result/doc[1]/arr[@name='" + VECTORS_PSEUDOFIELD - + "'][1]/doc[1]/arr[@name='" + vectorField + "']/float[4][.='1.0']", + "//result/doc[1]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + vectorField + + "']/float[1][.='8.0']", + "//result/doc[1]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + vectorField + + "']/float[2][.='1.0']", + "//result/doc[1]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + vectorField + + "']/float[3][.='1.0']", + "//result/doc[1]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + vectorField + + "']/float[4][.='1.0']", "//result/doc[2]/str[@name='id'][.='7']", - "//result/doc[2]/arr[@name='" + VECTORS_PSEUDOFIELD - + "'][1]/doc[1]/arr[@name='" + vectorField + "']/float[1][.='11.0']", - "//result/doc[2]/arr[@name='" + VECTORS_PSEUDOFIELD - + "'][1]/doc[1]/arr[@name='" + vectorField + "']/float[2][.='1.0']", - "//result/doc[2]/arr[@name='" + VECTORS_PSEUDOFIELD - + "'][1]/doc[1]/arr[@name='" + vectorField + "']/float[3][.='1.0']", - "//result/doc[2]/arr[@name='" + VECTORS_PSEUDOFIELD - + "'][1]/doc[1]/arr[@name='" + vectorField + "']/float[4][.='1.0']", + "//result/doc[2]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + vectorField + + "']/float[1][.='11.0']", + "//result/doc[2]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + vectorField + + "']/float[2][.='1.0']", + "//result/doc[2]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + vectorField + + "']/float[3][.='1.0']", + "//result/doc[2]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + vectorField + + "']/float[4][.='1.0']", "//result/doc[3]/str[@name='id'][.='2']", - "//result/doc[3]/arr[@name='" + VECTORS_PSEUDOFIELD - + "'][1]/doc[1]/arr[@name='" + vectorField + "']/float[1][.='26.0']", - "//result/doc[3]/arr[@name='" + VECTORS_PSEUDOFIELD - + "'][1]/doc[1]/arr[@name='" + vectorField + "']/float[2][.='1.0']", - "//result/doc[3]/arr[@name='" + VECTORS_PSEUDOFIELD - + "'][1]/doc[1]/arr[@name='" + vectorField + "']/float[3][.='1.0']", - "//result/doc[3]/arr[@name='" + VECTORS_PSEUDOFIELD - + "'][1]/doc[1]/arr[@name='" + vectorField + "']/float[4][.='1.0']"); + "//result/doc[3]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + vectorField + + "']/float[1][.='26.0']", + "//result/doc[3]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + vectorField + + "']/float[2][.='1.0']", + "//result/doc[3]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + vectorField + + "']/float[3][.='1.0']", + "//result/doc[3]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + vectorField + + "']/float[4][.='1.0']"); } - protected void parentRetrievalByte_topKWithChildTransformer_shouldReturnAllChildren(String vectorByteField) { + protected void parentRetrievalByte_topKWithChildTransformer_shouldReturnAllChildren( + String vectorByteField) { assertQ( req( - "q", "{!parent which=$allParents score=max v=$children.q}", + "q", + "{!parent which=$allParents score=max v=$children.q}", "fl", - "id,score," + VECTORS_PSEUDOFIELD + "," + vectorByteField + ",[child limit=2 fl=" - + vectorByteField +"]", + "id,score," + + VECTORS_PSEUDOFIELD + + "," + + vectorByteField + + ",[child limit=2 fl=" + + vectorByteField + + "]", "children.q", - "{!knn f=" + vectorByteField - + " topK=3 childrenOf=$someParents allParents=$allParents}" - + BYTE_QUERY_VECTOR, - "allParents", "parent_s:[* TO *]", - "someParents", "parent_s:(b c)"), + "{!knn f=" + + vectorByteField + + " topK=3 childrenOf=$someParents allParents=$allParents}" + + BYTE_QUERY_VECTOR, + "allParents", + "parent_s:[* TO *]", + "someParents", + "parent_s:(b c)"), "//result[@numFound='3']", "//result/doc[1]/str[@name='id'][.='8']", - "//result/doc[1]/arr[@name='" + VECTORS_PSEUDOFIELD + "'][1]/doc[1]/arr[@name='" + vectorByteField + "//result/doc[1]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + vectorByteField + "']/int[1][.='10']", - "//result/doc[1]/arr[@name='" + VECTORS_PSEUDOFIELD + "'][1]/doc[1]/arr[@name='" + vectorByteField + "//result/doc[1]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + vectorByteField + "']/int[2][.='1']", - "//result/doc[1]/arr[@name='" + VECTORS_PSEUDOFIELD + "'][1]/doc[1]/arr[@name='" + vectorByteField + "//result/doc[1]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + vectorByteField + "']/int[3][.='1']", - "//result/doc[1]/arr[@name='" + VECTORS_PSEUDOFIELD + "'][1]/doc[1]/arr[@name='" + vectorByteField + "//result/doc[1]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + vectorByteField + "']/int[4][.='1']", - "//result/doc[1]/arr[@name='" + VECTORS_PSEUDOFIELD + "'][1]/doc[2]/arr[@name='" + vectorByteField + "//result/doc[1]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[2]/arr[@name='" + + vectorByteField + "']/int[1][.='9']", - "//result/doc[1]/arr[@name='" + VECTORS_PSEUDOFIELD + "'][1]/doc[2]/arr[@name='" + vectorByteField + "//result/doc[1]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[2]/arr[@name='" + + vectorByteField + "']/int[2][.='1']", - "//result/doc[1]/arr[@name='" + VECTORS_PSEUDOFIELD + "'][1]/doc[2]/arr[@name='" + vectorByteField + "//result/doc[1]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[2]/arr[@name='" + + vectorByteField + "']/int[3][.='1']", - "//result/doc[1]/arr[@name='" + VECTORS_PSEUDOFIELD + "'][1]/doc[2]/arr[@name='" + vectorByteField + "//result/doc[1]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[2]/arr[@name='" + + vectorByteField + "']/int[4][.='1']", "//result/doc[2]/str[@name='id'][.='7']", - "//result/doc[2]/arr[@name='" + VECTORS_PSEUDOFIELD + "'][1]/doc[1]/arr[@name='" + vectorByteField + "//result/doc[2]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + vectorByteField + "']/int[1][.='13']", - "//result/doc[2]/arr[@name='" + VECTORS_PSEUDOFIELD + "'][1]/doc[1]/arr[@name='" + vectorByteField + "//result/doc[2]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + vectorByteField + "']/int[2][.='1']", - "//result/doc[2]/arr[@name='" + VECTORS_PSEUDOFIELD + "'][1]/doc[1]/arr[@name='" + vectorByteField + "//result/doc[2]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + vectorByteField + "']/int[3][.='1']", - "//result/doc[2]/arr[@name='" + VECTORS_PSEUDOFIELD + "'][1]/doc[1]/arr[@name='" + vectorByteField + "//result/doc[2]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + vectorByteField + "']/int[4][.='1']", - "//result/doc[2]/arr[@name='" + VECTORS_PSEUDOFIELD + "'][1]/doc[2]/arr[@name='" + vectorByteField + "//result/doc[2]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[2]/arr[@name='" + + vectorByteField + "']/int[1][.='12']", - "//result/doc[2]/arr[@name='" + VECTORS_PSEUDOFIELD + "'][1]/doc[2]/arr[@name='" + vectorByteField + "//result/doc[2]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[2]/arr[@name='" + + vectorByteField + "']/int[2][.='1']", - "//result/doc[2]/arr[@name='" + VECTORS_PSEUDOFIELD + "'][1]/doc[2]/arr[@name='" + vectorByteField + "//result/doc[2]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[2]/arr[@name='" + + vectorByteField + "']/int[3][.='1']", - "//result/doc[2]/arr[@name='" + VECTORS_PSEUDOFIELD + "'][1]/doc[2]/arr[@name='" + vectorByteField + "//result/doc[2]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[2]/arr[@name='" + + vectorByteField + "']/int[4][.='1']", "//result/doc[3]/str[@name='id'][.='2']", - "//result/doc[3]/arr[@name='" + VECTORS_PSEUDOFIELD + "'][1]/doc[1]/arr[@name='" + vectorByteField + "//result/doc[3]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + vectorByteField + "']/int[1][.='28']", - "//result/doc[3]/arr[@name='" + VECTORS_PSEUDOFIELD + "'][1]/doc[1]/arr[@name='" + vectorByteField + "//result/doc[3]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + vectorByteField + "']/int[2][.='1']", - "//result/doc[3]/arr[@name='" + VECTORS_PSEUDOFIELD + "'][1]/doc[1]/arr[@name='" + vectorByteField + "//result/doc[3]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + vectorByteField + "']/int[3][.='1']", - "//result/doc[3]/arr[@name='" + VECTORS_PSEUDOFIELD + "'][1]/doc[1]/arr[@name='" + vectorByteField + "//result/doc[3]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + vectorByteField + "']/int[4][.='1']", - "//result/doc[3]/arr[@name='" + VECTORS_PSEUDOFIELD + "'][1]/doc[2]/arr[@name='" + vectorByteField + "//result/doc[3]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[2]/arr[@name='" + + vectorByteField + "']/int[1][.='27']", - "//result/doc[3]/arr[@name='" + VECTORS_PSEUDOFIELD + "'][1]/doc[2]/arr[@name='" + vectorByteField + "//result/doc[3]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[2]/arr[@name='" + + vectorByteField + "']/int[2][.='1']", - "//result/doc[3]/arr[@name='" + VECTORS_PSEUDOFIELD + "'][1]/doc[2]/arr[@name='" + vectorByteField + "//result/doc[3]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[2]/arr[@name='" + + vectorByteField + "']/int[3][.='1']", - "//result/doc[3]/arr[@name='" + VECTORS_PSEUDOFIELD + "'][1]/doc[2]/arr[@name='" + vectorByteField + "//result/doc[3]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[2]/arr[@name='" + + vectorByteField + "']/int[4][.='1']"); } - protected void parentRetrievalByte_topKWithChildTransformerWithFilter_shouldReturnBestChild(String vectorByteField) { + protected void parentRetrievalByte_topKWithChildTransformerWithFilter_shouldReturnBestChild( + String vectorByteField) { assertQ( req( - "q", "{!parent which=$allParents score=max v=$children.q}", + "q", + "{!parent which=$allParents score=max v=$children.q}", "fl", - "id,score," + VECTORS_PSEUDOFIELD + "," + vectorByteField - + ",[child fl=" + vectorByteField + " childFilter=$children.q]", + "id,score," + + VECTORS_PSEUDOFIELD + + "," + + vectorByteField + + ",[child fl=" + + vectorByteField + + " childFilter=$children.q]", "children.q", - "{!knn f=" + vectorByteField - + " topK=3 childrenOf=$someParents allParents=$allParents}" - + BYTE_QUERY_VECTOR, - "allParents", "parent_s:[* TO *]", - "someParents", "parent_s:(b c)"), + "{!knn f=" + + vectorByteField + + " topK=3 childrenOf=$someParents allParents=$allParents}" + + BYTE_QUERY_VECTOR, + "allParents", + "parent_s:[* TO *]", + "someParents", + "parent_s:(b c)"), "//result[@numFound='3']", "//result/doc[1]/str[@name='id'][.='8']", - "//result/doc[1]/arr[@name='" + VECTORS_PSEUDOFIELD + "'][1]/doc[1]/arr[@name='" + vectorByteField + "//result/doc[1]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + vectorByteField + "']/int[1][.='8']", - "//result/doc[1]/arr[@name='" + VECTORS_PSEUDOFIELD + "'][1]/doc[1]/arr[@name='" + vectorByteField + "//result/doc[1]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + vectorByteField + "']/int[2][.='1']", - "//result/doc[1]/arr[@name='" + VECTORS_PSEUDOFIELD + "'][1]/doc[1]/arr[@name='" + vectorByteField + "//result/doc[1]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + vectorByteField + "']/int[3][.='1']", - "//result/doc[1]/arr[@name='" + VECTORS_PSEUDOFIELD + "'][1]/doc[1]/arr[@name='" + vectorByteField + "//result/doc[1]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + vectorByteField + "']/int[4][.='1']", "//result/doc[2]/str[@name='id'][.='7']", - "//result/doc[2]/arr[@name='" + VECTORS_PSEUDOFIELD + "'][1]/doc[1]/arr[@name='" + vectorByteField + "//result/doc[2]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + vectorByteField + "']/int[1][.='11']", - "//result/doc[2]/arr[@name='" + VECTORS_PSEUDOFIELD + "'][1]/doc[1]/arr[@name='" + vectorByteField + "//result/doc[2]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + vectorByteField + "']/int[2][.='1']", - "//result/doc[2]/arr[@name='" + VECTORS_PSEUDOFIELD + "'][1]/doc[1]/arr[@name='" + vectorByteField + "//result/doc[2]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + vectorByteField + "']/int[3][.='1']", - "//result/doc[2]/arr[@name='" + VECTORS_PSEUDOFIELD + "'][1]/doc[1]/arr[@name='" + vectorByteField + "//result/doc[2]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + vectorByteField + "']/int[4][.='1']", "//result/doc[3]/str[@name='id'][.='2']", - "//result/doc[3]/arr[@name='" + VECTORS_PSEUDOFIELD + "'][1]/doc[1]/arr[@name='" + vectorByteField + "//result/doc[3]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + vectorByteField + "']/int[1][.='26']", - "//result/doc[3]/arr[@name='" + VECTORS_PSEUDOFIELD + "'][1]/doc[1]/arr[@name='" + vectorByteField + "//result/doc[3]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + vectorByteField + "']/int[2][.='1']", - "//result/doc[3]/arr[@name='" + VECTORS_PSEUDOFIELD + "'][1]/doc[1]/arr[@name='" + vectorByteField + "//result/doc[3]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + vectorByteField + "']/int[3][.='1']", - "//result/doc[3]/arr[@name='" + VECTORS_PSEUDOFIELD + "'][1]/doc[1]/arr[@name='" + vectorByteField + "//result/doc[3]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + vectorByteField + "']/int[4][.='1']"); } } diff --git a/solr/core/src/test/org/apache/solr/search/join/BlockJoinNestedVectorsTest.java b/solr/core/src/test/org/apache/solr/search/join/BlockJoinNestedVectorsTest.java index d5168fb9c764..27c09a520642 100644 --- a/solr/core/src/test/org/apache/solr/search/join/BlockJoinNestedVectorsTest.java +++ b/solr/core/src/test/org/apache/solr/search/join/BlockJoinNestedVectorsTest.java @@ -81,7 +81,8 @@ private static List prepareDocs() { child.setField("id", i + "" + j); child.setField("child_s", klm[i % klm.length]); child.setField("vector", outDistanceFloat(FLOAT_QUERY_VECTOR, totalNestedVectors)); - child.setField("vector_byte_encoding", outDistanceByte(BYTE_QUERY_VECTOR, totalNestedVectors)); + child.setField( + "vector_byte_encoding", outDistanceByte(BYTE_QUERY_VECTOR, totalNestedVectors)); totalNestedVectors--; // the higher the id of the nested document, lower the distance with // the query vector children.add(child); @@ -95,7 +96,8 @@ private static List prepareDocs() { @Test public void parentRetrieval_knnChildrenDiversifyingWithNoAllParents_shouldThrowException() { - super.parentRetrieval_knnChildrenDiversifyingWithNoAllParents_shouldThrowException(VECTOR_FIELD); + super.parentRetrieval_knnChildrenDiversifyingWithNoAllParents_shouldThrowException( + VECTOR_FIELD); } @Test @@ -105,7 +107,8 @@ public void childrenRetrievalFloat_filteringByParentMetadata_shouldReturnKnnChil @Test public void childrenRetrievalByte_filteringByParentMetadata_shouldReturnKnnChildren() { - super.childrenRetrievalByte_filteringByParentMetadata_shouldReturnKnnChildren(VECTOR_BYTE_FIELD); + super.childrenRetrievalByte_filteringByParentMetadata_shouldReturnKnnChildren( + VECTOR_BYTE_FIELD); } @Test @@ -125,14 +128,14 @@ public void parentRetrievalFloat_knnChildrenWithParentFilter_shouldReturnKnnPare @Test public void - parentRetrievalFloat_knnChildrenWithParentFilterAndChildrenFilter_shouldReturnKnnParents( - ) { - super.parentRetrievalFloat_knnChildrenWithParentFilterAndChildrenFilter_shouldReturnKnnParents(VECTOR_FIELD); + parentRetrievalFloat_knnChildrenWithParentFilterAndChildrenFilter_shouldReturnKnnParents() { + super.parentRetrievalFloat_knnChildrenWithParentFilterAndChildrenFilter_shouldReturnKnnParents( + VECTOR_FIELD); } @Test public void parentRetrievalByte_knnChildren_shouldReturnKnnParents() { - super.parentRetrievalByte_knnChildren_shouldReturnKnnParents(VECTOR_BYTE_FIELD); + super.parentRetrievalByte_knnChildren_shouldReturnKnnParents(VECTOR_BYTE_FIELD); } @Test @@ -142,19 +145,23 @@ public void parentRetrievalByte_knnChildrenWithParentFilter_shouldReturnKnnParen @Test public void - parentRetrievalByte_knnChildrenWithParentFilterAndChildrenFilter_shouldReturnKnnParents() { - super.parentRetrievalByte_knnChildrenWithParentFilterAndChildrenFilter_shouldReturnKnnParents(VECTOR_BYTE_FIELD); + parentRetrievalByte_knnChildrenWithParentFilterAndChildrenFilter_shouldReturnKnnParents() { + super.parentRetrievalByte_knnChildrenWithParentFilterAndChildrenFilter_shouldReturnKnnParents( + VECTOR_BYTE_FIELD); } @Test public void - parentRetrievalFloat_topKWithChildTransformerWithFilter_shouldUseOriginalChildTransformerFilter() { - super.parentRetrievalFloat_topKWithChildTransformerWithFilter_shouldUseOriginalChildTransformerFilter(VECTOR_FIELD); + parentRetrievalFloat_topKWithChildTransformerWithFilter_shouldUseOriginalChildTransformerFilter() { + super + .parentRetrievalFloat_topKWithChildTransformerWithFilter_shouldUseOriginalChildTransformerFilter( + VECTOR_FIELD); } @Test public void parentRetrievalFloat_topKWithChildTransformerWithFilter_shouldReturnBestChild() { - super.parentRetrievalFloat_topKWithChildTransformerWithFilter_shouldReturnBestChild(VECTOR_FIELD); + super.parentRetrievalFloat_topKWithChildTransformerWithFilter_shouldReturnBestChild( + VECTOR_FIELD); } @Test @@ -164,6 +171,7 @@ public void parentRetrievalByte_topKWithChildTransformer_shouldReturnAllChildren @Test public void parentRetrievalByte_topKWithChildTransformerWithFilter_shouldReturnBestChild() { - super.parentRetrievalByte_topKWithChildTransformerWithFilter_shouldReturnBestChild(VECTOR_BYTE_FIELD); + super.parentRetrievalByte_topKWithChildTransformerWithFilter_shouldReturnBestChild( + VECTOR_BYTE_FIELD); } } From bcf9996257e9870490ce83ccbdf26d359438d903 Mon Sep 17 00:00:00 2001 From: Alessandro Benedetti Date: Thu, 18 Dec 2025 12:31:27 +0100 Subject: [PATCH 33/43] work in progress --- .../transform/ChildDocTransformer.java | 36 +++- .../org/apache/solr/schema/IndexSchema.java | 1 + .../NestedUpdateProcessorFactory.java | 4 +- .../join/BlockJoinMultiValuedVectorsTest.java | 24 +-- ...ockJoinNestedVectorsParentQParserTest.java | 168 +----------------- .../join/BlockJoinNestedVectorsTest.java | 105 +++++++++-- .../org/apache/solr/common/SolrDocument.java | 2 + 7 files changed, 150 insertions(+), 190 deletions(-) diff --git a/solr/core/src/java/org/apache/solr/response/transform/ChildDocTransformer.java b/solr/core/src/java/org/apache/solr/response/transform/ChildDocTransformer.java index 204aa7a6190c..2819c98fe384 100644 --- a/solr/core/src/java/org/apache/solr/response/transform/ChildDocTransformer.java +++ b/solr/core/src/java/org/apache/solr/response/transform/ChildDocTransformer.java @@ -19,14 +19,17 @@ import static org.apache.solr.response.transform.ChildDocTransformerFactory.NUM_SEP_CHAR; import static org.apache.solr.response.transform.ChildDocTransformerFactory.PATH_SEP_CHAR; +import static org.apache.solr.schema.IndexSchema.NESTED_VECTORS_PSEUDO_FIELD_NAME; import static org.apache.solr.schema.IndexSchema.NEST_PATH_FIELD_NAME; import java.io.IOException; import java.lang.invoke.MethodHandles; import java.util.ArrayList; +import java.util.Collection; import java.util.HashMap; import java.util.List; import java.util.Map; +import org.apache.lucene.document.StoredField; import org.apache.lucene.index.DocValues; import org.apache.lucene.index.LeafReader; import org.apache.lucene.index.LeafReaderContext; @@ -64,6 +67,8 @@ class ChildDocTransformer extends DocTransformer { private final boolean isNestedSchema; private final SolrReturnFields childReturnFields; private final String[] extraRequestedFields; + private final String multiValuedVectorField = "vector"; + ChildDocTransformer( String name, @@ -219,8 +224,12 @@ public void transform(SolrDocument rootDoc, int rootDocId, DocIterationInfo docI if (isAncestor) { // if this path has pending child docs, add them. - addChildrenToParent( - doc, pendingParentPathsToChildren.remove(fullDocPath)); // no longer pending + if(multiValuedVectorField != null) { + addFlatChildrenToParent(doc, pendingParentPathsToChildren.remove(fullDocPath)); + } else { + addChildrenToParent( + doc, pendingParentPathsToChildren.remove(fullDocPath)); // no longer pending + } } // get parent path @@ -248,7 +257,11 @@ public void transform(SolrDocument rootDoc, int rootDocId, DocIterationInfo docI assert pendingParentPathsToChildren.keySet().size() == 1; // size == 1, so get the last remaining entry - addChildrenToParent(rootDoc, pendingParentPathsToChildren.values().iterator().next()); + if(multiValuedVectorField != null) { + addFlatChildrenToParent(rootDoc, pendingParentPathsToChildren.values().iterator().next()); + } else { + addChildrenToParent(rootDoc, pendingParentPathsToChildren.values().iterator().next()); + } } catch (IOException e) { // TODO DWS: reconsider this unusual error handling approach; shouldn't we rethrow? @@ -285,6 +298,23 @@ private static void addChildrenToParent( parent.setField(trimmedPath, children.get(0)); } + private void addFlatChildrenToParent( + SolrDocument parent, Map> children) { + List solrDocuments = children.get(NESTED_VECTORS_PSEUDO_FIELD_NAME); + for(SolrDocument singleVector: solrDocuments){ + parent.addField(multiValuedVectorField, this.extractVector(singleVector.getFieldValues(multiValuedVectorField))); + } + } + + private Object extractVector(Collection fieldValues) { + List vector = new ArrayList<>(fieldValues.size()); + for (Object fieldValue : fieldValues) { + StoredField storedVectorValue = (StoredField) fieldValue; + vector.add(storedVectorValue.numericValue()); + } + return vector; + } + private static String getLastPath(String path) { int lastIndexOfPathSepChar = path.lastIndexOf(PATH_SEP_CHAR); if (lastIndexOfPathSepChar == -1) { diff --git a/solr/core/src/java/org/apache/solr/schema/IndexSchema.java b/solr/core/src/java/org/apache/solr/schema/IndexSchema.java index 3372bcae0650..ea6a18612295 100644 --- a/solr/core/src/java/org/apache/solr/schema/IndexSchema.java +++ b/solr/core/src/java/org/apache/solr/schema/IndexSchema.java @@ -106,6 +106,7 @@ public class IndexSchema { public static final String NAME = "name"; public static final String NEST_PARENT_FIELD_NAME = "_nest_parent_"; public static final String NEST_PATH_FIELD_NAME = "_nest_path_"; + public static final String NESTED_VECTORS_PSEUDO_FIELD_NAME = "vectors";//"_nested_vectors_"; public static final String REQUIRED = "required"; public static final String SCHEMA = "schema"; public static final String SIMILARITY = "similarity"; diff --git a/solr/core/src/java/org/apache/solr/update/processor/NestedUpdateProcessorFactory.java b/solr/core/src/java/org/apache/solr/update/processor/NestedUpdateProcessorFactory.java index 6367cebf474d..6432a8aa6149 100644 --- a/solr/core/src/java/org/apache/solr/update/processor/NestedUpdateProcessorFactory.java +++ b/solr/core/src/java/org/apache/solr/update/processor/NestedUpdateProcessorFactory.java @@ -31,6 +31,8 @@ import org.apache.solr.schema.SchemaField; import org.apache.solr.update.AddUpdateCommand; +import static org.apache.solr.schema.IndexSchema.NESTED_VECTORS_PSEUDO_FIELD_NAME; + /** * Adds fields to nested documents to support some nested search requirements. It can even generate * uniqueKey fields for nested docs. @@ -154,7 +156,7 @@ private boolean processDocChildren(SolrInputDocument doc, String fullPath) { } this.cleanOriginalVectorFields(doc, originalVectorFieldsToRemove); if (vectors.size() > 0) { - doc.setField("vectors", vectors); + doc.setField(NESTED_VECTORS_PSEUDO_FIELD_NAME, vectors); } return isNested; } diff --git a/solr/core/src/test/org/apache/solr/search/join/BlockJoinMultiValuedVectorsTest.java b/solr/core/src/test/org/apache/solr/search/join/BlockJoinMultiValuedVectorsTest.java index 7a9438ed1e8b..ded54f7aaf8b 100644 --- a/solr/core/src/test/org/apache/solr/search/join/BlockJoinMultiValuedVectorsTest.java +++ b/solr/core/src/test/org/apache/solr/search/join/BlockJoinMultiValuedVectorsTest.java @@ -105,11 +105,11 @@ public void childrenRetrievalFloat_filteringByParentMetadata_shouldReturnKnnChil "parent.fq", "parent_s:(a c)", "allParents", "parent_s:[* TO *]"), "//*[@numFound='5']", - "//result/doc[1]/str[@name='id'][.='8/vector_byte_multivalued#2']", - "//result/doc[2]/str[@name='id'][.='8/vector_byte_multivalued#1']", - "//result/doc[3]/str[@name='id'][.='8/vector_byte_multivalued#0']", - "//result/doc[4]/str[@name='id'][.='6/vector_byte_multivalued#2']", - "//result/doc[5]/str[@name='id'][.='6/vector_byte_multivalued#1']"); + "//result/doc[1]/str[@name='id'][.='8/vector_multivalued#2']", + "//result/doc[2]/str[@name='id'][.='8/vector_multivalued#1']", + "//result/doc[3]/str[@name='id'][.='8/vector_multivalued#0']", + "//result/doc[4]/str[@name='id'][.='6/vector_multivalued#2']", + "//result/doc[5]/str[@name='id'][.='6/vector_multivalued#1']"); } @Test @@ -131,7 +131,7 @@ public void childrenRetrievalByte_filteringByParentMetadata_shouldReturnKnnChild @Test public void parentRetrievalFloat_knnChildren_shouldReturnKnnParents() { - super.parentRetrievalFloat_knnChildren_shouldReturnKnnParents(VECTOR_FIELD); + super.parentRetrieval_knnChildren_shouldReturnKnnParents(VECTOR_FIELD); } @Test @@ -141,31 +141,31 @@ public void parentRetrievalFloat_knnChildrenWithNoDiversifying_shouldReturnOnePa @Test public void parentRetrievalFloat_knnChildrenWithParentFilter_shouldReturnKnnParents() { - super.parentRetrievalFloat_knnChildrenWithParentFilter_shouldReturnKnnParents(VECTOR_FIELD); + super.parentRetrieval_knnChildrenWithParentFilter_shouldReturnKnnParents(VECTOR_FIELD); } @Test public void parentRetrievalByte_knnChildren_shouldReturnKnnParents() { - super.parentRetrievalByte_knnChildren_shouldReturnKnnParents(VECTOR_BYTE_FIELD); + super.parentRetrieval_knnChildren_shouldReturnKnnParents(VECTOR_BYTE_FIELD); } @Test public void parentRetrievalByte_knnChildrenWithParentFilter_shouldReturnKnnParents() { - super.parentRetrievalByte_knnChildrenWithParentFilter_shouldReturnKnnParents(VECTOR_BYTE_FIELD); + super.parentRetrieval_knnChildrenWithParentFilter_shouldReturnKnnParents(VECTOR_BYTE_FIELD); } @Test public void parentRetrievalByte_topKWithChildTransformer_shouldReturnAllChildren() { // new transformer // all vectors - super.parentRetrievalByte_topKWithChildTransformer_shouldReturnAllChildren(VECTOR_BYTE_FIELD); + //super.parentRetrievalByte_topKWithChildTransformer_shouldReturnAllChildren(VECTOR_BYTE_FIELD); } @Test public void parentRetrievalByte_topKWithChildTransformerWithFilter_shouldReturnBestChild() { // new // trasnformer best vector - super.parentRetrievalByte_topKWithChildTransformerWithFilter_shouldReturnBestChild( - VECTOR_BYTE_FIELD); + //super.parentRetrievalByte_topKWithChildTransformerWithFilter_shouldReturnBestChild( + // VECTOR_BYTE_FIELD); } } diff --git a/solr/core/src/test/org/apache/solr/search/join/BlockJoinNestedVectorsParentQParserTest.java b/solr/core/src/test/org/apache/solr/search/join/BlockJoinNestedVectorsParentQParserTest.java index dd904f594860..173e7502d929 100644 --- a/solr/core/src/test/org/apache/solr/search/join/BlockJoinNestedVectorsParentQParserTest.java +++ b/solr/core/src/test/org/apache/solr/search/join/BlockJoinNestedVectorsParentQParserTest.java @@ -85,12 +85,12 @@ protected void parentRetrieval_knnChildrenDiversifyingWithNoAllParents_shouldThr 400); } - protected void childrenRetrievalFloat_filteringByParentMetadata_shouldReturnKnnChildren( + protected void childrenRetrieval_filteringByParentMetadata_shouldReturnKnnChildren( String vectorField) { assertQ( req( "fq", "{!child of=$allParents filters=$parent.fq}", - "q", "{!knn f=" + vectorField + " topK=5}" + FLOAT_QUERY_VECTOR, + "q", "{!knn f=" + vectorField + " topK=5}" + BYTE_QUERY_VECTOR, "fl", "id", "parent.fq", "parent_s:(a c)", "allParents", "parent_s:[* TO *]"), @@ -102,37 +102,6 @@ protected void childrenRetrievalFloat_filteringByParentMetadata_shouldReturnKnnC "//result/doc[5]/str[@name='id'][.='61']"); } - protected void childrenRetrievalByte_filteringByParentMetadata_shouldReturnKnnChildren( - String vectorByteField) { - assertQ( - req( - "fq", "{!child of=$allParents filters=$parent.fq}", - "q", "{!knn f=" + vectorByteField + " topK=5}" + BYTE_QUERY_VECTOR, - "fl", "id", - "parent.fq", "parent_s:(a c)", - "allParents", "parent_s:[* TO *]"), - "//*[@numFound='5']", - "//result/doc[1]/str[@name='id'][.='82']", - "//result/doc[2]/str[@name='id'][.='81']", - "//result/doc[3]/str[@name='id'][.='80']", - "//result/doc[4]/str[@name='id'][.='62']", - "//result/doc[5]/str[@name='id'][.='61']"); - } - - protected void parentRetrievalFloat_knnChildren_shouldReturnKnnParents(String vectorField) { - assertQ( - req( - "q", "{!parent which=$allParents score=max v=$children.q}", - "fl", "id,score", - "children.q", - "{!knn f=" + vectorField + " topK=3 allParents=$allParents}" + FLOAT_QUERY_VECTOR, - "allParents", "parent_s:[* TO *]"), - "//*[@numFound='3']", - "//result/doc[1]/str[@name='id'][.='10']", - "//result/doc[2]/str[@name='id'][.='9']", - "//result/doc[3]/str[@name='id'][.='8']"); - } - protected void parentRetrievalFloat_knnChildrenWithNoDiversifying_shouldReturnOneParent( String vectorField) { assertQ( @@ -145,45 +114,7 @@ protected void parentRetrievalFloat_knnChildrenWithNoDiversifying_shouldReturnOn "//result/doc[1]/str[@name='id'][.='10']"); } - protected void parentRetrievalFloat_knnChildrenWithParentFilter_shouldReturnKnnParents( - String vectorField) { - assertQ( - req( - "q", "{!parent which=$allParents score=max v=$children.q}", - "fl", "id,score", - "children.q", - "{!knn f=" - + vectorField - + " topK=3 childrenOf=$someParents allParents=$allParents}" - + FLOAT_QUERY_VECTOR, - "allParents", "parent_s:[* TO *]", - "someParents", "parent_s:(a c)"), - "//*[@numFound='3']", - "//result/doc[1]/str[@name='id'][.='8']", - "//result/doc[2]/str[@name='id'][.='6']", - "//result/doc[3]/str[@name='id'][.='2']"); - } - - protected void - parentRetrievalFloat_knnChildrenWithParentFilterAndChildrenFilter_shouldReturnKnnParents( - String vectorField) { - assertQ( - req( - "q", "{!parent which=$allParents score=max v=$children.q}", - "fl", "id,score", - "children.q", - "{!knn f=" - + vectorField - + " topK=3 preFilter=child_s:m childrenOf=$someParents allParents=$allParents}" - + FLOAT_QUERY_VECTOR, - "allParents", "parent_s:[* TO *]", - "someParents", "parent_s:(a c)"), - "//*[@numFound='2']", - "//result/doc[1]/str[@name='id'][.='8']", - "//result/doc[2]/str[@name='id'][.='2']"); - } - - protected void parentRetrievalByte_knnChildren_shouldReturnKnnParents(String vectorByteField) { + protected void parentRetrieval_knnChildren_shouldReturnKnnParents(String vectorByteField) { assertQ( req( "q", "{!parent which=$allParents score=max v=$children.q}", @@ -200,7 +131,7 @@ protected void parentRetrievalByte_knnChildren_shouldReturnKnnParents(String vec "//result/doc[3]/str[@name='id'][.='8']"); } - protected void parentRetrievalByte_knnChildrenWithParentFilter_shouldReturnKnnParents( + protected void parentRetrieval_knnChildrenWithParentFilter_shouldReturnKnnParents( String vectorByteField) { assertQ( req( @@ -220,7 +151,7 @@ protected void parentRetrievalByte_knnChildrenWithParentFilter_shouldReturnKnnPa } protected void - parentRetrievalByte_knnChildrenWithParentFilterAndChildrenFilter_shouldReturnKnnParents( + parentRetrieval_knnChildrenWithParentFilterAndChildrenFilter_shouldReturnKnnParents( String vectorByteField) { assertQ( req( @@ -239,7 +170,7 @@ protected void parentRetrievalByte_knnChildrenWithParentFilter_shouldReturnKnnPa } protected void - parentRetrievalFloat_topKWithChildTransformerWithFilter_shouldUseOriginalChildTransformerFilter( + parentRetrievalFloat_topKWithChildTransformer_shouldReturnAllChildren( String vectorField) { assertQ( req( @@ -383,93 +314,6 @@ protected void parentRetrievalByte_knnChildrenWithParentFilter_shouldReturnKnnPa + "']/float[4][.='1.0']"); } - protected void parentRetrievalFloat_topKWithChildTransformerWithFilter_shouldReturnBestChild( - String vectorField) { - assertQ( - req( - "q", - "{!parent which=$allParents score=max v=$children.q}", - "fl", - "id,score," - + VECTORS_PSEUDOFIELD - + "," - + vectorField - + ",[child fl=vector childFilter=$children.q]", - "children.q", - "{!knn f=" - + vectorField - + " topK=3 childrenOf=$someParents allParents=$allParents}" - + FLOAT_QUERY_VECTOR, - "allParents", - "parent_s:[* TO *]", - "someParents", - "parent_s:(b c)"), - "//result[@numFound='3']", - "//result/doc[1]/str[@name='id'][.='8']", - "//result/doc[1]/arr[@name='" - + VECTORS_PSEUDOFIELD - + "'][1]/doc[1]/arr[@name='" - + vectorField - + "']/float[1][.='8.0']", - "//result/doc[1]/arr[@name='" - + VECTORS_PSEUDOFIELD - + "'][1]/doc[1]/arr[@name='" - + vectorField - + "']/float[2][.='1.0']", - "//result/doc[1]/arr[@name='" - + VECTORS_PSEUDOFIELD - + "'][1]/doc[1]/arr[@name='" - + vectorField - + "']/float[3][.='1.0']", - "//result/doc[1]/arr[@name='" - + VECTORS_PSEUDOFIELD - + "'][1]/doc[1]/arr[@name='" - + vectorField - + "']/float[4][.='1.0']", - "//result/doc[2]/str[@name='id'][.='7']", - "//result/doc[2]/arr[@name='" - + VECTORS_PSEUDOFIELD - + "'][1]/doc[1]/arr[@name='" - + vectorField - + "']/float[1][.='11.0']", - "//result/doc[2]/arr[@name='" - + VECTORS_PSEUDOFIELD - + "'][1]/doc[1]/arr[@name='" - + vectorField - + "']/float[2][.='1.0']", - "//result/doc[2]/arr[@name='" - + VECTORS_PSEUDOFIELD - + "'][1]/doc[1]/arr[@name='" - + vectorField - + "']/float[3][.='1.0']", - "//result/doc[2]/arr[@name='" - + VECTORS_PSEUDOFIELD - + "'][1]/doc[1]/arr[@name='" - + vectorField - + "']/float[4][.='1.0']", - "//result/doc[3]/str[@name='id'][.='2']", - "//result/doc[3]/arr[@name='" - + VECTORS_PSEUDOFIELD - + "'][1]/doc[1]/arr[@name='" - + vectorField - + "']/float[1][.='26.0']", - "//result/doc[3]/arr[@name='" - + VECTORS_PSEUDOFIELD - + "'][1]/doc[1]/arr[@name='" - + vectorField - + "']/float[2][.='1.0']", - "//result/doc[3]/arr[@name='" - + VECTORS_PSEUDOFIELD - + "'][1]/doc[1]/arr[@name='" - + vectorField - + "']/float[3][.='1.0']", - "//result/doc[3]/arr[@name='" - + VECTORS_PSEUDOFIELD - + "'][1]/doc[1]/arr[@name='" - + vectorField - + "']/float[4][.='1.0']"); - } - protected void parentRetrievalByte_topKWithChildTransformer_shouldReturnAllChildren( String vectorByteField) { assertQ( diff --git a/solr/core/src/test/org/apache/solr/search/join/BlockJoinNestedVectorsTest.java b/solr/core/src/test/org/apache/solr/search/join/BlockJoinNestedVectorsTest.java index 27c09a520642..3e4683c407f5 100644 --- a/solr/core/src/test/org/apache/solr/search/join/BlockJoinNestedVectorsTest.java +++ b/solr/core/src/test/org/apache/solr/search/join/BlockJoinNestedVectorsTest.java @@ -102,18 +102,18 @@ public void parentRetrieval_knnChildrenDiversifyingWithNoAllParents_shouldThrowE @Test public void childrenRetrievalFloat_filteringByParentMetadata_shouldReturnKnnChildren() { - super.childrenRetrievalFloat_filteringByParentMetadata_shouldReturnKnnChildren(VECTOR_FIELD); + super.childrenRetrieval_filteringByParentMetadata_shouldReturnKnnChildren(VECTOR_FIELD); } @Test public void childrenRetrievalByte_filteringByParentMetadata_shouldReturnKnnChildren() { - super.childrenRetrievalByte_filteringByParentMetadata_shouldReturnKnnChildren( + super.childrenRetrieval_filteringByParentMetadata_shouldReturnKnnChildren( VECTOR_BYTE_FIELD); } @Test public void parentRetrievalFloat_knnChildren_shouldReturnKnnParents() { - super.parentRetrievalFloat_knnChildren_shouldReturnKnnParents(VECTOR_FIELD); + super.parentRetrieval_knnChildren_shouldReturnKnnParents(VECTOR_FIELD); } @Test @@ -123,45 +123,126 @@ public void parentRetrievalFloat_knnChildrenWithNoDiversifying_shouldReturnOnePa @Test public void parentRetrievalFloat_knnChildrenWithParentFilter_shouldReturnKnnParents() { - super.parentRetrievalFloat_knnChildrenWithParentFilter_shouldReturnKnnParents(VECTOR_FIELD); + super.parentRetrieval_knnChildrenWithParentFilter_shouldReturnKnnParents(VECTOR_FIELD); } @Test public void parentRetrievalFloat_knnChildrenWithParentFilterAndChildrenFilter_shouldReturnKnnParents() { - super.parentRetrievalFloat_knnChildrenWithParentFilterAndChildrenFilter_shouldReturnKnnParents( + super.parentRetrieval_knnChildrenWithParentFilterAndChildrenFilter_shouldReturnKnnParents( VECTOR_FIELD); } @Test public void parentRetrievalByte_knnChildren_shouldReturnKnnParents() { - super.parentRetrievalByte_knnChildren_shouldReturnKnnParents(VECTOR_BYTE_FIELD); + super.parentRetrieval_knnChildren_shouldReturnKnnParents(VECTOR_BYTE_FIELD); } @Test public void parentRetrievalByte_knnChildrenWithParentFilter_shouldReturnKnnParents() { - super.parentRetrievalByte_knnChildrenWithParentFilter_shouldReturnKnnParents(VECTOR_BYTE_FIELD); + super.parentRetrieval_knnChildrenWithParentFilter_shouldReturnKnnParents(VECTOR_BYTE_FIELD); } @Test public void parentRetrievalByte_knnChildrenWithParentFilterAndChildrenFilter_shouldReturnKnnParents() { - super.parentRetrievalByte_knnChildrenWithParentFilterAndChildrenFilter_shouldReturnKnnParents( + super.parentRetrieval_knnChildrenWithParentFilterAndChildrenFilter_shouldReturnKnnParents( VECTOR_BYTE_FIELD); } @Test public void - parentRetrievalFloat_topKWithChildTransformerWithFilter_shouldUseOriginalChildTransformerFilter() { + parentRetrievalFloat_topKWithChildTransformerWithFilter_shouldReturnAllChildren() { super - .parentRetrievalFloat_topKWithChildTransformerWithFilter_shouldUseOriginalChildTransformerFilter( + .parentRetrievalFloat_topKWithChildTransformer_shouldReturnAllChildren( VECTOR_FIELD); } @Test public void parentRetrievalFloat_topKWithChildTransformerWithFilter_shouldReturnBestChild() { - super.parentRetrievalFloat_topKWithChildTransformerWithFilter_shouldReturnBestChild( - VECTOR_FIELD); + assertQ( + req( + "q", + "{!parent which=$allParents score=max v=$children.q}", + "fl", + "id,score," + + VECTORS_PSEUDOFIELD + + "," + + VECTOR_FIELD + + ",[child fl=vector childFilter=$children.q]", + "children.q", + "{!knn f=" + + VECTOR_FIELD + + " topK=3 childrenOf=$someParents allParents=$allParents}" + + FLOAT_QUERY_VECTOR, + "allParents", + "parent_s:[* TO *]", + "someParents", + "parent_s:(b c)"), + "//result[@numFound='3']", + "//result/doc[1]/str[@name='id'][.='8']", + "//result/doc[1]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + VECTOR_FIELD + + "']/float[1][.='8.0']", + "//result/doc[1]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + VECTOR_FIELD + + "']/float[2][.='1.0']", + "//result/doc[1]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + VECTOR_FIELD + + "']/float[3][.='1.0']", + "//result/doc[1]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + VECTOR_FIELD + + "']/float[4][.='1.0']", + "//result/doc[2]/str[@name='id'][.='7']", + "//result/doc[2]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + VECTOR_FIELD + + "']/float[1][.='11.0']", + "//result/doc[2]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + VECTOR_FIELD + + "']/float[2][.='1.0']", + "//result/doc[2]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + VECTOR_FIELD + + "']/float[3][.='1.0']", + "//result/doc[2]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + VECTOR_FIELD + + "']/float[4][.='1.0']", + "//result/doc[3]/str[@name='id'][.='2']", + "//result/doc[3]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + VECTOR_FIELD + + "']/float[1][.='26.0']", + "//result/doc[3]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + VECTOR_FIELD + + "']/float[2][.='1.0']", + "//result/doc[3]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + VECTOR_FIELD + + "']/float[3][.='1.0']", + "//result/doc[3]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + VECTOR_FIELD + + "']/float[4][.='1.0']"); } @Test diff --git a/solr/solrj/src/java/org/apache/solr/common/SolrDocument.java b/solr/solrj/src/java/org/apache/solr/common/SolrDocument.java index 16ba4ee00dbf..e18cac9d35c1 100644 --- a/solr/solrj/src/java/org/apache/solr/common/SolrDocument.java +++ b/solr/solrj/src/java/org/apache/solr/common/SolrDocument.java @@ -161,6 +161,8 @@ public void addField(String name, Object value) { _fields.put(name, vals); } + + /////////////////////////////////////////////////////////////////// // Get the field values /////////////////////////////////////////////////////////////////// From c19c33f9ed072594f97b8f9035fde1bab2053719 Mon Sep 17 00:00:00 2001 From: Alessandro Benedetti Date: Thu, 18 Dec 2025 19:33:39 +0100 Subject: [PATCH 34/43] float multi valued vectors managed --- .../transform/ChildDocTransformer.java | 58 +++++++++++++++---- .../org/apache/solr/schema/IndexSchema.java | 2 +- .../join/BlockJoinMultiValuedVectorsTest.java | 50 ++++++++++++++-- 3 files changed, 95 insertions(+), 15 deletions(-) diff --git a/solr/core/src/java/org/apache/solr/response/transform/ChildDocTransformer.java b/solr/core/src/java/org/apache/solr/response/transform/ChildDocTransformer.java index 2819c98fe384..9662f675e1a7 100644 --- a/solr/core/src/java/org/apache/solr/response/transform/ChildDocTransformer.java +++ b/solr/core/src/java/org/apache/solr/response/transform/ChildDocTransformer.java @@ -27,8 +27,10 @@ import java.util.ArrayList; import java.util.Collection; import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Set; import org.apache.lucene.document.StoredField; import org.apache.lucene.index.DocValues; import org.apache.lucene.index.LeafReader; @@ -45,7 +47,9 @@ import org.apache.lucene.util.BytesRef; import org.apache.solr.common.SolrDocument; import org.apache.solr.common.SolrException; +import org.apache.solr.schema.DenseVectorField; import org.apache.solr.schema.IndexSchema; +import org.apache.solr.schema.SchemaField; import org.apache.solr.search.BitsFilteredPostingsEnum; import org.apache.solr.search.DocIterationInfo; import org.apache.solr.search.DocSet; @@ -67,7 +71,6 @@ class ChildDocTransformer extends DocTransformer { private final boolean isNestedSchema; private final SolrReturnFields childReturnFields; private final String[] extraRequestedFields; - private final String multiValuedVectorField = "vector"; ChildDocTransformer( @@ -143,6 +146,8 @@ public void transform(SolrDocument rootDoc, int rootDocId, DocIterationInfo docI final Bits liveDocs = leafReaderContext.reader().getLiveDocs(); final int segBaseId = leafReaderContext.docBase; final int segRootId = rootDocId - segBaseId; + Set multiValuedVectorFields = + this.getMultiValuedVectorFields(searcher.getSchema(), childReturnFields); // can return be -1 and that's okay (happens for very first block) final int segPrevRootId; @@ -224,8 +229,11 @@ public void transform(SolrDocument rootDoc, int rootDocId, DocIterationInfo docI if (isAncestor) { // if this path has pending child docs, add them. - if(multiValuedVectorField != null) { - addFlatChildrenToParent(doc, pendingParentPathsToChildren.remove(fullDocPath)); + if (!multiValuedVectorFields.isEmpty()) { + addFlatChildrenToParent( + doc, + pendingParentPathsToChildren.remove(fullDocPath), + multiValuedVectorFields); } else { addChildrenToParent( doc, pendingParentPathsToChildren.remove(fullDocPath)); // no longer pending @@ -257,8 +265,11 @@ public void transform(SolrDocument rootDoc, int rootDocId, DocIterationInfo docI assert pendingParentPathsToChildren.keySet().size() == 1; // size == 1, so get the last remaining entry - if(multiValuedVectorField != null) { - addFlatChildrenToParent(rootDoc, pendingParentPathsToChildren.values().iterator().next()); + if (!multiValuedVectorFields.isEmpty()) { + addFlatChildrenToParent( + rootDoc, + pendingParentPathsToChildren.values().iterator().next(), + multiValuedVectorFields); } else { addChildrenToParent(rootDoc, pendingParentPathsToChildren.values().iterator().next()); } @@ -270,6 +281,27 @@ public void transform(SolrDocument rootDoc, int rootDocId, DocIterationInfo docI } } + private Set getMultiValuedVectorFields( + IndexSchema schema, + SolrReturnFields childReturnFields) { + Set multiValuedVectorsFields = new HashSet<>(); + for (String fieldName : childReturnFields.getExplicitlyRequestedFieldNames()) { + SchemaField sfield = schema.getFieldOrNull(fieldName); + if (sfield.getType() instanceof DenseVectorField && sfield.multiValued()) { + multiValuedVectorsFields.add(fieldName); + } + } + if (multiValuedVectorsFields.size() != childReturnFields.getExplicitlyRequestedFieldNames() + .size()) { + throw new SolrException( + SolrException.ErrorCode.BAD_REQUEST, + "When using the Child transformer to flatten nested vectors, all 'fl' must be " + + "multivalued vector fields"); + } else { + return multiValuedVectorsFields; + } + } + private static void addChildrenToParent( SolrDocument parent, Map> children) { for (Map.Entry> entry : children.entrySet()) { @@ -299,14 +331,20 @@ private static void addChildrenToParent( } private void addFlatChildrenToParent( - SolrDocument parent, Map> children) { - List solrDocuments = children.get(NESTED_VECTORS_PSEUDO_FIELD_NAME); - for(SolrDocument singleVector: solrDocuments){ - parent.addField(multiValuedVectorField, this.extractVector(singleVector.getFieldValues(multiValuedVectorField))); + SolrDocument parent, Map> children, + Set multiValuedVectorFields) { + for (String multiValuedVectorField : multiValuedVectorFields) { + List solrDocuments = children.get(multiValuedVectorField); + List> multiValuedVectors = new ArrayList<>(solrDocuments.size()); + for (SolrDocument singleVector : solrDocuments) { + multiValuedVectors.add(this.extractVector(singleVector.getFieldValues(multiValuedVectorField))); + } + parent.setField(multiValuedVectorField, multiValuedVectors); } } - private Object extractVector(Collection fieldValues) { + private List extractVector(Collection fieldValues) { + //manage Byte List vector = new ArrayList<>(fieldValues.size()); for (Object fieldValue : fieldValues) { StoredField storedVectorValue = (StoredField) fieldValue; diff --git a/solr/core/src/java/org/apache/solr/schema/IndexSchema.java b/solr/core/src/java/org/apache/solr/schema/IndexSchema.java index ea6a18612295..f0e4f2f49f94 100644 --- a/solr/core/src/java/org/apache/solr/schema/IndexSchema.java +++ b/solr/core/src/java/org/apache/solr/schema/IndexSchema.java @@ -106,7 +106,7 @@ public class IndexSchema { public static final String NAME = "name"; public static final String NEST_PARENT_FIELD_NAME = "_nest_parent_"; public static final String NEST_PATH_FIELD_NAME = "_nest_path_"; - public static final String NESTED_VECTORS_PSEUDO_FIELD_NAME = "vectors";//"_nested_vectors_"; + public static final String NESTED_VECTORS_PSEUDO_FIELD_NAME = "_nested_vectors_"; public static final String REQUIRED = "required"; public static final String SCHEMA = "schema"; public static final String SIMILARITY = "similarity"; diff --git a/solr/core/src/test/org/apache/solr/search/join/BlockJoinMultiValuedVectorsTest.java b/solr/core/src/test/org/apache/solr/search/join/BlockJoinMultiValuedVectorsTest.java index ded54f7aaf8b..4f525e66b774 100644 --- a/solr/core/src/test/org/apache/solr/search/join/BlockJoinMultiValuedVectorsTest.java +++ b/solr/core/src/test/org/apache/solr/search/join/BlockJoinMultiValuedVectorsTest.java @@ -163,9 +163,51 @@ public void parentRetrievalByte_knnChildrenWithParentFilter_shouldReturnKnnParen @Test public void - parentRetrievalByte_topKWithChildTransformerWithFilter_shouldReturnBestChild() { // new - // trasnformer best vector - //super.parentRetrievalByte_topKWithChildTransformerWithFilter_shouldReturnBestChild( - // VECTOR_BYTE_FIELD); + parentRetrievalFloat_ChildTransformerWithChildFilter_shouldFlattenAndReturnBestChild() { + assertQ( + req( + "q", "{!parent which=$allParents score=max v=$children.q}", + "fl", "id,"+VECTOR_FIELD+", [child fl="+ VECTOR_FIELD + " childFilter=$children.q]", + "children.q", + "{!knn f=" + + VECTOR_FIELD + + " topK=3 allParents=$allParents}" + + FLOAT_QUERY_VECTOR, + "allParents", "parent_s:[* TO *]"), + "//*[@numFound='3']", + "//result/doc[1]/str[@name='id'][.='10']", + "//result/doc[1]/arr[@name='" + VECTOR_FIELD + "']/arr/float[1][.='2.0']", + "//result/doc[1]/arr[@name='" + VECTOR_FIELD + "']/arr/float[2][.='1.0']", + "//result/doc[1]/arr[@name='" + VECTOR_FIELD + "']/arr/float[3][.='1.0']", + "//result/doc[1]/arr[@name='" + VECTOR_FIELD + "']/arr/float[4][.='1.0']", + "//result/doc[2]/str[@name='id'][.='9']", + "//result/doc[2]/arr[@name='" + VECTOR_FIELD + "']/arr/float[1][.='5.0']", + "//result/doc[2]/arr[@name='" + VECTOR_FIELD + "']/arr/float[2][.='1.0']", + "//result/doc[2]/arr[@name='" + VECTOR_FIELD + "']/arr/float[3][.='1.0']", + "//result/doc[2]/arr[@name='" + VECTOR_FIELD + "']/arr/float[4][.='1.0']", + "//result/doc[3]/str[@name='id'][.='8']", + "//result/doc[3]/arr[@name='" + VECTOR_FIELD + "']/arr/float[1][.='8.0']", + "//result/doc[3]/arr[@name='" + VECTOR_FIELD + "']/arr/float[2][.='1.0']", + "//result/doc[3]/arr[@name='" + VECTOR_FIELD + "']/arr/float[3][.='1.0']", + "//result/doc[3]/arr[@name='" + VECTOR_FIELD + "']/arr/float[4][.='1.0']"); + } + + @Test + public void + parentRetrievalByte_ChildTransformerWithChildFilter_shouldFlattenAndReturnBestChild() { + assertQ( + req( + "q", "{!parent which=$allParents score=max v=$children.q}", + "fl", "id,[child fl="+ VECTOR_BYTE_FIELD + " childFilter=$children.q]", + "children.q", + "{!knn f=" + + VECTOR_BYTE_FIELD + + " topK=3 allParents=$allParents}" + + BYTE_QUERY_VECTOR, + "allParents", "parent_s:[* TO *]"), + "//*[@numFound='3']", + "//result/doc[1]/str[@name='id'][.='10']", + "//result/doc[2]/str[@name='id'][.='9']", + "//result/doc[3]/str[@name='id'][.='8']"); } } From 80cf944a676c8b5660fc569f18cb0c64cd692e87 Mon Sep 17 00:00:00 2001 From: Alessandro Benedetti Date: Fri, 19 Dec 2025 23:33:39 +0100 Subject: [PATCH 35/43] first fully working draft, green tests --- .../transform/ChildDocTransformer.java | 83 ++++++++---- .../join/BlockJoinMultiValuedVectorsTest.java | 126 +++++++++++++++++- 2 files changed, 179 insertions(+), 30 deletions(-) diff --git a/solr/core/src/java/org/apache/solr/response/transform/ChildDocTransformer.java b/solr/core/src/java/org/apache/solr/response/transform/ChildDocTransformer.java index 9662f675e1a7..b9ae1e77737f 100644 --- a/solr/core/src/java/org/apache/solr/response/transform/ChildDocTransformer.java +++ b/solr/core/src/java/org/apache/solr/response/transform/ChildDocTransformer.java @@ -19,7 +19,6 @@ import static org.apache.solr.response.transform.ChildDocTransformerFactory.NUM_SEP_CHAR; import static org.apache.solr.response.transform.ChildDocTransformerFactory.PATH_SEP_CHAR; -import static org.apache.solr.schema.IndexSchema.NESTED_VECTORS_PSEUDO_FIELD_NAME; import static org.apache.solr.schema.IndexSchema.NEST_PATH_FIELD_NAME; import java.io.IOException; @@ -40,6 +39,7 @@ import org.apache.lucene.index.SortedDocValues; import org.apache.lucene.index.Terms; import org.apache.lucene.index.TermsEnum; +import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.join.BitSetProducer; import org.apache.lucene.util.BitSet; @@ -146,8 +146,16 @@ public void transform(SolrDocument rootDoc, int rootDocId, DocIterationInfo docI final Bits liveDocs = leafReaderContext.reader().getLiveDocs(); final int segBaseId = leafReaderContext.docBase; final int segRootId = rootDocId - segBaseId; - Set multiValuedVectorFields = - this.getMultiValuedVectorFields(searcher.getSchema(), childReturnFields); + Set multiValuedFLoatVectorFields = + this.getMultiValuedVectorFields(searcher.getSchema(), childReturnFields, VectorEncoding.FLOAT32); + Set multiValuedByteVectorFields = + this.getMultiValuedVectorFields(searcher.getSchema(), childReturnFields, VectorEncoding.BYTE); + if ((multiValuedFLoatVectorFields.size() + multiValuedByteVectorFields.size()) != childReturnFields.getExplicitlyRequestedFieldNames() + .size()) { + throw new SolrException( + SolrException.ErrorCode.BAD_REQUEST, + "When using the Child transformer to flatten nested vectors, all 'fl' must be " + + "multivalued vector fields");} // can return be -1 and that's okay (happens for very first block) final int segPrevRootId; @@ -229,11 +237,15 @@ public void transform(SolrDocument rootDoc, int rootDocId, DocIterationInfo docI if (isAncestor) { // if this path has pending child docs, add them. - if (!multiValuedVectorFields.isEmpty()) { - addFlatChildrenToParent( - doc, - pendingParentPathsToChildren.remove(fullDocPath), - multiValuedVectorFields); + if (!multiValuedFLoatVectorFields.isEmpty() || !multiValuedByteVectorFields.isEmpty()) { + addFlatMultiValuedFloatVectorsToParent( + rootDoc, + pendingParentPathsToChildren.values().iterator().next(), + multiValuedFLoatVectorFields); + addFlatMultiValuedByteVectorsToParent( + rootDoc, + pendingParentPathsToChildren.values().iterator().next(), + multiValuedByteVectorFields); } else { addChildrenToParent( doc, pendingParentPathsToChildren.remove(fullDocPath)); // no longer pending @@ -265,11 +277,15 @@ public void transform(SolrDocument rootDoc, int rootDocId, DocIterationInfo docI assert pendingParentPathsToChildren.keySet().size() == 1; // size == 1, so get the last remaining entry - if (!multiValuedVectorFields.isEmpty()) { - addFlatChildrenToParent( + if (!multiValuedFLoatVectorFields.isEmpty() || !multiValuedByteVectorFields.isEmpty()) { + addFlatMultiValuedFloatVectorsToParent( rootDoc, pendingParentPathsToChildren.values().iterator().next(), - multiValuedVectorFields); + multiValuedFLoatVectorFields); + addFlatMultiValuedByteVectorsToParent( + rootDoc, + pendingParentPathsToChildren.values().iterator().next(), + multiValuedByteVectorFields); } else { addChildrenToParent(rootDoc, pendingParentPathsToChildren.values().iterator().next()); } @@ -281,25 +297,16 @@ public void transform(SolrDocument rootDoc, int rootDocId, DocIterationInfo docI } } - private Set getMultiValuedVectorFields( - IndexSchema schema, - SolrReturnFields childReturnFields) { + private Set getMultiValuedVectorFields(IndexSchema schema, + SolrReturnFields childReturnFields, VectorEncoding encoding) { Set multiValuedVectorsFields = new HashSet<>(); for (String fieldName : childReturnFields.getExplicitlyRequestedFieldNames()) { SchemaField sfield = schema.getFieldOrNull(fieldName); - if (sfield.getType() instanceof DenseVectorField && sfield.multiValued()) { + if (sfield.getType() instanceof DenseVectorField && sfield.multiValued() && ((DenseVectorField) sfield.getType()).getVectorEncoding() == encoding) { multiValuedVectorsFields.add(fieldName); } } - if (multiValuedVectorsFields.size() != childReturnFields.getExplicitlyRequestedFieldNames() - .size()) { - throw new SolrException( - SolrException.ErrorCode.BAD_REQUEST, - "When using the Child transformer to flatten nested vectors, all 'fl' must be " - + "multivalued vector fields"); - } else { return multiValuedVectorsFields; - } } private static void addChildrenToParent( @@ -330,21 +337,33 @@ private static void addChildrenToParent( parent.setField(trimmedPath, children.get(0)); } - private void addFlatChildrenToParent( + private void addFlatMultiValuedFloatVectorsToParent( SolrDocument parent, Map> children, Set multiValuedVectorFields) { for (String multiValuedVectorField : multiValuedVectorFields) { List solrDocuments = children.get(multiValuedVectorField); List> multiValuedVectors = new ArrayList<>(solrDocuments.size()); for (SolrDocument singleVector : solrDocuments) { - multiValuedVectors.add(this.extractVector(singleVector.getFieldValues(multiValuedVectorField))); + multiValuedVectors.add(this.extractFloatVector(singleVector.getFieldValues(multiValuedVectorField))); } parent.setField(multiValuedVectorField, multiValuedVectors); } } - private List extractVector(Collection fieldValues) { - //manage Byte + private void addFlatMultiValuedByteVectorsToParent( + SolrDocument parent, Map> children, + Set multiValuedVectorFields) { + for (String multiValuedVectorField : multiValuedVectorFields) { + List solrDocuments = children.get(multiValuedVectorField); + List> multiValuedVectors = new ArrayList<>(solrDocuments.size()); + for (SolrDocument singleVector : solrDocuments) { + multiValuedVectors.add(this.extractByteVector(singleVector.getFieldValues(multiValuedVectorField))); + } + parent.setField(multiValuedVectorField, multiValuedVectors); + } + } + + private List extractFloatVector(Collection fieldValues) { List vector = new ArrayList<>(fieldValues.size()); for (Object fieldValue : fieldValues) { StoredField storedVectorValue = (StoredField) fieldValue; @@ -353,6 +372,16 @@ private List extractVector(Collection fieldValues) { return vector; } + private List extractByteVector(Collection singleVector) { + StoredField vector = (StoredField) singleVector.iterator().next(); + BytesRef byteVector = vector.binaryValue(); + List extractedVector = new ArrayList<>(byteVector.length); + for (Byte element : byteVector.bytes) { + extractedVector.add(element.byteValue()); + } + return extractedVector; + } + private static String getLastPath(String path) { int lastIndexOfPathSepChar = path.lastIndexOf(PATH_SEP_CHAR); if (lastIndexOfPathSepChar == -1) { diff --git a/solr/core/src/test/org/apache/solr/search/join/BlockJoinMultiValuedVectorsTest.java b/solr/core/src/test/org/apache/solr/search/join/BlockJoinMultiValuedVectorsTest.java index 4f525e66b774..7c34d7842a6b 100644 --- a/solr/core/src/test/org/apache/solr/search/join/BlockJoinMultiValuedVectorsTest.java +++ b/solr/core/src/test/org/apache/solr/search/join/BlockJoinMultiValuedVectorsTest.java @@ -192,13 +192,67 @@ public void parentRetrievalByte_knnChildrenWithParentFilter_shouldReturnKnnParen "//result/doc[3]/arr[@name='" + VECTOR_FIELD + "']/arr/float[4][.='1.0']"); } + @Test + public void parentRetrievalFloat_ChildTransformer_shouldFlattenAndReturnAllChildren() { + assertQ( + req( + "q", "{!parent which=$allParents score=max v=$children.q}", + "fl", "id,"+VECTOR_FIELD+", [child fl="+ VECTOR_FIELD + " ]", + "children.q", + "{!knn f=" + + VECTOR_FIELD + + " topK=3 allParents=$allParents}" + + FLOAT_QUERY_VECTOR, + "allParents", "parent_s:[* TO *]"), + "//*[@numFound='3']", + "//result/doc[1]/str[@name='id'][.='10']", + "//result/doc[1]/arr[@name='" + VECTOR_FIELD + "']/arr[1]/float[1][.='4.0']", + "//result/doc[1]/arr[@name='" + VECTOR_FIELD + "']/arr[1]/float[2][.='1.0']", + "//result/doc[1]/arr[@name='" + VECTOR_FIELD + "']/arr[1]/float[3][.='1.0']", + "//result/doc[1]/arr[@name='" + VECTOR_FIELD + "']/arr[1]/float[4][.='1.0']", + "//result/doc[1]/arr[@name='" + VECTOR_FIELD + "']/arr[2]/float[1][.='3.0']", + "//result/doc[1]/arr[@name='" + VECTOR_FIELD + "']/arr[2]/float[2][.='1.0']", + "//result/doc[1]/arr[@name='" + VECTOR_FIELD + "']/arr[2]/float[3][.='1.0']", + "//result/doc[1]/arr[@name='" + VECTOR_FIELD + "']/arr[2]/float[4][.='1.0']", + "//result/doc[1]/arr[@name='" + VECTOR_FIELD + "']/arr[3]/float[1][.='2.0']", + "//result/doc[1]/arr[@name='" + VECTOR_FIELD + "']/arr[3]/float[2][.='1.0']", + "//result/doc[1]/arr[@name='" + VECTOR_FIELD + "']/arr[3]/float[3][.='1.0']", + "//result/doc[1]/arr[@name='" + VECTOR_FIELD + "']/arr[3]/float[4][.='1.0']", + "//result/doc[2]/str[@name='id'][.='9']", + "//result/doc[2]/arr[@name='" + VECTOR_FIELD + "']/arr[1]/float[1][.='7.0']", + "//result/doc[2]/arr[@name='" + VECTOR_FIELD + "']/arr[1]/float[2][.='1.0']", + "//result/doc[2]/arr[@name='" + VECTOR_FIELD + "']/arr[1]/float[3][.='1.0']", + "//result/doc[2]/arr[@name='" + VECTOR_FIELD + "']/arr[1]/float[4][.='1.0']", + "//result/doc[2]/arr[@name='" + VECTOR_FIELD + "']/arr[2]/float[1][.='6.0']", + "//result/doc[2]/arr[@name='" + VECTOR_FIELD + "']/arr[2]/float[2][.='1.0']", + "//result/doc[2]/arr[@name='" + VECTOR_FIELD + "']/arr[2]/float[3][.='1.0']", + "//result/doc[2]/arr[@name='" + VECTOR_FIELD + "']/arr[2]/float[4][.='1.0']", + "//result/doc[2]/arr[@name='" + VECTOR_FIELD + "']/arr[3]/float[1][.='5.0']", + "//result/doc[2]/arr[@name='" + VECTOR_FIELD + "']/arr[3]/float[2][.='1.0']", + "//result/doc[2]/arr[@name='" + VECTOR_FIELD + "']/arr[3]/float[3][.='1.0']", + "//result/doc[2]/arr[@name='" + VECTOR_FIELD + "']/arr[3]/float[4][.='1.0']", + "//result/doc[3]/str[@name='id'][.='8']", + "//result/doc[3]/arr[@name='" + VECTOR_FIELD + "']/arr[1]/float[1][.='10.0']", + "//result/doc[3]/arr[@name='" + VECTOR_FIELD + "']/arr[1]/float[2][.='1.0']", + "//result/doc[3]/arr[@name='" + VECTOR_FIELD + "']/arr[1]/float[3][.='1.0']", + "//result/doc[3]/arr[@name='" + VECTOR_FIELD + "']/arr[1]/float[4][.='1.0']", + "//result/doc[3]/arr[@name='" + VECTOR_FIELD + "']/arr[2]/float[1][.='9.0']", + "//result/doc[3]/arr[@name='" + VECTOR_FIELD + "']/arr[2]/float[2][.='1.0']", + "//result/doc[3]/arr[@name='" + VECTOR_FIELD + "']/arr[2]/float[3][.='1.0']", + "//result/doc[3]/arr[@name='" + VECTOR_FIELD + "']/arr[2]/float[4][.='1.0']", + "//result/doc[3]/arr[@name='" + VECTOR_FIELD + "']/arr[3]/float[1][.='8.0']", + "//result/doc[3]/arr[@name='" + VECTOR_FIELD + "']/arr[3]/float[2][.='1.0']", + "//result/doc[3]/arr[@name='" + VECTOR_FIELD + "']/arr[3]/float[3][.='1.0']", + "//result/doc[3]/arr[@name='" + VECTOR_FIELD + "']/arr[3]/float[4][.='1.0']"); + } + @Test public void - parentRetrievalByte_ChildTransformerWithChildFilter_shouldFlattenAndReturnBestChild() { + parentRetrievalByte_ChildTransformerWithChildFilter_shouldFlattenAndReturnBestChild() { assertQ( req( "q", "{!parent which=$allParents score=max v=$children.q}", - "fl", "id,[child fl="+ VECTOR_BYTE_FIELD + " childFilter=$children.q]", + "fl", "id,"+VECTOR_BYTE_FIELD+", [child fl="+ VECTOR_BYTE_FIELD + " childFilter=$children.q]", "children.q", "{!knn f=" + VECTOR_BYTE_FIELD @@ -207,7 +261,73 @@ public void parentRetrievalByte_knnChildrenWithParentFilter_shouldReturnKnnParen "allParents", "parent_s:[* TO *]"), "//*[@numFound='3']", "//result/doc[1]/str[@name='id'][.='10']", + "//result/doc[1]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr/int[1][.='2']", + "//result/doc[1]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr/int[2][.='1']", + "//result/doc[1]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr/int[3][.='1']", + "//result/doc[1]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr/int[4][.='1']", "//result/doc[2]/str[@name='id'][.='9']", - "//result/doc[3]/str[@name='id'][.='8']"); + "//result/doc[2]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr/int[1][.='5']", + "//result/doc[2]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr/int[2][.='1']", + "//result/doc[2]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr/int[3][.='1']", + "//result/doc[2]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr/int[4][.='1']", + "//result/doc[3]/str[@name='id'][.='8']", + "//result/doc[3]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr/int[1][.='8']", + "//result/doc[3]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr/int[2][.='1']", + "//result/doc[3]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr/int[3][.='1']", + "//result/doc[3]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr/int[4][.='1']"); + } + + @Test + public void parentRetrievalByte_ChildTransformer_shouldFlattenAndReturnAllChildren() { + assertQ( + req( + "q", "{!parent which=$allParents score=max v=$children.q}", + "fl", "id,"+VECTOR_BYTE_FIELD+", [child fl="+ VECTOR_BYTE_FIELD + " ]", + "children.q", + "{!knn f=" + + VECTOR_BYTE_FIELD + + " topK=3 allParents=$allParents}" + + BYTE_QUERY_VECTOR, + "allParents", "parent_s:[* TO *]"), + "//*[@numFound='3']", + "//result/doc[1]/str[@name='id'][.='10']", + "//result/doc[1]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr[1]/int[1][.='4']", + "//result/doc[1]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr[1]/int[2][.='1']", + "//result/doc[1]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr[1]/int[3][.='1']", + "//result/doc[1]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr[1]/int[4][.='1']", + "//result/doc[1]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr[2]/int[1][.='3']", + "//result/doc[1]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr[2]/int[2][.='1']", + "//result/doc[1]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr[2]/int[3][.='1']", + "//result/doc[1]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr[2]/int[4][.='1']", + "//result/doc[1]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr[3]/int[1][.='2']", + "//result/doc[1]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr[3]/int[2][.='1']", + "//result/doc[1]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr[3]/int[3][.='1']", + "//result/doc[1]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr[3]/int[4][.='1']", + "//result/doc[2]/str[@name='id'][.='9']", + "//result/doc[2]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr[1]/int[1][.='7']", + "//result/doc[2]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr[1]/int[2][.='1']", + "//result/doc[2]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr[1]/int[3][.='1']", + "//result/doc[2]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr[1]/int[4][.='1']", + "//result/doc[2]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr[2]/int[1][.='6']", + "//result/doc[2]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr[2]/int[2][.='1']", + "//result/doc[2]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr[2]/int[3][.='1']", + "//result/doc[2]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr[2]/int[4][.='1']", + "//result/doc[2]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr[3]/int[1][.='5']", + "//result/doc[2]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr[3]/int[2][.='1']", + "//result/doc[2]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr[3]/int[3][.='1']", + "//result/doc[2]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr[3]/int[4][.='1']", + "//result/doc[3]/str[@name='id'][.='8']", + "//result/doc[3]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr[1]/int[1][.='10']", + "//result/doc[3]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr[1]/int[2][.='1']", + "//result/doc[3]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr[1]/int[3][.='1']", + "//result/doc[3]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr[1]/int[4][.='1']", + "//result/doc[3]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr[2]/int[1][.='9']", + "//result/doc[3]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr[2]/int[2][.='1']", + "//result/doc[3]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr[2]/int[3][.='1']", + "//result/doc[3]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr[2]/int[4][.='1']", + "//result/doc[3]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr[3]/int[1][.='8']", + "//result/doc[3]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr[3]/int[2][.='1']", + "//result/doc[3]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr[3]/int[3][.='1']", + "//result/doc[3]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr[3]/int[4][.='1']"); } } From c02d8ca9073afae80166b14befeaf0666720da35 Mon Sep 17 00:00:00 2001 From: Alessandro Benedetti Date: Thu, 15 Jan 2026 11:03:27 +0100 Subject: [PATCH 36/43] first fully working draft, green tests --- .../transform/ChildDocTransformer.java | 1 - .../collection1/conf/schema-densevector.xml | 3 -- .../join/BlockJoinMultiValuedVectorsTest.java | 7 ---- .../search/vector/KnnQParserChildTest.java | 1 - .../pages/dense-vector-search.adoc | 40 ------------------- .../org/apache/solr/common/SolrDocument.java | 2 - 6 files changed, 54 deletions(-) diff --git a/solr/core/src/java/org/apache/solr/response/transform/ChildDocTransformer.java b/solr/core/src/java/org/apache/solr/response/transform/ChildDocTransformer.java index b9ae1e77737f..0956e58af170 100644 --- a/solr/core/src/java/org/apache/solr/response/transform/ChildDocTransformer.java +++ b/solr/core/src/java/org/apache/solr/response/transform/ChildDocTransformer.java @@ -72,7 +72,6 @@ class ChildDocTransformer extends DocTransformer { private final SolrReturnFields childReturnFields; private final String[] extraRequestedFields; - ChildDocTransformer( String name, BitSetProducer parentsFilter, diff --git a/solr/core/src/test-files/solr/collection1/conf/schema-densevector.xml b/solr/core/src/test-files/solr/collection1/conf/schema-densevector.xml index 116221a3151d..aaffb1336b6d 100644 --- a/solr/core/src/test-files/solr/collection1/conf/schema-densevector.xml +++ b/solr/core/src/test-files/solr/collection1/conf/schema-densevector.xml @@ -18,8 +18,6 @@ - - @@ -68,6 +66,5 @@ - id diff --git a/solr/core/src/test/org/apache/solr/search/join/BlockJoinMultiValuedVectorsTest.java b/solr/core/src/test/org/apache/solr/search/join/BlockJoinMultiValuedVectorsTest.java index 7c34d7842a6b..e05208677b4b 100644 --- a/solr/core/src/test/org/apache/solr/search/join/BlockJoinMultiValuedVectorsTest.java +++ b/solr/core/src/test/org/apache/solr/search/join/BlockJoinMultiValuedVectorsTest.java @@ -154,13 +154,6 @@ public void parentRetrievalByte_knnChildrenWithParentFilter_shouldReturnKnnParen super.parentRetrieval_knnChildrenWithParentFilter_shouldReturnKnnParents(VECTOR_BYTE_FIELD); } - @Test - public void - parentRetrievalByte_topKWithChildTransformer_shouldReturnAllChildren() { // new transformer - // all vectors - //super.parentRetrievalByte_topKWithChildTransformer_shouldReturnAllChildren(VECTOR_BYTE_FIELD); - } - @Test public void parentRetrievalFloat_ChildTransformerWithChildFilter_shouldFlattenAndReturnBestChild() { diff --git a/solr/core/src/test/org/apache/solr/search/vector/KnnQParserChildTest.java b/solr/core/src/test/org/apache/solr/search/vector/KnnQParserChildTest.java index 9d99f1fdfc21..31fa2fe21b45 100644 --- a/solr/core/src/test/org/apache/solr/search/vector/KnnQParserChildTest.java +++ b/solr/core/src/test/org/apache/solr/search/vector/KnnQParserChildTest.java @@ -17,7 +17,6 @@ package org.apache.solr.search.vector; import java.util.List; -import java.util.Random; import java.util.stream.Collectors; import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase; import org.apache.lucene.tests.util.TestUtil; diff --git a/solr/solr-ref-guide/modules/query-guide/pages/dense-vector-search.adoc b/solr/solr-ref-guide/modules/query-guide/pages/dense-vector-search.adoc index 2cb3cac4f4ad..4dc9239fd0b6 100644 --- a/solr/solr-ref-guide/modules/query-guide/pages/dense-vector-search.adoc +++ b/solr/solr-ref-guide/modules/query-guide/pages/dense-vector-search.adoc @@ -560,46 +560,6 @@ Here is an example of a `knn` search using a `parents.preFilter`: The search results retrieved are the k=3 nearest documents to the vector in input `[1.0, 2.0, 3.0, 4.0]`, each of them with a different parent. Only the documents with a parent that satisfy the 'color_s:RED' condition are considered candidates for the ANN search. -`childrenOf`:: -+ -[%autowidth,frame=none] -|=== -|Optional |Default: none -|=== -This parameter is meant to be a filter query on parent document metadata. -The knn search returns the top-k nearest children documents that satify the filter on the parent. -+ -Only one child per distinct parent is returned. - -Here is an example of a `knn` search using a `childrenOf`: - -[source,text] -?q={!knn f=vector topK=3 childrenOf=$someParents allParents=$allParents}[1.0, 2.0, 3.0, 4.0] -&allParents=*:* -_nest_path_:* -&someParents=color_s:RED - -The search results retrieved are the k=3 nearest documents to the vector in input `[1.0, 2.0, 3.0, 4.0]`, each of them with a different parent. Only the documents with a parent that satisfy the 'color_s:RED' condition are considered candidates for the ANN search. - -`allParents`:: -+ -[%autowidth,frame=none] -|=== -|Optional |Default: none -|Mandatory if using 'childrenOf' parameter|Default: none -|=== -+ -A query that matches ALL parents. -It's required to work with the 'childrenOf' parameter. - - -Here is an example of a `knn` search using a `allParents`: - -[source,text] -?q={!knn f=vector topK=3 allParents=$allParents}[1.0, 2.0, 3.0, 4.0] -&allParents=*:* -_nest_path_:* - -The search results retrieved are the k=3 nearest documents to the vector in input `[1.0, 2.0, 3.0, 4.0]`, each of them with a different parent. The 'allParents' parameter must return all parents to guarantee the correct functioning of the query. - === knn_text_to_vector Query Parser The `knn_text_to_vector` query parser encode a textual query to a vector using a dedicated Large Language Model(fine tuned for the task of encoding text to vector for sentence similarity) and matches k-nearest neighbours documents to such query vector. diff --git a/solr/solrj/src/java/org/apache/solr/common/SolrDocument.java b/solr/solrj/src/java/org/apache/solr/common/SolrDocument.java index e18cac9d35c1..16ba4ee00dbf 100644 --- a/solr/solrj/src/java/org/apache/solr/common/SolrDocument.java +++ b/solr/solrj/src/java/org/apache/solr/common/SolrDocument.java @@ -161,8 +161,6 @@ public void addField(String name, Object value) { _fields.put(name, vals); } - - /////////////////////////////////////////////////////////////////// // Get the field values /////////////////////////////////////////////////////////////////// From dfe966a8f7452f18f1fcf9191b4175ebb260f12f Mon Sep 17 00:00:00 2001 From: Alessandro Benedetti Date: Thu, 15 Jan 2026 11:23:56 +0100 Subject: [PATCH 37/43] tidy and changelog --- changelog/unreleased/SOLR-18074.yml | 8 +++ .../transform/ChildDocTransformer.java | 35 ++++++---- .../NestedUpdateProcessorFactory.java | 4 +- .../join/BlockJoinMultiValuedVectorsTest.java | 64 ++++++++++--------- ...ockJoinNestedVectorsParentQParserTest.java | 5 +- .../join/BlockJoinNestedVectorsTest.java | 10 +-- 6 files changed, 71 insertions(+), 55 deletions(-) create mode 100644 changelog/unreleased/SOLR-18074.yml diff --git a/changelog/unreleased/SOLR-18074.yml b/changelog/unreleased/SOLR-18074.yml new file mode 100644 index 000000000000..69dba9c2966b --- /dev/null +++ b/changelog/unreleased/SOLR-18074.yml @@ -0,0 +1,8 @@ +# See https://github.com/apache/solr/blob/main/dev-docs/changelog.adoc +title: Introducing support for multi valued dense vector representation in documents through nested vectors +type: added # added, changed, fixed, deprecated, removed, dependency_update, security, other +authors: + - name: Alessandro Benedetti +links: + - name: SOLR-18074 + url: https://issues.apache.org/jira/browse/SOLR-18074 diff --git a/solr/core/src/java/org/apache/solr/response/transform/ChildDocTransformer.java b/solr/core/src/java/org/apache/solr/response/transform/ChildDocTransformer.java index 0956e58af170..a844dccce744 100644 --- a/solr/core/src/java/org/apache/solr/response/transform/ChildDocTransformer.java +++ b/solr/core/src/java/org/apache/solr/response/transform/ChildDocTransformer.java @@ -146,15 +146,18 @@ public void transform(SolrDocument rootDoc, int rootDocId, DocIterationInfo docI final int segBaseId = leafReaderContext.docBase; final int segRootId = rootDocId - segBaseId; Set multiValuedFLoatVectorFields = - this.getMultiValuedVectorFields(searcher.getSchema(), childReturnFields, VectorEncoding.FLOAT32); + this.getMultiValuedVectorFields( + searcher.getSchema(), childReturnFields, VectorEncoding.FLOAT32); Set multiValuedByteVectorFields = - this.getMultiValuedVectorFields(searcher.getSchema(), childReturnFields, VectorEncoding.BYTE); - if ((multiValuedFLoatVectorFields.size() + multiValuedByteVectorFields.size()) != childReturnFields.getExplicitlyRequestedFieldNames() - .size()) { + this.getMultiValuedVectorFields( + searcher.getSchema(), childReturnFields, VectorEncoding.BYTE); + if ((multiValuedFLoatVectorFields.size() + multiValuedByteVectorFields.size()) + != childReturnFields.getExplicitlyRequestedFieldNames().size()) { throw new SolrException( SolrException.ErrorCode.BAD_REQUEST, "When using the Child transformer to flatten nested vectors, all 'fl' must be " - + "multivalued vector fields");} + + "multivalued vector fields"); + } // can return be -1 and that's okay (happens for very first block) final int segPrevRootId; @@ -296,16 +299,18 @@ public void transform(SolrDocument rootDoc, int rootDocId, DocIterationInfo docI } } - private Set getMultiValuedVectorFields(IndexSchema schema, - SolrReturnFields childReturnFields, VectorEncoding encoding) { + private Set getMultiValuedVectorFields( + IndexSchema schema, SolrReturnFields childReturnFields, VectorEncoding encoding) { Set multiValuedVectorsFields = new HashSet<>(); for (String fieldName : childReturnFields.getExplicitlyRequestedFieldNames()) { SchemaField sfield = schema.getFieldOrNull(fieldName); - if (sfield.getType() instanceof DenseVectorField && sfield.multiValued() && ((DenseVectorField) sfield.getType()).getVectorEncoding() == encoding) { + if (sfield.getType() instanceof DenseVectorField + && sfield.multiValued() + && ((DenseVectorField) sfield.getType()).getVectorEncoding() == encoding) { multiValuedVectorsFields.add(fieldName); } } - return multiValuedVectorsFields; + return multiValuedVectorsFields; } private static void addChildrenToParent( @@ -337,26 +342,30 @@ private static void addChildrenToParent( } private void addFlatMultiValuedFloatVectorsToParent( - SolrDocument parent, Map> children, + SolrDocument parent, + Map> children, Set multiValuedVectorFields) { for (String multiValuedVectorField : multiValuedVectorFields) { List solrDocuments = children.get(multiValuedVectorField); List> multiValuedVectors = new ArrayList<>(solrDocuments.size()); for (SolrDocument singleVector : solrDocuments) { - multiValuedVectors.add(this.extractFloatVector(singleVector.getFieldValues(multiValuedVectorField))); + multiValuedVectors.add( + this.extractFloatVector(singleVector.getFieldValues(multiValuedVectorField))); } parent.setField(multiValuedVectorField, multiValuedVectors); } } private void addFlatMultiValuedByteVectorsToParent( - SolrDocument parent, Map> children, + SolrDocument parent, + Map> children, Set multiValuedVectorFields) { for (String multiValuedVectorField : multiValuedVectorFields) { List solrDocuments = children.get(multiValuedVectorField); List> multiValuedVectors = new ArrayList<>(solrDocuments.size()); for (SolrDocument singleVector : solrDocuments) { - multiValuedVectors.add(this.extractByteVector(singleVector.getFieldValues(multiValuedVectorField))); + multiValuedVectors.add( + this.extractByteVector(singleVector.getFieldValues(multiValuedVectorField))); } parent.setField(multiValuedVectorField, multiValuedVectors); } diff --git a/solr/core/src/java/org/apache/solr/update/processor/NestedUpdateProcessorFactory.java b/solr/core/src/java/org/apache/solr/update/processor/NestedUpdateProcessorFactory.java index 6432a8aa6149..14fb44009f9f 100644 --- a/solr/core/src/java/org/apache/solr/update/processor/NestedUpdateProcessorFactory.java +++ b/solr/core/src/java/org/apache/solr/update/processor/NestedUpdateProcessorFactory.java @@ -17,6 +17,8 @@ package org.apache.solr.update.processor; +import static org.apache.solr.schema.IndexSchema.NESTED_VECTORS_PSEUDO_FIELD_NAME; + import java.io.IOException; import java.util.ArrayList; import java.util.Collection; @@ -31,8 +33,6 @@ import org.apache.solr.schema.SchemaField; import org.apache.solr.update.AddUpdateCommand; -import static org.apache.solr.schema.IndexSchema.NESTED_VECTORS_PSEUDO_FIELD_NAME; - /** * Adds fields to nested documents to support some nested search requirements. It can even generate * uniqueKey fields for nested docs. diff --git a/solr/core/src/test/org/apache/solr/search/join/BlockJoinMultiValuedVectorsTest.java b/solr/core/src/test/org/apache/solr/search/join/BlockJoinMultiValuedVectorsTest.java index e05208677b4b..1284bbfb655f 100644 --- a/solr/core/src/test/org/apache/solr/search/join/BlockJoinMultiValuedVectorsTest.java +++ b/solr/core/src/test/org/apache/solr/search/join/BlockJoinMultiValuedVectorsTest.java @@ -156,17 +156,17 @@ public void parentRetrievalByte_knnChildrenWithParentFilter_shouldReturnKnnParen @Test public void - parentRetrievalFloat_ChildTransformerWithChildFilter_shouldFlattenAndReturnBestChild() { + parentRetrievalFloat_ChildTransformerWithChildFilter_shouldFlattenAndReturnBestChild() { assertQ( req( - "q", "{!parent which=$allParents score=max v=$children.q}", - "fl", "id,"+VECTOR_FIELD+", [child fl="+ VECTOR_FIELD + " childFilter=$children.q]", + "q", + "{!parent which=$allParents score=max v=$children.q}", + "fl", + "id," + VECTOR_FIELD + ", [child fl=" + VECTOR_FIELD + " childFilter=$children.q]", "children.q", - "{!knn f=" - + VECTOR_FIELD - + " topK=3 allParents=$allParents}" - + FLOAT_QUERY_VECTOR, - "allParents", "parent_s:[* TO *]"), + "{!knn f=" + VECTOR_FIELD + " topK=3 allParents=$allParents}" + FLOAT_QUERY_VECTOR, + "allParents", + "parent_s:[* TO *]"), "//*[@numFound='3']", "//result/doc[1]/str[@name='id'][.='10']", "//result/doc[1]/arr[@name='" + VECTOR_FIELD + "']/arr/float[1][.='2.0']", @@ -189,14 +189,14 @@ public void parentRetrievalByte_knnChildrenWithParentFilter_shouldReturnKnnParen public void parentRetrievalFloat_ChildTransformer_shouldFlattenAndReturnAllChildren() { assertQ( req( - "q", "{!parent which=$allParents score=max v=$children.q}", - "fl", "id,"+VECTOR_FIELD+", [child fl="+ VECTOR_FIELD + " ]", + "q", + "{!parent which=$allParents score=max v=$children.q}", + "fl", + "id," + VECTOR_FIELD + ", [child fl=" + VECTOR_FIELD + " ]", "children.q", - "{!knn f=" - + VECTOR_FIELD - + " topK=3 allParents=$allParents}" - + FLOAT_QUERY_VECTOR, - "allParents", "parent_s:[* TO *]"), + "{!knn f=" + VECTOR_FIELD + " topK=3 allParents=$allParents}" + FLOAT_QUERY_VECTOR, + "allParents", + "parent_s:[* TO *]"), "//*[@numFound='3']", "//result/doc[1]/str[@name='id'][.='10']", "//result/doc[1]/arr[@name='" + VECTOR_FIELD + "']/arr[1]/float[1][.='4.0']", @@ -241,17 +241,21 @@ public void parentRetrievalFloat_ChildTransformer_shouldFlattenAndReturnAllChild @Test public void - parentRetrievalByte_ChildTransformerWithChildFilter_shouldFlattenAndReturnBestChild() { + parentRetrievalByte_ChildTransformerWithChildFilter_shouldFlattenAndReturnBestChild() { assertQ( req( - "q", "{!parent which=$allParents score=max v=$children.q}", - "fl", "id,"+VECTOR_BYTE_FIELD+", [child fl="+ VECTOR_BYTE_FIELD + " childFilter=$children.q]", - "children.q", - "{!knn f=" + "q", + "{!parent which=$allParents score=max v=$children.q}", + "fl", + "id," + VECTOR_BYTE_FIELD - + " topK=3 allParents=$allParents}" - + BYTE_QUERY_VECTOR, - "allParents", "parent_s:[* TO *]"), + + ", [child fl=" + + VECTOR_BYTE_FIELD + + " childFilter=$children.q]", + "children.q", + "{!knn f=" + VECTOR_BYTE_FIELD + " topK=3 allParents=$allParents}" + BYTE_QUERY_VECTOR, + "allParents", + "parent_s:[* TO *]"), "//*[@numFound='3']", "//result/doc[1]/str[@name='id'][.='10']", "//result/doc[1]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr/int[1][.='2']", @@ -274,14 +278,14 @@ public void parentRetrievalFloat_ChildTransformer_shouldFlattenAndReturnAllChild public void parentRetrievalByte_ChildTransformer_shouldFlattenAndReturnAllChildren() { assertQ( req( - "q", "{!parent which=$allParents score=max v=$children.q}", - "fl", "id,"+VECTOR_BYTE_FIELD+", [child fl="+ VECTOR_BYTE_FIELD + " ]", + "q", + "{!parent which=$allParents score=max v=$children.q}", + "fl", + "id," + VECTOR_BYTE_FIELD + ", [child fl=" + VECTOR_BYTE_FIELD + " ]", "children.q", - "{!knn f=" - + VECTOR_BYTE_FIELD - + " topK=3 allParents=$allParents}" - + BYTE_QUERY_VECTOR, - "allParents", "parent_s:[* TO *]"), + "{!knn f=" + VECTOR_BYTE_FIELD + " topK=3 allParents=$allParents}" + BYTE_QUERY_VECTOR, + "allParents", + "parent_s:[* TO *]"), "//*[@numFound='3']", "//result/doc[1]/str[@name='id'][.='10']", "//result/doc[1]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr[1]/int[1][.='4']", diff --git a/solr/core/src/test/org/apache/solr/search/join/BlockJoinNestedVectorsParentQParserTest.java b/solr/core/src/test/org/apache/solr/search/join/BlockJoinNestedVectorsParentQParserTest.java index 173e7502d929..9eeddbac69c9 100644 --- a/solr/core/src/test/org/apache/solr/search/join/BlockJoinNestedVectorsParentQParserTest.java +++ b/solr/core/src/test/org/apache/solr/search/join/BlockJoinNestedVectorsParentQParserTest.java @@ -169,9 +169,8 @@ protected void parentRetrieval_knnChildrenWithParentFilter_shouldReturnKnnParent "//result/doc[2]/str[@name='id'][.='2']"); } - protected void - parentRetrievalFloat_topKWithChildTransformer_shouldReturnAllChildren( - String vectorField) { + protected void parentRetrievalFloat_topKWithChildTransformer_shouldReturnAllChildren( + String vectorField) { assertQ( req( "q", "{!parent which=$allParents score=max v=$children.q}", diff --git a/solr/core/src/test/org/apache/solr/search/join/BlockJoinNestedVectorsTest.java b/solr/core/src/test/org/apache/solr/search/join/BlockJoinNestedVectorsTest.java index 3e4683c407f5..1568f6f926a9 100644 --- a/solr/core/src/test/org/apache/solr/search/join/BlockJoinNestedVectorsTest.java +++ b/solr/core/src/test/org/apache/solr/search/join/BlockJoinNestedVectorsTest.java @@ -107,8 +107,7 @@ public void childrenRetrievalFloat_filteringByParentMetadata_shouldReturnKnnChil @Test public void childrenRetrievalByte_filteringByParentMetadata_shouldReturnKnnChildren() { - super.childrenRetrieval_filteringByParentMetadata_shouldReturnKnnChildren( - VECTOR_BYTE_FIELD); + super.childrenRetrieval_filteringByParentMetadata_shouldReturnKnnChildren(VECTOR_BYTE_FIELD); } @Test @@ -151,11 +150,8 @@ public void parentRetrievalByte_knnChildrenWithParentFilter_shouldReturnKnnParen } @Test - public void - parentRetrievalFloat_topKWithChildTransformerWithFilter_shouldReturnAllChildren() { - super - .parentRetrievalFloat_topKWithChildTransformer_shouldReturnAllChildren( - VECTOR_FIELD); + public void parentRetrievalFloat_topKWithChildTransformerWithFilter_shouldReturnAllChildren() { + super.parentRetrievalFloat_topKWithChildTransformer_shouldReturnAllChildren(VECTOR_FIELD); } @Test From d187eb356645942f8762c787e9d073ca66d6d339 Mon Sep 17 00:00:00 2001 From: Alessandro Benedetti Date: Thu, 15 Jan 2026 13:16:33 +0100 Subject: [PATCH 38/43] first doc draft --- .../collection1/conf/schema-densevector.xml | 1 - .../pages/dense-vector-search.adoc | 84 +++++++++++++++++-- 2 files changed, 79 insertions(+), 6 deletions(-) diff --git a/solr/core/src/test-files/solr/collection1/conf/schema-densevector.xml b/solr/core/src/test-files/solr/collection1/conf/schema-densevector.xml index aaffb1336b6d..fd7702ea3b9b 100644 --- a/solr/core/src/test-files/solr/collection1/conf/schema-densevector.xml +++ b/solr/core/src/test-files/solr/collection1/conf/schema-densevector.xml @@ -33,7 +33,6 @@ - diff --git a/solr/solr-ref-guide/modules/query-guide/pages/dense-vector-search.adoc b/solr/solr-ref-guide/modules/query-guide/pages/dense-vector-search.adoc index 4dc9239fd0b6..0623865850be 100644 --- a/solr/solr-ref-guide/modules/query-guide/pages/dense-vector-search.adoc +++ b/solr/solr-ref-guide/modules/query-guide/pages/dense-vector-search.adoc @@ -174,12 +174,9 @@ For more details, refer to the official https://arxiv.org/pdf/1603.09320[2018 pa Accepted values: Any integer. -`DenseVectorField` supports the attributes: `indexed`, `stored`. +`DenseVectorField` supports the attributes: `indexed`, `stored`, `multivalued`. -[NOTE] -currently multivalue is not supported - -Here's how a `DenseVectorField` should be indexed: +Here's how a `DenseVectorField` should be indexed when single valued: [tabs#densevectorfield-index] ====== @@ -243,6 +240,46 @@ client.add(Arrays.asList(d1, d2)); ==== ====== +Here's how a `DenseVectorField` should be indexed when multi-valued: + +[tabs#densevectorfield-index] +====== +JSON:: ++ +==== +[source,json] +---- +[ +] +---- +==== + +XML:: ++ +==== +[source,xml] +---- + + + + + +---- +==== + +SolrJ:: ++ +==== +[source,java,indent=0] +---- +final SolrClient client = getSolrClient(); + +TO ADD + +---- +==== +====== + === ScalarQuantizedDenseVectorField Because dense vectors can have a costly size, it may be worthwhile to use a technique called "quantization" which creates a compressed representation of the original vectors. This allows more of the index to be stored in faster memory @@ -593,6 +630,43 @@ The search results retrieved are the k=10 nearest documents to the vector encode For more details on how to work with vectorise text in Apache Solr, please refer to the dedicated page: xref:text-to-vector.adoc[Text to Vector] +=== Handle multivalued vector fields at query time +Behind the scenes a multivalued vector field is handled by Solr as nested documents with a single vector each (see the parameters for the knn query parser that deal with nested vectors 'parents.preFilter' and 'childrenOf'). + +So you should query a multivalued vector fields following the same syntax: +[source,text] +?q={!parent which=$allParents score=max v=$children.q} +&children.q={!knn f=multivaluedvector topK=3 parents.preFilter=$someParents childrenOf=$allParents}[1.0, 2.0, 3.0, 4.0] +&allParents=*:* -_nest_path_:* +&someParents=color_s:RED + +In terms of rendering the results, you need the child transformer if you want to output them flat (you can choose to only return the best vector per result or all vectors): + +All Children +[source,text] +fl=id,multivaluedvector,[child fl="multivaluedvector"] + +==== +[source,json] +---- +[ +] +---- +==== + +Best Child +[source,text] +fl=id,multivaluedvector,[child fl="multivaluedvector" childFilter=$children.q] + +==== +[source,json] +---- +[ +] +---- +==== + + === vectorSimilarity Query Parser The `vectorSimilarity` vector similarity query parser matches documents whose similarity with the target vector is a above a minimum threshold. From 820334274e5c1d9b86c43b9623bcfc7cff603d9f Mon Sep 17 00:00:00 2001 From: Alessandro Benedetti Date: Thu, 15 Jan 2026 13:44:54 +0100 Subject: [PATCH 39/43] tests fixed --- .../response/transform/ChildDocTransformer.java | 5 +++-- .../join/BlockJoinMultiValuedVectorsTest.java | 8 ++++---- .../BlockJoinNestedVectorsParentQParserTest.java | 14 +++++++------- .../search/join/BlockJoinNestedVectorsTest.java | 2 +- 4 files changed, 15 insertions(+), 14 deletions(-) diff --git a/solr/core/src/java/org/apache/solr/response/transform/ChildDocTransformer.java b/solr/core/src/java/org/apache/solr/response/transform/ChildDocTransformer.java index a844dccce744..f6291a9004a8 100644 --- a/solr/core/src/java/org/apache/solr/response/transform/ChildDocTransformer.java +++ b/solr/core/src/java/org/apache/solr/response/transform/ChildDocTransformer.java @@ -151,8 +151,9 @@ public void transform(SolrDocument rootDoc, int rootDocId, DocIterationInfo docI Set multiValuedByteVectorFields = this.getMultiValuedVectorFields( searcher.getSchema(), childReturnFields, VectorEncoding.BYTE); - if ((multiValuedFLoatVectorFields.size() + multiValuedByteVectorFields.size()) - != childReturnFields.getExplicitlyRequestedFieldNames().size()) { + if ((multiValuedFLoatVectorFields.size() + multiValuedByteVectorFields.size()) > 0 && + (multiValuedFLoatVectorFields.size() + multiValuedByteVectorFields.size()) + != childReturnFields.getExplicitlyRequestedFieldNames().size()) { throw new SolrException( SolrException.ErrorCode.BAD_REQUEST, "When using the Child transformer to flatten nested vectors, all 'fl' must be " diff --git a/solr/core/src/test/org/apache/solr/search/join/BlockJoinMultiValuedVectorsTest.java b/solr/core/src/test/org/apache/solr/search/join/BlockJoinMultiValuedVectorsTest.java index 1284bbfb655f..dcf8d9f34a98 100644 --- a/solr/core/src/test/org/apache/solr/search/join/BlockJoinMultiValuedVectorsTest.java +++ b/solr/core/src/test/org/apache/solr/search/join/BlockJoinMultiValuedVectorsTest.java @@ -164,7 +164,7 @@ public void parentRetrievalByte_knnChildrenWithParentFilter_shouldReturnKnnParen "fl", "id," + VECTOR_FIELD + ", [child fl=" + VECTOR_FIELD + " childFilter=$children.q]", "children.q", - "{!knn f=" + VECTOR_FIELD + " topK=3 allParents=$allParents}" + FLOAT_QUERY_VECTOR, + "{!knn f=" + VECTOR_FIELD + " topK=3 childrenOf=$allParents}" + FLOAT_QUERY_VECTOR, "allParents", "parent_s:[* TO *]"), "//*[@numFound='3']", @@ -194,7 +194,7 @@ public void parentRetrievalFloat_ChildTransformer_shouldFlattenAndReturnAllChild "fl", "id," + VECTOR_FIELD + ", [child fl=" + VECTOR_FIELD + " ]", "children.q", - "{!knn f=" + VECTOR_FIELD + " topK=3 allParents=$allParents}" + FLOAT_QUERY_VECTOR, + "{!knn f=" + VECTOR_FIELD + " topK=3 childrenOf=$allParents}" + FLOAT_QUERY_VECTOR, "allParents", "parent_s:[* TO *]"), "//*[@numFound='3']", @@ -253,7 +253,7 @@ public void parentRetrievalFloat_ChildTransformer_shouldFlattenAndReturnAllChild + VECTOR_BYTE_FIELD + " childFilter=$children.q]", "children.q", - "{!knn f=" + VECTOR_BYTE_FIELD + " topK=3 allParents=$allParents}" + BYTE_QUERY_VECTOR, + "{!knn f=" + VECTOR_BYTE_FIELD + " topK=3 childrenOf=$allParents}" + BYTE_QUERY_VECTOR, "allParents", "parent_s:[* TO *]"), "//*[@numFound='3']", @@ -283,7 +283,7 @@ public void parentRetrievalByte_ChildTransformer_shouldFlattenAndReturnAllChildr "fl", "id," + VECTOR_BYTE_FIELD + ", [child fl=" + VECTOR_BYTE_FIELD + " ]", "children.q", - "{!knn f=" + VECTOR_BYTE_FIELD + " topK=3 allParents=$allParents}" + BYTE_QUERY_VECTOR, + "{!knn f=" + VECTOR_BYTE_FIELD + " topK=3 childrenOf=$allParents}" + BYTE_QUERY_VECTOR, "allParents", "parent_s:[* TO *]"), "//*[@numFound='3']", diff --git a/solr/core/src/test/org/apache/solr/search/join/BlockJoinNestedVectorsParentQParserTest.java b/solr/core/src/test/org/apache/solr/search/join/BlockJoinNestedVectorsParentQParserTest.java index 9eeddbac69c9..a4d1f39097cd 100644 --- a/solr/core/src/test/org/apache/solr/search/join/BlockJoinNestedVectorsParentQParserTest.java +++ b/solr/core/src/test/org/apache/solr/search/join/BlockJoinNestedVectorsParentQParserTest.java @@ -77,7 +77,7 @@ protected void parentRetrieval_knnChildrenDiversifyingWithNoAllParents_shouldThr "fl", "id,score", "children.q", - "{!knn f=" + vectorField + " topK=3 childrenOf=$someParents}" + FLOAT_QUERY_VECTOR, + "{!knn f=" + vectorField + " topK=3 parents.preFilter=$someParents}" + FLOAT_QUERY_VECTOR, "allParents", "parent_s:[* TO *]", "someParents", @@ -122,7 +122,7 @@ protected void parentRetrieval_knnChildren_shouldReturnKnnParents(String vectorB "children.q", "{!knn f=" + vectorByteField - + " topK=3 allParents=$allParents}" + + " topK=3 childrenOf=$allParents}" + BYTE_QUERY_VECTOR, "allParents", "parent_s:[* TO *]"), "//*[@numFound='3']", @@ -140,7 +140,7 @@ protected void parentRetrieval_knnChildrenWithParentFilter_shouldReturnKnnParent "children.q", "{!knn f=" + vectorByteField - + " topK=3 childrenOf=$someParents allParents=$allParents}" + + " topK=3 parents.preFilter=$someParents childrenOf=$allParents}" + BYTE_QUERY_VECTOR, "allParents", "parent_s:[* TO *]", "someParents", "parent_s:(a c)"), @@ -160,7 +160,7 @@ protected void parentRetrieval_knnChildrenWithParentFilter_shouldReturnKnnParent "children.q", "{!knn f=" + vectorByteField - + " topK=3 preFilter=child_s:m childrenOf=$someParents allParents=$allParents}" + + " topK=3 preFilter=child_s:m parents.preFilter=$someParents childrenOf=$allParents}" + BYTE_QUERY_VECTOR, "allParents", "parent_s:[* TO *]", "someParents", "parent_s:(a c)"), @@ -183,7 +183,7 @@ protected void parentRetrievalFloat_topKWithChildTransformer_shouldReturnAllChil "children.q", "{!knn f=" + vectorField - + " topK=3 childrenOf=$someParents allParents=$allParents}" + + " topK=3 parents.preFilter=$someParents childrenOf=$allParents}" + FLOAT_QUERY_VECTOR, "allParents", "parent_s:[* TO *]", "someParents", "parent_s:(a c)"), @@ -330,7 +330,7 @@ protected void parentRetrievalByte_topKWithChildTransformer_shouldReturnAllChild "children.q", "{!knn f=" + vectorByteField - + " topK=3 childrenOf=$someParents allParents=$allParents}" + + " topK=3 parents.preFilter=$someParents childrenOf=$allParents}" + BYTE_QUERY_VECTOR, "allParents", "parent_s:[* TO *]", @@ -479,7 +479,7 @@ protected void parentRetrievalByte_topKWithChildTransformerWithFilter_shouldRetu "children.q", "{!knn f=" + vectorByteField - + " topK=3 childrenOf=$someParents allParents=$allParents}" + + " topK=3 parents.preFilter=$someParents childrenOf=$allParents}" + BYTE_QUERY_VECTOR, "allParents", "parent_s:[* TO *]", diff --git a/solr/core/src/test/org/apache/solr/search/join/BlockJoinNestedVectorsTest.java b/solr/core/src/test/org/apache/solr/search/join/BlockJoinNestedVectorsTest.java index 1568f6f926a9..81303600e7da 100644 --- a/solr/core/src/test/org/apache/solr/search/join/BlockJoinNestedVectorsTest.java +++ b/solr/core/src/test/org/apache/solr/search/join/BlockJoinNestedVectorsTest.java @@ -169,7 +169,7 @@ public void parentRetrievalFloat_topKWithChildTransformerWithFilter_shouldReturn "children.q", "{!knn f=" + VECTOR_FIELD - + " topK=3 childrenOf=$someParents allParents=$allParents}" + + " topK=3 parents.preFilter=$someParents childrenOf=$allParents}" + FLOAT_QUERY_VECTOR, "allParents", "parent_s:[* TO *]", From b2df150bbc74d3cedd234a0c931e4131bc1c83db Mon Sep 17 00:00:00 2001 From: Alessandro Benedetti Date: Fri, 23 Jan 2026 18:31:18 +0000 Subject: [PATCH 40/43] documentation completed --- .../pages/dense-vector-search.adoc | 87 ++++++++++++++----- 1 file changed, 65 insertions(+), 22 deletions(-) diff --git a/solr/solr-ref-guide/modules/query-guide/pages/dense-vector-search.adoc b/solr/solr-ref-guide/modules/query-guide/pages/dense-vector-search.adoc index 0623865850be..94938e64c1a6 100644 --- a/solr/solr-ref-guide/modules/query-guide/pages/dense-vector-search.adoc +++ b/solr/solr-ref-guide/modules/query-guide/pages/dense-vector-search.adoc @@ -249,24 +249,16 @@ JSON:: ==== [source,json] ---- -[ +[{ "id": "1", + "vector_multivalued": [[1.0, 2.0, 3.0, 4.0],[5.0, 6.0, 7.0, 8.0]] +}, +{ "id": "2", + "vector_multivalued": [[1.0, 2.0, 3.0, 4.0],[5.0, 6.0, 7.0, 8.0]] +} ] ---- ==== -XML:: -+ -==== -[source,xml] ----- - - - - - ----- -==== - SolrJ:: + ==== @@ -274,7 +266,22 @@ SolrJ:: ---- final SolrClient client = getSolrClient(); -TO ADD +final SolrInputDocument d1 = new SolrInputDocument(); +d1.setField("id", "1"); +List> floatVectors1 = new ArrayList<>(2); +floatVectors1.add(Arrays.asList(1.0f, 2.0f, 3.0f, 4.0f)); +floatVectors1.add(Arrays.asList(5.0f, 6.0f, 7.0f, 8.0f)); +d1.setField("vector_multivalued", floatVectors1); + + +final SolrInputDocument d2 = new SolrInputDocument(); +d2.setField("id", "2"); +List> floatVectors2 = new ArrayList<>(2); +floatVectors2.add(Arrays.asList(1.0f, 2.0f, 3.0f, 4.0f)); +floatVectors2.add(Arrays.asList(5.0f, 6.0f, 7.0f, 8.0f)); +d2.setField("vector_multivalued", floatVectors2); + +client.add(Arrays.asList(d1, d2)); ---- ==== @@ -636,7 +643,7 @@ Behind the scenes a multivalued vector field is handled by Solr as nested docume So you should query a multivalued vector fields following the same syntax: [source,text] ?q={!parent which=$allParents score=max v=$children.q} -&children.q={!knn f=multivaluedvector topK=3 parents.preFilter=$someParents childrenOf=$allParents}[1.0, 2.0, 3.0, 4.0] +&children.q={!knn f=vector_multivalued topK=3 parents.preFilter=$someParents childrenOf=$allParents}[1.0, 2.0, 3.0, 4.0] &allParents=*:* -_nest_path_:* &someParents=color_s:RED @@ -644,25 +651,61 @@ In terms of rendering the results, you need the child transformer if you want to All Children [source,text] -fl=id,multivaluedvector,[child fl="multivaluedvector"] +fl=id,vector_multivalued,[child fl="vector_multivalued"] ==== [source,json] ---- -[ -] +"docs":[ + { + "id":"1", + "vector_multivalued":[ + [ + 1.0,2.0, 3.0, 4.0 + ], + [ + 5.0,6.0, 7.0, 8.0 + ] + ] + }, + { + "id":"2", + "vector_multivalued":[ + [ + 1.0,2.0, 3.0, 4.0 + ], + [ + 5.0,6.0, 7.0, 8.0 + ] + ] + }] ---- ==== Best Child [source,text] -fl=id,multivaluedvector,[child fl="multivaluedvector" childFilter=$children.q] +fl=id,vector_multivalued,[child fl="vector_multivalued" childFilter=$children.q] ==== [source,json] ---- -[ -] +"docs":[ + { + "id":"1", + "vector_multivalued":[ + [ + 1.0,2.0, 3.0, 4.0 + ] + ] + }, + { + "id":"2", + "vector_multivalued":[ + [ + 1.0,2.0, 3.0, 4.0 + ] + ] + }] ---- ==== From 0b127090ff52c87be4a5ee166437ba264ca0fdc8 Mon Sep 17 00:00:00 2001 From: Alessandro Benedetti Date: Fri, 23 Jan 2026 19:02:18 +0000 Subject: [PATCH 41/43] tidy --- .../apache/solr/response/transform/ChildDocTransformer.java | 4 ++-- .../search/join/BlockJoinNestedVectorsParentQParserTest.java | 5 ++++- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/solr/core/src/java/org/apache/solr/response/transform/ChildDocTransformer.java b/solr/core/src/java/org/apache/solr/response/transform/ChildDocTransformer.java index f6291a9004a8..6f799710d0ad 100644 --- a/solr/core/src/java/org/apache/solr/response/transform/ChildDocTransformer.java +++ b/solr/core/src/java/org/apache/solr/response/transform/ChildDocTransformer.java @@ -151,8 +151,8 @@ public void transform(SolrDocument rootDoc, int rootDocId, DocIterationInfo docI Set multiValuedByteVectorFields = this.getMultiValuedVectorFields( searcher.getSchema(), childReturnFields, VectorEncoding.BYTE); - if ((multiValuedFLoatVectorFields.size() + multiValuedByteVectorFields.size()) > 0 && - (multiValuedFLoatVectorFields.size() + multiValuedByteVectorFields.size()) + if ((multiValuedFLoatVectorFields.size() + multiValuedByteVectorFields.size()) > 0 + && (multiValuedFLoatVectorFields.size() + multiValuedByteVectorFields.size()) != childReturnFields.getExplicitlyRequestedFieldNames().size()) { throw new SolrException( SolrException.ErrorCode.BAD_REQUEST, diff --git a/solr/core/src/test/org/apache/solr/search/join/BlockJoinNestedVectorsParentQParserTest.java b/solr/core/src/test/org/apache/solr/search/join/BlockJoinNestedVectorsParentQParserTest.java index a4d1f39097cd..8e374ba01c9f 100644 --- a/solr/core/src/test/org/apache/solr/search/join/BlockJoinNestedVectorsParentQParserTest.java +++ b/solr/core/src/test/org/apache/solr/search/join/BlockJoinNestedVectorsParentQParserTest.java @@ -77,7 +77,10 @@ protected void parentRetrieval_knnChildrenDiversifyingWithNoAllParents_shouldThr "fl", "id,score", "children.q", - "{!knn f=" + vectorField + " topK=3 parents.preFilter=$someParents}" + FLOAT_QUERY_VECTOR, + "{!knn f=" + + vectorField + + " topK=3 parents.preFilter=$someParents}" + + FLOAT_QUERY_VECTOR, "allParents", "parent_s:[* TO *]", "someParents", From 5a5b2d8f621beccf409cd233f4a96c444e05dd59 Mon Sep 17 00:00:00 2001 From: Alessandro Benedetti Date: Thu, 29 Jan 2026 13:18:42 +0100 Subject: [PATCH 42/43] minor null checks --- .../response/transform/ChildDocTransformer.java | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/solr/core/src/java/org/apache/solr/response/transform/ChildDocTransformer.java b/solr/core/src/java/org/apache/solr/response/transform/ChildDocTransformer.java index 6f799710d0ad..bc368cfed6e4 100644 --- a/solr/core/src/java/org/apache/solr/response/transform/ChildDocTransformer.java +++ b/solr/core/src/java/org/apache/solr/response/transform/ChildDocTransformer.java @@ -303,12 +303,17 @@ public void transform(SolrDocument rootDoc, int rootDocId, DocIterationInfo docI private Set getMultiValuedVectorFields( IndexSchema schema, SolrReturnFields childReturnFields, VectorEncoding encoding) { Set multiValuedVectorsFields = new HashSet<>(); - for (String fieldName : childReturnFields.getExplicitlyRequestedFieldNames()) { - SchemaField sfield = schema.getFieldOrNull(fieldName); - if (sfield.getType() instanceof DenseVectorField - && sfield.multiValued() - && ((DenseVectorField) sfield.getType()).getVectorEncoding() == encoding) { - multiValuedVectorsFields.add(fieldName); + Set explicitlyRequestedFieldNames = + childReturnFields.getExplicitlyRequestedFieldNames(); + if (explicitlyRequestedFieldNames != null) { + for (String fieldName : explicitlyRequestedFieldNames) { + SchemaField sfield = schema.getFieldOrNull(fieldName); + if (sfield != null + && sfield.getType() instanceof DenseVectorField + && sfield.multiValued() + && ((DenseVectorField) sfield.getType()).getVectorEncoding() == encoding) { + multiValuedVectorsFields.add(fieldName); + } } } return multiValuedVectorsFields; From 96764522f805e55e3c462cdfca94841911137657 Mon Sep 17 00:00:00 2001 From: Alessandro Benedetti Date: Fri, 30 Jan 2026 15:50:27 +0100 Subject: [PATCH 43/43] minor null checks --- .../test/org/apache/solr/schema/DenseVectorFieldTest.java | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/solr/core/src/test/org/apache/solr/schema/DenseVectorFieldTest.java b/solr/core/src/test/org/apache/solr/schema/DenseVectorFieldTest.java index 5d0a5ce2a9b7..f84eaf40ef0b 100644 --- a/solr/core/src/test/org/apache/solr/schema/DenseVectorFieldTest.java +++ b/solr/core/src/test/org/apache/solr/schema/DenseVectorFieldTest.java @@ -718,7 +718,7 @@ public void denseVectorByteEncoding_shouldRaiseExceptionWithValuesOutsideBoundar assertThat( thrown.getCause().getMessage(), is( - "Error while creating field 'vector_byte_encoding{type=knn_vector_byte_encoding,properties=indexed,stored}' from value '[128, 6, 7, 8]'")); + "Error while creating field 'vector_byte_encoding{type=knn_vector_byte_encoding,properties=indexed,stored,omitNorms,omitTermFreqAndPositions,useDocValuesAsStored}' from value '[128, 6, 7, 8]'")); assertThat( thrown.getCause().getCause().getMessage(), @@ -740,7 +740,7 @@ public void denseVectorByteEncoding_shouldRaiseExceptionWithValuesOutsideBoundar assertThat( thrown.getCause().getMessage(), is( - "Error while creating field 'vector_byte_encoding{type=knn_vector_byte_encoding,properties=indexed,stored}' from value '[1, -129, 7, 8]'")); + "Error while creating field 'vector_byte_encoding{type=knn_vector_byte_encoding,properties=indexed,stored,omitNorms,omitTermFreqAndPositions,useDocValuesAsStored}' from value '[1, -129, 7, 8]'")); assertThat( thrown.getCause().getCause().getMessage(), is( @@ -769,7 +769,7 @@ public void denseVectorByteEncoding_shouldRaiseExceptionWithFloatValues() throws assertThat( thrown.getCause().getMessage(), is( - "Error while creating field 'vector_byte_encoding{type=knn_vector_byte_encoding,properties=indexed,stored}' from value '[14.3, 6.2, 7.2, 8.1]'")); + "Error while creating field 'vector_byte_encoding{type=knn_vector_byte_encoding,properties=indexed,stored,omitNorms,omitTermFreqAndPositions,useDocValuesAsStored}' from value '[14.3, 6.2, 7.2, 8.1]'")); assertThat( thrown.getCause().getCause().getMessage(),