Full DQN Algorithm

import * as tf from '@tensorflow/tfjs';
import RLAlgorithm from './rl.js';

class DQNAlgorithm extends RLAlgorithm {
    /**
     * Constructs a DQN agent.
     * @param {tf.Sequential} model - The Q-network model.
     * @param {Array[Object]} env - Array of environment instances (same environment class), must implement reset() and step(action)
     * @param {number} pzn - Max parallelization, used for autovectorization.
     * @param {Object} config - Algorithm-specific hyperparameters.
     */
    constructor(model, env, pzn, config) {
        super(model, env, pzn, config);
        this.epsilon = this.config.epsilon || 1.0;
        this.epsilonDecay = this.config.epsilonDecay || 0.995;
        this.minEpsilon = this.config.minEpsilon || 0.05;
        this.buffer = [];
        this.bufferSize = this.config.bufferSize || 20000;
        this.numActions = this.config.numActions || 2;
        this.numEpisodes = this.config.numEpisodes || 10;
        this.batchSize = this.config.batchSize || 32;
        this.discountFactor = this.config.discountFactor || 0.99;
        this.targetUpdateFrequency = this.config.targetUpdateFrequency || 100;
        this.updateCounter = 0;
    }

    /**
     * Class initialization function. This is called so async functions can run outside of the constructor.
     * This must be present, even if empty.
     */
    async initialize() {
        await this.cloneModelToTarget();
    }

    /**
     * Clones the model to the target model.
     */
    async cloneModelToTarget() {
        // Use tf.io.withSaveHandler to get a copy of the model.
        this.targetModel = await this.model.save(tf.io.withSaveHandler(async (data) => {
            return await tf.loadLayersModel(tf.io.fromMemory(data));
        }));
    }

    /**
     * Selects an action using an epsilon-greedy policy.
     * @param {Array} state - The current state.
     * @returns {Number} The selected action.
     */
    selectAction(state) {
        if (Math.random() < this.epsilon) {
            return Math.floor(Math.random() * this.numActions);
        }
        const stateTensor = tf.tensor2d([state], [1, state.length]);
        const qValues = this.model.predict(stateTensor).arraySync()[0];
        stateTensor.dispose();
        return qValues.indexOf(Math.max(...qValues));
    }

    /**
     * Stores an experience tuple in the replay buffer.
     * @param {Object} exp - The experience { state, action, reward, nextState, done }.
     */
    storeExperience(exp) {
        this.buffer.push(exp);
        if (this.buffer.length > this.bufferSize) {
            this.buffer.shift();
        }
    }

    /**
     * Updates the Q-network by sampling a mini-batch from the replay buffer.
     * Implements the DQN update rule.
     */
    async update() {
        if (this.buffer.length < this.batchSize) return; // Not enough samples to update.

        // Sample a random mini-batch.
        const batch = [];
        for (let i = 0; i < this.batchSize; i++) {
            const index = Math.floor(Math.random() * this.buffer.length);
            batch.push(this.buffer[index]);
        }

        const states = batch.map(exp => exp.state);
        const nextStates = batch.map(exp => exp.nextState);
        const actions = batch.map(exp => exp.action);
        const rewards = batch.map(exp => exp.reward);
        const dones = batch.map(exp => exp.done);

        if (this.model && !this.targetModel) {
            await this.cloneModelToTarget();
            console.log("Target model was not set, cloning model to target model now.");
        }

        // Compute Q-values for next states using the target network.
        const nextStatesTensor = tf.tensor2d(nextStates, [this.batchSize, states[0].length]);
        const targetQTensor = this.targetModel.predict(nextStatesTensor);
        const targetQArray = targetQTensor.arraySync();
        nextStatesTensor.dispose();
        targetQTensor.dispose();

        // Compute Q-values for current states using the current model.
        const statesTensor = tf.tensor2d(states, [this.batchSize, states[0].length]);
        const qTensor = this.model.predict(statesTensor);
        const qArray = qTensor.arraySync();
        statesTensor.dispose();
        qTensor.dispose();

        // For each sample in the mini-batch, compute the target value.
        for (let i = 0; i < this.batchSize; i++) {
            let target = rewards[i];
            if (!dones[i]) {
                target += this.discountFactor * Math.max(...targetQArray[i]);
            }
            qArray[i][actions[i]] = target;
        }

        // Prepare training tensors.
        const xTrain = tf.tensor2d(states, [this.batchSize, states[0].length]);
        const yTrain = tf.tensor2d(qArray, [this.batchSize, this.numActions]);

        // Perform one gradient descent step.
        await this.model.fit(xTrain, yTrain, { epochs: 1, verbose: 0 });
        xTrain.dispose();
        yTrain.dispose();

        // Decay epsilon.
        this.epsilon = Math.max(this.minEpsilon, this.epsilon * this.epsilonDecay);

        // Periodically update the target network.
        this.updateCounter++;
        if (this.updateCounter % this.targetUpdateFrequency === 0) {
            const weights = this.model.getWeights();
            this.targetModel.setWeights(weights);
            console.log("Target network updated.");
        }
    }

    /**
     * Runs a single step in the environment (not parallelized)
     * @param {Number} env_index - Index of the environment to use
     */
    async step(env_index) {
        let state = this.env[env_index].reset();
        let done = false;
        let totalReward = 0;
        const maxSteps = this.config.maxSteps || 1000;
        for (let i = 0; i < maxSteps && !done; i++) {
            const action = this.selectAction(state);
            const result = this.env[env_index].step(action);
            totalReward += result.reward;
            this.storeExperience({
                state: state,
                action: action,
                reward: result.reward,
                nextState: result.state,
                done: result.done
            });
            state = result.state;
            done = result.done;
        }
        return totalReward;
    }

    /**
     * Runs a single episode of training in the environment; parallelizes step().
     * @returns {Promise<void>}
     */
    async trainEpisode() {
        const episodePromises = [];
        for (let i = 0; i < this.pzn; i++) {
            episodePromises.push(this.step(i));
        }

        const rewards = await Promise.all(episodePromises);
        await this.update();
        return rewards;
    }

    /**
     * Trains the agent over a number of episodes.
     * @returns {Promise<void>}
     */
    async train() {
        for (let ep = 0; ep < this.numEpisodes; ep++) {
            const rewards = await this.trainEpisode();
            if (this.config && this.config.logInterval && ((ep + 1) % this.config.logInterval === 0)) {
                const avg = rewards.reduce((a, b) => a + b, 0) / rewards.length;
                console.log(`Episode ${ep + 1}: avgReward=${avg.toFixed(2)}`);
            }
        }
    }

    /**
     * Tests the current policy by running evaluation episodes without exploration.
     * Temporarily sets epsilon to 0 so the agent acts greedily.
     * Computes several metrics: average reward, max reward, average episode length,
     * max episode length, standard deviation of episode lengths, and success rate.
     * This is not parallelized
     * @param {Number} numEpisodes - Number of test episodes (default is 10).
     * @param {Number} successThreshold - Reward threshold to count an episode as a success (e.g., 195).
     * @returns {Promise<Object>} - An object containing the test metrics.
     */
    async test(numEpisodes = 10, successThreshold = 195) {
        const originalEpsilon = this.epsilon;
        this.epsilon = 0; // disable exploration
        const rewardsArray = [];
        const lengthsArray = [];
        let successCount = 0;

        for (let ep = 0; ep < numEpisodes; ep++) {
            let state = this.env[0].reset();
            let done = false;
            let episodeReward = 0;
            let steps = 0;
            const maxSteps = this.config.maxSteps || 200;
            while (steps < maxSteps && !done) {
                const action = this.selectAction(state);
                const result = this.env[0].step(action);
                episodeReward += result.reward;
                state = result.state;
                done = result.done;
                steps++;
            }
            console.log(`Test Episode ${ep + 1}: Reward = ${episodeReward}, Steps = ${steps}`);
            rewardsArray.push(episodeReward);
            lengthsArray.push(steps);
            if (episodeReward >= successThreshold) {
                successCount++;
            }
        }

        const avgReward = rewardsArray.reduce((a, b) => a + b, 0) / rewardsArray.length;
        const maxReward = Math.max(...rewardsArray);
        const avgLength = lengthsArray.reduce((a, b) => a + b, 0) / lengthsArray.length;
        const maxLength = Math.max(...lengthsArray);
        const variance = lengthsArray.reduce((sum, x) => sum + Math.pow(x - avgLength, 2), 0) / lengthsArray.length;
        const stdDev = Math.sqrt(variance);
        const successRate = (successCount / numEpisodes) * 100;

        console.log(`Test Results over ${numEpisodes} episodes:`);
        console.log(`Average Reward: ${avgReward}`);
        console.log(`Max Reward: ${maxReward}`);
        console.log(`Average Episode Length: ${avgLength}`);
        console.log(`Max Episode Length: ${maxLength}`);
        console.log(`Standard Deviation of Episode Length: ${stdDev}`);
        console.log(`Success Rate (>= ${successThreshold} reward): ${successRate}%`);

        this.epsilon = originalEpsilon; // restore original epsilon
        return {
            averageReward: avgReward,
            maxReward: maxReward,
            averageLength: avgLength,
            maxLength: maxLength,
            stdDevLength: stdDev,
            successRate: successRate
        };
    }
}

export { DQNAlgorithm };