python smote算法实现理解

代码参考:
https://blog.csdn.net/Yaphat/article/details/52463304?locationNum=7

import random
from sklearn.neighbors import NearestNeighbors
import numpy as np
class Smote:
    def __init__(self,samples,N=10,k=5):
        self.n_samples,self.n_attrs=samples.shape
        self.N=N  #确定综合样本的数量
        self.k=k
        self.samples=samples
        self.newindex=0
       # self.synthetic=np.zeros((self.n_samples*N,self.n_attrs))
 def over_sampling(self):
        N=int(self.N/100)
        self.synthetic = np.zeros((self.n_samples * N, self.n_attrs))
        neighbors=NearestNeighbors(n_neighbors=self.k).fit(self.samples)
        print 'neighbors',neighbors
        for i in range(len(self.samples)):
            nnarray=neighbors.kneighbors(self.samples[i].reshape(1,-1),return_distance=False)[0]
            #print nnarray
            self._populate(N,i,nnarray)
        return self.synthetic
         # for each minority class samples,choose N of the k nearest neighbors and generate N synthetic samples.
    def _populate(self,N,i,nnarray):
        for j in range(N):
            nn=random.randint(0,self.k-1)
            dif=self.samples[nnarray[nn]]-self.samples[i] //邻居样本与该样本的距离差
            gap=random.random() //生成0和1之间的随机浮点数float
            self.synthetic[self.newindex]=self.samples[i]+gap*dif //生成新样本,原样本加上距离乘以一个百分比
            self.newindex+=1
a=np.array([[1,2,3],[4,5,6],[2,3,1],[2,1,2],[2,3,4],[2,3,4]])
s=Smote(a,N=100)
from sklearn.neighbors import NearestNeighbors

官方文档:https://scikit-learn.org/stable/modules/generated/sklearn.neighbors.NearestNeighbors.html
class sklearn.neighbors.NearestNeighbors(n_neighbors=5, radius=1.0, algorithm=’auto’, leaf_size=30, metric=’minkowski’, p=2, metric_params=None, n_jobs=None, **kwargs)
搜索邻居

nnarray=neighbors.kneighbors(self.samples[i].reshape(1,-1),return_distance=False)[0]

作用:搜索k个邻居
当return_distance=False时,返回该点邻居所在的位置
当return_distance=True时,返回该点邻居的距离

>>> samples = [[0., 0., 0.], [0., .5, 0.], [1., 1., .5]]
>>> from sklearn.neighbors import NearestNeighbors
>>> neigh = NearestNeighbors(n_neighbors=1)
>>> neigh.fit(samples) 
NearestNeighbors(algorithm='auto', leaf_size=30, ...)
>>> print(neigh.kneighbors([[1., 1., 1.]])) 
(array([[0.5]]), array([[2]]))

[[0.5]]的意思是距离为0.5,[[2]]的意思是样本中的第三个元素(索引从0开始)

reshape(1,-1)

参考:https://blog.csdn.net/weixin_39449570/article/details/78619196
新数组的shape属性应该要与原来数组的一致,即新数组元素数量与原数组元素数量要相等。一个参数为-1时,那么reshape函数会根据另一个参数的维度计算出数组的另外一个shape属性值
即本代码中,根据行确定列的个数

算法缺陷:
(1)该算法主要存在两方面的问题:一是在近邻选择时,存在一定的盲目性。从上面的算法流程可以看出,在算法执行过程中,需要确定K值,即选择多少个近邻样本,这需要用户自行解决。从K值的定义可以看出,K值的下限是M值(M值为从K个近邻中随机挑选出的近邻样本的个数,且有M<
K),M的大小可以根据负类样本数量、正类样本数量和数据集最后需要达到的平衡率决定。但K值的上限没有办法确定,只能根据具体的数据集去反复测试。因此如何确定K值,才能使算法达到最优这是未知的。
(2)另外,该算法无法克服非平衡数据集的数据分布问题,容易产生分布边缘化问题。由于负类样本的分布决定了其可选择的近邻,如果一个负类样本处在负类样本集的分布边缘,则由此负类样本和相邻样本产生的“人造”样本也会处在这个边缘,且会越来越边缘化,从而模糊了正类样本和负类样本的边界,而且使边界变得越来越模糊。这种边界模糊性,虽然使数据集的平衡性得到了改善,但加大了分类算法进行分类的难度。
参考博客:
https://blog.csdn.net/Yaphat/article/details/52463304?locationNum=7

猜你喜欢

转载自blog.csdn.net/weixin_37353303/article/details/84229255
今日推荐