Intervening on early readouts for mitigating spurious features and simplicity bias

Intervening on early readouts for mitigating spurious features and simplicity bias

Machine learning models in the real world are often trained on limited data that may contain unintended statistical biases. For example, in the CELEBA celebrity image dataset, a disproportionate number of female celebrities have blond hair, leading to classifiers incorrectly predicting “blond” as the hair color for most female faces — here, gender is a spurious feature for predicting hair color. Such unfair biases could have significant consequences in critical applications such as medical diagnosis.

Surprisingly, recent work has also discovered an inherent tendency of deep networks to amplify such statistical biases, through the so-called simplicity bias of deep learning. This bias is the tendency of deep networks to identify weakly predictive features early in the training, and continue to anchor on these features, failing to identify more complex and potentially more accurate features.

With the above in mind, we propose simple and effective fixes to this dual challenge of spurious features and simplicity bias by applying early readouts and feature forgetting. First, in “Using Early Readouts to Mediate Featural Bias in Distillation”, we show that making predictions from early layers of a deep network (referred to as “early readouts”) can automatically signal issues with the quality of the learned representations. In particular, these predictions are more often wrong, and more confidently wrong, when the network is relying on spurious features. We use this erroneous confidence to improve outcomes in model distillation, a setting where a larger “teacher” model guides the training of a smaller “student” model. Then in “Overcoming Simplicity Bias in Deep Networks using a Feature Sieve”, we intervene directly on these indicator signals by making the network “forget” the problematic features and consequently look for better, more predictive features. This substantially improves the model’s ability to generalize to unseen domains compared to previous approaches. Our AI Principles and our Responsible AI practices guide how we research and develop these advanced applications and help us address the challenges posed by statistical biases.

Animation comparing hypothetical responses from two models trained with and without the feature sieve.

Early readouts for debiasing distillation

We first illustrate the diagnostic value of early readouts and their application in debiased distillation, i.e., making sure that the student model inherits the teacher model’s resilience to feature bias through distillation. We start with a standard distillation framework where the student is trained with a mixture of label matching (minimizing the cross-entropy loss between student outputs and the ground-truth labels) and teacher matching (minimizing the KL divergence loss between student and teacher outputs for any given input).

Suppose one trains a linear decoder, i.e., a small auxiliary neural network named as Aux, on top of an intermediate representation of the student model. We refer to the output of this linear decoder as an early readout of the network representation. Our finding is that early readouts make more errors on instances that contain spurious features, and further, the confidence on those errors is higher than the confidence associated with other errors. This suggests that confidence on errors from early readouts is a fairly strong, automated indicator of the model’s dependence on potentially spurious features.

Illustrating the usage of early readouts (i.e., output from the auxiliary layer) in debiasing distillation. Instances that are confidently mispredicted in the early readouts are upweighted in the distillation loss.

We used this signal to modulate the contribution of the teacher in the distillation loss on a per-instance basis, and found significant improvements in the trained student model as a result.

We evaluated our approach on standard benchmark datasets known to contain spurious correlations (Waterbirds, CelebA, CivilComments, MNLI). Each of these datasets contain groupings of data that share an attribute potentially correlated with the label in a spurious manner. As an example, the CelebA dataset mentioned above includes groups such as {blond male, blond female, non-blond male, non-blond female}, with models typically performing the worst on the {non-blond female} group when predicting hair color. Thus, a measure of model performance is its worst group accuracy, i.e., the lowest accuracy among all known groups present in the dataset. We improved the worst group accuracy of student models on all datasets; moreover, we also improved overall accuracy in three of the four datasets, showing that our improvement on any one group does not come at the expense of accuracy on other groups. More details are available in our paper.

Comparison of Worst Group Accuracies of different distillation techniques relative to that of the Teacher model. Our method outperforms other methods on all datasets.

Overcoming simplicity bias with a feature sieve

In a second, closely related project, we intervene directly on the information provided by early readouts, to improve feature learning and generalization. The workflow alternates between identifying problematic features and erasing identified features from the network. Our primary hypothesis is that early features are more prone to simplicity bias, and that by erasing (“sieving”) these features, we allow richer feature representations to be learned.

Training workflow with feature sieve. We alternate between identifying problematic features (using training iteration) and erasing them from the network (using forgetting iteration).

We describe the identification and erasure steps in more detail:

  • Identifying simple features: We train the primary model and the readout model (AUX above) in conventional fashion via forward- and back-propagation. Note that feedback from the auxiliary layer does not back-propagate to the main network. This is to force the auxiliary layer to learn from already-available features rather than create or reinforce them in the main network.
  • Applying the feature sieve: We aim to erase the identified features in the early layers of the neural network with the use of a novel forgetting loss, Lf , which is simply the cross-entropy between the readout and a uniform distribution over labels. Essentially, all information that leads to nontrivial readouts are erased from the primary network. In this step, the auxiliary network and upper layers of the main network are kept unchanged.

We can control specifically how the feature sieve is applied to a given dataset through a small number of configuration parameters. By changing the position and complexity of the auxiliary network, we control the complexity of the identified- and erased features. By modifying the mixing of learning and forgetting steps, we control the degree to which the model is challenged to learn more complex features. These choices, which are dataset-dependent, are made via hyperparameter search to maximize validation accuracy, a standard measure of generalization. Since we include “no-forgetting” (i.e., the baseline model) in the search space, we expect to find settings that are at least as good as the baseline.

Below we show features learned by the baseline model (middle row) and our model (bottom row) on two benchmark datasets — biased activity recognition (BAR) and animal categorization (NICO). Feature importance was estimated using post-hoc gradient-based importance scoring (GRAD-CAM), with the orange-red end of the spectrum indicating high importance, while green-blue indicates low importance. Shown below, our trained models focus on the primary object of interest, whereas the baseline model tends to focus on background features that are simpler and spuriously correlated with the label.

Feature importance scoring using GRAD-CAM on activity recognition (BAR) and animal categorization (NICO) generalization benchmarks. Our approach (last row) focuses on the relevant objects in the image, whereas the baseline (ERM; middle row) relies on background features that are spuriously correlated with the label.

Through this ability to learn better, generalizable features, we show substantial gains over a range of relevant baselines on real-world spurious feature benchmark datasets: BAR, CelebA Hair, NICO and ImagenetA, by margins up to 11% (see figure below). More details are available in our paper.

Our feature sieve method improves accuracy by significant margins relative to the nearest baseline for a range of feature generalization benchmark datasets.

Conclusion

We hope that our work on early readouts and their use in feature sieving for generalization will both spur the development of a new class of adversarial feature learning approaches and help improve the generalization capability and robustness of deep learning systems.

Acknowledgements

The work on applying early readouts to debiasing distillation was conducted in collaboration with our academic partners Durga Sivasubramanian, Anmol Reddy and Prof. Ganesh Ramakrishnan at IIT Bombay. We extend our sincere gratitude to Praneeth Netrapalli and Anshul Nasery for their feedback and recommendations. We are also grateful to Nishant Jain, Shreyas Havaldar, Rachit Bansal, Kartikeya Badola, Amandeep Kaur and the whole cohort of pre-doctoral researchers at Google Research India for taking part in research discussions. Special thanks to Tom Small for creating the animation used in this post.

Read More

MobileDiffusion: Rapid text-to-image generation on-device

MobileDiffusion: Rapid text-to-image generation on-device

Text-to-image diffusion models have shown exceptional capabilities in generating high-quality images from text prompts. However, leading models feature billions of parameters and are consequently expensive to run, requiring powerful desktops or servers (e.g., Stable Diffusion, DALL·E, and Imagen). While recent advancements in inference solutions on Android via MediaPipe and iOS via Core ML have been made in the past year, rapid (sub-second) text-to-image generation on mobile devices has remained out of reach.

To that end, in “MobileDiffusion: Subsecond Text-to-Image Generation on Mobile Devices”, we introduce a novel approach with the potential for rapid text-to-image generation on-device. MobileDiffusion is an efficient latent diffusion model specifically designed for mobile devices. We also adopt DiffusionGAN to achieve one-step sampling during inference, which fine-tunes a pre-trained diffusion model while leveraging a GAN to model the denoising step. We have tested MobileDiffusion on iOS and Android premium devices, and it can run in half a second to generate a 512×512 high-quality image. Its comparably small model size of just 520M parameters makes it uniquely suited for mobile deployment.

      
Rapid text-to-image generation on-device.

Background

The relative inefficiency of text-to-image diffusion models arises from two primary challenges. First, the inherent design of diffusion models requires iterative denoising to generate images, necessitating multiple evaluations of the model. Second, the complexity of the network architecture in text-to-image diffusion models involves a substantial number of parameters, regularly reaching into the billions and resulting in computationally expensive evaluations. As a result, despite the potential benefits of deploying generative models on mobile devices, such as enhancing user experience and addressing emerging privacy concerns, it remains relatively unexplored within the current literature.

The optimization of inference efficiency in text-to-image diffusion models has been an active research area. Previous studies predominantly concentrate on addressing the first challenge, seeking to reduce the number of function evaluations (NFEs). Leveraging advanced numerical solvers (e.g., DPM) or distillation techniques (e.g., progressive distillation, consistency distillation), the number of necessary sampling steps have significantly reduced from several hundreds to single digits. Some recent techniques, like DiffusionGAN and Adversarial Diffusion Distillation, even reduce to a single necessary step.

However, on mobile devices, even a small number of evaluation steps can be slow due to the complexity of model architecture. Thus far, the architectural efficiency of text-to-image diffusion models has received comparatively less attention. A handful of earlier works briefly touches upon this matter, involving the removal of redundant neural network blocks (e.g., SnapFusion). However, these efforts lack a comprehensive analysis of each component within the model architecture, thereby falling short of providing a holistic guide for designing highly efficient architectures.

MobileDiffusion

Effectively overcoming the challenges imposed by the limited computational power of mobile devices requires an in-depth and holistic exploration of the model’s architectural efficiency. In pursuit of this objective, our research undertakes a detailed examination of each constituent and computational operation within Stable Diffusion’s UNet architecture. We present a comprehensive guide for crafting highly efficient text-to-image diffusion models culminating in the MobileDiffusion.

The design of MobileDiffusion follows that of latent diffusion models. It contains three components: a text encoder, a diffusion UNet, and an image decoder. For the text encoder, we use CLIP-ViT/L14, which is a small model (125M parameters) suitable for mobile. We then turn our focus to the diffusion UNet and image decoder.

Diffusion UNet

As illustrated in the figure below, diffusion UNets commonly interleave transformer blocks and convolution blocks. We conduct a comprehensive investigation of these two fundamental building blocks. Throughout the study, we control the training pipeline (e.g., data, optimizer) to study the effects of different architectures.

In classic text-to-image diffusion models, a transformer block consists of a self-attention layer (SA) for modeling long-range dependencies among visual features, a cross-attention layer (CA) to capture interactions between text conditioning and visual features, and a feed-forward layer (FF) to post-process the output of attention layers. These transformer blocks hold a pivotal role in text-to-image diffusion models, serving as the primary components responsible for text comprehension. However, they also pose a significant efficiency challenge, given the computational expense of the attention operation, which is quadratic to the sequence length. We follow the idea of UViT architecture, which places more transformer blocks at the bottleneck of the UNet. This design choice is motivated by the fact that the attention computation is less resource-intensive at the bottleneck due to its lower dimensionality.

Our UNet architecture incorporates more transformers in the middle, and skips self-attention (SA) layers at higher resolutions.

Convolution blocks, in particular ResNet blocks, are deployed at each level of the UNet. While these blocks are instrumental for feature extraction and information flow, the associated computational costs, especially at high-resolution levels, can be substantial. One proven approach in this context is separable convolution. We observed that replacing regular convolution layers with lightweight separable convolution layers in the deeper segments of the UNet yields similar performance.

In the figure below, we compare the UNets of several diffusion models. Our MobileDiffusion exhibits superior efficiency in terms of FLOPs (floating-point operations) and number of parameters.

Comparison of some diffusion UNets.

Image decoder

In addition to the UNet, we also optimized the image decoder. We trained a variational autoencoder (VAE) to encode an RGB image to an 8-channel latent variable, with 8× smaller spatial size of the image. A latent variable can be decoded to an image and gets 8× larger in size. To further enhance efficiency, we design a lightweight decoder architecture by pruning the original’s width and depth. The resulting lightweight decoder leads to a significant performance boost, with nearly 50% latency improvement and better quality. For more details, please refer to our paper.

VAE reconstruction. Our VAE decoders have better visual quality than SD (Stable Diffusion).

Decoder   #Params (M)     PSNR↑     SSIM↑     LPIPS↓  
SD 49.5 26.7 0.76 0.037
Ours 39.3 30.0 0.83 0.032
Ours-Lite     9.8 30.2 0.84 0.032

Quality evaluation of VAE decoders. Our lite decoder is much smaller than SD, with better quality metrics, including peak signal-to-noise ratio (PSNR), structural similarity index measure (SSIM), and Learned Perceptual Image Patch Similarity (LPIPS).

One-step sampling

In addition to optimizing the model architecture, we adopt a DiffusionGAN hybrid to achieve one-step sampling. Training DiffusionGAN hybrid models for text-to-image generation encounters several intricacies. Notably, the discriminator, a classifier distinguishing real data and generated data, must make judgments based on both texture and semantics. Moreover, the cost of training text-to-image models can be extremely high, particularly in the case of GAN-based models, where the discriminator introduces additional parameters. Purely GAN-based text-to-image models (e.g., StyleGAN-T, GigaGAN) confront similar complexities, resulting in highly intricate and expensive training.

To overcome these challenges, we use a pre-trained diffusion UNet to initialize the generator and discriminator. This design enables seamless initialization with the pre-trained diffusion model. We postulate that the internal features within the diffusion model contain rich information of the intricate interplay between textual and visual data. This initialization strategy significantly streamlines the training.

The figure below illustrates the training procedure. After initialization, a noisy image is sent to the generator for one-step diffusion. The result is evaluated against ground truth with a reconstruction loss, similar to diffusion model training. We then add noise to the output and send it to the discriminator, whose result is evaluated with a GAN loss, effectively adopting the GAN to model a denoising step. By using pre-trained weights to initialize the generator and the discriminator, the training becomes a fine-tuning process, which converges in less than 10K iterations.

Illustration of DiffusionGAN fine-tuning.

Results

Below we show example images generated by our MobileDiffusion with DiffusionGAN one-step sampling. With such a compact model (520M parameters in total), MobileDiffusion can generate high-quality diverse images for various domains.

Images generated by our MobileDiffusion

We measured the performance of our MobileDiffusion on both iOS and Android devices, using different runtime optimizers. The latency numbers are reported below. We see that MobileDiffusion is very efficient and can run within half a second to generate a 512×512 image. This lightning speed potentially enables many interesting use cases on mobile devices.

Latency measurements (s) on mobile devices.

Conclusion

With superior efficiency in terms of latency and size, MobileDiffusion has the potential to be a very friendly option for mobile deployments given its capability to enable a rapid image generation experience while typing text prompts. And we will ensure any application of this technology will be in-line with Google’s responsible AI practices.

Acknowledgments

We like to thank our collaborators and contributors that helped bring MobileDiffusion to on-device: Zhisheng Xiao, Yanwu Xu, Jiuqiang Tang, Haolin Jia, Lutz Justen, Daniel Fenner, Ronald Wotzlaw, Jianing Wei, Raman Sarokin, Juhyun Lee, Andrei Kulik, Chuo-Ling Chang, and Matthias Grundmann.

Read More