Improved Group Robustness via Classifier Retraining on Independent Splits

20 Apr 2022  ·  Thien Hang Nguyen, Hongyang R. Zhang, Huy Le Nguyen ·

Deep neural networks trained by minimizing the average risk can achieve strong average performance. Still, their performance for a subgroup may degrade if the subgroup is underrepresented in the overall data population. Group distributionally robust optimization (Sagawa et al., 2020a), or group DRO in short, is a widely used baseline for learning models with strong worst-group performance. We note that this method requires group labels for every example at training time and can overfit to small groups, requiring strong regularization. Given a limited amount of group labels at training time, Just Train Twice (Liu et al., 2021), or JTT in short, is a two-stage method that infers a pseudo group label for every unlabeled example first, then applies group DRO based on the inferred group labels. The inference process is also sensitive to overfitting, sometimes involving additional hyperparameters. This paper designs a simple method based on the idea of classifier retraining on independent splits of the training data. We find that using a novel sample-splitting procedure achieves robust worst-group performance in the fine-tuning step. When evaluated on benchmark image and text classification tasks, our approach consistently performs favorably to group DRO, JTT, and other strong baselines when either group labels are available during training or are only given in validation sets. Importantly, our method only relies on a single hyperparameter, which adjusts the fraction of labels used for training feature extractors vs. training classification layers. We justify the rationale of our splitting scheme with a generalization-bound analysis of the worst-group loss.

PDF Abstract

Datasets


Results from the Paper


  Submit results from this paper to get state-of-the-art GitHub badges and help the community compare results to other papers.

Methods


No methods listed for this paper. Add relevant methods here