Distilling Conditional Diffusion Models for Offline Reinforcement Learning through Trajectory Stitching

1 Feb 2024  ·  Shangzhe Li, Xinhua Zhang ·

Deep generative models have recently emerged as an effective approach to offline reinforcement learning. However, their large model size poses challenges in computation. We address this issue by proposing a knowledge distillation method based on data augmentation. In particular, high-return trajectories are generated from a conditional diffusion model, and they are blended with the original trajectories through a novel stitching algorithm that leverages a new reward generator. Applying the resulting dataset to behavioral cloning, the learned shallow policy whose size is much smaller outperforms or nearly matches deep generative planners on several D4RL benchmarks.

PDF Abstract

Datasets


Results from the Paper


  Submit results from this paper to get state-of-the-art GitHub badges and help the community compare results to other papers.

Methods