train_svmsgd.cpp 5.9 KB
Newer Older
wester committed
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210
#include <opencv2/opencv.hpp>
#include "opencv2/video/tracking.hpp"
#include "opencv2/imgproc/imgproc.hpp"
#include "opencv2/highgui/highgui.hpp"

using namespace cv;
using namespace cv::ml;


struct Data
{
    Mat img;
    Mat samples;          //Set of train samples. Contains points on image
    Mat responses;        //Set of responses for train samples

    Data()
    {
        const int WIDTH = 841;
        const int HEIGHT = 594;
        img = Mat::zeros(HEIGHT, WIDTH, CV_8UC3);
        imshow("Train svmsgd", img);
    }
};

//Train with SVMSGD algorithm
//(samples, responses) is a train set
//weights is a required vector for decision function of SVMSGD algorithm
bool doTrain(const Mat samples, const Mat responses, Mat &weights, float &shift);

//function finds two points for drawing line (wx = 0)
bool findPointsForLine(const Mat &weights, float shift, Point points[], int width, int height);

// function finds cross point of line (wx = 0) and segment ( (y = HEIGHT, 0 <= x <= WIDTH) or (x = WIDTH, 0 <= y <= HEIGHT) )
bool findCrossPointWithBorders(const Mat &weights, float shift, const std::pair<Point,Point> &segment, Point &crossPoint);

//segments' initialization ( (y = HEIGHT, 0 <= x <= WIDTH) and (x = WIDTH, 0 <= y <= HEIGHT) )
void fillSegments(std::vector<std::pair<Point,Point> > &segments, int width, int height);

//redraw points' set and line (wx = 0)
void redraw(Data data, const Point points[2]);

//add point in train set, train SVMSGD algorithm and draw results on image
void addPointRetrainAndRedraw(Data &data, int x, int y, int response);


bool doTrain( const Mat samples, const Mat responses, Mat &weights, float &shift)
{
    cv::Ptr<SVMSGD> svmsgd = SVMSGD::create();

    cv::Ptr<TrainData> trainData = TrainData::create(samples, cv::ml::ROW_SAMPLE, responses);
    svmsgd->train( trainData );

    if (svmsgd->isTrained())
    {
        weights = svmsgd->getWeights();
        shift = svmsgd->getShift();

        return true;
    }
    return false;
}

void fillSegments(std::vector<std::pair<Point,Point> > &segments, int width, int height)
{
    std::pair<Point,Point> currentSegment;

    currentSegment.first = Point(width, 0);
    currentSegment.second = Point(width, height);
    segments.push_back(currentSegment);

    currentSegment.first = Point(0, height);
    currentSegment.second = Point(width, height);
    segments.push_back(currentSegment);

    currentSegment.first = Point(0, 0);
    currentSegment.second = Point(width, 0);
    segments.push_back(currentSegment);

    currentSegment.first = Point(0, 0);
    currentSegment.second = Point(0, height);
    segments.push_back(currentSegment);
}


bool findCrossPointWithBorders(const Mat &weights, float shift, const std::pair<Point,Point> &segment, Point &crossPoint)
{
    int x = 0;
    int y = 0;
    int xMin = std::min(segment.first.x, segment.second.x);
    int xMax = std::max(segment.first.x, segment.second.x);
    int yMin = std::min(segment.first.y, segment.second.y);
    int yMax = std::max(segment.first.y, segment.second.y);

    CV_Assert(weights.type() == CV_32FC1);
    CV_Assert(xMin == xMax || yMin == yMax);

    if (xMin == xMax && weights.at<float>(1) != 0)
    {
        x = xMin;
        y = static_cast<int>(std::floor( - (weights.at<float>(0) * x + shift) / weights.at<float>(1)));
        if (y >= yMin && y <= yMax)
        {
            crossPoint.x = x;
            crossPoint.y = y;
            return true;
        }
    }
    else if (yMin == yMax && weights.at<float>(0) != 0)
    {
        y = yMin;
        x = static_cast<int>(std::floor( - (weights.at<float>(1) * y + shift) / weights.at<float>(0)));
        if (x >= xMin && x <= xMax)
        {
            crossPoint.x = x;
            crossPoint.y = y;
            return true;
        }
    }
    return false;
}

bool findPointsForLine(const Mat &weights, float shift, Point points[2], int width, int height)
{
    if (weights.empty())
    {
        return false;
    }

    int foundPointsCount = 0;
    std::vector<std::pair<Point,Point> > segments;
    fillSegments(segments, width, height);

    for (uint i = 0; i < segments.size(); i++)
    {
        if (findCrossPointWithBorders(weights, shift, segments[i], points[foundPointsCount]))
            foundPointsCount++;
        if (foundPointsCount >= 2)
            break;
    }

    return true;
}

void redraw(Data data, const Point points[2])
{
    data.img.setTo(0);
    Point center;
    int radius = 3;
    Scalar color;
    CV_Assert((data.samples.type() == CV_32FC1) && (data.responses.type() == CV_32FC1));
    for (int i = 0; i < data.samples.rows; i++)
    {
        center.x = static_cast<int>(data.samples.at<float>(i,0));
        center.y = static_cast<int>(data.samples.at<float>(i,1));
        color = (data.responses.at<float>(i) > 0) ? Scalar(128,128,0) : Scalar(0,128,128);
        circle(data.img, center, radius, color, 5);
    }
    line(data.img, points[0], points[1],cv::Scalar(1,255,1));

    imshow("Train svmsgd", data.img);
}

void addPointRetrainAndRedraw(Data &data, int x, int y, int response)
{
    Mat currentSample(1, 2, CV_32FC1);

    currentSample.at<float>(0,0) = (float)x;
    currentSample.at<float>(0,1) = (float)y;
    data.samples.push_back(currentSample);
    data.responses.push_back(static_cast<float>(response));

    Mat weights(1, 2, CV_32FC1);
    float shift = 0;

    if (doTrain(data.samples, data.responses, weights, shift))
    {
        Point points[2];
        findPointsForLine(weights, shift, points, data.img.cols, data.img.rows);

        redraw(data, points);
    }
}


static void onMouse( int event, int x, int y, int, void* pData)
{
    Data &data = *(Data*)pData;

    switch( event )
    {
    case CV_EVENT_LBUTTONUP:
        addPointRetrainAndRedraw(data, x, y, 1);
        break;

    case CV_EVENT_RBUTTONDOWN:
        addPointRetrainAndRedraw(data, x, y, -1);
        break;
    }

}

int main()
{
    Data data;

    setMouseCallback( "Train svmsgd", onMouse, &data );
    waitKey();

    return 0;
}