Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Welcome To Ask or Share your Answers For Others

Categories

0 votes
1.1k views
in Technique[技术] by (71.8m points)

scala - How to add columns into org.apache.spark.sql.Row inside of mapPartitions

I am a newbie at scala and spark, please keep that in mind :)

Actually, I have three questions

  1. How should I define function to pass it into df.rdd.mapPartitions, if I want to create new Row with few additional columns
  2. How can I add few columns into Row object(or create a new one)
  3. How create DataFrame from created RDD

Thank you at advance

See Question&Answers more detail:os

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome To Ask or Share your Answers For Others

1 Reply

0 votes
by (71.8m points)

Usually there should be no need for that and it is better to use UDFs but here you are:

How should I define function to pass it into df.rdd.mapPartitions, if I want to create new Row with few additional columns

It should take Iterator[Row] and return Iterator[T] so in your case you should use something like this

import org.apache.spark.sql.Row

def transformRows(iter: Iterator[Row]): Iterator[Row] = ???

How can I add few columns into Row object(or create a new one)

There are multiple ways of accessing Row values including Row.get* methods, Row.toSeq etc. New Row can be created using Row.apply, Row.fromSeq, Row.fromTuple or RowFactory. For example:

def transformRow(row: Row): Row =  Row.fromSeq(row.toSeq ++ Array[Any](-1, 1))

How create DataFrame from created RDD

If you have RDD[Row] you can use SQLContext.createDataFrame and provide schema.

Putting this all together:

import org.apache.spark.sql.types.{IntegerType, StructField, StructType}

val  df = sc.parallelize(Seq(
    (1.0, 2.0), (0.0, -1.0),
    (3.0, 4.0), (6.0, -2.3))).toDF("x", "y")

def transformRows(iter: Iterator[Row]): Iterator[Row] = iter.map(transformRow)

val newSchema = StructType(df.schema.fields ++ Array(
  StructField("z", IntegerType, false), StructField("v", IntegerType, false)))

sqlContext.createDataFrame(df.rdd.mapPartitions(transformRows), newSchema).show

// +---+----+---+---+
// |  x|   y|  z|  v|
// +---+----+---+---+
// |1.0| 2.0| -1|  1|
// |0.0|-1.0| -1|  1|
// |3.0| 4.0| -1|  1|
// |6.0|-2.3| -1|  1|
// +---+----+---+---+

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
OGeek|极客中国-欢迎来到极客的世界,一个免费开放的程序员编程交流平台!开放,进步,分享!让技术改变生活,让极客改变未来! Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Click Here to Ask a Question

...