(MCTS)蒙特卡洛树搜索——参数寻优

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接: https://blog.csdn.net/weixin_38316806/article/details/102631077

Zero.写作动机

对给定参数区间内部进行搜索,寻找到最优参数近似解的方法有很多。比如网格搜索。但是网格搜索太过暴力,往往花销过大。这里介绍一种新的参数寻优方法——蒙特卡洛树搜索
网络上关于蒙特卡洛方法几乎清一色都是在介绍Buffon实验并以此估计某个量。这里,我们介绍蒙特卡洛树用于参数寻优。

一、模型原理

下面推荐几个博客,这些文章已经介绍得很好了:

①https://blog.csdn.net/ljyt2/article/details/78332802
②https://www.jianshu.com/p/a34f06885ef8

二、编程实现

Version one: Python
https://www.jianshu.com/p/a34f06885ef8

Version Two: Matlab
鉴于实际需求,笔者在Python版本的基础上实现了matlab版本,涉及到matlab的面向对象编程。读者诸君按需获取即可

state.m文件

classdef State < handle
    properties
        value
        round
        choices
        PATH
        x2 %因为假定现在只用MCST找到第三步迭代的最优参数
        y 
        sigma
        im
    end
    methods
        function self= State(x2,y,sigma,im)
        %在这里进行初始化
        self.value = 0;
        self.round = 0;
        self.choices = [];
        self.PATH = [0.1:0.2:3];
        self.x2 = x2;
        self.y = y;
        self.sigma = sigma;
        self.im = im;
       
        end
        
        function state = new_state(self)
            choice = randperm(numel(self.PATH));
            choice = self.PATH(choice(1));%从一维数组中进行随机采样
            state = State(self.x2,self.y,self.sigma,self.im);
            %对于辣椒的彩色图片,第三步迭代的默认两个参数是0.7, 0.8
            value_ = 0;
            if numel(self.choices) == 1 %当前在选择第二个参数
                %计算潜在的value
                x3 = step(self.x2, self.y, self.sigma^2, 15, 7, self.choices(1), choice);
                value_ = - (sum(sum((x3 - self.im).^2)) / numel(x3)); %反向来
            elseif numel(self.choices) == 0 %当前在选择第一个参数
                %计算潜在的value
                x3 = step(self.x2, self.y, self.sigma^2, 15, 7, choice,0.8);
                value_ = - (sum(sum((x3 - self.im).^2)) / numel(x3)); %反向来
            else
                value_ = 0;
            end
            
            
            %得到一个参数的选择结果
            state.value = self.value +  value_; %价值计算函数需要更改
            state.round = self.round+1;
            state.choices = [ self.choices,choice ];%扩充当前的选择
        end
        
        function display(self)
            fprintf(1,'class State:\n');%表示在终端上进行输出
            fprintf(1,'value = %f\n',self.value);
            fprintf(1,'round = %d\n',self.round);
            fprintf(1,'ready to show the choice array:\n');
            for i = 1:numel(self.choices)
                if i == 1
                    fprintf(1,'[');
                end
                fprintf(1,'%d,',self.choices(i));
                if i == numel(self.choices)
                    fprintf(1,']');
                end
            end
        end
        
    end
end

Node.m文件

classdef Node < handle
    properties
        parent
        children 
        quality
        visit
        state
        MAX_DEPTH = 2
        MAX_CHOICE = numel([0.1:0.2:3]) %其实代表的是children数组的长度的上限
    end
    methods
        function self= Node()
            self.quality = 0.0;
            self.visit = 0;
           %剩下的变量没有定义
        end
        
        function add_child(self,node)
            
            fprintf(1,'printing node in function add_child\n');
            node          
            
            self.children = [self.children,node];
            node.parent = self;
        end
        
        function display(self)
            fprintf(1,'class Node:\n');%表示在终端上进行输出
            fprintf(1,'quality = %f\n',self.quality);
            fprintf(1,'visit = %d\n',self.visit);
        end
        
        function  child_node = expand(cnt_node)
            %随机选择一个之前没有扩展过的——也就是不在children列表中的一个子节点进行扩展,随机性在new_state的时候的随机函数中体现出来
            %返回当前结点扩展出的子节点
            fprintf(1,'printing node in function EXPAND\n');
            cnt_node
            fprintf(1,'printing value of the ori_state in function EXPAND:%f\n\n',cnt_node.state.value);
            cnt_node.state.choices
            
            state = new_state(cnt_node.state);
            %拿到当前结点的children列表中的子节点的状态
            sub_state_value_list = [];
            for i = 1:numel(cnt_node.children)
                sub_state_value_list(i) = cnt_node.children(i).state.value;
            end
            fprintf(1,'printing value of the new_state in function EXPAND:%f\n\n',state.value);
            state.choices
            
            while ismember(state.value,sub_state_value_list)
                fprintf(1,'printing value of the new_state in function EXPAND:%f\n\n',state.value);
                 state.choices
                state = new_state(cnt_node.state);
            end
            child_node = Node();
            child_node.state = state;
            add_child(cnt_node,child_node);
            
            
            fprintf(1,'printing value of the end_child_state in function EXPAND\n');            
            for i = 1:numel(cnt_node.children)
                fprintf(1,'printing value of the end_child_state in function EXPAND:%f\n\n',cnt_node.children(i).state.value);
                cnt_node.children(i).state.choices
            end                        
            
        end
        
        function best = best_child(node)
            %返回当前结点的children列表中最适合作为扩展结点的子节点
            
            fprintf(1,'printing node in function BEST_CHILD\n');
            node   
            
            best_score = -100000000; %代表负无穷
            best = -1 ;%初始化
            for i=  1:numel(node.children)
                C = 1/sqrt(2.0);
                sub_node = node.children(i);
                left = sub_node.quality / sub_node.visit; %分母是被访问的次数
                right = 2.0*log(node.visit)/sub_node.visit;
                score = left+C*sqrt(right);
                
                if score >best_score
                    best = sub_node;
                    best_score = score;
                end
            end
        end
        
        function node = tree_policy(node)
            fprintf(1,'printing node in function TREE_POLICY\n');
            node   
        %选择+expand扩展
        %调用逻辑:如果当前结点还有子节点没有被添加到children列表——也就是还没有expand过,那么就从还没有扩展过的子节点中随机选择一个进行扩展,并返回该被需选中的子节点
        %调用逻辑:如果当前结点是叶子结点,直接返回该结点
        %调用逻辑:如果当前结点的所有子节点都已经被加入到了children列表,那么就从中选择一个收益最高的结点进行扩展,并且返回该结点
            %选择是否是叶子结点
            count = 0;
            while node.state.round < node.MAX_DEPTH
                fprintf(1,'running while-end with count:%d in Node.m/line73\n',count);
                count = count +1;
                if numel(node.children) < node.MAX_CHOICE
                    node = expand(node);
                    return
                else
                    node = best_child(node);
                end
            end     
        end
        
        function expanded_value = default_policy(node)
            fprintf(1,'printing node in function DEFAULT_POLICY\n');
            node   
        %模拟
        %算一次从当前结点随机走到叶节点的收益
            now_state = node.state;
            count= 0;
            while now_state.round < node.MAX_DEPTH
                fprintf(1,'running while-end with count:%d in Node.m/line90\n',count);
                count = count +1;
                now_state = new_state(now_state);
            end
            
            expanded_value = now_state.value;
            
        end
        
        function backup(node,reward)
            fprintf(1,'printing node in function BACKUP\n');
            node   
            %从当前结点带着reward回溯到根节点,并且增加路径上的每个结点的visit次数和quality
            while ~isempty(node)               
                fprintf(1,'not empty\n');
                node.visit = node.visit +1;
                node.quality = node.quality+reward;
                node = node.parent;
            end
        
        end
        
        function best = mcts(node)
            %似乎是多次尝试扩展,选择当前扩展到children列表中的子节点中的收益最好的一个子结点进行扩展,并且返回该被选中的子节点
          %  times =  5 ;%为什么是5?
            times = 20;
            for i = 1:times
                expand = tree_policy(node);%当前结点向下选择扩展一个结点
                reward = default_policy(expand);%计算从该扩展结点走到叶子结点的随机一条路径的一种收益情况
                backup(expand,reward);
            end
            best = best_child(node);
            
        end
        
        function main(self)
            init_state = State();
            init_node = Node();
            init_node.state = init_state;
            cnt_node = init_node;
            
            
            for i = 1:self.MAX_DEPTH
                cnt_node = mcts(cnt_node);
            end
            
        end
        
    end
end

Notice.

在matlab的实现版本中,注意两种不同的类的写法classdef name < handle是引用类型,这样的类可以作为另外一个类的属性存在。classdef name是按value类型,这样的类如果想要使用自己的实例对象作为类的一个属性会报错。
上面的Node类和State类都属于引用类型。

猜你喜欢

转载自blog.csdn.net/weixin_38316806/article/details/102631077