C/C++学习笔记 Vantage Point Trees的C++实现

        下面代码是VP 树的 C++ 实现,递归search()函数决定是搜索左孩子、右孩子还是两个孩子。为了有效地维护结果列表,我们使用优先级队列。

// A VP-Tree implementation, by Steve Hanov. ([email protected])
// Released to the Public Domain
// Based on "Data Structures and Algorithms for Nearest Neighbor Search" by Peter N. Yianilos
#include <stdlib.h>
#include <algorithm>
#include <vector>
#include <stdio.h>
#include <queue>
#include <limits>

template<typename T, double (*distance)( const T&, const T& )>
class VpTree
{
public:
    VpTree() : _root(0) {}

    ~VpTree() {
        delete _root;
    }

    void create( const std::vector& items ) {
        delete _root;
        _items = items;
        _root = buildFromPoints(0, items.size());
    }

    void search( const T& target, int k, std::vector* results, 
        std::vector<double>* distances) 
    {
        std::priority_queue<HeapItem> heap;

        _tau = std::numeric_limits::max();
        search( _root, target, k, heap );

        results->clear(); distances->clear();

        while( !heap.empty() ) {
            results->push_back( _items[heap.top().index] );
            distances->push_back( heap.top().dist );
            heap.pop();
        }

        std::reverse( results->begin(), results->end() );
        std::reverse( distances->begin(), distances->end() );
    }

private:
    std::vector<T> _items;
    double _tau;

    struct Node 
    {
        int index;
        double threshold;
        Node* left;
        Node* right;

        Node() :
            index(0), threshold(0.), left(0), right(0) {}

        ~Node() {
            delete left;
            delete right;
        }
    }* _root;

    struct HeapItem {
        HeapItem( int index, double dist) :
            index(index), dist(dist) {}
        int index;
        double dist;
        bool operator<( const HeapItem& o ) const {
            return dist < o.dist;   
        }
    };

    struct DistanceComparator
    {
        const T& item;
        DistanceComparator( const T& item ) : item(item) {}
        bool operator()(const T& a, const T& b) {
            return distance( item, a ) < distance( item, b );
        }
    };

    Node* buildFromPoints( int lower, int upper )
    {
        if ( upper == lower ) {
            return NULL;
        }

        Node* node = new Node();
        node->index = lower;

        if ( upper - lower > 1 ) {

            // choose an arbitrary point and move it to the start
            int i = (int)((double)rand() / RAND_MAX * (upper - lower - 1) ) + lower;
            std::swap( _items[lower], _items[i] );

            int median = ( upper + lower ) / 2;

            // partitian around the median distance
            std::nth_element( 
                _items.begin() + lower + 1, 
                _items.begin() + median,
                _items.begin() + upper,
                DistanceComparator( _items[lower] ));

            // what was the median?
            node->threshold = distance( _items[lower], _items[median] );

            node->index = lower;
            node->left = buildFromPoints( lower + 1, median );
            node->right = buildFromPoints( median, upper );
        }

        return node;

        下面是具体的调用,需要用到下面的数据

#include "VpTree.h"
#include <stdio.h>
#include <stdlib.h>
#include <sys/time.h>
#include <stdint.h>
#include <string>
#include <string.h>
#include <math.h>

#define DIM 200
#define NUM 32000

void QueryPerformanceCounter( uint64_t* val )
{
    timeval tv;
    struct timezone tz = {0, 0};
    gettimeofday( &tv, &tz );
    *val = tv.tv_sec * 1000000 + tv.tv_usec;
}

struct Point {
    std::string city;
    double latitude;
    double longitude;
};

double distance( const Point& p1, const Point& p2 )
{
    double a = (p1.latitude-p2.latitude);
    double b = (p1.longitude-p2.longitude);
    return sqrt(a*a+b*b);
}

struct HeapItem {
    HeapItem( int index, double dist) :
        index(index), dist(dist) {}
    int index;
    double dist;
    bool operator<( const HeapItem& o ) const {
        return dist < o.dist;   
    }
};

void linear_search( const std::vector<Point>& items, const Point& target, int k, std::vector<Point>* results, 
    std::vector<double>* distances) 
{
    std::priority_queue<HeapItem> heap;
    for ( int i = 0; i < items.size(); i++ ) {
        double dist = distance( target, items[i] );
        if ( heap.size() < k || dist < heap.top().dist ) {
            if (heap.size() == k) heap.pop();
            heap.push( HeapItem( i, dist ) );
        }
    }

    results->clear();
    distances->clear();
    while( !heap.empty() ) {
        results->push_back( items[heap.top().index] );
        distances->push_back( heap.top().dist );
        heap.pop();
    }

    std::reverse( results->begin(), results->end() );
    std::reverse( distances->begin(), distances->end() );
}

int main( int argc, char* argv[] ) {
    std::vector<Point> points;
    printf("Reading cities database...\n");
    FILE* file = fopen("cities.txt", "rt");
    for(;;) {
        char buffer[1000];
        Point point;
        if ( !fgets(buffer, 1000, file ) ) {
            fclose( file );
            break;
        }
        point.city = buffer;
        size_t comma = point.city.rfind(",");
        sscanf(buffer + comma + 1, "%lg", &point.longitude);
        comma = point.city.rfind(",", comma-1);
        sscanf(buffer + comma + 1, "%lg", &point.latitude);
        //printf("%lg, %lg\n", point.latitude, point.longitude);
        points.push_back(point);
        //if(points.size()>50000)break;
    }
    
    VpTree<Point, distance> tree;
    uint64_t start, end;
    QueryPerformanceCounter( &start );
    tree.create( points );
    QueryPerformanceCounter( &end );
    printf("Create took %d\n", (int)(end-start));

    Point point;
    point.latitude = 43.466438;
    point.longitude = -80.519185;
    std::vector<Point> results;
    std::vector<double> distances;

    QueryPerformanceCounter( &start );
    tree.search( point, 8, &results, &distances );
    QueryPerformanceCounter( &end );
    printf("Search took %d\n", (int)(end-start));

    for( int i = 0; i < results.size(); i++ ) {
        printf("%s %lg\n", results[i].city.c_str(), distances[i]);
    }

    printf("---\n");
    QueryPerformanceCounter( &start );
    linear_search( points, point, 8, &results, &distances );
    QueryPerformanceCounter( &end );
    printf("Linear search took %d\n", (int)(end-start));

    for( int i = 0; i < results.size(); i++ ) {
        printf("%s %lg\n", results[i].city.c_str(), distances[i]);
    }


    return 0;
}

        上面程序运行时候的

http://stevehanov.ca/blog/cities.txt.gzicon-default.png?t=N6B9http://stevehanov.ca/blog/cities.txt.gz

猜你喜欢

转载自blog.csdn.net/bashendixie5/article/details/132167539
今日推荐