Skip to content
Merged
Show file tree
Hide file tree
Changes from 54 commits
Commits
Show all changes
65 commits
Select commit Hold shift + click to select a range
260516d
messing with overriding model name at runtime
mattbrandman Nov 18, 2025
ef4dcf3
got multi model working
mattbrandman Nov 24, 2025
3441db0
removing docstring that was getting tested
mattbrandman Nov 24, 2025
1cb1c82
fixing types
mattbrandman Nov 24, 2025
6ed7461
more type fixes
mattbrandman Nov 24, 2025
a70c360
fixing types
mattbrandman Nov 24, 2025
d8d0039
adding coverage
mattbrandman Nov 25, 2025
7904768
allow arbitrary models
mattbrandman Dec 1, 2025
c04a14d
fixing tests
mattbrandman Dec 1, 2025
ce7c5fd
fixing lint
mattbrandman Dec 1, 2025
9998a5f
remove pandas comment
mattbrandman Dec 1, 2025
0857bbc
adding coverage
mattbrandman Dec 1, 2025
02308df
lint fixes
mattbrandman Dec 1, 2025
5477868
coverage
mattbrandman Dec 1, 2025
eccf3d1
adding is none coverage
mattbrandman Dec 2, 2025
a87d09e
simplifying registration
mattbrandman Dec 3, 2025
6913faa
fixes types
mattbrandman Dec 3, 2025
250606b
fixing coverage adding provider factory
mattbrandman Dec 3, 2025
7f354bb
adding tests
mattbrandman Dec 3, 2025
0ca186c
inlining normalize and setting up set current run context
mattbrandman Dec 4, 2025
6d2474c
required kwarg and rename of selection to model_id
mattbrandman Dec 10, 2025
ce4b17c
rename selection everywhere
mattbrandman Dec 10, 2025
70d2451
reorder arguments
mattbrandman Dec 10, 2025
733f0e3
do not put initial agent through provider factory
mattbrandman Dec 10, 2025
a34c7c3
adding docs
mattbrandman Dec 10, 2025
4d1c5a3
fixing tests
mattbrandman Dec 10, 2025
7a8d91b
fixing coverage
mattbrandman Dec 10, 2025
9902338
fix
mattbrandman Dec 10, 2025
d2930a5
update comments
mattbrandman Dec 13, 2025
d2eb6b0
test
mattbrandman Dec 14, 2025
eaf4cc3
updates
mattbrandman Dec 14, 2025
748294d
updates
mattbrandman Dec 14, 2025
9ca1cdd
fixing formatting
mattbrandman Dec 14, 2025
995f0b8
removing unecessary change
mattbrandman Dec 14, 2025
06fe0e1
parametrize tests
mattbrandman Dec 14, 2025
3f71342
fix example
mattbrandman Dec 14, 2025
71d8da3
no cover
mattbrandman Dec 14, 2025
8be2857
fixes comments
mattbrandman Dec 18, 2025
60e2e70
fix more
mattbrandman Dec 19, 2025
44b2aaa
fixes
mattbrandman Dec 19, 2025
3fede5a
fixing feedback
mattbrandman Dec 19, 2025
c921487
fixing tests
mattbrandman Dec 19, 2025
b489251
simplify
mattbrandman Dec 19, 2025
acee05c
remove confusing sentence
mattbrandman Dec 19, 2025
09298da
move things into temporal model
mattbrandman Dec 19, 2025
f881f8a
fixes
mattbrandman Dec 19, 2025
df87b41
resolve conflict
mattbrandman Dec 19, 2025
0349a82
updates
mattbrandman Dec 19, 2025
f3a8091
first fixes
mattbrandman Dec 19, 2025
4534788
more fixes
mattbrandman Dec 19, 2025
88e90f5
dedupe
mattbrandman Dec 19, 2025
f08b057
fixes
mattbrandman Dec 19, 2025
58adb03
error messages
mattbrandman Dec 19, 2025
732c277
fixing tests
mattbrandman Dec 19, 2025
4072fd0
updating feedback
mattbrandman Dec 19, 2025
1dda0de
fix tests
mattbrandman Dec 19, 2025
3710b1c
fixes
mattbrandman Dec 19, 2025
246a1cc
changing everything to id
mattbrandman Dec 19, 2025
85c2a83
merge
mattbrandman Dec 19, 2025
45dc40f
merge
mattbrandman Dec 19, 2025
3a37765
merge
mattbrandman Dec 19, 2025
4278e81
new tests
mattbrandman Dec 20, 2025
0ac18fc
merge
mattbrandman Dec 20, 2025
61cf85f
updates
mattbrandman Dec 20, 2025
b4dd880
final test
mattbrandman Dec 20, 2025
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 80 additions & 0 deletions docs/durable_execution/temporal.md
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,86 @@ As the streaming model request activity, workflow, and workflow execution call a
- To get data from the workflow call site or workflow to the event stream handler, you can use a [dependencies object](#agent-run-context-and-dependencies).
- To get data from the event stream handler to the workflow, workflow call site, or a frontend, you need to use an external system that the event stream handler can write to and the event consumer can read from, like a message queue. You can use the dependency object to make sure the same connection string or other unique ID is available in all the places that need it.

### Model Selection at Runtime

[`Agent.run(model=...)`][pydantic_ai.agent.Agent.run] normally supports both model strings (like `'openai:gpt-5.2'`) and model instances. However, `TemporalAgent` does not support arbitrary model instances because they cannot be serialized for Temporal's replay mechanism.

To use model instances with `TemporalAgent`, you need to pre-register them by passing a dict of model instances to `TemporalAgent(models={...})`. You can then reference them by name or by passing the registered instance directly. If the wrapped agent doesn't have a model set, the first registered model will be used as the default.

Model strings work as expected. For scenarios where you need to customize the provider used by the model string (e.g., inject API keys from deps), you can pass a `provider_factory` to `TemporalAgent`.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

... which is passed the run context (with a link). I liked the explicit mention you had of getting API keys from deps


Here's an example showing how to pre-register and use multiple models:

```python {title="multi_model_temporal.py" test="skip"}
from dataclasses import dataclass
from typing import Any

from temporalio import workflow

from pydantic_ai import Agent
from pydantic_ai.durable_exec.temporal import TemporalAgent
from pydantic_ai.models.anthropic import AnthropicModel
from pydantic_ai.models.google import GoogleModel
from pydantic_ai.models.openai import OpenAIResponsesModel
from pydantic_ai.providers import Provider
from pydantic_ai.tools import RunContext


@dataclass
class Deps:
openai_api_key: str | None = None
anthropic_api_key: str | None = None


# Create models from different providers
default_model = OpenAIResponsesModel('gpt-5.2')
fast_model = AnthropicModel('claude-sonnet-4-5')
reasoning_model = GoogleModel('gemini-2.5-pro')


# Optional: provider factory for dynamic model configuration
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd rather have a simple example first (showing just the values that are supported for run(model=...)), and then break provider factory out into a separate example, preceded by a paragraph explaining the feature/use case

def my_provider_factory(run_context: RunContext[Deps], provider_name: str) -> Provider[Any]:
"""Create providers with custom configuration based on run context."""
if provider_name == 'openai':
from pydantic_ai.providers.openai import OpenAIProvider

return OpenAIProvider(api_key=run_context.deps.openai_api_key)
elif provider_name == 'anthropic':
from pydantic_ai.providers.anthropic import AnthropicProvider

return AnthropicProvider(api_key=run_context.deps.anthropic_api_key)
else:
raise ValueError(f'Unknown provider: {provider_name}')


agent = Agent(default_model, name='multi_model_agent', deps_type=Deps)

temporal_agent = TemporalAgent(
agent,
models={
'fast': fast_model,
'reasoning': reasoning_model,
},
provider_factory=my_provider_factory, # Optional
)


@workflow.defn
class MultiModelWorkflow:
@workflow.run
async def run(self, prompt: str, use_reasoning: bool, use_fast: bool) -> str:
if use_reasoning:
# Select by registered name
result = await temporal_agent.run(prompt, model='reasoning')
elif use_fast:
# Or pass the registered instance directly
result = await temporal_agent.run(prompt, model=fast_model)
else:
# Or pass a model string (uses provider_factory if set)
result = await temporal_agent.run(prompt, model='openai:gpt-4.1-mini')
return result.output
```

## Activity Configuration

Temporal activity configuration, like timeouts and retry policies, can be customized by passing [`temporalio.workflow.ActivityConfig`](https://python.temporal.io/temporalio.workflow.ActivityConfig.html) objects to the `TemporalAgent` constructor:
Expand Down
45 changes: 24 additions & 21 deletions pydantic_ai_slim/pydantic_ai/_agent_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from pydantic_graph.nodes import End, NodeRunEndT

from . import _output, _system_prompt, exceptions, messages as _messages, models, result, usage as _usage
from ._run_context import set_current_run_context
from .exceptions import ToolRetryError
from .output import OutputDataT, OutputSpec
from .settings import ModelSettings
Expand Down Expand Up @@ -447,25 +448,26 @@ async def stream(
assert not self._did_stream, 'stream() should only be called once per node'

model_settings, model_request_parameters, message_history, run_context = await self._prepare_request(ctx)
async with ctx.deps.model.request_stream(
message_history, model_settings, model_request_parameters, run_context
) as streamed_response:
self._did_stream = True
ctx.state.usage.requests += 1
agent_stream = result.AgentStream[DepsT, T](
_raw_stream_response=streamed_response,
_output_schema=ctx.deps.output_schema,
_model_request_parameters=model_request_parameters,
_output_validators=ctx.deps.output_validators,
_run_ctx=build_run_context(ctx),
_usage_limits=ctx.deps.usage_limits,
_tool_manager=ctx.deps.tool_manager,
)
yield agent_stream
# In case the user didn't manually consume the full stream, ensure it is fully consumed here,
# otherwise usage won't be properly counted:
async for _ in agent_stream:
pass
with set_current_run_context(run_context):
async with ctx.deps.model.request_stream(
message_history, model_settings, model_request_parameters, run_context
) as streamed_response:
self._did_stream = True
ctx.state.usage.requests += 1
agent_stream = result.AgentStream[DepsT, T](
_raw_stream_response=streamed_response,
_output_schema=ctx.deps.output_schema,
_model_request_parameters=model_request_parameters,
_output_validators=ctx.deps.output_validators,
_run_ctx=build_run_context(ctx),
_usage_limits=ctx.deps.usage_limits,
_tool_manager=ctx.deps.tool_manager,
)
yield agent_stream
# In case the user didn't manually consume the full stream, ensure it is fully consumed here,
# otherwise usage won't be properly counted:
async for _ in agent_stream:
pass

model_response = streamed_response.get()

Expand All @@ -478,8 +480,9 @@ async def _make_request(
if self._result is not None:
return self._result # pragma: no cover

model_settings, model_request_parameters, message_history, _ = await self._prepare_request(ctx)
model_response = await ctx.deps.model.request(message_history, model_settings, model_request_parameters)
model_settings, model_request_parameters, message_history, run_context = await self._prepare_request(ctx)
with set_current_run_context(run_context):
model_response = await ctx.deps.model.request(message_history, model_settings, model_request_parameters)
ctx.state.usage.requests += 1

return self._finish_handling(ctx, model_response)
Expand Down
37 changes: 36 additions & 1 deletion pydantic_ai_slim/pydantic_ai/_run_context.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from __future__ import annotations as _annotations

import dataclasses
from collections.abc import Sequence
from collections.abc import Iterator, Sequence
from contextlib import contextmanager
from contextvars import ContextVar
from dataclasses import field
from typing import TYPE_CHECKING, Any, Generic

Expand Down Expand Up @@ -71,3 +73,36 @@ def last_attempt(self) -> bool:
return self.retry == self.max_retries

__repr__ = _utils.dataclasses_no_defaults_repr


_CURRENT_RUN_CONTEXT: ContextVar[RunContext[Any] | None] = ContextVar(
'pydantic_ai.current_run_context',
default=None,
)
"""Context variable storing the current [`RunContext`][pydantic_ai.tools.RunContext]."""


def get_current_run_context() -> RunContext[Any] | None:
"""Get the current run context, if one is set.

Returns:
The current [`RunContext`][pydantic_ai.tools.RunContext], or `None` if not in an agent run.
"""
return _CURRENT_RUN_CONTEXT.get()


@contextmanager
def set_current_run_context(run_context: RunContext[Any]) -> Iterator[None]:
"""Context manager to set the current run context.

Args:
run_context: The run context to set as current.

Yields:
None
"""
token = _CURRENT_RUN_CONTEXT.set(run_context)
try:
yield
finally:
_CURRENT_RUN_CONTEXT.reset(token)
Loading
Loading