Fine-tune and Deploy Mistral 7B with Amazon SageMaker JumpStart

Fine-tune and Deploy Mistral 7B with Amazon SageMaker JumpStart

Today, we are excited to announce the capability to fine-tune the Mistral 7B model using Amazon SageMaker JumpStart. You can now fine-tune and deploy Mistral text generation models on SageMaker JumpStart using the Amazon SageMaker Studio UI with a few clicks or using the SageMaker Python SDK.

Foundation models perform very well with generative tasks, from crafting text and summaries, answering questions, to producing images and videos. Despite the great generalization capabilities of these models, there are often use cases that have very specific domain data (such as healthcare or financial services), and these models may not be able to provide good results for these use cases. This results in a need for further fine-tuning of these generative AI models over the use case-specific and domain-specific data.

In this post, we demonstrate how to fine-tune the Mistral 7B model using SageMaker JumpStart.

What is Mistral 7B

Mistral 7B is a foundation model developed by Mistral AI, supporting English text and code generation abilities. It supports a variety of use cases, such as text summarization, classification, text completion, and code completion. To demonstrate the customizability of the model, Mistral AI has also released a Mistral 7B-Instruct model for chat use cases, fine-tuned using a variety of publicly available conversation datasets.

Mistral 7B is a transformer model and uses grouped query attention and sliding window attention to achieve faster inference (low latency) and handle longer sequences. Grouped query attention is an architecture that combines multi-query and multi-head attention to achieve output quality close to multi-head attention and comparable speed to multi-query attention. The sliding window attention method uses the multiple levels of a transformer model to focus on information that came earlier, which helps the model understand a longer stretch of context. . Mistral 7B has an 8,000-token context length, demonstrates low latency and high throughput, and has strong performance when compared to larger model alternatives, providing low memory requirements at a 7B model size. The model is made available under the permissive Apache 2.0 license, for use without restrictions.

You can fine-tune the models using either the SageMaker Studio UI or SageMaker Python SDK. We discuss both methods in this post.

Fine-tune via the SageMaker Studio UI

In SageMaker Studio, you can access the Mistral model via SageMaker JumpStart under Models, notebooks, and solutions, as shown in the following screenshot.

If you don’t see Mistral models, update your SageMaker Studio version by shutting down and restarting. For more information about version updates, refer to Shut down and Update Studio Apps.

On the model page, you can point to the Amazon Simple Storage Service (Amazon S3) bucket containing the training and validation datasets for fine-tuning. In addition, you can configure deployment configuration, hyperparameters, and security settings for fine-tuning. You can then choose Train to start the training job on a SageMaker ML instance.

Deploy the model

After the model is fine-tuned, you can deploy it using the model page on SageMaker JumpStart. The option to deploy the fine-tuned model will appear when fine-tuning is complete, as shown in the following screenshot.

Fine-tune via the SageMaker Python SDK

You can also fine-tune Mistral models using the SageMaker Python SDK. The complete notebook is available on GitHub. In this section, we provide examples of two types of fine-tuning.

Instruction fine-tuning

Instruction tuning is a technique that involves fine-tuning a language model on a collection of natural language processing (NLP) tasks using instructions. In this technique, the model is trained to perform tasks by following textual instructions instead of specific datasets for each task. The model is fine-tuned with a set of input and output examples for each task, allowing the model to generalize to new tasks that it hasn’t been explicitly trained on as long as prompts are provided for the tasks. Instruction tuning helps improve the accuracy and effectiveness of models and is helpful in situations where large datasets aren’t available for specific tasks.

Let’s walk through the fine-tuning code provided in the example notebook with the SageMaker Python SDK.

We use a subset of the Dolly dataset in an instruction tuning format, and specify the template.json file describing the input and the output formats. The training data must be formatted in JSON lines (.jsonl) format, where each line is a dictionary representing a single data sample. In this case, we name it train.jsonl.

The following snippet is an example of train.jsonl. The keys instruction, context, and response in each sample should have corresponding entries {instruction}, {context}, {response} in the template.json.

{
    "instruction": "What is a dispersive prism?", 
    "context": "In optics, a dispersive prism is an optical prism that is used to disperse light, that is, to separate light into its spectral components (the colors of the rainbow). Different wavelengths (colors) of light will be deflected by the prism at different angles. This is a result of the prism material's index of refraction varying with wavelength (dispersion). Generally, longer wavelengths (red) undergo a smaller deviation than shorter wavelengths (blue). The dispersion of white light into colors by a prism led Sir Isaac Newton to conclude that white light consisted of a mixture of different colors.", 
    "response": "A dispersive prism is an optical prism that disperses the light's different wavelengths at different angles. When white light is shined through a dispersive prism it will separate into the different colors of the rainbow."
}

The following is a sample of template.json:

{
    "prompt": "Below is an instruction that describes a task, paired with an input that provides further context. "
    "Write a response that appropriately completes the request.nn"
    "### Instruction:n{instruction}nn### Input:n{context}nn",
    "completion": " {response}",
}

After you upload the prompt template and the training data to an S3 bucket, you can set the hyperparameters.

my_hyperparameters["epoch"] = "1"
my_hyperparameters["per_device_train_batch_size"] = "2"
my_hyperparameters["gradient_accumulation_steps"] = "2"
my_hyperparameters["instruction_tuned"] = "True"
print(my_hyperparameters)

You can then start the fine-tuning process and deploy the model to an inference endpoint. In the following code, we use an ml.g5.12xlarge instance:

from sagemaker.jumpstart.estimator import JumpStartEstimator

instruction_tuned_estimator = JumpStartEstimator(
    model_id=model_id,
    hyperparameters=my_hyperparameters,
    instance_type="ml.g5.12xlarge",
)
instruction_tuned_estimator.fit({"train": train_data_location}, logs=True)

instruction_tuned_predictor = instruction_tuned_estimator.deploy()

Domain adaptation fine-tuning

Domain adaptation fine-tuning is a process that refines a pre-trained LLM to better suit a specific domain or task. By using a smaller, domain-specific dataset, the LLM can be fine-tuned to understand and generate content that is more accurate, relevant, and insightful for that specific domain, while still retaining the vast knowledge it gained during its original training.

The Mistral model can be fine-tuned on any domain-specific dataset. After it’s fine-tuned, it’s expected to generate domain-specific text and solve various NLP tasks in that specific domain. For the training dataset, provide a train directory and an optional validation directory, each containing a single CSV, JSON, or TXT file. For CSV and JSON formats, use data from the text column or the first column if text isn’t present. Ensure only one file exists under each directory. For instance, input data may be SEC filings of Amazon as a text file:

This report includes estimates, projections, statements relating to our
business plans, objectives, and expected operating results that are “forward-
looking statements” within the meaning of the Private Securities Litigation
Reform Act of 1995, Section 27A of the Securities Act of 1933, and Section 21E
of the Securities Exchange Act of 1934. Forward-looking statements may appear
throughout this report, including the following sections: “Business” (Part I,
Item 1 of this Form 10-K), “Risk Factors” (Part I, Item 1A of this Form 10-K),
and “Management’s Discussion and Analysis of Financial Condition and Results
of Operations” (Part II, Item 7 of this Form 10-K). These forward-looking
statements generally are identified by the words “believe,” “project,”
“expect,” “anticipate,” “estimate,” “intend,” “strategy,” “future,”
“opportunity,” “plan,” “may,” “should,” “will,” “would,” “will be,” “will
continue,” “will likely result,” and similar expressions.

You can start domain adaptation fine-tuning by specifying the hyperparameter “instruction_tuned” as “False“. The rest of the steps are similar to the instruction fine-tuning steps.

We set max_input_length to be 2048 on g5.12xlarge. You can use a higher input length on a larger instance type. For details, see the example notebook.

Inference example with and without fine-tuning

The following table contains the results of the Mistral 7B model fine-tuned with SEC filing documents of Amazon from 2021–2022. We compare the output before and after fine-tuning.

Input Prompt Output Before Fine-Tuning Output After Fine-Tuning Parameters
This Form 10-K report shows that

the company has been in business for 10 years and has a total of 10 employees. The company has a total of 10 employees and is in the business of providing services to the public. The company has a total of 10 employees and is in the business of providing services to the public.

The company has a total of 10 employees and is in the business of providing services to the public. The company has a total of 10 employees and is in the business of providing services to the public. The company has a total of 10 employees and is in the business of providing services to the public. The company has a total of 10 employees and is in the business of providing

we had net sales of $1.2 billion, $1.5 billion, and $1.8 billion in 2004, 2003, and 2002. These amounts include sales from our international websites. Net sales from our international websites were $323 million, $242 million, and $150 million in 2004, 2003, and 2002. Net sales from our international websites represented 27%, 16%, and 8% of net sales in 2004, 2003, and 2002. “temperature”: 0.2,
“max_new_tokens”: 150,
We serve consumers through

our retail network of 1,700 stores, which includes 1,200 stores in the United States, 400 stores in Canada, and 100 stores in Mexico. We also serve consumers through our e-commerce websites.

## What is the biggest home improvement store?

The Home Depot is the largest home improvement retailer in the United States.

## What is the largest home improvement store in the world?

The largest home improvement store in the world is Home Depot.

## What is the biggest home improvement store in the

our online and physical stores and focus on selection, price, and convenience. We design our stores to enable hundreds of millions of unique products to be sold by us and by third parties across dozens of product categories. Customers access our offerings through our websites, mobile apps, Alexa, devices, streaming, and physically visiting our stores. We also manufacture and sell electronic devices, including Kindle, Fire tablet, Fire TV, Echo, and Ring, and we develop and produce media content. We seek to offer our customers low prices, fast and free delivery, easy-to-use functionality, and timely customer service. In addition, we offer subscription services such as Amazon Prime, a
Our vision is

to be the leading provider of quality, affordable, and accessible health care services in the communities we serve.

Our mission is to provide quality, affordable, and accessible health care services to the communities we serve.

Our values are:

– Integrity: We act with honesty and integrity in all our dealings.
– Respect: We treat everyone with respect and dignity.
– Excellence: We strive for excellence in all we do.
– Compassion: We care for our patients and each other with compassion.
– Teamwork: We work together as a team to achieve our goals.

Our commitment to our patients is to provide quality, affordable, and accessible health care services.

to be Earth’s most customer-centric company. We are guided by four principles: customer obsession rather than competitor focus, passion for invention, commitment to operational excellence, and long-term thinking. In each of our segments, we serve our primary customer sets, consisting of consumers, sellers, developers, enterprises, and content creators. In addition, we provide services, such as advertising. We have organized our operations into three segments: North America, International, and AWS. These segments reflect the way the Company evaluates its business performance and manages its operations. Information on our net sales is contained in Item 8 of Part II, “Financial Statements

As you can see, the fine-tuned model provides more specific information related to Amazon compared to the generic pre-trained one. This is because fine-tuning adapts the model to understand the nuances, patterns, and specifics of the provided dataset. By using a pre-trained model and tailoring it with fine-tuning, we ensure that you get the best of both worlds: the broad knowledge of the pre-trained model and the specialized accuracy for your unique dataset. One size may not fit all in the world of machine learning, and fine-tuning is the tailor-made solution you need!

Conclusion

In this post, we discussed fine-tuning the Mistral 7B model using SageMaker JumpStart. We showed how you can use the SageMaker JumpStart console in SageMaker Studio or the SageMaker Python SDK to fine-tune and deploy these models. As a next step, you can try fine-tuning these models on your own dataset using the code provided in the GitHub repository to test and benchmark the results for your use cases.


About the Authors

Xin HuangXin Huang is a Senior Applied Scientist for Amazon SageMaker JumpStart and Amazon SageMaker built-in algorithms. He focuses on developing scalable machine learning algorithms. His research interests are in the area of natural language processing, explainable deep learning on tabular data, and robust analysis of non-parametric space-time clustering. He has published many papers in ACL, ICDM, KDD conferences, and Royal Statistical Society: Series A.

Vivek Gangasani is a AI/ML Startup Solutions Architect for Generative AI startups at AWS. He helps emerging GenAI startups build innovative solutions using AWS services and accelerated compute. Currently, he is focused on developing strategies for fine-tuning and optimizing the inference performance of Large Language Models. In his free time, Vivek enjoys hiking, watching movies and trying different cuisines.

Dr. Ashish Khetan is a Senior Applied Scientist with Amazon SageMaker built-in algorithms and helps develop machine learning algorithms. He got his PhD from University of Illinois Urbana-Champaign. He is an active researcher in machine learning and statistical inference, and has published many papers in NeurIPS, ICML, ICLR, JMLR, ACL, and EMNLP conferences.

Read More

Harness large language models in fake news detection

Harness large language models in fake news detection

Fake news, defined as news that conveys or incorporates false, fabricated, or deliberately misleading information, has been around as early as the emergence of the printing press. The rapid spread of fake news and disinformation online is not only deceiving to the public, but can also have a profound impact on society, politics, economy, and culture. Examples include:

  • Cultivating distrust in the media
  • Undermining the democratic process
  • Spreading false or discredited science (for example, the anti-vax movement)

Advances in artificial intelligence (AI) and machine learning (ML) have made developing tools for creating and sharing fake news even easier. Early examples include advanced social bots and automated accounts that supercharge the initial stage of spreading fake news. In general, it is not trivial for the public to determine whether such accounts are people or bots. In addition, social bots are not illegal tools, and many companies legally purchase them as part of their marketing strategy. Therefore, it’s not easy to curb the use of social bots systematically.

Recent discoveries in the field of generative AI make it possible to produce textual content at an unprecedented pace with the help of large language models (LLMs). LLMs are generative AI text models with over 1 billion parameters, and they are facilitated in the synthesis of high-quality text.

In this post, we explore how you can use LLMs to tackle the prevalent issue of detecting fake news. We suggest that LLMs are sufficiently advanced for this task, especially if improved prompt techniques such as Chain-of-Thought and ReAct are used in conjunction with tools for information retrieval.

We illustrate this by creating a LangChain application that, given a piece of news, informs the user whether the article is true or fake using natural language. The solution also uses Amazon Bedrock, a fully managed service that makes foundation models (FMs) from Amazon and third-party model providers accessible through the AWS Management Console and APIs.

LLMs and fake news

The fake news phenomenon started evolving rapidly with the advent of the internet and more specifically social media (Nielsen et al., 2017). On social media,­ fake news can be shared quickly in a user’s network, leading the public to form the wrong collective opinion. In addition, people often propagate fake news impulsively, ignoring the factuality of the content if the news resonates with their personal norms (Tsipursky et al. 2018). Research in social science has suggested that cognitive bias (confirmation bias, bandwagon effect, and choice-supportive bias) is one of the most pivotal factors in making irrational decisions in terms of the both creation and consumption of fake news (Kim, et al., 2021). This also implies that news consumers share and consume information only in the direction of strengthening their beliefs.

The power of generative AI to produce textual and rich content at an unprecedented pace aggravates the fake news problem. An example worth mentioning is deepfake technology—combining various images on an original video and generating a different video. Besides the disinformation intent that human actors bring to the mix, LLMs add a whole new set of challenges:

  • Factual errors – LLMs have an increased risk of containing factual errors due to the nature of their training and ability to be creative while generating the next words in a sentence. LLM training is based on repeatedly presenting a model with incomplete input, then using ML training techniques until it correctly fills in the gaps, thereby learning language structure and a language-based world model. Consequently, although LLMs are great pattern matchers and re-combiners (“stochastic parrots”), they fail at a number of simple tasks that require logical reasoning or mathematical deduction, and can hallucinate answers. In addition, temperature is one of the LLM input parameters that controls the behavior of the model when generating the next word in a sentence. By selecting a higher temperature, the model will use a lower-probability word, providing a more random response.
  • Lengthy – Generated texts tend to be lengthy and lack a clearly defined granularity for individual facts.
  • Lack of fact-checking – There is no standardized tooling available for fact-checking during the process of text generation.

Overall, the combination of human psychology and limitations of AI systems has created a perfect storm for the proliferation of fake news and misinformation online.

Solution overview

LLMs are demonstrating outstanding capabilities in language generation, understanding, and few-shot learning. They are trained on a vast corpus of text from the internet, where quality and accuracy of extracted natural language may not be assured.

In this post, we provide a solution to detect fake news based both on Chain-of-Thought and Re-Act (Reasoning and Acting) prompt approaches. First, we discuss those two prompt engineering techniques, then we show their implementation using LangChain and Amazon Bedrock.

The following architecture diagram outlines the solution for our fake news detector.

Architecture diagram for fake news detection.

We use a subset of the FEVER dataset containing a statement and the ground truth about the statement indicating false, true, or unverifiable claims (Thorne J. et al., 2018).

The workflow can be broken down into the following steps:

  1. The user selects one of the statements to check if fake or true.
  2. The statement and the fake news detection task are incorporated into the prompt.
  3. The prompt is passed to LangChain, which invokes the FM in Amazon Bedrock.
  4. Amazon Bedrock returns a response to the user request with the statement True or False.

In this post, we use the Claude v2 model from Anthrophic (anthropic.claude-v2). Claude is a generative LLM based on Anthropic’s research into creating reliable, interpretable, and steerable AI systems. Created using techniques like constitutional AI and harmlessness training, Claude excels at thoughtful dialogue, content creation, complex reasoning, creativity, and coding. However, by using Amazon Bedrock and our solution architecture, we also have the flexibility to choose among other FMs provided by Amazon, AI21labs, Cohere, and Stability.ai.

You can find the implementation details in the following sections. The source code is available in the GitHub repository.

Prerequisites

For this tutorial, you need a bash terminal with Python 3.9 or higher installed on either Linux, Mac, or a Windows Subsystem for Linux and an AWS account.

We also recommend using either an Amazon SageMaker Studio notebook, an AWS Cloud9 instance, or an Amazon Elastic Compute Cloud (Amazon EC2) instance.

Deploy fake news detection using the Amazon Bedrock API

The solution uses the Amazon Bedrock API, which can be accessed using the AWS Command Line Interface (AWS CLI), the AWS SDK for Python (Boto3), or an Amazon SageMaker notebook. Refer to the Amazon Bedrock User Guide for more information. For this post, we use the Amazon Bedrock API via the AWS SDK for Python.

Set up Amazon Bedrock API environment

To set up your Amazon Bedrock API environment, complete the following steps:

  1. Download the latest Boto3 or upgrade it:
    pip install --upgrade boto3

  2. Make sure you configure the AWS credentials using the aws configure command or pass them to the Boto3 client.
  3. Install the latest version of LangChain:
    pip install “langchain>=0.0.317” --quiet

You can now test your setup using the following Python shell script. The script instantiates the Amazon Bedrock client using Boto3. Next, we call the list_foundation_models API to get the list of foundation models available for use.

import boto3 
import json 
bedrock = boto3.client( 'bedrock', region_name=YOUR_REGION) 
print(json.dumps(bedrock.list_foundation_models(), indent=4))

After successfully running the preceding command, you should get the list of FMs from Amazon Bedrock.

LangChain as a prompt chaining solution

To detect fake news for a given sentence, we follow the zero-shot Chain-of-Thought reasoning process (Wei J. et al., 2022), which is composed of the following steps:

  1. Initially, the model attempts to create a statement about the news prompted.
  2. The model creates a bullet point list of assertions.
  3. For each assertion, the model determines if the assertion is true or false. Note that using this methodology, the model relies exclusively on its internal knowledge (weights computed in the pre-training phase) to reach a verdict. The information is not verified against any external data at this point.
  4. Given the facts, the model answers TRUE or FALSE for the given statement in the prompt.

To achieve these steps, we use LangChain, a framework for developing applications powered by language models. This framework allows us to augment the FMs by chaining together various components to create advanced use cases. In this solution, we use the built-in SimpleSequentialChain in LangChain to create a simple sequential chain. This is very useful, because we can take the output from one chain and use it as the input to another.

Amazon Bedrock is integrated with LangChain, so you only need to instantiate it by passing the model_id when instantiating the Amazon Bedrock object. If needed, the model inference parameters can be provided through the model_kwargs argument, such as:

  • maxTokenCount – The maximum number of tokens in the generated response
  • stopSequences – The stop sequence used by the model
  • temperature – A value that ranges between 0–1, with 0 being the most deterministic and 1 being the most creative
  • top – A value that ranges between 0–1, and is used to control tokens’ choices based on the probability of the potential choices

If this is the first time you are using an Amazon Bedrock foundational model, make sure you request access to the model by selecting from the list of models on the Model access page on the Amazon Bedrock console, which in our case is claude-v2 from Anthropic.

from langchain.llms.bedrock import Bedrock
bedrock_runtime = boto3.client(
    service_name='bedrock-runtime',
    region_name= YOUR_REGION,
)
model_kwargs={
        'max_tokens_to_sample': 8192
    }
llm = Bedrock(model_id=" anthropic.claude-v2", client=bedrock_runtime, model_kwargs=model_kwargs)

The following function defines the Chain-of-Thought prompt chain we mentioned earlier for detecting fake news. The function takes the Amazon Bedrock object (llm) and the user prompt (q) as arguments. LangChain’s PromptTemplate functionality is used here to predefine a recipe for generating prompts.

from langchain.prompts import PromptTemplate
from langchain.chains import LLMChain
from langchain.chains import SimpleSequentialChain

def generate_and_print(llm, q):
    total_prompt = """"""

    # the model is asked to create a bullet point list of assertions
    template = """Here is a statement:
    {statement}
    Make a bullet point list of the assumptions you made when given the above statement.nn"""
    prompt_template = PromptTemplate(input_variables=["statement"], template=template)
    assumptions_chain = LLMChain(llm=llm, prompt=prompt_template)
    total_prompt = total_prompt + template

    # the model is asked to create a bullet point list of assertions    
    template = """Here is a bullet point list of assertions:
    {assertions}
    For each assertion, determine whether it is true or false. If it is false, explain why.nn"""
    prompt_template = PromptTemplate(input_variables=["assertions"], template=template)
    fact_checker_chain = LLMChain(llm=llm, prompt=prompt_template)
    total_prompt = total_prompt + template

    #for each assertion, the model is askded to determine if the assertion is true or false, based on internal knowledge alone

    template = """ Based on the above assertions, the final response is FALSE if one of the assertions is FALSE. Otherwise, the final response is TRUE. You should only respond with TRUE or FALSE.'{}'""".format(q)
    template = """{facts}n""" + template
    prompt_template = PromptTemplate(input_variables=["facts"], template=template)
    answer_chain = LLMChain(llm=llm, prompt=prompt_template)
    total_prompt = total_prompt + template

    #SimpleSequentialChain allows us to take the output from one chain and use it as the input to another
    overall_chain = SimpleSequentialChain(chains=[assumptions_chain, fact_checker_chain, answer_chain], verbose=True)
    answer = overall_chain.run(q)

    return answer

The following code calls the function we defined earlier and provides the answer. The statement is TRUE or FALSE. TRUE means that the statement provided contains correct facts, and FALSE means that the statement contains at least one incorrect fact.

from IPython.display import display, Markdown

q="The first woman to receive a Ph.D. in computer science was Dr. Barbara Liskov, who earned her degree from Stanford University in 1968."
print(f'The statement is: {q}')
display(Markdown(generate_and_print(llm, q)))

An example of a statement and model response is provided in the following output:

The statement is: The first woman to receive a Ph.D. in computer science was Dr. Barbara Liskov, who earned her degree from Stanford University in 1968.

> Entering new SimpleSequentialChain chain...
 Here is a bullet point list of assumptions I made about the statement:

- Dr. Barbara Liskov was the first woman to earn a Ph.D. in computer science. 

- Dr. Liskov earned her Ph.D. from Stanford University.

- She earned her Ph.D. in 1968.

- No other woman earned a Ph.D. in computer science prior to 1968.

- Stanford University had a computer science Ph.D. program in 1968. 

- The statement refers to Ph.D. degrees earned in the United States.
 Here are my assessments of each assertion:

- Dr. Barbara Liskov was the first woman to earn a Ph.D. in computer science.
  - True. Dr. Liskov was the first American woman to earn a Ph.D. in computer science, which she received from Stanford University in 1968.

- Dr. Liskov earned her Ph.D. from Stanford University.
  - True. Multiple sources confirm she received her Ph.D. from Stanford in 1968.

- She earned her Ph.D. in 1968.
  - True. This is consistent across sources.

- No other woman earned a Ph.D. in computer science prior to 1968.
  - False. While she was the first American woman, Mary Kenneth Keller earned a Ph.D. in computer science from the University of Wisconsin in 1965. However, Keller earned her degree in the US as well.

- Stanford University had a computer science Ph.D. program in 1968.
  - True. Stanford established its computer science department and Ph.D. program in 1965.

- The statement refers to Ph.D. degrees earned in the United States.
  - False. The original statement does not specify the country. My assumptions that it refers to the United States is incorrect. Keller earned her Ph.D. in the US before Liskov.
 False

ReAct and tools

In the preceding example, the model correctly identified that the statement is false. However, submitting the query again demonstrates the model’s inability to distinguish the correctness of facts. The model doesn’t have the tools to verify the truthfulness of statements beyond its own training memory, so subsequent runs of the same prompt can lead it to mislabel fake statements as true. In the following code, you have a different run of the same example:

The statement is: The first woman to receive a Ph.D. in computer science was Dr. Barbara Liskov, who earned her degree from Stanford University in 1968.

> Entering new SimpleSequentialChain chain...
 Here is a bullet point list of assumptions I made about the statement:

- Dr. Barbara Liskov was the first woman to earn a Ph.D. in computer science
- Dr. Liskov earned her Ph.D. degree in 1968 
- Dr. Liskov earned her Ph.D. from Stanford University
- Stanford University awarded Ph.D. degrees in computer science in 1968
- Dr. Liskov was a woman
- Ph.D. degrees existed in 1968
- Computer science existed as a field of study in 1968
 Here are my assessments of each assertion:

- Dr. Barbara Liskov was the first woman to earn a Ph.D. in computer science
    - True. Dr. Liskov was the first woman to earn a Ph.D. in computer science in 1968 from Stanford University.

- Dr. Liskov earned her Ph.D. degree in 1968
    - True. Multiple sources confirm she received her Ph.D. in computer science from Stanford in 1968.

- Dr. Liskov earned her Ph.D. from Stanford University 
    - True. Dr. Liskov earned her Ph.D. in computer science from Stanford University in 1968.

- Stanford University awarded Ph.D. degrees in computer science in 1968
    - True. Stanford awarded Liskov a Ph.D. in computer science in 1968, so they offered the degree at that time.

- Dr. Liskov was a woman
    - True. All biographical information indicates Dr. Liskov is female.

- Ph.D. degrees existed in 1968
    - True. Ph.D. degrees have existed since the late 19th century.

- Computer science existed as a field of study in 1968
    - True. While computer science was a relatively new field in the 1960s, Stanford and other universities offered it as a field of study and research by 1968.
 True

One technique for guaranteeing truthfulness is ReAct. ReAct (Yao S. et al., 2023) is a prompt technique that augments the foundation model with an agent’s action space. In this post, as well as in the ReAct paper, the action space implements information retrieval using search, lookup, and finish actions from a simple Wikipedia web API.

The reason behind using ReAct in comparison to Chain-of-Thought is to use external knowledge retrieval to augment the foundation model to detect if a given piece of news is fake or true.

In this post, we use LangChain’s implementation of ReAct through the agent ZERO_SHOT_REACT_DESCRIPTION. We modify the previous function to implement ReAct and use Wikipedia by using the load_tools function from the langchain.agents.

We also need to install the Wikipedia package:

!pip install Wikipedia

Below is the new code:

from langchain.agents import load_tools, initialize_agent, AgentType

def generate_and_print(llm, q):
    print(f'Inside generate_and_print: q = {q}')
    tools = load_tools(["wikipedia"], llm=llm)
    agent = initialize_agent(tools, llm, 
                             agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, 
                             verbose=True,
                             handle_parsing_errors=True,
                             agent_kwargs={})

    input = """Here is a statement:
    {statement}
    Is this statement correct? You can use tools to find information if needed.
    The final response is FALSE if the statement is FALSE. Otherwise, TRUE."""

    answer = agent.run(input.format(statement=q))
    
    return answer

The following is the output of the preceding function given the same statement used before:

> Entering new AgentExecutor chain...
 Here are my thoughts and actions to determine if the statement is true or false:

Thought: To verify if this statement about the first woman to receive a PhD in computer science is true, I should consult a reliable information source like Wikipedia.

Action: Wikipedia
Action Input: first woman to receive phd in computer science
Observation: Page: Fu Foundation School of Engineering and Applied Science
Summary: The Fu Foundation School of Engineering and Applied Science (popularly known as SEAS or Columbia Engineering; previously known as Columbia School of Mines) is the engineering and applied science school of Columbia University. It was founded as the School of Mines in 1863 and then the School of Mines, Engineering and Chemistry before becoming the School of Engineering and Applied Science. On October 1, 1997, the school was renamed in honor of Chinese businessman Z.Y. Fu, who had donated $26 million to the school.
The Fu Foundation School of Engineering and Applied Science maintains a close research tie with other institutions including NASA, IBM, MIT, and The Earth Institute. Patents owned by the school generate over $100 million annually for the university. SEAS faculty and alumni are responsible for technological achievements including the developments of FM radio and the maser.
The School's applied mathematics, biomedical engineering, computer science and the financial engineering program in operations research are very famous and highly ranked. The current SEAS faculty include 27 members of the National Academy of Engineering and one Nobel laureate. In all, the faculty and alumni of Columbia Engineering have won 10 Nobel Prizes in physics, chemistry, medicine, and economics.
The school consists of approximately 300 undergraduates in each graduating class and maintains close links with its undergraduate liberal arts sister school Columbia College which shares housing with SEAS students. The School's current dean is Shih-Fu Chang, who was appointed in 2022.

Page: Doctor of Science
Summary: A Doctor of Science (Latin: Scientiae Doctor; most commonly abbreviated DSc or ScD) is an academic research doctorate awarded in a number of countries throughout the world. In some countries, a Doctor of Science is the degree used for the standard doctorate in the sciences; elsewhere a Doctor of Science is a "higher doctorate" awarded in recognition of a substantial and sustained contribution to scientific knowledge beyond that required for a Doctor of Philosophy (PhD).

Page: Timeline of women in science
Summary: This is a timeline of women in science, spanning from ancient history up to the 21st century. While the timeline primarily focuses on women involved with natural sciences such as astronomy, biology, chemistry and physics, it also includes women from the social sciences (e.g. sociology, psychology) and the formal sciences (e.g. mathematics, computer science), as well as notable science educators and medical scientists. The chronological events listed in the timeline relate to both scientific achievements and gender equality within the sciences.
Thought: Based on the Wikipedia pages, the statement appears to be false. The Wikipedia Timeline of Women in Science page indicates that Adele Goldstine was the first woman to earn a PhD in computer science in 1964 from the University of Michigan, not Barbara Liskov from Stanford in 1968. Therefore, my final answer is:

Final Answer: FALSE

Clean up

To save costs, delete all the resources you deployed as part of the tutorial. If you launched AWS Cloud9 or an EC2 instance, you can delete it via the console or using the AWS CLI. Similarly, you can delete the SageMaker notebook you may have created via the SageMaker console.

Limitations and related work

The field of fake news detection is actively researched in the scientific community. In this post, we used Chain-of-Thought and ReAct techniques and in evaluating the techniques, we only focused on the accuracy of the prompt technique classification (if a given statement is true or false). Therefore, we haven’t considered other important aspects such as speed of the response, nor extended the solution to additional knowledge base sources besides Wikipedia.

Although this post focused on two techniques, Chain-of-Thought and ReAct, an extensive body of work has explored how LLMs can detect, eliminate or mitigate fake news. Lee et al. has proposed the use of an encoder-decoder model using NER (named entity recognition) to mask the named entities in order to ensure that the token masked actually uses the knowledge encoded in the language model. Chern et.al. developed FacTool, which uses Chain-of-Thought principles to extract claims from the prompt, and consequently collect relevant evidences of the claims. The LLM then judges the factuality of the claim given the retrieved list of evidences. Du E. et al. presents a complementary approach where multiple LLMs propose and debate their individual responses and reasoning processes over multiple rounds in order to arrive at a common final answer.

Based on the literature, we see that the effectiveness of LLMs in detecting fake news increases when the LLMs are augmented with external knowledge and multi-agent conversation capability. However, these approaches are more computationally complex because they require multiple model calls and interactions, longer prompts, and lengthy network layer calls. Ultimately, this complexity translates into an increased overall cost. We recommend assessing the cost-to-performance ratio before deploying similar solutions in production.

Conclusion

In this post, we delved into how to use LLMs to tackle the prevalent issue of fake news, which is one of the major challenges of our society nowadays. We started by outlining the challenges presented by fake news, with an emphasis on its potential to sway public sentiment and cause societal disruptions.

We then introduced the concept of LLMs as advanced AI models that are trained on a substantial quantity of data. Due to this extensive training, these models boast an impressive understanding of language, enabling them to produce human-like text. With this capacity, we demonstrated how LLMs can be harnessed in the battle against fake news by using two different prompt techniques, Chain-of-Thought and ReAct.

We underlined how LLMs can facilitate fact-checking services on an unparalleled scale, given their capability to process and analyze vast amounts of text swiftly. This potential for real-time analysis can lead to early detection and containment of fake news. We illustrated this by creating a Python script that, given a statement, highlights to the user whether the article is true or fake using natural language.

We concluded by underlining the limitations of the current approach and ended on a hopeful note, stressing that, with the correct safeguards and continuous enhancements, LLMs could become indispensable tools in the fight against fake news.

We’d love to hear from you. Let us know what you think in the comments section, or use the issues forum in the GitHub repository.

Disclaimer: The code provided in this post is meant for educational and experimentation purposes only. It should not be relied upon to detect fake news or misinformation in real-world production systems. No guarantees are made about the accuracy or completeness of fake news detection using this code. Users should exercise caution and perform due diligence before utilizing these techniques in sensitive applications.

To get started with Amazon Bedrock, visit the Amazon Bedrock console.


About the authors

Anamaria Todor is a Principal Solutions Architect based in Copenhagen, Denmark. She saw her first computer when she was 4 years old and never let go of computer science, video games, and engineering since. She has worked in various technical roles, from freelancer, full-stack developer, to data engineer, technical lead, and CTO, at various companies in Denmark, focusing on the gaming and advertising industries. She has been at AWS for over 3 years, working as a Principal Solutions Architect, focusing mainly on life sciences and AI/ML. Anamaria has a bachelor’s in Applied Engineering and Computer Science, a master’s degree in Computer Science, and over 10 years of AWS experience. When she’s not working or playing video games, she’s coaching girls and female professionals in understanding and finding their path through technology.

Marcel Castro is a Senior Solutions Architect based in Oslo, Norway. In his role, Marcel helps customers with architecture, design, and development of cloud-optimized infrastructure. He is a member of the AWS Generative AI Ambassador team with the goal to drive and support EMEA customers on their generative AI journey. He holds a PhD in Computer Science from Sweden and a master’s and bachelor’s degree in Electrical Engineering and Telecommunications from Brazil.

Read More

Model management for LoRA fine-tuned models using Llama2 and Amazon SageMaker

Model management for LoRA fine-tuned models using Llama2 and Amazon SageMaker

In the era of big data and AI, companies are continually seeking ways to use these technologies to gain a competitive edge. One of the hottest areas in AI right now is generative AI, and for good reason. Generative AI offers powerful solutions that push the boundaries of what’s possible in terms of creativity and innovation. At the core of these cutting-edge solutions lies a foundation model (FM), a highly advanced machine learning model that is pre-trained on vast amounts of data. Many of these foundation models have shown remarkable capability in understanding and generating human-like text, making them a valuable tool for a variety of applications, from content creation to customer support automation.

However, these models are not without their challenges. They are exceptionally large and require large amounts of data and computational resources to train. Additionally, optimizing the training process and calibrating the parameters can be a complex and iterative process, requiring expertise and careful experimentation. These can be barriers for many organizations looking to build their own foundation models. To overcome this challenge, many customers are considering to fine-tune existing foundation models. This is a popular technique to adjust a small portion of model parameters for specific applications while still preserving the knowledge already encoded in the model. It allows organizations to use the power of these models while reducing the resources required to customize to a specific domain or task.

There are two primary approaches to fine-tuning foundation models: traditional fine-tuning and parameter-efficient fine-tuning. Traditional fine-tuning involves updating all the parameters of the pre-trained model for a specific downstream task. On the other hand, parameter-efficient fine-tuning includes a variety of techniques that allow for customization of a model without updating all the original model parameters. One such technique is called Low-rank Adaptation (LoRA). It involves adding small, task-specific modules to the pre-trained model and training them while keeping the rest of the parameters fixed as shown in the following image.

Source: Generative AI on AWS (O’Reilly, 2023)

LoRA has gained popularity recently for several reasons. It offers faster training, reduced memory requirements, and the ability to reuse pre-trained models for multiple downstream tasks. More importantly, the base model and adapter can be stored separately and combined at any time, making it easier to store, distribute, and share fine-tuned versions. However, this introduces a new challenge: how to properly manage these new types of fine-tuned models. Should you combine the base model and adapter or keep them separate? In this post, we walk through best practices for managing LoRA fine-tuned models on Amazon SageMaker to address this emerging question.

Working with FMs on SageMaker Model Registry

In this post, we walk through an end-to-end example of fine-tuning the Llama2 large language model (LLM) using the QLoRA method. QLoRA combines the benefits of parameter efficient fine-tuning with 4-bit/8-bit quantization to further reduce the resources required to fine-tune a FM to a specific task or use case. For this, we will use the pre-trained 7 billion parameter Llama2 model and fine-tune it on the databricks-dolly-15k dataset. LLMs like Llama2 have billions of parameters and are pretrained on massive text datasets. Fine-tuning adapts an LLM to a downstream task using a smaller dataset. However, fine-tuning large models is computationally expensive. This is why we will use the QLoRA method to quantize the weights during finetuning to reduce this computation cost.

In our examples, you will find two notebooks (llm-finetune-combined-with-registry.ipynb and llm-finetune-separate-with-registry.ipynb). Each works through a different way to handle LoRA fine-tuned models as illustrated in the following diagram:

  1. First, we download the pre-trained Llama2 model with 7 billion parameters using SageMaker Studio Notebooks. LLMs, like Llama2, have shown state-of-the-art performance on natural language processing (NLP) tasks when fine-tuned on domain-specific data.
  2. Next, we fine-tune Llama2 on the databricks-dolly-15k dataset using the QLoRA method. QLoRA reduces the computational cost of fine-tuning by quantizing model weights.
  3. During fine-tuning, we integrate SageMaker Experiments Plus with the Transformers API to automatically log metrics like gradient, loss, etc.
  4. We then version the fine-tuned Llama2 model in SageMaker Model Registry using two approaches:
    1. Storing the full model
    2. Storing the adapter and base model separately.
  5. Finally, we host the fine-tuned Llama2 models using Deep Java Library (DJL) Serving on a SageMaker Real-time endpoint.

In the following sections, we will dive deeper into each of these steps, to demonstrate the flexibility of SageMaker for different LLM workflows and how these features can help improve the operations of your models.

Prerequisites

Complete the following prerequisites to start experimenting with the code.

  • Create a SageMaker Studio Domain: Amazon SageMaker Studio, specifically Studio Notebooks, is used to kick off the Llama2 fine-tuning task then register and view models within SageMaker Model Registry. SageMaker Experiments is also used to view and compare Llama2 fine-tuning job logs (training loss/test loss/etc.).
  • Create an Amazon Simple Storage Service (S3) bucket: Access to an S3 bucket to store training artifacts and model weights is required. For instructions, refer to Creating a bucket. The sample code used for this post will use the SageMaker default S3 bucket but you can customize it to use any relevant S3 bucket.
  • Set up Model Collections (IAM permissions): Update your SageMaker Execution Role with permissions to resource-groups as listed under Model Registry Collections Developer Guide to implement Model Registry grouping using Model Collections.
  • Accept the Terms & Conditions for Llama2: You will need to accept the end-user license agreement and acceptable use policy for using the Llama2 foundation model.

The examples are available in the GitHub repository. The notebook files are tested using Studio notebooks running on PyTorch 2.0.0 Python 3.10 GPU Optimized kernel and ml.g4dn.xlarge instance type.

Experiments plus callback integration

Amazon SageMaker Experiments lets you organize, track, compare and evaluate machine learning (ML) experiments and model versions from any integrated development environment (IDE), including local Jupyter Notebooks, using the SageMaker Python SDK or boto3. It provides the flexibility to log your model metrics, parameters, files, artifacts, plot charts from the different metrics, capture various metadata, search through them and support model reproducibility. Data scientists can quickly compare the performance and hyperparameters for model evaluation through visual charts and tables. They can also use SageMaker Experiments to download the created charts and share the model evaluation with their stakeholders.

Training LLMs can be a slow, expensive, and iterative process. It is very important for a user to track LLM experimentation at scale to prevent an inconsistent model tuning experience. HuggingFace Transformer APIs allow users to track metrics during training tasks through Callbacks. Callbacks are “read only” pieces of code that can customize the behavior of the training loop in the PyTorch Trainer that can inspect the training loop state for progress reporting, logging on TensorBoard or SageMaker Experiments Plus via custom logic (which is included as a part of this codebase).

You can import the SageMaker Experiments callback code included in this post’s code repository as shown in the following code block:

# imports a custom implementation of Experiments Callback
from smexperiments_callback import SageMakerExperimentsCallback
...
...
# Create Trainer instance with SageMaker experiments callback
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=validation_dataset,
    data_collator=default_data_collator,
    callbacks=[SageMakerExperimentsCallback] # Add our Experiments Plus Callback function
)

This callback will automatically log the following information into SageMaker Experiments as a part of the training run:

  • Training Parameters and Hyper-Parameters
  • Model Training and Validation loss at Step, Epoch and Final
  • Model Input and Output artifacts (training dataset, validation dataset, model output location, training debugger and more)

The following graph shows examples of the charts you can display by using that information.

This allows you to compare multiple runs easily using the Analyze feature of SageMaker Experiments. You can select the experiment runs you want to compare, and they will automatically populate comparison graphs.

Register fine-tuned models to Model Registry Collections

Model Registry Collections is a feature of SageMaker Model Registry that allows you to group registered models that are related to each other and organize them in hierarchies to improve model discoverability at scale. We will use Model Registry Collections to keep track of the base model and fine-tuned variants.

Full Model Copy method

The first method combines the base model and LoRA adapter and saves the full fine-tuned model. The following code illustrates the model merging process and saves the combined model using model.save_pretrained().

if args.merge_weights:
        
    trainer.model.save_pretrained(temp_dir, safe_serialization=False)
    # clear memory
    del model
    del trainer
    torch.cuda.empty_cache()
    
    from peft import AutoPeftModelForCausalLM

    # load PEFT model in fp16
    model = AutoPeftModelForCausalLM.from_pretrained(
        temp_dir,
        low_cpu_mem_usage=True,
        torch_dtype=torch.float16,
    )  
    # Merge LoRA and base model and save
    model = model.merge_and_unload()        
    model.save_pretrained(
        args.sm_model_dir, safe_serialization=True, max_shard_size="2GB"
    )

Combining the LoRA adapter and base model into a single model artifact after fine-tuning has advantages and disadvantages. The combined model is self-contained and can be independently managed and deployed without needing the original base model. The model can be tracked as its own entity with a version name reflecting the base model and fine-tuning data. We can adopt a nomenclature using the base_model_name + fine-tuned dataset_name to organize the model groups. Optionally, model collections could associate the original and fine-tuned models, but this may not be necessary since the combined model is independent.  The following code snippet shows you how to register the fine-tuned model.

# Model Package Group Vars
ft_package_group_name = f"{model_id.replace('/', '--')}-{dataset_name}"
ft_package_group_desc = "QLoRA for model Mikael110/llama-2-7b-{dataset_name}-fp16"
...
...
...
model_package_group_input_dict = {
    "ModelPackageGroupName" : ft_package_group_name,
    "ModelPackageGroupDescription" : ft_package_group_desc,
    "Tags": ft_tags
}
create_model_pacakge_group_response = sm_client.create_model_package_group(
**model_package_group_input_dict
)

You can use the training estimator to register the model into Model Registry.

inference_image_uri = sagemaker.image_uris.retrieve(
    "djl-deepspeed", region=region, version="0.23.0"
)
print(f"Image going to be used is ---- > {inference_image_uri}")

model_package = huggingface_estimator.register(
    content_types=["application/json"],
    response_types=["application/json"],
    inference_instances=[
        "ml.p2.16xlarge", 
...
...
...
    ],
    image_uri = inference_image_uri,
    customer_metadata_properties = {"training-image-uri": huggingface_estimator.training_image_uri()},  #Store the training image url
    model_package_group_name=ft_model_pkg_group_name,
    approval_status="Approved"
)

model_package_arn = model_package.model_package_arn
print("Model Package ARN : ", model_package_arn)

From Model Registry, you can retrieve the model package and deploy that model directly.

endpoint_name = f"{name_from_base(model_group_for_base)}-endpoint"

model_package.deploy(
    initial_instance_count=1,
    instance_type="ml.g5.12xlarge",
    endpoint_name=endpoint_name
)

However, there are drawbacks to this approach. Combining the models leads to storage inefficiency and redundancy since the base model is duplicated in each fine-tuned version. As model size and the number of fine-tuned models increase, this exponentially inflates storage needs. Taking the llama2 7b model as an example, the base model is approximately 13 GB and the fine-tuned model is 13.6 GB. 96% percent of the model needs to be duplicated after each fine tuning. Additionally, distributing and sharing very large model files also becomes more difficult and presents operational challenges as file transfer and management cost increases with increasing model size and fine-tune jobs.

Separate adapter and base method

The second method focuses on separation of base weights and adapter weights by saving them as separate model components and loading them sequentially at runtime.

    ..
    ..
    ..
    else:   
        # save finetuned LoRA model and then the tokenizer for inference
        trainer.model.save_pretrained(
            args.sm_model_dir, 
            safe_serialization=True
        )
    tokenizer.save_pretrained(
        args.sm_model_dir
    )

Saving base and adapter weights has advantages and disadvantages, similar to the Full Model Copy method. One advantage is that it can save storage space. The base weights, which are the largest component of a fine-tuned model, are only saved once and can be reused with other adapter weights that are tuned for different tasks. For example, the base weights of Llama2-7B are about 13 GB, but each fine-tuning task only needs to store about 0.6 GB of adapter weights, which is a 95% space savings. Another advantage is that base weights can be managed separately from adapter weights using a base weights only model registry. This can be useful for SageMaker domains that are running in a VPC only mode without an internet gateway, since the base weights can be accessed without having to go through the internet.

Create Model Package Group for base weights

### Create Model Package Group
base_package_group_name = model_id.replace('/', '--')
base_package_group_desc = "Source: https://huggingface.co/Mikael110/llama-2-7b-guanaco-fp16"
...
...
...
model_package_group_input_dict = {
    "ModelPackageGroupName" : base_package_group_name,
    "ModelPackageGroupDescription" : base_package_group_desc,
    "Tags": base_tags
}
create_model_pacakge_group_response = sm_client.create_model_package_group(
**model_package_group_input_dict
)

>>>
Created ModelPackageGroup Arn : arn:aws:sagemaker:us-west-2:376678947624:model-package-group/Mikael110--llama-2-7b-guanaco-fp16
...
...
...

### Register Base Model Weights
from sagemaker.huggingface import HuggingFaceModel

# create Hugging Face Model Class
huggingface_model = HuggingFaceModel(
    transformers_version='4.28',
    pytorch_version='2.0',
    py_version='py310',
    model_data=model_data_uri, # this is an S3 path to your base weights as *.tar.gz
    role=role,
)

_response = huggingface_model.register(
    content_types=["application/json"],
    response_types=["application/json"],
    inference_instances=[
    "ml.p2.16xlarge",
    ...
    ],
    transform_instances=[
    "ml.p2.16xlarge",
    ...
    ],
    model_package_group_name=base_model_pkg_group_name,
    approval_status="Approved"
 )

Create Model Package Group for QLoRA weights

The following code shows how to tag QLoRA weights with the dataset/task type and register fine-tuned delta weights into a separate model registry and track the delta weights separately.

### Create Model Package Group for delta weights
ft_package_group_name = f"{model_id.replace('/', '--')}-finetuned-sql"
ft_package_group_desc = "QLoRA for model Mikael110/llama-2-7b-guanaco-fp16"
ft_tags = [
    {
    "Key": "modelType",
    "Value": "QLoRAModel"
    },
    {
    "Key": "fineTuned",
    "Value": "True"
    },
    {
    "Key": "sourceDataset",
    "Value": f"{dataset_name}"
    }
]
model_package_group_input_dict = {
    "ModelPackageGroupName" : ft_package_group_name,
    "ModelPackageGroupDescription" : ft_package_group_desc,
    "Tags": ft_tags
}
create_model_pacakge_group_response = sm_client.create_model_package_group(
**model_package_group_input_dict
)
print(f'Created ModelPackageGroup Arn : {create_model_pacakge_group_response["ModelPackageGroupArn"]}')
ft_model_pkg_group_name = create_model_pacakge_group_response["ModelPackageGroupArn"]

>>> 
Created ModelPackageGroup Arn : arn:aws:sagemaker:us-east-1:811828458885:model-package-group/mikael110--llama-2-7b-guanaco-fp16-finetuned-sql

...
...
...

### Register Delta Weights QLoRA Model Weights
huggingface_model = HuggingFaceModel(
    transformers_version='4.28',
    pytorch_version='2.0',  
    py_version='py310',
    model_data="s3://sagemaker-us-east-1-811828458885/huggingface-qlora-2308180454/output/model.tar.gz", OR #huggingface_estimator.model_data
    role=role,
)

_response = huggingface_model.register(
    content_types=["application/json"],
    response_types=["application/json"],
    inference_instances=[
    "ml.p2.16xlarge",
    ...
    ],
    transform_instances=[
    "ml.p2.16xlarge",
    ...
    ],
    model_package_group_name=ft_model_pkg_group_name,
    approval_status="Approved"
)

>>>
Model collection creation status: {'added_groups': ['arn:aws:sagemaker:us-east-1:811828458885:model-package-group/mikael110--llama-2-7b-guanaco-fp16-finetuned-sql'], 'failure': []}

The following snippet shows a view from the Model Registry where the models are split into base and fine-tuned weights.

Managing models, datasets, and tasks for hyper-personalized LLMs can quickly become overwhelming. SageMaker Model Registry Collections can help you group related models together and organize them in a hierarchy to improve model discoverability. This makes it easier to track the relationships between base weights, adapter weights, and fine-tuning task datasets. You can also create complex relationships and linkages between models.

Create a new Collection and add your base model weights to this Collection

# create model collection
base_collection = model_collector.create(
    collection_name=model_group_for_base # ex: "Website_Customer_QnA_Bot_Model"
)

# Add the base weights at first level of model collections as all future models 
# are going to be tuned from the base weights
_response = model_collector.add_model_groups(
    collection_name=base_collection["Arn"],
    model_groups=[base_model_pkg_group_name]
)
print(f"Model collection creation status: {_response}")

>>>
Model collection creation status: {'added_groups': ['arn:aws:sagemaker:us-west-2:376678947624:model-package-group/Mikael110--llama-2-7b-guanaco-fp16'], 'failure': []}

Link all your Fine-Tuned LoRA Adapter Delta Weights to this collection by task and/or dataset

# create model collection for finetuned and link it back to the base
finetuned_collection = model_collector.create(
    collection_name=model_group_for_finetune,
    parent_collection_name=model_group_for_base
)

# add finetuned model package group to the new finetuned collection
_response = model_collector.add_model_groups(
    collection_name=model_group_for_finetune,
    model_groups=[ft_model_pkg_group_name]
)
print(f"Model collection creation status: {_response}")

>>>
Model collection creation status: {'added_groups': ['arn:aws:sagemaker:us-east-1:811828458885:model-package-group/mikael110--llama-2-7b-guanaco-fp16-finetuned-sql'], 'failure': []}

This will result in a collection hierarchy which are linked by model/task type and the dataset used to fine tune the base model.

This method of separating the base and adapter models has some drawbacks. One drawback is complexity in deploying the model. Because there are two separate model artifacts, you need additional steps to repackage the model instead of deploy directly from Model Registry. In the following code example, download and repack the latest version of the base model first.

!aws s3 cp {base_model_package.model_data} .

!tar -xvf {model_tar_filename} -C ./deepspeed/

!mv ./deepspeed/{model_id} ./deepspeed/base

!rm -rf ./deepspeed/{model_id}

Then download and repack the latest fine-tuned LoRA adapter weights.

!aws s3 cp {LoRA_package.model_data} .

!mkdir -p ./deepspeed/lora/

!tar -xzf model.tar.gz -C ./deepspeed/lora/

Since you will be using DJL serving with deepspeed to host the model, your inference directory should look like the following.

deepspeed
    |-serving.properties
    |-requirements.txt
    |-model.py
    |-base/
        |-...
    |-lora/
        |-...

Finally, package the custom inference code, base model, and LoRA adaptor in a single .tar.gz file for deployment.

!rm -f model.tar.gz
!tar czvf model.tar.gz -C deepspeed .
s3_code_artifact_deepspeed = sagemaker_session.upload_data("model.tar.gz", default_bucket, f"{s3_key_prefix}/inference")
print(f"S3 Code or Model tar for deepspeed uploaded to --- > {s3_code_artifact_deepspeed}")

Clean up

Clean up your resources by following the instructions in the cleanup section of the notebook. Refer to Amazon SageMaker Pricing for details on the cost of the inference instances.

Conclusion

This post walked you through best practices for managing LoRA fine-tuned models on Amazon SageMaker. We covered two main methods: combining the base and adapter weights into one self-contained model, and separating the base and adapter weights. Both approaches have tradeoffs, but separating weights helps optimize storage and enables advanced model management techniques like SageMaker Model Registry Collections. This allows you to build hierarchies and relationships between models to improve organization and discoverability. We encourage you to try the sample code on GitHub repository to experiment with these methods yourself. As generative AI progresses rapidly, following model management best practices will help you track experiments, find the right model for your task, and manage specialized LLMs efficiently at scale.

References


About the authors

James Wu is a Senior AI/ML Specialist Solution Architect at AWS. helping customers design and build AI/ML solutions. James’s work covers a wide range of ML use cases, with a primary interest in computer vision, deep learning, and scaling ML across the enterprise. Prior to joining AWS, James was an architect, developer, and technology leader for over 10 years, including 6 years in engineering and 4 years in marketing & advertising industries.

Pranav Murthy is an AI/ML Specialist Solutions Architect at AWS. He focuses on helping customers build, train, deploy and migrate machine learning (ML) workloads to SageMaker. He previously worked in the semiconductor industry developing large computer vision (CV) and natural language processing (NLP) models to improve semiconductor processes. In his free time, he enjoys playing chess and traveling.

Mecit Gungor is an AI/ML Specialist Solution Architect at AWS helping customers design and build AI/ML solutions at scale. He covers a wide range of AI/ML use cases for Telecommunication customers and currently focuses on Generative AI, LLMs, and training and inference optimization. He can often be found hiking in the wilderness or playing board games with his friends in his free time.

Shelbee Eigenbrode is a Principal AI and Machine Learning Specialist Solutions Architect at Amazon Web Services (AWS). She has been in technology for 24 years spanning multiple industries, technologies, and roles. She is currently focusing on combining her DevOps and ML background into the domain of MLOps to help customers deliver and manage ML workloads at scale. With over 35 patents granted across various technology domains, she has a passion for continuous innovation and using data to drive business outcomes. Shelbee is a co-creator and instructor of the Practical Data Science specialization on Coursera. She is also the Co-Director of Women In Big Data (WiBD), Denver chapter. In her spare time, she likes to spend time with her family, friends, and overactive dogs.

Read More

Implement real-time personalized recommendations using Amazon Personalize

Implement real-time personalized recommendations using Amazon Personalize

At a basic level, Machine Learning (ML) technology learns from data to make predictions. Businesses use their data with an ML-powered personalization service to elevate their customer experience. This approach allows businesses to use data to derive actionable insights and help grow their revenue and brand loyalty.

Amazon Personalize accelerates your digital transformation with ML, making it easier to integrate personalized recommendations into existing websites, applications, email marketing systems, and more. Amazon Personalize enables developers to quickly implement a customized personalization engine, without requiring ML expertise. Amazon Personalize provisions the necessary infrastructure and manages the entire machine learning (ML) pipeline, including processing the data, identifying features, using the most appropriate algorithms, and training, optimizing, and hosting the models. You receive results through an API and pay only for what you use, with no minimum fees or upfront commitments.

The post Architecting near real-time personalized recommendations with Amazon Personalize shows how to architect near real-time personalized recommendations using Amazon Personalize and AWS purpose-built data services. In this post, we walk you through a reference implementation of a real-time personalized recommendation system using Amazon Personalize.

Solution overview

The real-time personalized recommendations solution is implemented using Amazon Personalize, Amazon Simple Storage Service (Amazon S3), Amazon Kinesis Data Streams, AWS Lambda, and Amazon API Gateway.

The architecture is implemented as follows:

  1. Data preparation – Start by creating a dataset group, schemas, and datasets representing your items, interactions, and user data.
  2. Train the model – After importing your data, select the recipe matching your use case, and then create a solution to train a model by creating a solution version. When your solution version is ready, you can create a campaign for your solution version.
  3. Get near real-time recommendations – When you have a campaign, you can integrate calls to the campaign in your application. This is where calls to the GetRecommendations or GetPersonalizedRanking APIs are made to request near real-time recommendations from Amazon Personalize.

For more information, refer to Architecting near real-time personalized recommendations with Amazon Personalize.

The following diagram illustrates the solution architecture.

Implementation

We demonstrate this implementation with a use case about making real-time movie recommendations to an end user based on their interactions with the movie database over time.

The solution is implemented using the following steps:

  1. Prerequisite (Data preparation)
  2. Setup your development environment
  3. Deploy the solution
  4. Create a solution version
  5. Create a campaign
  6. Create an event tracker
  7. Get recommendations
  8. Ingest real-time interactions
  9. Validate real-time recommendations
  10. Cleanup

Prerequisites

Before you get started, make sure you have the following prerequisites:

  • Prepare your training data – Prepare and upload the data to an S3 bucket using the instructions. For this particular use case, you will be uploading interactions data and items data. An interaction is an event that you record and then import as training data. Amazon Personalize generates recommendations primarily based on the interactions data you import into an Interactions dataset. You can record multiple event types, such as click, watch, or like. Although the model created by Amazon Personalize can suggest based on a user’s past interactions, the quality of these suggestions can be enhanced when the model possesses data about the associations among users or items . If a user has engaged with movies categorized as Drama in the item dataset, Amazon Personalize will suggest movies (items) with the same genre.
  • Setup your development environment Install the AWS Command Line Interface (AWS CLI).
  • Configure CLI with your Amazon accountConfigure the AWS CLI with your AWS account information.
  • Install and bootstrap AWS Cloud Development Kit (AWS CDK)

Deploy the solution

To deploy the solution, do the following:

  • Clone the repository to a new folder on your desktop.
  • Deploy the stack to your AWS environment.

Create a solution version

A solution refers to the combination of an Amazon Personalize recipe, customized parameters, and one or more solution versions (trained models). When you deploy the CDK project in the previous step, a solution with a User-Personalization recipe is created for you automatically. A solution version refers to a trained machine learning model. Create a solution version for the implementation.

Create a campaign

A campaign deploys a solution version (trained model) with a provisioned transaction capacity for generating real-time recommendations. Create a campaign for the implementation.

Create an event tracker

Amazon Personalize can make recommendations based on real-time event data only, historical event data only, or both. Record real-time events to build out your interactions data and allow Amazon Personalize to learn from your user’s most recent activity. This keeps your data fresh and improves the relevance of Amazon Personalize recommendations. Before you can record events, you must create an event tracker. An event tracker directs new event data to the Interactions dataset in your dataset group. Create and event tracker for the implementation.

Get recommendations

In this use case, the interaction dataset is composed of movie IDs. Consequently, the recommendations presented to the user will consist of movie IDs that align most closely with their personal preferences, determined from their historical interactions. You can use the getRecommendations API to retrieve personalized recommendations for a user by sending its associated userID, the number of results for recommendations that you need for the user as well as the campaign ARN. You can find the campaign ARN in the Amazon Personalize console menu.

For example, the following request will retrieve 5 recommendations for the user whose userId is 429:

curl --location 'https://{your-api-id}.execute-api.{your-region}.amazonaws.com/prod/getRecommendations?campaignArn={campaignArn}&userId=429&numResults=5'

The response from the request will be:

{
	"$metadata": {
		"httpStatusCode": 200,
		"requestId": "7159c128-4e16-45a4-9d7e-cf19aa2256e8",
		"attempts": 1,
		"totalRetryDelay": 0
	},
	"itemList": [
	{
		"itemId": "596",
		"score": 0.0243044
	},
	{
		"itemId": "153",
		"score": 0.0151695
	},
	{
		"itemId": "16",
		"score": 0.013694
	},
	{
		"itemId": "261",
		"score": 0.013524
	},
	{
		"itemId": "34",
		"score": 0.0122294
	}
	],
	"recommendationId": "RID-1d-40c1-8d20-dfffbd7b0ac7-CID-06b10f"
}

The items returned by the API call are the movies that Amazon Personalize recommends to the user based on their historical interactions.

The score values provided in this context represent floating-point numbers that range between zero and 1.0. These values correspond to the current campaign and the associated recipes for this use case. They are determined based on the collective scores assigned to all items present in your comprehensive dataset.

Ingest real-time interactions

In the previous example, recommendations were obtained for the user with an ID of 429 based on their historical interactions with the movie database. For real-time recommendations, the user interactions with the items must be ingested into Amazon Personalize in real-time. These interactions are ingested into the recommendation system through the Amazon Personalize Event Tracker. The type of interaction, also called EventType, is given by the column of the same name in the interaction data dataset (EVENT_TYPE). In this example, the events can be of type “watch” or “click”, but you can have your own types of events according to the needs of your application.

In this example, the exposed API that generates the events of the users with the items receives the “interactions” parameter that corresponds to the number of events (interactions) of a user (UserId) with a single element (itemId) right now. The trackingId parameter can be found in the Amazon Personalize console and in the response of the creation of Event Tracker request.

This example shows a putEvent request: Generate 1 interactions of click type, with an item id of ‘185’ for the user id ‘429’, using the current timestamp. Note that in production, the ‘sentAt’ should be set to the time of the user’s interaction. In the following example, we set this to the point in time in epoch time format when we wrote the API request for this post. The events are sent to Amazon Kinesis Data Streams through an API Gateway which is why you need to send the stream-name and PartitionKey parameters.

curl --location 'https://iyxhva3ll6.execute-api.us-west-2.amazonaws.com/prod/data' --header 'Content-Type: application/json' --data '{ "stream-name": "my-stream","Data": {"userId" : "429", "interactions": 1, "itemId": "185", "trackingId" : "c90ac6d7-3d89-4abc-8a70-9b09c295cbcd", "eventType": "click", "sentAt":"1698711110"},"PartitionKey":"userId"}'

You will receive a confirmation response similar to the following:

{
	"Message": "Event sent successfully",
	"data": {
		"EncryptionType": "KMS",
		"SequenceNumber": "49..........1901314",
		"ShardId": "shardId-xxxxxxx"
	}
}

Validate real-time recommendations

Because the interaction dataset has been updated, the recommendations will be automatically updated to consider the new interactions. To validate the recommendations updated in real-time, you can call the getRecommendations API again for the same user id 429, and the result should be different from the previous one. The following results show a new recommendation with an id of 594 and the recommendations with the id’s of 16, 596, 153and 261 changed their scores. These items brought in new movie genre (‘Animation|Children|Drama|Fantasy|Musical’) the top 5 recommendations.

Request:

curl --location 'https://{your-api-id}.execute-api.{your-region}.amazonaws.com/prod/getRecommendations?campaignArn={campaignArn} &userId=429&numResults=5'

Response:

{
	"$metadata": {
		"httpStatusCode": 200,
		"requestId": "680f2be8-2e64-47d7-96f7-1c4aa9b9ac9d",
		"attempts": 1,
		"totalRetryDelay": 0
	},
	"itemList": [
	{
		"itemId": "596",
		"score": 0.0288085
	},
	{
		"itemId": "16",
		"score": 0.0134173
	},
	{
		"itemId": "594",
		"score": 0.0129357
	},
	{
		"itemId": "153",
		"score": 0.0129337
	},
	{
		"itemId": "261",
		"score": 0.0123728
	}
	],
	"recommendationId": "RID-dc-44f8-a327-482fb9e54921-CID-06b10f"
}

The response shows that the recommendation provided by Amazon Personalize was updated in real-time.

Clean up

To avoid unnecessary charges, clean up the solution implementation by using Cleaning up resources.

Conclusion

In this post, we showed you how to implement a real-time personalized recommendations system using Amazon Personalize. The interactions with Amazon Personalize to ingest real-time interactions and get recommendations were executed through a command line tool called curl but these API calls can be integrated into a business application and derive the same outcome.

To choose a new recipe for your use case, refer to Real-time personalization. To measure the impact of the recommendations made by Amazon Personalize, refer to Measuring impact of recommendations.


About the Authors

Cristian Marquez is a Senior Cloud Application Architect. He has vast experience designing, building, and delivering enterprise-level software, high load and distributed systems and cloud native applications. He has experience in backend and frontend programming languages, as well as system design and implementation of DevOps practices. He actively assists customers build and secure innovative cloud solutions, solving their business problems and achieving their business goals.

Anand Komandooru is a Senior Cloud Architect at AWS. He joined AWS Professional Services organization in 2021 and helps customers build cloud-native applications on AWS cloud. He has over 20 years of experience building software and his favorite Amazon leadership principle is “Leaders are right a lot.

Read More

Improve LLM responses in RAG use cases by interacting with the user

Improve LLM responses in RAG use cases by interacting with the user

One of the most common applications of generative AI and large language models (LLMs) is answering questions based on a specific external knowledge corpus. Retrieval-Augmented Generation (RAG) is a popular technique for building question answering systems that use an external knowledge base. To learn more, refer to Build a powerful question answering bot with Amazon SageMaker, Amazon OpenSearch Service, Streamlit, and LangChain.

Traditional RAG systems often struggle to provide satisfactory answers when users ask vague or ambiguous questions without providing sufficient context. This leads to unhelpful responses like “I don’t know” or incorrect, made-up answers provided by an LLM. In this post, we demonstrate a solution to improve the quality of answers in such use cases over traditional RAG systems by introducing an interactive clarification component using LangChain.

The key idea is to enable the RAG system to engage in a conversational dialogue with the user when the initial question is unclear. By asking clarifying questions, prompting the user for more details, and incorporating the new contextual information, the RAG system can gather the necessary context to provide an accurate, helpful answer—even from an ambiguous initial user query.

Solution overview

To demonstrate our solution, we have set up an Amazon Kendra index (composed of the AWS online documentation for Amazon Kendra, Amazon Lex, and Amazon SageMaker), a LangChain agent with an Amazon Bedrock LLM, and a straightforward Streamlit user interface.

Prerequisites

To run this demo in your AWS account, complete the following prerequisites:

  1. Clone the GitHub repository and follow the steps explained in the README.
  2. Deploy an Amazon Kendra index in your AWS account. You can use the following AWS CloudFormation template to create a new index or use an already running index. Deploying a new index might add additional charges to your bill, therefore we recommend deleting it if you don’t longer need it. Note that the data within the index will be sent to the selected Amazon Bedrock foundation model (FM).
  3. The LangChain agent relies on FMs available in Amazon Bedrock, but this can be adapted to any other LLM that LangChain supports.
  4. To experiment with the sample front end shared with the code, you can use Amazon SageMaker Studio to run a local deployment of the Streamlit app. Note that running this demo will incur some additional costs.

Implement the solution

Traditional RAG agents are often designed as follows. The agent has access to a tool that is used to retrieve documents relevant to a user query. The retrieved documents are then inserted into the LLM prompt, so that the agent can provide an answer based on the retrieved document snippets.

In this post, we implement an agent that has access to KendraRetrievalTool and derives relevant documents from the Amazon Kendra index and provides the answer given the retrieved context:

# tool for Kendra retrieval

kendra_tool = Tool(
    name="KendraRetrievalTool",
    func=retrieval_qa_chain,
    description="Use this tool first to answer human questions. The input to this tool should be the question.",
)
# traditional RAG agent

traditional_agent = initialize_agent(
    agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
    tools=[kendra_tool]
    llm=llm,
    early_stopping_method="generate",
    memory=conversational_memory,
)

# user question
answer = traditional_agent.run("How many GPUs does my EC2 instance have?")

Refer to the GitHub repo for the full implementation code. To learn more about traditional RAG use cases, refer to Question answering using Retrieval Augmented Generation with foundation models in Amazon SageMaker JumpStart.

Consider the following example. A user asks “How many GPUs does my EC2 instance have?” As shown in the following screenshot, the agent is looking for the answer using KendraRetrievalTool. However, the agent realizes it doesn’t know which Amazon Elastic Compute Cloud (Amazon EC2) instance type the user is referencing and therefore provides no helpful answer to the user, leading to a poor customer experience.

Regular RAG Response

To address this problem, we define an additional custom tool called AskHumanTool and provide it to the agent. The tool instructs an LLM to read the user question and ask a follow-up question to the user if KendraRetrievalTool is not able to return a good answer. This implies that the agent will now have two tools at its disposal:

# tool for asking human
human_ask_tool = CustomAskHumanTool()

# RAG agent with two tools
improved_agent = initialize_agent(
    agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
    tools=[kendra_tool, human_ask_tool],
    llm=llm,
    early_stopping_method="generate",
    memory=conversational_memory,
)

# user question
answer = improved_agent.run("How many GPUs does my EC2 instance have?")

This allows the agent to either refine the question or provide additional context that is needed to respond to the prompt. To guide the agent to use AskHumanTool for this purpose, we provide the following tool description to the LLM:

Use this tool if you don’t find an answer using the KendraRetrievalTool. 
Ask the human to clarify the question or provide the missing information. 
The input should be a question for the human.

As illustrated in the following screenshot, by using AskHumanTool, the agent is now identifying vague user questions and returning a follow-up question to the user asking to specify what EC2 instance type is being used.

RAG Follow up question

After the user has specified the instance type, the agent is incorporating the additional answer into the context for the original question, before deriving the correct answer.

RAG improved final answer

Note that the agent can now decide whether to use KendraRetrievalTool to retrieve the relevant documents or ask a clarifying question using AskHumanTool. The agent’s decision is based on whether it finds the document snippets inserted into the prompt sufficient to provide the final answer. This flexibility allows the RAG system to support different queries a user may submit, including both well-formulated and vague questions.

In our example, the full agent workflow is as follows:

  1. The user makes a request to the RAG app, asking “How many GPUs does my EC2 instance have?”
  2. The agent uses the LLM to decide what action to take: Find relevant information to answer the user’s request by calling the KendraRetrievalTool.
  3. The agent retrieves information from the Amazon Kendra index using the tool. The snippets from the retrieved documents are inserted into the agent prompt.
  4. The LLM (of the agent) derives that the retrieved documents from Amazon Kendra aren’t helpful or don’t contain enough context to provide an answer to the user’s request.
  5. The agent uses AskHumanTool to formulate a follow-up question: “What is the specific EC2 instance type you are using? Knowing the instance type would help answer how many GPUs it has.” The user provides the answer “ml.g5.12xlarge,” and the agent calls KendraRetrievalTool again, but this time adding the EC2 instance type into the search query.
  6. After running through Steps 2–4 again, the agent derives a useful answer and sends it back to the user.

The following diagram illustrates this workflow.RAG improved architecture

The example described in this post illustrates how the addition of the custom AskHumanTool allows the agent to request clarifying details when needed. This can improve the reliability and accuracy of the responses, leading to a better customer experience in a growing number of RAG applications across different domains.

Clean up

To avoid incurring unnecessary costs, delete the Amazon Kendra index if you’re not using it anymore and shut down the SageMaker Studio instance if you used it to run the demo.

Conclusion

In this post, we showed how to enable a better customer experience for the users of a RAG system by adding a custom tool that enables the system to ask a user for a missing piece of information. This interactive conversational approach represents a promising direction for improving traditional RAG architectures. The ability to resolve vagueness through a dialogue can lead to delivering more satisfactory answers from a knowledge base.

Note that this approach is not limited to RAG use cases; you can use it in other generative AI use cases that depend on an agent at its core, where a custom AskHumanTool can be added.

For more information about using Amazon Kendra with generative AI, refer to Quickly build high-accuracy Generative AI applications on enterprise data using Amazon Kendra, LangChain, and large language models.


About the authors

Antonia Wiebeler is a Data Scientist at the AWS Generative AI Innovation Center, where she enjoys building proofs of concept for customers. Her passion is exploring how generative AI can solve real-world problems and create value for customers. While she is not coding, she enjoys running and competing in triathlons.

Nikita Kozodoi is an Applied Scientist at the AWS Generative AI Innovation Center, where he develops ML solutions to solve customer problems across industries. In his role, he focuses on advancing generative AI to tackle real-world challenges. In his spare time, he loves playing beach volleyball and reading.

Read More

Build trust and safety for generative AI applications with Amazon Comprehend and LangChain

Build trust and safety for generative AI applications with Amazon Comprehend and LangChain

We are witnessing a rapid increase in the adoption of large language models (LLM) that power generative AI applications across industries. LLMs are capable of a variety of tasks, such as generating creative content, answering inquiries via chatbots, generating code, and more.

Organizations looking to use LLMs to power their applications are increasingly wary about data privacy to ensure trust and safety is maintained within their generative AI applications. This includes handling customers’ personally identifiable information (PII) data properly. It also includes preventing abusive and unsafe content from being propagated to LLMs and checking that data generated by LLMs follows the same principles.

In this post, we discuss new features powered by Amazon Comprehend that enable seamless integration to ensure data privacy, content safety, and prompt safety in new and existing generative AI applications.

Amazon Comprehend is a natural language processing (NLP) service that uses machine learning (ML) to uncover information in unstructured data and text within documents. In this post, we discuss why trust and safety with LLMs matter for your workloads. We also delve deeper into how these new moderation capabilities are utilized with the popular generative AI development framework LangChain to introduce a customizable trust and safety mechanism for your use case.

Why trust and safety with LLMs matter

Trust and safety are paramount when working with LLMs due to their profound impact on a wide range of applications, from customer support chatbots to content generation. As these models process vast amounts of data and generate humanlike responses, the potential for misuse or unintended outcomes increases. Ensuring that these AI systems operate within ethical and reliable boundaries is crucial, not just for the reputation of businesses that utilize them, but also for preserving the trust of end-users and customers.

Moreover, as LLMs become more integrated into our daily digital experiences, their influence on our perceptions, beliefs, and decisions grows. Ensuring trust and safety with LLMs goes beyond just technical measures; it speaks to the broader responsibility of AI practitioners and organizations to uphold ethical standards. By prioritizing trust and safety, organizations not only protect their users, but also ensure sustainable and responsible growth of AI in society. It can also help to reduce risk of generating harmful content, and help adhere to regulatory requirements.

In the realm of trust and safety, content moderation is a mechanism that addresses various aspects, including but not limited to:

  • Privacy – Users can inadvertently provide text that contains sensitive information, jeopardizing their privacy. Detecting and redacting any PII is essential.
  • Toxicity – Recognizing and filtering out harmful content, such as hate speech, threats, or abuse, is of utmost importance.
  • User intention – Identifying whether the user input (prompt) is safe or unsafe is critical. Unsafe prompts can explicitly or implicitly express malicious intent, such as requesting personal or private information and generating offensive, discriminatory, or illegal content. Prompts may also implicitly express or request advice on medical, legal, political, controversial, personal, or financial

Content moderation with Amazon Comprehend

In this section, we discuss the benefits of content moderation with Amazon Comprehend.

Addressing privacy

Amazon Comprehend already addresses privacy through its existing PII detection and redaction abilities via the DetectPIIEntities and ContainsPIIEntities APIs. These two APIs are backed by NLP models that can detect a large number of PII entities such as Social Security numbers (SSNs), credit card numbers, names, addresses, phone numbers, and so on. For a full list of entities, refer to PII universal entity types. DetectPII also provides character-level position of the PII entity within a text; for example, the start character position of the NAME entity (John Doe) in the sentence “My name is John Doe” is 12, and the end character position is 19. These offsets can be used to perform masking or redaction of the values, thereby reducing risks of private data propagation into LLMs.

Addressing toxicity and prompt safety

Today, we are announcing two new Amazon Comprehend features in the form of APIs: Toxicity detection via the DetectToxicContent API, and prompt safety classification via the ClassifyDocument API. Note that DetectToxicContent is a new API, whereas ClassifyDocument is an existing API that now supports prompt safety classification.

Toxicity detection

With Amazon Comprehend toxicity detection, you can identify and flag content that may be harmful, offensive, or inappropriate. This capability is particularly valuable for platforms where users generate content, such as social media sites, forums, chatbots, comment sections, and applications that use LLMs to generate content. The primary goal is to maintain a positive and safe environment by preventing the dissemination of toxic content.

At its core, the toxicity detection model analyzes text to determine the likelihood of it containing hateful content, threats, obscenities, or other forms of harmful text. The model is trained on vast datasets containing examples of both toxic and nontoxic content. The toxicity API evaluates a given piece of text to provide toxicity classification and confidence score. Generative AI applications can then use this information to take appropriate actions, such as stopping the text from propagating to LLMs. As of this writing, the labels detected by the toxicity detection API are HATE_SPEECH, GRAPHIC, HARRASMENT_OR_ABUSE, SEXUAL, VIOLENCE_OR_THREAT, INSULT, and PROFANITY. The following code demonstrates the API call with Python Boto3 for Amazon Comprehend toxicity detection:

import boto3
client = boto3.client('comprehend')
response = client.detect_toxic_content(
    	TextSegments=[{"Text": "What is the capital of France?"},
                      {"Text": "Where do I find good baguette in France?"}],
    	LanguageCode='en')
print(response)

Prompt safety classification

Prompt safety classification with Amazon Comprehend helps classify an input text prompt as safe or unsafe. This capability is crucial for applications like chatbots, virtual assistants, or content moderation tools where understanding the safety of a prompt can determine responses, actions, or content propagation to LLMs.

In essence, prompt safety classification analyzes human input for any explicit or implicit malicious intent, such as requesting personal or private information and generation of offensive, discriminatory, or illegal content. It also flags prompts looking for advice on medical, legal, political, controversial, personal, or financial subjects. Prompt classification returns two classes, UNSAFE_PROMPT and SAFE_PROMPT, for an associated text, with an associated confidence score for each. The confidence score ranges between 0–1 and combined will sum up to 1. For instance, in a customer support chatbot, the text “How do I reset my password?” signals an intent to seek guidance on password reset procedures and is labeled as SAFE_PROMPT. Similarly, a statement like “I wish something bad happens to you” can be flagged for having a potentially harmful intent and labeled as UNSAFE_PROMPT. It’s important to note that prompt safety classification is primarily focused on detecting intent from human inputs (prompts), rather than machine-generated text (LLM outputs). The following code demonstrates how to access the prompt safety classification feature with the ClassifyDocument API:

import boto3
client = boto3.client('comprehend')
response = self.client.classify_document(
           		Text=prompt_value, 
EndpointArn=endpoint_arn)
print(response)

Note that endpoint_arn in the preceding code is an AWS-provided Amazon Resource Number (ARN) of the pattern arn:aws:comprehend:<region>:aws:document-classifier-endpoint/prompt-safety, where <region> is the AWS Region of your choice where Amazon Comprehend is available.

To demonstrate these capabilities, we built a sample chat application where we ask an LLM to extract PII entities such as address, phone number, and SSN from a given piece of text. The LLM finds and returns the appropriate PII entities, as shown in the image on the left.

With Amazon Comprehend moderation, we can redact the input to the LLM and output from the LLM. In the image on the right, the SSN value is allowed to be passed to the LLM without redaction. However, any SSN value in the LLM’s response is redacted.

The following is an example of how a prompt containing PII information can be prevented from reaching the LLM altogether. This example demonstrates a user asking a question that contains PII information. We use Amazon Comprehend moderation to detect PII entities in the prompt and show an error by interrupting the flow.

The preceding chat examples showcase how Amazon Comprehend moderation applies restrictions on data being sent to an LLM. In the following sections, we explain how this moderation mechanism is implemented using LangChain.

Integration with LangChain

With the endless possibilities of the application of LLMs into various use cases, it has become equally important to simplify the development of generative AI applications. LangChain is a popular open source framework that makes it effortless to develop generative AI applications. Amazon Comprehend moderation extends the LangChain framework to offer PII identification and redaction, toxicity detection, and prompt safety classification capabilities via AmazonComprehendModerationChain.

AmazonComprehendModerationChain is a custom implementation of the LangChain base chain interface. This means that applications can use this chain with their own LLM chains to apply the desired moderation to the input prompt as well as to the output text from the LLM. Chains can be built by merging numerous chains or by mixing chains with other components. You can use AmazonComprehendModerationChain with other LLM chains to develop complex AI applications in a modular and flexible manner.

To explain it further, we provide a few samples in the following sections. The source code for the AmazonComprehendModerationChain implementation can be found within the LangChain open source repository. For full documentation of the API interface, refer to the LangChain API documentation for the Amazon Comprehend moderation chain. Using this moderation chain is as simple as initializing an instance of the class with default configurations:

from langchain_experimental.comprehend_moderation import AmazonComprehendModerationChain

comprehend_moderation = AmazonComprehendModerationChain()

Behind the scenes, the moderation chain performs three consecutive moderation checks, namely PII, toxicity, and prompt safety, as explained in the following diagram. This is the default flow for the moderation.

The following code snippet shows a simple example of using the moderation chain with the Amazon FalconLite LLM (which is a quantized version of the Falcon 40B SFT OASST-TOP1 model) hosted in Hugging Face Hub:

from langchain import HuggingFaceHub
from langchain import PromptTemplate, LLMChain
from langchain_experimental.comprehend_moderation import AmazonComprehendModerationChain

template = """Question: {question}
Answer:"""
repo_id = "amazon/FalconLite"
prompt = PromptTemplate(template=template, input_variables=["question"])
llm = HuggingFaceHub(
repo_id=repo_id, 
model_kwargs={"temperature": 0.5, "max_length": 256}
)
comprehend_moderation = AmazonComprehendModerationChain(verbose=True)
chain = (
    prompt 
    | comprehend_moderation 
    | { "input" : (lambda x: x['output']) | llm }  
    | comprehend_moderation
)

try:
    response = chain.invoke({"question": "An SSN is of the format 123-45-6789. Can you give me John Doe's SSN?"})
except Exception as e:
    print(str(e))
else:
    print(response['output'])

In the preceding example, we augment our chain with comprehend_moderation for both text going into the LLM and text generated by the LLM. This will perform default moderation that will check PII, toxicity, and prompt safety classification in that sequence.

Customize your moderation with filter configurations

You can use the AmazonComprehendModerationChain with specific configurations, which gives you the ability to control what moderations you wish to perform in your generative AI–based application. At the core of the configuration, you have three filter configurations available.

  1. ModerationPiiConfig – Used to configure PII filter.
  2. ModerationToxicityConfig – Used to configure toxic content filter.
  3. ModerationIntentConfig – Used to configure intent filter.

You can use each of these filter configurations to customize the behavior of how your moderations behave. Each filter’s configurations have a few common parameters, and some unique parameters, that they can be initialized with. After you define the configurations, you use the BaseModerationConfig class to define the sequence in which the filters must apply to the text. For example, in the following code, we first define the three filter configurations, and subsequently specify the order in which they must apply:

from langchain_experimental.comprehend_moderation 
import (BaseModerationConfig, 
ModerationPromptSafetyConfig, 
ModerationPiiConfig, 
ModerationToxicityConfig)

pii_config = ModerationPiiConfig(labels=["SSN"],
   redact=True,
   mask_character="X")
toxicity_config = ModerationToxicityConfig(threshold=0.6)
prompt_safety_config = ModerationPromptSafetyConfig(threshold=0.8)
moderation_config = BaseModerationConfig(filters=[ toxicity_config, 
      pii_config, 
      prompt_safety_config])
comprehend_moderation = AmazonComprehendModerationChain(moderation_config=moderation_config)

Let’s dive a little deeper to understand what this configuration achieves:

  • First, for the toxicity filter, we specified a threshold of 0.6. This means that if the text contains any of the available toxic labels or entities with a score greater than the threshold, the whole chain will be interrupted.
  • If there is no toxic content found in the text, a PII check is In this case, we’re interested in checking if the text contains SSN values. Because the redact parameter is set to True, the chain will mask the detected SSN values (if any) where the SSN entitiy’s confidence score is greater than or equal to 0.5, with the mask character specified (X). If redact is set to False, the chain will be interrupted for any SSN detected.
  • Finally, the chain performs prompt safety classification, and will stop the content from propagating further down the chain if the content is classified with UNSAFE_PROMPT with a confidence score of greater than or equal to 0.8.

The following diagram illustrates this workflow.

In case of interruptions to the moderation chain (in this example, applicable for the toxicity and prompt safety classification filters), the chain will raise a Python exception, essentially stopping the chain in progress and allowing you to catch the exception (in a try-catch block) and perform any relevant action. The three possible exception types are:

  1. ModerationPIIError
  2. ModerationToxicityError
  3. ModerationPromptSafetyError

You can configure one filter or more than one filter using BaseModerationConfig. You can also have the same type of filter with different configurations within the same chain. For example, if your use case is only concerned with PII, you can specify a configuration that must interrupt the chain if in case an SSN is detected; otherwise, it must perform redaction on age and name PII entities. A configuration for this can be defined as follows:

pii_config1 = ModerationPiiConfig(labels=["SSN"],
    redact=False)
pii_config2 = ModerationPiiConfig(labels=["AGE", "NAME"],
    redact=True, 
    mask_character="X")
moderation_config = BaseModerationConfig(filters=[ pii_config1, 
      pii_config2])
comprehend_moderation = AmazonComprehendModerationChain(moderation_config=moderation_config)

Using callbacks and unique identifiers

If you’re familiar with the concept of workflows, you may also be familiar with callbacks. Callbacks within workflows are independent pieces of code that run when certain conditions are met within the workflow. A callback can either be blocking or nonblocking to the workflow. LangChain chains are, in essence, workflows for LLMs. AmazonComprehendModerationChain allows you to define your own callback functions. Initially, the implementation is limited to asynchronous (nonblocking) callback functions only.

This effectively means that if you use callbacks with the moderation chain, they will run independently of the chain’s run without blocking it. For the moderation chain, you get options to run pieces of code, with any business logic, after each moderation is run, independent of the chain.

You can also optionally provide an arbitrary unique identifier string when creating an AmazonComprehendModerationChain to enable logging and analytics later. For example, if you’re operating a chatbot powered by an LLM, you may want to track users who are consistently abusive or are deliberately or unknowingly exposing personal information. In such cases, it becomes necessary to track the origin of such prompts and perhaps store them in a database or log them appropriately for further action. You can pass a unique ID that distinctly identifies a user, such as their user name or email, or an application name that is generating the prompt.

The combination of callbacks and unique identifiers provides you with a powerful way to implement a moderation chain that fits your use case in a much more cohesive manner with less code that is easier to maintain. The callback handler is available via the BaseModerationCallbackHandler, with three available callbacks: on_after_pii(), on_after_toxicity(), and on_after_prompt_safety(). Each of these callback functions is called asynchronously after the respective moderation check is performed within the chain. These functions also receive two default parameters:

  • moderation_beacon – A dictionary containing details such as the text on which the moderation was performed, the full JSON output of the Amazon Comprehend API, the type of moderation, and if the supplied labels (in the configuration) were found within the text or not
  • unique_id – The unique ID that you assigned while initializing an instance of the AmazonComprehendModerationChain.

The following is an example of how an implementation with callback works. In this case, we defined a single callback that we want the chain to run after the PII check is performed:

from langchain_experimental.comprehend_moderation import BaseModerationCallbackHandler

class MyModCallback(BaseModerationCallbackHandler):
    async def on_after_pii(self, output_beacon, unique_id):
        import json
        moderation_type = output_beacon['moderation_type']
        chain_id = output_beacon['moderation_chain_id']
        with open(f'output-{moderation_type}-{chain_id}.json', 'w') as file:
            data = { 'beacon_data': output_beacon, 'unique_id': unique_id }
            json.dump(data, file)
    
    '''
    # implement this callback for toxicity
    async def on_after_toxicity(self, output_beacon, unique_id):
        pass

    # implement this callback for prompt safety
    async def on_after_prompt_safety(self, output_beacon, unique_id):
        pass
    '''

my_callback = MyModCallback()

We then use the my_callback object while initializing the moderation chain and also pass a unique_id. You may use callbacks and unique identifiers with or without a configuration. When you subclass BaseModerationCallbackHandler, you must implement one or all of the callback methods depending on the filters you intend to use. For brevity, the following example shows a way to use callbacks and unique_id without any configuration:

comprehend_moderation = AmazonComprehendModerationChain(
moderation_callback = my_callback,
unique_id = 'john.doe@email.com')

The following diagram explains how this moderation chain with callbacks and unique identifiers works. Specifically, we implemented the PII callback that should write a JSON file with the data available in the moderation_beacon and the unique_id passed (the user’s email in this case).

In the following Python notebook, we have compiled a few different ways you can configure and use the moderation chain with various LLMs, such as LLMs hosted with Amazon SageMaker JumpStart and hosted in Hugging Face Hub. We have also included the sample chat application that we discussed earlier with the following Python notebook.

Conclusion

The transformative potential of large language models and generative AI is undeniable. However, their responsible and ethical use hinges on addressing concerns of trust and safety. By recognizing the challenges and actively implementing measures to mitigate risks, developers, organizations, and society at large can harness the benefits of these technologies while preserving the trust and safety that underpin their successful integration. Use Amazon Comprehend ContentModerationChain to add trust and safety features to any LLM workflow, including Retrieval Augmented Generation (RAG) workflows implemented in LangChain.

For information on building RAG based solutions using LangChain and Amazon Kendra’s highly accurate, machine learning (ML)-powered intelligent search, see – Quickly build high-accuracy Generative AI applications on enterprise data using Amazon Kendra, LangChain, and large language models. As a next step, refer to the code samples we created for using Amazon Comprehend moderation with LangChain. For full documentation of the Amazon Comprehend moderation chain API, refer to the LangChain API documentation.


About the authors

Wrick Talukdar is a Senior Architect with the Amazon Comprehend Service team. He works with AWS customers to help them adopt machine learning on a large scale. Outside of work, he enjoys reading and photography.

Anjan Biswas is a Senior AI Services Solutions Architect with a focus on AI/ML and Data Analytics. Anjan is part of the world-wide AI services team and works with customers to help them understand and develop solutions to business problems with AI and ML. Anjan has over 14 years of experience working with global supply chain, manufacturing, and retail organizations, and is actively helping customers get started and scale on AWS AI services.

Nikhil Jha is a Senior Technical Account Manager at Amazon Web Services. His focus areas include AI/ML, and analytics. In his spare time, he enjoys playing badminton with his daughter and exploring the outdoors.

Chin Rane is an AI/ML Specialist Solutions Architect at Amazon Web Services. She is passionate about applied mathematics and machine learning. She focuses on designing intelligent document processing solutions for AWS customers. Outside of work, she enjoys salsa and bachata dancing.

Read More

Use machine learning without writing a single line of code with Amazon SageMaker Canvas

Use machine learning without writing a single line of code with Amazon SageMaker Canvas

In the recent past, using machine learning (ML) to make predictions, especially for data in the form of text and images, required extensive ML knowledge for creating and tuning of deep learning models. Today, ML has become more accessible to any user who wants to use ML models to generate business value. With Amazon SageMaker Canvas, you can create predictions for a number of different data types beyond just tabular or time series data without writing a single line of code. These capabilities include pre-trained models for image, text, and document data types.

In this post, we discuss how you can use pre-trained models to retrieve predictions for supported data types beyond tabular data.

Text data

SageMaker Canvas provides a visual, no-code environment for building, training, and deploying ML models. For natural language processing (NLP) tasks, SageMaker Canvas integrates seamlessly with Amazon Comprehend to allow you to perform key NLP capabilities like language detection, entity recognition, sentiment analysis, topic modeling, and more. The integration eliminates the need for any coding or data engineering to use the robust NLP models of Amazon Comprehend. You simply provide your text data and select from four commonly used capabilities: sentiment analysis, language detection, entities extraction, and personal information detection. For each scenario, you can use the UI to test and use batch prediction to select data stored in Amazon Simple Storage Service (Amazon S3).

Analyzing text data on SageMaker Canvas

Sentiment analysis

With sentiment analysis, SageMaker Canvas allows you to analyze the sentiment of your input text. It can determine if the overall sentiment is positive, negative, mixed, or neutral, as shown in the following screenshot. This is useful in situations like analyzing product reviews. For example, the text “I love this product, it’s amazing!” would be classified by SageMaker Canvas as having a positive sentiment, whereas “This product is horrible, I regret buying it” would be labeled as negative sentiment.

Sentiment Analysis on SageMaker Canvas

Entities extraction

SageMaker Canvas can analyze text and automatically detect entities mentioned within it. When a document is sent to SageMaker Canvas for analysis, it will identify people, organizations, locations, dates, quantities, and other entities in the text. This entity extraction capability enables you to quickly gain insights into the key people, places, and details discussed in documents. For a list of supported entities, refer to Entities.

Entites Extraction on SageMaker Canvas

Language detection

SageMaker Canvas can also determine the dominant language of text using Amazon Comprehend. It analyzes text to identify the main language and provides confidence scores for the detected dominant language, but doesn’t indicate percentage breakdowns for multilingual documents. For best results with long documents in multiple languages, split the text into smaller pieces and aggregate the results to estimate language percentages. It works best with at least 20 characters of text.

Language Detection on SageMaker Canvas

Personal information detection

You can also protect sensitive data using personal information detection with SageMaker Canvas. It can analyze text documents to automatically detect personally identifiable information (PII) entities, allowing you to locate sensitive data like names, addresses, dates of birth, phone numbers, email addresses, and more. It analyzes documents up to 100 KB and provides a confidence score for each detected entity so you can review and selectively redact the most sensitive information. For a list of entities detected, refer to Detecting PII entities.

PII Detection on SageMaker Canvas

Image data

SageMaker Canvas provides a visual, no-code interface that makes it straightforward for you to use computer vision capabilities by integrating with Amazon Rekognition for image analysis. For example, you can upload a dataset of images, use Amazon Rekognition to detect objects and scenes, and perform text detection to address a wide range of use cases. The visual interface and Amazon Rekognition integration make it possible for non-developers to harness advanced computer vision techniques.

Analyzing image data on SageMaker Canvas

Object detection in images

SageMaker Canvas uses Amazon Rekognition to detect labels (objects) in an image. You can upload the image from the SageMaker Canvas UI or use the Batch Prediction tab to select images stored in an S3 bucket. As shown in the following example, it can extract objects in the image such as clock tower, bus, buildings, and more. You can use the interface to search through the prediction results and sort them.

Object Detection in Images on SageMaker Canvas

Text detection in images

Extracting text from images is a very common use case. Now, you can perform this task with ease on SageMaker Canvas with no code. The text is extracted as line items, as shown in the following screenshot. Short phrases within the image are classified together and identified as a phrase.

Text Detection in Images on SageMaker Canvas

You can perform batch predictions by uploading a set of images, extract all the images in a single batch job, and download the results as a CSV file. This solution is useful when you want to extract and detect text in images.

Document data

SageMaker Canvas offers a variety of ready-to-use solutions that solve your day-to-day document understanding needs. These solutions are powered by Amazon Textract. To view all the available options for documents, choose to Ready-to-use models in the navigation pane and filter by Documents, as shown in the following screenshot.

Analyzing Document Data on SageMaker Canvas

Document analysis

Document analysis analyzes documents and forms for relationships among detected text. The operations return four categories of document extraction: raw text, forms, tables, and signatures. The solution’s capability of understanding the document structure gives you extra flexibility in the type of data you want to extract from the documents. The following screenshot is an example of what table detection looks like.

Document Analysis on SageMaker Canvas

This solution is able to understand layouts of complex documents, which is helpful when you need to extract specific information in your documents.

Identity document analysis

This solution is designed to analyze documents like personal identification cards, driver’s licenses, or other similar forms of identification. Information such as middle name, county, and place of birth, together with its individual confidence score on the accuracy, will be returned for each identity document, as shown in the following screenshot.

Identity Document Analysis on SageMaker Canvas

There is an option to do batch prediction, whereby you can bulk upload sets of identification documents and process them as a batch job. This provides a quick and seamless way to transform identification document details into key-value pairs that can be used for downstream processes such as data analysis.

Expense analysis

Expense analysis is designed to analyze expense documents like invoices and receipts. The following screenshot is an example of what the extracted information looks like.

Expense Analysis on SageMaker Canvas

The results are returned as summary fields and line item fields. Summary fields are key-value pairs extracted from the document, and contain keys such as Grand Total, Due Date, and Tax. Line item fields refer to data that is structured as a table in the document. This is useful for extracting information from the document while retaining its layout.

Document queries

Document queries are designed for you to ask questions about your documents. This is a great solution to use when you have multi-page documents and you want to extract very specific answers from your documents. The following is an example of the types of questions you can ask and what the extracted answers look like.

Document Queries on SageMaker Canvas

The solution provides a straightforward interface for you to interact with your documents. This is helpful when you want to get specific details within large documents.

Conclusion

SageMaker Canvas provides a no-code environment to use ML with ease across various data types like text, images, and documents. The visual interface and integration with AWS services like Amazon Comprehend, Amazon Rekognition, and Amazon Textract eliminates the need for coding and data engineering. You can analyze text for sentiment, entities, languages, and PII. For images, object and text detection enables computer vision use cases. Finally, document analysis can extract text while preserving its layout for downstream processes. The ready-to-use solutions in SageMaker Canvas make it possible for you to harness advanced ML techniques to generate insights from both structured and unstructured data. If you’re interested using no-code tools with ready-to-use ML models, try out SageMaker Canvas today. For more information, refer to Getting started with using Amazon SageMaker Canvas.


About the authors

Julia Ang is a Solutions Architect based in Singapore. She has worked with customers in a range of fields, from health and public sector to digital native businesses, to adopt solutions according to their business needs. She has also been supporting customers in Southeast Asia and beyond to use AI & ML in their businesses. Outside of work, she enjoys learning about the world through traveling and engaging in creative pursuits.

Loke Jun Kai is a Specialist Solutions Architect for AI/ML based in Singapore. He works with customer across ASEAN to architect machine learning solutions at scale in AWS. Jun Kai is an advocate for Low-Code No-Code machine learning tools. In his spare time, he enjoys being with the nature.

Read More

Explore advanced techniques for hyperparameter optimization with Amazon SageMaker Automatic Model Tuning

Explore advanced techniques for hyperparameter optimization with Amazon SageMaker Automatic Model Tuning

Creating high-performance machine learning (ML) solutions relies on exploring and optimizing training parameters, also known as hyperparameters. Hyperparameters are the knobs and levers that we use to adjust the training process, such as learning rate, batch size, regularization strength, and others, depending on the specific model and task at hand. Exploring hyperparameters involves systematically varying the values of each parameter and observing the impact on model performance. Although this process requires additional efforts, the benefits are significant. Hyperparameter optimization (HPO) can lead to faster training times, improved model accuracy, and better generalization to new data.

We continue our journey from the post Optimize hyperparameters with Amazon SageMaker Automatic Model Tuning. We previously explored a single job optimization, visualized the outcomes for SageMaker built-in algorithm, and learned about the impact of particular hyperparameter values. On top of using HPO as a one-time optimization at the end of the model creation cycle, we can also use it across multiple steps in a conversational manner. Each tuning job helps us get closer to a good performance, but additionally, we also learn how sensitive the model is to certain hyperparameters and can use this understanding to inform the next tuning job. We can revise the hyperparameters and their value ranges based on what we learned and therefore turn this optimization effort into a conversation. And in the same way that we as ML practitioners accumulate knowledge over these runs, Amazon SageMaker Automatic Model Tuning (AMT) with warm starts can maintain this knowledge acquired in previous tuning jobs for the next tuning job as well.

In this post, we run multiple HPO jobs with a custom training algorithm and different HPO strategies such as Bayesian optimization and random search. We also put warm starts into action and visually compare our trials to refine hyperparameter space exploration.

Advanced concepts of SageMaker AMT

In the next sections, we take a closer look at each of the following topics and show how SageMaker AMT can help you implement them in your ML projects:

  • Use custom training code and the popular ML framework Scikit-learn in SageMaker Training
  • Define custom evaluation metrics based on the logs for evaluation and optimization
  • Perform HPO using an appropriate strategy
  • Use warm starts to turn a single hyperparameter search into a dialog with our model
  • Use advanced visualization techniques using our solution library to compare two HPO strategies and tuning jobs results

Whether you’re using the built-in algorithms used in our first post or your own training code, SageMaker AMT offers a seamless user experience for optimizing ML models. It provides key functionality that allows you to focus on the ML problem at hand while automatically keeping track of the trials and results. At the same time, it automatically manages the underlying infrastructure for you.

In this post, we move away from a SageMaker built-in algorithm and use custom code. We use a Random Forest from SkLearn. But we stick to the same ML task and dataset as in our first post, which is detecting handwritten digits. We cover the content of the Jupyter notebook 2_advanced_tuning_with_custom_training_and_visualizing.ipynb and invite you to invoke the code side by side to read further.

Let’s dive deeper and discover how we can use custom training code, deploy it, and run it, while exploring the hyperparameter search space to optimize our results.

How to build an ML model and perform hyperparameter optimization

What does a typical process for building an ML solution look like? Although there are many possible use cases and a large variety of ML tasks out there, we suggest the following mental model for a stepwise approach:

  1. Understand your ML scenario at hand and select an algorithm based on the requirements. For example, you might want to solve an image recognition task using a supervised learning algorithm. In this post, we continue to use the handwritten image recognition scenario and the same dataset as in our first post.
  2. Decide which implementation of the algorithm in SageMaker Training you want to use. There are various options, inside SageMaker or external ones. Additionally, you need to define which underlying metric fits best for your task and you want to optimize for (such as accuracy, F1 score, or ROC). SageMaker supports four options depending on your needs and resources:
    • Use a pre-trained model via Amazon SageMaker JumpStart, which you can use out of the box or just fine-tune it.
    • Use one of the built-in algorithms for training and tuning, like XGBoost, as we did in our previous post.
    • Train and tune a custom model based on one of the major frameworks like Scikit-learn, TensorFlow, or PyTorch. AWS provides a selection of pre-made Docker images for this purpose. For this post, we use this option, which allows you to experiment quickly by running your own code on top of a pre-made container image.
    • Bring your own custom Docker image in case you want to use a framework or software that is not otherwise supported. This option requires the most effort, but also provides the highest degree of flexibility and control.
  3. Train the model with your data. Depending on the algorithm implementation from the previous step, this can be as simple as referencing your training data and running the training job or by additionally providing custom code for training. In our case, we use some custom training code in Python based on Scikit-learn.
  4. Apply hyperparameter optimization (as a “conversation” with your ML model). After the training, you typically want to optimize the performance of your model by finding the most promising combination of values for your algorithm’s hyperparameters.

Depending on your ML algorithm and model size, the last step of hyperparameter optimization may turn out to be a bigger challenge than expected. The following questions are typical for ML practitioners at this stage and might sound familiar to you:

  • What kind of hyperparameters are impactful for my ML problem?
  • How can I effectively search a huge hyperparameter space to find those best-performing values?
  • How does the combination of certain hyperparameter values influence my performance metric?
  • Costs matter; how can I use my resources in an efficient manner?
  • What kind of tuning experiments are worthwhile, and how can I compare them?

It’s not easy to answer these questions, but there is good news. SageMaker AMT takes the heavy lifting from you, and lets you concentrate on choosing the right HPO strategy and value ranges you want to explore. Additionally, our visualization solution facilitates the iterative analysis and experimentation process to efficiently find well-performing hyperparameter values.

In the next sections, we build a digit recognition model from scratch using Scikit-learn and show all these concepts in action.

Solution overview

SageMaker offers some very handy features to train, evaluate, and tune our model. It covers all functionality of an end-to-end ML lifecycle, so we don’t even need to leave our Jupyter notebook.

In our first post, we used the SageMaker built-in algorithm XGBoost. For demonstration purposes, this time we switch to a Random Forest classifier because we can then show how to provide your own training code. We opted for providing our own Python script and using Scikit-learn as our framework. Now, how do we express that we want to use a specific ML framework? As we will see, SageMaker uses another AWS service in the background to retrieve a pre-built Docker container image for training—Amazon Elastic Container Registry (Amazon ECR).

We cover the following steps in detail, including code snippets and diagrams to connect the dots. As mentioned before, if you have the chance, open the notebook and run the code cells step by step to create the artifacts in your AWS environment. There is no better way of active learning.

  1. First, load and prepare the data. We use Amazon Simple Storage Service (Amazon S3) to upload a file containing our handwritten digits data.
  2. Next, prepare the training script and framework dependencies. We provide the custom training code in Python, reference some dependent libraries, and make a test run.
  3. To define the custom objective metrics, SageMaker lets us define a regular expression to extract the metrics we need from the container log files.
  4. Train the model using the scikit-learn framework. By referencing a pre-built container image, we create a corresponding Estimator object and pass our custom training script.
  5. AMT enables us to try out various HPO strategies. We concentrate on two of them for this post: random search and Bayesian search.
  6. Choose between SageMaker HPO strategies.
  7. Visualize, analyze, and compare tuning results. Our visualization package allows us to discover which strategy performs better and which hyperparameter values deliver the best performance based on our metrics.
  8. Continue the exploration of the hyperparameter space and warm start HPO jobs.

AMT takes care of scaling and managing the underlying compute infrastructure to run the various tuning jobs on Amazon Elastic Compute Cloud (Amazon EC2) instances. This way, you don’t need to burden yourself to provision instances, handle any operating system and hardware issues, or aggregate log files on your own. The ML framework image is retrieved from Amazon ECR and the model artifacts including tuning results are stored in Amazon S3. All logs and metrics are collected in Amazon CloudWatch for convenient access and further analysis if needed.

Prerequisites

Because this is a continuation of a series, it is recommended, but not necessarily required, to read our first post about SageMaker AMT and HPO. Apart from that, basic familiarity with ML concepts and Python programming is helpful. We also recommend following along with each step in the accompanying notebook from our GitHub repository while reading this post. The notebook can be run independently from the first one, but needs some code from subfolders. Make sure to clone the full repository in your environment as described in the README file.

Experimenting with the code and using the interactive visualization options greatly enhances your learning experience. So, please check it out.

Load and prepare the data

As a first step, we make sure the downloaded digits data we need for training is accessible to SageMaker. Amazon S3 allows us to do this in a safe and scalable way. Refer to the notebook for the complete source code and feel free to adapt it with your own data.

sm_sess = sagemaker.session.Session(boto_session=boto_sess, sagemaker_client=sm)
BUCKET = sm_sess.default_bucket()
PREFIX = 'amt-visualize-demo'
s3_data_url = f's3://{BUCKET}/{PREFIX}/data'
digits         = datasets.load_digits()
digits_df      = pd.DataFrame(digits.data)
digits_df['y'] = digits.target
digits_df.to_csv('data/digits.csv', index=False)
!aws s3 sync data/ {s3_data_url} —exclude '*' —include 'digits.csv'

The digits.csv file contains feature data and labels. Each digit is represented by pixel values in an 8×8 image, as depicted by the following image for the digit 4.
Digits Dataset from Scikit-learn

Prepare the training script and framework dependencies

Now that the data is stored in our S3 bucket, we can define our custom training script based on Scikit-learn in Python. SageMaker gives us the option to simply reference the Python file later for training. Any dependencies like the Scikit-learn or pandas libraries can be provided in two ways:

  • They can be specified explicitly in a requirements.txt file
  • They are pre-installed in the underlying ML container image, which is either provided by SageMaker or custom-built

Both options are generally considered standard ways for dependency management, so you might already be familiar with it. SageMaker supports a variety of ML frameworks in a ready-to-use managed environment. This includes many of the most popular data science and ML frameworks like PyTorch, TensorFlow, or Scikit-learn, as in our case. We don’t use an additional requirements.txt file, but feel free to add some libraries to try it out.

The code of our implementation contains a method called fit(), which creates a new classifier for the digit recognition task and trains it. In contrast to our first post where we used the SageMaker built-in XGBoost algorithm, we now use a RandomForestClassifier provided by the ML library sklearn. The call of the fit() method on the classifier object starts the training process using a subset (80%) of our CSV data:

def fit(train_dir, n_estimators, max_depth, min_samples_leaf, max_features, min_weight_fraction_leaf):
    
    digits = pd.read_csv(Path(train_dir)/'digits.csv')
    
    Xtrain, Xtest, ytrain, ytest = train_test_split(digits.iloc[:, :-1], digits.iloc[:, -1], test_size=.2)
    
    m = RandomForestClassifier(n_estimators=n_estimators, 
                               max_depth=max_depth, 
                               min_samples_leaf=min_samples_leaf,
                               max_features=max_features,
                               min_weight_fraction_leaf=min_weight_fraction_leaf)
    m.fit(Xtrain, ytrain)
    predicted = m.predict(Xtest)
    pre, rec, f1, _ = precision_recall_fscore_support(ytest, predicted, pos_label=1, average='weighted')
    
    print(f'pre: {pre:5.3f} rec: {rec:5.3f} f1: {f1:5.3}')
    
    return m

See the full script in our Jupyter notebook on GitHub.

Before you spin up container resources for the full training process, did you try to run the script directly? This is a good practice to quickly ensure the code has no syntax errors, check for matching dimensions of your data structures, and some other errors early on.

There are two ways to run your code locally. First, you can run it right away in the notebook, which also allows you to use the Python Debugger pdb:

# Running the code from within the notebook. It would then be possible to use the Python Debugger, pdb.
from train import fit
fit('data', 100, 10, 1, 'auto', 0.01)

Alternatively, run the train script from the command line in the same way you may want to use it in a container. This also supports setting various parameters and overwriting the default values as needed, for example:

!cd src && python train.py --train ../data/ --model-dir /tmp/ --n-estimators 100

As output, you can see the first results for the model’s performance based on the objective metrics precision, recall, and F1-score. For example, pre: 0.970 rec: 0.969 f1: 0.969.

Not bad for such a quick training. But where did these numbers come from and what do we do with them?

Define custom objective metrics

Remember, our goal is to fully train and tune our model based on the objective metrics we consider relevant for our task. Because we use a custom training script, we need to define those metrics for SageMaker explicitly.

Our script emits the metrics precision, recall, and F1-score during training simply by using the print function:

print(f'pre: {pre:5.3f} rec: {rec:5.3f} f1: {f1:5.3}')

The standard output is captured by SageMaker and sent to CloudWatch as a log stream. To retrieve the metric values and work with them later in SageMaker AMT, we need to provide some information on how to parse that output. We can achieve this by defining regular expression statements (for more information, refer to Monitor and Analyze Training Jobs Using Amazon CloudWatch Metrics):

metric_definitions = [
    {'Name': 'valid-precision',  'Regex': r'pre:s+(-?[0-9.]+)'},
    {'Name': 'valid-recall',     'Regex': r'rec:s+(-?[0-9.]+)'},
    {'Name': 'valid-f1',         'Regex': r'f1:s+(-?[0-9.]+)'}]   

Let’s walk through the first metric definition in the preceding code together. SageMaker will look for output in the log that starts with pre: and is followed by one or more whitespace and then a number that we want to extract, which is why we use the round parenthesis. Every time SageMaker finds a value like that, it turns it into a CloudWatch metric with the name valid-precision.

Train the model using the Scikit-learn framework

After we create our training script train.py and instruct SageMaker on how to monitor the metrics within CloudWatch, we define a SageMaker Estimator object. It initiates the training job and uses the instance type we specify. But how can this instance type be different from the one you run an Amazon SageMaker Studio notebook on, and why? SageMaker Studio runs your training (and inference) jobs on separate compute instances than your notebook. This allows you to continue working in your notebook while the jobs run in the background.

The parameter framework_version refers to the Scikit-learn version we use for our training job. Alternatively, we can pass image_uri to the estimator. You can check whether your favorite framework or ML library is available as a pre-built SageMaker Docker image and use it as is or with extensions.

Moreover, we can run SageMaker training jobs on EC2 Spot Instances by setting use_spot_instances to True. They are spare capacity instances that can save up to 90% of costs.  These instances provide flexibility on when the training jobs are run.

estimator = SKLearn(
    'train.py',
    source_dir='src',
    role=get_execution_role(),
    instance_type= 'ml.m5.large',
    instance_count=1,
    framework_version='0.23-1',
    metric_definitions=metric_definitions,

    # Uncomment the following three lines to use Managed Spot Training
    # use_spot_instances= True,
    # max_run=  60 * 60 * 24,
    # max_wait= 60 * 60 * 24,

    hyperparameters = {'n-estimators': 100,
                       'max-depth': 10,
                       'min-samples-leaf': 1,
                       'max-features': 'auto',
                       'min-weight-fraction-leaf': 0.1}
)

After the Estimator object is set up, we start the training by calling the fit() function, supplying the path to the training dataset on Amazon S3. We can use this same method to provide validation and test data. We set the wait parameter to True so we can use the trained model in the subsequent code cells.

estimator.fit({'train': s3_data_url}, wait=True)

Define hyperparameters and run tuning jobs

So far, we have trained the model with one set of hyperparameter values. But were those values good? Or could we look for better ones? Let’s use the HyperparameterTuner class to run a systematic search over the hyperparameter space. How do we search this space with the tuner? The necessary parameters are the objective metric name and objective type that will guide the optimization. The optimization strategy is another key argument for the tuner because it further defines the search space. The following are four different strategies to choose from:

  • Grid search
  • Random search
  • Bayesian optimization (default)
  • Hyperband

We further describe these strategies and equip you with some guidance to choose one later in this post.

Before we define and run our tuner object, let’s recap our understanding from an architecture perspective. We covered the architectural overview of SageMaker AMT in our last post and reproduce an excerpt of it here for convenience.

Amazon SageMaker Automatic Model Tuning Architecture

We can choose what hyperparameters we want to tune or leave static. For dynamic hyperparameters, we provide hyperparameter_ranges that can be used to optimize for tunable hyperparameters. Because we use a Random Forest classifier, we have utilized the hyperparameters from the Scikit-learn Random Forest documentation.

We also limit resources with the maximum number of training jobs and parallel training jobs the tuner can use. We will see how these limits help us compare the results of various strategies with each other.

tuner_parameters = {
    'estimator': estimator,
    'base_tuning_job_name': 'random',
    'metric_definitions': metric_definitions,
    'objective_metric_name': 'valid-f1',
    'objective_type': 'Maximize',
    'hyperparameter_ranges': hpt_ranges,
    'strategy': 'Random',
    'max_jobs': n, # 50
    'max_parallel_jobs': k # 2
    } 

Similar to the Estimator’s fit function, we start a tuning job calling the tuner’s fit:

random_tuner = HyperparameterTuner(**tuner_parameters)
random_tuner.fit({'train': s3_data_url}, wait=False)

This is all we have to do to let SageMaker run the training jobs (n=50) in the background, each using a different set of hyperparameters. We explore the results later in this post. But before that, let’s start another tuning job, this time applying the Bayesian optimization strategy. We will compare both strategies visually after their completion.

tuner_parameters['strategy']             = 'Bayesian'
tuner_parameters['base_tuning_job_name'] = 'bayesian'
bayesian_tuner = HyperparameterTuner(**tuner_parameters)
bayesian_tuner.fit({'train': s3_data_url}, wait=False)

Note that both tuner jobs can run in parallel because SageMaker orchestrates the required compute instances independently of each other. That’s quite helpful for practitioners who experiment with different approaches at the same time, like we do here.

Choose between SageMaker HPO strategies

When it comes to tuning strategies, you have a few options with SageMaker AMT: grid search, random search, Bayesian optimization, and Hyperband. These strategies determine how the automatic tuning algorithms explore the specified ranges of hyperparameters.

Random search is pretty straightforward. It randomly selects combinations of values from the specified ranges and can be run in a sequential or parallel manner. It’s like throwing darts blindfolded, hoping to hit the target. We have started with this strategy, but will the results improve with another one?

Bayesian optimization takes a different approach than random search. It considers the history of previous selections and chooses values that are likely to yield the best results. If you want to learn from previous explorations, you can achieve this only with running a new tuning job after the previous ones. Makes sense, right? In this way, Bayesian optimization is dependent on the previous runs. But do you see what HPO strategy allows for higher parallelization?

Hyperband is an interesting one! It uses a multi-fidelity strategy, which means it dynamically allocates resources to the most promising training jobs and stops those that are underperforming. Therefore, Hyperband is computationally efficient with resources, learning from previous training jobs. After stopping the underperforming configurations, a new configuration starts, and its values are chosen randomly.

Depending on your needs and the nature of your model, you can choose between random search, Bayesian optimization, or Hyperband as your tuning strategy. Each has its own approach and advantages, so it’s important to consider which one works best for your ML exploration. The good news for ML practitioners is that you can select the best HPO strategy by visually comparing the impact of each trial on the objective metric. In the next section, we see how to visually identify the impact of different strategies.

Visualize, analyze, and compare tuning results

When our tuning jobs are complete, it gets exciting. What results do they deliver? What kind of boost can you expect on our metric compared to your base model? What are the best-performing hyperparameters for our use case?

A quick and straightforward way to view the HPO results is by visiting the SageMaker console. Under Hyperparameter tuning jobs, we can see (per tuning job) the combination of hyperparameter values that have been tested and delivered the best performance as measured by our objective metric (valid-f1).

Metrics for Hyperparameter tuning jobs

Is that all you need? As an ML practitioner, you may be not only interested in those values, but certainly want to learn more about the inner workings of your model to explore its full potential and strengthen your intuition with empirical feedback.

A good visualization tool can greatly help you understand the improvement by HPO over time and get empirical feedback on design decisions of your ML model. It shows the impact of each individual hyperparameter on your objective metric and provides guidance to further optimize your tuning results.

We use the amtviz custom visualization package to visualize and analyze tuning jobs. It’s straightforward to use and provides helpful features. We demonstrate its benefit by interpreting some individual charts, and finally comparing random search side by side with Bayesian optimization.

First, let’s create a visualization for random search. We can do this by calling visualize_tuning_job() from amtviz and passing our first tuner object as an argument:

from amtviz import visualize_tuning_job
visualize_tuning_job(random_tuner, advanced=True, trials_only=True)

You will see a couple of charts, but let’s take it step by step. The first scatter plot from the output looks like the following and already gives us some visual clues we wouldn’t recognize in any table.

Hyperparameter Optimization Job Results

Each dot represents the performance of an individual training job (our objective valid-f1 on the y-axis) based on its start time (x-axis), produced by a specific set of hyperparameters. Therefore, we look at the performance of our model as it progresses over the duration of the tuning job.

The dotted line highlights the best result found so far and indicates improvement over time. The best two training jobs achieved an F1 score of around 0.91.

Besides the dotted line showing the cumulative progress, do you see a trend in the chart?

Probably not. And this is expected, because we’re viewing the results of the random HPO strategy. Each training job was run using a different but randomly selected set of hyperparameters. If we continued our tuning job (or ran another one with the same setting), we would probably see some better results over time, but we can’t be sure. Randomness is a tricky thing.

The next charts help you gauge the influence of hyperparameters on the overall performance. All hyperparameters are visualized, but for the sake of brevity, we focus on two of them: n-estimators and max-depth.

Hyperparameter Jobs Details

Our top two training jobs were using n-estimators of around 20 and 80, and max-depth of around 10 and 18, respectively. The exact hyperparameter values are displayed via tooltip for each dot (training job). They are even dynamically highlighted across all charts and give you a multi-dimensional view! Did you see that? Each hyperparameter is plotted against the objective metric, as a separate chart.

Now, what kind of insights do we get about n-estimators?

Based on the left chart, it seems that very low value ranges (below 10) more often deliver poor results compared to higher values. Therefore, higher values may help your model to perform better—interesting.

In contrast, the correlation of the max-depth hyperparameter to our objective metric is rather low. We can’t clearly tell which value ranges are performing better from a general perspective.

In summary, random search can help you find a well-performing set of hyperparameters even in a relatively short amount of time. Also, it isn’t biased towards a good solution but gives a balanced view of the search space. Your resource utilization, however, might not be very efficient. It continues to run training jobs with hyperparameters in value ranges that are known to deliver poor results.

Let’s examine the results of our second tuning job using Bayesian optimization. We can use amtviz to visualize the results in the same way as we did so far for the random search tuner. Or, even better, we can use the capability of the function to compare both tuning jobs in a single set of charts. Quite handy!

visualize_tuning_job([random_tuner, bayesian_tuner], advanced=True, trials_only=True)

Hyperparameter Optimization Job Bayesian VS Random

There are more dots now because we visualize the results of all training jobs for both, the random search (orange dots) and the Bayesian optimization (blue dots). On the right side, you can see a density chart visualizing the distribution of all F1-scores. A majority of the training jobs achieved results in the upper part of the F1 scale (over 0.6)—that’s good!

What is the key takeaway here? The scatter plot clearly shows the benefit of Bayesian optimization. It delivers better results over time because it can learn from previous runs. That’s why we achieved significantly better results using Bayesian compared to random (0.967 vs. 0.919) with the same number of training jobs.

There is even more you can do with amtviz. Let’s drill in.

If you give SageMaker AMT the instruction to run a larger number of jobs for tuning, seeing many trials at once can get messy. That’s one of the reasons why we made these charts interactive. You can click and drag on every hyperparameter scatter plot to zoom in to certain value ranges and refine your visual interpretation of the results. All other charts are automatically updated. That’s pretty helpful, isn’t it? See the next charts as an example and try it for yourself in your notebook!

Hyperparameter Optimization Job Visualisation Features

As a tuning maximalist, you may also decide that running another hyperparameter tuning job could further improve your model performance. But this time, a more specific range of hyperparameter values can be explored because you already know (roughly) where to expect better results. For example, you may choose to focus on values between 100–200 for n-estimators, as shown in the chart. This lets AMT focus on the most promising training jobs and increases your tuning efficiency.

To sum it up, amtviz provides you with a rich set of visualization capabilities that allow you to better understand the impact of your model’s hyperparameters on performance and enable smarter decisions in your tuning activities.

Continue the exploration of the hyperparameter space and warm start HPO jobs

We have seen that AMT helps us explore the hyperparameter search space efficiently. But what if we need multiple rounds of tuning to iteratively improve our results? As mentioned in the beginning, we want to establish an optimization feedback cycle—our “conversation” with the model. Do we need to start from scratch every time?

Let’s look into the concept of running a warm start hyperparameter tuning job. It doesn’t initiate new tuning jobs from scratch, it reuses what has been learned in the previous HPO runs. This helps us be more efficient with our tuning time and compute resources. We can further iterate on top of our previous results. To use warm starts, we create a WarmStartConfig and specify warm_start_type as IDENTICAL_DATA_AND_ALGORITHM. This means that we change the hyperparameter values but we don’t change the data or algorithm. We tell AMT to transfer the previous knowledge to our new tuning job.

By referring to our previous Bayesian optimization and random search tuning jobs as parents, we can use them both for the warm start:

warm_start_config = WarmStartConfig(warm_start_type=WarmStartTypes.IDENTICAL_DATA_AND_ALGORITHM,
                                    parents=[bayesian_tuner_name, random_tuner_name])
tuner_parameters['warm_start_config'] =  warm_start_config

To see the benefit of using warm starts, refer to the following charts. These are generated by amtviz in a similar way as we did earlier, but this time we have added another tuning job based on a warm start.

Hyperparameter Optimization Job Warmstart

In the left chart, we can observe that new tuning jobs mostly lie in the upper-right corner of the performance metric graph (see dots marked in orange). The warm start has indeed reused the previous results, which is why those data points are in the top results for F1 score. This improvement is also reflected in the density chart on the right.

In other words, AMT automatically selects promising sets of hyperparameter values based on its knowledge from previous trials. This is shown in the next chart. For example, the algorithm would test a low value for n-estimators less often because these are known to produce poor F1 scores. We don’t waste any resources on that, thanks to warm starts.

Hyperparameter Optimization Visualised Jobs

Clean up

To avoid incurring unwanted costs when you’re done experimenting with HPO, you must remove all files in your S3 bucket with the prefix amt-visualize-demo and also shut down SageMaker Studio resources.

Run the following code in your notebook to remove all S3 files from this post:

!aws s3 rm s3://{BUCKET}/amt-visualize-demo --recursive

If you wish to keep the datasets or the model artifacts, you may modify the prefix in the code to amt-visualize-demo/data to only delete the data or amt-visualize-demo/output to only delete the model artifacts.

Conclusion

We have learned how the art of building ML solutions involves exploring and optimizing hyperparameters. Adjusting those knobs and levers is a demanding yet rewarding process that leads to faster training times, improved model accuracy, and overall better ML solutions. The SageMaker AMT functionality helps us run multiple tuning jobs and warm start them, and provides data points for further review, visual comparison, and analysis.

In this post, we looked into HPO strategies that we use with SageMaker AMT. We started with random search, a straightforward but performant strategy where hyperparameters are randomly sampled from a search space. Next, we compared the results to Bayesian optimization, which uses probabilistic models to guide the search for optimal hyperparameters. After we identified a suitable HPO strategy and good hyperparameter value ranges through initial trials, we showed how to use warm starts to streamline future HPO jobs.

You can explore the hyperparameter search space by comparing quantitative results. We have suggested the side-by-side visual comparison and provided the necessary package for interactive exploration. Let us know in the comments how helpful it was for you on your hyperparameter tuning journey!


About the authors

Uemit YoldasÜmit Yoldas is a Senior Solutions Architect with Amazon Web Services. He works with enterprise customers across industries in Germany. He’s driven to translate AI concepts into real-world solutions. Outside of work, he enjoys time with family, savoring good food, and pursuing fitness.

Elina LesykElina Lesyk is a Solutions Architect located in Munich. She is focusing on enterprise customers from the financial services industry. In her free time, you can find Elina building applications with generative AI at some IT meetups, driving a new idea on fixing climate change fast, or running in the forest to prepare for a half-marathon with a typical deviation from the planned schedule.

Mariano kampMariano Kamp is a Principal Solutions Architect with Amazon Web Services. He works with banks and insurance companies in Germany on machine learning. In his spare time, Mariano enjoys hiking with his wife.

Read More