Here is a generic solution to this problem that makes it possible to update any number of nested values, at any level, based on an arbitrary function applied in a recursive traversal:
def mutate(df: DataFrame, fn: Column => Column): DataFrame = {
// Get a projection with fields mutated by `fn` and select it
// out of the original frame with the schema reassigned to the original
// frame (explained later)
df.sqlContext.createDataFrame(df.select(traverse(df.schema, fn):_*).rdd, df.schema)
}
def traverse(schema: StructType, fn: Column => Column, path: String = ""): Array[Column] = {
schema.fields.map(f => {
f.dataType match {
case s: StructType => struct(traverse(s, fn, path + f.name + "."): _*)
case _ => fn(col(path + f.name))
}
})
}
This is effectively equivalent to the usual "just redefine the whole struct as a projection" solutions, but it automates re-nesting fields with the original structure AND preserves nullability/metadata (which are lost when you redefine the structs manually). Annoyingly, preserving those properties isn't possible while creating the projection (afaict) so the code above redefines the schema manually.
An example application:
case class Organ(name: String, count: Int)
case class Disease(id: Int, name: String, organ: Organ)
case class Drug(id: Int, name: String, alt: Array[String])
val df = Seq(
(1, Drug(1, "drug1", Array("x", "y")), Disease(1, "disease1", Organ("heart", 2))),
(2, Drug(2, "drug2", Array("a")), Disease(2, "disease2", Organ("eye", 3)))
).toDF("id", "drug", "disease")
df.show(false)
+---+------------------+-------------------------+
|id |drug |disease |
+---+------------------+-------------------------+
|1 |[1, drug1, [x, y]]|[1, disease1, [heart, 2]]|
|2 |[2, drug2, [a]] |[2, disease2, [eye, 3]] |
+---+------------------+-------------------------+
// Update the integer field ("count") at the lowest level:
val df2 = mutate(df, c => if (c.toString == "disease.organ.count") c - 1 else c)
df2.show(false)
+---+------------------+-------------------------+
|id |drug |disease |
+---+------------------+-------------------------+
|1 |[1, drug1, [x, y]]|[1, disease1, [heart, 1]]|
|2 |[2, drug2, [a]] |[2, disease2, [eye, 2]] |
+---+------------------+-------------------------+
// This will NOT necessarily be equal unless the metadata and nullability
// of all fields is preserved (as the code above does)
assertResult(df.schema.toString)(df2.schema.toString)
A limitation of this is that it cannot add new fields, only update existing ones (though the map can be changed into a flatMap and the function to return Array[Column] for that, if you don't care about preserving nullability/metadata).
Additionally, here is a more generic version for Dataset[T]:
case class Record(id: Int, drug: Drug, disease: Disease)
def mutateDS[T](df: Dataset[T], fn: Column => Column)(implicit enc: Encoder[T]): Dataset[T] = {
df.sqlContext.createDataFrame(df.select(traverse(df.schema, fn):_*).rdd, enc.schema).as[T]
}
// To call as typed dataset:
val fn: Column => Column = c => if (c.toString == "disease.organ.count") c - 1 else c
mutateDS(df.as[Record], fn).show(false)
// To call as untyped dataset:
implicit val encoder: ExpressionEncoder[Row] = RowEncoder(df.schema) // This is necessary regardless of sparkSession.implicits._ imports
mutateDS(df, fn).show(false)