Full Pyodide Example

// Train a PPO agent on the Pendulum-v1 environment using federated reinforcement learning
// Note: This uses Pyodide to run Gymnasium's Pendulum environment with discretized actions

import * as tf from '@tensorflow/tfjs-node';
import { markAllTasksDone } from "../../src/tools/server.js";
// Import the federated reinforcement learning session.
import { ReinforcementLearningSession } from '../../src/sessions/reinforce.js';
// Pyodide Pendulum environment.
import { Environment } from '../../src/simulations/pyodide-pendulum.js';
// PPO algorithm (compatible with RLAlgorithm).
import { PPOAlgorithm as RLAlgorithm } from '../../src/algorithms/ppo.js';
// Required to activate the training session routes.
import { linkTrainingSession } from "../../src/tools/server.js";

// Define some information for Pendulum.
// Pendulum-v1 has a state size of 3: [cos(theta), sin(theta), angular_velocity]
// Note: Pendulum-v1 actually uses continuous actions, but this example 
// discretizes the action space for compatibility with the discrete PPO implementation
const dummyEnvironment = {
    info: {
        name: "Pendulum-v1",
        stateSize: 3,
        actionCount: 5  // Discretized torque actions: [-2, -1, 0, 1, 2]
    }
};

// PPO training parameters.
const FINAL_OUTPUT_UNITS = dummyEnvironment.info.actionCount; // Number of discrete actions (5 for Pendulum)
const TOTAL_SAMPLES = 500;       // Total episodes expected (used for partitioning)
const TOTAL_ROUNDS = 3;          // Total federated training rounds
const BATCH_SIZE = 32;           // Batch size for training updates (client-side computation)
const EPOCHS_PER_ROUND = 1;      // Epochs per round
const MIN_CLIENTS_TO_START = 1;  // Minimum clients required to start the training session

// Build an actor–critic model for PPO on Pendulum.
// The model has a shared base and two outputs:
//  - "policy": outputting logits for action probabilities.
//  - "value": outputting the state value.
const input = tf.input({ shape: [dummyEnvironment.info.stateSize] });
const dense1 = tf.layers.dense({ units: 32, activation: 'relu' }).apply(input);
const dense2 = tf.layers.dense({ units: 32, activation: 'relu' }).apply(dense1);
const policyOutput = tf.layers.dense({ units: FINAL_OUTPUT_UNITS, activation: 'linear', name: 'policy' }).apply(dense2);
const valueOutput = tf.layers.dense({ units: 1, activation: 'linear', name: 'value' }).apply(dense2);
const globalModel = tf.model({ inputs: input, outputs: [policyOutput, valueOutput] });

// Although PPO uses custom training loops, compiling the model may be useful.
globalModel.compile({ optimizer: tf.train.adam(), loss: ['categoricalCrossentropy', 'meanSquaredError'] });

// Options for the environment initialization (specific to the simulation).
const initializationOptions = {
    pyodide: {
        enabled: true,
        // initialized with these args
        initialization: `"Pendulum-v1", render_mode="rgb_array", g=9.81`,
        // reset parameters
        reset: `seed=42, options={"low": -0.7, "high": 0.5}`
    }
};

// Algorithm-specific options.
// Note: "numEpisodes" here refers to the total on-policy samples collected rather than federated rounds.
const algorithmOptions = {
    numActions: dummyEnvironment.info.actionCount,
    numEpisodes: 5000,
    logEpisodeRewards: false,
    logInterval: 10
};

const autovectorizeInfo = {
    enabled: false,
    // This is just for testing purposes to determine the user's hardware parallelization limits
    // Nothing will be passed into step(input), e.g input will be null
    // Therefore you may want to set a default value for input in the Environment step function
    dummyData: 1,
    initialCount: 1,
    steps: 1000,
    slowdownFactor: 1.5,
    maximumEnvironments: 16
}

// Initialize the federated reinforcement learning session for PPO.
// The session offloads heavy computations (simulation, training) to the clients,
// while the server serves model segments and aggregates updates.
const trainingSession = new ReinforcementLearningSession(
    FINAL_OUTPUT_UNITS,
    TOTAL_SAMPLES,
    TOTAL_ROUNDS,
    BATCH_SIZE,
    EPOCHS_PER_ROUND,
    MIN_CLIENTS_TO_START,
    globalModel,
    Environment,
    initializationOptions,
    RLAlgorithm,
    algorithmOptions,
    autovectorizeInfo
);

// Activate the training session routes so that clients can register and interact.
linkTrainingSession(trainingSession);

// Set a metrics function to log client metrics during training.
trainingSession.setMetricsFunction((metrics) => {
    console.log(metrics);
});

// Once the training session completes, mark all tasks as done and save the trained model.
trainingSession.completedTraining().then(async () => {
    markAllTasksDone();
    await trainingSession.save("ppo_model");
});

// Export metadata for this plugin.
export default {
    type: "rl_train",
    action: "PPO reinforcement training on CartPole",
    officialName: "train-obitws-prod1-ppo",
    organization: "Obit Web Services",
};