Skip to content

Commit 0e49d61

Browse files
committed
update spark, metric.md
1 parent f0b77e2 commit 0e49d61

File tree

8 files changed

+224
-29
lines changed

8 files changed

+224
-29
lines changed

src/main/resources/metrics_learning.md

Lines changed: 36 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -5,40 +5,35 @@
55
* False Positive (FP) - 标签为负,预测为正
66
* False Negative (FN) - 标签为正,预测为负
77

8-
9-
|metrics|definition|description|
10-
|-------|----------|-----------|
11-
|Precision (Positive Predictive Value)|\\(PPV=\frac{TP}{TP + FP}\\)|准确率|
12-
|Recall (True Positive Rate)|\\(TPR=\frac{TP}{P}=\frac{TP}{TP + FN}\\)|召回率|
13-
|FPR|\\(FPR = \frac{FP}{FP+TN}\\)||
14-
|F-measure|\\(F(\beta) = \left(1 + \beta^2\right) \cdot \left(\frac{PPV \cdot TPR}{\beta^2 \cdot PPV + TPR}\right) = \frac{1}{\frac{1}{1+\beta^2}\frac{1}{\text{PPV}}+\frac{\beta^2}{1+\beta^2}\frac{1}{\text{TPR}}}\\) | \\(\beta\\)代表模型分类的偏好,当\\(\beta\\)小于1时,Precision更重要;<br/>当\\(\beta\\)大于1时,Recall更重要;<br/>当\\(\beta = 1\\)时,指标退化为F1。|
15-
|Receiver Operating Characteristic (ROC)|\\(FPR(T)=\int^\infty_{T} P_0(T)\,dT\\) <br/> \\(TPR(T)=\int^\infty_{T} P_1(T)\,dT\\)||
16-
|Area Under ROC Curve|\\(AUROC=\int^1_{0} \frac{TP}{P} d\left(\frac{FP}{N}\right)\\)||
17-
|Area Under Precision-Recall Curve|\\(AUPRC=\int^1_{0} \frac{TP}{TP+FP} d\left(\frac{TP}{P}\right)\\)||
18-
|MAE|\\(MAE = \frac{1}{n}\sum_{i=0}^n\mid y_i - \hat{y_i} \mid \\)|负向指标,值越小越好,取值范围0到无穷大;<br/>对比RMSE,易解释,易理解,易计算|
19-
|RMSE|\\(RMSE = \sqrt{\frac{1}{n}\sum_{i=0}^n {\({y_i} - \hat{y_i}\)}^2 }\\)|又作RMSD(root mean square deviation),负向指标,值越小越好,取值范围0到无穷大;<br/>能更好的限制误差的量级,有效识别大误差|
20-
|DCG|\\( DCG = rel_1+\sum_{i=2}^p \frac{rel_i}{\log_2 i} \\)|当权重以单调递减方式排序后,DCG可取到最大值,为iDCG(ideal DCG)|
21-
|NDCG|\\( NDCG = \frac{DCG}{iDCG}\\)| |
22-
23-
### MAE: mean absolute error
24-
25-
### RMSE: root mean square error
8+
### Confusion Matrix
9+
![](https://note.youdao.com/yws/api/personal/file/7456F54FE899436D863546AAF7A20F77?method=download&shareKey=a823568a6551ae56eb90cedaf2c594a9)
2610

2711
### KS TEST
2812
- 基于累计分布函数,用于检验数据是否符合某个分布或两个分布是否相同;
2913
- 可以用来测量模型区分正例和负例的能力,即正例分布和负例分布的分离程度的度量;
14+
- 下表显示了正负样本在不同区间上的统计数和累计个数:
15+
![](https://note.youdao.com/yws/api/personal/file/1C8823E28461422B8ACB38FD8ADAEFC7?method=download&shareKey=6bf418d12724853f1c36c9fd099a534e "ks chart")
16+
- 下图显示了正负样本的累计分布随区间阈值而变化的趋势:
17+
![](https://note.youdao.com/yws/api/personal/file/324FAED6FCE84E9788B3C575056E7293?method=download&shareKey=97ea70dfce5e51be408470a935505275 "ks chart")
18+
- 可发现,在第7个区间上,两个累计分布的间隔达到最大为(94%-12%=82%),即ks-test值为0.82
3019

31-
### DCG: discounted cumulative gain
32-
- 信息检索中,用来测量搜索引擎(排序系统,推荐系统)检索质量的评价指标;
33-
- 权重高的项排在前面的DCG值越大,越往后DCG值越小;
20+
### ROC: Receiver operating characteristic(接收机操作特性)
21+
- 以(FPR,TPR)为点作出的曲线;
3422

23+
![](https://note.youdao.com/yws/api/personal/file/409EEFE4D636423B9022ED9F60488D18?method=download&shareKey=6b22beb44139d97e26106d17648efa1a "roc")
3524

36-
### NDCG: Normalized DCG
37-
38-
### AUC: Area Under ROC Curve(Receiver operating characteristic)
25+
### AUC: Area Under ROC Curve
26+
- 当roc曲线下面积auc值为1时,意味着分类器能够完美的区分正例和负例,是一个完美分类器;
27+
- roc值为0.5时,意味着分类器无法区分正例和负例,是完全随机的分类器;
28+
- roc值在0.8以上时即为好的分类器。
3929

4030
### AUPRC: Area Under Precision-Recall Curve
4131

32+
### DCG: discounted cumulative gain
33+
- 信息检索中,用来测量搜索引擎(排序系统,推荐系统)检索质量的评价指标;
34+
- 权重高的项排在前面的DCG值越大,越往后DCG值越小;
35+
36+
### NDCG: Normalized DCG
4237

4338
> https://en.wikipedia.org/wiki/Mean_absolute_error
4439
> https://en.wikipedia.org/wiki/Root-mean-square_deviation
@@ -47,4 +42,21 @@
4742
> http://spark.apache.org/docs/2.2.1/mllib-evaluation-metrics.html
4843
> https://en.wikipedia.org/wiki/Receiver_operating_characteristic
4944
45+
46+
### 计算公式
47+
48+
|metrics|definition|description|
49+
|-------|----------|-----------|
50+
|Precision (Positive Predictive Value)|\\(PPV=\frac{TP}{TP + FP}\\)|准确率|
51+
|Recall (True Positive Rate)|\\(TPR=\frac{TP}{P}=\frac{TP}{TP + FN}\\)|召回率|
52+
|FPR|\\(FPR = \frac{FP}{FP+TN}\\)||
53+
|F-measure|\\(F(\beta) = \left(1 + \beta^2\right) \cdot \left(\frac{PPV \cdot TPR}{\beta^2 \cdot PPV + TPR}\right) = \frac{1}{\frac{1}{1+\beta^2}\frac{1}{\text{PPV}}+\frac{\beta^2}{1+\beta^2}\frac{1}{\text{TPR}}}\\) | \\(\beta\\)代表模型分类的偏好,当\\(\beta\\)小于1时,Precision更重要;<br/>当\\(\beta\\)大于1时,Recall更重要;<br/>当\\(\beta = 1\\)时,指标退化为F1。|
54+
|Receiver Operating Characteristic (ROC)|\\(FPR(T)=\int^\infty_{T} P_0(T)\,dT\\) <br/> \\(TPR(T)=\int^\infty_{T} P_1(T)\,dT\\)||
55+
|Area Under ROC Curve|\\(AUROC=\int^1_{0} \frac{TP}{P} d\left(\frac{FP}{N}\right)\\)||
56+
|Area Under Precision-Recall Curve|\\(AUPRC=\int^1_{0} \frac{TP}{TP+FP} d\left(\frac{TP}{P}\right)\\)||
57+
|MAE(mean absolute error)|\\(MAE = \frac{1}{n}\sum_{i=0}^n\mid y_i - \hat{y_i} \mid \\)|负向指标,值越小越好,取值范围0到无穷大;<br/>对比RMSE,易解释,易理解,易计算|
58+
|RMSE(root mean square error)|\\(RMSE = \sqrt{\frac{1}{n}\sum_{i=0}^n {\({y_i} - \hat{y_i}\)}^2 }\\)|又作RMSD(root mean square deviation),负向指标,值越小越好,取值范围0到无穷大;<br/>能更好的限制误差的量级,有效识别大误差|
59+
|DCG|\\( DCG = rel_1+\sum_{i=2}^p \frac{rel_i}{\log_2 i} \\)|当权重以单调递减方式排序后,DCG可取到最大值,为iDCG(ideal DCG)|
60+
|NDCG|\\( NDCG = \frac{DCG}{iDCG}\\)| |
61+
5062
<script type="text/javascript" src="http://cdn.mathjax.org/mathjax/latest/MathJax.js?config=default"></script>

src/main/scala/com/hyzs/spark/ml/MatrixOps.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ object MatrixOps extends App{
2323
IndexedRow(index, vector)
2424
}
2525

26-
2726
val rowNum = matrix.count()
2827

2928
val first = matrix.first()
Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,54 @@
11
package com.hyzs.spark.ml
22

3+
import com.hyzs.spark.ml.evaluation.{BinaryConfusionMatrix, BinaryConfusionMatrixImpl, BinaryLabelCounter}
34
import org.apache.spark.rdd.RDD
45
import com.hyzs.spark.utils.SparkUtils._
6+
import com.hyzs.spark.utils.BaseUtil._
7+
import org.apache.spark.mllib.linalg.{Vector, Vectors}
8+
import org.apache.spark.sql._
9+
import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
510

611
/**
712
* Created by xk on 2018/5/9.
813
*/
914
object ModelEvaluation {
15+
val threshold = 0.5
1016

11-
// row( index, label, score)
12-
def loadDataFromTable(tableName:String): RDD[(Int,Double)] = {
13-
val scoreRdd = spark.table(tableName).rdd.map(row => (row.getInt(1), row.getDouble(2)))
17+
// src row(index, score, label), result row(score, label)
18+
def loadDataFromTable(tableName:String): RDD[Row] = {
19+
val scoreRdd = spark.table(tableName).rdd.map(row => anySeqToRow(Seq(row(1), row(2))))
1420
scoreRdd
1521
}
1622

17-
val scoresRdd:RDD[(Int, Double)] = loadDataFromTable("scores")
23+
// src row(score, label), result row(score, label, pred_label)
24+
def getLabeledRDD(threshold:Double, rdd:RDD[Row]): RDD[Row] ={
25+
val labeledRdd = rdd.map( row => {
26+
val score = toDoubleDynamic(row(0))
27+
if(score >= threshold) Row(row.toSeq :+ 1.0)
28+
else Row(row.toSeq :+ 0.0)
29+
})
30+
labeledRdd
31+
}
32+
33+
def getConfusionMatrix(threshold:Double, labeledRdd:RDD[Row]): BinaryConfusionMatrix = {
34+
val posNum = labeledRdd.filter(row => row.getDouble(1) == 1.0).count()
35+
val negNum = labeledRdd.filter(row => row.getDouble(1) == 0.0).count()
36+
val truePosNum = labeledRdd.filter(row => row.getDouble(1) == 1.0 && row.getDouble(2) == 1.0).count()
37+
val falsePosNum = labeledRdd.filter(row => row.getDouble(1) == 0.0 && row.getDouble(2) == 1.0).count()
38+
val posCount = new BinaryLabelCounter(truePosNum, falsePosNum)
39+
val totalCount = new BinaryLabelCounter(posNum, negNum)
40+
val confusion = BinaryConfusionMatrixImpl(posCount, totalCount)
41+
confusion
42+
}
43+
44+
val scoresRdd:RDD[Row] = loadDataFromTable("scores")
45+
val predictRdd:RDD[Row] = getLabeledRDD(threshold, scoresRdd)
46+
47+
val metrics = new BinaryClassificationMetrics(
48+
scoresRdd.map( row => (toDoubleDynamic(row(0)), toDoubleDynamic(row(1))) )
49+
)
50+
val precision = metrics.precisionByThreshold
1851

1952

2053
}
54+
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package com.hyzs.spark.ml.evaluation
19+
20+
21+
/**
22+
* Created by xk on 2018/5/10.
23+
*/
24+
private[ml] trait BinaryConfusionMatrix {
25+
26+
/** number of true positives */
27+
def numTruePositives: Long
28+
29+
/** number of false positives */
30+
def numFalsePositives: Long
31+
32+
/** number of false negatives */
33+
def numFalseNegatives: Long
34+
35+
/** number of true negatives */
36+
def numTrueNegatives: Long
37+
38+
/** number of positives */
39+
def numPositives: Long = numTruePositives + numFalseNegatives
40+
41+
/** number of negatives */
42+
def numNegatives: Long = numFalsePositives + numTrueNegatives
43+
}
44+
45+
/**
46+
* Implementation of [[org.apache.spark.mllib.evaluation.binary.BinaryConfusionMatrix]].
47+
*
48+
* @param count label counter for labels with scores greater than or equal to the current score
49+
* @param totalCount label counter for all labels
50+
*/
51+
private[ml] case class BinaryConfusionMatrixImpl( count: BinaryLabelCounter,
52+
totalCount: BinaryLabelCounter) extends BinaryConfusionMatrix {
53+
54+
/** number of true positives */
55+
override def numTruePositives: Long = count.numPositives
56+
57+
/** number of false positives */
58+
override def numFalsePositives: Long = count.numNegatives
59+
60+
/** number of false negatives */
61+
override def numFalseNegatives: Long = totalCount.numPositives - count.numPositives
62+
63+
/** number of true negatives */
64+
override def numTrueNegatives: Long = totalCount.numNegatives - count.numNegatives
65+
66+
/** number of positives */
67+
override def numPositives: Long = totalCount.numPositives
68+
69+
/** number of negatives */
70+
override def numNegatives: Long = totalCount.numNegatives
71+
}
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
package com.hyzs.spark.ml.evaluation
2+
3+
/**
4+
* Created by xk on 2018/5/10.
5+
*/
6+
/**
7+
* A counter for positives and negatives.
8+
*
9+
*
10+
* @param numPositives number of positive labels
11+
* @param numNegatives number of negative labels
12+
*/
13+
private[ml] class BinaryLabelCounter( var numPositives: Long = 0L,
14+
var numNegatives: Long = 0L) extends Serializable {
15+
16+
/** Processes a label. */
17+
def +=(label: Double): BinaryLabelCounter = {
18+
// Though we assume 1.0 for positive and 0.0 for negative, the following check will handle
19+
// -1.0 for negative as well.
20+
if (label >= 0.5) numPositives += 1L else numNegatives += 1L
21+
this
22+
}
23+
24+
/** Merges another counter. */
25+
def +=(other: BinaryLabelCounter): BinaryLabelCounter = {
26+
numPositives += other.numPositives
27+
numNegatives += other.numNegatives
28+
this
29+
}
30+
31+
override def clone: BinaryLabelCounter = {
32+
new BinaryLabelCounter(numPositives, numNegatives)
33+
}
34+
35+
override def toString: String = s"{numPos: $numPositives, numNeg: $numNegatives}"
36+
}

src/main/scala/com/hyzs/spark/utils/BaseUtil.scala

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
package com.hyzs.spark.utils
22

33
import java.text.SimpleDateFormat
4+
5+
import org.apache.spark.mllib.linalg.{Vector, Vectors}
6+
import org.apache.spark.sql.Row
47
import scala.util.{Failure, Success, Try}
58
/**
69
* Created by Administrator on 2018/2/5.
@@ -18,7 +21,24 @@ object BaseUtil {
1821
}
1922
}
2023

24+
def toDoubleDynamic(x: Any): Double = x match {
25+
case s: String => s.toDouble
26+
case num: java.lang.Number => num.doubleValue()
27+
case _ => throw new ClassCastException("cannot cast to double")
28+
}
2129

30+
def anySeqToSparkVector[T](x: Any): Vector = x match {
31+
case a: Array[T] => Vectors.dense(a.map(toDoubleDynamic))
32+
case s: Seq[Any] => Vectors.dense(s.toArray.map(toDoubleDynamic))
33+
case v: Vector => v
34+
case _ => throw new ClassCastException("unsupported class")
35+
}
2236

37+
def anySeqToRow[T](x:Any): Row = x match {
38+
case a: Array[T] => Row(a.map(toDoubleDynamic))
39+
case s: Seq[Any] => Row(s.map(toDoubleDynamic))
40+
case r: Row => Row(r.toSeq.map(toDoubleDynamic))
41+
case _ => throw new ClassCastException("unsupported class")
42+
}
2343

2444
}

src/main/scala/com/hyzs/spark/utils/SparkUtils.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import com.fasterxml.jackson.module.scala.DefaultScalaModule
66
import org.apache.hadoop.conf.Configuration
77
import org.apache.hadoop.fs.{FileStatus, FileSystem, FileUtil, Path}
88
import org.apache.spark.broadcast.Broadcast
9+
import org.apache.spark.mllib.linalg.{Vector, Vectors}
910
import org.apache.spark.rdd.RDD
1011
import org.apache.spark.sql.{Dataset, Row, SparkSession}
1112
import org.apache.spark.{SparkConf, SparkContext}
@@ -154,4 +155,6 @@ object SparkUtils {
154155
SizeEstimator.estimate(rdd)
155156
}
156157

158+
159+
157160
}

src/test/scala/ScalaTest.scala

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ import scala.annotation.tailrec
44
import java.io._
55

66
import com.hyzs.spark.utils.BaseUtil
7+
import com.hyzs.spark.utils.BaseUtil._
8+
import org.apache.spark.sql.Row
79

810
import scala.io.Source
911
import scala.util.Random
@@ -87,4 +89,22 @@ class ScalaTest extends FunSuite{
8789
writer.close()
8890
}
8991

92+
test("evaluation test file"){
93+
val writer = new PrintWriter(new File("d:/evaluation_test.txt"))
94+
for( i <- 0 until 100000){
95+
writer.write(i+",")
96+
writer.write(Random.nextDouble+",")
97+
writer.write(Random.nextInt(2)+"\n")
98+
}
99+
writer.close()
100+
}
101+
102+
test("type cast in spark"){
103+
val row = Row(1,3,4.0, "5")
104+
for(v <- row.toSeq){
105+
println(toDoubleDynamic(v))
106+
}
107+
println(anySeqToSparkVector(Array(1,2.3,3)))
108+
println(anySeqToSparkVector(row.toSeq))
109+
}
90110
}

0 commit comments

Comments
 (0)