This is very cool. I enjoyed going through the writeup and GitHub README.
I was wondering if these same optimizations can be brought to bear on training as well, rather than only inference. I guess the challenge here is fusing backward computations with gradient communication.
I also saw that this currently does not handle dynamic workloads such as MoE. I recently came across this paper that does exactly this:
FlashDMoE: Fast Distributed MoE in a Single Kernel - https://arxiv.org/pdf/2506.04667
Next step - compile straight to verilog so I can buy some LLMs on aliexpress
After working pretty closely with vLLM and SGLang over the past few months, this is EXACTLY what I had envisioned what a successor project would look like - analyzing an operation dependency graph and then fusing (or, at a minimum, scheduling tasks smarter). Congrats to the team.
The improvement is real!
And unlike a lot of research, the code actually runs well. I can reproduce the results using Modal GPUs, leaving the code here: https://github.com/mirage-project/mirage/pull/327/files
Triton + FlashInfer: Prompt length 39, generate length 264, per-token latency 19.189573345762312 ms
MPK: Prompt length 39, generate length 334, per-token latency 7.71875 ms
Somewhat relevant anecdote: we had a small CUDA competition (10-ish years ago). Some embrassingly parallel CV algorithm.
I tried to be smart and cache intermediate results that were shared by multiple kernels.
When the results were published I was stumped to see that others were orders of magnitude faster then me.
Turns out they didn't bother with caching at all. The overhead of recalculating everything a thousand times was tiny compared to the overhead of doing roundtrips through RAM.
I assume it's the same thing here. By compiling into MegaKernels, layer boundaries are squashed. There likely will be _more_ calculations and less shared intermediate results. But overall it's still a win due to less memory roundtrips.
There has to be a sweet spot, especially for convolution networks. No idea if the MegaKernel takes this into account.
This project is from CMU. Hazy Research at Stanford talked about the megakernel too: https://hazyresearch.stanford.edu/blog/2025-05-27-no-bubbles
Good to see the competition in this area.
(Edited): Related paper covering the larger "mirage" project, but this doesn't cover the "megakernel" approach: https://arxiv.org/abs/2405.05751
The Qwen 8B number, if verified, is very impressive. Much more practical than the previous megakernel one.
That's being said, these one-persisted kernel on each SM reminds me Larrabee, and now wondering what the world will be if we just do traditional process-thread-simd path rather than CUDA path.
Does anyone have an intuition on why this offers significant gains over CUDA Graphs?. The CPU launch cost of a graph is tiny which implies most of the work has been offloaded to the GPU's own scheduler. I'd expect that some I/O marshalling at kernel boundaries could be avoided with megakernels. Maybe some loop fusion? Are there any more interesting optimizations they enable?
Certainly an important discovery for utilizing these models on scaled hardware. This approach could certainly be applied beyond LLMs to other types of neural networks. That would be an interesting space to explore.
Curious if anyone has thoughts on going even further: eschewing soft-ware based inference in favor of a purely ASIC approach to a static LLM. Cost benefits? Software level additional, fine-tuneable layers to allow a degree of improvement and flexibility? We are quickly approaching ‘good enough’ for some tasks—at what point does that mean we’re comfortable locking something in for the ~2-4 year lifespan of a device if there _were_ advantages offered by a hyper-specialized chip?
People keep coming up with new metaphors for LLMs to explain their impact and functionality.
Maybe we should think of them like transistors? Right now, we are at the point where we have a room-sized computer than can do multiplication from punch card input.
It is fun to imagine what we could do if we ran, say, 1 million coordinated o3-pro queries at once?
This is super interesting! We do something similar I think by taking a checkpoint after model initialization. I'm curious what you think about our approach, here's some benchmarks: https://docs.cedana.ai/articles/performance-of-cedanas-gpu-i...
We do some on-the-fly optimizations as well (like compiling into CUDA graphs or fusing together calls) which ends up resulting (for some inference engines) faster token throughput too.
A question for the author(s) since they seem to be very responsive to this thread :).
1. How fine grain is each task? In a traditional matrix multiplication kernel, for example, each thread block is responsible for a small output tile of the resulting matrix. In Mirage's mega kernel, would there correspondingly be a task for each small output tile?
2. How does the Mirage compiler form the task graph? Does it have domain knowledge of every operator's data flow at the granularity of individual elements? Again taking matmul as an example: a given output output tile requires the correspond M_BLOCK rows of the A matrix. If the A matrix was itself an output of a prior matmul (+ nonlinearity), the dependees would be all of output tile tasks corresponding to those M_BLOCK rows of the operator that produced A?
if you want to try on 5090, it's not supported yet
> Support for modern GPU architectures. One of our next milestones is extending MPK to support next-generation architectures such as NVIDIA Blackwell. A major challenge lies in integrating warp specialization — a key optimization for newer GPUs — with MPK’s megakernel execution model.
Isn’t fusing ops at a fine-grained level also the core benefit of JAX over TensorFlow? How does this work compare to JAX?
How is this possible? I mean, I thought that sometimes you had no choice but to separate computation into several kernels. But here they literally allow cuda threads to dinamically perform tasks assigned by scheduler threads? I only have a little experience writing cuda kernels, so I have my mind blown.
Probably should make this into a backend of torch.compile
is this approach viable on training? wat about kernels that require different grids?
really cool. would love to try it for our 3b model.
any detailed tutorial about how to use it?
ELI5
> Traditional LLM systems often rely on sequences of GPU kernel launches and external communication calls, resulting in underutilized hardware.
What? Why? This seems like an obvious optimization if it's possible.
Ollama integration?
Hi author(s), the on-GPU interpreter approach looks like a promising path forward, have you seen this strikingly similar concurrent work?
https://news.ycombinator.com/item?id=44111673
I find it curious that fundamentals of the CUDA programming model (eg kernel launches) are being subverted in favor of fine grained task based parallelism that ends up using the hardware more effectively. Makes me wonder if CUDA has been holding us back in some ways.
What are the chances we see your work land in PyTorch as an experimental backend?
Awesome stuff thanks for sharing.
P.S. minor typo, your first two paragraphs under part 1 are nearly identical.