Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import com.linkedin.feathr.common.exception.{ErrorLabel, FeathrFeatureTransforma
import com.linkedin.feathr.offline.client.DataFrameColName
import com.linkedin.feathr.offline.derived.DerivedFeature
import com.linkedin.feathr.offline.derived.functions.SQLFeatureDerivationFunction
import com.linkedin.feathr.offline.job.FeatureTransformation.FEATURE_NAME_PREFIX
import com.linkedin.feathr.offline.mvel.plugins.FeathrExpressionExecutionContext
import org.apache.spark.sql.functions.expr
import org.apache.spark.sql.{DataFrame, SparkSession}
Expand All @@ -21,12 +22,14 @@ class SqlDerivationSpark extends SqlDerivationSparkStrategy {
* @param deriveFeature derived feature definition
* @param keyTag list of tags represented by integer
* @param keyTagId2StringMap Map from the tag integer id to the string tag
* @param asIsFeatureNames features names that does not to be rewritten, i.e. passthrough features, as they do not have key tags
* @return Rewritten SQL expression
*/
private[offline] def rewriteDerivedFeatureExpression(
deriveFeature: DerivedFeature,
keyTag: Seq[Int],
keyTagId2StringMap: Seq[String]): String = {
keyTagId2StringMap: Seq[String],
asIsFeatureNames: Set[String]): String = {
if (!deriveFeature.derivation.isInstanceOf[SQLFeatureDerivationFunction]) {
throw new FeathrFeatureTransformationException(ErrorLabel.FEATHR_ERROR, "Should not rewrite derived feature expression for non-SQLDerivedFeatures")
}
Expand All @@ -42,7 +45,7 @@ class SqlDerivationSpark extends SqlDerivationSparkStrategy {
val namePattern = if (parameterNames.isEmpty) consumeFeatureName.getFeatureName else parameterNames(index)
// getBinding.map(keyTag.get) resolves the call tags
val newName =
if (!consumeFeatureName.getBinding.isEmpty // Passthrough features do not have keyTag
if (!asIsFeatureNames.contains(FEATURE_NAME_PREFIX + consumeFeatureName.getFeatureName)
// Feature generation code path does not create columns with tags.
// The check ensures we do not run into IndexOutOfBoundsException when keyTag & keyTagId2StringMap are empty.
&& keyTag.nonEmpty
Expand Down Expand Up @@ -98,7 +101,15 @@ class SqlDerivationSpark extends SqlDerivationSparkStrategy {
derivationFunction: SQLFeatureDerivationFunction,
mvelContext: Option[FeathrExpressionExecutionContext]): DataFrame = {
// sql expression based derived feature needs rewrite, e.g, replace the feature names with feature column names in the dataframe
val rewrittenExpr = rewriteDerivedFeatureExpression(derivedFeature, keyTags, keyTagList)
// Passthrough fields do not need rewrite as they do not have tags.
val passthroughFieldNames = df.schema.fields.map(f =>
if (f.name.startsWith(FEATURE_NAME_PREFIX)) {
f.name
} else {
FEATURE_NAME_PREFIX + f.name
}
).toSet
val rewrittenExpr = rewriteDerivedFeatureExpression(derivedFeature, keyTags, keyTagList, passthroughFieldNames)
val tags = Some(keyTags.map(keyTagList).toList)
val featureColumnName = DataFrameColName.genFeatureColumnName(derivedFeature.producedFeatureNames.head, tags)
df.withColumn(featureColumnName, expr(rewrittenExpr))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -484,7 +484,16 @@ class AnchoredFeaturesIntegTest extends FeathrIntegTest {
|
|derivations: {
| f_trip_time_distance: {
| definition: "f_trip_distance * f_trip_time_duration"
| definition: "f_trip_distance * f_trip_time_duration"
| type: NUMERIC
| }
| f_trip_time_distance_sql: {
| key: [trip]
| inputs: {
| trip_distance: { key: [trip], feature: f_trip_distance }
| trip_time_duration: { key: [trip], feature: f_trip_time_duration }
| }
| definition.sqlExpr: "trip_distance * trip_time_duration"
| type: NUMERIC
| }
|}
Expand Down Expand Up @@ -514,7 +523,8 @@ class AnchoredFeaturesIntegTest extends FeathrIntegTest {
|featureList: [
| {
| key: DOLocationID
| featureList: [f_location_avg_fare, f_trip_time_distance, f_trip_distance, f_trip_time_duration, f_is_long_trip_distance, f_day_of_week]
| featureList: [f_location_avg_fare, f_trip_time_distance, f_trip_distance,
| f_trip_time_duration, f_is_long_trip_distance, f_day_of_week, f_trip_time_distance_sql]
| }
|]
""".stripMargin
Expand Down