A GPT in 60 Lines of NumPy

by squidhunteron 2/9/2023, 4:08 PMwith 146 comments

by jaykmodyon 2/9/2023, 8:25 PM

Hey ya'll author here!

Thank you for all the nice and constructive comments!

For clarity, this is ONLY the forward pass of the model. There's no training code, batching, kv cache for efficiency, GPU support, etc ...

The goal here was to provide a simple yet complete technical introduction to the GPT as an educational tool. Tried to make the first two sections something any programmer can understand, but yeah, beyond that you're gonna need to know some deep learning.

Btw, I tried to make the implementation as hackable as possible. For example, if you change the import from `import numpy as np` to `import jax.numpy as np`, the code becomes end-to-end differentiable:

    def lm_loss(params, inputs, n_head) -> float:
        x, y = inputs[:-1], inputs[1:]
        output = gpt(x, **params, n_head=n_head)
        loss = np.mean(-np.log(output[y]))
        return loss
  
    grads = jax.grad(lm_loss)(params, inputs, n_head)
You can even support batching with `jax.vmap` (https://jax.readthedocs.io/en/latest/_autosummary/jax.vmap.h...):

    gpt2_batched = jax.vmap(gpt2, in_axes=0)
    gpt2_batched(batched_inputs) # [batch, seq_len] -> [batch, seq_len, vocab]
Of course, with JAX comes in-built GPU and even TPU support!

As far as training code and KV Cache for inference efficiency, I leave that as an exercise for the reader lol

by simonwon 2/9/2023, 6:23 PM

This article is an absolutely fantastic introduction to GPT models - I think the clearest I've seen anywhere, at least for the first section that talks about generating text and sampling.

Then it got to the training section, which starts "We train a GPT like any other neural network, using gradient descent with respect to some loss function".

It's still good from that point on, but it's not as valuable as a beginner's introduction.

by barbazooon 2/9/2023, 5:40 PM

So much criticism in the comments. I appreciated the write-up and the code samples. For some people not in ML like myself it's hard to understand the concept behind GPT and this made it a little bit clearer.

by lspearson 2/9/2023, 6:03 PM

For those interested I would also check out Andrej Karpathy's YouTube video on building GPT from scratch:

https://youtu.be/kCc8FmEb1nY

by ultrasounderon 2/9/2023, 8:39 PM

I also learnt a ton from NLPDemystified-https://www.nlpdemystified.org. In fact I used this resource first before attempting Andrej Karpathy's https://karpathy.ai/zero-to-hero.html. I find Nitin's voice soothing and am able to focus more. I also found the pacing good and the course introduces a lots of concepts a beginner level and also points to appropriate resources along the way(spacy for instance). Overall an exciting time to be a total beginner looking to grok NLP concepts.

by adamnemecekon 2/9/2023, 5:02 PM

It turns out that transformers have a learning mechanism similar to autodiff but better since it happens mostly within the single layers as opposed to over the whole graph. I wrote a paper on this recently https://arxiv.org/abs/2302.01834v1. The math is crazy.

by eddsh1994on 2/9/2023, 5:11 PM

Why do people in ML put imports inside function definitions?

by teaearlgraycoldon 2/10/2023, 1:49 AM

Reminds me the scene from Westworld where they explain their failed prototypes of the human mind with millions of lines of code. The version that finally worked was only a few dozen.

by qwerty456127on 2/10/2023, 9:38 AM

How powerful/heavy it is? Some time ago here was a post about implementing a GPT on a very constrained computer (under a gigabyte of RAM, some old CPU, no GPU (?)) as opposed to an ordinary kind of GPT requiring terabytes of RAM.

I immediately thought it would be nice to do something in the middle: taking full advantage of a reasonably modern multicore CPU with AVX support, a humble yet again reasonably modern OpenCL-capable GPU and some 32 Gigabytes of RAM.

by lvwarrenon 2/11/2023, 8:04 AM

Make this change in utils.py:

  def load_gpt2_params_from_tf_ckpt(tf_ckpt_path, hparams):
       [...]
        #name = name.removeprefix("model/")
        name = name[len('model/'):]
and you're cool example will run in Google Colab under Python 3.8 otherwise the 3.9 Jupyter patching is a headache.

by eston 2/11/2023, 6:02 AM

> GPT-3 was trained on 300 billion tokens of text from the internet and books:

> GPT-3 is 175 billion parameters

Total newbie here. What does these two numbers mean?

If running huge number of texts through BPE, we get a array with length of 300B ?

What's the number if we de-dup these tokens? (size of vocab?)

175B parameters means there are somewhat useful 175B floats in the pre-trained neural network?

by eslaughton 2/9/2023, 9:04 PM

I know this probably isn't intended for performance, but it would be fun to run this in cuNumeric [1] and see how it scales.

[1]: https://github.com/nv-legate/cunumeric

by voz_on 2/9/2023, 7:01 PM

Wonderfully written, I love the amount of detail put into the diagrams. Would love breakdowns like this for more stuff :)

by durdnon 2/10/2023, 1:07 PM

Very impressive. Recently I watched this really amazing lecture on building GPT from scratch from Karpathy, I was blown away: https://www.youtube.com/watch?v=kCc8FmEb1nY&t=642s

by eston 2/10/2023, 3:31 AM

If I maintain an open source project, could I build a doc page using a small GPT allowing users to query FAQ and common methods using natural language?

by thomasfromcdnjson 2/9/2023, 6:50 PM

This reads really well, thank you very much.

by lvwarrenon 2/11/2023, 7:57 AM

make this change and it will run under Python 3.8 in google colab

        #name = name.removeprefix("model/")
        name = name[len('model/'):]
in function: load_gpt2_params_from_tf_ckpt in the utils.py module

by sva_on 2/9/2023, 4:38 PM

Impressive, but only forward pass.

by insane_dreameron 2/9/2023, 7:11 PM

nice and clear. a worthy contribution to the subject.

by terran57on 2/9/2023, 5:13 PM

From the article:

"Of course, you need a sufficiently large model to be able to learn from all this data, which is why GPT-3 is 175 billion parameters and probably cost between $1m-10m in compute cost to train.[2]"

So, perhaps better title would be "GPT in 60 Lines of Numpy (and $1m-$10m)"

by eric_huion 2/13/2023, 8:00 AM

fantastic article about GPT. Thank you for sharing

by freecodyxon 2/9/2023, 5:43 PM

Since most models require little code compared to big software projects, why not use c++ or any other compiled language directly. Python with it’s magic functions, shortcuts is just hiding too much complexity which can result in bug performance issues. Plus code is more hard to maintain