动手学Transformer

动手实现Transformer,所有代码基于tensorflow2.0,配合illustrated-transformer更香。

模型架构

Encoder+Decoder

Encoder

Decoder

Attention

Add&Norm

FeedForward

Embedding

Position Encoding

模型架构

transformer使用经典的编码器-解码器框架,编码器接受一个输入序列 \((x_1,…,x_n)\),经过Embedding转化为词向量,和位置编码相加作为Encoder的输入,在一顿操作后输入被映射到\(z=(z_1,…,z_n)\),Decoder基于\(z\)在一顿操作后生成输出序列\((y_1,…,y_m)\)

动手学Transformer

看图说话

左边是Encoder,输入为词ID序列,对应形状\([batch\ size,max\ input\ sentense\ length]\),如果embedding维度设置为512,输出形状为\([batch\ size, max\ input\ sentence\ length, 512]\),\(Nx\)表示将Encoder模块堆叠\(N\)次(论文中\(N=6\))

右边是Decoder,训练阶段,Decoder输入包括目标句子的词ID序列和最后一个Encoder部分的输出,测试阶段,Decoder的输入为上一次输出的词。Decoder同样被堆叠\(N\)次,最后一个Encoder的输出被接到每一个Decoder块的输入。Decoder输出下一个词的概率,输出形状为\([batch\ size, max\ output\ sentence\ length, \ vocabulary\ length])\)

我们先盖房子在装修

class Transformer(tf.keras.Model): ''' Transformer架构,Encoder-Decoder;softmax params: num_layers:堆叠层数 dim_model:embedding 维度 num_heads:multihead attention dim_ff:FeedForWard 维度 input_vocab_size:输入词典大小 target_vocab_size:输出词典大小 rate:dropout rate ''' def __init__(self,num_layers, dim_model, num_heads, dim_ff, input_vocab_size,target_vocab_size, rate=0.1): super(Transformer, self).__init__() self.encoder = Encoder(num_layers, dim_model, num_heads,dim_ff, input_vocab_size, rate)#Encoder self.decoder = Decoder(num_layers, dim_model, num_heads,dim_ff, target_vocab_size, rate)#Decoder self.output_layer = tf.keras.layers.Dense(target_vocab_size) def call(self, inputs, targets, training, enc_padding_mask, look_ahead_mask, dec_padding_mask): encoder_output = self.encoder(inputs, training, enc_padding_mask) # (batch_size, inputs_seq_len, d_model) decoder_output, attention_weights = self.decoder(targets, encoder_output, training, look_ahead_mask, dec_padding_mask) output = self.output_layer(decoder_output) # (batch_size, tar_seq_len, target_vocab_size) return output, attention_weights

Encoder

动手学Transformer

Encoder接受输入token的embedding和位置编码,经过N次Encoder layer 堆叠,代码如下所示

class Encoder(tf.keras.layers.Layer): ''' Encoder 部分,input embedding ;Encoder layer stack ''' def __init__(self, num_layers, dim_model, num_heads,dim_ff, input_vocab_size, rate=0.1): super(Encoder, self).__init__() self.dim_model = dim_model self.num_layers = num_layers self.embedding = tf.keras.layers.Embedding(input_vocab_size, self.dim_model)#输入Embedding self.pos_encoding = positional_encoding(input_vocab_size, self.dim_model)#位置编码 self.enc_layers = [EncoderLayer(dim_model, num_heads, dim_ff, rate) for _ in range(num_layers)]#创建Encoder layer self.dropout = tf.keras.layers.Dropout(rate) def call(self, x, training, mask): seq_len = tf.shape(x)[1] # adding embedding and position encoding. x = self.embedding(x) # (batch_size, input_seq_len, dim_model) # x *= tf.math.sqrt(tf.cast(self.dim_model, tf.float32)) x += self.pos_encoding[:, :seq_len, :] x = self.dropout(x, training=training) for layer in self.enc_layers: x = layer(x, training, mask) return x # (batch_size, input_seq_len, d_model)

单个Encoder layer 有两个子层: attention层和point wise feed forward network.

class EncoderLayer(tf.keras.layers.Layer): ''' Encoder layer: multihead attention;add&layer norm;FeedForward;add&layer norm ''' def __init__(self, dim_model, num_heads, dim_ff, rate=0.1): super(EncoderLayer, self).__init__() self.mha = MultiHeadAttention(dim_model, num_heads) self.ffn = point_wise_feed_forward_network(dim_model, dim_ff) self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6) self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6) self.dropout1 = tf.keras.layers.Dropout(rate) self.dropout2 = tf.keras.layers.Dropout(rate) def call(self, x, training, mask): attn_output, _ = self.mha(x, x, x, mask) # (batch_size, input_seq_len, d_model) attn_output = self.dropout1(attn_output, training=training) out1 = self.layernorm1(x + attn_output) # (batch_size, input_seq_len, d_model) ffn_output = self.ffn(out1) # (batch_size, input_seq_len, d_model) ffn_output = self.dropout2(ffn_output, training=training) out2 = self.layernorm2(out1 + ffn_output) # (batch_size, input_seq_len, d_model) return out2

Decoder

内容版权声明:除非注明,否则皆为本站原创文章。

转载注明出处:https://www.heiqu.com/zyyyfj.html