svm.cpp 90.9 KB
Newer Older
wester committed
1 2 3 4 5 6 7 8 9
/*M///////////////////////////////////////////////////////////////////////////////////////
//
//  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
//
//  By downloading, copying, installing or using the software you agree to this license.
//  If you do not agree to this license, do not download, install,
//  copy or use the software.
//
//
wester committed
10
//                        Intel License Agreement
wester committed
11 12 13 14 15 16 17 18 19 20 21 22 23 24
//
// Copyright (C) 2000, Intel Corporation, all rights reserved.
// Third party copyrights are property of their respective owners.
//
// Redistribution and use in source and binary forms, with or without modification,
// are permitted provided that the following conditions are met:
//
//   * Redistribution's of source code must retain the above copyright notice,
//     this list of conditions and the following disclaimer.
//
//   * Redistribution's in binary form must reproduce the above copyright notice,
//     this list of conditions and the following disclaimer in the documentation
//     and/or other materials provided with the distribution.
//
wester committed
25
//   * The name of Intel Corporation may not be used to endorse or promote products
wester committed
26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83
//     derived from this software without specific prior written permission.
//
// This software is provided by the copyright holders and contributors "as is" and
// any express or implied warranties, including, but not limited to, the implied
// warranties of merchantability and fitness for a particular purpose are disclaimed.
// In no event shall the Intel Corporation or contributors be liable for any direct,
// indirect, incidental, special, exemplary, or consequential damages
// (including, but not limited to, procurement of substitute goods or services;
// loss of use, data, or profits; or business interruption) however caused
// and on any theory of liability, whether in contract, strict liability,
// or tort (including negligence or otherwise) arising in any way out of
// the use of this software, even if advised of the possibility of such damage.
//
//M*/

#include "precomp.hpp"

/****************************************************************************************\
                                COPYRIGHT NOTICE
                                ----------------

  The code has been derived from libsvm library (version 2.6)
  (http://www.csie.ntu.edu.tw/~cjlin/libsvm).

  Here is the orignal copyright:
------------------------------------------------------------------------------------------
    Copyright (c) 2000-2003 Chih-Chung Chang and Chih-Jen Lin
    All rights reserved.

    Redistribution and use in source and binary forms, with or without
    modification, are permitted provided that the following conditions
    are met:

    1. Redistributions of source code must retain the above copyright
    notice, this list of conditions and the following disclaimer.

    2. Redistributions in binary form must reproduce the above copyright
    notice, this list of conditions and the following disclaimer in the
    documentation and/or other materials provided with the distribution.

    3. Neither name of copyright holders nor the names of its contributors
    may be used to endorse or promote products derived from this software
    without specific prior written permission.


    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE REGENTS OR
    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
\****************************************************************************************/

wester committed
84 85 86
using namespace cv;

#define CV_SVM_MIN_CACHE_SIZE  (40 << 20)  /* 40Mb */
wester committed
87

wester committed
88 89 90 91
#include <stdarg.h>
#include <ctype.h>

#if 1
wester committed
92
typedef float Qfloat;
wester committed
93 94 95 96 97
#define QFLOAT_TYPE CV_32F
#else
typedef double Qfloat;
#define QFLOAT_TYPE CV_64F
#endif
wester committed
98 99

// Param Grid
wester committed
100
bool CvParamGrid::check() const
wester committed
101
{
wester committed
102
    bool ok = false;
wester committed
103

wester committed
104 105
    CV_FUNCNAME( "CvParamGrid::check" );
    __BEGIN__;
wester committed
106

wester committed
107 108 109 110 111 112
    if( min_val > max_val )
        CV_ERROR( CV_StsBadArg, "Lower bound of the grid must be less then the upper one" );
    if( min_val < DBL_EPSILON )
        CV_ERROR( CV_StsBadArg, "Lower bound of the grid must be positive" );
    if( step < 1. + FLT_EPSILON )
        CV_ERROR( CV_StsBadArg, "Grid step must greater then 1" );
wester committed
113

wester committed
114 115 116 117 118 119 120 121
    ok = true;

    __END__;

    return ok;
}

CvParamGrid CvSVM::get_default_grid( int param_id )
wester committed
122
{
wester committed
123 124
    CvParamGrid grid;
    if( param_id == CvSVM::C )
wester committed
125
    {
wester committed
126 127 128
        grid.min_val = 0.1;
        grid.max_val = 500;
        grid.step = 5; // total iterations = 5
wester committed
129
    }
wester committed
130
    else if( param_id == CvSVM::GAMMA )
wester committed
131
    {
wester committed
132 133 134
        grid.min_val = 1e-5;
        grid.max_val = 0.6;
        grid.step = 15; // total iterations = 4
wester committed
135
    }
wester committed
136
    else if( param_id == CvSVM::P )
wester committed
137
    {
wester committed
138 139 140
        grid.min_val = 0.01;
        grid.max_val = 100;
        grid.step = 7; // total iterations = 4
wester committed
141
    }
wester committed
142
    else if( param_id == CvSVM::NU )
wester committed
143
    {
wester committed
144 145 146
        grid.min_val = 0.01;
        grid.max_val = 0.2;
        grid.step = 3; // total iterations = 3
wester committed
147
    }
wester committed
148
    else if( param_id == CvSVM::COEF )
wester committed
149
    {
wester committed
150 151 152
        grid.min_val = 0.1;
        grid.max_val = 300;
        grid.step = 14; // total iterations = 3
wester committed
153
    }
wester committed
154
    else if( param_id == CvSVM::DEGREE )
wester committed
155
    {
wester committed
156 157 158
        grid.min_val = 0.01;
        grid.max_val = 4;
        grid.step = 7; // total iterations = 3
wester committed
159
    }
wester committed
160 161 162 163 164 165 166 167 168 169 170 171 172
    else
        cvError( CV_StsBadArg, "CvSVM::get_default_grid", "Invalid type of parameter "
            "(use one of CvSVM::C, CvSVM::GAMMA et al.)", __FILE__, __LINE__ );
    return grid;
}

// SVM training parameters
CvSVMParams::CvSVMParams() :
    svm_type(CvSVM::C_SVC), kernel_type(CvSVM::RBF), degree(0),
    gamma(1), coef0(0), C(1), nu(0), p(0), class_weights(0)
{
    term_crit = cvTermCriteria( CV_TERMCRIT_ITER+CV_TERMCRIT_EPS, 1000, FLT_EPSILON );
}
wester committed
173 174


wester committed
175 176 177 178 179 180 181 182 183
CvSVMParams::CvSVMParams( int _svm_type, int _kernel_type,
    double _degree, double _gamma, double _coef0,
    double _Con, double _nu, double _p,
    CvMat* _class_weights, CvTermCriteria _term_crit ) :
    svm_type(_svm_type), kernel_type(_kernel_type),
    degree(_degree), gamma(_gamma), coef0(_coef0),
    C(_Con), nu(_nu), p(_p), class_weights(_class_weights), term_crit(_term_crit)
{
}
wester committed
184 185


wester committed
186
/////////////////////////////////////// SVM kernel ///////////////////////////////////////
wester committed
187

wester committed
188 189 190 191
CvSVMKernel::CvSVMKernel()
{
    clear();
}
wester committed
192 193


wester committed
194 195 196 197 198
void CvSVMKernel::clear()
{
    params = 0;
    calc_func = 0;
}
wester committed
199 200


wester committed
201 202 203
CvSVMKernel::~CvSVMKernel()
{
}
wester committed
204

wester committed
205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234

CvSVMKernel::CvSVMKernel( const CvSVMParams* _params, Calc _calc_func )
{
    clear();
    create( _params, _calc_func );
}


bool CvSVMKernel::create( const CvSVMParams* _params, Calc _calc_func )
{
    clear();
    params = _params;
    calc_func = _calc_func;

    if( !calc_func )
        calc_func = params->kernel_type == CvSVM::RBF ? &CvSVMKernel::calc_rbf :
                    params->kernel_type == CvSVM::POLY ? &CvSVMKernel::calc_poly :
                    params->kernel_type == CvSVM::SIGMOID ? &CvSVMKernel::calc_sigmoid :
                    &CvSVMKernel::calc_linear;

    return true;
}


void CvSVMKernel::calc_non_rbf_base( int vcount, int var_count, const float** vecs,
                                     const float* another, Qfloat* results,
                                     double alpha, double beta )
{
    int j, k;
    for( j = 0; j < vcount; j++ )
wester committed
235
    {
wester committed
236 237 238 239 240 241 242 243
        const float* sample = vecs[j];
        double s = 0;
        for( k = 0; k <= var_count - 4; k += 4 )
            s += sample[k]*another[k] + sample[k+1]*another[k+1] +
                 sample[k+2]*another[k+2] + sample[k+3]*another[k+3];
        for( ; k < var_count; k++ )
            s += sample[k]*another[k];
        results[j] = (Qfloat)(s*alpha + beta);
wester committed
244
    }
wester committed
245 246 247 248 249 250 251 252
}


void CvSVMKernel::calc_linear( int vcount, int var_count, const float** vecs,
                               const float* another, Qfloat* results )
{
    calc_non_rbf_base( vcount, var_count, vecs, another, results, 1, 0 );
}
wester committed
253

wester committed
254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272

void CvSVMKernel::calc_poly( int vcount, int var_count, const float** vecs,
                             const float* another, Qfloat* results )
{
    CvMat R = cvMat( 1, vcount, QFLOAT_TYPE, results );
    calc_non_rbf_base( vcount, var_count, vecs, another, results, params->gamma, params->coef0 );
    if( vcount > 0 )
        cvPow( &R, &R, params->degree );
}


void CvSVMKernel::calc_sigmoid( int vcount, int var_count, const float** vecs,
                                const float* another, Qfloat* results )
{
    int j;
    calc_non_rbf_base( vcount, var_count, vecs, another, results,
                       -2*params->gamma, -2*params->coef0 );
    // TODO: speedup this
    for( j = 0; j < vcount; j++ )
wester committed
273
    {
wester committed
274 275 276 277 278 279
        Qfloat t = results[j];
        double e = exp(-fabs(t));
        if( t > 0 )
            results[j] = (Qfloat)((1. - e)/(1. + e));
        else
            results[j] = (Qfloat)((e - 1.)/(e + 1.));
wester committed
280
    }
wester committed
281 282 283 284 285 286 287 288 289
}


void CvSVMKernel::calc_rbf( int vcount, int var_count, const float** vecs,
                            const float* another, Qfloat* results )
{
    CvMat R = cvMat( 1, vcount, QFLOAT_TYPE, results );
    double gamma = -params->gamma;
    int j, k;
wester committed
290

wester committed
291
    for( j = 0; j < vcount; j++ )
wester committed
292
    {
wester committed
293 294 295 296
        const float* sample = vecs[j];
        double s = 0;

        for( k = 0; k <= var_count - 4; k += 4 )
wester committed
297
        {
wester committed
298 299 300 301 302 303 304 305 306
            double t0 = sample[k] - another[k];
            double t1 = sample[k+1] - another[k+1];

            s += t0*t0 + t1*t1;

            t0 = sample[k+2] - another[k+2];
            t1 = sample[k+3] - another[k+3];

            s += t0*t0 + t1*t1;
wester committed
307
        }
wester committed
308 309

        for( ; k < var_count; k++ )
wester committed
310
        {
wester committed
311 312
            double t0 = sample[k] - another[k];
            s += t0*t0;
wester committed
313
        }
wester committed
314
        results[j] = (Qfloat)(s*gamma);
wester committed
315 316
    }

wester committed
317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370
    if( vcount > 0 )
        cvExp( &R, &R );
}


void CvSVMKernel::calc( int vcount, int var_count, const float** vecs,
                        const float* another, Qfloat* results )
{
    const Qfloat max_val = (Qfloat)(FLT_MAX*1e-3);
    int j;
    (this->*calc_func)( vcount, var_count, vecs, another, results );
    for( j = 0; j < vcount; j++ )
    {
        if( results[j] > max_val )
            results[j] = max_val;
    }
}


// Generalized SMO+SVMlight algorithm
// Solves:
//
//  min [0.5(\alpha^T Q \alpha) + b^T \alpha]
//
//      y^T \alpha = \delta
//      y_i = +1 or -1
//      0 <= alpha_i <= Cp for y_i = 1
//      0 <= alpha_i <= Cn for y_i = -1
//
// Given:
//
//  Q, b, y, Cp, Cn, and an initial feasible point \alpha
//  l is the size of vectors and matrices
//  eps is the stopping criterion
//
// solution will be put in \alpha, objective value will be put in obj
//

void CvSVMSolver::clear()
{
    G = 0;
    alpha = 0;
    y = 0;
    b = 0;
    buf[0] = buf[1] = 0;
    cvReleaseMemStorage( &storage );
    kernel = 0;
    select_working_set_func = 0;
    calc_rho_func = 0;

    rows = 0;
    samples = 0;
    get_row_func = 0;
}
wester committed
371 372


wester committed
373 374 375 376 377
CvSVMSolver::CvSVMSolver()
{
    storage = 0;
    clear();
}
wester committed
378 379


wester committed
380
CvSVMSolver::~CvSVMSolver()
wester committed
381
{
wester committed
382 383 384
    clear();
}

wester committed
385

wester committed
386 387 388 389 390 391 392 393 394
CvSVMSolver::CvSVMSolver( int _sample_count, int _var_count, const float** _samples, schar* _y,
                int _alpha_count, double* _alpha, double _Cp, double _Cn,
                CvMemStorage* _storage, CvSVMKernel* _kernel, GetRow _get_row,
                SelectWorkingSet _select_working_set, CalcRho _calc_rho )
{
    storage = 0;
    create( _sample_count, _var_count, _samples, _y, _alpha_count, _alpha, _Cp, _Cn,
            _storage, _kernel, _get_row, _select_working_set, _calc_rho );
}
wester committed
395 396


wester committed
397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470
bool CvSVMSolver::create( int _sample_count, int _var_count, const float** _samples, schar* _y,
                int _alpha_count, double* _alpha, double _Cp, double _Cn,
                CvMemStorage* _storage, CvSVMKernel* _kernel, GetRow _get_row,
                SelectWorkingSet _select_working_set, CalcRho _calc_rho )
{
    bool ok = false;
    int i, svm_type;

    CV_FUNCNAME( "CvSVMSolver::create" );

    __BEGIN__;

    int rows_hdr_size;

    clear();

    sample_count = _sample_count;
    var_count = _var_count;
    samples = _samples;
    y = _y;
    alpha_count = _alpha_count;
    alpha = _alpha;
    kernel = _kernel;

    C[0] = _Cn;
    C[1] = _Cp;
    eps = kernel->params->term_crit.epsilon;
    max_iter = kernel->params->term_crit.max_iter;
    storage = cvCreateChildMemStorage( _storage );

    b = (double*)cvMemStorageAlloc( storage, alpha_count*sizeof(b[0]));
    alpha_status = (schar*)cvMemStorageAlloc( storage, alpha_count*sizeof(alpha_status[0]));
    G = (double*)cvMemStorageAlloc( storage, alpha_count*sizeof(G[0]));
    for( i = 0; i < 2; i++ )
        buf[i] = (Qfloat*)cvMemStorageAlloc( storage, sample_count*2*sizeof(buf[i][0]) );
    svm_type = kernel->params->svm_type;

    select_working_set_func = _select_working_set;
    if( !select_working_set_func )
        select_working_set_func = svm_type == CvSVM::NU_SVC || svm_type == CvSVM::NU_SVR ?
        &CvSVMSolver::select_working_set_nu_svm : &CvSVMSolver::select_working_set;

    calc_rho_func = _calc_rho;
    if( !calc_rho_func )
        calc_rho_func = svm_type == CvSVM::NU_SVC || svm_type == CvSVM::NU_SVR ?
            &CvSVMSolver::calc_rho_nu_svm : &CvSVMSolver::calc_rho;

    get_row_func = _get_row;
    if( !get_row_func )
        get_row_func = params->svm_type == CvSVM::EPS_SVR ||
                       params->svm_type == CvSVM::NU_SVR ? &CvSVMSolver::get_row_svr :
                       params->svm_type == CvSVM::C_SVC ||
                       params->svm_type == CvSVM::NU_SVC ? &CvSVMSolver::get_row_svc :
                       &CvSVMSolver::get_row_one_class;

    cache_line_size = sample_count*sizeof(Qfloat);
    // cache size = max(num_of_samples^2*sizeof(Qfloat)*0.25, 64Kb)
    // (assuming that for large training sets ~25% of Q matrix is used)
    cache_size = MAX( cache_line_size*sample_count/4, CV_SVM_MIN_CACHE_SIZE );

    // the size of Q matrix row headers
    rows_hdr_size = sample_count*sizeof(rows[0]);
    if( rows_hdr_size > storage->block_size )
        CV_ERROR( CV_StsOutOfRange, "Too small storage block size" );

    lru_list.prev = lru_list.next = &lru_list;
    rows = (CvSVMKernelRow*)cvMemStorageAlloc( storage, rows_hdr_size );
    memset( rows, 0, rows_hdr_size );

    ok = true;

    __END__;

    return ok;
wester committed
471 472 473
}


wester committed
474
float* CvSVMSolver::get_row_base( int i, bool* _existed )
wester committed
475
{
wester committed
476 477 478 479 480 481
    int i1 = i < sample_count ? i : i - sample_count;
    CvSVMKernelRow* row = rows + i1;
    bool existed = row->data != 0;
    Qfloat* data;

    if( existed || cache_size <= 0 )
wester committed
482
    {
wester committed
483 484 485 486 487 488 489 490
        CvSVMKernelRow* del_row = existed ? row : lru_list.prev;
        data = del_row->data;
        assert( data != 0 );

        // delete row from the LRU list
        del_row->data = 0;
        del_row->prev->next = del_row->next;
        del_row->next->prev = del_row->prev;
wester committed
491
    }
wester committed
492
    else
wester committed
493
    {
wester committed
494 495
        data = (Qfloat*)cvMemStorageAlloc( storage, cache_line_size );
        cache_size -= cache_line_size;
wester committed
496
    }
wester committed
497 498 499 500 501 502 503 504

    // insert row into the LRU list
    row->data = data;
    row->prev = &lru_list;
    row->next = lru_list.next;
    row->prev->next = row->next->prev = row;

    if( !existed )
wester committed
505
    {
wester committed
506
        kernel->calc( sample_count, var_count, samples, samples[i1], row->data );
wester committed
507
    }
wester committed
508 509 510 511 512 513 514 515 516 517 518

    if( _existed )
        *_existed = existed;

    return row->data;
}


float* CvSVMSolver::get_row_svc( int i, float* row, float*, bool existed )
{
    if( !existed )
wester committed
519
    {
wester committed
520 521 522 523 524 525 526 527 528 529 530 531 532 533
        const schar* _y = y;
        int j, len = sample_count;
        assert( _y && i < sample_count );

        if( _y[i] > 0 )
        {
            for( j = 0; j < len; j++ )
                row[j] = _y[j]*row[j];
        }
        else
        {
            for( j = 0; j < len; j++ )
                row[j] = -_y[j]*row[j];
        }
wester committed
534
    }
wester committed
535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550
    return row;
}


float* CvSVMSolver::get_row_one_class( int, float* row, float*, bool )
{
    return row;
}


float* CvSVMSolver::get_row_svr( int i, float* row, float* dst, bool )
{
    int j, len = sample_count;
    Qfloat* dst_pos = dst;
    Qfloat* dst_neg = dst + len;
    if( i >= len )
wester committed
551
    {
wester committed
552 553
        Qfloat* temp;
        CV_SWAP( dst_pos, dst_neg, temp );
wester committed
554
    }
wester committed
555 556

    for( j = 0; j < len; j++ )
wester committed
557
    {
wester committed
558 559 560
        Qfloat t = row[j];
        dst_pos[j] = t;
        dst_neg[j] = -t;
wester committed
561
    }
wester committed
562
    return dst;
wester committed
563 564 565
}


wester committed
566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611

float* CvSVMSolver::get_row( int i, float* dst )
{
    bool existed = false;
    float* row = get_row_base( i, &existed );
    return (this->*get_row_func)( i, row, dst, existed );
}


#undef is_upper_bound
#define is_upper_bound(i) (alpha_status[i] > 0)

#undef is_lower_bound
#define is_lower_bound(i) (alpha_status[i] < 0)

#undef is_free
#define is_free(i) (alpha_status[i] == 0)

#undef get_C
#define get_C(i) (C[y[i]>0])

#undef update_alpha_status
#define update_alpha_status(i) \
    alpha_status[i] = (schar)(alpha[i] >= get_C(i) ? 1 : alpha[i] <= 0 ? -1 : 0)

#undef reconstruct_gradient
#define reconstruct_gradient() /* empty for now */


bool CvSVMSolver::solve_generic( CvSVMSolutionInfo& si )
{
    int iter = 0;
    int i, j, k;

    // 1. initialize gradient and alpha status
    for( i = 0; i < alpha_count; i++ )
    {
        update_alpha_status(i);
        G[i] = b[i];
        if( fabs(G[i]) > 1e200 )
            return false;
    }

    for( i = 0; i < alpha_count; i++ )
    {
        if( !is_lower_bound(i) )
wester committed
612
        {
wester committed
613 614 615 616 617
            const Qfloat *Q_i = get_row( i, buf[0] );
            double alpha_i = alpha[i];

            for( j = 0; j < alpha_count; j++ )
                G[j] += alpha_i*Q_i[j];
wester committed
618
        }
wester committed
619 620 621 622 623 624 625 626 627
    }

    // 2. optimization loop
    for(;;)
    {
        const Qfloat *Q_i, *Q_j;
        double C_i, C_j;
        double old_alpha_i, old_alpha_j, alpha_i, alpha_j;
        double delta_alpha_i, delta_alpha_j;
wester committed
628

wester committed
629 630
#ifdef _DEBUG
        for( i = 0; i < alpha_count; i++ )
wester committed
631
        {
wester committed
632 633 634 635 636
            if( fabs(G[i]) > 1e+300 )
                return false;

            if( fabs(alpha[i]) > 1e16 )
                return false;
wester committed
637
        }
wester committed
638 639 640 641
#endif

        if( (this->*select_working_set_func)( i, j ) != 0 || iter++ >= max_iter )
            break;
wester committed
642

wester committed
643 644 645 646 647 648 649 650 651 652
        Q_i = get_row( i, buf[0] );
        Q_j = get_row( j, buf[1] );

        C_i = get_C(i);
        C_j = get_C(j);

        alpha_i = old_alpha_i = alpha[i];
        alpha_j = old_alpha_j = alpha[j];

        if( y[i] != y[j] )
wester committed
653
        {
wester committed
654 655 656 657 658 659 660
            double denom = Q_i[i]+Q_j[j]+2*Q_i[j];
            double delta = (-G[i]-G[j])/MAX(fabs(denom),FLT_EPSILON);
            double diff = alpha_i - alpha_j;
            alpha_i += delta;
            alpha_j += delta;

            if( diff > 0 && alpha_j < 0 )
wester committed
661
            {
wester committed
662 663
                alpha_j = 0;
                alpha_i = diff;
wester committed
664
            }
wester committed
665
            else if( diff <= 0 && alpha_i < 0 )
wester committed
666
            {
wester committed
667 668
                alpha_i = 0;
                alpha_j = -diff;
wester committed
669 670
            }

wester committed
671
            if( diff > C_i - C_j && alpha_i > C_i )
wester committed
672
            {
wester committed
673 674 675 676 677 678 679
                alpha_i = C_i;
                alpha_j = C_i - diff;
            }
            else if( diff <= C_i - C_j && alpha_j > C_j )
            {
                alpha_j = C_j;
                alpha_i = C_j + diff;
wester committed
680 681
            }
        }
wester committed
682
        else
wester committed
683
        {
wester committed
684 685 686 687 688
            double denom = Q_i[i]+Q_j[j]-2*Q_i[j];
            double delta = (G[i]-G[j])/MAX(fabs(denom),FLT_EPSILON);
            double sum = alpha_i + alpha_j;
            alpha_i -= delta;
            alpha_j += delta;
wester committed
689

wester committed
690 691 692 693 694 695 696 697 698 699
            if( sum > C_i && alpha_i > C_i )
            {
                alpha_i = C_i;
                alpha_j = sum - C_i;
            }
            else if( sum <= C_i && alpha_j < 0)
            {
                alpha_j = 0;
                alpha_i = sum;
            }
wester committed
700

wester committed
701
            if( sum > C_j && alpha_j > C_j )
wester committed
702
            {
wester committed
703 704 705 706 707 708 709
                alpha_j = C_j;
                alpha_i = sum - C_j;
            }
            else if( sum <= C_j && alpha_i < 0 )
            {
                alpha_i = 0;
                alpha_j = sum;
wester committed
710 711 712
            }
        }

wester committed
713 714 715 716 717
        // update alpha
        alpha[i] = alpha_i;
        alpha[j] = alpha_j;
        update_alpha_status(i);
        update_alpha_status(j);
wester committed
718

wester committed
719 720 721
        // update G
        delta_alpha_i = alpha_i - old_alpha_i;
        delta_alpha_j = alpha_j - old_alpha_j;
wester committed
722

wester committed
723 724 725
        for( k = 0; k < alpha_count; k++ )
            G[k] += Q_i[k]*delta_alpha_i + Q_j[k]*delta_alpha_j;
    }
wester committed
726

wester committed
727 728
    // calculate rho
    (this->*calc_rho_func)( si.rho, si.r );
wester committed
729

wester committed
730 731 732
    // calculate objective value
    for( i = 0, si.obj = 0; i < alpha_count; i++ )
        si.obj += alpha[i] * (G[i] + b[i]);
wester committed
733

wester committed
734
    si.obj *= 0.5;
wester committed
735

wester committed
736 737
    si.upper_bound_p = C[1];
    si.upper_bound_n = C[0];
wester committed
738

wester committed
739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756
    return true;
}


// return 1 if already optimal, return 0 otherwise
bool
CvSVMSolver::select_working_set( int& out_i, int& out_j )
{
    // return i,j which maximize -grad(f)^T d , under constraint
    // if alpha_i == C, d != +1
    // if alpha_i == 0, d != -1
    double Gmax1 = -DBL_MAX;        // max { -grad(f)_i * d | y_i*d = +1 }
    int Gmax1_idx = -1;

    double Gmax2 = -DBL_MAX;        // max { -grad(f)_i * d | y_i*d = -1 }
    int Gmax2_idx = -1;

    int i;
wester committed
757

wester committed
758 759 760
    for( i = 0; i < alpha_count; i++ )
    {
        double t;
wester committed
761

wester committed
762 763 764
        if( y[i] > 0 )    // y = +1
        {
            if( !is_upper_bound(i) && (t = -G[i]) > Gmax1 )  // d = +1
wester committed
765
            {
wester committed
766 767
                Gmax1 = t;
                Gmax1_idx = i;
wester committed
768
            }
wester committed
769
            if( !is_lower_bound(i) && (t = G[i]) > Gmax2 )  // d = -1
wester committed
770
            {
wester committed
771 772
                Gmax2 = t;
                Gmax2_idx = i;
wester committed
773
            }
wester committed
774 775 776 777
        }
        else        // y = -1
        {
            if( !is_upper_bound(i) && (t = -G[i]) > Gmax2 )  // d = +1
wester committed
778
            {
wester committed
779 780 781 782 783 784 785 786 787 788
                Gmax2 = t;
                Gmax2_idx = i;
            }
            if( !is_lower_bound(i) && (t = G[i]) > Gmax1 )  // d = -1
            {
                Gmax1 = t;
                Gmax1_idx = i;
            }
        }
    }
wester committed
789

wester committed
790 791
    out_i = Gmax1_idx;
    out_j = Gmax2_idx;
wester committed
792

wester committed
793 794
    return Gmax1 + Gmax2 < eps;
}
wester committed
795 796


wester committed
797 798 799 800 801
void
CvSVMSolver::calc_rho( double& rho, double& r )
{
    int i, nr_free = 0;
    double ub = DBL_MAX, lb = -DBL_MAX, sum_free = 0;
wester committed
802

wester committed
803 804 805
    for( i = 0; i < alpha_count; i++ )
    {
        double yG = y[i]*G[i];
wester committed
806

wester committed
807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826
        if( is_lower_bound(i) )
        {
            if( y[i] > 0 )
                ub = MIN(ub,yG);
            else
                lb = MAX(lb,yG);
        }
        else if( is_upper_bound(i) )
        {
            if( y[i] < 0)
                ub = MIN(ub,yG);
            else
                lb = MAX(lb,yG);
        }
        else
        {
            ++nr_free;
            sum_free += yG;
        }
    }
wester committed
827

wester committed
828 829 830
    rho = nr_free > 0 ? sum_free/nr_free : (ub + lb)*0.5;
    r = 0;
}
wester committed
831 832


wester committed
833 834 835 836 837 838 839 840 841 842 843 844 845 846
bool
CvSVMSolver::select_working_set_nu_svm( int& out_i, int& out_j )
{
    // return i,j which maximize -grad(f)^T d , under constraint
    // if alpha_i == C, d != +1
    // if alpha_i == 0, d != -1
    double Gmax1 = -DBL_MAX;    // max { -grad(f)_i * d | y_i = +1, d = +1 }
    int Gmax1_idx = -1;

    double Gmax2 = -DBL_MAX;    // max { -grad(f)_i * d | y_i = +1, d = -1 }
    int Gmax2_idx = -1;

    double Gmax3 = -DBL_MAX;    // max { -grad(f)_i * d | y_i = -1, d = +1 }
    int Gmax3_idx = -1;
wester committed
847

wester committed
848 849
    double Gmax4 = -DBL_MAX;    // max { -grad(f)_i * d | y_i = -1, d = -1 }
    int Gmax4_idx = -1;
wester committed
850

wester committed
851
    int i;
wester committed
852

wester committed
853 854 855 856 857 858 859 860 861 862 863 864 865 866 867
    for( i = 0; i < alpha_count; i++ )
    {
        double t;

        if( y[i] > 0 )    // y == +1
        {
            if( !is_upper_bound(i) && (t = -G[i]) > Gmax1 )  // d = +1
            {
                Gmax1 = t;
                Gmax1_idx = i;
            }
            if( !is_lower_bound(i) && (t = G[i]) > Gmax2 )  // d = -1
            {
                Gmax2 = t;
                Gmax2_idx = i;
wester committed
868
            }
wester committed
869 870 871 872 873 874 875 876 877 878 879 880 881 882 883
        }
        else        // y == -1
        {
            if( !is_upper_bound(i) && (t = -G[i]) > Gmax3 )  // d = +1
            {
                Gmax3 = t;
                Gmax3_idx = i;
            }
            if( !is_lower_bound(i) && (t = G[i]) > Gmax4 )  // d = -1
            {
                Gmax4 = t;
                Gmax4_idx = i;
            }
        }
    }
wester committed
884

wester committed
885 886
    if( MAX(Gmax1 + Gmax2, Gmax3 + Gmax4) < eps )
        return 1;
wester committed
887

wester committed
888 889 890 891 892 893 894 895 896 897 898 899
    if( Gmax1 + Gmax2 > Gmax3 + Gmax4 )
    {
        out_i = Gmax1_idx;
        out_j = Gmax2_idx;
    }
    else
    {
        out_i = Gmax3_idx;
        out_j = Gmax4_idx;
    }
    return 0;
}
wester committed
900 901


wester committed
902 903 904 905 906 907 908 909 910 911
void
CvSVMSolver::calc_rho_nu_svm( double& rho, double& r )
{
    int nr_free1 = 0, nr_free2 = 0;
    double ub1 = DBL_MAX, ub2 = DBL_MAX;
    double lb1 = -DBL_MAX, lb2 = -DBL_MAX;
    double sum_free1 = 0, sum_free2 = 0;
    double r1, r2;

    int i;
wester committed
912

wester committed
913 914 915 916 917 918 919 920 921 922 923 924 925 926
    for( i = 0; i < alpha_count; i++ )
    {
        double G_i = G[i];
        if( y[i] > 0 )
        {
            if( is_lower_bound(i) )
                ub1 = MIN( ub1, G_i );
            else if( is_upper_bound(i) )
                lb1 = MAX( lb1, G_i );
            else
            {
                ++nr_free1;
                sum_free1 += G_i;
            }
wester committed
927
        }
wester committed
928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947
        else
        {
            if( is_lower_bound(i) )
                ub2 = MIN( ub2, G_i );
            else if( is_upper_bound(i) )
                lb2 = MAX( lb2, G_i );
            else
            {
                ++nr_free2;
                sum_free2 += G_i;
            }
        }
    }

    r1 = nr_free1 > 0 ? sum_free1/nr_free1 : (ub1 + lb1)*0.5;
    r2 = nr_free2 > 0 ? sum_free2/nr_free2 : (ub2 + lb2)*0.5;

    rho = (r1 - r2)*0.5;
    r = (r1 + r2)*0.5;
}
wester committed
948

wester committed
949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998

/*
///////////////////////// construct and solve various formulations ///////////////////////
*/

bool CvSVMSolver::solve_c_svc( int _sample_count, int _var_count, const float** _samples, schar* _y,
                               double _Cp, double _Cn, CvMemStorage* _storage,
                               CvSVMKernel* _kernel, double* _alpha, CvSVMSolutionInfo& _si )
{
    int i;

    if( !create( _sample_count, _var_count, _samples, _y, _sample_count,
                 _alpha, _Cp, _Cn, _storage, _kernel, &CvSVMSolver::get_row_svc,
                 &CvSVMSolver::select_working_set, &CvSVMSolver::calc_rho ))
        return false;

    for( i = 0; i < sample_count; i++ )
    {
        alpha[i] = 0;
        b[i] = -1;
    }

    if( !solve_generic( _si ))
        return false;

    for( i = 0; i < sample_count; i++ )
        alpha[i] *= y[i];

    return true;
}


bool CvSVMSolver::solve_nu_svc( int _sample_count, int _var_count, const float** _samples, schar* _y,
                                CvMemStorage* _storage, CvSVMKernel* _kernel,
                                double* _alpha, CvSVMSolutionInfo& _si )
{
    int i;
    double sum_pos, sum_neg, inv_r;

    if( !create( _sample_count, _var_count, _samples, _y, _sample_count,
                 _alpha, 1., 1., _storage, _kernel, &CvSVMSolver::get_row_svc,
                 &CvSVMSolver::select_working_set_nu_svm, &CvSVMSolver::calc_rho_nu_svm ))
        return false;

    sum_pos = kernel->params->nu * sample_count * 0.5;
    sum_neg = kernel->params->nu * sample_count * 0.5;

    for( i = 0; i < sample_count; i++ )
    {
        if( y[i] > 0 )
wester committed
999
        {
wester committed
1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246 1247 1248 1249 1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260 1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286 1287 1288 1289 1290 1291 1292 1293
            alpha[i] = MIN(1.0, sum_pos);
            sum_pos -= alpha[i];
        }
        else
        {
            alpha[i] = MIN(1.0, sum_neg);
            sum_neg -= alpha[i];
        }
        b[i] = 0;
    }

    if( !solve_generic( _si ))
        return false;

    inv_r = 1./_si.r;

    for( i = 0; i < sample_count; i++ )
        alpha[i] *= y[i]*inv_r;

    _si.rho *= inv_r;
    _si.obj *= (inv_r*inv_r);
    _si.upper_bound_p = inv_r;
    _si.upper_bound_n = inv_r;

    return true;
}


bool CvSVMSolver::solve_one_class( int _sample_count, int _var_count, const float** _samples,
                                   CvMemStorage* _storage, CvSVMKernel* _kernel,
                                   double* _alpha, CvSVMSolutionInfo& _si )
{
    int i, n;
    double nu = _kernel->params->nu;

    if( !create( _sample_count, _var_count, _samples, 0, _sample_count,
                 _alpha, 1., 1., _storage, _kernel, &CvSVMSolver::get_row_one_class,
                 &CvSVMSolver::select_working_set, &CvSVMSolver::calc_rho ))
        return false;

    y = (schar*)cvMemStorageAlloc( storage, sample_count*sizeof(y[0]) );
    n = cvRound( nu*sample_count );

    for( i = 0; i < sample_count; i++ )
    {
        y[i] = 1;
        b[i] = 0;
        alpha[i] = i < n ? 1 : 0;
    }

    if( n < sample_count )
        alpha[n] = nu * sample_count - n;
    else
        alpha[n-1] = nu * sample_count - (n-1);

    return solve_generic(_si);
}


bool CvSVMSolver::solve_eps_svr( int _sample_count, int _var_count, const float** _samples,
                                 const float* _y, CvMemStorage* _storage,
                                 CvSVMKernel* _kernel, double* _alpha, CvSVMSolutionInfo& _si )
{
    int i;
    double p = _kernel->params->p, kernel_param_c = _kernel->params->C;

    if( !create( _sample_count, _var_count, _samples, 0,
                 _sample_count*2, 0, kernel_param_c, kernel_param_c, _storage, _kernel, &CvSVMSolver::get_row_svr,
                 &CvSVMSolver::select_working_set, &CvSVMSolver::calc_rho ))
        return false;

    y = (schar*)cvMemStorageAlloc( storage, sample_count*2*sizeof(y[0]) );
    alpha = (double*)cvMemStorageAlloc( storage, alpha_count*sizeof(alpha[0]) );

    for( i = 0; i < sample_count; i++ )
    {
        alpha[i] = 0;
        b[i] = p - _y[i];
        y[i] = 1;

        alpha[i+sample_count] = 0;
        b[i+sample_count] = p + _y[i];
        y[i+sample_count] = -1;
    }

    if( !solve_generic( _si ))
        return false;

    for( i = 0; i < sample_count; i++ )
        _alpha[i] = alpha[i] - alpha[i+sample_count];

    return true;
}


bool CvSVMSolver::solve_nu_svr( int _sample_count, int _var_count, const float** _samples,
                                const float* _y, CvMemStorage* _storage,
                                CvSVMKernel* _kernel, double* _alpha, CvSVMSolutionInfo& _si )
{
    int i;
    double kernel_param_c = _kernel->params->C, sum;

    if( !create( _sample_count, _var_count, _samples, 0,
                 _sample_count*2, 0, 1., 1., _storage, _kernel, &CvSVMSolver::get_row_svr,
                 &CvSVMSolver::select_working_set_nu_svm, &CvSVMSolver::calc_rho_nu_svm ))
        return false;

    y = (schar*)cvMemStorageAlloc( storage, sample_count*2*sizeof(y[0]) );
    alpha = (double*)cvMemStorageAlloc( storage, alpha_count*sizeof(alpha[0]) );
    sum = kernel_param_c * _kernel->params->nu * sample_count * 0.5;

    for( i = 0; i < sample_count; i++ )
    {
        alpha[i] = alpha[i + sample_count] = MIN(sum, kernel_param_c);
        sum -= alpha[i];

        b[i] = -_y[i];
        y[i] = 1;

        b[i + sample_count] = _y[i];
        y[i + sample_count] = -1;
    }

    if( !solve_generic( _si ))
        return false;

    for( i = 0; i < sample_count; i++ )
        _alpha[i] = alpha[i] - alpha[i+sample_count];

    return true;
}


//////////////////////////////////////////////////////////////////////////////////////////

CvSVM::CvSVM()
{
    decision_func = 0;
    class_labels = 0;
    class_weights = 0;
    storage = 0;
    var_idx = 0;
    kernel = 0;
    solver = 0;
    default_model_name = "my_svm";

    clear();
}


CvSVM::~CvSVM()
{
    clear();
}


void CvSVM::clear()
{
    cvFree( &decision_func );
    cvReleaseMat( &class_labels );
    cvReleaseMat( &class_weights );
    cvReleaseMemStorage( &storage );
    cvReleaseMat( &var_idx );
    delete kernel;
    delete solver;
    kernel = 0;
    solver = 0;
    var_all = 0;
    sv = 0;
    sv_total = 0;
}


CvSVM::CvSVM( const CvMat* _train_data, const CvMat* _responses,
    const CvMat* _var_idx, const CvMat* _sample_idx, CvSVMParams _params )
{
    decision_func = 0;
    class_labels = 0;
    class_weights = 0;
    storage = 0;
    var_idx = 0;
    kernel = 0;
    solver = 0;
    default_model_name = "my_svm";

    train( _train_data, _responses, _var_idx, _sample_idx, _params );
}


int CvSVM::get_support_vector_count() const
{
    return sv_total;
}


const float* CvSVM::get_support_vector(int i) const
{
    return sv && (unsigned)i < (unsigned)sv_total ? sv[i] : 0;
}


bool CvSVM::set_params( const CvSVMParams& _params )
{
    bool ok = false;

    CV_FUNCNAME( "CvSVM::set_params" );

    __BEGIN__;

    int kernel_type, svm_type;

    params = _params;

    kernel_type = params.kernel_type;
    svm_type = params.svm_type;

    if( kernel_type != LINEAR && kernel_type != POLY &&
        kernel_type != SIGMOID && kernel_type != RBF )
        CV_ERROR( CV_StsBadArg, "Unknown/unsupported kernel type" );

    if( kernel_type == LINEAR )
        params.gamma = 1;
    else if( params.gamma <= 0 )
        CV_ERROR( CV_StsOutOfRange, "gamma parameter of the kernel must be positive" );

    if( kernel_type != SIGMOID && kernel_type != POLY )
        params.coef0 = 0;
    else if( params.coef0 < 0 )
        CV_ERROR( CV_StsOutOfRange, "The kernel parameter <coef0> must be positive or zero" );

    if( kernel_type != POLY )
        params.degree = 0;
    else if( params.degree <= 0 )
        CV_ERROR( CV_StsOutOfRange, "The kernel parameter <degree> must be positive" );

    if( svm_type != C_SVC && svm_type != NU_SVC &&
        svm_type != ONE_CLASS && svm_type != EPS_SVR &&
        svm_type != NU_SVR )
        CV_ERROR( CV_StsBadArg, "Unknown/unsupported SVM type" );

    if( svm_type == ONE_CLASS || svm_type == NU_SVC )
        params.C = 0;
    else if( params.C <= 0 )
        CV_ERROR( CV_StsOutOfRange, "The parameter C must be positive" );

    if( svm_type == C_SVC || svm_type == EPS_SVR )
        params.nu = 0;
    else if( params.nu <= 0 || params.nu >= 1 )
        CV_ERROR( CV_StsOutOfRange, "The parameter nu must be between 0 and 1" );

    if( svm_type != EPS_SVR )
        params.p = 0;
    else if( params.p <= 0 )
        CV_ERROR( CV_StsOutOfRange, "The parameter p must be positive" );

    if( svm_type != C_SVC )
        params.class_weights = 0;

    params.term_crit = cvCheckTermCriteria( params.term_crit, DBL_EPSILON, INT_MAX );
    params.term_crit.epsilon = MAX( params.term_crit.epsilon, DBL_EPSILON );
    ok = true;

    __END__;

    return ok;
}



void CvSVM::create_kernel()
{
    kernel = new CvSVMKernel(&params,0);
}


void CvSVM::create_solver( )
{
    solver = new CvSVMSolver;
}


// switching function
bool CvSVM::train1( int sample_count, int var_count, const float** samples,
                    const void* _responses, double Cp, double Cn,
                    CvMemStorage* _storage, double* alpha, double& rho )
{
    bool ok = false;

    //CV_FUNCNAME( "CvSVM::train1" );

    __BEGIN__;

    CvSVMSolutionInfo si;
    int svm_type = params.svm_type;
wester committed
1294

wester committed
1295
    si.rho = 0;
wester committed
1296

wester committed
1297 1298 1299 1300 1301 1302 1303 1304 1305 1306
    ok = svm_type == C_SVC ? solver->solve_c_svc( sample_count, var_count, samples, (schar*)_responses,
                                                  Cp, Cn, _storage, kernel, alpha, si ) :
         svm_type == NU_SVC ? solver->solve_nu_svc( sample_count, var_count, samples, (schar*)_responses,
                                                    _storage, kernel, alpha, si ) :
         svm_type == ONE_CLASS ? solver->solve_one_class( sample_count, var_count, samples,
                                                          _storage, kernel, alpha, si ) :
         svm_type == EPS_SVR ? solver->solve_eps_svr( sample_count, var_count, samples, (float*)_responses,
                                                      _storage, kernel, alpha, si ) :
         svm_type == NU_SVR ? solver->solve_nu_svr( sample_count, var_count, samples, (float*)_responses,
                                                    _storage, kernel, alpha, si ) : false;
wester committed
1307

wester committed
1308 1309 1310 1311 1312 1313 1314 1315 1316 1317 1318 1319 1320 1321 1322 1323 1324 1325 1326 1327 1328 1329 1330 1331 1332 1333 1334 1335 1336 1337 1338 1339 1340 1341 1342 1343 1344 1345 1346 1347 1348 1349 1350 1351 1352 1353 1354
    rho = si.rho;

    __END__;

    return ok;
}


bool CvSVM::do_train( int svm_type, int sample_count, int var_count, const float** samples,
                    const CvMat* responses, CvMemStorage* temp_storage, double* alpha )
{
    bool ok = false;

    CV_FUNCNAME( "CvSVM::do_train" );

    __BEGIN__;

    CvSVMDecisionFunc* df = 0;
    const int sample_size = var_count*sizeof(samples[0][0]);
    int i, j, k;

    cvClearMemStorage( storage );

    if( svm_type == ONE_CLASS || svm_type == EPS_SVR || svm_type == NU_SVR )
    {
        int sv_count = 0;

        CV_CALL( decision_func = df =
            (CvSVMDecisionFunc*)cvAlloc( sizeof(df[0]) ));

        df->rho = 0;
        if( !train1( sample_count, var_count, samples, svm_type == ONE_CLASS ? 0 :
            responses->data.i, 0, 0, temp_storage, alpha, df->rho ))
            EXIT;

        for( i = 0; i < sample_count; i++ )
            sv_count += fabs(alpha[i]) > 0;

        CV_Assert(sv_count != 0);

        sv_total = df->sv_count = sv_count;
        CV_CALL( df->alpha = (double*)cvMemStorageAlloc( storage, sv_count*sizeof(df->alpha[0])) );
        CV_CALL( sv = (float**)cvMemStorageAlloc( storage, sv_count*sizeof(sv[0])));

        for( i = k = 0; i < sample_count; i++ )
        {
            if( fabs(alpha[i]) > 0 )
wester committed
1355
            {
wester committed
1356 1357 1358 1359 1360 1361 1362 1363 1364 1365 1366 1367 1368 1369 1370 1371 1372 1373 1374 1375 1376 1377 1378 1379 1380 1381 1382 1383 1384
                CV_CALL( sv[k] = (float*)cvMemStorageAlloc( storage, sample_size ));
                memcpy( sv[k], samples[i], sample_size );
                df->alpha[k++] = alpha[i];
            }
        }
    }
    else
    {
        int class_count = class_labels->cols;
        int* sv_tab = 0;
        const float** temp_samples = 0;
        int* class_ranges = 0;
        schar* temp_y = 0;
        assert( svm_type == CvSVM::C_SVC || svm_type == CvSVM::NU_SVC );

        if( svm_type == CvSVM::C_SVC && params.class_weights )
        {
            const CvMat* cw = params.class_weights;

            if( !CV_IS_MAT(cw) || (cw->cols != 1 && cw->rows != 1) ||
                cw->rows + cw->cols - 1 != class_count ||
                (CV_MAT_TYPE(cw->type) != CV_32FC1 && CV_MAT_TYPE(cw->type) != CV_64FC1) )
                CV_ERROR( CV_StsBadArg, "params.class_weights must be 1d floating-point vector "
                    "containing as many elements as the number of classes" );

            CV_CALL( class_weights = cvCreateMat( cw->rows, cw->cols, CV_64F ));
            CV_CALL( cvConvert( cw, class_weights ));
            CV_CALL( cvScale( class_weights, class_weights, params.C ));
        }
wester committed
1385

wester committed
1386 1387 1388 1389 1390 1391 1392 1393 1394 1395 1396 1397 1398 1399 1400 1401 1402 1403 1404 1405 1406 1407 1408 1409 1410
        CV_CALL( decision_func = df = (CvSVMDecisionFunc*)cvAlloc(
            (class_count*(class_count-1)/2)*sizeof(df[0])));

        CV_CALL( sv_tab = (int*)cvMemStorageAlloc( temp_storage, sample_count*sizeof(sv_tab[0]) ));
        memset( sv_tab, 0, sample_count*sizeof(sv_tab[0]) );
        CV_CALL( class_ranges = (int*)cvMemStorageAlloc( temp_storage,
                            (class_count + 1)*sizeof(class_ranges[0])));
        CV_CALL( temp_samples = (const float**)cvMemStorageAlloc( temp_storage,
                            sample_count*sizeof(temp_samples[0])));
        CV_CALL( temp_y = (schar*)cvMemStorageAlloc( temp_storage, sample_count));

        class_ranges[class_count] = 0;
        cvSortSamplesByClasses( samples, responses, class_ranges, 0 );
        //check that while cross-validation there were the samples from all the classes
        if( class_ranges[class_count] <= 0 )
            CV_ERROR( CV_StsBadArg, "While cross-validation one or more of the classes have "
            "been fell out of the sample. Try to enlarge <CvSVMParams::k_fold>" );

        if( svm_type == NU_SVC )
        {
            // check if nu is feasible
            for(i = 0; i < class_count; i++ )
            {
                int ci = class_ranges[i+1] - class_ranges[i];
                for( j = i+1; j< class_count; j++ )
wester committed
1411
                {
wester committed
1412 1413
                    int cj = class_ranges[j+1] - class_ranges[j];
                    if( params.nu*(ci + cj)*0.5 > MIN( ci, cj ) )
wester committed
1414
                    {
wester committed
1415 1416
                        // !!!TODO!!! add some diagnostic
                        EXIT; // exit immediately; will release the model and return NULL pointer
wester committed
1417 1418 1419 1420 1421
                    }
                }
            }
        }

wester committed
1422 1423
        // train n*(n-1)/2 classifiers
        for( i = 0; i < class_count; i++ )
wester committed
1424
        {
wester committed
1425
            for( j = i+1; j < class_count; j++, df++ )
wester committed
1426
            {
wester committed
1427 1428 1429 1430
                int si = class_ranges[i], ci = class_ranges[i+1] - si;
                int sj = class_ranges[j], cj = class_ranges[j+1] - sj;
                double Cp = params.C, Cn = Cp;
                int k1 = 0, sv_count = 0;
wester committed
1431

wester committed
1432
                for( k = 0; k < ci; k++ )
wester committed
1433
                {
wester committed
1434 1435
                    temp_samples[k] = samples[si + k];
                    temp_y[k] = 1;
wester committed
1436
                }
wester committed
1437 1438

                for( k = 0; k < cj; k++ )
wester committed
1439
                {
wester committed
1440 1441
                    temp_samples[ci + k] = samples[sj + k];
                    temp_y[ci + k] = -1;
wester committed
1442
                }
wester committed
1443 1444

                if( class_weights )
wester committed
1445
                {
wester committed
1446 1447
                    Cp = class_weights->data.db[i];
                    Cn = class_weights->data.db[j];
wester committed
1448 1449
                }

wester committed
1450 1451 1452
                if( !train1( ci + cj, var_count, temp_samples, temp_y,
                             Cp, Cn, temp_storage, alpha, df->rho ))
                    EXIT;
wester committed
1453

wester committed
1454 1455
                for( k = 0; k < ci + cj; k++ )
                    sv_count += fabs(alpha[k]) > 0;
wester committed
1456

wester committed
1457
                df->sv_count = sv_count;
wester committed
1458

wester committed
1459 1460 1461 1462
                CV_CALL( df->alpha = (double*)cvMemStorageAlloc( temp_storage,
                                                sv_count*sizeof(df->alpha[0])));
                CV_CALL( df->sv_index = (int*)cvMemStorageAlloc( temp_storage,
                                                sv_count*sizeof(df->sv_index[0])));
wester committed
1463

wester committed
1464
                for( k = 0; k < ci; k++ )
wester committed
1465
                {
wester committed
1466
                    if( fabs(alpha[k]) > 0 )
wester committed
1467
                    {
wester committed
1468 1469 1470
                        sv_tab[si + k] = 1;
                        df->sv_index[k1] = si + k;
                        df->alpha[k1++] = alpha[k];
wester committed
1471 1472
                    }
                }
wester committed
1473 1474

                for( k = 0; k < cj; k++ )
wester committed
1475
                {
wester committed
1476
                    if( fabs(alpha[ci + k]) > 0 )
wester committed
1477
                    {
wester committed
1478 1479 1480
                        sv_tab[sj + k] = 1;
                        df->sv_index[k1] = sj + k;
                        df->alpha[k1++] = alpha[ci + k];
wester committed
1481 1482 1483
                    }
                }
            }
wester committed
1484
        }
wester committed
1485

wester committed
1486 1487 1488 1489 1490 1491
        // allocate support vectors and initialize sv_tab
        for( i = 0, k = 0; i < sample_count; i++ )
        {
            if( sv_tab[i] )
                sv_tab[i] = ++k;
        }
wester committed
1492

wester committed
1493 1494 1495 1496 1497 1498
        sv_total = k;
        CV_CALL( sv = (float**)cvMemStorageAlloc( storage, sv_total*sizeof(sv[0])));

        for( i = 0, k = 0; i < sample_count; i++ )
        {
            if( sv_tab[i] )
wester committed
1499
            {
wester committed
1500 1501 1502
                CV_CALL( sv[k] = (float*)cvMemStorageAlloc( storage, sample_size ));
                memcpy( sv[k], samples[i], sample_size );
                k++;
wester committed
1503 1504 1505
            }
        }

wester committed
1506
        df = (CvSVMDecisionFunc*)decision_func;
wester committed
1507

wester committed
1508 1509 1510 1511
        // set sv pointers
        for( i = 0; i < class_count; i++ )
        {
            for( j = i+1; j < class_count; j++, df++ )
wester committed
1512
            {
wester committed
1513
                for( k = 0; k < df->sv_count; k++ )
wester committed
1514
                {
wester committed
1515 1516
                    df->sv_index[k] = sv_tab[df->sv_index[k]]-1;
                    assert( (unsigned)df->sv_index[k] < (unsigned)sv_total );
wester committed
1517 1518
                }
            }
wester committed
1519 1520
        }
    }
wester committed
1521

wester committed
1522 1523
    optimize_linear_svm();
    ok = true;
wester committed
1524

wester committed
1525 1526 1527 1528
    __END__;

    return ok;
}
wester committed
1529 1530


wester committed
1531 1532 1533 1534 1535
void CvSVM::optimize_linear_svm()
{
    // we optimize only linear SVM: compress all the support vectors into one.
    if( params.kernel_type != LINEAR )
        return;
wester committed
1536

wester committed
1537 1538
    int class_count = class_labels ? class_labels->cols :
            params.svm_type == CvSVM::ONE_CLASS ? 1 : 0;
wester committed
1539

wester committed
1540 1541
    int i, df_count = class_count > 1 ? class_count*(class_count-1)/2 : 1;
    CvSVMDecisionFunc* df = decision_func;
wester committed
1542

wester committed
1543 1544 1545 1546 1547 1548
    for( i = 0; i < df_count; i++ )
    {
        int sv_count = df[i].sv_count;
        if( sv_count != 1 )
            break;
    }
wester committed
1549

wester committed
1550 1551 1552 1553
    // if every decision functions uses a single support vector;
    // it's already compressed. skip it then.
    if( i == df_count )
        return;
wester committed
1554

wester committed
1555 1556 1557 1558
    int var_count = get_var_count();
    cv::AutoBuffer<double> vbuf(var_count);
    double* v = vbuf;
    float** new_sv = (float**)cvMemStorageAlloc(storage, df_count*sizeof(new_sv[0]));
wester committed
1559

wester committed
1560 1561 1562 1563 1564 1565 1566
    for( i = 0; i < df_count; i++ )
    {
        new_sv[i] = (float*)cvMemStorageAlloc(storage, var_count*sizeof(new_sv[i][0]));
        float* dst = new_sv[i];
        memset(v, 0, var_count*sizeof(v[0]));
        int j, k, sv_count = df[i].sv_count;
        for( j = 0; j < sv_count; j++ )
wester committed
1567
        {
wester committed
1568 1569 1570 1571 1572 1573 1574 1575 1576 1577 1578 1579
            const float* src = class_count > 1 && df[i].sv_index ? sv[df[i].sv_index[j]] : sv[j];
            double a = df[i].alpha[j];
            for( k = 0; k < var_count; k++ )
                v[k] += src[k]*a;
        }
        for( k = 0; k < var_count; k++ )
            dst[k] = (float)v[k];
        df[i].sv_count = 1;
        df[i].alpha[0] = 1.;
        if( class_count > 1 && df[i].sv_index )
            df[i].sv_index[0] = i;
    }
wester committed
1580

wester committed
1581 1582 1583
    sv = new_sv;
    sv_total = df_count;
}
wester committed
1584 1585


wester committed
1586 1587 1588 1589 1590 1591 1592
bool CvSVM::train( const CvMat* _train_data, const CvMat* _responses,
    const CvMat* _var_idx, const CvMat* _sample_idx, CvSVMParams _params )
{
    bool ok = false;
    CvMat* responses = 0;
    CvMemStorage* temp_storage = 0;
    const float** samples = 0;
wester committed
1593

wester committed
1594
    CV_FUNCNAME( "CvSVM::train" );
wester committed
1595

wester committed
1596
    __BEGIN__;
wester committed
1597

wester committed
1598 1599 1600
    int svm_type, sample_count, var_count, sample_size;
    int block_size = 1 << 16;
    double* alpha;
wester committed
1601

wester committed
1602 1603
    clear();
    CV_CALL( set_params( _params ));
wester committed
1604

wester committed
1605
    svm_type = _params.svm_type;
wester committed
1606

wester committed
1607 1608 1609 1610 1611 1612 1613 1614
    /* Prepare training data and related parameters */
    CV_CALL( cvPrepareTrainData( "CvSVM::train", _train_data, CV_ROW_SAMPLE,
                                 svm_type != CvSVM::ONE_CLASS ? _responses : 0,
                                 svm_type == CvSVM::C_SVC ||
                                 svm_type == CvSVM::NU_SVC ? CV_VAR_CATEGORICAL :
                                 CV_VAR_ORDERED, _var_idx, _sample_idx,
                                 false, &samples, &sample_count, &var_count, &var_all,
                                 &responses, &class_labels, &var_idx ));
wester committed
1615 1616


wester committed
1617
    sample_size = var_count*sizeof(samples[0][0]);
wester committed
1618

wester committed
1619 1620 1621 1622 1623
    // make the storage block size large enough to fit all
    // the temporary vectors and output support vectors.
    block_size = MAX( block_size, sample_count*(int)sizeof(CvSVMKernelRow));
    block_size = MAX( block_size, sample_count*2*(int)sizeof(double) + 1024 );
    block_size = MAX( block_size, sample_size*2 + 1024 );
wester committed
1624

wester committed
1625 1626 1627
    CV_CALL( storage = cvCreateMemStorage(block_size + sizeof(CvMemBlock) + sizeof(CvSeqBlock)));
    CV_CALL( temp_storage = cvCreateChildMemStorage(storage));
    CV_CALL( alpha = (double*)cvMemStorageAlloc(temp_storage, sample_count*sizeof(double)));
wester committed
1628

wester committed
1629 1630
    create_kernel();
    create_solver();
wester committed
1631

wester committed
1632 1633
    if( !do_train( svm_type, sample_count, var_count, samples, responses, temp_storage, alpha ))
        EXIT;
wester committed
1634

wester committed
1635
    ok = true; // model has been trained successfully
wester committed
1636

wester committed
1637
    __END__;
wester committed
1638

wester committed
1639 1640 1641 1642 1643
    delete solver;
    solver = 0;
    cvReleaseMemStorage( &temp_storage );
    cvReleaseMat( &responses );
    cvFree( &samples );
wester committed
1644

wester committed
1645 1646
    if( cvGetErrStatus() < 0 || !ok )
        clear();
wester committed
1647

wester committed
1648 1649
    return ok;
}
wester committed
1650

wester committed
1651 1652 1653 1654 1655 1656 1657
struct indexedratio
{
    double val;
    int ind;
    int count_smallest, count_biggest;
    void eval() { val = (double) count_smallest/(count_smallest+count_biggest); }
};
wester committed
1658

wester committed
1659 1660 1661 1662 1663 1664 1665
static int CV_CDECL
icvCmpIndexedratio( const void* a, const void* b )
{
    return ((const indexedratio*)a)->val < ((const indexedratio*)b)->val ? -1
    : ((const indexedratio*)a)->val > ((const indexedratio*)b)->val ? 1
    : 0;
}
wester committed
1666

wester committed
1667 1668 1669 1670 1671 1672 1673 1674 1675 1676 1677 1678 1679 1680 1681 1682 1683 1684 1685 1686 1687 1688 1689 1690 1691 1692 1693 1694 1695 1696 1697 1698 1699
bool CvSVM::train_auto( const CvMat* _train_data, const CvMat* _responses,
    const CvMat* _var_idx, const CvMat* _sample_idx, CvSVMParams _params, int k_fold,
    CvParamGrid C_grid, CvParamGrid gamma_grid, CvParamGrid p_grid,
    CvParamGrid nu_grid, CvParamGrid coef_grid, CvParamGrid degree_grid,
    bool balanced)
{
    bool ok = false;
    CvMat* responses = 0;
    CvMat* responses_local = 0;
    CvMemStorage* temp_storage = 0;
    const float** samples = 0;
    const float** samples_local = 0;

    CV_FUNCNAME( "CvSVM::train_auto" );
    __BEGIN__;

    int svm_type, sample_count, var_count, sample_size;
    int block_size = 1 << 16;
    double* alpha;
    RNG* rng = &theRNG();

    // all steps are logarithmic and must be > 1
    double degree_step = 10, g_step = 10, coef_step = 10, C_step = 10, nu_step = 10, p_step = 10;
    double gamma = 0, curr_c = 0, degree = 0, coef = 0, p = 0, nu = 0;
    double best_degree = 0, best_gamma = 0, best_coef = 0, best_C = 0, best_nu = 0, best_p = 0;
    float min_error = FLT_MAX, error;

    if( _params.svm_type == CvSVM::ONE_CLASS )
    {
        if(!train( _train_data, _responses, _var_idx, _sample_idx, _params ))
            EXIT;
        return true;
    }
wester committed
1700

wester committed
1701
    clear();
wester committed
1702

wester committed
1703 1704
    if( k_fold < 2 )
        CV_ERROR( CV_StsBadArg, "Parameter <k_fold> must be > 1" );
wester committed
1705

wester committed
1706 1707
    CV_CALL(set_params( _params ));
    svm_type = _params.svm_type;
wester committed
1708

wester committed
1709 1710 1711 1712 1713 1714 1715 1716 1717
    // All the parameters except, possibly, <coef0> are positive.
    // <coef0> is nonnegative
    if( C_grid.step <= 1 )
    {
        C_grid.min_val = C_grid.max_val = params.C;
        C_grid.step = 10;
    }
    else
        CV_CALL(C_grid.check());
wester committed
1718

wester committed
1719 1720 1721 1722 1723 1724 1725
    if( gamma_grid.step <= 1 )
    {
        gamma_grid.min_val = gamma_grid.max_val = params.gamma;
        gamma_grid.step = 10;
    }
    else
        CV_CALL(gamma_grid.check());
wester committed
1726

wester committed
1727 1728 1729 1730 1731 1732 1733
    if( p_grid.step <= 1 )
    {
        p_grid.min_val = p_grid.max_val = params.p;
        p_grid.step = 10;
    }
    else
        CV_CALL(p_grid.check());
wester committed
1734

wester committed
1735 1736 1737 1738 1739 1740 1741
    if( nu_grid.step <= 1 )
    {
        nu_grid.min_val = nu_grid.max_val = params.nu;
        nu_grid.step = 10;
    }
    else
        CV_CALL(nu_grid.check());
wester committed
1742

wester committed
1743 1744 1745 1746 1747 1748 1749
    if( coef_grid.step <= 1 )
    {
        coef_grid.min_val = coef_grid.max_val = params.coef0;
        coef_grid.step = 10;
    }
    else
        CV_CALL(coef_grid.check());
wester committed
1750

wester committed
1751 1752 1753 1754 1755 1756 1757 1758 1759 1760 1761 1762 1763 1764 1765 1766 1767 1768 1769 1770 1771 1772 1773 1774 1775 1776 1777 1778 1779 1780 1781 1782 1783 1784 1785 1786 1787 1788 1789 1790 1791 1792 1793 1794 1795 1796 1797 1798
    if( degree_grid.step <= 1 )
    {
        degree_grid.min_val = degree_grid.max_val = params.degree;
        degree_grid.step = 10;
    }
    else
        CV_CALL(degree_grid.check());

    // these parameters are not used:
    if( params.kernel_type != CvSVM::POLY )
        degree_grid.min_val = degree_grid.max_val = params.degree;
    if( params.kernel_type == CvSVM::LINEAR )
        gamma_grid.min_val = gamma_grid.max_val = params.gamma;
    if( params.kernel_type != CvSVM::POLY && params.kernel_type != CvSVM::SIGMOID )
        coef_grid.min_val = coef_grid.max_val = params.coef0;
    if( svm_type == CvSVM::NU_SVC || svm_type == CvSVM::ONE_CLASS )
        C_grid.min_val = C_grid.max_val = params.C;
    if( svm_type == CvSVM::C_SVC || svm_type == CvSVM::EPS_SVR )
        nu_grid.min_val = nu_grid.max_val = params.nu;
    if( svm_type != CvSVM::EPS_SVR )
        p_grid.min_val = p_grid.max_val = params.p;

    CV_ASSERT( g_step > 1 && degree_step > 1 && coef_step > 1);
    CV_ASSERT( p_step > 1 && C_step > 1 && nu_step > 1 );

    /* Prepare training data and related parameters */
    CV_CALL(cvPrepareTrainData( "CvSVM::train_auto", _train_data, CV_ROW_SAMPLE,
                                 svm_type != CvSVM::ONE_CLASS ? _responses : 0,
                                 svm_type == CvSVM::C_SVC ||
                                 svm_type == CvSVM::NU_SVC ? CV_VAR_CATEGORICAL :
                                 CV_VAR_ORDERED, _var_idx, _sample_idx,
                                 false, &samples, &sample_count, &var_count, &var_all,
                                 &responses, &class_labels, &var_idx ));

    sample_size = var_count*sizeof(samples[0][0]);

    // make the storage block size large enough to fit all
    // the temporary vectors and output support vectors.
    block_size = MAX( block_size, sample_count*(int)sizeof(CvSVMKernelRow));
    block_size = MAX( block_size, sample_count*2*(int)sizeof(double) + 1024 );
    block_size = MAX( block_size, sample_size*2 + 1024 );

    CV_CALL( storage = cvCreateMemStorage(block_size + sizeof(CvMemBlock) + sizeof(CvSeqBlock)));
    CV_CALL(temp_storage = cvCreateChildMemStorage(storage));
    CV_CALL(alpha = (double*)cvMemStorageAlloc(temp_storage, sample_count*sizeof(double)));

    create_kernel();
    create_solver();
wester committed
1799

wester committed
1800 1801 1802 1803 1804 1805
    {
    const int testset_size = sample_count/k_fold;
    const int trainset_size = sample_count - testset_size;
    const int last_testset_size = sample_count - testset_size*(k_fold-1);
    const int last_trainset_size = sample_count - last_testset_size;
    const bool is_regression = (svm_type == EPS_SVR) || (svm_type == NU_SVR);
wester committed
1806

wester committed
1807 1808
    size_t resp_elem_size = CV_ELEM_SIZE(responses->type);
    size_t size = 2*last_trainset_size*sizeof(samples[0]);
wester committed
1809

wester committed
1810 1811 1812 1813 1814
    samples_local = (const float**) cvAlloc( size );
    memset( samples_local, 0, size );

    responses_local = cvCreateMat( 1, trainset_size, CV_MAT_TYPE(responses->type) );
    cvZero( responses_local );
wester committed
1815

wester committed
1816 1817
    // randomly permute samples and responses
    for(int i = 0; i < sample_count; i++ )
wester committed
1818
    {
wester committed
1819 1820 1821 1822 1823 1824 1825 1826 1827 1828 1829 1830 1831 1832 1833 1834 1835 1836 1837 1838 1839 1840 1841 1842 1843 1844 1845 1846 1847 1848 1849 1850 1851 1852 1853 1854 1855 1856 1857 1858 1859 1860 1861 1862 1863 1864 1865 1866 1867 1868 1869 1870 1871 1872 1873 1874 1875 1876 1877 1878 1879 1880 1881 1882 1883 1884 1885 1886 1887 1888 1889 1890 1891 1892 1893 1894 1895 1896 1897 1898 1899 1900 1901 1902 1903 1904 1905 1906 1907 1908 1909 1910 1911 1912 1913 1914 1915 1916 1917 1918 1919 1920 1921 1922 1923 1924 1925 1926 1927 1928
        int i1 = (*rng)(sample_count);
        int i2 = (*rng)(sample_count);
        const float* temp;
        float t;
        int y;

        CV_SWAP( samples[i1], samples[i2], temp );
        if( is_regression )
            CV_SWAP( responses->data.fl[i1], responses->data.fl[i2], t );
        else
            CV_SWAP( responses->data.i[i1], responses->data.i[i2], y );
    }

    if (!is_regression && class_labels->cols==2 && balanced)
    {
        // count class samples
        int num_0=0,num_1=0;
        for (int i=0; i<sample_count; ++i)
        {
            if (responses->data.i[i]==class_labels->data.i[0])
                ++num_0;
            else
                ++num_1;
        }

        int label_smallest_class;
        int label_biggest_class;
        if (num_0 < num_1)
        {
            label_biggest_class = class_labels->data.i[1];
            label_smallest_class = class_labels->data.i[0];
        }
        else
        {
            label_biggest_class = class_labels->data.i[0];
            label_smallest_class = class_labels->data.i[1];
            int y;
            CV_SWAP(num_0,num_1,y);
        }
        const double class_ratio = (double) num_0/sample_count;
        // calculate class ratio of each fold
        indexedratio *ratios=0;
        ratios = (indexedratio*) cvAlloc(k_fold*sizeof(*ratios));
        for (int k=0, i_begin=0; k<k_fold; ++k, i_begin+=testset_size)
        {
            int count0=0;
            int count1=0;
            int i_end = i_begin + (k<k_fold-1 ? testset_size : last_testset_size);
            for (int i=i_begin; i<i_end; ++i)
            {
                if (responses->data.i[i]==label_smallest_class)
                    ++count0;
                else
                    ++count1;
            }
            ratios[k].ind = k;
            ratios[k].count_smallest = count0;
            ratios[k].count_biggest = count1;
            ratios[k].eval();
        }
        // initial distance
        qsort(ratios, k_fold, sizeof(ratios[0]), icvCmpIndexedratio);
        double old_dist = 0.0;
        for (int k=0; k<k_fold; ++k)
            old_dist += std::abs(ratios[k].val-class_ratio);
        double new_dist = 1.0;
        // iterate to make the folds more balanced
        while (new_dist > 0.0)
        {
            if (ratios[0].count_biggest==0 || ratios[k_fold-1].count_smallest==0)
                break; // we are not able to swap samples anymore
            // what if we swap the samples, calculate the new distance
            ratios[0].count_smallest++;
            ratios[0].count_biggest--;
            ratios[0].eval();
            ratios[k_fold-1].count_smallest--;
            ratios[k_fold-1].count_biggest++;
            ratios[k_fold-1].eval();
            qsort(ratios, k_fold, sizeof(ratios[0]), icvCmpIndexedratio);
            new_dist = 0.0;
            for (int k=0; k<k_fold; ++k)
                new_dist += std::abs(ratios[k].val-class_ratio);
            if (new_dist < old_dist)
            {
                // swapping really improves, so swap the samples
                // index of the biggest_class sample from the minimum ratio fold
                int i1 = ratios[0].ind * testset_size;
                for ( ; i1<sample_count; ++i1)
                {
                    if (responses->data.i[i1]==label_biggest_class)
                        break;
                }
                // index of the smallest_class sample from the maximum ratio fold
                int i2 = ratios[k_fold-1].ind * testset_size;
                for ( ; i2<sample_count; ++i2)
                {
                    if (responses->data.i[i2]==label_smallest_class)
                        break;
                }
                // swap
                const float* temp;
                int y;
                CV_SWAP( samples[i1], samples[i2], temp );
                CV_SWAP( responses->data.i[i1], responses->data.i[i2], y );
                old_dist = new_dist;
            }
            else
                break; // does not improve, so break the loop
        }
        cvFree(&ratios);
wester committed
1929 1930
    }

wester committed
1931 1932 1933
    int* cls_lbls = class_labels ? class_labels->data.i : 0;
    curr_c = C_grid.min_val;
    do
wester committed
1934
    {
wester committed
1935 1936 1937 1938 1939 1940 1941 1942 1943 1944 1945 1946 1947 1948 1949 1950 1951 1952 1953 1954 1955 1956 1957 1958 1959 1960 1961 1962 1963 1964 1965 1966 1967 1968 1969 1970 1971 1972 1973 1974 1975 1976 1977 1978 1979 1980 1981 1982 1983 1984 1985 1986 1987 1988 1989 1990 1991 1992 1993 1994 1995 1996 1997 1998 1999 2000 2001 2002 2003 2004 2005 2006 2007 2008 2009 2010 2011 2012 2013 2014 2015 2016 2017 2018 2019 2020 2021 2022 2023 2024 2025 2026 2027 2028 2029 2030 2031 2032 2033 2034 2035 2036 2037 2038 2039 2040 2041 2042 2043 2044 2045
      params.C = curr_c;
      gamma = gamma_grid.min_val;
      do
      {
        params.gamma = gamma;
        p = p_grid.min_val;
        do
        {
          params.p = p;
          nu = nu_grid.min_val;
          do
          {
            params.nu = nu;
            coef = coef_grid.min_val;
            do
            {
              params.coef0 = coef;
              degree = degree_grid.min_val;
              do
              {
                params.degree = degree;

                float** test_samples_ptr = (float**)samples;
                uchar* true_resp = responses->data.ptr;
                int test_size = testset_size;
                int train_size = trainset_size;

                error = 0;
                for(int k = 0; k < k_fold; k++ )
                {
                    memcpy( samples_local, samples, sizeof(samples[0])*test_size*k );
                    memcpy( samples_local + test_size*k, test_samples_ptr + test_size,
                        sizeof(samples[0])*(sample_count - testset_size*(k+1)) );

                    memcpy( responses_local->data.ptr, responses->data.ptr, resp_elem_size*test_size*k );
                    memcpy( responses_local->data.ptr + resp_elem_size*test_size*k,
                        true_resp + resp_elem_size*test_size,
                        resp_elem_size*(sample_count - testset_size*(k+1)) );

                    if( k == k_fold - 1 )
                    {
                        test_size = last_testset_size;
                        train_size = last_trainset_size;
                        responses_local->cols = last_trainset_size;
                    }

                    // Train SVM on <train_size> samples
                    if( !do_train( svm_type, train_size, var_count,
                        (const float**)samples_local, responses_local, temp_storage, alpha ) )
                        EXIT;

                    // Compute test set error on <test_size> samples
                    for(int i = 0; i < test_size; i++, true_resp += resp_elem_size, test_samples_ptr++ )
                    {
                        float resp = predict( *test_samples_ptr, var_count );
                        error += is_regression ? powf( resp - *(float*)true_resp, 2 )
                            : ((int)resp != cls_lbls[*(int*)true_resp]);
                    }
                }
                if( min_error > error )
                {
                    min_error   = error;
                    best_degree = degree;
                    best_gamma  = gamma;
                    best_coef   = coef;
                    best_C      = curr_c;
                    best_nu     = nu;
                    best_p      = p;
                }
                degree *= degree_grid.step;
              }
              while( degree < degree_grid.max_val );
              coef *= coef_grid.step;
            }
            while( coef < coef_grid.max_val );
            nu *= nu_grid.step;
          }
          while( nu < nu_grid.max_val );
          p *= p_grid.step;
        }
        while( p < p_grid.max_val );
        gamma *= gamma_grid.step;
      }
      while( gamma < gamma_grid.max_val );
      curr_c *= C_grid.step;
    }
    while( curr_c < C_grid.max_val );
    }

    min_error /= (float) sample_count;

    params.C      = best_C;
    params.nu     = best_nu;
    params.p      = best_p;
    params.gamma  = best_gamma;
    params.degree = best_degree;
    params.coef0  = best_coef;

    CV_CALL(ok = do_train( svm_type, sample_count, var_count, samples, responses, temp_storage, alpha ));

    __END__;

    delete solver;
    solver = 0;
    cvReleaseMemStorage( &temp_storage );
    cvReleaseMat( &responses );
    cvReleaseMat( &responses_local );
    cvFree( &samples );
    cvFree( &samples_local );

    if( cvGetErrStatus() < 0 || !ok )
wester committed
2046 2047
        clear();

wester committed
2048 2049
    return ok;
}
wester committed
2050

wester committed
2051 2052 2053 2054
float CvSVM::predict( const float* row_sample, int row_len, bool returnDFVal ) const
{
    assert( kernel );
    assert( row_sample );
wester committed
2055

wester committed
2056 2057 2058
    int var_count = get_var_count();
    assert( row_len == var_count );
    (void)row_len;
wester committed
2059

wester committed
2060 2061
    int class_count = class_labels ? class_labels->cols :
                  params.svm_type == ONE_CLASS ? 1 : 0;
wester committed
2062

wester committed
2063 2064 2065
    float result = 0;
    cv::AutoBuffer<float> _buffer(sv_total + (class_count+1)*2);
    float* buffer = _buffer;
wester committed
2066

wester committed
2067 2068 2069
    if( params.svm_type == EPS_SVR ||
        params.svm_type == NU_SVR ||
        params.svm_type == ONE_CLASS )
wester committed
2070
    {
wester committed
2071 2072 2073
        CvSVMDecisionFunc* df = (CvSVMDecisionFunc*)decision_func;
        int i, sv_count = df->sv_count;
        double sum = -df->rho;
wester committed
2074

wester committed
2075 2076 2077
        kernel->calc( sv_count, var_count, (const float**)sv, row_sample, buffer );
        for( i = 0; i < sv_count; i++ )
            sum += buffer[i]*df->alpha[i];
wester committed
2078

wester committed
2079 2080 2081 2082
        result = params.svm_type == ONE_CLASS ? (float)(sum > 0) : (float)sum;
    }
    else if( params.svm_type == C_SVC ||
             params.svm_type == NU_SVC )
wester committed
2083
    {
wester committed
2084 2085 2086 2087 2088 2089 2090 2091 2092
        CvSVMDecisionFunc* df = (CvSVMDecisionFunc*)decision_func;
        int* vote = (int*)(buffer + sv_total);
        int i, j, k;

        memset( vote, 0, class_count*sizeof(vote[0]));
        kernel->calc( sv_total, var_count, (const float**)sv, row_sample, buffer );
        double sum = 0.;

        for( i = 0; i < class_count; i++ )
wester committed
2093
        {
wester committed
2094 2095 2096 2097 2098 2099 2100 2101 2102
            for( j = i+1; j < class_count; j++, df++ )
            {
                sum = -df->rho;
                int sv_count = df->sv_count;
                for( k = 0; k < sv_count; k++ )
                    sum += df->alpha[k]*buffer[df->sv_index[k]];

                vote[sum > 0 ? i : j]++;
            }
wester committed
2103
        }
wester committed
2104 2105

        for( i = 1, k = 0; i < class_count; i++ )
wester committed
2106
        {
wester committed
2107 2108
            if( vote[i] > vote[k] )
                k = i;
wester committed
2109
        }
wester committed
2110 2111 2112 2113 2114
        result = returnDFVal && class_count == 2 ? (float)sum : (float)(class_labels->data.i[k]);
    }
    else
        CV_Error( CV_StsBadArg, "INTERNAL ERROR: Unknown SVM type, "
                                "the SVM structure is probably corrupted" );
wester committed
2115

wester committed
2116 2117
    return result;
}
wester committed
2118

wester committed
2119 2120 2121 2122
float CvSVM::predict( const CvMat* sample, bool returnDFVal ) const
{
    float result = 0;
    float* row_sample = 0;
wester committed
2123

wester committed
2124
    CV_FUNCNAME( "CvSVM::predict" );
wester committed
2125

wester committed
2126
    __BEGIN__;
wester committed
2127

wester committed
2128
    int class_count;
wester committed
2129

wester committed
2130 2131
    if( !kernel )
        CV_ERROR( CV_StsBadArg, "The SVM should be trained first" );
wester committed
2132

wester committed
2133 2134
    class_count = class_labels ? class_labels->cols :
                  params.svm_type == ONE_CLASS ? 1 : 0;
wester committed
2135

wester committed
2136 2137 2138
    CV_CALL( cvPreparePredictData( sample, var_all, var_idx,
                                   class_count, 0, &row_sample ));
    result = predict( row_sample, get_var_count(), returnDFVal );
wester committed
2139

wester committed
2140
    __END__;
wester committed
2141

wester committed
2142 2143 2144 2145 2146 2147 2148 2149
    if( sample && (!CV_IS_MAT(sample) || sample->data.fl != row_sample) )
        cvFree( &row_sample );

    return result;
}

struct predict_body_svm : ParallelLoopBody {
    predict_body_svm(const CvSVM* _pointer, float* _result, const CvMat* _samples, CvMat* _results)
wester committed
2150
    {
wester committed
2151 2152 2153 2154 2155
        pointer = _pointer;
        result = _result;
        samples = _samples;
        results = _results;
    }
wester committed
2156

wester committed
2157 2158 2159 2160
    const CvSVM* pointer;
    float* result;
    const CvMat* samples;
    CvMat* results;
wester committed
2161

wester committed
2162 2163 2164
    void operator()( const cv::Range& range ) const
    {
        for(int i = range.start; i < range.end; i++ )
wester committed
2165
        {
wester committed
2166 2167 2168 2169 2170 2171 2172 2173 2174 2175
            CvMat sample;
            cvGetRow( samples, &sample, i );
            int r = (int)pointer->predict(&sample);
            if (results)
                results->data.fl[i] = (float)r;
            if (i == 0)
                *result = (float)r;
    }
    }
};
wester committed
2176

wester committed
2177 2178 2179 2180 2181 2182 2183 2184
float CvSVM::predict(const CvMat* samples, CV_OUT CvMat* results) const
{
    float result = 0;
    cv::parallel_for_(cv::Range(0, samples->rows),
             predict_body_svm(this, &result, samples, results)
    );
    return result;
}
wester committed
2185

wester committed
2186 2187 2188 2189 2190 2191
void CvSVM::predict( cv::InputArray _samples, cv::OutputArray _results ) const
{
    _results.create(_samples.size().height, 1, CV_32F);
    CvMat samples = _samples.getMat(), results = _results.getMat();
    predict(&samples, &results);
}
wester committed
2192

wester committed
2193 2194 2195 2196 2197 2198 2199 2200 2201 2202 2203 2204 2205 2206
CvSVM::CvSVM( const Mat& _train_data, const Mat& _responses,
              const Mat& _var_idx, const Mat& _sample_idx, CvSVMParams _params )
{
    decision_func = 0;
    class_labels = 0;
    class_weights = 0;
    storage = 0;
    var_idx = 0;
    kernel = 0;
    solver = 0;
    default_model_name = "my_svm";

    train( _train_data, _responses, _var_idx, _sample_idx, _params );
}
wester committed
2207

wester committed
2208 2209 2210 2211 2212 2213
bool CvSVM::train( const Mat& _train_data, const Mat& _responses,
                  const Mat& _var_idx, const Mat& _sample_idx, CvSVMParams _params )
{
    CvMat tdata = _train_data, responses = _responses, vidx = _var_idx, sidx = _sample_idx;
    return train(&tdata, &responses, vidx.data.ptr ? &vidx : 0, sidx.data.ptr ? &sidx : 0, _params);
}
wester committed
2214 2215


wester committed
2216 2217 2218 2219 2220 2221 2222 2223 2224 2225
bool CvSVM::train_auto( const Mat& _train_data, const Mat& _responses,
                       const Mat& _var_idx, const Mat& _sample_idx, CvSVMParams _params, int k_fold,
                       CvParamGrid C_grid, CvParamGrid gamma_grid, CvParamGrid p_grid,
                       CvParamGrid nu_grid, CvParamGrid coef_grid, CvParamGrid degree_grid, bool balanced )
{
    CvMat tdata = _train_data, responses = _responses, vidx = _var_idx, sidx = _sample_idx;
    return train_auto(&tdata, &responses, vidx.data.ptr ? &vidx : 0,
                      sidx.data.ptr ? &sidx : 0, _params, k_fold, C_grid, gamma_grid, p_grid,
                      nu_grid, coef_grid, degree_grid, balanced);
}
wester committed
2226

wester committed
2227 2228 2229 2230 2231
float CvSVM::predict( const Mat& _sample, bool returnDFVal ) const
{
    CvMat sample = _sample;
    return predict(&sample, returnDFVal);
}
wester committed
2232 2233


wester committed
2234 2235 2236 2237 2238 2239 2240 2241 2242 2243 2244 2245 2246 2247 2248 2249 2250 2251 2252 2253 2254 2255 2256 2257 2258
void CvSVM::write_params( CvFileStorage* fs ) const
{
    //CV_FUNCNAME( "CvSVM::write_params" );

    __BEGIN__;

    int svm_type = params.svm_type;
    int kernel_type = params.kernel_type;

    const char* svm_type_str =
        svm_type == CvSVM::C_SVC ? "C_SVC" :
        svm_type == CvSVM::NU_SVC ? "NU_SVC" :
        svm_type == CvSVM::ONE_CLASS ? "ONE_CLASS" :
        svm_type == CvSVM::EPS_SVR ? "EPS_SVR" :
        svm_type == CvSVM::NU_SVR ? "NU_SVR" : 0;
    const char* kernel_type_str =
        kernel_type == CvSVM::LINEAR ? "LINEAR" :
        kernel_type == CvSVM::POLY ? "POLY" :
        kernel_type == CvSVM::RBF ? "RBF" :
        kernel_type == CvSVM::SIGMOID ? "SIGMOID" : 0;

    if( svm_type_str )
        cvWriteString( fs, "svm_type", svm_type_str );
    else
        cvWriteInt( fs, "svm_type", svm_type );
wester committed
2259

wester committed
2260 2261
    // save kernel
    cvStartWriteStruct( fs, "kernel", CV_NODE_MAP + CV_NODE_FLOW );
wester committed
2262

wester committed
2263 2264 2265 2266
    if( kernel_type_str )
        cvWriteString( fs, "type", kernel_type_str );
    else
        cvWriteInt( fs, "type", kernel_type );
wester committed
2267

wester committed
2268 2269
    if( kernel_type == CvSVM::POLY || !kernel_type_str )
        cvWriteReal( fs, "degree", params.degree );
wester committed
2270

wester committed
2271 2272
    if( kernel_type != CvSVM::LINEAR || !kernel_type_str )
        cvWriteReal( fs, "gamma", params.gamma );
wester committed
2273

wester committed
2274 2275
    if( kernel_type == CvSVM::POLY || kernel_type == CvSVM::SIGMOID || !kernel_type_str )
        cvWriteReal( fs, "coef0", params.coef0 );
wester committed
2276

wester committed
2277
    cvEndWriteStruct(fs);
wester committed
2278

wester committed
2279 2280 2281
    if( svm_type == CvSVM::C_SVC || svm_type == CvSVM::EPS_SVR ||
        svm_type == CvSVM::NU_SVR || !svm_type_str )
        cvWriteReal( fs, "C", params.C );
wester committed
2282

wester committed
2283 2284 2285
    if( svm_type == CvSVM::NU_SVC || svm_type == CvSVM::ONE_CLASS ||
        svm_type == CvSVM::NU_SVR || !svm_type_str )
        cvWriteReal( fs, "nu", params.nu );
wester committed
2286

wester committed
2287 2288
    if( svm_type == CvSVM::EPS_SVR || !svm_type_str )
        cvWriteReal( fs, "p", params.p );
wester committed
2289

wester committed
2290 2291 2292 2293 2294 2295
    cvStartWriteStruct( fs, "term_criteria", CV_NODE_MAP + CV_NODE_FLOW );
    if( params.term_crit.type & CV_TERMCRIT_EPS )
        cvWriteReal( fs, "epsilon", params.term_crit.epsilon );
    if( params.term_crit.type & CV_TERMCRIT_ITER )
        cvWriteInt( fs, "iterations", params.term_crit.max_iter );
    cvEndWriteStruct( fs );
wester committed
2296

wester committed
2297 2298
    __END__;
}
wester committed
2299 2300


wester committed
2301 2302 2303 2304
static bool isSvmModelApplicable(int sv_total, int var_all, int var_count, int class_count)
{
    return (sv_total > 0 && var_count > 0 && var_count <= var_all && class_count >= 0);
}
wester committed
2305 2306


wester committed
2307 2308 2309
void CvSVM::write( CvFileStorage* fs, const char* name ) const
{
    CV_FUNCNAME( "CvSVM::write" );
wester committed
2310

wester committed
2311
    __BEGIN__;
wester committed
2312

wester committed
2313 2314 2315 2316 2317
    int i, var_count = get_var_count(), df_count;
    int class_count = class_labels ? class_labels->cols :
                      params.svm_type == CvSVM::ONE_CLASS ? 1 : 0;
    const CvSVMDecisionFunc* df = decision_func;
    if( !isSvmModelApplicable(sv_total, var_all, var_count, class_count) )
wester committed
2318
    {
wester committed
2319 2320
        cvReleaseFileStorage( &fs );
        fs = NULL;
wester committed
2321

wester committed
2322 2323
        CV_ERROR( CV_StsParseError, "SVM model data is invalid, check sv_count, var_* and class_count tags" );
    }
wester committed
2324

wester committed
2325
    cvStartWriteStruct( fs, name, CV_NODE_MAP, CV_TYPE_NAME_ML_SVM );
wester committed
2326

wester committed
2327
    write_params( fs );
wester committed
2328

wester committed
2329 2330
    cvWriteInt( fs, "var_all", var_all );
    cvWriteInt( fs, "var_count", var_count );
wester committed
2331

wester committed
2332 2333 2334
    if( class_count )
    {
        cvWriteInt( fs, "class_count", class_count );
wester committed
2335

wester committed
2336 2337
        if( class_labels )
            cvWrite( fs, "class_labels", class_labels );
wester committed
2338

wester committed
2339 2340
        if( class_weights )
            cvWrite( fs, "class_weights", class_weights );
wester committed
2341 2342
    }

wester committed
2343 2344 2345 2346 2347 2348 2349
    if( var_idx )
        cvWrite( fs, "var_idx", var_idx );

    // write the joint collection of support vectors
    cvWriteInt( fs, "sv_total", sv_total );
    cvStartWriteStruct( fs, "support_vectors", CV_NODE_SEQ );
    for( i = 0; i < sv_total; i++ )
wester committed
2350
    {
wester committed
2351 2352 2353 2354
        cvStartWriteStruct( fs, 0, CV_NODE_SEQ + CV_NODE_FLOW );
        cvWriteRawData( fs, sv[i], var_count, "f" );
        cvEndWriteStruct( fs );
    }
wester committed
2355

wester committed
2356
    cvEndWriteStruct( fs );
wester committed
2357

wester committed
2358 2359 2360
    // write decision functions
    df_count = class_count > 1 ? class_count*(class_count-1)/2 : 1;
    df = decision_func;
wester committed
2361

wester committed
2362 2363 2364 2365 2366 2367 2368 2369 2370 2371 2372
    cvStartWriteStruct( fs, "decision_functions", CV_NODE_SEQ );
    for( i = 0; i < df_count; i++ )
    {
        int sv_count = df[i].sv_count;
        cvStartWriteStruct( fs, 0, CV_NODE_MAP );
        cvWriteInt( fs, "sv_count", sv_count );
        cvWriteReal( fs, "rho", df[i].rho );
        cvStartWriteStruct( fs, "alpha", CV_NODE_SEQ+CV_NODE_FLOW );
        cvWriteRawData( fs, df[i].alpha, df[i].sv_count, "d" );
        cvEndWriteStruct( fs );
        if( class_count > 1 )
wester committed
2373
        {
wester committed
2374 2375 2376
            cvStartWriteStruct( fs, "index", CV_NODE_SEQ+CV_NODE_FLOW );
            cvWriteRawData( fs, df[i].sv_index, df[i].sv_count, "i" );
            cvEndWriteStruct( fs );
wester committed
2377 2378
        }
        else
wester committed
2379 2380 2381 2382 2383
            CV_ASSERT( sv_count == sv_total );
        cvEndWriteStruct( fs );
    }
    cvEndWriteStruct( fs );
    cvEndWriteStruct( fs );
wester committed
2384

wester committed
2385 2386
    __END__;
}
wester committed
2387 2388


wester committed
2389 2390 2391
void CvSVM::read_params( CvFileStorage* fs, CvFileNode* svm_node )
{
    CV_FUNCNAME( "CvSVM::read_params" );
wester committed
2392

wester committed
2393
    __BEGIN__;
wester committed
2394

wester committed
2395 2396
    int svm_type, kernel_type;
    CvSVMParams _params;
wester committed
2397

wester committed
2398 2399 2400 2401
    CvFileNode* tmp_node = cvGetFileNodeByName( fs, svm_node, "svm_type" );
    CvFileNode* kernel_node;
    if( !tmp_node )
        CV_ERROR( CV_StsBadArg, "svm_type tag is not found" );
wester committed
2402

wester committed
2403 2404 2405 2406 2407 2408 2409 2410 2411 2412 2413 2414 2415 2416 2417
    if( CV_NODE_TYPE(tmp_node->tag) == CV_NODE_INT )
        svm_type = cvReadInt( tmp_node, -1 );
    else
    {
        const char* svm_type_str = cvReadString( tmp_node, "" );
        svm_type =
            strcmp( svm_type_str, "C_SVC" ) == 0 ? CvSVM::C_SVC :
            strcmp( svm_type_str, "NU_SVC" ) == 0 ? CvSVM::NU_SVC :
            strcmp( svm_type_str, "ONE_CLASS" ) == 0 ? CvSVM::ONE_CLASS :
            strcmp( svm_type_str, "EPS_SVR" ) == 0 ? CvSVM::EPS_SVR :
            strcmp( svm_type_str, "NU_SVR" ) == 0 ? CvSVM::NU_SVR : -1;

        if( svm_type < 0 )
            CV_ERROR( CV_StsParseError, "Missing of invalid SVM type" );
    }
wester committed
2418

wester committed
2419 2420 2421
    kernel_node = cvGetFileNodeByName( fs, svm_node, "kernel" );
    if( !kernel_node )
        CV_ERROR( CV_StsParseError, "SVM kernel tag is not found" );
wester committed
2422

wester committed
2423 2424 2425
    tmp_node = cvGetFileNodeByName( fs, kernel_node, "type" );
    if( !tmp_node )
        CV_ERROR( CV_StsParseError, "SVM kernel type tag is not found" );
wester committed
2426

wester committed
2427 2428 2429 2430 2431 2432 2433 2434 2435 2436 2437 2438 2439 2440
    if( CV_NODE_TYPE(tmp_node->tag) == CV_NODE_INT )
        kernel_type = cvReadInt( tmp_node, -1 );
    else
    {
        const char* kernel_type_str = cvReadString( tmp_node, "" );
        kernel_type =
            strcmp( kernel_type_str, "LINEAR" ) == 0 ? CvSVM::LINEAR :
            strcmp( kernel_type_str, "POLY" ) == 0 ? CvSVM::POLY :
            strcmp( kernel_type_str, "RBF" ) == 0 ? CvSVM::RBF :
            strcmp( kernel_type_str, "SIGMOID" ) == 0 ? CvSVM::SIGMOID : -1;

        if( kernel_type < 0 )
            CV_ERROR( CV_StsParseError, "Missing of invalid SVM kernel type" );
    }
wester committed
2441

wester committed
2442 2443 2444 2445 2446
    _params.svm_type = svm_type;
    _params.kernel_type = kernel_type;
    _params.degree = cvReadRealByName( fs, kernel_node, "degree", 0 );
    _params.gamma = cvReadRealByName( fs, kernel_node, "gamma", 0 );
    _params.coef0 = cvReadRealByName( fs, kernel_node, "coef0", 0 );
wester committed
2447

wester committed
2448 2449 2450 2451
    _params.C = cvReadRealByName( fs, svm_node, "C", 0 );
    _params.nu = cvReadRealByName( fs, svm_node, "nu", 0 );
    _params.p = cvReadRealByName( fs, svm_node, "p", 0 );
    _params.class_weights = 0;
wester committed
2452

wester committed
2453 2454 2455 2456 2457 2458 2459 2460 2461 2462
    tmp_node = cvGetFileNodeByName( fs, svm_node, "term_criteria" );
    if( tmp_node )
    {
        _params.term_crit.epsilon = cvReadRealByName( fs, tmp_node, "epsilon", -1. );
        _params.term_crit.max_iter = cvReadIntByName( fs, tmp_node, "iterations", -1 );
        _params.term_crit.type = (_params.term_crit.epsilon >= 0 ? CV_TERMCRIT_EPS : 0) +
                               (_params.term_crit.max_iter >= 0 ? CV_TERMCRIT_ITER : 0);
    }
    else
        _params.term_crit = cvTermCriteria( CV_TERMCRIT_EPS + CV_TERMCRIT_ITER, 1000, FLT_EPSILON );
wester committed
2463

wester committed
2464
    set_params( _params );
wester committed
2465

wester committed
2466 2467
    __END__;
}
wester committed
2468

wester committed
2469 2470 2471
void CvSVM::read( CvFileStorage* fs, CvFileNode* svm_node )
{
    const double not_found_dbl = DBL_MAX;
wester committed
2472

wester committed
2473
    CV_FUNCNAME( "CvSVM::read" );
wester committed
2474

wester committed
2475
    __BEGIN__;
wester committed
2476

wester committed
2477 2478 2479 2480 2481
    int i, var_count, df_count, class_count;
    int block_size = 1 << 16, sv_size;
    CvFileNode *sv_node, *df_node;
    CvSVMDecisionFunc* df;
    CvSeqReader reader;
wester committed
2482

wester committed
2483 2484
    if( !svm_node )
        CV_ERROR( CV_StsParseError, "The requested element is not found" );
wester committed
2485

wester committed
2486
    clear();
wester committed
2487

wester committed
2488 2489
    // read SVM parameters
    read_params( fs, svm_node );
wester committed
2490

wester committed
2491 2492 2493 2494 2495
    // and top-level data
    sv_total = cvReadIntByName( fs, svm_node, "sv_total", -1 );
    var_all = cvReadIntByName( fs, svm_node, "var_all", -1 );
    var_count = cvReadIntByName( fs, svm_node, "var_count", var_all );
    class_count = cvReadIntByName( fs, svm_node, "class_count", 0 );
wester committed
2496

wester committed
2497 2498
    if( !isSvmModelApplicable(sv_total, var_all, var_count, class_count) )
        CV_ERROR( CV_StsParseError, "SVM model data is invalid, check sv_count, var_* and class_count tags" );
wester committed
2499

wester committed
2500 2501 2502
    CV_CALL( class_labels = (CvMat*)cvReadByName( fs, svm_node, "class_labels" ));
    CV_CALL( class_weights = (CvMat*)cvReadByName( fs, svm_node, "class_weights" ));
    CV_CALL( var_idx = (CvMat*)cvReadByName( fs, svm_node, "var_idx" ));
wester committed
2503

wester committed
2504 2505 2506
    if( class_count > 1 && (!class_labels ||
        !CV_IS_MAT(class_labels) || class_labels->cols != class_count))
        CV_ERROR( CV_StsParseError, "Array of class labels is missing or invalid" );
wester committed
2507

wester committed
2508 2509
    if( var_count < var_all && (!var_idx || !CV_IS_MAT(var_idx) || var_idx->cols != var_count) )
        CV_ERROR( CV_StsParseError, "var_idx array is missing or invalid" );
wester committed
2510

wester committed
2511 2512 2513 2514
    // read support vectors
    sv_node = cvGetFileNodeByName( fs, svm_node, "support_vectors" );
    if( !sv_node || !CV_NODE_IS_SEQ(sv_node->tag))
        CV_ERROR( CV_StsParseError, "Missing or invalid sequence of support vectors" );
wester committed
2515

wester committed
2516 2517 2518
    block_size = MAX( block_size, sv_total*(int)sizeof(CvSVMKernelRow));
    block_size = MAX( block_size, sv_total*2*(int)sizeof(double));
    block_size = MAX( block_size, var_all*(int)sizeof(double));
wester committed
2519

wester committed
2520 2521 2522
    CV_CALL( storage = cvCreateMemStorage(block_size + sizeof(CvMemBlock) + sizeof(CvSeqBlock)));
    CV_CALL( sv = (float**)cvMemStorageAlloc( storage,
                                sv_total*sizeof(sv[0]) ));
wester committed
2523

wester committed
2524 2525
    CV_CALL( cvStartReadSeq( sv_node->data.seq, &reader, 0 ));
    sv_size = var_count*sizeof(sv[0][0]);
wester committed
2526

wester committed
2527 2528 2529 2530 2531
    for( i = 0; i < sv_total; i++ )
    {
        CvFileNode* sv_elem = (CvFileNode*)reader.ptr;
        CV_ASSERT( var_count == 1 || (CV_NODE_IS_SEQ(sv_elem->tag) &&
                   sv_elem->data.seq->total == var_count) );
wester committed
2532

wester committed
2533 2534 2535 2536
        CV_CALL( sv[i] = (float*)cvMemStorageAlloc( storage, sv_size ));
        CV_CALL( cvReadRawData( fs, sv_elem, sv[i], "f" ));
        CV_NEXT_SEQ_ELEM( sv_node->data.seq->elem_size, reader );
    }
wester committed
2537

wester committed
2538 2539 2540 2541 2542 2543 2544
    // read decision functions
    df_count = class_count > 1 ? class_count*(class_count-1)/2 : 1;
    df_node = cvGetFileNodeByName( fs, svm_node, "decision_functions" );
    if( !df_node || !CV_NODE_IS_SEQ(df_node->tag) ||
        df_node->data.seq->total != df_count )
        CV_ERROR( CV_StsParseError, "decision_functions is missing or is not a collection "
                  "or has a wrong number of elements" );
wester committed
2545

wester committed
2546 2547
    CV_CALL( df = decision_func = (CvSVMDecisionFunc*)cvAlloc( df_count*sizeof(df[0]) ));
    cvStartReadSeq( df_node->data.seq, &reader, 0 );
wester committed
2548

wester committed
2549 2550 2551 2552
    for( i = 0; i < df_count; i++ )
    {
        CvFileNode* df_elem = (CvFileNode*)reader.ptr;
        CvFileNode* alpha_node = cvGetFileNodeByName( fs, df_elem, "alpha" );
wester committed
2553

wester committed
2554 2555 2556 2557
        int sv_count = cvReadIntByName( fs, df_elem, "sv_count", -1 );
        if( sv_count <= 0 )
            CV_ERROR( CV_StsParseError, "sv_count is missing or non-positive" );
        df[i].sv_count = sv_count;
wester committed
2558

wester committed
2559 2560 2561
        df[i].rho = cvReadRealByName( fs, df_elem, "rho", not_found_dbl );
        if( fabs(df[i].rho - not_found_dbl) < DBL_EPSILON )
            CV_ERROR( CV_StsParseError, "rho is missing" );
wester committed
2562

wester committed
2563 2564
        if( !alpha_node )
            CV_ERROR( CV_StsParseError, "alpha is missing in the decision function" );
wester committed
2565

wester committed
2566 2567 2568 2569 2570
        CV_CALL( df[i].alpha = (double*)cvMemStorageAlloc( storage,
                                        sv_count*sizeof(df[i].alpha[0])));
        CV_ASSERT( sv_count == 1 || (CV_NODE_IS_SEQ(alpha_node->tag) &&
                   alpha_node->data.seq->total == sv_count) );
        CV_CALL( cvReadRawData( fs, alpha_node, df[i].alpha, "d" ));
wester committed
2571

wester committed
2572
        if( class_count > 1 )
wester committed
2573
        {
wester committed
2574 2575 2576 2577 2578 2579 2580 2581
            CvFileNode* index_node = cvGetFileNodeByName( fs, df_elem, "index" );
            if( !index_node )
                CV_ERROR( CV_StsParseError, "index is missing in the decision function" );
            CV_CALL( df[i].sv_index = (int*)cvMemStorageAlloc( storage,
                                            sv_count*sizeof(df[i].sv_index[0])));
            CV_ASSERT( sv_count == 1 || (CV_NODE_IS_SEQ(index_node->tag) &&
                   index_node->data.seq->total == sv_count) );
            CV_CALL( cvReadRawData( fs, index_node, df[i].sv_index, "i" ));
wester committed
2582 2583
        }
        else
wester committed
2584
            df[i].sv_index = 0;
wester committed
2585

wester committed
2586
        CV_NEXT_SEQ_ELEM( df_node->data.seq->elem_size, reader );
wester committed
2587 2588
    }

wester committed
2589 2590 2591
    if( cvReadIntByName(fs, svm_node, "optimize_linear", 1) != 0 )
        optimize_linear_svm();
    create_kernel();
wester committed
2592

wester committed
2593 2594
    __END__;
}
wester committed
2595

wester committed
2596
#if 0
wester committed
2597

wester committed
2598 2599 2600 2601
static void*
icvCloneSVM( const void* _src )
{
    CvSVMModel* dst = 0;
wester committed
2602

wester committed
2603
    CV_FUNCNAME( "icvCloneSVM" );
wester committed
2604

wester committed
2605
    __BEGIN__;
wester committed
2606

wester committed
2607 2608 2609 2610
    const CvSVMModel* src = (const CvSVMModel*)_src;
    int var_count, class_count;
    int i, sv_total, df_count;
    int sv_size;
wester committed
2611

wester committed
2612 2613
    if( !CV_IS_SVM(src) )
        CV_ERROR( !src ? CV_StsNullPtr : CV_StsBadArg, "Input pointer is NULL or invalid" );
wester committed
2614

wester committed
2615 2616 2617 2618 2619
    // 0. create initial CvSVMModel structure
    CV_CALL( dst = icvCreateSVM() );
    dst->params = src->params;
    dst->params.weight_labels = 0;
    dst->params.weights = 0;
wester committed
2620

wester committed
2621 2622 2623 2624 2625 2626 2627
    dst->var_all = src->var_all;
    if( src->class_labels )
        dst->class_labels = cvCloneMat( src->class_labels );
    if( src->class_weights )
        dst->class_weights = cvCloneMat( src->class_weights );
    if( src->comp_idx )
        dst->comp_idx = cvCloneMat( src->comp_idx );
wester committed
2628

wester committed
2629 2630 2631 2632 2633 2634 2635
    var_count = src->comp_idx ? src->comp_idx->cols : src->var_all;
    class_count = src->class_labels ? src->class_labels->cols :
                  src->params.svm_type == CvSVM::ONE_CLASS ? 1 : 0;
    sv_total = dst->sv_total = src->sv_total;
    CV_CALL( dst->storage = cvCreateMemStorage( src->storage->block_size ));
    CV_CALL( dst->sv = (float**)cvMemStorageAlloc( dst->storage,
                                    sv_total*sizeof(dst->sv[0]) ));
wester committed
2636

wester committed
2637
    sv_size = var_count*sizeof(dst->sv[0][0]);
wester committed
2638

wester committed
2639
    for( i = 0; i < sv_total; i++ )
wester committed
2640
    {
wester committed
2641 2642
        CV_CALL( dst->sv[i] = (float*)cvMemStorageAlloc( dst->storage, sv_size ));
        memcpy( dst->sv[i], src->sv[i], sv_size );
wester committed
2643 2644
    }

wester committed
2645
    df_count = class_count > 1 ? class_count*(class_count-1)/2 : 1;
wester committed
2646

wester committed
2647
    CV_CALL( dst->decision_func = cvAlloc( df_count*sizeof(CvSVMDecisionFunc) ));
wester committed
2648

wester committed
2649
    for( i = 0; i < df_count; i++ )
wester committed
2650
    {
wester committed
2651 2652 2653 2654 2655 2656 2657 2658 2659 2660 2661 2662 2663 2664 2665 2666 2667 2668 2669
        const CvSVMDecisionFunc *sdf =
            (const CvSVMDecisionFunc*)src->decision_func+i;
        CvSVMDecisionFunc *ddf =
            (CvSVMDecisionFunc*)dst->decision_func+i;
        int sv_count = sdf->sv_count;
        ddf->sv_count = sv_count;
        ddf->rho = sdf->rho;
        CV_CALL( ddf->alpha = (double*)cvMemStorageAlloc( dst->storage,
                                        sv_count*sizeof(ddf->alpha[0])));
        memcpy( ddf->alpha, sdf->alpha, sv_count*sizeof(ddf->alpha[0]));

        if( class_count > 1 )
        {
            CV_CALL( ddf->sv_index = (int*)cvMemStorageAlloc( dst->storage,
                                                sv_count*sizeof(ddf->sv_index[0])));
            memcpy( ddf->sv_index, sdf->sv_index, sv_count*sizeof(ddf->sv_index[0]));
        }
        else
            ddf->sv_index = 0;
wester committed
2670 2671
    }

wester committed
2672
    __END__;
wester committed
2673

wester committed
2674 2675
    if( cvGetErrStatus() < 0 && dst )
        icvReleaseSVM( &dst );
wester committed
2676

wester committed
2677 2678
    return dst;
}
wester committed
2679

wester committed
2680 2681 2682 2683 2684 2685 2686 2687 2688 2689 2690 2691 2692 2693 2694 2695 2696
static int icvRegisterSVMType()
{
    CvTypeInfo info;
    memset( &info, 0, sizeof(info) );

    info.flags = 0;
    info.header_size = sizeof( info );
    info.is_instance = icvIsSVM;
    info.release = (CvReleaseFunc)icvReleaseSVM;
    info.read = icvReadSVM;
    info.write = icvWriteSVM;
    info.clone = icvCloneSVM;
    info.type_name = CV_TYPE_NAME_ML_SVM;
    cvRegisterType( &info );

    return 1;
}
wester committed
2697 2698


wester committed
2699 2700 2701 2702 2703 2704 2705 2706 2707 2708 2709 2710 2711 2712 2713 2714 2715 2716 2717 2718 2719 2720 2721 2722 2723 2724 2725 2726 2727 2728 2729 2730 2731 2732 2733 2734 2735 2736 2737 2738 2739 2740 2741 2742 2743 2744 2745 2746 2747 2748 2749 2750 2751 2752 2753 2754 2755 2756 2757 2758 2759 2760 2761 2762 2763 2764 2765 2766 2767 2768 2769 2770 2771 2772 2773 2774 2775 2776 2777 2778 2779 2780
static int svm = icvRegisterSVMType();

/* The function trains SVM model with optimal parameters, obtained by using cross-validation.
The parameters to be estimated should be indicated by setting theirs values to FLT_MAX.
The optimal parameters are saved in <model_params> */
CV_IMPL CvStatModel*
cvTrainSVM_CrossValidation( const CvMat* train_data, int tflag,
            const CvMat* responses,
            CvStatModelParams* model_params,
            const CvStatModelParams* cross_valid_params,
            const CvMat* comp_idx,
            const CvMat* sample_idx,
            const CvParamGrid* degree_grid,
            const CvParamGrid* gamma_grid,
            const CvParamGrid* coef_grid,
            const CvParamGrid* C_grid,
            const CvParamGrid* nu_grid,
            const CvParamGrid* p_grid )
{
    CvStatModel* svm = 0;

    CV_FUNCNAME("cvTainSVMCrossValidation");
    __BEGIN__;

    double degree_step = 7,
           g_step      = 15,
           coef_step   = 14,
           C_step      = 20,
           nu_step     = 5,
           p_step      = 7; // all steps must be > 1
    double degree_begin = 0.01, degree_end = 2;
    double g_begin      = 1e-5, g_end      = 0.5;
    double coef_begin   = 0.1,  coef_end   = 300;
    double C_begin      = 0.1,  C_end      = 6000;
    double nu_begin     = 0.01,  nu_end    = 0.4;
    double p_begin      = 0.01, p_end      = 100;

    double rate = 0, gamma = 0, C = 0, degree = 0, coef = 0, p = 0, nu = 0;

    double best_rate    = 0;
    double best_degree  = degree_begin;
    double best_gamma   = g_begin;
    double best_coef    = coef_begin;
    double best_C       = C_begin;
    double best_nu      = nu_begin;
    double best_p       = p_begin;

    CvSVMModelParams svm_params, *psvm_params;
    CvCrossValidationParams* cv_params = (CvCrossValidationParams*)cross_valid_params;
    int svm_type, kernel;
    int is_regression;

    if( !model_params )
        CV_ERROR( CV_StsBadArg, "" );
    if( !cv_params )
        CV_ERROR( CV_StsBadArg, "" );

    svm_params = *(CvSVMModelParams*)model_params;
    psvm_params = (CvSVMModelParams*)model_params;
    svm_type = svm_params.svm_type;
    kernel = svm_params.kernel_type;

    svm_params.degree = svm_params.degree > 0 ? svm_params.degree : 1;
    svm_params.gamma = svm_params.gamma > 0 ? svm_params.gamma : 1;
    svm_params.coef0 = svm_params.coef0 > 0 ? svm_params.coef0 : 1e-6;
    svm_params.C = svm_params.C > 0 ? svm_params.C : 1;
    svm_params.nu = svm_params.nu > 0 ? svm_params.nu : 1;
    svm_params.p = svm_params.p > 0 ? svm_params.p : 1;

    if( degree_grid )
    {
        if( !(degree_grid->max_val == 0 && degree_grid->min_val == 0 &&
              degree_grid->step == 0) )
        {
            if( degree_grid->min_val > degree_grid->max_val )
                CV_ERROR( CV_StsBadArg,
                "low bound of grid should be less then the upper one");
            if( degree_grid->step <= 1 )
                CV_ERROR( CV_StsBadArg, "grid step should be greater 1" );
            degree_begin = degree_grid->min_val;
            degree_end   = degree_grid->max_val;
            degree_step  = degree_grid->step;
wester committed
2781
        }
wester committed
2782 2783 2784
    }
    else
        degree_begin = degree_end = svm_params.degree;
wester committed
2785

wester committed
2786 2787 2788 2789
    if( gamma_grid )
    {
        if( !(gamma_grid->max_val == 0 && gamma_grid->min_val == 0 &&
              gamma_grid->step == 0) )
wester committed
2790
        {
wester committed
2791 2792 2793 2794 2795 2796 2797 2798
            if( gamma_grid->min_val > gamma_grid->max_val )
                CV_ERROR( CV_StsBadArg,
                "low bound of grid should be less then the upper one");
            if( gamma_grid->step <= 1 )
                CV_ERROR( CV_StsBadArg, "grid step should be greater 1" );
            g_begin = gamma_grid->min_val;
            g_end   = gamma_grid->max_val;
            g_step  = gamma_grid->step;
wester committed
2799
        }
wester committed
2800 2801 2802
    }
    else
        g_begin = g_end = svm_params.gamma;
wester committed
2803

wester committed
2804 2805 2806 2807
    if( coef_grid )
    {
        if( !(coef_grid->max_val == 0 && coef_grid->min_val == 0 &&
              coef_grid->step == 0) )
wester committed
2808
        {
wester committed
2809 2810 2811 2812 2813 2814 2815 2816
            if( coef_grid->min_val > coef_grid->max_val )
                CV_ERROR( CV_StsBadArg,
                "low bound of grid should be less then the upper one");
            if( coef_grid->step <= 1 )
                CV_ERROR( CV_StsBadArg, "grid step should be greater 1" );
            coef_begin = coef_grid->min_val;
            coef_end   = coef_grid->max_val;
            coef_step  = coef_grid->step;
wester committed
2817
        }
wester committed
2818 2819 2820
    }
    else
        coef_begin = coef_end = svm_params.coef0;
wester committed
2821

wester committed
2822 2823 2824
    if( C_grid )
    {
        if( !(C_grid->max_val == 0 && C_grid->min_val == 0 && C_grid->step == 0))
wester committed
2825
        {
wester committed
2826 2827 2828 2829 2830 2831 2832 2833
            if( C_grid->min_val > C_grid->max_val )
                CV_ERROR( CV_StsBadArg,
                "low bound of grid should be less then the upper one");
            if( C_grid->step <= 1 )
                CV_ERROR( CV_StsBadArg, "grid step should be greater 1" );
            C_begin = C_grid->min_val;
            C_end   = C_grid->max_val;
            C_step  = C_grid->step;
wester committed
2834
        }
wester committed
2835 2836 2837 2838 2839 2840 2841
    }
    else
        C_begin = C_end = svm_params.C;

    if( nu_grid )
    {
        if(!(nu_grid->max_val == 0 && nu_grid->min_val == 0 && nu_grid->step==0))
wester committed
2842
        {
wester committed
2843 2844 2845 2846 2847 2848 2849 2850
            if( nu_grid->min_val > nu_grid->max_val )
                CV_ERROR( CV_StsBadArg,
                "low bound of grid should be less then the upper one");
            if( nu_grid->step <= 1 )
                CV_ERROR( CV_StsBadArg, "grid step should be greater 1" );
            nu_begin = nu_grid->min_val;
            nu_end   = nu_grid->max_val;
            nu_step  = nu_grid->step;
wester committed
2851 2852
        }
    }
wester committed
2853 2854
    else
        nu_begin = nu_end = svm_params.nu;
wester committed
2855

wester committed
2856
    if( p_grid )
wester committed
2857
    {
wester committed
2858 2859 2860 2861 2862 2863 2864 2865 2866 2867 2868 2869 2870 2871
        if( !(p_grid->max_val == 0 && p_grid->min_val == 0 && p_grid->step == 0))
        {
            if( p_grid->min_val > p_grid->max_val )
                CV_ERROR( CV_StsBadArg,
                "low bound of grid should be less then the upper one");
            if( p_grid->step <= 1 )
                CV_ERROR( CV_StsBadArg, "grid step should be greater 1" );
            p_begin = p_grid->min_val;
            p_end   = p_grid->max_val;
            p_step  = p_grid->step;
        }
    }
    else
        p_begin = p_end = svm_params.p;
wester committed
2872

wester committed
2873 2874 2875
    // these parameters are not used:
    if( kernel != CvSVM::POLY )
        degree_begin = degree_end = svm_params.degree;
wester committed
2876

wester committed
2877 2878
   if( kernel == CvSVM::LINEAR )
        g_begin = g_end = svm_params.gamma;
wester committed
2879

wester committed
2880 2881
    if( kernel != CvSVM::POLY && kernel != CvSVM::SIGMOID )
        coef_begin = coef_end = svm_params.coef0;
wester committed
2882

wester committed
2883 2884
    if( svm_type == CvSVM::NU_SVC || svm_type == CvSVM::ONE_CLASS )
        C_begin = C_end = svm_params.C;
wester committed
2885

wester committed
2886 2887
    if( svm_type == CvSVM::C_SVC || svm_type == CvSVM::EPS_SVR )
        nu_begin = nu_end = svm_params.nu;
wester committed
2888

wester committed
2889 2890
    if( svm_type != CvSVM::EPS_SVR )
        p_begin = p_end = svm_params.p;
wester committed
2891

wester committed
2892 2893
    is_regression = cv_params->is_regression;
    best_rate = is_regression ? FLT_MAX : 0;
wester committed
2894

wester committed
2895 2896
    assert( g_step > 1 && degree_step > 1 && coef_step > 1);
    assert( p_step > 1 && C_step > 1 && nu_step > 1 );
wester committed
2897

wester committed
2898 2899 2900 2901 2902 2903 2904 2905 2906
    for( degree = degree_begin; degree <= degree_end; degree *= degree_step )
    {
      svm_params.degree = degree;
      //printf("degree = %.3f\n", degree );
      for( gamma= g_begin; gamma <= g_end; gamma *= g_step )
      {
        svm_params.gamma = gamma;
        //printf("   gamma = %.3f\n", gamma );
        for( coef = coef_begin; coef <= coef_end; coef *= coef_step )
wester committed
2907
        {
wester committed
2908 2909 2910 2911 2912 2913 2914
          svm_params.coef0 = coef;
          //printf("      coef = %.3f\n", coef );
          for( C = C_begin; C <= C_end; C *= C_step )
          {
            svm_params.C = C;
            //printf("         C = %.3f\n", C );
            for( nu = nu_begin; nu <= nu_end; nu *= nu_step )
wester committed
2915
            {
wester committed
2916 2917 2918 2919 2920 2921 2922 2923 2924 2925 2926 2927 2928 2929 2930 2931 2932 2933 2934 2935 2936 2937 2938 2939
              svm_params.nu = nu;
              //printf("            nu = %.3f\n", nu );
              for( p = p_begin; p <= p_end; p *= p_step )
              {
                int well;
                svm_params.p = p;
                //printf("               p = %.3f\n", p );

                CV_CALL(rate = cvCrossValidation( train_data, tflag, responses, &cvTrainSVM,
                    cross_valid_params, (CvStatModelParams*)&svm_params, comp_idx, sample_idx ));

                well =  rate > best_rate && !is_regression || rate < best_rate && is_regression;
                if( well || (rate == best_rate && C < best_C) )
                {
                    best_rate   = rate;
                    best_degree = degree;
                    best_gamma  = gamma;
                    best_coef   = coef;
                    best_C      = C;
                    best_nu     = nu;
                    best_p      = p;
                }
                //printf("                  rate = %.2f\n", rate );
              }
wester committed
2940
            }
wester committed
2941
          }
wester committed
2942
        }
wester committed
2943
      }
wester committed
2944
    }
wester committed
2945 2946
    //printf("The best:\nrate = %.2f%% degree = %f gamma = %f coef = %f c = %f nu = %f p = %f\n",
      //  best_rate, best_degree, best_gamma, best_coef, best_C, best_nu, best_p );
wester committed
2947

wester committed
2948 2949 2950 2951 2952 2953
    psvm_params->C      = best_C;
    psvm_params->nu     = best_nu;
    psvm_params->p      = best_p;
    psvm_params->gamma  = best_gamma;
    psvm_params->degree = best_degree;
    psvm_params->coef0  = best_coef;
wester committed
2954

wester committed
2955
    CV_CALL(svm = cvTrainSVM( train_data, tflag, responses, model_params, comp_idx, sample_idx ));
wester committed
2956

wester committed
2957
    __END__;
wester committed
2958

wester committed
2959
    return svm;
wester committed
2960 2961
}

wester committed
2962
#endif
wester committed
2963 2964

/* End of file. */