/*M///////////////////////////////////////////////////////////////////////////////////////
//
//  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
//
//  By downloading, copying, installing or using the software you agree to this license.
//  If you do not agree to this license, do not download, install,
//  copy or use the software.
//
//
//                        Intel License Agreement
//
// Copyright (C) 2000, Intel Corporation, all rights reserved.
// Third party copyrights are property of their respective owners.
//
// Redistribution and use in source and binary forms, with or without modification,
// are permitted provided that the following conditions are met:
//
//   * Redistribution's of source code must retain the above copyright notice,
//     this list of conditions and the following disclaimer.
//
//   * Redistribution's in binary form must reproduce the above copyright notice,
//     this list of conditions and the following disclaimer in the documentation
//     and/or other materials provided with the distribution.
//
//   * The name of Intel Corporation may not be used to endorse or promote products
//     derived from this software without specific prior written permission.
//
// This software is provided by the copyright holders and contributors "as is" and
// any express or implied warranties, including, but not limited to, the implied
// warranties of merchantability and fitness for a particular purpose are disclaimed.
// In no event shall the Intel Corporation or contributors be liable for any direct,
// indirect, incidental, special, exemplary, or consequential damages
// (including, but not limited to, procurement of substitute goods or services;
// loss of use, data, or profits; or business interruption) however caused
// and on any theory of liability, whether in contract, strict liability,
// or tort (including negligence or otherwise) arising in any way out of
// the use of this software, even if advised of the possibility of such damage.
//
//M*/

#include "precomp.hpp"

/****************************************************************************************\
*                          K-Nearest Neighbors Classifier                                *
\****************************************************************************************/

// k Nearest Neighbors
CvKNearest::CvKNearest()
{
    samples = 0;
    clear();
}


CvKNearest::~CvKNearest()
{
    clear();
}


CvKNearest::CvKNearest( const CvMat* _train_data, const CvMat* _responses,
                        const CvMat* _sample_idx, bool _is_regression, int _max_k )
{
    samples = 0;
    train( _train_data, _responses, _sample_idx, _is_regression, _max_k, false );
}


void CvKNearest::clear()
{
    while( samples )
    {
        CvVectors* next_samples = samples->next;
        cvFree( &samples->data.fl );
        cvFree( &samples );
        samples = next_samples;
    }
    var_count = 0;
    total = 0;
    max_k = 0;
}


int CvKNearest::get_max_k() const { return max_k; }

int CvKNearest::get_var_count() const { return var_count; }

bool CvKNearest::is_regression() const { return regression; }

int CvKNearest::get_sample_count() const { return total; }

bool CvKNearest::train( const CvMat* _train_data, const CvMat* _responses,
                        const CvMat* _sample_idx, bool _is_regression,
                        int _max_k, bool _update_base )
{
    bool ok = false;
    CvMat* responses = 0;

    CV_FUNCNAME( "CvKNearest::train" );

    __BEGIN__;

    CvVectors* _samples = 0;
    float** _data = 0;
    int _count = 0, _dims = 0, _dims_all = 0, _rsize = 0;

    if( !_update_base )
        clear();

    // Prepare training data and related parameters.
    // Treat categorical responses as ordered - to prevent class label compression and
    // to enable entering new classes in the updates
    CV_CALL( cvPrepareTrainData( "CvKNearest::train", _train_data, CV_ROW_SAMPLE,
        _responses, CV_VAR_ORDERED, 0, _sample_idx, true, (const float***)&_data,
        &_count, &_dims, &_dims_all, &responses, 0, 0 ));

    if( !responses )
        CV_ERROR( CV_StsNoMem, "Could not allocate memory for responses" );

    if( _update_base && _dims != var_count )
        CV_ERROR( CV_StsBadArg, "The newly added data have different dimensionality" );

    if( !_update_base )
    {
        if( _max_k < 1 )
            CV_ERROR( CV_StsOutOfRange, "max_k must be a positive number" );

        regression = _is_regression;
        var_count = _dims;
        max_k = _max_k;
    }

    _rsize = _count*sizeof(float);
    CV_CALL( _samples = (CvVectors*)cvAlloc( sizeof(*_samples) + _rsize ));
    _samples->next = samples;
    _samples->type = CV_32F;
    _samples->data.fl = _data;
    _samples->count = _count;
    total += _count;

    samples = _samples;
    memcpy( _samples + 1, responses->data.fl, _rsize );

    ok = true;

    __END__;

    if( responses && responses->data.ptr != _responses->data.ptr )
        cvReleaseMat(&responses);

    return ok;
}



void CvKNearest::find_neighbors_direct( const CvMat* _samples, int k, int start, int end,
                    float* neighbor_responses, const float** neighbors, float* dist ) const
{
    int i, j, count = end - start, k1 = 0, k2 = 0, d = var_count;
    CvVectors* s = samples;

    for( ; s != 0; s = s->next )
    {
        int n = s->count;
        for( j = 0; j < n; j++ )
        {
            for( i = 0; i < count; i++ )
            {
                double sum = 0;
                Cv32suf si;
                const float* v = s->data.fl[j];
                const float* u = (float*)(_samples->data.ptr + _samples->step*(start + i));
                Cv32suf* dd = (Cv32suf*)(dist + i*k);
                float* nr;
                const float** nn;
                int t, ii, ii1;

                for( t = 0; t <= d - 4; t += 4 )
                {
                    double t0 = u[t] - v[t], t1 = u[t+1] - v[t+1];
                    double t2 = u[t+2] - v[t+2], t3 = u[t+3] - v[t+3];
                    sum += t0*t0 + t1*t1 + t2*t2 + t3*t3;
                }

                for( ; t < d; t++ )
                {
                    double t0 = u[t] - v[t];
                    sum += t0*t0;
                }

                si.f = (float)sum;
                for( ii = k1-1; ii >= 0; ii-- )
                    if( si.i > dd[ii].i )
                        break;
                if( ii >= k-1 )
                    continue;

                nr = neighbor_responses + i*k;
                nn = neighbors ? neighbors + (start + i)*k : 0;
                for( ii1 = k2 - 1; ii1 > ii; ii1-- )
                {
                    dd[ii1+1].i = dd[ii1].i;
                    nr[ii1+1] = nr[ii1];
                    if( nn ) nn[ii1+1] = nn[ii1];
                }
                dd[ii+1].i = si.i;
                nr[ii+1] = ((float*)(s + 1))[j];
                if( nn )
                    nn[ii+1] = v;
            }
            k1 = MIN( k1+1, k );
            k2 = MIN( k1, k-1 );
        }
    }
}


float CvKNearest::write_results( int k, int k1, int start, int end,
    const float* neighbor_responses, const float* dist,
    CvMat* _results, CvMat* _neighbor_responses,
    CvMat* _dist, Cv32suf* sort_buf ) const
{
    float result = 0.f;
    int i, j, j1, count = end - start;
    double inv_scale = 1./k1;
    int rstep = _results && !CV_IS_MAT_CONT(_results->type) ? _results->step/sizeof(result) : 1;

    for( i = 0; i < count; i++ )
    {
        const Cv32suf* nr = (const Cv32suf*)(neighbor_responses + i*k);
        float* dst;
        float r;
        if( _results || start+i == 0 )
        {
            if( regression )
            {
                double s = 0;
                for( j = 0; j < k1; j++ )
                    s += nr[j].f;
                r = (float)(s*inv_scale);
            }
            else
            {
                int prev_start = 0, best_count = 0, cur_count;
                Cv32suf best_val;

                for( j = 0; j < k1; j++ )
                    sort_buf[j].i = nr[j].i;

                for( j = k1-1; j > 0; j-- )
                {
                    bool swap_fl = false;
                    for( j1 = 0; j1 < j; j1++ )
                        if( sort_buf[j1].i > sort_buf[j1+1].i )
                        {
                            int t;
                            CV_SWAP( sort_buf[j1].i, sort_buf[j1+1].i, t );
                            swap_fl = true;
                        }
                    if( !swap_fl )
                        break;
                }

                best_val.i = 0;
                for( j = 1; j <= k1; j++ )
                    if( j == k1 || sort_buf[j].i != sort_buf[j-1].i )
                    {
                        cur_count = j - prev_start;
                        if( best_count < cur_count )
                        {
                            best_count = cur_count;
                            best_val.i = sort_buf[j-1].i;
                        }
                        prev_start = j;
                    }
                r = best_val.f;
            }

            if( start+i == 0 )
                result = r;

            if( _results )
                _results->data.fl[(start + i)*rstep] = r;
        }

        if( _neighbor_responses )
        {
            dst = (float*)(_neighbor_responses->data.ptr +
                (start + i)*_neighbor_responses->step);
            for( j = 0; j < k1; j++ )
                dst[j] = nr[j].f;
            for( ; j < k; j++ )
                dst[j] = 0.f;
        }

        if( _dist )
        {
            dst = (float*)(_dist->data.ptr + (start + i)*_dist->step);
            for( j = 0; j < k1; j++ )
                dst[j] = dist[j + i*k];
            for( ; j < k; j++ )
                dst[j] = 0.f;
        }
    }

    return result;
}

struct P1 : cv::ParallelLoopBody {
  P1(const CvKNearest* _pointer, int _buf_sz, int _k, const CvMat* __samples, const float** __neighbors,
     int _k1, CvMat* __results, CvMat* __neighbor_responses, CvMat* __dist, float* _result)
  {
    pointer = _pointer;
    k = _k;
    _samples = __samples;
    _neighbors = __neighbors;
    k1 = _k1;
    _results = __results;
    _neighbor_responses = __neighbor_responses;
    _dist = __dist;
    result = _result;
    buf_sz = _buf_sz;
  }

  const CvKNearest* pointer;
  int k;
  const CvMat* _samples;
  const float** _neighbors;
  int k1;
  CvMat* _results;
  CvMat* _neighbor_responses;
  CvMat* _dist;
  float* result;
  int buf_sz;

  void operator()( const cv::Range& range ) const
  {
    cv::AutoBuffer<float> buf(buf_sz);
    for(int i = range.start; i < range.end; i += 1 )
    {
        float* neighbor_responses = &buf[0];
        float* dist = neighbor_responses + 1*k;
        Cv32suf* sort_buf = (Cv32suf*)(dist + 1*k);

        pointer->find_neighbors_direct( _samples, k, i, i + 1,
                    neighbor_responses, _neighbors, dist );

        float r = pointer->write_results( k, k1, i, i + 1, neighbor_responses, dist,
                                 _results, _neighbor_responses, _dist, sort_buf );

        if( i == 0 )
            *result = r;
    }
  }

};

float CvKNearest::find_nearest( const CvMat* _samples, int k, CvMat* _results,
    const float** _neighbors, CvMat* _neighbor_responses, CvMat* _dist ) const
{
    float result = 0.f;
    const int max_blk_count = 128, max_buf_sz = 1 << 12;

    if( !samples )
        CV_Error( CV_StsError, "The search tree must be constructed first using train method" );

    if( !CV_IS_MAT(_samples) ||
        CV_MAT_TYPE(_samples->type) != CV_32FC1 ||
        _samples->cols != var_count )
        CV_Error( CV_StsBadArg, "Input samples must be floating-point matrix (<num_samples>x<var_count>)" );

    if( _results && (!CV_IS_MAT(_results) ||
        (_results->cols != 1 && _results->rows != 1) ||
        _results->cols + _results->rows - 1 != _samples->rows) )
        CV_Error( CV_StsBadArg,
        "The results must be 1d vector containing as much elements as the number of samples" );

    if( _results && CV_MAT_TYPE(_results->type) != CV_32FC1 &&
        (CV_MAT_TYPE(_results->type) != CV_32SC1 || regression))
        CV_Error( CV_StsUnsupportedFormat,
        "The results must be floating-point or integer (in case of classification) vector" );

    if( k < 1 || k > max_k )
        CV_Error( CV_StsOutOfRange, "k must be within 1..max_k range" );

    if( _neighbor_responses )
    {
        if( !CV_IS_MAT(_neighbor_responses) || CV_MAT_TYPE(_neighbor_responses->type) != CV_32FC1 ||
            _neighbor_responses->rows != _samples->rows || _neighbor_responses->cols != k )
            CV_Error( CV_StsBadArg,
            "The neighbor responses (if present) must be floating-point matrix of <num_samples> x <k> size" );
    }

    if( _dist )
    {
        if( !CV_IS_MAT(_dist) || CV_MAT_TYPE(_dist->type) != CV_32FC1 ||
            _dist->rows != _samples->rows || _dist->cols != k )
            CV_Error( CV_StsBadArg,
            "The distances from the neighbors (if present) must be floating-point matrix of <num_samples> x <k> size" );
    }

    int count = _samples->rows;
    int count_scale = k*2;
    int blk_count0 = MIN( count, max_blk_count );
    int buf_sz = MIN( blk_count0 * count_scale, max_buf_sz );
    blk_count0 = MAX( buf_sz/count_scale, 1 );
    blk_count0 += blk_count0 % 2;
    blk_count0 = MIN( blk_count0, count );
    buf_sz = blk_count0 * count_scale + k;
    int k1 = get_sample_count();
    k1 = MIN( k1, k );

    cv::parallel_for_(cv::Range(0, count), P1(this, buf_sz, k, _samples, _neighbors, k1,
                                             _results, _neighbor_responses, _dist, &result)
    );

    return result;
}


using namespace cv;

CvKNearest::CvKNearest( const Mat& _train_data, const Mat& _responses,
                       const Mat& _sample_idx, bool _is_regression, int _max_k )
{
    samples = 0;
    train(_train_data, _responses, _sample_idx, _is_regression, _max_k, false );
}

bool CvKNearest::train( const Mat& _train_data, const Mat& _responses,
                        const Mat& _sample_idx, bool _is_regression,
                        int _max_k, bool _update_base )
{
    CvMat tdata = _train_data, responses = _responses, sidx = _sample_idx;

    return train(&tdata, &responses, sidx.data.ptr ? &sidx : 0, _is_regression, _max_k, _update_base );
}


float CvKNearest::find_nearest( const Mat& _samples, int k, Mat* _results,
                                const float** _neighbors, Mat* _neighbor_responses,
                                Mat* _dist ) const
{
    CvMat s = _samples, results, *presults = 0, nresponses, *pnresponses = 0, dist, *pdist = 0;

    if( _results )
    {
        if(!(_results->data && (_results->type() == CV_32F ||
            (_results->type() == CV_32S && regression)) &&
             (_results->cols == 1 || _results->rows == 1) &&
             _results->cols + _results->rows - 1 == _samples.rows) )
            _results->create(_samples.rows, 1, CV_32F);
        presults = &(results = *_results);
    }

    if( _neighbor_responses )
    {
        if(!(_neighbor_responses->data && _neighbor_responses->type() == CV_32F &&
             _neighbor_responses->cols == k && _neighbor_responses->rows == _samples.rows) )
            _neighbor_responses->create(_samples.rows, k, CV_32F);
        pnresponses = &(nresponses = *_neighbor_responses);
    }

    if( _dist )
    {
        if(!(_dist->data && _dist->type() == CV_32F &&
             _dist->cols == k && _dist->rows == _samples.rows) )
            _dist->create(_samples.rows, k, CV_32F);
        pdist = &(dist = *_dist);
    }

    return find_nearest(&s, k, presults, _neighbors, pnresponses, pdist );
}


float CvKNearest::find_nearest( const cv::Mat& _samples, int k, CV_OUT cv::Mat& results,
                                CV_OUT cv::Mat& neighborResponses, CV_OUT cv::Mat& dists) const
{
    return find_nearest(_samples, k, &results, 0, &neighborResponses, &dists);
}

/* End of file */