diff --git a/changelog/unreleased/SOLR-18074.yml b/changelog/unreleased/SOLR-18074.yml new file mode 100644 index 000000000000..69dba9c2966b --- /dev/null +++ b/changelog/unreleased/SOLR-18074.yml @@ -0,0 +1,8 @@ +# See https://github.com/apache/solr/blob/main/dev-docs/changelog.adoc +title: Introducing support for multi valued dense vector representation in documents through nested vectors +type: added # added, changed, fixed, deprecated, removed, dependency_update, security, other +authors: + - name: Alessandro Benedetti +links: + - name: SOLR-18074 + url: https://issues.apache.org/jira/browse/SOLR-18074 diff --git a/solr/core/src/java/org/apache/solr/response/transform/ChildDocTransformer.java b/solr/core/src/java/org/apache/solr/response/transform/ChildDocTransformer.java index 204aa7a6190c..bc368cfed6e4 100644 --- a/solr/core/src/java/org/apache/solr/response/transform/ChildDocTransformer.java +++ b/solr/core/src/java/org/apache/solr/response/transform/ChildDocTransformer.java @@ -24,9 +24,13 @@ import java.io.IOException; import java.lang.invoke.MethodHandles; import java.util.ArrayList; +import java.util.Collection; import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Set; +import org.apache.lucene.document.StoredField; import org.apache.lucene.index.DocValues; import org.apache.lucene.index.LeafReader; import org.apache.lucene.index.LeafReaderContext; @@ -35,6 +39,7 @@ import org.apache.lucene.index.SortedDocValues; import org.apache.lucene.index.Terms; import org.apache.lucene.index.TermsEnum; +import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.join.BitSetProducer; import org.apache.lucene.util.BitSet; @@ -42,7 +47,9 @@ import org.apache.lucene.util.BytesRef; import org.apache.solr.common.SolrDocument; import org.apache.solr.common.SolrException; +import org.apache.solr.schema.DenseVectorField; import org.apache.solr.schema.IndexSchema; +import org.apache.solr.schema.SchemaField; import org.apache.solr.search.BitsFilteredPostingsEnum; import org.apache.solr.search.DocIterationInfo; import org.apache.solr.search.DocSet; @@ -138,6 +145,20 @@ public void transform(SolrDocument rootDoc, int rootDocId, DocIterationInfo docI final Bits liveDocs = leafReaderContext.reader().getLiveDocs(); final int segBaseId = leafReaderContext.docBase; final int segRootId = rootDocId - segBaseId; + Set multiValuedFLoatVectorFields = + this.getMultiValuedVectorFields( + searcher.getSchema(), childReturnFields, VectorEncoding.FLOAT32); + Set multiValuedByteVectorFields = + this.getMultiValuedVectorFields( + searcher.getSchema(), childReturnFields, VectorEncoding.BYTE); + if ((multiValuedFLoatVectorFields.size() + multiValuedByteVectorFields.size()) > 0 + && (multiValuedFLoatVectorFields.size() + multiValuedByteVectorFields.size()) + != childReturnFields.getExplicitlyRequestedFieldNames().size()) { + throw new SolrException( + SolrException.ErrorCode.BAD_REQUEST, + "When using the Child transformer to flatten nested vectors, all 'fl' must be " + + "multivalued vector fields"); + } // can return be -1 and that's okay (happens for very first block) final int segPrevRootId; @@ -219,8 +240,19 @@ public void transform(SolrDocument rootDoc, int rootDocId, DocIterationInfo docI if (isAncestor) { // if this path has pending child docs, add them. - addChildrenToParent( - doc, pendingParentPathsToChildren.remove(fullDocPath)); // no longer pending + if (!multiValuedFLoatVectorFields.isEmpty() || !multiValuedByteVectorFields.isEmpty()) { + addFlatMultiValuedFloatVectorsToParent( + rootDoc, + pendingParentPathsToChildren.values().iterator().next(), + multiValuedFLoatVectorFields); + addFlatMultiValuedByteVectorsToParent( + rootDoc, + pendingParentPathsToChildren.values().iterator().next(), + multiValuedByteVectorFields); + } else { + addChildrenToParent( + doc, pendingParentPathsToChildren.remove(fullDocPath)); // no longer pending + } } // get parent path @@ -248,7 +280,18 @@ public void transform(SolrDocument rootDoc, int rootDocId, DocIterationInfo docI assert pendingParentPathsToChildren.keySet().size() == 1; // size == 1, so get the last remaining entry - addChildrenToParent(rootDoc, pendingParentPathsToChildren.values().iterator().next()); + if (!multiValuedFLoatVectorFields.isEmpty() || !multiValuedByteVectorFields.isEmpty()) { + addFlatMultiValuedFloatVectorsToParent( + rootDoc, + pendingParentPathsToChildren.values().iterator().next(), + multiValuedFLoatVectorFields); + addFlatMultiValuedByteVectorsToParent( + rootDoc, + pendingParentPathsToChildren.values().iterator().next(), + multiValuedByteVectorFields); + } else { + addChildrenToParent(rootDoc, pendingParentPathsToChildren.values().iterator().next()); + } } catch (IOException e) { // TODO DWS: reconsider this unusual error handling approach; shouldn't we rethrow? @@ -257,6 +300,25 @@ public void transform(SolrDocument rootDoc, int rootDocId, DocIterationInfo docI } } + private Set getMultiValuedVectorFields( + IndexSchema schema, SolrReturnFields childReturnFields, VectorEncoding encoding) { + Set multiValuedVectorsFields = new HashSet<>(); + Set explicitlyRequestedFieldNames = + childReturnFields.getExplicitlyRequestedFieldNames(); + if (explicitlyRequestedFieldNames != null) { + for (String fieldName : explicitlyRequestedFieldNames) { + SchemaField sfield = schema.getFieldOrNull(fieldName); + if (sfield != null + && sfield.getType() instanceof DenseVectorField + && sfield.multiValued() + && ((DenseVectorField) sfield.getType()).getVectorEncoding() == encoding) { + multiValuedVectorsFields.add(fieldName); + } + } + } + return multiValuedVectorsFields; + } + private static void addChildrenToParent( SolrDocument parent, Map> children) { for (Map.Entry> entry : children.entrySet()) { @@ -285,6 +347,55 @@ private static void addChildrenToParent( parent.setField(trimmedPath, children.get(0)); } + private void addFlatMultiValuedFloatVectorsToParent( + SolrDocument parent, + Map> children, + Set multiValuedVectorFields) { + for (String multiValuedVectorField : multiValuedVectorFields) { + List solrDocuments = children.get(multiValuedVectorField); + List> multiValuedVectors = new ArrayList<>(solrDocuments.size()); + for (SolrDocument singleVector : solrDocuments) { + multiValuedVectors.add( + this.extractFloatVector(singleVector.getFieldValues(multiValuedVectorField))); + } + parent.setField(multiValuedVectorField, multiValuedVectors); + } + } + + private void addFlatMultiValuedByteVectorsToParent( + SolrDocument parent, + Map> children, + Set multiValuedVectorFields) { + for (String multiValuedVectorField : multiValuedVectorFields) { + List solrDocuments = children.get(multiValuedVectorField); + List> multiValuedVectors = new ArrayList<>(solrDocuments.size()); + for (SolrDocument singleVector : solrDocuments) { + multiValuedVectors.add( + this.extractByteVector(singleVector.getFieldValues(multiValuedVectorField))); + } + parent.setField(multiValuedVectorField, multiValuedVectors); + } + } + + private List extractFloatVector(Collection fieldValues) { + List vector = new ArrayList<>(fieldValues.size()); + for (Object fieldValue : fieldValues) { + StoredField storedVectorValue = (StoredField) fieldValue; + vector.add(storedVectorValue.numericValue()); + } + return vector; + } + + private List extractByteVector(Collection singleVector) { + StoredField vector = (StoredField) singleVector.iterator().next(); + BytesRef byteVector = vector.binaryValue(); + List extractedVector = new ArrayList<>(byteVector.length); + for (Byte element : byteVector.bytes) { + extractedVector.add(element.byteValue()); + } + return extractedVector; + } + private static String getLastPath(String path) { int lastIndexOfPathSepChar = path.lastIndexOf(PATH_SEP_CHAR); if (lastIndexOfPathSepChar == -1) { diff --git a/solr/core/src/java/org/apache/solr/schema/DenseVectorField.java b/solr/core/src/java/org/apache/solr/schema/DenseVectorField.java index cf5942b12bf0..247b389904b0 100644 --- a/solr/core/src/java/org/apache/solr/schema/DenseVectorField.java +++ b/solr/core/src/java/org/apache/solr/schema/DenseVectorField.java @@ -316,11 +316,6 @@ protected boolean enableDocValuesByDefault() { @Override public void checkSchemaField(final SchemaField field) throws SolrException { super.checkSchemaField(field); - if (field.multiValued()) { - throw new SolrException( - SolrException.ErrorCode.SERVER_ERROR, - getClass().getSimpleName() + " fields can not be multiValued: " + field.getName()); - } if (field.hasDocValues()) { throw new SolrException( diff --git a/solr/core/src/java/org/apache/solr/schema/IndexSchema.java b/solr/core/src/java/org/apache/solr/schema/IndexSchema.java index 3372bcae0650..f0e4f2f49f94 100644 --- a/solr/core/src/java/org/apache/solr/schema/IndexSchema.java +++ b/solr/core/src/java/org/apache/solr/schema/IndexSchema.java @@ -106,6 +106,7 @@ public class IndexSchema { public static final String NAME = "name"; public static final String NEST_PARENT_FIELD_NAME = "_nest_parent_"; public static final String NEST_PATH_FIELD_NAME = "_nest_path_"; + public static final String NESTED_VECTORS_PSEUDO_FIELD_NAME = "_nested_vectors_"; public static final String REQUIRED = "required"; public static final String SCHEMA = "schema"; public static final String SIMILARITY = "similarity"; diff --git a/solr/core/src/java/org/apache/solr/update/processor/NestedUpdateProcessorFactory.java b/solr/core/src/java/org/apache/solr/update/processor/NestedUpdateProcessorFactory.java index 96a93b7b45e4..14fb44009f9f 100644 --- a/solr/core/src/java/org/apache/solr/update/processor/NestedUpdateProcessorFactory.java +++ b/solr/core/src/java/org/apache/solr/update/processor/NestedUpdateProcessorFactory.java @@ -17,14 +17,20 @@ package org.apache.solr.update.processor; +import static org.apache.solr.schema.IndexSchema.NESTED_VECTORS_PSEUDO_FIELD_NAME; + import java.io.IOException; +import java.util.ArrayList; import java.util.Collection; +import java.util.List; import org.apache.solr.common.SolrException; import org.apache.solr.common.SolrInputDocument; import org.apache.solr.common.SolrInputField; import org.apache.solr.request.SolrQueryRequest; import org.apache.solr.response.SolrQueryResponse; +import org.apache.solr.schema.DenseVectorField; import org.apache.solr.schema.IndexSchema; +import org.apache.solr.schema.SchemaField; import org.apache.solr.update.AddUpdateCommand; /** @@ -63,6 +69,7 @@ private static class NestedUpdateProcessor extends UpdateRequestProcessor { private boolean storePath; private boolean storeParent; private String uniqueKeyFieldName; + private IndexSchema schema; NestedUpdateProcessor( SolrQueryRequest req, boolean storeParent, boolean storePath, UpdateRequestProcessor next) { @@ -70,6 +77,7 @@ private static class NestedUpdateProcessor extends UpdateRequestProcessor { this.storeParent = storeParent; this.storePath = storePath; this.uniqueKeyFieldName = req.getSchema().getUniqueKeyField().getName(); + this.schema = req.getSchema(); } @Override @@ -81,53 +89,98 @@ public void processAdd(AddUpdateCommand cmd) throws IOException { private boolean processDocChildren(SolrInputDocument doc, String fullPath) { boolean isNested = false; + List originalVectorFieldsToRemove = new ArrayList<>(); + ArrayList vectors = new ArrayList<>(); for (SolrInputField field : doc.values()) { + SchemaField sfield = schema.getFieldOrNull(field.getName()); int childNum = 0; boolean isSingleVal = !(field.getValue() instanceof Collection); - for (Object val : field) { - if (!(val instanceof SolrInputDocument cDoc)) { - // either all collection items are child docs or none are. - break; - } - final String fieldName = field.getName(); - - if (fieldName.contains(PATH_SEP_CHAR)) { - throw new SolrException( - SolrException.ErrorCode.BAD_REQUEST, - "Field name: '" - + fieldName - + "' contains: '" - + PATH_SEP_CHAR - + "' , which is reserved for the nested URP"); - } - final String sChildNum = isSingleVal ? SINGULAR_VALUE_CHAR : String.valueOf(childNum); - if (!cDoc.containsKey(uniqueKeyFieldName)) { + boolean firstLevelChildren = fullPath == null; + if (firstLevelChildren && sfield != null && isMultiValuedVectorField(sfield)) { + for (Object vectorValue : field.getValues()) { + SolrInputDocument singleVectorNestedDoc = new SolrInputDocument(); + singleVectorNestedDoc.setField(field.getName(), vectorValue); + final String sChildNum = isSingleVal ? SINGULAR_VALUE_CHAR : String.valueOf(childNum); String parentDocId = doc.getField(uniqueKeyFieldName).getFirstValue().toString(); - cDoc.setField( - uniqueKeyFieldName, generateChildUniqueId(parentDocId, fieldName, sChildNum)); + singleVectorNestedDoc.setField( + uniqueKeyFieldName, generateChildUniqueId(parentDocId, field.getName(), sChildNum)); + + if (!isNested) { + isNested = true; + } + final String lastKeyPath = PATH_SEP_CHAR + field.getName() + NUM_SEP_CHAR + sChildNum; + final String childDocPath = firstLevelChildren ? lastKeyPath : fullPath + lastKeyPath; + if (storePath) { + setPathField(singleVectorNestedDoc, childDocPath); + } + if (storeParent) { + setParentKey(singleVectorNestedDoc, doc); + } + ++childNum; + vectors.add(singleVectorNestedDoc); } - if (!isNested) { - isNested = true; + originalVectorFieldsToRemove.add(field.getName()); + } else { + for (Object val : field) { + if (!(val instanceof SolrInputDocument cDoc)) { + // either all collection items are child docs or none are. + break; + } + final String fieldName = field.getName(); + + if (fieldName.contains(PATH_SEP_CHAR)) { + throw new SolrException( + SolrException.ErrorCode.BAD_REQUEST, + "Field name: '" + + fieldName + + "' contains: '" + + PATH_SEP_CHAR + + "' , which is reserved for the nested URP"); + } + final String sChildNum = isSingleVal ? SINGULAR_VALUE_CHAR : String.valueOf(childNum); + if (!cDoc.containsKey(uniqueKeyFieldName)) { + String parentDocId = doc.getField(uniqueKeyFieldName).getFirstValue().toString(); + cDoc.setField( + uniqueKeyFieldName, generateChildUniqueId(parentDocId, fieldName, sChildNum)); + } + if (!isNested) { + isNested = true; + } + final String lastKeyPath = PATH_SEP_CHAR + fieldName + NUM_SEP_CHAR + sChildNum; + // concat of all paths children.grandChild => /children#1/grandChild# + final String childDocPath = firstLevelChildren ? lastKeyPath : fullPath + lastKeyPath; + processChildDoc(cDoc, doc, childDocPath); + ++childNum; } - final String lastKeyPath = PATH_SEP_CHAR + fieldName + NUM_SEP_CHAR + sChildNum; - // concat of all paths children.grandChild => /children#1/grandChild# - final String childDocPath = fullPath == null ? lastKeyPath : fullPath + lastKeyPath; - processChildDoc(cDoc, doc, childDocPath); - ++childNum; } } + this.cleanOriginalVectorFields(doc, originalVectorFieldsToRemove); + if (vectors.size() > 0) { + doc.setField(NESTED_VECTORS_PSEUDO_FIELD_NAME, vectors); + } return isNested; } + private void cleanOriginalVectorFields( + SolrInputDocument doc, List originalVectorFieldsToRemove) { + for (String fieldName : originalVectorFieldsToRemove) { + doc.removeField(fieldName); + } + } + + private static boolean isMultiValuedVectorField(SchemaField sfield) { + return sfield.getType() instanceof DenseVectorField && sfield.multiValued(); + } + private void processChildDoc( - SolrInputDocument sdoc, SolrInputDocument parent, String fullPath) { + SolrInputDocument child, SolrInputDocument parent, String fullPath) { if (storePath) { - setPathField(sdoc, fullPath); + setPathField(child, fullPath); } if (storeParent) { - setParentKey(sdoc, parent); + setParentKey(child, parent); } - processDocChildren(sdoc, fullPath); + processDocChildren(child, fullPath); } private String generateChildUniqueId(String parentId, String childKey, String childNum) { @@ -135,12 +188,12 @@ private String generateChildUniqueId(String parentId, String childKey, String ch return parentId + PATH_SEP_CHAR + childKey + NUM_SEP_CHAR + childNum; } - private void setParentKey(SolrInputDocument sdoc, SolrInputDocument parent) { - sdoc.setField(IndexSchema.NEST_PARENT_FIELD_NAME, parent.getFieldValue(uniqueKeyFieldName)); + private void setParentKey(SolrInputDocument child, SolrInputDocument parent) { + child.setField(IndexSchema.NEST_PARENT_FIELD_NAME, parent.getFieldValue(uniqueKeyFieldName)); } - private void setPathField(SolrInputDocument sdoc, String fullPath) { - sdoc.setField(IndexSchema.NEST_PATH_FIELD_NAME, fullPath); + private void setPathField(SolrInputDocument child, String fullPath) { + child.setField(IndexSchema.NEST_PATH_FIELD_NAME, fullPath); } } } diff --git a/solr/core/src/test-files/solr/collection1/conf/schema-densevector.xml b/solr/core/src/test-files/solr/collection1/conf/schema-densevector.xml index f3d663a40663..fd7702ea3b9b 100644 --- a/solr/core/src/test-files/solr/collection1/conf/schema-densevector.xml +++ b/solr/core/src/test-files/solr/collection1/conf/schema-densevector.xml @@ -18,8 +18,9 @@ - + + @@ -27,15 +28,25 @@ + + + + + + + + + + diff --git a/solr/core/src/test/org/apache/solr/schema/DenseVectorFieldTest.java b/solr/core/src/test/org/apache/solr/schema/DenseVectorFieldTest.java index 18794907df24..f84eaf40ef0b 100644 --- a/solr/core/src/test/org/apache/solr/schema/DenseVectorFieldTest.java +++ b/solr/core/src/test/org/apache/solr/schema/DenseVectorFieldTest.java @@ -93,14 +93,6 @@ public void fieldDefinition_docValues_shouldThrowException() throws Exception { "DenseVectorField fields can not have docValues: vector"); } - @Test - public void fieldDefinition_multiValued_shouldThrowException() throws Exception { - assertConfigs( - "solrconfig-basic.xml", - "bad-schema-densevector-multivalued.xml", - "DenseVectorField fields can not be multiValued: vector"); - } - @Test public void fieldTypeDefinition_nullSimilarityDistance_shouldUseDefaultSimilarityEuclidean() throws Exception { @@ -726,7 +718,7 @@ public void denseVectorByteEncoding_shouldRaiseExceptionWithValuesOutsideBoundar assertThat( thrown.getCause().getMessage(), is( - "Error while creating field 'vector_byte_encoding{type=knn_vector_byte_encoding,properties=indexed,stored}' from value '[128, 6, 7, 8]'")); + "Error while creating field 'vector_byte_encoding{type=knn_vector_byte_encoding,properties=indexed,stored,omitNorms,omitTermFreqAndPositions,useDocValuesAsStored}' from value '[128, 6, 7, 8]'")); assertThat( thrown.getCause().getCause().getMessage(), @@ -748,7 +740,7 @@ public void denseVectorByteEncoding_shouldRaiseExceptionWithValuesOutsideBoundar assertThat( thrown.getCause().getMessage(), is( - "Error while creating field 'vector_byte_encoding{type=knn_vector_byte_encoding,properties=indexed,stored}' from value '[1, -129, 7, 8]'")); + "Error while creating field 'vector_byte_encoding{type=knn_vector_byte_encoding,properties=indexed,stored,omitNorms,omitTermFreqAndPositions,useDocValuesAsStored}' from value '[1, -129, 7, 8]'")); assertThat( thrown.getCause().getCause().getMessage(), is( @@ -777,7 +769,7 @@ public void denseVectorByteEncoding_shouldRaiseExceptionWithFloatValues() throws assertThat( thrown.getCause().getMessage(), is( - "Error while creating field 'vector_byte_encoding{type=knn_vector_byte_encoding,properties=indexed,stored}' from value '[14.3, 6.2, 7.2, 8.1]'")); + "Error while creating field 'vector_byte_encoding{type=knn_vector_byte_encoding,properties=indexed,stored,omitNorms,omitTermFreqAndPositions,useDocValuesAsStored}' from value '[14.3, 6.2, 7.2, 8.1]'")); assertThat( thrown.getCause().getCause().getMessage(), diff --git a/solr/core/src/test/org/apache/solr/search/join/BlockJoinMultiValuedVectorsTest.java b/solr/core/src/test/org/apache/solr/search/join/BlockJoinMultiValuedVectorsTest.java new file mode 100644 index 000000000000..dcf8d9f34a98 --- /dev/null +++ b/solr/core/src/test/org/apache/solr/search/join/BlockJoinMultiValuedVectorsTest.java @@ -0,0 +1,330 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.search.join; + +import java.util.ArrayList; +import java.util.List; +import org.apache.solr.common.SolrInputDocument; +import org.apache.solr.util.RandomNoReverseMergePolicyFactory; +import org.junit.BeforeClass; +import org.junit.ClassRule; +import org.junit.Test; +import org.junit.rules.TestRule; + +public class BlockJoinMultiValuedVectorsTest extends BlockJoinNestedVectorsParentQParserTest { + + protected static String VECTOR_FIELD = "vector_multivalued"; + protected static String VECTOR_BYTE_FIELD = "vector_byte_multivalued"; + + @ClassRule + public static final TestRule noReverseMerge = RandomNoReverseMergePolicyFactory.createRule(); + + @BeforeClass + public static void beforeClass() throws Exception { + /* vectorDimension="4" similarityFunction="cosine" */ + initCore("solrconfig_codec.xml", "schema-densevector.xml"); + prepareIndex(); + } + + protected static void prepareIndex() throws Exception { + List docsToIndex = prepareDocs(); + for (SolrInputDocument doc : docsToIndex) { + updateJ(jsonAdd(doc), null); + } + assertU(commit()); + } + + /** + * The documents in the index are 10 parents, with some parent level metadata and 30 nested + * documents (with vectors and children level metadata) Each parent document has 3 nested + * documents with vectors. + * + *

This allows to run knn queries both at parent/children level and using various pre-filters + * both for parent metadata and children. + * + * @return a list of documents to index + */ + protected static List prepareDocs() { + int totalParentDocuments = 10; + int totalNestedVectors = 30; + int perParentChildren = totalNestedVectors / totalParentDocuments; + + final String[] abcdef = new String[] {"a", "b", "c", "d", "e", "f"}; + + List docs = new ArrayList<>(totalParentDocuments); + for (int i = 1; i < totalParentDocuments + 1; i++) { + SolrInputDocument doc = new SolrInputDocument(); + doc.setField("id", i); + doc.setField("parent_b", true); + doc.setField("parent_s", abcdef[i % abcdef.length]); + List> floatVectors = new ArrayList<>(perParentChildren); + List> byteVectors = new ArrayList<>(perParentChildren); + // nested vector documents have a distance from the query vector inversely proportional to + // their id + for (int j = 0; j < perParentChildren; j++) { + floatVectors.add(outDistanceFloat(FLOAT_QUERY_VECTOR, totalNestedVectors)); + byteVectors.add(outDistanceByte(BYTE_QUERY_VECTOR, totalNestedVectors)); + totalNestedVectors--; // the higher the id of the nested document, lower the distance with + } + doc.setField(VECTOR_FIELD, floatVectors); + doc.setField(VECTOR_BYTE_FIELD, byteVectors); + + docs.add(doc); + } + + return docs; + } + + @Test + public void parentRetrieval_knnChildrenDiversifyingWithNoAllParents_shouldThrowException() { + super.parentRetrieval_knnChildrenDiversifyingWithNoAllParents_shouldThrowException( + VECTOR_FIELD); + } + + @Test + public void childrenRetrievalFloat_filteringByParentMetadata_shouldReturnKnnChildren() { + assertQ( + req( + "fq", "{!child of=$allParents filters=$parent.fq}", + "q", "{!knn f=" + VECTOR_FIELD + " topK=5}" + FLOAT_QUERY_VECTOR, + "fl", "id", + "parent.fq", "parent_s:(a c)", + "allParents", "parent_s:[* TO *]"), + "//*[@numFound='5']", + "//result/doc[1]/str[@name='id'][.='8/vector_multivalued#2']", + "//result/doc[2]/str[@name='id'][.='8/vector_multivalued#1']", + "//result/doc[3]/str[@name='id'][.='8/vector_multivalued#0']", + "//result/doc[4]/str[@name='id'][.='6/vector_multivalued#2']", + "//result/doc[5]/str[@name='id'][.='6/vector_multivalued#1']"); + } + + @Test + public void childrenRetrievalByte_filteringByParentMetadata_shouldReturnKnnChildren() { + assertQ( + req( + "fq", "{!child of=$allParents filters=$parent.fq}", + "q", "{!knn f=" + VECTOR_BYTE_FIELD + " topK=5}" + BYTE_QUERY_VECTOR, + "fl", "id", + "parent.fq", "parent_s:(a c)", + "allParents", "parent_s:[* TO *]"), + "//*[@numFound='5']", + "//result/doc[1]/str[@name='id'][.='8/vector_byte_multivalued#2']", + "//result/doc[2]/str[@name='id'][.='8/vector_byte_multivalued#1']", + "//result/doc[3]/str[@name='id'][.='8/vector_byte_multivalued#0']", + "//result/doc[4]/str[@name='id'][.='6/vector_byte_multivalued#2']", + "//result/doc[5]/str[@name='id'][.='6/vector_byte_multivalued#1']"); + } + + @Test + public void parentRetrievalFloat_knnChildren_shouldReturnKnnParents() { + super.parentRetrieval_knnChildren_shouldReturnKnnParents(VECTOR_FIELD); + } + + @Test + public void parentRetrievalFloat_knnChildrenWithNoDiversifying_shouldReturnOneParent() { + super.parentRetrievalFloat_knnChildrenWithNoDiversifying_shouldReturnOneParent(VECTOR_FIELD); + } + + @Test + public void parentRetrievalFloat_knnChildrenWithParentFilter_shouldReturnKnnParents() { + super.parentRetrieval_knnChildrenWithParentFilter_shouldReturnKnnParents(VECTOR_FIELD); + } + + @Test + public void parentRetrievalByte_knnChildren_shouldReturnKnnParents() { + super.parentRetrieval_knnChildren_shouldReturnKnnParents(VECTOR_BYTE_FIELD); + } + + @Test + public void parentRetrievalByte_knnChildrenWithParentFilter_shouldReturnKnnParents() { + super.parentRetrieval_knnChildrenWithParentFilter_shouldReturnKnnParents(VECTOR_BYTE_FIELD); + } + + @Test + public void + parentRetrievalFloat_ChildTransformerWithChildFilter_shouldFlattenAndReturnBestChild() { + assertQ( + req( + "q", + "{!parent which=$allParents score=max v=$children.q}", + "fl", + "id," + VECTOR_FIELD + ", [child fl=" + VECTOR_FIELD + " childFilter=$children.q]", + "children.q", + "{!knn f=" + VECTOR_FIELD + " topK=3 childrenOf=$allParents}" + FLOAT_QUERY_VECTOR, + "allParents", + "parent_s:[* TO *]"), + "//*[@numFound='3']", + "//result/doc[1]/str[@name='id'][.='10']", + "//result/doc[1]/arr[@name='" + VECTOR_FIELD + "']/arr/float[1][.='2.0']", + "//result/doc[1]/arr[@name='" + VECTOR_FIELD + "']/arr/float[2][.='1.0']", + "//result/doc[1]/arr[@name='" + VECTOR_FIELD + "']/arr/float[3][.='1.0']", + "//result/doc[1]/arr[@name='" + VECTOR_FIELD + "']/arr/float[4][.='1.0']", + "//result/doc[2]/str[@name='id'][.='9']", + "//result/doc[2]/arr[@name='" + VECTOR_FIELD + "']/arr/float[1][.='5.0']", + "//result/doc[2]/arr[@name='" + VECTOR_FIELD + "']/arr/float[2][.='1.0']", + "//result/doc[2]/arr[@name='" + VECTOR_FIELD + "']/arr/float[3][.='1.0']", + "//result/doc[2]/arr[@name='" + VECTOR_FIELD + "']/arr/float[4][.='1.0']", + "//result/doc[3]/str[@name='id'][.='8']", + "//result/doc[3]/arr[@name='" + VECTOR_FIELD + "']/arr/float[1][.='8.0']", + "//result/doc[3]/arr[@name='" + VECTOR_FIELD + "']/arr/float[2][.='1.0']", + "//result/doc[3]/arr[@name='" + VECTOR_FIELD + "']/arr/float[3][.='1.0']", + "//result/doc[3]/arr[@name='" + VECTOR_FIELD + "']/arr/float[4][.='1.0']"); + } + + @Test + public void parentRetrievalFloat_ChildTransformer_shouldFlattenAndReturnAllChildren() { + assertQ( + req( + "q", + "{!parent which=$allParents score=max v=$children.q}", + "fl", + "id," + VECTOR_FIELD + ", [child fl=" + VECTOR_FIELD + " ]", + "children.q", + "{!knn f=" + VECTOR_FIELD + " topK=3 childrenOf=$allParents}" + FLOAT_QUERY_VECTOR, + "allParents", + "parent_s:[* TO *]"), + "//*[@numFound='3']", + "//result/doc[1]/str[@name='id'][.='10']", + "//result/doc[1]/arr[@name='" + VECTOR_FIELD + "']/arr[1]/float[1][.='4.0']", + "//result/doc[1]/arr[@name='" + VECTOR_FIELD + "']/arr[1]/float[2][.='1.0']", + "//result/doc[1]/arr[@name='" + VECTOR_FIELD + "']/arr[1]/float[3][.='1.0']", + "//result/doc[1]/arr[@name='" + VECTOR_FIELD + "']/arr[1]/float[4][.='1.0']", + "//result/doc[1]/arr[@name='" + VECTOR_FIELD + "']/arr[2]/float[1][.='3.0']", + "//result/doc[1]/arr[@name='" + VECTOR_FIELD + "']/arr[2]/float[2][.='1.0']", + "//result/doc[1]/arr[@name='" + VECTOR_FIELD + "']/arr[2]/float[3][.='1.0']", + "//result/doc[1]/arr[@name='" + VECTOR_FIELD + "']/arr[2]/float[4][.='1.0']", + "//result/doc[1]/arr[@name='" + VECTOR_FIELD + "']/arr[3]/float[1][.='2.0']", + "//result/doc[1]/arr[@name='" + VECTOR_FIELD + "']/arr[3]/float[2][.='1.0']", + "//result/doc[1]/arr[@name='" + VECTOR_FIELD + "']/arr[3]/float[3][.='1.0']", + "//result/doc[1]/arr[@name='" + VECTOR_FIELD + "']/arr[3]/float[4][.='1.0']", + "//result/doc[2]/str[@name='id'][.='9']", + "//result/doc[2]/arr[@name='" + VECTOR_FIELD + "']/arr[1]/float[1][.='7.0']", + "//result/doc[2]/arr[@name='" + VECTOR_FIELD + "']/arr[1]/float[2][.='1.0']", + "//result/doc[2]/arr[@name='" + VECTOR_FIELD + "']/arr[1]/float[3][.='1.0']", + "//result/doc[2]/arr[@name='" + VECTOR_FIELD + "']/arr[1]/float[4][.='1.0']", + "//result/doc[2]/arr[@name='" + VECTOR_FIELD + "']/arr[2]/float[1][.='6.0']", + "//result/doc[2]/arr[@name='" + VECTOR_FIELD + "']/arr[2]/float[2][.='1.0']", + "//result/doc[2]/arr[@name='" + VECTOR_FIELD + "']/arr[2]/float[3][.='1.0']", + "//result/doc[2]/arr[@name='" + VECTOR_FIELD + "']/arr[2]/float[4][.='1.0']", + "//result/doc[2]/arr[@name='" + VECTOR_FIELD + "']/arr[3]/float[1][.='5.0']", + "//result/doc[2]/arr[@name='" + VECTOR_FIELD + "']/arr[3]/float[2][.='1.0']", + "//result/doc[2]/arr[@name='" + VECTOR_FIELD + "']/arr[3]/float[3][.='1.0']", + "//result/doc[2]/arr[@name='" + VECTOR_FIELD + "']/arr[3]/float[4][.='1.0']", + "//result/doc[3]/str[@name='id'][.='8']", + "//result/doc[3]/arr[@name='" + VECTOR_FIELD + "']/arr[1]/float[1][.='10.0']", + "//result/doc[3]/arr[@name='" + VECTOR_FIELD + "']/arr[1]/float[2][.='1.0']", + "//result/doc[3]/arr[@name='" + VECTOR_FIELD + "']/arr[1]/float[3][.='1.0']", + "//result/doc[3]/arr[@name='" + VECTOR_FIELD + "']/arr[1]/float[4][.='1.0']", + "//result/doc[3]/arr[@name='" + VECTOR_FIELD + "']/arr[2]/float[1][.='9.0']", + "//result/doc[3]/arr[@name='" + VECTOR_FIELD + "']/arr[2]/float[2][.='1.0']", + "//result/doc[3]/arr[@name='" + VECTOR_FIELD + "']/arr[2]/float[3][.='1.0']", + "//result/doc[3]/arr[@name='" + VECTOR_FIELD + "']/arr[2]/float[4][.='1.0']", + "//result/doc[3]/arr[@name='" + VECTOR_FIELD + "']/arr[3]/float[1][.='8.0']", + "//result/doc[3]/arr[@name='" + VECTOR_FIELD + "']/arr[3]/float[2][.='1.0']", + "//result/doc[3]/arr[@name='" + VECTOR_FIELD + "']/arr[3]/float[3][.='1.0']", + "//result/doc[3]/arr[@name='" + VECTOR_FIELD + "']/arr[3]/float[4][.='1.0']"); + } + + @Test + public void + parentRetrievalByte_ChildTransformerWithChildFilter_shouldFlattenAndReturnBestChild() { + assertQ( + req( + "q", + "{!parent which=$allParents score=max v=$children.q}", + "fl", + "id," + + VECTOR_BYTE_FIELD + + ", [child fl=" + + VECTOR_BYTE_FIELD + + " childFilter=$children.q]", + "children.q", + "{!knn f=" + VECTOR_BYTE_FIELD + " topK=3 childrenOf=$allParents}" + BYTE_QUERY_VECTOR, + "allParents", + "parent_s:[* TO *]"), + "//*[@numFound='3']", + "//result/doc[1]/str[@name='id'][.='10']", + "//result/doc[1]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr/int[1][.='2']", + "//result/doc[1]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr/int[2][.='1']", + "//result/doc[1]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr/int[3][.='1']", + "//result/doc[1]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr/int[4][.='1']", + "//result/doc[2]/str[@name='id'][.='9']", + "//result/doc[2]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr/int[1][.='5']", + "//result/doc[2]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr/int[2][.='1']", + "//result/doc[2]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr/int[3][.='1']", + "//result/doc[2]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr/int[4][.='1']", + "//result/doc[3]/str[@name='id'][.='8']", + "//result/doc[3]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr/int[1][.='8']", + "//result/doc[3]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr/int[2][.='1']", + "//result/doc[3]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr/int[3][.='1']", + "//result/doc[3]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr/int[4][.='1']"); + } + + @Test + public void parentRetrievalByte_ChildTransformer_shouldFlattenAndReturnAllChildren() { + assertQ( + req( + "q", + "{!parent which=$allParents score=max v=$children.q}", + "fl", + "id," + VECTOR_BYTE_FIELD + ", [child fl=" + VECTOR_BYTE_FIELD + " ]", + "children.q", + "{!knn f=" + VECTOR_BYTE_FIELD + " topK=3 childrenOf=$allParents}" + BYTE_QUERY_VECTOR, + "allParents", + "parent_s:[* TO *]"), + "//*[@numFound='3']", + "//result/doc[1]/str[@name='id'][.='10']", + "//result/doc[1]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr[1]/int[1][.='4']", + "//result/doc[1]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr[1]/int[2][.='1']", + "//result/doc[1]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr[1]/int[3][.='1']", + "//result/doc[1]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr[1]/int[4][.='1']", + "//result/doc[1]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr[2]/int[1][.='3']", + "//result/doc[1]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr[2]/int[2][.='1']", + "//result/doc[1]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr[2]/int[3][.='1']", + "//result/doc[1]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr[2]/int[4][.='1']", + "//result/doc[1]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr[3]/int[1][.='2']", + "//result/doc[1]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr[3]/int[2][.='1']", + "//result/doc[1]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr[3]/int[3][.='1']", + "//result/doc[1]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr[3]/int[4][.='1']", + "//result/doc[2]/str[@name='id'][.='9']", + "//result/doc[2]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr[1]/int[1][.='7']", + "//result/doc[2]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr[1]/int[2][.='1']", + "//result/doc[2]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr[1]/int[3][.='1']", + "//result/doc[2]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr[1]/int[4][.='1']", + "//result/doc[2]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr[2]/int[1][.='6']", + "//result/doc[2]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr[2]/int[2][.='1']", + "//result/doc[2]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr[2]/int[3][.='1']", + "//result/doc[2]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr[2]/int[4][.='1']", + "//result/doc[2]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr[3]/int[1][.='5']", + "//result/doc[2]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr[3]/int[2][.='1']", + "//result/doc[2]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr[3]/int[3][.='1']", + "//result/doc[2]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr[3]/int[4][.='1']", + "//result/doc[3]/str[@name='id'][.='8']", + "//result/doc[3]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr[1]/int[1][.='10']", + "//result/doc[3]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr[1]/int[2][.='1']", + "//result/doc[3]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr[1]/int[3][.='1']", + "//result/doc[3]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr[1]/int[4][.='1']", + "//result/doc[3]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr[2]/int[1][.='9']", + "//result/doc[3]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr[2]/int[2][.='1']", + "//result/doc[3]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr[2]/int[3][.='1']", + "//result/doc[3]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr[2]/int[4][.='1']", + "//result/doc[3]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr[3]/int[1][.='8']", + "//result/doc[3]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr[3]/int[2][.='1']", + "//result/doc[3]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr[3]/int[3][.='1']", + "//result/doc[3]/arr[@name='" + VECTOR_BYTE_FIELD + "']/arr[3]/int[4][.='1']"); + } +} diff --git a/solr/core/src/test/org/apache/solr/search/join/BlockJoinNestedVectorsParentQParserTest.java b/solr/core/src/test/org/apache/solr/search/join/BlockJoinNestedVectorsParentQParserTest.java new file mode 100644 index 000000000000..8e374ba01c9f --- /dev/null +++ b/solr/core/src/test/org/apache/solr/search/join/BlockJoinNestedVectorsParentQParserTest.java @@ -0,0 +1,556 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.search.join; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import org.apache.solr.SolrTestCaseJ4; + +public class BlockJoinNestedVectorsParentQParserTest extends SolrTestCaseJ4 { + protected static final List FLOAT_QUERY_VECTOR = Arrays.asList(1.0f, 1.0f, 1.0f, 1.0f); + protected static final List BYTE_QUERY_VECTOR = Arrays.asList(1, 1, 1, 1); + + protected static String VECTORS_PSEUDOFIELD = "vectors"; + + /** + * Generate a resulting float vector with a distance from the original vector that is proportional + * to the value in input (higher the value, higher the distance from the original vector) + * + * @param vector a numerical vector + * @param value a numerical value to be added to the first element of the vector + * @return a numerical vector that has a distance from the input vector, proportional to the value + */ + protected static List outDistanceFloat(List vector, int value) { + List result = new ArrayList<>(vector.size()); + for (int i = 0; i < vector.size(); i++) { + if (i == 0) { + result.add(vector.get(i) + value); + } else { + result.add(vector.get(i)); + } + } + return result; + } + + /** + * Generate a resulting byte vector with a distance from the original vector that is proportional + * to the value in input (higher the value, higher the distance from the original vector) + * + * @param vector a numerical vector + * @param value a numerical value to be added to the first element of the vector + * @return a numerical vector that has a distance from the input vector, proportional to the value + */ + protected static List outDistanceByte(List vector, int value) { + List result = new ArrayList<>(vector.size()); + for (int i = 0; i < vector.size(); i++) { + if (i == 0) { + result.add(vector.get(i) + value); + } else { + result.add(vector.get(i)); + } + } + return result; + } + + protected void parentRetrieval_knnChildrenDiversifyingWithNoAllParents_shouldThrowException( + String vectorField) { + assertQEx( + "When running a diversifying children KNN query, 'allParents' parameter is required", + req( + "q", + "{!parent which=$allParents score=max v=$children.q}", + "fl", + "id,score", + "children.q", + "{!knn f=" + + vectorField + + " topK=3 parents.preFilter=$someParents}" + + FLOAT_QUERY_VECTOR, + "allParents", + "parent_s:[* TO *]", + "someParents", + "parent_s:(a c)"), + 400); + } + + protected void childrenRetrieval_filteringByParentMetadata_shouldReturnKnnChildren( + String vectorField) { + assertQ( + req( + "fq", "{!child of=$allParents filters=$parent.fq}", + "q", "{!knn f=" + vectorField + " topK=5}" + BYTE_QUERY_VECTOR, + "fl", "id", + "parent.fq", "parent_s:(a c)", + "allParents", "parent_s:[* TO *]"), + "//*[@numFound='5']", + "//result/doc[1]/str[@name='id'][.='82']", + "//result/doc[2]/str[@name='id'][.='81']", + "//result/doc[3]/str[@name='id'][.='80']", + "//result/doc[4]/str[@name='id'][.='62']", + "//result/doc[5]/str[@name='id'][.='61']"); + } + + protected void parentRetrievalFloat_knnChildrenWithNoDiversifying_shouldReturnOneParent( + String vectorField) { + assertQ( + req( + "q", "{!parent which=$allParents score=max v=$children.q}", + "fl", "id,score", + "children.q", "{!knn f=" + vectorField + " topK=3}" + FLOAT_QUERY_VECTOR, + "allParents", "parent_s:[* TO *]"), + "//*[@numFound='1']", + "//result/doc[1]/str[@name='id'][.='10']"); + } + + protected void parentRetrieval_knnChildren_shouldReturnKnnParents(String vectorByteField) { + assertQ( + req( + "q", "{!parent which=$allParents score=max v=$children.q}", + "fl", "id,score", + "children.q", + "{!knn f=" + + vectorByteField + + " topK=3 childrenOf=$allParents}" + + BYTE_QUERY_VECTOR, + "allParents", "parent_s:[* TO *]"), + "//*[@numFound='3']", + "//result/doc[1]/str[@name='id'][.='10']", + "//result/doc[2]/str[@name='id'][.='9']", + "//result/doc[3]/str[@name='id'][.='8']"); + } + + protected void parentRetrieval_knnChildrenWithParentFilter_shouldReturnKnnParents( + String vectorByteField) { + assertQ( + req( + "q", "{!parent which=$allParents score=max v=$children.q}", + "fl", "id,score", + "children.q", + "{!knn f=" + + vectorByteField + + " topK=3 parents.preFilter=$someParents childrenOf=$allParents}" + + BYTE_QUERY_VECTOR, + "allParents", "parent_s:[* TO *]", + "someParents", "parent_s:(a c)"), + "//*[@numFound='3']", + "//result/doc[1]/str[@name='id'][.='8']", + "//result/doc[2]/str[@name='id'][.='6']", + "//result/doc[3]/str[@name='id'][.='2']"); + } + + protected void + parentRetrieval_knnChildrenWithParentFilterAndChildrenFilter_shouldReturnKnnParents( + String vectorByteField) { + assertQ( + req( + "q", "{!parent which=$allParents score=max v=$children.q}", + "fl", "id,score", + "children.q", + "{!knn f=" + + vectorByteField + + " topK=3 preFilter=child_s:m parents.preFilter=$someParents childrenOf=$allParents}" + + BYTE_QUERY_VECTOR, + "allParents", "parent_s:[* TO *]", + "someParents", "parent_s:(a c)"), + "//*[@numFound='2']", + "//result/doc[1]/str[@name='id'][.='8']", + "//result/doc[2]/str[@name='id'][.='2']"); + } + + protected void parentRetrievalFloat_topKWithChildTransformer_shouldReturnAllChildren( + String vectorField) { + assertQ( + req( + "q", "{!parent which=$allParents score=max v=$children.q}", + "fl", + "id,score," + + VECTORS_PSEUDOFIELD + + "," + + vectorField + + ",[child limit=2 fl=vector]", + "children.q", + "{!knn f=" + + vectorField + + " topK=3 parents.preFilter=$someParents childrenOf=$allParents}" + + FLOAT_QUERY_VECTOR, + "allParents", "parent_s:[* TO *]", + "someParents", "parent_s:(a c)"), + "//result[@numFound='3']", + "//result/doc[1]/str[@name='id'][.='8']", + "//result/doc[1]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + vectorField + + "']/float[1][.='10.0']", + "//result/doc[1]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + vectorField + + "']/float[2][.='1.0']", + "//result/doc[1]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + vectorField + + "']/float[3][.='1.0']", + "//result/doc[1]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + vectorField + + "']/float[4][.='1.0']", + "//result/doc[1]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[2]/arr[@name='" + + vectorField + + "']/float[1][.='9.0']", + "//result/doc[1]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[2]/arr[@name='" + + vectorField + + "']/float[2][.='1.0']", + "//result/doc[1]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[2]/arr[@name='" + + vectorField + + "']/float[3][.='1.0']", + "//result/doc[1]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[2]/arr[@name='" + + vectorField + + "']/float[4][.='1.0']", + "//result/doc[2]/str[@name='id'][.='6']", + "//result/doc[2]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + vectorField + + "']/float[1][.='16.0']", + "//result/doc[2]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + vectorField + + "']/float[2][.='1.0']", + "//result/doc[2]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + vectorField + + "']/float[3][.='1.0']", + "//result/doc[2]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + vectorField + + "']/float[4][.='1.0']", + "//result/doc[2]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[2]/arr[@name='" + + vectorField + + "']/float[1][.='15.0']", + "//result/doc[2]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[2]/arr[@name='" + + vectorField + + "']/float[2][.='1.0']", + "//result/doc[2]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[2]/arr[@name='" + + vectorField + + "']/float[3][.='1.0']", + "//result/doc[2]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[2]/arr[@name='" + + vectorField + + "']/float[4][.='1.0']", + "//result/doc[3]/str[@name='id'][.='2']", + "//result/doc[3]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + vectorField + + "']/float[1][.='28.0']", + "//result/doc[3]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + vectorField + + "']/float[2][.='1.0']", + "//result/doc[3]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + vectorField + + "']/float[3][.='1.0']", + "//result/doc[3]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + vectorField + + "']/float[4][.='1.0']", + "//result/doc[3]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[2]/arr[@name='" + + vectorField + + "']/float[1][.='27.0']", + "//result/doc[3]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[2]/arr[@name='" + + vectorField + + "']/float[2][.='1.0']", + "//result/doc[3]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[2]/arr[@name='" + + vectorField + + "']/float[3][.='1.0']", + "//result/doc[3]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[2]/arr[@name='" + + vectorField + + "']/float[4][.='1.0']"); + } + + protected void parentRetrievalByte_topKWithChildTransformer_shouldReturnAllChildren( + String vectorByteField) { + assertQ( + req( + "q", + "{!parent which=$allParents score=max v=$children.q}", + "fl", + "id,score," + + VECTORS_PSEUDOFIELD + + "," + + vectorByteField + + ",[child limit=2 fl=" + + vectorByteField + + "]", + "children.q", + "{!knn f=" + + vectorByteField + + " topK=3 parents.preFilter=$someParents childrenOf=$allParents}" + + BYTE_QUERY_VECTOR, + "allParents", + "parent_s:[* TO *]", + "someParents", + "parent_s:(b c)"), + "//result[@numFound='3']", + "//result/doc[1]/str[@name='id'][.='8']", + "//result/doc[1]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + vectorByteField + + "']/int[1][.='10']", + "//result/doc[1]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + vectorByteField + + "']/int[2][.='1']", + "//result/doc[1]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + vectorByteField + + "']/int[3][.='1']", + "//result/doc[1]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + vectorByteField + + "']/int[4][.='1']", + "//result/doc[1]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[2]/arr[@name='" + + vectorByteField + + "']/int[1][.='9']", + "//result/doc[1]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[2]/arr[@name='" + + vectorByteField + + "']/int[2][.='1']", + "//result/doc[1]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[2]/arr[@name='" + + vectorByteField + + "']/int[3][.='1']", + "//result/doc[1]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[2]/arr[@name='" + + vectorByteField + + "']/int[4][.='1']", + "//result/doc[2]/str[@name='id'][.='7']", + "//result/doc[2]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + vectorByteField + + "']/int[1][.='13']", + "//result/doc[2]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + vectorByteField + + "']/int[2][.='1']", + "//result/doc[2]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + vectorByteField + + "']/int[3][.='1']", + "//result/doc[2]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + vectorByteField + + "']/int[4][.='1']", + "//result/doc[2]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[2]/arr[@name='" + + vectorByteField + + "']/int[1][.='12']", + "//result/doc[2]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[2]/arr[@name='" + + vectorByteField + + "']/int[2][.='1']", + "//result/doc[2]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[2]/arr[@name='" + + vectorByteField + + "']/int[3][.='1']", + "//result/doc[2]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[2]/arr[@name='" + + vectorByteField + + "']/int[4][.='1']", + "//result/doc[3]/str[@name='id'][.='2']", + "//result/doc[3]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + vectorByteField + + "']/int[1][.='28']", + "//result/doc[3]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + vectorByteField + + "']/int[2][.='1']", + "//result/doc[3]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + vectorByteField + + "']/int[3][.='1']", + "//result/doc[3]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + vectorByteField + + "']/int[4][.='1']", + "//result/doc[3]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[2]/arr[@name='" + + vectorByteField + + "']/int[1][.='27']", + "//result/doc[3]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[2]/arr[@name='" + + vectorByteField + + "']/int[2][.='1']", + "//result/doc[3]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[2]/arr[@name='" + + vectorByteField + + "']/int[3][.='1']", + "//result/doc[3]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[2]/arr[@name='" + + vectorByteField + + "']/int[4][.='1']"); + } + + protected void parentRetrievalByte_topKWithChildTransformerWithFilter_shouldReturnBestChild( + String vectorByteField) { + assertQ( + req( + "q", + "{!parent which=$allParents score=max v=$children.q}", + "fl", + "id,score," + + VECTORS_PSEUDOFIELD + + "," + + vectorByteField + + ",[child fl=" + + vectorByteField + + " childFilter=$children.q]", + "children.q", + "{!knn f=" + + vectorByteField + + " topK=3 parents.preFilter=$someParents childrenOf=$allParents}" + + BYTE_QUERY_VECTOR, + "allParents", + "parent_s:[* TO *]", + "someParents", + "parent_s:(b c)"), + "//result[@numFound='3']", + "//result/doc[1]/str[@name='id'][.='8']", + "//result/doc[1]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + vectorByteField + + "']/int[1][.='8']", + "//result/doc[1]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + vectorByteField + + "']/int[2][.='1']", + "//result/doc[1]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + vectorByteField + + "']/int[3][.='1']", + "//result/doc[1]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + vectorByteField + + "']/int[4][.='1']", + "//result/doc[2]/str[@name='id'][.='7']", + "//result/doc[2]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + vectorByteField + + "']/int[1][.='11']", + "//result/doc[2]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + vectorByteField + + "']/int[2][.='1']", + "//result/doc[2]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + vectorByteField + + "']/int[3][.='1']", + "//result/doc[2]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + vectorByteField + + "']/int[4][.='1']", + "//result/doc[3]/str[@name='id'][.='2']", + "//result/doc[3]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + vectorByteField + + "']/int[1][.='26']", + "//result/doc[3]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + vectorByteField + + "']/int[2][.='1']", + "//result/doc[3]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + vectorByteField + + "']/int[3][.='1']", + "//result/doc[3]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + vectorByteField + + "']/int[4][.='1']"); + } +} diff --git a/solr/core/src/test/org/apache/solr/search/join/BlockJoinNestedVectorsTest.java b/solr/core/src/test/org/apache/solr/search/join/BlockJoinNestedVectorsTest.java new file mode 100644 index 000000000000..81303600e7da --- /dev/null +++ b/solr/core/src/test/org/apache/solr/search/join/BlockJoinNestedVectorsTest.java @@ -0,0 +1,254 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.search.join; + +import java.util.ArrayList; +import java.util.List; +import org.apache.solr.common.SolrInputDocument; +import org.apache.solr.util.RandomNoReverseMergePolicyFactory; +import org.junit.BeforeClass; +import org.junit.ClassRule; +import org.junit.Test; +import org.junit.rules.TestRule; + +public class BlockJoinNestedVectorsTest extends BlockJoinNestedVectorsParentQParserTest { + protected static String VECTOR_FIELD = "vector"; + protected static String VECTOR_BYTE_FIELD = "vector_byte_encoding"; + + @ClassRule + public static final TestRule noReverseMerge = RandomNoReverseMergePolicyFactory.createRule(); + + @BeforeClass + public static void beforeClass() throws Exception { + /* vectorDimension="4" similarityFunction="cosine" */ + initCore("solrconfig_codec.xml", "schema-densevector.xml"); + prepareIndex(); + } + + protected static void prepareIndex() throws Exception { + List docsToIndex = prepareDocs(); + for (SolrInputDocument doc : docsToIndex) { + assertU(adoc(doc)); + } + assertU(commit()); + } + + /** + * The documents in the index are 10 parents, with some parent level metadata and 30 nested + * documents (with vectors and children level metadata) Each parent document has 3 nested + * documents with vectors. + * + *

This allows to run knn queries both at parent/children level and using various pre-filters + * both for parent metadata and children. + * + * @return a list of documents to index + */ + private static List prepareDocs() { + int totalParentDocuments = 10; + int totalNestedVectors = 30; + int perParentChildren = totalNestedVectors / totalParentDocuments; + + final String[] klm = new String[] {"k", "l", "m"}; + final String[] abcdef = new String[] {"a", "b", "c", "d", "e", "f"}; + + List docs = new ArrayList<>(totalParentDocuments); + for (int i = 1; i < totalParentDocuments + 1; i++) { + SolrInputDocument doc = new SolrInputDocument(); + doc.setField("id", i); + doc.setField("parent_b", true); + + doc.setField("parent_s", abcdef[i % abcdef.length]); + List children = new ArrayList<>(perParentChildren); + + // nested vector documents have a distance from the query vector inversely proportional to + // their id + for (int j = 0; j < perParentChildren; j++) { + SolrInputDocument child = new SolrInputDocument(); + child.setField("id", i + "" + j); + child.setField("child_s", klm[i % klm.length]); + child.setField("vector", outDistanceFloat(FLOAT_QUERY_VECTOR, totalNestedVectors)); + child.setField( + "vector_byte_encoding", outDistanceByte(BYTE_QUERY_VECTOR, totalNestedVectors)); + totalNestedVectors--; // the higher the id of the nested document, lower the distance with + // the query vector + children.add(child); + } + doc.setField("vectors", children); + docs.add(doc); + } + + return docs; + } + + @Test + public void parentRetrieval_knnChildrenDiversifyingWithNoAllParents_shouldThrowException() { + super.parentRetrieval_knnChildrenDiversifyingWithNoAllParents_shouldThrowException( + VECTOR_FIELD); + } + + @Test + public void childrenRetrievalFloat_filteringByParentMetadata_shouldReturnKnnChildren() { + super.childrenRetrieval_filteringByParentMetadata_shouldReturnKnnChildren(VECTOR_FIELD); + } + + @Test + public void childrenRetrievalByte_filteringByParentMetadata_shouldReturnKnnChildren() { + super.childrenRetrieval_filteringByParentMetadata_shouldReturnKnnChildren(VECTOR_BYTE_FIELD); + } + + @Test + public void parentRetrievalFloat_knnChildren_shouldReturnKnnParents() { + super.parentRetrieval_knnChildren_shouldReturnKnnParents(VECTOR_FIELD); + } + + @Test + public void parentRetrievalFloat_knnChildrenWithNoDiversifying_shouldReturnOneParent() { + super.parentRetrievalFloat_knnChildrenWithNoDiversifying_shouldReturnOneParent(VECTOR_FIELD); + } + + @Test + public void parentRetrievalFloat_knnChildrenWithParentFilter_shouldReturnKnnParents() { + super.parentRetrieval_knnChildrenWithParentFilter_shouldReturnKnnParents(VECTOR_FIELD); + } + + @Test + public void + parentRetrievalFloat_knnChildrenWithParentFilterAndChildrenFilter_shouldReturnKnnParents() { + super.parentRetrieval_knnChildrenWithParentFilterAndChildrenFilter_shouldReturnKnnParents( + VECTOR_FIELD); + } + + @Test + public void parentRetrievalByte_knnChildren_shouldReturnKnnParents() { + super.parentRetrieval_knnChildren_shouldReturnKnnParents(VECTOR_BYTE_FIELD); + } + + @Test + public void parentRetrievalByte_knnChildrenWithParentFilter_shouldReturnKnnParents() { + super.parentRetrieval_knnChildrenWithParentFilter_shouldReturnKnnParents(VECTOR_BYTE_FIELD); + } + + @Test + public void + parentRetrievalByte_knnChildrenWithParentFilterAndChildrenFilter_shouldReturnKnnParents() { + super.parentRetrieval_knnChildrenWithParentFilterAndChildrenFilter_shouldReturnKnnParents( + VECTOR_BYTE_FIELD); + } + + @Test + public void parentRetrievalFloat_topKWithChildTransformerWithFilter_shouldReturnAllChildren() { + super.parentRetrievalFloat_topKWithChildTransformer_shouldReturnAllChildren(VECTOR_FIELD); + } + + @Test + public void parentRetrievalFloat_topKWithChildTransformerWithFilter_shouldReturnBestChild() { + assertQ( + req( + "q", + "{!parent which=$allParents score=max v=$children.q}", + "fl", + "id,score," + + VECTORS_PSEUDOFIELD + + "," + + VECTOR_FIELD + + ",[child fl=vector childFilter=$children.q]", + "children.q", + "{!knn f=" + + VECTOR_FIELD + + " topK=3 parents.preFilter=$someParents childrenOf=$allParents}" + + FLOAT_QUERY_VECTOR, + "allParents", + "parent_s:[* TO *]", + "someParents", + "parent_s:(b c)"), + "//result[@numFound='3']", + "//result/doc[1]/str[@name='id'][.='8']", + "//result/doc[1]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + VECTOR_FIELD + + "']/float[1][.='8.0']", + "//result/doc[1]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + VECTOR_FIELD + + "']/float[2][.='1.0']", + "//result/doc[1]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + VECTOR_FIELD + + "']/float[3][.='1.0']", + "//result/doc[1]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + VECTOR_FIELD + + "']/float[4][.='1.0']", + "//result/doc[2]/str[@name='id'][.='7']", + "//result/doc[2]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + VECTOR_FIELD + + "']/float[1][.='11.0']", + "//result/doc[2]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + VECTOR_FIELD + + "']/float[2][.='1.0']", + "//result/doc[2]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + VECTOR_FIELD + + "']/float[3][.='1.0']", + "//result/doc[2]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + VECTOR_FIELD + + "']/float[4][.='1.0']", + "//result/doc[3]/str[@name='id'][.='2']", + "//result/doc[3]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + VECTOR_FIELD + + "']/float[1][.='26.0']", + "//result/doc[3]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + VECTOR_FIELD + + "']/float[2][.='1.0']", + "//result/doc[3]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + VECTOR_FIELD + + "']/float[3][.='1.0']", + "//result/doc[3]/arr[@name='" + + VECTORS_PSEUDOFIELD + + "'][1]/doc[1]/arr[@name='" + + VECTOR_FIELD + + "']/float[4][.='1.0']"); + } + + @Test + public void parentRetrievalByte_topKWithChildTransformer_shouldReturnAllChildren() { + super.parentRetrievalByte_topKWithChildTransformer_shouldReturnAllChildren(VECTOR_BYTE_FIELD); + } + + @Test + public void parentRetrievalByte_topKWithChildTransformerWithFilter_shouldReturnBestChild() { + super.parentRetrievalByte_topKWithChildTransformerWithFilter_shouldReturnBestChild( + VECTOR_BYTE_FIELD); + } +} diff --git a/solr/solr-ref-guide/modules/query-guide/pages/dense-vector-search.adoc b/solr/solr-ref-guide/modules/query-guide/pages/dense-vector-search.adoc index 4dc9239fd0b6..94938e64c1a6 100644 --- a/solr/solr-ref-guide/modules/query-guide/pages/dense-vector-search.adoc +++ b/solr/solr-ref-guide/modules/query-guide/pages/dense-vector-search.adoc @@ -174,12 +174,9 @@ For more details, refer to the official https://arxiv.org/pdf/1603.09320[2018 pa Accepted values: Any integer. -`DenseVectorField` supports the attributes: `indexed`, `stored`. +`DenseVectorField` supports the attributes: `indexed`, `stored`, `multivalued`. -[NOTE] -currently multivalue is not supported - -Here's how a `DenseVectorField` should be indexed: +Here's how a `DenseVectorField` should be indexed when single valued: [tabs#densevectorfield-index] ====== @@ -243,6 +240,53 @@ client.add(Arrays.asList(d1, d2)); ==== ====== +Here's how a `DenseVectorField` should be indexed when multi-valued: + +[tabs#densevectorfield-index] +====== +JSON:: ++ +==== +[source,json] +---- +[{ "id": "1", + "vector_multivalued": [[1.0, 2.0, 3.0, 4.0],[5.0, 6.0, 7.0, 8.0]] +}, +{ "id": "2", + "vector_multivalued": [[1.0, 2.0, 3.0, 4.0],[5.0, 6.0, 7.0, 8.0]] +} +] +---- +==== + +SolrJ:: ++ +==== +[source,java,indent=0] +---- +final SolrClient client = getSolrClient(); + +final SolrInputDocument d1 = new SolrInputDocument(); +d1.setField("id", "1"); +List> floatVectors1 = new ArrayList<>(2); +floatVectors1.add(Arrays.asList(1.0f, 2.0f, 3.0f, 4.0f)); +floatVectors1.add(Arrays.asList(5.0f, 6.0f, 7.0f, 8.0f)); +d1.setField("vector_multivalued", floatVectors1); + + +final SolrInputDocument d2 = new SolrInputDocument(); +d2.setField("id", "2"); +List> floatVectors2 = new ArrayList<>(2); +floatVectors2.add(Arrays.asList(1.0f, 2.0f, 3.0f, 4.0f)); +floatVectors2.add(Arrays.asList(5.0f, 6.0f, 7.0f, 8.0f)); +d2.setField("vector_multivalued", floatVectors2); + +client.add(Arrays.asList(d1, d2)); + +---- +==== +====== + === ScalarQuantizedDenseVectorField Because dense vectors can have a costly size, it may be worthwhile to use a technique called "quantization" which creates a compressed representation of the original vectors. This allows more of the index to be stored in faster memory @@ -593,6 +637,79 @@ The search results retrieved are the k=10 nearest documents to the vector encode For more details on how to work with vectorise text in Apache Solr, please refer to the dedicated page: xref:text-to-vector.adoc[Text to Vector] +=== Handle multivalued vector fields at query time +Behind the scenes a multivalued vector field is handled by Solr as nested documents with a single vector each (see the parameters for the knn query parser that deal with nested vectors 'parents.preFilter' and 'childrenOf'). + +So you should query a multivalued vector fields following the same syntax: +[source,text] +?q={!parent which=$allParents score=max v=$children.q} +&children.q={!knn f=vector_multivalued topK=3 parents.preFilter=$someParents childrenOf=$allParents}[1.0, 2.0, 3.0, 4.0] +&allParents=*:* -_nest_path_:* +&someParents=color_s:RED + +In terms of rendering the results, you need the child transformer if you want to output them flat (you can choose to only return the best vector per result or all vectors): + +All Children +[source,text] +fl=id,vector_multivalued,[child fl="vector_multivalued"] + +==== +[source,json] +---- +"docs":[ + { + "id":"1", + "vector_multivalued":[ + [ + 1.0,2.0, 3.0, 4.0 + ], + [ + 5.0,6.0, 7.0, 8.0 + ] + ] + }, + { + "id":"2", + "vector_multivalued":[ + [ + 1.0,2.0, 3.0, 4.0 + ], + [ + 5.0,6.0, 7.0, 8.0 + ] + ] + }] +---- +==== + +Best Child +[source,text] +fl=id,vector_multivalued,[child fl="vector_multivalued" childFilter=$children.q] + +==== +[source,json] +---- +"docs":[ + { + "id":"1", + "vector_multivalued":[ + [ + 1.0,2.0, 3.0, 4.0 + ] + ] + }, + { + "id":"2", + "vector_multivalued":[ + [ + 1.0,2.0, 3.0, 4.0 + ] + ] + }] +---- +==== + + === vectorSimilarity Query Parser The `vectorSimilarity` vector similarity query parser matches documents whose similarity with the target vector is a above a minimum threshold.