As accelerated Numpy
首先Jax有类似 numpy
的函数库,API使用基本一致:
和数学中求导一致,Jax可以自动对Python中的标量纯函数进行求导计算,默认对函数传入中的第一个变量求导,还可以指定求导对象jax.grad(func, argnums=(0,1,...))
:
求导函数要求输出为标量(scalar)不能是向量,也就是只能对以下可微函数 f 求导:
f:Rn→R,(∇f)(x)i=∂xi∂f(x),(i=1,2,⋯,n)
不能是 f:Rn→Rm,这样的 ∇f 是 m×n 的Jacobii矩阵。
所以对于机器学习中的损失函数,我们只用想下面这样写就行:
这里直接调用 jax.grad(...)
不会重复编译函数,如果之前编译过相同输入的函数则直接从缓存中读取。
一个例子
带额外参数的 grad
求导
上述的求导要求只能有一个标量输出,如果我们期望在求解过程中将过程量(例如,loss值)返回出来,那么可以使用参数 has_aux
,使用方法如下:
jax.grad(func, has_aux=True)
:要求输出为2-tuples,其中第一个为函数的输出(标量),第二个可以是函数计算中的过程量(Auxiliary data, 任意类型,例如,字典)
与Numpy的区别
Jax特点就是函数式编程(Functional programming),也就是不要在函数中使用带有副作用 (side-effect)的代码,即与当前函数输出无关的任何代码。
在Jax对数组进行修改的方法是 x.at[0].set(new_value)
但这样会生成一个新的数组,原先的并不会进行修改:
由于Jax先编译后运行的特点,如果一个函数中的旧数组被修改后,没有再被使用到,则编译器就会进行原地修改,而非创建一个新的数组,验证如下,使用colab的免费GPU显存大小有11Gb,于是我创建了三个数组,大小合起来正好11Gb(如果前面两次修改创建了新的数组,则会导致显存溢出):
线性回归
尝试用Jax实现简单的线性回归,设数据集大小为 N=106。
- 构建数据:
y=wx+b+ε,ε∼N(0,1)
- 定义模型和损失函数:
y^(x;θ)=y^(x:w,b)=wx+bL(x,y;θ)=∣y−y^(x;θ)∣2
- 更新参数:
θ←θ−α∇θL(θ)
JIT
jaxpr语法转化
Jax的底层和TF是相同的,均使用XLA对数据进行并行计算加速,并且有类似 @tf.function
的图执行功能,在Jax中就是 @jit
(just in time),他将函数中不包含side-effect的部分先转化为 jaxpr
再用XLA编译,从而可以将编译后的函数部署在CPU、GPU或TPU上。
注意:只有在一次调用函数时才会根据传入的参数进行转化。
使用 jax.make_jaxpr()
先转化为显示 jaxpr
代码的函数,然后传入参数,查看转化后的 jaxpr
代码:
输出 jaxpr
可以用于调试代码,函数中side-effect部分的代码虽然不会被编译到XLA中,但是在生成 jaxpr
过程中会执行其一次,所以可以认为所有的side-effect在编译函数的过程中只会被执行一次。
Jax是通过对每个参数用 tracer
类进行包装(跟踪),然后重建生成 jaxpr
代码,所以上述输出中可以看到 x
被 Traced
类包装。
这篇文章讲解了如何理解 jaxpr
:Understanding Jaxprs
jit无法使用的情况
在函数中包含和输入的具体值相关而函数都是无法使用 jit
的,因为 jaxpr
的需要依赖于输入的具体值生成对应的代码,如果输入的具体值有限,则可以将其设为常量
在 jit
中执行任何和输入值相关的条件 if, while
都会报错,只有将条件中的变量设置常量,或者在输入的时候能确定下来,然后就能编译出来,Jax 的默认输入是 ShapedArray
类型,也就是默认其是数组,所以和维度相关的信息是可以作为条件的:
Pytree
Jax将Python中的字典或者递归式构造的数据结构统称为Pytree,每个字典中的 key
或者 list
中的一个索引对应树上的一个分支,例如:
就是一个包含6个叶子节点的树,jax中常用的树上操作有:
jax.tree_map(func, pytree1, pytree2, ...)
:对 pytree
中每个叶子节点作用函数 func
,并且可以对多颗结构相同的 pytree
的对应元素作用 func
函数,func
函数包含多个输入参数即可。
jax.tree_util.tree_leavs(pytree)
:显示 pytree
的所有叶子节点。
jax.tree_map
常用于更新梯度:
Vectorization
在Jax如果要对Batch中每个样本执行某个函数,例如将样本的特征由类别标签转化为one-hot向量,直接执行 for
循环效率太低,Jax提供了一个效率很高且易于使用的构造函数 jax.vmap
(Vector map)解决该问题,在 jax.vmap
外部套上 jax.jit
就可以并行执行向量化操作:
jax.vmap(func, in_axes=0 | Sequence[int], out_axes=0)
:返回一个函数向量化执行函数,函数的输入按照 in_axes
给定的维度进行展开,第 i
个 in_axes
值对应的第 i
个入参的展开维度,如果对应展开维度为 None
,则不进行展开,直接传入;out_axes
表示 func
函数的输出结果按照第 out_axes
维度进行堆叠,默认为 0
。
PRNGKey(pseudo-random number generator key)
在Jax中所有的伪随机数(pseudo random number, PRN)都是基于key的二元组生成的,key的生成方法如下:
所有使用随机数相关的函数均需要消耗一个key,所以为了保证实验的可重复性,每次消耗key前需要对其进行分解(至少分解成俩)我们保留其中一个,另一个用于生成随机数,使用过的key就不用再被使用了,下次再分解就去用新的key:
MNIST数据集训练
Jax + Flax + Optax
模型搭建
Flax主要负责深度网络模型搭建,通过继承父类 nn.Module
实现,具体有两种搭建方式 官方解释 - setup vs compact:
@nn.compact
:类似TF2的函数式构建方法,只需重构 __call__(self, inputs)
,其余只需通过调用函数(nn.Dense
, nn.relu
)即可,这些层都是 nn.Module
的子类
setup
:类似Pytorch的构建方法,需要重构 setup(self)
,并在其中先初始化好模型中带参数的层,例如全链接层,然后在 __call__(self, inputs)
中建立层之间的计算关系
模型初始化及结构显示
在搭建完模型之后通过给定初始化 key
完成参数构建,并且可以通过 clu 中 clu.parameter_overview
优化器
optax
包提供了很多常用优化器(当然基于 jax
这些优化器都可以自己实现,只需要记录下每个权重对应的动量一阶矩和二阶矩还有当前更新的次数,就可以计算 Adam
优化器的结果了),创建一个优化器及其直接更新梯度方法如下:
然而有胡 tx.update()
和 optax.apply_updates()
这两个操作是在给定 grads
和 params
后就可以直接更新,所以 flax.training.train_state
中类 TrainState
就是通过给定参数,直接一步更新梯度:
想要再将 TrainState
中加入其他参数,例如 metrics
那就有点复杂了,可以参考 Flax - quick start,其实 metrics
可以自行通过函数的输出结果自行计算,无需使用 clu.metrics
中的度量器进行更新(较为复杂)。
TrainState
这个类包含的参数有:
- step:模型更新次数。
apply_fn
:一般存储模型的预测函数,例如 model.apply(params, X)
,也可以不存储。
params
:模型的权重,是一种 pytree
。
tx
:模型所用的优化器,是 optax.GradientTransformation
的子类。
opt_state
:优化器的状态,再确定 tx
后会进行创建。
以 Adam
为例,可以通过 opt_state[0].mu['params']
查看一阶矩的参数,同理 opt_state[0].nu['params']
是二阶矩参数:
速度测试
这里比较了Jax和TF的训练速度(使用CPU计算,锐龙R7 4800U),每个epoch,Jax用时2~3s,TF用时5s。
Jax
在MNIST数据集上进行训练的方法如下:
- 首先通过
(X_train, y_train), (X_test, y_test) = keras.datasets.mnist.load_data()
获取数据集。
- 利用
jax.vmap
将标签构建为 one_hot
向量。
- 使用
flax
搭建自定义模型 nn.Module
,并定义 TrainState
类用于纪录参数。
- 定义
train_step
函数,每次将划分好的 batch_idxs
传入(数据集),并在其中定义损失函数 loss_fn
,利用 jax.value_and_grad()
计算损失值及其导数,最后用 state.apply_gradient(grads=grads)
更新状态。
- 实现主函数中的
epoch
循环和 batch
索引通过排列随机生成。
这里我还额外加上了准确率计算函数 accuracy(params, X, y)
用于计算训练集和测试集上模型的准确率,完整代码:
训练结果
TF
实现上明显更简单,但是速度不如Jax。
训练结果
利用tensorboar和wandb可视化训练过程
这里利用 tensorboardX
在 tensorboard
上进行图像绘制,并利用 wandb
的云存储功能记录训练结果,并且有更好的效果图,tensorboardX
和 wandb
配合的使用方法如下:
训练曲线图:wanb - mnist_test__42__20230822_175254,完整代码