mirror of
https://gitcode.com/gh_mirrors/es/esp32-opencv.git
synced 2025-08-14 10:40:47 +08:00
initial commit
This commit is contained in:
116
samples/cpp/tree_engine.cpp
Normal file
116
samples/cpp/tree_engine.cpp
Normal file
@ -0,0 +1,116 @@
|
||||
#include "opencv2/ml.hpp"
|
||||
#include "opencv2/core.hpp"
|
||||
#include "opencv2/core/utility.hpp"
|
||||
#include <stdio.h>
|
||||
#include <string>
|
||||
#include <map>
|
||||
|
||||
using namespace cv;
|
||||
using namespace cv::ml;
|
||||
|
||||
static void help()
|
||||
{
|
||||
printf(
|
||||
"\nThis sample demonstrates how to use different decision trees and forests including boosting and random trees.\n"
|
||||
"Usage:\n\t./tree_engine [-r=<response_column>] [-ts=type_spec] <csv filename>\n"
|
||||
"where -r=<response_column> specified the 0-based index of the response (0 by default)\n"
|
||||
"-ts= specifies the var type spec in the form ord[n1,n2-n3,n4-n5,...]cat[m1-m2,m3,m4-m5,...]\n"
|
||||
"<csv filename> is the name of training data file in comma-separated value format\n\n");
|
||||
}
|
||||
|
||||
static void train_and_print_errs(Ptr<StatModel> model, const Ptr<TrainData>& data)
|
||||
{
|
||||
bool ok = model->train(data);
|
||||
if( !ok )
|
||||
{
|
||||
printf("Training failed\n");
|
||||
}
|
||||
else
|
||||
{
|
||||
printf( "train error: %f\n", model->calcError(data, false, noArray()) );
|
||||
printf( "test error: %f\n\n", model->calcError(data, true, noArray()) );
|
||||
}
|
||||
}
|
||||
|
||||
int main(int argc, char** argv)
|
||||
{
|
||||
cv::CommandLineParser parser(argc, argv, "{ help h | | }{r | 0 | }{ts | | }{@input | | }");
|
||||
if (parser.has("help"))
|
||||
{
|
||||
help();
|
||||
return 0;
|
||||
}
|
||||
std::string filename = parser.get<std::string>("@input");
|
||||
int response_idx;
|
||||
std::string typespec;
|
||||
response_idx = parser.get<int>("r");
|
||||
typespec = parser.get<std::string>("ts");
|
||||
if( filename.empty() || !parser.check() )
|
||||
{
|
||||
parser.printErrors();
|
||||
help();
|
||||
return 0;
|
||||
}
|
||||
printf("\nReading in %s...\n\n",filename.c_str());
|
||||
const double train_test_split_ratio = 0.5;
|
||||
|
||||
Ptr<TrainData> data = TrainData::loadFromCSV(filename, 0, response_idx, response_idx+1, typespec);
|
||||
if( data.empty() )
|
||||
{
|
||||
printf("ERROR: File %s can not be read\n", filename.c_str());
|
||||
return 0;
|
||||
}
|
||||
|
||||
data->setTrainTestSplitRatio(train_test_split_ratio);
|
||||
std::cout << "Test/Train: " << data->getNTestSamples() << "/" << data->getNTrainSamples();
|
||||
|
||||
printf("======DTREE=====\n");
|
||||
Ptr<DTrees> dtree = DTrees::create();
|
||||
dtree->setMaxDepth(10);
|
||||
dtree->setMinSampleCount(2);
|
||||
dtree->setRegressionAccuracy(0);
|
||||
dtree->setUseSurrogates(false);
|
||||
dtree->setMaxCategories(16);
|
||||
dtree->setCVFolds(0);
|
||||
dtree->setUse1SERule(false);
|
||||
dtree->setTruncatePrunedTree(false);
|
||||
dtree->setPriors(Mat());
|
||||
train_and_print_errs(dtree, data);
|
||||
|
||||
if( (int)data->getClassLabels().total() <= 2 ) // regression or 2-class classification problem
|
||||
{
|
||||
printf("======BOOST=====\n");
|
||||
Ptr<Boost> boost = Boost::create();
|
||||
boost->setBoostType(Boost::GENTLE);
|
||||
boost->setWeakCount(100);
|
||||
boost->setWeightTrimRate(0.95);
|
||||
boost->setMaxDepth(2);
|
||||
boost->setUseSurrogates(false);
|
||||
boost->setPriors(Mat());
|
||||
train_and_print_errs(boost, data);
|
||||
}
|
||||
|
||||
printf("======RTREES=====\n");
|
||||
Ptr<RTrees> rtrees = RTrees::create();
|
||||
rtrees->setMaxDepth(10);
|
||||
rtrees->setMinSampleCount(2);
|
||||
rtrees->setRegressionAccuracy(0);
|
||||
rtrees->setUseSurrogates(false);
|
||||
rtrees->setMaxCategories(16);
|
||||
rtrees->setPriors(Mat());
|
||||
rtrees->setCalculateVarImportance(true);
|
||||
rtrees->setActiveVarCount(0);
|
||||
rtrees->setTermCriteria(TermCriteria(TermCriteria::MAX_ITER, 100, 0));
|
||||
train_and_print_errs(rtrees, data);
|
||||
cv::Mat ref_labels = data->getClassLabels();
|
||||
cv::Mat test_data = data->getTestSampleIdx();
|
||||
cv::Mat predict_labels;
|
||||
rtrees->predict(data->getSamples(), predict_labels);
|
||||
|
||||
cv::Mat variable_importance = rtrees->getVarImportance();
|
||||
std::cout << "Estimated variable importance" << std::endl;
|
||||
for (int i = 0; i < variable_importance.rows; i++) {
|
||||
std::cout << "Variable " << i << ": " << variable_importance.at<float>(i, 0) << std::endl;
|
||||
}
|
||||
return 0;
|
||||
}
|
Reference in New Issue
Block a user