Distributed fine-tuning of a BERT Large model for a Question-Answering Task using Hugging Face Transformers on Amazon SageMaker

From training new models to deploying them in production, Amazon SageMaker offers the most complete set of tools for startups and enterprises to harness the power of machine learning (ML) and Deep Learning.

With its Transformers open-source library and ML platform, Hugging Face makes transfer learning and the latest ML models accessible to the global AI community, reducing the time needed for data scientists and ML engineers in companies around the world to take advantage of every new scientific advancement.

Applying Transformers to new NLP tasks or domains requires fine-tuning of large language models, a technique leveraging the accumulated knowledge of pre-trained models to adapt them to a new task or specific type of documents in an additional, efficient training process.

Fine-tuning the model to produce accurate predictions for the business problem at hand requires the training of large Transformers models, for example, BERT, BART, RoBERTa, T5, which can be challenging to perform in a scalable way.

Hugging Face has been working closely with SageMaker to deliver ready-to-use Deep Learning Containers (DLCs) that make training and deploying the latest Transformers models easier and faster than ever. Because features such as SageMaker Data Parallel (SMDP), SageMaker Model Parallel (SMMP), S3 pipe mode, are integrated into the container, using these drastically reduces the time for companies to create Transformers-based ML solutions such as question-answering, generating text and images, optimizing search results, and improves customer support automation, conversational interfaces, semantic search, document analyses, and many more applications.

In this post, we focus on the deep integration of SageMaker distributed libraries with Hugging Face, which enables data scientists to accelerate training and fine-tuning of Transformers models from days to hours, all in SageMaker.

Overview of distributed training

ML practitioners and data scientists face two scaling challenges when training models: scaling model size (number of parameters and layers) and scaling training data. Scaling either the model size or training data can result in better accuracy, but there can be cases in deep learning where the amount of memory on the accelerator (CPU or GPU) limits the combination of the size of the training data and the size of the model. For example, when training a large language model, the batch size is often limited to a small number of samples, which can result in a less accurate model.

Distributed training can split up the workload to train the model among multiple processors, called workers. These workers operate in parallel to speed up model training.

Based on what we want to scale (model or data) there are two approaches to distributed training: data parallel and model parallel.

Data parallel is the most common approach to distributed training. Data parallelism entails creating a copy of the model architecture and weights on different accelerators. Then, rather than passing in the entire training set to a single accelerator, we can partition the training set across the different accelerators, and get through the training set faster. Although this adds the step of the accelerators needing to communicate their gradient information back to a parameter server, this time is more than offset by the speed boost of iterating over a fraction of the entire dataset per accelerator. Because of this, data parallelism can significantly help reduce training times. For example, training a single model without parallelization takes 4 hours. Using distributed training can reduce that to 24 minutes. SageMaker distributed training also implements cutting-edge techniques in gradient updates.

A model parallel approach is used with large models too large to fit on one accelerator (GPU). This approach implements a parallelization strategy where the model architecture is divided into shards and placed onto different accelerators. The configuration of each of these shards is neural network architecture dependent, and typically includes several layers. The communication between the accelerators occurs each time the training data passes from one of the shards to the next.

To summarize, you should use distributed training data parallelism for time-intensive tasks due to large datasets or when you want to accelerate your training experiments. You should use model parallelism when your model can’t fit onto one accelerator.


To perform distributed training of Hugging Face Transformers models in SageMaker, you need to complete the following prerequisites:

Implement distributed training

The Hugging Face Transformers library provides a Trainer API that is optimized to train or fine-tune the models the library provides. You can also use it on your own models if they work the same way as Transformers models; see Trainer for more details. This API is used in our example scripts, which show how to preprocess the data for various NLP tasks, which you can take as models to write a script solving your own custom problem. The promise of the Trainer API is that this script works out of the box on any distributed setup, including SageMaker.

The Trainer API takes everything needed for the training. This includes your datasets, your model (or a function that returns your model), a compute_metrics function that returns the metrics you want to track from the arrays of predications and labels, your optimizer and learning rate scheduler (good defaults are provided), as well as all the hyperparameters you can tune for your training grouped in a data class called TrainingArguments. With all of that, it exposes three methods—train, evaluate, and predict—to train your model, get the metric results on any dataset, or get the predictions on any dataset. To learn more about the Trainer object, refer to Fine-tuning a model with the Trainer API and the video The Trainer API, which walks you through a simple example.

Behind the scenes, the Trainer API starts by analyzing the environment in which you are launching your script when you create the TrainingArguments. For instance, if you launched your training with SageMaker, it looks at the SM_FRAMEWORK_PARAMS variable in the environment to detect if you enabled SageMaker data parallelism or model parallelism. Then it gets the relevant variables (such as the rank of the process or the world size) from the environment before performing the necessary initialization steps (such as smdistributed.dataparallel.torch.distributed.init_process_group()).

The Trainer contains the whole training loop, so it can adjust the necessary steps to make sure the smdistributed.dataparallel backend is used when necessary without you having to change a line of code in your script. It can still run (albeit much slower) on your local machine for debugging. It handles sharding your dataset such that each process sees different samples automatically, with a reshuffle at each epoch, synchronizing your gradients before the optimization step, mixed precision training if you activated it, gradient accumulation if you can’t fit a big batch size on your GPUs, and many more optimizations.

If you activated model parallelism, it makes sure the processes that have to see the same data (if their dp_rank is the same) get the same batches, and that processes with different dp_rank don’t see the same samples, again with a reshuffle at each epoch. It makes sure the state dictionaries of the model or optimizers are properly synchronized when checkpointing, and again handles all the optimizations such as mixed precision and gradient accumulation.

When using the evaluate and predict methods, the Trainer performs a distributed evaluation, to take advantage of all your GPUs. It properly handles splitting your data for each process (process of the same dp_rank if model parallelism is activated) and makes sure that the predictions are properly gathered in the same order as the dataset you’re using before they are sent to the compute_metrics function or just returned. Using the Trainer API is not mandatory. Users can still use Keras or PyTorch within Hugging Face. However, the Trainer API can provide a helpful abstraction layer.

Train a model using SageMaker Hugging Face Estimators

An Estimator is a high-level interface for SageMaker training and handles end-to-end SageMaker training and deployment tasks. The training of your script is invoked when you call fit on a HuggingFace Estimator. In the Estimator, you define which fine-tuning script to use as entry_point, which instance_type to use, and which hyperparameters are passed in. For more information about HuggingFace parameters, see Hugging Face Estimator.

Distributed training: Data parallel

In this example, we use the new Hugging Face DLCs and SageMaker SDK to train a distributed Seq2Seq-transformer model on the question and answering task using the Transformers and datasets libraries. The bert-large-uncased-whole-word-masking model is fine-tuned on the squad dataset.

The following code samples show you steps of creating a HuggingFace estimator for distributed training with data parallelism.

  1. Choose a Hugging Face Transformers script:
    # git configuration to download our fine-tuning script
    git_config = {'repo': 'https://github.com/huggingface/transformers.git','branch': 'v4.6.1'}

When you create a HuggingFace Estimator, you can specify a training script that is stored in a GitHub repository as the entry point for the Estimator, so you don’t have to download the scripts locally. You can use git_config to run the Hugging Face Transformers examples scripts and right ‘branch’ if your transformers_version needs to be configured. For example, if you use transformers_version 4.6.1, you have to use ‘branch':'v4.6.1‘.

  1. Configure training hyperparameters that are passed into the training job:
    # hyperparameters, which are passed into the training job
        'model_name_or_path': 'bert-large-uncased-whole-word-masking',
        'do_train': True,
        'do_eval': True,
        'fp16': True,
        'per_device_train_batch_size': 4,
        'per_device_eval_batch_size': 4,
        'num_train_epochs': 2,
        'max_seq_length': 384,
        'max_steps': 100,
        'pad_to_max_length': True,
        'doc_stride': 128,
        'output_dir': '/opt/ml/model'

As a hyperparameter, we can define any Seq2SeqTrainingArguments and the ones defined in the training script.

  1. Define the distribution parameters in the HuggingFace Estimator:
    # configuration for running training on smdistributed Data Parallel
    distribution = {'smdistributed':{'dataparallel':{ 'enabled': True }}}

You can use the SageMaker data parallelism library out of the box for distributed training. We added the functionality of data parallelism directly into the Trainer. To enable data parallelism, you can simply add a single parameter to your HuggingFace Estimator to let your Trainer-based code use it automatically.

  1. Create a HuggingFace Estimator including parameters defined in previous steps and start training:
from sagemaker.huggingface import HuggingFace
# estimator
huggingface_estimator = HuggingFace(entry_point='run_qa.py',
                                    instance_type= 'ml.p3.16xlarge',
                                    instance_count= 2,
                                    volume_size= 200,
                                    role= <SageMaker Role>, # IAM role,
                                    distribution= distribution,
                                    hyperparameters = hyperparameters)

# starting the train job 

The Hugging Face Transformers repository contains several examples and scripts for fine-tuning models on tasks from language modeling to token classification. In our case, we use run_qa.py from the examples/pytorch/question-answering examples.

smdistributed.dataparallel supports model training on SageMaker with the following instance types only. For best performance, we recommend using an instance type that supports Elastic Fabric Adapter (EFA):

  • ml.p3.16xlarge
  • ml.p3dn.24xlarge (Recommended)
  • ml.p4d.24xlarge (Recommended)

To get the best performance and the most out of SMDataParallel, you should use at least two instances, but you can also use one for testing this example.

The following example notebook provides more detailed step-by-step guidance.

Distributed training: Model parallel

For distributed training with model parallelism, we use the Hugging Face Transformers and datasets library together with the SageMaker SDK for sequence classification on the General Language Understanding Evaluation (GLUE) benchmark on a multi-node, multi-GPU cluster using the SageMaker model parallelism library.

As with data parallelism, we first set the git configuration, training hyperparameters, and distribution parameters in the HuggingFace Estimator:

# git configuration to download our fine-tuning script
git_config = {'repo': 'https://github.com/huggingface/transformers.git','branch': 'v4.6.1'}

# hyperparameters, which are passed into the training job
    'task_name': 'mnli',
    'per_device_train_batch_size': 16,
    'per_device_eval_batch_size': 16,
    'do_train': True,
    'do_eval': True,
    'do_predict': True,
    'num_train_epochs': 2,
    'max_steps': 500,

# configuration for running training on smdistributed Model Parallel
mpi_options = {
    "enabled" : True,
    "processes_per_host" : 8,
smp_options = {
    "parameters": {
        "microbatches": 4,
        "placement_strategy": "spread",
        "pipeline": "interleaved",
        "optimize": "speed",
        "partitions": 4,
        "ddp": True,

    "smdistributed": {"modelparallel": smp_options},
    "mpi": mpi_options

The model parallelism library internally uses MPI, so to use model parallelism, MPI must be enabled using the distribution parameter. “processes_per_host” in the preceding code specifies the number of processes MPI should launch on each host. We suggest these for development and testing. At production time, you can contact AWS Support if requesting extensive GPU capacity. For more information, see Run a SageMaker Distributed Model Parallel Training Job.

The following example notebook contains the complete code scripts.

Spot Instances

With the Hugging Face framework extension for the SageMaker Python SDK, we can also take advantage of fully managed Amazon Elastic Compute Cloud (Amazon EC2) Spot Instances and save up to 90% of our training cost.

Unless your training job will complete quickly, we recommend you use checkpointing with managed spot training, therefore you need to define checkpoint_s3_uri.

To use Spot Instances with the HuggingFace Estimator, we have to set the use_spot_instances parameter to True and define your max_wait and max_run time. For more information about the managed spot training lifecycle, see Managed Spot Training in Amazon SageMaker.

The following is a code snippet for setting up a spot training Estimator:

from sagemaker.huggingface import HuggingFace

# hyperparameters, which are passed into the training job
hyperparameters={'epochs': 1,
                 'train_batch_size': 32,

# s3 uri where our checkpoints will be uploaded during training
job_name = "using-spot"
checkpoint_s3_uri = f's3://{sess.default_bucket()}/{job_name}/checkpoints'

huggingface_estimator = HuggingFace(entry_point='train.py',
                            max_wait=3600, # This should be equal to or greater than max_run in seconds'
                            max_run=1000, # expected max run in seconds
                            hyperparameters = hyperparameters)

The following notebook contains the complete code scripts.


In this post, we discussed distributed training of Hugging Face Transformers using SageMaker. We first reviewed the use cases for data parallelism vs. model parallelism. Data parallelism is typically more appropriate but not necessarily restricted to when training is bottlenecked by compute, whereas you can use model parallelism when a model can’t fit within the memory provided on a single accelerator. We then showed how to train with both methods.

In the data parallelism use case we discussed, training a model on a single p3.2xlarge instance (with a single GPU) takes 4 hours and costs roughly $15 at the time of this writing. With data parallelism, we can train the same model in 24 minutes at a cost of $28. Although the cost has doubled, this has reduced the training time by a factor of 10. For a situation in which you need to train many models within a short period of time, data parallelism can enable this at a relatively low cost increase. As for the model parallelism use case, it adds the ability to train models that could not have been previously trained at all due to hardware limitations. Both features enable new workflows for ML practitioners, and are readily accessible through the HuggingFace Estimator as a part of the SageMaker Python SDK. Deploying these models to hosted endpoints follows the same procedure as for other Estimators.

This integration enables other features that are part of the SageMaker ecosystem. For example, you can use Spot Instances by adding a simple flag to the Estimator for additional cost-optimization. As a next step, you can find and run the training demo and example notebook.

About the Authors

Archis Joglekar is an AI/ML Partner Solutions Architect in the Emerging Technologies team. He is interested in performant, scalable deep learning and scientific computing using the building blocks at AWS. His past experiences range from computational physics research to machine learning platform development in academia, national labs, and startups. His time away from the computer is spent playing soccer and with friends and family.

James Yi is a Sr. AI/ML Partner Solutions Architect in the Emerging Technologies team at Amazon Web Services. He is passionate about working with enterprise customers and partners to design, deploy and scale AI/ML applications to derive their business values. Outside of work, he enjoys playing soccer, traveling and spending time with his family.

Philipp Schmid is a Machine Learning Engineer and Tech Lead at Hugging Face, where he leads the collaboration with the Amazon SageMaker team. He is passionate about democratizing, optimizing, and productionizing cutting-edge NLP models and improving the ease of use for Deep Learning.

Sylvain Gugger is a Research Engineer at Hugging Face and one of the main maintainers of the Transformers library. He loves open source software and help the community use it.

Jeff Boudier builds products at Hugging Face, creator of Transformers, the leading open-source ML library. Previously Jeff was a co-founder of Stupeflix, acquired by GoPro, where he served as director of Product Management, Product Marketing, Business Development and Corporate Development.

Read More

Detect NLP data drift using custom Amazon SageMaker Model Monitor

Natural language understanding is applied in a wide range of use cases, from chatbots and virtual assistants, to machine translation and text summarization. To ensure that these applications are running at an expected level of performance, it’s important that data in the training and production environments is from the same distribution. When the data that is used for inference (production data) differs from the data used during model training, we encounter a phenomenon known as data drift. When data drift occurs, the model is no longer relevant to the data in production and likely performs worse than expected. It’s important to continuously monitor the inference data and compare it to the data used during training.

You can use Amazon SageMaker to quickly build, train, and deploy machine learning (ML) models at any scale. As a proactive measure against model degradation, you can use Amazon SageMaker Model Monitor to continuously monitor the quality of your ML models in real time. With Model Monitor, you can also configure alerts to notify and trigger actions if any drift in model performance is observed. Early and proactive detection of these deviations enables you to take corrective actions, such as collecting new ground truth training data, retraining models, and auditing upstream systems, without having to manually monitor models or build additional tooling.

Model Monitor offers four different types of monitoring capabilities to detect and mitigate model drift in real time:

  • Data quality – Helps detect change in data schemas and statistical properties of independent variables and alerts when a drift is detected.
  • Model quality – For monitoring model performance characteristics such as accuracy or precision in real time, Model Monitor allows you to ingest the ground truth labels collected from your applications. Model Monitor automatically merges the ground truth information with prediction data to compute the model performance metrics.
  • Model bias –Model Monitor is integrated with Amazon SageMaker Clarify to improve visibility into potential bias. Although your initial data or model may not be biased, changes in the world may cause bias to develop over time in a model that has already been trained.
  • Model explainability – Drift detection alerts you when a change occurs in the relative importance of feature attributions.

In this post, we discuss the types of data quality drift that are applicable to text data. We also present an approach to detecting data drift in text data using Model Monitor.

Data drift in NLP

Data drift can be classified into three categories depending on whether the distribution shift is happening on the input or on the output side, or whether the relationship between the input and the output has changed.

Covariate shift

In a covariate shift, the distribution of inputs changes over time, but the conditional distribution P(y|x) doesn’t change. This type of drift is called covariate shift because the problem arises due to a shift in the distribution of the covariates (features). For example, in an email spam classification model, distribution of training data (email corpora) may diverge from the distribution of data during scoring.

Label shift

While covariate shift focuses on changes in the feature distribution, label shift focuses on changes in the distribution of the class variable. This type of shifting is essentially the reverse of covariate shift. An intuitive way to think about it might be to consider an unbalanced dataset. If the spam to non-spam ratio of emails in our training set is 50%, but in reality 10% of our emails are non-spam, then the target label distribution has shifted.

Concept shift

Concept shift is different from covariate and label shift in that it’s not related to the data distribution or the class distribution, but instead is related to the relationship between the two variables. For example, email spammers often use a variety of concepts to pass the spam filter models, and the concept of emails used during training may change as time goes by.

Now that we understand the different types of data drift, let’s see how we can use Model Monitor to detect covariate shift in text data.

Solution overview

Unlike tabular data, which is structured and bounded, textual data is complex, high dimensional, and free form. To efficiently detect drift in NLP, we work with embeddings, which are low-dimensional representations of the text. You can obtain embeddings using various language models such as Word2Vec and transformer-based models like BERT. These models project high-dimensional data into low-dimensional spaces while preserving the semantic information of the text. The results are dense and contextually meaningful vectors, which can be used for various downstream tasks, including monitoring for data drift.

In our solution, we use embeddings to detect the covariate shift of English sentences. We utilize Model Monitor to facilitate continuous monitoring for a text classifier that is deployed to a production environment. Our approach consists of the following steps:

  1. Fine-tune a BERT model using SageMaker.
  2. Deploy a fine-tuned BERT classifier as a real-time endpoint with data capture enabled.
  3. Create a baseline dataset that consists of a sample of the sentences used to train the BERT classifier.
  4. Create a custom SageMaker monitoring job to calculate the cosine similarity between the data captured in production and the baseline dataset.

The following diagram illustrates the solution workflow:

Fine-tune a BERT model

In this post, we use Corpus of Linguistic Acceptability (CoLA), a dataset of 10,657 English sentences labeled as grammatical or ungrammatical from published linguistics literature. We use SageMaker training to fine-tune a BERT model using the CoLa dataset by defining an PyTorch estimator class. For more information on how to use this SDK with PyTorch, see Use PyTorch with the SageMaker Python SDK. Calling the fit() method of the estimator launches the training job:

from sagemaker.pytorch import PyTorch

# place to save model artifact
output_path = f"s3://{bucket}/{model_prefix}"

estimator = PyTorch(
        "epochs": 1,
        "num_labels": 2,
        "backend": "gloo",
    disable_profiler=True, # disable debugger
estimator.fit({"training": inputs_train, "testing": inputs_test})

Deploy the model

After training our model, we host it on a SageMaker endpoint. To make the endpoint load the model and serve predictions, we implement a few methods in train_deploy.py:

  • model_fn() – Loads the saved model and returns a model object that can be used for model serving. The SageMaker PyTorch model server loads our model by invoking model_fn.
  • input_fn() – Deserializes and prepares the prediction input. In this example, our request body is first serialized to JSON and then sent to the model serving endpoint. Therefore, in input_fn(), we first deserialize the JSON-formatted request body and return the input as a torch.tensor, as required for BERT.
  • predict_fn() – Performs the prediction and returns the result.

Enable Model Monitor data capture

We enable Model Monitor data capture to record the input data into the Amazon Simple Storage Service (Amazon S3) bucket to reference it later:

data_capture_config = DataCaptureConfig(enable_capture=True,

Then we create a real-time SageMaker endpoint with the model created in the previous step:

predictor = estimator.deploy(endpoint_name='nlp-data-drift-bert-endpoint',


We run prediction using the predictor object that we created in the previous step. We set JSON serializer and deserializer, which is used by the inference endpoint:

print("Sending test traffic to the endpoint {}. nPlease wait...".format(endpoint_name))

result = predictor.predict([
"Thanks so much for driving me home",
"Thanks so much for cooking dinner. I really appreciate it",
"Nice to meet you, Sergio. So, where are you from"

The real-time endpoint is configured to capture data from the request, and the response and the data gets stored in Amazon S3. You can view the data that’s captured in the previous monitoring schedule.

Create a baseline

We use a fine-tuned BERT model to extract sentence embedding features from the training data. We use these vectors as high-quality feature inputs for comparing cosine distance because BERT produces dynamic word representation with semantic context. Complete the following steps to get sentence embedding:

  1. Use a BERT tokenizer to get token IDs for each token (input_id) in the input sentence and mask to indicate which elements in the input sequence are tokens vs. padding elements (attention_mask_id). We use the BERT tokenizer.encode_plus function to get these values for each input sentence:
#Add instantiation of tokenizer
encoded_dict = tokenizer.encode_plus(
               sent,                      # Input Sentence to encode.
               add_special_tokens = True, # Add '[CLS]' and '[SEP]'
               max_length = 64,           # Pad sentence to max_length
               pad_to_max_length = True,  # Truncate sentence to max_length
               return_attention_mask = True, #BERT model needs attention_mask
               return_tensors = 'pt',     # Return pytorch tensors.
input_ids = encoded_dict['input_ids']
attention_mask_ids = encoded_dict['attention_mask']

input_ids and attention_mask_ids are passed to the model and fetch the hidden states of the network. The hidden_states has four dimensions in the following order:

  • Layer number (BERT has 12 layers)
  • Batch number (1 sentence)
  • Word token indexes
  • Hidden units (768 features)
  1. Use the last two hidden layers to get a single vector (sentence embedding) by calculating the average of all input tokens in the sentence:
outputs = model(input_ids, attention_mask_ids) # forward pass to model
hidden_states = outputs[2]                     # token vectors
token_vecs = hidden_states[-2][0]              # last 2 layer hidden states
sentence_embedding = torch.mean(token_vecs, dim=0) # average token vectors
  1. Convert the sentence embedding as a NumPy array and store it in an Amazon S3 location as a baseline that is used by Model Monitor:
sentence_embeddings_list = []for i in sentence_embeddings:sentence_embeddings_list.append(i.numpy())

np.save('embeddings.npy', sentence_embeddings_list)

#Upload the sentence embedding to S3
!aws s3 cp embeddings.npy s3://{bucket}/{model_prefix}/embeddings/

Evaluation script

Model Monitor provides a pre-built container with the ability to analyze the data captured from endpoints for tabular datasets. If you want to bring your own container, Model Monitor provides extension points that you can use. When you create a MonitoringSchedule, Model Monitor ultimately kicks off processing jobs. Therefore, the container needs to be aware of the processing job contract. We need to create an evaluation script that is compatible with container contract inputs and outputs.

Model Monitor uses evaluation code on all the samples that are captured during the monitoring schedule. For each inference data point, we calculate the sentence embedding using the same logic described earlier. Cosine similarity is used as a distance metric to measure the similarity of an inference data point and sentence embeddings in the baseline. Mathematically, it measures the cosine angle between two sentence embedding vectors. A high the cosine similarity score indicates similar sentence embeddings. A lower cosine similarity score indicates data drift. We calculate an average of all the cosine similarity scores, and if it’s less than the threshold, it gets captured in the violation report. Based on the use case, you can use other distance metrics like manhattan or euclidean to measure similarity of sentence embeddings.

The following diagram shows how we use SageMaker Model Monitoring to establish baseline and detect data drift using cosine distance similarity.

The following is the code for calculating the violations; the complete evaluation script is available on GitHub:

for embed_item in embedding_list: # all sentence embeddings from baseline
    cosine_score += (1 - cosine(input_sentence_embedding, embed_item)) # cosine distance between input sentence embedding and baseline embedding
cosine_score_avg = cosine_score/(len(embedding_list)) # average cosine score of input sentence
if cosine_score_avg < env.max_ratio_threshold: # compare averge cosine score against a threshold
    sent_cosine_dict[record] = cosine_score_avg # capture details for violation report
            "sentence": record,
            "avg_cosine_score": cosine_score_avg,
            "feature_name": "sent_cosine_score",
            "constraint_check_type": "baseline_drift_check",
            "endpoint_name" : env.sagemaker_endpoint_name,
            "monitoring_schedule_name": env.sagemaker_monitoring_schedule_name

Measure data drift using Model Monitor

In this section, we focus on measuring data drift using Model Monitor. Model Monitor pre-built monitors are powered by Deequ, which is a library built on top of Apache Spark for defining unit tests for data, which measure data quality in large datasets. You don’t require coding to utilize these pre-built monitoring capabilities. You also have the flexibility to monitor models by coding to provide custom analysis. You can collect and review all metrics emitted by Model Monitor in Amazon SageMaker Studio, so you can visually analyze your model performance without writing additional code.

In certain scenarios, for instance when the data is non-tabular, the default processing job (powered by Deequ) doesn’t suffice because it only supports tabular datasets. The pre-built monitors may not be sufficient to generate sophisticated metrics to detect drifts, and may necessitate bringing your own metrics. In the next sections, we describe the setup to bring in your metrics by building a custom container.

Build the custom Model Monitor container

We use the evaluation script from the previous section to build a Docker container and push it to Amazon Elastic Container Registry (Amazon ECR):

#Build a docker container and push to ECR

account_id = boto3.client('sts').get_caller_identity().get('Account')
ecr_repository = 'nlp-data-drift-bert-v1'
tag = ':latest'
region = boto3.session.Session().region_name
sm = boto3.client('sagemaker')
uri_suffix = 'amazonaws.com'
if region in ['cn-north-1', 'cn-northwest-1']:
    uri_suffix = 'amazonaws.com.cn'
    processing_repository_uri = f'{account_id}.dkr.ecr.{region}.{uri_suffix}/{ecr_repository + tag}'
# Creating the ECR repository and pushing the container image

!docker build -t $ecr_repository docker

!$(aws ecr get-login --region $region --registry-ids $account_id --no-include-email)

!aws ecr create-repository --repository-name $ecr_repository

!docker tag {ecr_repository + tag} $processing_repository_uri!docker push $processing_repository_uri

When the customer Docker container is in Amazon ECR, we can schedule a Model Monitoring job and generate a violations report, as demonstrated in the next sections.

Schedule a model monitoring job

To schedule a model monitoring job, we create an instance of Model Monitor and in the image_uri, we refer to the Docker container that we created in the previous section:

from sagemaker.model_monitor import ModelMonitor

monitor = ModelMonitor(
    env={ 'THRESHOLD':'0.5', 'bucket': bucket },

We schedule the monitoring job using the create_monitoring_schedule API. You can schedule the monitoring job on an hourly or daily basis. You configure the job using the destination parameter, as shown in the following code:

from sagemaker.model_monitor import CronExpressionGenerator, MonitoringOutput
from sagemaker.processing import ProcessingInput, ProcessingOutput

destination = f's3://{sagemaker_session.default_bucket()}/{prefix}/{endpoint_name}/monitoring_schedule'

processing_output = ProcessingOutput(
output = MonitoringOutput(source=processing_output.source, destination=processing_output.destination)


To describe and list the monitoring schedule and its runs, you can use the following commands:


Data drift violation report

When the model monitoring job is complete, you can navigate to the destination S3 path to access the violation reports. This report contains all the inputs whose average cosine score (avg_cosine_score) is below the threshold configured as an environment variable THRESHOLD:0.5 in the ModelMonitor instance. This is an indication that the data observed during inference is drifting beyond the established baseline.

The following code shows the generated violation report:

    "violations": [
            "feature_name": "sent_cosine_score",
            "constraint_check_type": "baseline_drift_check",
            "sentence": "Thanks so much for driving me home",
            "avg_cosine_score": 0.36653404209142876
            "feature_name": "sent_cosine_score",
            "constraint_check_type": "baseline_drift_check",
            "sentence": "Thanks so much for cooking dinner. I really appreciate it",
            "avg_cosine_score": 0.34974955975723576
            "feature_name": "sent_cosine_score",
            "constraint_check_type": "baseline_drift_check",
            "sentence": "Nice to meet you, Sergio. So, where are you from",
            "avg_cosine_score": 0.378982806084463

Finally, based on this observation, you can configure your model for retraining. You can also enable Amazon Simple Notification Service (Amazon SNS) notifications to send alerts when violations occur.


Model Monitor enables you to maintain the high quality of your models in production. In this post, we highlighted the challenges with monitoring data drift on unstructured data like text, and provided an intuitive approach to detect data drift using a custom monitoring script. You can find the code associated with the post in the following GitHub repository. Additionally, you can customize the solution to utilize other distance metrics such as maximum mean discrepancy (MMD), a non-parametric distance metric to compute marginal distribution between source and target distribution on the embedded space.

About the Authors

Vikram Elango is an AI/ML Specialist Solutions Architect at Amazon Web Services, based in Virginia, USA. Vikram helps financial and insurance industry customers with design, thought leadership to build and deploy machine learning applications at scale. He is currently focused on natural language processing, responsible AI, inference optimization and scaling ML across the enterprise. In his spare time, he enjoys traveling, hiking, cooking and camping with his family.

Raghu Ramesha is a ML Solutions Architect with the Amazon SageMaker Service team. He focuses on helping customers migrate ML production workloads to SageMaker at scale. He specializes in machine learning, AI, and computer vision domains, and holds a master’s degree in Computer Science from UT Dallas. In his free time, he enjoys traveling and photography.

Tony Chen is a Machine Learning Solutions Architect at Amazon Web Services, helping customers design scalable and robust machine learning capabilities in the cloud. As a former data scientist and data engineer, he leverages his experience to help tackle some of the most challenging problems organizations face with operationalizing machine learning.

Read More

Computer vision-based anomaly detection using Amazon Lookout for Vision and AWS Panorama

This is the second post in the two-part series on how Tyson Foods Inc., is using computer vision applications at the edge to automate industrial processes inside their meat processing plants. In Part 1, we discussed an inventory counting application at packaging lines built with Amazon SageMaker and AWS Panorama . In this post, we discuss a vision-based anomaly detection solution at the edge for predictive maintenance of industrial equipment.

Operational excellence is a key priority at Tyson Foods. Predictive maintenance is an essential asset for achieving this objective by continuously improving overall equipment effectiveness (OEE). In 2021, Tyson Foods launched a machine learning (ML) based computer vision project to identify failing product carriers during production to prevent them from impacting team member safety, operations, or product quality. When a product carrier breaks or moves into the wrong position, production must be stopped. If it’s not caught in time, it poses a threat to team member safety and machinery. With a manual inspection method, an operator inspects 8,000 pins per line. This is a slow and challenging task because attention to detail is critical. ML practitioners at Tyson Foods have built computer vision models to automate the inspection process and detect anomalies continuously. This process can enable the maintenance team to reduce the cycle time and improve the reliability of inspecting 8,000 pins.

Developing a custom ML model to analyze images and detect anomalies, and making these models run efficiently at the edge is a challenging task. This requires specialized expertise, time, and resources. The entire development cycle may take months to complete. With the approaches mentioned in Part 1 of this series, we completed the project for monitoring the condition of the product carriers at Tyson Foods in record time using AWS Managed Services such as Amazon Lookout for Vision.

Solution overview

The patterns, code, and infrastructure designed for the tray counting use case in Part 1 were readily replicated in the product carrier project. Although at first glance these projects may seem very different, at their core they are made up of the same five components: image capture, labeling, model training, frame deduplication, and inference.

This post demonstrates how to set up a computer vision-based anomaly detection solution for failing product carriers (or similar manufacturing line assembly) using AWS Panorama and Lookout for Vision. The workflow begins with inference via an object detection model on an AWS Panorama device at the edge. The object detection model crops the image and passes the result to the Lookout for Vision anomaly detection model that classifies the pin images. The anomalous pin images and model results are sent to the cloud and available for additional processing.

The following diagram illustrates this architecture.


To follow along with this post, you need the following:

Train an object detection model

The first stage of our multi-model inference design is an SSD object detection model trained to detect product carriers and flags. The pins are used to train the anomaly classification model using Lookout for Vision. The flag, referencing the beginning of the product carrier line, helps us track each loop cycle and deduplicate anomaly detections.

The following image is an example inference result from the pin detection SSD model.

Train an anomaly classification model using Lookout for Vision

Lookout for Vision is a fully managed ML service that uses computer vision to help identify visual defects in objects. It allows you to build an anomaly detection model quickly with little-to-no code and requires very little data to start (minimum 20 normal and 10 anomaly images). Training a Lookout for Vision model follows a four-step process:

  1. Create a Lookout for Vision project.
  2. Build a product carrier dataset.
  3. Train and tune the Lookout for Vision model.
  4. Export the Lookout for Vision model for inference.

In this section, we walk you through Steps 1–3.

Create a Lookout for Vision project

For instructions on creating a Lookout for Vision project, see Creating your project.

Build a product carrier dataset

The dataset for Lookout for Vision has to be square images, JPG or PNG format, minimum pixel size of 64 x 64, and maximum pixel size of 4096 x 4096. To generate a dataset that satisfies the requirements, we had to crop each bounding box and resize them while preserving the original aspect ratio using the following Python code. We add this code to the image capture pipeline described in Part 1 to generate the final 150 x 150 pixel images for Lookout for Vision.

def crop_n_resize_image(self, img, bbox, size, padColor=0):

    # crop images ==============================
    crop = img[bbox[1]:bbox[3],bbox[0]:bbox[2]].copy()
    # cropped image size
    h, w = crop.shape[:2]
    # designed crop image sizes
    sh, sw = size

    # interpolation method
    if h > sh or w > sw: # shrinking image
        interp = cv2.INTER_AREA
    else: # stretching image
        interp = cv2.INTER_CUBIC

    # aspect ratio of image
    aspect = w/h 

    # compute scaling and pad sizing
    if aspect > 1: # horizontal image
        new_w = sw
        new_h = np.round(new_w/aspect).astype(int)
        pad_vert = (sh-new_h)/2
        pad_top, pad_bot = np.floor(pad_vert).astype(int), np.ceil(pad_vert).astype(int)
        pad_left, pad_right = 0, 0
    elif aspect < 1: # vertical image
        new_h = sh
        new_w = np.round(new_h*aspect).astype(int)
        pad_horz = (sw-new_w)/2
        pad_left, pad_right = np.floor(pad_horz).astype(int), np.ceil(pad_horz).astype(int)
        pad_top, pad_bot = 0, 0
    else: # square image
        new_h, new_w = sh, sw
        pad_left, pad_right, pad_top, pad_bot = 0, 0, 0, 0

    # set pad color
    if len(img.shape) is 3 and not isinstance(padColor, (list, tuple, np.ndarray)): # color image but only one color provided
        padColor = [padColor]*3

    # scale and pad
    scaled_img = cv2.resize(crop, (new_w, new_h), interpolation=interp)
    scaled_img = cv2.copyMakeBorder(scaled_img, pad_top, pad_bot, pad_left, pad_right, borderType=cv2.BORDER_CONSTANT, value=padColor)

    return scaled_img

The following are examples of processed product carrier images.

We label the images through Amazon SageMaker Ground Truth, which returns a label manifest file. This file is imported into Lookout for Vision to create the anomaly detection dataset. You can label the images within the Lookout for Vision platform, but we didn’t use that approach in this project. The following screenshot shows the labeled dataset on the Lookout for Vision console.

Train and tune the Lookout for Vision model

Training an anomaly detection model in Lookout for Vision is as simple as a click of a button. Lookout for Vision automatically holds out 20% of the data as a test set to validate the model performance. The key to generating good model results is to focus on labeling and image quality. The initial image size used was too small, and critical details were lost due to resolution. Increasing the resolution from 64 x 64 to 150 x 150 resulted in a significant jump in model accuracy. To tune the labels, the development team spent a significant amount of time with subject matter experts from the plant to utilize their knowledge in designing the definitions for each class. It was imperative that these class definitions were very clear, and it took a few iterations to get them perfect. The following screenshot shows the results achieved with well-established class definitions.

Develop the AWS Panorama application

The AWS Panorama application is an inference container deployed to the AWS Panorama Appliance to process input video streams, run inference, and output video results using the AWS Panorama SDK. Most of the inference code is the same as in Part 1; the following features are added specifically for this product carrier use case:

  • Build a frame inference trigger
  • Run Lookout for Vision inference
  • Deduplicate and isolate pin location

Build a frame inference trigger

For this use case, our product carriers are moving continuously across the video frame, and the same pins may be detected repeatedly until it moves off of the camera view. To avoid sending duplicated pins to the Lookout for Vision model for anomaly classification and wasting compute resources, we developed a software trigger in our inference code to downsample the frames and reduce the number of duplicated pins for inference. In the following screenshot, the minimum number of pins detected is 8 and the maximum number of pins detected is 10.

The logic determines the trigger using product carrier IDs, which is a counter for the number of new product carriers moving into the camera view. We get that by determining when the number of bounding boxes in a frame reaches the max value. As shown in the preceding figure, there is a min and max possible bounding boxes detected at any given time. The count oscillates between the min and max value, which corresponds to a new product carrier moving into the camera view. The following figure illustrates the oscillation pattern. Because a camera frame can only fit six product carriers, we know an entire frame shifted off when the product carrier ID incremented by 6.

Run Lookout for Vision inference

We crop the bounding boxes from the frame image and process them using the same resize function described earlier, and then forward these images to the Lookout for Vision model for anomaly classification. In response, the Lookout for Vision model produces a label (normal or anomaly) and confidence score.

Isolate pin locations and deduplicate anomaly detections

Lastly for this use case, it was important to identify the relative location of the product carriers and only generate one entry per bad pin to avoid duplications. To track the pin location, inference code was written to use the flag as a point of reference and count the product carrier ID. When an anomaly is detected, the product carrier ID is recorded with the pin image to provide the location reference relative to the flag. We also use this flag to help us deduplicate the anomaly detections and track when an entire product carrier line has looped around. There is a cycle ID parameter that gets incremented every time the flag appears, and all the parameters like product carrier ID reset to 0 to start a new cycle.

Deploy models at the edge with AWS Panorama

When we have the models and the inference code ready, we package the object detection model, inference code, and camera stream into a container and deploy to AWS Panorama using the same deployment pattern described in Part 1.

Email alerts

Whenever the system detects an anomaly, the image containing the defective pin is sent to Amazon S3 for storage, and the metadata associated with it is sent to AWS IoT SiteWise. At the end of each shift, an EventBridge event triggers a Lambda function, which uses the images and metadata to send a summary email to the plant staff. The plant staff uses this information when making repairs during shift change.


In this post, we demonstrated how to set up a vision-based anomaly detection system in a production environment using Lookout for Vision and AWS Panorama. With this solution, plants can save 1 hour of team member time per day per line. This would save this plant alone an estimated 15,000 hours of skilled labor annually. This would free up the time of valuable Tyson team members to complete other, more complex tasks.

The models trained in this process performed well. The SSD pin detection model achieved 95% accuracy across both classes. The Lookout for Vision model was tuned to perform at 99.1% accuracy for failing pin detection. Despite the two models utilized in this project, the inference code was easily able to keep up with line speed, running at around 10 FPS.

By far the most exciting result of this project was the speedup in development time. Although this project utilizes two models and more complex application code than the project in Part 1, it took 12% less developer time to complete. This agility is only possible because of the repeatable patterns established in Part 1 and using managed services from AWS. This combination made our final solutions faster to scale and industry ready. Learn more about Amazon Lookout for Vision by going to the Amazon Lookout for Vision Resources page. You can also view other examples of AWS Panorama in action by going to the GitHub repo.

About the Authors

Audrey Timmerman is a Sr Applications Developer at Tyson Foods. She is a Computer Engineering Graduate from the University of Arkansas and has been on the Emerging Technology team at Tyson Foods for 2 years. She has an interest in computer vision, machine learning, and IoT applications.

James Wu is a Senior Customer Solutions Manager at AWS, based in Dallas, TX. He works with customers to accelerate their cloud journey and fast-track their business value realization. In addition to that, James is also passionate about developing and scaling large AI/ ML solutions across various domains. Prior to joining AWS, he led a multi-discipline innovation technology team with ML engineers and software developers for a top global firm in the market and advertising industry.

Farooq Sabir is a Senior AI/ML Specialist Solutions Architect at AWS. He holds a PhD in Electrical Engineering from the University of Texas at Austin. He helps customers solve their business problems using data science, machine learning, artificial intelligence, and numerical optimization.

Elizabeth Samara Rubio is a Principal Specialist in the WWSO at Amazon Web Services, driving new AI/ML and computer vision solutions across industries, including industrial and manufacturing sectors. Prior to joining Amazon, Elizabeth was a Managing Director at Accenture leading North America Industry X growth and strategy, Divisional Vice President at AMETEK, and Business Unit Manager at Cognex.

Shreyas Subramanian is an AI/ML specialist Solutions Architect, and helps customers by using Machine Learning to solve their business challenges on the AWS Cloud.

Read More

Label text for aspect-based sentiment analysis using SageMaker Ground Truth

The Amazon Machine Learning Solutions Lab (MLSL) recently created a tool for annotating text with named-entity recognition (NER) and relationship labels using Amazon SageMaker Ground Truth. Annotators use this tool to label text with named entities and link their relationships, thereby building a dataset for training state-of-the-art natural language processing (NLP) machine learning (ML) models. Most importantly, this is now publicly available to all AWS customers.

Customer Use Case: Booking.com

Booking.com is one of the world’s leading online travel platforms. Understanding what customers are saying about the company’s 28 million+ property listings on the platform is essential for maintaining a top-notch customer experience. Previously, Booking.com could only utilize traditional sentiment analysis to interpret customer-generated reviews at scale. Looking to upgrade the specificity of these interpretations, Booking.com recently turned to the MLSL for help with building a custom annotated dataset for training an aspect-based sentiment analysis model.

Traditional sentiment analysis is the process of classifying a piece of text as positive, negative, or neutral as a singular sentiment. This works to broadly understand if users are satisfied or unsatisfied with a particular experience. For example, with traditional sentiment analysis, the following text may be classified as “neutral”:

Our stay at the hotel was nice. The staff was friendly and the rooms were clean, but our beds were quite uncomfortable.

Aspect-based sentiment analysis offers a more nuanced understanding of content. In the case of Booking.com, rather than taking a customer review as a whole and classifying it categorically, it can take sentiment from within a review and assign it to specific aspects. For example, customer reviews of a given hotel might praise the immaculate pool and fitness area, but give critical feedback on the restaurant and lounge.

The statement which would have been classified as “neutral” by traditional sentiment analysis will, with aspect-based sentiment analysis, become:

Our stay at the hotel was nice. The staff was friendly and the rooms were clean, but our beds were quite uncomfortable.

  • Hotel: Positive
  • Staff: Positive
  • Room: Positive
  • Beds: Negative

Booking.com sought to build a custom aspect-based sentiment analysis model that would tell them which specific parts of the guest experience (from a list of 50+ aspects) were positivenegative, or neutral.

Before Booking.com could build a training dataset for this model, they needed a way to annotate it. MLSL’s annotation tool provided the much-needed customized solution. Human review was performed on a large collection of hotel reviews. Then, annotators completed named-entity annotation on sentiment and guest-experience text spans and phrases before linking appropriate spans together.

The new aspect-based model lets Booking.com personalize both accommodations and reviews to its customers. Highlighting the positive and negative aspects of each accommodation enables the customers to choose their perfect match. In addition, different customers care about different aspects of the accommodation, and the new model opens up the opportunity to show the most relevant reviews to each one.

Labeling Requirements

Although Ground Truth provides a built-in NER text annotation capability, it doesn’t provide the ability to link entities together. With this in mind, Booking.com and MLSL worked out the following high-level requirements for a new named entity recognition text labeling tool that:

  • Accepts as input: text, entity labels, relationship labels, and classification labels.
  • Optionally accepts as input pre-annotated data with the preceding label and relationship annotations.
  • Presents the annotator with either unannotated or pre-annotated text.
  • Allows annotators to highlight and annotate arbitrary text with an entity label.
  • Allows annotators to create relationships between two entity annotations.
  • Allows annotators to easily navigate large numbers of entity labels.
  • Supports grouping entity labels into categories.
  • Allow overlapping relationships, which means that the same annotated text segment can be related to more than one other annotated text segment.
  • Allows overlapping entity label annotations, which means that two annotations can overlap the same piece of text. For example, the text “Seattle Space Needle” can have both the annotations “Seattle” → “locations”, and “Seattle Space Needle” → “attractions”.
  • Output format is compatible with input format, and it can be fed back into subsequent labeling tasks.
  • Supports UTF-8 encoded text containing emoji and other multi-byte characters.
  • Supports left-to-right languages.

Sample Annotation

Consider the following document:

We loved the location of this hotel! The rooftop lounge gave us the perfect view of space needle. It is also a short drive away from pike place market and the waterfront.
Food was only available via room service, which was a little disappointing but makes sense in this post-pandemic world.
Overall, a reasonably priced experience.

Loading this document into the new NER annotation presents a worker with the following interface:

Worker presented with an unannotated document

Worker presented with an unannotated document

In this case, the worker’s job is to:

  • Label entities related to the property (location, price, food, etc.)
  • Label entities related to sentiment (positive, negative, or neutral)
  • Link property-related named entities to sentiment-related keywords to accurately capture the guest experience
Worker performing annotations

Worker performing annotations

Annotation speed was an important consideration of the tool. Using a sequence of intuitive keyboard shortcuts and mouse gestures, annotators can drive the interface and:

  • Add and remove named entity annotations
  • Add relationships between named entities
  • Jump to the beginning and end of the document
  • Submit the document

Additionally, there is support for overlapping labels. For example, Seattle Space Needle: in this phrase, Seattle is annotated both as a location by itself and as a part of the attraction name.

The completed annotation provides a more complete, nuanced analysis of the data:

Completed document

Completed document

Relationships can be configured in many levels, from entity categories to other entity categories (for example, from “food” to “sentiment”), or between individual entity types. Relationships are directed, so annotators can link an aspect like food to a sentiment, but not vice-versa (unless explicitly enabled). When drawing relationships, the annotation tool will automatically deduce the relationship label and direction.

Configuring the NER Annotation Tool

In this section, we cover how to customize the NER annotation tool for customer-specific use cases. This includes configuring:

  • The input text to annotate
  • Entity labels
  • Relationship Labels
  • Classification Labels
  • Pre-annotated data
  • Worker instructions

We’ll cover the specifics of the input and output document formats, as well as provide some examples of each.

Input Document Format

The NER annotation tool expects the following JSON formatted input document (Fields with a question mark next to the name are optional).

  text: string;
  tokenRows?: string[][];
  documentId?: string;
  entityLabels?: {
    name: string;
    shortName?: string;
    category?: string;
    shortCategory?: string;
    color?: string;
  classificationLabels?: string[];
  relationshipLabels?: {
    name: string;
    allowedRelationships?: {
        sourceEntityLabelCategories?: string[];
        targetEntityLabelCategories?: string[];
        sourceEntityLabels?: string[];
        targetEntityLabels?: string[];
  entityAnnotations?: {
    id: string;
    start: number;
    end: number;
    text: string;
    label: string;
    labelCategory?: string;
  relationshipAnnotations?: {
    sourceEntityAnnotationId: string;
    targetEntityAnnotationId: string;
    label: string;
  classificationAnnotations?: string[];
  meta?: {
    instructions?: string;
    disableSubmitConfirmation?: boolean;
    multiClassification: boolean;

In a nutshell, the input format has these characteristics:

  • Either entityLabels or classificationLabels (or both) are required to annotate.
  • If entityLabels are given, then relationshipLabels can be added.
  • Relationships can be allowed between different entity/category labels or a mix of these.
  • The “source” of a relationship is the entity that the directed arrow starts with, while the “target” is where it’s heading.
Field Type Description
text string Required. Input text for annotation.
tokenRows string[][] Optional. Custom tokenization of input text. Array of arrays of strings. Top level array represents each row of text (line breaks), and second level array represents tokens on each row. All characters/runes in the input text must be accounted for in tokenRows, including any white space.
documentId string Optional. Optional value for customers to keep track of document being annotated.
entityLabels object[] Required if classificationLabels is blank. Array of entity labels.
entityLabels[].name string Required. Entity label display name.
entityLabels[].category string Optional. Entity label category name.
entityLabels[].shortName string Optional. Display this text over annotated entities rather than the full name.
entityLabels[].shortCategory string Optional. Display this text in the entity annotation select dropdown instead of the first four letters of the category name.
entityLabels.color string Optional. Hex color code with “#” prefix. If blank, then it will automatically assign a color to the entity label.
relationshipLabels object[] Optional. Array of relationship labels.
relationshipLabels[].name string Required. Relationship label display name.
relationshipLabels[].allowedRelationships object[] Optional. Array of values restricting what types of source and destination entity labels this relationship can be assigned to. Each item in array is “OR’ed” together.
relationshipLabels[].allowedRelationships[].sourceEntityLabelCategories string[] Required to set either sourceEntityLabelCategories or sourceEntityLabels (or both). List of legal source entity label category types for this relationship.
relationshipLabels[].allowedRelationships[].targetEntityLabelCategories string[] Required to set either targetEntityLabelCategories or targetEntityLabels (or both). List of legal target entity label category types for this relationship.
relationshipLabels[].allowedRelationships[].sourceEntityLabels string[] Required to set either sourceEntityLabelCategories or sourceEntityLabels (or both). List of legal source entity label types for this relationship.
relationshipLabels[].allowedRelationships[].sourceEntityLabels string[] Required to set either targetEntityLabelCategories or targetEntityLabels (or both). List of legal target entity label types for this relationship.
classificationLabels string[] Required if entityLabels is blank. List of document level classification labels.
entityAnnotations object[] Optional. Array of entity annotations to pre-annotate input text with.
entityAnnotations[].id string Required. Unique identifier for this entity annotation. Used to reference this entity in relationshipAnnotations.
entityAnnotations[].start number Required. Start rune offset of this entity annotation.
entityAnnotations[].end number Required. End rune offset of this entity annotation.
entityAnnotations[].text string Required. Text content between start and end rune offset.
entityAnnotations[].label string Required. Associated entity label name (from the names in entityLabels).
entityAnnotations[].labelCategory string Optional.Associated entity label category (from the categories in entityLabels).
relationshipAnnotations object[] Optional. Array of relationship annotations.
relationshipAnnotations[].sourceEntityAnnotationId string Required. Source entity annotation ID for this relationship.
relationshipAnnotations[].targetEntityAnnotationId string Required. Target entity annotation ID for this relationship.
relationshipAnnotations[].label string Required. Associated relationship label name.
classificationAnnotations string[] Optional. Array of classifications to pre-annotate the document with.
meta object Optional. Additional configuration parameters.
meta.instructions string Optional. Instructions for the labeling annotator in Markdown format.
meta.disableSubmitConfirmation boolean Optional. Set to true to disable submit confirmation modal.
meta.multiClassification boolean Optional. Set to true to enable multi-label mode for classificationLabels.

Here are a few sample documents to get a better sense of this input format

Documents that adhere to this schema are provided to Ground Truth as individual line items in an input manifest.

Output Document Format

The output format is designed to feedback easily into a new annotation task. Optional fields in the output document are set if they are also set in the input document. The only difference between the input and output formats is the meta object.

  text: string;
  tokenRows?: string[][];
  documentId?: string;
  entityLabels?: {
    name: string;
    shortName?: string;
    category?: string;
    shortCategory?: string;
    color?: string;
  relationshipLabels: {
    name: string;
    allowedRelationships?: {
        sourceEntityLabelCategories?: string[];
        targetEntityLabelCategories?: string[];
        sourceEntityLabels?: string[];
        targetEntityLabels?: string[];
  classificationLabels?: string[];
  entityAnnotations?: {
    id: string;
    start: number;
    end: number;
    text: string;
    labelCategory?: string;
    label: string;
  relationshipAnnotations?: {
    sourceEntityAnnotationId: string;
    targetEntityAnnotationId: string;
    label: string;
  classificationAnnotations?: string[];
  meta: {
    instructions?: string;
    disableSubmitConfirmation?: boolean;
    multiClassification: boolean;
    runes: string[];
    rejected: boolean;
    rejectedReason: string;
Field Type Description
meta.rejected boolean Is set to true if the annotator rejected this document.
meta.rejectedReason string Annotator’s reason given for rejecting the document.
meta.runes string[] Array of runes accounting for all of the characters in the input text. Used to calculate entity annotation start and end offsets.

Here is a sample output document that’s been annotated:

Runes note:

A “rune” in this context is a single highlight-able character in text, including multi-byte characters such as emoji.

  • Because different programming languages represent multi-byte characters differently, using “Runes” to define every highlight-able character as a single atomic element means that we have an unambiguous way to describe any given text selection.
  • For example, Python treats the Swedish flag as four characters:

    But JavaScript treats the same emoji as two characters

To eliminate any ambiguity, we will treat the Swedish flag (and all other emoji and multi-byte characters) as a single atomic element.

  • Offset: Rune position relative to Input Text (starting with index 0)

Performing NER Annotations with Ground Truth

As a fully managed data labeling service, Ground Truth builds training datasets for ML. For this use case, we use Ground Truth to send a collection of text documents to a pool of workers for annotation. Finally, we review for quality.

Ground Truth can be configured to build a data labeling job using the new NER tool as a custom template.

Specifically, we will:

  1. Create a private labeling workforce of workers to perform the annotation task
  2. Create a Ground Truth input manifest with the documents we want to annotate and then upload it to Amazon Simple Storage Service (Amazon S3)
  3. Create pre-labeling task and post-labeling task Lambda functions
  4. Create a Ground Truth labeling job using the custom NER template
  5. Annotate documents
  6. Review results

NER Tool Resources

A complete list of referenced resources and sample documents can be found in the following chart:

Description Filename
Production custom worker task template worker-template.liquid.html
Sample Ground Truth Pre-Labeling Lambda smgt-ner-pre-labeling-task-lambda.py
Sample Ground Truth Post-Labeling Lambda smgt-ner-post-labeling-task-lambda.py
Sample Input Document #1 (pre-labeled) review-01.json
Sample Input Document #2 (pre-labeled) review-02.json
Sample Input Document #3 (custom tokenization) review-03.json
Sample Input Document #4 (Document classification) review-04.json
Sample Ground Truth Input Manifest reviews.manifest
Output for Sample Input Document #1 review-01-output.json

Labeling Workforce Creation

Ground Truth uses SageMaker labeling workforces to manage workers and distribute tasks. Create a private workforce, a worker team called ner-worker-team, and assign yourself to the team using the instructions found in Create a Private Workforce (Amazon SageMaker Console).

Once you’ve added yourself to a private workforce and confirmed your email, note the worker portal URL from the AWS Management Console:

  • Navigate to SageMaker
  • Navigate to Ground Truth → Labeling workforces
  • Select the Private tab
  • Note the URL Labeling portal sign-in URL

Log in to the worker portal to view and start work on labeling tasks.

Input Manifest

The Ground Truth input data manifest is a JSON-lines file where each line contains a single worker task. In our case, each line will contain a single JSON encoded Input Document containing the text that we want to annotate and the NER annotation schema.

Download a sample input manifest reviews.manifest from https://assets.solutions-lab.ml/NER/0.2.1/sample-data/reviews.manifest

Note: each row in the input manifest needs a top-level key source or source-ref. You can learn more in Use an Input Manifest File in the Amazon SageMaker Developer Guide.

Upload Input Manifest to Amazon S3

Upload this input manifest to an S3 bucket using the AWS Management Console or from the command line, thereby replacing your-bucket with an actual bucket name.

aws s3 cp reviews.manifest s3://your-bucket/ner-input/reviews.manifest

Download custom worker template

Download the NER tool custom worker template from https://assets.solutions-lab.ml/NER/0.2.1/worker-template.liquid.html by viewing the source and saving the contents locally, or from the command line:

wget https://assets.solutions-lab.ml/NER/0.2.1/worker-template.liquid.html

Create pre-labeling task and post-labeling task Lambda functions

Download sample pre-labeling task Lambda function: smgt-ner-pre-labeling-task-lambda.py from https://assets.solutions-lab.ml/NER/0.2.1/sample-scripts/smgt-ner-pre-labeling-task-lambda.py

Download sample pre-labeling task Lambda function: smgt-ner-post-labeling-task-lambda.py from https://assets.solutions-lab.ml/NER/0.2.1/sample-scripts/smgt-ner-post-labeling-task-lambda.py

  • Create pre-labeling task Lambda function from the AWS Management Console:
    • Navigate to Lambda
    • Select Create function
    • Specify Function name as smgt-ner-pre-labeling-task-lambda
    • Select RuntimePython 3.6
    • Select Create function
    • In Function codelambda_hanadler.py, paste the contents of smgt-ner-pre-labeling-task-lambda.py
    • Select Deploy
  • Create post-labeling task Lambda function from the AWS Management Console:
    • Navigate to Lambda
    • Select Create function
    • Specify Function name as smgt-ner-post-labeling-task-lambda
    • Select RuntimePython 3.6
    • Expand Change default execution role
    • Select Create a new role from AWS policy templates
    • Enter the Role name: smgt-ner-post-labeling-task-lambda-role
    • Select Create function
    • Select the Permissions tab
    • Select the Role name: smgt-ner-post-labeling-task-lambda-role to open the IAM console
    • Add two policies to the role
      • Select Attach policies
      • Attach the AmazonS3FullAccess policy
      • Select Add inline policy
      • Select the JSON tab
      • Paste in the following inline policy:
            "Version": "2012-10-17",
            "Statement": {
                "Effect": "Allow",
                "Action": "sts:AssumeRole",
                "Resource": "arn:aws:iam::YOUR_ACCOUNT_NUMBER:role/service-role/AmazonSageMaker-ExecutionRole-*"

    • Navigate back to the smgt-ner-post-labeling-task-lambda Lambda function configuration page
    • Select the Configuration tab
    • In Function code → lambda_hanadler.py, paste the contents of smgt-ner-post-labeling-task-lambda.py
    • Select Deploy

Create a Ground Truth labeling job

From the AWS Management Console:

  • Navigate to the Amazon SageMaker service
  • Navigate to Ground TruthLabeling Jobs.
  • Select Create labeling job
  • Specify a Job Name
  • Select Manual Data Setup
  • Specify the Input dataset location where you uploaded the input manifest earlier (e.g., s3://your-bucket/ner-input/sample-smgt-input-manifest.jsonl)
  • Specify the Output dataset location to point to a different folder in the same bucket (e.g., s3://your-bucket/ner-output/)
  • Specify an IAM Role by selecting Create new role
    • Allow this role to access any S3 bucket by selecting S3 buckets you specifyAny S3 bucket when creating the policy
    • In a new AWS Management Console window, open the IAM console and select Roles
    • Search for the name of the role that you just created (for example, AmazonSageMaker-ExecutionRole-20210301T154158)
    • Select the role name to open the role in the console
    • Attach the following three policies:
      • Select Attach policies
      • Attach the AWSLambda_FullAccess to the role
      • Select Trust RelationshipsEdit Trust Relationships
      • Edit the trust relationship JSON,
      • Replace YOUR_ACCOUNT_NUMBER with your numerical AWS Account number, to read:
          "Version": "2012-10-17",
          "Statement": [
              "Effect": "Allow",
              "Principal": {
                "Service": "sagemaker.amazonaws.com"
              "Action": "sts:AssumeRole"
              "Effect": "Allow",
              "Principal": {
                "AWS": "arn:aws:iam::YOUR_ACCOUNT_NUMBER:role/service-role/smgt-ner-post-labeling-task-lambda-role"
              "Action": "sts:AssumeRole"

      • Save the trust relationship
  • Return to the new Ground Truth job in the previous AWS Management Console window: under Task Category, select Custom
  • Select Next
  • Select Worker types: Private
  • Select the Private team : ner-worker-team that was created in the preceding section
  • In the Custom labeling task setup text area, clear the default content and paste in the content of the worker-template.liquid.html file obtained earlier
  • Specify the Pre-labeling task Lambda function with the previously created function: smgt-ner-pre-labeling
  • Specify the Post-labeling task Lambda function with the function created earlier: smgt-ner-post-labeling
  • Select Create

Annotate documents

Once the Ground Truth job is created, we can start annotating documents. Open the worker portal for our workforce created earlier (In the AWS Management Console, navigate to the SageMakerGround Truth → Labeling workforces, Private, and open the Labeling portal sign-in URL )

Sign in and select the first labeling task in the table, and then select “Start working” to open the annotator. Perform your annotations and select submit on all three of the sample documents.

Review results

As Ground Truth annotators complete tasks, results will be available in the output S3 bucket:


Once all tasks for a labeling job are complete, the consolidated output is available in the output.manifest file located here:


This output manifest is a JSON-lines file with one annotated text document per line in the “Output Document Format” specified previously. This file is compatible with the “Input Document Format”, and it can be fed directly into a subsequent Ground Truth job for another round of annotation. Alternatively, it can be parsed and sent to an ML training job. Some scenarios where we might employ a second round of annotations are:

  • Breaking the annotation process into two steps where the first annotator identifies entity annotations and the second annotator draws relationships
  • Taking a sample of our output.manifest and sending it to a second, more experienced annotator for review as a quality control check

Custom Ground Truth Annotation Templates

The NER annotation tool described in this document is implemented as a custom Ground Truth annotation template. AWS customers can build their own custom annotation interfaces using the instructions found here:


By working together, Booking.com and the Amazon MLSL were able to develop a powerful text annotation tool that is capable of creating complex named-entity recognition and relationship annotations.

We encourage AWS customers with an NER text annotation use case to try the tool described in this post. If you’d like help accelerate the use of ML in your products and services, please contact the Amazon Machine Learning Solutions Lab.

About the Authors

Dan Noble is a Software Development Engineer at Amazon where he helps build delightful user experiences. In his spare time, he enjoys reading, exercising, and having adventures with his family.

Pri Nonis is a Deep Learning Architect at the Amazon ML Solutions Lab, where he works with customers across various verticals, and helps them accelerate their cloud migration journey, and to solve their ML problems using state-of-the-art solutions and technologies.

Niharika Jayanthi is a Front End Engineer at AWS, where she develops custom annotation solutions for Amazon SageMaker customers. Outside of work, she enjoys going to museums and working out.

Amit Beka is a Machine Learning Manager at Booking.com, with over 15 years of experience in software development and machine learning. He is fascinated with people and languages, and how computers are still puzzled by both.

Read More

Optimize your inference jobs using dynamic batch inference with TorchServe on Amazon SageMaker

In deep learning, batch processing refers to feeding multiple inputs into a model. Although it’s essential during training, it can be very helpful to manage the cost and optimize throughput during inference time as well. Hardware accelerators are optimized for parallelism, and batching helps saturate the compute capacity and often leads to higher throughput.

Batching can be helpful in several scenarios during model deployment in production. Here we broadly categorize them into two use cases:

  • Real-time applications where several inference requests are received from different clients and are dynamically batched and fed to the serving model. Latency is usually important in these use cases.
  • Offline applications where several inputs or requests are batched on the client side and sent to the serving model. Higher throughput is often the objective for these use cases, which helps manage the cost. Example use cases include video analysis and model evaluation.

Amazon SageMaker provides two popular options for your inference jobs. For real-time applications, SageMaker Hosting uses TorchServe as the backend serving library that handles the dynamic batching of the received requests. For offline applications, you can use SageMaker batch transform jobs. In this post, we go through an example of each option to help you get started.

Because TorchServe is natively integrated with SageMaker via the SageMaker PyTorch inference toolkit, you can easily deploy a PyTorch model onto TorchServe using SageMaker Hosting. There may be also times when you need to customize your environment further using custom Docker images. In this post, we first show how to deploy a real-time endpoint using the native SageMaker PyTorch inference toolkit and configuring the batch size to optimize throughput. In the second example, we demonstrate how to use a custom Docker image to configure advanced TorchServe configurations that aren’t available as an environment variable to optimize your batch inference job.

Best practices for batch inference

Batch processing can increase throughput and optimize your resources because it helps complete a larger number of inferences in a certain amount of time at the expense of latency. To optimize model deployment for higher throughput, the general guideline is to increase the batch size until throughput decreases. This most often suits offline applications, where several inputs are batched (such as video frames, images, or text) to get prediction outputs.

For real-time applications, latency is often a main concern. There’s a trade-off between higher throughput and increased batch size and latency; you may need to adjust as needed to meet your latency SLA. In terms of best practices on the cloud, the cost per a certain number of inferences is a helpful guideline in making an informed decision that meets your business needs. One contributing factor in managing the cost is choosing the right accelerator. For more information, see Choose the best AI accelerator and model compilation for computer vision inference with Amazon SageMaker.

TorchServe dynamic batching on SageMaker

TorchServe is the native PyTorch library for serving models in production at scale. It’s a joint development from Facebook and AWS. TorchServe allows you to monitor, add custom metrics, support multiple models, scale up and down the number of workers through secure management APIs, and provide inference and explanation endpoints.

To support batch processing, TorchServe provides a dynamic batching feature. It aggregates the received requests within a specified time frame, batches them together, and sends the batch for inference. The received requests are processed through the handlers in TorchServe. TorchServe has several default handlers, and you’re welcome to author a custom handler if your use case isn’t covered. When using a custom handler, make sure that the batch inference logic has been implemented in the handler. An example of a custom handler with batch inference support is available on GitHub.

You can configure dynamic batching using two settings, batch_size and max_batch_delay, either through environment variables in SageMaker or through the config.properties file in TorchServe (if using a custom container). TorchServe uses any of the settings that comes first, either the maximum batch size (batch_size) or specified time window to wait for the batch of requests through max_batch_delay.

With TorchServe integrations with SageMaker, you can now deploy PyTorch models natively on SageMaker, where you can define a SageMaker PyTorch model. You can add custom model loading, inference, and preprocessing and postprocessing logic in a script passed as an entry point to the SageMaker PyTorch (see the following example code). Alternatively, you can use a custom container to deploy your models. For more information, see The SageMaker PyTorch Model Server.

You can set the batch size for PyTorch models on SageMaker through environment variables. If you choose to use a custom container, you can bundle settings in config.properties with your model when packaging your model in TorchServe. The following code snippet shows an example how to set the batch size using environment variables and how to deploy a PyTorch model on SageMaker:

from SageMaker.pytorch.model import PyTorchModel

env_variables_dict = {

pytorch_model = PyTorchModel(

predictor = pytorch_model.deploy(initial_instance_count=1, instance_type="ml.c5.2xlarge", serializer=SageMaker.serializers.JSONSerializer(), deserializer=SageMaker.deserializers.BytesDeserializer())

In the code snippet, model_artifact refers to all the required files for loading back the trained model, which is archived in a .tar file and pushed into an Amazon Simple Storage Service (Amazon S3) bucket. The inference.py is similar to the TorchServe custom handler; it has several functions that you can override to accommodate the model initialization, preprocessing and postprocessing of received requests, and inference logic.

The following notebook shows a full example of deploying a Hugging Face BERT model.

If you need a custom container, you can build a custom container image and push it to the Amazon Elastic Container Registry (Amazon ECR) repository. The model artifact in this case can be a TorchServe .mar file that bundles the model artifacts along with handler. We demonstrate this in the next section, where we use a SageMaker batch transform job.

SageMaker batch transform job

For offline use cases where requests are batched from a data source such as a dataset, SageMaker provides batch transform jobs. These jobs enable you to read data from an S3 bucket and write the results to a target S3 bucket. For more information, see Use Batch Transform to Get Inferences from Large Datasets. A full example of batch inference using batch transform jobs can be found in the following notebook, where we use a machine translation model from the FLORES competition. In this example, we show how to use a custom container to score our model using SageMaker. Using a custom inference container allows you to further customize your TorchServe configuration. In this example, we want to change and disable JSON decoding, which we can do through the TorchServe config.properties file.

When using a custom handler for TorchServe, we need to make sure that the handler implements the batch inference logic. Each handler can have custom functions to perform preprocessing, inference, and postprocessing. An example of a custom handler with batch inference support is available on GitHub.

We use our custom container to bundle the model artifacts with the handler as we do in TorchServe (making a .mar file). We also need an entry point to the Docker container that starts TorchServe with the batch size and JSON decoding set in config.properties. We demonstrate this in the example notebook.

The SageMaker batch transform job requires access to the input files from an S3 bucket, where it divides the input files into mini batches and sends them for inference. Consider the following points when configuring the batch transformation job:

  • Place the input files (such as a dataset) in an S3 bucket and set it as a data source in the job settings.
  • Assign an S3 bucket in which to save the results of the batch transform job.
  • Set BatchStrategy to MultiRecord and SplitType to Line if you need the batch transform job to make mini batches from the input file. If it can’t automatically split the dataset into mini batches, you can divide it into mini batches by putting each batch in a separate input file, placed in the data source S3 bucket.
  • Make sure that the batch size fits into the memory. SageMaker usually handles this automatically; however, when dividing batches manually, this needs to be tuned based on the memory.

The following code is an example for a batch transform job:

s3_bucket_name= 'SageMaker-us-west-2-XXXXXXXX'
batch_input = f"s3://{s3_bucket_name}/folder/jobename_TorchServe_SageMaker/"
batch_output = f"s3://{s3_bucket_name}/folder/jobname_TorchServe_SageMaker_output/"

batch_job_name = 'job-batch' + time.strftime("%Y-%m-%d-%H-%M-%S", time.gmtime())

request = {
    "ModelClientConfig": {
        "InvocationsTimeoutInSeconds": 3600,
        "InvocationsMaxRetries": 1,
    "TransformJobName": batch_job_name,
    "ModelName": model_name,
    "BatchStrategy": "MultiRecord",
    "TransformOutput": {"S3OutputPath": batch_output, "AssembleWith": "Line", "Accept": "application/json"},
    "TransformInput": {
        "DataSource": {
            "S3DataSource": {"S3DataType": "S3Prefix", "S3Uri": batch_input}
        "SplitType" : "Line",
        "ContentType": "application/json",
    "TransformResources": {"InstanceType": "ml.p2.xlarge", "InstanceCount": 1},

When we use the preceding settings and launch our transform job, it reads the input files from the source S3 bucket in batches and sends them for inference. The results are written back to the S3 bucket specified to the outputs.

The following code snippet shows how to create and launch a job using the preceding settings:


while True:
    response = sm.describe_transform_job(TransformJobName=batch_job_name)
    status = response["TransformJobStatus"]
    if status == "Completed":
        print("Transform job ended with status: " + status)
    if status == "Failed":
        message = response["FailureReason"]
        print("Transform failed with the following error: {}".format(message))
        raise Exception("Transform job failed")
    print("Transform job is still in status: " + status)


In this post, we reviewed the two modes SageMaker offers for online and offline inference. The former uses dynamic batching provided in TorchServe to batch the requests from multiple clients. The latter uses a SageMaker transform job to batch the requests from input files in an S3 bucket and run inference.

We also showed how to serve models on SageMaker using native SageMaker PyTorch inference toolkit container images, and how to use custom containers for use cases that require advanced TorchServe configuration settings.

As TorchServe continues to evolve to address the needs of the PyTorch community, new features are integrated into SageMaker to provide performant ways for serving models in production. For more information, check out the TorchServe GitHub repo and the SageMaker examples.

About the Authors

Phi Nguyen is a solutions architect at AWS helping customers with their cloud journey with a special focus on data lake, analytics, semantics technologies and machine learning. In his spare time, you can find him biking to work, coaching his son’s soccer team or enjoying nature walk with his family.

Nikhil Kulkarni is a software developer with AWS Machine Learning, focusing on making machine learning workloads more performant on the cloud and is a co-creator of AWS Deep Learning Containers for training and inference. He’s passionate about distributed Deep Learning Systems. Outside of work, he enjoys reading books, fiddling with the guitar and making pizza.

Hamid Shojanazeri is a Partner Engineer at Pytorch working on OSS high performance model optimization and serving. Hamid holds a P.h.D in Computer vision and worked as a researcher in multimedia labs in Australia, Malaysia and NLP lead in Opus.ai. He likes to find simpler solutions to hard problems and is an art enthusiast in his spare time.

Geeta Chauhan leads AI Partner Engineering at Meta AI with expertise in building resilient, anti-fragile, large scale distributed platforms for startups and Fortune 500s. Her team works with strategic partners, machine learning leaders across the industry and all major cloud service providers for building and launching new AI product services and experiences; and taking PyTorch models from research to production.. She is a winner of Women in IT – Silicon Valley – CTO of the year 2019, an ACM Distinguished Speaker and thought leader on topics ranging from Ethics in AI, Deep Learning, Blockchain, IoT. She is passionate about promoting use of AI for Good.

Read More

Graph-based recommendation system with Neptune ML: An illustration on social network link prediction challenges

Recommendation systems are one of the most widely adopted machine learning (ML) technologies in real-world applications, ranging from social networks to ecommerce platforms. Users of many online systems rely on recommendation systems to make new friendships, discover new music according to suggested music lists, or even make ecommerce purchase decisions based on the recommended products. In social networks, one common use case is to recommend new friends to a user based on the users’ other connections. Users with common friends likely know each other. Therefore, they should have a higher score for a recommendation system to propose if they haven’t been connected yet.

Social networks can naturally be expressed in a graph, where the nodes represent people, and the connections between people, such as friendship or co-workers, are represented by edges. The following illustrates one such social network. Let’s imagine that we have a social network with the members (nodes) Bill, Terry, Henry, Gary, and Alistair. Their relationships are represented by a link (edge), and each person’s interests, such as sports, arts, games, and comics, are represented by node properties.

The objective here is to predict if there is a potential missing link between members. For example, should we recommend a connection between Henry and Terry? Looking at the graph, we can see that they have two mutual friends, Gary and Alistair. Therefore, there is a good chance that Henry and Terry either already knew each other or may get to know each other soon. How about Henry and Bill? They don’t have any mutual friends, but they do have some weak connection through their friends’ connections. In addition, they both have similar interests in arts, comics, and games. Should we promote this connection? All of these questions and intuitions are the core logic of social network recommendation systems.

One possible way to do this is recommending relationships based on graph exploration. In graph query languages, such as Apache TinkerPop Gremlin, the implementation of rule sets such as counting common friends, is relatively easy, and it can be used to determine the link between Henry and Terry. However, these rule sets will be very complicated when we want account for other attributes such as node properties, connection strength, etc. Let’s imagine a rule set to determine the link between Henry and Bill. This rule set must account for their common interests and their weak connections through certain paths in the graph. To increase robustness, we might also need to add a distance factor to favor strong connections and penalize the weak ones. Similarly, we would want a factor to favor common interests. Soon, the rule sets that can reveal complex hidden patterns will become impossible to enumerate.

ML technology lets us discover hidden patterns by learning algorithms. One example is XGBoost, which is widely used for classification or regression tasks. However, algorithms such as XGBoost use a conventional ML approach based on a tabular data format. These approaches aren’t optimized for graph data structures, and they require complex feature engineering to cope with these data patterns.

In the preceding social network example, the graph interaction information is critical to improving the recommendation accuracy. Graph Neural Network (GNN) is a deep learning (DL) framework that can be applied to graph data to perform edge-level, node-level, or graph-level prediction tasks. GNNs can leverage individual node characteristics as well as graph structure information when learning the graph representation and underlying patterns. Therefore, in recent years, GNN-based methods have set new standards on many recommender system benchmarks. See more detailed information in recent research papers: A Comprehensive Survey on Graph Neural Networks and Graph Learning based Recommender Systems: A Review.

The following is one famous example of such a use case. Researchers and engineers at Pinterest have trained Graph Convolutional Neural Networks for Web-Scale Recommender Systems, called PinSage, with three billion nodes representing pins and boards, and 18 billion edges. PinSage generates high-quality embeddings that represent pins (visual bookmarks to online content). These can be used for a wide range of downstream recommendation tasks, such as nearest-neighbor lookups in the learned embedding space for content discovery and recommendations.

In this post, we will walk you through how to use GNNs for recommendation use cases by casting this as a link prediction problem. We’ll also illustrate how Neptune ML can facilitate implementation. We will also provide sample code on GitHub to train your first GNN with Neptune ML, and make recommendation inferences on the demo graph through link prediction tasks.

Link prediction with Graph Neural Networks

Considering the previous social network example, we would like to recommend new friends to Henry. Both Terry and Bill would be good candidates. Terry has more common friends (Gary, Alistair) with Henry but no common interests. While Bill shares common interests (arts, comics, games) with Henry, but no common friends. Which one would be a better recommendation? When framed as a link prediction problem, the task is to assign a score to any possible link between the two nodes. The higher the link score, the more likely this recommendation will converge. By learning link structures already present in the graph, a link prediction model can generalize new link predictions that ‘complete’ the graph.

The parameters of the function f that predicts the link score is learned during the training phase. Since the function f makes a prediction for any two nodes in the graph, the feature vectors associated with the nodes are essential to the learning process. To predict the link score between Henry and Bill, we have a set of raw data features (arts, comics, games) that can represent Henry and Bill. We transform this, along with the connections in the graph, using a GNN network to form new representations known as node embeddings. We can also supplement or replace the initial raw features with vectors from an embedding lookup table that can be learned during the training process. Ideally, the embedded features for Henry and Bill should represent their interests as well as their topological information from the graph.

How GNNs work

A GNN transforms the initial node features to node embeddings by using a technique called message passing. The message passing process is illustrated in the following figure. In the beginning, the node attributes or features are converted into numerical attributes. In our case, we do one-hot encoding of the categorical features (Henry’s interests: arts, comics, games). Then, the first layer of GNN aggregates all of the neighbors’ (Gary and Alistair) raw features (in black) to form a new set of features (in yellow). A common approach is the linear transformation of all of the neighboring features, then aggregate them through a normalized sum, and pass the results into a non-linear activation function, such as ReLU, to generate a new vector set. The following figure illustrates how message passing works for node Henry. H, the GNN message passing algorithm, will compute representations for all of the graph nodes. These are later used as the input features for the second layer.

The second layer of a GNN repeats the same process. It takes the previously computed feature (in yellow) from the first layer as input, aggregates all of Gary and Alistair’s neighbors’ new embedded features, and generates second layer feature vectors for Henry (in orange). As you can see, by repeating the message passing mechanism, we extended the feature aggregation to 2-hop neighbors. In our illustration, we limit ourselves to 2-hop neighbors, but extending into 3-hop neighbors can be done in the same way by adding another GNN layer.

The final embeddings from Henry and Bill (in orange) are used for computing the score. During the training process, the link score is defined as 1 when the edge exists between the two nodes (positive sample), and as 0 when the edges between the two nodes don’t exist (negative sample). Then, the error or loss between the actual score and the prediction f(e1,e2) is back-propagated into previous layers to adjust the weights. Once the training is finished, we can rely on the embedded feature vectors for each node to compute their link scores with our function f.

In this example, we simplified the learning task on a homogeneous graph, where all of the nodes and edges are of the same type. For example, all of the nodes in the graph are the “People” type, and all of the edges are the “friends with” type. However, the learning algorithm also supports heterogeneous graphs with different node and edge types. We can extend the previous use case to recommend products to different users that share similar interactions and interests. See more details in this research paper: Modeling Relational Data with Graph Convolutional Networks.

At AWS re:Invent 2020, we introduced Amazon Neptune ML, which lets our customers train ML models on graph data, without necessarily having deep ML expertise. In this example, with the help of Neptune ML, we will show you how to build your own recommender system on graph data.

Train your Graph Convolution Network with Amazon Neptune ML

Neptune ML uses graph neural network technology to automatically create, train, and deploy ML models on your graph data. Neptune ML supports common graph prediction tasks, such as node classification and regression, edge classification and regression, and link prediction.

It is powered by:

  • Amazon Neptune: a fast, reliable, and fully managed graph database, which is optimized for storing billions of relationships and querying the graph with millisecond latency. Amazon Neptune supports three open standards for building graph applications: Apache TinkerPop Gremlin, RDF SPARQL, and openCypher. Learn more at Overview of Amazon Neptune Features.
  • Amazon SageMaker: a fully managed service that provides every developer and data scientist with the ability to prepare build, train, and deploy ML models quickly.
  • Deep Graph Library (DGL): an open-source, high-performance, and scalable Python package for DL on graphs. It provides fast and memory-efficient message passing primitives for training Graph Neural Networks. Neptune ML uses DGL to automatically choose and train the best ML model for your workload. This enables you to make ML-based predictions on graph data in hours instead of weeks.

The easiest way to get started with Neptune ML is to use the AWS CloudFormation quickstart template. The template installs all of the necessary components, including a Neptune DB cluster, and sets up the network configurations, IAM roles, and associated SageMaker notebook instance with pre-populated notebook samples for Neptune ML.

The following figure illustrates different steps for Neptune ML to train a GNN-based recommendation system. Let’s zoom in on each step and explore what it involves:

  1. Data export configuration

The first step in our Neptune ML process is to export the graph data from the Neptune cluster. We must specify the parameters and model configuration for the data export task. We use the Neptune workbench for all of the configurations and commends. The workbench lets us work with the Neptune DB cluster using Jupyter notebooks hosted by Amazon SageMaker. In addition, it provides a number of magic commands in the notebooks that save a great deal of time and effort. Here is our example of export parameters:

"command": "export-pg", 
"params": { "endpoint": neptune_host,
            "profile": "neptune_ml",
            "cloneCluster": False
"outputS3Path": f'{s3_bucket_uri}/neptune-export',
"additionalParams": {
        "neptune_ml": {
          "version": "v2.0",
        "targets": [
                "edge": ["User", "FRIEND", "User"],
                "type" : "link_prediction"
         "features": [
                "node": "User",
                "property": "interests",
                "type": "category",
                "separator": " ;" 
"jobSize": "small"}

In export_params, we must configure the basic setup, such as the Neptune cluster and output Amazon Simple Storage Service (S3) path for exported data storage. The configuration specified in additionalParams is the type of ML task to perform. In this example, link prediction is optionally used to predict a particular edge type (User—FRIEND—User). If no target type is specified, then Neptune ML will assume that the task is Link Prediction. The parameters also specify details about the data stored in our graph and how the ML model will interpret that data (we have “User” as node, and “interests” as node property).

To run each step in the ML building process, simply use Neptune workbench commands. The Neptune workbench contains a line magic and a cell magic that can save you a lot of time managing these steps. To run the data export, use the Neptune workbench command: %neptune_ml export start

Once the export job completes, we will have the Neptune graph exported into CSV format and stored in an S3 bucket. There will be two types of files: nodes.csv and edges.csv. A file named training-data-configuration.json will also be generated which has the configuration needed for Neptune ML to perform model training.

See Export data from Neptune for Neptune ML for more information.

  1. Data preprocessing

Neptune ML performs feature extraction and encoding as part of the data-processing steps. Common types of property pre-processing include: encoding categorical features through one-hot encoding, bucketing numerical features, or using word2vec to encode a string property or other free-form text property values.

In our example, we will simply use the property “interests”. Neptune ML encodes the values as multi-categorical. However, if a categorical value is complex (more than three words per node), then Neptune ML infers the property type to be text and uses the text_word2vec encoding.

To run data preprocessing, use the following Neptune notebook magic command: %neptune_ml dataprocessing start

At the end of this step, a DGL graph is generated from the exported dataset for use by the model training step. Neptune ML automatically tunes the model with Hyperparameter Optimization Tuning jobs defined in training-data-configuration.json. We can download and modify this file to tune the model’s hyperparameters, such as batch-size, num-hidden, num-epochs, dropout, etc. Here is a sample configuration.json file.

See Processing the graph data exported from Neptune for training for more information.

  1. Model training

The next step is the automated training of the GNN model. The model training is done in two stages. The first stage uses a SageMaker Processing job to generate a model training strategy. This is a configuration set that specifies what type of model and model hyperparameter ranges will be used for the model training.

Then, a SageMaker hyperparameter tuning job will be launched. The SageMaker Hyperparameter Tuning Optimization job runs a pre-specified number of model training job trials on the processed data, tries different hyperparameter combinations according to the model-hpo-configuration.json file, and stores the model artifacts generated by the training in the output Amazon S3 location.

To start the training step, you can use the %neptune_ml training start command.

Once all of the training jobs are complete, the Hyperparameter tuning job will save the artifacts from the best performing model, which will be used for inference.

At the end of the training, Neptune ML will instruct SageMaker to save the trained model, the raw embeddings calculated for the nodes and edges, and the mapping information between the embeddings and node indices.

See Training a model using Neptune ML for more information.

  1. Create an inference endpoint in Amazon SageMaker

Now that the graph representation is learned, we can deploy the learned model behind an endpoint to perform inference requests. The model input will be the User for which we need to generate friends’ recommendations, along with the edge type, and the output will be the list of the likely recommended friends for that user.

To deploy the model to the SageMaker endpoint instance, use the %neptune_ml endpoint create command.

  1. Query the ML model using Gremlin

Once the endpoint is ready, we can use it for graph inference queries. Neptune ML supports graph inference queries in Gremlin or SPARQL. In our example, we can now check the friends recommendation with Neptune ML on User “Henry”. It requires nearly the same syntax to traverse the edge, and it lists the other Users that are connected to Henry through the FRIEND connection.

    V().hasLabel('User').has('name', 'Henry').
1 Bill

Neptune#ml.prediction returns the connection determined by Neptune ML predictions by using the model that we just trained on the social graph. Bill is returned just like our expectation.

Here is another sample prediction query that is used to predict the top eight users that are most likely to connect with Henry:

with("Neptune#ml.limit",8).V().hasLabel('User').has('name', 'Henry').

1 Bill, 2 Colin, 3 Sarah, 4 Gordon, 5 Mary, 6 Josie, 7 Arnold, 8 Terry

The results are ranked from stronger connection to weaker, where link Henry — FRIEND — Colin and Henry — FRIEND — Terry is also proposed. This proposition is through graph-based ML where complex interaction patterns on graph can be explored.

See Gremlin inference queries in Neptune ML for more information.

Model transform or retraining when graph data changes

Another question you might ask is: what if my social network changes, or if I want to make recommendations for newly added users? In these scenarios, where you have continuously changing graphs, you may need to update ML predictions with the newest graph data. The generated model artifacts after training are directly tied to the training graph. This means that the inference endpoint must be updated once the entities in the original training graph changes.

However, you don’t need to retrain the whole model to make predictions on the updated graph. With an incremental model inference workflow, you only need to export the Neptune DB data, perform an incremental data preprocessing, run a model batch transform job, and then update the inference endpoint. The model-transform step takes the trained model from the main workflow and the results of the incremental data preprocessing step as inputs. Then it outputs a new model artifact to use for inference. This new model artifact is created from the up-to-date graph data.

One special focus here is for the model-transform step command. It can compute model artifacts on graph data that was not used for model training. The node embeddings are re-computed and any existing node embeddings are overridden. Neptune ML applies the learned GNN encoder from the previous trained model to the new graph data nodes with their new features. Therefore, the new graph data must be processed using the same feature encodings, and it must adhere to the same graph schema as the original graph data. See more Neptune ML implementation details at Generating new model artifacts.

Moreover, you can retrain the whole model if the graph changes dramatically, or if the previously trained model could no longer accurately represent the underlying interactions. In this case, re-using the learned model parameters on a new graph cannot guarantee a similar model performance. You must retrain your model on the new graph. To accelerate the hyperparameters search, Neptune ML can leverage the information from the previous model training task with warm start: the results of previous training jobs are used to select good combinations of hyperparameters to search over the new tuning job.

See workflows for handling evolving graph data for more details.


In this post, you have seen how Neptune ML and GNNs can help you make recommendations on graph data using a link prediction task by combining information from the complex interaction patterns in the graph.

Link prediction is one way of implementing a recommendation system on graph. You can construct your recommender in many other ways. You can use the embeddings learned during link prediction training to cluster the nodes into different segments in an unsupervised manner, and recommend items to the one belonging to the same segment. Furthermore, you can obtain the embeddings and feed them into a downstream similarity-based recommendation system as an input feature. Now this additional input feature also encodes the semantic information derived from graph and can provide significant improvements to the overall precision of the system. Learn more about Amazon Neptune ML by visiting the website or feel free to ask questions in the comments!

About the Authors

Yanwei Cui, PhD, is a Machine Learning Specialist Solutions Architect at AWS. He started machine learning research at IRISA (Research Institute of Computer Science and Random Systems), and has several years of experience building artificial intelligence powered industrial applications in computer vision, natural language processing and online user behavior prediction. At AWS, he shares the domain expertise and helps customers to unlock business potentials, and to drive actionable outcomes with machine learning at scale. Outside of work, he enjoys reading and traveling.

Will Badr is a Principal AI/ML Specialist SA who works as part of the global Amazon Machine Learning team. Will is passionate about using technology in innovative ways to positively impact the community. In his spare time, he likes to go diving, play soccer and explore the Pacific Islands.

Read More