@@ -258,7 +258,6 @@ def extractSentencePairs(conversations):
258258print ("\n Writing newly formatted file..." )
259259with open (datafile , 'w' , encoding = 'utf-8' ) as outputfile :
260260 writer = csv .writer (outputfile , delimiter = delimiter )
261-
262261 for pair in extractSentencePairs (conversations ):
263262 writer .writerow (pair )
264263
@@ -351,11 +350,6 @@ def trim(self, min_count):
351350# filter out sentences with length greater than the ``MAX_LENGTH``
352351# threshold (``filterPairs``).
353352#
354- # For the sake of efficient re-usability, in our ``loadPrepareData``
355- # function, we will save our clean ``voc`` and ``pairs`` to .tar files.
356- # That way, if we run this code again, we can simply load this data
357- # instead of having to redo the tedious preprocessing.
358- #
359353
360354MAX_LENGTH = 10 # Maximum sentence length to consider
361355
@@ -400,27 +394,17 @@ def filterPairs(pairs):
400394
401395
402396def loadPrepareData (corpus , corpus_name , datafile , save_dir ):
403- try :
404- print ("Start loading training data ..." )
405- voc = torch .load (os .path .join (save_dir , 'training_data' , corpus_name , 'voc.tar' ))
406- pairs = torch .load (os .path .join (save_dir , 'training_data' , corpus_name , 'pairs.tar' ))
407- except FileNotFoundError :
408- print ("Saved data not found, start preparing training data ..." )
409- voc , pairs = readVocs (datafile , corpus_name )
410- print ("Read {!s} sentence pairs" .format (len (pairs )))
411- pairs = filterPairs (pairs )
412- print ("Trimmed to {!s} sentence pairs" .format (len (pairs )))
413- print ("Counting words..." )
414- for pair in pairs :
415- voc .addSentence (pair [0 ])
416- voc .addSentence (pair [1 ])
417- print ("Counted words:" , voc .num_words )
418- # Save filtered & trimmed voc and pairs to file for later use
419- directory = os .path .join (save_dir , 'training_data' , corpus_name )
420- if not os .path .exists (directory ):
421- os .makedirs (directory )
422- torch .save (voc , os .path .join (directory , '{!s}.tar' .format ('voc' )))
423- torch .save (pairs , os .path .join (directory , '{!s}.tar' .format ('pairs' )))
397+ print ("Saved data not found, start preparing training data ..." )
398+ voc , pairs = readVocs (datafile , corpus_name )
399+ print ("Read {!s} sentence pairs" .format (len (pairs )))
400+ pairs = filterPairs (pairs )
401+ print ("Trimmed to {!s} sentence pairs" .format (len (pairs )))
402+ print ("Counting words..." )
403+ # Add words from both query and response sentences
404+ for pair in pairs :
405+ voc .addSentence (pair [0 ])
406+ voc .addSentence (pair [1 ])
407+ print ("Counted words:" , voc .num_words )
424408 return voc , pairs
425409
426410
@@ -744,7 +728,7 @@ def forward(self, input_seq, input_lengths, hidden=None):
744728# we calculate attention weights, or energies, using the hidden state of
745729# the decoder from the current time step only. Bahdanau et al.’s attention
746730# calculation requires knowledge of the decoder’s state from the previous
747- # time step. Also, Luong et al. provides various methods to calculate the
731+ # time step. Also, Luong et al. provides various methods to calculate the
748732# attention energies between the encoder output and decoder output which
749733# are called “score functions”:
750734#
@@ -1064,7 +1048,7 @@ def train(input_variable, lengths, target_variable, mask, max_target_len, encode
10641048# training right where we left off.
10651049#
10661050
1067- def trainIters (model_name , voc , pairs , encoder , decoder , encoder_optimizer , decoder_optimizer , embedding , encoder_n_layers , decoder_n_layers , save_dir , n_iteration , batch_size , print_every , save_every , clip , corpus_name ):
1051+ def trainIters (model_name , voc , pairs , encoder , decoder , encoder_optimizer , decoder_optimizer , embedding , encoder_n_layers , decoder_n_layers , save_dir , n_iteration , batch_size , print_every , save_every , clip , corpus_name , loadFilename ):
10681052
10691053 # Load batches for each iteration
10701054 training_batches = [batch2TrainData (voc , [random .choice (pairs ) for _ in range (batch_size )])
@@ -1107,7 +1091,7 @@ def trainIters(model_name, voc, pairs, encoder, decoder, encoder_optimizer, deco
11071091 'en_opt' : encoder_optimizer .state_dict (),
11081092 'de_opt' : decoder_optimizer .state_dict (),
11091093 'loss' : loss ,
1110- 'voc ' : voc ,
1094+ 'voc_dict ' : voc . __dict__ ,
11111095 'embedding' : embedding .state_dict ()
11121096 }, os .path .join (directory , '{}_{}.tar' .format (iteration , 'checkpoint' )))
11131097
@@ -1356,24 +1340,31 @@ def evaluateInput(encoder, decoder, voc, beam_size):
13561340# '{}_checkpoint.tar'.format(checkpoint_iter))
13571341
13581342
1359- # Initialize Model
1343+ # Load model if a loadFilename is provided
13601344if loadFilename :
1361- #checkpoint = torch.load(loadFilename)
1362- checkpoint = torch .load (loadFilename , map_location = torch .device ('cpu' ))
1345+ # If loading on same machine the model was trained on
1346+ checkpoint = torch .load (loadFilename )
1347+ # If loading a model trained on GPU to CPU
1348+ #checkpoint = torch.load(loadFilename, map_location=torch.device('cpu'))
13631349 encoder_sd = checkpoint ['en' ]
13641350 decoder_sd = checkpoint ['de' ]
1351+ encoder_optimizer_sd = checkpoint ['en_opt' ]
1352+ decoder_optimizer_sd = checkpoint ['de_opt' ]
13651353 embedding_sd = checkpoint ['embedding' ]
1354+ voc .__dict__ = checkpoint ['voc_dict' ]
13661355
13671356
1368- checkpoint = None
13691357print ('Building encoder and decoder ...' )
1358+ # Initialize word embeddings
13701359embedding = nn .Embedding (voc .num_words , hidden_size )
1360+ if loadFilename :
1361+ embedding .load_state_dict (embedding_sd )
1362+ # Initialize encoder & decoder models
13711363encoder = EncoderRNN (hidden_size , embedding , encoder_n_layers , dropout )
13721364decoder = LuongAttnDecoderRNN (attn_model , embedding , hidden_size , voc .num_words , decoder_n_layers , dropout )
13731365if loadFilename :
13741366 encoder .load_state_dict (encoder_sd )
13751367 decoder .load_state_dict (decoder_sd )
1376- embedding .load_state_dict (embedding_sd )
13771368# use cuda
13781369encoder = encoder .to (device )
13791370decoder = decoder .to (device )
@@ -1409,14 +1400,14 @@ def evaluateInput(encoder, decoder, voc, beam_size):
14091400encoder_optimizer = optim .Adam (encoder .parameters (), lr = learning_rate )
14101401decoder_optimizer = optim .Adam (decoder .parameters (), lr = learning_rate * decoder_learning_ratio )
14111402if loadFilename :
1412- encoder_optimizer .load_state_dict (checkpoint [ 'en_opt' ] )
1413- decoder_optimizer .load_state_dict (checkpoint [ 'de_opt' ] )
1403+ encoder_optimizer .load_state_dict (encoder_optimizer_sd )
1404+ decoder_optimizer .load_state_dict (decoder_optimizer_sd )
14141405
14151406# Run training iterations
14161407print ("Starting Training!" )
14171408trainIters (model_name , voc , pairs , encoder , decoder , encoder_optimizer , decoder_optimizer ,
14181409 embedding , encoder_n_layers , decoder_n_layers , save_dir , n_iteration , batch_size ,
1419- print_every , save_every , clip , corpus_name )
1410+ print_every , save_every , clip , corpus_name , loadFilename )
14201411
14211412
14221413######################################################################
0 commit comments