Detailed explanation of TokenFlow

https://github.com/omerbt/TokenFlow/issues/25
https://github.com/omerbt/TokenFlow/issues/31
https://github.com/omerbt/TokenFlow/issues/32
https://github.com/eps696/SDfu

This article mainly explains how the Model part of TokenFlow is constructed. The code is excerpted from TokenFlow/tokenflow_utils.py.

Tokenflow's Model construction logic is to first load the original Stable Diffusion , and then re-register the UNet module that needs to be modified . The modification operation is called first run_tokenflow.py:

self.init_method(conv_injection_t=pnp_f_t, qk_injection_t=pnp_attn_t)

If you want to understand the following code, you first need to understand the SD BasicTransformerBlocksource code. It is best to look at the PnP source code pnp-diffusers , because TokenFlow is improved based on PnP.

    def init_method(self, conv_injection_t, qk_injection_t):
        self.qk_injection_timesteps = self.scheduler.timesteps[:qk_injection_t] if qk_injection_t >= 0 else []
        self.conv_injection_timesteps = self.scheduler.timesteps[:conv_injection_t] if conv_injection_t >= 0 else []
        
        register_extended_attention_pnp(self, self.qk_injection_timesteps)
        register_conv_injection(self, self.conv_injection_timesteps)
        set_tokenflow(self.unet)

init_methodThe function has completed 3 things:
(1) register_extended_attention_pnp: replace unet's self attention (expanded to KV from multiple frames, complete inject at the same time.
(2) register_conv_injection: replace conv's conv_injection (UpBlock's second resnet block, complete inject.
(3) ) set_tokenflow: Replace unet's 16 BasicTransformerBlock to TokenFlowBlock.
Among them, qk_injection_timestepsand conv_injection_timestepsare two timestep list, used to control the PnP Inject operation to be executed only in the first few steps.

In addition to these modifications to the UNet Model, the function in the source code also sets the keyframe id batched_denoise_stepin order to edit the keyframe first . register_pivotalThe batch id is set before editing each batch register_batch_idx. register_timeSet step t for some layers of UNet before predicting noise .

Next, we will analyze one by one in order, the modifications made by tokenflow to the original Stable Diffusion model during the inference process.

register_extended_attention_pnp

The role of the register_extended_attention_pnp function : Reconstruct the forward function to extend attention for attn1 (16) of all BasicTransformerBlock layerssa_forward of UNet , but only inject PnP operations for some of the attn1 (8) .injection_schedule

It can be seen from the results of BasicTransformerBlock: Although the Class implementation of attn1 is CrossAttention, the context is not passed in for KV during inference, and it is essentially SelfAttention .

class BasicTransformerBlock(nn.Module):
    def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True):
        super().__init__()
        self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout)  # is a self-attention
        self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
        self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
                                    heads=n_heads, dim_head=d_head, dropout=dropout)  # cross attention
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)
        self.norm3 = nn.LayerNorm(dim)

    def forward(self, x, context=None):
        x = self.attn1(self.norm1(x)) + x
        x = self.attn2(self.norm2(x), context=context) + x
        x = self.ff(self.norm3(x)) + x
        return x

Input parameters : register_extended_attention_pnpThe two required parameters passed in by the function are unet modeland injection_schedule. Used to control the time step of PnP Injection executioninjection_schedule during inference , because we want to perform PnP operations only in the first few timesteps .

Reconstructionextend_attention : Because the forward of the original tokenflow sa_forwardis to perform attention matrix operations on all frames , which consumes too many resources, so I added attention that sa_3frame_forwardonly calculates the attention of 3 adjacent frames . The function I reconstructed register_extended_attention_pnpis as shown below. Based on the original function, I added a is_3_frameparameter to choose whether to use it sa_3frame_forward.

Insert image description here

First, let's skip these two functions sa_forwardand sa_3frame_forwardsee how to find the module corresponding to UNet and modify its forward method.

1. Reconstruct forward for attn1 of all BasicTransformerBlock layers

Determine register_forward_funwhich one to use based on sa_forward, and then traverse each module of unet to determine whether the modified module inherits from BasicTransformerBlock. If so, modify it forward, but leave it injection_scheduleblank (that is, do not execute PnP).

    module_names = []
    register_forward_fun = sa_3frame_forward if is_3_frame else sa_forward
    for module_name, module in model.unet.named_modules():
        if isinstance_str(module, "BasicTransformerBlock"):
            module_names.append(module_name)
            # replace BasicTransformerBlock.attn1's forward with sa_forward
            module.attn1.forward = register_forward_fun(module.attn1)
            # set injection_schedule empty[] for BasicTransformerBlock.attn1
            setattr(module.attn1, 'injection_schedule', [])
    print(f"all change {
      
      len(module_names)} layer's BasicTransformerBlock.attn1.forward() for extended_attention_pnp...")
    print(module_names)  # up_blocks.1.attentions.0.transformer_blocks.0

isinstance_strDetermine whether the inherited type list of x contains the cls_name class:

def isinstance_str(x: object, cls_name: str):
    for _cls in x.__class__.__mro__:
        if _cls.__name__ == cls_name:
            return True
    return False

For the first time, the forward of the following 16 layers of attention in unet is reconstructed: 6 down_blocks, 9 up_blocks, and 1 mid_block.

down_blocks.0.attentions.0.transformer_blocks.0.attn1
down_blocks.0.attentions.1.transformer_blocks.0.attn1
down_blocks.1.attentions.0.transformer_blocks.0.attn1
down_blocks.1.attentions.1.transformer_blocks.0.attn1
down_blocks.2.attentions.0.transformer_blocks.0.attn1
down_blocks.2.attentions.1.transformer_blocks.0.attn1

up_blocks.1.attentions.0.transformer_blocks.0.attn1
up_blocks.1.attentions.1.transformer_blocks.0.attn1
up_blocks.1.attentions.2.transformer_blocks.0.attn1
up_blocks.2.attentions.0.transformer_blocks.0.attn1
up_blocks.2.attentions.1.transformer_blocks.0.attn1
up_blocks.2.attentions.2.transformer_blocks.0.attn1
up_blocks.3.attentions.0.transformer_blocks.0.attn1
up_blocks.3.attentions.1.transformer_blocks.0.attn1
up_blocks.3.attentions.2.transformer_blocks.0.attn1

mid_block.attentions.0.transformer_blocks.0.attn1

2. Inject injection_schedule and use PnP operation on some of attn1 (8)

Under the instruction of res_dict, modify the forward function for the specific attn1, register injection_schedule for it, and use PnP operation.

    res_dict = {
    
    1: [1, 2], 2: [0, 1, 2], 3: [0, 1, 2]}  # upblock's self-attention layers
    # we are injecting attention in blocks 4 - 11 of the Unet UpBlock, so not in the first block of the lowest resolution
    for res in res_dict:  # res = 1
        for block in res_dict[res]:  # res_dict[res] = [1, 2]
            module = model.unet.up_blocks[res].attentions[block].transformer_blocks[0].attn1
            module.forward = sa_forward(module)
            setattr(module, 'injection_schedule', injection_schedule)

The second reconstruction of 8 up_blocks:

model.unet.up_blocks.1.attentions.1.transformer_blocks.0.attn1
model.unet.up_blocks.1.attentions.2.transformer_blocks.0.attn1
model.unet.up_blocks.2.attentions.0.transformer_blocks.0.attn1
model.unet.up_blocks.2.attentions.1.transformer_blocks.0.attn1
model.unet.up_blocks.2.attentions.2.transformer_blocks.0.attn1
model.unet.up_blocks.3.attentions.0.transformer_blocks.0.attn1
model.unet.up_blocks.3.attentions.1.transformer_blocks.0.attn1
model.unet.up_blocks.3.attentions.2.transformer_blocks.0.attn1

3. to_forward

PnP operation : Because TokenFlow's input considers both PnPand , the original single latentclassifer-free guidance input by UNet becomes three copies : ( one x corresponds to one x and source_latents corresponds to it ).source_latents + x + xedit_promptnull_promptsource_prompt

latent_model_input = torch.cat([source_latents] + ([x] * 2))  

In this way, when Unet is inferring, it can be directly sliced ​​from the input latents x, divided into 3 parts and source_latentsinjected into the sum (PnP injection is direct replacement ). For self-attention, we only replace Q and K.uncond_latentscond_latents

source_latents = x[:n_frames]
uncond_latents = x[n_frames:2*n_frames]
cond_latents = x[2*n_frames:]
# source inject uncond
q[n_frames:2*n_frames] = q[:n_frames]
k[n_frames:2*n_frames] = k[:n_frames]
# source inject cond
q[2*n_frames:] = q[:n_frames]
k[2*n_frames:] = k[:n_frames]

Extend_Attention : tokenflow implements the use of extended self-attention, because for the i-th frame, when calculating self attention, Q is the feature of the i-th frame, and KV must come from all other frames, so it is necessary to repeat一下K和Vfacilitate subsequent calculations.
T base = S oftmax ( Q i ; [ K i 1 , . . . , K ik ] d ) ⋅ [ V i 1 , . . . , V ik ] T_{base}=Softmax(\frac{Q^i; [K^{i1},...,K^{ik}]}{\sqrt{d}})\cdot[V^{i1},...,V^{ik}]Tbase=Softmax(d Qi;[Ki 1 ,...,KI ])[Vi 1 ,...,VI ]

# KV reshape and repeat for extend_attention: Softmax(Q_i_frame @ K_all_frame) @ V_all_frame
# (n_frames, seq_len, dim) -> (1, n_frames * seq_len, dim) -> (n_frames, n_frames * seq_len, dim)
k_source = k[:n_frames]
k_uncond = k[n_frames:2 * n_frames].reshape(1, n_frames * sequence_length, -1).repeat(n_frames, 1, 1)
k_cond = k[2 * n_frames:].reshape(1, n_frames * sequence_length, -1).repeat(n_frames, 1, 1)
v_source = v[:n_frames]
v_uncond = v[n_frames:2 * n_frames].reshape(1, n_frames * sequence_length, -1).repeat(n_frames, 1, 1)
v_cond = v[2 * n_frames:].reshape(1, n_frames * sequence_length, -1).repeat(n_frames, 1, 1)

Because 逐帧calculations are performed and multi-head attention 逐头is calculated, a double for loop is constructed to calculate the attention 第 i 帧第 j 头的 attention out, and finally 分别concat帧维度和头维度the final attention result is obtained:

Q @ K -> sim:
(b, 1, seq_len, dim//head) @ (b, 1, dim//head, frame*seq_len) -> (b, 1, seq_len, frame*seq_len)

sim @ V -> out:
(b, 1, seq_len, frame*seq_len) @ (b, 1, frame*seq_len, dim//head) -> (b, 1, seq_len, dim//head)

cat each head's out:
(b->n_frames, 1, seq_len, dim//head) -> (n_frames, 1, seq_len, dim//head)

cat each frame's out:
(n_frames, 1, seq_len, dim//head) -> (n_frames, heads, seq_len, dim//heads)

sa_forwardThe complete code is as follows:

def sa_forward(self):
        to_out = self.to_out  # self.to_out = [linear, dropout]
        if type(to_out) is torch.nn.modules.container.ModuleList:
            to_out = self.to_out[0]
        else:
            to_out = self.to_out
        
        def forward(x, encoder_hidden_states=None, attention_mask=None):  
            is_cross = encoder_hidden_states is not None  # corss-attention or self-attention

            h = self.heads
            batch_size, sequence_length, dim = x.shape  # (3*n_frames, seq_len, dim)
            # batch: 前n_frames个样本为source feature, 中间n_frames个样本为uncond featur, 后n_frames个样本为cond feature
            n_frames = batch_size // 3
            # source_latents = x[:n_frames], uncond_latents = x[n_frames:2*n_frames], cond_latents = x[2*n_frames:]
                        
            encoder_hidden_states = encoder_hidden_states if is_cross else x
            q = self.to_q(x)
            k = self.to_k(encoder_hidden_states)
            v = self.to_v(encoder_hidden_states)

            # PnP Injection QK:只需要sample过程中的前几个timestep进行injection (判断t是否符合),且只在UpBlock进行inject
            if self.injection_schedule is not None and (self.t in self.injection_schedule or self.t == 1000):
                # source inject into unconditional
                q[n_frames:2 * n_frames] = q[:n_frames]
                k[n_frames:2 * n_frames] = k[:n_frames]
                # source inject into conditional
                q[2 * n_frames:] = q[:n_frames]
                k[2 * n_frames:] = k[:n_frames]

            # KV reshape and repeat for extend_attention: Softmax(Q_i_frame @ K_all_frame) @ V_all_frame
            # (n_frames, seq_len, dim) -> (1, n_frames * seq_len, dim) -> (n_frames, n_frames * seq_len, dim)
            k_source = k[:n_frames]
            k_uncond = k[n_frames:2 * n_frames].reshape(1, n_frames * sequence_length, -1).repeat(n_frames, 1, 1)
            k_cond = k[2 * n_frames:].reshape(1, n_frames * sequence_length, -1).repeat(n_frames, 1, 1)
            v_source = v[:n_frames]
            v_uncond = v[n_frames:2 * n_frames].reshape(1, n_frames * sequence_length, -1).repeat(n_frames, 1, 1)
            v_cond = v[2 * n_frames:].reshape(1, n_frames * sequence_length, -1).repeat(n_frames, 1, 1)
            
            # project QKV's source, cond and uncond, respectively 
            q_source = self.reshape_heads_to_batch_dim(q[:n_frames])  # q (n_frames*heads, seq_len, dim//heads)
            q_uncond = self.reshape_heads_to_batch_dim(q[n_frames:2 * n_frames])
            q_cond = self.reshape_heads_to_batch_dim(q[2 * n_frames:])
            k_source = self.reshape_heads_to_batch_dim(k_source)  # kv (n_frames*heads, n_frames * seq_len, dim//heads)
            k_uncond = self.reshape_heads_to_batch_dim(k_uncond)
            k_cond = self.reshape_heads_to_batch_dim(k_cond)
            v_source = self.reshape_heads_to_batch_dim(v_source)
            v_uncond = self.reshape_heads_to_batch_dim(v_uncond)
            v_cond = self.reshape_heads_to_batch_dim(v_cond)
            
            # split heads
            q_src = q_source.view(n_frames, h, sequence_length, dim // h)
            k_src = k_source.view(n_frames, h, sequence_length, dim // h)
            v_src = v_source.view(n_frames, h, sequence_length, dim // h)
            q_uncond = q_uncond.view(n_frames, h, sequence_length, dim // h)
            k_uncond = k_uncond.view(n_frames, h, sequence_length * n_frames, dim // h)
            v_uncond = v_uncond.view(n_frames, h, sequence_length * n_frames, dim // h)
            q_cond = q_cond.view(n_frames, h, sequence_length, dim // h)
            k_cond = k_cond.view(n_frames, h, sequence_length * n_frames, dim // h)
            v_cond = v_cond.view(n_frames, h, sequence_length * n_frames, dim // h)

            out_source_all = []
            out_uncond_all = []
            out_cond_all = []
            
            # each frame or single_batch frames
            single_batch = n_frames<=12
            b = n_frames if single_batch else 1  # b=1
            # do attention for each frame respectively. frames [frame:frame=b]
            for frame in range(0, n_frames, b):
                out_source = []
                out_uncond = []
                out_cond = []
                # do attention for each head respectively. head j
                for j in range(h):
                    # do attention for source, cond and uncond respectively, (b, 1, seq_len, dim//head) @ (b, 1, dim//head, frame*seq_len) -> (b, 1, seq_len, frame*seq_len)
                    sim_source_b = torch.bmm(q_src[frame: frame+ b, j], k_src[frame: frame+ b, j].transpose(-1, -2)) * self.scale
                    sim_uncond_b = torch.bmm(q_uncond[frame: frame+ b, j], k_uncond[frame: frame+ b, j].transpose(-1, -2)) * self.scale
                    sim_cond = torch.bmm(q_cond[frame: frame+ b, j], k_cond[frame: frame+ b, j].transpose(-1, -2)) * self.scale
                    # append each head's out, (b, 1, seq_len, frame*seq_len) @ (b, 1, frame*seq_len, dim//head) -> (b, 1, seq_len, dim//head)
                    out_source.append(torch.bmm(sim_source_b.softmax(dim=-1), v_src[frame: frame+ b, j]))
                    out_uncond.append(torch.bmm(sim_uncond_b.softmax(dim=-1), v_uncond[frame: frame+ b, j]))
                    out_cond.append(torch.bmm(sim_cond.softmax(dim=-1), v_cond[frame: frame+ b, j]))
                # cat each head's out, (b->n_frames, 1, seq_len, dim//head) -> (n_frames, 1, seq_len, dim//head)
                out_source = torch.cat(out_source, dim=0)
                out_uncond = torch.cat(out_uncond, dim=0) 
                out_cond = torch.cat(out_cond, dim=0) 
                if single_batch: # if use single_batch, view single_batch frame's out
                    out_source = out_source.view(h, n_frames,sequence_length, dim // h).permute(1, 0, 2, 3).reshape(h * n_frames, sequence_length, -1)
                    out_uncond = out_uncond.view(h, n_frames,sequence_length, dim // h).permute(1, 0, 2, 3).reshape(h * n_frames, sequence_length, -1)
                    out_cond = out_cond.view(h, n_frames,sequence_length, dim // h).permute(1, 0, 2, 3).reshape(h * n_frames, sequence_length, -1)
                # append each frame's out
                out_source_all.append(out_source)
                out_uncond_all.append(out_uncond)
                out_cond_all.append(out_cond)
            # cat each frame's out, (n_frames, 1, seq_len, dim//head) -> (n_frames, heads, seq_len, dim//heads)
            out_source = torch.cat(out_source_all, dim=0)
            out_uncond = torch.cat(out_uncond_all, dim=0)
            out_cond = torch.cat(out_cond_all, dim=0)
            # cat source, cond and uncond's out, (n_frames, heads, seq_len, dim//heads) -> (3*n_frames, heads, seq_len, dim//heads)
            out = torch.cat([out_source, out_uncond, out_cond], dim=0)
            out = self.reshape_batch_dim_to_heads(out)
            return to_out(out)
        return forward

3. at_3frame_forward

Because PnP is used: each time self attention is no longer like sa_forwardrepeating n_frames on the input repeat, but normal KV self attention from a single frame issource_latent performed , and self attention is performed .forward_originaluncond_latentcond_latentKV 来自相邻3帧forward_extended

Each time the attention of the i-th frame is calculated (window_size=3), with the i-th frame as the center, 下标=[i-1, i, i+1]3 frames are taken as KV:

 def sa_3frame_forward(self):  # self attention只是扩展到连续的 3 个关键帧,而不是所有关键帧。
        to_out = self.to_out
        if type(to_out) is torch.nn.modules.container.ModuleList:
            to_out = self.to_out[0]
        else:
            to_out = self.to_out
        
        # 原始的UNet attention forward
        def forward_original(q, k, v):
            n_frames, seq_len, dim = q.shape
            h = self.heads
            head_dim = dim // h
            
            q = self.head_to_batch_dim(q).reshape(n_frames, h, seq_len, head_dim)
            k = self.head_to_batch_dim(k).reshape(n_frames, h, seq_len, head_dim)
            v = self.head_to_batch_dim(v).reshape(n_frames, h, seq_len, head_dim)

            out_all = []
            
            for frame in range(n_frames):
                out = []
                for j in range(h):
                    sim = torch.matmul(q[frame, j], k[frame, j].transpose(-1, -2)) * self.scale # (seq_len, seq_len)                                            
                    out.append(torch.matmul(sim.softmax(dim=-1), v[frame, j])) # h * (seq_len, head_dim)

                out = torch.cat(out, dim=0).reshape(-1, seq_len, head_dim) # (h, seq_len, head_dim)
                out_all.append(out) # n_frames * (h, seq_len, head_dim)
            
            out = torch.cat(out_all, dim=0) # (n_frames * h, seq_len, head_dim)
            out = self.batch_to_head_dim(out) # (n_frames, seq_len, h * head_dim)
            return out
            
        # extend UNet attention forward(all frames)
        def forward_extended(q, k, v):
            n_frames, seq_len, dim = q.shape
            h = self.heads
            head_dim = dim // h
            
            q = self.head_to_batch_dim(q).reshape(n_frames, h, seq_len, head_dim)
            k = self.head_to_batch_dim(k).reshape(n_frames, h, seq_len, head_dim)
            v = self.head_to_batch_dim(v).reshape(n_frames, h, seq_len, head_dim)

            out_all = []
            window_size = 3
            
            for frame in range(n_frames):  # frame=32, window_size=3: window=[14, 15, 16, 17, 18]
                out = []
                # sliding window to improve speed.  以当前帧frame为中心,取window_size大小的帧,如frame_idx=1时, window: [0, 1, 2]
                window = range(max(0, frame-window_size // 2), min(n_frames, frame+window_size//2+1))  
                
                for j in range(h):
                    sim_all = []  # 存当前帧frame和window内3帧的sim,len(sim_all)=3
                    
                    for kframe in window:  # (1, 1, seq_len, head_dim) @ (1, 1, head_dim, seq_len) -> (1, 1, seq_len, seq_len)
                        # 当前帧frame 依次和window内的帧kframe,计算sim存入sim_all
                        sim_all.append(torch.matmul(q[frame, j], k[kframe, j].transpose(-1, -2)) * self.scale) # window * (seq_len, seq_len)
                        
                    sim_all = torch.cat(sim_all).reshape(len(window), seq_len, seq_len).transpose(0, 1) # (seq_len, window, seq_len)
                    sim_all = sim_all.reshape(seq_len, len(window) * seq_len) # (seq_len, window * seq_len)
                    out.append(torch.matmul(sim_all.softmax(dim=-1), v[window, j].reshape(len(window) * seq_len, head_dim))) # h * (seq_len, head_dim)

                out = torch.cat(out, dim=0).reshape(-1, seq_len, head_dim) # (h, seq_len, head_dim)
                out_all.append(out) # n_frames * (h, seq_len, head_dim)
            
            out = torch.cat(out_all, dim=0) # (n_frames * h, seq_len, head_dim)
            out = self.batch_to_head_dim(out) # (n_frames, seq_len, h * head_dim)

            return out
            
        def forward(x, encoder_hidden_states=None, attention_mask=None):
            batch_size, sequence_length, dim = x.shape
            h = self.heads
            n_frames = batch_size // 3
            
            is_cross = encoder_hidden_states is not None
            encoder_hidden_states = encoder_hidden_states if is_cross else x
            q = self.to_q(x)
            k = self.to_k(encoder_hidden_states)
            v = self.to_v(encoder_hidden_states)

            if self.injection_schedule is not None and (self.t in self.injection_schedule or self.t == 1000):
                # inject unconditional
                q[n_frames:2 * n_frames] = q[:n_frames]
                k[n_frames:2 * n_frames] = k[:n_frames]
                # inject conditional
                q[2 * n_frames:] = q[:n_frames]
                k[2 * n_frames:] = k[:n_frames]

            # source_latent 正常的self attention, uncond 和 cond进行 KV来自相邻3帧的self attention
            out_source = forward_original(q[:n_frames], k[:n_frames], v[:n_frames])  
            out_uncond = forward_extended(q[n_frames:2 * n_frames], k[n_frames:2 * n_frames], v[n_frames:2 * n_frames])
            out_cond = forward_extended(q[2 * n_frames:], k[2 * n_frames:], v[2 * n_frames:])
                            
            out = torch.cat([out_source, out_uncond, out_cond], dim=0) # (3 * n_frames, seq_len, dim)

            return to_out(out)

        return forward

register_conv_injection

After registering the SelfAttention forward for UNet, unet.up_blocks[1].resnets[1]register a new forward for UNet's ResnetBlock2D, and register injection_schedulethe control PnP injection time step at the same time.

Insert image description here

There is conv_forwardonly one more step of PnP Inject operation than the ordinary ResnetBlock2D forward :

if self.injection_schedule is not None and (self.t in self.injection_schedule or self.t == 1000):
                source_batch_size = int(hidden_states.shape[0] // 3)
                # inject unconditional
                hidden_states[source_batch_size:2 * source_batch_size] = hidden_states[:source_batch_size]
                # inject conditional
                hidden_states[2 * source_batch_size:] = hidden_states[:source_batch_size]

set_tokenflow

__class__It means to return its own parent class, set_tokenflowthat is, to find all the modules in UNet whose parent class is BasicTransformerBlock, and use BasicTransformerBlock as the parent class for them, and then add a layer of TokenFlowBlock class outside.

def set_tokenflow(model: torch.nn.Module):
    """
    Sets the tokenflow attention blocks in a model.
    """
    for _, module in model.named_modules():
        if isinstance_str(module, "BasicTransformerBlock"):
            # 16个 module.__class__ = <class 'diffusers.models.attention.BasicTransformerBlock'>
            make_tokenflow_block_fn = make_tokenflow_attention_block 
            # 将BasicTransformerBlock作为父类,外面再套一层TokenFlowBlock类
            module.__class__ = make_tokenflow_block_fn(module.__class__)

            # Something needed for older versions of diffusers
            if not hasattr(module, "use_ada_layer_norm_zero"):
                module.use_ada_layer_norm = False
                module.use_ada_layer_norm_zero = False
    return model

make_tokenflow_attention_block

This function defines a TokenFlowBlock class and then returns the TokenFlowBlock class. TokenFlowBlock inherits from the BasicTransformerBlock class and only rewrites forwardthe function .

First pivotal_passdetermine whether it is a key frame:

  • Keyframes, just save thempivot_hidden_states
  • For non-key frames, take the non-key frames and key frames source_latent, calculate the cosine similarity cosine_sim with the key frames, shape= (n_frames * seq_len, len(batch_idxs) * seq_len), find it 相似度最大的帧下标idx, and then stack 3 copies for source, uncond, and cond.
    • If the current batch is not the first batch, len(batch_idxs) =2save the most similar frame index to idx1 and idx2 respectively.
    • If it is the first batch, len(batch_idxs) =1save the most similar frame index to idx1
			batch_size, sequence_length, dim = hidden_states.shape  # (batch, seq_len, dim)
            n_frames = batch_size // 3  # batch = 3 * n_frames: source + uncond + cond
            mid_idx = n_frames // 2
            hidden_states = hidden_states.view(3, n_frames, sequence_length, dim)  # (source + uncond + cond, n_frames, seq_len, dim)

            norm_hidden_states = self.norm1(hidden_states).view(3, n_frames, sequence_length, dim)

            if self.pivotal_pass:  # is_pivotal = True # 关键帧,存下
                self.pivot_hidden_states = norm_hidden_states  # (3, n_frames, sequence_length, dim) ,关键帧的n_frames=5
            else:  # is_pivotal = False # 非关键帧,与关键帧计算source_latent的cosine_sim
                idx1 = []
                idx2 = []
                batch_idxs = [self.batch_idx]  # 每batch_size帧进行一批处理,batch_idx是第几个batch,如32帧,batch_size=8,batch_idx可以为0或1或2或3或4
                if self.batch_idx > 0:  # 如果不是第一个batch
                    batch_idxs.append(self.batch_idx - 1)  # 加入前一个batch的idx,如当前batch_idx=1时,再加入0,则batch_idxs=[1,0]
                
                # 取source_latent的非关键帧与关键帧计算cosine_sim,如果batch_idxs=[1,0],则只拿第0个batch和第1个batch的关键帧和其norm_hidden_states计算sim
                sim = batch_cosine_sim(norm_hidden_states[0].reshape(-1, dim),  # (n_frames*sequence_length, dim)
                                        self.pivot_hidden_states[0][batch_idxs].reshape(-1, dim))  # (len(batch_idxs)*sequence_length, dim)
                if len(batch_idxs) == 2:  # 如果不是第一个batch, 分别保存最相似的帧下标到idx1和idx2
                    # sim: (n_frames * seq_len, len(batch_idxs) * seq_len),  len(batch_idxs)=2
                    sim1, sim2 = sim.chunk(2, dim=1) 
                    idx1.append(sim1.argmax(dim=-1))  # (n_frames * seq_len) 个数,每个数在[0,76]之间
                    idx2.append(sim2.argmax(dim=-1))  # (n_frames * seq_len) 个数,每个数在[0,76]之间
                else:  # 如果是第一个batch,保存最相似的帧下标到idx1
                    idx1.append(sim.argmax(dim=-1))

                # 为source、uncond、cond 堆叠3份
                idx1 = torch.stack(idx1 * 3, dim=0) # (3, n_frames * seq_len)
                idx1 = idx1.squeeze(1)
                if len(batch_idxs) == 2:
                    idx2 = torch.stack(idx2 * 3, dim=0) # (3, n_frames * seq_len)
                    idx2 = idx2.squeeze(1)

Next, Self-Attentionattn1 , Cross-Attentionattn2 , and Feed-forwardff are performed in sequence . There is no change in Cross-Attention and Feed-forward. The only change is the Self-Attention process :

  • For keyframes, calculate the self-attention result and save it.
  • For non-key frames, fuse them with the attention results of key frames. The fusion method is weighted average, and the weight is determined by the distance between the frame and the key frame. If the non-key frame is a frame in the first batch, the attention result of the key frame is used directly. If the non-key frame is a frame in the second batch, calculate the attention results with the key frames in the first batch and the key frames in the second batch, and then perform a weighted average. The weight is determined by the distance between the frame and two keyframes. The specific formula is as follows:
    weight = ∣ s − p 1 ∣ / ( ∣ s − p 1 ∣ + ∣ s − p 2 ∣ ) weight = |s - p1| / (|s - p1| + |s - p2|)weight=sp1∣/(sp1∣+sp 2∣ )
    where, s represents the number of the frame, p1 represents the number of the first key frame, and p2 represents the number of the second key frame.
			# 1. Self-Attention
            cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {
    
    }
            if self.pivotal_pass:
                # norm_hidden_states.shape = 3, n_frames * seq_len, dim
                self.attn_output = self.attn1(
                        norm_hidden_states.view(batch_size, sequence_length, dim),
                        encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
                        **cross_attention_kwargs,
                    )
                # 3, n_frames * seq_len, dim - > 3 * n_frames, seq_len, dim
                self.kf_attn_output = self.attn_output 
            else:
                batch_kf_size, _, _ = self.kf_attn_output.shape
                self.attn_output = self.kf_attn_output.view(3, batch_kf_size // 3, sequence_length, dim)[:,
                                   batch_idxs]  # 3, n_frames, seq_len, dim --> 3, len(batch_idxs), seq_len, dim

            # gather values from attn_output, using idx as indices, and get a tensor of shape 3, n_frames, seq_len, dim
            if not self.pivotal_pass:
                if len(batch_idxs) == 2:
                    attn_1, attn_2 = self.attn_output[:, 0], self.attn_output[:, 1]
                    attn_output1 = attn_1.gather(dim=1, index=idx1.unsqueeze(-1).repeat(1, 1, dim))
                    attn_output2 = attn_2.gather(dim=1, index=idx2.unsqueeze(-1).repeat(1, 1, dim))

                    s = torch.arange(0, n_frames).to(idx1.device) + batch_idxs[0] * n_frames
                    # distance from the pivot
                    p1 = batch_idxs[0] * n_frames + n_frames // 2
                    p2 = batch_idxs[1] * n_frames + n_frames // 2
                    d1 = torch.abs(s - p1)
                    d2 = torch.abs(s - p2)
                    # weight
                    w1 = d2 / (d1 + d2)
                    w1 = torch.sigmoid(w1)
                    
                    w1 = w1.unsqueeze(0).unsqueeze(-1).unsqueeze(-1).repeat(3, 1, sequence_length, dim)
                    attn_output1 = attn_output1.view(3, n_frames, sequence_length, dim)
                    attn_output2 = attn_output2.view(3, n_frames, sequence_length, dim)
                    attn_output = w1 * attn_output1 + (1 - w1) * attn_output2
                else:
                    attn_output = self.attn_output[:,0].gather(dim=1, index=idx1.unsqueeze(-1).repeat(1, 1, dim))

                attn_output = attn_output.reshape(
                        batch_size, sequence_length, dim)  # 3 * n_frames, seq_len, dim
            else:
                attn_output = self.attn_output
                
            hidden_states = hidden_states.reshape(batch_size, sequence_length, dim)  # 3 * n_frames, seq_len, dim
            hidden_states = attn_output + hidden_states  # res_connect

Supongo que te gusta

Origin blog.csdn.net/weixin_54338498/article/details/135074600
Recomendado
Clasificación