Book Image

Hands-On Reinforcement Learning with Python

By : Sudharsan Ravichandiran
Book Image

Hands-On Reinforcement Learning with Python

By: Sudharsan Ravichandiran

Overview of this book

Reinforcement Learning (RL) is the trending and most promising branch of artificial intelligence. Hands-On Reinforcement learning with Python will help you master not only the basic reinforcement learning algorithms but also the advanced deep reinforcement learning algorithms. The book starts with an introduction to Reinforcement Learning followed by OpenAI Gym, and TensorFlow. You will then explore various RL algorithms and concepts, such as Markov Decision Process, Monte Carlo methods, and dynamic programming, including value and policy iteration. This example-rich guide will introduce you to deep reinforcement learning algorithms, such as Dueling DQN, DRQN, A3C, PPO, and TRPO. You will also learn about imagination-augmented agents, learning from human preference, DQfD, HER, and many more of the recent advancements in reinforcement learning. By the end of the book, you will have all the knowledge and experience needed to implement reinforcement learning and deep reinforcement learning in your projects, and you will be all set to enter the world of artificial intelligence.
Table of Contents (16 chapters)

Neural networks in TensorFlow

Now, we will see how to build a basic neural network using TensorFlow, which predicts handwritten digits. We will use the popular MNIST dataset which has a collection of labeled handwritten images for training.

First, we must import TensorFlow and load the dataset from tensorflow.examples.tutorial.mnist:

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("/tmp/data/", one_hot=True)

Now, we will see what we have in our data:

print("No of images in training set {}".format(mnist.train.images.shape))
print("No of labels in training set {}".format(mnist.train.labels.shape))

print("No of images in test set {}".format(mnist.test.images.shape))
print("No of labels in test set {}".format(mnist.test.labels.shape))

It will print the following:

No of...