Full MNist Unsupervised Training Example

// train an autoencoder on the MNIST dataset using unsupervised federated learning

import * as tf from '@tensorflow/tfjs-node';
// call any tools functions you like here
import { markAllTasksDone } from "../../modules/tools.js";
// import UnsupervisedTrainingSession from our unsupervised training plugin
import { UnsupervisedTrainingSession } from '../../src/sessions/unsupervised.js';
// required to activate the training session
import { linkTrainingSession } from "../../modules/train.js";
// import your dataset
import { loadMnistData } from "../../src/tools/mnist.js"
// In this unsupervised example, we train a simple autoencoder where the model learns to reconstruct its input.
// The autoencoder is implemented as a single dense layer, which is sufficient to showcase unsupervised federated learning.

// Final output units: for an autoencoder, the reconstruction dimension equals the input dimension.
const RECONSTRUCTION_UNITS = 784;  // same as image size
const IMAGE_SIZE = 784;
const TOTAL_SAMPLES = 25000;
const TOTAL_ROUNDS = 2;
const EPOCHS_PER_ROUND = 1;
const BATCH_SIZE = 100;
const MIN_CLIENTS_TO_START = 1;

// Load the MNIST dataset (only images are used for unsupervised training)
const { images: fullImages } = loadMnistData(TOTAL_SAMPLES, true);
// Load test data (only images)
const { images: testImages } = loadMnistData(1000, false);

// Define a simple autoencoder model (a single dense layer that reconstructs the input)
const autoencoder = tf.sequential();
autoencoder.add(tf.layers.dense({
    units: RECONSTRUCTION_UNITS,
    activation: 'sigmoid',
    inputShape: [IMAGE_SIZE]
}));

// Compile the model using an optimizer and a reconstruction loss function
autoencoder.compile({
    optimizer: tf.train.adam(),
    loss: 'meanSquaredError'
});

// Initialize the unsupervised training session with the autoencoder and dataset.
// Note: UnsupervisedTrainingSession expects only input data (without labels) for training.
const trainingSession = new UnsupervisedTrainingSession(
    RECONSTRUCTION_UNITS,
    TOTAL_SAMPLES,
    IMAGE_SIZE,
    TOTAL_ROUNDS,
    BATCH_SIZE,
    EPOCHS_PER_ROUND,
    MIN_CLIENTS_TO_START,
    fullImages,
    autoencoder,
    testImages
);

// Activate the training session routes
linkTrainingSession(trainingSession);

// Optionally, set a metrics function to view metrics during training
trainingSession.setMetricsFunction((metrics) => {
    // send these metrics (from tensorflow) to your server or aggregate them
    console.log(metrics);
});

// Once training is completed, mark all tasks as done and save the trained model.
trainingSession.completedTraining().then(async () => {
    markAllTasksDone();
    await trainingSession.save("mnist_unsupervised_model");
});

// Fallback payloads are NOT supported here!!
// This is just some metadata
export default {
    type: "train",
    action: "unsupervised training with mnist autoencoder",
    officialName: "train-obitws-prod1-unsupervised",
    organization: "Obit Web Services",
};