Run text classification with Amazon SageMaker JumpStart using TensorFlow Hub and Hugging Face models

In December 2020, AWS announced the general availability of Amazon SageMaker JumpStart, a capability of Amazon SageMaker that helps you quickly and easily get started with machine learning (ML). JumpStart provides one-click fine-tuning and deployment of a wide variety of pre-trained models across popular ML tasks, as well as a selection of end-to-end solutions that solve common business problems. These features remove the heavy lifting from each step of the ML process, making it easier to develop high-quality models and reducing time to deployment.

All JumpStart content was previously available only through Amazon SageMaker Studio, which provides a user-friendly graphical interface to interact with the feature. Recently, we also announced the launch of easy-to-use JumpStart APIs as an extension of the SageMaker Python SDK, allowing you to programmatically deploy and fine-tune a vast selection of JumpStart-supported pre-trained models on your own datasets. This launch unlocks the usage of JumpStart capabilities in your code workflows, MLOps pipelines, and anywhere else you are interacting with SageMaker via SDK.

This post is the second in a series on using JumpStart for specific tasks. In the first post, we showed how you can run image classification use cases on JumpStart. In this post, we provide a step-by-step walkthrough on how to fine-tune and deploy a text classification model, using trained models from TensorFlow Hub. We explore two ways of obtaining the same result: via JumpStart’s graphical interface on Studio, and programmatically through JumpStart APIs. While not used in this blog post, Amazon Comprehend is a natural language processing (NLP) service that uses machine learning (ML) to find insights and relationships like people, places, sentiments, and topics in text. In a fully-managed experience, customers can use Comprehend’s custom classification API to train a custom classification model to recognize the classes that are of interest and then use that model to classify documents into your own categories.

If you want to jump straight into the JumpStart API code we explain in this post, you can refer to the following sample Jupyter notebooks:

JumpStart overview

JumpStart is a multi-faceted product that includes different capabilities to help get you quickly started with ML on SageMaker. At the time of writing, JumpStart enables you to do the following:

  • Deploy pre-trained models for common ML tasks – JumpStart enables you to address common ML tasks with no development effort by providing easy deployment of models pre-trained on large, publicly available datasets. The ML research community has put a large amount of effort into making a majority of recently developed models publicly available for use. JumpStart hosts a collection of over 300 models, spanning the 15 most popular ML tasks such as object detection, text classification, and text generation, making it easy for beginners to use them. These models are drawn from popular model hubs such as TensorFlow, PyTorch, Hugging Face, and MXNet Hub.
  • Fine-tune pre-trained models – JumpStart allows you to fine-tune pre-trained models with no need to write your own training algorithm. In ML, the ability to transfer the knowledge learned in one domain to another domain is called transfer learning. You can use transfer learning to produce accurate models on your smaller datasets, with much lower training costs than the ones involved in training the original model. JumpStart also includes popular training algorithms based on LightGBM, CatBoost, XGBoost, and Scikit-learn, which you can train from scratch for tabular regression and classification.
  • Use pre-built solutions – JumpStart provides a set of 17 solutions for common ML use cases, such as demand forecasting and industrial and financial applications, which you can deploy with just a few clicks. Solutions are end-to-end ML applications that string together various AWS services to solve a particular business use case. They use AWS CloudFormation templates and reference architectures for quick deployment, which means they’re fully customizable.
  • Refer to notebook examples for SageMaker algorithms – SageMaker provides a suite of built-in algorithms to help data scientists and ML practitioners get started with training and deploying ML models quickly. JumpStart provides sample notebooks that you can use to quickly use these algorithms.
  • Review training videos and blogs – JumpStart also provides numerous blog posts and videos that teach you how to use different functionalities within SageMaker.

JumpStart accepts custom VPC settings and AWS Key Management Service (AWS KMS) encryption keys, so you can use the available models and solutions securely within your enterprise environment. You can pass your security settings to JumpStart within Studio or through the SageMaker Python SDK.

Transformer models and the importance of fine-tuning

The attention-based Transformer architecture has become the de-facto standard for the state-of-the-art natural language processing (NLP) models. In 2018, the famous BERT model was born from an adapted Transformer encoder, pre-trained on GBs of unlabeled English text from Wikipedia and other public resources. BERT was incredibly useful for creating contextualized text representations, which you could then use in many downstream tasks by fine-tuning the model. Since then, many variants of the BERT model have been developed by way of architecture, pre-training schema, or pre-training dataset changes.

Fine-tuning means to use a model that has been pre-trained on a given task and train it again, this time for a specific task that is different but related (and on your specific data). This practice is also typically referred to as transfer learning—literally meaning to transfer knowledge gained on one task to another. Typically, Transformer-based models are pre-trained on massive amounts of unlabeled data, and comparatively much smaller labeled datasets are then used for fine-tuning. Being able to leverage the large computational investments made to pre-train such models for downstream tasks (commonly made open-source) has been one of the most important factors in the growth of NLP as a field in previous years. However, as new models get larger and more complex, so does the development effort required to fine-tune and deploy them efficiently. In this post, we show you how to use JumpStart to fine-tune a BERT model with little to no development effort involved.

Text classification

Sentiment analysis is one of the many tasks under the umbrella of text classification. It consists of predicting what sentiment should be assigned to a specific passage of text, with varying degrees of granularity. Typical applications include social media monitoring, customer support management, and analyzing customer feedback.

The input is a directory containing a data.csv file. The first column is the label (an integer between 0 and the number of classes in the dataset), and the second column is the corresponding passage of text. This means that you could even use a dataset with more degrees of sentiment than the original—for example, very negative (0), negative (1), neutral (2), positive (3), very positive (4). The following is an example of a data.csv file corresponding to the SST2 (Stanford Sentiment Treebank) dataset, and shows values in its first two columns. Note that the file shouldn’t have any header.

Column 1 Column 2
0 hide new secretions from the parental units
0 contains no wit , only labored gags
1 that loves its characters and communicates something rather beautiful about human nature
0 remains utterly satisfied to remain the same throughout
0 on the worst revenge-of-the-nerds clichés the filmmakers could dredge up
0 that ‘s far too tragic to merit such superficial treatment
1 demonstrates that the director of such hollywood blockbusters as patriot games can still turn out a small , personal film with an emotional wallop .

The SST21,2 dataset is downloaded from TensorFlow. Apache 2.0 License. Dataset Homepage.

Sentence pair classification

Sentence pair classification consists of predicting a class for a pair of sentences, which forces the model to learn semantic dependencies between sentence pairs. Among these are typically textual entailment (does the first sentence come before the second originally?), paraphrasing (are both sentences just differently worded versions of one another?), and others.

The input is a directory containing a data.csv file. The first column in the file should have integer class labels between 0 and the number of classes. The second and third columns should contain the first and second sentence corresponding to that row. The following is an example of a data.csv file for the QNLI dataset, and shows values in its first three columns. Note that the file shouldn’t have any header.

Column 1 Column 2 Column 3
0 What is the Grotto at Notre Dame? Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection.
1 What is the Grotto at Notre Dame? It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858.
0 What sits on top of the Main Building at Notre Dame? Atop the Main Building’s gold dome is a golden statue of the Virgin Mary.
1 What sits on top of the Main Building at Notre Dame? Next to the Main Building is the Basilica of the Sacred Heart.
0 When did the Scholastic Magazine of Notre dame begin publishing? Begun as a one-page journal in September 1876, the Scholastic magazine is issued twice monthly and claims to be the oldest continuous collegiate publication in the United States.
1 When did the Scholastic Magazine of Notre dame begin publishing? The newspapers have varying publication interests, with The Observer published daily and mainly reporting university and other news, and staffed by students from both Notre Dame and Saint Mary’s College.
0 In what year did the student paper Common Sense begin publication at Notre Dame? In 1987, when some students believed that The Observer began to show a conservative bias, a liberal newspaper, Common Sense was published.

The QNLI3.4 dataset is downloaded from TensorFlow. Apache 2.0 License. Dataset Homepage. CC BY-SA 4.0 License.

Solution overview

The following sections provide a step-by-step demo to perform sentiment analysis with JumpStart, both via the Studio UI and via JumpStart APIs. The workflow for sentence pair classification is almost identical, and we describe the changes required for that task.

We walk through the following steps:

  1. Access JumpStart through the Studio UI:
    1. Fine-tune the pre-trained model.
    2. Deploy the fine-tuned model.
  2. Use JumpStart programmatically with the SageMaker Python SDK:
    1. Fine-tune the pre-trained model.
    2. Deploy the fine-tuned model.

Access JumpStart through the Studio UI

In this section, we demonstrate how to train and deploy JumpStart models through the Studio UI.

Fine-tune the pre-trained model

The following video shows you how to find a pre-trained text classification model on JumpStart and fine-tune it on the task of sentiment analysis. The model page contains valuable information about the model, how to use it, expected data format, and some fine-tuning details.

For demonstration purposes, we fine-tune the model using the dataset provided by default, which is the Stanford Sentiment Treebank v2 (SST2) dataset. Fine-tuning on your own dataset involves taking the correct formatting of data (as explained on the model page), uploading it to Amazon Simple Storage Service (Amazon S3), and specifying its location in the data source configuration.

We use the sane hyperparameter values set by default (number of epochs, learning rate, and batch size). We also use a GPU-backed ml.p3.2xlarge as our SageMaker training instance.

You can monitor your training job running directly on the Studio console, and are notified upon its completion.

For sentence pair classification, instead of searching for text classification models in the JumpStart search bar, search for sentence pair classification. For example, choose Bert Base Uncased from the resulting model selection. Apart from the dataset format, the rest of the workflow is identical between text classification and sentence pair classification.

Deploy the fine-tuned model

After training is complete, you can deploy the fine-tuned model from the same page that holds the training job details. To deploy our model, we pick a different instance type, ml.g4dn.xlarge. It still provides the GPU acceleration needed for low inference latency, but at a lower price point. After you configure the SageMaker hosting instance, choose Deploy. It may take 5–10 minutes until your persistent endpoint is up and running.

After a few minutes, your endpoint is operational and ready to respond to inference requests!

To accelerate your time to inference, JumpStart provides a sample notebook that shows you how to run inference on your freshly deployed endpoint. Choose Open Notebook under Use Endpoint from Studio.

Use JumpStart programmatically with the SageMaker SDK

In the preceding sections, we showed how you can use the JumpStart UI to fine-tune and deploy a model interactively, in a matter of a few clicks. However, you can also use JumpStart’s models and easy fine-tuning programmatically by using APIs that are integrated into the SageMaker SDK. We now go over a quick example of how you can replicate the preceding process. All the steps in this demo are available in the accompanying notebooks Introduction to JumpStart – Text Classification and Introduction to JumpStart – Sentence Pair Classification.

Fine-tune the pre-trained model

To fine-tune a selected model, we need to get that model’s URI, as well as that of the training script and the container image used for training. Thankfully, these three inputs depend solely on the model name, version (for a list of the available models, see JumpStart Available Model Table), and the type of instance you want to train on. This is demonstrated in the following code snippet:

from sagemaker import image_uris, model_uris, script_uris

model_id, model_version = "tensorflow-tc-bert-en-uncased-L-12-H-768-A-12-2", "1.0.0"
training_instance_type = "ml.p3.2xlarge"

# Retrieve the docker image
train_image_uri = image_uris.retrieve(
    region=None,
    framework=None,
    model_id=model_id,
    model_version=model_version,
    image_scope="training",
    instance_type=training_instance_type,
)
# Retrieve the training script

train_source_uri = script_uris.retrieve(model_id=model_id, model_version=model_version, script_scope="training")

# Retrieve the pre-trained model tarball to further fine-tune

train_model_uri = model_uris.retrieve(model_id=model_id, model_version=model_version, model_scope="training")

We retrieve the model_id corresponding to the same model we used previously (dimensions are characteristic to the base version of BERT). The tc in the identifier corresponds to text classification.

For sentence pair classification, we can set model_id to huggingface-spc-bert-base-uncased. The spc in the identifier corresponds to sentence pair classification.

You can now fine-tune this JumpStart model on your own custom dataset using the SageMaker SDK. We use a dataset that is publicly hosted on Amazon S3, conveniently focused on sentiment analysis. The dataset should be structured for fine-tuning as explained in the previous section. See the following example code:

# URI of your training dataset
training_dataset_s3_path = "s3://jumpstart-cache-prod-us-west-2/training-datasets/tc/data.csv"
training_job_name = name_from_base(f"jumpstart-example-{model_id}-transfer-learning")

# Create SageMaker Estimator instance
tc_estimator = Estimator(
    role=aws_role,
    image_uri=train_image_uri,
    source_dir=train_source_uri,
    model_uri=train_model_uri,
    entry_point="transfer_learning.py",
    instance_count=1,
    instance_type=training_instance_type,
    max_run=360000,
    hyperparameters=hyperparameters,
    output_path=s3_output_location,
)

# Launch a SageMaker Training job by passing s3 path of the training data
tc_estimator.fit({"training": training_dataset_s3_path}, logs=True)

We obtain the same default hyperparameters for our selected model as the ones we saw in the previous section, using sagemaker.hyperparameters.retrieve_default(). We then instantiate a SageMaker estimator, and call the .fit method to start fine-tuning our model, passing it the Amazon S3 URI for our training data. As you can see, the entry_point script provided is named transfer_learning.py (the same for other tasks and models), and the input data channel passed to .fit must be named training.

Deploy the fine-tuned model

When training is complete, you can deploy your fine-tuned model. To do so, all we need to obtain is the inference script URI (the code that determines how the model is used for inference once deployed) and the inference container image URI, which includes an appropriate model server to host the model we chose. See the following code:

# Retrieve the inference docker container uri
deploy_image_uri = image_uris.retrieve(
    region=None,
    framework=None,
    image_scope="inference",
    model_id=model_id,
    model_version=model_version,
    instance_type=inference_instance_type,
)
# Retrieve the inference script uri
deploy_source_uri = script_uris.retrieve(
    model_id=model_id, model_version=model_version, script_scope="inference"
)

endpoint_name = name_from_base(f"jumpstart-example-FT-{model_id}-")

# Use the estimator from the previous step to deploy to a SageMaker endpoint
finetuned_predictor = tc_estimator.deploy(
    initial_instance_count=1,
    instance_type=inference_instance_type,
    entry_point="inference.py",
    image_uri=deploy_image_uri,
    source_dir=deploy_source_uri,
    endpoint_name=endpoint_name,
)

After a few minutes, our model is deployed and we can get predictions from it in real time!

Next, we invoke the endpoint to predict the sentiment of the example text. We use the query_endpoint and parse_response helper functions, which are defined in the accompanying notebook:

text = "simply stupid , irrelevant and deeply , truly , bottomlessly cynical "
query_response = query_endpoint(text.encode("utf-8"))
probabilities, labels, predicted_label = parse_response(query_response)

Conclusion

JumpStart is a capability in SageMaker that allows you to quickly get started with ML. JumpStart uses open-source pre-trained models to solve common ML problems like image classification, object detection, text classification, sentence pair classification, and question answering.

In this post, we showed you how to fine-tune and deploy a pre-trained text classification model for sentiment analysis. We also mentioned the changes required to adapt the demo for sentence pair classification. With JumpStart, you can easily do this process with no need to code. Try out the solution on your own and let us know how it goes in the comments. To learn more about JumpStart, check out the AWS re:Invent 2020 video Get started with ML in minutes with Amazon SageMaker JumpStart.

References

  1. Socher et al., 2013
  2. Wang et al., 2018a
  3. Wang et al., 2018a

About the Authors

João Moura is an AI/ML Specialist Solutions Architect at Amazon Web Services. He is mostly focused on NLP use cases and helping customers optimize Deep Learning model training and deployment. He is also an active proponent of low-code ML solutions and ML-specialized hardware.

Dr. Vivek Madan is an Applied Scientist with Amazon SageMaker JumpStart team. He got his PhD from University of Illinois at Urbana-Champaign and was a Post Doctoral Researcher at Georgia Tech. He is an active researcher in machine learning and algorithm design and has published papers in EMNLP, ICLR, COLT, FOCS and SODA conferences.

Dr. Ashish Khetan is a Senior Applied Scientist with Amazon SageMaker JumpStart and Amazon SageMaker built-in algorithms and helps develop machine learning algorithms. He is an active researcher in machine learning and statistical inference and has published many papers in NeurIPS, ICML, ICLR, JMLR, ACL, and EMNLP conferences.

Read More

Solving the World’s Biggest Challenges, Together

Gamers know NVIDIA powers great gaming experiences. Researchers know NVIDIA speeds world-changing breakthroughs. Businesses know us for the AI engines transforming their industries.

And NVIDIA employees know the company as one of the best places to work on the planet.

More people than ever have a piece of NVIDIA. Roboticists, visual artists, data scientists — all sorts of innovators and creators rely on the company’s technology. And that’s only natural: NVIDIA’s the largest startup on Earth, growing to 25,000 employees from 10,000 a few years ago.

But as NVIDIA spills out in all directions it’s more important than ever to connect all these pieces, these people who may know our products, but don’t know one another.

That’s why we’re launching a campaign this week to bring all these elements together. To reflect back to the entertainers and entrepreneurs, researchers and scientists, developers and designers the staggering body of work we’ve built together.

It’s going to be quite a conversation. Not just because it’s comprehensive, but because it’s coherent.

The same GPU technology that powers the Nintendo Switch has proven the existence of the gravity waves Einstein predicted a century ago.

The parallel computing power harnessed by NVIDIA’s CUDA platform is key not just to Oscar-winning special effects, but to a new generation of medical breakthroughs.

And the huge leaps in computing power unleashed by innovations in silicon, software and systems created at NVIDIA are turning data centers into engines of business innovation and imbuing supercomputers with the power to simulate the planet itself, for the benefit of those of us who would become its stewards.

These stories don’t just cover the breadth of what’s been accomplished. They point to possibilities, new places where each of these endeavors intersect. These spillovers are anything but happy accidents; they’re by design, and they’ve always been the soul of NVIDIA.

So think of this as an introduction. Not just to the NVIDIA story, but to each other. And with more people than ever contributing to this story we all share, this body of work, there can be no doubt that the best part of the NVIDIA story is still to come.

Visit NVIDIA’s “About Us,” page, or click here for more on what NVIDIA, developers and customers have built together.

 

Featured image credit: image courtesy of Accuray

The post Solving the World’s Biggest Challenges, Together appeared first on NVIDIA Blog.

Read More

OCR in the browser using TensorFlow.js

A guest post by Charles Gaillard, Mindee

Introduction

Optical Character Recognition (OCR) refers to technologies capable of capturing text elements from images or documents and converting them into a machine-readable text format. If you want to learn more on that topic, this article is a good introduction.

At Mindee, we have developed an open-source Python-based OCR called DocTR, however we also wanted to deploy it in the browser to ensure that it was accessible to all developers – especially as ~70% developers choose to use JavaScript.

We managed to achieve this using the TensorFlow.js API, which resulted in a web demo that you can now try for yourself using images of your own.

The demo interface with a picture of 2 receipts being parsed by the OCR: 89 words were found here

This demo is designed to be very simple to use and run quickly on most computers, therefore we provided a single pretrained model that we trained with a small (512 x 512) input size to save memory. Images are resized to be squares, so it generalizes well to most of the documents which have an aspect ratio close to 1: cards, smaller receipts, tickets, A4, etc. For rectangles with a very high aspect ratio, segmentation results might not be as good because we don’t preserve the aspect ratio (with padding) at the text detection step. It is optimized to work on documents with a significant word size (for example receipts, cards, etc). Keep in mind that these models have been designed to offer performance while running in the browser. Hence, performance might not be optimal on documents that have a very small writing size vs the size of the document or images with a very high aspect ratio.

Dive into the architecture

OCR models can be divided into 2 parts: A detection model and a text recognition model. In DocTR, the detection model is a CNN (convolutional neural network) which segments the input image to find text areas, then text boxes are cropped around each detected word and sent to a recognition model. The second model is a convolutional recurrent neural network (CRNN), which extracts features from word-images and then decodes the sequence of letters on the image with recurrent layers (LSTM).

Global architecture of the OCR model used in this Demo

Detection model

We have different architectures implemented in DocTR, but we chose a very light one for use on the client side as device hardware can change from person to person. Here we used a mobilenetV2 backbone with a DB (Differentiable Binarization) head. The implementation details can be found in the DocTR Github. We trained this model with an input size of (512, 512, 3) to decrease latency and memory usage. We have a private dataset composed of 130,000 annotated documents that was used to train this model.

Recognition model

The recognition model we used is also our lighter architecture: a CRNN (convolutional recurrent neural network) with a mobilenetV2 backbone. More information on this architecture can be found here. It is basically composed of the first half of the mobilenetV2 layers to extract features and it is followed by 2 bi-LSTMs to decode visual features as character sequences (words). It uses the CTC loss, introduced by Alex Graves, to decode a sequence efficiently. We have an input size of (32, 128, 3) for word images in this model, and we use padding to preserve the aspect ratio of crops. It is trained on our private dataset, composed of 11 millions text boxes extracted from different documents. This dataset has a wide variety of fonts, since it is composed of documents which come from many different data sources. We used data augmentation so that it generalizes well on different fonts, backgrounds, and renderings. It should also give decent results on handwritten text as long as it is human-readable.

Model conversion & code implementation

As our model was originally implemented using TensorFlow, Python conversion was required to run the resulting models in the web browser at scale. To do this we exported a tensorflow SavedModel for each Python model trained and used the tensorflowjs_converter command line tool to quickly convert our saved models to the TensorFlow.js JSON format required for execution in the browser.

The resulting converted models were then integrated into our React.js front end application that powered the user interface of the demo. More precisely, we used MUI to design the components of the interface for our in-house front-end SDK react-mindee-js (which provides computer vision tools) and OpenCV.js for the detection model post processing. This post processing step took the raw binarized segmentation map and converted it to a list of polygons with OpenCV.js functions. We could then crop those boxes from the source image to finally obtain word images ready to be sent to the recognition model.

Speed & performance

We had to manage the tradeoff between speed and performance efficiently. OCR models are quite slow because you have 2 tasks (text areas segmentation + words recognition) that can’t be parallelized, so we had to use lightweight models to ensure speedy execution on most devices.

On an modern computer with an RTX 2060 and an i7 9th Gen, the detection task takes around 750 milliseconds per image, and the recognition model around 170 milliseconds per batch of 32 crops (words) with the WebGL backend, benchmarked with the TensorFlow.js benchmarking tool.

Wrapping up the 2 models and the vision operations (detection post processing), the end-to-end OCR runs in less than 2 seconds on small documents (less than 100 words) and the prediction time can only take a few seconds more to run on very dense documents with a lot of words.

A screenshot of the demo interface with a very dense old A4 document being parsed by the OCR: 738 words are identified.

Conclusion

This demo powered by TensorFlow.js is a way to give access to an online, relatively quick and robust document OCR to almost everyone, which is one of the first of its kind powered by TensorFlow.js entirely in the browser.

As we are executing the model on the client side, exact performance will vary depending on the hardware of the device it is run on. However the goal here is more to demonstrate that even complex and state-of-the-art deep learning models can be deployed in the browser and run on almost every machine in an efficient manner that can be very useful, especially for potentially sensitive document information, where you do not want to send the document to the cloud for analysis.

We are excited to offer this solution for all to use, and keen to follow the future of the Web ML industry, where things will no doubt get faster with time as new web standards like WebGPU become mainstream and enabled by default on modern web browsers.

Read More

In bias we trust?

When the stakes are high, machine-learning models are sometimes used to aid human decision-makers. For instance, a model could predict which law school applicants are most likely to pass the bar exam to help an admissions officer determine which students should be accepted.

These models often have millions of parameters, so how they make predictions is nearly impossible for researchers to fully understand, let alone an admissions officer with no machine-learning experience. Researchers sometimes employ explanation methods that mimic a larger model by creating simple approximations of its predictions. These approximations, which are far easier to understand, help users determine whether to trust the model’s predictions.

But are these explanation methods fair? If an explanation method provides better approximations for men than for women, or for white people than for Black people, it may encourage users to trust the model’s predictions for some people but not for others.

MIT researchers took a hard look at the fairness of some widely used explanation methods. They found that the approximation quality of these explanations can vary dramatically between subgroups and that the quality is often significantly lower for minoritized subgroups.

In practice, this means that if the approximation quality is lower for female applicants, there is a mismatch between the explanations and the model’s predictions that could lead the admissions officer to wrongly reject more women than men.

Once the MIT researchers saw how pervasive these fairness gaps are, they tried several techniques to level the playing field. They were able to shrink some gaps, but couldn’t eradicate them.

“What this means in the real-world is that people might incorrectly trust predictions more for some subgroups than for others. So, improving explanation models is important, but communicating the details of these models to end users is equally important. These gaps exist, so users may want to adjust their expectations as to what they are getting when they use these explanations,” says lead author Aparna Balagopalan, a graduate student in the Healthy ML group of the MIT Computer Science and Artificial Intelligence Laboratory (CSAIL).

Balagopalan wrote the paper with CSAIL graduate students Haoran Zhang and Kimia Hamidieh; CSAIL postdoc Thomas Hartvigsen; Frank Rudzicz, associate professor of computer science at the University of Toronto; and senior author Marzyeh Ghassemi, an assistant professor and head of the Healthy ML Group. The research will be presented at the ACM Conference on Fairness, Accountability, and Transparency.

High fidelity

Simplified explanation models can approximate predictions of a more complex machine-learning model in a way that humans can grasp. An effective explanation model maximizes a property known as fidelity, which measures how well it matches the larger model’s predictions.

Rather than focusing on average fidelity for the overall explanation model, the MIT researchers studied fidelity for subgroups of people in the model’s dataset. In a dataset with men and women, the fidelity should be very similar for each group, and both groups should have fidelity close to that of the overall explanation model.

“When you are just looking at the average fidelity across all instances, you might be missing out on artifacts that could exist in the explanation model,” Balagopalan says.

They developed two metrics to measure fidelity gaps, or disparities in fidelity between subgroups. One is the difference between the average fidelity across the entire explanation model and the fidelity for the worst-performing subgroup. The second calculates the absolute difference in fidelity between all possible pairs of subgroups and then computes the average.

With these metrics, they searched for fidelity gaps using two types of explanation models that were trained on four real-world datasets for high-stakes situations, such as predicting whether a patient dies in the ICU, whether a defendant reoffends, or whether a law school applicant will pass the bar exam. Each dataset contained protected attributes, like the sex and race of individual people. Protected attributes are features that may not be used for decisions, often due to laws or organizational policies. The definition for these can vary based on the task specific to each decision setting.

The researchers found clear fidelity gaps for all datasets and explanation models. The fidelity for disadvantaged groups was often much lower, up to 21 percent in some instances. The law school dataset had a fidelity gap of 7 percent between race subgroups, meaning the approximations for some subgroups were wrong 7 percent more often on average. If there are 10,000 applicants from these subgroups in the dataset, for example, a significant portion could be wrongly rejected, Balagopalan explains.

“I was surprised by how pervasive these fidelity gaps are in all the datasets we evaluated. It is hard to overemphasize how commonly explanations are used as a ‘fix’ for black-box machine-learning models. In this paper, we are showing that the explanation methods themselves are imperfect approximations that may be worse for some subgroups,” says Ghassemi.

Narrowing the gaps

After identifying fidelity gaps, the researchers tried some machine-learning approaches to fix them. They trained the explanation models to identify regions of a dataset that could be prone to low fidelity and then focus more on those samples. They also tried using balanced datasets with an equal number of samples from all subgroups.

These robust training strategies did reduce some fidelity gaps, but they didn’t eliminate them.

The researchers then modified the explanation models to explore why fidelity gaps occur in the first place. Their analysis revealed that an explanation model might indirectly use protected group information, like sex or race, that it could learn from the dataset, even if group labels are hidden.

They want to explore this conundrum more in future work. They also plan to further study the implications of fidelity gaps in the context of real-world decision making.

Balagopalan is excited to see that concurrent work on explanation fairness from an independent lab has arrived at similar conclusions, highlighting the importance of understanding this problem well.

As she looks to the next phase in this research, she has some words of warning for machine-learning users.

“Choose the explanation model carefully. But even more importantly, think carefully about the goals of using an explanation model and who it eventually affects,” she says.

This work was funded, in part, by the MIT-IBM Watson AI Lab, the Quanta Research Institute, a Canadian Institute for Advanced Research AI Chair, and Microsoft Research.

Read More