16.统一异常处理(上)编写自定义处理函数

首先放上httptransport.NewServer的源碼

func NewServer(
    e endpoint.Endpoint,
    dec DecodeRequestFunc,
    enc EncodeResponseFunc,
    options ...ServerOption, //这里的不定长参数可以用来自定义错误处理
) *Server {
    s := &Server{
        e:            e,
        dec:          dec,
        enc:          enc,
        errorEncoder: DefaultErrorEncoder,
        errorHandler: transport.NewLogErrorHandler(log.NewNopLogger()),
    }
    for _, option := range options {
        option(s)
    }
    return s
}

自定义ErrorEncoder

func MyErrorEncoder(ctx context.Context, err error, w http.ResponseWriter) {
    contentType, body := "text/plain; charset=utf-8", []byte(err.Error())
    w.Header().Set("Content-type", contentType) //设置请求头
    w.WriteHeader(429) //写入返回码
    w.Write(body)
}

使用ErrorEncoder生成ServerOption继而生成相应的Handler

options := []httptransport.ServerOption{
        httptransport.ServerErrorEncoder(Services.MyErrorEncoder),
        //ServerErrorEncoder支持ErrorEncoder类型的参数 type ErrorEncoder func(ctx context.Context, err error, w http.ResponseWriter)
           //我们自定义的MyErrorEncoder只要符合ErrorEncoder类型就可以传入
} //创建ServerOption切片

serverHandler := httptransport.NewServer(endp, Services.DecodeUserRequest, Services.EncodeUserResponse, options...)//在创建handler的同事把切片展开传入

完整代码

package main

import (
    "flag"
    "fmt"
    httptransport "github.com/go-kit/kit/transport/http"
    mymux "github.com/gorilla/mux"
    "golang.org/x/time/rate"
    "gomicro/Services"
    "gomicro/utils"
    "log"
    "net/http"
    "os"
    "os/signal"
    "strconv"
    "syscall"
)

func main() {
    name := flag.String("name", "", "服务名称")
    port := flag.Int("port", 0, "服务端口")
    flag.Parse()
    if *name == "" {
        log.Fatal("请指定服务名")
    }
    if *port == 0 {
        log.Fatal("请指定端口")
    }
    utils.SetServiceNameAndPort(*name, *port) //设置服务名和端口

    user := Services.UserService{}
    limit := rate.NewLimiter(1, 5)
    endp := Services.RateLimit(limit)(Services.GenUserEnPoint(user))

    options := []httptransport.ServerOption{ //生成ServerOtion切片,传入我们自定义的错误处理函数
        httptransport.ServerErrorEncoder(Services.MyErrorEncoder),
    }

    serverHandler := httptransport.NewServer(endp, Services.DecodeUserRequest, Services.EncodeUserResponse, options...) //使用go kit创建server传入我们之前定义的两个解析函数

    r := mymux.NewRouter()
    //r.Handle(`/user/{uid:\d+}`, serverHandler) //这种写法支持多种请求方式
    r.Methods("GET", "DELETE").Path(`/user/{uid:\d+}`).Handler(serverHandler) //这种写法仅支持Get,限定只能Get请求
    r.Methods("GET").Path("/health").HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
        writer.Header().Set("Content-type", "application/json")
        writer.Write([]byte(`{"status":"ok"}`))
    })
    errChan := make(chan error)
    go func() {
        utils.RegService()                                                 //调用注册服务程序
        err := http.ListenAndServe(":"+strconv.Itoa(utils.ServicePort), r) //启动http服务
        if err != nil {
            log.Println(err)
            errChan <- err
        }
    }()
    go func() {
        sigChan := make(chan os.Signal)
        signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
        errChan <- fmt.Errorf("%s", <-sigChan)
    }()
    getErr := <-errChan
    utils.UnRegService()
    log.Println(getErr)
}
//因为设置了限流,访问过快就会报错,此时的报错Code就是我们自定义的429

限流代码

package Services

import (
    "context"
    "errors"
    "fmt"
    "github.com/go-kit/kit/endpoint"
    "golang.org/x/time/rate"
    "gomicro/utils"
    "strconv"
)

type UserRequest struct { //封装User请求结构体
    Uid    int `json:"uid"`
    Method string
}

type UserResponse struct {
    Result string `json:"result"`
}

//加入限流功能中间件
func RateLimit(limit *rate.Limiter) endpoint.Middleware { //Middleware type Middleware func(Endpoint) Endpoint
    return func(next endpoint.Endpoint) endpoint.Endpoint { //Endpoint type Endpoint func(ctx context.Context, request interface{}) (response interface{}, err error)
        return func(ctx context.Context, request interface{}) (response interface{}, err error) {
            if !limit.Allow(){
                return nil,errors.New("too many requests")
            }
            return next(ctx,request)
        }
    }
}

func GenUserEnPoint(userService IUserService) endpoint.Endpoint {
    return func(ctx context.Context, request interface{}) (response interface{}, err error) {
        r := request.(UserRequest) //通过类型断言获取请求结构体
        result := "nothings"
        if r.Method == "GET" {
            result = userService.GetName(r.Uid) + strconv.Itoa(utils.ServicePort)
            fmt.Println(result)
        } else if r.Method == "DELETE" {
            err := userService.DelUser(r.Uid)
            if err != nil {
                result = err.Error()
            } else {
                result = fmt.Sprintf("userid为%d的用户已删除", r.Uid)
            }
        }
        return UserResponse{Result: result}, nil
    }
}




猜你喜欢

转载自www.cnblogs.com/hualou/p/12083483.html