Skip to content

4.2 Custom Middleware

本节介绍如何创建自定义中间件,实现特定的业务逻辑。


概述

当内置中间件不满足需求时,可以创建自定义中间件。LangChain 提供两种实现方式:

  • 装饰器方式:简单、单钩子
  • 类方式:复杂、多钩子

钩子类型

Node-style Hooks(节点钩子)

在执行的特定点顺序运行:

钩子触发时机调用次数
before_agentAgent 启动前每次调用 1 次
before_model每次模型调用前每次模型调用
after_model每次模型响应后每次模型调用
after_agentAgent 完成后每次调用 1 次

Wrap-style Hooks(包装钩子)

控制处理器的调用方式,像函数调用一样嵌套:

钩子包装目标
wrap_model_call模型调用
wrap_tool_call工具调用

装饰器方式

最简单的创建方式,适合单个钩子:

before_model

在每次模型调用前执行:

python
from langchain.agents import before_model, AgentState, Runtime

@before_model
def check_message_limit(state: AgentState, runtime: Runtime):
    """检查消息数量限制"""
    if len(state["messages"]) >= 50:
        print("消息过多,终止执行")
        return {"jump_to": "end"}  # 提前终止
    return None

agent = create_agent(
    "gpt-4o",
    tools=[my_tools],
    middleware=[check_message_limit]
)

after_model

在每次模型响应后执行:

python
from langchain.agents import after_model

@after_model
def log_response(state: AgentState, response, runtime: Runtime):
    """记录模型响应"""
    print(f"模型响应: {response.content[:100]}...")
    return None

agent = create_agent(
    "gpt-4o",
    tools=[my_tools],
    middleware=[log_response]
)

before_agent

在 Agent 启动前执行(每次调用只执行一次):

python
from langchain.agents import before_agent

@before_agent
def initialize_session(state: AgentState, runtime: Runtime):
    """初始化会话"""
    state["session_id"] = generate_session_id()
    state["start_time"] = time.time()
    print(f"Session started: {state['session_id']}")
    return None

after_agent

在 Agent 完成后执行:

python
from langchain.agents import after_agent

@after_agent
def cleanup_session(state: AgentState, runtime: Runtime):
    """清理会话"""
    duration = time.time() - state.get("start_time", 0)
    print(f"Session ended. Duration: {duration:.2f}s")
    return None

wrap_model_call

包装模型调用:

python
from langchain.agents import wrap_model_call

@wrap_model_call
def time_model_call(state, runtime, call_next):
    """计时模型调用"""
    start = time.time()
    result = call_next()  # 执行实际的模型调用
    duration = time.time() - start
    print(f"Model call took {duration:.2f}s")
    return result

wrap_tool_call

包装工具调用:

python
from langchain.agents import wrap_tool_call

@wrap_tool_call
def validate_tool_call(state, tool_name, tool_args, runtime, call_next):
    """验证工具调用"""
    print(f"Calling tool: {tool_name}")
    print(f"Arguments: {tool_args}")

    # 可以在这里进行验证或修改参数
    if tool_name == "dangerous_tool":
        raise ValueError("This tool is not allowed")

    result = call_next()  # 执行实际的工具调用
    print(f"Tool result: {result}")
    return result

类方式

适合复杂的多钩子中间件:

基本结构

python
from langchain.agents import AgentMiddleware, AgentState, Runtime

class LoggingMiddleware(AgentMiddleware):
    """日志记录中间件"""

    def __init__(self, log_level: str = "info"):
        self.log_level = log_level

    def before_agent(self, state: AgentState, runtime: Runtime):
        """Agent 启动时"""
        print(f"[{self.log_level}] Agent starting...")
        return None

    def before_model(self, state: AgentState, runtime: Runtime):
        """模型调用前"""
        msg_count = len(state["messages"])
        print(f"[{self.log_level}] Model call with {msg_count} messages")
        return None

    def after_model(self, state: AgentState, response, runtime: Runtime):
        """模型响应后"""
        print(f"[{self.log_level}] Model responded")
        return None

    def after_agent(self, state: AgentState, runtime: Runtime):
        """Agent 完成时"""
        print(f"[{self.log_level}] Agent completed")
        return None

# 使用
agent = create_agent(
    "gpt-4o",
    tools=[my_tools],
    middleware=[LoggingMiddleware(log_level="debug")]
)

复杂示例:分析中间件

python
class AnalyticsMiddleware(AgentMiddleware):
    """分析追踪中间件"""

    def __init__(self, analytics_client):
        self.client = analytics_client
        self.metrics = {
            "model_calls": 0,
            "tool_calls": 0,
            "tokens_used": 0,
            "start_time": None,
        }

    def before_agent(self, state, runtime):
        self.metrics["start_time"] = time.time()
        self.client.track("agent_started", {
            "thread_id": runtime.config.get("thread_id")
        })
        return None

    def after_model(self, state, response, runtime):
        self.metrics["model_calls"] += 1
        if hasattr(response, "response_metadata"):
            usage = response.response_metadata.get("usage", {})
            self.metrics["tokens_used"] += usage.get("total_tokens", 0)
        return None

    def wrap_tool_call(self, state, tool_name, tool_args, runtime, call_next):
        self.metrics["tool_calls"] += 1
        start = time.time()
        result = call_next()
        duration = time.time() - start

        self.client.track("tool_called", {
            "tool_name": tool_name,
            "duration": duration,
        })
        return result

    def after_agent(self, state, runtime):
        total_time = time.time() - self.metrics["start_time"]
        self.client.track("agent_completed", {
            "model_calls": self.metrics["model_calls"],
            "tool_calls": self.metrics["tool_calls"],
            "tokens_used": self.metrics["tokens_used"],
            "total_time": total_time,
        })
        return None

自定义状态 Schema

扩展 Agent 状态以存储自定义数据:

python
from langchain.agents import AgentState, create_agent
from typing import Optional

class CustomAgentState(AgentState):
    """自定义 Agent 状态"""
    user_id: str
    session_id: Optional[str] = None
    request_count: int = 0
    custom_data: dict = {}

@before_model
def track_requests(state: CustomAgentState, runtime):
    """追踪请求数量"""
    state["request_count"] = state.get("request_count", 0) + 1
    return None

agent = create_agent(
    "gpt-4o",
    tools=[my_tools],
    state_schema=CustomAgentState,
    middleware=[track_requests]
)

Agent 跳转

使用 jump_to 控制 Agent 流程:

python
@before_model
def check_safety(state, runtime):
    """安全检查,必要时提前终止"""
    last_message = state["messages"][-1].content

    # 检测危险内容
    if contains_dangerous_content(last_message):
        # 添加警告消息
        state["messages"].append(
            AIMessage(content="检测到不安全内容,终止执行。")
        )
        # 跳转到结束
        return {"jump_to": "end"}

    return None

执行顺序

python
# 中间件列表
middleware = [middleware_1, middleware_2, middleware_3]

# Before hooks: 按顺序执行
# middleware_1.before_model() → middleware_2.before_model() → middleware_3.before_model()

# After hooks: 逆序执行
# middleware_3.after_model() → middleware_2.after_model() → middleware_1.after_model()

# Wrap hooks: 嵌套执行
# middleware_1.wrap_model_call(
#     middleware_2.wrap_model_call(
#         middleware_3.wrap_model_call(
#             actual_model_call()
#         )
#     )
# )

实际案例

速率限制中间件

python
import time
from collections import deque

class RateLimitMiddleware(AgentMiddleware):
    """速率限制中间件"""

    def __init__(self, max_requests: int = 10, window_seconds: int = 60):
        self.max_requests = max_requests
        self.window_seconds = window_seconds
        self.request_times = deque()

    def before_model(self, state, runtime):
        current_time = time.time()

        # 清理过期的请求记录
        while self.request_times and \
              self.request_times[0] < current_time - self.window_seconds:
            self.request_times.popleft()

        # 检查是否超过限制
        if len(self.request_times) >= self.max_requests:
            wait_time = self.request_times[0] + self.window_seconds - current_time
            print(f"Rate limit exceeded. Waiting {wait_time:.1f}s...")
            time.sleep(wait_time)

        self.request_times.append(current_time)
        return None

成本追踪中间件

python
class CostTrackingMiddleware(AgentMiddleware):
    """成本追踪中间件"""

    # 每 1K tokens 的价格(美元)
    PRICING = {
        "gpt-4o": {"input": 0.005, "output": 0.015},
        "gpt-3.5-turbo": {"input": 0.0005, "output": 0.0015},
    }

    def __init__(self):
        self.total_cost = 0.0

    def after_model(self, state, response, runtime):
        if hasattr(response, "response_metadata"):
            usage = response.response_metadata.get("usage", {})
            model = response.response_metadata.get("model", "gpt-4o")

            pricing = self.PRICING.get(model, self.PRICING["gpt-4o"])
            input_cost = (usage.get("prompt_tokens", 0) / 1000) * pricing["input"]
            output_cost = (usage.get("completion_tokens", 0) / 1000) * pricing["output"]

            self.total_cost += input_cost + output_cost
            print(f"Current cost: ${self.total_cost:.4f}")

        return None

缓存中间件

python
import hashlib
import json

class CacheMiddleware(AgentMiddleware):
    """响应缓存中间件"""

    def __init__(self, cache_client):
        self.cache = cache_client

    def _get_cache_key(self, state):
        """生成缓存键"""
        messages_str = json.dumps([
            {"role": m.type, "content": m.content}
            for m in state["messages"]
        ])
        return hashlib.md5(messages_str.encode()).hexdigest()

    def before_model(self, state, runtime):
        """检查缓存"""
        cache_key = self._get_cache_key(state)
        cached = self.cache.get(cache_key)

        if cached:
            print("Cache hit!")
            # 返回缓存的响应,跳过模型调用
            state["messages"].append(AIMessage(content=cached))
            return {"jump_to": "end"}

        return None

    def after_model(self, state, response, runtime):
        """存储到缓存"""
        if response.content:
            cache_key = self._get_cache_key(state)
            self.cache.set(cache_key, response.content, ttl=3600)
        return None

最佳实践

实践说明
保持专注每个中间件只做一件事
优雅降级中间件错误不应导致 Agent 崩溃
避免阻塞不要在钩子中执行耗时操作
记录日志便于调试和监控
测试充分单独测试每个中间件

调试技巧

python
@before_model
def debug_state(state, runtime):
    """调试:打印当前状态"""
    print("=" * 50)
    print(f"Messages: {len(state['messages'])}")
    print(f"Last message: {state['messages'][-1].content[:100]}")
    print(f"Config: {runtime.config}")
    print("=" * 50)
    return None

上一节4.1 Built-in Middleware

下一章5.0 Advanced Usage

基于 MIT 许可证发布。内容版权归作者所有。