Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Add function resolution to SPI #12588

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Sep 4, 2022
Prev Previous commit
Next Next commit
Rename AggregationMetadata to AggregationImplementation
  • Loading branch information
dain committed Sep 3, 2022
commit 0b1f5f13c8ec9217ada99930587995a305827b0d
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
*/
package io.trino.metadata;

import io.trino.operator.aggregation.AggregationMetadata;
import io.trino.operator.aggregation.AggregationImplementation;
import io.trino.operator.window.WindowFunctionSupplier;
import io.trino.spi.function.InvocationConvention;

Expand All @@ -33,7 +33,7 @@ ScalarFunctionImplementation getScalarFunctionImplementation(
FunctionDependencies functionDependencies,
InvocationConvention invocationConvention);

AggregationMetadata getAggregateFunctionImplementation(FunctionId functionId, BoundSignature boundSignature, FunctionDependencies functionDependencies);
AggregationImplementation getAggregationImplementation(FunctionId functionId, BoundSignature boundSignature, FunctionDependencies functionDependencies);

WindowFunctionSupplier getWindowFunctionImplementation(FunctionId functionId, BoundSignature boundSignature, FunctionDependencies functionDependencies);
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import com.google.common.util.concurrent.UncheckedExecutionException;
import io.trino.FeaturesConfig;
import io.trino.collect.cache.NonEvictableCache;
import io.trino.operator.aggregation.AggregationMetadata;
import io.trino.operator.aggregation.AggregationImplementation;
import io.trino.operator.window.WindowFunctionSupplier;
import io.trino.spi.TrinoException;
import io.trino.spi.block.Block;
Expand Down Expand Up @@ -52,7 +52,7 @@
public class FunctionManager
{
private final NonEvictableCache<FunctionKey, ScalarFunctionImplementation> specializedScalarCache;
private final NonEvictableCache<FunctionKey, AggregationMetadata> specializedAggregationCache;
private final NonEvictableCache<FunctionKey, AggregationImplementation> specializedAggregationCache;
private final NonEvictableCache<FunctionKey, WindowFunctionSupplier> specializedWindowCache;

private final GlobalFunctionCatalog globalFunctionCatalog;
Expand Down Expand Up @@ -98,21 +98,21 @@ private ScalarFunctionImplementation getScalarFunctionImplementationInternal(Res
return scalarFunctionImplementation;
}

public AggregationMetadata getAggregateFunctionImplementation(ResolvedFunction resolvedFunction)
public AggregationImplementation getAggregationImplementation(ResolvedFunction resolvedFunction)
{
try {
return uncheckedCacheGet(specializedAggregationCache, new FunctionKey(resolvedFunction), () -> getAggregateFunctionImplementationInternal(resolvedFunction));
return uncheckedCacheGet(specializedAggregationCache, new FunctionKey(resolvedFunction), () -> getAggregationImplementationInternal(resolvedFunction));
}
catch (UncheckedExecutionException e) {
throwIfInstanceOf(e.getCause(), TrinoException.class);
throw new RuntimeException(e.getCause());
}
}

private AggregationMetadata getAggregateFunctionImplementationInternal(ResolvedFunction resolvedFunction)
private AggregationImplementation getAggregationImplementationInternal(ResolvedFunction resolvedFunction)
{
FunctionDependencies functionDependencies = getFunctionDependencies(resolvedFunction);
return globalFunctionCatalog.getAggregateFunctionImplementation(
return globalFunctionCatalog.getAggregationImplementation(
resolvedFunction.getFunctionId(),
resolvedFunction.getSignature(),
functionDependencies);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import com.google.common.collect.ImmutableListMultimap;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Multimap;
import io.trino.operator.aggregation.AggregationMetadata;
import io.trino.operator.aggregation.AggregationImplementation;
import io.trino.operator.window.WindowFunctionSupplier;
import io.trino.spi.function.InvocationConvention;
import io.trino.spi.function.OperatorType;
Expand Down Expand Up @@ -140,9 +140,9 @@ public WindowFunctionSupplier getWindowFunctionImplementation(FunctionId functio
return functions.getFunctionBundle(functionId).getWindowFunctionImplementation(functionId, boundSignature, functionDependencies);
}

public AggregationMetadata getAggregateFunctionImplementation(FunctionId functionId, BoundSignature boundSignature, FunctionDependencies functionDependencies)
public AggregationImplementation getAggregationImplementation(FunctionId functionId, BoundSignature boundSignature, FunctionDependencies functionDependencies)
{
return functions.getFunctionBundle(functionId).getAggregateFunctionImplementation(functionId, boundSignature, functionDependencies);
return functions.getFunctionBundle(functionId).getAggregationImplementation(functionId, boundSignature, functionDependencies);
}

public FunctionDependencyDeclaration getFunctionDependencies(FunctionId functionId, BoundSignature boundSignature)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import com.google.common.collect.ImmutableList;
import com.google.common.util.concurrent.UncheckedExecutionException;
import io.trino.collect.cache.NonEvictableCache;
import io.trino.operator.aggregation.AggregationMetadata;
import io.trino.operator.aggregation.AggregationImplementation;
import io.trino.operator.scalar.SpecializedSqlScalarFunction;
import io.trino.operator.scalar.annotations.ScalarFromAnnotationsParser;
import io.trino.operator.window.SqlWindowFunction;
Expand Down Expand Up @@ -52,7 +52,7 @@ public class InternalFunctionBundle
{
// scalar function specialization may involve expensive code generation
private final NonEvictableCache<FunctionKey, SpecializedSqlScalarFunction> specializedScalarCache;
private final NonEvictableCache<FunctionKey, AggregationMetadata> specializedAggregationCache;
private final NonEvictableCache<FunctionKey, AggregationImplementation> specializedAggregationCache;
private final NonEvictableCache<FunctionKey, WindowFunctionSupplier> specializedWindowCache;
private final Map<FunctionId, SqlFunction> functions;

Expand Down Expand Up @@ -136,7 +136,7 @@ private SpecializedSqlScalarFunction specializeScalarFunction(FunctionId functio
}

@Override
public AggregationMetadata getAggregateFunctionImplementation(FunctionId functionId, BoundSignature boundSignature, FunctionDependencies functionDependencies)
public AggregationImplementation getAggregationImplementation(FunctionId functionId, BoundSignature boundSignature, FunctionDependencies functionDependencies)
{
try {
return uncheckedCacheGet(specializedAggregationCache, new FunctionKey(functionId, boundSignature), () -> specializedAggregation(functionId, boundSignature, functionDependencies));
Expand All @@ -147,7 +147,7 @@ public AggregationMetadata getAggregateFunctionImplementation(FunctionId functio
}
}

private AggregationMetadata specializedAggregation(FunctionId functionId, BoundSignature boundSignature, FunctionDependencies functionDependencies)
private AggregationImplementation specializedAggregation(FunctionId functionId, BoundSignature boundSignature, FunctionDependencies functionDependencies)
{
SqlAggregationFunction aggregationFunction = (SqlAggregationFunction) functions.get(functionId);
return aggregationFunction.specialize(boundSignature, functionDependencies);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import com.google.common.collect.ImmutableList;
import io.trino.operator.aggregation.AggregationFromAnnotationsParser;
import io.trino.operator.aggregation.AggregationMetadata;
import io.trino.operator.aggregation.AggregationImplementation;

import java.util.List;

Expand Down Expand Up @@ -54,12 +54,12 @@ public AggregationFunctionMetadata getAggregationMetadata()
return aggregationFunctionMetadata;
}

public AggregationMetadata specialize(BoundSignature boundSignature, FunctionDependencies functionDependencies)
public AggregationImplementation specialize(BoundSignature boundSignature, FunctionDependencies functionDependencies)
{
return specialize(boundSignature);
}

protected AggregationMetadata specialize(BoundSignature boundSignature)
protected AggregationImplementation specialize(BoundSignature boundSignature)
{
throw new UnsupportedOperationException();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
import io.trino.metadata.BoundSignature;
import io.trino.metadata.FunctionNullability;
import io.trino.operator.GroupByIdBlock;
import io.trino.operator.aggregation.AggregationMetadata.AccumulatorStateDescriptor;
import io.trino.operator.aggregation.AggregationImplementation.AccumulatorStateDescriptor;
import io.trino.operator.window.InternalWindowIndex;
import io.trino.spi.Page;
import io.trino.spi.block.Block;
Expand Down Expand Up @@ -86,41 +86,41 @@ private AccumulatorCompiler() {}

public static AccumulatorFactory generateAccumulatorFactory(
BoundSignature boundSignature,
AggregationMetadata metadata,
AggregationImplementation implementation,
FunctionNullability functionNullability)
{
// change types used in Aggregation methods to types used in the core Trino engine to simplify code generation
metadata = normalizeAggregationMethods(metadata);
implementation = normalizeAggregationMethods(implementation);

DynamicClassLoader classLoader = new DynamicClassLoader(AccumulatorCompiler.class.getClassLoader());

List<Boolean> argumentNullable = functionNullability.getArgumentNullable()
.subList(0, functionNullability.getArgumentNullable().size() - metadata.getLambdaInterfaces().size());
.subList(0, functionNullability.getArgumentNullable().size() - implementation.getLambdaInterfaces().size());

Constructor<? extends Accumulator> accumulatorConstructor = generateAccumulatorClass(
boundSignature,
Accumulator.class,
metadata,
implementation,
argumentNullable,
classLoader);

Constructor<? extends GroupedAccumulator> groupedAccumulatorConstructor = generateAccumulatorClass(
boundSignature,
GroupedAccumulator.class,
metadata,
implementation,
argumentNullable,
classLoader);

return new CompiledAccumulatorFactory(
accumulatorConstructor,
groupedAccumulatorConstructor,
metadata.getLambdaInterfaces());
implementation.getLambdaInterfaces());
}

private static <T> Constructor<? extends T> generateAccumulatorClass(
BoundSignature boundSignature,
Class<T> accumulatorInterface,
AggregationMetadata metadata,
AggregationImplementation implementation,
List<Boolean> argumentNullable,
DynamicClassLoader classLoader)
{
Expand All @@ -134,7 +134,7 @@ private static <T> Constructor<? extends T> generateAccumulatorClass(

CallSiteBinder callSiteBinder = new CallSiteBinder();

List<AccumulatorStateDescriptor<?>> stateDescriptors = metadata.getAccumulatorStateDescriptors();
List<AccumulatorStateDescriptor<?>> stateDescriptors = implementation.getAccumulatorStateDescriptors();
List<StateFieldAndDescriptor> stateFieldAndDescriptors = new ArrayList<>();
for (int i = 0; i < stateDescriptors.size(); i++) {
stateFieldAndDescriptors.add(new StateFieldAndDescriptor(
Expand All @@ -147,7 +147,7 @@ private static <T> Constructor<? extends T> generateAccumulatorClass(
.map(StateFieldAndDescriptor::getStateField)
.collect(toImmutableList());

int lambdaCount = metadata.getLambdaInterfaces().size();
int lambdaCount = implementation.getLambdaInterfaces().size();
List<FieldDefinition> lambdaProviderFields = new ArrayList<>(lambdaCount);
for (int i = 0; i < lambdaCount; i++) {
lambdaProviderFields.add(definition.declareField(a(PRIVATE, FINAL), "lambdaProvider_" + i, Supplier.class));
Expand All @@ -173,7 +173,7 @@ private static <T> Constructor<? extends T> generateAccumulatorClass(
stateFields,
argumentNullable,
lambdaProviderFields,
metadata.getInputFunction(),
implementation.getInputFunction(),
callSiteBinder,
grouped);
generateGetEstimatedSize(definition, stateFields);
Expand All @@ -182,7 +182,7 @@ private static <T> Constructor<? extends T> generateAccumulatorClass(
definition,
stateFieldAndDescriptors,
lambdaProviderFields,
metadata.getCombineFunction(),
implementation.getCombineFunction(),
callSiteBinder,
grouped);

Expand All @@ -194,10 +194,10 @@ private static <T> Constructor<? extends T> generateAccumulatorClass(
}

if (grouped) {
generateGroupedEvaluateFinal(definition, stateFields, metadata.getOutputFunction(), callSiteBinder);
generateGroupedEvaluateFinal(definition, stateFields, implementation.getOutputFunction(), callSiteBinder);
}
else {
generateEvaluateFinal(definition, stateFields, metadata.getOutputFunction(), callSiteBinder);
generateEvaluateFinal(definition, stateFields, implementation.getOutputFunction(), callSiteBinder);
}

if (grouped) {
Expand All @@ -215,13 +215,13 @@ private static <T> Constructor<? extends T> generateAccumulatorClass(

public static Constructor<? extends WindowAccumulator> generateWindowAccumulatorClass(
BoundSignature boundSignature,
AggregationMetadata metadata,
AggregationImplementation implementation,
FunctionNullability functionNullability)
{
DynamicClassLoader classLoader = new DynamicClassLoader(AccumulatorCompiler.class.getClassLoader());

List<Boolean> argumentNullable = functionNullability.getArgumentNullable()
.subList(0, functionNullability.getArgumentNullable().size() - metadata.getLambdaInterfaces().size());
.subList(0, functionNullability.getArgumentNullable().size() - implementation.getLambdaInterfaces().size());

ClassDefinition definition = new ClassDefinition(
a(PUBLIC, FINAL),
Expand All @@ -231,7 +231,7 @@ public static Constructor<? extends WindowAccumulator> generateWindowAccumulator

CallSiteBinder callSiteBinder = new CallSiteBinder();

List<AccumulatorStateDescriptor<?>> stateDescriptors = metadata.getAccumulatorStateDescriptors();
List<AccumulatorStateDescriptor<?>> stateDescriptors = implementation.getAccumulatorStateDescriptors();
List<StateFieldAndDescriptor> stateFieldAndDescriptors = new ArrayList<>();
for (int i = 0; i < stateDescriptors.size(); i++) {
stateFieldAndDescriptors.add(new StateFieldAndDescriptor(
Expand All @@ -244,7 +244,7 @@ public static Constructor<? extends WindowAccumulator> generateWindowAccumulator
.map(StateFieldAndDescriptor::getStateField)
.collect(toImmutableList());

int lambdaCount = metadata.getLambdaInterfaces().size();
int lambdaCount = implementation.getLambdaInterfaces().size();
List<FieldDefinition> lambdaProviderFields = new ArrayList<>(lambdaCount);
for (int i = 0; i < lambdaCount; i++) {
lambdaProviderFields.add(definition.declareField(a(PRIVATE, FINAL), "lambdaProvider_" + i, Supplier.class));
Expand All @@ -268,10 +268,10 @@ public static Constructor<? extends WindowAccumulator> generateWindowAccumulator
stateFields,
argumentNullable,
lambdaProviderFields,
metadata.getInputFunction(),
implementation.getInputFunction(),
"addInput",
callSiteBinder);
metadata.getRemoveInputFunction().ifPresent(
implementation.getRemoveInputFunction().ifPresent(
removeInputFunction -> generateAddOrRemoveInputWindowIndex(
definition,
stateFields,
Expand All @@ -281,7 +281,7 @@ public static Constructor<? extends WindowAccumulator> generateWindowAccumulator
"removeInput",
callSiteBinder));

generateEvaluateFinal(definition, stateFields, metadata.getOutputFunction(), callSiteBinder);
generateEvaluateFinal(definition, stateFields, implementation.getOutputFunction(), callSiteBinder);
generateGetEstimatedSize(definition, stateFields);

Class<? extends WindowAccumulator> windowAccumulatorClass = defineClass(definition, WindowAccumulator.class, callSiteBinder.getBindings(), classLoader);
Expand Down Expand Up @@ -1044,18 +1044,18 @@ private static BytecodeExpression generateRequireNotNull(BytecodeExpression expr
.cast(expression.getType());
}

private static AggregationMetadata normalizeAggregationMethods(AggregationMetadata metadata)
private static AggregationImplementation normalizeAggregationMethods(AggregationImplementation implementation)
{
// change aggregations state variables to simply AccumulatorState to avoid any class loader issues in generated code
int stateParameterCount = metadata.getAccumulatorStateDescriptors().size();
int lambdaParameterCount = metadata.getLambdaInterfaces().size();
return new AggregationMetadata(
castStateParameters(metadata.getInputFunction(), stateParameterCount, lambdaParameterCount),
metadata.getRemoveInputFunction().map(removeFunction -> castStateParameters(removeFunction, stateParameterCount, lambdaParameterCount)),
metadata.getCombineFunction().map(combineFunction -> castStateParameters(combineFunction, stateParameterCount * 2, lambdaParameterCount)),
castStateParameters(metadata.getOutputFunction(), stateParameterCount, 0),
metadata.getAccumulatorStateDescriptors(),
metadata.getLambdaInterfaces());
int stateParameterCount = implementation.getAccumulatorStateDescriptors().size();
int lambdaParameterCount = implementation.getLambdaInterfaces().size();
return new AggregationImplementation(
castStateParameters(implementation.getInputFunction(), stateParameterCount, lambdaParameterCount),
implementation.getRemoveInputFunction().map(removeFunction -> castStateParameters(removeFunction, stateParameterCount, lambdaParameterCount)),
implementation.getCombineFunction().map(combineFunction -> castStateParameters(combineFunction, stateParameterCount * 2, lambdaParameterCount)),
castStateParameters(implementation.getOutputFunction(), stateParameterCount, 0),
implementation.getAccumulatorStateDescriptors(),
implementation.getLambdaInterfaces());
}

private static MethodHandle castStateParameters(MethodHandle inputFunction, int stateParameterCount, int lambdaParameterCount)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import io.trino.metadata.FunctionDependencies;
import io.trino.metadata.Signature;
import io.trino.operator.ParametricImplementationsGroup;
import io.trino.operator.aggregation.AggregationMetadata.AccumulatorStateDescriptor;
import io.trino.operator.aggregation.AggregationImplementation.AccumulatorStateDescriptor;
import io.trino.operator.aggregation.state.InOutStateSerializer;
import io.trino.operator.annotations.FunctionsParserHelper;
import io.trino.operator.annotations.ImplementationDependency;
Expand Down
Loading