felarof
9 months ago
Hey HN, we recently fine-tuned the llama3.1 405B model on 8xAMD MI300x GPUs using JAX instead of PyTorch. JAX's advanced sharding APIs allowed us to achieve great performance. Check out our blog post to learn about the cool sharding tricks we used. We've also open-sourced the code: https://github.com/felafax/felafax
We're a small startup building AI infra for fine-tuning and serving LLMs on non-NVIDIA hardware (TPUs, AMD, Trainium).
Problem: Many companies are trying to get PyTorch working on AMD GPUs, but we believe this is a treacherous path. PyTorch is deeply intertwined with the NVIDIA ecosystem in a lot of ways (e.g., `torch.cuda` or scaled_dot_product_attention is an NVIDIA CUDA kernel exposed as a PyTorch function). So, to get PyTorch code running on non-NVIDIA hardware, there's a lot of "de-NVIDIAfying" that needs to be done.
Solution: We believe JAX is a better fit for non-NVIDIA hardware. In JAX, ML model code compiles to hardware-independent HLO graphs, which are then optimized by the XLA compiler before hardware-specific optimization. This clean separation allowed us to run the same LLaMA3 JAX code both on Google TPUs and AMD GPUs with no changes.
Our strategy as a company is to invest upfront in porting models to JAX, then leverage its framework and XLA kernels to extract maximum performance from non-NVIDIA backends. This is why we first ported Llama 3.1 from PyTorch to JAX, and now the same JAX model works great on TPUs and runs perfectly on AMD GPUs.
We'd love to hear your thoughts on our vision and repo!
ipsum2
9 months ago
I, and several others, had no problem running on PyTorch on AMD GPUs, with no code changes from CUDA. Check out MosaicML's blog posts: https://www.databricks.com/blog/training-llms-scale-amd-mi25...
mistymountains
9 months ago
Again, the problem is custom kernels in CUDA. It’s not straightforward for many applications (LLMs are probably the most straightforward).
felarof
9 months ago
Ahh, interesting, will take a look!
Curious what are the steps to run PyTorch on AMD (does it work out-of-box with PyTorch+rocm docker image)? Does torch.compile work smoothly?
lhl
9 months ago
While your project is neat and I'd like to see how the performance compares, for LLM training, PyTorch, including torch.compile works completely OOTB on AMD.
All you have to do is pip install the ROCm version of PyTorch (or run the docker image) and it's seamless (the ROCm version just treats torch.cuda as calling ROCm).
I've used axolotl (trl/accelerate based), torchtune, and LLaMA-Factory, which are all PyTorch-based without any issues for training.
anthonix1
9 months ago
Yeah I would suggest taking a look at PyTorch on AMD before saying stuff like "scaled_dot_product_attention is an NVIDIA CUDA kernel exposed as a PyTorch function", because that is demonstrably false.
Also, FWIW, I would suggest getting a small Llama 3.1 model training fast before trying to do a big 405B model -- faster to iterate and almost everything you'll learn on the small models will scale to the 405B.
felarof
9 months ago
Thanks for the feedback! I appreciate you pointing that out. My understanding was based on the PyTorch documentation for scaled_dot_product_attention (https://pytorch.org/docs/stable/generated/torch.nn.functiona...). - "The function may call optimized kernels for improved performance when using the CUDA backend. For all other backends, the PyTorch implementation will be used."
And was trying to make a broader point about the lack of transparency (in performance, lower-level impl) in PyTorch when running on NVIDIA vs. non-NVIDIA hardware.
chillee
9 months ago
> And was trying to make a broader point about the lack of transparency (in performance, lower-level impl) in PyTorch when running on NVIDIA vs. non-NVIDIA hardware.
I don't quite understand this argument. Lack of transparency from running PyTorch so instead we're gonna leave it all to XLA? How does this solve the "transparency" issue?
orf
9 months ago
Having a common library function that is either lighting fast or dog slow depending on the hardware, is not a great position to be in.
Moreover, this will get worse as more CUDA specific features are added to PyTorch with ad-hoc fallback functions.
I guess OP is saying that XLA is more transparent in this regard, because it wouldn’t use functions like these and the generated comparable code would be on-pare performance wise?
chillee
9 months ago
> it wouldn’t use functions like these and the generated comparable code would be on-pare performance wise
Perhaps if XLA generated all functions from scratch, this would be more compelling. But XLA relies very heavily on pattern-matching to common library functions (e.g. CuDNN), and these patterns will certainly work better on Nvidia GPUs than AMD GPUs.
In this way, I actually think explicitly calling the common library functions is actually much more transparent.
anthonix1
9 months ago
[flagged]
unlikelymordant
9 months ago
are you at all confident that this isn't hallucinated? I'd never trust an answer like this from an LLM
WithinReason
9 months ago
Did you verify everything else it said is true?
germanjoey
9 months ago
How are you verifying accuracy for your JAX port of Llama 3.1?
IMHO, the main reason to use pytorch is actually that the original model used pytorch. What can seem to be identical logic between different model versions may actually cause model drift when infinitesimal floating point errors accumulate due to the huge scale of the data. My experience is that debugging an accuracy mismatches like this in a big model is a torturous ordeal beyond the 10th circle of hell.
felarof
9 months ago
Good question. We used a new AI+math-based testing tool (benchify.com) to run comparison tests, but we are working on building more robust infrastructure for this. Translating models from PyTorch to JAX is core to our strategy.
That said, this path is not uncommon (translating from one framework to another). HuggingFace translates Google's Gemma family models from JAX to PyTorch, and a ton of people use it.
credit_guy
9 months ago
When you say "model versions", do you mean different quantizations of the model? Then it's not floating point errors that accumulate. Different quantizations of the model are different models. People will call such a model something like Meta-Llama-3.1-8B-Instruct--q4_0, claiming that it's just a "version" of the Meta-Llama-3.1-8B-Instruct. But it's just a lie. It's not the same model, and you should not expect the same results. There is no reason to debug the differences, what exactly would you expect to find, and what action would you envision to take once you find what you are looking for? However, is the quantized version still a useful LLM? Absolutely. Most people don't have an A100 to run the original model, so a quantized version is better than nothing.
srcreigh
9 months ago
Very fascinating, can you explain more about a time when this happened?
Like what area was affected by fp errors, why were they introduced (was it like refactoring of pytorch code?), how was this determined to be the cause?
anthonix1
9 months ago
Does JAX have its own implementations of matmul, flash attention etc? Or does it use the ROCm implementations like PyTorch does? (e.g,. hipblaslt, Composable Kernel FA etc)
Not too familiar with JAX, but the abysmal PyTorch training perf on MI300x is in large part attributable to the slow perf of the ROCm libraries it is using under the hood.
jdeaton
9 months ago
JAX has a sub-system called Pallas[1] with a Triton-like programming model and an example implementation of Flash Attention [2]. It is quite fast. On TPUs I've heard that the XLA compiler already emits a flash-attention-like computation graph for a regular JAX implementation of attention so there's no need to have some specialized kernel in that case.
1. https://jax.readthedocs.io/en/latest/pallas/index.html
2. https://github.com/jax-ml/jax/blob/main/jax/experimental/pal...
llm_trw
9 months ago
Does this work on the consumer grade cards like the 7090 XTX?
And by work I don't mean: spend two weeks trying to get the drivers set up and never update the server again.
lhl
9 months ago
A couple months ago I did some testing on some consumer cards. [1] I think you should be able to use torchtune or axolotl without anything besides installing the ROCm version of PyTorch.
[1] https://wandb.ai/augmxnt/train-bench/reports/Trainer-perform...
tommiegannert
9 months ago
Am I reading that right that the 7900 XTX is on a par with 3090, and 4090 is twice as fast?
lhl
9 months ago
Yeah, those numbers are correct as of their testing (in June) although people who are really interested should check out the linked repo and do their own runs as software/optimizations have continued to change a lot and the RDNA3 side has a lot of untapped potential. Eg, the 7900 XTX has a huge theoretical FLOPS advantage over the 3090 but the results totally don't reflect that. One example of this hobbling is that RDNA3 only recently got backpass FA via a still under-optimized aotriton implementation: https://github.com/ROCm/aotriton/pull/39
There are also still ongoing optimizations on the Nvidia side as well. In the beginning of the year the 7900 XTX and 3090 were pretty close on llama.cpp inference performance, but a few months ago llama.cpp got CUDA graph and FA support implemented that boosted perf significantly for both my 3090 and 4090.
(For AI/ML, a used 3090 remains I think the best bang/buck for both inference and small training runs. You can pay twice as much for the twice as fast 4090, but at the end of the day you'll still wish you had more VRAM, so it's hard to really recommend unless you're going to use mixed precision. The RDNA3 cards are not as bad to work with as the Internet would have you believe, but they'd have to be a lot cheaper if your main use case was AI/ML for both the PITA factor and just from pure real-world performance.)
itsTyrion
9 months ago
Damn, ML/AI performance is that different? In games, 4090 -> 7900 XTX is more like -20%
woodrowbarlow
9 months ago
i've been running inference on the 7900xtx using pytorch and rocm (installed directly from package managers, no manual fiddling) with great performance. no problem running the full flux1.dev model, for instance. haven't looked at training or fine-tuning yet.
ngcc_hk
9 months ago
Given it is a migration, is there actual comparison of the same model on PyTorch vs your version. The comparison table there seems to be on technical side.
Also any technical issues encountered?
felarof
9 months ago
We have a few technical issues that we still need to address:
1) This entire fine-tuning run was done in JAX eager mode. I kept running out of memory (OOM) when trying to `jax.jit` the entire training step. Even gradual `jax.jit` didn't work.
2) The current version doesn't have gradient accumulation, and with a batch size of just 16, that’s not ideal. I'm working on implementing gradient accumulation next.
3) We still haven't found a good way to load large sequence-length data (like 32k sequence length). Currently, before sharding the training batch across GPUs, it ends up loading the entire batch onto a single GPU’s VRAM and causes OOM issues.
logicchains
9 months ago
> I kept running out of memory (OOM) when trying to `jax.jit` the entire training step. Even gradual `jax.jit` didn't work.
Were you using activation checkpointing? https://jax.readthedocs.io/en/latest/_autosummary/jax.checkp... is very important for keeping memory usage reasonable when training large models.
cameron_b
9 months ago
I'm glad to see a full implementation on AMD hardware.
I'm not familiar with JAX, but the idea of providing an abstraction layer to more easily get to work on what hardware is available seems really valuable. Bringing back some competitiveness to the ecosystem will be a big win for workload mobility.
I suspect that price/performance across implementations will be highly dependent on contract details, but do you intend to publish some comparisons in the future?
anthonix1
9 months ago
Any direct comparisons to 8xH100? 2 toks/sec seems very slow!
I haven't done any LoRA training on MI300x myself, but I have done LLama 3.1 full training on 8xMI300x and got pretty close to 8xH100 performance with my own kernels (ROCm is just too slow).
felarof
9 months ago
Oops, my calculation was wrong. Let me add an edit to the blog, thanks for pointing it out!
My train step was taking 30s.
And I was using a batch size of 16 and seq length of 64, making the training speed as (16*64/30) tokens per sec == 35 tokens per second (for fine-tuning in JAX eager mode).
(I haven't done comparison with 8XH100)
gdiamos
9 months ago
That’s approximately 0.8% MFU - h100 would get more like 30% or 40% MFU if well tuned
405e9 parameters
2 flops per matrix multiply per parameter
3 matrix multiplies for (forward, backward param, and backward activation) passes
batch size 16
seq length 64
1.3 petaflops per second per GPU in bfloat16
8 GPUs
30 seconds
So that’s 0.8% = (405e9 * 2 * 3 * 16 * 64 / 30) / (1.3e15 * 8)
Note that I’m ignoring the attention flops in this simplified calculation, but they would be a second order effect at this sequence length
Also note that I’m assuming full weight training, not LoRA . The result would be lower MFU if using LoRA
These MI300X results are promising functionally (it's tough to get any model this big running) but they have a long way to go on perf. It's also single node. The biggest issues I've seen on MI300X are related to scaling to multiple nodes.
EDIT: The blog seems to indicate it is using LoRA. So we should remove the backward param pass from the equation above. Backward param only applies to adaptor weights, which are much more than 10x smaller, so we set it to 0 in the approximation. So we get
0.53% = (405e9 * 2 * 2 * 16 * 64 / 30) / (1.3e15 * 8)
darrick_horton
9 months ago
We've significantly optimized multinode on AMD MI300X for both stability and performance at TensorWave. There were certainly a lot of challenges, but we've become experts at multinode on AMD. We'd be happy to show you the latest results! They are quite compelling
latchkey
9 months ago
Promises... promises... https://x.com/DarrickHorton/status/1828902979875017003
Maybe challenging for Darrick, but in reality, you just need to follow the directions and configure things correctly...
https://techdocs.broadcom.com/us/en/storage-and-ethernet-con...
heyoni
9 months ago
If I feed your reply to an llm will it be able to tell me what you’re talking about?
/edit I think it did alright…
gdiamos
9 months ago
I tried with Llama 3.1 and it hallucinated MFU. It also completely misinterpreted multi-node scaling.
==========================================================
Let's break down the results described in the post.
Context: The post is discussing the performance of a large language model on a MI300X GPU, which is a high-performance computing (HPC) system. The model has approximately 405 billion parameters and is trained using a batch size of 16 and sequence length of 64.
Key metrics:
MFU (Million Floating-Point Operations per Second): This is a measure of the model's performance, specifically the number of floating-point operations (FLOPs) it can perform per second.
FLOPs: The number of floating-point operations required to perform a matrix multiplication, which is a fundamental operation in deep learning.
GPU performance: The MI300X GPU is capable of 1.3 petaflops (1.3 x 10^15 FLOPs) per second in bfloat16 (a 16-bit floating-point format).
Calculations:
The author provides two calculations to estimate the MFU of the model:
Initial calculation: Assuming full weight training (not LoRA), the author estimates the MFU as:
405 billion parameters
2 FLOPs per matrix multiply per parameter
3 matrix multiplies (forward, backward parameter, and backward activation)
Batch size 16
Sequence length 64
30 seconds to complete the calculation
1.3 petaflops per second per GPU
8 GPUs
The calculation yields an MFU of approximately 0.8%.
Revised calculation: After correcting the assumption to use LoRA (a technique that reduces the number of FLOPs), the author revises the calculation by removing the backward parameter pass, which is only applied to adaptor weights (much smaller than the main weights). This yields an MFU of approximately 0.53%.
Interpretation:
The results indicate that the MI300X GPU is not yet optimized for this large language model, with an MFU of only 0.53% (or 0.8% in the initial calculation). This is a relatively low performance compared to the theoretical maximum of 1.3 petaflops per second per GPU. The author notes that the biggest issues are related to scaling to multiple nodes, suggesting that the performance may improve when running on a larger cluster.
The revised calculation using LoRA reduces the MFU by about 33%, indicating that using this technique can lead to a more efficient use of the GPU resources.
==========================================================
MFU means model flops utilitization. it is a measure of efficiency from 0% to 100% . 100% means that the model is running at maximum possible efficiency, i.e. 1.3 petaflops per gpu.
==========================================================
In that case, the results indicate that the MI300X GPU is running the large language model at a relatively low efficiency, with an MFU of 0.53% (or 0.8% in the initial calculation).
This means that the model is only utilizing a tiny fraction of the GPU's maximum theoretical performance of 1.3 petaflops per second. In other words, the model is not fully utilizing the GPU's capabilities, and there is a significant amount of headroom for optimization.
To put this into perspective, an MFU of 100% would mean that the model is running at the maximum possible efficiency, using 1.3 petaflops per second per GPU. An MFU of 0.53% or 0.8% is extremely low, indicating that the model is running at a tiny fraction of its potential performance.
The author's comment that the MI300X results are "promising functionally" suggests that the model is able to run, but the low MFU indicates that there are significant opportunities for optimization and performance improvement.
heyoni
9 months ago
MFU being maximal floating point utilization right? I used chatgpt 4o and it otherwise looks very similar to your llama result
YetAnotherNick
9 months ago
> The blog seems to indicate it is using LoRA. So we should remove the backward param pass from the equation above. Backward param only applies to adaptor weights
Backward pass still runs on the non adapter weights. But yeah 10 TFlops/GPU specially on tiny sequence size is very bad compared to what you can get on Nvidia. And I believe the difference would be even higher with large sequence length.
gdiamos
9 months ago
backward activations does but typically not backwards weight gradients.
Why compute gradients with regards to weights that aren't going to be updated?
jgalt212
9 months ago
Is there some cost rule of thumb to compare Nvidia, AMD, and Google TPU?
felarof
9 months ago
Good question. No good metric give performance depends on software stack (JAX vs PyTorch) + optimizations.
But my take performance per dollar of TPU > AMD > NVIDIA.
CuriouslyC
9 months ago
TPUs are slow but efficient and AMD has bugs but for some things works quite well. Nvidia is obviously the gold standard.
felarof
9 months ago
Haha, TPUs are not slow :) All of Google's training (including Gemini models) is done on TPUs.
There are good 1p [a] and 3p [b] benchmarks comparing TPUs vs NVIDIA GPUs.
[a] - https://github.com/GoogleCloudPlatform/vertex-ai-samples/blo...
logicchains
9 months ago
Did you consider using https://github.com/AI-Hypercomputer/maxtext ? It has a Jax llama implementation, and gets decent MFU on TPU and GPU (I've only tried it on NVidia GPU, not AMD).
lostmsu
9 months ago
Could you share performance so we could compare?
ewalk153
9 months ago
Do you see tinygrad as a useful lower level abstraction or is JAX sufficient to get pref out of AMD GPUs?
felarof
9 months ago
Tinygrad is great, but still in early stages I believe.
JAX has matured a lot over last 6 years and XLA has been around for lot longer. We believe we can extract good perf from AMD with JAX + XLA kernels.
upbeat_general
9 months ago
scaled_dot_product_attention isn’t CUDA specific, it even works on TPUs.