Simplest Approach (requires Spark 2.0.1+ and not exact median)
As noted in the comments in reference to the first question Find median in Spark SQL for double datatype columns, we can use percentile_approx
to calculate median for Spark 2.0.1+. To apply this for grouped data in Apache Spark, the query would look like:
val df = Seq(("A", 0.0), ("A", 0.0), ("A", 1.0), ("A", 1.0), ("A", 1.0), ("A", 1.0), ("B", 0.0), ("B", 1.0), ("B", 1.0)).toDF("id", "num")
df.createOrReplaceTempView("df")
spark.sql("select id, percentile_approx(num, 0.5) as median from df group by id order by id").show()
with the output being:
+---+------+
| id|median|
+---+------+
| A| 1.0|
| B| 1.0|
+---+------+
Saying this, this is an approximate value (as opposed to an exact median per the question).
Calculate exact median for grouped data
There are multiple approaches so I'm sure others in SO can provide better or more efficient examples. But here's a code snippet calculate the median for grouped data in Spark (verified in Spark 1.6 and Spark 2.1):
import org.apache.spark.SparkContext._
val rdd: RDD[(String, Double)] = sc.parallelize(Seq(("A", 1.0), ("A", 0.0), ("A", 1.0), ("A", 1.0), ("A", 0.0), ("A", 1.0), ("B", 0.0), ("B", 1.0), ("B", 1.0)))
// Scala median function
def median(inputList: List[Double]): Double = {
val count = inputList.size
if (count % 2 == 0) {
val l = count / 2 - 1
val r = l + 1
(inputList(l) + inputList(r)).toDouble / 2
} else
inputList(count / 2).toDouble
}
// Sort the values
val setRDD = rdd.groupByKey()
val sortedListRDD = setRDD.mapValues(_.toList.sorted)
// Output DataFrame of id and median
sortedListRDD.map(m => {
(m._1, median(m._2))
}).toDF("id", "median_of_num").show()
with the output being:
+---+-------------+
| id|median_of_num|
+---+-------------+
| A| 1.0|
| B| 1.0|
+---+-------------+
There are some caveats that I should call out as this likely isn't the most efficient implementation:
- It's currently using a
groupByKey
which is not very performant. You may want to change this into a reduceByKey
instead (more information at Avoid GroupByKey)
- Using a Scala function to calculate the
median
.
This approach should work okay for smaller amounts of data but if you have millions of rows for each key, would advise utilizing Spark 2.0.1+ and using the percentile_approx
approach.