Introducción al aprendizaje por refuerzo. Parte 3: Q-Learning con redes neuronales, algoritmo DQN.
En la parte 2 vimos que el algoritmo Q-Learning funciona muy bien cuando el entorno es simple y la función Q(s,a) se puede representar como una tabla o matriz de valores. Pero cuando hay miles de millones de estados diferentes y cientos de acciones distintas, la tabla se vuelve enorme, y no es viable su utilización. Por ello, Mnih et al. [1] inventaron el algoritmo Deep Q-Network o DQN. Este algoritmo combina el algoritmo Q-learning con redes neuronales profundas (Deep Neural Networks). Como es sabido en el campo de la IA, las redes neuronales son una fantástica manera de aproximar funciones no lineales. Por lo tanto, este algoritmo usa una red neuronal para aproximar la función Q, evitando así utilizar una tabla para representar la misma. En realidad, utiliza dos redes neuronales para estabilizar el proceso de aprendizaje. La primera, la red neuronal principal (main Neural Network), representada por los parámetros θ, se utiliza para estimar los valores-Q del estado s y acción a actuales: Q(s, a; θ). La segunda, la red neuronal objetivo (target Neural Network), parametrizada por θ´, tendrá la misma arquitectura que la red principal pero se usará para aproximar los valores-Q del siguiente estado s´ y la siguiente acción a´. El aprendizaje ocurre en la red principal y no en la objetivo. La red objetivo se congela (sus parámetros no se cambian) durante varias iteraciones (normalmente alrededor de 10000), y después los parámetros de la red principal se copian a la red objetivo, transmitiendo así el aprendizaje de una a otra, haciendo que las estimaciones calculadas por la red objetivo sean más precisas.
Ecuación de Bellman en DQN
La ecuación de Bellman tiene esta forma ahora. Para poder entrenar una red neuronal, necesitamos una función de pérdida o coste (loss or cost function), la cual definimos como el cuadrado de la diferencia entre ambos lados de la ecuación de Bellman:
Por lo tanto, ésta será la función que minimizaremos usando el algoritmo de descenso de gradientes (gradient descent), el cuál se ejecuta automáticamente si usamos una librería de diferenciación automática con redes neuronales, como TensorFlow o PyTorch.
El código, resolviendo el problema CartPole con TensorFlow
Ejecuta tú mismo el código que usa TensorFlow paso a paso en este enlace (o la versión que usa PyTorch en este enlace), o sigue leyendo para ver el código sin ejecutarlo. Como en esta parte el código es algo más largo que en las anteriores, aquí sólo mostraré el código más importante. Para verlo entero, visita el enlace mencionado anterioremente.
Aquí tenemos el entorno (environment) conocido como CartPole. He utilizado la librería OpenAI Gym para visualizar y ejecutar este entorno. En este entorno, el objetivo es mover el carro hacia la derecha y la izquierda, con el objetivo de equilibrar el palo. Y si el palo se tuerce más de 15 grados respecto al eje vertical, el episodio terminará y volveremos a empezar.
Para implementar el algoritmo DQN, empezaremos creando las dos redes neuronales, la principal (main_nn) y la objetivo (target_nn). Ésta última será una copia de la principal, pero con sus propios pesos. También necesitaremos un optimizador (optimizer) y una función de pérdida (loss function).
Ahora crearemos el buffer donde guardaremos la experiencia recogida para usarla después y entrenar la red neuronal.
También crearemos una función auxiliar para ejecutar la política ε-voraz, y otra para entrenar la red neuronal principal usando los datos guardados en el buffer.
Tras esto, definiremos los hiperparámetros y empezaremos a entrenar el algoritmo. Para ello, primero usaremos la política ε-voraz para jugar al juego y recoger experiencia para poder aprender de esos datos. Después de terminar un episodio jugado, llamaremos a la función que entrena la red neuronal. Cada 2000 pasos de entrenamiento, copiaremos los pesos de la red neuronal principal a la red neuronal objetivo. También reduciremos el valor de epsilon (ε), para empezar con un valor de exploración alto y bajarlo poco a poco. Así, veremos cómo el algoritmo empieza a aprender a jugar al juego y la recompensa obtenida jugando al juego irá mejorando poco a poco.
El resultado que observamos es el siguiente:
Episode 0/1000. Epsilon: 0.99. Reward in last 100 episodes: 14.0 Episode 50/1000. Epsilon: 0.94. Reward in last 100 episodes: 22.2 Episode 100/1000. Epsilon: 0.89. Reward in last 100 episodes: 23.3 Episode 150/1000. Epsilon: 0.84. Reward in last 100 episodes: 23.4 Episode 200/1000. Epsilon: 0.79. Reward in last 100 episodes: 24.9 Episode 250/1000. Epsilon: 0.74. Reward in last 100 episodes: 30.4 Episode 300/1000. Epsilon: 0.69. Reward in last 100 episodes: 38.4 Episode 350/1000. Epsilon: 0.64. Reward in last 100 episodes: 51.4 Episode 400/1000. Epsilon: 0.59. Reward in last 100 episodes: 68.2 Episode 450/1000. Epsilon: 0.54. Reward in last 100 episodes: 82.4 Episode 500/1000. Epsilon: 0.49. Reward in last 100 episodes: 102.1 Episode 550/1000. Epsilon: 0.44. Reward in last 100 episodes: 129.7 Episode 600/1000. Epsilon: 0.39. Reward in last 100 episodes: 151.7 Episode 650/1000. Epsilon: 0.34. Reward in last 100 episodes: 173.0 Episode 700/1000. Epsilon: 0.29. Reward in last 100 episodes: 187.3 Episode 750/1000. Epsilon: 0.24. Reward in last 100 episodes: 190.9 Episode 800/1000. Epsilon: 0.19. Reward in last 100 episodes: 194.6 Episode 850/1000. Epsilon: 0.14. Reward in last 100 episodes: 195.9 Episode 900/1000. Epsilon: 0.09. Reward in last 100 episodes: 197.9 Episode 950/1000. Epsilon: 0.05. Reward in last 100 episodes: 200.0 Episode 1000/1000. Epsilon: 0.05. Reward in last 100 episodes: 200.0
Ahora que el agente ha aprendido a maximizar la recompensa para el entorno CartPole, haremos que el agente interactúe con el entorno una vez más, y visualizamos el resultado, viendo como ahora es capaz de mantener el palo equilibrado durante 200 frames.
Puedes ejecutar tú mismo el código que usa TensorFlow paso a paso en este enlace (o la versión que usa PyTorch en este enlace).
Referencias:
Sutton, R. S., & Barto, A. G. (2018). Reinforcement learning: An introduction. MIT press.
La serie completa de introducción al aprendizaje por refuerzo:
- Parte 1: el problema del bandido multibrazo
- Parte 2: Q-Learning
- Parte 3: Q-Learning con redes neuronales, algoritmo DQN
- Parte 4: Double DQN y Dueling DQN
- Parte 5: Políticas de gradiente
Mi repositorio de GitHub con algoritmos frecuentes de aprendizaje por refuerzo profundo (en desarrollo): https://github.com/markelsanz14/independent-rl-agents