Skip to content

Commit 597eda2

Browse files
better forward compatibility with Pydantic v2 (#162)
* better forward compatibility with pydantic v2 * format change * fix format
1 parent 8f1231c commit 597eda2

File tree

5 files changed

+39
-22
lines changed

5 files changed

+39
-22
lines changed

launch/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,16 @@
66
77
"""
88

9+
import warnings
910
from typing import Sequence
1011

1112
import pkg_resources
13+
import pydantic
14+
15+
if pydantic.VERSION.startswith("2."):
16+
# HACK: Suppress warning from pydantic v2 about protected namespace, this is due to
17+
# launch-python-client module is based on v1 and only does minimum to support forward compatibility
18+
warnings.filterwarnings("ignore", category=UserWarning, module="pydantic")
1219

1320
from .client import LaunchClient
1421
from .connection import Connection

launch/docker_image_batch_job_bundle.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,18 +34,18 @@ class DockerImageBatchJobBundleResponse(BaseModel):
3434
"""The command to run inside the docker image"""
3535
env: Dict[str, str]
3636
"""Environment variables to be injected into the docker image"""
37-
mount_location: Optional[str]
37+
mount_location: Optional[str] = None
3838
"""Location of a json-formatted file to mount inside the docker image.
3939
Contents get populated at runtime, and this is the method to change behavior on runtime."""
40-
cpus: Optional[str]
40+
cpus: Optional[str] = None
4141
"""Default number of cpus to give to the docker image"""
42-
memory: Optional[str]
42+
memory: Optional[str] = None
4343
"""Default amount of memory to give to the docker image"""
44-
storage: Optional[str]
44+
storage: Optional[str] = None
4545
"""Default amount of disk to give to the docker image"""
46-
gpus: Optional[int]
46+
gpus: Optional[int] = None
4747
"""Default number of gpus to give to the docker image"""
48-
gpu_type: Optional[str]
48+
gpu_type: Optional[str] = None
4949
"""Default type of gpu, e.g. nvidia-tesla-t4, nvidia-ampere-a10 to give to the docker image"""
5050

5151

launch/fine_tune.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ class CreateFineTuneResponse(BaseModel):
2828
class GetFineTuneResponse(BaseModel):
2929
id: str
3030
"""ID of the requested job"""
31-
fine_tuned_model: Optional[str]
31+
fine_tuned_model: Optional[str] = None
3232
"""
3333
Name of the resulting fine-tuned model. This can be plugged into the
3434
Completion API ones the fine-tune is complete

launch/model_bundle.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ class ZipArtifactFlavor(BaseModel):
9494
[`CustomFramework`](./#launch.model_bundle.CustomFramework).
9595
"""
9696

97-
app_config: Optional[Dict[str, Any]]
97+
app_config: Optional[Dict[str, Any]] = None
9898
"""Optional configuration for the application."""
9999

100100
location: str
@@ -112,7 +112,7 @@ class RunnableImageLike(BaseModel, ABC):
112112
repository: str
113113
tag: str
114114
command: List[str]
115-
env: Optional[Dict[str, str]]
115+
env: Optional[Dict[str, str]] = None
116116
protocol: Literal["http"] # TODO: add support for other protocols (e.g. grpc)
117117
readiness_initial_delay_seconds: int = 120
118118

@@ -137,15 +137,15 @@ class TritonEnhancedRunnableImageFlavor(RunnableImageLike):
137137

138138
triton_model_repository: str
139139

140-
triton_model_replicas: Optional[Dict[str, str]]
140+
triton_model_replicas: Optional[Dict[str, str]] = None
141141

142142
triton_num_cpu: float
143143

144144
triton_commit_tag: str
145145

146-
triton_storage: Optional[str]
146+
triton_storage: Optional[str] = None
147147

148-
triton_memory: Optional[str]
148+
triton_memory: Optional[str] = None
149149

150150
triton_readiness_initial_delay_seconds: int = 300
151151

@@ -197,7 +197,7 @@ class ModelBundleV2Response(BaseModel):
197197
model_artifact_ids: List[str]
198198
"""IDs of the Model Artifacts associated with the Model Bundle."""
199199

200-
schema_location: Optional[str]
200+
schema_location: Optional[str] = None
201201

202202
flavor: ModelBundleFlavors = Field(..., discriminator="flavor")
203203
"""Flavor of the Model Bundle, representing how the model bundle was packaged.

launch/pydantic_schemas.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,25 @@
11
from enum import Enum
2-
from typing import Any, Dict, Set, Type, Union
2+
from typing import Any, Callable, Dict, Set, Type, Union
33

4+
import pydantic
45
from pydantic import BaseModel
56

6-
try:
7+
if hasattr(pydantic, "VERSION") and pydantic.VERSION.startswith("1."):
8+
PYDANTIC_VERSION = 1
79
from pydantic.schema import (
810
get_flat_models_from_models,
911
model_process_schema,
1012
)
11-
except ImportError:
12-
# We assume this is due to the user having pydantic 2.x installed.
13-
from pydantic.v1.schema import ( # type: ignore
14-
get_flat_models_from_models,
15-
model_process_schema,
16-
)
13+
elif hasattr(pydantic, "VERSION") and pydantic.VERSION.startswith("2."):
14+
PYDANTIC_VERSION = 2
15+
else:
16+
raise ImportError("Unsupported pydantic version.")
1717

1818

1919
REF_PREFIX = "#/components/schemas/"
2020

2121

22-
def get_model_definitions(request_schema: Type[BaseModel], response_schema: Type[BaseModel]) -> Dict[str, Any]:
22+
def get_model_definitions_v1(request_schema: Type[BaseModel], response_schema: Type[BaseModel]) -> Dict[str, Any]:
2323
"""
2424
Gets the model schemas in jsonschema format from a sequence of Pydantic BaseModels.
2525
"""
@@ -29,6 +29,16 @@ def get_model_definitions(request_schema: Type[BaseModel], response_schema: Type
2929
return get_model_definitions_from_flat_models(flat_models=flat_models, model_name_map=model_name_map)
3030

3131

32+
def get_model_definitions_v2(request_schema: Type[BaseModel], response_schema: Type[BaseModel]) -> Dict[str, Any]:
33+
return {"RequestSchema": request_schema.model_json_schema(), "ResponseSchema": response_schema.model_json_schema()}
34+
35+
36+
if PYDANTIC_VERSION == 1:
37+
get_model_definitions: Callable = get_model_definitions_v1
38+
elif PYDANTIC_VERSION == 2:
39+
get_model_definitions: Callable = get_model_definitions_v2
40+
41+
3242
def get_model_definitions_from_flat_models(
3343
*,
3444
flat_models: Set[Union[Type[BaseModel], Type[Enum]]],

0 commit comments

Comments
 (0)