YOLO26 自定义损失函数 分类任务自定义损失的接口约定 YOLO26 自定义损失函数 分类任务自定义损失的接口约定flyfish这个约定是 分类训练循环中调用损失函数的固定调用契约自定义损失类必须完全符合这个契约才能被框架正常识别、调用不会出现参数不匹配、返回值解包失败等报错。分别约束了「调用形式、入参格式、返回值格式、挂载位置」1. 调用形式约定必须实现__call__方法实例可直接调用框架在训练迭代中会以函数调用的方式直接使用损失函数实例框架内部的调用逻辑简化版为# 训练循环中框架的固定调用写法loss,loss_itemsself.model.criterion(preds,batch)因此自定义损失类必须实现__call__方法让类的实例可以像函数一样被直接调用。普通 PyTorch 损失函数继承nn.Module通过forward方法实现计算本质也是依赖nn.Module自带的__call__机制。YOLO 分类的损失不强制要求继承nn.Module只要类实现了__call__即可正常工作。2. 入参格式约定固定接收 2 个参数顺序不可修改损失函数的__call__方法必须固定接收两个入参顺序为(模型预测结果, 批次数据字典)不可调换、不可增减参数。第一个参数preds模型前向输出分类模型前向传播的输出结果格式兼容不同版本 Ultralytics 中分类模型的输出有两种形式直接返回形状为[batch_size, 类别数]的分类 logits 张量返回元组/列表通常结构为[中间特征图, 最终分类logits]有效预测值在第二个位置对应代码中的兼容处理# 兼容两种输出格式提取最终的分类预测张量predspreds[1]ifisinstance(preds,(list,tuple))elsepreds第二个参数batch批次数据字典数据加载器DataLoader返回的单批次数据固定为字典格式分类任务固定包含键cls对应形状为[batch_size]的类别索引标签不是 one-hot 编码损失计算时必须通过batch[cls]取出真实标签示例batch8 的二分类任务中batch[cls]是形如tensor([0, 1, 0, 1, 1, 0, 0, 1])的一维张量3. 返回值约定必须返回二元组必须返回 2 个值框架会自动解包少返回/多返回都会直接触发报错。返回值顺序作用格式要求第一个值用于反向传播更新模型权重必须是带计算图的标量张量可求导通常是批次内所有样本损失求均值后的结果第二个值用于训练日志统计、进度条打印、指标文件记录必须是分离梯度后的损失值.detach()不参与计算图避免显存泄漏对应代码中的标准实现lossfocal_loss.mean()# 第一个值带梯度的标量损失用于反向传播更新参数returnloss,loss.detach()# 第二个值脱梯度的损失值仅用于日志打印和统计目标检测任务中会返回多个损失项的字典但分类任务只有单损失直接返回脱梯度的标量即可。4. 挂载位置约定必须挂载为模型的criterion属性框架是通过self.model.criterion来定位损失函数的因此无论你用哪种注入方式最终都要把自定义损失的实例赋值给模型实例的.criterion属性。子类化标准法模型初始化时自动调用init_criterion()生成实例并赋值给self.criterion属于框架原生的标准流程示例下面是一个最小化的、完全符合接口约定的自定义损失包装原生交叉熵可以直接接入YOLO分类训练fromtorch.nnimportCrossEntropyLossclassSimpleCustomLoss:def__init__(self,label_smoothing0.1):self.ceCrossEntropyLoss(label_smoothinglabel_smoothing)# 约定1实现 __call__ 方法# 约定2固定入参顺序 preds, batchdef__call__(self,preds,batch):# 兼容模型输出格式predspreds[1]ifisinstance(preds,(list,tuple))elsepreds# 从 batch 字典中取出分类标签lossself.ce(preds,batch[cls])# 约定3返回 (带梯度损失, 脱梯度损失) 二元组returnloss,loss.detach()