AutoBNN: Probabilistic time series forecasting with compositional bayesian neural networks

AutoBNN: Probabilistic time series forecasting with compositional bayesian neural networks

Time series problems are ubiquitous, from forecasting weather and traffic patterns to understanding economic trends. Bayesian approaches start with an assumption about the data’s patterns (prior probability), collecting evidence (e.g., new time series data), and continuously updating that assumption to form a posterior probability distribution. Traditional Bayesian approaches like Gaussian processes (GPs) and Structural Time Series are extensively used for modeling time series data, e.g., the commonly used Mauna Loa CO2 dataset. However, they often rely on domain experts to painstakingly select appropriate model components and may be computationally expensive. Alternatives such as neural networks lack interpretability, making it difficult to understand how they generate forecasts, and don’t produce reliable confidence intervals.

To that end, we introduce AutoBNN, a new open-source package written in JAX. AutoBNN automates the discovery of interpretable time series forecasting models, provides high-quality uncertainty estimates, and scales effectively for use on large datasets. We describe how AutoBNN combines the interpretability of traditional probabilistic approaches with the scalability and flexibility of neural networks.

AutoBNN

AutoBNN is based on a line of research that over the past decade has yielded improved predictive accuracy by modeling time series using GPs with learned kernel structures. The kernel function of a GP encodes assumptions about the function being modeled, such as the presence of trends, periodicity or noise. With learned GP kernels, the kernel function is defined compositionally: it is either a base kernel (such as Linear, Quadratic, Periodic, Matérn or ExponentiatedQuadratic) or a composite that combines two or more kernel functions using operators such as Addition, Multiplication, or ChangePoint. This compositional kernel structure serves two related purposes. First, it is simple enough that a user who is an expert about their data, but not necessarily about GPs, can construct a reasonable prior for their time series. Second, techniques like Sequential Monte Carlo can be used for discrete searches over small structures and can output interpretable results.

AutoBNN improves upon these ideas, replacing the GP with Bayesian neural networks (BNNs) while retaining the compositional kernel structure. A BNN is a neural network with a probability distribution over weights rather than a fixed set of weights. This induces a distribution over outputs, capturing uncertainty in the predictions. BNNs bring the following advantages over GPs: First, training large GPs is computationally expensive, and traditional training algorithms scale as the cube of the number of data points in the time series. In contrast, for a fixed width, training a BNN will often be approximately linear in the number of data points. Second, BNNs lend themselves better to GPU and TPU hardware acceleration than GP training operations. Third, compositional BNNs can be easily combined with traditional deep BNNs, which have the ability to do feature discovery. One could imagine “hybrid” architectures, in which users specify a top-level structure of Add(Linear, Periodic, Deep), and the deep BNN is left to learn the contributions from potentially high-dimensional covariate information.

How might one translate a GP with compositional kernels into a BNN then? A single layer neural network will typically converge to a GP as the number of neurons (or “width”) goes to infinity. More recently, researchers have discovered a correspondence in the other direction — many popular GP kernels (such as Matern, ExponentiatedQuadratic, Polynomial or Periodic) can be obtained as infinite-width BNNs with appropriately chosen activation functions and weight distributions. Furthermore, these BNNs remain close to the corresponding GP even when the width is very much less than infinite. For example, the figures below show the difference in the covariance between pairs of observations, and regression results of the true GPs and their corresponding width-10 neural network versions.

Comparison of Gram matrices between true GP kernels (top row) and their width 10 neural network approximations (bottom row).
Comparison of regression results between true GP kernels (top row) and their width 10 neural network approximations (bottom row).

Finally, the translation is completed with BNN analogues of the Addition and Multiplication operators over GPs, and input warping to produce periodic kernels. BNN addition is straightforwardly given by adding the outputs of the component BNNs. BNN multiplication is achieved by multiplying the activations of the hidden layers of the BNNs and then applying a shared dense layer. We are therefore limited to only multiplying BNNs with the same hidden width.

Using AutoBNN

The AutoBNN package is available within Tensorflow Probability. It is implemented in JAX and uses the flax.linen neural network library. It implements all of the base kernels and operators discussed so far (Linear, Quadratic, Matern, ExponentiatedQuadratic, Periodic, Addition, Multiplication) plus one new kernel and three new operators:

  • a OneLayer kernel, a single hidden layer ReLU BNN,
  • a ChangePoint operator that allows smoothly switching between two kernels,
  • a LearnableChangePoint operator which is the same as ChangePoint except position and slope are given prior distributions and can be learnt from the data, and
  • a WeightedSum operator.

WeightedSum combines two or more BNNs with learnable mixing weights, where the learnable weights follow a Dirichlet prior. By default, a flat Dirichlet distribution with concentration 1.0 is used.

WeightedSums allow a “soft” version of structure discovery, i.e., training a linear combination of many possible models at once. In contrast to structure discovery with discrete structures, such as in AutoGP, this allows us to use standard gradient methods to learn structures, rather than using expensive discrete optimization. Instead of evaluating potential combinatorial structures in series, WeightedSum allows us to evaluate them in parallel.

To easily enable exploration, AutoBNN defines a number of model structures that contain either top-level or internal WeightedSums. The names of these models can be used as the first parameter in any of the estimator constructors, and include things like sum_of_stumps (the WeightedSum over all the base kernels) and sum_of_shallow (which adds all possible combinations of base kernels with all operators).

Illustration of the sum_of_stumps model. The bars in the top row show the amount by which each base kernel contributes, and the bottom row shows the function represented by the base kernel. The resulting weighted sum is shown on the right.

The figure below demonstrates the technique of structure discovery on the N374 (a time series of yearly financial data starting from 1949) from the M3 dataset. The six base structures were ExponentiatedQuadratic (which is the same as the Radial Basis Function kernel, or RBF for short), Matern, Linear, Quadratic, OneLayer and Periodic kernels. The figure shows the MAP estimates of their weights over an ensemble of 32 particles. All of the high likelihood particles gave a large weight to the Periodic component, low weights to Linear, Quadratic and OneLayer, and a large weight to either RBF or Matern.

Parallel coordinates plot of the MAP estimates of the base kernel weights over 32 particles. The sum_of_stumps model was trained on the N374 series from the M3 dataset (insert in blue). Darker lines correspond to particles with higher likelihoods.

By using WeightedSums as the inputs to other operators, it is possible to express rich combinatorial structures, while keeping models compact and the number of learnable weights small. As an example, we include the sum_of_products model (illustrated in the figure below) which first creates a pairwise product of two WeightedSums, and then a sum of the two products. By setting some of the weights to zero, we can create many different discrete structures. The total number of possible structures in this model is 216, since there are 16 base kernels that can be turned on or off. All these structures are explored implicitly by training just this one model.

Illustration of the “sum_of_products” model. Each of the four WeightedSums have the same structure as the “sum_of_stumps” model.

We have found, however, that certain combinations of kernels (e.g., the product of Periodic and either the Matern or ExponentiatedQuadratic) lead to overfitting on many datasets. To prevent this, we have defined model classes like sum_of_safe_shallow that exclude such products when performing structure discovery with WeightedSums.

For training, AutoBNN provides AutoBnnMapEstimator and AutoBnnMCMCEstimator to perform MAP and MCMC inference, respectively. Either estimator can be combined with any of the six likelihood functions, including four based on normal distributions with different noise characteristics for continuous data and two based on the negative binomial distribution for count data.

Result from running AutoBNN on the Mauna Loa CO2 dataset in our example colab. The model captures the trend and seasonal component in the data. Extrapolating into the future, the mean prediction slightly underestimates the actual trend, while the 95% confidence interval gradually increases.

To fit a model like in the figure above, all it takes is the following 10 lines of code, using the scikit-learn–inspired estimator interface:

import autobnn as ab

model = ab.operators.Add(
    bnns=(ab.kernels.PeriodicBNN(width=50),
          ab.kernels.LinearBNN(width=50),
          ab.kernels.MaternBNN(width=50)))

estimator = ab.estimators.AutoBnnMapEstimator(
    model, 'normal_likelihood_logistic_noise', jax.random.PRNGKey(42),
    periods=[12])

estimator.fit(my_training_data_xs, my_training_data_ys)
low, mid, high = estimator.predict_quantiles(my_training_data_xs)

Conclusion

AutoBNN provides a powerful and flexible framework for building sophisticated time series prediction models. By combining the strengths of BNNs and GPs with compositional kernels, AutoBNN opens a world of possibilities for understanding and forecasting complex data. We invite the community to try the colab, and leverage this library to innovate and solve real-world challenges.

Acknowledgements

AutoBNN was written by Colin Carroll, Thomas Colthurst, Urs Köster and Srinivas Vasudevan. We would like to thank Kevin Murphy, Brian Patton and Feras Saad for their advice and feedback.

Read More

Computer-aided diagnosis for lung cancer screening

Computer-aided diagnosis for lung cancer screening

Lung cancer is the leading cause of cancer-related deaths globally with 1.8 million deaths reported in 2020. Late diagnosis dramatically reduces the chances of survival. Lung cancer screening via computed tomography (CT), which provides a detailed 3D image of the lungs, has been shown to reduce mortality in high-risk populations by at least 20% by detecting potential signs of cancers earlier. In the US, screening involves annual scans, with some countries or cases recommending more or less frequent scans.

The United States Preventive Services Task Force recently expanded lung cancer screening recommendations by roughly 80%, which is expected to increase screening access for women and racial and ethnic minority groups. However, false positives (i.e., incorrectly reporting a potential cancer in a cancer-free patient) can cause anxiety and lead to unnecessary procedures for patients while increasing costs for the healthcare system. Moreover, efficiency in screening a large number of individuals can be challenging depending on healthcare infrastructure and radiologist availability.

At Google we have previously developed machine learning (ML) models for lung cancer detection, and have evaluated their ability to automatically detect and classify regions that show signs of potential cancer. Performance has been shown to be comparable to that of specialists in detecting possible cancer. While they have achieved high performance, effectively communicating findings in realistic environments is necessary to realize their full potential.

To that end, in “Assistive AI in Lung Cancer Screening: A Retrospective Multinational Study in the US and Japan”, published in Radiology AI, we investigate how ML models can effectively communicate findings to radiologists. We also introduce a generalizable user-centric interface to help radiologists leverage such models for lung cancer screening. The system takes CT imaging as input and outputs a cancer suspicion rating using four categories (no suspicion, probably benign, suspicious, highly suspicious) along with the corresponding regions of interest. We evaluate the system’s utility in improving clinician performance through randomized reader studies in both the US and Japan, using the local cancer scoring systems (Lung-RADSs V1.1 and Sendai Score) and image viewers that mimic realistic settings. We found that reader specificity increases with model assistance in both reader studies. To accelerate progress in conducting similar studies with ML models, we have open-sourced code to process CT images and generate images compatible with the picture archiving and communication system (PACS) used by radiologists.

Developing an interface to communicate model results

Integrating ML models into radiologist workflows involves understanding the nuances and goals of their tasks to meaningfully support them. In the case of lung cancer screening, hospitals follow various country-specific guidelines that are regularly updated. For example, in the US, Lung-RADs V1.1 assigns an alpha-numeric score to indicate the lung cancer risk and follow-up recommendations. When assessing patients, radiologists load the CT in their workstation to read the case, find lung nodules or lesions, and apply set guidelines to determine follow-up decisions.

Our first step was to improve the previously developed ML models through additional training data and architectural improvements, including self-attention. Then, instead of targeting specific guidelines, we experimented with a complementary way of communicating AI results independent of guidelines or their particular versions. Specifically, the system output offers a suspicion rating and localization (regions of interest) for the user to consider in conjunction with their own specific guidelines. The interface produces output images directly associated with the CT study, requiring no changes to the user’s workstation. The radiologist only needs to review a small set of additional images. There is no other change to their system or interaction with the system.

Example of the assistive lung cancer screening system outputs. Results for the radiologist’s evaluation are visualized on the location of the CT volume where the suspicious lesion is found. The overall suspicion is displayed at the top of the CT images. Circles highlight the suspicious lesions while squares show a rendering of the same lesion from a different perspective, called a sagittal view.

The assistive lung cancer screening system comprises 13 models and has a high-level architecture similar to the end-to-end system used in prior work. The models coordinate with each other to first segment the lungs, obtain an overall assessment, locate three suspicious regions, then use the information to assign a suspicion rating to each region. The system was deployed on Google Cloud using a Google Kubernetes Engine (GKE) that pulled the images, ran the ML models, and provided results. This allows scalability and directly connects to servers where the images are stored in DICOM stores.

Outline of the Google Cloud deployment of the assistive lung cancer screening system and the directional calling flow for the individual components that serve the images and compute results. Images are served to the viewer and to the system using Google Cloud services. The system is run on a Google Kubernetes Engine that pulls the images, processes them, and writes them back into the DICOM store.

Reader studies

To evaluate the system’s utility in improving clinical performance, we conducted two reader studies (i.e., experiments designed to assess clinical performance comparing expert performance with and without the aid of a technology) with 12 radiologists using pre-existing, de-identified CT scans. We presented 627 challenging cases to 6 US-based and 6 Japan-based radiologists. In the experimental setup, readers were divided into two groups that read each case twice, with and without assistance from the model. Readers were asked to apply scoring guidelines they typically use in their clinical practice and report their overall suspicion of cancer for each case. We then compared the results of the reader’s responses to measure the impact of the model on their workflow and decisions. The score and suspicion level were judged against the actual cancer outcomes of the individuals to measure sensitivity, specificity, and area under the ROC curve (AUC) values. These were compared with and without assistance.

A multi-case multi-reader study involves each case being reviewed by each reader twice, once with ML system assistance and once without. In this visualization one reader first reviews Set A without assistance (blue) and then with assistance (orange) after a wash-out period. A second reader group follows the opposite path by reading the same set of cases Set A with assistance first. Readers are randomized to these groups to remove the effect of ordering.

The ability to conduct these studies using the same interface highlights its generalizability to completely different cancer scoring systems, and the generalization of the model and assistive capability to different patient populations. Our study results demonstrated that when radiologists used the system in their clinical evaluation, they had an increased ability to correctly identify lung images without actionable lung cancer findings (i.e., specificity) by an absolute 5–7% compared to when they didn’t use the assistive system. This potentially means that for every 15–20 patients screened, one may be able to avoid unnecessary follow-up procedures, thus reducing their anxiety and the burden on the health care system. This can, in turn, help improve the sustainability of lung cancer screening programs, particularly as more people become eligible for screening.

Reader specificity increases with ML model assistance in both the US-based and Japan-based reader studies. Specificity values were derived from reader scores from actionable findings (something suspicious was found) versus no actionable findings, compared against the true cancer outcome of the individual. Under model assistance, readers flagged fewer cancer-negative individuals for follow-up visits. Sensitivity for cancer positive individuals remained the same.

Translating this into real-world impact through partnership

The system results demonstrate the potential for fewer follow-up visits, reduced anxiety, as well lower overall costs for lung cancer screening. In an effort to translate this research into real-world clinical impact, we are working with: DeepHealth, a leading AI-powered health informatics provider; and Apollo Radiology International a leading provider of Radiology services in India to explore paths for incorporating this system into future products. In addition, we are looking to help other researchers studying how best to integrate ML model results into clinical workflows by open sourcing code used for the reader study and incorporating the insights described in this blog. We hope that this will help accelerate medical imaging researchers looking to conduct reader studies for their AI models, and catalyze translational research in the field.

Acknowledgements

Key contributors to this project include Corbin Cunningham, Zaid Nabulsi, Ryan Najafi, Jie Yang, Charles Lau, Joseph R. Ledsam, Wenxing Ye, Diego Ardila, Scott M. McKinney, Rory Pilgrim, Hiroaki Saito, Yasuteru Shimamura, Mozziyar Etemadi, Yun Liu, David Melnick, Sunny Jansen, Nadia Harhen, David P. Nadich, Mikhail Fomitchev, Ziyad Helali, Shabir Adeel, Greg S. Corrado, Lily Peng, Daniel Tse, Shravya Shetty, Shruthi Prabhakara, Neeral Beladia, and Krish Eswaran. Thanks to Arnav Agharwal and Andrew Sellergren for their open sourcing support and Vivek Natarajan and Michael D. Howell for their feedback. Sincere appreciation also goes to the radiologists who enabled this work with their image interpretation and annotation efforts throughout the study, and Jonny Wong and Carli Sampson for coordinating the reader studies.

Read More

Using AI to expand global access to reliable flood forecasts

Using AI to expand global access to reliable flood forecasts

Floods are the most common natural disaster, and are responsible for roughly $50 billion in annual financial damages worldwide. The rate of flood-related disasters has more than doubled since the year 2000 partly due to climate change. Nearly 1.5 billion people, making up 19% of the world’s population, are exposed to substantial risks from severe flood events. Upgrading early warning systems to make accurate and timely information accessible to these populations can save thousands of lives per year.

Driven by the potential impact of reliable flood forecasting on people’s lives globally, we started our flood forecasting effort in 2017. Through this multi-year journey, we advanced research over the years hand-in-hand with building a real-time operational flood forecasting system that provides alerts on Google Search, Maps, Android notifications and through the Flood Hub. However, in order to scale globally, especially in places where accurate local data is not available, more research advances were required.

In “Global prediction of extreme floods in ungauged watersheds”, published in Nature, we demonstrate how machine learning (ML) technologies can significantly improve global-scale flood forecasting relative to the current state-of-the-art for countries where flood-related data is scarce. With these AI-based technologies we extended the reliability of currently-available global nowcasts, on average, from zero to five days, and improved forecasts across regions in Africa and Asia to be similar to what are currently available in Europe. The evaluation of the models was conducted in collaboration with the European Center for Medium Range Weather Forecasting (ECMWF).

These technologies also enable Flood Hub to provide real-time river forecasts up to seven days in advance, covering river reaches across over 80 countries. This information can be used by people, communities, governments and international organizations to take anticipatory action to help protect vulnerable populations.

Flood forecasting at Google

The ML models that power the FloodHub tool are the product of many years of research, conducted in collaboration with several partners, including academics, governments, international organizations, and NGOs.

In 2018, we launched a pilot early warning system in the Ganges-Brahmaputra river basin in India, with the hypothesis that ML could help address the challenging problem of reliable flood forecasting at scale. The pilot was further expanded the following year via the combination of an inundation model, real-time water level measurements, the creation of an elevation map and hydrologic modeling.

In collaboration with academics, and, in particular, with the JKU Institute for Machine Learning we explored ML-based hydrologic models, showing that LSTM-based models could produce more accurate simulations than traditional conceptual and physics-based hydrology models. This research led to flood forecasting improvements that enabled the expansion of our forecasting coverage to include all of India and Bangladesh. We also worked with researchers at Yale University to test technological interventions that increase the reach and impact of flood warnings.

Our hydrological models predict river floods by processing publicly available weather data like precipitation and physical watershed information. Such models must be calibrated to long data records from streamflow gauging stations in individual rivers. A low percentage of global river watersheds (basins) have streamflow gauges, which are expensive but necessary to supply relevant data, and it’s challenging for hydrological simulation and forecasting to provide predictions in basins that lack this infrastructure. Lower gross domestic product (GDP) is correlated with increased vulnerability to flood risks, and there is an inverse correlation between national GDP and the amount of publicly available data in a country. ML helps to address this problem by allowing a single model to be trained on all available river data and to be applied to ungauged basins where no data are available. In this way, models can be trained globally, and can make predictions for any river location.

There is an inverse (log-log) correlation between the amount of publicly available streamflow data in a country and national GDP. Streamflow data from the Global Runoff Data Center.

Our academic collaborations led to ML research that developed methods to estimate uncertainty in river forecasts and showed how ML river forecast models synthesize information from multiple data sources. They demonstrated that these models can simulate extreme events reliably, even when those events are not part of the training data. In an effort to contribute to open science, in 2023 we open-sourced a community-driven dataset for large-sample hydrology in Nature Scientific Data.

The river forecast model

Most hydrology models used by national and international agencies for flood forecasting and river modeling are state-space models, which depend only on daily inputs (e.g., precipitation, temperature, etc.) and the current state of the system (e.g., soil moisture, snowpack, etc.). LSTMs are a variant of state-space models and work by defining a neural network that represents a single time step, where input data (such as current weather conditions) are processed to produce updated state information and output values (streamflow) for that time step. LSTMs are applied sequentially to make time-series predictions, and in this sense, behave similarly to how scientists typically conceptualize hydrologic systems. Empirically, we have found that LSTMs perform well on the task of river forecasting.

A diagram of the LSTM, which is a neural network that operates sequentially in time. An accessible primer can be found here.

Our river forecast model uses two LSTMs applied sequentially: (1) a “hindcast” LSTM ingests historical weather data (dynamic hindcast features) up to the present time (or rather, the issue time of a forecast), and (2) a “forecast” LSTM ingests states from the hindcast LSTM along with forecasted weather data (dynamic forecast features) to make future predictions. One year of historical weather data are input into the hindcast LSTM, and seven days of forecasted weather data are input into the forecast LSTM. Static features include geographical and geophysical characteristics of watersheds that are input into both the hindcast and forecast LSTMs and allow the model to learn different hydrological behaviors and responses in various types of watersheds.

Output from the forecast LSTM is fed into a “head” layer that uses mixture density networks to produce a probabilistic forecast (i.e., predicted parameters of a probability distribution over streamflow). Specifically, the model predicts the parameters of a mixture of heavy-tailed probability density functions, called asymmetric Laplacian distributions, at each forecast time step. The result is a mixture density function, called a Countable Mixture of Asymmetric Laplacians (CMAL) distribution, which represents a probabilistic prediction of the volumetric flow rate in a particular river at a particular time.

LSTM-based river forecast model architecture. Two LSTMs are applied in sequence, one ingesting historical weather data and one ingesting forecasted weather data. The model outputs are the parameters of a probability distribution over streamflow at each forecasted timestep.

Input and training data

The model uses three types of publicly available data inputs, mostly from governmental sources:

  1. Static watershed attributes representing geographical and geophysical variables: From the HydroATLAS project, including data like long-term climate indexes (precipitation, temperature, snow fractions), land cover, and anthropogenic attributes (e.g., a nighttime lights index as a proxy for human development).
  2. Historical meteorological time-series data: Used to spin up the model for one year prior to the issue time of a forecast. The data comes from NASA IMERG, NOAA CPC Global Unified Gauge-Based Analysis of Daily Precipitation, and the ECMWF ERA5-land reanalysis. Variables include daily total precipitation, air temperature, solar and thermal radiation, snowfall, and surface pressure.
  3. Forecasted meteorological time series over a seven-day forecast horizon: Used as input for the forecast LSTM. These data are the same meteorological variables listed above, and come from the ECMWF HRES atmospheric model.

Training data are daily streamflow values from the Global Runoff Data Center over the time period 1980 – 2023. A single streamflow forecast model is trained using data from 5,680 diverse watershed streamflow gauges (shown below) to improve accuracy.

Location of 5,680 streamflow gauges that supply training data for the river forecast model from the Global Runoff Data Center.

Improving on the current state-of-the-art

We compared our river forecast model with GloFAS version 4, the current state-of-the-art global flood forecasting system. These experiments showed that ML can provide accurate warnings earlier and over larger and more impactful events.

The figure below shows the distribution of F1 scores when predicting different severity events at river locations around the world, with plus or minus 1 day accuracy. F1 scores are an average of precision and recall and event severity is measured by return period. For example, a 2-year return period event is a volume of streamflow that is expected to be exceeded on average once every two years. Our model achieves reliability scores at up to 4-day or 5-day lead times that are similar to or better, on average, than the reliability of GloFAS nowcasts (0-day lead time).

Distributions of F1 scores over 2-year return period events in 2,092 watersheds globally during the time period 2014-2023 from GloFAS (blue) and our model (orange) at different lead times. On average, our model is statistically as accurate as GloFAS nowcasts (0–day lead time) up to 5 days in advance over 2-year (shown) and 1-year, 5-year, and 10-year events (not shown).

Additionally (not shown), our model achieves accuracies over larger and rarer extreme events, with precision and recall scores over 5-year return period events that are similar to or better than GloFAS accuracies over 1-year return period events. See the paper for more information.

Looking into the future

The flood forecasting initiative is part of our Adaptation and Resilience efforts and reflects Google’s commitment to address climate change while helping global communities become more resilient. We believe that AI and ML will continue to play a critical role in helping advance science and research towards climate action.

We actively collaborate with several international aid organizations (e.g., the Centre for Humanitarian Data and the Red Cross) to provide actionable flood forecasts. Additionally, in an ongoing collaboration with the World Meteorological Organization (WMO) to support early warning systems for climate hazards, we are conducting a study to help understand how AI can help address real-world challenges faced by national flood forecasting agencies.

While the work presented here demonstrates a significant step forward in flood forecasting, future work is needed to further expand flood forecasting coverage to more locations globally and other types of flood-related events and disasters, including flash floods and urban floods. We are looking forward to continuing collaborations with our partners in the academic and expert communities, local governments and the industry to reach these goals.

Read More

ScreenAI: A visual language model for UI and visually-situated language understanding

ScreenAI: A visual language model for UI and visually-situated language understanding

Screen user interfaces (UIs) and infographics, such as charts, diagrams and tables, play important roles in human communication and human-machine interaction as they facilitate rich and interactive user experiences. UIs and infographics share similar design principles and visual language (e.g., icons and layouts), that offer an opportunity to build a single model that can understand, reason, and interact with these interfaces. However, because of their complexity and varied presentation formats, infographics and UIs present a unique modeling challenge.

To that end, we introduce “ScreenAI: A Vision-Language Model for UI and Infographics Understanding”. ScreenAI improves upon the PaLI architecture with the flexible patching strategy from pix2struct. We train ScreenAI on a unique mixture of datasets and tasks, including a novel Screen Annotation task that requires the model to identify UI element information (i.e., type, location and description) on a screen. These text annotations provide large language models (LLMs) with screen descriptions, enabling them to automatically generate question-answering (QA), UI navigation, and summarization training datasets at scale. At only 5B parameters, ScreenAI achieves state-of-the-art results on UI- and infographic-based tasks (WebSRC and MoTIF), and best-in-class performance on Chart QA, DocVQA, and InfographicVQA compared to models of similar size. We are also releasing three new datasets: Screen Annotation to evaluate the layout understanding capability of the model, as well as ScreenQA Short and Complex ScreenQA for a more comprehensive evaluation of its QA capability.

ScreenAI

ScreenAI’s architecture is based on PaLI, composed of a multimodal encoder block and an autoregressive decoder. The PaLI encoder uses a vision transformer (ViT) that creates image embeddings and a multimodal encoder that takes the concatenation of the image and text embeddings as input. This flexible architecture allows ScreenAI to solve vision tasks that can be recast as text+image-to-text problems.

On top of the PaLI architecture, we employ a flexible patching strategy introduced in pix2struct. Instead of using a fixed-grid pattern, the grid dimensions are selected such that they preserve the native aspect ratio of the input image. This enables ScreenAI to work well across images of various aspect ratios.

The ScreenAI model is trained in two stages: a pre-training stage followed by a fine-tuning stage. First, self-supervised learning is applied to automatically generate data labels, which are then used to train ViT and the language model. ViT is frozen during the fine-tuning stage, where most data used is manually labeled by human raters.

ScreenAI model architecture.

Data generation

To create a pre-training dataset for ScreenAI, we first compile an extensive collection of screenshots from various devices, including desktops, mobile, and tablets. This is achieved by using publicly accessible web pages and following the programmatic exploration approach used for the RICO dataset for mobile apps. We then apply a layout annotator, based on the DETR model, that identifies and labels a wide range of UI elements (e.g., image, pictogram, button, text) and their spatial relationships. Pictograms undergo further analysis using an icon classifier capable of distinguishing 77 different icon types. This detailed classification is essential for interpreting the subtle information conveyed through icons. For icons that are not covered by the classifier, and for infographics and images, we use the PaLI image captioning model to generate descriptive captions that provide contextual information. We also apply an optical character recognition (OCR) engine to extract and annotate textual content on screen. We combine the OCR text with the previous annotations to create a detailed description of each screen.

A mobile app screenshot with generated annotations that include UI elements and their descriptions, e.g., TEXT elements also contain the text content from OCR, IMAGE elements contain image captions, LIST_ITEMs contain all their child elements.

LLM-based data generation

We enhance the pre-training data’s diversity using PaLM 2 to generate input-output pairs in a two-step process. First, screen annotations are generated using the technique outlined above, then we craft a prompt around this schema for the LLM to create synthetic data. This process requires prompt engineering and iterative refinement to find an effective prompt. We assess the generated data’s quality through human validation against a quality threshold.

You only speak JSON. Do not write text that isn’t JSON.
You are given the following mobile screenshot, described in words. Can you generate 5 questions regarding the content of the screenshot as well as the corresponding short answers to them? 

The answer should be as short as possible, containing only the necessary information. Your answer should be structured as follows:
questions: [
{{question: the question,
    answer: the answer
}},
 ...
]

{THE SCREEN SCHEMA}

A sample prompt for QA data generation.

By combining the natural language capabilities of LLMs with a structured schema, we simulate a wide range of user interactions and scenarios to generate synthetic, realistic tasks. In particular, we generate three categories of tasks:

  • Question answering: The model is asked to answer questions regarding the content of the screenshots, e.g., “When does the restaurant open?”
  • Screen navigation: The model is asked to convert a natural language utterance into an executable action on a screen, e.g., “Click the search button.”
  • Screen summarization: The model is asked to summarize the screen content in one or two sentences.
Block diagram of our workflow for generating data for QA, summarization and navigation tasks using existing ScreenAI models and LLMs. Each task uses a custom prompt to emphasize desired aspects, like questions related to counting, involving reasoning, etc.

LLM-generated data. Examples for screen QA, navigation and summarization. For navigation, the action bounding box is displayed in red on the screenshot.

Experiments and results

As previously mentioned, ScreenAI is trained in two stages: pre-training and fine-tuning. Pre-training data labels are obtained using self-supervised learning and fine-tuning data labels comes from human raters.

We fine-tune ScreenAI using public QA, summarization, and navigation datasets and a variety of tasks related to UIs. For QA, we use well established benchmarks in the multimodal and document understanding field, such as ChartQA, DocVQA, Multi page DocVQA, InfographicVQA, OCR VQA, Web SRC and ScreenQA. For navigation, datasets used include Referring Expressions, MoTIF, Mug, and Android in the Wild. Finally, we use Screen2Words for screen summarization and Widget Captioning for describing specific UI elements. Along with the fine-tuning datasets, we evaluate the fine-tuned ScreenAI model using three novel benchmarks:

  1. Screen Annotation: Enables the evaluation model layout annotations and spatial understanding capabilities.
  2. ScreenQA Short: A variation of ScreenQA, where its ground truth answers have been shortened to contain only the relevant information that better aligns with other QA tasks.
  3. Complex ScreenQA: Complements ScreenQA Short with more difficult questions (counting, arithmetic, comparison, and non-answerable questions) and contains screens with various aspect ratios.

The fine-tuned ScreenAI model achieves state-of-the-art results on various UI and infographic-based tasks (WebSRC and MoTIF) and best-in-class performance on Chart QA, DocVQA, and InfographicVQA compared to models of similar size. ScreenAI achieves competitive performance on Screen2Words and OCR-VQA. Additionally, we report results on the new benchmark datasets introduced to serve as a baseline for further research.

Comparing model performance of ScreenAI with state-of-the-art (SOTA) models of similar size.

Next, we examine ScreenAI’s scaling capabilities and observe that across all tasks, increasing the model size improves performances and the improvements have not saturated at the largest size.

Model performance increases with size, and the performance has not saturated even at the largest size of 5B params.

Conclusion

We introduce the ScreenAI model along with a unified representation that enables us to develop self-supervised learning tasks leveraging data from all these domains. We also illustrate the impact of data generation using LLMs and investigate improving model performance on specific aspects with modifying the training mixture. We apply all of these techniques to build multi-task trained models that perform competitively with state-of-the-art approaches on a number of public benchmarks. However, we also note that our approach still lags behind large models and further research is needed to bridge this gap.

Acknowledgements

This project is the result of joint work with Maria Wang, Fedir Zubach, Hassan Mansoor, Vincent Etter, Victor Carbune, Jason Lin, Jindong Chen and Abhanshu Sharma. We thank Fangyu Liu, Xi Chen, Efi Kokiopoulou, Jesse Berent, Gabriel Barcik, Lukas Zilka, Oriana Riva, Gang Li,Yang Li, Radu Soricut, and Tania Bedrax-Weiss for their insightful feedback and discussions, along with Rahul Aralikatte, Hao Cheng and Daniel Kim for their support in data preparation. We also thank Jay Yagnik, Blaise Aguera y Arcas, Ewa Dominowska, David Petrou, and Matt Sharifi for their leadership, vision and support. We are very grateful toTom Small for helping us create the animation in this post.

Read More

SCIN: A new resource for representative dermatology images

SCIN: A new resource for representative dermatology images

Health datasets play a crucial role in research and medical education, but it can be challenging to create a dataset that represents the real world. For example, dermatology conditions are diverse in their appearance and severity and manifest differently across skin tones. Yet, existing dermatology image datasets often lack representation of everyday conditions (like rashes, allergies and infections) and skew towards lighter skin tones. Furthermore, race and ethnicity information is frequently missing, hindering our ability to assess disparities or create solutions.

To address these limitations, we are releasing the Skin Condition Image Network (SCIN) dataset in collaboration with physicians at Stanford Medicine. We designed SCIN to reflect the broad range of concerns that people search for online, supplementing the types of conditions typically found in clinical datasets. It contains images across various skin tones and body parts, helping to ensure that future AI tools work effectively for all. We’ve made the SCIN dataset freely available as an open-access resource for researchers, educators, and developers, and have taken careful steps to protect contributor privacy.

Example set of images and metadata from the SCIN dataset.

Dataset composition

The SCIN dataset currently contains over 10,000 images of skin, nail, or hair conditions, directly contributed by individuals experiencing them. All contributions were made voluntarily with informed consent by individuals in the US, under an institutional-review board approved study. To provide context for retrospective dermatologist labeling, contributors were asked to take images both close-up and from slightly further away. They were given the option to self-report demographic information and tanning propensity (self-reported Fitzpatrick Skin Type, i.e., sFST), and to describe the texture, duration and symptoms related to their concern.

One to three dermatologists labeled each contribution with up to five dermatology conditions, along with a confidence score for each label. The SCIN dataset contains these individual labels, as well as an aggregated and weighted differential diagnosis derived from them that could be useful for model testing or training. These labels were assigned retrospectively and are not equivalent to a clinical diagnosis, but they allow us to compare the distribution of dermatology conditions in the SCIN dataset with existing datasets.

The SCIN dataset contains largely allergic, inflammatory and infectious conditions while datasets from clinical sources focus on benign and malignant neoplasms.

While many existing dermatology datasets focus on malignant and benign tumors and are intended to assist with skin cancer diagnosis, the SCIN dataset consists largely of common allergic, inflammatory, and infectious conditions. The majority of images in the SCIN dataset show early-stage concerns — more than half arose less than a week before the photo, and 30% arose less than a day before the image was taken. Conditions within this time window are seldom seen within the health system and therefore are underrepresented in existing dermatology datasets.

We also obtained dermatologist estimates of Fitzpatrick Skin Type (estimated FST or eFST) and layperson labeler estimates of Monk Skin Tone (eMST) for the images. This allowed comparison of the skin condition and skin type distributions to those in existing dermatology datasets. Although we did not selectively target any skin types or skin tones, the SCIN dataset has a balanced Fitzpatrick skin type distribution (with more of Types 3, 4, 5, and 6) compared to similar datasets from clinical sources.

Self-reported and dermatologist-estimated Fitzpatrick Skin Type distribution in the SCIN dataset compared with existing un-enriched dermatology datasets (Fitzpatrick17k, PH², SKINL2, and PAD-UFES-20).

The Fitzpatrick Skin Type scale was originally developed as a photo-typing scale to measure the response of skin types to UV radiation, and it is widely used in dermatology research. The Monk Skin Tone scale is a newer 10-shade scale that measures skin tone rather than skin phototype, capturing more nuanced differences between the darker skin tones. While neither scale was intended for retrospective estimation using images, the inclusion of these labels is intended to enable future research into skin type and tone representation in dermatology. For example, the SCIN dataset provides an initial benchmark for the distribution of these skin types and tones in the US population.

The SCIN dataset has a high representation of women and younger individuals, likely reflecting a combination of factors. These could include differences in skin condition incidence, propensity to seek health information online, and variations in willingness to contribute to research across demographics.

Crowdsourcing method

To create the SCIN dataset, we used a novel crowdsourcing method, which we describe in the accompanying research paper co-authored with investigators at Stanford Medicine. This approach empowers individuals to play an active role in healthcare research. It allows us to reach people at earlier stages of their health concerns, potentially before they seek formal care. Crucially, this method uses advertisements on web search result pages — the starting point for many people’s health journey — to connect with participants.

Our results demonstrate that crowdsourcing can yield a high-quality dataset with a low spam rate. Over 97.5% of contributions were genuine images of skin conditions. After performing further filtering steps to exclude images that were out of scope for the SCIN dataset and to remove duplicates, we were able to release nearly 90% of the contributions received over the 8-month study period. Most images were sharp and well-exposed. Approximately half of the contributions include self-reported demographics, and 80% contain self-reported information relating to the skin condition, such as texture, duration, or other symptoms. We found that dermatologists’ ability to retrospectively assign a differential diagnosis depended more on the availability of self-reported information than on image quality.

Dermatologist confidence in their labels (scale from 1-5) depended on the availability of self-reported demographic and symptom information.

While perfect image de-identification can never be guaranteed, protecting the privacy of individuals who contributed their images was a top priority when creating the SCIN dataset. Through informed consent, contributors were made aware of potential re-identification risks and advised to avoid uploading images with identifying features. Post-submission privacy protection measures included manual redaction or cropping to exclude potentially identifying areas, reverse image searches to exclude publicly available copies and metadata removal or aggregation. The SCIN Data Use License prohibits attempts to re-identify contributors.

We hope the SCIN dataset will be a helpful resource for those working to advance inclusive dermatology research, education, and AI tool development. By demonstrating an alternative to traditional dataset creation methods, SCIN paves the way for more representative datasets in areas where self-reported data or retrospective labeling is feasible.

Acknowledgements

We are grateful to all our co-authors Abbi Ward, Jimmy Li, Julie Wang, Sriram Lakshminarasimhan, Ashley Carrick, Bilson Campana, Jay Hartford, Pradeep Kumar S, Tiya Tiyasirisokchai, Sunny Virmani, Renee Wong, Yossi Matias, Greg S. Corrado, Dale R. Webster, Dawn Siegel (Stanford Medicine), Steven Lin (Stanford Medicine), Justin Ko (Stanford Medicine), Alan Karthikesalingam and Christopher Semturs. We also thank Yetunde Ibitoye, Sami Lachgar, Lisa Lehmann, Javier Perez, Margaret Ann Smith (Stanford Medicine), Rachelle Sico, Amit Talreja, Annisah Um’rani and Wayne Westerlind for their essential contributions to this work. Finally, we are grateful to Heather Cole-Lewis, Naama Hammel, Ivor Horn, Michael Howell, Yun Liu, and Eric Teasley for their insightful comments on the study design and manuscript.

Read More