Calculus Analysis of Stable Diffusion Model

Prepare

Refer to the parameter amount and calculation amount of the model:  ThanatosShinji/onnx-tool: ONNX model's shape inference and MACs(FLOPs) counting. (github.com) These four models are the most important four onnx models of Stable Diffusion 1.4:

 The Baidu network disk in github can download the model with the intermediate tensor shape. For example:

TextEncoder

This model is very similar to BERT, Bert Base of 12 layers. The calculation volume is 6.7GMACs.

Like BertBase, 98% of the computation is concentrated on MatMul.

 

This token generates a hidden state of 1x77x768 and needs to be sent to UNetCondition.

UNet2DCondition

This is UNet+Transformer. It is the part responsible for image generation in the entire StableDiffusion. The parameter volume is as high as 859M, and the model file exceeds 3G. It can be understood that this model parameter remembers a lot of image texture information, and can be dynamically adjusted according to the input text description Textures for various parts of the image.

The resolution of the model input is 64x64, according to the method of UNet, use conv2d with stride==2 to perform 2 times downsampling to 32x32, 16x16, 8x8:

 Different from UNet's continuous convolution, it uses the following structure to convert hw into a matrix structure and use it to do MHA (multi-head attention):

Input reshape from 2x640x32x32 (common structure of UNet) to 2x1024x640 (common structure of BERT.

 Then use this BERT-like tensor as a transformer encoder:

 Here it changed QKV's MatMul to Einsum.

So UNet2DCondition is to replace the CNN structure after UNet downsampling with the Transformer structure.

Overall: Convolution accounts for 49% of total operations, matrix calculations account for 31% (plus Einsum for a total of 45%)

VAE Encoder+Decoder

These two models are used in pairs, so they are put together. These two models are also relatively clever, and both have the structure of CV+transformer.

The Encoder downsampled the input tensor from 3x512x512 to 512x64x64 through CNN, and immediately did another MHA:

 After doing it, after several layers of convolution, the compressed features of 8x64x64 are output.

After the Decoder gets the compressed features of 4x64x64, it is first convoluted to 512x64x64, and then a transformer MHA structure:

 After this MHA, the tensor is restored to 3x512x512 by Conv+Resize.

Probably something like this:

Encoder:

Conv    3x512x512

  ....

Conv    512x64x64

MHA      4096x512

Conv       8x64x64

Decoder:

Conv     4x64x64

Conv     512x64x64

MHA     4096x512

....

Conv     3x512x512

Among them, these two calculations are the largest because the processing of the maximum resolution 512x512 is here. Encoder: 566G MACs Decoder: 1271G MACs. Conv accounts for 95% and 97% of the calculations respectively. Transformer's MatMul can basically be ignored .

Summarize

Transformer is really a good structure to provide model interpretation. 

It is not recommended to consider increasing the resolution of the model to 64x64 to increase the resolution of the generated image. The resolution has a high impact on the overall calculation of the model. You can consider outputting a lower resolution, such as 32x32 output to 256x256, and then use other super-resolution network to increase the output resolution.

For the analysis report of the amount of calculation, you can use the command line to generate a csv file yourself:

 python -m onnx_tool -i .\vae_encoder.onnx -f vae_encoder.csv

 

Guess you like

Origin blog.csdn.net/luoyu510183/article/details/127695184