NeILF++: Inter-Reflectable Light Fields for Geometry and Material Estimation

We present a novel differentiable rendering framework for joint geometry, material, and lighting estimation from multi-view images. In contrast to previous methods which assume a simplified environment map or co-located flashlights, in this work, we formulate the lighting of a static scene as one neural incident light field (NeILF) and one outgoing neural radiance field (NeRF). The key insight of the proposed method is the union of the incident and outgoing light fields through physically-based rendering and inter-reflections between surfaces, making it possible to disentangle the scene…Apple Machine Learning Research

MediaPipe FaceStylizer: On-device real-time few-shot face stylization

MediaPipe FaceStylizer: On-device real-time few-shot face stylization


In recent years, we have witnessed rising interest across consumers and researchers in integrated augmented reality (AR) experiences using real-time face feature generation and editing functions in mobile applications, including short videos, virtual reality, and gaming. As a result, there is a growing demand for lightweight, yet high-quality face generation and editing models, which are often based on generative adversarial network (GAN) techniques. However, the majority of GAN models suffer from high computational complexity and the need for a large training dataset. In addition, it is also important to employ GAN models responsibly.

In this post, we introduce MediaPipe FaceStylizer, an efficient design for few-shot face stylization that addresses the aforementioned model complexity and data efficiency challenges while being guided by Google’s responsible AI Principles. The model consists of a face generator and a face encoder used as GAN inversion to map the image into latent code for the generator. We introduce a mobile-friendly synthesis network for the face generator with an auxiliary head that converts features to RGB at each level of the generator to generate high quality images from coarse to fine granularities. We also carefully designed the loss functions for the aforementioned auxiliary heads and combined them with the common GAN loss functions to distill the student generator from the teacher StyleGAN model, resulting in a lightweight model that maintains high generation quality. The proposed solution is available in open source through MediaPipe. Users can fine-tune the generator to learn a style from one or a few images using MediaPipe Model Maker, and deploy to on-device face stylization applications with the customized model using MediaPipe FaceStylizer.

Few-shot on-device face stylization

An end-to-end pipeline

Our goal is to build a pipeline to support users to adapt the MediaPipe FaceStylizer to different styles by fine-tuning the model with a few examples. To enable such a face stylization pipeline, we built the pipeline with a GAN inversion encoder and efficient face generator model (see below). The encoder and generator pipeline can then be adapted to different styles via a few-shot learning process. The user first sends a single or a few similar samples of the style images to MediaPipe ModelMaker to fine-tune the model. The fine-tuning process freezes the encoder module and only fine-tunes the generator. The training process samples multiple latent codes close to the encoding output of the input style images as the input to the generator. The generator is then trained to reconstruct an image of a person’s face in the style of the input style image by optimizing a joint adversarial loss function that also accounts for style and content. With such a fine-tuning process, the MediaPipe FaceStylizer can adapt to the customized style, which approximates the user’s input. It can then be applied to stylize test images of real human faces.

Generator: BlazeStyleGAN

The StyleGAN model family has been widely adopted for face generation and various face editing tasks. To support efficient on-device face generation, we based the design of our generator on StyleGAN. This generator, which we call BlazeStyleGAN, is similar to StyleGAN in that it also contains a mapping network and synthesis network. However, since the synthesis network of StyleGAN is the major contributor to the model’s high computation complexity, we designed and employed a more efficient synthesis network. The improved efficiency and generation quality is achieved by:

  1. Reducing the latent feature dimension in the synthesis network to a quarter of the resolution of the counterpart layers in the teacher StyleGAN,
  2. Designing multiple auxiliary heads to transform the downscaled feature to the image domain to form a coarse-to-fine image pyramid to evaluate the perceptual quality of the reconstruction, and
  3. Skipping all but the final auxiliary head at inference time.

With the newly designed architecture, we train the BlazeStyleGAN model by distilling it from a teacher StyleGAN model. We use a multi-scale perceptual loss and adversarial loss in the distillation to transfer the high fidelity generation capability from the teacher model to the student BlazeStyleGAN model and also to mitigate the artifacts from the teacher model.

More details of the model architecture and training scheme can be found in our paper.

Visual comparison between face samples generated by StyleGAN and BlazeStyleGAN. The images on the first row are generated by the teacher StyleGAN. The images on the second row are generated by the student BlazeStyleGAN. The face generated by BlazeStyleGAN has similar visual quality to the image generated by the teacher model. Some results demonstrate the student BlazeStyleGAN suppresses the artifacts from the teacher model in the distillation.

In the above figure, we demonstrate some sample results of our BlazeStyleGAN. By comparing with the face image generated by the teacher StyleGAN model (top row), the images generated by the student BlazeStyleGAN (bottom row) maintain high visual quality and further reduce artifacts produced by the teacher due to the loss function design in our distillation.

An encoder for efficient GAN inversion

To support image-to-image stylization, we also introduced an efficient GAN inversion as the encoder to map input images to the latent space of the generator. The encoder is defined by a MobileNet V2 backbone and trained with natural face images. The loss is defined as a combination of image perceptual quality loss, which measures the content difference, style similarity and embedding distance, as well as the L1 loss between the input images and reconstructed images.

On-device performance

We documented model complexities in terms of parameter numbers and computing FLOPs in the following table. Compared to the teacher StyleGAN (33.2M parameters), BlazeStyleGAN (generator) significantly reduces the model complexity, with only 2.01M parameters and 1.28G FLOPs for output resolution 256×256. Compared to StyleGAN-1024 (generating image size of 1024×1024), the BlazeStyleGAN-1024 can reduce both model size and computation complexity by 95% with no notable quality difference and can even suppress the artifacts from the teacher StyleGAN model.

Model     Image Size     #Params (M)     FLOPs (G)
StyleGAN     1024     33.17     74.3
BlazeStyleGAN     1024     2.07     4.70
BlazeStyleGAN     512     2.05     1.57
BlazeStyleGAN     256     2.01     1.28
Encoder     256     1.44     0.60
Model complexity measured by parameter numbers and FLOPs.

We benchmarked the inference time of the MediaPipe FaceStylizer on various high-end mobile devices and demonstrated the results in the table below. From the results, both BlazeStyleGAN-256 and BlazeStyleGAN-512 achieved real-time performance on all GPU devices. It can run in less than 10 ms runtime on a high-end phone’s GPU. BlazeStyleGAN-256 can also achieve real-time performance on the iOS devices’ CPU.

Model     BlazeStyleGAN-256 (ms)     Encoder-256 (ms)
iPhone 11     12.14     11.48
iPhone 12     11.99     12.25
iPhone 13 Pro     7.22     5.41
Pixel 6     12.24     11.23
Samsung Galaxy S10     17.01     12.70
Samsung Galaxy S20     8.95     8.20
Latency benchmark of the BlazeStyleGAN, face encoder, and the end-to-end pipeline on various mobile devices.

Fairness evaluation

The model has been trained with a high diversity dataset of human faces. The model is expected to be fair to different human faces. The fairness evaluation demonstrates the model performs good and balanced in terms of human gender, skin-tone, and ages.

Face stylization visualization

Some face stylization results are demonstrated in the following figure. The images in the top row (in orange boxes) represent the style images used to fine-tune the model. The images in the left column (in the green boxes) are the natural face images used for testing. The 2×4 matrix of images represents the output of the MediaPipe FaceStylizer which is blending outputs between the natural faces on the left-most column and the corresponding face styles on the top row. The results demonstrate that our solution can achieve high-quality face stylization for several popular styles.

Sample results of our MediaPipe FaceStylizer.

MediaPipe Solutions

The MediaPipe FaceStylizer is going to be released to public users in MediaPipe Solutions. Users can leverage MediaPipe Model Maker to train a customized face stylization model using their own style images. After training, the exported bundle of TFLite model files can be deployed to applications across platforms (Android, iOS, Web, Python, etc.) using the MediaPipe Tasks FaceStylizer API in just a few lines of code.

Acknowledgements

This work is made possible through a collaboration spanning several teams across Google. We’d like to acknowledge contributions from Omer Tov, Yang Zhao, Andrey Vakunov, Fei Deng, Ariel Ephrat, Inbar Mosseri, Lu Wang, Chuo-Ling Chang, Tingbo Hou, and Matthias Grundmann.

Read More

Test-time Adaptation with Slot-Centric Models

Test-time Adaptation with Slot-Centric Models

TLDR: Current SOTA methods for scene understanding, though impressive, often fail to decompose out-of-distribution scenes. In our ICML paper, Slot-TTA (http://slot-tta.github.io) we find that optimizing per test sample over reconstruction loss improves scene decomposition accuracy.

Problem Statement: In machine learning, we often assume the train and test split are IID samples from the same distribution. However, this doesn’t hold true in reality. In fact, there is a distribution shift happening all the time!

For example on the left, we visualize images from the ImageNet Chair category, and on the right, we visualize the ObjectNet chair category. As you can see there are a variety of real-world distribution shifts happening all the time. For instance, camera pose changes, occlusions, and changes in scene configuration.

So what is the issue? The issue is that in machine learning we always assume there to be a fixed train and test split. However, in the real world, there is no such universal train and test split, instead, there are distribution shifts happening all the time.

Instead of freezing our models at test time, which is what we conventionally do, we should instead continuously adapt them to various distribution shifts.

Given these issues, there has been a lot of work in this domain, which is also referred to as test-time adaptation.  Test-time adaptation can be broadly classified into supervised test-time adaptation, where you are given access to a few labeled examples, or unsupervised domain adaptation where you do not have access to any labels. In this work, we focus on unsupervised adaptation, as it is a more general setting.

Within unsupervised domain adaptation, there are various settings such as batch, online, or single-example test-time adaptation. In this work, we focus on single-example setting. In this setting, the model adapts to each example in the test set independently. This is a lot more general setting than batch or online where you assume access to many unlabeled examples.

What is the prominent approach in this setting?

Sun, et al. proposed to encode input data X into a shared encoding of z which is then passed to a supervised task decoder and a self-supervised decoder. The whole model is then trained jointly using supervised and self-supervised losses. This joint training helps to couple the self-supervised and supervised tasks. Coupling allows test-time adaption using the self-supervised loss. Approaches vary based on the type of self-supervised loss used: TTT uses rotation prediction loss, MT3 uses instance prediction loss and TTT-MAE uses masked autoencoding loss.

However, all approaches only focus on the task of Image Classification. In our work, we find just joint training with losses is insufficient for Scene Understanding tasks. We find that architectural biases could be important for adaptation. Specifically, we use slot-centric biases that strongly couple scene decomposition and reconstruction loss are a perfect fit.

Slot-centric generative models attempt to segment scenes into object entities in a completely unsupervised manner, by optimizing a reconstruction objective [1,2,3] that shares the end goal of scene decomposition which can become a good candidate architecture for TTA.

These methods differ in detail but share the notion of incorporating a fixed set of entities, also known as slots or object files. Each slot extracts information about a single entity during encoding and is “synthesized” back to the input domain during decoding.

Test-time adaptation in Slot-TTA: Segmentation improves when optimizing reconstruction or view synthesis objectives via gradient descent at test-time on a single test sample.

In light of the above, we propose Test-Time Adaptation with Slot-Centric models (Slot-TTA), a semi-supervised model equipped with a slot-centric bottleneck that jointly segments and reconstructs scenes.

At training time, Slot-TTA is trained in a supervised manner to jointly segment and reconstruct 2D (multi-view or single-view) RGB images or 3D point clouds. At test time, the model adapts to a single test sample by updating its network parameters solely by optimizing the reconstruction objective through gradient descent, as shown in the above figure.

Slot-TTA builds on top of slot-centric models by incorporating segmentation supervision during the training phase. Until now, slot-centric models have been neither designed nor utilized with the foresight of Test-Time Adaptation (TTA).

In particular, Engelcke et al. (2020) showed that TTA via reconstruction in slot-centric models fails due to a reconstruction segmentation trade-off: as the entity bottleneck loosens, there’s an improvement in reconstruction; however, segmentation subsequently deteriorates. We show that segmentation supervision aids in mitigating this trade-off and helps scale to scenes with complicated textures. We show that TTA in semi-supervised slot-centric models significantly improves scene decomposition.

Model architecture for Slot-TTA for posed multi-view or single-view RGB images (top) and 3D point clouds (bottom). Slot-TTA maps the input (multi-view posed) RGB images or 3D point cloud to a set of token features with appropriate encoder backbones. It then maps these token features to a set of slot vectors using Slot Attention. Finally, it decodes each slot into its respective segmentation mask and RGB image or 3D point cloud. It uses weighted averaging or max-pooling to fuse renders across all slots. For RGB images, we show results for multi-view and single-view settings, where in the multi-view setting the decoder is conditioned on a target camera viewpoint. We train Slot-TTA using reconstruction and segmentation losses. At test time, we optimize only the reconstruction loss

Our contributions are as follows:

(i) We present an algorithm that significantly improves scene decomposition accuracy for out-of-distribution examples by performing test-time adaptation on each example in the test set independently.

(ii) We showcase the effectiveness of SSL-based TTA approaches for scene decomposition, while previous self-supervised test-time adaptation methods have primarily demonstrated results in classification tasks.

(iii) We introduce semi-supervised learning for slot-centric generative models, and show it can enable these methods to continue learning during test time. In contrast, previous works on slot-centric generative have neither been trained with supervision nor been used for test time adaptation.

(iv) Lastly, we devise numerous baselines and ablations, and evaluate them across multiple benchmarks and distribution shifts to offer valuable insights into test-time adaptation and object-centric learning.

Results: We test Slot-TTA on scene understanding tasks of novel view rendering and scene segmentation. We test on various input modalities such as multi-view posed images, single-view images, and 3D point clouds in the datasets of PartNet, MultiShapeNet-Hard, and CLEVR.

We compare Slot-TTA’s segmentation performance against state-of-the-art supervised feedforward RGB image and 3D point cloud segmentors of Mask2Former and Mask3D, state-of-the-art novel view rendering methods of SemanticNeRF that adapt per scene through RGB and segmentation rendering and state-of-the-art test-time adaptation methods such as MT3.

We show that Slot-TTA outperforms SOTA feedforward segmenters in out-of-distribution scenes, dramatically outperforms alternative TTA methods and alternative semi-supervised scene decomposition methods, and better exploits multiview information for improving segmentation over semantic NeRF-based multi-view fusion.

Below we show our multi-view RGB results on MultiShapeNet dataset of Kubrics.

We consider various distribution shifts throughout our paper, for the results below we consider the following distribution shift.

We use a train-test split of Multi-ShapeNet-Easy to Multi-ShapeNet-Hard where there is no overlap between object instances and between the number of objects present in the scene between training and test sets. Specifically, scenes with 5-7 object instances are in the training set, and scenes with 16-30 objects are in the test set.

We consider the following baselines:

(i) Mask2Former (Cheng et al., 2021), a state-of-the-art 2D image segmentor that extends detection transformers (Carion et al., 2020) to the task of image segmentation via using multiscale segmentation decoders with masked attention.

(ii) Mask2Former-BYOL which combines the segmentation model of Cheng et al. (2021) with test time adaptation using BYOL self-supervised loss of MT3 (Bartler et al. (2022)).

(iii) Mask2Former-Recon which combines the segmentation model of Cheng et al. (2021) with an RGB rendering module and an image reconstruction objective for test-time adaptation.

(iv) Semantic-NeRF (Zhi et al., 2021), a NeRF model that adds a segmentation rendering head to the multi-view RGB rendering head of traditional NeRFs. It is fit per scene on all available 9 RGB posed images and corresponding segmentation maps from Mask2Former as input.

(v) Slot-TTA-w/o supervision, a variant of our model that does not use any segmentation supervision; rather is trained only for cross-view image synthesis similar to OSRT (Sajjadi et al., 2022a).

Instance Segmentation ARI accuracy (higher is better) in the multi-view RGB setup for in-distribution test set of 5-7 object instances and out-of-distribution 16-30 object instances.

Our conclusions are as follows:

(i) Slot-TTA with TTA outperforms Mask2Former in out-of-distribution scenes and has comparable performance within the training distribution.

(ii) Mask2Former-BYOL does not improve over Mask2Former, which suggests that adding self-supervised losses of SOTA image classification TTA methods (Bartler et al., 2022) to scene segmentation methods does not help.

(iii) Slot-TTA-w/o supervision (model identical to Sajjadi et al. (2022a)) greatly underperforms a supervised segmentor Mask2Former. This means that unsupervised slot-centric models are still far from reaching their supervised counterparts.

(iv) Slot-TTA-w/o supervision does not improve during test-time adaptation. This suggests segmentation supervision at training time is essential for effective TTA.

(v) Semantic-NeRF which fuses segmentation masks across
views in a geometrically consistent manner outperforms single-view segmentation performance of Mask2Former by 3%.

(vi) Slot-TTA which adapts model parameters of the segmentor at test time greatly outperforms Semantic-NeRF in OOD scenes.

(vii) Mask2Former-Recon performs worse with TTA, which suggests that the decoder’s design is very important for aligning the reconstruction and segmentation tasks.

For point clouds, we train the model using certain categories of PartNet and test it using a different set. For quantitative comparisons with the baselines please refer to our paper. As can be seen in the figure below, point cloud segmentation of Slot-TTA improves after optimizing over point cloud reconstruction loss.

For 2D RGB images, we train the model supervised on the CLEVR dataset and test it on CLEVR-Tex. For quantitative comparisons with the baselines please refer to our paper. As can be seen in the figure below, RGB segmentation of Slot-TTA improves after optimizing over RGB reconstruction loss.

Finally, we find that Slot-TTA doesn’t just improve the segmentation performance on out-of-distribution scenes, but also improves the performance on other downstream tasks such as novel view synthesis!

Novel view rendering results of Slot-TTA after doing test-time adaptation. As can be seen, our scene segmentation results improve after adding TTA.

Conclusion: We presented Slot-TTA, a novel semi-supervised scene decomposition model equipped with a slot-centric image or point-cloud rendering component for test time adaptation. We showed Slot-TTA greatly improves instance segmentation on out-of-distribution scenes using test-time adaptation on reconstruction or novel view synthesis objectives. We compared with numerous baseline methods, ranging from state-of-the-art feedforward segmentors, to NERF-based TTA for multiview semantic fusion, to state-of-the-art TTA methods, to unsupervised or weakly supervised 2D and 3D generative models. We showed Slot-TTA compares favorably against all of them for scene decomposition of OOD scenes, while still being competitive within distribution.

Paper Authors; Mihir Prabhudesai, Anirudh Goyal, Sujoy Paul, Sjoerd van Steenkiste, Mehdi S. M. Sajjadi, Gaurav Aggarwal, Thomas Kipf, Deepak Pathak, Katerina Fragkiadaki.

Code: <https://github.com/mihirp1998/Slot-TTA>

Webpage: <https://slot-tta.github.io/>

Paper: <https://arxiv.org/abs/2203.11194>

Read More

Learn how to build and deploy tool-using LLM agents using AWS SageMaker JumpStart Foundation Models

Learn how to build and deploy tool-using LLM agents using AWS SageMaker JumpStart Foundation Models

Large language model (LLM) agents are programs that extend the capabilities of standalone LLMs with 1) access to external tools (APIs, functions, webhooks, plugins, and so on), and 2) the ability to plan and execute tasks in a self-directed fashion. Often, LLMs need to interact with other software, databases, or APIs to accomplish complex tasks. For example, an administrative chatbot that schedules meetings would require access to employees’ calendars and email. With access to tools, LLM agents can become more powerful—at the cost of additional complexity.

In this post, we introduce LLM agents and demonstrate how to build and deploy an e-commerce LLM agent using Amazon SageMaker JumpStart and AWS Lambda. The agent will use tools to provide new capabilities, such as answering questions about returns (“Is my return rtn001 processed?”) and providing updates about orders (“Could you tell me if order 123456 has shipped?”). These new capabilities require LLMs to fetch data from multiple data sources (orders, returns) and perform retrieval augmented generation (RAG).

To power the LLM agent, we use a Flan-UL2 model deployed as a SageMaker endpoint and use data retrieval tools built with AWS Lambda. The agent can subsequently be integrated with Amazon Lex and used as a chatbot inside websites or AWS Connect. We conclude the post with items to consider before deploying LLM agents to production. For a fully managed experience for building LLM agents, AWS also provides the agents for Amazon Bedrock feature (in preview).

A brief overview of LLM agent architectures

LLM agents are programs that use LLMs to decide when and how to use tools as necessary to complete complex tasks. With tools and task planning abilities, LLM agents can interact with outside systems and overcome traditional limitations of LLMs, such as knowledge cutoffs, hallucinations, and imprecise calculations. Tools can take a variety of forms, such as API calls, Python functions, or webhook-based plugins. For example, an LLM can use a “retrieval plugin” to fetch relevant context and perform RAG.

So what does it mean for an LLM to pick tools and plan tasks? There are numerous approaches (such as ReAct, MRKL, Toolformer, HuggingGPT, and Transformer Agents) to using LLMs with tools, and advancements are happening rapidly. But one simple way is to prompt an LLM with a list of tools and ask it to determine 1) if a tool is needed to satisfy the user query, and if so, 2) select the appropriate tool. Such a prompt typically looks like the following example and may include few-shot examples to improve the LLM’s reliability in picking the right tool.

‘’’
Your task is to select a tool to answer a user question. You have access to the following tools.

search: search for an answer in FAQs
order: order items
noop: no tool is needed

{few shot examples}

Question: {input}
Tool:
‘’’

More complex approaches involve using a specialized LLM that can directly decode “API calls” or “tool use,” such as GorillaLLM. Such finetuned LLMs are trained on API specification datasets to recognize and predict API calls based on instruction. Often, these LLMs require some metadata about available tools (descriptions, yaml, or JSON schema for their input parameters) in order to output tool invocations. This approach is taken by agents for Amazon Bedrock and OpenAI function calls. Note that LLMs generally need to be sufficiently large and complex in order to show tool selection ability.

Typical LLM Agent Architecture

Assuming task planning and tool selection mechanisms are chosen, a typical LLM agent program works in the following sequence:

  1. User request – The program takes a user input such as “Where is my order 123456?” from some client application.
  2. Plan next action(s) and select tool(s) to use – Next, the program uses a prompt to have the LLM generate the next action, for example, “Look up the orders table using OrdersAPI.” The LLM is prompted to suggest a tool name such as OrdersAPI from a predefined list of available tools and their descriptions. Alternatively, the LLM could be instructed to directly generate an API call with input parameters such as OrdersAPI(12345).
    1. Note that the next action may or may not involve using a tool or API. If not, the LLM would respond to user input without incorporating additional context from tools or simply return a canned response such as, “I cannot answer this question.”
  3. Parse tool request – Next, we need to parse out and validate the tool/action prediction suggested by the LLM. Validation is needed to ensure tool names, APIs, and request parameters aren’t hallucinated and that the tools are properly invoked according to specification. This parsing may require a separate LLM call.
  4. Invoke tool – Once valid tool name(s) and parameter(s) are ensured, we invoke the tool. This could be an HTTP request, function call, and so on.
  5. Parse output – The response from the tool may need additional processing. For example, an API call may result in a long JSON response, where only a subset of fields are of interest to the LLM. Extracting information in a clean, standardized format can help the LLM interpret the result more reliably.
  6. Interpret output – Given the output from the tool, the LLM is prompted again to make sense of it and decide whether it can generate the final answer back to the user or whether additional actions are required.
  7. Terminate or continue to step 2 – Either return a final answer or a default answer in the case of errors or timeouts.

Different agent frameworks execute the previous program flow differently. For example, ReAct combines tool selection and final answer generation into a single prompt, as opposed to using separate prompts for tool selection and answer generation. Also, this logic can be run in a single pass or run in a while statement (the “agent loop”), which terminates when the final answer is generated, an exception is thrown, or timeout occurs. What remains constant is that agents use the LLM as the centerpiece to orchestrate planning and tool invocations until the task terminates. Next, we show how to implement a simple agent loop using AWS services.

Solution overview

For this blog post, we implement an e-commerce support LLM agent that provides two functionalities powered by tools:

  • Return status retrieval tool – Answer questions about the status of returns such as, “What is happening to my return rtn001?”
  • Order status retrieval tool – Track the status of orders such as, “What’s the status of my order 123456?”

The agent effectively uses the LLM as a query router. Given a query (“What is the status of order 123456?”), select the appropriate retrieval tool to query across multiple data sources (that is, returns and orders). We accomplish query routing by having the LLM pick among multiple retrieval tools, which are responsible for interacting with a data source and fetching context. This extends the simple RAG pattern, which assumes a single data source.

Both retrieval tools are Lambda functions that take an id (orderId or returnId) as input, fetches a JSON object from the data source, and converts the JSON into a human friendly representation string that’s suitable to be used by LLM. The data source in a real-world scenario could be a highly scalable NoSQL database such as DynamoDB, but this solution employs simple Python Dict with sample data for demo purposes.

Additional functionalities can be added to the agent by adding Retrieval Tools and modifying prompts accordingly. This agent can be tested a standalone service that integrates with any UI over HTTP, which can be done easily with Amazon Lex.

Solution Overview

Here are some additional details about the key components:

  1. LLM inference endpoint – The core of an agent program is an LLM. We will use SageMaker JumpStart foundation model hub to easily deploy the Flan-UL2 model. SageMaker JumpStart makes it easy to deploy LLM inference endpoints to dedicated SageMaker instances.
  2. Agent orchestrator – Agent orchestrator orchestrates the interactions among the LLM, tools, and the client app. For our solution, we use an AWS Lambda function to drive this flow and employ the following as helper functions.
    • Task (tool) planner – Task planner uses the LLM to suggest one of 1) returns inquiry, 2) order inquiry, or 3) no tool. We use prompt engineering only and Flan-UL2 model as-is without fine-tuning.
    • Tool parser – Tool parser ensures that the tool suggestion from task planner is valid. Notably, we ensure that a single orderId or returnId can be parsed. Otherwise, we respond with a default message.
    • Tool dispatcher – Tool dispatcher invokes tools (Lambda functions) using the valid parameters.
    • Output parser – Output parser cleans and extracts relevant items from JSON into a human-readable string. This task is done both by each retrieval tool as well as within the orchestrator.
    • Output interpreter – Output interpreter’s responsibility is to 1) interpret the output from tool invocation and 2) determine whether the user request can be satisfied or additional steps are needed. If the latter, a final response is generated separately and returned to the user.

Now, let’s dive a bit deeper into the key components: agent orchestrator, task planner, and tool dispatcher.

Agent orchestrator

Below is an abbreviated version of the agent loop inside the agent orchestrator Lambda function. The loop uses helper functions such as task_planner or tool_parser, to modularize the tasks. The loop here is designed to run at most two times to prevent the LLM from being stuck in a loop unnecessarily long.

#.. imports ..
MAX_LOOP_COUNT = 2 # stop the agent loop after up to 2 iterations
# ... helper function definitions ...
def agent_handler(event):
    user_input = event["query"]
    print(f"user input: {user_input}") 
    
    final_generation = ""
    is_task_complete = False
    loop_count = 0 

    # start of agent loop
    while not is_task_complete and loop_count < MAX_LOOP_COUNT:
        tool_prediction = task_planner(user_input)
        print(f"tool_prediction: {tool_prediction}")  
        
        tool_name, tool_input, tool_output, error_msg = None, None, "", ""

        try:
            tool_name, tool_input = tool_parser(tool_prediction, user_input)
            print(f"tool name: {tool_name}") 
            print(f"tool input: {tool_input}") 
        except Exception as e:
            error_msg = str(e)
            print(f"tool parse error: {error_msg}")  
    
        if tool_name is not None: # if a valid tool is selected and parsed 
            raw_tool_output = tool_dispatch(tool_name, tool_input)
            tool_status, tool_output = output_parser(raw_tool_output)
            print(f"tool status: {tool_status}")  

            if tool_status == 200:
                is_task_complete, final_generation = output_interpreter(user_input, tool_output) 
            else:
                final_generation = tool_output
        else: # if no valid tool was selected and parsed, either return the default msg or error msg
            final_generation = DEFAULT_RESPONSES.NO_TOOL_FEEDBACK if error_msg == "" else error_msg
    
        loop_count += 1

    return {
        'statusCode': 200,
        'body': final_generation
    }

Task planner (tool prediction)

The agent orchestrator uses task planner to predict a retrieval tool based on user input. For our LLM agent, we will simply use prompt engineering and few shot prompting to teach the LLM this task in context. More sophisticated agents could use a fine-tuned LLM for tool prediction, which is beyond the scope of this post. The prompt is as follows:

tool_selection_prompt_template = """
Your task is to select appropriate tools to satisfy the user input. If no tool is required, then pick "no_tool"

Tools available are:

returns_inquiry: Database of information about a specific return's status, whether it's pending, processed, etc.
order_inquiry: Information about a specific order's status, such as shipping status, product, amount, etc.
no_tool: No tool is needed to answer the user input.

You can suggest multiple tools, separated by a comma.

Examples:
user: "What are your business hours?"
tool: no_tool

user: "Has order 12345 shipped?"
tool: order_inquiry

user: "Has return ret812 processed?"
tool: returns_inquiry

user: "How many days do I have until returning orders?"
tool: returns_inquiry

user: "What was the order total for order 38745?"
tool: order_inquiry

user: "Can I return my order 38756 based on store policy?"
tool: order_inquiry

user: "Hi"
tool: no_tool

user: "Are you an AI?"
tool: no_tool

user: "How's the weather?"
tool: no_tool

user: "What is the refund status of order 12347?"
tool: order_inquiry

user: "What is the refund status of return ret172?"
tool: returns_inquiry

user input: {}
tool:
"""

Tool dispatcher

The tool dispatch mechanism works via if/else logic to call appropriate Lambda functions depending on the tool’s name. The following is tool_dispatch helper function’s implementation. It’s used inside the agent loop and returns the raw response from the tool Lambda function, which is then cleaned by an output_parser function.


def tool_dispatch(tool_name, tool_input):
    #...
     
    tool_response = None 

    if tool_name == "returns_inquiry":
        tool_response = lambda_client.invoke(
            FunctionName=RETURNS_DB_TOOL_LAMBDA,
            InvocationType="RequestResponse",
            Payload=json.dumps({
              "returnId": tool_input  
            })
        )
    elif tool_name == "order_inquiry":
        tool_response = lambda_client.invoke(
            FunctionName=ORDERS_DB_TOOL_LAMBDA,
            InvocationType="RequestResponse",
            Payload=json.dumps({
                "orderId": tool_input
            })
        )
    else:
        raise ValueError("Invalid tool invocation")
        
    return tool_response

Deploy the solution

Important prerequisites – To get started with the deployment, you need to fulfill the following prerequisites:

  • Access to the AWS Management Console via a user who can launch AWS CloudFormation stacks
  • Familiarity with navigating the AWS Lambda and Amazon Lex consoles
  • Flan-UL2 requires a single ml.g5.12xlarge for deployment, which may necessitate increasing resource limits via a support ticket. In our example, we use us-east-1 as the Region, so please make sure to increase the service quota (if needed) in us-east-1.

Deploy using CloudFormation – You can deploy the solution to us-east-1 by clicking the button below:

Launch stack

Deploying the solution will take about 20 minutes and will create a LLMAgentStack stack, which:

  • deploys the SageMaker endpoint using Flan-UL2 model from SageMaker JumpStart;
  • deploys three Lambda functions: LLMAgentOrchestrator, LLMAgentReturnsTool, LLMAgentOrdersTool; and
  • deploys an AWS Lex bot that can be used to test the agent: Sagemaker-Jumpstart-Flan-LLM-Agent-Fallback-Bot.

Test the solution

The stack deploys an Amazon Lex bot with the name Sagemaker-Jumpstart-Flan-LLM-Agent-Fallback-Bot. The bot can be used to test the agent end-to-end. Here’s an additional comprehensive guide for testing AWS Amazon Lex bots with a Lambda integration and how the integration works at a high level. But in short, Amazon Lex bot is a resource that provides a quick UI to chat with the LLM agent running inside a Lambda function that we built (LLMAgentOrchestrator).

The sample test cases to consider are as follows:

  • Valid order inquiry (for example, “Which item was ordered for 123456?”)
    • Order “123456” is a valid order, so we should expect a reasonable answer (e.g. “Herbal Handsoap”)
  • Valid return inquiry for a return (for example, “When is my return rtn003 processed?”)
    • We should expect a reasonable answer about the return’s status.
  • Irrelevant to both returns or orders (for example, “How is the weather in Scotland right now?”)
    • An irrelevant question to returns or orders, thus a default answer should be returned (“Sorry, I cannot answer that question.”)
  • Invalid order inquiry (for example, “Which item was ordered for 383833?”)
    • The id 383832 does not exist in the orders dataset and hence we should fail gracefully (for example, “Order not found. Please check your Order ID.”)
  • Invalid return inquiry (for example, “When is my return rtn123 processed?”)
    • Similarly, id rtn123 does not exist in the returns dataset, and hence should fail gracefully.
  • Irrelevant return inquiry (for example, “What is the impact of return rtn001 on world peace?”)
    • This question, while it seems to pertain to a valid order, is irrelevant. The LLM is used to filter questions with irrelevant context.

To run these tests yourself, here are the instructions.

  1. On the Amazon Lex console (AWS Console > Amazon Lex), navigate to the bot entitled Sagemaker-Jumpstart-Flan-LLM-Agent-Fallback-Bot. This bot has already been configured to call the LLMAgentOrchestrator Lambda function whenever the FallbackIntent is triggered.
  2. In the navigation pane, choose Intents.
    intent navigation
  3. Choose Build at the top right corner
    lex bot start build
  4. 4. Wait for the build process to complete. When it’s done, you get a success message, as shown in the following screenshot.
    build complete status
  5. Test the bot by entering the test cases.

Cleanup

To avoid additional charges, delete the resources created by our solution by following these steps:

  • On the AWS CloudFormation console, select the stack named LLMAgentStack (or the custom name you picked).
  • Choose Delete
  • Check that the stack is deleted from the CloudFormation console.

Important: double-check that the stack is successfully deleted by ensuring that the Flan-UL2 inference endpoint is removed.

  • To check, go to AWS console > Sagemaker > Endpoints > Inference page.
  • The page should list all active endpoints.
  • Make sure sm-jumpstart-flan-bot-endpoint does not exist like the below screenshot.

sagemaker clean up

Considerations for production

Deploying LLM agents to production requires taking extra steps to ensure reliability, performance, and maintainability. Here are some considerations prior to deploying agents in production:

  • Selecting the LLM model to power the agent loop: For the solution discussed in this post, we used a Flan-UL2 model without fine-tuning to perform task planning or tool selection. In practice, using an LLM that is fine-tuned to directly output tool or API requests can increase reliability and performance, as well as simplify development. We could fine-tune an LLM on tool selection tasks or use a model that directly decodes tool tokens like Toolformer.
    • Using fine-tuned models can also simplify adding, removing, and updating tools available to an agent. With prompt-only based approaches, updating tools requires modifying every prompt inside the agent orchestrator, such as those for task planning, tool parsing, and tool dispatch. This can be cumbersome, and the performance may degrade if too many tools are provided in context to the LLM.
  • Reliability and performance: LLM agents can be unreliable, especially for complex tasks that cannot be completed within a few loops. Adding output validations, retries, structuring outputs from LLMs into JSON or yaml, and enforcing timeouts to provide escape hatches for LLMs stuck in loops can enhance reliability.

Conclusion

In this post, we explored how to build an LLM agent that can utilize multiple tools from the ground up, using low-level prompt engineering, AWS Lambda functions, and SageMaker JumpStart as building blocks. We discussed the architecture of LLM agents and the agent loop in detail. The concepts and solution architecture introduced in this blog post may be appropriate for agents that use a small number of a predefined set of tools. We also discussed several strategies for using agents in production. Agents for Bedrock, which is in preview, also provides a managed experience for building agents with native support for agentic tool invocations.


About the Author

John Hwang is a Generative AI Architect at AWS with special focus on Large Language Model (LLM) applications, vector databases, and generative AI product strategy. He is passionate about helping companies with AI/ML product development, and the future of LLM agents and co-pilots. Prior to joining AWS, he was a Product Manager at Alexa, where he helped bring conversational AI to mobile devices, as well as a derivatives trader at Morgan Stanley. He holds B.S. in computer science from Stanford University.

Read More

On-device content distillation with graph neural networks

On-device content distillation with graph neural networks

In today’s digital age, smartphones and desktop web browsers serve as the primary tools for accessing news and information. However, the proliferation of website clutter — encompassing complex layouts, navigation elements, and extraneous links — significantly impairs both the reading experience and article navigation. This issue is particularly acute for individuals with accessibility requirements.

To improve the user experience and make reading more accessible, Android and Chrome users may leverage the Reading Mode feature, which enhances accessibility by processing webpages to allow customizable contrast, adjustable text size, more legible fonts, and to enable text-to-speech utilities. Additionally, Android’s Reading Mode is equipped to distill content from apps. Expanding Reading Mode to encompass a wide array of content and improving its performance, while still operating locally on the user’s device without transmitting data externally, poses a unique challenge.

To broaden Reading Mode capabilities without compromising privacy, we have developed a novel on-device content distillation model. Unlike early attempts using DOM Distiller — a heuristic approach limited to news articles — our model excels in both quality and versatility across various types of content. We ensure that article content doesn’t leave the confines of the local environment. Our on-device content distillation model smoothly transforms long-form content into a simple and customizable layout for a more pleasant reading journey while also outperforming the leading alternative approaches. Here we explore details of this research highlighting our approach, methodology, and results.

Graph neural networks

Instead of relying on complicated heuristics that are difficult to maintain and scale to a variety of article layouts, we approach this task as a fully supervised learning problem. This data-driven approach allows the model to generalize better across different layouts, without the constraints and fragility of heuristics. Previous work for optimizing the reading experience relied on HTML or parsing, filtering, and modeling of a document object model (DOM), a programming interface automatically generated by the user’s web browser from site HTML that represents the structure of a document and allows it to be manipulated.

The new Reading Mode model relies on accessibility trees, which provide a streamlined and more accessible representation of the DOM. Accessibility trees are automatically generated from the DOM tree and are utilized by assistive technologies to allow people with disabilities to interact with web content. These are available on Chrome Web browser and on Android through AccessibilityNodeInfo objects, which are provided for both WebView and native application content.

We started by manually collecting and annotating accessibility trees. The Android dataset used for this project comprises on the order of 10k labeled examples, while the Chrome dataset contains approximately 100k labeled examples. We developed a novel tool that uses graph neural networks (GNNs) to distill essential content from the accessibility trees using a multi-class supervised learning approach. The datasets consist of long-form articles sampled from the web and labeled with classes such as headline, paragraph, images, publication date, etc.

GNNs are a natural choice for dealing with tree-like data structures, because unlike traditional models that often demand detailed, hand-crafted features to understand the layout and links within such trees, GNNs learn these connections naturally. To illustrate this, consider the analogy of a family tree. In such a tree, each node represents a family member and the connections denote familial relationships. If one were to predict certain traits using conventional models, features like the “number of immediate family members with a trait” might be needed. However, with GNNs, such manual feature crafting becomes redundant. By directly feeding the tree structure into the model, GNNs utilize a message-passing mechanism where each node communicates with its neighbors. Over time, information gets shared and accumulated across the network, enabling the model to naturally discern intricate relationships.

Returning to the context of accessibility trees, this means that GNNs can efficiently distill content by understanding and leveraging the inherent structure and relationships within the tree. This capability allows them to identify and possibly omit non-essential sections based on the information flow within the tree, ensuring more accurate content distillation.

Our architecture heavily follows the encode-process-decode paradigm using a message-passing neural network to classify text nodes. The overall design is illustrated in the figure below. The tree representation of the article is the input to the model. We compute lightweight features based on bounding box information, text information, and accessibility roles. The GNN then propagates each node’s latent representation through the edges of the tree using a message-passing neural network. This propagation process allows nearby nodes, containers, and text elements to share contextual information with each other, enhancing the model’s understanding of the page’s structure and content. Each node then updates its current state based on the message received, providing a more informed basis for classifying the nodes. After a fixed number of message-passing steps, the now contextualized latent representations of the nodes are decoded into essential or non-essential classes. This approach enables the model to leverage both the inherent relationships in the tree and the hand-crafted features representing each node, thereby enriching the final classification.

A visual demonstration of the algorithm in action, processing an article on a mobile device. A graph neural network (GNN) is used to distill essential content from an article. 1. A tree representation of the article is extracted from the application. 2. Lightweight features are computed for each node, represented as vectors. 3. A message-passing neural network propagates information through the edges of the tree and updates each node representation. 4. Leaf nodes containing text content are classified as essential or non-essential content. 5. A decluttered version of the application is composed based on the GNN output.

We deliberately restrict the feature set used by the model to increase its broad generalization across languages and speed up inference latency on user devices. This was a unique challenge, as we needed to create an on-device lightweight model that could preserve privacy.

Our final lightweight Android model has 64k parameters and is 334kB in size with a median latency of 800ms, while the Chrome model has 241k parameters, is 928kB in size, and has a 378ms median latency. By employing such on-device processing, we ensure that user data never leaves the device, reinforcing our responsible approach and commitment to user privacy. The features used in the model can be grouped into intermediate node features, leaf-node text features, and element position features. We performed feature engineering and feature selection to optimize the set of features for model performance and model size. The final model was transformed into TensorFlow Lite format to deploy as an on-device model on Android or Chrome.

Results

We trained the GNN for about 50 epochs in a single GPU. The performance of the Android model on webpages and native application test sets is presented below:

The table presents the content distillation metrics in Android for webpages and native apps. We report precision, recall and F1-score for three classes: non-essential content, headline, and main body text, including macro average and weighted average by number of instances in each class. Node metrics assess the classification performance at the granularity of the accessibility tree node, which is analogous to a paragraph level. In contrast, word metrics evaluate classification at an individual word level, meaning each word within a node gets the same classification.

<!–

Android Quality
    Webpages     Native Apps
node metrics     Precision     Recall     F1-score     Precision     Recall     F1-score
non-essential     0.9842     0.9846     0.9844     0.9744     0.9350     0.9543
headline     0.9187     0.9784     0.9476     0.9183     0.8568     0.8865
main-text     0.9223     0.9172     0.9197     0.8443     0.9424     0.8907
macro-average     0.9417     0.9600     0.9506     0.9124     0.9114     0.9105
weighted average     0.9736     0.9736     0.9736     0.9392     0.9353     0.9363
headline + main-text     0.9510     0.9683     0.9595     0.9473     0.9507     0.9490
The table presents the content distillation metrics in Android for webpages and native apps. We report precision, recall and F1-score for three classes: non-essential content, headline, and main body text, including macro average and weighted average by number of instances in each class. Node metrics assess the classification performance at the granularity of the accessibility tree node, which is analogous to a paragraph level. In contrast, word metrics evaluate classification at an individual word level, meaning each word within a node gets the same classification.

–>

In assessing the results’ quality on commonly visited webpage articles, an F1-score exceeding 0.9 for main-text (essentially paragraphs) corresponds to 88% of these articles being processed without missing any paragraphs. Furthermore, in over 95% of cases, the distillation proves to be valuable for readers. Put simply, the vast majority of readers will perceive the distilled content as both pertinent and precise, with errors or omissions being an infrequent occurrence.

The comparison of Chrome content distillation with other models such as DOM Distiller or Mozilla Readability on a set of English language pages is presented in the table below. We reuse the metrics from machine translation to compare the quality of these models. The reference text is from the groundtruth main content and the text from the models as hypothesis text. The results show the excellent performance of our models in comparison to other DOM-based approaches.

The table presents the comparison between DOM-Distiller, Mozilla Readability and the new Chrome model. We report text-based metrics, such as BLUE, CHRF and ROUGE, by comparing the main body text distilled from each model to a ground-truth text manually labeled by raters using our annotation policy.

<!–

Chrome Model Comparison on Webpages
Metric / Model     DOM Distiller     Mozilla Readability     Our Chrome model
BLEU     78.97     79.16     94.59
CHRF     0.92     0.92     0.98
ROUGE1     84.10     84.62     95.13
ROUGE2     81.84     82.66     94.81
ROUGE3     80.21     81.45     94.60
ROUGEL     83.58     84.02     95.04
ROUGEL-SUM     83.46     84.03     95.04
The table presents the comparison between DOM-Distiller, Mozilla Readability and the new Chrome model. We report text-based metrics, such as BLUE, CHRF and ROUGE, by comparing the main body text distilled from each model to a ground-truth text manually labeled by raters using our annotation policy.

–>

The F1-score of the Chrome content distillation model for headline and main text content on the test sets of different widely spoken languages demonstrates that the Chrome model, in particular, is able to support a wide range of languages.

The table presents per language of F1-scores of the Chrome model for the headline and main text classes. The language codes correspond to the following languages: German, English, Spanish, French, Italian, Persian, Japanese, Korean, Portuguese, Vietnamese, simplified Chinese and traditional Chinese.

<!–

Chrome Model on Different Languages
F1-score     de     en     es     fr     it     fa     ja     ko     pt     vi     zh-Hans     zh-Hant     average
headline     0.91     0.97     0.99     0.98     0.97     0.89     0.97     0.98     0.99     0.98     0.97     0.93     0.96
main text     0.84     0.90     0.93     0.91     0.93     0.87     0.88     0.91     0.91     0.90     0.90     0.90     0.90
The table presents per language of F1-scores of the Chrome model for the headline and main text classes. The language codes correspond to the following languages: German, English, Spanish, French, Italian, Persian, Japanese, Korean, Portuguese, Vietnamese, simplified Chinese and traditional Chinese.

–>

Conclusion

The digital age demands both streamlined content presentation and an unwavering commitment to user privacy. Our research highlights the effectiveness of Reading Mode in platforms like Android and Chrome, offering an innovative, data-driven approach to content parsing through Graph Neural Networks. Crucially, our lightweight on-device model ensures that content distillation occurs without compromising user data, with all processes executed locally. This not only enhances the reading experience but also reinforces our dedication to user privacy. As we navigate the evolving landscape of digital content consumption, our findings underscore the paramount importance of prioritizing the user in both experience and security.

Acknowledgements

This project is the result of joint work with Manuel Tragut, Mihai Popa, Abodunrinwa Toki, Abhanshu Sharma, Matt Sharifi, David Petrou and Blaise Aguera y Arcas. We sincerely thank our collaborators Gang Li and Yang Li. We are very grateful to Tom Small for assisting us in preparing the post.

Read More

Build a classification pipeline with Amazon Comprehend custom classification (Part I)

Build a classification pipeline with Amazon Comprehend custom classification (Part I)

“Data locked away in text, audio, social media, and other unstructured sources can be a competitive advantage for firms that figure out how to use it“

Only 18% of organizations in a 2019 survey by Deloitte reported being able to take advantage of unstructured data. The majority of data, between 80% and 90%, is unstructured data. That is a big untapped resource that has the potential to give businesses a competitive edge if they can find out how to use it. It can be difficult to find insights from this data, particularly if efforts are needed to classify, tag, or label it. Amazon Comprehend custom classification can be useful in this situation. Amazon Comprehend is a natural-language processing (NLP) service that uses machine learning to uncover valuable insights and connections in text.

Document categorization or classification has significant benefits across business domains –

  • Improved search and retrieval – By categorizing documents into relevant topics or categories, it makes it much easier for users to search and retrieve the documents they need. They can search within specific categories to narrow down results.
  • Knowledge management – Categorizing documents in a systematic way helps to organize an organization’s knowledge base. It makes it easier to locate relevant information and see connections between related content.
  • Streamlined workflows – Automatic document sorting can help streamline many business processes like processing invoices, customer support, or regulatory compliance. Documents can be automatically routed to the right people or workflows.
  • Cost and time savings – Manual document categorization is tedious, time-consuming, and expensive. AI techniques can take over this mundane task and categorize thousands of documents in a short time at a much lower cost.
  • Insight generation – Analyzing trends in document categories can provide useful business insights. For example, an increase in customer complaints in a product category could signify some issues that need to be addressed.
  • Governance and policy enforcement – Setting up document categorization rules helps to ensure that documents are classified correctly according to an organization’s policies and governance standards. This allows for better monitoring and auditing.
  • Personalized experiences – In contexts like website content, document categorization allows for tailored content to be shown to users based on their interests and preferences as determined from their browsing behavior. This can increase user engagement.

The complexity of developing a bespoke classification machine learning model varies depending on a variety of aspects such as data quality, algorithm, scalability, and domain knowledge, to mention a few. It’s essential to start with a clear problem definition, clean and relevant data, and gradually work through the different stages of model development. However, businesses can create their own unique machine learning models using Amazon Comprehend custom classification to automatically classify text documents into categories or tags, to meet business specific requirements and map to business technology and document categories. As human tagging or categorization is no longer necessary, this can save businesses a lot of time, money, and labor. We have made this process simple by automating the whole training pipeline.

In first part of this multi-series blog post, you will learn how to create a scalable training pipeline and prepare training data for Comprehend Custom Classification models. We will introduce a custom classifier training pipeline that can be deployed in your AWS account with few clicks. We are using the BBC news dataset, and will be training a classifier to identify the class (e.g. politics, sports) that a document belongs to. The pipeline will enable your organization to rapidly respond to changes and train new models without having to start from scratch each time. You may scale up and train multiple models based on your demand easily.

Prerequisites

  • An active AWS account (Click here to create a new AWS account)
  • Access to Amazon Comprehend, Amazon S3, Amazon Lambda, Amazon Step Function, Amazon SNS, and Amazon CloudFormation
  • Training data (semi-structure or text) prepared in following section
  • Basic knowledge about Python and Machine Learning in general

Prepare training data

This solution can take input as either text format (ex. CSV) or semi-structured format (ex. PDF).

Text input

Amazon Comprehend custom classification supports two modes: multi-class and multi-label.

In multi-class mode, each document can have one and only one class assigned to it. The training data should be prepared as two-column CSV file with each line of the file containing a single class and the text of a document that demonstrates the class.

CLASS, Text of document 1
CLASS, Text of document 2
...

Example for BBC news dataset:

Business, Europe blames US over weak dollar...
Tech, Cabs collect mountain of mobiles...
...

In multi-label mode, each document has at least one class assigned to it, but can have more. Training data should be as a two-column CSV file, which each line of the file containing one or more classes and the text of the training document. More than one class should be indicated by using a delimiter between each class.

CLASS, Text of document 1
CLASS|CLASS|CLASS, Text of document 2
...

No header should be included in the CSV file for either of the training mode.

Semi-structured input

Starting in 2023, Amazon Comprehend now supports training models using semi-structured documents. The training data for semi-structure input is comprised of a set of labeled documents, which can be pre-identified documents from a document repository that you already have access to. The following is an example of an annotations file CSV data required for training (Sample Data):

CLASS, document1.pdf, 1
CLASS, document1.pdf, 2
...

The annotations CSV file contains three columns: The first column contains the label for the document, the second column is the document name (i.e., file name), and the last column is the page number of the document that you want to include in the training dataset. In most cases, if the annotations CSV file is located at the same folder with all other document, then you just need to specify the document name in the second column. However, if the CSV file is located in a different location, then you’d need to specify the path to location in the second column, such as path/to/prefix/document1.pdf.

For details, how to prepare your training data, please refer to here.

Solution overview

  1. Amazon Comprehend training pipeline starts when training data (.csv file for text input and annotation .csv file for semi-structure input) is uploaded to a dedicated Amazon Simple Storage Service (Amazon S3) bucket.
  2. An AWS Lambda function is invoked by Amazon S3 trigger such that every time an object is uploaded to specified Amazon S3 location, the AWS Lambda function retrieves the source bucket name and the key name of the uploaded object and pass it to training step function workflow.
  3. In training step function, after receiving the training data bucket name and object key name as input parameters, a custom model training workflow kicks-off as a series of lambdas functions as described:
    1. StartComprehendTraining: This AWS Lambda function defines a ComprehendClassifier object depending on the type of input files (i.e., text or semi-structured) and then kicks-off an Amazon Comprehend custom classification training task by calling create_document_classifier Application Programming Interfact (API), which returns a training Job Amazon Resource Names (ARN) . Subsequently, this function checks the status of the training job by invoking describe_document_classifier API. Finally, it returns a training Job ARN and job status, as output to the next stage of training workflow.
    2. GetTrainingJobStatus: This AWS Lambda checks the job status of training job in every 15 minutes, by calling describe_document_classifier API, until training job status changes to Complete or Failed.
    3. GenerateMultiClass or GenerateMultiLabel: If you select yes for performance report when launching the stack, one of these two AWS Lambdas will run analysis according to your Amazon Comprehend model outputs, which generates per class performance analysis and save it to Amazon S3.
    4. GenerateMultiClass: This AWS Lambda will be called if your input is MultiClass and you select yes for performance report.
    5. GenerateMultiLabel: This AWS Lambda will be called if your input is MultiLabel and you select yes for performance report.
  4. Once the training is done successfully, the solution generates following outputs:
    1. Custom Classification Model: A trained model ARN will be available in your account for future inference work.
    2. Confusion Matrix [Optional]: A confusion matrix (confusion_matrix.json) will be available in user defined output Amazon S3 path, depending on the user selection.
    3. Amazon Simple Notification Service notification [Optional]: A notification email will be sent about training job status to the subscribers, depending on the initial user selection.

Walkthrough

Launching the solution

To deploy your pipeline, complete the following steps:

  1. Choose Launch Stack button:

  1. Choose Next

  1. Specify the pipeline details with the options fitting your use case:

Information for each stack detail:

  • Stack name (Required) – the name you specified for this AWS CloudFormation stack. The name must be unique in the Region in which you’re creating it.
  • Q01ClassifierInputBucketName (Required) – The Amazon S3 bucket name to store your input data. It should be a globally unique name and AWS CloudFormation stack helps you create the bucket while it’s being launched.
  • Q02ClassifierOutputBucketName (Required) – The Amazon S3 bucket name to store outputs from Amazon Comprehend and the pipeline. It should also be a globally unique name.
  • Q03InputFormat – A dropdown selection, you can choose text (if your training data is csv files) or semi-structure (if your training data are semi-structure [e.g., PDF files]) based on your data input format.
  • Q04Language – A dropdown selection, choosing the language of documents from supported list. Please note, currently only English is supported if your input format is semi-structure.
  • Q05MultiClass – A dropdown selection, select yes if your input is MultiClass mode. Otherwise, select no.
  • Q06LabelDelimiter – Only required if your Q05MultiClass answer is no. This delimiter is used in your training data to separate each class.
  • Q07ValidationDataset – A dropdown selection, change the answer to yes if you want to test the performance of trained classifier with your own test data.
  • Q08S3ValidationPath – Only required if your Q07ValidationDataset answer is yes.
  • Q09PerformanceReport – A dropdown selection, select yes if you want to generate the class-level performance report post model training. The report will be saved in you specified output bucket in Q02ClassifierOutputBucketName.
  • Q10EmailNotification – A dropdown selection. Select yes if you want to receive notification after model is trained.
  • Q11EmailID – Enter valid email address for receiving performance report notification. Please note, you have to confirm subscription from your email after AWS CloudFormation stack is launched, before you could receive notification when training is completed.
  1. In the Amazon Configure stack options section, add optional tags, permissions, and other advanced settings.

  1. Choose Next
  2. Review the stack details and select I acknowledge that AWS CloudFormation might create AWS IAM resources.

  1. Choose Submit. This initiates pipeline deployment in your AWS account.
  2. After the stack is deployed successfully, then you can start using the pipeline. Create a /training-data folder under your specified Amazon S3 location for input. Note: Amazon S3 automatically applies server-side encryption (SSE-S3) for each new object unless you specify a different encryption option. Please refer Data protection in Amazon S3 for more details on data protection and encryption in Amazon S3.

  1. Upload your training data to the folder. (If the training data are semi-structure, then upload all the PDF files before uploading .csv format label information).

You’re done! You’ve successfully deployed your pipeline and you can check the pipeline status in deployed step function. (You will have a trained model in your Amazon Comprehend custom classification panel).

If you choose the model and its version inside Amazon Comprehend Console, then you can now see more details about the model you just trained. It includes the Mode you select, which corresponds to the option Q05MultiClass, the number of labels, and the number of trained and test documents inside your training data. You could also check the overall performance below; however, if you want to check detailed performance for each class, then please refer to the Performance Report generated by the deployed pipeline.

Service quotas

Your AWS account has default quotas for Amazon Comprehend and AmazonTextract, if inputs are in semi-structure format. To view service quotas, please refer here for Amazon Comprehend and here for AmazonTextract.

Clean up

To avoid incurring ongoing charges, delete the resources you created as part of this solution when you’re done.

  1. On the Amazon S3 console, manually delete the contents inside buckets you created for input and output data.
  2. On the AWS CloudFormation console, choose Stacks in the navigation pane.
  3. Select the main stack and choose Delete.

This automatically deletes the deployed stack.

  1. Your trained Amazon Comprehend custom classification model will remain in your account. If you don’t need it anymore, in Amazon Comprehend console, delete the created model.

Conclusion

In this post, we showed you the concept of a scalable training pipeline for Amazon Comprehend custom classification models and providing an automated solution to efficiently training new models. The AWS CloudFormation template provided makes it possible for you to create your own text classification models effortlessly, catering to demand scales. The solution adopts the recent announced Euclid feature and accepts inputs in text or semi-structured format.

Now, we encourage you, our readers, to test these tools. You can find more details about training data preparation and understand the custom classifier metrics. Try it out and see firsthand how it can streamline your model training process and enhance efficiency. Please share your feedback to us!


About the Authors

Sandeep Singh is a Senior Data Scientist with AWS Professional Services. He is passionate about helping customers innovate and achieve their business objectives by developing state-of-the-art AI/ML powered solutions. He is currently focused on Generative AI, LLMs, prompt engineering, and scaling Machine Learning across enterprises. He brings recent AI advancements to create value for customers.

Yanyan Zhang is a Senior Data Scientist in the Energy Delivery team with AWS Professional Services. She is passionate about helping customers solve real problems with AI/ML knowledge. Recently, her focus has been on exploring the potential of Generative AI and LLM. Outside of work, she loves traveling, working out and exploring new things.

Wrick Talukdar is a Senior Architect with the Amazon Comprehend Service team. He works with AWS customers to help them adopt machine learning on a large scale. Outside of work, he enjoys reading and photography.

Read More

Fine-tune Falcon 7B and other LLMs on Amazon SageMaker with @remote decorator

Fine-tune Falcon 7B and other LLMs on Amazon SageMaker with @remote decorator

Today, generative AI models cover a variety of tasks from text summarization, Q&A, and image and video generation. To improve the quality of output, approaches like n-short learning, Prompt engineering, Retrieval Augmented Generation (RAG) and fine tuning are used. Fine-tuning allows you to adjust these generative AI models to achieve improved performance on your domain-specific tasks.

With Amazon SageMaker, now you can run a SageMaker training job simply by annotating your Python code with @remote decorator. The SageMaker Python SDK automatically translates your existing workspace environment, and any associated data processing code and datasets, into an SageMaker training job that runs on the training platform. This has the advantage of writing the code in a more natural, object-oriented way, and still uses SageMaker capabilities to run training jobs on a remote cluster with minimal changes.

In this post, we showcase how to fine-tune a Falcon-7B Foundation Models (FM) using @remote decorator from SageMaker Python SDK. It also uses Hugging Face’s parameter-efficient fine-tuning (PEFT) library and quantization techniques through bitsandbytes to support fine-tuning. The code presented in this blog can also be used to fine-tune other FMs, such as Llama-2 13b.

The full precision representations of this model might have challenges to fit into memory on a single or even several Graphic Processing Units (GPUs) — or may even need a bigger instance. Hence, in order to fine-tune this model without increasing cost, we use the technique known as Quantized LLMs with Low-Rank Adapters (QLoRA). QLoRA is an efficient fine-tuning approach that reduces memory usage of LLMs while maintaining very good performance.

Advantages of using @remote decorator

Before going further, let’s understand how remote decorator improves developer productivity while working with SageMaker:

  • @remote decorator triggers a training job directly using native python code, without the explicit invocation of SageMaker Estimators and SageMaker input channels
  • Low barrier for entry for developers training models on SageMaker.
  • No need to switch Integrated development environments (IDEs). Continue writing code in your choice of IDE and invoke SageMaker training jobs.
  • No need to learn about containers. Continue providing dependencies in a requirements.txt and supply that to remote decorator.

Prerequisites

An AWS account is needed with an AWS Identity and Access Management (AWS IAM) role that has permissions to manage resources created as part of the solution. For details, refer to Creating an AWS account.

In this post, we use Amazon SageMaker Studio with the Data Science 3.0 image and a ml.t3.medium fast launch instance. However, you can use any integrated development environment (IDE) of your choice. You just need to set up your AWS Command Line Interface (AWS CLI) credentials correctly. For more information, refer to Configure the AWS CLI.

For fine-tuning, the Falcon-7B, an ml.g5.12xlarge instance is used in this post. Please ensure sufficient capacity for this instance in AWS account.

You need to clone this Github repository for replicating the solution demonstrated in this post.

Solution overview

  1. Install pre-requisites to fine tuning the Falcon-7B model
  2. Set up remote decorator configurations
  3. Preprocess the dataset containing AWS services FAQs
  4. Fine-tune Falcon-7B on AWS services FAQs
  5. Test the fine-tune models on sample questions related to AWS services

1. Install prerequisites to fine tuning the Falcon-7B model

Launch the notebook falcon-7b-qlora-remote-decorator_qa.ipynb in SageMaker Studio by selecting the Image as Data Science and Kernel as Python 3. Install all the required libraries mentioned in the requirements.txt. Few of the libraries need to be installed on the notebook instance itself. Perform other operations needed for dataset processing and triggering a SageMaker training job.

%pip install -r requirements.txt

%pip install -q -U transformers==4.31.0
%pip install -q -U datasets==2.13.1
%pip install -q -U peft==0.4.0
%pip install -q -U accelerate==0.21.0
%pip install -q -U bitsandbytes==0.40.2
%pip install -q -U boto3
%pip install -q -U sagemaker==2.154.0
%pip install -q -U scikit-learn

2. Setup remote decorator configurations

Create a configuration file where all the configurations related to Amazon SageMaker training job are specified. This file is read by @remote decorator while running the training job. This file contains settings like dependencies, training image, instance, and the execution role to be used for training job. For a detailed reference of all the settings supported by config file, check out Configuring and using defaults with the SageMaker Python SDK.

SchemaVersion: '1.0'
SageMaker:
  PythonSDK:
    Modules:
      RemoteFunction:
        Dependencies: ./requirements.txt
        ImageUri: '{aws_account_id}.dkr.ecr.{region}.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04'
        InstanceType: ml.g5.12xlarge
        RoleArn: arn:aws:iam::111122223333:role/ExampleSageMakerRole

It’s not mandatory to use the config.yaml file in order to work with the @remote decorator. This is just a cleaner way to supply all configurations to the @remote decorator. This keeps SageMaker and AWS related parameters outside of code with a one time effort for setting up the config file used across the team members. All the configurations could also be supplied directly in the decorator arguments, but that reduces readability and maintainability of changes in the long run. Also, the configuration file can be created by an administrator and shared with all the users in an environment.

Preprocess the dataset containing AWS services FAQs

Next step is to load and preprocess the dataset to make it ready for training job. First, let us have a look at the dataset:

It shows FAQ for one of the AWS services. In addition to QLoRA, bitsanbytes is used to convert to 4-bit precision to quantize frozen LLM to 4-bit and attach LoRA adapters on it.

Create a prompt template to convert each FAQ sample to a prompt format:

from random import randint

# custom instruct prompt start
prompt_template = f"{{question}}n---nAnswer:n{{answer}}{{eos_token}}"

# template dataset to add prompt to each sample
def template_dataset(sample):
    sample["text"] = prompt_template.format(question=sample["question"],
                                            answer=sample["answers"],
                                            eos_token=tokenizer.eos_token)
    return sample

Next step is to convert the inputs (text) to token IDs. This is done by a Hugging Face Transformers Tokenizer.

from transformers import AutoTokenizer

model_id = "tiiuae/falcon-7b"

tokenizer = AutoTokenizer.from_pretrained(model_id)
# Set the Falcon tokenizer
tokenizer.pad_token = tokenizer.eos_token

Now simply use the prompt_template function to convert all the FAQ to prompt format and set up train and test datasets.

4. Fine tune Falcon-7B on AWS services FAQs

Now you can prepare the training script and define the training function train_fn and put @remote decorator on the function.

The training function does the following:

  • tokenizes and chunks the dataset
  • set up BitsAndBytesConfig, which specifies the model should be loaded in 4-bit but while computation should be converted to bfloat16.
  • Load the model
  • Find target modules and update the necessary matrices by using the utility method find_all_linear_names
  • Create LoRA configurations that specify ranking of update matrices (s), scaling factor (lora_alpha), the modules to apply the LoRA update matrices (target_modules), dropout probability for Lora layers(lora_dropout), task_type, etc.
  • Start the training and evaluation
import bitsandbytes as bnb

def find_all_linear_names(hf_model):
    lora_module_names = set()
    for name, module in hf_model.named_modules():
        if isinstance(module, bnb.nn.Linear4bit):
            names = name.split(".")
            lora_module_names.add(names[0] if len(names) == 1 else names[-1])

    if "lm_head" in lora_module_names:
        lora_module_names.remove("lm_head")
    return list(lora_module_names)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from sagemaker.remote_function import remote
import torch
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
import transformers

# Start training
@remote(volume_size=50)
def train_fn(
        model_name,
        train_ds,
        test_ds,
        lora_r=8,
        lora_alpha=32,
        lora_dropout=0.05,
        per_device_train_batch_size=8,
        per_device_eval_batch_size=8,
        learning_rate=2e-4,
        num_train_epochs=1
):
    # tokenize and chunk dataset
    lm_train_dataset = train_ds.map(
        lambda sample: tokenizer(sample["text"]), batched=True, batch_size=24, remove_columns=list(train_dataset.features)
    )


    lm_test_dataset = test_ds.map(
        lambda sample: tokenizer(sample["text"]), batched=True, remove_columns=list(test_dataset.features)
    )

    # Print total number of samples
    print(f"Total number of train samples: {len(lm_train_dataset)}")

    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16
    )
    # Falcon requires you to allow remote code execution. This is because the model uses a new architecture that is not part of transformers yet.
    # The code is provided by the model authors in the repo.
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        trust_remote_code=True,
        quantization_config=bnb_config,
        device_map="auto")

    model.gradient_checkpointing_enable()
    model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=True)

    # get lora target modules
    modules = find_all_linear_names(model)
    print(f"Found {len(modules)} modules to quantize: {modules}")

    config = LoraConfig(
        r=lora_r,
        lora_alpha=lora_alpha,
        target_modules=modules,
        lora_dropout=lora_dropout,
        bias="none",
        task_type="CAUSAL_LM"
    )

    model = get_peft_model(model, config)
    print_trainable_parameters(model)

    trainer = transformers.Trainer(
        model=model,
        train_dataset=lm_train_dataset,
        eval_dataset=lm_test_dataset,
        args=transformers.TrainingArguments(
            per_device_train_batch_size=per_device_train_batch_size,
            per_device_eval_batch_size=per_device_eval_batch_size,
            logging_steps=2,
            num_train_epochs=num_train_epochs,
            learning_rate=learning_rate,
            bf16=True,
            save_strategy="no",
            output_dir="outputs"
        ),
        data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),
    )
    model.config.use_cache = False

    trainer.train()
    trainer.evaluate()

    model.save_pretrained("/opt/ml/model")

And invoke the train_fn()

train_fn(model_id, train_dataset, test_dataset)

The tuning job would be running on the Amazon SageMaker training cluster. Wait for tuning job to finish.

5. Test the fine tune models on sample questions related to AWS services

Now, it’s time to run some tests on the model. First, let us load the model:

from peft import PeftModel, PeftConfig
import torch
from transformers import AutoModelForCausalLM

device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'

config = PeftConfig.from_pretrained("./model")
model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path, trust_remote_code=True)
model = PeftModel.from_pretrained(model, "./model")
model.to(device)

Now load a sample question from the training dataset to see the original answer and then ask the same question from the tuned model to see the answer in comparison.

Here is a sample a question from training set and the original answer:

Now, same question being asked to tuned Falcon-7B model:

This concludes the implementation of fine tuning Falcon-7B on AWS services FAQ dataset using @remote decorator from Amazon SageMaker Python SDK.

Cleaning up

Complete the following steps to clean up your resources:

  • Shut down the Amazon SageMaker Studio instances to avoid incurring additional costs.
  • Clean up your Amazon Elastic File System (Amazon EFS) directory by clearing the Hugging Face cache directory:
    rm -R ~/.cache/huggingface/hub

Conclusion

In this post, we showed you how to effectively use the @remote decorator’s capabilities to fine-tune Falcon-7B model using QLoRA, Hugging Face PEFT with bitsandbtyes without applying significant changes in the training notebook, and used Amazon SageMaker capabilities to run training jobs on a remote cluster.

All the code shown as part of this post to fine-tune Falcon-7B is available in the GitHub repository. The repository also contains notebook showing how to fine-tune Llama-13B.

As a next step, we encourage you to check out the @remote decorator functionality and Python SDK API and use it in your choice of environment and IDE. Additional examples are available in the amazon-sagemaker-examples repository to get you started quickly. You can also check out the following posts:


About the Authors

Bruno Pistone is an AI/ML Specialist Solutions Architect for AWS based in Milan. He works with large customers helping them to deeply understand their technical needs and design AI and Machine Learning solutions that make the best use of the AWS Cloud and the Amazon Machine Learning stack. His expertise include: Machine Learning end to end, Machine Learning Industrialization, and Generative AI. He enjoys spending time with his friends and exploring new places, as well as travelling to new destinations.

Vikesh Pandey is a Machine Learning Specialist Solutions Architect at AWS, helping customers from financial industries design and build solutions on generative AI and ML. Outside of work, Vikesh enjoys trying out different cuisines and playing outdoor sports.

Read More

Simplify access to internal information using Retrieval Augmented Generation and LangChain Agents

Simplify access to internal information using Retrieval Augmented Generation and LangChain Agents

This post takes you through the most common challenges that customers face when searching internal documents, and gives you concrete guidance on how AWS services can be used to create a generative AI conversational bot that makes internal information more useful.

Unstructured data accounts for 80% of all the data found within organizations, consisting of repositories of manuals, PDFs, FAQs, emails, and other documents that grows daily. Businesses today rely on continuously growing repositories of internal information, and problems arise when the amount of unstructured data becomes unmanageable. Often, users find themselves reading and checking many different internal sources to find the answers they need.

Internal question and answer forums can help users get highly specific answers but also require longer wait times. In the case of company-specific internal FAQs, long wait times result in lower employee productivity. Question and answer forums are difficult to scale as they rely on manually written answers. With generative AI, there is currently a paradigm shift in how users search and find information. The next logical step is to use generative AI to condense large documents into smaller bite sized information for easier user consumption. Instead of spending a long time reading text or waiting for answers, users can generate summaries in real-time based on multiple existing repositories of internal information.

Solution overview

The solution allows customers to retrieve curated responses to questions asked about internal documents by using a transformer model to generate answers to questions about data that it has not been trained on, a technique known as zero-shot prompting. By adopting this solution, customers can gain the following benefits:

  • Find accurate answers to questions based on existing sources of internal documents
  • Reduce the time users spend searching for answers by using Large Language Models (LLMs) to provide near-immediate answers to complex queries using documents with the most updated information
  • Search previously answered questions through a centralized dashboard
  • Reduce stress caused by spending time manually reading information to look for answers

Retrieval Augmented Generation (RAG)

Retrieval Augmented Generation (RAG) reduces some of the shortcomings of LLM based queries by finding the answers from your knowledge base and using the LLM to summarize the documents into concise responses. Please read this post to learn how to implement the RAG approach with Amazon Kendra. The following risks and limitations are associated with LLM based queries that a RAG approach with Amazon Kendra addresses:

  • Hallucinations and traceability – LLMS are trained on large data sets and generate responses on probabilities. This can lead to inaccurate answers, which are known as hallucinations.
  • Multiple data silos – In order to reference data from multiple sources within your response, one needs to set up a connector ecosystem to aggregate the data. Accessing multiple repositories is manual and time-consuming.
  • Security – Security and privacy are critical considerations when deploying conversational bots powered by RAG and LLMs. Despite using Amazon Comprehend to filter out personal data that may be provided through user queries, there remains a possibility of unintentionally surfacing personal or sensitive information, depending on the ingested data. This means that controlling access to the chatbot is crucial to prevent unintended access to sensitive information.
  • Data relevance – LLMS are trained on data up to certain date, which means information is often not current. The cost associated with training models on recent data is high. To ensure accurate and up-to-date responses, organizations bear the responsibility of regularly updating and enriching the content of the indexed documents.
  • Cost – The cost associated with deploying this solution should be a consideration for businesses. Businesses need to carefully assess their budget and performance requirements when implementing this solution. Running LLMs can require substantial computational resources, which may increase operational costs. These costs can become a limitation for applications that need to operate at a large scale. However, one of the benefits of the AWS Cloud is the flexibility to only pay for what you use. AWS offers a simple, consistent, pay-as-you-go pricing model, so you are charged only for the resources you consume.

Usage of Amazon SageMaker JumpStart

For transformer-based language models, organizations can benefit from using Amazon SageMaker JumpStart, which offers a collection of pre-built machine learning models. Amazon SageMaker JumpStart offers a wide range of text generation and question-answering (Q&A) foundational models that can be easily deployed and utilized. This solution integrates a FLAN T5-XL Amazon SageMaker JumpStart model, but there are different aspects to keep in mind when choosing a foundation model.

Integrating security in our workflow

Following the best practices of the Security Pillar of the Well-Architected Framework, Amazon Cognito is used for authentication. Amazon Cognito User Pools can be integrated with third-party identity providers that support several frameworks used for access control, including Open Authorization (OAuth), OpenID Connect (OIDC), or Security Assertion Markup Language (SAML). Identifying users and their actions allows the solution to maintain traceability. The solution also uses the Amazon Comprehend personally identifiable information (PII) detection feature to automatically identity and redact PII. Redacted PII includes addresses, social security numbers, email addresses, and other sensitive information. This design ensures that any PII provided by the user through the input query is redacted. The PII is not stored, used by Amazon Kendra, or fed to the LLM.

Solution Walkthrough

The following steps describe the workflow of the Question answering over documents flow:

  1. Users send a query through a web interface.
  2. Amazon Cognito is used for authentication, ensuring secure access to the web application.
  3. The web application front-end is hosted on AWS Amplify.
  4. Amazon API Gateway hosts a REST API with various endpoints to handle user requests that are authenticated using Amazon Cognito.
  5. PII redaction with Amazon Comprehend:
    • User Query Processing: When a user submits a query or input, it is first passed through Amazon Comprehend. The service analyzes the text and identifies any PII entities present within the query.
    • PII Extraction: Amazon Comprehend extracts the detected PII entities from the user query.
  6. Relevant Information Retrieval with Amazon Kendra:
    • Amazon Kendra is used to manage an index of documents that contains the information used to generate answers to the user’s queries.
    • The LangChain QA retrieval module is used to build a conversation chain that has relevant information about the user’s queries.
  7. Integration with Amazon SageMaker JumpStart:
    • The AWS Lambda function uses the LangChain library and connects to the Amazon SageMaker JumpStart endpoint with a context-stuffed query. The Amazon SageMaker JumpStart endpoint serves as the interface of the LLM used for inference.
  8. Storing responses and returning it to the user:
    • The response from the LLM is stored in Amazon DynamoDB along with the user’s query, the timestamp, a unique identifier, and other arbitrary identifiers for the item such as question category. Storing the question and answer as discrete items allows the AWS Lambda function to easily recreate a user’s conversation history based on the time when questions were asked.
    • Finally, the response is sent back to the user via a HTTPs request through the Amazon API Gateway REST API integration response.

The following steps describe the AWS Lambda functions and their flow through the process:

  1. Check and redact any PII / Sensitive info
  2. LangChain QA Retrieval Chain
    • Search and retrieve relevant info
  3. Context Stuffing & Prompt Engineering
    • LangChain
  4. Inference with LLM
  5. Return response & Save it

Use cases

There are many business use cases where customers can use this workflow. The following section explains how the workflow can be used in different industries and verticals.

Employee Assistance

Well-designed corporate training can improve employee satisfaction and reduce the time required for onboarding new employees. As organizations grow and complexity increases, employees find it difficult to understand the many sources of internal documents. Internal documents in this context include company guidelines, policies, and Standard Operating Procedures. For this scenario, an employee has a question in how to proceed and edit an internal issue ticketing ticket. The employee can access and use the generative artificial intelligence (AI) conversational bot to ask and execute the next steps for a specific ticket.

Specific use case: Automate issue resolution for employees based on corporate guidelines.

The following steps describe the AWS Lambda functions and their flow through the process:

  1. LangChain agent to identify the intent
  2. Send notification based on employee request
  3. Modify ticket status

In this architecture diagram, corporate training videos can be ingested through Amazon Transcribe to collect a log of these video scripts. Additionally, corporate training content stored in various sources (i.e., Confluence, Microsoft SharePoint, Google Drive, Jira, etc.) can be used to create indexes through Amazon Kendra connectors. Read this article to learn more on the collection of native connectors you can utilize in Amazon Kendra as a source point. The Amazon Kendra crawler is then able to use both the corporate training video scripts and documentation stored in these other sources to assist the conversational bot in answering questions specific to company corporate training guidelines. The LangChain agent verifies permissions, modifies ticket status, and notifies the correct individuals using Amazon Simple Notification Service (Amazon SNS).

Customer Support Teams

Quickly resolving customer queries improves the customer experience and encourages brand loyalty. A loyal customer base helps drive sales, which contributes to the bottom line and increases customer engagement. Customer support teams spend lots of energy referencing many internal documents and customer relationship management software to answer customer queries about products and services. Internal documents in this context can include generic customer support call scripts, playbooks, escalation guidelines, and business information. The generative AI conversational bot helps with cost optimization because it handles queries on behalf of the customer support team.

Specific use case: Handling an oil change request based on service history and customer service plan purchased.

In this architecture diagram, the customer is routed to either the generative AI conversational bot or the Amazon Connect contact center. This decision can be based on the level of support needed or the availability of customer support agents. The LangChain agent identifies the customer’s intent and verifies identity. The LangChain agent also checks the service history and purchased support plan.

The following steps describe the AWS Lambda functions and their flow through the process:

  1. LangChain agent identifies the intent
  2. Retrieve Customer Information
  3. Check customer service history and warranty information
  4. Book appointment, provide more information, or route to contact center
  5. Send email confirmation

Amazon Connect is used to collect the voice and chat logs, and Amazon Comprehend is used to remove personally identifiable information (PII) from these logs. The Amazon Kendra crawler is then able to use the redacted voice and chat logs, customer call scripts, and customer service support plan policies to create the index. Once a decision is made, the generative AI conversational bot decides whether to book an appointment, provide more information, or route the customer to the contact center for further assistance. For cost optimization, the LangChain agent can also generate answers using fewer tokens and a less expensive large language model for lower priority customer queries.

Financial Services

Financial services companies rely on timely use of information to stay competitive and comply with financial regulations. Using a generative AI conversational bot, financial analysts and advisors can interact with textual information in a conversational manner and reduce the time and effort it takes to make better informed decisions. Outside of investment and market research, a generative AI conversational bot can also augment human capabilities by handling tasks that would traditionally require more human effort and time. For example, a financial institution specializing in personal loans can increase the rate at which loans are processed while providing better transparency to customers.

Specific use case: Use customer financial history and previous loan applications to decide and explain loan decision.

The following steps describe the AWS Lambda functions and their flow through the process:

  1. LangChain agent to identify the intent
  2. Check customer financial and credit score history
  3. Check internal customer relationship management system
  4. Check standard loan policies and suggest decision for employee qualifying the loan
  5. Send notification to customer

This architecture incorporates customer financial data stored in a database and data stored in a customer relationship management (CRM) tool. These data points are used to inform a decision based on the company’s internal loan policies. The customer is able to ask clarifying questions to understand what loans they qualify for and the terms of the loans they can accept. If the generative AI conversational bot is unable to approve a loan application, the user can still ask questions about improving credit scores or alternative financing options.

Government

Generative AI conversational bots can greatly benefit government institutions by speeding up communication, efficiency, and decision-making processes. Generative AI conversational bots can also provide instant access to internal knowledge bases to help government employees to quickly retrieve information, policies, and procedures (i.e., eligibility criteria, application processes, and citizen’s services and support). One solution is an interactive system, which allows tax payers and tax professionals to easily find tax-related details and benefits. It can be used to understand user questions, summarize tax documents, and provide clear answers through interactive conversations.

Users can ask questions such as:

  • How does inheritance tax work and what are the tax thresholds?
  • Can you explain the concept of income tax?
  • What are the tax implications when selling a second property?

Additionally, users can have the convenience of submitting tax forms to a system, which can help verify the correctness of the information provided.

This architecture illustrates how users can upload completed tax forms to the solution and utilize it for interactive verification and guidance on how to accurately completing the necessary information.

Healthcare

Healthcare businesses have the opportunity to automate the use of large amounts of internal patient information, while also addressing common questions regarding use cases such as treatment options, insurance claims, clinical trials, and pharmaceutical research. Using a generative AI conversational bot enables quick and accurate generation of answers about health information from the provided knowledge base. For example, some healthcare professionals spend a lot of time filling in forms to file insurance claims.

In similar settings, clinical trial administrators and researchers need to find information about treatment options. A generative AI conversational bot can use the pre-built connectors in Amazon Kendra to retrieve the most relevant information from the millions of documents published through ongoing research conducted by pharmaceutical companies and universities.

Specific use case: Reduce the errors and time needed to fill out and send insurance forms.

In this architecture diagram, a healthcare professional is able to use the generative AI conversational bot to figure out what forms need to be filled out for the insurance. The LangChain agent is then able to retrieve the right forms and add the needed information for a patient as well as giving responses for descriptive parts of the forms based on insurance policies and previous forms. The healthcare professional can edit the responses given by the LLM before approving and having the form delivered to the insurance portal.

The following steps describe the AWS Lambda functions and their flow through the process:

  1. LangChain agent to identify the intent
  2. Retrieve the patient information needed
  3. Fill out the insurance form based on the patient information and form guideline
  4. Submit the form to the insurance portal after user approval

AWS HealthLake is used to securely store the health data including previous insurance forms and patient information, and Amazon Comprehend is used to remove personally identifiable information (PII) from the previous insurance forms. The Amazon Kendra crawler is then able to use the set of insurance forms and guidelines to create the index. Once the form(s) are filled out by the generative AI, then the form(s) reviewed by the medical professional can be sent to the insurance portal.

Cost estimate

The cost of deploying the base solution as a proof-of-concept is shown in the following table. Since the base solution is considered a proof-of-concept, Amazon Kendra Developer Edition was used as a low-cost option since the workload would not be in production. Our assumption for Amazon Kendra Developer Edition was 730 active hours for the month.

For Amazon SageMaker, we made an assumption that the customer would be using the ml.g4dn.2xlarge instance for real-time inference, with a single inference endpoint per instance. You can find more information on Amazon SageMaker pricing and available inference instance types here.

Service Resources Consumed Cost Estimate Per Month in USD
AWS Amplify 150 build minutes
1 GB of Data served
500,000 requests
15.71
Amazon API Gateway 1M REST API Calls 3.5
AWS Lambda 1 Million requests
5 seconds duration per request
2 GB memory allocated
160.23
Amazon DynamoDB 1 million reads
1 million writes
100 GB storage
26.38
Amazon Sagemaker Real-time inference with ml.g4dn.2xlarge 676.8
Amazon Kendra Developer Edition with 730 hours/month
10,000 Documents scanned
5,000 queries/day
821.25
. . Total Cost: 1703.87

*  Amazon Cognito has a free tier of 50,000 Monthly Active Users who use Cognito User Pools or 50 Monthly Active Users who use SAML 2.0 identity providers

Clean Up

To save costs, delete all the resources you deployed as part of the tutorial. You can delete any SageMaker endpoints you may have created via the SageMaker console. Remember, deleting an Amazon Kendra index doesn’t remove the original documents from your storage.

Conclusion

In this post, we showed you how to simplify access to internal information by summarizing from multiple repositories in real-time. After the recent developments of commercially available LLMs, the possibilities of generative AI have become more apparent. In this post, we showcased ways to use AWS services to create a serverless chatbot that uses generative AI to answer questions. This approach incorporates an authentication layer and Amazon Comprehend’s PII detection to filter out any sensitive information provided in the user’s query. Whether it be individuals in healthcare understanding the nuances to file insurance claims or HR understanding specific company-wide regulations, there’re multiple industries and verticals that can benefit from this approach. An Amazon SageMaker JumpStart foundation model is the engine behind the chatbot, while a context stuffing approach using the RAG technique is used to ensure that the responses more accurately reference internal documents.

To learn more about working with generative AI on AWS, refer to Announcing New Tools for Building with Generative AI on AWS. For more in-depth guidance on using the RAG technique with AWS services, refer to Quickly build high-accuracy Generative AI applications on enterprise data using Amazon Kendra, LangChain, and large language models. Since the approach in this blog is LLM agnostic, any LLM can be used for inference. In our next post, we’ll outline ways to implement this solution using Amazon Bedrock and the Amazon Titan LLM.


About the Authors

Abhishek Maligehalli Shivalingaiah is a Senior AI Services Solution Architect at AWS. He is passionate about building applications using Generative AI, Amazon Kendra and NLP. He has around 10 years of experience in building Data & AI solutions to create value for customers and enterprises. He has even built a (personal) chatbot for fun to answers questions about his career and professional journey. Outside of work he enjoys making portraits of family & friends, and loves creating artworks.

Medha Aiyah is an Associate Solutions Architect at AWS, based in Austin, Texas. She recently graduated from the University of Texas at Dallas in December 2022 with her Masters of Science in Computer Science with a specialization in Intelligent Systems focusing on AI/ML. She is interested to learn more about AI/ML and utilizing AWS services to discover solutions customers can benefit from.

Hugo Tse is an Associate Solutions Architect at AWS based in Seattle, Washington. He holds a Master’s degree in Information Technology from Arizona State University and a bachelor’s degree in Economics from the University of Chicago. He is a member of the Information Systems Audit and Control Association (ISACA) and International Information System Security Certification Consortium (ISC)2. He enjoys helping customers benefit from technology.

Ayman Ishimwe is an Associate Solutions Architect at AWS based in Seattle, Washington. He holds a Master’s degree in Software Engineering and IT from Oakland University. He has a prior experience in software development, specifically in building microservices for distributed web applications. He is passionate about helping customers build robust and scalable solutions on AWS cloud services following best practices.

Shervin Suresh is an Associate Solutions Architect at AWS based in Austin, Texas. He has graduated with a Masters in Software Engineering with a Concentration in Cloud Computing and Virtualization and a Bachelors in Computer Engineering from San Jose State University. He is passionate about leveraging technology to help improve the lives of people from all backgrounds.

Read More