1、代码
from numpy import *
def createVocabList(dataSet):
vocabSet = set([]) #create empty set
for document in dataSet:
vocabSet = vocabSet | set(document) #union of the two sets
return list(vocabSet)
def setOfWords2Vec(vocabList, inputSet):
returnVec = [0]*len(vocabList)
for word in inputSet:
if word in vocabList:
returnVec[vocabList.index(word)] = 1
else: print ("the word: %s is not in my Vocabulary!" % word)
return returnVec
def trainNB0(trainMatrix,trainCategory): #朴素贝叶斯分类器训练函数
numtraindocs = len(trainMatrix)
numwords = len(trainMatrix[0])
pbsive = sum(trainCategory) / float(numtraindocs)
p1 = 2.0; p2 = 2.0
p1nums = ones(numwords); p2nums = ones(numwords)
for i in range(numtraindocs):
if trainCategory[i] == 1:
p1nums += trainMatrix[i]
p1 += sum(trainMatrix[i])
else:
p2nums += trainMatrix[i]
p2 += sum(trainMatrix[i])
p1vects = log(p1nums/p1)
p2vects = log(p2nums/p2)
return p1vects, p2vects, pbsive
def classifyNB(testvect,p1vect,p2vect,pbsive):
p1 = sum(testvect*p1vect) + log(pbsive) #此处应是矩阵点乘,即矩阵的对应元素想乘
p2 = sum(testvect*p2vect) + log(1-pbsive)
if p1 > p2:
return 1
else:
return 0
def textparse(bigstring):
import re
listwords = re.split(r'\W*',bigstring)
return [tok.lower() for tok in listwords if len(tok) > 2]
def spamtest():
doclist = []; classlist = [] ;fulltext = []
for i in range(1,26):
wordlist = textparse(open('email/spam/%d.txt'% i).read())
print (wordlist)
doclist.append(wordlist)
fulltext.extend(wordlist)
classlist.append(1)
wordlist = textparse(open('email/ham/%d.txt'% i).read())
doclist.append(wordlist)
fulltext.extend(wordlist)
classlist.append(0)
vocablist = createVocabList(doclist)
trainingset = list(range(50))
testset = []
for i in range(10):
randeindex = int(random.uniform(0,len(trainingset)))
testset.append(trainingset[randeindex])
del(trainingset[randeindex])
trainmat = []; trainclasses = []
for i in trainingset:
trainmat.append(setOfWords2Vec(vocablist,doclist[i]))
trainclasses.append(classlist[i])
p0v, p1v, pbsive = trainNB0(array(trainmat),array(trainclasses))
errorcount = 0
for docindex in testset:
wordlist = setOfWords2Vec(vocablist,doclist[docindex])
item = classifyNB(array(wordlist),p0v,p1v,pbsive)
if item != classlist[docindex]:
errorcount += 1
print ('the error is:',float(errorcount)/len(testset))
注意事项:
python3报错解决办法:UnicodeDecodeError: 'gbk' codec can't decode byte 0xae in position 199: illegal multib
解决办法:打开email\ham\23.txt,找到SciFinance?,把?替换成空格即可。
'range' object doesn't support item deletion
python3.x , 出现错误 'range' object doesn't support item deletion
原因:python3.x range返回的是range对象,不返回数组对象
解决方法:
把 trainingSet = range(50) 改为 trainingset = list(range(50))
注意array对象相乘是点乘,对应元素的相乘。matrix相乘是矩阵的乘法