-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathrpsTemplate.js
109 lines (97 loc) · 1.97 KB
/
rpsTemplate.js
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
//definitions
moves = ["ROCK", "PAPER", "SCISSORS"];
var myMove;
var opMove;
var opMoveIdx;
const model = tf.sequential();
/**
*Initialize the model parameters and plot the initial probabilities.
*/
function init(){
//create model class
//plot probs
}
/**
* Choose an agent's move based on the phase.
* Training Phase: Choose actions randomly
* Evaluation Phase: Choose actions based on argmax probability.
*
* Params: My Move
* Return: Agent Move
*/
function chooseMove(move){
//check if training or evaluate
//if training choose randomly
//if eval get argmax of the move
}
/**
* Plot the probabilities of the choosing a move for each of the my moves.
*/
function plotProbs(){
var divs = ['div1', 'div2', 'div3']
var probs;
var data;
var xs;
var logits;
for(var i=0;i<3;i++){
xs = tf.tensor2d(convertToOneHot(moves[i]), [1, 3]);
logits = model.predict(xs).arraySync()[0];
probs = tf.softmax(logits).arraySync();
data = [
{
x:moves,
y:probs,
type:'bar'
}
];
var layout = {
title: 'What should I play against ' + moves[i] + '?',
width: 450,
height: 300
};
Plotly.newPlot(divs[i], data, layout);
}
}
/**
* Trains the model based on the reward given by the user.
*
* Params: reward
* Return: None
*/
function train(reward){
//check phase
//if phase is train
//convert my move to one hot
//pass through network
//update the model
//plot probs after await is done
}
/**
* Converts a move into a one-hot vector.
*
* Params: Move
* Return: One-Hot-Vector
*/
function convertToOneHot(move){
if(move=="ROCK") return [1, 0, 0];
if(move=="PAPER") return [0, 1, 0];
if(move=="SCISSORS") return [0, 0, 1];
//throw error
}
/**
* Choose the index of the maximum value from the array.
*
* Params: Array of values
* Return: Index of the max value
*/
function getMaxIndex(values){
var max=values[0];
var index=0;
for(var i=1;i<values.length;i++){
if(values[i]>max){
max = values[i];
index = i;
}
}
return index;
}