Since Spark 2.3 you can use pandas_udf
. GROUPED_MAP
takes Callable[[pandas.DataFrame], pandas.DataFrame]
or in other words a function which maps from Pandas DataFrame
of the same shape as the input, to the output DataFrame
.
For example if data looks like this:
df = spark.createDataFrame(
[("a", 1, 0), ("a", -1, 42), ("b", 3, -1), ("b", 10, -2)],
("key", "value1", "value2")
)
and you want to compute average value of pairwise min between value1
value2
, you have to define output schema:
from pyspark.sql.types import *
schema = StructType([
StructField("key", StringType()),
StructField("avg_min", DoubleType())
])
pandas_udf
:
import pandas as pd
from pyspark.sql.functions import pandas_udf
from pyspark.sql.functions import PandasUDFType
@pandas_udf(schema, functionType=PandasUDFType.GROUPED_MAP)
def g(df):
result = pd.DataFrame(df.groupby(df.key).apply(
lambda x: x.loc[:, ["value1", "value2"]].min(axis=1).mean()
))
result.reset_index(inplace=True, drop=False)
return result
and apply it:
df.groupby("key").apply(g).show()
+---+-------+
|key|avg_min|
+---+-------+
| b| -1.5|
| a| -0.5|
+---+-------+
Excluding schema definition and decorator, your current Pandas code can be applied as-is.
Since Spark 2.4.0 there is also GROUPED_AGG
variant, which takes Callable[[pandas.Series, ...], T]
, where T
is a primitive scalar:
import numpy as np
@pandas_udf(DoubleType(), functionType=PandasUDFType.GROUPED_AGG)
def f(x, y):
return np.minimum(x, y).mean()
which can be used with standard group_by
/ agg
construct:
df.groupBy("key").agg(f("value1", "value2").alias("avg_min")).show()
+---+-------+
|key|avg_min|
+---+-------+
| b| -1.5|
| a| -0.5|
+---+-------+
Please note that neither GROUPED_MAP
nor GROUPPED_AGG
pandas_udf
behave the same way as UserDefinedAggregateFunction
or Aggregator
, and it is closer to groupByKey
or window functions with unbounded frame. Data is shuffled first, and only after that, UDF is applied.
For optimized execution you should implement Scala UserDefinedAggregateFunction
and add Python wrapper.
See also User defined function to be applied to Window in PySpark?