Error al descargar el conjunto de datos mnist en tensorflow2.0, importar el conjunto de datos mnist desde local

1. Al principio, se utilizó el sitio web oficial input_datapara cargar el conjunto de datos locales, pero se informaría del siguiente error

No module named 'tensorflow.examples.tutorials'

Y el sitio web oficial input_data.py
no se puede descargar 2. Usando keras, también fue debido a la imposibilidad de acceder a googlesource al principio que el conjunto de datos mnist no se pudo cargar.
Solución: modifique mnist.py, utilice el conjunto de datos mnist descargado localmente y cambie directamente la ruta en mnist.py por la ruta del conjunto de datos mnist local.
Código adjunto:
main.py

from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow as tf

mnist = tf.keras.datasets.mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dropout(0.2),
  tf.keras.layers.Dense(10, activation='softmax')
])

model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])


model.fit(x_train, y_train, epochs=5)

model.evaluate(x_test,  y_test, verbose=2)

mnist.py


"""MNIST handwritten digits dataset.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np

from tensorflow.python.keras.utils.data_utils import get_file
from tensorflow.python.util.tf_export import keras_export


@keras_export('keras.datasets.mnist.load_data')
def load_data(path='mnist.npz'):
  """Loads the MNIST dataset.

  Arguments:
      path: path where to cache the dataset locally
          (relative to ~/.keras/datasets).

  Returns:
      Tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`.

  License:
      Yann LeCun and Corinna Cortes hold the copyright of MNIST dataset,
      which is a derivative work from original NIST datasets.
      MNIST dataset is made available under the terms of the
      [Creative Commons Attribution-Share Alike 3.0 license.](
      https://creativecommons.org/licenses/by-sa/3.0/)
  """

  path = "./mnist.npz"
  with np.load(path) as f:
    x_train, y_train = f['x_train'], f['y_train']
    x_test, y_test = f['x_test'], f['y_test']

    return (x_train, y_train), (x_test, y_test)

Supongo que te gusta

Origin blog.csdn.net/qq_45465526/article/details/103125997
Recomendado
Clasificación