go-context实例

使用channel实现各个协程的优雅关闭

package main

import "fmt"
import "time"

func main() {
    messages := make(chan int, 10)
    doneConsumer := make(chan bool, 1)
	doneProducer := make(chan bool, 1)

    defer close(messages)

    // 消费线程
    go func() {
        ticker := time.NewTicker(time.Millisecond * 400)
        for _ = range ticker.C {
            select {
            case <-doneProducer:
				// 清空消息队列中的元素
				for i := 0; i < len(messages); i++ {
                	fmt.Printf("recv message: %d\n", <-messages)
				}

                fmt.Println("child process interrupt...")
				// 主动关闭消费者
				doneConsumer<-true
                return
			case val, ok := <-messages:
				if ok {
                	fmt.Printf("recv message: %d\n", val)
				} else {
                	fmt.Printf("chan has been closed\n")	
				}
			default:
                fmt.Printf("no chan event happen\n")	
            }
        }
    }()

	// 生产线程
	go func() {
		for i := 0; i < 10; i++ {
  			time.Sleep(1 * time.Second)
       		messages <- i
    	}

		//主动关闭生产者
		doneProducer<-true
	}()

    time.Sleep(5 * time.Second)

	// 等待消费者退出
	<-doneConsumer
	fmt.Println("main process exit!")
}

// 学习博客:https://blog.csdn.net/u012190809/article/details/107700495

WithValue

package main

import (
    "context"
    "fmt"
)

type test01String string
type test02String string

func main() {
    // ctx 为 emptyCtx
    // type emptyCtx int
    ctx := context.Background()
    test01(ctx)
}

func test01(ctx context.Context) {
    // ctx01 是 ctx 的 子ctx
    // ctx01 的 key 是 test01String("test01") ,
    // ctx01 的 key 的类型是 test01String , 键值是 "test01"
    // ctx01 的 value 的类型是 String , 实值是 "hello1"
    ctx01 := context.WithValue(ctx, test01String("test01"), "hello1")
    // 为 ctx01 再创建一个 子ctx
    test02(ctx01)
}

func test02(ctx context.Context) {
    // ctx02 是 ctx01 的 子ctx
    // ctx02 的 key 是 test02String("test02") ,
    // ctx02 的 key 的类型是 test02String , 键值是 "test02"
    // ctx02 的 value 的类型是 String , 实值是 "hello2"
    ctx02 := context.WithValue(ctx, test02String("test02"), "hello2")

    // 执行跟 ctx02 强相关的操作
    test03(ctx02)
}

func test03(ctx context.Context) {
    // 获取 ctx02 的 键值为 test01String("test01") 的 实值
    fmt.Println(ctx.Value(test01String("test01"))) // 打印 hello1
    // 获取 ctx02 的 键值为 test02String("test02") 的 实值
    fmt.Println(ctx.Value(test02String("test02"))) // 打印 hello2
    // 获取 ctx02 的 键值为 test02String("test03") 的 实值
    fmt.Println(ctx.Value(test02String("test03"))) // 打印 <nil>
}

//===========================================================================

WithCancel

package main
 
import (
    "context"
    "fmt"
    "time"
)
 
func main() {
    // 根ctx 是 context.Background() ,即 emptyCtx
    // 创建一个 ctx1 , 它的 父ctx 是 emptyCtx
    // type CancelFunc func() :这里的返回值 cancel 是 func() 类型
    ctx1, cancel := context.WithCancel(context.Background())
    // 开启一个 watch1 协程,则该协程的 根ctx 就是 ctx1
    go watch1(ctx1)

    fmt.Println("main: 现在开始等待10秒")
    time.Sleep(10 * time.Second)

    // emptyCtx 退出
    fmt.Println("main: 等待10秒结束,关闭ctx1前")
    cancel()
    fmt.Println("main: 等待10秒结束,关闭ctx1后")

    // 再等待5秒看输出,可以发现父context的子协程和子context的子协程都会被结束掉
    time.Sleep(5 * time.Second)
    fmt.Println("最终结束")
}

func watch1(ctx context.Context) {
    // 创建一个 ctx11 , 它的 父ctx 是 ctx1
    ctx11, _ := context.WithCancel(ctx)
    // 开启一个 watch11 协程,则该协程的 根ctx 就是 ctx11
    go watch11(ctx11)

    // 创建一个 ctx12 , 它的 父ctx 是 ctx1
    ctx12, _ := context.WithCancel(ctx)
    // 开启一个 watch12 协程,则该协程的 根ctx 就是 ctx12
    go watch12(ctx12)

    // 创建一个 ctx13 , 它的 父ctx 是 ctx1
    ctx13, _ := context.WithCancel(ctx)
    // 开启一个 watch13 协程,则该协程的 根ctx 就是 ctx13
    go watch13(ctx13)

    for {
        select {
        // 等待 ctx1 的退出信号
        // ctx.Done() 返回的是 当前ctx 的 chan 
        case <-ctx.Done(): //能取出值即说明是结束信号
            fmt.Println("ctx1: 收到信号,watch1的协程退出")
            return
        default:
            fmt.Println("ctx1: watch1的协程监控中")
            time.Sleep(1 * time.Second)
        }
    }
}

func watch11(ctx context.Context) {
    // 创建一个 ctx111 , 它的 父ctx 是 ctx11
    ctx111, _ := context.WithCancel(ctx)
    // 开启一个 watch111 协程,则该协程的 根ctx 就是 ctx111
    go watch111(ctx111)

    for {
        select {
        // 等待 ctx11 的退出信号
        // ctx.Done() 返回的是 当前ctx 的 chan 
        case <-ctx.Done(): //能取出值即说明是结束信号
            fmt.Println("ctx11: 收到信号, watch11的协程退出")
            return
        default:
            fmt.Println("ctx11: watch11的协程监控中")
            time.Sleep(1 * time.Second)
        }
    }
}

// watch12 在 6秒 退出
// watch12 退出时,注销 watch1 中的登记,退出原因:Canceled
func watch12(ctx context.Context) {
    // 创建一个 ctx121 , 它的 父ctx 是 ctx12
    ctx121, cancel := context.WithCancel(ctx)
    // 开启一个 watch121 协程,则该协程的 根ctx 就是 ctx121
    go watch121(ctx121)

    i := 0
    flag := 0
    for {
        select {
        // 等待 ctx12 的退出信号
        // ctx.Done() 返回的是 当前ctx 的 chan 
        case <-ctx.Done(): //能取出值即说明是结束信号
            fmt.Println("ctx12: 收到信号, watch12的协程退出")
            return
        default:
            fmt.Println("ctx12: watch12的协程监控中")
            time.Sleep(1 * time.Second)
            if i++; i == 6 && flag == 0 {
                fmt.Println("ctx12: watch12的children在6秒后主动cancel前")
                cancel()
                flag = 1
                fmt.Println("ctx12: watch12的children在6秒后主动cancel后")
            }
        }
    }
}

// watch13 在 8秒 退出
// watch13 退出时,注销 watch1 中的登记,退出原因:Canceled
func watch13(ctx context.Context) {
    // 创建一个 ctx131 , 它的 父ctx 是 ctx13
    ctx131, cancel := context.WithCancel(ctx)
    // 开启一个 watch131 协程,则该协程的 根ctx 就是 ctx131
    go watch131(ctx131)

    i := 0
    flag := 0
    for {
        select {
        // 等待 ctx13 的退出信号
        // ctx.Done() 返回的是 当前ctx 的 chan 
        case <-ctx.Done(): //能取出值即说明是结束信号
            fmt.Println("ctx13: 收到信号, watch13的协程退出")
            return
        default:
            fmt.Println("ctx13: watch13的协程监控中")
            time.Sleep(1 * time.Second)
            if i++; i == 8 && flag == 0 {
                fmt.Println("ctx13: watch13的children在8秒后主动cancel前")
                cancel()
                fmt.Println("ctx13: watch13的children在8秒后主动cancel后")
				// 注意,此时不能 return
            }
        }
    }
}

// watch111 等待 ctx11 挨个注销children 时才退出  
func watch111(ctx context.Context) {
    for {
        select {
        case <-ctx.Done():
            fmt.Println("ctx111: 收到信号,watch111的协程退出")
            return
        default:
            fmt.Println("ctx111: watch111的协程监控中")
            time.Sleep(1 * time.Second)
        }
    }
}

// watch121 等待 ctx12 挨个注销children 时才退出  
func watch121(ctx context.Context) {
    for {
        select {
        case <-ctx.Done():
            fmt.Println("ctx121: 收到信号,watch121的协程退出")
            return
        default:
            fmt.Println("ctx121: watch121的协程监控中")
            time.Sleep(1 * time.Second)
        }
    }
}

// watch131 等待 ctx13 挨个注销children 时才退出  
func watch131(ctx context.Context) {
    for {
        select {
        case <-ctx.Done():
            fmt.Println("ctx131: 收到信号,watch131的协程退出")
            return
        default:
            fmt.Println("ctx131: watch131的协程监控中")
            time.Sleep(1 * time.Second)
        }
    }
}
main: 现在开始等待10秒
... ...
ctx12: watch12的children在6秒后主动cancel前
ctx12: watch12的children在6秒后主动cancel后
... ...
ctx121: 收到信号,watch121的协程退出
... ...
ctx13: watch13的children在8秒后主动cancel前
ctx13: watch13的children在8秒后主动cancel后
... ...
ctx131: 收到信号,watch131的协程退出
... ...
main: 等待10秒结束,关闭ctx1前
main: 等待10秒结束,关闭ctx1后
ctx1: 收到信号,watch1的协程退出
ctx11: 收到信号, watch11的协程退出
ctx13: 收到信号, watch13的协程退出
ctx111: 收到信号,watch111的协程退出
ctx12: 收到信号, watch12的协程退出

本地尝试,将 context.Context (interface) 转换为 cancelCtx (struct), 再转为 canceler (interface)
发现不行,因为 canceler 是私有接口
只要能拿到 cancelCtx ,就能调用 cancelCtx.cancel(true/false, err)方法
但是,context.WithCancel(ctx) 原型是 func WithCancel(parent Context) (ctx Context, cancel CancelFunc)
形参是 context.Context (interface)类型,返回值也是 context.Context

//===========================================================================

WithTimeout

package main

import (
    "context"
    "fmt"
    "time"
)

func longRunningCalculation(timeCost int) chan string {
    result:=make(chan string)
    go func (){
        time.Sleep(time.Second*(time.Duration(timeCost)))
        result<-"Done"
    }()

    return result
}

func main() {
	// 创建一个 ctx , 继承自 context.Background()
	// 该ctx 的超时时间是6秒
    ctx,cancel := context.WithTimeout(context.Background(), 6 * time.Second)
    defer cancel()

	// 如果一个函数执行时长超过6秒,则ctx自动退出
	// 如果一个函数执行时长没有超过6秒,则正常打印执行结果
    select{
        case <-ctx.Done():
            fmt.Println(ctx.Err())
            return
        case result:=<-longRunningCalculation(5):
            fmt.Println("Answer is ", result)
    }
    return
}

结果:

(1) Answer is  Done
(2) context deadline exceeded

//===========================================================================

WithDeadline

package main

import (
    "context"
    "fmt"
    "time"
)

func longRunningCalculation(timeCost int) chan string {
    result:=make(chan string)
    go func (){
        time.Sleep(time.Second*(time.Duration(timeCost)))
        result<-"Done"
    }()

    return result
}

func main() {
    ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(time.Second*3))
    defer cancel()

    select{
        case <-ctx.Done():
            fmt.Println(ctx.Err())
            return
        case result:=<-longRunningCalculation(5):
            fmt.Println("Answer is ", result)
    }
    return
}

//===========================================================================

生产者-消费者

package main

import "fmt"
import "time"
import "context"

func main() {
    messages := make(chan int, 10)
    ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)

    defer close(messages)
    defer cancel()

    // 消费线程
    go func(ctx context.Context) {
        ticker := time.NewTicker(time.Millisecond * 400)
        for _ = range ticker.C {
            select {
            case val, ok := <-messages:
                if ok {
                    fmt.Printf("recv message: %d\n", val)
                } else {
                    fmt.Printf("chan has been closed\n")    
                }
            default:
                fmt.Printf("no chan event happen\n")    
            }
        }
    }(ctx)

    // 生产线程
    go func() {
        for i := 0; i < 10; i++ {
            time.Sleep(1 * time.Second)
            messages <- i
        }
    }()

    // 主线程阻塞在 ctx.Done() 上面
    select {
    case <-ctx.Done():
        fmt.Println("child process interrupt...")
		// 把 chan 中的数据全部拿完
        for i := 0; i < len(messages); i++ {
            fmt.Printf("recv message: %d\n", <-messages)
        }
        time.Sleep(1 * time.Second)
        fmt.Println("main process exit!")
    }

    return
}

好奇,两处地方都能接收到信号

import "fmt"
import "time"
import "context"

func main() {
    messages := make(chan int, 10)
    ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)

    defer close(messages)
    defer cancel()

    // 消费线程
    go func(ctx context.Context) {
        ticker := time.NewTicker(time.Millisecond * 400)
        for _ = range ticker.C {
            select {
            case <-ctx.Done(): //这里能收到
                for i := 0; i < len(messages); i++ {
                    fmt.Printf("recv message: %d\n", <-messages)
                }
                fmt.Println("child process interrupt...")
            case val, ok := <-messages:
                if ok {
                    fmt.Printf("recv message: %d\n", val)
                } else {
                    fmt.Printf("chan has been closed\n")    
                }
            default:
                fmt.Printf("no chan event happen\n")    
            }
        }
    }(ctx)

    // 生产线程
    go func() {
        for i := 0; i < 10; i++ {
            time.Sleep(1 * time.Second)
            messages <- i
        }
    }()

    select {
    case <-ctx.Done(): //这里能收到
		// 如果把 sleep 放后面,go func(ctx context.Context) 中能打印好几个 interrupt
        time.Sleep(1 * time.Second)
        fmt.Println("main process exit!")
    }

    return
}

附件

查看 channel 中未读数据的长度

// 学习博客:https://blog.csdn.net/lanyang123456/article/details/83096127

func len(v Type) int

The len built-in function returns the length of v, according to its type:
Array: 
the number of elements in v

Pointer to array: 
the number of elements in *v (even if v is nil).

Slice, or map: 
the number of elements in v; if v is nil, len(v) is zero.

String: 
the number of bytes in v.

Channel: 
the number of elements queued (unread) in the channel buffer; if v is nil, len(v) is zero.

猜你喜欢

转载自blog.csdn.net/wangkai6666/article/details/121180461