@@ -306,13 +306,35 @@ def predict(example):
306
306
# ______________________________________________________________________________
307
307
308
308
309
- def NaiveBayesLearner (dataset , continuous = True ):
309
+ def NaiveBayesLearner (dataset , continuous = True , simple = False ):
310
+ if simple :
311
+ return NaiveBayesSimple (dataset )
310
312
if (continuous ):
311
313
return NaiveBayesContinuous (dataset )
312
314
else :
313
315
return NaiveBayesDiscrete (dataset )
314
316
315
317
318
+ def NaiveBayesSimple (distribution ):
319
+ """A simple naive bayes classifier that takes as input a dictionary of
320
+ CountingProbDist objects and classifies items according to these distributions.
321
+ The input dictionary is in the following form:
322
+ (ClassName, ClassProb): CountingProbDist"""
323
+ target_dist = {c_name : prob for c_name , prob in distribution .keys ()}
324
+ attr_dists = {c_name : count_prob for (c_name , _ ), count_prob in distribution .items ()}
325
+
326
+ def predict (example ):
327
+ """Predict the target value for example. Calculate probabilities for each
328
+ class and pick the max."""
329
+ def class_probability (targetval ):
330
+ attr_dist = attr_dists [targetval ]
331
+ return target_dist [targetval ] * product (attr_dist [a ] for a in example )
332
+
333
+ return argmax (target_dist .keys (), key = class_probability )
334
+
335
+ return predict
336
+
337
+
316
338
def NaiveBayesDiscrete (dataset ):
317
339
"""Just count how many times each value of each input attribute
318
340
occurs, conditional on the target value. Count the different
0 commit comments