snntorch : torch を snn に導入するスパイキング ニューラル ネットワーク トレーニング フレームワーク (P1 がデータをスパイキング シーケンスに変換する方法)

この記事の下のコードは主に次の 3 つの関数を実装します。

  1. データセットをパルス列のデータセットに変換します。
  2. それらをどのように視覚化するか
  3. ランダムなパルス列を生成する方法

データセットはディープラーニングでよく使われるMNISTデータセットを採用


スパイク シーケンスを入力として使用する 3 つの利点: 3-Sスパイクニューラル ネットワークの入力は、 0 と 1 で構成される一連のスパイク シーケンスであり、これは人間の脳の軸索に沿って伝達される神経インパルスのデジタル表現でもあります。
スパーシティスパーシティとはスパース性を意味します。これは、前のポイントで説明したパルス シーケンスが通常、スパース行列の形式であることを意味します。つまり、これらの行列の要素のほとんどが 0 で、少数の意味のある部分だけが 1 になります。 。これら 2 つの特性により、スパイキング ニューラル ネットワークのハードウェア回路のエネルギー消費が非常に低くなります。将来のインテリジェント無人システムの脳チップでは、より多くのスパイキング ニューラル ネットワークが使用されると大胆に予測されています~ 静的抑制 静的抑制はイベント駆動型とも呼ばれ
ます, 非常に抽象的に聞こえます。これも人間の脳からインスピレーションを得たものです。私たちは一般に、物事の動きや変化に対してより敏感です。静的抑制とは、入力をパルス シーケンスに変換するプロセス中に、一部の不変部分が抑制されることを意味し、一部の部分が強調表示されます。新しい「イベント」がパルス列の生成を駆動します。

パルス ニューラル ネットワークは、第 3 世代のニューラル ネットワークとして、脳のような知能と低消費電力のハードウェア実装という点で、ここ数年で登場したばかりであり、現在、輝かしい輝きを放っている第 2 世代のニューラル ネットワークと比較して、パルス ニューラル ネットワークは、ニューラルネットワーク 応用面、理論研究面、普及面では若干劣りますが、まだまだ模索は必要です〜

以下はコード部分です: (pip install snntorch最初に snntorch パッケージをインストールしてください)

import snntorch as snn
import torch

# Training Parameters
batch_size=128
data_path='/data/mnist'
num_classes = 10  # MNIST has 10 output classes

# Torch Variables
dtype = torch.float # tensor中的数据全部存储为torch.float型

from torchvision import datasets, transforms

# Define a transform
# 图像变换,通过定义好一个transform对象实现
transform = transforms.Compose([
            transforms.Resize((28,28)),
            transforms.Grayscale(),
            transforms.ToTensor(),
            transforms.Normalize((0,), (1,))]) # 转为tensor,并归一化至[0-1]

mnist_train = datasets.MNIST(data_path, train=True, download=True, transform=transform)

from snntorch import utils
# 在我们实际开始训练网络之前,不需要太大的数据集,通过utils中的data_subset函数将MNIST数据集数目从60000减少到6000
subset = 10
mnist_train = utils.data_subset(mnist_train, subset)
print(f"The size of mnist_train is {
      
      len(mnist_train)}")



from torch.utils.data import DataLoader

train_loader = DataLoader(mnist_train, batch_size=batch_size, shuffle=True)





# Temporal Dynamics
num_steps = 10

# create vector filled with 0.5
raw_vector = torch.ones(num_steps)*0.5

# pass each sample through a Bernoulli trial
rate_coded_vector = torch.bernoulli(raw_vector)
print(f"Converted vector: {
      
      rate_coded_vector}")

print(f"The output is spiking {
      
      rate_coded_vector.sum()*100/len(rate_coded_vector):.2f}% of the time.")



num_steps = 100

# create vector filled with 0.5
raw_vector = torch.ones(num_steps)*0.5

# pass each sample through a Bernoulli trial
rate_coded_vector = torch.bernoulli(raw_vector)
print(f"The output is spiking {
      
      rate_coded_vector.sum()*100/len(rate_coded_vector):.2f}% of the time.")



from snntorch import spikegen

# Iterate through minibatches
data = iter(train_loader)
data_it, targets_it = next(data)

# Spiking Data
spike_data = spikegen.rate(data_it, num_steps=num_steps)


print(spike_data.size())



import matplotlib.pyplot as plt
import snntorch.spikeplot as splt
from IPython.display import HTML        ###


# To plot one sample of data, index into a single sample from the batch (B) dimension of `spike_data`, ``[T x B x 1 x 28 x 28]``:

# In[ ]:


spike_data_sample = spike_data[:, 0, 0]
print(spike_data_sample.size())


# `spikeplot.animator` makes it super simple to animate 2-D data.<br>
# Note: if you are running the notebook locally on your desktop, please uncomment the line below and modify the path to your ffmpeg.exe

# In[ ]:


fig, ax = plt.subplots()
anim = splt.animator(spike_data_sample, fig, ax)
# plt.rcParams['animation.ffmpeg_path'] = 'C:\\path\\to\\your\\ffmpeg.exe'

HTML(anim.to_html5_video())     ###



anim.save("spike_mnist_test.mp4")



print(f"The corresponding target is: {
      
      targets_it[0]}")



spike_data = spikegen.rate(data_it, num_steps=num_steps, gain=0.25)

spike_data_sample2 = spike_data[:, 0, 0]
fig, ax = plt.subplots()
anim = splt.animator(spike_data_sample2, fig, ax)
HTML(anim.to_html5_video())



plt.figure(facecolor="w")
plt.subplot(1,2,1)
plt.imshow(spike_data_sample.mean(axis=0).reshape((28,-1)).cpu(), cmap='binary')
plt.axis('off')
plt.title('Gain = 1')

plt.subplot(1,2,2)
plt.imshow(spike_data_sample2.mean(axis=0).reshape((28,-1)).cpu(), cmap='binary')
plt.axis('off')
plt.title('Gain = 0.25')

plt.show()


# Reshape
spike_data_sample2 = spike_data_sample2.reshape((num_steps, -1))

# raster plot
fig = plt.figure(facecolor="w", figsize=(10, 5))
ax = fig.add_subplot(111)
splt.raster(spike_data_sample2, ax, s=1.5, c="black")

plt.title("Input Layer")
plt.xlabel("Time step")
plt.ylabel("Neuron Number")
plt.show()



idx = 210  # index into 210th neuron

fig = plt.figure(facecolor="w", figsize=(8, 1))
ax = fig.add_subplot(111)

splt.raster(spike_data_sample.reshape(num_steps, -1)[:, idx].unsqueeze(1), ax, s=100, c="black", marker="|")

plt.title("Input Neuron")
plt.xlabel("Time step")
plt.yticks([])
plt.show()


def convert_to_time(data, tau=5, threshold=0.01):
  spike_time = tau * torch.log(data / (data - threshold))
  return spike_time 


raw_input = torch.arange(0, 5, 0.05) # tensor from 0 to 5
spike_times = convert_to_time(raw_input)

plt.plot(raw_input, spike_times)
plt.xlabel('Input Value')
plt.ylabel('Spike Time (s)')
plt.show()



spike_data = spikegen.latency(data_it, num_steps=100, tau=5, threshold=0.01)



fig = plt.figure(facecolor="w", figsize=(10, 5))
ax = fig.add_subplot(111)
splt.raster(spike_data[:, 0].view(num_steps, -1), ax, s=25, c="black")

plt.title("Input Layer")
plt.xlabel("Time step")
plt.ylabel("Neuron Number")
plt.show()


spike_data = spikegen.latency(data_it, num_steps=100, tau=5, threshold=0.01, linear=True)

fig = plt.figure(facecolor="w", figsize=(10, 5))
ax = fig.add_subplot(111)
splt.raster(spike_data[:, 0].view(num_steps, -1), ax, s=25, c="black")
plt.title("Input Layer")
plt.xlabel("Time step")
plt.ylabel("Neuron Number")
plt.show()



spike_data = spikegen.latency(data_it, num_steps=100, tau=5, threshold=0.01,
                              normalize=True, linear=True)

fig = plt.figure(facecolor="w", figsize=(10, 5))
ax = fig.add_subplot(111)
splt.raster(spike_data[:, 0].view(num_steps, -1), ax, s=25, c="black")

plt.title("Input Layer")
plt.xlabel("Time step")
plt.ylabel("Neuron Number")
plt.show()

spike_data = spikegen.latency(data_it, num_steps=100, tau=5, threshold=0.01, 
                              clip=True, normalize=True, linear=True)

fig = plt.figure(facecolor="w", figsize=(10, 5))
ax = fig.add_subplot(111)
splt.raster(spike_data[:, 0].view(num_steps, -1), ax, s=25, c="black")

plt.title("Input Layer")
plt.xlabel("Time step")
plt.ylabel("Neuron Number")
plt.show()



spike_data_sample = spike_data[:, 0, 0]
print(spike_data_sample.size())


# In[ ]:


fig, ax = plt.subplots()
anim = splt.animator(spike_data_sample, fig, ax)

HTML(anim.to_html5_video())


print(targets_it[0])



# Create a tensor with some fake time-series data
data = torch.Tensor([0, 1, 0, 2, 8, -20, 20, -5, 0, 1, 0])

# Plot the tensor
plt.plot(data)

plt.title("Some fake time-series data")
plt.xlabel("Time step")
plt.ylabel("Voltage (mV)")
plt.show()



# Convert data
spike_data = spikegen.delta(data, threshold=4)

# Create fig, ax
fig = plt.figure(facecolor="w", figsize=(8, 1))
ax = fig.add_subplot(111)

# Raster plot of delta converted data
splt.raster(spike_data, ax, c="black")

plt.title("Input Neuron")
plt.xlabel("Time step")
plt.yticks([])
plt.xlim(0, len(data))
plt.show()


# Convert data
spike_data = spikegen.delta(data, threshold=4, off_spike=True)

# Create fig, ax
fig = plt.figure(facecolor="w", figsize=(8, 1))
ax = fig.add_subplot(111)

# Raster plot of delta converted data
splt.raster(spike_data, ax, c="black")

plt.title("Input Neuron")
plt.xlabel("Time step")
plt.yticks([])
plt.xlim(0, len(data))
plt.show()



print(spike_data)


spike_prob = torch.rand((num_steps, 28, 28), dtype=dtype) * 0.5
spike_rand = spikegen.rate_conv(spike_prob)


fig, ax = plt.subplots()
anim = splt.animator(spike_rand, fig, ax)

HTML(anim.to_html5_video())

fig = plt.figure(facecolor="w", figsize=(10, 5))
ax = fig.add_subplot(111)
splt.raster(spike_rand[:, 0].view(num_steps, -1), ax, s=25, c="black")

plt.title("Input Layer")
plt.xlabel("Time step")
plt.ylabel("Neuron Number")
plt.show()

おすすめ

転載: blog.csdn.net/cyy0789/article/details/121351527