Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 2f5299a

Browse files
Matthew InkawhichJoelMarcey
authored andcommitted
Update chatbot model loading; serialize voc.__dict__ instead of entire object
1 parent 6636d6f commit 2f5299a

1 file changed

Lines changed: 29 additions & 38 deletions

File tree

beginner_source/chatbot_tutorial.py

100755100644
Lines changed: 29 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,6 @@ def extractSentencePairs(conversations):
258258
print("\nWriting newly formatted file...")
259259
with open(datafile, 'w', encoding='utf-8') as outputfile:
260260
writer = csv.writer(outputfile, delimiter=delimiter)
261-
262261
for pair in extractSentencePairs(conversations):
263262
writer.writerow(pair)
264263

@@ -351,11 +350,6 @@ def trim(self, min_count):
351350
# filter out sentences with length greater than the ``MAX_LENGTH``
352351
# threshold (``filterPairs``).
353352
#
354-
# For the sake of efficient re-usability, in our ``loadPrepareData``
355-
# function, we will save our clean ``voc`` and ``pairs`` to .tar files.
356-
# That way, if we run this code again, we can simply load this data
357-
# instead of having to redo the tedious preprocessing.
358-
#
359353

360354
MAX_LENGTH = 10 # Maximum sentence length to consider
361355

@@ -400,27 +394,17 @@ def filterPairs(pairs):
400394

401395

402396
def loadPrepareData(corpus, corpus_name, datafile, save_dir):
403-
try:
404-
print("Start loading training data ...")
405-
voc = torch.load(os.path.join(save_dir, 'training_data', corpus_name, 'voc.tar'))
406-
pairs = torch.load(os.path.join(save_dir, 'training_data', corpus_name, 'pairs.tar'))
407-
except FileNotFoundError:
408-
print("Saved data not found, start preparing training data ...")
409-
voc, pairs = readVocs(datafile, corpus_name)
410-
print("Read {!s} sentence pairs".format(len(pairs)))
411-
pairs = filterPairs(pairs)
412-
print("Trimmed to {!s} sentence pairs".format(len(pairs)))
413-
print("Counting words...")
414-
for pair in pairs:
415-
voc.addSentence(pair[0])
416-
voc.addSentence(pair[1])
417-
print("Counted words:", voc.num_words)
418-
# Save filtered & trimmed voc and pairs to file for later use
419-
directory = os.path.join(save_dir, 'training_data', corpus_name)
420-
if not os.path.exists(directory):
421-
os.makedirs(directory)
422-
torch.save(voc, os.path.join(directory, '{!s}.tar'.format('voc')))
423-
torch.save(pairs, os.path.join(directory, '{!s}.tar'.format('pairs')))
397+
print("Saved data not found, start preparing training data ...")
398+
voc, pairs = readVocs(datafile, corpus_name)
399+
print("Read {!s} sentence pairs".format(len(pairs)))
400+
pairs = filterPairs(pairs)
401+
print("Trimmed to {!s} sentence pairs".format(len(pairs)))
402+
print("Counting words...")
403+
# Add words from both query and response sentences
404+
for pair in pairs:
405+
voc.addSentence(pair[0])
406+
voc.addSentence(pair[1])
407+
print("Counted words:", voc.num_words)
424408
return voc, pairs
425409

426410

@@ -744,7 +728,7 @@ def forward(self, input_seq, input_lengths, hidden=None):
744728
# we calculate attention weights, or energies, using the hidden state of
745729
# the decoder from the current time step only. Bahdanau et al.’s attention
746730
# calculation requires knowledge of the decoder’s state from the previous
747-
# time step. Also, Luong et al. provides various methods to calculate the
731+
# time step. Also, Luong et al. provides various methods to calculate the
748732
# attention energies between the encoder output and decoder output which
749733
# are called “score functions”:
750734
#
@@ -1064,7 +1048,7 @@ def train(input_variable, lengths, target_variable, mask, max_target_len, encode
10641048
# training right where we left off.
10651049
#
10661050

1067-
def trainIters(model_name, voc, pairs, encoder, decoder, encoder_optimizer, decoder_optimizer, embedding, encoder_n_layers, decoder_n_layers, save_dir, n_iteration, batch_size, print_every, save_every, clip, corpus_name):
1051+
def trainIters(model_name, voc, pairs, encoder, decoder, encoder_optimizer, decoder_optimizer, embedding, encoder_n_layers, decoder_n_layers, save_dir, n_iteration, batch_size, print_every, save_every, clip, corpus_name, loadFilename):
10681052

10691053
# Load batches for each iteration
10701054
training_batches = [batch2TrainData(voc, [random.choice(pairs) for _ in range(batch_size)])
@@ -1107,7 +1091,7 @@ def trainIters(model_name, voc, pairs, encoder, decoder, encoder_optimizer, deco
11071091
'en_opt': encoder_optimizer.state_dict(),
11081092
'de_opt': decoder_optimizer.state_dict(),
11091093
'loss': loss,
1110-
'voc': voc,
1094+
'voc_dict': voc.__dict__,
11111095
'embedding': embedding.state_dict()
11121096
}, os.path.join(directory, '{}_{}.tar'.format(iteration, 'checkpoint')))
11131097

@@ -1356,24 +1340,31 @@ def evaluateInput(encoder, decoder, voc, beam_size):
13561340
# '{}_checkpoint.tar'.format(checkpoint_iter))
13571341

13581342

1359-
# Initialize Model
1343+
# Load model if a loadFilename is provided
13601344
if loadFilename:
1361-
#checkpoint = torch.load(loadFilename)
1362-
checkpoint = torch.load(loadFilename, map_location=torch.device('cpu'))
1345+
# If loading on same machine the model was trained on
1346+
checkpoint = torch.load(loadFilename)
1347+
# If loading a model trained on GPU to CPU
1348+
#checkpoint = torch.load(loadFilename, map_location=torch.device('cpu'))
13631349
encoder_sd = checkpoint['en']
13641350
decoder_sd = checkpoint['de']
1351+
encoder_optimizer_sd = checkpoint['en_opt']
1352+
decoder_optimizer_sd = checkpoint['de_opt']
13651353
embedding_sd = checkpoint['embedding']
1354+
voc.__dict__ = checkpoint['voc_dict']
13661355

13671356

1368-
checkpoint = None
13691357
print('Building encoder and decoder ...')
1358+
# Initialize word embeddings
13701359
embedding = nn.Embedding(voc.num_words, hidden_size)
1360+
if loadFilename:
1361+
embedding.load_state_dict(embedding_sd)
1362+
# Initialize encoder & decoder models
13711363
encoder = EncoderRNN(hidden_size, embedding, encoder_n_layers, dropout)
13721364
decoder = LuongAttnDecoderRNN(attn_model, embedding, hidden_size, voc.num_words, decoder_n_layers, dropout)
13731365
if loadFilename:
13741366
encoder.load_state_dict(encoder_sd)
13751367
decoder.load_state_dict(decoder_sd)
1376-
embedding.load_state_dict(embedding_sd)
13771368
# use cuda
13781369
encoder = encoder.to(device)
13791370
decoder = decoder.to(device)
@@ -1409,14 +1400,14 @@ def evaluateInput(encoder, decoder, voc, beam_size):
14091400
encoder_optimizer = optim.Adam(encoder.parameters(), lr=learning_rate)
14101401
decoder_optimizer = optim.Adam(decoder.parameters(), lr=learning_rate * decoder_learning_ratio)
14111402
if loadFilename:
1412-
encoder_optimizer.load_state_dict(checkpoint['en_opt'])
1413-
decoder_optimizer.load_state_dict(checkpoint['de_opt'])
1403+
encoder_optimizer.load_state_dict(encoder_optimizer_sd)
1404+
decoder_optimizer.load_state_dict(decoder_optimizer_sd)
14141405

14151406
# Run training iterations
14161407
print("Starting Training!")
14171408
trainIters(model_name, voc, pairs, encoder, decoder, encoder_optimizer, decoder_optimizer,
14181409
embedding, encoder_n_layers, decoder_n_layers, save_dir, n_iteration, batch_size,
1419-
print_every, save_every, clip, corpus_name)
1410+
print_every, save_every, clip, corpus_name, loadFilename)
14201411

14211412

14221413
######################################################################

0 commit comments

Comments
 (0)