相关推荐recommended
Python FastAPI系列:自定义FastAPI middleware中间件
作者:mmseoamin日期:2024-02-02

Python FastAPI系列:自定义FastAPI middleware中间件

    • FastAPI middleware中间件执行逻辑
    • 创建FastAPI middleware中间件
      • 使用装饰器创建中间件
      • 通过继承BaseHTTPMiddleware创建中间件
      • 根据ASGI规范创建中间件

        ​​

        在一些情况下,我们需要对整个FastAPI应用的全部或部分路由执行一些通用的功能,例如身份验证、日志记录、错误处理等,我们可以通过自定义FastAPI middleware中间件来完成。在FastAPI中也自带了一些常用的中间件来完成请求协议限定、跨域提交等。

        一般情况下,碰到以下需求场景时,可以考虑使用FastAPI middleware来实现:

        1、身份验证:验证请求的身份,例如检查 JWT token 或使用 OAuth2 进行验证。

        2、日志记录:记录请求和响应的日志,包括请求方法、URL、响应状态码等信息。

        3、错误处理:处理应用程序中的异常情况,例如捕获异常并返回自定义的错误响应。

        4、请求处理:对请求进行处理,例如解析请求参数、验证请求数据等。

        5、缓存:可以使用中间件来实现缓存功能,例如在中间件中检查缓存中是否存在请求的响应,如果存在则直接返回缓存的响应。

        FastAPI middleware中间件执行逻辑

        FastAPI middleware中间件工作在每一次的Request请求和Response响应之间,可以对Request和Response进行修改,其详细过程如下:

        1、接收来自客户端的Request请求;

        2、针对该次Request请求,自定义操作;

        3、然后将Request请求传回原路由,由路由中定义的业务逻辑继续处理该次Request请求;

        4、原路由业务处理完毕后,FastAPI middleware中间件将获得该路由产生的Response响应结果,此时可以针对Response响应结果自定义操作;

        5、最后将Response响应结果发送给客户端;

        创建FastAPI middleware中间件

        使用装饰器创建中间件

        在FastAPI应用中可以使用app.middleware(“http”)装饰器创建FastAPI middleware中间件。在以下例子中,我们在程序进入middleware时记录一对开始结束时间作为middleware时间,在程序进入路由业务逻辑中记录一对开始结束时间作为router时间,其中router开始时间为获取middleware写入request的middleware时间基础上加1小时,详细代码如下:

        main.py

        import time
        from datetime import datetime, timedelta
        import uvicorn as uvicorn
        from fastapi import FastAPI
        from starlette.requests import Request
        from starlette.responses import Response
        app = FastAPI()
        # 将时间格式化为字符串
        def _time2str(time_str):
            return datetime.strftime(time_str, '%Y-%m-%d %H:%M:%S')
        # 将字符串转换为时间
        def _str2time(time_str):
            return datetime.strptime(time_str, '%Y-%m-%d %H:%M:%S')
        # 定义中间件
        @app.middleware("http")
        async def process_time_middleware(request: Request, call_next):
            # 接收来自客户端的Request请求;
            headers = dict(request.scope['headers'])
            # 定义middleware开始时间
            middleware_start_time = _time2str(datetime.now())
            # 将middleware开始时间添加到request的headers中,这里request.headers是一个可读可写的对象,但是它的值是不可变的,所以这里需要将request.headers转换为字典,然后再修改字典的值,最后再将字典转换为元组,赋值给request.scope['headers'];
            headers[b'middleware_start_time'] = middleware_start_time.encode('utf-8')
            request.scope['headers'] = [(k, v) for k, v in headers.items()]
            # 将Request请求传回原路由
            response = await call_next(request)
            # 为了更好的观察middleware的执行过程,这里让middleware休眠1秒钟
            time.sleep(1)
            # 接收来自原路由的Response响应,将middleware结束时间添加到response的headers中
            response.headers["middleware_start_time"] = middleware_start_time
            response.headers["middleware_end_time"] = _time2str(datetime.now())
            return response
        @app.get("/")
        async def index(request: Request, response: Response):
            # 在路由中获取middleware通过request传递过来的middleware开始时间
            middleware_start_time = _str2time(request.headers.get("middleware_start_time"))
            # 在middleware_start_time的基础上加1小时,作为router开始时间,加2小时,作为router结束时间
            router_start_time = _str2time(middleware_start_time) + timedelta(hours=1)
            router_end_time = _str2time(request.headers.get("middleware_start_time")) + timedelta(hours=2)
            # 将router开始时间和router结束时间添加到response的headers中
            response.headers["router_start_time"] = _time2str(router_start_time)
            response.headers["router_end_time"] = _time2str(router_end_time)
            return "test middleware"
        if __name__ == '__main__':
            uvicorn.run(app="main:app", port=8088, reload=True)
        

        下图为FastAPI middleware中间件执行结果:

        Python FastAPI系列:自定义FastAPI middleware中间件,FastAPI middleware中间件执行结果,第1张

        通过继承BaseHTTPMiddleware创建中间件

        在实际的项目中往往会定义多个FastAPI middleware中间件以实现完整的业务需求,这时我们可以通过通过继承starlette.middleware.base.Middleware类来创建自定义的中间件。以下为上述需求通过继承BaseHTTPMiddleware方式的实现情况:

        项目文件可拆分为main.py(主程序),process_time_middleware.py(自定义中间件类),utils.py(工具函数)

        main.py

        from starlette.responses import Response
        from fapi.process_time_middleware import ProcessTimeMiddleware
        from fapi.utils import _str2time, _time2str
        app = FastAPI()
        # 将中间件添加到主程序中
        app.add_middleware(ProcessTimeMiddleware, header_namespace="middleware")
        @app.get("/")
        async def index(request: Request, response: Response):
            # 在路由中获取middleware通过request传递过来的middleware开始时间
            middleware_start_time = _str2time(request.headers.get("middleware_start_time"))
            # 在middleware_start_time的基础上加1小时,作为router开始时间,加2小时,作为router结束时间
            router_start_time = middleware_start_time + timedelta(hours=1)
            router_end_time = middleware_start_time + timedelta(hours=2)
            # 将router开始时间和router结束时间添加到response的headers中
            response.headers["router_start_time"] = _time2str(router_start_time)
            response.headers["router_end_time"] = _time2str(router_end_time)
            return "test middleware"
        if __name__ == '__main__':
            uvicorn.run(app="main:app", port=8088, reload=True)
        

        process_time_middleware.py

        import time
        from datetime import datetime
        from fastapi import Request
        from starlette.middleware.base import BaseHTTPMiddleware
        from fapi.utils import _time2str
        class ProcessTimeMiddleware(BaseHTTPMiddleware):
            def __init__(self, app, header_namespace: str):
                super().__init__(app)
                # 自定义参数,用于定义middleware的header名称空间
                self.header_namespace = header_namespace
            async def dispatch(self, request: Request, call_next):
                # 接收来自客户端的Request请求;
                headers = dict(request.scope['headers'])
                # 定义middleware开始时间
                middleware_start_time = _time2str(datetime.now())
                # 将middleware开始时间添加到request的headers中,这里request.headers是一个可读可写的对象,但是它的值是不可变的,所以这里需要将request.headers转换为字典,然后再修改字典的值,最后再将字典转换为元组,赋值给request.scope['headers'];
                headers[b'middleware_start_time'] = middleware_start_time.encode('utf-8')
                request.scope['headers'] = [(k, v) for k, v in headers.items()]
                # 将Request请求传回原路由
                response = await call_next(request)
                # 为了更好的观察middleware的执行过程,这里让middleware休眠1秒钟
                time.sleep(1)
                # 接收来自原路由的Response响应,将middleware结束时间添加到response的headers中
                response.headers[f"{self.header_namespace}_start_time"] = middleware_start_time
                response.headers[f"{self.header_namespace}_end_time"] = _time2str(datetime.now())
                return response
        

        utils.py

        from datetime import datetime
        # 将时间格式化为字符串
        def _time2str(time_str):
            return datetime.strftime(time_str, '%Y-%m-%d %H:%M:%S')
        # 将字符串转换为时间
        def _str2time(time_str):
            return datetime.strptime(time_str, '%Y-%m-%d %H:%M:%S')
        

        根据ASGI规范创建中间件

        根据ASGI规范来创建的中间件可以获得更加底层的功能,并增强了跨框架和服务器的互操作性。以下代码将通过创建纯ASGI类来实现上述例子中的需求。

        项目文件依然可拆分为main.py(主程序),process_time_asgi_middleware.py(自定义ASGI中间件类),utils.py(工具函数),其中main.py和utils.py与上面的保持一致:

        main.py

        from datetime import timedelta
        import uvicorn as uvicorn
        from fastapi import FastAPI
        from starlette.requests import Request
        from starlette.responses import Response
        from fapi.process_time_asgi_middleware import ProcessTimeASGIMiddleware
        from fapi.utils import _str2time, _time2str
        app = FastAPI()
        # 将ASGI中间件添加到主程序中
        app.add_middleware(ProcessTimeASGIMiddleware, header_namespace="middleware")
        @app.get("/")
        async def index(request: Request, response: Response):
            # 在路由中获取middleware通过request传递过来的middleware开始时间
            middleware_start_time = _str2time(request.headers.get("middleware_start_time"))
            # 在middleware_start_time的基础上加1小时,作为router开始时间,加2小时,作为router结束时间
            router_start_time = middleware_start_time + timedelta(hours=1)
            router_end_time = middleware_start_time + timedelta(hours=2)
            # 将router开始时间和router结束时间添加到response的headers中
            response.headers["router_start_time"] = _time2str(router_start_time)
            response.headers["router_end_time"] = _time2str(router_end_time)
            return "test middleware"
        if __name__ == '__main__':
            uvicorn.run(app="main:app", port=8088, reload=True)
        

        process_time_asgi_middleware.py

        import time
        from datetime import datetime
        from fastapi import Request
        from starlette.datastructures import MutableHeaders
        from fapi.utils import _time2str
        class ProcessTimeASGIMiddleware:
            def __init__(self, app, header_namespace: str):
                self.app = app
                # 自定义参数,用于定义middleware的header名称空间
                self.header_namespace = header_namespace
            #ASGI 中间件必须是接受三个参数的可调用对象,即 scope、receive、send;
            async def __call__(self, scope, receive, send):
                request = Request(scope)
                # 接收来自客户端的Request请求;
                headers = dict(request.scope['headers'])
                # 定义middleware开始时间
                middleware_start_time = _time2str(datetime.now())
                headers[b'middleware_start_time'] = middleware_start_time.encode('utf-8')
                request.scope['headers'] = [(k, v) for k, v in headers.items()]
                
                # 定义Send函数,用于将middleware开始时间和middleware结束时间添加到response的headers中
                async def add_headers(message):
                    if message["type"] == "http.response.start":
                        new_headers = MutableHeaders(scope=message)
                        new_headers.append(f"{self.header_namespace}_start_time", middleware_start_time)
                        new_headers.append(f"{self.header_namespace}_end_time", _time2str(datetime.now()))
                    await send(message)
                # 将scope、receive、add_headers传递给原始的ASGI应用程序
                return await self.app(scope, receive, add_headers)
        

        utils.py

        from datetime import datetime
        # 将时间格式化为字符串
        def _time2str(time_str):
            return datetime.strftime(time_str, '%Y-%m-%d %H:%M:%S')
        # 将字符串转换为时间
        def _str2time(time_str):
            return datetime.strptime(time_str, '%Y-%m-%d %H:%M:%S')
        

        在代码中,ASGI中间件必须是接受三个参数的可调用对象,即 scope、receive、send。其中

        1、scope是保存有关连接的信息的字典,其中scope[“type”]的type可以是:

        “http”:用于 HTTP 请求。

        “websocket”:用于 WebSocket 连接。

        “lifespan”:用于 ASGI 生命周期消息。

        2、receive用于与ASGI服务器交换ASGI事件消息。这些消息的类型和内容取决于作用域类型。

        当然,也可以使用函数来代替纯ASGI中间件类:

        import functools
        def asgi_middleware():
            def asgi_decorator(app):
                @functools.wraps(app)
                async def wrapped_app(scope, receive, send):
                    await app(scope, receive, send)
                return wrapped_app
            return asgi_decorator
        

        综上,在FastAPI中利用middleware中间件可以获得更加强大的功能,并使得程序更加优雅可读!