Annuaire d'articles
Régression logistique réalisée avec TensorFlow
TensorFlow est un framework d'apprentissage automatique open source développé par l'équipe Google Brain en 2015. Il est largement utilisé dans la reconnaissance d’images et de la parole, le traitement du langage naturel, les systèmes de recommandation et d’autres domaines.
À la base, TensorFlow est un graphe de flux de données pour le calcul. Dans un graphe de flux de données, les nœuds représentent des opérations mathématiques et les arêtes représentent des tenseurs (tableaux multidimensionnels). Les graphiques Dataflow qui combinent opérations et données permettent à TensorFlow d'optimiser des modèles mathématiques complexes tout en prenant en charge l'informatique distribuée.
TensorFlow fournit des interfaces de plusieurs langages de programmation tels que Python, C++, Java, Go, etc., permettant aux développeurs d'utiliser TensorFlow pour créer et former plus facilement des modèles d'apprentissage en profondeur. En outre, TensorFlow dispose également d'une multitude d'outils et de bibliothèques, notamment l'outil de visualisation TensorBoard, le service de modèle TensorFlow Serving pour l'environnement de production, l'API d'empaquetage de haut niveau Keras, etc.
TensorFlow a développé de nombreux excellents modèles, tels que les réseaux de neurones convolutifs, les réseaux de neurones récurrents et les réseaux de confrontation générative. Ces modèles ont obtenu d’excellents résultats dans de nombreux domaines, comme la reconnaissance d’images, la reconnaissance vocale, le traitement du langage naturel, etc.
En plus de TensorFlow open source, Google a également lancé Google Cloud ML, une plate-forme d'apprentissage automatique dans le cloud basée sur TensorFlow, qui offre aux utilisateurs des services plus pratiques pour la formation et le déploiement de modèles d'apprentissage automatique.
Le modèle de base le plus courant pour résoudre les problèmes de classification est la régression logistique. Il est simple et interprétable, ce qui le rend très populaire. Utilisons Tensorflow pour terminer la construction de ce modèle.
1. Paramètres d'environnement
import os
os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
import warnings
warnings.filterwarnings("ignore")
import numpy as np
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import time
2. Lecture des données
#使用tensorflow自带的工具加载MNIST手写数字集合
mnist = input_data.read_data_sets('./data/mnist', one_hot=True)
Extracting ./data/mnist/train-images-idx3-ubyte.gz
Extracting ./data/mnist/train-labels-idx1-ubyte.gz
Extracting ./data/mnist/t10k-images-idx3-ubyte.gz
Extracting ./data/mnist/t10k-labels-idx1-ubyte.gz
#查看一下数据维度
mnist.train.images.shape
(55000, 784)
#查看target维度
mnist.train.labels.shape
(55000, 10)
3. Préparez l'espace réservé
batch_size = 128
X = tf.placeholder(tf.float32, [batch_size, 784], name='X_placeholder')
Y = tf.placeholder(tf.int32, [batch_size, 10], name='Y_placeholder')
4. Préparer les paramètres/poids
w = tf.Variable(tf.random_normal(shape=[784, 10], stddev=0.01), name='weights')
b = tf.Variable(tf.zeros([1, 10]), name="bias")
logits = tf.matmul(X, w) + b
5. Calculer la fonction de perte du softmax multicatégorie
# 求交叉熵损失
entropy = tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=Y, name='loss')
# 求平均
loss = tf.reduce_mean(entropy)
6. Préparez l'optimiseur
L'optimisation ici utilise la descente de gradient stochastique, et nous pouvons choisir un optimiseur tel que AdamOptimizer
learning_rate = 0.01
optimizer = tf.train.AdamOptimizer(learning_rate).minimize(loss)
7. Exécuter les opérations définies dans le graphique dans la session
#迭代总轮次
n_epochs = 30
with tf.Session() as sess:
# 在Tensorboard里可以看到图的结构
writer = tf.summary.FileWriter('../graphs/logistic_reg', sess.graph)
start_time = time.time()
sess.run(tf.global_variables_initializer())
n_batches = int(mnist.train.num_examples/batch_size)
for i in range(n_epochs): # 迭代这么多轮
total_loss = 0
for _ in range(n_batches):
X_batch, Y_batch = mnist.train.next_batch(batch_size)
_, loss_batch = sess.run([optimizer, loss], feed_dict={
X: X_batch, Y:Y_batch})
total_loss += loss_batch
print('Average loss epoch {0}: {1}'.format(i, total_loss/n_batches))
print('Total time: {0} seconds'.format(time.time() - start_time))
print('Optimization Finished!')
# 测试模型
preds = tf.nn.softmax(logits)
correct_preds = tf.equal(tf.argmax(preds, 1), tf.argmax(Y, 1))
accuracy = tf.reduce_sum(tf.cast(correct_preds, tf.float32))
n_batches = int(mnist.test.num_examples/batch_size)
total_correct_preds = 0
for i in range(n_batches):
X_batch, Y_batch = mnist.test.next_batch(batch_size)
accuracy_batch = sess.run([accuracy], feed_dict={
X: X_batch, Y:Y_batch})
total_correct_preds += accuracy_batch[0]
print('Accuracy {0}'.format(total_correct_preds/mnist.test.num_examples))
writer.close()
Average loss epoch 0: 0.36748782022571785
Average loss epoch 1: 0.2978815356126198
Average loss epoch 2: 0.27840628396797845
Average loss epoch 3: 0.2783186247437706
Average loss epoch 4: 0.2783641471138923
Average loss epoch 5: 0.2750668214473413
Average loss epoch 6: 0.2687560408126502
Average loss epoch 7: 0.2713795114126239
Average loss epoch 8: 0.2657588795522154
Average loss epoch 9: 0.26322007090686916
Average loss epoch 10: 0.26289192279735646
Average loss epoch 11: 0.26248606019989873
Average loss epoch 12: 0.2604622903056356
Average loss epoch 13: 0.26015280702939403
Average loss epoch 14: 0.2581879366319496
Average loss epoch 15: 0.2590309207117085
Average loss epoch 16: 0.2630510463581219
Average loss epoch 17: 0.25501730025578767
Average loss epoch 18: 0.2547102673000945
Average loss epoch 19: 0.258298404375851
Average loss epoch 20: 0.2549241428330784
Average loss epoch 21: 0.2546788509283866
Average loss epoch 22: 0.259556887067837
Average loss epoch 23: 0.25428259843365575
Average loss epoch 24: 0.25442713139565676
Average loss epoch 25: 0.2553852511383159
Average loss epoch 26: 0.2503043229415978
Average loss epoch 27: 0.25468004046828596
Average loss epoch 28: 0.2552785321479633
Average loss epoch 29: 0.2506257003663859
Total time: 28.603315353393555 seconds
Optimization Finished!
Accuracy 0.9187
Pièce jointe : série d'articles
numéro de série | Annuaire d'articles | lien direct |
---|---|---|
1 | Prévisions de prix immobilier à Boston | https://want595.blog.csdn.net/article/details/132181950 |
2 | Analyse de l'ensemble de données sur l'iris | https://want595.blog.csdn.net/article/details/132182057 |
3 | traitement des fonctionnalités | https://want595.blog.csdn.net/article/details/132182165 |
4 | Validation croisée | https://want595.blog.csdn.net/article/details/132182238 |
5 | Construction d'un exemple de réseau neuronal | https://want595.blog.csdn.net/article/details/132182341 |
6 | Utiliser TensorFlow pour effectuer une régression linéaire | https://want595.blog.csdn.net/article/details/132182417 |
7 | Régression logistique réalisée avec TensorFlow | https://want595.blog.csdn.net/article/details/132182496 |
8 | Cas TensorBoard | https://want595.blog.csdn.net/article/details/132182584 |
9 | Régression linéaire réalisée avec Keras | https://want595.blog.csdn.net/article/details/132182723 |
dix | Régression logistique réalisée avec Keras | https://want595.blog.csdn.net/article/details/132182795 |
11 | Utilisez le modèle de pré-formation Keras pour compléter la reconnaissance des chats et des chiens | https://want595.blog.csdn.net/article/details/132243928 |
12 | Entraîner le modèle avec PyTorch | https://want595.blog.csdn.net/article/details/132243989 |
13 | Utilisez Dropout pour supprimer le surapprentissage | https://want595.blog.csdn.net/article/details/132244111 |
14 | Utiliser CNN pour compléter la reconnaissance de l'écriture manuscrite MNIST (TensorFlow) | https://want595.blog.csdn.net/article/details/132244499 |
15 | Utiliser CNN pour compléter la reconnaissance de l'écriture manuscrite MNIST (Keras) | https://want595.blog.csdn.net/article/details/132244552 |
16 | Utiliser CNN pour compléter la reconnaissance de l'écriture manuscrite MNIST (PyTorch) | https://want595.blog.csdn.net/article/details/132244641 |
17 | Générer des échantillons de chiffres manuscrits à l'aide de GAN | https://want595.blog.csdn.net/article/details/132244764 |
18 | traitement du langage naturel | https://want595.blog.csdn.net/article/details/132276591 |