Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Welcome To Ask or Share your Answers For Others

Categories

0 votes
257 views
in Technique[技术] by (71.8m points)

apache spark sql - Scala: how to know which probability correspond to which class?

I create a classifier random forest to predict something. The label is either "yes" (=1.0) or "no" (=0.0)

I apply my model on a test. Here is my code and my result for 20 lines:

import org.apache.spark.ml.tuning.CrossValidatorModel
import org.apache.spark.sql.types._
import org.apache.spark.sql._
import org.apache.spark.sql.functions.udf
import org.apache.spark.sql.functions._

var modelrf = CrossValidatorModel.load("modelSupervise/newModel")
var test = spark.sql("""select * from dc.newTest""")

var predictions = modelrf.transform(test)

predictions.select("id","label","rawPrediction","probability","prediction").show(20,false)


+--------+--------------+----------------------------------------+-----------------------------------------+----------+
|id      |label         |rawPrediction                           |probability                              |prediction|
+--------+--------------+----------------------------------------+-----------------------------------------+----------+
|1       |0             |[18.954508743604,1.0454912563959982]    |[0.9477254371802001,0.05227456281979992] |0.0       |
|2       |0             |[19.396893651115214,0.6031063488847838] |[0.9698446825557608,0.030155317444239195]|0.0       |
|3       |0             |[19.562942473138747,0.4370575268612524] |[0.9781471236569373,0.02185287634306262] |0.0       |
|4       |0             |[19.072030495384865,0.9279695046151306] |[0.9536015247692434,0.04639847523075654] |0.0       |
|5       |0             |[19.43338228765314,0.5666177123468583]  |[0.9716691143826571,0.02833088561734292] |0.0       |
|6       |0             |[19.696154641398266,0.3038453586017339] |[0.9848077320699133,0.015192267930086694]|0.0       |
|7       |0             |[19.561887703818552,0.4381122961814507] |[0.9780943851909274,0.02190561480907253] |0.0       |
|8       |0             |[19.670868420870097,0.32913157912990343]|[0.9835434210435048,0.01645657895649517] |0.0       |
|9       |0             |[19.31258444658832,0.6874155534116762]  |[0.9656292223294163,0.034370777670583816]|0.0       |
|10      |1             |[19.324118365007614,0.6758816349923846] |[0.9662059182503807,0.03379408174961923] |0.0       |
|11      |0             |[19.671923190190295,0.32807680980970505]|[0.9835961595095147,0.016403840490485253]|0.0       |
|12      |0             |[5.549867107480572,14.450132892519427]  |[0.2774933553740286,0.7225066446259714]  |1.0       |
|13      |0             |[8.302734500577003,11.697265499422995]  |[0.41513672502885013,0.5848632749711498] |1.0       |
|14      |0             |[3.719926021010336,16.280073978989666]  |[0.1859963010505168,0.8140036989494831]  |1.0       |
|15      |1             |[4.9810130629790486,15.018986937020955] |[0.2490506531489524,0.7509493468510476]  |1.0       |
|16      |1             |[7.575144612227263,12.424855387772734]  |[0.37875723061136324,0.6212427693886368] |1.0       |
|17      |0             |[9.763210063340546,10.236789936659454]  |[0.4881605031670273,0.5118394968329727]  |1.0       |
|18      |0             |[9.475787091640768,10.524212908359234]  |[0.4737893545820384,0.5262106454179617]  |1.0       |
|19      |1             |[4.236097613170449,15.763902386829551]  |[0.21180488065852243,0.7881951193414776] |1.0       |
|20      |0             |[8.748700591583557,11.251299408416445]  |[0.43743502957917785,0.5625649704208222] |1.0       |
|21      |0             |[8.908800090849974,11.091199909150026]  |[0.4454400045424987,0.5545599954575013]  |1.0       |
|22      |1             |[9.726530070446398,10.273469929553602]  |[0.4863265035223199,0.5136734964776801]  |1.0       |
|23      |1             |[8.908800090849974,11.091199909150026]  |[0.4454400045424987,0.5545599954575013]  |1.0       |
+--------+--------------+----------------------------------------+-----------------------------------------+----------+

Here is what I understand first:

for id=1. 18.95 trees predict the value "0.0" and 1.045 trees predict the value "1.1". I thought that scala order the values of the vector "rawPrediction" regaring the value of the class --> first regard the class "0" and the second one regard the class "1".

But if it were true and if we had "yes" or "no" instead of 0 or 1, what order would scala give? Alphabetical order?

I made some research and I find this question: Random Forest Classifier :To which class corresponds the probabilities

The question is the same but for the vector "probability". Which element of the vector correspond to the probability to predict "0" and which element correspond to the probability to predict "1"?

I do not understand the answer...

How to know, for each line, what is the probability for the model to predict "yes" (or 1)? Does scala order probabilities numericaly or alphabeticaly regarding the type of the label...?

Thank you in advance!!

See Question&Answers more detail:os

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome To Ask or Share your Answers For Others

1 Reply

0 votes
by (71.8m points)

Here is the answer!!! In my question I load a model.

But the answer is before that.

To fit the model I use a labelIndexer on my target. This label indexer transform the target into an indexe by descending frequency.

ex: if, in my target I have 20% of "aa" and 80% of "bb" label indexer will create a column "label" that took the value 0 for "bb" and 1 for "aa" (because I "bb" is ore frequent than "aa")

When we fit a random forest, the probabilities correspond to the order of frequency.

In binary classification:

  • first proba = probability that the class is the most frequent class in the train set
  • second proba = probability that the class is the less frequent class in the train set

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
OGeek|极客中国-欢迎来到极客的世界,一个免费开放的程序员编程交流平台!开放,进步,分享!让技术改变生活,让极客改变未来! Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Click Here to Ask a Question

...