8585# ------------
8686#
8787# To start, Download the data ZIP file
88- # `here <https://www.cs.cornell.edu/~cristian/Cornell_Movie-Dialogs_Corpus.html>`__
88+ # `here <https://zissou.infosci.cornell.edu/convokit/datasets/movie-corpus/movie-corpus.zip>`__
89+
8990# and put in a ``data/`` directory under the current directory.
9091#
9192# After that, let’s import some necessities.
110111from io import open
111112import itertools
112113import math
114+ import json
113115
114116
115117USE_CUDA = torch .cuda .is_available ()
140142# original format.
141143#
142144
143- corpus_name = "cornell movie-dialogs corpus"
145+ corpus_name = "movie-corpus"
144146corpus = os .path .join ("data" , corpus_name )
145147
146148def printLines (file , n = 10 ):
@@ -149,7 +151,7 @@ def printLines(file, n=10):
149151 for line in lines [:n ]:
150152 print (line )
151153
152- printLines (os .path .join (corpus , "movie_lines.txt " ))
154+ printLines (os .path .join (corpus , "utterances.jsonl " ))
153155
154156
155157######################################################################
@@ -160,55 +162,47 @@ def printLines(file, n=10):
160162# contains a tab-separated *query sentence* and a *response sentence* pair.
161163#
162164# The following functions facilitate the parsing of the raw
163- # *movie_lines.txt * data file.
165+ # *utterances.jsonl * data file.
164166#
165- # - ``loadLines`` splits each line of the file into a dictionary of
166- # fields (lineID, characterID, movieID, character, text)
167- # - ``loadConversations`` groups fields of lines from ``loadLines`` into
168- # conversations based on *movie_conversations.txt*
167+ # - ``loadLinesAndConversations`` splits each line of the file into a dictionary of
168+ # lines with fields: lineID, characterID, and text and then groups them
169+ # into conversations with fields: conversationID, movieID, and lines.
169170# - ``extractSentencePairs`` extracts pairs of sentences from
170171# conversations
171172#
172173
173- # Splits each line of the file into a dictionary of fields
174- def loadLines (fileName , fields ):
174+ # Splits each line of the file to create lines and conversations
175+ def loadLinesAndConversations (fileName ):
175176 lines = {}
177+ conversations = {}
176178 with open (fileName , 'r' , encoding = 'iso-8859-1' ) as f :
177179 for line in f :
178- values = line . split ( " +++$+++ " )
179- # Extract fields
180+ lineJson = json . loads ( line )
181+ # Extract fields for line object
180182 lineObj = {}
181- for i , field in enumerate (fields ):
182- lineObj [field ] = values [i ]
183+ lineObj ["lineID" ] = lineJson ["id" ]
184+ lineObj ["characterID" ] = lineJson ["speaker" ]
185+ lineObj ["text" ] = lineJson ["text" ]
183186 lines [lineObj ['lineID' ]] = lineObj
184- return lines
185187
188+ # Extract fields for conversation object
189+ if lineJson ["conversation_id" ] not in conversations :
190+ convObj = {}
191+ convObj ["conversationID" ] = lineJson ["conversation_id" ]
192+ convObj ["movieID" ] = lineJson ["meta" ]["movie_id" ]
193+ convObj ["lines" ] = [lineObj ]
194+ else :
195+ convObj = conversations [lineJson ["conversation_id" ]]
196+ convObj ["lines" ].insert (0 , lineObj )
197+ conversations [convObj ["conversationID" ]] = convObj
186198
187- # Groups fields of lines from `loadLines` into conversations based on *movie_conversations.txt*
188- def loadConversations (fileName , lines , fields ):
189- conversations = []
190- with open (fileName , 'r' , encoding = 'iso-8859-1' ) as f :
191- for line in f :
192- values = line .split (" +++$+++ " )
193- # Extract fields
194- convObj = {}
195- for i , field in enumerate (fields ):
196- convObj [field ] = values [i ]
197- # Convert string to list (convObj["utteranceIDs"] == "['L598485', 'L598486', ...]")
198- utterance_id_pattern = re .compile ('L[0-9]+' )
199- lineIds = utterance_id_pattern .findall (convObj ["utteranceIDs" ])
200- # Reassemble lines
201- convObj ["lines" ] = []
202- for lineId in lineIds :
203- convObj ["lines" ].append (lines [lineId ])
204- conversations .append (convObj )
205- return conversations
199+ return lines , conversations
206200
207201
208202# Extracts pairs of sentences from conversations
209203def extractSentencePairs (conversations ):
210204 qa_pairs = []
211- for conversation in conversations :
205+ for conversation in conversations . values () :
212206 # Iterate over all the lines of the conversation
213207 for i in range (len (conversation ["lines" ]) - 1 ): # We ignore the last line (no answer for it)
214208 inputLine = conversation ["lines" ][i ]["text" ].strip ()
@@ -231,18 +225,12 @@ def extractSentencePairs(conversations):
231225# Unescape the delimiter
232226delimiter = str (codecs .decode (delimiter , "unicode_escape" ))
233227
234- # Initialize lines dict, conversations list, and field ids
228+ # Initialize lines dict and conversations dict
235229lines = {}
236- conversations = []
237- MOVIE_LINES_FIELDS = ["lineID" , "characterID" , "movieID" , "character" , "text" ]
238- MOVIE_CONVERSATIONS_FIELDS = ["character1ID" , "character2ID" , "movieID" , "utteranceIDs" ]
239-
240- # Load lines and process conversations
241- print ("\n Processing corpus..." )
242- lines = loadLines (os .path .join (corpus , "movie_lines.txt" ), MOVIE_LINES_FIELDS )
243- print ("\n Loading conversations..." )
244- conversations = loadConversations (os .path .join (corpus , "movie_conversations.txt" ),
245- lines , MOVIE_CONVERSATIONS_FIELDS )
230+ conversations = {}
231+ # Load lines and conversations
232+ print ("\n Processing corpus into lines and conversations..." )
233+ lines , conversations = loadLinesAndConversations (os .path .join (corpus , "utterances.jsonl" ))
246234
247235# Write new csv file
248236print ("\n Writing newly formatted file..." )
@@ -1341,7 +1329,7 @@ def evaluateInput(encoder, decoder, searcher, voc):
13411329 for k , v in state .items ():
13421330 if isinstance (v , torch .Tensor ):
13431331 state [k ] = v .cuda ()
1344-
1332+
13451333# Run training iterations
13461334print ("Starting Training!" )
13471335trainIters (model_name , voc , pairs , encoder , decoder , encoder_optimizer , decoder_optimizer ,
0 commit comments