diff --git a/core/trino-main/src/main/java/io/trino/connector/CatalogServiceProviderModule.java b/core/trino-main/src/main/java/io/trino/connector/CatalogServiceProviderModule.java index 17f03cf2d8f3..84791d0ba852 100644 --- a/core/trino-main/src/main/java/io/trino/connector/CatalogServiceProviderModule.java +++ b/core/trino-main/src/main/java/io/trino/connector/CatalogServiceProviderModule.java @@ -34,6 +34,7 @@ import io.trino.spi.connector.ConnectorPageSinkProvider; import io.trino.spi.connector.ConnectorPageSourceProvider; import io.trino.spi.connector.ConnectorSplitManager; +import io.trino.spi.function.FunctionProvider; import javax.inject.Inject; import javax.inject.Singleton; @@ -162,6 +163,13 @@ public static CatalogServiceProvider> createAcc return new ConnectorCatalogServiceProvider<>("access control", connectorServicesProvider, ConnectorServices::getAccessControl); } + @Provides + @Singleton + public static CatalogServiceProvider createFunctionProvider(ConnectorServicesProvider connectorServicesProvider) + { + return new ConnectorCatalogServiceProvider<>("function provider", connectorServicesProvider, ConnectorServices::getFunctionProvider); + } + private static class ConnectorAccessControlLazyRegister { @Inject diff --git a/core/trino-main/src/main/java/io/trino/connector/ConnectorServices.java b/core/trino-main/src/main/java/io/trino/connector/ConnectorServices.java index be3e78fe49b1..e94561c179a7 100644 --- a/core/trino-main/src/main/java/io/trino/connector/ConnectorServices.java +++ b/core/trino-main/src/main/java/io/trino/connector/ConnectorServices.java @@ -34,6 +34,7 @@ import io.trino.spi.connector.SystemTable; import io.trino.spi.connector.TableProcedureMetadata; import io.trino.spi.eventlistener.EventListener; +import io.trino.spi.function.FunctionProvider; import io.trino.spi.procedure.Procedure; import io.trino.spi.ptf.ArgumentSpecification; import io.trino.spi.ptf.ConnectorTableFunction; @@ -65,6 +66,7 @@ public class ConnectorServices private final Set systemTables; private final CatalogProcedures procedures; private final CatalogTableProcedures tableProcedures; + private final Optional functionProvider; private final CatalogTableFunctions tableFunctions; private final Optional splitManager; private final Optional pageSourceProvider; @@ -101,6 +103,8 @@ public ConnectorServices(CatalogHandle catalogHandle, Connector connector, Runna requireNonNull(procedures, format("Connector '%s' returned a null table procedures set", catalogHandle)); this.tableProcedures = new CatalogTableProcedures(tableProcedures); + this.functionProvider = requireNonNull(connector.getFunctionProvider(), format("Connector '%s' returned a null function provider", catalogHandle)); + Set tableFunctions = connector.getTableFunctions(); requireNonNull(tableFunctions, format("Connector '%s' returned a null table functions set", catalogHandle)); this.tableFunctions = new CatalogTableFunctions(tableFunctions); @@ -226,6 +230,12 @@ public CatalogTableProcedures getTableProcedures() return tableProcedures; } + public FunctionProvider getFunctionProvider() + { + checkArgument(functionProvider.isPresent(), "Connector '%s' does not have functions", catalogHandle); + return functionProvider.get(); + } + public CatalogTableFunctions getTableFunctions() { return tableFunctions; diff --git a/core/trino-main/src/main/java/io/trino/json/CachingResolver.java b/core/trino-main/src/main/java/io/trino/json/CachingResolver.java index 7b337a0114d1..d3c1ec519e47 100644 --- a/core/trino-main/src/main/java/io/trino/json/CachingResolver.java +++ b/core/trino-main/src/main/java/io/trino/json/CachingResolver.java @@ -19,11 +19,11 @@ import io.trino.Session; import io.trino.collect.cache.NonEvictableCache; import io.trino.json.ir.IrPathNode; -import io.trino.metadata.BoundSignature; import io.trino.metadata.Metadata; import io.trino.metadata.OperatorNotFoundException; import io.trino.metadata.ResolvedFunction; import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.function.BoundSignature; import io.trino.spi.function.OperatorType; import io.trino.spi.type.Type; import io.trino.spi.type.TypeManager; diff --git a/core/trino-main/src/main/java/io/trino/metadata/CatalogSchemaFunctionName.java b/core/trino-main/src/main/java/io/trino/metadata/CatalogSchemaFunctionName.java index 6956779f7136..60f39bf23676 100644 --- a/core/trino-main/src/main/java/io/trino/metadata/CatalogSchemaFunctionName.java +++ b/core/trino-main/src/main/java/io/trino/metadata/CatalogSchemaFunctionName.java @@ -13,6 +13,8 @@ */ package io.trino.metadata; +import io.trino.spi.function.SchemaFunctionName; + import java.util.Objects; public final class CatalogSchemaFunctionName diff --git a/core/trino-main/src/main/java/io/trino/metadata/CatalogTableFunctions.java b/core/trino-main/src/main/java/io/trino/metadata/CatalogTableFunctions.java index 98883e1633b0..657cb19d8adf 100644 --- a/core/trino-main/src/main/java/io/trino/metadata/CatalogTableFunctions.java +++ b/core/trino-main/src/main/java/io/trino/metadata/CatalogTableFunctions.java @@ -14,6 +14,7 @@ package io.trino.metadata; import com.google.common.collect.Maps; +import io.trino.spi.function.SchemaFunctionName; import io.trino.spi.ptf.ConnectorTableFunction; import javax.annotation.concurrent.ThreadSafe; diff --git a/core/trino-main/src/main/java/io/trino/metadata/FunctionBinding.java b/core/trino-main/src/main/java/io/trino/metadata/FunctionBinding.java index d8ec689a19f7..548e086b80eb 100644 --- a/core/trino-main/src/main/java/io/trino/metadata/FunctionBinding.java +++ b/core/trino-main/src/main/java/io/trino/metadata/FunctionBinding.java @@ -14,6 +14,8 @@ package io.trino.metadata; import com.google.common.collect.ImmutableSortedMap; +import io.trino.spi.function.BoundSignature; +import io.trino.spi.function.FunctionId; import io.trino.spi.type.Type; import java.util.Map; diff --git a/core/trino-main/src/main/java/io/trino/metadata/FunctionBundle.java b/core/trino-main/src/main/java/io/trino/metadata/FunctionBundle.java index ef4af5d2ba7c..63167fb3ff1e 100644 --- a/core/trino-main/src/main/java/io/trino/metadata/FunctionBundle.java +++ b/core/trino-main/src/main/java/io/trino/metadata/FunctionBundle.java @@ -13,9 +13,16 @@ */ package io.trino.metadata; -import io.trino.operator.aggregation.AggregationMetadata; -import io.trino.operator.window.WindowFunctionSupplier; +import io.trino.spi.function.AggregationFunctionMetadata; +import io.trino.spi.function.AggregationImplementation; +import io.trino.spi.function.BoundSignature; +import io.trino.spi.function.FunctionDependencies; +import io.trino.spi.function.FunctionDependencyDeclaration; +import io.trino.spi.function.FunctionId; +import io.trino.spi.function.FunctionMetadata; import io.trino.spi.function.InvocationConvention; +import io.trino.spi.function.ScalarFunctionImplementation; +import io.trino.spi.function.WindowFunctionSupplier; import java.util.Collection; @@ -27,13 +34,13 @@ public interface FunctionBundle FunctionDependencyDeclaration getFunctionDependencies(FunctionId functionId, BoundSignature boundSignature); - FunctionInvoker getScalarFunctionInvoker( + ScalarFunctionImplementation getScalarFunctionImplementation( FunctionId functionId, BoundSignature boundSignature, FunctionDependencies functionDependencies, InvocationConvention invocationConvention); - AggregationMetadata getAggregateFunctionImplementation(FunctionId functionId, BoundSignature boundSignature, FunctionDependencies functionDependencies); + AggregationImplementation getAggregationImplementation(FunctionId functionId, BoundSignature boundSignature, FunctionDependencies functionDependencies); - WindowFunctionSupplier getWindowFunctionImplementation(FunctionId functionId, BoundSignature boundSignature, FunctionDependencies functionDependencies); + WindowFunctionSupplier getWindowFunctionSupplier(FunctionId functionId, BoundSignature boundSignature, FunctionDependencies functionDependencies); } diff --git a/core/trino-main/src/main/java/io/trino/metadata/FunctionInvoker.java b/core/trino-main/src/main/java/io/trino/metadata/FunctionInvoker.java deleted file mode 100644 index 22b81535e837..000000000000 --- a/core/trino-main/src/main/java/io/trino/metadata/FunctionInvoker.java +++ /dev/null @@ -1,58 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.trino.metadata; - -import com.google.common.collect.ImmutableList; - -import java.lang.invoke.MethodHandle; -import java.util.List; -import java.util.Optional; - -import static java.util.Objects.requireNonNull; - -public class FunctionInvoker -{ - private final MethodHandle methodHandle; - private final Optional instanceFactory; - private final List> lambdaInterfaces; - - public FunctionInvoker(MethodHandle methodHandle, Optional instanceFactory) - { - this.methodHandle = requireNonNull(methodHandle, "methodHandle is null"); - this.instanceFactory = requireNonNull(instanceFactory, "instanceFactory is null"); - this.lambdaInterfaces = ImmutableList.of(); - } - - public FunctionInvoker(MethodHandle methodHandle, Optional instanceFactory, List> lambdaInterfaces) - { - this.methodHandle = requireNonNull(methodHandle, "methodHandle is null"); - this.instanceFactory = requireNonNull(instanceFactory, "instanceFactory is null"); - this.lambdaInterfaces = requireNonNull(lambdaInterfaces, "lambdaInterfaces is null"); - } - - public MethodHandle getMethodHandle() - { - return methodHandle; - } - - public Optional getInstanceFactory() - { - return instanceFactory; - } - - public List> getLambdaInterfaces() - { - return lambdaInterfaces; - } -} diff --git a/core/trino-main/src/main/java/io/trino/metadata/FunctionManager.java b/core/trino-main/src/main/java/io/trino/metadata/FunctionManager.java index 8c7cb11f60c6..e5a63bc53f66 100644 --- a/core/trino-main/src/main/java/io/trino/metadata/FunctionManager.java +++ b/core/trino-main/src/main/java/io/trino/metadata/FunctionManager.java @@ -17,14 +17,21 @@ import com.google.common.util.concurrent.UncheckedExecutionException; import io.trino.FeaturesConfig; import io.trino.collect.cache.NonEvictableCache; -import io.trino.operator.aggregation.AggregationMetadata; -import io.trino.operator.window.WindowFunctionSupplier; +import io.trino.connector.CatalogServiceProvider; +import io.trino.connector.system.GlobalSystemConnector; import io.trino.spi.TrinoException; import io.trino.spi.block.Block; import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.function.AggregationImplementation; +import io.trino.spi.function.BoundSignature; +import io.trino.spi.function.FunctionDependencies; +import io.trino.spi.function.FunctionId; +import io.trino.spi.function.FunctionProvider; import io.trino.spi.function.InOut; import io.trino.spi.function.InvocationConvention; import io.trino.spi.function.InvocationConvention.InvocationArgumentConvention; +import io.trino.spi.function.ScalarFunctionImplementation; +import io.trino.spi.function.WindowFunctionSupplier; import io.trino.spi.type.Type; import io.trino.spi.type.TypeOperators; import io.trino.type.BlockTypeOperators; @@ -51,14 +58,15 @@ public class FunctionManager { - private final NonEvictableCache specializedScalarCache; - private final NonEvictableCache specializedAggregationCache; + private final NonEvictableCache specializedScalarCache; + private final NonEvictableCache specializedAggregationCache; private final NonEvictableCache specializedWindowCache; + private final CatalogServiceProvider functionProviders; private final GlobalFunctionCatalog globalFunctionCatalog; @Inject - public FunctionManager(GlobalFunctionCatalog globalFunctionCatalog) + public FunctionManager(CatalogServiceProvider functionProviders, GlobalFunctionCatalog globalFunctionCatalog) { specializedScalarCache = buildNonEvictableCache(CacheBuilder.newBuilder() .maximumSize(1000) @@ -72,13 +80,14 @@ public FunctionManager(GlobalFunctionCatalog globalFunctionCatalog) .maximumSize(1000) .expireAfterWrite(1, HOURS)); - this.globalFunctionCatalog = globalFunctionCatalog; + this.functionProviders = requireNonNull(functionProviders, "functionProviders is null"); + this.globalFunctionCatalog = requireNonNull(globalFunctionCatalog, "globalFunctionCatalog is null"); } - public FunctionInvoker getScalarFunctionInvoker(ResolvedFunction resolvedFunction, InvocationConvention invocationConvention) + public ScalarFunctionImplementation getScalarFunctionImplementation(ResolvedFunction resolvedFunction, InvocationConvention invocationConvention) { try { - return uncheckedCacheGet(specializedScalarCache, new FunctionKey(resolvedFunction, invocationConvention), () -> getScalarFunctionInvokerInternal(resolvedFunction, invocationConvention)); + return uncheckedCacheGet(specializedScalarCache, new FunctionKey(resolvedFunction, invocationConvention), () -> getScalarFunctionImplementationInternal(resolvedFunction, invocationConvention)); } catch (UncheckedExecutionException e) { throwIfInstanceOf(e.getCause(), TrinoException.class); @@ -86,22 +95,22 @@ public FunctionInvoker getScalarFunctionInvoker(ResolvedFunction resolvedFunctio } } - private FunctionInvoker getScalarFunctionInvokerInternal(ResolvedFunction resolvedFunction, InvocationConvention invocationConvention) + private ScalarFunctionImplementation getScalarFunctionImplementationInternal(ResolvedFunction resolvedFunction, InvocationConvention invocationConvention) { FunctionDependencies functionDependencies = getFunctionDependencies(resolvedFunction); - FunctionInvoker functionInvoker = globalFunctionCatalog.getScalarFunctionInvoker( + ScalarFunctionImplementation scalarFunctionImplementation = getFunctionProvider(resolvedFunction).getScalarFunctionImplementation( resolvedFunction.getFunctionId(), resolvedFunction.getSignature(), functionDependencies, invocationConvention); - verifyMethodHandleSignature(resolvedFunction.getSignature(), functionInvoker, invocationConvention); - return functionInvoker; + verifyMethodHandleSignature(resolvedFunction.getSignature(), scalarFunctionImplementation, invocationConvention); + return scalarFunctionImplementation; } - public AggregationMetadata getAggregateFunctionImplementation(ResolvedFunction resolvedFunction) + public AggregationImplementation getAggregationImplementation(ResolvedFunction resolvedFunction) { try { - return uncheckedCacheGet(specializedAggregationCache, new FunctionKey(resolvedFunction), () -> getAggregateFunctionImplementationInternal(resolvedFunction)); + return uncheckedCacheGet(specializedAggregationCache, new FunctionKey(resolvedFunction), () -> getAggregationImplementationInternal(resolvedFunction)); } catch (UncheckedExecutionException e) { throwIfInstanceOf(e.getCause(), TrinoException.class); @@ -109,19 +118,19 @@ public AggregationMetadata getAggregateFunctionImplementation(ResolvedFunction r } } - private AggregationMetadata getAggregateFunctionImplementationInternal(ResolvedFunction resolvedFunction) + private AggregationImplementation getAggregationImplementationInternal(ResolvedFunction resolvedFunction) { FunctionDependencies functionDependencies = getFunctionDependencies(resolvedFunction); - return globalFunctionCatalog.getAggregateFunctionImplementation( + return getFunctionProvider(resolvedFunction).getAggregationImplementation( resolvedFunction.getFunctionId(), resolvedFunction.getSignature(), functionDependencies); } - public WindowFunctionSupplier getWindowFunctionImplementation(ResolvedFunction resolvedFunction) + public WindowFunctionSupplier getWindowFunctionSupplier(ResolvedFunction resolvedFunction) { try { - return uncheckedCacheGet(specializedWindowCache, new FunctionKey(resolvedFunction), () -> getWindowFunctionImplementationInternal(resolvedFunction)); + return uncheckedCacheGet(specializedWindowCache, new FunctionKey(resolvedFunction), () -> getWindowFunctionSupplierInternal(resolvedFunction)); } catch (UncheckedExecutionException e) { throwIfInstanceOf(e.getCause(), TrinoException.class); @@ -129,10 +138,10 @@ public WindowFunctionSupplier getWindowFunctionImplementation(ResolvedFunction r } } - private WindowFunctionSupplier getWindowFunctionImplementationInternal(ResolvedFunction resolvedFunction) + private WindowFunctionSupplier getWindowFunctionSupplierInternal(ResolvedFunction resolvedFunction) { FunctionDependencies functionDependencies = getFunctionDependencies(resolvedFunction); - return globalFunctionCatalog.getWindowFunctionImplementation( + return getFunctionProvider(resolvedFunction).getWindowFunctionSupplier( resolvedFunction.getFunctionId(), resolvedFunction.getSignature(), functionDependencies); @@ -140,12 +149,23 @@ private WindowFunctionSupplier getWindowFunctionImplementationInternal(ResolvedF private FunctionDependencies getFunctionDependencies(ResolvedFunction resolvedFunction) { - return new FunctionDependencies(this::getScalarFunctionInvoker, resolvedFunction.getTypeDependencies(), resolvedFunction.getFunctionDependencies()); + return new InternalFunctionDependencies(this::getScalarFunctionImplementation, resolvedFunction.getTypeDependencies(), resolvedFunction.getFunctionDependencies()); } - private static void verifyMethodHandleSignature(BoundSignature boundSignature, FunctionInvoker functionInvoker, InvocationConvention convention) + private FunctionProvider getFunctionProvider(ResolvedFunction resolvedFunction) { - MethodHandle methodHandle = functionInvoker.getMethodHandle(); + if (resolvedFunction.getCatalogHandle().equals(GlobalSystemConnector.CATALOG_HANDLE)) { + return globalFunctionCatalog; + } + + FunctionProvider functionProvider = functionProviders.getService(resolvedFunction.getCatalogHandle()); + checkArgument(functionProvider != null, "No function provider for catalog: '%s' (function '%s')", resolvedFunction.getCatalogHandle(), resolvedFunction.getSignature().getName()); + return functionProvider; + } + + private static void verifyMethodHandleSignature(BoundSignature boundSignature, ScalarFunctionImplementation scalarFunctionImplementation, InvocationConvention convention) + { + MethodHandle methodHandle = scalarFunctionImplementation.getMethodHandle(); MethodType methodType = methodHandle.type(); checkArgument(convention.getArgumentConventions().size() == boundSignature.getArgumentTypes().size(), @@ -155,16 +175,16 @@ private static void verifyMethodHandleSignature(BoundSignature boundSignature, F .mapToInt(InvocationArgumentConvention::getParameterCount) .sum(); expectedParameterCount += methodType.parameterList().stream().filter(ConnectorSession.class::equals).count(); - if (functionInvoker.getInstanceFactory().isPresent()) { + if (scalarFunctionImplementation.getInstanceFactory().isPresent()) { expectedParameterCount++; } checkArgument(expectedParameterCount == methodType.parameterCount(), "Expected %s method parameters, but got %s", expectedParameterCount, methodType.parameterCount()); int parameterIndex = 0; - if (functionInvoker.getInstanceFactory().isPresent()) { + if (scalarFunctionImplementation.getInstanceFactory().isPresent()) { verifyFunctionSignature(convention.supportsInstanceFactory(), "Method requires instance factory, but calling convention does not support an instance factory"); - MethodHandle factoryMethod = functionInvoker.getInstanceFactory().orElseThrow(); + MethodHandle factoryMethod = scalarFunctionImplementation.getInstanceFactory().orElseThrow(); verifyFunctionSignature(methodType.parameterType(parameterIndex).equals(factoryMethod.type().returnType()), "Invalid return type"); parameterIndex++; } @@ -203,7 +223,7 @@ private static void verifyMethodHandleSignature(BoundSignature boundSignature, F verifyFunctionSignature(parameterType.equals(InOut.class), "Expected IN_OUT argument type to be InOut"); break; case FUNCTION: - Class lambdaInterface = functionInvoker.getLambdaInterfaces().get(lambdaArgumentIndex); + Class lambdaInterface = scalarFunctionImplementation.getLambdaInterfaces().get(lambdaArgumentIndex); verifyFunctionSignature(parameterType.equals(lambdaInterface), "Expected function interface to be %s, but is %s", lambdaInterface, parameterType); lambdaArgumentIndex++; @@ -297,6 +317,6 @@ public static FunctionManager createTestingFunctionManager() GlobalFunctionCatalog functionCatalog = new GlobalFunctionCatalog(); functionCatalog.addFunctions(SystemFunctionBundle.create(new FeaturesConfig(), typeOperators, new BlockTypeOperators(typeOperators), UNKNOWN)); functionCatalog.addFunctions(new InternalFunctionBundle(new LiteralFunction(new InternalBlockEncodingSerde(new BlockEncodingManager(), TESTING_TYPE_MANAGER)))); - return new FunctionManager(functionCatalog); + return new FunctionManager(CatalogServiceProvider.fail(), functionCatalog); } } diff --git a/core/trino-main/src/main/java/io/trino/metadata/FunctionResolver.java b/core/trino-main/src/main/java/io/trino/metadata/FunctionResolver.java index eeb9f387a607..6674d38a0a0b 100644 --- a/core/trino-main/src/main/java/io/trino/metadata/FunctionResolver.java +++ b/core/trino-main/src/main/java/io/trino/metadata/FunctionResolver.java @@ -17,13 +17,19 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.Ordering; import io.trino.Session; +import io.trino.connector.CatalogHandle; +import io.trino.connector.system.GlobalSystemConnector; import io.trino.spi.TrinoException; +import io.trino.spi.function.BoundSignature; +import io.trino.spi.function.FunctionMetadata; +import io.trino.spi.function.FunctionNullability; +import io.trino.spi.function.QualifiedFunctionName; +import io.trino.spi.function.Signature; import io.trino.spi.type.Type; import io.trino.spi.type.TypeManager; import io.trino.sql.SqlPathElement; import io.trino.sql.analyzer.TypeSignatureProvider; import io.trino.sql.tree.Identifier; -import io.trino.sql.tree.QualifiedName; import java.util.ArrayList; import java.util.Collection; @@ -38,8 +44,7 @@ import static com.google.common.base.Preconditions.checkState; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.Iterables.getOnlyElement; -import static io.trino.metadata.GlobalFunctionCatalog.GLOBAL_CATALOG; -import static io.trino.metadata.GlobalFunctionCatalog.GLOBAL_SCHEMA; +import static io.trino.metadata.GlobalFunctionCatalog.BUILTIN_SCHEMA; import static io.trino.spi.StandardErrorCode.AMBIGUOUS_FUNCTION_CALL; import static io.trino.spi.StandardErrorCode.FUNCTION_IMPLEMENTATION_MISSING; import static io.trino.spi.StandardErrorCode.FUNCTION_NOT_FOUND; @@ -61,12 +66,13 @@ public FunctionResolver(Metadata metadata, TypeManager typeManager) this.typeManager = requireNonNull(typeManager, "typeManager is null"); } - boolean isAggregationFunction(Session session, QualifiedName name, Function> candidateLoader) + boolean isAggregationFunction(Session session, QualifiedFunctionName name, Function> candidateLoader) { for (CatalogSchemaFunctionName catalogSchemaFunctionName : toPath(session, name)) { - Collection candidates = candidateLoader.apply(catalogSchemaFunctionName); + Collection candidates = candidateLoader.apply(catalogSchemaFunctionName); if (!candidates.isEmpty()) { return candidates.stream() + .map(CatalogFunctionMetadata::getFunctionMetadata) .map(FunctionMetadata::getKind) .anyMatch(AGGREGATE::equals); } @@ -74,25 +80,25 @@ boolean isAggregationFunction(Session session, QualifiedName name, Function> candidateLoader) + CatalogFunctionBinding resolveCoercion(Session session, QualifiedFunctionName name, Signature signature, Function> candidateLoader) { for (CatalogSchemaFunctionName catalogSchemaFunctionName : toPath(session, name)) { - Collection candidates = candidateLoader.apply(catalogSchemaFunctionName); - List exactCandidates = candidates.stream() - .filter(function -> possibleExactCastMatch(signature, function.getSignature())) + Collection candidates = candidateLoader.apply(catalogSchemaFunctionName); + List exactCandidates = candidates.stream() + .filter(function -> possibleExactCastMatch(signature, function.getFunctionMetadata().getSignature())) .collect(toImmutableList()); - for (FunctionMetadata candidate : exactCandidates) { - if (canBindSignature(session, candidate.getSignature(), signature)) { + for (CatalogFunctionMetadata candidate : exactCandidates) { + if (canBindSignature(session, candidate.getFunctionMetadata().getSignature(), signature)) { return toFunctionBinding(candidate, signature); } } // only consider generic genericCandidates - List genericCandidates = candidates.stream() - .filter(function -> !function.getSignature().getTypeVariableConstraints().isEmpty()) + List genericCandidates = candidates.stream() + .filter(function -> !function.getFunctionMetadata().getSignature().getTypeVariableConstraints().isEmpty()) .collect(toImmutableList()); - for (FunctionMetadata candidate : genericCandidates) { - if (canBindSignature(session, candidate.getSignature(), signature)) { + for (CatalogFunctionMetadata candidate : genericCandidates) { + if (canBindSignature(session, candidate.getFunctionMetadata().getSignature(), signature)) { return toFunctionBinding(candidate, signature); } } @@ -107,7 +113,7 @@ private boolean canBindSignature(Session session, Signature declaredSignature, S .canBind(fromTypeSignatures(actualSignature.getArgumentTypes()), actualSignature.getReturnType()); } - private FunctionBinding toFunctionBinding(FunctionMetadata functionMetadata, Signature signature) + private CatalogFunctionBinding toFunctionBinding(CatalogFunctionMetadata functionMetadata, Signature signature) { BoundSignature boundSignature = new BoundSignature( signature.getName(), @@ -115,10 +121,12 @@ private FunctionBinding toFunctionBinding(FunctionMetadata functionMetadata, Sig signature.getArgumentTypes().stream() .map(typeManager::getType) .collect(toImmutableList())); - return SignatureBinder.bindFunction( - functionMetadata.getFunctionId(), - functionMetadata.getSignature(), - boundSignature); + return new CatalogFunctionBinding( + functionMetadata.getCatalogHandle(), + SignatureBinder.bindFunction( + functionMetadata.getFunctionMetadata().getFunctionId(), + functionMetadata.getFunctionMetadata().getSignature(), + boundSignature)); } private static boolean possibleExactCastMatch(Signature signature, Signature declaredSignature) @@ -135,26 +143,26 @@ private static boolean possibleExactCastMatch(Signature signature, Signature dec return true; } - FunctionBinding resolveFunction( + CatalogFunctionBinding resolveFunction( Session session, - QualifiedName name, + QualifiedFunctionName name, List parameterTypes, - Function> candidateLoader) + Function> candidateLoader) { - ImmutableList.Builder allCandidates = ImmutableList.builder(); + ImmutableList.Builder allCandidates = ImmutableList.builder(); for (CatalogSchemaFunctionName catalogSchemaFunctionName : toPath(session, name)) { - Collection candidates = candidateLoader.apply(catalogSchemaFunctionName); - List exactCandidates = candidates.stream() - .filter(function -> function.getSignature().getTypeVariableConstraints().isEmpty()) + Collection candidates = candidateLoader.apply(catalogSchemaFunctionName); + List exactCandidates = candidates.stream() + .filter(function -> function.getFunctionMetadata().getSignature().getTypeVariableConstraints().isEmpty()) .collect(toImmutableList()); - Optional match = matchFunctionExact(session, exactCandidates, parameterTypes); + Optional match = matchFunctionExact(session, exactCandidates, parameterTypes); if (match.isPresent()) { return match.get(); } - List genericCandidates = candidates.stream() - .filter(function -> !function.getSignature().getTypeVariableConstraints().isEmpty()) + List genericCandidates = candidates.stream() + .filter(function -> !function.getFunctionMetadata().getSignature().getTypeVariableConstraints().isEmpty()) .collect(toImmutableList()); match = matchFunctionExact(session, genericCandidates, parameterTypes); @@ -170,15 +178,15 @@ FunctionBinding resolveFunction( allCandidates.addAll(candidates); } - List candidates = allCandidates.build(); + List candidates = allCandidates.build(); if (candidates.isEmpty()) { throw new TrinoException(FUNCTION_NOT_FOUND, format("Function '%s' not registered", name)); } List expectedParameters = new ArrayList<>(); - for (FunctionMetadata function : candidates) { - String arguments = Joiner.on(", ").join(function.getSignature().getArgumentTypes()); - String constraints = Joiner.on(", ").join(function.getSignature().getTypeVariableConstraints()); + for (CatalogFunctionMetadata function : candidates) { + String arguments = Joiner.on(", ").join(function.getFunctionMetadata().getSignature().getArgumentTypes()); + String constraints = Joiner.on(", ").join(function.getFunctionMetadata().getSignature().getTypeVariableConstraints()); expectedParameters.add(format("%s(%s) %s", name, arguments, constraints).stripTrailing()); } @@ -188,46 +196,43 @@ FunctionBinding resolveFunction( throw new TrinoException(FUNCTION_NOT_FOUND, message); } - public static List toPath(Session session, QualifiedName name) + public static List toPath(Session session, QualifiedFunctionName name) { - List parts = name.getParts(); - checkArgument(parts.size() <= 3, "Function name can only have 3 parts: " + name); - if (parts.size() == 3) { - return ImmutableList.of(new CatalogSchemaFunctionName(parts.get(0), parts.get(1), parts.get(2))); + if (name.getCatalogName().isPresent()) { + return ImmutableList.of(new CatalogSchemaFunctionName(name.getCatalogName().orElseThrow(), name.getSchemaName().orElseThrow(), name.getFunctionName())); } - if (parts.size() == 2) { + if (name.getSchemaName().isPresent()) { String currentCatalog = session.getCatalog() .orElseThrow(() -> new IllegalArgumentException("Session default catalog must be set to resolve a partial function name: " + name)); - return ImmutableList.of(new CatalogSchemaFunctionName(currentCatalog, parts.get(0), parts.get(1))); + return ImmutableList.of(new CatalogSchemaFunctionName(currentCatalog, name.getSchemaName().orElseThrow(), name.getFunctionName())); } ImmutableList.Builder names = ImmutableList.builder(); - String functionName = parts.get(0); // global namespace - names.add(new CatalogSchemaFunctionName(GLOBAL_CATALOG, GLOBAL_SCHEMA, functionName)); + names.add(new CatalogSchemaFunctionName(GlobalSystemConnector.NAME, BUILTIN_SCHEMA, name.getFunctionName())); // add resolved path items for (SqlPathElement sqlPathElement : session.getPath().getParsedPath()) { String catalog = sqlPathElement.getCatalog().map(Identifier::getCanonicalValue).or(session::getCatalog) .orElseThrow(() -> new IllegalArgumentException("Session default catalog must be set to resolve a partial function name: " + name)); - names.add(new CatalogSchemaFunctionName(catalog, sqlPathElement.getSchema().getCanonicalValue(), functionName)); + names.add(new CatalogSchemaFunctionName(catalog, sqlPathElement.getSchema().getCanonicalValue(), name.getFunctionName())); } return names.build(); } - private Optional matchFunctionExact(Session session, List candidates, List actualParameters) + private Optional matchFunctionExact(Session session, List candidates, List actualParameters) { return matchFunction(session, candidates, actualParameters, false); } - private Optional matchFunctionWithCoercion(Session session, Collection candidates, List actualParameters) + private Optional matchFunctionWithCoercion(Session session, Collection candidates, List actualParameters) { return matchFunction(session, candidates, actualParameters, true); } - private Optional matchFunction(Session session, Collection candidates, List parameters, boolean coercionAllowed) + private Optional matchFunction(Session session, Collection candidates, List parameters, boolean coercionAllowed) { List applicableFunctions = identifyApplicableFunctions(session, candidates, parameters, coercionAllowed); if (applicableFunctions.isEmpty()) { @@ -255,11 +260,11 @@ private Optional matchFunction(Session session, Collection identifyApplicableFunctions(Session session, Collection candidates, List actualParameters, boolean allowCoercion) + private List identifyApplicableFunctions(Session session, Collection candidates, List actualParameters, boolean allowCoercion) { ImmutableList.Builder applicableFunctions = ImmutableList.builder(); - for (FunctionMetadata function : candidates) { - new SignatureBinder(session, metadata, typeManager, function.getSignature(), allowCoercion) + for (CatalogFunctionMetadata function : candidates) { + new SignatureBinder(session, metadata, typeManager, function.getFunctionMetadata().getSignature(), allowCoercion) .bind(actualParameters) .ifPresent(signature -> applicableFunctions.add(new ApplicableFunction(function, signature))); } @@ -376,7 +381,7 @@ private static boolean allReturnNullOnGivenInputTypes(List a private static boolean returnsNullOnGivenInputTypes(ApplicableFunction applicableFunction, List parameterTypes) { - FunctionMetadata function = applicableFunction.getFunction(); + FunctionMetadata function = applicableFunction.getFunctionMetadata(); // Window and Aggregation functions have fixed semantic where NULL values are always skipped if (function.getKind() != SCALAR) { @@ -417,25 +422,30 @@ private boolean isMoreSpecificThan(Session session, ApplicableFunction left, App private static class ApplicableFunction { - private final FunctionMetadata function; + private final CatalogFunctionMetadata function; // Ideally this would be a real bound signature, but the resolver algorithm considers functions with illegal types (e.g., char(large_number)) // We could just not consider these applicable functions, but there are tests that depend on the specific error messages for these failures. private final Signature boundSignature; - private ApplicableFunction(FunctionMetadata function, Signature boundSignature) + private ApplicableFunction(CatalogFunctionMetadata function, Signature boundSignature) { this.function = function; this.boundSignature = boundSignature; } - public FunctionMetadata getFunction() + public CatalogFunctionMetadata getFunction() { return function; } + public FunctionMetadata getFunctionMetadata() + { + return function.getFunctionMetadata(); + } + public Signature getDeclaredSignature() { - return function.getSignature(); + return function.getFunctionMetadata().getSignature(); } public Signature getBoundSignature() @@ -447,9 +457,53 @@ public Signature getBoundSignature() public String toString() { return toStringHelper(this) - .add("declaredSignature", function.getSignature()) + .add("declaredSignature", function.getFunctionMetadata().getSignature()) .add("boundSignature", boundSignature) .toString(); } } + + static class CatalogFunctionMetadata + { + private final CatalogHandle catalogHandle; + private final FunctionMetadata functionMetadata; + + public CatalogFunctionMetadata(CatalogHandle catalogHandle, FunctionMetadata functionMetadata) + { + this.catalogHandle = requireNonNull(catalogHandle, "catalogHandle is null"); + this.functionMetadata = requireNonNull(functionMetadata, "functionMetadata is null"); + } + + public CatalogHandle getCatalogHandle() + { + return catalogHandle; + } + + public FunctionMetadata getFunctionMetadata() + { + return functionMetadata; + } + } + + static class CatalogFunctionBinding + { + private final CatalogHandle catalogHandle; + private final FunctionBinding functionBinding; + + private CatalogFunctionBinding(CatalogHandle catalogHandle, FunctionBinding functionBinding) + { + this.catalogHandle = requireNonNull(catalogHandle, "catalogHandle is null"); + this.functionBinding = requireNonNull(functionBinding, "functionBinding is null"); + } + + public CatalogHandle getCatalogHandle() + { + return catalogHandle; + } + + public FunctionBinding getFunctionBinding() + { + return functionBinding; + } + } } diff --git a/core/trino-main/src/main/java/io/trino/metadata/GlobalFunctionCatalog.java b/core/trino-main/src/main/java/io/trino/metadata/GlobalFunctionCatalog.java index f0949d7c3a11..ecdcad44a540 100644 --- a/core/trino-main/src/main/java/io/trino/metadata/GlobalFunctionCatalog.java +++ b/core/trino-main/src/main/java/io/trino/metadata/GlobalFunctionCatalog.java @@ -17,10 +17,20 @@ import com.google.common.collect.ImmutableListMultimap; import com.google.common.collect.ImmutableMap; import com.google.common.collect.Multimap; -import io.trino.operator.aggregation.AggregationMetadata; -import io.trino.operator.window.WindowFunctionSupplier; +import io.trino.spi.function.AggregationFunctionMetadata; +import io.trino.spi.function.AggregationImplementation; +import io.trino.spi.function.BoundSignature; +import io.trino.spi.function.FunctionDependencies; +import io.trino.spi.function.FunctionDependencyDeclaration; +import io.trino.spi.function.FunctionId; +import io.trino.spi.function.FunctionMetadata; +import io.trino.spi.function.FunctionProvider; import io.trino.spi.function.InvocationConvention; import io.trino.spi.function.OperatorType; +import io.trino.spi.function.ScalarFunctionImplementation; +import io.trino.spi.function.SchemaFunctionName; +import io.trino.spi.function.Signature; +import io.trino.spi.function.WindowFunctionSupplier; import io.trino.spi.type.TypeSignature; import javax.annotation.concurrent.ThreadSafe; @@ -34,8 +44,8 @@ import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; import static com.google.common.collect.ImmutableMap.toImmutableMap; -import static io.trino.metadata.Signature.isOperatorName; -import static io.trino.metadata.Signature.unmangleOperator; +import static io.trino.metadata.OperatorNameUtil.isOperatorName; +import static io.trino.metadata.OperatorNameUtil.unmangleOperator; import static io.trino.spi.function.FunctionKind.AGGREGATE; import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.BooleanType.BOOLEAN; @@ -44,9 +54,9 @@ @ThreadSafe public class GlobalFunctionCatalog + implements FunctionProvider { - public static final String GLOBAL_CATALOG = "system"; - public static final String GLOBAL_SCHEMA = "global"; + public static final String BUILTIN_SCHEMA = "builtin"; private volatile FunctionMap functions = new FunctionMap(); public final synchronized void addFunctions(FunctionBundle functionBundle) @@ -119,7 +129,7 @@ public List listFunctions() public Collection getFunctions(SchemaFunctionName name) { - if (!GLOBAL_SCHEMA.equals(name.getSchemaName())) { + if (!BUILTIN_SCHEMA.equals(name.getSchemaName())) { return ImmutableList.of(); } return functions.get(name.getFunctionName()); @@ -135,14 +145,16 @@ public AggregationFunctionMetadata getAggregationFunctionMetadata(FunctionId fun return functions.getFunctionBundle(functionId).getAggregationFunctionMetadata(functionId); } - public WindowFunctionSupplier getWindowFunctionImplementation(FunctionId functionId, BoundSignature boundSignature, FunctionDependencies functionDependencies) + @Override + public WindowFunctionSupplier getWindowFunctionSupplier(FunctionId functionId, BoundSignature boundSignature, FunctionDependencies functionDependencies) { - return functions.getFunctionBundle(functionId).getWindowFunctionImplementation(functionId, boundSignature, functionDependencies); + return functions.getFunctionBundle(functionId).getWindowFunctionSupplier(functionId, boundSignature, functionDependencies); } - public AggregationMetadata getAggregateFunctionImplementation(FunctionId functionId, BoundSignature boundSignature, FunctionDependencies functionDependencies) + @Override + public AggregationImplementation getAggregationImplementation(FunctionId functionId, BoundSignature boundSignature, FunctionDependencies functionDependencies) { - return functions.getFunctionBundle(functionId).getAggregateFunctionImplementation(functionId, boundSignature, functionDependencies); + return functions.getFunctionBundle(functionId).getAggregationImplementation(functionId, boundSignature, functionDependencies); } public FunctionDependencyDeclaration getFunctionDependencies(FunctionId functionId, BoundSignature boundSignature) @@ -150,13 +162,14 @@ public FunctionDependencyDeclaration getFunctionDependencies(FunctionId function return functions.getFunctionBundle(functionId).getFunctionDependencies(functionId, boundSignature); } - public FunctionInvoker getScalarFunctionInvoker( + @Override + public ScalarFunctionImplementation getScalarFunctionImplementation( FunctionId functionId, BoundSignature boundSignature, FunctionDependencies functionDependencies, InvocationConvention invocationConvention) { - return functions.getFunctionBundle(functionId).getScalarFunctionInvoker(functionId, boundSignature, functionDependencies, invocationConvention); + return functions.getFunctionBundle(functionId).getScalarFunctionImplementation(functionId, boundSignature, functionDependencies, invocationConvention); } private static class FunctionMap diff --git a/core/trino-main/src/main/java/io/trino/metadata/InternalFunctionBundle.java b/core/trino-main/src/main/java/io/trino/metadata/InternalFunctionBundle.java index e33638998f0d..39d3fc685486 100644 --- a/core/trino-main/src/main/java/io/trino/metadata/InternalFunctionBundle.java +++ b/core/trino-main/src/main/java/io/trino/metadata/InternalFunctionBundle.java @@ -17,18 +17,25 @@ import com.google.common.collect.ImmutableList; import com.google.common.util.concurrent.UncheckedExecutionException; import io.trino.collect.cache.NonEvictableCache; -import io.trino.operator.aggregation.AggregationMetadata; -import io.trino.operator.scalar.ScalarFunctionImplementation; +import io.trino.operator.scalar.SpecializedSqlScalarFunction; import io.trino.operator.scalar.annotations.ScalarFromAnnotationsParser; import io.trino.operator.window.SqlWindowFunction; import io.trino.operator.window.WindowAnnotationsParser; -import io.trino.operator.window.WindowFunctionSupplier; import io.trino.spi.TrinoException; import io.trino.spi.function.AggregationFunction; +import io.trino.spi.function.AggregationFunctionMetadata; +import io.trino.spi.function.AggregationImplementation; +import io.trino.spi.function.BoundSignature; +import io.trino.spi.function.FunctionDependencies; +import io.trino.spi.function.FunctionDependencyDeclaration; +import io.trino.spi.function.FunctionId; +import io.trino.spi.function.FunctionMetadata; import io.trino.spi.function.InvocationConvention; import io.trino.spi.function.ScalarFunction; +import io.trino.spi.function.ScalarFunctionImplementation; import io.trino.spi.function.ScalarOperator; import io.trino.spi.function.WindowFunction; +import io.trino.spi.function.WindowFunctionSupplier; import java.util.ArrayList; import java.util.Collection; @@ -51,8 +58,8 @@ public class InternalFunctionBundle implements FunctionBundle { // scalar function specialization may involve expensive code generation - private final NonEvictableCache specializedScalarCache; - private final NonEvictableCache specializedAggregationCache; + private final NonEvictableCache specializedScalarCache; + private final NonEvictableCache specializedAggregationCache; private final NonEvictableCache specializedWindowCache; private final Map functions; @@ -109,15 +116,15 @@ public FunctionDependencyDeclaration getFunctionDependencies(FunctionId function } @Override - public FunctionInvoker getScalarFunctionInvoker( + public ScalarFunctionImplementation getScalarFunctionImplementation( FunctionId functionId, BoundSignature boundSignature, FunctionDependencies functionDependencies, InvocationConvention invocationConvention) { - ScalarFunctionImplementation scalarFunctionImplementation; + SpecializedSqlScalarFunction specializedSqlScalarFunction; try { - scalarFunctionImplementation = uncheckedCacheGet( + specializedSqlScalarFunction = uncheckedCacheGet( specializedScalarCache, new FunctionKey(functionId, boundSignature), () -> specializeScalarFunction(functionId, boundSignature, functionDependencies)); @@ -126,17 +133,17 @@ public FunctionInvoker getScalarFunctionInvoker( throwIfInstanceOf(e.getCause(), TrinoException.class); throw new RuntimeException(e.getCause()); } - return scalarFunctionImplementation.getScalarFunctionInvoker(invocationConvention); + return specializedSqlScalarFunction.getScalarFunctionImplementation(invocationConvention); } - private ScalarFunctionImplementation specializeScalarFunction(FunctionId functionId, BoundSignature boundSignature, FunctionDependencies functionDependencies) + private SpecializedSqlScalarFunction specializeScalarFunction(FunctionId functionId, BoundSignature boundSignature, FunctionDependencies functionDependencies) { SqlScalarFunction function = (SqlScalarFunction) getSqlFunction(functionId); return function.specialize(boundSignature, functionDependencies); } @Override - public AggregationMetadata getAggregateFunctionImplementation(FunctionId functionId, BoundSignature boundSignature, FunctionDependencies functionDependencies) + public AggregationImplementation getAggregationImplementation(FunctionId functionId, BoundSignature boundSignature, FunctionDependencies functionDependencies) { try { return uncheckedCacheGet(specializedAggregationCache, new FunctionKey(functionId, boundSignature), () -> specializedAggregation(functionId, boundSignature, functionDependencies)); @@ -147,14 +154,14 @@ public AggregationMetadata getAggregateFunctionImplementation(FunctionId functio } } - private AggregationMetadata specializedAggregation(FunctionId functionId, BoundSignature boundSignature, FunctionDependencies functionDependencies) + private AggregationImplementation specializedAggregation(FunctionId functionId, BoundSignature boundSignature, FunctionDependencies functionDependencies) { SqlAggregationFunction aggregationFunction = (SqlAggregationFunction) functions.get(functionId); return aggregationFunction.specialize(boundSignature, functionDependencies); } @Override - public WindowFunctionSupplier getWindowFunctionImplementation(FunctionId functionId, BoundSignature boundSignature, FunctionDependencies functionDependencies) + public WindowFunctionSupplier getWindowFunctionSupplier(FunctionId functionId, BoundSignature boundSignature, FunctionDependencies functionDependencies) { try { return uncheckedCacheGet(specializedWindowCache, new FunctionKey(functionId, boundSignature), () -> specializeWindow(functionId, boundSignature, functionDependencies)); diff --git a/core/trino-main/src/main/java/io/trino/metadata/FunctionDependencies.java b/core/trino-main/src/main/java/io/trino/metadata/InternalFunctionDependencies.java similarity index 83% rename from core/trino-main/src/main/java/io/trino/metadata/FunctionDependencies.java rename to core/trino-main/src/main/java/io/trino/metadata/InternalFunctionDependencies.java index 24f330914884..70bcd1425905 100644 --- a/core/trino-main/src/main/java/io/trino/metadata/FunctionDependencies.java +++ b/core/trino-main/src/main/java/io/trino/metadata/InternalFunctionDependencies.java @@ -15,11 +15,15 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import io.trino.spi.function.FunctionDependencies; +import io.trino.spi.function.FunctionNullability; import io.trino.spi.function.InvocationConvention; import io.trino.spi.function.OperatorType; +import io.trino.spi.function.QualifiedFunctionName; +import io.trino.spi.function.ScalarFunctionImplementation; +import io.trino.spi.function.Signature; import io.trino.spi.type.Type; import io.trino.spi.type.TypeSignature; -import io.trino.sql.tree.QualifiedName; import java.util.Collection; import java.util.List; @@ -30,23 +34,24 @@ import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableMap.toImmutableMap; -import static io.trino.metadata.Signature.isOperatorName; -import static io.trino.metadata.Signature.unmangleOperator; +import static io.trino.metadata.OperatorNameUtil.isOperatorName; +import static io.trino.metadata.OperatorNameUtil.unmangleOperator; import static io.trino.spi.function.OperatorType.CAST; import static java.lang.String.format; import static java.util.Objects.requireNonNull; import static java.util.function.Function.identity; -public class FunctionDependencies +public class InternalFunctionDependencies + implements FunctionDependencies { - private final BiFunction specialization; + private final BiFunction specialization; private final Map types; private final Map functions; private final Map operators; private final Map casts; - public FunctionDependencies( - BiFunction specialization, + public InternalFunctionDependencies( + BiFunction specialization, Map typeDependencies, Collection functionDependencies) { @@ -60,13 +65,14 @@ public FunctionDependencies( .filter(function -> !isOperatorName(function.getSignature().getName())) .collect(toImmutableMap(FunctionKey::new, identity())); this.operators = functionDependencies.stream() - .filter(FunctionDependencies::isOperator) + .filter(InternalFunctionDependencies::isOperator) .collect(toImmutableMap(OperatorKey::new, identity())); this.casts = functionDependencies.stream() - .filter(FunctionDependencies::isCast) + .filter(InternalFunctionDependencies::isCast) .collect(toImmutableMap(CastKey::new, identity())); } + @Override public Type getType(TypeSignature typeSignature) { // CHAR type does not properly roundtrip, so load directly from metadata and then verify type was declared correctly @@ -77,7 +83,8 @@ public Type getType(TypeSignature typeSignature) return type; } - public FunctionNullability getFunctionNullability(QualifiedName name, List parameterTypes) + @Override + public FunctionNullability getFunctionNullability(QualifiedFunctionName name, List parameterTypes) { FunctionKey functionKey = new FunctionKey(name, toTypeSignatures(parameterTypes)); ResolvedFunction resolvedFunction = functions.get(functionKey); @@ -87,6 +94,7 @@ public FunctionNullability getFunctionNullability(QualifiedName name, List return resolvedFunction.getFunctionNullability(); } + @Override public FunctionNullability getOperatorNullability(OperatorType operatorType, List parameterTypes) { OperatorKey operatorKey = new OperatorKey(operatorType, toTypeSignatures(parameterTypes)); @@ -97,6 +105,7 @@ public FunctionNullability getOperatorNullability(OperatorType operatorType, Lis return resolvedFunction.getFunctionNullability(); } + @Override public FunctionNullability getCastNullability(Type fromType, Type toType) { CastKey castKey = new CastKey(fromType.getTypeSignature(), toType.getTypeSignature()); @@ -107,7 +116,8 @@ public FunctionNullability getCastNullability(Type fromType, Type toType) return resolvedFunction.getFunctionNullability(); } - public FunctionInvoker getFunctionInvoker(QualifiedName name, List parameterTypes, InvocationConvention invocationConvention) + @Override + public ScalarFunctionImplementation getScalarFunctionImplementation(QualifiedFunctionName name, List parameterTypes, InvocationConvention invocationConvention) { FunctionKey functionKey = new FunctionKey(name, toTypeSignatures(parameterTypes)); ResolvedFunction resolvedFunction = functions.get(functionKey); @@ -117,7 +127,8 @@ public FunctionInvoker getFunctionInvoker(QualifiedName name, List paramet return specialization.apply(resolvedFunction, invocationConvention); } - public FunctionInvoker getFunctionSignatureInvoker(QualifiedName name, List parameterTypes, InvocationConvention invocationConvention) + @Override + public ScalarFunctionImplementation getScalarFunctionImplementationSignature(QualifiedFunctionName name, List parameterTypes, InvocationConvention invocationConvention) { FunctionKey functionKey = new FunctionKey(name, parameterTypes); ResolvedFunction resolvedFunction = functions.get(functionKey); @@ -127,7 +138,8 @@ public FunctionInvoker getFunctionSignatureInvoker(QualifiedName name, List parameterTypes, InvocationConvention invocationConvention) + @Override + public ScalarFunctionImplementation getOperatorImplementation(OperatorType operatorType, List parameterTypes, InvocationConvention invocationConvention) { OperatorKey operatorKey = new OperatorKey(operatorType, toTypeSignatures(parameterTypes)); ResolvedFunction resolvedFunction = operators.get(operatorKey); @@ -137,7 +149,8 @@ public FunctionInvoker getOperatorInvoker(OperatorType operatorType, List return specialization.apply(resolvedFunction, invocationConvention); } - public FunctionInvoker getOperatorSignatureInvoker(OperatorType operatorType, List parameterTypes, InvocationConvention invocationConvention) + @Override + public ScalarFunctionImplementation getOperatorImplementationSignature(OperatorType operatorType, List parameterTypes, InvocationConvention invocationConvention) { OperatorKey operatorKey = new OperatorKey(operatorType, parameterTypes); ResolvedFunction resolvedFunction = operators.get(operatorKey); @@ -147,7 +160,8 @@ public FunctionInvoker getOperatorSignatureInvoker(OperatorType operatorType, Li return specialization.apply(resolvedFunction, invocationConvention); } - public FunctionInvoker getCastInvoker(Type fromType, Type toType, InvocationConvention invocationConvention) + @Override + public ScalarFunctionImplementation getCastImplementation(Type fromType, Type toType, InvocationConvention invocationConvention) { CastKey castKey = new CastKey(fromType.getTypeSignature(), toType.getTypeSignature()); ResolvedFunction resolvedFunction = casts.get(castKey); @@ -157,7 +171,8 @@ public FunctionInvoker getCastInvoker(Type fromType, Type toType, InvocationConv return specialization.apply(resolvedFunction, invocationConvention); } - public FunctionInvoker getCastSignatureInvoker(TypeSignature fromType, TypeSignature toType, InvocationConvention invocationConvention) + @Override + public ScalarFunctionImplementation getCastImplementationSignature(TypeSignature fromType, TypeSignature toType, InvocationConvention invocationConvention) { CastKey castKey = new CastKey(fromType, toType); ResolvedFunction resolvedFunction = casts.get(castKey); @@ -188,19 +203,19 @@ private static boolean isCast(ResolvedFunction function) public static final class FunctionKey { - private final QualifiedName name; + private final QualifiedFunctionName name; private final List argumentTypes; private FunctionKey(ResolvedFunction resolvedFunction) { Signature signature = resolvedFunction.getSignature().toSignature(); - name = QualifiedName.of(signature.getName()); + name = QualifiedFunctionName.of(signature.getName()); argumentTypes = resolvedFunction.getSignature().getArgumentTypes().stream() .map(Type::getTypeSignature) .collect(toImmutableList()); } - private FunctionKey(QualifiedName name, List argumentTypes) + private FunctionKey(QualifiedFunctionName name, List argumentTypes) { this.name = requireNonNull(name, "name is null"); this.argumentTypes = ImmutableList.copyOf(requireNonNull(argumentTypes, "argumentTypes is null")); diff --git a/core/trino-main/src/main/java/io/trino/metadata/LiteralFunction.java b/core/trino-main/src/main/java/io/trino/metadata/LiteralFunction.java index 157a8106b579..99d7bea38572 100644 --- a/core/trino-main/src/main/java/io/trino/metadata/LiteralFunction.java +++ b/core/trino-main/src/main/java/io/trino/metadata/LiteralFunction.java @@ -16,10 +16,13 @@ import com.google.common.collect.ImmutableList; import com.google.common.primitives.Primitives; import io.airlift.slice.Slice; -import io.trino.operator.scalar.ChoicesScalarFunctionImplementation; -import io.trino.operator.scalar.ScalarFunctionImplementation; +import io.trino.operator.scalar.ChoicesSpecializedSqlScalarFunction; +import io.trino.operator.scalar.SpecializedSqlScalarFunction; import io.trino.spi.block.Block; import io.trino.spi.block.BlockEncodingSerde; +import io.trino.spi.function.BoundSignature; +import io.trino.spi.function.FunctionMetadata; +import io.trino.spi.function.Signature; import io.trino.spi.type.Type; import io.trino.spi.type.TypeSignature; import io.trino.spi.type.VarcharType; @@ -62,7 +65,7 @@ public LiteralFunction(BlockEncodingSerde blockEncodingSerde) } @Override - public ScalarFunctionImplementation specialize(BoundSignature boundSignature) + public SpecializedSqlScalarFunction specialize(BoundSignature boundSignature) { Type parameterType = boundSignature.getArgumentTypes().get(0); Type type = boundSignature.getReturnType(); @@ -88,7 +91,7 @@ else if (type.getJavaType() != Slice.class) { parameterType.getJavaType(), type.getJavaType()); - return new ChoicesScalarFunctionImplementation( + return new ChoicesSpecializedSqlScalarFunction( boundSignature, FAIL_ON_NULL, ImmutableList.of(NEVER_NULL), diff --git a/core/trino-main/src/main/java/io/trino/metadata/Metadata.java b/core/trino-main/src/main/java/io/trino/metadata/Metadata.java index 31a3ab58dcca..5efd5b94e953 100644 --- a/core/trino-main/src/main/java/io/trino/metadata/Metadata.java +++ b/core/trino-main/src/main/java/io/trino/metadata/Metadata.java @@ -47,6 +47,8 @@ import io.trino.spi.connector.TableScanRedirectApplicationResult; import io.trino.spi.connector.TopNApplicationResult; import io.trino.spi.expression.ConnectorExpression; +import io.trino.spi.function.AggregationFunctionMetadata; +import io.trino.spi.function.FunctionMetadata; import io.trino.spi.function.OperatorType; import io.trino.spi.predicate.TupleDomain; import io.trino.spi.security.GrantInfo; diff --git a/core/trino-main/src/main/java/io/trino/metadata/MetadataManager.java b/core/trino-main/src/main/java/io/trino/metadata/MetadataManager.java index 6a6877d5d311..d4b1fb4f8e88 100644 --- a/core/trino-main/src/main/java/io/trino/metadata/MetadataManager.java +++ b/core/trino-main/src/main/java/io/trino/metadata/MetadataManager.java @@ -27,7 +27,9 @@ import io.trino.Session; import io.trino.collect.cache.NonEvictableCache; import io.trino.connector.CatalogHandle; -import io.trino.metadata.AggregationFunctionMetadata.AggregationFunctionMetadataBuilder; +import io.trino.connector.system.GlobalSystemConnector; +import io.trino.metadata.FunctionResolver.CatalogFunctionBinding; +import io.trino.metadata.FunctionResolver.CatalogFunctionMetadata; import io.trino.metadata.ResolvedFunction.ResolvedFunctionDecoder; import io.trino.spi.QueryId; import io.trino.spi.TrinoException; @@ -79,7 +81,15 @@ import io.trino.spi.connector.TopNApplicationResult; import io.trino.spi.expression.ConnectorExpression; import io.trino.spi.expression.Variable; +import io.trino.spi.function.AggregationFunctionMetadata; +import io.trino.spi.function.AggregationFunctionMetadata.AggregationFunctionMetadataBuilder; +import io.trino.spi.function.BoundSignature; +import io.trino.spi.function.FunctionDependencyDeclaration; +import io.trino.spi.function.FunctionId; +import io.trino.spi.function.FunctionMetadata; import io.trino.spi.function.OperatorType; +import io.trino.spi.function.QualifiedFunctionName; +import io.trino.spi.function.Signature; import io.trino.spi.predicate.TupleDomain; import io.trino.spi.security.GrantInfo; import io.trino.spi.security.Identity; @@ -94,9 +104,11 @@ import io.trino.spi.type.TypeNotFoundException; import io.trino.spi.type.TypeOperators; import io.trino.spi.type.TypeSignature; +import io.trino.sql.SqlPathElement; import io.trino.sql.analyzer.TypeSignatureProvider; import io.trino.sql.planner.ConnectorExpressions; import io.trino.sql.planner.PartitioningHandle; +import io.trino.sql.tree.Identifier; import io.trino.sql.tree.QualifiedName; import io.trino.transaction.TransactionManager; import io.trino.type.BlockTypeOperators; @@ -138,11 +150,10 @@ import static io.trino.collect.cache.SafeCaches.buildNonEvictableCache; import static io.trino.metadata.CatalogMetadata.SecurityManagement.CONNECTOR; import static io.trino.metadata.CatalogMetadata.SecurityManagement.SYSTEM; -import static io.trino.metadata.GlobalFunctionCatalog.GLOBAL_CATALOG; +import static io.trino.metadata.OperatorNameUtil.mangleOperatorName; import static io.trino.metadata.QualifiedObjectName.convertFromSchemaTableName; import static io.trino.metadata.RedirectionAwareTableHandle.noRedirection; import static io.trino.metadata.RedirectionAwareTableHandle.withRedirectionTo; -import static io.trino.metadata.Signature.mangleOperatorName; import static io.trino.metadata.SignatureBinder.applyBoundVariables; import static io.trino.spi.StandardErrorCode.FUNCTION_IMPLEMENTATION_MISSING; import static io.trino.spi.StandardErrorCode.FUNCTION_NOT_FOUND; @@ -2044,7 +2055,17 @@ public List listTablePrivileges(Session session, QualifiedTablePrefix @Override public Collection listFunctions(Session session) { - return functions.listFunctions(); + ImmutableList.Builder functions = ImmutableList.builder(); + functions.addAll(this.functions.listFunctions()); + for (SqlPathElement sqlPathElement : session.getPath().getParsedPath()) { + String catalog = sqlPathElement.getCatalog().map(Identifier::getValue).or(session::getCatalog) + .orElseThrow(() -> new IllegalArgumentException("Session default catalog must be set to resolve a partial function name: " + sqlPathElement)); + getOptionalCatalogMetadata(session, catalog).ifPresent(metadata -> { + ConnectorSession connectorSession = session.toConnectorSession(metadata.getCatalogHandle()); + functions.addAll(metadata.getMetadata(session).listFunctions(connectorSession, sqlPathElement.getSchema().getValue().toLowerCase(ENGLISH))); + }); + } + return functions.build(); } @Override @@ -2086,7 +2107,31 @@ public ResolvedFunction resolveOperator(Session session, OperatorType operatorTy private ResolvedFunction resolvedFunctionInternal(Session session, QualifiedName name, List parameterTypes) { return functionDecoder.fromQualifiedName(name) - .orElseGet(() -> resolve(session, functionResolver.resolveFunction(session, name, parameterTypes, this::getFunctions))); + .orElseGet(() -> resolvedFunctionInternal(session, toQualifiedFunctionName(name), parameterTypes)); + } + + private ResolvedFunction resolvedFunctionInternal(Session session, QualifiedFunctionName name, List parameterTypes) + { + CatalogFunctionBinding catalogFunctionBinding = functionResolver.resolveFunction( + session, + name, + parameterTypes, + catalogSchemaFunctionName -> getFunctions(session, catalogSchemaFunctionName)); + return resolve(session, catalogFunctionBinding); + } + + // this is only public for TableFunctionRegistry, which is effectively part of MetadataManager but for some reason is a separate class + public static QualifiedFunctionName toQualifiedFunctionName(QualifiedName qualifiedName) + { + List parts = qualifiedName.getParts(); + checkArgument(parts.size() <= 3, "Function name can only have 3 parts: " + qualifiedName); + if (parts.size() == 3) { + return QualifiedFunctionName.of(parts.get(0), parts.get(1), parts.get(2)); + } + if (parts.size() == 2) { + return QualifiedFunctionName.of(parts.get(0), parts.get(1)); + } + return QualifiedFunctionName.of(parts.get(0)); } @Override @@ -2097,15 +2142,15 @@ public ResolvedFunction getCoercion(Session session, OperatorType operatorType, // todo we should not be caching functions across session return uncheckedCacheGet(coercionCache, new CoercionCacheKey(operatorType, fromType, toType), () -> { String name = mangleOperatorName(operatorType); - FunctionBinding functionBinding = functionResolver.resolveCoercion( + CatalogFunctionBinding functionBinding = functionResolver.resolveCoercion( session, - QualifiedName.of(name), + QualifiedFunctionName.of(name), Signature.builder() .name(name) .returnType(toType) .argumentType(fromType) .build(), - this::getFunctions); + catalogSchemaFunctionName -> getFunctions(session, catalogSchemaFunctionName)); return resolve(session, functionBinding); }); } @@ -2124,34 +2169,52 @@ public ResolvedFunction getCoercion(Session session, OperatorType operatorType, @Override public ResolvedFunction getCoercion(Session session, QualifiedName name, Type fromType, Type toType) { - FunctionBinding functionBinding = functionResolver.resolveCoercion( + CatalogFunctionBinding catalogFunctionBinding = functionResolver.resolveCoercion( session, - name, + toQualifiedFunctionName(name), Signature.builder() .name(name.getSuffix()) .returnType(toType) .argumentType(fromType) .build(), - this::getFunctions); - return resolve(session, functionBinding); + catalogSchemaFunctionName -> getFunctions(session, catalogSchemaFunctionName)); + return resolve(session, catalogFunctionBinding); + } + + private ResolvedFunction resolve(Session session, CatalogFunctionBinding functionBinding) + { + FunctionDependencyDeclaration dependencies = getDependencies( + session, + functionBinding.getCatalogHandle(), + functionBinding.getFunctionBinding().getFunctionId(), + functionBinding.getFunctionBinding().getBoundSignature()); + FunctionMetadata functionMetadata = getFunctionMetadata( + session, + functionBinding.getCatalogHandle(), + functionBinding.getFunctionBinding().getFunctionId(), + functionBinding.getFunctionBinding().getBoundSignature()); + return resolve(session, functionBinding.getCatalogHandle(), functionBinding.getFunctionBinding(), functionMetadata, dependencies); } - private ResolvedFunction resolve(Session session, FunctionBinding functionBinding) + private FunctionDependencyDeclaration getDependencies(Session session, CatalogHandle catalogHandle, FunctionId functionId, BoundSignature boundSignature) { - FunctionDependencyDeclaration declaration = functions.getFunctionDependencies(functionBinding.getFunctionId(), functionBinding.getBoundSignature()); - FunctionMetadata functionMetadata = getFunctionMetadata(functionBinding.getFunctionId(), functionBinding.getBoundSignature()); - return resolve(session, functionBinding, functionMetadata, declaration); + if (catalogHandle.equals(GlobalSystemConnector.CATALOG_HANDLE)) { + return functions.getFunctionDependencies(functionId, boundSignature); + } + ConnectorSession connectorSession = session.toConnectorSession(catalogHandle); + return getMetadata(session, catalogHandle) + .getFunctionDependencies(connectorSession, functionId, boundSignature); } @VisibleForTesting - public ResolvedFunction resolve(Session session, FunctionBinding functionBinding, FunctionMetadata functionMetadata, FunctionDependencyDeclaration declaration) + public ResolvedFunction resolve(Session session, CatalogHandle catalogHandle, FunctionBinding functionBinding, FunctionMetadata functionMetadata, FunctionDependencyDeclaration dependencies) { - Map dependentTypes = declaration.getTypeDependencies().stream() + Map dependentTypes = dependencies.getTypeDependencies().stream() .map(typeSignature -> applyBoundVariables(typeSignature, functionBinding)) .collect(toImmutableMap(Function.identity(), typeManager::getType, (left, right) -> left)); ImmutableSet.Builder functions = ImmutableSet.builder(); - declaration.getFunctionDependencies().stream() + dependencies.getFunctionDependencies().stream() .map(functionDependency -> { try { List argumentTypes = applyBoundVariables(functionDependency.getArgumentTypes(), functionBinding); @@ -2167,7 +2230,7 @@ public ResolvedFunction resolve(Session session, FunctionBinding functionBinding .filter(Objects::nonNull) .forEach(functions::add); - declaration.getOperatorDependencies().stream() + dependencies.getOperatorDependencies().stream() .map(operatorDependency -> { try { List argumentTypes = applyBoundVariables(operatorDependency.getArgumentTypes(), functionBinding); @@ -2183,7 +2246,7 @@ public ResolvedFunction resolve(Session session, FunctionBinding functionBinding .filter(Objects::nonNull) .forEach(functions::add); - declaration.getCastDependencies().stream() + dependencies.getCastDependencies().stream() .map(castDependency -> { try { Type fromType = typeManager.getType(applyBoundVariables(castDependency.getFromType(), functionBinding)); @@ -2202,6 +2265,7 @@ public ResolvedFunction resolve(Session session, FunctionBinding functionBinding return new ResolvedFunction( functionBinding.getBoundSignature(), + catalogHandle, functionBinding.getFunctionId(), functionMetadata.getKind(), functionMetadata.isDeterministic(), @@ -2213,26 +2277,42 @@ public ResolvedFunction resolve(Session session, FunctionBinding functionBinding @Override public boolean isAggregationFunction(Session session, QualifiedName name) { - return functionResolver.isAggregationFunction(session, name, this::getFunctions); + return functionResolver.isAggregationFunction(session, toQualifiedFunctionName(name), catalogSchemaFunctionName -> getFunctions(session, catalogSchemaFunctionName)); } - private Collection getFunctions(CatalogSchemaFunctionName name) + private Collection getFunctions(Session session, CatalogSchemaFunctionName name) { - if (name.getCatalogName().equals(GLOBAL_CATALOG)) { - return functions.getFunctions(name.getSchemaFunctionName()); + if (name.getCatalogName().equals(GlobalSystemConnector.NAME)) { + return functions.getFunctions(name.getSchemaFunctionName()).stream() + .map(function -> new CatalogFunctionMetadata(GlobalSystemConnector.CATALOG_HANDLE, function)) + .collect(toImmutableList()); } - return ImmutableList.of(); + + return getOptionalCatalogMetadata(session, name.getCatalogName()) + .map(metadata -> metadata.getMetadata(session) + .getFunctions(session.toConnectorSession(metadata.getCatalogHandle()), name.getSchemaFunctionName()).stream() + .map(function -> new CatalogFunctionMetadata(metadata.getCatalogHandle(), function)) + .collect(toImmutableList())) + .orElse(ImmutableList.of()); } @Override public FunctionMetadata getFunctionMetadata(Session session, ResolvedFunction resolvedFunction) { - return getFunctionMetadata(resolvedFunction.getFunctionId(), resolvedFunction.getSignature()); + return getFunctionMetadata(session, resolvedFunction.getCatalogHandle(), resolvedFunction.getFunctionId(), resolvedFunction.getSignature()); } - private FunctionMetadata getFunctionMetadata(FunctionId functionId, BoundSignature signature) + private FunctionMetadata getFunctionMetadata(Session session, CatalogHandle catalogHandle, FunctionId functionId, BoundSignature signature) { - FunctionMetadata functionMetadata = functions.getFunctionMetadata(functionId); + FunctionMetadata functionMetadata; + if (catalogHandle.equals(GlobalSystemConnector.CATALOG_HANDLE)) { + functionMetadata = functions.getFunctionMetadata(functionId); + } + else { + ConnectorSession connectorSession = session.toConnectorSession(catalogHandle); + functionMetadata = getMetadata(session, catalogHandle) + .getFunctionMetadata(connectorSession, functionId); + } FunctionMetadata.Builder newMetadata = FunctionMetadata.builder(functionMetadata.getKind()) .functionId(functionMetadata.getFunctionId()) @@ -2277,7 +2357,18 @@ private FunctionMetadata getFunctionMetadata(FunctionId functionId, BoundSignatu @Override public AggregationFunctionMetadata getAggregationFunctionMetadata(Session session, ResolvedFunction resolvedFunction) { - AggregationFunctionMetadata aggregationFunctionMetadata = functions.getAggregationFunctionMetadata(resolvedFunction.getFunctionId()); + Signature functionSignature; + AggregationFunctionMetadata aggregationFunctionMetadata; + if (resolvedFunction.getCatalogHandle().equals(GlobalSystemConnector.CATALOG_HANDLE)) { + functionSignature = functions.getFunctionMetadata(resolvedFunction.getFunctionId()).getSignature(); + aggregationFunctionMetadata = functions.getAggregationFunctionMetadata(resolvedFunction.getFunctionId()); + } + else { + ConnectorSession connectorSession = session.toConnectorSession(resolvedFunction.getCatalogHandle()); + ConnectorMetadata metadata = getMetadata(session, resolvedFunction.getCatalogHandle()); + functionSignature = metadata.getFunctionMetadata(connectorSession, resolvedFunction.getFunctionId()).getSignature(); + aggregationFunctionMetadata = metadata.getAggregationFunctionMetadata(connectorSession, resolvedFunction.getFunctionId()); + } AggregationFunctionMetadataBuilder builder = AggregationFunctionMetadata.builder(); if (aggregationFunctionMetadata.isOrderSensitive()) { @@ -2285,7 +2376,7 @@ public AggregationFunctionMetadata getAggregationFunctionMetadata(Session sessio } if (!aggregationFunctionMetadata.getIntermediateTypes().isEmpty()) { - FunctionBinding functionBinding = toFunctionBinding(resolvedFunction); + FunctionBinding functionBinding = toFunctionBinding(resolvedFunction.getFunctionId(), resolvedFunction.getSignature(), functionSignature); aggregationFunctionMetadata.getIntermediateTypes().stream() .map(typeSignature -> applyBoundVariables(typeSignature, functionBinding)) .forEach(builder::intermediateType); @@ -2294,12 +2385,6 @@ public AggregationFunctionMetadata getAggregationFunctionMetadata(Session sessio return builder.build(); } - private FunctionBinding toFunctionBinding(ResolvedFunction resolvedFunction) - { - Signature functionSignature = functions.getFunctionMetadata(resolvedFunction.getFunctionId()).getSignature(); - return toFunctionBinding(resolvedFunction.getFunctionId(), resolvedFunction.getSignature(), functionSignature); - } - @VisibleForTesting public static FunctionBinding toFunctionBinding(FunctionId functionId, BoundSignature boundSignature, Signature functionSignature) { diff --git a/core/trino-main/src/main/java/io/trino/metadata/OperatorNameUtil.java b/core/trino-main/src/main/java/io/trino/metadata/OperatorNameUtil.java new file mode 100644 index 000000000000..443083373ef5 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/metadata/OperatorNameUtil.java @@ -0,0 +1,44 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.metadata; + +import com.clearspring.analytics.util.Preconditions; +import com.google.common.annotations.VisibleForTesting; +import io.trino.spi.function.OperatorType; + +import java.util.Locale; + +public final class OperatorNameUtil +{ + private static final String OPERATOR_PREFIX = "$operator$"; + + private OperatorNameUtil() {} + + public static boolean isOperatorName(String mangledName) + { + return mangledName.startsWith(OPERATOR_PREFIX); + } + + public static String mangleOperatorName(OperatorType operatorType) + { + return OPERATOR_PREFIX + operatorType.name(); + } + + @VisibleForTesting + public static OperatorType unmangleOperator(String mangledName) + { + Preconditions.checkArgument(mangledName.startsWith(OPERATOR_PREFIX), "not a mangled operator name: %s", mangledName); + return OperatorType.valueOf(mangledName.substring(OPERATOR_PREFIX.length()).toUpperCase(Locale.ENGLISH)); + } +} diff --git a/core/trino-main/src/main/java/io/trino/metadata/PolymorphicScalarFunction.java b/core/trino-main/src/main/java/io/trino/metadata/PolymorphicScalarFunction.java index e18e9d321210..a4640061fcdf 100644 --- a/core/trino-main/src/main/java/io/trino/metadata/PolymorphicScalarFunction.java +++ b/core/trino-main/src/main/java/io/trino/metadata/PolymorphicScalarFunction.java @@ -18,9 +18,11 @@ import io.trino.metadata.PolymorphicScalarFunctionBuilder.MethodAndNativeContainerTypes; import io.trino.metadata.PolymorphicScalarFunctionBuilder.MethodsGroup; import io.trino.metadata.PolymorphicScalarFunctionBuilder.SpecializeContext; -import io.trino.operator.scalar.ChoicesScalarFunctionImplementation; -import io.trino.operator.scalar.ChoicesScalarFunctionImplementation.ScalarImplementationChoice; -import io.trino.operator.scalar.ScalarFunctionImplementation; +import io.trino.operator.scalar.ChoicesSpecializedSqlScalarFunction; +import io.trino.operator.scalar.ChoicesSpecializedSqlScalarFunction.ScalarImplementationChoice; +import io.trino.operator.scalar.SpecializedSqlScalarFunction; +import io.trino.spi.function.BoundSignature; +import io.trino.spi.function.FunctionMetadata; import io.trino.spi.function.InvocationConvention.InvocationArgumentConvention; import io.trino.spi.function.InvocationConvention.InvocationReturnConvention; import io.trino.spi.type.Type; @@ -48,7 +50,7 @@ class PolymorphicScalarFunction } @Override - protected ScalarFunctionImplementation specialize(BoundSignature boundSignature) + protected SpecializedSqlScalarFunction specialize(BoundSignature boundSignature) { ImmutableList.Builder implementationChoices = ImmutableList.builder(); @@ -58,7 +60,7 @@ protected ScalarFunctionImplementation specialize(BoundSignature boundSignature) implementationChoices.add(getScalarFunctionImplementationChoice(functionBinding, choice)); } - return new ChoicesScalarFunctionImplementation(boundSignature, implementationChoices.build()); + return new ChoicesSpecializedSqlScalarFunction(boundSignature, implementationChoices.build()); } private ScalarImplementationChoice getScalarFunctionImplementationChoice( diff --git a/core/trino-main/src/main/java/io/trino/metadata/PolymorphicScalarFunctionBuilder.java b/core/trino-main/src/main/java/io/trino/metadata/PolymorphicScalarFunctionBuilder.java index c87a9faab641..c71800d8deab 100644 --- a/core/trino-main/src/main/java/io/trino/metadata/PolymorphicScalarFunctionBuilder.java +++ b/core/trino-main/src/main/java/io/trino/metadata/PolymorphicScalarFunctionBuilder.java @@ -16,9 +16,11 @@ import com.google.common.collect.ImmutableList; import com.google.common.primitives.Booleans; import io.trino.metadata.PolymorphicScalarFunction.PolymorphicScalarFunctionChoice; +import io.trino.spi.function.FunctionMetadata; import io.trino.spi.function.InvocationConvention.InvocationArgumentConvention; import io.trino.spi.function.InvocationConvention.InvocationReturnConvention; import io.trino.spi.function.OperatorType; +import io.trino.spi.function.Signature; import io.trino.spi.type.Type; import java.lang.reflect.Method; @@ -32,7 +34,7 @@ import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; import static com.google.common.collect.ImmutableList.toImmutableList; -import static io.trino.metadata.Signature.mangleOperatorName; +import static io.trino.metadata.OperatorNameUtil.mangleOperatorName; import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION; import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.NEVER_NULL; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL; diff --git a/core/trino-main/src/main/java/io/trino/metadata/ResolvedFunction.java b/core/trino-main/src/main/java/io/trino/metadata/ResolvedFunction.java index a3dfd45f67f6..6f397cd2cacd 100644 --- a/core/trino-main/src/main/java/io/trino/metadata/ResolvedFunction.java +++ b/core/trino-main/src/main/java/io/trino/metadata/ResolvedFunction.java @@ -26,7 +26,11 @@ import io.airlift.json.JsonCodecFactory; import io.airlift.json.ObjectMapperProvider; import io.trino.collect.cache.NonEvictableLoadingCache; +import io.trino.connector.CatalogHandle; +import io.trino.spi.function.BoundSignature; +import io.trino.spi.function.FunctionId; import io.trino.spi.function.FunctionKind; +import io.trino.spi.function.FunctionNullability; import io.trino.spi.type.Type; import io.trino.spi.type.TypeId; import io.trino.spi.type.TypeSignature; @@ -57,6 +61,7 @@ public class ResolvedFunction { private static final String PREFIX = "@"; private final BoundSignature signature; + private final CatalogHandle catalogHandle; private final FunctionId functionId; private final FunctionKind functionKind; private final boolean deterministic; @@ -67,6 +72,7 @@ public class ResolvedFunction @JsonCreator public ResolvedFunction( @JsonProperty("signature") BoundSignature signature, + @JsonProperty("catalogHandle") CatalogHandle catalogHandle, @JsonProperty("id") FunctionId functionId, @JsonProperty("functionKind") FunctionKind functionKind, @JsonProperty("deterministic") boolean deterministic, @@ -75,6 +81,7 @@ public ResolvedFunction( @JsonProperty("functionDependencies") Set functionDependencies) { this.signature = requireNonNull(signature, "signature is null"); + this.catalogHandle = requireNonNull(catalogHandle, "catalogHandle is null"); this.functionId = requireNonNull(functionId, "functionId is null"); this.functionKind = requireNonNull(functionKind, "functionKind is null"); this.deterministic = deterministic; @@ -90,6 +97,12 @@ public BoundSignature getSignature() return signature; } + @JsonProperty + public CatalogHandle getCatalogHandle() + { + return catalogHandle; + } + @JsonProperty("id") public FunctionId getFunctionId() { @@ -158,6 +171,7 @@ public boolean equals(Object o) } ResolvedFunction that = (ResolvedFunction) o; return Objects.equals(signature, that.signature) && + Objects.equals(catalogHandle, that.catalogHandle) && Objects.equals(functionId, that.functionId) && functionKind == that.functionKind && deterministic == that.deterministic && @@ -169,7 +183,7 @@ public boolean equals(Object o) @Override public int hashCode() { - return Objects.hash(signature, functionId, functionKind, deterministic, functionNullability, typeDependencies, functionDependencies); + return Objects.hash(signature, catalogHandle, functionId, functionKind, deterministic, functionNullability, typeDependencies, functionDependencies); } @Override diff --git a/core/trino-main/src/main/java/io/trino/metadata/SignatureBinder.java b/core/trino-main/src/main/java/io/trino/metadata/SignatureBinder.java index c5ba362b0d06..748289c8572c 100644 --- a/core/trino-main/src/main/java/io/trino/metadata/SignatureBinder.java +++ b/core/trino-main/src/main/java/io/trino/metadata/SignatureBinder.java @@ -19,6 +19,11 @@ import com.google.common.collect.ImmutableSet; import io.trino.Session; import io.trino.spi.TrinoException; +import io.trino.spi.function.BoundSignature; +import io.trino.spi.function.FunctionId; +import io.trino.spi.function.LongVariableConstraint; +import io.trino.spi.function.Signature; +import io.trino.spi.function.TypeVariableConstraint; import io.trino.spi.type.NamedTypeSignature; import io.trino.spi.type.ParameterKind; import io.trino.spi.type.RowType; diff --git a/core/trino-main/src/main/java/io/trino/metadata/SqlAggregationFunction.java b/core/trino-main/src/main/java/io/trino/metadata/SqlAggregationFunction.java index ab8deb73b9b2..b88677b25104 100644 --- a/core/trino-main/src/main/java/io/trino/metadata/SqlAggregationFunction.java +++ b/core/trino-main/src/main/java/io/trino/metadata/SqlAggregationFunction.java @@ -15,7 +15,11 @@ import com.google.common.collect.ImmutableList; import io.trino.operator.aggregation.AggregationFromAnnotationsParser; -import io.trino.operator.aggregation.AggregationMetadata; +import io.trino.spi.function.AggregationFunctionMetadata; +import io.trino.spi.function.AggregationImplementation; +import io.trino.spi.function.BoundSignature; +import io.trino.spi.function.FunctionDependencies; +import io.trino.spi.function.FunctionMetadata; import java.util.List; @@ -54,12 +58,12 @@ public AggregationFunctionMetadata getAggregationMetadata() return aggregationFunctionMetadata; } - public AggregationMetadata specialize(BoundSignature boundSignature, FunctionDependencies functionDependencies) + public AggregationImplementation specialize(BoundSignature boundSignature, FunctionDependencies functionDependencies) { return specialize(boundSignature); } - protected AggregationMetadata specialize(BoundSignature boundSignature) + protected AggregationImplementation specialize(BoundSignature boundSignature) { throw new UnsupportedOperationException(); } diff --git a/core/trino-main/src/main/java/io/trino/metadata/SqlFunction.java b/core/trino-main/src/main/java/io/trino/metadata/SqlFunction.java index 55e9b20bb36b..c883acd9e348 100644 --- a/core/trino-main/src/main/java/io/trino/metadata/SqlFunction.java +++ b/core/trino-main/src/main/java/io/trino/metadata/SqlFunction.java @@ -13,7 +13,11 @@ */ package io.trino.metadata; -import static io.trino.metadata.FunctionDependencyDeclaration.NO_DEPENDENCIES; +import io.trino.spi.function.BoundSignature; +import io.trino.spi.function.FunctionDependencyDeclaration; +import io.trino.spi.function.FunctionMetadata; + +import static io.trino.spi.function.FunctionDependencyDeclaration.NO_DEPENDENCIES; public interface SqlFunction { diff --git a/core/trino-main/src/main/java/io/trino/metadata/SqlScalarFunction.java b/core/trino-main/src/main/java/io/trino/metadata/SqlScalarFunction.java index 47b0918c07b1..a0f81925085b 100644 --- a/core/trino-main/src/main/java/io/trino/metadata/SqlScalarFunction.java +++ b/core/trino-main/src/main/java/io/trino/metadata/SqlScalarFunction.java @@ -13,7 +13,10 @@ */ package io.trino.metadata; -import io.trino.operator.scalar.ScalarFunctionImplementation; +import io.trino.operator.scalar.SpecializedSqlScalarFunction; +import io.trino.spi.function.BoundSignature; +import io.trino.spi.function.FunctionDependencies; +import io.trino.spi.function.FunctionMetadata; public abstract class SqlScalarFunction implements SqlFunction @@ -31,12 +34,12 @@ public FunctionMetadata getFunctionMetadata() return functionMetadata; } - public ScalarFunctionImplementation specialize(BoundSignature boundSignature, FunctionDependencies functionDependencies) + public SpecializedSqlScalarFunction specialize(BoundSignature boundSignature, FunctionDependencies functionDependencies) { return specialize(boundSignature); } - protected ScalarFunctionImplementation specialize(BoundSignature boundSignature) + protected SpecializedSqlScalarFunction specialize(BoundSignature boundSignature) { throw new UnsupportedOperationException(); } diff --git a/core/trino-main/src/main/java/io/trino/metadata/TableFunctionRegistry.java b/core/trino-main/src/main/java/io/trino/metadata/TableFunctionRegistry.java index 17a831cf3ac4..687a652df28e 100644 --- a/core/trino-main/src/main/java/io/trino/metadata/TableFunctionRegistry.java +++ b/core/trino-main/src/main/java/io/trino/metadata/TableFunctionRegistry.java @@ -15,6 +15,7 @@ import io.trino.connector.CatalogHandle; import io.trino.connector.CatalogServiceProvider; +import io.trino.spi.function.SchemaFunctionName; import io.trino.spi.ptf.ConnectorTableFunction; import javax.annotation.concurrent.ThreadSafe; diff --git a/core/trino-main/src/main/java/io/trino/operator/ParametricFunctionHelpers.java b/core/trino-main/src/main/java/io/trino/operator/ParametricFunctionHelpers.java index 9529e2186f71..fc6531692a2e 100644 --- a/core/trino-main/src/main/java/io/trino/operator/ParametricFunctionHelpers.java +++ b/core/trino-main/src/main/java/io/trino/operator/ParametricFunctionHelpers.java @@ -14,8 +14,8 @@ package io.trino.operator; import io.trino.metadata.FunctionBinding; -import io.trino.metadata.FunctionDependencies; import io.trino.operator.annotations.ImplementationDependency; +import io.trino.spi.function.FunctionDependencies; import java.lang.invoke.MethodHandle; import java.lang.invoke.MethodHandles; diff --git a/core/trino-main/src/main/java/io/trino/operator/ParametricImplementation.java b/core/trino-main/src/main/java/io/trino/operator/ParametricImplementation.java index eea744d3da61..a1184f14fd3c 100644 --- a/core/trino-main/src/main/java/io/trino/operator/ParametricImplementation.java +++ b/core/trino-main/src/main/java/io/trino/operator/ParametricImplementation.java @@ -13,8 +13,8 @@ */ package io.trino.operator; -import io.trino.metadata.FunctionNullability; -import io.trino.metadata.Signature; +import io.trino.spi.function.FunctionNullability; +import io.trino.spi.function.Signature; public interface ParametricImplementation { diff --git a/core/trino-main/src/main/java/io/trino/operator/ParametricImplementationsGroup.java b/core/trino-main/src/main/java/io/trino/operator/ParametricImplementationsGroup.java index 9af04b0d1ee6..7cf2741c33a9 100644 --- a/core/trino-main/src/main/java/io/trino/operator/ParametricImplementationsGroup.java +++ b/core/trino-main/src/main/java/io/trino/operator/ParametricImplementationsGroup.java @@ -15,8 +15,8 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import io.trino.metadata.FunctionNullability; -import io.trino.metadata.Signature; +import io.trino.spi.function.FunctionNullability; +import io.trino.spi.function.Signature; import io.trino.spi.type.TypeSignature; import java.util.List; diff --git a/core/trino-main/src/main/java/io/trino/operator/WindowFunctionDefinition.java b/core/trino-main/src/main/java/io/trino/operator/WindowFunctionDefinition.java index c896c32d46b9..cb326221e492 100644 --- a/core/trino-main/src/main/java/io/trino/operator/WindowFunctionDefinition.java +++ b/core/trino-main/src/main/java/io/trino/operator/WindowFunctionDefinition.java @@ -16,8 +16,8 @@ import com.google.common.collect.ImmutableList; import io.trino.operator.window.FrameInfo; import io.trino.operator.window.MappedWindowFunction; -import io.trino.operator.window.WindowFunctionSupplier; import io.trino.spi.function.WindowFunction; +import io.trino.spi.function.WindowFunctionSupplier; import io.trino.spi.type.Type; import java.util.Arrays; diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/AccumulatorCompiler.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/AccumulatorCompiler.java index 81b5482db904..dd1f56c53a3c 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/AccumulatorCompiler.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/AccumulatorCompiler.java @@ -27,10 +27,7 @@ import io.airlift.bytecode.control.IfStatement; import io.airlift.bytecode.expression.BytecodeExpression; import io.airlift.bytecode.expression.BytecodeExpressions; -import io.trino.metadata.BoundSignature; -import io.trino.metadata.FunctionNullability; import io.trino.operator.GroupByIdBlock; -import io.trino.operator.aggregation.AggregationMetadata.AccumulatorStateDescriptor; import io.trino.operator.window.InternalWindowIndex; import io.trino.spi.Page; import io.trino.spi.block.Block; @@ -39,6 +36,10 @@ import io.trino.spi.function.AccumulatorState; import io.trino.spi.function.AccumulatorStateFactory; import io.trino.spi.function.AccumulatorStateSerializer; +import io.trino.spi.function.AggregationImplementation; +import io.trino.spi.function.AggregationImplementation.AccumulatorStateDescriptor; +import io.trino.spi.function.BoundSignature; +import io.trino.spi.function.FunctionNullability; import io.trino.spi.function.GroupedAccumulatorState; import io.trino.spi.function.WindowIndex; import io.trino.sql.gen.Binding; @@ -86,41 +87,41 @@ private AccumulatorCompiler() {} public static AccumulatorFactory generateAccumulatorFactory( BoundSignature boundSignature, - AggregationMetadata metadata, + AggregationImplementation implementation, FunctionNullability functionNullability) { // change types used in Aggregation methods to types used in the core Trino engine to simplify code generation - metadata = normalizeAggregationMethods(metadata); + implementation = normalizeAggregationMethods(implementation); DynamicClassLoader classLoader = new DynamicClassLoader(AccumulatorCompiler.class.getClassLoader()); List argumentNullable = functionNullability.getArgumentNullable() - .subList(0, functionNullability.getArgumentNullable().size() - metadata.getLambdaInterfaces().size()); + .subList(0, functionNullability.getArgumentNullable().size() - implementation.getLambdaInterfaces().size()); Constructor accumulatorConstructor = generateAccumulatorClass( boundSignature, Accumulator.class, - metadata, + implementation, argumentNullable, classLoader); Constructor groupedAccumulatorConstructor = generateAccumulatorClass( boundSignature, GroupedAccumulator.class, - metadata, + implementation, argumentNullable, classLoader); return new CompiledAccumulatorFactory( accumulatorConstructor, groupedAccumulatorConstructor, - metadata.getLambdaInterfaces()); + implementation.getLambdaInterfaces()); } private static Constructor generateAccumulatorClass( BoundSignature boundSignature, Class accumulatorInterface, - AggregationMetadata metadata, + AggregationImplementation implementation, List argumentNullable, DynamicClassLoader classLoader) { @@ -134,7 +135,7 @@ private static Constructor generateAccumulatorClass( CallSiteBinder callSiteBinder = new CallSiteBinder(); - List> stateDescriptors = metadata.getAccumulatorStateDescriptors(); + List> stateDescriptors = implementation.getAccumulatorStateDescriptors(); List stateFieldAndDescriptors = new ArrayList<>(); for (int i = 0; i < stateDescriptors.size(); i++) { stateFieldAndDescriptors.add(new StateFieldAndDescriptor( @@ -147,7 +148,7 @@ private static Constructor generateAccumulatorClass( .map(StateFieldAndDescriptor::getStateField) .collect(toImmutableList()); - int lambdaCount = metadata.getLambdaInterfaces().size(); + int lambdaCount = implementation.getLambdaInterfaces().size(); List lambdaProviderFields = new ArrayList<>(lambdaCount); for (int i = 0; i < lambdaCount; i++) { lambdaProviderFields.add(definition.declareField(a(PRIVATE, FINAL), "lambdaProvider_" + i, Supplier.class)); @@ -173,7 +174,7 @@ private static Constructor generateAccumulatorClass( stateFields, argumentNullable, lambdaProviderFields, - metadata.getInputFunction(), + implementation.getInputFunction(), callSiteBinder, grouped); generateGetEstimatedSize(definition, stateFields); @@ -182,7 +183,7 @@ private static Constructor generateAccumulatorClass( definition, stateFieldAndDescriptors, lambdaProviderFields, - metadata.getCombineFunction(), + implementation.getCombineFunction(), callSiteBinder, grouped); @@ -194,10 +195,10 @@ private static Constructor generateAccumulatorClass( } if (grouped) { - generateGroupedEvaluateFinal(definition, stateFields, metadata.getOutputFunction(), callSiteBinder); + generateGroupedEvaluateFinal(definition, stateFields, implementation.getOutputFunction(), callSiteBinder); } else { - generateEvaluateFinal(definition, stateFields, metadata.getOutputFunction(), callSiteBinder); + generateEvaluateFinal(definition, stateFields, implementation.getOutputFunction(), callSiteBinder); } if (grouped) { @@ -215,13 +216,13 @@ private static Constructor generateAccumulatorClass( public static Constructor generateWindowAccumulatorClass( BoundSignature boundSignature, - AggregationMetadata metadata, + AggregationImplementation implementation, FunctionNullability functionNullability) { DynamicClassLoader classLoader = new DynamicClassLoader(AccumulatorCompiler.class.getClassLoader()); List argumentNullable = functionNullability.getArgumentNullable() - .subList(0, functionNullability.getArgumentNullable().size() - metadata.getLambdaInterfaces().size()); + .subList(0, functionNullability.getArgumentNullable().size() - implementation.getLambdaInterfaces().size()); ClassDefinition definition = new ClassDefinition( a(PUBLIC, FINAL), @@ -231,7 +232,7 @@ public static Constructor generateWindowAccumulator CallSiteBinder callSiteBinder = new CallSiteBinder(); - List> stateDescriptors = metadata.getAccumulatorStateDescriptors(); + List> stateDescriptors = implementation.getAccumulatorStateDescriptors(); List stateFieldAndDescriptors = new ArrayList<>(); for (int i = 0; i < stateDescriptors.size(); i++) { stateFieldAndDescriptors.add(new StateFieldAndDescriptor( @@ -244,7 +245,7 @@ public static Constructor generateWindowAccumulator .map(StateFieldAndDescriptor::getStateField) .collect(toImmutableList()); - int lambdaCount = metadata.getLambdaInterfaces().size(); + int lambdaCount = implementation.getLambdaInterfaces().size(); List lambdaProviderFields = new ArrayList<>(lambdaCount); for (int i = 0; i < lambdaCount; i++) { lambdaProviderFields.add(definition.declareField(a(PRIVATE, FINAL), "lambdaProvider_" + i, Supplier.class)); @@ -268,10 +269,10 @@ public static Constructor generateWindowAccumulator stateFields, argumentNullable, lambdaProviderFields, - metadata.getInputFunction(), + implementation.getInputFunction(), "addInput", callSiteBinder); - metadata.getRemoveInputFunction().ifPresent( + implementation.getRemoveInputFunction().ifPresent( removeInputFunction -> generateAddOrRemoveInputWindowIndex( definition, stateFields, @@ -281,7 +282,7 @@ public static Constructor generateWindowAccumulator "removeInput", callSiteBinder)); - generateEvaluateFinal(definition, stateFields, metadata.getOutputFunction(), callSiteBinder); + generateEvaluateFinal(definition, stateFields, implementation.getOutputFunction(), callSiteBinder); generateGetEstimatedSize(definition, stateFields); Class windowAccumulatorClass = defineClass(definition, WindowAccumulator.class, callSiteBinder.getBindings(), classLoader); @@ -1044,18 +1045,23 @@ private static BytecodeExpression generateRequireNotNull(BytecodeExpression expr .cast(expression.getType()); } - private static AggregationMetadata normalizeAggregationMethods(AggregationMetadata metadata) + private static AggregationImplementation normalizeAggregationMethods(AggregationImplementation implementation) { // change aggregations state variables to simply AccumulatorState to avoid any class loader issues in generated code - int stateParameterCount = metadata.getAccumulatorStateDescriptors().size(); - int lambdaParameterCount = metadata.getLambdaInterfaces().size(); - return new AggregationMetadata( - castStateParameters(metadata.getInputFunction(), stateParameterCount, lambdaParameterCount), - metadata.getRemoveInputFunction().map(removeFunction -> castStateParameters(removeFunction, stateParameterCount, lambdaParameterCount)), - metadata.getCombineFunction().map(combineFunction -> castStateParameters(combineFunction, stateParameterCount * 2, lambdaParameterCount)), - castStateParameters(metadata.getOutputFunction(), stateParameterCount, 0), - metadata.getAccumulatorStateDescriptors(), - metadata.getLambdaInterfaces()); + int stateParameterCount = implementation.getAccumulatorStateDescriptors().size(); + int lambdaParameterCount = implementation.getLambdaInterfaces().size(); + AggregationImplementation.Builder builder = AggregationImplementation.builder(); + builder.inputFunction(castStateParameters(implementation.getInputFunction(), stateParameterCount, lambdaParameterCount)); + implementation.getRemoveInputFunction() + .map(removeFunction -> castStateParameters(removeFunction, stateParameterCount, lambdaParameterCount)) + .ifPresent(builder::removeInputFunction); + implementation.getCombineFunction() + .map(combineFunction -> castStateParameters(combineFunction, stateParameterCount * 2, lambdaParameterCount)) + .ifPresent(builder::combineFunction); + builder.outputFunction(castStateParameters(implementation.getOutputFunction(), stateParameterCount, 0)); + builder.accumulatorStateDescriptors(implementation.getAccumulatorStateDescriptors()); + builder.lambdaInterfaces(implementation.getLambdaInterfaces()); + return builder.build(); } private static MethodHandle castStateParameters(MethodHandle inputFunction, int stateParameterCount, int lambdaParameterCount) diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/AggregationFromAnnotationsParser.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/AggregationFromAnnotationsParser.java index 44949ac5d81d..fcf064cbfc70 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/AggregationFromAnnotationsParser.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/AggregationFromAnnotationsParser.java @@ -20,10 +20,7 @@ import com.google.common.collect.MoreCollectors; import io.airlift.log.Logger; import io.trino.metadata.FunctionBinding; -import io.trino.metadata.FunctionDependencies; -import io.trino.metadata.Signature; import io.trino.operator.ParametricImplementationsGroup; -import io.trino.operator.aggregation.AggregationMetadata.AccumulatorStateDescriptor; import io.trino.operator.aggregation.state.InOutStateSerializer; import io.trino.operator.annotations.FunctionsParserHelper; import io.trino.operator.annotations.ImplementationDependency; @@ -34,8 +31,10 @@ import io.trino.spi.function.AccumulatorStateMetadata; import io.trino.spi.function.AccumulatorStateSerializer; import io.trino.spi.function.AggregationFunction; +import io.trino.spi.function.AggregationImplementation.AccumulatorStateDescriptor; import io.trino.spi.function.AggregationState; import io.trino.spi.function.CombineFunction; +import io.trino.spi.function.FunctionDependencies; import io.trino.spi.function.FunctionDependency; import io.trino.spi.function.InOut; import io.trino.spi.function.InputFunction; @@ -43,6 +42,7 @@ import io.trino.spi.function.OperatorDependency; import io.trino.spi.function.OutputFunction; import io.trino.spi.function.RemoveInputFunction; +import io.trino.spi.function.Signature; import io.trino.spi.function.TypeParameter; import io.trino.spi.type.TypeSignature; @@ -66,7 +66,7 @@ import static com.google.common.base.Strings.emptyToNull; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.Iterables.getOnlyElement; -import static io.trino.operator.aggregation.AggregationImplementation.Parser.parseImplementation; +import static io.trino.operator.aggregation.ParametricAggregationImplementation.Parser.parseImplementation; import static io.trino.operator.aggregation.state.StateCompiler.generateInOutStateFactory; import static io.trino.operator.aggregation.state.StateCompiler.generateStateFactory; import static io.trino.operator.aggregation.state.StateCompiler.generateStateSerializer; @@ -107,11 +107,11 @@ else if (combineFunction.isPresent()) { } // Input functions can have either an exact signature, or generic/calculate signature - List exactImplementations = new ArrayList<>(); - List nonExactImplementations = new ArrayList<>(); + List exactImplementations = new ArrayList<>(); + List nonExactImplementations = new ArrayList<>(); for (Method inputFunction : getInputFunctions(aggregationDefinition, stateDetails)) { Optional removeInputFunction = getRemoveInputFunction(aggregationDefinition, inputFunction); - AggregationImplementation implementation = parseImplementation( + ParametricAggregationImplementation implementation = parseImplementation( aggregationDefinition, header.getName(), stateDetails, @@ -141,13 +141,13 @@ private static List buildFunctions( String name, AggregationHeader header, List> stateDetails, - List exactImplementations, - List nonExactImplementations) + List exactImplementations, + List nonExactImplementations) { ImmutableList.Builder functions = ImmutableList.builder(); // create a separate function for each exact implementation - for (AggregationImplementation exactImplementation : exactImplementations) { + for (ParametricAggregationImplementation exactImplementation : exactImplementations) { functions.add(new ParametricAggregation( exactImplementation.getSignature().withName(name), header, @@ -157,9 +157,9 @@ private static List buildFunctions( // if there are non-exact functions, create a single generic/calculated function using these implementations if (!nonExactImplementations.isEmpty()) { - ParametricImplementationsGroup.Builder implementationsBuilder = ParametricImplementationsGroup.builder(); + ParametricImplementationsGroup.Builder implementationsBuilder = ParametricImplementationsGroup.builder(); nonExactImplementations.forEach(implementationsBuilder::addImplementation); - ParametricImplementationsGroup implementations = implementationsBuilder.build(); + ParametricImplementationsGroup implementations = implementationsBuilder.build(); functions.add(new ParametricAggregation( implementations.getSignature().withName(name), header, @@ -598,10 +598,10 @@ public List getDependencies() public AccumulatorStateDescriptor createAccumulatorStateDescriptor(FunctionBinding functionBinding, FunctionDependencies functionDependencies) { - return new AccumulatorStateDescriptor<>( - stateClass, - serializerGenerator.apply(functionBinding, functionDependencies), - factoryGenerator.apply(functionBinding, functionDependencies)); + return AccumulatorStateDescriptor.builder(stateClass) + .serializer(serializerGenerator.apply(functionBinding, functionDependencies)) + .factory(factoryGenerator.apply(functionBinding, functionDependencies)) + .build(); } @Override diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/AggregationFunctionAdapter.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/AggregationFunctionAdapter.java index 1de30a09f617..84d20bfddf86 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/AggregationFunctionAdapter.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/AggregationFunctionAdapter.java @@ -14,8 +14,8 @@ package io.trino.operator.aggregation; import com.google.common.collect.ImmutableList; -import io.trino.metadata.BoundSignature; import io.trino.spi.block.Block; +import io.trino.spi.function.BoundSignature; import io.trino.spi.type.Type; import java.lang.invoke.MethodHandle; diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/AggregationMetadata.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/AggregationMetadata.java deleted file mode 100644 index 681a82513121..000000000000 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/AggregationMetadata.java +++ /dev/null @@ -1,127 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.trino.operator.aggregation; - -import com.google.common.collect.ImmutableList; -import io.trino.spi.function.AccumulatorState; -import io.trino.spi.function.AccumulatorStateFactory; -import io.trino.spi.function.AccumulatorStateSerializer; - -import java.lang.invoke.MethodHandle; -import java.util.List; -import java.util.Optional; - -import static java.util.Objects.requireNonNull; - -public class AggregationMetadata -{ - private final MethodHandle inputFunction; - private final Optional removeInputFunction; - private final Optional combineFunction; - private final MethodHandle outputFunction; - private final List> accumulatorStateDescriptors; - private final List> lambdaInterfaces; - - public AggregationMetadata( - MethodHandle inputFunction, - Optional removeInputFunction, - Optional combineFunction, - MethodHandle outputFunction, - List> accumulatorStateDescriptors) - { - this( - inputFunction, - removeInputFunction, - combineFunction, - outputFunction, - accumulatorStateDescriptors, - ImmutableList.of()); - } - - public AggregationMetadata( - MethodHandle inputFunction, - Optional removeInputFunction, - Optional combineFunction, - MethodHandle outputFunction, - List> accumulatorStateDescriptors, - List> lambdaInterfaces) - { - this.inputFunction = requireNonNull(inputFunction, "inputFunction is null"); - this.removeInputFunction = requireNonNull(removeInputFunction, "removeInputFunction is null"); - this.combineFunction = requireNonNull(combineFunction, "combineFunction is null"); - this.outputFunction = requireNonNull(outputFunction, "outputFunction is null"); - this.accumulatorStateDescriptors = requireNonNull(accumulatorStateDescriptors, "accumulatorStateDescriptors is null"); - this.lambdaInterfaces = ImmutableList.copyOf(requireNonNull(lambdaInterfaces, "lambdaInterfaces is null")); - } - - public MethodHandle getInputFunction() - { - return inputFunction; - } - - public Optional getRemoveInputFunction() - { - return removeInputFunction; - } - - public Optional getCombineFunction() - { - return combineFunction; - } - - public MethodHandle getOutputFunction() - { - return outputFunction; - } - - public List> getAccumulatorStateDescriptors() - { - return accumulatorStateDescriptors; - } - - public List> getLambdaInterfaces() - { - return lambdaInterfaces; - } - - public static class AccumulatorStateDescriptor - { - private final Class stateInterface; - private final AccumulatorStateSerializer serializer; - private final AccumulatorStateFactory factory; - - public AccumulatorStateDescriptor(Class stateInterface, AccumulatorStateSerializer serializer, AccumulatorStateFactory factory) - { - this.stateInterface = requireNonNull(stateInterface, "stateInterface is null"); - this.serializer = requireNonNull(serializer, "serializer is null"); - this.factory = requireNonNull(factory, "factory is null"); - } - - // this is only used to verify method interfaces - public Class getStateInterface() - { - return stateInterface; - } - - public AccumulatorStateSerializer getSerializer() - { - return serializer; - } - - public AccumulatorStateFactory getFactory() - { - return factory; - } - } -} diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/ParametricAggregation.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/ParametricAggregation.java index 5f7a6f408c1a..d447e88ecd35 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/ParametricAggregation.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/ParametricAggregation.java @@ -15,24 +15,24 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableList; -import io.trino.metadata.AggregationFunctionMetadata; -import io.trino.metadata.AggregationFunctionMetadata.AggregationFunctionMetadataBuilder; -import io.trino.metadata.BoundSignature; import io.trino.metadata.FunctionBinding; -import io.trino.metadata.FunctionDependencies; -import io.trino.metadata.FunctionDependencyDeclaration; -import io.trino.metadata.FunctionDependencyDeclaration.FunctionDependencyDeclarationBuilder; -import io.trino.metadata.FunctionMetadata; -import io.trino.metadata.FunctionNullability; -import io.trino.metadata.Signature; import io.trino.metadata.SignatureBinder; import io.trino.metadata.SqlAggregationFunction; import io.trino.operator.ParametricImplementationsGroup; import io.trino.operator.aggregation.AggregationFromAnnotationsParser.AccumulatorStateDetails; import io.trino.operator.aggregation.AggregationFunctionAdapter.AggregationParameterKind; -import io.trino.operator.aggregation.AggregationMetadata.AccumulatorStateDescriptor; import io.trino.operator.annotations.ImplementationDependency; import io.trino.spi.TrinoException; +import io.trino.spi.function.AggregationFunctionMetadata; +import io.trino.spi.function.AggregationFunctionMetadata.AggregationFunctionMetadataBuilder; +import io.trino.spi.function.AggregationImplementation; +import io.trino.spi.function.BoundSignature; +import io.trino.spi.function.FunctionDependencies; +import io.trino.spi.function.FunctionDependencyDeclaration; +import io.trino.spi.function.FunctionDependencyDeclaration.FunctionDependencyDeclarationBuilder; +import io.trino.spi.function.FunctionMetadata; +import io.trino.spi.function.FunctionNullability; +import io.trino.spi.function.Signature; import java.lang.invoke.MethodHandle; import java.util.Collection; @@ -52,14 +52,14 @@ public class ParametricAggregation extends SqlAggregationFunction { - private final ParametricImplementationsGroup implementations; + private final ParametricImplementationsGroup implementations; private final List> stateDetails; public ParametricAggregation( Signature signature, AggregationHeader details, List> stateDetails, - ParametricImplementationsGroup implementations) + ParametricImplementationsGroup implementations) { super( createFunctionMetadata(signature, details, implementations.getFunctionNullability()), @@ -126,9 +126,9 @@ public FunctionDependencyDeclaration getFunctionDependencies() return builder.build(); } - private static void declareDependencies(FunctionDependencyDeclarationBuilder builder, Collection implementations) + private static void declareDependencies(FunctionDependencyDeclarationBuilder builder, Collection implementations) { - for (AggregationImplementation implementation : implementations) { + for (ParametricAggregationImplementation implementation : implementations) { for (ImplementationDependency dependency : implementation.getInputDependencies()) { dependency.declareDependencies(builder); } @@ -142,44 +142,54 @@ private static void declareDependencies(FunctionDependencyDeclarationBuilder bui } @Override - public AggregationMetadata specialize(BoundSignature boundSignature, FunctionDependencies functionDependencies) + public AggregationImplementation specialize(BoundSignature boundSignature, FunctionDependencies functionDependencies) { // Find implementation matching arguments - AggregationImplementation concreteImplementation = findMatchingImplementation(boundSignature); + ParametricAggregationImplementation concreteImplementation = findMatchingImplementation(boundSignature); + List inputParameterKinds = concreteImplementation.getInputParameterKinds(); // Build state factory and serializer + AggregationImplementation.Builder builder = AggregationImplementation.builder(); FunctionMetadata metadata = getFunctionMetadata(); FunctionBinding functionBinding = SignatureBinder.bindFunction(metadata.getFunctionId(), metadata.getSignature(), boundSignature); - List> accumulatorStateDescriptors = stateDetails.stream() + builder.accumulatorStateDescriptors(stateDetails.stream() .map(state -> state.createAccumulatorStateDescriptor(functionBinding, functionDependencies)) - .collect(toImmutableList()); + .collect(toImmutableList())); // Bind provided dependencies to aggregation method handlers - MethodHandle inputHandle = bindDependencies(concreteImplementation.getInputFunction(), concreteImplementation.getInputDependencies(), functionBinding, functionDependencies); - Optional removeInputHandle = concreteImplementation.getRemoveInputFunction().map( - removeInputFunction -> bindDependencies(removeInputFunction, concreteImplementation.getRemoveInputDependencies(), functionBinding, functionDependencies)); + builder.inputFunction(normalizeInputMethod( + bindDependencies( + concreteImplementation.getInputFunction(), + concreteImplementation.getInputDependencies(), + functionBinding, + functionDependencies), + boundSignature, + inputParameterKinds)); + concreteImplementation.getRemoveInputFunction() + .map(removeInputFunction -> bindDependencies( + removeInputFunction, + concreteImplementation.getRemoveInputDependencies(), + functionBinding, + functionDependencies)) + .map(removeInputFunction -> normalizeInputMethod(removeInputFunction, boundSignature, inputParameterKinds)) + .ifPresent(builder::removeInputFunction); - Optional combineHandle = concreteImplementation.getCombineFunction(); if (getAggregationMetadata().isDecomposable()) { - checkArgument(combineHandle.isPresent(), "Decomposable method %s does not have a combine method", boundSignature.getName()); - combineHandle = combineHandle.map(combineFunction -> bindDependencies(combineFunction, concreteImplementation.getCombineDependencies(), functionBinding, functionDependencies)); + MethodHandle combineHandle = concreteImplementation.getCombineFunction() + .orElseThrow(() -> new IllegalArgumentException(format("Decomposable method %s does not have a combine method", boundSignature.getName()))); + builder.combineFunction(bindDependencies(combineHandle, concreteImplementation.getCombineDependencies(), functionBinding, functionDependencies)); } else { checkArgument(concreteImplementation.getCombineFunction().isEmpty(), "Decomposable method %s does not have a combine method", boundSignature.getName()); } - MethodHandle outputHandle = bindDependencies(concreteImplementation.getOutputFunction(), concreteImplementation.getOutputDependencies(), functionBinding, functionDependencies); + builder.outputFunction(bindDependencies( + concreteImplementation.getOutputFunction(), + concreteImplementation.getOutputDependencies(), + functionBinding, + functionDependencies)); - List inputParameterKinds = concreteImplementation.getInputParameterKinds(); - inputHandle = normalizeInputMethod(inputHandle, boundSignature, inputParameterKinds); - removeInputHandle = removeInputHandle.map(function -> normalizeInputMethod(function, boundSignature, inputParameterKinds)); - - return new AggregationMetadata( - inputHandle, - removeInputHandle, - combineHandle, - outputHandle, - accumulatorStateDescriptors); + return builder.build(); } @VisibleForTesting @@ -189,20 +199,20 @@ public List> getStateDetails() } @VisibleForTesting - public ParametricImplementationsGroup getImplementations() + public ParametricImplementationsGroup getImplementations() { return implementations; } - private AggregationImplementation findMatchingImplementation(BoundSignature boundSignature) + private ParametricAggregationImplementation findMatchingImplementation(BoundSignature boundSignature) { Signature signature = boundSignature.toSignature(); - Optional foundImplementation = Optional.empty(); + Optional foundImplementation = Optional.empty(); if (implementations.getExactImplementations().containsKey(signature)) { foundImplementation = Optional.of(implementations.getExactImplementations().get(signature)); } else { - for (AggregationImplementation candidate : implementations.getGenericImplementations()) { + for (ParametricAggregationImplementation candidate : implementations.getGenericImplementations()) { if (candidate.areTypesAssignable(boundSignature)) { if (foundImplementation.isPresent()) { throw new TrinoException(AMBIGUOUS_FUNCTION_CALL, format("Ambiguous function call (%s) for %s", boundSignature, getFunctionMetadata().getSignature())); diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/AggregationImplementation.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/ParametricAggregationImplementation.java similarity index 97% rename from core/trino-main/src/main/java/io/trino/operator/aggregation/AggregationImplementation.java rename to core/trino-main/src/main/java/io/trino/operator/aggregation/ParametricAggregationImplementation.java index 20542d890fd8..7e3973f4c63d 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/AggregationImplementation.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/ParametricAggregationImplementation.java @@ -15,9 +15,6 @@ import com.google.common.base.VerifyException; import com.google.common.collect.ImmutableList; -import io.trino.metadata.BoundSignature; -import io.trino.metadata.FunctionNullability; -import io.trino.metadata.Signature; import io.trino.operator.ParametricImplementation; import io.trino.operator.aggregation.AggregationFromAnnotationsParser.AccumulatorStateDetails; import io.trino.operator.aggregation.AggregationFunctionAdapter.AggregationParameterKind; @@ -27,7 +24,10 @@ import io.trino.spi.function.AggregationState; import io.trino.spi.function.BlockIndex; import io.trino.spi.function.BlockPosition; +import io.trino.spi.function.BoundSignature; +import io.trino.spi.function.FunctionNullability; import io.trino.spi.function.OutputFunction; +import io.trino.spi.function.Signature; import io.trino.spi.function.SqlType; import io.trino.spi.function.TypeParameter; import io.trino.spi.type.TypeSignature; @@ -65,7 +65,7 @@ import static io.trino.util.Reflection.methodHandle; import static java.util.Objects.requireNonNull; -public class AggregationImplementation +public class ParametricAggregationImplementation implements ParametricImplementation { public static class AggregateNativeContainerType @@ -105,7 +105,7 @@ public boolean isBlockPosition() private final List inputParameterKinds; private final FunctionNullability functionNullability; - public AggregationImplementation( + public ParametricAggregationImplementation( Signature signature, Class definitionClass, MethodHandle inputFunction, @@ -232,9 +232,9 @@ public boolean areTypesAssignable(BoundSignature boundSignature) } @Override - public AggregationImplementation withAlias(String alias) + public ParametricImplementation withAlias(String alias) { - return new AggregationImplementation( + return new ParametricAggregationImplementation( signature.withName(alias), definitionClass, inputFunction, @@ -323,9 +323,9 @@ private Parser( outputHandle = methodHandle(outputFunction); } - private AggregationImplementation get() + private ParametricAggregationImplementation get() { - return new AggregationImplementation( + return new ParametricAggregationImplementation( signatureBuilder.build(), aggregationDefinition, inputHandle, @@ -340,7 +340,7 @@ private AggregationImplementation get() inputParameterKinds); } - public static AggregationImplementation parseImplementation( + public static ParametricAggregationImplementation parseImplementation( Class aggregationDefinition, String name, List> stateDetails, diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/ReduceAggregationFunction.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/ReduceAggregationFunction.java index d88ddd8e9d9d..4959ebab017b 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/ReduceAggregationFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/ReduceAggregationFunction.java @@ -14,12 +14,7 @@ package io.trino.operator.aggregation; import com.google.common.collect.ImmutableList; -import io.trino.metadata.AggregationFunctionMetadata; -import io.trino.metadata.BoundSignature; -import io.trino.metadata.FunctionMetadata; -import io.trino.metadata.Signature; import io.trino.metadata.SqlAggregationFunction; -import io.trino.operator.aggregation.AggregationMetadata.AccumulatorStateDescriptor; import io.trino.operator.aggregation.state.GenericBooleanState; import io.trino.operator.aggregation.state.GenericBooleanStateSerializer; import io.trino.operator.aggregation.state.GenericDoubleState; @@ -29,16 +24,19 @@ import io.trino.operator.aggregation.state.StateCompiler; import io.trino.spi.TrinoException; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.function.AggregationFunctionMetadata; +import io.trino.spi.function.AggregationImplementation; +import io.trino.spi.function.BoundSignature; +import io.trino.spi.function.FunctionMetadata; +import io.trino.spi.function.Signature; import io.trino.spi.type.Type; import io.trino.spi.type.TypeSignature; import io.trino.sql.gen.lambda.BinaryFunctionInterface; import java.lang.invoke.MethodHandle; -import java.util.Optional; import static io.trino.operator.aggregation.AggregationFunctionAdapter.AggregationParameterKind.INPUT_CHANNEL; import static io.trino.operator.aggregation.AggregationFunctionAdapter.AggregationParameterKind.STATE; -import static io.trino.operator.aggregation.AggregationFunctionAdapter.normalizeInputMethod; import static io.trino.spi.StandardErrorCode.NOT_SUPPORTED; import static io.trino.spi.type.TypeSignature.functionType; import static io.trino.util.Reflection.methodHandle; @@ -84,42 +82,46 @@ public ReduceAggregationFunction() } @Override - public AggregationMetadata specialize(BoundSignature boundSignature) + public AggregationImplementation specialize(BoundSignature boundSignature) { Type inputType = boundSignature.getArgumentTypes().get(0); Type stateType = boundSignature.getArgumentTypes().get(1); - MethodHandle inputMethodHandle; - MethodHandle combineMethodHandle; - MethodHandle outputMethodHandle; - AccumulatorStateDescriptor stateDescriptor; - if (stateType.getJavaType() == long.class) { - inputMethodHandle = LONG_STATE_INPUT_FUNCTION; - combineMethodHandle = LONG_STATE_COMBINE_FUNCTION; - outputMethodHandle = LONG_STATE_OUTPUT_FUNCTION.bindTo(stateType); - stateDescriptor = new AccumulatorStateDescriptor<>( - GenericLongState.class, - new GenericLongStateSerializer(stateType), - StateCompiler.generateStateFactory(GenericLongState.class)); + return AggregationImplementation.builder() + .inputFunction(normalizeInputMethod(boundSignature, inputType, LONG_STATE_INPUT_FUNCTION)) + .combineFunction(LONG_STATE_COMBINE_FUNCTION) + .outputFunction(LONG_STATE_OUTPUT_FUNCTION.bindTo(stateType)) + .accumulatorStateDescriptor( + GenericLongState.class, + new GenericLongStateSerializer(stateType), + StateCompiler.generateStateFactory(GenericLongState.class)) + .lambdaInterfaces(BinaryFunctionInterface.class, BinaryFunctionInterface.class) + .build(); } else if (stateType.getJavaType() == double.class) { - inputMethodHandle = DOUBLE_STATE_INPUT_FUNCTION; - combineMethodHandle = DOUBLE_STATE_COMBINE_FUNCTION; - outputMethodHandle = DOUBLE_STATE_OUTPUT_FUNCTION.bindTo(stateType); - stateDescriptor = new AccumulatorStateDescriptor<>( - GenericDoubleState.class, - new GenericDoubleStateSerializer(stateType), - StateCompiler.generateStateFactory(GenericDoubleState.class)); + return AggregationImplementation.builder() + .inputFunction(normalizeInputMethod(boundSignature, inputType, DOUBLE_STATE_INPUT_FUNCTION)) + .combineFunction(DOUBLE_STATE_COMBINE_FUNCTION) + .outputFunction(DOUBLE_STATE_OUTPUT_FUNCTION.bindTo(stateType)) + .accumulatorStateDescriptor( + GenericDoubleState.class, + new GenericDoubleStateSerializer(stateType), + StateCompiler.generateStateFactory(GenericDoubleState.class)) + .lambdaInterfaces(BinaryFunctionInterface.class, BinaryFunctionInterface.class) + .build(); } else if (stateType.getJavaType() == boolean.class) { - inputMethodHandle = BOOLEAN_STATE_INPUT_FUNCTION; - combineMethodHandle = BOOLEAN_STATE_COMBINE_FUNCTION; - outputMethodHandle = BOOLEAN_STATE_OUTPUT_FUNCTION.bindTo(stateType); - stateDescriptor = new AccumulatorStateDescriptor<>( - GenericBooleanState.class, - new GenericBooleanStateSerializer(stateType), - StateCompiler.generateStateFactory(GenericBooleanState.class)); + return AggregationImplementation.builder() + .inputFunction(normalizeInputMethod(boundSignature, inputType, BOOLEAN_STATE_INPUT_FUNCTION)) + .combineFunction(BOOLEAN_STATE_COMBINE_FUNCTION) + .outputFunction(BOOLEAN_STATE_OUTPUT_FUNCTION.bindTo(stateType)) + .accumulatorStateDescriptor( + GenericBooleanState.class, + new GenericBooleanStateSerializer(stateType), + StateCompiler.generateStateFactory(GenericBooleanState.class)) + .lambdaInterfaces(BinaryFunctionInterface.class, BinaryFunctionInterface.class) + .build(); } else { // State with Slice or Block as native container type is intentionally not supported yet, @@ -127,17 +129,13 @@ else if (stateType.getJavaType() == boolean.class) { // See JDK-8017163. throw new TrinoException(NOT_SUPPORTED, format("State type not supported for %s: %s", NAME, stateType.getDisplayName())); } + } + private static MethodHandle normalizeInputMethod(BoundSignature boundSignature, Type inputType, MethodHandle inputMethodHandle) + { inputMethodHandle = inputMethodHandle.asType(inputMethodHandle.type().changeParameterType(1, inputType.getJavaType())); - inputMethodHandle = normalizeInputMethod(inputMethodHandle, boundSignature, ImmutableList.of(STATE, INPUT_CHANNEL, INPUT_CHANNEL), 2); - - return new AggregationMetadata( - inputMethodHandle, - Optional.empty(), - Optional.of(combineMethodHandle), - outputMethodHandle, - ImmutableList.of(stateDescriptor), - ImmutableList.of(BinaryFunctionInterface.class, BinaryFunctionInterface.class)); + inputMethodHandle = AggregationFunctionAdapter.normalizeInputMethod(inputMethodHandle, boundSignature, ImmutableList.of(STATE, INPUT_CHANNEL, INPUT_CHANNEL), 2); + return inputMethodHandle; } public static void input(GenericLongState state, Object value, long initialStateValue, BinaryFunctionInterface inputFunction, BinaryFunctionInterface combineFunction) diff --git a/core/trino-main/src/main/java/io/trino/operator/annotations/CastImplementationDependency.java b/core/trino-main/src/main/java/io/trino/operator/annotations/CastImplementationDependency.java index bf8c3839abb2..58b852f473be 100644 --- a/core/trino-main/src/main/java/io/trino/operator/annotations/CastImplementationDependency.java +++ b/core/trino-main/src/main/java/io/trino/operator/annotations/CastImplementationDependency.java @@ -14,10 +14,10 @@ package io.trino.operator.annotations; import io.trino.metadata.FunctionBinding; -import io.trino.metadata.FunctionDependencies; -import io.trino.metadata.FunctionDependencyDeclaration.FunctionDependencyDeclarationBuilder; -import io.trino.metadata.FunctionInvoker; +import io.trino.spi.function.FunctionDependencies; +import io.trino.spi.function.FunctionDependencyDeclaration.FunctionDependencyDeclarationBuilder; import io.trino.spi.function.InvocationConvention; +import io.trino.spi.function.ScalarFunctionImplementation; import io.trino.spi.type.TypeSignature; import java.util.Objects; @@ -55,11 +55,11 @@ public void declareDependencies(FunctionDependencyDeclarationBuilder builder) } @Override - protected FunctionInvoker getInvoker(FunctionBinding functionBinding, FunctionDependencies functionDependencies, InvocationConvention invocationConvention) + protected ScalarFunctionImplementation getImplementation(FunctionBinding functionBinding, FunctionDependencies functionDependencies, InvocationConvention invocationConvention) { TypeSignature from = applyBoundVariables(fromType, functionBinding); TypeSignature to = applyBoundVariables(toType, functionBinding); - return functionDependencies.getCastSignatureInvoker(from, to, invocationConvention); + return functionDependencies.getCastImplementationSignature(from, to, invocationConvention); } @Override diff --git a/core/trino-main/src/main/java/io/trino/operator/annotations/FunctionImplementationDependency.java b/core/trino-main/src/main/java/io/trino/operator/annotations/FunctionImplementationDependency.java index 1d5244071810..d7b1c820a3ce 100644 --- a/core/trino-main/src/main/java/io/trino/operator/annotations/FunctionImplementationDependency.java +++ b/core/trino-main/src/main/java/io/trino/operator/annotations/FunctionImplementationDependency.java @@ -14,12 +14,12 @@ package io.trino.operator.annotations; import io.trino.metadata.FunctionBinding; -import io.trino.metadata.FunctionDependencies; -import io.trino.metadata.FunctionDependencyDeclaration.FunctionDependencyDeclarationBuilder; -import io.trino.metadata.FunctionInvoker; +import io.trino.spi.function.FunctionDependencies; +import io.trino.spi.function.FunctionDependencyDeclaration.FunctionDependencyDeclarationBuilder; import io.trino.spi.function.InvocationConvention; +import io.trino.spi.function.QualifiedFunctionName; +import io.trino.spi.function.ScalarFunctionImplementation; import io.trino.spi.type.TypeSignature; -import io.trino.sql.tree.QualifiedName; import java.util.List; import java.util.Objects; @@ -30,19 +30,19 @@ public final class FunctionImplementationDependency extends ScalarImplementationDependency { - private final QualifiedName fullyQualifiedName; + private final QualifiedFunctionName fullyQualifiedFunctionName; private final List argumentTypes; - public FunctionImplementationDependency(QualifiedName fullyQualifiedName, List argumentTypes, InvocationConvention invocationConvention, Class type) + public FunctionImplementationDependency(QualifiedFunctionName fullyQualifiedFunctionName, List argumentTypes, InvocationConvention invocationConvention, Class type) { super(invocationConvention, type); - this.fullyQualifiedName = requireNonNull(fullyQualifiedName, "fullyQualifiedName is null"); + this.fullyQualifiedFunctionName = requireNonNull(fullyQualifiedFunctionName, "fullyQualifiedFunctionName is null"); this.argumentTypes = requireNonNull(argumentTypes, "argumentTypes is null"); } - public QualifiedName getFullyQualifiedName() + public QualifiedFunctionName getFullyQualifiedName() { - return fullyQualifiedName; + return fullyQualifiedFunctionName; } public List getArgumentTypes() @@ -53,14 +53,14 @@ public List getArgumentTypes() @Override public void declareDependencies(FunctionDependencyDeclarationBuilder builder) { - builder.addFunctionSignature(fullyQualifiedName, argumentTypes); + builder.addFunctionSignature(fullyQualifiedFunctionName, argumentTypes); } @Override - protected FunctionInvoker getInvoker(FunctionBinding functionBinding, FunctionDependencies functionDependencies, InvocationConvention invocationConvention) + protected ScalarFunctionImplementation getImplementation(FunctionBinding functionBinding, FunctionDependencies functionDependencies, InvocationConvention invocationConvention) { List types = applyBoundVariables(argumentTypes, functionBinding); - return functionDependencies.getFunctionSignatureInvoker(fullyQualifiedName, types, invocationConvention); + return functionDependencies.getScalarFunctionImplementationSignature(fullyQualifiedFunctionName, types, invocationConvention); } @Override @@ -73,13 +73,13 @@ public boolean equals(Object o) return false; } FunctionImplementationDependency that = (FunctionImplementationDependency) o; - return Objects.equals(fullyQualifiedName, that.fullyQualifiedName) && + return Objects.equals(fullyQualifiedFunctionName, that.fullyQualifiedFunctionName) && Objects.equals(argumentTypes, that.argumentTypes); } @Override public int hashCode() { - return Objects.hash(fullyQualifiedName, argumentTypes); + return Objects.hash(fullyQualifiedFunctionName, argumentTypes); } } diff --git a/core/trino-main/src/main/java/io/trino/operator/annotations/FunctionsParserHelper.java b/core/trino-main/src/main/java/io/trino/operator/annotations/FunctionsParserHelper.java index c363eda8b20b..fe910d40ba82 100644 --- a/core/trino-main/src/main/java/io/trino/operator/annotations/FunctionsParserHelper.java +++ b/core/trino-main/src/main/java/io/trino/operator/annotations/FunctionsParserHelper.java @@ -16,18 +16,18 @@ import com.google.common.collect.HashMultimap; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; -import io.trino.metadata.Signature; -import io.trino.metadata.Signature.Builder; -import io.trino.metadata.TypeVariableConstraint; -import io.trino.metadata.TypeVariableConstraint.TypeVariableConstraintBuilder; import io.trino.spi.function.Description; import io.trino.spi.function.IsNull; import io.trino.spi.function.LiteralParameters; import io.trino.spi.function.OperatorType; +import io.trino.spi.function.Signature; +import io.trino.spi.function.Signature.Builder; import io.trino.spi.function.SqlNullable; import io.trino.spi.function.SqlType; import io.trino.spi.function.TypeParameter; import io.trino.spi.function.TypeParameterSpecialization; +import io.trino.spi.function.TypeVariableConstraint; +import io.trino.spi.function.TypeVariableConstraint.TypeVariableConstraintBuilder; import io.trino.spi.type.TypeSignature; import io.trino.spi.type.TypeSignatureParameter; import io.trino.type.Constraint; diff --git a/core/trino-main/src/main/java/io/trino/operator/annotations/ImplementationDependency.java b/core/trino-main/src/main/java/io/trino/operator/annotations/ImplementationDependency.java index 07df7f14616b..68f438c74f3d 100644 --- a/core/trino-main/src/main/java/io/trino/operator/annotations/ImplementationDependency.java +++ b/core/trino-main/src/main/java/io/trino/operator/annotations/ImplementationDependency.java @@ -15,19 +15,19 @@ import com.google.common.collect.ImmutableSet; import io.trino.metadata.FunctionBinding; -import io.trino.metadata.FunctionDependencies; -import io.trino.metadata.FunctionDependencyDeclaration.FunctionDependencyDeclarationBuilder; import io.trino.spi.function.CastDependency; import io.trino.spi.function.Convention; +import io.trino.spi.function.FunctionDependencies; import io.trino.spi.function.FunctionDependency; +import io.trino.spi.function.FunctionDependencyDeclaration.FunctionDependencyDeclarationBuilder; import io.trino.spi.function.InvocationConvention; import io.trino.spi.function.LiteralParameter; import io.trino.spi.function.OperatorDependency; import io.trino.spi.function.OperatorType; +import io.trino.spi.function.QualifiedFunctionName; import io.trino.spi.function.TypeParameter; import io.trino.spi.type.TypeSignature; import io.trino.spi.type.TypeSignatureParameter; -import io.trino.sql.tree.QualifiedName; import java.lang.annotation.Annotation; import java.lang.reflect.AnnotatedElement; @@ -111,7 +111,7 @@ public static ImplementationDependency createDependency(Annotation annotation, S if (annotation instanceof FunctionDependency) { FunctionDependency functionDependency = (FunctionDependency) annotation; return new FunctionImplementationDependency( - QualifiedName.of(functionDependency.name()), + QualifiedFunctionName.of(functionDependency.name()), Arrays.stream(functionDependency.argumentTypes()) .map(signature -> parseTypeSignature(signature, literalParameters)) .collect(toImmutableList()), diff --git a/core/trino-main/src/main/java/io/trino/operator/annotations/LiteralImplementationDependency.java b/core/trino-main/src/main/java/io/trino/operator/annotations/LiteralImplementationDependency.java index 304b76883386..3e691c7e4534 100644 --- a/core/trino-main/src/main/java/io/trino/operator/annotations/LiteralImplementationDependency.java +++ b/core/trino-main/src/main/java/io/trino/operator/annotations/LiteralImplementationDependency.java @@ -14,8 +14,8 @@ package io.trino.operator.annotations; import io.trino.metadata.FunctionBinding; -import io.trino.metadata.FunctionDependencies; -import io.trino.metadata.FunctionDependencyDeclaration.FunctionDependencyDeclarationBuilder; +import io.trino.spi.function.FunctionDependencies; +import io.trino.spi.function.FunctionDependencyDeclaration.FunctionDependencyDeclarationBuilder; import static java.util.Objects.requireNonNull; diff --git a/core/trino-main/src/main/java/io/trino/operator/annotations/OperatorImplementationDependency.java b/core/trino-main/src/main/java/io/trino/operator/annotations/OperatorImplementationDependency.java index 13736365bd82..705be0b3cd4d 100644 --- a/core/trino-main/src/main/java/io/trino/operator/annotations/OperatorImplementationDependency.java +++ b/core/trino-main/src/main/java/io/trino/operator/annotations/OperatorImplementationDependency.java @@ -15,11 +15,11 @@ import com.google.common.collect.ImmutableList; import io.trino.metadata.FunctionBinding; -import io.trino.metadata.FunctionDependencies; -import io.trino.metadata.FunctionDependencyDeclaration.FunctionDependencyDeclarationBuilder; -import io.trino.metadata.FunctionInvoker; +import io.trino.spi.function.FunctionDependencies; +import io.trino.spi.function.FunctionDependencyDeclaration.FunctionDependencyDeclarationBuilder; import io.trino.spi.function.InvocationConvention; import io.trino.spi.function.OperatorType; +import io.trino.spi.function.ScalarFunctionImplementation; import io.trino.spi.type.TypeSignature; import java.util.List; @@ -62,10 +62,10 @@ public void declareDependencies(FunctionDependencyDeclarationBuilder builder) } @Override - protected FunctionInvoker getInvoker(FunctionBinding functionBinding, FunctionDependencies functionDependencies, InvocationConvention invocationConvention) + protected ScalarFunctionImplementation getImplementation(FunctionBinding functionBinding, FunctionDependencies functionDependencies, InvocationConvention invocationConvention) { List types = applyBoundVariables(argumentTypes, functionBinding); - return functionDependencies.getOperatorSignatureInvoker(operator, types, invocationConvention); + return functionDependencies.getOperatorImplementationSignature(operator, types, invocationConvention); } @Override diff --git a/core/trino-main/src/main/java/io/trino/operator/annotations/ScalarImplementationDependency.java b/core/trino-main/src/main/java/io/trino/operator/annotations/ScalarImplementationDependency.java index e3d79441f05b..83cdaf50e96d 100644 --- a/core/trino-main/src/main/java/io/trino/operator/annotations/ScalarImplementationDependency.java +++ b/core/trino-main/src/main/java/io/trino/operator/annotations/ScalarImplementationDependency.java @@ -14,10 +14,10 @@ package io.trino.operator.annotations; import io.trino.metadata.FunctionBinding; -import io.trino.metadata.FunctionDependencies; -import io.trino.metadata.FunctionInvoker; import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.function.FunctionDependencies; import io.trino.spi.function.InvocationConvention; +import io.trino.spi.function.ScalarFunctionImplementation; import java.lang.invoke.MethodHandle; import java.lang.invoke.MethodHandleProxies; @@ -50,12 +50,12 @@ public Class getType() return type; } - protected abstract FunctionInvoker getInvoker(FunctionBinding functionBinding, FunctionDependencies functionDependencies, InvocationConvention invocationConvention); + protected abstract ScalarFunctionImplementation getImplementation(FunctionBinding functionBinding, FunctionDependencies functionDependencies, InvocationConvention invocationConvention); @Override public Object resolve(FunctionBinding functionBinding, FunctionDependencies functionDependencies) { - MethodHandle methodHandle = getInvoker(functionBinding, functionDependencies, invocationConvention).getMethodHandle(); + MethodHandle methodHandle = getImplementation(functionBinding, functionDependencies, invocationConvention).getMethodHandle(); if (invocationConvention.supportsSession() && !methodHandle.type().parameterType(0).equals(ConnectorSession.class)) { methodHandle = dropArguments(methodHandle, 0, ConnectorSession.class); } diff --git a/core/trino-main/src/main/java/io/trino/operator/annotations/TypeImplementationDependency.java b/core/trino-main/src/main/java/io/trino/operator/annotations/TypeImplementationDependency.java index f44a70bc6ce7..26a15cb5726e 100644 --- a/core/trino-main/src/main/java/io/trino/operator/annotations/TypeImplementationDependency.java +++ b/core/trino-main/src/main/java/io/trino/operator/annotations/TypeImplementationDependency.java @@ -14,8 +14,8 @@ package io.trino.operator.annotations; import io.trino.metadata.FunctionBinding; -import io.trino.metadata.FunctionDependencies; -import io.trino.metadata.FunctionDependencyDeclaration.FunctionDependencyDeclarationBuilder; +import io.trino.spi.function.FunctionDependencies; +import io.trino.spi.function.FunctionDependencyDeclaration.FunctionDependencyDeclarationBuilder; import io.trino.spi.type.TypeSignature; import java.util.Objects; diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/AbstractGreatestLeast.java b/core/trino-main/src/main/java/io/trino/operator/scalar/AbstractGreatestLeast.java index 5e76fd2eb48a..057779be9e66 100644 --- a/core/trino-main/src/main/java/io/trino/operator/scalar/AbstractGreatestLeast.java +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/AbstractGreatestLeast.java @@ -24,12 +24,12 @@ import io.airlift.bytecode.control.IfStatement; import io.airlift.bytecode.expression.BytecodeExpression; import io.airlift.bytecode.instruction.LabelNode; -import io.trino.metadata.BoundSignature; -import io.trino.metadata.FunctionDependencies; -import io.trino.metadata.FunctionDependencyDeclaration; -import io.trino.metadata.FunctionMetadata; -import io.trino.metadata.Signature; import io.trino.metadata.SqlScalarFunction; +import io.trino.spi.function.BoundSignature; +import io.trino.spi.function.FunctionDependencies; +import io.trino.spi.function.FunctionDependencyDeclaration; +import io.trino.spi.function.FunctionMetadata; +import io.trino.spi.function.Signature; import io.trino.spi.type.Type; import io.trino.spi.type.TypeSignature; import io.trino.sql.gen.CallSiteBinder; @@ -98,7 +98,7 @@ public FunctionDependencyDeclaration getFunctionDependencies() } @Override - public ScalarFunctionImplementation specialize(BoundSignature boundSignature, FunctionDependencies functionDependencies) + public SpecializedSqlScalarFunction specialize(BoundSignature boundSignature, FunctionDependencies functionDependencies) { Type type = boundSignature.getReturnType(); checkArgument(type.isOrderable(), "Type must be orderable"); @@ -112,7 +112,7 @@ public ScalarFunctionImplementation specialize(BoundSignature boundSignature, Fu Class clazz = generate(javaTypes, compareMethod); MethodHandle methodHandle = methodHandle(clazz, getFunctionMetadata().getSignature().getName(), javaTypes.toArray(new Class[0])); - return new ChoicesScalarFunctionImplementation( + return new ChoicesSpecializedSqlScalarFunction( boundSignature, NULLABLE_RETURN, nCopies(javaTypes.size(), BOXED_NULLABLE), diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/ApplyFunction.java b/core/trino-main/src/main/java/io/trino/operator/scalar/ApplyFunction.java index b54cbdf69042..6ffc64c3b3c8 100644 --- a/core/trino-main/src/main/java/io/trino/operator/scalar/ApplyFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/ApplyFunction.java @@ -16,10 +16,10 @@ import com.google.common.collect.ImmutableList; import com.google.common.primitives.Primitives; import io.trino.annotation.UsedByGeneratedCode; -import io.trino.metadata.BoundSignature; -import io.trino.metadata.FunctionMetadata; -import io.trino.metadata.Signature; import io.trino.metadata.SqlScalarFunction; +import io.trino.spi.function.BoundSignature; +import io.trino.spi.function.FunctionMetadata; +import io.trino.spi.function.Signature; import io.trino.spi.type.Type; import io.trino.spi.type.TypeSignature; import io.trino.sql.gen.lambda.UnaryFunctionInterface; @@ -62,11 +62,11 @@ private ApplyFunction() } @Override - protected ScalarFunctionImplementation specialize(BoundSignature boundSignature) + protected SpecializedSqlScalarFunction specialize(BoundSignature boundSignature) { Type argumentType = boundSignature.getArgumentTypes().get(0); Type returnType = boundSignature.getReturnType(); - return new ChoicesScalarFunctionImplementation( + return new ChoicesSpecializedSqlScalarFunction( boundSignature, NULLABLE_RETURN, ImmutableList.of(BOXED_NULLABLE, FUNCTION), diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayConcatFunction.java b/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayConcatFunction.java index a7e3640854f2..86d42d8b0501 100644 --- a/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayConcatFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayConcatFunction.java @@ -15,14 +15,14 @@ import com.google.common.collect.ImmutableList; import io.trino.annotation.UsedByGeneratedCode; -import io.trino.metadata.BoundSignature; -import io.trino.metadata.FunctionMetadata; -import io.trino.metadata.Signature; import io.trino.metadata.SqlScalarFunction; import io.trino.spi.PageBuilder; import io.trino.spi.TrinoException; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.function.BoundSignature; +import io.trino.spi.function.FunctionMetadata; +import io.trino.spi.function.Signature; import io.trino.spi.type.ArrayType; import io.trino.spi.type.Type; import io.trino.spi.type.TypeSignature; @@ -65,7 +65,7 @@ private ArrayConcatFunction() } @Override - protected ScalarFunctionImplementation specialize(BoundSignature boundSignature) + protected SpecializedSqlScalarFunction specialize(BoundSignature boundSignature) { if (boundSignature.getArity() < 2) { throw new TrinoException(INVALID_FUNCTION_ARGUMENT, "There must be two or more arguments to " + FUNCTION_NAME); @@ -80,7 +80,7 @@ protected ScalarFunctionImplementation specialize(BoundSignature boundSignature) METHOD_HANDLE.bindTo(arrayType.getElementType()), USER_STATE_FACTORY.bindTo(arrayType.getElementType())); - return new ChoicesScalarFunctionImplementation( + return new ChoicesSpecializedSqlScalarFunction( boundSignature, FAIL_ON_NULL, nCopies(boundSignature.getArity(), NEVER_NULL), diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayConstructor.java b/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayConstructor.java index 950549882716..6b8a3676f026 100644 --- a/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayConstructor.java +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayConstructor.java @@ -25,13 +25,13 @@ import io.airlift.bytecode.Variable; import io.airlift.bytecode.control.IfStatement; import io.airlift.bytecode.expression.BytecodeExpression; -import io.trino.metadata.BoundSignature; -import io.trino.metadata.FunctionMetadata; -import io.trino.metadata.Signature; import io.trino.metadata.SqlScalarFunction; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; import io.trino.spi.block.BlockBuilderStatus; +import io.trino.spi.function.BoundSignature; +import io.trino.spi.function.FunctionMetadata; +import io.trino.spi.function.Signature; import io.trino.spi.type.Type; import io.trino.spi.type.TypeSignature; import io.trino.sql.gen.CallSiteBinder; @@ -86,7 +86,7 @@ public ArrayConstructor() } @Override - protected ScalarFunctionImplementation specialize(BoundSignature boundSignature) + protected SpecializedSqlScalarFunction specialize(BoundSignature boundSignature) { ImmutableList.Builder> builder = ImmutableList.builder(); Type type = boundSignature.getArgumentTypes().get(0); @@ -108,7 +108,7 @@ protected ScalarFunctionImplementation specialize(BoundSignature boundSignature) catch (ReflectiveOperationException e) { throw new RuntimeException(e); } - return new ChoicesScalarFunctionImplementation( + return new ChoicesSpecializedSqlScalarFunction( boundSignature, FAIL_ON_NULL, nCopies(stackTypes.size(), BOXED_NULLABLE), diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayFlattenFunction.java b/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayFlattenFunction.java index eb27cd66a3ac..323212978b73 100644 --- a/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayFlattenFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayFlattenFunction.java @@ -14,12 +14,12 @@ package io.trino.operator.scalar; import com.google.common.collect.ImmutableList; -import io.trino.metadata.BoundSignature; -import io.trino.metadata.FunctionMetadata; -import io.trino.metadata.Signature; import io.trino.metadata.SqlScalarFunction; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.function.BoundSignature; +import io.trino.spi.function.FunctionMetadata; +import io.trino.spi.function.Signature; import io.trino.spi.type.ArrayType; import io.trino.spi.type.Type; import io.trino.spi.type.TypeSignature; @@ -53,13 +53,13 @@ private ArrayFlattenFunction() } @Override - protected ScalarFunctionImplementation specialize(BoundSignature boundSignature) + protected SpecializedSqlScalarFunction specialize(BoundSignature boundSignature) { ArrayType arrayType = (ArrayType) boundSignature.getReturnType(); MethodHandle methodHandle = METHOD_HANDLE .bindTo(arrayType.getElementType()) .bindTo(arrayType); - return new ChoicesScalarFunctionImplementation( + return new ChoicesSpecializedSqlScalarFunction( boundSignature, FAIL_ON_NULL, ImmutableList.of(NEVER_NULL), diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayJoin.java b/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayJoin.java index cd0c03f3bf54..30fb04145069 100644 --- a/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayJoin.java +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayJoin.java @@ -16,19 +16,19 @@ import com.google.common.collect.ImmutableList; import io.airlift.slice.Slice; import io.trino.annotation.UsedByGeneratedCode; -import io.trino.metadata.BoundSignature; -import io.trino.metadata.FunctionDependencies; -import io.trino.metadata.FunctionDependencyDeclaration; -import io.trino.metadata.FunctionMetadata; -import io.trino.metadata.Signature; import io.trino.metadata.SqlScalarFunction; import io.trino.spi.PageBuilder; import io.trino.spi.TrinoException; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.function.BoundSignature; +import io.trino.spi.function.FunctionDependencies; +import io.trino.spi.function.FunctionDependencyDeclaration; +import io.trino.spi.function.FunctionMetadata; import io.trino.spi.function.InvocationConvention; import io.trino.spi.function.InvocationConvention.InvocationArgumentConvention; +import io.trino.spi.function.Signature; import io.trino.spi.type.ArrayType; import io.trino.spi.type.Type; import io.trino.spi.type.TypeSignature; @@ -105,7 +105,7 @@ public FunctionDependencyDeclaration getFunctionDependencies() } @Override - public ScalarFunctionImplementation specialize(BoundSignature boundSignature, FunctionDependencies functionDependencies) + public SpecializedSqlScalarFunction specialize(BoundSignature boundSignature, FunctionDependencies functionDependencies) { return specializeArrayJoin(boundSignature, functionDependencies, METHOD_HANDLE); } @@ -145,12 +145,12 @@ private static FunctionDependencyDeclaration arrayJoinFunctionDependencies() } @Override - public ScalarFunctionImplementation specialize(BoundSignature boundSignature, FunctionDependencies functionDependencies) + public SpecializedSqlScalarFunction specialize(BoundSignature boundSignature, FunctionDependencies functionDependencies) { return specializeArrayJoin(boundSignature, functionDependencies, METHOD_HANDLE); } - private static ChoicesScalarFunctionImplementation specializeArrayJoin( + private static ChoicesSpecializedSqlScalarFunction specializeArrayJoin( BoundSignature boundSignature, FunctionDependencies functionDependencies, MethodHandle methodHandle) @@ -159,7 +159,7 @@ private static ChoicesScalarFunctionImplementation specializeArrayJoin( Type type = ((ArrayType) boundSignature.getArgumentTypes().get(0)).getElementType(); if (type instanceof UnknownType) { - return new ChoicesScalarFunctionImplementation( + return new ChoicesSpecializedSqlScalarFunction( boundSignature, FAIL_ON_NULL, argumentConventions, @@ -169,7 +169,7 @@ private static ChoicesScalarFunctionImplementation specializeArrayJoin( else { try { InvocationConvention convention = new InvocationConvention(ImmutableList.of(BLOCK_POSITION), NULLABLE_RETURN, true, false); - MethodHandle cast = functionDependencies.getCastInvoker(type, VARCHAR, convention).getMethodHandle(); + MethodHandle cast = functionDependencies.getCastImplementation(type, VARCHAR, convention).getMethodHandle(); // if the cast doesn't take a ConnectorSession, create an adapter that drops the provided session if (cast.type().parameterArray()[0] != ConnectorSession.class) { @@ -177,7 +177,7 @@ private static ChoicesScalarFunctionImplementation specializeArrayJoin( } MethodHandle target = MethodHandles.insertArguments(methodHandle, 0, cast); - return new ChoicesScalarFunctionImplementation( + return new ChoicesSpecializedSqlScalarFunction( boundSignature, FAIL_ON_NULL, argumentConventions, diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayReduceFunction.java b/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayReduceFunction.java index a3e2b9e82a9c..685d1be86576 100644 --- a/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayReduceFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayReduceFunction.java @@ -15,11 +15,11 @@ import com.google.common.collect.ImmutableList; import com.google.common.primitives.Primitives; -import io.trino.metadata.BoundSignature; -import io.trino.metadata.FunctionMetadata; -import io.trino.metadata.Signature; import io.trino.metadata.SqlScalarFunction; import io.trino.spi.block.Block; +import io.trino.spi.function.BoundSignature; +import io.trino.spi.function.FunctionMetadata; +import io.trino.spi.function.Signature; import io.trino.spi.type.ArrayType; import io.trino.spi.type.Type; import io.trino.spi.type.TypeSignature; @@ -67,14 +67,14 @@ private ArrayReduceFunction() } @Override - protected ScalarFunctionImplementation specialize(BoundSignature boundSignature) + protected SpecializedSqlScalarFunction specialize(BoundSignature boundSignature) { ArrayType arrayType = (ArrayType) boundSignature.getArgumentTypes().get(0); Type inputType = arrayType.getElementType(); Type intermediateType = boundSignature.getArgumentTypes().get(1); Type outputType = boundSignature.getReturnType(); MethodHandle methodHandle = METHOD_HANDLE.bindTo(inputType); - return new ChoicesScalarFunctionImplementation( + return new ChoicesSpecializedSqlScalarFunction( boundSignature, NULLABLE_RETURN, ImmutableList.of(NEVER_NULL, BOXED_NULLABLE, FUNCTION, FUNCTION), diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/ArraySubscriptOperator.java b/core/trino-main/src/main/java/io/trino/operator/scalar/ArraySubscriptOperator.java index 14caaab79aae..07019bfff6ee 100644 --- a/core/trino-main/src/main/java/io/trino/operator/scalar/ArraySubscriptOperator.java +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/ArraySubscriptOperator.java @@ -16,12 +16,12 @@ import com.google.common.collect.ImmutableList; import io.airlift.slice.Slice; import io.trino.annotation.UsedByGeneratedCode; -import io.trino.metadata.BoundSignature; -import io.trino.metadata.FunctionMetadata; -import io.trino.metadata.Signature; import io.trino.metadata.SqlScalarFunction; import io.trino.spi.TrinoException; import io.trino.spi.block.Block; +import io.trino.spi.function.BoundSignature; +import io.trino.spi.function.FunctionMetadata; +import io.trino.spi.function.Signature; import io.trino.spi.type.Type; import io.trino.spi.type.TypeSignature; @@ -64,7 +64,7 @@ protected ArraySubscriptOperator() } @Override - protected ScalarFunctionImplementation specialize(BoundSignature boundSignature) + protected SpecializedSqlScalarFunction specialize(BoundSignature boundSignature) { Type elementType = boundSignature.getReturnType(); @@ -87,7 +87,7 @@ else if (elementType.getJavaType() == Slice.class) { } methodHandle = methodHandle.bindTo(elementType); requireNonNull(methodHandle, "methodHandle is null"); - return new ChoicesScalarFunctionImplementation( + return new ChoicesSpecializedSqlScalarFunction( boundSignature, NULLABLE_RETURN, ImmutableList.of(NEVER_NULL, NEVER_NULL), diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayToElementConcatFunction.java b/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayToElementConcatFunction.java index 4051e249b50a..2ab487ee2fad 100644 --- a/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayToElementConcatFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayToElementConcatFunction.java @@ -15,11 +15,11 @@ import com.google.common.collect.ImmutableList; import io.airlift.slice.Slice; -import io.trino.metadata.BoundSignature; -import io.trino.metadata.FunctionMetadata; -import io.trino.metadata.Signature; import io.trino.metadata.SqlScalarFunction; import io.trino.spi.block.Block; +import io.trino.spi.function.BoundSignature; +import io.trino.spi.function.FunctionMetadata; +import io.trino.spi.function.Signature; import io.trino.spi.type.Type; import io.trino.spi.type.TypeSignature; @@ -57,7 +57,7 @@ public ArrayToElementConcatFunction() } @Override - protected ScalarFunctionImplementation specialize(BoundSignature boundSignature) + protected SpecializedSqlScalarFunction specialize(BoundSignature boundSignature) { Type type = boundSignature.getArgumentTypes().get(1); MethodHandle methodHandle; @@ -78,7 +78,7 @@ else if (type.getJavaType() == Slice.class) { } methodHandle = methodHandle.bindTo(type); - return new ChoicesScalarFunctionImplementation( + return new ChoicesSpecializedSqlScalarFunction( boundSignature, FAIL_ON_NULL, ImmutableList.of(NEVER_NULL, NEVER_NULL), diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayToJsonCast.java b/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayToJsonCast.java index d3a6ff4006ef..9dcc1f25165f 100644 --- a/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayToJsonCast.java +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayToJsonCast.java @@ -18,11 +18,11 @@ import io.airlift.slice.DynamicSliceOutput; import io.airlift.slice.Slice; import io.airlift.slice.SliceOutput; -import io.trino.metadata.BoundSignature; -import io.trino.metadata.FunctionMetadata; -import io.trino.metadata.Signature; import io.trino.metadata.SqlScalarFunction; import io.trino.spi.block.Block; +import io.trino.spi.function.BoundSignature; +import io.trino.spi.function.FunctionMetadata; +import io.trino.spi.function.Signature; import io.trino.spi.type.ArrayType; import io.trino.spi.type.TypeSignature; import io.trino.util.JsonUtil.JsonGeneratorWriter; @@ -66,14 +66,14 @@ private ArrayToJsonCast(boolean legacyRowToJson) } @Override - protected ScalarFunctionImplementation specialize(BoundSignature boundSignature) + protected SpecializedSqlScalarFunction specialize(BoundSignature boundSignature) { ArrayType arrayType = (ArrayType) boundSignature.getArgumentTypes().get(0); checkCondition(canCastToJson(arrayType), INVALID_CAST_ARGUMENT, "Cannot cast %s to JSON", arrayType); JsonGeneratorWriter writer = JsonGeneratorWriter.createJsonGeneratorWriter(arrayType.getElementType(), legacyRowToJson); MethodHandle methodHandle = METHOD_HANDLE.bindTo(writer); - return new ChoicesScalarFunctionImplementation( + return new ChoicesSpecializedSqlScalarFunction( boundSignature, FAIL_ON_NULL, ImmutableList.of(NEVER_NULL), diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayTransformFunction.java b/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayTransformFunction.java index 0cbf1e8ac264..710b532a2574 100644 --- a/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayTransformFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayTransformFunction.java @@ -24,13 +24,13 @@ import io.airlift.bytecode.Variable; import io.airlift.bytecode.control.ForLoop; import io.airlift.bytecode.control.IfStatement; -import io.trino.metadata.BoundSignature; -import io.trino.metadata.FunctionMetadata; -import io.trino.metadata.Signature; import io.trino.metadata.SqlScalarFunction; import io.trino.spi.PageBuilder; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.function.BoundSignature; +import io.trino.spi.function.FunctionMetadata; +import io.trino.spi.function.Signature; import io.trino.spi.type.ArrayType; import io.trino.spi.type.Type; import io.trino.spi.type.TypeSignature; @@ -87,12 +87,12 @@ private ArrayTransformFunction() } @Override - protected ScalarFunctionImplementation specialize(BoundSignature boundSignature) + protected SpecializedSqlScalarFunction specialize(BoundSignature boundSignature) { Type inputType = ((ArrayType) boundSignature.getArgumentTypes().get(0)).getElementType(); Type outputType = ((ArrayType) boundSignature.getReturnType()).getElementType(); Class generatedClass = generateTransform(inputType, outputType); - return new ChoicesScalarFunctionImplementation( + return new ChoicesSpecializedSqlScalarFunction( boundSignature, FAIL_ON_NULL, ImmutableList.of(NEVER_NULL, FUNCTION), diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/CastFromUnknownOperator.java b/core/trino-main/src/main/java/io/trino/operator/scalar/CastFromUnknownOperator.java index 173ace9ce167..007c74406a66 100644 --- a/core/trino-main/src/main/java/io/trino/operator/scalar/CastFromUnknownOperator.java +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/CastFromUnknownOperator.java @@ -15,10 +15,10 @@ import com.google.common.collect.ImmutableList; import io.trino.annotation.UsedByGeneratedCode; -import io.trino.metadata.BoundSignature; -import io.trino.metadata.FunctionMetadata; -import io.trino.metadata.Signature; import io.trino.metadata.SqlScalarFunction; +import io.trino.spi.function.BoundSignature; +import io.trino.spi.function.FunctionMetadata; +import io.trino.spi.function.Signature; import io.trino.spi.type.Type; import io.trino.spi.type.TypeSignature; @@ -49,11 +49,11 @@ public CastFromUnknownOperator() } @Override - protected ScalarFunctionImplementation specialize(BoundSignature boundSignature) + protected SpecializedSqlScalarFunction specialize(BoundSignature boundSignature) { Type toType = boundSignature.getReturnType(); MethodHandle methodHandle = METHOD_HANDLE_NON_NULL.asType(METHOD_HANDLE_NON_NULL.type().changeReturnType(toType.getJavaType())); - return new ChoicesScalarFunctionImplementation( + return new ChoicesSpecializedSqlScalarFunction( boundSignature, FAIL_ON_NULL, ImmutableList.of(NEVER_NULL), diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/ChoicesScalarFunctionImplementation.java b/core/trino-main/src/main/java/io/trino/operator/scalar/ChoicesSpecializedSqlScalarFunction.java similarity index 91% rename from core/trino-main/src/main/java/io/trino/operator/scalar/ChoicesScalarFunctionImplementation.java rename to core/trino-main/src/main/java/io/trino/operator/scalar/ChoicesSpecializedSqlScalarFunction.java index 594432a16498..8423edb014a3 100644 --- a/core/trino-main/src/main/java/io/trino/operator/scalar/ChoicesScalarFunctionImplementation.java +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/ChoicesSpecializedSqlScalarFunction.java @@ -15,14 +15,14 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableList; -import io.trino.metadata.BoundSignature; -import io.trino.metadata.FunctionInvoker; import io.trino.spi.TrinoException; import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.function.BoundSignature; import io.trino.spi.function.InvocationConvention; import io.trino.spi.function.InvocationConvention.InvocationArgumentConvention; import io.trino.spi.function.InvocationConvention.InvocationReturnConvention; import io.trino.spi.function.ScalarFunctionAdapter; +import io.trino.spi.function.ScalarFunctionImplementation; import java.lang.invoke.MethodHandle; import java.util.ArrayList; @@ -37,15 +37,15 @@ import static java.util.Comparator.comparingInt; import static java.util.Objects.requireNonNull; -public final class ChoicesScalarFunctionImplementation - implements ScalarFunctionImplementation +public final class ChoicesSpecializedSqlScalarFunction + implements SpecializedSqlScalarFunction { private final ScalarFunctionAdapter functionAdapter = new ScalarFunctionAdapter(RETURN_NULL_ON_NULL); private final BoundSignature boundSignature; private final List choices; - public ChoicesScalarFunctionImplementation( + public ChoicesSpecializedSqlScalarFunction( BoundSignature boundSignature, InvocationReturnConvention returnConvention, List argumentConventions, @@ -54,7 +54,7 @@ public ChoicesScalarFunctionImplementation( this(boundSignature, returnConvention, argumentConventions, ImmutableList.of(), methodHandle, Optional.empty()); } - public ChoicesScalarFunctionImplementation( + public ChoicesSpecializedSqlScalarFunction( BoundSignature boundSignature, InvocationReturnConvention returnConvention, List argumentConventions, @@ -64,7 +64,7 @@ public ChoicesScalarFunctionImplementation( this(boundSignature, returnConvention, argumentConventions, ImmutableList.of(), methodHandle, instanceFactory); } - public ChoicesScalarFunctionImplementation( + public ChoicesSpecializedSqlScalarFunction( BoundSignature boundSignature, InvocationReturnConvention returnConvention, List argumentConventions, @@ -85,7 +85,7 @@ public ChoicesScalarFunctionImplementation( * @param boundSignature * @param choices the list of choices, ordered from generic to specific */ - public ChoicesScalarFunctionImplementation(BoundSignature boundSignature, List choices) + public ChoicesSpecializedSqlScalarFunction(BoundSignature boundSignature, List choices) { this.boundSignature = boundSignature; checkArgument(!choices.isEmpty(), "choices is an empty list"); @@ -99,7 +99,7 @@ public List getChoices() } @Override - public FunctionInvoker getScalarFunctionInvoker(InvocationConvention invocationConvention) + public ScalarFunctionImplementation getScalarFunctionImplementation(InvocationConvention invocationConvention) { List choices = new ArrayList<>(); for (ScalarImplementationChoice choice : this.choices) { @@ -119,10 +119,11 @@ public FunctionInvoker getScalarFunctionInvoker(InvocationConvention invocationC boundSignature.getArgumentTypes(), bestChoice.getInvocationConvention(), invocationConvention); - return new FunctionInvoker( - methodHandle, - bestChoice.getInstanceFactory(), - bestChoice.getLambdaInterfaces()); + ScalarFunctionImplementation.Builder builder = ScalarFunctionImplementation.builder() + .methodHandle(methodHandle); + bestChoice.getInstanceFactory().ifPresent(builder::instanceFactory); + builder.lambdaInterfaces(bestChoice.getLambdaInterfaces()); + return builder.build(); } public static class ScalarImplementationChoice diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/ConcatFunction.java b/core/trino-main/src/main/java/io/trino/operator/scalar/ConcatFunction.java index bb811bee466d..1f05a7651430 100644 --- a/core/trino-main/src/main/java/io/trino/operator/scalar/ConcatFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/ConcatFunction.java @@ -15,11 +15,11 @@ import io.airlift.slice.Slice; import io.airlift.slice.Slices; -import io.trino.metadata.BoundSignature; -import io.trino.metadata.FunctionMetadata; -import io.trino.metadata.Signature; import io.trino.metadata.SqlScalarFunction; import io.trino.spi.TrinoException; +import io.trino.spi.function.BoundSignature; +import io.trino.spi.function.FunctionMetadata; +import io.trino.spi.function.Signature; import io.trino.spi.type.TypeSignature; import java.lang.invoke.MethodHandle; @@ -60,7 +60,7 @@ private ConcatFunction(TypeSignature type, String description) } @Override - protected ScalarFunctionImplementation specialize(BoundSignature boundSignature) + protected SpecializedSqlScalarFunction specialize(BoundSignature boundSignature) { int arity = boundSignature.getArity(); @@ -75,7 +75,7 @@ protected ScalarFunctionImplementation specialize(BoundSignature boundSignature) MethodHandle arrayMethodHandle = methodHandle(ConcatFunction.class, "concat", Slice[].class); MethodHandle customMethodHandle = arrayMethodHandle.asCollector(Slice[].class, arity); - return new ChoicesScalarFunctionImplementation( + return new ChoicesSpecializedSqlScalarFunction( boundSignature, FAIL_ON_NULL, nCopies(arity, NEVER_NULL), diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/ConcatWsFunction.java b/core/trino-main/src/main/java/io/trino/operator/scalar/ConcatWsFunction.java index 9945a8c0d11d..2b87b118580d 100644 --- a/core/trino-main/src/main/java/io/trino/operator/scalar/ConcatWsFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/ConcatWsFunction.java @@ -17,14 +17,14 @@ import io.airlift.slice.Slice; import io.airlift.slice.Slices; import io.trino.annotation.UsedByGeneratedCode; -import io.trino.metadata.BoundSignature; -import io.trino.metadata.FunctionMetadata; -import io.trino.metadata.Signature; import io.trino.metadata.SqlScalarFunction; import io.trino.spi.TrinoException; import io.trino.spi.block.Block; +import io.trino.spi.function.BoundSignature; +import io.trino.spi.function.FunctionMetadata; import io.trino.spi.function.InvocationConvention; import io.trino.spi.function.ScalarFunction; +import io.trino.spi.function.Signature; import io.trino.spi.function.SqlType; import java.lang.invoke.MethodHandle; @@ -105,7 +105,7 @@ public ConcatWsFunction() } @Override - public ScalarFunctionImplementation specialize(BoundSignature boundSignature) + protected SpecializedSqlScalarFunction specialize(BoundSignature boundSignature) { int valueCount = boundSignature.getArity() - 1; if (valueCount < 1) { @@ -115,7 +115,7 @@ public ScalarFunctionImplementation specialize(BoundSignature boundSignature) MethodHandle arrayMethodHandle = methodHandle(ConcatWsFunction.class, "concatWs", Slice.class, Slice[].class); MethodHandle customMethodHandle = arrayMethodHandle.asCollector(Slice[].class, valueCount); - return new ChoicesScalarFunctionImplementation( + return new ChoicesSpecializedSqlScalarFunction( boundSignature, FAIL_ON_NULL, ImmutableList.builder() diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/ElementToArrayConcatFunction.java b/core/trino-main/src/main/java/io/trino/operator/scalar/ElementToArrayConcatFunction.java index 37a3c70b9767..495adc56091d 100644 --- a/core/trino-main/src/main/java/io/trino/operator/scalar/ElementToArrayConcatFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/ElementToArrayConcatFunction.java @@ -15,11 +15,11 @@ import com.google.common.collect.ImmutableList; import io.airlift.slice.Slice; -import io.trino.metadata.BoundSignature; -import io.trino.metadata.FunctionMetadata; -import io.trino.metadata.Signature; import io.trino.metadata.SqlScalarFunction; import io.trino.spi.block.Block; +import io.trino.spi.function.BoundSignature; +import io.trino.spi.function.FunctionMetadata; +import io.trino.spi.function.Signature; import io.trino.spi.type.Type; import io.trino.spi.type.TypeSignature; @@ -57,7 +57,7 @@ public ElementToArrayConcatFunction() } @Override - protected ScalarFunctionImplementation specialize(BoundSignature boundSignature) + protected SpecializedSqlScalarFunction specialize(BoundSignature boundSignature) { Type type = boundSignature.getArgumentTypes().get(0); MethodHandle methodHandle; @@ -78,7 +78,7 @@ else if (type.getJavaType() == Slice.class) { } methodHandle = methodHandle.bindTo(type); - return new ChoicesScalarFunctionImplementation( + return new ChoicesSpecializedSqlScalarFunction( boundSignature, FAIL_ON_NULL, ImmutableList.of(NEVER_NULL, NEVER_NULL), diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/FormatFunction.java b/core/trino-main/src/main/java/io/trino/operator/scalar/FormatFunction.java index 3015de28ad65..f83bef1aea42 100644 --- a/core/trino-main/src/main/java/io/trino/operator/scalar/FormatFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/FormatFunction.java @@ -16,16 +16,17 @@ import com.google.common.collect.ImmutableList; import io.airlift.slice.Slice; import io.trino.annotation.UsedByGeneratedCode; -import io.trino.metadata.BoundSignature; -import io.trino.metadata.FunctionDependencies; -import io.trino.metadata.FunctionDependencyDeclaration; -import io.trino.metadata.FunctionDependencyDeclaration.FunctionDependencyDeclarationBuilder; -import io.trino.metadata.FunctionMetadata; -import io.trino.metadata.Signature; import io.trino.metadata.SqlScalarFunction; import io.trino.spi.TrinoException; import io.trino.spi.block.Block; import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.function.BoundSignature; +import io.trino.spi.function.FunctionDependencies; +import io.trino.spi.function.FunctionDependencyDeclaration; +import io.trino.spi.function.FunctionDependencyDeclaration.FunctionDependencyDeclarationBuilder; +import io.trino.spi.function.FunctionMetadata; +import io.trino.spi.function.QualifiedFunctionName; +import io.trino.spi.function.Signature; import io.trino.spi.type.CharType; import io.trino.spi.type.DecimalType; import io.trino.spi.type.Int128; @@ -35,7 +36,6 @@ import io.trino.spi.type.Type; import io.trino.spi.type.TypeSignature; import io.trino.spi.type.VarcharType; -import io.trino.sql.tree.QualifiedName; import java.lang.invoke.MethodHandle; import java.math.BigDecimal; @@ -130,14 +130,14 @@ private static void addDependencies(FunctionDependencyDeclarationBuilder builder } if (type.equals(JSON)) { - builder.addFunction(QualifiedName.of("json_format"), ImmutableList.of(JSON)); + builder.addFunction(QualifiedFunctionName.of("json_format"), ImmutableList.of(JSON)); return; } builder.addCast(type, VARCHAR); } @Override - public ScalarFunctionImplementation specialize(BoundSignature boundSignature, FunctionDependencies functionDependencies) + public SpecializedSqlScalarFunction specialize(BoundSignature boundSignature, FunctionDependencies functionDependencies) { Type rowType = boundSignature.getArgumentType(1); @@ -146,7 +146,7 @@ public ScalarFunctionImplementation specialize(BoundSignature boundSignature, Fu (type, index) -> converter(functionDependencies, type, toIntExact(index))) .collect(toImmutableList()); - return new ChoicesScalarFunctionImplementation( + return new ChoicesSpecializedSqlScalarFunction( boundSignature, FAIL_ON_NULL, ImmutableList.of(NEVER_NULL, NEVER_NULL), @@ -212,7 +212,7 @@ private static BiFunction valueConverter(Functi } // TODO: support TIME WITH TIME ZONE by https://github.com/trinodb/trino/issues/191 + mapping to java.time.OffsetTime if (type.equals(JSON)) { - MethodHandle handle = functionDependencies.getFunctionInvoker(QualifiedName.of("json_format"), ImmutableList.of(JSON), simpleConvention(FAIL_ON_NULL, NEVER_NULL)).getMethodHandle(); + MethodHandle handle = functionDependencies.getScalarFunctionImplementation(QualifiedFunctionName.of("json_format"), ImmutableList.of(JSON), simpleConvention(FAIL_ON_NULL, NEVER_NULL)).getMethodHandle(); return (session, block) -> convertToString(handle, type.getSlice(block, position)); } if (isShortDecimal(type)) { @@ -248,7 +248,7 @@ else if (type.getJavaType() == Slice.class) { function = (session, block) -> type.getObject(block, position); } - MethodHandle handle = functionDependencies.getCastInvoker(type, VARCHAR, simpleConvention(FAIL_ON_NULL, NEVER_NULL)).getMethodHandle(); + MethodHandle handle = functionDependencies.getCastImplementation(type, VARCHAR, simpleConvention(FAIL_ON_NULL, NEVER_NULL)).getMethodHandle(); return (session, block) -> convertToString(handle, function.apply(session, block)); } diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/GenericComparisonUnorderedFirstOperator.java b/core/trino-main/src/main/java/io/trino/operator/scalar/GenericComparisonUnorderedFirstOperator.java index 3bac34027a47..da529a8cba97 100644 --- a/core/trino-main/src/main/java/io/trino/operator/scalar/GenericComparisonUnorderedFirstOperator.java +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/GenericComparisonUnorderedFirstOperator.java @@ -13,17 +13,16 @@ */ package io.trino.operator.scalar; -import io.trino.metadata.BoundSignature; -import io.trino.metadata.FunctionInvoker; -import io.trino.metadata.FunctionMetadata; -import io.trino.metadata.Signature; import io.trino.metadata.SqlScalarFunction; +import io.trino.spi.function.BoundSignature; +import io.trino.spi.function.FunctionMetadata; +import io.trino.spi.function.ScalarFunctionImplementation; +import io.trino.spi.function.Signature; import io.trino.spi.type.Type; import io.trino.spi.type.TypeOperators; import io.trino.spi.type.TypeSignature; import java.lang.invoke.MethodHandle; -import java.util.Optional; import static io.trino.spi.function.OperatorType.COMPARISON_UNORDERED_FIRST; import static io.trino.spi.type.IntegerType.INTEGER; @@ -49,12 +48,14 @@ public GenericComparisonUnorderedFirstOperator(TypeOperators typeOperators) } @Override - protected ScalarFunctionImplementation specialize(BoundSignature boundSignature) + protected SpecializedSqlScalarFunction specialize(BoundSignature boundSignature) { Type type = boundSignature.getArgumentType(0); return invocationConvention -> { MethodHandle methodHandle = typeOperators.getComparisonUnorderedFirstOperator(type, invocationConvention); - return new FunctionInvoker(methodHandle, Optional.empty()); + return ScalarFunctionImplementation.builder() + .methodHandle(methodHandle) + .build(); }; } } diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/GenericComparisonUnorderedLastOperator.java b/core/trino-main/src/main/java/io/trino/operator/scalar/GenericComparisonUnorderedLastOperator.java index c8d894db4f91..a520558fc8c0 100644 --- a/core/trino-main/src/main/java/io/trino/operator/scalar/GenericComparisonUnorderedLastOperator.java +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/GenericComparisonUnorderedLastOperator.java @@ -13,17 +13,16 @@ */ package io.trino.operator.scalar; -import io.trino.metadata.BoundSignature; -import io.trino.metadata.FunctionInvoker; -import io.trino.metadata.FunctionMetadata; -import io.trino.metadata.Signature; import io.trino.metadata.SqlScalarFunction; +import io.trino.spi.function.BoundSignature; +import io.trino.spi.function.FunctionMetadata; +import io.trino.spi.function.ScalarFunctionImplementation; +import io.trino.spi.function.Signature; import io.trino.spi.type.Type; import io.trino.spi.type.TypeOperators; import io.trino.spi.type.TypeSignature; import java.lang.invoke.MethodHandle; -import java.util.Optional; import static io.trino.spi.function.OperatorType.COMPARISON_UNORDERED_LAST; import static io.trino.spi.type.IntegerType.INTEGER; @@ -49,12 +48,14 @@ public GenericComparisonUnorderedLastOperator(TypeOperators typeOperators) } @Override - protected ScalarFunctionImplementation specialize(BoundSignature boundSignature) + protected SpecializedSqlScalarFunction specialize(BoundSignature boundSignature) { Type type = boundSignature.getArgumentType(0); return invocationConvention -> { MethodHandle methodHandle = typeOperators.getComparisonUnorderedLastOperator(type, invocationConvention); - return new FunctionInvoker(methodHandle, Optional.empty()); + return ScalarFunctionImplementation.builder() + .methodHandle(methodHandle) + .build(); }; } } diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/GenericDistinctFromOperator.java b/core/trino-main/src/main/java/io/trino/operator/scalar/GenericDistinctFromOperator.java index f4cbf9ac64e7..1b142e6d7645 100644 --- a/core/trino-main/src/main/java/io/trino/operator/scalar/GenericDistinctFromOperator.java +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/GenericDistinctFromOperator.java @@ -13,17 +13,16 @@ */ package io.trino.operator.scalar; -import io.trino.metadata.BoundSignature; -import io.trino.metadata.FunctionInvoker; -import io.trino.metadata.FunctionMetadata; -import io.trino.metadata.Signature; import io.trino.metadata.SqlScalarFunction; +import io.trino.spi.function.BoundSignature; +import io.trino.spi.function.FunctionMetadata; +import io.trino.spi.function.ScalarFunctionImplementation; +import io.trino.spi.function.Signature; import io.trino.spi.type.Type; import io.trino.spi.type.TypeOperators; import io.trino.spi.type.TypeSignature; import java.lang.invoke.MethodHandle; -import java.util.Optional; import static io.trino.spi.function.OperatorType.IS_DISTINCT_FROM; import static io.trino.spi.type.BooleanType.BOOLEAN; @@ -50,12 +49,14 @@ public GenericDistinctFromOperator(TypeOperators typeOperators) } @Override - protected ScalarFunctionImplementation specialize(BoundSignature boundSignature) + protected SpecializedSqlScalarFunction specialize(BoundSignature boundSignature) { Type type = boundSignature.getArgumentType(0); return invocationConvention -> { MethodHandle methodHandle = typeOperators.getDistinctFromOperator(type, invocationConvention); - return new FunctionInvoker(methodHandle, Optional.empty()); + return ScalarFunctionImplementation.builder() + .methodHandle(methodHandle) + .build(); }; } } diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/GenericEqualOperator.java b/core/trino-main/src/main/java/io/trino/operator/scalar/GenericEqualOperator.java index 12ac9fcf4b79..ecd21fa1858c 100644 --- a/core/trino-main/src/main/java/io/trino/operator/scalar/GenericEqualOperator.java +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/GenericEqualOperator.java @@ -13,17 +13,16 @@ */ package io.trino.operator.scalar; -import io.trino.metadata.BoundSignature; -import io.trino.metadata.FunctionInvoker; -import io.trino.metadata.FunctionMetadata; -import io.trino.metadata.Signature; import io.trino.metadata.SqlScalarFunction; +import io.trino.spi.function.BoundSignature; +import io.trino.spi.function.FunctionMetadata; +import io.trino.spi.function.ScalarFunctionImplementation; +import io.trino.spi.function.Signature; import io.trino.spi.type.Type; import io.trino.spi.type.TypeOperators; import io.trino.spi.type.TypeSignature; import java.lang.invoke.MethodHandle; -import java.util.Optional; import static io.trino.spi.function.OperatorType.EQUAL; import static io.trino.spi.type.BooleanType.BOOLEAN; @@ -50,12 +49,14 @@ public GenericEqualOperator(TypeOperators typeOperators) } @Override - protected ScalarFunctionImplementation specialize(BoundSignature boundSignature) + protected SpecializedSqlScalarFunction specialize(BoundSignature boundSignature) { Type type = boundSignature.getArgumentType(0); return invocationConvention -> { MethodHandle methodHandle = typeOperators.getEqualOperator(type, invocationConvention); - return new FunctionInvoker(methodHandle, Optional.empty()); + return ScalarFunctionImplementation.builder() + .methodHandle(methodHandle) + .build(); }; } } diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/GenericHashCodeOperator.java b/core/trino-main/src/main/java/io/trino/operator/scalar/GenericHashCodeOperator.java index 441d30727fd7..e6b5e114777c 100644 --- a/core/trino-main/src/main/java/io/trino/operator/scalar/GenericHashCodeOperator.java +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/GenericHashCodeOperator.java @@ -13,17 +13,16 @@ */ package io.trino.operator.scalar; -import io.trino.metadata.BoundSignature; -import io.trino.metadata.FunctionInvoker; -import io.trino.metadata.FunctionMetadata; -import io.trino.metadata.Signature; import io.trino.metadata.SqlScalarFunction; +import io.trino.spi.function.BoundSignature; +import io.trino.spi.function.FunctionMetadata; +import io.trino.spi.function.ScalarFunctionImplementation; +import io.trino.spi.function.Signature; import io.trino.spi.type.Type; import io.trino.spi.type.TypeOperators; import io.trino.spi.type.TypeSignature; import java.lang.invoke.MethodHandle; -import java.util.Optional; import static io.trino.spi.function.OperatorType.HASH_CODE; import static io.trino.spi.type.BigintType.BIGINT; @@ -48,12 +47,14 @@ public GenericHashCodeOperator(TypeOperators typeOperators) } @Override - protected ScalarFunctionImplementation specialize(BoundSignature boundSignature) + protected SpecializedSqlScalarFunction specialize(BoundSignature boundSignature) { Type type = boundSignature.getArgumentType(0); return invocationConvention -> { MethodHandle methodHandle = typeOperators.getHashCodeOperator(type, invocationConvention); - return new FunctionInvoker(methodHandle, Optional.empty()); + return ScalarFunctionImplementation.builder() + .methodHandle(methodHandle) + .build(); }; } } diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/GenericIndeterminateOperator.java b/core/trino-main/src/main/java/io/trino/operator/scalar/GenericIndeterminateOperator.java index 4009b307f394..e57b35e16e6e 100644 --- a/core/trino-main/src/main/java/io/trino/operator/scalar/GenericIndeterminateOperator.java +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/GenericIndeterminateOperator.java @@ -13,17 +13,16 @@ */ package io.trino.operator.scalar; -import io.trino.metadata.BoundSignature; -import io.trino.metadata.FunctionInvoker; -import io.trino.metadata.FunctionMetadata; -import io.trino.metadata.Signature; import io.trino.metadata.SqlScalarFunction; +import io.trino.spi.function.BoundSignature; +import io.trino.spi.function.FunctionMetadata; +import io.trino.spi.function.ScalarFunctionImplementation; +import io.trino.spi.function.Signature; import io.trino.spi.type.Type; import io.trino.spi.type.TypeOperators; import io.trino.spi.type.TypeSignature; import java.lang.invoke.MethodHandle; -import java.util.Optional; import static io.trino.spi.function.OperatorType.INDETERMINATE; import static io.trino.spi.type.BooleanType.BOOLEAN; @@ -49,12 +48,14 @@ public GenericIndeterminateOperator(TypeOperators typeOperators) } @Override - protected ScalarFunctionImplementation specialize(BoundSignature boundSignature) + protected SpecializedSqlScalarFunction specialize(BoundSignature boundSignature) { Type type = boundSignature.getArgumentType(0); return invocationConvention -> { MethodHandle methodHandle = typeOperators.getIndeterminateOperator(type, invocationConvention); - return new FunctionInvoker(methodHandle, Optional.empty()); + return ScalarFunctionImplementation.builder() + .methodHandle(methodHandle) + .build(); }; } } diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/GenericLessThanOperator.java b/core/trino-main/src/main/java/io/trino/operator/scalar/GenericLessThanOperator.java index 3af7608c0010..24ce2f45512e 100644 --- a/core/trino-main/src/main/java/io/trino/operator/scalar/GenericLessThanOperator.java +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/GenericLessThanOperator.java @@ -13,17 +13,16 @@ */ package io.trino.operator.scalar; -import io.trino.metadata.BoundSignature; -import io.trino.metadata.FunctionInvoker; -import io.trino.metadata.FunctionMetadata; -import io.trino.metadata.Signature; import io.trino.metadata.SqlScalarFunction; +import io.trino.spi.function.BoundSignature; +import io.trino.spi.function.FunctionMetadata; +import io.trino.spi.function.ScalarFunctionImplementation; +import io.trino.spi.function.Signature; import io.trino.spi.type.Type; import io.trino.spi.type.TypeOperators; import io.trino.spi.type.TypeSignature; import java.lang.invoke.MethodHandle; -import java.util.Optional; import static io.trino.spi.function.OperatorType.LESS_THAN; import static io.trino.spi.type.BooleanType.BOOLEAN; @@ -49,12 +48,14 @@ public GenericLessThanOperator(TypeOperators typeOperators) } @Override - protected ScalarFunctionImplementation specialize(BoundSignature boundSignature) + protected SpecializedSqlScalarFunction specialize(BoundSignature boundSignature) { Type type = boundSignature.getArgumentType(0); return invocationConvention -> { MethodHandle methodHandle = typeOperators.getLessThanOperator(type, invocationConvention); - return new FunctionInvoker(methodHandle, Optional.empty()); + return ScalarFunctionImplementation.builder() + .methodHandle(methodHandle) + .build(); }; } } diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/GenericLessThanOrEqualOperator.java b/core/trino-main/src/main/java/io/trino/operator/scalar/GenericLessThanOrEqualOperator.java index add93d18f854..3535d0167d73 100644 --- a/core/trino-main/src/main/java/io/trino/operator/scalar/GenericLessThanOrEqualOperator.java +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/GenericLessThanOrEqualOperator.java @@ -13,17 +13,16 @@ */ package io.trino.operator.scalar; -import io.trino.metadata.BoundSignature; -import io.trino.metadata.FunctionInvoker; -import io.trino.metadata.FunctionMetadata; -import io.trino.metadata.Signature; import io.trino.metadata.SqlScalarFunction; +import io.trino.spi.function.BoundSignature; +import io.trino.spi.function.FunctionMetadata; +import io.trino.spi.function.ScalarFunctionImplementation; +import io.trino.spi.function.Signature; import io.trino.spi.type.Type; import io.trino.spi.type.TypeOperators; import io.trino.spi.type.TypeSignature; import java.lang.invoke.MethodHandle; -import java.util.Optional; import static io.trino.spi.function.OperatorType.LESS_THAN_OR_EQUAL; import static io.trino.spi.type.BooleanType.BOOLEAN; @@ -49,12 +48,14 @@ public GenericLessThanOrEqualOperator(TypeOperators typeOperators) } @Override - protected ScalarFunctionImplementation specialize(BoundSignature boundSignature) + protected SpecializedSqlScalarFunction specialize(BoundSignature boundSignature) { Type type = boundSignature.getArgumentType(0); return invocationConvention -> { MethodHandle methodHandle = typeOperators.getLessThanOrEqualOperator(type, invocationConvention); - return new FunctionInvoker(methodHandle, Optional.empty()); + return ScalarFunctionImplementation.builder() + .methodHandle(methodHandle) + .build(); }; } } diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/GenericXxHash64Operator.java b/core/trino-main/src/main/java/io/trino/operator/scalar/GenericXxHash64Operator.java index fb02b6fef758..5343fc5c883e 100644 --- a/core/trino-main/src/main/java/io/trino/operator/scalar/GenericXxHash64Operator.java +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/GenericXxHash64Operator.java @@ -13,17 +13,16 @@ */ package io.trino.operator.scalar; -import io.trino.metadata.BoundSignature; -import io.trino.metadata.FunctionInvoker; -import io.trino.metadata.FunctionMetadata; -import io.trino.metadata.Signature; import io.trino.metadata.SqlScalarFunction; +import io.trino.spi.function.BoundSignature; +import io.trino.spi.function.FunctionMetadata; +import io.trino.spi.function.ScalarFunctionImplementation; +import io.trino.spi.function.Signature; import io.trino.spi.type.Type; import io.trino.spi.type.TypeOperators; import io.trino.spi.type.TypeSignature; import java.lang.invoke.MethodHandle; -import java.util.Optional; import static io.trino.spi.function.OperatorType.XX_HASH_64; import static io.trino.spi.type.BigintType.BIGINT; @@ -48,12 +47,14 @@ public GenericXxHash64Operator(TypeOperators typeOperators) } @Override - protected ScalarFunctionImplementation specialize(BoundSignature boundSignature) + protected SpecializedSqlScalarFunction specialize(BoundSignature boundSignature) { Type type = boundSignature.getArgumentType(0); return invocationConvention -> { MethodHandle methodHandle = typeOperators.getXxHash64Operator(type, invocationConvention); - return new FunctionInvoker(methodHandle, Optional.empty()); + return ScalarFunctionImplementation.builder() + .methodHandle(methodHandle) + .build(); }; } } diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/IdentityCast.java b/core/trino-main/src/main/java/io/trino/operator/scalar/IdentityCast.java index 8443d48e4c37..f2c600045aa4 100644 --- a/core/trino-main/src/main/java/io/trino/operator/scalar/IdentityCast.java +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/IdentityCast.java @@ -14,10 +14,10 @@ package io.trino.operator.scalar; import com.google.common.collect.ImmutableList; -import io.trino.metadata.BoundSignature; -import io.trino.metadata.FunctionMetadata; -import io.trino.metadata.Signature; import io.trino.metadata.SqlScalarFunction; +import io.trino.spi.function.BoundSignature; +import io.trino.spi.function.FunctionMetadata; +import io.trino.spi.function.Signature; import io.trino.spi.type.Type; import io.trino.spi.type.TypeSignature; @@ -46,11 +46,11 @@ private IdentityCast() } @Override - protected ScalarFunctionImplementation specialize(BoundSignature boundSignature) + protected SpecializedSqlScalarFunction specialize(BoundSignature boundSignature) { Type type = boundSignature.getReturnType(); MethodHandle identity = MethodHandles.identity(type.getJavaType()); - return new ChoicesScalarFunctionImplementation( + return new ChoicesSpecializedSqlScalarFunction( boundSignature, FAIL_ON_NULL, ImmutableList.of(NEVER_NULL), diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/InvokeFunction.java b/core/trino-main/src/main/java/io/trino/operator/scalar/InvokeFunction.java index c929bc44e27b..6be43122c69b 100644 --- a/core/trino-main/src/main/java/io/trino/operator/scalar/InvokeFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/InvokeFunction.java @@ -15,10 +15,10 @@ import com.google.common.collect.ImmutableList; import com.google.common.primitives.Primitives; -import io.trino.metadata.BoundSignature; -import io.trino.metadata.FunctionMetadata; -import io.trino.metadata.Signature; import io.trino.metadata.SqlScalarFunction; +import io.trino.spi.function.BoundSignature; +import io.trino.spi.function.FunctionMetadata; +import io.trino.spi.function.Signature; import io.trino.spi.type.Type; import io.trino.spi.type.TypeSignature; import io.trino.sql.gen.lambda.LambdaFunctionInterface; @@ -57,10 +57,10 @@ private InvokeFunction() } @Override - protected ScalarFunctionImplementation specialize(BoundSignature boundSignature) + protected SpecializedSqlScalarFunction specialize(BoundSignature boundSignature) { Type returnType = boundSignature.getReturnType(); - return new ChoicesScalarFunctionImplementation( + return new ChoicesSpecializedSqlScalarFunction( boundSignature, NULLABLE_RETURN, ImmutableList.of(FUNCTION), diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/JsonStringToArrayCast.java b/core/trino-main/src/main/java/io/trino/operator/scalar/JsonStringToArrayCast.java index ba642cb3a201..14cb658660b5 100644 --- a/core/trino-main/src/main/java/io/trino/operator/scalar/JsonStringToArrayCast.java +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/JsonStringToArrayCast.java @@ -13,10 +13,10 @@ */ package io.trino.operator.scalar; -import io.trino.metadata.BoundSignature; -import io.trino.metadata.FunctionMetadata; -import io.trino.metadata.Signature; import io.trino.metadata.SqlScalarFunction; +import io.trino.spi.function.BoundSignature; +import io.trino.spi.function.FunctionMetadata; +import io.trino.spi.function.Signature; import io.trino.spi.type.TypeSignature; import static io.trino.operator.scalar.JsonToArrayCast.JSON_TO_ARRAY; @@ -45,7 +45,7 @@ private JsonStringToArrayCast() } @Override - protected ScalarFunctionImplementation specialize(BoundSignature boundSignature) + protected SpecializedSqlScalarFunction specialize(BoundSignature boundSignature) { return JSON_TO_ARRAY.specialize(boundSignature); } diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/JsonStringToMapCast.java b/core/trino-main/src/main/java/io/trino/operator/scalar/JsonStringToMapCast.java index b250cd2c7cfe..0cc8d99181ca 100644 --- a/core/trino-main/src/main/java/io/trino/operator/scalar/JsonStringToMapCast.java +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/JsonStringToMapCast.java @@ -13,10 +13,10 @@ */ package io.trino.operator.scalar; -import io.trino.metadata.BoundSignature; -import io.trino.metadata.FunctionMetadata; -import io.trino.metadata.Signature; import io.trino.metadata.SqlScalarFunction; +import io.trino.spi.function.BoundSignature; +import io.trino.spi.function.FunctionMetadata; +import io.trino.spi.function.Signature; import io.trino.spi.type.TypeSignature; import static io.trino.operator.scalar.JsonToMapCast.JSON_TO_MAP; @@ -46,7 +46,7 @@ private JsonStringToMapCast() } @Override - protected ScalarFunctionImplementation specialize(BoundSignature boundSignature) + protected SpecializedSqlScalarFunction specialize(BoundSignature boundSignature) { return JSON_TO_MAP.specialize(boundSignature); } diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/JsonStringToRowCast.java b/core/trino-main/src/main/java/io/trino/operator/scalar/JsonStringToRowCast.java index a0833afcce62..45131229a6f7 100644 --- a/core/trino-main/src/main/java/io/trino/operator/scalar/JsonStringToRowCast.java +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/JsonStringToRowCast.java @@ -13,10 +13,10 @@ */ package io.trino.operator.scalar; -import io.trino.metadata.BoundSignature; -import io.trino.metadata.FunctionMetadata; -import io.trino.metadata.Signature; import io.trino.metadata.SqlScalarFunction; +import io.trino.spi.function.BoundSignature; +import io.trino.spi.function.FunctionMetadata; +import io.trino.spi.function.Signature; import io.trino.spi.type.TypeSignature; import static io.trino.operator.scalar.JsonToRowCast.JSON_TO_ROW; @@ -44,7 +44,7 @@ private JsonStringToRowCast() } @Override - protected ScalarFunctionImplementation specialize(BoundSignature boundSignature) + protected SpecializedSqlScalarFunction specialize(BoundSignature boundSignature) { return JSON_TO_ROW.specialize(boundSignature); } diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/JsonToArrayCast.java b/core/trino-main/src/main/java/io/trino/operator/scalar/JsonToArrayCast.java index f123f4acaf0f..7c69e63d9431 100644 --- a/core/trino-main/src/main/java/io/trino/operator/scalar/JsonToArrayCast.java +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/JsonToArrayCast.java @@ -18,14 +18,14 @@ import com.google.common.collect.ImmutableList; import io.airlift.slice.Slice; import io.trino.annotation.UsedByGeneratedCode; -import io.trino.metadata.BoundSignature; -import io.trino.metadata.FunctionMetadata; -import io.trino.metadata.Signature; import io.trino.metadata.SqlScalarFunction; import io.trino.spi.TrinoException; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.function.BoundSignature; +import io.trino.spi.function.FunctionMetadata; +import io.trino.spi.function.Signature; import io.trino.spi.type.ArrayType; import io.trino.spi.type.TypeSignature; import io.trino.util.JsonCastException; @@ -68,7 +68,7 @@ private JsonToArrayCast() } @Override - protected ScalarFunctionImplementation specialize(BoundSignature boundSignature) + protected SpecializedSqlScalarFunction specialize(BoundSignature boundSignature) { checkArgument(boundSignature.getArity() == 1, "Expected arity to be 1"); ArrayType arrayType = (ArrayType) boundSignature.getReturnType(); @@ -76,7 +76,7 @@ protected ScalarFunctionImplementation specialize(BoundSignature boundSignature) BlockBuilderAppender arrayAppender = BlockBuilderAppender.createBlockBuilderAppender(arrayType); MethodHandle methodHandle = METHOD_HANDLE.bindTo(arrayType).bindTo(arrayAppender); - return new ChoicesScalarFunctionImplementation( + return new ChoicesSpecializedSqlScalarFunction( boundSignature, NULLABLE_RETURN, ImmutableList.of(NEVER_NULL), diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/JsonToMapCast.java b/core/trino-main/src/main/java/io/trino/operator/scalar/JsonToMapCast.java index db64c5c9b6c3..db8898fc6d60 100644 --- a/core/trino-main/src/main/java/io/trino/operator/scalar/JsonToMapCast.java +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/JsonToMapCast.java @@ -18,14 +18,14 @@ import com.google.common.collect.ImmutableList; import io.airlift.slice.Slice; import io.trino.annotation.UsedByGeneratedCode; -import io.trino.metadata.BoundSignature; -import io.trino.metadata.FunctionMetadata; -import io.trino.metadata.Signature; import io.trino.metadata.SqlScalarFunction; import io.trino.spi.TrinoException; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.function.BoundSignature; +import io.trino.spi.function.FunctionMetadata; +import io.trino.spi.function.Signature; import io.trino.spi.type.MapType; import io.trino.spi.type.TypeSignature; import io.trino.util.JsonCastException; @@ -71,7 +71,7 @@ private JsonToMapCast() } @Override - protected ScalarFunctionImplementation specialize(BoundSignature boundSignature) + protected SpecializedSqlScalarFunction specialize(BoundSignature boundSignature) { checkArgument(boundSignature.getArity() == 1, "Expected arity to be 1"); MapType mapType = (MapType) boundSignature.getReturnType(); @@ -79,7 +79,7 @@ protected ScalarFunctionImplementation specialize(BoundSignature boundSignature) BlockBuilderAppender mapAppender = createBlockBuilderAppender(mapType); MethodHandle methodHandle = METHOD_HANDLE.bindTo(mapType).bindTo(mapAppender); - return new ChoicesScalarFunctionImplementation( + return new ChoicesSpecializedSqlScalarFunction( boundSignature, NULLABLE_RETURN, ImmutableList.of(NEVER_NULL), diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/JsonToRowCast.java b/core/trino-main/src/main/java/io/trino/operator/scalar/JsonToRowCast.java index 50bd7bb26f8d..b4cd9d809fa7 100644 --- a/core/trino-main/src/main/java/io/trino/operator/scalar/JsonToRowCast.java +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/JsonToRowCast.java @@ -18,15 +18,15 @@ import com.google.common.collect.ImmutableList; import io.airlift.slice.Slice; import io.trino.annotation.UsedByGeneratedCode; -import io.trino.metadata.BoundSignature; -import io.trino.metadata.FunctionMetadata; -import io.trino.metadata.Signature; import io.trino.metadata.SqlScalarFunction; -import io.trino.metadata.TypeVariableConstraint; import io.trino.spi.TrinoException; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.function.BoundSignature; +import io.trino.spi.function.FunctionMetadata; +import io.trino.spi.function.Signature; +import io.trino.spi.function.TypeVariableConstraint; import io.trino.spi.type.RowType; import io.trino.spi.type.TypeSignature; import io.trino.util.JsonCastException; @@ -73,14 +73,14 @@ private JsonToRowCast() } @Override - protected ScalarFunctionImplementation specialize(BoundSignature boundSignature) + protected SpecializedSqlScalarFunction specialize(BoundSignature boundSignature) { RowType rowType = (RowType) boundSignature.getReturnType(); checkCondition(canCastFromJson(rowType), INVALID_CAST_ARGUMENT, "Cannot cast JSON to %s", rowType); BlockBuilderAppender fieldAppender = createBlockBuilderAppender(rowType); MethodHandle methodHandle = METHOD_HANDLE.bindTo(rowType).bindTo(fieldAppender); - return new ChoicesScalarFunctionImplementation( + return new ChoicesSpecializedSqlScalarFunction( boundSignature, NULLABLE_RETURN, ImmutableList.of(NEVER_NULL), diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/MapConcatFunction.java b/core/trino-main/src/main/java/io/trino/operator/scalar/MapConcatFunction.java index 66915a6c8d15..fc8744523208 100644 --- a/core/trino-main/src/main/java/io/trino/operator/scalar/MapConcatFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/MapConcatFunction.java @@ -15,15 +15,15 @@ import com.google.common.collect.ImmutableList; import io.trino.annotation.UsedByGeneratedCode; -import io.trino.metadata.BoundSignature; -import io.trino.metadata.FunctionMetadata; -import io.trino.metadata.Signature; import io.trino.metadata.SqlScalarFunction; import io.trino.operator.aggregation.TypedSet; import io.trino.spi.PageBuilder; import io.trino.spi.TrinoException; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.function.BoundSignature; +import io.trino.spi.function.FunctionMetadata; +import io.trino.spi.function.Signature; import io.trino.spi.type.MapType; import io.trino.spi.type.Type; import io.trino.spi.type.TypeSignature; @@ -82,7 +82,7 @@ public MapConcatFunction(BlockTypeOperators blockTypeOperators) } @Override - protected ScalarFunctionImplementation specialize(BoundSignature boundSignature) + protected SpecializedSqlScalarFunction specialize(BoundSignature boundSignature) { if (boundSignature.getArity() < 2) { throw new TrinoException(INVALID_FUNCTION_ARGUMENT, "There must be two or more concatenation arguments to " + FUNCTION_NAME); @@ -100,7 +100,7 @@ protected ScalarFunctionImplementation specialize(BoundSignature boundSignature) MethodHandles.insertArguments(METHOD_HANDLE, 0, mapType, keysDistinctOperator, keyHashCode), USER_STATE_FACTORY.bindTo(mapType)); - return new ChoicesScalarFunctionImplementation( + return new ChoicesSpecializedSqlScalarFunction( boundSignature, FAIL_ON_NULL, nCopies(boundSignature.getArity(), NEVER_NULL), diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/MapConstructor.java b/core/trino-main/src/main/java/io/trino/operator/scalar/MapConstructor.java index 9f38a3681511..2f522a9da9d8 100644 --- a/core/trino-main/src/main/java/io/trino/operator/scalar/MapConstructor.java +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/MapConstructor.java @@ -15,11 +15,6 @@ import com.google.common.collect.ImmutableList; import io.trino.annotation.UsedByGeneratedCode; -import io.trino.metadata.BoundSignature; -import io.trino.metadata.FunctionDependencies; -import io.trino.metadata.FunctionDependencyDeclaration; -import io.trino.metadata.FunctionMetadata; -import io.trino.metadata.Signature; import io.trino.metadata.SqlScalarFunction; import io.trino.spi.PageBuilder; import io.trino.spi.TrinoException; @@ -28,6 +23,11 @@ import io.trino.spi.block.DuplicateMapKeyException; import io.trino.spi.block.MapBlockBuilder; import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.function.BoundSignature; +import io.trino.spi.function.FunctionDependencies; +import io.trino.spi.function.FunctionDependencyDeclaration; +import io.trino.spi.function.FunctionMetadata; +import io.trino.spi.function.Signature; import io.trino.spi.type.MapType; import io.trino.spi.type.TypeSignature; @@ -92,16 +92,16 @@ public FunctionDependencyDeclaration getFunctionDependencies() } @Override - public ScalarFunctionImplementation specialize(BoundSignature boundSignature, FunctionDependencies functionDependencies) + public SpecializedSqlScalarFunction specialize(BoundSignature boundSignature, FunctionDependencies functionDependencies) { MapType mapType = (MapType) boundSignature.getReturnType(); - MethodHandle keyIndeterminate = functionDependencies.getOperatorInvoker( + MethodHandle keyIndeterminate = functionDependencies.getOperatorImplementation( INDETERMINATE, ImmutableList.of(mapType.getKeyType()), simpleConvention(FAIL_ON_NULL, NEVER_NULL)).getMethodHandle(); MethodHandle instanceFactory = constructorMethodHandle(State.class, MapType.class).bindTo(mapType); - return new ChoicesScalarFunctionImplementation( + return new ChoicesSpecializedSqlScalarFunction( boundSignature, FAIL_ON_NULL, ImmutableList.of(NEVER_NULL, NEVER_NULL), diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/MapElementAtFunction.java b/core/trino-main/src/main/java/io/trino/operator/scalar/MapElementAtFunction.java index 49f29a43d48d..df0acca55dec 100644 --- a/core/trino-main/src/main/java/io/trino/operator/scalar/MapElementAtFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/MapElementAtFunction.java @@ -16,14 +16,14 @@ import com.google.common.collect.ImmutableList; import com.google.common.primitives.Primitives; import io.trino.annotation.UsedByGeneratedCode; -import io.trino.metadata.BoundSignature; -import io.trino.metadata.FunctionDependencies; -import io.trino.metadata.FunctionDependencyDeclaration; -import io.trino.metadata.FunctionMetadata; -import io.trino.metadata.Signature; import io.trino.metadata.SqlScalarFunction; import io.trino.spi.block.Block; import io.trino.spi.block.SingleMapBlock; +import io.trino.spi.function.BoundSignature; +import io.trino.spi.function.FunctionDependencies; +import io.trino.spi.function.FunctionDependencyDeclaration; +import io.trino.spi.function.FunctionMetadata; +import io.trino.spi.function.Signature; import io.trino.spi.type.MapType; import io.trino.spi.type.Type; import io.trino.spi.type.TypeSignature; @@ -72,7 +72,7 @@ public FunctionDependencyDeclaration getFunctionDependencies() } @Override - public ScalarFunctionImplementation specialize(BoundSignature boundSignature, FunctionDependencies functionDependencies) + public SpecializedSqlScalarFunction specialize(BoundSignature boundSignature, FunctionDependencies functionDependencies) { MapType mapType = (MapType) boundSignature.getArgumentType(0); Type keyType = mapType.getKeyType(); @@ -94,7 +94,7 @@ else if (keyType.getJavaType() == double.class) { methodHandle = methodHandle.bindTo(valueType); methodHandle = methodHandle.asType(methodHandle.type().changeReturnType(Primitives.wrap(valueType.getJavaType()))); - return new ChoicesScalarFunctionImplementation( + return new ChoicesSpecializedSqlScalarFunction( boundSignature, NULLABLE_RETURN, ImmutableList.of(NEVER_NULL, NEVER_NULL), diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/MapFilterFunction.java b/core/trino-main/src/main/java/io/trino/operator/scalar/MapFilterFunction.java index 0510628db5d2..88a2095ae2fd 100644 --- a/core/trino-main/src/main/java/io/trino/operator/scalar/MapFilterFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/MapFilterFunction.java @@ -25,13 +25,13 @@ import io.airlift.bytecode.control.ForLoop; import io.airlift.bytecode.control.IfStatement; import io.trino.annotation.UsedByGeneratedCode; -import io.trino.metadata.BoundSignature; -import io.trino.metadata.FunctionMetadata; -import io.trino.metadata.Signature; import io.trino.metadata.SqlScalarFunction; import io.trino.spi.PageBuilder; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.function.BoundSignature; +import io.trino.spi.function.FunctionMetadata; +import io.trino.spi.function.Signature; import io.trino.spi.type.MapType; import io.trino.spi.type.Type; import io.trino.spi.type.TypeSignature; @@ -92,10 +92,10 @@ private MapFilterFunction() } @Override - public ScalarFunctionImplementation specialize(BoundSignature boundSignature) + public SpecializedSqlScalarFunction specialize(BoundSignature boundSignature) { MapType mapType = (MapType) boundSignature.getReturnType(); - return new ChoicesScalarFunctionImplementation( + return new ChoicesSpecializedSqlScalarFunction( boundSignature, FAIL_ON_NULL, ImmutableList.of(NEVER_NULL, FUNCTION), diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/MapSubscriptOperator.java b/core/trino-main/src/main/java/io/trino/operator/scalar/MapSubscriptOperator.java index 78e170edcbaf..282deda32aa9 100644 --- a/core/trino-main/src/main/java/io/trino/operator/scalar/MapSubscriptOperator.java +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/MapSubscriptOperator.java @@ -17,17 +17,17 @@ import com.google.common.primitives.Primitives; import io.airlift.slice.Slice; import io.trino.annotation.UsedByGeneratedCode; -import io.trino.metadata.BoundSignature; -import io.trino.metadata.FunctionDependencies; -import io.trino.metadata.FunctionDependencyDeclaration; -import io.trino.metadata.FunctionMetadata; -import io.trino.metadata.Signature; import io.trino.metadata.SqlScalarFunction; import io.trino.spi.TrinoException; import io.trino.spi.block.Block; import io.trino.spi.block.SingleMapBlock; import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.function.BoundSignature; +import io.trino.spi.function.FunctionDependencies; +import io.trino.spi.function.FunctionDependencyDeclaration; +import io.trino.spi.function.FunctionMetadata; import io.trino.spi.function.InvocationConvention; +import io.trino.spi.function.Signature; import io.trino.spi.type.MapType; import io.trino.spi.type.Type; import io.trino.spi.type.TypeSignature; @@ -77,7 +77,7 @@ public FunctionDependencyDeclaration getFunctionDependencies() } @Override - public ScalarFunctionImplementation specialize(BoundSignature boundSignature, FunctionDependencies functionDependencies) + public SpecializedSqlScalarFunction specialize(BoundSignature boundSignature, FunctionDependencies functionDependencies) { MapType mapType = (MapType) boundSignature.getArgumentType(0); Type keyType = mapType.getKeyType(); @@ -100,7 +100,7 @@ else if (keyType.getJavaType() == double.class) { methodHandle = methodHandle.bindTo(missingKeyExceptionFactory).bindTo(keyType).bindTo(valueType); methodHandle = methodHandle.asType(methodHandle.type().changeReturnType(Primitives.wrap(valueType.getJavaType()))); - return new ChoicesScalarFunctionImplementation( + return new ChoicesSpecializedSqlScalarFunction( boundSignature, NULLABLE_RETURN, ImmutableList.of(NEVER_NULL, NEVER_NULL), @@ -160,7 +160,7 @@ public MissingKeyExceptionFactory(FunctionDependencies functionDependencies, Typ MethodHandle castMethod = null; try { InvocationConvention invocationConvention = new InvocationConvention(ImmutableList.of(BOXED_NULLABLE), NULLABLE_RETURN, true, false); - castMethod = functionDependencies.getCastInvoker(keyType, VARCHAR, invocationConvention).getMethodHandle(); + castMethod = functionDependencies.getCastImplementation(keyType, VARCHAR, invocationConvention).getMethodHandle(); if (!castMethod.type().parameterType(0).equals(ConnectorSession.class)) { castMethod = MethodHandles.dropArguments(castMethod, 0, ConnectorSession.class); } diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/MapToJsonCast.java b/core/trino-main/src/main/java/io/trino/operator/scalar/MapToJsonCast.java index 17e95b99a56b..1b624be75405 100644 --- a/core/trino-main/src/main/java/io/trino/operator/scalar/MapToJsonCast.java +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/MapToJsonCast.java @@ -19,11 +19,11 @@ import io.airlift.slice.Slice; import io.airlift.slice.SliceOutput; import io.trino.annotation.UsedByGeneratedCode; -import io.trino.metadata.BoundSignature; -import io.trino.metadata.FunctionMetadata; -import io.trino.metadata.Signature; import io.trino.metadata.SqlScalarFunction; import io.trino.spi.block.Block; +import io.trino.spi.function.BoundSignature; +import io.trino.spi.function.FunctionMetadata; +import io.trino.spi.function.Signature; import io.trino.spi.type.MapType; import io.trino.spi.type.Type; import io.trino.spi.type.TypeSignature; @@ -73,7 +73,7 @@ private MapToJsonCast(boolean legacyRowToJson) } @Override - public ScalarFunctionImplementation specialize(BoundSignature boundSignature) + public SpecializedSqlScalarFunction specialize(BoundSignature boundSignature) { MapType mapType = (MapType) boundSignature.getArgumentType(0); Type keyType = mapType.getKeyType(); @@ -84,7 +84,7 @@ public ScalarFunctionImplementation specialize(BoundSignature boundSignature) JsonGeneratorWriter writer = JsonGeneratorWriter.createJsonGeneratorWriter(valueType, legacyRowToJson); MethodHandle methodHandle = METHOD_HANDLE.bindTo(provider).bindTo(writer); - return new ChoicesScalarFunctionImplementation( + return new ChoicesSpecializedSqlScalarFunction( boundSignature, FAIL_ON_NULL, ImmutableList.of(NEVER_NULL), diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/MapToMapCast.java b/core/trino-main/src/main/java/io/trino/operator/scalar/MapToMapCast.java index a78fc11365e5..54ae82504367 100644 --- a/core/trino-main/src/main/java/io/trino/operator/scalar/MapToMapCast.java +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/MapToMapCast.java @@ -16,19 +16,19 @@ import com.google.common.collect.ImmutableList; import io.airlift.slice.Slice; import io.trino.annotation.UsedByGeneratedCode; -import io.trino.metadata.BoundSignature; -import io.trino.metadata.FunctionDependencies; -import io.trino.metadata.FunctionDependencyDeclaration; -import io.trino.metadata.FunctionMetadata; -import io.trino.metadata.FunctionNullability; -import io.trino.metadata.Signature; import io.trino.metadata.SqlScalarFunction; import io.trino.operator.aggregation.TypedSet; import io.trino.spi.TrinoException; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.function.BoundSignature; +import io.trino.spi.function.FunctionDependencies; +import io.trino.spi.function.FunctionDependencyDeclaration; +import io.trino.spi.function.FunctionMetadata; +import io.trino.spi.function.FunctionNullability; import io.trino.spi.function.InvocationConvention; +import io.trino.spi.function.Signature; import io.trino.spi.type.MapType; import io.trino.spi.type.Type; import io.trino.spi.type.TypeSignature; @@ -111,7 +111,7 @@ public FunctionDependencyDeclaration getFunctionDependencies() } @Override - public ScalarFunctionImplementation specialize(BoundSignature boundSignature, FunctionDependencies functionDependencies) + public SpecializedSqlScalarFunction specialize(BoundSignature boundSignature, FunctionDependencies functionDependencies) { checkArgument(boundSignature.getArity() == 1, "Expected arity to be 1"); MapType fromMapType = (MapType) boundSignature.getArgumentType(0); @@ -126,7 +126,7 @@ public ScalarFunctionImplementation specialize(BoundSignature boundSignature, Fu BlockPositionIsDistinctFrom keyEqual = blockTypeOperators.getDistinctFromOperator(toKeyType); BlockPositionHashCode keyHashCode = blockTypeOperators.getHashCodeOperator(toKeyType); MethodHandle target = MethodHandles.insertArguments(METHOD_HANDLE, 0, keyProcessor, valueProcessor, toMapType, keyEqual, keyHashCode); - return new ChoicesScalarFunctionImplementation(boundSignature, NULLABLE_RETURN, ImmutableList.of(NEVER_NULL), target); + return new ChoicesSpecializedSqlScalarFunction(boundSignature, NULLABLE_RETURN, ImmutableList.of(NEVER_NULL), target); } /** @@ -138,7 +138,7 @@ private MethodHandle buildProcessor(FunctionDependencies functionDependencies, T // Get block position cast, with optional connector session FunctionNullability functionNullability = functionDependencies.getCastNullability(fromType, toType); InvocationConvention invocationConvention = new InvocationConvention(ImmutableList.of(BLOCK_POSITION), functionNullability.isReturnNullable() ? NULLABLE_RETURN : FAIL_ON_NULL, true, false); - MethodHandle cast = functionDependencies.getCastInvoker(fromType, toType, invocationConvention).getMethodHandle(); + MethodHandle cast = functionDependencies.getCastImplementation(fromType, toType, invocationConvention).getMethodHandle(); // Normalize cast to have connector session as first argument if (cast.type().parameterArray()[0] != ConnectorSession.class) { cast = dropArguments(cast, 0, ConnectorSession.class); diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/MapTransformKeysFunction.java b/core/trino-main/src/main/java/io/trino/operator/scalar/MapTransformKeysFunction.java index 4201c2d06f16..4e01c7d22e16 100644 --- a/core/trino-main/src/main/java/io/trino/operator/scalar/MapTransformKeysFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/MapTransformKeysFunction.java @@ -25,9 +25,6 @@ import io.airlift.bytecode.control.ForLoop; import io.airlift.bytecode.control.IfStatement; import io.trino.annotation.UsedByGeneratedCode; -import io.trino.metadata.BoundSignature; -import io.trino.metadata.FunctionMetadata; -import io.trino.metadata.Signature; import io.trino.metadata.SqlScalarFunction; import io.trino.operator.aggregation.TypedSet; import io.trino.spi.ErrorCodeSupplier; @@ -36,6 +33,9 @@ import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.function.BoundSignature; +import io.trino.spi.function.FunctionMetadata; +import io.trino.spi.function.Signature; import io.trino.spi.type.MapType; import io.trino.spi.type.Type; import io.trino.spi.type.TypeSignature; @@ -110,7 +110,7 @@ public MapTransformKeysFunction(BlockTypeOperators blockTypeOperators) } @Override - protected ScalarFunctionImplementation specialize(BoundSignature boundSignature) + protected SpecializedSqlScalarFunction specialize(BoundSignature boundSignature) { MapType inputMapType = (MapType) boundSignature.getArgumentType(0); Type inputKeyType = inputMapType.getKeyType(); @@ -118,7 +118,7 @@ protected ScalarFunctionImplementation specialize(BoundSignature boundSignature) Type outputKeyType = outputMapType.getKeyType(); Type valueType = outputMapType.getValueType(); - return new ChoicesScalarFunctionImplementation( + return new ChoicesSpecializedSqlScalarFunction( boundSignature, FAIL_ON_NULL, ImmutableList.of(NEVER_NULL, FUNCTION), diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/MapTransformValuesFunction.java b/core/trino-main/src/main/java/io/trino/operator/scalar/MapTransformValuesFunction.java index 6cbcac79e263..2b0cbfde1833 100644 --- a/core/trino-main/src/main/java/io/trino/operator/scalar/MapTransformValuesFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/MapTransformValuesFunction.java @@ -27,15 +27,15 @@ import io.airlift.bytecode.control.IfStatement; import io.airlift.bytecode.control.TryCatch; import io.trino.annotation.UsedByGeneratedCode; -import io.trino.metadata.BoundSignature; -import io.trino.metadata.FunctionMetadata; -import io.trino.metadata.Signature; import io.trino.metadata.SqlScalarFunction; import io.trino.spi.ErrorCodeSupplier; import io.trino.spi.PageBuilder; import io.trino.spi.TrinoException; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.function.BoundSignature; +import io.trino.spi.function.FunctionMetadata; +import io.trino.spi.function.Signature; import io.trino.spi.type.MapType; import io.trino.spi.type.Type; import io.trino.spi.type.TypeSignature; @@ -100,7 +100,7 @@ private MapTransformValuesFunction() } @Override - protected ScalarFunctionImplementation specialize(BoundSignature boundSignature) + protected SpecializedSqlScalarFunction specialize(BoundSignature boundSignature) { MapType inputMapType = (MapType) boundSignature.getArgumentType(0); Type inputValueType = inputMapType.getValueType(); @@ -108,7 +108,7 @@ protected ScalarFunctionImplementation specialize(BoundSignature boundSignature) Type keyType = outputMapType.getKeyType(); Type outputValueType = outputMapType.getValueType(); - return new ChoicesScalarFunctionImplementation( + return new ChoicesSpecializedSqlScalarFunction( boundSignature, FAIL_ON_NULL, ImmutableList.of(NEVER_NULL, FUNCTION), diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/MapZipWithFunction.java b/core/trino-main/src/main/java/io/trino/operator/scalar/MapZipWithFunction.java index d032d3484f23..4c5b3165e52b 100644 --- a/core/trino-main/src/main/java/io/trino/operator/scalar/MapZipWithFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/MapZipWithFunction.java @@ -14,14 +14,14 @@ package io.trino.operator.scalar; import com.google.common.collect.ImmutableList; -import io.trino.metadata.BoundSignature; -import io.trino.metadata.FunctionMetadata; -import io.trino.metadata.Signature; import io.trino.metadata.SqlScalarFunction; import io.trino.spi.PageBuilder; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; import io.trino.spi.block.SingleMapBlock; +import io.trino.spi.function.BoundSignature; +import io.trino.spi.function.FunctionMetadata; +import io.trino.spi.function.Signature; import io.trino.spi.type.MapType; import io.trino.spi.type.Type; import io.trino.spi.type.TypeSignature; @@ -68,13 +68,13 @@ private MapZipWithFunction() } @Override - protected ScalarFunctionImplementation specialize(BoundSignature boundSignature) + protected SpecializedSqlScalarFunction specialize(BoundSignature boundSignature) { MapType outputMapType = (MapType) boundSignature.getReturnType(); Type keyType = outputMapType.getKeyType(); Type inputValueType1 = ((MapType) boundSignature.getArgumentType(0)).getValueType(); Type inputValueType2 = ((MapType) boundSignature.getArgumentType(1)).getValueType(); - return new ChoicesScalarFunctionImplementation( + return new ChoicesSpecializedSqlScalarFunction( boundSignature, FAIL_ON_NULL, ImmutableList.of(NEVER_NULL, NEVER_NULL, FUNCTION), diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/MathFunctions.java b/core/trino-main/src/main/java/io/trino/operator/scalar/MathFunctions.java index c5009f121f4a..806bf47ed519 100644 --- a/core/trino-main/src/main/java/io/trino/operator/scalar/MathFunctions.java +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/MathFunctions.java @@ -18,7 +18,6 @@ import com.google.common.primitives.Shorts; import com.google.common.primitives.SignedBytes; import io.airlift.slice.Slice; -import io.trino.metadata.Signature; import io.trino.metadata.SqlScalarFunction; import io.trino.operator.aggregation.TypedSet; import io.trino.spi.TrinoException; @@ -29,6 +28,7 @@ import io.trino.spi.function.LiteralParameters; import io.trino.spi.function.OperatorDependency; import io.trino.spi.function.ScalarFunction; +import io.trino.spi.function.Signature; import io.trino.spi.function.SqlNullable; import io.trino.spi.function.SqlType; import io.trino.spi.type.Decimals; diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/ParametricScalar.java b/core/trino-main/src/main/java/io/trino/operator/scalar/ParametricScalar.java index 487fa21d387f..1651fb71a69c 100644 --- a/core/trino-main/src/main/java/io/trino/operator/scalar/ParametricScalar.java +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/ParametricScalar.java @@ -14,14 +14,7 @@ package io.trino.operator.scalar; import com.google.common.annotations.VisibleForTesting; -import io.trino.metadata.BoundSignature; import io.trino.metadata.FunctionBinding; -import io.trino.metadata.FunctionDependencies; -import io.trino.metadata.FunctionDependencyDeclaration; -import io.trino.metadata.FunctionDependencyDeclaration.FunctionDependencyDeclarationBuilder; -import io.trino.metadata.FunctionMetadata; -import io.trino.metadata.FunctionNullability; -import io.trino.metadata.Signature; import io.trino.metadata.SignatureBinder; import io.trino.metadata.SqlScalarFunction; import io.trino.operator.ParametricImplementationsGroup; @@ -29,6 +22,13 @@ import io.trino.operator.scalar.annotations.ParametricScalarImplementation; import io.trino.operator.scalar.annotations.ParametricScalarImplementation.ParametricScalarImplementationChoice; import io.trino.spi.TrinoException; +import io.trino.spi.function.BoundSignature; +import io.trino.spi.function.FunctionDependencies; +import io.trino.spi.function.FunctionDependencyDeclaration; +import io.trino.spi.function.FunctionDependencyDeclaration.FunctionDependencyDeclarationBuilder; +import io.trino.spi.function.FunctionMetadata; +import io.trino.spi.function.FunctionNullability; +import io.trino.spi.function.Signature; import java.util.Collection; import java.util.Optional; @@ -117,21 +117,21 @@ private static void declareDependencies(FunctionDependencyDeclarationBuilder bui } @Override - public ScalarFunctionImplementation specialize(BoundSignature boundSignature, FunctionDependencies functionDependencies) + public SpecializedSqlScalarFunction specialize(BoundSignature boundSignature, FunctionDependencies functionDependencies) { FunctionMetadata metadata = getFunctionMetadata(); FunctionBinding functionBinding = SignatureBinder.bindFunction(metadata.getFunctionId(), metadata.getSignature(), boundSignature); ParametricScalarImplementation exactImplementation = implementations.getExactImplementations().get(boundSignature.toSignature()); if (exactImplementation != null) { - Optional scalarFunctionImplementation = exactImplementation.specialize(functionBinding, functionDependencies); + Optional scalarFunctionImplementation = exactImplementation.specialize(functionBinding, functionDependencies); checkCondition(scalarFunctionImplementation.isPresent(), FUNCTION_IMPLEMENTATION_ERROR, format("Exact implementation of %s do not match expected java types.", boundSignature.getName())); return scalarFunctionImplementation.get(); } - ScalarFunctionImplementation selectedImplementation = null; + SpecializedSqlScalarFunction selectedImplementation = null; for (ParametricScalarImplementation implementation : implementations.getSpecializedImplementations()) { - Optional scalarFunctionImplementation = implementation.specialize(functionBinding, functionDependencies); + Optional scalarFunctionImplementation = implementation.specialize(functionBinding, functionDependencies); if (scalarFunctionImplementation.isPresent()) { checkCondition(selectedImplementation == null, AMBIGUOUS_FUNCTION_IMPLEMENTATION, "Ambiguous implementation for %s with bindings %s", metadata.getSignature(), boundSignature); selectedImplementation = scalarFunctionImplementation.get(); @@ -141,7 +141,7 @@ public ScalarFunctionImplementation specialize(BoundSignature boundSignature, Fu return selectedImplementation; } for (ParametricScalarImplementation implementation : implementations.getGenericImplementations()) { - Optional scalarFunctionImplementation = implementation.specialize(functionBinding, functionDependencies); + Optional scalarFunctionImplementation = implementation.specialize(functionBinding, functionDependencies); if (scalarFunctionImplementation.isPresent()) { checkCondition(selectedImplementation == null, AMBIGUOUS_FUNCTION_IMPLEMENTATION, "Ambiguous implementation for %s with bindings %s", metadata.getSignature(), boundSignature); selectedImplementation = scalarFunctionImplementation.get(); diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/Re2JCastToRegexpFunction.java b/core/trino-main/src/main/java/io/trino/operator/scalar/Re2JCastToRegexpFunction.java index c0ee022c2fad..d5c212a5a387 100644 --- a/core/trino-main/src/main/java/io/trino/operator/scalar/Re2JCastToRegexpFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/Re2JCastToRegexpFunction.java @@ -17,11 +17,11 @@ import com.google.common.collect.ImmutableSet; import io.airlift.slice.Slice; import io.trino.annotation.UsedByGeneratedCode; -import io.trino.metadata.BoundSignature; -import io.trino.metadata.FunctionMetadata; -import io.trino.metadata.Signature; import io.trino.metadata.SqlScalarFunction; import io.trino.spi.TrinoException; +import io.trino.spi.function.BoundSignature; +import io.trino.spi.function.FunctionMetadata; +import io.trino.spi.function.Signature; import io.trino.spi.type.Type; import io.trino.type.Re2JRegexp; @@ -71,11 +71,11 @@ private Re2JCastToRegexpFunction(String sourceType, int dfaStatesLimit, int dfaR } @Override - protected ScalarFunctionImplementation specialize(BoundSignature boundSignature) + protected SpecializedSqlScalarFunction specialize(BoundSignature boundSignature) { Type inputType = boundSignature.getArgumentType(0); Long typeLength = inputType.getTypeSignature().getParameters().get(0).getLongLiteral(); - return new ChoicesScalarFunctionImplementation( + return new ChoicesSpecializedSqlScalarFunction( boundSignature, FAIL_ON_NULL, ImmutableList.of(NEVER_NULL), diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/RowToJsonCast.java b/core/trino-main/src/main/java/io/trino/operator/scalar/RowToJsonCast.java index f291a2badf06..62c3714222fc 100644 --- a/core/trino-main/src/main/java/io/trino/operator/scalar/RowToJsonCast.java +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/RowToJsonCast.java @@ -19,12 +19,12 @@ import io.airlift.slice.Slice; import io.airlift.slice.SliceOutput; import io.trino.annotation.UsedByGeneratedCode; -import io.trino.metadata.BoundSignature; -import io.trino.metadata.FunctionMetadata; -import io.trino.metadata.Signature; import io.trino.metadata.SqlScalarFunction; -import io.trino.metadata.TypeVariableConstraint; import io.trino.spi.block.Block; +import io.trino.spi.function.BoundSignature; +import io.trino.spi.function.FunctionMetadata; +import io.trino.spi.function.Signature; +import io.trino.spi.function.TypeVariableConstraint; import io.trino.spi.type.Type; import io.trino.spi.type.TypeSignature; import io.trino.spi.type.TypeSignatureParameter; @@ -77,7 +77,7 @@ private RowToJsonCast(boolean legacyRowToJson) } @Override - protected ScalarFunctionImplementation specialize(BoundSignature boundSignature) + protected SpecializedSqlScalarFunction specialize(BoundSignature boundSignature) { Type type = boundSignature.getArgumentType(0); checkCondition(canCastToJson(type), INVALID_CAST_ARGUMENT, "Cannot cast %s to JSON", type); @@ -102,7 +102,7 @@ protected ScalarFunctionImplementation specialize(BoundSignature boundSignature) methodHandle = METHOD_HANDLE.bindTo(fieldNames).bindTo(fieldWriters); } - return new ChoicesScalarFunctionImplementation( + return new ChoicesSpecializedSqlScalarFunction( boundSignature, FAIL_ON_NULL, ImmutableList.of(NEVER_NULL), diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/RowToRowCast.java b/core/trino-main/src/main/java/io/trino/operator/scalar/RowToRowCast.java index 702d2f811018..d96baba6fe64 100644 --- a/core/trino-main/src/main/java/io/trino/operator/scalar/RowToRowCast.java +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/RowToRowCast.java @@ -23,21 +23,21 @@ import io.airlift.bytecode.Parameter; import io.airlift.bytecode.Scope; import io.airlift.bytecode.Variable; -import io.trino.metadata.BoundSignature; -import io.trino.metadata.FunctionDependencies; -import io.trino.metadata.FunctionDependencyDeclaration; -import io.trino.metadata.FunctionDependencyDeclaration.FunctionDependencyDeclarationBuilder; -import io.trino.metadata.FunctionMetadata; -import io.trino.metadata.Signature; import io.trino.metadata.SqlScalarFunction; -import io.trino.metadata.TypeVariableConstraint; import io.trino.spi.StandardErrorCode; import io.trino.spi.TrinoException; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; import io.trino.spi.block.BlockBuilderStatus; import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.function.BoundSignature; +import io.trino.spi.function.FunctionDependencies; +import io.trino.spi.function.FunctionDependencyDeclaration; +import io.trino.spi.function.FunctionDependencyDeclaration.FunctionDependencyDeclarationBuilder; +import io.trino.spi.function.FunctionMetadata; import io.trino.spi.function.InvocationConvention; +import io.trino.spi.function.Signature; +import io.trino.spi.function.TypeVariableConstraint; import io.trino.spi.type.Type; import io.trino.spi.type.TypeSignature; import io.trino.sql.gen.CachedInstanceBinder; @@ -140,7 +140,7 @@ public FunctionDependencyDeclaration getFunctionDependencies(BoundSignature boun } @Override - public ScalarFunctionImplementation specialize(BoundSignature boundSignature, FunctionDependencies functionDependencies) + public SpecializedSqlScalarFunction specialize(BoundSignature boundSignature, FunctionDependencies functionDependencies) { Type fromType = boundSignature.getArgumentType(0); Type toType = boundSignature.getReturnType(); @@ -149,7 +149,7 @@ public ScalarFunctionImplementation specialize(BoundSignature boundSignature, Fu } Class castOperatorClass = generateRowCast(fromType, toType, functionDependencies); MethodHandle methodHandle = methodHandle(castOperatorClass, "castRow", ConnectorSession.class, Block.class); - return new ChoicesScalarFunctionImplementation( + return new ChoicesSpecializedSqlScalarFunction( boundSignature, FAIL_ON_NULL, ImmutableList.of(NEVER_NULL), @@ -271,7 +271,7 @@ else if (type.getJavaType() == double.class) { private static MethodHandle getNullSafeCast(FunctionDependencies functionDependencies, Type fromElementType, Type toElementType) { - MethodHandle castMethod = functionDependencies.getCastInvoker( + MethodHandle castMethod = functionDependencies.getCastImplementation( fromElementType, toElementType, new InvocationConvention(ImmutableList.of(BLOCK_POSITION), NULLABLE_RETURN, true, false)) diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/ScalarFunctionImplementation.java b/core/trino-main/src/main/java/io/trino/operator/scalar/SpecializedSqlScalarFunction.java similarity index 75% rename from core/trino-main/src/main/java/io/trino/operator/scalar/ScalarFunctionImplementation.java rename to core/trino-main/src/main/java/io/trino/operator/scalar/SpecializedSqlScalarFunction.java index 2a5a72fb0a98..c9e6f0a5fe71 100644 --- a/core/trino-main/src/main/java/io/trino/operator/scalar/ScalarFunctionImplementation.java +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/SpecializedSqlScalarFunction.java @@ -13,10 +13,10 @@ */ package io.trino.operator.scalar; -import io.trino.metadata.FunctionInvoker; import io.trino.spi.function.InvocationConvention; +import io.trino.spi.function.ScalarFunctionImplementation; -public interface ScalarFunctionImplementation +public interface SpecializedSqlScalarFunction { - FunctionInvoker getScalarFunctionInvoker(InvocationConvention invocationConvention); + ScalarFunctionImplementation getScalarFunctionImplementation(InvocationConvention invocationConvention); } diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/TryCastFunction.java b/core/trino-main/src/main/java/io/trino/operator/scalar/TryCastFunction.java index c45bc3196f58..6fcb5c21e19a 100644 --- a/core/trino-main/src/main/java/io/trino/operator/scalar/TryCastFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/TryCastFunction.java @@ -14,13 +14,13 @@ package io.trino.operator.scalar; import com.google.common.collect.ImmutableList; -import io.trino.metadata.BoundSignature; -import io.trino.metadata.FunctionDependencies; -import io.trino.metadata.FunctionDependencyDeclaration; -import io.trino.metadata.FunctionMetadata; -import io.trino.metadata.Signature; import io.trino.metadata.SqlScalarFunction; +import io.trino.spi.function.BoundSignature; +import io.trino.spi.function.FunctionDependencies; +import io.trino.spi.function.FunctionDependencyDeclaration; +import io.trino.spi.function.FunctionMetadata; import io.trino.spi.function.InvocationConvention; +import io.trino.spi.function.Signature; import io.trino.spi.type.Type; import io.trino.spi.type.TypeSignature; @@ -64,7 +64,7 @@ public FunctionDependencyDeclaration getFunctionDependencies() } @Override - public ScalarFunctionImplementation specialize(BoundSignature boundSignature, FunctionDependencies functionDependencies) + public SpecializedSqlScalarFunction specialize(BoundSignature boundSignature, FunctionDependencies functionDependencies) { Type fromType = boundSignature.getArgumentType(0); Type toType = boundSignature.getReturnType(); @@ -73,13 +73,13 @@ public ScalarFunctionImplementation specialize(BoundSignature boundSignature, Fu // the resulting method needs to return a boxed type InvocationConvention invocationConvention = new InvocationConvention(ImmutableList.of(NEVER_NULL), NULLABLE_RETURN, true, false); - MethodHandle coercion = functionDependencies.getCastInvoker(fromType, toType, invocationConvention).getMethodHandle(); + MethodHandle coercion = functionDependencies.getCastImplementation(fromType, toType, invocationConvention).getMethodHandle(); coercion = coercion.asType(methodType(returnType, coercion.type())); MethodHandle exceptionHandler = dropArguments(constant(returnType, null), 0, RuntimeException.class); MethodHandle tryCastHandle = catchException(coercion, RuntimeException.class, exceptionHandler); - return new ChoicesScalarFunctionImplementation( + return new ChoicesSpecializedSqlScalarFunction( boundSignature, NULLABLE_RETURN, ImmutableList.of(NEVER_NULL), diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/VersionFunction.java b/core/trino-main/src/main/java/io/trino/operator/scalar/VersionFunction.java index 2a07656bf839..5b4ef6c6cac9 100644 --- a/core/trino-main/src/main/java/io/trino/operator/scalar/VersionFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/VersionFunction.java @@ -16,10 +16,10 @@ import com.google.common.collect.ImmutableList; import io.airlift.slice.Slice; import io.trino.annotation.UsedByGeneratedCode; -import io.trino.metadata.BoundSignature; -import io.trino.metadata.FunctionMetadata; -import io.trino.metadata.Signature; import io.trino.metadata.SqlScalarFunction; +import io.trino.spi.function.BoundSignature; +import io.trino.spi.function.FunctionMetadata; +import io.trino.spi.function.Signature; import java.lang.invoke.MethodHandle; @@ -48,10 +48,10 @@ public VersionFunction(String nodeVersion) } @Override - public ScalarFunctionImplementation specialize(BoundSignature boundSignature) + public SpecializedSqlScalarFunction specialize(BoundSignature boundSignature) { MethodHandle methodHandle = METHOD_HANDLE.bindTo(nodeVersion); - return new ChoicesScalarFunctionImplementation( + return new ChoicesSpecializedSqlScalarFunction( boundSignature, FAIL_ON_NULL, ImmutableList.of(), diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/ZipFunction.java b/core/trino-main/src/main/java/io/trino/operator/scalar/ZipFunction.java index 33e8b1137e29..e792ffc75ac5 100644 --- a/core/trino-main/src/main/java/io/trino/operator/scalar/ZipFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/ZipFunction.java @@ -14,13 +14,13 @@ package io.trino.operator.scalar; import io.trino.annotation.UsedByGeneratedCode; -import io.trino.metadata.BoundSignature; -import io.trino.metadata.FunctionMetadata; -import io.trino.metadata.Signature; import io.trino.metadata.SqlScalarFunction; -import io.trino.metadata.TypeVariableConstraint; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.function.BoundSignature; +import io.trino.spi.function.FunctionMetadata; +import io.trino.spi.function.Signature; +import io.trino.spi.function.TypeVariableConstraint; import io.trino.spi.type.ArrayType; import io.trino.spi.type.RowType; import io.trino.spi.type.Type; @@ -80,7 +80,7 @@ private ZipFunction(List typeParameters) } @Override - protected ScalarFunctionImplementation specialize(BoundSignature boundSignature) + protected SpecializedSqlScalarFunction specialize(BoundSignature boundSignature) { List types = boundSignature.getArgumentTypes().stream() .map(ArrayType.class::cast) @@ -88,7 +88,7 @@ protected ScalarFunctionImplementation specialize(BoundSignature boundSignature) .collect(toImmutableList()); List> javaArgumentTypes = nCopies(types.size(), Block.class); MethodHandle methodHandle = METHOD_HANDLE.bindTo(types).asVarargsCollector(Block[].class).asType(methodType(Block.class, javaArgumentTypes)); - return new ChoicesScalarFunctionImplementation( + return new ChoicesSpecializedSqlScalarFunction( boundSignature, FAIL_ON_NULL, nCopies(types.size(), NEVER_NULL), diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/ZipWithFunction.java b/core/trino-main/src/main/java/io/trino/operator/scalar/ZipWithFunction.java index 078ae255ede6..9c686c7c68fa 100644 --- a/core/trino-main/src/main/java/io/trino/operator/scalar/ZipWithFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/ZipWithFunction.java @@ -14,13 +14,13 @@ package io.trino.operator.scalar; import com.google.common.collect.ImmutableList; -import io.trino.metadata.BoundSignature; -import io.trino.metadata.FunctionMetadata; -import io.trino.metadata.Signature; import io.trino.metadata.SqlScalarFunction; import io.trino.spi.PageBuilder; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.function.BoundSignature; +import io.trino.spi.function.FunctionMetadata; +import io.trino.spi.function.Signature; import io.trino.spi.type.ArrayType; import io.trino.spi.type.Type; import io.trino.spi.type.TypeSignature; @@ -67,13 +67,13 @@ private ZipWithFunction() } @Override - protected ScalarFunctionImplementation specialize(BoundSignature boundSignature) + protected SpecializedSqlScalarFunction specialize(BoundSignature boundSignature) { Type leftElementType = ((ArrayType) boundSignature.getArgumentType(0)).getElementType(); Type rightElementType = ((ArrayType) boundSignature.getArgumentType(1)).getElementType(); Type outputElementType = ((ArrayType) boundSignature.getReturnType()).getElementType(); ArrayType outputArrayType = new ArrayType(outputElementType); - return new ChoicesScalarFunctionImplementation( + return new ChoicesSpecializedSqlScalarFunction( boundSignature, FAIL_ON_NULL, ImmutableList.of(NEVER_NULL, NEVER_NULL, FUNCTION), diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/annotations/ParametricScalarImplementation.java b/core/trino-main/src/main/java/io/trino/operator/scalar/annotations/ParametricScalarImplementation.java index fff0a6588d04..827061ca2f66 100644 --- a/core/trino-main/src/main/java/io/trino/operator/scalar/annotations/ParametricScalarImplementation.java +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/annotations/ParametricScalarImplementation.java @@ -18,25 +18,25 @@ import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import com.google.common.primitives.Primitives; -import io.trino.metadata.BoundSignature; import io.trino.metadata.FunctionBinding; -import io.trino.metadata.FunctionDependencies; -import io.trino.metadata.FunctionNullability; -import io.trino.metadata.Signature; import io.trino.operator.ParametricImplementation; import io.trino.operator.annotations.FunctionsParserHelper; import io.trino.operator.annotations.ImplementationDependency; -import io.trino.operator.scalar.ChoicesScalarFunctionImplementation; -import io.trino.operator.scalar.ChoicesScalarFunctionImplementation.ScalarImplementationChoice; -import io.trino.operator.scalar.ScalarFunctionImplementation; +import io.trino.operator.scalar.ChoicesSpecializedSqlScalarFunction; +import io.trino.operator.scalar.ChoicesSpecializedSqlScalarFunction.ScalarImplementationChoice; +import io.trino.operator.scalar.SpecializedSqlScalarFunction; import io.trino.spi.block.Block; import io.trino.spi.connector.ConnectorSession; import io.trino.spi.function.BlockIndex; import io.trino.spi.function.BlockPosition; +import io.trino.spi.function.BoundSignature; +import io.trino.spi.function.FunctionDependencies; +import io.trino.spi.function.FunctionNullability; import io.trino.spi.function.InOut; import io.trino.spi.function.InvocationConvention.InvocationArgumentConvention; import io.trino.spi.function.InvocationConvention.InvocationReturnConvention; import io.trino.spi.function.IsNull; +import io.trino.spi.function.Signature; import io.trino.spi.function.SqlNullable; import io.trino.spi.function.SqlType; import io.trino.spi.function.TypeParameter; @@ -146,7 +146,7 @@ public FunctionNullability getFunctionNullability() return functionNullability; } - public Optional specialize(FunctionBinding functionBinding, FunctionDependencies functionDependencies) + public Optional specialize(FunctionBinding functionBinding, FunctionDependencies functionDependencies) { List implementationChoices = new ArrayList<>(); for (Map.Entry> entry : specializedTypeParameters.entrySet()) { @@ -198,7 +198,7 @@ public Optional specialize(FunctionBinding functio boundMethodHandle.asType(javaMethodType(choice, boundSignature)), boundConstructor)); } - return Optional.of(new ChoicesScalarFunctionImplementation(boundSignature, implementationChoices)); + return Optional.of(new ChoicesSpecializedSqlScalarFunction(boundSignature, implementationChoices)); } @Override diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/annotations/ScalarFromAnnotationsParser.java b/core/trino-main/src/main/java/io/trino/operator/scalar/annotations/ScalarFromAnnotationsParser.java index febd5752dfe1..886e9b40b268 100644 --- a/core/trino-main/src/main/java/io/trino/operator/scalar/annotations/ScalarFromAnnotationsParser.java +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/annotations/ScalarFromAnnotationsParser.java @@ -15,7 +15,6 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; -import io.trino.metadata.Signature; import io.trino.metadata.SqlScalarFunction; import io.trino.operator.ParametricImplementationsGroup; import io.trino.operator.annotations.FunctionsParserHelper; @@ -23,6 +22,7 @@ import io.trino.operator.scalar.annotations.ParametricScalarImplementation.SpecializedSignature; import io.trino.spi.function.ScalarFunction; import io.trino.spi.function.ScalarOperator; +import io.trino.spi.function.Signature; import io.trino.spi.function.SqlType; import java.lang.reflect.Constructor; diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/annotations/ScalarImplementationHeader.java b/core/trino-main/src/main/java/io/trino/operator/scalar/annotations/ScalarImplementationHeader.java index 15ddb0b23ee0..57712a2a7740 100644 --- a/core/trino-main/src/main/java/io/trino/operator/scalar/annotations/ScalarImplementationHeader.java +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/annotations/ScalarImplementationHeader.java @@ -27,7 +27,7 @@ import static com.google.common.base.CaseFormat.LOWER_CAMEL; import static com.google.common.base.CaseFormat.LOWER_UNDERSCORE; import static com.google.common.base.Preconditions.checkArgument; -import static io.trino.metadata.Signature.mangleOperatorName; +import static io.trino.metadata.OperatorNameUtil.mangleOperatorName; import static io.trino.operator.annotations.FunctionsParserHelper.parseDescription; import static java.util.Objects.requireNonNull; diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/json/JsonArrayFunction.java b/core/trino-main/src/main/java/io/trino/operator/scalar/json/JsonArrayFunction.java index 697dfaf31c80..07cf06316ae3 100644 --- a/core/trino-main/src/main/java/io/trino/operator/scalar/json/JsonArrayFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/json/JsonArrayFunction.java @@ -20,14 +20,14 @@ import com.google.common.collect.ImmutableList; import io.trino.annotation.UsedByGeneratedCode; import io.trino.json.ir.TypedValue; -import io.trino.metadata.BoundSignature; -import io.trino.metadata.FunctionMetadata; -import io.trino.metadata.Signature; import io.trino.metadata.SqlScalarFunction; -import io.trino.operator.scalar.ChoicesScalarFunctionImplementation; -import io.trino.operator.scalar.ScalarFunctionImplementation; +import io.trino.operator.scalar.ChoicesSpecializedSqlScalarFunction; +import io.trino.operator.scalar.SpecializedSqlScalarFunction; import io.trino.spi.TrinoException; import io.trino.spi.block.Block; +import io.trino.spi.function.BoundSignature; +import io.trino.spi.function.FunctionMetadata; +import io.trino.spi.function.Signature; import io.trino.spi.type.RowType; import io.trino.spi.type.Type; import io.trino.spi.type.TypeSignature; @@ -73,12 +73,12 @@ private JsonArrayFunction() } @Override - protected ScalarFunctionImplementation specialize(BoundSignature boundSignature) + protected SpecializedSqlScalarFunction specialize(BoundSignature boundSignature) { RowType elementsRowType = (RowType) boundSignature.getArgumentType(0); MethodHandle methodHandle = METHOD_HANDLE .bindTo(elementsRowType); - return new ChoicesScalarFunctionImplementation( + return new ChoicesSpecializedSqlScalarFunction( boundSignature, FAIL_ON_NULL, ImmutableList.of(BOXED_NULLABLE, NEVER_NULL), diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/json/JsonExistsFunction.java b/core/trino-main/src/main/java/io/trino/operator/scalar/json/JsonExistsFunction.java index 7aa986c6db55..f11acf049410 100644 --- a/core/trino-main/src/main/java/io/trino/operator/scalar/json/JsonExistsFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/json/JsonExistsFunction.java @@ -20,17 +20,17 @@ import io.trino.json.JsonPathInvocationContext; import io.trino.json.PathEvaluationError; import io.trino.json.ir.IrJsonPath; -import io.trino.metadata.BoundSignature; import io.trino.metadata.FunctionManager; -import io.trino.metadata.FunctionMetadata; import io.trino.metadata.Metadata; -import io.trino.metadata.Signature; import io.trino.metadata.SqlScalarFunction; -import io.trino.operator.scalar.ChoicesScalarFunctionImplementation; -import io.trino.operator.scalar.ScalarFunctionImplementation; +import io.trino.operator.scalar.ChoicesSpecializedSqlScalarFunction; +import io.trino.operator.scalar.SpecializedSqlScalarFunction; import io.trino.spi.TrinoException; import io.trino.spi.block.Block; import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.function.BoundSignature; +import io.trino.spi.function.FunctionMetadata; +import io.trino.spi.function.Signature; import io.trino.spi.type.Type; import io.trino.spi.type.TypeManager; import io.trino.spi.type.TypeSignature; @@ -86,7 +86,7 @@ public JsonExistsFunction(FunctionManager functionManager, Metadata metadata, Ty } @Override - protected ScalarFunctionImplementation specialize(BoundSignature boundSignature) + protected SpecializedSqlScalarFunction specialize(BoundSignature boundSignature) { Type parametersRowType = boundSignature.getArgumentType(2); MethodHandle methodHandle = METHOD_HANDLE @@ -95,7 +95,7 @@ protected ScalarFunctionImplementation specialize(BoundSignature boundSignature) .bindTo(typeManager) .bindTo(parametersRowType); MethodHandle instanceFactory = constructorMethodHandle(JsonPathInvocationContext.class); - return new ChoicesScalarFunctionImplementation( + return new ChoicesSpecializedSqlScalarFunction( boundSignature, NULLABLE_RETURN, ImmutableList.of(BOXED_NULLABLE, BOXED_NULLABLE, BOXED_NULLABLE, NEVER_NULL), diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/json/JsonObjectFunction.java b/core/trino-main/src/main/java/io/trino/operator/scalar/json/JsonObjectFunction.java index 862c2a12487e..b42a9aeb9e09 100644 --- a/core/trino-main/src/main/java/io/trino/operator/scalar/json/JsonObjectFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/json/JsonObjectFunction.java @@ -21,14 +21,14 @@ import io.airlift.slice.Slice; import io.trino.annotation.UsedByGeneratedCode; import io.trino.json.ir.TypedValue; -import io.trino.metadata.BoundSignature; -import io.trino.metadata.FunctionMetadata; -import io.trino.metadata.Signature; import io.trino.metadata.SqlScalarFunction; -import io.trino.operator.scalar.ChoicesScalarFunctionImplementation; -import io.trino.operator.scalar.ScalarFunctionImplementation; +import io.trino.operator.scalar.ChoicesSpecializedSqlScalarFunction; +import io.trino.operator.scalar.SpecializedSqlScalarFunction; import io.trino.spi.TrinoException; import io.trino.spi.block.Block; +import io.trino.spi.function.BoundSignature; +import io.trino.spi.function.FunctionMetadata; +import io.trino.spi.function.Signature; import io.trino.spi.type.RowType; import io.trino.spi.type.Type; import io.trino.spi.type.TypeSignature; @@ -79,7 +79,7 @@ private JsonObjectFunction() } @Override - protected ScalarFunctionImplementation specialize(BoundSignature boundSignature) + protected SpecializedSqlScalarFunction specialize(BoundSignature boundSignature) { RowType keysRowType = (RowType) boundSignature.getArgumentType(0); RowType valuesRowType = (RowType) boundSignature.getArgumentType(1); @@ -87,7 +87,7 @@ protected ScalarFunctionImplementation specialize(BoundSignature boundSignature) MethodHandle methodHandle = METHOD_HANDLE .bindTo(keysRowType) .bindTo(valuesRowType); - return new ChoicesScalarFunctionImplementation( + return new ChoicesSpecializedSqlScalarFunction( boundSignature, FAIL_ON_NULL, ImmutableList.of(BOXED_NULLABLE, BOXED_NULLABLE, NEVER_NULL, NEVER_NULL), diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/json/JsonQueryFunction.java b/core/trino-main/src/main/java/io/trino/operator/scalar/json/JsonQueryFunction.java index eb6a7abed3f2..f21a3678aa7d 100644 --- a/core/trino-main/src/main/java/io/trino/operator/scalar/json/JsonQueryFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/json/JsonQueryFunction.java @@ -24,17 +24,17 @@ import io.trino.json.PathEvaluationError; import io.trino.json.ir.IrJsonPath; import io.trino.json.ir.TypedValue; -import io.trino.metadata.BoundSignature; import io.trino.metadata.FunctionManager; -import io.trino.metadata.FunctionMetadata; import io.trino.metadata.Metadata; -import io.trino.metadata.Signature; import io.trino.metadata.SqlScalarFunction; -import io.trino.operator.scalar.ChoicesScalarFunctionImplementation; -import io.trino.operator.scalar.ScalarFunctionImplementation; +import io.trino.operator.scalar.ChoicesSpecializedSqlScalarFunction; +import io.trino.operator.scalar.SpecializedSqlScalarFunction; import io.trino.spi.TrinoException; import io.trino.spi.block.Block; import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.function.BoundSignature; +import io.trino.spi.function.FunctionMetadata; +import io.trino.spi.function.Signature; import io.trino.spi.type.Type; import io.trino.spi.type.TypeManager; import io.trino.spi.type.TypeSignature; @@ -102,7 +102,7 @@ public JsonQueryFunction(FunctionManager functionManager, Metadata metadata, Typ } @Override - protected ScalarFunctionImplementation specialize(BoundSignature boundSignature) + protected SpecializedSqlScalarFunction specialize(BoundSignature boundSignature) { Type parametersRowType = boundSignature.getArgumentType(2); MethodHandle methodHandle = METHOD_HANDLE @@ -111,7 +111,7 @@ protected ScalarFunctionImplementation specialize(BoundSignature boundSignature) .bindTo(typeManager) .bindTo(parametersRowType); MethodHandle instanceFactory = constructorMethodHandle(JsonPathInvocationContext.class); - return new ChoicesScalarFunctionImplementation( + return new ChoicesSpecializedSqlScalarFunction( boundSignature, NULLABLE_RETURN, ImmutableList.of(BOXED_NULLABLE, BOXED_NULLABLE, BOXED_NULLABLE, NEVER_NULL, NEVER_NULL, NEVER_NULL), diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/json/JsonValueFunction.java b/core/trino-main/src/main/java/io/trino/operator/scalar/json/JsonValueFunction.java index d337b5a33998..a2f848f434de 100644 --- a/core/trino-main/src/main/java/io/trino/operator/scalar/json/JsonValueFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/json/JsonValueFunction.java @@ -25,19 +25,19 @@ import io.trino.json.ir.IrJsonPath; import io.trino.json.ir.SqlJsonLiteralConverter.JsonLiteralConversionError; import io.trino.json.ir.TypedValue; -import io.trino.metadata.BoundSignature; import io.trino.metadata.FunctionManager; -import io.trino.metadata.FunctionMetadata; import io.trino.metadata.Metadata; import io.trino.metadata.OperatorNotFoundException; import io.trino.metadata.ResolvedFunction; -import io.trino.metadata.Signature; import io.trino.metadata.SqlScalarFunction; -import io.trino.operator.scalar.ChoicesScalarFunctionImplementation; -import io.trino.operator.scalar.ScalarFunctionImplementation; +import io.trino.operator.scalar.ChoicesSpecializedSqlScalarFunction; +import io.trino.operator.scalar.SpecializedSqlScalarFunction; import io.trino.spi.TrinoException; import io.trino.spi.block.Block; import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.function.BoundSignature; +import io.trino.spi.function.FunctionMetadata; +import io.trino.spi.function.Signature; import io.trino.spi.type.Type; import io.trino.spi.type.TypeManager; import io.trino.spi.type.TypeSignature; @@ -112,7 +112,7 @@ public JsonValueFunction(FunctionManager functionManager, Metadata metadata, Typ } @Override - protected ScalarFunctionImplementation specialize(BoundSignature boundSignature) + protected SpecializedSqlScalarFunction specialize(BoundSignature boundSignature) { Type parametersRowType = boundSignature.getArgumentType(2); Type returnType = boundSignature.getReturnType(); @@ -140,7 +140,7 @@ else if (returnType.getJavaType().equals(Slice.class)) { .bindTo(parametersRowType) .bindTo(returnType); MethodHandle instanceFactory = constructorMethodHandle(JsonPathInvocationContext.class); - return new ChoicesScalarFunctionImplementation( + return new ChoicesSpecializedSqlScalarFunction( boundSignature, NULLABLE_RETURN, ImmutableList.of(BOXED_NULLABLE, BOXED_NULLABLE, BOXED_NULLABLE, NEVER_NULL, BOXED_NULLABLE, NEVER_NULL, BOXED_NULLABLE), diff --git a/core/trino-main/src/main/java/io/trino/operator/window/AggregationWindowFunctionSupplier.java b/core/trino-main/src/main/java/io/trino/operator/window/AggregationWindowFunctionSupplier.java index 8f38d11dd62b..a30dc2418965 100644 --- a/core/trino-main/src/main/java/io/trino/operator/window/AggregationWindowFunctionSupplier.java +++ b/core/trino-main/src/main/java/io/trino/operator/window/AggregationWindowFunctionSupplier.java @@ -13,11 +13,12 @@ */ package io.trino.operator.window; -import io.trino.metadata.BoundSignature; -import io.trino.metadata.FunctionNullability; -import io.trino.operator.aggregation.AggregationMetadata; import io.trino.operator.aggregation.WindowAccumulator; +import io.trino.spi.function.AggregationImplementation; +import io.trino.spi.function.BoundSignature; +import io.trino.spi.function.FunctionNullability; import io.trino.spi.function.WindowFunction; +import io.trino.spi.function.WindowFunctionSupplier; import java.lang.reflect.Constructor; import java.util.List; @@ -33,13 +34,13 @@ public class AggregationWindowFunctionSupplier private final boolean hasRemoveInput; private final List> lambdaInterfaces; - public AggregationWindowFunctionSupplier(BoundSignature boundSignature, AggregationMetadata aggregationMetadata, FunctionNullability functionNullability) + public AggregationWindowFunctionSupplier(BoundSignature boundSignature, AggregationImplementation aggregationImplementation, FunctionNullability functionNullability) { requireNonNull(boundSignature, "boundSignature is null"); - requireNonNull(aggregationMetadata, "aggregationMetadata is null"); - constructor = generateWindowAccumulatorClass(boundSignature, aggregationMetadata, functionNullability); - hasRemoveInput = aggregationMetadata.getRemoveInputFunction().isPresent(); - lambdaInterfaces = aggregationMetadata.getLambdaInterfaces(); + requireNonNull(aggregationImplementation, "aggregationMetadata is null"); + constructor = generateWindowAccumulatorClass(boundSignature, aggregationImplementation, functionNullability); + hasRemoveInput = aggregationImplementation.getRemoveInputFunction().isPresent(); + lambdaInterfaces = aggregationImplementation.getLambdaInterfaces(); } @Override diff --git a/core/trino-main/src/main/java/io/trino/operator/window/ReflectionWindowFunctionSupplier.java b/core/trino-main/src/main/java/io/trino/operator/window/ReflectionWindowFunctionSupplier.java index c99e5e431c0e..8a7fc503910e 100644 --- a/core/trino-main/src/main/java/io/trino/operator/window/ReflectionWindowFunctionSupplier.java +++ b/core/trino-main/src/main/java/io/trino/operator/window/ReflectionWindowFunctionSupplier.java @@ -16,6 +16,7 @@ import com.google.common.base.VerifyException; import com.google.common.collect.ImmutableList; import io.trino.spi.function.WindowFunction; +import io.trino.spi.function.WindowFunctionSupplier; import java.lang.reflect.Constructor; import java.util.List; diff --git a/core/trino-main/src/main/java/io/trino/operator/window/SqlWindowFunction.java b/core/trino-main/src/main/java/io/trino/operator/window/SqlWindowFunction.java index 8f0344f1a41c..f69f9dd9728f 100644 --- a/core/trino-main/src/main/java/io/trino/operator/window/SqlWindowFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/window/SqlWindowFunction.java @@ -13,11 +13,12 @@ */ package io.trino.operator.window; -import io.trino.metadata.BoundSignature; -import io.trino.metadata.FunctionDependencies; -import io.trino.metadata.FunctionMetadata; -import io.trino.metadata.Signature; import io.trino.metadata.SqlFunction; +import io.trino.spi.function.BoundSignature; +import io.trino.spi.function.FunctionDependencies; +import io.trino.spi.function.FunctionMetadata; +import io.trino.spi.function.Signature; +import io.trino.spi.function.WindowFunctionSupplier; import java.util.Optional; diff --git a/core/trino-main/src/main/java/io/trino/operator/window/WindowAnnotationsParser.java b/core/trino-main/src/main/java/io/trino/operator/window/WindowAnnotationsParser.java index 4960beb616e4..652d2288b05a 100644 --- a/core/trino-main/src/main/java/io/trino/operator/window/WindowAnnotationsParser.java +++ b/core/trino-main/src/main/java/io/trino/operator/window/WindowAnnotationsParser.java @@ -14,8 +14,8 @@ package io.trino.operator.window; import com.google.common.collect.ImmutableSet; -import io.trino.metadata.Signature; import io.trino.spi.function.Description; +import io.trino.spi.function.Signature; import io.trino.spi.function.WindowFunction; import io.trino.spi.function.WindowFunctionSignature; diff --git a/core/trino-main/src/main/java/io/trino/operator/window/pattern/MatchAggregation.java b/core/trino-main/src/main/java/io/trino/operator/window/pattern/MatchAggregation.java index 0c67e3fdd02c..9ab111087620 100644 --- a/core/trino-main/src/main/java/io/trino/operator/window/pattern/MatchAggregation.java +++ b/core/trino-main/src/main/java/io/trino/operator/window/pattern/MatchAggregation.java @@ -15,7 +15,6 @@ import io.trino.memory.context.AggregatedMemoryContext; import io.trino.memory.context.LocalMemoryContext; -import io.trino.metadata.BoundSignature; import io.trino.operator.aggregation.WindowAccumulator; import io.trino.operator.window.AggregationWindowFunctionSupplier; import io.trino.operator.window.MappedWindowIndex; @@ -24,6 +23,7 @@ import io.trino.spi.TrinoException; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.function.BoundSignature; import java.util.List; import java.util.function.Supplier; diff --git a/core/trino-main/src/main/java/io/trino/sql/InterpretedFunctionInvoker.java b/core/trino-main/src/main/java/io/trino/sql/InterpretedFunctionInvoker.java index 15f8f772349e..d6ecb5a0ca54 100644 --- a/core/trino-main/src/main/java/io/trino/sql/InterpretedFunctionInvoker.java +++ b/core/trino-main/src/main/java/io/trino/sql/InterpretedFunctionInvoker.java @@ -14,14 +14,14 @@ package io.trino.sql; import com.google.common.collect.ImmutableList; -import io.trino.metadata.BoundSignature; -import io.trino.metadata.FunctionInvoker; import io.trino.metadata.FunctionManager; -import io.trino.metadata.FunctionNullability; import io.trino.metadata.ResolvedFunction; import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.function.BoundSignature; +import io.trino.spi.function.FunctionNullability; import io.trino.spi.function.InvocationConvention; import io.trino.spi.function.InvocationConvention.InvocationArgumentConvention; +import io.trino.spi.function.ScalarFunctionImplementation; import io.trino.spi.type.Type; import io.trino.type.FunctionType; @@ -60,15 +60,15 @@ public Object invoke(ResolvedFunction function, ConnectorSession session, Object */ public Object invoke(ResolvedFunction function, ConnectorSession session, List arguments) { - FunctionInvoker invoker = functionManager.getScalarFunctionInvoker(function, getInvocationConvention(function.getSignature(), function.getFunctionNullability())); - MethodHandle method = invoker.getMethodHandle(); + ScalarFunctionImplementation implementation = functionManager.getScalarFunctionImplementation(function, getInvocationConvention(function.getSignature(), function.getFunctionNullability())); + MethodHandle method = implementation.getMethodHandle(); List actualArguments = new ArrayList<>(); // handle function on instance method, to allow use of fields - if (invoker.getInstanceFactory().isPresent()) { + if (implementation.getInstanceFactory().isPresent()) { try { - actualArguments.add(invoker.getInstanceFactory().get().invoke()); + actualArguments.add(implementation.getInstanceFactory().get().invoke()); } catch (Throwable throwable) { throw propagate(throwable); @@ -90,7 +90,7 @@ public Object invoke(ResolvedFunction function, ConnectorSession session, List resolveTableFunction(TableFunctionInvocation node) { - for (CatalogSchemaFunctionName name : toPath(session, node.getName())) { + for (CatalogSchemaFunctionName name : toPath(session, toQualifiedFunctionName(node.getName()))) { CatalogHandle catalogHandle = getRequiredCatalogHandle(metadata, session, node, name.getCatalogName()); Optional resolved = tableFunctionRegistry.resolve(catalogHandle, name.getSchemaFunctionName()); if (resolved.isPresent()) { diff --git a/core/trino-main/src/main/java/io/trino/sql/gen/BytecodeGeneratorContext.java b/core/trino-main/src/main/java/io/trino/sql/gen/BytecodeGeneratorContext.java index 7eda9a0c8d9b..dc373ac44d93 100644 --- a/core/trino-main/src/main/java/io/trino/sql/gen/BytecodeGeneratorContext.java +++ b/core/trino-main/src/main/java/io/trino/sql/gen/BytecodeGeneratorContext.java @@ -16,10 +16,10 @@ import io.airlift.bytecode.BytecodeNode; import io.airlift.bytecode.Scope; import io.airlift.bytecode.Variable; -import io.trino.metadata.FunctionInvoker; import io.trino.metadata.FunctionManager; import io.trino.metadata.ResolvedFunction; import io.trino.spi.function.InvocationConvention; +import io.trino.spi.function.ScalarFunctionImplementation; import io.trino.sql.relational.RowExpression; import java.lang.invoke.MethodHandle; @@ -77,9 +77,9 @@ public BytecodeNode generate(RowExpression expression) return rowExpressionCompiler.compile(expression, scope); } - public FunctionInvoker getScalarFunctionInvoker(ResolvedFunction resolvedFunction, InvocationConvention invocationConvention) + public ScalarFunctionImplementation getScalarFunctionImplementation(ResolvedFunction resolvedFunction, InvocationConvention invocationConvention) { - return functionManager.getScalarFunctionInvoker(resolvedFunction, invocationConvention); + return functionManager.getScalarFunctionImplementation(resolvedFunction, invocationConvention); } /** diff --git a/core/trino-main/src/main/java/io/trino/sql/gen/BytecodeUtils.java b/core/trino-main/src/main/java/io/trino/sql/gen/BytecodeUtils.java index e3d10fef72f7..672698e39634 100644 --- a/core/trino-main/src/main/java/io/trino/sql/gen/BytecodeUtils.java +++ b/core/trino-main/src/main/java/io/trino/sql/gen/BytecodeUtils.java @@ -25,16 +25,16 @@ import io.airlift.bytecode.expression.BytecodeExpression; import io.airlift.bytecode.instruction.LabelNode; import io.airlift.slice.Slice; -import io.trino.metadata.BoundSignature; -import io.trino.metadata.FunctionInvoker; import io.trino.metadata.FunctionManager; -import io.trino.metadata.FunctionNullability; import io.trino.metadata.ResolvedFunction; import io.trino.spi.block.BlockBuilder; import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.function.BoundSignature; +import io.trino.spi.function.FunctionNullability; import io.trino.spi.function.InOut; import io.trino.spi.function.InvocationConvention; import io.trino.spi.function.InvocationConvention.InvocationArgumentConvention; +import io.trino.spi.function.ScalarFunctionImplementation; import io.trino.spi.type.Type; import io.trino.sql.gen.InputReferenceCompiler.InputReferenceNode; import io.trino.type.FunctionType; @@ -166,7 +166,7 @@ public static BytecodeNode generateInvocation( scope, resolvedFunction.getSignature().getName(), resolvedFunction.getFunctionNullability(), - invocationConvention -> functionManager.getScalarFunctionInvoker(resolvedFunction, invocationConvention), + invocationConvention -> functionManager.getScalarFunctionImplementation(resolvedFunction, invocationConvention), arguments, binder); } @@ -175,7 +175,7 @@ public static BytecodeNode generateInvocation( Scope scope, String functionName, FunctionNullability functionNullability, - Function functionInvokerProvider, + Function functionImplementationProvider, List arguments, CallSiteBinder binder) { @@ -184,7 +184,7 @@ public static BytecodeNode generateInvocation( functionName, functionNullability, Collections.nCopies(arguments.size(), false), - functionInvokerProvider, + functionImplementationProvider, instanceFactory -> { throw new IllegalArgumentException("Simple method invocation can not be used with functions that require an instance factory"); }, @@ -217,7 +217,7 @@ public static BytecodeNode generateFullInvocation( resolvedFunction.getSignature().getArgumentTypes().stream() .map(FunctionType.class::isInstance) .collect(toImmutableList()), - invocationConvention -> functionManager.getScalarFunctionInvoker(resolvedFunction, invocationConvention), + invocationConvention -> functionManager.getScalarFunctionImplementation(resolvedFunction, invocationConvention), instanceFactory, argumentCompilers, binder); @@ -228,7 +228,7 @@ private static BytecodeNode generateFullInvocation( String functionName, FunctionNullability functionNullability, List argumentIsFunctionType, - Function functionInvokerProvider, + Function functionImplementationProvider, Function instanceFactory, List>, BytecodeNode>> argumentCompilers, CallSiteBinder binder) @@ -253,15 +253,15 @@ private static BytecodeNode generateFullInvocation( functionNullability.isReturnNullable() ? NULLABLE_RETURN : FAIL_ON_NULL, true, true); - FunctionInvoker functionInvoker = functionInvokerProvider.apply(invocationConvention); + ScalarFunctionImplementation implementation = functionImplementationProvider.apply(invocationConvention); - Binding binding = binder.bind(functionInvoker.getMethodHandle()); + Binding binding = binder.bind(implementation.getMethodHandle()); LabelNode end = new LabelNode("end"); BytecodeBlock block = new BytecodeBlock() .setDescription("invoke " + functionName); - Optional instance = functionInvoker.getInstanceFactory() + Optional instance = implementation.getInstanceFactory() .map(instanceFactory); // Index of current parameter in the MethodHandle @@ -283,7 +283,7 @@ private static BytecodeNode generateFullInvocation( Class type = methodType.parameterArray()[currentParameterIndex]; stackTypes.add(type); if (instance.isPresent() && !instanceIsBound) { - checkState(type.equals(functionInvoker.getInstanceFactory().get().type().returnType()), "Mismatched type for instance parameter"); + checkState(type.equals(implementation.getInstanceFactory().get().type().returnType()), "Mismatched type for instance parameter"); block.append(instance.get()); instanceIsBound = true; } @@ -330,7 +330,7 @@ else if (type == ConnectorSession.class) { currentParameterIndex++; break; case FUNCTION: - Class lambdaInterface = functionInvoker.getLambdaInterfaces().get(lambdaArgumentIndex); + Class lambdaInterface = implementation.getLambdaInterfaces().get(lambdaArgumentIndex); block.append(argumentCompilers.get(realParameterIndex).apply(Optional.of(lambdaInterface))); lambdaArgumentIndex++; break; diff --git a/core/trino-main/src/main/java/io/trino/sql/gen/InCodeGenerator.java b/core/trino-main/src/main/java/io/trino/sql/gen/InCodeGenerator.java index 4391d2625247..91e4ee13727e 100644 --- a/core/trino-main/src/main/java/io/trino/sql/gen/InCodeGenerator.java +++ b/core/trino-main/src/main/java/io/trino/sql/gen/InCodeGenerator.java @@ -124,9 +124,9 @@ public BytecodeNode generateExpression(BytecodeGeneratorContext generatorContext SwitchGenerationCase switchGenerationCase = checkSwitchGenerationCase(type, testExpressions); - MethodHandle equalsMethodHandle = generatorContext.getScalarFunctionInvoker(resolvedEqualsFunction, simpleConvention(NULLABLE_RETURN, NEVER_NULL, NEVER_NULL)).getMethodHandle(); - MethodHandle hashCodeMethodHandle = generatorContext.getScalarFunctionInvoker(resolvedHashCodeFunction, simpleConvention(FAIL_ON_NULL, NEVER_NULL)).getMethodHandle(); - MethodHandle indeterminateMethodHandle = generatorContext.getScalarFunctionInvoker(resolvedIsIndeterminate, simpleConvention(FAIL_ON_NULL, NEVER_NULL)).getMethodHandle(); + MethodHandle equalsMethodHandle = generatorContext.getScalarFunctionImplementation(resolvedEqualsFunction, simpleConvention(NULLABLE_RETURN, NEVER_NULL, NEVER_NULL)).getMethodHandle(); + MethodHandle hashCodeMethodHandle = generatorContext.getScalarFunctionImplementation(resolvedHashCodeFunction, simpleConvention(FAIL_ON_NULL, NEVER_NULL)).getMethodHandle(); + MethodHandle indeterminateMethodHandle = generatorContext.getScalarFunctionImplementation(resolvedIsIndeterminate, simpleConvention(FAIL_ON_NULL, NEVER_NULL)).getMethodHandle(); ImmutableListMultimap.Builder hashBucketsBuilder = ImmutableListMultimap.builder(); ImmutableList.Builder defaultBucket = ImmutableList.builder(); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/ExpressionInterpreter.java b/core/trino-main/src/main/java/io/trino/sql/planner/ExpressionInterpreter.java index 50b63e87d68c..8f30611685b6 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/ExpressionInterpreter.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/ExpressionInterpreter.java @@ -22,7 +22,6 @@ import io.trino.Session; import io.trino.execution.warnings.WarningCollector; import io.trino.likematcher.LikeMatcher; -import io.trino.metadata.FunctionNullability; import io.trino.metadata.Metadata; import io.trino.metadata.ResolvedFunction; import io.trino.operator.scalar.ArraySubscriptOperator; @@ -33,6 +32,7 @@ import io.trino.spi.block.RowBlockBuilder; import io.trino.spi.block.SingleRowBlock; import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.function.FunctionNullability; import io.trino.spi.function.InvocationConvention; import io.trino.spi.function.OperatorType; import io.trino.spi.type.ArrayType; @@ -619,8 +619,8 @@ protected Object visitInPredicate(InPredicate node, Object context) set = FastutilSetHelper.toFastutilHashSet( objectSet, type, - plannerContext.getFunctionManager().getScalarFunctionInvoker(metadata.resolveOperator(session, HASH_CODE, ImmutableList.of(type)), simpleConvention(FAIL_ON_NULL, NEVER_NULL)).getMethodHandle(), - plannerContext.getFunctionManager().getScalarFunctionInvoker(metadata.resolveOperator(session, EQUAL, ImmutableList.of(type, type)), simpleConvention(NULLABLE_RETURN, NEVER_NULL, NEVER_NULL)).getMethodHandle()); + plannerContext.getFunctionManager().getScalarFunctionImplementation(metadata.resolveOperator(session, HASH_CODE, ImmutableList.of(type)), simpleConvention(FAIL_ON_NULL, NEVER_NULL)).getMethodHandle(), + plannerContext.getFunctionManager().getScalarFunctionImplementation(metadata.resolveOperator(session, EQUAL, ImmutableList.of(type, type)), simpleConvention(NULLABLE_RETURN, NEVER_NULL, NEVER_NULL)).getMethodHandle()); } inListCache.put(valueList, set); } @@ -743,7 +743,7 @@ protected Object visitArithmeticUnary(ArithmeticUnaryExpression node, Object con case MINUS -> { ResolvedFunction resolvedOperator = metadata.resolveOperator(session, OperatorType.NEGATION, types(node.getValue())); InvocationConvention invocationConvention = new InvocationConvention(ImmutableList.of(NEVER_NULL), FAIL_ON_NULL, true, false); - MethodHandle handle = plannerContext.getFunctionManager().getScalarFunctionInvoker(resolvedOperator, invocationConvention).getMethodHandle(); + MethodHandle handle = plannerContext.getFunctionManager().getScalarFunctionImplementation(resolvedOperator, invocationConvention).getMethodHandle(); if (handle.type().parameterCount() > 0 && handle.type().parameterType(0) == ConnectorSession.class) { handle = handle.bindTo(connectorSession); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java b/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java index 1867a04ade5b..193545092f29 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java @@ -42,8 +42,6 @@ import io.trino.execution.buffer.OutputBuffer; import io.trino.execution.buffer.PagesSerdeFactory; import io.trino.index.IndexManager; -import io.trino.metadata.BoundSignature; -import io.trino.metadata.FunctionId; import io.trino.metadata.MergeHandle; import io.trino.metadata.Metadata; import io.trino.metadata.ResolvedFunction; @@ -101,7 +99,6 @@ import io.trino.operator.WindowOperator.WindowOperatorFactory; import io.trino.operator.WorkProcessorPipelineSourceOperator; import io.trino.operator.aggregation.AccumulatorFactory; -import io.trino.operator.aggregation.AggregationMetadata; import io.trino.operator.aggregation.AggregatorFactory; import io.trino.operator.aggregation.DistinctAccumulatorFactory; import io.trino.operator.aggregation.OrderedAccumulatorFactory; @@ -138,7 +135,6 @@ import io.trino.operator.window.PartitionerSupplier; import io.trino.operator.window.PatternRecognitionPartitionerSupplier; import io.trino.operator.window.RegularPartitionerSupplier; -import io.trino.operator.window.WindowFunctionSupplier; import io.trino.operator.window.matcher.IrRowPatternToProgramRewriter; import io.trino.operator.window.matcher.Matcher; import io.trino.operator.window.matcher.Program; @@ -161,7 +157,11 @@ import io.trino.spi.connector.DynamicFilter; import io.trino.spi.connector.RecordSet; import io.trino.spi.connector.SortOrder; +import io.trino.spi.function.AggregationImplementation; +import io.trino.spi.function.BoundSignature; +import io.trino.spi.function.FunctionId; import io.trino.spi.function.FunctionKind; +import io.trino.spi.function.WindowFunctionSupplier; import io.trino.spi.predicate.Domain; import io.trino.spi.predicate.NullableValue; import io.trino.spi.type.RowType; @@ -1162,14 +1162,14 @@ private WindowFunctionSupplier getWindowFunctionImplementation(ResolvedFunction { if (resolvedFunction.getFunctionKind() == FunctionKind.AGGREGATE) { return uncheckedCacheGet(aggregationWindowFunctionSupplierCache, new FunctionKey(resolvedFunction.getFunctionId(), resolvedFunction.getSignature()), () -> { - AggregationMetadata aggregationMetadata = plannerContext.getFunctionManager().getAggregateFunctionImplementation(resolvedFunction); + AggregationImplementation aggregationImplementation = plannerContext.getFunctionManager().getAggregationImplementation(resolvedFunction); return new AggregationWindowFunctionSupplier( resolvedFunction.getSignature(), - aggregationMetadata, + aggregationImplementation, resolvedFunction.getFunctionNullability()); }); } - return plannerContext.getFunctionManager().getWindowFunctionImplementation(resolvedFunction); + return plannerContext.getFunctionManager().getWindowFunctionSupplier(resolvedFunction); } @Override @@ -1507,7 +1507,7 @@ else if (matchNumberSymbols.contains(pointer.getInputSymbol())) { boolean classifierInvolved = false; ResolvedFunction resolvedFunction = pointer.getFunction(); - AggregationMetadata aggregationMetadata = plannerContext.getFunctionManager().getAggregateFunctionImplementation(pointer.getFunction()); + AggregationImplementation aggregationImplementation = plannerContext.getFunctionManager().getAggregationImplementation(pointer.getFunction()); ImmutableList.Builder> builder = ImmutableList.builder(); List signatureTypes = resolvedFunction.getSignature().getArgumentTypes(); @@ -1529,7 +1529,7 @@ else if (matchNumberSymbols.contains(pointer.getInputSymbol())) { .collect(toImmutableList()); // TODO when we support lambda arguments: lambda cannot have runtime-evaluated symbols -- add check in the Analyzer - List> lambdaProviders = makeLambdaProviders(lambdaExpressions, aggregationMetadata.getLambdaInterfaces(), functionTypes); + List> lambdaProviders = makeLambdaProviders(lambdaExpressions, aggregationImplementation.getLambdaInterfaces(), functionTypes); // handle non-lambda arguments List valueChannels = new ArrayList<>(); @@ -1577,7 +1577,7 @@ else if (symbol.equals(matchNumberArgumentSymbol)) { new FunctionKey(resolvedFunction.getFunctionId(), resolvedFunction.getSignature()), () -> new AggregationWindowFunctionSupplier( resolvedFunction.getSignature(), - aggregationMetadata, + aggregationImplementation, resolvedFunction.getFunctionNullability())); matchAggregations.add(new MatchAggregationInstantiator( resolvedFunction.getSignature(), @@ -3667,13 +3667,13 @@ private AggregatorFactory buildAggregatorFactory( } ResolvedFunction resolvedFunction = aggregation.getResolvedFunction(); - AggregationMetadata aggregationMetadata = plannerContext.getFunctionManager().getAggregateFunctionImplementation(aggregation.getResolvedFunction()); + AggregationImplementation aggregationImplementation = plannerContext.getFunctionManager().getAggregationImplementation(aggregation.getResolvedFunction()); AccumulatorFactory accumulatorFactory = uncheckedCacheGet( accumulatorFactoryCache, new FunctionKey(resolvedFunction.getFunctionId(), resolvedFunction.getSignature()), () -> generateAccumulatorFactory( resolvedFunction.getSignature(), - aggregationMetadata, + aggregationImplementation, resolvedFunction.getFunctionNullability())); if (aggregation.isDistinct()) { @@ -3720,7 +3720,7 @@ private AggregatorFactory buildAggregatorFactory( pagesIndexFactory); } - ImmutableList intermediateTypes = aggregationMetadata.getAccumulatorStateDescriptors().stream() + ImmutableList intermediateTypes = aggregationImplementation.getAccumulatorStateDescriptors().stream() .map(stateDescriptor -> stateDescriptor.getSerializer().getSerializedType()) .collect(toImmutableList()); Type intermediateType = (intermediateTypes.size() == 1) ? getOnlyElement(intermediateTypes) : RowType.anonymous(intermediateTypes); @@ -3738,7 +3738,7 @@ private AggregatorFactory buildAggregatorFactory( .filter(FunctionType.class::isInstance) .map(FunctionType.class::cast) .collect(toImmutableList()); - List> lambdaProviders = makeLambdaProviders(lambdaExpressions, aggregationMetadata.getLambdaInterfaces(), functionTypes); + List> lambdaProviders = makeLambdaProviders(lambdaExpressions, aggregationImplementation.getLambdaInterfaces(), functionTypes); return new AggregatorFactory( accumulatorFactory, diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PruneCountAggregationOverScalar.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PruneCountAggregationOverScalar.java index b7a398fe73eb..3a3d4d1d04d6 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PruneCountAggregationOverScalar.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PruneCountAggregationOverScalar.java @@ -16,9 +16,9 @@ import com.google.common.collect.ImmutableList; import io.trino.matching.Captures; import io.trino.matching.Pattern; -import io.trino.metadata.FunctionId; import io.trino.metadata.Metadata; import io.trino.metadata.ResolvedFunction; +import io.trino.spi.function.FunctionId; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.iterative.Rule; import io.trino.sql.planner.plan.AggregationNode; diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushAggregationIntoTableScan.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushAggregationIntoTableScan.java index 9c47f1fca32f..c72c82638f4a 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushAggregationIntoTableScan.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushAggregationIntoTableScan.java @@ -19,7 +19,6 @@ import io.trino.matching.Capture; import io.trino.matching.Captures; import io.trino.matching.Pattern; -import io.trino.metadata.BoundSignature; import io.trino.metadata.Metadata; import io.trino.metadata.TableHandle; import io.trino.spi.connector.AggregateFunction; @@ -29,6 +28,7 @@ import io.trino.spi.connector.SortItem; import io.trino.spi.expression.ConnectorExpression; import io.trino.spi.expression.Variable; +import io.trino.spi.function.BoundSignature; import io.trino.spi.predicate.TupleDomain; import io.trino.spi.type.Type; import io.trino.sql.PlannerContext; diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushFilterThroughCountAggregation.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushFilterThroughCountAggregation.java index 97612f283c78..e6db9385aa4a 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushFilterThroughCountAggregation.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushFilterThroughCountAggregation.java @@ -20,7 +20,7 @@ import io.trino.matching.Capture; import io.trino.matching.Captures; import io.trino.matching.Pattern; -import io.trino.metadata.BoundSignature; +import io.trino.spi.function.BoundSignature; import io.trino.spi.predicate.Domain; import io.trino.spi.predicate.Range; import io.trino.spi.predicate.TupleDomain; diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushPartialAggregationThroughExchange.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushPartialAggregationThroughExchange.java index 13c87ead056c..97c28495a32b 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushPartialAggregationThroughExchange.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushPartialAggregationThroughExchange.java @@ -17,8 +17,8 @@ import io.trino.matching.Capture; import io.trino.matching.Captures; import io.trino.matching.Pattern; -import io.trino.metadata.AggregationFunctionMetadata; import io.trino.metadata.ResolvedFunction; +import io.trino.spi.function.AggregationFunctionMetadata; import io.trino.spi.type.RowType; import io.trino.spi.type.Type; import io.trino.sql.PlannerContext; diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ReplaceWindowWithRowNumber.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ReplaceWindowWithRowNumber.java index 85e9744fb755..f0df14377fb0 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ReplaceWindowWithRowNumber.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ReplaceWindowWithRowNumber.java @@ -15,8 +15,8 @@ import io.trino.matching.Captures; import io.trino.matching.Pattern; -import io.trino.metadata.BoundSignature; import io.trino.metadata.Metadata; +import io.trino.spi.function.BoundSignature; import io.trino.sql.planner.iterative.Rule; import io.trino.sql.planner.plan.RowNumberNode; import io.trino.sql.planner.plan.WindowNode; diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/SimplifyCountOverConstant.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/SimplifyCountOverConstant.java index 63bba215993a..f5d06f991790 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/SimplifyCountOverConstant.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/SimplifyCountOverConstant.java @@ -20,9 +20,9 @@ import io.trino.matching.Capture; import io.trino.matching.Captures; import io.trino.matching.Pattern; -import io.trino.metadata.BoundSignature; import io.trino.metadata.ResolvedFunction; import io.trino.security.AllowAllAccessControl; +import io.trino.spi.function.BoundSignature; import io.trino.sql.PlannerContext; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.iterative.Rule; diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/Util.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/Util.java index 9396c15a859c..d3b998c9a82e 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/Util.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/Util.java @@ -16,7 +16,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Sets; -import io.trino.metadata.BoundSignature; +import io.trino.spi.function.BoundSignature; import io.trino.sql.planner.PlanNodeIdAllocator; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.SymbolsExtractor; diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/ExpressionEquivalence.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/ExpressionEquivalence.java index 7f9dba2d9e1e..b24b369d675a 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/ExpressionEquivalence.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/ExpressionEquivalence.java @@ -47,7 +47,7 @@ import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.collect.ImmutableList.toImmutableList; -import static io.trino.metadata.Signature.mangleOperatorName; +import static io.trino.metadata.OperatorNameUtil.mangleOperatorName; import static io.trino.spi.function.OperatorType.EQUAL; import static io.trino.spi.function.OperatorType.IS_DISTINCT_FROM; import static io.trino.spi.type.BooleanType.BOOLEAN; diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/HashGenerationOptimizer.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/HashGenerationOptimizer.java index e86a201cb869..83cc0ba8f4f5 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/HashGenerationOptimizer.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/HashGenerationOptimizer.java @@ -83,7 +83,7 @@ import static com.google.common.base.Verify.verify; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableMap.toImmutableMap; -import static io.trino.metadata.Signature.mangleOperatorName; +import static io.trino.metadata.OperatorNameUtil.mangleOperatorName; import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.sql.planner.SystemPartitioningHandle.FIXED_HASH_DISTRIBUTION; import static io.trino.sql.planner.plan.ChildReplacer.replaceChildren; diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/IndexJoinOptimizer.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/IndexJoinOptimizer.java index f0e5c981bc53..bbd64514424f 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/IndexJoinOptimizer.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/IndexJoinOptimizer.java @@ -23,10 +23,10 @@ import io.trino.Session; import io.trino.cost.TableStatsProvider; import io.trino.execution.warnings.WarningCollector; -import io.trino.metadata.BoundSignature; import io.trino.metadata.ResolvedFunction; import io.trino.metadata.ResolvedIndex; import io.trino.spi.connector.ColumnHandle; +import io.trino.spi.function.BoundSignature; import io.trino.spi.predicate.TupleDomain; import io.trino.sql.PlannerContext; import io.trino.sql.planner.DomainTranslator; diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/WindowFilterPushDown.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/WindowFilterPushDown.java index 36465e6f50ed..40d58c7cf932 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/WindowFilterPushDown.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/WindowFilterPushDown.java @@ -17,7 +17,7 @@ import io.trino.Session; import io.trino.cost.TableStatsProvider; import io.trino.execution.warnings.WarningCollector; -import io.trino.metadata.FunctionId; +import io.trino.spi.function.FunctionId; import io.trino.spi.predicate.Domain; import io.trino.spi.predicate.Range; import io.trino.spi.predicate.TupleDomain; diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/plan/AggregationNode.java b/core/trino-main/src/main/java/io/trino/sql/planner/plan/AggregationNode.java index a152160e81c7..df41262bf228 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/plan/AggregationNode.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/plan/AggregationNode.java @@ -20,9 +20,9 @@ import com.google.common.collect.ImmutableSet; import com.google.common.collect.Iterables; import io.trino.Session; -import io.trino.metadata.AggregationFunctionMetadata; import io.trino.metadata.Metadata; import io.trino.metadata.ResolvedFunction; +import io.trino.spi.function.AggregationFunctionMetadata; import io.trino.sql.planner.OrderingScheme; import io.trino.sql.planner.Symbol; import io.trino.sql.tree.Expression; diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/plan/StatisticAggregations.java b/core/trino-main/src/main/java/io/trino/sql/planner/plan/StatisticAggregations.java index d42323ace31b..b2cf2c26786b 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/plan/StatisticAggregations.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/plan/StatisticAggregations.java @@ -18,8 +18,8 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import io.trino.Session; -import io.trino.metadata.AggregationFunctionMetadata; import io.trino.metadata.ResolvedFunction; +import io.trino.spi.function.AggregationFunctionMetadata; import io.trino.spi.type.RowType; import io.trino.spi.type.Type; import io.trino.sql.PlannerContext; diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/sanity/TypeValidator.java b/core/trino-main/src/main/java/io/trino/sql/planner/sanity/TypeValidator.java index 80ea9c5b41d6..2ca5007dbf28 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/sanity/TypeValidator.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/sanity/TypeValidator.java @@ -16,7 +16,7 @@ import com.google.common.collect.ListMultimap; import io.trino.Session; import io.trino.execution.warnings.WarningCollector; -import io.trino.metadata.BoundSignature; +import io.trino.spi.function.BoundSignature; import io.trino.spi.type.Type; import io.trino.spi.type.TypeManager; import io.trino.sql.PlannerContext; diff --git a/core/trino-main/src/main/java/io/trino/sql/relational/SpecialForm.java b/core/trino-main/src/main/java/io/trino/sql/relational/SpecialForm.java index 63b3a087600d..4d3b1ae1e53d 100644 --- a/core/trino-main/src/main/java/io/trino/sql/relational/SpecialForm.java +++ b/core/trino-main/src/main/java/io/trino/sql/relational/SpecialForm.java @@ -15,9 +15,9 @@ import com.google.common.base.Joiner; import com.google.common.collect.ImmutableList; -import io.trino.metadata.BoundSignature; +import io.trino.metadata.OperatorNameUtil; import io.trino.metadata.ResolvedFunction; -import io.trino.metadata.Signature; +import io.trino.spi.function.BoundSignature; import io.trino.spi.function.OperatorType; import io.trino.spi.type.Type; @@ -66,7 +66,7 @@ public List getFunctionDependencies() public ResolvedFunction getOperatorDependency(OperatorType operator) { - String mangleOperatorName = Signature.mangleOperatorName(operator); + String mangleOperatorName = OperatorNameUtil.mangleOperatorName(operator); for (ResolvedFunction function : functionDependencies) { if (function.getSignature().getName().equals(mangleOperatorName)) { return function; @@ -80,7 +80,7 @@ public Optional getCastDependency(Type fromType, Type toType) if (fromType.equals(toType)) { return Optional.empty(); } - BoundSignature boundSignature = new BoundSignature(Signature.mangleOperatorName(CAST), toType, ImmutableList.of(fromType)); + BoundSignature boundSignature = new BoundSignature(OperatorNameUtil.mangleOperatorName(CAST), toType, ImmutableList.of(fromType)); for (ResolvedFunction function : functionDependencies) { if (function.getSignature().equals(boundSignature)) { return Optional.of(function); diff --git a/core/trino-main/src/main/java/io/trino/sql/relational/optimizer/ExpressionOptimizer.java b/core/trino-main/src/main/java/io/trino/sql/relational/optimizer/ExpressionOptimizer.java index 44523535ad3b..2deb772241af 100644 --- a/core/trino-main/src/main/java/io/trino/sql/relational/optimizer/ExpressionOptimizer.java +++ b/core/trino-main/src/main/java/io/trino/sql/relational/optimizer/ExpressionOptimizer.java @@ -16,8 +16,8 @@ import com.google.common.collect.ImmutableList; import io.trino.Session; import io.trino.metadata.FunctionManager; -import io.trino.metadata.FunctionMetadata; import io.trino.metadata.Metadata; +import io.trino.spi.function.FunctionMetadata; import io.trino.spi.type.ArrayType; import io.trino.spi.type.MapType; import io.trino.spi.type.RowType; @@ -39,7 +39,7 @@ import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; import static com.google.common.collect.ImmutableList.toImmutableList; -import static io.trino.metadata.Signature.mangleOperatorName; +import static io.trino.metadata.OperatorNameUtil.mangleOperatorName; import static io.trino.operator.scalar.JsonStringToArrayCast.JSON_STRING_TO_ARRAY_NAME; import static io.trino.operator.scalar.JsonStringToMapCast.JSON_STRING_TO_MAP_NAME; import static io.trino.operator.scalar.JsonStringToRowCast.JSON_STRING_TO_ROW_NAME; diff --git a/core/trino-main/src/main/java/io/trino/sql/rewrite/ShowQueriesRewrite.java b/core/trino-main/src/main/java/io/trino/sql/rewrite/ShowQueriesRewrite.java index 3ab9a417dfb6..92123f6349a0 100644 --- a/core/trino-main/src/main/java/io/trino/sql/rewrite/ShowQueriesRewrite.java +++ b/core/trino-main/src/main/java/io/trino/sql/rewrite/ShowQueriesRewrite.java @@ -25,7 +25,6 @@ import io.trino.execution.warnings.WarningCollector; import io.trino.metadata.CatalogInfo; import io.trino.metadata.ColumnPropertyManager; -import io.trino.metadata.FunctionMetadata; import io.trino.metadata.MaterializedViewDefinition; import io.trino.metadata.MaterializedViewPropertyManager; import io.trino.metadata.Metadata; @@ -45,6 +44,7 @@ import io.trino.spi.connector.ConnectorTableMetadata; import io.trino.spi.connector.SchemaTableName; import io.trino.spi.function.FunctionKind; +import io.trino.spi.function.FunctionMetadata; import io.trino.spi.security.PrincipalType; import io.trino.spi.security.TrinoPrincipal; import io.trino.spi.session.PropertyMetadata; diff --git a/core/trino-main/src/main/java/io/trino/testing/LocalQueryRunner.java b/core/trino-main/src/main/java/io/trino/testing/LocalQueryRunner.java index 49171de52b7f..c1cb5791576b 100644 --- a/core/trino-main/src/main/java/io/trino/testing/LocalQueryRunner.java +++ b/core/trino-main/src/main/java/io/trino/testing/LocalQueryRunner.java @@ -221,6 +221,7 @@ import static io.trino.connector.CatalogServiceProviderModule.createAccessControlProvider; import static io.trino.connector.CatalogServiceProviderModule.createAnalyzePropertyManager; import static io.trino.connector.CatalogServiceProviderModule.createColumnPropertyManager; +import static io.trino.connector.CatalogServiceProviderModule.createFunctionProvider; import static io.trino.connector.CatalogServiceProviderModule.createIndexProvider; import static io.trino.connector.CatalogServiceProviderModule.createMaterializedViewPropertyManager; import static io.trino.connector.CatalogServiceProviderModule.createNodePartitioningProvider; @@ -367,28 +368,18 @@ private LocalQueryRunner( this.globalFunctionCatalog = new GlobalFunctionCatalog(); globalFunctionCatalog.addFunctions(new InternalFunctionBundle(new LiteralFunction(blockEncodingSerde))); globalFunctionCatalog.addFunctions(SystemFunctionBundle.create(featuresConfig, typeOperators, blockTypeOperators, nodeManager.getCurrentNode().getNodeVersion())); - this.functionManager = new FunctionManager(globalFunctionCatalog); Metadata metadata = metadataProvider.getMetadata( new DisabledSystemSecurityMetadata(), transactionManager, globalFunctionCatalog, typeManager); - globalFunctionCatalog.addFunctions(new InternalFunctionBundle( - new JsonExistsFunction(functionManager, metadata, typeManager), - new JsonValueFunction(functionManager, metadata, typeManager), - new JsonQueryFunction(functionManager, metadata, typeManager))); typeRegistry.addType(new JsonPath2016Type(new TypeDeserializer(typeManager), blockEncodingSerde)); - this.plannerContext = new PlannerContext(metadata, typeOperators, blockEncodingSerde, typeManager, functionManager); this.joinCompiler = new JoinCompiler(typeOperators); PageIndexerFactory pageIndexerFactory = new GroupByHashPageIndexerFactory(joinCompiler, blockTypeOperators); this.groupProvider = new TestingGroupProvider(); this.accessControl = new TestingAccessControlManager(transactionManager, eventListenerManager); accessControl.loadSystemAccessControl(AllowAllSystemAccessControl.NAME, ImmutableMap.of()); - this.pageFunctionCompiler = new PageFunctionCompiler(functionManager, 0); - this.expressionCompiler = new ExpressionCompiler(functionManager, pageFunctionCompiler); - this.joinFilterFunctionCompiler = new JoinFilterFunctionCompiler(functionManager); - HandleResolver handleResolver = new HandleResolver(); NodeInfo nodeInfo = new NodeInfo("test"); @@ -412,6 +403,7 @@ private LocalQueryRunner( this.sessionPropertyManager = createSessionPropertyManager(catalogManager, extraSessionProperties, taskManagerConfig, featuresConfig, optimizerConfig); this.nodePartitioningManager = new NodePartitioningManager(nodeScheduler, blockTypeOperators, createNodePartitioningProvider(catalogManager)); TableProceduresRegistry tableProceduresRegistry = new TableProceduresRegistry(createTableProceduresProvider(catalogManager)); + this.functionManager = new FunctionManager(createFunctionProvider(catalogManager), globalFunctionCatalog); TableFunctionRegistry tableFunctionRegistry = new TableFunctionRegistry(createTableFunctionProvider(catalogManager)); this.schemaPropertyManager = createSchemaPropertyManager(catalogManager); this.columnPropertyManager = createColumnPropertyManager(catalogManager); @@ -422,6 +414,16 @@ private LocalQueryRunner( accessControl.setConnectorAccessControlProvider(createAccessControlProvider(catalogManager)); + globalFunctionCatalog.addFunctions(new InternalFunctionBundle( + new JsonExistsFunction(functionManager, metadata, typeManager), + new JsonValueFunction(functionManager, metadata, typeManager), + new JsonQueryFunction(functionManager, metadata, typeManager))); + + this.plannerContext = new PlannerContext(metadata, typeOperators, blockEncodingSerde, typeManager, functionManager); + this.pageFunctionCompiler = new PageFunctionCompiler(functionManager, 0); + this.expressionCompiler = new ExpressionCompiler(functionManager, pageFunctionCompiler); + this.joinFilterFunctionCompiler = new JoinFilterFunctionCompiler(functionManager); + this.statementAnalyzerFactory = new StatementAnalyzerFactory( plannerContext, sqlParser, diff --git a/core/trino-main/src/main/java/io/trino/type/DecimalCasts.java b/core/trino-main/src/main/java/io/trino/type/DecimalCasts.java index 55d1fc533a18..023a247242ea 100644 --- a/core/trino-main/src/main/java/io/trino/type/DecimalCasts.java +++ b/core/trino-main/src/main/java/io/trino/type/DecimalCasts.java @@ -23,9 +23,9 @@ import io.airlift.slice.SliceOutput; import io.trino.annotation.UsedByGeneratedCode; import io.trino.metadata.PolymorphicScalarFunctionBuilder; -import io.trino.metadata.Signature; import io.trino.metadata.SqlScalarFunction; import io.trino.spi.TrinoException; +import io.trino.spi.function.Signature; import io.trino.spi.type.DecimalConversions; import io.trino.spi.type.DecimalType; import io.trino.spi.type.Decimals; diff --git a/core/trino-main/src/main/java/io/trino/type/DecimalOperators.java b/core/trino-main/src/main/java/io/trino/type/DecimalOperators.java index e999f97c7cc8..f9932e99fb6c 100644 --- a/core/trino-main/src/main/java/io/trino/type/DecimalOperators.java +++ b/core/trino-main/src/main/java/io/trino/type/DecimalOperators.java @@ -17,11 +17,11 @@ import io.trino.annotation.UsedByGeneratedCode; import io.trino.metadata.PolymorphicScalarFunctionBuilder; import io.trino.metadata.PolymorphicScalarFunctionBuilder.SpecializeContext; -import io.trino.metadata.Signature; import io.trino.metadata.SqlScalarFunction; import io.trino.spi.TrinoException; import io.trino.spi.function.LiteralParameters; import io.trino.spi.function.ScalarOperator; +import io.trino.spi.function.Signature; import io.trino.spi.function.SqlType; import io.trino.spi.type.DecimalType; import io.trino.spi.type.Decimals; diff --git a/core/trino-main/src/main/java/io/trino/type/DecimalSaturatedFloorCasts.java b/core/trino-main/src/main/java/io/trino/type/DecimalSaturatedFloorCasts.java index 986f6a59e561..35a4fa66cd17 100644 --- a/core/trino-main/src/main/java/io/trino/type/DecimalSaturatedFloorCasts.java +++ b/core/trino-main/src/main/java/io/trino/type/DecimalSaturatedFloorCasts.java @@ -16,8 +16,8 @@ import com.google.common.collect.ImmutableList; import io.trino.annotation.UsedByGeneratedCode; import io.trino.metadata.PolymorphicScalarFunctionBuilder; -import io.trino.metadata.Signature; import io.trino.metadata.SqlScalarFunction; +import io.trino.spi.function.Signature; import io.trino.spi.type.Int128; import io.trino.spi.type.Type; import io.trino.spi.type.TypeSignature; diff --git a/core/trino-main/src/main/java/io/trino/type/DecimalToDecimalCasts.java b/core/trino-main/src/main/java/io/trino/type/DecimalToDecimalCasts.java index 207161296561..d2dea79c9c6f 100644 --- a/core/trino-main/src/main/java/io/trino/type/DecimalToDecimalCasts.java +++ b/core/trino-main/src/main/java/io/trino/type/DecimalToDecimalCasts.java @@ -16,8 +16,8 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; import io.trino.metadata.PolymorphicScalarFunctionBuilder; -import io.trino.metadata.Signature; import io.trino.metadata.SqlScalarFunction; +import io.trino.spi.function.Signature; import io.trino.spi.type.DecimalConversions; import io.trino.spi.type.DecimalType; diff --git a/core/trino-main/src/main/java/io/trino/util/MinMaxCompare.java b/core/trino-main/src/main/java/io/trino/util/MinMaxCompare.java index 26846c4bf842..506c3500fc32 100644 --- a/core/trino-main/src/main/java/io/trino/util/MinMaxCompare.java +++ b/core/trino-main/src/main/java/io/trino/util/MinMaxCompare.java @@ -15,8 +15,8 @@ import com.google.common.collect.ImmutableList; import io.trino.annotation.UsedByGeneratedCode; -import io.trino.metadata.FunctionDependencies; -import io.trino.metadata.FunctionDependencyDeclaration; +import io.trino.spi.function.FunctionDependencies; +import io.trino.spi.function.FunctionDependencyDeclaration; import io.trino.spi.function.InvocationConvention; import io.trino.spi.function.OperatorType; import io.trino.spi.type.Type; @@ -49,7 +49,7 @@ public static FunctionDependencyDeclaration getMinMaxCompareFunctionDependencies public static MethodHandle getMinMaxCompare(FunctionDependencies dependencies, Type type, InvocationConvention convention, boolean min) { OperatorType comparisonOperator = getMinMaxCompareOperatorType(min); - MethodHandle handle = dependencies.getOperatorInvoker(comparisonOperator, List.of(type, type), convention).getMethodHandle(); + MethodHandle handle = dependencies.getOperatorImplementation(comparisonOperator, List.of(type, type), convention).getMethodHandle(); return comparisonToMinMaxResult(min, handle); } diff --git a/core/trino-main/src/test/java/io/trino/metadata/AbstractMockMetadata.java b/core/trino-main/src/test/java/io/trino/metadata/AbstractMockMetadata.java index 840d736c9474..a22f5aba756f 100644 --- a/core/trino-main/src/test/java/io/trino/metadata/AbstractMockMetadata.java +++ b/core/trino-main/src/test/java/io/trino/metadata/AbstractMockMetadata.java @@ -21,6 +21,7 @@ import io.airlift.slice.Slice; import io.trino.Session; import io.trino.connector.CatalogHandle; +import io.trino.connector.system.GlobalSystemConnector; import io.trino.metadata.ResolvedFunction.ResolvedFunctionDecoder; import io.trino.spi.TrinoException; import io.trino.spi.connector.AggregateFunction; @@ -52,6 +53,10 @@ import io.trino.spi.connector.TableScanRedirectApplicationResult; import io.trino.spi.connector.TopNApplicationResult; import io.trino.spi.expression.ConnectorExpression; +import io.trino.spi.function.AggregationFunctionMetadata; +import io.trino.spi.function.BoundSignature; +import io.trino.spi.function.FunctionMetadata; +import io.trino.spi.function.FunctionNullability; import io.trino.spi.function.OperatorType; import io.trino.spi.predicate.TupleDomain; import io.trino.spi.security.GrantInfo; @@ -74,9 +79,9 @@ import java.util.OptionalLong; import java.util.Set; -import static io.trino.metadata.FunctionId.toFunctionId; import static io.trino.metadata.RedirectionAwareTableHandle.noRedirection; import static io.trino.spi.StandardErrorCode.FUNCTION_NOT_FOUND; +import static io.trino.spi.function.FunctionId.toFunctionId; import static io.trino.spi.function.FunctionKind.SCALAR; import static io.trino.spi.type.DoubleType.DOUBLE; import static io.trino.type.InternalTypeManager.TESTING_TYPE_MANAGER; @@ -757,6 +762,7 @@ public ResolvedFunction resolveFunction(Session session, QualifiedName name, Lis BoundSignature boundSignature = new BoundSignature(nameSuffix, DOUBLE, ImmutableList.of()); return new ResolvedFunction( boundSignature, + GlobalSystemConnector.CATALOG_HANDLE, toFunctionId(boundSignature.toSignature()), SCALAR, true, diff --git a/core/trino-main/src/test/java/io/trino/metadata/CountingAccessMetadata.java b/core/trino-main/src/test/java/io/trino/metadata/CountingAccessMetadata.java index 095721c2e96b..d4dcee212246 100644 --- a/core/trino-main/src/test/java/io/trino/metadata/CountingAccessMetadata.java +++ b/core/trino-main/src/test/java/io/trino/metadata/CountingAccessMetadata.java @@ -49,6 +49,8 @@ import io.trino.spi.connector.TableScanRedirectApplicationResult; import io.trino.spi.connector.TopNApplicationResult; import io.trino.spi.expression.ConnectorExpression; +import io.trino.spi.function.AggregationFunctionMetadata; +import io.trino.spi.function.FunctionMetadata; import io.trino.spi.function.OperatorType; import io.trino.spi.predicate.TupleDomain; import io.trino.spi.security.GrantInfo; diff --git a/core/trino-main/src/test/java/io/trino/metadata/TestGlobalFunctionCatalog.java b/core/trino-main/src/test/java/io/trino/metadata/TestGlobalFunctionCatalog.java index 0315d62c606e..43bce79e1e7e 100644 --- a/core/trino-main/src/test/java/io/trino/metadata/TestGlobalFunctionCatalog.java +++ b/core/trino-main/src/test/java/io/trino/metadata/TestGlobalFunctionCatalog.java @@ -17,11 +17,15 @@ import com.google.common.collect.ImmutableSet; import io.trino.FeaturesConfig; import io.trino.client.NodeVersion; -import io.trino.operator.scalar.ChoicesScalarFunctionImplementation; -import io.trino.operator.scalar.ScalarFunctionImplementation; +import io.trino.operator.scalar.ChoicesSpecializedSqlScalarFunction; +import io.trino.operator.scalar.SpecializedSqlScalarFunction; +import io.trino.spi.function.BoundSignature; +import io.trino.spi.function.FunctionMetadata; import io.trino.spi.function.OperatorType; import io.trino.spi.function.ScalarFunction; +import io.trino.spi.function.Signature; import io.trino.spi.function.SqlType; +import io.trino.spi.function.TypeVariableConstraint; import io.trino.spi.type.ArrayType; import io.trino.spi.type.StandardTypes; import io.trino.spi.type.Type; @@ -41,11 +45,11 @@ import static com.google.common.collect.ImmutableSet.toImmutableSet; import static io.trino.SessionTestUtils.TEST_SESSION; import static io.trino.metadata.InternalFunctionBundle.extractFunctions; -import static io.trino.metadata.Signature.mangleOperatorName; -import static io.trino.metadata.Signature.unmangleOperator; -import static io.trino.metadata.TypeVariableConstraint.typeVariable; +import static io.trino.metadata.OperatorNameUtil.mangleOperatorName; +import static io.trino.metadata.OperatorNameUtil.unmangleOperator; import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.NEVER_NULL; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL; +import static io.trino.spi.function.TypeVariableConstraint.typeVariable; import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.DecimalType.createDecimalType; import static io.trino.spi.type.HyperLogLogType.HYPER_LOG_LOG; @@ -265,7 +269,7 @@ public void testResolveFunctionForUnknown() private static List listOperators(Metadata metadata) { Set operatorNames = Arrays.stream(OperatorType.values()) - .map(Signature::mangleOperatorName) + .map(OperatorNameUtil::mangleOperatorName) .collect(toImmutableSet()); return metadata.listFunctions(TEST_SESSION).stream() @@ -355,9 +359,9 @@ private InternalFunctionBundle createFunctionsFromSignatures() functions.add(new SqlScalarFunction(functionMetadata) { @Override - protected ScalarFunctionImplementation specialize(BoundSignature boundSignature) + protected SpecializedSqlScalarFunction specialize(BoundSignature boundSignature) { - return new ChoicesScalarFunctionImplementation( + return new ChoicesSpecializedSqlScalarFunction( boundSignature, FAIL_ON_NULL, nCopies(boundSignature.getArity(), NEVER_NULL), diff --git a/core/trino-main/src/test/java/io/trino/metadata/TestPolymorphicScalarFunction.java b/core/trino-main/src/test/java/io/trino/metadata/TestPolymorphicScalarFunction.java index 47f66859ab40..d41909c586a7 100644 --- a/core/trino-main/src/test/java/io/trino/metadata/TestPolymorphicScalarFunction.java +++ b/core/trino-main/src/test/java/io/trino/metadata/TestPolymorphicScalarFunction.java @@ -18,10 +18,13 @@ import com.google.common.collect.ImmutableSet; import io.airlift.slice.Slice; import io.airlift.slice.Slices; -import io.trino.operator.scalar.ChoicesScalarFunctionImplementation; +import io.trino.operator.scalar.ChoicesSpecializedSqlScalarFunction; import io.trino.spi.block.Block; import io.trino.spi.block.LongArrayBlock; +import io.trino.spi.function.BoundSignature; import io.trino.spi.function.InvocationConvention; +import io.trino.spi.function.Signature; +import io.trino.spi.function.TypeVariableConstraint; import io.trino.spi.type.DecimalType; import io.trino.spi.type.Int128; import io.trino.spi.type.TypeSignature; @@ -96,26 +99,26 @@ public void testSelectsMultipleChoiceWithBlockPosition() .build(); BoundSignature shortDecimalBoundSignature = new BoundSignature(signature.getName(), BOOLEAN, ImmutableList.of(SHORT_DECIMAL_BOUND_TYPE, SHORT_DECIMAL_BOUND_TYPE)); - ChoicesScalarFunctionImplementation functionImplementation = (ChoicesScalarFunctionImplementation) function.specialize( + ChoicesSpecializedSqlScalarFunction specializedFunction = (ChoicesSpecializedSqlScalarFunction) function.specialize( shortDecimalBoundSignature, - new FunctionDependencies(FUNCTION_MANAGER::getScalarFunctionInvoker, ImmutableMap.of(), ImmutableSet.of())); + new InternalFunctionDependencies(FUNCTION_MANAGER::getScalarFunctionImplementation, ImmutableMap.of(), ImmutableSet.of())); - assertEquals(functionImplementation.getChoices().size(), 2); + assertEquals(specializedFunction.getChoices().size(), 2); assertEquals( - functionImplementation.getChoices().get(0).getInvocationConvention(), + specializedFunction.getChoices().get(0).getInvocationConvention(), new InvocationConvention(ImmutableList.of(NULL_FLAG, NULL_FLAG), FAIL_ON_NULL, false, false)); assertEquals( - functionImplementation.getChoices().get(1).getInvocationConvention(), + specializedFunction.getChoices().get(1).getInvocationConvention(), new InvocationConvention(ImmutableList.of(BLOCK_POSITION, BLOCK_POSITION), FAIL_ON_NULL, false, false)); Block block1 = new LongArrayBlock(0, Optional.empty(), new long[0]); Block block2 = new LongArrayBlock(0, Optional.empty(), new long[0]); - assertFalse((boolean) functionImplementation.getChoices().get(1).getMethodHandle().invoke(block1, 0, block2, 0)); + assertFalse((boolean) specializedFunction.getChoices().get(1).getMethodHandle().invoke(block1, 0, block2, 0)); BoundSignature longDecimalBoundSignature = new BoundSignature(signature.getName(), BOOLEAN, ImmutableList.of(LONG_DECIMAL_BOUND_TYPE, LONG_DECIMAL_BOUND_TYPE)); - functionImplementation = (ChoicesScalarFunctionImplementation) function.specialize( + specializedFunction = (ChoicesSpecializedSqlScalarFunction) function.specialize( longDecimalBoundSignature, - new FunctionDependencies(FUNCTION_MANAGER::getScalarFunctionInvoker, ImmutableMap.of(), ImmutableSet.of())); - assertTrue((boolean) functionImplementation.getChoices().get(1).getMethodHandle().invoke(block1, 0, block2, 0)); + new InternalFunctionDependencies(FUNCTION_MANAGER::getScalarFunctionImplementation, ImmutableMap.of(), ImmutableSet.of())); + assertTrue((boolean) specializedFunction.getChoices().get(1).getMethodHandle().invoke(block1, 0, block2, 0)); } @Test @@ -132,10 +135,10 @@ public void testSelectsMethodBasedOnArgumentTypes() .withExtraParameters(context -> ImmutableList.of(context.getLiteral("x"))))) .build(); - ChoicesScalarFunctionImplementation functionImplementation = (ChoicesScalarFunctionImplementation) function.specialize( + ChoicesSpecializedSqlScalarFunction specializedFunction = (ChoicesSpecializedSqlScalarFunction) function.specialize( BOUND_SIGNATURE, - new FunctionDependencies(FUNCTION_MANAGER::getScalarFunctionInvoker, ImmutableMap.of(), ImmutableSet.of())); - assertEquals(functionImplementation.getChoices().get(0).getMethodHandle().invoke(INPUT_SLICE), (long) INPUT_VARCHAR_LENGTH); + new InternalFunctionDependencies(FUNCTION_MANAGER::getScalarFunctionImplementation, ImmutableMap.of(), ImmutableSet.of())); + assertEquals(specializedFunction.getChoices().get(0).getMethodHandle().invoke(INPUT_SLICE), (long) INPUT_VARCHAR_LENGTH); } @Test @@ -152,11 +155,11 @@ public void testSelectsMethodBasedOnReturnType() .withExtraParameters(context -> ImmutableList.of(42)))) .build(); - ChoicesScalarFunctionImplementation functionImplementation = (ChoicesScalarFunctionImplementation) function.specialize( + ChoicesSpecializedSqlScalarFunction specializedFunction = (ChoicesSpecializedSqlScalarFunction) function.specialize( BOUND_SIGNATURE, - new FunctionDependencies(FUNCTION_MANAGER::getScalarFunctionInvoker, ImmutableMap.of(), ImmutableSet.of())); + new InternalFunctionDependencies(FUNCTION_MANAGER::getScalarFunctionImplementation, ImmutableMap.of(), ImmutableSet.of())); - assertEquals(functionImplementation.getChoices().get(0).getMethodHandle().invoke(INPUT_SLICE), VARCHAR_TO_BIGINT_RETURN_VALUE); + assertEquals(specializedFunction.getChoices().get(0).getMethodHandle().invoke(INPUT_SLICE), VARCHAR_TO_BIGINT_RETURN_VALUE); } @Test @@ -178,10 +181,10 @@ public void testSameLiteralInArgumentsAndReturnValue() BoundSignature boundSignature = new BoundSignature(signature.getName(), createVarcharType(INPUT_VARCHAR_LENGTH), ImmutableList.of(createVarcharType(INPUT_VARCHAR_LENGTH))); - ChoicesScalarFunctionImplementation functionImplementation = (ChoicesScalarFunctionImplementation) function.specialize( + ChoicesSpecializedSqlScalarFunction specializedFunction = (ChoicesSpecializedSqlScalarFunction) function.specialize( boundSignature, - new FunctionDependencies(FUNCTION_MANAGER::getScalarFunctionInvoker, ImmutableMap.of(), ImmutableSet.of())); - Slice slice = (Slice) functionImplementation.getChoices().get(0).getMethodHandle().invoke(INPUT_SLICE); + new InternalFunctionDependencies(FUNCTION_MANAGER::getScalarFunctionImplementation, ImmutableMap.of(), ImmutableSet.of())); + Slice slice = (Slice) specializedFunction.getChoices().get(0).getMethodHandle().invoke(INPUT_SLICE); assertEquals(slice, VARCHAR_TO_VARCHAR_RETURN_VALUE); } @@ -208,10 +211,10 @@ public void testTypeParameters() BoundSignature boundSignature = new BoundSignature(signature.getName(), VARCHAR, ImmutableList.of(VARCHAR)); - ChoicesScalarFunctionImplementation functionImplementation = (ChoicesScalarFunctionImplementation) function.specialize( + ChoicesSpecializedSqlScalarFunction specializedFunction = (ChoicesSpecializedSqlScalarFunction) function.specialize( boundSignature, - new FunctionDependencies(FUNCTION_MANAGER::getScalarFunctionInvoker, ImmutableMap.of(), ImmutableSet.of())); - Slice slice = (Slice) functionImplementation.getChoices().get(0).getMethodHandle().invoke(INPUT_SLICE); + new InternalFunctionDependencies(FUNCTION_MANAGER::getScalarFunctionImplementation, ImmutableMap.of(), ImmutableSet.of())); + Slice slice = (Slice) specializedFunction.getChoices().get(0).getMethodHandle().invoke(INPUT_SLICE); assertEquals(slice, VARCHAR_TO_VARCHAR_RETURN_VALUE); } @@ -232,7 +235,7 @@ public void testSetsHiddenToTrueForOperators() .build(); BoundSignature boundSignature = new BoundSignature(signature.getName(), createVarcharType(INPUT_VARCHAR_LENGTH), ImmutableList.of(createVarcharType(INPUT_VARCHAR_LENGTH))); - function.specialize(boundSignature, new FunctionDependencies(FUNCTION_MANAGER::getScalarFunctionInvoker, ImmutableMap.of(), ImmutableSet.of())); + function.specialize(boundSignature, new InternalFunctionDependencies(FUNCTION_MANAGER::getScalarFunctionImplementation, ImmutableMap.of(), ImmutableSet.of())); } @Test @@ -274,7 +277,7 @@ public void testFailIfTwoMethodsWithSameArguments() .implementation(methodsGroup -> methodsGroup.methods("varcharToBigintReturnExtraParameter"))) .build(); - assertThatThrownBy(() -> function.specialize(BOUND_SIGNATURE, new FunctionDependencies(FUNCTION_MANAGER::getScalarFunctionInvoker, ImmutableMap.of(), ImmutableSet.of()))) + assertThatThrownBy(() -> function.specialize(BOUND_SIGNATURE, new InternalFunctionDependencies(FUNCTION_MANAGER::getScalarFunctionImplementation, ImmutableMap.of(), ImmutableSet.of()))) .isInstanceOf(IllegalStateException.class) .hasMessageMatching("two matching methods \\(varcharToBigintReturnFirstExtraParameter and varcharToBigintReturnExtraParameter\\) for parameter types \\[varchar\\(10\\)\\]"); } diff --git a/core/trino-main/src/test/java/io/trino/metadata/TestResolvedFunction.java b/core/trino-main/src/test/java/io/trino/metadata/TestResolvedFunction.java index b09517d59972..90c30af43571 100644 --- a/core/trino-main/src/test/java/io/trino/metadata/TestResolvedFunction.java +++ b/core/trino-main/src/test/java/io/trino/metadata/TestResolvedFunction.java @@ -15,7 +15,12 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; +import io.trino.connector.system.GlobalSystemConnector; import io.trino.metadata.ResolvedFunction.ResolvedFunctionDecoder; +import io.trino.spi.function.BoundSignature; +import io.trino.spi.function.FunctionId; +import io.trino.spi.function.FunctionNullability; +import io.trino.spi.function.Signature; import io.trino.spi.type.Type; import io.trino.spi.type.TypeId; import io.trino.spi.type.TypeSignature; @@ -51,6 +56,7 @@ private static ResolvedFunction createResolvedFunction(String name, int depth) { return new ResolvedFunction( new BoundSignature(name + "_" + depth, createVarcharType(10 + depth), ImmutableList.of(createVarcharType(20 + depth), createVarcharType(30 + depth))), + GlobalSystemConnector.CATALOG_HANDLE, FunctionId.toFunctionId(Signature.builder() .name(name) .returnType(new TypeSignature("x")) diff --git a/core/trino-main/src/test/java/io/trino/metadata/TestSignature.java b/core/trino-main/src/test/java/io/trino/metadata/TestSignature.java index 7c88287c8ed1..5517264d1305 100644 --- a/core/trino-main/src/test/java/io/trino/metadata/TestSignature.java +++ b/core/trino-main/src/test/java/io/trino/metadata/TestSignature.java @@ -17,6 +17,7 @@ import io.airlift.json.JsonCodec; import io.airlift.json.JsonCodecFactory; import io.airlift.json.ObjectMapperProvider; +import io.trino.spi.function.Signature; import io.trino.spi.type.Type; import io.trino.spi.type.TypeSignature; import io.trino.type.TypeDeserializer; diff --git a/core/trino-main/src/test/java/io/trino/metadata/TestSignatureBinder.java b/core/trino-main/src/test/java/io/trino/metadata/TestSignatureBinder.java index bbe0a2894199..11cab519d155 100644 --- a/core/trino-main/src/test/java/io/trino/metadata/TestSignatureBinder.java +++ b/core/trino-main/src/test/java/io/trino/metadata/TestSignatureBinder.java @@ -15,6 +15,8 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; +import io.trino.spi.function.Signature; +import io.trino.spi.function.TypeVariableConstraint; import io.trino.spi.type.ArrayType; import io.trino.spi.type.RowType; import io.trino.spi.type.Type; diff --git a/core/trino-main/src/test/java/io/trino/metadata/TestingFunctionResolution.java b/core/trino-main/src/test/java/io/trino/metadata/TestingFunctionResolution.java index f0271b886cf1..87207e262699 100644 --- a/core/trino-main/src/test/java/io/trino/metadata/TestingFunctionResolution.java +++ b/core/trino-main/src/test/java/io/trino/metadata/TestingFunctionResolution.java @@ -18,6 +18,7 @@ import io.trino.security.AllowAllAccessControl; import io.trino.spi.function.InvocationConvention; import io.trino.spi.function.OperatorType; +import io.trino.spi.function.ScalarFunctionImplementation; import io.trino.spi.type.Type; import io.trino.spi.type.TypeSignature; import io.trino.sql.PlannerContext; @@ -131,9 +132,9 @@ public ResolvedFunction resolveFunction(QualifiedName name, List metadata.resolveFunction(session, name, parameterTypes)); } - public FunctionInvoker getScalarFunctionInvoker(QualifiedName name, List parameterTypes, InvocationConvention invocationConvention) + public ScalarFunctionImplementation getScalarFunction(QualifiedName name, List parameterTypes, InvocationConvention invocationConvention) { - return inTransaction(session -> plannerContext.getFunctionManager().getScalarFunctionInvoker(metadata.resolveFunction(session, name, parameterTypes), invocationConvention)); + return inTransaction(session -> plannerContext.getFunctionManager().getScalarFunctionImplementation(metadata.resolveFunction(session, name, parameterTypes), invocationConvention)); } public TestingAggregationFunction getAggregateFunction(QualifiedName name, List parameterTypes) @@ -143,7 +144,7 @@ public TestingAggregationFunction getAggregateFunction(QualifiedName name, List< return new TestingAggregationFunction( resolvedFunction.getSignature(), resolvedFunction.getFunctionNullability(), - plannerContext.getFunctionManager().getAggregateFunctionImplementation(resolvedFunction)); + plannerContext.getFunctionManager().getAggregationImplementation(resolvedFunction)); }); } diff --git a/core/trino-main/src/test/java/io/trino/operator/GenericLongFunction.java b/core/trino-main/src/test/java/io/trino/operator/GenericLongFunction.java index 5c1506eed8ce..ec7bb98c9497 100644 --- a/core/trino-main/src/test/java/io/trino/operator/GenericLongFunction.java +++ b/core/trino-main/src/test/java/io/trino/operator/GenericLongFunction.java @@ -14,12 +14,12 @@ package io.trino.operator; import com.google.common.collect.ImmutableList; -import io.trino.metadata.BoundSignature; -import io.trino.metadata.FunctionMetadata; -import io.trino.metadata.Signature; import io.trino.metadata.SqlScalarFunction; -import io.trino.operator.scalar.ChoicesScalarFunctionImplementation; -import io.trino.operator.scalar.ScalarFunctionImplementation; +import io.trino.operator.scalar.ChoicesSpecializedSqlScalarFunction; +import io.trino.operator.scalar.SpecializedSqlScalarFunction; +import io.trino.spi.function.BoundSignature; +import io.trino.spi.function.FunctionMetadata; +import io.trino.spi.function.Signature; import java.lang.invoke.MethodHandle; import java.util.function.LongUnaryOperator; @@ -52,10 +52,10 @@ public final class GenericLongFunction } @Override - protected ScalarFunctionImplementation specialize(BoundSignature boundSignature) + protected SpecializedSqlScalarFunction specialize(BoundSignature boundSignature) { MethodHandle methodHandle = METHOD_HANDLE.bindTo(longUnaryOperator); - return new ChoicesScalarFunctionImplementation(boundSignature, FAIL_ON_NULL, ImmutableList.of(NEVER_NULL), methodHandle); + return new ChoicesSpecializedSqlScalarFunction(boundSignature, FAIL_ON_NULL, ImmutableList.of(NEVER_NULL), methodHandle); } public static long apply(LongUnaryOperator longUnaryOperator, long value) diff --git a/core/trino-main/src/test/java/io/trino/operator/TestAnnotationEngine.java b/core/trino-main/src/test/java/io/trino/operator/TestAnnotationEngine.java index ce97ab55dd2c..4e60a1aee6ff 100644 --- a/core/trino-main/src/test/java/io/trino/operator/TestAnnotationEngine.java +++ b/core/trino-main/src/test/java/io/trino/operator/TestAnnotationEngine.java @@ -13,7 +13,7 @@ */ package io.trino.operator; -import io.trino.operator.aggregation.AggregationImplementation; +import io.trino.operator.aggregation.ParametricAggregationImplementation; import io.trino.operator.scalar.ParametricScalar; import static org.testng.Assert.assertEquals; @@ -32,7 +32,7 @@ void assertImplementationCount(ParametricImplementationsGroup implementations assertEquals(implementations.getGenericImplementations().size(), generic); } - void assertDependencyCount(AggregationImplementation implementation, int input, int combine, int output) + void assertDependencyCount(ParametricAggregationImplementation implementation, int input, int combine, int output) { assertEquals(implementation.getInputDependencies().size(), input); assertEquals(implementation.getCombineDependencies().size(), combine); diff --git a/core/trino-main/src/test/java/io/trino/operator/TestAnnotationEngineForAggregates.java b/core/trino-main/src/test/java/io/trino/operator/TestAnnotationEngineForAggregates.java index 455e86a69634..5435e3ce5296 100644 --- a/core/trino-main/src/test/java/io/trino/operator/TestAnnotationEngineForAggregates.java +++ b/core/trino-main/src/test/java/io/trino/operator/TestAnnotationEngineForAggregates.java @@ -17,18 +17,15 @@ import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import io.airlift.slice.Slice; -import io.trino.metadata.AggregationFunctionMetadata; -import io.trino.metadata.BoundSignature; +import io.trino.connector.system.GlobalSystemConnector; import io.trino.metadata.FunctionBinding; -import io.trino.metadata.FunctionDependencies; import io.trino.metadata.FunctionManager; -import io.trino.metadata.FunctionMetadata; +import io.trino.metadata.InternalFunctionDependencies; import io.trino.metadata.MetadataManager; import io.trino.metadata.ResolvedFunction; -import io.trino.metadata.Signature; import io.trino.metadata.SqlAggregationFunction; -import io.trino.operator.aggregation.AggregationImplementation; import io.trino.operator.aggregation.ParametricAggregation; +import io.trino.operator.aggregation.ParametricAggregationImplementation; import io.trino.operator.aggregation.state.LongState; import io.trino.operator.aggregation.state.NullableDoubleState; import io.trino.operator.aggregation.state.NullableLongState; @@ -40,17 +37,22 @@ import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; import io.trino.spi.function.AggregationFunction; +import io.trino.spi.function.AggregationFunctionMetadata; import io.trino.spi.function.AggregationState; import io.trino.spi.function.BlockIndex; import io.trino.spi.function.BlockPosition; +import io.trino.spi.function.BoundSignature; import io.trino.spi.function.CombineFunction; import io.trino.spi.function.Convention; import io.trino.spi.function.Description; +import io.trino.spi.function.FunctionDependencies; +import io.trino.spi.function.FunctionMetadata; import io.trino.spi.function.InputFunction; import io.trino.spi.function.LiteralParameter; import io.trino.spi.function.LiteralParameters; import io.trino.spi.function.OperatorDependency; import io.trino.spi.function.OutputFunction; +import io.trino.spi.function.Signature; import io.trino.spi.function.SqlType; import io.trino.spi.function.TypeParameter; import io.trino.spi.function.TypeParameterSpecialization; @@ -95,7 +97,7 @@ public class TestAnnotationEngineForAggregates { private static final MetadataManager METADATA = createTestMetadataManager(); private static final FunctionManager FUNCTION_MANAGER = createTestingFunctionManager(); - private static final FunctionDependencies NO_FUNCTION_DEPENDENCIES = new FunctionDependencies(FUNCTION_MANAGER::getScalarFunctionInvoker, ImmutableMap.of(), ImmutableSet.of()); + private static final FunctionDependencies NO_FUNCTION_DEPENDENCIES = new InternalFunctionDependencies(FUNCTION_MANAGER::getScalarFunctionImplementation, ImmutableMap.of(), ImmutableSet.of()); @AggregationFunction("simple_exact_aggregate") @Description("Simple exact aggregate description") @@ -133,9 +135,9 @@ public void testSimpleExactAggregationParse() assertEquals(aggregation.getFunctionMetadata().getDescription(), "Simple exact aggregate description"); assertTrue(aggregation.getFunctionMetadata().isDeterministic()); assertEquals(aggregation.getFunctionMetadata().getSignature(), expectedSignature); - ParametricImplementationsGroup implementations = aggregation.getImplementations(); + ParametricImplementationsGroup implementations = aggregation.getImplementations(); assertImplementationCount(implementations, 1, 0, 0); - AggregationImplementation implementation = getOnlyElement(implementations.getExactImplementations().values()); + ParametricAggregationImplementation implementation = getOnlyElement(implementations.getExactImplementations().values()); assertEquals(implementation.getDefinitionClass(), ExactAggregationFunction.class); assertDependencyCount(implementation, 0, 0, 0); assertFalse(implementation.hasSpecializedTypeParameters()); @@ -238,7 +240,7 @@ public void testNotAnnotatedAggregateStateAggregationParse() { ParametricAggregation aggregation = getOnlyElement(parseFunctionDefinitions(NotAnnotatedAggregateStateAggregationFunction.class)); - AggregationImplementation implementation = getOnlyElement(aggregation.getImplementations().getExactImplementations().values()); + ParametricAggregationImplementation implementation = getOnlyElement(aggregation.getImplementations().getExactImplementations().values()); assertEquals(implementation.getInputParameterKinds(), ImmutableList.of(STATE, INPUT_CHANNEL)); BoundSignature boundSignature = new BoundSignature(aggregation.getFunctionMetadata().getSignature().getName(), DoubleType.DOUBLE, ImmutableList.of(DoubleType.DOUBLE)); @@ -344,9 +346,9 @@ public void testSimpleGenericAggregationFunctionParse() assertTrue(aggregation.getFunctionMetadata().isDeterministic()); assertEquals(aggregation.getFunctionMetadata().getSignature(), expectedSignature); assertEquals(aggregation.getStateDetails(), ImmutableList.of(toAccumulatorStateDetails(NullableLongState.class, ImmutableList.of()))); - ParametricImplementationsGroup implementations = aggregation.getImplementations(); + ParametricImplementationsGroup implementations = aggregation.getImplementations(); assertImplementationCount(implementations, 0, 0, 2); - AggregationImplementation implementationDouble = implementations.getGenericImplementations().stream() + ParametricAggregationImplementation implementationDouble = implementations.getGenericImplementations().stream() .filter(impl -> impl.getInputFunction().type().equals(methodType(void.class, NullableLongState.class, double.class))) .collect(toImmutableList()) .get(0); @@ -355,7 +357,7 @@ public void testSimpleGenericAggregationFunctionParse() assertFalse(implementationDouble.hasSpecializedTypeParameters()); assertEquals(implementationDouble.getInputParameterKinds(), ImmutableList.of(STATE, INPUT_CHANNEL)); - AggregationImplementation implementationLong = implementations.getGenericImplementations().stream() + ParametricAggregationImplementation implementationLong = implementations.getGenericImplementations().stream() .filter(impl -> impl.getInputFunction().type().equals(methodType(void.class, NullableLongState.class, long.class))) .collect(toImmutableList()) .get(0); @@ -414,9 +416,9 @@ public void testSimpleBlockInputAggregationParse() assertEquals(aggregation.getFunctionMetadata().getDescription(), "Simple aggregate with @BlockPosition usage"); assertTrue(aggregation.getFunctionMetadata().isDeterministic()); assertEquals(aggregation.getFunctionMetadata().getSignature(), expectedSignature); - ParametricImplementationsGroup implementations = aggregation.getImplementations(); + ParametricImplementationsGroup implementations = aggregation.getImplementations(); assertImplementationCount(implementations, 1, 0, 0); - AggregationImplementation implementation = getOnlyElement(implementations.getExactImplementations().values()); + ParametricAggregationImplementation implementation = getOnlyElement(implementations.getExactImplementations().values()); assertEquals(implementation.getDefinitionClass(), BlockInputAggregationFunction.class); assertDependencyCount(implementation, 0, 0, 0); assertFalse(implementation.hasSpecializedTypeParameters()); @@ -499,15 +501,15 @@ public void testSimpleImplicitSpecializedAggregationParse() assertEquals(aggregation.getFunctionMetadata().getDescription(), "Simple implicit specialized aggregate"); assertTrue(aggregation.getFunctionMetadata().isDeterministic()); assertEquals(aggregation.getFunctionMetadata().getSignature(), expectedSignature); - ParametricImplementationsGroup implementations = aggregation.getImplementations(); + ParametricImplementationsGroup implementations = aggregation.getImplementations(); assertImplementationCount(implementations, 0, 0, 2); - AggregationImplementation implementation1 = implementations.getSpecializedImplementations().get(0); + ParametricAggregationImplementation implementation1 = implementations.getSpecializedImplementations().get(0); assertTrue(implementation1.hasSpecializedTypeParameters()); assertFalse(implementation1.hasSpecializedTypeParameters()); assertEquals(implementation1.getInputParameterKinds(), ImmutableList.of(STATE, INPUT_CHANNEL, INPUT_CHANNEL)); - AggregationImplementation implementation2 = implementations.getSpecializedImplementations().get(1); + ParametricAggregationImplementation implementation2 = implementations.getSpecializedImplementations().get(1); assertTrue(implementation2.hasSpecializedTypeParameters()); assertFalse(implementation2.hasSpecializedTypeParameters()); assertEquals(implementation2.getInputParameterKinds(), ImmutableList.of(STATE, INPUT_CHANNEL, INPUT_CHANNEL)); @@ -589,14 +591,14 @@ public void testSimpleExplicitSpecializedAggregationParse() assertEquals(aggregation.getFunctionMetadata().getDescription(), "Simple explicit specialized aggregate"); assertTrue(aggregation.getFunctionMetadata().isDeterministic()); assertEquals(aggregation.getFunctionMetadata().getSignature(), expectedSignature); - ParametricImplementationsGroup implementations = aggregation.getImplementations(); + ParametricImplementationsGroup implementations = aggregation.getImplementations(); assertImplementationCount(implementations, 0, 1, 1); - AggregationImplementation implementation1 = implementations.getSpecializedImplementations().get(0); + ParametricAggregationImplementation implementation1 = implementations.getSpecializedImplementations().get(0); assertTrue(implementation1.hasSpecializedTypeParameters()); assertFalse(implementation1.hasSpecializedTypeParameters()); assertEquals(implementation1.getInputParameterKinds(), ImmutableList.of(STATE, INPUT_CHANNEL)); - AggregationImplementation implementation2 = implementations.getSpecializedImplementations().get(1); + ParametricAggregationImplementation implementation2 = implementations.getSpecializedImplementations().get(1); assertTrue(implementation2.hasSpecializedTypeParameters()); assertFalse(implementation2.hasSpecializedTypeParameters()); assertEquals(implementation2.getInputParameterKinds(), ImmutableList.of(STATE, INPUT_CHANNEL)); @@ -674,13 +676,13 @@ public void testMultiOutputAggregationParse() assertEquals(aggregation2.getFunctionMetadata().getSignature(), expectedSignature2); assertEquals(aggregation2.getFunctionMetadata().getDescription(), "Simple multi output function aggregate generic description"); - ParametricImplementationsGroup implementations1 = aggregation1.getImplementations(); + ParametricImplementationsGroup implementations1 = aggregation1.getImplementations(); assertImplementationCount(implementations1, 1, 0, 0); - ParametricImplementationsGroup implementations2 = aggregation2.getImplementations(); + ParametricImplementationsGroup implementations2 = aggregation2.getImplementations(); assertImplementationCount(implementations2, 1, 0, 0); - AggregationImplementation implementation = getOnlyElement(implementations1.getExactImplementations().values()); + ParametricAggregationImplementation implementation = getOnlyElement(implementations1.getExactImplementations().values()); assertEquals(implementation.getDefinitionClass(), MultiOutputAggregationFunction.class); assertDependencyCount(implementation, 0, 0, 0); assertFalse(implementation.hasSpecializedTypeParameters()); @@ -750,9 +752,9 @@ public void testInjectOperatorAggregateParse() assertEquals(aggregation.getFunctionMetadata().getDescription(), "Simple aggregate with operator injected"); assertTrue(aggregation.getFunctionMetadata().isDeterministic()); assertEquals(aggregation.getFunctionMetadata().getSignature(), expectedSignature); - ParametricImplementationsGroup implementations = aggregation.getImplementations(); + ParametricImplementationsGroup implementations = aggregation.getImplementations(); - AggregationImplementation implementation = getOnlyElement(implementations.getExactImplementations().values()); + ParametricAggregationImplementation implementation = getOnlyElement(implementations.getExactImplementations().values()); assertEquals(implementation.getDefinitionClass(), InjectOperatorAggregateFunction.class); assertDependencyCount(implementation, 1, 1, 1); @@ -814,10 +816,10 @@ public void testInjectTypeAggregateParse() assertEquals(aggregation.getFunctionMetadata().getDescription(), "Simple aggregate with type injected"); assertTrue(aggregation.getFunctionMetadata().isDeterministic()); assertEquals(aggregation.getFunctionMetadata().getSignature(), expectedSignature); - ParametricImplementationsGroup implementations = aggregation.getImplementations(); + ParametricImplementationsGroup implementations = aggregation.getImplementations(); assertEquals(implementations.getGenericImplementations().size(), 1); - AggregationImplementation implementation = implementations.getGenericImplementations().get(0); + ParametricAggregationImplementation implementation = implementations.getGenericImplementations().get(0); assertEquals(implementation.getDefinitionClass(), InjectTypeAggregateFunction.class); assertDependencyCount(implementation, 1, 1, 1); @@ -878,10 +880,10 @@ public void testInjectLiteralAggregateParse() assertEquals(aggregation.getFunctionMetadata().getDescription(), "Simple aggregate with type literal"); assertTrue(aggregation.getFunctionMetadata().isDeterministic()); assertEquals(aggregation.getFunctionMetadata().getSignature(), expectedSignature); - ParametricImplementationsGroup implementations = aggregation.getImplementations(); + ParametricImplementationsGroup implementations = aggregation.getImplementations(); assertEquals(implementations.getGenericImplementations().size(), 1); - AggregationImplementation implementation = implementations.getGenericImplementations().get(0); + ParametricAggregationImplementation implementation = implementations.getGenericImplementations().get(0); assertEquals(implementation.getDefinitionClass(), InjectLiteralAggregateFunction.class); assertDependencyCount(implementation, 1, 1, 1); @@ -946,10 +948,10 @@ public void testLongConstraintAggregateFunctionParse() assertEquals(aggregation.getFunctionMetadata().getDescription(), "Parametric aggregate with parametric type returned"); assertTrue(aggregation.getFunctionMetadata().isDeterministic()); assertEquals(aggregation.getFunctionMetadata().getSignature(), expectedSignature); - ParametricImplementationsGroup implementations = aggregation.getImplementations(); + ParametricImplementationsGroup implementations = aggregation.getImplementations(); assertEquals(implementations.getGenericImplementations().size(), 1); - AggregationImplementation implementation = implementations.getGenericImplementations().get(0); + ParametricAggregationImplementation implementation = implementations.getGenericImplementations().get(0); assertEquals(implementation.getDefinitionClass(), LongConstraintAggregateFunction.class); assertDependencyCount(implementation, 0, 0, 0); @@ -1009,9 +1011,9 @@ public void testFixedTypeParameterInjectionAggregateFunctionParse() assertTrue(aggregation.getFunctionMetadata().isDeterministic()); assertEquals(aggregation.getFunctionMetadata().getSignature(), expectedSignature); assertEquals(aggregation.getStateDetails(), ImmutableList.of(toAccumulatorStateDetails(NullableDoubleState.class, ImmutableList.of()))); - ParametricImplementationsGroup implementations = aggregation.getImplementations(); + ParametricImplementationsGroup implementations = aggregation.getImplementations(); assertImplementationCount(implementations, 1, 0, 0); - AggregationImplementation implementationDouble = implementations.getExactImplementations().get(expectedSignature); + ParametricAggregationImplementation implementationDouble = implementations.getExactImplementations().get(expectedSignature); assertEquals(implementationDouble.getDefinitionClass(), FixedTypeParameterInjectionAggregateFunction.class); assertDependencyCount(implementationDouble, 1, 1, 1); assertFalse(implementationDouble.hasSpecializedTypeParameters()); @@ -1073,9 +1075,9 @@ public void testPartiallyFixedTypeParameterInjectionAggregateFunctionParse() assertTrue(aggregation.getFunctionMetadata().isDeterministic()); assertEquals(aggregation.getFunctionMetadata().getSignature(), expectedSignature); assertEquals(aggregation.getStateDetails(), ImmutableList.of(toAccumulatorStateDetails(NullableDoubleState.class, ImmutableList.of()))); - ParametricImplementationsGroup implementations = aggregation.getImplementations(); + ParametricImplementationsGroup implementations = aggregation.getImplementations(); assertImplementationCount(implementations, 0, 0, 1); - AggregationImplementation implementationDouble = getOnlyElement(implementations.getGenericImplementations()); + ParametricAggregationImplementation implementationDouble = getOnlyElement(implementations.getGenericImplementations()); assertEquals(implementationDouble.getDefinitionClass(), PartiallyFixedTypeParameterInjectionAggregateFunction.class); assertDependencyCount(implementationDouble, 1, 1, 1); assertFalse(implementationDouble.hasSpecializedTypeParameters()); @@ -1171,8 +1173,8 @@ private static void specializeAggregationFunction(BoundSignature boundSignature, assertFalse(aggregationMetadata.isOrderSensitive()); assertFalse(aggregationMetadata.getIntermediateTypes().isEmpty()); - ResolvedFunction resolvedFunction = METADATA.resolve(TEST_SESSION, functionBinding, functionMetadata, aggregation.getFunctionDependencies(boundSignature)); - FunctionDependencies functionDependencies = new FunctionDependencies(FUNCTION_MANAGER::getScalarFunctionInvoker, resolvedFunction.getTypeDependencies(), resolvedFunction.getFunctionDependencies()); + ResolvedFunction resolvedFunction = METADATA.resolve(TEST_SESSION, GlobalSystemConnector.CATALOG_HANDLE, functionBinding, functionMetadata, aggregation.getFunctionDependencies(boundSignature)); + FunctionDependencies functionDependencies = new InternalFunctionDependencies(FUNCTION_MANAGER::getScalarFunctionImplementation, resolvedFunction.getTypeDependencies(), resolvedFunction.getFunctionDependencies()); aggregation.specialize(boundSignature, functionDependencies); } } diff --git a/core/trino-main/src/test/java/io/trino/operator/TestAnnotationEngineForScalars.java b/core/trino-main/src/test/java/io/trino/operator/TestAnnotationEngineForScalars.java index 8a01c5c81ad8..2cd72a178388 100644 --- a/core/trino-main/src/test/java/io/trino/operator/TestAnnotationEngineForScalars.java +++ b/core/trino-main/src/test/java/io/trino/operator/TestAnnotationEngineForScalars.java @@ -17,25 +17,25 @@ import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import io.airlift.slice.Slice; -import io.trino.metadata.BoundSignature; -import io.trino.metadata.FunctionDependencies; import io.trino.metadata.FunctionManager; -import io.trino.metadata.FunctionMetadata; -import io.trino.metadata.Signature; +import io.trino.metadata.InternalFunctionDependencies; import io.trino.metadata.SqlScalarFunction; import io.trino.operator.annotations.ImplementationDependency; import io.trino.operator.annotations.LiteralImplementationDependency; import io.trino.operator.annotations.TypeImplementationDependency; -import io.trino.operator.scalar.ChoicesScalarFunctionImplementation; +import io.trino.operator.scalar.ChoicesSpecializedSqlScalarFunction; import io.trino.operator.scalar.ParametricScalar; import io.trino.operator.scalar.annotations.ParametricScalarImplementation.ParametricScalarImplementationChoice; import io.trino.operator.scalar.annotations.ScalarFromAnnotationsParser; import io.trino.spi.block.Block; +import io.trino.spi.function.BoundSignature; import io.trino.spi.function.Description; +import io.trino.spi.function.FunctionMetadata; import io.trino.spi.function.IsNull; import io.trino.spi.function.LiteralParameter; import io.trino.spi.function.LiteralParameters; import io.trino.spi.function.ScalarFunction; +import io.trino.spi.function.Signature; import io.trino.spi.function.SqlNullable; import io.trino.spi.function.SqlType; import io.trino.spi.function.TypeParameter; @@ -98,9 +98,9 @@ public void testSingleImplementationScalarParse() assertImplementationCount(scalar, 1, 0, 0); BoundSignature boundSignature = new BoundSignature(expectedSignature.getName(), DOUBLE, ImmutableList.of(DOUBLE)); - ChoicesScalarFunctionImplementation specialized = (ChoicesScalarFunctionImplementation) scalar.specialize( + ChoicesSpecializedSqlScalarFunction specialized = (ChoicesSpecializedSqlScalarFunction) scalar.specialize( boundSignature, - new FunctionDependencies(FUNCTION_MANAGER::getScalarFunctionInvoker, ImmutableMap.of(), ImmutableSet.of())); + new InternalFunctionDependencies(FUNCTION_MANAGER::getScalarFunctionImplementation, ImmutableMap.of(), ImmutableSet.of())); assertFalse(specialized.getChoices().get(0).getInstanceFactory().isPresent()); } @@ -187,9 +187,9 @@ public void testWithNullablePrimitiveArgScalarParse() assertTrue(functionMetadata.getFunctionNullability().isArgumentNullable(1)); BoundSignature boundSignature = new BoundSignature(expectedSignature.getName(), DOUBLE, ImmutableList.of(DOUBLE, DOUBLE)); - ChoicesScalarFunctionImplementation specialized = (ChoicesScalarFunctionImplementation) scalar.specialize( + ChoicesSpecializedSqlScalarFunction specialized = (ChoicesSpecializedSqlScalarFunction) scalar.specialize( boundSignature, - new FunctionDependencies(FUNCTION_MANAGER::getScalarFunctionInvoker, ImmutableMap.of(), ImmutableSet.of())); + new InternalFunctionDependencies(FUNCTION_MANAGER::getScalarFunctionImplementation, ImmutableMap.of(), ImmutableSet.of())); assertFalse(specialized.getChoices().get(0).getInstanceFactory().isPresent()); } @@ -229,9 +229,9 @@ public void testWithNullableComplexArgScalarParse() assertTrue(functionMetadata.getFunctionNullability().isArgumentNullable(1)); BoundSignature boundSignature = new BoundSignature(expectedSignature.getName(), DOUBLE, ImmutableList.of(DOUBLE, DOUBLE)); - ChoicesScalarFunctionImplementation specialized = (ChoicesScalarFunctionImplementation) scalar.specialize( + ChoicesSpecializedSqlScalarFunction specialized = (ChoicesSpecializedSqlScalarFunction) scalar.specialize( boundSignature, - new FunctionDependencies(FUNCTION_MANAGER::getScalarFunctionInvoker, ImmutableMap.of(), ImmutableSet.of())); + new InternalFunctionDependencies(FUNCTION_MANAGER::getScalarFunctionImplementation, ImmutableMap.of(), ImmutableSet.of())); assertFalse(specialized.getChoices().get(0).getInstanceFactory().isPresent()); } diff --git a/core/trino-main/src/test/java/io/trino/operator/aggregation/AbstractTestAggregationFunction.java b/core/trino-main/src/test/java/io/trino/operator/aggregation/AbstractTestAggregationFunction.java index 60d4b4b35505..c74e5c7c175f 100644 --- a/core/trino-main/src/test/java/io/trino/operator/aggregation/AbstractTestAggregationFunction.java +++ b/core/trino-main/src/test/java/io/trino/operator/aggregation/AbstractTestAggregationFunction.java @@ -24,6 +24,7 @@ import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; import io.trino.spi.block.RunLengthEncodedBlock; +import io.trino.spi.function.AggregationImplementation; import io.trino.spi.function.WindowIndex; import io.trino.spi.type.Type; import io.trino.sql.tree.QualifiedName; @@ -146,14 +147,14 @@ public void testSlidingWindow() WindowIndex windowIndex = new PagesWindowIndex(pagesIndex, 0, totalPositions - 1); ResolvedFunction resolvedFunction = functionResolution.resolveFunction(QualifiedName.of(getFunctionName()), fromTypes(getFunctionParameterTypes())); - AggregationMetadata aggregationMetadata = functionResolution.getPlannerContext().getFunctionManager().getAggregateFunctionImplementation(resolvedFunction); - WindowAccumulator aggregation = createWindowAccumulator(resolvedFunction, aggregationMetadata); + AggregationImplementation aggregationImplementation = functionResolution.getPlannerContext().getFunctionManager().getAggregationImplementation(resolvedFunction); + WindowAccumulator aggregation = createWindowAccumulator(resolvedFunction, aggregationImplementation); int oldStart = 0; int oldWidth = 0; for (int start = 0; start < totalPositions; ++start) { int width = windowWidths[start]; // Note that add/removeInput's interval is inclusive on both ends - if (aggregationMetadata.getRemoveInputFunction().isPresent()) { + if (aggregationImplementation.getRemoveInputFunction().isPresent()) { for (int oldi = oldStart; oldi < oldStart + oldWidth; ++oldi) { if (oldi < start || oldi >= start + width) { aggregation.removeInput(windowIndex, oldi, oldi); @@ -166,7 +167,7 @@ public void testSlidingWindow() } } else { - aggregation = createWindowAccumulator(resolvedFunction, aggregationMetadata); + aggregation = createWindowAccumulator(resolvedFunction, aggregationImplementation); aggregation.addInput(windowIndex, start, start + width - 1); } oldStart = start; @@ -184,12 +185,12 @@ public void testSlidingWindow() } } - private static WindowAccumulator createWindowAccumulator(ResolvedFunction resolvedFunction, AggregationMetadata aggregationMetadata) + private static WindowAccumulator createWindowAccumulator(ResolvedFunction resolvedFunction, AggregationImplementation aggregationImplementation) { try { Constructor constructor = generateWindowAccumulatorClass( resolvedFunction.getSignature(), - aggregationMetadata, + aggregationImplementation, resolvedFunction.getFunctionNullability()); return constructor.newInstance(ImmutableList.of()); } diff --git a/core/trino-main/src/test/java/io/trino/operator/aggregation/TestAccumulatorCompiler.java b/core/trino-main/src/test/java/io/trino/operator/aggregation/TestAccumulatorCompiler.java index 24d174c3750f..a84db0a52e2a 100644 --- a/core/trino-main/src/test/java/io/trino/operator/aggregation/TestAccumulatorCompiler.java +++ b/core/trino-main/src/test/java/io/trino/operator/aggregation/TestAccumulatorCompiler.java @@ -15,8 +15,6 @@ import com.google.common.collect.ImmutableList; import io.airlift.bytecode.DynamicClassLoader; -import io.trino.metadata.BoundSignature; -import io.trino.metadata.FunctionNullability; import io.trino.operator.aggregation.state.StateCompiler; import io.trino.server.PluginManager; import io.trino.spi.Page; @@ -25,6 +23,9 @@ import io.trino.spi.function.AccumulatorState; import io.trino.spi.function.AccumulatorStateFactory; import io.trino.spi.function.AccumulatorStateSerializer; +import io.trino.spi.function.AggregationImplementation; +import io.trino.spi.function.BoundSignature; +import io.trino.spi.function.FunctionNullability; import io.trino.spi.type.LongTimestamp; import io.trino.spi.type.RealType; import io.trino.spi.type.TimestampType; @@ -32,7 +33,6 @@ import org.testng.annotations.Test; import java.lang.invoke.MethodHandle; -import java.util.Optional; import static io.trino.operator.aggregation.AggregationFunctionAdapter.AggregationParameterKind.INPUT_CHANNEL; import static io.trino.operator.aggregation.AggregationFunctionAdapter.AggregationParameterKind.STATE; @@ -85,21 +85,18 @@ private static void assertGenerateAccumulator(Cl inputFunction = normalizeInputMethod(inputFunction, signature, STATE, INPUT_CHANNEL); MethodHandle combineFunction = methodHandle(aggregation, "combine", stateInterface, stateInterface); MethodHandle outputFunction = methodHandle(aggregation, "output", stateInterface, BlockBuilder.class); - AggregationMetadata metadata = new AggregationMetadata( - inputFunction, - Optional.empty(), - Optional.of(combineFunction), - outputFunction, - ImmutableList.of(new AggregationMetadata.AccumulatorStateDescriptor<>( - stateInterface, - stateSerializer, - stateFactory))); + AggregationImplementation implementation = AggregationImplementation.builder() + .inputFunction(inputFunction) + .combineFunction(combineFunction) + .outputFunction(outputFunction) + .accumulatorStateDescriptor(stateInterface, stateSerializer, stateFactory) + .build(); FunctionNullability functionNullability = new FunctionNullability(false, ImmutableList.of(false)); // test if we can compile aggregation - AccumulatorFactory accumulatorFactory = AccumulatorCompiler.generateAccumulatorFactory(signature, metadata, functionNullability); + AccumulatorFactory accumulatorFactory = AccumulatorCompiler.generateAccumulatorFactory(signature, implementation, functionNullability); assertThat(accumulatorFactory).isNotNull(); - assertThat(AccumulatorCompiler.generateWindowAccumulatorClass(signature, metadata, functionNullability)).isNotNull(); + assertThat(AccumulatorCompiler.generateWindowAccumulatorClass(signature, implementation, functionNullability)).isNotNull(); TestingAggregationFunction aggregationFunction = new TestingAggregationFunction( ImmutableList.of(TIMESTAMP_PICOS), diff --git a/core/trino-main/src/test/java/io/trino/operator/aggregation/TestingAggregationFunction.java b/core/trino-main/src/test/java/io/trino/operator/aggregation/TestingAggregationFunction.java index eb9ea8e3fe91..8947bb94ed18 100644 --- a/core/trino-main/src/test/java/io/trino/operator/aggregation/TestingAggregationFunction.java +++ b/core/trino-main/src/test/java/io/trino/operator/aggregation/TestingAggregationFunction.java @@ -14,8 +14,9 @@ package io.trino.operator.aggregation; import com.google.common.collect.ImmutableList; -import io.trino.metadata.BoundSignature; -import io.trino.metadata.FunctionNullability; +import io.trino.spi.function.AggregationImplementation; +import io.trino.spi.function.BoundSignature; +import io.trino.spi.function.FunctionNullability; import io.trino.spi.type.RowType; import io.trino.spi.type.Type; import io.trino.spi.type.TypeOperators; @@ -43,15 +44,15 @@ public class TestingAggregationFunction private final AccumulatorFactory factory; private final DistinctAccumulatorFactory distinctFactory; - public TestingAggregationFunction(BoundSignature signature, FunctionNullability functionNullability, AggregationMetadata aggregationMetadata) + public TestingAggregationFunction(BoundSignature signature, FunctionNullability functionNullability, AggregationImplementation aggregationImplementation) { this.parameterTypes = signature.getArgumentTypes(); - List intermediateTypes = aggregationMetadata.getAccumulatorStateDescriptors().stream() + List intermediateTypes = aggregationImplementation.getAccumulatorStateDescriptors().stream() .map(stateDescriptor -> stateDescriptor.getSerializer().getSerializedType()) .collect(toImmutableList()); intermediateType = (intermediateTypes.size() == 1) ? getOnlyElement(intermediateTypes) : RowType.anonymous(intermediateTypes); this.finalType = signature.getReturnType(); - this.factory = generateAccumulatorFactory(signature, aggregationMetadata, functionNullability); + this.factory = generateAccumulatorFactory(signature, aggregationImplementation, functionNullability); distinctFactory = new DistinctAccumulatorFactory( factory, parameterTypes, diff --git a/core/trino-main/src/test/java/io/trino/operator/scalar/AbstractTestFunctions.java b/core/trino-main/src/test/java/io/trino/operator/scalar/AbstractTestFunctions.java index 3b49618a8a07..09f1a5b5f292 100644 --- a/core/trino-main/src/test/java/io/trino/operator/scalar/AbstractTestFunctions.java +++ b/core/trino-main/src/test/java/io/trino/operator/scalar/AbstractTestFunctions.java @@ -35,7 +35,7 @@ import static io.airlift.testing.Closeables.closeAllRuntimeException; import static io.trino.SessionTestUtils.TEST_SESSION; -import static io.trino.metadata.Signature.mangleOperatorName; +import static io.trino.metadata.OperatorNameUtil.mangleOperatorName; import static io.trino.operator.scalar.timestamp.VarcharToTimestampCast.castToLongTimestamp; import static io.trino.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT; import static io.trino.spi.StandardErrorCode.NOT_SUPPORTED; diff --git a/core/trino-main/src/test/java/io/trino/operator/scalar/BenchmarkArrayFilter.java b/core/trino-main/src/test/java/io/trino/operator/scalar/BenchmarkArrayFilter.java index 60e3b826c714..6628df298941 100644 --- a/core/trino-main/src/test/java/io/trino/operator/scalar/BenchmarkArrayFilter.java +++ b/core/trino-main/src/test/java/io/trino/operator/scalar/BenchmarkArrayFilter.java @@ -15,11 +15,8 @@ import com.google.common.collect.ImmutableList; import io.trino.jmh.Benchmarks; -import io.trino.metadata.BoundSignature; -import io.trino.metadata.FunctionMetadata; import io.trino.metadata.InternalFunctionBundle; import io.trino.metadata.ResolvedFunction; -import io.trino.metadata.Signature; import io.trino.metadata.SqlScalarFunction; import io.trino.metadata.TestingFunctionResolution; import io.trino.operator.DriverYieldSignal; @@ -27,6 +24,9 @@ import io.trino.spi.Page; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.function.BoundSignature; +import io.trino.spi.function.FunctionMetadata; +import io.trino.spi.function.Signature; import io.trino.spi.type.ArrayType; import io.trino.spi.type.Type; import io.trino.spi.type.TypeSignature; @@ -206,10 +206,10 @@ private ExactArrayFilterFunction() } @Override - protected ScalarFunctionImplementation specialize(BoundSignature boundSignature) + protected SpecializedSqlScalarFunction specialize(BoundSignature boundSignature) { Type type = ((ArrayType) boundSignature.getReturnType()).getElementType(); - return new ChoicesScalarFunctionImplementation( + return new ChoicesSpecializedSqlScalarFunction( boundSignature, FAIL_ON_NULL, ImmutableList.of(NEVER_NULL, NEVER_NULL), diff --git a/core/trino-main/src/test/java/io/trino/operator/scalar/TestParametricScalarImplementationValidation.java b/core/trino-main/src/test/java/io/trino/operator/scalar/TestParametricScalarFunctionImplementationValidation.java similarity index 74% rename from core/trino-main/src/test/java/io/trino/operator/scalar/TestParametricScalarImplementationValidation.java rename to core/trino-main/src/test/java/io/trino/operator/scalar/TestParametricScalarFunctionImplementationValidation.java index 34d4bdea0169..cd459568b1d3 100644 --- a/core/trino-main/src/test/java/io/trino/operator/scalar/TestParametricScalarImplementationValidation.java +++ b/core/trino-main/src/test/java/io/trino/operator/scalar/TestParametricScalarFunctionImplementationValidation.java @@ -14,8 +14,8 @@ package io.trino.operator.scalar; import com.google.common.collect.ImmutableList; -import io.trino.metadata.BoundSignature; import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.function.BoundSignature; import org.testng.annotations.Test; import java.lang.invoke.MethodHandle; @@ -28,33 +28,33 @@ import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.testng.Assert.assertEquals; -public class TestParametricScalarImplementationValidation +public class TestParametricScalarFunctionImplementationValidation { - private static final MethodHandle STATE_FACTORY = methodHandle(TestParametricScalarImplementationValidation.class, "createState"); + private static final MethodHandle STATE_FACTORY = methodHandle(TestParametricScalarFunctionImplementationValidation.class, "createState"); @Test public void testConnectorSessionPosition() { // Without cached instance factory - MethodHandle validFunctionMethodHandle = methodHandle(TestParametricScalarImplementationValidation.class, "validConnectorSessionParameterPosition", ConnectorSession.class, long.class, long.class); - ChoicesScalarFunctionImplementation validFunction = new ChoicesScalarFunctionImplementation( + MethodHandle validFunctionMethodHandle = methodHandle(TestParametricScalarFunctionImplementationValidation.class, "validConnectorSessionParameterPosition", ConnectorSession.class, long.class, long.class); + ChoicesSpecializedSqlScalarFunction validFunction = new ChoicesSpecializedSqlScalarFunction( new BoundSignature("test", BIGINT, ImmutableList.of(BIGINT, BIGINT)), FAIL_ON_NULL, ImmutableList.of(NEVER_NULL, NEVER_NULL), validFunctionMethodHandle); assertEquals(validFunction.getChoices().get(0).getMethodHandle(), validFunctionMethodHandle); - assertThatThrownBy(() -> new ChoicesScalarFunctionImplementation( + assertThatThrownBy(() -> new ChoicesSpecializedSqlScalarFunction( new BoundSignature("test", BIGINT, ImmutableList.of(BIGINT, BIGINT)), FAIL_ON_NULL, ImmutableList.of(NEVER_NULL, NEVER_NULL), - methodHandle(TestParametricScalarImplementationValidation.class, "invalidConnectorSessionParameterPosition", long.class, long.class, ConnectorSession.class))) + methodHandle(TestParametricScalarFunctionImplementationValidation.class, "invalidConnectorSessionParameterPosition", long.class, long.class, ConnectorSession.class))) .isInstanceOf(IllegalArgumentException.class) .hasMessage("ConnectorSession must be the first argument when instanceFactory is not present"); // With cached instance factory - MethodHandle validFunctionWithInstanceFactoryMethodHandle = methodHandle(TestParametricScalarImplementationValidation.class, "validConnectorSessionParameterPosition", Object.class, ConnectorSession.class, long.class, long.class); - ChoicesScalarFunctionImplementation validFunctionWithInstanceFactory = new ChoicesScalarFunctionImplementation( + MethodHandle validFunctionWithInstanceFactoryMethodHandle = methodHandle(TestParametricScalarFunctionImplementationValidation.class, "validConnectorSessionParameterPosition", Object.class, ConnectorSession.class, long.class, long.class); + ChoicesSpecializedSqlScalarFunction validFunctionWithInstanceFactory = new ChoicesSpecializedSqlScalarFunction( new BoundSignature("test", BIGINT, ImmutableList.of(BIGINT, BIGINT)), FAIL_ON_NULL, ImmutableList.of(NEVER_NULL, NEVER_NULL), @@ -62,11 +62,11 @@ public void testConnectorSessionPosition() Optional.of(STATE_FACTORY)); assertEquals(validFunctionWithInstanceFactory.getChoices().get(0).getMethodHandle(), validFunctionWithInstanceFactoryMethodHandle); - assertThatThrownBy(() -> new ChoicesScalarFunctionImplementation( + assertThatThrownBy(() -> new ChoicesSpecializedSqlScalarFunction( new BoundSignature("test", BIGINT, ImmutableList.of(BIGINT, BIGINT)), FAIL_ON_NULL, ImmutableList.of(NEVER_NULL, NEVER_NULL), - methodHandle(TestParametricScalarImplementationValidation.class, "invalidConnectorSessionParameterPosition", Object.class, long.class, long.class, ConnectorSession.class), + methodHandle(TestParametricScalarFunctionImplementationValidation.class, "invalidConnectorSessionParameterPosition", Object.class, long.class, long.class, ConnectorSession.class), Optional.of(STATE_FACTORY))) .isInstanceOf(IllegalArgumentException.class) .hasMessage("ConnectorSession must be the second argument when instanceFactory is present"); diff --git a/core/trino-main/src/test/java/io/trino/sql/gen/TestVarArgsToArrayAdapterGenerator.java b/core/trino-main/src/test/java/io/trino/sql/gen/TestVarArgsToArrayAdapterGenerator.java index 91aa4fd15a36..41ae562596b6 100644 --- a/core/trino-main/src/test/java/io/trino/sql/gen/TestVarArgsToArrayAdapterGenerator.java +++ b/core/trino-main/src/test/java/io/trino/sql/gen/TestVarArgsToArrayAdapterGenerator.java @@ -15,14 +15,14 @@ import com.google.common.base.Joiner; import io.trino.annotation.UsedByGeneratedCode; -import io.trino.metadata.BoundSignature; -import io.trino.metadata.FunctionMetadata; -import io.trino.metadata.Signature; import io.trino.metadata.SqlScalarFunction; import io.trino.operator.scalar.AbstractTestFunctions; -import io.trino.operator.scalar.ChoicesScalarFunctionImplementation; -import io.trino.operator.scalar.ScalarFunctionImplementation; +import io.trino.operator.scalar.ChoicesSpecializedSqlScalarFunction; +import io.trino.operator.scalar.SpecializedSqlScalarFunction; +import io.trino.spi.function.BoundSignature; +import io.trino.spi.function.FunctionMetadata; import io.trino.spi.function.InvocationConvention.InvocationReturnConvention; +import io.trino.spi.function.Signature; import org.testng.annotations.BeforeClass; import org.testng.annotations.Test; @@ -87,7 +87,7 @@ private TestVarArgsSum() } @Override - protected ScalarFunctionImplementation specialize(BoundSignature boundSignature) + protected SpecializedSqlScalarFunction specialize(BoundSignature boundSignature) { VarArgsToArrayAdapterGenerator.MethodHandleAndConstructor methodHandleAndConstructor = generateVarArgsToArrayAdapter( long.class, @@ -95,7 +95,7 @@ protected ScalarFunctionImplementation specialize(BoundSignature boundSignature) boundSignature.getArity(), METHOD_HANDLE, USER_STATE_FACTORY); - return new ChoicesScalarFunctionImplementation( + return new ChoicesSpecializedSqlScalarFunction( boundSignature, InvocationReturnConvention.FAIL_ON_NULL, nCopies(boundSignature.getArity(), NEVER_NULL), diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestEffectivePredicateExtractor.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestEffectivePredicateExtractor.java index 82759d018451..9ca9bdfb3708 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestEffectivePredicateExtractor.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestEffectivePredicateExtractor.java @@ -20,10 +20,8 @@ import com.google.common.collect.ImmutableSet; import com.google.common.collect.Maps; import io.trino.Session; +import io.trino.connector.system.GlobalSystemConnector; import io.trino.metadata.AbstractMockMetadata; -import io.trino.metadata.BoundSignature; -import io.trino.metadata.FunctionMetadata; -import io.trino.metadata.FunctionNullability; import io.trino.metadata.Metadata; import io.trino.metadata.ResolvedFunction; import io.trino.metadata.TableHandle; @@ -34,6 +32,9 @@ import io.trino.spi.connector.ConnectorTableHandle; import io.trino.spi.connector.ConnectorTableProperties; import io.trino.spi.connector.SortOrder; +import io.trino.spi.function.BoundSignature; +import io.trino.spi.function.FunctionMetadata; +import io.trino.spi.function.FunctionNullability; import io.trino.spi.predicate.Domain; import io.trino.spi.predicate.TupleDomain; import io.trino.spi.type.RowType; @@ -93,7 +94,7 @@ import static com.google.common.base.Preconditions.checkState; import static com.google.common.collect.ImmutableList.toImmutableList; -import static io.trino.metadata.FunctionId.toFunctionId; +import static io.trino.spi.function.FunctionId.toFunctionId; import static io.trino.spi.function.FunctionKind.SCALAR; import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.DoubleType.DOUBLE; @@ -1193,6 +1194,7 @@ private static ResolvedFunction fakeFunction(String name) BoundSignature boundSignature = new BoundSignature(name, UNKNOWN, ImmutableList.of()); return new ResolvedFunction( boundSignature, + GlobalSystemConnector.CATALOG_HANDLE, toFunctionId(boundSignature.toSignature()), SCALAR, true, diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestLiteralEncoder.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestLiteralEncoder.java index 4ba428f1d08b..33af3f07de8a 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestLiteralEncoder.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestLiteralEncoder.java @@ -17,13 +17,14 @@ import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import io.airlift.slice.Slice; -import io.trino.metadata.BoundSignature; -import io.trino.metadata.FunctionNullability; +import io.trino.connector.system.GlobalSystemConnector; import io.trino.metadata.LiteralFunction; import io.trino.metadata.ResolvedFunction; -import io.trino.metadata.Signature; import io.trino.operator.scalar.Re2JCastToRegexpFunction; import io.trino.security.AllowAllAccessControl; +import io.trino.spi.function.BoundSignature; +import io.trino.spi.function.FunctionNullability; +import io.trino.spi.function.Signature; import io.trino.spi.type.LongTimestamp; import io.trino.spi.type.LongTimestampWithTimeZone; import io.trino.spi.type.TimeZoneKey; @@ -46,11 +47,11 @@ import static io.airlift.slice.Slices.utf8Slice; import static io.airlift.testing.Assertions.assertEqualsIgnoreCase; import static io.trino.SessionTestUtils.TEST_SESSION; -import static io.trino.metadata.FunctionId.toFunctionId; import static io.trino.metadata.LiteralFunction.LITERAL_FUNCTION_NAME; import static io.trino.operator.scalar.JoniRegexpCasts.castVarcharToJoniRegexp; import static io.trino.operator.scalar.JsonFunctions.castVarcharToJsonPath; import static io.trino.operator.scalar.StringFunctions.castVarcharToCodePoints; +import static io.trino.spi.function.FunctionId.toFunctionId; import static io.trino.spi.function.FunctionKind.SCALAR; import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.CharType.createCharType; @@ -85,6 +86,7 @@ public class TestLiteralEncoder private final ResolvedFunction literalFunction = new ResolvedFunction( new BoundSignature(LITERAL_FUNCTION_NAME, VARBINARY, ImmutableList.of(VARBINARY)), + GlobalSystemConnector.CATALOG_HANDLE, new LiteralFunction(PLANNER_CONTEXT.getBlockEncodingSerde()).getFunctionMetadata().getFunctionId(), SCALAR, true, @@ -94,6 +96,7 @@ public class TestLiteralEncoder private final ResolvedFunction base64Function = new ResolvedFunction( new BoundSignature("from_base64", VARBINARY, ImmutableList.of(VARCHAR)), + GlobalSystemConnector.CATALOG_HANDLE, toFunctionId(Signature.builder() .name("from_base64") .returnType(VARBINARY) diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestingPlannerContext.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestingPlannerContext.java index 037f5a3d8942..636b5630d33a 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestingPlannerContext.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestingPlannerContext.java @@ -14,6 +14,7 @@ package io.trino.sql.planner; import io.trino.FeaturesConfig; +import io.trino.connector.CatalogServiceProvider; import io.trino.metadata.BlockEncodingManager; import io.trino.metadata.FunctionBundle; import io.trino.metadata.FunctionManager; @@ -131,7 +132,7 @@ public PlannerContext build() metadata = builder.build(); } - FunctionManager functionManager = new FunctionManager(globalFunctionCatalog); + FunctionManager functionManager = new FunctionManager(CatalogServiceProvider.fail(), globalFunctionCatalog); globalFunctionCatalog.addFunctions(new InternalFunctionBundle( new JsonExistsFunction(functionManager, metadata, typeManager), new JsonValueFunction(functionManager, metadata, typeManager), diff --git a/core/trino-spi/src/main/java/io/trino/spi/connector/Connector.java b/core/trino-spi/src/main/java/io/trino/spi/connector/Connector.java index 91192179a66b..8fdb5222449e 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/connector/Connector.java +++ b/core/trino-spi/src/main/java/io/trino/spi/connector/Connector.java @@ -13,13 +13,16 @@ */ package io.trino.spi.connector; +import io.trino.spi.Experimental; import io.trino.spi.eventlistener.EventListener; +import io.trino.spi.function.FunctionProvider; import io.trino.spi.procedure.Procedure; import io.trino.spi.ptf.ConnectorTableFunction; import io.trino.spi.session.PropertyMetadata; import io.trino.spi.transaction.IsolationLevel; import java.util.List; +import java.util.Optional; import java.util.Set; import static java.util.Collections.emptyList; @@ -133,6 +136,15 @@ default Set getSystemTables() return emptySet(); } + /** + * @return the set of procedures provided by this connector + */ + @Experimental(eta = "2022-10-31") + default Optional getFunctionProvider() + { + return Optional.empty(); + } + /** * @return the set of procedures provided by this connector */ diff --git a/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorMetadata.java b/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorMetadata.java index f030788eef9c..aa0bc2f7a015 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorMetadata.java +++ b/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorMetadata.java @@ -18,6 +18,12 @@ import io.trino.spi.expression.ConnectorExpression; import io.trino.spi.expression.Constant; import io.trino.spi.expression.Variable; +import io.trino.spi.function.AggregationFunctionMetadata; +import io.trino.spi.function.BoundSignature; +import io.trino.spi.function.FunctionDependencyDeclaration; +import io.trino.spi.function.FunctionId; +import io.trino.spi.function.FunctionMetadata; +import io.trino.spi.function.SchemaFunctionName; import io.trino.spi.predicate.TupleDomain; import io.trino.spi.ptf.ConnectorTableFunctionHandle; import io.trino.spi.security.GrantInfo; @@ -854,6 +860,46 @@ default Optional resolveIndex(ConnectorSession session, return Optional.empty(); } + /** + * List available functions. + */ + default Collection listFunctions(ConnectorSession session, String schemaName) + { + return List.of(); + } + + /** + * Get all functions with specified name. + */ + default Collection getFunctions(ConnectorSession session, SchemaFunctionName name) + { + return List.of(); + } + + /** + * Return the function with the specified id. + */ + default FunctionMetadata getFunctionMetadata(ConnectorSession session, FunctionId functionId) + { + throw new IllegalArgumentException("Unknown function " + functionId); + } + + /** + * Returns the aggregation metadata for the aggregation function with the specified id. + */ + default AggregationFunctionMetadata getAggregationFunctionMetadata(ConnectorSession session, FunctionId functionId) + { + throw new IllegalArgumentException("Unknown function " + functionId); + } + + /** + * Returns the dependencies of the function with the specified id. + */ + default FunctionDependencyDeclaration getFunctionDependencies(ConnectorSession session, FunctionId functionId, BoundSignature boundSignature) + { + throw new IllegalArgumentException("Unknown function " + functionId); + } + /** * Does the specified role exist. */ diff --git a/core/trino-main/src/main/java/io/trino/metadata/AggregationFunctionMetadata.java b/core/trino-spi/src/main/java/io/trino/spi/function/AggregationFunctionMetadata.java similarity index 84% rename from core/trino-main/src/main/java/io/trino/metadata/AggregationFunctionMetadata.java rename to core/trino-spi/src/main/java/io/trino/spi/function/AggregationFunctionMetadata.java index 70931ec3f7a0..a02eb3518666 100644 --- a/core/trino-main/src/main/java/io/trino/metadata/AggregationFunctionMetadata.java +++ b/core/trino-spi/src/main/java/io/trino/spi/function/AggregationFunctionMetadata.java @@ -11,18 +11,19 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package io.trino.metadata; +package io.trino.spi.function; -import com.google.common.collect.ImmutableList; +import io.trino.spi.Experimental; import io.trino.spi.type.Type; import io.trino.spi.type.TypeSignature; import java.util.ArrayList; import java.util.List; +import java.util.StringJoiner; -import static com.google.common.base.MoreObjects.toStringHelper; import static java.util.Objects.requireNonNull; +@Experimental(eta = "2022-10-31") public class AggregationFunctionMetadata { private final boolean orderSensitive; @@ -31,7 +32,7 @@ public class AggregationFunctionMetadata private AggregationFunctionMetadata(boolean orderSensitive, List intermediateTypes) { this.orderSensitive = orderSensitive; - this.intermediateTypes = ImmutableList.copyOf(requireNonNull(intermediateTypes, "intermediateTypes is null")); + this.intermediateTypes = List.copyOf(requireNonNull(intermediateTypes, "intermediateTypes is null")); } public boolean isOrderSensitive() @@ -52,9 +53,9 @@ public List getIntermediateTypes() @Override public String toString() { - return toStringHelper(this) - .add("orderSensitive", orderSensitive) - .add("intermediateTypes", intermediateTypes) + return new StringJoiner(", ", AggregationFunctionMetadata.class.getSimpleName() + "[", "]") + .add("orderSensitive=" + orderSensitive) + .add("intermediateTypes=" + intermediateTypes) .toString(); } diff --git a/core/trino-spi/src/main/java/io/trino/spi/function/AggregationImplementation.java b/core/trino-spi/src/main/java/io/trino/spi/function/AggregationImplementation.java new file mode 100644 index 000000000000..dd308264975e --- /dev/null +++ b/core/trino-spi/src/main/java/io/trino/spi/function/AggregationImplementation.java @@ -0,0 +1,225 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.spi.function; + +import io.trino.spi.Experimental; + +import java.lang.invoke.MethodHandle; +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; + +import static java.util.Objects.requireNonNull; + +@Experimental(eta = "2022-10-31") +public class AggregationImplementation +{ + private final MethodHandle inputFunction; + private final Optional removeInputFunction; + private final Optional combineFunction; + private final MethodHandle outputFunction; + private final List> accumulatorStateDescriptors; + private final List> lambdaInterfaces; + + private AggregationImplementation( + MethodHandle inputFunction, + Optional removeInputFunction, + Optional combineFunction, + MethodHandle outputFunction, + List> accumulatorStateDescriptors, + List> lambdaInterfaces) + { + this.inputFunction = requireNonNull(inputFunction, "inputFunction is null"); + this.removeInputFunction = requireNonNull(removeInputFunction, "removeInputFunction is null"); + this.combineFunction = requireNonNull(combineFunction, "combineFunction is null"); + this.outputFunction = requireNonNull(outputFunction, "outputFunction is null"); + this.accumulatorStateDescriptors = requireNonNull(accumulatorStateDescriptors, "accumulatorStateDescriptors is null"); + this.lambdaInterfaces = List.copyOf(requireNonNull(lambdaInterfaces, "lambdaInterfaces is null")); + } + + public MethodHandle getInputFunction() + { + return inputFunction; + } + + public Optional getRemoveInputFunction() + { + return removeInputFunction; + } + + public Optional getCombineFunction() + { + return combineFunction; + } + + public MethodHandle getOutputFunction() + { + return outputFunction; + } + + public List> getAccumulatorStateDescriptors() + { + return accumulatorStateDescriptors; + } + + public List> getLambdaInterfaces() + { + return lambdaInterfaces; + } + + public static class AccumulatorStateDescriptor + { + private final Class stateInterface; + private final AccumulatorStateSerializer serializer; + private final AccumulatorStateFactory factory; + + private AccumulatorStateDescriptor(Class stateInterface, AccumulatorStateSerializer serializer, AccumulatorStateFactory factory) + { + this.stateInterface = requireNonNull(stateInterface, "stateInterface is null"); + this.serializer = requireNonNull(serializer, "serializer is null"); + this.factory = requireNonNull(factory, "factory is null"); + } + + // this is only used to verify method interfaces + public Class getStateInterface() + { + return stateInterface; + } + + public AccumulatorStateSerializer getSerializer() + { + return serializer; + } + + public AccumulatorStateFactory getFactory() + { + return factory; + } + + public static Builder builder(Class stateInterface) + { + return new Builder<>(stateInterface); + } + + public static class Builder + { + private final Class stateInterface; + private AccumulatorStateSerializer serializer; + private AccumulatorStateFactory factory; + + private Builder(Class stateInterface) + { + this.stateInterface = requireNonNull(stateInterface, "stateInterface is null"); + } + + public Builder serializer(AccumulatorStateSerializer serializer) + { + this.serializer = serializer; + return this; + } + + public Builder factory(AccumulatorStateFactory factory) + { + this.factory = factory; + return this; + } + + public AccumulatorStateDescriptor build() + { + return new AccumulatorStateDescriptor<>(stateInterface, serializer, factory); + } + } + } + + public static Builder builder() + { + return new Builder(); + } + + public static class Builder + { + private MethodHandle inputFunction; + private Optional removeInputFunction = Optional.empty(); + private Optional combineFunction = Optional.empty(); + private MethodHandle outputFunction; + private List> accumulatorStateDescriptors = new ArrayList<>(); + private List> lambdaInterfaces = List.of(); + + private Builder() {} + + public Builder inputFunction(MethodHandle inputFunction) + { + this.inputFunction = requireNonNull(inputFunction, "inputFunction is null"); + return this; + } + + public Builder removeInputFunction(MethodHandle removeInputFunction) + { + this.removeInputFunction = Optional.of(requireNonNull(removeInputFunction, "removeInputFunction is null")); + return this; + } + + public Builder combineFunction(MethodHandle combineFunction) + { + this.combineFunction = Optional.of(requireNonNull(combineFunction, "combineFunction is null")); + return this; + } + + public Builder outputFunction(MethodHandle outputFunction) + { + this.outputFunction = requireNonNull(outputFunction, "outputFunction is null"); + return this; + } + + public Builder accumulatorStateDescriptor(Class stateInterface, AccumulatorStateSerializer serializer, AccumulatorStateFactory factory) + { + this.accumulatorStateDescriptors.add(AccumulatorStateDescriptor.builder(stateInterface) + .serializer(serializer) + .factory(factory) + .build()); + return this; + } + + public Builder accumulatorStateDescriptors(List> accumulatorStateDescriptors) + { + requireNonNull(accumulatorStateDescriptors, "accumulatorStateDescriptors is null"); + + this.accumulatorStateDescriptors = new ArrayList<>(); + this.accumulatorStateDescriptors.addAll(accumulatorStateDescriptors); + return this; + } + + public Builder lambdaInterfaces(Class... lambdaInterfaces) + { + return lambdaInterfaces(List.of(lambdaInterfaces)); + } + + public Builder lambdaInterfaces(List> lambdaInterfaces) + { + this.lambdaInterfaces = List.copyOf(requireNonNull(lambdaInterfaces, "lambdaInterfaces is null")); + return this; + } + + public AggregationImplementation build() + { + return new AggregationImplementation( + inputFunction, + removeInputFunction, + combineFunction, + outputFunction, + accumulatorStateDescriptors, + lambdaInterfaces); + } + } +} diff --git a/core/trino-main/src/main/java/io/trino/metadata/BoundSignature.java b/core/trino-spi/src/main/java/io/trino/spi/function/BoundSignature.java similarity index 90% rename from core/trino-main/src/main/java/io/trino/metadata/BoundSignature.java rename to core/trino-spi/src/main/java/io/trino/spi/function/BoundSignature.java index 599a891999e3..6100ca8058ef 100644 --- a/core/trino-main/src/main/java/io/trino/metadata/BoundSignature.java +++ b/core/trino-spi/src/main/java/io/trino/spi/function/BoundSignature.java @@ -11,20 +11,21 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package io.trino.metadata; +package io.trino.spi.function; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; -import com.google.common.collect.ImmutableList; +import io.trino.spi.Experimental; import io.trino.spi.type.Type; import java.util.List; import java.util.Objects; -import static com.google.common.collect.ImmutableList.toImmutableList; import static java.util.Objects.requireNonNull; import static java.util.stream.Collectors.joining; +import static java.util.stream.Collectors.toUnmodifiableList; +@Experimental(eta = "2022-10-31") public class BoundSignature { private final String name; @@ -39,7 +40,7 @@ public BoundSignature( { this.name = requireNonNull(name, "name is null"); this.returnType = requireNonNull(returnType, "returnType is null"); - this.argumentTypes = ImmutableList.copyOf(requireNonNull(argumentTypes, "argumentTypes is null")); + this.argumentTypes = List.copyOf(requireNonNull(argumentTypes, "argumentTypes is null")); } @JsonProperty @@ -77,7 +78,7 @@ public Signature toSignature() .returnType(returnType) .argumentTypes(argumentTypes.stream() .map(Type::getTypeSignature) - .collect(toImmutableList())) + .collect(toUnmodifiableList())) .build(); } diff --git a/core/trino-spi/src/main/java/io/trino/spi/function/FunctionDependencies.java b/core/trino-spi/src/main/java/io/trino/spi/function/FunctionDependencies.java new file mode 100644 index 000000000000..b3f27a4c1ced --- /dev/null +++ b/core/trino-spi/src/main/java/io/trino/spi/function/FunctionDependencies.java @@ -0,0 +1,44 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.spi.function; + +import io.trino.spi.Experimental; +import io.trino.spi.type.Type; +import io.trino.spi.type.TypeSignature; + +import java.util.List; + +@Experimental(eta = "2022-10-31") +public interface FunctionDependencies +{ + Type getType(TypeSignature typeSignature); + + FunctionNullability getFunctionNullability(QualifiedFunctionName name, List parameterTypes); + + FunctionNullability getOperatorNullability(OperatorType operatorType, List parameterTypes); + + FunctionNullability getCastNullability(Type fromType, Type toType); + + ScalarFunctionImplementation getScalarFunctionImplementation(QualifiedFunctionName name, List parameterTypes, InvocationConvention invocationConvention); + + ScalarFunctionImplementation getScalarFunctionImplementationSignature(QualifiedFunctionName name, List parameterTypes, InvocationConvention invocationConvention); + + ScalarFunctionImplementation getOperatorImplementation(OperatorType operatorType, List parameterTypes, InvocationConvention invocationConvention); + + ScalarFunctionImplementation getOperatorImplementationSignature(OperatorType operatorType, List parameterTypes, InvocationConvention invocationConvention); + + ScalarFunctionImplementation getCastImplementation(Type fromType, Type toType, InvocationConvention invocationConvention); + + ScalarFunctionImplementation getCastImplementationSignature(TypeSignature fromType, TypeSignature toType, InvocationConvention invocationConvention); +} diff --git a/core/trino-main/src/main/java/io/trino/metadata/FunctionDependencyDeclaration.java b/core/trino-spi/src/main/java/io/trino/spi/function/FunctionDependencyDeclaration.java similarity index 86% rename from core/trino-main/src/main/java/io/trino/metadata/FunctionDependencyDeclaration.java rename to core/trino-spi/src/main/java/io/trino/spi/function/FunctionDependencyDeclaration.java index 6a7c8b5c1b7c..f376d69b3f2e 100644 --- a/core/trino-main/src/main/java/io/trino/metadata/FunctionDependencyDeclaration.java +++ b/core/trino-spi/src/main/java/io/trino/spi/function/FunctionDependencyDeclaration.java @@ -11,14 +11,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package io.trino.metadata; +package io.trino.spi.function; -import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableSet; -import io.trino.spi.function.OperatorType; +import io.trino.spi.Experimental; import io.trino.spi.type.Type; import io.trino.spi.type.TypeSignature; -import io.trino.sql.tree.QualifiedName; import java.util.LinkedHashSet; import java.util.List; @@ -26,10 +23,11 @@ import java.util.Set; import java.util.stream.Collectors; -import static com.google.common.collect.ImmutableList.toImmutableList; import static java.lang.String.format; import static java.util.Objects.requireNonNull; +import static java.util.stream.Collectors.toUnmodifiableList; +@Experimental(eta = "2022-10-31") public class FunctionDependencyDeclaration { public static final FunctionDependencyDeclaration NO_DEPENDENCIES = builder().build(); @@ -50,10 +48,10 @@ private FunctionDependencyDeclaration( Set operatorDependencies, Set castDependencies) { - this.typeDependencies = ImmutableSet.copyOf(requireNonNull(typeDependencies, "typeDependencies is null")); - this.functionDependencies = ImmutableSet.copyOf(requireNonNull(functionDependencies, "functionDependencies is null")); - this.operatorDependencies = ImmutableSet.copyOf(requireNonNull(operatorDependencies, "operatorDependencies is null")); - this.castDependencies = ImmutableSet.copyOf(requireNonNull(castDependencies, "castDependencies is null")); + this.typeDependencies = Set.copyOf(requireNonNull(typeDependencies, "typeDependencies is null")); + this.functionDependencies = Set.copyOf(requireNonNull(functionDependencies, "functionDependencies is null")); + this.operatorDependencies = Set.copyOf(requireNonNull(operatorDependencies, "operatorDependencies is null")); + this.castDependencies = Set.copyOf(requireNonNull(castDependencies, "castDependencies is null")); } public Set getTypeDependencies() @@ -91,32 +89,32 @@ public FunctionDependencyDeclarationBuilder addType(TypeSignature typeSignature) return this; } - public FunctionDependencyDeclarationBuilder addFunction(QualifiedName name, List parameterTypes) + public FunctionDependencyDeclarationBuilder addFunction(QualifiedFunctionName name, List parameterTypes) { functionDependencies.add(new FunctionDependency(name, parameterTypes.stream() .map(Type::getTypeSignature) - .collect(toImmutableList()), false)); + .collect(toUnmodifiableList()), false)); return this; } - public FunctionDependencyDeclarationBuilder addFunctionSignature(QualifiedName name, List parameterTypes) + public FunctionDependencyDeclarationBuilder addFunctionSignature(QualifiedFunctionName name, List parameterTypes) { functionDependencies.add(new FunctionDependency(name, parameterTypes, false)); return this; } - public FunctionDependencyDeclarationBuilder addOptionalFunction(QualifiedName name, List parameterTypes) + public FunctionDependencyDeclarationBuilder addOptionalFunction(QualifiedFunctionName name, List parameterTypes) { functionDependencies.add(new FunctionDependency( name, parameterTypes.stream() .map(Type::getTypeSignature) - .collect(toImmutableList()), + .collect(toUnmodifiableList()), true)); return this; } - public FunctionDependencyDeclarationBuilder addOptionalFunctionSignature(QualifiedName name, List parameterTypes) + public FunctionDependencyDeclarationBuilder addOptionalFunctionSignature(QualifiedFunctionName name, List parameterTypes) { functionDependencies.add(new FunctionDependency(name, parameterTypes, true)); return this; @@ -126,7 +124,7 @@ public FunctionDependencyDeclarationBuilder addOperator(OperatorType operatorTyp { operatorDependencies.add(new OperatorDependency(operatorType, parameterTypes.stream() .map(Type::getTypeSignature) - .collect(toImmutableList()), false)); + .collect(toUnmodifiableList()), false)); return this; } @@ -142,7 +140,7 @@ public FunctionDependencyDeclarationBuilder addOptionalOperator(OperatorType ope operatorType, parameterTypes.stream() .map(Type::getTypeSignature) - .collect(toImmutableList()), + .collect(toUnmodifiableList()), true)); return this; } @@ -189,18 +187,18 @@ public FunctionDependencyDeclaration build() public static final class FunctionDependency { - private final QualifiedName name; + private final QualifiedFunctionName name; private final List argumentTypes; private final boolean optional; - private FunctionDependency(QualifiedName name, List argumentTypes, boolean optional) + private FunctionDependency(QualifiedFunctionName name, List argumentTypes, boolean optional) { this.name = requireNonNull(name, "name is null"); - this.argumentTypes = ImmutableList.copyOf(requireNonNull(argumentTypes, "argumentTypes is null")); + this.argumentTypes = List.copyOf(requireNonNull(argumentTypes, "argumentTypes is null")); this.optional = optional; } - public QualifiedName getName() + public QualifiedFunctionName getName() { return name; } @@ -253,7 +251,7 @@ public static final class OperatorDependency private OperatorDependency(OperatorType operatorType, List argumentTypes, boolean optional) { this.operatorType = requireNonNull(operatorType, "operatorType is null"); - this.argumentTypes = ImmutableList.copyOf(requireNonNull(argumentTypes, "argumentTypes is null")); + this.argumentTypes = List.copyOf(requireNonNull(argumentTypes, "argumentTypes is null")); this.optional = optional; } diff --git a/core/trino-main/src/main/java/io/trino/metadata/FunctionId.java b/core/trino-spi/src/main/java/io/trino/spi/function/FunctionId.java similarity index 76% rename from core/trino-main/src/main/java/io/trino/metadata/FunctionId.java rename to core/trino-spi/src/main/java/io/trino/spi/function/FunctionId.java index b8ffbf46eb06..d78bee164bc2 100644 --- a/core/trino-main/src/main/java/io/trino/metadata/FunctionId.java +++ b/core/trino-spi/src/main/java/io/trino/spi/function/FunctionId.java @@ -11,16 +11,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package io.trino.metadata; +package io.trino.spi.function; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonValue; +import io.trino.spi.Experimental; import java.util.Locale; -import static com.google.common.base.Preconditions.checkArgument; import static java.util.Objects.requireNonNull; +@Experimental(eta = "2022-10-31") public class FunctionId { private final String id; @@ -29,9 +30,15 @@ public class FunctionId public FunctionId(String id) { requireNonNull(id, "id is null"); - checkArgument(!id.isEmpty(), "id must not be empty"); - checkArgument(id.toLowerCase(Locale.US).equals(id), "id must be lowercase"); - checkArgument(!id.contains("@"), "id must not contain '@'"); + if (id.isEmpty()) { + throw new IllegalArgumentException("id must not be empty"); + } + if (!id.toLowerCase(Locale.US).equals(id)) { + throw new IllegalArgumentException("id must be lowercase"); + } + if (id.contains("@")) { + throw new IllegalArgumentException("id must not contain '@'"); + } this.id = id; } diff --git a/core/trino-main/src/main/java/io/trino/metadata/FunctionMetadata.java b/core/trino-spi/src/main/java/io/trino/spi/function/FunctionMetadata.java similarity index 89% rename from core/trino-main/src/main/java/io/trino/metadata/FunctionMetadata.java rename to core/trino-spi/src/main/java/io/trino/spi/function/FunctionMetadata.java index a30735ba4ba2..96d39b009ec1 100644 --- a/core/trino-main/src/main/java/io/trino/metadata/FunctionMetadata.java +++ b/core/trino-spi/src/main/java/io/trino/spi/function/FunctionMetadata.java @@ -11,21 +11,20 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package io.trino.metadata; +package io.trino.spi.function; -import com.google.common.collect.ImmutableList; -import com.google.common.primitives.Booleans; -import io.trino.spi.function.FunctionKind; +import io.trino.spi.Experimental; +import java.util.ArrayList; import java.util.Collections; import java.util.List; -import static com.google.common.base.Preconditions.checkArgument; import static io.trino.spi.function.FunctionKind.AGGREGATE; import static io.trino.spi.function.FunctionKind.SCALAR; import static io.trino.spi.function.FunctionKind.WINDOW; import static java.util.Objects.requireNonNull; +@Experimental(eta = "2022-10-31") public class FunctionMetadata { private final FunctionId functionId; @@ -53,7 +52,9 @@ private FunctionMetadata( this.signature = requireNonNull(signature, "signature is null"); this.canonicalName = requireNonNull(canonicalName, "canonicalName is null"); this.functionNullability = requireNonNull(functionNullability, "functionNullability is null"); - checkArgument(functionNullability.getArgumentNullable().size() == signature.getArgumentTypes().size(), "signature and functionNullability must have same argument count"); + if (functionNullability.getArgumentNullable().size() != signature.getArgumentTypes().size()) { + throw new IllegalArgumentException("signature and functionNullability must have same argument count"); + } this.hidden = hidden; this.deterministic = deterministic; @@ -168,7 +169,7 @@ private Builder(FunctionKind kind) public Builder signature(Signature signature) { this.signature = signature; - if (Signature.isOperatorName(signature.getName())) { + if (signature.isOperator()) { hidden = true; description = ""; } @@ -190,12 +191,16 @@ public Builder nullable() public Builder argumentNullability(boolean... argumentNullability) { requireNonNull(argumentNullability, "argumentNullability is null"); - return argumentNullability(ImmutableList.copyOf(Booleans.asList(argumentNullability))); + List list = new ArrayList<>(argumentNullability.length); + for (boolean nullability : argumentNullability) { + list.add(nullability); + } + return argumentNullability(list); } public Builder argumentNullability(List argumentNullability) { - this.argumentNullability = argumentNullability; + this.argumentNullability = List.copyOf(requireNonNull(argumentNullability, "argumentNullability is null")); return this; } @@ -223,7 +228,9 @@ public Builder noDescription() public Builder description(String description) { requireNonNull(description, "description is null"); - checkArgument(!description.isEmpty(), "description is empty"); + if (description.isBlank()) { + throw new IllegalArgumentException("description is blank"); + } this.description = description; return this; } diff --git a/core/trino-main/src/main/java/io/trino/metadata/FunctionNullability.java b/core/trino-spi/src/main/java/io/trino/spi/function/FunctionNullability.java similarity index 91% rename from core/trino-main/src/main/java/io/trino/metadata/FunctionNullability.java rename to core/trino-spi/src/main/java/io/trino/spi/function/FunctionNullability.java index 88e4e2691dc2..bc7757938fa9 100644 --- a/core/trino-main/src/main/java/io/trino/metadata/FunctionNullability.java +++ b/core/trino-spi/src/main/java/io/trino/spi/function/FunctionNullability.java @@ -11,11 +11,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package io.trino.metadata; +package io.trino.spi.function; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; -import com.google.common.collect.ImmutableList; +import io.trino.spi.Experimental; import java.util.List; import java.util.Objects; @@ -23,6 +23,7 @@ import static java.util.Objects.requireNonNull; import static java.util.stream.Collectors.joining; +@Experimental(eta = "2022-10-31") public class FunctionNullability { private final boolean returnNullable; @@ -34,7 +35,7 @@ public FunctionNullability( @JsonProperty("argumentNullable") List argumentNullable) { this.returnNullable = returnNullable; - this.argumentNullable = ImmutableList.copyOf(requireNonNull(argumentNullable, "argumentNullable is null")); + this.argumentNullable = List.copyOf(requireNonNull(argumentNullable, "argumentNullable is null")); } @JsonProperty diff --git a/core/trino-spi/src/main/java/io/trino/spi/function/FunctionProvider.java b/core/trino-spi/src/main/java/io/trino/spi/function/FunctionProvider.java new file mode 100644 index 000000000000..25841879eff4 --- /dev/null +++ b/core/trino-spi/src/main/java/io/trino/spi/function/FunctionProvider.java @@ -0,0 +1,30 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.spi.function; + +import io.trino.spi.Experimental; + +@Experimental(eta = "2022-10-31") +public interface FunctionProvider +{ + ScalarFunctionImplementation getScalarFunctionImplementation( + FunctionId functionId, + BoundSignature boundSignature, + FunctionDependencies functionDependencies, + InvocationConvention invocationConvention); + + AggregationImplementation getAggregationImplementation(FunctionId functionId, BoundSignature boundSignature, FunctionDependencies functionDependencies); + + WindowFunctionSupplier getWindowFunctionSupplier(FunctionId functionId, BoundSignature boundSignature, FunctionDependencies functionDependencies); +} diff --git a/core/trino-main/src/main/java/io/trino/metadata/LongVariableConstraint.java b/core/trino-spi/src/main/java/io/trino/spi/function/LongVariableConstraint.java similarity index 87% rename from core/trino-main/src/main/java/io/trino/metadata/LongVariableConstraint.java rename to core/trino-spi/src/main/java/io/trino/spi/function/LongVariableConstraint.java index 9e849c4991a1..11691f052512 100644 --- a/core/trino-main/src/main/java/io/trino/metadata/LongVariableConstraint.java +++ b/core/trino-spi/src/main/java/io/trino/spi/function/LongVariableConstraint.java @@ -11,13 +11,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package io.trino.metadata; +package io.trino.spi.function; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; +import io.trino.spi.Experimental; import java.util.Objects; +import static java.util.Objects.requireNonNull; + +@Experimental(eta = "2022-10-31") public class LongVariableConstraint { private final String name; @@ -25,8 +29,8 @@ public class LongVariableConstraint LongVariableConstraint(String name, String expression) { - this.name = name; - this.expression = expression; + this.name = requireNonNull(name, "name is null"); + this.expression = requireNonNull(expression, "expression is null"); } @JsonProperty diff --git a/core/trino-spi/src/main/java/io/trino/spi/function/QualifiedFunctionName.java b/core/trino-spi/src/main/java/io/trino/spi/function/QualifiedFunctionName.java new file mode 100644 index 000000000000..70a2f0da3487 --- /dev/null +++ b/core/trino-spi/src/main/java/io/trino/spi/function/QualifiedFunctionName.java @@ -0,0 +1,107 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.spi.function; + +import io.trino.spi.Experimental; + +import java.util.Objects; +import java.util.Optional; + +import static java.util.Objects.requireNonNull; + +@Experimental(eta = "2022-10-31") +public class QualifiedFunctionName +{ + private final Optional catalogName; + private final Optional schemaName; + private final String functionName; + + public static QualifiedFunctionName of(String functionName) + { + return new QualifiedFunctionName(Optional.empty(), Optional.empty(), functionName); + } + + public static QualifiedFunctionName of(String schemaName, String functionName) + { + return new QualifiedFunctionName(Optional.empty(), Optional.of(schemaName), functionName); + } + + public static QualifiedFunctionName of(String catalogName, String schemaName, String functionName) + { + return new QualifiedFunctionName(Optional.of(catalogName), Optional.of(schemaName), functionName); + } + + private QualifiedFunctionName(Optional catalogName, Optional schemaName, String functionName) + { + this.catalogName = requireNonNull(catalogName, "catalogName is null"); + if (catalogName.map(String::isEmpty).orElse(false)) { + throw new IllegalArgumentException("catalogName is empty"); + } + this.schemaName = requireNonNull(schemaName, "schemaName is null"); + if (schemaName.map(String::isEmpty).orElse(false)) { + throw new IllegalArgumentException("schemaName is empty"); + } + if (catalogName.isPresent() && schemaName.isEmpty()) { + throw new IllegalArgumentException("Schema name must be provided when catalog name is provided"); + } + this.functionName = requireNonNull(functionName, "functionName is null"); + if (functionName.isEmpty()) { + throw new IllegalArgumentException("functionName is empty"); + } + } + + public Optional getCatalogName() + { + return catalogName; + } + + public Optional getSchemaName() + { + return schemaName; + } + + public String getFunctionName() + { + return functionName; + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + QualifiedFunctionName that = (QualifiedFunctionName) o; + return catalogName.equals(that.catalogName) && + schemaName.equals(that.schemaName) && + functionName.equals(that.functionName); + } + + @Override + public int hashCode() + { + return Objects.hash(catalogName, schemaName, functionName); + } + + @Override + public String toString() + { + return catalogName.map(name -> name + ".").orElse("") + + schemaName.map(name -> name + ".").orElse("") + + functionName; + } +} diff --git a/core/trino-spi/src/main/java/io/trino/spi/function/ScalarFunctionImplementation.java b/core/trino-spi/src/main/java/io/trino/spi/function/ScalarFunctionImplementation.java new file mode 100644 index 000000000000..f08b52308b1d --- /dev/null +++ b/core/trino-spi/src/main/java/io/trino/spi/function/ScalarFunctionImplementation.java @@ -0,0 +1,89 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.spi.function; + +import io.trino.spi.Experimental; + +import java.lang.invoke.MethodHandle; +import java.util.List; +import java.util.Optional; + +import static java.util.Objects.requireNonNull; + +@Experimental(eta = "2022-10-31") +public class ScalarFunctionImplementation +{ + private final MethodHandle methodHandle; + private final Optional instanceFactory; + private final List> lambdaInterfaces; + + private ScalarFunctionImplementation(MethodHandle methodHandle, Optional instanceFactory, List> lambdaInterfaces) + { + this.methodHandle = requireNonNull(methodHandle, "methodHandle is null"); + this.instanceFactory = requireNonNull(instanceFactory, "instanceFactory is null"); + this.lambdaInterfaces = List.copyOf(requireNonNull(lambdaInterfaces, "lambdaInterfaces is null")); + } + + public MethodHandle getMethodHandle() + { + return methodHandle; + } + + public Optional getInstanceFactory() + { + return instanceFactory; + } + + public List> getLambdaInterfaces() + { + return lambdaInterfaces; + } + + public static Builder builder() + { + return new Builder(); + } + + public static class Builder + { + private MethodHandle methodHandle; + private MethodHandle instanceFactory; + private List> lambdaInterfaces = List.of(); + + private Builder() {} + + public Builder methodHandle(MethodHandle methodHandle) + { + this.methodHandle = requireNonNull(methodHandle, "methodHandle is null"); + return this; + } + + public Builder instanceFactory(MethodHandle instanceFactory) + { + this.instanceFactory = requireNonNull(instanceFactory, "instanceFactory is null"); + return this; + } + + public Builder lambdaInterfaces(List> lambdaInterfaces) + { + this.lambdaInterfaces = List.copyOf(requireNonNull(lambdaInterfaces, "lambdaInterfaces is null")); + return this; + } + + public ScalarFunctionImplementation build() + { + return new ScalarFunctionImplementation(methodHandle, Optional.ofNullable(instanceFactory), lambdaInterfaces); + } + } +} diff --git a/core/trino-main/src/main/java/io/trino/metadata/SchemaFunctionName.java b/core/trino-spi/src/main/java/io/trino/spi/function/SchemaFunctionName.java similarity index 84% rename from core/trino-main/src/main/java/io/trino/metadata/SchemaFunctionName.java rename to core/trino-spi/src/main/java/io/trino/spi/function/SchemaFunctionName.java index 409d3acf124f..eeeb48517c11 100644 --- a/core/trino-main/src/main/java/io/trino/metadata/SchemaFunctionName.java +++ b/core/trino-spi/src/main/java/io/trino/spi/function/SchemaFunctionName.java @@ -11,13 +11,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package io.trino.metadata; +package io.trino.spi.function; + +import io.trino.spi.Experimental; import java.util.Objects; -import static com.google.common.base.Preconditions.checkArgument; import static java.util.Objects.requireNonNull; +@Experimental(eta = "2022-10-31") public final class SchemaFunctionName { private final String schemaName; @@ -26,9 +28,13 @@ public final class SchemaFunctionName public SchemaFunctionName(String schemaName, String functionName) { this.schemaName = requireNonNull(schemaName, "schemaName is null"); - checkArgument(!schemaName.isEmpty(), "schemaName is empty"); + if (schemaName.isEmpty()) { + throw new IllegalArgumentException("schemaName is empty"); + } this.functionName = requireNonNull(functionName, "functionName is null"); - checkArgument(!functionName.isEmpty(), "functionName is empty"); + if (functionName.isEmpty()) { + throw new IllegalArgumentException("functionName is empty"); + } } public String getSchemaName() diff --git a/core/trino-main/src/main/java/io/trino/metadata/Signature.java b/core/trino-spi/src/main/java/io/trino/spi/function/Signature.java similarity index 86% rename from core/trino-main/src/main/java/io/trino/metadata/Signature.java rename to core/trino-spi/src/main/java/io/trino/spi/function/Signature.java index 06f5f342f23c..b4b772b0a2df 100644 --- a/core/trino-main/src/main/java/io/trino/metadata/Signature.java +++ b/core/trino-spi/src/main/java/io/trino/spi/function/Signature.java @@ -11,14 +11,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package io.trino.metadata; +package io.trino.spi.function; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; -import com.google.common.annotations.VisibleForTesting; -import com.google.common.base.Joiner; -import com.google.common.collect.ImmutableList; -import io.trino.spi.function.OperatorType; +import io.trino.spi.Experimental; import io.trino.spi.type.Type; import io.trino.spi.type.TypeSignature; @@ -28,12 +25,14 @@ import java.util.Objects; import java.util.stream.Collectors; -import static com.google.common.base.Preconditions.checkArgument; import static java.util.Objects.requireNonNull; +import static java.util.stream.Collectors.joining; import static java.util.stream.Stream.concat; +@Experimental(eta = "2022-10-31") public class Signature { + // Copied from OperatorNameUtil private static final String OPERATOR_PREFIX = "$operator$"; private final String name; @@ -56,28 +55,16 @@ private Signature( requireNonNull(longVariableConstraints, "longVariableConstraints is null"); this.name = name; - this.typeVariableConstraints = ImmutableList.copyOf(typeVariableConstraints); - this.longVariableConstraints = ImmutableList.copyOf(longVariableConstraints); + this.typeVariableConstraints = List.copyOf(typeVariableConstraints); + this.longVariableConstraints = List.copyOf(longVariableConstraints); this.returnType = requireNonNull(returnType, "returnType is null"); - this.argumentTypes = ImmutableList.copyOf(requireNonNull(argumentTypes, "argumentTypes is null")); + this.argumentTypes = List.copyOf(requireNonNull(argumentTypes, "argumentTypes is null")); this.variableArity = variableArity; } - public static boolean isOperatorName(String mangledName) + boolean isOperator() { - return mangledName.startsWith(OPERATOR_PREFIX); - } - - public static String mangleOperatorName(OperatorType operatorType) - { - return OPERATOR_PREFIX + operatorType.name(); - } - - @VisibleForTesting - public static OperatorType unmangleOperator(String mangledName) - { - checkArgument(mangledName.startsWith(OPERATOR_PREFIX), "not a mangled operator name: %s", mangledName); - return OperatorType.valueOf(mangledName.substring(OPERATOR_PREFIX.length()).toUpperCase(Locale.ENGLISH)); + return name.startsWith(OPERATOR_PREFIX); } @JsonProperty @@ -148,7 +135,10 @@ public String toString() longVariableConstraints.stream().map(LongVariableConstraint::toString)) .collect(Collectors.toList()); - return name + (allConstraints.isEmpty() ? "" : "<" + Joiner.on(",").join(allConstraints) + ">") + "(" + Joiner.on(",").join(argumentTypes) + "):" + returnType; + return name + + (allConstraints.isEmpty() ? "" : allConstraints.stream().collect(joining(",", "<", ">"))) + + argumentTypes.stream().map(Objects::toString).collect(joining(",", "(", ")")) + + ":" + returnType; } public Signature withName(String name) @@ -186,7 +176,7 @@ public Builder name(String name) public Builder operatorType(OperatorType operatorType) { - this.name = mangleOperatorName(requireNonNull(operatorType, "operatorType is null")); + this.name = OPERATOR_PREFIX + requireNonNull(operatorType, "operatorType is null").name(); return this; } diff --git a/core/trino-main/src/main/java/io/trino/metadata/TypeVariableConstraint.java b/core/trino-spi/src/main/java/io/trino/spi/function/TypeVariableConstraint.java similarity index 92% rename from core/trino-main/src/main/java/io/trino/metadata/TypeVariableConstraint.java rename to core/trino-spi/src/main/java/io/trino/spi/function/TypeVariableConstraint.java index e6a2fc0ab781..bfad4c87686a 100644 --- a/core/trino-main/src/main/java/io/trino/metadata/TypeVariableConstraint.java +++ b/core/trino-spi/src/main/java/io/trino/spi/function/TypeVariableConstraint.java @@ -11,12 +11,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package io.trino.metadata; +package io.trino.spi.function; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; -import com.google.common.base.Joiner; -import com.google.common.collect.ImmutableSet; +import io.trino.spi.Experimental; import io.trino.spi.type.Type; import io.trino.spi.type.TypeSignature; @@ -26,7 +25,9 @@ import java.util.Set; import static java.util.Objects.requireNonNull; +import static java.util.stream.Collectors.joining; +@Experimental(eta = "2022-10-31") public class TypeVariableConstraint { private final String name; @@ -51,8 +52,8 @@ private TypeVariableConstraint( if (variadicBound.map(bound -> !bound.equalsIgnoreCase("row")).orElse(false)) { throw new IllegalArgumentException("variadicBound must be row but is " + variadicBound.get()); } - this.castableTo = ImmutableSet.copyOf(requireNonNull(castableTo, "castableTo is null")); - this.castableFrom = ImmutableSet.copyOf(requireNonNull(castableFrom, "castableFrom is null")); + this.castableTo = Set.copyOf(requireNonNull(castableTo, "castableTo is null")); + this.castableFrom = Set.copyOf(requireNonNull(castableFrom, "castableFrom is null")); } @JsonProperty @@ -105,10 +106,10 @@ public String toString() value += ":" + variadicBound + "<*>"; } if (!castableTo.isEmpty()) { - value += ":castableTo(" + Joiner.on(", ").join(castableTo) + ")"; + value += castableTo.stream().map(Object::toString).collect(joining(", ", ":castableTo(", ")")); } if (!castableFrom.isEmpty()) { - value += ":castableFrom(" + Joiner.on(", ").join(castableFrom) + ")"; + value += castableFrom.stream().map(Object::toString).collect(joining(", ", ":castableFrom(", ")")); } return value; } diff --git a/core/trino-main/src/main/java/io/trino/operator/window/WindowFunctionSupplier.java b/core/trino-spi/src/main/java/io/trino/spi/function/WindowFunctionSupplier.java similarity index 89% rename from core/trino-main/src/main/java/io/trino/operator/window/WindowFunctionSupplier.java rename to core/trino-spi/src/main/java/io/trino/spi/function/WindowFunctionSupplier.java index 2a0c3a248c85..6f327588fc94 100644 --- a/core/trino-main/src/main/java/io/trino/operator/window/WindowFunctionSupplier.java +++ b/core/trino-spi/src/main/java/io/trino/spi/function/WindowFunctionSupplier.java @@ -11,13 +11,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package io.trino.operator.window; +package io.trino.spi.function; -import io.trino.spi.function.WindowFunction; +import io.trino.spi.Experimental; import java.util.List; import java.util.function.Supplier; +@Experimental(eta = "2022-10-31") public interface WindowFunctionSupplier { WindowFunction createWindowFunction(boolean ignoreNulls, List> lambdaProviders); diff --git a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/classloader/ClassLoaderSafeConnectorMetadata.java b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/classloader/ClassLoaderSafeConnectorMetadata.java index 6120b74ec2b8..5f0ec0cf1d3a 100644 --- a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/classloader/ClassLoaderSafeConnectorMetadata.java +++ b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/classloader/ClassLoaderSafeConnectorMetadata.java @@ -62,6 +62,12 @@ import io.trino.spi.connector.TableScanRedirectApplicationResult; import io.trino.spi.connector.TopNApplicationResult; import io.trino.spi.expression.ConnectorExpression; +import io.trino.spi.function.AggregationFunctionMetadata; +import io.trino.spi.function.BoundSignature; +import io.trino.spi.function.FunctionDependencyDeclaration; +import io.trino.spi.function.FunctionId; +import io.trino.spi.function.FunctionMetadata; +import io.trino.spi.function.SchemaFunctionName; import io.trino.spi.predicate.TupleDomain; import io.trino.spi.ptf.ConnectorTableFunctionHandle; import io.trino.spi.security.GrantInfo; @@ -679,6 +685,48 @@ public Optional resolveIndex(ConnectorSession session, C } } + @Override + public Collection listFunctions(ConnectorSession session, String schemaName) + { + try (ThreadContextClassLoader ignored = new ThreadContextClassLoader(classLoader)) { + return delegate.listFunctions(session, schemaName); + } + } + + @Override + public Collection getFunctions(ConnectorSession session, SchemaFunctionName name) + { + try (ThreadContextClassLoader ignored = new ThreadContextClassLoader(classLoader)) { + return delegate.getFunctions(session, name); + } + } + + @Override + public FunctionMetadata getFunctionMetadata(ConnectorSession session, FunctionId functionId) + { + try (ThreadContextClassLoader ignored = new ThreadContextClassLoader(classLoader)) { + return delegate.getFunctionMetadata(session, functionId); + } + } + + @Override + public AggregationFunctionMetadata getAggregationFunctionMetadata(ConnectorSession session, + FunctionId functionId) + { + try (ThreadContextClassLoader ignored = new ThreadContextClassLoader(classLoader)) { + return delegate.getAggregationFunctionMetadata(session, functionId); + } + } + + @Override + public FunctionDependencyDeclaration getFunctionDependencies(ConnectorSession session, + FunctionId functionId, BoundSignature boundSignature) + { + try (ThreadContextClassLoader ignored = new ThreadContextClassLoader(classLoader)) { + return delegate.getFunctionDependencies(session, functionId, boundSignature); + } + } + @Override public boolean roleExists(ConnectorSession session, String role) { diff --git a/testing/trino-benchmark/src/main/java/io/trino/benchmark/AbstractOperatorBenchmark.java b/testing/trino-benchmark/src/main/java/io/trino/benchmark/AbstractOperatorBenchmark.java index 1dfb8bcde37f..6cd816c7a2e9 100644 --- a/testing/trino-benchmark/src/main/java/io/trino/benchmark/AbstractOperatorBenchmark.java +++ b/testing/trino-benchmark/src/main/java/io/trino/benchmark/AbstractOperatorBenchmark.java @@ -39,7 +39,6 @@ import io.trino.operator.PageSourceOperator; import io.trino.operator.TaskContext; import io.trino.operator.TaskStats; -import io.trino.operator.aggregation.AggregationMetadata; import io.trino.operator.project.InputPageProjection; import io.trino.operator.project.PageProcessor; import io.trino.operator.project.PageProjection; @@ -48,6 +47,7 @@ import io.trino.spi.connector.ColumnHandle; import io.trino.spi.connector.ConnectorPageSource; import io.trino.spi.connector.DynamicFilter; +import io.trino.spi.function.AggregationImplementation; import io.trino.spi.type.Type; import io.trino.spiller.SpillSpaceTracker; import io.trino.split.SplitSource; @@ -156,8 +156,8 @@ protected final List getColumnTypes(String tableName, String... columnName protected final BenchmarkAggregationFunction createAggregationFunction(String name, Type... argumentTypes) { ResolvedFunction resolvedFunction = localQueryRunner.getMetadata().resolveFunction(session, QualifiedName.of(name), fromTypes(argumentTypes)); - AggregationMetadata aggregationMetadata = localQueryRunner.getFunctionManager().getAggregateFunctionImplementation(resolvedFunction); - return new BenchmarkAggregationFunction(resolvedFunction, aggregationMetadata); + AggregationImplementation aggregationImplementation = localQueryRunner.getFunctionManager().getAggregationImplementation(resolvedFunction); + return new BenchmarkAggregationFunction(resolvedFunction, aggregationImplementation); } protected final OperatorFactory createTableScanOperator(int operatorId, PlanNodeId planNodeId, String tableName, String... columnNames) diff --git a/testing/trino-benchmark/src/main/java/io/trino/benchmark/BenchmarkAggregationFunction.java b/testing/trino-benchmark/src/main/java/io/trino/benchmark/BenchmarkAggregationFunction.java index 0e00c363ab30..62ca20d61bae 100644 --- a/testing/trino-benchmark/src/main/java/io/trino/benchmark/BenchmarkAggregationFunction.java +++ b/testing/trino-benchmark/src/main/java/io/trino/benchmark/BenchmarkAggregationFunction.java @@ -14,11 +14,11 @@ package io.trino.benchmark; import com.google.common.collect.ImmutableList; -import io.trino.metadata.BoundSignature; import io.trino.metadata.ResolvedFunction; import io.trino.operator.aggregation.AccumulatorFactory; -import io.trino.operator.aggregation.AggregationMetadata; import io.trino.operator.aggregation.AggregatorFactory; +import io.trino.spi.function.AggregationImplementation; +import io.trino.spi.function.BoundSignature; import io.trino.spi.type.Type; import io.trino.sql.planner.plan.AggregationNode.Step; @@ -34,12 +34,12 @@ public class BenchmarkAggregationFunction private final AccumulatorFactory accumulatorFactory; private final Type finalType; - public BenchmarkAggregationFunction(ResolvedFunction resolvedFunction, AggregationMetadata aggregationMetadata) + public BenchmarkAggregationFunction(ResolvedFunction resolvedFunction, AggregationImplementation aggregationImplementation) { BoundSignature signature = resolvedFunction.getSignature(); - intermediateType = getOnlyElement(aggregationMetadata.getAccumulatorStateDescriptors()).getSerializer().getSerializedType(); + intermediateType = getOnlyElement(aggregationImplementation.getAccumulatorStateDescriptors()).getSerializer().getSerializedType(); finalType = signature.getReturnType(); - accumulatorFactory = generateAccumulatorFactory(signature, aggregationMetadata, resolvedFunction.getFunctionNullability()); + accumulatorFactory = generateAccumulatorFactory(signature, aggregationImplementation, resolvedFunction.getFunctionNullability()); } public AggregatorFactory bind(List inputChannels) diff --git a/testing/trino-testing/src/main/java/io/trino/testing/StatefulSleepingSum.java b/testing/trino-testing/src/main/java/io/trino/testing/StatefulSleepingSum.java index 4eff66d1cff7..aad4a27639ee 100644 --- a/testing/trino-testing/src/main/java/io/trino/testing/StatefulSleepingSum.java +++ b/testing/trino-testing/src/main/java/io/trino/testing/StatefulSleepingSum.java @@ -13,12 +13,12 @@ */ package io.trino.testing; -import io.trino.metadata.BoundSignature; -import io.trino.metadata.FunctionMetadata; -import io.trino.metadata.Signature; import io.trino.metadata.SqlScalarFunction; -import io.trino.operator.scalar.ChoicesScalarFunctionImplementation; -import io.trino.operator.scalar.ScalarFunctionImplementation; +import io.trino.operator.scalar.ChoicesSpecializedSqlScalarFunction; +import io.trino.operator.scalar.SpecializedSqlScalarFunction; +import io.trino.spi.function.BoundSignature; +import io.trino.spi.function.FunctionMetadata; +import io.trino.spi.function.Signature; import java.util.Optional; import java.util.concurrent.ThreadLocalRandom; @@ -56,10 +56,10 @@ private StatefulSleepingSum() } @Override - protected ScalarFunctionImplementation specialize(BoundSignature boundSignature) + protected SpecializedSqlScalarFunction specialize(BoundSignature boundSignature) { int args = 4; - return new ChoicesScalarFunctionImplementation( + return new ChoicesSpecializedSqlScalarFunction( boundSignature, FAIL_ON_NULL, nCopies(args, NEVER_NULL),