We’ve released a catalogue of ‘missense’ mutations where researchers can learn more about what effect they may have. Missense variants are genetic mutations that can affect the function of human proteins. In some cases, they can lead to diseases such as cystic fibrosis, sickle-cell anaemia, or cancer. The AlphaMissense catalogue was developed using AlphaMissense, our new AI model which classifies missense variants.Read More
Orchestrate Ray-based machine learning workflows using Amazon SageMaker
Machine learning (ML) is becoming increasingly complex as customers try to solve more and more challenging problems. This complexity often leads to the need for distributed ML, where multiple machines are used to train a single model. Although this enables parallelization of tasks across multiple nodes, leading to accelerated training times, enhanced scalability, and improved performance, there are significant challenges in effectively using distributed hardware. Data scientists have to address challenges like data partitioning, load balancing, fault tolerance, and scalability. ML engineers must handle parallelization, scheduling, faults, and retries manually, requiring complex infrastructure code.
In this post, we discuss the benefits of using Ray and Amazon SageMaker for distributed ML, and provide a step-by-step guide on how to use these frameworks to build and deploy a scalable ML workflow.
Ray, an open-source distributed computing framework, provides a flexible framework for distributed training and serving of ML models. It abstracts away low-level distributed system details through simple, scalable libraries for common ML tasks such as data preprocessing, distributed training, hyperparameter tuning, reinforcement learning, and model serving.
SageMaker is a fully managed service for building, training, and deploying ML models. Ray seamlessly integrates with SageMaker features to build and deploy complex ML workloads that are both efficient and reliable. The combination of Ray and SageMaker provides end-to-end capabilities for scalable ML workflows, and has the following highlighted features:
- Distributed actors and parallelism constructs in Ray simplify developing distributed applications.
- Ray AI Runtime (AIR) reduces friction of going from development to production. With Ray and AIR, the same Python code can scale seamlessly from a laptop to a large cluster.
- The managed infrastructure of SageMaker and features like processing jobs, training jobs, and hyperparameter tuning jobs can use Ray libraries underneath for distributed computing.
- Amazon SageMaker Experiments allows rapidly iterating and keeping track of trials.
- Amazon SageMaker Feature Store provides a scalable repository for storing, retrieving, and sharing ML features for model training.
- Trained models can be stored, versioned, and tracked in Amazon SageMaker Model Registry for governance and management.
- Amazon SageMaker Pipelines allows orchestrating the end-to-end ML lifecycle from data preparation and training to model deployment as automated workflows.
Solution overview
This post focuses on the benefits of using Ray and SageMaker together. We set up an end-to-end Ray-based ML workflow, orchestrated using SageMaker Pipelines. The workflow includes parallel ingestion of data into the feature store using Ray actors, data preprocessing with Ray Data, training models and hyperparameter tuning at scale using Ray Train and hyperparameter optimization (HPO) tuning jobs, and finally model evaluation and registering the model into a model registry.
For our data, we use a synthetic housing dataset that consists of eight features (YEAR_BUILT
, SQUARE_FEET
, NUM_BEDROOM
, NUM_BATHROOMS
, LOT_ACRES
, GARAGE_SPACES
, FRONT_PORCH
, and DECK
) and our model will predict the PRICE
of the house.
Each stage in the ML workflow is broken into discrete steps, with its own script that takes input and output parameters. In the next section, we highlight key code snippets from each step. The full code can be found on the aws-samples-for-ray GitHub repository.
Prerequisites
To use the SageMaker Python SDK and run the code associated with this post, you need the following prerequisites:
- An AWS account that contains all your AWS resources
- An AWS Identity and Access Management (IAM) role with access to Amazon SageMaker Studio notebooks, SageMaker Feature Store, SageMaker Model Registry, and SageMaker Pipelines
Ingest data into SageMaker Feature Store
The first step in the ML workflow is to read the source data file from Amazon Simple Storage Service (Amazon S3) in CSV format and ingest it into SageMaker Feature Store. SageMaker Feature Store is a purpose-built repository that makes it easy for teams to create, share, and manage ML features. It simplifies feature discovery, reuse, and sharing, leading to faster development, increased collaboration within customer teams, and reduced costs.
Ingesting features into the feature store contains the following steps:
- Define a feature group and create the feature group in the feature store.
- Prepare the source data for the feature store by adding an event time and record ID for each row of data.
- Ingest the prepared data into the feature group by using the Boto3 SDK.
In this section, we only highlight Step 3, because this is the part that involves parallel processing of the ingestion task using Ray. You can review the full code for this process in the GitHub repo.
The ingest_features method is defined inside a class called Featurestore
. Note that the Featurestore
class is decorated with @ray.remote
. This indicates that an instance of this class is a Ray actor, a stateful and concurrent computational unit within Ray. It’s a programming model that allows you to create distributed objects that maintain an internal state and can be accessed concurrently by multiple tasks running on different nodes in a Ray cluster. Actors provide a way to manage and encapsulate the mutable state, making them valuable for building complex, stateful applications in a distributed setting. You can specify resource requirements in actors too. In this case, each instance of the FeatureStore
class will require 0.5 CPUs. See the following code:
@ray.remote(num_cpus=0.5)
class Featurestore:
def ingest_features(self,feature_group_name, df, region):
"""
Ingest features to Feature Store Group
Args:
feature_group_name (str): Feature Group Name
data_path (str): Path to the train/validation/test data in CSV format.
"""
...
You can interact with the actor by calling the remote
operator. In the following code, the desired number of actors is passed in as an input argument to the script. The data is then partitioned based on the number of actors and passed to the remote parallel processes to be ingested into the feature store. You can call get
on the object ref to block the execution of the current task until the remote computation is complete and the result is available. When the result is available, ray.get
will return the result, and the execution of the current task will continue.
import modin.pandas as pd
import ray
df = pd.read_csv(s3_path)
data = prepare_df_for_feature_store(df)
# Split into partitions
partitions = [ray.put(part) for part in np.array_split(data, num_actors)]
# Start actors and assign partitions in a loop
actors = [Featurestore.remote() for _ in range(args.num_actors)]
results = []
for actor, partition in zip(actors, input_partitions):
results.append(actor.ingest_features.remote(
args.feature_group_name,
partition, args.region
)
)
ray.get(results)
Prepare data for training, validation, and testing
In this step, we use Ray Dataset to efficiently split, transform, and scale our dataset in preparation for machine learning. Ray Dataset provides a standard way to load distributed data into Ray, supporting various storage systems and file formats. It has APIs for common ML data preprocessing operations like parallel transformations, shuffling, grouping, and aggregations. Ray Dataset also handles operations needing stateful setup and GPU acceleration. It integrates smoothly with other data processing libraries like Spark, Pandas, NumPy, and more, as well as ML frameworks like TensorFlow and PyTorch. This allows building end-to-end data pipelines and ML workflows on top of Ray. The goal is to make distributed data processing and ML easier for practitioners and researchers.
Let’s look at sections of the scripts that perform this data preprocessing. We start by loading the data from the feature store:
def load_dataset(feature_group_name, region):
"""
Loads the data as a ray dataset from the offline featurestore S3 location
Args:
feature_group_name (str): name of the feature group
Returns:
ds (ray.data.dataset): Ray dataset the contains the requested dat from the feature store
"""
session = sagemaker.Session(boto3.Session(region_name=region))
fs_group = FeatureGroup(
name=feature_group_name,
sagemaker_session=session
)
fs_data_loc = fs_group.describe().get("OfflineStoreConfig").get("S3StorageConfig").get("ResolvedOutputS3Uri")
# Drop columns added by the feature store
# Since these are not related to the ML problem at hand
cols_to_drop = ["record_id", "event_time","write_time",
"api_invocation_time", "is_deleted",
"year", "month", "day", "hour"]
ds = ray.data.read_parquet(fs_data_loc)
ds = ds.drop_columns(cols_to_drop)
print(f"{fs_data_loc} count is {ds.count()}")
return ds
We then split and scale data using the higher-level abstractions available from the ray.data
library:
def split_dataset(dataset, train_size, val_size, test_size, random_state=None):
"""
Split dataset into train, validation and test samples
Args:
dataset (ray.data.Dataset): input data
train_size (float): ratio of data to use as training dataset
val_size (float): ratio of data to use as validation dataset
test_size (float): ratio of data to use as test dataset
random_state (int): Pass an int for reproducible output across multiple function calls.
Returns:
train_set (ray.data.Dataset): train dataset
val_set (ray.data.Dataset): validation dataset
test_set (ray.data.Dataset): test dataset
"""
# Shuffle this dataset with a fixed random seed.
shuffled_ds = dataset.random_shuffle(seed=random_state)
# Split the data into train, validation and test datasets
train_set, val_set, test_set = shuffled_ds.split_proportionately([train_size, val_size])
return train_set, val_set, test_set
def scale_dataset(train_set, val_set, test_set, target_col):
"""
Fit StandardScaler to train_set and apply it to val_set and test_set
Args:
train_set (ray.data.Dataset): train dataset
val_set (ray.data.Dataset): validation dataset
test_set (ray.data.Dataset): test dataset
target_col (str): target col
Returns:
train_transformed (ray.data.Dataset): train data scaled
val_transformed (ray.data.Dataset): val data scaled
test_transformed (ray.data.Dataset): test data scaled
"""
tranform_cols = dataset.columns()
# Remove the target columns from being scaled
tranform_cols.remove(target_col)
# set up a standard scaler
standard_scaler = StandardScaler(tranform_cols)
# fit scaler to training dataset
print("Fitting scaling to training data and transforming dataset...")
train_set_transformed = standard_scaler.fit_transform(train_set)
# apply scaler to validation and test datasets
print("Transforming validation and test datasets...")
val_set_transformed = standard_scaler.transform(val_set)
test_set_transformed = standard_scaler.transform(test_set)
return train_set_transformed, val_set_transformed, test_set_transformed
The processed train, validation, and test datasets are stored in Amazon S3 and will be passed as the input parameters to subsequent steps.
Perform model training and hyperparameter optimization
With our data preprocessed and ready for modeling, it’s time to train some ML models and fine-tune their hyperparameters to maximize predictive performance. We use XGBoost-Ray, a distributed backend for XGBoost built on Ray that enables training XGBoost models on large datasets by using multiple nodes and GPUs. It provides simple drop-in replacements for XGBoost’s train and predict APIs while handling the complexities of distributed data management and training under the hood.
To enable distribution of the training over multiple nodes, we utilize a helper class named RayHelper. As shown in the following code, we use the resource configuration of the training job and choose the first host as the head node:
class RayHelper():
def __init__(self, ray_port:str="9339", redis_pass:str="redis_password"):
....
self.resource_config = self.get_resource_config()
self.head_host = self.resource_config["hosts"][0]
self.n_hosts = len(self.resource_config["hosts"])
We can use the host information to decide how to initialize Ray on each of the training job instances:
def start_ray(self):
head_ip = self._get_ip_from_host()
# If the current host is the host choosen as the head node
# run `ray start` with specifying the --head flag making this is the head node
if self.resource_config["current_host"] == self.head_host:
output = subprocess.run(['ray', 'start', '--head', '-vvv', '--port',
self.ray_port, '--redis-password', self.redis_pass,
'--include-dashboard', 'false'], stdout=subprocess.PIPE)
print(output.stdout.decode("utf-8"))
ray.init(address="auto", include_dashboard=False)
self._wait_for_workers()
print("All workers present and accounted for")
print(ray.cluster_resources())
else:
# If the current host is not the head node,
# run `ray start` with specifying ip address as the head_host as the head node
time.sleep(10)
output = subprocess.run(['ray', 'start',
f"--address={head_ip}:{self.ray_port}",
'--redis-password', self.redis_pass, "--block"], stdout=subprocess.PIPE)
print(output.stdout.decode("utf-8"))
sys.exit(0)
When a training job is started, a Ray cluster can be initialized by calling the start_ray()
method on an instance of RayHelper
:
if __name__ == '__main__':
ray_helper = RayHelper()
ray_helper.start_ray()
args = read_parameters()
sess = sagemaker.Session(boto3.Session(region_name=args.region))
We then use the XGBoost trainer from XGBoost-Ray for training:
def train_xgboost(ds_train, ds_val, params, num_workers, target_col = "price") -> Result:
"""
Creates a XGBoost trainer, train it, and return the result.
Args:
ds_train (ray.data.dataset): Training dataset
ds_val (ray.data.dataset): Validation dataset
params (dict): Hyperparameters
num_workers (int): number of workers to distribute the training across
target_col (str): target column
Returns:
result (ray.air.result.Result): Result of the training job
"""
train_set = RayDMatrix(ds_train, 'PRICE')
val_set = RayDMatrix(ds_val, 'PRICE')
evals_result = {}
trainer = train(
params=params,
dtrain=train_set,
evals_result=evals_result,
evals=[(val_set, "validation")],
verbose_eval=False,
num_boost_round=100,
ray_params=RayParams(num_actors=num_workers, cpus_per_actor=1),
)
output_path=os.path.join(args.model_dir, 'model.xgb')
trainer.save_model(output_path)
valMAE = evals_result["validation"]["mae"][-1]
valRMSE = evals_result["validation"]["rmse"][-1]
print('[3] #011validation-mae:{}'.format(valMAE))
print('[4] #011validation-rmse:{}'.format(valRMSE))
local_testing = False
try:
load_run(sagemaker_session=sess)
except:
local_testing = True
if not local_testing: # Track experiment if using SageMaker Training
with load_run(sagemaker_session=sess) as run:
run.log_metric('validation-mae', valMAE)
run.log_metric('validation-rmse', valRMSE)
Note that while instantiating the trainer
, we pass RayParams
, which takes the number actors and number of CPUs per actors. XGBoost-Ray uses this information to distribute the training across all the nodes attached to the Ray cluster.
We now create a XGBoost estimator object based on the SageMaker Python SDK and use that for the HPO job.
Orchestrate the preceding steps using SageMaker Pipelines
To build an end-to-end scalable and reusable ML workflow, we need to use a CI/CD tool to orchestrate the preceding steps into a pipeline. SageMaker Pipelines has direct integration with SageMaker, the SageMaker Python SDK, and SageMaker Studio. This integration allows you to create ML workflows with an easy-to-use Python SDK, and then visualize and manage your workflow using SageMaker Studio. You can also track the history of your data within the pipeline execution and designate steps for caching.
SageMaker Pipelines creates a Directed Acyclic Graph (DAG) that includes steps needed to build an ML workflow. Each pipeline is a series of interconnected steps orchestrated by data dependencies between steps, and can be parameterized, allowing you to provide input variables as parameters for each run of the pipeline. SageMaker Pipelines has four types of pipeline parameters: ParameterString
, ParameterInteger
, ParameterFloat
, and ParameterBoolean
. In this section, we parameterize some of the input variables and set up the step caching configuration:
processing_instance_count = ParameterInteger(
name='ProcessingInstanceCount',
default_value=1
)
feature_group_name = ParameterString(
name='FeatureGroupName',
default_value='fs-ray-synthetic-housing-data'
)
bucket_prefix = ParameterString(
name='Bucket_Prefix',
default_value='aws-ray-mlops-workshop/feature-store'
)
rmse_threshold = ParameterFloat(name="RMSEThreshold", default_value=15000.0)
train_size = ParameterString(
name='TrainSize',
default_value="0.6"
)
val_size = ParameterString(
name='ValidationSize',
default_value="0.2"
)
test_size = ParameterString(
name='TestSize',
default_value="0.2"
)
cache_config = CacheConfig(enable_caching=True, expire_after="PT12H")
We define two processing steps: one for SageMaker Feature Store ingestion, the other for data preparation. This should look very similar to the previous steps described earlier. The only new line of code is the ProcessingStep
after the steps’ definition, which allows us to take the processing job configuration and include it as a pipeline step. We further specify the dependency of the data preparation step on the SageMaker Feature Store ingestion step. See the following code:
feature_store_ingestion_step = ProcessingStep(
name='FeatureStoreIngestion',
step_args=fs_processor_args,
cache_config=cache_config
)
preprocess_dataset_step = ProcessingStep(
name='PreprocessData',
step_args=processor_args,
cache_config=cache_config
)
preprocess_dataset_step.add_depends_on([feature_store_ingestion_step])
Similarly, to build a model training and tuning step, we need to add a definition of TuningStep
after the model training step’s code to allow us to run SageMaker hyperparameter tuning as a step in the pipeline:
tuning_step = TuningStep(
name="HPTuning",
tuner=tuner,
inputs={
"train": TrainingInput(
s3_data=preprocess_dataset_step.properties.ProcessingOutputConfig.Outputs[
"train"
].S3Output.S3Uri,
content_type="text/csv"
),
"validation": TrainingInput(
s3_data=preprocess_dataset_step.properties.ProcessingOutputConfig.Outputs[
"validation"
].S3Output.S3Uri,
content_type="text/csv"
)
},
cache_config=cache_config,
)
tuning_step.add_depends_on([preprocess_dataset_step])
After the tuning step, we choose to register the best model into SageMaker Model Registry. To control the model quality, we implement a minimum quality gate that compares the best model’s objective metric (RMSE) against a threshold defined as the pipeline’s input parameter rmse_threshold
. To do this evaluation, we create another processing step to run an evaluation script. The model evaluation result will be stored as a property file. Property files are particularly useful when analyzing the results of a processing step to decide how other steps should be run. See the following code:
# Specify where we'll store the model evaluation results so that other steps can access those results
evaluation_report = PropertyFile(
name='EvaluationReport',
output_name='evaluation',
path='evaluation.json',
)
# A ProcessingStep is used to evaluate the performance of a selected model from the HPO step.
# In this case, the top performing model is evaluated.
evaluation_step = ProcessingStep(
name='EvaluateModel',
processor=evaluation_processor,
inputs=[
ProcessingInput(
source=tuning_step.get_top_model_s3_uri(
top_k=0, s3_bucket=bucket, prefix=s3_prefix
),
destination='/opt/ml/processing/model',
),
ProcessingInput(
source=preprocess_dataset_step.properties.ProcessingOutputConfig.Outputs['test'].S3Output.S3Uri,
destination='/opt/ml/processing/test',
),
],
outputs=[
ProcessingOutput(
output_name='evaluation', source='/opt/ml/processing/evaluation'
),
],
code='./pipeline_scripts/evaluate/script.py',
property_files=[evaluation_report],
)
We define a ModelStep
to register the best model into SageMaker Model Registry in our pipeline. In case the best model doesn’t pass our predetermined quality check, we additionally specify a FailStep
to output an error message:
register_step = ModelStep(
name='RegisterTrainedModel',
step_args=model_registry_args
)
metrics_fail_step = FailStep(
name="RMSEFail",
error_message=Join(on=" ", values=["Execution failed due to RMSE >", rmse_threshold]),
)
Next, we use a ConditionStep
to evaluate whether the model registration step or failure step should be taken next in the pipeline. In our case, the best model will be registered if its RMSE score is lower than the threshold.
# Condition step for evaluating model quality and branching execution
cond_lte = ConditionLessThanOrEqualTo(
left=JsonGet(
step_name=evaluation_step.name,
property_file=evaluation_report,
json_path='regression_metrics.rmse.value',
),
right=rmse_threshold,
)
condition_step = ConditionStep(
name='CheckEvaluation',
conditions=[cond_lte],
if_steps=[register_step],
else_steps=[metrics_fail_step],
)
Finally, we orchestrate all the defined steps into a pipeline:
pipeline_name = 'synthetic-housing-training-sm-pipeline-ray'
step_list = [
feature_store_ingestion_step,
preprocess_dataset_step,
tuning_step,
evaluation_step,
condition_step
]
training_pipeline = Pipeline(
name=pipeline_name,
parameters=[
processing_instance_count,
feature_group_name,
train_size,
val_size,
test_size,
bucket_prefix,
rmse_threshold
],
steps=step_list
)
# Note: If an existing pipeline has the same name it will be overwritten.
training_pipeline.upsert(role_arn=role_arn)
The preceding pipeline can be visualized and executed directly in SageMaker Studio, or be executed by calling execution = training_pipeline.start()
. The following figure illustrates the pipeline flow.
Additionally, we can review the lineage of artifacts generated by the pipeline execution.
from sagemaker.lineage.visualizer import LineageTableVisualizer
viz = LineageTableVisualizer(sagemaker.session.Session())
for execution_step in reversed(execution.list_steps()):
print(execution_step)
display(viz.show(pipeline_execution_step=execution_step))
time.sleep(5)
Deploy the model
After the best model is registered in SageMaker Model Registry via a pipeline run, we deploy the model to a real-time endpoint by using the fully managed model deployment capabilities of SageMaker. SageMaker has other model deployment options to meet the needs of different use cases. For details, refer to Deploy models for inference when choosing the right option for your use case. First, let’s get the model registered in SageMaker Model Registry:
xgb_regressor_model = ModelPackage(
role_arn,
model_package_arn=model_package_arn,
name=model_name
)
The model’s current status is PendingApproval
. We need to set its status to Approved
prior to deployment:
sagemaker_client.update_model_package(
ModelPackageArn=xgb_regressor_model.model_package_arn,
ModelApprovalStatus='Approved'
)
xgb_regressor_model.deploy(
initial_instance_count=1,
instance_type='ml.m5.xlarge',
endpoint_name=endpoint_name
)
Clean up
After you are done experimenting, remember to clean up the resources to avoid unnecessary charges. To clean up, delete the real-time endpoint, model group, pipeline, and feature group by calling the APIs DeleteEndpoint, DeleteModelPackageGroup, DeletePipeline, and DeleteFeatureGroup, respectively, and shut down all SageMaker Studio notebook instances.
Conclusion
This post demonstrated a step-by-step walkthrough on how to use SageMaker Pipelines to orchestrate Ray-based ML workflows. We also demonstrated the capability of SageMaker Pipelines to integrate with third-party ML tools. There are various AWS services that support Ray workloads in a scalable and secure fashion to ensure performance excellence and operational efficiency. Now, it’s your turn to explore these powerful capabilities and start optimizing your machine learning workflows with Amazon SageMaker Pipelines and Ray. Take action today and unlock the full potential of your ML projects!
About the Author
Raju Rangan is a Senior Solutions Architect at Amazon Web Services (AWS). He works with government sponsored entities, helping them build AI/ML solutions using AWS. When not tinkering with cloud solutions, you’ll catch him hanging out with family or smashing birdies in a lively game of badminton with friends.
Sherry Ding is a senior AI/ML specialist solutions architect at Amazon Web Services (AWS). She has extensive experience in machine learning with a PhD degree in computer science. She mainly works with public sector customers on various AI/ML-related business challenges, helping them accelerate their machine learning journey on the AWS Cloud. When not helping customers, she enjoys outdoor activities.
Designing resilient cities at Arup using Amazon SageMaker geospatial capabilities
This post is co-authored with Richard Alexander and Mark Hallows from Arup.
Arup is a global collective of designers, consultants, and experts dedicated to sustainable development. Data underpins Arup consultancy for clients with world-class collection and analysis providing insight to make an impact.
The solution presented here is to direct decision-making processes for resilient city design. Informing design decisions towards more sustainable choices reduces the overall urban heat islands (UHI) effect and improves quality of life metrics for air quality, water quality, urban acoustics, biodiversity, and thermal comfort. Identifying key areas within an urban environment for intervention allows Arup to provide the best guidance in the industry and create better quality of life for citizens around the planet.
Urban heat islands describe the effect urban areas have on temperature compared to surrounding rural environments. Understanding how UHI affects our cities leads to improved designs that reduce the impact of urban heat on residents. The UHI effect impacts human health, greenhouse gas emissions, and water quality, and leads to increased energy usage. For city authorities, asset owners, and developers, understanding the impact on the population is key to improving quality of life and natural ecosystems. Modeling UHI accurately is a complex challenge, which Arup is now solving with earth observation data and Amazon SageMaker.
This post shows how Arup partnered with AWS to perform earth observation analysis with Amazon SageMaker geospatial capabilities to unlock UHI insights from satellite imagery. SageMaker geospatial capabilities make it easy for data scientists and machine learning (ML) engineers to build, train, and deploy models using geospatial data. SageMaker geospatial capabilities allow you to efficiently transform and enrich large-scale geospatial datasets, accelerate product development and time to insight with pre-trained ML models, and explore model predictions and geospatial data on an interactive map using 3D accelerated graphics and built-in visualization tools.
Overview of solution
The initial solution focuses on London, where during a heatwave in the summer of 2022, the UK Health Security Agency estimated 2,803 excess deaths were caused due to heat. Identifying areas within an urban environment where people may be more vulnerable to the UHI effect allows public services to direct assistance where it will have the greatest impact. This can even be forecast prior to high temperature events, reducing the impact of extreme weather and delivering a positive outcome for city dwellers.
Earth Observation (EO) data was used to perform the analysis at city scale. However, the total size poses challenges with traditional ways of storing, organizing, and querying data for large geographical areas. Arup addressed this challenge by partnering with AWS and using SageMaker geospatial capabilities to enable analysis at a city scale and beyond. As the geographic area grows to larger metropolitan areas like Los Angeles or Tokyo, the more storage and compute for analysis is required. The elasticity of AWS infrastructure is ideal for UHI analyses of urban environments of any size.
The solution: UHeat
Arup used SageMaker to develop UHeat, a digital solution that analyzes huge areas of cities to identify particular buildings, structures, and materials that are causing temperatures to rise. UHeat uses a combination of satellite imagery and open-source climate data to perform the analysis.
A small team at Arup undertook the initial analysis, during which additional data scientists needed to be trained on the SageMaker tooling and workflows. Onboarding data scientists to a new project used to take weeks using in-house tools. This now takes a matter of hours with SageMaker.
The first step of any EO analysis is the collection and preparation of the data. With SageMaker, Arup can access data from a catalog of geospatial data providers, including Sentinel-2 data, which was used for the London analysis. Built-in geospatial dataset access saves weeks of effort otherwise lost to collecting and preparing data from various data providers and vendors. EO imagery is frequently made up of small tiles which, to cover an area the size of London, need to be combined. This is known as a geomosaic, which can be created automatically using the managed geospatial operations in a SageMaker Geomosaic Earth Observation job.
After the EO data for the area of interest is compiled, the key influencing parameters for the analysis can be extracted. For UHI, Arup needed to be able to derive data on parameters for building geometry, building materials, anthropogenic heat sources, and coverage of existing and planned green spaces. Using optical imagery such as Sentinel-2, land cover classes including buildings, roads, water, vegetation cover, bare ground, and the albedo (measure of reflectiveness) of each of these surfaces can be calculated.
Calculating the values from the different bands in the satellite imagery allows them to be used as inputs into the SUEWS model, which provides a rigorous way of calculating UHI effect. The results of SUEWS are then visualized, in this case with Arup’s existing geospatial data platform. By adjusting values such as the albedo of a specific location, Arup are able to test the effect of mitigation strategies. Albedo performance can be further refined in simulations by modeling different construction materials, cladding, or roofing. Arup found that in one area of London, increasing albedo from 0.1 to 0.9 could decrease ambient temperature by 1.1°C during peak conditions. Over larger areas of interest, this modeling can also be used to forecast the UHI effect alongside climate projections to quantify the scale of the UHI effect.
With historical data from sources such as Sentinel-2, temporal studies can be completed. This enables Arup to visualize the UHI effect during periods of interest, such as the London summer 2022 heatwave. The Urban Heat Snapshot research Arup has completed reveals how the UHI effect is pushing up temperatures in cities like London, Madrid, Mumbai, and Los Angeles.
Collecting data for an area of interest
SageMaker eliminates the complexities in manually collecting data for Earth Observation jobs (EOJs) by providing a catalog of geospatial data providers. As of this writing, USGS Landsat, Sentinel-1, Copernicus DEM, NAIP: National Agriculture Imagery Program, and Sentinel-2 data is available directly from the catalog. You can also bring your own Planet Labs data when imagery at a higher resolution and frequency is required. Built-in geospatial dataset access saves weeks of effort otherwise lost to collecting data from various data providers and vendors. Coordinates for the polygon area of interest need to be provided as well as the time range for when EO imagery was collected.
Arup’s next step was to combine these images into a larger single raster covering the entire area of interest. This is known as mosaicking and is supported by passing GeoMosaicConfig to the SageMaker StartEarthObservationJob API.
We have provided some code samples representative of the steps Arup took:
This can take a while to complete. You can check the status of your jobs like so:
Resampling
Next, the raster is resampled to normalize the pixel size across the collected images. You can use ResamplingConfig to achieve this by providing the value of the length of a side of the pixel:
Determining coverage
Determining land coverage such as vegetation is possible by applying a normalized difference vegetation index (NDVI). In practice, this can be calculated from the intensity of reflected red and near-infrared light. To apply such a calculation to EO data within SageMaker, the BandMathConfig can be supplied to the StartEarthObservationJob API:
We can visualize the result of the band math job output within the SageMaker geospatial capabilities visualization tool. SageMaker geospatial capabilities can help you overlay model predictions on a base map and provide layered visualization to make collaboration easier. The GPU-powered interactive visualizer and Python notebooks provide a seamless way to explore millions of data points in a single window as well as collaborate on the insights and results.
Preparing for visualization
As a final step, Arup prepares the various bands and calculated bands for visualization by combining them into a single GeoTIFF. For band stacking, SageMaker EOJs can be passed the StackConfig object, where the output resolution can be set based on the resolutions of the input images:
Finally, the output GeoTIFF can be stored for later use in Amazon Simple Storage Service (Amazon S3) or visualized using SageMaker geospatial capabilities. By storing the output in Amazon S3, Arup can use the analysis in new projects and incorporate the data into new inference jobs. In Arup’s case, they used the processed GeoTIFF in their existing geographic information system visualization tooling to produce visualizations consistent with their product design themes.
Conclusion
By utilizing the native functionality of SageMaker, Arup was able to conduct an analysis of UHI effect at city scale, which previously took weeks, in a few hours. This helps Arup enable their own clients to meet their sustainability targets faster and narrows the areas of focus where UHI effect mitigation strategies should be applied, saving precious resources and optimizing mitigation tactics. The analysis can also be integrated into future earth observation tooling as part of larger risk analysis projects, and helps Arup’s customers forecast the effect of UHI in different scenarios.
Companies such as Arup are unlocking sustainability through the cloud with earth observation data. Unlock the possibilities of earth observation data in your sustainability projects by exploring the SageMaker geospatial capabilities on the SageMaker console today. To find out more, refer to Amazon SageMaker geospatial capabilities, or get hands on with a guidance solution.
About the Authors
Richard Alexander is an Associate Geospatial Data Scientist at Arup, based in Bristol. He has a proven track record of building successful teams and leading and delivering earth observation and data science-related projects across multiple environmental sectors.
Mark Hallows is a Remote Sensing Specialist at Arup, based in London. Mark provides expertise in earth observation and geospatial data analysis to a broad range of clients and delivers insights and thought leadership using both traditional machine learning and deep learning techniques.
Thomas Attree is a Senior Solutions Architect at Amazon Web Services based in London. Thomas currently helps customers in the power and utilities industry and applies his passion for sustainability to help customers architect applications for energy efficiency, as well as advise on using cloud technology to empower sustainability projects.
Tamara Herbert is a Senior Application Developer with AWS Professional Services in the UK. She specializes in building modern and scalable applications for a wide variety of customers, currently focusing on those within the public sector. She is actively involved in building solutions and driving conversations that enable organizations to meet their sustainability goals both in and through the cloud.
Anirudh Viswanathan – is a Sr Product Manager, Technical – External Services with the SageMaker geospatial ML team. He holds a Masters in Robotics from Carnegie Mellon University and an MBA from the Wharton School of Business, and is named inventor on over 50 patents. He enjoys long-distance running, visiting art galleries, and Broadway shows.
Ray Shines with NVIDIA AI: Anyscale Collaboration to Help Developers Build, Tune, Train and Scale Production LLMs
Large language model development is about to reach supersonic speed thanks to a collaboration between NVIDIA and Anyscale.
At its annual Ray Summit developers conference, Anyscale — the company behind the fast growing open-source unified compute framework for scalable computing — announced today that it is bringing NVIDIA AI to Ray open source and the Anyscale Platform. It will also be integrated into Anyscale Endpoints, a new service announced today that makes it easy for application developers to cost-effectively embed LLMs in their applications using the most popular open source models.
These integrations can dramatically speed generative AI development and efficiency while boosting security for production AI, from proprietary LLMs to open models such as Code Llama, Falcon, Llama 2, SDXL and more.
Developers will have the flexibility to deploy open-source NVIDIA software with Ray or opt for NVIDIA AI Enterprise software running on the Anyscale Platform for a fully supported and secure production deployment.
Ray and the Anyscale Platform are widely used by developers building advanced LLMs for generative AI applications capable of powering intelligent chatbots, coding copilots and powerful search and summarization tools.
NVIDIA and Anyscale Deliver Speed, Savings and Efficiency
Generative AI applications are captivating the attention of businesses around the globe. Fine-tuning, augmenting and running LLMs requires significant investment and expertise. Together, NVIDIA and Anyscale can help reduce costs and complexity for generative AI development and deployment with a number of application integrations.
NVIDIA TensorRT-LLM, new open-source software announced last week, will support Anyscale offerings to supercharge LLM performance and efficiency to deliver cost savings. Also supported in the NVIDIA AI Enterprise software platform, Tensor-RT LLM automatically scales inference to run models in parallel over multiple GPUs, which can provide up to 8x higher performance when running on NVIDIA H100 Tensor Core GPUs, compared to prior-generation GPUs.
TensorRT-LLM automatically scales inference to run models in parallel over multiple GPUs and includes custom GPU kernels and optimizations for a wide range of popular LLM models. It also implements the new FP8 numerical format available in the NVIDIA H100 Tensor Core GPU Transformer Engine and offers an easy-to-use and customizable Python interface.
NVIDIA Triton Inference Server software supports inference across cloud, data center, edge and embedded devices on GPUs, CPUs and other processors. Its integration can enable Ray developers to boost efficiency when deploying AI models from multiple deep learning and machine learning frameworks, including TensorRT, TensorFlow, PyTorch, ONNX, OpenVINO, Python, RAPIDS XGBoost and more.
With the NVIDIA NeMo framework, Ray users will be able to easily fine-tune and customize LLMs with business data, paving the way for LLMs that understand the unique offerings of individual businesses.
NeMo is an end-to-end, cloud-native framework to build, customize and deploy generative AI models anywhere. It features training and inferencing frameworks, guardrailing toolkits, data curation tools and pretrained models, offering enterprises an easy, cost-effective and fast way to adopt generative AI.
Options for Open-Source or Fully Supported Production AI
Ray open source and the Anyscale Platform enable developers to effortlessly move from open source to deploying production AI at scale in the cloud.
The Anyscale Platform provides fully managed, enterprise-ready unified computing that makes it easy to build, deploy and manage scalable AI and Python applications using Ray, helping customers bring AI products to market faster at significantly lower cost.
Whether developers use Ray open source or the supported Anyscale Platform, Anyscale’s core functionality helps them easily orchestrate LLM workloads. The NVIDIA AI integration can help developers build, train, tune and scale AI with even greater efficiency.
Ray and the Anyscale Platform run on accelerated computing from leading clouds, with the option to run on hybrid or multi-cloud computing. This helps developers easily scale up as they need more computing to power a successful LLM deployment.
The collaboration will also enable developers to begin building models on their workstations through NVIDIA AI Workbench and scale them easily across hybrid or multi-cloud accelerated computing once it’s time to move to production.
NVIDIA AI integrations with Anyscale are in development and expected to be available by the end of the year.
Developers can sign up to get the latest news on this integration as well as a free 90-day evaluation of NVIDIA AI Enterprise.
To learn more, attend the Ray Summit in San Francisco this week or watch the demo video below.
See this notice regarding NVIDIA’s software roadmap.
NeILF++: Inter-Reflectable Light Fields for Geometry and Material Estimation
We present a novel differentiable rendering framework for joint geometry, material, and lighting estimation from multi-view images. In contrast to previous methods which assume a simplified environment map or co-located flashlights, in this work, we formulate the lighting of a static scene as one neural incident light field (NeILF) and one outgoing neural radiance field (NeRF). The key insight of the proposed method is the union of the incident and outgoing light fields through physically-based rendering and inter-reflections between surfaces, making it possible to disentangle the scene…Apple Machine Learning Research
MediaPipe FaceStylizer: On-device real-time few-shot face stylization
In recent years, we have witnessed rising interest across consumers and researchers in integrated augmented reality (AR) experiences using real-time face feature generation and editing functions in mobile applications, including short videos, virtual reality, and gaming. As a result, there is a growing demand for lightweight, yet high-quality face generation and editing models, which are often based on generative adversarial network (GAN) techniques. However, the majority of GAN models suffer from high computational complexity and the need for a large training dataset. In addition, it is also important to employ GAN models responsibly.
In this post, we introduce MediaPipe FaceStylizer, an efficient design for few-shot face stylization that addresses the aforementioned model complexity and data efficiency challenges while being guided by Google’s responsible AI Principles. The model consists of a face generator and a face encoder used as GAN inversion to map the image into latent code for the generator. We introduce a mobile-friendly synthesis network for the face generator with an auxiliary head that converts features to RGB at each level of the generator to generate high quality images from coarse to fine granularities. We also carefully designed the loss functions for the aforementioned auxiliary heads and combined them with the common GAN loss functions to distill the student generator from the teacher StyleGAN model, resulting in a lightweight model that maintains high generation quality. The proposed solution is available in open source through MediaPipe. Users can fine-tune the generator to learn a style from one or a few images using MediaPipe Model Maker, and deploy to on-device face stylization applications with the customized model using MediaPipe FaceStylizer.
Few-shot on-device face stylization
An end-to-end pipeline
Our goal is to build a pipeline to support users to adapt the MediaPipe FaceStylizer to different styles by fine-tuning the model with a few examples. To enable such a face stylization pipeline, we built the pipeline with a GAN inversion encoder and efficient face generator model (see below). The encoder and generator pipeline can then be adapted to different styles via a few-shot learning process. The user first sends a single or a few similar samples of the style images to MediaPipe ModelMaker to fine-tune the model. The fine-tuning process freezes the encoder module and only fine-tunes the generator. The training process samples multiple latent codes close to the encoding output of the input style images as the input to the generator. The generator is then trained to reconstruct an image of a person’s face in the style of the input style image by optimizing a joint adversarial loss function that also accounts for style and content. With such a fine-tuning process, the MediaPipe FaceStylizer can adapt to the customized style, which approximates the user’s input. It can then be applied to stylize test images of real human faces.
Generator: BlazeStyleGAN
The StyleGAN model family has been widely adopted for face generation and various face editing tasks. To support efficient on-device face generation, we based the design of our generator on StyleGAN. This generator, which we call BlazeStyleGAN, is similar to StyleGAN in that it also contains a mapping network and synthesis network. However, since the synthesis network of StyleGAN is the major contributor to the model’s high computation complexity, we designed and employed a more efficient synthesis network. The improved efficiency and generation quality is achieved by:
- Reducing the latent feature dimension in the synthesis network to a quarter of the resolution of the counterpart layers in the teacher StyleGAN,
- Designing multiple auxiliary heads to transform the downscaled feature to the image domain to form a coarse-to-fine image pyramid to evaluate the perceptual quality of the reconstruction, and
- Skipping all but the final auxiliary head at inference time.
With the newly designed architecture, we train the BlazeStyleGAN model by distilling it from a teacher StyleGAN model. We use a multi-scale perceptual loss and adversarial loss in the distillation to transfer the high fidelity generation capability from the teacher model to the student BlazeStyleGAN model and also to mitigate the artifacts from the teacher model.
More details of the model architecture and training scheme can be found in our paper.
In the above figure, we demonstrate some sample results of our BlazeStyleGAN. By comparing with the face image generated by the teacher StyleGAN model (top row), the images generated by the student BlazeStyleGAN (bottom row) maintain high visual quality and further reduce artifacts produced by the teacher due to the loss function design in our distillation.
An encoder for efficient GAN inversion
To support image-to-image stylization, we also introduced an efficient GAN inversion as the encoder to map input images to the latent space of the generator. The encoder is defined by a MobileNet V2 backbone and trained with natural face images. The loss is defined as a combination of image perceptual quality loss, which measures the content difference, style similarity and embedding distance, as well as the L1 loss between the input images and reconstructed images.
On-device performance
We documented model complexities in terms of parameter numbers and computing FLOPs in the following table. Compared to the teacher StyleGAN (33.2M parameters), BlazeStyleGAN (generator) significantly reduces the model complexity, with only 2.01M parameters and 1.28G FLOPs for output resolution 256×256. Compared to StyleGAN-1024 (generating image size of 1024×1024), the BlazeStyleGAN-1024 can reduce both model size and computation complexity by 95% with no notable quality difference and can even suppress the artifacts from the teacher StyleGAN model.
Model | Image Size | #Params (M) | FLOPs (G) | |||
StyleGAN | 1024 | 33.17 | 74.3 | |||
BlazeStyleGAN | 1024 | 2.07 | 4.70 | |||
BlazeStyleGAN | 512 | 2.05 | 1.57 | |||
BlazeStyleGAN | 256 | 2.01 | 1.28 | |||
Encoder | 256 | 1.44 | 0.60 |
Model complexity measured by parameter numbers and FLOPs. |
We benchmarked the inference time of the MediaPipe FaceStylizer on various high-end mobile devices and demonstrated the results in the table below. From the results, both BlazeStyleGAN-256 and BlazeStyleGAN-512 achieved real-time performance on all GPU devices. It can run in less than 10 ms runtime on a high-end phone’s GPU. BlazeStyleGAN-256 can also achieve real-time performance on the iOS devices’ CPU.
Model | BlazeStyleGAN-256 (ms) | Encoder-256 (ms) | ||
iPhone 11 | 12.14 | 11.48 | ||
iPhone 12 | 11.99 | 12.25 | ||
iPhone 13 Pro | 7.22 | 5.41 | ||
Pixel 6 | 12.24 | 11.23 | ||
Samsung Galaxy S10 | 17.01 | 12.70 | ||
Samsung Galaxy S20 | 8.95 | 8.20 |
Latency benchmark of the BlazeStyleGAN, face encoder, and the end-to-end pipeline on various mobile devices. |
Fairness evaluation
The model has been trained with a high diversity dataset of human faces. The model is expected to be fair to different human faces. The fairness evaluation demonstrates the model performs good and balanced in terms of human gender, skin-tone, and ages.
Face stylization visualization
Some face stylization results are demonstrated in the following figure. The images in the top row (in orange boxes) represent the style images used to fine-tune the model. The images in the left column (in the green boxes) are the natural face images used for testing. The 2×4 matrix of images represents the output of the MediaPipe FaceStylizer which is blending outputs between the natural faces on the left-most column and the corresponding face styles on the top row. The results demonstrate that our solution can achieve high-quality face stylization for several popular styles.
Sample results of our MediaPipe FaceStylizer. |
MediaPipe Solutions
The MediaPipe FaceStylizer is going to be released to public users in MediaPipe Solutions. Users can leverage MediaPipe Model Maker to train a customized face stylization model using their own style images. After training, the exported bundle of TFLite model files can be deployed to applications across platforms (Android, iOS, Web, Python, etc.) using the MediaPipe Tasks FaceStylizer API in just a few lines of code.
Acknowledgements
This work is made possible through a collaboration spanning several teams across Google. We’d like to acknowledge contributions from Omer Tov, Yang Zhao, Andrey Vakunov, Fei Deng, Ariel Ephrat, Inbar Mosseri, Lu Wang, Chuo-Ling Chang, Tingbo Hou, and Matthias Grundmann.
Test-time Adaptation with Slot-Centric Models
TLDR: Current SOTA methods for scene understanding, though impressive, often fail to decompose out-of-distribution scenes. In our ICML paper, Slot-TTA (http://slot-tta.github.io) we find that optimizing per test sample over reconstruction loss improves scene decomposition accuracy.
Problem Statement: In machine learning, we often assume the train and test split are IID samples from the same distribution. However, this doesn’t hold true in reality. In fact, there is a distribution shift happening all the time!
For example on the left, we visualize images from the ImageNet Chair category, and on the right, we visualize the ObjectNet chair category. As you can see there are a variety of real-world distribution shifts happening all the time. For instance, camera pose changes, occlusions, and changes in scene configuration.
So what is the issue? The issue is that in machine learning we always assume there to be a fixed train and test split. However, in the real world, there is no such universal train and test split, instead, there are distribution shifts happening all the time.
Instead of freezing our models at test time, which is what we conventionally do, we should instead continuously adapt them to various distribution shifts.
Given these issues, there has been a lot of work in this domain, which is also referred to as test-time adaptation. Test-time adaptation can be broadly classified into supervised test-time adaptation, where you are given access to a few labeled examples, or unsupervised domain adaptation where you do not have access to any labels. In this work, we focus on unsupervised adaptation, as it is a more general setting.
Within unsupervised domain adaptation, there are various settings such as batch, online, or single-example test-time adaptation. In this work, we focus on single-example setting. In this setting, the model adapts to each example in the test set independently. This is a lot more general setting than batch or online where you assume access to many unlabeled examples.
What is the prominent approach in this setting?
Sun, et al. proposed to encode input data X into a shared encoding of z which is then passed to a supervised task decoder and a self-supervised decoder. The whole model is then trained jointly using supervised and self-supervised losses. This joint training helps to couple the self-supervised and supervised tasks. Coupling allows test-time adaption using the self-supervised loss. Approaches vary based on the type of self-supervised loss used: TTT uses rotation prediction loss, MT3 uses instance prediction loss and TTT-MAE uses masked autoencoding loss.
However, all approaches only focus on the task of Image Classification. In our work, we find just joint training with losses is insufficient for Scene Understanding tasks. We find that architectural biases could be important for adaptation. Specifically, we use slot-centric biases that strongly couple scene decomposition and reconstruction loss are a perfect fit.
Slot-centric generative models attempt to segment scenes into object entities in a completely unsupervised manner, by optimizing a reconstruction objective [1,2,3] that shares the end goal of scene decomposition which can become a good candidate architecture for TTA.
These methods differ in detail but share the notion of incorporating a fixed set of entities, also known as slots or object files. Each slot extracts information about a single entity during encoding and is “synthesized” back to the input domain during decoding.
In light of the above, we propose Test-Time Adaptation with Slot-Centric models (Slot-TTA), a semi-supervised model equipped with a slot-centric bottleneck that jointly segments and reconstructs scenes.
At training time, Slot-TTA is trained in a supervised manner to jointly segment and reconstruct 2D (multi-view or single-view) RGB images or 3D point clouds. At test time, the model adapts to a single test sample by updating its network parameters solely by optimizing the reconstruction objective through gradient descent, as shown in the above figure.
Slot-TTA builds on top of slot-centric models by incorporating segmentation supervision during the training phase. Until now, slot-centric models have been neither designed nor utilized with the foresight of Test-Time Adaptation (TTA).
In particular, Engelcke et al. (2020) showed that TTA via reconstruction in slot-centric models fails due to a reconstruction segmentation trade-off: as the entity bottleneck loosens, there’s an improvement in reconstruction; however, segmentation subsequently deteriorates. We show that segmentation supervision aids in mitigating this trade-off and helps scale to scenes with complicated textures. We show that TTA in semi-supervised slot-centric models significantly improves scene decomposition.
Our contributions are as follows:
(i) We present an algorithm that significantly improves scene decomposition accuracy for out-of-distribution examples by performing test-time adaptation on each example in the test set independently.
(ii) We showcase the effectiveness of SSL-based TTA approaches for scene decomposition, while previous self-supervised test-time adaptation methods have primarily demonstrated results in classification tasks.
(iii) We introduce semi-supervised learning for slot-centric generative models, and show it can enable these methods to continue learning during test time. In contrast, previous works on slot-centric generative have neither been trained with supervision nor been used for test time adaptation.
(iv) Lastly, we devise numerous baselines and ablations, and evaluate them across multiple benchmarks and distribution shifts to offer valuable insights into test-time adaptation and object-centric learning.
Results: We test Slot-TTA on scene understanding tasks of novel view rendering and scene segmentation. We test on various input modalities such as multi-view posed images, single-view images, and 3D point clouds in the datasets of PartNet, MultiShapeNet-Hard, and CLEVR.
We compare Slot-TTA’s segmentation performance against state-of-the-art supervised feedforward RGB image and 3D point cloud segmentors of Mask2Former and Mask3D, state-of-the-art novel view rendering methods of SemanticNeRF that adapt per scene through RGB and segmentation rendering and state-of-the-art test-time adaptation methods such as MT3.
We show that Slot-TTA outperforms SOTA feedforward segmenters in out-of-distribution scenes, dramatically outperforms alternative TTA methods and alternative semi-supervised scene decomposition methods, and better exploits multiview information for improving segmentation over semantic NeRF-based multi-view fusion.
Below we show our multi-view RGB results on MultiShapeNet dataset of Kubrics.
We consider various distribution shifts throughout our paper, for the results below we consider the following distribution shift.
We use a train-test split of Multi-ShapeNet-Easy to Multi-ShapeNet-Hard where there is no overlap between object instances and between the number of objects present in the scene between training and test sets. Specifically, scenes with 5-7 object instances are in the training set, and scenes with 16-30 objects are in the test set.
We consider the following baselines:
(i) Mask2Former (Cheng et al., 2021), a state-of-the-art 2D image segmentor that extends detection transformers (Carion et al., 2020) to the task of image segmentation via using multiscale segmentation decoders with masked attention.
(ii) Mask2Former-BYOL which combines the segmentation model of Cheng et al. (2021) with test time adaptation using BYOL self-supervised loss of MT3 (Bartler et al. (2022)).
(iii) Mask2Former-Recon which combines the segmentation model of Cheng et al. (2021) with an RGB rendering module and an image reconstruction objective for test-time adaptation.
(iv) Semantic-NeRF (Zhi et al., 2021), a NeRF model that adds a segmentation rendering head to the multi-view RGB rendering head of traditional NeRFs. It is fit per scene on all available 9 RGB posed images and corresponding segmentation maps from Mask2Former as input.
(v) Slot-TTA-w/o supervision, a variant of our model that does not use any segmentation supervision; rather is trained only for cross-view image synthesis similar to OSRT (Sajjadi et al., 2022a).
Our conclusions are as follows:
(i) Slot-TTA with TTA outperforms Mask2Former in out-of-distribution scenes and has comparable performance within the training distribution.
(ii) Mask2Former-BYOL does not improve over Mask2Former, which suggests that adding self-supervised losses of SOTA image classification TTA methods (Bartler et al., 2022) to scene segmentation methods does not help.
(iii) Slot-TTA-w/o supervision (model identical to Sajjadi et al. (2022a)) greatly underperforms a supervised segmentor Mask2Former. This means that unsupervised slot-centric models are still far from reaching their supervised counterparts.
(iv) Slot-TTA-w/o supervision does not improve during test-time adaptation. This suggests segmentation supervision at training time is essential for effective TTA.
(v) Semantic-NeRF which fuses segmentation masks across
views in a geometrically consistent manner outperforms single-view segmentation performance of Mask2Former by 3%.
(vi) Slot-TTA which adapts model parameters of the segmentor at test time greatly outperforms Semantic-NeRF in OOD scenes.
(vii) Mask2Former-Recon performs worse with TTA, which suggests that the decoder’s design is very important for aligning the reconstruction and segmentation tasks.
For point clouds, we train the model using certain categories of PartNet and test it using a different set. For quantitative comparisons with the baselines please refer to our paper. As can be seen in the figure below, point cloud segmentation of Slot-TTA improves after optimizing over point cloud reconstruction loss.
For 2D RGB images, we train the model supervised on the CLEVR dataset and test it on CLEVR-Tex. For quantitative comparisons with the baselines please refer to our paper. As can be seen in the figure below, RGB segmentation of Slot-TTA improves after optimizing over RGB reconstruction loss.
Finally, we find that Slot-TTA doesn’t just improve the segmentation performance on out-of-distribution scenes, but also improves the performance on other downstream tasks such as novel view synthesis!
Conclusion: We presented Slot-TTA, a novel semi-supervised scene decomposition model equipped with a slot-centric image or point-cloud rendering component for test time adaptation. We showed Slot-TTA greatly improves instance segmentation on out-of-distribution scenes using test-time adaptation on reconstruction or novel view synthesis objectives. We compared with numerous baseline methods, ranging from state-of-the-art feedforward segmentors, to NERF-based TTA for multiview semantic fusion, to state-of-the-art TTA methods, to unsupervised or weakly supervised 2D and 3D generative models. We showed Slot-TTA compares favorably against all of them for scene decomposition of OOD scenes, while still being competitive within distribution.
Paper Authors; Mihir Prabhudesai, Anirudh Goyal, Sujoy Paul, Sjoerd van Steenkiste, Mehdi S. M. Sajjadi, Gaurav Aggarwal, Thomas Kipf, Deepak Pathak, Katerina Fragkiadaki.
Code: <https://github.com/mihirp1998/Slot-TTA>
Webpage: <https://slot-tta.github.io/>
Paper: <https://arxiv.org/abs/2203.11194>
Learn how to build and deploy tool-using LLM agents using AWS SageMaker JumpStart Foundation Models
Large language model (LLM) agents are programs that extend the capabilities of standalone LLMs with 1) access to external tools (APIs, functions, webhooks, plugins, and so on), and 2) the ability to plan and execute tasks in a self-directed fashion. Often, LLMs need to interact with other software, databases, or APIs to accomplish complex tasks. For example, an administrative chatbot that schedules meetings would require access to employees’ calendars and email. With access to tools, LLM agents can become more powerful—at the cost of additional complexity.
In this post, we introduce LLM agents and demonstrate how to build and deploy an e-commerce LLM agent using Amazon SageMaker JumpStart and AWS Lambda. The agent will use tools to provide new capabilities, such as answering questions about returns (“Is my return rtn001
processed?”) and providing updates about orders (“Could you tell me if order 123456
has shipped?”). These new capabilities require LLMs to fetch data from multiple data sources (orders
, returns
) and perform retrieval augmented generation (RAG).
To power the LLM agent, we use a Flan-UL2
model deployed as a SageMaker endpoint and use data retrieval tools built with AWS Lambda. The agent can subsequently be integrated with Amazon Lex and used as a chatbot inside websites or AWS Connect. We conclude the post with items to consider before deploying LLM agents to production. For a fully managed experience for building LLM agents, AWS also provides the agents for Amazon Bedrock feature (in preview).
A brief overview of LLM agent architectures
LLM agents are programs that use LLMs to decide when and how to use tools as necessary to complete complex tasks. With tools and task planning abilities, LLM agents can interact with outside systems and overcome traditional limitations of LLMs, such as knowledge cutoffs, hallucinations, and imprecise calculations. Tools can take a variety of forms, such as API calls, Python functions, or webhook-based plugins. For example, an LLM can use a “retrieval plugin” to fetch relevant context and perform RAG.
So what does it mean for an LLM to pick tools and plan tasks? There are numerous approaches (such as ReAct, MRKL, Toolformer, HuggingGPT, and Transformer Agents) to using LLMs with tools, and advancements are happening rapidly. But one simple way is to prompt an LLM with a list of tools and ask it to determine 1) if a tool is needed to satisfy the user query, and if so, 2) select the appropriate tool. Such a prompt typically looks like the following example and may include few-shot examples to improve the LLM’s reliability in picking the right tool.
More complex approaches involve using a specialized LLM that can directly decode “API calls” or “tool use,” such as GorillaLLM. Such finetuned LLMs are trained on API specification datasets to recognize and predict API calls based on instruction. Often, these LLMs require some metadata about available tools (descriptions, yaml, or JSON schema for their input parameters) in order to output tool invocations. This approach is taken by agents for Amazon Bedrock and OpenAI function calls. Note that LLMs generally need to be sufficiently large and complex in order to show tool selection ability.
Assuming task planning and tool selection mechanisms are chosen, a typical LLM agent program works in the following sequence:
- User request – The program takes a user input such as “Where is my order
123456
?” from some client application. - Plan next action(s) and select tool(s) to use – Next, the program uses a prompt to have the LLM generate the next action, for example, “Look up the orders table using
OrdersAPI
.” The LLM is prompted to suggest a tool name such asOrdersAPI
from a predefined list of available tools and their descriptions. Alternatively, the LLM could be instructed to directly generate an API call with input parameters such asOrdersAPI(12345)
.- Note that the next action may or may not involve using a tool or API. If not, the LLM would respond to user input without incorporating additional context from tools or simply return a canned response such as, “I cannot answer this question.”
- Parse tool request – Next, we need to parse out and validate the tool/action prediction suggested by the LLM. Validation is needed to ensure tool names, APIs, and request parameters aren’t hallucinated and that the tools are properly invoked according to specification. This parsing may require a separate LLM call.
- Invoke tool – Once valid tool name(s) and parameter(s) are ensured, we invoke the tool. This could be an HTTP request, function call, and so on.
- Parse output – The response from the tool may need additional processing. For example, an API call may result in a long JSON response, where only a subset of fields are of interest to the LLM. Extracting information in a clean, standardized format can help the LLM interpret the result more reliably.
- Interpret output – Given the output from the tool, the LLM is prompted again to make sense of it and decide whether it can generate the final answer back to the user or whether additional actions are required.
- Terminate or continue to step 2 – Either return a final answer or a default answer in the case of errors or timeouts.
Different agent frameworks execute the previous program flow differently. For example, ReAct combines tool selection and final answer generation into a single prompt, as opposed to using separate prompts for tool selection and answer generation. Also, this logic can be run in a single pass or run in a while statement (the “agent loop”), which terminates when the final answer is generated, an exception is thrown, or timeout occurs. What remains constant is that agents use the LLM as the centerpiece to orchestrate planning and tool invocations until the task terminates. Next, we show how to implement a simple agent loop using AWS services.
Solution overview
For this blog post, we implement an e-commerce support LLM agent that provides two functionalities powered by tools:
- Return status retrieval tool – Answer questions about the status of returns such as, “What is happening to my return
rtn001
?” - Order status retrieval tool – Track the status of orders such as, “What’s the status of my order
123456
?”
The agent effectively uses the LLM as a query router. Given a query (“What is the status of order 123456
?”), select the appropriate retrieval tool to query across multiple data sources (that is, returns and orders). We accomplish query routing by having the LLM pick among multiple retrieval tools, which are responsible for interacting with a data source and fetching context. This extends the simple RAG pattern, which assumes a single data source.
Both retrieval tools are Lambda functions that take an id (orderId
or returnId
) as input, fetches a JSON object from the data source, and converts the JSON into a human friendly representation string that’s suitable to be used by LLM. The data source in a real-world scenario could be a highly scalable NoSQL database such as DynamoDB, but this solution employs simple Python Dict
with sample data for demo purposes.
Additional functionalities can be added to the agent by adding Retrieval Tools and modifying prompts accordingly. This agent can be tested a standalone service that integrates with any UI over HTTP, which can be done easily with Amazon Lex.
Here are some additional details about the key components:
- LLM inference endpoint – The core of an agent program is an LLM. We will use SageMaker JumpStart foundation model hub to easily deploy the
Flan-UL2
model. SageMaker JumpStart makes it easy to deploy LLM inference endpoints to dedicated SageMaker instances. - Agent orchestrator – Agent orchestrator orchestrates the interactions among the LLM, tools, and the client app. For our solution, we use an AWS Lambda function to drive this flow and employ the following as helper functions.
- Task (tool) planner – Task planner uses the LLM to suggest one of 1) returns inquiry, 2) order inquiry, or 3) no tool. We use prompt engineering only and
Flan-UL2
model as-is without fine-tuning. - Tool parser – Tool parser ensures that the tool suggestion from task planner is valid. Notably, we ensure that a single
orderId
orreturnId
can be parsed. Otherwise, we respond with a default message. - Tool dispatcher – Tool dispatcher invokes tools (Lambda functions) using the valid parameters.
- Output parser – Output parser cleans and extracts relevant items from JSON into a human-readable string. This task is done both by each retrieval tool as well as within the orchestrator.
- Output interpreter – Output interpreter’s responsibility is to 1) interpret the output from tool invocation and 2) determine whether the user request can be satisfied or additional steps are needed. If the latter, a final response is generated separately and returned to the user.
- Task (tool) planner – Task planner uses the LLM to suggest one of 1) returns inquiry, 2) order inquiry, or 3) no tool. We use prompt engineering only and
Now, let’s dive a bit deeper into the key components: agent orchestrator, task planner, and tool dispatcher.
Agent orchestrator
Below is an abbreviated version of the agent loop inside the agent orchestrator Lambda function. The loop uses helper functions such as task_planner
or tool_parser
, to modularize the tasks. The loop here is designed to run at most two times to prevent the LLM from being stuck in a loop unnecessarily long.
Task planner (tool prediction)
The agent orchestrator uses task planner
to predict a retrieval tool based on user input. For our LLM agent, we will simply use prompt engineering and few shot prompting to teach the LLM this task in context. More sophisticated agents could use a fine-tuned LLM for tool prediction, which is beyond the scope of this post. The prompt is as follows:
Tool dispatcher
The tool dispatch mechanism works via if/else
logic to call appropriate Lambda functions depending on the tool’s name. The following is tool_dispatch
helper function’s implementation. It’s used inside the agent
loop and returns the raw response from the tool Lambda function, which is then cleaned by an output_parser
function.
Deploy the solution
Important prerequisites – To get started with the deployment, you need to fulfill the following prerequisites:
- Access to the AWS Management Console via a user who can launch AWS CloudFormation stacks
- Familiarity with navigating the AWS Lambda and Amazon Lex consoles
Flan-UL2
requires a singleml.g5.12xlarge
for deployment, which may necessitate increasing resource limits via a support ticket. In our example, we useus-east-1
as the Region, so please make sure to increase the service quota (if needed) inus-east-1
.
Deploy using CloudFormation – You can deploy the solution to us-east-1
by clicking the button below:
Deploying the solution will take about 20 minutes and will create a LLMAgentStack
stack, which:
- deploys the SageMaker endpoint using
Flan-UL2
model from SageMaker JumpStart; - deploys three Lambda functions:
LLMAgentOrchestrator
,LLMAgentReturnsTool
,LLMAgentOrdersTool
; and - deploys an AWS Lex bot that can be used to test the agent:
Sagemaker-Jumpstart-Flan-LLM-Agent-Fallback-Bot
.
Test the solution
The stack deploys an Amazon Lex bot with the name Sagemaker-Jumpstart-Flan-LLM-Agent-Fallback-Bot
. The bot can be used to test the agent end-to-end. Here’s an additional comprehensive guide for testing AWS Amazon Lex bots with a Lambda integration and how the integration works at a high level. But in short, Amazon Lex bot is a resource that provides a quick UI to chat with the LLM agent running inside a Lambda function that we built (LLMAgentOrchestrator
).
The sample test cases to consider are as follows:
- Valid order inquiry (for example, “Which item was ordered for
123456
?”)- Order “123456” is a valid order, so we should expect a reasonable answer (e.g. “Herbal Handsoap”)
- Valid return inquiry for a return (for example, “When is my return
rtn003
processed?”)- We should expect a reasonable answer about the return’s status.
- Irrelevant to both returns or orders (for example, “How is the weather in Scotland right now?”)
- An irrelevant question to returns or orders, thus a default answer should be returned (“Sorry, I cannot answer that question.”)
- Invalid order inquiry (for example, “Which item was ordered for
383833
?”)- The id 383832 does not exist in the orders dataset and hence we should fail gracefully (for example, “Order not found. Please check your Order ID.”)
- Invalid return inquiry (for example, “When is my return
rtn123
processed?”)- Similarly, id
rtn123
does not exist in the returns dataset, and hence should fail gracefully.
- Similarly, id
- Irrelevant return inquiry (for example, “What is the impact of return
rtn001
on world peace?”)- This question, while it seems to pertain to a valid order, is irrelevant. The LLM is used to filter questions with irrelevant context.
To run these tests yourself, here are the instructions.
- On the Amazon Lex console (AWS Console > Amazon Lex), navigate to the bot entitled
Sagemaker-Jumpstart-Flan-LLM-Agent-Fallback-Bot
. This bot has already been configured to call theLLMAgentOrchestrator
Lambda function whenever theFallbackIntent
is triggered. - In the navigation pane, choose Intents.
- Choose Build at the top right corner
- 4. Wait for the build process to complete. When it’s done, you get a success message, as shown in the following screenshot.
- Test the bot by entering the test cases.
Cleanup
To avoid additional charges, delete the resources created by our solution by following these steps:
- On the AWS CloudFormation console, select the stack named
LLMAgentStack
(or the custom name you picked). - Choose Delete
- Check that the stack is deleted from the CloudFormation console.
Important: double-check that the stack is successfully deleted by ensuring that the Flan-UL2
inference endpoint is removed.
- To check, go to AWS console > Sagemaker > Endpoints > Inference page.
- The page should list all active endpoints.
- Make sure
sm-jumpstart-flan-bot-endpoint
does not exist like the below screenshot.
Considerations for production
Deploying LLM agents to production requires taking extra steps to ensure reliability, performance, and maintainability. Here are some considerations prior to deploying agents in production:
- Selecting the LLM model to power the agent loop: For the solution discussed in this post, we used a
Flan-UL2
model without fine-tuning to perform task planning or tool selection. In practice, using an LLM that is fine-tuned to directly output tool or API requests can increase reliability and performance, as well as simplify development. We could fine-tune an LLM on tool selection tasks or use a model that directly decodes tool tokens like Toolformer.- Using fine-tuned models can also simplify adding, removing, and updating tools available to an agent. With prompt-only based approaches, updating tools requires modifying every prompt inside the agent orchestrator, such as those for task planning, tool parsing, and tool dispatch. This can be cumbersome, and the performance may degrade if too many tools are provided in context to the LLM.
- Reliability and performance: LLM agents can be unreliable, especially for complex tasks that cannot be completed within a few loops. Adding output validations, retries, structuring outputs from LLMs into JSON or yaml, and enforcing timeouts to provide escape hatches for LLMs stuck in loops can enhance reliability.
Conclusion
In this post, we explored how to build an LLM agent that can utilize multiple tools from the ground up, using low-level prompt engineering, AWS Lambda functions, and SageMaker JumpStart as building blocks. We discussed the architecture of LLM agents and the agent loop in detail. The concepts and solution architecture introduced in this blog post may be appropriate for agents that use a small number of a predefined set of tools. We also discussed several strategies for using agents in production. Agents for Bedrock, which is in preview, also provides a managed experience for building agents with native support for agentic tool invocations.
About the Author
John Hwang is a Generative AI Architect at AWS with special focus on Large Language Model (LLM) applications, vector databases, and generative AI product strategy. He is passionate about helping companies with AI/ML product development, and the future of LLM agents and co-pilots. Prior to joining AWS, he was a Product Manager at Alexa, where he helped bring conversational AI to mobile devices, as well as a derivatives trader at Morgan Stanley. He holds B.S. in computer science from Stanford University.
Making deep learning practical for Earth system forecasting
Novel “cuboid attention” helps transformers handle large-scale multidimensional data, while diffusion models enable probabilistic predictionRead More
On-device content distillation with graph neural networks
In today’s digital age, smartphones and desktop web browsers serve as the primary tools for accessing news and information. However, the proliferation of website clutter — encompassing complex layouts, navigation elements, and extraneous links — significantly impairs both the reading experience and article navigation. This issue is particularly acute for individuals with accessibility requirements.
To improve the user experience and make reading more accessible, Android and Chrome users may leverage the Reading Mode feature, which enhances accessibility by processing webpages to allow customizable contrast, adjustable text size, more legible fonts, and to enable text-to-speech utilities. Additionally, Android’s Reading Mode is equipped to distill content from apps. Expanding Reading Mode to encompass a wide array of content and improving its performance, while still operating locally on the user’s device without transmitting data externally, poses a unique challenge.
To broaden Reading Mode capabilities without compromising privacy, we have developed a novel on-device content distillation model. Unlike early attempts using DOM Distiller — a heuristic approach limited to news articles — our model excels in both quality and versatility across various types of content. We ensure that article content doesn’t leave the confines of the local environment. Our on-device content distillation model smoothly transforms long-form content into a simple and customizable layout for a more pleasant reading journey while also outperforming the leading alternative approaches. Here we explore details of this research highlighting our approach, methodology, and results.
Graph neural networks
Instead of relying on complicated heuristics that are difficult to maintain and scale to a variety of article layouts, we approach this task as a fully supervised learning problem. This data-driven approach allows the model to generalize better across different layouts, without the constraints and fragility of heuristics. Previous work for optimizing the reading experience relied on HTML or parsing, filtering, and modeling of a document object model (DOM), a programming interface automatically generated by the user’s web browser from site HTML that represents the structure of a document and allows it to be manipulated.
The new Reading Mode model relies on accessibility trees, which provide a streamlined and more accessible representation of the DOM. Accessibility trees are automatically generated from the DOM tree and are utilized by assistive technologies to allow people with disabilities to interact with web content. These are available on Chrome Web browser and on Android through AccessibilityNodeInfo objects, which are provided for both WebView and native application content.
We started by manually collecting and annotating accessibility trees. The Android dataset used for this project comprises on the order of 10k labeled examples, while the Chrome dataset contains approximately 100k labeled examples. We developed a novel tool that uses graph neural networks (GNNs) to distill essential content from the accessibility trees using a multi-class supervised learning approach. The datasets consist of long-form articles sampled from the web and labeled with classes such as headline, paragraph, images, publication date, etc.
GNNs are a natural choice for dealing with tree-like data structures, because unlike traditional models that often demand detailed, hand-crafted features to understand the layout and links within such trees, GNNs learn these connections naturally. To illustrate this, consider the analogy of a family tree. In such a tree, each node represents a family member and the connections denote familial relationships. If one were to predict certain traits using conventional models, features like the “number of immediate family members with a trait” might be needed. However, with GNNs, such manual feature crafting becomes redundant. By directly feeding the tree structure into the model, GNNs utilize a message-passing mechanism where each node communicates with its neighbors. Over time, information gets shared and accumulated across the network, enabling the model to naturally discern intricate relationships.
Returning to the context of accessibility trees, this means that GNNs can efficiently distill content by understanding and leveraging the inherent structure and relationships within the tree. This capability allows them to identify and possibly omit non-essential sections based on the information flow within the tree, ensuring more accurate content distillation.
Our architecture heavily follows the encode-process-decode paradigm using a message-passing neural network to classify text nodes. The overall design is illustrated in the figure below. The tree representation of the article is the input to the model. We compute lightweight features based on bounding box information, text information, and accessibility roles. The GNN then propagates each node’s latent representation through the edges of the tree using a message-passing neural network. This propagation process allows nearby nodes, containers, and text elements to share contextual information with each other, enhancing the model’s understanding of the page’s structure and content. Each node then updates its current state based on the message received, providing a more informed basis for classifying the nodes. After a fixed number of message-passing steps, the now contextualized latent representations of the nodes are decoded into essential or non-essential classes. This approach enables the model to leverage both the inherent relationships in the tree and the hand-crafted features representing each node, thereby enriching the final classification.
We deliberately restrict the feature set used by the model to increase its broad generalization across languages and speed up inference latency on user devices. This was a unique challenge, as we needed to create an on-device lightweight model that could preserve privacy.
Our final lightweight Android model has 64k parameters and is 334kB in size with a median latency of 800ms, while the Chrome model has 241k parameters, is 928kB in size, and has a 378ms median latency. By employing such on-device processing, we ensure that user data never leaves the device, reinforcing our responsible approach and commitment to user privacy. The features used in the model can be grouped into intermediate node features, leaf-node text features, and element position features. We performed feature engineering and feature selection to optimize the set of features for model performance and model size. The final model was transformed into TensorFlow Lite format to deploy as an on-device model on Android or Chrome.
Results
We trained the GNN for about 50 epochs in a single GPU. The performance of the Android model on webpages and native application test sets is presented below:
The table presents the content distillation metrics in Android for webpages and native apps. We report precision, recall and F1-score for three classes: non-essential content, headline, and main body text, including macro average and weighted average by number of instances in each class. Node metrics assess the classification performance at the granularity of the accessibility tree node, which is analogous to a paragraph level. In contrast, word metrics evaluate classification at an individual word level, meaning each word within a node gets the same classification. |
<!–
Android Quality | ||||||||||||
Webpages | Native Apps | |||||||||||
node metrics | Precision | Recall | F1-score | Precision | Recall | F1-score | ||||||
non-essential | 0.9842 | 0.9846 | 0.9844 | 0.9744 | 0.9350 | 0.9543 | ||||||
headline | 0.9187 | 0.9784 | 0.9476 | 0.9183 | 0.8568 | 0.8865 | ||||||
main-text | 0.9223 | 0.9172 | 0.9197 | 0.8443 | 0.9424 | 0.8907 | ||||||
macro-average | 0.9417 | 0.9600 | 0.9506 | 0.9124 | 0.9114 | 0.9105 | ||||||
weighted average | 0.9736 | 0.9736 | 0.9736 | 0.9392 | 0.9353 | 0.9363 | ||||||
headline + main-text | 0.9510 | 0.9683 | 0.9595 | 0.9473 | 0.9507 | 0.9490 |
The table presents the content distillation metrics in Android for webpages and native apps. We report precision, recall and F1-score for three classes: non-essential content, headline, and main body text, including macro average and weighted average by number of instances in each class. Node metrics assess the classification performance at the granularity of the accessibility tree node, which is analogous to a paragraph level. In contrast, word metrics evaluate classification at an individual word level, meaning each word within a node gets the same classification. |
–>
In assessing the results’ quality on commonly visited webpage articles, an F1-score exceeding 0.9 for main-text (essentially paragraphs) corresponds to 88% of these articles being processed without missing any paragraphs. Furthermore, in over 95% of cases, the distillation proves to be valuable for readers. Put simply, the vast majority of readers will perceive the distilled content as both pertinent and precise, with errors or omissions being an infrequent occurrence.
The comparison of Chrome content distillation with other models such as DOM Distiller or Mozilla Readability on a set of English language pages is presented in the table below. We reuse the metrics from machine translation to compare the quality of these models. The reference text is from the groundtruth main content and the text from the models as hypothesis text. The results show the excellent performance of our models in comparison to other DOM-based approaches.
The table presents the comparison between DOM-Distiller, Mozilla Readability and the new Chrome model. We report text-based metrics, such as BLUE, CHRF and ROUGE, by comparing the main body text distilled from each model to a ground-truth text manually labeled by raters using our annotation policy. |
<!–
Chrome Model Comparison on Webpages | ||||||
Metric / Model | DOM Distiller | Mozilla Readability | Our Chrome model | |||
BLEU | 78.97 | 79.16 | 94.59 | |||
CHRF | 0.92 | 0.92 | 0.98 | |||
ROUGE1 | 84.10 | 84.62 | 95.13 | |||
ROUGE2 | 81.84 | 82.66 | 94.81 | |||
ROUGE3 | 80.21 | 81.45 | 94.60 | |||
ROUGEL | 83.58 | 84.02 | 95.04 | |||
ROUGEL-SUM | 83.46 | 84.03 | 95.04 |
The table presents the comparison between DOM-Distiller, Mozilla Readability and the new Chrome model. We report text-based metrics, such as BLUE, CHRF and ROUGE, by comparing the main body text distilled from each model to a ground-truth text manually labeled by raters using our annotation policy. |
–>
The F1-score of the Chrome content distillation model for headline and main text content on the test sets of different widely spoken languages demonstrates that the Chrome model, in particular, is able to support a wide range of languages.
<!–
Chrome Model on Different Languages | ||||||||||||||||||||||||||
F1-score | de | en | es | fr | it | fa | ja | ko | pt | vi | zh-Hans | zh-Hant | average | |||||||||||||
headline | 0.91 | 0.97 | 0.99 | 0.98 | 0.97 | 0.89 | 0.97 | 0.98 | 0.99 | 0.98 | 0.97 | 0.93 | 0.96 | |||||||||||||
main text | 0.84 | 0.90 | 0.93 | 0.91 | 0.93 | 0.87 | 0.88 | 0.91 | 0.91 | 0.90 | 0.90 | 0.90 | 0.90 |
The table presents per language of F1-scores of the Chrome model for the headline and main text classes. The language codes correspond to the following languages: German, English, Spanish, French, Italian, Persian, Japanese, Korean, Portuguese, Vietnamese, simplified Chinese and traditional Chinese. |
–>
Conclusion
The digital age demands both streamlined content presentation and an unwavering commitment to user privacy. Our research highlights the effectiveness of Reading Mode in platforms like Android and Chrome, offering an innovative, data-driven approach to content parsing through Graph Neural Networks. Crucially, our lightweight on-device model ensures that content distillation occurs without compromising user data, with all processes executed locally. This not only enhances the reading experience but also reinforces our dedication to user privacy. As we navigate the evolving landscape of digital content consumption, our findings underscore the paramount importance of prioritizing the user in both experience and security.
Acknowledgements
This project is the result of joint work with Manuel Tragut, Mihai Popa, Abodunrinwa Toki, Abhanshu Sharma, Matt Sharifi, David Petrou and Blaise Aguera y Arcas. We sincerely thank our collaborators Gang Li and Yang Li. We are very grateful to Tom Small for assisting us in preparing the post.