You can't reference a DataFrame inside of a udf
. As your alluded to, this problem is best solved using a join
.
IIUC, you are looking for something like:
from pyspark.sql import Window
import pyspark.sql.functions as F
df1.alias("L").join(df2.alias("R"), (df1.n == df2.x1) | (df1.n == df2.x2), how="left")
.select("L.*", F.sum("w").over(Window.partitionBy("n")).alias("gamma"))
.distinct()
.show()
#+---+---+----------+----------+
#| n|val| distances| gamma|
#+---+---+----------+----------+
#| 1| 1|0.27308652|0.75747334|
#| 3| 1|0.21314497| null|
#| 2| 1|0.24969208|0.03103427|
#+---+---+----------+----------+
Or if you're more comfortable with pyspark-sql
syntax, you can register temp tables and do:
df1.registerTempTable("df1")
df2.registerTempTable("df2")
sqlCtx.sql(
"SELECT DISTINCT L.*, SUM(R.w) OVER (PARTITION BY L.n) AS gamma "
"FROM df1 L LEFT JOIN df2 R ON L.n = R.x1 OR L.n = R.x2"
).show()
#+---+---+----------+----------+
#| n|val| distances| gamma|
#+---+---+----------+----------+
#| 1| 1|0.27308652|0.75747334|
#| 3| 1|0.21314497| null|
#| 2| 1|0.24969208|0.03103427|
#+---+---+----------+----------+
Explanation
In both cases we are doing a left join of df1
to df2
. This will keep all the rows in df1
regardless if there's a match.
The join clause is the condition that you specified in your question. So all rows in df2
where either x1
or x2
equals n
will be joined.
Next select all of the rows from the left tables plus we group by (partition by) n
and sum the values of w
. This will get the sum over all rows that matched the join condition, for each value of n
.
Finally we only return distinct rows to eliminate duplicates.
与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…