diff --git a/build.sbt b/build.sbt index a448478f6..b9e747117 100644 --- a/build.sbt +++ b/build.sbt @@ -44,7 +44,9 @@ val localAndCloudCommonDependencies = Seq( "net.snowflake" % "spark-snowflake_2.12" % "2.10.0-spark_3.2", "org.apache.commons" % "commons-lang3" % "3.12.0", "org.xerial" % "sqlite-jdbc" % "3.36.0.3", - "com.github.changvvb" %% "jackson-module-caseclass" % "1.1.1" + "com.github.changvvb" %% "jackson-module-caseclass" % "1.1.1", + "com.azure.cosmos.spark" % "azure-cosmos-spark_3-1_2-12" % "4.11.1", + "org.eclipse.jetty" % "jetty-util" % "9.3.24.v20180605" ) // Common deps val jdbcDrivers = Seq( diff --git a/src/main/scala/com/linkedin/feathr/offline/config/FeathrConfigLoader.scala b/src/main/scala/com/linkedin/feathr/offline/config/FeathrConfigLoader.scala index a1d639ac3..4aa122339 100644 --- a/src/main/scala/com/linkedin/feathr/offline/config/FeathrConfigLoader.scala +++ b/src/main/scala/com/linkedin/feathr/offline/config/FeathrConfigLoader.scala @@ -14,7 +14,7 @@ import com.linkedin.feathr.offline.anchored.anchorExtractor.{SQLConfigurableAnch import com.linkedin.feathr.offline.anchored.feature.{FeatureAnchor, FeatureAnchorWithSource} import com.linkedin.feathr.offline.anchored.keyExtractor.{MVELSourceKeyExtractor, SQLSourceKeyExtractor} import com.linkedin.feathr.offline.client.plugins.{AnchorExtractorAdaptor, FeathrUdfPluginContext, FeatureDerivationFunctionAdaptor, SimpleAnchorExtractorSparkAdaptor, SourceKeyExtractorAdaptor} -import com.linkedin.feathr.offline.config.location.{InputLocation, Jdbc, KafkaEndpoint, LocationUtils, SimplePath} +import com.linkedin.feathr.offline.config.location.{DataLocation, KafkaEndpoint, LocationUtils, SimplePath} import com.linkedin.feathr.offline.derived._ import com.linkedin.feathr.offline.derived.functions.{MvelFeatureDerivationFunction, SQLFeatureDerivationFunction, SeqJoinDerivationFunction, SimpleMvelDerivationFunction} import com.linkedin.feathr.offline.source.{DataSource, SourceFormatType, TimeWindowParams} @@ -735,7 +735,7 @@ private[offline] class DataSourceLoader extends JsonDeserializer[DataSource] { * 2. a placeholder with reserved string "PASSTHROUGH" for anchor defined pass-through features, * since anchor defined pass-through features do not have path */ - val path: InputLocation = dataSourceType match { + val path: DataLocation = dataSourceType match { case "KAFKA" => Option(node.get("config")) match { case Some(field: ObjectNode) => @@ -748,7 +748,7 @@ private[offline] class DataSourceLoader extends JsonDeserializer[DataSource] { case "PASSTHROUGH" => SimplePath("PASSTHROUGH") case _ => Option(node.get("location")) match { case Some(field: ObjectNode) => - LocationUtils.getMapper().treeToValue(field, classOf[InputLocation]) + LocationUtils.getMapper().treeToValue(field, classOf[DataLocation]) case None => throw new FeathrConfigException(ErrorLabel.FEATHR_USER_ERROR, s"Data location is not defined for data source ${node.toPrettyString()}") case _ => throw new FeathrConfigException(ErrorLabel.FEATHR_USER_ERROR, diff --git a/src/main/scala/com/linkedin/feathr/offline/config/location/InputLocation.scala b/src/main/scala/com/linkedin/feathr/offline/config/location/DataLocation.scala similarity index 53% rename from src/main/scala/com/linkedin/feathr/offline/config/location/InputLocation.scala rename to src/main/scala/com/linkedin/feathr/offline/config/location/DataLocation.scala index e7c72bea1..71a1c4191 100644 --- a/src/main/scala/com/linkedin/feathr/offline/config/location/InputLocation.scala +++ b/src/main/scala/com/linkedin/feathr/offline/config/location/DataLocation.scala @@ -1,15 +1,19 @@ package com.linkedin.feathr.offline.config.location import com.fasterxml.jackson.annotation.{JsonSubTypes, JsonTypeInfo} +import com.fasterxml.jackson.core.JacksonException import com.fasterxml.jackson.databind.module.SimpleModule import com.fasterxml.jackson.databind.{DeserializationFeature, ObjectMapper} import com.fasterxml.jackson.module.caseclass.mapper.CaseClassObjectMapper import com.jasonclawson.jackson.dataformat.hocon.HoconFactory -import com.linkedin.feathr.common.FeathrJacksonScalaModule +import com.linkedin.feathr.common.{FeathrJacksonScalaModule, Header} import com.linkedin.feathr.offline.config.DataSourceLoader import com.linkedin.feathr.offline.source.DataSource +import com.typesafe.config.{Config, ConfigException} import org.apache.spark.sql.{DataFrame, SparkSession} +import scala.collection.JavaConverters._ + /** * An InputLocation is a data source definition, it can either be HDFS files or a JDBC database connection */ @@ -20,38 +24,50 @@ import org.apache.spark.sql.{DataFrame, SparkSession} new JsonSubTypes.Type(value = classOf[SimplePath], name = "path"), new JsonSubTypes.Type(value = classOf[PathList], name = "pathlist"), new JsonSubTypes.Type(value = classOf[Jdbc], name = "jdbc"), + new JsonSubTypes.Type(value = classOf[GenericLocation], name = "generic"), )) -trait InputLocation { +trait DataLocation { /** * Backward Compatibility * Many existing codes expect a simple path + * * @return the `path` or `url` of the data source * - * WARN: This method is deprecated, you must use match/case on InputLocation, - * and get `path` from `SimplePath` only + * WARN: This method is deprecated, you must use match/case on InputLocation, + * and get `path` from `SimplePath` only */ @deprecated("Do not use this method in any new code, it will be removed soon") def getPath: String /** * Backward Compatibility + * * @return the `path` or `url` of the data source, wrapped in an List * - * WARN: This method is deprecated, you must use match/case on InputLocation, - * and get `paths` from `PathList` only + * WARN: This method is deprecated, you must use match/case on InputLocation, + * and get `paths` from `PathList` only */ @deprecated("Do not use this method in any new code, it will be removed soon") def getPathList: List[String] /** * Load DataFrame from Spark session + * * @param ss SparkSession * @return */ def loadDf(ss: SparkSession, dataIOParameters: Map[String, String] = Map()): DataFrame + /** + * Write DataFrame to the location + * @param ss SparkSession + * @param df DataFrame to write + */ + def writeDf(ss: SparkSession, df: DataFrame, header: Option[Header]) + /** * Tell if this location is file based + * * @return boolean */ def isFileBasedLocation(): Boolean @@ -67,6 +83,7 @@ object LocationUtils { /** * String template substitution, replace "...${VAR}.." with corresponding System property or environment variable * Non-existent pattern is replaced by empty string. + * * @param s String template to be processed * @return Processed result */ @@ -76,6 +93,7 @@ object LocationUtils { /** * Get an ObjectMapper to deserialize DataSource + * * @return the ObjectMapper */ def getMapper(): ObjectMapper = { @@ -86,3 +104,50 @@ object LocationUtils { .registerModule(new SimpleModule().addDeserializer(classOf[DataSource], new DataSourceLoader)) } } + +object DataLocation { + /** + * Create DataLocation from string, try parsing the string as JSON and fallback to SimplePath + * @param cfg the input string + * @return DataLocation + */ + def apply(cfg: String): DataLocation = { + val jackson = (new ObjectMapper(new HoconFactory) with CaseClassObjectMapper) + .registerModule(FeathrJacksonScalaModule) // DefaultScalaModule causes a fail on holdem + .configure(DeserializationFeature.ACCEPT_SINGLE_VALUE_AS_ARRAY, true) + .configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false) + .registerModule(new SimpleModule().addDeserializer(classOf[DataSource], new DataSourceLoader)) + try { + // Cfg is either a plain path or a JSON object + if (cfg.trim.startsWith("{")) { + val location = jackson.readValue(cfg, classOf[DataLocation]) + location + } else { + SimplePath(cfg) + } + } catch { + case _ @ (_: ConfigException | _: JacksonException) => SimplePath(cfg) + } + } + + def apply(cfg: Config): DataLocation = { + apply(cfg.root().keySet().asScala.map(key ⇒ key → cfg.getString(key)).toMap) + } + + def apply(cfg: Any): DataLocation = { + val jackson = (new ObjectMapper(new HoconFactory) with CaseClassObjectMapper) + .registerModule(FeathrJacksonScalaModule) // DefaultScalaModule causes a fail on holdem + .configure(DeserializationFeature.ACCEPT_SINGLE_VALUE_AS_ARRAY, true) + .configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false) + .registerModule(new SimpleModule().addDeserializer(classOf[DataSource], new DataSourceLoader)) + try { + val location = jackson.convertValue(cfg, classOf[DataLocation]) + location + } catch { + case e: JacksonException => { + print(e) + SimplePath(cfg.toString) + } + } + } +} diff --git a/src/main/scala/com/linkedin/feathr/offline/config/location/GenericLocation.scala b/src/main/scala/com/linkedin/feathr/offline/config/location/GenericLocation.scala new file mode 100644 index 000000000..1ad8e94ac --- /dev/null +++ b/src/main/scala/com/linkedin/feathr/offline/config/location/GenericLocation.scala @@ -0,0 +1,161 @@ +package com.linkedin.feathr.offline.config.location + +import com.fasterxml.jackson.annotation.JsonAnySetter +import com.fasterxml.jackson.module.caseclass.annotation.CaseClassDeserialize +import com.linkedin.feathr.common.Header +import com.linkedin.feathr.common.exception.FeathrException +import com.linkedin.feathr.offline.generation.FeatureGenUtils +import com.linkedin.feathr.offline.join.DataFrameKeyCombiner +import net.minidev.json.annotate.JsonIgnore +import org.apache.log4j.Logger +import org.apache.spark.sql.functions.monotonically_increasing_id +import org.apache.spark.sql.{DataFrame, DataFrameWriter, Row, SparkSession} + +@CaseClassDeserialize() +case class GenericLocation(format: String, mode: Option[String] = None) extends DataLocation { + val log: Logger = Logger.getLogger(getClass) + val options: collection.mutable.Map[String, String] = collection.mutable.Map[String, String]() + val conf: collection.mutable.Map[String, String] = collection.mutable.Map[String, String]() + + /** + * Backward Compatibility + * Many existing codes expect a simple path + * + * @return the `path` or `url` of the data source + * + * WARN: This method is deprecated, you must use match/case on DataLocation, + * and get `path` from `SimplePath` only + */ + override def getPath: String = s"GenericLocation(${format})" + + /** + * Backward Compatibility + * + * @return the `path` or `url` of the data source, wrapped in an List + * + * WARN: This method is deprecated, you must use match/case on DataLocation, + * and get `paths` from `PathList` only + */ + override def getPathList: List[String] = List(getPath) + + /** + * Load DataFrame from Spark session + * + * @param ss SparkSession + * @return + */ + override def loadDf(ss: SparkSession, dataIOParameters: Map[String, String]): DataFrame = { + GenericLocationFixes.readDf(ss, this) + } + + /** + * Write DataFrame to the location + * + * @param ss SparkSession + * @param df DataFrame to write + */ + override def writeDf(ss: SparkSession, df: DataFrame, header: Option[Header]): Unit = { + GenericLocationFixes.writeDf(ss, df, header, this) + } + + /** + * Tell if this location is file based + * + * @return boolean + */ + override def isFileBasedLocation(): Boolean = false + + @JsonAnySetter + def setOption(key: String, value: Any): Unit = { + println(s"GenericLocation.setOption(key: $key, value: $value)") + if (key == null) { + log.warn("Got null key, skipping") + return + } + if (value == null) { + log.warn(s"Got null value for key '$key', skipping") + return + } + val v = value.toString + if (v == null) { + log.warn(s"Got invalid value for key '$key', skipping") + return + } + if (key.startsWith("__conf__")) { + conf += (key.stripPrefix("__conf__").replace("__", ".") -> LocationUtils.envSubstitute(v)) + } else { + options += (key.replace("__", ".") -> LocationUtils.envSubstitute(v)) + } + } +} + +/** + * Some Spark connectors need extra actions before read or write, namely CosmosDb and ElasticSearch + * Need to run specific fixes base on `format` + */ +object GenericLocationFixes { + def readDf(ss: SparkSession, location: GenericLocation): DataFrame = { + location.conf.foreach(e => { + ss.conf.set(e._1, e._2) + }) + ss.read.format(location.format) + .options(location.options) + .load() + } + + def writeDf(ss: SparkSession, df: DataFrame, header: Option[Header], location: GenericLocation) = { + location.conf.foreach(e => { + ss.conf.set(e._1, e._2) + }) + + location.format.toLowerCase() match { + case "cosmos.oltp" => + // Ensure the database and the table exist before writing + val endpoint = location.options.getOrElse("spark.cosmos.accountEndpoint", throw new FeathrException("Missing spark__cosmos__accountEndpoint")) + val key = location.options.getOrElse("spark.cosmos.accountKey", throw new FeathrException("Missing spark__cosmos__accountKey")) + val databaseName = location.options.getOrElse("spark.cosmos.database", throw new FeathrException("Missing spark__cosmos__database")) + val tableName = location.options.getOrElse("spark.cosmos.container", throw new FeathrException("Missing spark__cosmos__container")) + ss.conf.set("spark.sql.catalog.cosmosCatalog", "com.azure.cosmos.spark.CosmosCatalog") + ss.conf.set("spark.sql.catalog.cosmosCatalog.spark.cosmos.accountEndpoint", endpoint) + ss.conf.set("spark.sql.catalog.cosmosCatalog.spark.cosmos.accountKey", key) + ss.sql(s"CREATE DATABASE IF NOT EXISTS cosmosCatalog.${databaseName};") + ss.sql(s"CREATE TABLE IF NOT EXISTS cosmosCatalog.${databaseName}.${tableName} using cosmos.oltp TBLPROPERTIES(partitionKeyPath = '/id')") + + // CosmosDb requires the column `id` to exist and be the primary key, and `id` must be in `string` type + val keyDf = if (!df.columns.contains("id")) { + header match { + case Some(h) => { + // Generate key column from header info, which is required by CosmosDb + val (keyCol, keyedDf) = DataFrameKeyCombiner().combine(df, FeatureGenUtils.getKeyColumnsFromHeader(h)) + // Rename key column to `id` + keyedDf.withColumnRenamed(keyCol, "id") + } + case None => { + // If there is no key column, we use a auto-generated monotonic id. + // but in this case the result could be duplicated if you run job for multiple times + // This function is for offline-storage usage, ideally user should create a new container for every run + df.withColumn("id", (monotonically_increasing_id().cast("string"))) + } + } + } else { + // We already have an `id` column + // TODO: Should we do anything here? + // A corner case is that the `id` column exists but not unique, then the output will be incomplete as + // CosmosDb will overwrite the old entry with the new one with same `id`. + // We can either rename the existing `id` column and use header/autogen key column, or we can tell user + // to avoid using `id` column for non-unique data, but both workarounds have pros and cons. + df + } + keyDf.write.format(location.format) + .options(location.options) + .mode(location.mode.getOrElse("append")) // CosmosDb doesn't support ErrorIfExist mode in batch mode + .save() + case _ => + // Normal writing procedure, just set format and options then write + df.write.format(location.format) + .options(location.options) + .mode(location.mode.getOrElse("default")) + .save() + } + } +} diff --git a/src/main/scala/com/linkedin/feathr/offline/config/location/Jdbc.scala b/src/main/scala/com/linkedin/feathr/offline/config/location/Jdbc.scala index 1cef50f7b..2837a87a8 100644 --- a/src/main/scala/com/linkedin/feathr/offline/config/location/Jdbc.scala +++ b/src/main/scala/com/linkedin/feathr/offline/config/location/Jdbc.scala @@ -1,14 +1,16 @@ package com.linkedin.feathr.offline.config.location -import com.fasterxml.jackson.annotation.JsonAlias +import com.fasterxml.jackson.annotation.{JsonAlias, JsonIgnoreProperties} import com.fasterxml.jackson.module.caseclass.annotation.CaseClassDeserialize +import com.linkedin.feathr.common.Header import com.linkedin.feathr.offline.source.dataloader.jdbc.JdbcUtils import com.linkedin.feathr.offline.source.dataloader.jdbc.JdbcUtils.DBTABLE_CONF import org.apache.spark.sql.{DataFrame, SparkSession} import org.eclipse.jetty.util.StringUtil @CaseClassDeserialize() -case class Jdbc(url: String, @JsonAlias(Array("query")) dbtable: String, user: String = "", password: String = "", token: String = "", useToken: Boolean = false, anonymous: Boolean = false) extends InputLocation { +@JsonIgnoreProperties(ignoreUnknown = true) +case class Jdbc(url: String, @JsonAlias(Array("query")) dbtable: String, user: String = "", password: String = "", token: String = "") extends DataLocation { override def loadDf(ss: SparkSession, dataIOParameters: Map[String, String] = Map()): DataFrame = { println(s"Jdbc.loadDf, location is ${this}") var reader = ss.read.format("jdbc") @@ -18,28 +20,21 @@ case class Jdbc(url: String, @JsonAlias(Array("query")) dbtable: String, user: S reader = reader.option("dbtable", ss.conf.get(DBTABLE_CONF)) } else { val q = dbtable.trim - if("\\s".r.findFirstIn(q).nonEmpty) { + if ("\\s".r.findFirstIn(q).nonEmpty) { // This is a SQL instead of a table name reader = reader.option("query", q) } else { reader = reader.option("dbtable", q) } } - if (useToken) { + if (!StringUtil.isBlank(token)) { reader.option("accessToken", LocationUtils.envSubstitute(token)) .option("hostNameInCertificate", "*.database.windows.net") .option("encrypt", true) .load } else { if (StringUtil.isBlank(user) && StringUtil.isBlank(password)) { - if (anonymous) { - reader.load() - } else { - // Fallback to global JDBC credential - println("Fallback to default credential") - ss.conf.set(DBTABLE_CONF, dbtable) - JdbcUtils.loadDataFrame(ss, url) - } + reader.load() } else { reader.option("user", LocationUtils.envSubstitute(user)) .option("password", LocationUtils.envSubstitute(password)) @@ -47,6 +42,13 @@ case class Jdbc(url: String, @JsonAlias(Array("query")) dbtable: String, user: S } } + override def writeDf(ss: SparkSession, df: DataFrame, header: Option[Header]): Unit = { + println(s"Jdbc.writeDf, location is ${this}") + df.write.format("jdbc") + .options(getOptions(ss)) + .save() + } + override def getPath: String = url override def getPathList: List[String] = List(url) @@ -54,36 +56,70 @@ case class Jdbc(url: String, @JsonAlias(Array("query")) dbtable: String, user: S override def isFileBasedLocation(): Boolean = false // These members don't contain actual secrets - override def toString: String = s"Jdbc(url=$url, dbtable=$dbtable, useToken=$useToken, anonymous=$anonymous, user=$user, password=$password, token=$token)" + override def toString: String = s"Jdbc(url=$url, dbtable=$dbtable, user=$user, password=$password, token=$token)" + + def getOptions(ss: SparkSession): Map[String, String] = { + val options = collection.mutable.Map[String, String]() + options += ("url" -> url) + + if (StringUtil.isBlank(dbtable)) { + // Fallback to default table name + options += ("dbtable" -> ss.conf.get(DBTABLE_CONF)) + } else { + val q = dbtable.trim + if ("\\s".r.findFirstIn(q).nonEmpty) { + // This is a SQL instead of a table name + options += ("query" -> q) + } else { + options += ("dbtable" -> q) + } + } + if (!StringUtil.isBlank(token)) { + options += ("accessToken" -> LocationUtils.envSubstitute(token)) + options += ("hostNameInCertificate" -> "*.database.windows.net") + options += ("encrypt" -> "true") + } else { + if (!StringUtil.isBlank(user)) { + options += ("user" -> LocationUtils.envSubstitute(user)) + } + if (!StringUtil.isBlank(password)) { + options += ("password" -> LocationUtils.envSubstitute(password)) + } + } + options.toMap + } } -object Jdbc { - /** - * Create JDBC InputLocation with required info and user/password auth - * @param url - * @param dbtable - * @param user - * @param password - * @return Newly created InputLocation instance - */ - def apply(url: String, dbtable: String, user: String, password: String): Jdbc = Jdbc(url, dbtable, user = user, password = password, useToken = false) + object Jdbc { + /** + * Create JDBC InputLocation with required info and user/password auth + * + * @param url + * @param dbtable + * @param user + * @param password + * @return Newly created InputLocation instance + */ + def apply(url: String, dbtable: String, user: String, password: String): Jdbc = Jdbc(url, dbtable, user = user, password = password) - /** - * Create JDBC InputLocation with required info and OAuth token auth - * @param url - * @param dbtable - * @param token - * @return Newly created InputLocation instance - */ - def apply(url: String, dbtable: String, token: String): Jdbc = Jdbc(url, dbtable, token = token, useToken = true) + /** + * Create JDBC InputLocation with required info and OAuth token auth + * + * @param url + * @param dbtable + * @param token + * @return Newly created InputLocation instance + */ + def apply(url: String, dbtable: String, token: String): Jdbc = Jdbc(url, dbtable, token = token) - /** - * Create JDBC InputLocation with required info and OAuth token auth - * In this case, the auth info is taken from default setting passed from CLI/API, details can be found in `Jdbc#loadDf` - * @see com.linkedin.feathr.offline.source.dataloader.jdbc.JDBCUtils#loadDataFrame - * @param url - * @param dbtable - * @return Newly created InputLocation instance - */ - def apply(url: String, dbtable: String): Jdbc = Jdbc(url, dbtable, useToken = false) -} + /** + * Create JDBC InputLocation with required info and OAuth token auth + * In this case, the auth info is taken from default setting passed from CLI/API, details can be found in `Jdbc#loadDf` + * + * @see com.linkedin.feathr.offline.source.dataloader.jdbc.JDBCUtils#loadDataFrame + * @param url + * @param dbtable + * @return Newly created InputLocation instance + */ + def apply(url: String, dbtable: String): Jdbc = Jdbc(url, dbtable) + } diff --git a/src/main/scala/com/linkedin/feathr/offline/config/location/KafkaEndpoint.scala b/src/main/scala/com/linkedin/feathr/offline/config/location/KafkaEndpoint.scala index 42bbc9f72..f3818700b 100644 --- a/src/main/scala/com/linkedin/feathr/offline/config/location/KafkaEndpoint.scala +++ b/src/main/scala/com/linkedin/feathr/offline/config/location/KafkaEndpoint.scala @@ -1,6 +1,7 @@ package com.linkedin.feathr.offline.config.location import com.fasterxml.jackson.module.caseclass.annotation.CaseClassDeserialize +import com.linkedin.feathr.common.Header import org.apache.spark.sql.{DataFrame, SparkSession} import org.codehaus.jackson.annotate.JsonProperty @@ -28,9 +29,11 @@ case class KafkaSchema(@JsonProperty("type") `type`: String, @CaseClassDeserialize() case class KafkaEndpoint(@JsonProperty("brokers") brokers: List[String], @JsonProperty("topics") topics: List[String], - @JsonProperty("schema") schema: KafkaSchema) extends InputLocation { + @JsonProperty("schema") schema: KafkaSchema) extends DataLocation { override def loadDf(ss: SparkSession, dataIOParameters: Map[String, String] = Map()): DataFrame = ??? + override def writeDf(ss: SparkSession, df: DataFrame, header: Option[Header]): Unit = ??? + override def getPath: String = "kafka://" + brokers.mkString(",")+":"+topics.mkString(",") override def getPathList: List[String] = ??? diff --git a/src/main/scala/com/linkedin/feathr/offline/config/location/PathList.scala b/src/main/scala/com/linkedin/feathr/offline/config/location/PathList.scala index 7f491ec32..c583de8a3 100644 --- a/src/main/scala/com/linkedin/feathr/offline/config/location/PathList.scala +++ b/src/main/scala/com/linkedin/feathr/offline/config/location/PathList.scala @@ -1,10 +1,11 @@ package com.linkedin.feathr.offline.config.location +import com.linkedin.feathr.common.Header import com.linkedin.feathr.offline.generation.SparkIOUtils import org.apache.spark.sql.{DataFrame, SparkSession} import org.apache.hadoop.mapred.JobConf -case class PathList(paths: List[String]) extends InputLocation { +case class PathList(paths: List[String]) extends DataLocation { override def getPath: String = paths.mkString(";") override def getPathList: List[String] = paths @@ -15,6 +16,8 @@ case class PathList(paths: List[String]) extends InputLocation { SparkIOUtils.createUnionDataFrame(getPathList, dataIOParameters, new JobConf(), List()) //TODO: Add handler support here. Currently there are deserilization issues with adding handlers to factory builder. } + override def writeDf(ss: SparkSession, df: DataFrame, header: Option[Header]): Unit = ??? + override def toString: String = s"PathList(path=[${paths.mkString(",")}])" } diff --git a/src/main/scala/com/linkedin/feathr/offline/config/location/SimplePath.scala b/src/main/scala/com/linkedin/feathr/offline/config/location/SimplePath.scala index de726bf40..d2d1e2db6 100644 --- a/src/main/scala/com/linkedin/feathr/offline/config/location/SimplePath.scala +++ b/src/main/scala/com/linkedin/feathr/offline/config/location/SimplePath.scala @@ -1,17 +1,20 @@ package com.linkedin.feathr.offline.config.location import com.fasterxml.jackson.module.caseclass.annotation.CaseClassDeserialize +import com.linkedin.feathr.common.Header import com.linkedin.feathr.offline.generation.SparkIOUtils import org.apache.hadoop.mapred.JobConf import org.apache.spark.sql.{DataFrame, SparkSession} import org.codehaus.jackson.annotate.JsonProperty @CaseClassDeserialize() -case class SimplePath(@JsonProperty("path") path: String) extends InputLocation { +case class SimplePath(@JsonProperty("path") path: String) extends DataLocation { override def loadDf(ss: SparkSession, dataIOParameters: Map[String, String] = Map()): DataFrame = { SparkIOUtils.createUnionDataFrame(getPathList, dataIOParameters, new JobConf(), List()) // The simple path is not responsible for handling custom data loaders. } + override def writeDf(ss: SparkSession, df: DataFrame, header: Option[Header]): Unit = ??? + override def getPath: String = path override def getPathList: List[String] = List(path) diff --git a/src/main/scala/com/linkedin/feathr/offline/generation/SparkIOUtils.scala b/src/main/scala/com/linkedin/feathr/offline/generation/SparkIOUtils.scala index 8d1e832bf..3d3052c41 100644 --- a/src/main/scala/com/linkedin/feathr/offline/generation/SparkIOUtils.scala +++ b/src/main/scala/com/linkedin/feathr/offline/generation/SparkIOUtils.scala @@ -1,8 +1,7 @@ package com.linkedin.feathr.offline.generation -import com.linkedin.feathr.offline.config.location.{InputLocation, Jdbc, SimplePath} +import com.linkedin.feathr.offline.config.location.{DataLocation, SimplePath} import com.linkedin.feathr.offline.source.dataloader.hdfs.FileFormat -import com.linkedin.feathr.offline.source.dataloader.jdbc.JdbcUtils import com.linkedin.feathr.offline.source.dataloader.DataLoaderHandler import org.apache.avro.generic.GenericRecord import org.apache.hadoop.mapred.JobConf @@ -35,7 +34,7 @@ object SparkIOUtils { df } - def createDataFrame(location: InputLocation, dataIOParams: Map[String, String] = Map(), jobConf: JobConf, dataLoaderHandlers: List[DataLoaderHandler]): DataFrame = { + def createDataFrame(location: DataLocation, dataIOParams: Map[String, String] = Map(), jobConf: JobConf, dataLoaderHandlers: List[DataLoaderHandler]): DataFrame = { var dfOpt: Option[DataFrame] = None breakable { for (dataLoaderHandler <- dataLoaderHandlers) { @@ -54,23 +53,32 @@ object SparkIOUtils { df } - def writeDataFrame( outputDF: DataFrame, path: String, parameters: Map[String, String] = Map(), dataLoaderHandlers: List[DataLoaderHandler]): DataFrame = { + def writeDataFrame( outputDF: DataFrame, outputLocation: DataLocation, parameters: Map[String, String] = Map(), dataLoaderHandlers: List[DataLoaderHandler]): DataFrame = { var dfWritten = false breakable { for (dataLoaderHandler <- dataLoaderHandlers) { - if (dataLoaderHandler.validatePath(path)) { - dataLoaderHandler.writeDataFrame(outputDF, path, parameters) - dfWritten = true - break + outputLocation match { + case SimplePath(path) => { + if (dataLoaderHandler.validatePath(path)) { + dataLoaderHandler.writeDataFrame(outputDF, path, parameters) + dfWritten = true + break + } + } } } } if(!dfWritten) { - val output_format = outputDF.sqlContext.getConf("spark.feathr.outputFormat", "avro") - // if the output format is set by spark configurations "spark.feathr.outputFormat" - // we will use that as the job output format; otherwise use avro as default for backward compatibility - outputDF.write.mode(SaveMode.Overwrite).format(output_format).save(path) - outputDF + outputLocation match { + case SimplePath(path) => { + val output_format = outputDF.sqlContext.getConf("spark.feathr.outputFormat", "avro") + // if the output format is set by spark configurations "spark.feathr.outputFormat" + // we will use that as the job output format; otherwise use avro as default for backward compatibility + outputDF.write.mode(SaveMode.Overwrite).format(output_format).save(path) + outputDF + } + case _ => outputLocation.writeDf(SparkSession.builder().getOrCreate(), outputDF, None) + } } outputDF } diff --git a/src/main/scala/com/linkedin/feathr/offline/generation/outputProcessor/WriteToHDFSOutputProcessor.scala b/src/main/scala/com/linkedin/feathr/offline/generation/outputProcessor/WriteToHDFSOutputProcessor.scala index 7ac29a779..2b86642ad 100644 --- a/src/main/scala/com/linkedin/feathr/offline/generation/outputProcessor/WriteToHDFSOutputProcessor.scala +++ b/src/main/scala/com/linkedin/feathr/offline/generation/outputProcessor/WriteToHDFSOutputProcessor.scala @@ -4,6 +4,7 @@ import com.linkedin.feathr.offline.util.Transformations.sortColumns import com.linkedin.feathr.common.configObj.generation.OutputProcessorConfig import com.linkedin.feathr.common.exception.{ErrorLabel, FeathrDataOutputException, FeathrException} import com.linkedin.feathr.common.{Header, RichConfig, TaggedFeatureName} +import com.linkedin.feathr.offline.config.location.{DataLocation, SimplePath} import com.linkedin.feathr.offline.generation.{FeatureDataHDFSProcessUtils, FeatureGenerationPathName} import com.linkedin.feathr.offline.util.{FeatureGenConstants, IncrementalAggUtils} import com.linkedin.feathr.offline.source.dataloader.DataLoaderHandler @@ -149,7 +150,25 @@ private[offline] class WriteToHDFSOutputProcessor(val config: OutputProcessorCon // If it's local, we can't write to HDFS. val skipWrite = if (ss.sparkContext.isLocal) true else false - FeatureDataHDFSProcessUtils.processFeatureDataHDFS(ss, featuresToDF, parentPath, config, skipWrite = skipWrite, endTimeOpt, timestampOpt, dataLoaderHandlers) + location match { + case Some(l) => { + // We have a DataLocation to write the df + l.writeDf(ss, augmentedDF, Some(header)) + (augmentedDF, header) + } + case None => { + FeatureDataHDFSProcessUtils.processFeatureDataHDFS(ss, featuresToDF, parentPath, config, skipWrite = skipWrite, endTimeOpt, timestampOpt, dataLoaderHandlers) + } + } + } + + private val location: Option[DataLocation] = { + if (!config.getParams.getStringWithDefault("type", "").isEmpty) { + // The location param contains 'type' key, assuming it's a DataLocation + Some(DataLocation(config.getParams)) + } else { + None + } } // path parameter name diff --git a/src/main/scala/com/linkedin/feathr/offline/job/FeatureJoinJob.scala b/src/main/scala/com/linkedin/feathr/offline/job/FeatureJoinJob.scala index 4daf0e5f4..ef01044d1 100644 --- a/src/main/scala/com/linkedin/feathr/offline/job/FeatureJoinJob.scala +++ b/src/main/scala/com/linkedin/feathr/offline/job/FeatureJoinJob.scala @@ -9,6 +9,7 @@ import com.linkedin.feathr.offline._ import com.linkedin.feathr.offline.client._ import com.linkedin.feathr.offline.config.FeatureJoinConfig import com.linkedin.feathr.offline.config.datasource.{DataSourceConfigUtils, DataSourceConfigs} +import com.linkedin.feathr.offline.config.location.{DataLocation, SimplePath} import com.linkedin.feathr.offline.generation.SparkIOUtils import com.linkedin.feathr.offline.source.SourceFormatType import com.linkedin.feathr.offline.source.accessor.DataPathHandler @@ -88,11 +89,18 @@ object FeatureJoinJob { } private def checkAuthorization(ss: SparkSession, hadoopConf: Configuration, jobContext: FeathrJoinJobContext, dataLoaderHandlers: List[DataLoaderHandler]): Unit = { - AclCheckUtils.checkWriteAuthorization(hadoopConf, jobContext.jobJoinContext.outputPath) match { - case Failure(e) => - throw new FeathrDataOutputException(ErrorLabel.FEATHR_USER_ERROR, s"No write permission for output path ${jobContext.jobJoinContext.outputPath}.", e) - case Success(_) => log.debug("Checked write authorization on output path: " + jobContext.jobJoinContext.outputPath) + + jobContext.jobJoinContext.outputPath match { + case SimplePath(path) => { + AclCheckUtils.checkWriteAuthorization(hadoopConf, path) match { + case Failure(e) => + throw new FeathrDataOutputException(ErrorLabel.FEATHR_USER_ERROR, s"No write permission for output path ${jobContext.jobJoinContext.outputPath}.", e) + case Success(_) => log.debug("Checked write authorization on output path: " + jobContext.jobJoinContext.outputPath) + } + } + case _ => {} } + jobContext.jobJoinContext.inputData.map(inputData => { val failOnMissing = FeathrUtils.getFeathrJobParam(ss, FeathrUtils.FAIL_ON_MISSING_PARTITION).toBoolean val pathList = getPathList(sourceFormatType=inputData.sourceType, @@ -255,13 +263,13 @@ object FeatureJoinJob { } } - val joinJobContext = { - val feathrLocalConfig = cmdParser.extractOptionalValue("feathr-config") - val feathrFeatureConfig = cmdParser.extractOptionalValue("feature-config") - val localOverrideAll = cmdParser.extractRequiredValue("local-override-all") - val outputPath = cmdParser.extractRequiredValue("output") - val numParts = cmdParser.extractRequiredValue("num-parts").toInt + val feathrLocalConfig = cmdParser.extractOptionalValue("feathr-config") + val feathrFeatureConfig = cmdParser.extractOptionalValue("feature-config") + val localOverrideAll = cmdParser.extractRequiredValue("local-override-all") + val outputPath = DataLocation(cmdParser.extractRequiredValue("output")) + val numParts = cmdParser.extractRequiredValue("num-parts").toInt + val joinJobContext = { JoinJobContext( feathrLocalConfig, feathrFeatureConfig, @@ -359,7 +367,10 @@ object FeatureJoinJob { DataSourceConfigUtils.setupHadoopConf(sparkSession, jobContext.dataSourceConfigs) FeathrUdfRegistry.registerUdf(sparkSession) - HdfsUtils.deletePath(jobContext.jobJoinContext.outputPath, recursive = true, conf) + jobContext.jobJoinContext.outputPath match { + case SimplePath(path) => HdfsUtils.deletePath(path, recursive = true, conf) + case _ => {} + } val enableDebugLog = FeathrUtils.getFeathrJobParam(sparkConf, FeathrUtils.ENABLE_DEBUG_OUTPUT).toBoolean if (enableDebugLog) { diff --git a/src/main/scala/com/linkedin/feathr/offline/job/JoinJobContext.scala b/src/main/scala/com/linkedin/feathr/offline/job/JoinJobContext.scala index bdf506796..de72b7bda 100644 --- a/src/main/scala/com/linkedin/feathr/offline/job/JoinJobContext.scala +++ b/src/main/scala/com/linkedin/feathr/offline/job/JoinJobContext.scala @@ -1,6 +1,7 @@ package com.linkedin.feathr.offline.job import com.linkedin.feathr.offline.client.InputData +import com.linkedin.feathr.offline.config.location.{DataLocation, SimplePath} object JoinJobContext { @@ -23,7 +24,7 @@ case class JoinJobContext( feathrLocalConfig: Option[String] = None, feathrFeatureConfig: Option[String] = None, inputData: Option[InputData] = None, - outputPath: String = "/join_output", + outputPath: DataLocation = SimplePath("/join_output"), numParts: Int = 1 ) { } diff --git a/src/main/scala/com/linkedin/feathr/offline/source/DataSource.scala b/src/main/scala/com/linkedin/feathr/offline/source/DataSource.scala index 4f05a3c28..8c132cf4a 100644 --- a/src/main/scala/com/linkedin/feathr/offline/source/DataSource.scala +++ b/src/main/scala/com/linkedin/feathr/offline/source/DataSource.scala @@ -1,6 +1,6 @@ package com.linkedin.feathr.offline.source -import com.linkedin.feathr.offline.config.location.{InputLocation, SimplePath} +import com.linkedin.feathr.offline.config.location.{DataLocation, SimplePath} import com.linkedin.feathr.offline.source.SourceFormatType.SourceFormatType import com.linkedin.feathr.offline.util.{AclCheckUtils, HdfsUtils, LocalFeatureJoinUtils} import org.apache.hadoop.fs.Path @@ -20,7 +20,7 @@ import scala.util.{Failure, Success, Try} * @param timePartitionPattern format of the time partitioned feature */ private[offline] case class DataSource( - val location: InputLocation, + val location: DataLocation, sourceType: SourceFormatType, timeWindowParams: Option[TimeWindowParams], timePartitionPattern: Option[String]) @@ -63,7 +63,7 @@ object DataSource { timeWindowParams: Option[TimeWindowParams] = None, timePartitionPattern: Option[String] = None): DataSource = DataSource(SimplePath(rawPath), sourceType, timeWindowParams, timePartitionPattern) - def apply(inputLocation: InputLocation, + def apply(inputLocation: DataLocation, sourceType: SourceFormatType): DataSource = DataSource(inputLocation, sourceType, None, None) } \ No newline at end of file diff --git a/src/main/scala/com/linkedin/feathr/offline/source/accessor/NonTimeBasedDataSourceAccessor.scala b/src/main/scala/com/linkedin/feathr/offline/source/accessor/NonTimeBasedDataSourceAccessor.scala index 91c6e5665..385a0a833 100644 --- a/src/main/scala/com/linkedin/feathr/offline/source/accessor/NonTimeBasedDataSourceAccessor.scala +++ b/src/main/scala/com/linkedin/feathr/offline/source/accessor/NonTimeBasedDataSourceAccessor.scala @@ -1,6 +1,6 @@ package com.linkedin.feathr.offline.source.accessor -import com.linkedin.feathr.offline.config.location.{Jdbc, KafkaEndpoint, PathList, SimplePath} +import com.linkedin.feathr.offline.config.location.{GenericLocation, Jdbc, KafkaEndpoint, PathList, SimplePath} import com.linkedin.feathr.offline.source.DataSource import com.linkedin.feathr.offline.source.dataloader.DataLoaderFactory import com.linkedin.feathr.offline.testfwk.TestFwkUtils @@ -30,7 +30,8 @@ private[offline] class NonTimeBasedDataSourceAccessor( val df = source.location match { case SimplePath(path) => List(path).map(fileLoaderFactory.create(_).loadDataFrame()).reduce((x, y) => x.fuzzyUnion(y)) case PathList(paths) => paths.map(fileLoaderFactory.create(_).loadDataFrame()).reduce((x, y) => x.fuzzyUnion(y)) - case Jdbc(_, _, _, _, _, _, _) => source.location.loadDf(SparkSession.builder().getOrCreate()) + case Jdbc(_, _, _, _, _) => source.location.loadDf(SparkSession.builder().getOrCreate()) + case GenericLocation(_, _) => source.location.loadDf(SparkSession.builder().getOrCreate()) case _ => fileLoaderFactory.createFromLocation(source.location).loadDataFrame() } diff --git a/src/main/scala/com/linkedin/feathr/offline/source/dataloader/BatchDataLoader.scala b/src/main/scala/com/linkedin/feathr/offline/source/dataloader/BatchDataLoader.scala index db430fb4a..63fe275a0 100644 --- a/src/main/scala/com/linkedin/feathr/offline/source/dataloader/BatchDataLoader.scala +++ b/src/main/scala/com/linkedin/feathr/offline/source/dataloader/BatchDataLoader.scala @@ -1,7 +1,7 @@ package com.linkedin.feathr.offline.source.dataloader import com.linkedin.feathr.common.exception.{ErrorLabel, FeathrInputDataException} -import com.linkedin.feathr.offline.config.location.InputLocation +import com.linkedin.feathr.offline.config.location.DataLocation import com.linkedin.feathr.offline.generation.SparkIOUtils import com.linkedin.feathr.offline.job.DataSourceUtils.getSchemaFromAvroDataFile import com.linkedin.feathr.offline.source.dataloader.jdbc.JdbcUtils @@ -16,7 +16,7 @@ import org.apache.spark.sql.{DataFrame, SparkSession} * @param ss the spark session * @param path input data path */ -private[offline] class BatchDataLoader(ss: SparkSession, location: InputLocation, dataLoaderHandlers: List[DataLoaderHandler]) extends DataLoader { +private[offline] class BatchDataLoader(ss: SparkSession, location: DataLocation, dataLoaderHandlers: List[DataLoaderHandler]) extends DataLoader { /** * get the schema of the source. It's only used in the deprecated DataSource.getDataSetAndSchema diff --git a/src/main/scala/com/linkedin/feathr/offline/source/dataloader/DataLoaderFactory.scala b/src/main/scala/com/linkedin/feathr/offline/source/dataloader/DataLoaderFactory.scala index d3b94ebfc..057be7e9b 100644 --- a/src/main/scala/com/linkedin/feathr/offline/source/dataloader/DataLoaderFactory.scala +++ b/src/main/scala/com/linkedin/feathr/offline/source/dataloader/DataLoaderFactory.scala @@ -1,6 +1,6 @@ package com.linkedin.feathr.offline.source.dataloader -import com.linkedin.feathr.offline.config.location.InputLocation +import com.linkedin.feathr.offline.config.location.DataLocation import com.linkedin.feathr.offline.source.dataloader.DataLoaderHandler import org.apache.log4j.Logger import org.apache.spark.customized.CustomGenericRowWithSchema @@ -21,7 +21,7 @@ private[offline] trait DataLoaderFactory { */ def create(path: String): DataLoader - def createFromLocation(input: InputLocation): DataLoader = create(input.getPath) + def createFromLocation(input: DataLocation): DataLoader = create(input.getPath) } private[offline] object DataLoaderFactory { diff --git a/src/main/scala/com/linkedin/feathr/offline/source/dataloader/LocalDataLoaderFactory.scala b/src/main/scala/com/linkedin/feathr/offline/source/dataloader/LocalDataLoaderFactory.scala index 4d2fe505f..c6f007333 100644 --- a/src/main/scala/com/linkedin/feathr/offline/source/dataloader/LocalDataLoaderFactory.scala +++ b/src/main/scala/com/linkedin/feathr/offline/source/dataloader/LocalDataLoaderFactory.scala @@ -1,6 +1,6 @@ package com.linkedin.feathr.offline.source.dataloader -import com.linkedin.feathr.offline.config.location.{InputLocation, KafkaEndpoint, SimplePath} +import com.linkedin.feathr.offline.config.location.{DataLocation, KafkaEndpoint, SimplePath} import com.linkedin.feathr.offline.source.dataloader.stream.KafkaDataLoader import com.linkedin.feathr.offline.source.dataloader.DataLoaderHandler import com.linkedin.feathr.offline.util.LocalFeatureJoinUtils @@ -58,7 +58,7 @@ dataLoaderHandlers: List[DataLoaderHandler]) extends DataLoaderFactory { } } - override def createFromLocation(inputLocation: InputLocation): DataLoader = { + override def createFromLocation(inputLocation: DataLocation): DataLoader = { if (inputLocation.isInstanceOf[KafkaEndpoint]) { new KafkaDataLoader(ss, inputLocation.asInstanceOf[KafkaEndpoint]) } else { diff --git a/src/main/scala/com/linkedin/feathr/offline/source/dataloader/StreamingDataLoaderFactory.scala b/src/main/scala/com/linkedin/feathr/offline/source/dataloader/StreamingDataLoaderFactory.scala index 2eafde55b..279cd04a8 100644 --- a/src/main/scala/com/linkedin/feathr/offline/source/dataloader/StreamingDataLoaderFactory.scala +++ b/src/main/scala/com/linkedin/feathr/offline/source/dataloader/StreamingDataLoaderFactory.scala @@ -1,6 +1,6 @@ package com.linkedin.feathr.offline.source.dataloader -import com.linkedin.feathr.offline.config.location.{InputLocation, KafkaEndpoint} +import com.linkedin.feathr.offline.config.location.{DataLocation, KafkaEndpoint} import com.linkedin.feathr.offline.source.dataloader.stream.KafkaDataLoader import org.apache.spark.sql.SparkSession @@ -17,7 +17,7 @@ private[offline] class StreamingDataLoaderFactory(ss: SparkSession) extends Dat * @param input the input location for streaming * @return a [[DataLoader]] */ - override def createFromLocation(input: InputLocation): DataLoader = new KafkaDataLoader(ss, input.asInstanceOf[KafkaEndpoint]) + override def createFromLocation(input: DataLocation): DataLoader = new KafkaDataLoader(ss, input.asInstanceOf[KafkaEndpoint]) /** * create a data loader based on the file type. diff --git a/src/main/scala/com/linkedin/feathr/offline/util/SourceUtils.scala b/src/main/scala/com/linkedin/feathr/offline/util/SourceUtils.scala index 03701a95d..e9d3b1335 100644 --- a/src/main/scala/com/linkedin/feathr/offline/util/SourceUtils.scala +++ b/src/main/scala/com/linkedin/feathr/offline/util/SourceUtils.scala @@ -7,7 +7,7 @@ import com.jasonclawson.jackson.dataformat.hocon.HoconFactory import com.linkedin.feathr.common.exception._ import com.linkedin.feathr.common.{AnchorExtractor, DateParam} import com.linkedin.feathr.offline.client.InputData -import com.linkedin.feathr.offline.config.location.{InputLocation, SimplePath} +import com.linkedin.feathr.offline.config.location.{DataLocation, SimplePath} import com.linkedin.feathr.offline.generation.SparkIOUtils import com.linkedin.feathr.offline.mvel.{MvelContext, MvelUtils} import com.linkedin.feathr.offline.source.SourceFormatType @@ -175,7 +175,7 @@ private[offline] object SourceUtils { def safeWriteDF(df: DataFrame, dataPath: String, parameters: Map[String, String], dataLoaderHandlers: List[DataLoaderHandler]): Unit = { val tempBasePath = dataPath.stripSuffix("/") + "_temp_" HdfsUtils.deletePath(dataPath, true) - SparkIOUtils.writeDataFrame(df, tempBasePath, parameters, dataLoaderHandlers) + SparkIOUtils.writeDataFrame(df, SimplePath(tempBasePath), parameters, dataLoaderHandlers) if (HdfsUtils.exists(tempBasePath) && !HdfsUtils.renamePath(tempBasePath, dataPath)) { throw new FeathrDataOutputException( ErrorLabel.FEATHR_ERROR, @@ -644,8 +644,8 @@ private[offline] object SourceUtils { * @param inputPath * @return */ - def loadAsDataFrame(ss: SparkSession, location: InputLocation, - dataLoaderHandlers: List[DataLoaderHandler]): DataFrame = { + def loadAsDataFrame(ss: SparkSession, location: DataLocation, + dataLoaderHandlers: List[DataLoaderHandler]): DataFrame = { val sparkConf = ss.sparkContext.getConf val inputSplitSize = sparkConf.get("spark.feathr.input.split.size", "") val dataIOParameters = Map(SparkIOUtils.SPLIT_SIZE -> inputSplitSize) diff --git a/src/test/scala/com/linkedin/feathr/offline/AnchoredFeaturesIntegTest.scala b/src/test/scala/com/linkedin/feathr/offline/AnchoredFeaturesIntegTest.scala index 8ef0d0a7e..3ca387b55 100644 --- a/src/test/scala/com/linkedin/feathr/offline/AnchoredFeaturesIntegTest.scala +++ b/src/test/scala/com/linkedin/feathr/offline/AnchoredFeaturesIntegTest.scala @@ -2,6 +2,7 @@ package com.linkedin.feathr.offline import com.linkedin.feathr.common.configObj.configbuilder.ConfigBuilderException import com.linkedin.feathr.common.exception.FeathrConfigException +import com.linkedin.feathr.offline.config.location.SimplePath import com.linkedin.feathr.offline.generation.SparkIOUtils import com.linkedin.feathr.offline.job.PreprocessedDataFrameManager import com.linkedin.feathr.offline.mvel.plugins.FeathrMvelPluginContext @@ -144,7 +145,7 @@ class AnchoredFeaturesIntegTest extends FeathrIntegTest { // create a data source from anchorAndDerivations/nullValueSource.avro.json val df = new AvroJsonDataLoader(ss, "nullValueSource.avro.json").loadDataFrame() - SparkIOUtils.writeDataFrame(df, mockDataFolder + "/nullValueSource", parameters=Map(), dataLoaderHandlers=List()) + SparkIOUtils.writeDataFrame(df, SimplePath(mockDataFolder + "/nullValueSource"), parameters=Map(), dataLoaderHandlers=List()) } /** diff --git a/src/test/scala/com/linkedin/feathr/offline/config/TestDataSourceLoader.scala b/src/test/scala/com/linkedin/feathr/offline/config/TestDataSourceLoader.scala index 762d3fd77..585e2eab6 100644 --- a/src/test/scala/com/linkedin/feathr/offline/config/TestDataSourceLoader.scala +++ b/src/test/scala/com/linkedin/feathr/offline/config/TestDataSourceLoader.scala @@ -54,7 +54,7 @@ class TestDataSourceLoader extends FunSuite { |""".stripMargin val ds = jackson.readValue(configDoc, classOf[DataSource]) ds.location match { - case Jdbc(url, dbtable, user, password, token, useToken, _) => { + case Jdbc(url, dbtable, user, password, token) => { assert(url == "jdbc:sqlserver://myserver.database.windows.net:1433;database=mydatabase") assert(user=="bar") assert(password=="foo") diff --git a/src/test/scala/com/linkedin/feathr/offline/config/location/TestDesLocation.scala b/src/test/scala/com/linkedin/feathr/offline/config/location/TestDesLocation.scala index 5d5f80998..1be2adf77 100644 --- a/src/test/scala/com/linkedin/feathr/offline/config/location/TestDesLocation.scala +++ b/src/test/scala/com/linkedin/feathr/offline/config/location/TestDesLocation.scala @@ -7,7 +7,7 @@ import com.jasonclawson.jackson.dataformat.hocon.HoconFactory import com.linkedin.feathr.common.FeathrJacksonScalaModule import com.linkedin.feathr.offline.config.DataSourceLoader import com.linkedin.feathr.offline.config.location.LocationUtils.envSubstitute -import com.linkedin.feathr.offline.config.location.{InputLocation, Jdbc, SimplePath} +import com.linkedin.feathr.offline.config.location.{DataLocation, Jdbc, SimplePath} import com.linkedin.feathr.offline.generation.SparkIOUtils import com.linkedin.feathr.offline.source.DataSource import org.apache.spark.{SparkConf, SparkContext} @@ -32,7 +32,7 @@ class TestDesLocation extends FunSuite { test("Deserialize Location") { { val configDoc = """{ path: "abfss://feathrazuretest3fs@feathrazuretest3storage.dfs.core.windows.net/demo_data/green_tripdata_2020-04.csv" }""" - val ds = jackson.readValue(configDoc, classOf[InputLocation]) + val ds = jackson.readValue(configDoc, classOf[DataLocation]) ds match { case SimplePath(path) => { assert(path == "abfss://feathrazuretest3fs@feathrazuretest3storage.dfs.core.windows.net/demo_data/green_tripdata_2020-04.csv") @@ -50,9 +50,9 @@ class TestDesLocation extends FunSuite { | user: "bar" | password: "foo" |}""".stripMargin - val ds = jackson.readValue(configDoc, classOf[InputLocation]) + val ds = jackson.readValue(configDoc, classOf[DataLocation]) ds match { - case Jdbc(url, dbtable, user, password, token, useToken, _) => { + case Jdbc(url, dbtable, user, password, token) => { assert(url == "jdbc:sqlserver://myserver.database.windows.net:1433;database=mydatabase") assert(user == "bar") assert(password == "foo") @@ -69,7 +69,7 @@ class TestDesLocation extends FunSuite { | type: "pathlist" | paths: ["abfss://feathrazuretest3fs@feathrazuretest3storage.dfs.core.windows.net/demo_data/green_tripdata_2020-04.csv"] |}""".stripMargin - val ds = jackson.readValue(configDoc, classOf[InputLocation]) + val ds = jackson.readValue(configDoc, classOf[DataLocation]) ds match { case PathList(pathList) => { assert(pathList == List("abfss://feathrazuretest3fs@feathrazuretest3storage.dfs.core.windows.net/demo_data/green_tripdata_2020-04.csv")) @@ -88,7 +88,7 @@ class TestDesLocation extends FunSuite { | dbtable: "table1" | anonymous: true |}""".stripMargin - val ds = jackson.readValue(configDoc, classOf[InputLocation]) + val ds = jackson.readValue(configDoc, classOf[DataLocation]) val _ = SparkSession.builder().config("spark.master", "local").appName("Sqlite test").getOrCreate() @@ -112,7 +112,31 @@ class TestDesLocation extends FunSuite { | query: "select c1, c2 from table1" | anonymous: true |}""".stripMargin - val ds = jackson.readValue(configDoc, classOf[InputLocation]) + val ds = jackson.readValue(configDoc, classOf[DataLocation]) + + val _ = SparkSession.builder().config("spark.master", "local").appName("Sqlite test").getOrCreate() + + val df = SparkIOUtils.createDataFrame(ds, Map(), new JobConf(), List()) + val rows = df.head(3) + assert(rows(0).getLong(0) == 1) + assert(rows(1).getLong(0) == 2) + assert(rows(2).getLong(0) == 3) + assert(rows(0).getString(1) == "r1c2") + assert(rows(1).getString(1) == "r2c2") + assert(rows(2).getString(1) == "r3c2") + } + + test("Test GenericLocation load Sqlite with SQL query") { + val path = s"${System.getProperty("user.dir")}/src/test/resources/mockdata/sqlite/test.db" + val configDoc = + s""" + |{ + | type: "generic" + | format: "jdbc" + | url: "jdbc:sqlite:${path}" + | query: "select c1, c2 from table1" + |}""".stripMargin + val ds = jackson.readValue(configDoc, classOf[DataLocation]) val _ = SparkSession.builder().config("spark.master", "local").appName("Sqlite test").getOrCreate()