Mamba: A replacement for transformers in neural networks?
Mamba: A replacement for transformers in neural networks?
Not to be confused with Mamba, the package manager for Python (intended to be a replacement for Conda), which is something entirely different.
So first remember "Transformer" is the non intuitive name for an "attention" mechanism in neural networks. It was originally invented in the context of language translation, where words in the target language can be output in a different sequence from the words in the source language, and the neural network had to learn to pay "attention" to words in the input sequence out of order.
It works by constructing a neural network that takes "query", "key", and "value" inputs, and this is supposed to remind of you doing a lookup in a key-value table.
What sort of alternate way could exist to do the same sort of thing?
This development comes from a completely different line of development. In digital signal processing, you have the concept of "linear time invariant" processors. They take inputs, combine them in linear ways, and their behavior doesn't change over time based on prior input.
The next developmental step from that is selective state space models. I studied two of these, Gated Recurrent Units (GRU), and Long Short-Term Memory (LSTM) networks. I didn't realize these were just two examples of a whole family of neural networks. The "state space" part of the name implies that the network maintains internal state. Recurrent neural networks (RNNs) do this. But "selective" part of the name allows special neurons to act as "gates", controlling whether the hidden state gets updated by any particular input or not.
At this point I have to pause and inject some background information. Normally, a neural network is a series of layers, and information goes in one end results come out the other. If you have, say, 5 layers, data goes in layer 1 and results come out layer 5. What gets changed when you go to "recurrent" neural networks, is you add some additional "output", but instead of including it in your results, you loop around and add it to your "input". This gives the neural network a form of "memory" since it can now send information to itself in the future.
To understand how you get from this to GRUs and LSTMs, you have to add "gates" that decide whether the information that gets looped around gets changed or kept the same. That's the basic idea and there are twists on it but I think we can skip that here.
To understand how they got from those to the Mamba system here, you have to completely shift gears and start over with "linear time invariant". That's because the Mamba system is build on something called H3. H3 stands for "Hungry Hungry Hippos". I not joking, that's really what it's called. (Maybe someone played the game a lot as a kid?) That's because it's built on something called "HiPPO", which in turn stands for "high-order polynomial projection operators". As the name implies, it has something to do with polynomials -- what they were doing was using polynomials to solve the problem of what information should be remembered and what should be forgotten in a continuous stream of numerical input -- a problem analogous to the "gating" problem I describe above but for "linear time invariant" processors in digital signal processing. H3 takes HiPPO and adds neural networks to improve the calculations -- essentially it uses neural networks to improve the "gating" function. And they expand the scope beyond digital signal processing to include sequences of tokens, such as text.
The neural networks H3 used for this were convolutional neural networks. Obviously it wouldn't make sense to delve into the complete history of convolutional neural networks -- for here, I think it would suffice to say, convolutional networks take a small number of parameters and apply it to all the input. For example, if you had a convolutional neural network for image recognition and trained it to recognize vertical edges, it wouldn't make sense to recognize edges only in the left corner of the image -- you want it to recognize vertical edges anywhere. So you apply your parameters to all the input.
To go from H3 to Mamba, another "selection" mechanism was added. Essentially, Mamba is H3 plus a gated layer (fully connected layer). (How exactly these two systems interface is unclear to me, but what is clear is the combination of both the H3 gating system and the Mamba gating system is required for the combined system to be competitive with attention via transformers.) But the gated layer is parallelized for performance.
The researchers took great pains to optimize their algorith for modern GPUs, paying close attention to when data is moved from GPU high-bandwidth memory (HBU) and SRAM and back. They designed the algorithm so it can process all the input tokens in parallel, and tested that it can handle up to 1 million input tokens without forgetting anything.
The parallelism is the reason they call their new "selection" mechanism a "scan" -- it's part of a parallel "scan" of the input. Typically this results in a 40x speedup compared with regular "attention"-driven models.
In addition to language text they tested it on DNA sequences. It was able to outperform Pythia, RWKV, OPT, and GPT-Neo models of comparable sizes.
What does this mean for you, in terms of what you should expect from large language models? Well, first it has to be spelled out that Mamba is not a drop-in replacement for attention via transformers. It doesn't work the same way as the "query", "key", "value" system of transformers. So existing models like GPT (which, remember, stands for "generative pre-trained transformer") won't switch to it.
Rather, new models that use it will be built from scratch. And that could take some time, so we probably won't see these right away. But at some point we probably will see these. The fact that they are vastly faster and can process such huger context windows will make them appealing.
https://medium.com/@jelkhoury880/what-is-mamba-845987734ffc
Direct link to paper: