/* Original code has been submitted by Liu Liu. ---------------------------------------------------------------------------------- * Spill-Tree for Approximate KNN Search * Author: Liu Liu * mailto: liuliu.1987+opencv@gmail.com * Refer to Paper: * An Investigation of Practical Approximate Nearest Neighbor Algorithms * cvMergeSpillTree TBD * * Redistribution and use in source and binary forms, with or * without modification, are permitted provided that the following * conditions are met: * Redistributions of source code must retain the above * copyright notice, this list of conditions and the following * disclaimer. * Redistributions in binary form must reproduce the above * copyright notice, this list of conditions and the following * disclaimer in the documentation and/or other materials * provided with the distribution. * The name of Contributor may not be used to endorse or * promote products derived from this software without * specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND * CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, * INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF * MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL THE CONTRIBUTORS BE LIABLE FOR ANY * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, * OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR * TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT * OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY * OF SUCH DAMAGE. */ #include "precomp.hpp" #include "_featuretree.h" struct CvSpillTreeNode { bool leaf; // is leaf or not (leaf is the point that have no more child) bool spill; // is not a non-overlapping point (defeatist search) CvSpillTreeNode* lc; // left child (<) CvSpillTreeNode* rc; // right child (>) int cc; // child count CvMat* u; // projection vector CvMat* center; // center int i; // original index double r; // radius of remaining feature point double ub; // upper bound double lb; // lower bound double mp; // mean point double p; // projection value }; struct CvSpillTree { CvSpillTreeNode* root; CvMat** refmat; // leaf ref matrix int total; // total leaves int naive; // under this value, we perform naive search int type; // mat type double rho; // under this value, it is a spill tree double tau; // the overlapping buffer ratio }; struct CvResult { int index; double distance; }; // find the farthest node in the "list" from "node" static inline CvSpillTreeNode* icvFarthestNode( CvSpillTreeNode* node, CvSpillTreeNode* list, int total ) { double farthest = -1.; CvSpillTreeNode* result = NULL; for ( int i = 0; i < total; i++ ) { double norm = cvNorm( node->center, list->center ); if ( norm > farthest ) { farthest = norm; result = list; } list = list->rc; } return result; } // clone a new tree node static inline CvSpillTreeNode* icvCloneSpillTreeNode( CvSpillTreeNode* node ) { CvSpillTreeNode* result = (CvSpillTreeNode*)cvAlloc( sizeof(CvSpillTreeNode) ); memcpy( result, node, sizeof(CvSpillTreeNode) ); return result; } // append the link-list of a tree node static inline void icvAppendSpillTreeNode( CvSpillTreeNode* node, CvSpillTreeNode* append ) { if ( node->lc == NULL ) { node->lc = node->rc = append; node->lc->lc = node->rc->rc = NULL; } else { append->lc = node->rc; append->rc = NULL; node->rc->rc = append; node->rc = append; } node->cc++; } #define _dispatch_mat_ptr(x, step) (CV_MAT_DEPTH((x)->type) == CV_32F ? (void*)((x)->data.fl+(step)) : (CV_MAT_DEPTH((x)->type) == CV_64F ? (void*)((x)->data.db+(step)) : (void*)(0))) static void icvDFSInitSpillTreeNode( const CvSpillTree* tr, const int d, CvSpillTreeNode* node ) { if ( node->cc <= tr->naive ) { // already get to a leaf, terminate the recursion. node->leaf = true; node->spill = false; return; } // random select a node, then find a farthest node from this one, then find a farthest from that one... // to approximate the farthest node-pair static CvRNG rng_state = cvRNG(0xdeadbeef); int rn = cvRandInt( &rng_state ) % node->cc; CvSpillTreeNode* lnode = NULL; CvSpillTreeNode* rnode = node->lc; for ( int i = 0; i < rn; i++ ) rnode = rnode->rc; lnode = icvFarthestNode( rnode, node->lc, node->cc ); rnode = icvFarthestNode( lnode, node->lc, node->cc ); // u is the projection vector node->u = cvCreateMat( 1, d, tr->type ); cvSub( lnode->center, rnode->center, node->u ); cvNormalize( node->u, node->u ); // find the center of node in hyperspace node->center = cvCreateMat( 1, d, tr->type ); cvZero( node->center ); CvSpillTreeNode* it = node->lc; for ( int i = 0; i < node->cc; i++ ) { cvAdd( it->center, node->center, node->center ); it = it->rc; } cvConvertScale( node->center, node->center, 1./node->cc ); // project every node to "u", and find the mean point "mp" it = node->lc; node->r = -1.; node->mp = 0; for ( int i = 0; i < node->cc; i++ ) { node->mp += ( it->p = cvDotProduct( it->center, node->u ) ); double norm = cvNorm( node->center, it->center ); if ( norm > node->r ) node->r = norm; it = it->rc; } node->mp = node->mp / node->cc; // overlapping buffer and upper bound, lower bound double ob = (lnode->p-rnode->p)*tr->tau*.5; node->ub = node->mp+ob; node->lb = node->mp-ob; int sl = 0, l = 0; int sr = 0, r = 0; it = node->lc; for ( int i = 0; i < node->cc; i++ ) { if ( it->p <= node->ub ) sl++; if ( it->p >= node->lb ) sr++; if ( it->p < node->mp ) l++; else r++; it = it->rc; } // precision problem, return the node as it is. if (( l == 0 )||( r == 0 )) { cvReleaseMat( &(node->u) ); cvReleaseMat( &(node->center) ); node->leaf = true; node->spill = false; return; } CvSpillTreeNode* lc = (CvSpillTreeNode*)cvAlloc( sizeof(CvSpillTreeNode) ); memset(lc, 0, sizeof(CvSpillTreeNode)); CvSpillTreeNode* rc = (CvSpillTreeNode*)cvAlloc( sizeof(CvSpillTreeNode) ); memset(rc, 0, sizeof(CvSpillTreeNode)); lc->lc = lc->rc = rc->lc = rc->rc = NULL; lc->cc = rc->cc = 0; int undo = cvRound(node->cc*tr->rho); if (( sl >= undo )||( sr >= undo )) { // it is not a spill point (defeatist search disabled) it = node->lc; for ( int i = 0; i < node->cc; i++ ) { CvSpillTreeNode* next = it->rc; if ( it->p < node->mp ) icvAppendSpillTreeNode( lc, it ); else icvAppendSpillTreeNode( rc, it ); it = next; } node->spill = false; } else { // a spill point it = node->lc; for ( int i = 0; i < node->cc; i++ ) { CvSpillTreeNode* next = it->rc; if ( it->p < node->lb ) icvAppendSpillTreeNode( lc, it ); else if ( it->p > node->ub ) icvAppendSpillTreeNode( rc, it ); else { CvSpillTreeNode* cit = icvCloneSpillTreeNode( it ); icvAppendSpillTreeNode( lc, it ); icvAppendSpillTreeNode( rc, cit ); } it = next; } node->spill = true; } node->lc = lc; node->rc = rc; // recursion process icvDFSInitSpillTreeNode( tr, d, node->lc ); icvDFSInitSpillTreeNode( tr, d, node->rc ); } static CvSpillTree* icvCreateSpillTree( const CvMat* raw_data, const int naive, const double rho, const double tau ) { int n = raw_data->rows; int d = raw_data->cols; CvSpillTree* tr = (CvSpillTree*)cvAlloc( sizeof(CvSpillTree) ); tr->root = (CvSpillTreeNode*)cvAlloc( sizeof(CvSpillTreeNode) ); memset(tr->root, 0, sizeof(CvSpillTreeNode)); tr->refmat = (CvMat**)cvAlloc( sizeof(CvMat*)*n ); tr->total = n; tr->naive = naive; tr->rho = rho; tr->tau = tau; tr->type = raw_data->type; // tie a link-list to the root node tr->root->lc = (CvSpillTreeNode*)cvAlloc( sizeof(CvSpillTreeNode) ); memset(tr->root->lc, 0, sizeof(CvSpillTreeNode)); tr->root->lc->center = cvCreateMatHeader( 1, d, tr->type ); cvSetData( tr->root->lc->center, _dispatch_mat_ptr(raw_data, 0), raw_data->step ); tr->refmat[0] = tr->root->lc->center; tr->root->lc->lc = NULL; tr->root->lc->leaf = true; tr->root->lc->i = 0; CvSpillTreeNode* node = tr->root->lc; for ( int i = 1; i < n; i++ ) { CvSpillTreeNode* newnode = (CvSpillTreeNode*)cvAlloc( sizeof(CvSpillTreeNode) ); memset(newnode, 0, sizeof(CvSpillTreeNode)); newnode->center = cvCreateMatHeader( 1, d, tr->type ); cvSetData( newnode->center, _dispatch_mat_ptr(raw_data, i*d), raw_data->step ); tr->refmat[i] = newnode->center; newnode->lc = node; newnode->i = i; newnode->leaf = true; newnode->rc = NULL; node->rc = newnode; node = newnode; } tr->root->rc = node; tr->root->cc = n; icvDFSInitSpillTreeNode( tr, d, tr->root ); return tr; } static void icvSpillTreeNodeHeapify( CvResult * heap, int i, const int k ) { if ( heap[i].index == -1 ) return; int l, r, largest = i; CvResult inp; do { i = largest; r = (i+1)<<1; l = r-1; if (( l < k )&&( heap[l].index == -1 )) largest = l; else if (( r < k )&&( heap[r].index == -1 )) largest = r; else { if (( l < k )&&( heap[l].distance > heap[i].distance )) largest = l; if (( r < k )&&( heap[r].distance > heap[largest].distance )) largest = r; } if ( largest != i ) CV_SWAP( heap[largest], heap[i], inp ); } while ( largest != i ); } static void icvSpillTreeDFSearch( CvSpillTree* tr, CvSpillTreeNode* node, CvResult* heap, int* es, const CvMat* desc, const int k, const int emax, bool * cache) { if ((emax > 0)&&( *es >= emax )) return; double dist, p=0; double distance; while ( node->spill ) { // defeatist search if ( !node->leaf ) p = cvDotProduct( node->u, desc ); if ( p < node->lb && node->lc->cc >= k ) // check the number of children larger than k otherwise you'll skip over better neighbor node = node->lc; else if ( p > node->ub && node->rc->cc >= k ) node = node->rc; else break; if ( NULL == node ) return; } if ( node->leaf ) { // a leaf, naive search CvSpillTreeNode* it = node->lc; for ( int i = 0; i < node->cc; i++ ) { if ( !cache[it->i] ) { distance = cvNorm( it->center, desc ); cache[it->i] = true; if (( heap[0].index == -1)||( distance < heap[0].distance )) { CvResult current_result; current_result.index = it->i; current_result.distance = distance; heap[0] = current_result; icvSpillTreeNodeHeapify( heap, 0, k ); (*es)++; } } it = it->rc; } return; } dist = cvNorm( node->center, desc ); // impossible case, skip if (( heap[0].index != -1 )&&( dist-node->r > heap[0].distance )) return; p = cvDotProduct( node->u, desc ); // guided dfs if ( p < node->mp ) { icvSpillTreeDFSearch( tr, node->lc, heap, es, desc, k, emax, cache ); icvSpillTreeDFSearch( tr, node->rc, heap, es, desc, k, emax, cache ); } else { icvSpillTreeDFSearch( tr, node->rc, heap, es, desc, k, emax, cache ); icvSpillTreeDFSearch( tr, node->lc, heap, es, desc, k, emax, cache ); } } static void icvFindSpillTreeFeatures( CvSpillTree* tr, const CvMat* desc, CvMat* results, CvMat* dist, const int k, const int emax ) { assert( desc->type == tr->type ); CvResult* heap = (CvResult*)cvAlloc( k*sizeof(heap[0]) ); bool* cache = (bool*)cvAlloc( sizeof(bool)*tr->total ); for ( int j = 0; j < desc->rows; j++ ) { CvMat _desc = cvMat( 1, desc->cols, desc->type, _dispatch_mat_ptr(desc, j*desc->cols) ); for ( int i = 0; i < k; i++ ) { CvResult current; current.index=-1; current.distance=-1; heap[i] = current; } memset( cache, 0, sizeof(bool)*tr->total ); int es = 0; icvSpillTreeDFSearch( tr, tr->root, heap, &es, &_desc, k, emax, cache ); CvResult inp; for ( int i = k-1; i > 0; i-- ) { CV_SWAP( heap[i], heap[0], inp ); icvSpillTreeNodeHeapify( heap, 0, i ); } int* rs = results->data.i+j*results->cols; double* dt = dist->data.db+j*dist->cols; for ( int i = 0; i < k; i++, rs++, dt++ ) if ( heap[i].index != -1 ) { *rs = heap[i].index; *dt = heap[i].distance; } else *rs = -1; } cvFree( &heap ); cvFree( &cache ); } static void icvDFSReleaseSpillTreeNode( CvSpillTreeNode* node ) { if ( node->leaf ) { CvSpillTreeNode* it = node->lc; for ( int i = 0; i < node->cc; i++ ) { CvSpillTreeNode* s = it; it = it->rc; cvFree( &s ); } } else { cvReleaseMat( &node->u ); cvReleaseMat( &node->center ); icvDFSReleaseSpillTreeNode( node->lc ); icvDFSReleaseSpillTreeNode( node->rc ); } cvFree( &node ); } static void icvReleaseSpillTree( CvSpillTree** tr ) { for ( int i = 0; i < (*tr)->total; i++ ) cvReleaseMat( &((*tr)->refmat[i]) ); cvFree( &((*tr)->refmat) ); icvDFSReleaseSpillTreeNode( (*tr)->root ); cvFree( tr ); } class CvSpillTreeWrap : public CvFeatureTree { CvSpillTree* tr; public: CvSpillTreeWrap(const CvMat* raw_data, const int naive, const double rho, const double tau) { tr = icvCreateSpillTree(raw_data, naive, rho, tau); } ~CvSpillTreeWrap() { icvReleaseSpillTree(&tr); } void FindFeatures(const CvMat* desc, int k, int emax, CvMat* results, CvMat* dist) { icvFindSpillTreeFeatures(tr, desc, results, dist, k, emax); } }; CvFeatureTree* cvCreateSpillTree( const CvMat* raw_data, const int naive, const double rho, const double tau ) { return new CvSpillTreeWrap(raw_data, naive, rho, tau); }