Develop and train large models cost-efficiently with Metaflow and AWS Trainium

Develop and train large models cost-efficiently with Metaflow and AWS Trainium

This is a guest post co-authored with Ville Tuulos (Co-founder and CEO) and Eddie Mattia (Data Scientist) of Outerbounds.

To build a production-grade AI system today (for example, to do multilingual sentiment analysis of customer support conversations), what are the primary technical challenges? Historically, natural language processing (NLP) would be a primary research and development expense. In 2024, however, organizations are using large language models (LLMs), which require relatively little focus on NLP, shifting research and development from modeling to the infrastructure needed to support LLM workflows.

For AWS and Outerbounds customers, the goal is to build a differentiated machine learning and artificial intelligence (ML/AI) system and reliably improve it over time. This often means the method of using a third-party LLM API won’t do for security, control, and scale reasons. Owning the infrastructural control and knowhow to run workflows that power AI systems is a requirement.

Returning to the original question, three MLOps challenges may arise:

  • You need high-quality data to train and fine-tune models
  • You need a diverse cloud infrastructure for experimentation, training, tracking, and orchestrating the production system
  • You need a significant amount of compute to power the system

In this post, we highlight a collaboration between Outerbounds and AWS that takes a step towards addressing the last two challenges. First, the AWS Trainium accelerator provides a high-performance, cost-effective, and readily available solution for training and fine-tuning large models. Second, open source Metaflow provides the necessary software infrastructure to build production-grade ML/AI systems in a developer-friendly manner. It provides an approachable, robust Python API for the full infrastructure stack of ML/AI, from data and compute to workflows and observability.

In the following sections, we first introduce Metaflow and the Trainium integration. We then show how to set up the infrastructure stack you need to take your own data assets and pre-train or fine-tune a state-of-the-art Llama2 model on Trainium hardware.

Metaflow overview

Metaflow was originally developed at Netflix to enable data scientists and ML engineers to build ML/AI systems quickly and deploy them on production-grade infrastructure. Netflix open sourced the framework in 2019 with integrations to AWS services like AWS Batch, AWS Step Functions (see Unbundling Data Science Workflows with Metaflow and AWS Step Functions), Kubernetes, and throughput-optimized Amazon Simple Storage Service (Amazon S3), so you can build your own Netflix-scale ML/AI environment in your AWS account.

The key motivation of Metaflow is to address the typical needs of all ML/AI projects with a straightforward, human-centric API, from prototype to production (and back). The following figure illustrates this workflow.

Typical workflow with Metaflow and AWS Trainium

Metaflow’s coherent APIs simplify the process of building real-world ML/AI systems in teams. Metaflow helps scientists and engineers access, move, and manipulate data efficiently; track and version experiments and models; orchestrate and integrate workflows to surrounding systems; and scale compute to the cloud easily. Moreover, it has first-class support for teams, such as namespacing and deploying workflows in versioned production branches.

Now, with today’s announcement, you have another straightforward compute option for workflows that need to train or fine-tune demanding deep learning models: running them on Trainium.

How Metaflow integrates with Trainium

From a Metaflow developer perspective, using Trainium is similar to other accelerators. After a Metaflow deployment is configured to access Trainium chips through the compute platform customers use with Metaflow (which we discuss later in this post), ML engineers and data scientists can operate autonomously in the land of deep learning code. Scientists can write PyTorch, Hugging Face, and use the AWS Neuron SDK along with the NeuronX Distributed SDK to optimize these frameworks to target Trainium devices, and Metaflow integrates with the underlying AWS services to separate concerns about how to actually run the code at scale.

As illustrated by the following figure, you can declare the following in a few lines of Python code:

  • How many nodes to launch
  • How many Trainium devices to use per node
  • How the nodes are interconnected (Elastic Fabric Adapter)
  • How often to check the resource utilization
  • What training script the torchrun process should run on each node

Configuring a training job using a Metaflow FlowSpec

You can initialize the training process in the start step, which directs the next train step to run on two parallel instances (num_parallel=2). The decorators of the train step configure your desired training setup:

  • @torchrun – Sets up PyTorch Distributed across two instances
  • @batch – Configures the Trainium nodes, managed by AWS Batch
  • @neuron_monitor – Activates the monitoring UI that allows you to monitor the utilization of the Trainium cores

Metaflow allows you to configure all this functionality in a few lines of code. However, the main benefit is that you can embed Trainium-based training code inside a larger production system, using the scaffolding provided by Metaflow.

Benefits of using Trainium with Metaflow

Trainium and Metaflow work together to solve problems like what we discussed earlier in this post. The Trainium devices and Neuron software stack make it straightforward for teams to access and effectively use the high-performance hardware needed for cutting-edge AI.

Trainium provides a few key benefits for building real-world AI systems:

  • Trainium instances can help reduce generative AI model training and fine-tuning costs by up to 50% over comparable instances on AWS
  • It is readily available in many AWS Regions, is often more available than GPU-based instance types, and scaling is available in the most popular Regions worldwide
  • The hardware and software are mature and actively developed by AWS

If you have been struggling with GPU availability and cost, you’ll surely appreciate these benefits. Using Trainium effectively can require a bit of infrastructure effort and knowledge, which is a key motivation for this integration. Through Metaflow and the deployment scripts provided in this post, you should be able to get started with Trainium with ease.

Besides easy access, using Trainium with Metaflow brings a few additional benefits:

Infrastructure accessibility

Metaflow is known for its developer-friendly APIs that allow ML/AI developers to focus on developing models and applications, and not worry about infrastructure. Metaflow helps engineers manage the infrastructure, making sure it integrates with existing systems and policies effortlessly.

Data, model, and configuration management

Metaflow provides built-in, seamless artifact persistence, tracking, and versioning, which covers the full state of the workflows, making sure you’ll follow MLOps best practices. Thanks to Metaflow’s high-throughput S3 client, you can load and save datasets and model checkpoints very quickly, without having to worry about extra infrastructure such as shared file systems. You can use artifacts to manage configuration, so everything from hyperparameters to cluster sizing can be managed in a single file, tracked alongside the results.

Observability

Metaflow comes with a convenient UI, which you can customize to observe metrics and data that matter to your use cases in real time. In the case of Trainium, we provide a custom visualization that allows you to monitor utilization of the NeuronCores inside Trainium instances, making sure that resources are used efficiently. The following screenshot shows an example of the visualization for core (top) and memory (bottom) utilization.

Visualizing NeuronCore and memory utilization

Multi-node compute

Finally, a huge benefit of Metaflow is that you can use it to manage advanced multi-instance training clusters, which would take a lot of involved engineering otherwise. For instance, you can train a large PyTorch model, sharded across Trainium instances, using Metaflow’s @torchrun and @batch decorators.

Behind the scenes, the decorators set up a training cluster using AWS Batch multi-node with a specified number of Trainium instances, configured to train a PyTorch model across the instances. By using the launch template we provide in this post, the setup can benefit from low-latency, high-throughput networking via Elastic Fabric Adapter (EFA) networking interfaces.

Solution overview

As a practical example, let’s set up the complete stack required to pre-train Llama2 for a few epochs on Trainium using Metaflow. The same recipe applies to the fine-tuning examples in the repository.

Deploy and configure Metaflow

If you already use a Metaflow deployment, you can skip to the next step to deploy the Trainium compute environment.

Deployment

To deploy a Metaflow stack using AWS CloudFormation, complete the following steps:

  1. Download the CloudFormation template.
  2. On the CloudFormation console, choose Stacks in the navigation pane.
  3. Choose Create new stack.
  4. For Prepare template¸ select Template is ready.
  5. For Template source, select Upload a template file.
  6. Upload the template.
  7. Choose Next.

Deploy Metaflow stack using CloudFormation

  1. If you are brand new to Metaflow, or are trying this recipe as a proof of concept, we suggest you change the APIBasicAuth parameter to false and leave all other default parameter settings.
  2. Complete the stack creation process.

Specify stack details

After you create the CloudFormation stack and configure Metaflow to use the stack resources, there is no additional setup required. For more information about the Metaflow components that AWS CloudFormation deploys, see AWS Managed with CloudFormation.

Configuration

To use the stack you just deployed from your laptop or cloud workstation, complete the following steps:

  1. Prepare a Python environment and install Metaflow in it:
pip install metaflow
  1. Run metaflow configure aws in a terminal.
metaflow configure aws

After the CloudFormation stack deployment is complete, the Outputs on the stack details page will contain a list of resource names and their values, which you can use in the Metaflow AWS configuration prompts.

Deploy a Trainium compute environment

The default Metaflow deployment from the previous step has an AWS Batch compute environment, but it will not be able to schedule jobs to run on Amazon Elastic Compute Cloud (Amazon EC2) instances with Trainium devices. To deploy an AWS Batch compute environment for use with Trainium accelerators, you can use the following CloudFormation template. Complete the following steps:

  1. Download the CloudFormation template.
  2. On the CloudFormation console, choose Stacks in the navigation pane.
  3. Choose Create new stack.
  4. For Prepare template¸ select Template is ready.
  5. For Template source, select Upload a template file.
  6. Upload the template.
  7. Choose Next.
  8. Complete the stack creation process.

Take note of the name of the AWS Batch job queue that you created to use in a later step.

Prepare a base Docker image to run Metaflow tasks

Metaflow tasks run inside Docker containers when AWS Batch is used as a compute backend. To run Trainium jobs, developers need to build a custom image and specify it in the @batch decorator Metaflow developers use to declare task resources:

@batch(trainium=16, efa=8, image=”YOUR_IMAGE_IN_ECR” )
@step
def train_llama2(self):
    # neuron distributed training code

To make the image, complete the following steps:

  1. Create an Amazon Elastic Container Registry (Amazon ECR) registry to store your image in.
  2. Create and log in to an EC2 instance with sufficient memory. For this post, we used Ubuntu x86 OS on a C5.4xlarge instance.
  3. Install Docker.
  4. Copy the following Dockerfile to your instance.
  5. Authenticate with the upstream base image provider:
aws ecr get-login-password 
--region $REGION | docker login 
--username AWS 
--password-stdin 763104351884.dkr.ecr.$REGION.amazonaws.com
  1. Build the image:
docker build . -t $YOUR_IMAGE_NAME:$YOUR_IMAGE_TAG
  1. On the Amazon ECR console, navigate to the ECR registry you created, and you will find the commands needed to authenticate from the EC2 instance and push your image.

Clone the repository on your workstation

Now you’re ready to verify the infrastructure is working properly, after which you can run complex distributed training code like Llama2 training. To get started, clone the examples repository to the workstation where you configured Metaflow with AWS:

git clone https://github.com/outerbounds/metaflow-trainium

Verify the infrastructure with an allreduce example

To validate your infrastructure configuration, complete the following steps:

  1. Navigate to the allreduce example:
cd allreduce-trn
  1. Open the flow.py file and make sure to set the job queue and image to the name of the queue you deployed with AWS CloudFormation and the image you pushed to Amazon ECR, respectively.
  2. To run the allreduce code, run the following Metaflow command:
python flow.py --package-suffixes=.sh run

You can find the logs (truncated in the following code snippet for readability) in the Metaflow UI:

Task is starting (status SUBMITTED)...
Task is starting (status RUNNABLE)... (parallel node status: [SUBMITTED:3])
Task is starting (status STARTING)... (parallel node status: [SUBMITTED:3])
Task is starting (status RUNNING)... (parallel node status: [SUBMITTED:3])
Setting up task environment.
Downloading code package...
Code package downloaded.
Task is starting.
...
Compiler status PASS
result OK step 0: tensor([[64., 64., 64.],
[64., 64., 64.]], device='xla:1')
...
result OK step 900: tensor([[64., 64., 64.],
[64., 64., 64.]], device='xla:1')
Before final rendezvous
Waiting for batch secondary tasks to finish

Configure and run any Neuron distributed code

If the allreduce test runs successfully, you are ready to move on to meaningful workloads. To complete this onboarding, complete the following steps:

  1. Navigate to the llama2-7b-pretrain-trn directory.
  2. Similar to the all reduce example, before using this code, you need to modify the config.py file so that it matches the AWS Batch job queue and ECR image that you created. Open the file, find these lines, and modify them to your values:
class BatchJobConfig:
    # <snip>
    image: str = "YOUR_IMAGE"
    job_queue: str = "YOUR_QUEUE"
  1. After modifying these values, and any others you want to experiment with, run the following command:
python config.py
  1. Then run the workflow to pre-train your own Llama2 model from scratch:
python flow.py run --config-file config.yaml

This will train the model on however many nodes you specify in the config.py file, and will push the trained model result to Amazon S3 storage, versioned by Metaflow’s data store using the flow name and run ID.

Logs will appear like the following (truncated from a sample run of five steps for readability):

Task is starting (status SUBMITTED)...
Task is starting (status RUNNABLE)... (parallel node status: [SUBMITTED:3])
Task is starting (status STARTING)... (parallel node status: [SUBMITTED:3])
Task is starting (status RUNNING)... (parallel node status: [SUBMITTED:3])
Setting up task environment.
Downloading code package...
Code package downloaded.
Task is starting.
...
initializing tensor model parallel with size 8
initializing pipeline model parallel with size 1
initializing data parallel with size 16
...
Epoch 0 begin Fri Mar 15 21:19:10 2024
...
Compiler status PASS
...
(0, 3) step_loss : 15.4375 learning_rate : 3.00e-04 throughput : 4.38
(0, 4) step_loss : 12.1250 learning_rate : 1.50e-04 throughput : 5.47
(0, 5) step_loss : 11.8750 learning_rate : 0.00e+00 throughput : 6.44
...
Writing data to the provided results file: /metaflow/metaflow/metrics.json
...
Waiting for batch secondary tasks to finish

Clean up

To clean up resources, delete the CloudFormation stacks for your Metaflow deployment and Trainium compute environment:

aws cloudformation delete-stack --stack-name metaflow
aws cloudformation delete-stack --stack-name trn1-batch

Conclusion

You can get started experimenting with the solution presented in this post in your environment today. Follow the instructions in the GitHub repository to pre-train a Llama2 model on Trainium devices. Additionally, we have prepared examples for fine-tuning Llama2 and BERT models, demonstrating how you can use the Optimum Neuron package to use the integration from this post with any Hugging Face model.

We are happy to help you get started. Join the Metaflow community Slack for support, to provide feedback, and share experiences!


About the authors

Ville Tuulos is a co-founder and CEO of Outerbounds, a developer-friendly ML/AI platform. He has been developing infrastructure for ML and AI for over two decades in academia and as a leader at a number of companies. At Netflix, he led the ML infrastructure team that created Metaflow, a popular open-source, human-centric foundation for ML/AI systems. He is also the author of a book, Effective Data Science Infrastructure, published by Manning.

Eddie Mattia is in scientific computing and more recently building machine learning developer tools. He has worked as a researcher in academia, in customer-facing and engineering roles at MLOps startups, and as a product manager at Intel. Currently, Eddie is working to improve the open-source Metaflow project and is building tools for AI researchers and MLOps developers at Outerbounds.

Vidyasagar specializes in high performance computing, numerical simulations, optimization techniques and software development across industrial and academic environments. At AWS, Vidyasagar is a Senior Solutions Architect developing predictive models, generative AI and simulation technologies. Vidyasagar has a PhD from the California Institute of Technology.

Diwakar Bansal is an AWS Senior Specialist focused on business development and go-to-market for GenAI and Machine Learning accelerated computing services. Diwakar has led product definition, global business development, and marketing of technology products in the fields of IOT, Edge Computing, and Autonomous Driving focusing on bringing AI and Machine leaning to these domains. Diwakar is passionate about public speaking and thought leadership in the Cloud and GenAI space.

Sadaf Rasool is a Machine Learning Engineer with the Annapurna ML Accelerator team at AWS. As an enthusiastic and optimistic AI/ML professional, he holds firm to the belief that the ethical and responsible application of AI has the potential to enhance society in the years to come, fostering both economic growth and social well-being.

Scott Perry is a Solutions Architect on the Annapurna ML accelerator team at AWS. Based in Canada, he helps customers deploy and optimize deep learning training and inference workloads using AWS Inferentia and AWS Trainium. His interests include large language models, deep reinforcement learning, IoT, and genomics.

Read More

Cohere Command R and R+ are now available in Amazon SageMaker JumpStart

Cohere Command R and R+ are now available in Amazon SageMaker JumpStart

This blog post is co-written with Pradeep Prabhakaran from Cohere. 

Today, we are excited to announce that Cohere Command R and R+ foundation models are available through Amazon SageMaker JumpStart to deploy and run inference. Command R/R+ are the state-of-the-art retrieval augmented generation (RAG)-optimized models designed to tackle enterprise-grade workloads.

In this post, we walk through how to discover and deploy Cohere Command R/R+ via SageMaker JumpStart.

What are Cohere Command R and Command R+?

Cohere Command R is a family of highly scalable language models that balance high performance with strong accuracy. Command R family – include Command R and Command R+ models – are optimized for RAG based workflows such as conversational interaction and long context tasks, enabling companies to move beyond proof of concept and into production. These powerful models are designed to handle complex tasks with high performance and strong accuracy, making them suitable for real-world applications.

Command R boasts high precision on RAG and tool use tasks, low latency and high throughput, a long 128,000-token context length, and strong capabilities across 10 key languages: English, French, Spanish, Italian, German, Portuguese, Japanese, Korean, Arabic, and Chinese.

Command R+ is the newest model, optimized for extremely performant conversational interaction and long-context tasks. It is recommended for workflows that lean on complex RAG functionality and multi-step tool use (agents), while Cohere R is well-suited for simpler RAG and single-step tool use tasks, as well as applications where price is a major consideration.

What is SageMaker JumpStart

With SageMaker JumpStart, you can choose from a broad selection of publicly available foundation models. ML practitioners can deploy foundation models to dedicated SageMaker instances from a network-isolated environment and customize models using SageMaker for model training and deployment. You can now discover and deploy Cohere Command R/R+ models with a few choices in Amazon SageMaker Studio or programmatically through the SageMaker Python SDK. Doing so enables you to derive model performance and machine learning operations (MLOps) controls with SageMaker features such as SageMaker PipelinesSageMaker Debugger, or container logs.

The model is deployed in an AWS secure environment and under your virtual private cloud (VPC) controls, helping provide data security. Cohere Command R/R+ models are available today for deployment and inferencing in Amazon SageMaker Studio in us-east-1 (N. Virginia), us-east-2 (Ohio), us-west-1 (N. California), us-west-2 (Oregon), Canada (Central), eu-central-1 (Frankfurt), eu-west-1 (Ireland), eu-west-2 (London), eu-west-3 (Paris), eu-north-1 (Stockholm), ap-southeast-1 (Singapore), ap-southeast-2 (Sydney), ap-northeast-1 (Tokyo) , ap-northeast-2 (Seoul), ap-south-1 (Mumbai), and sa-east-1 (Sao Paulo).

Discover models

You can access the foundation models through SageMaker JumpStart in the SageMaker Studio UI and the SageMaker Python SDK. In this section, we go over how to discover the models in SageMaker Studio.

From the SageMaker JumpStart landing page, you can easily discover various models by browsing through different hubs, which are named after model providers. The Cohere Command R and R+ models are available in the Cohere hub. If you don’t see these models, ensure you have the latest SageMaker Studio version by shutting down and restarting Studio Classic Apps.

To find the Command R and R+ models, search for “Command R” in the search box located at the top left of the SageMaker JumpStart landing page. Each model can be deployed on Amazon Elastic Compute Cloud (EC2) P5 instances powered by NVIDIA H100 Tensor Core GPUs (p5.48xlarge) and Amazon EC2 P4de instances powered by NVIDIA A100 Tensor Core GPUs (ml.p4de.24xlarge).

Deploy a model

To illustrate model deployment, we’ll deploy Cohere Command R+ on NVIDIA H100. Choose the model card to open the corresponding model detail page.

When you choose Deploy, a window appears prompting you to subscribe to the model on AWS Marketplace. Choose Subscribe, which redirects you to the AWS Marketplace listing for Cohere Command R+ (H100). Follow the on-screen instructions to complete the subscription process.

Once subscribed, return to the model detail page and choose Deploy in the window. The deployment process initiates.

Alternatively, you can choose Notebooks on the model card and open the example notebook in JupyterLab. This notebook provides end-to-end guidance on deploying the model for inference and cleaning up resources. You can also find this example notebook in the Cohere SageMaker GitHub repository. To ensure the security of the endpoint, you can configure AWS Key Management Service (KMS) key for a SageMaker endpoint configuration.

If an endpoint has already been created, you can simply connect to it:

co = Client(region_name=region)

co.connect_to_endpoint(endpoint_name="cohere-command-r-plus")

Real-time inference

Once your endpoint has been connected, you can perform real-time inference using the co.chat endpoint.

message = "Write a LinkedIn post about starting a career in tech:"
response = co.chat(message=message, stream=False)

Multilingual capabilities

Command R/R+ is optimized to perform well in 10 key languages, as listed in the introduction. Additionally, pre-training data have been included for the following 13 languages: Russian, Polish, Turkish, Vietnamese, Dutch, Czech, Indonesian, Ukrainian, Romanian, Greek, Hindi, Hebrew, Persian.

The model has been trained to respond in the language of the user. Here’s an example in Spanish:

co.chat(
  message="Écris une description de produit pour une voiture électrique en 50 à 75 mots"
)

Here’s what the response might look like:

Découvrez la voiture électrique qui va révolutionner votre façon de conduire.
Avec son design élégant, cette voiture offre une expérience de conduit unique avec une accélération puissante et une autonomie impressionnante. Sa technologie avancée vous garantit une charge rapide et une fiabilité inégalée. Avec sa conception innovante et durable, cette voiture est parfaite pour les trajets urbains et les longues distances. Profitez d'une conduite silencieuse et vivez l'expérience de la voiture électrique!

Command R/R+ can also perform cross-lingual tasks, such as translation or answering questions about content in other languages.

Chat with documents (RAG)

Command R/R+ can ground its generations. This means that it can generate responses based on a list of supplied document snippets, and it includes citations in its response indicating the source of the information.

For example, the code snippet that follows produces an answer to “How deep is the Mariana Trench” along with inline citations based on the provided on-line documents.

Request:

message="How deep is the Mariana Trench"
documents = [
    {
       "id": "national_geographic_everest",
       "title": "Height of Mount Everest",
       "snippet": "The height of Mount Everest is 29,035 feet",
       "url": "https://education.nationalgeographic.org/resource/mount-everest/",
    },
    {
        "id": "national_geographic_mariana",
        "title": "Depth of the Mariana Trench",
        "snippet": "The depth of the Mariana Trench is 36,070 feet",
        "url": "https://www.nationalgeographic.org/activity/mariana-trench-deepest-place-earth",
    }
]

response = co.chat(message=message, documents=documents, stream=False)

Response:

{
   text: “The depth of the Mariana Trench is 36,070 feet.”,
   citations: [
      {'start': 35, 'end': 47, 'text': '36,070 feet.', 'document_ids': ['national_geographic_mariana']}
   ],
   documents: [
      {'id': 'national_geographic_mariana', 
       'snippet': 'The depth of the Mariana Trench is 36,070 feet', 
       'title': 'Depth of the Mariana Trench'
	'url':'https://www.nationalgeographic.org/activity/mariana-trench-deepest-place-earth'}
   ]
}

Single-Step & Multi-Step Tool Use

Command R/R+, comes with a Tool Use API that enables the language model to interact with user-defined tools to automate highly sophisticated tasks. Command R/R+ in Tool Use mode creates API payloads (JSONs with specific parameters) based on user interactions and conversational history. These can be used to instruct any other application or tool.

For example, an application can be instructed to automatically categorize and route support tickets to the appropriate individual, change a status in customer relationship management software (CRM), or retrieve relevant snippets from a vector database. It comes in two variants; single-step and multi-step:

  • Single-step tool use enables a richer set of behaviors by leveraging data stored in tools, taking actions through APIs, interacting with a vector database, querying a search engine, etc.
  • Multi-step tool use is an extension of this basic idea and allows the model to call more than one tool in a sequence of steps, using the results from one tool call in a subsequent step. This process allows the language model to reason, perform dynamic actions, and quickly adapt based on information coming from external sources.

To explore these capabilities further, you can refer to the provided Jupyter notebook and Cohere’s AWS GitHub repository, which offer additional examples showcasing various use cases and applications.

Clean Up

After you’ve finished running the notebook and exploring the Cohere Command R and R+ models, it’s essential to clean up the resources you’ve created to avoid incurring unnecessary charges. Follow these steps to delete the resources and stop the billing:

co.delete_endpoint()
co.close()

Conclusion

In this post, we explored how to leverage the powerful capabilities of Cohere’s Command R and R+ models on Amazon SageMaker JumpStart. These state-of-the-art large language models are specifically designed to excel at real-world enterprise use cases, offering unparalleled performance and scalability. With their availability on SageMaker JumpStart and AWS Marketplace, you now have seamless access to these cutting-edge models, enabling you to unlock new levels of productivity and innovation in your natural language processing projects.


About the authors

Pradeep Prabhakaran is a Customer Solutions Architect at Cohere. In his current role at Cohere, Pradeep acts as a trusted technical advisor to customers and partners, providing guidance and strategies to help them realize the full potential of Cohere’s cutting-edge Generative AI platform. Prior to joining Cohere, Pradeep was a Principal Customer Solutions Manager at Amazon Web Services, where he led Enterprise Cloud transformation programs for large enterprises. Prior to AWS, Pradeep has held various leadership positions at consulting companies such as Slalom, Deloitte, and Wipro. Pradeep holds a Bachelor’s degree in Engineering and is based in Dallas, TX.

James Yi is a Senior AI/ML Partner Solutions Architect at Amazon Web Services. He spearheads AWS’s strategic partnerships in Emerging Technologies, guiding engineering teams to design and develop cutting-edge joint solutions in GenAI. He enables field and technical teams to seamlessly deploy, operate, secure, and integrate partner solutions on AWS. James collaborates closely with business leaders to define and execute joint Go-To-Market strategies, driving cloud-based business growth. Outside of work, he enjoys playing soccer, traveling, and spending time with his family.

Read More

Revolutionizing large language model training with Arcee and AWS Trainium

Revolutionizing large language model training with Arcee and AWS Trainium

This is a guest post by Mark McQuade, Malikeh Ehghaghi, and Shamane Siri from Arcee.

In recent years, large language models (LLMs) have gained attention for their effectiveness, leading various industries to adapt general LLMs to their data for improved results, making efficient training and hardware availability crucial. At Arcee, we focus primarily on enhancing the domain adaptation of LLMs in a client-centric manner. Arcee’s innovative continual pre-training (CPT) and model merging techniques have brought a significant leap forward in the efficient training of LLMs, with particularly strong evaluations in the medical, legal, and financial verticals. Close collaboration with AWS Trainium has also played a major role in making the Arcee platform extremely performant, not only accelerating model training but also reducing overall costs and enforcing compliance and data integrity in the secure AWS environment. In this post, we show you how efficient we make our continual pre-training by using Trainium chips.

Understanding continual pre-training

Arcee recognizes the critical importance of continual CPT [1] in tailoring models to specific domains, as evidenced by previous studies such as PMC-LLaMA [2] and ChipNeMo [3]. These projects showcase the power of domain adaptation pre-training in enhancing model performance across diverse fields, from medical applications to industrial chip design. Inspired by these endeavors, our approach to CPT involves extending the training of base models like Llama 2 using domain-specific datasets, allowing us to fine-tune models to the nuances of specialized fields. To further amplify the efficiency of our CPT process, we collaborated with the Trainium team, using their cutting-edge technology to enhance a Llama 2 [4] model using a PubMed dataset [2] comprising 88 billion tokens. This collaboration represents a significant milestone in our quest for innovation, and through this post, we’re excited to share the transformative insights we’ve gained. Join us as we unveil the future of domain-specific model adaptation and the potential of CPT with Trainium in optimizing model performance for real-world applications.

Dataset collection

We followed the methodology outlined in the PMC-Llama paper [6] to assemble our dataset, which includes PubMed papers sourced from the Semantic Scholar API and various medical texts cited within the paper, culminating in a comprehensive collection of 88 billion tokens. For further details on the dataset, the original paper offers in-depth information.

To prepare this dataset for training, we used the Llama 2 tokenizer within an AWS Glue pipeline for efficient processing. We then organized the data so that each row contained 4,096 tokens, adhering to recommendations from the Neuron Distributed tutorials.

Why Trainium?

Continual pre-training techniques like the ones described in this post require access to high-performance compute instances, which has become more difficult to get as more developers are using generative artificial intelligence (AI) and LLMs for their applications. Traditionally, these workloads have been deployed to GPUs; however, in recent years, the cost and availability of GPUs has stifled model building innovations. With the introduction of Trainium, we are able to unlock new techniques that enable us to continue model innovations that will allow us to build models more efficiently and most importantly, at lower costs. Trainium is the second-generation machine learning (ML) accelerator that AWS purpose built to help developers access high-performance model training accelerators to help lower training costs by up to 50% over comparable Amazon Elastic Compute Cloud (Amazon EC2) instances. With Trainium available in AWS Regions worldwide, developers don’t have to take expensive, long-term compute reservations just to get access to clusters of GPUs to build their models. Trainium instances offer developers the performance they need with the elasticity they want to optimize both for training efficiency and lowering model building costs.

Setting up the Trainium cluster

We used AWS ParallelCluster to build a High Performance Computing (HPC) compute environment that uses Trn1 compute nodes to run our distributed ML training job (see the GitHub tutorial). You can also use developer flows like Amazon SageMaker, Amazon Elastic Kubernetes Service (Amazon EKS), Ray, or others (to learn more, see Developer Flows). After the nodes were launched, we ran a training task to confirm that the nodes were working, and used slurm commands to check the job status. In this part, we used the AWS pcluster command to run a .yaml file to generate the cluster. Our cluster consisted of 16 nodes, each equipped with a trn1n.32xlarge instance featuring 32 GB of VRAM.

We set up our ParallelCluster infrastructure as shown in the following diagram (source).

As shown in the preceding figure, inside a VPC, there are two subnets, a public one and a private one. The head node resides in the public subnet, and the compute fleet (in this case, Trn1 instances) is in the private subnet. A NAT gateway is also needed in order for nodes in the private subnet to connect to clients outside the VPC. In the following section, we describe how to set up the necessary infrastructure for Trn1 ParallelCluster.

Set up the environment

To set up your environment, complete the following steps:

  1. Install the VPC and necessary components for ParallelCluster. For instructions, see VPC setup for ParallelCluster with Trn1.
  2. Create and launch ParallelCluster in the VPC. For instructions, see Create ParallelCluster.

Now you can launch a training job to submit a model training script as a slurm job.

Deploy to Trainium

Trainium-based EC2 Trn1 instances use the AWS Neuron SDK and support common ML frameworks like PyTorch and TensorFlow. Neuron allows for effortless distributed training and has integrations with Megatron Nemo and Neuron Distributed.

When engaging with Trainium, it’s crucial to understand several key parameters:

  • Tensor parallel size – This determines the level of tensor parallelization, particularly in self-attention computations within transformers, and is crucial for optimizing memory usage (not computational time efficiency) during model loading
  • NeuronCores – Each Trainium device has two NeuronCores, and an eight-node setup equates to a substantial 256 cores
  • Mini batch – This reflects the number of examples processed in each batch as determined by the data loader
  • World size – This is the total count of nodes involved in the training operation

A deep understanding of these parameters is vital for anyone looking to harness the power of Trainium devices effectively.

Train the model

For this post, we train a Llama 2 7B model with tensor parallelism. For a streamlined and effective training process, we adhered to the following steps:

  1. Download the Llama 2 full checkpoints (model weights and tokenizer) from Hugging Face.
  2. Convert these checkpoints to a format compatible with the Neuron Distributed setup, so they can be efficiently utilized in our training infrastructure.
  3. Determine the number of steps required per epoch, incorporating the effective batch size and dataset size to tailor the training process to our specific needs.
  4. Launch the training job, carefully monitoring its progress and performance.
  5. Periodically save training checkpoints. Initially, this process may be slow due to its synchronous nature, but improvements are anticipated as the NeuronX team works on enhancements.
  6. Finally, convert the saved checkpoints back to a standard format for subsequent use, employing scripts for seamless conversion.

For more details, you can find the full implementation of the training steps in the following GitHub repository.

Clean up

Don’t forget to tear down any resources you set up in this post.

Results

Our study focused on evaluating the quality of the CPT-enhanced checkpoints. We monitored the perplexity of a held-out PubMed dataset [6] across various checkpoints obtained during training, which provided valuable insights into the model’s performance improvements over time.

Through this journey, we’ve advanced our model’s capabilities, and hope to contribute to the broader community’s understanding of effective model adaptation strategies.

The following figure shows the perplexity of the baseline Llama 2 7B checkpoint vs. its CPT-enhanced checkpoint on the PMC test dataset. Based on these findings, continual pre-training on domain-specific raw data, specifically PubMed papers in our study, resulted in an enhancement of the Llama 2 7B checkpoint, leading to improved perplexity of the model on the PMC test set.

The following figure shows the perplexity of the CPT-enhanced checkpoints of the Llama 2 7B model across varying numbers of trained tokens. The increasing number of trained tokens correlated with enhanced model performance, as measured by the perplexity metric.

The following figure shows the perplexity comparison between the baseline Llama 2 7B model and its CPT-enhanced checkpoints, with and without data mixing. This underscores the significance of data mixing, where we have added 1% of general tokens to the domain-specific dataset, wherein utilizing a CPT-enhanced checkpoint with data mixing exhibited better performance compared to both the baseline Llama 2 7B model and the CPT-enhanced checkpoint solely trained on PubMed data.

Conclusion

Arcee’s innovative approach to CPT and model merging, as demonstrated through our collaboration with Trainium, signifies a transformative advancement in the training of LLMs, particularly in specialized domains such as medical research. By using the extensive capabilities of Trainium, we have not only accelerated the model training process, but also significantly reduced costs, with an emphasis on security and compliance that provides data integrity within a secure AWS environment.

The results from our training experiments, as seen in the improved perplexity scores of domain-specific models, underscore the effectiveness of our method in enhancing the performance and applicability of LLMs across various fields. This is particularly evident from the direct comparisons of time-to-train metrics between Trainium and traditional GPU setups, where Trainium’s efficiency and cost-effectiveness shine.

Furthermore, our case study using PubMed data for domain-specific training highlights the potential of Arcee’s CPT strategies to fine-tune models to the nuances of highly specialized datasets, thereby creating more accurate and reliable tools for professionals in those fields.

As we continue to push the boundaries of what’s possible in LLM training, we encourage researchers, developers, and enterprises to take advantage of the scalability, efficiency, and enhanced security features of Trainium and Arcee’s methodologies. These technologies not only facilitate more effective model training, but also open up new avenues for innovation and practical application in AI-driven industries.

The integration of Trainium’s advanced ML capabilities with Arcee’s pioneering strategies in model training and adaptation is poised to revolutionize the landscape of LLM development, making it more accessible, economical, and tailored to meet the evolving demands of diverse industries.

To learn more about Arcee.ai, visit Arcee.ai or reach out to our team.

Additional resources

References

  1. Gupta, Kshitij, et al. “Continual Pre-Training of Large Language Models: How to (re) warm your model?.” arXiv preprint arXiv:2308.04014 (2023).
  2. Wu, Chaoyi, et al. “Pmc-LLaMA: Towards building open-source language models for medicine.” arXiv preprint arXiv:2305.10415 6 (2023).
  3. Liu, Mingjie, et al. “Chipnemo: Domain-adapted llms for chip design.” arXiv preprint arXiv:2311.00176 (2023).
  4. Touvron, Hugo, et al. “Llama 2: Open foundation and fine-tuned chat models.” arXiv preprint arXiv:2307.09288 (2023).
  5. https://aws.amazon.com/ec2/instance-types/trn1/
  6. Wu, C., Zhang, X., Zhang, Y., Wang, Y., & Xie, W. (2023). Pmc-llama: Further fine tuning llama on medical papers. arXiv preprint arXiv:2304.14454.

About the Authors

Mark McQuade is the CEO/Co-Founder at Arcee. Mark co-founded Arcee with a vision to empower enterprises with industry-specific AI solutions. This idea emerged from his time at Hugging Face, where he helped spearhead the Monetization team, collaborating with high-profile enterprises. This frontline experience exposed him to critical industry pain points: the reluctance to rely on closed source APIs and the challenges of training open source models without compromising data security.

Shamane Siri Ph.D. is the Head of Applied NLP Research at Arcee. Before joining Arcee, Shamane worked in both industry and academia, developing recommendation systems using language models to address the cold start problem, and focusing on information retrieval, multi-modal emotion recognition, and summarization. Shamane has also collaborated with the Hugging Face Transformers crew and Meta Reality Labs on cutting-edge projects. He holds a PhD from the University of Auckland, where he specialized in domain adaptation of foundational language models.

Malikeh Ehghaghi is an Applied NLP Research Engineer at Arcee. Malikeh’s research interests are NLP, domain-adaptation of LLMs, ML for healthcare, and responsible AI. She earned an MScAC degree in Computer Science from the University of Toronto. She previously collaborated with Lavita AI as a Machine Learning Consultant, developing healthcare chatbots in partnership with Dartmouth Center for Precision Health and Artificial Intelligence. She also worked as a Machine Learning Research Scientist at Cambridge Cognition Inc. and Winterlight Labs, with a focus on monitoring and detection of mental health disorders through speech and language. Malikeh has authored several publications presented at top-tier conferences such as ACL, COLING, AAAI, NAACL, IEEE-BHI, and MICCAI.

Read More

Databricks DBRX is now available in Amazon SageMaker JumpStart

Databricks DBRX is now available in Amazon SageMaker JumpStart

Today, we are excited to announce that the DBRX model, an open, general-purpose large language model (LLM) developed by Databricks, is available for customers through Amazon SageMaker JumpStart to deploy with one click for running inference. The DBRX LLM employs a fine-grained mixture-of-experts (MoE) architecture, pre-trained on 12 trillion tokens of carefully curated data and a maximum context length of 32,000 tokens.

You can try out this model with SageMaker JumpStart, a machine learning (ML) hub that provides access to algorithms and models so you can quickly get started with ML. In this post, we walk through how to discover and deploy the DBRX model.

What is the DBRX model

DBRX is a sophisticated decoder-only LLM built on transformer architecture. It employs a fine-grained MoE architecture, incorporating 132 billion total parameters, with 36 billion of these parameters being active for any given input.

The model underwent pre-training using a dataset consisting of 12 trillion tokens of text and code. In contrast to other open MoE models like Mixtral and Grok-1, DBRX features a fine-grained approach, using a higher quantity of smaller experts for optimized performance. Compared to other MoE models, DBRX has 16 experts and chooses 4.

The model is made available under the Databricks Open Model license, for use without restrictions.

What is SageMaker JumpStart

SageMaker JumpStart is a fully managed platform that offers state-of-the-art foundation models for various use cases such as content writing, code generation, question answering, copywriting, summarization, classification, and information retrieval. It provides a collection of pre-trained models that you can deploy quickly and with ease, accelerating the development and deployment of ML applications. One of the key components of SageMaker JumpStart is the Model Hub, which offers a vast catalog of pre-trained models, such as DBRX, for a variety of tasks.

You can now discover and deploy DBRX models with a few clicks in Amazon SageMaker Studio or programmatically through the SageMaker Python SDK, enabling you to derive model performance and MLOps controls with Amazon SageMaker features such as Amazon SageMaker Pipelines, Amazon SageMaker Debugger, or container logs. The model is deployed in an AWS secure environment and under your VPC controls, helping provide data security.

Discover models in SageMaker JumpStart

You can access the DBRX model through SageMaker JumpStart in the SageMaker Studio UI and the SageMaker Python SDK. In this section, we go over how to discover the models in SageMaker Studio.

SageMaker Studio is an integrated development environment (IDE) that provides a single web-based visual interface where you can access purpose-built tools to perform all ML development steps, from preparing data to building, training, and deploying your ML models. For more details on how to get started and set up SageMaker Studio, refer to Amazon SageMaker Studio.

In SageMaker Studio, you can access SageMaker JumpStart by choosing JumpStart in the navigation pane.

From the SageMaker JumpStart landing page, you can search for “DBRX” in the search box. The search results will list DBRX Instruct and DBRX Base.

You can choose the model card to view details about the model such as license, data used to train, and how to use the model. You will also find the Deploy button to deploy the model and create an endpoint.

Deploy the model in SageMaker JumpStart

Deployment starts when you choose the Deploy button. After deployment finishes, you will see that an endpoint is created. You can test the endpoint by passing a sample inference request payload or by selecting the testing option using the SDK. When you select the option to use the SDK, you will see example code that you can use in the notebook editor of your choice in SageMaker Studio.

DBRX Base

To deploy using the SDK, we start by selecting the DBRX Base model, specified by the model_id with value huggingface-llm-dbrx-base. You can deploy any of the selected models on SageMaker with the following code. Similarly, you can deploy DBRX Instruct using its own model ID.

from sagemaker.jumpstart.model import JumpStartModel

accept_eula = True

model = JumpStartModel(model_id="huggingface-llm-dbrx-base")
predictor = model.deploy(accept_eula=accept_eula)

This deploys the model on SageMaker with default configurations, including the default instance type and default VPC configurations. You can change these configurations by specifying non-default values in JumpStartModel. The Eula value must be explicitly defined as True in order to accept the end-user license agreement (EULA). Also make sure you have the account-level service limit for using ml.p4d.24xlarge or ml.pde.24xlarge for endpoint usage as one or more instances. You can follow the instructions here in order to request a service quota increase.

After it’s deployed, you can run inference against the deployed endpoint through the SageMaker predictor:

payload = {
    "inputs": "Hello!",
    "parameters": {
        "max_new_tokens": 10,
    },
}
predictor.predict(payload)

Example prompts

You can interact with the DBRX Base model like any standard text generation model, where the model processes an input sequence and outputs predicted next words in the sequence. In this section, we provide some example prompts and sample output.

Code generation

Using the preceding example, we can use code generation prompts as follows:

payload = { 
      "inputs": "Write a function to read a CSV file in Python using pandas library:", 
      "parameters": { 
          "max_new_tokens": 30, }, } 
           response = predictor.predict(payload)["generated_text"].strip() 
           print(response)

The following is the output:

import pandas as pd 
df = pd.read_csv("file_name.csv") 
#The above code will import pandas library and then read the CSV file using read_csv

Sentiment analysis

You can perform sentiment analysis using a prompt like the following with DBRX:

payload = {
"inputs": """
Tweet: "I am so excited for the weekend!"
Sentiment: Positive

Tweet: "Why does traffic have to be so terrible?"
Sentiment: Negative

Tweet: "Just saw a great movie, would recommend it."
Sentiment: Positive

Tweet: "According to the weather report, it will be cloudy today."
Sentiment: Neutral

Tweet: "This restaurant is absolutely terrible."
Sentiment: Negative

Tweet: "I love spending time with my family."
Sentiment:""",
"parameters": {
"max_new_tokens": 2,
},
}
response = predictor.predict(payload)["generated_text"].strip()
print(response)

The following is the output:

Positive

Question answering

You can use a question answering prompt like the following with DBRX:

# Question answering
payload = {
    "inputs": "Respond to the question: How did the development of transportation systems, such as railroads and steamships, impact global trade and cultural exchange?",
    "parameters": {
        "max_new_tokens": 225,
    },
}
response = predictor.predict(payload)["generated_text"].strip()
print(response)

The following is the output:

The development of transportation systems, such as railroads and steamships, impacted global trade and cultural exchange in a number of ways. 
The documents provided show that the development of these systems had a profound effect on the way people and goods were able to move around the world. 
One of the most significant impacts of the development of transportation systems was the way it facilitated global trade. 
The documents show that the development of railroads and steamships made it possible for goods to be transported more quickly and efficiently than ever before. 
This allowed for a greater exchange of goods between different parts of the world, which in turn led to a greater exchange of ideas and cultures. 
Another impact of the development of transportation systems was the way it facilitated cultural exchange. The documents show that the development of railroads and steamships made it possible for people to travel more easily and quickly than ever before. 
This allowed for a greater exchange of ideas and cultures between different parts of the world. Overall, the development of transportation systems, such as railroads and steamships, had a profound impact on global trade and cultural exchange.

 

DBRX Instruct

The instruction-tuned version of DBRX accepts formatted instructions where conversation roles must start with a prompt from the user and alternate between user instructions and the assistant (DBRX-instruct). The instruction format must be strictly respected, otherwise the model will generate suboptimal outputs. The template to build a prompt for the Instruct model is defined as follows:

<|im_start|>system
{system_message} <|im_end|>
<|im_start|>user
{human_message} <|im_end|>
<|im_start|>assistantn

<|im_start|> and <|im_end|> are special tokens for beginning of string (BOS) and end of string (EOS). The model can contain multiple conversation turns between system, user, and assistant, allowing for the incorporation of few-shot examples to enhance the model’s responses.

The following code shows how you can format the prompt in instruction format:

from typing import Dict, List

def format_instructions(instructions: List[Dict[str, str]]) -> List[str]:
    """Format instructions where conversation roles must alternate system/user/assistant/user/assistant/..."""
    prompt: List[str] = []
    for instruction in instructions:
        if instruction["role"] == "system":
            prompt.extend(["<|im_start|>systemn", (instruction["content"]).strip(), " <|im_end|>n"])
        elif instruction["role"] == "user":
            prompt.extend(["<|im_start|>usern", (instruction["content"]).strip(), " <|im_end|>n"])
        else:
            raise ValueError(f"Invalid role: {instruction['role']}. Role must be either 'user' or 'system'.")
    prompt.extend(["<|im_start|>assistantn"])
    return "".join(prompt)

def print_instructions(prompt: str, response: str) -> None:
    bold, unbold = '33[1m', '33[0m'
    print(f"{bold}> Input{unbold}n{prompt}nn{bold}> Output{unbold}n{response['generated_text'].strip()}n")

Knowledge retrieval

You can use the following prompt for knowledge retrieval:

instructions = [
{"role": "system", "content": "You are an expert in suggesting diet plans optimized for different sports or fitness activities, taking into account energy requirements, and recovery requirements."},
{"role": "user", "content": "I am trying to build muscle mass and increase my strength. What kind of high-protein, calorie-dense diet plan would you recommend for me?"}
]
prompt = format_instructions(instructions)
payload = {
"inputs": prompt,
"parameters": {"max_new_tokens": 700, "temperature": 0.5, "do_sample": True}
}
response=predictor.predict(payload)
print_instructions(prompt, response)

The following is the output:

> Input
<|im_start|>system
You are an expert in suggesting diet plans optimized for different sports or fitness activities, taking into account energy requirements, and recovery requirements. <|im_end|>
<|im_start|>user
I am trying to build muscle mass and increase my strength. What kind of high-protein, calorie-dense diet plan would you recommend for me? <|im_end|>
<|im_start|>assistant

> Output
Absolutely, I'd be happy to help you with that. Given your goal of building muscle mass and increasing strength, you'll want a diet that's high in protein and provides sufficient calories to support muscle growth and recovery. Here's a basic outline of what your diet plan could look like:

1. **Breakfast:**
   - 3 whole eggs
   - 3 egg whites
   - 1 cup of oatmeal
   - 1 scoop of whey protein
   - 1 banana

2. **Mid-Morning Snack:**
   - 1.5 ounces of almonds
   - 1 apple

3. **Lunch:**
   - 8 ounces of chicken breast
   - 1 cup of brown rice
   - 1 cup of mixed vegetables

4. **Afternoon Snack:**
   - 1.5 cups of Greek yogurt
   - 1 cup of mixed berries

5. **Dinner:**
   - 8 ounces of salmon or lean beef
   - 1 cup of quinoa
   - 1.5 cups of mixed vegetables

6. **Evening Snack:**
   - 1 scoop of casein protein
   - 1 tablespoon of almond butter

This is a high-protein, calorie-dense diet that includes a variety of nutrient-rich foods. It's important to note that everyone's nutritional needs are different, so you may need to adjust portion sizes or food choices based on your specific needs and preferences.

Also, remember to stay hydrated by drinking plenty of water throughout the day. Regular exercise, particularly strength training, is crucial to achieving your goals. It's also important to get enough sleep to aid in muscle recovery and growth.

Before starting any new diet or exercise program, it's a good idea to consult with a healthcare professional or a registered dietitian. They can provide personalized advice based on your health history and current lifestyle.

Code generation

DBRX models demonstrate benchmarked strengths for coding tasks. For example, see the following code:

instructions = [ 
{"role": "system", "content": "You are an expert in coding and programming. Your role is to provide examples, explanations, and insights related to programming languages, algorithms, data structures, and software development practices."}, 
{"role": "user", "content": "Write a Python script to read a CSV file containing stock prices and plot the closing prices over time using Matplotlib."} 
] 
prompt = format_instructions(instructions) 
payload = { "inputs": prompt, "parameters": {"max_new_tokens": 315, "temperature": 0.3, "do_sample": True} } 
response=predictor.predict(payload) 
print_instructions(prompt, response)

The following is the output:

> Input
<|im_start|>system
You are an expert in coding and programming. Your role is to provide examples, explanations, and insights related to programming languages, algorithms, data structures, and software development practices. <|im_end|>
<|im_start|>user
Write a Python script to read a CSV file containing stock prices and plot the closing prices over time using Matplotlib. <|im_end|>
<|im_start|>assistant

> Output
Sure, I'd be happy to help you with that. Here's a simple Python script that uses the `pandas` library to read a CSV file and `matplotlib` to plot the closing prices over time.

```python
import pandas as pd
import matplotlib.pyplot as plt

# Load the data from a CSV file
df = pd.read_csv('stock_prices.csv')

# Plot the closing prices over time
plt.figure(figsize=(12,5))
plt.plot(df['Date'], df['Close'])
plt.xlabel('Date')
plt.ylabel('Price')
plt.title('Stock Prices Over Time')
plt.show()
```

In this script, replace `'stock_prices.csv'` with the path to your CSV file. The file should have columns named 'Date' and 'Close' for this script to work correctly. If your file uses different column names, you'll need to adjust the script accordingly.

The `pandas` library is a powerful tool for data manipulation and analysis. It can read data from a variety of sources, including CSV files, and represent the data in a tabular format. The `matplotlib` library is a popular choice for creating static, animated, and interactive visualizations in Python.

This script first loads the data from the CSV file into a pandas DataFrame. It then plots the 'Close' column against the 'Date' column using matplotlib's `plot()` function. The `figure()` function is used to specify the size of the plot, and `show()` is used to display the plot.

Mathematics and reasoning

The DBRX models also report strengths in mathematic accuracy. For example, see the following code:

instructions = [
  {"role": "system", "content": "You are an expert in mathematics and reasoning. Your role is to provide examples, explanations, and insights related to mathematical concepts, problem-solving techniques, and logical reasoning."},
  {"role": "user", "content": "A car rental company charges $40 per day plus $0.25 per mile driven. If I rented a car for 3 days and drove 250 miles, what would be the total cost? Break down the calculation steps before giving the final answer."}
]
prompt = format_instructions(instructions)
payload = {
"inputs": prompt,
"parameters": {"max_new_tokens": 400, "temperature": 0.2, "do_sample": True}
}
response=predictor.predict(payload)
print_instructions(prompt, response)

DBRX can provide comprehension as shown in the following output with the math logic:

> Input
<|im_start|>system
You are an expert in mathematics and reasoning. Your role is to provide examples, explanations, and insights related to mathematical concepts, problem-solving techniques, and logical reasoning. <|im_end|>
<|im_start|>user
A car rental company charges $40 per day plus $0.25 per mile driven. If I rented a car for 3 days and drove 250 miles, what would be the total cost? Break down the calculation steps before giving the final answer. <|im_end|>
<|im_start|>assistant

> Output
Sure, I'd be happy to help with that! The problem you've presented involves both fixed costs (the daily rate) and variable costs (the cost per mile). Let's break it down:

1. First, let's calculate the total cost of the fixed daily rate. The car rental company charges $40 per day, and you rented the car for 3 days. So, the total cost for the daily rate is $40/day * 3 days = $120.

2. Next, let's calculate the total cost of the variable rate, which is the cost per mile driven. The company charges $0.25 per mile, and you drove 250 miles. So, the total cost for the mileage is $0.25/mile * 250 miles = $62.50.

3. Finally, to find the total cost of renting the car, we add the total cost of the daily rate and the total cost of the mileage. So, the total cost is $120 + $62.50 = $182.50.

So, the total cost of renting the car for 3 days and driving 250 miles would be $182.50.

Clean up

After you’re done running the notebook, make sure to delete all resources that you created in the process so your billing is stopped. Use the following code:

predictor.delete_model()
predictor.delete_endpoint()

Conclusion

In this post, we showed you how to get started with DBRX in SageMaker Studio and deploy the model for inference. Because foundation models are pre-trained, they can help lower training and infrastructure costs and enable customization for your use case. Visit SageMaker JumpStart in SageMaker Studio now to get started.

Resources


About the Authors

Shikhar Kwatra is an AI/ML Specialist Solutions Architect at Amazon Web Services, working with a leading Global System Integrator. He has earned the title of one of the Youngest Indian Master Inventors with over 400 patents in the AI/ML and IoT domains. He has over 8 years of industry experience from startups to large-scale enterprises, from IoT Research Engineer, Data Scientist, to Data & AI Architect. Shikhar aids in architecting, building, and maintaining cost-efficient, scalable cloud environments for organizations and supports GSI partners in building strategic industry

Niithiyn Vijeaswaran is a Solutions Architect at AWS. His area of focus is generative AI and AWS AI Accelerators. He holds a Bachelor’s degree in Computer Science and Bioinformatics. Niithiyn works closely with the Generative AI GTM team to enable AWS customers on multiple fronts and accelerate their adoption of generative AI. He’s an avid fan of the Dallas Mavericks and enjoys collecting sneakers.

Sebastian Bustillo is a Solutions Architect at AWS. He focuses on AI/ML technologies with a profound passion for generative AI and compute accelerators. At AWS, he helps customers unlock business value through generative AI. When he’s not at work, he enjoys brewing a perfect cup of specialty coffee and exploring the world with his wife.

Armando Diaz is a Solutions Architect at AWS. He focuses on generative AI, AI/ML, and data analytics. At AWS, Armando helps customers integrating cutting-edge generative AI capabilities into their systems, fostering innovation and competitive advantage. When he’s not at work, he enjoys spending time with his wife and family, hiking, and traveling the world.

Read More

Knowledge Bases in Amazon Bedrock now simplifies asking questions on a single document

Knowledge Bases in Amazon Bedrock now simplifies asking questions on a single document

At AWS re:Invent 2023, we announced the general availability of Knowledge Bases for Amazon Bedrock. With Knowledge Bases for Amazon Bedrock, you can securely connect foundation models (FMs) in Amazon Bedrock to your company data for fully managed Retrieval Augmented Generation (RAG).

In previous posts, we covered new capabilities like hybrid search support, metadata filtering to improve retrieval accuracy, and how Knowledge Bases for Amazon Bedrock manages the end-to-end RAG workflow.

Today, we’re introducing the new capability to chat with your document with zero setup in Knowledge Bases for Amazon Bedrock. With this new capability, you can securely ask questions on single documents, without the overhead of setting up a vector database or ingesting data, making it effortless for businesses to use their enterprise data. You only need to provide a relevant data file as input and choose your FM to get started.

But before we jump into the details of this feature, let’s start with the basics and understand what RAG is, its benefits, and how this new capability enables content retrieval and generation for temporal needs.

What is Retrieval Augmented Generation?

FM-powered artificial intelligence (AI) assistants have limitations, such as providing outdated information or struggling with context outside their training data. RAG addresses these issues by allowing FMs to cross-reference authoritative knowledge sources before generating responses.

With RAG, when a user asks a question, the system retrieves relevant context from a curated knowledge base, such as company documentation. It provides this context to the FM, which uses it to generate a more informed and precise response. RAG helps overcome FM limitations by augmenting its capabilities with an organization’s proprietary knowledge, enabling chatbots and AI assistants to provide up-to-date, context-specific information tailored to business needs without retraining the entire FM. At AWS, we recognize RAG’s potential and have worked to simplify its adoption through Knowledge Bases for Amazon Bedrock, providing a fully managed RAG experience.

Short-term and instant information needs

Although a knowledge base does all the heavy lifting and serves as a persistent large store of enterprise knowledge, you might require temporary access to data for specific tasks or analysis within isolated user sessions. Traditional RAG approaches are not optimized for these short-term, session-based data access scenarios.

Businesses incur charges for data storage and management. This may make RAG less cost-effective for organizations with highly dynamic or ephemeral information requirements, especially when data is only needed for specific, isolated tasks or analyses.

Ask questions on a single document with zero setup

This new capability to chat with your document within Knowledge Bases for Amazon Bedrock addresses the aforementioned challenges. It provides a zero-setup method to use your single document for content retrieval and generation-related tasks, along with the FMs provided by Amazon Bedrock. With this new capability, you can ask questions of your data without the overhead of setting up a vector database or ingesting data, making it effortless to use your enterprise data.

You can now interact with your documents in real time without prior data ingestion or database configuration. You don’t need to take any further data readiness steps before querying the data.

This zero-setup approach makes it straightforward to use your enterprise information assets with generative AI using Amazon Bedrock.

Use cases and benefits

Consider a recruiting firm that needs to analyze resumes and match candidates with suitable job opportunities based on their experience and skills. Previously, you would have to set up a knowledge base, invoking a data ingestion workflow to make sure only authorized recruiters can access the data. Additionally, you would need to manage cleanup when the data was no longer required for a session or candidate. In the end, you would pay more for the vector database storage and management than for the actual FM usage. This new feature in Knowledge Bases for Amazon Bedrock enables recruiters to quickly and ephemerally analyze resumes and match candidates with suitable job opportunities based on the candidate’s experience and skill set.

For another example, consider a product manager at a technology company who needs to quickly analyze customer feedback and support tickets to identify common issues and areas for improvement. With this new capability, you can simply upload a document to extract insights in no time. For example, you could ask “What are the requirements for the mobile app?” or “What are the common pain points mentioned by customers regarding our onboarding process?” This feature empowers you to rapidly synthesize this information without the hassle of data preparation or any management overhead. You can also request summaries or key takeaways, such as “What are the highlights from this requirements document?”

The benefits of this feature extend beyond cost savings and operational efficiency. By eliminating the need for vector databases and data ingestion, this new capability within Knowledge Bases for Amazon Bedrock helps secure your proprietary data, making it accessible only within the context of isolated user sessions.

Now that we’ve covered the feature benefits and the use cases it enables, let’s dive into how you can start using this new feature from Knowledge Bases for Amazon Bedrock.

Chat with your document in Knowledge Bases for Amazon Bedrock

You have multiple options to begin using this feature:

  • The Amazon Bedrock console
  • The Amazon Bedrock RetrieveAndGenerate API (SDK)

Let’s see how we can get started using the Amazon Bedrock console:

  1. On the Amazon Bedrock console, under Orchestration in the navigation pane, choose Knowledge bases.
  2. Choose Chat with your document.
  3. Under Model, choose Select model.
  4. Choose your model. For this example, we use the Claude 3 Sonnet model (we are only supporting Sonnet at the time of the launch).
  5. Choose Apply.
  6. Under Data, you can upload the document you want to chat with or point to the Amazon Simple Storage Service (Amazon S3) bucket location that contains your file. For this post, we upload a document from our computer.

The supported file formats are PDF, MD (Markdown), TXT, DOCX, HTML, CSV, XLS, and XLSX. Make that the file size does not exceed 10 MB and contains no more than 20,000 tokens. A token is considered to be a unit of text, such as a word, sub-word, number, or symbol, that is processed as a single entity. Due to the preset ingestion token limit, it is recommended to use a file under 10MB. However, a text-heavy file, that is much smaller than 10MB, can potentially breach the token limit.

You’re now ready to chat with your document.

As shown in the following screenshot, you can chat with your document in real time.

To customize your prompt, enter your prompt under System prompt.

Similarly, you can use the AWS SDK through the retrieve_and_generate API in major coding languages. In the following example, we use the AWS SDK for Python (Boto3):

import boto3

bedrock_client = boto3.client(service_name='bedrock-agent-runtime')
model_id = "your_model_id_here"    # Replace with your modelID
document_uri = "your_s3_uri_here"  # Replace with your S3 URI

def retrieveAndGenerate(input_text, sourceType, model_id, document_s3_uri=None, data=None):
    region = 'us-west-2'  
    model_arn = f'arn:aws:bedrock:{region}::foundation-model/{model_id}'

    if sourceType == "S3":
        return bedrock_client.retrieve_and_generate(
            input={'text': input_text},
            retrieveAndGenerateConfiguration={
                'type': 'EXTERNAL_SOURCES',
                'externalSourcesConfiguration': {
                    'modelArn': model_arn,
                    'sources': [
                        {
                            "sourceType": sourceType,
                            "s3Location": {
                                "uri": document_s3_uri  
                            }
                        }
                    ]
                }
            }
        )
        
    else:
        return bedrock_client.retrieve_and_generate(
            input={'text': input_text},
            retrieveAndGenerateConfiguration={
                'type': 'EXTERNAL_SOURCES',
                'externalSourcesConfiguration': {
                    'modelArn': model_arn,
                    'sources': [
                        {
                            "sourceType": sourceType,
                            "byteContent": {
                                "identifier": "testFile.txt",
                                "contentType": "text/plain",
                                "data": data  
                            }
                        }
                    ]
                }
            }
        )

response = retrieveAndGenerate(
                                input_text="What is the main topic of this document?",
                                sourceType="S3", 
                                model_id=model_id,
                                document_s3_uri=document_uri
                              )
                    
print(response['output']['text'])

Conclusion

In this post, we covered how Knowledge Bases for Amazon Bedrock now simplifies asking questions on a single document. We explored the core concepts behind RAG, the challenges this new feature addresses, and the various use cases it enables across different roles and industries. We also demonstrated how to configure and use this capability through the Amazon Bedrock console and the AWS SDK, showcasing the simplicity and flexibility of this feature, which provides a zero-setup solution to gather information from a single document, without setting up a vector database.

To further explore the capabilities of Knowledge Bases for Amazon Bedrock, refer to the following resources:

Share and learn with our generative AI community at community.aws.


About the authors

Suman Debnath is a Principal Developer Advocate for Machine Learning at Amazon Web Services. He frequently speaks at AI/ML conferences, events, and meetups around the world. He is passionate about large-scale distributed systems and is an avid fan of Python.

Sebastian Munera is a Software Engineer in the Amazon Bedrock Knowledge Bases team at AWS where he focuses on building customer solutions that leverage Generative AI and RAG applications. He has previously worked on building Generative AI-based solutions for customers to streamline their processes and Low code/No code applications. In his spare time he enjoys running, lifting and tinkering with technology.

Read More

Deploy a Hugging Face (PyAnnote) speaker diarization model on Amazon SageMaker as an asynchronous endpoint

Deploy a Hugging Face (PyAnnote) speaker diarization model on Amazon SageMaker as an asynchronous endpoint

Speaker diarization, an essential process in audio analysis, segments an audio file based on speaker identity. This post delves into integrating Hugging Face’s PyAnnote for speaker diarization with Amazon SageMaker asynchronous endpoints.

We provide a comprehensive guide on how to deploy speaker segmentation and clustering solutions using SageMaker on the AWS Cloud. You can use this solution for applications dealing with multi-speaker (over 100) audio recordings.

Solution overview

Amazon Transcribe is the go-to service for speaker diarization in AWS. However, for non-supported languages, you can use other models (in our case, PyAnnote) that will be deployed in SageMaker for inference. For short audio files where the inference takes up to 60 seconds, you can use real-time inference. For longer than 60 seconds, asynchronous inference should be used. The added benefit of asynchronous inference is the cost savings by auto scaling the instance count to zero when there are no requests to process.

Hugging Face is a popular open source hub for machine learning (ML) models. AWS and Hugging Face have a partnership that allows a seamless integration through SageMaker with a set of AWS Deep Learning Containers (DLCs) for training and inference in PyTorch or TensorFlow, and Hugging Face estimators and predictors for the SageMaker Python SDK. SageMaker features and capabilities help developers and data scientists get started with natural language processing (NLP) on AWS with ease.

The integration for this solution involves using Hugging Face’s pre-trained speaker diarization model using the PyAnnote library. PyAnnote is an open source toolkit written in Python for speaker diarization. This model, trained on the sample audio dataset, enables effective speaker partitioning in audio files. The model is deployed on SageMaker as an asynchronous endpoint setup, providing efficient and scalable processing of diarization tasks.

The following diagram illustrates the solution architecture.Solution architecture

For this post, we use the following audio file.

Stereo or multi-channel audio files are automatically downmixed to mono by averaging the channels. Audio files sampled at a different rate are resampled to 16kHz automatically upon loading.

Prerequisites

Complete the following prerequisites:

  1. Create a SageMaker domain.
  2. Make sure your AWS Identity and Access Management (IAM) user has the necessary access permissions for creating a SageMaker role.
  3. Make sure the AWS account has a service quota for hosting a SageMaker endpoint for an ml.g5.2xlarge instance.

Create a model function for accessing PyAnnote speaker diarization from Hugging Face

You can use the Hugging Face Hub to access the desired pre-trained PyAnnote speaker diarization model. You use the same script for downloading the model file when creating the SageMaker endpoint.

Hugging face

See the following code:

from PyAnnote.audio import Pipeline

def model_fn(model_dir):
# Load the model from the specified model directory
model = Pipeline.from_pretrained(
"PyAnnote/speaker-diarization-3.1",
use_auth_token="Replace-with-the-Hugging-face-auth-token")
return model

Package the model code

Prepare essential files like inference.py, which contains the inference code:

%%writefile model/code/inference.py
from PyAnnote.audio import Pipeline
import subprocess
import boto3
from urllib.parse import urlparse
import pandas as pd
from io import StringIO
import os
import torch

def model_fn(model_dir):
    # Load the model from the specified model directory
    model = Pipeline.from_pretrained(
        "PyAnnote/speaker-diarization-3.1",
        use_auth_token="hf_oBxxxxxxxxxxxx)
    return model 


def diarization_from_s3(model, s3_file, language=None):
    s3 = boto3.client("s3")
    o = urlparse(s3_file, allow_fragments=False)
    bucket = o.netloc
    key = o.path.lstrip("/")
    s3.download_file(bucket, key, "tmp.wav")
    result = model("tmp.wav")
    data = {} 
    for turn, _, speaker in result.itertracks(yield_label=True):
        data[turn] = (turn.start, turn.end, speaker)
    data_df = pd.DataFrame(data.values(), columns=["start", "end", "speaker"])
    print(data_df.shape)
    result = data_df.to_json(orient="split")
    return result


def predict_fn(data, model):
    s3_file = data.pop("s3_file")
    language = data.pop("language", None)
    result = diarization_from_s3(model, s3_file, language)
    return {
        "diarization_from_s3": result
    }

Prepare a requirements.txt file, which contains the required Python libraries necessary to run the inference:

with open("model/code/requirements.txt", "w") as f:
    f.write("transformers==4.25.1n")
    f.write("boto3n")
    f.write("PyAnnote.audion")
    f.write("soundfilen")
    f.write("librosan")
    f.write("onnxruntimen")
    f.write("wgetn")
    f.write("pandas")

Lastly, compress the inference.py and requirements.txt files and save it as model.tar.gz:

!tar zcvf model.tar.gz *

Configure a SageMaker model

Define a SageMaker model resource by specifying the image URI, model data location in Amazon Simple Storage Service (S3), and SageMaker role:

import sagemaker
import boto3

sess = sagemaker.Session()

sagemaker_session_bucket = None
if sagemaker_session_bucket is None and sess is not None:
    sagemaker_session_bucket = sess.default_bucket()

try:
    role = sagemaker.get_execution_role()
except ValueError:
    iam = boto3.client("iam")
    role = iam.get_role(RoleName="sagemaker_execution_role")["Role"]["Arn"]

sess = sagemaker.Session(default_bucket=sagemaker_session_bucket)

print(f"sagemaker role arn: {role}")
print(f"sagemaker bucket: {sess.default_bucket()}")
print(f"sagemaker session region: {sess.boto_region_name}")

Upload the model to Amazon S3

Upload the zipped PyAnnote Hugging Face model file to an S3 bucket:

s3_location = f"s3://{sagemaker_session_bucket}/whisper/model/model.tar.gz"
!aws s3 cp model.tar.gz $s3_location

Create a SageMaker asynchronous endpoint

Configure an asynchronous endpoint for deploying the model on SageMaker using the provided asynchronous inference configuration:

from sagemaker.huggingface.model import HuggingFaceModel
from sagemaker.async_inference.async_inference_config import AsyncInferenceConfig
from sagemaker.s3 import s3_path_join
from sagemaker.utils import name_from_base

async_endpoint_name = name_from_base("custom-asyc")

# create Hugging Face Model Class
huggingface_model = HuggingFaceModel(
    model_data=s3_location,  # path to your model and script
    role=role,  # iam role with permissions to create an Endpoint
    transformers_version="4.17",  # transformers version used
    pytorch_version="1.10",  # pytorch version used
    py_version="py38",  # python version used
)

# create async endpoint configuration
async_config = AsyncInferenceConfig(
    output_path=s3_path_join(
        "s3://", sagemaker_session_bucket, "async_inference/output"
    ),  # Where our results will be stored
    # Add nofitication SNS if needed
    notification_config={
        # "SuccessTopic": "PUT YOUR SUCCESS SNS TOPIC ARN",
        # "ErrorTopic": "PUT YOUR ERROR SNS TOPIC ARN",
    },  #  Notification configuration
)

env = {"MODEL_SERVER_WORKERS": "2"}

# deploy the endpoint endpoint
async_predictor = huggingface_model.deploy(
    initial_instance_count=1,
    instance_type="ml.xx",
    async_inference_config=async_config,
    endpoint_name=async_endpoint_name,
    env=env,
)

Test the endpoint

Evaluate the endpoint functionality by sending an audio file for diarization and retrieving the JSON output stored in the specified S3 output path:

# Replace with a path to audio object in S3
from sagemaker.async_inference import WaiterConfig
res = async_predictor.predict_async(data=data)
print(f"Response output path: {res.output_path}")
print("Start Polling to get response:")

config = WaiterConfig(
  max_attempts=10, #  number of attempts
  delay=10#  time in seconds to wait between attempts
  )
res.get_result(config)
#import waiterconfig

To deploy this solution at scale, we suggest using AWS Lambda, Amazon Simple Notification Service (Amazon SNS), or Amazon Simple Queue Service (Amazon SQS). These services are designed for scalability, event-driven architectures, and efficient resource utilization. They can help decouple the asynchronous inference process from the result processing, allowing you to scale each component independently and handle bursts of inference requests more effectively.

Results

Model output is stored at s3://sagemaker-xxxx /async_inference/output/. The output shows that the audio recording has been segmented into three columns:

  • Start (start time in seconds)
  • End (end time in seconds)
  • Speaker (speaker label)

The following code shows an example of our results:

[0.9762308998, 8.9049235993, "SPEAKER_01"]

[9.533106961, 12.1646859083, "SPEAKER_01"]

[13.1324278438, 13.9303904924, "SPEAKER_00"]

[14.3548387097, 26.1884550085, "SPEAKER_00"]

[27.2410865874, 28.2258064516, "SPEAKER_01"]

[28.3446519525, 31.298811545, "SPEAKER_01"]

Clean up

You can set a scaling policy to zero by setting MinCapacity to 0; asynchronous inference lets you auto scale to zero with no requests. You don’t need to delete the endpoint, it scales from zero when needed again, reducing costs when not in use. See the following code:

# Common class representing application autoscaling for SageMaker 
client = boto3.client('application-autoscaling') 

# This is the format in which application autoscaling references the endpoint
resource_id='endpoint/' + <endpoint_name> + '/variant/' + <'variant1'> 

# Define and register your endpoint variant
response = client.register_scalable_target(
    ServiceNamespace='sagemaker', 
    ResourceId=resource_id,
    ScalableDimension='sagemaker:variant:DesiredInstanceCount', # The number of EC2 instances for your Amazon SageMaker model endpoint variant.
    MinCapacity=0,
    MaxCapacity=5
)

If you want to delete the endpoint, use the following code:

async_predictor.delete_endpoint(async_endpoint_name)

Benefits of asynchronous endpoint deployment

This solution offers the following benefits:

  • The solution can efficiently handle multiple or large audio files.
  • This example uses a single instance for demonstration. If you want to use this solution for hundreds or thousands of videos and use an asynchronous endpoint to process across multiple instances, you can use an auto scaling policy, which is designed for a large number of source documents. Auto scaling dynamically adjusts the number of instances provisioned for a model in response to changes in your workload.
  • The solution optimizes resources and reduces system load by separating long-running tasks from real-time inference.

Conclusion

In this post, we provided a straightforward approach to deploy Hugging Face’s speaker diarization model on SageMaker using Python scripts. Using an asynchronous endpoint provides an efficient and scalable means to deliver diarization predictions as a service, accommodating concurrent requests seamlessly.

Get started today with asynchronous speaker diarization for your audio projects. Reach out in the comments if you have any questions about getting your own asynchronous diarization endpoint up and running.


About the Authors

Sanjay Tiwary is a Specialist Solutions Architect AI/ML who spends his time working with strategic customers to define business requirements, provide L300 sessions around specific use cases, and design AI/ML applications and services that are scalable, reliable, and performant. He has helped launch and scale the AI/ML powered Amazon SageMaker service and has implemented several proofs of concept using Amazon AI services. He has also developed the advanced analytics platform as a part of the digital transformation journey.

Kiran Challapalli is a deep tech business developer with the AWS public sector. He has more than 8 years of experience in AI/ML and 23 years of overall software development and sales experience. Kiran helps public sector businesses across India explore and co-create cloud-based solutions that use AI, ML, and generative AI—including large language models—technologies.

Read More

Evaluate the text summarization capabilities of LLMs for enhanced decision-making on AWS

Evaluate the text summarization capabilities of LLMs for enhanced decision-making on AWS

Organizations across industries are using automatic text summarization to more efficiently handle vast amounts of information and make better decisions. In the financial sector, investment banks condense earnings reports down to key takeaways to rapidly analyze quarterly performance. Media companies use summarization to monitor news and social media so journalists can quickly write stories on developing issues. Government agencies summarize lengthy policy documents and reports to help policymakers strategize and prioritize goals.

By creating condensed versions of long, complex documents, summarization technology enables users to focus on the most salient content. This leads to better comprehension and retention of critical information. The time savings allow stakeholders to review more material in less time, gaining a broader perspective. With enhanced understanding and more synthesized insights, organizations can make better informed strategic decisions, accelerate research, improve productivity, and increase their impact. The transformative power of advanced summarization capabilities will only continue growing as more industries adopt artificial intelligence (AI) to harness overflowing information streams.

In this post, we explore leading approaches for evaluating summarization accuracy objectively, including ROUGE metrics, METEOR, and BERTScore. Understanding the strengths and weaknesses of these techniques can help guide selection and improvement efforts. The overall goal of this post is to demystify summarization evaluation to help teams better benchmark performance on this critical capability as they seek to maximize value.

Types of summarization

Summarization can generally be divided into two main types: extractive summarization and abstractive summarization. Both approaches aim to condense long pieces of text into shorter forms, capturing the most critical information or essence of the original content, but they do so in fundamentally different ways.

Extractive summarization involves identifying and extracting key phrases, sentences, or segments from the original text without altering them. The system selects parts of the text deemed most informative or representative of the whole. Extractive summarization is useful if accuracy is critical and the summary needs to reflect the exact information from the original text. These could be use cases like highlighting specific legal terms, obligations, and rights outlined in the terms of use. The most common techniques used for extractive summarization are term frequency-inverse document frequency (TF-IDF), sentence scoring, text rank algorithm, and supervised machine learning (ML).

Abstractive summarization goes a step further by generating new phrases and sentences that were not in the original text, essentially paraphrasing and condensing the original content. This approach requires a deeper understanding of the text, because the AI needs to interpret the meaning and then express it in a new, concise form. Large language models (LLMs) are best suited for abstractive summarization because the transformer models use attention mechanisms to focus on relevant parts of the input text when generating summaries. The attention mechanism allows the model to assign different weights to different words or tokens in the input sequence, enabling it to capture long-range dependencies and contextually relevant information.

In addition to these two primary types, there are hybrid approaches that combine extractive and abstractive methods. These approaches might start with extractive summarization to identify the most important content and then use abstractive techniques to rewrite or condense that content into a fluent summary.

The challenge

Finding the optimal method to evaluate summary quality remains an open challenge. As organizations increasingly rely on automatic text summarization to distill key information from documents, the need grows for standardized techniques to measure summarization accuracy. Ideally, these evaluation metrics would quantify how well machine-generated summaries extract the most salient content from source texts and present coherent summaries reflecting the original meaning and context.

However, developing robust evaluation methodologies for text summarization presents difficulties:

  • Human-authored reference summaries used for comparison often exhibit high variability based on subjective determinations of importance
  • Nuanced aspects of summary quality like fluency, readability, and coherence prove difficult to quantify programmatically
  • Wide variation exists across summarization methods from statistical algorithms to neural networks, complicating direct comparisons

Recall-Oriented Understudy for Gisting Evaluation (ROUGE)

ROUGE metrics, such as ROUGE-N and ROUGE-L, play a crucial role in evaluating the quality of machine-generated summaries compared to human-written reference summaries. These metrics focus on assessing the overlap between the content of machine-generated and human-crafted summaries by analyzing n-grams, which are groups of words or tokens. For instance, ROUGE-1 evaluates the match of individual words (unigrams), whereas ROUGE-2 considers pairs of words (bigrams). Additionally, ROUGE-N assesses the longest common subsequence of words between the two texts, allowing for flexibility in word order.

To illustrate this, consider the following examples:

  • ROGUE-1 metric – ROUGE-1 evaluates the overlap of unigrams (single words) between a generated summary and a reference summary. For example, if a reference summary contains “The quick brown fox jumps,” and the generated summary is “The brown fox jumps quickly,” the ROUGE-1 metric would consider “brown,” “fox,” and “jumps” as overlapping unigrams. ROUGE-1 focuses on the presence of individual words in the summaries, measuring how well the generated summary captures the key words from the reference summary.
  • ROGUE-2 metric – ROUGE-2 assesses the overlap of bigrams (pairs of adjacent words) between a generated summary and a reference summary. For instance, if the reference summary has “The cat is sleeping,” and the generated summary reads “A cat is sleeping,” ROUGE-2 would identify “cat is” and “is sleeping” as an overlapping bigram. ROUGE-2 provides insight into how well the generated summary maintains the sequence and context of word pairs compared to the reference summary.
  • ROUGE-N metric – ROUGE-N is a generalized form where N represents any number, allowing evaluation based on n-grams (sequences of N words). Considering N=3, if the reference summary states “The sun is shining brightly,” and the generated summary is “Sun shining brightly,” ROUGE-3 would recognize “sun shining brightly” as a matching trigram. ROUGE-N offers flexibility to evaluate summaries based on different lengths of word sequences, providing a more comprehensive assessment of content overlap.

These examples illustrate how ROUGE-1, ROUGE-2, and ROUGE-N metrics function in evaluating automatic summarization or machine translation tasks by comparing generated summaries with reference summaries based on different levels of word sequences.

Calculate a ROUGE-N score

You can use the following steps to calculate a ROUGE-N score:

  1. Tokenize the generated summary and the reference summary into individual words or tokens using basic tokenization methods like splitting by whitespace or natural language processing (NLP) libraries.
  2. Generate n-grams (contiguous sequences of N words) from both the generated summary and the reference summary.
  3. Count the number of overlapping n-grams between the generated summary and the reference summary.
  4. Calculate precision, recall, and F1 score:
    • Precision – The number of overlapping n-grams divided by the total number of n-grams in the generated summary.
    • Recall – The number of overlapping n-grams divided by the total number of n-grams in the reference summary.
    • F1 score – The harmonic mean of precision and recall, calculated as (2 * precision * recall) / (precision + recall).
  5. The aggregate F1 score obtained from calculating precision, recall, and F1 score for each row in the dataset is considered as the ROUGE-N score.

Limitations

ROGUE has the following limitations:

  • Narrow focus on lexical overlap – The core idea behind ROUGE is to compare the system-generated summary to a set of reference or human-created summaries, and measure the lexical overlap between them. This means ROUGE has a very narrow focus on word-level similarity. It doesn’t actually evaluate semantic meaning, coherence, or readability of the summary. A system could achieve high ROUGE scores by simply extracting sentences word-for-word from the original text, without generating a coherent or concise summary.
  • Insensitivity to paraphrasing – Because ROUGE relies on lexical matching, it can’t detect semantic equivalence between words and phrases. Therefore, paraphrasing and use of synonyms will often lead to lower ROUGE scores, even if the meaning is preserved. This disadvantages systems that paraphrase or summarize in an abstractive way.
  • Lack of semantic understanding – ROUGE doesn’t evaluate whether the system truly understood the meanings and concepts in the original text. A summary could achieve high lexical overlap with references, while missing the main ideas or containing factual inconsistencies. ROUGE would not identify these issues.

When to use ROUGE

ROUGE is simple and fast to calculate. Use it as a baseline or benchmark for summary quality related to content selection. ROUGE metrics are most effectively employed in scenarios involving abstractive summarization tasks, automatic summarization evaluation, assessments of LLMs, and comparative analyses of different summarization approaches. By using ROUGE metrics in these contexts, stakeholders can quantitatively evaluate the quality and effectiveness of summary generation processes.

Metric for Evaluation of Translation with Explicit Ordering (METEOR)

One of the major challenges in evaluating summarization systems is assessing how well the generated summary flows logically, rather than just selecting relevant words and phrases from the source text. Simply extracting relevant keywords and sentences doesn’t necessarily produce a coherent and cohesive summary. The summary should flow smoothly and connect ideas logically, even if they aren’t presented in the same order as the original document.

The flexibility of matching by reducing words to their root or base form (For example, after stemming, words like “running,” “runs,” and “ran” all become “run”) and synonyms means METEOR correlates better with human judgements of summary quality. It can identify if important content is preserved, even if the wording differs. This is a key advantage over n-gram based metrics like ROUGE, which only look for exact token matches. METEOR also gives higher scores to summaries that focus on the most salient content from the reference. Lower scores are given to repetitive or irrelevant information. This aligns well with the goal of summarization to keep the most important content only. METEOR is a semantically meaningful metric that can overcome some of the limitations of n-gram matching for evaluating text summarization. The incorporation of stemming and synonyms allows for better assessment of information overlap and content accuracy.

To illustrate this, consider the following examples:

Reference Summary: Leaves fall during autumn.

Generated Summary 1: Leaves drop in fall.

Generated Summary 2: Leaves green in summer.

The words that match between the reference and generated summary 1 are highlighted:

Reference Summary: Leavesfall during autumn.

Generated Summary 1: Leaves drop in fall.

Even though “fall” and “autumn” are different tokens, METEOR recognizes them as synonyms through its synonym matching. “Drop” and “fall” are identified as a stemmed match. For generated summary 2, there are no matches with the reference summary besides “Leaves,” so this summary would receive a much lower METEOR score. The more semantically meaningful matches, the higher the METEOR score. This allows METEOR to better evaluate the content and accuracy of summaries compared to simple n-gram matching.

Calculate a METEOR score

Complete the following steps to calculate a METEOR score:

  1. Tokenize the generated summary and the reference summary into individual words or tokens using basic tokenization methods like splitting by whitespace or NLP libraries.
  2. Calculate the unigram precision, recall, and F-mean score, giving more weightage to recall than precision.
  3. Apply a penalty for exact matches to avoid overemphasizing them. The penalty is chosen based on dataset characteristics, task requirements, and the balance between precision and recall. Subtract this penalty from the F-mean score calculated in Step 2.
  4. Calculate the F-mean score for stemmed forms (reducing words to their base or root form) and synonyms for unigrams where applicable. Aggregate this with the earlier calculated F-mean score to obtain the final METEOR score. The METEOR score ranges from 0–1, where 0 indicates no similarity between the generated summary and reference summary, and 1 indicates perfect alignment. Typically, summarization scores fall between 0–0.6.

Limitations

When employing the METEOR metric for evaluating summarization tasks, several challenges may arise:

  • Semantic complexity – METEOR’s emphasis on semantic similarity can struggle to capture the nuanced meanings and context in complex summarization tasks, potentially leading to inaccuracies in evaluation.
  • Reference variability – Variability in human-generated reference summaries can impact METEOR scores, because differences in reference content may affect the evaluation of machine-generated summaries.
  • Linguistic diversity – The effectiveness of METEOR may vary across languages due to linguistic variations, syntax differences, and semantic nuances, posing challenges in multilingual summarization evaluations.
  • Length discrepancy – Evaluating summaries of varying lengths can be challenging for METEOR, because discrepancies in length compared to the reference summary may result in penalties or inaccuracies in assessment.
  • Parameter tuning – Optimizing METEOR’s parameters for different datasets and summarization tasks can be time-consuming and require careful tuning to make sure the metric provides accurate evaluations.
  • Evaluation bias – There is a risk of evaluation bias with METEOR if not properly adjusted or calibrated for specific summarization domains or tasks. This can potentially lead to skewed results and affect the reliability of the evaluation process.

By being aware of these challenges and considering them when using METEOR as a metric for summarization tasks, researchers and practitioners can navigate potential limitations and make more informed decisions in their evaluation processes.

When to use METEOR

METEOR is commonly used to automatically evaluate the quality of text summaries. It is preferable to use METEOR as an evaluation metric when the order of ideas, concepts, or entities in the summary matters. METEOR considers the order and matches n-grams between the generated summary and reference summaries. It rewards summaries that preserve sequential information. Unlike metrics like ROUGE, which rely on overlap of n-grams with reference summaries, METEOR matches stems, synonyms, and paraphrases. METEOR works better when there can be multiple correct ways of summarizing the original text. METEOR incorporates WordNet synonyms and stemmed tokens when matching n-grams. In short, summaries that are semantically similar but use different words or phrasing will still score well. METEOR has a built-in penalty for summaries with repetitive n-grams. Therefore, it discourages word-for-word extraction or lack of abstraction. METEOR is a good choice when semantic similarity, order of ideas, and fluent phrasing are important for judging summary quality. It is less appropriate for tasks where only lexical overlap with reference summaries matters.

BERTScore

Surface-level lexical measures like ROUGE and METEOR evaluate summarization systems by comparing the word overlap between a candidate summary and a reference summary. However, they rely heavily on exact string matching between words and phrases. This means they may miss semantic similarities between words and phrases that have different surface forms but similar underlying meanings. By relying only on surface matching, these metrics may underestimate the quality of system summaries that use synonymous words or paraphrase concepts differently from reference summaries. Two summaries could convey nearly identical information but receive low surface-level scores due to vocabulary differences.

BERTScore is a way to automatically evaluate how good a summary is by comparing it to a reference summary written by a human. It uses BERT, a popular NLP technique, to understand the meaning and context of words in the candidate summary and reference summary. Specifically, it looks at each word or token in the candidate summary and finds the most similar word in the reference summary based on the BERT embeddings, which are vector representations of the meaning and context of each word. It measures the similarity using cosine similarity, which tells how close the vectors are to each other. For each word in the candidate summary, it finds the most related word in the reference summary using BERT’s understanding of language. It compares all these word similarities across the whole summary to get an overall score of how semantically similar the candidate summary is to the reference summary. The more similar the words and meanings captured by BERT, the higher the BERTScore. This allows it to automatically evaluate the quality of a generated summary by comparing it to a human reference without needing human evaluation each time.

To illustrate this, imagine you have a machine-generated summary: “The quick brown fox jumps over the lazy dog.” Now, let’s consider a human-crafted reference summary: “A fast brown fox leaps over a sleeping canine.”

Calculate a BERTScore

Complete the following steps to calculate a BERTScore:

  1. BERTScore uses contextual embeddings to represent each token in both the candidate (machine-generated) and reference (human-crafted) sentences. Contextual embeddings are a type of word representation in NLP that captures the meaning of a word based on its context within a sentence or text. Unlike traditional word embeddings that assign a fixed vector to each word regardless of its context, contextual embeddings consider the surrounding words to generate a unique representation for each word depending on how it is used in a specific sentence.
  2. The metric then computes the similarity between each token in the candidate sentence with each token in the reference sentence using cosine similarity. Cosine similarity helps us quantify how closely related two sets of data are by focusing on the direction they point in a multi-dimensional space, making it a valuable tool for tasks like search algorithms, NLP, and recommendation systems.
  3. By comparing the contextual embeddings and computing similarity scores for all tokens, BERTScore generates a comprehensive evaluation that captures the semantic relevance and context of the generated summary compared to the human-crafted reference.
  4. The final BERTScore output provides a similarity score that reflects how well the machine-generated summary aligns with the reference summary in terms of meaning and context.

In essence, BERTScore goes beyond traditional metrics by considering the semantic nuances and context of sentences, offering a more sophisticated evaluation that closely mirrors human judgment. This advanced approach enhances the accuracy and reliability of evaluating summarization tasks, making BERTScore a valuable tool in assessing text generation systems.

Limitations:

Although BERTScore offers significant advantages in evaluating summarization tasks, it also comes with certain limitations that need to be considered:

  • Computational intensity – BERTScore can be computationally intensive due to its reliance on pre-trained language models like BERT. This can lead to longer evaluation times, especially when processing large volumes of text data.
  • Dependency on pre-trained models – The effectiveness of BERTScore is highly dependent on the quality and relevance of the pre-trained language model used. In scenarios where the pre-trained model may not adequately capture the nuances of the text, the evaluation results may be affected.
  • Scalability – Scaling BERTScore for large datasets or real-time applications can be challenging due to its computational demands. Implementing BERTScore in production environments may require optimization strategies to provide efficient performance.
  • Domain specificity – BERTScore’s performance may vary across different domains or specialized text types. Adapting the metric to specific domains or tasks may require fine-tuning or adjustments to produce accurate evaluations.
  • Interpretability – Although BERTScore provides a comprehensive evaluation based on contextual embeddings, interpreting the specific reasons behind the similarity scores generated for each token can be complex and may require additional analysis.
  • Reference-free evaluation – Although BERTScore reduces the reliance on reference summaries for evaluation, this reference-free approach may not fully capture all aspects of summarization quality, particularly in scenarios where human-crafted references are essential for assessing content relevance and coherence.

Acknowledging these limitations can help you make informed decisions when using BERTScore as a metric for evaluating summarization tasks, providing a balanced understanding of its strengths and constraints.

When to use BERTScore

BERTScore can evaluate the quality of text summarization by comparing a generated summary to a reference summary. It uses neural networks like BERT to measure semantic similarity beyond just exact word or phrase matching. This makes BERTScore very useful when semantic fidelity preserving the full meaning and content is critical for your summarization task. BERTScore will give higher scores to summaries that convey the same information as the reference summary, even if they use different words and sentence structures. The bottom line is that BERTScore is ideal for summarization tasks where retaining the full semantic meaning not just keywords or topics is vital. Its advanced neural scoring allows it to compare meaning beyond surface-level word matching. This makes it suitable for cases where subtle differences in wording can substantially alter overall meaning and implications. BERTScore, in particular, excels in capturing semantic similarity, which is crucial for assessing the quality of abstractive summaries like those produced by Retrieval Augmented Generation (RAG) models.

Model evaluation frameworks

Model evaluation frameworks are essential for accurately gauging the performance of various summarization models. These frameworks are instrumental in comparing models, providing coherence between generated summaries and source content, and pinpointing deficiencies in evaluation methods. By conducting thorough assessments and consistent benchmarking, these frameworks propel text summarization research by advocating standardized evaluation practices and enabling multifaceted model comparisons.

In AWS, the FMEval library within Amazon SageMaker Clarify streamlines the evaluation and selection of foundation models (FMs) for tasks like text summarization, question answering, and classification. It empowers you to evaluate FMs based on metrics such as accuracy, robustness, creativity, bias, and toxicity, supporting both automated and human-in-the-loop evaluations for LLMs. With UI-based or programmatic evaluations, FMEval generates detailed reports with visualizations to quantify model risks like inaccuracies, toxicity, or bias, helping organizations align with their responsible generative AI guidelines. In this section, we demonstrate how to use the FMEval library.

Evaluate Claude v2 on summarization accuracy using Amazon Bedrock

The following code snippet is an example of how to interact with the Anthropic Claude model using Python code:

import json
# We use Claude v2 in this example.
# See https://docs.anthropic.com/claude/reference/claude-on-amazon-bedrock#list-available-models
# for instructions on how to list the model IDs for all available Claude model variants.
model_id = 'anthropic.claude-v2'
accept = "application/json"
contentType = "application/json"
# `prompt_data` is structured in the format that the Claude model expects, as documented here:
# https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-claude.html#model-parameters-claude-request-body
prompt_data = """Human: Who is Barack Obama?
Assistant:
"""
# For more details on parameters that can be included in `body` (such as "max_tokens_to_sample"),
# see https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-claude.html#model-parameters-claude-request-body
body = json.dumps({"prompt": prompt_data, "max_tokens_to_sample": 500})
# Invoke the model
response = bedrock_runtime.invoke_model(
body=body, modelId=model_id, accept=accept, contentType=contentType
)
# Parse the invocation response
response_body = json.loads(response.get("body").read())
print(response_body.get("completion"))

In simple terms, this code performs the following actions:

  1. Import the necessary libraries, including json, to work with JSON data.
  2. Define the model ID as anthropic.claude-v2 and set the content type for the request.
  3. Create a prompt_data variable that structures the input data for the Claude model. In this case, it asks the question “Who is Barack Obama?” and expects a response from the model.
  4. Construct a JSON object named body that includes the prompt data, and specify additional parameters like the maximum number of tokens to generate.
  5. Invoke the Claude model using bedrock_runtime.invoke_model with the defined parameters.
  6. Parse the response from the model, extract the completion (generated text), and print it out.

Make sure the AWS Identity and Access Management (IAM) role associated with the Amazon SageMaker Studio user profile has access to the Amazon Bedrock models being invoked. Refer to Identity-based policy examples for Amazon Bedrock for guidance on best practices and examples of identity-based policies for Amazon Bedrock.

Using the FMEval library to evaluate the summarized output from Claude

We use the following code to evaluate the summarized output:

from fmeval.data_loaders.data_config import DataConfig
from fmeval.model_runners.bedrock_model_runner import BedrockModelRunner
from fmeval.constants import MIME_TYPE_JSONLINES
from fmeval.eval_algorithms.summarization_accuracy import SummarizationAccuracy
config = DataConfig(
    dataset_name="gigaword_sample",
    dataset_uri="gigaword_sample.jsonl",
    dataset_mime_type=MIME_TYPE_JSONLINES,
    model_input_location="document",
    target_output_location="summary"
)
bedrock_model_runner = BedrockModelRunner(
    model_id=model_id,
    output='completion',
    content_template='{"prompt": $prompt, "max_tokens_to_sample": 500}'
)
eval_algo = SummarizationAccuracy()
eval_output = eval_algo.evaluate(model=bedrock_model_runner, dataset_config=config,
prompt_template="Human: Summarise the following text in one sentence: $featurennAssistant:n", save=True)

In the preceding code snippet, to evaluate text summarization using the FMEval library, we complete the following steps:

  1. Create a ModelRunner to perform invocation on your LLM. The FMEval library provides built-in support for Amazon SageMaker endpoints and Amazon SageMaker JumpStart LLMs. You can also extend the ModelRunner interface for any LLMs hosted anywhere.
  2. Use supported eval_algorithms like toxicity, summarization, accuracy, semantic, and robustness, based on your evaluation needs.
  3. Customize the evaluation configuration parameters for your specific use case.
  4. Use the evaluation algorithm with either built-in or custom datasets to evaluate your LLM model. The dataset used in this case is sourced from the following GitHub repo.

Refer to the developer guide and examples for detailed usage of evaluation algorithms.

The following table summarizes the results of the evaluation.

model _input model_output target_output prompt scores meteor_score rouge_score bert_score
John Edward
0 Bates, formerly of Spalding, Linco…..
I cannot make any definitive
judgments, as th…
A former
Lincolnshire Police officer carried o…
Human: John
Edward Bates, formerly of Spalding…
[{‘name’: ‘meteor’, ‘value’:
0.101010101010101…
0.10101 0 0.557155
23 October 2015
Last updated at
17:44 BST|nIt’…
Here are some key points about hurricane/trop.. Hurricane Patricia has been rated as a categor… Human: 23
October 2015 Last updated at 17:44
B…
[{‘name’: meteor’, “value’:
0.102339181286549..
0.102339 0.018265 0.441421
Ferrari appeared in a position to challenge un… Here are the key points from the article:nin… Lewis Hamilton stormed to pole position at the… Human: Ferrari appeared in a position to chall… [{‘name’: ‘meteor’, ‘value’:
0.322543352601156…
0.322543 0.078212 0.606487
The Bath-born player, 28, has made 36
appearan…
Okay, let me summarize the key points:/nin- E….. Newport Gwent Dragons number eight Ed Jackson Human: The Bath-born player, 28, has made 36 a… [{‘name’: ‘meteor’, ‘value’:
0105740181268882…
0.10574 0.012987 0.539488
Weaknesses in the way mice swapped data with c… Here are the key points I gathered from the a… Hackers could gain access to home and Human:
Weaknesses in the
swar mice swapped data
[{‘name’: ‘meteor’, ‘value’:
0.201048289433848…
0.201048 0.021858 0.526947

Check out the sample notebook for more details about the summarization evaluation that we discussed in this post.

Conclusion

ROUGE, METEOR, and BERTScore all measure the quality of machine-generated summaries, but focus on different aspects like lexical overlap, fluency, or semantic similarity. Make sure to select the metric that aligns with what defines “good” for your specific summarization use case. You can also use a combination of metrics. This provides a more well-rounded evaluation and guards against potential weaknesses of any individual metric. With the right measurements, you can iteratively improve your summarizers to meet whichever notion of accuracy matters most.

Additionally, FM and LLM evaluation is necessary to be able to productionize these models at scale. With FMEval, you get a vast set of built-in algorithms across many NLP tasks, but also a scalable and flexible tool for large-scale evaluations of your own models, datasets, and algorithms. To scale up, you can use this package in your LLMOps pipelines to evaluate multiple models. To learn more about FMEval in AWS and how to use it effectively, refer to Use SageMaker Clarify to evaluate large language models. For further understanding and insights into the capabilities of SageMaker Clarify in evaluating FMs, see Amazon SageMaker Clarify Makes It Easier to Evaluate and Select Foundation Models.


About the Authors

Dinesh Kumar Subramani is a Senior Solutions Architect based in Edinburgh, Scotland. He specializes in artificial intelligence and machine learning, and is member of technical field community with in Amazon. Dinesh works closely with UK Central Government customers to solve their problems using AWS services. Outside of work, Dinesh enjoys spending quality time with his family, playing chess, and exploring a diverse range of music.

Pranav Sharma is an AWS leader driving technology and business transformation initiatives across Europe, the Middle East, and Africa. He has experience in designing and running artificial intelligence platforms in production that support millions of customers and deliver business outcomes. He has played technology and people leadership roles for Global Financial Services organizations. Outside of work, he likes to read, play tennis with his son, and watch movies.

Read More

Enhance conversational AI with advanced routing techniques with Amazon Bedrock

Enhance conversational AI with advanced routing techniques with Amazon Bedrock

Conversational artificial intelligence (AI) assistants are engineered to provide precise, real-time responses through intelligent routing of queries to the most suitable AI functions. With AWS generative AI services like Amazon Bedrock, developers can create systems that expertly manage and respond to user requests. Amazon Bedrock is a fully managed service that offers a choice of high-performing foundation models (FMs) from leading AI companies like AI21 Labs, Anthropic, Cohere, Meta, Stability AI, and Amazon using a single API, along with a broad set of capabilities you need to build generative AI applications with security, privacy, and responsible AI.

This post assesses two primary approaches for developing AI assistants: using managed services such as Agents for Amazon Bedrock, and employing open source technologies like LangChain. We explore the advantages and challenges of each, so you can choose the most suitable path for your needs.

What is an AI assistant?

An AI assistant is an intelligent system that understands natural language queries and interacts with various tools, data sources, and APIs to perform tasks or retrieve information on behalf of the user. Effective AI assistants possess the following key capabilities:

  • Natural language processing (NLP) and conversational flow
  • Knowledge base integration and semantic searches to understand and retrieve relevant information based on the nuances of conversation context
  • Running tasks, such as database queries and custom AWS Lambda functions
  • Handling specialized conversations and user requests

We demonstrate the benefits of AI assistants using Internet of Things (IoT) device management as an example. In this use case, AI can help technicians manage machinery efficiently with commands that fetch data or automate tasks, streamlining operations in manufacturing.

Agents for Amazon Bedrock approach

Agents for Amazon Bedrock allows you to build generative AI applications that can run multi-step tasks across a company’s systems and data sources. It offers the following key capabilities:

  • Automatic prompt creation from instructions, API details, and data source information, saving weeks of prompt engineering effort
  • Retrieval Augmented Generation (RAG) to securely connect agents to a company’s data sources and provide relevant responses
  • Orchestration and running of multi-step tasks by breaking down requests into logical sequences and calling necessary APIs
  • Visibility into the agent’s reasoning through a chain-of-thought (CoT) trace, allowing troubleshooting and steering of model behavior
  • Prompt engineering abilities to modify the automatically generated prompt template for enhanced control over agents

You can use Agents for Amazon Bedrock and Knowledge Bases for Amazon Bedrock to build and deploy AI assistants for complex routing use cases. They provide a strategic advantage for developers and organizations by simplifying infrastructure management, enhancing scalability, improving security, and reducing undifferentiated heavy lifting. They also allow for simpler application layer code because the routing logic, vectorization, and memory is fully managed.

Solution overview

This solution introduces a conversational AI assistant tailored for IoT device management and operations when using Anthropic’s Claude v2.1 on Amazon Bedrock. The AI assistant’s core functionality is governed by a comprehensive set of instructions, known as a system prompt, which delineates its capabilities and areas of expertise. This guidance makes sure the AI assistant can handle a wide range of tasks, from managing device information to running operational commands.

"""The following is the system prompt that outlines the full scope of the AI assistant's capabilities:
You are an IoT Ops agent that handles the following activities:
- Looking up IoT device information
- Checking IoT operating metrics (historical data)
- Performing actions on a device-by-device ID
- Answering general questions
You can check device information (Device ID, Features, Technical Specifications, Installation Guide, Maintenance and Troubleshooting, Safety Guidelines, Warranty, and Support) from the "IotDeviceSpecs" knowledge base.
Additionally, you can access device historical data or device metrics. The device metrics are stored in an Athena DB named "iot_ops_glue_db" in a table named "iot_device_metrics". 
The table schema includes fields for oil level, temperature, pressure, received_at timestamp, and device_id.
The available actions you can perform on the devices include start, shutdown, and reboot."""

Equipped with these capabilities, as detailed in the system prompt, the AI assistant follows a structured workflow to address user questions. The following figure provides a visual representation of this workflow, illustrating each step from initial user interaction to the final response.

a visual representation of this workflow, illustrating each step from initial user interaction to the final response.

The workflow is composed of the following steps:

  1. The process begins when a user requests the assistant to perform a task; for example, asking for the maximum data points for a specific IoT device device_xxx. This text input is captured and sent to the AI assistant.
  2. The AI assistant interprets the user’s text input. It uses the provided conversation history, action groups, and knowledge bases to understand the context and determine the necessary tasks.
  3. After the user’s intent is parsed and understood, the AI assistant defines tasks. This is based on the instructions that are interpreted by the assistant as per the system prompt and user’s input.
  4. The tasks are then run through a series of API calls. This is done using ReAct prompting, which breaks down the task into a series of steps that are processed sequentially:
    1. For device metrics checks, we use the check-device-metrics action group, which involves an API call to Lambda functions that then query Amazon Athena for the requested data.
    2. For direct device actions like start, stop, or reboot, we use the action-on-device action group, which invokes a Lambda function. This function initiates a process that sends commands to the IoT device. For this post, the Lambda function sends notifications using Amazon Simple Email Service (Amazon SES).
    3. We use Knowledge Bases for Amazon Bedrock to fetch from historical data stored as embeddings in the Amazon OpenSearch Service vector database.
  5. After the tasks are complete, the final response is generated by the Amazon Bedrock FM and conveyed back to the user.
  6. Agents for Amazon Bedrock automatically stores information using a stateful session to maintain the same conversation. The state is deleted after a configurable idle timeout elapses.

Technical overview

The following diagram illustrates the architecture to deploy an AI assistant with Agents for Amazon Bedrock.

Architecture diagram to deploy an AI assistant with Agents for Amazon Bedrock.

It consists of the following key components:

  • Conversational interface – The conversational interface uses Streamlit, an open source Python library that simplifies the creation of custom, visually appealing web apps for machine learning (ML) and data science. It is hosted on Amazon Elastic Container Service (Amazon ECS) with AWS Fargate, and it is accessed using an Application Load Balancer. You can use Fargate with Amazon ECS to run containers without having to manage servers, clusters, or virtual machines.
  • Agents for Amazon Bedrock – Agents for Amazon Bedrock completes the user queries through a series of reasoning steps and corresponding actions based on ReAct prompting:
    • Knowledge Bases for Amazon Bedrock – Knowledge Bases for Amazon Bedrock provides fully managed RAG to supply the AI assistant with access to your data. In our use case, we uploaded device specifications into an Amazon Simple Storage Service (Amazon S3) bucket. It serves as the data source to the knowledge base.
    • Action groups – These are defined API schemas that invoke specific Lambda functions to interact with IoT devices and other AWS services.
    • Anthropic Claude v2.1 on Amazon Bedrock – This model interprets user queries and orchestrates the flow of tasks.
    • Amazon Titan Embeddings – This model serves as a text embeddings model, transforming natural language text—from single words to complex documents—into numerical vectors. This enables vector search capabilities, allowing the system to semantically match user queries with the most relevant knowledge base entries for effective search.

The solution is integrated with AWS services such as Lambda for running code in response to API calls, Athena for querying datasets, OpenSearch Service for searching through knowledge bases, and Amazon S3 for storage. These services work together to provide a seamless experience for IoT device operations management through natural language commands.

Benefits

This solution offers the following benefits:

  • Implementation complexity:
    • Fewer lines of code are required, because Agents for Amazon Bedrock abstracts away much of the underlying complexity, reducing development effort
    • Managing vector databases like OpenSearch Service is simplified, because Knowledge Bases for Amazon Bedrock handles vectorization and storage
    • Integration with various AWS services is more streamlined through pre-defined action groups
  • Developer experience:
    • The Amazon Bedrock console provides a user-friendly interface for prompt development, testing, and root cause analysis (RCA), enhancing the overall developer experience
  • Agility and flexibility:
    • Agents for Amazon Bedrock allows for seamless upgrades to newer FMs (such as Claude 3.0) when they become available, so your solution stays up to date with the latest advancements
    • Service quotas and limitations are managed by AWS, reducing the overhead of monitoring and scaling infrastructure
  • Security:
    • Amazon Bedrock is a fully managed service, adhering to AWS’s stringent security and compliance standards, potentially simplifying organizational security reviews

Although Agents for Amazon Bedrock offers a streamlined and managed solution for building conversational AI applications, some organizations may prefer an open source approach. In such cases, you can use frameworks like LangChain, which we discuss in the next section.

LangChain dynamic routing approach

LangChain is an open source framework that simplifies building conversational AI by allowing the integration of large language models (LLMs) and dynamic routing capabilities. With LangChain Expression Language (LCEL), developers can define the routing, which allows you to create non-deterministic chains where the output of a previous step defines the next step. Routing helps provide structure and consistency in interactions with LLMs.

For this post, we use the same example as the AI assistant for IoT device management. However, the main difference is that we need to handle the system prompts separately and treat each chain as a separate entity. The routing chain decides the destination chain based on the user’s input. The decision is made with the support of an LLM by passing the system prompt, chat history, and user’s question.

Solution overview

The following diagram illustrates the dynamic routing solution workflow.

Dynamic routing solution workflow with LangChain

The workflow consists of the following steps:

  1. The user presents a question to the AI assistant. For example, “What are the max metrics for device 1009?”
  2. An LLM evaluates each question along with the chat history from the same session to determine its nature and which subject area it falls under (such as SQL, action, search, or SME). The LLM classifies the input and the LCEL routing chain takes that input.
  3. The router chain selects the destination chain based on the input, and the LLM is provided with the following system prompt:
"""Given the user question below, classify it as one of the candidate prompts. You may want to modify the input considering the chat history and the context of the question. 
Sometimes the user may just assume that you have the context of the conversation and may not provide a clear input. Hence, you are being provided with the chat history for more context. 
Respond with only a Markdown code snippet containing a JSON object formatted EXACTLY as specified below. 
Do not provide an explanation to your classification beside the Markdown, I just need to know your decision on which destination and next_inputs
<candidate prompt>
physics: Good for answering questions about physics
sql: sql: Good for querying sql from AWS Athena. User input may look like: get me max or min for device x?
lambdachain: Good to execute actions with Amazon Lambda like shutting down a device or turning off an engine User input can be like, shutdown device x, or terminate process y, etc.
rag: Good to search knowledgebase and retrieve information about devices and other related information. User question can be like: what do you know about device x?
default: if the input is not well suited for any of the candidate prompts above. this could be used to carry on the conversation and respond to queries like provide a summary of the conversation
</candidate prompt>"""

The LLM evaluates the user’s question along with the chat history to determine the nature of the query and which subject area it falls under. The LLM then classifies the input and outputs a JSON response in the following format:

<Markdown>
```json
{{
"destination": string  name of the prompt to use
"next_inputs": string  a potentially modified version of the original input
}}
```

The router chain uses this JSON response to invoke the corresponding destination chain. There are four subject-specific destination chains, each with its own system prompt:

  1. SQL-related queries are sent to the SQL destination chain for database interactions. You can use LCEL to build the SQL chain.
  2. Action-oriented questions invoke the custom Lambda destination chain for running operations. With LCEL, you can define your own custom function; in our case, it’s a function to run a predefined Lambda function to send an email with a device ID parsed. Example user input might be “Shut down device 1009.”
  3. Search-focused inquiries proceed to the RAG destination chain for information retrieval.
  4. SME-related questions go to the SME/expert destination chain for specialized insights.
  5. Each destination chain takes the input and runs the necessary models or functions:
    1. The SQL chain uses Athena for running queries.
    2. The RAG chain uses OpenSearch Service for semantic search.
    3. The custom Lambda chain runs Lambda functions for actions.
    4. The SME/expert chain provides insights using the Amazon Bedrock model.
  6. Responses from each destination chain are formulated into coherent insights by the LLM. These insights are then delivered to the user, completing the query cycle.
  7. User input and responses are stored in Amazon DynamoDB to provide context to the LLM for the current session and from past interactions. The duration of persisted information in DynamoDB is controlled by the application.

Technical overview

The following diagram illustrates the architecture of the LangChain dynamic routing solution.

Architecture diagram of the LangChain dynamic routing solution

The web application is built on Streamlit hosted on Amazon ECS with Fargate, and it is accessed using an Application Load Balancer. We use Anthropic’s Claude v2.1 on Amazon Bedrock as our LLM. The web application interacts with the model using LangChain libraries. It also interacts with variety of other AWS services, such as OpenSearch Service, Athena, and DynamoDB to fulfill end-users’ needs.

Benefits

This solution offers the following benefits:

  • Implementation complexity:
    • Although it requires more code and custom development, LangChain provides greater flexibility and control over the routing logic and integration with various components.
    • Managing vector databases like OpenSearch Service requires additional setup and configuration efforts. The vectorization process is implemented in code.
    • Integrating with AWS services may involve more custom code and configuration.
  • Developer experience:
    • LangChain’s Python-based approach and extensive documentation can be appealing to developers already familiar with Python and open source tools.
    • Prompt development and debugging may require more manual effort compared to using the Amazon Bedrock console.
  • Agility and flexibility:
    • LangChain supports a wide range of LLMs, allowing you to switch between different models or providers, fostering flexibility.
    • The open source nature of LangChain enables community-driven improvements and customizations.
  • Security:
    • As an open source framework, LangChain may require more rigorous security reviews and vetting within organizations, potentially adding overhead.

Conclusion

Conversational AI assistants are transformative tools for streamlining operations and enhancing user experiences. This post explored two powerful approaches using AWS services: the managed Agents for Amazon Bedrock and the flexible, open source LangChain dynamic routing. The choice between these approaches hinges on your organization’s requirements, development preferences, and desired level of customization. Regardless of the path taken, AWS empowers you to create intelligent AI assistants that revolutionize business and customer interactions

Find the solution code and deployment assets in our GitHub repository, where you can follow the detailed steps for each conversational AI approach.


About the Authors

Ameer Hakme is an AWS Solutions Architect based in Pennsylvania. He collaborates with Independent Software Vendors (ISVs) in the Northeast region, assisting them in designing and building scalable and modern platforms on the AWS Cloud. An expert in AI/ML and generative AI, Ameer helps customers unlock the potential of these cutting-edge technologies. In his leisure time, he enjoys riding his motorcycle and spending quality time with his family.

Sharon Li is an AI/ML Solutions Architect at Amazon Web Services based in Boston, with a passion for designing and building Generative AI applications on AWS. She collaborates with customers to leverage AWS AI/ML services for innovative solutions.

Kawsar Kamal is a senior solutions architect at Amazon Web Services with over 15 years of experience in the infrastructure automation and security space. He helps clients design and build scalable DevSecOps and AI/ML solutions in the Cloud.

Read More

Improve LLM performance with human and AI feedback on Amazon SageMaker for Amazon Engineering

Improve LLM performance with human and AI feedback on Amazon SageMaker for Amazon Engineering

The Amazon EU Design and Construction (Amazon D&C) team is the engineering team designing and constructing Amazon warehouses. The team navigates a large volume of documents and locates the right information to make sure the warehouse design meets the highest standards. In the post A generative AI-powered solution on Amazon SageMaker to help Amazon EU Design and Construction, we presented a question answering bot solution using a Retrieval Augmented Generation (RAG) pipeline with a fine-tuned large language model (LLM) for Amazon D&C to efficiently retrieve accurate information from a large volume of unorganized documents, and provide timely and high-quality services in their construction projects. The Amazon D&C team implemented the solution in a pilot for Amazon engineers and collected user feedback.

In this post, we share how we analyzed the feedback data and identified limitations of accuracy and hallucinations RAG provided, and used the human evaluation score to train the model through reinforcement learning. To increase training samples for better learning, we also used another LLM to generate feedback scores. This method addressed the RAG limitation and further improved the bot response quality. We present the reinforcement learning process and the benchmarking results to demonstrate the LLM performance improvement. The solution uses Amazon SageMaker JumpStart as the core service for model deployment, fine-tuning, and reinforcement learning.

Collect feedback from Amazon engineers in a pilot project

After developing the solution described in A generative AI-powered solution on Amazon SageMaker to help Amazon EU Design and Construction, the Amazon D&C team deployed the solution and ran a pilot project with Amazon engineers. The engineers accessed the pilot system through a web application developed by Streamlit, connected with the RAG pipeline. In the pipeline, we used Amazon OpenSearch Service for the vector database, and deployed a fine-tuned Mistral-7B-Instruct model on Amazon SageMaker.

One of the key objectives of the pilot is to collect feedback from Amazon engineers and use the feedback to further reduce LLM hallucinations. To achieve this, we developed a feedback collection module in the UI, as shown in the following figure, and stored the web session information and user feedback in Amazon DynamoDB. Through the feedback collection UI, Amazon engineers can select from five satisfaction levels: strongly disagree, disagree, neutral, agree, and strongly agree, corresponding to the feedback scores from 1–5. They can also provide a better answer to the question or comment on why the LLM response is not satisfactory.

human feedback UI

During the pilot, users provided 118 feedback responses. The following chart shows the distribution from negative to positive. We observed the incorrect answers marked by the negative feedback (disagree and strongly disagree) are 45% among the total feedback (53 out of 118). Some of the errors (about 10% of negative feedback and 7.5% of overall responses) can be addressed by user education and prompt engineering. Additionally, we can address the issue with the solution of LLM fine-tuning and reinforcement learning, described in the next section.

User feedback from pilot

Some users simply left a note, such as “Great!” for the strongly agree answers, or “Doesn’t answer the question,” for the strongly disagree answers. Other users provided scores and explained how they justify the LLM answers in their notes. In the following example, the user specifically provided the adequate document and content to correct the LLM hallucination.

Question How many car parking spaces do we have?
Bot Answer

Based on the given context, we don’t have enough information to determine the number of car parking spaces. The context discusses the minimum size of parking spaces and the requirements for external lighting levels, but it does not provide a specific number of parking spaces.

**You can refer to further explanations in the following resources:**

* ARS GEN 10.0/05.01.02. Design Criteria & Appendices/Performance Package AR Sortable Design Criteria v20.1.1.pdf – page: 10

* ARS GEN 10.0/05.01.02. Design Criteria & Appendices/Performance Package AR Sortable Design Criteria v20.1.1.pdf – page: 79

User score Strongly Disagree
User notes This is specified on page 21 of design criteria section 01 13 10

Improve bot response with supervised fine-tuning and reinforcement learning

The solution consists of three steps of fine-tuning:

  1. Conduct supervised fine-tuning using labeled data. This method was described in A generative AI-powered solution on Amazon SageMaker to help Amazon EU Design and Construction.
  2. Collect user feedback to label the question-answer pairs for further LLM tuning.
  3. When the training data is ready, further tune the model using reinforcement learning from human feedback (RLHF).

RLHF is widely used throughout generative artificial intelligence (AI) and LLM applications. It incorporates human feedback in the rewards function and trains the model with a reinforcement learning algorithm to maximize rewards, which makes the model perform tasks more aligned with human goals. The following diagram shows the pipeline of the steps.

Fine tuning workflow

We tested the methodology using the Amazon D&C documents with a Mistral-7B model on SageMaker JumpStart.

Supervised fine-tuning

In the previous post, we demonstrated how the fine-tuned Falcon-7B model outperforms the RAG pipeline and improves the quality and accuracy of QA bot response. For this post, we performed supervised fine-tuning on the Mistral-7B model. The supervised fine-tuning used the PEFT/LoRA technique (LoRA_r = 512, LoRA_alpha = 1024) on 436,207,616 parameters (5.68% of the total 7,677,964,288 parameters). The training was conducted on a p3.8x node with 137 samples synthetically generated by LLM and validated by humans; the process is well converged after 20 epochs, as shown in the following figure.

SFT training process

The fine-tuned model was validated by 274 samples, and the inference results were compared with the reference answers by the semantic similarity score. The score is 0.8100, which is higher than the score of 0.6419 from the traditional RAG.

Collect human and AI feedback for reinforcement learning

For RLHF, a sufficient amount of high-quality training samples labeled by subject matter experts (SMEs) are essential. However, poor-quality human labels will likely cause worse model performance than the original model after RLHF training. SMEs’ time is a scarce resource in any organization; reviewing hundreds or thousands of LLM responses and providing feedback requires a significant time investment from SMEs that may not have a clear return on investment.

To address this challenge, we adopted the reinforcement learning from AI feedback (RLAIF) method. RLAIF employs an AI assistant (another LLM) to provide evaluation scores, rather than from humans. In this hybrid learning approach, the learning agent refines the actions not only based on the interaction with a human but also from feedback provided by another AI model. It is much more scalable to provide sufficient training data for reinforcement learning, and much less subjective because the training is not dependent on particular preferences from a small group of SMEs. The SMEs may still be involved, but only to supervise the process and examine the AI feedback quality. This tremendously reduces the SMEs’ workload, because the AI-generated scores and labels are much less complicated for SMEs to analyze, filter, and sort.

In this use case, we provide the following prompt to an Anthropic Claude 2 model to generate the evaluation scores as the training samples for the reinforcement learning step:

"""

Human:

I am going to provide you with a question, a reference answer, and an machine generated answer to the question.

question: {question}

reference answer: {reference_answer}

machine_answer: {machine_answer}

Please analyze the question, the reference answer, and the machine generated answer to determine if the machine generated answer semantically matches the reference answer.

First think through your answer inside of <thinking> tags, then assign a score between 0.0 and 5.0 based on how well the machine answer matches the reference answer. Score 5.0 means a completely match, score 0.0 means a completely mismatch. Answer the score inside of <scoring> tags.

Also provide the reason to give this score inside of <reason> tags.

Assistant: <thinking>

"""

The Claude 2 model generates an evaluation score and a comment on why the score is given, as shown in the following example.

Question What is meant by ‘Fit Out Works’ in this context?
Reference Answer Fit Out Works refers to Works carried out under the construction contract by or on behalf of Amazon via the Developer up to First Receive Date from a notional Developer Shell & Core Base-build position. These Works include Amazon’s required Fit-Outs, big-box installations, and custom
QA Bot response Fit Out Works refers to Works carried out under the construction contract by or on behalf of Amazon via the Developer up to First Receive Date from a notional Developer Shell & Core Base-build position.
AI score 4.5
AI comment The machine answer provides the key information from the standard answer – that Fit Out Works refers to Works carried out under the construction contract by or on behalf of Amazon via the Developer up to First Receive Date from a notional Developer Shell & Core Base-build position. The additional details about including Amazon’s required Fit-Outs, big-box installations, and customizations are relevant and help provide further context, so I’ve scored the machine answer 4.5 out of 5. It captures the essence of the standard answer and provides some useful extra details.

Out of the 274 validation questions, the supervised fine-tuned model generated 159 responses that have AI scores greater than 4. We observed 60 answers with scores lower than 3; there is space to improve the overall response quality.

Feedback score before RLHF

The Amazon Engineering SMEs validated this AI feedback and acknowledged the benefits of using AI scores. Without AI feedback, the SMEs would need some time to review and analyze each LLM response to identify the cut-off answers and hallucinations, and to judge whether the LLM is returning correct contents and key concepts. AI feedback provides AI scores automatically and enables the SMEs to use filtering, sorting, and grouping to validate the scores and identify trends in the responses. This reduces the average SME’s review time by 80%.

Reinforcement learning from human and AI feedback

When the training samples are ready, we use the proximal policy optimization (PPO) algorithm to perform reinforcement learning. PPO uses a policy gradient method, which takes small steps to update the policy in the learning process, so that the learning agents can reliably reach the optimal policy network. This makes the training process more stable and reduces the possibility of divergence.

During the training, first we use the human- and AI-labeled data to build a reward model, which will be used guide the weights update in the learning process. For this use case, we select a distilroberta-base reward model and train it by samples in the following format:

[Instruction, Chosen_response, Rejected_response]

The following is an example of a training record.

Instruction According to the context, what is specified for inclusive and accessible design?
Chosen_response BREEAM Credit HEA06 – inclusive and accessible design – The building is designed to be fit for purpose, appropriate and accessible by all potential users. An access strategy is developed in line with the BREEAM Check list A3
Rejected_response The context states that

The reward model is trained with the learning rate 1e-5. As shown in the following chart, the training well converges after 10 epochs.

RLHF training process

Then the reward model is used for reinforcement learning. For this use case, we use PEFT/LoRA to accelerate the learning. The LoRA_r is 32 and LoRA_alpha is 64. The PPO-based learning is on a rate of 1.41e-5 in three epochs. Reinforcement training is highly volatile and very sensitive to the label data and feedback data. We use the reward_mean metric, the average reward obtained from the environment by specifically monitoring the reward model, to observe the learning process.

After RLAIF, we can see the number of bot responses with positive scores (above 3) increase from 78.1% to 85.5%, and those with negative scores (below 3) decrease from 21.9% to 14.5%. In total, the average AI feedback score increases up to 8%, from 3.9 to 4.2.

Improvement by RLAIF

You can see the improvement in the following example.

Question For floor areas being tiled, what screed is needed?
Reference Concrete Screed flooring is applicable only in floor-tiled areas. Refer to drawing series A-32 to see where floor-tiles are applicable.
Bot response before RLAIF Concrete Screed flooring is applicable only in
AI score before RLAIF 3
Bot response after RLAIF Concrete Screed flooring is applicable only in areas being tiled. Refer to drawing series A-32 to see where tile flooring is applicable.
AI score after RLAIF 4.5

Conclusion

This post is a continuation of our work with the Amazon EU Design and Construction team as discussed in A generative AI-powered solution on Amazon SageMaker to help Amazon EU Design and Construction. In this post, we showed how we generated human and AI feedback data to fine-tune the Mistral-7B model with reinforcement learning. The model after RLAIF provided better performance for Amazon Engineering’s question answering bot, improved the AI feedback score by 8%. In the Amazon D&C team’s pilot project, using RLAIF reduced the validation workload for SMEs by an estimated 80%. As the next step, we will scale up this solution by connecting with Amazon Engineering’s data infrastructure, and design a framework to automate the continuous learning process with a human in the loop. We will also further improve the AI feedback quality by tuning the prompt template.

Through this process, we learned how to further improve the quality and performance of question answering tasks through RLHF and RLAIF.

  • Human validation and augmentation are essential to provide accurate and responsible outputs from LLM. The human feedback can be used in RLHF to further improve the model response.
  • RLAIF automates the evaluation and learning cycle. The AI-generated feedback is less subjective because it doesn’t depend on a particular preference from a small pool of SMEs.
  • RLAIF is more scalable to improve the bot quality through continued reinforcement learning while minimizing the efforts required from SMEs. It is especially useful for developing domain-specific generative AI solutions within large organizations.
  • This process should be done on a regular basis, especially when new domain data is available to be covered by the solution.

In this use case, we used SageMaker JumpStart to test multiple LLMs and experiment with multiple LLM training approaches. It significantly accelerates the AI feedback and learning cycle with maximized efficiency and quality. For your own project, you can introduce the human-in-the-loop approach to collect your users’ feedback, or generate AI feedback using another LLM. Then you can follow the three-step process defined in this post to fine-tune your models using RLHF and RLAIF. We recommend experimenting with the methods using SageMaker JumpStart to speed up the process.


About the Author

YunfeiYunfei Bai is a Senior Solutions Architect at AWS. With a background in AI/ML, data science, and analytics, Yunfei helps customers adopt AWS services to deliver business results. He designs AI/ML and data analytics solutions that overcome complex technical challenges and drive strategic objectives. Yunfei has a PhD in Electronic and Electrical Engineering. Outside of work, Yunfei enjoys reading and music.

Elad_photoElad Dwek is a Construction Technology Manager at Amazon. With a background in construction and project management, Elad helps teams adopt new technologies and data-based processes to deliver construction projects. He identifies needs and solutions, and facilitates the development of the bespoke attributes. Elad has an MBA and a BSc in Structural Engineering. Outside of work, Elad enjoys yoga, woodworking, and traveling with his family.

Luca_photoLuca Cerabone is a Business Intelligence Engineer at Amazon. Drawing from his background in data science and analytics, Luca crafts tailored technical solutions to meet the unique needs of his customers, driving them towards more sustainable and scalable processes. Armed with an MSc in Data Science, Luca enjoys engaging in DIY projects, gardening and experimenting with culinary delights in his leisure moments.

Read More