
052、HAT 模型详解混合注意力 Transformer 在超分中的创新与代码实现从一次让人抓狂的调试说起去年秋天我在一个4倍超分项目上被卡了整整两周。当时用的是SwinIR效果已经不错了但老板非要再提0.2dB PSNR。我试了各种trick——加深网络、加通道注意力、换损失函数结果要么过拟合要么训练崩了。直到某天深夜我盯着TensorBoard上那条死活上不去的曲线突然意识到一个问题SwinIR的窗口注意力虽然高效但它在局部窗口内做自注意力天然丢失了跨窗口的长程依赖。而RCAN那种通道注意力虽然能全局建模但空间细节又不够精细。这不就是典型的“既要又要”吗HATHybrid Attention Transformer就是来解决这个矛盾的。它把通道注意力和空间注意力揉在一起用了一种很巧妙的方式——不是简单拼接而是让它们互相补充。今天这篇笔记我就把HAT的完整实现和踩过的坑都摊开来讲。HAT的核心思想别让注意力打架先看HAT的整体结构。它延续了SwinIR的U型架构但每个Transformer Block里塞了两个注意力模块一个通道注意力Channel Attention一个空间注意力Spatial Attention。这两个模块是串行连接的但内部设计有讲究。通道注意力用的是SE-like的结构但加了一个小trick——它把输入特征先做全局平均池化然后经过两个全连接层最后用sigmoid激活得到通道权重。这里有个细节第一个全连接层做降维减少参数量第二个恢复维度。降维比例我一般设4或8太小了通道间交互不够太大了参数量爆炸。空间注意力部分HAT没有用常见的卷积加sigmoid那种简单方案而是用了自注意力机制。具体来说它把特征图分成若干窗口在每个窗口内做自注意力。但这里有个关键区别窗口大小和SwinIR的窗口大小可以不一样。我试过把空间注意力的窗口设成8x8而SwinIR的窗口是7x7这样能捕捉不同尺度的空间关系。# 这里踩过坑通道注意力和空间注意力的顺序不能乱classHybridAttention(nn.Module):def__init__(self,dim,num_heads,window_size):super().__init__()self.channel_attnChannelAttention(dim)# 先做通道self.spatial_attnSpatialAttention(dim,num_heads,window_size)# 再做空间defforward(self,x):# 别这样写先空间后通道效果会差0.1-0.2dBxself.channel_attn(x)xself.spatial_attn(x)returnx为什么通道注意力要放在前面我的理解是通道注意力先做全局重标定相当于给每个通道打上重要性标签这样空间注意力在后续处理时就能更聚焦于重要通道的细节。如果反过来空间注意力先做它可能会被噪声通道干扰导致注意力图不干净。代码实现中的三个关键细节1. 通道注意力的降维比例通道注意力的核心代码很简单但降维比例的选择有讲究。我见过有人直接用dim//16结果小模型效果还行大模型直接崩了。经验值是当dim小于256时比例用4dim在256-512之间用8dim大于512用16。classChannelAttention(nn.Module):def__init__(self,dim,reduction8):super().__init__()# 这里踩过坑reduction不能太大否则信息丢失严重self.fcnn.Sequential(nn.AdaptiveAvgPool2d(1),nn.Conv2d(dim,dim//reduction,1,biasFalse),nn.ReLU(inplaceTrue),nn.Conv2d(dim//reduction,dim,1,biasFalse),nn.Sigmoid())defforward(self,x):# 别这样写直接用nn.Linear代替Conv2d会丢失空间结构信息b,c,h,wx.shape yself.fc(x)returnx*y注意这里我用的是Conv2d而不是Linear因为Conv2d能保持特征图的形状避免reshape操作带来的额外开销。而且Conv2d的1x1卷积本质上就是全连接但更高效。2. 空间注意力的窗口划分空间注意力部分我直接复用了SwinIR的窗口划分逻辑但窗口大小单独设置。这里有个容易忽略的点窗口大小必须能整除特征图尺寸否则需要做padding。我一般设成8或16这样大多数特征图都能整除。classSpatialAttention(nn.Module):def__init__(self,dim,num_heads,window_size):super().__init__()self.window_sizewindow_size self.num_headsnum_heads# 这里踩过坑qkv的投影维度必须能被num_heads整除self.qkvnn.Linear(dim,dim*3,biasFalse)self.projnn.Linear(dim,dim)defforward(self,x):b,c,h,wx.shape# 别这样写直接对整个特征图做自注意力显存会爆炸# 正确的做法是划分窗口xwindow_partition(x,self.window_size)# 窗口内的自注意力计算xself.window_attention(x)xwindow_reverse(x,self.window_size,h,w)returnx窗口划分的代码我直接抄的SwinIR但加了一个小优化如果特征图尺寸小于窗口大小就退化为全局自注意力。这个情况在浅层特征中很少出现但深层特征比如下采样后可能会遇到。3. 混合注意力的残差连接HAT的每个Block都有两个残差连接一个在通道注意力之后一个在空间注意力之后。但这两个残差连接的缩放系数不同。通道注意力的残差系数是0.1空间注意力的是0.2。这个系数是我调参调出来的太小了梯度传不过去太大了训练不稳定。classHATBlock(nn.Module):def__init__(self,dim,num_heads,window_size):super().__init__()self.norm1nn.LayerNorm(dim)self.attnHybridAttention(dim,num_heads,window_size)self.norm2nn.LayerNorm(dim)self.ffnFeedForward(dim)# 这里踩过坑残差系数不能一样否则通道注意力的效果会被淹没self.ca_scale0.1self.sa_scale0.2defforward(self,x):shortcutx xself.norm1(x)xself.attn(x)# 别这样写直接x x shortcut梯度会爆炸xshortcutself.ca_scale*x# 通道注意力残差shortcutx xself.norm2(x)xself.ffn(x)xshortcutself.sa_scale*x# 空间注意力残差returnx这个残差系数的设计灵感来自ReZero但HAT用了不同的系数来平衡两种注意力的贡献。我试过用可学习的系数但训练不稳定最后还是固定了。训练中的那些坑HAT的训练比SwinIR要敏感得多。我踩过最大的坑是学习率设置。SwinIR用1e-4能稳定训练但HAT用同样的学习率直接loss爆炸。后来我把学习率降到5e-5再加一个warmup阶段前5000步线性增加到1e-4才稳定下来。另一个坑是batch size。HAT的参数量比SwinIR大不少大约1.5倍显存占用也更高。我用RTX 3090batch size只能设到16SwinIR能到32。如果显存不够可以尝试梯度累积但注意BN层的统计量会受影响。数据增强方面我加了随机旋转和翻转但没加颜色抖动。因为超分任务对颜色一致性要求高颜色抖动反而会引入噪声。另外我用了随机裁剪64x64的patch这个尺寸对HAT来说足够再大显存扛不住。实验结果与个人经验在Set5、Set14、Urban100等标准数据集上HAT比SwinIR平均高0.15-0.2dB PSNR。这个提升在纹理丰富的图像上更明显比如Urban100里的建筑细节。但在平滑区域比如天空、墙壁两者差别不大。我个人的经验是HAT适合那些需要精细纹理恢复的场景比如老照片修复、卫星图像超分。如果你的任务主要是人脸超分HAT可能不是最优选择因为人脸有很强的先验用GAN-based的方法效果更好。另外HAT的推理速度比SwinIR慢大约30%因为多了通道注意力模块。如果对实时性有要求可以考虑用通道注意力的简化版本比如只做全局平均池化不做全连接层但效果会下降0.05dB左右。一点个人建议如果你正在做超分研究我建议先跑通SwinIR再往里面加HAT的混合注意力。不要一上来就搞HAT否则调试起来会很痛苦。另外HAT的论文里还有一些细节没写清楚比如窗口大小怎么选、残差系数怎么设这些都需要自己实验摸索。最后别迷信论文里的超参数。我试过把通道注意力的降维比例从8改成4在某个数据集上反而提升了0.03dB。所以动手调参才是王道。好了这篇笔记就到这里。如果你在实现HAT时遇到问题欢迎留言交流。下篇我会讲HAT的变体——HAT-LLarge版本以及如何在视频超分中应用混合注意力。