import tensorflow as tf
from tensorflow.keras.layers import Dense, LSTM, Embedding, Concatenate
from tensorflow.keras.models import Model
class BahdanauAttention(tf.keras.layers.Layer):
def __init__(self, units):
super(BahdanauAttention, self).__init__()
self.W1 = Dense(units)
self.W2 = Dense(units)
self.V = Dense(1)
def call(self, query, values):
query_with_time_axis = tf.expand_dims(query, 1)
score = self.V(tf.nn.tanh(
self.W1(query_with_time_axis) + self.W2(values)))
attention_weights = tf.nn.softmax(score, axis=1)
context_vector = attention_weights * values
context_vector = tf.reduce_sum(context_vector, axis=1)
return context_vector, attention_weights
# Define the encoder
class Encoder(tf.keras.Model):
def __init__(self, vocab_size, embedding_dim, enc_units, batch_sz):
super(Encoder, self).__init__()
self.batch_sz = batch_sz
self.enc_units = enc_units
self.embedding = Embedding(vocab_size, embedding_dim)
self.lstm = LSTM(self.enc_units, return_sequences=True, return_state=True)
def call(self, x, hidden):
x = self.embedding(x)
output, state_h, state_c = self.lstm(x, initial_state=hidden)
return output, state_h, state_c
# Define the decoder
class Decoder(tf.keras.Model):
def __init__(self, vocab_size, embedding_dim, dec_units, batch_sz):
super(Decoder, self).__init__()
self.batch_sz = batch_sz
self.dec_units = dec_units
self.embedding = Embedding(vocab_size, embedding_dim)
self.lstm = LSTM(self.dec_units, return_sequences=True, return_state=True)
self.fc = Dense(vocab_size)
self.attention = BahdanauAttention(self.dec_units)
def call(self, x, hidden, enc_output):
context_vector, attention_weights = self.attention(hidden, enc_output)
x = self.embedding(x)
x = tf.concat([tf.expand_dims(context_vector, 1), x], axis=-1)
output, state_h, state_c = self.lstm(x)
output = tf.reshape(output, (-1, output.shape[2]))
x = self.fc(output)
return x, state_h, state_c, attention_weights
# Define the model
def build_model(vocab_size, embedding_dim, units, batch_size):
encoder = Encoder(vocab_size, embedding_dim, units, batch_size)
decoder = Decoder(vocab_size, embedding_dim, units, batch_size)
# Define input placeholders
input_sequence = tf.keras.Input(shape=(None,))
target_sequence = tf.keras.Input(shape=(None,))
# Initialize the hidden state for the LSTM
hidden = (tf.zeros([batch_size, units]), tf.zeros([batch_size, units]))
# Encode the input sequence
enc_output, enc_hidden_h, enc_hidden_c = encoder(input_sequence, hidden)
# Set the decoder's initial hidden state to the encoder's final hidden state
dec_hidden = (enc_hidden_h, enc_hidden_c)
# Initialize the attention weights for visualization
attention_weights = []
# Iterate over the target sequence to generate the output sequence
loss = 0
for t in range(target_sequence.shape