Fully-recurrent neural network
The basic concept of a recurrent neural network (RNN) is that each token (usually a word or word piece) in our sequence feeds forward into the representation of our next one. We start with the embedding for our first token t0. For the next token, t1 we take some function (defined by the weights our neural network learns) of the embeddings for t0 and t1 like f(t0, t1). Each new token combines with the previous token in the sequence until we reach the final token, whose embedding is used to represent the whole sequence. This simple version of this architecture is a fully-recurrent neural network (FRNN).
This architecture has issues with vanishing gradients that limit the neural network training process. Remember, training a neural network works by making small updates to model parameters based on a loss function that expresses how close the model’s prediction for a training item is to the true value. If an early parameter is buried under a series of decimal weights later in the model, it quickly approaches zero. Its impact on the loss function becomes negligible, as do any updates to its value.
This is a big problem for long-distance relationships common in text. Consider the sentence "The dog that I adopted from the pound five years ago won the local pet competition." It's important to understand that it's the dog that won the competition despite the fact that none of these words are adjacent in the sequence.
Long short-term memory
The long short-term memory (LSTM) architecture addresses this vanishing gradient problem. The LSTM uses a long-term memory cell that stably passes information forward parallel to the RNN, while a set of gates passes information in and out of the memory cell.
Remember, though, that in the machine learning world a larger training set is almost always better. The fact that the LSTM has to calculate a value for each token sequentially before it can start on the next is a big bottleneck—it’s impossible to parallelize these operations.
Transformer
The transformer architecture, which is at the heart of the current generation of LLMs, is an evolution of the LSTM concept. Not only does it better capture the context and dependencies between words in a sequence, but it can run in parallel on the GPU with highly-optimized tensor operations.
The transformer uses an attention mechanism to weigh the influence of each token in the sequence on each other token. Along with an embedding value of each token, the attention mechanism learns two more vectors for each token: a query vector and a key vector. How close a token’s query vector is to another token’s key vector determines how much of the second token’s value gets added to the first.
Because we’ve loosened up the sequence bottleneck, we can afford to stack up multiple layers of attention—at each layer, the attention contributes a little meaning to each token from the others in the sequence before moving on to the next layer with the updated values.
If you’ve followed enough so far that we can cobble together a spatial intuition for this attention mechanism, I’ll consider this article a success. Let’s give it a try.
A token’s value vector captures its semantic meaning in a high-dimensional embedding space, much like in our library analogy from earlier. The attention mechanism uses another embedding space for the key and query vectors—a sort of semantic plumbing in the floor between each level of the library. The key vector positions the output end of a pipe that draws some semantic value from the token and pumps it out into the embedding space. The query vector places the input end of a pipe that sucks up semantic value other tokens’ key vectors pump into the embedding space nearby and all this into the token’s new representation on the floor above.
To capture an embedding for a full sequence, we just pick one of these tokens to grab a value vector from and use in the downstream tasks. (Exactly which token this is depends on the specific model. Masked models like BERT use a special [CLS] or [MASK] token, while the autoregressive GPT models use the last token in the sequence.)
So the transformer architecture can encode sequences really well, but if we want it to understand language well, how do we train it? Remember, when we start training, all these vectors are randomly initialized. Our tokens’ value vectors are distributed at random in their semantic embedding space as are our key and query vectors in theirs. We ask the model to predict a token given the rest of the encoded sequence. The great thing about this task is that we can gather as much text as we can find and turn it into training data. All we have to do is hide one of the tokens in a chunk of text from the model and encode what’s left. We already know what the missing token should be, so we can build a loss function based on how close the prediction is to this known value.
The other beautiful thing is that the difficulty of predicting the right word scales up smoothly. It goes from a general sense of topicality and word order—something even a simple predictive text model on your phone can do pretty well—up through complex syntax and semantics.
The incredible thing here is that as we scale up the number of parameters in these models—things like the size of the embeddings and number of transformer layers—and scale up the size of the training data, the models just keep getting better and smarter.
property\model | RNN | LSTM | Transformer |
ㅤ | short term( hidden) | long term( coz of gating mech) | long term and complex |
ㅤ | cant process data in parallel
gradient problem | more complex | parallel processing
capturing long term dependency |