We fine-tuned Llama 405B on AMD GPUs

252 pointsposted 8 hours ago
by felarof

46 Comments

felarof

7 hours ago

Hey HN, we recently fine-tuned the llama3.1 405B model on 8xAMD MI300x GPUs using JAX instead of PyTorch. JAX's advanced sharding APIs allowed us to achieve great performance. Check out our blog post to learn about the cool sharding tricks we used. We've also open-sourced the code: https://github.com/felafax/felafax

We're a small startup building AI infra for fine-tuning and serving LLMs on non-NVIDIA hardware (TPUs, AMD, Trainium).

Problem: Many companies are trying to get PyTorch working on AMD GPUs, but we believe this is a treacherous path. PyTorch is deeply intertwined with the NVIDIA ecosystem in a lot of ways (e.g., `torch.cuda` or scaled_dot_product_attention is an NVIDIA CUDA kernel exposed as a PyTorch function). So, to get PyTorch code running on non-NVIDIA hardware, there's a lot of "de-NVIDIAfying" that needs to be done.

Solution: We believe JAX is a better fit for non-NVIDIA hardware. In JAX, ML model code compiles to hardware-independent HLO graphs, which are then optimized by the XLA compiler before hardware-specific optimization. This clean separation allowed us to run the same LLaMA3 JAX code both on Google TPUs and AMD GPUs with no changes.

Our strategy as a company is to invest upfront in porting models to JAX, then leverage its framework and XLA kernels to extract maximum performance from non-NVIDIA backends. This is why we first ported Llama 3.1 from PyTorch to JAX, and now the same JAX model works great on TPUs and runs perfectly on AMD GPUs.

We'd love to hear your thoughts on our vision and repo!

ipsum2

5 hours ago

I, and several others, had no problem running on PyTorch on AMD GPUs, with no code changes from CUDA. Check out MosaicML's blog posts: https://www.databricks.com/blog/training-llms-scale-amd-mi25...

felarof

5 hours ago

Ahh, interesting, will take a look!

Curious what are the steps to run PyTorch on AMD (does it work out-of-box with PyTorch+rocm docker image)? Does torch.compile work smoothly?

anthonix1

4 hours ago

Yeah I would suggest taking a look at PyTorch on AMD before saying stuff like "scaled_dot_product_attention is an NVIDIA CUDA kernel exposed as a PyTorch function", because that is demonstrably false.

Also, FWIW, I would suggest getting a small Llama 3.1 model training fast before trying to do a big 405B model -- faster to iterate and almost everything you'll learn on the small models will scale to the 405B.

felarof

4 hours ago

Thanks for the feedback! I appreciate you pointing that out. My understanding was based on the PyTorch documentation for scaled_dot_product_attention (https://pytorch.org/docs/stable/generated/torch.nn.functiona...). - "The function may call optimized kernels for improved performance when using the CUDA backend. For all other backends, the PyTorch implementation will be used."

And was trying to make a broader point about the lack of transparency (in performance, lower-level impl) in PyTorch when running on NVIDIA vs. non-NVIDIA hardware.

anthonix1

3 hours ago

I just asked an instance of Hermes 3 Llama 3.1 405B running on AMD GPUs: "does pytorch scaled dot product attention run on AMD GPUs?":

"Yes, PyTorch's scaled dot product attention can run on AMD GPUs. PyTorch supports AMD GPUs through the ROCm (Radeon Open Compute) platform. To use PyTorch with an AMD GPU, you need to install the ROCm version of PyTorch, which is specifically built for AMD GPUs [...]"

And it proceeded to give the steps to follow to install and run, with example Python code to demonstrate it. One slight nitpick is that it referred to an older URL with the --index-url to install torch with pip, but otherwise it was correct.

mistymountains

4 hours ago

Again, the problem is custom kernels in CUDA. It’s not straightforward for many applications (LLMs are probably the most straightforward).

anthonix1

5 hours ago

Does JAX have its own implementations of matmul, flash attention etc? Or does it use the ROCm implementations like PyTorch does? (e.g,. hipblaslt, Composable Kernel FA etc)

Not too familiar with JAX, but the abysmal PyTorch training perf on MI300x is in large part attributable to the slow perf of the ROCm libraries it is using under the hood.

germanjoey

6 hours ago

How are you verifying accuracy for your JAX port of Llama 3.1?

IMHO, the main reason to use pytorch is actually that the original model used pytorch. What can seem to be identical logic between different model versions may actually cause model drift when infinitesimal floating point errors accumulate due to the huge scale of the data. My experience is that debugging an accuracy mismatches like this in a big model is a torturous ordeal beyond the 10th circle of hell.

felarof

3 hours ago

Good question. We used a new AI+math-based testing tool (benchify.com) to run comparison tests, but we are working on building more robust infrastructure for this. Translating models from PyTorch to JAX is core to our strategy.

That said, this path is not uncommon (translating from one framework to another). HuggingFace translates Google's Gemma family models from JAX to PyTorch, and a ton of people use it.

credit_guy

5 hours ago

When you say "model versions", do you mean different quantizations of the model? Then it's not floating point errors that accumulate. Different quantizations of the model are different models. People will call such a model something like Meta-Llama-3.1-8B-Instruct--q4_0, claiming that it's just a "version" of the Meta-Llama-3.1-8B-Instruct. But it's just a lie. It's not the same model, and you should not expect the same results. There is no reason to debug the differences, what exactly would you expect to find, and what action would you envision to take once you find what you are looking for? However, is the quantized version still a useful LLM? Absolutely. Most people don't have an A100 to run the original model, so a quantized version is better than nothing.

srcreigh

5 hours ago

Very fascinating, can you explain more about a time when this happened?

Like what area was affected by fp errors, why were they introduced (was it like refactoring of pytorch code?), how was this determined to be the cause?

llm_trw

6 hours ago

Does this work on the consumer grade cards like the 7090 XTX?

And by work I don't mean: spend two weeks trying to get the drivers set up and never update the server again.

cameron_b

6 hours ago

I'm glad to see a full implementation on AMD hardware.

I'm not familiar with JAX, but the idea of providing an abstraction layer to more easily get to work on what hardware is available seems really valuable. Bringing back some competitiveness to the ecosystem will be a big win for workload mobility.

I suspect that price/performance across implementations will be highly dependent on contract details, but do you intend to publish some comparisons in the future?

anthonix1

5 hours ago

Any direct comparisons to 8xH100? 2 toks/sec seems very slow!

I haven't done any LoRA training on MI300x myself, but I have done LLama 3.1 full training on 8xMI300x and got pretty close to 8xH100 performance with my own kernels (ROCm is just too slow).

felarof

3 hours ago

Oops, my calculation was wrong. Let me add an edit to the blog, thanks for pointing it out!

My train step was taking 30s.

And I was using a batch size of 16 and seq length of 64, making the training speed as (16*64/30) tokens per sec == 35 tokens per second (for fine-tuning in JAX eager mode).

(I haven't done comparison with 8XH100)

jgalt212

7 hours ago

Is there some cost rule of thumb to compare Nvidia, AMD, and Google TPU?

felarof

5 hours ago

Good question. No good metric give performance depends on software stack (JAX vs PyTorch) + optimizations.

But my take performance per dollar of TPU > AMD > NVIDIA.

ngcc_hk

7 hours ago

Given it is a migration, is there actual comparison of the same model on PyTorch vs your version. The comparison table there seems to be on technical side.

Also any technical issues encountered?

felarof

3 hours ago

We have a few technical issues that we still need to address:

1) This entire fine-tuning run was done in JAX eager mode. I kept running out of memory (OOM) when trying to `jax.jit` the entire training step. Even gradual `jax.jit` didn't work.

2) The current version doesn't have gradient accumulation, and with a batch size of just 16, that’s not ideal. I'm working on implementing gradient accumulation next.

3) We still haven't found a good way to load large sequence-length data (like 32k sequence length). Currently, before sharding the training batch across GPUs, it ends up loading the entire batch onto a single GPU’s VRAM and causes OOM issues.

3abiton

7 hours ago

Firstly great work! I dabbled with AMD GPUs and ROCm support a year ago, and it was obvious AMD still a long way from catch ling up with Nvidia. While opting for JAX is in an interesting approach, what were the challenges for you deviating from pytorch (being the standard library for ML)?

felarof

5 hours ago

A few weeks ago, I did a Show HN explaining our journey: https://news.ycombinator.com/item?id=41512142.

We initially started with the goal of fine-tuning LLaMA 3 on TPUs, but PyTorch XLA was clunky, so we decided to rewrite the model in JAX. That said, as mentioned earlier in the thread, we also believe JAX is a better platform for non-NVIDIA GPUs and want to build on JAX+openXLA for building infra for non-NVIDIA GPUs.

6y56h56

6 hours ago

I cannot get AMD ROCm running on my debian 12 system which is what I think is causing Ollama to use CPU instead of GPU. So I guess there is still a long way to go.

jchw

4 hours ago

At the risk of pissing people off, I think you may be better served by a distribution that provides a more up-to-date kernel. Debian 12 will give you Linux 6.1 LTS, which is probably OK if you're using an older Radeon card, but I've heard support for the 7900 XT/X series is a bit dicey and beyond that (e.g. Radeon 890M) non-existent.

If there were improvements on the AMDGPU DRM driver side, you would not see them in Debian any time soon, as the 6.1 LTS kernel will be stuck with roughly whatever shipped January of last year. This is just a shortcoming in the Linux kernel, due to its lack of any kind of stable ABI for drivers.

Of course it is possible this would help nothing or even hurt. My experience running stable (or even newer) kernels has been quite good, though. I run stable or newer across a few devices and run into hiccups not more than once every few years, which is definitely worth it to be able to get new driver improvements years in advance.

(FWIW Debian is not even supported by ROCm[1]... although distros with even older kernels are. But, even if ROCm works, I can't imagine you will get ideal hardware support when running older kernels. I am not sure if ROCm has some workaround for enterprise Linux distributions specifically, but it feels like they must, given how many of their customers in the datacenter are likely to want to use them.)

[1]: https://rocm.docs.amd.com/en/latest/compatibility/compatibil...

slavik81

an hour ago

> I've heard support for the 7900 XT/X series is a bit dicey

The firmware-amd-graphics package in stable is too old to properly support RDNA 3. It kind of works, but it is quite buggy. All RDNA 3 users on Debian 12 should be sure to install the kernel and firmware from bookworm-backports.

There is full support for RDNA 3 hardware enabled on Debian Testing (both in the drivers and runtime libraries). The Debian ROCm Team intended to backport all the ROCm packages from Testing into Bookworm, but have been held up as LLVM 17 is not available in bookworm-backports (yet?).

> FWIW Debian is not even supported by ROCm

ROCm does not support Debian, but Debian supports ROCm. Most of the libraries that comprise ROCm have been directly packaged by the distribution.

llm_trw

4 hours ago

Like everything in machine learning it only really runs on Ubuntu 22.04. Anything else is unsupported and you need to spend weeks tinkering to get it to work, then never upgrade.

ants_everywhere

6 hours ago

I've had more luck with the ROCm docker container. I run it via k8s. It was pretty painless to set up and has been mostly painless since. Prior to that it was nearly impossible to get Jax running reliably on ROCm.

Even with the container, you have to be careful installing Python libraries because they can still break things.

lenova

6 hours ago

I just recently went down the AMD GPU + ROCm rabbit hole as well. ROCm 6.2 was just released in August of this year and introduces a lot better support, though as the above poster mentioned, isn't merged into most recent OSes.

This Github repo is good for tracking the latest Ubuntu + ROCm install process: https://github.com/nktice/AMD-AI

latchkey

3 hours ago

That's a nice repo of random installation notes. Very helpful, thanks!

superkuh

6 hours ago

You'd probably have a lot better luck using Vulkan acceleration (not ROCm) of llama.cpp as backend to ollama. It is incomparibly easier to set up and maintain compared to ROCm. You can actually do it on your computer's normal OS instead of inside a bunch of container/vms where the system libs are entirely customized to running just that one application.

AMD's support of consumer cards is very, very short. By the time it's stable enough for a new card to run the card is no longer supported. In 2021 I bought an AMD GPU that came out 3 years before and 1 year after I bought it (4 years since release) they dropped ROCm support.

coppsilgold

17 minutes ago

ROCm is not even worth the effort for inference workloads. Vulkan is much more convenient and performs fine.

llama.cpp and stable-diffusion.cpp offer Vulkan backends but generally you can run most models on Vulkan if you use IREE[1].

[1] <https://iree.dev/guides/ml-frameworks/>

latchkey

7 hours ago

Nice work! I was just playing with the inference side of things with 405B myself this weekend [0].

I'm not convinced that 'torch.cuda' is really that bad since the AMD version of PyTorch just translates that for you. More like a naming problem, than anything. Fact is that it is just as easy to grab the rocm:pytorch container, as it is the rocm:jax container.

I don't see very many numbers posted. What MFU did you get?

[0] https://x.com/HotAisle/status/1837580046732874026

felarof

5 hours ago

Nice!

I need to calculate MFU. GPU, VRAM details can be found in the repo: https://dub.sh/amd-405b-res.

I plan to reattempt the training run next weekend and JIT the entire training step to calculate MFU then

yeahwhatever10

7 hours ago

Where is the performance data?

felarof

5 hours ago

(author here, sorry for the delay in replying, was stuck in back-to-back meetings)

I updated our github repo to include GPU, VRAM utilization data (https://github.com/felafax/felafax?tab=readme-ov-file#amd-40...)

Note: we couldn't run the JIT-compiled version of the 405B model due to our code/VRAM constraints (we need to investigate this further). The entire training run was executed in JAX eager mode, so there is significant potential for performance improvements.

GPU utilization across the board was still ~30-40% even with eager mode, which is quite good! With JIT, I think the GPU util can easily shoot up to ~50-60%.

manojlds

6 hours ago

Thought this was a post from Obsidian at first. Why haven't they done the GitHub.com vs GitHub.io thing yet.

codetrotter

5 hours ago

Looking at the URL has me thinking that this confusion would be resolved if HN adds a small piece of logic to treat the domain publish.obsidian.md specially, just like how HN already does for pages served under forbes.com/sites which is not written by the Forbes staff themselves.

So instead of showing the domain as obsidian.md, HN would show the domain for this link as publish.obsidian.md

Maybe something for dang to consider if he sees this comment?

gbraad

6 hours ago

Same thought here. Why would Obsidian bother with AI? Oh wait, this is publish? So this is what $8 per month gets you? I am amazed, as I would have at least expected a subhost: [username].publish.obsidian.md

felarof

5 hours ago

Yeah, used Obsidian Publish.

But struggling to get custom domain to work with it (have emailed support).

abalaji

7 hours ago

@dang: could we get url to include the username since this isn't about Obsidian itself, but rather a user generated blog?

m00x

5 hours ago

It's strange that HN didn't include the full domain "publish.obsidian.cmd".

dang

2 hours ago

That's not turned on by default but I've done it for this domain now.

meiraleal

6 hours ago

That's something obsidian should fix if they care about not looking like they are being impersonated on HN.

viraptor

3 hours ago

Obsidian can't do anything about it. It's HN chopping up the url