Full MNist Differential Privacy Example
// MNist example with differential privacy
// train a basic CNN on the mnist dataset
// all training should be done with tensorflow
import * as tf from '@tensorflow/tfjs-node';
// call any tools functions you like here
import { markAllTasksDone } from "../../src/tools/server.js";
// import SupervisedTrainingSession if you'd like, or make your own (but you must implement all required methods in sessions/train.js)
import { SupervisedTrainingSession } from '../../src/sessions/supervised.js';
// required to activate the training session
import { linkTrainingSession } from "../../src/tools/server.js";
// import your dataset
import { loadMnistData } from "../../src/tools/mnist.js"
// final-layer units
const OUTPUT_UNITS = 10;
// 28x28 MNIST-like
const IMAGE_SIZE = 784;
const NUM_LABEL_CLASSES = 10;
const TOTAL_SAMPLES = 25000;
const TOTAL_ROUNDS = 2;
const EPOCHS_PER_ROUND = 1;
const BATCH_SIZE = 100;
const MIN_CLIENTS_TO_START = 1;
const { images: fullImages, labels: fullLabels } = loadMnistData(TOTAL_SAMPLES, true);
// load test info
const { images: testImages, labels: testLabels } = loadMnistData(1000, false);
const globalModel = tf.sequential();
globalModel.add(tf.layers.dense({ units: OUTPUT_UNITS, activation: 'softmax', inputShape: [IMAGE_SIZE] }));
// loss and metric fns are NOT supported here, because we can't run untrusted code.
// instead please use tf metrics, or contact us to add one to the client manually
// optimizer functions are supported because they return an object
globalModel.compile({ optimizer: tf.train.adam(), loss: 'categoricalCrossentropy', metrics: ['accuracy'] });
const trainingSession = new SupervisedTrainingSession(
OUTPUT_UNITS,
TOTAL_SAMPLES,
IMAGE_SIZE,
NUM_LABEL_CLASSES,
TOTAL_ROUNDS,
BATCH_SIZE,
EPOCHS_PER_ROUND,
MIN_CLIENTS_TO_START,
fullImages,
fullLabels,
globalModel,
testImages,
testLabels
);
// need to run this to activate the training session routes
linkTrainingSession(trainingSession);
// view metrics for a single user for 5 seconds of computing.
trainingSession.setMetricsFunction((metrics) => {
// send these metrics (from tensorflow) to your server or aggregate them
console.log(metrics)
})
trainingSession.completedTraining().then(async () => {
markAllTasksDone();
await trainingSession.save("mnist_model");
});
// fallback payloads are NOT supported here!!
// this is just some metadata
export default {
type: "train",
action: "test training with mnist",
officialName: "train-obitws-prod1",
organization: "Obit Web Services",
}
Last updated