@@ -33,16 +33,16 @@ def __init__(self, agents, queens=8, master_agents=[]):
33
33
self .queens = queens
34
34
self .tables = [Table (a , queens ) for a in agents ]
35
35
self .master_agents = master_agents
36
- self .stats = StatsModule (len ( agents ) )
36
+ self .stats = StatsModule (agents )
37
37
38
38
def step (self ):
39
39
for t in self .tables :
40
40
mes , new_state = t .agent (t .board )
41
41
if isinstance (mes , str ) and mes == "Success" :
42
- self .stats .addWin (t )
42
+ self .stats .add_win (t )
43
43
t .randomize_board (self .queens )
44
44
elif isinstance (mes , str ) and mes == "NoOp" :
45
- self .stats .addLoss (t )
45
+ self .stats .add_loss (t )
46
46
t .randomize_board (self .queens )
47
47
else :
48
48
t .perf += 1
@@ -56,30 +56,52 @@ def run(self, steps):
56
56
self .step ()
57
57
58
58
def find_sol (self , how_many ):
59
- while len ( self . stats . solutions ) < how_many :
59
+ while True :
60
60
self .step ()
61
+ for t in self .tables :
62
+ if len (self .stats .solutions [t .agent ]) >= how_many :
63
+ self .print_stats ()
64
+ return
65
+
66
+ def print_stats (self ):
67
+ self .stats .print_stats (self .tables , self .queens )
61
68
62
69
63
70
class StatsModule :
64
71
""" Class to measure agent's performance. Should be called on every win and loss which occurs
65
72
in the environment. """
66
- def __init__ (self , count ):
67
- self .count = count
68
- self .solutions = set ()
69
- self .win_times , self .loss_times = [], []
70
-
71
- def addWin (self , table ):
72
- self .win_times .append (table .perf )
73
+ def __init__ (self , agents ):
74
+ self .count = len (agents )
75
+ self .solutions = dict ([(k , set ()) for k in agents ])
76
+ self .win_times = dict ([(k ,[]) for k in agents ])
77
+ self .loss_times = dict ([(k , []) for k in agents ])
78
+
79
+ def add_win (self , table ):
80
+ self .win_times [table .agent ].append (table .perf )
73
81
if not tuple (table .board .tolist ()) in self .solutions :
74
- self .solutions .add (tuple (table .board .tolist ()))
75
- # print(len(self.solutions))
82
+ self .solutions [table .agent ].add (tuple (table .board .tolist ()))
76
83
77
- def addLoss (self , table ):
78
- self .loss_times .append (table .perf )
84
+ def add_loss (self , table ):
85
+ self .loss_times [ table . agent ] .append (table .perf )
79
86
80
- def printStats (self ):
81
- total = len (self .win_times ) + len (self .loss_times )
82
- print ("Found {} solutions" .format (len (self .solutions )))
83
- print ("Win ratio:" , len (self .win_times )/ total )
84
- print ("Avg. win time:" , np .average (self .win_times ))
85
- print ("Avg. loss time:" , np .average (self .loss_times ))
87
+ def print_stats (self , tables , queens ):
88
+ for t in tables :
89
+ print ("Agent" , t .agent )
90
+ self .print_table_stats (t )
91
+
92
+ def print_table_stats (self , table ):
93
+ win = len (self .win_times [table .agent ])
94
+ loss = len (self .loss_times [table .agent ])
95
+ print ("Found {} solutions" .format (len (self .solutions [table .agent ])))
96
+ if win + loss > 0 :
97
+ print ("Win ratio:" , win / (win + loss ))
98
+ print ("Avg. win time:" , np .average (self .win_times [table .agent ]))
99
+ print ("Avg. loss time:" , np .average (self .loss_times [table .agent ]))
100
+
101
+
102
+
103
+
104
+
105
+
106
+
107
+
0 commit comments