参考文献:Atteetion Is All You Need[ 1 ] ^{[1]} [ 1 ] , On Layer Normalization in the Transformer Architecture[ 2 ] ^{[2]} [ 2 ] , Improving Language Understanding by Generative Pre-Training[ 3 ] ^{[3]} [ 3 ]
参考Blog:The Illustrated Transformer - Jay Alammar
参考Code:minGPT-pytorch , decision-transformer , tinyGPT-jax
概述
Transformer 由论文[1]提出,这篇文章的核心框架就是 Self-Attention 和 Multi-Head Attention 架构,基于 Multi-Head Attention 本文给出了 Transformer 的 Encoder-Decoder 架构,但是当前流行的 GPT 模型只使用 Decoder 部分。首先我们将分析 Attention 部分,再分析 GPT1 模型的架构,最后在小文本数据集上进行训练,并可以进行简单句子扩写。
模型介绍
前置芝士
首先简单介绍 NLP 任务的前置芝士,Token
表示对文字的 D D D 维编码 ,不同语言 Token
的对象不同:中文一般为一个字,英文可以是一个单词、也可以是一个阿拉伯字母,每个单词或者字母都有其对应的 D D D 维编码。在 NLP 任务中,输入样本的维度一般为 R B × T × D \mathbb{R}^{B\times T\times D} R B × T × D ,其中 B B B 表示 Batch Size 大小,T T T 表示 Token
数量,D D D 表示对每个 Token
进行编码后的维度。(不失一般性,我们下面讨论的时候都省略第一个维度 B B B )
对字母进行编码 (Embedding) 成 D D D 维 Token
就是一个 hash
过程,假设我们的字符集大小为 N N N ,通过创建一个 R N × D \mathbb{R}^{N\times D} R N × D 矩阵 W = { w 1 , w 2 , ⋯ , w N } T W = \{w_1,w_2,\cdots,w_N\}^T W = { w 1 , w 2 , ⋯ , w N } T ,则字符集中第 i i i 个字符的对应 Token
就是 W i W_i W i 。在机器学习中,该编码矩阵 W W W 可以在梯度下降中自动更新,无需自己手动设定。
Self-Attention 机制
Self-Attention 是一种自监督机制(此处也成为交叉注意力机制 Cross-Attention ),本质上就是一种基于协方差矩阵对另一个向量进行加权平均 的结果。定义如下:(默认声明 w i w_i w i 表示矩阵 W W W 的第一个维度中的第 i i i 个元素)
设 Q , K ∈ R T × d k , V ∈ R T × d v Q, K \in \mathbb{R}^{T\times d_k}, V\in\mathbb{R}^{T\times d_v} Q , K ∈ R T × d k , V ∈ R T × d v 分别表示第 i i i 个 Token
的询问值 q i q_i q i (Query),键 k i k_i k i (Key) 和值 v i v_i v i (Value),其中我们用 ⟨ q i , k j ⟩ \langle q_i, k_j\rangle ⟨ q i , k j ⟩ 衡量第 i i i 个 Query 与第 j j j 个 Key 的相关性大小,我们想要求出每个 Query 对所有的 Key 计算平均后,对 Value 进行加权求和得到的结果,该过程可以表示如下:
z i = ∑ j = 1 T softmax ( { ⟨ q i , k l ⟩ } l = 1 T ) j ⋅ v j z_i = \sum_{j=1}^{T}\text{softmax}(\{\langle q_i, k_l\rangle\}_{l=1}^{T})_j\cdot v_j
z i = j = 1 ∑ T softmax ( { ⟨ q i , k l ⟩ } l = 1 T ) j ⋅ v j
写成矩阵形式如下:
Z = softmax ( Q K T d k ) V , 其中softmax作用在最后一个维度上 Z = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V,\qquad \text{其中softmax作用在最后一个维度上}
Z = softmax ( d k Q K T ) V , 其中 softmax 作用在最后一个维度上
在这里 1 / d k 1/\sqrt{d_k} 1 / d k 的原因:∀ i ∈ { 1 , ⋯ , T } \forall i\in \{1,\cdots, T\} ∀ i ∈ { 1 , ⋯ , T } 有 q i j , k i j ∼ N ( 0 , σ 2 ) , ( j = 1 , ⋯ , d k ) q_{ij},k_{ij}\sim \mathcal{N}(0,\sigma^2), (j = 1,\cdots, d_k) q i j , k i j ∼ N ( 0 , σ 2 ) , ( j = 1 , ⋯ , d k ) ,则 ∑ j = 1 d k q i j k i j ∼ N ( 0 , d k σ 4 ) \sum_{j=1}^{d_k}q_{ij}k_{ij}\sim \mathcal{N}(0,d_k\sigma^4) ∑ j = 1 d k q i j k i j ∼ N ( 0 , d k σ 4 ) ,由于 σ ≈ 1 \sigma\approx 1 σ ≈ 1 ,所以系数 1 / d k 1/\sqrt{d_k} 1 / d k 可以保持输出的方差在 1 1 1 左右,避免发散。
注意力矩阵的修正
Causal Self-Attention(因果自注意力机制) :这里 Σ : = Q K T \Sigma:=QK^T Σ : = Q K T 就是协方差矩阵(注意力矩阵),可以被用来衡量两组随机变量的相关性,其中 Σ i j \Sigma_{ij} Σ i j 表示第 i i i 个 Query 和第 j j j 个 Key 之间的相关性大小,如果我们只期望每个 token
只考虑它及它前面的相关性 ,我们只需要令 Σ i j = 0 , ( i > j ) \Sigma_{ij}=0, (i>j) Σ i j = 0 , ( i > j ) (只保留下三角部分,其余部分用 0 0 0 替代)。通过这样变换后的协方差矩阵作用在 V V V 上得到的结果就是因果注意力机制的输出。
文本续写 :在文本续写时候,我们输入的文本 Token
数量通常会小于最大 Token
数量 T T T ,我们就需要用 0 0 0 对输入进行填充,即 x = { x 1 , ⋯ , x t , 0 , ⋯ , 0 } T x = \{x_1,\cdots, x_t, 0,\cdots, 0\}^T x = { x 1 , ⋯ , x t , 0 , ⋯ , 0 } T ,那么我们在计算协方差矩阵时候就不要对填充部分计算相关性,即令 Σ i j = 0 , ( i > t , j > t ) \Sigma_{ij} = 0, (i>t, j>t) Σ i j = 0 , ( i > t , j > t )
那么什么是自注意力中的自 从哪来?我们还没有介绍 Q , K , V Q,K,V Q , K , V 如何获得,假设我们输入的样本维度为 x ∈ R T × D x\in\mathbb{R}^{T\times D} x ∈ R T × D ,有意思的是自监督中的三个值就是通过三个全连接 W Q , W K ∈ R T × d k , W V ∈ R T × d v W^{Q},W^{K}\in\mathbb{R}^{T\times d_k}, W^{V}\in\mathbb{R}^{T\times d_v} W Q , W K ∈ R T × d k , W V ∈ R T × d v 分别输出得到的,即 Q = x W Q , K = x W K , V = x W V Q = xW^{Q}, K = xW^{K}, V = xW^{V} Q = x W Q , K = x W K , V = x W V 。综上,自注意力机制表示如下:
Attn ( x ; W Q , W K , W V ) = softmax ( x W Q ( W K ) T x T d k ) ( x W V ) \text{Attn}(x; W^{Q},W^{K},W^{V}) = \text{softmax}\left(\frac{xW^Q(W^K)^Tx^T}{\sqrt{d_k}}\right)(xW^V)
Attn ( x ; W Q , W K , W V ) = softmax ( d k x W Q ( W K ) T x T ) ( x W V )
所以说模型到底是怎么理解 Q , K , V Q,K,V Q , K , V 的真实含义其实我们无法知道,Query,Key,Value 这只是我们赋予其的概念。
排列等变性
设变换 T : X ∈ R N → Y ∈ R N T: X\in\mathbb{R}^{N}\to Y\in\mathbb{R}^{N} T : X ∈ R N → Y ∈ R N ,∀ x ∈ X \forall x\in X ∀ x ∈ X ,对于任意的排列变换 p : I n → I n p: \mathcal{I}_{n}\to\mathcal{I}_{n} p : I n → I n ,I n \mathcal{I}_n I n 表示大小为 n n n 的指标集(排列变换满足 p ( x ) = p ( x 1 , x 2 , ⋯ , x N ) = ( x i 1 , x i 2 , ⋯ , x i N ) , ( i j ≠ i k , i j , i k ∈ { 1 , ⋯ , N } ) p(x) = p(x_1,x_2,\cdots,x_N) = (x_{i_1},x_{i_2},\cdots,x_{i_N}), (i_j\neq i_k, i_j,i_k\in\{1,\cdots, N\}) p ( x ) = p ( x 1 , x 2 , ⋯ , x N ) = ( x i 1 , x i 2 , ⋯ , x i N ) , ( i j = i k , i j , i k ∈ { 1 , ⋯ , N } ) ),若有 T ( p ( x ) ) = p ( T ( x ) ) T(p(x)) = p(T(x)) T ( p ( x ) ) = p ( T ( x ) ) 则称变换 T T T 具有排列等变性 。(简单来说就是把输入 x x x 的下标重新排列下,再经过 T T T 变换后的结果 T ( p ( x ) ) T(p(x)) T ( p ( x ) ) ,和直接把 x x x 经过 T T T 变换后再进行相同重新排列 p ( T ( x ) ) p(T(x)) p ( T ( x ) ) 结果一致)
下面我们证明 Attn ( p ( x ) ) = p ( Attn ( x ) ) \text{Attn}(p(x)) = p(\text{Attn}(x)) Attn ( p ( x ) ) = p ( Attn ( x ) ) ,设 X ∈ R L × d = ( x 1 , ⋯ , x L ) T X\in \mathbb{R}^{L\times d} = (x_1,\cdots,x_L)^T X ∈ R L × d = ( x 1 , ⋯ , x L ) T ,忽略掉 softmax
和 1 / d k 1/\sqrt{d_k} 1 / d k 系数,我们可以得到自注意力变换为 X W Q ( W K ) T X T X W V XW^Q(W^K)^TX^TXW^V X W Q ( W K ) T X T X W V ,由于最后的 W V W^V W V 是对第二维度进行变换的矩阵,满足对第一维度的排列等变性,简化为 f ( X ) = X A X T X B f(X) = XAX^TXB f ( X ) = X A X T X B ,只需证 f f f 具有排列等变性。由于
f ( X ) = X A X T X B = [ x 1 T ⋮ x L T ] A [ x 1 ⋯ x T ] [ x 1 T ⋮ x L T ] = [ x 1 T ⋮ x L T ] A ∑ i = 1 L x i x i T f(X) = XAX^TXB = \begin{bmatrix}x_1^T\\\vdots\\x_L^T\end{bmatrix}A\begin{bmatrix}x_1&\cdots&x_T\end{bmatrix}\begin{bmatrix}x_1^T\\\vdots\\x_L^T\end{bmatrix}
= \begin{bmatrix}
x_1^T\\\vdots\\x_L^T
\end{bmatrix}A\sum_{i=1}^Lx_ix_i^T
f ( X ) = X A X T X B = ⎣ ⎢ ⎢ ⎡ x 1 T ⋮ x L T ⎦ ⎥ ⎥ ⎤ A [ x 1 ⋯ x T ] ⎣ ⎢ ⎢ ⎡ x 1 T ⋮ x L T ⎦ ⎥ ⎥ ⎤ = ⎣ ⎢ ⎢ ⎡ x 1 T ⋮ x L T ⎦ ⎥ ⎥ ⎤ A i = 1 ∑ L x i x i T
注意到 A ∑ i = 1 L x i x i T A\sum_{i=1}^Lx_ix_i^T A ∑ i = 1 L x i x i T 是与排列变换无关的常量,当我们将输入 X X X 中的 i , j i,j i , j 行交换后,输出的 i , j i, j i , j 行也会相应进行交换,所以变换 f f f 具有排列等变性。 QED
形象理解:由于自注意力是对 Query 进行的查询,当 Query 位置发生变换时,自注意力输出的结果也会相应发生变换。上式中的体现:结果中最左边的 x x x 唯一确定输出的行排列顺序,而这个 x x x 也同时确定 Query 的位置。
位置信息嵌入 (Position Embedding)
正是由于自注意力机制关于排列具有不变性,也即每个 Token
的位置信息无法被模型获取 ,所以我们传入样本 x x x 时候,需要嵌入位置信息 PE ∈ R T × D \text{PE}\in\mathbb{R}^{T\times D} PE ∈ R T × D ,其中 PE i \text{PE}_{i} PE i 表示处于当前输入中第 i i i 个 Token
的位置信息编码,我们可以将其直接加到传入样本上:x ← x + PE x\gets x + \text{PE} x ← x + PE ,从而引入位置信息。
在论文[1]中是一个 PE \text{PE} PE 由 s i n , c o s sin, cos s i n , c o s 交错形成的固定矩阵(位置和频率相关),而更通用的做法则是将 PE \text{PE} PE 作为可学习参数,让模型自己学习得到(初始化为全零)。
Multihead-Attention 模型
多头注意力模型(Multihead-Attention)就是将 Self-Attention 进行堆叠得到的,我们将上面的自注意力过程简记为 Attn ( x ) \text{Attn}(x) Attn ( x ) ,设自注意力头数目为 h ∈ N + h\in \mathbb{N}_{+} h ∈ N + ,则 Multi-Attention 过程表示如下:
Multi ( x ) = [ Concate i = 1 h ( Attn i ( x ) ) ] W 0 ∈ R T × d e \text{Multi}(x) = \left[\overset{h}{\underset{i=1}\text{Concate}}(\text{Attn}_i(x))\right]W^0\in \mathbb{R}^{T\times d_e}
Multi ( x ) = [ i = 1 Concate h ( Attn i ( x ) ) ] W 0 ∈ R T × d e
其中 Concate ( W 1 , W 2 , ⋯ , W n ) \text{Concate}(W_1,W_2,\cdots,W_n) Concate ( W 1 , W 2 , ⋯ , W n ) 表示将矩阵 W i W_i W i 按照最后一个维度进行连接,W 0 ∈ R ( h d v ) × d e W^0\in \mathbb{R}^{(hd_v)\times d_e} W 0 ∈ R ( h d v ) × d e ,其中 d e d_e d e 表示 Multi-Attention 最终输出的每个 Token
的编码 (Embed) 维度。
代码实现
Multihead-Attention 的代码实现上需要注意一下细节:
一般有 h h h 能够整除 d e d_e d e ,即 h ∣ d e h | d_e h ∣ d e 且 d v = d k = d e / h d_v = d_k = d_e / h d v = d k = d e / h ,所以我们无需定义 d v , d k d_v,d_k d v , d k 。
可以将计算 Q , V , K Q,V,K Q , V , K 的三个神经网络合为一个大网络,输出维度为 3 d e 3d_e 3 d e ,将输出的结果先按照 h h h 头数目进行划分,再将最后特征维度平均划分为三个维度为 d e / h d_e / h d e / h 的 Q , K , V Q,K,V Q , K , V 。
在计算完自注意力矩阵后需要通过一个 Dropout
,避免过拟合。
class CausalSelfAttention ( nn. Module) :
n_embd: int
n_head: int
p_drop_attn: float
@nn. compact
def __call__ ( self, x: jnp. ndarray, train: bool , mask_len: int = None ) :
D = self. n_embd // self. n_head
B, L, _ = x. shape
mask = jnp. tri( L)
if mask_len is not None :
mask = jnp. where( jnp. arange( L) . reshape( L, 1 ) >= mask_len, 0 , mask)
x = nn. Dense( 3 * self. n_embd) ( x)
q, k, v = jnp. array_split( x. reshape( B, L, self. n_head, - 1 ) . transpose( 0 , 2 , 1 , 3 ) , 3 , - 1 )
attn = q @ jnp. swapaxes( k, - 1 , - 2 ) / jnp. sqrt( D)
attn = jnp. where( mask == 0 , - 1e18 , attn)
attn = jax. nn. softmax( attn)
attn = nn. Dropout( self. p_drop_attn) ( attn, deterministic= not train)
y = ( attn @ v) . transpose( 0 , 2 , 1 , 3 ) . reshape( B, L, self. n_embd)
y = nn. Dense( self. n_embd) ( y)
return y
Transformer Block (Transformer Layer)就是一种带有残差连接和 Layer Normalization 的 Attention 架构,在论文 中介绍了两种 Layer Norm 的放置位置,如下图所示:
左图为论文[1]中所提出的原始 Transformer Layer 结构,右图为该论文给出将 Layer Norm 前置的结构,通过实验验证了前置的效果优于后置的。这里的残差连接和 ResNet 理念一致,因为一个 Transformer 模型需要非常多的 Transformer Block (Transformer Layer),残差连接就是为了深度提升的同时保持原始特征不丢失。
后置 Layer Norm 的 JAX 代码实现如下:
class GPTConfig ( MainCLS) :
n_embd = 768
n_head = 12
n_block = 12
p_drop_embd = 0.1
p_drop_resid = 0.1
p_drop_attn = 0.1
def __init__ ( self, n_vocab, n_token, ** kwargs) :
self. n_vocab = n_vocab
self. n_token = n_token
for k, v in kwargs. items( ) :
setattr ( self, k, v)
assert self. n_embd % self. n_head == 0 , "n_embd must be devided by n_head"
class TransformerBlock ( nn. Module) :
cfg: GPTConfig
@nn. compact
def __call__ ( self, x: jnp. ndarray, train: bool , mask_len: int = None ) :
attn_cfg = { key: getattr ( self. cfg, key) for key in [ 'n_embd' , 'n_head' , 'p_drop_attn' ] }
z = nn. LayerNorm( ) ( x)
z = CausalSelfAttention( ** attn_cfg) ( z, train, mask_len)
x = x + nn. Dropout( self. cfg. p_drop_resid) ( z, deterministic= not train)
z = nn. Sequential( [
nn. LayerNorm( ) ,
nn. Dense( 4 * self. cfg. n_embd) , nn. selu,
nn. Dense( self. cfg. n_embd) ,
] ) ( x)
x = x + nn. Dropout( self. cfg. p_drop_resid) ( z, deterministic= not train)
return x
GPT模型
在论文[1]中最开始提出的是一种Encoder-Decoder的形式,而论文[3]中给出的GPT-1模型则是只用Decoder进行编码,方法非常简单,只需将 Transformer Block 进行堆叠,最后连接一个全连接网络以及Softmax给出每个 Token
的下一个 Token
预测。例如训练集中有“我很好。”这句话,那么输入 x x x 可以为“我很好”的编码,维度为 R 3 × d e \mathbb{R}^{3\times d_e} R 3 × d e ,模型的预测目标为“很好。”,然而模型需要给出对每个 Token
的下一个 Token
的概率分布预测,也就是从整个词库中选出下一个词,即输出维度为 R 3 × n v \mathbb{R}^{3\times n_v} R 3 × n v ,其中 n v n_v n v 表示整个词库大小。
从代码上理解非常容易:
class GPT ( nn. Module) :
cfg: GPTConfig
@nn. compact
def __call__ ( self, x: jnp. ndarray, train: bool , mask_len: int = None ) :
cfg = self. cfg
pos_embd = self. param( 'pos_embd' , lambda _, shape: jnp. zeros( shape) , ( 1 , cfg. n_token, cfg. n_embd) )
x = pos_embd + nn. Embed( cfg. n_vocab, cfg. n_embd) ( x)
x = nn. Dropout( cfg. p_drop_embd) ( x, deterministic= not train)
for _ in range ( cfg. n_block) :
x = TransformerBlock( cfg) ( x, train, mask_len)
x = nn. LayerNorm( ) ( x)
x = nn. Dense( cfg. n_vocab) ( x)
return x
模型训练及预测代码可以见下文(用JAX实现,PyTorch实现类似)。
模型训练
def model_step ( state: TrainState, x: jnp. ndarray, y: jnp. ndarray, train: bool ) :
dropout_rng, base_rng = jax. random. split( state. dropout_rng)
def loss_fn ( params) :
logits = state. apply_fn( { 'params' : params} , x, train= train, rngs= { 'dropout' : dropout_rng} )
tmp = - jax. nn. log_softmax( logits) . reshape( - 1 , logits. shape[ - 1 ] )
loss = tmp[ jnp. arange( tmp. shape[ 0 ] ) , y. reshape( - 1 ) ] . mean( )
acc = ( jnp. argmax( logits, - 1 ) . reshape( - 1 ) == y. reshape( - 1 ) ) . mean( )
return loss, acc
( loss, acc) , grads = jax. value_and_grad( loss_fn, has_aux= True ) ( state. params)
state = state. apply_gradients( grads= grads)
state = state. replace( dropout_rng= base_rng)
return state, ( loss, acc)
模型预测
def predict ( state: TrainState, x: jnp. ndarray, rng: jax. Array, mask_len: int = None ) :
logits = state. apply_fn( { 'params' : state. params} , x, train= False , mask_len= mask_len)
if mask_len is not None :
logits = logits[ jnp. arange( logits. shape[ 0 ] ) , mask_len- 1 , : ]
pred = jax. random. categorical( rng, logits, - 1 )
return pred
简单文本训练
数据集搭建
训练数据集来自 tinyGPT-jax ,文本包含四大名著以及几篇莎士比亚的文章,这里我采用 torch.utils.data.Dataset
做数据读入,值得注意的是数据集划分方式(文本数据集不像图像数据集直接随机采样就行,文本还要求每个样本的字符保持连续性),由于数据量非常小,所以我们可以将全部文本字符串读入到内存中,存储到字符串 text
中,简单起见,我们将每个 token
定义为一个中文字符和一个阿拉伯字母(而非一个英文单词)。
设整个文本数据集大小为 N N N ,训练集与验证集大小占比为 ( 1 − r ) : r (1-r):r ( 1 − r ) : r ,我们首先将其平均划分为 n d n_d n d 块,每块大小为 n c = ⌊ N n d ⌋ n_c = \left\lfloor\frac{N}{n_d}\right\rfloor n c = ⌊ n d N ⌋ ,然后再将每个块中的前 ( 1 − r ) n c (1-r)n_c ( 1 − r ) n c 个字符划分给训练集,后面剩余的字符划分给验证集。
text = . . .
data = self. encode( text)
block_size = len ( data) // n_divide
train_block_size = int ( block_size * ( 1 - val_ratio) )
self. data = {
'train' : np. concatenate( [ data[ i: i+ train_block_size] for i in range ( 0 , len ( data) , block_size) ] ) ,
'val' : np. concatenate( [ data[ i+ train_block_size: i+ block_size] for i in range ( 0 , len ( data) , block_size) ] )
}
我们的数据集大小为 8650026 8650026 8 6 5 0 0 2 6 ,词库大小为 5840 5840 5 8 4 0 ,划分的训练集大小为 6920026 6920026 6 9 2 0 0 2 6 ,验证集大小为 1730000 1730000 1 7 3 0 0 0 0 ,如果把每个字符开始的一段句子都作为样本进行训练,大约要 ( 6920026 − 128 ) / 64 × 0.37 ≈ 11 hours (6920026-128)/64 \times 0.37 \approx 11 \text{hours} ( 6 9 2 0 0 2 6 − 1 2 8 ) / 6 4 × 0 . 3 7 ≈ 1 1 hours (在 RTX4060-Laptop 上训练一个 batch_size=64, n_token=128
的时间为 0.37s
)。
所以为了简短时间,我将每个训练集大小设置为 512 × 128 512\times 128 5 1 2 × 1 2 8 个 Token
也就是训练 512 × 0.37 ≈ 3 mins 512\times 0.37\approx 3 \text{mins} 5 1 2 × 0 . 3 7 ≈ 3 mins ,验证集大小设置为 32 × 128 32\times 128 3 2 × 1 2 8 个 Token
也就是训练 10 seconds 10 \text{seconds} 1 0 seconds 左右,而 Token
的开始位置就是从对应的数据集中随机采样获取。具体实现如下:
class TextDataset ( Dataset) :
def __init__ ( self, data, n_token, datasize) :
self. data, self. n_token, self. datasize = data, n_token, datasize
def __len__ ( self) :
return self. datasize
def __getitem__ ( self, idx) :
idx = random. randint( 0 , len ( self. data) - 2 - self. n_token)
d = self. data[ idx: idx+ self. n_token+ 1 ]
x, y = d[ : self. n_token] , d[ 1 : ]
return x, y
代码框架
源代码 中包含5个代码文件:
ckpt_manager.py
dataset.py
miniGPT.py
predict.py
train.py
使用方法:直接运行 train.py
等待训练完成(RTX 4060 Laptop)WandB训练结果 。
predict.py
执行效果:
具体文本: