谷歌最新推出的JAX,官方定义为CPU、GPU和TPU上的NumPy。它具有出色的自动微分(differentiation)功能,是可用于高性能机器学习研究的python库。Numpy在科学计算领域十分普及,但是在深度学习领域,由于它不支持自动微分和GPU加速,所以更多的是使用Tensorflow或Pytorch这样的深度学习框架。然而谷歌之前推出的Tensorflow API有一些比较混乱的情况,在1.x的迭代中,就存在如原子op、layers等不同层次的API。面对不同类型的用户,使用粒度不同的多层API本身并不是什么问题。但同层次的API也有多种竞品,如slim和layers等实则提高了学习成本和迁移成本。而JAX使用 XLA 在诸如GPU和TPU的加速器上编译和运行NumPy。它与 NumPy API 非常相似, numpy 完成的事情几乎都可以用 jax.numpy 完成,从而避免了直接定义API这件事。
下面简要介绍JAX的几个特性,并同时给出一些示例让读者能够快速入门上手。最后我们将结合科学计算的实例,展现google JAX在科学计算方面的巨大威力。
在深度学习领域,网络参数的优化是通过基于梯度的反向传播算法实现的。因此能够实现任意数值函数的微分对于机器学习有着十分重要的意义。下面结合官方文档的例子简要介绍这一特性。
首先介绍最简单的grad求一阶微分:可以直接通过grad函数求某一函数在某位置的梯度值
import jax.numpy as jnpfrom jax import grad, jit, vmapgrad_tanh = grad(jnp.tanh)print(grad_tanh(2.0))[OUT]:0.070650816
当然如果想对双切正弦函数继续求二阶,三阶导数,也可以这样做:
print(grad(grad(jnp.tanh))(2.0))print(grad(grad(grad(jnp.tanh)))(2.0))[OUT]:-0.136218680.25265405
除此之外,还可以利用hessian、jacfwd 和 jacrev 等方法实现函数转换,它们的功能分别是求解海森矩阵,以及利用前向或反向模式求解雅克比矩阵。Jacfwd和jacrev可以得到一样的结果,但是在不同的情形下求解效率不同,这是因为两者背后对应的微分几何中的push forward和pull back方法。而前面提到的grad则是基于反向模式。
在一些拟牛顿法的优化算法中,常常需要利用二阶的海森矩阵。为了实现海森矩阵的求解。为了实现这一目标,我们可以使用jacfwd(jacrev(f))或者jacrev(jacfwd(f))。但是前者的效率更高,因为内层的雅克比矩阵计算是通过类似于一个1维损失函数对n维向量的求导,明显使用反向模式更为合适。外层则通常是n维函数对n维向量的求导,正向模式更有优势。
无论是科学计算或者机器学习的研究中,我们都会将定义的优化目标函数应用到大量数据中,例如在神经网络中我们去计算每一个批次的损失函数值。JAX 通过 vmap 转换实现自动向量化,简化了这种形式的编程。
-
-
in_axes:输入格式为元组,代表fun中每个输入参数中,使用哪一个维度进行向量化;
-
out_axes: 经过fun计算后,每组输出在哪个维度输出。
import jax.numpy as jnpimport numpy as npimport jax
a = np.array(([1,3],[23, 5]))print(a)[out]: [[ 1 3][23 5]]b = np.array(([11,7],[19,13]))print(b)[OUT]: [[11 7][19 13]]
(2)正常的两个矩阵element-wise的相加
print(jnp.add(a,b))#[[1+11, 3+7]]# [[23+19, 5+13]][OUT]: [[12 10][42 18]]
(3)矩阵a的行 + 矩阵b的行,然后根据out_axes=0输出,0表示行输出
print(jax.vmap(jnp.add, in_axes=(0,0), out_axes=0)(a,b))#[[1+11, 3+7]]#[[23+19, 5+13]][OUT]: [[12 10][42 18]]
(4)矩阵a的行 + 矩阵b的行,然后根据out_axes=1输出,1表示列输出
print(jax.vmap(jnp.add, in_axes=(0,0), out_axes=1)(a,b))# [[1+11, 3+7]]#[[23+19, 5+13]] 再以列转置输出[OUT]: [[12 42][10 18]]
理解了上面的例子之后,现在开始增加难度,换成三维的例子:
from jax.numpy import jnpA, B, C, D = 2, 3, 4, 5def foo(tree_arg):x, (y, z) = tree_argreturn jnp.dot(x, jnp.dot(y, z))from jax import vmapK = 6 # batch sizex = jnp.ones((K, A, B)) # batch axis in different locationsy = jnp.ones((B, K, C))z = jnp.ones((C, D, K))tree = (x, (y, z))vfoo = vmap(foo, in_axes=((0, (1, 2)),))print(vfoo(tree).shape)
让我们一起来分析一下。在这段代码中分别定义了三个全1矩阵x,y,z,他们的维度分别是6*2*3,3*6*4,4*5*6。而tree则控制了foo函数中矩阵连续点积的顺序。根据in_axes可知,y和z的点积最后结果为6个3*5的子矩阵,这是由于y和z此时相当于6个y的子矩阵(3*4维)和6个z的子矩阵(4*5维)点积。再与x点积,得到的最终结果为(6,2,5)。
XLA是TensorFlow底层做JIT编译优化的工具,XLA可以对计算图做算子Fusion,将多个GPU Kernel合并成少量的GPU Kernel,用以减少调用次数,可以大量节省GPU Memory IO时间。Jax本身并没有重新做执行引擎层面的东西,而是直接复用TensorFlow中的XLA Backend进行静态编译,以此实现加速。
jit的基本使用方法非常简单,直接调用jax.jit()或使用@jax.jit装饰函数即可:
import jax.numpy as jnpfrom jax import jitdef slow_f(x):# Element-wise ops see a large benefit from fusionreturn x * x + x * 2.0x = jnp.ones((5000, 5000))fast_f = jax.jit(slow_f) # 静态编译slow_f;%timeit -n10 -r3 fast_f(x)%timeit -n10 -r3 slow_f(x)10 loops, best of 3: 24.2 ms per loop10 loops, best of 3: 82.8 ms per loop
运行时间结果:fast_f(x)是slow_f(x) 在CPU上运行速度的3.5倍!静态编译大大加速了程序的运行速度。如图1 所示。