qust 朴素贝叶斯垃圾邮件分类_cipherxxx的博客-爱代码爱编程
import os
import re
import string
import math
DATA_DIR = 'enron'
target_names = ['ham', 'spam']
def get_data(DATA_DIR):
subfolders = ['enron%d' % i for i in range(1, 7)]
data = []
target = []
for subfolder in subfolders:
spam_files = os.listdir(os.path.join(DATA_DIR, subfolder, 'spam'))
for spam_file in spam_files:
with open(os.path.join(DATA_DIR, subfolder, 'spam', spam_file), encoding="latin-1") as f:
data.append(f.read())
target.append(1)
ham_files = os.listdir(os.path.join(DATA_DIR, subfolder, 'ham'))
for ham_file in ham_files:
with open(os.path.join(DATA_DIR, subfolder, 'ham', ham_file), encoding="latin-1") as f:
data.append(f.read())
target.append(0)
return data, target
X, y = get_data(DATA_DIR)
class SpamDetector_1(object):
"""Implementation of Naive Bayes for binary classification"""
def clean(self, s):
translator = str.maketrans("", "", string.punctuation)
return s.translate(translator)
def tokenize(self, text):
text = self.clean(text).lower()
return re.split("\W+", text)
def get_word_counts(self, words):
word_counts = {}
for word in words:
word_counts[word] = word_counts.get(word, 0.0) + 1.0
return word_counts
class SpamDetector_2(SpamDetector_1):
def fit(self, X, Y):
self.num_messages = {}
self.log_class_priors = {}
self.word_counts = {}
self.vocab = set()
self.num_messages['spam'] = sum(1 for label in Y if label == 1)
self.num_messages['ham'] = sum(1 for label in Y if label == 0)
self.log_class_priors['spam'] = math.log(
self.num_messages['spam'] / (self.num_messages['spam'] + self.num_messages['ham']))
self.log_class_priors['ham'] = math.log(
self.num_messages['ham'] / (self.num_messages['spam'] + self.num_messages['ham']))
self.word_counts['spam'] = {}
self.word_counts['ham'] = {}
for x, y in zip(X, Y):
c = 'spam' if y == 1 else 'ham'
counts = self.get_word_counts(self.tokenize(x))
for word, count in counts.items():
if word not in self.vocab:
self.vocab.add(word)
if word not in self.word_counts[c]:
self.word_counts[c][word] = 0.0
self.word_counts[c][word] += count
MNB = SpamDetector_2()
MNB.fit(X[100:], y[100:])
class SpamDetector(SpamDetector_2):
def predict(self, X):
result = []
flag_1 = 0
for x in X:
counts = self.get_word_counts(self.tokenize(x))
spam_score = 0
ham_score = 0
flag_2 = 0
for word, _ in counts.items():
if word not in self.vocab:
continue
else:
if word in self.word_counts['spam'].keys() and word in self.word_counts['ham'].keys():
log_w_given_spam = math.log(
(self.word_counts['spam'][word] + 1) / (
sum(self.word_counts['spam'].values()) + len(self.vocab)))
log_w_given_ham = math.log(
(self.word_counts['ham'][word] + 1) / (sum(self.word_counts['ham'].values()) + len(
self.vocab)))
if word in self.word_counts['spam'].keys() and word not in self.word_counts['ham'].keys():
log_w_given_spam = math.log(
(self.word_counts['spam'][word] + 1) / (
sum(self.word_counts['spam'].values()) + len(self.vocab)))
log_w_given_ham = math.log(1 / (sum(self.word_counts['ham'].values()) + len(
self.vocab)))
if word not in self.word_counts['spam'].keys() and word in self.word_counts['ham'].keys():
log_w_given_spam = math.log(1 / (sum(self.word_counts['spam'].values()) + len(self.vocab)))
log_w_given_ham = math.log(
(self.word_counts['ham'][word] + 1) / (sum(self.word_counts['ham'].values()) + len(
self.vocab)))
spam_score += log_w_given_spam
ham_score += log_w_given_ham
flag_2 += 1
spam_score += self.log_class_priors['spam']
ham_score += self.log_class_priors['ham']
if spam_score > ham_score:
result.append(1)
else:
result.append(0)
flag_1 += 1
return result
MNB = SpamDetector()
MNB.fit(X[100:], y[100:])
pred = MNB.predict(X[:100])
true = y[:100]
accuracy = 0
for i in range(100):
if pred[i] == true[i]:
accuracy += 1
print(accuracy)
运行结果:
分类正确率为98%