Take advantage of advanced deployment strategies using Amazon SageMaker deployment guardrails

Deployment guardrails in Amazon SageMaker provide a new set of deployment capabilities allowing you to implement advanced deployment strategies that minimize risk when deploying new model versions on SageMaker hosting. Depending on your use case, you can use a variety of deployment strategies to release new model versions. Each of these strategies relies on a mechanism to shift inference traffic to one or more versions of a deployed model. The chosen strategy depends on your business requirements for your machine learning (ML) use case. However, any strategy should include the ability to monitor the performance of new model versions and automatically roll back to a previous version as needed to minimize potential risk of introducing a new model version with errors. Deployment guardrails offer new advanced deployment capabilities and as of this writing supports two new traffic shifting policies, canary and linear, as well as the ability to automatically roll back when issues are detected.

As part of your MLOps strategy to create repeatable and reliable mechanisms to deploy your models, you should also ensure that the chosen deployment strategy is implemented as part of your automated deployment pipeline. Deployment guardrails use the existing SageMaker CreateEndpoint and UpdateEndpoint APIs, so you can modify your existing deployment pipeline configurations to take advantage of the new deployment capabilities.

In this post, we show you how to use the new deployment guardrail capabilities to deploy your model versions using both a canary and linear deployment strategy.

Solution overview

Amazon SageMaker inference provides managed deployment strategies for testing new versions of your models in production. We cover two new traffic shifting policies in this post: canary and linear. For each of these traffic shifting modes, two HTTPS endpoints are provisioned. Two endpoints are provisioned to reduce deployment risk as traffic is shifted from the original endpoint variant to the new endpoint variant. You configure the endpoints to contain one or more compute instances to deploy your trained model and perform inference requests. SageMaker manages the routing of traffic between the two endpoints. You define Amazon CloudWatch metrics and alarms to monitor metrics on the new endpoint, when traffic is shifted, for a set baking period. If a CloudWatch alarm is triggered, SageMaker performs an auto-rollback to route all traffic to the original endpoint variant. If no CloudWatch alarms are triggered, the original endpoint variant is stopped and the new endpoint variant continues to receive all traffic. The following diagrams illustrate shifting traffic to the new endpoint.

Let’s dive deeper into examples of the canary and linear traffic shifting policies.

We go over the following high-level steps as part of the deployment procedure:

  1. Create the model and endpoint configurations required for the three scenarios: the baseline, the update containing the incompatible model version, and the update with the correct model version.
  2. Invoke the baseline endpoint prior to the update.
  3. Specify the CloudWatch alarms used to trigger the rollbacks.
  4. Update the endpoint to trigger a rollback using either the canary or linear strategy.

First, let’s start with canary deployment.

Canary deployment

The canary deployment option lets you shift one small portion of your traffic (a canary) to the green fleet and monitor it for a baking period. If the canary succeeds on the green fleet, the rest of the traffic is shifted from the blue fleet to the green fleet before stopping the blue fleet.

To demonstrate canary deployments and the auto-rollback feature, we update an endpoint with an incompatible model version and deploy it as a canary fleet, taking a small percentage of the traffic. Requests sent to this canary fleet result in errors, which trigger a rollback using preconfigured CloudWatch alarms. We also demonstrate a success scenario where no alarms are tripped and the update succeeds.

Create and deploy the models

First, we upload our pre-trained models to Amazon Simple Storage Service (Amazon S3). These models were trained using the XGBoost churn prediction notebook in SageMaker. You can also use your own pre-trained models in this step. If you already have a pre-trained model in Amazon S3, you can add it by specifying the s3_key.

The models in this example are used to predict the probability of a mobile customer leaving their current mobile operator. The dataset we use is publicly available and was mentioned in the book Discovering Knowledge in Data by Daniel T. Larose.

Upload the models with the following code:

model_url = S3Uploader.upload(local_path="model/xgb-churn-prediction-model.tar.gz",
desired_s3_uri=f"s3://{bucket}/{prefix}")
model_url2 = S3Uploader.upload(local_path="model/xgb-churn-prediction-model2.tar.gz",
desired_s3_uri=f"s3://{bucket}/{prefix}")

Next, we create our model definitions. We start with deploying the pre-trained churn prediction models. Here, we create the model objects with the image and model data. The three URIs correspond to the baseline version, the update containing the incompatible version, and the update containing the correct model version:

image_uri = image_uris.retrieve('xgboost', boto3.Session().region_name, '0.90-1')
# using newer version of XGBoost which is incompatible, in order to simulate model faults

image_uri2 = image_uris.retrieve('xgboost', boto3.Session().region_name, '1.2-1')
image_uri3 = image_uris.retrieve('xgboost', boto3.Session().region_name, '0.90-2')
model_name = f"DEMO-xgb-churn-pred-{datetime.now():%Y-%m-%d-%H-%M-%S}" 
model_name2 = f"DEMO-xgb-churn-pred2-{datetime.now():%Y-%m-%d-%H-%M-%S}"
model_name3 = f"DEMO-xgb-churn-pred3-{datetime.now():%Y-%m-%d-%H-%M-%S}"

resp = sm.create_model(
    ModelName=model_name,
    ExecutionRoleArn=role,
    Containers=[{
       'Image': image_uri,
       'ModelDataUrl': model_url
     }])

resp = sm.create_model(
    ModelName=model_name2,
    ExecutionRoleArn=role,
    Containers=[{
       'Image':image_uri2,
       'ModelDataUrl': model_url2
     }])

resp = sm.create_model(
    ModelName=model_name3,
    ExecutionRoleArn=role,
    Containers=[{
       'Image':image_uri3,
       'ModelDataUrl': model_url2
     }])

Now that the three models are created, we create the three endpoint configs:

ep_config_name = f"DEMO-EpConfig-1-{datetime.now():%Y-%m-%d-%H-%M-%S}" 
ep_config_name2 = f"DEMO-EpConfig-2-{datetime.now():%Y-%m-%d-%H-%M-%S}" 
ep_config_name3 = f"DEMO-EpConfig-3-{datetime.now():%Y-%m-%d-%H-%M-%S}"

resp = sm.create_endpoint_config(
     EndpointConfigName=ep_config_name,
     ProductionVariants=[
        {
          'VariantName': "AllTraffic",
          'ModelName': model_name,
          'InstanceType': "ml.m5.xlarge",
          "InitialInstanceCount": 3
        }
      ])

resp = sm.create_endpoint_config(
     EndpointConfigName=ep_config_name2,
     ProductionVariants=[
        {
          'VariantName': "AllTraffic",
          'ModelName': model_name2,
          'InstanceType': "ml.m5.xlarge",
          "InitialInstanceCount": 3
        }
      ])


resp = sm.create_endpoint_config(
      EndpointConfigName=ep_config_name3,
      ProductionVariants=[
         {
           'VariantName': "AllTraffic",
           'ModelName': model_name3,
           'InstanceType': "ml.m5.xlarge",
           "InitialInstanceCount": 3
         }
     ])

We then deploy the baseline model to a SageMaker endpoint:

resp = sm.create_endpoint(
          EndpointName=endpoint_name,
          EndpointConfigName=ep_config_name
)

Invoke the endpoint

This step invokes the endpoint with sample data with a maximum invocations count and waiting intervals. See the following code:

def invoke_endpoint(endpoint_name, max_invocations=300, wait_interval_sec=1, should_raise_exp=False):
    print(f"Sending test traffic to the endpoint {endpoint_name}. nPlease wait...")
 
    count = 0
    with open('test_data/test-dataset-input-cols.csv', 'r') as f:
        for row in f:
            payload = row.rstrip('n')
            try:
                response = sm_runtime.invoke_endpoint(EndpointName=endpoint_name,
                                                      ContentType='text/csv', 
                                                      Body=payload)
                response['Body'].read()
                print(".", end="", flush=True)
            except Exception as e:
                print("E", end="", flush=True)
                if should_raise_exp:
                    raise e
            count += 1
            if count > max_invocations:
                break
            time.sleep(wait_interval_sec)
 
    print("nDone!")
 
invoke_endpoint(endpoint_name, max_invocations=100)

For a full list of metrics, see Monitor Amazon SageMaker with Amazon CloudWatch.

Then we plot graphs to show the metrics Invocations, Invocation4XXErrors, Invocation5XXErrors, ModelLatency, and OverheadLatency against the endpoint over time.

You can observe a flat line for Invocation4XXErrors and Invocation5XXErrors because we’re using the correct version model version and configs. Additionally, ModelLatency and OverheadLatency start decreasing over time.

Create CloudWatch alarms to monitor endpoint performance

We create CloudWatch alarms to monitor endpoint performance with the metrics Invocation5XXErrors and ModelLatency.

We use metric dimensions EndpointName and VariantName to select the metric for each endpoint config and variant. See the following code:

def create_auto_rollback_alarm(alarm_name, endpoint_name, variant_name, metric_name, statistic, threshold):
cw.put_metric_alarm(
AlarmName=alarm_name,
AlarmDescription='Test SageMaker endpoint deployment auto-rollback alarm',
ActionsEnabled=False,
Namespace='AWS/SageMaker',
MetricName=metric_name,
Statistic=statistic,
Dimensions=[
{
'Name': 'EndpointName',
'Value': endpoint_name
},
{
'Name': 'VariantName',
'Value': variant_name
}
],
Period=60,
EvaluationPeriods=1,
Threshold=threshold,
ComparisonOperator='GreaterThanOrEqualToThreshold',
TreatMissingData='notBreaching'
)

# alarm on 1% 5xx error rate for 1 minute
create_auto_rollback_alarm(error_alarm, endpoint_name, 'AllTraffic', 'Invocation5XXErrors', 'Average', 1)
# alarm on model latency >= 10 ms for 1 minute
create_auto_rollback_alarm(latency_alarm, endpoint_name, 'AllTraffic', 'ModelLatency', 'Average', 10000)

Update the endpoint with deployment configurations

We define the following deployment configuration to perform a blue/green update strategy with canary traffic shifting from the old to the new stack. The canary traffic shifting option can reduce the blast ratio of a regressive update to the endpoint. In contrast, for the all-at-once traffic shifting option, the invocation requests start faulting at 100% after flipping the traffic. In canary mode, invocation requests are shifted to the new version of model gradually, preventing errors from impacting 100% of the traffic. Additionally, the auto-rollback alarms monitor the metrics during the canary stage.

The following diagram illustrates the workflow of our rollback use case.

We update the endpoint with an incompatible model version to simulate errors and trigger a rollback:

canary_deployment_config = {
    "BlueGreenUpdatePolicy": {
        "TrafficRoutingConfiguration": {
            "Type": "CANARY",
            "CanarySize": {
                "Type": "INSTANCE_COUNT", # or use "CAPACITY_PERCENT" as 30%, 50%
                "Value": 1
            },
            "WaitIntervalInSeconds": 300, # wait for 5 minutes before enabling traffic on the rest of fleet
        },
        "TerminationWaitInSeconds": 120, # wait for 2 minutes before terminating the old stack
        "MaximumExecutionTimeoutInSeconds": 1800 # maximum timeout for deployment
    },
    "AutoRollbackConfiguration": {
        "Alarms": [
            {
                "AlarmName": error_alarm
            },
            {
                "AlarmName": latency_alarm
            }
        ],
    }
}
 
# update endpoint request with new DeploymentConfig parameter
sm.update_endpoint(
    EndpointName=endpoint_name,
    EndpointConfigName=ep_config_name2,
    DeploymentConfig=canary_deployment_config
)

When we invoke the endpoint, we encounter errors because of the incompatible version of the model (ep_config_name2), and this leads to the rollback to a stable version of the model (ep_config_name1). This is reflected in the following graphs as Invocation5XXErrors and ModelLatency increase during this rollback phase.

The following diagram shows a success case where we use the same canary deployment configuration but a valid endpoint configuration.

We update the endpoint configuration to a valid version (using the same canary deployment config as the rollback case):

# update endpoint with a valid version of DeploymentConfig
sm.update_endpoint(
    EndpointName=endpoint_name,
    EndpointConfigName=ep_config_name3,
    RetainDeploymentConfig=True
)

We plot graphs to show the Invocations, Invocation5XXErrors, and ModelLatency metrics against the endpoint. When the new endpoint config-3 (correct model version) starts getting deployed, it takes over endpoint config-2 (incompatible due to model version) without any errors. We can see this in the graphs as Invocation5XXErrors and ModelLatency decrease during this transition phase.

Next, let’s see how linear deployments are configured and how it works.

Linear deployment

The linear deployment option provides even more customization over how many traffic-shifting steps to make and what percentage of traffic to shift for each step. Whereas canary shifting lets you shift traffic in two steps, linear shifting extends this to n linearly spaced steps.

To demonstrate linear deployments and the auto-rollback feature, we update an endpoint with an incompatible model version and deploy it as a linear fleet, taking a small percentage of the traffic. Requests sent to this linear fleet result in errors, which triggers a rollback using preconfigured CloudWatch alarms. We also demonstrate a success scenario where no alarms are tripped and the update succeeds.

The steps to create the models, invoke the endpoint, and create the CloudWatch alarms are the same as with the canary method.

We define the following deployment configuration to perform a blue/green update strategy with linear traffic shifting from old to new stack. The linear traffic shifting option can reduce the blast ratio of a regressive update to the endpoint. In contrast, for the all-at-once traffic shifting option, the invocation requests start faulting at 100% after flipping the traffic. In linear mode, invocation requests are shifted to the new version of the model gradually, with a controlled percentage of traffic shifting for each step. You can use the auto-rollback alarms to monitor the metrics during the linear traffic shifting stage.

The following diagram shows the workflow for our linear rollback case.

We update the endpoint with an incompatible model version to simulate errors and trigger a rollback:

linear_deployment_config = {
    "BlueGreenUpdatePolicy": {
        "TrafficRoutingConfiguration": {
            "Type": "LINEAR",
            "LinearStepSize": {
                "Type": "CAPACITY_PERCENT",
                "Value": 33, # 33% of whole fleet capacity (33% * 3 = 1 instance)
            },
            "WaitIntervalInSeconds": 180, # wait for 3 minutes before enabling traffic on the rest of fleet
        },
        "TerminationWaitInSeconds": 120, # wait for 2 minutes before terminating the old stack
        "MaximumExecutionTimeoutInSeconds": 1800 # maximum timeout for deployment
    },
    "AutoRollbackConfiguration": {
        "Alarms": [
            {
                "AlarmName": error_alarm
            },
            {
                "AlarmName": latency_alarm
            }
        ],
    }
}
 
# update endpoint request with new DeploymentConfig parameter
sm.update_endpoint(
    EndpointName=endpoint_name,
    EndpointConfigName=ep_config_name2,
    DeploymentConfig=linear_deployment_config
)

When we invoke the endpoint, we encounter errors because of the incompatible version of the model (ep_config_name2), which leads to the rollback to a stable version of the model (ep_config_name1). We can see this in the following graphs as the Invocation5XXErrors and ModelLatency metrics increase during this rollback phase.

Let’s look at a success case where we use the same linear deployment configuration but a valid endpoint configuration. The following diagram illustrates our workflow.

We update the endpoint to a valid endpoint configuration version with the same linear deployment configuration:

# update endpoint with a valid version of DeploymentConfig
sm.update_endpoint(
    EndpointName=endpoint_name,
    EndpointConfigName=ep_config_name3,
    RetainDeploymentConfig=True
)

Then we plot graphs to show the Invocations, Invocation5XXErrors, and ModelLatency metrics against the endpoint.

As the new endpoint config-3 (correct model version) starts getting deployed, it takes over endpoint config-2 (incompatible due to model version) without any errors. We can see this in the following graphs as Invocation5XXErrors and ModelLatency decrease during this transition phase.

Considerations and best practices

Now that we’ve walked through a comprehensive example, let’s recap some best practices and considerations:

  • Pick the right health check – The CloudWatch alarms determine whether the traffic shift to the new endpoint variant succeeds. In our example, we used Invocation5XXErrors (caused by the endpoint failing to return a valid result) and ModelLatency, which measure how long the model takes to return a response. You can consider other built-in metrics in some cases, like OverheadLatency, which accounts for other causes of latency, such as unusually large response payloads. You can also have your inference code record custom metrics, and you can configure the alarm measurement evaluation interval. For more information about available metrics, see SageMaker Endpoint Invocation Metrics.
  • Pick the most suitable traffic shifting policy – The all-at-once policy is a good choice if you just want to make sure that the new endpoint variant is healthy and able to serve traffic. The canary policy is useful if you want to avoid affecting too much traffic if the new endpoint variant has a problem, or if you want to evaluate a custom metric on a small percentage of traffic before shifting over. For example, perhaps you want to emit a custom metric that checks for inference response distribution, and make sure it falls within expected ranges. The linear policy is a more conservative and more complex take on the canary pattern.
  • Monitor the alarms – The alarms you use to trigger rollback should also cause other actions, like notifying an operations team.
  • Use the same deployment strategy in multiple environments – As part of an overall MLOps pipeline, use the same deployment strategy in test as well as production environments, so that you become comfortable with the behavior. This consideration implies that you can inject realistic load onto your test endpoints.

Conclusion

In this post, we introduced SageMaker inference’s new deployment guardrail options, which let you manage deployment of a new model version in a safe and controlled way. We reviewed the new traffic shifting policies, canary and linear, and showed how to use them in a realistic example. Finally, we discussed some best practices and considerations. Get started today with deployment guardrails on the SageMaker console or, for more information, review Deployment Guardrails.


About the Authors

Raghu Ramesha is an ML Solutions Architect with the Amazon SageMaker Services SA team. He focuses on helping customers migrate ML production workloads to SageMaker at scale. He specializes in machine learning, AI, and computer vision domains, and holds a master’s degree in Computer Science from UT Dallas. In his free time, he enjoys traveling and photography.

Shelbee Eigenbrode is a Principal AI and Machine Learning Specialist Solutions Architect at Amazon Web Services (AWS). She has been in technology for 24 years spanning multiple industries, technologies, and roles. She is currently focusing on combining her DevOps and ML background into the domain of MLOps to help customers deliver and manage ML workloads at scale. With over 35 patents granted across various technology domains, she has a passion for continuous innovation and using data to drive business outcomes. Shelbee is a co-creator and instructor of the Practical Data Science specialization on Coursera. She is also the Co-Director of Women In Big Data (WiBD), Denver chapter. In her spare time, she likes to spend time with her family, friends, and overactive dogs.

Randy DeFauw is a Principal Solutions Architect. He’s an electrical engineer by training who’s been working in technology for 23 years at companies ranging from startups to large defense firms. A fascination with distributed consensus systems led him into the big data space, where he discovered a passion for analytics and machine learning. He started using AWS in his Hadoop days, where he saw how easy it was to set up large complex infrastructure, and then realized that the cloud solved some of the challenges he saw with Hadoop. Randy picked up an MBA so he could learn how business leaders think and talk, and found that the soft skill classes were some of the most interesting ones he took. Lately, he’s been dabbling with reinforcement learning as a way to tackle optimization problems, and re-reading Martin Kleppmann’s book on data intensive design.

Lauren Mullennex is a Solutions Architect based in Denver, CO. She works with customers to help them architect solutions on AWS. In her spare time, she enjoys hiking and cooking Hawaiian cuisine.

Read More

Train graph neural nets for millions of proteins on Amazon SageMaker and Amazon DocumentDB (with MongoDB compatibility)

There are over 180,000 unique proteins with 3D structures determined, with tens of thousands new structures resolved every year. This is only a small fraction of the 200 million known proteins with distinctive sequences. Recent deep learning algorithms such as AlphaFold can accurately predict 3D structures of proteins using their sequences, which help scale the protein 3D structure data to the millions. Graph neural network (GNN) has emerged as an effective deep learning approach to extract information from protein structures, which can be represented by graphs of amino acid residues. Individual protein graphs usually contain a few hundred nodes, which is manageable in size. Tens of thousands of protein graphs can be easily stored in serialized data structures such as TFrecord for training GNNs. However, training GNN on millions of protein structures is challenging. Data serialization isn’t scalable to millions of protein structures because it requires loading the entire terabyte-scale dataset into memory.

In this post, we introduce a scalable deep learning solution that allows you to train GNNs on millions of proteins stored in Amazon DocumentDB (with MongoDB compatibility) using Amazon SageMaker.

For illustrative purposes, we use publicly available experimentally determined protein structures from the Protein Data Bank and computationally predicted protein structures from the AlphaFold Protein Structure Database. The machine learning (ML) problem is to develop a discriminator GNN model to distinguish between experimental and predicted structures based on protein graphs constructed from their 3D structures.

Overview of solution

We first parse the protein structures into JSON records with multiple types of data structures, such as an n-dimensional array and nested object, to store the proteins’ atomic coordinates, properties, and identifiers. Storing a JSON record for a protein’s structure takes 45 KB on average; we project storing 100 million proteins would take around 4.2 TB. Amazon DocumentDB storage automatically scales with the data in your cluster volume in 10 GB increments, up to 64 TB. Therefore, the support for JSON data structure and scalability makes Amazon DocumentDB a natural choice.

We next build a GNN model to predict protein properties using graphs of amino acid residues constructed from their structures. The GNN model is trained using SageMaker and configured to efficiently retrieve batches of protein structures from the database.

Finally, we analyze the trained GNN model to gain some insights into the predictions.

We walk through the following steps for this tutorial:

  1. Create resources using an AWS CloudFormation template.
  2. Prepare protein structures and properties and ingest the data into Amazon DocumentDB.
  3. Train a GNN on the protein structures using SageMaker.
  4. Load and evaluate the trained GNN model.

The code and notebooks used in this post are available in the GitHub repo.

Prerequisites

For this walkthrough, you should have the following prerequisites:

Running this tutorial for an hour should cost no more than $2.00.

Create resources

We provide a CloudFormation template to create the required AWS resources for this post, with a similar architecture as in the post Analyzing data stored in Amazon DocumentDB (with MongoDB compatibility) using Amazon SageMaker. For instructions on creating a CloudFormation stack, see the video Simplify your Infrastructure Management using AWS CloudFormation.

The CloudFormation stack provisions the following:

  • A VPC with three private subnets for Amazon DocumentDB and two public subnets intended for the SageMaker notebook instance and ML training containers, respectively.
  • An Amazon DocumentDB cluster with three nodes, one in each private subnet.
  • A Secrets Manager secret to store login credentials for Amazon DocumentDB. This allows us to avoid storing plaintext credentials in our SageMaker instance.
  • A SageMaker notebook instance to prepare data, orchestrate training jobs, and run interactive analyses.

When creating the CloudFormation stack, you need to specify the following:

  • Name for your CloudFormation stack
  • Amazon DocumentDB user name and password (to be stored in Secrets Manager)
  • Amazon DocumentDB instance type (default db.r5.large)
  • SageMaker instance type (default ml.t3.xlarge)

It should take about 15 minutes to create the CloudFormation stack. The following diagram shows the resource architecture.

Prepare protein structures and properties and ingest the data into Amazon DocumentDB

All the subsequent code in this section is in the Jupyter notebook Prepare_data.ipynb in the SageMaker instance created in your CloudFormation stack.

This notebook handles the procedures required for preparing and ingesting protein structure data into Amazon DocumentDB.

  1. We first download predicted protein structures from AlphaFold DB in PDB format and the matching experimental structures from the Protein Data Bank.

For demonstration purposes, we only use proteins from the thermophilic archaean Methanocaldococcus jannaschii, which has the smallest proteome of 1,773 proteins for us to work with. You are welcome to try using proteins from other species.

  1. We connect to an Amazon DocumentDB cluster by retrieving the credentials stored in Secrets Manager:
def get_secret(stack_name):

    # Create a Secrets Manager client
    session = boto3.session.Session()
    client = session.client(
        service_name="secretsmanager",
        region_name=session.region_name
    )
    
    secret_name = f"{stack_name}-DocDBSecret"
    get_secret_value_response = client.get_secret_value(SecretId=secret_name)
    secret = get_secret_value_response["SecretString"]
    
return json.loads(secret)
	
	secrets = get_secret("gnn-proteins")
	
# connect to DocDB
	uri = "mongodb://{}:{}@{}:{}/?tls=true&tlsCAFile=rds-combined-ca-bundle.pem&replicaSet=rs0&readPreference=secondaryPreferred&retryWrites=false"
    		.format(secrets["username"], secrets["password"], secrets["host"], secrets["port"])
	
client = MongoClient(uri)

db = client["proteins"] # create a database
collection = db["proteins"] # create a collection
  1. After we set up the connection to Amazon DocumentDB, we parse the PDB files into JSON records to ingest into the database.

We provide utility functions required for parsing PDB files in pdb_parse.py. The parse_pdb_file_to_json_record function does the heavy lifting of extracting atomic coordinates from one or multiple peptide chains in a PDB file and returns one or a list of JSON documents, which can be directly ingested into the Amazon DocumentDB collection as a document. See the following code:

recs = parse_pdb_file_to_json_record(pdb_parser, pdb_file, pdb_id)
collection.insert_many(recs)

After we ingest the parsed protein data into Amazon DocumentDB, we can update the contents of the protein documents. For instance, it makes our model training logistics easier if we add a field indicating whether a protein structure should be used in the training, validation, or test sets.

  1. We first retrieve the all the documents with the field is_AF to stratify documents using an aggregation pipeline:
match = {"is_AF": {"$exists": True}}
project = {"y": "$is_AF"}

pipeline = [
    {"$match": match},
    {"$project": project},
]
# aggregation pipeline
cur = collection.aggregate(pipeline)
# retrieve documents from the DB cursor
docs = [doc for doc in cur]
# convert to a data frame:
df = pd.DataFrame(docs)
# stratified split: full -> train/test
df_train, df_test = train_test_split(
    df, 
    test_size=0.2,
    stratify=df["y"], 
    random_state=42
)
# stratified split: train -> train/valid
df_train, df_valid = train_test_split(
    df_train, 
    test_size=0.2,
    stratify=df_train["y"], 
    random_state=42
)
  1. Next, we use the update_many function to store the split information back to Amazon DocumentDB:
for split, df_split in zip(
    ["train", "valid", "test"],
    [df_train, df_valid, df_test]
):
    result = collection.update_many(
        {"_id": {"$in": df_split["_id"].tolist()}}, 
        {"$set": {"split": split}}
)
print("Number of documents modified:", result.modified_count)

Train a GNN on the protein structures using SageMaker

All the subsequent code in this section is in the Train_and_eval.ipynb notebook in the SageMaker instance created in your CloudFormation stack.

This notebook trains a GNN model on the protein structure datasets stored in the Amazon DocumentDB.

We first need to implement a PyTorch dataset class for our protein dataset capable of retrieving mini-batches of protein documents from Amazon DocumentDB. It’s more efficient to retrieve batches documents by the built-in primary id (_id).

  1. We use the iterable-style dataset by extending the IterableDataset, which pre-fetches the _id and labels of the documents at initialization:
class ProteinDataset(data.IterableDataset):
    """
    An iterable-style dataset for proteins in DocumentDB
    Args:
        pipeline: an aggregation pipeline to retrieve data from DocumentDB
        db_uri: URI of the DocumentDB
        db_name: name of the database
        collection_name: name of the collection
        k: k used for kNN when creating a graph from atomic coordinates
    """

    def __init__(
        self, pipeline, db_uri="", db_name="", collection_name="", k=3
    ):

        self.db_uri = db_uri
        self.db_name = db_name
        self.collection_name = collection_name
        self.k = k

        client = MongoClient(self.db_uri, connect=False)
        collection = client[self.db_name][self.collection_name]
        # pre-fetch the metadata as docs from DocumentDB
        self.docs = [doc for doc in collection.aggregate(pipeline)]
        # mapping document '_id' to label
        self.labels = {doc["_id"]: doc["y"] for doc in self.docs}
  1. The ProteinDataset performs a database read operation in the __iter__ method. It tries to evenly split the workload if there are multiple workers:
def __iter__(self):
        worker_info = torch.utils.data.get_worker_info()
        if worker_info is None:
            # single-process data loading, return the full iterator
            protein_ids = [doc["_id"] for doc in self.docs]

        else:  # in a worker process
            # split workload
            start = 0
            end = len(self.docs)
            per_worker = int(
                math.ceil((end - start) / float(worker_info.num_workers))
            )
            worker_id = worker_info.id
            iter_start = start + worker_id * per_worker
            iter_end = min(iter_start + per_worker, end)

            protein_ids = [
                doc["_id"] for doc in self.docs[iter_start:iter_end]
            ]

        # retrieve a list of proteins by _id from DocDB
        with MongoClient(self.db_uri) as client:
            collection = client[self.db_name][self.collection_name]
            cur = collection.find(
                {"_id": {"$in": protein_ids}},
                projection={"coords": True, "seq": True},
            )
            return (
                (
                    convert_to_graph(protein, k=self.k),
                    self.labels[protein["_id"]],
                )
                for protein in cur
            )
  1. The preceding __iter__ method also converts the atomic coordinates of proteins into DGLGraph objects after they’re loaded from Amazon DocumentDB via the convert_to_graph function. This function constructs a k-nearest neighbor (kNN) graph for the amino acid residues using the 3D coordinates of the C-alpha atoms and adds one-hot encoded node features to represent residue identities:
def convert_to_graph(protein, k=3):
    """
    Convert a protein (dict) to a dgl graph using kNN.
    """
    coords = torch.tensor(protein["coords"])
    X_ca = coords[:, 1]
    # construct knn graph from C-alpha coordinates
    g = dgl.knn_graph(X_ca, k=k)
    seq = protein["seq"]
    node_features = torch.tensor([d1_to_index[residue] for residue in seq])
    node_features = F.one_hot(node_features, num_classes=len(d1_to_index)).to(
        dtype=torch.float
    )

    # add node features
    g.ndata["h"] = node_features
    return g
  1. With the ProteinDataset implemented, we can initialize instances for train, validation, and test datasets and wrap the training instance with BufferedShuffleDataset to enable shuffling.
  2. We further wrap them with torch.utils.data.DataLoader to work with other components of the SageMaker PyTorch Estimator training script.
  3. Next, we implement a simple two-layered graph convolution network (GCN) with a global attention pooling layer for ease of interpretation:
class GCN(nn.Module):
    """A two layer Graph Conv net with Global Attention Pooling over the
    nodes.
    Args:
        in_feats: int, dim of input node features
        h_feats: int, dim of hidden layers
        num_classes: int, number of output units
    """

    def __init__(self, in_feats, h_feats, num_classes):
        super(GCN, self).__init__()
        self.conv1 = GraphConv(in_feats, h_feats)
        self.conv2 = GraphConv(h_feats, h_feats)
        # the gate layer that maps node feature to outputs
        self.gate_nn = nn.Linear(h_feats, num_classes)
        self.gap = GlobalAttentionPooling(self.gate_nn)
        # the output layer making predictions
        self.output = nn.Linear(h_feats, num_classes)

    def _conv_forward(self, g):
        """forward pass through the GraphConv layers"""
        in_feat = g.ndata["h"]
        h = self.conv1(g, in_feat)
        h = F.relu(h)
        h = self.conv2(g, h)
        h = F.relu(h)
        return h

    def forward(self, g):
        h = self._conv_forward(g)
        h = self.gap(g, h)
        return self.output(h)

    def attention_scores(self, g):
        """Calculate attention scores"""
        h = self._conv_forward(g)
        with g.local_scope():
            gate = self.gap.gate_nn(h)
            g.ndata["gate"] = gate
            gate = dgl.softmax_nodes(g, "gate")
            g.ndata.pop("gate")
            return gate
  1. Afterwards, we can train this GCN on the ProteinDataset instance for a binary classification task of predicting whether a protein structure is predicted by AlphaFold or not. We use binary cross entropy as the objective function and Adam optimizer for stochastic gradient optimization. The full training script can be found in src/main.py.

Next, we set up the SageMaker PyTorch Estimator to handle the training job. To allow the managed Docker container initiated by SageMaker to connect to Amazon DocumentDB, we need to configure the subnet and security group for the Estimator.

  1. We retrieve the subnet ID where the Network Address Translation (NAT) gateway resides, as well as the security group ID of our Amazon DocumentDB cluster by name:
ec2 = boto3.client("ec2")
# find the NAT gateway's subnet ID 
resp = ec2.describe_subnets(
    Filters=[{"Name": "tag:Name", "Values": ["{}-NATSubnet".format(stack_name)]}]
)
nat_subnet_id = resp["Subnets"][0]["SubnetId"]
# find security group id of the DocumentDB
resp = ec2.describe_security_groups(
    Filters=[{
        "Name": "tag:Name", 
        "Values": ["{}-SG-DocumentDB".format(stack_name)]
    }])
sg_id = resp["SecurityGroups"][0]["GroupId"]
Finally, we can kick off the training of our GCN model using SageMaker: 
from sagemaker.pytorch import PyTorch

CODE_PATH = "main.py"

params = {
    "patience": 5, 
    "n-epochs": 200,
    "batch-size": 64,
    "db-host": secrets["host"],
    "db-username": secrets["username"], 
    "db-password": secrets["password"], 
    "db-port": secrets["port"],
    "knn": 4,
}

estimator = PyTorch(
    entry_point=CODE_PATH,
    source_dir="src",
    role=role,
    instance_count=1,
    instance_type="ml.p3.2xlarge", # 'ml.c4.2xlarge' for CPU
    framework_version="1.7.1",
    py_version="py3",
    hyperparameters=params,
    sagemaker_session=sess,
    subnets=[nat_subnet_id], 
    security_group_ids=[sg_id],
)
# run the training job:
estimator.fit()

Load and evaluate the trained GNN model

When the training job is complete, we can load the trained GCN model and perform some in-depth evaluation.

The codes for the following steps are also available in the notebook Train_and_eval.ipynb.

SageMaker training jobs save the model artifacts into the default S3 bucket, the URI of which can be accessed from the estimator.model_data attribute. We can also navigate to the Training jobs page on the SageMaker console to find the trained model to evaluate.

  1. For research purposes, we can load the model artifact (learned parameters) into a PyTorch state_dict using the following function:
def load_sagemaker_model_artifact(s3_bucket, key):
    """Load a PyTorch model artifact (model.tar.gz) produced by a SageMaker
    Training job.
    Args:
        s3_bucket: str, s3 bucket name (s3://bucket_name)
        key: object key: path to model.tar.gz from within the bucket
    Returns:
        state_dict: dict representing the PyTorch checkpoint
    """
    # load the s3 object
    s3 = boto3.client("s3")
    obj = s3.get_object(Bucket=s3_bucket, Key=key)
    # read into memory
    model_artifact = BytesIO(obj["Body"].read())
    # parse out the state dict from the tar.gz file
    tar = tarfile.open(fileobj=model_artifact)
    for member in tar.getmembers():
        pth = tar.extractfile(member).read()

    state_dict = torch.load(BytesIO(pth), map_location=torch.device("cpu"))
return state_dict

	state_dict = load_sagemaker_model_artifact(
bucket, 
key=estimator.model_data.split(bucket)[1].lstrip("/")
)

# initialize a GCN model
model = GCN(dim_nfeats, 16, n_classes)
# load the learned parameters
model.load_state_dict(state_dict["model_state_dict"])
  1. Next, we perform quantitative model evaluation on the full test set by calculating accuracy:
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
num_correct = 0
num_tests = 0
model.eval()
with torch.no_grad():
    for batched_graph, labels in test_loader:
        batched_graph = batched_graph.to(device)
        labels = labels.to(device)
        logits = model(batched_graph)
        preds = (logits.sigmoid() > 0.5).to(labels.dtype)
        num_correct += (preds == labels).sum().item()
        num_tests += len(labels)

print('Test accuracy: {:.6f}'.format(num_correct / num_tests))

We found our GCN model achieved an accuracy of 74.3%, whereas the dummy baseline model making predictions based on class priors only achieved 56.3%.

We’re also interested in interpretability of our GCN model. Because we implement a global attention pooling layer, we can compute the attention scores across nodes to explain specific predictions made by the model.

  1. Next, we compute the attention scores and overlay them on the protein graphs for a pair of structures (AlphaFold predicted and experimental) from the same peptide:
pair = ["AF-Q57887", "1JT8-A"]
cur = collection.find(
    {"id": {"$in": pair}},
)

for doc in cur:
    # convert to dgl.graph object
    graph = convert_to_graph(doc, k=4)
    
    with torch.no_grad():
        # make prediction
        pred = model(graph).sigmoid()
        # calculate attention scores for a protein graph
        attn = model.attention_scores(graph)
    
    pred = pred.item()
    attn = attn.numpy()
    
    # convert to networkx graph for visualization
    graph = graph.to_networkx().to_undirected()
    # calculate graph layout
    pos = nx.spring_layout(graph, iterations=500)
    
    fig, ax = plt.subplots(figsize=(8, 8))
    nx.draw(
        graph, 
        pos, 
        node_color=attn.flatten(),
        cmap="Reds",
        with_labels=True, 
        font_size=8,
        ax=ax
    )
    ax.set(title="{}, p(is_predicted)={:.6f}".format(doc["id"], pred))
plt.show()

The preceding codes produce the following protein graphs overlaid with attention scores on the nodes. We find the model’s global attentive pooling layer can highlight certain residues in the protein graph as being important for making the prediction of whether the protein structure is predicted by AlphaFold. This indicates that these residues may have distinctive graph topologies in predicted and experimental protein structures.

In summary, we showcase a scalable deep learning solution to train GNNs on protein structures stored in Amazon DocumentDB. Although the tutorial only uses thousands of proteins for training, this solution is scalable to millions of proteins. Unlike other approaches such as serializing the entire protein dataset, our approach transfers the memory-heavy workloads to the database, making the memory complexity for the training jobs O(batch_size), which is independent of the total number of proteins to train.

Clean up

To avoid incurring future charges, delete the CloudFormation stack you created. This removes all the resources you provisioned using the CloudFormation template, including the VPC, Amazon DocumentDB cluster, and SageMaker instance. For instructions, see Deleting a stack on the AWS CloudFormation console.

Conclusion

We described a cloud-based deep learning architecture scalable to millions of protein structures by storing them in Amazon DocumentDB and efficiently retrieving mini-batches of data from SageMaker.

To learn more about the use of GNN in protein property predictions, check out our recent publication LM-GVP, A Generalizable Deep Learning Framework for Protein Property Prediction from Sequence and Structure.


About the Authors

Zichen Wang, PhD, is an Applied Scientist in the Amazon Machine Learning Solutions Lab. With several years of research experience in developing ML and statistical methods using biological and medical data, he works with customers across various verticals to solve their ML problems.

Selvan Senthivel is a Senior ML Engineer with the Amazon ML Solutions Lab at AWS, focusing on helping customers on machine learning, deep learning problems, and end-to-end ML solutions. He was a founding engineering lead of Amazon Comprehend Medical and contributed to the design and architecture of multiple AWS AI services.

Read More

Identity verification using Amazon Rekognition

In-person user identity verification is slow to scale, costly, and high friction for users. Machine learning (ML) powered facial recognition technology can enable online user identity verification. Amazon Rekognition offers pre-trained facial recognition capabilities that you can quickly add to your user onboarding and authentication workflows to verify opted-in users’ identities online. No ML expertise is required. With Amazon Rekognition, you can onboard and authenticate users in seconds while detecting fraudulent or duplicate accounts. As a result, you can grow users faster, reduce fraud, and lower user verification costs.

In this post, we describe a typical identity verification workflow and show how to build an identity verification solution using various Amazon Rekognition APIs. We provide a complete sample implementation in our GitHub repository.

User registration workflow

The following figure shows a sample workflow of a new user registration. Typical steps in this process are:

  1. User captures selfie image and the image of a government-issued identity document.
  2. Quality check of the selfie image and optional liveness detection of the user face.
  3. Comparison of the selfie image with the identity document face image.
  4. Check of the selfie against a database of existing user faces.

You can customize the flow according to the business process. It often contains some or all of the steps presented in the preceding diagram. You can choose to run all the steps synchronously (wait for one step to complete before moving on to the next step). Alternately, you can run some of the steps highlighted in orange asynchronously (don’t wait for that step to complete) to speed up the user registration process and improve the customer experience. If the steps aren’t successful, you must roll back the user registration.

In addition to new user registration, another common flow is an existing or returning user login. In this flow, a check of the user face (selfie) is performed against a previously registered face. Typical steps in this process include user face capture (selfie), check of the selfie image quality, and search and compare of the selfie against the faces database. The following diagram shows a possible flow.

You can customize the steps of the process according to your business needs, and choose to include or exclude the liveness detection.

Solution overview

The following reference architecture shows how you can use Amazon Rekognition, along with other AWS services, to implement identity verification.

The architecture includes the following components:

  1. Applications invoke Amazon API Gateway to route requests to the correct AWS Lambda function depending on the user flow. There are four major actions in this solution: authenticate, register, register with ID card, and update.
  2. API Gateway uses a service integration to run the AWS Step Functions express state machine corresponding to the specific path called from API Gateway. Within each step, Lambda functions are responsible for triggering the correct set of calls to and from Amazon DynamoDB and Amazon Simple Storage Service (Amazon S3), along with the relevant Amazon Rekognition APIs.
  3. DynamoDB holds face IDs (face-id), S3 path URIs, and unique IDs (for example employee ID number) for each face-id. Amazon S3 stores all the face images.
  4. The final major component of the solution is Amazon Rekognition. Each flow (authenticate, register, register with ID card, and update) calls different Amazon Rekognition APIs depending on the task.

Before we deploy the solution, it’s important to know the following concepts and API descriptions:

  • Collections – Amazon Rekognition stores information about detected faces in server-side containers known as collections. You can use the facial information that’s stored in a collection to search for known faces in images, stored videos, and streaming videos. You can use collections in a variety of scenarios. For example, you might create a face collection to store scanned badge images by using the IndexFaces operation. When an employee enters the building, an image of the employee’s face is captured and sent to the SearchFacesByImage operation. If the face match produces a sufficiently high similarity score (say 99%), you can authenticate the employee.
  • DetectFaces API – This API detects faces within an image provided as input and returns information about faces. In a user registration workflow, this operation may help you screen images before moving to the next step. For example, you can check if a photo contains a face, if the person identified is in the right orientation, and if they’re not wearing any face blocker such as sunglasses or a cap.
  • IndexFaces API – This API detects faces in the input image and adds them to the specified collection. This operation is used to add a screened image to a collection for future queries.
  • SearchFacesByImage API – For a given input image, the API first detects the largest face in the image, and then searches the specified collection for matching faces. The operation compares the features of the input face with face features in the specified collection.
  • CompareFaces API – This API compares a face in the source input image with each of the 100 largest faces detected in the target input image. If the source image contains multiple faces, the service detects the largest face and compares it with each face detected in the target image. For our use case, we expect both the source and target image to contain a single face.
  • DeleteFaces API – This API deletes faces from a collection. You specify a collection ID and an array of face IDs to remove.

Prerequisites

Before you get started, complete the following prerequisites:

  1. Create an AWS account.
  2. Clone the sample repo on your local machine:
    git clone https://github.com/aws-samples/rekognition-identity-verification.git

We use the test client in this repository to test the various workflows.

  1. Install Python 3.6+ on your local machine.

Deploy the solution

Choose the appropriate AWS CloudFormation stack to provision the solution into your AWS account in your preferred Region:

N. Virginia (us-east-1)

Oregon (us-west-2)

As we discussed earlier, this solution uses API Gateway integrated with Step Functions and Amazon Rekognition APIs to run the identity verification workflows. To test the solution, follow the steps in the code repository to use the provided test client.

The following sections describe the various workflows implemented via Step Functions.

New user registration

The following image illustrates the Step Functions definition for new user registration. The steps are defined in the register_user.py file.

Three functions are called in this workflow: detect-faces, search-faces, and index-faces. The detect-faces function calls the Amazon Rekognition DetectFaces API to determine if a face is detected in an image and is usable. Some of the quality checks include determining that only the face is present in the image, ensuring the face isn’t obscured by sunglasses or a hat, and confirming that the face isn’t rotated by using the pose dimension. If the image passes the quality check, the search-faces function searches for an existing face match in the Amazon Rekognition collections by confirming the FaceMatchThreshold confidence score meets your threshold objective. For more information, refer to the section on using similarity thresholds to match faces. If the face image doesn’t exist in the collections, the index-faces function is called to index the face into the collection. The face image metadata is stored in the DynamoDB table and the face images are stored in an S3 bucket.

To register a new user, run the app.py script (test-client) by running the following code:

python3 src/test-client/app.py register -z Riv-Prod -r <region> -u <username> -p <path to face image>

If the new user registration succeeds, the face image attribute information is added in DynamoDB.

New user registration with ID card

The steps to register a new user with an ID card are similar to the steps for registering a new user. The following image illustrates the steps, which are defined in the register_idcard.py file.

The same three functions that we used to register a user (detect-faces, search-faces, and index-faces) are called in for this workflow. First, the customer captures an image of their ID and a live image of their face. The face image is checked to confirm it meets our defined quality standards using the DetectFaces API. If the image meets the quality standards, the live face image is compared to the face in the ID to determine if they’re a match. If the images don’t match, the user receives an error and the process ends. If the images match, we check if the face already exists in the Amazon Rekognition collections using the SearchFacesByImage API. The search results are compared to the user’s current face image. If the user already exists, the user isn’t registered. If the user doesn’t exist in the collections, the relevant properties are extracted from the ID card. You can extract key-value pairs from identity documents using the newly launched Amazon Textract AnalyzeID API. The extracted properties from the ID card are merged and the user’s face is indexed in the DynamoDB table. After the image is indexed, the new user ID registration process is complete.

Existing user authentication

The following image illustrates the workflow for authenticating an existing user. The steps are defined in the auth.py file.

This Step Function workflow calls three functions: detect-faces, compare-faces, and search-faces. After the detect-faces function verifies that the captured face image is valid, the compare-faces function checks the DynamoDB table for a face image that matches an existing user. If a match is found in DynamoDB, the user authenticates successfully. If a match isn’t found, the search-faces function is called to search for the face image in the collections. The user is verified and the authentication process completes if their face image exists in the collections. Otherwise, the user’s access is denied.

To test authenticating an existing user, run the app.py script (test-client) by running the following code::

python3 src/test-client/app.py auth -z Riv-Prod -r <region> -u <username> -p <path to face image>

Existing user login with a request for photo update

The following image illustrates the workflow to update an existing user’s photo. The steps are defined in the update.py file.

This workflow calls four functions: detect-faces, compare-faces, search-faces, and index-faces. The steps are similar to the steps in the existing user authentication workflow. After the user captures their face image and the image quality is checked, we check for a matching face image in DynamoDB using the compare-faces function. If a match for the user is found, their user profile is updated, their new face image is indexed by calling the index-faces function, and the update process completes. Alternatively, if a match isn’t found, the search-faces function is called to search for the face image in the Amazon Rekognition collections. If the face image is found in the collection, the user’s profile is updated and their new face image is indexed. The user’s access is denied if their image isn’t found in the collections.

To update an existing user’s photo, run the app.py script (test-client) by running the following code:

python3 src/test-client/app.py update -z Riv-Prod -r <region> -u <username> -p <path to updated face image>

Clean up

To prevent accruing additional charges in your AWS account, delete the resources you provisioned by navigating to the AWS CloudFormation console and deleting the Riv-Prod stack.

Deleting the stack doesn’t delete the S3 bucket you created. This bucket stores all the face images. If you choose to delete the S3 bucket, navigate to the Amazon S3 console, empty the bucket, and then confirm you want to permanently delete it.

Conclusion

Amazon Rekognition makes it easy to add image analysis to your identity verification applications using proven, highly scalable, deep learning technology that requires no ML expertise to use. Amazon Rekognition provides face detection and comparison capabilities. With a combination of the DetectFaces, CompareFaces, IndexFaces, and SearchFacesByImage APIs, you can implement the common flows around new user registration and existing user logins.

Amazon Rekognition collections provide a method to store information about detected faces in server-side containers. You can then use the facial information stored in a collection to search for known faces in images. When using collections, you don’t need to store original photos after you index faces in the collection. Amazon Rekognition collections don’t persist actual images. Instead, the underlying detection algorithm detects the faces in the input image, extracts facial features into a feature vector for each face, and stores it in the collection.

To start your journey towards identity verification, visit Identity Verification using Amazon Rekognition.


About the Authors

Nate Bachmeier is an AWS Senior Solutions Architect that nomadically explores New York, one cloud integration at a time. He specializes in migrating and modernizing applications. Besides this, Nate is a full-time student and has two kids.

Anthony Pasquariello is an Enterprise Solutions Architect based in New York City. He provides technical consultation to customers during their cloud journey, especially around security best practices. He has an MS and BS in electrical and computer engineering from Boston University. In his free time, he enjoys ramen, writing non-fiction, and philosophy.

Lauren Mullennex is a Solutions Architect based in Denver, CO. She works with customers to help them architect solutions on AWS. In her spare time, she enjoys hiking and cooking Hawaiian cuisine.

Amit Gupta is a Senior AI Services Solutions Architect at AWS. He is passionate about enabling customers with well-architected machine learning solutions at scale.

Read More

Introducing hybrid machine learning

Gartner predicts that by the end of 2024, 75% of enterprises will shift from piloting to operationalizing artificial intelligence (AI), and the vast majority of workloads will end up in the cloud in the long run. For some enterprises that plan to migrate to the cloud, the complexity, magnitude, and length of migrations may be daunting. The speed of different teams and their appetites for new tooling can vary dramatically. An enterprise’s data science team may be hungry for adopting the latest cloud technology, while the application development team is focused on running their web applications on premises. Even with a multi-year cloud migration plan, some of the product releases must be built on the cloud in order to meet the enterprise’s business outcomes.

For these customers, we propose hybrid machine learning (ML) patterns as an intermediate step in your journey to the cloud. Hybrid ML patterns are those that involve a minimum of two compute environments, typically local compute resources such as personal laptops or corporate data centers, and the cloud. With the hybrid ML architecture patterns described in this post, enterprises can achieve their desired business goals without having to wait for the cloud migration to complete. At the end of the day, we want to support customer success in all shapes and forms.

We have published a new whitepaper, Hybrid Machine Learning, to help you integrate the cloud with existing on-premises ML infrastructure. For more whitepapers from AWS, see AWS Whitepapers & Guides.

Hybrid ML architecture patterns

The whitepaper gives you an overview of the various hybrid ML patterns across the entire ML lifecycle, including ML model development, data preparation, training, deployment, and ongoing management. The following table summarizes the eight different hybrid ML architectural patterns we discuss in the whitepaper. For each pattern, we provide a preliminary reference architecture in addition to the advantages and disadvantages. We also identify a “when to move” criterion to help you make decisions—for example, when the level of effort to maintain and scale a given pattern has exceeded the value it provides.

Development Training Deployment
Develop on personal computers, train and host in the cloud Train locally, deploy in the cloud Serve ML models in the cloud to applications hosted on premises
Develop on local servers, train and host in the cloud Store data locally, train and deploy in the cloud Host ML models with Lambda@Edge to applications on premises
Develop in the cloud while connecting to data hosted on premises Train with a third-party SaaS provider to host in the cloud
Train in the cloud, deploy ML models on premises Orchestrate hybrid ML workloads with Kubeflow and Amazon EKS Anywhere

In this post, we dive deep into the hybrid architecture pattern for deployment with a focus on serving models hosted in the cloud to applications hosted on premises.

Architecture overview

The most common use case for this hybrid pattern is enterprise migrations. Your data science team may be ready to deploy to the cloud, but your application team is still refactoring their code to host on cloud-native services. This approach enables the data scientists to bring their newest models to market, while the application team separately considers when, where, and how to move the rest of the application to the cloud.

The following diagram shows the architecture for hosting an ML model via Amazon SageMaker in an AWS Region, serving responses to requests from applications hosted on premises.

Hybrid ML

Technical deep dive

In this section, we dive deep into the technical architecture and focus on the various components that comprise the hybrid workload explicitly and refer to resources elsewhere as necessary.

Let’s take a real-world use case of a retail company whose application development team has hosted their ecommerce web application on premises. The company wants to improve brand loyalty, grow sales and revenue, and increase efficiencies by using data to create more sophisticated and unique customer experiences. They intend to increase customer engagement by 50% by adding a “recommended for you” widget on their home screen. However, they’re struggling to deliver personalized experiences due to the limitations of static, rule-based systems, complexities and costs, and friction with platform integration due to their current legacy, on-premises architecture.

The application team has a 5-year enterprise migration strategy to refactor their web application using cloud-native architecture to move to the cloud, whereas the data science teams are ready to begin implementation in the cloud. With the hybrid architecture pattern described in this post, the company can achieve their desired business outcome quickly without having to wait for the 5-year enterprise migration to complete.

The data scientists develop the ML model, perform training, and deploy the trained model in the cloud. The ecommerce web application that’s hosted on premises consumes the ML model via the exposed endpoints. Let’s walk this through in detail.

In the model development phase, data scientists can use local development environments, such as PyCharm or Jupyter installations on their personal computer, and then connect to the cloud via AWS Identity and Access Management (IAM) permissions and interface with AWS service APIs through the AWS Command Line Interface (AWS CLI) or an AWS SDK (such as Boto3). They also have the flexibility to use Amazon SageMaker Studio, a single web-based visual interface that comes with common data science packages and kernels preinstalled for model development.

Data scientists can take advantage of SageMaker training capabilities, including access to on-demand CPU and GPU instances, automatic model tuning, managed Spot Instances, checkpointing for saving the state of models, managed distributed training, and many more, using the SageMaker training SDK and APIs. For an overview on training models with SageMaker, see Train a Model with Amazon SageMaker.

After the model is trained, data scientists can deploy the models using SageMaker hosting capabilities and expose a REST HTTP(s) endpoint serving predictions to end applications hosted on premises. The application development teams can integrate their on-premises applications to interact with the ML model via SageMaker hosted endpoints to get the inference results. Application developers can access the deployed models through application programming interface (API) requests with response times as low as a few milliseconds. This supports use cases requiring real-time responses, such as personalized product recommendations.

The client application on premises connects with the ML model hosted on the SageMaker hosted endpoint on AWS over a private network using VPN or Direct Connect connection, to provide inference results to its end users. The client application can use any client library to invoke the endpoint using an HTTP Post request along with necessary authentication credentials configured programmatically and the expected payload. SageMaker also has commands and libraries that abstract some of the low-level details such as authentication using the AWS credentials saved in our client application environment, such as the SageMaker invoke-endpoint runtime command from the AWS CLI, SageMaker runtime client from Boto3 (AWS SDK for Python), and the Predictor class from the SageMaker Python SDK.

To make the endpoint accessible over the internet, we can use Amazon API Gateway. Although you can directly access SageMaker hosted endpoints from API Gateway, a common pattern you can use is adding an AWS Lambda function in between. You can use the Lambda function for any preprocessing, which may be needed in order to send the request in the format expected by the endpoint, or postprocessing for transforming the response into the format required by the client application. For more information, see Call an Amazon SageMaker model endpoint using Amazon API Gateway and AWS Lambda.

The client application on premises connects with ML models hosted on SageMaker on AWS over a private network using VPN or Direct Connect connection, to provide inference results to its end users.

The following diagram illustrates how the data science team develops the ML model, performs training, and deploys the trained model in the cloud, while the application development team develops and deploys the ecommerce web application on premises.

Architecture Deep Dive

After the model is deployed into the production environment, your data scientists can use Amazon SageMaker Model Monitor to continuously monitor the quality of the ML models in real time. They can also set up an automated alert triggering system when deviations in the model quality occur, such as data drift and anomalies. Amazon CloudWatch Logs collects log files monitoring the model status and notifies you when the quality of the model reaches certain thresholds. This enables your data scientists to take corrective actions, such as retraining models, auditing upstream systems, or fixing quality issues without having to monitor models manually. With AWS Managed Services, your data science team can avoid the downside of implementing monitoring solutions from scratch.

Your data scientists can reduce the overall time required to deploy their ML models in production by automating load testing and model tuning across SageMaker ML instances by using Amazon SageMaker Inference Recommender. It helps your data scientists select the best instance type and configuration (such as instance count, container parameters, and model optimizations) for their ML models.

Lastly, it’s always a best practice to decouple hosting your ML model from hosting your application. In this approach, the data scientists use dedicated resources to host their ML model, specifically ones that are separated from the application, which greatly simplifies the process to push better models. This is a key step in the innovation flywheel. This also prevents any form of tight coupling between the hosted ML model and the application, thereby enabling the model to be highly performant.

In addition to improving the model performance with updated research trends, this approach provides the ability to redeploy a model with updated data. The global COVID-19 pandemic has demonstrated the reality that markets are changing all the time, and the ML model need to stay up to date with the latest trends. The only way you can deliver on that requirement is by being able to retrain and redeploy your model with updated data.

Conclusion

Check out the whitepaper Hybrid Machine Learning, in which we look at additional patterns for hosting ML models via Lambda@Edge, AWS Outposts, AWS Local Zones, and AWS Wavelength. We explore hybrid ML patterns across the entire ML lifecycle. We look at developing locally, while training and deploying in the cloud. We discuss patterns for training locally to deploy on the cloud, and even to host ML models in the cloud to serve applications on premises.

How are you integrating the cloud with your existing on-premises ML infrastructure? Please share your feedback about hybrid ML in the comments so we can continue to improve our products, features, and documentation. If you want to engage the authors of this document for advice on your cloud migration, contact us at hybrid-ml-support@amazon.com.


About the Authors

Alak Eswaradass is a Solutions Architect at AWS, based in Chicago, Illinois. She is passionate about helping customers design cloud architectures utilizing AWS services to solve business challenges. She hangs out with her daughters and explores the outdoors in her free time.

Emily Webber joined AWS just after SageMaker launched, and has been trying to tell the world about it ever since! Outside of building new ML experiences for customers, Emily enjoys meditating and studying Tibetan Buddhism.

Roop Bains is a Solutions Architect at AWS focusing on AI/ML. He is passionate about machine learning and helping customers achieve their business objectives. In his spare time, he enjoys reading and hiking.

Read More