attention.py
""" 实现attention """ import torch import torch.nn as nn import torch.nn.functional as F import config class Attention(nn.Module): def __init__(self,method="general"): super(Attention,self).__init__() assert method in ["dot","general","concat"],"attention method error" self.method = method if method == "general": self.W = nn.Linear(config.chatbot_encoder_hidden_size*2,config.chatbot_encoder_hidden_size*2,bias=False) if method == "concat": self.W = nn.Linear(config.chatbot_decoder_hidden_size*4,config.chatbot_decoder_hidden_size*2,bias=False) self.V = nn.Linear(config.chatbot_decoder_hidden_size*2,1,bias=False) def forward(self,decoder_hidden,encoder_outputs): if self.method == "dot": return self.dot_score(decoder_hidden,encoder_outputs) elif self.method == "general": return self.general_socre(decoder_hidden,encoder_outputs) elif self.method == "concat": return self.concat_socre(decoder_hidden,encoder_outputs) def dot_score(self,decoder_hidden,encoder_outputs): """H_t^T * H_s :param decoder_hidden:[1,batch_size,128*2] --->[batch_size,128*2,1] :param encoder_outputs:[batch_size,encoder_max_len,128*2] --->[batch_size,encoder_max_len,128*2] :return:attention_weight:[batch_size,encoder_max_len] """ decoder_hidden_viewed = decoder_hidden.squeeze(0).unsqueeze(-1) #[batch_size,128*2,1] attention_weight = torch.bmm(encoder_outputs,decoder_hidden_viewed).squeeze(-1) return F.softmax(attention_weight,dim=-1) def general_socre(self,decoder_hidden,encoder_outputs): """ H_t^T *W* H_s :param decoder_hidden:[1,batch_size,128*2]-->[batch_size,decode_hidden_size] *[decoder_hidden_size,encoder_hidden_size]--->[batch_size,encoder_hidden_size] :param encoder_outputs:[batch_size,encoder_max_len,128*2] :return:[batch_size,encoder_max_len] """ decoder_hidden_processed =self.W(decoder_hidden.squeeze(0)).unsqueeze(-1) #[batch_size,encoder_hidden_size*2,1] attention_weight = torch.bmm(encoder_outputs, decoder_hidden_processed).squeeze(-1) return F.softmax(attention_weight, dim=-1) def concat_socre(self,decoder_hidden,encoder_outputs): """ V*tanh(W[H_t,H_s]) :param decoder_hidden:[1,batch_size,128*2] :param encoder_outputs:[batch_size,encoder_max_len,128*2] :return:[batch_size,encoder_max_len] """ #1. decoder_hidden:[batch_size,128*2] ----> [batch_size,encoder_max_len,128*2] # encoder_max_len 个[batch_size,128*2] -->[encoder_max_len,bathc_size,128*2] -->transpose--->[] encoder_max_len = encoder_outputs.size(1) batch_size = encoder_outputs.size(0) decoder_hidden_repeated = decoder_hidden.squeeze(0).repeat(encoder_max_len,1,1).transpose(0,1) #[batch_size,max_len,128*2] h_cated = torch.cat([decoder_hidden_repeated,encoder_outputs],dim=-1).view([batch_size*encoder_max_len,-1]) #[batch_size*max_len,128*4] attention_weight = self.V(F.tanh(self.W(h_cated))).view([batch_size,encoder_max_len]) #[batch_size*max_len,1] return F.softmax(attention_weight,dim=-1)
decoder.py
""" 实现解码器 """ import torch.nn as nn import config import torch import torch.nn.functional as F import numpy as np import random from chatbot.attention import Attention class Decoder(nn.Module): def __init__(self): super(Decoder,self).__init__() self.embedding = nn.Embedding(num_embeddings=len(config.target_ws), embedding_dim=config.chatbot_decoder_embedding_dim, padding_idx=config.target_ws.PAD) #需要的hidden_state形状:[1,batch_size,64] self.gru = nn.GRU(input_size=config.chatbot_decoder_embedding_dim, hidden_size=config.chatbot_decoder_hidden_size, num_layers=config.chatbot_decoder_number_layer, bidirectional=False, batch_first=True, dropout=config.chatbot_decoder_dropout) #假如encoder的hidden_size=64,num_layer=1 encoder_hidden :[2,batch_sizee,64] self.fc = nn.Linear(config.chatbot_decoder_hidden_size,len(config.target_ws)) self.attn = Attention(method="general") self.fc_attn = nn.Linear(config.chatbot_decoder_hidden_size * 2, config.chatbot_decoder_hidden_size, bias=False) def forward(self, encoder_hidden,target,encoder_outputs): # print("target size:",target.size()) #第一个时间步的输入的hidden_state decoder_hidden = encoder_hidden #[1,batch_size,128*2] #第一个时间步的输入的input batch_size = encoder_hidden.size(1) decoder_input = torch.LongTensor([[config.target_ws.SOS]]*batch_size).to(config.device) #[batch_size,1] # print("decoder_input:",decoder_input.size()) #使用全为0的数组保存数据,[batch_size,max_len,vocab_size] decoder_outputs = torch.zeros([batch_size,config.chatbot_target_max_len,len(config.target_ws)]).to(config.device) if random.random()>0.5: #teacher_forcing机制 for t in range(config.chatbot_target_max_len): decoder_output_t,decoder_hidden = self.forward_step(decoder_input,decoder_hidden,encoder_outputs) decoder_outputs[:,t,:] = decoder_output_t #获取当前时间步的预测值 value,index = decoder_output_t.max(dim=-1) decoder_input = index.unsqueeze(-1) #[batch_size,1] # print("decoder_input:",decoder_input.size()) else: for t in range(config.chatbot_target_max_len): decoder_output_t, decoder_hidden = self.forward_step(decoder_input, decoder_hidden,encoder_outputs) decoder_outputs[:, t, :] = decoder_output_t #把真实值作为下一步的输入 decoder_input = target[:,t].unsqueeze(-1) # print("decoder_input size:",decoder_input.size()) return decoder_outputs,decoder_hidden def forward_step(self,decoder_input,decoder_hidden,encoder_outputs): ''' 计算一个时间步的结果 :param decoder_input: [batch_size,1] :param decoder_hidden: [1,batch_size,128*2] :return: ''' decoder_input_embeded = self.embedding(decoder_input) # print("decoder_input_embeded:",decoder_input_embeded.size()) #out:[batch_size,1,128*2] #decoder_hidden :[1,bathc_size,128*2] # print(decoder_hidden.size()) out,decoder_hidden = self.gru(decoder_input_embeded,decoder_hidden) ##### 开始attention ############ ### 1. 计算attention weight attn_weight = self.attn(decoder_hidden,encoder_outputs) #[batch_size,1,encoder_max_len] ### 2. 计算context vector #encoder_ouputs :[batch_size,encoder_max_len,128*2] context_vector = torch.bmm(attn_weight.unsqueeze(1),encoder_outputs).squeeze(1) #[batch_szie,128*2] ### 3. 计算 attention的结果 #[batch_size,128*2] #context_vector:[batch_size,128*2] --> 128*4 #attention_result = [batch_size,128*4] --->[batch_size,128*2] attention_result = torch.tanh(self.fc_attn(torch.cat([context_vector,out.squeeze(1)],dim=-1))) # attention_result = torch.tanh(torch.cat([context_vector,out.squeeze(1)],dim=-1)) #### attenion 结束 # print("decoder_hidden size:",decoder_hidden.size()) #out :【batch_size,1,hidden_size】 # out_squeezed = out.squeeze(dim=1) #去掉为1的维度 out_fc = F.log_softmax(self.fc(attention_result),dim=-1) #[bathc_size,vocab_size] # print("out_fc:",out_fc.size()) return out_fc,decoder_hidden def evaluate(self,encoder_hidden,encoder_outputs): # 第一个时间步的输入的hidden_state decoder_hidden = encoder_hidden # [1,batch_size,128*2] # 第一个时间步的输入的input batch_size = encoder_hidden.size(1) decoder_input = torch.LongTensor([[config.target_ws.SOS]] * batch_size).to(config.device) # [batch_size,1] # print("decoder_input:",decoder_input.size()) # 使用全为0的数组保存数据,[batch_size,max_len,vocab_size] decoder_outputs = torch.zeros([batch_size, config.chatbot_target_max_len, len(config.target_ws)]).to( config.device) predict_result = [] for t in range(config.chatbot_target_max_len): decoder_output_t, decoder_hidden = self.forward_step(decoder_input, decoder_hidden,encoder_outputs) decoder_outputs[:, t, :] = decoder_output_t # 获取当前时间步的预测值 value, index = decoder_output_t.max(dim=-1) predict_result.append(index.cpu().detach().numpy()) #[[batch],[batch]...] decoder_input = index.unsqueeze(-1) # [batch_size,1] # print("decoder_input:",decoder_input.size()) # predict_result.append(decoder_input) #把结果转化为ndarray,每一行是一条预测结果 predict_result = np.array(predict_result).transpose() return decoder_outputs, predict_result
seq2seq.py
""" 完成seq2seq模型 """ import torch.nn as nn from chatbot.encoder import Encoder from chatbot.decoder import Decoder class Seq2Seq(nn.Module): def __init__(self): super(Seq2Seq,self).__init__() self.encoder = Encoder() self.decoder = Decoder() def forward(self, input,input_len,target): encoder_outputs,encoder_hidden = self.encoder(input,input_len) decoder_outputs,decoder_hidden = self.decoder(encoder_hidden,target,encoder_outputs) return decoder_outputs def evaluate(self,input,input_len): encoder_outputs, encoder_hidden = self.encoder(input, input_len) decoder_outputs, predict_result = self.decoder.evaluate(encoder_hidden,encoder_outputs) return decoder_outputs,predict_result
本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:pytorch seq2seq闲聊机器人加入attention机制 - Python技术站