tree_engine.cpp 4.5 KB
Newer Older
wester committed
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138
#include "opencv2/ml/ml.hpp"
#include "opencv2/core/core_c.h"
#include <stdio.h>
#include <map>

static void help()
{
    printf(
        "\nThis sample demonstrates how to use different decision trees and forests including boosting and random trees:\n"
        "CvDTree dtree;\n"
        "CvBoost boost;\n"
        "CvRTrees rtrees;\n"
        "CvERTrees ertrees;\n"
        "CvGBTrees gbtrees;\n"
        "Call:\n\t./tree_engine [-r <response_column>] [-c] <csv filename>\n"
        "where -r <response_column> specified the 0-based index of the response (0 by default)\n"
        "-c specifies that the response is categorical (it's ordered by default) and\n"
        "<csv filename> is the name of training data file in comma-separated value format\n\n");
}


static int count_classes(CvMLData& data)
{
    cv::Mat r(data.get_responses());
    std::map<int, int> rmap;
    int i, n = (int)r.total();
    for( i = 0; i < n; i++ )
    {
        float val = r.at<float>(i);
        int ival = cvRound(val);
        if( ival != val )
            return -1;
        rmap[ival] = 1;
    }
    return (int)rmap.size();
}

static void print_result(float train_err, float test_err, const CvMat* _var_imp)
{
    printf( "train error    %f\n", train_err );
    printf( "test error    %f\n\n", test_err );

    if (_var_imp)
    {
        cv::Mat var_imp(_var_imp), sorted_idx;
        cv::sortIdx(var_imp, sorted_idx, CV_SORT_EVERY_ROW + CV_SORT_DESCENDING);

        printf( "variable importance:\n" );
        int i, n = (int)var_imp.total();
        int type = var_imp.type();
        CV_Assert(type == CV_32F || type == CV_64F);

        for( i = 0; i < n; i++)
        {
            int k = sorted_idx.at<int>(i);
            printf( "%d\t%f\n", k, type == CV_32F ? var_imp.at<float>(k) : var_imp.at<double>(k));
        }
    }
    printf("\n");
}

int main(int argc, char** argv)
{
    if(argc < 2)
    {
        help();
        return 0;
    }
    const char* filename = 0;
    int response_idx = 0;
    bool categorical_response = false;

    for(int i = 1; i < argc; i++)
    {
        if(strcmp(argv[i], "-r") == 0)
            sscanf(argv[++i], "%d", &response_idx);
        else if(strcmp(argv[i], "-c") == 0)
            categorical_response = true;
        else if(argv[i][0] != '-' )
            filename = argv[i];
        else
        {
            printf("Error. Invalid option %s\n", argv[i]);
            help();
            return -1;
        }
    }

    printf("\nReading in %s...\n\n",filename);
    CvDTree dtree;
    CvBoost boost;
    CvRTrees rtrees;
    CvERTrees ertrees;
    CvGBTrees gbtrees;

    CvMLData data;


    CvTrainTestSplit spl( 0.5f );

    if ( data.read_csv( filename ) == 0)
    {
        data.set_response_idx( response_idx );
        if(categorical_response)
            data.change_var_type( response_idx, CV_VAR_CATEGORICAL );
        data.set_train_test_split( &spl );

        printf("======DTREE=====\n");
        dtree.train( &data, CvDTreeParams( 10, 2, 0, false, 16, 0, false, false, 0 ));
        print_result( dtree.calc_error( &data, CV_TRAIN_ERROR), dtree.calc_error( &data, CV_TEST_ERROR ), dtree.get_var_importance() );

        if( categorical_response && count_classes(data) == 2 )
        {
        printf("======BOOST=====\n");
        boost.train( &data, CvBoostParams(CvBoost::DISCRETE, 100, 0.95, 2, false, 0));
        print_result( boost.calc_error( &data, CV_TRAIN_ERROR ), boost.calc_error( &data, CV_TEST_ERROR ), 0 ); //doesn't compute importance
        }

        printf("======RTREES=====\n");
        rtrees.train( &data, CvRTParams( 10, 2, 0, false, 16, 0, true, 0, 100, 0, CV_TERMCRIT_ITER ));
        print_result( rtrees.calc_error( &data, CV_TRAIN_ERROR), rtrees.calc_error( &data, CV_TEST_ERROR ), rtrees.get_var_importance() );

        printf("======ERTREES=====\n");
        ertrees.train( &data, CvRTParams( 18, 2, 0, false, 16, 0, true, 0, 100, 0, CV_TERMCRIT_ITER ));
        print_result( ertrees.calc_error( &data, CV_TRAIN_ERROR), ertrees.calc_error( &data, CV_TEST_ERROR ), ertrees.get_var_importance() );

        printf("======GBTREES=====\n");
        if (categorical_response)
            gbtrees.train( &data, CvGBTreesParams(CvGBTrees::DEVIANCE_LOSS, 100, 0.1f, 0.8f, 5, false));
        else
            gbtrees.train( &data, CvGBTreesParams(CvGBTrees::SQUARED_LOSS, 100, 0.1f, 0.8f, 5, false));
        print_result( gbtrees.calc_error( &data, CV_TRAIN_ERROR), gbtrees.calc_error( &data, CV_TEST_ERROR ), 0 ); //doesn't compute importance
    }
    else
        printf("File can not be read");

    return 0;
}