中间件

目标

  • 设计并实现 Web 框架的中间件机制。
  • 实现通用的Logger中间件,能够记录请求到响应所花费的时间。

中间件的概念

中间件,就是非业务的技术类组件。Web 框架本身不可能去理解所有的业务,因而不可能实现所有的功能。因此,框架需要有一个插口,允许用户自己定义功能,嵌入到框架中,仿佛这个功能是框架原生支持的一样。因此,对中间件而言,需要考虑2个比较关键的点:

  • 插入点在哪?使用框架的人并不关心底层逻辑的具体实现,如果插入点太底层,中间件逻辑就会非常复杂。如果插入点离用户太近,那和用户直接定义一组函数,每次在 Handler 中手工调用没有多大的优势了。
  • 中间件的输入是什么?中间件的输入,决定了扩展能力。暴露的参数太少,用户发挥空间有限。

那对于一个 Web 框架而言,中间件应该设计成什么样呢?接下来的实现,基本参考了 Gin 框架。

中间件设计

Gee 的中间件的定义与路由映射的 Handler 一致,处理的输入是Context对象。插入点是框架接收到请求初始化Context对象后,允许用户使用自己定义的中间件做一些额外的处理,例如记录日志等,以及对Context进行二次加工。另外通过调用(*Context).Next()函数,中间件可等待用户自己定义的 Handler处理结束后,再做一些额外的操作,例如计算本次处理所用时间等。即 Gee 的中间件支持用户在请求被处理的前后,做一些额外的操作。举个例子,我们希望最终能够支持如下定义的中间件,c.Next()表示等待执行其他的中间件或用户的Handler

1
2
3
4
5
6
7
8
9
10
func Logger() HandlerFunc {
return func(c *Context) {
// 开始时间
t := time.Now()
// 执行next
c.Next()
// 计算总时间
log.Printf("[%d] %s in %v", c.StatusCode, c.Req.RequestURI, time.Since(t))
}
}

中间件的调用顺序以及介绍可以看这篇文章

gin框架注册中间件

在gin框架中,我们可以为每个路由添加任意数量的中间件。

image-20221129142317272

image-20221129142332221

image-20221129142348794

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
package main

import (
"fmt"
"github.com/gin-gonic/gin"
"net/http"
)

func indexHandler(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"msg": "index",
})
println("index")
}

// 定义一个中间件
func m2(c *gin.Context) {
fmt.Println("m2 in ...")
c.Next() // 调用后续的处理函数
//c.Abort() // 阻止调用后续的处理函数
fmt.Println("m2 out ...")
}

// 定义一个中间件
func m1(c *gin.Context) {
fmt.Println("m1 in ...")
c.Next() // 调用后续的处理函数
//c.Abort() // 阻止调用后续的处理函数
fmt.Println("m1 out ...")
}

func main() {
r := gin.Default()

r.GET("/index", m1, m2, indexHandler)

r.Run(":9090")
}

上程序运行结果为:

image-20221129142506736

此Web框架中间件设计

需要支持设置多个中间件,依次进行调用。

设计思路:当接收到请求后,匹配路由,该请求的所有信息都保存在Context中。中间件也不例外,接收到请求后,应查找所有应作用于该路由的中间件,保存在Context中,依次进行调用。

context.go

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
type Context struct {
// origin objects
Writer http.ResponseWriter
Req *http.Request
// request info
Path string
Method string
Params map[string]string
// response info
StatusCode int
// middleware
handlers []HandlerFunc
index int // 记录当前执行到第几个中间件
}

func newContext(w http.ResponseWriter, req *http.Request) *Context {
return &Context{
Path: req.URL.Path,
Method: req.Method,
Req: req,
Writer: w,
index: -1,
}
}

// 依次调用中间件
func (c *Context) Next() {
c.index++
s := len(c.handlers)
for ; c.index < s; c.index++ {
c.handlers[c.index](c)
}
}

看一个例子:

1
2
3
4
5
6
7
8
9
10
func A(c *Context) {
part1
c.Next()
part2
}
func B(c *Context) {
part3
c.Next()
part4
}

上述代码执行顺序就是part1 -> part3 -> Handler -> part 4 -> part2

代码实现

将中间件应用到某个 Group

gee.go

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
// 将中间件应用到某个 Group 
func (group *RouterGroup) Use(middlewares ...HandlerFunc) {
group.middlewares = append(group.middlewares, middlewares...)
}

//ServeHTTP 函数当接收到一个具体请求时,要判断该请求适用于哪些中间件
func (engine *Engine) ServeHTTP(w http.ResponseWriter, req *http.Request) {
var middlewares []HandlerFunc
// 在这里我们简单通过 URL 的前缀来判断。
for _, group := range engine.groups {
if strings.HasPrefix(req.URL.Path, group.prefix) {
middlewares = append(middlewares, group.middlewares...)
}
}
c := newContext(w, req)
c.handlers = middlewares // 得到中间件列表后,赋值给 c.handlers。
engine.router.handle(c)
}

router.go

1
2
3
4
5
6
7
8
9
10
11
12
13
14
func (r *router) handle(c *Context) {
n, params := r.getRoute(c.Method, c.Path)

if n != nil {
key := c.Method + "-" + n.pattern
c.Params = params
c.handlers = append(c.handlers, r.handlers[key])
} else {
c.handlers = append(c.handlers, func(c *Context) {
c.String(http.StatusNotFound, "404 NOT FOUND: %s\n", c.Path)
})
}
c.Next()
}

代码地址