It can be done more efficiently, as long as you are fine with approximations, and don't require exact results (or exact number or results).
Similarly to my answer to Efficient string matching in Apache Spark you can use LSH, with:
If feature space is small (or can be reasonably reduced) and each category is relatively small you can also optimize your code by hand:
explode
feature array to generate #features records from a single record.
- Self join result by feature, compute distance and filter out candidates (each pair of records will be compared if and only if they share specific categorical feature).
- Take top records using your current code.
A minimal example would be (consider it to be a pseudocode):
import org.apache.spark.ml.linalg._
// This is oversimplified. In practice don't assume only sparse scenario
val indices = udf((v: SparseVector) => v.indices)
val df = Seq(
(1L, Vectors.sparse(1024, Array(1, 3, 5), Array(1.0, 1.0, 1.0))),
(2L, Vectors.sparse(1024, Array(3, 8, 12), Array(1.0, 1.0, 1.0))),
(3L, Vectors.sparse(1024, Array(3, 5), Array(1.0, 1.0))),
(4L, Vectors.sparse(1024, Array(11, 21), Array(1.0, 1.0))),
(5L, Vectors.sparse(1024, Array(21, 32), Array(1.0, 1.0)))
).toDF("id", "features")
val possibleMatches = df
.withColumn("key", explode(indices($"features")))
.transform(df => df.alias("left").join(df.alias("right"), Seq("key")))
val closeEnough(threshold: Double) = udf((v1: SparseVector, v2: SparseVector) => intersectionCosine(v1, v2) > threshold)
possilbeMatches.filter(closeEnough($"left.features", $"right.features")).select($"left.id", $"right.id").distinct
Note that both solutions are worth the overhead only if hashing / features are selective enough (and optimally sparse). In the example shown above you'd compare only rows inside set {1, 2, 3} and {4, 5}, never between sets.
However in the worst case scenario (M records, N features) we can make N M2 comparisons, instead of M2