StableGrad: Backward Scale Control without Batch Normalization
StableGrad: 无需批量归一化的反向缩放控制
Jose I. Mestre, Alberto Fernández-Hernández, Cristian Pérez-Corral, Manuel F. Dolz, Enrique S. Quintana-Ortí
AI总结 本文提出StableGrad,一种在无需批量归一化的情况下通过优化器层面控制权重-梯度缩放来稳定深度神经网络训练的方法,特别适用于物理信息神经网络等场景。
详情
训练非常深的神经网络需要控制深度方向上的量值传播。没有这种控制,激活值和梯度可能会消失、爆炸或进入不稳定区域,导致优化失败。现代架构通常通过批量归一化、残差连接或其他归一化层来缓解这个问题,这些机制会重复地重新缩放或绕过中间表示。然而,这些机制并不总是适用。在物理信息神经网络(PINNs)中,网络表示连续的物理场及其输入导数定义了训练目标,使批量依赖的归一化变得有问题,因为这会引入非局部依赖性到预测场及其导数中。我们提出StableGrad,一种优化器层面的缩放控制机制,可以在不修改前向模型的情况下纠正层间权重-梯度不平衡。因为归一化仅在反向传播后、优化器更新前应用,网络输出、其导数和物理残差保持不变。我们分析了这种缩放所引起的有效训练动态,并在深度PINNs上评估StableGrad作为目标应用,用无批量归一化的卷积网络作为诊断压力测试。在PINN基准测试中,StableGrad提高了匹配深度的解精度,并使更深层的模型在标准优化下更加可靠。在ResNet和EfficientNet架构中,移除批量归一化通常会导致训练崩溃,但StableGrad在不引入其他架构变化的情况下稳定了优化。这些结果表明,优化器层面的权重-梯度缩放控制可以提供一种实用的替代方案,当前向归一化不可用或不适用时。
Training very deep neural networks requires controlling the propagation of magnitudes across depth. Without such control, activations and gradients may vanish, explode, or enter unstable regimes that make optimization fail. Modern architectures often mitigate this problem through Batch Normalization, residual connections, or other normalization layers, which repeatedly re-scale or bypass intermediate representations. However, these mechanisms are not always appropriate. In Physics-Informed Neural Networks (PINNs), the network represents a continuous physical field and its input derivatives define the training objective, making batch-dependent normalization problematic because it can introduce non-local dependencies into the predicted field and its derivatives. We propose StableGrad, an optimizer-level scale-control mechanism that corrects layer-wise weight-gradient imbalances without modifying the forward model. Because the normalization is applied only after backpropagation and before the optimizer update, the network output, its derivatives, and the physical residual remain unchanged. We analyze the effective training dynamics induced by this rescaling and evaluate StableGrad on deep PINNs as the target application, with BatchNorm-free convolutional networks serving as a diagnostic stress test. On PINN benchmarks, StableGrad improves matched-depth solution accuracy and makes deeper models more reliable under standard optimization. On ResNet and EfficientNet architectures, where removing Batch Normalization normally leads to training collapse, StableGrad stabilizes optimization without introducing any other architectural change. These results show that optimizer-level control of weight-gradient scale can provide a practical alternative when forward normalization is unavailable or undesirable.