Paper

Learning from Mistakes: Using Mis-predictions as Harm Alerts in Language Pre-Training

Fitting complex patterns in the training data, such as reasoning and commonsense, is a key challenge for language pre-training. According to recent studies and our empirical observations, one possible reason is that some easy-to-fit patterns in the training data, such as frequently co-occurring word combinations, dominate and harm pre-training, making it hard for the model to fit more complex information. We argue that mis-predictions can help locate such dominating patterns that harm language understanding. When a mis-prediction occurs, there should be frequently co-occurring patterns with the mis-predicted word fitted by the model that lead to the mis-prediction. If we can add regularization to train the model to rely less on such dominating patterns when a mis-prediction occurs and focus more on the rest more subtle patterns, more information can be efficiently fitted at pre-training. Following this motivation, we propose a new language pre-training method, Mis-Predictions as Harm Alerts (MPA). In MPA, when a mis-prediction occurs during pre-training, we use its co-occurrence information to guide several heads of the self-attention modules. Some self-attention heads in the Transformer modules are optimized to assign lower attention weights to the words in the input sentence that frequently co-occur with the mis-prediction while assigning higher weights to the other words. By doing so, the Transformer model is trained to rely less on the dominating frequently co-occurring patterns with mis-predictions while focus more on the rest more complex information when mis-predictions occur. Our experiments show that MPA expedites the pre-training of BERT and ELECTRA and improves their performances on downstream tasks.

Results in Papers With Code
(↓ scroll down to see all results)