Google Cloud TPU Multislice Training(cloud.google.com) |
Google Cloud TPU Multislice Training(cloud.google.com) |
Summary from Bard: "This article is about training large language models (LLMs) on Google Cloud TPUs. It discusses the challenges of training LLMs at scale, and how Google Cloud TPU Multislice Training addresses these challenges. The article also details the results of a recent experiment in which Google trained a 128B parameter LLM on 50,944 TPU v5e chips. This experiment is the largest publicly disclosed LLM distributed training job to date."
More details: https://cloud.google.com/tpu/docs/multislice-introduction
(P.S. - contributor on blog post, Google employee, all thoughts my own)
The term "pod" originated in early data center design and occasionally crosses over from HPC to broad use - i.e. nVidia calls the set of DGX machines a "pod" https://blogs.nvidia.com/blog/2021/03/05/what-is-a-cluster-p....
Kubernetes chose "pod" to represent a set of co-scheduled containers, like a "pod of whales". Other systems like Mesos and Google's Borg https://storage.googleapis.com/pub-tools-public-publication-... use "task" to refer to a single container but didn't have a concept for heterogenous co-scheduled tasks at the time.
Somewhat ironically, it now means TPUs on GKE are confusing because we have TPUs hosts organized into "pods", and "pods" for the software using the TPUs.
A Kubernetes pod using a TPU lands on a host which is part of a slice of a TPU pod.
Well, actually, Google Cloud is just an abstraction on top of internal Google infra, so this isn't the right question. So, it depends on what you want to infer/compare.
I see a per-device batch size of 6 for the 16B model. With 256x199 = 50944 TPUs and a sequence length of 2048, this works out to 104M tokens per batch. This is much larger than typical for training runs of dense LMs of this size, which are usually closer to ~4M tokens per batch.
Was your critical batch size really this large? In other words, did you really see a benefit as compared to a much smaller batch size (and probably many fewer TPUs)? Did you use some special learning rate schedule or optimizer to achieve this?
It turns out that what matters is step count. Increasing batch size makes the model train faster, but increasing seq length from 1024 to 2048 doesn’t make it train twice as fast. So saying 104M tokens rather than batch size 50k x 6 is misleading to yourself. (One of the most surprising aspects of learning ML was how easy it is for me to trick myself in various silly ways like that.)
The mental model to make this easy to remember: progress happens in discrete quantities called steps. Training on 104M tokens per step means it’s embedding 104M tokens of knowledge every step. This isn’t the same thing as requiring an special case optimizer due to large batch sizes — sequence length adds "knowledge bandwidth", but otherwise doesn’t mess with the training dynamics.
As far as large batch optimizers, there’s LARS, which google used for their MLPerf results. I imagine they stuck with that. It creates a per-layer confidence metric, so that when the massive batch size makes a massive change, it dampens the change to smooth out the effect across the network. And since it’s a multiply, the shape of the gradient (by "shape" I mean in 3D space, where the Z axis is the intensity of the gradient) remains the same, so it doesn’t harm any knowledge transfer. It’s purely a stabilization aid.
Kind of weird I remember that after three years.
Article stated that it was throughput scheduling the pods on the clusters (from unrelated benchmarks that's usually ~300 pods/sec throughput for kube scheduler today) and then doing XLA compilation at pod launch, rather than amortizing once for all jobs.
Optimizing throughput of kube scheduler is a good general opportunity and something I believe we would like to see.
I believe AOT compilation just not a critical optimization for the test, we would recommend it when running large and long training jobs to AOT compile to keep pod start latency low for hardware failures and job restarts (from checkpoints).
> The start times we observed were impressive, but we believe we can improve these even further. We are working on areas such as optimizing scheduling in GKE to increase throughput and enabling ahead-of-time compilation in MaxText to avoid just-in-time compilations on the full cluster.
Actually now that I am searching for that it seems Baidu has a number of papers on workload orchestration at scale specifically for learning.
Agreed with you and we definitely weren't trying to brag! This is fast compared to people's expectations in the space but slow compared to what we should be able to accomplish and will accomplish in the future.
they just showed they could indeed make 50k TPUs do some flops.
With no paper this is just a marketing press release - the only takeaway is that existing tech stacks can utilize it probably.
You're asking the right question but I think the math is off by a bit. The equivalent number on the H100's is 989 TFLOP/s/chip so the equivalent job is ~10K H100's = (10 * 10^18) / (989 * 10^12). (Both chips also have 8-bit acceleration!)
I believe this is the largest ML job both by exaflops and number of chips every demonstrated. Other companies own more chips or exaflops than we show in this job but getting all the hardware working at once on a single job is a different matter! :-)
989 is TF32 core, for 16 bit it is 1979, so I guess around 5000 H100’s in a single training job would be equivalent to the training job mentioned in this article.
Either way I actually would not be surprised if OpenAI has launched a single job on more than 10k GPU’s, but I also am not very knowledgeable on practical scaling. Congrats on the feat!
I'm not certain but I think part of this is that XLA (for example) is a mountain of chip-specific optimizations between your code and the actual operations. So comparing your throughput between GPU and TPU is not just flops-to-flops.
The PaLM paper linked in the blog post is about how to get something actually useful out of that compute.
https://inflection.ai/inflection-ai-announces-1-3-billion-of...
As of 2 months ago, they had at least 7000 up and running, fwiw
On what number or op for the h100?
I didn't say otherwise. Of course Google Cloud runs on internal Google infrastructure. They wouldn't have an entirely different stack to build Google Cloud on. The problem is that Googlers don't use Google Cloud.
Amazonians use AWS. https://courses.cs.washington.edu/courses/cse452/23wi/papers...
Microsofties use Azure. https://www.zdnet.com/article/microsoft-moves-closer-to-runn...
So up until late 2018 when I left, very little of Cloud ran on "proper" Google3 infra. This may have shifted slightly (cloud has been fishing for good Google infra to externalize a lot), but in general cloud!=google3 infra.
While at the same time as you note we have functional models for container execution that are approaching millions of dispatches for highly partitionable work, especially in data engineering and ETL.
https://youtu.be/z3hmfSVmyqg?feature=shared&t=3328
I'm curious why it is so hard for them to deploy the compute. They seem to be fairly behind schedule.
We work with customers interested in training models who run their own ablations, including batch size and learning rates.
Based on that, we demonstrate workloads that we think will be interesting to potential customers! Absolutely agreed that this workload has a larger batch size than the public literature suggests.
I agree that raw throughput is the metric to aim for, since figuring out how to put it to use is an exercise for the user. But it’s probably best to be straightforward about that. The reason MLPerf measures "time until loss reaches X for a resnet classifier on imagenet" is precisely because it gives information about performance at scale —- if you didn’t train anything, you haven’t actually achieved "fastest training run". You’ve achieved largest throughput, which is similar but not the same.
And I don’t think this is a pedantic distinction. Just throw LARS on it (the MLPerf code you used in 2019 is at https://github.com/shawwn/daxx-lightning fwiw, and it runs on pods last time I tried) and see how it performs in practice.
EDIT: reading over https://github.com/google/maxtext, it looks pretty delightful. I was in the TPU scene back in 2020, and there was no way to do ahead of time compilation. Restarting training runs was a major pain point once LLMs became the focus, and I kept pestering James Bradbury to please add it. Happy to see that it finally made its way in.
It sounds like MaxText is the right approach, but until you try to actually train a model —- to achieve a low loss on a specific dataset —- you can’t know whether the code works. This isn’t theoretical. I spent over a year debugging google’s public BigGAN code (compare_gan) and discovered why it never worked: the batch norm gamma parameter was initialized to zero instead of one, so everything was being multiplied by zero to start off, which severely crippled the model.
A bug like that could easily be lurking in MaxText. You can’t know until you try to train a useful LLM. Note that compare_gan seemed to work; the authors noted that they couldn’t replicate the performance of the official BigGAN paper, but the samples looked sort of reasonable. But the model was screwy, and no one knew why until the rigorous debugging process.
If you need help with this, let me know. There are challenges when training an actual LLM that aren’t present in theoretical runs like these. For example, you need a big dataset. The Pile is a good starting point for that, and it gives a nice comparison baseline, e.g. to GPT-J.
Alternatively, post a link to a tensorboard.dev showing the loss curves for your training runs. I suspect the reason you didn’t is because you didn’t have a real dataset. That’s ok, but it doesn’t prove that MaxText works until there’s empirical evidence.
In other words, DavidSJ was precisely right: it’s an impressive-looking benchmark, which doesn’t actually help your customers train LLMs in practice. They’ll need to solve this problem eventually, and the optimizer is certainly one aspect. The other is the quantized INT8 training. It may sound impressive to say it gives a 1.4x step count speedup, but that’s useless if it harms loss convergence. How do you know it doesn’t? This isn’t an easy question to answer unless you run MLPerf or some other known stable baseline, which I’m a little shocked no one has done yet.
Totally agreed with all your comments! MaxText mainline is right now a reference implementation for users who have their own scientific opinions on model architecture and convergence. We're additionally hoping to let MaxText run compatibly with some open-source models for customers who want to use known good configurations.