Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 1aa4496

Browse files
feat(langchain): register stream transformers on middleware (#37591)
Adds a `transformers` attribute to `AgentMiddleware` so middleware can declare scope-aware `StreamTransformer` factories alongside their `tools` and lifecycle hooks. `create_agent` merges middleware-registered factories with any caller-supplied ones at compile time. ## API ```python class MyMiddleware(AgentMiddleware): transformers = (MyTransformer,) # factory: (scope,) -> StreamTransformer ``` When the agent compiles, the final transformer order on the run mux is: 1. Built-in ``ToolCallTransformer`` 2. Middleware-registered factories, in middleware order 3. Caller-supplied ``transformers=`` from ``create_agent`` This ordering keeps the built-in tool-call projection in front of any consumer transformers and gives caller-supplied entries the final word.
1 parent d2931d8 commit 1aa4496

3 files changed

Lines changed: 168 additions & 5 deletions

File tree

libs/langchain_v1/langchain/agents/factory.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ class _ComposedExtendedModelResponse(Generic[ResponseT]):
8181
from langgraph.graph.state import CompiledStateGraph
8282
from langgraph.runtime import Runtime
8383
from langgraph.store.base import BaseStore
84+
from langgraph.stream._mux import TransformerFactory
8485
from langgraph.types import Checkpointer
8586

8687
from langchain.agents.middleware.types import ToolCallWrapper
@@ -708,7 +709,7 @@ def create_agent(
708709
debug: bool = False,
709710
name: str | None = None,
710711
cache: BaseCache[Any] | None = None,
711-
transformers: Sequence[Callable[[tuple[str, ...]], Any]] | None = None,
712+
transformers: Sequence[TransformerFactory] | None = None,
712713
) -> CompiledStateGraph[
713714
AgentState[ResponseT], ContextT, _InputAgentState, _OutputAgentState[ResponseT]
714715
]:
@@ -806,9 +807,11 @@ def create_agent(
806807
cache: An optional `BaseCache` instance to enable caching of graph execution.
807808
transformers: Optional sequence of scope-aware `StreamTransformer`
808809
factories to register on the compiled graph in addition to
809-
the agent defaults. Each factory is invoked per-scope
810-
(`factory(scope)`) so subgraph mini-muxes get fresh
811-
instances. Appended after the built-in `ToolCallTransformer`.
810+
the agent defaults. Each factory is invoked as `factory(scope)`
811+
so every invocation receives a fresh instance. The final order
812+
on the compiled graph is: `ToolCallTransformer`, then any
813+
factories declared by middleware via
814+
`AgentMiddleware.transformers`, then any factories supplied here.
812815
813816
Returns:
814817
A compiled `StateGraph` that can be used for chat interactions.
@@ -1662,6 +1665,8 @@ async def amodel_node(state: AgentState[Any], runtime: Runtime[ContextT]) -> lis
16621665
if name:
16631666
config["metadata"]["lc_agent_name"] = name
16641667

1668+
middleware_transformers = [t for m in middleware for t in getattr(m, "transformers", ())]
1669+
16651670
return graph.compile(
16661671
checkpointer=checkpointer,
16671672
store=store,
@@ -1670,7 +1675,11 @@ async def amodel_node(state: AgentState[Any], runtime: Runtime[ContextT]) -> lis
16701675
debug=debug,
16711676
name=name,
16721677
cache=cache,
1673-
transformers=[ToolCallTransformer, *(transformers or ())],
1678+
transformers=[
1679+
ToolCallTransformer,
1680+
*middleware_transformers,
1681+
*(transformers or ()),
1682+
],
16741683
).with_config(config)
16751684

16761685

libs/langchain_v1/langchain/agents/middleware/types.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
from langchain_core.language_models.chat_models import BaseChatModel
4141
from langchain_core.tools import BaseTool
4242
from langgraph.runtime import Runtime
43+
from langgraph.stream._mux import TransformerFactory
4344
from langgraph.types import Command
4445

4546
from langchain.agents.structured_output import ResponseFormat
@@ -397,6 +398,16 @@ class AgentMiddleware(Generic[StateT, ContextT, ResponseT]):
397398
tools: Sequence[BaseTool]
398399
"""Additional tools registered by the middleware."""
399400

401+
transformers: Sequence[TransformerFactory] = ()
402+
"""Stream transformer factories registered by the middleware.
403+
404+
Each entry is a scope-aware factory invoked as `factory(scope)` so every
405+
invocation receives a fresh instance. Factories are merged with the
406+
`transformers` argument of [`create_agent`][langchain.agents.create_agent]
407+
at graph compile time, after the `ToolCallTransformer` and before any
408+
user-supplied entries.
409+
"""
410+
400411
@property
401412
def name(self) -> str:
402413
"""The name of the middleware instance.
Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
"""Tests for middleware-registered stream transformers."""
2+
3+
from __future__ import annotations
4+
5+
from typing import TYPE_CHECKING, Any
6+
7+
from langchain_core.messages import HumanMessage
8+
from langgraph.prebuilt import ToolCallTransformer
9+
from langgraph.stream import StreamChannel, StreamTransformer
10+
11+
from langchain.agents.factory import create_agent
12+
from langchain.agents.middleware.types import AgentMiddleware
13+
from tests.unit_tests.agents.model import FakeToolCallingModel
14+
15+
if TYPE_CHECKING:
16+
from langgraph.stream._types import ProtocolEvent
17+
18+
19+
class _MiddlewareMarker(StreamTransformer):
20+
"""Marker transformer used to assert registration order."""
21+
22+
required_stream_modes = ()
23+
24+
def __init__(self, scope: tuple[str, ...] = ()) -> None:
25+
super().__init__(scope)
26+
self._log: StreamChannel[int] = StreamChannel()
27+
28+
def init(self) -> dict[str, Any]:
29+
return {"middleware_marker": self._log}
30+
31+
def process(self, event: ProtocolEvent) -> bool:
32+
del event
33+
return True
34+
35+
36+
class _UserMarker(StreamTransformer):
37+
"""Second marker to verify user-supplied transformers append last."""
38+
39+
required_stream_modes = ()
40+
41+
def __init__(self, scope: tuple[str, ...] = ()) -> None:
42+
super().__init__(scope)
43+
self._log: StreamChannel[int] = StreamChannel()
44+
45+
def init(self) -> dict[str, Any]:
46+
return {"user_marker": self._log}
47+
48+
def process(self, event: ProtocolEvent) -> bool:
49+
del event
50+
return True
51+
52+
53+
def test_middleware_transformer_registered_on_compiled_graph() -> None:
54+
"""A `transformers` factory declared on middleware is wired into the run mux."""
55+
56+
class _Middleware(AgentMiddleware):
57+
transformers = (_MiddlewareMarker,)
58+
59+
agent = create_agent(model=FakeToolCallingModel(), tools=[], middleware=[_Middleware()])
60+
61+
run = agent.stream_events({"messages": [HumanMessage("hi")]}, version="v3")
62+
63+
assert "middleware_marker" in run._mux.extensions # type: ignore[attr-defined]
64+
# Drain to close the run cleanly.
65+
list(run.tool_calls)
66+
67+
68+
def test_middleware_and_user_transformers_compose_in_order() -> None:
69+
"""Order is: built-in `ToolCallTransformer` → middleware → user-supplied."""
70+
71+
class _Middleware(AgentMiddleware):
72+
transformers = (_MiddlewareMarker,)
73+
74+
agent = create_agent(
75+
model=FakeToolCallingModel(),
76+
tools=[],
77+
middleware=[_Middleware()],
78+
transformers=[_UserMarker],
79+
)
80+
81+
run = agent.stream_events({"messages": [HumanMessage("hi")]}, version="v3")
82+
83+
transformers = run._mux._transformers # type: ignore[attr-defined]
84+
tool_call_idx = next(
85+
i for i, t in enumerate(transformers) if isinstance(t, ToolCallTransformer)
86+
)
87+
middleware_idx = next(i for i, t in enumerate(transformers) if isinstance(t, _MiddlewareMarker))
88+
user_idx = next(i for i, t in enumerate(transformers) if isinstance(t, _UserMarker))
89+
90+
assert tool_call_idx < middleware_idx < user_idx, (
91+
"transformers must register as: built-in, then middleware, then user-supplied"
92+
)
93+
94+
list(run.tool_calls)
95+
96+
97+
def test_transformers_from_multiple_middleware_preserve_middleware_order() -> None:
98+
"""Transformers across middleware register in middleware-list order."""
99+
100+
class _MarkerA(_MiddlewareMarker):
101+
def init(self) -> dict[str, Any]:
102+
return {"marker_a": self._log}
103+
104+
class _MarkerB(_MiddlewareMarker):
105+
def init(self) -> dict[str, Any]:
106+
return {"marker_b": self._log}
107+
108+
class _MwA(AgentMiddleware):
109+
transformers = (_MarkerA,)
110+
111+
class _MwB(AgentMiddleware):
112+
transformers = (_MarkerB,)
113+
114+
agent = create_agent(
115+
model=FakeToolCallingModel(),
116+
tools=[],
117+
middleware=[_MwA(), _MwB()],
118+
)
119+
120+
run = agent.stream_events({"messages": [HumanMessage("hi")]}, version="v3")
121+
122+
transformers = run._mux._transformers # type: ignore[attr-defined]
123+
idx_a = next(i for i, t in enumerate(transformers) if isinstance(t, _MarkerA))
124+
idx_b = next(i for i, t in enumerate(transformers) if isinstance(t, _MarkerB))
125+
assert idx_a < idx_b
126+
127+
list(run.tool_calls)
128+
129+
130+
def test_middleware_without_transformers_does_not_affect_registry() -> None:
131+
"""Middleware that omits `transformers` leaves the default registry intact."""
132+
133+
class _Middleware(AgentMiddleware):
134+
pass
135+
136+
agent = create_agent(model=FakeToolCallingModel(), tools=[], middleware=[_Middleware()])
137+
run = agent.stream_events({"messages": [HumanMessage("hi")]}, version="v3")
138+
139+
transformers = run._mux._transformers # type: ignore[attr-defined]
140+
assert any(isinstance(t, ToolCallTransformer) for t in transformers)
141+
assert not any(isinstance(t, _MiddlewareMarker) for t in transformers)
142+
143+
list(run.tool_calls)

0 commit comments

Comments
 (0)