Follow @DigEventHorizon |
In this post, we collaborate with the team working on PyTorch at Meta to showcase how the torchtitan library accelerates and simplifies the pre-training of Meta Llama 3-like model architectures. We showcase the key features and capabilities of torchtitan such as FSDP2, torch.compile integration, and FP8 support that optimize the training efficiency. Initiate the model training on SageMaker: estimator.fit(inputs=data_channels) Performance numbers The following table summarizes the performance numbers for the various training runs with different optimizations. Setup Configuration TOML Configuration Throughput (Tokens per Second) Speedup Over Baseline LLama3 8B pre-training on 4 x p5.48xlarge instances (32 NVIDIA H100 GPUs) Baseline Default Configuration 6475 torch.compile compile = true 7166 10.67% FP8 linear compile = true enable_float8_linear = true 8624 33.19% FP8 all-gather compile = true enable_float8_linear = true enable_fsdp_float8_all_gather= true precompute_float8_dynamic_scale_for_fsdp = true 8950 38.23% The performance results show clear optimization progress in Meta Llama 3 8B pre-training. torch.compile() delivered an 10.67% speedup, and FP8 linear operations tripled this to 33%. Adding FP8 all-gather further increased the speedup to 38.23% over the baseline. This progression demonstrates how combining optimization strategies significantly enhances training efficiency. The following figure illustrates the stepwise performance gains for Meta Llama 3 8B pre-training on torchtitan with the optimizations. These optimizations didn’t affect the model’s training quality. The loss curves for all optimization levels, including the baseline, torch.compile(), FP8 linear, and FP8 all-gather configurations, remained consistent throughout the training process, as shown in the following figure. The following table showcases the consistent loss value with the different configurations. Configuration Loss After 2,000 Steps Baseline 3.602 Plus torch.compile 3.601 Plus FP8 3.612 Plus FP8 all-gather 3.607 Clean up After you complete your training experiments, clean up your resources to avoid unnecessary charges. You can start by deleting any unused SageMaker Studio resources. Next, remove the custom container image from Amazon ECR by deleting the repository you created. If you ran the optional step to use your own dataset, delete the S3 bucket where this data was stored. Conclusion In this post, we demonstrated how to efficiently pre-train Meta Llama 3 models using the torchtitan library on SageMaker. With torchtitan’s advanced optimizations, including torch.compile, FP8 linear operations, and FP8 all-gather, we achieved a 38.23% acceleration in Meta Llama 3 8B pre-training without compromising the model’s accuracy. SageMaker simplified the large-scale training by offering seamless integration with custom containers, effortless scaling across multiple instances, built-in support for distributed training, and integration with TensorBoard for real-time monitoring. Pre-training is a crucial step in developing powerful and adaptable LLMs that can effectively tackle a wide range of tasks and applications. By combining the latest PyTorch distributed training features in torchtitan with the scalability and flexibility of SageMaker, organizations can use their proprietary data and domain expertise to create robust and high-performance AI models. Get started by visiting the GitHub repository for the complete code example and optimize your LLM pre-training workflow. Special thanks Special thanks to Gokul Nadathur (Engineering Manager at Meta), Gal Oshri (Principal Product Manager Technical at AWS) and Janosch Woschitz (Sr. ML Solution Architect at AWS) for their support to the launch of this post. About the Authors Roy Allela is a Senior AI/ML Specialist Solutions Architect at AWS.He helps AWS customers from small startups to large enterprises train and deploy foundation models efficiently on AWS. He is passionate about computational optimization problems and improving the performance of AI workloads. Kanwaljit Khurmi is a Principal Solutions Architect at Amazon Web Services. He works with AWS customers to provide guidance and technical assistance, helping them improve the value of their solutions when using AWS. Kanwaljit specializes in helping customers with containerized and machine learning applications. Trevor Harvey is a Principal Specialist in Generative AI at Amazon Web Services (AWS) and an AWS Certified Solutions Architect Professional. He serves as a voting member of the PyTorch Foundation Governing Board, where he contributes to the strategic advancement of open-source deep learning frameworks. At AWS, Trevor works with customers to design and implement machine learning solutions and leads go-to-market strategies for generative AI services. Less Wright is an AI/Partner Engineer in PyTorch. He works on Triton/CUDA kernels (Accelerating Dequant with SplitK work decomposition); paged, streaming, and quantized optimizers; and PyTorch Distributed (PyTorch FSDP). Wei Feng is a Software Engineer on the PyTorch distributed team. He has worked on float8 all-gather for FSDP2, TP (Tensor Parallel) in TorchTitan, and 4-bit quantization for distributed QLoRA in TorchTune. He is also a core maintainer of FSDP2.
Published: 2024-10-08T22:10:44
Follow @DigEventHorizon |