Scaling Ensemble Distribution Distillation to Many Classes with Proxy Targets
View / Open Files
Conference Name
Neural Information Processing Systems (NeurIPS 2021)
Type
Conference Object
This Version
AM
Metadata
Show full item recordCitation
Gales, M., Malinin, A., & Ryabinin, M. Scaling Ensemble Distribution Distillation to Many Classes with Proxy Targets. Neural Information Processing Systems (NeurIPS 2021). https://doi.org/10.17863/CAM.78106
Abstract
Ensembles of machine learning models yield improved system performance as well as robust and interpretable uncertainty estimates; however, their inference costs can be prohibitively high. Ensemble Distribution Distillation (EnD2) is an approach that allows a single model to efficiently capture both the predictive performance and uncertainty estimates of an ensemble. For classification, this is achieved by training a Dirichlet distribution over the ensemble members’ output distributions via the maximum likelihood criterion. Although theoretically principled, this work shows that the criterion exhibits poor convergence when applied to large-scale tasks where the number of classes is very high. Specifically, we show that for the Dirichlet log-likelihood criterion classes with low probability induce larger gradients than high-probability classes. Hence during training the model focuses on the distribution of the ensemble tail-class probabilities rather than the probability of the correct and closely related classes. We propose a new training objective which minimizes the reverse KL-divergence to a Proxy-Dirichlet target derived from the ensemble. This loss resolves the gradient issues of EnD2, as we demonstrate both theoretically and empirically on the ImageNet, LibriSpeech, and WMT17 En-De datasets containing 1000, 5000, and 40,000 classes, respectively.
Sponsorship
Andrey
Funder references
Cambridge Assessment (Unknown)
Embargo Lift Date
2022-11-15
Identifiers
External DOI: https://doi.org/10.17863/CAM.78106
This record's URL: https://www.repository.cam.ac.uk/handle/1810/330661
Statistics
Total file downloads (since January 2020). For more information on metrics see the
IRUS guide.