
ApplyAdamWithAmsgradV2【免费下载链接】ops-nn本项目是CANN提供的神经网络类计算算子库实现网络在NPU上加速计算。项目地址: https://gitcode.com/cann/ops-nn产品支持情况产品是否支持Ascend 950PR/Ascend 950DT√Atlas A3 训练系列产品/Atlas A3 推理系列产品√Atlas A2 训练系列产品/Atlas A2 推理系列产品√Atlas 200I/500 A2 推理产品×Atlas 推理系列产品√Atlas 训练系列产品√功能说明算子功能Adam AMSGrad变体单步原地更新优化器。维护二阶矩历史最大值vhat用max(v, vhat)替代v抑制学习率振荡。var/m/v/vhat原地更新到var_out/m_out/v_out/vhat_out。计算公式$$ lr_{t}lr\cdot\sqrt{1-\beta_{2}^{t}},/,\left(1-\beta_{1}^{t}\right) $$ $$ m_{t}m_{t-1}\left(1-\beta_{1}\right)\left(g_{t}-m_{t-1}\right) $$ $$ v_{t}v_{t-1}\left(1-\beta_{2}\right)\left(g_{t}^{2}-v_{t-1}\right) $$ $$ \hat{v}{t}\max\left(\hat{v}{t-1}, v_{t}\right) $$ $$ \theta_{t}\theta_{t-1}-lr_{t}\cdot\frac{m_{t}}{\sqrt{\hat{v}_{t}}\epsilon} $$其中beta1_power、beta2_power为$\beta_{1}^{t}$、$\beta_{2}^{t}$标量epsilon加法顺序锁定为$\sqrt{\hat{v}_{t}}\epsilon$。参数说明参数名输入/输出描述数据类型数据格式var计算输入/计算输出待更新权重原地更新到var_outFLOAT32NDm计算输入/计算输出一阶矩估计原地更新到m_outFLOAT32NDv计算输入/计算输出二阶矩估计原地更新到v_outFLOAT32NDvhat计算输入/计算输出二阶矩历史最大值原地更新到vhat_outFLOAT32NDbeta1_power计算输入β1的t次幂标量 (1,)FLOAT32NDbeta2_power计算输入β2的t次幂标量 (1,)FLOAT32NDlr计算输入学习率标量 (1,)FLOAT32NDbeta1计算输入一阶矩衰减系数标量 (1,)FLOAT32NDbeta2计算输入二阶矩衰减系数标量 (1,)FLOAT32NDepsilon计算输入防止除零标量 (1,)FLOAT32NDgrad计算输入梯度与var同shapeFLOAT32FLOAT32NDvar_out计算输出更新后权重inplace varFLOAT32NDm_out计算输出更新后一阶矩inplace mFLOAT32NDv_out计算输出更新后二阶矩inplace vFLOAT32NDvhat_out计算输出更新后历史最大值inplace vhatFLOAT32ND约束说明var/m/v/vhat/grad 数据类型须一致仅支持FLOAT32标量输入固定FLOAT32。beta1_power、beta2_power、lr、beta1、beta2、epsilon 的shape须为 (1,)。调用说明调用方式调用样例说明图模式调用test_geir_apply_adam_with_amsgrad_v2通过图模式方式调用ApplyAdamWithAmsgradV2 算子。【免费下载链接】ops-nn本项目是CANN提供的神经网络类计算算子库实现网络在NPU上加速计算。项目地址: https://gitcode.com/cann/ops-nn创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考