Composing Ensembles of Pre-trained Models via Iterative Consensus

20 Oct 2022  ·  Shuang Li, Yilun Du, Joshua B. Tenenbaum, Antonio Torralba, Igor Mordatch ·

Large pre-trained models exhibit distinct and complementary capabilities dependent on the data they are trained on. Language models such as GPT-3 are capable of textual reasoning but cannot understand visual information, while vision models such as DALL-E can generate photorealistic photos but fail to understand complex language descriptions. In this work, we propose a unified framework for composing ensembles of different pre-trained models -- combining the strengths of each individual model to solve various multimodal problems in a zero-shot manner. We use pre-trained models as "generators" or "scorers" and compose them via closed-loop iterative consensus optimization. The generator constructs proposals and the scorers iteratively provide feedback to refine the generated result. Such closed-loop communication enables models to correct errors caused by other models, significantly boosting performance on downstream tasks, e.g. improving accuracy on grade school math problems by 7.5%, without requiring any model finetuning. We demonstrate that consensus achieved by an ensemble of scorers outperforms the feedback of a single scorer, by leveraging the strengths of each expert model. Results show that the proposed method can be used as a general purpose framework for a wide range of zero-shot multimodal tasks, such as image generation, video question answering, mathematical reasoning, and robotic manipulation. Project page: https://energy-based-model.github.io/composing-pretrained-models.

PDF Abstract
Task Dataset Model Metric Name Metric Value Global Rank Result Benchmark
Video Question Answering ActivityNet-QA GPT-2 + CLIP-14 + CLIP-multilingual (Zero-Shot) Accuracy 61.2 # 1
Video Question Answering ActivityNet-QA GPT-2 + CLIP-32 (Zero-Shot) Accuracy 58.4 # 2
Arithmetic Reasoning GSM8K GPT-2-Medium 355M (fine-tuned, BS=5) Accuracy 18.3 # 139
Parameters (Billion) 0.355 # 2
Arithmetic Reasoning GSM8K GPT-2-Medium 355M (BS=5) Accuracy 12.2 # 146
Parameters (Billion) 0.355 # 2
Arithmetic Reasoning GSM8K GPT-2-Medium 355M + question-solution classifier (BS=5) Accuracy 20.8 # 137
Parameters (Billion) 0.355 # 2
Arithmetic Reasoning GSM8K GPT-2-Medium 355M + question-solution classifier (BS=1) Accuracy 16.8 # 144
Parameters (Billion) 0.355 # 2
Image Generation ImageNet 64x64 GLIDE + CLS Inception Score 22.077 # 8
FID 30.871 # 17
Image Generation ImageNet 64x64 GLIDE + CLIP + CLS + CLS-FREE Inception Score 34.952 # 5
FID 29.184 # 14
KID 3.766 # 1
Image Generation ImageNet 64x64 GLIDE + CLS-FREE Inception Score 25.926 # 6
FID 29.219 # 15
KID 5.325 # 2
Image Generation ImageNet 64x64 GLIDE +CLS KID 7.952 # 4
Image Generation ImageNet 64x64 GLIDE + CLIP Inception Score 25.017 # 7
FID 30.462 # 16
KID 6.174 # 3

Methods