// Train a PPO agent on the Pendulum-v1 environment using federated reinforcement learning
// Note: This uses Pyodide to run Gymnasium's Pendulum environment with discretized actions
import * as tf from '@tensorflow/tfjs-node';
import { markAllTasksDone } from "../../src/tools/server.js";
// Import the federated reinforcement learning session.
import { ReinforcementLearningSession } from '../../src/sessions/reinforce.js';
// Pyodide Pendulum environment.
import { Environment } from '../../src/simulations/pyodide-pendulum.js';
// PPO algorithm (compatible with RLAlgorithm).
import { PPOAlgorithm as RLAlgorithm } from '../../src/algorithms/ppo.js';
// Required to activate the training session routes.
import { linkTrainingSession } from "../../src/tools/server.js";
// Define some information for Pendulum.
// Pendulum-v1 has a state size of 3: [cos(theta), sin(theta), angular_velocity]
// Note: Pendulum-v1 actually uses continuous actions, but this example
// discretizes the action space for compatibility with the discrete PPO implementation
const dummyEnvironment = {
info: {
name: "Pendulum-v1",
stateSize: 3,
actionCount: 5 // Discretized torque actions: [-2, -1, 0, 1, 2]
}
};
// PPO training parameters.
const FINAL_OUTPUT_UNITS = dummyEnvironment.info.actionCount; // Number of discrete actions (5 for Pendulum)
const TOTAL_SAMPLES = 500; // Total episodes expected (used for partitioning)
const TOTAL_ROUNDS = 3; // Total federated training rounds
const BATCH_SIZE = 32; // Batch size for training updates (client-side computation)
const EPOCHS_PER_ROUND = 1; // Epochs per round
const MIN_CLIENTS_TO_START = 1; // Minimum clients required to start the training session
// Build an actor–critic model for PPO on Pendulum.
// The model has a shared base and two outputs:
// - "policy": outputting logits for action probabilities.
// - "value": outputting the state value.
const input = tf.input({ shape: [dummyEnvironment.info.stateSize] });
const dense1 = tf.layers.dense({ units: 32, activation: 'relu' }).apply(input);
const dense2 = tf.layers.dense({ units: 32, activation: 'relu' }).apply(dense1);
const policyOutput = tf.layers.dense({ units: FINAL_OUTPUT_UNITS, activation: 'linear', name: 'policy' }).apply(dense2);
const valueOutput = tf.layers.dense({ units: 1, activation: 'linear', name: 'value' }).apply(dense2);
const globalModel = tf.model({ inputs: input, outputs: [policyOutput, valueOutput] });
// Although PPO uses custom training loops, compiling the model may be useful.
globalModel.compile({ optimizer: tf.train.adam(), loss: ['categoricalCrossentropy', 'meanSquaredError'] });
// Options for the environment initialization (specific to the simulation).
const initializationOptions = {
pyodide: {
enabled: true,
// initialized with these args
initialization: `"Pendulum-v1", render_mode="rgb_array", g=9.81`,
// reset parameters
reset: `seed=42, options={"low": -0.7, "high": 0.5}`
}
};
// Algorithm-specific options.
// Note: "numEpisodes" here refers to the total on-policy samples collected rather than federated rounds.
const algorithmOptions = {
numActions: dummyEnvironment.info.actionCount,
numEpisodes: 5000,
logEpisodeRewards: false,
logInterval: 10
};
const autovectorizeInfo = {
enabled: false,
// This is just for testing purposes to determine the user's hardware parallelization limits
// Nothing will be passed into step(input), e.g input will be null
// Therefore you may want to set a default value for input in the Environment step function
dummyData: 1,
initialCount: 1,
steps: 1000,
slowdownFactor: 1.5,
maximumEnvironments: 16
}
// Initialize the federated reinforcement learning session for PPO.
// The session offloads heavy computations (simulation, training) to the clients,
// while the server serves model segments and aggregates updates.
const trainingSession = new ReinforcementLearningSession(
FINAL_OUTPUT_UNITS,
TOTAL_SAMPLES,
TOTAL_ROUNDS,
BATCH_SIZE,
EPOCHS_PER_ROUND,
MIN_CLIENTS_TO_START,
globalModel,
Environment,
initializationOptions,
RLAlgorithm,
algorithmOptions,
autovectorizeInfo
);
// Activate the training session routes so that clients can register and interact.
linkTrainingSession(trainingSession);
// Set a metrics function to log client metrics during training.
trainingSession.setMetricsFunction((metrics) => {
console.log(metrics);
});
// Once the training session completes, mark all tasks as done and save the trained model.
trainingSession.completedTraining().then(async () => {
markAllTasksDone();
await trainingSession.save("ppo_model");
});
// Export metadata for this plugin.
export default {
type: "rl_train",
action: "PPO reinforcement training on CartPole",
officialName: "train-obitws-prod1-ppo",
organization: "Obit Web Services",
};