Skip to content

Commit

Permalink
mult state
Browse files Browse the repository at this point in the history
  • Loading branch information
huailiang committed Dec 8, 2018
1 parent 9038a4a commit e935f1a
Show file tree
Hide file tree
Showing 13 changed files with 1,244 additions and 115 deletions.
1,290 changes: 1,222 additions & 68 deletions Assets/Resources/Prefabs/Birds/Bird_1.prefab

Large diffs are not rendered by default.

Binary file modified Assets/Resources/tfModels/ppo.bytes
Binary file not shown.
2 changes: 1 addition & 1 deletion Assets/Scene/3DBird.unity
Original file line number Diff line number Diff line change
Expand Up @@ -903,7 +903,7 @@ Transform:
m_PrefabInternal: {fileID: 0}
m_GameObject: {fileID: 402347489}
m_LocalRotation: {x: 0, y: 0, z: 0, w: 1}
m_LocalPosition: {x: 0, y: 2, z: 0}
m_LocalPosition: {x: 0, y: 1, z: 0}
m_LocalScale: {x: 1, y: 1, z: 1}
m_Children:
- {fileID: 335634046}
Expand Down
6 changes: 3 additions & 3 deletions Assets/Scripts/Core/BaseEnv.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ public abstract class BaseEnv : ScriptableObject
protected int[] last_state;
protected int total_r = 0;
protected BirdAction last_action = BirdAction.PAD;

public int Score { get { return total_r; } }

protected abstract bool birdFly { get; }
Expand All @@ -35,7 +35,7 @@ void OnScore(object arg)

void OnDied(object arg)
{
last_r = -100;
last_r = -2000;
}

public virtual void OnApplicationQuit() { }
Expand Down Expand Up @@ -71,7 +71,7 @@ public virtual void OnTick()
public abstract BirdAction choose_action(int[] state);

public abstract void UpdateState(int[] state, int[] state_, int rewd, BirdAction action);

public virtual void OnRestart(int[] state) { }

public virtual void OnInspector() { }
Expand Down
16 changes: 8 additions & 8 deletions Assets/Scripts/Core/InternalEnv.cs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ public override void OnTick()
{
UpdateState(last_state, state, last_r, last_action);
}

//do next loop
BirdAction action = choose_action(state);
GameMgr.S.RespondByDecision(action);
Expand All @@ -63,10 +62,12 @@ public override BirdAction choose_action(int[] state)
{
#if TensorFlow
var runner = session.GetRunner();
float[,] sample = new float[1, 1];
// sample[0, 0] = state;
TFTensor t = new TFTensor(sample);
runner.AddInput(graph["state"][0], t);
float[,] fstate = new float[1, 3];
for (int i = 0; i < state.Length; i++)
{
fstate[0, i] = state[i];
}
runner.AddInput(graph["state"][0], fstate);
runner.Fetch(graph["pi/probweights"][0]);
TFTensor[] networkOutput;
try
Expand All @@ -82,12 +83,11 @@ public override BirdAction choose_action(int[] state)
}
finally
{
throw new System.Exception(errorMessage);
throw new System.Exception(errorMessage + " \n" + e.StackTrace);
}
}
// Debug.Log(networkOutput.Length);
float[,] output = (float[,])networkOutput[0].GetValue();
Debug.Log("choice action: " + output[0, 0] + " " + output[0, 1]);
Debug.Log(string.Format("pi/probweights:{0},{1} ", output[0, 0], output[0, 1]));
int rand = Random.Range(0, 100);
return rand < (int)(output[0, 0] * 100) ? BirdAction.FLY : BirdAction.PAD;
#else
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@


[CustomEditor(typeof(GameMgr))]
public class GameManagerEditor : Editor
public class GameMgrEditor : Editor
{
public override void OnInspectorGUI()
{
Expand Down
File renamed without changes.
2 changes: 1 addition & 1 deletion Assets/Scripts/Env/PillarManager.cs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ void CreatePillar()
currPillar = pillar;
pillar.transform.position = new Vector3(EnvGlobalValue.PillarBornX, 0, 0);
pillar.transform.localScale = Vector3.one;
int state = Random.Range(0, 2);
int state = 1;//Random.Range(0, 2);
pillar.SetState(state);
run_pool.Add(pillar);
}
Expand Down
29 changes: 2 additions & 27 deletions Assets/Scripts/Gamer/GameMgr.cs
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,6 @@ public class GameMgr : MonoBehaviour
public Bird mainBird;
public PillarMgr pillMgr;

const float fpsMeasurePeriod = 0.5f;
private int m_FpsAccumulator = 0;
private float m_FpsNextPeriod = 0;
private float m_TotalTime = 0f;
private float m_CurrentFps;
private float m_SignTime = 0;

public bool IsGameOver { get { return isGameOver; } }

public bool IsGameStart { get { return isGameStart; } }
Expand Down Expand Up @@ -70,17 +63,10 @@ void Awake()
env.Init();
}


void OnGUI()
{
string str = string.Format("frame:{0} ", m_CurrentFps.ToString("f2"));
GUI.Label(new Rect(20, 20, 100, 30), str, style);
str = string.Format("runer:{0}", (Time.time - resetTime).ToString("f2"));
GUI.Label(new Rect(20, 50, 100, 30), str, style);
str = string.Format("epsln:{0}", epsilon);
GUI.Label(new Rect(20, 80, 100, 30), str, style);
str = string.Format("score:{0}", env.Score);
GUI.Label(new Rect(20, 110, 100, 30), str, style);
string str = string.Format("round:{0} timer:{1}", epsilon, (Time.time - resetTime).ToString("f2"));
GUI.Label(new Rect(30, 30, 100, 30), str, style);
}

void Update()
Expand All @@ -96,17 +82,6 @@ void Update()
}
env.OnUpdate(delta);
pillMgr.Update(delta);

m_FpsAccumulator++;
m_TotalTime += Time.realtimeSinceStartup - m_SignTime;
m_SignTime = Time.realtimeSinceStartup;
if (Time.realtimeSinceStartup > m_FpsNextPeriod)
{
m_CurrentFps = m_FpsAccumulator / m_TotalTime;
m_TotalTime = 0;
m_FpsAccumulator = 0;
m_FpsNextPeriod = Time.realtimeSinceStartup + fpsMeasurePeriod;
}
}

public void ManuControl(bool fly)
Expand Down
4 changes: 2 additions & 2 deletions ppo/brain.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,21 +85,21 @@ def _build_anet(self, name, trainable):
with tf.variable_scope(name):
l_1 = tf.layers.dense(self.tfs, 256, tf.nn.relu, trainable=trainable)
a_prob = tf.layers.dense(l_1, A_DIM, tf.nn.softmax, trainable=trainable)
tf.identity(a_prob, name='probweights')
params = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=name)
return a_prob, params

def choose_action(self, s):
prob_weights = self.sess.run(self.pi, feed_dict={self.tfs: s[None, :]})
action = np.random.choice(range(prob_weights.shape[1]), p=prob_weights.ravel())
tf.identity(prob_weights, name='probweights')
logger.info("action:{0} prob:{1}".format(str(action), str(prob_weights)))
return action

def get_v(self, s):
return self.sess.run(self.v, {self.tfs: s})[0, 0]

def output_nodes(self):
return ["state", "action", "advantage", "critic/discounted_r", "probweights"]
return ["state", "action", "advantage", "critic/discounted_r", "pi/probweights"]

def freeze_graph(self):
logger.info('**** Saved Model ****')
Expand Down
2 changes: 1 addition & 1 deletion ppo/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
BATCH = 8
EP_LEN = 200
all_ep_r = []
Train = True
Train = False


class UnityEnvironment(object):
Expand Down
2 changes: 1 addition & 1 deletion ppo/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def __init__(self):
self.graph_def.ParseFromString(f.read())
self.output = tf.import_graph_def(self.graph_def,
input_map={'state:0': self.xstate},
return_elements=['probweights:0'])
return_elements=['pi/probweights:0'])

def update(self, s, a, r):
self.sess.run(self.update_oldpi_op)
Expand Down
4 changes: 2 additions & 2 deletions readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ git clone https://github.com/huailiang/bird
#切换到PolicyGradient
git checkout PolicyGradient

#切换到ppo分支
git checkout ppo
#切换到mulstate分支
git checkout mulstate

```

Expand Down

0 comments on commit e935f1a

Please sign in to comment.