038、Transformer进军视频:VSRT视频超分Transformer的结构创新与复现 038、Transformer进军视频VSRT视频超分Transformer的结构创新与复现上周调一个视频超分模型训练到第3个epoch突然loss炸了——从0.021直接跳到0.89我盯着终端愣了三秒。检查梯度发现是时间维度上的注意力权重在某个帧上出现了NaN。这种问题在视频超分里太典型了尤其是当你试图把图像Transformer直接搬到视频上时时序信息的处理稍有不慎就会崩掉。VSRTVideo Super-Resolution Transformer就是为解决这类问题而生的。它不像ViT那样把视频帧当成独立图像处理而是设计了一套专门针对视频时序特性的Transformer架构。今天我们就从结构创新到代码复现把VSRT掰开揉碎讲清楚。为什么视频超分需要专门的Transformer先看一个直观问题假设你有一段10帧的低分辨率视频每帧64x64。如果用ViT处理你得把每帧切成16x16的patch10帧就是10x(64/16)^2160个patch。但问题在于——这些patch之间没有任何时序关联信息。模型不知道第1帧的某个patch和第5帧的对应patch有什么关系这就等于把视频退化成了独立图像序列。VSRT的核心洞察是视频超分不仅要利用空间纹理更要挖掘时序上的运动补偿信息。它把Transformer的注意力机制拆成了两个维度——空间注意力和时间注意力分别处理帧内纹理和帧间运动。VSRT的结构创新三阶段设计VSRT的架构可以理解为三个串联的模块浅层特征提取、深层时序-空间特征融合、上采样重建。其中第二个模块是重点它包含多个VSRT Block每个Block内部又分为时间注意力Temporal Attention和空间注意力Spatial Attention。时间注意力模块的设计很有意思。它没有直接用全帧注意力而是采用了“窗口化”策略——只在相邻帧之间做注意力。比如处理第t帧时只关注[t-2, t-1, t, t1, t2]这5帧。这样做有两个好处一是计算量从O(N^2)降到O(N*W)W是窗口大小二是避免了远距离帧之间的无效关联因为视频中相隔太远的帧往往没有直接的运动补偿关系。空间注意力模块则相对常规但有一个关键细节——它引入了相对位置编码。这里踩过坑如果你直接用绝对位置编码模型会对帧内patch的绝对坐标过度敏感导致平移不变性丢失。VSRT用的是2D相对位置偏置让模型关注patch之间的相对空间关系而非绝对位置。代码复现从零搭建VSRT Block下面我们直接上手写VSRT的核心模块。注意这里我跳过了数据加载和预处理部分那些在之前的博客里讲过直接复用就行。importtorchimporttorch.nnasnnimporttorch.nn.functionalasFimportmathclassTemporalAttention(nn.Module):def__init__(self,dim,num_heads8,window_size5):super().__init__()self.num_headsnum_heads self.window_sizewindow_size self.scale(dim//num_heads)**-0.5# QKV投影注意这里dim是输入特征维度self.qkvnn.Linear(dim,dim*3,biasFalse)self.projnn.Linear(dim,dim)# 相对时间位置编码窗口大小为5时相对位置范围是[-2,2]# 别这样写直接硬编码relative_position_bias_table nn.Parameter(torch.zeros(5, num_heads))# 应该用2*window_size-1来覆盖所有可能的相对位置self.relative_position_bias_tablenn.Parameter(torch.zeros(2*window_size-1,num_heads))# 预计算相对位置索引coordstorch.arange(window_size)relative_coordscoords[:,None]-coords[None,:]# [window_size, window_size]relative_coordswindow_size-1# 偏移到非负self.register_buffer(relative_position_index,relative_coords)# 初始化偏置参数nn.init.trunc_normal_(self.relative_position_bias_table,std.02)defforward(self,x):# x shape: [B, T, H*W, C] T是帧数这里假设Twindow_sizeB,T,N,Cx.shape# 生成QKVqkvself.qkv(x).reshape(B,T,N,3,self.num_heads,C//self.num_heads)qkvqkv.permute(3,0,4,1,2,5)# [3, B, num_heads, T, N, head_dim]q,k,vqkv[0],qkv[1],qkv[2]# 计算时间维度上的注意力# 这里只在T维度做注意力N维度保持不变attn(q k.transpose(-2,-1))*self.scale# [B, num_heads, T, N, T]# 加入相对时间位置偏置relative_position_biasself.relative_position_bias_table[self.relative_position_index.view(-1)].view(self.window_size,self.window_size,-1)# [T, T, num_heads]relative_position_biasrelative_position_bias.permute(2,0,1).unsqueeze(0).unsqueeze(2)# 这里踩过坑unsqueeze(2)是为了匹配N维度因为attn的shape是[B, num_heads, T, N, T]attnattnrelative_position_bias attnattn.softmax(dim-1)x(attn v).permute(0,2,3,1,4).reshape(B,T,N,C)xself.proj(x)returnx这段代码里有个容易忽略的细节时间注意力是在T维度上做的但每个时间步的N个空间patch是独立处理的。这意味着模型可以捕捉到不同帧之间对应patch的运动信息但不会混入空间上的干扰。接下来是空间注意力模块它处理的是帧内的patch关系classSpatialAttention(nn.Module):def__init__(self,dim,num_heads8,window_size7):super().__init__()self.num_headsnum_heads self.window_sizewindow_size# 空间窗口大小比如7x7self.scale(dim//num_heads)**-0.5self.qkvnn.Linear(dim,dim*3,biasFalse)self.projnn.Linear(dim,dim)# 2D相对位置偏置# 别这样写self.relative_position_bias_table nn.Parameter(torch.zeros(window_size**2, num_heads))# 应该用(2*window_size-1)^2来覆盖所有相对位置self.relative_position_bias_tablenn.Parameter(torch.zeros((2*window_size-1)**2,num_heads))# 计算相对位置索引coords_htorch.arange(window_size)coords_wtorch.arange(window_size)coordstorch.stack(torch.meshgrid([coords_h,coords_w]))# [2, window_size, window_size]coords_flattencoords.flatten(1)# [2, window_size^2]relative_coordscoords_flatten[:,:,None]-coords_flatten[:,None,:]# [2, window_size^2, window_size^2]relative_coordsrelative_coords.permute(1,2,0).contiguous()# [window_size^2, window_size^2, 2]relative_coords[:,:,0]window_size-1relative_coords[:,:,1]window_size-1relative_coords[:,:,0]*2*window_size-1relative_position_indexrelative_coords.sum(-1)# [window_size^2, window_size^2]self.register_buffer(relative_position_index,relative_position_index)nn.init.trunc_normal_(self.relative_position_bias_table,std.02)defforward(self,x):# x shape: [B, T, N, C] N H*WB,T,N,Cx.shape qkvself.qkv(x).reshape(B,T,N,3,self.num_heads,C//self.num_heads)qkvqkv.permute(3,0,1,4,2,5)# [3, B, T, num_heads, N, head_dim]q,k,vqkv[0],qkv[1],qkv[2]attn(q k.transpose(-2,-1))*self.scale# [B, T, num_heads, N, N]# 加入2D相对位置偏置relative_position_biasself.relative_position_bias_table[self.relative_position_index.view(-1)].view(self.window_size**2,self.window_size**2,-1)# [N, N, num_heads]relative_position_biasrelative_position_bias.permute(2,0,1).unsqueeze(0).unsqueeze(0)# 这里踩过坑unsqueeze(0)两次分别对应B和T维度attnattnrelative_position_bias attnattn.softmax(dim-1)x(attn v).permute(0,1,3,2,4).reshape(B,T,N,C)xself.proj(x)returnx把这两个模块组合成VSRT Block注意顺序先做时间注意力再做空间注意力。这个顺序不是随便定的——先做时间注意力可以让模型先建立帧间对应关系再做空间注意力时就能利用到时间维度的信息来指导空间纹理重建。classVSRTBlock(nn.Module):def__init__(self,dim,num_heads8,temporal_window5,spatial_window7):super().__init__()self.temporal_attnTemporalAttention(dim,num_heads,temporal_window)self.spatial_attnSpatialAttention(dim,num_heads,spatial_window)self.norm1nn.LayerNorm(dim)self.norm2nn.LayerNorm(dim)self.ffnnn.Sequential(nn.Linear(dim,dim*4),nn.GELU(),nn.Linear(dim*4,dim))self.norm3nn.LayerNorm(dim)defforward(self,x):# x shape: [B, T, H*W, C]# 时间注意力 残差xxself.temporal_attn(self.norm1(x))# 空间注意力 残差xxself.spatial_attn(self.norm2(x))# FFN 残差xxself.ffn(self.norm3(x))returnx训练中的坑与经验回到开头那个loss爆炸的问题。排查后发现是时间注意力模块的softmax在某个帧上出现了数值不稳定。原因是当某些帧之间运动过大时注意力分数会集中在少数几个帧上导致梯度爆炸。解决办法有两个方向在时间注意力中加入dropout让注意力分布更平滑使用梯度裁剪我一般设max_norm1.0另外VSRT对batch size比较敏感。我试过batch size4时模型收敛很慢提到8之后效果明显改善。但显存不够怎么办可以尝试梯度累积或者减小空间窗口大小——从7x7降到5x5能省不少显存。还有一个容易被忽视的点数据增强。视频超分的数据增强和图像不同你不能随机裁剪单帧因为会破坏时序一致性。我建议的做法是先对视频片段做统一的随机裁剪再对每帧做相同的旋转或翻转。别这样写对每帧独立做随机翻转那样会让运动方向混乱。个人经验性建议如果你打算用VSRT做实际项目有几点值得注意第一VSRT在运动剧烈的视频上效果会打折扣。它的时间窗口机制假设相邻帧之间有较强的相关性如果场景切换太快窗口内的帧可能没有有效信息。这种情况下可以考虑增大时间窗口或者引入光流引导。第二推理速度是个瓶颈。VSRT的参数量不算大大约5M左右但注意力计算在长序列上很慢。我试过处理30帧的视频单帧推理时间接近100ms。优化方向包括用FlashAttention加速、减少空间窗口的重叠、或者只在关键帧上做时间注意力。第三如果你要做视频超分的落地应用建议先在小规模数据集上验证VSRT的时序建模能力是否真的优于图像级超分后处理。有些场景下简单的帧间对齐单帧超分反而更稳定。最后代码里那个相对位置索引的计算我建议你手算一遍确认维度。我第一次写的时候就在那里卡了半天debug发现索引算错了导致注意力偏置全乱套。这种细节问题跑一次前向传播就能暴露出来。