Skip to content

Commit e935f1a

Browse files
committed
mult state
1 parent 9038a4a commit e935f1a

File tree

13 files changed

+1244
-115
lines changed

13 files changed

+1244
-115
lines changed

Assets/Resources/Prefabs/Birds/Bird_1.prefab

Lines changed: 1222 additions & 68 deletions
Large diffs are not rendered by default.

Assets/Resources/tfModels/ppo.bytes

2 KB
Binary file not shown.

Assets/Scene/3DBird.unity

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -903,7 +903,7 @@ Transform:
903903
m_PrefabInternal: {fileID: 0}
904904
m_GameObject: {fileID: 402347489}
905905
m_LocalRotation: {x: 0, y: 0, z: 0, w: 1}
906-
m_LocalPosition: {x: 0, y: 2, z: 0}
906+
m_LocalPosition: {x: 0, y: 1, z: 0}
907907
m_LocalScale: {x: 1, y: 1, z: 1}
908908
m_Children:
909909
- {fileID: 335634046}

Assets/Scripts/Core/BaseEnv.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ public abstract class BaseEnv : ScriptableObject
99
protected int[] last_state;
1010
protected int total_r = 0;
1111
protected BirdAction last_action = BirdAction.PAD;
12-
12+
1313
public int Score { get { return total_r; } }
1414

1515
protected abstract bool birdFly { get; }
@@ -35,7 +35,7 @@ void OnScore(object arg)
3535

3636
void OnDied(object arg)
3737
{
38-
last_r = -100;
38+
last_r = -2000;
3939
}
4040

4141
public virtual void OnApplicationQuit() { }
@@ -71,7 +71,7 @@ public virtual void OnTick()
7171
public abstract BirdAction choose_action(int[] state);
7272

7373
public abstract void UpdateState(int[] state, int[] state_, int rewd, BirdAction action);
74-
74+
7575
public virtual void OnRestart(int[] state) { }
7676

7777
public virtual void OnInspector() { }

Assets/Scripts/Core/InternalEnv.cs

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@ public override void OnTick()
4949
{
5050
UpdateState(last_state, state, last_r, last_action);
5151
}
52-
5352
//do next loop
5453
BirdAction action = choose_action(state);
5554
GameMgr.S.RespondByDecision(action);
@@ -63,10 +62,12 @@ public override BirdAction choose_action(int[] state)
6362
{
6463
#if TensorFlow
6564
var runner = session.GetRunner();
66-
float[,] sample = new float[1, 1];
67-
// sample[0, 0] = state;
68-
TFTensor t = new TFTensor(sample);
69-
runner.AddInput(graph["state"][0], t);
65+
float[,] fstate = new float[1, 3];
66+
for (int i = 0; i < state.Length; i++)
67+
{
68+
fstate[0, i] = state[i];
69+
}
70+
runner.AddInput(graph["state"][0], fstate);
7071
runner.Fetch(graph["pi/probweights"][0]);
7172
TFTensor[] networkOutput;
7273
try
@@ -82,12 +83,11 @@ public override BirdAction choose_action(int[] state)
8283
}
8384
finally
8485
{
85-
throw new System.Exception(errorMessage);
86+
throw new System.Exception(errorMessage + " \n" + e.StackTrace);
8687
}
8788
}
88-
// Debug.Log(networkOutput.Length);
8989
float[,] output = (float[,])networkOutput[0].GetValue();
90-
Debug.Log("choice action: " + output[0, 0] + " " + output[0, 1]);
90+
Debug.Log(string.Format("pi/probweights:{0},{1} ", output[0, 0], output[0, 1]));
9191
int rand = Random.Range(0, 100);
9292
return rand < (int)(output[0, 0] * 100) ? BirdAction.FLY : BirdAction.PAD;
9393
#else

Assets/Scripts/Editor/GameManagerEditor.cs renamed to Assets/Scripts/Editor/GameMgrEditor.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44

55
[CustomEditor(typeof(GameMgr))]
6-
public class GameManagerEditor : Editor
6+
public class GameMgrEditor : Editor
77
{
88
public override void OnInspectorGUI()
99
{

Assets/Scripts/Env/PillarManager.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ void CreatePillar()
4848
currPillar = pillar;
4949
pillar.transform.position = new Vector3(EnvGlobalValue.PillarBornX, 0, 0);
5050
pillar.transform.localScale = Vector3.one;
51-
int state = Random.Range(0, 2);
51+
int state = 1;//Random.Range(0, 2);
5252
pillar.SetState(state);
5353
run_pool.Add(pillar);
5454
}

Assets/Scripts/Gamer/GameMgr.cs

Lines changed: 2 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,6 @@ public class GameMgr : MonoBehaviour
2929
public Bird mainBird;
3030
public PillarMgr pillMgr;
3131

32-
const float fpsMeasurePeriod = 0.5f;
33-
private int m_FpsAccumulator = 0;
34-
private float m_FpsNextPeriod = 0;
35-
private float m_TotalTime = 0f;
36-
private float m_CurrentFps;
37-
private float m_SignTime = 0;
38-
3932
public bool IsGameOver { get { return isGameOver; } }
4033

4134
public bool IsGameStart { get { return isGameStart; } }
@@ -70,17 +63,10 @@ void Awake()
7063
env.Init();
7164
}
7265

73-
7466
void OnGUI()
7567
{
76-
string str = string.Format("frame:{0} ", m_CurrentFps.ToString("f2"));
77-
GUI.Label(new Rect(20, 20, 100, 30), str, style);
78-
str = string.Format("runer:{0}", (Time.time - resetTime).ToString("f2"));
79-
GUI.Label(new Rect(20, 50, 100, 30), str, style);
80-
str = string.Format("epsln:{0}", epsilon);
81-
GUI.Label(new Rect(20, 80, 100, 30), str, style);
82-
str = string.Format("score:{0}", env.Score);
83-
GUI.Label(new Rect(20, 110, 100, 30), str, style);
68+
string str = string.Format("round:{0} timer:{1}", epsilon, (Time.time - resetTime).ToString("f2"));
69+
GUI.Label(new Rect(30, 30, 100, 30), str, style);
8470
}
8571

8672
void Update()
@@ -96,17 +82,6 @@ void Update()
9682
}
9783
env.OnUpdate(delta);
9884
pillMgr.Update(delta);
99-
100-
m_FpsAccumulator++;
101-
m_TotalTime += Time.realtimeSinceStartup - m_SignTime;
102-
m_SignTime = Time.realtimeSinceStartup;
103-
if (Time.realtimeSinceStartup > m_FpsNextPeriod)
104-
{
105-
m_CurrentFps = m_FpsAccumulator / m_TotalTime;
106-
m_TotalTime = 0;
107-
m_FpsAccumulator = 0;
108-
m_FpsNextPeriod = Time.realtimeSinceStartup + fpsMeasurePeriod;
109-
}
11085
}
11186

11287
public void ManuControl(bool fly)

ppo/brain.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,21 +85,21 @@ def _build_anet(self, name, trainable):
8585
with tf.variable_scope(name):
8686
l_1 = tf.layers.dense(self.tfs, 256, tf.nn.relu, trainable=trainable)
8787
a_prob = tf.layers.dense(l_1, A_DIM, tf.nn.softmax, trainable=trainable)
88+
tf.identity(a_prob, name='probweights')
8889
params = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=name)
8990
return a_prob, params
9091

9192
def choose_action(self, s):
9293
prob_weights = self.sess.run(self.pi, feed_dict={self.tfs: s[None, :]})
9394
action = np.random.choice(range(prob_weights.shape[1]), p=prob_weights.ravel())
94-
tf.identity(prob_weights, name='probweights')
9595
logger.info("action:{0} prob:{1}".format(str(action), str(prob_weights)))
9696
return action
9797

9898
def get_v(self, s):
9999
return self.sess.run(self.v, {self.tfs: s})[0, 0]
100100

101101
def output_nodes(self):
102-
return ["state", "action", "advantage", "critic/discounted_r", "probweights"]
102+
return ["state", "action", "advantage", "critic/discounted_r", "pi/probweights"]
103103

104104
def freeze_graph(self):
105105
logger.info('**** Saved Model ****')

ppo/environment.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
BATCH = 8
3535
EP_LEN = 200
3636
all_ep_r = []
37-
Train = True
37+
Train = False
3838

3939

4040
class UnityEnvironment(object):

ppo/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def __init__(self):
2828
self.graph_def.ParseFromString(f.read())
2929
self.output = tf.import_graph_def(self.graph_def,
3030
input_map={'state:0': self.xstate},
31-
return_elements=['probweights:0'])
31+
return_elements=['pi/probweights:0'])
3232

3333
def update(self, s, a, r):
3434
self.sess.run(self.update_oldpi_op)

readme.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@ git clone https://github.com/huailiang/bird
2929
#切换到PolicyGradient
3030
git checkout PolicyGradient
3131

32-
#切换到ppo分支
33-
git checkout ppo
32+
#切换到mulstate分支
33+
git checkout mulstate
3434

3535
```
3636

0 commit comments

Comments
 (0)