From 068e8365d148a3ca60d29462a6795ff3441551bf Mon Sep 17 00:00:00 2001 From: Chen Xu Date: Fri, 15 Jul 2022 03:19:05 +0800 Subject: [PATCH 01/10] GenericLocation for DataFrame read/write --- .../offline/config/FeathrConfigLoader.scala | 6 +- ...InputLocation.scala => DataLocation.scala} | 69 +++++++++++- .../config/location/GenericLocation.scala | 70 ++++++++++++ .../feathr/offline/config/location/Jdbc.scala | 106 +++++++++++++----- .../config/location/KafkaEndpoint.scala | 4 +- .../offline/config/location/PathList.scala | 4 +- .../offline/config/location/SimplePath.scala | 4 +- .../offline/generation/SparkIOUtils.scala | 34 +++--- .../WriteToHDFSOutputProcessor.scala | 21 +++- .../feathr/offline/job/FeatureJoinJob.scala | 33 ++++-- .../feathr/offline/job/JoinJobContext.scala | 3 +- .../feathr/offline/source/DataSource.scala | 6 +- .../NonTimeBasedDataSourceAccessor.scala | 3 +- .../source/dataloader/BatchDataLoader.scala | 4 +- .../source/dataloader/DataLoaderFactory.scala | 4 +- .../dataloader/LocalDataLoaderFactory.scala | 4 +- .../StreamingDataLoaderFactory.scala | 4 +- .../feathr/offline/util/SourceUtils.scala | 8 +- .../offline/AnchoredFeaturesIntegTest.scala | 3 +- .../config/location/TestDesLocation.scala | 36 +++++- 20 files changed, 336 insertions(+), 90 deletions(-) rename src/main/scala/com/linkedin/feathr/offline/config/location/{InputLocation.scala => DataLocation.scala} (57%) create mode 100644 src/main/scala/com/linkedin/feathr/offline/config/location/GenericLocation.scala diff --git a/src/main/scala/com/linkedin/feathr/offline/config/FeathrConfigLoader.scala b/src/main/scala/com/linkedin/feathr/offline/config/FeathrConfigLoader.scala index 66a1e247e..702329e0a 100644 --- a/src/main/scala/com/linkedin/feathr/offline/config/FeathrConfigLoader.scala +++ b/src/main/scala/com/linkedin/feathr/offline/config/FeathrConfigLoader.scala @@ -13,7 +13,7 @@ import com.linkedin.feathr.offline.ErasedEntityTaggedFeature import com.linkedin.feathr.offline.anchored.anchorExtractor.{SQLConfigurableAnchorExtractor, SimpleConfigurableAnchorExtractor, TimeWindowConfigurableAnchorExtractor} import com.linkedin.feathr.offline.anchored.feature.{FeatureAnchor, FeatureAnchorWithSource} import com.linkedin.feathr.offline.anchored.keyExtractor.{MVELSourceKeyExtractor, SQLSourceKeyExtractor} -import com.linkedin.feathr.offline.config.location.{InputLocation, Jdbc, KafkaEndpoint, LocationUtils, SimplePath} +import com.linkedin.feathr.offline.config.location.{DataLocation, KafkaEndpoint, LocationUtils, SimplePath} import com.linkedin.feathr.offline.derived._ import com.linkedin.feathr.offline.derived.functions.{MvelFeatureDerivationFunction, SQLFeatureDerivationFunction, SeqJoinDerivationFunction, SimpleMvelDerivationFunction} import com.linkedin.feathr.offline.source.{DataSource, SourceFormatType, TimeWindowParams} @@ -712,7 +712,7 @@ private[offline] class DataSourceLoader extends JsonDeserializer[DataSource] { * 2. a placeholder with reserved string "PASSTHROUGH" for anchor defined pass-through features, * since anchor defined pass-through features do not have path */ - val path: InputLocation = dataSourceType match { + val path: DataLocation = dataSourceType match { case "KAFKA" => Option(node.get("config")) match { case Some(field: ObjectNode) => @@ -725,7 +725,7 @@ private[offline] class DataSourceLoader extends JsonDeserializer[DataSource] { case "PASSTHROUGH" => SimplePath("PASSTHROUGH") case _ => Option(node.get("location")) match { case Some(field: ObjectNode) => - LocationUtils.getMapper().treeToValue(field, classOf[InputLocation]) + LocationUtils.getMapper().treeToValue(field, classOf[DataLocation]) case None => throw new FeathrConfigException(ErrorLabel.FEATHR_USER_ERROR, s"Data location is not defined for data source ${node.toPrettyString()}") case _ => throw new FeathrConfigException(ErrorLabel.FEATHR_USER_ERROR, diff --git a/src/main/scala/com/linkedin/feathr/offline/config/location/InputLocation.scala b/src/main/scala/com/linkedin/feathr/offline/config/location/DataLocation.scala similarity index 57% rename from src/main/scala/com/linkedin/feathr/offline/config/location/InputLocation.scala rename to src/main/scala/com/linkedin/feathr/offline/config/location/DataLocation.scala index e7c72bea1..781997e1f 100644 --- a/src/main/scala/com/linkedin/feathr/offline/config/location/InputLocation.scala +++ b/src/main/scala/com/linkedin/feathr/offline/config/location/DataLocation.scala @@ -1,6 +1,7 @@ package com.linkedin.feathr.offline.config.location import com.fasterxml.jackson.annotation.{JsonSubTypes, JsonTypeInfo} +import com.fasterxml.jackson.core.JacksonException import com.fasterxml.jackson.databind.module.SimpleModule import com.fasterxml.jackson.databind.{DeserializationFeature, ObjectMapper} import com.fasterxml.jackson.module.caseclass.mapper.CaseClassObjectMapper @@ -8,7 +9,9 @@ import com.jasonclawson.jackson.dataformat.hocon.HoconFactory import com.linkedin.feathr.common.FeathrJacksonScalaModule import com.linkedin.feathr.offline.config.DataSourceLoader import com.linkedin.feathr.offline.source.DataSource +import com.typesafe.config.Config import org.apache.spark.sql.{DataFrame, SparkSession} +import scala.collection.JavaConverters._ /** * An InputLocation is a data source definition, it can either be HDFS files or a JDBC database connection @@ -20,38 +23,50 @@ import org.apache.spark.sql.{DataFrame, SparkSession} new JsonSubTypes.Type(value = classOf[SimplePath], name = "path"), new JsonSubTypes.Type(value = classOf[PathList], name = "pathlist"), new JsonSubTypes.Type(value = classOf[Jdbc], name = "jdbc"), + new JsonSubTypes.Type(value = classOf[GenericLocation], name = "generic"), )) -trait InputLocation { +trait DataLocation { /** * Backward Compatibility * Many existing codes expect a simple path + * * @return the `path` or `url` of the data source * - * WARN: This method is deprecated, you must use match/case on InputLocation, - * and get `path` from `SimplePath` only + * WARN: This method is deprecated, you must use match/case on InputLocation, + * and get `path` from `SimplePath` only */ @deprecated("Do not use this method in any new code, it will be removed soon") def getPath: String /** * Backward Compatibility + * * @return the `path` or `url` of the data source, wrapped in an List * - * WARN: This method is deprecated, you must use match/case on InputLocation, - * and get `paths` from `PathList` only + * WARN: This method is deprecated, you must use match/case on InputLocation, + * and get `paths` from `PathList` only */ @deprecated("Do not use this method in any new code, it will be removed soon") def getPathList: List[String] /** * Load DataFrame from Spark session + * * @param ss SparkSession * @return */ def loadDf(ss: SparkSession, dataIOParameters: Map[String, String] = Map()): DataFrame + /** + * Write DataFrame to the location + * @param ss SparkSession + * @param df DataFrame to write + */ + def writeDf(ss: SparkSession, df: DataFrame) + /** * Tell if this location is file based + * * @return boolean */ def isFileBasedLocation(): Boolean @@ -67,6 +82,7 @@ object LocationUtils { /** * String template substitution, replace "...${VAR}.." with corresponding System property or environment variable * Non-existent pattern is replaced by empty string. + * * @param s String template to be processed * @return Processed result */ @@ -76,6 +92,7 @@ object LocationUtils { /** * Get an ObjectMapper to deserialize DataSource + * * @return the ObjectMapper */ def getMapper(): ObjectMapper = { @@ -86,3 +103,45 @@ object LocationUtils { .registerModule(new SimpleModule().addDeserializer(classOf[DataSource], new DataSourceLoader)) } } + +object DataLocation { + /** + * Create DataLocation from string, try parsing the string as JSON and fallback to SimplePath + * @param cfg the input string + * @return DataLocation + */ + def apply(cfg: String): DataLocation = { + val jackson = (new ObjectMapper(new HoconFactory) with CaseClassObjectMapper) + .registerModule(FeathrJacksonScalaModule) // DefaultScalaModule causes a fail on holdem + .configure(DeserializationFeature.ACCEPT_SINGLE_VALUE_AS_ARRAY, true) + .configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false) + .registerModule(new SimpleModule().addDeserializer(classOf[DataSource], new DataSourceLoader)) + try { + val location = jackson.readValue(cfg, classOf[DataLocation]) + location + } catch { + case _: JacksonException => SimplePath(cfg) + } + } + + def apply(cfg: Config): DataLocation = { + apply(cfg.root().keySet().asScala.map(key ⇒ key → cfg.getString(key)).toMap) + } + + def apply(cfg: Any): DataLocation = { + val jackson = (new ObjectMapper(new HoconFactory) with CaseClassObjectMapper) + .registerModule(FeathrJacksonScalaModule) // DefaultScalaModule causes a fail on holdem + .configure(DeserializationFeature.ACCEPT_SINGLE_VALUE_AS_ARRAY, true) + .configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false) + .registerModule(new SimpleModule().addDeserializer(classOf[DataSource], new DataSourceLoader)) + try { + val location = jackson.convertValue(cfg, classOf[DataLocation]) + location + } catch { + case e: JacksonException => { + print(e) + SimplePath(cfg.toString) + } + } + } +} diff --git a/src/main/scala/com/linkedin/feathr/offline/config/location/GenericLocation.scala b/src/main/scala/com/linkedin/feathr/offline/config/location/GenericLocation.scala new file mode 100644 index 000000000..12d473cb7 --- /dev/null +++ b/src/main/scala/com/linkedin/feathr/offline/config/location/GenericLocation.scala @@ -0,0 +1,70 @@ +package com.linkedin.feathr.offline.config.location + +import com.fasterxml.jackson.annotation.JsonAnySetter +import com.fasterxml.jackson.module.caseclass.annotation.CaseClassDeserialize +import net.minidev.json.annotate.JsonIgnore +import org.apache.spark.sql.{DataFrame, SparkSession} + +@CaseClassDeserialize() +case class GenericLocation(format: String, @JsonIgnore options: collection.mutable.Map[String, String] = collection.mutable.Map[String, String]()) extends DataLocation { + /** + * Backward Compatibility + * Many existing codes expect a simple path + * + * @return the `path` or `url` of the data source + * + * WARN: This method is deprecated, you must use match/case on DataLocation, + * and get `path` from `SimplePath` only + */ + override def getPath: String = s"GenericLocation(${format})" + + /** + * Backward Compatibility + * + * @return the `path` or `url` of the data source, wrapped in an List + * + * WARN: This method is deprecated, you must use match/case on DataLocation, + * and get `paths` from `PathList` only + */ + override def getPathList: List[String] = List(getPath) + + /** + * Load DataFrame from Spark session + * + * @param ss SparkSession + * @return + */ + override def loadDf(ss: SparkSession, dataIOParameters: Map[String, String]): DataFrame = { + ss.read.format(format) + .options(getOptions) + .load() + } + + /** + * Write DataFrame to the location + * + * @param ss SparkSession + * @param df DataFrame to write + */ + override def writeDf(ss: SparkSession, df: DataFrame): Unit = { + df.write.format(format) + .options(getOptions) + .save() + } + + /** + * Tell if this location is file based + * + * @return boolean + */ + override def isFileBasedLocation(): Boolean = false + + def getOptions(): Map[String, String] = { + options.map(e => e._1 -> LocationUtils.envSubstitute(e._2)).toMap + } + + @JsonAnySetter + def setOption(key: String, value: Any) = { + options += (key -> value.toString) + } +} diff --git a/src/main/scala/com/linkedin/feathr/offline/config/location/Jdbc.scala b/src/main/scala/com/linkedin/feathr/offline/config/location/Jdbc.scala index 1cef50f7b..0d70b1a48 100644 --- a/src/main/scala/com/linkedin/feathr/offline/config/location/Jdbc.scala +++ b/src/main/scala/com/linkedin/feathr/offline/config/location/Jdbc.scala @@ -8,7 +8,7 @@ import org.apache.spark.sql.{DataFrame, SparkSession} import org.eclipse.jetty.util.StringUtil @CaseClassDeserialize() -case class Jdbc(url: String, @JsonAlias(Array("query")) dbtable: String, user: String = "", password: String = "", token: String = "", useToken: Boolean = false, anonymous: Boolean = false) extends InputLocation { +case class Jdbc(url: String, @JsonAlias(Array("query")) dbtable: String, user: String = "", password: String = "", token: String = "", useToken: Boolean = false, anonymous: Boolean = false) extends DataLocation { override def loadDf(ss: SparkSession, dataIOParameters: Map[String, String] = Map()): DataFrame = { println(s"Jdbc.loadDf, location is ${this}") var reader = ss.read.format("jdbc") @@ -18,7 +18,7 @@ case class Jdbc(url: String, @JsonAlias(Array("query")) dbtable: String, user: S reader = reader.option("dbtable", ss.conf.get(DBTABLE_CONF)) } else { val q = dbtable.trim - if("\\s".r.findFirstIn(q).nonEmpty) { + if ("\\s".r.findFirstIn(q).nonEmpty) { // This is a SQL instead of a table name reader = reader.option("query", q) } else { @@ -47,6 +47,18 @@ case class Jdbc(url: String, @JsonAlias(Array("query")) dbtable: String, user: S } } + override def writeDf(ss: SparkSession, df: DataFrame): Unit = { + println(s"Jdbc.writeDf, location is ${this}") + if (StringUtil.isBlank(user) && StringUtil.isBlank(password) && !anonymous && !useToken) { + // Fallback to global JDBC credential + println("Fallback to default credential") + ss.conf.set(DBTABLE_CONF, dbtable) + } + df.write.format("jdbc") + .options(getOptions(ss)) + .save() + } + override def getPath: String = url override def getPathList: List[String] = List(url) @@ -55,35 +67,69 @@ case class Jdbc(url: String, @JsonAlias(Array("query")) dbtable: String, user: S // These members don't contain actual secrets override def toString: String = s"Jdbc(url=$url, dbtable=$dbtable, useToken=$useToken, anonymous=$anonymous, user=$user, password=$password, token=$token)" + + def getOptions(ss: SparkSession): Map[String, String] = { + val options = collection.mutable.Map[String, String]() + options += ("url" -> url) + + if (StringUtil.isBlank(dbtable)) { + // Fallback to default table name + options += ("dbtable" -> ss.conf.get(DBTABLE_CONF)) + } else { + val q = dbtable.trim + if ("\\s".r.findFirstIn(q).nonEmpty) { + // This is a SQL instead of a table name + options += ("query" -> q) + } else { + options += ("dbtable" -> q) + } + } + if (useToken) { + options += ("accessToken" -> LocationUtils.envSubstitute(token)) + options += ("hostNameInCertificate" -> "*.database.windows.net") + options += ("encrypt" -> "true") + } else { + if (!StringUtil.isBlank(user)) { + options += ("user" -> LocationUtils.envSubstitute(user)) + } + if (!StringUtil.isBlank(password)) { + options += ("password" -> LocationUtils.envSubstitute(password)) + } + } + options.toMap + } } -object Jdbc { - /** - * Create JDBC InputLocation with required info and user/password auth - * @param url - * @param dbtable - * @param user - * @param password - * @return Newly created InputLocation instance - */ - def apply(url: String, dbtable: String, user: String, password: String): Jdbc = Jdbc(url, dbtable, user = user, password = password, useToken = false) + object Jdbc { + /** + * Create JDBC InputLocation with required info and user/password auth + * + * @param url + * @param dbtable + * @param user + * @param password + * @return Newly created InputLocation instance + */ + def apply(url: String, dbtable: String, user: String, password: String): Jdbc = Jdbc(url, dbtable, user = user, password = password, useToken = false) - /** - * Create JDBC InputLocation with required info and OAuth token auth - * @param url - * @param dbtable - * @param token - * @return Newly created InputLocation instance - */ - def apply(url: String, dbtable: String, token: String): Jdbc = Jdbc(url, dbtable, token = token, useToken = true) + /** + * Create JDBC InputLocation with required info and OAuth token auth + * + * @param url + * @param dbtable + * @param token + * @return Newly created InputLocation instance + */ + def apply(url: String, dbtable: String, token: String): Jdbc = Jdbc(url, dbtable, token = token, useToken = true) - /** - * Create JDBC InputLocation with required info and OAuth token auth - * In this case, the auth info is taken from default setting passed from CLI/API, details can be found in `Jdbc#loadDf` - * @see com.linkedin.feathr.offline.source.dataloader.jdbc.JDBCUtils#loadDataFrame - * @param url - * @param dbtable - * @return Newly created InputLocation instance - */ - def apply(url: String, dbtable: String): Jdbc = Jdbc(url, dbtable, useToken = false) -} + /** + * Create JDBC InputLocation with required info and OAuth token auth + * In this case, the auth info is taken from default setting passed from CLI/API, details can be found in `Jdbc#loadDf` + * + * @see com.linkedin.feathr.offline.source.dataloader.jdbc.JDBCUtils#loadDataFrame + * @param url + * @param dbtable + * @return Newly created InputLocation instance + */ + def apply(url: String, dbtable: String): Jdbc = Jdbc(url, dbtable, useToken = false) + } diff --git a/src/main/scala/com/linkedin/feathr/offline/config/location/KafkaEndpoint.scala b/src/main/scala/com/linkedin/feathr/offline/config/location/KafkaEndpoint.scala index 42bbc9f72..e3c3a7298 100644 --- a/src/main/scala/com/linkedin/feathr/offline/config/location/KafkaEndpoint.scala +++ b/src/main/scala/com/linkedin/feathr/offline/config/location/KafkaEndpoint.scala @@ -28,9 +28,11 @@ case class KafkaSchema(@JsonProperty("type") `type`: String, @CaseClassDeserialize() case class KafkaEndpoint(@JsonProperty("brokers") brokers: List[String], @JsonProperty("topics") topics: List[String], - @JsonProperty("schema") schema: KafkaSchema) extends InputLocation { + @JsonProperty("schema") schema: KafkaSchema) extends DataLocation { override def loadDf(ss: SparkSession, dataIOParameters: Map[String, String] = Map()): DataFrame = ??? + override def writeDf(ss: SparkSession, df: DataFrame): Unit = ??? + override def getPath: String = "kafka://" + brokers.mkString(",")+":"+topics.mkString(",") override def getPathList: List[String] = ??? diff --git a/src/main/scala/com/linkedin/feathr/offline/config/location/PathList.scala b/src/main/scala/com/linkedin/feathr/offline/config/location/PathList.scala index 7f491ec32..2f16091ee 100644 --- a/src/main/scala/com/linkedin/feathr/offline/config/location/PathList.scala +++ b/src/main/scala/com/linkedin/feathr/offline/config/location/PathList.scala @@ -4,7 +4,7 @@ import com.linkedin.feathr.offline.generation.SparkIOUtils import org.apache.spark.sql.{DataFrame, SparkSession} import org.apache.hadoop.mapred.JobConf -case class PathList(paths: List[String]) extends InputLocation { +case class PathList(paths: List[String]) extends DataLocation { override def getPath: String = paths.mkString(";") override def getPathList: List[String] = paths @@ -15,6 +15,8 @@ case class PathList(paths: List[String]) extends InputLocation { SparkIOUtils.createUnionDataFrame(getPathList, dataIOParameters, new JobConf(), List()) //TODO: Add handler support here. Currently there are deserilization issues with adding handlers to factory builder. } + override def writeDf(ss: SparkSession, df: DataFrame): Unit = ??? + override def toString: String = s"PathList(path=[${paths.mkString(",")}])" } diff --git a/src/main/scala/com/linkedin/feathr/offline/config/location/SimplePath.scala b/src/main/scala/com/linkedin/feathr/offline/config/location/SimplePath.scala index de726bf40..e94e2211b 100644 --- a/src/main/scala/com/linkedin/feathr/offline/config/location/SimplePath.scala +++ b/src/main/scala/com/linkedin/feathr/offline/config/location/SimplePath.scala @@ -7,11 +7,13 @@ import org.apache.spark.sql.{DataFrame, SparkSession} import org.codehaus.jackson.annotate.JsonProperty @CaseClassDeserialize() -case class SimplePath(@JsonProperty("path") path: String) extends InputLocation { +case class SimplePath(@JsonProperty("path") path: String) extends DataLocation { override def loadDf(ss: SparkSession, dataIOParameters: Map[String, String] = Map()): DataFrame = { SparkIOUtils.createUnionDataFrame(getPathList, dataIOParameters, new JobConf(), List()) // The simple path is not responsible for handling custom data loaders. } + override def writeDf(ss: SparkSession, df: DataFrame): Unit = ??? + override def getPath: String = path override def getPathList: List[String] = List(path) diff --git a/src/main/scala/com/linkedin/feathr/offline/generation/SparkIOUtils.scala b/src/main/scala/com/linkedin/feathr/offline/generation/SparkIOUtils.scala index 8d1e832bf..e7c912f17 100644 --- a/src/main/scala/com/linkedin/feathr/offline/generation/SparkIOUtils.scala +++ b/src/main/scala/com/linkedin/feathr/offline/generation/SparkIOUtils.scala @@ -1,8 +1,7 @@ package com.linkedin.feathr.offline.generation -import com.linkedin.feathr.offline.config.location.{InputLocation, Jdbc, SimplePath} +import com.linkedin.feathr.offline.config.location.{DataLocation, SimplePath} import com.linkedin.feathr.offline.source.dataloader.hdfs.FileFormat -import com.linkedin.feathr.offline.source.dataloader.jdbc.JdbcUtils import com.linkedin.feathr.offline.source.dataloader.DataLoaderHandler import org.apache.avro.generic.GenericRecord import org.apache.hadoop.mapred.JobConf @@ -35,7 +34,7 @@ object SparkIOUtils { df } - def createDataFrame(location: InputLocation, dataIOParams: Map[String, String] = Map(), jobConf: JobConf, dataLoaderHandlers: List[DataLoaderHandler]): DataFrame = { + def createDataFrame(location: DataLocation, dataIOParams: Map[String, String] = Map(), jobConf: JobConf, dataLoaderHandlers: List[DataLoaderHandler]): DataFrame = { var dfOpt: Option[DataFrame] = None breakable { for (dataLoaderHandler <- dataLoaderHandlers) { @@ -54,23 +53,32 @@ object SparkIOUtils { df } - def writeDataFrame( outputDF: DataFrame, path: String, parameters: Map[String, String] = Map(), dataLoaderHandlers: List[DataLoaderHandler]): DataFrame = { + def writeDataFrame( outputDF: DataFrame, outputLocation: DataLocation, parameters: Map[String, String] = Map(), dataLoaderHandlers: List[DataLoaderHandler]): DataFrame = { var dfWritten = false breakable { for (dataLoaderHandler <- dataLoaderHandlers) { - if (dataLoaderHandler.validatePath(path)) { - dataLoaderHandler.writeDataFrame(outputDF, path, parameters) - dfWritten = true - break + outputLocation match { + case SimplePath(path) => { + if (dataLoaderHandler.validatePath(path)) { + dataLoaderHandler.writeDataFrame(outputDF, path, parameters) + dfWritten = true + break + } + } } } } if(!dfWritten) { - val output_format = outputDF.sqlContext.getConf("spark.feathr.outputFormat", "avro") - // if the output format is set by spark configurations "spark.feathr.outputFormat" - // we will use that as the job output format; otherwise use avro as default for backward compatibility - outputDF.write.mode(SaveMode.Overwrite).format(output_format).save(path) - outputDF + outputLocation match { + case SimplePath(path) => { + val output_format = outputDF.sqlContext.getConf("spark.feathr.outputFormat", "avro") + // if the output format is set by spark configurations "spark.feathr.outputFormat" + // we will use that as the job output format; otherwise use avro as default for backward compatibility + outputDF.write.mode(SaveMode.Overwrite).format(output_format).save(path) + outputDF + } + case _ => outputLocation.writeDf(SparkSession.builder().getOrCreate(), outputDF) + } } outputDF } diff --git a/src/main/scala/com/linkedin/feathr/offline/generation/outputProcessor/WriteToHDFSOutputProcessor.scala b/src/main/scala/com/linkedin/feathr/offline/generation/outputProcessor/WriteToHDFSOutputProcessor.scala index 7ac29a779..0d0438397 100644 --- a/src/main/scala/com/linkedin/feathr/offline/generation/outputProcessor/WriteToHDFSOutputProcessor.scala +++ b/src/main/scala/com/linkedin/feathr/offline/generation/outputProcessor/WriteToHDFSOutputProcessor.scala @@ -4,6 +4,7 @@ import com.linkedin.feathr.offline.util.Transformations.sortColumns import com.linkedin.feathr.common.configObj.generation.OutputProcessorConfig import com.linkedin.feathr.common.exception.{ErrorLabel, FeathrDataOutputException, FeathrException} import com.linkedin.feathr.common.{Header, RichConfig, TaggedFeatureName} +import com.linkedin.feathr.offline.config.location.{DataLocation, SimplePath} import com.linkedin.feathr.offline.generation.{FeatureDataHDFSProcessUtils, FeatureGenerationPathName} import com.linkedin.feathr.offline.util.{FeatureGenConstants, IncrementalAggUtils} import com.linkedin.feathr.offline.source.dataloader.DataLoaderHandler @@ -149,7 +150,25 @@ private[offline] class WriteToHDFSOutputProcessor(val config: OutputProcessorCon // If it's local, we can't write to HDFS. val skipWrite = if (ss.sparkContext.isLocal) true else false - FeatureDataHDFSProcessUtils.processFeatureDataHDFS(ss, featuresToDF, parentPath, config, skipWrite = skipWrite, endTimeOpt, timestampOpt, dataLoaderHandlers) + location match { + case Some(l) => { + // We have a DataLocation to write the df + l.writeDf(ss, augmentedDF) + (augmentedDF, header) + } + case None => { + FeatureDataHDFSProcessUtils.processFeatureDataHDFS(ss, featuresToDF, parentPath, config, skipWrite = skipWrite, endTimeOpt, timestampOpt, dataLoaderHandlers) + } + } + } + + private val location: Option[DataLocation] = { + if (!config.getParams.getStringWithDefault("type", "").isEmpty) { + // The location param contains 'type' key, assuming it's a DataLocation + Some(DataLocation(config.getParams)) + } else { + None + } } // path parameter name diff --git a/src/main/scala/com/linkedin/feathr/offline/job/FeatureJoinJob.scala b/src/main/scala/com/linkedin/feathr/offline/job/FeatureJoinJob.scala index 0288b9451..24ea1fbc5 100644 --- a/src/main/scala/com/linkedin/feathr/offline/job/FeatureJoinJob.scala +++ b/src/main/scala/com/linkedin/feathr/offline/job/FeatureJoinJob.scala @@ -9,6 +9,7 @@ import com.linkedin.feathr.offline._ import com.linkedin.feathr.offline.client._ import com.linkedin.feathr.offline.config.FeatureJoinConfig import com.linkedin.feathr.offline.config.datasource.{DataSourceConfigUtils, DataSourceConfigs} +import com.linkedin.feathr.offline.config.location.{DataLocation, SimplePath} import com.linkedin.feathr.offline.generation.SparkIOUtils import com.linkedin.feathr.offline.source.SourceFormatType import com.linkedin.feathr.offline.source.accessor.DataPathHandler @@ -88,11 +89,18 @@ object FeatureJoinJob { } private def checkAuthorization(ss: SparkSession, hadoopConf: Configuration, jobContext: FeathrJoinJobContext, dataLoaderHandlers: List[DataLoaderHandler]): Unit = { - AclCheckUtils.checkWriteAuthorization(hadoopConf, jobContext.jobJoinContext.outputPath) match { - case Failure(e) => - throw new FeathrDataOutputException(ErrorLabel.FEATHR_USER_ERROR, s"No write permission for output path ${jobContext.jobJoinContext.outputPath}.", e) - case Success(_) => log.debug("Checked write authorization on output path: " + jobContext.jobJoinContext.outputPath) + + jobContext.jobJoinContext.outputPath match { + case SimplePath(path) => { + AclCheckUtils.checkWriteAuthorization(hadoopConf, path) match { + case Failure(e) => + throw new FeathrDataOutputException(ErrorLabel.FEATHR_USER_ERROR, s"No write permission for output path ${jobContext.jobJoinContext.outputPath}.", e) + case Success(_) => log.debug("Checked write authorization on output path: " + jobContext.jobJoinContext.outputPath) + } + } + case _ => {} } + jobContext.jobJoinContext.inputData.map(inputData => { val failOnMissing = FeathrUtils.getFeathrJobParam(ss, FeathrUtils.FAIL_ON_MISSING_PARTITION).toBoolean val pathList = getPathList(sourceFormatType=inputData.sourceType, @@ -255,13 +263,13 @@ object FeatureJoinJob { } } - val joinJobContext = { - val feathrLocalConfig = cmdParser.extractOptionalValue("feathr-config") - val feathrFeatureConfig = cmdParser.extractOptionalValue("feature-config") - val localOverrideAll = cmdParser.extractRequiredValue("local-override-all") - val outputPath = cmdParser.extractRequiredValue("output") - val numParts = cmdParser.extractRequiredValue("num-parts").toInt + val feathrLocalConfig = cmdParser.extractOptionalValue("feathr-config") + val feathrFeatureConfig = cmdParser.extractOptionalValue("feature-config") + val localOverrideAll = cmdParser.extractRequiredValue("local-override-all") + val outputPath = DataLocation(cmdParser.extractRequiredValue("output")) + val numParts = cmdParser.extractRequiredValue("num-parts").toInt + val joinJobContext = { JoinJobContext( feathrLocalConfig, feathrFeatureConfig, @@ -359,7 +367,10 @@ object FeatureJoinJob { DataSourceConfigUtils.setupHadoopConf(sparkSession, jobContext.dataSourceConfigs) FeathrUdfRegistry.registerUdf(sparkSession) - HdfsUtils.deletePath(jobContext.jobJoinContext.outputPath, recursive = true, conf) + jobContext.jobJoinContext.outputPath match { + case SimplePath(path) => HdfsUtils.deletePath(path, recursive = true, conf) + case _ => {} + } val enableDebugLog = FeathrUtils.getFeathrJobParam(sparkConf, FeathrUtils.ENABLE_DEBUG_OUTPUT).toBoolean if (enableDebugLog) { diff --git a/src/main/scala/com/linkedin/feathr/offline/job/JoinJobContext.scala b/src/main/scala/com/linkedin/feathr/offline/job/JoinJobContext.scala index bdf506796..de72b7bda 100644 --- a/src/main/scala/com/linkedin/feathr/offline/job/JoinJobContext.scala +++ b/src/main/scala/com/linkedin/feathr/offline/job/JoinJobContext.scala @@ -1,6 +1,7 @@ package com.linkedin.feathr.offline.job import com.linkedin.feathr.offline.client.InputData +import com.linkedin.feathr.offline.config.location.{DataLocation, SimplePath} object JoinJobContext { @@ -23,7 +24,7 @@ case class JoinJobContext( feathrLocalConfig: Option[String] = None, feathrFeatureConfig: Option[String] = None, inputData: Option[InputData] = None, - outputPath: String = "/join_output", + outputPath: DataLocation = SimplePath("/join_output"), numParts: Int = 1 ) { } diff --git a/src/main/scala/com/linkedin/feathr/offline/source/DataSource.scala b/src/main/scala/com/linkedin/feathr/offline/source/DataSource.scala index 4f05a3c28..8c132cf4a 100644 --- a/src/main/scala/com/linkedin/feathr/offline/source/DataSource.scala +++ b/src/main/scala/com/linkedin/feathr/offline/source/DataSource.scala @@ -1,6 +1,6 @@ package com.linkedin.feathr.offline.source -import com.linkedin.feathr.offline.config.location.{InputLocation, SimplePath} +import com.linkedin.feathr.offline.config.location.{DataLocation, SimplePath} import com.linkedin.feathr.offline.source.SourceFormatType.SourceFormatType import com.linkedin.feathr.offline.util.{AclCheckUtils, HdfsUtils, LocalFeatureJoinUtils} import org.apache.hadoop.fs.Path @@ -20,7 +20,7 @@ import scala.util.{Failure, Success, Try} * @param timePartitionPattern format of the time partitioned feature */ private[offline] case class DataSource( - val location: InputLocation, + val location: DataLocation, sourceType: SourceFormatType, timeWindowParams: Option[TimeWindowParams], timePartitionPattern: Option[String]) @@ -63,7 +63,7 @@ object DataSource { timeWindowParams: Option[TimeWindowParams] = None, timePartitionPattern: Option[String] = None): DataSource = DataSource(SimplePath(rawPath), sourceType, timeWindowParams, timePartitionPattern) - def apply(inputLocation: InputLocation, + def apply(inputLocation: DataLocation, sourceType: SourceFormatType): DataSource = DataSource(inputLocation, sourceType, None, None) } \ No newline at end of file diff --git a/src/main/scala/com/linkedin/feathr/offline/source/accessor/NonTimeBasedDataSourceAccessor.scala b/src/main/scala/com/linkedin/feathr/offline/source/accessor/NonTimeBasedDataSourceAccessor.scala index 91c6e5665..cebc7d54e 100644 --- a/src/main/scala/com/linkedin/feathr/offline/source/accessor/NonTimeBasedDataSourceAccessor.scala +++ b/src/main/scala/com/linkedin/feathr/offline/source/accessor/NonTimeBasedDataSourceAccessor.scala @@ -1,6 +1,6 @@ package com.linkedin.feathr.offline.source.accessor -import com.linkedin.feathr.offline.config.location.{Jdbc, KafkaEndpoint, PathList, SimplePath} +import com.linkedin.feathr.offline.config.location.{GenericLocation, Jdbc, KafkaEndpoint, PathList, SimplePath} import com.linkedin.feathr.offline.source.DataSource import com.linkedin.feathr.offline.source.dataloader.DataLoaderFactory import com.linkedin.feathr.offline.testfwk.TestFwkUtils @@ -31,6 +31,7 @@ private[offline] class NonTimeBasedDataSourceAccessor( case SimplePath(path) => List(path).map(fileLoaderFactory.create(_).loadDataFrame()).reduce((x, y) => x.fuzzyUnion(y)) case PathList(paths) => paths.map(fileLoaderFactory.create(_).loadDataFrame()).reduce((x, y) => x.fuzzyUnion(y)) case Jdbc(_, _, _, _, _, _, _) => source.location.loadDf(SparkSession.builder().getOrCreate()) + case GenericLocation(_, _) => source.location.loadDf(SparkSession.builder().getOrCreate()) case _ => fileLoaderFactory.createFromLocation(source.location).loadDataFrame() } diff --git a/src/main/scala/com/linkedin/feathr/offline/source/dataloader/BatchDataLoader.scala b/src/main/scala/com/linkedin/feathr/offline/source/dataloader/BatchDataLoader.scala index db430fb4a..63fe275a0 100644 --- a/src/main/scala/com/linkedin/feathr/offline/source/dataloader/BatchDataLoader.scala +++ b/src/main/scala/com/linkedin/feathr/offline/source/dataloader/BatchDataLoader.scala @@ -1,7 +1,7 @@ package com.linkedin.feathr.offline.source.dataloader import com.linkedin.feathr.common.exception.{ErrorLabel, FeathrInputDataException} -import com.linkedin.feathr.offline.config.location.InputLocation +import com.linkedin.feathr.offline.config.location.DataLocation import com.linkedin.feathr.offline.generation.SparkIOUtils import com.linkedin.feathr.offline.job.DataSourceUtils.getSchemaFromAvroDataFile import com.linkedin.feathr.offline.source.dataloader.jdbc.JdbcUtils @@ -16,7 +16,7 @@ import org.apache.spark.sql.{DataFrame, SparkSession} * @param ss the spark session * @param path input data path */ -private[offline] class BatchDataLoader(ss: SparkSession, location: InputLocation, dataLoaderHandlers: List[DataLoaderHandler]) extends DataLoader { +private[offline] class BatchDataLoader(ss: SparkSession, location: DataLocation, dataLoaderHandlers: List[DataLoaderHandler]) extends DataLoader { /** * get the schema of the source. It's only used in the deprecated DataSource.getDataSetAndSchema diff --git a/src/main/scala/com/linkedin/feathr/offline/source/dataloader/DataLoaderFactory.scala b/src/main/scala/com/linkedin/feathr/offline/source/dataloader/DataLoaderFactory.scala index d3b94ebfc..057be7e9b 100644 --- a/src/main/scala/com/linkedin/feathr/offline/source/dataloader/DataLoaderFactory.scala +++ b/src/main/scala/com/linkedin/feathr/offline/source/dataloader/DataLoaderFactory.scala @@ -1,6 +1,6 @@ package com.linkedin.feathr.offline.source.dataloader -import com.linkedin.feathr.offline.config.location.InputLocation +import com.linkedin.feathr.offline.config.location.DataLocation import com.linkedin.feathr.offline.source.dataloader.DataLoaderHandler import org.apache.log4j.Logger import org.apache.spark.customized.CustomGenericRowWithSchema @@ -21,7 +21,7 @@ private[offline] trait DataLoaderFactory { */ def create(path: String): DataLoader - def createFromLocation(input: InputLocation): DataLoader = create(input.getPath) + def createFromLocation(input: DataLocation): DataLoader = create(input.getPath) } private[offline] object DataLoaderFactory { diff --git a/src/main/scala/com/linkedin/feathr/offline/source/dataloader/LocalDataLoaderFactory.scala b/src/main/scala/com/linkedin/feathr/offline/source/dataloader/LocalDataLoaderFactory.scala index 4d2fe505f..c6f007333 100644 --- a/src/main/scala/com/linkedin/feathr/offline/source/dataloader/LocalDataLoaderFactory.scala +++ b/src/main/scala/com/linkedin/feathr/offline/source/dataloader/LocalDataLoaderFactory.scala @@ -1,6 +1,6 @@ package com.linkedin.feathr.offline.source.dataloader -import com.linkedin.feathr.offline.config.location.{InputLocation, KafkaEndpoint, SimplePath} +import com.linkedin.feathr.offline.config.location.{DataLocation, KafkaEndpoint, SimplePath} import com.linkedin.feathr.offline.source.dataloader.stream.KafkaDataLoader import com.linkedin.feathr.offline.source.dataloader.DataLoaderHandler import com.linkedin.feathr.offline.util.LocalFeatureJoinUtils @@ -58,7 +58,7 @@ dataLoaderHandlers: List[DataLoaderHandler]) extends DataLoaderFactory { } } - override def createFromLocation(inputLocation: InputLocation): DataLoader = { + override def createFromLocation(inputLocation: DataLocation): DataLoader = { if (inputLocation.isInstanceOf[KafkaEndpoint]) { new KafkaDataLoader(ss, inputLocation.asInstanceOf[KafkaEndpoint]) } else { diff --git a/src/main/scala/com/linkedin/feathr/offline/source/dataloader/StreamingDataLoaderFactory.scala b/src/main/scala/com/linkedin/feathr/offline/source/dataloader/StreamingDataLoaderFactory.scala index 2eafde55b..279cd04a8 100644 --- a/src/main/scala/com/linkedin/feathr/offline/source/dataloader/StreamingDataLoaderFactory.scala +++ b/src/main/scala/com/linkedin/feathr/offline/source/dataloader/StreamingDataLoaderFactory.scala @@ -1,6 +1,6 @@ package com.linkedin.feathr.offline.source.dataloader -import com.linkedin.feathr.offline.config.location.{InputLocation, KafkaEndpoint} +import com.linkedin.feathr.offline.config.location.{DataLocation, KafkaEndpoint} import com.linkedin.feathr.offline.source.dataloader.stream.KafkaDataLoader import org.apache.spark.sql.SparkSession @@ -17,7 +17,7 @@ private[offline] class StreamingDataLoaderFactory(ss: SparkSession) extends Dat * @param input the input location for streaming * @return a [[DataLoader]] */ - override def createFromLocation(input: InputLocation): DataLoader = new KafkaDataLoader(ss, input.asInstanceOf[KafkaEndpoint]) + override def createFromLocation(input: DataLocation): DataLoader = new KafkaDataLoader(ss, input.asInstanceOf[KafkaEndpoint]) /** * create a data loader based on the file type. diff --git a/src/main/scala/com/linkedin/feathr/offline/util/SourceUtils.scala b/src/main/scala/com/linkedin/feathr/offline/util/SourceUtils.scala index 95838d1a0..cb85a2f61 100644 --- a/src/main/scala/com/linkedin/feathr/offline/util/SourceUtils.scala +++ b/src/main/scala/com/linkedin/feathr/offline/util/SourceUtils.scala @@ -7,7 +7,7 @@ import com.jasonclawson.jackson.dataformat.hocon.HoconFactory import com.linkedin.feathr.common.exception._ import com.linkedin.feathr.common.{AnchorExtractor, DateParam} import com.linkedin.feathr.offline.client.InputData -import com.linkedin.feathr.offline.config.location.{InputLocation, SimplePath} +import com.linkedin.feathr.offline.config.location.{DataLocation, SimplePath} import com.linkedin.feathr.offline.generation.SparkIOUtils import com.linkedin.feathr.offline.mvel.{MvelContext, MvelUtils} import com.linkedin.feathr.offline.source.SourceFormatType @@ -175,7 +175,7 @@ private[offline] object SourceUtils { def safeWriteDF(df: DataFrame, dataPath: String, parameters: Map[String, String], dataLoaderHandlers: List[DataLoaderHandler]): Unit = { val tempBasePath = dataPath.stripSuffix("/") + "_temp_" HdfsUtils.deletePath(dataPath, true) - SparkIOUtils.writeDataFrame(df, tempBasePath, parameters, dataLoaderHandlers) + SparkIOUtils.writeDataFrame(df, SimplePath(tempBasePath), parameters, dataLoaderHandlers) if (HdfsUtils.exists(tempBasePath) && !HdfsUtils.renamePath(tempBasePath, dataPath)) { throw new FeathrDataOutputException( ErrorLabel.FEATHR_ERROR, @@ -644,8 +644,8 @@ private[offline] object SourceUtils { * @param inputPath * @return */ - def loadAsDataFrame(ss: SparkSession, location: InputLocation, - dataLoaderHandlers: List[DataLoaderHandler]): DataFrame = { + def loadAsDataFrame(ss: SparkSession, location: DataLocation, + dataLoaderHandlers: List[DataLoaderHandler]): DataFrame = { val sparkConf = ss.sparkContext.getConf val inputSplitSize = sparkConf.get("spark.feathr.input.split.size", "") val dataIOParameters = Map(SparkIOUtils.SPLIT_SIZE -> inputSplitSize) diff --git a/src/test/scala/com/linkedin/feathr/offline/AnchoredFeaturesIntegTest.scala b/src/test/scala/com/linkedin/feathr/offline/AnchoredFeaturesIntegTest.scala index caf680564..061b42598 100644 --- a/src/test/scala/com/linkedin/feathr/offline/AnchoredFeaturesIntegTest.scala +++ b/src/test/scala/com/linkedin/feathr/offline/AnchoredFeaturesIntegTest.scala @@ -2,6 +2,7 @@ package com.linkedin.feathr.offline import com.linkedin.feathr.common.configObj.configbuilder.ConfigBuilderException import com.linkedin.feathr.common.exception.FeathrConfigException +import com.linkedin.feathr.offline.config.location.SimplePath import com.linkedin.feathr.offline.generation.SparkIOUtils import com.linkedin.feathr.offline.job.PreprocessedDataFrameManager import com.linkedin.feathr.offline.source.dataloader.{AvroJsonDataLoader, CsvDataLoader} @@ -142,7 +143,7 @@ class AnchoredFeaturesIntegTest extends FeathrIntegTest { // create a data source from anchorAndDerivations/nullValueSource.avro.json val df = new AvroJsonDataLoader(ss, "nullValueSource.avro.json").loadDataFrame() - SparkIOUtils.writeDataFrame(df, mockDataFolder + "/nullValueSource", parameters=Map(), dataLoaderHandlers=List()) + SparkIOUtils.writeDataFrame(df, SimplePath(mockDataFolder + "/nullValueSource"), parameters=Map(), dataLoaderHandlers=List()) } /** diff --git a/src/test/scala/com/linkedin/feathr/offline/config/location/TestDesLocation.scala b/src/test/scala/com/linkedin/feathr/offline/config/location/TestDesLocation.scala index 5d5f80998..dcee1a909 100644 --- a/src/test/scala/com/linkedin/feathr/offline/config/location/TestDesLocation.scala +++ b/src/test/scala/com/linkedin/feathr/offline/config/location/TestDesLocation.scala @@ -7,7 +7,7 @@ import com.jasonclawson.jackson.dataformat.hocon.HoconFactory import com.linkedin.feathr.common.FeathrJacksonScalaModule import com.linkedin.feathr.offline.config.DataSourceLoader import com.linkedin.feathr.offline.config.location.LocationUtils.envSubstitute -import com.linkedin.feathr.offline.config.location.{InputLocation, Jdbc, SimplePath} +import com.linkedin.feathr.offline.config.location.{DataLocation, Jdbc, SimplePath} import com.linkedin.feathr.offline.generation.SparkIOUtils import com.linkedin.feathr.offline.source.DataSource import org.apache.spark.{SparkConf, SparkContext} @@ -32,7 +32,7 @@ class TestDesLocation extends FunSuite { test("Deserialize Location") { { val configDoc = """{ path: "abfss://feathrazuretest3fs@feathrazuretest3storage.dfs.core.windows.net/demo_data/green_tripdata_2020-04.csv" }""" - val ds = jackson.readValue(configDoc, classOf[InputLocation]) + val ds = jackson.readValue(configDoc, classOf[DataLocation]) ds match { case SimplePath(path) => { assert(path == "abfss://feathrazuretest3fs@feathrazuretest3storage.dfs.core.windows.net/demo_data/green_tripdata_2020-04.csv") @@ -50,7 +50,7 @@ class TestDesLocation extends FunSuite { | user: "bar" | password: "foo" |}""".stripMargin - val ds = jackson.readValue(configDoc, classOf[InputLocation]) + val ds = jackson.readValue(configDoc, classOf[DataLocation]) ds match { case Jdbc(url, dbtable, user, password, token, useToken, _) => { assert(url == "jdbc:sqlserver://myserver.database.windows.net:1433;database=mydatabase") @@ -69,7 +69,7 @@ class TestDesLocation extends FunSuite { | type: "pathlist" | paths: ["abfss://feathrazuretest3fs@feathrazuretest3storage.dfs.core.windows.net/demo_data/green_tripdata_2020-04.csv"] |}""".stripMargin - val ds = jackson.readValue(configDoc, classOf[InputLocation]) + val ds = jackson.readValue(configDoc, classOf[DataLocation]) ds match { case PathList(pathList) => { assert(pathList == List("abfss://feathrazuretest3fs@feathrazuretest3storage.dfs.core.windows.net/demo_data/green_tripdata_2020-04.csv")) @@ -88,7 +88,7 @@ class TestDesLocation extends FunSuite { | dbtable: "table1" | anonymous: true |}""".stripMargin - val ds = jackson.readValue(configDoc, classOf[InputLocation]) + val ds = jackson.readValue(configDoc, classOf[DataLocation]) val _ = SparkSession.builder().config("spark.master", "local").appName("Sqlite test").getOrCreate() @@ -112,7 +112,31 @@ class TestDesLocation extends FunSuite { | query: "select c1, c2 from table1" | anonymous: true |}""".stripMargin - val ds = jackson.readValue(configDoc, classOf[InputLocation]) + val ds = jackson.readValue(configDoc, classOf[DataLocation]) + + val _ = SparkSession.builder().config("spark.master", "local").appName("Sqlite test").getOrCreate() + + val df = SparkIOUtils.createDataFrame(ds, Map(), new JobConf(), List()) + val rows = df.head(3) + assert(rows(0).getLong(0) == 1) + assert(rows(1).getLong(0) == 2) + assert(rows(2).getLong(0) == 3) + assert(rows(0).getString(1) == "r1c2") + assert(rows(1).getString(1) == "r2c2") + assert(rows(2).getString(1) == "r3c2") + } + + test("Test GenericLocation load Sqlite with SQL query") { + val path = s"${System.getProperty("user.dir")}/src/test/resources/mockdata/sqlite/test.db" + val configDoc = + s""" + |{ + | type: "generic" + | format: "jdbc" + | url: "jdbc:sqlite:${path}" + | query: "select c1, c2 from table1" + |}""".stripMargin + val ds = jackson.readValue(configDoc, classOf[DataLocation]) val _ = SparkSession.builder().config("spark.master", "local").appName("Sqlite test").getOrCreate() From 94c26b015acbe8b4da5104df55c2c49692f81ba6 Mon Sep 17 00:00:00 2001 From: Chen Xu Date: Fri, 15 Jul 2022 12:20:51 +0800 Subject: [PATCH 02/10] WIP --- build.sbt | 4 +- .../config/location/GenericLocation.scala | 49 +++++++++++++++---- .../NonTimeBasedDataSourceAccessor.scala | 2 +- 3 files changed, 43 insertions(+), 12 deletions(-) diff --git a/build.sbt b/build.sbt index a448478f6..b9e747117 100644 --- a/build.sbt +++ b/build.sbt @@ -44,7 +44,9 @@ val localAndCloudCommonDependencies = Seq( "net.snowflake" % "spark-snowflake_2.12" % "2.10.0-spark_3.2", "org.apache.commons" % "commons-lang3" % "3.12.0", "org.xerial" % "sqlite-jdbc" % "3.36.0.3", - "com.github.changvvb" %% "jackson-module-caseclass" % "1.1.1" + "com.github.changvvb" %% "jackson-module-caseclass" % "1.1.1", + "com.azure.cosmos.spark" % "azure-cosmos-spark_3-1_2-12" % "4.11.1", + "org.eclipse.jetty" % "jetty-util" % "9.3.24.v20180605" ) // Common deps val jdbcDrivers = Seq( diff --git a/src/main/scala/com/linkedin/feathr/offline/config/location/GenericLocation.scala b/src/main/scala/com/linkedin/feathr/offline/config/location/GenericLocation.scala index 12d473cb7..a42140dcf 100644 --- a/src/main/scala/com/linkedin/feathr/offline/config/location/GenericLocation.scala +++ b/src/main/scala/com/linkedin/feathr/offline/config/location/GenericLocation.scala @@ -2,11 +2,16 @@ package com.linkedin.feathr.offline.config.location import com.fasterxml.jackson.annotation.JsonAnySetter import com.fasterxml.jackson.module.caseclass.annotation.CaseClassDeserialize +import com.linkedin.feathr.common.exception.FeathrException import net.minidev.json.annotate.JsonIgnore import org.apache.spark.sql.{DataFrame, SparkSession} @CaseClassDeserialize() -case class GenericLocation(format: String, @JsonIgnore options: collection.mutable.Map[String, String] = collection.mutable.Map[String, String]()) extends DataLocation { +case class GenericLocation(format: String, + mode: Option[String] = None, + @JsonIgnore options: collection.mutable.Map[String, String] = collection.mutable.Map[String, String](), + @JsonIgnore conf: collection.mutable.Map[String, String] = collection.mutable.Map[String, String]() + ) extends DataLocation { /** * Backward Compatibility * Many existing codes expect a simple path @@ -35,8 +40,11 @@ case class GenericLocation(format: String, @JsonIgnore options: collection.mutab * @return */ override def loadDf(ss: SparkSession, dataIOParameters: Map[String, String]): DataFrame = { + conf.foreach(e => { + ss.conf.set(e._1, e._2) + }) ss.read.format(format) - .options(getOptions) + .options(options) .load() } @@ -47,9 +55,30 @@ case class GenericLocation(format: String, @JsonIgnore options: collection.mutab * @param df DataFrame to write */ override def writeDf(ss: SparkSession, df: DataFrame): Unit = { - df.write.format(format) - .options(getOptions) - .save() + conf.foreach(e => { + ss.conf.set(e._1, e._2) + }) + val keyDf = if (!df.columns.contains("id")) { + if(df.columns.contains("key0")) { + df.withColumnRenamed("key0", "id") + } else { + throw new FeathrException("DataFrame doesn't have id column") + } + } else { + df + } + val w = mode match { + case Some(m) => { + keyDf.write.format(format) + .options(options) + .mode(m) + } + case None => { + keyDf.write.format(format) + .options(options) + } + } + w.save() } /** @@ -59,12 +88,12 @@ case class GenericLocation(format: String, @JsonIgnore options: collection.mutab */ override def isFileBasedLocation(): Boolean = false - def getOptions(): Map[String, String] = { - options.map(e => e._1 -> LocationUtils.envSubstitute(e._2)).toMap - } - @JsonAnySetter def setOption(key: String, value: Any) = { - options += (key -> value.toString) + if (key.startsWith("__conf__")) { + conf += (key.stripPrefix("__conf__").replace("__", ".") -> LocationUtils.envSubstitute(value.toString)) + } else { + options += (key.replace("__", ".") -> LocationUtils.envSubstitute(value.toString)) + } } } diff --git a/src/main/scala/com/linkedin/feathr/offline/source/accessor/NonTimeBasedDataSourceAccessor.scala b/src/main/scala/com/linkedin/feathr/offline/source/accessor/NonTimeBasedDataSourceAccessor.scala index cebc7d54e..cb311ce6f 100644 --- a/src/main/scala/com/linkedin/feathr/offline/source/accessor/NonTimeBasedDataSourceAccessor.scala +++ b/src/main/scala/com/linkedin/feathr/offline/source/accessor/NonTimeBasedDataSourceAccessor.scala @@ -31,7 +31,7 @@ private[offline] class NonTimeBasedDataSourceAccessor( case SimplePath(path) => List(path).map(fileLoaderFactory.create(_).loadDataFrame()).reduce((x, y) => x.fuzzyUnion(y)) case PathList(paths) => paths.map(fileLoaderFactory.create(_).loadDataFrame()).reduce((x, y) => x.fuzzyUnion(y)) case Jdbc(_, _, _, _, _, _, _) => source.location.loadDf(SparkSession.builder().getOrCreate()) - case GenericLocation(_, _) => source.location.loadDf(SparkSession.builder().getOrCreate()) + case GenericLocation(_, _, _, _) => source.location.loadDf(SparkSession.builder().getOrCreate()) case _ => fileLoaderFactory.createFromLocation(source.location).loadDataFrame() } From fb2c4e9db521f3d444a8b6a7b37d18395d4df1c8 Mon Sep 17 00:00:00 2001 From: Chen Xu Date: Fri, 15 Jul 2022 16:08:30 +0800 Subject: [PATCH 03/10] Generate id column --- .../config/location/DataLocation.scala | 5 +- .../config/location/GenericLocation.scala | 110 +++++++++++++----- .../feathr/offline/config/location/Jdbc.scala | 3 +- .../config/location/KafkaEndpoint.scala | 3 +- .../offline/config/location/PathList.scala | 3 +- .../offline/config/location/SimplePath.scala | 3 +- .../offline/generation/SparkIOUtils.scala | 2 +- .../WriteToHDFSOutputProcessor.scala | 2 +- 8 files changed, 91 insertions(+), 40 deletions(-) diff --git a/src/main/scala/com/linkedin/feathr/offline/config/location/DataLocation.scala b/src/main/scala/com/linkedin/feathr/offline/config/location/DataLocation.scala index 781997e1f..4aa2bc760 100644 --- a/src/main/scala/com/linkedin/feathr/offline/config/location/DataLocation.scala +++ b/src/main/scala/com/linkedin/feathr/offline/config/location/DataLocation.scala @@ -6,11 +6,12 @@ import com.fasterxml.jackson.databind.module.SimpleModule import com.fasterxml.jackson.databind.{DeserializationFeature, ObjectMapper} import com.fasterxml.jackson.module.caseclass.mapper.CaseClassObjectMapper import com.jasonclawson.jackson.dataformat.hocon.HoconFactory -import com.linkedin.feathr.common.FeathrJacksonScalaModule +import com.linkedin.feathr.common.{FeathrJacksonScalaModule, Header} import com.linkedin.feathr.offline.config.DataSourceLoader import com.linkedin.feathr.offline.source.DataSource import com.typesafe.config.Config import org.apache.spark.sql.{DataFrame, SparkSession} + import scala.collection.JavaConverters._ /** @@ -62,7 +63,7 @@ trait DataLocation { * @param ss SparkSession * @param df DataFrame to write */ - def writeDf(ss: SparkSession, df: DataFrame) + def writeDf(ss: SparkSession, df: DataFrame, header: Option[Header]) /** * Tell if this location is file based diff --git a/src/main/scala/com/linkedin/feathr/offline/config/location/GenericLocation.scala b/src/main/scala/com/linkedin/feathr/offline/config/location/GenericLocation.scala index a42140dcf..c31810ec2 100644 --- a/src/main/scala/com/linkedin/feathr/offline/config/location/GenericLocation.scala +++ b/src/main/scala/com/linkedin/feathr/offline/config/location/GenericLocation.scala @@ -2,9 +2,12 @@ package com.linkedin.feathr.offline.config.location import com.fasterxml.jackson.annotation.JsonAnySetter import com.fasterxml.jackson.module.caseclass.annotation.CaseClassDeserialize +import com.linkedin.feathr.common.Header import com.linkedin.feathr.common.exception.FeathrException +import com.linkedin.feathr.offline.generation.FeatureGenUtils import net.minidev.json.annotate.JsonIgnore -import org.apache.spark.sql.{DataFrame, SparkSession} +import org.apache.spark.sql.functions.monotonically_increasing_id +import org.apache.spark.sql.{DataFrame, DataFrameWriter, Row, SparkSession} @CaseClassDeserialize() case class GenericLocation(format: String, @@ -40,12 +43,7 @@ case class GenericLocation(format: String, * @return */ override def loadDf(ss: SparkSession, dataIOParameters: Map[String, String]): DataFrame = { - conf.foreach(e => { - ss.conf.set(e._1, e._2) - }) - ss.read.format(format) - .options(options) - .load() + GenericLocationFixes.readDf(ss, this) } /** @@ -54,31 +52,8 @@ case class GenericLocation(format: String, * @param ss SparkSession * @param df DataFrame to write */ - override def writeDf(ss: SparkSession, df: DataFrame): Unit = { - conf.foreach(e => { - ss.conf.set(e._1, e._2) - }) - val keyDf = if (!df.columns.contains("id")) { - if(df.columns.contains("key0")) { - df.withColumnRenamed("key0", "id") - } else { - throw new FeathrException("DataFrame doesn't have id column") - } - } else { - df - } - val w = mode match { - case Some(m) => { - keyDf.write.format(format) - .options(options) - .mode(m) - } - case None => { - keyDf.write.format(format) - .options(options) - } - } - w.save() + override def writeDf(ss: SparkSession, df: DataFrame, header: Option[Header]): Unit = { + GenericLocationFixes.writeDf(ss, df, header, this) } /** @@ -97,3 +72,74 @@ case class GenericLocation(format: String, } } } + +/** + * Some Spark connectors need extra actions before read or write, namely CosmosDb and ElasticSearch + * Need to run specific fixes base on `format` + */ +object GenericLocationFixes { + def readDf(ss: SparkSession, location: GenericLocation): DataFrame = { + location.conf.foreach(e => { + ss.conf.set(e._1, e._2) + }) + ss.read.format(location.format) + .options(location.options) + .load() + } + + def writeDf(ss: SparkSession, df: DataFrame, header: Option[Header], location: GenericLocation) = { + location.conf.foreach(e => { + ss.conf.set(e._1, e._2) + }) + + location.format.toLowerCase() match { + case "cosmos.oltp" => + // Ensure the database and the table exist before writing + val endpoint = location.options.getOrElse("spark.cosmos.accountEndpoint", throw new FeathrException("Missing spark__cosmos__accountEndpoint")) + val key = location.options.getOrElse("spark.cosmos.accountKey", throw new FeathrException("Missing spark__cosmos__accountKey")) + val databaseName = location.options.getOrElse("spark.cosmos.database", throw new FeathrException("Missing spark__cosmos__database")) + val tableName = location.options.getOrElse("spark.cosmos.container", throw new FeathrException("Missing spark__cosmos__container")) + ss.conf.set("spark.sql.catalog.cosmosCatalog", "com.azure.cosmos.spark.CosmosCatalog") + ss.conf.set("spark.sql.catalog.cosmosCatalog.spark.cosmos.accountEndpoint", endpoint) + ss.conf.set("spark.sql.catalog.cosmosCatalog.spark.cosmos.accountKey", key) + ss.sql(s"CREATE DATABASE IF NOT EXISTS cosmosCatalog.${databaseName};") + ss.sql(s"CREATE TABLE IF NOT EXISTS cosmosCatalog.${databaseName}.${tableName} using cosmos.oltp TBLPROPERTIES(partitionKeyPath = '/id')") + + // CosmosDb requires the column `id` to exist and be the primary key + val keyDf = if (!df.columns.contains("id")) { + header match { + case Some(h) => { + // We have the header info, copy the 1st key column to `id`, which is required by CosmosDb + val key = FeatureGenUtils.getKeyColumnsFromHeader(h).head + // Copy key column to `id` + df.withColumn("id", df.col(key)) + } + case None => { + // If there is no key column, we use a auto-generated monotonic id. + // but in this case the result could be duplicated if you run job for multiple times + // This function is for offline-storage usage, ideally user should create a new container for every run + df.withColumn("id", monotonically_increasing_id()) + } + } + } else { + // We already have an `id` column + // TODO: Should we do anything here? + // A corner case is that the `id` column exists but not unique, then the output will be incomplete as + // CosmosDb will overwrite the old entry with the new one with same `id`. + // We can either rename the existing `id` column and use header/autogen key column, or we can tell user + // to avoid using `id` column for non-unique data, but both workarounds have pros and cons. + df + } + keyDf.write.format(location.format) + .options(location.options) + .mode(location.mode.getOrElse("append")) // CosmosDb doesn't support ErrorIfExist mode in batch mode + .save() + case _ => + // Normal writing procedure, just set format and options then write + df.write.format(location.format) + .options(location.options) + .mode(location.mode.getOrElse("default")) + .save() + } + } +} diff --git a/src/main/scala/com/linkedin/feathr/offline/config/location/Jdbc.scala b/src/main/scala/com/linkedin/feathr/offline/config/location/Jdbc.scala index 0d70b1a48..bca5a887c 100644 --- a/src/main/scala/com/linkedin/feathr/offline/config/location/Jdbc.scala +++ b/src/main/scala/com/linkedin/feathr/offline/config/location/Jdbc.scala @@ -2,6 +2,7 @@ package com.linkedin.feathr.offline.config.location import com.fasterxml.jackson.annotation.JsonAlias import com.fasterxml.jackson.module.caseclass.annotation.CaseClassDeserialize +import com.linkedin.feathr.common.Header import com.linkedin.feathr.offline.source.dataloader.jdbc.JdbcUtils import com.linkedin.feathr.offline.source.dataloader.jdbc.JdbcUtils.DBTABLE_CONF import org.apache.spark.sql.{DataFrame, SparkSession} @@ -47,7 +48,7 @@ case class Jdbc(url: String, @JsonAlias(Array("query")) dbtable: String, user: S } } - override def writeDf(ss: SparkSession, df: DataFrame): Unit = { + override def writeDf(ss: SparkSession, df: DataFrame, header: Option[Header]): Unit = { println(s"Jdbc.writeDf, location is ${this}") if (StringUtil.isBlank(user) && StringUtil.isBlank(password) && !anonymous && !useToken) { // Fallback to global JDBC credential diff --git a/src/main/scala/com/linkedin/feathr/offline/config/location/KafkaEndpoint.scala b/src/main/scala/com/linkedin/feathr/offline/config/location/KafkaEndpoint.scala index e3c3a7298..f3818700b 100644 --- a/src/main/scala/com/linkedin/feathr/offline/config/location/KafkaEndpoint.scala +++ b/src/main/scala/com/linkedin/feathr/offline/config/location/KafkaEndpoint.scala @@ -1,6 +1,7 @@ package com.linkedin.feathr.offline.config.location import com.fasterxml.jackson.module.caseclass.annotation.CaseClassDeserialize +import com.linkedin.feathr.common.Header import org.apache.spark.sql.{DataFrame, SparkSession} import org.codehaus.jackson.annotate.JsonProperty @@ -31,7 +32,7 @@ case class KafkaEndpoint(@JsonProperty("brokers") brokers: List[String], @JsonProperty("schema") schema: KafkaSchema) extends DataLocation { override def loadDf(ss: SparkSession, dataIOParameters: Map[String, String] = Map()): DataFrame = ??? - override def writeDf(ss: SparkSession, df: DataFrame): Unit = ??? + override def writeDf(ss: SparkSession, df: DataFrame, header: Option[Header]): Unit = ??? override def getPath: String = "kafka://" + brokers.mkString(",")+":"+topics.mkString(",") diff --git a/src/main/scala/com/linkedin/feathr/offline/config/location/PathList.scala b/src/main/scala/com/linkedin/feathr/offline/config/location/PathList.scala index 2f16091ee..c583de8a3 100644 --- a/src/main/scala/com/linkedin/feathr/offline/config/location/PathList.scala +++ b/src/main/scala/com/linkedin/feathr/offline/config/location/PathList.scala @@ -1,5 +1,6 @@ package com.linkedin.feathr.offline.config.location +import com.linkedin.feathr.common.Header import com.linkedin.feathr.offline.generation.SparkIOUtils import org.apache.spark.sql.{DataFrame, SparkSession} import org.apache.hadoop.mapred.JobConf @@ -15,7 +16,7 @@ case class PathList(paths: List[String]) extends DataLocation { SparkIOUtils.createUnionDataFrame(getPathList, dataIOParameters, new JobConf(), List()) //TODO: Add handler support here. Currently there are deserilization issues with adding handlers to factory builder. } - override def writeDf(ss: SparkSession, df: DataFrame): Unit = ??? + override def writeDf(ss: SparkSession, df: DataFrame, header: Option[Header]): Unit = ??? override def toString: String = s"PathList(path=[${paths.mkString(",")}])" } diff --git a/src/main/scala/com/linkedin/feathr/offline/config/location/SimplePath.scala b/src/main/scala/com/linkedin/feathr/offline/config/location/SimplePath.scala index e94e2211b..d2d1e2db6 100644 --- a/src/main/scala/com/linkedin/feathr/offline/config/location/SimplePath.scala +++ b/src/main/scala/com/linkedin/feathr/offline/config/location/SimplePath.scala @@ -1,6 +1,7 @@ package com.linkedin.feathr.offline.config.location import com.fasterxml.jackson.module.caseclass.annotation.CaseClassDeserialize +import com.linkedin.feathr.common.Header import com.linkedin.feathr.offline.generation.SparkIOUtils import org.apache.hadoop.mapred.JobConf import org.apache.spark.sql.{DataFrame, SparkSession} @@ -12,7 +13,7 @@ case class SimplePath(@JsonProperty("path") path: String) extends DataLocation { SparkIOUtils.createUnionDataFrame(getPathList, dataIOParameters, new JobConf(), List()) // The simple path is not responsible for handling custom data loaders. } - override def writeDf(ss: SparkSession, df: DataFrame): Unit = ??? + override def writeDf(ss: SparkSession, df: DataFrame, header: Option[Header]): Unit = ??? override def getPath: String = path diff --git a/src/main/scala/com/linkedin/feathr/offline/generation/SparkIOUtils.scala b/src/main/scala/com/linkedin/feathr/offline/generation/SparkIOUtils.scala index e7c912f17..3d3052c41 100644 --- a/src/main/scala/com/linkedin/feathr/offline/generation/SparkIOUtils.scala +++ b/src/main/scala/com/linkedin/feathr/offline/generation/SparkIOUtils.scala @@ -77,7 +77,7 @@ object SparkIOUtils { outputDF.write.mode(SaveMode.Overwrite).format(output_format).save(path) outputDF } - case _ => outputLocation.writeDf(SparkSession.builder().getOrCreate(), outputDF) + case _ => outputLocation.writeDf(SparkSession.builder().getOrCreate(), outputDF, None) } } outputDF diff --git a/src/main/scala/com/linkedin/feathr/offline/generation/outputProcessor/WriteToHDFSOutputProcessor.scala b/src/main/scala/com/linkedin/feathr/offline/generation/outputProcessor/WriteToHDFSOutputProcessor.scala index 0d0438397..2b86642ad 100644 --- a/src/main/scala/com/linkedin/feathr/offline/generation/outputProcessor/WriteToHDFSOutputProcessor.scala +++ b/src/main/scala/com/linkedin/feathr/offline/generation/outputProcessor/WriteToHDFSOutputProcessor.scala @@ -153,7 +153,7 @@ private[offline] class WriteToHDFSOutputProcessor(val config: OutputProcessorCon location match { case Some(l) => { // We have a DataLocation to write the df - l.writeDf(ss, augmentedDF) + l.writeDf(ss, augmentedDF, Some(header)) (augmentedDF, header) } case None => { From 8d4d5ce753b86f8d5e091d78e5fdcd76cf557815 Mon Sep 17 00:00:00 2001 From: Chen Xu Date: Fri, 15 Jul 2022 17:11:50 +0800 Subject: [PATCH 04/10] Fix unit test --- .../feathr/offline/config/location/DataLocation.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/main/scala/com/linkedin/feathr/offline/config/location/DataLocation.scala b/src/main/scala/com/linkedin/feathr/offline/config/location/DataLocation.scala index 4aa2bc760..c85136eb7 100644 --- a/src/main/scala/com/linkedin/feathr/offline/config/location/DataLocation.scala +++ b/src/main/scala/com/linkedin/feathr/offline/config/location/DataLocation.scala @@ -9,7 +9,7 @@ import com.jasonclawson.jackson.dataformat.hocon.HoconFactory import com.linkedin.feathr.common.{FeathrJacksonScalaModule, Header} import com.linkedin.feathr.offline.config.DataSourceLoader import com.linkedin.feathr.offline.source.DataSource -import com.typesafe.config.Config +import com.typesafe.config.{Config, ConfigException} import org.apache.spark.sql.{DataFrame, SparkSession} import scala.collection.JavaConverters._ @@ -121,7 +121,7 @@ object DataLocation { val location = jackson.readValue(cfg, classOf[DataLocation]) location } catch { - case _: JacksonException => SimplePath(cfg) + case _ @ (_: ConfigException | _: JacksonException) => SimplePath(cfg) } } From 278d3cb2de3e4417d94de3ddd503fdcc4020f54c Mon Sep 17 00:00:00 2001 From: Chen Xu Date: Fri, 15 Jul 2022 18:16:15 +0800 Subject: [PATCH 05/10] Parse string into DataLocation --- .../feathr/offline/config/location/DataLocation.scala | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/main/scala/com/linkedin/feathr/offline/config/location/DataLocation.scala b/src/main/scala/com/linkedin/feathr/offline/config/location/DataLocation.scala index c85136eb7..71a1c4191 100644 --- a/src/main/scala/com/linkedin/feathr/offline/config/location/DataLocation.scala +++ b/src/main/scala/com/linkedin/feathr/offline/config/location/DataLocation.scala @@ -118,8 +118,13 @@ object DataLocation { .configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false) .registerModule(new SimpleModule().addDeserializer(classOf[DataSource], new DataSourceLoader)) try { - val location = jackson.readValue(cfg, classOf[DataLocation]) - location + // Cfg is either a plain path or a JSON object + if (cfg.trim.startsWith("{")) { + val location = jackson.readValue(cfg, classOf[DataLocation]) + location + } else { + SimplePath(cfg) + } } catch { case _ @ (_: ConfigException | _: JacksonException) => SimplePath(cfg) } From 10bd7f2cf68a40232a50f9cf81bdb0c2e22c747e Mon Sep 17 00:00:00 2001 From: Chen Xu Date: Fri, 15 Jul 2022 20:55:18 +0800 Subject: [PATCH 06/10] Id column must be string --- .../feathr/offline/config/location/GenericLocation.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/main/scala/com/linkedin/feathr/offline/config/location/GenericLocation.scala b/src/main/scala/com/linkedin/feathr/offline/config/location/GenericLocation.scala index c31810ec2..2661f3e1e 100644 --- a/src/main/scala/com/linkedin/feathr/offline/config/location/GenericLocation.scala +++ b/src/main/scala/com/linkedin/feathr/offline/config/location/GenericLocation.scala @@ -105,20 +105,20 @@ object GenericLocationFixes { ss.sql(s"CREATE DATABASE IF NOT EXISTS cosmosCatalog.${databaseName};") ss.sql(s"CREATE TABLE IF NOT EXISTS cosmosCatalog.${databaseName}.${tableName} using cosmos.oltp TBLPROPERTIES(partitionKeyPath = '/id')") - // CosmosDb requires the column `id` to exist and be the primary key + // CosmosDb requires the column `id` to exist and be the primary key, and `id` must be in `string` type val keyDf = if (!df.columns.contains("id")) { header match { case Some(h) => { // We have the header info, copy the 1st key column to `id`, which is required by CosmosDb val key = FeatureGenUtils.getKeyColumnsFromHeader(h).head // Copy key column to `id` - df.withColumn("id", df.col(key)) + df.withColumn("id", df.col(key).cast("string")) } case None => { // If there is no key column, we use a auto-generated monotonic id. // but in this case the result could be duplicated if you run job for multiple times // This function is for offline-storage usage, ideally user should create a new container for every run - df.withColumn("id", monotonically_increasing_id()) + df.withColumn("id", (monotonically_increasing_id().cast("string"))) } } } else { From b2c733db577eac11ec54b8d1ce596d2764a76457 Mon Sep 17 00:00:00 2001 From: Chen Xu Date: Mon, 18 Jul 2022 03:51:20 +0800 Subject: [PATCH 07/10] Fix auth logic --- .../feathr/offline/config/location/Jdbc.scala | 31 ++++++------------- .../NonTimeBasedDataSourceAccessor.scala | 2 +- 2 files changed, 11 insertions(+), 22 deletions(-) diff --git a/src/main/scala/com/linkedin/feathr/offline/config/location/Jdbc.scala b/src/main/scala/com/linkedin/feathr/offline/config/location/Jdbc.scala index bca5a887c..2837a87a8 100644 --- a/src/main/scala/com/linkedin/feathr/offline/config/location/Jdbc.scala +++ b/src/main/scala/com/linkedin/feathr/offline/config/location/Jdbc.scala @@ -1,6 +1,6 @@ package com.linkedin.feathr.offline.config.location -import com.fasterxml.jackson.annotation.JsonAlias +import com.fasterxml.jackson.annotation.{JsonAlias, JsonIgnoreProperties} import com.fasterxml.jackson.module.caseclass.annotation.CaseClassDeserialize import com.linkedin.feathr.common.Header import com.linkedin.feathr.offline.source.dataloader.jdbc.JdbcUtils @@ -9,7 +9,8 @@ import org.apache.spark.sql.{DataFrame, SparkSession} import org.eclipse.jetty.util.StringUtil @CaseClassDeserialize() -case class Jdbc(url: String, @JsonAlias(Array("query")) dbtable: String, user: String = "", password: String = "", token: String = "", useToken: Boolean = false, anonymous: Boolean = false) extends DataLocation { +@JsonIgnoreProperties(ignoreUnknown = true) +case class Jdbc(url: String, @JsonAlias(Array("query")) dbtable: String, user: String = "", password: String = "", token: String = "") extends DataLocation { override def loadDf(ss: SparkSession, dataIOParameters: Map[String, String] = Map()): DataFrame = { println(s"Jdbc.loadDf, location is ${this}") var reader = ss.read.format("jdbc") @@ -26,21 +27,14 @@ case class Jdbc(url: String, @JsonAlias(Array("query")) dbtable: String, user: S reader = reader.option("dbtable", q) } } - if (useToken) { + if (!StringUtil.isBlank(token)) { reader.option("accessToken", LocationUtils.envSubstitute(token)) .option("hostNameInCertificate", "*.database.windows.net") .option("encrypt", true) .load } else { if (StringUtil.isBlank(user) && StringUtil.isBlank(password)) { - if (anonymous) { - reader.load() - } else { - // Fallback to global JDBC credential - println("Fallback to default credential") - ss.conf.set(DBTABLE_CONF, dbtable) - JdbcUtils.loadDataFrame(ss, url) - } + reader.load() } else { reader.option("user", LocationUtils.envSubstitute(user)) .option("password", LocationUtils.envSubstitute(password)) @@ -50,11 +44,6 @@ case class Jdbc(url: String, @JsonAlias(Array("query")) dbtable: String, user: S override def writeDf(ss: SparkSession, df: DataFrame, header: Option[Header]): Unit = { println(s"Jdbc.writeDf, location is ${this}") - if (StringUtil.isBlank(user) && StringUtil.isBlank(password) && !anonymous && !useToken) { - // Fallback to global JDBC credential - println("Fallback to default credential") - ss.conf.set(DBTABLE_CONF, dbtable) - } df.write.format("jdbc") .options(getOptions(ss)) .save() @@ -67,7 +56,7 @@ case class Jdbc(url: String, @JsonAlias(Array("query")) dbtable: String, user: S override def isFileBasedLocation(): Boolean = false // These members don't contain actual secrets - override def toString: String = s"Jdbc(url=$url, dbtable=$dbtable, useToken=$useToken, anonymous=$anonymous, user=$user, password=$password, token=$token)" + override def toString: String = s"Jdbc(url=$url, dbtable=$dbtable, user=$user, password=$password, token=$token)" def getOptions(ss: SparkSession): Map[String, String] = { val options = collection.mutable.Map[String, String]() @@ -85,7 +74,7 @@ case class Jdbc(url: String, @JsonAlias(Array("query")) dbtable: String, user: S options += ("dbtable" -> q) } } - if (useToken) { + if (!StringUtil.isBlank(token)) { options += ("accessToken" -> LocationUtils.envSubstitute(token)) options += ("hostNameInCertificate" -> "*.database.windows.net") options += ("encrypt" -> "true") @@ -111,7 +100,7 @@ case class Jdbc(url: String, @JsonAlias(Array("query")) dbtable: String, user: S * @param password * @return Newly created InputLocation instance */ - def apply(url: String, dbtable: String, user: String, password: String): Jdbc = Jdbc(url, dbtable, user = user, password = password, useToken = false) + def apply(url: String, dbtable: String, user: String, password: String): Jdbc = Jdbc(url, dbtable, user = user, password = password) /** * Create JDBC InputLocation with required info and OAuth token auth @@ -121,7 +110,7 @@ case class Jdbc(url: String, @JsonAlias(Array("query")) dbtable: String, user: S * @param token * @return Newly created InputLocation instance */ - def apply(url: String, dbtable: String, token: String): Jdbc = Jdbc(url, dbtable, token = token, useToken = true) + def apply(url: String, dbtable: String, token: String): Jdbc = Jdbc(url, dbtable, token = token) /** * Create JDBC InputLocation with required info and OAuth token auth @@ -132,5 +121,5 @@ case class Jdbc(url: String, @JsonAlias(Array("query")) dbtable: String, user: S * @param dbtable * @return Newly created InputLocation instance */ - def apply(url: String, dbtable: String): Jdbc = Jdbc(url, dbtable, useToken = false) + def apply(url: String, dbtable: String): Jdbc = Jdbc(url, dbtable) } diff --git a/src/main/scala/com/linkedin/feathr/offline/source/accessor/NonTimeBasedDataSourceAccessor.scala b/src/main/scala/com/linkedin/feathr/offline/source/accessor/NonTimeBasedDataSourceAccessor.scala index cb311ce6f..51ac76ef0 100644 --- a/src/main/scala/com/linkedin/feathr/offline/source/accessor/NonTimeBasedDataSourceAccessor.scala +++ b/src/main/scala/com/linkedin/feathr/offline/source/accessor/NonTimeBasedDataSourceAccessor.scala @@ -30,7 +30,7 @@ private[offline] class NonTimeBasedDataSourceAccessor( val df = source.location match { case SimplePath(path) => List(path).map(fileLoaderFactory.create(_).loadDataFrame()).reduce((x, y) => x.fuzzyUnion(y)) case PathList(paths) => paths.map(fileLoaderFactory.create(_).loadDataFrame()).reduce((x, y) => x.fuzzyUnion(y)) - case Jdbc(_, _, _, _, _, _, _) => source.location.loadDf(SparkSession.builder().getOrCreate()) + case Jdbc(_, _, _, _, _) => source.location.loadDf(SparkSession.builder().getOrCreate()) case GenericLocation(_, _, _, _) => source.location.loadDf(SparkSession.builder().getOrCreate()) case _ => fileLoaderFactory.createFromLocation(source.location).loadDataFrame() } From 489031b2c784bbe4f69c035f6386784cc96e8a7b Mon Sep 17 00:00:00 2001 From: Chen Xu Date: Mon, 18 Jul 2022 12:15:51 +0800 Subject: [PATCH 08/10] Fix unit test --- .../linkedin/feathr/offline/config/TestDataSourceLoader.scala | 2 +- .../feathr/offline/config/location/TestDesLocation.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/test/scala/com/linkedin/feathr/offline/config/TestDataSourceLoader.scala b/src/test/scala/com/linkedin/feathr/offline/config/TestDataSourceLoader.scala index 762d3fd77..585e2eab6 100644 --- a/src/test/scala/com/linkedin/feathr/offline/config/TestDataSourceLoader.scala +++ b/src/test/scala/com/linkedin/feathr/offline/config/TestDataSourceLoader.scala @@ -54,7 +54,7 @@ class TestDataSourceLoader extends FunSuite { |""".stripMargin val ds = jackson.readValue(configDoc, classOf[DataSource]) ds.location match { - case Jdbc(url, dbtable, user, password, token, useToken, _) => { + case Jdbc(url, dbtable, user, password, token) => { assert(url == "jdbc:sqlserver://myserver.database.windows.net:1433;database=mydatabase") assert(user=="bar") assert(password=="foo") diff --git a/src/test/scala/com/linkedin/feathr/offline/config/location/TestDesLocation.scala b/src/test/scala/com/linkedin/feathr/offline/config/location/TestDesLocation.scala index dcee1a909..1be2adf77 100644 --- a/src/test/scala/com/linkedin/feathr/offline/config/location/TestDesLocation.scala +++ b/src/test/scala/com/linkedin/feathr/offline/config/location/TestDesLocation.scala @@ -52,7 +52,7 @@ class TestDesLocation extends FunSuite { |}""".stripMargin val ds = jackson.readValue(configDoc, classOf[DataLocation]) ds match { - case Jdbc(url, dbtable, user, password, token, useToken, _) => { + case Jdbc(url, dbtable, user, password, token) => { assert(url == "jdbc:sqlserver://myserver.database.windows.net:1433;database=mydatabase") assert(user == "bar") assert(password == "foo") From 75e20279e858f12b9bdbeee4f7153aaf003e2e7a Mon Sep 17 00:00:00 2001 From: Chen Xu Date: Fri, 22 Jul 2022 13:38:07 +0800 Subject: [PATCH 09/10] Fix id column generation --- .../feathr/offline/config/location/GenericLocation.scala | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/main/scala/com/linkedin/feathr/offline/config/location/GenericLocation.scala b/src/main/scala/com/linkedin/feathr/offline/config/location/GenericLocation.scala index 2661f3e1e..e70bd01e9 100644 --- a/src/main/scala/com/linkedin/feathr/offline/config/location/GenericLocation.scala +++ b/src/main/scala/com/linkedin/feathr/offline/config/location/GenericLocation.scala @@ -5,6 +5,7 @@ import com.fasterxml.jackson.module.caseclass.annotation.CaseClassDeserialize import com.linkedin.feathr.common.Header import com.linkedin.feathr.common.exception.FeathrException import com.linkedin.feathr.offline.generation.FeatureGenUtils +import com.linkedin.feathr.offline.join.DataFrameKeyCombiner import net.minidev.json.annotate.JsonIgnore import org.apache.spark.sql.functions.monotonically_increasing_id import org.apache.spark.sql.{DataFrame, DataFrameWriter, Row, SparkSession} @@ -109,10 +110,10 @@ object GenericLocationFixes { val keyDf = if (!df.columns.contains("id")) { header match { case Some(h) => { - // We have the header info, copy the 1st key column to `id`, which is required by CosmosDb - val key = FeatureGenUtils.getKeyColumnsFromHeader(h).head - // Copy key column to `id` - df.withColumn("id", df.col(key).cast("string")) + // Generate key column from header info, which is required by CosmosDb + val (keyCol, keyedDf) = DataFrameKeyCombiner().combine(df, FeatureGenUtils.getKeyColumnsFromHeader(h)) + // Rename key column to `id` + keyedDf.withColumnRenamed(keyCol, "id") } case None => { // If there is no key column, we use a auto-generated monotonic id. From c62007b8d9dc323ff25dbe1cf1c2181fa9626f27 Mon Sep 17 00:00:00 2001 From: Chen Xu Date: Tue, 26 Jul 2022 04:13:56 +0800 Subject: [PATCH 10/10] CosmosDb Sink --- .../config/location/GenericLocation.scala | 31 ++++++++++++++----- .../NonTimeBasedDataSourceAccessor.scala | 2 +- 2 files changed, 24 insertions(+), 9 deletions(-) diff --git a/src/main/scala/com/linkedin/feathr/offline/config/location/GenericLocation.scala b/src/main/scala/com/linkedin/feathr/offline/config/location/GenericLocation.scala index e70bd01e9..1ad8e94ac 100644 --- a/src/main/scala/com/linkedin/feathr/offline/config/location/GenericLocation.scala +++ b/src/main/scala/com/linkedin/feathr/offline/config/location/GenericLocation.scala @@ -7,15 +7,16 @@ import com.linkedin.feathr.common.exception.FeathrException import com.linkedin.feathr.offline.generation.FeatureGenUtils import com.linkedin.feathr.offline.join.DataFrameKeyCombiner import net.minidev.json.annotate.JsonIgnore +import org.apache.log4j.Logger import org.apache.spark.sql.functions.monotonically_increasing_id import org.apache.spark.sql.{DataFrame, DataFrameWriter, Row, SparkSession} @CaseClassDeserialize() -case class GenericLocation(format: String, - mode: Option[String] = None, - @JsonIgnore options: collection.mutable.Map[String, String] = collection.mutable.Map[String, String](), - @JsonIgnore conf: collection.mutable.Map[String, String] = collection.mutable.Map[String, String]() - ) extends DataLocation { +case class GenericLocation(format: String, mode: Option[String] = None) extends DataLocation { + val log: Logger = Logger.getLogger(getClass) + val options: collection.mutable.Map[String, String] = collection.mutable.Map[String, String]() + val conf: collection.mutable.Map[String, String] = collection.mutable.Map[String, String]() + /** * Backward Compatibility * Many existing codes expect a simple path @@ -65,11 +66,25 @@ case class GenericLocation(format: String, override def isFileBasedLocation(): Boolean = false @JsonAnySetter - def setOption(key: String, value: Any) = { + def setOption(key: String, value: Any): Unit = { + println(s"GenericLocation.setOption(key: $key, value: $value)") + if (key == null) { + log.warn("Got null key, skipping") + return + } + if (value == null) { + log.warn(s"Got null value for key '$key', skipping") + return + } + val v = value.toString + if (v == null) { + log.warn(s"Got invalid value for key '$key', skipping") + return + } if (key.startsWith("__conf__")) { - conf += (key.stripPrefix("__conf__").replace("__", ".") -> LocationUtils.envSubstitute(value.toString)) + conf += (key.stripPrefix("__conf__").replace("__", ".") -> LocationUtils.envSubstitute(v)) } else { - options += (key.replace("__", ".") -> LocationUtils.envSubstitute(value.toString)) + options += (key.replace("__", ".") -> LocationUtils.envSubstitute(v)) } } } diff --git a/src/main/scala/com/linkedin/feathr/offline/source/accessor/NonTimeBasedDataSourceAccessor.scala b/src/main/scala/com/linkedin/feathr/offline/source/accessor/NonTimeBasedDataSourceAccessor.scala index 51ac76ef0..385a0a833 100644 --- a/src/main/scala/com/linkedin/feathr/offline/source/accessor/NonTimeBasedDataSourceAccessor.scala +++ b/src/main/scala/com/linkedin/feathr/offline/source/accessor/NonTimeBasedDataSourceAccessor.scala @@ -31,7 +31,7 @@ private[offline] class NonTimeBasedDataSourceAccessor( case SimplePath(path) => List(path).map(fileLoaderFactory.create(_).loadDataFrame()).reduce((x, y) => x.fuzzyUnion(y)) case PathList(paths) => paths.map(fileLoaderFactory.create(_).loadDataFrame()).reduce((x, y) => x.fuzzyUnion(y)) case Jdbc(_, _, _, _, _) => source.location.loadDf(SparkSession.builder().getOrCreate()) - case GenericLocation(_, _, _, _) => source.location.loadDf(SparkSession.builder().getOrCreate()) + case GenericLocation(_, _) => source.location.loadDf(SparkSession.builder().getOrCreate()) case _ => fileLoaderFactory.createFromLocation(source.location).loadDataFrame() }