res_board 残差神经网络训练

#coding=utf-8
import tensorflow as tf
import numpy as np
import pandas as pd
import os
import skimage
import skimage.data
import random
import residual_network
from dataset import randomize

def load_data(data_dir):
    directories = [d for d in os.listdir(data_dir)
                   if os.path.isdir(os.path.join(data_dir, d))]
    labels = []
    images = []
    for d in directories:
        label_dir = os.path.join(data_dir, d)
        file_names = [os.path.join(label_dir, f)
                      for f in os.listdir(label_dir) if f.endswith(".jpg")]
        for f in file_names:
            #print(f)

            images.append(skimage.data.imread(f))
            labels.append(int(d))
   # for i in range(0,10):
        #print(images[i])
    return images, labels

train_data_dir =  "result"

images,labels=load_data(train_data_dir)


_images=[skimage.transform.resize(image,(32,32),mode='reflect')
          for image in  images]


images_a = np.array(_images)
labels_a = np.array(labels)

#print(images_a[0])
print("labels: ", labels_a.shape, "\nimages: ", images_a.shape)

train_dataset,train_labels=randomize(images_a,labels_a)
#train_dataset,train_labels=images_a,labels_a

x=tf.placeholder(tf.float32,[None,32,32,3],name='x_input')
y=tf.placeholder(tf.int32,[None],name='y_input')

sess=tf.InteractiveSession()

logits=residual_network.ResNet(x,100)

cross_entropy=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logits,labels=y))
train_step=tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)
correct_pre=tf.equal(tf.arg_max(y,1),tf.arg_max(logits,1))
accuracy=tf.reduce_mean(tf.cast(correct_pre,tf.float32))

batch_size=10

sess.run(tf.global_variables_initializer())
print(train_dataset.shape)
print(train_labels.shape)
print("Start~")
for i in range(0,2000):
   # offset=(i*batch_size)%(train_labels.shape[0]-batch_size)
   # xs=train_dataset[offset:(offset+batch_size),:,:,:]
   # ys=train_labels[offset:(offset+batch_size)]
   # feed_dict={x:xs,y:ys}
    feed_dict={x:train_dataset,y:train_labels}
    acc=accuracy.eval(feed_dict=feed_dict)
    if i%10==0:
        print(acc)
    train_step.run(feed_dict=feed_dict)

Caused by op ‘ArgMax’, defined at:
File “res_board.py”, line 57, in
correct_pre=tf.equal(tf.arg_max(y,1),tf.arg_max(logits,1))
File “/home/sys-04/anaconda3/envs/python35/lib/python3.5/site-packages/tensorflow/python/util/deprecation.py”, line 128, in new_func
return func(*args, **kwargs)
File “/home/sys-04/anaconda3/envs/python35/lib/python3.5/site-packages/tensorflow/python/ops/gen_math_ops.py”, line 189, in arg_max
output_type=output_type, name=name)
File “/home/sys-04/anaconda3/envs/python35/lib/python3.5/site-packages/tensorflow/python/framework/op_def_library.py”, line 767, in apply_op
op_def=op_def)
File “/home/sys-04/anaconda3/envs/python35/lib/python3.5/site-packages/tensorflow/python/framework/ops.py”, line 2630, in create_op
original_op=self._default_original_op, op_def=op_def)
File “/home/sys-04/anaconda3/envs/python35/lib/python3.5/site-packages/tensorflow/python/framework/ops.py”, line 1204, in init
self._traceback = self._graph._extract_stack() # pylint: disable=protected-access

InvalidArgumentError (see above for traceback): Expected dimension in the range [-1, 1), but got 1
[[Node: ArgMax = ArgMax[T=DT_INT32, Tidx=DT_INT32,
output_type=DT_INT64, _device=”/job:localhost/replica:0/task:0/cpu:0”](_arg_y_input_0_1, ArgMax/dimension)]]

感谢逻辑没毛病 却总是抱错
jesus

猜你喜欢

转载自blog.csdn.net/hensonwells/article/details/80368270
今日推荐