Knn算法智能识别验证码数字

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/coldence/article/details/78874037

1. 首先,需要写一个爬取图片的程序,获取大量验证码素材。下面用了python实现。

#coding=utf-8
import urllib
import re
import time
import socket
#USER_AGENT='Mozilla/5.0 (X11;Ubuntu;Linux x86_64;rv:40.0)Gecko/20100101 Firefox/40.0'
#HEADERS='User-agent:'+USER_AGENT
#print HEADERS
def getHtml(url):
    page = urllib.urlopen(url)
    html = page.read()
    return html


def Schedule(a, b, c):
    '''
    a:已经下载的数据块
    b:数据库块的大小
    c:远程文件的大小
    '''
    per = 100.0 * a * b / c
    if per > 100:
        per = 100
    print('%.2f%%' % per)

def auto_down(url,filename,Schedule):
    x=0
    try:
        urllib.urlretrieve(url,filename,Schedule)
    except urllib.ContentTooShortError:
        if x<5:
            print 'Network conditions is not good.Reloading...%s' %x
            auto_down(url, filename, Schedule)
            x+=1
        else:
            print 'Download failed.Connecting to next image.'
            return False
    except socket.error:
        if x<5:
            print 'Socket error.Reloading...%s' %x
            auto_down(url, filename, Schedule)
            x+=1
        else:
            print 'Download failed.Connecting to next image.'
            return False
    '''else:
        if x<5:
            print 'Unknown error.Reloading...%s' %x
            auto_down(url, filename, Schedule)
            x+=1
        else:
            print 'Download failed.Connecting to next image.'
            return False'''
    return True

def getImg(html):
    reg = r'src="(.+?\.jpg)" pic_ext'
    imgre = re.compile(reg)
    imglist = re.findall(imgre,html)
    x = 1
    success_num=0
    failed_num=0
    for imgurl in imglist:
        print '=============Image No.%s=============' % x
        rst=auto_down(imgurl,'C:/Joe_Workspace/reptile_workspace/jpg/%s.jpg' % x,Schedule)
        print 'Image No.%s download finish.' % x
        x+=1
        if rst:
            success_num+=1
        else:
            failed_num+=1
        #time.sleep(2)
        rst_val=success_num/(success_num+failed_num)*100
    if rst_val ==100:
        print "[result] All %s images have been downloaded successfully." %success_num
    else:
        print "[result] %s/%s images have been downloaded successfully." %success_num %(success_num+failed_num)
    return imglist

html = getHtml("http://tieba.baidu.com/p/2460150866")

print getImg(html)

结果如下:

2. 从网上下载的验证码颜色形状较多,为了使机器学习效果更显著,我们先用程序生成简单的验证码来做。下面使用java实现。

import java.awt.Color; 
import java.awt.Font; 
import java.awt.Graphics2D; 
import java.awt.image.BufferedImage; 
import java.io.File; 
import java.io.FileNotFoundException; 
import java.io.FileOutputStream; 
import java.io.IOException; 
import java.io.OutputStream; 
import java.util.Random; 

import javax.imageio.ImageIO; 

//import org.junit.Test; 

/**
 * @author : Administrator
 * @function :
 */ 
public class VerificationCode {
 private int w = 32; 
    private int h = 32; 
    private String text;
    private Random r = new Random(); 
    public static String[] fontNames_all = {"Arial","BatangChe","Bell MT","Arial Narrow","Arial Rounded MT Bold","Bookman Old Style","Bookshelf Symbil 7","Calbri Light","Calibri","Arial Black","Batang","Bodoni MT Black"};   
    Color bgColor = new Color(255, 255, 255);
 public static String[] fontNames=new String[1];;
 public static String codes="";

   public static void main(String[] args){
  for (int num=0;num<10;num++){
   for (int i=0;i<fontNames_all.length;i++){
    fontNames[0]=fontNames_all[i];
    codes=""+num;
    test_fun(num,i);
   }
  }
 }
    /**
     *
     */ 
    //@Test 
    public static void test_fun(int num,int i) { 
        VerificationCode vc = new VerificationCode(); 
        BufferedImage image = vc.getImage(); 
        try {
            VerificationCode.output(image, new FileOutputStream(new File( 
                    "C:\\Joe_Workspace\\image\\"+num+"_"+i+".jpg"))); 
        } catch (FileNotFoundException e) { 
            e.printStackTrace(); 
        } 
        System.out.println(vc.getText()); 
    } 

    /**
     *
     */ 
    public BufferedImage getImage() { 
        BufferedImage image = createImage();  
        Graphics2D g2 = (Graphics2D) image.getGraphics();  
        StringBuilder sb = new StringBuilder(); 

        for (int i = 0; i < 1; ++i) {
            String s = randomChar() + ""; 
            sb.append(s); 
            float x = i * 1.0F * w / 4 +9; 
            g2.setFont(randomFont()); 
            g2.setColor(randomColor()); 
            g2.drawString(s, x, h - 7); 
        } 

        this.text = sb.toString();
        //drawLine(image); 
        return image; 

    } 

    /**
     * @return
     */ 
    public String getText() { 
        return text; 
    } 

    /**
     * @param image
     * @param out
     *            
     */ 
    public static void output(BufferedImage image, OutputStream out) { 
        try { 
            ImageIO.write(image, "jpeg", out); 
        } catch (IOException e) { 
            e.printStackTrace(); 
        } 
    } 

    private void drawLine(BufferedImage image) { 
        Graphics2D g2 = (Graphics2D) image.getGraphics(); 
        for (int i = 0; i < 3; ++i) {
            int x1 = r.nextInt(w); 
            int y1 = r.nextInt(h); 
            int x2 = r.nextInt(w); 
            int y2 = r.nextInt(h); 
            g2.setColor(Color.BLUE); 
            g2.drawLine(x1, y1, x2, y2); 
        } 
    } 

    private Color randomColor() { 
        int red = r.nextInt(150); 
        int green = r.nextInt(150); 
        int blue = r.nextInt(150); 
        return new Color(0, 0, 0); 
    } 

    private Font randomFont() { 
        int index = r.nextInt(fontNames.length); 
        String fontName = fontNames[index]; 
        int style = r.nextInt(4); 
        int size = r.nextInt(5) + 24; 
        return new Font(fontName, style, size); 
    } 

    private char randomChar() { 
        int index = r.nextInt(codes.length()); 
        return codes.charAt(index); 
    } 

    private BufferedImage createImage() { 
        BufferedImage image = new BufferedImage(w, h, 
                BufferedImage.TYPE_INT_RGB); 
        Graphics2D g2 = (Graphics2D) image.getGraphics(); 
        g2.setColor(this.bgColor); 
        g2.fillRect(0, 0, w, h); 

        return image; 
    } 

} 

3. 然后,将生成的验证码图片转化成字符串矩阵存入txt。以下用python实现。

from PIL import Image
import os
from os import listdir

def img2txt_func(img_path_1,txt_path_1):
    fh=open(txt_path_1,'w')
    im=Image.open(img_path_1)
    fh=open(txt_path_1,'a')

    width=im.size[0]
    height=im.size[1]

    for i in range(0,width):
        for j in range(0,height):
            cl=im.getpixel((j,i))
            clall=cl[0]+cl[1]+cl[2]
            if(clall==0):#black
                fh.write("1")
            else:
                fh.write("0")
        fh.write("\n")

    fh.close()

img_path="c:/Joe_Workspace/images"
txt_path="c:/Joe_Workspace/traindata"
imgs=listdir(img_path)
for img in imgs:
    #print img_path+"/"+os.path.basename(img)
    #print txt_path+"/"+os.path.splitext(img)[0]+".txt"
    img2txt_func(img_path+"/"+os.path.basename(img),txt_path+"/"+os.path.splitext(img)[0]+".txt")

4. 最后,写一个knn算法去学习数字的程序。用python实现。

from numpy import *
import operator
from os import listdir
import os

from numpy.matlib import zeros

def pow_func_2nd_arry(dif,sqnum,arr_len):
    sqarr = zeros((len(dif),arr_len))
    for i in range(0,len(dif)):
        arr=[]
        arrlist=mat(dif[i]).tolist()
        #print arrlist[0][175]
        for j in range(0,arr_len):
            arr.append(int(arrlist[0][j])**sqnum)
        sqarr[i,:]=arr
    return sqarr

def pow_func_1nd_arry(dif,sqnum,arr_len):
    sqarr = zeros((1,arr_len))
    arrlist=mat(dif).tolist()
    #print sqarr
    #print arrlist[0][0]
    #print arrlist[675][0]
    #print arrlist[674][0]
    for i in range(0,arr_len):
        sqarr[:,i]=int(arrlist[i][0])**sqnum
    return sqarr

def array2list(dif,arr_len):
    list1=[]
    arrlist = mat(dif).tolist()
    for i in range(0,arr_len):
        list1.append(int(arrlist[0][i]))
    #print list1
    return list1

def knn(k,testdata,traindata,labels):
    traindatasize=traindata.shape[0]
    dif=tile(testdata,(traindatasize,1))-traindata
    #print dif[0][0]
    sqdif=pow_func_2nd_arry(dif,2,1024)
    sumsqdif=sqdif.sum(axis=1)
    distance=pow_func_1nd_arry(sumsqdif,0.5,len(labels))
    sortdistance=distance.argsort()
    sortdis_list=array2list(sortdistance,len(labels))
    #print sortdis_list
    count={}
    for i in range(0,k):
        vote=labels[int(sortdis_list[i])]
        count[vote]=count.get(vote,0)+1
    sortcount=sorted(count.items(),key=operator.itemgetter(1),reverse=True)
    return sortcount[0][0]
def data2array(fname):
    arr=[]
    fh=open(fname)
    for i in range(0,32):
        thisline=fh.readline()
        for j in range(0,32):
            arr.append(int(thisline[j]))
    #print arr
    return arr
def seplabel(fname):
    filestr=fname.split(".")[0]
    label=int(filestr.split("_")[0])
    return label
def traindata(train_data_path):
    labels=[]
    trainfile=listdir(train_data_path)
    num=len(trainfile)
    trainarr=zeros((num,1024))
    for i in range(0,num):
        thisname=trainfile[i]
        thislabel=seplabel(thisname)
        labels.append(thislabel)
        trainarr[i,:]=data2array(train_data_path+thisname)
    return trainarr,labels

test_data_path="c:/Joe_Workspace/testdata/"
train_data_path="c:/Joe_Workspace/traindata/"
trainarr,labels=traindata(train_data_path)
testfiles=listdir(test_data_path)
pass_count=0
for thistestfile in testfiles:
    testarr=data2array(test_data_path+os.path.basename(thistestfile))
    rst=knn(5,testarr,trainarr,labels)
    print os.path.basename(thistestfile)+":"+str(seplabel(thistestfile))+"?="+str(rst)
    if str(seplabel(thistestfile))==str(rst):
        pass_count+=1
print "[Pass Rate:] %s%%" %((pass_count*100)/len(testfiles))

猜你喜欢

转载自blog.csdn.net/coldence/article/details/78874037