Skip to content

Commit 3c12489

Browse files
committed
Initial push of gensim_with_word2vec_model_artifacts_local_serving sample code
1 parent 9591000 commit 3c12489

File tree

5 files changed

+93
-0
lines changed

5 files changed

+93
-0
lines changed

.gitignore

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,3 +88,7 @@ scikit_learn_bring_your_own_container_and_own_model_local_serving/data/train/cal
8888
scikit_learn_bring_your_own_container_and_own_model_local_serving/data/validation/california_validation.csv
8989
scikit_learn_bring_your_own_container_and_own_model_local_serving/model.joblib
9090
scikit_learn_bring_your_own_container_and_own_model_local_serving/model.tar.gz
91+
gensim_with_word2vec_model_artifacts_local_serving/eval.json
92+
gensim_with_word2vec_model_artifacts_local_serving/model.tar.gz
93+
gensim_with_word2vec_model_artifacts_local_serving/vectors.bin
94+
gensim_with_word2vec_model_artifacts_local_serving/vectors.txt
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import os
2+
import json
3+
from gensim.models import KeyedVectors
4+
5+
def input_fn(request_body, request_content_type):
6+
print(f"request_body: {request_body}")
7+
print(type(request_body))
8+
if request_content_type == "application/json":
9+
payload = json.loads(request_body)
10+
instances = payload["instances"]
11+
return instances
12+
13+
14+
def predict_fn(instances, word_vectors):
15+
#########################################
16+
# Do your custom preprocessing logic here
17+
#########################################
18+
19+
print(f"instances: {instances}")
20+
print("calling model")
21+
predictions = word_vectors.most_similar(positive=instances)
22+
return predictions
23+
24+
25+
def model_fn(model_dir):
26+
print("loading model from: {}".format(model_dir))
27+
word_vectors = KeyedVectors.load_word2vec_format(os.path.join(model_dir, "vectors.txt"), binary=False)
28+
print(f'word vectors length: {len(word_vectors)}')
29+
return word_vectors
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
gensim==4.1.2
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
# This is a sample Python program that serve a Word2Vec model, trained with BlazingText algorithm with inference using gensim.
2+
# This implementation will work on your *local computer* or in the *AWS Cloud*.
3+
#
4+
# Prerequisites:
5+
# 1. Install required Python packages:
6+
# `pip install -r requirements.txt`
7+
# 2. Docker Desktop installed and running on your computer:
8+
# `docker ps`
9+
# 3. You should have AWS credentials configured on your local machine
10+
# in order to be able to pull the docker image from ECR.
11+
###############################################################################################
12+
13+
import boto3
14+
from sagemaker.deserializers import JSONDeserializer
15+
from sagemaker.serializers import JSONSerializer
16+
from sagemaker.sklearn import SKLearnModel
17+
18+
19+
DUMMY_IAM_ROLE = 'arn:aws:iam::111111111111:role/service-role/AmazonSageMaker-ExecutionRole-20200101T000001'
20+
s3 = boto3.client('s3')
21+
22+
23+
def main():
24+
25+
# Download a pre-trained model archive file
26+
print('Downloading a pre-trained model archive file')
27+
s3.download_file('aws-ml-blog', 'artifacts/word2vec_algorithm_model_artifacts/model.tar.gz', 'model.tar.gz')
28+
29+
print('Deploying endpoint in local mode')
30+
print(
31+
'Note: if launching for the first time in local mode, container image download might take a few minutes to complete.')
32+
model = SKLearnModel(
33+
role=DUMMY_IAM_ROLE,
34+
model_data='file://./model.tar.gz',
35+
framework_version='0.23-1',
36+
py_version='py3',
37+
source_dir='code',
38+
entry_point='inference.py'
39+
)
40+
41+
print('Deploying endpoint in local mode')
42+
predictor = model.deploy(initial_instance_count=1, instance_type='local')
43+
44+
payload = {"instances": ["king","queen"]}
45+
predictor.serializer = JSONSerializer()
46+
predictor.deserializer = JSONDeserializer()
47+
predictions = predictor.predict(payload)
48+
print(f"Predictions: {predictions}")
49+
50+
print('About to delete the endpoint to stop paying (if in cloud mode).')
51+
predictor.delete_endpoint(predictor.endpoint_name)
52+
53+
54+
if __name__ == "__main__":
55+
main()
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
numpy
2+
pandas
3+
sagemaker>=2.0.0<3.0.0
4+
sagemaker[local]

0 commit comments

Comments
 (0)