在Android平台上使用决策树方法进行简单的图像分类识别

什么是决策树

决策树是一类通过自动学习数据间特征,不断分裂树形决策,最终得到一堆if-else从而进行分类的机器学习算法。

一图以蔽之:

image-20201105150720166

决策树的详细原理这里不再深究,网上有很多相关介绍,这里我们只需要知道用 scikit-learn 工具可以非常方便得实现模型训练和使用就好了。本文关心的主题是训练的代码怎么写,以及具体要怎样在 Android 平台使用。

如何训练模型

  • 准备和清洗数据,按类别放到各自文件夹下,随后导入数据,我们的样本如下:

image-20201105172015887

  • 设置好各自 label, 然后将图片二值化处理并降维(由于代码历史原因,具体降维时会将特征缩放到 -1,1 的二值化),得到特征如下:

image-20201105172207754

我这里将每个特征压缩到了 18*15,上图中的第一个为负样本(就是各类噪音数据作为一个单独的类别)。

  • 使用 train_test_split 切割训练和验证集
1
2
3
4
5
from sklearn.model_selection import train_test_split
X_train,X_test,y_train,y_test = train_test_split(ds_feature,ds_label,test_size=0.1)

X_train = np.array(X_train).reshape((np.shape(X_train)[0], 18*15)) # flatten
X_test = np.array(X_test).reshape((np.shape(X_test)[0], 18*15))
  • 真正的训练很简单,使用 sklearn 中的决策树对象在数据集上 fit 即可。然后验证。若是使用随机森林方法,也是类似写法,只是要调的参数多一点,生成的模型也成倍增大。
1
2
3
4
5
6
7
from sklearn import tree

clf = tree.DecisionTreeClassifier(max_depth=11)
clf.fit(X_train, y_train)

clf.score(X_test, y_test)
# 0.9962825278810409

其中 max_depth 是多次尝试调参后的选择。需要自己在数据集上尝试。

决策树的可视化

1
2
3
4
5
6
7
8
9
10
11
import pydotplus
import graphviz
from sklearn import tree
from IPython.display import Image

dot_tree = tree.export_graphviz(clf, out_file=None,
feature_names=dummy_feature_names,class_names=class_names,
filled=True, rounded=True, special_characters=True)
graph = pydotplus.graph_from_dot_data(dot_tree)
# img = Image(graph.create_png())
graph.write_png("out.png")

打开生成的文件就可以看到类似前文中的二叉树图像。

如何在 Android 上运行

直接使用 Java 实现决策树代码,事实上,这个代码甚至于不需要自己写,可以用一个 github 上的工具辅助生成 java 代码并保存模型数据。

1
2
3
4
5
6
7
8
from sklearn.ensemble import DecisionTreeClassifier

clf = tree.DecisionTreeClassifier(max_depth=11)
clf.fit(X_train, y_train)

from sklearn_porter import Porter
porter_clf = Porter(clf, language='java')
output = porter_clf.export(export_data=True)

将 output 打印出来,得到如下 java 代码,直接整个拷贝到自己的项目中去就能用了!

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import com.google.gson.Gson;
import java.io.File;
import java.io.FileNotFoundException;
import java.util.Scanner;


class DecisionTreeClassifier {

private class Classifier {
private int[] leftChilds;
private int[] rightChilds;
private double[] thresholds;
private int[] indices;
private int[][] classes;
}
private Classifier clf;

public DecisionTreeClassifier(String file) throws FileNotFoundException {
String jsonStr = new Scanner(new File(file)).useDelimiter("\\Z").next();
this.clf = new Gson().fromJson(jsonStr, Classifier.class);
}

public int predict(double[] features, int node) {
if (this.clf.thresholds[node] != -2) {
if (features[this.clf.indices[node]] <= this.clf.thresholds[node]) {
return predict(features, this.clf.leftChilds[node]);
} else {
return predict(features, this.clf.rightChilds[node]);
}
}
return findMax(this.clf.classes[node]);
}
public int predict(double[] features) {
return this.predict(features, 0);
}

private int findMax(int[] nums) {
int index = 0;
for (int i = 0; i < nums.length; i++) {
index = nums[i] > nums[index] ? i : index;
}
return index;
}

public static void main(String[] args) throws FileNotFoundException {
if (args.length > 0 && args[0].endsWith(".json")) {

// Features:
double[] features = new double[args.length-1];
for (int i = 1, l = args.length; i < l; i++) {
features[i - 1] = Double.parseDouble(args[i]);
}

// Parameters:
String modelData = args[0];

// Estimators:
DecisionTreeClassifier clf = new DecisionTreeClassifier(modelData);

// Prediction:
int prediction = clf.predict(features);
System.out.println(prediction);

}
}
}

从代码中可以看到,主要是使用 gson 加载 Tree 对象,使用 #predict 方法,深搜计算模型结果,符合我们对决策树模型的认知。

以随机森林模型为例, 在 Android 上,在初始化时,仅需使用模型文件构建模型对象

1
mRandomForestClf = new RandomForestClassifier(path);

然后在预测时,假定输入特征是从 opencv 来的,将特征展开为一维后,直接调用模型的 #predict 方法就能计算得到结果了:

1
2
3
4
5
6
7
8
9
10
11
12
private int recognizeOnSampleRF(Mat matBinSample) {
int featureSize = matBinSample.rows() * matBinSample.cols();
double[] features = new double[featureSize];

Mat matDouble = new Mat(matBinSample.rows(), matBinSample.cols(), CvType.CV_64F);
matBinSample.convertTo(matDouble, CvType.CV_64F);
matDouble.get(0, 0, features);

int prediction = mRandomForestClf.predict(features);

return prediction;
}

是不是比预想的要简单太多?毕竟决策树的具体算法代码居然由大佬为你自动生成好了。

决策树算法由于是 if-else 的组合,所以推演计算迅速,效率很高。由于深度可以无限增加,从而理论上可以拟合各种问题。 但是要注意,决策树非常容易过拟合,因此实作时一般换用随机森林模型,并要对深度、剪枝等参数作控制。

(Hidden Content)