diff --git a/tests/test_text.py b/tests/test_text.py index e0ee71e2c..ac1f9c996 100644 --- a/tests/test_text.py +++ b/tests/test_text.py @@ -74,6 +74,72 @@ def test_text_models(): assert len(P3.dictionary) == 3 +def test_char_models(): + test_string = 'unigram' + wordseq = words(test_string) + P1 = NgramCharModel(1, wordseq) + + assert len(P1.dictionary) == len(test_string) + for char in test_string: + assert tuple(char) in P1.dictionary + + test_string = 'a b c' + wordseq = words(test_string) + P1 = NgramCharModel(1, wordseq) + + assert len(P1.dictionary) == len(test_string.split()) + for char in test_string.split(): + assert tuple(char) in P1.dictionary + + test_string = 'bigram' + wordseq = words(test_string) + P2 = NgramCharModel(2, wordseq) + + expected_bigrams = {(' ', 'b'): 1, ('b', 'i'): 1, ('i', 'g'): 1, ('g', 'r'): 1, ('r', 'a'): 1, ('a', 'm'): 1} + + assert len(P2.dictionary) == len(expected_bigrams) + for bigram, count in expected_bigrams.items(): + assert bigram in P2.dictionary + assert P2.dictionary[bigram] == count + + test_string = 'bigram bigram' + wordseq = words(test_string) + P2 = NgramCharModel(2, wordseq) + + expected_bigrams = {(' ', 'b'): 2, ('b', 'i'): 2, ('i', 'g'): 2, ('g', 'r'): 2, ('r', 'a'): 2, ('a', 'm'): 2} + + assert len(P2.dictionary) == len(expected_bigrams) + for bigram, count in expected_bigrams.items(): + assert bigram in P2.dictionary + assert P2.dictionary[bigram] == count + + test_string = 'trigram' + wordseq = words(test_string) + P3 = NgramCharModel(3, wordseq) + + expected_trigrams = {(' ', ' ', 't'): 1, (' ', 't', 'r'): 1, ('t', 'r', 'i'): 1, + ('r', 'i', 'g'): 1, ('i', 'g', 'r'): 1, ('g', 'r', 'a'): 1, + ('r', 'a', 'm'): 1} + + assert len(P3.dictionary) == len(expected_trigrams) + for bigram, count in expected_trigrams.items(): + assert bigram in P3.dictionary + assert P3.dictionary[bigram] == count + + test_string = 'trigram trigram trigram' + wordseq = words(test_string) + P3 = NgramCharModel(3, wordseq) + + expected_trigrams = {(' ', ' ', 't'): 3, (' ', 't', 'r'): 3, ('t', 'r', 'i'): 3, + ('r', 'i', 'g'): 3, ('i', 'g', 'r'): 3, ('g', 'r', 'a'): 3, + ('r', 'a', 'm'): 3} + + assert len(P3.dictionary) == len(expected_trigrams) + for bigram, count in expected_trigrams.items(): + assert bigram in P3.dictionary + assert P3.dictionary[bigram] == count + + def test_viterbi_segmentation(): flatland = DataFile("EN-text/flatland.txt").read() wordseq = words(flatland)