The solution I suggested in Stratified sampling in Spark
is pretty straightforward to convert from Scala to Python (or even to Java - What's the easiest way to stratify a Spark Dataset ?).
Nevertheless, I'll rewrite it python. Let's start first by creating a toy DataFrame
:
from pyspark.sql.functions import lit
list = [(2147481832,23355149,1),(2147481832,973010692,1),(2147481832,2134870842,1),(2147481832,541023347,1),(2147481832,1682206630,1),(2147481832,1138211459,1),(2147481832,852202566,1),(2147481832,201375938,1),(2147481832,486538879,1),(2147481832,919187908,1),(214748183,919187908,1),(214748183,91187908,1)]
df = spark.createDataFrame(list, ["x1","x2","x3"])
df.show()
# +----------+----------+---+
# | x1| x2| x3|
# +----------+----------+---+
# |2147481832| 23355149| 1|
# |2147481832| 973010692| 1|
# |2147481832|2134870842| 1|
# |2147481832| 541023347| 1|
# |2147481832|1682206630| 1|
# |2147481832|1138211459| 1|
# |2147481832| 852202566| 1|
# |2147481832| 201375938| 1|
# |2147481832| 486538879| 1|
# |2147481832| 919187908| 1|
# | 214748183| 919187908| 1|
# | 214748183| 91187908| 1|
# +----------+----------+---+
This DataFrame
has 12 elements as you can see :
df.count()
# 12
Distributed as followed :
df.groupBy("x1").count().show()
# +----------+-----+
# | x1|count|
# +----------+-----+
# |2147481832| 10|
# | 214748183| 2|
# +----------+-----+
Now let's sample :
First we'll set the seed :
seed = 12
The find the keys to fraction on and sample :
fractions = df.select("x1").distinct().withColumn("fraction", lit(0.8)).rdd.collectAsMap()
print(fractions)
# {2147481832: 0.8, 214748183: 0.8}
sampled_df = df.stat.sampleBy("x1", fractions, seed)
sampled_df.show()
# +----------+---------+---+
# | x1| x2| x3|
# +----------+---------+---+
# |2147481832| 23355149| 1|
# |2147481832|973010692| 1|
# |2147481832|541023347| 1|
# |2147481832|852202566| 1|
# |2147481832|201375938| 1|
# |2147481832|486538879| 1|
# |2147481832|919187908| 1|
# | 214748183|919187908| 1|
# | 214748183| 91187908| 1|
# +----------+---------+---+
We can now check the content of our sample :
sampled_df.count()
# 9
sampled_df.groupBy("x1").count().show()
# +----------+-----+
# | x1|count|
# +----------+-----+
# |2147481832| 7|
# | 214748183| 2|
# +----------+-----+