/* Original code has been submitted by Liu Liu.
   ----------------------------------------------------------------------------------
   * Spill-Tree for Approximate KNN Search
   * Author: Liu Liu
   * mailto: liuliu.1987+opencv@gmail.com
   * Refer to Paper:
   * An Investigation of Practical Approximate Nearest Neighbor Algorithms
   * cvMergeSpillTree TBD
   *
   * Redistribution and use in source and binary forms, with or
   * without modification, are permitted provided that the following
   * conditions are met:
   * 	Redistributions of source code must retain the above
   * 	copyright notice, this list of conditions and the following
   * 	disclaimer.
   * 	Redistributions in binary form must reproduce the above
   * 	copyright notice, this list of conditions and the following
   * 	disclaimer in the documentation and/or other materials
   * 	provided with the distribution.
   * 	The name of Contributor may not be used to endorse or
   * 	promote products derived from this software without
   * 	specific prior written permission.
   *
   * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND
   * CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES,
   * INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
   * MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
   * DISCLAIMED. IN NO EVENT SHALL THE CONTRIBUTORS BE LIABLE FOR ANY
   * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
   * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
   * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA,
   * OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
   * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR
   * TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT
   * OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY
   * OF SUCH DAMAGE.
   */

#include "precomp.hpp"
#include "_featuretree.h"

struct CvSpillTreeNode
{
  bool leaf; // is leaf or not (leaf is the point that have no more child)
  bool spill; // is not a non-overlapping point (defeatist search)
  CvSpillTreeNode* lc; // left child (<)
  CvSpillTreeNode* rc; // right child (>)
  int cc; // child count
  CvMat* u; // projection vector
  CvMat* center; // center
  int i; // original index
  double r; // radius of remaining feature point
  double ub; // upper bound
  double lb; // lower bound
  double mp; // mean point
  double p; // projection value
};

struct CvSpillTree
{
  CvSpillTreeNode* root;
  CvMat** refmat; // leaf ref matrix
  int total; // total leaves
  int naive; // under this value, we perform naive search
  int type; // mat type
  double rho; // under this value, it is a spill tree
  double tau; // the overlapping buffer ratio
};

struct CvResult
{
  int index;
  double distance;
};

// find the farthest node in the "list" from "node"
static inline CvSpillTreeNode*
icvFarthestNode( CvSpillTreeNode* node,
         CvSpillTreeNode* list,
         int total )
{
  double farthest = -1.;
  CvSpillTreeNode* result = NULL;
  for ( int i = 0; i < total; i++ )
    {
      double norm = cvNorm( node->center, list->center );
      if ( norm > farthest )
    {
      farthest = norm;
      result = list;
    }
      list = list->rc;
    }
  return result;
}

// clone a new tree node
static inline CvSpillTreeNode*
icvCloneSpillTreeNode( CvSpillTreeNode* node )
{
  CvSpillTreeNode* result = (CvSpillTreeNode*)cvAlloc( sizeof(CvSpillTreeNode) );
  memcpy( result, node, sizeof(CvSpillTreeNode) );
  return result;
}

// append the link-list of a tree node
static inline void
icvAppendSpillTreeNode( CvSpillTreeNode* node,
            CvSpillTreeNode* append )
{
  if ( node->lc == NULL )
    {
      node->lc = node->rc = append;
      node->lc->lc = node->rc->rc = NULL;
    } else {
      append->lc = node->rc;
      append->rc = NULL;
      node->rc->rc = append;
      node->rc = append;
    }
  node->cc++;
}

#define _dispatch_mat_ptr(x, step) (CV_MAT_DEPTH((x)->type) == CV_32F ? (void*)((x)->data.fl+(step)) : (CV_MAT_DEPTH((x)->type) == CV_64F ? (void*)((x)->data.db+(step)) : (void*)(0)))

static void
icvDFSInitSpillTreeNode( const CvSpillTree* tr,
             const int d,
             CvSpillTreeNode* node )
{
  if ( node->cc <= tr->naive )
    {
      // already get to a leaf, terminate the recursion.
      node->leaf = true;
      node->spill = false;
      return;
    }

  // random select a node, then find a farthest node from this one, then find a farthest from that one...
  // to approximate the farthest node-pair
  static CvRNG rng_state = cvRNG(0xdeadbeef);
  int rn = cvRandInt( &rng_state ) % node->cc;
  CvSpillTreeNode* lnode = NULL;
  CvSpillTreeNode* rnode = node->lc;
  for ( int i = 0; i < rn; i++ )
    rnode = rnode->rc;
  lnode = icvFarthestNode( rnode, node->lc, node->cc );
  rnode = icvFarthestNode( lnode, node->lc, node->cc );

  // u is the projection vector
  node->u = cvCreateMat( 1, d, tr->type );
  cvSub( lnode->center, rnode->center, node->u );
  cvNormalize( node->u, node->u );

  // find the center of node in hyperspace
  node->center = cvCreateMat( 1, d, tr->type );
  cvZero( node->center );
  CvSpillTreeNode* it = node->lc;
  for ( int i = 0; i < node->cc; i++ )
    {
      cvAdd( it->center, node->center, node->center );
      it = it->rc;
    }
  cvConvertScale( node->center, node->center, 1./node->cc );

  // project every node to "u", and find the mean point "mp"
  it = node->lc;
  node->r = -1.;
  node->mp = 0;
  for ( int i = 0; i < node->cc; i++ )
    {
      node->mp += ( it->p = cvDotProduct( it->center, node->u ) );
      double norm = cvNorm( node->center, it->center );
      if ( norm > node->r )
    node->r = norm;
      it = it->rc;
    }
  node->mp = node->mp / node->cc;

  // overlapping buffer and upper bound, lower bound
  double ob = (lnode->p-rnode->p)*tr->tau*.5;
  node->ub = node->mp+ob;
  node->lb = node->mp-ob;
  int sl = 0, l = 0;
  int sr = 0, r = 0;
  it = node->lc;
  for ( int i = 0; i < node->cc; i++ )
    {
      if ( it->p <= node->ub )
    sl++;
      if ( it->p >= node->lb )
    sr++;
      if ( it->p < node->mp )
    l++;
      else
    r++;
      it = it->rc;
    }
  // precision problem, return the node as it is.
  if (( l == 0 )||( r == 0 ))
    {
      cvReleaseMat( &(node->u) );
      cvReleaseMat( &(node->center) );
      node->leaf = true;
      node->spill = false;
      return;
    }
  CvSpillTreeNode* lc = (CvSpillTreeNode*)cvAlloc( sizeof(CvSpillTreeNode) );
  memset(lc, 0, sizeof(CvSpillTreeNode));
  CvSpillTreeNode* rc = (CvSpillTreeNode*)cvAlloc( sizeof(CvSpillTreeNode) );
  memset(rc, 0, sizeof(CvSpillTreeNode));
  lc->lc = lc->rc = rc->lc = rc->rc = NULL;
  lc->cc = rc->cc = 0;
  int undo = cvRound(node->cc*tr->rho);
  if (( sl >= undo )||( sr >= undo ))
    {
      // it is not a spill point (defeatist search disabled)
      it = node->lc;
      for ( int i = 0; i < node->cc; i++ )
    {
      CvSpillTreeNode* next = it->rc;
      if ( it->p < node->mp )
        icvAppendSpillTreeNode( lc, it );
      else
        icvAppendSpillTreeNode( rc, it );
      it = next;
    }
      node->spill = false;
    } else {
      // a spill point
      it = node->lc;
      for ( int i = 0; i < node->cc; i++ )
    {
      CvSpillTreeNode* next = it->rc;
      if ( it->p < node->lb )
        icvAppendSpillTreeNode( lc, it );
      else if ( it->p > node->ub )
        icvAppendSpillTreeNode( rc, it );
      else {
        CvSpillTreeNode* cit = icvCloneSpillTreeNode( it );
        icvAppendSpillTreeNode( lc, it );
        icvAppendSpillTreeNode( rc, cit );
      }
      it = next;
    }
      node->spill = true;
    }
  node->lc = lc;
  node->rc = rc;

  // recursion process
  icvDFSInitSpillTreeNode( tr, d, node->lc );
  icvDFSInitSpillTreeNode( tr, d, node->rc );
}

static CvSpillTree*
icvCreateSpillTree( const CvMat* raw_data,
            const int naive,
            const double rho,
            const double tau )
{
  int n = raw_data->rows;
  int d = raw_data->cols;

  CvSpillTree* tr = (CvSpillTree*)cvAlloc( sizeof(CvSpillTree) );
  tr->root = (CvSpillTreeNode*)cvAlloc( sizeof(CvSpillTreeNode) );
  memset(tr->root, 0, sizeof(CvSpillTreeNode));
  tr->refmat = (CvMat**)cvAlloc( sizeof(CvMat*)*n );
  tr->total = n;
  tr->naive = naive;
  tr->rho = rho;
  tr->tau = tau;
  tr->type = raw_data->type;

  // tie a link-list to the root node
  tr->root->lc = (CvSpillTreeNode*)cvAlloc( sizeof(CvSpillTreeNode) );
  memset(tr->root->lc, 0, sizeof(CvSpillTreeNode));
  tr->root->lc->center = cvCreateMatHeader( 1, d, tr->type );
  cvSetData( tr->root->lc->center, _dispatch_mat_ptr(raw_data, 0), raw_data->step );
  tr->refmat[0] = tr->root->lc->center;
  tr->root->lc->lc = NULL;
  tr->root->lc->leaf = true;
  tr->root->lc->i = 0;
  CvSpillTreeNode* node = tr->root->lc;
  for ( int i = 1; i < n; i++ )
    {
      CvSpillTreeNode* newnode = (CvSpillTreeNode*)cvAlloc( sizeof(CvSpillTreeNode) );
      memset(newnode, 0, sizeof(CvSpillTreeNode));
      newnode->center = cvCreateMatHeader( 1, d, tr->type );
      cvSetData( newnode->center, _dispatch_mat_ptr(raw_data, i*d), raw_data->step );
      tr->refmat[i] = newnode->center;
      newnode->lc = node;
      newnode->i = i;
      newnode->leaf = true;
      newnode->rc = NULL;
      node->rc = newnode;
      node = newnode;
    }
  tr->root->rc = node;
  tr->root->cc = n;
  icvDFSInitSpillTreeNode( tr, d, tr->root );
  return tr;
}

static void
icvSpillTreeNodeHeapify( CvResult * heap,
             int i,
             const int k )
{
  if ( heap[i].index == -1 )
    return;
  int l, r, largest = i;
  CvResult inp;
  do {
    i = largest;
    r = (i+1)<<1;
    l = r-1;
    if (( l < k )&&( heap[l].index == -1 ))
      largest = l;
    else if (( r < k )&&( heap[r].index == -1 ))
      largest = r;
    else {
      if (( l < k )&&( heap[l].distance > heap[i].distance ))
        largest = l;
      if (( r < k )&&( heap[r].distance > heap[largest].distance ))
        largest = r;
    }
    if ( largest != i )
      CV_SWAP( heap[largest], heap[i], inp );
  } while ( largest != i );
}

static void
icvSpillTreeDFSearch( CvSpillTree* tr,
              CvSpillTreeNode* node,
              CvResult* heap,
              int* es,
              const CvMat* desc,
              const int k,
              const int emax,
                      bool * cache)
{
  if ((emax > 0)&&( *es >= emax ))
    return;
  double dist, p=0;
  double distance;
  while ( node->spill )
    {
      // defeatist search
      if ( !node->leaf )
    p = cvDotProduct( node->u, desc );
      if ( p < node->lb && node->lc->cc >= k ) // check the number of children larger than k otherwise you'll skip over better neighbor
    node = node->lc;
      else if ( p > node->ub && node->rc->cc >= k )
    node = node->rc;
      else
    break;
      if ( NULL == node )
    return;
    }
  if ( node->leaf )
    {
      // a leaf, naive search
      CvSpillTreeNode* it = node->lc;
      for ( int i = 0; i < node->cc; i++ )
        {
          if ( !cache[it->i] )
          {
        distance = cvNorm( it->center, desc );
            cache[it->i] = true;
        if (( heap[0].index == -1)||( distance < heap[0].distance ))
          {
                CvResult  current_result;
                current_result.index = it->i;
                current_result.distance = distance;
                heap[0] = current_result;
            icvSpillTreeNodeHeapify( heap, 0, k );
        (*es)++;
          }
          }
          it = it->rc;
    }
      return;
    }
  dist = cvNorm( node->center, desc );
  // impossible case, skip
  if (( heap[0].index != -1 )&&( dist-node->r > heap[0].distance ))
    return;
  p = cvDotProduct( node->u, desc );
  // guided dfs
  if ( p < node->mp )
    {
      icvSpillTreeDFSearch( tr, node->lc, heap, es, desc, k, emax, cache );
      icvSpillTreeDFSearch( tr, node->rc, heap, es, desc, k, emax, cache );
    } else {
    icvSpillTreeDFSearch( tr, node->rc, heap, es, desc, k, emax, cache );
    icvSpillTreeDFSearch( tr, node->lc, heap, es, desc, k, emax, cache );
    }
}

static void
icvFindSpillTreeFeatures( CvSpillTree* tr,
              const CvMat* desc,
              CvMat* results,
              CvMat* dist,
              const int k,
              const int emax )
{
  assert( desc->type == tr->type );
  CvResult* heap = (CvResult*)cvAlloc( k*sizeof(heap[0]) );
  bool* cache = (bool*)cvAlloc( sizeof(bool)*tr->total );
  for ( int j = 0; j < desc->rows; j++ )
    {
      CvMat _desc = cvMat( 1, desc->cols, desc->type, _dispatch_mat_ptr(desc, j*desc->cols) );
      for ( int i = 0; i < k; i++ ) {
        CvResult current;
        current.index=-1;
        current.distance=-1;
    heap[i] = current;
      }
      memset( cache, 0, sizeof(bool)*tr->total );
      int es = 0;
      icvSpillTreeDFSearch( tr, tr->root, heap, &es, &_desc, k, emax, cache );
      CvResult inp;
      for ( int i = k-1; i > 0; i-- )
    {
      CV_SWAP( heap[i], heap[0], inp );
      icvSpillTreeNodeHeapify( heap, 0, i );
    }
      int* rs = results->data.i+j*results->cols;
      double* dt = dist->data.db+j*dist->cols;
      for ( int i = 0; i < k; i++, rs++, dt++ )
    if ( heap[i].index != -1 )
      {
        *rs = heap[i].index;
        *dt = heap[i].distance;
      } else
        *rs = -1;
    }
  cvFree( &heap );
  cvFree( &cache );
}

static void
icvDFSReleaseSpillTreeNode( CvSpillTreeNode* node )
{
  if ( node->leaf )
    {
      CvSpillTreeNode* it = node->lc;
      for ( int i = 0; i < node->cc; i++ )
        {
          CvSpillTreeNode* s = it;
          it = it->rc;
          cvFree( &s );
        }
    } else {
      cvReleaseMat( &node->u );
      cvReleaseMat( &node->center );
      icvDFSReleaseSpillTreeNode( node->lc );
      icvDFSReleaseSpillTreeNode( node->rc );
    }
  cvFree( &node );
}

static void
icvReleaseSpillTree( CvSpillTree** tr )
{
  for ( int i = 0; i < (*tr)->total; i++ )
    cvReleaseMat( &((*tr)->refmat[i]) );
  cvFree( &((*tr)->refmat) );
  icvDFSReleaseSpillTreeNode( (*tr)->root );
  cvFree( tr );
}

class CvSpillTreeWrap : public CvFeatureTree {
  CvSpillTree* tr;
public:
  CvSpillTreeWrap(const CvMat* raw_data,
          const int naive,
          const double rho,
          const double tau) {
    tr = icvCreateSpillTree(raw_data, naive, rho, tau);
  }
  ~CvSpillTreeWrap() {
    icvReleaseSpillTree(&tr);
  }

  void FindFeatures(const CvMat* desc, int k, int emax, CvMat* results, CvMat* dist) {
    icvFindSpillTreeFeatures(tr, desc, results, dist, k, emax);
  }
};

CvFeatureTree* cvCreateSpillTree( const CvMat* raw_data,
                  const int naive,
                  const double rho,
                  const double tau ) {
  return new CvSpillTreeWrap(raw_data, naive, rho, tau);
}