Welcome to OStack Knowledge Sharing Community for programmer and developer-Open, Learning and Share
Welcome To Ask or Share your Answers For Others

Categories

0 votes
1.4k views
in Technique[技术] by (71.8m points)

apache spark - pyspark use dataframe inside udf

I have two dataframes df1

+---+---+----------+
|  n|val| distances|
+---+---+----------+
|  1|  1|0.27308652|
|  2|  1|0.24969208|
|  3|  1|0.21314497|
+---+---+----------+

and df2

+---+---+----------+
| x1| x2|         w|
+---+---+----------+
|  1|  2|0.03103427|
|  1|  4|0.19012526|
|  1| 10|0.26805446|
|  1|  8|0.26825935|
+---+---+----------+

I want to add a new column to df1 called gamma, which will contain the sum of the w value from df2 when df1.n == df2.x1 OR df1.n == df2.x2

I tried to use udf, but apparently selecting from the different dataframe will not work, because values should be determined before calculation

gamma_udf = udf(lambda n: float(df2.filter("x1 = %d OR x2 = %d"%(n,n)).groupBy().sum('w').rdd.map(lambda x: x).collect()[0]), FloatType())
df1.withColumn('gamma1', gamma_udf('n'))

Is there any way of doing it with join or groupby without using cycles?

See Question&Answers more detail:os

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome To Ask or Share your Answers For Others

1 Answer

0 votes
by (71.8m points)

You can't reference a DataFrame inside of a udf. As your alluded to, this problem is best solved using a join.

IIUC, you are looking for something like:

from pyspark.sql import Window
import pyspark.sql.functions as F

df1.alias("L").join(df2.alias("R"), (df1.n == df2.x1) | (df1.n == df2.x2), how="left")
    .select("L.*", F.sum("w").over(Window.partitionBy("n")).alias("gamma"))
    .distinct()
    .show()
#+---+---+----------+----------+
#|  n|val| distances|     gamma|
#+---+---+----------+----------+
#|  1|  1|0.27308652|0.75747334|
#|  3|  1|0.21314497|      null|
#|  2|  1|0.24969208|0.03103427|
#+---+---+----------+----------+

Or if you're more comfortable with pyspark-sql syntax, you can register temp tables and do:

df1.registerTempTable("df1")
df2.registerTempTable("df2")

sqlCtx.sql(
    "SELECT DISTINCT L.*, SUM(R.w) OVER (PARTITION BY L.n) AS gamma "
    "FROM df1 L LEFT JOIN df2 R ON L.n = R.x1 OR L.n = R.x2"
).show()
#+---+---+----------+----------+
#|  n|val| distances|     gamma|
#+---+---+----------+----------+
#|  1|  1|0.27308652|0.75747334|
#|  3|  1|0.21314497|      null|
#|  2|  1|0.24969208|0.03103427|
#+---+---+----------+----------+

Explanation

In both cases we are doing a left join of df1 to df2. This will keep all the rows in df1 regardless if there's a match.

The join clause is the condition that you specified in your question. So all rows in df2 where either x1 or x2 equals n will be joined.

Next select all of the rows from the left tables plus we group by (partition by) n and sum the values of w. This will get the sum over all rows that matched the join condition, for each value of n.

Finally we only return distinct rows to eliminate duplicates.


与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome to OStack Knowledge Sharing Community for programmer and developer-Open, Learning and Share
Click Here to Ask a Question

...