Welcome to OStack Knowledge Sharing Community for programmer and developer-Open, Learning and Share
Welcome To Ask or Share your Answers For Others

Categories

0 votes
1.1k views
in Technique[技术] by (71.8m points)

pyspark - Avoid performance impact of a single partition mode in Spark window functions

My question is triggered by the use case of calculating the differences between consecutive rows in a spark dataframe.

For example, I have:

>>> df.show()
+-----+----------+
|index|      col1|
+-----+----------+
|  0.0|0.58734024|
|  1.0|0.67304325|
|  2.0|0.85154736|
|  3.0| 0.5449719|
+-----+----------+

If I choose to calculate these using "Window" functions, then I can do that like so:

>>> winSpec = Window.partitionBy(df.index >= 0).orderBy(df.index.asc())
>>> import pyspark.sql.functions as f
>>> df.withColumn('diffs_col1', f.lag(df.col1, -1).over(winSpec) - df.col1).show()
+-----+----------+-----------+
|index|      col1| diffs_col1|
+-----+----------+-----------+
|  0.0|0.58734024|0.085703015|
|  1.0|0.67304325| 0.17850411|
|  2.0|0.85154736|-0.30657548|
|  3.0| 0.5449719|       null|
+-----+----------+-----------+

Question: I explicitly partitioned the dataframe in a single partition. What is the performance impact of this and, if there is, why is that so and how could I avoid it? Because when I do not specify a partition, I get the following warning:

16/12/24 13:52:27 WARN WindowExec: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation.
Question&Answers:os

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome To Ask or Share your Answers For Others

1 Answer

0 votes
by (71.8m points)

In practice performance impact will be almost the same as if you omitted partitionBy clause at all. All records will be shuffled to a single partition, sorted locally and iterated sequentially one by one.

The difference is only in the number of partitions created in total. Let's illustrate that with an example using simple dataset with 10 partitions and 1000 records:

df = spark.range(0, 1000, 1, 10).toDF("index").withColumn("col1", f.randn(42))

If you define frame without partition by clause

w_unpart = Window.orderBy(f.col("index").asc())

and use it with lag

df_lag_unpart = df.withColumn(
    "diffs_col1", f.lag("col1", 1).over(w_unpart) - f.col("col1")
)

there will be only one partition in total:

df_lag_unpart.rdd.glom().map(len).collect()
[1000]

Compared to that frame definition with dummy index (simplified a bit compared to your code:

w_part = Window.partitionBy(f.lit(0)).orderBy(f.col("index").asc())

will use number of partitions equal to spark.sql.shuffle.partitions:

spark.conf.set("spark.sql.shuffle.partitions", 11)

df_lag_part = df.withColumn(
    "diffs_col1", f.lag("col1", 1).over(w_part) - f.col("col1")
)

df_lag_part.rdd.glom().count()
11

with only one non-empty partition:

df_lag_part.rdd.glom().filter(lambda x: x).count()
1

Unfortunately there is no universal solution which can be used to address this problem in PySpark. This just an inherent mechanism of the implementation combined with distributed processing model.

Since index column is sequential you could generate artificial partitioning key with fixed number of records per block:

rec_per_block  = df.count() // int(spark.conf.get("spark.sql.shuffle.partitions"))

df_with_block = df.withColumn(
    "block", (f.col("index") / rec_per_block).cast("int")
)

and use it to define frame specification:

w_with_block = Window.partitionBy("block").orderBy("index")

df_lag_with_block = df_with_block.withColumn(
    "diffs_col1", f.lag("col1", 1).over(w_with_block) - f.col("col1")
)

This will use expected number of partitions:

df_lag_with_block.rdd.glom().count()
11

with roughly uniform data distribution (we cannot avoid hash collisions):

df_lag_with_block.rdd.glom().map(len).collect()
[0, 180, 0, 90, 90, 0, 90, 90, 100, 90, 270]

but with a number of gaps on the block boundaries:

df_lag_with_block.where(f.col("diffs_col1").isNull()).count()
12

Since boundaries are easy to compute:

from itertools import chain

boundary_idxs = sorted(chain.from_iterable(
    # Here we depend on sequential identifiers
    # This could be generalized to any monotonically increasing
    # id by taking min and max per block
    (idx - 1, idx) for idx in 
    df_lag_with_block.groupBy("block").min("index")
        .drop("block").rdd.flatMap(lambda x: x)
        .collect()))[2:]  # The first boundary doesn't carry useful inf.

you can always select:

missing = df_with_block.where(f.col("index").isin(boundary_idxs))

and fill these separately:

# We use window without partitions here. Since number of records
# will be small this won't be a performance issue
# but will generate "Moving all data to a single partition" warning
missing_with_lag = missing.withColumn(
    "diffs_col1", f.lag("col1", 1).over(w_unpart) - f.col("col1")
).select("index", f.col("diffs_col1").alias("diffs_fill"))

and join:

combined = (df_lag_with_block
    .join(missing_with_lag, ["index"], "leftouter")
    .withColumn("diffs_col1", f.coalesce("diffs_col1", "diffs_fill")))

to get desired result:

mismatched = combined.join(df_lag_unpart, ["index"], "outer").where(
    combined["diffs_col1"] != df_lag_unpart["diffs_col1"]
)
assert mismatched.count() == 0

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome to OStack Knowledge Sharing Community for programmer and developer-Open, Learning and Share
Click Here to Ask a Question

...