Stable Diffusion Principle Introduction and Source Code Analysis (1)

Stable Diffusion Principle Introduction and Source Code Analysis (1)

Foreword (not related to the main text, can be ignored)

Stable Diffusion is an open-source AI Wensen graph diffusion model from Stability AI. Previously, in the article Diffusion Model Brief Introduction and Source Code Analysis, I introduced the principle of the Diffusion Model and some algorithm codes. After satisfying the basic curiosity, I put it on the shelf. I didn’t expect that the recent development of AIGC is much faster. My expectation, especially after running out of the following AI-generated image, Stable Diffusion finally returned to my field of vision:

As an algorithm engineer, you need to have a pair of eyes that can see through the essence of things. The first thing that attracted me to this picture was not the content, but the quality of its generation: the image is high-definition and rich in details, which is not comparable to some rough toys I saw before. Red The incongruities marked in the box are also flawed. Therefore, it is imminent to further analyze the principles of the entire engineering framework of Stable Diffusion. I look forward to repairing the incongruity in the red box in the future and making a due contribution to the further development of AIGC.

overview

The source code of the entire framework of Stable Diffusion has tens of thousands of lines, so it is not necessary to analyze them all. This article takes "text to image" as the main line, examines the operation process of Stable Diffusion and each important component module, and adopts the form of "total-score" in the introduction, first summarizes the overall framework, and then analyzes each component ( Such as DDPM, DDIM, etc.), and some non-mainstream logic in the code, such as predict_cids, return_idstalk about my views on these small details. The content of the article is long and is ready to be split into multiple parts.

Source address: Stable Diffusion

illustrate

I have written many code analysis articles before, but when I encounter problems and read them again, I find that it is still very difficult to quickly locate the target location and accurately understand the code intent. At first sight, it is difficult to understand, because too many implementation details are introduced in the excerpt, which reduces the efficiency of information dissemination.

After some thinking, I stopped trying to save trouble and decided to use pseudo-code to record the core principles. I usually use this method when I analyze the code in depth. It will take some time to abstract the code, but I think it is beneficial. For example, if I write the forward Diffusion code of the DDPM model in pseudo-code, it will have the following effect:

It can be seen that after removing the irrelevant implementation details, the implementation of DDPM is so simple. If it is combined with certain annotations, it can be easily and quickly understood, and people can get an overall and comprehensive sense of control. In addition, more block diagrams, model diagrams, etc. should be added to the text to more intuitively display the implementation details of the code.

You can search for "Jenny's Algorithm Road" or "world4458" in WeChat, follow my WeChat public account, and get the latest updates of original technical articles in time.

In addition, you can take a look at the Zhihu column PoorMemory-Machine Learning , and future articles will also be published in the Zhihu column.

Stable Diffusion Overall Framework

First look at the overall framework of Stable Diffusion text to generate images (article drawing vomits blood... I hope that one day AI can assist):

There are many modules in the framework of the above picture, which are divided into 3 blocks from top to bottom. I marked them with Part 1, 2, and 3 in the picture. The framework includes two stages of training + sampling, where:

  • The training phase (see Part 1 and Part 2 in the figure), mainly includes:

    1. Use the AutoEncoderKL self-encoder to map the image Image from the pixel space to the latent space, and learn the implicit expression of the image. Note that the AutoEncoderKL encoder has been trained in advance and the parameters are fixed. At this time, the size of the Image will [B, C, H, W]be converted from to [B, Z, H/8, W/8], where Zrepresents the number of Channels of the image in the latent space. This process is called in the Stable Diffusion code encode_first_stage;
    2. Use the FrozenCLIPEmbedder text encoder to encode the Prompt prompt words to generate an embedding representation [B, K, E]of (ie context), where Krepresents the maximum encoding length of the text max length, and Erepresents the size of the embedding. This process is called in the Stable Diffusion code get_learned_conditioning;
    3. Carry out the forward diffusion process (Diffusion Process) to continuously add noise to the implicit expression of the image. This process calls UNetModel to complete; UNetModel simultaneously receives the latent image of the image and the text embedding, and uses Attention as the condition during training context.context mechanism to better learn the matching relationship between text and images;
    4. Diffusion model output noise ϵ θ \epsilon_{\theta}ϵi, the error between the calculation and the real noise is used as Loss, and the parameters of the UNetModel model are updated through the backpropagation algorithm. Note that the parameters in AutoEncoderKL and FrozenCLIPEmbedder will not be updated during this process.
  • Sampling phase (see Part 2 and Part 3 in the figure), that is, after we load the model parameters, enter the prompt word to output the image. Mainly include:

    1. Use the FrozenCLIPEmbedder text encoder to encode the Prompt prompt words to generate an embedding representation [B, K, E]of (ie context);
    2. Randomly generate noise with a size [B, Z, H/8, W/8]of , use the trained UNetModel model, iterate T times according to DDPM/DDIM/PLMS and other algorithms, continuously remove the noise, and restore the latent representation of the image;
    3. Use AutoEncoderKL [B, Z, H/8, W/8]to decode (decode) the latent representation of the image (the size is ), and finally restore the image of the pixel space, the image size is [B, C, H, W]; this process is called in Stable Diffusion decode_first_stage.

After the above introduction, you will have a clearer understanding of Stable Diffusion as a whole. Next, you can follow the diagram and try your best to understand each key module. Limited to personal energy and limited free time, except FrozenCLIPEmbedder and DPM algorithm (not written in the picture), other modules of Stable Diffusion have been roughly looked at, including:

  • UNetModel
  • AutoEncoderKL & VQModelInterface (also a variational autoencoder, not pictured)
  • DDPM, DDIM, PLMS algorithms

I will briefly introduce it later and record the learning process.

important papers

In the process of reading the code, I found that some heavyweight papers must be read. The theoretical derivation of the diffusion model is still somewhat complicated. Sometimes the combination of formula derivation and code implementation can deepen the understanding of knowledge. Here is a list of papers that have helped me a lot in reading code:

Analysis of important components

The following is a brief analysis of the important components in Stable Diffusion. Mainly include:

  • UNetModel
  • DDPM, DDIM, PLMS algorithms
  • AutoEncoderKL
  • Talk about some non-mainstream logic, such as predict_cids, etc.return_ids

First introduce the UNetModel structure, so that subsequent articles can be directly quoted.

Introduction to UNetModel

After drawing the UNetModel used in Stable Diffusion, the code will not be analyzed. It is easy to write the code by looking at the picture. Stable Diffusion uses the Encoder-Decoder structure of UNetModel to realize the diffusion process and predict the noise. The network structure is as follows:

The input to the model consists of three parts:

  • Image image with a size [B, C, H, W]of ; pay attention not to care about the symbols used to represent the size, they should be regarded as interfaces, for example, when UNetModel receives a noise latent image with a size [B, Z, H/8, W/8]of input, here Cis equal to Z, Hequal to H/8, Wequal toW/8 ;
  • [B,]timesteps of size
  • The text embedding representation with a size [B, K, E]of context, where Krepresents the maximum encoding length and Erepresents the embedding size

The model uses DownSampleand UpSampleto downsample and upsample the samples, and the modules that appear most often are ResBlockand SpatialTransformer, in which each in the figure ResBlockreceives the input from the previous module and the embedding corresponding to timesteps ( timestep_embthe size is configurable parameters); and Each in the figure receives the input from the previous module and (the embedding representation of the Prompt text), and uses Cross Attention to learn the matching relationship between the Prompt and the image as the condition. But the figure only shows that two modules have multiple inputs in the dotted box, and other modules are not drawn)[B, 4*M]MSpatialTransformercontextcontext

It can be seen that the output size of the final model [B, C, H, W]is the same as the input size, which means that UNetModel does not change the size of the input and output.

Let's take a look at the implementation of ResBlock, timestep_embedding, andcontext , respectively.SpatialTransformer

Implementation of ResBlock

The ResBlock network structure diagram is as follows, it accepts two inputs, the image xand the embedding corresponding to the timestep:

timestep_embedding implementation

The generation method of timestep_embedding is as follows, using the method in the paper Tranformer (Attention is All you Need):

Implementation of Prompt text embedding

That is context, the realization of . Prompt uses the CLIP model for coding. I haven't studied the CLIP model in detail, and I don't plan to read it in depth for the time being. I will add it later when I have a chance; the code is generated using pre-trained CLIP context:

Implementation of SpatialTransformer

Finally, look at the implementation SpatialTransformerof , which has many modules. When receiving images as input, it also uses contexttext as condition information, and the two use Cross Attention for modeling. Expanding further SpatialTransformer, it is found that BasicTransformerBlockit actually calls the Cross Attention module, and in the Cross Attention module, the image information is used as Query, and the text information is used as Key & Value. The model will pay attention to the correlation between the content of each part of the image and text:

I think we can use a simple idea to understand the role of Cross Attention here, such as giving a picture of a horse grazing grass during training, and a text prompt: "A white horse grazing in the desert", doing Attention When , the keyword "horse" in the text is more relevant to the animal (also "horse") in the image, because the weight is also greater, while "a horse", "white", "desert", "grass ” The equal weight is lower; at this time, when the model is well trained, the model will not only be able to learn the matching relationship between the image and the text, but also learn which keywords in the text want to highlight in the image through Attention main body.

And when we input prompt words and use the model to generate images, such as inputting "a horse is grazing", since the model has been able to capture the correlation between the image and the text and the key information in the text at this time, when it sees the text " "Horse", under the operation of black box magic, will highlight the generation of the image "horse"; when it sees "grass", it will highlight the generation of the image "grass", so as to generate an image that matches the text as much as possible .

So far, the basic introduction of each important component of UNetModel has been completed.

summary

Since the structure of the UNetModel model is not complicated, you can basically write the code by looking at the picture. A picture is worth a thousand words. In addition, I marked the size of the output result of each module , which is very convenient to run the model in the brain, hahaha.

This article briefly introduces the overall framework of the Stable Diffusion Vincent diagram code, lists some core papers on the diffusion model, and briefly analyzes UNetModel. Follow up to analyze other core components.

I found that the development of AIGC is too fast, and I can't learn it... I feel more and more that what Zhuangzi said is true: If there is a limit, there will be no limit, and it will be over!

Guess you like

Origin blog.csdn.net/Eric_1993/article/details/129393890