 
   | Follow @DigEventHorizon | 
In this post, we collaborate with the team working on PyTorch at Meta to showcase how the torchtitan library accelerates and simplifies the pre-training of Meta Llama 3-like model architectures. We showcase the key features and capabilities of torchtitan such as FSDP2, torch.compile integration, and FP8 support that optimize the training efficiency.        Initiate the model training on SageMaker:    estimator.fit(inputs=data_channels)  Performance numbers  The following table summarizes the performance numbers for the various training runs with different optimizations.              Setup     Configuration     TOML Configuration      Throughput (Tokens per Second)      Speedup Over Baseline              LLama3   8B pre-training on 4 x p5.48xlarge instances (32 NVIDIA H100 GPUs)     Baseline     Default Configuration     6475                   torch.compile     compile = true     7166     10.67%             FP8 linear      compile = true enable_float8_linear = true     8624     33.19%             FP8 all-gather      compile = true enable_float8_linear = true enable_fsdp_float8_all_gather= true precompute_float8_dynamic_scale_for_fsdp = true     8950     38.23%           The performance results show clear optimization progress in Meta Llama 3 8B pre-training. torch.compile() delivered an 10.67% speedup, and FP8 linear operations tripled this to 33%. Adding FP8 all-gather further increased the speedup to 38.23% over the baseline. This progression demonstrates how combining optimization strategies significantly enhances training efficiency.  The following figure illustrates the stepwise performance gains for Meta Llama 3 8B pre-training on torchtitan with the optimizations.    These optimizations didn’t affect the model’s training quality. The loss curves for all optimization levels, including the baseline, torch.compile(), FP8 linear, and FP8 all-gather configurations, remained consistent throughout the training process, as shown in the following figure.    The following table showcases the consistent loss value with the different configurations.              Configuration     Loss After 2,000 Steps             Baseline     3.602             Plus torch.compile     3.601             Plus FP8     3.612             Plus FP8 all-gather     3.607           Clean up  After you complete your training experiments, clean up your resources to avoid unnecessary charges. You can start by deleting any unused SageMaker Studio resources. Next, remove the custom container image from Amazon ECR by deleting the repository you created. If you ran the optional step to use your own dataset, delete the S3 bucket where this data was stored.  Conclusion  In this post, we demonstrated how to efficiently pre-train Meta Llama 3 models using the torchtitan library on SageMaker. With torchtitan’s advanced optimizations, including torch.compile, FP8 linear operations, and FP8 all-gather, we achieved a 38.23% acceleration in Meta Llama 3 8B pre-training without compromising the model’s accuracy.  SageMaker simplified the large-scale training by offering seamless integration with custom containers, effortless scaling across multiple instances, built-in support for distributed training, and integration with TensorBoard for real-time monitoring.  Pre-training is a crucial step in developing powerful and adaptable LLMs that can effectively tackle a wide range of tasks and applications. By combining the latest PyTorch distributed training features in torchtitan with the scalability and flexibility of SageMaker, organizations can use their proprietary data and domain expertise to create robust and high-performance AI models. Get started by visiting the GitHub repository for the complete code example and optimize your LLM pre-training workflow.  Special thanks  Special thanks to Gokul Nadathur (Engineering Manager at Meta), Gal Oshri (Principal Product Manager Technical at AWS) and Janosch Woschitz (Sr. ML Solution Architect at AWS) for their support to the launch of this post.    About the Authors  Roy Allela is a Senior AI/ML Specialist Solutions Architect at AWS.He helps AWS customers from small startups to large enterprises train and deploy foundation models efficiently on AWS. He is passionate about computational optimization problems and improving the performance of AI workloads.  Kanwaljit Khurmi is a Principal Solutions Architect at Amazon Web Services. He works with AWS customers to provide guidance and technical assistance, helping them improve the value of their solutions when using AWS. Kanwaljit specializes in helping customers with containerized and machine learning applications.  Trevor Harvey is a Principal Specialist in Generative AI at Amazon Web Services (AWS) and an AWS Certified Solutions Architect   Professional. He serves as a voting member of the PyTorch Foundation Governing Board, where he contributes to the strategic advancement of open-source deep learning frameworks. At AWS, Trevor works with customers to design and implement machine learning solutions and leads go-to-market strategies for generative AI services.  Less Wright is an AI/Partner Engineer in PyTorch. He works on Triton/CUDA kernels (Accelerating Dequant with SplitK work decomposition); paged, streaming, and quantized optimizers; and PyTorch Distributed (PyTorch FSDP).  Wei Feng is a Software Engineer on the PyTorch distributed team. He has worked on float8 all-gather for FSDP2, TP (Tensor Parallel) in TorchTitan, and 4-bit quantization for distributed QLoRA in TorchTune. He is also a core maintainer of FSDP2.
Published: 2024-10-08T22:10:44
| Follow @DigEventHorizon |