Train a computer vision model (handwritten digit classifier) in the browser with TensorFlow.js


We have already had a general understanding of the use of TensorFlow.js in "Running TensorFlow.js in the browser to train the model and give the prediction result (Iris data set)" , now let's take a closer look at how to train a picture data set , and do some visualization work. The code for the article can be found in the book "AI and Machine Learning for Coders" GitHub .

When using a browser, every time we open a resource at a URL, an HTTP connection is established. We use this connection to send commands to the server, and the server sends back the results. When it comes to computer vision, we usually have a lot of training data. For example, MNIST and Fashion MNIST, although they are already very small image datasets, they still contain 70,000 images, which would be 70,000 HTTP connections! This obviously creates a lot of overhead, and we'll see how to deal with that later.

Building a CNN in JavaScript

Let's see how the following CNN model for the handwritten digit dataset defined in keras is defined in JavaScript:

model = tf.keras.models.Sequential([
    tf.keras.layers.Conv2D(64, (3, 3),activation='relu', 
                           input_shape=(28, 28, 1)),
    tf.keras.layers.MaxPooling2D(2, 2),
    tf.keras.layers.Conv2D(64, (3, 3), activation='relu'),
    tf.keras.layers.MaxPooling2D(2, 2),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(128, activation=tf.nn.relu),
    tf.keras.layers.Dense(10, activation=tf.nn.softmax)
])

Let's take a look at how the convolutional layer, pooling layer, and fully connected layer are defined in JavaScript:

We first define the model as sequential:

model = tf.sequential();

The first convolutional layer:

model.add(tf.layers.conv2d({inputShape: [28, 28, 1],
                            kernelSize: 3,
                            filters: 64,
                            activation: 'relu'}));

The first pooling layer:

model.add(tf.layers.maxPooling2d({poolSize: [2, 2]}));

The first fully connected layer:

model.add(tf.layers.dense({units: 128, activation: 'relu'}));

So the full JavaScript definition is:

    model = tf.sequential();
        
    model.add(tf.layers.conv2d({inputShape: [28, 28, 1],
                                kernelSize: 3,
                                filters: 64,
                                activation: 'relu'}));
        
    model.add(tf.layers.maxPooling2d({poolSize: [2, 2]}));
        
    model.add(tf.layers.conv2d({kernelSize: 3,
                                filters: 64,
                                activation: 'relu'}));
        
    model.add(tf.layers.maxPooling2d({poolSize: [2, 2]}));
        
    model.add(tf.layers.flatten());
        
    model.add(tf.layers.dense({units: 128, activation: 'relu'}));
        
    model.add(tf.layers.dense({units: 10, activation: 'softmax'}));

Syntax when compiling a model:

model.compile({optimizer: tf.train.adam(),
               loss: 'categoricalCrossentropy',
               metrics: ['accuracy']});

Using Callbacks for Visualization

We directly use the ready-made code of "Run TensorFlow.js in the browser to train the model and give the prediction result (Iris data set)" to demonstrate, the code is as follows:

<html>
<head></head>
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@latest"></script>
    
    <script lang="js">
    
        async function run(){
      
      
            const csvUrl = 'iris.csv';
            const trainingData = tf.data.csv(csvUrl, {
      
      
                columnConfigs: {
      
      
                    species: {
      
      
                        isLabel: true
                    }
                }
            });
            
            const convertedData = trainingData.map(({
       
       xs, ys}) => {
      
      
                const labels = [
                    ys.species == 'setosa' ? 1 : 0,
                    ys.species == 'virginica' ? 1: 0,
                    ys.species == 'versicolor' ? 1 : 0
                ]
                return {
      
      xs: Object.values(xs), ys: Object.values(labels)};
            }).batch(10);
        
            const numOfFeatures = (await trainingData.columnNames()).length - 1;
        
            const model = tf.sequential();
            model.add(tf.layers.dense({
      
      inputShape: [numOfFeatures],
                                       activation: "sigmoid", units: 5}));
        
            model.add(tf.layers.dense({
      
      activation: "softmax", units: 3}));
        
            model.compile({
      
      loss: "categoricalCrossentropy",
                           optimizer: tf.train.adam(0.06)});
        
            await model.fitDataset(convertedData,
                                   {
      
      epochs:100,
                                    callbacks:{
      
      
                                        onEpochEnd: async(epoch, logs) =>{
      
      
                                            console.log("Epoch: " + epoch + " Loss: " + logs.loss);
                                    }
                                }});
            const testVal = tf.tensor2d([4.4, 2.9, 1.4, 0.2], [1, 4]);
            alert(model.predict(testVal));
        
        }
        
        run();
        
    </script>
    
<body>
    <h1>Iris Classifier</h1>
</body>
</html>

In order to use the visualizer tfjs-vis, we need to add the following scripttag :

<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs-vis"></script>

And use to tfvis.showdefine a callback to visualize during training:

const metrics = ['loss', 'accuracy'];
            
const container = {name: 'Model Training', 
                   styles: {height: '640px'},
                   tab: 'Training Progress'};
            
const fitCallbacks = tfvis.show.fitCallbacks(container, metrics);

Replace the callback in the original code with fitCallbacks, now our complete code is:

<html>
<head></head>
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@latest"></script>
    
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs-vis"></script>
    
    <script lang="js">
    
        async function run(){
      
      
            const csvUrl = 'iris.csv';
            const trainingData = tf.data.csv(csvUrl, {
      
      
                columnConfigs: {
      
      
                    species: {
      
      
                        isLabel: true
                    }
                }
            });
            
            const convertedData = trainingData.map(({
       
       xs, ys}) => {
      
      
                const labels = [
                    ys.species == 'setosa' ? 1 : 0,
                    ys.species == 'virginica' ? 1: 0,
                    ys.species == 'versicolor' ? 1 : 0
                ]
                return {
      
      xs: Object.values(xs), ys: Object.values(labels)};
            }).batch(10);
        
            const numOfFeatures = (await trainingData.columnNames()).length - 1;
        
            const model = tf.sequential();
            model.add(tf.layers.dense({
      
      inputShape: [numOfFeatures],
                                       activation: "sigmoid", units: 5}));
        
            model.add(tf.layers.dense({
      
      activation: "softmax", units: 3}));
        
            model.compile({
      
      loss: "categoricalCrossentropy",
                           optimizer: tf.train.adam(0.06)});
            
            const metrics = ['loss', 'accuracy'];
            
            const container = {
      
      name: 'Model Training', 
                               styles: {
      
      height: '640px'},
                               tab: 'Training Progress'};
            
            const fitCallbacks = tfvis.show.fitCallbacks(container, metrics);
        
            await model.fitDataset(convertedData,
                                   {
      
      epochs: 50,
                                    callbacks: fitCallbacks});
            

        
        }
        
        run();
        
    </script>
    
<body>
    <h1>Iris Classifier</h1>
</body>
</html>

After running, there will be the following results:

insert image description here


Training with the MNIST Dataset

Let's first create a new project as in "Run TensorFlow.js in the browser to train the model and give the prediction result (Iris dataset)" , and copy a script.jscopy name it data.js. The code listed below will be placed data.jsin the . Of course, if you don't want to create a new file, just put the following index.htmlin the

<script lang="js">

</script>

It's also perfectly fine, but may make the index.htmlfile too bloated.

In TensorFlow.js, a particular way of handling training data is to append all the images together into a single image , often called a sprite sheet , rather than downloading each image individually. This technique is often used in game development, where the game's graphics are stored in one file instead of multiple small files to improve file storage efficiency. If we store all the pictures for training in one file, we only need to open an HTTP connection to download all the pictures at once. For example, the sprite sheet of MNIST is shown below:

insert image description here
The dimension of this picture is 65000×784 (28×28), that is to say, we only need to read the picture file line by line to get pictures of 28×28 pixels one by one.

We can first load the image in JavaScript, then define a canvas (canvas), and draw these "lines" on the canvas after extracting each "line" (line) from the original image. The bytes on the canvas can then be extracted into a dataset for training. Let's look at the specific process below:

The training set test set ratio is 5:1

const IMAGE_SIZE = 784;
const NUM_CLASSES = 10;
const NUM_DATASET_ELEMENTS = 65000;
        
const TRAIN_TEST_RATIO = 5/6;
        
const NUM_TRAIN_ELEMENTS = Math.floor(TRAIN_TEST_RATIO * NUM_DATASET_ELEMENTS);
        
const NUM_TEST_ELEMENTS = NUM_DATASET_ELEMENTS - NUM_TRAIN_ELEMENTS;

Define the canvas:

const img = new Image();
const canvas = document.createElement('canvas');
const ctx = canvas.getContext('2d');

The map's address:

img.src = "https://storage.googleapis.com/learnjs-data/model-builder/mnist_images.png";

Once the image is loaded, we can set a buffer to hold the bytes in it. The image is a PNG file with 4 bytes per pixel, so 65,000×768×4 bytes need to be reserved for buffer. Instead of reading the file image by image, we can read it in chunks . By specifying chunkSize, we can fetch five thousand pictures at a time:

img.onload = () => {
    img.width = img.naturalWidth;
    img.height = img.naturalHeight;
            
    const datasetBytesBuffer = 
          new ArrayBuffer(NUM_DATASET_ELEMENTS * IMAGE_SIZE * 4);
            
     const chunkSize = 5000;
     canvas.width = img.width;
     anvas.height = chunkSize;

Next, we use a for loop to read the picture into the buffer. Because the picture is a grayscale image, the values ​​of the three channels of R\G\B are the same, and we choose one of them arbitrarily:

for (let i = 0; i < NUM_DATASET_ELEMENTS / chunkSize; i++) {
    const datasetBytesView = new Float32Array(
        datasetBytesBuffer, i * IMAGE_SIZE * chunkSize * 4,
        IMAGE_SIZE * chunkSize);
            
    ctx.drawImage(
        img, 0, i * chunkSize, img.width, chunkSize, 0, 0, img.width, chunkSize);
                
    const imageData = ctx.getImageData(0, 0, canvas.width, canvas.height);
                
    for (let j = 0; j < imageData.data.length / 4; j++) {
        datasetBytesView[j] = imageData.data[j * 4] / 255;
    }
}
            
this.datasetImages = new Float32Array(datasetBytesBuffer);

Like images, labels are stored in a separate file. This is a binary file with sparse encoding of the labels . Each label is represented by 10 bytes, one of which has a value of 01, representing a category. So, in addition to downloading and decoding the bytes of the image line by line, we also need to decode the tags. We use arrayBufferto decode the labels into an array of integers.

const labelsRequest = fetch("https://storage.googleapis.com/learnjs-data/model-builder/mnist_labels_uint8");
        
const [imgResponse, labelsResponse] = 
       await Promise.all([imgRequest, labelsRequest]);
        
this.datasetLabels = new Uint8Array(await labelsResponse.arrayBuffer());

Then we can divide the training set and test set:

this.trainImages = 
	this.datasetImages.slice(0, IMAGE_SIZE * NUM_TRAIN_ELEMENTS);
this.testImages = 
	this.datasetImages.slice(IMAGE_SIZE * NUM_TRAIN_ELEMENTS);
        
this.trainLabels = 
	this.datasetLabels.slice(0, NUM_CLASSES * NUM_TRAIN_ELEMENTS);
this.testLabels = 
	this.datasetLabels.slice(NUM_CLASSES * NUM_TRAIN_ELEMENTS);

Like the regular process, we can also batch the dataset :

nextBatch(batchSize, data, index) {
	const batchImagesArray = new Float32Array(batchSize * IMAGE_SIZE);
	const batchLabelsArray = new Uint8Array(batchSize * NUM_CLASSES);
	
	for (let i = 0; i < batchSize; i++) {
		const idx = index();
		
		const image =
			data[0].slice(idx * IMAGE_SIZE, idx * IMAGE_SIZE + IMAGE_SIZE);
		batchImagesArray.set(image, i * IMAGE_SIZE);
		
		const label =
			data[1].slice(idx * NUM_CLASSES, idx * NUM_CLASSES + NUM_CLASSES);
		batchLabelsArray.set(label, i * NUM_CLASSES);
	}
	
	const xs = tf.tensor2d(batchImagesArray, [batchSize, IMAGE_SIZE]);
	const labels = tf.tensor2d(batchLabelsArray, [batchSize, NUM_CLASSES]);
	
	return {xs, labels};
}

The training data can then use the following batch function to return training batches of the desired batch size in shuffled order:

nextTrainBatch(batchSize) {
	return this.nextBatch(
		batchSize, [this.trainImages, this.trainLabels], () => {
			this.shuffledTrainIndex =
				(this.shuffledTrainIndex + 1) % this.trainIndices.length;
			return this.trainIndices[this.shuffledTrainIndex];
		});
}

The test set data is handled in exactly the same way, which we give in the full code below.

The code in the entire data.jsfile is as follows: We define a class MnistDatato encapsulate all the function methods we just defined. Then index.htmlin , we directly data.jsimport the class from the file and instantiate it const data = new MnistData();, and then we can directly call the methods in the class.

const IMAGE_SIZE = 784;
const NUM_CLASSES = 10;
const NUM_DATASET_ELEMENTS = 65000;

const TRAIN_TEST_RATIO = 5 / 6;

const NUM_TRAIN_ELEMENTS = Math.floor(TRAIN_TEST_RATIO * NUM_DATASET_ELEMENTS);
const NUM_TEST_ELEMENTS = NUM_DATASET_ELEMENTS - NUM_TRAIN_ELEMENTS;

const MNIST_IMAGES_SPRITE_PATH =
    'https://storage.googleapis.com/learnjs-data/model-builder/mnist_images.png';
const MNIST_LABELS_PATH =
    'https://storage.googleapis.com/learnjs-data/model-builder/mnist_labels_uint8';

export class MnistData {
  constructor() {
    this.shuffledTrainIndex = 0;
    this.shuffledTestIndex = 0;
  }

  async load() {
    // Make a request for the MNIST sprited image.
    const img = new Image();
    const canvas = document.createElement('canvas');
    const ctx = canvas.getContext('2d');
    const imgRequest = new Promise((resolve, reject) => {
      img.crossOrigin = '';
      img.onload = () => {
        img.width = img.naturalWidth;
        img.height = img.naturalHeight;

        const datasetBytesBuffer =
            new ArrayBuffer(NUM_DATASET_ELEMENTS * IMAGE_SIZE * 4);

        const chunkSize = 5000;
        canvas.width = img.width;
        canvas.height = chunkSize;

        for (let i = 0; i < NUM_DATASET_ELEMENTS / chunkSize; i++) {
          const datasetBytesView = new Float32Array(
              datasetBytesBuffer, i * IMAGE_SIZE * chunkSize * 4,
              IMAGE_SIZE * chunkSize);
          ctx.drawImage(
              img, 0, i * chunkSize, img.width, chunkSize, 0, 0, img.width,
              chunkSize);

          const imageData = ctx.getImageData(0, 0, canvas.width, canvas.height);

          for (let j = 0; j < imageData.data.length / 4; j++) {
            // All channels hold an equal value since the image is grayscale, so
            // just read the red channel.
            datasetBytesView[j] = imageData.data[j * 4] / 255;
          }
        }
        this.datasetImages = new Float32Array(datasetBytesBuffer);

        resolve();
      };
      img.src = MNIST_IMAGES_SPRITE_PATH;
    });

    const labelsRequest = fetch(MNIST_LABELS_PATH);
    const [imgResponse, labelsResponse] =
        await Promise.all([imgRequest, labelsRequest]);

    this.datasetLabels = new Uint8Array(await labelsResponse.arrayBuffer());

    // Create shuffled indices into the train/test set for when we select a
    // random dataset element for training / validation.
    this.trainIndices = tf.util.createShuffledIndices(NUM_TRAIN_ELEMENTS);
    this.testIndices = tf.util.createShuffledIndices(NUM_TEST_ELEMENTS);

    // Slice the the images and labels into train and test sets.
    this.trainImages =
        this.datasetImages.slice(0, IMAGE_SIZE * NUM_TRAIN_ELEMENTS);
    this.testImages = this.datasetImages.slice(IMAGE_SIZE * NUM_TRAIN_ELEMENTS);
    this.trainLabels =
        this.datasetLabels.slice(0, NUM_CLASSES * NUM_TRAIN_ELEMENTS);
    this.testLabels =
        this.datasetLabels.slice(NUM_CLASSES * NUM_TRAIN_ELEMENTS);
  }

  nextTrainBatch(batchSize) {
    return this.nextBatch(
        batchSize, [this.trainImages, this.trainLabels], () => {
          this.shuffledTrainIndex =
              (this.shuffledTrainIndex + 1) % this.trainIndices.length;
          return this.trainIndices[this.shuffledTrainIndex];
        });
  }

  nextTestBatch(batchSize) {
    return this.nextBatch(batchSize, [this.testImages, this.testLabels], () => {
      this.shuffledTestIndex =
          (this.shuffledTestIndex + 1) % this.testIndices.length;
      return this.testIndices[this.shuffledTestIndex];
    });
  }

  nextBatch(batchSize, data, index) {
    const batchImagesArray = new Float32Array(batchSize * IMAGE_SIZE);
    const batchLabelsArray = new Uint8Array(batchSize * NUM_CLASSES);

    for (let i = 0; i < batchSize; i++) {
      const idx = index();

      const image =
          data[0].slice(idx * IMAGE_SIZE, idx * IMAGE_SIZE + IMAGE_SIZE);
      batchImagesArray.set(image, i * IMAGE_SIZE);

      const label =
          data[1].slice(idx * NUM_CLASSES, idx * NUM_CLASSES + NUM_CLASSES);
      batchLabelsArray.set(label, i * NUM_CLASSES);
    }

    const xs = tf.tensor2d(batchImagesArray, [batchSize, IMAGE_SIZE]);
    const labels = tf.tensor2d(batchLabelsArray, [batchSize, NUM_CLASSES]);

    return {xs, labels};
  }
}

The callback and training code defined below will be directly encapsulated index.htmlinto the trainfunction in .

Remember the visualization callback we just used? Let's define it again here:

const metrics = ['loss', 'val_loss', 'accuracy', 'val_accuracy'];
const container = { name: 'Model Training', styles: { height: '640px' },
					tab: 'Training Progress' };
const fitCallbacks = tfvis.show.fitCallbacks(container, metrics);

Call the function to generate training and test data sets:

const [trainXs, trainYs] = tf.tidy(() => {
	const d = data.nextTrainBatch(TRAIN_DATA_SIZE);
	return [
		d.xs.reshape([TRAIN_DATA_SIZE, 28, 28, 1]),
		d.labels
	];
});

const [testXs, testYs] = tf.tidy(() => {
	const d = data.nextTestBatch(TEST_DATA_SIZE);
	return [
		d.xs.reshape([TEST_DATA_SIZE, 28, 28, 1]),
		d.labels
	];
});

Note tf.tidythe . In TensorFlow.js, it will help us clean up all intermediate tensors except those returned by functions. This is critical to preventing memory leaks in browsers when using TensorFlow.js.

Now everything is ready, we can start training!

return model.fit(trainXs, trainYs, {
	batchSize: BATCH_SIZE,
	validationData: [testXs, testYs],
	epochs: 20,
	shuffle: true,
	callbacks: fitCallbacks
});

You may not be able to run the following training process at this time, don't worry, we will give the complete code of the index.htmlfile later, and you can run it directly at that time.

insert image description here


Running Inference on Images in TensorFlow.js

When inferring, we need a test picture, we can directly create a canvas object, and then let the user use the mouse to write the number to be judged on the canvas:

rawImage = document.getElementById('canvasimg');
ctx = canvas.getContext("2d");
ctx.fillStyle = "black";
ctx.fillRect(0,0,280,280);

After the user writes the number through the drawfunction :

function draw(e) {
	if(e.buttons!=1) return;
	ctx.beginPath();
	ctx.lineWidth = 24;
	ctx.lineCap = 'round';
	ctx.strokeStyle = 'white';
	ctx.moveTo(pos.x, pos.y);
	setPosition(e);
	ctx.lineTo(pos.x, pos.y);
	ctx.stroke();
	rawImage.src = canvas.toDataURL('image/png');
}

We grab pixels from the canvas and process them into input tensors that the model can handle:

var raw = tf.browser.fromPixels(rawImage,1);

var resized = tf.image.resizeBilinear(raw, [28,28]);

var tensor = resized.expandDims(0);

Then we can make predictions:

var prediction = model.predict(tensor);
var pIndex = tf.argMax(prediction, 1).dataSync();

index.thmlFull code of the file:

<html>
<head>
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@latest"></script>
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs-vis"></script>

</head>
<body>
    <h1>Handwriting Classifier!</h1>
    <canvas id="canvas" width="280" height="280" style="position:absolute;top:100;left:100;border:8px solid;"></canvas>
    <img id="canvasimg" style="position:absolute;top:10%;left:52%;width=280;height=280;display:none;">
    <input type="button" value="classify" id="sb" size="48" style="position:absolute;top:400;left:100;">
    <input type="button" value="clear" id="cb" size="23" style="position:absolute;top:400;left:180;">
    <script src="data.js" type="module">
    </script>
    
    
</body>
    
    <script type="module">
        
        import {
      
      MnistData} from './data.js';
        var canvas, ctx, saveButton, clearButton;
        var pos = {
      
      x:0, y:0};
        var rawImage;
        var model;
	
        function getModel() {
      
      
	       model = tf.sequential();

	       model.add(tf.layers.conv2d({
      
      inputShape: [28, 28, 1], kernelSize: 3, filters: 8, activation: 'relu'}));
	       model.add(tf.layers.maxPooling2d({
      
      poolSize: [2, 2]}));
	       model.add(tf.layers.conv2d({
      
      filters: 16, kernelSize: 3, activation: 'relu'}));
	       model.add(tf.layers.maxPooling2d({
      
      poolSize: [2, 2]}));
	       model.add(tf.layers.flatten());
	       model.add(tf.layers.dense({
      
      units: 128, activation: 'relu'}));
	       model.add(tf.layers.dense({
      
      units: 10, activation: 'softmax'}));

	       model.compile({
      
      optimizer: tf.train.adam(), loss: 'categoricalCrossentropy', metrics: ['accuracy']});

	       return model;
        }

        async function train(model, data) {
      
      
	       const metrics = ['loss', 'val_loss', 'accuracy', 'val_accuracy'];
	       const container = {
      
       name: 'Model Training', styles: {
      
       height: '640px' }, tab: 'Training Progress'};
	       const fitCallbacks = tfvis.show.fitCallbacks(container, metrics);
  
	       const BATCH_SIZE = 512;
	       const TRAIN_DATA_SIZE = 5500;
	       const TEST_DATA_SIZE = 1000;

	       const [trainXs, trainYs] = tf.tidy(() => {
      
      
		      const d = data.nextTrainBatch(TRAIN_DATA_SIZE);
		      return [
			     d.xs.reshape([TRAIN_DATA_SIZE, 28, 28, 1]),
			     d.labels
		      ];
	       });

	       const [testXs, testYs] = tf.tidy(() => {
      
      
		      const d = data.nextTestBatch(TEST_DATA_SIZE);
		      return [
                d.xs.reshape([TEST_DATA_SIZE, 28, 28, 1]),
			    d.labels
		      ];
	       });

	       return model.fit(trainXs, trainYs, {
      
      
		      batchSize: BATCH_SIZE,
		      validationData: [testXs, testYs],
		      epochs: 20,
		      shuffle: true,
		      callbacks: fitCallbacks
	       });
        }

        function setPosition(e){
      
      
	       pos.x = e.clientX-100;
	       pos.y = e.clientY-100;
        }
    
        function draw(e) {
      
      
	       if(e.buttons!=1) return;
	       ctx.beginPath();
	       ctx.lineWidth = 24;
	       ctx.lineCap = 'round';
	       ctx.strokeStyle = 'white';
	       ctx.moveTo(pos.x, pos.y);
	       setPosition(e);
	       ctx.lineTo(pos.x, pos.y);
	       ctx.stroke();
	       rawImage.src = canvas.toDataURL('image/png');
        }
    
        function erase() {
      
      
	       ctx.fillStyle = "black";
	       ctx.fillRect(0,0,280,280);
        }
    
        function save() {
      
      
	       var raw = tf.browser.fromPixels(rawImage,1);
	       var resized = tf.image.resizeBilinear(raw, [28,28]);
	       var tensor = resized.expandDims(0);
           var prediction = model.predict(tensor);
           var pIndex = tf.argMax(prediction, 1).dataSync();
    
	       alert(pIndex);
        }
    
        function init() {
      
      
	       canvas = document.getElementById('canvas');
	       rawImage = document.getElementById('canvasimg');
	       ctx = canvas.getContext("2d");
	       ctx.fillStyle = "black";
	       ctx.fillRect(0,0,280,280);
	       canvas.addEventListener("mousemove", draw);
	       canvas.addEventListener("mousedown", setPosition);
	       canvas.addEventListener("mouseenter", setPosition);
	       saveButton = document.getElementById('sb');
	       saveButton.addEventListener("click", save);
	       clearButton = document.getElementById('cb');
	       clearButton.addEventListener("click", erase);
        }


        async function run() {
      
        
	       const data = new MnistData();
	       await data.load();
	       const model = getModel();
	       tfvis.show.modelSummary({
      
      name: 'Model Architecture'}, model);
	       await train(model, data);
	       init();
	       alert("Training is done, try classifying your handwriting!");
        }
        
        run();
    
    </script>
    
    

</html>

After running, we wait for the model to be trained, and the following interface will appear on the right:
insert image description here

We directly use the mouse to write a number on the black canvas, and then click classify:

insert image description here
insert image description here


References

AI and Machine Learning for Coders by Laurence Moroney.

Guess you like

Origin blog.csdn.net/myDarling_/article/details/128159555