Introduction: Mnist Dataset -- From Training to Running With ESP32 / ESP32S3

In this post, I will demonstrate an experiment, starting from building a TensorFlow Lite Deep Learning model for the popular Mnist dataset, to uploading an ESP32 / ESP32S3 sketch for handwriting recognition [of a single digit].

DumbDisplay will be used as the UI driven by the ESP32 / ESP32S3 microcontroller, for both getting handwriting input from you, as well as displaying the recognition result.

Hopefully, the UI should be obviously enough

  • You draw a digit to be recognized on the big black square canvas
  • You clear what is drawn by clicking the "clear" button on the top-left side
  • The "center" option on the top-right side is for auto-centering the handwritten digit (so that recognition can be more accurate)
  • You trigger recognition by clicking the ">>>" button on the middle-bottom side.
  • After recognition, the handwritten digit will be transferred to the smaller black square canvas on the left-bottom side.
  • And the recognition result digit will be shown on the right-bottom side.

Step 1: Building TensorFlow Lite Model for the Mnist Dataset

The Python Jupyter notebook for building the TensorFlow Lite model (used by the sketch) can be downloaded and run locally, possibly in the same way as described in my previous post Install Jupyter Server As Docker Container in Windows WSL, for DL Model Training, Possibly With VSCode

Alternatively, the Python Jupyter notebook can be run in Google Colab for free


The first cell imports the needed Python modules

import keras
import tensorflow as tf
...

Then the Mnist dataset is loaded

(X_train, y_train), (X_valid, y_valid) = mnist.load_data()

The dataset is divided into two subsets, one for training (train), and the other for validation (valid). Note that each of the two subsets (train and valid) are actually two components -- X being the input (28x28 grayscale image), and y being the expected output (the digit).

After loading the dataset, the first 12 digits from the training subset are plotted

plt.figure(figsize=(5,5))
for k in range(12):
    plt.subplot(3, 4, k+1)
    plt.imshow(X_train[k], cmap='Greys')
    plt.axis('off')
plt.tight_layout()
plt.show()

Also, the first digit from the validation subset is also plotted

plt.imshow(X_valid[0], cmap='Greys')

Before proceeding further, the original training subset is "saved" to some other variables (for later use)

ORI_X_valid = X_valid
ORI_y_valid = y_valid

Then, the training and validation inputs are transformed to be suitable for DL training.

X_train = X_train.reshape(60000, 784).astype('float32') / 255
X_valid = X_valid.reshape(10000, 784).astype('float32') / 255

Basically, the input pixels (2D) are flattened, and each pixel (grayscale value from 0 to 255) is turned into a floating point number from 0 to 1.

Next, the expected outputs are also transformed

n_classes = 10
y_train = keras.utils.to_categorical(y_train, n_classes)
y_valid = keras.utils.to_categorical(y_valid, n_classes)

Basically, each output digit (0 to 9) is turned into a vector of 10 values indicating the probability of each corresponding digit.

After the transformations, a DL model is configured

model = Sequential()
model.add(Dense(64, activation='relu', input_shape=(784,)))
model.add(Dense(64, activation='relu'))
model.add(Dense(10, activation='softmax'))

Then, the model is compiled

model.compile(loss='categorical_crossentropy', optimizer=SGD(learning_rate=0.1), metrics=['accuracy'])  

As the most important step, the model is trained

history = model.fit(X_train, y_train, batch_size=128, epochs=20, verbose=1, validation_data=(X_valid, y_valid))

After the training, the training statistics are shown by plotting a chart

loss = history.history['loss']
val_loss = history.history['val_loss']
epochs = range(1, len(loss) + 1)
plt.plot(epochs, loss, 'r.', label='Training loss')
plt.plot(epochs, val_loss, 'b', label='Validation loss')
plt.title('Training and validation loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()

To try out the model, the valid digit with index 6 is inferenced (predicted)

idx = 6
plt.imshow(ORI_X_valid[idx], cmap='Greys')
prediction = model.predict(X_valid[idx : (idx + 1)])[0]
ans = argmax(prediction)
print("ANS:", ans)

The final two cells will output the trained model in TensorFlow Lite format, then it is turned into a C file mnist_model.h suitable for including with sketch.

It is important to note that you will need to modify the C file mnist_model.h changing the signature of the variable mnist_model_tflite to something like

const unsigned char mnist_model_tflite[] = { ... }

Copy the C file to the sketch folder esp32_mnist, which will be described in a later section.

Step 2: Preparation for the Sketch

To compile and run the sketch shown in this post, you will need the followings:

  • TensorFlow Lite ESP32 library. Open your Arduino IDE; go to the menu item Tools | Manage Libraries, and type "tensorflow lite esp32" in the search box there.
  • DumbDisplay Arduino library. Open your Arduino IDE; go to the menu item Tools | Manage Libraries, and type "dumbdisplay" in the search box there.
  • For your Android phone, you will need to install the DumbDisplay Android app.


Step 3: The Sketch

Yon can download the sketch here 🔗. (You get the actual URL to the file by clicking the "Raw" button.) Download it to a folder called esp32_mnist, and save the file as esp32_mnist.ino.

If you also want to use a pre-trained model, you can download one here 🔗. Download it to the folder esp32_mnist, and save the model C file as mnist_model.h.

The target microcontroller can be ESP32 or ESP32S3. In terms of using TensorFlow Lite library, as far as I can tell, the two microcontrollers are basically 100% compatible, except that ESP32S3 doesn't support Bluetooth Classic.

In fact, if ESP32 is the target microcontroller, it is strongly recommended that you use Bluetooth connectivity between the microcontroller and the DumbDisplay app, by uncommenting the #define line that defines the name of Bluetooth

#define BLUETOOTH "ESP32BT" 

If the BLUETOOTH macro is defined, the following code to instantiate a DumbDisplay object is effective

  #include "esp32dumbdisplay.h"
  DumbDisplay dumbdisplay(new DDBluetoothSerialIO(BLUETOOTH));

Otherwise, if the target microcontroller is ESP32S3, please add lines to define the macros WIFI_SSID and WIFI_PASSWORD in order to use WIFI for connection to the DumbDisplay app, like

#define WIFI_SSID           "<your-wifi-ssid>"
#define WIFI_PASSWORD       "<your-wifi-password>"

When WIFI_SSID is defined, the following code to instantiate a DumbDisplay object is effective

  #include "wifidumbdisplay.h"
  DumbDisplay dumbdisplay(new DDWiFiServerIO(WIFI_SSID, WIFI_PASSWORD));

The #include line includes the DL model

#include "mnist_model.h"

You may want to try out yet another DL model -- https://github.com/frogermcs/MNIST-TFLite -- by changing that #include line to

#include "frogermcs_mnist_model.h"

You can download the file frogermcs_mnist_model.hhere 🔗.

Then, an "error reporter" object is created, which is needed by TensorFlow Lite library

tflite::ErrorReporter* error_reporter = new DDTFLErrorReporter();

Note that DDTFLErrorReporter is specifically for reporting errors to DumbDisplay app (as comments).

Then, a tflite::Model object is created from the DL model mnist_model_tflite.

const tflite::Model* model = ::tflite::GetModel(mnist_model_tflite);


In the setup block. First, DumbDisplay is configured.

  drawLayer = dumbdisplay.createGraphicalLayer(28, 28);
  drawLayer->border(1, "lightgray", "round", 0.5);
  drawLayer->enableFeedback("fs:drag");
...
  copyLayer = dumbdisplay.createGraphicalLayer(28, 28);
...
  clearBtn = dumbdisplay.createLcdLayer(7, 1);  
...
  centerBtn = dumbdisplay.createLcdLayer(8, 1);
...
  inferenceBtn = dumbdisplay.createLcdLayer(3, 3);
...
  resultLayer = dumbdisplay.create7SegmentRowLayer();
...
  dumbdisplay.configAutoPin(
    DDAutoPinConfig('V')
...
      .build()
    );
...
  // set "idle callback restart ESP32 if idle (i.e. disconnected)
  dumbdisplay.setIdleCalback([](long idleForMillis) {
    ESP.restart();  // restart ESP32 if idle (i.e. disconnected)
  });
  • drawLayer is a GraphicalDDLayer layer that acts as input for the digit to recognize; you drag to draw on the layer's big black canvas; notice that the layer has "feedback" enabled with the "drag" option, which basically enables sending drag movements back to the layer as "feedbacks"
  • copyLayer is also a GraphicalDDLayer layer acting as the target to copy the handwritten digit after recognization
  • clearBtn, centerBtn and inferenceBtn are LcdDDLayer layers that simulate various buttons of the UI
  • result layer is a SevenSegmentRowDDLayer layer for showing the inference (recognition) result digit
  • the different layers are "auto pinned" in the desired layout by calling dumbdisplay.configAutoPin()
  • an "idle callback" is set up -- which is called when the connection between the microcontroller and the DumbDisplay app is lost -- to restart the ESP microcontroller when idle (disconnected)

Then, TensorFlow Lite is prepared.

First in preparing TensorFlow Lite library, its version is checked to make sure it is the correct version the model expects.

  // check version to make sure supported
  if (model->version() != TFLITE_SCHEMA_VERSION) {
    error_reporter->Report("Model provided is schema version %d not equal to supported version %d.",
    model->version(), TFLITE_SCHEMA_VERSION);
  }

Then, the needed memory (81K) is allocated from the heap. Note that this part is basically the same as my previous experiment -- ESP32-CAM Person Detection Experiment With TensorFlow Lite, and the 81K is some number I got from the original source of that "person detection" sketch.

  // allocation memory for tensor_arena ... in similar fashion as espcam_person.ino
  tensor_arena = (uint8_t *) heap_caps_malloc(tensor_arena_size, MALLOC_CAP_INTERNAL | MALLOC_CAP_8BIT);
  if (tensor_arena == NULL) {
    error_reporter->Report("heap_caps_malloc() failed");
    return;
  }

Then, all supported TensorFlow operation implementations are declared.

  // pull in all the operation implementations
  tflite::AllOpsResolver* resolver = new tflite::AllOpsResolver();

Then, a tflite::MicroInterpreter object is created; you do inference (recognition) through this object

  // build an interpreter to run the model with
  interpreter = new tflite::MicroInterpreter(model, *resolver, tensor_arena, tensor_arena_size, error_reporter);

Then, AllocateTensors is called to allocate resources from the previously allocated memory tensor_arena.

  // allocate memory from the tensor_arena for the model's tensors
  TfLiteStatus allocate_status = interpreter->AllocateTensors();
  if (allocate_status != kTfLiteOk) {
    error_reporter->Report("AllocateTensors() failed");
    return;
  }

Last for preparing TensorFlow Lite library, the input channel is acquired and assigned to the variable input.

  // obtain a pointer to the model's input tensor
  input = interpreter->input(0);

Step 4: The Sketch -- the Loop Block

After all the initialization mentioned in the last section, the loop() block is the core logic of the sketch.

Here I will briefly describe the logic.

If the UI "center" button is clicked, toggle autoCenter turning on/off the option to auto-center the handwritten digit input, captured by the 2D array Pixels -- uint8_t Pixels[28][28]

  bool toogleAutoCenter = centerBtn->getFeedback() != NULL;
  if (!started) {
    started = true;
    toogleAutoCenter = true;
  }
  if (toogleAutoCenter) {
    autoCenter = !autoCenter;
....
  }

If the UI "clear" button is clicked, reset all handwriting pixels by calling the subroutine ResetPixels()

  if (clearBtn->getFeedback()) {
    ResetPixels();
  }

ResetPixels() simply clears the drawLayer and sets all cells of the 2D array Pixels to 0s.

If detected "feedbacks" for the drawLayer, due to you dragging on the big black canvas, draw "lines" joining the positions on the course

  const DDFeedback *feedback = drawLayer->getFeedback();
  if (feedback != NULL) {
    int x = feedback->x;
    int y = feedback->y;
...
    if (x == -1) {
      lastX = -1;
      lastY = -1;  
    } else {
      bool update = true;
      if (lastX == -1) {
        DrawPixel(x, y);
      } else {
        if (lastX != x || lastY != y) {
          update = DrawLine(lastX, lastY, x, y);
        }
      }
      if (update) {
        lastX = x;
        lastY = y;
      }
}
}

The drawn pixels certain will also go to Pixels, which is supposed to store the grayscale shade values (0 - 255) of the drawn pixels -- please refer to DrawPixel() and DrawLine() in the sketch

If the UI button ">>>" is clicked, invoke inference (recognition) of the handwritten digit with the DL model

  bool doInference = inferenceBtn->getFeedback() != NULL;
  if (doInference) {
...
}

The first step is auto-center the pixels if the auto-center option is on

    if (autoCenter) {
...
}

Then set up the DL model's input according to the values in Pixels. Remember that it is a 28x28 array of grayscale shades (0 - 255); and the shades need be converted to float value (0 to 1)

    int idx = 0;
    for (int y = 0; y < 28; y++) {
      for (int x = 0; x < 28; x++) {
        input->data.f[idx++] = ((float) Pixels[x][y]) / 255.0;
      }
    }

Then run the model to do inferencing, by calling Invoke() of the interpreter object

    TfLiteStatus invoke_status = interpreter->Invoke();

After inferencing, get the output object; through which you can get the probabilities of each of the 10 digits (from 0 to 9)

    TfLiteTensor* output = interpreter->output(0);

Show the inferencing result -- the probabilities (float values) -- to the connected DumbDisplay app as comments

      for (int i = 0; i < 10; i++) {
        float p = output->data.f[i];
        dumbdisplay.writeComment(String(". ") + String(i) + ": " + String(p, 3));
      }

Find out the digit with the highest probability, treat the digit with the highest probability as the recognition result

   int best = -1;
    float bestProp;
    for (int i = 0; i < 10; i++) {
      float prop = output->data.f[i];
      if (i == 0) {
        best = 0;
        bestProp = prop;
      } else if (prop > bestProp) {
        best = i;
        bestProp = prop;
      }
    }

 Show the recognition result

    resultLayer->showDigit(best);
    DrawPixelsTo(copyLayer);

Step 5: Building the Sketch

You can use Arduino IDE to build the sketch and upload it to ESP32 / ESPS3.

For ESP32

  • for Board, select ESP32 Dev Module
  • select the correct Port
  • make sure Partition Scheme is set to Huge APP

For ESP32S3

  • for Board, select ESP32S3 Dev Module
  • select the correct Port
  • make sure Partition Scheme is set to Huge APP

Alternatively, you can use VSCode with PlatformIO extension (plugin) to build the sketch, in a way similar to that described in my previous post -- A Way to Run Arduino Sketch With VSCode PlatformIO Directly

For ESP32, platformio.ini will look like

[env:esp32]
monitor_speed = 115200
platform = espressif32
board = esp32dev
framework = arduino
board_build.partitions = huge_app.csv
lib_deps =
  https://github.com/trevorwslee/Arduino-DumbDisplay
  tanakamasayuki/TensorFlowLite_ESP32@^1.0.0
build_flags = -DFOR_ESP32

For ESP32S3, platformio.ini will look like

[env:esp32s3]
monitor_speed = 115200
platform = espressif32
board = esp32-s3-devkitc-1
framework = arduino
lib_deps =
  https://github.com/trevorwslee/Arduino-DumbDisplay
  tanakamasayuki/TensorFlowLite_ESP32@^1.0.0
build_flags = -DFOR_ESP32S3

Step 6: Enjoy!

Hope you can have fun trying out Mnist Dataset DL models, properly with some you fine-tuned. Enjoy!


Peace be with you. Jesus loves you. May God bless you!