{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "XJy9QoDC7XA7" }, "source": [ "<img width=\"800px\" src=\"../fidle/img/header.svg\"></img>\n", "\n", "# <!-- TITLE --> [DRL2] - RL Baselines3 Zoo: Training in Colab\n", "<!-- DESC --> Demo of Stable baseline3 with Colab\n", "<!-- AUTHOR : Nathan Cassereau (IDRIS) and Bertrand Cabot (IDRIS) -->\n", "\n", "\n", "Demo of Stable baseline3 adapted By Nathan Cassereau (IDRIS) and Bertrand Cabot (IDRIS)\n", "\n", "\n", "Github Repo: [https://github.com/DLR-RM/rl-baselines3-zoo](https://github.com/DLR-RM/rl-baselines3-zoo)\n", "\n", "Stable-Baselines3 Repo: [https://github.com/DLR-RM/rl-baselines3-zoo](https://github.com/DLR-RM/stable-baselines3)\n", "\n", "\n", "# Install Dependencies\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "AXVDDlTn02M9" }, "outputs": [], "source": [ "!apt-get install swig cmake ffmpeg freeglut3-dev xvfb" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "f1s9_3b1Tsq9" }, "outputs": [], "source": [ "!apt-get install -y \\\n", " libgl1-mesa-dev \\\n", " libgl1-mesa-glx \\\n", " libglew-dev \\\n", " libosmesa6-dev \\\n", " software-properties-common\n", "\n", "!apt-get install -y patchelf" ] }, { "cell_type": "markdown", "metadata": { "id": "kDjF3qRg7oGH" }, "source": [ "## Clone RL Baselines3 Zoo Repo" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "SCjGikdT1DFy" }, "outputs": [], "source": [ "!git clone --recursive https://github.com/DLR-RM/rl-baselines3-zoo" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "REMQlh-ezyVt" }, "outputs": [], "source": [ "%cd /content/rl-baselines3-zoo/" ] }, { "cell_type": "markdown", "metadata": { "id": "5tmD_QTBqTMb" }, "source": [ "### Install pip dependencies" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "OWIDzgJTqShY" }, "outputs": [], "source": [ "!pip install -r requirements.txt" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "VlOD-cImRMwW" }, "outputs": [], "source": [ "!pip install free-mujoco-py" ] }, { "cell_type": "markdown", "metadata": { "id": "kYPotFDF0Noa" }, "source": [ "## Pretrained model\n", "\n", "gym environments: https://gym.openai.com/envs/" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "kgOp2XIklaf2" }, "outputs": [], "source": [ "%cd /content/rl-baselines3-zoo/" ] }, { "cell_type": "markdown", "metadata": { "id": "xVm9QPNVwKXN" }, "source": [ "### Record a Video" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "MPyfQxD5z26J" }, "outputs": [], "source": [ "# Set up display; otherwise rendering will fail\n", "import os\n", "os.system(\"Xvfb :1 -screen 0 1024x768x24 &\")\n", "os.environ['DISPLAY'] = ':1'" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ZC3OTfpf8CXu" }, "outputs": [], "source": [ "import base64\n", "from pathlib import Path\n", "\n", "from IPython import display as ipythondisplay\n", "\n", "def show_videos(video_path='', prefix=''):\n", " \"\"\"\n", " Taken from https://github.com/eleurent/highway-env\n", "\n", " :param video_path: (str) Path to the folder containing videos\n", " :param prefix: (str) Filter the video, showing only the only starting with this prefix\n", " \"\"\"\n", " html = []\n", " for mp4 in Path(video_path).glob(\"**/*{}*.mp4\".format(prefix)):\n", " video_b64 = base64.b64encode(mp4.read_bytes())\n", " html.append('''{} <br> <video alt=\"{}\" autoplay \n", " loop controls style=\"height: 400px;\">\n", " <source src=\"data:video/mp4;base64,{}\" type=\"video/mp4\" />\n", " </video>'''.format(mp4, mp4, video_b64.decode('ascii')))\n", " ipythondisplay.display(ipythondisplay.HTML(data=\"<br>\".join(html)))" ] }, { "cell_type": "markdown", "metadata": { "id": "LW-7EWA50550" }, "source": [ "### Discrete environments" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ArN0B6YX0m4z" }, "outputs": [], "source": [ "%run scripts/all_plots.py -a dqn qrdqn a2c ppo --env PongNoFrameskip-v4 -f rl-trained-agents/" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "5zmVA6JADKvV" }, "outputs": [], "source": [ "%run scripts/plot_train.py -a dqn -e PongNoFrameskip-v4 -f rl-trained-agents/ -x time" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "BkcI91hABKUm" }, "outputs": [], "source": [ "%run scripts/plot_train.py -a qrdqn -e PongNoFrameskip-v4 -f rl-trained-agents/ -x time" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "E5s7UHn2DmqQ" }, "outputs": [], "source": [ "%run scripts/plot_train.py -a a2c -e PongNoFrameskip-v4 -f rl-trained-agents/ -x time" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "0aP2A-NqCQSC" }, "outputs": [], "source": [ "%run scripts/plot_train.py -a ppo -e PongNoFrameskip-v4 -f rl-trained-agents/ -x time" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "itP336W6HdeC" }, "outputs": [], "source": [ "!python enjoy.py --algo dqn --env PongNoFrameskip-v4 --no-render --n-timesteps 5000" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ip3AauLzwNGP" }, "outputs": [], "source": [ "!python -m utils.record_video --algo dqn --env PongNoFrameskip-v4" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "oKOjFuwK9HI0" }, "outputs": [], "source": [ "show_videos(video_path='rl-trained-agents/dqn', prefix='PongNoFrameskip-v4')" ] }, { "cell_type": "markdown", "metadata": { "id": "8kD_C440-xvw" }, "source": [ "### Continuous environments" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "2yJ0eOyjLM5O" }, "outputs": [], "source": [ "%run scripts/all_plots.py -a ppo trpo sac td3 tqc --env Ant-v3 -f rl-trained-agents/" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "iQAJmndtNCvN" }, "outputs": [], "source": [ "%run scripts/plot_train.py -a ppo -e Ant-v3 -f rl-trained-agents/ -x time" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "O1L_ze67NDf9" }, "outputs": [], "source": [ "%run scripts/plot_train.py -a trpo -e Ant-v3 -f rl-trained-agents/ -x time" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "FPn0iySFNzCQ" }, "outputs": [], "source": [ "%run scripts/plot_train.py -a tqc -e Ant-v3 -f rl-trained-agents/ -x time" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "3vDHeNNTNwsT" }, "outputs": [], "source": [ "%run scripts/plot_train.py -a td3 -e Ant-v3 -f rl-trained-agents/ -x time" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "dQDNhdlgOzFs" }, "outputs": [], "source": [ "%run scripts/plot_train.py -a sac -e Ant-v3 -f rl-trained-agents/ -x time" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "s9aT46m8QXzt" }, "outputs": [], "source": [ "!python enjoy.py --algo td3 --env Ant-v3 --no-render --n-timesteps 5000" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "zx7gZ23UQYQ4" }, "outputs": [], "source": [ "!python -m utils.record_video --algo td3 --env Ant-v3" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "yvETH_M_Unw3" }, "outputs": [], "source": [ "show_videos(video_path='rl-trained-agents/td3', prefix='Ant-v3')" ] }, { "cell_type": "markdown", "metadata": { "id": "6gJ-pAbF7zRZ" }, "source": [ "## Train an RL Agent\n", "\n", "\n", "The train agent can be found in the `logs/` folder.\n", "\n", "Here we will train A2C on CartPole-v1 environment for 100 000 steps. \n", "\n", "\n", "To train it on Pong (Atari), you just have to pass `--env PongNoFrameskip-v4`\n", "\n", "Note: You need to update `hyperparams/algo.yml` to support new environments. You can access it in the side panel of Google Colab. (see https://stackoverflow.com/questions/46986398/import-data-into-google-colaboratory)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "background_save": true }, "id": "9bIR_N7R11XI" }, "outputs": [], "source": [ "!python train.py --algo dqn --env PongNoFrameskip-v4 --n-timesteps 1000000" ] }, { "cell_type": "markdown", "metadata": { "id": "-fHBq73665yD" }, "source": [ "#### Evaluate trained agent\n", "\n", "\n", "You can remove the `--folder logs/` to evaluate pretrained agent." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Bw8YuEgU6bT3" }, "outputs": [], "source": [ "!python enjoy.py --algo dqn --env PongNoFrameskip-v4 --no-render --n-timesteps 5000 --folder logs/" ] }, { "cell_type": "markdown", "metadata": { "id": "w5Il2J0VHPLC" }, "source": [ "#### Tune Hyperparameters\n", "\n", "We use [Optuna](https://optuna.org/) for optimizing the hyperparameters.\n", "\n", "Tune the hyperparameters for PPO, using a tpe sampler and median pruner, 2 parallels jobs,\n", "with a budget of 1000 trials and a maximum of 50000 steps" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "w2sC22eGHTH-" }, "outputs": [], "source": [ "#!python train.py --algo dqn --env PongNoFrameskip-v4 -n 5000 -optimize --n-trials 10 --n-jobs 5 --sampler tpe --pruner median" ] }, { "cell_type": "markdown", "metadata": { "id": "qBuUfnzI8DN6" }, "source": [ "### Display the video" ] }, { "cell_type": "markdown", "metadata": { "id": "RjdpP0HE8D2p" }, "source": [ "### Continue Training\n", "\n", "Here, we will continue training of the previous model" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "zgMZQJJF6u1C" }, "outputs": [], "source": [ "#!python train.py --algo dqn --env PongNoFrameskip-v4 --n-timesteps 50000 -i logs/dqn/PongNoFrameskip-v4_1/PongNoFrameskip-v4.zip" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "GSaoyiAE8cVj" }, "outputs": [], "source": [ "#!python enjoy.py --algo dqn --env PongNoFrameskip-v4 --no-render --n-timesteps 1000 --folder logs/" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "SKglp1awG6c6" }, "outputs": [], "source": [] } ], "metadata": { "accelerator": "GPU", "colab": { "collapsed_sections": [], "name": "FIDLE rl-baselines-zoo.ipynb", "provenance": [] }, "kernelspec": { "display_name": "Python 3", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 0 }