Skip to content
This repository was archived by the owner on Apr 23, 2025. It is now read-only.

Commit a1a61d9

Browse files
Implement Proximal Policy Optimization (#655)
* Copy current version of Colab notebook * Copy code from Colab with zero gradient bug * Refactor code * Correct loss1 calculation * Separate actor and critic optimization * Fix zero-gradient issue * Refactor PPO.update() * Refactor main.swift * Plot learning curve with matplotlib * Find good hyperparameters for CartPole-v0 * Refactor code * Add copyright statements * Document ActorCritic.swift * Remove memory update in ActorCritic.act() * Convert ActorCritic to a Layer * Fix break condition and solved definition * Refactor PPOMemory for Swift-like function names * Convert PPOMemory to struct * Remove unneeded NumPy operations * Document PPOMemory and PPOAgent * Document hyperparameters * Save figure in /tmp/ * Resolve TODO in Categorical.swift * Remove unneeded parts in Categorical.swift * Remove PPOAgent.updateOldActorCritic * Minor fixes suggested by dan-zheng * Move PPOMemory inside PPOAgent * Create agent.step() that captures env.step() and memory.append() * Attribute Categorical.swift to eaplatanios/swift-rl * Add assert to ActorCritic.callAsFunction()
1 parent f775dea commit a1a61d9

File tree

7 files changed

+607
-0
lines changed

7 files changed

+607
-0
lines changed

Gym/PPO/ActorCritic.swift

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
// Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
import TensorFlow
16+
17+
/// The actor network that returns a probability for each action.
18+
///
19+
/// Actor-Critic methods has an actor network and a critic network. The actor network is the policy
20+
/// of the agent: it is used to select actions.
21+
struct ActorNetwork: Layer {
22+
typealias Input = Tensor<Float>
23+
typealias Output = Tensor<Float>
24+
25+
var l1, l2, l3: Dense<Float>
26+
27+
init(observationSize: Int, hiddenSize: Int, actionCount: Int) {
28+
l1 = Dense<Float>(
29+
inputSize: observationSize,
30+
outputSize: hiddenSize,
31+
activation: tanh,
32+
weightInitializer: heNormal()
33+
)
34+
l2 = Dense<Float>(
35+
inputSize: hiddenSize,
36+
outputSize: hiddenSize,
37+
activation: tanh,
38+
weightInitializer: heNormal()
39+
)
40+
l3 = Dense<Float>(
41+
inputSize: hiddenSize,
42+
outputSize: actionCount,
43+
activation: softmax,
44+
weightInitializer: heNormal()
45+
)
46+
}
47+
48+
@differentiable
49+
func callAsFunction(_ input: Input) -> Output {
50+
return input.sequenced(through: l1, l2, l3)
51+
}
52+
}
53+
54+
/// The critic network that returns the estimated value of each action, given a state.
55+
///
56+
/// Actor-Critic methods has an actor network and a critic network. The critic network is used to
57+
/// estimate the value of the state-action pair. With these value functions, the critic can evaluate
58+
/// the actions made by the actor.
59+
struct CriticNetwork: Layer {
60+
typealias Input = Tensor<Float>
61+
typealias Output = Tensor<Float>
62+
63+
var l1, l2, l3: Dense<Float>
64+
65+
init(observationSize: Int, hiddenSize: Int) {
66+
l1 = Dense<Float>(
67+
inputSize: observationSize,
68+
outputSize: hiddenSize,
69+
activation: relu,
70+
weightInitializer: heNormal()
71+
)
72+
l2 = Dense<Float>(
73+
inputSize: hiddenSize,
74+
outputSize: hiddenSize,
75+
activation: relu,
76+
weightInitializer: heNormal()
77+
)
78+
l3 = Dense<Float>(
79+
inputSize: hiddenSize,
80+
outputSize: 1,
81+
weightInitializer: heNormal()
82+
)
83+
}
84+
85+
@differentiable
86+
func callAsFunction(_ input: Input) -> Output {
87+
return input.sequenced(through: l1, l2, l3)
88+
}
89+
}
90+
91+
/// The actor-critic that contains actor and critic networks for action selection and evaluation.
92+
///
93+
/// Weight are often shared between the actor network and the critic network, but in this example,
94+
/// they are separated networks.
95+
struct ActorCritic: Layer {
96+
var actorNetwork: ActorNetwork
97+
var criticNetwork: CriticNetwork
98+
99+
init(observationSize: Int, hiddenSize: Int, actionCount: Int) {
100+
self.actorNetwork = ActorNetwork(
101+
observationSize: observationSize,
102+
hiddenSize: hiddenSize,
103+
actionCount: actionCount
104+
)
105+
self.criticNetwork = CriticNetwork(
106+
observationSize: observationSize,
107+
hiddenSize: hiddenSize
108+
)
109+
}
110+
111+
@differentiable
112+
func callAsFunction(_ state: Tensor<Float>) -> Categorical<Int32> {
113+
precondition(state.rank == 2, "The input must be 2-D ([batch size, state size]).")
114+
let actionProbs = self.actorNetwork(state).flattened()
115+
let dist = Categorical<Int32>(probabilities: actionProbs)
116+
return dist
117+
}
118+
}

Gym/PPO/Agent.swift

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
// Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
import PythonKit
16+
import TensorFlow
17+
18+
/// Agent that uses the Proximal Policy Optimization (PPO).
19+
///
20+
/// Proximal Policy Optimization is an algorithm that trains an actor (policy) and a critic (value
21+
/// function) using a clipped objective function. The clipped objective function simplifies the
22+
/// update equation from its predecessor Trust Region Policy Optimization (TRPO). For more
23+
/// information, check Proximal Policy Optimization Algorithms (Schulman et al., 2017).
24+
class PPOAgent {
25+
// Cache for trajectory segments for minibatch updates.
26+
var memory: PPOMemory
27+
/// The learning rate for both the actor and the critic.
28+
let learningRate: Float
29+
/// The discount factor that measures how much to weight to give to future
30+
/// rewards when calculating the action value.
31+
let discount: Float
32+
/// Number of epochs to run minibatch updates once enough trajectory segments are collected.
33+
let epochs: Int
34+
/// Parameter to clip the probability ratio.
35+
let clipEpsilon: Float
36+
/// Coefficient for the entropy bonus added to the objective.
37+
let entropyCoefficient: Float
38+
39+
var actorCritic: ActorCritic
40+
var oldActorCritic: ActorCritic
41+
var actorOptimizer: Adam<ActorNetwork>
42+
var criticOptimizer: Adam<CriticNetwork>
43+
44+
init(
45+
observationSize: Int,
46+
hiddenSize: Int,
47+
actionCount: Int,
48+
learningRate: Float,
49+
discount: Float,
50+
epochs: Int,
51+
clipEpsilon: Float,
52+
entropyCoefficient: Float
53+
) {
54+
self.learningRate = learningRate
55+
self.discount = discount
56+
self.epochs = epochs
57+
self.clipEpsilon = clipEpsilon
58+
self.entropyCoefficient = entropyCoefficient
59+
60+
self.memory = PPOMemory()
61+
62+
self.actorCritic = ActorCritic(
63+
observationSize: observationSize,
64+
hiddenSize: hiddenSize,
65+
actionCount: actionCount
66+
)
67+
self.oldActorCritic = self.actorCritic
68+
self.actorOptimizer = Adam(for: actorCritic.actorNetwork, learningRate: learningRate)
69+
self.criticOptimizer = Adam(for: actorCritic.criticNetwork, learningRate: learningRate)
70+
}
71+
72+
func step(env: PythonObject, state: PythonObject) -> (PythonObject, Bool, Float) {
73+
let tfState: Tensor<Float> = Tensor<Float>(numpy: np.array([state], dtype: np.float32))!
74+
let dist: Categorical<Int32> = oldActorCritic(tfState)
75+
let action: Int32 = dist.sample().scalarized()
76+
let (newState, reward, isDone, _) = env.step(action).tuple4
77+
78+
memory.append(
79+
state: Array(state)!,
80+
action: action,
81+
reward: Float(reward)!,
82+
logProb: dist.logProbabilities[Int(action)].scalarized(),
83+
isDone: Bool(isDone)!
84+
)
85+
86+
return (newState, Bool(isDone)!, Float(reward)!)
87+
}
88+
89+
func update() {
90+
// Discount rewards for advantage estimation
91+
var rewards: [Float] = []
92+
var discountedReward: Float = 0
93+
for i in (0..<memory.rewards.count).reversed() {
94+
if memory.isDones[i] {
95+
discountedReward = 0
96+
}
97+
discountedReward = memory.rewards[i] + (discount * discountedReward)
98+
rewards.insert(discountedReward, at: 0)
99+
}
100+
var tfRewards = Tensor<Float>(rewards)
101+
tfRewards = (tfRewards - tfRewards.mean()) / (tfRewards.standardDeviation() + 1e-5)
102+
103+
// Retrieve stored states, actions, and log probabilities
104+
let oldStates: Tensor<Float> = Tensor<Float>(numpy: np.array(memory.states, dtype: np.float32))!
105+
let oldActions: Tensor<Int32> = Tensor<Int32>(numpy: np.array(memory.actions, dtype: np.int32))!
106+
let oldLogProbs: Tensor<Float> = Tensor<Float>(numpy: np.array(memory.logProbs, dtype: np.float32))!
107+
108+
// Optimize actor and critic
109+
var actorLosses: [Float] = []
110+
var criticLosses: [Float] = []
111+
for _ in 0..<epochs {
112+
// Optimize policy network (actor)
113+
let (actorLoss, actorGradients) = valueWithGradient(at: self.actorCritic.actorNetwork) { actorNetwork -> Tensor<Float> in
114+
let npIndices = np.stack([np.arange(oldActions.shape[0], dtype: np.int32), oldActions.makeNumpyArray()], axis: 1)
115+
let tfIndices = Tensor<Int32>(numpy: npIndices)!
116+
let actionProbs = actorNetwork(oldStates).dimensionGathering(atIndices: tfIndices)
117+
118+
let dist = Categorical<Int32>(probabilities: actionProbs)
119+
let stateValues = self.actorCritic.criticNetwork(oldStates).flattened()
120+
let ratios: Tensor<Float> = exp(dist.logProbabilities - oldLogProbs)
121+
122+
let advantages: Tensor<Float> = tfRewards - stateValues
123+
let surrogateObjective = Tensor(stacking: [
124+
ratios * advantages,
125+
ratios.clipped(min:1 - self.clipEpsilon, max: 1 + self.clipEpsilon) * advantages
126+
]).min(alongAxes: 0).flattened()
127+
let entropyBonus: Tensor<Float> = Tensor<Float>(self.entropyCoefficient * dist.entropy())
128+
let loss: Tensor<Float> = -1 * (surrogateObjective + entropyBonus)
129+
130+
return loss.mean()
131+
}
132+
self.actorOptimizer.update(&self.actorCritic.actorNetwork, along: actorGradients)
133+
actorLosses.append(actorLoss.scalarized())
134+
135+
// Optimize value network (critic)
136+
let (criticLoss, criticGradients) = valueWithGradient(at: self.actorCritic.criticNetwork) { criticNetwork -> Tensor<Float> in
137+
let stateValues = criticNetwork(oldStates).flattened()
138+
let loss: Tensor<Float> = 0.5 * pow(stateValues - tfRewards, 2)
139+
140+
return loss.mean()
141+
}
142+
self.criticOptimizer.update(&self.actorCritic.criticNetwork, along: criticGradients)
143+
criticLosses.append(criticLoss.scalarized())
144+
}
145+
self.oldActorCritic = self.actorCritic
146+
memory.removeAll()
147+
}
148+
}

Gym/PPO/Categorical.swift

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
// Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
import TensorFlow
16+
17+
// Below code comes from eaplatanios/swift-rl:
18+
// https://github.com/eaplatanios/swift-rl/blob/master/Sources/ReinforcementLearning/Utilities/Protocols.swift
19+
public protocol Batchable {
20+
func flattenedBatch(outerDimCount: Int) -> Self
21+
func unflattenedBatch(outerDims: [Int]) -> Self
22+
}
23+
24+
public protocol DifferentiableBatchable: Batchable, Differentiable {
25+
@differentiable(wrt: self)
26+
func flattenedBatch(outerDimCount: Int) -> Self
27+
28+
@differentiable(wrt: self)
29+
func unflattenedBatch(outerDims: [Int]) -> Self
30+
}
31+
32+
extension Tensor: Batchable {
33+
public func flattenedBatch(outerDimCount: Int) -> Tensor {
34+
if outerDimCount == 1 {
35+
return self
36+
}
37+
var newShape = [-1]
38+
for i in outerDimCount..<rank {
39+
newShape.append(shape[i])
40+
}
41+
return reshaped(to: TensorShape(newShape))
42+
}
43+
44+
public func unflattenedBatch(outerDims: [Int]) -> Tensor {
45+
if rank > 1 {
46+
return reshaped(to: TensorShape(outerDims + shape.dimensions[1...]))
47+
}
48+
return reshaped(to: TensorShape(outerDims))
49+
}
50+
}
51+
52+
extension Tensor: DifferentiableBatchable where Scalar: TensorFlowFloatingPoint {
53+
@differentiable(wrt: self)
54+
public func flattenedBatch(outerDimCount: Int) -> Tensor {
55+
if outerDimCount == 1 {
56+
return self
57+
}
58+
var newShape = [-1]
59+
for i in outerDimCount..<rank {
60+
newShape.append(shape[i])
61+
}
62+
return reshaped(to: TensorShape(newShape))
63+
}
64+
65+
@differentiable(wrt: self)
66+
public func unflattenedBatch(outerDims: [Int]) -> Tensor {
67+
if rank > 1 {
68+
return reshaped(to: TensorShape(outerDims + shape.dimensions[1...]))
69+
}
70+
return reshaped(to: TensorShape(outerDims))
71+
}
72+
}
73+
74+
// Below code comes from eaplatanios/swift-rl:
75+
// https://github.com/eaplatanios/swift-rl/blob/master/Sources/ReinforcementLearning/Distributions/Distribution.swift
76+
public protocol Distribution {
77+
associatedtype Value
78+
79+
func entropy() -> Tensor<Float>
80+
81+
/// Returns a random sample drawn from this distribution.
82+
func sample() -> Value
83+
}
84+
85+
public protocol DifferentiableDistribution: Distribution, Differentiable {
86+
@differentiable(wrt: self)
87+
func entropy() -> Tensor<Float>
88+
}
89+
90+
// Below code comes from eaplatanios/swift-rl:
91+
// https://github.com/eaplatanios/swift-rl/blob/master/Sources/ReinforcementLearning/Distributions/Categorical.swift
92+
public struct Categorical<Scalar: TensorFlowIndex>: DifferentiableDistribution, KeyPathIterable {
93+
/// Log-probabilities of this categorical distribution.
94+
public var logProbabilities: Tensor<Float>
95+
96+
@inlinable
97+
@differentiable(wrt: probabilities)
98+
public init(probabilities: Tensor<Float>) {
99+
self.logProbabilities = log(probabilities)
100+
}
101+
102+
@inlinable
103+
@differentiable(wrt: self)
104+
public func entropy() -> Tensor<Float> {
105+
-(logProbabilities * exp(logProbabilities)).sum(squeezingAxes: -1)
106+
}
107+
108+
@inlinable
109+
public func sample() -> Tensor<Scalar> {
110+
let seed = Context.local.randomSeed
111+
let outerDimCount = self.logProbabilities.rank - 1
112+
let logProbabilities = self.logProbabilities.flattenedBatch(outerDimCount: outerDimCount)
113+
let multinomial: Tensor<Scalar> = _Raw.multinomial(
114+
logits: logProbabilities,
115+
numSamples: Tensor<Int32>(1),
116+
seed: Int64(seed.graph),
117+
seed2: Int64(seed.op))
118+
let flattenedSamples = multinomial.gathering(atIndices: Tensor<Int32>(0), alongAxis: 1)
119+
return flattenedSamples.unflattenedBatch(
120+
outerDims: [Int](self.logProbabilities.shape.dimensions[0..<outerDimCount]))
121+
}
122+
}

0 commit comments

Comments
 (0)