From b8923617bcc307c8c424f7192d4109d2d65ea3cb Mon Sep 17 00:00:00 2001 From: Jinghui Mo Date: Wed, 2 Nov 2022 11:39:18 -0400 Subject: [PATCH] Fix passthrough feature reference in sql-based derived feature --- .../derived/strategies/SqlDerivationSpark.scala | 17 ++++++++++++++--- .../offline/AnchoredFeaturesIntegTest.scala | 14 ++++++++++++-- 2 files changed, 26 insertions(+), 5 deletions(-) diff --git a/src/main/scala/com/linkedin/feathr/offline/derived/strategies/SqlDerivationSpark.scala b/src/main/scala/com/linkedin/feathr/offline/derived/strategies/SqlDerivationSpark.scala index 3afa0a6af..c7b44c1cf 100644 --- a/src/main/scala/com/linkedin/feathr/offline/derived/strategies/SqlDerivationSpark.scala +++ b/src/main/scala/com/linkedin/feathr/offline/derived/strategies/SqlDerivationSpark.scala @@ -4,6 +4,7 @@ import com.linkedin.feathr.common.exception.{ErrorLabel, FeathrFeatureTransforma import com.linkedin.feathr.offline.client.DataFrameColName import com.linkedin.feathr.offline.derived.DerivedFeature import com.linkedin.feathr.offline.derived.functions.SQLFeatureDerivationFunction +import com.linkedin.feathr.offline.job.FeatureTransformation.FEATURE_NAME_PREFIX import com.linkedin.feathr.offline.mvel.plugins.FeathrExpressionExecutionContext import org.apache.spark.sql.functions.expr import org.apache.spark.sql.{DataFrame, SparkSession} @@ -21,12 +22,14 @@ class SqlDerivationSpark extends SqlDerivationSparkStrategy { * @param deriveFeature derived feature definition * @param keyTag list of tags represented by integer * @param keyTagId2StringMap Map from the tag integer id to the string tag + * @param asIsFeatureNames features names that does not to be rewritten, i.e. passthrough features, as they do not have key tags * @return Rewritten SQL expression */ private[offline] def rewriteDerivedFeatureExpression( deriveFeature: DerivedFeature, keyTag: Seq[Int], - keyTagId2StringMap: Seq[String]): String = { + keyTagId2StringMap: Seq[String], + asIsFeatureNames: Set[String]): String = { if (!deriveFeature.derivation.isInstanceOf[SQLFeatureDerivationFunction]) { throw new FeathrFeatureTransformationException(ErrorLabel.FEATHR_ERROR, "Should not rewrite derived feature expression for non-SQLDerivedFeatures") } @@ -42,7 +45,7 @@ class SqlDerivationSpark extends SqlDerivationSparkStrategy { val namePattern = if (parameterNames.isEmpty) consumeFeatureName.getFeatureName else parameterNames(index) // getBinding.map(keyTag.get) resolves the call tags val newName = - if (!consumeFeatureName.getBinding.isEmpty // Passthrough features do not have keyTag + if (!asIsFeatureNames.contains(FEATURE_NAME_PREFIX + consumeFeatureName.getFeatureName) // Feature generation code path does not create columns with tags. // The check ensures we do not run into IndexOutOfBoundsException when keyTag & keyTagId2StringMap are empty. && keyTag.nonEmpty @@ -98,7 +101,15 @@ class SqlDerivationSpark extends SqlDerivationSparkStrategy { derivationFunction: SQLFeatureDerivationFunction, mvelContext: Option[FeathrExpressionExecutionContext]): DataFrame = { // sql expression based derived feature needs rewrite, e.g, replace the feature names with feature column names in the dataframe - val rewrittenExpr = rewriteDerivedFeatureExpression(derivedFeature, keyTags, keyTagList) + // Passthrough fields do not need rewrite as they do not have tags. + val passthroughFieldNames = df.schema.fields.map(f => + if (f.name.startsWith(FEATURE_NAME_PREFIX)) { + f.name + } else { + FEATURE_NAME_PREFIX + f.name + } + ).toSet + val rewrittenExpr = rewriteDerivedFeatureExpression(derivedFeature, keyTags, keyTagList, passthroughFieldNames) val tags = Some(keyTags.map(keyTagList).toList) val featureColumnName = DataFrameColName.genFeatureColumnName(derivedFeature.producedFeatureNames.head, tags) df.withColumn(featureColumnName, expr(rewrittenExpr)) diff --git a/src/test/scala/com/linkedin/feathr/offline/AnchoredFeaturesIntegTest.scala b/src/test/scala/com/linkedin/feathr/offline/AnchoredFeaturesIntegTest.scala index db69ea6f2..3735c0f9f 100644 --- a/src/test/scala/com/linkedin/feathr/offline/AnchoredFeaturesIntegTest.scala +++ b/src/test/scala/com/linkedin/feathr/offline/AnchoredFeaturesIntegTest.scala @@ -484,7 +484,16 @@ class AnchoredFeaturesIntegTest extends FeathrIntegTest { | |derivations: { | f_trip_time_distance: { - | definition: "f_trip_distance * f_trip_time_duration" + | definition: "f_trip_distance * f_trip_time_duration" + | type: NUMERIC + | } + | f_trip_time_distance_sql: { + | key: [trip] + | inputs: { + | trip_distance: { key: [trip], feature: f_trip_distance } + | trip_time_duration: { key: [trip], feature: f_trip_time_duration } + | } + | definition.sqlExpr: "trip_distance * trip_time_duration" | type: NUMERIC | } |} @@ -514,7 +523,8 @@ class AnchoredFeaturesIntegTest extends FeathrIntegTest { |featureList: [ | { | key: DOLocationID - | featureList: [f_location_avg_fare, f_trip_time_distance, f_trip_distance, f_trip_time_duration, f_is_long_trip_distance, f_day_of_week] + | featureList: [f_location_avg_fare, f_trip_time_distance, f_trip_distance, + | f_trip_time_duration, f_is_long_trip_distance, f_day_of_week, f_trip_time_distance_sql] | } |] """.stripMargin