Branch-Train-MiX: Mixing Expert LLMs into a Mixture-of-Experts LLM

We investigate efficient methods for training Large Language Models (LLMs) to possess capabilities in multiple specialized domains, such as coding, math reasoning and world knowledge. Our method, named Branch-Train-MiX (BTX), starts from a seed model, which is branched to train experts in embarrassingly parallel fashion with high throughput and reduced communication cost. After individual experts are asynchronously trained, BTX brings together their feedforward parameters as experts in Mixture-of-Expert (MoE) layers and averages the remaining parameters, followed by an MoE-finetuning stage to learn token-level routing. BTX generalizes two special cases, the Branch-Train-Merge method, which does not have the MoE finetuning stage to learn routing, and sparse upcycling, which omits the stage of training experts asynchronously. Compared to alternative approaches, BTX achieves the best accuracy-efficiency tradeoff.

PDF Abstract

Results from the Paper


Task Dataset Model Metric Name Metric Value Global Rank Benchmark
Arithmetic Reasoning GSM8K Branch-Train-MiX 4x7B (sampling top-2 experts) Accuracy 37.1 # 129
Code Generation HumanEval Branch-Train-Merge 4x7B (top-1) Pass@1 30.8 # 77
Code Generation HumanEval Branch-Train-MiX 4x7B (sampling top-2 experts) Pass@1 28.7 # 84
Math Word Problem Solving MATH Branch-Train-MiX 4x7B (sampling top-2 experts) Accuracy 17.8 # 80
Code Generation MBPP Branch-Train-Merge 4x7B (top-2) Accuracy 42.6 # 67
Code Generation MBPP Branch-Train-MiX 4x7B (sampling top-2 experts) Accuracy 39.4 # 71
Multi-task Language Understanding MMLU Branch-Train-MiX 4x7B (sampling top-1 experts) Average (%) 53.2 # 61
Question Answering TriviaQA Branch-Train-MiX 4x7B (sampling top-2 experts) EM 57.1 # 30
Common Sense Reasoning WinoGrande Branch-Train-MiX 4x7B (sampling top-1 expert) Accuracy 70.6 # 34

Methods