Reinforcement Learning |

Making Dropout Work with PPO

Vanilla dropout doesn't work with policy gradient methods like PPO, so let's add some flavor to it.
Sketch of a neural network diagram, falling off to the right.

Dropout is a widely used regularization method for neural networks. Despite its use in convolutional and transformer vision models, language models, and even off-policy reinforcement learning, it doesn’t work out of the box in on-policy reinforcement learning (RL) with policy gradient. However, there is no inherent incompatibility between dropout and policy gradient methods, and we can make dropout work with policy gradients by handling the dropout noise more carefully. By consistently applying the same dropout noise pattern during model updates as was used during rollouts, we’ll see that dropout works just fine with policy gradient reinforcement learning.

Background

Dropout is a regularization method that randomly zeros some hidden activations during training of a neural network. This prevents the individual activations from specializing (overfitting) too much to pecularities of the input data, and as such improves generalization. Dropout has a single parameter, the dropout probability, which controls the rate at which individual elements are zeroed, and a typical dropout rate is 10%.

Conventional dropout implementations such as nn.Dropout work by sampling a new binary noise vector for each minibatch and multiplying the incoming activation vector pointwise with the noise. As dropout is only useful during training, the output is also scaled to keep the signal level same between training and inference.

Proximal policy optimization (PPO) is an on-policy reinforcement learning method based on policy gradient optimization. PPO trains a policy network, which is a neural network that, given an observation from the environment, returns a probability distribution over actions. (Technically, PPO also includes a value network, but we subsume it within the loss function, as it isn’t needed after training.)

Briefly, PPO works in two alternating phases:

  1. rollout: generate data by sampling actions from the policy and acting in the environment, and
  2. update: improve the policy by optimizing the PPO loss function on the data.

In fact, PPO is not exactly an on-policy method but rather a nearly-on-policy method, because the policy is updated several times in minibatches. The PPO loss function is designed to allow multiple model updates without causing the training to diverge, and one key component in this is importance sampling. The losses of individual actions are weighted based on the probability ratio of the new and old policies, which corrects for the off-policy bias as the updated policy is drifting away from the rollout policy which was used to sample the actions.

What’s the problem?

Let’s add a dropout layer into a small policy network, as would be typical in supervised learning applications like image classification. Since reinforcement learning networks tend to be tiny, let’s just have a network with one 64-dimensional hidden layer.

policy = nn.Sequential(
    nn.Linear(observation_dim, 64),
    nn.ReLU(),
    nn.Dropout(0.1),
    nn.Linear(64, action_dim),
)
Policy network with naive dropout.

If we tried to train this policy network with PPO, we would learn that training wouldn’t proceed very well, if at all. The only way to train this network is to turn the dropout rate to zero. When the above network is trained using the PPO algorithm, it gets called two times for each observation:

  1. during rollouts, the policy is run forward to sample the action for the environment, and
  2. during updates, the policy is run forward and backward to optimize the loss function on the minibatch.

The reason why conventional dropout layers don’t work in this case is because the two dropout invocations use two independent noise samples, and therefore the resulting action probabilities are different. This wreaks havoc in the importance weighting component of the loss. The purpose of the importance weighting is to correct the loss by the ratio of the new and old action probabilities, where the probabilities come from the two policy invocations. But because both invocations use different dropout noises, the probability ratio becomes just noise and so does the loss function.

PPO flavored dropout

To make dropout work in PPO, we clearly need to store the dropout noise vector that was sampled during the rollout phase and then apply the same noise sample during model updates. The simplest way to achieve this in a reinforcement learning context is to add the dropout noise as part of the observation! That way, we can re-use all the existing data processing code and have guarantees that the dropout noises seen during model updates are exactly same as seen during rollouts.

Concretely, this means that the dropout implementation gets split into two parts: a preprocessor that is only run during rollouts and a network module that runs as part of the policy net. Here’s an example implementation of the preprocessor using the Gym API. This code assumes that there is one dropout layer with a dimension of 64; obviously, ask your doctor if this is right for you.

class DropoutPreprocessor(gym.Wrapper):
    def __init__(self, env: gym.Env, rate: float):
        super().__init__(env)
        self.rate = rate

    def _sample_dropout_noise(self):
        # sample dropout mask
        mask = np.random.binomial(1, 1 - self.rate, size=64)
        # calibrate scaling between training and inference
        noise = mask / (1 - self.rate)
        return noise

    def reset(self):
        observation, info = self.env.reset()
        observation["dropout_noise"] = self._sample_dropout_noise()
        return observation, info

    def step(self, action):
        observation, reward, done, info = self.env.step(action)
        observation["dropout_noise"] = self._sample_dropout_noise()
        return observation, reward, done, info
Dropout noise generator.

The dropout network component with preset noise cannot easily be implemented as a generic nn.Module because it needs side channel information. It is simple enough that it is easy to embed directly in the policy network.

class PolicyNet(nn.Module):
    def __init__(self, observation_dim: int, action_dim: int):
        super().__init__()
        self.fc_in = nn.Linear(observation_dim, 64)
        self.relu = nn.ReLU()
        self.fc_out = nn.Linear(64, action_dim)

    def forward(self, observation):
        feas = observation["features"]
        x = self.relu(self.fc_in(feas))
        if self.training:
            # apply dropout with preset noise
            noise = observation["dropout_noise"]
            x = x * noise
        x = self.fc_out(x)
        return x
Policy network with preset dropout noise.

We have essentially reimplemented the logic from nn.Dropout and split it into two halves: noise generation and activation modulation. That’s all there is to making dropout work with PPO!

Bonus: Gaussian flavored dropout

The dropout menu has one more flavor available: Gaussian dropout. Whereas the binary dropout noise is sampled from a Bernoulli distribution, the Gaussian dropout uses the normal distribution. We can even keep our PolicyNet module unchanged and only switch to a different noise generator to get a taste.

class GaussianDropoutPreprocessor(gym.Wrapper):
    def __init__(self, env: gym.Env, rate: float):
        super().__init__(env)
        # convert dropout rate into noise standard deviation
        self.scale = (rate / (1 - rate)) ** 0.5

    def _sample_dropout_noise(self):
        return np.random.normal(1, self.scale, size=64)

    def reset(self):
        observation, info = self.env.reset()
        observation["dropout_noise"] = self._sample_dropout_noise()
        return observation, info

    def step(self, action):
        observation, reward, done, info = self.env.step(action)
        observation["dropout_noise"] = self._sample_dropout_noise()
        return observation, reward, done, info
Gaussian dropout noise generator.