Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 718224a

Browse files
Chipe1norvig
authored andcommitted
* Added predicate_symbols * Added FOIL * Updated README
1 parent a065c3b commit 718224a

File tree

5 files changed

+335
-18
lines changed

5 files changed

+335
-18
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ Here is a table of algorithms, the figure, name of the algorithm in the book and
112112
| 19.2 | Current-Best-Learning | `current_best_learning` | [`knowledge.py`](knowledge.py) | Done |
113113
| 19.3 | Version-Space-Learning | `version_space_learning` | [`knowledge.py`](knowledge.py) | Done |
114114
| 19.8 | Minimal-Consistent-Det | `minimal_consistent_det` | [`knowledge.py`](knowledge.py) | Done |
115-
| 19.12 | FOIL | | |
115+
| 19.12 | FOIL | `FOIL_container` | [`knowledge.py`](knowledge.py) | Done |
116116
| 21.2 | Passive-ADP-Agent | `PassiveADPAgent` | [`rl.py`][rl] | Done |
117117
| 21.4 | Passive-TD-Agent | `PassiveTDAgent` | [`rl.py`][rl] | Done |
118118
| 21.8 | Q-Learning-Agent | `QLearningAgent` | [`rl.py`][rl] | Done |

knowledge.py

Lines changed: 115 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
"""Knowledge in learning, Chapter 19"""
22

33
from random import shuffle
4+
from math import log
45
from utils import powerset
56
from collections import defaultdict
6-
from itertools import combinations
7+
from itertools import combinations, product
8+
from logic import (FolKB, constant_symbols, predicate_symbols, standardize_variables,
9+
variables, is_definite_clause, subst, expr, Expr)
710

811
# ______________________________________________________________________________
912

@@ -231,6 +234,117 @@ def consistent_det(A, E):
231234
# ______________________________________________________________________________
232235

233236

237+
class FOIL_container(FolKB):
238+
"""Holds the kb and other necessary elements required by FOIL"""
239+
240+
def __init__(self, clauses=[]):
241+
self.const_syms = set()
242+
self.pred_syms = set()
243+
FolKB.__init__(self, clauses)
244+
245+
def tell(self, sentence):
246+
if is_definite_clause(sentence):
247+
self.clauses.append(sentence)
248+
self.const_syms.update(constant_symbols(sentence))
249+
self.pred_syms.update(predicate_symbols(sentence))
250+
else:
251+
raise Exception("Not a definite clause: {}".format(sentence))
252+
253+
def foil(self, examples, target):
254+
"""Learns a list of first-order horn clauses
255+
'examples' is a tuple: (positive_examples, negative_examples).
256+
positive_examples and negative_examples are both lists which contain substitutions."""
257+
clauses = []
258+
259+
pos_examples = examples[0]
260+
neg_examples = examples[1]
261+
262+
while pos_examples:
263+
clause, extended_pos_examples = self.new_clause((pos_examples, neg_examples), target)
264+
# remove positive examples covered by clause
265+
pos_examples = self.update_examples(target, pos_examples, extended_pos_examples)
266+
clauses.append(clause)
267+
268+
return clauses
269+
270+
def new_clause(self, examples, target):
271+
"""Finds a horn clause which satisfies part of the positive
272+
examples but none of the negative examples.
273+
The horn clause is specified as [consequent, list of antecedents]
274+
Return value is the tuple (horn_clause, extended_positive_examples)"""
275+
clause = [target, []]
276+
# [positive_examples, negative_examples]
277+
extended_examples = examples
278+
while extended_examples[1]:
279+
l = self.choose_literal(self.new_literals(clause), extended_examples)
280+
clause[1].append(l)
281+
extended_examples = [sum([list(self.extend_example(example, l)) for example in
282+
extended_examples[i]], []) for i in range(2)]
283+
284+
return (clause, extended_examples[0])
285+
286+
def extend_example(self, example, literal):
287+
"""Generates extended examples which satisfy the literal"""
288+
# find all substitutions that satisfy literal
289+
for s in self.ask_generator(subst(example, literal)):
290+
s.update(example)
291+
yield s
292+
293+
def new_literals(self, clause):
294+
"""Generates new literals based on known predicate symbols.
295+
Generated literal must share atleast one variable with clause"""
296+
share_vars = variables(clause[0])
297+
for l in clause[1]:
298+
share_vars.update(variables(l))
299+
300+
for pred, arity in self.pred_syms:
301+
new_vars = {standardize_variables(expr('x')) for _ in range(arity - 1)}
302+
for args in product(share_vars.union(new_vars), repeat=arity):
303+
if any(var in share_vars for var in args):
304+
yield Expr(pred, *[var for var in args])
305+
306+
def choose_literal(self, literals, examples):
307+
"""Chooses the best literal based on the information gain"""
308+
def gain(l):
309+
pre_pos = len(examples[0])
310+
pre_neg = len(examples[1])
311+
extended_examples = [sum([list(self.extend_example(example, l)) for example in
312+
examples[i]], []) for i in range(2)]
313+
post_pos = len(extended_examples[0])
314+
post_neg = len(extended_examples[1])
315+
if pre_pos + pre_neg == 0 or post_pos + post_neg == 0:
316+
return -1
317+
318+
# number of positive example that are represented in extended_examples
319+
T = 0
320+
for example in examples[0]:
321+
def represents(d):
322+
return all(d[x] == example[x] for x in example)
323+
if any(represents(l_) for l_ in extended_examples[0]):
324+
T += 1
325+
326+
return T * log((post_pos*(pre_pos + pre_neg) + 1e-4) / ((post_pos + post_neg)*pre_pos))
327+
328+
return max(literals, key=gain)
329+
330+
def update_examples(self, target, examples, extended_examples):
331+
"""Adds to the kb those examples what are represented in extended_examples
332+
List of omitted examples is returned"""
333+
uncovered = []
334+
for example in examples:
335+
def represents(d):
336+
return all(d[x] == example[x] for x in example)
337+
if any(represents(l) for l in extended_examples):
338+
self.tell(subst(example, target))
339+
else:
340+
uncovered.append(example)
341+
342+
return uncovered
343+
344+
345+
# ______________________________________________________________________________
346+
347+
234348
def check_all_consistency(examples, h):
235349
"""Check for the consistency of all examples under h"""
236350
for e in examples:

logic.py

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ def tt_entails(kb, alpha):
196196
True
197197
"""
198198
assert not variables(alpha)
199-
symbols = prop_symbols(kb & alpha)
199+
symbols = list(prop_symbols(kb & alpha))
200200
return tt_check_all(kb, alpha, symbols, {})
201201

202202

@@ -216,23 +216,33 @@ def tt_check_all(kb, alpha, symbols, model):
216216

217217

218218
def prop_symbols(x):
219-
"""Return a list of all propositional symbols in x."""
219+
"""Return the set of all propositional symbols in x."""
220220
if not isinstance(x, Expr):
221-
return []
221+
return set()
222222
elif is_prop_symbol(x.op):
223-
return [x]
223+
return {x}
224224
else:
225-
return list(set(symbol for arg in x.args for symbol in prop_symbols(arg)))
225+
return {symbol for arg in x.args for symbol in prop_symbols(arg)}
226226

227227

228228
def constant_symbols(x):
229-
"""Return a list of all constant symbols in x."""
229+
"""Return the set of all constant symbols in x."""
230230
if not isinstance(x, Expr):
231-
return []
231+
return set()
232232
elif is_prop_symbol(x.op) and not x.args:
233-
return [x]
233+
return {x}
234234
else:
235-
return list({symbol for arg in x.args for symbol in constant_symbols(arg)})
235+
return {symbol for arg in x.args for symbol in constant_symbols(arg)}
236+
237+
238+
def predicate_symbols(x):
239+
"""Return a set of (symbol_name, arity) in x.
240+
All symbols (even functional) with arity > 0 are considered."""
241+
if not isinstance(x, Expr) or not x.args:
242+
return set()
243+
pred_set = {(x.op, len(x.args))} if is_prop_symbol(x.op) else set()
244+
pred_set.update({symbol for arg in x.args for symbol in predicate_symbols(arg)})
245+
return pred_set
236246

237247

238248
def tt_true(s):
@@ -549,7 +559,7 @@ def dpll_satisfiable(s):
549559
function find_pure_symbol is passed a list of unknown clauses, rather
550560
than a list of all clauses and the model; this is more efficient."""
551561
clauses = conjuncts(to_cnf(s))
552-
symbols = prop_symbols(s)
562+
symbols = list(prop_symbols(s))
553563
return dpll(clauses, symbols, {})
554564

555565

@@ -652,7 +662,7 @@ def WalkSAT(clauses, p=0.5, max_flips=10000):
652662
"""Checks for satisfiability of all clauses by randomly flipping values of variables
653663
"""
654664
# Set of all symbols in all clauses
655-
symbols = set(sym for clause in clauses for sym in prop_symbols(clause))
665+
symbols = {sym for clause in clauses for sym in prop_symbols(clause)}
656666
# model is a random assignment of true/false to the symbols in clauses
657667
model = {s: random.choice([True, False]) for s in symbols}
658668
for i in range(max_flips):
@@ -663,7 +673,7 @@ def WalkSAT(clauses, p=0.5, max_flips=10000):
663673
return model
664674
clause = random.choice(unsatisfied)
665675
if probability(p):
666-
sym = random.choice(prop_symbols(clause))
676+
sym = random.choice(list(prop_symbols(clause)))
667677
else:
668678
# Flip the symbol in clause that maximizes number of sat. clauses
669679
def sat_count(sym):

0 commit comments

Comments
 (0)