C++数据结构——二叉树之二AVL树

本次实现的AVL树为基于上一篇实现的普通二叉树之上实现的,普通二叉树的实现已经在https://blog.csdn.net/qq811299838/article/details/104038745这篇文章中列出,此处就不再放出来了。

AVL树使用普通二叉树的功能采用组合方式而非继承。

AVL树最大的特点是每一个结点的左右子树的高度差都在2以内,特点比较简单,因此也容易实现。

编译环境:GCC 7.3、vs 2005

不多说,直接上代码

#ifndef __AVL_TREE_H__
#define __AVL_TREE_H__

#if __cplusplus >= 201103L
#include <type_traits> // std::forward、std::move
#endif

#if __cplusplus >= 201103L
#define null nullptr
#else
#define null NULL
#endif 

#include "btree.h"

template<typename _Tp>
struct Comparator
{
    int operator()(const _Tp &a1, const _Tp &a2)
    { 
        if(a1 < a2) return 1;
        if(a2 < a1) return -1;
        return 0;
    }
};

template<typename _Tp, typename _Compare = Comparator<_Tp>>
class AVLTree
{
public:
    typedef _Tp                  value_type;
    typedef _Tp &                reference;
    typedef _Tp *                pointer;
    typedef const _Tp &          const_reference;
    typedef unsigned long        size_type;
    typedef _Compare             compare_type;

#if __cplusplus >= 201103L
    typedef _Tp &&        rvalue_reference;
#endif

    typedef BinaryTree<value_type> tree_type;
    typedef typename tree_type::node_type node_type;

public:
    typedef node_type*(*iterator_func)(node_type*);
    
    template<iterator_func _Next, iterator_func _Prev>
    struct iterator_impl
    {
        node_type *_M_node;

        iterator_impl(node_type *node = null)
         : _M_node(node) { }
        
        iterator_impl operator++()
        { return iterator_impl(_M_node = _Next(_M_node)); }

        iterator_impl operator++(int)
        { 
            iterator_impl ret(_M_node);
            _M_node = _Next(_M_node);
            return ret;
        }

        iterator_impl operator--()
        { return iterator_impl(_M_node = _Prev(_M_node)); }

        iterator_impl operator--(int)
        {
            iterator_impl ret(_M_node);
            _M_node = _Prev(_M_node);
            return ret;
        }

        reference operator*()
		{ return *_M_node->value(); }

		pointer operator->()
		{ return _M_node->value(); }

        bool operator==(const iterator_impl &it) const 
		{ return _M_node == it._M_node; }

		bool operator!=(const iterator_impl &it) const
		{ return _M_node != it._M_node; }
    };
    
    template<iterator_func _Next, iterator_func _Prev>
    struct const_iterator_impl
    {
        const node_type *_M_node;

        const_iterator_impl(const node_type *node = null)
         : _M_node(node) { }

        const_iterator_impl(const iterator_impl<_Next, _Prev> &it)
         : _M_node(it._M_node) { }
        
        const_iterator_impl operator++()
        { return const_iterator_impl(_M_node = _Next(const_cast<node_type*>(_M_node))); }

        const_iterator_impl operator++(int)
        { 
            const_iterator_impl ret(_M_node);
            _M_node = _Next(const_cast<node_type*>(_M_node));
            return ret;
        }

        const_iterator_impl operator--()
        { return const_iterator_impl(_M_node = _Prev(const_cast<node_type*>(_M_node))); }

        const_iterator_impl operator--(int)
        {
            const_iterator_impl ret(_M_node);
            _M_node = _Prev(const_cast<node_type*>(_M_node));
            return ret;
        }

        reference operator*()
		{ return *_M_node->value(); }

		pointer operator->()
		{ return _M_node->value(); }

        bool operator==(const const_iterator_impl &it) const 
		{ return _M_node == it._M_node; }

		bool operator!=(const const_iterator_impl &it) const
		{ return _M_node != it._M_node; }
    };

public:
    typedef iterator_impl<&tree_type::middle_next, &tree_type::middle_previous> iterator;
    typedef iterator_impl<&tree_type::middle_previous, &tree_type::middle_next> reverse_iterator;
    typedef const_iterator_impl<&tree_type::middle_next, &tree_type::middle_previous> const_iterator;
    typedef const_iterator_impl<&tree_type::middle_previous, &tree_type::middle_next> const_reverse_iterator;

public:
    AVLTree() { }

    AVLTree(const AVLTree &tree)
     : _M_tree(tree._M_tree) { }

    template<typename _ForwardIterator>
    AVLTree(_ForwardIterator b, _ForwardIterator e)
    {
        while(b != e)
        { insert(*b++); }
    }

#if __cplusplus >= 201103L
    AVLTree(AVLTree &&tree)
     : _M_tree(std::move(tree._M_tree)) { }
#endif

    size_type size() const 
    { return _M_tree.size(); }

    size_type depth() const 
    { return _M_tree.depth(); }

    bool empty() const 
    { return size() == 0; }

    iterator begin()
    { return iterator(_M_tree.left_child_under(_M_tree.root())); }

    iterator end() 
    { return iterator(); }

    const_iterator begin() const 
    { return const_iterator(_M_tree.left_child_under(_M_tree.root())); }

    const_iterator end() const 
    { return const_iterator(); }

    const_iterator cbegin() const 
    { return begin(); }

    const_iterator cend() const 
    { return end(); }

    reverse_iterator rbegin() 
    { return reverse_iterator(_M_tree.right_child_under(_M_tree.root())); }

    reverse_iterator rend() 
    { return reverse_iterator(); }

    const_reverse_iterator rbegin() const 
    { return const_reverse_iterator(_M_tree.right_child_under(_M_tree.root())); }

    const_reverse_iterator rend() const 
    { return const_reverse_iterator(); }

    const_reverse_iterator crbegin() const 
    { return rbegin(); }

    const_reverse_iterator crend() const 
    { return rend(); }

    iterator insert(const_reference v)
    { return _M_insert(v); }

#if __cplusplus >= 201103L
    iterator insert(rvalue_reference v)
    { return _M_insert(std::move(v)); }
#endif
    // 删除后,迭代器失效
    void erase(const_iterator it)
    { _M_adjust(_M_tree.erase(const_cast<node_type*>(it._M_node))); }

    void clear()
    { _M_tree.clear(); }

    iterator find(const_reference v)
    { return _M_find<value_type, compare_type>(v); }

    template<typename _CompareType>
    const_iterator find(const_reference v) const
    { return _M_find<value_type, _CompareType>(v); }

    const tree_type& get_tree() const
    { return _M_tree; }

private:
    iterator _M_insert(const_reference v)
    {
        iterator found;
        if(_M_find_and_insert<value_type, compare_type>(v, found))
        { *found = v; }
        return found;
    }
#if __cplusplus >= 201103L
    iterator _M_insert(rvalue_reference v)
    {
        iterator found;
        if(_M_find_and_insert<value_type, compare_type>(std::move(v), found))
        { *found = v; }
        return found;
    }
#endif

    template<typename _InputType, typename _CompareType>
    iterator _M_find(const _InputType &input)
    {
        node_type *node = _M_tree.root();
        while(null != node)
        {
            int res = _CompareType()(input, *node->value());
            if(res == 0)
            { return iterator(node); }
            if(res > 0)
            { node = node->left_child(); }
            else
            { node = node->right_child(); }
        }
        return iterator();
    }
    /* 进行旋转操作
     * @first  第一次旋转的支点结点,当该结点的左右子树深度一致时,不进行旋转
     * @second  第二次旋转的支点结点,必然旋转
     * @rotate_func  第二次旋转的结点的旋转操作函数
     * @return  返回新的支点结点
     */
    node_type* _M_rotate(node_type *first, node_type *second, node_type*(tree_type::*rot)(node_type*))
    {
        size_type left_depth = null == first->left_child() ? 0 : first->left_child()->depth();
        size_type right_depth = null == first->right_child() ? 0 : first->right_child()->depth();
        
        if(left_depth < right_depth)
        { _M_tree.left_rotate(first); }
        else if(right_depth < left_depth)
        { _M_tree.right_rotate(first); }

        return (_M_tree.*rot)(second);
    }

    // 插入新结点后,以新结点为基点进行向上调整整棵树
    void _M_adjust(node_type *node)
    {
        node_type *visit = node;  
        while(null != visit)
        {
            size_type left_depth = null == visit->left_child() ? 0 : visit->left_child()->depth();
            size_type right_depth = null == visit->right_child() ? 0 : visit->right_child()->depth();

            if(left_depth > right_depth && left_depth - right_depth >= 2)
            { visit = _M_rotate(visit->left_child(), visit, &tree_type::right_rotate); }
            else if(right_depth > left_depth && right_depth - left_depth >= 2)
            { visit = _M_rotate(visit->right_child(), visit, &tree_type::left_rotate);  }

            visit = visit->parent();
        } 
    }

    /* 插入结点,如果结点不存在则插入新结点
     * @input  插入的值
     * @result  结点的迭代器
     * @return  如果结点本来已存在,则返回true
     */
    template<typename _InputType, typename _CompareType>
    bool _M_find_and_insert(const _InputType &input, iterator &result)
    {
        if(empty())
        { 
            result._M_node = _M_tree.append_root(input);
            return false; 
        }

        node_type *node = _M_tree.root();
        while(true)
        {
            int res = _CompareType()(input, *node->value());
            if(res == 0)
            { 
                result._M_node = node;
                return true; 
            }
            if(res > 0)
            {
                if(null == node->left_child())
                {
                    node = _M_tree.append_left(node, input);
                    break;
                }
                node = node->left_child();
            }
            else
            {
                if(null == node->right_child())
                {
                    node = _M_tree.append_right(node, input);
                    break;
                }
                node = node->right_child();
            }
        }
        _M_adjust(node);
        result._M_node = node;
        return false;
    }

#if __cplusplus >= 201103L
    template<typename _InputType, typename _CompareType>
    bool _M_find_and_insert(_InputType &&input, iterator &result)
    {
        if(empty())
        { 
            result._M_node = _M_tree.append_root(input);
            return false; 
        }

        node_type *node = _M_tree.root();
        while(true)
        {
            int res = _CompareType()(std::forward<value_type>(input), *node->value());
            if(res == 0)
            { 
                result._M_node = node;
                return true; 
            }
            if(res > 0)
            {
                if(null == node->left_child())
                {
                    node = _M_tree.append_left(node, std::move(input));
                    break;
                }
                node = node->left_child();
            }
            else
            {
                if(null == node->right_child())
                {
                    node = _M_tree.append_right(node, std::move(input));
                    break;
                }
                node = node->right_child();
            }
        }
        _M_adjust(node);
        result._M_node = node;
        return false;
    }
#endif

private:
    tree_type _M_tree;
};

#endif

下面是测试代码:

#include <iostream>
#include <list>
#include "avltree.h"

#if __cplusplus < 201103L
#include <sstream>
#endif

#define MAX_NUMBER_BIT 5

static std::string get_string(int v)
{ 
#if __cplusplus < 201103L
	std::stringstream ss;
	ss << v;
	std::string tmp = ss.str();
#else
	std::string tmp = std::to_string(v);
#endif
	std::string result = "";
	for(std::size_t i = 0; i < (MAX_NUMBER_BIT - tmp.size()) / 2; ++i)
	{ result += ' '; }
	result += tmp;
	for(std::size_t i = 0; i < (MAX_NUMBER_BIT - tmp.size()) / 2; ++i)
	{ result += ' '; }
	return result; 
}

struct T
{
	T(int v = 0) : value(v) { }

	void print()
	{ std::cout << get_string(value); }

	bool operator<(const T& t) const 
	{ return value < t.value; }

	int value;
};

typedef AVLTree<T> TestTree;
typedef TestTree::node_type NodeType;

static void print_tree(const BinaryTree<T> &tree)
{
	std::list<const NodeType*> s;
	s.push_back(tree.root());
	bool break_flag = false;
	while(!break_flag)
	{
		break_flag = true;
		std::size_t count = s.size();
		int print_count = 0;
		while(count-- > 0)
		{
			const NodeType *t = s.front();
			s.pop_front();
			if(null == t)
			{ 
				T().print();
				s.push_back(null);
				s.push_back(null); 
			}
			else 
			{ 
				t->value()->print(); 
				s.push_back(t->left_child());
				s.push_back(t->right_child());
				break_flag = false;
			}
			if(++print_count % 2 == 0)
			{ std::cout << "|"; }
		}
		std::cout << std::endl;
	}
}
void main_func()
{
	TestTree avl;
	avl.insert(T(10));
	avl.insert(T(30));
	avl.insert(T(1));
	avl.insert(T(31));
	TestTree::iterator it1 = avl.insert(T(32));
	avl.insert(T(33));
	avl.insert(T(34));
	avl.insert(T(35));
	avl.insert(T(36));
	avl.insert(T(37));
	avl.insert(T(38));

	std::cout << "size: " << avl.size() << std::endl;
	std::cout << "depth: " << avl.depth() << std::endl;
	print_tree(avl.get_tree());

	std::cout << "--------------erase node 1----------------" << std::endl;
	avl.erase(it1);
	std::cout << "size: " << avl.size() << std::endl;
	std::cout << "depth: " << avl.depth() << std::endl;
	print_tree(avl.get_tree());
	std::cout << "--------------erase node 2----------------" << std::endl;
	avl.erase(avl.find(T(35)));
	std::cout << "size: " << avl.size() << std::endl;
	std::cout << "depth: " << avl.depth() << std::endl;
	print_tree(avl.get_tree());

	std::cout << std::endl << "--------------iterator visit-----------" << std::endl;
	for(TestTree::const_iterator it = avl.begin(); it != avl.end(); ++it)
	{
		std::cout << it->value << ' ';
	}
	std::cout << std::endl;
}

int main() 
{
    main_func();
    system("pause");
    return 0;
}

测试结果:

发布了19 篇原创文章 · 获赞 1 · 访问量 2774

猜你喜欢

转载自blog.csdn.net/qq811299838/article/details/104210700