Model Architecture
DALL-E is an extension of the game-changing GPT-3 model released by the OPENAI team in July 2020. The GPT-3 has a mind-boggling 175 billion number of parameters, and it has been trained on hundreds of billions of words.
From the preprint released by the team, we know that we have a simple decoder-only transformer that receives the text and image as a single stream of data at the backbone of this network. Since its discovery, transformers have been revolutionary in the field of Natural Language Processing.
However, training the network is a challenge in itself. The difficulty arises because, in an image, the transformer tends to focus on short-range dependencies between pixels rather than the low-lying structures. These low-lying structures give any image their perceptible form to a human eye.
To overcome this challenge, the researchers used a 2 phase training process that aims to maximize the evidence lower bound on the joint likelihood of the model’s distribution on images, captions, and the encoded tokens of the picture.
First training phase:
Researchers have trained a discrete variational auto-encoder responsible for generating the image tokens in the first training phase. Using each pixel of a 256*256 image is impractical because it increases the model’s size tremendously. Instead, an auto-encoder is used, which decomposes an image into a 32*32 grid. Each element of this grid can occupy 8192 values. This grid is then used as the image token data.
Second training phase:
This phase is where the huge 12 Billion parameter transformer model comes into play. The 32*32 = 1024 image tokens are concatenated with 256 text tokens. Here the image tokens are obtained by argmax sampling of the dvae encoder’s output logits. An autoregressive transformer is trained to model the joint distribution over the token data stream.
After the model is trained, it is used to generate images from captions. At a time, the model is asked to generate several images from a given caption. After gathering all these images, an additional CLIP model is used to collect the most realistic images from the newly generated batch of images for a particular caption.
Setup your own model
OpenAI has not yet open-sourced their code base for DALL-e because they are still testing the social implications of their new model. Some other people, however, have attempted to clone their work and have open-sourced the code. The cloned work achieves most of what OpenAI’s model has achieved however it does not produce as realistic results as the original DALL-E because of its sheer size.
You can find the open-source implementation of DALL-E here. In the following section, we will discuss the code of the DALL-E model and how the various important functions work. You can read more about the discrete VAE here.
DALL-E Model Code
The author of this work has implemented the DALL-E model in PyTorch. If you are not familiar with PyTorch, then you may check out its documentation.
- The Constructor (init function)
Like any other init function, the init function of DALL e class is responsible for the initialization and assignment of essential parameters of the model. Firstly, it checks whether the vae is an instance of the DiscreteVAE class type or not. Then it assigns embedding layers to the image and text tokens.
Secondly, it adds up the number of tokens from the image and text stream data. It initializes a Transformer object with total number of outputs as the total number of tokens in the combined sequence.
- Generating images
This member function is decorated with the decorator torch.no_grad() to disable the computation of gradients in this step. This is because, at the time of generating images, the model need not be updated. This function would be used at the time of testing and generation of images.
This function first slices the total text data up to the maximum possible length of sequence data which can be fed into the DALL-E model, which in this case equals to text_seq_len.
Next, it checks whether there is an image that we have fed into the model to be used as a baseline. In this blog, I showed you the results of a “teapot with a heart”. To generate this example, the researchers had fed into the model a black mug so that the model can have an idea of what the researchers are expecting. This img field is optional and can be used to generate more realistic images if needed.
If there is any such image, the dVAE is called and the image tokens are appended to the text tokens.
The final step in this process is feeding the token data into the forward pass of the DALL-E model which is done within a for loop in order to iterate over multiple captions. The image token data is then sliced from the model output and fed into the dVAE’s decoder. This decoder returns a decoded image from the given token data.
Optionally, CLIP is used to rank the images based on how realistic they are.
- Forward pass of the model
A forward pass of the DALL-E model requires text, image, mask, and return_loss as arguments. Note that only the text argument is a compulsory argument because, at the time of testing, we need not necessarily feed an image into the model.
If an image has been provided as an argument to the function, then we run our pretrained dVAE on it and generate the encoded token data from the image. The image token data is then concatenated with our text token data.
After we have our sequence of token data ready, we feed this into our simple transformer (the one that we initialized in the init function) object which models the data autoregressively.
Now, in case we are not training, we say return_loss is false. In that case, this function returns the logits that are output by the transformer. In the case of training, return_loss is true and this calculates the loss and returns the total loss for one forward pass.