A GPT in 60 Lines of NumPy(jaykmody.com) |
A GPT in 60 Lines of NumPy(jaykmody.com) |
One counterpoint would be that verbosity, especially in the heavy syntax style of languages such as C++, distracts the reader and helps bugs hide in plain sight. For a silly example, imagine trying to read and verify the correctness of an academic paper from its uncompiled LaTeX source.
Thank you for all the nice and constructive comments!
For clarity, this is ONLY the forward pass of the model. There's no training code, batching, kv cache for efficiency, GPU support, etc ...
The goal here was to provide a simple yet complete technical introduction to the GPT as an educational tool. Tried to make the first two sections something any programmer can understand, but yeah, beyond that you're gonna need to know some deep learning.
Btw, I tried to make the implementation as hackable as possible. For example, if you change the import from `import numpy as np` to `import jax.numpy as np`, the code becomes end-to-end differentiable:
def lm_loss(params, inputs, n_head) -> float:
x, y = inputs[:-1], inputs[1:]
output = gpt(x, **params, n_head=n_head)
loss = np.mean(-np.log(output[y]))
return loss
grads = jax.grad(lm_loss)(params, inputs, n_head)
You can even support batching with `jax.vmap` (https://jax.readthedocs.io/en/latest/_autosummary/jax.vmap.h...): gpt2_batched = jax.vmap(gpt2, in_axes=0)
gpt2_batched(batched_inputs) # [batch, seq_len] -> [batch, seq_len, vocab]
Of course, with JAX comes in-built GPU and even TPU support!As far as training code and KV Cache for inference efficiency, I leave that as an exercise for the reader lol
Music to my ears, well done and don't worry too much about the negative comments! They'll come out for anything you do I think.
I saw a tweet from someone the other day talking about how they massively increased their training speed by changing part of their architecture to have dimensions that were a factor of 64 rather than a prime-like kind of number.
One of the comments below it? ~"Seems very architecture specific."
lol.
So don't sweat it! <3 Great work and thanks for putting yourself out there, super job! :D :D :D :D :)))))) <3 :D :D :fireworks:
If you haven't tried cuNumeric [1], you really ought to. It's a drop-in NumPy wrapper for distributed GPU acceleration. Would be interesting to see if it works for this.
I'd be curious how that library compares to other numeric python GPU libraries
Neat, but please add one-line comments/docstrings where these missing bits would go.
I want to commend you for one of the best written introductions in this space that I've seen, especially the excellent use of hyperlinking that points to really good resources exactly at the right time !
Then it got to the training section, which starts "We train a GPT like any other neural network, using gradient descent with respect to some loss function".
It's still good from that point on, but it's not as valuable as a beginner's introduction.
I thought it was a great post and manky kudos to the author for putting themselves out like that! I really appreciated this and any work that does this kind of effort in onboarding people and giving people tools to understand something well really I think has some of the most long-term impact to the field.
Lowering barriers to entry, making resources accessible to all, and decreasing experimentation cycle time I think are some of the most critical components to making any progress at all in the field beyond a basic pittance. Imagine if everyone had easy access to, knowledge about, and rapid experimentation results in things like quantum mechanics, large-algorithm testing, painting arts, musical arts, etc. It would drive things so much further forward at an individual and field-based level so quickly. <3 :)))) :D :D ;D :D :D :))))))))
"Bartender! A half-pint of your finest Combinatorial Hopf, if you please!"
But this is the same for all residual streams, not just those in transformers.
Join my discord to discuss this further https://discord.gg/mr9TAhpyBW
I immediately thought it would be nice to do something in the middle: taking full advantage of a reasonably modern multicore CPU with AVX support, a humble yet again reasonably modern OpenCL-capable GPU and some 32 Gigabytes of RAM.
def load_gpt2_params_from_tf_ckpt(tf_ckpt_path, hparams):
[...]
#name = name.removeprefix("model/")
name = name[len('model/'):]
and you're cool example will run in Google Colab under Python 3.8 otherwise the 3.9 Jupyter patching is a headache.> GPT-3 is 175 billion parameters
Total newbie here. What does these two numbers mean?
If running huge number of texts through BPE, we get a array with length of 300B ?
What's the number if we de-dup these tokens? (size of vocab?)
175B parameters means there are somewhat useful 175B floats in the pre-trained neural network?
Number of params is the number of weights. Basically the number of learnable variables.
Number of tokens is how many tokens it saw during training.
Vocab size is the number of distinct tokens.
The relationship between params/tokens/compute power is something people have studied a good deal and how it affects model performance. https://arxiv.org/pdf/2203.15556.pdf
#name = name.removeprefix("model/")
name = name[len('model/'):]
in function: load_gpt2_params_from_tf_ckpt in the utils.py module"Of course, you need a sufficiently large model to be able to learn from all this data, which is why GPT-3 is 175 billion parameters and probably cost between $1m-10m in compute cost to train.[2]"
So, perhaps better title would be "GPT in 60 Lines of Numpy (and $1m-$10m)"
It doesn't go into the math but I don't think that's a bad thing for beginners.
If you want mathematical, 3blue1brown has a great series of videos [3] on the topic.
[1] https://www.youtube.com/watch?v=hBBOjCiFcuo&t=1932s
[2] https://github.com/fastai/fastbook/blob/master/04_mnist_basi...
[3] https://www.youtube.com/watch?v=aircAruvnKk
* I've been messing around with this stuff since 2016 and have done a few different courses like the original Andrew Ng course and more.
[1] https://youtu.be/Wo5dMEP_BbI?list=PLQVvvaa0QuDcjD5BAw2DxE6OF... https://nnfs.io
Didn’t do the full course, but after the first few chapters I was able to write a very basic implementation in raw python (emphasizing here on “very basic”)
After all, there was a disclaimer that you might have missed up front in the blogpost! "This post assumes familiarity with Python, NumPy, and some basic experience training neural networks." So it is in there! But in all of the firehose of info we get maybe it is not that hard to miss.
However, I'm here to help! Thankfully the concept is not too terribly difficult, I believe.
Effectively, the loss function compresses the task we've described with our labels from our training dataset into our neural network. This includes (ideally, at least), 'all' the information the neural network needs to perform that task well, according to the data we have, at least. If you'd like to know more about the specifics of this, I'd refer you to the original Shannon-Weaver paper on information theory -- Weaver's introduction to the topic is in plain English and accessible to (I believe) nearly anyone off of the street with enough time and energy to think through and parse some of the concepts. Very good stuff! An initial read-through should take no more than half an hour to an hour or so, and should change the way you think about the world if you've not been introduced to the topic before. You can read a scan of the book at a university hosted link here: https://raley.english.ucsb.edu/wp-content/Engl800/Shannon-We...
Using some of the concepts of Shannon's theory, we can see that anything that minimizes an information-theoretic loss function should indeed learn as well those prerequisites to the task at hand (features that identify xyz, features that move information about xyz from place A to B in the neural network, etc). In this case, even though it appears we do not have labels -- we certainly do! We are training on predicting the _next words_ in a sequence, and so thus by consequence humans have already created a very, _very_ richly labeled dataset for free! In this way, getting the data is much easier and the bar to entry for high performance for a neural network is very low -- especially if we want to pivot and 'fine-tune' to other tasks. This is because...to learn the task of predicting the next word, we have to learn tons of other sub-tasks inside of the neural network which overlap with the tasks that we want to perform. And because of the nature of spoken/written language -- to truly perform incredibly well, sometimes we have to learn all of these alternative tasks well enough that little-to-no-finetuning on human-labeled data for this 'secondary' task (for example, question answering) is required! Very cool stuff.
This is a very rough introduction, I have not condensed it as much as it could be and certainly, some of the words are more than they should be. But it's an internet comment so this is probably the most I should put into it for now. I hope this helps set you forward a bit on your journey of neural network explanation! :D :D <3 <3 :)))))))))) :fireworks:
For reference, I'm interested very much in what I refer to as Kolmogorov-minimal explanations (Wikipedia 'Kolmogorov complexity' once you chew through some of that paper if you're interested! I am still very much a student of it, but it is a fun explanation). In fact (though this repo performs several functions), I made https://github.com/tysam-code/hlb-CIFAR10 as beginner-friendly as possible. One does have to make some decisions to keep verbosity down, and I assume a very basic understand of what's happening in neural networks here too.
I have yet to find a good go-to explanation of neural networks as a conceptual intro (I started with Hinton -- love the man but extremely mathematically technical for foundation! D:). Karpathy might have a really good one, I think I saw a zero-to-hero course from him a little while back that seemed really good.
Andrej (practically) got me into deep learning via some of his earlier work, and I really love basically everything that I've seen the man put out. I skimmed the first video of his from this series and it seems pretty darn good, I trust his content. You should take a look! (Github and first video: https://github.com/karpathy/nn-zero-to-hero, https://youtu.be/VMj-3S1tku0)
For reference, he is the person that's made a lot of cool things recently, including his own minimal GPT (https://github.com/karpathy/minGPT), and the much smaller version of it (https://github.com/karpathy/nanoGPT). But of course, since we are in this blog post I would refer you to this 60 line numpy GPT first (A. to keep us on track, B. because I skimmed it and it seemed very helpful! I'd recommend taking a look at outside sources if you're feeling particularly voracious in expanding your knowledge here.)
I hope this helps give you a solid introduction to the basics of this concept, and/or for anyone else reading this, feel free to let me know if you have any technically (or-otherwise) appropriate questions here, many thanks and much love! <3 <3 <3 <3 :DDDDDDDD :)))))))) :)))) :))))
In [1]: %time import transformers
CPU times: user 3.21 s, sys: 7.8 s, total: 11 s
Wall time: 1.91 s1) For demonstrative purposes. The title of the post is `A GPT in 60 Lines of NumPy`, I kinda wanted to show "hey it's just numpy, nothing to be scared about!". Also if an import is ONLY used in a single function, I find it visually helps show that "hey, this import is only used in this function" vs when it's at the top of the file you're not really sure when/where and how many times an import is used.
2) Scoping. `load_encoder_hparams_and_params` imports tensorflow, which is really slow to import. When I was testing, I used randomly initialized weights instead of loading the checkpoint which is slower, so I was only making use of the `gpt2` function. If I kept the import at the top level, it would've slowed things down unnecessarily.
Here - they put `import fire` only in the `if __name__ == "__main__":` - that seems reasonable to me as anyone pulling in the library from elsewhere doesn't need the pollution.
This make even more sense for a non-standard library like fire because you won't even need this dependency if you're going to import the module and write your own interface instead.
The import in main doesn't seem particularly useful in context on a quick read, but considering the line
> utils.py contains the code to download and load the GPT-2 model weights, tokenizer, and hyper-parameters.
it seems possible some downloads are happening on import so does make sense to defer until actually needed, as suggested in sibling comments.
I've seen a lot of Python people sprinkle imports all over the place in their code. I suspect this is a bad habit learned from too much time working in notebooks where you often have an "oh right, I need XXX library now" and just import it as you need it.
The aggressive aliasing I do get since in DS/ML work it's very common to have the same function do slightly different things depending on the library (standard deviation between numpy and pandas is a good example)
But I personally like all of my imports at the top so I know what this code I'm about to read is going to be doing. I do seem to be in the minority in this (and would be glad to be correct if I'm make some major error).
Of course, "don't do circular imports". But if my Orders model has OrderLines, and my OrderLines points to their Order, it's damn hard to avoid without putting everything in one huge file..
1) Circular dependencies (and you don't want your house of cards falling down if your IDE/isort decides to reorder a few things); 2) (slow/expensive) expressions that are evaluated on import; 3) startup time required for the module loader to resolve everything at start.
Only Big Tech giants like Microsoft, Google, etc can afford to foot the bill and throw away millions into training LLMs, whilst we celebrate and hype about ChatGPT and LLMs getting bigger and significantly more expensive to train when they get confused, hallucinate over silly inputs and confidently generate bullshit.
That can't be a good thing. OpenAI's ClosedAI model needs to be disrupted like how Stable Diffusion challenged DALLE-2 with an open source AI model.
Based on that groups success, they've recently proposed a mini project inspired by GPT that I am considering funding; the data its trained on is all publicly available for free, and most it comes from Common Crawl. I suspect that it will also yield similar results, where you can tailor your own version of GPT and get reasonably good models for a fraction of the price as well. We're no where close to the scale of Big Tech giants, but I've noticed for the better part of 15 years that small companies can actually derive a great deal of the benefits that larger companies have for a fraction of the cost if they play it smart and keep things tight.
That said, other organizations that can afford to foot the bill for it are the governments. This is hardly ideal, since such models will also come with plenty of strings attached - indeed, probably more than the private ones - but at least these policies are somewhat checked by democratic mechanisms.
Long-term I think the demand for more AI compute power will lead to much more investment in GPU design and manufacture, driving the prices down. Since the underlying tech itself is well-understood, I fully expect to see the day when one can train and run a customized GPT-3 instance for one's private use, although the major players will likely be far ahead by then.
1. https://youtu.be/rDke29MbKQA?list=PLyrlk8Xaylp7NvZ1r-eTIUHdy...
How large is the model on disk(s) once it is trained?
This one doesn't use any frameworks. The next book by the author (on GANs) uses PyTorch. The math is relatively easy to follow I think.
Andrew Ng's courses on Coursera can be viewed for free and have sightly more rigorous math, but still okay.
You don't have to understand every mathematical detail, same as you don't need every mathematical detail for 3d graphics. But knowing the basics should be good I think!
"Best practices" are incredibly unevenly distributed, and I suspect this is only more true for data/ML-heavy python code.
On one hand, you can explain it to a 5-year-old: Go in the direction which improves things.
On the other hand, we have more than a half-century of research on sophisticated mathematical methods for doing it well.
The latter isn't really helpful for beginners, and the former is easy to explain. You can't use sophisticated algorithms in either case, for beginners, so you can go with something as dumb as tweak in all directions, and go where it improves most. It will work fine for dummy examples.
"Autodidax: JAX core from scratch" walks you through it in detail.
JAX is able to differentiate arbitrary python code (so long as it uses JAX for the numeric stuff) automatically so the backprop is abstracted away.
If you have the forward model written, to train it all you have to do with wrap it in whatever loss function you want, and the use JAX's `grad` with respect to the model parameters and you can use that to find the optimum using your favorite gradient optimization algorithm.
This is why JAX is so awesome. Differentiable programming means you only have to think about problems in terms of the forward pass and then you can trivially get the derivative of that function without having to worry about the implementation details.
iirc GPT-3 itself alone is some 500TB in size. You need a really, really big machine to run LLMs, the first L means Large.
Edit: ok, Discord it is.
Can you share any good link on the subject?
This isn't something that should matter even a little in typical ML code. But in generic Python libraries, there are cases when this kind of micro-optimization can help. Similar tricks include turning methods into pre-bound attributes in __init__ to skip all the descriptor machinery on every call.
(It's much faster if you implement the callback in native code, but then that doesn't work on IronPython, Jython etc.)
It is absolutely worthwhile to avoid unnecessary imports if possible.
In some (many?) cases it's probably premature optimization, but it doesn't hurt, so I don't see why anyone would get up in arms over it.
For example, the AI doesn't have enough information about a companies process, or a regulation. It chats with an expert to fill in the gaps.
I have no understanding of AI
This is how the new Bing Assistant works. It's also how search engines like https://you.com/ and https://www.perplexity.ai/ work - as exposed by a prompt leak attack against Perplexity a few weeks ago: https://simonwillison.net/2023/Jan/22/perplexityai/
I wrote a tutorial about one way of implementing this pattern yourself here: https://simonwillison.net/2023/Jan/13/semantic-search-answer...
It’s a small difference, perhaps, but with some significance since the retrieval and incorporation occurring outside the model has a different set of trade offs. I’m not specifically aware of any work where model architectures are being extended to perform this function directly, but I am keen to learn of such efforts.
It’s trained on completing the text.
If an expert write a long test and you and "in summary: " at the end, the model will complete with something approximating truth (depend on size of model, training, etc)
Humains do a similar things. We have a model in our head of the subject discussed and we can summarize, but we will forget some parts, make errors, etc. GPT is very similar.
*pun intended :)
Side-effects in imports are, in my opinion, unnecessary, losing some of the benefits of static analysis, running with different parameters during tests, compiling to native code (if those tools exist), slowing things down, and more.
Libraries could have an initializer function and the problem would go away.
The momentum experiment they made also does not seem related. E.g. it just adds past values to V, which extends the effective context length.
Such is the nature of early theories.