felarof
7 hours ago
Hey HN, we recently fine-tuned the llama3.1 405B model on 8xAMD MI300x GPUs using JAX instead of PyTorch. JAX's advanced sharding APIs allowed us to achieve great performance. Check out our blog post to learn about the cool sharding tricks we used. We've also open-sourced the code: https://github.com/felafax/felafax
We're a small startup building AI infra for fine-tuning and serving LLMs on non-NVIDIA hardware (TPUs, AMD, Trainium).
Problem: Many companies are trying to get PyTorch working on AMD GPUs, but we believe this is a treacherous path. PyTorch is deeply intertwined with the NVIDIA ecosystem in a lot of ways (e.g., `torch.cuda` or scaled_dot_product_attention is an NVIDIA CUDA kernel exposed as a PyTorch function). So, to get PyTorch code running on non-NVIDIA hardware, there's a lot of "de-NVIDIAfying" that needs to be done.
Solution: We believe JAX is a better fit for non-NVIDIA hardware. In JAX, ML model code compiles to hardware-independent HLO graphs, which are then optimized by the XLA compiler before hardware-specific optimization. This clean separation allowed us to run the same LLaMA3 JAX code both on Google TPUs and AMD GPUs with no changes.
Our strategy as a company is to invest upfront in porting models to JAX, then leverage its framework and XLA kernels to extract maximum performance from non-NVIDIA backends. This is why we first ported Llama 3.1 from PyTorch to JAX, and now the same JAX model works great on TPUs and runs perfectly on AMD GPUs.
We'd love to hear your thoughts on our vision and repo!
ipsum2
5 hours ago
I, and several others, had no problem running on PyTorch on AMD GPUs, with no code changes from CUDA. Check out MosaicML's blog posts: https://www.databricks.com/blog/training-llms-scale-amd-mi25...
felarof
5 hours ago
Ahh, interesting, will take a look!
Curious what are the steps to run PyTorch on AMD (does it work out-of-box with PyTorch+rocm docker image)? Does torch.compile work smoothly?
anthonix1
4 hours ago
Yeah I would suggest taking a look at PyTorch on AMD before saying stuff like "scaled_dot_product_attention is an NVIDIA CUDA kernel exposed as a PyTorch function", because that is demonstrably false.
Also, FWIW, I would suggest getting a small Llama 3.1 model training fast before trying to do a big 405B model -- faster to iterate and almost everything you'll learn on the small models will scale to the 405B.
felarof
4 hours ago
Thanks for the feedback! I appreciate you pointing that out. My understanding was based on the PyTorch documentation for scaled_dot_product_attention (https://pytorch.org/docs/stable/generated/torch.nn.functiona...). - "The function may call optimized kernels for improved performance when using the CUDA backend. For all other backends, the PyTorch implementation will be used."
And was trying to make a broader point about the lack of transparency (in performance, lower-level impl) in PyTorch when running on NVIDIA vs. non-NVIDIA hardware.
anthonix1
3 hours ago
I just asked an instance of Hermes 3 Llama 3.1 405B running on AMD GPUs: "does pytorch scaled dot product attention run on AMD GPUs?":
"Yes, PyTorch's scaled dot product attention can run on AMD GPUs. PyTorch supports AMD GPUs through the ROCm (Radeon Open Compute) platform. To use PyTorch with an AMD GPU, you need to install the ROCm version of PyTorch, which is specifically built for AMD GPUs [...]"
And it proceeded to give the steps to follow to install and run, with example Python code to demonstrate it. One slight nitpick is that it referred to an older URL with the --index-url to install torch with pip, but otherwise it was correct.
mistymountains
4 hours ago
Again, the problem is custom kernels in CUDA. It’s not straightforward for many applications (LLMs are probably the most straightforward).
anthonix1
5 hours ago
Does JAX have its own implementations of matmul, flash attention etc? Or does it use the ROCm implementations like PyTorch does? (e.g,. hipblaslt, Composable Kernel FA etc)
Not too familiar with JAX, but the abysmal PyTorch training perf on MI300x is in large part attributable to the slow perf of the ROCm libraries it is using under the hood.
germanjoey
6 hours ago
How are you verifying accuracy for your JAX port of Llama 3.1?
IMHO, the main reason to use pytorch is actually that the original model used pytorch. What can seem to be identical logic between different model versions may actually cause model drift when infinitesimal floating point errors accumulate due to the huge scale of the data. My experience is that debugging an accuracy mismatches like this in a big model is a torturous ordeal beyond the 10th circle of hell.
felarof
3 hours ago
Good question. We used a new AI+math-based testing tool (benchify.com) to run comparison tests, but we are working on building more robust infrastructure for this. Translating models from PyTorch to JAX is core to our strategy.
That said, this path is not uncommon (translating from one framework to another). HuggingFace translates Google's Gemma family models from JAX to PyTorch, and a ton of people use it.
credit_guy
5 hours ago
When you say "model versions", do you mean different quantizations of the model? Then it's not floating point errors that accumulate. Different quantizations of the model are different models. People will call such a model something like Meta-Llama-3.1-8B-Instruct--q4_0, claiming that it's just a "version" of the Meta-Llama-3.1-8B-Instruct. But it's just a lie. It's not the same model, and you should not expect the same results. There is no reason to debug the differences, what exactly would you expect to find, and what action would you envision to take once you find what you are looking for? However, is the quantized version still a useful LLM? Absolutely. Most people don't have an A100 to run the original model, so a quantized version is better than nothing.
srcreigh
5 hours ago
Very fascinating, can you explain more about a time when this happened?
Like what area was affected by fp errors, why were they introduced (was it like refactoring of pytorch code?), how was this determined to be the cause?
llm_trw
6 hours ago
Does this work on the consumer grade cards like the 7090 XTX?
And by work I don't mean: spend two weeks trying to get the drivers set up and never update the server again.
cameron_b
6 hours ago
I'm glad to see a full implementation on AMD hardware.
I'm not familiar with JAX, but the idea of providing an abstraction layer to more easily get to work on what hardware is available seems really valuable. Bringing back some competitiveness to the ecosystem will be a big win for workload mobility.
I suspect that price/performance across implementations will be highly dependent on contract details, but do you intend to publish some comparisons in the future?
anthonix1
5 hours ago
Any direct comparisons to 8xH100? 2 toks/sec seems very slow!
I haven't done any LoRA training on MI300x myself, but I have done LLama 3.1 full training on 8xMI300x and got pretty close to 8xH100 performance with my own kernels (ROCm is just too slow).
felarof
3 hours ago
Oops, my calculation was wrong. Let me add an edit to the blog, thanks for pointing it out!
My train step was taking 30s.
And I was using a batch size of 16 and seq length of 64, making the training speed as (16*64/30) tokens per sec == 35 tokens per second (for fine-tuning in JAX eager mode).
(I haven't done comparison with 8XH100)
jgalt212
7 hours ago
Is there some cost rule of thumb to compare Nvidia, AMD, and Google TPU?
felarof
5 hours ago
Good question. No good metric give performance depends on software stack (JAX vs PyTorch) + optimizations.
But my take performance per dollar of TPU > AMD > NVIDIA.
CuriouslyC
6 hours ago
TPUs are slow but efficient and AMD has bugs but for some things works quite well. Nvidia is obviously the gold standard.
felarof
5 hours ago
Haha, TPUs are not slow :) All of Google's training (including Gemini models) is done on TPUs.
There are good 1p [a] and 3p [b] benchmarks comparing TPUs vs NVIDIA GPUs.
[a] - https://github.com/GoogleCloudPlatform/vertex-ai-samples/blo...
ngcc_hk
7 hours ago
Given it is a migration, is there actual comparison of the same model on PyTorch vs your version. The comparison table there seems to be on technical side.
Also any technical issues encountered?
felarof
3 hours ago
We have a few technical issues that we still need to address:
1) This entire fine-tuning run was done in JAX eager mode. I kept running out of memory (OOM) when trying to `jax.jit` the entire training step. Even gradual `jax.jit` didn't work.
2) The current version doesn't have gradient accumulation, and with a batch size of just 16, that’s not ideal. I'm working on implementing gradient accumulation next.
3) We still haven't found a good way to load large sequence-length data (like 32k sequence length). Currently, before sharding the training batch across GPUs, it ends up loading the entire batch onto a single GPU’s VRAM and causes OOM issues.