Poster is on display and will be presented at the poster pitch session.
AI-based methods have rapidly revolutionized atmospheric modeling, with recent successes in medium-range forecasting spurring the development of foundation models. However, accurate modeling of complex atmospheric dynamics at high spatial resolutions requires billions of neural network parameters and gigabyte-sized data samples, making accelerator memory and I/O-bandwidth the bottlenecks for model training. To overcome these limitations, we introduce Jigsaw, a distributed training and inference scheme that leverages domain and tensor parallelism to eliminate memory redundancy across model-parallel processes and reduce I/O demands. We apply the Jigsaw parallelization scheme into an MLP-Mixer architecture, WeatherMixer, a multi-layer-perceptron-based model with global vision that is well-suited for learning weather phenomena. Using Jigsaw, we train WeatherMixer with up to 3.2B-parameters, achieving predictive performance competitive with numerical weather prediction and state-of-the-art AI models. To highlight the computational performance, we perform scaling experiments on global 0.25° (~30 km resolution) ERA5 data across two HPC systems. Anticipating that future reanalysis datasets will include even higher resolutions, we demonstrate, for the first time, training on 0.125° data.
The scaling experiments demonstrate that high-resolution input data samples benefit from domain parallelism and improve per-GPU computational throughput by reducing dataloading bottlenecks. In compute–communication–limited regimes, Jigsaw achieves state-of-the-art performance in distributed model training, with 97% of theoretical peak performance on 4 GPUs; and a strong scaling speedup of 6.4 when training across 8 GPUs. By combining domain, tensor, and data parallelism at larger scales, training on 256 GPUs reaches 11 PFLOPs with a scaling efficiency of 72% compared to 51% without Jigsaw.
Contributors: