Skip to content

Conversation

@mattbrandman
Copy link
Contributor

@mattbrandman mattbrandman commented Nov 24, 2025

Fixes #3407

This pull request introduces enhanced support for model selection and management within Temporal agents, as well as improved handling of run context propagation. The main changes allow registering multiple models with a Temporal agent, selecting models by name or provider string at runtime (inside workflows), and ensuring the current run context is properly tracked across async boundaries. These improvements make it easier to use and configure multiple models in Temporal workflows, while maintaining safety and clarity in model selection.

Model selection and registration for Temporal agents:

  • Added support for registering multiple models with a Temporal agent via the new additional_models argument, and for selecting a model by name or provider string at runtime within workflows. This includes validation to prevent duplicate or invalid model names and ensures that only registered models or provider strings can be selected during workflow execution. (pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_agent.py, [1] [2] [3] [4] [5] [6] [7] [8]

  • Introduced the TemporalProviderFactory type and support for passing a provider factory to Temporal agents and models, enabling custom provider instantiation logic (e.g., injecting API keys from dependencies) when resolving models from provider strings. (pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_agent.py, [1] [2]; pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_model.py, [3] [4]

Model selection logic in Temporal model activities:

  • Updated the Temporal model wrapper to support runtime model selection, including a context variable for the current model selection and logic to resolve the correct model instance for each request or stream activity. This ensures the correct model is used for each workflow step, whether by registered name or provider string. (pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_model.py, pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_model.pyR81-R112)

Run context propagation improvements:

  • Introduced the CURRENT_RUN_CONTEXT context variable to track the current run context across asynchronous boundaries, and updated agent graph methods to set and reset this variable during model requests and streaming. This ensures that context-dependent logic (such as provider factories) has access to the correct run context throughout execution. (pydantic_ai_slim/pydantic_ai/_run_context.py, [1] [2]; pydantic_ai_slim/pydantic_ai/_agent_graph.py, [3] [4] [5] [6]

Other improvements and minor changes:

  • Updated imports and type annotations to support the new features and improve clarity. (pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_agent.py, [1]; pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_model.py, [2] [3]

These changes collectively make Temporal agents more flexible, robust, and easier to configure for advanced use cases involving multiple models and dynamic provider selection.

@mattbrandman
Copy link
Contributor Author

@DouweM let me know if this is more along the lines of what you are thinking. I did test this locally as well in our repo and it worked as expected.

@mattbrandman mattbrandman marked this pull request as ready for review November 24, 2025 21:30
@mattbrandman mattbrandman force-pushed the modelsetting-override branch 2 times, most recently from e089f1e to ef40b85 Compare November 25, 2025 16:07
@DouweM DouweM self-assigned this Nov 26, 2025
@DouweM
Copy link
Collaborator

DouweM commented Nov 26, 2025

@mattbrandman Thanks for working on this Matt!

A few high level thoughts:

  1. I wonder if we can use a single TemporalModel and just swap out what self.wrapped points at. We could pass the additional models to TemporalModel, and use a context manager + a model key pass into the request/request_stream activities to select the correct model
  2. Activity names should never use generated values as users need to be able to keep already-running activities working when they change the underlying code, so if we need model-specific info in activity names, it should be the (normalized) provider:model name if a string is provided, and if an instance is provided the user should explicitly name it. So maybe have additional_models be dict[str, Model | str] | list[str], not allowing un-named model instances. Note that if point 1 works, we may not need the model name in the activity name at all, as they'll all use the same activities
  3. If we do that, I suppose we can support arbitrary provider:model strs passed at agent.run time, as the TemporalModel can infer_model on them, and they don't need dedicated pre-registered activities. Specific model instances would still need to be registered up front, so they can be referenced by name. Maybe agent.run(model=...) should only take str | KnownModelName then, so we don't need to look up things by id().
  4. I don't know if the extra generic param is worth the effort, as it restrict the types but not the specific instances or registered model names, so we'd need to rely on a runtime check and error anyway
  5. If we support arbitrary model names on agent.run, it could be worth allowing a provider_factory to be registered to override how providers are built (see the arg on infer_model). Our version of that could also take run context, so the provider can be configured with an API key from deps, for example
@mattbrandman
Copy link
Contributor Author

@DouweM changed the PR to be more inline with your comments above

@mattbrandman
Copy link
Contributor Author

Confirmed this works locally. One thing that does appear to need updating but I'm not entirely sure where is that telemetry is printing the default model registered

@mattbrandman mattbrandman force-pushed the modelsetting-override branch 2 times, most recently from d1d85b9 to 4c7e487 Compare December 2, 2025 19:00
*,
name: str | None = None,
additional_models: Mapping[str, Model | models.KnownModelName | str]
| Sequence[models.KnownModelName | str]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Model strings that will just be passed to infer_model don't really need to be pre-registered, do they? So we can probably drop this type from the union


self._registered_model_instances: dict[str, Model] = {'default': wrapped_model}
self._registered_model_names: dict[str, str] = {}
self._default_selection = ModelSelection(model_key='default', model_name=None)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know if we need a whole ModelSelection object. I'd just store a single string, check if it's a key in the model instance map, and if not, call infer_model on it. But there may be another purpose to this that I'm missing

self._register_additional_model(key, value)
else:
for value in additional_models:
if not isinstance(value, str):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The type already makes this impossible right? We typically don't check things the type checker would've caught already

deps_type=self.deps_type,
run_context_type=self.run_context_type,
event_stream_handler=self.event_stream_handler,
model_selection_var=self._model_selection_var,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of having a context var that's owned at this level and then passed into the TemporalModel, could TemporalModel own the context var and have a contextmanager method that we could use like with self._temporal_model.using_model('...'):?

key = self._normalize_model_name(name)
self._registered_model_instances[key] = model

def _register_string_model(self, name: str, model_identifier: str) -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably not necessary; see above

self._provider_factory = provider_factory

request_activity_name = f'{activity_name_prefix}__model_request'
request_stream_activity_name = f'{activity_name_prefix}__model_request_stream'
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we don't edit this anymore, I'd rather move them back down where they were

self.request_activity.__annotations__['deps'] = deps_type

async def request_stream_activity(params: _RequestParams, deps: AgentDepsT) -> ModelResponse:
async def request_stream_activity(params: _RequestParams, deps: Any) -> ModelResponse:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unnecessary change?

@mattbrandman mattbrandman force-pushed the modelsetting-override branch from 25c4692 to bcd5b79 Compare December 3, 2025 18:56
@mattbrandman
Copy link
Contributor Author

@DouweM made updates based on feedback.

@mattbrandman mattbrandman force-pushed the modelsetting-override branch from 501af27 to 26269ac Compare December 4, 2025 01:50
@mattbrandman mattbrandman requested a review from DouweM December 5, 2025 16:38
@mattbrandman
Copy link
Contributor Author

@DouweM tested locally as well and confirmed its working correctly with registration and strings. Was a good call to add that makes it feel much better

try:
model_response = await ctx.deps.model.request(message_history, model_settings, model_request_parameters)
finally:
CURRENT_RUN_CONTEXT.reset(token)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we move this to a contextmanager method in _run_context.py, so we can do with set_current_run_context(run_context): here and in the method above?

__repr__ = _utils.dataclasses_no_defaults_repr


CURRENT_RUN_CONTEXT: ContextVar[RunContext[Any] | None] = ContextVar(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In addition to the set_current_run_context I suggested above, let's add a get_current_run_context that just returns RunContext[Any] | None for convenience. Then the actual contextvar can become private.

if isinstance(model, Model):
raise UserError(
'Model instances cannot be selected at runtime inside a Temporal workflow. '
'Register the model via the mapping form of `additional_models` and reference it by name.'
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can see if a key for this model instance exists in the registry?

# Store the model spec so we can re-infer with provider_factory if needed
self._default_model_spec: str | None = None
if provider_factory is not None:
self._default_model_spec = f'{model.system}:{model.model_name}'
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This has an issue because model.system will be openai for both openai-chat:gpt-5 and openai-responses:gpt-5, so we'd lose information here.

I'm not sure why we need this at all?

If this is so that we can use our own provider factory on the default agent model name, I realize I suggested that, but if it takes tricks like this, I'd rather remove this and use the factory only on model names provided at run time.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup this was to try and make it usable for the default model will remove!

assert self.event_stream_handler is not None

if params.serialized_run_context is None: # pragma: no cover
raise UserError('Serialized run context is required for Temporal streaming activities.')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why would we allow it to be none at all then?


self._validate_model_request_parameters(model_request_parameters)

selection = self._current_selection()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unclear method/variable name!

def _resolve_model(self, params: _RequestParams, deps: Any | None) -> Model:
selection = params.model_selection
if selection is None:
# Re-infer the default model if provider_factory is set
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As mentioned above, I think I'd prefer to not do this

return models.infer_model(model_name)

def _factory(provider_name: str) -> Provider[Any]:
return provider_factory(provider_name, run_context, deps)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • Can run_context be the first arg as it is everywhere else?
  • I don't think we need to pass in deps explicitly since it's already in run_context; we don't pass it separately anywhere else
@mattbrandman mattbrandman force-pushed the modelsetting-override branch 3 times, most recently from a64bbf0 to 5d94756 Compare December 10, 2025 19:28
@mattbrandman mattbrandman requested a review from DouweM December 10, 2025 20:44

from pydantic_ai import Agent
from pydantic_ai.durable_exec.temporal import TemporalAgent
from pydantic_ai.models.openai import OpenAIModel
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This class is deprecated :)

# Create models
default_model = OpenAIModel('gpt-4o')
fast_model = OpenAIModel('gpt-4o-mini')
reasoning_model = OpenAIModel('o1')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's use more recent models from different providers to illustrate that they don't need to be from the same provider


1. **By registered name**: Pass the string key used in `additional_models` (e.g., `model='fast'`)
2. **By registered instance**: Pass a model instance that was registered via `additional_models` or as the agent's default model
3. **By provider string**: Pass a model string like `'openai:gpt-4o'` which will be instantiated at runtime
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mention the provider factory here!

'pydantic_ai.current_run_context',
default=None,
)
"""Context variable storing the current [`RunContext`][pydantic_ai._run_context.RunContext]."""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This link will not work because it's to a private module, we need to get it from the tools module



async def test_temporal_agent_multi_model_backward_compatibility():
"""Test that single-model agents work as before (backward compatibility)."""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need this test, as all the existing tests already verify the one-model case works.

# This test verifies that the agent works correctly outside workflows
# even when additional models are registered
result = await temporal_agent.run('Hello')
assert result.output == 'Model 1 response'
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the same as the assert above 😄

Please carefully review all the tests for necessity/sanity.

I think passing a registered model ID should also work outside of a workflow, and I think passing an arbitrary model instance should as well, as a TemporalAgent outside of a workflow should behave like an Agent.


# Unregistered model string should pass through directly
selected = temporal_agent._select_model('openai:gpt-4o') # pyright: ignore[reportPrivateUsage]
assert selected == 'openai:gpt-4o'
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't want to test private methods like this, just the behavior of the feature.

A lot of the tests here seem unnecessary.

provider_run_ctx = provider_calls[0][0]
assert provider_run_ctx is not None
assert provider_run_ctx.run_id == 'run-123'
assert provider_run_ctx.deps == {'api_key': 'secret'}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we test the provider factory functionality without all the mocking? Just using the exact same code a user would write


assert params.serialized_run_context is None
# deps should also be None
assert deps is None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't review the remaining tests, I think we should be able to remove the vast majority of these and only have tests with code as a the user would write it, that verify it works as expected

@mattbrandman
Copy link
Contributor Author

@DouweM thanks for review will update. I was having some trouble hitting 100% coverage on the tests this some of the private testing let me take another pass and see if I can get the coverage needed. As a side note, anyway to easily run the coverage test I noticed it ran as an action after everything else but would be nice to be able to quickly run locally.

@DouweM
Copy link
Collaborator

DouweM commented Dec 12, 2025

As a side note, anyway to easily run the coverage test I noticed it ran as an action after everything else but would be nice to be able to quickly run locally.

@mattbrandman make test should do that for you!


return model

def resolve_model_for_outside_workflow(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the provider factory should still be run for model strings in this case. So maybe this should use _get_model_id + _resolve_model?

Not sure if that's exactly right, but I feel like these 3 methods have too much overlap to not be combinable 😄

I also think we can come up with a better name for this method. Maybe just resolve_model?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Had to make a few changes since some of the code we call previously expected run context but this function doesn't have it passedin

@mattbrandman mattbrandman requested a review from DouweM December 19, 2025 22:08
1. **Pre-register model instances**: Pass a dict of model instances to `TemporalAgent(models={...})`, then reference them by name or by passing the registered instance directly. If the wrapped agent doesn't have a model set, the first registered model will be used as the default.
2. **Use model strings**: Model strings work as expected.
3. **Use a provider factory**: For scenarios where you need to customize the provider (e.g., inject API keys from deps), pass a `provider_factory` to `TemporalAgent`.
Model strings work as expected. For scenarios where you need to customize the provider used by the model string (e.g., inject API keys from deps), you can pass a `provider_factory` to `TemporalAgent`.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

... which is passed the run context (with a link). I liked the explicit mention you had of getting API keys from deps

if m is model:
return m
# Not registered, return as-is for outside workflow use
return model
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The if above basically does nothing right? if m is model: return m is equivalent to return model, so just if isinstance(model, Model): return model? Maybe with an in_workflow check?

)
resolved_model = None
else:
resolved_model = self._temporal_model.resolve_model(model)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It feels unexpected that provider_factory is not used for model strings when outside the workflow (because there's no ctx), but I don't think there's anything we can do about it, so we should document it (in the arg docstring) at least.

@DouweM DouweM changed the title Allow multiple models in temporal agent Allow TemporalAgent to switch model at agent.run-time Dec 19, 2025
@mattbrandman mattbrandman requested a review from DouweM December 20, 2025 00:53
@DouweM DouweM merged commit ad5ba4f into pydantic:main Dec 20, 2025
30 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

2 participants