EMQX source code analysis---esockd_connection_sup source code analysis

The module is mainly to monitor the socket connection of the connection, so this module mainly contains some management interfaces for the connection. The main API of the module is as follows: 

1. start_link(Opts, MFA) This function mainly starts the esockd_connection_sup listener, which is called inside the function OTP's 
gen_server:start_link(?MODULE, [Opts, MFA], []) function, and then call back the init([Opts, MFA]) method of the module. 

2. count_connections(Sup) Calculate the number of socket connections under this module. Internally call call(Sup, count_connections) to send a synchronous message, and then the message is processed by the handle_call method of the module. 

3. get_max_connections(Sup) gets the maximum number of connections, internally calls call(Sup,get_max_connections) to send a synchronization message, and then the message is processed by the module's handle_call method. 

4. start_connection(Sup, Sock, UpgradeFuns) This function mainly starts a socket connection, internally calls call(Sup, {start_connection, Sock}) to 
   send a synchronous message, and then the message is processed by the module's handle_call method. 

5. set_max_connections(Sup, MaxConns) sets the maximum number of connections in the system, internally calls call(Sup, {set_max_connections, MaxConns}) to send a synchronization message, and then the message is processed by the module's handle_call method.

6. get_shutdown_count(Sup) gets the closed connection, calls call(Sup, get_shutdown_count) internally, and then is processed by the handle_call method.

See the source code comments in detail below.

-module(esockd_connection_sup).

-behaviour(gen_server).

-import(proplists, [get_value/3]).

-export([start_link/2, start_connection/3, count_connections/1]).
-export([get_max_connections/1, set_max_connections/2]).
-export([get_shutdown_count/1]).

%% Allow, Deny
-export([access_rules/1, allow/2, deny/2]).

%% gen_server callbacks
-export([init/1, handle_call/3, handle_cast/2, handle_info/2, terminate/2, code_change/3]).

-type(shutdown() :: brutal_kill | infinity | pos_integer()).

-record(state, {curr_connections :: map(), max_connections :: pos_integer(), access_rules :: list(), shutdown :: shutdown(), mfargs :: mfa()}).
%% 定义最大的客户端连接
-define(DEFAULT_MAX_CONNS, 1024).
%% 定义模块
-define(TRANSPORT, esockd_transport).
%%错误消息输出宏定义
-define(ERROR_MSG(Format, Args), error_logger:error_msg("[~s] " ++ Format, [?MODULE | Args])).

%% 启动连接监督者服务
-spec(start_link([esockd:option()], esockd:mfargs()) -> {ok, pid()} | ignore | {error, term()}).
start_link(Opts, MFA) ->
%%    io:format("esockd_connection_sup start_link ~n"),
    gen_server:start_link(?MODULE, [Opts, MFA], []).

%%------------------------------------------------------------------------------
%% API
%%------------------------------------------------------------------------------

%% 开启连接
start_connection(Sup, Sock, UpgradeFuns) ->
%%    发送同步消息给模块的handle_call方法去处理
    case call(Sup, {start_connection, Sock}) of
%%         返回连接进程的Pid
        {ok, ConnPid} ->
            %% Transfer controlling from acceptor to connection
            _ = ?TRANSPORT:controlling_process(Sock, ConnPid), %% 被监听
            _ = ?TRANSPORT:ready(ConnPid, Sock, UpgradeFuns),%% 准备读
            {ok, ConnPid};
        ignore -> ignore;
        {error, Reason} ->
            {error, Reason}
    end.

%% 启动连接进程,以echo_server为例子,然后就会调用echo_server的init函数
-spec(start_connection_proc(esockd:mfargs(), esockd_transport:sock()) -> {ok, pid()} | ignore | {error, term()}).
start_connection_proc(M, Sock) when is_atom(M) ->
    M:start_link(?TRANSPORT, Sock);
start_connection_proc({M, F}, Sock) when is_atom(M), is_atom(F) ->
    M:F(?TRANSPORT, Sock);
start_connection_proc({M, F, Args}, Sock) when is_atom(M), is_atom(F), is_list(Args) ->
    erlang:apply(M, F, [?TRANSPORT, Sock | Args]). %% echo_server,start_link,[]


-spec(count_connections(pid()) -> integer()).
count_connections(Sup) ->
    call(Sup, count_connections).

-spec(get_max_connections(pid()) -> integer()).
get_max_connections(Sup) when is_pid(Sup) ->
    call(Sup, get_max_connections).

-spec(set_max_connections(pid(), integer()) -> ok).
set_max_connections(Sup, MaxConns) when is_pid(Sup) ->
    call(Sup, {set_max_connections, MaxConns}).

-spec(get_shutdown_count(pid()) -> integer()).
get_shutdown_count(Sup) ->
    call(Sup, get_shutdown_count).

access_rules(Sup) ->
    call(Sup, access_rules).

allow(Sup, CIDR) ->
    call(Sup, {add_rule, {allow, CIDR}}).

deny(Sup, CIDR) ->
    call(Sup, {add_rule, {deny, CIDR}}).

call(Sup, Req) ->
    gen_server:call(Sup, Req, infinity).

%%------------------------------------------------------------------------------
%% gen_server callbacks
%%------------------------------------------------------------------------------

init([Opts, MFA]) ->
    process_flag(trap_exit, true),
%%      获取进程关闭方式
    Shutdown = get_value(shutdown, Opts, brutal_kill),
%%      获取设置的最大连接数量
    MaxConns = get_value(max_connections, Opts, ?DEFAULT_MAX_CONNS),
%%      获取规则
    RawRules = get_value(access_rules, Opts, [{allow, all}]),
%%      获取权限规则
    AccessRules = [esockd_access:compile(Rule) || Rule <- RawRules],
%%      数据存入进程的state记录里
    {ok, #state{curr_connections = #{}, max_connections = MaxConns, access_rules = AccessRules, shutdown = Shutdown, mfargs = MFA}}.

%% 当连接的数量大于最大设置数据,就返回一个{error, maxlimit} 消息
handle_call({start_connection, _Sock}, _From, State = #state{curr_connections = Conns, max_connections = MaxConns}) when map_size(Conns) >= MaxConns ->
    {reply, {error, maxlimit}, State};

%% 启动连接
%% 参数说明:
%% start_connection:原子变量,用于匹配消息
%% Sock:启动socket
%% _From: 消息来自哪个进程
%% State:当前进程状态
%%      curr_connections:当前连接数量
%%      access_rules:当前权限资源
%%      mfargs:要启动的模块,方法,和方法执行的参数组成的元组{M,F,A}
handle_call({start_connection, Sock}, _From, State = #state{curr_connections = Conns, access_rules = Rules, mfargs = MFA}) ->
%%    通过Sock获取socket的ip和port
    case esockd_transport:peername(Sock) of
        {ok, {Addr, _Port}} ->
%%            判断当前的ip地址是不是合法
            case allowed(Addr, Rules) of
                true ->
%%                    如果是合法的地址 开启一个连接进程
                    case catch start_connection_proc(MFA, Sock) of %% echo_server,start_link,[]
%%                    执行成功,返回echo_server的进程Pid
                        {ok, Pid} when is_pid(Pid) ->
%%                            修改进程记录State里面当前连接数的值
                            {reply, {ok, Pid}, State#state{curr_connections = maps:put(Pid, true, Conns)}};
                        ignore ->
                            {reply, ignore, State};
                        {error, Reason} ->
                            {reply, {error, Reason}, State};
                        What ->
                            {reply, {error, What}, State}
                    end;
                false ->
                    {reply, {error, forbidden}, State}
            end;
        {error, Reason} ->
            {reply, {error, Reason}, State}
    end;

%% 计算连接数量
handle_call(count_connections, _From, State = #state{curr_connections = Conns}) ->
    {reply, maps:size(Conns), State};
%% 获取最大的连接数量
handle_call(get_max_connections, _From, State = #state{max_connections = MaxConns}) ->
    {reply, MaxConns, State};
%% 设置最大的连接数
handle_call({set_max_connections, MaxConns}, _From, State) ->
    {reply, ok, State#state{max_connections = MaxConns}};
%% 获取关闭的连接
handle_call(get_shutdown_count, _From, State) ->
    Counts = [{Reason, Count} || {
   
   {shutdown_count, Reason}, Count} <- get()],
    {reply, Counts, State};
%%权限规则
handle_call(access_rules, _From, State = #state{access_rules = Rules}) ->
    {reply, [raw(Rule) || Rule <- Rules], State};
%% 增加规则
handle_call({add_rule, RawRule}, _From, State = #state{access_rules = Rules}) ->
    case catch esockd_access:compile(RawRule) of
        {'EXIT', _Error} ->
            {reply, {error, bad_access_rule}, State};
        Rule ->
            case lists:member(Rule, Rules) of
                true ->
                    {reply, {error, already_exists}, State};
                false ->
                    {reply, ok, State#state{access_rules = [Rule | Rules]}}
            end
    end;

handle_call(Req, _From, State) ->
    ?ERROR_MSG("unexpected call: ~p", [Req]),
    {reply, ignored, State}.

handle_cast(Msg, State) ->
    ?ERROR_MSG("unexpected cast: ~p", [Msg]),
    {noreply, State}.

%% 处理异常退出原因
handle_info({'EXIT', Pid, Reason}, State = #state{curr_connections = Conns}) ->
    case maps:take(Pid, Conns) of
        {true, Conns1} ->
            connection_crashed(Pid, Reason, State),
            {noreply, State#state{curr_connections = Conns1}};
        error ->
            ?ERROR_MSG("unexpected 'EXIT': ~p, reason: ~p", [Pid, Reason]),
            {noreply, State}
    end;

handle_info(Info, State) ->
    ?ERROR_MSG("unexpected info: ~p", [Info]),
    {noreply, State}.

%% 终止子进程
terminate(_Reason, State) ->
    terminate_children(State).

code_change(_OldVsn, State, _Extra) ->
    {ok, State}.

%%------------------------------------------------------------------------------
%% Internal functions
%%------------------------------------------------------------------------------
%% 匹配是否有权限
allowed(Addr, Rules) ->
    case esockd_access:match(Addr, Rules) of
%%        没有匹配,返回true
         nomatch          -> true;
%%        匹配允许,返回true
        {matched, allow} -> true;
%%        匹配否定,返回false
        {matched, deny}  -> false
    end.
%% 允许
raw({allow, CIDR = {_Start, _End, _Len}}) ->
     {allow, esockd_cidr:to_string(CIDR)};
%% 否定
raw({deny, CIDR = {_Start, _End, _Len}}) ->
     {deny, esockd_cidr:to_string(CIDR)};
raw(Rule) ->
     Rule.

%% 正常的连接销魂
connection_crashed(_Pid, normal, _State) ->
    ok;
%% 关闭销毁
connection_crashed(_Pid, shutdown, _State) ->
    ok;
%% kill销毁
connection_crashed(_Pid, killed, _State) ->
    ok;

connection_crashed(_Pid, Reason, _State) when is_atom(Reason) ->
    count_shutdown(Reason);
connection_crashed(_Pid, {shutdown, Reason}, _State) when is_atom(Reason) ->
    count_shutdown(Reason);
connection_crashed(Pid, {shutdown, Reason}, State) ->
%%    记录连接关闭
    report_error(connection_shutdown, Reason, Pid, State);
connection_crashed(Pid, Reason, State) ->
%%    记录连接销毁
    report_error(connection_crashed, Reason, Pid, State).

%% 计算关机原因
count_shutdown(Reason) ->
    Key = {shutdown_count, Reason},
    put(Key, case get(Key) of undefined -> 1; Cnt -> Cnt+1 end).

%% 终止该进程下的子进程
terminate_children(State = #state{curr_connections = Conns, shutdown = Shutdown}) ->
%% 返回进程数组
    {Pids, EStack0} = monitor_children(Conns),
%% 计算数组大小    
    Sz = sets:size(Pids),
%% 判断关闭原因
    EStack = case Shutdown of
                %% 暴力关闭
                 brutal_kill ->
                     sets:fold(fun(P, _) -> exit(P, kill) end, ok, Pids),
                     wait_children(Shutdown, Pids, Sz, undefined, EStack0);
                %% 
                 infinity ->
                     sets:fold(fun(P, _) -> exit(P, shutdown) end, ok, Pids),
                     wait_children(Shutdown, Pids, Sz, undefined, EStack0);
                %% 超时关闭 
                Time when is_integer(Time) ->
                     sets:fold(fun(P, _) -> exit(P, shutdown) end, ok, Pids),
                     TRef = erlang:start_timer(Time, self(), kill),
                     wait_children(Shutdown, Pids, Sz, TRef, EStack0)
             end,
    %% Unroll stacked errors and report them
    dict:fold(fun(Reason, Pid, _) ->
                  report_error(connection_shutdown_error, Reason, Pid, State)
              end, ok, EStack).

monitor_children(Conns) ->
    lists:foldl(fun(P, {Pids, EStack}) ->
        case monitor_child(P) of
            ok ->
                {sets:add_element(P, Pids), EStack};
            {error, normal} ->
                {Pids, EStack};
            {error, Reason} ->
                {Pids, dict:append(Reason, P, EStack)}
        end
    end, {sets:new(), dict:new()}, maps:keys(Conns)).

%% Help function to shutdown/2 switches from link to monitor approach
monitor_child(Pid) ->
    %% Do the monitor operation first so that if the child dies
    %% before the monitoring is done causing a 'DOWN'-message with
    %% reason noproc, we will get the real reason in the 'EXIT'-message
    %% unless a naughty child has already done unlink...
    erlang:monitor(process, Pid),
    unlink(Pid),

    receive
	%% If the child dies before the unlik we must empty
	%% the mail-box of the 'EXIT'-message and the 'DOWN'-message.
	{'EXIT', Pid, Reason} ->
	    receive
		{'DOWN', _, process, Pid, _} ->
		    {error, Reason}
	    end
    after 0 ->
	    %% If a naughty child did unlink and the child dies before
	    %% monitor the result will be that shutdown/2 receives a
	    %% 'DOWN'-message with reason noproc.
	    %% If the child should die after the unlink there
	    %% will be a 'DOWN'-message with a correct reason
	    %% that will be handled in shutdown/2.
	    ok
    end.

wait_children(_Shutdown, _Pids, 0, undefined, EStack) ->
    EStack;
wait_children(_Shutdown, _Pids, 0, TRef, EStack) ->
	%% If the timer has expired before its cancellation, we must empty the
	%% mail-box of the 'timeout'-message.
    erlang:cancel_timer(TRef),
    receive
        {timeout, TRef, kill} ->
            EStack
    after 0 ->
            EStack
    end;

%%TODO: Copied from supervisor.erl, rewrite it later.
wait_children(brutal_kill, Pids, Sz, TRef, EStack) ->
    receive
        {'DOWN', _MRef, process, Pid, killed} ->
            wait_children(brutal_kill, sets:del_element(Pid, Pids), Sz-1, TRef, EStack);

        {'DOWN', _MRef, process, Pid, Reason} ->
            wait_children(brutal_kill, sets:del_element(Pid, Pids),
                          Sz-1, TRef, dict:append(Reason, Pid, EStack))
    end;

wait_children(Shutdown, Pids, Sz, TRef, EStack) ->
    receive
        {'DOWN', _MRef, process, Pid, shutdown} ->
            wait_children(Shutdown, sets:del_element(Pid, Pids), Sz-1, TRef, EStack);
        {'DOWN', _MRef, process, Pid, normal} ->
            wait_children(Shutdown, sets:del_element(Pid, Pids), Sz-1, TRef, EStack);
        {'DOWN', _MRef, process, Pid, Reason} ->
            wait_children(Shutdown, sets:del_element(Pid, Pids), Sz-1,
                          TRef, dict:append(Reason, Pid, EStack));
        {timeout, TRef, kill} ->
            sets:fold(fun(P, _) -> exit(P, kill) end, ok, Pids),
            wait_children(Shutdown, Pids, Sz-1, undefined, EStack)
    end.

%% 上报错误原因
report_error(Error, Reason, Pid, #state{mfargs = MFA}) ->
%%    获取sup进程的名称
    SupName  = list_to_atom("esockd_connection_sup - " ++ pid_to_list(self())),
%%    组装错误信息
    ErrorMsg = [{supervisor, SupName}, {errorContext, Error}, {reason, Reason}, {offender, [{pid, Pid}, {name, connection}, {mfargs, MFA}]}],
%%    上报错误日志
    error_logger:error_report(supervisor_report, ErrorMsg).

The next article will introduce the basic functions of the esockd_acceptor_sup module.

Guess you like

Origin blog.csdn.net/qq513036862/article/details/88075217