Welcome to OStack Knowledge Sharing Community for programmer and developer-Open, Learning and Share
Welcome To Ask or Share your Answers For Others

Categories

0 votes
1.2k views
in Technique[技术] by (71.8m points)

pyspark - Using a column value as a parameter to a spark DataFrame function

Consider the following DataFrame:

#+------+---+
#|letter|rpt|
#+------+---+
#|     X|  3|
#|     Y|  1|
#|     Z|  2|
#+------+---+

which can be created using the following code:

df = spark.createDataFrame([("X", 3),("Y", 1),("Z", 2)], ["letter", "rpt"])

Suppose I wanted to repeat each row the number of times specified in the column rpt, just like in this question.

One way would be to replicate my solution to that question using the following pyspark-sql query:

query = """
SELECT *
FROM
  (SELECT DISTINCT *,
                   posexplode(split(repeat(",", rpt), ",")) AS (index, col)
   FROM df) AS a
WHERE index > 0
"""
query = query.replace("
", " ")  # replace newlines with spaces, avoid EOF error
spark.sql(query).drop("col").sort('letter', 'index').show()
#+------+---+-----+
#|letter|rpt|index|
#+------+---+-----+
#|     X|  3|    1|
#|     X|  3|    2|
#|     X|  3|    3|
#|     Y|  1|    1|
#|     Z|  2|    1|
#|     Z|  2|    2|
#+------+---+-----+

This works and produces the correct answer. However, I am unable to replicate this behavior using the DataFrame API functions.

I tried:

import pyspark.sql.functions as f
df.select(
    f.posexplode(f.split(f.repeat(",", f.col("rpt")), ",")).alias("index", "col")
).show()

But this results in:

TypeError: 'Column' object is not callable

Why am I able to pass the column as an input to repeat within the query, but not from the API? Is there a way to replicate this behavior using the spark DataFrame functions?

Question&Answers:os

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome To Ask or Share your Answers For Others

1 Answer

0 votes
by (71.8m points)

One option is to use pyspark.sql.functions.expr, which allows you to use columns values as inputs to spark-sql functions.

Based on @user8371915's comment I have found that the following works:

from pyspark.sql.functions import expr

df.select(
    '*',
    expr('posexplode(split(repeat(",", rpt), ","))').alias("index", "col")
).where('index > 0').drop("col").sort('letter', 'index').show()
#+------+---+-----+
#|letter|rpt|index|
#+------+---+-----+
#|     X|  3|    1|
#|     X|  3|    2|
#|     X|  3|    3|
#|     Y|  1|    1|
#|     Z|  2|    1|
#|     Z|  2|    2|
#+------+---+-----+

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome to OStack Knowledge Sharing Community for programmer and developer-Open, Learning and Share
Click Here to Ask a Question

...