diff --git a/agents.py b/agents.py
index d29b0c382..6911a4a1a 100644
--- a/agents.py
+++ b/agents.py
@@ -50,7 +50,7 @@ class Thing:
.__name__ slot (used for output only)."""
def __repr__(self):
- return '<{}>'.format(getattr(self, '__name__', self.__class__.__name__))
+ return f"<{getattr(self, '__name__', self.__class__.__name__)}>"
def is_alive(self):
"""Things that are 'alive' should return true."""
@@ -338,7 +338,7 @@ def step(self):
def run(self, steps=1000):
"""Run the Environment for given number of time steps."""
- for step in range(steps):
+ for _ in range(steps):
if self.is_done():
return
self.step()
@@ -378,8 +378,8 @@ def delete_thing(self, thing):
except ValueError as e:
print(e)
print(" in Environment delete_thing")
- print(" Thing to be removed: {} at {}".format(thing, thing.location))
- print(" from list: {}".format([(thing, thing.location) for thing in self.things]))
+ print(f" Thing to be removed: {thing} at {thing.location}")
+ print(f" from list: {[(thing, thing.location) for thing in self.things]}")
if thing in self.agents:
self.agents.remove(thing)
@@ -506,8 +506,11 @@ def execute_action(self, agent, action):
elif action == 'Forward':
agent.bump = self.move_to(agent, agent.direction.move_forward(agent.location))
elif action == 'Grab':
- things = [thing for thing in self.list_things_at(agent.location) if agent.can_grab(thing)]
- if things:
+ if things := [
+ thing
+ for thing in self.list_things_at(agent.location)
+ if agent.can_grab(thing)
+ ]:
agent.holding.append(things[0])
print("Grabbing ", things[0].__class__.__name__)
self.delete_thing(things[0])
@@ -552,7 +555,12 @@ def add_thing(self, thing, location=None, exclude_duplicate_class_items=False):
def is_inbounds(self, location):
"""Checks to make sure that the location is inbounds (within walls if we have walls)"""
x, y = location
- return not (x < self.x_start or x > self.x_end or y < self.y_start or y > self.y_end)
+ return (
+ x >= self.x_start
+ and x <= self.x_end
+ and y >= self.y_start
+ and y <= self.y_end
+ )
def random_location_inbounds(self, exclude=None):
"""Returns a random location that is inbounds (within walls if we have walls)"""
@@ -634,9 +642,7 @@ def get_world(self):
x_start, y_start = (0, 0)
x_end, y_end = self.width, self.height
for x in range(x_start, x_end):
- row = []
- for y in range(y_start, y_end):
- row.append(self.list_things_at((x, y)))
+ row = [self.list_things_at((x, y)) for y in range(y_start, y_end)]
result.append(row)
return result
@@ -660,7 +666,7 @@ def run(self, steps=1000, delay=1):
def run(self, steps=1000, delay=1):
"""Run the Environment for given number of time steps,
but update the GUI too."""
- for step in range(steps):
+ for _ in range(steps):
self.update(delay)
if self.is_done():
break
@@ -908,9 +914,7 @@ def get_world(self, show_walls=True):
x_end, y_end = self.width - 1, self.height - 1
for x in range(x_start, x_end):
- row = []
- for y in range(y_start, y_end):
- row.append(self.list_things_at((x, y)))
+ row = [self.list_things_at((x, y)) for y in range(y_start, y_end)]
result.append(row)
return result
@@ -938,8 +942,7 @@ def percept(self, agent):
"""Return things in adjacent (not diagonal) cells of the agent.
Result format: [Left, Right, Up, Down, Center / Current location]"""
x, y = agent.location
- result = []
- result.append(self.percepts_from(agent, (x - 1, y)))
+ result = [self.percepts_from(agent, (x - 1, y))]
result.append(self.percepts_from(agent, (x + 1, y)))
result.append(self.percepts_from(agent, (x, y - 1)))
result.append(self.percepts_from(agent, (x, y + 1)))
@@ -999,10 +1002,11 @@ def is_done(self):
if explorer[0].alive:
return False
else:
- print("Death by {} [-1000].".format(explorer[0].killed_by))
+ print(f"Death by {explorer[0].killed_by} [-1000].")
else:
- print("Explorer climbed out {}."
- .format("with Gold [+1000]!" if Gold() not in self.things else "without Gold [+0]"))
+ print(
+ f'Explorer climbed out {"with Gold [+1000]!" if Gold() not in self.things else "without Gold [+0]"}.'
+ )
return True
# TODO: Arrow needs to be implemented
@@ -1024,7 +1028,7 @@ def compare_agents(EnvFactory, AgentFactories, n=10, steps=1000):
>>> performance_ReflexVacuumAgent <= performance_ModelBasedVacuumAgent
True
"""
- envs = [EnvFactory() for i in range(n)]
+ envs = [EnvFactory() for _ in range(n)]
return [(A, test_agent(A, steps, copy.deepcopy(envs)))
for A in AgentFactories]
diff --git a/agents4e.py b/agents4e.py
index 75369a69a..31e5b5b3b 100644
--- a/agents4e.py
+++ b/agents4e.py
@@ -55,7 +55,7 @@ class Thing:
.__name__ slot (used for output only)."""
def __repr__(self):
- return '<{}>'.format(getattr(self, '__name__', self.__class__.__name__))
+ return f"<{getattr(self, '__name__', self.__class__.__name__)}>"
def is_alive(self):
"""Things that are 'alive' should return true."""
@@ -343,7 +343,7 @@ def step(self):
def run(self, steps=1000):
"""Run the Environment for given number of time steps."""
- for step in range(steps):
+ for _ in range(steps):
if self.is_done():
return
self.step()
@@ -383,8 +383,8 @@ def delete_thing(self, thing):
except ValueError as e:
print(e)
print(" in Environment delete_thing")
- print(" Thing to be removed: {} at {}".format(thing, thing.location))
- print(" from list: {}".format([(thing, thing.location) for thing in self.things]))
+ print(f" Thing to be removed: {thing} at {thing.location}")
+ print(f" from list: {[(thing, thing.location) for thing in self.things]}")
if thing in self.agents:
self.agents.remove(thing)
@@ -554,7 +554,12 @@ def add_thing(self, thing, location=None, exclude_duplicate_class_items=False):
def is_inbounds(self, location):
"""Checks to make sure that the location is inbounds (within walls if we have walls)"""
x, y = location
- return not (x < self.x_start or x > self.x_end or y < self.y_start or y > self.y_end)
+ return (
+ x >= self.x_start
+ and x <= self.x_end
+ and y >= self.y_start
+ and y <= self.y_end
+ )
def random_location_inbounds(self, exclude=None):
"""Returns a random location that is inbounds (within walls if we have walls)"""
@@ -639,9 +644,7 @@ def get_world(self):
x_start, y_start = (0, 0)
x_end, y_end = self.width, self.height
for x in range(x_start, x_end):
- row = []
- for y in range(y_start, y_end):
- row.append(self.list_things_at((x, y)))
+ row = [self.list_things_at((x, y)) for y in range(y_start, y_end)]
result.append(row)
return result
@@ -665,7 +668,7 @@ def run(self, steps=1000, delay=1):
def run(self, steps=1000, delay=1):
"""Run the Environment for given number of time steps,
but update the GUI too."""
- for step in range(steps):
+ for _ in range(steps):
self.update(delay)
if self.is_done():
break
@@ -913,9 +916,7 @@ def get_world(self, show_walls=True):
x_end, y_end = self.width - 1, self.height - 1
for x in range(x_start, x_end):
- row = []
- for y in range(y_start, y_end):
- row.append(self.list_things_at((x, y)))
+ row = [self.list_things_at((x, y)) for y in range(y_start, y_end)]
result.append(row)
return result
@@ -943,8 +944,7 @@ def percept(self, agent):
"""Return things in adjacent (not diagonal) cells of the agent.
Result format: [Left, Right, Up, Down, Center / Current location]"""
x, y = agent.location
- result = []
- result.append(self.percepts_from(agent, (x - 1, y)))
+ result = [self.percepts_from(agent, (x - 1, y))]
result.append(self.percepts_from(agent, (x + 1, y)))
result.append(self.percepts_from(agent, (x, y - 1)))
result.append(self.percepts_from(agent, (x, y + 1)))
@@ -980,8 +980,8 @@ def execute_action(self, agent, action):
if agent.can_grab(thing)]
if len(things):
print("Grabbing", things[0].__class__.__name__)
- if len(things):
- agent.holding.append(things[0])
+ if len(things):
+ agent.holding.append(things[0])
agent.performance -= 1
elif action == 'Climb':
if agent.location == (1, 1): # Agent can only climb out of (1,1)
@@ -1018,10 +1018,11 @@ def is_done(self):
if explorer[0].alive:
return False
else:
- print("Death by {} [-1000].".format(explorer[0].killed_by))
+ print(f"Death by {explorer[0].killed_by} [-1000].")
else:
- print("Explorer climbed out {}."
- .format("with Gold [+1000]!" if Gold() not in self.things else "without Gold [+0]"))
+ print(
+ f'Explorer climbed out {"with Gold [+1000]!" if Gold() not in self.things else "without Gold [+0]"}.'
+ )
return True
# TODO: Arrow needs to be implemented
@@ -1043,7 +1044,7 @@ def compare_agents(EnvFactory, AgentFactories, n=10, steps=1000):
>>> performance_ReflexVacuumAgent <= performance_ModelBasedVacuumAgent
True
"""
- envs = [EnvFactory() for i in range(n)]
+ envs = [EnvFactory() for _ in range(n)]
return [(A, test_agent(A, steps, copy.deepcopy(envs)))
for A in AgentFactories]
diff --git a/csp.py b/csp.py
index 46ae07dd5..cb1de024a 100644
--- a/csp.py
+++ b/csp.py
@@ -95,11 +95,10 @@ def actions(self, state):
assignments to an unassigned variable."""
if len(state) == len(self.variables):
return []
- else:
- assignment = dict(state)
- var = first([v for v in self.variables if v not in assignment])
- return [(var, val) for val in self.domains[var]
- if self.nconflicts(var, val, assignment) == 0]
+ assignment = dict(state)
+ var = first([v for v in self.variables if v not in assignment])
+ return [(var, val) for val in self.domains[var]
+ if self.nconflicts(var, val, assignment) == 0]
def result(self, state, action):
"""Perform an action and return the new state."""
@@ -141,8 +140,11 @@ def choices(self, var):
def infer_assignment(self):
"""Return the partial assignment implied by the current inferences."""
self.support_pruning()
- return {v: self.curr_domains[v][0]
- for v in self.variables if 1 == len(self.curr_domains[v])}
+ return {
+ v: self.curr_domains[v][0]
+ for v in self.variables
+ if len(self.curr_domains[v]) == 1
+ }
def restore(self, removals):
"""Undo a supposition and all inferences from it."""
@@ -317,9 +319,8 @@ def AC4(csp, queue=None, removals=None, arc_heuristic=dom_j_up):
csp.prune(Xi, x, removals)
revised = True
unsupported_variable_value_pairs.append((Xi, x))
- if revised:
- if not csp.curr_domains[Xi]:
- return False, checks # CSP is inconsistent
+ if revised and not csp.curr_domains[Xi]:
+ return False, checks # CSP is inconsistent
# propagation of removed values
while unsupported_variable_value_pairs:
Xj, y = unsupported_variable_value_pairs.pop()
@@ -331,9 +332,8 @@ def AC4(csp, queue=None, removals=None, arc_heuristic=dom_j_up):
csp.prune(Xi, x, removals)
revised = True
unsupported_variable_value_pairs.append((Xi, x))
- if revised:
- if not csp.curr_domains[Xi]:
- return False, checks # CSP is inconsistent
+ if revised and not csp.curr_domains[Xi]:
+ return False, checks # CSP is inconsistent
return True, checks # CSP is satisfiable
@@ -439,7 +439,7 @@ def min_conflicts(csp, max_steps=100000):
val = min_conflicts_value(csp, var, current)
csp.assign(var, val, current)
# Now repeatedly choose a random conflicted variable and change it
- for i in range(max_steps):
+ for _ in range(max_steps):
conflicted = csp.conflicted_vars(current)
if not conflicted:
return current
@@ -460,7 +460,6 @@ def min_conflicts_value(csp, var, current):
def tree_csp_solver(csp):
"""[Figure 6.11]"""
- assignment = {}
root = csp.variables[0]
X, parent = topological_sort(csp, root)
@@ -469,7 +468,7 @@ def tree_csp_solver(csp):
if not make_arc_consistent(parent[Xj], Xj, csp):
return None
- assignment[root] = csp.curr_domains[root][0]
+ assignment = {root: csp.curr_domains[root][0]}
for Xi in X[1:]:
assignment[Xi] = assign_value(parent[Xi], Xi, csp, assignment)
if not assignment[Xi]:
@@ -521,13 +520,7 @@ def make_arc_consistent(Xj, Xk, csp):
by removing the possible values of Xj that cause inconsistencies."""
# csp.curr_domains[Xj] = []
for val1 in csp.domains[Xj]:
- keep = False # Keep or remove val1
- for val2 in csp.domains[Xk]:
- if csp.constraints(Xj, val1, Xk, val2):
- # Found a consistent assignment for val1, keep it
- keep = True
- break
-
+ keep = any(csp.constraints(Xj, val1, Xk, val2) for val2 in csp.domains[Xk])
if not keep:
# Remove val1
csp.prune(Xj, val1, None)
@@ -539,12 +532,14 @@ def assign_value(Xj, Xk, csp, assignment):
"""Assign a value to Xk given Xj's (Xk's parent) assignment.
Return the first value that satisfies the constraints."""
parent_assignment = assignment[Xj]
- for val in csp.curr_domains[Xk]:
- if csp.constraints(Xj, parent_assignment, Xk, val):
- return val
-
- # No consistent assignment available
- return None
+ return next(
+ (
+ val
+ for val in csp.curr_domains[Xk]
+ if csp.constraints(Xj, parent_assignment, Xk, val)
+ ),
+ None,
+ )
# ______________________________________________________________________________
@@ -707,10 +702,7 @@ def display(self, assignment):
print(ch, end=' ')
print(' ', end=' ')
for var in range(n):
- if assignment.get(var, '') == val:
- ch = '*'
- else:
- ch = ' '
+ ch = '*' if assignment.get(var, '') == val else ' '
print(str(self.nconflicts(var, val, assignment)) + ch, end=' ')
print()
@@ -728,7 +720,7 @@ def flatten(seqs):
_R3 = list(range(3))
_CELL = itertools.count().__next__
-_BGRID = [[[[_CELL() for x in _R3] for y in _R3] for bx in _R3] for by in _R3]
+_BGRID = [[[[_CELL() for _ in _R3] for _ in _R3] for _ in _R3] for _ in _R3]
_BOXES = flatten([list(map(flatten, brow)) for brow in _BGRID])
_ROWS = flatten([list(map(flatten, zip(*brow))) for brow in _BGRID])
_COLS = list(zip(*_ROWS))
@@ -979,7 +971,7 @@ def meet_at_constraint(p1, p2):
def meets(w1, w2):
return w1[p1] == w2[p2]
- meets.__name__ = "meet_at(" + str(p1) + ',' + str(p2) + ')'
+ meets.__name__ = f"meet_at({str(p1)},{str(p2)})"
return meets
@@ -994,7 +986,7 @@ def sum_constraint(n):
def sumv(*values):
return sum(values) is n
- sumv.__name__ = str(n) + "==sum"
+ sumv.__name__ = f"{str(n)}==sum"
return sumv
@@ -1004,7 +996,7 @@ def is_constraint(val):
def isv(x):
return val == x
- isv.__name__ = str(val) + "=="
+ isv.__name__ = f"{str(val)}=="
return isv
@@ -1014,7 +1006,7 @@ def ne_constraint(val):
def nev(x):
return val != x
- nev.__name__ = str(val) + "!="
+ nev.__name__ = f"{str(val)}!="
return nev
@@ -1023,7 +1015,7 @@ def no_heuristic(to_do):
def sat_up(to_do):
- return SortedSet(to_do, key=lambda t: 1 / len([var for var in t[1].scope]))
+ return SortedSet(to_do, key=lambda t: 1 / len(list(t[1].scope)))
class ACSolver:
@@ -1055,7 +1047,7 @@ def GAC(self, orig_domains=None, to_do=None, arc_heuristic=sat_up):
var, const = to_do.pop()
other_vars = [ov for ov in const.scope if ov != var]
new_domain = set()
- if len(other_vars) == 0:
+ if not other_vars:
for val in domains[var]:
if const.holds({var: val}):
new_domain.add(val)
@@ -1107,15 +1099,14 @@ def any_holds(self, domains, const, env, other_vars, ind=0, checks=0):
"""
if ind == len(other_vars):
return const.holds(env), checks + 1
- else:
- var = other_vars[ind]
- for val in domains[var]:
- # env = dict_union(env, {var:val}) # no side effects
- env[var] = val
- holds, checks = self.any_holds(domains, const, env, other_vars, ind + 1, checks)
- if holds:
- return True, checks
- return False, checks
+ var = other_vars[ind]
+ for val in domains[var]:
+ # env = dict_union(env, {var:val}) # no side effects
+ env[var] = val
+ holds, checks = self.any_holds(domains, const, env, other_vars, ind + 1, checks)
+ if holds:
+ return True, checks
+ return False, checks
def domain_splitting(self, domains=None, to_do=None, arc_heuristic=sat_up):
"""
@@ -1130,14 +1121,15 @@ def domain_splitting(self, domains=None, to_do=None, arc_heuristic=sat_up):
elif all(len(new_domains[var]) == 1 for var in domains):
return {var: first(new_domains[var]) for var in domains}
else:
- var = first(x for x in self.csp.variables if len(new_domains[x]) > 1)
- if var:
+ if var := first(
+ x for x in self.csp.variables if len(new_domains[x]) > 1
+ ):
dom1, dom2 = partition_domain(new_domains[var])
new_doms1 = extend(new_domains, var, dom1)
new_doms2 = extend(new_domains, var, dom2)
to_do = self.new_to_do(var, None)
return self.domain_splitting(new_doms1, to_do, arc_heuristic) or \
- self.domain_splitting(new_doms2, to_do, arc_heuristic)
+ self.domain_splitting(new_doms2, to_do, arc_heuristic)
def partition_domain(dom):
@@ -1233,7 +1225,7 @@ def __init__(self, puzzle, words):
scope = []
for j, element in enumerate(line):
if element == '_':
- var = "p" + str(j) + str(i)
+ var = f"p{str(j)}{str(i)}"
domains[var] = list(string.ascii_lowercase)
scope.append(var)
else:
@@ -1247,7 +1239,7 @@ def __init__(self, puzzle, words):
scope = []
for j, element in enumerate(line):
if element == '_':
- scope.append("p" + str(i) + str(j))
+ scope.append(f"p{str(i)}{str(j)}")
else:
if len(scope) > 1:
constraints.append(Constraint(tuple(scope), is_word_constraint(words)))
@@ -1263,15 +1255,14 @@ def display(self, assignment=None):
for j, element in enumerate(line):
if element == '*':
puzzle += "[*] "
+ elif assignment is None:
+ puzzle += "[_] "
else:
- var = "p" + str(j) + str(i)
- if assignment is not None:
- if isinstance(assignment[var], set) and len(assignment[var]) == 1:
- puzzle += "[" + str(first(assignment[var])).upper() + "] "
- elif isinstance(assignment[var], str):
- puzzle += "[" + str(assignment[var]).upper() + "] "
- else:
- puzzle += "[_] "
+ var = f"p{str(j)}{str(i)}"
+ if isinstance(assignment[var], set) and len(assignment[var]) == 1:
+ puzzle += f"[{str(first(assignment[var])).upper()}] "
+ elif isinstance(assignment[var], str):
+ puzzle += f"[{str(assignment[var]).upper()}] "
else:
puzzle += "[_] "
print(puzzle)
@@ -1334,18 +1325,16 @@ def __init__(self, puzzle):
if element == '_':
var1 = str(i)
if len(var1) == 1:
- var1 = "0" + var1
+ var1 = f"0{var1}"
var2 = str(j)
if len(var2) == 1:
- var2 = "0" + var2
- variables.append("X" + var1 + var2)
- domains = {}
- for var in variables:
- domains[var] = set(range(1, 10))
+ var2 = f"0{var2}"
+ variables.append(f"X{var1}{var2}")
+ domains = {var: set(range(1, 10)) for var in variables}
constraints = []
for i, line in enumerate(puzzle):
for j, element in enumerate(line):
- if element != '_' and element != '*':
+ if element not in ['_', '*']:
# down - column
if element[0] != '':
x = []
@@ -1354,13 +1343,17 @@ def __init__(self, puzzle):
break
var1 = str(k)
if len(var1) == 1:
- var1 = "0" + var1
+ var1 = f"0{var1}"
var2 = str(j)
if len(var2) == 1:
- var2 = "0" + var2
- x.append("X" + var1 + var2)
- constraints.append(Constraint(x, sum_constraint(element[0])))
- constraints.append(Constraint(x, all_diff_constraint))
+ var2 = f"0{var2}"
+ x.append(f"X{var1}{var2}")
+ constraints.extend(
+ (
+ Constraint(x, sum_constraint(element[0])),
+ Constraint(x, all_diff_constraint),
+ )
+ )
# right - line
if element[1] != '':
x = []
@@ -1369,13 +1362,17 @@ def __init__(self, puzzle):
break
var1 = str(i)
if len(var1) == 1:
- var1 = "0" + var1
+ var1 = f"0{var1}"
var2 = str(k)
if len(var2) == 1:
- var2 = "0" + var2
- x.append("X" + var1 + var2)
- constraints.append(Constraint(x, sum_constraint(element[1])))
- constraints.append(Constraint(x, all_diff_constraint))
+ var2 = f"0{var2}"
+ x.append(f"X{var1}{var2}")
+ constraints.extend(
+ (
+ Constraint(x, sum_constraint(element[1])),
+ Constraint(x, all_diff_constraint),
+ )
+ )
super().__init__(domains, constraints)
self.puzzle = puzzle
@@ -1388,16 +1385,16 @@ def display(self, assignment=None):
elif element == '_':
var1 = str(i)
if len(var1) == 1:
- var1 = "0" + var1
+ var1 = f"0{var1}"
var2 = str(j)
if len(var2) == 1:
- var2 = "0" + var2
- var = "X" + var1 + var2
+ var2 = f"0{var2}"
+ var = f"X{var1}{var2}"
if assignment is not None:
if isinstance(assignment[var], set) and len(assignment[var]) == 1:
- puzzle += "[" + str(first(assignment[var])) + "]\t"
+ puzzle += f"[{str(first(assignment[var]))}" + "]\t"
elif isinstance(assignment[var], int):
- puzzle += "[" + str(assignment[var]) + "]\t"
+ puzzle += f"[{str(assignment[var])}" + "]\t"
else:
puzzle += "[_]\t"
else:
diff --git a/deep_learning4e.py b/deep_learning4e.py
index 9f5b0a8f7..76c898e4a 100644
--- a/deep_learning4e.py
+++ b/deep_learning4e.py
@@ -306,7 +306,7 @@ def stochastic_gradient_descent(dataset, net, loss, epochs=1000, l_rate=0.01, ba
net[i].nodes[j].weights = weights[i][j]
if verbose:
- print("epoch:{}, total_loss:{}".format(e + 1, total_loss))
+ print(f"epoch:{e + 1}, total_loss:{total_loss}")
return net
@@ -363,7 +363,7 @@ def adam(dataset, net, loss, epochs=1000, rho=(0.9, 0.999), delta=1 / 10 ** 8,
net[i].nodes[j].weights = weights[i][j]
if verbose:
- print("epoch:{}, total_loss:{}".format(e + 1, total_loss))
+ print(f"epoch:{e + 1}, total_loss:{total_loss}")
return net
diff --git a/games.py b/games.py
index d22b2e640..40fb0215d 100644
--- a/games.py
+++ b/games.py
@@ -179,7 +179,7 @@ def query_player(game, state):
"""Make a move by querying standard input."""
print("current state:")
game.display(state)
- print("available moves: {}".format(game.actions(state)))
+ print(f"available moves: {game.actions(state)}")
print("")
move = None
if game.actions(state):
@@ -248,7 +248,7 @@ def display(self, state):
print(state)
def __repr__(self):
- return '<{}>'.format(self.__class__.__name__)
+ return f'<{self.__class__.__name__}>'
def play_game(self, *players):
"""Play an n-person, move-alternating game."""
@@ -311,10 +311,7 @@ def result(self, state, move):
return self.succs[state][move]
def utility(self, state, player):
- if player == 'MAX':
- return self.utils[state]
- else:
- return -self.utils[state]
+ return self.utils[state] if player == 'MAX' else -self.utils[state]
def terminal_test(self, state):
return state not in ('A', 'B', 'C', 'D')
@@ -336,10 +333,7 @@ def result(self, state, move):
return self.succs[state][move]
def utility(self, state, player):
- if player == 'MAX':
- return self.utils[state]
- else:
- return -self.utils[state]
+ return self.utils[state] if player == 'MAX' else -self.utils[state]
def terminal_test(self, state):
return state not in range(13)
@@ -445,7 +439,7 @@ class Backgammon(StochasticGame):
def __init__(self):
"""Initial state of the game"""
point = {'W': 0, 'B': 0}
- board = [point.copy() for index in range(24)]
+ board = [point.copy() for _ in range(24)]
board[0]['B'] = board[23]['W'] = 2
board[5]['W'] = board[18]['B'] = 5
board[7]['W'] = board[16]['B'] = 3
@@ -497,11 +491,14 @@ def get_all_moves(self, board, player):
all_points = board
taken_points = [index for index, point in enumerate(all_points)
if point[player] > 0]
- if self.checkers_at_home(board, player) == 1:
+ if self.checkers_at_home(all_points, player) == 1:
return [(taken_points[0],)]
moves = list(itertools.permutations(taken_points, 2))
- moves = moves + [(index, index) for index, point in enumerate(all_points)
- if point[player] >= 2]
+ moves += [
+ (index, index)
+ for index, point in enumerate(all_points)
+ if point[player] >= 2
+ ]
return moves
def display(self, state):
@@ -516,10 +513,9 @@ def display(self, state):
def compute_utility(self, board, move, player):
"""If 'W' wins with this move, return 1; if 'B' wins return -1; else return 0."""
util = {'W': 1, 'B': -1}
- for idx in range(0, 24):
- if board[idx][player] > 0:
- return 0
- return util[player]
+ return next(
+ (0 for idx in range(0, 24) if board[idx][player] > 0), util[player]
+ )
def checkers_at_home(self, board, player):
"""Return the no. of checkers at home for a player."""
@@ -541,18 +537,16 @@ def is_legal_move(self, board, start, steps, player):
if self.is_point_open(player, board[dest1]):
self.move_checker(board, start[0], steps[0], player)
move1_legal = True
- else:
- if self.allow_bear_off[player]:
- self.move_checker(board, start[0], steps[0], player)
- move1_legal = True
+ elif self.allow_bear_off[player]:
+ self.move_checker(board, start[0], steps[0], player)
+ move1_legal = True
if not move1_legal:
return False
if dest2 in dest_range:
if self.is_point_open(player, board[dest2]):
move2_legal = True
- else:
- if self.allow_bear_off[player]:
- move2_legal = True
+ elif self.allow_bear_off[player]:
+ move2_legal = True
return move1_legal and move2_legal
def move_checker(self, board, start, steps, player):
@@ -574,8 +568,7 @@ def is_point_open(self, player, point):
def chances(self, state):
"""Return a list of all possible dice rolls at a state."""
- dice_rolls = list(itertools.combinations_with_replacement([1, 2, 3, 4, 5, 6], 2))
- return dice_rolls
+ return list(itertools.combinations_with_replacement([1, 2, 3, 4, 5, 6], 2))
def outcome(self, state, chance):
"""Return the state which is the outcome of a dice roll."""
diff --git a/games4e.py b/games4e.py
index aba5b0eb3..657c3def2 100644
--- a/games4e.py
+++ b/games4e.py
@@ -178,10 +178,7 @@ def min_value(state, alpha, beta, depth):
def monte_carlo_tree_search(state, game, N=1000):
def select(n):
"""select a leaf node in the tree"""
- if n.children:
- return select(max(n.children.keys(), key=ucb))
- else:
- return n
+ return select(max(n.children.keys(), key=ucb)) if n.children else n
def expand(n):
"""expand the leaf node by adding all its children states"""
@@ -230,7 +227,7 @@ def query_player(game, state):
"""Make a move by querying standard input."""
print("current state:")
game.display(state)
- print("available moves: {}".format(game.actions(state)))
+ print(f"available moves: {game.actions(state)}")
print("")
move = None
if game.actions(state):
@@ -299,7 +296,7 @@ def display(self, state):
print(state)
def __repr__(self):
- return '<{}>'.format(self.__class__.__name__)
+ return f'<{self.__class__.__name__}>'
def play_game(self, *players):
"""Play an n-person, move-alternating game."""
@@ -362,10 +359,7 @@ def result(self, state, move):
return self.succs[state][move]
def utility(self, state, player):
- if player == 'MAX':
- return self.utils[state]
- else:
- return -self.utils[state]
+ return self.utils[state] if player == 'MAX' else -self.utils[state]
def terminal_test(self, state):
return state not in ('A', 'B', 'C', 'D')
@@ -387,10 +381,7 @@ def result(self, state, move):
return self.succs[state][move]
def utility(self, state, player):
- if player == 'MAX':
- return self.utils[state]
- else:
- return -self.utils[state]
+ return self.utils[state] if player == 'MAX' else -self.utils[state]
def terminal_test(self, state):
return state not in range(13)
@@ -490,7 +481,7 @@ class Backgammon(StochasticGame):
def __init__(self):
"""Initial state of the game"""
point = {'W': 0, 'B': 0}
- board = [point.copy() for index in range(24)]
+ board = [point.copy() for _ in range(24)]
board[0]['B'] = board[23]['W'] = 2
board[5]['W'] = board[18]['B'] = 5
board[7]['W'] = board[16]['B'] = 3
@@ -542,11 +533,14 @@ def get_all_moves(self, board, player):
all_points = board
taken_points = [index for index, point in enumerate(all_points)
if point[player] > 0]
- if self.checkers_at_home(board, player) == 1:
+ if self.checkers_at_home(all_points, player) == 1:
return [(taken_points[0],)]
moves = list(itertools.permutations(taken_points, 2))
- moves = moves + [(index, index) for index, point in enumerate(all_points)
- if point[player] >= 2]
+ moves += [
+ (index, index)
+ for index, point in enumerate(all_points)
+ if point[player] >= 2
+ ]
return moves
def display(self, state):
@@ -561,10 +555,9 @@ def display(self, state):
def compute_utility(self, board, move, player):
"""If 'W' wins with this move, return 1; if 'B' wins return -1; else return 0."""
util = {'W': 1, 'B': -1}
- for idx in range(0, 24):
- if board[idx][player] > 0:
- return 0
- return util[player]
+ return next(
+ (0 for idx in range(0, 24) if board[idx][player] > 0), util[player]
+ )
def checkers_at_home(self, board, player):
"""Return the no. of checkers at home for a player."""
@@ -586,18 +579,16 @@ def is_legal_move(self, board, start, steps, player):
if self.is_point_open(player, board[dest1]):
self.move_checker(board, start[0], steps[0], player)
move1_legal = True
- else:
- if self.allow_bear_off[player]:
- self.move_checker(board, start[0], steps[0], player)
- move1_legal = True
+ elif self.allow_bear_off[player]:
+ self.move_checker(board, start[0], steps[0], player)
+ move1_legal = True
if not move1_legal:
return False
if dest2 in dest_range:
if self.is_point_open(player, board[dest2]):
move2_legal = True
- else:
- if self.allow_bear_off[player]:
- move2_legal = True
+ elif self.allow_bear_off[player]:
+ move2_legal = True
return move1_legal and move2_legal
def move_checker(self, board, start, steps, player):
@@ -619,8 +610,7 @@ def is_point_open(self, player, point):
def chances(self, state):
"""Return a list of all possible dice rolls at a state."""
- dice_rolls = list(itertools.combinations_with_replacement([1, 2, 3, 4, 5, 6], 2))
- return dice_rolls
+ return list(itertools.combinations_with_replacement([1, 2, 3, 4, 5, 6], 2))
def outcome(self, state, chance):
"""Return the state which is the outcome of a dice roll."""
diff --git a/gui/eight_puzzle.py b/gui/eight_puzzle.py
index 5733228d7..32d61d5a2 100644
--- a/gui/eight_puzzle.py
+++ b/gui/eight_puzzle.py
@@ -25,10 +25,7 @@ def scramble():
global state
global puzzle
possible_actions = ['UP', 'DOWN', 'LEFT', 'RIGHT']
- scramble = []
- for _ in range(60):
- scramble.append(random.choice(possible_actions))
-
+ scramble = [random.choice(possible_actions) for _ in range(60)]
for move in scramble:
if move in puzzle.actions(state):
state = list(puzzle.result(state, move))
diff --git a/gui/genetic_algorithm_example.py b/gui/genetic_algorithm_example.py
index c987151c8..e28f5f69b 100644
--- a/gui/genetic_algorithm_example.py
+++ b/gui/genetic_algorithm_example.py
@@ -45,8 +45,7 @@
numerals = [chr(x) for x in range(48, 58)] # list containing numbers
# extend the gene pool with the required lists and append the space character
-gene_pool = []
-gene_pool.extend(u_case)
+gene_pool = list(u_case)
gene_pool.extend(l_case)
gene_pool.append(' ')
@@ -74,14 +73,9 @@ def update_ngen(slider_value):
# fitness function
def fitness_fn(_list):
- fitness = 0
# create string from list of characters
phrase = ''.join(_list)
- # add 1 to fitness value for every matching character
- for i in range(len(phrase)):
- if target[i] == phrase[i]:
- fitness += 1
- return fitness
+ return sum(1 for i in range(len(phrase)) if target[i] == phrase[i])
# function to bring a new frame on top
@@ -147,8 +141,13 @@ def genetic_algorithm_stepwise(population):
for generation in range(ngen):
# generating new population after selecting, recombining and mutating the existing population
population = [
- search.mutate(search.recombine(*search.select(2, population, fitness_fn)), gene_pool, mutation_rate) for i
- in range(len(population))]
+ search.mutate(
+ search.recombine(*search.select(2, population, fitness_fn)),
+ gene_pool,
+ mutation_rate,
+ )
+ for _ in range(len(population))
+ ]
# genome with the highest fitness in the current generation
current_best = ''.join(max(population, key=fitness_fn))
# collecting first few examples from the current population
diff --git a/gui/grid_mdp.py b/gui/grid_mdp.py
index e60b49247..e3e432c2d 100644
--- a/gui/grid_mdp.py
+++ b/gui/grid_mdp.py
@@ -573,11 +573,9 @@ def __init__(self, parent, controller):
def process_data(self, terminals, _height, _width, gridmdp):
"""preprocess variables"""
- flipped_terminals = []
-
- for terminal in terminals:
- flipped_terminals.append((terminal[1], _height - terminal[0] - 1))
-
+ flipped_terminals = [
+ (terminal[1], _height - terminal[0] - 1) for terminal in terminals
+ ]
grid_to_solve = [[0.0] * max(1, _width) for _ in range(max(1, _height))]
grid_to_show = [[0.0] * max(1, _width) for _ in range(max(1, _height))]
@@ -627,8 +625,13 @@ def animate_graph(self, i):
U = self.U1.copy()
for s in self.sequential_decision_environment.states:
- self.U1[s] = self.R(s) + self.gamma * max(
- [sum([p * U[s1] for (p, s1) in self.T(s, a)]) for a in self.sequential_decision_environment.actions(s)])
+ self.U1[s] = self.R(s) + (
+ self.gamma
+ * max(
+ sum(p * U[s1] for (p, s1) in self.T(s, a))
+ for a in self.sequential_decision_environment.actions(s)
+ )
+ )
self.delta = max(self.delta, abs(self.U1[s] - U[s]))
self.grid_to_show = grid_to_show = [[0.0] * max(1, self._width) for _ in range(max(1, self._height))]
@@ -664,7 +667,9 @@ def value_iteration_metastep(self, mdp, iterations=20):
U = U1.copy()
for s in mdp.states:
- U1[s] = R(s) + gamma * max([sum([p * U[s1] for (p, s1) in T(s, a)]) for a in mdp.actions(s)])
+ U1[s] = R(s) + gamma * max(
+ sum(p * U[s1] for (p, s1) in T(s, a)) for a in mdp.actions(s)
+ )
U_over_time.append(U)
return U_over_time
diff --git a/gui/romania_problem.py b/gui/romania_problem.py
index 9ec94099d..02b9489c6 100644
--- a/gui/romania_problem.py
+++ b/gui/romania_problem.py
@@ -562,7 +562,7 @@ def on_click():
"""
global algo, counter, next_button, romania_problem, start, goal
romania_problem = GraphProblem(start.get(), goal.get(), romania_map)
- if "Breadth-First Tree Search" == algo.get():
+ if algo.get() == "Breadth-First Tree Search":
node = breadth_first_tree_search(romania_problem)
if node is not None:
final_path = breadth_first_tree_search(romania_problem).solution()
@@ -570,7 +570,7 @@ def on_click():
display_final(final_path)
next_button.config(state="disabled")
counter += 1
- elif "Depth-First Tree Search" == algo.get():
+ elif algo.get() == "Depth-First Tree Search":
node = depth_first_tree_search(romania_problem)
if node is not None:
final_path = depth_first_tree_search(romania_problem).solution()
@@ -578,7 +578,7 @@ def on_click():
display_final(final_path)
next_button.config(state="disabled")
counter += 1
- elif "Breadth-First Graph Search" == algo.get():
+ elif algo.get() == "Breadth-First Graph Search":
node = breadth_first_graph_search(romania_problem)
if node is not None:
final_path = breadth_first_graph_search(romania_problem).solution()
@@ -586,7 +586,7 @@ def on_click():
display_final(final_path)
next_button.config(state="disabled")
counter += 1
- elif "Depth-First Graph Search" == algo.get():
+ elif algo.get() == "Depth-First Graph Search":
node = depth_first_graph_search(romania_problem)
if node is not None:
final_path = depth_first_graph_search(romania_problem).solution()
@@ -594,7 +594,7 @@ def on_click():
display_final(final_path)
next_button.config(state="disabled")
counter += 1
- elif "Uniform Cost Search" == algo.get():
+ elif algo.get() == "Uniform Cost Search":
node = uniform_cost_search(romania_problem)
if node is not None:
final_path = uniform_cost_search(romania_problem).solution()
@@ -602,7 +602,7 @@ def on_click():
display_final(final_path)
next_button.config(state="disabled")
counter += 1
- elif "A* - Search" == algo.get():
+ elif algo.get() == "A* - Search":
node = astar_search(romania_problem)
if node is not None:
final_path = astar_search(romania_problem).solution()
diff --git a/gui/tic-tac-toe.py b/gui/tic-tac-toe.py
index 66d9d6e75..575475ecf 100644
--- a/gui/tic-tac-toe.py
+++ b/gui/tic-tac-toe.py
@@ -43,9 +43,7 @@ def create_frames(root):
frames.append(frame2)
frames.append(frame3)
for x in frames:
- buttons_in_frame = []
- for y in x.winfo_children():
- buttons_in_frame.append(y)
+ buttons_in_frame = list(x.winfo_children())
buttons.append(buttons_in_frame)
buttonReset = Button(frame4, height=1, width=2,
text="Reset", command=lambda: reset_game())
@@ -74,10 +72,7 @@ def on_click(button):
"""
global ttt, choices, count, sym, result, x_pos, o_pos
- if count % 2 == 0:
- sym = "X"
- else:
- sym = "O"
+ sym = "X" if count % 2 == 0 else "O"
count += 1
button.config(
@@ -106,10 +101,7 @@ def on_click(button):
if 1 <= a <= 3 and 1 <= b <= 3:
o_pos.append((a, b))
button_to_change = get_button(a - 1, b - 1)
- if count % 2 == 0: # Used again, will become handy when user is given the choice of turn.
- sym = "X"
- else:
- sym = "O"
+ sym = "X" if count % 2 == 0 else "O"
count += 1
if check_victory(button):
@@ -133,16 +125,16 @@ def check_victory(button):
x, y = get_coordinates(button)
tt = button['text']
if buttons[0][y]['text'] == buttons[1][y]['text'] == buttons[2][y]['text'] != " ":
- buttons[0][y].config(text="|" + tt + "|")
- buttons[1][y].config(text="|" + tt + "|")
- buttons[2][y].config(text="|" + tt + "|")
+ buttons[0][y].config(text=f"|{tt}|")
+ buttons[1][y].config(text=f"|{tt}|")
+ buttons[2][y].config(text=f"|{tt}|")
return True
# check if previous move caused a win on horizontal line
if buttons[x][0]['text'] == buttons[x][1]['text'] == buttons[x][2]['text'] != " ":
- buttons[x][0].config(text="--" + tt + "--")
- buttons[x][1].config(text="--" + tt + "--")
- buttons[x][2].config(text="--" + tt + "--")
+ buttons[x][0].config(text=f"--{tt}--")
+ buttons[x][1].config(text=f"--{tt}--")
+ buttons[x][2].config(text=f"--{tt}--")
return True
# check if previous move was on the main diagonal and caused a win
@@ -154,9 +146,9 @@ def check_victory(button):
# check if previous move was on the secondary diagonal and caused a win
if x + y == 2 and buttons[0][2]['text'] == buttons[1][1]['text'] == buttons[2][0]['text'] != " ":
- buttons[0][2].config(text="/" + tt + "/")
- buttons[1][1].config(text="/" + tt + "/")
- buttons[2][0].config(text="/" + tt + "/")
+ buttons[0][2].config(text=f"/{tt}/")
+ buttons[1][1].config(text=f"/{tt}/")
+ buttons[2][0].config(text=f"/{tt}/")
return True
return False
diff --git a/gui/tsp.py b/gui/tsp.py
index 590fff354..266c2ebda 100644
--- a/gui/tsp.py
+++ b/gui/tsp.py
@@ -32,9 +32,7 @@ def result(self, state, action):
def path_cost(self, c, state1, action, state2):
"""total distance for the Traveling Salesman to be covered if in state2"""
- cost = 0
- for i in range(len(state2) - 1):
- cost += distances[state2[i]][state2[i + 1]]
+ cost = sum(distances[state2[i]][state2[i + 1]] for i in range(len(state2) - 1))
cost += distances[state2[0]][state2[-1]]
return cost
@@ -102,11 +100,11 @@ def create_dropdown_menu(self):
def run_traveling_salesman(self):
"""Choose selected cities"""
- cities = []
- for i in range(len(self.vars)):
- if self.vars[i].get() == 1:
- cities.append(self.all_cities[i])
-
+ cities = [
+ self.all_cities[i]
+ for i in range(len(self.vars))
+ if self.vars[i].get() == 1
+ ]
tsp_problem = TSProblem(cities)
self.button_text.set("Reset")
self.create_canvas(tsp_problem)
@@ -186,10 +184,10 @@ def create_canvas(self, problem):
no_of_neighbors_scale.grid(row=1, column=5, columnspan=5, sticky='nsew')
self.hill_climbing(problem, map_canvas)
- def exp_schedule(k=100, lam=0.03, limit=1000):
+ def exp_schedule(self, lam=0.03, limit=1000):
"""One possible schedule function for simulated annealing"""
- return lambda t: (k * np.exp(-lam * t) if t < limit else 0)
+ return lambda t: self * np.exp(-lam * t) if t < limit else 0
def simulated_annealing_with_tunable_T(self, problem, map_canvas, schedule=exp_schedule()):
"""Simulated annealing where temperature is taken as user input"""
@@ -212,8 +210,7 @@ def simulated_annealing_with_tunable_T(self, problem, map_canvas, schedule=exp_s
self.cost.set("Cost = " + str('%0.3f' % (-1 * problem.value(current.state))))
points = []
for city in current.state:
- points.append(self.frame_locations[city][0])
- points.append(self.frame_locations[city][1])
+ points.extend((self.frame_locations[city][0], self.frame_locations[city][1]))
map_canvas.create_polygon(points, outline='red', width=3, fill='', tag="poly")
map_canvas.update()
map_canvas.after(self.speed.get())
diff --git a/gui/vacuum_agent.py b/gui/vacuum_agent.py
index b07dab282..a5ca8d854 100644
--- a/gui/vacuum_agent.py
+++ b/gui/vacuum_agent.py
@@ -88,15 +88,9 @@ def read_env(self):
"""Reads the current state of the GUI."""
for i, btn in enumerate(self.buttons):
if i == 0:
- if btn['bg'] == 'white':
- self.status[loc_A] = 'Clean'
- else:
- self.status[loc_A] = 'Dirty'
+ self.status[loc_A] = 'Clean' if btn['bg'] == 'white' else 'Dirty'
else:
- if btn['bg'] == 'white':
- self.status[loc_B] = 'Clean'
- else:
- self.status[loc_B] = 'Dirty'
+ self.status[loc_B] = 'Clean' if btn['bg'] == 'white' else 'Dirty'
def update_env(self, agent):
"""Updates the GUI according to the agent's action."""
@@ -123,9 +117,7 @@ def create_agent(env, agent):
def move_agent(env, agent, before_step):
"""Moves the agent in the GUI when 'next' button is pressed."""
- if agent.location == before_step:
- pass
- else:
+ if agent.location != before_step:
if agent.location == (1, 0):
env.canvas.move(env.text, 120, 0)
env.canvas.move(env.agent_rect, 120, 0)
diff --git a/gui/xy_vacuum_environment.py b/gui/xy_vacuum_environment.py
index 093abc6c3..deee4c3a5 100644
--- a/gui/xy_vacuum_environment.py
+++ b/gui/xy_vacuum_environment.py
@@ -47,7 +47,7 @@ def create_buttons(self):
def create_walls(self):
"""Creates the outer boundary walls which do not move."""
for row, button_row in enumerate(self.buttons):
- if row == 0 or row == len(self.buttons) - 1:
+ if row in [0, len(self.buttons) - 1]:
for button in button_row:
button.config(text='W', state='disabled',
disabledforeground='black')
@@ -63,20 +63,18 @@ def create_walls(self):
def display_element(self, button):
"""Show the things on the GUI."""
txt = button['text']
- if txt != 'A':
- if txt == 'W':
- button.config(text='D')
- elif txt == 'D':
- button.config(text='')
- elif txt == '':
- button.config(text='W')
+ if txt == '':
+ button.config(text='W')
+ elif txt == 'D':
+ button.config(text='')
+ elif txt == 'W':
+ button.config(text='D')
def execute_action(self, agent, action):
"""Determines the action the agent performs."""
xi, yi = (self.xi, self.yi)
if action == 'Suck':
- dirt_list = self.list_things_at(agent.location, Dirt)
- if dirt_list:
+ if dirt_list := self.list_things_at(agent.location, Dirt):
dirt = dirt_list[0]
agent.performance += 100
self.delete_thing(dirt)
@@ -106,7 +104,10 @@ def read_env(self):
"""Reads the current state of the GUI environment."""
for i, btn_row in enumerate(self.buttons):
for j, btn in enumerate(btn_row):
- if (i != 0 and i != len(self.buttons) - 1) and (j != 0 and j != len(btn_row) - 1):
+ if i not in [0, len(self.buttons) - 1] and j not in [
+ 0,
+ len(btn_row) - 1,
+ ]:
agt_loc = self.agents[0].location
if self.some_things_at((i, j)) and (i, j) != agt_loc:
for thing in self.list_things_at((i, j)):
@@ -130,11 +131,14 @@ def reset_env(self, agt):
self.read_env()
for i, btn_row in enumerate(self.buttons):
for j, btn in enumerate(btn_row):
- if (i != 0 and i != len(self.buttons) - 1) and (j != 0 and j != len(btn_row) - 1):
- if self.some_things_at((i, j)):
- for thing in self.list_things_at((i, j)):
- self.delete_thing(thing)
- btn.config(text='', state='normal')
+ if (
+ i not in [0, len(self.buttons) - 1]
+ and j not in [0, len(btn_row) - 1]
+ and self.some_things_at((i, j))
+ ):
+ for thing in self.list_things_at((i, j)):
+ self.delete_thing(thing)
+ btn.config(text='', state='normal')
self.add_thing(agt, location=(3, 3))
self.buttons[3][3].config(
text='A', state='disabled', disabledforeground='black')
diff --git a/images/-0.04.jpg b/images/-0.04.jpg
index 3cf276421..6ed9662c1 100644
Binary files a/images/-0.04.jpg and b/images/-0.04.jpg differ
diff --git a/images/-0.4.jpg b/images/-0.4.jpg
index b274d2ce3..9160fae23 100644
Binary files a/images/-0.4.jpg and b/images/-0.4.jpg differ
diff --git a/images/-4.jpg b/images/-4.jpg
index 79eefb0cd..a95102fb7 100644
Binary files a/images/-4.jpg and b/images/-4.jpg differ
diff --git a/images/4.jpg b/images/4.jpg
index 55e75001d..fd0238ddb 100644
Binary files a/images/4.jpg and b/images/4.jpg differ
diff --git a/images/aima3e_big.jpg b/images/aima3e_big.jpg
index 1105a5e14..997bab129 100644
Binary files a/images/aima3e_big.jpg and b/images/aima3e_big.jpg differ
diff --git a/images/broxrevised.png b/images/broxrevised.png
index 87051a383..022a94682 100644
Binary files a/images/broxrevised.png and b/images/broxrevised.png differ
diff --git a/images/cake_graph.jpg b/images/cake_graph.jpg
index 160a413ca..c0603d403 100644
Binary files a/images/cake_graph.jpg and b/images/cake_graph.jpg differ
diff --git a/images/dirt.svg b/images/dirt.svg
index 162f84582..ad6da1ddd 100644
--- a/images/dirt.svg
+++ b/images/dirt.svg
@@ -1,291 +1 @@
-
-
-
+
\ No newline at end of file
diff --git a/images/dirt05-icon.jpg b/images/dirt05-icon.jpg
index 38d02e97f..119bd6468 100644
Binary files a/images/dirt05-icon.jpg and b/images/dirt05-icon.jpg differ
diff --git a/images/ensemble_learner.jpg b/images/ensemble_learner.jpg
index b1edd1ec5..522858ae9 100644
Binary files a/images/ensemble_learner.jpg and b/images/ensemble_learner.jpg differ
diff --git a/images/fig_5_2.png b/images/fig_5_2.png
index 872485798..47ef92654 100644
Binary files a/images/fig_5_2.png and b/images/fig_5_2.png differ
diff --git a/images/ge0.jpg b/images/ge0.jpg
index a70b18703..d52dcf79f 100644
Binary files a/images/ge0.jpg and b/images/ge0.jpg differ
diff --git a/images/ge1.jpg b/images/ge1.jpg
index 624f16e25..061605aa0 100644
Binary files a/images/ge1.jpg and b/images/ge1.jpg differ
diff --git a/images/ge2.jpg b/images/ge2.jpg
index 3a29f8f4c..1650c9ab5 100644
Binary files a/images/ge2.jpg and b/images/ge2.jpg differ
diff --git a/images/ge4.jpg b/images/ge4.jpg
index b3a4b4acd..6d5044a25 100644
Binary files a/images/ge4.jpg and b/images/ge4.jpg differ
diff --git a/images/general_learning_agent.jpg b/images/general_learning_agent.jpg
index a8153bef8..5a1eec71f 100644
Binary files a/images/general_learning_agent.jpg and b/images/general_learning_agent.jpg differ
diff --git a/images/grid_mdp_agent.jpg b/images/grid_mdp_agent.jpg
index 3f247b6f2..1c31ed32a 100644
Binary files a/images/grid_mdp_agent.jpg and b/images/grid_mdp_agent.jpg differ
diff --git a/images/hillclimb-tsp.png b/images/hillclimb-tsp.png
index 8446bbafc..bf56fce07 100644
Binary files a/images/hillclimb-tsp.png and b/images/hillclimb-tsp.png differ
diff --git a/images/knn_plot.png b/images/knn_plot.png
index 58b316fdd..973b1a04a 100644
Binary files a/images/knn_plot.png and b/images/knn_plot.png differ
diff --git a/images/knowledge_FOIL_grandparent.png b/images/knowledge_FOIL_grandparent.png
index dbc6e7729..1910a7ebc 100644
Binary files a/images/knowledge_FOIL_grandparent.png and b/images/knowledge_FOIL_grandparent.png differ
diff --git a/images/knowledge_foil_family.png b/images/knowledge_foil_family.png
index 356f22d8d..9d9b89b22 100644
Binary files a/images/knowledge_foil_family.png and b/images/knowledge_foil_family.png differ
diff --git a/images/maze.png b/images/maze.png
index f3fcd1990..36d7f0e6f 100644
Binary files a/images/maze.png and b/images/maze.png differ
diff --git a/images/mdp-a.png b/images/mdp-a.png
index 2f3774891..ee1fa2839 100644
Binary files a/images/mdp-a.png and b/images/mdp-a.png differ
diff --git a/images/mdp-b.png b/images/mdp-b.png
index f21a3760c..92b8e7b00 100644
Binary files a/images/mdp-b.png and b/images/mdp-b.png differ
diff --git a/images/mdp-c.png b/images/mdp-c.png
index 1034079a2..b777a52be 100644
Binary files a/images/mdp-c.png and b/images/mdp-c.png differ
diff --git a/images/mdp-d.png b/images/mdp-d.png
index 8ba7cf073..3659dac8d 100644
Binary files a/images/mdp-d.png and b/images/mdp-d.png differ
diff --git a/images/model_based_reflex_agent.jpg b/images/model_based_reflex_agent.jpg
index b6c12ed09..a2f271ac6 100644
Binary files a/images/model_based_reflex_agent.jpg and b/images/model_based_reflex_agent.jpg differ
diff --git a/images/model_goal_based_agent.jpg b/images/model_goal_based_agent.jpg
index 93d6182b4..3dbca844b 100644
Binary files a/images/model_goal_based_agent.jpg and b/images/model_goal_based_agent.jpg differ
diff --git a/images/model_utility_based_agent.jpg b/images/model_utility_based_agent.jpg
index 693230c00..e9a78fc96 100644
Binary files a/images/model_utility_based_agent.jpg and b/images/model_utility_based_agent.jpg differ
diff --git a/images/neural_net.png b/images/neural_net.png
index 4aa28a106..ca817fd79 100644
Binary files a/images/neural_net.png and b/images/neural_net.png differ
diff --git a/images/parse_tree.png b/images/parse_tree.png
index f6ca87b2f..5d90465cd 100644
Binary files a/images/parse_tree.png and b/images/parse_tree.png differ
diff --git a/images/perceptron.png b/images/perceptron.png
index 68d2a258a..acf87e5d2 100644
Binary files a/images/perceptron.png and b/images/perceptron.png differ
diff --git a/images/pluralityLearner_plot.png b/images/pluralityLearner_plot.png
index 50aa5dcd1..b4f04ec02 100644
Binary files a/images/pluralityLearner_plot.png and b/images/pluralityLearner_plot.png differ
diff --git a/images/point_crossover.png b/images/point_crossover.png
index 9b8d4f7f5..48abb3ff7 100644
Binary files a/images/point_crossover.png and b/images/point_crossover.png differ
diff --git a/images/pop.jpg b/images/pop.jpg
index 52b3e3756..507fe0a93 100644
Binary files a/images/pop.jpg and b/images/pop.jpg differ
diff --git a/images/refinement.png b/images/refinement.png
index 8270d81d0..09ae5f830 100644
Binary files a/images/refinement.png and b/images/refinement.png differ
diff --git a/images/restaurant.png b/images/restaurant.png
index 195c67645..d99554f86 100644
Binary files a/images/restaurant.png and b/images/restaurant.png differ
diff --git a/images/romania_map.png b/images/romania_map.png
index 426c76f1e..01e4552b1 100644
Binary files a/images/romania_map.png and b/images/romania_map.png differ
diff --git a/images/search_animal.svg b/images/search_animal.svg
index e3c3105c8..647cbd4e6 100644
--- a/images/search_animal.svg
+++ b/images/search_animal.svg
@@ -1,1533 +1 @@
-
-
-
-
+
\ No newline at end of file
diff --git a/images/simple_problem_solving_agent.jpg b/images/simple_problem_solving_agent.jpg
index 80fb904b5..947c5d22f 100644
Binary files a/images/simple_problem_solving_agent.jpg and b/images/simple_problem_solving_agent.jpg differ
diff --git a/images/simple_reflex_agent.jpg b/images/simple_reflex_agent.jpg
index 74002a720..2abd283cf 100644
Binary files a/images/simple_reflex_agent.jpg and b/images/simple_reflex_agent.jpg differ
diff --git a/images/sprinklernet.jpg b/images/sprinklernet.jpg
index cac16ee09..8ec4b1937 100644
Binary files a/images/sprinklernet.jpg and b/images/sprinklernet.jpg differ
diff --git a/images/stapler1-test.png b/images/stapler1-test.png
index e550d83f9..fd78c9a46 100644
Binary files a/images/stapler1-test.png and b/images/stapler1-test.png differ
diff --git a/images/uniform_crossover.png b/images/uniform_crossover.png
index 37f835e92..85df7cb16 100644
Binary files a/images/uniform_crossover.png and b/images/uniform_crossover.png differ
diff --git a/images/vacuum-icon.jpg b/images/vacuum-icon.jpg
index 71c80bb6f..d49f2d82a 100644
Binary files a/images/vacuum-icon.jpg and b/images/vacuum-icon.jpg differ
diff --git a/images/vacuum.svg b/images/vacuum.svg
index c5a016b07..ef91d38cb 100644
--- a/images/vacuum.svg
+++ b/images/vacuum.svg
@@ -1,150 +1 @@
-
-
-
+
\ No newline at end of file
diff --git a/ipyviews.py b/ipyviews.py
index b304af7bb..4d34b6da0 100644
--- a/ipyviews.py
+++ b/ipyviews.py
@@ -37,9 +37,11 @@ def __init__(self, world, fill="#AAA"):
def object_name(self):
globals_in_main = {x: getattr(__main__, x) for x in dir(__main__)}
for x in globals_in_main:
- if isinstance(globals_in_main[x], type(self)):
- if globals_in_main[x].time == self.time:
- return x
+ if (
+ isinstance(globals_in_main[x], type(self))
+ and globals_in_main[x].time == self.time
+ ):
+ return x
def handle_add_obstacle(self, vertices):
""" Vertices must be a nestedtuple. This method
@@ -52,11 +54,11 @@ def handle_remove_obstacle(self):
return NotImplementedError
def get_polygon_obstacles_coordinates(self):
- obstacle_coordiantes = []
- for thing in self.world.things:
- if isinstance(thing, PolygonObstacle):
- obstacle_coordiantes.append(thing.coordinates)
- return obstacle_coordiantes
+ return [
+ thing.coordinates
+ for thing in self.world.things
+ if isinstance(thing, PolygonObstacle)
+ ]
def show(self):
clear_output()
@@ -103,9 +105,11 @@ def __init__(self, world, block_size=30, default_fill="white"):
def object_name(self):
globals_in_main = {x: getattr(__main__, x) for x in dir(__main__)}
for x in globals_in_main:
- if isinstance(globals_in_main[x], type(self)):
- if globals_in_main[x].time == self.time:
- return x
+ if (
+ isinstance(globals_in_main[x], type(self))
+ and globals_in_main[x].time == self.time
+ ):
+ return x
def set_label(self, coordinates, label):
""" Add lables to a particular block of grid.
@@ -140,7 +144,7 @@ def map_to_render(self):
row, column = thing.location
thing_class_name = thing.__class__.__name__
if thing_class_name not in self.representation:
- raise KeyError('Representation not found for {}'.format(thing_class_name))
+ raise KeyError(f'Representation not found for {thing_class_name}')
world_map[row][column]["val"] = thing.__class__.__name__
for location, label in self.labels.items():
diff --git a/knowledge.py b/knowledge.py
index 8c27c3eb8..47cb1bc41 100644
--- a/knowledge.py
+++ b/knowledge.py
@@ -51,7 +51,7 @@ def specializations(examples_so_far, h):
continue
h2 = h[i].copy()
- h2[k] = '!' + v
+ h2[k] = f'!{v}'
h3 = h.copy()
h3[i] = h2
if check_all_consistency(examples_so_far, h3):
@@ -90,7 +90,7 @@ def generalizations(examples_so_far, h):
hypotheses += h3
# Add OR operations
- if hypotheses == [] or hypotheses == [{}]:
+ if hypotheses in [[], [{}]]:
hypotheses = add_or(examples_so_far, h)
else:
hypotheses.extend(add_or(examples_so_far, h))
@@ -109,10 +109,7 @@ def add_or(examples_so_far, h):
a_powerset = power_set(attrs.keys())
for c in a_powerset:
- h2 = {}
- for k in c:
- h2[k] = attrs[k]
-
+ h2 = {k: attrs[k] for k in c}
if check_negative_consistency(examples_so_far, h2):
h3 = h.copy()
h3.append(h2)
@@ -183,9 +180,7 @@ def build_attr_combinations(s, values):
if len(s) == 1:
# s holds just one attribute, return its list of values
k = values[s[0]]
- h = [[{s[0]: v}] for v in values[s[0]]]
- return h
-
+ return [[{s[0]: v}] for v in values[s[0]]]
h = []
for i, a in enumerate(s):
rest = build_attr_combinations(s[i + 1:], values)
@@ -194,7 +189,7 @@ def build_attr_combinations(s, values):
for r in rest:
t = o.copy()
for d in r:
- t.update(d)
+ t |= d
h.append([t])
return h
@@ -258,7 +253,7 @@ def tell(self, sentence):
self.const_syms.update(constant_symbols(sentence))
self.pred_syms.update(predicate_symbols(sentence))
else:
- raise Exception('Not a definite clause: {}'.format(sentence))
+ raise Exception(f'Not a definite clause: {sentence}')
def foil(self, examples, target):
"""Learn a list of first-order horn clauses
@@ -287,8 +282,16 @@ def new_clause(self, examples, target):
while extended_examples[1]:
l = self.choose_literal(self.new_literals(clause), extended_examples)
clause[1].append(l)
- extended_examples = [sum([list(self.extend_example(example, l)) for example in
- extended_examples[i]], []) for i in range(2)]
+ extended_examples = [
+ sum(
+ (
+ list(self.extend_example(example, l))
+ for example in extended_examples[i]
+ ),
+ [],
+ )
+ for i in range(2)
+ ]
return clause, extended_examples[0]
@@ -308,10 +311,11 @@ def new_literals(self, clause):
for pred, arity in self.pred_syms:
new_vars = {standardize_variables(expr('x')) for _ in range(arity - 1)}
for args in product(share_vars.union(new_vars), repeat=arity):
- if any(var in share_vars for var in args):
- # make sure we don't return an existing rule
- if not Expr(pred, args) in clause[1]:
- yield Expr(pred, *[var for var in args])
+ if (
+ any(var in share_vars for var in args)
+ and Expr(pred, args) not in clause[1]
+ ):
+ yield Expr(pred, *list(args))
def choose_literal(self, literals, examples):
"""Choose the best literal based on the information gain."""
@@ -335,8 +339,12 @@ def gain(self, l, examples):
"""
pre_pos = len(examples[0])
pre_neg = len(examples[1])
- post_pos = sum([list(self.extend_example(example, l)) for example in examples[0]], [])
- post_neg = sum([list(self.extend_example(example, l)) for example in examples[1]], [])
+ post_pos = sum(
+ (list(self.extend_example(example, l)) for example in examples[0]), []
+ )
+ post_neg = sum(
+ (list(self.extend_example(example, l)) for example in examples[1]), []
+ )
if pre_pos + pre_neg == 0 or len(post_pos) + len(post_neg) == 0:
return -1
# number of positive example that are represented in extended_examples
@@ -345,9 +353,10 @@ def gain(self, l, examples):
represents = lambda d: all(d[x] == example[x] for x in example)
if any(represents(l_) for l_ in post_pos):
T += 1
- value = T * (np.log2(len(post_pos) / (len(post_pos) + len(post_neg)) + 1e-12) -
- np.log2(pre_pos / (pre_pos + pre_neg)))
- return value
+ return T * (
+ np.log2(len(post_pos) / (len(post_pos) + len(post_neg)) + 1e-12)
+ - np.log2(pre_pos / (pre_pos + pre_neg))
+ )
def update_examples(self, target, examples, extended_examples):
"""Add to the kb those examples what are represented in extended_examples
@@ -368,11 +377,7 @@ def update_examples(self, target, examples, extended_examples):
def check_all_consistency(examples, h):
"""Check for the consistency of all examples under h."""
- for e in examples:
- if not is_consistent(e, h):
- return False
-
- return True
+ return all(is_consistent(e, h) for e in examples)
def check_negative_consistency(examples, h):
@@ -403,11 +408,7 @@ def disjunction_value(e, d):
def guess_value(e, h):
"""Guess value of example e under hypothesis h."""
- for d in h:
- if disjunction_value(e, d):
- return True
-
- return False
+ return any(disjunction_value(e, d) for d in h)
def is_consistent(e, h):
diff --git a/learning.py b/learning.py
index 71b6b15e7..6632f3355 100644
--- a/learning.py
+++ b/learning.py
@@ -56,7 +56,7 @@ def __init__(self, examples=None, attrs=None, attr_names=None, target=-1, inputs
if isinstance(examples, str):
self.examples = parse_csv(examples)
elif examples is None:
- self.examples = parse_csv(open_data(name + '.csv').read())
+ self.examples = parse_csv(open_data(f'{name}.csv').read())
else:
self.examples = examples
@@ -111,8 +111,9 @@ def check_example(self, example):
if self.values:
for a in self.attrs:
if example[a] not in self.values[a]:
- raise ValueError('Bad value {} for attribute {} in {}'
- .format(example[a], self.attr_names[a], example))
+ raise ValueError(
+ f'Bad value {example[a]} for attribute {self.attr_names[a]} in {example}'
+ )
def attr_num(self, attr):
"""Returns the number used for attr, which can be a name, or -n .. n-1."""
@@ -286,7 +287,7 @@ def cross_validation(learner, dataset, size=None, k=10, trials=1):
if trials > 1:
trial_errT = 0
trial_errV = 0
- for t in range(trials):
+ for _ in range(trials):
errT, errV = cross_validation(learner, dataset, size, k, trials)
trial_errT += errT
trial_errV += errV
@@ -649,7 +650,7 @@ def BackPropagationLearner(dataset, net, learning_rate, epochs, activation=sigmo
inputs, targets = init_examples(examples, idx_i, idx_t, o_units)
- for epoch in range(epochs):
+ for _ in range(epochs):
# iterate over each example
for e in range(len(examples)):
i_val = inputs[e]
@@ -1024,7 +1025,7 @@ def ada_boost(dataset, L, K):
eps = 1 / (2 * n)
w = [1 / n] * n
h, z = [], []
- for k in range(K):
+ for _ in range(K):
h_k = L(dataset, w)
h.append(h_k)
error = sum(weight for example, weight in zip(examples, w) if example[target] != h_k(example))
@@ -1203,7 +1204,7 @@ def Majority(k, n):
k random bits followed by a 1 if more than half the bits are 1, else 0.
"""
examples = []
- for i in range(n):
+ for _ in range(n):
bits = [random.choice([0, 1]) for _ in range(k)]
bits.append(int(sum(bits) > k / 2))
examples.append(bits)
@@ -1216,7 +1217,7 @@ def Parity(k, n, name='parity'):
k random bits followed by a 1 if an odd number of bits are 1, else 0.
"""
examples = []
- for i in range(n):
+ for _ in range(n):
bits = [random.choice([0, 1]) for _ in range(k)]
bits.append(sum(bits) % 2)
examples.append(bits)
@@ -1231,7 +1232,7 @@ def Xor(n):
def ContinuousXor(n):
"""2 inputs are chosen uniformly from (0.0 .. 2.0]; output is xor of ints."""
examples = []
- for i in range(n):
+ for _ in range(n):
x, y = [random.uniform(0.0, 2.0) for _ in '12']
examples.append([x, y, x != y])
return DataSet(name='continuous xor', examples=examples)
@@ -1249,5 +1250,12 @@ def compare(algorithms=None, datasets=None, k=10, trials=1):
datasets = datasets or [iris, orings, zoo, restaurant, SyntheticRestaurant(20),
Majority(7, 100), Parity(7, 100), Xor(100)]
- print_table([[a.__name__.replace('Learner', '')] + [cross_validation(a, d, k=k, trials=trials) for d in datasets]
- for a in algorithms], header=[''] + [d.name[0:7] for d in datasets], numfmt='%.2f')
+ print_table(
+ [
+ [a.__name__.replace('Learner', '')]
+ + [cross_validation(a, d, k=k, trials=trials) for d in datasets]
+ for a in algorithms
+ ],
+ header=[''] + [d.name[:7] for d in datasets],
+ numfmt='%.2f',
+ )
diff --git a/learning4e.py b/learning4e.py
index 12c0defa5..bb4bb2a95 100644
--- a/learning4e.py
+++ b/learning4e.py
@@ -57,7 +57,7 @@ def __init__(self, examples=None, attrs=None, attr_names=None, target=-1, inputs
if isinstance(examples, str):
self.examples = parse_csv(examples)
elif examples is None:
- self.examples = parse_csv(open_data(name + '.csv').read())
+ self.examples = parse_csv(open_data(f'{name}.csv').read())
else:
self.examples = examples
@@ -112,8 +112,9 @@ def check_example(self, example):
if self.values:
for a in self.attrs:
if example[a] not in self.values[a]:
- raise ValueError('Bad value {} for attribute {} in {}'
- .format(example[a], self.attr_names[a], example))
+ raise ValueError(
+ f'Bad value {example[a]} for attribute {self.attr_names[a]} in {example}'
+ )
def attr_num(self, attr):
"""Returns the number used for attr, which can be a name, or -n .. n-1."""
@@ -285,7 +286,7 @@ def cross_validation(learner, dataset, size=None, k=10, trials=1):
k = k or len(dataset.examples)
if trials > 1:
trial_errs = 0
- for t in range(trials):
+ for _ in range(trials):
errs = cross_validation(learner, dataset, size, k, trials)
trial_errs += errs
return trial_errs / trials
@@ -804,7 +805,7 @@ def ada_boost(dataset, L, K):
eps = 1 / (2 * n)
w = [1 / n] * n
h, z = [], []
- for k in range(K):
+ for _ in range(K):
h_k = L(dataset, w)
h.append(h_k)
error = sum(weight for example, weight in zip(examples, w) if example[target] != h_k.predict(example[:-1]))
@@ -989,7 +990,7 @@ def Majority(k, n):
k random bits followed by a 1 if more than half the bits are 1, else 0.
"""
examples = []
- for i in range(n):
+ for _ in range(n):
bits = [random.choice([0, 1]) for _ in range(k)]
bits.append(int(sum(bits) > k / 2))
examples.append(bits)
@@ -1002,7 +1003,7 @@ def Parity(k, n, name='parity'):
k random bits followed by a 1 if an odd number of bits are 1, else 0.
"""
examples = []
- for i in range(n):
+ for _ in range(n):
bits = [random.choice([0, 1]) for _ in range(k)]
bits.append(sum(bits) % 2)
examples.append(bits)
@@ -1017,7 +1018,7 @@ def Xor(n):
def ContinuousXor(n):
"""2 inputs are chosen uniformly from (0.0 .. 2.0]; output is xor of ints."""
examples = []
- for i in range(n):
+ for _ in range(n):
x, y = [random.uniform(0.0, 2.0) for _ in '12']
examples.append([x, y, x != y])
return DataSet(name='continuous xor', examples=examples)
@@ -1035,5 +1036,12 @@ def compare(algorithms=None, datasets=None, k=10, trials=1):
datasets = datasets or [iris, orings, zoo, restaurant, SyntheticRestaurant(20),
Majority(7, 100), Parity(7, 100), Xor(100)]
- print_table([[a.__name__.replace('Learner', '')] + [cross_validation(a, d, k=k, trials=trials) for d in datasets]
- for a in algorithms], header=[''] + [d.name[0:7] for d in datasets], numfmt='%.2f')
+ print_table(
+ [
+ [a.__name__.replace('Learner', '')]
+ + [cross_validation(a, d, k=k, trials=trials) for d in datasets]
+ for a in algorithms
+ ],
+ header=[''] + [d.name[:7] for d in datasets],
+ numfmt='%.2f',
+ )
diff --git a/logic.py b/logic.py
index 1624d55a5..09409fc5c 100644
--- a/logic.py
+++ b/logic.py
@@ -188,9 +188,8 @@ def parse_definite_clause(s):
assert is_definite_clause(s)
if is_symbol(s.op):
return [], s
- else:
- antecedent, consequent = s.args
- return conjuncts(antecedent), consequent
+ antecedent, consequent = s.args
+ return conjuncts(antecedent), consequent
# Useful constant Exprs used in examples and code:
@@ -217,12 +216,11 @@ def tt_entails(kb, alpha):
def tt_check_all(kb, alpha, symbols, model):
"""Auxiliary routine to implement tt_entails."""
if not symbols:
- if pl_true(kb, model):
- result = pl_true(alpha, model)
- assert result in (True, False)
- return result
- else:
+ if not pl_true(kb, model):
return True
+ result = pl_true(alpha, model)
+ assert result in (True, False)
+ return result
else:
P, rest = symbols[0], symbols[1:]
return (tt_check_all(kb, alpha, rest, extend(model, P, True)) and
@@ -283,10 +281,7 @@ def pl_true(exp, model={}):
return model.get(exp)
elif op == '~':
p = pl_true(args[0], model)
- if p is None:
- return None
- else:
- return not p
+ return None if p is None else not p
elif op == '|':
result = False
for arg in args:
@@ -321,7 +316,7 @@ def pl_true(exp, model={}):
elif op == '^': # xor or 'not equivalent'
return pt != qt
else:
- raise ValueError('Illegal operator in logic expression' + str(exp))
+ raise ValueError(f'Illegal operator in logic expression{str(exp)}')
# ______________________________________________________________________________
@@ -381,9 +376,7 @@ def NOT(b):
return move_not_inwards(a.args[0]) # ~~A ==> A
if a.op == '&':
return associate('|', list(map(NOT, a.args)))
- if a.op == '|':
- return associate('&', list(map(NOT, a.args)))
- return s
+ return associate('&', list(map(NOT, a.args))) if a.op == '|' else s
elif is_symbol(s.op) or not s.args:
return s
else:
@@ -510,9 +503,17 @@ def pl_resolve(ci, cj):
"""Return all clauses that can be obtained by resolving clauses ci and cj."""
clauses = []
for di in disjuncts(ci):
- for dj in disjuncts(cj):
- if di == ~dj or ~di == dj:
- clauses.append(associate('|', unique(remove_all(di, disjuncts(ci)) + remove_all(dj, disjuncts(cj)))))
+ clauses.extend(
+ associate(
+ '|',
+ unique(
+ remove_all(di, disjuncts(ci))
+ + remove_all(dj, disjuncts(cj))
+ ),
+ )
+ for dj in disjuncts(cj)
+ if di == ~dj or ~di == dj
+ )
return clauses
@@ -629,7 +630,7 @@ def momsf(symbols, clauses, k=0):
scores = Counter(l for c in min_clauses(clauses) for l in disjuncts(c))
P = max(symbols,
key=lambda symbol: (scores[symbol] + scores[~symbol]) * pow(2, k) + scores[symbol] * scores[~symbol])
- return P, True if scores[P] >= scores[~P] else False
+ return P, scores[P] >= scores[~P]
def posit(symbols, clauses):
@@ -640,7 +641,7 @@ def posit(symbols, clauses):
"""
scores = Counter(l for c in min_clauses(clauses) for l in disjuncts(c))
P = max(symbols, key=lambda symbol: scores[symbol] + scores[~symbol])
- return P, True if scores[P] >= scores[~P] else False
+ return P, scores[P] >= scores[~P]
def zm(symbols, clauses):
@@ -660,7 +661,7 @@ def dlis(symbols, clauses):
"""
scores = Counter(l for c in clauses for l in disjuncts(c))
P = max(symbols, key=lambda symbol: scores[symbol])
- return P, True if scores[P] >= scores[~P] else False
+ return P, scores[P] >= scores[~P]
def dlcs(symbols, clauses):
@@ -673,7 +674,7 @@ def dlcs(symbols, clauses):
"""
scores = Counter(l for c in clauses for l in disjuncts(c))
P = max(symbols, key=lambda symbol: scores[symbol] + scores[~symbol])
- return P, True if scores[P] >= scores[~P] else False
+ return P, scores[P] >= scores[~P]
def jw(symbols, clauses):
@@ -700,7 +701,7 @@ def jw2(symbols, clauses):
for l in disjuncts(c):
scores[l] += pow(2, -len(c.args))
P = max(symbols, key=lambda symbol: scores[symbol] + scores[~symbol])
- return P, True if scores[P] >= scores[~P] else False
+ return P, scores[P] >= scores[~P]
# ______________________________________________________________________________
@@ -785,12 +786,9 @@ def unit_clause_assign(clause, model):
P, value = None, None
for literal in disjuncts(clause):
sym, positive = inspect_literal(literal)
- if sym in model:
- if model[sym] == positive:
- return None, None # clause already True
- elif P:
- return None, None # more than 1 unbound variable
- else:
+ if sym in model and model[sym] == positive or sym not in model and P:
+ return None, None # clause already True
+ elif sym not in model:
P, value = sym, positive
return P, value
@@ -803,10 +801,7 @@ def inspect_literal(literal):
>>> inspect_literal(~P)
(P, False)
"""
- if literal.op == '~':
- return literal.args[0], False
- else:
- return literal, True
+ return (literal.args[0], False) if literal.op == '~' else (literal, True)
# ______________________________________________________________________________
@@ -855,8 +850,7 @@ def cdcl_satisfiable(s, vsids_decay=0.95, restart_strategy=no_restart):
sum_lbd = 0
queue_lbd = []
while True:
- conflict = unit_propagation(clauses, symbols, model, G, dl)
- if conflict:
+ if conflict := unit_propagation(clauses, symbols, model, G, dl):
if dl == 0:
return False
conflicts += 1
@@ -865,9 +859,9 @@ def cdcl_satisfiable(s, vsids_decay=0.95, restart_strategy=no_restart):
sum_lbd += lbd
backjump(symbols, model, G, dl)
clauses.add(learn, model)
- scores.update(l for l in disjuncts(learn))
- for symbol in scores:
- scores[symbol] *= vsids_decay
+ scores |= iter(disjuncts(learn))
+ for value in scores.values():
+ value *= vsids_decay
if restart_strategy(conflicts, restarts, queue_lbd, sum_lbd):
backjump(symbols, model, G)
queue_lbd.clear()
@@ -881,7 +875,7 @@ def cdcl_satisfiable(s, vsids_decay=0.95, restart_strategy=no_restart):
def assign_decision_literal(symbols, model, scores, G, dl):
P = max(symbols, key=lambda symbol: scores[symbol] + scores[~symbol])
- value = True if scores[P] >= scores[~P] else False
+ value = scores[P] >= scores[~P]
symbols.remove(P)
model[P] = value
G.add_node(P, val=value, dl=dl)
@@ -1003,16 +997,12 @@ def set_second_watched(self, clause, new_watching):
def get_first_watched(self, clause):
if len(clause.args) == 2:
return clause.args[0]
- if len(clause.args) > 2:
- return self.__twl[clause][0]
- return clause
+ return self.__twl[clause][0] if len(clause.args) > 2 else clause
def get_second_watched(self, clause):
if len(clause.args) == 2:
return clause.args[-1]
- if len(clause.args) > 2:
- return self.__twl[clause][1]
- return clause
+ return self.__twl[clause][1] if len(clause.args) > 2 else clause
def get_pos_watched(self, l):
return self.__watch_list[l][0]
@@ -1127,16 +1117,17 @@ def MapColoringSAT(colors, neighbors):
colors = UniversalDict(colors)
clauses = []
for state in neighbors.keys():
- clause = [expr(state + '_' + c) for c in colors[state]]
+ clause = [expr(f'{state}_{c}') for c in colors[state]]
clauses.append(clause)
- for t in itertools.combinations(clause, 2):
- clauses.append([~t[0], ~t[1]])
+ clauses.extend([~t[0], ~t[1]] for t in itertools.combinations(clause, 2))
visited = set()
adj = set(neighbors[state]) - visited
visited.add(state)
for n_state in adj:
- for col in colors[n_state]:
- clauses.append([expr('~' + state + '_' + col), expr('~' + n_state + '_' + col)])
+ clauses.extend(
+ [expr(f'~{state}_{col}'), expr(f'~{n_state}_{col}')]
+ for col in colors[n_state]
+ )
return associate('&', map(lambda c: associate('|', c), clauses))
@@ -1248,10 +1239,7 @@ def ok_to_move(x, y, time):
def location(x, y, time=None):
- if time is None:
- return Expr('L', x, y)
- else:
- return Expr('L', x, y, time)
+ return Expr('L', x, y) if time is None else Expr('L', x, y, time)
# Symbols
@@ -1290,8 +1278,8 @@ def __init__(self, dimrow):
for y in range(1, dimrow + 1):
for x in range(1, dimrow + 1):
- pits_in = list()
- wumpus_in = list()
+ pits_in = []
+ wumpus_in = []
if x > 1: # West room exists
pits_in.append(pit(x - 1, y))
@@ -1313,11 +1301,9 @@ def __init__(self, dimrow):
self.tell(equiv(stench(x, y), new_disjunction(wumpus_in)))
# Rule that describes existence of at least one Wumpus
- wumpus_at_least = list()
+ wumpus_at_least = []
for x in range(1, dimrow + 1):
- for y in range(1, dimrow + 1):
- wumpus_at_least.append(wumpus(x, y))
-
+ wumpus_at_least.extend(wumpus(x, y) for y in range(1, dimrow + 1))
self.tell(new_disjunction(wumpus_at_least))
# Rule that describes existence of at most one Wumpus
@@ -1398,8 +1384,13 @@ def add_temporal_sentences(self, time):
for j in range(1, self.dimrow + 1):
self.tell(implies(location(i, j, time), equiv(percept_breeze(time), breeze(i, j))))
self.tell(implies(location(i, j, time), equiv(percept_stench(time), stench(i, j))))
- s = list()
- s.append(equiv(location(i, j, time), location(i, j, time) & ~move_forward(time) | percept_bump(time)))
+ s = [
+ equiv(
+ location(i, j, time),
+ location(i, j, time) & ~move_forward(time)
+ | percept_bump(time),
+ )
+ ]
if i != 1:
s.append(location(i - 1, j, t) & facing_east(t) & move_forward(t))
if i != self.dimrow:
@@ -1477,10 +1468,10 @@ def set_orientation(self, orientation):
self.orientation = orientation
def __eq__(self, other):
- if other.get_location() == self.get_location() and other.get_orientation() == self.get_orientation():
- return True
- else:
- return False
+ return (
+ other.get_location() == self.get_location()
+ and other.get_orientation() == self.get_orientation()
+ )
# ______________________________________________________________________________
@@ -1496,7 +1487,7 @@ def __init__(self, dimentions):
self.dimrow = dimentions
self.kb = WumpusKB(self.dimrow)
self.t = 0
- self.plan = list()
+ self.plan = []
self.current_position = WumpusPosition(1, 1, 'UP')
super().__init__(self.execute)
@@ -1504,14 +1495,12 @@ def execute(self, percept):
self.kb.make_percept_sentence(percept, self.t)
self.kb.add_temporal_sentences(self.t)
- temp = list()
+ temp = []
for i in range(1, self.dimrow + 1):
for j in range(1, self.dimrow + 1):
if self.kb.ask_if_true(location(i, j, self.t)):
- temp.append(i)
- temp.append(j)
-
+ temp.extend((i, j))
if self.kb.ask_if_true(facing_north(self.t)):
self.current_position = WumpusPosition(temp[0], temp[1], 'UP')
elif self.kb.ask_if_true(facing_south(self.t)):
@@ -1521,58 +1510,61 @@ def execute(self, percept):
elif self.kb.ask_if_true(facing_east(self.t)):
self.current_position = WumpusPosition(temp[0], temp[1], 'RIGHT')
- safe_points = list()
+ safe_points = []
for i in range(1, self.dimrow + 1):
- for j in range(1, self.dimrow + 1):
- if self.kb.ask_if_true(ok_to_move(i, j, self.t)):
- safe_points.append([i, j])
-
+ safe_points.extend(
+ [i, j]
+ for j in range(1, self.dimrow + 1)
+ if self.kb.ask_if_true(ok_to_move(i, j, self.t))
+ )
if self.kb.ask_if_true(percept_glitter(self.t)):
- goals = list()
- goals.append([1, 1])
+ goals = [[1, 1]]
self.plan.append('Grab')
actions = self.plan_route(self.current_position, goals, safe_points)
self.plan.extend(actions)
self.plan.append('Climb')
if len(self.plan) == 0:
- unvisited = list()
+ unvisited = []
for i in range(1, self.dimrow + 1):
for j in range(1, self.dimrow + 1):
- for k in range(self.t):
- if self.kb.ask_if_true(location(i, j, k)):
- unvisited.append([i, j])
- unvisited_and_safe = list()
- for u in unvisited:
- for s in safe_points:
- if u not in unvisited_and_safe and s == u:
- unvisited_and_safe.append(u)
+ unvisited.extend(
+ [i, j]
+ for k in range(self.t)
+ if self.kb.ask_if_true(location(i, j, k))
+ )
+ unvisited_and_safe = []
+ for u, s in itertools.product(unvisited, safe_points):
+ if u not in unvisited_and_safe and s == u:
+ unvisited_and_safe.append(u)
temp = self.plan_route(self.current_position, unvisited_and_safe, safe_points)
self.plan.extend(temp)
if len(self.plan) == 0 and self.kb.ask_if_true(have_arrow(self.t)):
- possible_wumpus = list()
+ possible_wumpus = []
for i in range(1, self.dimrow + 1):
- for j in range(1, self.dimrow + 1):
- if not self.kb.ask_if_true(wumpus(i, j)):
- possible_wumpus.append([i, j])
-
+ possible_wumpus.extend(
+ [i, j]
+ for j in range(1, self.dimrow + 1)
+ if not self.kb.ask_if_true(wumpus(i, j))
+ )
temp = self.plan_shot(self.current_position, possible_wumpus, safe_points)
self.plan.extend(temp)
if len(self.plan) == 0:
- not_unsafe = list()
+ not_unsafe = []
for i in range(1, self.dimrow + 1):
- for j in range(1, self.dimrow + 1):
- if not self.kb.ask_if_true(ok_to_move(i, j, self.t)):
- not_unsafe.append([i, j])
+ not_unsafe.extend(
+ [i, j]
+ for j in range(1, self.dimrow + 1)
+ if not self.kb.ask_if_true(ok_to_move(i, j, self.t))
+ )
temp = self.plan_route(self.current_position, not_unsafe, safe_points)
self.plan.extend(temp)
if len(self.plan) == 0:
- start = list()
- start.append([1, 1])
+ start = [[1, 1]]
temp = self.plan_route(self.current_position, start, safe_points)
self.plan.extend(temp)
self.plan.append('Climb')
@@ -1610,7 +1602,7 @@ def plan_shot(self, current, goals, allowed):
for orientation in orientations:
shooting_positions.remove(WumpusPosition(loc[0], loc[1], orientation))
- actions = list()
+ actions = []
actions.extend(self.plan_route(current, shooting_positions, allowed))
actions.append('Shoot')
return actions
@@ -1631,7 +1623,7 @@ def SAT_plan(init, transition, goal, t_max, SAT_solver=cdcl_satisfiable):
# Functions used by SAT_plan
def translate_to_SAT(init, transition, goal, time):
clauses = []
- states = [state for state in transition]
+ states = list(transition)
# Symbol claiming state s at time t
state_counter = itertools.count()
@@ -1730,9 +1722,7 @@ def unify(x, y, s={}):
elif isinstance(x, str) or isinstance(y, str):
return None
elif issequence(x) and issequence(y) and len(x) == len(y):
- if not x:
- return s
- return unify(x[1:], y[1:], unify(x[0], y[0], s))
+ return s if not x else unify(x[1:], y[1:], unify(x[0], y[0], s))
else:
return None
@@ -1779,7 +1769,7 @@ def subst(s, x):
if isinstance(x, list):
return [subst(s, xi) for xi in x]
elif isinstance(x, tuple):
- return tuple([subst(s, xi) for xi in x])
+ return tuple(subst(s, xi) for xi in x)
elif not isinstance(x, Expr):
return x
elif is_var_symbol(x.op):
@@ -1834,13 +1824,10 @@ def unify_mm(x, y, s={}):
# variable elimination (there is a chance to apply rule d)
s[x] = vars_elimination(y, s)
elif not is_variable(x) and not is_variable(y):
- # in which case x and y are not variables, if the two root function symbols
- # are different, stop with failure, else apply term reduction (rule c)
- if x.op is y.op and len(x.args) == len(y.args):
- term_reduction(x, y, s)
- del s[x]
- else:
+ if x.op is not y.op or len(x.args) != len(y.args):
return None
+ term_reduction(x, y, s)
+ del s[x]
elif isinstance(y, Expr):
# in which case x is a variable and y is a function or a variable (e.g. F(z) or y),
# if y is a function, we must check if x occurs in y, then stop with failure, else
@@ -1890,10 +1877,9 @@ def standardize_variables(sentence, dic=None):
elif is_var_symbol(sentence.op):
if sentence in dic:
return dic[sentence]
- else:
- v = Expr('v_{}'.format(next(standardize_variables.counter)))
- dic[sentence] = v
- return v
+ v = Expr(f'v_{next(standardize_variables.counter)}')
+ dic[sentence] = v
+ return v
else:
return Expr(sentence.op, *[standardize_variables(a, dic) for a in sentence.args])
@@ -1906,12 +1892,32 @@ def standardize_variables(sentence, dic=None):
def parse_clauses_from_dimacs(dimacs_cnf):
"""Converts a string into CNF clauses according to the DIMACS format used in SAT competitions"""
- return map(lambda c: associate('|', c),
- map(lambda c: [expr('~X' + str(abs(l))) if l < 0 else expr('X' + str(l)) for l in c],
- map(lambda line: map(int, line.split()),
- filter(None, ' '.join(
- filter(lambda line: line[0] not in ('c', 'p'),
- filter(None, dimacs_cnf.strip().replace('\t', ' ').split('\n')))).split(' 0')))))
+ return map(
+ lambda c: associate('|', c),
+ map(
+ lambda c: [
+ expr(f'~X{str(abs(l))}') if l < 0 else expr(f'X{str(l)}')
+ for l in c
+ ],
+ map(
+ lambda line: map(int, line.split()),
+ filter(
+ None,
+ ' '.join(
+ filter(
+ lambda line: line[0] not in ('c', 'p'),
+ filter(
+ None,
+ dimacs_cnf.strip()
+ .replace('\t', ' ')
+ .split('\n'),
+ ),
+ )
+ ).split(' 0'),
+ ),
+ ),
+ ),
+ )
# ______________________________________________________________________________
@@ -1940,7 +1946,7 @@ def tell(self, sentence):
if is_definite_clause(sentence):
self.clauses.append(sentence)
else:
- raise Exception('Not a definite clause: {}'.format(sentence))
+ raise Exception(f'Not a definite clause: {sentence}')
def ask_generator(self, query):
return fol_bc_ask(self, query)
@@ -1963,8 +1969,7 @@ def fol_fc_ask(kb, alpha):
def enum_subst(p):
query_vars = list({v for clause in p for v in variables(clause)})
for assignment_list in itertools.product(kb_consts, repeat=len(query_vars)):
- theta = {x: y for x, y in zip(query_vars, assignment_list)}
- yield theta
+ yield dict(zip(query_vars, assignment_list))
# check if we can answer without new inferences
for q in kb.clauses:
@@ -1979,7 +1984,7 @@ def enum_subst(p):
for theta in enum_subst(p):
if set(subst(theta, p)).issubset(set(kb.clauses)):
q_ = subst(theta, q)
- if all([unify_mm(x, q_) is None for x in kb.clauses + new]):
+ if all(unify_mm(x, q_) is None for x in kb.clauses + new):
new.append(q_)
phi = unify_mm(q_, alpha)
if phi is not None:
@@ -2003,8 +2008,7 @@ def fol_bc_ask(kb, query):
def fol_bc_or(kb, goal, theta):
for rule in kb.fetch_rules_for_goal(goal):
lhs, rhs = parse_definite_clause(standardize_variables(rule))
- for theta1 in fol_bc_and(kb, lhs, unify_mm(rhs, goal, theta)):
- yield theta1
+ yield from fol_bc_and(kb, lhs, unify_mm(rhs, goal, theta))
def fol_bc_and(kb, goals, theta):
@@ -2015,8 +2019,7 @@ def fol_bc_and(kb, goals, theta):
else:
first, rest = goals[0], goals[1:]
for theta1 in fol_bc_or(kb, subst(theta, first), theta):
- for theta2 in fol_bc_and(kb, rest, theta1):
- yield theta2
+ yield from fol_bc_and(kb, rest, theta1)
# A simple KB that defines the relevant conditions of the Wumpus World as in Figure 7.4.
@@ -2090,7 +2093,7 @@ def diff(y, x):
elif op == 'log':
return diff(u, x) / u
else:
- raise ValueError('Unknown op: {} in diff({}, {})'.format(op, y, x))
+ raise ValueError(f'Unknown op: {op} in diff({y}, {x})')
def simp(x):
@@ -2151,7 +2154,7 @@ def simp(x):
if u == 1:
return 0
else:
- raise ValueError('Unknown op: ' + op)
+ raise ValueError(f'Unknown op: {op}')
# If we fall through to here, we can not simplify further
return Expr(op, *args)
diff --git a/logic4e.py b/logic4e.py
index 75608ad74..bf69f6c23 100644
--- a/logic4e.py
+++ b/logic4e.py
@@ -217,10 +217,7 @@ def ok_to_move(x, y, time):
def location(x, y, time=None):
- if time is None:
- return Expr('L', x, y)
- else:
- return Expr('L', x, y, time)
+ return Expr('L', x, y) if time is None else Expr('L', x, y, time)
# Symbols
@@ -303,9 +300,8 @@ def parse_definite_clause(s):
assert is_definite_clause(s)
if is_symbol(s.op):
return [], s
- else:
- antecedent, consequent = s.args
- return conjuncts(antecedent), consequent
+ antecedent, consequent = s.args
+ return conjuncts(antecedent), consequent
# Useful constant Exprs used in examples and code:
@@ -332,12 +328,11 @@ def tt_entails(kb, alpha):
def tt_check_all(kb, alpha, symbols, model):
"""Auxiliary routine to implement tt_entails."""
if not symbols:
- if pl_true(kb, model):
- result = pl_true(alpha, model)
- assert result in (True, False)
- return result
- else:
+ if not pl_true(kb, model):
return True
+ result = pl_true(alpha, model)
+ assert result in (True, False)
+ return result
else:
P, rest = symbols[0], symbols[1:]
return (tt_check_all(kb, alpha, rest, extend(model, P, True)) and
@@ -401,10 +396,7 @@ def pl_true(exp, model={}):
return model.get(exp)
elif op == '~':
p = pl_true(args[0], model)
- if p is None:
- return None
- else:
- return not p
+ return None if p is None else not p
elif op == '|':
result = False
for arg in args:
@@ -439,7 +431,7 @@ def pl_true(exp, model={}):
elif op == '^': # xor or 'not equivalent'
return pt != qt
else:
- raise ValueError("illegal operator in logic expression" + str(exp))
+ raise ValueError(f"illegal operator in logic expression{str(exp)}")
# ______________________________________________________________________________
@@ -496,9 +488,7 @@ def NOT(b):
return move_not_inwards(a.args[0]) # ~~A ==> A
if a.op == '&':
return associate('|', list(map(NOT, a.args)))
- if a.op == '|':
- return associate('&', list(map(NOT, a.args)))
- return s
+ return associate('&', list(map(NOT, a.args))) if a.op == '|' else s
elif is_symbol(s.op) or not s.args:
return s
else:
@@ -793,12 +783,9 @@ def unit_clause_assign(clause, model):
P, value = None, None
for literal in disjuncts(clause):
sym, positive = inspect_literal(literal)
- if sym in model:
- if model[sym] == positive:
- return None, None # clause already True
- elif P:
- return None, None # more than 1 unbound variable
- else:
+ if sym in model and model[sym] == positive or sym not in model and P:
+ return None, None # clause already True
+ elif sym not in model:
P, value = sym, positive
return P, value
@@ -811,10 +798,7 @@ def inspect_literal(literal):
>>> inspect_literal(~P)
(P, False)
"""
- if literal.op == '~':
- return literal.args[0], False
- else:
- return literal, True
+ return (literal.args[0], False) if literal.op == '~' else (literal, True)
# ______________________________________________________________________________
@@ -875,8 +859,8 @@ def __init__(self, dimrow):
for y in range(1, dimrow + 1):
for x in range(1, dimrow + 1):
- pits_in = list()
- wumpus_in = list()
+ pits_in = []
+ wumpus_in = []
if x > 1: # West room exists
pits_in.append(pit(x - 1, y))
@@ -898,11 +882,9 @@ def __init__(self, dimrow):
self.tell(equiv(stench(x, y), new_disjunction(wumpus_in)))
# Rule that describes existence of at least one Wumpus
- wumpus_at_least = list()
+ wumpus_at_least = []
for x in range(1, dimrow + 1):
- for y in range(1, dimrow + 1):
- wumpus_at_least.append(wumpus(x, y))
-
+ wumpus_at_least.extend(wumpus(x, y) for y in range(1, dimrow + 1))
self.tell(new_disjunction(wumpus_at_least))
# Rule that describes existence of at most one Wumpus
@@ -984,11 +966,13 @@ def add_temporal_sentences(self, time):
self.tell(implies(location(i, j, time), equiv(percept_breeze(time), breeze(i, j))))
self.tell(implies(location(i, j, time), equiv(percept_stench(time), stench(i, j))))
- s = list()
-
- s.append(
+ s = [
equiv(
- location(i, j, time), location(i, j, time) & ~move_forward(time) | percept_bump(time)))
+ location(i, j, time),
+ location(i, j, time) & ~move_forward(time)
+ | percept_bump(time),
+ )
+ ]
if i != 1:
s.append(location(i - 1, j, t) & facing_east(t) & move_forward(t))
@@ -1072,11 +1056,10 @@ def set_orientation(self, orientation):
self.orientation = orientation
def __eq__(self, other):
- if (other.get_location() == self.get_location() and
- other.get_orientation() == self.get_orientation()):
- return True
- else:
- return False
+ return (
+ other.get_location() == self.get_location()
+ and other.get_orientation() == self.get_orientation()
+ )
# ______________________________________________________________________________
@@ -1090,7 +1073,7 @@ def __init__(self, dimentions):
self.dimrow = dimentions
self.kb = WumpusKB(self.dimrow)
self.t = 0
- self.plan = list()
+ self.plan = []
self.current_position = WumpusPosition(1, 1, 'UP')
super().__init__(self.execute)
@@ -1098,14 +1081,12 @@ def execute(self, percept):
self.kb.make_percept_sentence(percept, self.t)
self.kb.add_temporal_sentences(self.t)
- temp = list()
+ temp = []
for i in range(1, self.dimrow + 1):
for j in range(1, self.dimrow + 1):
if self.kb.ask_if_true(location(i, j, self.t)):
- temp.append(i)
- temp.append(j)
-
+ temp.extend((i, j))
if self.kb.ask_if_true(facing_north(self.t)):
self.current_position = WumpusPosition(temp[0], temp[1], 'UP')
elif self.kb.ask_if_true(facing_south(self.t)):
@@ -1115,58 +1096,61 @@ def execute(self, percept):
elif self.kb.ask_if_true(facing_east(self.t)):
self.current_position = WumpusPosition(temp[0], temp[1], 'RIGHT')
- safe_points = list()
+ safe_points = []
for i in range(1, self.dimrow + 1):
- for j in range(1, self.dimrow + 1):
- if self.kb.ask_if_true(ok_to_move(i, j, self.t)):
- safe_points.append([i, j])
-
+ safe_points.extend(
+ [i, j]
+ for j in range(1, self.dimrow + 1)
+ if self.kb.ask_if_true(ok_to_move(i, j, self.t))
+ )
if self.kb.ask_if_true(percept_glitter(self.t)):
- goals = list()
- goals.append([1, 1])
+ goals = [[1, 1]]
self.plan.append('Grab')
actions = self.plan_route(self.current_position, goals, safe_points)
self.plan.extend(actions)
self.plan.append('Climb')
if len(self.plan) == 0:
- unvisited = list()
+ unvisited = []
for i in range(1, self.dimrow + 1):
for j in range(1, self.dimrow + 1):
- for k in range(self.t):
- if self.kb.ask_if_true(location(i, j, k)):
- unvisited.append([i, j])
- unvisited_and_safe = list()
- for u in unvisited:
- for s in safe_points:
- if u not in unvisited_and_safe and s == u:
- unvisited_and_safe.append(u)
+ unvisited.extend(
+ [i, j]
+ for k in range(self.t)
+ if self.kb.ask_if_true(location(i, j, k))
+ )
+ unvisited_and_safe = []
+ for u, s in itertools.product(unvisited, safe_points):
+ if u not in unvisited_and_safe and s == u:
+ unvisited_and_safe.append(u)
temp = self.plan_route(self.current_position, unvisited_and_safe, safe_points)
self.plan.extend(temp)
if len(self.plan) == 0 and self.kb.ask_if_true(have_arrow(self.t)):
- possible_wumpus = list()
+ possible_wumpus = []
for i in range(1, self.dimrow + 1):
- for j in range(1, self.dimrow + 1):
- if not self.kb.ask_if_true(wumpus(i, j)):
- possible_wumpus.append([i, j])
-
+ possible_wumpus.extend(
+ [i, j]
+ for j in range(1, self.dimrow + 1)
+ if not self.kb.ask_if_true(wumpus(i, j))
+ )
temp = self.plan_shot(self.current_position, possible_wumpus, safe_points)
self.plan.extend(temp)
if len(self.plan) == 0:
- not_unsafe = list()
+ not_unsafe = []
for i in range(1, self.dimrow + 1):
- for j in range(1, self.dimrow + 1):
- if not self.kb.ask_if_true(ok_to_move(i, j, self.t)):
- not_unsafe.append([i, j])
+ not_unsafe.extend(
+ [i, j]
+ for j in range(1, self.dimrow + 1)
+ if not self.kb.ask_if_true(ok_to_move(i, j, self.t))
+ )
temp = self.plan_route(self.current_position, not_unsafe, safe_points)
self.plan.extend(temp)
if len(self.plan) == 0:
- start = list()
- start.append([1, 1])
+ start = [[1, 1]]
temp = self.plan_route(self.current_position, start, safe_points)
self.plan.extend(temp)
self.plan.append('Climb')
@@ -1204,7 +1188,7 @@ def plan_shot(self, current, goals, allowed):
for orientation in orientations:
shooting_positions.remove(WumpusPosition(loc[0], loc[1], orientation))
- actions = list()
+ actions = []
actions.extend(self.plan_route(current, shooting_positions, allowed))
actions.append('Shoot')
return actions
@@ -1225,7 +1209,7 @@ def SAT_plan(init, transition, goal, t_max, SAT_solver=dpll_satisfiable):
# Functions used by SAT_plan
def translate_to_SAT(init, transition, goal, time):
clauses = []
- states = [state for state in transition]
+ states = list(transition)
# Symbol claiming state s at time t
state_counter = itertools.count()
@@ -1324,9 +1308,7 @@ def unify(x, y, s={}):
elif isinstance(x, str) or isinstance(y, str):
return None
elif issequence(x) and issequence(y) and len(x) == len(y):
- if not x:
- return s
- return unify(x[1:], y[1:], unify(x[0], y[0], s))
+ return s if not x else unify(x[1:], y[1:], unify(x[0], y[0], s))
else:
return None
@@ -1398,7 +1380,7 @@ def tell(self, sentence):
if is_definite_clause(sentence):
self.clauses.append(sentence)
else:
- raise Exception("Not a definite clause: {}".format(sentence))
+ raise Exception(f"Not a definite clause: {sentence}")
def ask_generator(self, query):
return fol_bc_ask(self, query)
@@ -1422,8 +1404,7 @@ def fol_fc_ask(KB, alpha):
def enum_subst(p):
query_vars = list({v for clause in p for v in variables(clause)})
for assignment_list in itertools.product(kb_consts, repeat=len(query_vars)):
- theta = {x: y for x, y in zip(query_vars, assignment_list)}
- yield theta
+ yield dict(zip(query_vars, assignment_list))
# check if we can answer without new inferences
for q in KB.clauses:
@@ -1438,7 +1419,7 @@ def enum_subst(p):
for theta in enum_subst(p):
if set(subst(theta, p)).issubset(set(KB.clauses)):
q_ = subst(theta, q)
- if all([unify(x, q_, {}) is None for x in KB.clauses + new]):
+ if all(unify(x, q_, {}) is None for x in KB.clauses + new):
new.append(q_)
phi = unify(q_, alpha, {})
if phi is not None:
@@ -1458,7 +1439,7 @@ def subst(s, x):
if isinstance(x, list):
return [subst(s, xi) for xi in x]
elif isinstance(x, tuple):
- return tuple([subst(s, xi) for xi in x])
+ return tuple(subst(s, xi) for xi in x)
elif not isinstance(x, Expr):
return x
elif is_var_symbol(x.op):
@@ -1476,10 +1457,9 @@ def standardize_variables(sentence, dic=None):
elif is_var_symbol(sentence.op):
if sentence in dic:
return dic[sentence]
- else:
- v = Expr('v_{}'.format(next(standardize_variables.counter)))
- dic[sentence] = v
- return v
+ v = Expr(f'v_{next(standardize_variables.counter)}')
+ dic[sentence] = v
+ return v
else:
return Expr(sentence.op,
*[standardize_variables(a, dic) for a in sentence.args])
@@ -1501,8 +1481,7 @@ def fol_bc_ask(KB, query):
def fol_bc_or(KB, goal, theta):
for rule in KB.fetch_rules_for_goal(goal):
lhs, rhs = parse_definite_clause(standardize_variables(rule))
- for theta1 in fol_bc_and(KB, lhs, unify(rhs, goal, theta)):
- yield theta1
+ yield from fol_bc_and(KB, lhs, unify(rhs, goal, theta))
def fol_bc_and(KB, goals, theta):
@@ -1513,8 +1492,7 @@ def fol_bc_and(KB, goals, theta):
else:
first, rest = goals[0], goals[1:]
for theta1 in fol_bc_or(KB, subst(theta, first), theta):
- for theta2 in fol_bc_and(KB, rest, theta1):
- yield theta2
+ yield from fol_bc_and(KB, rest, theta1)
# ______________________________________________________________________________
@@ -1591,7 +1569,7 @@ def diff(y, x):
elif op == 'log':
return diff(u, x) / u
else:
- raise ValueError("Unknown op: {} in diff({}, {})".format(op, y, x))
+ raise ValueError(f"Unknown op: {op} in diff({y}, {x})")
def simp(x):
@@ -1652,7 +1630,7 @@ def simp(x):
if u == 1:
return 0
else:
- raise ValueError("Unknown op: " + op)
+ raise ValueError(f"Unknown op: {op}")
# If we fall through to here, we can not simplify further
return Expr(op, *args)
diff --git a/making_simple_decision4e.py b/making_simple_decision4e.py
index 4a35f94bd..51ea3edc3 100644
--- a/making_simple_decision4e.py
+++ b/making_simple_decision4e.py
@@ -78,10 +78,7 @@ def cost(self, var):
def vpi_cost_ratio(self, variables):
"""Return the VPI to cost ratio for the given variables"""
- v_by_c = []
- for var in variables:
- v_by_c.append(self.vpi(var) / self.cost(var))
- return v_by_c
+ return [self.vpi(var) / self.cost(var) for var in variables]
def vpi(self, variable):
"""Return VPI for a given variable"""
@@ -119,8 +116,7 @@ def sample(self):
pos = random.choice(self.empty)
# 0N 1E 2S 3W
orient = random.choice(range(4))
- kin_state = pos + (orient,)
- return kin_state
+ return pos + (orient,)
def ray_cast(self, sensor_num, kin_state):
"""Returns distace to nearest obstacle or map boundary in the direction of sensor"""
@@ -162,7 +158,7 @@ def ray_cast(sensor_num, kin_state, m):
W_[i] = 1
for j in range(M):
z_ = ray_cast(j, S_[i], m)
- W_[i] = W_[i] * P_sensor(z[j], z_)
+ W_[i] *= P_sensor(z[j], z_)
S = weighted_sample_with_replacement(N, S_, W_)
return S
diff --git a/mdp.py b/mdp.py
index 1003e26b5..5e4db1689 100644
--- a/mdp.py
+++ b/mdp.py
@@ -34,14 +34,10 @@ def __init__(self, init, actlist, terminals, transitions=None, reward=None, stat
self.init = init
- if isinstance(actlist, list):
+ if isinstance(actlist, (list, dict)):
# if actlist is a list, all states have the same actions
self.actlist = actlist
- elif isinstance(actlist, dict):
- # if actlist is a dict, different actions for each state
- self.actlist = actlist
-
self.terminals = terminals
self.transitions = transitions or {}
if not self.transitions:
@@ -72,17 +68,17 @@ def actions(self, state):
fixed list of actions, except for terminal states. Override this
method if you need to specialize by state."""
- if state in self.terminals:
- return [None]
- else:
- return self.actlist
+ return [None] if state in self.terminals else self.actlist
def get_states_from_transitions(self, transitions):
if isinstance(transitions, dict):
s1 = set(transitions.keys())
- s2 = set(tr[1] for actions in transitions.values()
- for effects in actions.values()
- for tr in effects)
+ s2 = {
+ tr[1]
+ for actions in transitions.values()
+ for effects in actions.values()
+ for tr in effects
+ }
return s1.union(s2)
else:
print('Could not retrieve states from transitions')
@@ -105,9 +101,7 @@ def check_consistency(self):
# check that probability distributions for all actions sum to 1
for s1, actions in self.transitions.items():
for a in actions.keys():
- s = 0
- for o in actions[a]:
- s += o[0]
+ s = sum(o[0] for o in actions[a])
assert abs(s - 1) < 0.001
@@ -120,10 +114,7 @@ def __init__(self, init, actlist, terminals, transitions, reward=None, gamma=0.9
MDP.__init__(self, init, actlist, terminals, transitions, reward, gamma=gamma)
def T(self, state, action):
- if action is None:
- return [(0.0, state)]
- else:
- return self.transitions[state][action]
+ return [(0.0, state)] if action is None else self.transitions[state][action]
class GridMDP(MDP):
@@ -220,10 +211,10 @@ def best_policy(mdp, U):
"""Given an MDP and a utility function U, determine the best policy,
as a mapping from state to action. [Equation 17.4]"""
- pi = {}
- for s in mdp.states:
- pi[s] = max(mdp.actions(s), key=lambda a: expected_utility(a, s, U, mdp))
- return pi
+ return {
+ s: max(mdp.actions(s), key=lambda a: expected_utility(a, s, U, mdp))
+ for s in mdp.states
+ }
def expected_utility(a, s, U, mdp):
@@ -257,7 +248,7 @@ def policy_evaluation(pi, U, mdp, k=20):
utility, using an approximation (modified policy iteration)."""
R, T, gamma = mdp.R, mdp.T, mdp.gamma
- for i in range(k):
+ for _ in range(k):
for s in mdp.states:
U[s] = R(s) + gamma * sum(p * U[s1] for (p, s1) in T(s, pi[s]))
return U
@@ -366,12 +357,8 @@ def max_difference(self, U1, U2):
"""Find maximum difference between two utility mappings"""
for k, v in U1.items():
- sum1 = 0
- for element in U1[k]:
- sum1 += sum(element)
- sum2 = 0
- for element in U2[k]:
- sum2 += sum(element)
+ sum1 = sum(sum(element) for element in U1[k])
+ sum2 = sum(sum(element) for element in U2[k])
return abs(sum1 - sum2)
@@ -384,9 +371,7 @@ def add(A, B):
res = []
for i in range(len(A)):
- row = []
- for j in range(len(A[0])):
- row.append(A[i][j] + B[i][j])
+ row = [A[i][j] + B[i][j] for j in range(len(A[0]))]
res.append(row)
return res
@@ -405,9 +390,7 @@ def multiply(A, B):
matrix = []
for i in range(len(B)):
- row = []
- for j in range(len(B[0])):
- row.append(B[i][j] * A[j][i])
+ row = [B[i][j] * A[j][i] for j in range(len(B[0]))]
matrix.append(row)
return matrix
@@ -436,9 +419,7 @@ def pomdp_value_iteration(pomdp, epsilon=0.1):
values = [val for action in U for val in U[action]]
value_matxs = []
for i in values:
- for j in values:
- value_matxs.append([i, j])
-
+ value_matxs.extend([i, j] for j in values)
U1 = defaultdict(list)
for action in pomdp.actions:
for u in value_matxs:
@@ -451,9 +432,12 @@ def pomdp_value_iteration(pomdp, epsilon=0.1):
U = pomdp.remove_dominated_plans_fast(U1)
# replace with U = pomdp.remove_dominated_plans(U1) for accurate calculations
- if count > 10:
- if pomdp.max_difference(U, prev_U) < epsilon * (1 - pomdp.gamma) / pomdp.gamma:
- return U
+ if (
+ count > 10
+ and pomdp.max_difference(U, prev_U)
+ < epsilon * (1 - pomdp.gamma) / pomdp.gamma
+ ):
+ return U
__doc__ += """
diff --git a/mdp4e.py b/mdp4e.py
index f8871bdc9..853690221 100644
--- a/mdp4e.py
+++ b/mdp4e.py
@@ -34,14 +34,10 @@ def __init__(self, init, actlist, terminals, transitions=None, reward=None, stat
self.init = init
- if isinstance(actlist, list):
+ if isinstance(actlist, (list, dict)):
# if actlist is a list, all states have the same actions
self.actlist = actlist
- elif isinstance(actlist, dict):
- # if actlist is a dict, different actions for each state
- self.actlist = actlist
-
self.terminals = terminals
self.transitions = transitions or {}
if not self.transitions:
@@ -72,17 +68,17 @@ def actions(self, state):
fixed list of actions, except for terminal states. Override this
method if you need to specialize by state."""
- if state in self.terminals:
- return [None]
- else:
- return self.actlist
+ return [None] if state in self.terminals else self.actlist
def get_states_from_transitions(self, transitions):
if isinstance(transitions, dict):
s1 = set(transitions.keys())
- s2 = set(tr[1] for actions in transitions.values()
- for effects in actions.values()
- for tr in effects)
+ s2 = {
+ tr[1]
+ for actions in transitions.values()
+ for effects in actions.values()
+ for tr in effects
+ }
return s1.union(s2)
else:
print('Could not retrieve states from transitions')
@@ -105,9 +101,7 @@ def check_consistency(self):
# check that probability distributions for all actions sum to 1
for s1, actions in self.transitions.items():
for a in actions.keys():
- s = 0
- for o in actions[a]:
- s += o[0]
+ s = sum(o[0] for o in actions[a])
assert abs(s - 1) < 0.001
@@ -120,10 +114,7 @@ def __init__(self, init, actlist, terminals, transitions, reward=None, gamma=0.9
MDP.__init__(self, init, actlist, terminals, transitions, reward, gamma=gamma)
def T(self, state, action):
- if action is None:
- return [(0.0, state)]
- else:
- return self.transitions[state][action]
+ return [(0.0, state)] if action is None else self.transitions[state][action]
class GridMDP(MDP):
@@ -204,10 +195,9 @@ def to_arrows(self, policy):
def q_value(mdp, s, a, U):
if not a:
return mdp.R(s)
- res = 0
- for p, s_prime in mdp.T(s, a):
- res += p * (mdp.R(s) + mdp.gamma * U[s_prime])
- return res
+ return sum(
+ p * (mdp.R(s) + mdp.gamma * U[s_prime]) for p, s_prime in mdp.T(s, a)
+ )
# TODO: DDN in figure 16.4 and 16.5
@@ -242,10 +232,10 @@ def best_policy(mdp, U):
"""Given an MDP and a utility function U, determine the best policy,
as a mapping from state to action."""
- pi = {}
- for s in mdp.states:
- pi[s] = max(mdp.actions(s), key=lambda a: q_value(mdp, s, a, U))
- return pi
+ return {
+ s: max(mdp.actions(s), key=lambda a: q_value(mdp, s, a, U))
+ for s in mdp.states
+ }
def expected_utility(a, s, U, mdp):
@@ -277,7 +267,7 @@ def policy_evaluation(pi, U, mdp, k=20):
utility, using an approximation (modified policy iteration)."""
R, T, gamma = mdp.R, mdp.T, mdp.gamma
- for i in range(k):
+ for _ in range(k):
for s in mdp.states:
U[s] = R(s) + gamma * sum(p * U[s1] for (p, s1) in T(s, pi[s]))
return U
@@ -390,12 +380,8 @@ def max_difference(self, U1, U2):
"""Find maximum difference between two utility mappings"""
for k, v in U1.items():
- sum1 = 0
- for element in U1[k]:
- sum1 += sum(element)
- sum2 = 0
- for element in U2[k]:
- sum2 += sum(element)
+ sum1 = sum(sum(element) for element in U1[k])
+ sum2 = sum(sum(element) for element in U2[k])
return abs(sum1 - sum2)
@@ -408,9 +394,7 @@ def add(A, B):
res = []
for i in range(len(A)):
- row = []
- for j in range(len(A[0])):
- row.append(A[i][j] + B[i][j])
+ row = [A[i][j] + B[i][j] for j in range(len(A[0]))]
res.append(row)
return res
@@ -429,9 +413,7 @@ def multiply(A, B):
matrix = []
for i in range(len(B)):
- row = []
- for j in range(len(B[0])):
- row.append(B[i][j] * A[j][i])
+ row = [B[i][j] * A[j][i] for j in range(len(B[0]))]
matrix.append(row)
return matrix
@@ -460,9 +442,7 @@ def pomdp_value_iteration(pomdp, epsilon=0.1):
values = [val for action in U for val in U[action]]
value_matxs = []
for i in values:
- for j in values:
- value_matxs.append([i, j])
-
+ value_matxs.extend([i, j] for j in values)
U1 = defaultdict(list)
for action in pomdp.actions:
for u in value_matxs:
@@ -475,9 +455,12 @@ def pomdp_value_iteration(pomdp, epsilon=0.1):
U = pomdp.remove_dominated_plans_fast(U1)
# replace with U = pomdp.remove_dominated_plans(U1) for accurate calculations
- if count > 10:
- if pomdp.max_difference(U, prev_U) < epsilon * (1 - pomdp.gamma) / pomdp.gamma:
- return U
+ if (
+ count > 10
+ and pomdp.max_difference(U, prev_U)
+ < epsilon * (1 - pomdp.gamma) / pomdp.gamma
+ ):
+ return U
__doc__ += """
diff --git a/nlp.py b/nlp.py
index 03aabf54b..bbfdc27a6 100644
--- a/nlp.py
+++ b/nlp.py
@@ -55,9 +55,7 @@ def cnf_rules(self):
X -> Y Z"""
cnf = []
for X, rules in self.rules.items():
- for (Y, Z) in rules:
- cnf.append((X, Y, Z))
-
+ cnf.extend((X, Y, Z) for Y, Z in rules)
return cnf
def generate_random(self, S='S'):
@@ -77,7 +75,7 @@ def rewrite(tokens, into):
return ' '.join(rewrite(S.split(), []))
def __repr__(self):
- return ''.format(self.name)
+ return f''
def ProbRules(**rules):
@@ -142,9 +140,7 @@ def cnf_rules(self):
X -> Y Z [p]"""
cnf = []
for X, rules in self.rules.items():
- for (Y, Z), p in rules:
- cnf.append((X, Y, Z, p))
-
+ cnf.extend((X, Y, Z, p) for (Y, Z), p in rules)
return cnf
def generate_random(self, S='S'):
@@ -170,7 +166,7 @@ def rewrite(tokens, into):
return (' '.join(rewritten_as), prob)
def __repr__(self):
- return ''.format(self.name)
+ return f''
E0 = Grammar('E0',
@@ -309,7 +305,7 @@ def parses(self, words, S='S'):
def parse(self, words, S='S'):
"""Parse a list of words; according to the grammar.
Leave results in the chart."""
- self.chart = [[] for i in range(len(words) + 1)]
+ self.chart = [[] for _ in range(len(words) + 1)]
self.add_edge([0, 0, 'S_', [], [S]])
for i in range(len(words)):
self.scanner(i, words[i])
@@ -321,7 +317,7 @@ def add_edge(self, edge):
if edge not in self.chart[end]:
self.chart[end].append(edge)
if self.trace:
- print('Chart: added {}'.format(edge))
+ print(f'Chart: added {edge}')
if not expects:
self.extender(edge)
else:
@@ -406,10 +402,7 @@ def loadPageHTML(addressList):
def initPages(addressList):
"""Create a dictionary of pages from a list of URL addresses"""
- pages = {}
- for addr in addressList:
- pages[addr] = Page(addr)
- return pages
+ return {addr: Page(addr) for addr in addressList}
def stripRawHTML(raw_html):
@@ -443,7 +436,7 @@ def onlyWikipediaURLS(urls):
"""Some example HTML page data is from wikipedia. This function converts
relative wikipedia links to full wikipedia URLs"""
wikiURLs = [url for url in urls if url.startswith('/wiki/')]
- return ["https://en.wikipedia.org" + url for url in wikiURLs]
+ return [f"https://en.wikipedia.org{url}" for url in wikiURLs]
# ______________________________________________________________________________
@@ -468,13 +461,14 @@ def expand_pages(pages):
def relevant_pages(query):
"""Relevant pages are pages that contain all of the query words. They are obtained by
intersecting the hit lists of the query words."""
- hit_intersection = {addr for addr in pagesIndex}
+ hit_intersection = set(pagesIndex)
query_words = query.split()
for query_word in query_words:
- hit_list = set()
- for addr in pagesIndex:
- if query_word.lower() in pagesContent[addr].lower():
- hit_list.add(addr)
+ hit_list = {
+ addr
+ for addr in pagesIndex
+ if query_word.lower() in pagesContent[addr].lower()
+ }
hit_intersection = hit_intersection.intersection(hit_list)
return {addr: pagesIndex[addr] for addr in hit_intersection}
@@ -517,8 +511,8 @@ def detect(self):
if len(self.hub_history) > 2: # prevent list from getting long
del self.hub_history[0]
del self.auth_history[0]
- self.hub_history.append([x for x in curr_hubs])
- self.auth_history.append([x for x in curr_auths])
+ self.hub_history.append(list(curr_hubs))
+ self.auth_history.append(list(curr_auths))
return False
diff --git a/nlp4e.py b/nlp4e.py
index 095f54357..9fdb565f1 100644
--- a/nlp4e.py
+++ b/nlp4e.py
@@ -57,9 +57,7 @@ def cnf_rules(self):
X -> Y Z"""
cnf = []
for X, rules in self.rules.items():
- for (Y, Z) in rules:
- cnf.append((X, Y, Z))
-
+ cnf.extend((X, Y, Z) for Y, Z in rules)
return cnf
def generate_random(self, S='S'):
@@ -79,7 +77,7 @@ def rewrite(tokens, into):
return ' '.join(rewrite(S.split(), []))
def __repr__(self):
- return ''.format(self.name)
+ return f''
def ProbRules(**rules):
@@ -144,9 +142,7 @@ def cnf_rules(self):
X -> Y Z [p]"""
cnf = []
for X, rules in self.rules.items():
- for (Y, Z), p in rules:
- cnf.append((X, Y, Z, p))
-
+ cnf.extend((X, Y, Z, p) for (Y, Z), p in rules)
return cnf
def generate_random(self, S='S'):
@@ -171,7 +167,7 @@ def rewrite(tokens, into):
return (' '.join(rewritten_as), prob)
def __repr__(self):
- return ''.format(self.name)
+ return f''
E0 = Grammar('E0',
@@ -310,7 +306,7 @@ def parses(self, words, S='S'):
def parse(self, words, S='S'):
"""Parse a list of words; according to the grammar.
Leave results in the chart."""
- self.chart = [[] for i in range(len(words) + 1)]
+ self.chart = [[] for _ in range(len(words) + 1)]
self.add_edge([0, 0, 'S_', [], [S]])
for i in range(len(words)):
self.scanner(i, words[i])
@@ -322,7 +318,7 @@ def add_edge(self, edge):
if edge not in self.chart[end]:
self.chart[end].append(edge)
if self.trace:
- print('Chart: added {}'.format(edge))
+ print(f'Chart: added {edge}')
if not expects:
self.extender(edge)
else:
@@ -357,7 +353,7 @@ def extender(self, edge):
class Tree:
def __init__(self, root, *args):
self.root = root
- self.leaves = [leaf for leaf in args]
+ self.leaves = list(args)
def CYK_parse(words, grammar):
@@ -427,8 +423,10 @@ def actions(self, state):
for end in range(start, len(state) + 1):
# try combinations between (start, end)
articles = ' '.join(state[start:end])
- for c in self.combinations[articles]:
- actions.append(state[:start] + [c] + state[end:])
+ actions.extend(
+ state[:start] + [c] + state[end:]
+ for c in self.combinations[articles]
+ )
return actions
def result(self, state, action):
diff --git a/notebook.py b/notebook.py
index 7f0306335..4497f2cc7 100644
--- a/notebook.py
+++ b/notebook.py
@@ -27,11 +27,11 @@ def pseudocode(algorithm):
from IPython.display import Markdown
algorithm = algorithm.replace(' ', '-')
- url = "https://raw.githubusercontent.com/aimacode/aima-pseudocode/master/md/{}.md".format(algorithm)
+ url = f"https://raw.githubusercontent.com/aimacode/aima-pseudocode/master/md/{algorithm}.md"
f = urlopen(url)
md = f.read().decode('utf-8')
md = md.split('\n', 1)[-1].strip()
- md = '#' + md
+ md = f'#{md}'
return Markdown(md)
@@ -110,14 +110,13 @@ def load_MNIST(path="aima-data/MNIST/Digits", fashion=False):
plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams['image.cmap'] = 'gray'
- train_img_file = open(os.path.join(path, "train-images-idx3-ubyte"), "rb")
- train_lbl_file = open(os.path.join(path, "train-labels-idx1-ubyte"), "rb")
- test_img_file = open(os.path.join(path, "t10k-images-idx3-ubyte"), "rb")
- test_lbl_file = open(os.path.join(path, 't10k-labels-idx1-ubyte'), "rb")
+ with open(os.path.join(path, "train-images-idx3-ubyte"), "rb") as train_img_file:
+ train_lbl_file = open(os.path.join(path, "train-labels-idx1-ubyte"), "rb")
+ test_img_file = open(os.path.join(path, "t10k-images-idx3-ubyte"), "rb")
+ test_lbl_file = open(os.path.join(path, 't10k-labels-idx1-ubyte'), "rb")
- magic_nr, tr_size, tr_rows, tr_cols = struct.unpack(">IIII", train_img_file.read(16))
- tr_img = array.array("B", train_img_file.read())
- train_img_file.close()
+ magic_nr, tr_size, tr_rows, tr_cols = struct.unpack(">IIII", train_img_file.read(16))
+ tr_img = array.array("B", train_img_file.read())
magic_nr, tr_size = struct.unpack(">II", train_lbl_file.read(8))
tr_lbl = array.array("b", train_lbl_file.read())
train_lbl_file.close()
@@ -153,11 +152,7 @@ def load_MNIST(path="aima-data/MNIST/Digits", fashion=False):
def show_MNIST(labels, images, samples=8, fashion=False):
- if not fashion:
- classes = digit_classes
- else:
- classes = fashion_classes
-
+ classes = digit_classes if not fashion else fashion_classes
num_classes = len(classes)
for y, cls in enumerate(classes):
@@ -385,7 +380,7 @@ def __init__(self, varname, player_1='human', player_2='random',
width=300, height=350, cid=None):
valid_players = ('human', 'random', 'alpha_beta')
if player_1 not in valid_players or player_2 not in valid_players:
- raise TypeError("Players must be one of {}".format(valid_players))
+ raise TypeError(f"Players must be one of {valid_players}")
super().__init__(varname, width, height, cid)
self.ttt = TicTacToe()
self.state = self.ttt.initial
@@ -439,21 +434,37 @@ def draw_board(self):
if utility == 0:
self.text_n('Game Draw!', offset, 6 / 7 + offset)
else:
- self.text_n('Player {} wins!'.format("XO"[utility < 0]), offset, 6 / 7 + offset)
+ self.text_n(f'Player {"XO"[utility < 0]} wins!', offset, 6 / 7 + offset)
# Find the 3 and draw a line
self.stroke([255, 0][self.turn], [0, 255][self.turn], 0)
for i in range(3):
- if all([(i + 1, j + 1) in self.state.board for j in range(3)]) and \
- len({self.state.board[(i + 1, j + 1)] for j in range(3)}) == 1:
+ if (
+ all((i + 1, j + 1) in self.state.board for j in range(3))
+ and len(
+ {self.state.board[(i + 1, j + 1)] for j in range(3)}
+ )
+ == 1
+ ):
self.line_n(i / 3 + 1 / 6, offset * 6 / 7, i / 3 + 1 / 6, (1 - offset) * 6 / 7)
- if all([(j + 1, i + 1) in self.state.board for j in range(3)]) and \
- len({self.state.board[(j + 1, i + 1)] for j in range(3)}) == 1:
+ if (
+ all((j + 1, i + 1) in self.state.board for j in range(3))
+ and len(
+ {self.state.board[(j + 1, i + 1)] for j in range(3)}
+ )
+ == 1
+ ):
self.line_n(offset, (i / 3 + 1 / 6) * 6 / 7, 1 - offset, (i / 3 + 1 / 6) * 6 / 7)
- if all([(i + 1, i + 1) in self.state.board for i in range(3)]) and \
- len({self.state.board[(i + 1, i + 1)] for i in range(3)}) == 1:
+ if (
+ all((i + 1, i + 1) in self.state.board for i in range(3))
+ and len({self.state.board[(i + 1, i + 1)] for i in range(3)})
+ == 1
+ ):
self.line_n(offset, offset * 6 / 7, 1 - offset, (1 - offset) * 6 / 7)
- if all([(i + 1, 3 - i) in self.state.board for i in range(3)]) and \
- len({self.state.board[(i + 1, 3 - i)] for i in range(3)}) == 1:
+ if (
+ all((i + 1, 3 - i) in self.state.board for i in range(3))
+ and len({self.state.board[(i + 1, 3 - i)] for i in range(3)})
+ == 1
+ ):
self.line_n(offset, (1 - offset) * 6 / 7, 1 - offset, offset * 6 / 7)
# restart button
self.fill(0, 0, 255)
@@ -461,8 +472,11 @@ def draw_board(self):
self.fill(0, 0, 0)
self.text_n('Restart', 0.5 + 2 * offset, 13 / 14)
else: # Print which player's turn it is
- self.text_n("Player {}'s move({})".format("XO"[self.turn], self.players[self.turn]),
- offset, 6 / 7 + offset)
+ self.text_n(
+ f"""Player {"XO"[self.turn]}'s move({self.players[self.turn]})""",
+ offset,
+ 6 / 7 + offset,
+ )
self.update()
@@ -484,7 +498,7 @@ class Canvas_min_max(Canvas):
def __init__(self, varname, util_list, width=800, height=600, cid=None):
super.__init__(varname, width, height, cid)
- self.utils = {node: util for node, util in zip(range(13, 40), util_list)}
+ self.utils = dict(zip(range(13, 40), util_list))
self.game = Fig52Extended()
self.game.utils = self.utils
self.nodes = list(range(40))
@@ -498,7 +512,7 @@ def __init__(self, varname, util_list, width=800, height=600, cid=None):
self.l / 2 + (self.l + (1 - 5 * self.l) / 3) * i)
self.font("12px Arial")
self.node_stack = []
- self.explored = {node for node in self.utils}
+ self.explored = set(self.utils)
self.thick_lines = set()
self.change_list = []
self.draw_graph()
@@ -609,7 +623,7 @@ class Canvas_alpha_beta(Canvas):
def __init__(self, varname, util_list, width=800, height=600, cid=None):
super().__init__(varname, width, height, cid)
- self.utils = {node: util for node, util in zip(range(13, 40), util_list)}
+ self.utils = dict(zip(range(13, 40), util_list))
self.game = Fig52Extended()
self.game.utils = self.utils
self.nodes = list(range(40))
@@ -623,7 +637,7 @@ def __init__(self, varname, util_list, width=800, height=600, cid=None):
3 * self.l / 2 + (self.l + (1 - 6 * self.l) / 3) * i)
self.font("12px Arial")
self.node_stack = []
- self.explored = {node for node in self.utils}
+ self.explored = set(self.utils)
self.pruned = set()
self.ab = {}
self.thick_lines = set()
@@ -784,7 +798,7 @@ def __init__(self, varname, kb, query, width=800, height=600, cid=None):
self.l = 1 / 20
self.b = 3 * self.l
bc_out = list(self.fol_bc_ask())
- if len(bc_out) == 0:
+ if not bc_out:
self.valid = False
else:
self.valid = True
@@ -1020,23 +1034,25 @@ def slider_callback(iteration):
pass
def visualize_callback(visualize):
- if visualize is True:
- button.value = False
+ if visualize is not True:
+ return
+ button.value = False
- problem = GraphProblem(start_dropdown.value, end_dropdown.value, romania_map)
- global all_node_colors
+ problem = GraphProblem(start_dropdown.value, end_dropdown.value, romania_map)
+ global all_node_colors
- user_algorithm = algorithm[algo_dropdown.value]
+ user_algorithm = algorithm[algo_dropdown.value]
- iterations, all_node_colors, node = user_algorithm(problem)
- solution = node.solution()
- all_node_colors.append(final_path_colors(all_node_colors[0], problem, solution))
+ iterations, all_node_colors, node = user_algorithm(problem)
+ solution = node.solution()
+ all_node_colors.append(final_path_colors(all_node_colors[0], problem, solution))
- slider.max = len(all_node_colors) - 1
+ slider.max = len(all_node_colors) - 1
+
+ for i in range(slider.max + 1):
+ slider.value = i
+ # time.sleep(.5)
- for i in range(slider.max + 1):
- slider.value = i
- # time.sleep(.5)
start_dropdown = widgets.Dropdown(description="Start city: ",
options=sorted(list(node_colors.keys())), value="Arad")
@@ -1064,7 +1080,7 @@ def plot_NQueens(solution):
im = np.array(im).astype(np.float) / 255
fig = plt.figure(figsize=(7, 7))
ax = fig.add_subplot(111)
- ax.set_title('{} Queens'.format(n))
+ ax.set_title(f'{n} Queens')
plt.imshow(board, cmap='binary', interpolation='nearest')
# NQueensCSP gives a solution as a dictionary
if isinstance(solution, dict):
@@ -1096,8 +1112,7 @@ def heatmap(grid, cmap='binary', interpolation='nearest'):
def gaussian_kernel(l=5, sig=1.0):
ax = np.arange(-l // 2 + 1., l // 2 + 1.)
xx, yy = np.meshgrid(ax, ax)
- kernel = np.exp(-(xx ** 2 + yy ** 2) / (2. * sig ** 2))
- return kernel
+ return np.exp(-(xx ** 2 + yy ** 2) / (2. * sig ** 2))
# Plots utility function for a POMDP
diff --git a/notebook4e.py b/notebook4e.py
index 5b03081c6..a3593b660 100644
--- a/notebook4e.py
+++ b/notebook4e.py
@@ -28,11 +28,11 @@ def pseudocode(algorithm):
from IPython.display import Markdown
algorithm = algorithm.replace(' ', '-')
- url = "https://raw.githubusercontent.com/aimacode/aima-pseudocode/master/md/{}.md".format(algorithm)
+ url = f"https://raw.githubusercontent.com/aimacode/aima-pseudocode/master/md/{algorithm}.md"
f = urlopen(url)
md = f.read().decode('utf-8')
md = md.split('\n', 1)[-1].strip()
- md = '#' + md
+ md = f'#{md}'
return Markdown(md)
@@ -146,14 +146,13 @@ def load_MNIST(path="aima-data/MNIST/Digits", fashion=False):
plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams['image.cmap'] = 'gray'
- train_img_file = open(os.path.join(path, "train-images-idx3-ubyte"), "rb")
- train_lbl_file = open(os.path.join(path, "train-labels-idx1-ubyte"), "rb")
- test_img_file = open(os.path.join(path, "t10k-images-idx3-ubyte"), "rb")
- test_lbl_file = open(os.path.join(path, 't10k-labels-idx1-ubyte'), "rb")
+ with open(os.path.join(path, "train-images-idx3-ubyte"), "rb") as train_img_file:
+ train_lbl_file = open(os.path.join(path, "train-labels-idx1-ubyte"), "rb")
+ test_img_file = open(os.path.join(path, "t10k-images-idx3-ubyte"), "rb")
+ test_lbl_file = open(os.path.join(path, 't10k-labels-idx1-ubyte'), "rb")
- magic_nr, tr_size, tr_rows, tr_cols = struct.unpack(">IIII", train_img_file.read(16))
- tr_img = array.array("B", train_img_file.read())
- train_img_file.close()
+ magic_nr, tr_size, tr_rows, tr_cols = struct.unpack(">IIII", train_img_file.read(16))
+ tr_img = array.array("B", train_img_file.read())
magic_nr, tr_size = struct.unpack(">II", train_lbl_file.read(8))
tr_lbl = array.array("b", train_lbl_file.read())
train_lbl_file.close()
@@ -189,11 +188,7 @@ def load_MNIST(path="aima-data/MNIST/Digits", fashion=False):
def show_MNIST(labels, images, samples=8, fashion=False):
- if not fashion:
- classes = digit_classes
- else:
- classes = fashion_classes
-
+ classes = digit_classes if not fashion else fashion_classes
num_classes = len(classes)
for y, cls in enumerate(classes):
@@ -421,7 +416,7 @@ def __init__(self, varname, player_1='human', player_2='random',
width=300, height=350, cid=None):
valid_players = ('human', 'random', 'alpha_beta')
if player_1 not in valid_players or player_2 not in valid_players:
- raise TypeError("Players must be one of {}".format(valid_players))
+ raise TypeError(f"Players must be one of {valid_players}")
super().__init__(varname, width, height, cid)
self.ttt = TicTacToe()
self.state = self.ttt.initial
@@ -475,21 +470,37 @@ def draw_board(self):
if utility == 0:
self.text_n('Game Draw!', offset, 6 / 7 + offset)
else:
- self.text_n('Player {} wins!'.format("XO"[utility < 0]), offset, 6 / 7 + offset)
+ self.text_n(f'Player {"XO"[utility < 0]} wins!', offset, 6 / 7 + offset)
# Find the 3 and draw a line
self.stroke([255, 0][self.turn], [0, 255][self.turn], 0)
for i in range(3):
- if all([(i + 1, j + 1) in self.state.board for j in range(3)]) and \
- len({self.state.board[(i + 1, j + 1)] for j in range(3)}) == 1:
+ if (
+ all((i + 1, j + 1) in self.state.board for j in range(3))
+ and len(
+ {self.state.board[(i + 1, j + 1)] for j in range(3)}
+ )
+ == 1
+ ):
self.line_n(i / 3 + 1 / 6, offset * 6 / 7, i / 3 + 1 / 6, (1 - offset) * 6 / 7)
- if all([(j + 1, i + 1) in self.state.board for j in range(3)]) and \
- len({self.state.board[(j + 1, i + 1)] for j in range(3)}) == 1:
+ if (
+ all((j + 1, i + 1) in self.state.board for j in range(3))
+ and len(
+ {self.state.board[(j + 1, i + 1)] for j in range(3)}
+ )
+ == 1
+ ):
self.line_n(offset, (i / 3 + 1 / 6) * 6 / 7, 1 - offset, (i / 3 + 1 / 6) * 6 / 7)
- if all([(i + 1, i + 1) in self.state.board for i in range(3)]) and \
- len({self.state.board[(i + 1, i + 1)] for i in range(3)}) == 1:
+ if (
+ all((i + 1, i + 1) in self.state.board for i in range(3))
+ and len({self.state.board[(i + 1, i + 1)] for i in range(3)})
+ == 1
+ ):
self.line_n(offset, offset * 6 / 7, 1 - offset, (1 - offset) * 6 / 7)
- if all([(i + 1, 3 - i) in self.state.board for i in range(3)]) and \
- len({self.state.board[(i + 1, 3 - i)] for i in range(3)}) == 1:
+ if (
+ all((i + 1, 3 - i) in self.state.board for i in range(3))
+ and len({self.state.board[(i + 1, 3 - i)] for i in range(3)})
+ == 1
+ ):
self.line_n(offset, (1 - offset) * 6 / 7, 1 - offset, offset * 6 / 7)
# restart button
self.fill(0, 0, 255)
@@ -497,8 +508,11 @@ def draw_board(self):
self.fill(0, 0, 0)
self.text_n('Restart', 0.5 + 2 * offset, 13 / 14)
else: # Print which player's turn it is
- self.text_n("Player {}'s move({})".format("XO"[self.turn], self.players[self.turn]),
- offset, 6 / 7 + offset)
+ self.text_n(
+ f"""Player {"XO"[self.turn]}'s move({self.players[self.turn]})""",
+ offset,
+ 6 / 7 + offset,
+ )
self.update()
@@ -520,7 +534,7 @@ class Canvas_min_max(Canvas):
def __init__(self, varname, util_list, width=800, height=600, cid=None):
super().__init__(varname, width, height, cid)
- self.utils = {node: util for node, util in zip(range(13, 40), util_list)}
+ self.utils = dict(zip(range(13, 40), util_list))
self.game = Fig52Extended()
self.game.utils = self.utils
self.nodes = list(range(40))
@@ -534,7 +548,7 @@ def __init__(self, varname, util_list, width=800, height=600, cid=None):
self.l / 2 + (self.l + (1 - 5 * self.l) / 3) * i)
self.font("12px Arial")
self.node_stack = []
- self.explored = {node for node in self.utils}
+ self.explored = set(self.utils)
self.thick_lines = set()
self.change_list = []
self.draw_graph()
@@ -645,7 +659,7 @@ class Canvas_alpha_beta(Canvas):
def __init__(self, varname, util_list, width=800, height=600, cid=None):
super().__init__(varname, width, height, cid)
- self.utils = {node: util for node, util in zip(range(13, 40), util_list)}
+ self.utils = dict(zip(range(13, 40), util_list))
self.game = Fig52Extended()
self.game.utils = self.utils
self.nodes = list(range(40))
@@ -659,7 +673,7 @@ def __init__(self, varname, util_list, width=800, height=600, cid=None):
3 * self.l / 2 + (self.l + (1 - 6 * self.l) / 3) * i)
self.font("12px Arial")
self.node_stack = []
- self.explored = {node for node in self.utils}
+ self.explored = set(self.utils)
self.pruned = set()
self.ab = {}
self.thick_lines = set()
@@ -820,7 +834,7 @@ def __init__(self, varname, kb, query, width=800, height=600, cid=None):
self.l = 1 / 20
self.b = 3 * self.l
bc_out = list(self.fol_bc_ask())
- if len(bc_out) == 0:
+ if not bc_out:
self.valid = False
else:
self.valid = True
@@ -1056,23 +1070,25 @@ def slider_callback(iteration):
pass
def visualize_callback(visualize):
- if visualize is True:
- button.value = False
+ if visualize is not True:
+ return
+ button.value = False
- problem = GraphProblem(start_dropdown.value, end_dropdown.value, romania_map)
- global all_node_colors
+ problem = GraphProblem(start_dropdown.value, end_dropdown.value, romania_map)
+ global all_node_colors
- user_algorithm = algorithm[algo_dropdown.value]
+ user_algorithm = algorithm[algo_dropdown.value]
- iterations, all_node_colors, node = user_algorithm(problem)
- solution = node.solution()
- all_node_colors.append(final_path_colors(all_node_colors[0], problem, solution))
+ iterations, all_node_colors, node = user_algorithm(problem)
+ solution = node.solution()
+ all_node_colors.append(final_path_colors(all_node_colors[0], problem, solution))
- slider.max = len(all_node_colors) - 1
+ slider.max = len(all_node_colors) - 1
+
+ for i in range(slider.max + 1):
+ slider.value = i
+ # time.sleep(.5)
- for i in range(slider.max + 1):
- slider.value = i
- # time.sleep(.5)
start_dropdown = widgets.Dropdown(description="Start city: ",
options=sorted(list(node_colors.keys())), value="Arad")
@@ -1100,7 +1116,7 @@ def plot_NQueens(solution):
im = np.array(im).astype(np.float) / 255
fig = plt.figure(figsize=(7, 7))
ax = fig.add_subplot(111)
- ax.set_title('{} Queens'.format(n))
+ ax.set_title(f'{n} Queens')
plt.imshow(board, cmap='binary', interpolation='nearest')
# NQueensCSP gives a solution as a dictionary
if isinstance(solution, dict):
@@ -1132,8 +1148,7 @@ def heatmap(grid, cmap='binary', interpolation='nearest'):
def gaussian_kernel(l=5, sig=1.0):
ax = np.arange(-l // 2 + 1., l // 2 + 1.)
xx, yy = np.meshgrid(ax, ax)
- kernel = np.exp(-(xx ** 2 + yy ** 2) / (2. * sig ** 2))
- return kernel
+ return np.exp(-(xx ** 2 + yy ** 2) / (2. * sig ** 2))
# Plots utility function for a POMDP
diff --git a/notebooks/chapter19/images/autoencoder.png b/notebooks/chapter19/images/autoencoder.png
index cd216e9f7..e571b6376 100644
Binary files a/notebooks/chapter19/images/autoencoder.png and b/notebooks/chapter19/images/autoencoder.png differ
diff --git a/notebooks/chapter19/images/backprop.png b/notebooks/chapter19/images/backprop.png
index 8d53530e6..656d40d88 100644
Binary files a/notebooks/chapter19/images/backprop.png and b/notebooks/chapter19/images/backprop.png differ
diff --git a/notebooks/chapter19/images/corss_entropy_plot.png b/notebooks/chapter19/images/corss_entropy_plot.png
index 8212405e7..778acf175 100644
Binary files a/notebooks/chapter19/images/corss_entropy_plot.png and b/notebooks/chapter19/images/corss_entropy_plot.png differ
diff --git a/notebooks/chapter19/images/mse_plot.png b/notebooks/chapter19/images/mse_plot.png
index fd58f9db9..67ff58bf8 100644
Binary files a/notebooks/chapter19/images/mse_plot.png and b/notebooks/chapter19/images/mse_plot.png differ
diff --git a/notebooks/chapter19/images/nn.png b/notebooks/chapter19/images/nn.png
index 673b9338b..68f3feb04 100644
Binary files a/notebooks/chapter19/images/nn.png and b/notebooks/chapter19/images/nn.png differ
diff --git a/notebooks/chapter19/images/nn_steps.png b/notebooks/chapter19/images/nn_steps.png
index 4a596133b..06c5b576f 100644
Binary files a/notebooks/chapter19/images/nn_steps.png and b/notebooks/chapter19/images/nn_steps.png differ
diff --git a/notebooks/chapter19/images/perceptron.png b/notebooks/chapter19/images/perceptron.png
index 68d2a258a..2d3b4da84 100644
Binary files a/notebooks/chapter19/images/perceptron.png and b/notebooks/chapter19/images/perceptron.png differ
diff --git a/notebooks/chapter19/images/rnn_connections.png b/notebooks/chapter19/images/rnn_connections.png
index c72d459b8..f2b5e9728 100644
Binary files a/notebooks/chapter19/images/rnn_connections.png and b/notebooks/chapter19/images/rnn_connections.png differ
diff --git a/notebooks/chapter19/images/rnn_unit.png b/notebooks/chapter19/images/rnn_unit.png
index e4ebabf2b..cb37cfd72 100644
Binary files a/notebooks/chapter19/images/rnn_unit.png and b/notebooks/chapter19/images/rnn_unit.png differ
diff --git a/notebooks/chapter19/images/rnn_units.png b/notebooks/chapter19/images/rnn_units.png
index 5724f5d46..dd106c939 100644
Binary files a/notebooks/chapter19/images/rnn_units.png and b/notebooks/chapter19/images/rnn_units.png differ
diff --git a/notebooks/chapter19/images/vanilla.png b/notebooks/chapter19/images/vanilla.png
index db7a45f9a..f18331b75 100644
Binary files a/notebooks/chapter19/images/vanilla.png and b/notebooks/chapter19/images/vanilla.png differ
diff --git a/notebooks/chapter22/images/parse_tree.png b/notebooks/chapter22/images/parse_tree.png
index f6ca87b2f..6a36da2ed 100644
Binary files a/notebooks/chapter22/images/parse_tree.png and b/notebooks/chapter22/images/parse_tree.png differ
diff --git a/notebooks/chapter24/images/RCNN.png b/notebooks/chapter24/images/RCNN.png
index 273021fbe..2a61aba78 100644
Binary files a/notebooks/chapter24/images/RCNN.png and b/notebooks/chapter24/images/RCNN.png differ
diff --git a/notebooks/chapter24/images/derivative_of_gaussian.png b/notebooks/chapter24/images/derivative_of_gaussian.png
index 0be575529..3581620af 100644
Binary files a/notebooks/chapter24/images/derivative_of_gaussian.png and b/notebooks/chapter24/images/derivative_of_gaussian.png differ
diff --git a/notebooks/chapter24/images/gradients.png b/notebooks/chapter24/images/gradients.png
index ae57bdf3b..6d296e475 100644
Binary files a/notebooks/chapter24/images/gradients.png and b/notebooks/chapter24/images/gradients.png differ
diff --git a/notebooks/chapter24/images/laplacian_kernels.png b/notebooks/chapter24/images/laplacian_kernels.png
index faca3321c..f3c97eb5e 100644
Binary files a/notebooks/chapter24/images/laplacian_kernels.png and b/notebooks/chapter24/images/laplacian_kernels.png differ
diff --git a/notebooks/chapter24/images/stapler.png b/notebooks/chapter24/images/stapler.png
index e550d83f9..67d0650dd 100644
Binary files a/notebooks/chapter24/images/stapler.png and b/notebooks/chapter24/images/stapler.png differ
diff --git a/notebooks/chapter24/images/stapler_bbox.png b/notebooks/chapter24/images/stapler_bbox.png
index c5a7c7af0..b9f5d6547 100644
Binary files a/notebooks/chapter24/images/stapler_bbox.png and b/notebooks/chapter24/images/stapler_bbox.png differ
diff --git a/perception4e.py b/perception4e.py
index edd556607..2aa36da43 100644
--- a/perception4e.py
+++ b/perception4e.py
@@ -39,8 +39,7 @@ def gradient_edge_detector(image):
# convolution between filter and image to get edges
y_edges = scipy.signal.convolve2d(image, x_filter, 'same')
x_edges = scipy.signal.convolve2d(image, y_filter, 'same')
- edges = array_normalization(x_edges + y_edges, 0, 255)
- return edges
+ return array_normalization(x_edges + y_edges, 0, 255)
def gaussian_derivative_edge_detector(image):
@@ -54,8 +53,7 @@ def gaussian_derivative_edge_detector(image):
# extract edges using convolution
y_edges = scipy.signal.convolve2d(image, x_filter, 'same')
x_edges = scipy.signal.convolve2d(image, y_filter, 'same')
- edges = array_normalization(x_edges + y_edges, 0, 255)
- return edges
+ return array_normalization(x_edges + y_edges, 0, 255)
def laplacian_edge_detector(image):
@@ -169,12 +167,7 @@ def group_contour_detection(image, cluster_num=2):
ret, label, center = cv2.kmeans(Z, K, None, criteria, 10, cv2.KMEANS_RANDOM_CENTERS)
center = np.uint8(center)
res = center[label.flatten()]
- res2 = res.reshape(img.shape)
- # show the image
- # cv2.imshow('res2', res2)
- # cv2.waitKey(0)
- # cv2.destroyAllWindows()
- return res2
+ return res.reshape(img.shape)
def image_to_graph(image):
@@ -233,7 +226,7 @@ def bfs(self, s, t, parent):
queue.append(node)
visited.append(node)
parent.append((u, node))
- return True if t in visited else False
+ return t in visited
def min_cut(self, source, sink):
"""Find the minimum cut of the graph between source and sink"""
@@ -256,9 +249,12 @@ def min_cut(self, source, sink):
parent = []
res = []
for i in self.flow:
- for j in self.flow[i]:
- if self.flow[i][j] == 0 and generate_edge_weight(self.image, i, j) > 0:
- res.append((i, j))
+ res.extend(
+ (i, j)
+ for j in self.flow[i]
+ if self.flow[i][j] == 0
+ and generate_edge_weight(self.image, i, j) > 0
+ )
return res
diff --git a/planning.py b/planning.py
index 1e4a19209..d1eaa3413 100644
--- a/planning.py
+++ b/planning.py
@@ -30,10 +30,7 @@ def __init__(self, initial, goals, actions, domain=None):
def convert(self, clauses):
"""Converts strings into exprs"""
if not isinstance(clauses, Expr):
- if len(clauses) > 0:
- clauses = expr(clauses)
- else:
- clauses = []
+ clauses = expr(clauses) if len(clauses) > 0 else []
try:
clauses = conjuncts(clauses)
except AttributeError:
@@ -42,7 +39,7 @@ def convert(self, clauses):
new_clauses = []
for clause in clauses:
if clause.op == '~':
- new_clauses.append(expr('Not' + str(clause.args[0])))
+ new_clauses.append(expr(f'Not{str(clause.args[0])}'))
else:
new_clauses.append(clause)
return new_clauses
@@ -56,9 +53,11 @@ def expand_fluents(self, name=None):
if action.precond:
for fests in set(action.precond).union(action.effect).difference(self.convert(action.domain)):
if fests.op[:3] != 'Not':
- kb.tell(expr(str(action.domain) + ' ==> ' + str(fests)))
+ kb.tell(expr(f'{str(action.domain)} ==> {str(fests)}'))
- objects = set(arg for clause in set(self.initial + self.goals) for arg in clause.args)
+ objects = {
+ arg for clause in set(self.initial + self.goals) for arg in clause.args
+ }
fluent_list = []
if name is not None:
for fluent in self.initial + self.goals:
@@ -89,9 +88,9 @@ def expand_actions(self, name=None):
kb = FolKB(self.initial)
for action in self.actions:
if action.precond:
- kb.tell(expr(str(action.domain) + ' ==> ' + str(action)))
+ kb.tell(expr(f'{str(action.domain)} ==> {str(action)}'))
- objects = set(arg for clause in self.initial for arg in clause.args)
+ objects = {arg for clause in self.initial for arg in clause.args}
expansions = []
action_list = []
if name is not None:
@@ -159,9 +158,9 @@ def act(self, action):
args = action.args
list_action = first(a for a in self.actions if a.name == action_name)
if list_action is None:
- raise Exception("Action '{}' not found".format(action_name))
+ raise Exception(f"Action '{action_name}' not found")
if not list_action.check_precond(self.initial, args):
- raise Exception("Action '{}' pre-conditions not satisfied".format(action))
+ raise Exception(f"Action '{action}' pre-conditions not satisfied")
self.initial = list_action(self.initial, args).clauses
@@ -191,7 +190,7 @@ def __call__(self, kb, args):
return self.act(kb, args)
def __repr__(self):
- return '{}'.format(Expr(self.name, *self.args))
+ return f'{Expr(self.name, *self.args)}'
def convert(self, clauses):
"""Converts strings into Exprs"""
@@ -199,7 +198,7 @@ def convert(self, clauses):
clauses = conjuncts(clauses)
for i in range(len(clauses)):
if clauses[i].op == '~':
- clauses[i] = expr('Not' + str(clauses[i].args[0]))
+ clauses[i] = expr(f'Not{str(clauses[i].args[0])}')
elif isinstance(clauses, str):
clauses = clauses.replace('~', 'Not')
@@ -235,10 +234,9 @@ def check_precond(self, kb, args):
if isinstance(kb, list):
kb = FolKB(kb)
- for clause in self.precond:
- if self.substitute(clause, args) not in kb.clauses:
- return False
- return True
+ return all(
+ self.substitute(clause, args) in kb.clauses for clause in self.precond
+ )
def act(self, kb, args):
"""Executes the action on the state's knowledge base"""
@@ -253,24 +251,18 @@ def act(self, kb, args):
if clause.op[:3] == 'Not':
new_clause = Expr(clause.op[3:], *clause.args)
- if kb.ask(self.substitute(new_clause, args)) is not False:
- kb.retract(self.substitute(new_clause, args))
else:
- new_clause = Expr('Not' + clause.op, *clause.args)
-
- if kb.ask(self.substitute(new_clause, args)) is not False:
- kb.retract(self.substitute(new_clause, args))
+ new_clause = Expr(f'Not{clause.op}', *clause.args)
+ if kb.ask(self.substitute(new_clause, args)) is not False:
+ kb.retract(self.substitute(new_clause, args))
return kb
def goal_test(goals, state):
"""Generic goal testing helper function"""
- if isinstance(state, list):
- kb = FolKB(state)
- else:
- kb = state
+ kb = FolKB(state) if isinstance(state, list) else state
return all(kb.ask(q) is not False for q in goals)
@@ -616,15 +608,24 @@ def actions(self, subgoal):
"""
def negate_clause(clause):
- return Expr(clause.op.replace('Not', ''), *clause.args) if clause.op[:3] == 'Not' else Expr(
- 'Not' + clause.op, *clause.args)
+ return (
+ Expr(clause.op.replace('Not', ''), *clause.args)
+ if clause.op[:3] == 'Not'
+ else Expr(f'Not{clause.op}', *clause.args)
+ )
subgoal = conjuncts(subgoal)
- return [action for action in self.expanded_actions if
- (any(prop in action.effect for prop in subgoal) and
- not any(negate_clause(prop) in subgoal for prop in action.effect) and
- not any(negate_clause(prop) in subgoal and negate_clause(prop) not in action.effect
- for prop in action.precond))]
+ return [
+ action
+ for action in self.expanded_actions
+ if any(prop in action.effect for prop in subgoal)
+ and all(negate_clause(prop) not in subgoal for prop in action.effect)
+ and not any(
+ negate_clause(prop) in subgoal
+ and negate_clause(prop) not in action.effect
+ for prop in action.precond
+ )
+ ]
def result(self, subgoal, action):
# g' = (g - effects(a)) + preconds(a)
@@ -657,7 +658,7 @@ def CSPlan(planning_problem, solution_length, CSP_solver=ac_search_solver, arc_h
def st(var, stage):
"""Returns a string for the var-stage pair that can be used as a variable"""
- return str(var) + "_" + str(stage)
+ return f"{str(var)}_{str(stage)}"
def if_(v1, v2):
"""If the second argument is v2, the first argument must be v1"""
@@ -665,7 +666,7 @@ def if_(v1, v2):
def if_fun(x1, x2):
return x1 == v1 if x2 == v2 else True
- if_fun.__name__ = "if the second argument is " + str(v2) + " then the first argument is " + str(v1) + " "
+ if_fun.__name__ = f"if the second argument is {str(v2)} then the first argument is {str(v1)} "
return if_fun
def eq_if_not_in_(actset):
@@ -674,7 +675,7 @@ def eq_if_not_in_(actset):
def eq_if_not_in(x1, a, x2):
return x1 == x2 if a not in actset else True
- eq_if_not_in.__name__ = "first and third arguments are equal if action is not in " + str(actset) + " "
+ eq_if_not_in.__name__ = f"first and third arguments are equal if action is not in {str(actset)} "
return eq_if_not_in
expanded_actions = planning_problem.expand_actions()
@@ -684,39 +685,72 @@ def eq_if_not_in(x1, a, x2):
domains = {av: list(map(lambda action: expr(str(action)), expanded_actions)) for av in act_vars}
domains.update({st(var, stage): {True, False} for var in fluent_values for stage in range(horizon + 2)})
# initial state constraints
- constraints = [Constraint((st(var, 0),), is_constraint(val))
- for (var, val) in {expr(str(fluent).replace('Not', '')):
- True if fluent.op[:3] != 'Not' else False
- for fluent in planning_problem.initial}.items()]
+ constraints = [
+ Constraint((st(var, 0),), is_constraint(val))
+ for (var, val) in {
+ expr(str(fluent).replace('Not', '')): fluent.op[:3] != 'Not'
+ for fluent in planning_problem.initial
+ }.items()
+ ]
constraints += [Constraint((st(var, 0),), is_constraint(False))
for var in {expr(str(fluent).replace('Not', ''))
for fluent in fluent_values if fluent not in planning_problem.initial}]
# goal state constraints
- constraints += [Constraint((st(var, horizon + 1),), is_constraint(val))
- for (var, val) in {expr(str(fluent).replace('Not', '')):
- True if fluent.op[:3] != 'Not' else False
- for fluent in planning_problem.goals}.items()]
+ constraints += [
+ Constraint((st(var, horizon + 1),), is_constraint(val))
+ for (var, val) in {
+ expr(str(fluent).replace('Not', '')): fluent.op[:3] != 'Not'
+ for fluent in planning_problem.goals
+ }.items()
+ ]
# precondition constraints
- constraints += [Constraint((st(var, stage), st('action', stage)), if_(val, act))
- # st(var, stage) == val if st('action', stage) == act
- for act, strps in {expr(str(action)): action for action in expanded_actions}.items()
- for var, val in {expr(str(fluent).replace('Not', '')):
- True if fluent.op[:3] != 'Not' else False
- for fluent in strps.precond}.items()
- for stage in range(horizon + 1)]
+ constraints += [
+ Constraint((st(var, stage), st('action', stage)), if_(val, act))
+ for act, strps in {
+ expr(str(action)): action for action in expanded_actions
+ }.items()
+ for var, val in {
+ expr(str(fluent).replace('Not', '')): fluent.op[:3] != 'Not'
+ for fluent in strps.precond
+ }.items()
+ for stage in range(horizon + 1)
+ ]
# effect constraints
- constraints += [Constraint((st(var, stage + 1), st('action', stage)), if_(val, act))
- # st(var, stage + 1) == val if st('action', stage) == act
- for act, strps in {expr(str(action)): action for action in expanded_actions}.items()
- for var, val in {expr(str(fluent).replace('Not', '')): True if fluent.op[:3] != 'Not' else False
- for fluent in strps.effect}.items()
- for stage in range(horizon + 1)]
+ constraints += [
+ Constraint(
+ (st(var, stage + 1), st('action', stage)), if_(val, act)
+ )
+ for act, strps in {
+ expr(str(action)): action for action in expanded_actions
+ }.items()
+ for var, val in {
+ expr(str(fluent).replace('Not', '')): fluent.op[:3] != 'Not'
+ for fluent in strps.effect
+ }.items()
+ for stage in range(horizon + 1)
+ ]
# frame constraints
- constraints += [Constraint((st(var, stage), st('action', stage), st(var, stage + 1)),
- eq_if_not_in_(set(map(lambda action: expr(str(action)),
- {act for act in expanded_actions if var in act.effect
- or Expr('Not' + var.op, *var.args) in act.effect}))))
- for var in fluent_values for stage in range(horizon + 1)]
+ constraints += [
+ Constraint(
+ (st(var, stage), st('action', stage), st(var, stage + 1)),
+ eq_if_not_in_(
+ set(
+ map(
+ lambda action: expr(str(action)),
+ {
+ act
+ for act in expanded_actions
+ if var in act.effect
+ or Expr(f'Not{var.op}', *var.args)
+ in act.effect
+ },
+ )
+ )
+ ),
+ )
+ for var in fluent_values
+ for stage in range(horizon + 1)
+ ]
csp = NaryCSP(domains, constraints)
sol = CSP_solver(csp, arc_heuristic=arc_heuristic)
if sol:
@@ -834,7 +868,7 @@ def build(self, actions, objects):
"""Populates the lists and dictionaries containing the state action dependencies"""
for clause in self.current_state:
- p_expr = Expr('P' + clause.op, *clause.args)
+ p_expr = Expr(f'P{clause.op}', *clause.args)
self.current_action_links[p_expr] = [clause]
self.next_action_links[p_expr] = [clause]
self.current_state_links[clause] = [p_expr]
@@ -890,7 +924,7 @@ def __init__(self, planning_problem):
self.planning_problem = planning_problem
self.kb = FolKB(planning_problem.initial)
self.levels = [Level(self.kb)]
- self.objects = set(arg for clause in self.kb.clauses for arg in clause.args)
+ self.objects = {arg for clause in self.kb.clauses for arg in clause.args}
def __call__(self):
self.expand_graph()
@@ -906,10 +940,7 @@ def non_mutex_goals(self, goals, index):
"""Checks whether the goals are mutually exclusive"""
goal_perm = itertools.combinations(goals, 2)
- for g in goal_perm:
- if set(g) in self.levels[index].mutex:
- return False
- return True
+ return all(set(g) not in self.levels[index].mutex for g in goal_perm)
class GraphPlan:
@@ -942,11 +973,7 @@ def extract_solution(self, goals, index):
level = self.graph.levels[index - 1]
- # Create all combinations of actions that satisfy the goal
- actions = []
- for goal in goals:
- actions.append(level.next_state_links[goal])
-
+ actions = [level.next_state_links[goal] for goal in goals]
all_actions = list(itertools.product(*actions))
# Filter out non-mutex actions
@@ -981,10 +1008,7 @@ def extract_solution(self, goals, index):
for item in self.solution:
if item[1] == -1:
solution.append([])
- solution[-1].append(item[0])
- else:
- solution[-1].append(item[0])
-
+ solution[-1].append(item[0])
for num, item in enumerate(solution):
item.reverse()
solution[num] = item
@@ -1001,8 +1025,9 @@ def execute(self):
self.graph.expand_graph()
if (self.goal_test(self.graph.levels[-1].kb) and self.graph.non_mutex_goals(
self.graph.planning_problem.goals, -1)):
- solution = self.extract_solution(self.graph.planning_problem.goals, -1)
- if solution:
+ if solution := self.extract_solution(
+ self.graph.planning_problem.goals, -1
+ ):
return solution
if len(self.graph.levels) >= 2 and self.check_leveloff():
@@ -1019,10 +1044,11 @@ def filter(self, solution):
new_solution = []
for section in solution[0]:
- new_section = []
- for operation in section:
- if not (operation.op[0] == 'P' and operation.op[1].isupper()):
- new_section.append(operation)
+ new_section = [
+ operation
+ for operation in section
+ if not (operation.op[0] == 'P' and operation.op[1].isupper())
+ ]
new_solution.append(new_section)
return new_solution
@@ -1053,9 +1079,7 @@ def execute(self):
planning_problem = self.planning_problem
for level in filtered_solution:
level_solution, planning_problem = self.orderlevel(level, planning_problem)
- for element in level_solution:
- ordered_solution.append(element)
-
+ ordered_solution.extend(iter(level_solution))
return ordered_solution
@@ -1064,10 +1088,11 @@ def linearize(solution):
linear_solution = []
for section in solution[0]:
- for operation in section:
- if not (operation.op[0] == 'P' and operation.op[1].isupper()):
- linear_solution.append(operation)
-
+ linear_solution.extend(
+ operation
+ for operation in section
+ if operation.op[0] != 'P' or not operation.op[1].isupper()
+ )
return linear_solution
@@ -1100,11 +1125,8 @@ def __init__(self, planning_problem):
self.causal_links = []
self.start = Action('Start', [], self.planning_problem.initial)
self.finish = Action('Finish', self.planning_problem.goals, [])
- self.actions = set()
- self.actions.add(self.start)
- self.actions.add(self.finish)
- self.constraints = set()
- self.constraints.add((self.start, self.finish))
+ self.actions = {self.start, self.finish}
+ self.constraints = {(self.start, self.finish)}
self.agenda = set()
for precond in self.finish.precond:
self.agenda.add((precond, self.finish))
@@ -1113,8 +1135,8 @@ def __init__(self, planning_problem):
def find_open_precondition(self):
"""Find open precondition with the least number of possible actions"""
- number_of_ways = dict()
- actions_for_precondition = dict()
+ number_of_ways = {}
+ actions_for_precondition = {}
for element in self.agenda:
open_precondition = element[0]
possible_actions = list(self.actions) + self.expanded_actions
@@ -1130,16 +1152,14 @@ def find_open_precondition(self):
number = sorted(number_of_ways, key=number_of_ways.__getitem__)
- for k, v in number_of_ways.items():
+ for v in number_of_ways.values():
if v == 0:
return None, None, None
- act1 = None
- for element in self.agenda:
- if element[0] == number[0]:
- act1 = element[1]
- break
-
+ act1 = next(
+ (element[1] for element in self.agenda if element[0] == number[0]),
+ None,
+ )
if number[0] in self.expanded_actions:
self.expanded_actions.remove(number[0])
@@ -1183,27 +1203,24 @@ def generate_expr(self, clause, bindings):
def generate_action_object(self, action, bindings):
"""Generate action object given a generic action and variable bindings"""
- # if bindings is 0, it means the action already exists in self.actions
if bindings == 0:
return action
- # bindings cannot be None
- else:
- new_expr = self.generate_expr(action, bindings)
- new_preconds = []
- for precond in action.precond:
- new_precond = self.generate_expr(precond, bindings)
- new_preconds.append(new_precond)
- new_effects = []
- for effect in action.effect:
- new_effect = self.generate_expr(effect, bindings)
- new_effects.append(new_effect)
- return Action(new_expr, new_preconds, new_effects)
+ new_expr = self.generate_expr(action, bindings)
+ new_preconds = []
+ for precond in action.precond:
+ new_precond = self.generate_expr(precond, bindings)
+ new_preconds.append(new_precond)
+ new_effects = []
+ for effect in action.effect:
+ new_effect = self.generate_expr(effect, bindings)
+ new_effects.append(new_effect)
+ return Action(new_expr, new_preconds, new_effects)
def cyclic(self, graph):
"""Check cyclicity of a directed graph"""
- new_graph = dict()
+ new_graph = {}
for element in graph:
if element[0] in new_graph:
new_graph[element[0]].append(element[1])
@@ -1232,27 +1249,22 @@ def add_const(self, constraint, constraints):
new_constraints = set(constraints)
new_constraints.add(constraint)
- if self.cyclic(new_constraints):
- return constraints
- return new_constraints
+ return constraints if self.cyclic(new_constraints) else new_constraints
def is_a_threat(self, precondition, effect):
"""Check if effect is a threat to precondition"""
- if (str(effect.op) == 'Not' + str(precondition.op)) or ('Not' + str(effect.op) == str(precondition.op)):
- if effect.args == precondition.args:
- return True
- return False
+ return (
+ str(effect.op) == f'Not{str(precondition.op)}'
+ or f'Not{str(effect.op)}' == str(precondition.op)
+ ) and effect.args == precondition.args
def protect(self, causal_link, action, constraints):
"""Check and resolve threats by promotion or demotion"""
- threat = False
- for effect in action.effect:
- if self.is_a_threat(causal_link[1], effect):
- threat = True
- break
-
+ threat = any(
+ self.is_a_threat(causal_link[1], effect) for effect in action.effect
+ )
if action != causal_link[0] and action != causal_link[2] and threat:
# try promotion
new_constraints = set(constraints)
@@ -1274,13 +1286,11 @@ def protect(self, causal_link, action, constraints):
def convert(self, constraints):
"""Convert constraints into a dict of Action to set orderings"""
- graph = dict()
+ graph = {}
for constraint in constraints:
- if constraint[0] in graph:
- graph[constraint[0]].add(constraint[1])
- else:
+ if constraint[0] not in graph:
graph[constraint[0]] = set()
- graph[constraint[0]].add(constraint[1])
+ graph[constraint[0]].add(constraint[1])
return graph
def toposort(self, graph):
@@ -1298,7 +1308,11 @@ def toposort(self, graph):
graph.update({element: set() for element in extra_elements_in_dependencies})
while True:
- ordered = set(element for element, dependency in graph.items() if len(dependency) == 0)
+ ordered = {
+ element
+ for element, dependency in graph.items()
+ if len(dependency) == 0
+ }
if not ordered:
break
yield ordered
@@ -1445,12 +1459,13 @@ def do_action(self, job_order, available_resources, kb, args):
resource checks, and ensures the actions are executed in the correct order.
"""
if not self.has_usable_resource(available_resources):
- raise Exception('Not enough usable resources to execute {}'.format(self.name))
+ raise Exception(f'Not enough usable resources to execute {self.name}')
if not self.has_consumable_resource(available_resources):
- raise Exception('Not enough consumable resources to execute {}'.format(self.name))
+ raise Exception(f'Not enough consumable resources to execute {self.name}')
if not self.inorder(job_order):
- raise Exception("Can't execute {} - execute prerequisite actions first".
- format(self.name))
+ raise Exception(
+ f"Can't execute {self.name} - execute prerequisite actions first"
+ )
kb = super().act(kb, args) # update knowledge base
for resource in self.consumes: # remove consumed resources
available_resources[resource] -= self.consumes[resource]
@@ -1520,7 +1535,7 @@ def act(self, action):
args = action.args
list_action = first(a for a in self.actions if a.name == action.name)
if list_action is None:
- raise Exception("Action '{}' not found".format(action.name))
+ raise Exception(f"Action '{action.name}' not found")
self.initial = list_action.do_action(self.jobs, self.resources, self.initial, args).clauses
def refinements(self, library): # refinements may be (multiple) HLA themselves ...
@@ -1589,20 +1604,20 @@ def hierarchical_search(self, hierarchy):
prefix = plan.action[:index]
outcome = RealWorldPlanningProblem(
RealWorldPlanningProblem.result(self.initial, prefix), self.goals, self.actions)
- suffix = plan.action[index + 1:]
- if not hla: # hla is None and plan is primitive
- if outcome.goal_test():
- return plan.action
- else:
+ if hla:
+ suffix = plan.action[index + 1:]
for sequence in RealWorldPlanningProblem.refinements(hla, hierarchy): # find refinements
frontier.append(Node(outcome.initial, plan, prefix + sequence + suffix))
- def result(state, actions):
+ elif outcome.goal_test():
+ return plan.action
+
+ def result(self, actions):
"""The outcome of applying an action to the current problem"""
for a in actions:
- if a.check_precond(state, a.args):
- state = a(state, a.args).clauses
- return state
+ if a.check_precond(self, a.args):
+ self = a(self, a.args).clauses
+ return self
def angelic_search(self, hierarchy, initial_plan):
"""
@@ -1630,7 +1645,7 @@ def angelic_search(self, hierarchy, initial_plan):
pes_reachable_set = RealWorldPlanningProblem.reach_pes(self.initial, plan)
if self.intersects_goal(opt_reachable_set):
if RealWorldPlanningProblem.is_primitive(plan, hierarchy):
- return [x for x in plan.action]
+ return list(plan.action)
guaranteed = self.intersects_goal(pes_reachable_set)
if guaranteed and RealWorldPlanningProblem.making_progress(plan, initial_plan):
final_state = guaranteed[0] # any element of guaranteed
@@ -1653,68 +1668,70 @@ def intersects_goal(self, reachable_set):
for y in reachable_set[x]
if all(goal in y for goal in self.goals)]
- def is_primitive(plan, library):
+ def is_primitive(self, library):
"""
checks if the hla is primitive action
"""
- for hla in plan.action:
+ for hla in self.action:
indices = [i for i, x in enumerate(library['HLA']) if expr(x).op == hla.name]
for i in indices:
if library["steps"][i]:
return False
return True
- def reach_opt(init, plan):
+ def reach_opt(self, plan):
"""
Finds the optimistic reachable set of the sequence of actions in plan
"""
- reachable_set = {0: [init]}
+ reachable_set = {0: [self]}
optimistic_description = plan.action # list of angelic actions with optimistic description
return RealWorldPlanningProblem.find_reachable_set(reachable_set, optimistic_description)
- def reach_pes(init, plan):
+ def reach_pes(self, plan):
"""
Finds the pessimistic reachable set of the sequence of actions in plan
"""
- reachable_set = {0: [init]}
+ reachable_set = {0: [self]}
pessimistic_description = plan.action_pes # list of angelic actions with pessimistic description
return RealWorldPlanningProblem.find_reachable_set(reachable_set, pessimistic_description)
- def find_reachable_set(reachable_set, action_description):
+ def find_reachable_set(self, action_description):
"""
Finds the reachable states of the action_description when applied in each state of reachable set.
"""
for i in range(len(action_description)):
- reachable_set[i + 1] = []
+ self[i + 1] = []
if type(action_description[i]) is AngelicHLA:
possible_actions = action_description[i].angelic_action()
else:
possible_actions = action_description
for action in possible_actions:
- for state in reachable_set[i]:
+ for state in self[i]:
if action.check_precond(state, action.args):
if action.effect[0]:
new_state = action(state, action.args).clauses
- reachable_set[i + 1].append(new_state)
+ self[i + 1].append(new_state)
else:
- reachable_set[i + 1].append(state)
- return reachable_set
+ self[i + 1].append(state)
+ return self
- def find_hla(plan, hierarchy):
+ def find_hla(self, hierarchy):
"""
Finds the the first HLA action in plan.action, which is not primitive
and its corresponding index in plan.action
"""
hla = None
- index = len(plan.action)
- for i in range(len(plan.action)): # find the first HLA in plan, that is not primitive
- if not RealWorldPlanningProblem.is_primitive(Node(plan.state, plan.parent, [plan.action[i]]), hierarchy):
- hla = plan.action[i]
+ index = len(self.action)
+ for i in range(len(self.action)): # find the first HLA in plan, that is not primitive
+ if not RealWorldPlanningProblem.is_primitive(
+ Node(self.state, self.parent, [self.action[i]]), hierarchy
+ ):
+ hla = self.action[i]
index = i
break
return hla, index
- def making_progress(plan, initial_plan):
+ def making_progress(self, initial_plan):
"""
Prevents from infinite regression of refinements
@@ -1722,12 +1739,9 @@ def making_progress(plan, initial_plan):
its pessimistic reachable set intersects the goal inside a call to decompose on
the same plan, in the same circumstances)
"""
- for i in range(len(initial_plan)):
- if plan == initial_plan[i]:
- return False
- return True
+ return all(self != initial_plan[i] for i in range(len(initial_plan)))
- def decompose(hierarchy, plan, s_f, reachable_set):
+ def decompose(self, plan, s_f, reachable_set):
solution = []
i = max(reachable_set.keys())
while plan.action_pes:
@@ -1736,29 +1750,38 @@ def decompose(hierarchy, plan, s_f, reachable_set):
return solution
s_i = RealWorldPlanningProblem.find_previous_state(s_f, reachable_set, i, action)
problem = RealWorldPlanningProblem(s_i, s_f, plan.action)
- angelic_call = RealWorldPlanningProblem.angelic_search(problem, hierarchy,
- [AngelicNode(s_i, Node(None), [action], [action])])
- if angelic_call:
- for x in angelic_call:
- solution.insert(0, x)
- else:
+ if not (
+ angelic_call := RealWorldPlanningProblem.angelic_search(
+ problem,
+ self,
+ [AngelicNode(s_i, Node(None), [action], [action])],
+ )
+ ):
return None
+ for x in angelic_call:
+ solution.insert(0, x)
s_f = s_i
i -= 1
return solution
- def find_previous_state(s_f, reachable_set, i, action):
+ def find_previous_state(self, reachable_set, i, action):
"""
Given a final state s_f and an action finds a state s_i in reachable_set
such that when action is applied to state s_i returns s_f.
"""
- s_i = reachable_set[i - 1][0]
- for state in reachable_set[i - 1]:
- if s_f in [x for x in RealWorldPlanningProblem.reach_pes(
- state, AngelicNode(state, None, [action], [action]))[1]]:
- s_i = state
- break
- return s_i
+ return next(
+ (
+ state
+ for state in reachable_set[i - 1]
+ if self
+ in list(
+ RealWorldPlanningProblem.reach_pes(
+ state, AngelicNode(state, None, [action], [action])
+ )[1]
+ )
+ ),
+ reachable_set[i - 1][0],
+ )
def job_shop_problem():
@@ -1881,13 +1904,12 @@ def convert(self, clauses):
if isinstance(clauses, Expr):
clauses = conjuncts(clauses)
- for i in range(len(clauses)):
- for ch in lib.keys():
- if clauses[i].op == ch:
- clauses[i] = expr(lib[ch] + str(clauses[i].args[0]))
+ for i, ch in itertools.product(range(len(clauses)), lib):
+ if clauses[i].op == ch:
+ clauses[i] = expr(lib[ch] + str(clauses[i].args[0]))
elif isinstance(clauses, str):
- for ch in lib.keys():
+ for ch in lib:
clauses = clauses.replace(ch, lib[ch])
if len(clauses) > 0:
clauses = expr(clauses)
@@ -1943,32 +1965,33 @@ def angelic_action(self):
for i in it:
if effects[i]:
if clause.args:
- effects[i] = expr(str(effects[i]) + '&' + str(
- Expr(clause.op[w:], clause.args[0]))) # make changes in the ith part of effects
+ effects[i] = expr(
+ f'{str(effects[i])}&{str(Expr(clause.op[w:], clause.args[0]))}'
+ )
if n == 3:
effects[i + len(effects) // 3] = expr(
- str(effects[i + len(effects) // 3]) + '&' + str(Expr(clause.op[6:], clause.args[0])))
+ f'{str(effects[i + len(effects) // 3])}&{str(Expr(clause.op[6:], clause.args[0]))}'
+ )
else:
- effects[i] = expr(
- str(effects[i]) + '&' + str(expr(clause.op[w:]))) # make changes in the ith part of effects
+ effects[i] = expr(f'{str(effects[i])}&{str(expr(clause.op[w:]))}')
if n == 3:
effects[i + len(effects) // 3] = expr(
- str(effects[i + len(effects) // 3]) + '&' + str(expr(clause.op[6:])))
+ f'{str(effects[i + len(effects) // 3])}&{str(expr(clause.op[6:]))}'
+ )
- else:
- if clause.args:
- effects[i] = Expr(clause.op[w:], clause.args[0]) # make changes in the ith part of effects
- if n == 3:
- effects[i + len(effects) // 3] = Expr(clause.op[6:], clause.args[0])
+ elif clause.args:
+ effects[i] = Expr(clause.op[w:], clause.args[0]) # make changes in the ith part of effects
+ if n == 3:
+ effects[i + len(effects) // 3] = Expr(clause.op[6:], clause.args[0])
- else:
- effects[i] = expr(clause.op[w:]) # make changes in the ith part of effects
- if n == 3:
- effects[i + len(effects) // 3] = expr(clause.op[6:])
+ else:
+ effects[i] = expr(clause.op[w:]) # make changes in the ith part of effects
+ if n == 3:
+ effects[i + len(effects) // 3] = expr(clause.op[6:])
return [HLA(Expr(self.name, self.args), self.precond, effects[i]) for i in range(len(effects))]
- def compute_parameters(clause):
+ def compute_parameters(self):
"""
computes n,w
@@ -1982,14 +2005,14 @@ def compute_parameters(clause):
n = 3, if effect is possibly add or remove
"""
- if clause.op[:9] == 'PosYesNot':
+ if self.op[:9] == 'PosYesNot':
# possibly add/remove variable: three possible effects for the variable
n = 3
w = 9
- elif clause.op[:6] == 'PosYes': # possibly add variable: two possible effects for the variable
+ elif self.op[:6] == 'PosYes': # possibly add variable: two possible effects for the variable
n = 2
w = 6
- elif clause.op[:6] == 'PosNot': # possibly remove variable: two possible effects for the variable
+ elif self.op[:6] == 'PosNot': # possibly remove variable: two possible effects for the variable
n = 2
w = 3 # We want to keep 'Not' from 'PosNot' when adding action
else: # variable or ~variable
diff --git a/probabilistic_learning.py b/probabilistic_learning.py
index 1138e702d..7d4329333 100644
--- a/probabilistic_learning.py
+++ b/probabilistic_learning.py
@@ -82,7 +82,7 @@ def NaiveBayesSimple(distribution):
The input dictionary is in the following form:
(ClassName, ClassProb): CountingProbDist
"""
- target_dist = {c_name: prob for c_name, prob in distribution.keys()}
+ target_dist = dict(distribution.keys())
attr_dists = {c_name: count_prob for (c_name, _), count_prob in distribution.items()}
def predict(example):
diff --git a/probability.py b/probability.py
index e1e77d224..ce1d1b030 100644
--- a/probability.py
+++ b/probability.py
@@ -75,7 +75,7 @@ def show_approx(self, numfmt='{:.3g}'):
return ', '.join([('{}: ' + numfmt).format(v, p) for (v, p) in sorted(self.prob.items())])
def __repr__(self):
- return "P({})".format(self.var_name)
+ return f"P({self.var_name})"
class JointProbDist(ProbDist):
@@ -112,7 +112,7 @@ def values(self, var):
return self.vals[var]
def __repr__(self):
- return "P({})".format(self.variables)
+ return f"P({self.variables})"
def event_values(event, variables):
@@ -125,7 +125,7 @@ def event_values(event, variables):
if isinstance(event, tuple) and len(event) == len(variables):
return event
else:
- return tuple([event[var] for var in variables])
+ return tuple(event[var] for var in variables)
# ______________________________________________________________________________
@@ -155,7 +155,7 @@ def enumerate_joint(variables, e, P):
if not variables:
return P[e]
Y, rest = variables[0], variables[1:]
- return sum([enumerate_joint(rest, extend(e, Y, y), P) for y in P.values(Y)])
+ return sum(enumerate_joint(rest, extend(e, Y, y), P) for y in P.values(Y))
# ______________________________________________________________________________
@@ -190,7 +190,7 @@ def variable_node(self, var):
for n in self.nodes:
if n.variable == var:
return n
- raise Exception("No such variable: {}".format(var))
+ raise Exception(f"No such variable: {var}")
def variable_values(self, var):
"""Return the domain of var."""
@@ -272,10 +272,7 @@ def cost(self, var):
def vpi_cost_ratio(self, variables):
"""Return the VPI to cost ratio for the given variables"""
- v_by_c = []
- for var in variables:
- v_by_c.append(self.vpi(var) / self.cost(var))
- return v_by_c
+ return [self.vpi(var) / self.cost(var) for var in variables]
def vpi(self, variable):
"""Return VPI for a given variable"""
@@ -543,7 +540,7 @@ def rejection_sampling(X, e, bn, N=10000):
'False: 0.7, True: 0.3'
"""
counts = {x: 0 for x in bn.variable_values(X)} # bold N in [Figure 14.14]
- for j in range(N):
+ for _ in range(N):
sample = prior_sample(bn) # boldface x in [Figure 14.14]
if consistent_with(sample, e):
counts[sample[X]] += 1
@@ -569,7 +566,7 @@ def likelihood_weighting(X, e, bn, N=10000):
'False: 0.702, True: 0.298'
"""
W = {x: 0 for x in bn.variable_values(X)}
- for j in range(N):
+ for _ in range(N):
sample, weight = weighted_sample(bn, e) # boldface x, w in [Figure 14.15]
W[sample[X]] += weight
return ProbDist(X, W)
@@ -603,7 +600,7 @@ def gibbs_ask(X, e, bn, N=1000):
state = dict(e) # boldface x in [Figure 14.16]
for Zi in Z:
state[Zi] = random.choice(bn.variable_values(Zi))
- for j in range(N):
+ for _ in range(N):
for Zi in Z:
state[Zi] = markov_blanket_sample(Zi, state, bn)
counts[state[X]] += 1
@@ -637,10 +634,7 @@ def __init__(self, transition_model, sensor_model, prior=None):
self.prior = prior or [0.5, 0.5]
def sensor_dist(self, ev):
- if ev is True:
- return self.sensor_model[0]
- else:
- return self.sensor_model[1]
+ return self.sensor_model[0] if ev is True else self.sensor_model[1]
def forward(HMM, fv, ev):
@@ -720,7 +714,7 @@ def viterbi(HMM, ev):
for i in range(t - 1, -1, -1):
ml_probabilities[i] = m[i][i_max]
- ml_path[i] = True if i_max == 0 else False
+ ml_path[i] = i_max == 0
if i > 0:
i_max = backtracking_graph[i - 1][i_max]
@@ -819,8 +813,7 @@ def sample(self):
pos = random.choice(self.empty)
# 0N 1E 2S 3W
orient = random.choice(range(4))
- kin_state = pos + (orient,)
- return kin_state
+ return pos + (orient,)
def ray_cast(self, sensor_num, kin_state):
"""Returns distance to nearest obstacle or map boundary in the direction of sensor"""
@@ -864,7 +857,7 @@ def ray_cast(sensor_num, kin_state, m):
W_[i] = 1
for j in range(M):
z_ = ray_cast(j, S_[i], m)
- W_[i] = W_[i] * P_sensor(z[j], z_)
+ W_[i] *= P_sensor(z[j], z_)
S = weighted_sample_with_replacement(N, S_, W_)
return S
diff --git a/probability4e.py b/probability4e.py
index d413a55ae..dccfda72a 100644
--- a/probability4e.py
+++ b/probability4e.py
@@ -82,7 +82,7 @@ def show_approx(self, numfmt='{:.3g}'):
for (v, p) in sorted(self.prob.items())])
def __repr__(self):
- return "P({})".format(self.varname)
+ return f"P({self.varname})"
# ______________________________________________________________________________
@@ -123,7 +123,7 @@ def values(self, var):
return self.vals[var]
def __repr__(self):
- return "P({})".format(self.variables)
+ return f"P({self.variables})"
def event_values(event, variables):
@@ -136,7 +136,7 @@ def event_values(event, variables):
if isinstance(event, tuple) and len(event) == len(variables):
return event
else:
- return tuple([event[var] for var in variables])
+ return tuple(event[var] for var in variables)
def enumerate_joint_ask(X, e, P):
@@ -161,8 +161,7 @@ def enumerate_joint(variables, e, P):
if not variables:
return P[e]
Y, rest = variables[0], variables[1:]
- return sum([enumerate_joint(rest, extend(e, Y, y), P)
- for y in P.values(Y)])
+ return sum(enumerate_joint(rest, extend(e, Y, y), P) for y in P.values(Y))
# ______________________________________________________________________________
@@ -255,7 +254,7 @@ def variable_node(self, var):
for n in self.nodes:
if n.variable == var:
return n
- raise Exception("No such variable: {}".format(var))
+ raise Exception(f"No such variable: {var}")
def variable_values(self, var):
"""Return the domain of var."""
@@ -382,12 +381,12 @@ def gaussian_probability(param, event, value):
assert isinstance(event, dict)
assert isinstance(param, dict)
- buff = 0
- for k, v in event.items():
- # buffer varianle to calculate h1*a_h1 + h2*a_h2
- buff += param['a'][k] * v
- res = 1 / (param['sigma'] * np.sqrt(2 * np.pi)) * np.exp(-0.5 * ((value - buff - param['b']) / param['sigma']) ** 2)
- return res
+ buff = sum(param['a'][k] * v for k, v in event.items())
+ return (
+ 1
+ / (param['sigma'] * np.sqrt(2 * np.pi))
+ * np.exp(-0.5 * ((value - buff - param['b']) / param['sigma']) ** 2)
+ )
def logistic_probability(param, event, value):
@@ -640,7 +639,7 @@ def rejection_sampling(X, e, bn, N=10000):
'False: 0.7, True: 0.3'
"""
counts = {x: 0 for x in bn.variable_values(X)} # bold N in [Figure 13.16]
- for j in range(N):
+ for _ in range(N):
sample = prior_sample(bn) # boldface x in [Figure 13.16]
if consistent_with(sample, e):
counts[sample[X]] += 1
@@ -668,7 +667,7 @@ def likelihood_weighting(X, e, bn, N=10000):
"""
W = {x: 0 for x in bn.variable_values(X)}
- for j in range(N):
+ for _ in range(N):
sample, weight = weighted_sample(bn, e) # boldface x, w in [Figure 14.15]
W[sample[X]] += weight
return ProbDist(X, W)
@@ -704,7 +703,7 @@ def gibbs_ask(X, e, bn, N=1000):
state = dict(e) # boldface x in [Figure 14.16]
for Zi in Z:
state[Zi] = random.choice(bn.variable_values(Zi))
- for j in range(N):
+ for _ in range(N):
for Zi in Z:
state[Zi] = markov_blanket_sample(Zi, state, bn)
counts[state[X]] += 1
@@ -738,39 +737,21 @@ class complied_burglary:
def Burglary(self, sample):
if sample['Alarm']:
- if sample['Earthquake']:
- return probability(0.00327)
- else:
- return probability(0.485)
+ return probability(0.00327) if sample['Earthquake'] else probability(0.485)
else:
- if sample['Earthquake']:
- return probability(7.05e-05)
- else:
- return probability(6.01e-05)
+ return probability(7.05e-05) if sample['Earthquake'] else probability(6.01e-05)
def Earthquake(self, sample):
if sample['Alarm']:
- if sample['Burglary']:
- return probability(0.0020212)
- else:
- return probability(0.36755)
+ return probability(0.0020212) if sample['Burglary'] else probability(0.36755)
else:
- if sample['Burglary']:
- return probability(0.0016672)
- else:
- return probability(0.0014222)
+ return probability(0.0016672) if sample['Burglary'] else probability(0.0014222)
def MaryCalls(self, sample):
- if sample['Alarm']:
- return probability(0.7)
- else:
- return probability(0.01)
+ return probability(0.7) if sample['Alarm'] else probability(0.01)
def JongCalls(self, sample):
- if sample['Alarm']:
- return probability(0.9)
- else:
- return probability(0.05)
+ return probability(0.9) if sample['Alarm'] else probability(0.05)
def Alarm(self, sample):
raise NotImplementedError
diff --git a/reinforcement_learning.py b/reinforcement_learning.py
index 4cb91af0f..40ceacc40 100644
--- a/reinforcement_learning.py
+++ b/reinforcement_learning.py
@@ -64,11 +64,8 @@ def estimate_U(self):
self.s_history, self.r_history = [], []
# setting the new utilities to the average of the previous
# iteration and this one
- for k in U2.keys():
- if k in self.U.keys():
- self.U[k] = (self.U[k] + U2[k]) / 2
- else:
- self.U[k] = U2[k]
+ for k in U2:
+ self.U[k] = (self.U[k] + U2[k]) / 2 if k in self.U.keys() else U2[k]
return self.U
def update_state(self, percept):
@@ -197,10 +194,7 @@ def __init__(self, pi, mdp, alpha=None):
self.gamma = mdp.gamma
self.terminals = mdp.terminals
- if alpha:
- self.alpha = alpha
- else:
- self.alpha = lambda n: 1 / (1 + n) # udacity video
+ self.alpha = alpha if alpha else (lambda n: 1 / (1 + n))
def __call__(self, percept):
s1, r1 = self.update_state(percept)
@@ -261,27 +255,18 @@ def __init__(self, mdp, Ne, Rplus, alpha=None):
self.a = None
self.r = None
- if alpha:
- self.alpha = alpha
- else:
- self.alpha = lambda n: 1. / (1 + n) # udacity video
+ self.alpha = alpha if alpha else (lambda n: 1. / (1 + n))
def f(self, u, n):
"""Exploration function. Returns fixed Rplus until
agent has visited state, action a Ne number of times.
Same as ADP agent in book."""
- if n < self.Ne:
- return self.Rplus
- else:
- return u
+ return self.Rplus if n < self.Ne else u
def actions_in_state(self, state):
"""Return actions possible in given state.
Useful for max and argmax."""
- if state in self.terminals:
- return [None]
- else:
- return self.all_act
+ return [None] if state in self.terminals else self.all_act
def __call__(self, percept):
s1, r1 = self.update_state(percept)
diff --git a/reinforcement_learning4e.py b/reinforcement_learning4e.py
index eaaba3e5a..69fd54d83 100644
--- a/reinforcement_learning4e.py
+++ b/reinforcement_learning4e.py
@@ -69,11 +69,8 @@ def estimate_U(self):
self.s_history, self.r_history = [], []
# setting the new utilities to the average of the previous
# iteration and this one
- for k in U2.keys():
- if k in self.U.keys():
- self.U[k] = (self.U[k] + U2[k]) / 2
- else:
- self.U[k] = U2[k]
+ for k in U2:
+ self.U[k] = (self.U[k] + U2[k]) / 2 if k in self.U.keys() else U2[k]
return self.U
def update_state(self, percept):
@@ -208,10 +205,7 @@ def __init__(self, pi, mdp, alpha=None):
self.gamma = mdp.gamma
self.terminals = mdp.terminals
- if alpha:
- self.alpha = alpha
- else:
- self.alpha = lambda n: 1 / (1 + n) # udacity video
+ self.alpha = alpha if alpha else (lambda n: 1 / (1 + n))
def __call__(self, percept):
s1, r1 = self.update_state(percept)
@@ -277,27 +271,18 @@ def __init__(self, mdp, Ne, Rplus, alpha=None):
self.a = None
self.r = None
- if alpha:
- self.alpha = alpha
- else:
- self.alpha = lambda n: 1. / (1 + n) # udacity video
+ self.alpha = alpha if alpha else (lambda n: 1. / (1 + n))
def f(self, u, n):
"""Exploration function. Returns fixed Rplus until
agent has visited state, action a Ne number of times.
Same as ADP agent in book."""
- if n < self.Ne:
- return self.Rplus
- else:
- return u
+ return self.Rplus if n < self.Ne else u
def actions_in_state(self, state):
"""Return actions possible in given state.
Useful for max and argmax."""
- if state in self.terminals:
- return [None]
- else:
- return self.all_act
+ return [None] if state in self.terminals else self.all_act
def __call__(self, percept):
s1, r1 = self.update_state(percept)
diff --git a/renovate.json b/renovate.json
new file mode 100644
index 000000000..ef2a3cf22
--- /dev/null
+++ b/renovate.json
@@ -0,0 +1,5 @@
+{
+ "extends": [
+ "github>sarvex/renovate-configs:python",
+ ]
+}
diff --git a/reviewpad.yml b/reviewpad.yml
new file mode 100644
index 000000000..1a013a135
--- /dev/null
+++ b/reviewpad.yml
@@ -0,0 +1,126 @@
+# This file is used to configure Reviewpad.
+# The configuration is a proposal to help you get started.
+# You can use it as a starting point and customize it to your needs.
+# For more details see https://docs.reviewpad.com/guides/syntax.
+
+# Define the list of labels to be used by Reviewpad.
+# For more details see https://docs.reviewpad.com/guides/syntax#label.
+labels:
+ small:
+ description: Pull request is small
+ color: "#76dbbe"
+ medium:
+ description: Pull request is medium
+ color: "#2986cc"
+ large:
+ description: Pull request is large
+ color: "#c90076"
+
+# Define the list of workflows to be run by Reviewpad.
+# A workflow is a list of actions that will be executed based on the defined rules.
+# For more details see https://docs.reviewpad.com/guides/syntax#workflow.
+workflows:
+ # This workflow calls Reviewpad AI agent to summarize the pull request.
+ - name: summarize
+ description: Summarize the pull request
+ run:
+ # Summarize the pull request on pull request synchronization.
+ - if: ($eventType() == "synchronize" || $eventType() == "opened") && $state() == "open"
+ then: $summarize()
+
+ # This workflow assigns the most relevant reviewer to pull requests.
+ # This helps guarantee that most pull requests are reviewed by at least one person.
+ - name: reviewer-assignment
+ description: Assign the most relevant reviewer to pull requests
+ run:
+ # Automatically assign reviewer when the pull request is ready for review;
+ - if: $isDraft() == false
+ then: $assignCodeAuthorReviewers()
+
+ # This workflow praises contributors on their pull request contributions.
+ # This helps contributors feel appreciated.
+ - name: praise-contributors-on-milestones
+ description: Praise contributors based on their contributions
+ run:
+ # Praise contributors on their first pull request.
+ - if: $pullRequestCountBy($author()) == 1
+ then: $commentOnce($sprintf("Thank you @%s for this first contribution!", [$author()]))
+
+ # This workflow validates that pull requests follow the conventional commits specification.
+ # This helps developers automatically generate changelogs.
+ # For more details, see https://www.conventionalcommits.org/en/v1.0.0/.
+ - name: check-conventional-commits
+ description: Validate that pull requests follow the conventional commits
+ run:
+ - if: $isDraft() == false
+ then:
+ # Check commits messages against the conventional commits specification
+ - $commitLint()
+ # Check pull request title against the conventional commits specification.
+ - $titleLint()
+
+ # This workflow validates best practices for pull request management.
+ # This helps developers follow best practices.
+ - name: best-practices
+ description: Validate best practices for pull request management
+ run:
+ # Warn pull requests that do not have an associated GitHub issue.
+ - if: $hasLinkedIssues() == false
+ then: $warn("Please link an issue to the pull request")
+ # Warn pull requests if their description is empty.
+ - if: $description() == ""
+ then: $warn("Please provide a description for the pull request")
+ # Warn pull request do not have a clean linear history.
+ - if: $hasLinearHistory() == false
+ then: $warn("Please rebase your pull request on the latest changes")
+
+ # This workflow labels pull requests based on the total number of lines changed.
+ # This helps pick pull requests based on their size and to incentivize small pull requests.
+ - name: size-labeling
+ description: Label pull request based on the number of lines changed
+ run:
+ - if: $size() < 100
+ then: $addLabel("small")
+ else: $removeLabel("small")
+ - if: $size() >= 100 && $size() < 300
+ then: $addLabel("medium")
+ else: $removeLabel("medium")
+ - if: $size() >= 300
+ then: $addLabel("large")
+ else: $removeLabel("large")
+
+ # This workflow signals pull requests waiting for reviews.
+ # This helps guarantee that pull requests are reviewed and approved by at least one person.
+ - name: check-approvals
+ description: Check that pull requests have the required number of approvals
+ run:
+ # Label pull requests with `waiting-for-review` if there are no approvals;
+ - if: $isDraft() == false && $approvalsCount() < 1
+ then: $addLabel("waiting-for-review")
+
+ # This workflow labels pull requests based on the pull request change type.
+ # This helps pick pull requests based on their change type.
+ - name: change-type-labelling
+ description: Label pull requests based on the type of changes
+ run:
+ # Label pull requests with `docs` if they only modify Markdown or txt files.
+ - if: $hasFileExtensions([".md", ".txt"])
+ then: $addLabel("docs")
+ else: $removeLabel("docs")
+ # Label pull requests with `infra` if they modify Terraform files.
+ - if: $hasFileExtensions([".tf"])
+ then: $addLabel("infra")
+ else: $removeLabel("infra")
+ # Label pull requests with `dependencies` if they only modify `package.json` and `package.lock` files.
+ - if: $hasFileExtensions(["package.json", "package-lock.json"])
+ then: $addLabel("dependencies")
+ else: $removeLabel("dependencies")
+
+ # This workflow validates that pull requests do not contain changes to the license.
+ # This helps avoid unwanted license modifications.
+ - name: license-validation
+ description: Validate that licenses are not modified
+ run:
+ # Fail Reviewpad check on pull requests that modify any LICENSE;
+ - if: $hasFilePattern("**/LICENSE*")
+ then: $fail("License files cannot be modified")
diff --git a/search.py b/search.py
index 5012c1a18..a0968b1f5 100644
--- a/search.py
+++ b/search.py
@@ -86,7 +86,7 @@ def __init__(self, state, parent=None, action=None, path_cost=0):
self.depth = parent.depth + 1
def __repr__(self):
- return "".format(self.state)
+ return f""
def __lt__(self, node):
return self.state < node.state
@@ -99,8 +99,12 @@ def expand(self, problem):
def child_node(self, problem, action):
"""[Figure 3.10]"""
next_state = problem.result(self.state, action)
- next_node = Node(next_state, self, action, problem.path_cost(self.path_cost, self.state, action, next_state))
- return next_node
+ return Node(
+ next_state,
+ self,
+ action,
+ problem.path_cost(self.path_cost, self.state, action, next_state),
+ )
def solution(self):
"""Return the sequence of actions to go from the root to this node."""
@@ -154,9 +158,7 @@ def __call__(self, percept):
goal = self.formulate_goal(self.state)
problem = self.formulate_problem(self.state, goal)
self.seq = self.search(problem)
- if not self.seq:
- return None
- return self.seq.pop(0)
+ return None if not self.seq else self.seq.pop(0)
def update_state(self, state, percept):
raise NotImplementedError
@@ -377,10 +379,9 @@ def find_key(pr_min, open_dir, g):
node = Node(-1)
for n in open_dir:
pr = max(g[n] + problem.h(n), 2 * g[n])
- if pr == pr_min:
- if g[n] < m:
- m = g[n]
- node = n
+ if pr == pr_min and g[n] < m:
+ m = g[n]
+ node = n
return node
@@ -516,18 +517,22 @@ def actions(self, state):
orientation = state.get_orientation()
# Prevent Bumps
- if x == 1 and orientation == 'LEFT':
- if 'Forward' in possible_actions:
- possible_actions.remove('Forward')
- if y == 1 and orientation == 'DOWN':
- if 'Forward' in possible_actions:
- possible_actions.remove('Forward')
- if x == self.dimrow and orientation == 'RIGHT':
- if 'Forward' in possible_actions:
- possible_actions.remove('Forward')
- if y == self.dimrow and orientation == 'UP':
- if 'Forward' in possible_actions:
- possible_actions.remove('Forward')
+ if x == 1 and orientation == 'LEFT' and 'Forward' in possible_actions:
+ possible_actions.remove('Forward')
+ if y == 1 and orientation == 'DOWN' and 'Forward' in possible_actions:
+ possible_actions.remove('Forward')
+ if (
+ x == self.dimrow
+ and orientation == 'RIGHT'
+ and 'Forward' in possible_actions
+ ):
+ possible_actions.remove('Forward')
+ if (
+ y == self.dimrow
+ and orientation == 'UP'
+ and 'Forward' in possible_actions
+ ):
+ possible_actions.remove('Forward')
return possible_actions
@@ -535,7 +540,7 @@ def result(self, state, action):
""" Given state and action, return a new state that is the result of the action.
Action is assumed to be a valid action in the state """
x, y = state.get_location()
- proposed_loc = list()
+ proposed_loc = []
# Move Forward
if action == 'Forward':
@@ -618,10 +623,7 @@ def RBFS(problem, node, flimit):
best = successors[0]
if best.f > flimit:
return None, best.f
- if len(successors) > 1:
- alternative = successors[1].f
- else:
- alternative = np.inf
+ alternative = successors[1].f if len(successors) > 1 else np.inf
result, best.f = RBFS(problem, best, min(flimit, alternative))
if result is not None:
return result, best.f
@@ -729,8 +731,12 @@ def and_search(states, problem, path):
# Pre-defined actions for PeakFindingProblem
directions4 = {'W': (-1, 0), 'N': (0, 1), 'E': (1, 0), 'S': (0, -1)}
-directions8 = dict(directions4)
-directions8.update({'NW': (-1, 1), 'NE': (1, 1), 'SE': (1, -1), 'SW': (-1, -1)})
+directions8 = directions4 | {
+ 'NW': (-1, 1),
+ 'NE': (1, 1),
+ 'SE': (1, -1),
+ 'SW': (-1, -1),
+}
class PeakFindingProblem(Problem):
@@ -781,8 +787,8 @@ def __init__(self, problem):
self.problem = problem
self.s = None
self.a = None
- self.untried = dict()
- self.unbacktracked = dict()
+ self.untried = {}
+ self.unbacktracked = {}
self.result = {}
def __call__(self, percept):
@@ -792,10 +798,9 @@ def __call__(self, percept):
else:
if s1 not in self.untried.keys():
self.untried[s1] = self.problem.actions(s1)
- if self.s is not None:
- if s1 != self.result[(self.s, self.a)]:
- self.result[(self.s, self.a)] = s1
- self.unbacktracked[s1].insert(0, self.s)
+ if self.s is not None and s1 != self.result[(self.s, self.a)]:
+ self.result[(self.s, self.a)] = s1
+ self.unbacktracked[s1].insert(0, self.s)
if len(self.untried[s1]) == 0:
if len(self.unbacktracked[s1]) == 0:
self.a = None
@@ -848,9 +853,7 @@ def update_state(self, percept):
raise NotImplementedError
def goal_test(self, state):
- if state == self.goal:
- return True
- return False
+ return state == self.goal
class LRTAStarAgent:
@@ -871,7 +874,6 @@ def __init__(self, problem):
def __call__(self, s1): # as of now s1 is a state rather than a percept
if self.problem.goal_test(s1):
self.a = None
- return self.a
else:
if s1 not in self.H:
self.H[s1] = self.problem.h(s1)
@@ -887,7 +889,8 @@ def __call__(self, s1): # as of now s1 is a state rather than a percept
key=lambda b: self.LRTA_cost(s1, b, self.problem.output(s1, b), self.H))
self.s = s1
- return self.a
+
+ return self.a
def LRTA_cost(self, s, a, s1, H):
"""Returns cost to move from state 's' to state 's1' plus
@@ -895,13 +898,12 @@ def LRTA_cost(self, s, a, s1, H):
print(s, a, s1)
if s1 is None:
return self.problem.h(s)
- else:
- # sometimes we need to get H[s1] which we haven't yet added to H
- # to replace this try, except: we can initialize H with values from problem.h
- try:
- return self.problem.c(s, a, s1) + self.H[s1]
- except:
- return self.problem.c(s, a, s1) + self.problem.h(s1)
+ # sometimes we need to get H[s1] which we haven't yet added to H
+ # to replace this try, except: we can initialize H with values from problem.h
+ try:
+ return self.problem.c(s, a, s1) + self.H[s1]
+ except:
+ return self.problem.c(s, a, s1) + self.problem.h(s1)
# ______________________________________________________________________________
@@ -924,12 +926,17 @@ def genetic_search(problem, ngen=1000, pmut=0.1, n=20):
def genetic_algorithm(population, fitness_fn, gene_pool=[0, 1], f_thres=None, ngen=1000, pmut=0.1):
"""[Figure 4.8]"""
- for i in range(ngen):
- population = [mutate(recombine(*select(2, population, fitness_fn)), gene_pool, pmut)
- for i in range(len(population))]
-
- fittest_individual = fitness_threshold(fitness_fn, f_thres, population)
- if fittest_individual:
+ for _ in range(ngen):
+ population = [
+ mutate(
+ recombine(*select(2, population, fitness_fn)), gene_pool, pmut
+ )
+ for _ in range(len(population))
+ ]
+
+ if fittest_individual := fitness_threshold(
+ fitness_fn, f_thres, population
+ ):
return fittest_individual
return max(population, key=fitness_fn)
@@ -953,8 +960,10 @@ def init_population(pop_number, gene_pool, state_length):
state_length: The length of each individual"""
g = len(gene_pool)
population = []
- for i in range(pop_number):
- new_individual = [gene_pool[random.randrange(0, g)] for j in range(state_length)]
+ for _ in range(pop_number):
+ new_individual = [
+ gene_pool[random.randrange(0, g)] for _ in range(state_length)
+ ]
population.append(new_individual)
return population
@@ -963,7 +972,7 @@ def init_population(pop_number, gene_pool, state_length):
def select(r, population, fitness_fn):
fitnesses = map(fitness_fn, population)
sampler = weighted_sampler(population, fitnesses)
- return [sampler() for i in range(r)]
+ return [sampler() for _ in range(r)]
def recombine(x, y):
@@ -1045,15 +1054,12 @@ def get(self, a, b=None):
.get(a,b) returns the distance or None;
.get(a) returns a dict of {node: distance} entries, possibly {}."""
links = self.graph_dict.setdefault(a, {})
- if b is None:
- return links
- else:
- return links.get(b)
+ return links if b is None else links.get(b)
def nodes(self):
"""Return a list of nodes in the graph."""
- s1 = set([k for k in self.graph_dict.keys()])
- s2 = set([k2 for v in self.graph_dict.values() for k2, v2 in v.items()])
+ s1 = set(list(self.graph_dict.keys()))
+ s2 = {k2 for v in self.graph_dict.values() for k2, v2 in v.items()}
nodes = s1.union(s2)
return list(nodes)
@@ -1096,6 +1102,7 @@ def distance_to_node(n):
""" [Figure 3.2]
Simplified road map of Romania
"""
+
romania_map = UndirectedGraph(dict(
Arad=dict(Zerind=75, Sibiu=140, Timisoara=118),
Bucharest=dict(Urziceni=85, Pitesti=101, Giurgiu=90, Fagaras=211),
@@ -1166,11 +1173,14 @@ def distance_to_node(n):
""" [Figure 6.1]
Principal states and territories of Australia
"""
-australia_map = UndirectedGraph(dict(
- T=dict(),
- SA=dict(WA=1, NT=1, Q=1, NSW=1, V=1),
- NT=dict(WA=1, Q=1),
- NSW=dict(Q=1, V=1)))
+australia_map = UndirectedGraph(
+ dict(
+ T={},
+ SA=dict(WA=1, NT=1, Q=1, NSW=1, V=1),
+ NT=dict(WA=1, Q=1),
+ NSW=dict(Q=1, V=1),
+ )
+)
australia_map.locations = dict(WA=(120, 24), NT=(135, 20), SA=(135, 30),
Q=(145, 20), NSW=(145, 32), T=(145, 42),
V=(145, 37))
@@ -1205,8 +1215,7 @@ def find_min_edge(self):
def h(self, node):
"""h function is straight-line distance from a node's state to goal."""
- locs = getattr(self.graph, 'locations', None)
- if locs:
+ if locs := getattr(self.graph, 'locations', None):
if type(node) is str:
return int(distance(locs[node], locs[self.goal]))
@@ -1252,10 +1261,9 @@ def actions(self, state):
"""In the leftmost empty column, try all non-conflicting rows."""
if state[-1] != -1:
return [] # All columns filled; no successors
- else:
- col = state.index(-1)
- return [row for row in range(self.N)
- if not self.conflicted(state, row, col)]
+ col = state.index(-1)
+ return [row for row in range(self.N)
+ if not self.conflicted(state, row, col)]
def result(self, state, row):
"""Place the next queen at the given row."""
@@ -1333,7 +1341,7 @@ def print_boggle(board):
if board[i] == 'Q':
print('Qu', end=' ')
else:
- print(str(board[i]) + ' ', end=' ')
+ print(f'{str(board[i])} ', end=' ')
print()
@@ -1373,7 +1381,7 @@ def boggle_neighbors(n2, cache={}):
def exact_sqrt(n2):
"""If n2 is a perfect square, return its square root, else raise error."""
n = int(np.sqrt(n2))
- assert n * n == n2
+ assert n**2 == n2
return n
@@ -1470,7 +1478,7 @@ def words(self):
def score(self):
"""The total score for the words found, according to the rules."""
- return sum([self.scores[len(w)] for w in self.words()])
+ return sum(self.scores[len(w)] for w in self.words())
def __len__(self):
"""The number of words found."""
diff --git a/tests/test_agents.py b/tests/test_agents.py
index d1a669486..82570c0cb 100644
--- a/tests/test_agents.py
+++ b/tests/test_agents.py
@@ -38,8 +38,8 @@ def test_move_forward():
def test_add():
d = Direction(Direction.U)
- l1 = d + "right"
- l2 = d + "left"
+ l1 = f"{d}right"
+ l2 = f"{d}left"
assert l1.direction == Direction.R
assert l2.direction == Direction.L
@@ -313,7 +313,12 @@ def constant_prog(percept):
assert not any(map(lambda x: not isinstance(x, Thing), w.things))
# check that gold and wumpus are not present on (1,1)
- assert not any(map(lambda x: isinstance(x, Gold) or isinstance(x, WumpusEnvironment), w.list_things_at((1, 1))))
+ assert not any(
+ map(
+ lambda x: isinstance(x, (Gold, WumpusEnvironment)),
+ w.list_things_at((1, 1)),
+ )
+ )
# check if w.get_world() segments objects correctly
assert len(w.get_world()) == 6
diff --git a/tests/test_agents4e.py b/tests/test_agents4e.py
index 295a1ee47..5ef182318 100644
--- a/tests/test_agents4e.py
+++ b/tests/test_agents4e.py
@@ -38,8 +38,8 @@ def test_move_forward():
def test_add():
d = Direction(Direction.U)
- l1 = d + "right"
- l2 = d + "left"
+ l1 = f"{d}right"
+ l2 = f"{d}left"
assert l1.direction == Direction.R
assert l2.direction == Direction.L
@@ -312,7 +312,12 @@ def constant_prog(percept):
assert not any(map(lambda x: not isinstance(x, Thing), w.things))
# check that gold and wumpus are not present on (1,1)
- assert not any(map(lambda x: isinstance(x, Gold) or isinstance(x, WumpusEnvironment), w.list_things_at((1, 1))))
+ assert not any(
+ map(
+ lambda x: isinstance(x, (Gold, WumpusEnvironment)),
+ w.list_things_at((1, 1)),
+ )
+ )
# check if w.get_world() segments objects correctly
assert len(w.get_world()) == 6
diff --git a/tests/test_csp.py b/tests/test_csp.py
index a070cd531..4dfaf53bf 100644
--- a/tests/test_csp.py
+++ b/tests/test_csp.py
@@ -51,8 +51,10 @@ def test_csp_actions():
assert map_coloring_test.actions(state) == [('C', '2')]
state = {'A': '1'}
- assert (map_coloring_test.actions(state) == [('C', '2'), ('C', '3')] or
- map_coloring_test.actions(state) == [('B', '2'), ('B', '3')])
+ assert map_coloring_test.actions(state) in [
+ [('C', '2'), ('C', '3')],
+ [('B', '2'), ('B', '3')],
+ ]
def test_csp_result():
@@ -162,7 +164,7 @@ def test_csp_conflicted_vars():
conflicted_vars = map_coloring_test.conflicted_vars(current)
- assert (conflicted_vars == ['B', 'C'] or conflicted_vars == ['C', 'B'])
+ assert conflicted_vars in [['B', 'C'], ['C', 'B']]
def test_revise():
@@ -178,7 +180,7 @@ def test_revise():
consistency, _ = revise(csp, Xi, Xj, removals)
assert not consistency
- assert len(removals) == 0
+ assert not removals
domains = {'A': [0, 1, 2, 3, 4], 'B': [0, 1, 2, 3, 4]}
csp = CSP(variables=None, domains=domains, neighbors=neighbors, constraints=constraints)
@@ -204,8 +206,10 @@ def test_AC3():
csp = CSP(variables=None, domains=domains, neighbors=neighbors, constraints=constraints)
assert AC3(csp, removals=removals)
- assert (removals == [('A', 1), ('A', 3), ('B', 1), ('B', 3)] or
- removals == [('B', 1), ('B', 3), ('A', 1), ('A', 3)])
+ assert removals in [
+ [('A', 1), ('A', 3), ('B', 1), ('B', 3)],
+ [('B', 1), ('B', 3), ('A', 1), ('A', 3)],
+ ]
domains = {'A': [2, 4], 'B': [3, 5]}
constraints = lambda X, x, Y, y: (X == 'A' and Y == 'B') or (X == 'B' and Y == 'A') and x > y
@@ -231,8 +235,10 @@ def test_AC3b():
csp = CSP(variables=None, domains=domains, neighbors=neighbors, constraints=constraints)
assert AC3b(csp, removals=removals)
- assert (removals == [('A', 1), ('A', 3), ('B', 1), ('B', 3)] or
- removals == [('B', 1), ('B', 3), ('A', 1), ('A', 3)])
+ assert removals in [
+ [('A', 1), ('A', 3), ('B', 1), ('B', 3)],
+ [('B', 1), ('B', 3), ('A', 1), ('A', 3)],
+ ]
domains = {'A': [2, 4], 'B': [3, 5]}
constraints = lambda X, x, Y, y: (X == 'A' and Y == 'B') or (X == 'B' and Y == 'A') and x > y
@@ -258,8 +264,10 @@ def test_AC4():
csp = CSP(variables=None, domains=domains, neighbors=neighbors, constraints=constraints)
assert AC4(csp, removals=removals)
- assert (removals == [('A', 1), ('A', 3), ('B', 1), ('B', 3)] or
- removals == [('B', 1), ('B', 3), ('A', 1), ('A', 3)])
+ assert removals in [
+ [('A', 1), ('A', 3), ('B', 1), ('B', 3)],
+ [('B', 1), ('B', 3), ('A', 1), ('A', 3)],
+ ]
domains = {'A': [2, 4], 'B': [3, 5]}
constraints = lambda X, x, Y, y: (X == 'A' and Y == 'B') or (X == 'B' and Y == 'A') and x > y
@@ -275,8 +283,7 @@ def test_first_unassigned_variable():
assert first_unassigned_variable(assignment, map_coloring_test) == 'C'
assignment = {'B': '1'}
- assert (first_unassigned_variable(assignment, map_coloring_test) == 'A' or
- first_unassigned_variable(assignment, map_coloring_test) == 'C')
+ assert first_unassigned_variable(assignment, map_coloring_test) in ['A', 'C']
def test_num_legal_values():
@@ -306,8 +313,7 @@ def test_mrv():
domains = {'A': [0, 1, 2, 3, 4], 'B': [0, 1, 2, 3, 4], 'C': [0, 1, 2, 3, 4]}
csp = CSP(variables=None, domains=domains, neighbors=neighbors, constraints=constraints)
- assert (mrv(assignment, csp) == 'B' or
- mrv(assignment, csp) == 'C')
+ assert mrv(assignment, csp) in ['B', 'C']
domains = {'A': [0, 1, 2, 3, 4], 'B': [0, 1, 2, 3, 4, 5, 6], 'C': [0, 1, 2, 3, 4]}
csp = CSP(variables=None, domains=domains, neighbors=neighbors, constraints=constraints)
@@ -484,34 +490,63 @@ def test_tree_csp_solver():
def test_ac_solver():
- assert ac_solver(csp_crossword) == {'one_across': 'has',
- 'one_down': 'hold',
- 'two_down': 'syntax',
- 'three_across': 'land',
- 'four_across': 'ant'} or {'one_across': 'bus',
- 'one_down': 'buys',
- 'two_down': 'search',
- 'three_across': 'year',
- 'four_across': 'car'}
- assert ac_solver(two_two_four) == {'T': 7, 'F': 1, 'W': 6, 'O': 5, 'U': 3, 'R': 0, 'C1': 1, 'C2': 1, 'C3': 1} or \
- {'T': 9, 'F': 1, 'W': 2, 'O': 8, 'U': 5, 'R': 6, 'C1': 1, 'C2': 0, 'C3': 1}
+ assert (
+ ac_solver(csp_crossword)
+ == {
+ 'one_across': 'has',
+ 'one_down': 'hold',
+ 'two_down': 'syntax',
+ 'three_across': 'land',
+ 'four_across': 'ant',
+ }
+ or True
+ )
+ assert (
+ ac_solver(two_two_four)
+ == {
+ 'T': 7,
+ 'F': 1,
+ 'W': 6,
+ 'O': 5,
+ 'U': 3,
+ 'R': 0,
+ 'C1': 1,
+ 'C2': 1,
+ 'C3': 1,
+ }
+ or True
+ )
assert ac_solver(send_more_money) == \
{'S': 9, 'M': 1, 'E': 5, 'N': 6, 'D': 7, 'O': 0, 'R': 8, 'Y': 2, 'C1': 1, 'C2': 1, 'C3': 0, 'C4': 1}
def test_ac_search_solver():
- assert ac_search_solver(csp_crossword) == {'one_across': 'has',
- 'one_down': 'hold',
- 'two_down': 'syntax',
- 'three_across': 'land',
- 'four_across': 'ant'} or {'one_across': 'bus',
- 'one_down': 'buys',
- 'two_down': 'search',
- 'three_across': 'year',
- 'four_across': 'car'}
- assert ac_search_solver(two_two_four) == {'T': 7, 'F': 1, 'W': 6, 'O': 5, 'U': 3, 'R': 0,
- 'C1': 1, 'C2': 1, 'C3': 1} or \
- {'T': 9, 'F': 1, 'W': 2, 'O': 8, 'U': 5, 'R': 6, 'C1': 1, 'C2': 0, 'C3': 1}
+ assert (
+ ac_search_solver(csp_crossword)
+ == {
+ 'one_across': 'has',
+ 'one_down': 'hold',
+ 'two_down': 'syntax',
+ 'three_across': 'land',
+ 'four_across': 'ant',
+ }
+ or True
+ )
+ assert (
+ ac_search_solver(two_two_four)
+ == {
+ 'T': 7,
+ 'F': 1,
+ 'W': 6,
+ 'O': 5,
+ 'U': 3,
+ 'R': 0,
+ 'C1': 1,
+ 'C2': 1,
+ 'C3': 1,
+ }
+ or True
+ )
assert ac_search_solver(send_more_money) == {'S': 9, 'M': 1, 'E': 5, 'N': 6, 'D': 7, 'O': 0, 'R': 8, 'Y': 2,
'C1': 1, 'C2': 1, 'C3': 0, 'C4': 1}
diff --git a/tests/test_games.py b/tests/test_games.py
index b7541ee93..668677d42 100644
--- a/tests/test_games.py
+++ b/tests/test_games.py
@@ -15,11 +15,13 @@ def gen_state(to_move='X', x_positions=[], o_positions=[], h=3, v=3):
and how many consecutive X's or O's required to win, return the corresponding
game state"""
- moves = set([(x, y) for x in range(1, h + 1) for y in range(1, v + 1)]) - set(x_positions) - set(o_positions)
+ moves = (
+ {(x, y) for x in range(1, h + 1) for y in range(1, v + 1)}
+ - set(x_positions)
+ - set(o_positions)
+ )
moves = list(moves)
- board = {}
- for pos in x_positions:
- board[pos] = 'X'
+ board = {pos: 'X' for pos in x_positions}
for pos in o_positions:
board[pos] = 'O'
return GameState(to_move=to_move, utility=0, board=board, moves=moves)
diff --git a/tests/test_games4e.py b/tests/test_games4e.py
index 7dfa47f11..de140776a 100644
--- a/tests/test_games4e.py
+++ b/tests/test_games4e.py
@@ -16,11 +16,13 @@ def gen_state(to_move='X', x_positions=[], o_positions=[], h=3, v=3):
and how many consecutive X's or O's required to win, return the corresponding
game state"""
- moves = set([(x, y) for x in range(1, h + 1) for y in range(1, v + 1)]) - set(x_positions) - set(o_positions)
+ moves = (
+ {(x, y) for x in range(1, h + 1) for y in range(1, v + 1)}
+ - set(x_positions)
+ - set(o_positions)
+ )
moves = list(moves)
- board = {}
- for pos in x_positions:
- board[pos] = 'X'
+ board = {pos: 'X' for pos in x_positions}
for pos in o_positions:
board[pos] = 'O'
return GameState(to_move=to_move, utility=0, board=board, moves=moves)
diff --git a/tests/test_knowledge.py b/tests/test_knowledge.py
index d3829de02..efd1353fb 100644
--- a/tests/test_knowledge.py
+++ b/tests/test_knowledge.py
@@ -79,12 +79,7 @@ def test_version_space_learning():
V = version_space_learning(party)
results = []
for e in party:
- guess = False
- for h in V:
- if guess_value(e, h):
- guess = True
- break
-
+ guess = any(guess_value(e, h) for h in V)
results.append(guess)
assert results == [True, True, False]
@@ -178,7 +173,9 @@ def test_extend_example():
with each possible constant value for each new variable in literal.)
"""
assert len(list(small_family.extend_example({x: expr('Andrew')}, expr('Father(x, y)')))) == 2
- assert len(list(small_family.extend_example({x: expr('Andrew')}, expr('Mother(x, y)')))) == 0
+ assert not list(
+ small_family.extend_example({x: expr('Andrew')}, expr('Mother(x, y)'))
+ )
assert len(list(small_family.extend_example({x: expr('Andrew')}, expr('Female(y)')))) == 6
diff --git a/tests/test_logic.py b/tests/test_logic.py
index 2ead21746..18eb86fdb 100644
--- a/tests/test_logic.py
+++ b/tests/test_logic.py
@@ -141,7 +141,11 @@ def test_dpll_satisfiable():
(B | ~C | D) & (A | ~E | F) & (~A | E | D)) == \
{B: False, C: True, A: True, F: False, D: True, E: False}
assert dpll_satisfiable(A & B & ~C & D) == {C: False, A: True, D: True, B: True}
- assert dpll_satisfiable((A | (B & C)) | '<=>' | ((A | B) & (A | C))) == {C: True, A: True} or {C: True, B: True}
+ assert (
+ dpll_satisfiable((A | (B & C)) | '<=>' | ((A | B) & (A | C)))
+ == {C: True, A: True}
+ or True
+ )
assert dpll_satisfiable(A | '<=>' | B) == {A: True, B: True}
assert dpll_satisfiable(A & ~B) == {A: True, B: False}
assert dpll_satisfiable(P & ~P) is False
@@ -152,7 +156,11 @@ def test_cdcl_satisfiable():
(B | ~C | D) & (A | ~E | F) & (~A | E | D)) == \
{B: False, C: True, A: True, F: False, D: True, E: False}
assert cdcl_satisfiable(A & B & ~C & D) == {C: False, A: True, D: True, B: True}
- assert cdcl_satisfiable((A | (B & C)) | '<=>' | ((A | B) & (A | C))) == {C: True, A: True} or {C: True, B: True}
+ assert (
+ cdcl_satisfiable((A | (B & C)) | '<=>' | ((A | B) & (A | C)))
+ == {C: True, A: True}
+ or True
+ )
assert cdcl_satisfiable(A | '<=>' | B) == {A: True, B: True}
assert cdcl_satisfiable(A & ~B) == {A: True, B: False}
assert cdcl_satisfiable(P & ~P) is False
@@ -318,8 +326,13 @@ def test_fol_bc_ask():
def test_ask(query, kb=None):
q = expr(query)
answers = fol_bc_ask(kb or test_kb, q)
- return sorted([dict((x, v) for x, v in list(a.items()) if x in variables(q))
- for a in answers], key=repr)
+ return sorted(
+ [
+ {x: v for x, v in list(a.items()) if x in variables(q)}
+ for a in answers
+ ],
+ key=repr,
+ )
assert repr(test_ask('Farmer(x)')) == '[{x: Mac}]'
assert repr(test_ask('Human(x)')) == '[{x: Mac}, {x: MrsMac}]'
@@ -331,8 +344,13 @@ def test_fol_fc_ask():
def test_ask(query, kb=None):
q = expr(query)
answers = fol_fc_ask(kb or test_kb, q)
- return sorted([dict((x, v) for x, v in list(a.items()) if x in variables(q))
- for a in answers], key=repr)
+ return sorted(
+ [
+ {x: v for x, v in list(a.items()) if x in variables(q)}
+ for a in answers
+ ],
+ key=repr,
+ )
assert repr(test_ask('Criminal(x)', crime_kb)) == '[{x: West}]'
assert repr(test_ask('Enemy(x, America)', crime_kb)) == '[{x: Nono}]'
diff --git a/tests/test_logic4e.py b/tests/test_logic4e.py
index 5a7399281..2d79f2f93 100644
--- a/tests/test_logic4e.py
+++ b/tests/test_logic4e.py
@@ -139,7 +139,11 @@ def test_dpll():
& (~D | ~F) & (B | ~C | D) & (A | ~E | F) & (~A | E | D))
== {B: False, C: True, A: True, F: False, D: True, E: False})
assert dpll_satisfiable(A & B & ~C & D) == {C: False, A: True, D: True, B: True}
- assert dpll_satisfiable((A | (B & C)) | '<=>' | ((A | B) & (A | C))) == {C: True, A: True} or {C: True, B: True}
+ assert (
+ dpll_satisfiable((A | (B & C)) | '<=>' | ((A | B) & (A | C)))
+ == {C: True, A: True}
+ or True
+ )
assert dpll_satisfiable(A | '<=>' | B) == {A: True, B: True}
assert dpll_satisfiable(A & ~B) == {A: True, B: False}
assert dpll_satisfiable(P & ~P) is False
@@ -288,8 +292,12 @@ def test_ask(query, kb=None):
test_variables = variables(q)
answers = fol_bc_ask(kb or test_kb, q)
return sorted(
- [dict((x, v) for x, v in list(a.items()) if x in test_variables)
- for a in answers], key=repr)
+ [
+ {x: v for x, v in list(a.items()) if x in test_variables}
+ for a in answers
+ ],
+ key=repr,
+ )
assert repr(test_ask('Farmer(x)')) == '[{x: Mac}]'
assert repr(test_ask('Human(x)')) == '[{x: Mac}, {x: MrsMac}]'
@@ -303,8 +311,12 @@ def test_ask(query, kb=None):
test_variables = variables(q)
answers = fol_fc_ask(kb or test_kb, q)
return sorted(
- [dict((x, v) for x, v in list(a.items()) if x in test_variables)
- for a in answers], key=repr)
+ [
+ {x: v for x, v in list(a.items()) if x in test_variables}
+ for a in answers
+ ],
+ key=repr,
+ )
assert repr(test_ask('Criminal(x)', crime_kb)) == '[{x: West}]'
assert repr(test_ask('Enemy(x, America)', crime_kb)) == '[{x: Nono}]'
diff --git a/tests/test_mdp.py b/tests/test_mdp.py
index 979b4ba85..96fcee793 100644
--- a/tests/test_mdp.py
+++ b/tests/test_mdp.py
@@ -137,10 +137,7 @@ def test_pomdp_value_iteration():
utility = pomdp_value_iteration(pomdp, epsilon=5)
for _, v in utility.items():
- sum_ = 0
- for element in v:
- sum_ += sum(element)
-
+ sum_ = sum(sum(element) for element in v)
assert -9.76 < sum_ < -9.70 or 246.5 < sum_ < 248.5 or 0 < sum_ < 1
@@ -157,10 +154,7 @@ def test_pomdp_value_iteration2():
utility = pomdp_value_iteration(pomdp, epsilon=100)
for _, v in utility.items():
- sum_ = 0
- for element in v:
- sum_ += sum(element)
-
+ sum_ = sum(sum(element) for element in v)
assert -77.31 < sum_ < -77.25 or 799 < sum_ < 800
diff --git a/tests/test_mdp4e.py b/tests/test_mdp4e.py
index e51bda5d6..d2b4433f8 100644
--- a/tests/test_mdp4e.py
+++ b/tests/test_mdp4e.py
@@ -145,10 +145,7 @@ def test_pomdp_value_iteration():
utility = pomdp_value_iteration(pomdp, epsilon=5)
for _, v in utility.items():
- sum_ = 0
- for element in v:
- sum_ += sum(element)
-
+ sum_ = sum(sum(element) for element in v)
assert -9.76 < sum_ < -9.70 or 246.5 < sum_ < 248.5 or 0 < sum_ < 1
@@ -165,10 +162,7 @@ def test_pomdp_value_iteration2():
utility = pomdp_value_iteration(pomdp, epsilon=100)
for _, v in utility.items():
- sum_ = 0
- for element in v:
- sum_ += sum(element)
-
+ sum_ = sum(sum(element) for element in v)
assert -77.31 < sum_ < -77.25 or 799 < sum_ < 800
diff --git a/tests/test_nlp.py b/tests/test_nlp.py
index 85d246dfa..3358360ef 100644
--- a/tests/test_nlp.py
+++ b/tests/test_nlp.py
@@ -54,10 +54,10 @@ def test_generation():
sentence = grammar.generate_random('S')
for token in sentence.split():
- found = False
- for non_terminal, terminals in grammar.lexicon.items():
- if token in terminals:
- found = True
+ found = any(
+ token in terminals
+ for non_terminal, terminals in grammar.lexicon.items()
+ )
assert found
@@ -184,9 +184,9 @@ def test_stripRawHTML(html_mock):
def test_determineInlinks():
- assert set(determineInlinks(pA)) == set(['B', 'C', 'E'])
+ assert set(determineInlinks(pA)) == {'B', 'C', 'E'}
assert set(determineInlinks(pE)) == set([])
- assert set(determineInlinks(pF)) == set(['E'])
+ assert set(determineInlinks(pF)) == {'E'}
def test_findOutlinks_wiki():
diff --git a/tests/test_nlp4e.py b/tests/test_nlp4e.py
index 2d16a3196..c9a2d1c18 100644
--- a/tests/test_nlp4e.py
+++ b/tests/test_nlp4e.py
@@ -49,10 +49,10 @@ def test_generation():
sentence = grammar.generate_random('S')
for token in sentence.split():
- found = False
- for non_terminal, terminals in grammar.lexicon.items():
- if token in terminals:
- found = True
+ found = any(
+ token in terminals
+ for non_terminal, terminals in grammar.lexicon.items()
+ )
assert found
diff --git a/tests/test_planning.py b/tests/test_planning.py
index a39152adc..fc239f015 100644
--- a/tests/test_planning.py
+++ b/tests/test_planning.py
@@ -586,7 +586,7 @@ def test_job_shop_problem():
def test_refinements():
- result = [i for i in RealWorldPlanningProblem.refinements(go_SFO, library_1)]
+ result = list(RealWorldPlanningProblem.refinements(go_SFO, library_1))
assert (result[0][0].name == drive_SFOLongTermParking.name)
assert (result[0][0].args == drive_SFOLongTermParking.args)
diff --git a/tests/test_probability.py b/tests/test_probability.py
index 8def79c68..176b61114 100644
--- a/tests/test_probability.py
+++ b/tests/test_probability.py
@@ -165,7 +165,7 @@ def test_elimination_ask():
def test_prior_sample():
random.seed(42)
- all_obs = [prior_sample(burglary) for x in range(1000)]
+ all_obs = [prior_sample(burglary) for _ in range(1000)]
john_calls_true = [observation for observation in all_obs if observation['JohnCalls']]
mary_calls_true = [observation for observation in all_obs if observation['MaryCalls']]
burglary_and_john = [observation for observation in john_calls_true if observation['Burglary']]
@@ -178,13 +178,13 @@ def test_prior_sample():
def test_prior_sample2():
random.seed(128)
- all_obs = [prior_sample(sprinkler) for x in range(1000)]
+ all_obs = [prior_sample(sprinkler) for _ in range(1000)]
rain_true = [observation for observation in all_obs if observation['Rain']]
sprinkler_true = [observation for observation in all_obs if observation['Sprinkler']]
rain_and_cloudy = [observation for observation in rain_true if observation['Cloudy']]
sprinkler_and_cloudy = [observation for observation in sprinkler_true if observation['Cloudy']]
- assert len(rain_true) / 1000 == 0.476
- assert len(sprinkler_true) / 1000 == 0.291
+ assert len(rain_true) == 0.476 * 1000
+ assert len(sprinkler_true) == 0.291 * 1000
assert len(rain_and_cloudy) / len(rain_true) == 376 / 476
assert len(sprinkler_and_cloudy) / len(sprinkler_true) == 39 / 291
diff --git a/tests/test_probability4e.py b/tests/test_probability4e.py
index d07954e0a..b25a3f568 100644
--- a/tests/test_probability4e.py
+++ b/tests/test_probability4e.py
@@ -200,7 +200,7 @@ def test_elimination_ask():
def test_prior_sample():
random.seed(42)
- all_obs = [prior_sample(burglary) for x in range(1000)]
+ all_obs = [prior_sample(burglary) for _ in range(1000)]
john_calls_true = [observation for observation in all_obs if observation['JohnCalls'] is True]
mary_calls_true = [observation for observation in all_obs if observation['MaryCalls'] is True]
burglary_and_john = [observation for observation in john_calls_true if observation['Burglary'] is True]
@@ -213,13 +213,13 @@ def test_prior_sample():
def test_prior_sample2():
random.seed(128)
- all_obs = [prior_sample(sprinkler) for x in range(1000)]
+ all_obs = [prior_sample(sprinkler) for _ in range(1000)]
rain_true = [observation for observation in all_obs if observation['Rain'] is True]
sprinkler_true = [observation for observation in all_obs if observation['Sprinkler'] is True]
rain_and_cloudy = [observation for observation in rain_true if observation['Cloudy'] is True]
sprinkler_and_cloudy = [observation for observation in sprinkler_true if observation['Cloudy'] is True]
- assert len(rain_true) / 1000 == 0.476
- assert len(sprinkler_true) / 1000 == 0.291
+ assert len(rain_true) == 0.476 * 1000
+ assert len(sprinkler_true) == 0.291 * 1000
assert len(rain_and_cloudy) / len(rain_true) == 376 / 476
assert len(sprinkler_and_cloudy) / len(sprinkler_true) == 39 / 291
diff --git a/tests/test_reinforcement_learning.py b/tests/test_reinforcement_learning.py
index d80ad3baf..b5d66245c 100644
--- a/tests/test_reinforcement_learning.py
+++ b/tests/test_reinforcement_learning.py
@@ -19,7 +19,7 @@
def test_PassiveDUEAgent():
agent = PassiveDUEAgent(policy, sequential_decision_environment)
- for i in range(200):
+ for _ in range(200):
run_single_trial(agent, sequential_decision_environment)
agent.estimate_U()
# Agent does not always produce same results.
@@ -32,7 +32,7 @@ def test_PassiveDUEAgent():
def test_PassiveADPAgent():
agent = PassiveADPAgent(policy, sequential_decision_environment)
- for i in range(100):
+ for _ in range(100):
run_single_trial(agent, sequential_decision_environment)
# Agent does not always produce same results.
@@ -45,7 +45,7 @@ def test_PassiveADPAgent():
def test_PassiveTDAgent():
agent = PassiveTDAgent(policy, sequential_decision_environment, alpha=lambda n: 60. / (59 + n))
- for i in range(200):
+ for _ in range(200):
run_single_trial(agent, sequential_decision_environment)
# Agent does not always produce same results.
@@ -58,7 +58,7 @@ def test_PassiveTDAgent():
def test_QLearning():
q_agent = QLearningAgent(sequential_decision_environment, Ne=5, Rplus=2, alpha=lambda n: 60. / (59 + n))
- for i in range(200):
+ for _ in range(200):
run_single_trial(q_agent, sequential_decision_environment)
# Agent does not always produce same results.
diff --git a/tests/test_reinforcement_learning4e.py b/tests/test_reinforcement_learning4e.py
index 287ec397b..b002e3153 100644
--- a/tests/test_reinforcement_learning4e.py
+++ b/tests/test_reinforcement_learning4e.py
@@ -17,7 +17,7 @@
def test_PassiveDUEAgent():
agent = PassiveDUEAgent(policy, sequential_decision_environment)
- for i in range(200):
+ for _ in range(200):
run_single_trial(agent, sequential_decision_environment)
agent.estimate_U()
# Agent does not always produce same results.
@@ -30,7 +30,7 @@ def test_PassiveDUEAgent():
def test_PassiveADPAgent():
agent = PassiveADPAgent(policy, sequential_decision_environment)
- for i in range(100):
+ for _ in range(100):
run_single_trial(agent, sequential_decision_environment)
# Agent does not always produce same results.
@@ -43,7 +43,7 @@ def test_PassiveADPAgent():
def test_PassiveTDAgent():
agent = PassiveTDAgent(policy, sequential_decision_environment, alpha=lambda n: 60. / (59 + n))
- for i in range(200):
+ for _ in range(200):
run_single_trial(agent, sequential_decision_environment)
# Agent does not always produce same results.
@@ -56,7 +56,7 @@ def test_PassiveTDAgent():
def test_QLearning():
q_agent = QLearningAgent(sequential_decision_environment, Ne=5, Rplus=2, alpha=lambda n: 60. / (59 + n))
- for i in range(200):
+ for _ in range(200):
run_single_trial(q_agent, sequential_decision_environment)
# Agent does not always produce same results.
diff --git a/tests/test_search.py b/tests/test_search.py
index 9be3e4a47..a7882ac21 100644
--- a/tests/test_search.py
+++ b/tests/test_search.py
@@ -204,7 +204,7 @@ def test_simulated_annealing():
prob = PeakFindingProblem((0, 0), [[0, 5, 10, 8],
[-3, 7, 9, 999],
[1, 2, 5, 11]], directions8)
- sols = {prob.value(simulated_annealing(prob)) for i in range(100)}
+ sols = {prob.value(simulated_annealing(prob)) for _ in range(100)}
assert max(sols) == 999
@@ -235,7 +235,7 @@ def run_plan(state, problem, plan):
def test_online_dfs_agent():
odfs_agent = OnlineDFSAgent(LRTA_problem)
- keys = [key for key in odfs_agent('State_3')]
+ keys = list(odfs_agent('State_3'))
assert keys[0] in ['Right', 'Left']
assert keys[1] in ['Right', 'Left']
assert odfs_agent('State_5') is None
@@ -267,13 +267,16 @@ def fitness(c):
return sum(c[n1] != c[n2] for (n1, n2) in edges.values())
solution_chars = GA_GraphColoringChars(edges, fitness)
- assert solution_chars == ['R', 'G', 'R', 'G'] or solution_chars == ['G', 'R', 'G', 'R']
+ assert solution_chars in [['R', 'G', 'R', 'G'], ['G', 'R', 'G', 'R']]
solution_bools = GA_GraphColoringBools(edges, fitness)
- assert solution_bools == [True, False, True, False] or solution_bools == [False, True, False, True]
+ assert solution_bools in [
+ [True, False, True, False],
+ [False, True, False, True],
+ ]
solution_ints = GA_GraphColoringInts(edges, fitness)
- assert solution_ints == [0, 1, 0, 1] or solution_ints == [1, 0, 1, 0]
+ assert solution_ints in [[0, 1, 0, 1], [1, 0, 1, 0]]
# Queens Problem
gene_pool = range(8)
@@ -318,17 +321,18 @@ def GA_GraphColoringInts(edges, fitness):
def test_simpleProblemSolvingAgent():
+
+
+
class vacuumAgent(SimpleProblemSolvingAgentProgram):
def update_state(self, state, percept):
return percept
def formulate_goal(self, state):
- goal = [state7, state8]
- return goal
+ return [state7, state8]
def formulate_problem(self, state, goal):
- problem = state
- return problem
+ return state
def search(self, problem):
if problem == state1:
@@ -345,6 +349,7 @@ def search(self, problem):
seq = ["Left", "Suck"]
return seq
+
state1 = [(0, 0), [(0, 0), "Dirty"], [(1, 0), ["Dirty"]]]
state2 = [(1, 0), [(0, 0), "Dirty"], [(1, 0), ["Dirty"]]]
state3 = [(0, 0), [(0, 0), "Clean"], [(1, 0), ["Dirty"]]]
diff --git a/tests/test_text.py b/tests/test_text.py
index 3aaa007f6..16847a2dd 100644
--- a/tests/test_text.py
+++ b/tests/test_text.py
@@ -194,7 +194,7 @@ def test_rot13_decoding():
def test_counting_probability_distribution():
D = CountingProbDist()
- for i in range(10000):
+ for _ in range(10000):
D.add(random.choice('123456'))
ps = [D[n] for n in '123456']
@@ -279,7 +279,7 @@ def test_canonicalize():
def test_translate():
text = 'orange apple lemon '
- func = lambda x: ('s ' + x) if x == ' ' else x
+ func = lambda x: f's {x}' if x == ' ' else x
assert translate(text, func) == 'oranges apples lemons '
diff --git a/tests/test_utils.py b/tests/test_utils.py
index 6c2a50808..75679724d 100644
--- a/tests/test_utils.py
+++ b/tests/test_utils.py
@@ -38,7 +38,7 @@ def test_count():
assert count([1, 2, 3, 4, 2, 3, 4]) == 7
assert count("aldpeofmhngvia") == 14
assert count([True, False, True, True, False]) == 3
- assert count([5 > 1, len("abc") == 3, 3 + 1 == 5]) == 2
+ assert count([5 > 1, len("abc") == 3, 3 == 4]) == 2
assert count("aima") == 4
@@ -186,7 +186,7 @@ def test_weighted_choice():
def compare_list(x, y):
- return all([elm_x == y[i] for i, elm_x in enumerate(x)])
+ return all(elm_x == y[i] for i, elm_x in enumerate(x))
def test_distance():
diff --git a/text.py b/text.py
index 11a5731f1..53c2837bb 100644
--- a/text.py
+++ b/text.py
@@ -31,7 +31,7 @@ def __init__(self, observations, default=0):
def samples(self, n):
"""Return a string of n words, random according to the model."""
- return ' '.join(self.sample() for i in range(n))
+ return ' '.join(self.sample() for _ in range(n))
class NgramWordModel(CountingProbDist):
@@ -74,7 +74,7 @@ def samples(self, nwords):
n = self.n
output = list(self.sample())
- for i in range(n, nwords):
+ for _ in range(n, nwords):
last = output[-n + 1:]
next_word = self.cond_prob[tuple(last)].sample()
output.append(next_word)
@@ -86,7 +86,7 @@ class NgramCharModel(NgramWordModel):
def add_sequence(self, words):
"""Add an empty space to every word to catch the beginning of words."""
for word in words:
- super().add_sequence(' ' + word)
+ super().add_sequence(f' {word}')
class UnigramCharModel(NgramCharModel):
@@ -125,7 +125,7 @@ def viterbi_segment(text, P):
sequence = []
i = len(words) - 1
while i > 0:
- sequence[0:0] = [words[i]]
+ sequence[:0] = [words[i]]
i = i - len(words[i])
# Return sequence of best words and overall probability
return sequence, best[-1]
@@ -275,18 +275,12 @@ def rot13(plaintext):
def translate(plaintext, function):
"""Translate chars of a plaintext with the given function."""
- result = ""
- for char in plaintext:
- result += function(char)
- return result
+ return "".join(function(char) for char in plaintext)
def maketrans(from_, to_):
"""Create a translation table and return the proper function."""
- trans_table = {}
- for n, char in enumerate(from_):
- trans_table[char] = to_[n]
-
+ trans_table = {char: to_[n] for n, char in enumerate(from_)}
return lambda char: trans_table.get(char, char)
diff --git a/utils.py b/utils.py
index 3158e3793..ac570e27f 100644
--- a/utils.py
+++ b/utils.py
@@ -20,7 +20,11 @@
def sequence(iterable):
"""Converts iterable to sequence, if it is not already one."""
- return iterable if isinstance(iterable, collections.abc.Sequence) else tuple([iterable])
+ return (
+ iterable
+ if isinstance(iterable, collections.abc.Sequence)
+ else (iterable,)
+ )
def remove_all(item, seq):
@@ -190,8 +194,7 @@ def weighted_sample_with_replacement(n, seq, weights):
def weighted_sampler(seq, weights):
"""Return a random-sample function that picks from seq weighted by weights."""
totals = []
- for w in weights:
- totals.append(w + totals[-1] if totals else w)
+ totals.extend(w + totals[-1] if totals else w for w in weights)
return lambda: seq[bisect.bisect(totals, random.uniform(0, totals[-1]))]
@@ -212,9 +215,8 @@ def rounder(numbers, d=4):
"""Round a single number, or sequence of numbers, to d decimal places."""
if isinstance(numbers, (int, float)):
return round(numbers, d)
- else:
- constructor = type(numbers) # Can be list, set, tuple, etc.
- return constructor(rounder(n, d) for n in numbers)
+ constructor = type(numbers) # Can be list, set, tuple, etc.
+ return constructor(rounder(n, d) for n in numbers)
def num_or_str(x): # TODO: rename as `atom`
@@ -413,10 +415,10 @@ def memoize(fn, slot=None, maxsize=32):
def memoized_fn(obj, *args):
if hasattr(obj, slot):
return getattr(obj, slot)
- else:
- val = fn(obj, *args)
- setattr(obj, slot, val)
- return val
+ val = fn(obj, *args)
+ setattr(obj, slot, val)
+ return val
+
else:
@functools.lru_cache(maxsize=maxsize)
def memoized_fn(*args):
@@ -612,12 +614,12 @@ def __repr__(self):
op = self.op
args = [str(arg) for arg in self.args]
if op.isidentifier(): # f(x) or f(x, y)
- return '{}({})'.format(op, ', '.join(args)) if args else op
+ return f"{op}({', '.join(args)})" if args else op
elif len(args) == 1: # -x or -(x + 1)
return op + args[0]
else: # (x - y)
- opp = (' ' + op + ' ')
- return '(' + opp.join(args) + ')'
+ opp = f' {op} '
+ return f'({opp.join(args)})'
# An 'Expression' is either an Expr or a Number.
@@ -648,10 +650,7 @@ def subexpressions(x):
def arity(expression):
"""The number of sub-expressions in this expression."""
- if isinstance(expression, Expr):
- return len(expression.args)
- else: # expression is a number
- return 0
+ return len(expression.args) if isinstance(expression, Expr) else 0
# For operators that are not defined in Python, we allow new InfixOps:
@@ -667,7 +666,7 @@ def __or__(self, rhs):
return Expr(self.op, self.lhs, rhs)
def __repr__(self):
- return "PartialExpr('{}', {})".format(self.op, self.lhs)
+ return f"PartialExpr('{self.op}', {self.lhs})"
def expr(x):
@@ -690,7 +689,7 @@ def expr_handle_infix_ops(x):
"P |'==>'| Q"
"""
for op in infix_ops:
- x = x.replace(op, '|' + repr(op) + '|')
+ x = x.replace(op, f'|{repr(op)}|')
return x
@@ -758,7 +757,7 @@ def __len__(self):
def __contains__(self, key):
"""Return True if the key is in PriorityQueue."""
- return any([item == key for _, item in self.heap])
+ return any(item == key for _, item in self.heap)
def __getitem__(self, key):
"""Returns the first value associated with key in PriorityQueue.
@@ -766,14 +765,14 @@ def __getitem__(self, key):
for value, item in self.heap:
if item == key:
return value
- raise KeyError(str(key) + " is not in the priority queue")
+ raise KeyError(f"{str(key)} is not in the priority queue")
def __delitem__(self, key):
"""Delete the first occurrence of key."""
try:
del self.heap[[item == key for _, item in self.heap].index(True)]
except ValueError:
- raise KeyError(str(key) + " is not in the priority queue")
+ raise KeyError(f"{str(key)} is not in the priority queue")
heapq.heapify(self.heap)
diff --git a/utils4e.py b/utils4e.py
index 65cb9026f..fed307fd4 100644
--- a/utils4e.py
+++ b/utils4e.py
@@ -59,7 +59,7 @@ def __len__(self):
def __contains__(self, key):
"""Return True if the key is in PriorityQueue."""
- return any([item == key for _, item in self.heap])
+ return any(item == key for _, item in self.heap)
def __getitem__(self, key):
"""Returns the first value associated with key in PriorityQueue.
@@ -67,14 +67,14 @@ def __getitem__(self, key):
for value, item in self.heap:
if item == key:
return value
- raise KeyError(str(key) + " is not in the priority queue")
+ raise KeyError(f"{str(key)} is not in the priority queue")
def __delitem__(self, key):
"""Delete the first occurrence of key."""
try:
del self.heap[[item == key for _, item in self.heap].index(True)]
except ValueError:
- raise KeyError(str(key) + " is not in the priority queue")
+ raise KeyError(f"{str(key)} is not in the priority queue")
heapq.heapify(self.heap)
@@ -84,8 +84,11 @@ def __delitem__(self, key):
def sequence(iterable):
"""Converts iterable to sequence, if it is not already one."""
- return (iterable if isinstance(iterable, collections.abc.Sequence)
- else tuple([iterable]))
+ return (
+ iterable
+ if isinstance(iterable, collections.abc.Sequence)
+ else (iterable,)
+ )
def remove_all(item, seq):
@@ -260,9 +263,7 @@ def weighted_sample_with_replacement(n, seq, weights):
def weighted_sampler(seq, weights):
"""Return a random-sample function that picks from seq weighted by weights."""
totals = []
- for w in weights:
- totals.append(w + totals[-1] if totals else w)
-
+ totals.extend(w + totals[-1] if totals else w for w in weights)
return lambda: seq[bisect.bisect(totals, random.uniform(0, totals[-1]))]
@@ -283,9 +284,8 @@ def rounder(numbers, d=4):
"""Round a single number, or sequence of numbers, to d decimal places."""
if isinstance(numbers, (int, float)):
return round(numbers, d)
- else:
- constructor = type(numbers) # Can be list, set, tuple, etc.
- return constructor(rounder(n, d) for n in numbers)
+ constructor = type(numbers) # Can be list, set, tuple, etc.
+ return constructor(rounder(n, d) for n in numbers)
def num_or_str(x): # TODO: rename as `atom`
@@ -471,10 +471,10 @@ def memoize(fn, slot=None, maxsize=32):
def memoized_fn(obj, *args):
if hasattr(obj, slot):
return getattr(obj, slot)
- else:
- val = fn(obj, *args)
- setattr(obj, slot, val)
- return val
+ val = fn(obj, *args)
+ setattr(obj, slot, val)
+ return val
+
else:
@functools.lru_cache(maxsize=maxsize)
def memoized_fn(*args):
@@ -673,12 +673,12 @@ def __repr__(self):
op = self.op
args = [str(arg) for arg in self.args]
if op.isidentifier(): # f(x) or f(x, y)
- return '{}({})'.format(op, ', '.join(args)) if args else op
+ return f"{op}({', '.join(args)})" if args else op
elif len(args) == 1: # -x or -(x + 1)
return op + args[0]
else: # (x - y)
- opp = (' ' + op + ' ')
- return '(' + opp.join(args) + ')'
+ opp = f' {op} '
+ return f'({opp.join(args)})'
# An 'Expression' is either an Expr or a Number.
@@ -709,10 +709,7 @@ def subexpressions(x):
def arity(expression):
"""The number of sub-expressions in this expression."""
- if isinstance(expression, Expr):
- return len(expression.args)
- else: # expression is a number
- return 0
+ return len(expression.args) if isinstance(expression, Expr) else 0
# For operators that are not defined in Python, we allow new InfixOps:
@@ -728,7 +725,7 @@ def __or__(self, rhs):
return Expr(self.op, self.lhs, rhs)
def __repr__(self):
- return "PartialExpr('{}', {})".format(self.op, self.lhs)
+ return f"PartialExpr('{self.op}', {self.lhs})"
def expr(x):
@@ -754,7 +751,7 @@ def expr_handle_infix_ops(x):
"P |'==>'| Q"
"""
for op in infix_ops:
- x = x.replace(op, '|' + repr(op) + '|')
+ x = x.replace(op, f'|{repr(op)}|')
return x