11"""Module symbol-table generator"""
22
33from compiler import ast
4+ import types
45
5- module_scope = None
6+ MANGLE_LEN = 256
67
78class Scope :
89 # XXX how much information do I need about each name?
9- def __init__ (self , name ):
10+ def __init__ (self , name , module , klass = None ):
1011 self .name = name
12+ self .module = module
1113 self .defs = {}
1214 self .uses = {}
1315 self .globals = {}
1416 self .params = {}
17+ self .children = []
18+ self .klass = None
19+ if klass is not None :
20+ for i in range (len (klass )):
21+ if klass [i ] != '_' :
22+ self .klass = klass [i :]
23+ break
1524
1625 def __repr__ (self ):
1726 return "<%s: %s>" % (self .__class__ .__name__ , self .name )
1827
28+ def mangle (self , name ):
29+ if self .klass is None :
30+ return name
31+ if not name .startswith ('__' ):
32+ return name
33+ if len (name ) + 2 >= MANGLE_LEN :
34+ return name
35+ if name .endswith ('__' ):
36+ return name
37+ return "_%s%s" % (self .klass , name )
38+
1939 def add_def (self , name ):
20- self .defs [name ] = 1
40+ self .defs [self . mangle ( name ) ] = 1
2141
2242 def add_use (self , name ):
23- self .uses [name ] = 1
43+ self .uses [self . mangle ( name ) ] = 1
2444
2545 def add_global (self , name ):
46+ name = self .mangle (name )
2647 if self .uses .has_key (name ) or self .defs .has_key (name ):
2748 pass # XXX warn about global following def/use
2849 if self .params .has_key (name ):
2950 raise SyntaxError , "%s in %s is global and parameter" % \
3051 (name , self .name )
3152 self .globals [name ] = 1
32- module_scope .add_def (name )
53+ self . module .add_def (name )
3354
3455 def add_param (self , name ):
56+ name = self .mangle (name )
3557 self .defs [name ] = 1
3658 self .params [name ] = 1
3759
@@ -41,46 +63,53 @@ def get_names(self):
4163 d .update (self .uses )
4264 return d .keys ()
4365
66+ def add_child (self , child ):
67+ self .children .append (child )
68+
69+ def get_children (self ):
70+ return self .children
71+
4472class ModuleScope (Scope ):
4573 __super_init = Scope .__init__
4674
4775 def __init__ (self ):
48- self .__super_init ("global" )
49- global module_scope
50- assert module_scope is None
51- module_scope = self
76+ self .__super_init ("global" , self )
5277
5378class LambdaScope (Scope ):
5479 __super_init = Scope .__init__
5580
5681 __counter = 1
5782
58- def __init__ (self ):
83+ def __init__ (self , module , klass = None ):
5984 i = self .__counter
6085 self .__counter += 1
61- self .__super_init ("lambda.%d" % i )
86+ self .__super_init ("lambda.%d" % i , module , klass )
6287
6388class FunctionScope (Scope ):
6489 pass
6590
6691class ClassScope (Scope ):
67- pass
92+ __super_init = Scope .__init__
93+
94+ def __init__ (self , name , module ):
95+ self .__super_init (name , module , name )
6896
6997class SymbolVisitor :
7098 def __init__ (self ):
7199 self .scopes = {}
72-
100+ self .klass = None
101+
73102 # node that define new scopes
74103
75104 def visitModule (self , node ):
76- scope = self .scopes [node ] = ModuleScope ()
105+ scope = self .module = self . scopes [node ] = ModuleScope ()
77106 self .visit (node .node , scope )
78107
79108 def visitFunction (self , node , parent ):
80109 parent .add_def (node .name )
81110 for n in node .defaults :
82111 self .visit (n , parent )
83- scope = FunctionScope (node .name )
112+ scope = FunctionScope (node .name , self . module , self . klass )
84113 self .scopes [node ] = scope
85114 for name in node .argnames :
86115 scope .add_param (name )
@@ -89,7 +118,7 @@ def visitFunction(self, node, parent):
89118 def visitLambda (self , node , parent ):
90119 for n in node .defaults :
91120 self .visit (n , parent )
92- scope = LambdaScope ()
121+ scope = LambdaScope (self . module , self . klass )
93122 self .scopes [node ] = scope
94123 for name in node .argnames :
95124 scope .add_param (name )
@@ -99,9 +128,12 @@ def visitClass(self, node, parent):
99128 parent .add_def (node .name )
100129 for n in node .bases :
101130 self .visit (n , parent )
102- scope = ClassScope (node .name )
131+ scope = ClassScope (node .name , self . module )
103132 self .scopes [node ] = scope
133+ prev = self .klass
134+ self .klass = node .name
104135 self .visit (node .code , scope )
136+ self .klass = prev
105137
106138 # name can be a def or a use
107139
@@ -155,6 +187,21 @@ def visitGlobal(self, node, scope):
155187 for name in node .names :
156188 scope .add_global (name )
157189
190+ # prune if statements if tests are false
191+
192+ _const_types = types .StringType , types .IntType , types .FloatType
193+
194+ def visitIf (self , node , scope ):
195+ for test , body in node .tests :
196+ if isinstance (test , ast .Const ):
197+ if type (test .value ) in self ._const_types :
198+ if not test .value :
199+ continue
200+ self .visit (test , scope )
201+ self .visit (body , scope )
202+ if node .else_ :
203+ self .visit (node .else_ , scope )
204+
158205def sort (l ):
159206 l = l [:]
160207 l .sort ()
@@ -168,26 +215,47 @@ def list_eq(l1, l2):
168215 from compiler import parseFile , walk
169216 import symtable
170217
218+ def get_names (syms ):
219+ return [s for s in [s .get_name () for s in syms .get_symbols ()]
220+ if not s .startswith ('_[' )]
221+
171222 for file in sys .argv [1 :]:
172223 print file
173224 f = open (file )
174225 buf = f .read ()
175226 f .close ()
176227 syms = symtable .symtable (buf , file , "exec" )
177- mod_names = [s for s in [s .get_name ()
178- for s in syms .get_symbols ()]
179- if not s .startswith ('_[' )]
228+ mod_names = get_names (syms )
180229 tree = parseFile (file )
181230 s = SymbolVisitor ()
182231 walk (tree , s )
183- for node , scope in s .scopes .items ():
184- print node .__class__ .__name__ , id (node )
185- print scope
186- print scope .get_names ()
187232
233+ # compare module-level symbols
188234 names2 = s .scopes [tree ].get_names ()
235+
189236 if not list_eq (mod_names , names2 ):
237+ print
190238 print "oops" , file
191239 print sort (mod_names )
192240 print sort (names2 )
193241 sys .exit (- 1 )
242+
243+ d = {}
244+ d .update (s .scopes )
245+ del d [tree ]
246+ scopes = d .values ()
247+ del d
248+
249+ for s in syms .get_symbols ():
250+ if s .is_namespace ():
251+ l = [sc for sc in scopes
252+ if sc .name == s .get_name ()]
253+ if len (l ) > 1 :
254+ print "skipping" , s .get_name ()
255+ else :
256+ if not list_eq (get_names (s .get_namespace ()),
257+ l [0 ].get_names ()):
258+ print s .get_name ()
259+ print get_names (s .get_namespace ())
260+ print l [0 ].get_names ()
261+ sys .exit (- 1 )
0 commit comments