import * as tf from '@tensorflow/tfjs';
import RLAlgorithm from './rl.js';
class DQNAlgorithm extends RLAlgorithm {
/**
* Constructs a DQN agent.
* @param {tf.Sequential} model - The Q-network model.
* @param {Array[Object]} env - Array of environment instances (same environment class), must implement reset() and step(action)
* @param {number} pzn - Max parallelization, used for autovectorization.
* @param {Object} config - Algorithm-specific hyperparameters.
*/
constructor(model, env, pzn, config) {
super(model, env, pzn, config);
this.epsilon = this.config.epsilon || 1.0;
this.epsilonDecay = this.config.epsilonDecay || 0.995;
this.minEpsilon = this.config.minEpsilon || 0.05;
this.buffer = [];
this.bufferSize = this.config.bufferSize || 20000;
this.numActions = this.config.numActions || 2;
this.numEpisodes = this.config.numEpisodes || 10;
this.batchSize = this.config.batchSize || 32;
this.discountFactor = this.config.discountFactor || 0.99;
this.targetUpdateFrequency = this.config.targetUpdateFrequency || 100;
this.updateCounter = 0;
}
/**
* Class initialization function. This is called so async functions can run outside of the constructor.
* This must be present, even if empty.
*/
async initialize() {
await this.cloneModelToTarget();
}
/**
* Clones the model to the target model.
*/
async cloneModelToTarget() {
// Use tf.io.withSaveHandler to get a copy of the model.
this.targetModel = await this.model.save(tf.io.withSaveHandler(async (data) => {
return await tf.loadLayersModel(tf.io.fromMemory(data));
}));
}
/**
* Selects an action using an epsilon-greedy policy.
* @param {Array} state - The current state.
* @returns {Number} The selected action.
*/
selectAction(state) {
if (Math.random() < this.epsilon) {
return Math.floor(Math.random() * this.numActions);
}
const stateTensor = tf.tensor2d([state], [1, state.length]);
const qValues = this.model.predict(stateTensor).arraySync()[0];
stateTensor.dispose();
return qValues.indexOf(Math.max(...qValues));
}
/**
* Stores an experience tuple in the replay buffer.
* @param {Object} exp - The experience { state, action, reward, nextState, done }.
*/
storeExperience(exp) {
this.buffer.push(exp);
if (this.buffer.length > this.bufferSize) {
this.buffer.shift();
}
}
/**
* Updates the Q-network by sampling a mini-batch from the replay buffer.
* Implements the DQN update rule.
*/
async update() {
if (this.buffer.length < this.batchSize) return; // Not enough samples to update.
// Sample a random mini-batch.
const batch = [];
for (let i = 0; i < this.batchSize; i++) {
const index = Math.floor(Math.random() * this.buffer.length);
batch.push(this.buffer[index]);
}
const states = batch.map(exp => exp.state);
const nextStates = batch.map(exp => exp.nextState);
const actions = batch.map(exp => exp.action);
const rewards = batch.map(exp => exp.reward);
const dones = batch.map(exp => exp.done);
if (this.model && !this.targetModel) {
await this.cloneModelToTarget();
console.log("Target model was not set, cloning model to target model now.");
}
// Compute Q-values for next states using the target network.
const nextStatesTensor = tf.tensor2d(nextStates, [this.batchSize, states[0].length]);
const targetQTensor = this.targetModel.predict(nextStatesTensor);
const targetQArray = targetQTensor.arraySync();
nextStatesTensor.dispose();
targetQTensor.dispose();
// Compute Q-values for current states using the current model.
const statesTensor = tf.tensor2d(states, [this.batchSize, states[0].length]);
const qTensor = this.model.predict(statesTensor);
const qArray = qTensor.arraySync();
statesTensor.dispose();
qTensor.dispose();
// For each sample in the mini-batch, compute the target value.
for (let i = 0; i < this.batchSize; i++) {
let target = rewards[i];
if (!dones[i]) {
target += this.discountFactor * Math.max(...targetQArray[i]);
}
qArray[i][actions[i]] = target;
}
// Prepare training tensors.
const xTrain = tf.tensor2d(states, [this.batchSize, states[0].length]);
const yTrain = tf.tensor2d(qArray, [this.batchSize, this.numActions]);
// Perform one gradient descent step.
await this.model.fit(xTrain, yTrain, { epochs: 1, verbose: 0 });
xTrain.dispose();
yTrain.dispose();
// Decay epsilon.
this.epsilon = Math.max(this.minEpsilon, this.epsilon * this.epsilonDecay);
// Periodically update the target network.
this.updateCounter++;
if (this.updateCounter % this.targetUpdateFrequency === 0) {
const weights = this.model.getWeights();
this.targetModel.setWeights(weights);
console.log("Target network updated.");
}
}
/**
* Runs a single step in the environment (not parallelized)
* @param {Number} env_index - Index of the environment to use
*/
async step(env_index) {
let state = this.env[env_index].reset();
let done = false;
let totalReward = 0;
const maxSteps = this.config.maxSteps || 1000;
for (let i = 0; i < maxSteps && !done; i++) {
const action = this.selectAction(state);
const result = this.env[env_index].step(action);
totalReward += result.reward;
this.storeExperience({
state: state,
action: action,
reward: result.reward,
nextState: result.state,
done: result.done
});
state = result.state;
done = result.done;
}
return totalReward;
}
/**
* Runs a single episode of training in the environment; parallelizes step().
* @returns {Promise<void>}
*/
async trainEpisode() {
const episodePromises = [];
for (let i = 0; i < this.pzn; i++) {
episodePromises.push(this.step(i));
}
const rewards = await Promise.all(episodePromises);
await this.update();
return rewards;
}
/**
* Trains the agent over a number of episodes.
* @returns {Promise<void>}
*/
async train() {
for (let ep = 0; ep < this.numEpisodes; ep++) {
const rewards = await this.trainEpisode();
if (this.config && this.config.logInterval && ((ep + 1) % this.config.logInterval === 0)) {
const avg = rewards.reduce((a, b) => a + b, 0) / rewards.length;
console.log(`Episode ${ep + 1}: avgReward=${avg.toFixed(2)}`);
}
}
}
/**
* Tests the current policy by running evaluation episodes without exploration.
* Temporarily sets epsilon to 0 so the agent acts greedily.
* Computes several metrics: average reward, max reward, average episode length,
* max episode length, standard deviation of episode lengths, and success rate.
* This is not parallelized
* @param {Number} numEpisodes - Number of test episodes (default is 10).
* @param {Number} successThreshold - Reward threshold to count an episode as a success (e.g., 195).
* @returns {Promise<Object>} - An object containing the test metrics.
*/
async test(numEpisodes = 10, successThreshold = 195) {
const originalEpsilon = this.epsilon;
this.epsilon = 0; // disable exploration
const rewardsArray = [];
const lengthsArray = [];
let successCount = 0;
for (let ep = 0; ep < numEpisodes; ep++) {
let state = this.env[0].reset();
let done = false;
let episodeReward = 0;
let steps = 0;
const maxSteps = this.config.maxSteps || 200;
while (steps < maxSteps && !done) {
const action = this.selectAction(state);
const result = this.env[0].step(action);
episodeReward += result.reward;
state = result.state;
done = result.done;
steps++;
}
console.log(`Test Episode ${ep + 1}: Reward = ${episodeReward}, Steps = ${steps}`);
rewardsArray.push(episodeReward);
lengthsArray.push(steps);
if (episodeReward >= successThreshold) {
successCount++;
}
}
const avgReward = rewardsArray.reduce((a, b) => a + b, 0) / rewardsArray.length;
const maxReward = Math.max(...rewardsArray);
const avgLength = lengthsArray.reduce((a, b) => a + b, 0) / lengthsArray.length;
const maxLength = Math.max(...lengthsArray);
const variance = lengthsArray.reduce((sum, x) => sum + Math.pow(x - avgLength, 2), 0) / lengthsArray.length;
const stdDev = Math.sqrt(variance);
const successRate = (successCount / numEpisodes) * 100;
console.log(`Test Results over ${numEpisodes} episodes:`);
console.log(`Average Reward: ${avgReward}`);
console.log(`Max Reward: ${maxReward}`);
console.log(`Average Episode Length: ${avgLength}`);
console.log(`Max Episode Length: ${maxLength}`);
console.log(`Standard Deviation of Episode Length: ${stdDev}`);
console.log(`Success Rate (>= ${successThreshold} reward): ${successRate}%`);
this.epsilon = originalEpsilon; // restore original epsilon
return {
averageReward: avgReward,
maxReward: maxReward,
averageLength: avgLength,
maxLength: maxLength,
stdDevLength: stdDev,
successRate: successRate
};
}
}
export { DQNAlgorithm };