RL Transformers

Contrastive BERT

Introduced by Banino et al. in CoBERL: Contrastive BERT for Reinforcement Learning

Contrastive BERT is a reinforcement learning agent that combines a new contrastive loss and a hybrid LSTM-transformer architecture to tackle the challenge of improving data efficiency for RL. It uses bidirectional masked prediction in combination with a generalization of recent contrastive methods to learn better representations for transformers in RL, without the need of hand engineered data augmentations.

For the architecture, a residual network is used to encode observations into embeddings $Y_{t}$. $Y_{t}$ is fed through a causally masked GTrXL transformer, which computes the predicted masked inputs $X_{t}$ and passes those together with $Y_{t}$ to a learnt gate. The output of the gate is passed through a single LSTM layer to produce the values that we use for computing the RL loss. A contrastive loss is computed using predicted masked inputs $X_{t}$ and $Y_{t}$ as targets. For this, we do not use the causal mask of the Transformer.

Source: CoBERL: Contrastive BERT for Reinforcement Learning

Papers


Paper Code Results Date Stars

Tasks


Task Papers Share
Reinforcement Learning (RL) 1 100.00%

Categories