TensorFlow使用指南

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/ZWX2445205419/article/details/88667774

数据读取

官方教程:
【1】导入数据
【2】数据输入流水线性能

import tensorflow as tf

test_file = 'test.txt'
train_file = 'train.txt'


def get_items(filename):
    filenames = []
    labels = []
    for line in open(filename):
        filename, label = line.strip('\n').split()
        filenames.append(filename)
        labels.append(label)
    return tf.constant(filenames), tf.constant(labels)


def _parse_function(filename, label):
    image_string = tf.read_file(filename)
    image_decoded = tf.image.decode_jpeg(image_string)
    image_resized = tf.image.resize_images(image_decoded, [28, 28])
    return image_resized, label


test_filenames, test_labels = get_items(test_file)
train_filenames, train_labels = get_items(train_file)

train_dataset = tf.data.Dataset.from_tensor_slices((train_filenames, train_labels))
train_dataset = train_dataset.map(_parse_function, num_parallel_calls=4).repeat().shuffle(buffer_size=10000).batch(32).prefetch(buffer_size=1000)
test_dataset = tf.data.Dataset.from_tensor_slices((test_filenames, test_labels))
test_dataset = test_dataset.map(_parse_function, num_parallel_calls=4).repeat().batch(32).prefetch(buffer_size=1000)

handle = tf.placeholder(tf.string, shape=[])
iterator = tf.data.Iterator.from_string_handle(handle, train_dataset.output_types, train_dataset.output_shapes)
next_element = iterator.get_next()

train_iterator = train_dataset.make_one_shot_iterator()
test_iterator = test_dataset.make_initializable_iterator()

sess = tf.Session()
train_handle = sess.run(train_iterator.string_handle())
test_handle = sess.run(test_iterator.string_handle())

while True:
    print('train')
    for _ in range(20):
        inputs, labels = sess.run(next_element, feed_dict={handle: train_handle})
        print(inputs.shape, labels.shape)

    print('test')
    sess.run(test_iterator.initializer)
    for _ in range(5):
        inputs, labels = sess.run(next_element, feed_dict={handle: test_handle})
        print(inputs.shape, labels.shape)

Tensorflow导入数据探究

#! -*- coding: utf-8 -*-
import tensorflow as tf


train_x = tf.range(0, 1000)
train_y = tf.range(0, 1000)

test_x = tf.range(0, 100)
test_y = tf.range(0, 100)

train_dataset = tf.data.Dataset.from_tensor_slices((train_x, train_y))
train_dataset = train_dataset.shuffle(buffer_size=100).batch(batch_size=10).repeat().prefetch(buffer_size=20)
train_iterator = train_dataset.make_one_shot_iterator()
train_next_element = train_iterator.get_next()

test_dataset = tf.data.Dataset.from_tensor_slices((test_x, test_y))
test_dataset = test_dataset.batch(10).prefetch(buffer_size=20)
test_iterator = test_dataset.make_initializable_iterator()
test_next_element = test_iterator.get_next()

with tf.Session() as sess:
    for _ in range(5):
        for _ in range(10):
            train_x, train_y = sess.run(train_next_element)
            print('train: ', train_x, train_y)

        sess.run(test_iterator.initializer)
        for _ in range(5):
            test_x, test_y = sess.run(test_next_element)
            print('test: ', test_x, test_y)

训练集使用make_one_shot_iterator(),且进行了shuffle,batch,repeat,prefetch等操作
测试集使用make_initializable_iterator(),且没有进行shufflerepeat操作
其结果为:

train:  [ 39  64  37  86   1  23  31 100 101 107] [ 39  64  37  86   1  23  31 100 101 107]
train:  [73 32 70 67 81 44 30 50  7 12] [73 32 70 67 81 44 30 50  7 12]
train:  [114 116 119  36 121 103  63  80  28   2] [114 116 119  36 121 103  63  80  28   2]
train:  [ 58 104 102  46  10  95  41 133  62  96] [ 58 104 102  46  10  95  41 133  62  96]
train:  [  5   4 130 132  15  89  43  54  99 126] [  5   4 130 132  15  89  43  54  99 126]
train:  [ 17  51 123  48 113  57  61  59  93 142] [ 17  51 123  48 113  57  61  59  93 142]
train:  [ 52 128   6 148 158  20   9 161  71  53] [ 52 128   6 148 158  20   9 161  71  53]
train:  [150 149 156  55  77 166 172 124 122 153] [150 149 156  55  77 166 172 124 122 153]
train:  [134  79 118 136 115 127   0  16 164 180] [134  79 118 136 115 127   0  16 164 180]
train:  [ 26 137 179  49 129  11  38 195  25 197] [ 26 137 179  49 129  11  38 195  25 197]
test:  [0 1 2 3 4 5 6 7 8 9] [0 1 2 3 4 5 6 7 8 9]
test:  [10 11 12 13 14 15 16 17 18 19] [10 11 12 13 14 15 16 17 18 19]
test:  [20 21 22 23 24 25 26 27 28 29] [20 21 22 23 24 25 26 27 28 29]
test:  [30 31 32 33 34 35 36 37 38 39] [30 31 32 33 34 35 36 37 38 39]
test:  [40 41 42 43 44 45 46 47 48 49] [40 41 42 43 44 45 46 47 48 49]
train:  [171 184  91 189 155  35 181  83  24 203] [171 184  91 189 155  35 181  83  24 203]
train:  [160 174  33  75 204  82 170  85 177 154] [160 174  33  75 204  82 170  85 177 154]
train:  [109 151  22 196 198  18 192  40 186 215] [109 151  22 196 198  18 192  40 186 215]
train:  [147  88  66   3 182 205 223 229 145 237] [147  88  66   3 182 205 223 229 145 237]
train:  [210 218 178  68 162  19 243 222 138 165] [210 218 178  68 162  19 243 222 138 165]
train:  [ 42 167 221  13  92 131  72 239 236 110] [ 42 167 221  13  92 131  72 239 236 110]
train:  [ 14 106 233 163 191  90  98 230 251 235] [ 14 106 233 163 191  90  98 230 251 235]
train:  [225 270 211  45 255 274 259  56 265 252] [225 270 211  45 255 274 259  56 265 252]
train:  [185 246 208  78  74  27 268 245 261 202] [185 246 208  78  74  27 268 245 261 202]
train:  [152 213 258 256 176 292 284 286 248 281] [152 213 258 256 176 292 284 286 248 281]
test:  [0 1 2 3 4 5 6 7 8 9] [0 1 2 3 4 5 6 7 8 9]
test:  [10 11 12 13 14 15 16 17 18 19] [10 11 12 13 14 15 16 17 18 19]
test:  [20 21 22 23 24 25 26 27 28 29] [20 21 22 23 24 25 26 27 28 29]
test:  [30 31 32 33 34 35 36 37 38 39] [30 31 32 33 34 35 36 37 38 39]
test:  [40 41 42 43 44 45 46 47 48 49] [40 41 42 43 44 45 46 47 48 49]
train:  [105 257 146  87 226 157 183 287 244 273] [105 257 146  87 226 157 183 287 244 273]
train:  [253 289 190 283 173 250 234   8  65 217] [253 289 190 283 173 250 234   8  65 217]
train:  [117 305 269 201 282 288 315 309 140  76] [117 305 269 201 282 288 315 309 140  76]
train:  [188 141 111 175 326 300 240 187 321  97] [188 141 111 175 326 300 240 187 321  97]
train:  [329 302 249 220  69  84 307 200 267 320] [329 302 249 220  69  84 307 200 267 320]
train:  [216 209 304 135 314 232 308 334 347 357] [216 209 304 135 314 232 308 334 347 357]
train:  [335 264 351 272 263 313 199 299 341 194] [335 264 351 272 263 313 199 299 341 194]
train:  [168 298 144 193 323 290 296 346  94 277] [168 298 144 193 323 290 296 346  94 277]
train:  [ 60 139 339 366 356 280 332 348 247 291] [ 60 139 339 366 356 280 332 348 247 291]
train:  [112 262 324 125 333 353 227 231 393 297] [112 262 324 125 333 353 227 231 393 297]
test:  [0 1 2 3 4 5 6 7 8 9] [0 1 2 3 4 5 6 7 8 9]
test:  [10 11 12 13 14 15 16 17 18 19] [10 11 12 13 14 15 16 17 18 19]
test:  [20 21 22 23 24 25 26 27 28 29] [20 21 22 23 24 25 26 27 28 29]
test:  [30 31 32 33 34 35 36 37 38 39] [30 31 32 33 34 35 36 37 38 39]
test:  [40 41 42 43 44 45 46 47 48 49] [40 41 42 43 44 45 46 47 48 49]
train:  [391 337 368 344 212 241 359 260 159 325] [391 337 368 344 212 241 359 260 159 325]
train:  [303 238 362 349 364 322 345 405 409 417] [303 238 362 349 364 322 345 405 409 417]
train:  [343 404 310 358 384  29 418 407 412 271] [343 404 310 358 384  29 418 407 412 271]
train:  [413 120 381 336 398 294 389 228 376 328] [413 120 381 336 398 294 389 228 376 328]
train:  [436 206 169 385 420 383 395 372 367 396] [436 206 169 385 420 383 395 372 367 396]
train:  [438 275 449 399 371 439 433  34 327 388] [438 275 449 399 371 439 433  34 327 388]
train:  [440 295 427 432 301 401 431 458 442 278] [440 295 427 432 301 401 431 458 442 278]
train:  [370 468 414 279 403 276 456 378 459 463] [370 468 414 279 403 276 456 378 459 463]
train:  [410 435 350 471 406 479 415 214  21 352] [410 435 350 471 406 479 415 214  21 352]
train:  [424 312 361 489 360 457 316 402 447 428] [424 312 361 489 360 457 316 402 447 428]
test:  [0 1 2 3 4 5 6 7 8 9] [0 1 2 3 4 5 6 7 8 9]
test:  [10 11 12 13 14 15 16 17 18 19] [10 11 12 13 14 15 16 17 18 19]
test:  [20 21 22 23 24 25 26 27 28 29] [20 21 22 23 24 25 26 27 28 29]
test:  [30 31 32 33 34 35 36 37 38 39] [30 31 32 33 34 35 36 37 38 39]
test:  [40 41 42 43 44 45 46 47 48 49] [40 41 42 43 44 45 46 47 48 49]
train:  [397 317 451 379 470 491 318 469 421 494] [397 317 451 379 470 491 318 469 421 494]
train:  [455 475 464 426 444 365 500 474 502 293] [455 475 464 426 444 365 500 474 502 293]
train:  [485 434 330 481 386 480 430 319 266 411] [485 434 330 481 386 480 430 319 266 411]
train:  [505 373 511 108 478 461 482 311 495 504] [505 373 511 108 478 461 482 311 495 504]
train:  [462 539 536 448 219 527 531 499 453 533] [462 539 536 448 219 527 531 499 453 533]
train:  [416 465 497 375 331 518 355 419 450 369] [416 465 497 375 331 518 355 419 450 369]
train:  [390 549 496 306 392 467 374 422 477 508] [390 549 496 306 392 467 374 422 477 508]
train:  [254 460 473 547 572 445 425 570 554 525] [254 460 473 547 572 445 425 570 554 525]
train:  [377 285 569 517 566 542 560 529 544 564] [377 285 569 517 566 542 560 529 544 564]
train:  [552 143 503 476 576 573 512 591 571 488] [552 143 503 476 576 573 512 591 571 488]
test:  [0 1 2 3 4 5 6 7 8 9] [0 1 2 3 4 5 6 7 8 9]
test:  [10 11 12 13 14 15 16 17 18 19] [10 11 12 13 14 15 16 17 18 19]
test:  [20 21 22 23 24 25 26 27 28 29] [20 21 22 23 24 25 26 27 28 29]
test:  [30 31 32 33 34 35 36 37 38 39] [30 31 32 33 34 35 36 37 38 39]
test:  [40 41 42 43 44 45 46 47 48 49] [40 41 42 43 44 45 46 47 48 49]

可以看到,训练集的每个batch都进行了shuffle,而测试集每次都从头开始重新取50个数据。
我们改动一下,使得训练时随机获取,且每训练10个batch进行一次测试,测试时从测试集中随机获取5个batch

#! -*- coding: utf-8 -*-
import tensorflow as tf


train_x = tf.range(0, 1000)
train_y = tf.range(0, 1000)

test_x = tf.range(0, 100)
test_y = tf.range(0, 100)

train_dataset = tf.data.Dataset.from_tensor_slices((train_x, train_y))
train_dataset = train_dataset.shuffle(buffer_size=100).batch(batch_size=10).repeat().prefetch(buffer_size=20)
train_iterator = train_dataset.make_one_shot_iterator()
train_next_element = train_iterator.get_next()

test_dataset = tf.data.Dataset.from_tensor_slices((test_x, test_y))
test_dataset = test_dataset.shuffle(buffer_size=100).batch(10).prefetch(buffer_size=20)
test_iterator = test_dataset.make_initializable_iterator()
test_next_element = test_iterator.get_next()

with tf.Session() as sess:
    for _ in range(5):
        for _ in range(10):
            train_x, train_y = sess.run(train_next_element)
            print('train: ', train_x, train_y)

        sess.run(test_iterator.initializer)
        for _ in range(5):
            test_x, test_y = sess.run(test_next_element)
            print('test: ', test_x, test_y)

其结果为:

train:  [ 83   6  18  16  90  54  12  57  29 103] [ 83   6  18  16  90  54  12  57  29 103]
train:  [104  88  35  55  43  63 108  38   8  61] [104  88  35  55  43  63 108  38   8  61]
train:  [  4  91  28  42  77  30  37  31 114   1] [  4  91  28  42  77  30  37  31 114   1]
train:  [ 68  78   7  34  87 125  73  62  60  24] [ 68  78   7  34  87 125  73  62  60  24]
train:  [  3 138 100 139 133  17  48  11 130  39] [  3 138 100 139 133  17  48  11 130  39]
train:  [ 74  51  44  46 153 126 142 143  82  50] [ 74  51  44  46 153 126 142 143  82  50]
train:  [148  22 140 136  14   0  33 113  20 109] [148  22 140 136  14   0  33 113  20 109]
train:  [168 169 155 122   5 117 166  69 115 176] [168 169 155 122   5 117 166  69 115 176]
train:  [ 21  27  72 162  49   2 111  99 127  70] [ 21  27  72 162  49   2 111  99 127  70]
train:  [ 56 188 184 175 152 150 164 196 158 187] [ 56 188 184 175 152 150 164 196 158 187]
test:  [43 17 83 36 90 96 62 95  0 63] [43 17 83 36 90 96 62 95  0 63]
test:  [16 28 87 34 15 22 68 42 35 25] [16 28 87 34 15 22 68 42 35 25]
test:  [ 9 65 30 31 33 73 70 92 40 80] [ 9 65 30 31 33 73 70 92 40 80]
test:  [23  2 52 85 20 75 97  4 61 91] [23  2 52 85 20 75 97  4 61 91]
test:  [ 8 71 58 46 55 64 37  6 76 11] [ 8 71 58 46 55 64 37  6 76 11]
train:  [ 96  95 186  19 181 203 191 144  32  75] [ 96  95 186  19 181 203 191 144  32  75]
train:  [156 118 146 178  59  85 132 193 157 159] [156 118 146 178  59  85 132 193 157 159]
train:  [ 45  89 198  86 201 207 225 105 149 151] [ 45  89 198  86 201 207 225 105 149 151]
train:  [218 183 214 123  94 216 121 102 227 180] [218 183 214 123  94 216 121 102 227 180]
train:  [219 106 167  15 228  92 230 231 141 245] [219 106 167  15 228  92 230 231 141 245]
train:  [239 170 200 217 213  84 120 107 163  81] [239 170 200 217 213  84 120 107 163  81]
train:  [ 13 241 195 173 248  25  64 194 254  98] [ 13 241 195 173 248  25  64 194 254  98]
train:  [267 256 223 234 202 274 154  76 179 259] [267 256 223 234 202 274 154  76 179 259]
train:  [211 275 222 171  97 137 182 220 185 277] [211 275 222 171  97 137 182 220 185 277]
train:  [272 265 290 232  36 161 260 255 128 205] [272 265 290 232  36 161 260 255 128 205]
test:  [14 78 79 66 73 16 31 18 91 48] [14 78 79 66 73 16 31 18 91 48]
test:  [68 34 50 17 70 92 28 94 27  6] [68 34 50 17 70 92 28 94 27  6]
test:  [62 46 51 32 99 44 81 59 25 54] [62 46 51 32 99 44 81 59 25 54]
test:  [58  5 95 64 63 10  4 89 67 98] [58  5 95 64 63 10  4 89 67 98]
test:  [76 65 15 33 19 74 22 45  7 13] [76 65 15 33 19 74 22 45  7 13]
train:  [247 293 270 302 172 129 199   9 306  10] [247 293 270 302 172 129 199   9 306  10]
train:  [262 257 280 235 208 309 282 124 263 252] [262 257 280 235 208 309 282 124 263 252]
train:  [221 314 271 147 289 224  47  79 287 313] [221 314 271 147 289 224  47  79 287 313]
train:  [192 261 209 269 112 330 204 316 298  71] [192 261 209 269 112 330 204 316 298  71]
train:  [110 276 131 165 334 310 329 336 229 331] [110 276 131 165 334 310 329 336 229 331]
train:  [339 300 341 284 295 338 285  53 342 305] [339 300 341 284 295 338 285  53 342 305]
train:  [212 174  52 297 312 317 322 320 244 237] [212 174  52 297 312 317 322 320 244 237]
train:  [296 299  93 286 246 249  67 366 324  26] [296 299  93 286 246 249  67 366 324  26]
train:  [332 251 337 266 315 283 101 351 358 236] [332 251 337 266 315 283 101 351 358 236]
train:  [145 354 363 273 373 250  65 352 393 340] [145 354 363 273 373 250  65 352 393 340]
test:  [38 13 53 74 64 94 61 91 56 59] [38 13 53 74 64 94 61 91 56 59]
test:  [82 57 54 11 26 66 92  0  1 60] [82 57 54 11 26 66 92  0  1 60]
test:  [70 12 17 15 31 37 41  3 99 80] [70 12 17 15 31 37 41  3 99 80]
test:  [77 19 63 72 89 43 81 97 50 85] [77 19 63 72 89 43 81 97 50 85]
test:  [75 18 36 47  4 33 24 22 39 46] [75 18 36 47  4 33 24 22 39 46]
train:  [226 370 346 362 304 375 402 374 376  58] [226 370 346 362 304 375 402 374 376  58]
train:  [383  80 190 233 258 206 360 367 344 307] [383  80 190 233 258 206 360 367 344 307]
train:  [391  23 323 328 368 395 197 409 410 372] [391  23 323 328 368 395 197 409 410 372]
train:  [243 160 430 268 429 353 343 414 326 294] [243 160 430 268 429 353 343 414 326 294]
train:  [279 407 428 419 397 406 253 436 119 303] [279 407 428 419 397 406 253 436 119 303]
train:  [347 398 442 382 377 421 288 345 386 432] [347 398 442 382 377 421 288 345 386 432]
train:  [349 318 447 333 451 462 444 308 418 403] [349 318 447 333 451 462 444 308 418 403]
train:  [426 401 457 413 423 189 371 327 459 468] [426 401 457 413 423 189 371 327 459 468]
train:  [350 458 292  41 479 399 465 454 238 387] [350 458 292  41 479 399 465 454 238 387]
train:  [427 455 134 435 461  66 477 378 416 456] [427 455 134 435 461  66 477 378 416 456]
test:  [23 46 59 53 10 81 13 48 43 27] [23 46 59 53 10 81 13 48 43 27]
test:  [25 70 69 98 55 68 74 92 29  3] [25 70 69 98 55 68 74 92 29  3]
test:  [11 39 80  1 73  6 32 41 75 96] [11 39 80  1 73  6 32 41 75 96]
test:  [17 24 97 95  9 51 67 76 44 58] [17 24 97 95  9 51 67 76 44 58]
test:  [21 45 71 18 15 72  0 88 91 31] [21 45 71 18 15 72  0 88 91 31]
train:  [215 311 483 446 396 440 503 364 496 450] [215 311 483 446 396 440 503 364 496 450]
train:  [385 392 493 361 473 480 443 408 400 507] [385 392 493 361 473 480 443 408 400 507]
train:  [501 481 500 445 449 489 475 335 488 412] [501 481 500 445 449 489 475 335 488 412]
train:  [510 291 514 469 434 476 422 487 498 490] [510 291 514 469 434 476 422 487 498 490]
train:  [357 474 264 321 516 384 452 380 531 495] [357 474 264 321 516 384 452 380 531 495]
train:  [542 509 431 532 135 453 554 524 301 539] [542 509 431 532 135 453 554 524 301 539]
train:  [441 319 278 521 448 460 553 558 520 379] [441 319 278 521 448 460 553 558 520 379]
train:  [411 505 562 537 388 437 325 116 560 389] [411 505 562 537 388 437 325 116 560 389]
train:  [471 552 544 569 484 526 499 548 525 390] [471 552 544 569 484 526 499 548 525 390]
train:  [551 355 513 497 438 466 540 579 369 550] [551 355 513 497 438 466 540 579 369 550]
test:  [25 64 39 42 90 23 86 20 55 60] [25 64 39 42 90 23 86 20 55 60]
test:  [94 53 14 73 16 81 84 13 92 24] [94 53 14 73 16 81 84 13 92 24]
test:  [79 67  7  2 61 10 99 36 28 95] [79 67  7  2 61 10 99 36 28 95]
test:  [31 38 76 33 15 41 48 59 11  4] [31 38 76 33 15 41 48 59 11  4]
test:  [68 88 37 96 70  6 93 78 62 26] [68 88 37 96 70  6 93 78 62 26]

这样的设置是比较合理的,但测试集比较大时,进行全量测试是比较耗时的,我们设置每隔多少个训练batch,随意抽取一些测试集的batch查看模型效果。

设置GPU使用资源

import tensorflow as tf
gpu_config = tf.GPUOptions(
    allow_growth=True,  # 刚开始会分配少量的GPU容量,然后按需慢慢地增加,由于不会释放内存,所以会导致碎片
    per_process_gpu_memory_fraction=0.7,  # 给GPU分配固定大小的计算资源
)
config = tf.ConfigProto(
    log_device_placement=True,  # 是否打印设备分配日志
    allow_soft_placement=True,  # 如果指定的设备不存在,允许TF自动分配设备
    gpu_options=gpu_config,  # 设置GPU使用资源
)

with tf.Session(config=config) as sess:
    with tf.device("/gpu:0"):  # 指定GPU运算
        a = tf.placeholder(tf.int16)
        b = tf.placeholder(tf.int16)
        add = tf.add(a, b)
        print(sess.run(add, feed_dict={a: 1, b: 2}))

使用Estimator

参考:https://www.jianshu.com/p/e343758a185e

#! -*- coding: utf-8 -*-
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import numpy as np
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

tf.logging.set_verbosity(tf.logging.INFO)  # 设定输出日志的模式


# 我们的程序代码将放在这里
def cnn_model_fn(features, labels, mode):
    # 输入层,-1表示自动计算,这里是图片批次大小,宽高各28,最后1表示颜色单色
    input_layer = tf.reshape(features["x"], [-1, 28, 28, 1])

    # 1号卷积层,过滤32次,核心区域5x5,激活函数relu
    conv1 = tf.layers.conv2d(
        inputs=input_layer,  # 接收上面创建的输入层输出的张量
        filters=32,
        kernel_size=[5, 5],
        padding="same",
        activation=tf.nn.relu)

    # 1号池化层,接收1号卷积层输出的张量
    pool1 = tf.layers.max_pooling2d(inputs=conv1, pool_size=[2, 2], strides=2)

    # 2号卷积层
    conv2 = tf.layers.conv2d(
        inputs=pool1,  # 继续1号池化层的输出
        filters=64,
        kernel_size=[5, 5],
        padding="same",
        activation=tf.nn.relu)

    # 2号池化层
    pool2 = tf.layers.max_pooling2d(inputs=conv2, pool_size=[2, 2], strides=2)

    # 对2号池化层的输入变换张量形状
    pool2_flat = tf.reshape(pool2, [-1, 7 * 7 * 64])

    # 密度层
    dense = tf.layers.dense(inputs=pool2_flat, units=1024, activation=tf.nn.relu)

    # 丢弃层进行简化
    dropout = tf.layers.dropout(
        inputs=dense, rate=0.4, training=mode == tf.estimator.ModeKeys.TRAIN)

    # 使用密度层作为最终输出,unit可能的分类数量
    logits = tf.layers.dense(inputs=dropout, units=10)

    # 预测和评价使用的输出数据内容
    predictions = {
        # 产生预测,argmax输出第一个轴向的最大数值
        "classes": tf.argmax(input=logits, axis=1),
        # 输出可能性
        "probabilities": tf.nn.softmax(logits, name="softmax_tensor")
    }

    # 以下是根据mode切换的三个不同的方法,都返回tf.estimator.EstimatorSpec对象

    # 预测
    if mode == tf.estimator.ModeKeys.PREDICT:
        return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)

    # 损失函数(训练与评价使用),稀疏柔性最大值交叉熵
    loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)

    # 训练,使用梯度下降优化器,
    if mode == tf.estimator.ModeKeys.TRAIN:
        optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.001)
        train_op = optimizer.minimize(
            loss=loss,
            global_step=tf.train.get_global_step())
        return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op)

    # 评价函数(上面两个mode之外else)添加评价度量(for EVAL mode)
    eval_metric_ops = {
        "accuracy": tf.metrics.accuracy(
            labels=labels, predictions=predictions["classes"])}
    return tf.estimator.EstimatorSpec(
        mode=mode, loss=loss, eval_metric_ops=eval_metric_ops)


dir_path = os.path.dirname(os.path.realpath(__file__))
data_path = os.path.join(dir_path, 'MNIST_data')


def main(args):
    # 载入训练和测试数据
    mnist = input_data.read_data_sets(data_path)
    train_data = mnist.train.images  # 得到np.array
    train_labels = np.asarray(mnist.train.labels, dtype=np.int32)
    eval_data = mnist.test.images  # 得到np.array
    eval_labels = np.asarray(mnist.test.labels, dtype=np.int32)

    # 创建估算器
    mnist_classifier = tf.estimator.Estimator(
        model_fn=cnn_model_fn, model_dir="/tmp/mnist_convnet_model")

    # 设置输出预测的日志
    tensors_to_log = {"probabilities": "softmax_tensor"}
    logging_hook = tf.train.LoggingTensorHook(
        tensors=tensors_to_log, every_n_iter=50)

    # 训练喂食函数
    train_input_fn = tf.estimator.inputs.numpy_input_fn(
        x={"x": train_data},
        y=train_labels,
        batch_size=100,
        num_epochs=None,
        shuffle=True)

    # 启动训练
    mnist_classifier.train(
        input_fn=train_input_fn,
        steps=20000,
        hooks=[logging_hook])

    # 评价喂食函数
    eval_input_fn = tf.estimator.inputs.numpy_input_fn(
        x={"x": eval_data},
        y=eval_labels,
        num_epochs=1,
        shuffle=False)

    # 启动评价并输出结果
    eval_results = mnist_classifier.evaluate(input_fn=eval_input_fn)
    print(eval_results)


# 这个文件能够直接运行,也可以作为模块被其他文件载入
if __name__ == "__main__":
    tf.app.run()

TensorFlow搭建网络模型

tensorflow搭建网络的库: tf.keras, tf.nn, tf.layers

  • tf.nn
    最底层的函数,其他各种库基本都是基于这个底层库来进行扩展的
  • tf.layers
    比tf.nn更高级的库,对tf.nn进行了多方位功能扩展,就是用tf.nn造的轮子。最大的特点就是,库中每个函数都有相应的类
  • tf.keras
    tf.keras是基于tf.layers和tf.nn的高度封装

猜你喜欢

转载自blog.csdn.net/ZWX2445205419/article/details/88667774
今日推荐