大数跨境
0
0

神经常微分方程(Neural ODEs):分步解释示例

神经常微分方程(Neural ODEs):分步解释示例 AI算法之道
2025-10-13
0
导读:神经常微分方程原理讲解

点击蓝字
关注我们










01


引言


宇宙中的许多现象都是连续过程,通常由常微分方程(ODE)建立模型。这在各科学领域都十分常见,具体案例包括天气预测、流体运动模拟与疾病传播等课题。

对于大多数此类问题,我们难以获得解析解,因此常采用数值方法求解这些常微分方程。但有时我们甚至无法准确列出方程表达式——可能仅掌握观测数据,且已知目标建模过程具有连续性。此时我们期望从数据中直接学习对应的常微分方程。

神经网络擅长从数据中学习,但传统神经网络通常在离散空间中运作。神经常微分方程(ODE)的核心思想由此诞生:通过学习系统随时间连续演化的规律,而非直接预测离散状态。该理论由陈等人于论文《Neural Ordinary Differential Equations》中首次提出,并成功应用于多种简单现象的建模。

论文:https://arxiv.org/pdf/1806.07366

本文将不再赘述原论文细节,而是聚焦于该技术的原理机制。





02

  核心思想

神经常微分方程(Neural ODEs)的关键在于构建一个预测导数(变化率)的函数,而非直接预测未来状态。一旦获得描述导数的模型,我们就能通过积分运算还原系统轨迹,这与传统常微分方程求解思路一脉相承。

欧拉法是最基础的数值积分方案:基于当前状态,利用估计的导数向前推进一小段时间步长。

在神经常微分方程中,我们沿用相同原理,但使用神经网络参数化导数函数,使其能灵活建模复杂的未知动力学系统。这种连续时间框架天然适配非均匀时间序列(非均匀间隔的测量点数据),并为系统演化注入了平滑连续的归纳偏置。正如卷积神经网络天生适合处理图像数据,神经常微分方程(Neural ODEs)对连续系统建模具有天然优势。

其数学表达如下所示,其中h代表神经网络:

若采用欧拉法对此进行积分,即可得到下一状态的预测值:

其中h代表步长,z表示系统在t时刻的状态。通过缩小步长可以获得更精确的解。或者也可以采用标准的常微分方程求解器,这类求解器会根据导数函数变化的"快慢"动态调整步长——当导数变化平缓时采用较大步长,变化剧烈时自动缩小步长。





03

如何训练该类模型?

最简易且通常最高效的方法是直接"展开"常微分方程的每个计算步骤(类似于循环神经网络的处理方式),然后通过神经网络进行反向传播(例如,若采用两步欧拉积分,可将其视为具有两层结构的神经网络)。然而,当网络规模较大时,这种方法会消耗大量内存。为此,陈发明了一种名为伴随方法的技术,它可以实现高效利用内存的反向传播。对于较小的网络,伴随方法的速度通常比这种“简单”的展开方法慢得多。





04

代码示例


以下代码是使用Pytorch实现的(简化版)欧拉法。完整代码示例可在此处查看。

代码:https://colab.research.google.com/drive/17kBvuIeXwVgrRwsi27_rqXjiBxZKYx--?usp=sharing

我们采用螺旋数据集,目标是在仅给定初始数据点(下图绿色点)的情况下预测后续的螺旋数据点(蓝色点)。

直观来说,这就像原点处有一个点,随时间推移受到某种神秘力的推动,而我们试图还原其运动轨迹。我们设置的损失函数是预测数据点与真实值之间的均方误差。

为简化实现,我们假设固定步长等于输入时间点之间的差值。然而,在实践中,能够明确控制步长是比较好的做法。你可能希望在不同时间点之间进行多步计算,不过这会增加额外的代码复杂度,所以我在这里省略了这部分内容。

class ODEFunc(nn.Module):    def __init__(self, hidden_dim, input_dim):        super(ODEFunc, self).__init__()        self.net = nn.Sequential(            nn.Linear(input_dim, hidden_dim),            nn.ELU(),            nn.Linear(hidden_dim, hidden_dim),            nn.ELU(),            nn.Linear(hidden_dim, input_dim)        )    def forward(self, t, x):        return self.net(x)class ODESolver(nn.Module):    def __init__(self, func):        super(ODESolver, self).__init__()        self.func = func        self.method = method    def forward(self, x0, t):        # t is vector of time points to evaluate (i.e. [1, 2, 3, ...])        h = t[1] - t[0]  # Assuming uniform time steps        trajectory = [x0]        x = x0        for i in range(len(t) - 1):            # Simple Euler method            dx = self.func(t[i], x) * h            x = x + dx            trajectory.append(x)        return torch.stack(trajectory)

训练的主函数如下:

def main():    n_points = 50    spiral_data = generate_spiral_data(n_points=n_points, noise=0.1)    spiral_data = torch.tensor(spiral_data, dtype=torch.float32)    t = torch.linspace(01, n_points)    func = ODEFunc(hidden_dim=100, input_dim=2)    ode_solver = ODESolver(func)    x0 = spiral_data[0]    optimizer = optim.Adam(func.parameters(), lr=0.01)    # Training loop    n_epochs = 300    losses = []    for epoch in tqdm(range(n_epochs)):        optimizer.zero_grad()        pred_trajectory = ode_solver(x0, t)        # Loss is MSE between predicted and true trajectory (only for original dimensions)        loss = nn.MSELoss()(pred_trajectory[:, :, :2].squeeze(), spiral_data)        loss.backward()        optimizer.step()        losses.append(loss.item())

下面可视化了数据集、训练损失以及向量场(即神经网络预测一个点应该移动的方向)。

这里有几点需要注意。其一,损失值有点不稳定,这在神经常微分方程(Neural ODEs)中很常见;其二,鉴于这只是个简单示例问题,训练却花费了很长时间。下一节将讨论这些问题以及其他一些需要考虑的因素。

正如我们将看到的,普通的神经常微分方程本身的效果其实并不理想。在很多情况下,需要对其进行扩展才能有效应用。





05

实用技巧与注意事项

在本节中,我们将介绍一些使用技巧总结如下:

  • 批处理(Batching)

在上述示例中,我们尝试根据初始点预测单个轨迹。然而,实际应用中通常需要从包含多条不同长度轨迹的数据集中学习,并希望在训练时对其进行批处理以加速训练。为了实现高效的批处理,我们需要在损失函数中引入掩码机制(masking)。此外,若轨迹长度差异显著,建议根据序列长度对数据进行分组,以最大化计算资源的利用率。

  • 激活函数选择(避免使用ReLU)

由于神经网络的目标是预测连续动态现象,因此使用具有急剧不连续性的激活函数(如修正线性单元 ReLU)往往会导致模型表现不佳。推荐选用平滑的激活函数(如  Swish)。

  • 课程学习(Curriculum Learning)

与循环神经网络(RNN)类似,神经微分方程可能因时序步长缩小或轨迹长度增加而面临梯度消失或爆炸问题,这会导致损失值剧烈震荡。缓解方法之一是逐步增加训练时的序列长度。例如:初始阶段仅预测相邻点之间的差异;随后每个训练周期逐步延长轨迹长度(如先训练长度为4的序列,再依次扩展到8、16),直至覆盖完整轨迹长度。

  • 增强型神经微分方程(Augmented Neural ODEs)

这是一种极其简洁的扩展方法。其核心思路是:若当前状态(假设仅包含粒子的位置信息)不足以预测下一状态(可能还需速度、加速度等信息),则可扩展状态维度。具体做法是在初始状态中追加若干零值维度,随着状态演化,这些维度将自动学习到对网络有用的特征表示,该方法通常还能加速训练进程。






06

扩展

  • 神经事件函数与神经跳跃ODE

许多现象是连续的,但具有明显的不连续性。这些现象也被称为混合系统。神经常微分方程(Neural ODEs)在学习这些现象时表现不佳。例如,想象一个弹跳的球。当它在空中时,我们可以轻松地建模它的下一个状态。然而,当它撞击地面时,它的速度方向会发生突变。准确知道它何时撞击地面将对神经常微分方程的预测产生重大影响。在这里的微小的时间误差会导致结果显著偏离

  • 神经事件函数

通过预设状态突变触发条件来解决这一问题,并精准定位事件触发时刻。在弹跳球场景中,可设定"球心接近地面"为触发条件,系统会自动搜索球体与地面的精确碰撞时间点,随后将球体状态更新为反弹上升状态。

论文:https://arxiv.org/abs/2011.03902

  • 神经跳跃ODE

采用类似思路解决该问题,但其特点在于能够自动学习状态转换时机,最初是为更具随机性的过程所设计。这类技术的典型应用场景包括:机器人运动中的地面接触检测、控制系统的突发状态切换(如温控器启停)等。

论文:https://arxiv.org/abs/2006.04727


  • 物理信息神经网络

物理信息神经网络(Physics-informed neural networks)将已知物理定律嵌入损失函数或神经ODE微分函数,驱使神经网络满足守恒定律或力平衡原则。

论文:https://www.sciencedirect.com/science/article/abs/pii/S0021999118307125

简而言之,物理约束可作为归纳偏置引入系统,从而提升模型的学习效能





07

应用案例

尽管神经常微分方程非常酷炫,但它目前仍是一个相对小众的研究领域。例如,在处理具有极不规则时间间隔的时间序列预测问题时,神经常微分方程取得了一定的成功。如今,它们最大的应用可能是在生成式图像模型中的流匹配(Flow Matching),在某些情况下,其表现优于扩散模型。话虽如此,仍有许多人在对其进行研究探索,或许未来情况会有所改变。





点击上方小卡片关注我




添加个人微信,进专属粉丝群!



【声明】内容源于网络
0
0
AI算法之道
一个专注于深度学习、计算机视觉和自动驾驶感知算法的公众号,涵盖视觉CV、神经网络、模式识别等方面,包括相应的硬件和软件配置,以及开源项目等。
内容 573
粉丝 0
AI算法之道 一个专注于深度学习、计算机视觉和自动驾驶感知算法的公众号,涵盖视觉CV、神经网络、模式识别等方面,包括相应的硬件和软件配置,以及开源项目等。
总阅读185
粉丝0
内容573