- Notifications
You must be signed in to change notification settings - Fork 1.5k
Allow TemporalAgent to switch model at agent.run-time #3537
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
| @DouweM let me know if this is more along the lines of what you are thinking. I did test this locally as well in our repo and it worked as expected. |
e089f1e to ef40b85 Compare | @mattbrandman Thanks for working on this Matt! A few high level thoughts:
|
f6a2ec9 to 9445096 Compare | @DouweM changed the PR to be more inline with your comments above |
| Confirmed this works locally. One thing that does appear to need updating but I'm not entirely sure where is that telemetry is printing the default model registered |
d1d85b9 to 4c7e487 Compare | *, | ||
| name: str | None = None, | ||
| additional_models: Mapping[str, Model | models.KnownModelName | str] | ||
| | Sequence[models.KnownModelName | str] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Model strings that will just be passed to infer_model don't really need to be pre-registered, do they? So we can probably drop this type from the union
| | ||
| self._registered_model_instances: dict[str, Model] = {'default': wrapped_model} | ||
| self._registered_model_names: dict[str, str] = {} | ||
| self._default_selection = ModelSelection(model_key='default', model_name=None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't know if we need a whole ModelSelection object. I'd just store a single string, check if it's a key in the model instance map, and if not, call infer_model on it. But there may be another purpose to this that I'm missing
| self._register_additional_model(key, value) | ||
| else: | ||
| for value in additional_models: | ||
| if not isinstance(value, str): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The type already makes this impossible right? We typically don't check things the type checker would've caught already
| deps_type=self.deps_type, | ||
| run_context_type=self.run_context_type, | ||
| event_stream_handler=self.event_stream_handler, | ||
| model_selection_var=self._model_selection_var, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of having a context var that's owned at this level and then passed into the TemporalModel, could TemporalModel own the context var and have a contextmanager method that we could use like with self._temporal_model.using_model('...'):?
| key = self._normalize_model_name(name) | ||
| self._registered_model_instances[key] = model | ||
| | ||
| def _register_string_model(self, name: str, model_identifier: str) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably not necessary; see above
| self._provider_factory = provider_factory | ||
| | ||
| request_activity_name = f'{activity_name_prefix}__model_request' | ||
| request_stream_activity_name = f'{activity_name_prefix}__model_request_stream' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since we don't edit this anymore, I'd rather move them back down where they were
| self.request_activity.__annotations__['deps'] = deps_type | ||
| | ||
| async def request_stream_activity(params: _RequestParams, deps: AgentDepsT) -> ModelResponse: | ||
| async def request_stream_activity(params: _RequestParams, deps: Any) -> ModelResponse: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unnecessary change?
25c4692 to bcd5b79 Compare | @DouweM made updates based on feedback. |
501af27 to 26269ac Compare | @DouweM tested locally as well and confirmed its working correctly with registration and strings. Was a good call to add that makes it feel much better |
| try: | ||
| model_response = await ctx.deps.model.request(message_history, model_settings, model_request_parameters) | ||
| finally: | ||
| CURRENT_RUN_CONTEXT.reset(token) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we move this to a contextmanager method in _run_context.py, so we can do with set_current_run_context(run_context): here and in the method above?
| __repr__ = _utils.dataclasses_no_defaults_repr | ||
| | ||
| | ||
| CURRENT_RUN_CONTEXT: ContextVar[RunContext[Any] | None] = ContextVar( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In addition to the set_current_run_context I suggested above, let's add a get_current_run_context that just returns RunContext[Any] | None for convenience. Then the actual contextvar can become private.
| if isinstance(model, Model): | ||
| raise UserError( | ||
| 'Model instances cannot be selected at runtime inside a Temporal workflow. ' | ||
| 'Register the model via the mapping form of `additional_models` and reference it by name.' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we can see if a key for this model instance exists in the registry?
| # Store the model spec so we can re-infer with provider_factory if needed | ||
| self._default_model_spec: str | None = None | ||
| if provider_factory is not None: | ||
| self._default_model_spec = f'{model.system}:{model.model_name}' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This has an issue because model.system will be openai for both openai-chat:gpt-5 and openai-responses:gpt-5, so we'd lose information here.
I'm not sure why we need this at all?
If this is so that we can use our own provider factory on the default agent model name, I realize I suggested that, but if it takes tricks like this, I'd rather remove this and use the factory only on model names provided at run time.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yup this was to try and make it usable for the default model will remove!
| assert self.event_stream_handler is not None | ||
| | ||
| if params.serialized_run_context is None: # pragma: no cover | ||
| raise UserError('Serialized run context is required for Temporal streaming activities.') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why would we allow it to be none at all then?
| | ||
| self._validate_model_request_parameters(model_request_parameters) | ||
| | ||
| selection = self._current_selection() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unclear method/variable name!
| def _resolve_model(self, params: _RequestParams, deps: Any | None) -> Model: | ||
| selection = params.model_selection | ||
| if selection is None: | ||
| # Re-infer the default model if provider_factory is set |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As mentioned above, I think I'd prefer to not do this
| return models.infer_model(model_name) | ||
| | ||
| def _factory(provider_name: str) -> Provider[Any]: | ||
| return provider_factory(provider_name, run_context, deps) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- Can
run_contextbe the first arg as it is everywhere else? - I don't think we need to pass in
depsexplicitly since it's already inrun_context; we don't pass it separately anywhere else
a64bbf0 to 5d94756 Compare docs/durable_execution/temporal.md Outdated
| | ||
| from pydantic_ai import Agent | ||
| from pydantic_ai.durable_exec.temporal import TemporalAgent | ||
| from pydantic_ai.models.openai import OpenAIModel |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This class is deprecated :)
docs/durable_execution/temporal.md Outdated
| # Create models | ||
| default_model = OpenAIModel('gpt-4o') | ||
| fast_model = OpenAIModel('gpt-4o-mini') | ||
| reasoning_model = OpenAIModel('o1') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's use more recent models from different providers to illustrate that they don't need to be from the same provider
docs/durable_execution/temporal.md Outdated
| | ||
| 1. **By registered name**: Pass the string key used in `additional_models` (e.g., `model='fast'`) | ||
| 2. **By registered instance**: Pass a model instance that was registered via `additional_models` or as the agent's default model | ||
| 3. **By provider string**: Pass a model string like `'openai:gpt-4o'` which will be instantiated at runtime |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mention the provider factory here!
| 'pydantic_ai.current_run_context', | ||
| default=None, | ||
| ) | ||
| """Context variable storing the current [`RunContext`][pydantic_ai._run_context.RunContext].""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This link will not work because it's to a private module, we need to get it from the tools module
tests/test_temporal.py Outdated
| | ||
| | ||
| async def test_temporal_agent_multi_model_backward_compatibility(): | ||
| """Test that single-model agents work as before (backward compatibility).""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we need this test, as all the existing tests already verify the one-model case works.
| # This test verifies that the agent works correctly outside workflows | ||
| # even when additional models are registered | ||
| result = await temporal_agent.run('Hello') | ||
| assert result.output == 'Model 1 response' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the same as the assert above 😄
Please carefully review all the tests for necessity/sanity.
I think passing a registered model ID should also work outside of a workflow, and I think passing an arbitrary model instance should as well, as a TemporalAgent outside of a workflow should behave like an Agent.
tests/test_temporal.py Outdated
| | ||
| # Unregistered model string should pass through directly | ||
| selected = temporal_agent._select_model('openai:gpt-4o') # pyright: ignore[reportPrivateUsage] | ||
| assert selected == 'openai:gpt-4o' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't want to test private methods like this, just the behavior of the feature.
A lot of the tests here seem unnecessary.
tests/test_temporal.py Outdated
| provider_run_ctx = provider_calls[0][0] | ||
| assert provider_run_ctx is not None | ||
| assert provider_run_ctx.run_id == 'run-123' | ||
| assert provider_run_ctx.deps == {'api_key': 'secret'} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we test the provider factory functionality without all the mocking? Just using the exact same code a user would write
tests/test_temporal.py Outdated
| | ||
| assert params.serialized_run_context is None | ||
| # deps should also be None | ||
| assert deps is None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't review the remaining tests, I think we should be able to remove the vast majority of these and only have tests with code as a the user would write it, that verify it works as expected
| @DouweM thanks for review will update. I was having some trouble hitting 100% coverage on the tests this some of the private testing let me take another pass and see if I can get the coverage needed. As a side note, anyway to easily run the coverage test I noticed it ran as an action after everything else but would be nice to be able to quickly run locally. |
@mattbrandman |
718c2f3 to 0349a82 Compare | | ||
| return model | ||
| | ||
| def resolve_model_for_outside_workflow( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the provider factory should still be run for model strings in this case. So maybe this should use _get_model_id + _resolve_model?
Not sure if that's exactly right, but I feel like these 3 methods have too much overlap to not be combinable 😄
I also think we can come up with a better name for this method. Maybe just resolve_model?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Had to make a few changes since some of the code we call previously expected run context but this function doesn't have it passedin
docs/durable_execution/temporal.md Outdated
| 1. **Pre-register model instances**: Pass a dict of model instances to `TemporalAgent(models={...})`, then reference them by name or by passing the registered instance directly. If the wrapped agent doesn't have a model set, the first registered model will be used as the default. | ||
| 2. **Use model strings**: Model strings work as expected. | ||
| 3. **Use a provider factory**: For scenarios where you need to customize the provider (e.g., inject API keys from deps), pass a `provider_factory` to `TemporalAgent`. | ||
| Model strings work as expected. For scenarios where you need to customize the provider used by the model string (e.g., inject API keys from deps), you can pass a `provider_factory` to `TemporalAgent`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
... which is passed the run context (with a link). I liked the explicit mention you had of getting API keys from deps
| if m is model: | ||
| return m | ||
| # Not registered, return as-is for outside workflow use | ||
| return model |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The if above basically does nothing right? if m is model: return m is equivalent to return model, so just if isinstance(model, Model): return model? Maybe with an in_workflow check?
| ) | ||
| resolved_model = None | ||
| else: | ||
| resolved_model = self._temporal_model.resolve_model(model) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It feels unexpected that provider_factory is not used for model strings when outside the workflow (because there's no ctx), but I don't think there's anything we can do about it, so we should document it (in the arg docstring) at least.
627bc0d to 4072fd0 Compare TemporalAgent to switch model at agent.run-time 40c3dce to 61cf85f Compare
Fixes #3407
This pull request introduces enhanced support for model selection and management within Temporal agents, as well as improved handling of run context propagation. The main changes allow registering multiple models with a Temporal agent, selecting models by name or provider string at runtime (inside workflows), and ensuring the current run context is properly tracked across async boundaries. These improvements make it easier to use and configure multiple models in Temporal workflows, while maintaining safety and clarity in model selection.
Model selection and registration for Temporal agents:
Added support for registering multiple models with a Temporal agent via the new
additional_modelsargument, and for selecting a model by name or provider string at runtime within workflows. This includes validation to prevent duplicate or invalid model names and ensures that only registered models or provider strings can be selected during workflow execution. (pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_agent.py, [1] [2] [3] [4] [5] [6] [7] [8]Introduced the
TemporalProviderFactorytype and support for passing a provider factory to Temporal agents and models, enabling custom provider instantiation logic (e.g., injecting API keys from dependencies) when resolving models from provider strings. (pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_agent.py, [1] [2];pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_model.py, [3] [4]Model selection logic in Temporal model activities:
pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_model.py, pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_model.pyR81-R112)Run context propagation improvements:
CURRENT_RUN_CONTEXTcontext variable to track the current run context across asynchronous boundaries, and updated agent graph methods to set and reset this variable during model requests and streaming. This ensures that context-dependent logic (such as provider factories) has access to the correct run context throughout execution. (pydantic_ai_slim/pydantic_ai/_run_context.py, [1] [2];pydantic_ai_slim/pydantic_ai/_agent_graph.py, [3] [4] [5] [6]Other improvements and minor changes:
pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_agent.py, [1];pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_model.py, [2] [3]These changes collectively make Temporal agents more flexible, robust, and easier to configure for advanced use cases involving multiple models and dynamic provider selection.