Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 63458a9

Browse files
antmarakisnorvig
authored andcommitted
Learning: Naive Bayes Classifier (aimacode#618)
* add a simple naive bayes classifier * Update test_learning.py * spacing * minor fix * lists to strings
1 parent a58fe90 commit 63458a9

File tree

2 files changed

+37
-1
lines changed

2 files changed

+37
-1
lines changed

learning.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -306,13 +306,35 @@ def predict(example):
306306
# ______________________________________________________________________________
307307

308308

309-
def NaiveBayesLearner(dataset, continuous=True):
309+
def NaiveBayesLearner(dataset, continuous=True, simple=False):
310+
if simple:
311+
return NaiveBayesSimple(dataset)
310312
if(continuous):
311313
return NaiveBayesContinuous(dataset)
312314
else:
313315
return NaiveBayesDiscrete(dataset)
314316

315317

318+
def NaiveBayesSimple(distribution):
319+
"""A simple naive bayes classifier that takes as input a dictionary of
320+
CountingProbDist objects and classifies items according to these distributions.
321+
The input dictionary is in the following form:
322+
(ClassName, ClassProb): CountingProbDist"""
323+
target_dist = {c_name: prob for c_name, prob in distribution.keys()}
324+
attr_dists = {c_name: count_prob for (c_name, _), count_prob in distribution.items()}
325+
326+
def predict(example):
327+
"""Predict the target value for example. Calculate probabilities for each
328+
class and pick the max."""
329+
def class_probability(targetval):
330+
attr_dist = attr_dists[targetval]
331+
return target_dist[targetval] * product(attr_dist[a] for a in example)
332+
333+
return argmax(target_dist.keys(), key=class_probability)
334+
335+
return predict
336+
337+
316338
def NaiveBayesDiscrete(dataset):
317339
"""Just count how many times each value of each input attribute
318340
occurs, conditional on the target value. Count the different

tests/test_learning.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,20 @@ def test_naive_bayes():
105105
assert nBC([6, 5, 3, 1.5]) == "versicolor"
106106
assert nBC([7, 3, 6.5, 2]) == "virginica"
107107

108+
# Simple
109+
data1 = 'a'*50 + 'b'*30 + 'c'*15
110+
dist1 = CountingProbDist(data1)
111+
data2 = 'a'*30 + 'b'*45 + 'c'*20
112+
dist2 = CountingProbDist(data2)
113+
data3 = 'a'*20 + 'b'*20 + 'c'*35
114+
dist3 = CountingProbDist(data3)
115+
116+
dist = {('First', 0.5): dist1, ('Second', 0.3): dist2, ('Third', 0.2): dist3}
117+
nBS = NaiveBayesLearner(dist, simple=True)
118+
assert nBS('aab') == 'First'
119+
assert nBS(['b', 'b']) == 'Second'
120+
assert nBS('ccbcc') == 'Third'
121+
108122

109123
def test_k_nearest_neighbors():
110124
iris = DataSet(name="iris")

0 commit comments

Comments
 (0)