em.cpp 21.3 KB
Newer Older
wester committed
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48
/*M///////////////////////////////////////////////////////////////////////////////////////
//
//  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
//
//  By downloading, copying, installing or using the software you agree to this license.
//  If you do not agree to this license, do not download, install,
//  copy or use the software.
//
//
//                        Intel License Agreement
//                For Open Source Computer Vision Library
//
// Copyright( C) 2000, Intel Corporation, all rights reserved.
// Third party copyrights are property of their respective owners.
//
// Redistribution and use in source and binary forms, with or without modification,
// are permitted provided that the following conditions are met:
//
//   * Redistribution's of source code must retain the above copyright notice,
//     this list of conditions and the following disclaimer.
//
//   * Redistribution's in binary form must reproduce the above copyright notice,
//     this list of conditions and the following disclaimer in the documentation
//     and/or other materials provided with the distribution.
//
//   * The name of Intel Corporation may not be used to endorse or promote products
//     derived from this software without specific prior written permission.
//
// This software is provided by the copyright holders and contributors "as is" and
// any express or implied warranties, including, but not limited to, the implied
// warranties of merchantability and fitness for a particular purpose are disclaimed.
// In no event shall the Intel Corporation or contributors be liable for any direct,
// indirect, incidental, special, exemplary, or consequential damages
//(including, but not limited to, procurement of substitute goods or services;
// loss of use, data, or profits; or business interruption) however caused
// and on any theory of liability, whether in contract, strict liability,
// or tort(including negligence or otherwise) arising in any way out of
// the use of this software, even ifadvised of the possibility of such damage.
//
//M*/

#include "precomp.hpp"

namespace cv
{

const double minEigenValue = DBL_EPSILON;

wester committed
49
///////////////////////////////////////////////////////////////////////////////////////////////////////
wester committed
50

wester committed
51 52 53 54 55 56 57
EM::EM(int _nclusters, int _covMatType, const TermCriteria& _termCrit)
{
    nclusters = _nclusters;
    covMatType = _covMatType;
    maxIters = (_termCrit.type & TermCriteria::MAX_ITER) ? _termCrit.maxCount : DEFAULT_MAX_ITERS;
    epsilon = (_termCrit.type & TermCriteria::EPS) ? _termCrit.epsilon : 0;
}
wester committed
58

wester committed
59 60 61 62
EM::~EM()
{
    //clear();
}
wester committed
63

wester committed
64 65 66 67 68 69
void EM::clear()
{
    trainSamples.release();
    trainProbs.release();
    trainLogLikelihoods.release();
    trainLabels.release();
wester committed
70

wester committed
71 72 73
    weights.release();
    means.release();
    covs.clear();
wester committed
74

wester committed
75 76 77
    covsEigenValues.clear();
    invCovsEigenValues.clear();
    covsRotateMats.clear();
wester committed
78

wester committed
79 80
    logWeightDivDet.release();
}
wester committed
81 82


wester committed
83
bool EM::train(InputArray samples,
wester committed
84 85 86
               OutputArray logLikelihoods,
               OutputArray labels,
               OutputArray probs)
wester committed
87 88 89 90 91
{
    Mat samplesMat = samples.getMat();
    setTrainData(START_AUTO_STEP, samplesMat, 0, 0, 0, 0);
    return doTrain(START_AUTO_STEP, logLikelihoods, labels, probs);
}
wester committed
92

wester committed
93
bool EM::trainE(InputArray samples,
wester committed
94 95 96 97 98 99
                InputArray _means0,
                InputArray _covs0,
                InputArray _weights0,
                OutputArray logLikelihoods,
                OutputArray labels,
                OutputArray probs)
wester committed
100 101 102 103
{
    Mat samplesMat = samples.getMat();
    vector<Mat> covs0;
    _covs0.getMatVector(covs0);
wester committed
104

wester committed
105
    Mat means0 = _means0.getMat(), weights0 = _weights0.getMat();
wester committed
106

wester committed
107 108 109 110
    setTrainData(START_E_STEP, samplesMat, 0, !_means0.empty() ? &means0 : 0,
                 !_covs0.empty() ? &covs0 : 0, !_weights0.empty() ? &weights0 : 0);
    return doTrain(START_E_STEP, logLikelihoods, labels, probs);
}
wester committed
111

wester committed
112
bool EM::trainM(InputArray samples,
wester committed
113 114 115 116
                InputArray _probs0,
                OutputArray logLikelihoods,
                OutputArray labels,
                OutputArray probs)
wester committed
117 118 119
{
    Mat samplesMat = samples.getMat();
    Mat probs0 = _probs0.getMat();
wester committed
120

wester committed
121 122 123
    setTrainData(START_M_STEP, samplesMat, !_probs0.empty() ? &probs0 : 0, 0, 0, 0);
    return doTrain(START_M_STEP, logLikelihoods, labels, probs);
}
wester committed
124 125


wester committed
126 127 128 129
Vec2d EM::predict(InputArray _sample, OutputArray _probs) const
{
    Mat sample = _sample.getMat();
    CV_Assert(isTrained());
wester committed
130

wester committed
131 132 133 134 135 136
    CV_Assert(!sample.empty());
    if(sample.type() != CV_64FC1)
    {
        Mat tmp;
        sample.convertTo(tmp, CV_64FC1);
        sample = tmp;
wester committed
137
    }
wester committed
138
    sample = sample.reshape(1, 1);
wester committed
139

wester committed
140 141
    Mat probs;
    if( _probs.needed() )
wester committed
142
    {
wester committed
143 144 145
        _probs.create(1, nclusters, CV_64FC1);
        probs = _probs.getMat();
    }
wester committed
146

wester committed
147 148
    return computeProbabilities(sample, !probs.empty() ? &probs : 0);
}
wester committed
149

wester committed
150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199
bool EM::isTrained() const
{
    return !means.empty();
}


static
void checkTrainData(int startStep, const Mat& samples,
                    int nclusters, int covMatType, const Mat* probs, const Mat* means,
                    const vector<Mat>* covs, const Mat* weights)
{
    // Check samples.
    CV_Assert(!samples.empty());
    CV_Assert(samples.channels() == 1);

    int nsamples = samples.rows;
    int dim = samples.cols;

    // Check training params.
    CV_Assert(nclusters > 0);
    CV_Assert(nclusters <= nsamples);
    CV_Assert(startStep == EM::START_AUTO_STEP ||
              startStep == EM::START_E_STEP ||
              startStep == EM::START_M_STEP);
    CV_Assert(covMatType == EM::COV_MAT_GENERIC ||
              covMatType == EM::COV_MAT_DIAGONAL ||
              covMatType == EM::COV_MAT_SPHERICAL);

    CV_Assert(!probs ||
        (!probs->empty() &&
         probs->rows == nsamples && probs->cols == nclusters &&
         (probs->type() == CV_32FC1 || probs->type() == CV_64FC1)));

    CV_Assert(!weights ||
        (!weights->empty() &&
         (weights->cols == 1 || weights->rows == 1) && static_cast<int>(weights->total()) == nclusters &&
         (weights->type() == CV_32FC1 || weights->type() == CV_64FC1)));

    CV_Assert(!means ||
        (!means->empty() &&
         means->rows == nclusters && means->cols == dim &&
         means->channels() == 1));

    CV_Assert(!covs ||
        (!covs->empty() &&
         static_cast<int>(covs->size()) == nclusters));
    if(covs)
    {
        const Size covSize(dim, dim);
        for(size_t i = 0; i < covs->size(); i++)
wester committed
200
        {
wester committed
201 202
            const Mat& m = (*covs)[i];
            CV_Assert(!m.empty() && m.size() == covSize && (m.channels() == 1));
wester committed
203 204 205
        }
    }

wester committed
206
    if(startStep == EM::START_E_STEP)
wester committed
207
    {
wester committed
208
        CV_Assert(means);
wester committed
209
    }
wester committed
210
    else if(startStep == EM::START_M_STEP)
wester committed
211
    {
wester committed
212
        CV_Assert(probs);
wester committed
213
    }
wester committed
214
}
wester committed
215

wester committed
216 217 218 219 220 221 222 223
static
void preprocessSampleData(const Mat& src, Mat& dst, int dstType, bool isAlwaysClone)
{
    if(src.type() == dstType && !isAlwaysClone)
        dst = src;
    else
        src.convertTo(dst, dstType);
}
wester committed
224

wester committed
225 226 227 228
static
void preprocessProbability(Mat& probs)
{
    max(probs, 0., probs);
wester committed
229

wester committed
230 231
    const double uniformProbability = (double)(1./probs.cols);
    for(int y = 0; y < probs.rows; y++)
wester committed
232
    {
wester committed
233
        Mat sampleProbs = probs.row(y);
wester committed
234

wester committed
235 236 237 238 239 240
        double maxVal = 0;
        minMaxLoc(sampleProbs, 0, &maxVal);
        if(maxVal < FLT_EPSILON)
            sampleProbs.setTo(uniformProbability);
        else
            normalize(sampleProbs, sampleProbs, 1, 0, NORM_L1);
wester committed
241
    }
wester committed
242 243 244 245 246 247 248 249 250 251 252 253 254 255 256
}

void EM::setTrainData(int startStep, const Mat& samples,
                      const Mat* probs0,
                      const Mat* means0,
                      const vector<Mat>* covs0,
                      const Mat* weights0)
{
    clear();

    checkTrainData(startStep, samples, nclusters, covMatType, probs0, means0, covs0, weights0);

    bool isKMeansInit = (startStep == EM::START_AUTO_STEP) || (startStep == EM::START_E_STEP && (covs0 == 0 || weights0 == 0));
    // Set checked data
    preprocessSampleData(samples, trainSamples, isKMeansInit ? CV_32FC1 : CV_64FC1, false);
wester committed
257

wester committed
258 259
    // set probs
    if(probs0 && startStep == EM::START_M_STEP)
wester committed
260
    {
wester committed
261 262
        preprocessSampleData(*probs0, trainProbs, CV_64FC1, true);
        preprocessProbability(trainProbs);
wester committed
263 264
    }

wester committed
265 266
    // set weights
    if(weights0 && (startStep == EM::START_E_STEP && covs0))
wester committed
267
    {
wester committed
268 269 270 271
        weights0->convertTo(weights, CV_64FC1);
        weights = weights.reshape(1,1);
        preprocessProbability(weights);
    }
wester committed
272

wester committed
273 274 275
    // set means
    if(means0 && (startStep == EM::START_E_STEP/* || startStep == EM::START_AUTO_STEP*/))
        means0->convertTo(means, isKMeansInit ? CV_32FC1 : CV_64FC1);
wester committed
276

wester committed
277 278 279 280 281 282
    // set covs
    if(covs0 && (startStep == EM::START_E_STEP && weights0))
    {
        covs.resize(nclusters);
        for(size_t i = 0; i < covs0->size(); i++)
            (*covs0)[i].convertTo(covs[i], CV_64FC1);
wester committed
283
    }
wester committed
284
}
wester committed
285

wester committed
286 287 288 289 290 291 292 293
void EM::decomposeCovs()
{
    CV_Assert(!covs.empty());
    covsEigenValues.resize(nclusters);
    if(covMatType == EM::COV_MAT_GENERIC)
        covsRotateMats.resize(nclusters);
    invCovsEigenValues.resize(nclusters);
    for(int clusterIndex = 0; clusterIndex < nclusters; clusterIndex++)
wester committed
294
    {
wester committed
295
        CV_Assert(!covs[clusterIndex].empty());
wester committed
296

wester committed
297
        SVD svd(covs[clusterIndex], SVD::MODIFY_A + SVD::FULL_UV);
wester committed
298

wester committed
299
        if(covMatType == EM::COV_MAT_SPHERICAL)
wester committed
300
        {
wester committed
301 302
            double maxSingularVal = svd.w.at<double>(0);
            covsEigenValues[clusterIndex] = Mat(1, 1, CV_64FC1, Scalar(maxSingularVal));
wester committed
303
        }
wester committed
304
        else if(covMatType == EM::COV_MAT_DIAGONAL)
wester committed
305
        {
wester committed
306
            covsEigenValues[clusterIndex] = svd.w;
wester committed
307
        }
wester committed
308
        else //EM::COV_MAT_GENERIC
wester committed
309
        {
wester committed
310 311
            covsEigenValues[clusterIndex] = svd.w;
            covsRotateMats[clusterIndex] = svd.u;
wester committed
312
        }
wester committed
313 314
        max(covsEigenValues[clusterIndex], minEigenValue, covsEigenValues[clusterIndex]);
        invCovsEigenValues[clusterIndex] = 1./covsEigenValues[clusterIndex];
wester committed
315
    }
wester committed
316
}
wester committed
317

wester committed
318 319 320 321 322 323 324 325 326 327 328 329 330
void EM::clusterTrainSamples()
{
    int nsamples = trainSamples.rows;

    // Cluster samples, compute/update means

    // Convert samples and means to 32F, because kmeans requires this type.
    Mat trainSamplesFlt, meansFlt;
    if(trainSamples.type() != CV_32FC1)
        trainSamples.convertTo(trainSamplesFlt, CV_32FC1);
    else
        trainSamplesFlt = trainSamples;
    if(!means.empty())
wester committed
331
    {
wester committed
332 333 334 335 336
        if(means.type() != CV_32FC1)
            means.convertTo(meansFlt, CV_32FC1);
        else
            meansFlt = means;
    }
wester committed
337

wester committed
338 339
    Mat labels;
    kmeans(trainSamplesFlt, nclusters, labels,  TermCriteria(TermCriteria::COUNT, means.empty() ? 10 : 1, 0.5), 10, KMEANS_PP_CENTERS, meansFlt);
wester committed
340

wester committed
341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359
    // Convert samples and means back to 64F.
    CV_Assert(meansFlt.type() == CV_32FC1);
    if(trainSamples.type() != CV_64FC1)
    {
        Mat trainSamplesBuffer;
        trainSamplesFlt.convertTo(trainSamplesBuffer, CV_64FC1);
        trainSamples = trainSamplesBuffer;
    }
    meansFlt.convertTo(means, CV_64FC1);

    // Compute weights and covs
    weights = Mat(1, nclusters, CV_64FC1, Scalar(0));
    covs.resize(nclusters);
    for(int clusterIndex = 0; clusterIndex < nclusters; clusterIndex++)
    {
        Mat clusterSamples;
        for(int sampleIndex = 0; sampleIndex < nsamples; sampleIndex++)
        {
            if(labels.at<int>(sampleIndex) == clusterIndex)
wester committed
360
            {
wester committed
361 362
                const Mat sample = trainSamples.row(sampleIndex);
                clusterSamples.push_back(sample);
wester committed
363 364
            }
        }
wester committed
365 366 367 368 369
        CV_Assert(!clusterSamples.empty());

        calcCovarMatrix(clusterSamples, covs[clusterIndex], means.row(clusterIndex),
            CV_COVAR_NORMAL + CV_COVAR_ROWS + CV_COVAR_USE_AVG + CV_COVAR_SCALE, CV_64FC1);
        weights.at<double>(clusterIndex) = static_cast<double>(clusterSamples.rows)/static_cast<double>(nsamples);
wester committed
370 371
    }

wester committed
372 373
    decomposeCovs();
}
wester committed
374

wester committed
375 376 377
void EM::computeLogWeightDivDet()
{
    CV_Assert(!covsEigenValues.empty());
wester committed
378

wester committed
379 380 381
    Mat logWeights;
    cv::max(weights, DBL_MIN, weights);
    log(weights, logWeights);
wester committed
382

wester committed
383 384
    logWeightDivDet.create(1, nclusters, CV_64FC1);
    // note: logWeightDivDet = log(weight_k) - 0.5 * log(|det(cov_k)|)
wester committed
385

wester committed
386 387 388 389 390 391
    for(int clusterIndex = 0; clusterIndex < nclusters; clusterIndex++)
    {
        double logDetCov = 0.;
        const int evalCount = static_cast<int>(covsEigenValues[clusterIndex].total());
        for(int di = 0; di < evalCount; di++)
            logDetCov += std::log(covsEigenValues[clusterIndex].at<double>(covMatType != EM::COV_MAT_SPHERICAL ? di : 0));
wester committed
392

wester committed
393 394 395
        logWeightDivDet.at<double>(clusterIndex) = logWeights.at<double>(clusterIndex) - 0.5 * logDetCov;
    }
}
wester committed
396

wester committed
397 398 399 400 401 402 403 404 405 406
bool EM::doTrain(int startStep, OutputArray logLikelihoods, OutputArray labels, OutputArray probs)
{
    int dim = trainSamples.cols;
    // Precompute the empty initial train data in the cases of EM::START_E_STEP and START_AUTO_STEP
    if(startStep != EM::START_M_STEP)
    {
        if(covs.empty())
        {
            CV_Assert(weights.empty());
            clusterTrainSamples();
wester committed
407
        }
wester committed
408
    }
wester committed
409

wester committed
410 411 412
    if(!covs.empty() && covsEigenValues.empty() )
    {
        CV_Assert(invCovsEigenValues.empty());
wester committed
413 414 415
        decomposeCovs();
    }

wester committed
416 417 418 419 420
    if(startStep == EM::START_M_STEP)
        mStep();

    double trainLogLikelihood, prevTrainLogLikelihood = 0.;
    for(int iter = 0; ; iter++)
wester committed
421
    {
wester committed
422 423
        eStep();
        trainLogLikelihood = sum(trainLogLikelihoods)[0];
wester committed
424

wester committed
425 426
        if(iter >= maxIters - 1)
            break;
wester committed
427

wester committed
428 429 430 431 432
        double trainLogLikelihoodDelta = trainLogLikelihood - prevTrainLogLikelihood;
        if( iter != 0 &&
            (trainLogLikelihoodDelta < -DBL_EPSILON ||
             trainLogLikelihoodDelta < epsilon * std::fabs(trainLogLikelihood)))
            break;
wester committed
433

wester committed
434
        mStep();
wester committed
435

wester committed
436
        prevTrainLogLikelihood = trainLogLikelihood;
wester committed
437 438
    }

wester committed
439
    if( trainLogLikelihood <= -DBL_MAX/10000. )
wester committed
440
    {
wester committed
441 442 443 444 445 446 447 448 449
        clear();
        return false;
    }

    // postprocess covs
    covs.resize(nclusters);
    for(int clusterIndex = 0; clusterIndex < nclusters; clusterIndex++)
    {
        if(covMatType == EM::COV_MAT_SPHERICAL)
wester committed
450
        {
wester committed
451 452
            covs[clusterIndex].create(dim, dim, CV_64FC1);
            setIdentity(covs[clusterIndex], Scalar(covsEigenValues[clusterIndex].at<double>(0)));
wester committed
453
        }
wester committed
454
        else if(covMatType == EM::COV_MAT_DIAGONAL)
wester committed
455
        {
wester committed
456
            covs[clusterIndex] = Mat::diag(covsEigenValues[clusterIndex]);
wester committed
457
        }
wester committed
458
    }
wester committed
459

wester committed
460 461 462 463 464 465
    if(labels.needed())
        trainLabels.copyTo(labels);
    if(probs.needed())
        trainProbs.copyTo(probs);
    if(logLikelihoods.needed())
        trainLogLikelihoods.copyTo(logLikelihoods);
wester committed
466

wester committed
467 468 469 470
    trainSamples.release();
    trainProbs.release();
    trainLabels.release();
    trainLogLikelihoods.release();
wester committed
471

wester committed
472 473
    return true;
}
wester committed
474

wester committed
475 476 477 478 479 480 481
Vec2d EM::computeProbabilities(const Mat& sample, Mat* probs) const
{
    // L_ik = log(weight_k) - 0.5 * log(|det(cov_k)|) - 0.5 *(x_i - mean_k)' cov_k^(-1) (x_i - mean_k)]
    // q = arg(max_k(L_ik))
    // probs_ik = exp(L_ik - L_iq) / (1 + sum_j!=q (exp(L_ij - L_iq))
    // see Alex Smola's blog http://blog.smola.org/page/2 for
    // details on the log-sum-exp trick
wester committed
482

wester committed
483 484 485 486
    CV_Assert(!means.empty());
    CV_Assert(sample.type() == CV_64FC1);
    CV_Assert(sample.rows == 1);
    CV_Assert(sample.cols == means.cols);
wester committed
487

wester committed
488
    int dim = sample.cols;
wester committed
489

wester committed
490 491 492 493 494
    Mat L(1, nclusters, CV_64FC1);
    int label = 0;
    for(int clusterIndex = 0; clusterIndex < nclusters; clusterIndex++)
    {
        const Mat centeredSample = sample - means.row(clusterIndex);
wester committed
495

wester committed
496 497
        Mat rotatedCenteredSample = covMatType != EM::COV_MAT_GENERIC ?
                centeredSample : centeredSample * covsRotateMats[clusterIndex];
wester committed
498

wester committed
499 500
        double Lval = 0;
        for(int di = 0; di < dim; di++)
wester committed
501
        {
wester committed
502 503 504
            double w = invCovsEigenValues[clusterIndex].at<double>(covMatType != EM::COV_MAT_SPHERICAL ? di : 0);
            double val = rotatedCenteredSample.at<double>(di);
            Lval += w * val * val;
wester committed
505
        }
wester committed
506 507
        CV_DbgAssert(!logWeightDivDet.empty());
        L.at<double>(clusterIndex) = logWeightDivDet.at<double>(clusterIndex) - 0.5 * Lval;
wester committed
508

wester committed
509 510
        if(L.at<double>(clusterIndex) > L.at<double>(label))
            label = clusterIndex;
wester committed
511 512
    }

wester committed
513 514 515 516 517
    double maxLVal = L.at<double>(label);
    Mat expL_Lmax = L; // exp(L_ij - L_iq)
    for(int i = 0; i < L.cols; i++)
        expL_Lmax.at<double>(i) = std::exp(L.at<double>(i) - maxLVal);
    double expDiffSum = sum(expL_Lmax)[0]; // sum_j(exp(L_ij - L_iq))
wester committed
518

wester committed
519 520 521 522 523 524 525
    if(probs)
    {
        probs->create(1, nclusters, CV_64FC1);
        double factor = 1./expDiffSum;
        expL_Lmax *= factor;
        expL_Lmax.copyTo(*probs);
    }
wester committed
526

wester committed
527 528 529
    Vec2d res;
    res[0] = std::log(expDiffSum)  + maxLVal - 0.5 * dim * CV_LOG2PI;
    res[1] = label;
wester committed
530

wester committed
531 532
    return res;
}
wester committed
533

wester committed
534 535 536 537 538 539
void EM::eStep()
{
    // Compute probs_ik from means_k, covs_k and weights_k.
    trainProbs.create(trainSamples.rows, nclusters, CV_64FC1);
    trainLabels.create(trainSamples.rows, 1, CV_32SC1);
    trainLogLikelihoods.create(trainSamples.rows, 1, CV_64FC1);
wester committed
540

wester committed
541
    computeLogWeightDivDet();
wester committed
542

wester committed
543 544
    CV_DbgAssert(trainSamples.type() == CV_64FC1);
    CV_DbgAssert(means.type() == CV_64FC1);
wester committed
545

wester committed
546 547 548 549 550 551
    for(int sampleIndex = 0; sampleIndex < trainSamples.rows; sampleIndex++)
    {
        Mat sampleProbs = trainProbs.row(sampleIndex);
        Vec2d res = computeProbabilities(trainSamples.row(sampleIndex), &sampleProbs);
        trainLogLikelihoods.at<double>(sampleIndex) = res[0];
        trainLabels.at<int>(sampleIndex) = static_cast<int>(res[1]);
wester committed
552
    }
wester committed
553
}
wester committed
554

wester committed
555 556 557 558
void EM::mStep()
{
    // Update means_k, covs_k and weights_k from probs_ik
    int dim = trainSamples.cols;
wester committed
559

wester committed
560 561 562
    // Update weights
    // not normalized first
    reduce(trainProbs, weights, 0, CV_REDUCE_SUM);
wester committed
563

wester committed
564 565 566
    // Update means
    means.create(nclusters, dim, CV_64FC1);
    means = Scalar(0);
wester committed
567

wester committed
568 569 570 571 572 573 574 575 576
    const double minPosWeight = trainSamples.rows * DBL_EPSILON;
    double minWeight = DBL_MAX;
    int minWeightClusterIndex = -1;
    for(int clusterIndex = 0; clusterIndex < nclusters; clusterIndex++)
    {
        if(weights.at<double>(clusterIndex) <= minPosWeight)
            continue;

        if(weights.at<double>(clusterIndex) < minWeight)
wester committed
577
        {
wester committed
578 579
            minWeight = weights.at<double>(clusterIndex);
            minWeightClusterIndex = clusterIndex;
wester committed
580
        }
wester committed
581 582 583 584 585

        Mat clusterMean = means.row(clusterIndex);
        for(int sampleIndex = 0; sampleIndex < trainSamples.rows; sampleIndex++)
            clusterMean += trainProbs.at<double>(sampleIndex, clusterIndex) * trainSamples.row(sampleIndex);
        clusterMean /= weights.at<double>(clusterIndex);
wester committed
586 587
    }

wester committed
588 589 590 591 592 593 594
    // Update covsEigenValues and invCovsEigenValues
    covs.resize(nclusters);
    covsEigenValues.resize(nclusters);
    if(covMatType == EM::COV_MAT_GENERIC)
        covsRotateMats.resize(nclusters);
    invCovsEigenValues.resize(nclusters);
    for(int clusterIndex = 0; clusterIndex < nclusters; clusterIndex++)
wester committed
595
    {
wester committed
596 597
        if(weights.at<double>(clusterIndex) <= minPosWeight)
            continue;
wester committed
598

wester committed
599 600 601 602
        if(covMatType != EM::COV_MAT_SPHERICAL)
            covsEigenValues[clusterIndex].create(1, dim, CV_64FC1);
        else
            covsEigenValues[clusterIndex].create(1, 1, CV_64FC1);
wester committed
603

wester committed
604 605
        if(covMatType == EM::COV_MAT_GENERIC)
            covs[clusterIndex].create(dim, dim, CV_64FC1);
wester committed
606

wester committed
607 608
        Mat clusterCov = covMatType != EM::COV_MAT_GENERIC ?
            covsEigenValues[clusterIndex] : covs[clusterIndex];
wester committed
609

wester committed
610
        clusterCov = Scalar(0);
wester committed
611

wester committed
612 613
        Mat centeredSample;
        for(int sampleIndex = 0; sampleIndex < trainSamples.rows; sampleIndex++)
wester committed
614
        {
wester committed
615
            centeredSample = trainSamples.row(sampleIndex) - means.row(clusterIndex);
wester committed
616

wester committed
617 618
            if(covMatType == EM::COV_MAT_GENERIC)
                clusterCov += trainProbs.at<double>(sampleIndex, clusterIndex) * centeredSample.t() * centeredSample;
wester committed
619 620
            else
            {
wester committed
621 622
                double p = trainProbs.at<double>(sampleIndex, clusterIndex);
                for(int di = 0; di < dim; di++ )
wester committed
623
                {
wester committed
624 625
                    double val = centeredSample.at<double>(di);
                    clusterCov.at<double>(covMatType != EM::COV_MAT_SPHERICAL ? di : 0) += p*val*val;
wester committed
626 627
                }
            }
wester committed
628
        }
wester committed
629

wester committed
630 631
        if(covMatType == EM::COV_MAT_SPHERICAL)
            clusterCov /= dim;
wester committed
632

wester committed
633
        clusterCov /= weights.at<double>(clusterIndex);
wester committed
634

wester committed
635 636
        // Update covsRotateMats for EM::COV_MAT_GENERIC only
        if(covMatType == EM::COV_MAT_GENERIC)
wester committed
637
        {
wester committed
638 639 640
            SVD svd(covs[clusterIndex], SVD::MODIFY_A + SVD::FULL_UV);
            covsEigenValues[clusterIndex] = svd.w;
            covsRotateMats[clusterIndex] = svd.u;
wester committed
641 642
        }

wester committed
643
        max(covsEigenValues[clusterIndex], minEigenValue, covsEigenValues[clusterIndex]);
wester committed
644

wester committed
645 646
        // update invCovsEigenValues
        invCovsEigenValues[clusterIndex] = 1./covsEigenValues[clusterIndex];
wester committed
647 648
    }

wester committed
649
    for(int clusterIndex = 0; clusterIndex < nclusters; clusterIndex++)
wester committed
650
    {
wester committed
651 652 653 654 655 656 657 658 659 660
        if(weights.at<double>(clusterIndex) <= minPosWeight)
        {
            Mat clusterMean = means.row(clusterIndex);
            means.row(minWeightClusterIndex).copyTo(clusterMean);
            covs[minWeightClusterIndex].copyTo(covs[clusterIndex]);
            covsEigenValues[minWeightClusterIndex].copyTo(covsEigenValues[clusterIndex]);
            if(covMatType == EM::COV_MAT_GENERIC)
                covsRotateMats[minWeightClusterIndex].copyTo(covsRotateMats[clusterIndex]);
            invCovsEigenValues[minWeightClusterIndex].copyTo(invCovsEigenValues[clusterIndex]);
        }
wester committed
661 662
    }

wester committed
663 664 665
    // Normalize weights
    weights /= trainSamples.rows;
}
wester committed
666

wester committed
667
void EM::read(const FileNode& fn)
wester committed
668
{
wester committed
669
    Algorithm::read(fn);
wester committed
670

wester committed
671 672
    decomposeCovs();
    computeLogWeightDivDet();
wester committed
673
}
wester committed
674

wester committed
675 676 677
} // namespace cv

/* End of file. */