diff --git a/src/main/scala/com/linkedin/feathr/offline/client/DataFrameColName.scala b/src/main/scala/com/linkedin/feathr/offline/client/DataFrameColName.scala index 830bc34f1..e1ae67a88 100644 --- a/src/main/scala/com/linkedin/feathr/offline/client/DataFrameColName.scala +++ b/src/main/scala/com/linkedin/feathr/offline/client/DataFrameColName.scala @@ -1,5 +1,6 @@ package com.linkedin.feathr.offline.client +import com.google.common.annotations.VisibleForTesting import com.linkedin.feathr.common._ import com.linkedin.feathr.common.exception.{ErrorLabel, FeathrFeatureTransformationException} import com.linkedin.feathr.offline.anchored.feature.FeatureAnchorWithSource @@ -357,11 +358,13 @@ object DataFrameColName { /** * generate header info (e.g, feature type, feature column name map) for output dataframe of * feature join or feature generation + * * @param featureToColumnNameMap map of feature to its column name in the dataframe * @param inferredFeatureTypeConfigs feature name to inferred feature types * @return header info for a dataframe that contains the features in featureToColumnNameMap */ - private def generateHeader( + @VisibleForTesting + def generateHeader( featureToColumnNameMap: Map[TaggedFeatureName, String], allAnchoredFeatures: Map[String, FeatureAnchorWithSource], allDerivedFeatures: Map[String, DerivedFeature], @@ -370,13 +373,10 @@ object DataFrameColName { // if the feature type is unspecified in the anchor config, we will use FeatureTypes.UNSPECIFIED val anchoredFeatureTypes: Map[String, FeatureTypeConfig] = allAnchoredFeatures.map { case (featureName, anchorWithSource) => - val featureTypeOpt = anchorWithSource.featureAnchor.getFeatureTypes.map(types => { - // Get the actual type in the output dataframe, the type is inferred and stored previously, if not specified by users - val inferredType = inferredFeatureTypeConfigs.getOrElse(featureName, FeatureTypeConfig.UNDEFINED_TYPE_CONFIG) - val fType = new FeatureTypeConfig(types.getOrElse(featureName, FeatureTypes.UNSPECIFIED)) - if (fType == FeatureTypeConfig.UNDEFINED_TYPE_CONFIG) inferredType else fType - }) - val featureType = featureTypeOpt.getOrElse(FeatureTypeConfig.UNDEFINED_TYPE_CONFIG) + val featureTypeOpt = anchorWithSource.featureAnchor.featureTypeConfigs.get(featureName) + // Get the actual type in the output dataframe, the type is inferred and stored previously, if not specified by users + val inferredType = inferredFeatureTypeConfigs.getOrElse(featureName, FeatureTypeConfig.UNDEFINED_TYPE_CONFIG) + val featureType = featureTypeOpt.getOrElse(inferredType) featureName -> featureType } diff --git a/src/test/scala/com/linkedin/feathr/offline/client/TestDataFrameColName.scala b/src/test/scala/com/linkedin/feathr/offline/client/TestDataFrameColName.scala index 1f7cf3d5c..540cae74a 100644 --- a/src/test/scala/com/linkedin/feathr/offline/client/TestDataFrameColName.scala +++ b/src/test/scala/com/linkedin/feathr/offline/client/TestDataFrameColName.scala @@ -1,9 +1,12 @@ package com.linkedin.feathr.offline.client -import com.linkedin.feathr.common.{DateParam, JoiningFeatureParams, TaggedFeatureName} +import com.linkedin.feathr.common.{DateParam, FeatureTypeConfig, JoiningFeatureParams, TaggedFeatureName} import com.linkedin.feathr.offline.TestFeathr +import com.linkedin.feathr.offline.anchored.feature.{FeatureAnchor, FeatureAnchorWithSource} import org.apache.spark.sql.Row import org.apache.spark.sql.types.{DoubleType, StringType, StructField, StructType} +import org.mockito.Mockito.when +import org.scalatest.mockito.MockitoSugar.mock import org.testng.Assert.assertEquals import org.testng.annotations.Test @@ -59,4 +62,26 @@ class TestDataFrameColName extends TestFeathr { val taggedFeature3 = new TaggedFeatureName("x", "seq_join_a_names") assertEquals(taggedFeatureToNewColumnNameMap(taggedFeature3)._2, "seq_join_a_names") } + + @Test(description = "Inferred feature type should be honored when user does not provide feature type") + def testGenerateHeader(): Unit = { + val mockFeatureAnchor = mock[FeatureAnchor] + // Mock if the user does not define feature type + when(mockFeatureAnchor.featureTypeConfigs).thenReturn(Map.empty[String, FeatureTypeConfig]) + + val mockFeatureAnchorWithSource = mock[FeatureAnchorWithSource] + when(mockFeatureAnchorWithSource.featureAnchor).thenReturn(mockFeatureAnchor) + val taggedFeatureName = new TaggedFeatureName("id", "f") + val featureToColumnNameMap: Map[TaggedFeatureName, String] = Map(taggedFeatureName -> "f") + val allAnchoredFeatures: Map[String, FeatureAnchorWithSource] = Map("f" -> mockFeatureAnchorWithSource) + // Mock if the type if inferred to be numeric + val inferredFeatureTypeConfigs: Map[String, FeatureTypeConfig] = Map("f" -> FeatureTypeConfig.NUMERIC_TYPE_CONFIG) + val header = DataFrameColName.generateHeader( + featureToColumnNameMap, + allAnchoredFeatures, + Map(), + inferredFeatureTypeConfigs) + // output should be using the inferred type, i.e. numeric + assertEquals(header.featureInfoMap.get(taggedFeatureName).get.featureType, FeatureTypeConfig.NUMERIC_TYPE_CONFIG) + } }