https://github.com/omerbt/TokenFlow/issues/25
https://github.com/omerbt/TokenFlow/issues/31
https://github.com/omerbt/TokenFlow/issues/32
https://github.com/eps696/SDfu
This article mainly explains how the Model part of TokenFlow is constructed. The code is excerpted from TokenFlow/tokenflow_utils.py
.
Tokenflow's Model construction logic is to first load the original Stable Diffusion , and then re-register the UNet module that needs to be modified . The modification operation is called first run_tokenflow.py
:
self.init_method(conv_injection_t=pnp_f_t, qk_injection_t=pnp_attn_t)
If you want to understand the following code, you first need to understand the SD BasicTransformerBlock
source code. It is best to look at the PnP source code pnp-diffusers , because TokenFlow is improved based on PnP.
def init_method(self, conv_injection_t, qk_injection_t):
self.qk_injection_timesteps = self.scheduler.timesteps[:qk_injection_t] if qk_injection_t >= 0 else []
self.conv_injection_timesteps = self.scheduler.timesteps[:conv_injection_t] if conv_injection_t >= 0 else []
register_extended_attention_pnp(self, self.qk_injection_timesteps)
register_conv_injection(self, self.conv_injection_timesteps)
set_tokenflow(self.unet)
init_method
The function has completed 3 things:
(1) register_extended_attention_pnp
: replace unet's self attention (expanded to KV from multiple frames, complete inject at the same time.
(2) register_conv_injection
: replace conv's conv_injection (UpBlock's second resnet block, complete inject.
(3) ) set_tokenflow
: Replace unet's 16 BasicTransformerBlock to TokenFlowBlock.
Among them, qk_injection_timesteps
and conv_injection_timesteps
are two timestep list
, used to control the PnP Inject operation to be executed only in the first few steps.
In addition to these modifications to the UNet Model, the function in the source code also sets the keyframe id batched_denoise_step
in order to edit the keyframe first . register_pivotal
The batch id is set before editing each batch register_batch_idx
. register_time
Set step t for some layers of UNet before predicting noise .
Next, we will analyze one by one in order, the modifications made by tokenflow to the original Stable Diffusion model during the inference process.
register_extended_attention_pnp
The role of the register_extended_attention_pnp function : Reconstruct the forward function to extend attention for attn1 (16) of all BasicTransformerBlock layerssa_forward
of UNet , but only inject PnP operations for some of the attn1 (8) .injection_schedule
It can be seen from the results of BasicTransformerBlock: Although the Class implementation of attn1 is CrossAttention, the context is not passed in for KV during inference, and it is essentially SelfAttention .
class BasicTransformerBlock(nn.Module):
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True):
super().__init__()
self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
heads=n_heads, dim_head=d_head, dropout=dropout) # cross attention
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.norm3 = nn.LayerNorm(dim)
def forward(self, x, context=None):
x = self.attn1(self.norm1(x)) + x
x = self.attn2(self.norm2(x), context=context) + x
x = self.ff(self.norm3(x)) + x
return x
Input parameters : register_extended_attention_pnp
The two required parameters passed in by the function are unet model
and injection_schedule
. Used to control the time step of PnP Injection executioninjection_schedule
during inference , because we want to perform PnP operations only in the first few timesteps .
Reconstructionextend_attention
: Because the forward of the original tokenflow sa_forward
is to perform attention matrix operations on all frames , which consumes too many resources, so I added attention that sa_3frame_forward
only calculates the attention of 3 adjacent frames . The function I reconstructed register_extended_attention_pnp
is as shown below. Based on the original function, I added a is_3_frame
parameter to choose whether to use it sa_3frame_forward
.
First, let's skip these two functions sa_forward
and sa_3frame_forward
see how to find the module corresponding to UNet and modify its forward method.
1. Reconstruct forward for attn1 of all BasicTransformerBlock layers
Determine register_forward_fun
which one to use based on sa_forward
, and then traverse each module of unet to determine whether the modified module inherits from BasicTransformerBlock
. If so, modify it forward
, but leave it injection_schedule
blank (that is, do not execute PnP).
module_names = []
register_forward_fun = sa_3frame_forward if is_3_frame else sa_forward
for module_name, module in model.unet.named_modules():
if isinstance_str(module, "BasicTransformerBlock"):
module_names.append(module_name)
# replace BasicTransformerBlock.attn1's forward with sa_forward
module.attn1.forward = register_forward_fun(module.attn1)
# set injection_schedule empty[] for BasicTransformerBlock.attn1
setattr(module.attn1, 'injection_schedule', [])
print(f"all change {
len(module_names)} layer's BasicTransformerBlock.attn1.forward() for extended_attention_pnp...")
print(module_names) # up_blocks.1.attentions.0.transformer_blocks.0
isinstance_str
Determine whether the inherited type list of x contains the cls_name class:
def isinstance_str(x: object, cls_name: str):
for _cls in x.__class__.__mro__:
if _cls.__name__ == cls_name:
return True
return False
For the first time, the forward of the following 16 layers of attention in unet is reconstructed: 6 down_blocks, 9 up_blocks, and 1 mid_block.
down_blocks.0.attentions.0.transformer_blocks.0.attn1
down_blocks.0.attentions.1.transformer_blocks.0.attn1
down_blocks.1.attentions.0.transformer_blocks.0.attn1
down_blocks.1.attentions.1.transformer_blocks.0.attn1
down_blocks.2.attentions.0.transformer_blocks.0.attn1
down_blocks.2.attentions.1.transformer_blocks.0.attn1
up_blocks.1.attentions.0.transformer_blocks.0.attn1
up_blocks.1.attentions.1.transformer_blocks.0.attn1
up_blocks.1.attentions.2.transformer_blocks.0.attn1
up_blocks.2.attentions.0.transformer_blocks.0.attn1
up_blocks.2.attentions.1.transformer_blocks.0.attn1
up_blocks.2.attentions.2.transformer_blocks.0.attn1
up_blocks.3.attentions.0.transformer_blocks.0.attn1
up_blocks.3.attentions.1.transformer_blocks.0.attn1
up_blocks.3.attentions.2.transformer_blocks.0.attn1
mid_block.attentions.0.transformer_blocks.0.attn1
2. Inject injection_schedule and use PnP operation on some of attn1 (8)
Under the instruction of res_dict, modify the forward function for the specific attn1, register injection_schedule for it, and use PnP operation.
res_dict = {
1: [1, 2], 2: [0, 1, 2], 3: [0, 1, 2]} # upblock's self-attention layers
# we are injecting attention in blocks 4 - 11 of the Unet UpBlock, so not in the first block of the lowest resolution
for res in res_dict: # res = 1
for block in res_dict[res]: # res_dict[res] = [1, 2]
module = model.unet.up_blocks[res].attentions[block].transformer_blocks[0].attn1
module.forward = sa_forward(module)
setattr(module, 'injection_schedule', injection_schedule)
The second reconstruction of 8 up_blocks:
model.unet.up_blocks.1.attentions.1.transformer_blocks.0.attn1
model.unet.up_blocks.1.attentions.2.transformer_blocks.0.attn1
model.unet.up_blocks.2.attentions.0.transformer_blocks.0.attn1
model.unet.up_blocks.2.attentions.1.transformer_blocks.0.attn1
model.unet.up_blocks.2.attentions.2.transformer_blocks.0.attn1
model.unet.up_blocks.3.attentions.0.transformer_blocks.0.attn1
model.unet.up_blocks.3.attentions.1.transformer_blocks.0.attn1
model.unet.up_blocks.3.attentions.2.transformer_blocks.0.attn1
3. to_forward
PnP operation : Because TokenFlow's input considers both PnP
and , the original single latentclassifer-free guidance
input by UNet becomes three copies : ( one x corresponds to one x and source_latents corresponds to it ).source_latents + x + x
edit_prompt
null_prompt
source_prompt
latent_model_input = torch.cat([source_latents] + ([x] * 2))
In this way, when Unet is inferring, it can be directly sliced from the input latents x, divided into 3 parts and source_latents
injected into the sum (PnP injection is direct replacement ). For self-attention, we only replace Q and K.uncond_latents
cond_latents
source_latents = x[:n_frames]
uncond_latents = x[n_frames:2*n_frames]
cond_latents = x[2*n_frames:]
# source inject uncond
q[n_frames:2*n_frames] = q[:n_frames]
k[n_frames:2*n_frames] = k[:n_frames]
# source inject cond
q[2*n_frames:] = q[:n_frames]
k[2*n_frames:] = k[:n_frames]
Extend_Attention : tokenflow implements the use of extended self-attention, because for the i-th frame, when calculating self attention, Q is the feature of the i-th frame, and KV must come from all other frames, so it is necessary to repeat一下K和V
facilitate subsequent calculations.
T base = S oftmax ( Q i ; [ K i 1 , . . . , K ik ] d ) ⋅ [ V i 1 , . . . , V ik ] T_{base}=Softmax(\frac{Q^i; [K^{i1},...,K^{ik}]}{\sqrt{d}})\cdot[V^{i1},...,V^{ik}]Tbase=Softmax(dQi;[Ki 1 ,...,KI ])⋅[Vi 1 ,...,VI ]
# KV reshape and repeat for extend_attention: Softmax(Q_i_frame @ K_all_frame) @ V_all_frame
# (n_frames, seq_len, dim) -> (1, n_frames * seq_len, dim) -> (n_frames, n_frames * seq_len, dim)
k_source = k[:n_frames]
k_uncond = k[n_frames:2 * n_frames].reshape(1, n_frames * sequence_length, -1).repeat(n_frames, 1, 1)
k_cond = k[2 * n_frames:].reshape(1, n_frames * sequence_length, -1).repeat(n_frames, 1, 1)
v_source = v[:n_frames]
v_uncond = v[n_frames:2 * n_frames].reshape(1, n_frames * sequence_length, -1).repeat(n_frames, 1, 1)
v_cond = v[2 * n_frames:].reshape(1, n_frames * sequence_length, -1).repeat(n_frames, 1, 1)
Because 逐帧
calculations are performed and multi-head attention 逐头
is calculated, a double for loop is constructed to calculate the attention 第 i 帧第 j 头的 attention out
, and finally 分别concat帧维度和头维度
the final attention result is obtained:
Q @ K -> sim:
(b, 1, seq_len, dim//head) @ (b, 1, dim//head, frame*seq_len) -> (b, 1, seq_len, frame*seq_len)
sim @ V -> out:
(b, 1, seq_len, frame*seq_len) @ (b, 1, frame*seq_len, dim//head) -> (b, 1, seq_len, dim//head)
cat each head's out:
(b->n_frames, 1, seq_len, dim//head) -> (n_frames, 1, seq_len, dim//head)
cat each frame's out:
(n_frames, 1, seq_len, dim//head) -> (n_frames, heads, seq_len, dim//heads)
sa_forward
The complete code is as follows:
def sa_forward(self):
to_out = self.to_out # self.to_out = [linear, dropout]
if type(to_out) is torch.nn.modules.container.ModuleList:
to_out = self.to_out[0]
else:
to_out = self.to_out
def forward(x, encoder_hidden_states=None, attention_mask=None):
is_cross = encoder_hidden_states is not None # corss-attention or self-attention
h = self.heads
batch_size, sequence_length, dim = x.shape # (3*n_frames, seq_len, dim)
# batch: 前n_frames个样本为source feature, 中间n_frames个样本为uncond featur, 后n_frames个样本为cond feature
n_frames = batch_size // 3
# source_latents = x[:n_frames], uncond_latents = x[n_frames:2*n_frames], cond_latents = x[2*n_frames:]
encoder_hidden_states = encoder_hidden_states if is_cross else x
q = self.to_q(x)
k = self.to_k(encoder_hidden_states)
v = self.to_v(encoder_hidden_states)
# PnP Injection QK:只需要sample过程中的前几个timestep进行injection (判断t是否符合),且只在UpBlock进行inject
if self.injection_schedule is not None and (self.t in self.injection_schedule or self.t == 1000):
# source inject into unconditional
q[n_frames:2 * n_frames] = q[:n_frames]
k[n_frames:2 * n_frames] = k[:n_frames]
# source inject into conditional
q[2 * n_frames:] = q[:n_frames]
k[2 * n_frames:] = k[:n_frames]
# KV reshape and repeat for extend_attention: Softmax(Q_i_frame @ K_all_frame) @ V_all_frame
# (n_frames, seq_len, dim) -> (1, n_frames * seq_len, dim) -> (n_frames, n_frames * seq_len, dim)
k_source = k[:n_frames]
k_uncond = k[n_frames:2 * n_frames].reshape(1, n_frames * sequence_length, -1).repeat(n_frames, 1, 1)
k_cond = k[2 * n_frames:].reshape(1, n_frames * sequence_length, -1).repeat(n_frames, 1, 1)
v_source = v[:n_frames]
v_uncond = v[n_frames:2 * n_frames].reshape(1, n_frames * sequence_length, -1).repeat(n_frames, 1, 1)
v_cond = v[2 * n_frames:].reshape(1, n_frames * sequence_length, -1).repeat(n_frames, 1, 1)
# project QKV's source, cond and uncond, respectively
q_source = self.reshape_heads_to_batch_dim(q[:n_frames]) # q (n_frames*heads, seq_len, dim//heads)
q_uncond = self.reshape_heads_to_batch_dim(q[n_frames:2 * n_frames])
q_cond = self.reshape_heads_to_batch_dim(q[2 * n_frames:])
k_source = self.reshape_heads_to_batch_dim(k_source) # kv (n_frames*heads, n_frames * seq_len, dim//heads)
k_uncond = self.reshape_heads_to_batch_dim(k_uncond)
k_cond = self.reshape_heads_to_batch_dim(k_cond)
v_source = self.reshape_heads_to_batch_dim(v_source)
v_uncond = self.reshape_heads_to_batch_dim(v_uncond)
v_cond = self.reshape_heads_to_batch_dim(v_cond)
# split heads
q_src = q_source.view(n_frames, h, sequence_length, dim // h)
k_src = k_source.view(n_frames, h, sequence_length, dim // h)
v_src = v_source.view(n_frames, h, sequence_length, dim // h)
q_uncond = q_uncond.view(n_frames, h, sequence_length, dim // h)
k_uncond = k_uncond.view(n_frames, h, sequence_length * n_frames, dim // h)
v_uncond = v_uncond.view(n_frames, h, sequence_length * n_frames, dim // h)
q_cond = q_cond.view(n_frames, h, sequence_length, dim // h)
k_cond = k_cond.view(n_frames, h, sequence_length * n_frames, dim // h)
v_cond = v_cond.view(n_frames, h, sequence_length * n_frames, dim // h)
out_source_all = []
out_uncond_all = []
out_cond_all = []
# each frame or single_batch frames
single_batch = n_frames<=12
b = n_frames if single_batch else 1 # b=1
# do attention for each frame respectively. frames [frame:frame=b]
for frame in range(0, n_frames, b):
out_source = []
out_uncond = []
out_cond = []
# do attention for each head respectively. head j
for j in range(h):
# do attention for source, cond and uncond respectively, (b, 1, seq_len, dim//head) @ (b, 1, dim//head, frame*seq_len) -> (b, 1, seq_len, frame*seq_len)
sim_source_b = torch.bmm(q_src[frame: frame+ b, j], k_src[frame: frame+ b, j].transpose(-1, -2)) * self.scale
sim_uncond_b = torch.bmm(q_uncond[frame: frame+ b, j], k_uncond[frame: frame+ b, j].transpose(-1, -2)) * self.scale
sim_cond = torch.bmm(q_cond[frame: frame+ b, j], k_cond[frame: frame+ b, j].transpose(-1, -2)) * self.scale
# append each head's out, (b, 1, seq_len, frame*seq_len) @ (b, 1, frame*seq_len, dim//head) -> (b, 1, seq_len, dim//head)
out_source.append(torch.bmm(sim_source_b.softmax(dim=-1), v_src[frame: frame+ b, j]))
out_uncond.append(torch.bmm(sim_uncond_b.softmax(dim=-1), v_uncond[frame: frame+ b, j]))
out_cond.append(torch.bmm(sim_cond.softmax(dim=-1), v_cond[frame: frame+ b, j]))
# cat each head's out, (b->n_frames, 1, seq_len, dim//head) -> (n_frames, 1, seq_len, dim//head)
out_source = torch.cat(out_source, dim=0)
out_uncond = torch.cat(out_uncond, dim=0)
out_cond = torch.cat(out_cond, dim=0)
if single_batch: # if use single_batch, view single_batch frame's out
out_source = out_source.view(h, n_frames,sequence_length, dim // h).permute(1, 0, 2, 3).reshape(h * n_frames, sequence_length, -1)
out_uncond = out_uncond.view(h, n_frames,sequence_length, dim // h).permute(1, 0, 2, 3).reshape(h * n_frames, sequence_length, -1)
out_cond = out_cond.view(h, n_frames,sequence_length, dim // h).permute(1, 0, 2, 3).reshape(h * n_frames, sequence_length, -1)
# append each frame's out
out_source_all.append(out_source)
out_uncond_all.append(out_uncond)
out_cond_all.append(out_cond)
# cat each frame's out, (n_frames, 1, seq_len, dim//head) -> (n_frames, heads, seq_len, dim//heads)
out_source = torch.cat(out_source_all, dim=0)
out_uncond = torch.cat(out_uncond_all, dim=0)
out_cond = torch.cat(out_cond_all, dim=0)
# cat source, cond and uncond's out, (n_frames, heads, seq_len, dim//heads) -> (3*n_frames, heads, seq_len, dim//heads)
out = torch.cat([out_source, out_uncond, out_cond], dim=0)
out = self.reshape_batch_dim_to_heads(out)
return to_out(out)
return forward
3. at_3frame_forward
Because PnP is used: each time self attention is no longer like sa_forward
repeating n_frames on the input repeat, but normal KV self attention from a single frame issource_latent
performed , and self attention is performed .forward_original
uncond_latent
cond_latent
KV 来自相邻3帧
forward_extended
Each time the attention of the i-th frame is calculated (window_size=3), with the i-th frame as the center, 下标=[i-1, i, i+1]
3 frames are taken as KV:
def sa_3frame_forward(self): # self attention只是扩展到连续的 3 个关键帧,而不是所有关键帧。
to_out = self.to_out
if type(to_out) is torch.nn.modules.container.ModuleList:
to_out = self.to_out[0]
else:
to_out = self.to_out
# 原始的UNet attention forward
def forward_original(q, k, v):
n_frames, seq_len, dim = q.shape
h = self.heads
head_dim = dim // h
q = self.head_to_batch_dim(q).reshape(n_frames, h, seq_len, head_dim)
k = self.head_to_batch_dim(k).reshape(n_frames, h, seq_len, head_dim)
v = self.head_to_batch_dim(v).reshape(n_frames, h, seq_len, head_dim)
out_all = []
for frame in range(n_frames):
out = []
for j in range(h):
sim = torch.matmul(q[frame, j], k[frame, j].transpose(-1, -2)) * self.scale # (seq_len, seq_len)
out.append(torch.matmul(sim.softmax(dim=-1), v[frame, j])) # h * (seq_len, head_dim)
out = torch.cat(out, dim=0).reshape(-1, seq_len, head_dim) # (h, seq_len, head_dim)
out_all.append(out) # n_frames * (h, seq_len, head_dim)
out = torch.cat(out_all, dim=0) # (n_frames * h, seq_len, head_dim)
out = self.batch_to_head_dim(out) # (n_frames, seq_len, h * head_dim)
return out
# extend UNet attention forward(all frames)
def forward_extended(q, k, v):
n_frames, seq_len, dim = q.shape
h = self.heads
head_dim = dim // h
q = self.head_to_batch_dim(q).reshape(n_frames, h, seq_len, head_dim)
k = self.head_to_batch_dim(k).reshape(n_frames, h, seq_len, head_dim)
v = self.head_to_batch_dim(v).reshape(n_frames, h, seq_len, head_dim)
out_all = []
window_size = 3
for frame in range(n_frames): # frame=32, window_size=3: window=[14, 15, 16, 17, 18]
out = []
# sliding window to improve speed. 以当前帧frame为中心,取window_size大小的帧,如frame_idx=1时, window: [0, 1, 2]
window = range(max(0, frame-window_size // 2), min(n_frames, frame+window_size//2+1))
for j in range(h):
sim_all = [] # 存当前帧frame和window内3帧的sim,len(sim_all)=3
for kframe in window: # (1, 1, seq_len, head_dim) @ (1, 1, head_dim, seq_len) -> (1, 1, seq_len, seq_len)
# 当前帧frame 依次和window内的帧kframe,计算sim存入sim_all
sim_all.append(torch.matmul(q[frame, j], k[kframe, j].transpose(-1, -2)) * self.scale) # window * (seq_len, seq_len)
sim_all = torch.cat(sim_all).reshape(len(window), seq_len, seq_len).transpose(0, 1) # (seq_len, window, seq_len)
sim_all = sim_all.reshape(seq_len, len(window) * seq_len) # (seq_len, window * seq_len)
out.append(torch.matmul(sim_all.softmax(dim=-1), v[window, j].reshape(len(window) * seq_len, head_dim))) # h * (seq_len, head_dim)
out = torch.cat(out, dim=0).reshape(-1, seq_len, head_dim) # (h, seq_len, head_dim)
out_all.append(out) # n_frames * (h, seq_len, head_dim)
out = torch.cat(out_all, dim=0) # (n_frames * h, seq_len, head_dim)
out = self.batch_to_head_dim(out) # (n_frames, seq_len, h * head_dim)
return out
def forward(x, encoder_hidden_states=None, attention_mask=None):
batch_size, sequence_length, dim = x.shape
h = self.heads
n_frames = batch_size // 3
is_cross = encoder_hidden_states is not None
encoder_hidden_states = encoder_hidden_states if is_cross else x
q = self.to_q(x)
k = self.to_k(encoder_hidden_states)
v = self.to_v(encoder_hidden_states)
if self.injection_schedule is not None and (self.t in self.injection_schedule or self.t == 1000):
# inject unconditional
q[n_frames:2 * n_frames] = q[:n_frames]
k[n_frames:2 * n_frames] = k[:n_frames]
# inject conditional
q[2 * n_frames:] = q[:n_frames]
k[2 * n_frames:] = k[:n_frames]
# source_latent 正常的self attention, uncond 和 cond进行 KV来自相邻3帧的self attention
out_source = forward_original(q[:n_frames], k[:n_frames], v[:n_frames])
out_uncond = forward_extended(q[n_frames:2 * n_frames], k[n_frames:2 * n_frames], v[n_frames:2 * n_frames])
out_cond = forward_extended(q[2 * n_frames:], k[2 * n_frames:], v[2 * n_frames:])
out = torch.cat([out_source, out_uncond, out_cond], dim=0) # (3 * n_frames, seq_len, dim)
return to_out(out)
return forward
register_conv_injection
After registering the SelfAttention forward for UNet, unet.up_blocks[1].resnets[1]
register a new forward for UNet's ResnetBlock2D, and register injection_schedule
the control PnP injection time step at the same time.
There is conv_forward
only one more step of PnP Inject operation than the ordinary ResnetBlock2D forward :
if self.injection_schedule is not None and (self.t in self.injection_schedule or self.t == 1000):
source_batch_size = int(hidden_states.shape[0] // 3)
# inject unconditional
hidden_states[source_batch_size:2 * source_batch_size] = hidden_states[:source_batch_size]
# inject conditional
hidden_states[2 * source_batch_size:] = hidden_states[:source_batch_size]
set_tokenflow
__class__
It means to return its own parent class, set_tokenflow
that is, to find all the modules in UNet whose parent class is BasicTransformerBlock, and use BasicTransformerBlock as the parent class for them, and then add a layer of TokenFlowBlock class outside.
def set_tokenflow(model: torch.nn.Module):
"""
Sets the tokenflow attention blocks in a model.
"""
for _, module in model.named_modules():
if isinstance_str(module, "BasicTransformerBlock"):
# 16个 module.__class__ = <class 'diffusers.models.attention.BasicTransformerBlock'>
make_tokenflow_block_fn = make_tokenflow_attention_block
# 将BasicTransformerBlock作为父类,外面再套一层TokenFlowBlock类
module.__class__ = make_tokenflow_block_fn(module.__class__)
# Something needed for older versions of diffusers
if not hasattr(module, "use_ada_layer_norm_zero"):
module.use_ada_layer_norm = False
module.use_ada_layer_norm_zero = False
return model
make_tokenflow_attention_block
This function defines a TokenFlowBlock class and then returns the TokenFlowBlock class. TokenFlowBlock inherits from the BasicTransformerBlock class and only rewrites forward
the function .
First pivotal_pass
determine whether it is a key frame:
- Keyframes, just save them
pivot_hidden_states
- For non-key frames, take the non-key frames and key frames
source_latent
, calculate the cosine similarity cosine_sim with the key frames, shape=(n_frames * seq_len, len(batch_idxs) * seq_len)
, find it相似度最大的帧下标idx
, and then stack 3 copies for source, uncond, and cond.- If the current batch is not the first batch,
len(batch_idxs) =2
save the most similar frame index to idx1 and idx2 respectively. - If it is the first batch,
len(batch_idxs) =1
save the most similar frame index to idx1
- If the current batch is not the first batch,
batch_size, sequence_length, dim = hidden_states.shape # (batch, seq_len, dim)
n_frames = batch_size // 3 # batch = 3 * n_frames: source + uncond + cond
mid_idx = n_frames // 2
hidden_states = hidden_states.view(3, n_frames, sequence_length, dim) # (source + uncond + cond, n_frames, seq_len, dim)
norm_hidden_states = self.norm1(hidden_states).view(3, n_frames, sequence_length, dim)
if self.pivotal_pass: # is_pivotal = True # 关键帧,存下
self.pivot_hidden_states = norm_hidden_states # (3, n_frames, sequence_length, dim) ,关键帧的n_frames=5
else: # is_pivotal = False # 非关键帧,与关键帧计算source_latent的cosine_sim
idx1 = []
idx2 = []
batch_idxs = [self.batch_idx] # 每batch_size帧进行一批处理,batch_idx是第几个batch,如32帧,batch_size=8,batch_idx可以为0或1或2或3或4
if self.batch_idx > 0: # 如果不是第一个batch
batch_idxs.append(self.batch_idx - 1) # 加入前一个batch的idx,如当前batch_idx=1时,再加入0,则batch_idxs=[1,0]
# 取source_latent的非关键帧与关键帧计算cosine_sim,如果batch_idxs=[1,0],则只拿第0个batch和第1个batch的关键帧和其norm_hidden_states计算sim
sim = batch_cosine_sim(norm_hidden_states[0].reshape(-1, dim), # (n_frames*sequence_length, dim)
self.pivot_hidden_states[0][batch_idxs].reshape(-1, dim)) # (len(batch_idxs)*sequence_length, dim)
if len(batch_idxs) == 2: # 如果不是第一个batch, 分别保存最相似的帧下标到idx1和idx2
# sim: (n_frames * seq_len, len(batch_idxs) * seq_len), len(batch_idxs)=2
sim1, sim2 = sim.chunk(2, dim=1)
idx1.append(sim1.argmax(dim=-1)) # (n_frames * seq_len) 个数,每个数在[0,76]之间
idx2.append(sim2.argmax(dim=-1)) # (n_frames * seq_len) 个数,每个数在[0,76]之间
else: # 如果是第一个batch,保存最相似的帧下标到idx1
idx1.append(sim.argmax(dim=-1))
# 为source、uncond、cond 堆叠3份
idx1 = torch.stack(idx1 * 3, dim=0) # (3, n_frames * seq_len)
idx1 = idx1.squeeze(1)
if len(batch_idxs) == 2:
idx2 = torch.stack(idx2 * 3, dim=0) # (3, n_frames * seq_len)
idx2 = idx2.squeeze(1)
Next, Self-Attentionattn1
, Cross-Attentionattn2
, and Feed-forwardff
are performed in sequence . There is no change in Cross-Attention and Feed-forward. The only change is the Self-Attention process :
- For keyframes, calculate the self-attention result and save it.
- For non-key frames, fuse them with the attention results of key frames. The fusion method is weighted average, and the weight is determined by the distance between the frame and the key frame. If the non-key frame is a frame in the first batch, the attention result of the key frame is used directly. If the non-key frame is a frame in the second batch, calculate the attention results with the key frames in the first batch and the key frames in the second batch, and then perform a weighted average. The weight is determined by the distance between the frame and two keyframes. The specific formula is as follows:
weight = ∣ s − p 1 ∣ / ( ∣ s − p 1 ∣ + ∣ s − p 2 ∣ ) weight = |s - p1| / (|s - p1| + |s - p2|)weight=∣s−p1∣/(∣s−p1∣+∣s−p 2∣ )
where, s represents the number of the frame, p1 represents the number of the first key frame, and p2 represents the number of the second key frame.
# 1. Self-Attention
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {
}
if self.pivotal_pass:
# norm_hidden_states.shape = 3, n_frames * seq_len, dim
self.attn_output = self.attn1(
norm_hidden_states.view(batch_size, sequence_length, dim),
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
**cross_attention_kwargs,
)
# 3, n_frames * seq_len, dim - > 3 * n_frames, seq_len, dim
self.kf_attn_output = self.attn_output
else:
batch_kf_size, _, _ = self.kf_attn_output.shape
self.attn_output = self.kf_attn_output.view(3, batch_kf_size // 3, sequence_length, dim)[:,
batch_idxs] # 3, n_frames, seq_len, dim --> 3, len(batch_idxs), seq_len, dim
# gather values from attn_output, using idx as indices, and get a tensor of shape 3, n_frames, seq_len, dim
if not self.pivotal_pass:
if len(batch_idxs) == 2:
attn_1, attn_2 = self.attn_output[:, 0], self.attn_output[:, 1]
attn_output1 = attn_1.gather(dim=1, index=idx1.unsqueeze(-1).repeat(1, 1, dim))
attn_output2 = attn_2.gather(dim=1, index=idx2.unsqueeze(-1).repeat(1, 1, dim))
s = torch.arange(0, n_frames).to(idx1.device) + batch_idxs[0] * n_frames
# distance from the pivot
p1 = batch_idxs[0] * n_frames + n_frames // 2
p2 = batch_idxs[1] * n_frames + n_frames // 2
d1 = torch.abs(s - p1)
d2 = torch.abs(s - p2)
# weight
w1 = d2 / (d1 + d2)
w1 = torch.sigmoid(w1)
w1 = w1.unsqueeze(0).unsqueeze(-1).unsqueeze(-1).repeat(3, 1, sequence_length, dim)
attn_output1 = attn_output1.view(3, n_frames, sequence_length, dim)
attn_output2 = attn_output2.view(3, n_frames, sequence_length, dim)
attn_output = w1 * attn_output1 + (1 - w1) * attn_output2
else:
attn_output = self.attn_output[:,0].gather(dim=1, index=idx1.unsqueeze(-1).repeat(1, 1, dim))
attn_output = attn_output.reshape(
batch_size, sequence_length, dim) # 3 * n_frames, seq_len, dim
else:
attn_output = self.attn_output
hidden_states = hidden_states.reshape(batch_size, sequence_length, dim) # 3 * n_frames, seq_len, dim
hidden_states = attn_output + hidden_states # res_connect