Golang:Context基本使用

本篇仅涉及如何使用 Context ,不包含源码分析


编程准则

  • 不要将 Context 存储在结构体中,而是要将其进行显式传递,参数命名约定为 ctx
1
2
3
func DoSomething(ctx context.Context, arg Arg) error {
// ... use ctx ...
}
  • 不要传递值为 nilContext ,使用 context.TODO
  • 不要使用 Context 来传递函数参数

Context 希望解决的问题

  • A goroutine’s parent may want to cancel it.
  • A goroutine may want to cancel its children.
  • Any blocking operations within a goroutine need to be preemptable(可抢占)so that it may be canceled

官方库提供的Context

提供的接口

1
2
3
4
5
6
7
8
9
10
var Canceled = errors.New("context canceled")
var DeadlineExceeded error = deadlineExceededError{}
type CancelFunc
type Context
func Background() Context
func TODO() Context
func WithCancel(parent Context) (ctx Context, cancel CancelFunc)
func WithDeadline(parent Context, deadline time.Time) (Context, CancelFunc)
func WithTimeout(parent Context, timeout time.Duration) (Context, CancelFunc)
func WithValue(parent Context, key, val interface{}) Context

Context接口实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
type emptyCtx int

type cancelCtx struct {
Context
// ...
}

type timerCtx struct {
cancelCtx
// ...
}

type valueCtx struct {
Context
// ...
}

实例代码

cancelCtx例子

设想一个场景,一个goroutine下有多个子goroutine,如果一个子goroutine出现错误问题,希望将所有还在运行的子goroutine立刻全部返回。

代码如下,将它们放在同一个context下。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
package main

import (
"context"
"fmt"
"sync"
"time"
)

func main() {
var wg sync.WaitGroup
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

wg.Add(5)
for i := 0; i < 5; i++ {
go func(id int, wg *sync.WaitGroup) {
if err := worker(ctx, id, time.Duration(id)*time.Second); err != nil {
fmt.Printf("%v", err)
cancel()
}
wg.Done()
}(i, &wg)
}
wg.Wait()
}

func worker(ctx context.Context, id int, timeUse time.Duration) (err error) {
fmt.Printf("[worker %d] got job: %v\n", id, timeUse)

// 让一个worker出问题
if id == 2 {
time.Sleep(2 * time.Second)
return fmt.Errorf("[worker %d] something went wrong\n", id)
}

select {
case <-ctx.Done():
return fmt.Errorf("[worker %d] quit : %v\n", id, ctx.Err())
case <-time.After(timeUse):
}
fmt.Printf("[worker %d] finish job\n", id)
return
}

输出如下

1
2
3
4
5
6
7
8
9
10
[worker 4] got job: 4s
[worker 3] got job: 3s
[worker 2] got job: 2s
[worker 1] got job: 1s
[worker 0] got job: 0s
[worker 0] finish job
[worker 1] finish job
[worker 2] something went wrong
[worker 3] quit : context canceled
[worker 4] quit : context canceled

timerCtx 例子

和上面相同的场景,我希望所有的worker必须在1s以内完成任务,超时的worker直接退出不同干了,那么代码如下

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
package main

import (
"context"
"fmt"
"sync"
"time"
)

func main() {
var wg sync.WaitGroup
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
workerNum := 4
defer cancel()

wg.Add(workerNum)
for i := 0; i < workerNum; i++ {
go func(id int, wg *sync.WaitGroup) {
if err := worker(ctx, id, time.Duration(id)*time.Second); err != nil {
fmt.Printf("%v", err)
cancel()
}
wg.Done()
}(i, &wg)
}
wg.Wait()
}

func worker(ctx context.Context, id int, timeUse time.Duration) (err error) {
fmt.Printf("[worker %d] got job: %v\n", id, timeUse)

select {
case <-ctx.Done():
return fmt.Errorf("[worker %d] quit : %v\n", id, ctx.Err())
case <-time.After(timeUse):
}
fmt.Printf("[worker %d] finish job\n", id)
return
}

输出结果如下

1
2
3
4
5
6
7
8
[worker 0] got job: 0s
[worker 1] got job: 1s
[worker 0] finish job
[worker 3] got job: 3s
[worker 2] got job: 2s
[worker 1] finish job
[worker 2] quit : context deadline exceeded
[worker 3] quit : context deadline exceeded

另外的一个例子来自 sohamkamani/blog-example-go-context-cancellation,将 Context 与Golang的http库结合起来用,设置超时时间

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
package main

import (
"context"
"fmt"
"net/http"
"time"
)

func main() {
// Create a new context
// With a deadline of 100 milliseconds
ctx := context.Background()
ctx, _ = context.WithTimeout(ctx, 100*time.Millisecond)

// Make a request, that will call the google homepage
req, _ := http.NewRequest(http.MethodGet, "https://baidu.com", nil)
// Associate the cancellable context we just created to the request
req = req.WithContext(ctx)

// Create a new HTTP client and execute the request
client := &http.Client{}
res, err := client.Do(req)
// If the request failed, log to STDOUT
if err != nil {
fmt.Println("Request failed:", err)
return
}
// Print the statuscode if the request succeeds
fmt.Println("Response received, status code:", res.StatusCode)
}

Context继承

例子来自《Concurrency in Go》

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
package main

import (
"context"
"fmt"
"sync"
"time"
)

// context 继承树
func main() {
var wg sync.WaitGroup
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

wg.Add(1)
go func() {
defer wg.Done()
if err := printGreet(ctx); err != nil {
fmt.Printf("cannot print greeting: %v\n", err)
cancel() // if there is an error, then cancel all functions under this ctx
}
}()

wg.Add(1)
go func() {
defer wg.Done()
if err := printFarewell(ctx); err != nil {
fmt.Printf("cannot print farewell: %v\n", err)
}
}()

wg.Wait()
}

func printGreet(ctx context.Context) error {
greeting, err := genGreeting(ctx)
if err != nil {
return err
}
fmt.Printf("%s world!\n", greeting)
return nil
}

func printFarewell(ctx context.Context) error {
farewell, err := genFarewell(ctx)
if err != nil {
return err
}
fmt.Printf("%s world!\n", farewell)
return nil
}

func genGreeting(ctx context.Context) (string, error) {
// set a timeout context
// if time is up, then return error
ctx, cancel := context.WithTimeout(ctx, 1*time.Second)
defer cancel()

switch locale, err := locale(ctx); {
case err != nil:
return "", err
case locale == "ZH/CN":
return "你好", nil
}
return "", fmt.Errorf("unsupported locale")
}

func genFarewell(ctx context.Context) (string, error) {
switch locale, err := locale(ctx); {
case err != nil:
return "", err
case locale == "ZH/CN":
return "再见", nil
}
return "", fmt.Errorf("unsupported locale")
}

// a blocking function
func locale(ctx context.Context) (string, error) {
// firstly we can test if we can meet the deadline
if deadline, ok := ctx.Deadline(); ok { // check if there is an deadline
if deadline.Sub(time.Now().Add(1*time.Minute)) <= 0 {
return "", context.DeadlineExceeded
}
}

select {
case <-ctx.Done(): // check if this context is cancelled (time is up? cancel() is called ?)
return "", ctx.Err()
case <-time.After(1 * time.Minute):
}
return "ZH/CN", nil
}

输出为

1
2
cannot print greeting: context deadline exceeded
cannot print farewell: context canceled

程序执行流程图