DEV Community

Petter Gustafsson
Petter Gustafsson

Posted on

Sagemaker model CI/CD

Intro

On our journey to build cheap, scalable GPU inference at Tonar we have selected to go with Sagemaker Async Inference. We manage our infrastructure with Terraform and this is a small how-to on actually getting CI/CD to work when you are continuously updating models behind Sagemaker Endpoints. We will use a custom Whisper model as the base reference here.

Expected behaviour...

Reading the documentation of both AWS and Terraform you would end up with something like the below code to have an autoscaling endpoint.

resource "aws_sagemaker_model" "whisper" { execution_role_arn = aws_iam_role.whisper.arn name = "whisper" primary_container { image = data.aws_ssm_parameter.whisper_digest_id.value model_data_url = "s3://..." environment = { STAGE = terraform.workspace CONTAINER_ID = data.aws_ssm_parameter.whisper_digest_id.value SENTRY_DSN = data.aws_ssm_parameter.sentry_dsn.value SENTRY_ENVIRONMENT = terraform.workspace INSTANCE_TYPE = "ml.g4dn.xlarge" } } } resource "aws_sagemaker_endpoint_configuration" "whisper" { production_variants { model_name = aws_sagemaker_model.whisper.name variant_name = "AllTraffic" initial_instance_count = 1 instance_type = "ml.g4dn.xlarge" container_startup_health_check_timeout_in_seconds = 1800 model_data_download_timeout_in_seconds = 1200 } async_inference_config { output_config { s3_output_path = "s3://..." s3_failure_path = "s3://..." notification_config { include_inference_response_in = ["SUCCESS_NOTIFICATION_TOPIC", "ERROR_NOTIFICATION_TOPIC"] success_topic = aws_sns_topic.whisper_success_topic.arn error_topic = aws_sns_topic.whisper_error_topic.arn } } client_config { max_concurrent_invocations_per_instance = 1 } } } resource "aws_sagemaker_endpoint" "whisper" { name = "whisper" endpoint_config_name = aws_sagemaker_endpoint_configuration.whisper.name deployment_config { rolling_update_policy { maximum_batch_size { type = "CAPACITY_PERCENT" value = 50 } maximum_execution_timeout_in_seconds = 900 wait_interval_in_seconds = 180 } } depends_on = [aws_sagemaker_model.whisper, aws_sagemaker_endpoint_configuration.whisper] } resource "aws_appautoscaling_target" "sagemaker_target" { max_capacity = 30 min_capacity = 0 resource_id = "endpoint/${aws_sagemaker_endpoint.whisper.name}/variant/AllTraffic" role_arn = aws_iam_role.whisper.arn scalable_dimension = "sagemaker:variant:DesiredInstanceCount" service_namespace = "sagemaker" depends_on = [aws_sagemaker_endpoint.whisper, aws_sagemaker_endpoint_configuration.whisper] } resource "aws_appautoscaling_policy" "sagemaker_policy_regular" { name = "whisper-invocations-scaling-policy" policy_type = "TargetTrackingScaling" resource_id = aws_appautoscaling_target.sagemaker_target.resource_id scalable_dimension = aws_appautoscaling_target.sagemaker_target.scalable_dimension service_namespace = aws_appautoscaling_target.sagemaker_target.service_namespace target_tracking_scaling_policy_configuration { customized_metric_specification { metric_name = "ApproximateBacklogSizePerInstance" namespace = "AWS/SageMaker" dimensions { name = "EndpointName" value = aws_sagemaker_endpoint.whisper.name } statistic = "Average" } target_value = 10 scale_in_cooldown = 30 scale_out_cooldown = 120 } } // Scales from 0 to 1 without waiting for queue to fill up resource "aws_appautoscaling_policy" "sagemaker_policy_zero_to_one" { name = "whisper-backlog-without-capacity-scaling-policy" policy_type = "StepScaling" resource_id = aws_appautoscaling_target.sagemaker_target.resource_id scalable_dimension = aws_appautoscaling_target.sagemaker_target.scalable_dimension service_namespace = aws_appautoscaling_target.sagemaker_target.service_namespace step_scaling_policy_configuration { adjustment_type = "ChangeInCapacity" cooldown = 300 metric_aggregation_type = "Average" step_adjustment { metric_interval_lower_bound = 0 scaling_adjustment = 1 } } } resource "aws_cloudwatch_metric_alarm" "sagemaker_policy_zero_to_one" { alarm_name = "whisper-backlog-without-capacity-scaling-policy" metric_name = "HasBacklogWithoutCapacity" namespace = "AWS/SageMaker" statistic = "Average" evaluation_periods = 2 datapoints_to_alarm = 2 comparison_operator = "GreaterThanOrEqualToThreshold" threshold = 1 treat_missing_data = "missing" dimensions = { EndpointName = aws_sagemaker_endpoint.whisper.name } period = 60 alarm_description = "This metric is used to trigger the scaling policy that scales from 0 to 1 without waiting for queue to fill up" alarm_actions = [aws_appautoscaling_policy.sagemaker_policy_zero_to_one.arn] } 
Enter fullscreen mode Exit fullscreen mode

The problem

Why this is not working has to do with the inner workings of Terraform not fully understanding the update-chain that must happen to satisfy Sagemaker. Basically updating the model will result in one of two scenarios:

  • The model and endpoint configuration is updated, but the changes aren't applied to the endpoint, so you have to destroy the endpoint in order to apply the changes (not good for production...)
  • The endpoint apply fails because it's always referring to the previous version of the endpoint configuration (which terraform has destroyed), which leads to deleting the endpoint again...

The dirty solution

So if you really want to just be able to update the underlying code of the model (like any other container), push it to ECR and expect it to roll out to your endpoint, this is the only solution I've come up with so far.

First we will change the terraform code to:

resource "aws_sagemaker_model" "whisper" { execution_role_arn = aws_iam_role.whisper.arn primary_container { image = data.aws_ssm_parameter.whisper_digest_id.value model_data_url = "s3://..." environment = { STAGE = terraform.workspace CONTAINER_ID = data.aws_ssm_parameter.whisper_digest_id.value SENTRY_DSN = data.aws_ssm_parameter.sentry_dsn.value SENTRY_ENVIRONMENT = terraform.workspace INSTANCE_TYPE = "ml.g4dn.xlarge" } } } resource "aws_cloudwatch_metric_alarm" "sagemaker_endpoint_error_rate" { alarm_name = "EndToEndDeploymentHighErrorRateAlarm" alarm_description = "Monitors the error rate of 4xx errors" metric_name = "Invocation4XXErrors" namespace = "AWS/SageMaker" statistic = "Average" evaluation_periods = 2 comparison_operator = "GreaterThanThreshold" threshold = 1 period = 600 treat_missing_data = "notBreaching" dimensions = { EndpointName = "whisper" VariantName = "AllTraffic" } } data "external" "deploy_model" { program = ["python", "${path.module}/deploy_model.py"] query = { deploy_action = var.DEPLOY_ACTION aws_access_key_id = var.ACCESS_KEY aws_secret_access_key = var.SECRET_KEY endpoint_name = "whisper" endpoint_alarm_name = "EndToEndDeploymentHighErrorRateAlarm" endpoint_config_model_name = aws_sagemaker_model.whisper.name endpoint_config_model_image = data.aws_ssm_parameter.whisper_digest_id.value endpoint_config_instance_type = "ml.g4dn.xlarge" endpoint_config_output_path = "s3://..." endpoint_config_error_path = "s3://..." endpoint_config_success_topic = aws_sns_topic.whisper_success_topic.arn endpoint_config_error_topic = aws_sns_topic.whisper_error_topic.arn } depends_on = [aws_sagemaker_model.whisper] } output "model_deployment" { description = "Model deployment" value = data.external.deploy_model.result } resource "aws_appautoscaling_target" "sagemaker_target" { max_capacity = 30 min_capacity = 0 resource_id = "endpoint/${data.external.deploy_model.result["endpoint_name"]}/variant/AllTraffic" role_arn = aws_iam_role.whisper.arn scalable_dimension = "sagemaker:variant:DesiredInstanceCount" service_namespace = "sagemaker" depends_on = [data.external.deploy_model] } resource "aws_appautoscaling_policy" "sagemaker_policy_regular" { name = "whisper-invocations-scaling-policy" policy_type = "TargetTrackingScaling" resource_id = aws_appautoscaling_target.sagemaker_target.resource_id scalable_dimension = aws_appautoscaling_target.sagemaker_target.scalable_dimension service_namespace = aws_appautoscaling_target.sagemaker_target.service_namespace target_tracking_scaling_policy_configuration { customized_metric_specification { metric_name = "ApproximateBacklogSizePerInstance" namespace = "AWS/SageMaker" dimensions { name = "EndpointName" value = data.external.deploy_model.result["endpoint_name"] } statistic = "Average" } target_value = 10 scale_in_cooldown = 30 scale_out_cooldown = 120 } } // Scales from 0 to 1 without waiting for queue to fill up resource "aws_appautoscaling_policy" "sagemaker_policy_zero_to_one" { name = "whisper-backlog-without-capacity-scaling-policy" policy_type = "StepScaling" resource_id = aws_appautoscaling_target.sagemaker_target.resource_id scalable_dimension = aws_appautoscaling_target.sagemaker_target.scalable_dimension service_namespace = aws_appautoscaling_target.sagemaker_target.service_namespace step_scaling_policy_configuration { adjustment_type = "ChangeInCapacity" cooldown = 300 metric_aggregation_type = "Average" step_adjustment { metric_interval_lower_bound = 0 scaling_adjustment = 1 } } } resource "aws_cloudwatch_metric_alarm" "sagemaker_policy_zero_to_one" { alarm_name = "whisper-backlog-without-capacity-scaling-policy" metric_name = "HasBacklogWithoutCapacity" namespace = "AWS/SageMaker" statistic = "Average" evaluation_periods = 2 datapoints_to_alarm = 2 comparison_operator = "GreaterThanOrEqualToThreshold" threshold = 1 treat_missing_data = "missing" dimensions = { EndpointName = data.external.deploy_model.result["endpoint_name"] } period = 60 alarm_description = "This metric is used to trigger the scaling policy that scales from 0 to 1 without waiting for queue to fill up" alarm_actions = [aws_appautoscaling_policy.sagemaker_policy_zero_to_one.arn] } 
Enter fullscreen mode Exit fullscreen mode

Basically we are moving the entire management of the CRUD of the endpoint/config/model to a Python script that runs boto3. In this case I've hard coded some values as you can see below, and some values I keep as parameters passed into the script. For this scenario I'm also passing in the terraform action (plan, apply etc.) since I don't want to run the script if it's not an apply action. Maybe not the best way, but it fits well with our Github actions.

import boto3 import json import sys from datetime import datetime from time import sleep def endpoint_exists(client, endpoint_name: str) -> tuple[str | None, str | None]: try: res = client.describe_endpoint(EndpointName=endpoint_name) return ( res["EndpointConfigName"], res["ProductionVariants"][0]["DeployedImages"][0]["SpecifiedImage"], ) except Exception: return None, None def create_endpoint_config( client, name: str, model_name: str, instance_type: str, output_path: str, error_path: str, success_topic: str, error_topic: str, ): config_name = f"{name}-{datetime.now().strftime('%Y%m%d%H%M%S')}" res = client.create_endpoint_config( EndpointConfigName=config_name, ProductionVariants=[ { "ModelName": model_name, "VariantName": "AllTraffic", "InstanceType": instance_type, "InitialInstanceCount": 1, "ContainerStartupHealthCheckTimeoutInSeconds": 1800, "ModelDataDownloadTimeoutInSeconds": 1200, } ], AsyncInferenceConfig={ "OutputConfig": { "S3OutputPath": output_path, "S3FailurePath": error_path, "NotificationConfig": { "IncludeInferenceResponseIn": [ "SUCCESS_NOTIFICATION_TOPIC", "ERROR_NOTIFICATION_TOPIC", ], "SuccessTopic": success_topic, "ErrorTopic": error_topic, }, }, "ClientConfig": { "MaxConcurrentInvocationsPerInstance": 1, }, }, ) return config_name, res["EndpointConfigArn"] def create_endpoint( client, endpoint_name: str, endpoint_config_name: str, alarm_name: str ): client.create_endpoint( EndpointName=endpoint_name, EndpointConfigName=endpoint_config_name, DeploymentConfig={ "BlueGreenUpdatePolicy": { "TrafficRoutingConfiguration": { "Type": "ALL_AT_ONCE", "WaitIntervalInSeconds": 0, }, }, "AutoRollbackConfiguration": { "Alarms": [ {"AlarmName": alarm_name}, ] }, }, ) waiter = client.get_waiter("endpoint_in_service") waiter.wait( EndpointName=endpoint_name, WaiterConfig={"Delay": 10, "MaxAttempts": 60} ) def delete_endpoint_config(client, endpoint_config_name: str): client.delete_endpoint_config(EndpointConfigName=endpoint_config_name) def update_endpoint( client, endpoint_name: str, endpoint_config_name: str, alarm_name: str ): retry_delay = 10 for _ in range(60): try: client.update_endpoint( EndpointName=endpoint_name, EndpointConfigName=endpoint_config_name, RetainAllVariantProperties=True, RetainDeploymentConfig=True, ) break except Exception as e: if "Cannot update in-progress endpoint" in str(e): sleep(retry_delay) else: raise e waiter = client.get_waiter("endpoint_in_service") waiter.wait( EndpointName=endpoint_name, WaiterConfig={"Delay": 10, "MaxAttempts": 120} ) def main(input: dict): deploy_action = input["deploy_action"] if deploy_action != "apply": print( json.dumps({"type": "no_change", "endpoint_name": input["endpoint_name"]}) ) return session = boto3.Session( aws_access_key_id=input["aws_access_key_id"], aws_secret_access_key=input["aws_secret_access_key"], ) ec_model_name = input["endpoint_config_model_name"] ec_model_image = input["endpoint_config_model_image"] ec_instance_type = input["endpoint_config_instance_type"] ec_output_path = input["endpoint_config_output_path"] ec_error_path = input["endpoint_config_error_path"] ec_success_topic = input["endpoint_config_success_topic"] ec_error_topic = input["endpoint_config_error_topic"] endpoint_name = input["endpoint_name"] endpoint_alarm_name = input["endpoint_alarm_name"] client = session.client("sagemaker", region_name="eu-north-1") config_name, image_name = endpoint_exists(client, endpoint_name) if config_name: if ec_model_image == image_name: print(json.dumps({"type": "no_change", "endpoint_name": endpoint_name})) else: new_config_name, config_arn = create_endpoint_config( client, endpoint_name, ec_model_name, ec_instance_type, ec_output_path, ec_error_path, ec_success_topic, ec_error_topic, ) update_endpoint(client, endpoint_name, new_config_name, endpoint_alarm_name) delete_endpoint_config(client, config_name) print( json.dumps( { "type": "update", "endpoint_name": endpoint_name, "endpoint_config_name": new_config_name, "model_name": ec_model_name, "old_endpoint_config_name": config_name, "endpoint_config_arn": config_arn, } ) ) else: new_config_name, config_arn = create_endpoint_config( client, endpoint_name, ec_model_name, ec_instance_type, ec_output_path, ec_error_path, ec_success_topic, ec_error_topic, ) create_endpoint(client, endpoint_name, new_config_name, endpoint_alarm_name) print( json.dumps( { "type": "new", "endpoint_name": endpoint_name, "endpoint_config_name": new_config_name, "endpoint_config_arn": config_arn, } ) ) if __name__ == "__main__": input = sys.stdin.read() input_json = json.loads(input) main(input_json) 
Enter fullscreen mode Exit fullscreen mode

Conclusion

It's unfortunate that this has been a forum problem for >4 years with no fix. But at least this solution allows you to correctly handle the Sagemaker logic for replacing the underlying containers in a safe way with progressive rollout and rollback.

Top comments (0)