test_precomp.hpp 2.24 KB
Newer Older
wester committed
1 2 3 4 5 6 7 8 9 10 11
#ifdef __GNUC__
#  pragma GCC diagnostic ignored "-Wmissing-declarations"
#  if defined __clang__ || defined __APPLE__
#    pragma GCC diagnostic ignored "-Wmissing-prototypes"
#    pragma GCC diagnostic ignored "-Wextra"
#  endif
#endif

#ifndef __OPENCV_TEST_PRECOMP_HPP__
#define __OPENCV_TEST_PRECOMP_HPP__

wester committed
12 13 14
#include "opencv2/ts/ts.hpp"
#include "opencv2/ml/ml.hpp"
#include "opencv2/core/core_c.h"
wester committed
15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41
#include <iostream>
#include <map>

#define CV_NBAYES   "nbayes"
#define CV_KNEAREST "knearest"
#define CV_SVM      "svm"
#define CV_EM       "em"
#define CV_ANN      "ann"
#define CV_DTREE    "dtree"
#define CV_BOOST    "boost"
#define CV_RTREES   "rtrees"
#define CV_ERTREES  "ertrees"

class CV_MLBaseTest : public cvtest::BaseTest
{
public:
    CV_MLBaseTest( const char* _modelName );
    virtual ~CV_MLBaseTest();
protected:
    virtual int read_params( CvFileStorage* fs );
    virtual void run( int startFrom );
    virtual int prepare_test_case( int testCaseIdx );
    virtual std::string& get_validation_filename();
    virtual int run_test_case( int testCaseIdx ) = 0;
    virtual int validate_test_results( int testCaseIdx ) = 0;

    int train( int testCaseIdx );
wester committed
42
    float get_error( int testCaseIdx, int type, std::vector<float> *resp = 0 );
wester committed
43 44 45
    void save( const char* filename );
    void load( const char* filename );

wester committed
46
    CvMLData data;
wester committed
47 48 49 50
    std::string modelName, validationFN;
    std::vector<std::string> dataSetNames;
    cv::FileStorage validationFS;

wester committed
51 52 53 54 55 56 57 58 59
    // MLL models
    CvNormalBayesClassifier* nbayes;
    CvKNearest* knearest;
    CvSVM* svm;
    CvANN_MLP* ann;
    CvDTree* dtree;
    CvBoost* boost;
    CvRTrees* rtrees;
    CvERTrees* ertrees;
wester committed
60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87

    std::map<int, int> cls_map;

    int64 initSeed;
};

class CV_AMLTest : public CV_MLBaseTest
{
public:
    CV_AMLTest( const char* _modelName );
protected:
    virtual int run_test_case( int testCaseIdx );
    virtual int validate_test_results( int testCaseIdx );
};

class CV_SLMLTest : public CV_MLBaseTest
{
public:
    CV_SLMLTest( const char* _modelName );
protected:
    virtual int run_test_case( int testCaseIdx );
    virtual int validate_test_results( int testCaseIdx );

    std::vector<float> test_resps1, test_resps2; // predicted responses for test data
    std::string fname1, fname2;
};

#endif