|
117 | 117 | import unicodedata |
118 | 118 | import numpy as np |
119 | 119 |
|
| 120 | +device = torch.device("cpu") |
120 | 121 |
|
121 | | -USE_CUDA = torch.cuda.is_available() |
122 | | -device = torch.device("cuda" if USE_CUDA else "cpu") |
123 | 122 |
|
124 | 123 | MAX_LENGTH = 10 # Maximum sentence length |
125 | 124 |
|
@@ -677,7 +676,7 @@ def evaluateExample(sentence, encoder, decoder, searcher, voc): |
677 | 676 | # |
678 | 677 | # To load the hosted model: |
679 | 678 | # |
680 | | -# 1) Download the model `here <>`__. |
| 679 | +# 1) Download the model `here <https://download.pytorch.org/models/tutorials/4000_checkpoint.tar>`__. |
681 | 680 | # |
682 | 681 | # 2) Set the ``loadFilename`` variable to the path to the downloaded |
683 | 682 | # checkpoint file. |
@@ -728,18 +727,19 @@ def evaluateExample(sentence, encoder, decoder, searcher, voc): |
728 | 727 | dropout = 0.1 |
729 | 728 | batch_size = 64 |
730 | 729 |
|
| 730 | +# If you're loading your own model |
731 | 731 | # Set checkpoint to load from |
732 | 732 | checkpoint_iter = 4000 |
733 | | -loadFilename = os.path.join(save_dir, model_name, corpus_name, |
734 | | - '{}-{}_{}'.format(encoder_n_layers, decoder_n_layers, hidden_size), |
735 | | - '{}_checkpoint.tar'.format(checkpoint_iter)) |
| 733 | +# loadFilename = os.path.join(save_dir, model_name, corpus_name, |
| 734 | +# '{}-{}_{}'.format(encoder_n_layers, decoder_n_layers, hidden_size), |
| 735 | +# '{}_checkpoint.tar'.format(checkpoint_iter)) |
736 | 736 |
|
| 737 | +# If you're loading the hosted model |
| 738 | +loadFilename = '4000_checkpoint.tar' |
737 | 739 |
|
738 | 740 | # Load model |
739 | | -# If loading on same machine the model was trained on |
740 | | -checkpoint = torch.load(loadFilename) |
741 | | -# If loading a model trained on GPU to CPU |
742 | | -#checkpoint = torch.load(loadFilename, map_location=torch.device('cpu')) |
| 741 | +# Force CPU device options (to match tensors in this tutorial) |
| 742 | +checkpoint = torch.load(loadFilename, map_location=torch.device('cpu')) |
743 | 743 | encoder_sd = checkpoint['en'] |
744 | 744 | decoder_sd = checkpoint['de'] |
745 | 745 | encoder_optimizer_sd = checkpoint['en_opt'] |
@@ -874,7 +874,8 @@ def evaluateExample(sentence, encoder, decoder, searcher, voc): |
874 | 874 | # will serialize it for use in a non-Python deployment environment. To do |
875 | 875 | # this, we can simply save our ``scripted_searcher`` module, as this is |
876 | 876 | # the user-facing interface for running inference against the chatbot |
877 | | -# model. |
| 877 | +# model. When saving a Script module, use script_module.save(PATH) instead |
| 878 | +# of torch.save(model, PATH). |
878 | 879 | # |
879 | 880 |
|
880 | | -torch.save(scripted_searcher.state_dict(), "scripted_chatbot.pth") |
| 881 | +scripted_searcher.save("scripted_chatbot.pth") |
0 commit comments