Full MNist Supervised Training Example

// train a basic CNN on the mnist dataset

// all training should be done with tensorflow
import * as tf from '@tensorflow/tfjs-node';
// call any tools functions you like here
import { markAllTasksDone } from "../../src/tools/server.js";
// import SupervisedTrainingSession if you'd like, or make your own (but you must implement all required methods in sessions/train.js)
import { SupervisedTrainingSession } from '../../src/sessions/supervised.js';
// required to activate the training session
import { linkTrainingSession } from "../../src/tools/server.js";
// import your dataset
import { loadMnistData } from "../../src/tools/mnist.js"

// final-layer units
const OUTPUT_UNITS = 10;
// 28x28 MNIST-like
const IMAGE_SIZE = 784;
const NUM_LABEL_CLASSES = 10;
const TOTAL_SAMPLES = 25000;
const TOTAL_ROUNDS = 2;
const EPOCHS_PER_ROUND = 1;
const BATCH_SIZE = 100;
const MIN_CLIENTS_TO_START = 1;

const { images: fullImages, labels: fullLabels } = loadMnistData(TOTAL_SAMPLES, true);
// load test info
const { images: testImages, labels: testLabels } = loadMnistData(1000, false);

const globalModel = tf.sequential();
globalModel.add(tf.layers.dense({ units: OUTPUT_UNITS, activation: 'softmax', inputShape: [IMAGE_SIZE] }));
// loss and metric fns are NOT supported here, because we can't run untrusted code.
// instead please use tf metrics, or contact us to add one to the client manually
// optimizer functions are supported because they return an object
globalModel.compile({ optimizer: tf.train.adam(), loss: 'categoricalCrossentropy', metrics: ['accuracy'] });

const trainingSession = new SupervisedTrainingSession(
    OUTPUT_UNITS,
    TOTAL_SAMPLES,
    IMAGE_SIZE,
    NUM_LABEL_CLASSES,
    TOTAL_ROUNDS,
    BATCH_SIZE,
    EPOCHS_PER_ROUND,
    MIN_CLIENTS_TO_START,
    fullImages,
    fullLabels,
    globalModel,
    testImages,
    testLabels
);
// need to run this to activate the training session routes
linkTrainingSession(trainingSession);

// view metrics for a single user for 5 seconds of computing.
trainingSession.setMetricsFunction((metrics) => {
    // send these metrics (from tensorflow) to your server or aggregate them
    // this is commented to prevent console spam, but you can see full metrics by uncommenting
    // console.log(metrics)
})

trainingSession.completedTraining().then(async () => {
    markAllTasksDone();
    await trainingSession.save("mnist_model");
});

// fallback payloads are NOT supported here!!
// this is just some metadata
export default {
    type: "train",
    action: "test training with mnist",
    officialName: "train-obitws-prod1",
    organization: "Obit Web Services",
}