MTraining: Distributed Dynamic Sparse Attention for Efficient Ultra-Long Context Training
MTraining: 分布式动态稀疏注意力用于高效超长上下文训练
Wenxuan Li, Chengruidong Zhang, Huiqiang Jiang, Yucheng Li, Yuqing Yang, Lili Qiu
AI总结 本文提出MTraining方法,通过动态稀疏注意力机制解决超长上下文训练中的计算不平衡和通信开销问题,实现了Qwen2.5-3B模型上下文窗口从32K扩展到512K,并在多个下游任务中达到6倍更高的训练吞吐量同时保持模型准确性。
详情
长上下文窗口的采用已成为大型语言模型(LLMs)的标准特性,扩展的上下文显著增强了其复杂推理能力,并拓宽了其在多样化场景中的应用。动态稀疏注意力是一种减少长上下文计算成本的有希望的方法。然而,高效地在分布式设置中训练具有动态稀疏注意力的LLMs在超长上下文中仍然是一个重大挑战,这主要由于工人级别和步骤级别的不平衡。本文介绍了MTraining,一种新的分布式方法,利用动态稀疏注意力来实现具有超长上下文的LLMs的高效训练。具体来说,MTraining集成了三个关键组件:动态稀疏训练模式、平衡稀疏环注意力和分层稀疏环注意力。这些组件旨在协同解决动态稀疏注意力机制在训练具有广泛上下文长度的模型时固有的计算不平衡和通信开销问题。我们通过训练Qwen2.5-3B来证明MTraining的有效性,成功将其上下文窗口从32K扩展到512K tokens,在32块A100 GPU的集群上。我们在全面的下游任务评估中,包括RULER、PG-19、InfiniteBench和Needle In A Haystack,发现MTraining在保持模型准确性的同时,实现了高达6倍的训练吞吐量提升。我们的代码可在https://github.com/microsoft/MInference/tree/main/MTraining上获得。
The adoption of long context windows has become a standard feature in Large Language Models (LLMs), as extended contexts significantly enhance their capacity for complex reasoning and broaden their applicability across diverse scenarios. Dynamic sparse attention is a promising approach for reducing the computational cost of long-context. However, efficiently training LLMs with dynamic sparse attention on ultra-long contexts-especially in distributed settings-remains a significant challenge, due in large part to worker- and step-level imbalance. This paper introduces MTraining, a novel distributed methodology leveraging dynamic sparse attention to enable efficient training for LLMs with ultra-long contexts. Specifically, MTraining integrates three key components: a dynamic sparse training pattern, balanced sparse ring attention, and hierarchical sparse ring attention. These components are designed to synergistically address the computational imbalance and communication overheads inherent in dynamic sparse attention mechanisms during the training of models with extensive context lengths. We demonstrate the efficacy of MTraining by training Qwen2.5-3B, successfully expanding its context window from 32K to 512K tokens on a cluster of 32 A100 GPUs. Our evaluations on a comprehensive suite of downstream tasks, including RULER, PG-19, InfiniteBench, and Needle In A Haystack, reveal that MTraining achieves up to a 6x higher training throughput while preserving model accuracy. Our code is available at https://github.com/microsoft/MInference/tree/main/MTraining.