用Numpy写一个Softmax

softmax

  • 计算exponential
  • 按行求和
  • 每一行都要除以计算的和
m = np.random.rand(10,10) * 10 + 1000
print(m)
[[ 1008.64304012  1001.25079229  1006.81896868  1005.89015258  1008.8915297
   1001.84923866  1005.53509734  1005.34075305  1008.93404709
   1006.94897664]
 [ 1003.24267825  1003.72710741  1000.28354398  1000.32012105  1004.3690361
   1007.18390602  1002.49741606  1005.83510332  1009.19678396
   1002.32098566]
 [ 1002.32824002  1006.2813999   1009.27645662  1002.57259159
   1006.30743627  1000.35201323  1003.94430099  1008.79056869
   1007.40485841  1006.38239542]
 [ 1007.06228714  1006.01325352  1007.96901864  1002.34269542
   1000.75563221  1005.26357317  1006.14861174  1005.68119044
   1000.69006453  1007.21834125]
 [ 1004.15770428  1003.0554848   1005.55619032  1003.04000025
   1005.54338468  1002.23952638  1008.86317857  1006.96983789
   1005.84232318  1009.28833837]
 [ 1008.47151667  1006.30354927  1006.69274016  1004.12418543
   1007.17550972  1004.31758292  1007.27760499  1007.45250445
   1000.02943239  1002.25886446]
 [ 1000.63764781  1003.39894276  1008.26298759  1001.89295012
   1007.85388369  1004.67565255  1004.58872708  1003.24488815
   1000.39528914  1007.20964465]
 [ 1005.21815308  1007.42651355  1006.32407717  1003.0096329   1005.03545902
   1008.85925437  1009.57634418  1003.74546024  1003.40512867  1004.4437606 ]
 [ 1001.78786625  1008.73282377  1003.98906267  1008.17533941
   1002.79957584  1000.89332666  1007.64343999  1003.88248211
   1005.75517566  1008.27556001]
 [ 1002.05916059  1007.25663392  1009.48655775  1009.56831564
   1008.28488062  1004.92593854  1008.0468565   1007.53278621
   1001.94935121  1007.01473574]]
# np.exp(m) #发生指数结果爆炸,值太大了,每行同时减去该行最大值再exp
m_row_max = m.max(axis=1).reshape(10,1) #每一行的最大值#注意:要reshape成行向量是列维度为1
print(m_row_max, m_row_max.shape)
[[ 1008.93404709]
 [ 1009.19678396]
 [ 1009.27645662]
 [ 1007.96901864]
 [ 1009.28833837]
 [ 1008.47151667]
 [ 1008.26298759]
 [ 1009.57634418]
 [ 1008.73282377]
 [ 1009.56831564]] (10, 1)
m = m - m_row_max   #广播:每一行都减去了该行的最大值
print(m)
[[-0.29100696 -7.6832548  -2.11507841 -3.04389451 -0.04251738 -7.08480843
  -3.39894975 -3.59329403  0.         -1.98507045]
 [-5.95410571 -5.46967655 -8.91323998 -8.87666291 -4.82774786 -2.01287794
  -6.6993679  -3.36168065  0.         -6.8757983 ]
 [-6.9482166  -2.99505673  0.         -6.70386503 -2.96902035 -8.9244434
  -5.33215563 -0.48588793 -1.87159821 -2.8940612 ]
 [-0.9067315  -1.95576512  0.         -5.62632322 -7.21338643 -2.70544547
  -1.82040689 -2.2878282  -7.27895411 -0.75067739]
 [-5.13063409 -6.23285357 -3.73214805 -6.24833812 -3.74495369 -7.04881199
  -0.4251598  -2.31850048 -3.44601518  0.        ]
 [ 0.         -2.16796739 -1.77877651 -4.34733123 -1.29600694 -4.15393374
  -1.19391168 -1.01901222 -8.44208427 -6.21265221]
 [-7.62533977 -4.86404483  0.         -6.37003747 -0.4091039  -3.58733504
  -3.67426051 -5.01809944 -7.86769844 -1.05334294]
 [-4.3581911  -2.14983063 -3.25226701 -6.56671127 -4.54088515 -0.71708981
   0.         -5.83088393 -6.1712155  -5.13258358]
 [-6.94495753  0.         -4.7437611  -0.55748437 -5.93324793 -7.83949711
  -1.08938379 -4.85034166 -2.97764811 -0.45726376]
 [-7.50915505 -2.31168172 -0.08175789  0.         -1.28343501 -4.6423771
  -1.52145914 -2.03552943 -7.61896443 -2.5535799 ]]
m_exp = np.exp(m)
print(m_exp, m_exp.shape)
[[  7.47510474e-01   4.60473707e-04   1.20623832e-01   4.76489585e-02
    9.58373807e-01   8.37735258e-04   3.34083387e-02   2.75075701e-02
    1.00000000e+00   1.37370936e-01]
 [  2.59516363e-03   4.21259451e-03   1.34595041e-04   1.39609277e-04
    8.00452828e-03   1.33603617e-01   1.23169021e-03   3.46769303e-02
    1.00000000e+00   1.03247308e-03]
 [  9.60346309e-04   5.00337887e-02   1.00000000e+00   1.22616357e-03
    5.13535942e-02   1.33095533e-04   4.83363924e-03   6.15150742e-01
    1.53877536e-01   5.53509641e-02]
 [  4.03842028e-01   1.41456204e-01   1.00000000e+00   3.60179403e-03
    7.36658284e-04   6.68405420e-02   1.61959837e-01   1.01486632e-01
    6.89906753e-04   4.72046684e-01]
 [  5.91281003e-03   1.96383995e-03   2.39413532e-02   1.93366498e-03
    2.36367236e-02   8.68440058e-04   6.53665320e-01   9.84210597e-02
    3.18723893e-02   1.00000000e+00]
 [  1.00000000e+00   1.14409931e-01   1.68844601e-01   1.29413038e-02
    2.73622203e-01   1.57025251e-02   3.03033573e-01   3.60951306e-01
    2.15600312e-04   2.00391561e-03]
 [  4.87929429e-04   7.71919784e-03   1.00000000e+00   1.71209508e-03
    6.64245216e-01   2.76719769e-02   2.53681582e-02   6.61709096e-03
    3.82914649e-04   3.48769883e-01]
 [  1.28015234e-02   1.16503889e-01   3.86864061e-02   1.40641513e-03
    1.06639631e-02   4.88170860e-01   1.00000000e+00   2.93548106e-03
    2.08869564e-03   5.90129429e-03]
 [  9.63481254e-04   1.00000000e+00   8.70584098e-03   5.72647825e-01
    2.64986143e-03   3.93867063e-04   3.36423738e-01   7.82570334e-03
    5.09124334e-02   6.33013352e-01]
 [  5.48043962e-04   9.90944624e-02   9.21495038e-01   1.00000000e+00
    2.77083877e-01   9.63476759e-03   2.18392989e-01   1.30611314e-01
    4.91050084e-04   7.78026413e-02]] (10, 10)
m_exp_row_sum = m_exp.sum(axis=1).reshape(10,1) #每一行求和
print(m_exp_row_sum, m_exp_row_sum.shape) 
[[ 3.07374213]
 [ 1.1856312 ]
 [ 1.93291987]
 [ 2.35266029]
 [ 1.8422156 ]
 [ 2.25172496]
 [ 2.08297446]
 [ 1.67915853]
 [ 2.6135361 ]
 [ 2.73515418]] (10, 1)
m_softmax = m_exp / m_exp_row_sum
print(m_softmax)
[[  2.43192319e-01   1.49808829e-04   3.92433154e-02   1.55019376e-02
    3.11793823e-01   2.72545719e-04   1.08689465e-02   8.94921207e-03
    3.25336336e-01   4.46917571e-02]
 [  2.18884559e-03   3.55303952e-03   1.13521845e-04   1.17751015e-04
    6.75128005e-03   1.12685645e-01   1.03884767e-03   2.92476533e-02
    8.43432594e-01   8.70821452e-04]
 [  4.96837103e-04   2.58850817e-02   5.17352020e-01   6.34358200e-04
    2.65678857e-02   6.88572427e-05   2.50069303e-03   3.18249479e-01
    7.96088542e-02   2.86359331e-02]
 [  1.71653353e-01   6.01260646e-02   4.25050742e-01   1.53094522e-03
    3.13117150e-04   2.84106220e-02   6.88411489e-02   4.31369681e-02
    2.93245377e-04   2.00643793e-01]
 [  3.20961891e-03   1.06602069e-03   1.29959562e-02   1.04964098e-03
    1.28305957e-02   4.71410652e-04   3.54825635e-01   5.34253752e-02
    1.73011179e-02   5.42824629e-01]
 [  4.44103973e-01   5.08099049e-02   7.49845582e-02   5.74728445e-03
    1.21516707e-01   6.97355379e-03   1.34578414e-01   1.60299909e-01
    9.57489550e-05   8.89946885e-04]
 [  2.34246477e-04   3.70585333e-03   4.80082698e-01   8.21947224e-04
    3.18892636e-01   1.32848374e-02   1.21788138e-02   3.17675088e-03
    1.83830698e-04   1.67438386e-01]
 [  7.62377296e-03   6.93823048e-02   2.30391624e-02   8.37571382e-04
    6.35077805e-03   2.90723509e-01   5.95536385e-01   1.74818578e-03
    1.24389425e-03   3.51443547e-03]
 [  3.68650448e-04   3.82623373e-01   3.33105824e-03   2.19108442e-01
    1.01389892e-03   1.50702744e-04   1.28723586e-01   2.99429701e-03
    1.94802870e-02   2.42205704e-01]
 [  2.00370409e-04   3.62299365e-02   3.36907895e-01   3.65610102e-01
    1.01304664e-01   3.52256836e-03   7.98466830e-02   4.77528160e-02
    1.79532871e-04   2.84454316e-02]]
#检查:每一行加起来应该是1
print(m_softmax.sum(axis=1))
[ 1.  1.  1.  1.  1.  1.  1.  1.  1.  1.]

猜你喜欢

转载自blog.csdn.net/douhh_sisy/article/details/80557138