I have the following example Spark DataFrame:
rdd = sc.parallelize([(1,"19:00:00", "19:30:00", 30), (1,"19:30:00", "19:40:00", 10),(1,"19:40:00", "19:43:00", 3), (2,"20:00:00", "20:10:00", 10), (1,"20:05:00", "20:15:00", 10),(1,"20:15:00", "20:35:00", 20)])
df = spark.createDataFrame(rdd, ["user_id", "start_time", "end_time", "duration"])
df.show()
+-------+----------+--------+--------+
|user_id|start_time|end_time|duration|
+-------+----------+--------+--------+
| 1| 19:00:00|19:30:00| 30|
| 1| 19:30:00|19:40:00| 10|
| 1| 19:40:00|19:43:00| 3|
| 2| 20:00:00|20:10:00| 10|
| 1| 20:05:00|20:15:00| 10|
| 1| 20:15:00|20:35:00| 20|
+-------+----------+--------+--------+
I want to group consecutive rows based on the start and end times. For instance, for the same user_id, if a row's start time is the same as the previous row's end time, I want to group them together and sum the duration.
The desired result is:
+-------+----------+--------+--------+
|user_id|start_time|end_time|duration|
+-------+----------+--------+--------+
| 1| 19:00:00|19:43:00| 43|
| 2| 20:00:00|20:10:00| 10|
| 1| 20:05:00|20:35:00| 30|
+-------+----------+--------+--------+
The first three rows of the dataframe were grouped together because they all correspond to user_id 1 and the start times and end times form a continuous timeline.
This was my initial approach:
Use the lag function to get the next start time:
from pyspark.sql.functions import *
from pyspark.sql import Window
import sys
# compute next start time
window = Window.partitionBy('user_id').orderBy('start_time')
df = df.withColumn("next_start_time", lag(df.start_time, -1).over(window))
df.show()
+-------+----------+--------+--------+---------------+
|user_id|start_time|end_time|duration|next_start_time|
+-------+----------+--------+--------+---------------+
| 1| 19:00:00|19:30:00| 30| 19:30:00|
| 1| 19:30:00|19:40:00| 10| 19:40:00|
| 1| 19:40:00|19:43:00| 3| 20:05:00|
| 1| 20:05:00|20:15:00| 10| 20:15:00|
| 1| 20:15:00|20:35:00| 20| null|
| 2| 20:00:00|20:10:00| 10| null|
+-------+----------+--------+--------+---------------+
get the difference between the current row's end time and the next row's start time:
time_fmt = "HH:mm:ss"
timeDiff = unix_timestamp('next_start_time', format=time_fmt) - unix_timestamp('end_time', format=time_fmt)
df = df.withColumn("difference", timeDiff)
df.show()
+-------+----------+--------+--------+---------------+----------+
|user_id|start_time|end_time|duration|next_start_time|difference|
+-------+----------+--------+--------+---------------+----------+
| 1| 19:00:00|19:30:00| 30| 19:30:00| 0|
| 1| 19:30:00|19:40:00| 10| 19:40:00| 0|
| 1| 19:40:00|19:43:00| 3| 20:05:00| 1320|
| 1| 20:05:00|20:15:00| 10| 20:15:00| 0|
| 1| 20:15:00|20:35:00| 20| null| null|
| 2| 20:00:00|20:10:00| 10| null| null|
+-------+----------+--------+--------+---------------+----------+
Now my idea was to use the sum function with a window to get the cumulative sum of duration and then do a groupBy. But my approach was flawed for many reasons.
See Question&Answers more detail:
os