public DecisionTreeClassifier(String file) throws FileNotFoundException { String jsonStr = new Scanner(new File(file)).useDelimiter("\\Z").next(); this.clf = new Gson().fromJson(jsonStr, Classifier.class); }
public int predict(double[] features, int node) { if (this.clf.thresholds[node] != -2) { if (features[this.clf.indices[node]] <= this.clf.thresholds[node]) { return predict(features, this.clf.leftChilds[node]); } else { return predict(features, this.clf.rightChilds[node]); } } return findMax(this.clf.classes[node]); } public int predict(double[] features) { return this.predict(features, 0); }
private int findMax(int[] nums) { int index = 0; for (int i = 0; i < nums.length; i++) { index = nums[i] > nums[index] ? i : index; } return index; }
public static void main(String[] args) throws FileNotFoundException { if (args.length > 0 && args[0].endsWith(".json")) {
// Features: double[] features = new double[args.length-1]; for (int i = 1, l = args.length; i < l; i++) { features[i - 1] = Double.parseDouble(args[i]); }
// Parameters: String modelData = args[0];
// Estimators: DecisionTreeClassifier clf = new DecisionTreeClassifier(modelData);
// Prediction: int prediction = clf.predict(features); System.out.println(prediction);
} } }
从代码中可以看到,主要是使用 gson 加载 Tree 对象,使用 #predict 方法,深搜计算模型结果,符合我们对决策树模型的认知。
以随机森林模型为例, 在 Android 上,在初始化时,仅需使用模型文件构建模型对象
1
mRandomForestClf = new RandomForestClassifier(path);
private int recognizeOnSampleRF(Mat matBinSample) { int featureSize = matBinSample.rows() * matBinSample.cols(); double[] features = new double[featureSize];
Mat matDouble = new Mat(matBinSample.rows(), matBinSample.cols(), CvType.CV_64F); matBinSample.convertTo(matDouble, CvType.CV_64F); matDouble.get(0, 0, features);
int prediction = mRandomForestClf.predict(features);