Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Welcome To Ask or Share your Answers For Others

Categories

0 votes
239 views
in Technique[技术] by (71.8m points)

scala - Convert ArrayType(FloatType,false) to VectorUTD

I want to perform cluster analysis using K-Means on itemFactors produced by ALS. Although the itemFactors of ALSModel returns a dataframe that contains the id and the features of the itemFactors, this data structure seems to be unsuitable for K-Means.

Here's the code for collaborative filtering using ALS:

val als = new ALS()
      .setRegParam(0.01)
      .setNonnegative(false)
      .setUserCol("userId")
      .setItemCol("movieId")
      .setRatingCol("rating")

 val model = als.fit(training)
 val predictions = model.transform(testing)

 val item_factors = model.itemFactors

item_factors dataframe looks like

+---+-------------------------------------------------------------------------------------------------------------------------------------+
|id |features                                                                                                                             |
+---+-------------------------------------------------------------------------------------------------------------------------------------+
|10 |[-0.1317064, 0.07098049, -0.042259596, -0.28769347, 0.58783025, -0.33474237, 0.31248248, -0.34541374, 0.33257273, 0.06327486]        |
|20 |[-0.0033912044, 0.31334892, -0.080896676, -0.75597364, -0.016326033, -0.34558973, 0.045129072, -0.38614395, -0.02269395, -0.16486467]|
|30 |[0.19784503, -0.313929, -0.67753965, -0.7700008, 0.08975326, -0.03427274, 0.49707127, 0.05604595, 0.078268416, 0.08767615]           |
|40 |[0.29390565, -0.22765353, -0.9278744, -0.59953785, 0.184721, -0.061099682, 0.33711356, 0.094112396, 0.08261518, -0.30668002]         |
|50 |[-0.4070981, -0.0013739555, -0.21247752, -0.3771588, 0.3029064, -0.3883846, 0.4752892, 0.30097932, 0.5130039, 0.2938855]             |
|60 |[0.1413918, -0.074142076, -0.87392575, -0.07855377, -0.11006678, -0.44359666, 0.33419594, -0.16027139, -0.2440797, -0.1596081]       |
|70 |[-0.26080364, -0.11437138, 0.046630252, -0.70999575, 0.014645281, -0.69176155, 0.05397229, -0.24038066, -0.429569, 0.5660369]        |
|80 |[0.6104476, -0.35322133, -0.80230886, -0.5302148, -0.26538768, -0.25481275, 0.20784922, -0.10604211, 0.26007786, 0.47488773]         |
|90 |[0.6976714, -0.5851011, -0.64844996, -0.82472694, 0.102610275, -0.45195442, 0.24074861, 0.2683314, 0.11396688, -0.52693856]          |
|100|[-0.11564436, 0.21467225, -0.42873487, -0.54825515, 0.20628366, -0.28728506, 0.18303588, 0.11490151, -0.033433616, -0.08694091]      |
|110|[-0.530162, 0.22694068, -0.30889827, -0.091455124, 0.52988344, -0.7247424, 0.029707031, 0.43658048, 0.21511139, -0.22376455]         |
|120|[0.59780246, -0.3396686, -0.58882934, -0.11867501, -0.6055776, -0.82480395, -0.22715187, -0.4544479, 0.012708589, -0.22158282]       |
|130|[0.9630984, -0.012603591, -0.37178686, -1.0995674, -0.57324636, -0.7460034, 1.2981551, 0.15384857, -1.0350431, -0.58156097]          |
|140|[-0.1617866, 0.3927005, -0.26183906, -0.3666182, -0.015750444, -0.28372696, 0.3577147, -0.18155682, 0.22410324, -0.5632848]          |
|150|[-0.20490485, 0.37170428, -0.47898963, 0.0686825, 0.31148073, -0.4663402, 0.2088939, -0.0071071014, 0.44748953, 0.0067634075]        |
|160|[0.31892687, 0.30109385, -0.036033046, -0.58646286, 0.015361498, -0.5640331, 0.010378816, -0.52527076, -0.20914118, -0.07263985]     |
|170|[0.13082151, -0.082676716, 0.15034986, -0.7333888, 0.14089121, -0.34780806, 0.51327425, -0.43825528, 0.2210635, -0.19778338]         |
|180|[-0.45791233, -0.64516217, 0.3496911, -0.6879449, 0.11970334, -0.3473338, 0.30204558, -0.18284592, 0.5934964, 0.06711411]            |
|190|[0.41464698, 0.04347724, -0.9297292, -1.2885705, -0.5567429, 0.2531382, 0.11184802, -0.46155334, -0.3385828, 0.789031]               |
|200|[0.37707302, -0.023397477, -0.47769275, -0.99200153, -0.11546725, -0.125011, -0.07772487, -0.5624814, -0.026348682, -0.33438805]     |
+---+-------------------------------------------------------------------------------------------------------------------------------------+

And here is the code for K-Means clustering.

 val kmeans = new KMeans().setK(10).setSeed(1L)
 val kmeans_model = kmeans.fit(item_factors)
 val predictions = kmeans_model.transform(item_factors)

The error I get when the item_factors dataframe is fed into the K-Means is shown below:

Exception in thread "main" java.lang.IllegalArgumentException: requirement failed: Column features must be of type org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7 but was actually ArrayType(FloatType,false).
See Question&Answers more detail:os

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome To Ask or Share your Answers For Others

1 Reply

0 votes
by (71.8m points)

You can map array to vector:

import org.apache.spark.ml.linalg._
import  scala.collection.mutable.WrappedArray

val itemFactors = model.itemFactors
val convertUDF = udf((array : Seq[Double]) => {
  Vectors.dense(array.toArray)
})
val withVector = itemFactors
  .withColumn("features", convertUDF('features.cast("array<double>")))

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
OGeek|极客中国-欢迎来到极客的世界,一个免费开放的程序员编程交流平台!开放,进步,分享!让技术改变生活,让极客改变未来! Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Click Here to Ask a Question

...