go http 中间件

在 go 里实现一个 http 中间件并不难,使用洋葱模型即可。本篇代码参考 https://github.com/urfave/negroni

中间件的设计如下

首先,定义一个 Handler 接口及其函数式实现

1
2
3
4
5
6
7
8
9
type Handler interface {
ServeHTTP(rw http.ResponseWriter, r *http.Request, next http.HandlerFunc)
}

type HandlerFunc func(rw http.ResponseWriter, r *http.Request, next http.HandlerFunc)

func (h HandlerFunc) ServeHTTP(rw http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
h(rw, r, next)
}

然后定义 middleware 的结构

1
2
3
4
5
6
7
8
type middleware struct {
handler Handler
next *middleware
}

func (mw middleware) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
mw.handler.ServeHTTP(rw, r, mw.next.ServeHTTP)
}

它实现了标准库 http/server.go 中的 Handler 接口

1
2
3
type Handler interface {
ServeHTTP(ResponseWriter, *Request)
}

最后定义这个中间件的整体实现结构,这个也实现了 Handler 接口

1
2
3
4
5
6
7
8
type HTTPMiddleware struct {
handlers []Handler
mw *middleware
}

func (hmw *HTTPMiddleware) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
hmw.mw.ServeHTTP(rw, r)
}

接下来就是初始化这个中间件了。注意一点,最后需要加一个 emptyMiddleware 作为结尾。coreMiddleware 的含义是位于最里面的中间件,其实也就是包含业务逻辑的函数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
func New(handlers ...Handler) *HTTPMiddleware {
return &HTTPMiddleware{
handlers: handlers,
mw: build(handlers),
}
}

func build(handlers []Handler) *middleware {
if len(handlers) == 0 {
return newEmptyMiddleWare()
}
if len(handlers) == 1 {
return newCoreMiddleWare(handlers[0])
}
return &middleware{
handler: handlers[0],
next: build(handlers[1:]),
}
}

func newEmptyMiddleWare() *middleware {
fn := func(rw http.ResponseWriter, r *http.Request, next http.HandlerFunc) {}
return &middleware{handler: HandlerFunc(fn), next: &middleware{}}
}

func newCoreMiddleWare(handler Handler) *middleware {
fn := func(rw http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
handler.ServeHTTP(rw, r, func(http.ResponseWriter, *http.Request) {})
}
return &middleware{handler: HandlerFunc(fn), next: newEmptyMiddleWare()}
}

再为原有的 http.HandlerFunc 提供一个转换成 Handler 的函数

1
2
3
4
5
6
7
func Wrap(httpHandlerFunc http.HandlerFunc) Handler {
fn := func(rw http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
httpHandlerFunc(rw, r)
next(rw, r)
}
return HandlerFunc(fn)
}

最后编写一些测试用例来进行测试

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
func TestFuncTest(t *testing.T) {
fn1 := func(rw http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
fmt.Println("fn1 before")
_, _ = rw.Write([]byte("fn1b "))
next(rw, r)
_, _ = rw.Write([]byte("fn1a "))
fmt.Println("fn1 after")

}
fn2 := func(rw http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
fmt.Println("fn2 before")
_, _ = rw.Write([]byte("fn2b "))
next(rw, r)
_, _ = rw.Write([]byte("fn2a "))
fmt.Println("fn2 after")
}
fn3 := func(rw http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
fmt.Println("fn3 before")
_, _ = rw.Write([]byte("fn3b "))
next(rw, r)
_, _ = rw.Write([]byte("fn3a "))
fmt.Println("fn3 after")
}

fnBuzz := func(rw http.ResponseWriter, r *http.Request) {
_, _ = rw.Write([]byte("buzz "))
fmt.Println("buzz logic")
}

tableTests := []struct {
name string
expectedStr string
handlers []Handler
}{
{
name: "normal test",
expectedStr: "fn1b fn2b fn3b fn3a fn2a fn1a ",
handlers: []Handler{HandlerFunc(fn1), HandlerFunc(fn2), HandlerFunc(fn3)},
},
{
name: "empty test",
expectedStr: "",
handlers: nil,
},
{
name: "with buzz",
expectedStr: "fn1b fn2b buzz fn2a fn1a ",
handlers: []Handler{HandlerFunc(fn1), HandlerFunc(fn2), Wrap(fnBuzz)},
},
}

for _, tt := range tableTests {
hmw := New(tt.handlers...)
ts := httptest.NewServer(hmw)
defer ts.Close()
res, err := http.Get(ts.URL)
if err != nil {
t.Fatal(err)
}
resArr, _ := ioutil.ReadAll(res.Body)
str := string(resArr)
defer res.Body.Close()
if str != tt.expectedStr {
t.Errorf("name: %s , expected %s, but got %s", tt.name, tt.expectedStr, str)
}
}
}

全部代码如下

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
type Handler interface {
ServeHTTP(rw http.ResponseWriter, r *http.Request, next http.HandlerFunc)
}

type HandlerFunc func(rw http.ResponseWriter, r *http.Request, next http.HandlerFunc)

func (h HandlerFunc) ServeHTTP(rw http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
h(rw, r, next)
}

type middleware struct {
handler Handler
next *middleware
}

func (mw middleware) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
mw.handler.ServeHTTP(rw, r, mw.next.ServeHTTP)
}

type HTTPMiddleware struct {
handlers []Handler
mw *middleware
}

func (hmw *HTTPMiddleware) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
hmw.mw.ServeHTTP(rw, r)
}

func New(handlers ...Handler) *HTTPMiddleware {
return &HTTPMiddleware{
handlers: handlers,
mw: build(handlers),
}
}

func Wrap(httpHandlerFunc http.HandlerFunc) Handler {
fn := func(rw http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
httpHandlerFunc(rw, r)
next(rw, r)
}
return HandlerFunc(fn)
}

func build(handlers []Handler) *middleware {
if len(handlers) == 0 {
return newEmptyMiddleWare()
}
if len(handlers) == 1 {
return newCoreMiddleWare(handlers[0])
}
return &middleware{
handler: handlers[0],
next: build(handlers[1:]),
}
}

func newEmptyMiddleWare() *middleware {
fn := func(rw http.ResponseWriter, r *http.Request, next http.HandlerFunc) {}
return &middleware{handler: HandlerFunc(fn), next: &middleware{}}
}

func newCoreMiddleWare(handler Handler) *middleware {
fn := func(rw http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
handler.ServeHTTP(rw, r, func(http.ResponseWriter, *http.Request) {})
}
return &middleware{handler: HandlerFunc(fn), next: newEmptyMiddleWare()}
}