Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit c814d7f

Browse files
mkarikomboathitMatt Karikomi
authored
Corpus hierarchy (#9)
* fix deprecated warning for julia 0.4 * v1.3 compat v1.3 compat fixed lexicon * Refactoring 1) Type hierarchy for data: rooted at abstract corpus and document, which support subtypes representing fully-synthetic and real world data 2) Type hierarchy for MCMC: break struct model into "model" and "state" reflecting the scope (document locality) of latent variables vs model parameters and hyperpriors. This will facilitate clear cut testing in next PR based on Grosse and Duvenaud https://arxiv.org/abs/1412.5218 * Add unit test for Gibbs sampler, etc 1) Per-word topics: add a test for consistency (with the full joint) of the corresponding conditional 2) Get rid of mutability on structs in src/Data.jl in favor of in-place assignment * comments in gibbs tests Co-authored-by: Fineday <[email protected]> Co-authored-by: Matt Karikomi <[email protected]>
1 parent b447326 commit c814d7f

File tree

9 files changed

+585
-345
lines changed

9 files changed

+585
-345
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
.DS_Store

Manifest.toml

Lines changed: 36 additions & 137 deletions
Original file line numberDiff line numberDiff line change
@@ -1,136 +1,5 @@
11
# This file is machine-generated - editing it directly is not advised
22

3-
[[Arpack]]
4-
deps = ["Arpack_jll", "Libdl", "LinearAlgebra"]
5-
git-tree-sha1 = "2ff92b71ba1747c5fdd541f8fc87736d82f40ec9"
6-
uuid = "7d9fca2a-8960-54d3-9f78-7d1dccf2cb97"
7-
version = "0.4.0"
8-
9-
[[Arpack_jll]]
10-
deps = ["Libdl", "OpenBLAS_jll", "Pkg"]
11-
git-tree-sha1 = "68a90a692ddc0eb72d69a6993ca26e2a923bf195"
12-
uuid = "68821587-b530-5797-8361-c406ea357684"
13-
version = "3.5.0+2"
14-
15-
[[Base64]]
16-
uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
17-
18-
[[BinaryProvider]]
19-
deps = ["Libdl", "SHA"]
20-
git-tree-sha1 = "5b08ed6036d9d3f0ee6369410b830f8873d4024c"
21-
uuid = "b99e7846-7c00-51b0-8f62-c81ae34c0232"
22-
version = "0.5.8"
23-
24-
[[DataAPI]]
25-
git-tree-sha1 = "674b67f344687a88310213ddfa8a2b3c76cc4252"
26-
uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a"
27-
version = "1.1.0"
28-
29-
[[DataStructures]]
30-
deps = ["InteractiveUtils", "OrderedCollections"]
31-
git-tree-sha1 = "5a431d46abf2ef2a4d5d00bd0ae61f651cf854c8"
32-
uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
33-
version = "0.17.10"
34-
35-
[[Dates]]
36-
deps = ["Printf"]
37-
uuid = "ade2ca70-3891-5945-98fb-dc099432e06a"
38-
39-
[[Distributed]]
40-
deps = ["Random", "Serialization", "Sockets"]
41-
uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"
42-
43-
[[Distributions]]
44-
deps = ["FillArrays", "LinearAlgebra", "PDMats", "Printf", "QuadGK", "Random", "SpecialFunctions", "Statistics", "StatsBase", "StatsFuns"]
45-
git-tree-sha1 = "6b19601c0e98de3a8964ed33ad73e130c7165b1d"
46-
uuid = "31c24e10-a181-5473-b8eb-7969acd0382f"
47-
version = "0.22.4"
48-
49-
[[FillArrays]]
50-
deps = ["LinearAlgebra", "Random", "SparseArrays"]
51-
git-tree-sha1 = "85c6b57e2680fa28d5c8adc798967377646fbf66"
52-
uuid = "1a297f60-69ca-5386-bcde-b61e274b549b"
53-
version = "0.8.5"
54-
55-
[[InteractiveUtils]]
56-
deps = ["Markdown"]
57-
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
58-
59-
[[LibGit2]]
60-
uuid = "76f85450-5226-5b5a-8eaa-529ad045b433"
61-
62-
[[Libdl]]
63-
uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
64-
65-
[[LinearAlgebra]]
66-
deps = ["Libdl"]
67-
uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
68-
69-
[[Logging]]
70-
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"
71-
72-
[[Markdown]]
73-
deps = ["Base64"]
74-
uuid = "d6f4376e-aef5-505a-96c1-9c027394607a"
75-
76-
[[Missings]]
77-
deps = ["DataAPI"]
78-
git-tree-sha1 = "de0a5ce9e5289f27df672ffabef4d1e5861247d5"
79-
uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28"
80-
version = "0.4.3"
81-
82-
[[OpenBLAS_jll]]
83-
deps = ["Libdl", "Pkg"]
84-
git-tree-sha1 = "e2551d7c25d52f35b76d86a50917a3ba8988f519"
85-
uuid = "4536629a-c528-5b80-bd46-f80d51c5b363"
86-
version = "0.3.7+5"
87-
88-
[[OpenSpecFun_jll]]
89-
deps = ["Libdl", "Pkg"]
90-
git-tree-sha1 = "65f672edebf3f4e613ddf37db9dcbd7a407e5e90"
91-
uuid = "efe28fd5-8261-553b-a9e1-b2916fc3738e"
92-
version = "0.5.3+1"
93-
94-
[[OrderedCollections]]
95-
deps = ["Random", "Serialization", "Test"]
96-
git-tree-sha1 = "c4c13474d23c60d20a67b217f1d7f22a40edf8f1"
97-
uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
98-
version = "1.1.0"
99-
100-
[[PDMats]]
101-
deps = ["Arpack", "LinearAlgebra", "SparseArrays", "SuiteSparse", "Test"]
102-
git-tree-sha1 = "5f303510529486bb02ac4d70da8295da38302194"
103-
uuid = "90014a1f-27ba-587c-ab20-58faa44d9150"
104-
version = "0.9.11"
105-
106-
[[Pkg]]
107-
deps = ["Dates", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Test", "UUIDs"]
108-
uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
109-
110-
[[Printf]]
111-
deps = ["Unicode"]
112-
uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7"
113-
114-
[[QuadGK]]
115-
deps = ["DataStructures", "LinearAlgebra"]
116-
git-tree-sha1 = "dc84e810393cfc6294248c9032a9cdacc14a3db4"
117-
uuid = "1fd47b50-473d-5c70-9696-f719f8f3bcdc"
118-
version = "2.3.1"
119-
120-
[[REPL]]
121-
deps = ["InteractiveUtils", "Markdown", "Sockets"]
122-
uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"
123-
124-
[[Random]]
125-
deps = ["Serialization"]
126-
uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
127-
128-
[[Rmath]]
129-
deps = ["BinaryProvider", "Libdl", "Random", "Statistics"]
130-
git-tree-sha1 = "2bbddcb984a1d08612d0c4abb5b4774883f6fa98"
131-
uuid = "79098fc4-a85e-5d69-aa6a-4863f24498fa"
132-
version = "0.6.0"
133-
1343
[[SHA]]
1354
uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce"
1365

@@ -160,12 +29,6 @@ version = "0.10.0"
16029
deps = ["LinearAlgebra", "SparseArrays"]
16130
uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
16231

163-
[[StatsBase]]
164-
deps = ["DataAPI", "DataStructures", "LinearAlgebra", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics"]
165-
git-tree-sha1 = "be5c7d45daa449d12868f4466dbf5882242cf2d9"
166-
uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
167-
version = "0.32.1"
168-
16932
[[StatsFuns]]
17033
deps = ["Rmath", "SpecialFunctions"]
17134
git-tree-sha1 = "f290ddd5fdedeadd10e961eb3f4d3340f09d030a"
@@ -186,3 +49,39 @@ uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
18649

18750
[[Unicode]]
18851
uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"
52+
53+
[[Zlib_jll]]
54+
deps = ["Libdl", "Pkg"]
55+
git-tree-sha1 = "fd36a6739e256527287c5444960d0266712cd49e"
56+
uuid = "83775a58-1f1d-513f-b197-d71354ab007a"
57+
version = "1.2.11+8"
58+
59+
[[libass_jll]]
60+
deps = ["Bzip2_jll", "FreeType2_jll", "FriBidi_jll", "Libdl", "Pkg", "Zlib_jll"]
61+
git-tree-sha1 = "3fd3ea3525f2e3d337c54a52b2ca78a5a272bbf5"
62+
uuid = "0ac62f75-1d6f-5e53-bd7c-93b484bb37c0"
63+
version = "0.14.0+0"
64+
65+
[[libfdk_aac_jll]]
66+
deps = ["Libdl", "Pkg"]
67+
git-tree-sha1 = "0e4ace600c20714a8dd67700c4502714d8473e8e"
68+
uuid = "f638f0a6-7fb0-5443-88ba-1cc74229b280"
69+
version = "0.1.6+1"
70+
71+
[[libvorbis_jll]]
72+
deps = ["Libdl", "Ogg_jll", "Pkg"]
73+
git-tree-sha1 = "71e54fb89ac3e0344c7185d1876fd96b0f246952"
74+
uuid = "f27f6e37-5d2b-51aa-960f-b287f2bc3b7a"
75+
version = "1.3.6+2"
76+
77+
[[x264_jll]]
78+
deps = ["Libdl", "Pkg"]
79+
git-tree-sha1 = "23664c0757c3740050ca0e22944c786c165ca25a"
80+
uuid = "1270edf5-f2f9-52d2-97e9-ab00b5d0237a"
81+
version = "2019.5.25+1"
82+
83+
[[x265_jll]]
84+
deps = ["Libdl", "Pkg"]
85+
git-tree-sha1 = "9345e417084421a8e91373d6196bc58e660eed2a"
86+
uuid = "dfaa095f-4041-5dcd-9319-2fabd8486b76"
87+
version = "3.0.0+0"

Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,12 @@ authors = ["Jonathan Chang <slycoder @gmail.com>"]
44
version = "0.1.0"
55

66
[deps]
7+
Clustering = "aaaa29a8-35af-508c-8bc3-b662a17a0fe5"
78
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
9+
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
810
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
911
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
12+
UMAP = "c4f8c510-2410-5be4-91d7-4fbaeb39457e"
1013

1114
[compat]
1215
julia = "1.3"

examples/LDA.jl

Lines changed: 39 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,46 @@
1-
using TopicModels
1+
using TopicModels, Plots, UMAP
22

3-
exdir = Pkg.dir("TopicModels", "examples")
3+
##################################################################################################################################
4+
# Fit and Visualize Real-World Text Data
45

5-
testDocuments = readDocuments(open(joinpath(exdir, "cora.documents")))
6+
exdir = joinpath(dirname(pathof(TopicModels)), "..", "examples")
7+
8+
testDocuments = readDocs(open(joinpath(exdir, "cora.documents")))
69
testLexicon = readLexicon(open(joinpath(exdir, "cora.lexicon")))
710

8-
corpus = Corpus(testDocuments)
11+
corpus = Corpus(testDocuments,testLexicon)
12+
model = Model(fill(0.1, 10), fill(0.01,length(testLexicon)), corpus)
13+
state = State(model,corpus)
14+
15+
#@time Juno.@run trainModel(model, state, 30)
16+
@time trainModel(model, state, 30)
17+
topWords = topTopicWords(model, state, 10)
18+
19+
# visualize the fit
20+
@time embedding = umap(state.topics, 2, n_neighbors=10)
21+
maxlabels = vec(map(i->i[1], findmax(state.topics,dims=1)[2]))
22+
scatter(embedding[1,:], embedding[2,:], zcolor=maxlabels, title="UMAP: Max-Likelihood Doc Topics on Learned", marker=(2, 2, :auto, stroke(0)))
23+
24+
##################################################################################################################################
25+
# Fit, Validate, and Visualize Synthetic Data Derived from a Fully-Generative Simulation (Poisson-distributed document-length)
26+
27+
k = 10
28+
lexLength = 1000
29+
corpLambda = 1000 # poisson parameter for random doc length
30+
corpLength = 100
31+
scaleK = 0.01
32+
scaleL = 0.01
33+
testCorpus = LdaCorpus(k, lexLength, corpLambda, corpLength, scaleK, scaleL)
934

10-
model = Model(fill(0.1, 10), 0.01, length(testLexicon), corpus)
35+
testModel = Model(testCorpus.alpha, testCorpus.beta, testCorpus)
36+
testState = State(testModel, testCorpus)
37+
@time trainModel(testModel, testState, 100)
1138

12-
@time trainModel(model, 30)
39+
# compute validation metrics on a single fit
40+
CorpusARI(testState,testModel,testCorpus) # ARI for max. likelihood. document topics
41+
DocsARI(testState,testCorpus) # ARI for actual word topics
1342

14-
topWords = topTopicWords(model, testLexicon, 21)
43+
# visualize the fit
44+
@time embedding = umap(testState.topics, 2;n_neighbors=10)
45+
maxlabels = vec(map(i->i[1], findmax(CorpusTopics(testCorpus),dims=1)[2]))
46+
scatter(embedding[1,:], embedding[2,:], zcolor=maxlabels, title="UMAP: True on Learned", marker=(2, 2, :auto, stroke(0)))

0 commit comments

Comments
 (0)