Skip to content

Commit 9d6755e

Browse files
committed
OFA PyTorch sample - Ensure full code locality
1 parent 83390b8 commit 9d6755e

File tree

2 files changed

+13
-4
lines changed

2 files changed

+13
-4
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,3 +179,4 @@ xgboost_script_mode_local_training_and_serving/code/build/lib/build/lib/build/li
179179
xgboost_script_mode_local_training_and_serving/model/model.tar.gz
180180
xgboost_script_mode_local_training_and_serving/model/output.tar.gz
181181
xgboost_script_mode_local_training_and_serving/model/xgboost-model
182+
pytorch_extend_dlc_container_ofa_local_serving/model.tar.gz

pytorch_extend_dlc_container_ofa_local_serving/pytorch_extend_dlc_container_ofa_local_serving.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,26 +15,34 @@
1515
from PIL import Image
1616
import numpy as np
1717
from sagemaker.pytorch import PyTorchModel
18+
from sagemaker.local import LocalSession
19+
import boto3
1820

1921
DUMMY_IAM_ROLE = 'arn:aws:iam::111111111111:role/service-role/AmazonSageMaker-ExecutionRole-20200101T000001'
22+
LOCAL_SESSION = LocalSession()
23+
LOCAL_SESSION.config={'local': {'local_code': True}} # Ensure full code locality, see: https://sagemaker.readthedocs.io/en/stable/overview.html#local-mode
2024

2125

2226
def main():
2327

2428
image = 'sagemaker-ofa-pytorch-extended-local'
2529

30+
print('Downloading model file from S3')
31+
s3 = boto3.client('s3')
32+
s3.download_file('aws-ml-blog', 'artifacts/pytorch-extend-dlc-container-ofa-tiny/model.tar.gz', 'model.tar.gz')
33+
print('Model downloaded')
34+
2635
ofa_hf_model = PyTorchModel(
2736
source_dir="code",
2837
entry_point="inference.py",
2938
role=DUMMY_IAM_ROLE,
30-
model_data="s3://aws-ml-blog/artifacts/pytorch-extend-dlc-container-ofa-tiny/model.tar.gz",
39+
model_data="./model.tar.gz",
3140
image_uri=image,
32-
framework_version='1.8'
41+
framework_version='1.8',
42+
sagemaker_session=LOCAL_SESSION
3343
)
3444

3545
print('Deploying endpoint in local mode')
36-
print(
37-
'Note: model download might take a few minutes to complete due to its size.')
3846
predictor = ofa_hf_model.deploy(
3947
initial_instance_count=1,
4048
instance_type='local',

0 commit comments

Comments
 (0)