I solved this with a global groupBy
. This works for numeric and non-numeric columns:
case class Entry(id: Long, name: String, value: java.lang.Float)
val results = Seq(
Entry(10, null, null),
Entry(10, null, null),
Entry(20, null, null)
val df: DataFrame = spark.createDataFrame(results)
// mark all columns with null only
val row = df
.select(df.columns.map(c => when(col(c).isNull, 0).otherwise(1).as(c)): _*)
.groupBy().max(df.columns.map(c => c): _*)
// and filter the columns out
val colKeep = row.getValuesMap[Int](row.schema.fieldNames)
.map{c => if (c._2 == 1) Some(c._1) else None }
.map(c => col(c.drop(4).dropRight(1))): _*).show(false)
|id |
|10 |
|10 |
|20 |
Edit: I removed the shuffling of columns. The new approach keeps the given order of the columns.