def forward(self, x):
B, T, N = x.size()
x = x.permute(0,2,1)
print(x.shape)
out1 = self.inception1(x)
out2 = self.inception2(x)
out3 = self.inception3(x)
out4 = self.inception4(x)
out5 = self.inception5(x)
# UEB
out_m = torch.mean(torch.stack([out1,out2,out3,out5,out4]), dim=0)
print(out_m.shape)
out_v = 1 - F.softmax(torch.var(torch.stack([out1,out2,out3,out5,out4]), dim=0, unbiased=False))
uncertainty = self.conv_u(out_v)
loss_ae = torch.mean(torch.abs(torch.stack([out1,out2,out3,out5,out4]) - out_m.unsqueeze(0).repeat(5,1,1,1)))
loss = torch.abs(torch.sum(0.5*(torch.exp((-1)*uncertainty)) * loss_ae**2 + 0.5*uncertainty))
kernel = self.linear_k((out_m * uncertainty).view(-1,T*N))
kernel = kernel[0,:].view(-1,N,self.kernel_lenth).repeat(N,1,1)
res = F.conv1d(out_m, kernel, padding = (self.kernel_lenth-1)//2)
out = out_m + res
out = out.permute(0,2,1)
return out, loss
The above code will not report an error.
def forward(self, x):
B, T, N = x.size()
x = x.permute(0,2,1)
print(x.shape)
out1 = self.inception1(x)
out2 = self.inception2(x)
out3 = self.inception3(x)
out4 = self.inception4(x)
out5 = self.inception5(x)
# UEB
out_m = torch.mean(torch.stack([out1,out2,out3,out5,out4]), dim=0)
print(out_m.shape)
out_v = 1 - F.softmax(torch.var(torch.stack([out1,out2,out3,out5,out4]), dim=0, unbiased=False))
uncertainty = self.conv_u(out_v)
loss_ae = torch.mean(torch.abs(torch.stack([out1,out2,out3,out5,out4]) - out_m.unsqueeze(0).repeat(5,1,1,1)))
loss = torch.abs(torch.sum(0.5*(torch.exp((-1)*uncertainty)) * loss_ae**2 + 0.5*uncertainty))
kernel = self.linear_k((x * uncertainty).view(-1,T*N))
kernel = kernel[0,:].view(-1,N,self.kernel_lenth).repeat(N,1,1)
res = F.conv1d(x, kernel, padding = (self.kernel_lenth-1)//2)
out = x + res
out = out.permute(0,2,1)
return out, loss
Changing out_m to x will report an error.
But the printed shape is the same.
The main reason for this is that view() needs the element addresses in Tensor to be continuous, because Tensor may be discontinuous, so the modification method is: add .contiguous() before .view to make it continuous and
ok .
Not sure why x is not continuous.
It is really good to add contiguous.