We can use explode() function to solve this issue. In Python, the same thing can be done with melt
# Loading the requisite packages
from pyspark.sql.functions import col, explode, array, struct, expr, sum, lit
# Creating the DataFrame
df = sqlContext.createDataFrame([(100,30,105,35),(200,55,85,65),(300,20,125,90)],('store_id','qty_on_hand_milk','qty_on_hand_bread','qty_on_hand_eggs'))
df.show()
+--------+----------------+-----------------+----------------+
|store_id|qty_on_hand_milk|qty_on_hand_bread|qty_on_hand_eggs|
+--------+----------------+-----------------+----------------+
| 100| 30| 105| 35|
| 200| 55| 85| 65|
| 300| 20| 125| 90|
+--------+----------------+-----------------+----------------+
Writing the function below, which shall explode
this DataFrame:
def to_explode(df, by):
# Filter dtypes and split into column names and type description
cols, dtypes = zip(*((c, t) for (c, t) in df.dtypes if c not in by))
# Spark SQL supports only homogeneous columns
assert len(set(dtypes)) == 1, "All columns have to be of the same type"
# Create and explode an array of (column_name, column_value) structs
kvs = explode(array([
struct(lit(c).alias("CATEGORY"), col(c).alias("qty_on_hand")) for c in cols
])).alias("kvs")
return df.select(by + [kvs]).select(by + ["kvs.CATEGORY", "kvs.qty_on_hand"])
Applying the function on this DataFrame to explode
it-
df = to_explode(df, ['store_id'])
.drop('store_id')
df.show()
+-----------------+-----------+
| CATEGORY|qty_on_hand|
+-----------------+-----------+
| qty_on_hand_milk| 30|
|qty_on_hand_bread| 105|
| qty_on_hand_eggs| 35|
| qty_on_hand_milk| 55|
|qty_on_hand_bread| 85|
| qty_on_hand_eggs| 65|
| qty_on_hand_milk| 20|
|qty_on_hand_bread| 125|
| qty_on_hand_eggs| 90|
+-----------------+-----------+
Now, we need to remove the string qty_on_hand_
from CATEGORY
column. It can be done using expr() function. Note expr
follows 1 based indexing for the substring, as opposed to 0 -
df = df.withColumn('CATEGORY',expr('substring(CATEGORY, 13)'))
df.show()
+--------+-----------+
|CATEGORY|qty_on_hand|
+--------+-----------+
| milk| 30|
| bread| 105|
| eggs| 35|
| milk| 55|
| bread| 85|
| eggs| 65|
| milk| 20|
| bread| 125|
| eggs| 90|
+--------+-----------+
Finally, aggregating the column qty_on_hand
grouped by CATEGORY
using agg() function -
df = df.groupBy(['CATEGORY']).agg(sum('qty_on_hand').alias('total_qty_on_hand'))
df.show()
+--------+-----------------+
|CATEGORY|total_qty_on_hand|
+--------+-----------------+
| eggs| 190|
| bread| 315|
| milk| 105|
+--------+-----------------+