Tensorflow学习小记

接触


工作要用,就要学喽,不过就看了几天,还是比较懵的状态,所以这里把一些遇到的坑说一下,希望能对新人有帮助吧。

简介


还是介绍一下吧,随便扯扯。Tensorflow是Google的产品,是一个人工智能和深度学习的框架。感觉好屌啊,Google是厉害啊,什么都能搞出来,貌似AlphaGo就是用这东西开发的?还有很多地方可以应用,如语音识别,自然语言理解,计算机视觉……对于某些公司或个人来说,减少了成本,有这么个东西,还是比手写方便多了哈~里面有不少常用的人工智能算法,还可以用GPU辅助计算。真的厉害了。

正文


前面扯了一堆,其实是因为正文很少2333。

首先是一些学习资料,反正我肯定讲不了着么多,哈哈。
英文官网,但是需要翻墙:
http://tensorflow.org/
这两个是中文网站,算是翻译吧,但是有些没翻译,还有些翻译过时了:
http://www.tensorfly.cn/
http://wiki.jikexueyuan.com/project/tensorflow-zh/
某个博主写的,比较详细,不过有些地方也不适用。

你如果是按照上面中文网站的教程一步一步走的话,你无可避免的需要踩个坑,教程中给的安装指令是这个:
pip install https://storage.googleapis.com/tensorflow/mac/tensorflow-0.5.0-py2-none-any.whl
然而,这个版本是过时的,如果你安装了这个版本,后面很多指令都会不work,所以要用这个地址:
sudo pip install https://storage.googleapis.com/tensorflow/mac/tensorflow-r1.0-py2-none-any.whl

至少在我些博客的时候,他是最新的,当然,你最好去看一下Github上是否有更新的版本,github地址:https://github.com/tensorflow/tensorflow/tree/r1.0
如果有新的,把r1.0替换一下就行了,我也是猜的安装地址,如果它之后用了别的规则我也没办法~

我的建议是,如果遇到问题,多翻翻github吧,开源的好处……

另外一个比较坑的东西是Mnist数据集,没有这东西你是没法训练的,所以在教程里会让你去下这个东西,教程里有些代码同样是过时的……最好的方法就是直接去源码里找:

from tensorflow.examples.tutorials.mnist import input_data
#这里FLAGS.data_dir是参数传来的一个目录
#你可以设置为自己想要保存的目录
mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True)

然后,如果你没有翻墙,会发现这个数据集是无法下载下来的,没关系,还有办法~
源码里看一下上面read_data_sets的定义:

def read_data_sets(train_dir,
                   fake_data=False,
                   one_hot=False,
                   dtype=dtypes.float32,
                   reshape=True,
                   validation_size=5000):
  if fake_data:

    def fake():
      return DataSet([], [], fake_data=True, one_hot=one_hot, dtype=dtype)

    train = fake()
    validation = fake()
    test = fake()
    return base.Datasets(train=train, validation=validation, test=test)

  TRAIN_IMAGES = 'train-images-idx3-ubyte.gz'
  TRAIN_LABELS = 'train-labels-idx1-ubyte.gz'
  TEST_IMAGES = 't10k-images-idx3-ubyte.gz'
  TEST_LABELS = 't10k-labels-idx1-ubyte.gz'

  local_file = base.maybe_download(TRAIN_IMAGES, train_dir,
                                   SOURCE_URL + TRAIN_IMAGES)
  with open(local_file, 'rb') as f:
    train_images = extract_images(f)

  local_file = base.maybe_download(TRAIN_LABELS, train_dir,
                                   SOURCE_URL + TRAIN_LABELS)
  with open(local_file, 'rb') as f:
    train_labels = extract_labels(f, one_hot=one_hot)

  local_file = base.maybe_download(TEST_IMAGES, train_dir,
                                   SOURCE_URL + TEST_IMAGES)
  with open(local_file, 'rb') as f:
    test_images = extract_images(f)

  local_file = base.maybe_download(TEST_LABELS, train_dir,
                                   SOURCE_URL + TEST_LABELS)
  with open(local_file, 'rb') as f:
    test_labels = extract_labels(f, one_hot=one_hot)

  if not 0 <= validation_size <= len(train_images):
    raise ValueError(
        'Validation size should be between 0 and {}. Received: {}.'
        .format(len(train_images), validation_size))

  validation_images = train_images[:validation_size]
  validation_labels = train_labels[:validation_size]
  train_images = train_images[validation_size:]
  train_labels = train_labels[validation_size:]

  train = DataSet(train_images, train_labels, dtype=dtype, reshape=reshape)
  validation = DataSet(validation_images,
                       validation_labels,
                       dtype=dtype,
                       reshape=reshape)
  test = DataSet(test_images, test_labels, dtype=dtype, reshape=reshape)

  return base.Datasets(train=train, validation=validation, test=test)

还有里面maybe_download的定义:

def maybe_download(filename, work_directory, source_url):
  """Download the data from source url, unless it's already here.

  Args:
      filename: string, name of the file in the directory.
      work_directory: string, path to working directory.
      source_url: url to download from if file doesn't exist.

  Returns:
      Path to resulting file.
  """
  if not gfile.Exists(work_directory):
    gfile.MakeDirs(work_directory)
  filepath = os.path.join(work_directory, filename)
  if not gfile.Exists(filepath):
    temp_file_name, _ = urlretrieve_with_retry(source_url)
    gfile.Copy(temp_file_name, filepath)
    with gfile.GFile(filepath) as f:
      size = f.size()
    print('Successfully downloaded', filename, size, 'bytes.')
  return filepath

这两段代码就清楚了,我们需要这四个文件:

‘train-images-idx3-ubyte.gz’
‘train-labels-idx1-ubyte.gz’
‘t10k-images-idx3-ubyte.gz’
‘t10k-labels-idx1-ubyte.gz’

然后把它放在你自己定的目录下就行了。
这个数据集可以在这个gitbu里找到。
https://github.com/csuldw/MachineLearning/tree/master/dataset/MNIST

到现在为止,应该你的识别手写数字的sample就可以跑起来了~

发布了443 篇原创文章 · 获赞 149 · 访问量 55万+

猜你喜欢

转载自blog.csdn.net/qian99/article/details/61430824