-
Notifications
You must be signed in to change notification settings - Fork 1.4k
/
Copy pathcodelab_rebuild.yaml
609 lines (571 loc) · 23.1 KB
/
codelab_rebuild.yaml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
name: tfagents-flutter rebuild script
steps:
- name: step0
steps:
- name: Remove generated code
rmdir: step0/frontend
- name: Create project
flutter: create frontend
- name: Add deps
path: frontend
flutter: pub add http
- name: Strip DEVELOPMENT_TEAM
strip-lines-containing: DEVELOPMENT_TEAM =
path: frontend/ios/Runner.xcodeproj/project.pbxproj
- name: Configure analysis_options.yaml
path: frontend/analysis_options.yaml
replace-contents: |
include: ../../../analysis_options.yaml
analyzer:
errors:
unused_import: ignore
unused_field: ignore
unused_local_variable: ignore
- name: Replace lib/main.dart
path: frontend/lib/main.dart
replace-contents: |
import 'dart:async';
import 'dart:math';
import 'package:flutter/material.dart';
import 'game_agent.dart';
// Hidden board cell statuses; 'occupied' means it's part of the plane
const double hiddenBoardCellOccupied = 1;
const double hiddenBoardCellUnoccupied = 0;
// Visible board cell statuses
const double visibleBoardCellHit = 1;
const double visibleBoardCellMiss = -1;
const double visibleBoardCellUntried = 0;
void main() {
runApp(const PlaneStrike());
}
class PlaneStrike extends StatefulWidget {
const PlaneStrike({super.key});
// This widget is the root of your application.
@override
State<PlaneStrike> createState() => _PlaneStrikeState();
}
class _PlaneStrikeState extends State<PlaneStrike>
with SingleTickerProviderStateMixin {
// The board should be in square shape so we only need one size
final _boardSize = 8;
// Number of pieces needed to form a 'plane'
final _planePieceCount = 8;
late int _agentHitCount;
late int _playerHitCount;
late TFAgentsAgent _policyGradientAgent;
late List<List<double>> _agentVisibleBoardState;
late List<List<double>> _agentHiddenBoardState;
late List<List<double>> _playerVisibleBoardState;
late List<List<double>> _playerHiddenBoardState;
late int _agentActionX;
late int _agentActionY;
@override
void initState() {
super.initState();
_resetGame();
}
@override
Widget build(BuildContext context) {
return MaterialApp(
title: 'TFAgents Flutter Sample App',
theme: ThemeData(
colorScheme: ColorScheme.fromSeed(seedColor: Colors.deepPurple),
),
home: _buildGameBody(),
);
}
List<List<double>> _initEmptyBoard() =>
List.generate(_boardSize, (_) => List.filled(_boardSize, 0));
void _resetGame() {
_agentHitCount = 0;
_playerHitCount = 0;
_policyGradientAgent = TFAgentsAgent();
// We keep track of 4 sets of boards (2 for each player):
// - _*BoardState is the visible board that tracks the game progress
// - _*HiddentBoardState is the secret board that records the true plane location
_agentVisibleBoardState = _initEmptyBoard();
_agentHiddenBoardState = _setHiddenBoardState();
_playerVisibleBoardState = _initEmptyBoard();
_playerHiddenBoardState = _setHiddenBoardState();
}
List<List<double>> _setHiddenBoardState() {
var hiddenBoardState = _initEmptyBoard();
// Place the plane on the board
// First, decide the plane's orientation
// 0: heading right
// 1: heading up
// 2: heading left
// 3: heading down
var rng = Random();
int planeOrientation = rng.nextInt(4);
// Figrue out the location of plane core as the '*' below
// | | | | | ---
// |-*- -*- -*-| |
// | | | | | -*-
// --- |
int planeCoreX, planeCoreY;
switch (planeOrientation) {
case 0:
planeCoreX = rng.nextInt(_boardSize - 2) + 1;
planeCoreY = rng.nextInt(_boardSize - 3) + 2;
// Populate the tail
hiddenBoardState[planeCoreX][planeCoreY - 2] = hiddenBoardCellOccupied;
hiddenBoardState[planeCoreX - 1][planeCoreY - 2] =
hiddenBoardCellOccupied;
hiddenBoardState[planeCoreX + 1][planeCoreY - 2] =
hiddenBoardCellOccupied;
case 1:
planeCoreX = rng.nextInt(_boardSize - 3) + 1;
planeCoreY = rng.nextInt(_boardSize - 2) + 1;
// Populate the tail
hiddenBoardState[planeCoreX + 2][planeCoreY] = hiddenBoardCellOccupied;
hiddenBoardState[planeCoreX + 2][planeCoreY + 1] =
hiddenBoardCellOccupied;
hiddenBoardState[planeCoreX + 2][planeCoreY - 1] =
hiddenBoardCellOccupied;
case 2:
planeCoreX = rng.nextInt(_boardSize - 2) + 1;
planeCoreY = rng.nextInt(_boardSize - 3) + 1;
// Populate the tail
hiddenBoardState[planeCoreX][planeCoreY + 2] = hiddenBoardCellOccupied;
hiddenBoardState[planeCoreX - 1][planeCoreY + 2] =
hiddenBoardCellOccupied;
hiddenBoardState[planeCoreX + 1][planeCoreY + 2] =
hiddenBoardCellOccupied;
default:
planeCoreX = rng.nextInt(_boardSize - 3) + 2;
planeCoreY = rng.nextInt(_boardSize - 2) + 1;
// Populate the tail
hiddenBoardState[planeCoreX - 2][planeCoreY] = hiddenBoardCellOccupied;
hiddenBoardState[planeCoreX - 2][planeCoreY + 1] =
hiddenBoardCellOccupied;
hiddenBoardState[planeCoreX - 2][planeCoreY - 1] =
hiddenBoardCellOccupied;
}
// Populate the 'cross' in the plane
hiddenBoardState[planeCoreX][planeCoreY] = hiddenBoardCellOccupied;
hiddenBoardState[planeCoreX + 1][planeCoreY] = hiddenBoardCellOccupied;
hiddenBoardState[planeCoreX - 1][planeCoreY] = hiddenBoardCellOccupied;
hiddenBoardState[planeCoreX][planeCoreY + 1] = hiddenBoardCellOccupied;
hiddenBoardState[planeCoreX][planeCoreY - 1] = hiddenBoardCellOccupied;
return hiddenBoardState;
}
Widget _buildGameBody() {
return Scaffold(
appBar: AppBar(
// Here we take the value from the MyHomePage object that was created by
// the App.build method, and use it to set our appbar title.
title: const Text('Plane Strike game based on TF Agents and Flutter'),
),
body: SingleChildScrollView(
child: Column(
mainAxisSize: MainAxisSize.max,
mainAxisAlignment: MainAxisAlignment.spaceBetween,
children: [
Container(
width: 265,
height: 265,
margin: const EdgeInsets.only(
left: 0,
top: 10,
right: 0,
bottom: 0,
),
decoration: BoxDecoration(
border: Border.all(color: Colors.black, width: 2.0),
),
child: GridView.builder(
gridDelegate: SliverGridDelegateWithFixedCrossAxisCount(
crossAxisCount: _boardSize,
),
itemBuilder: _buildAgentBoardItems,
itemCount: _boardSize * _boardSize,
physics: const NeverScrollableScrollPhysics(),
),
),
Text(
"Agent's board (hits: $_playerHitCount)",
style: const TextStyle(
fontSize: 18,
color: Colors.blue,
fontWeight: FontWeight.bold,
),
),
const Divider(height: 20, thickness: 5, indent: 20, endIndent: 20),
Text(
'Your board (hits: $_agentHitCount)',
style: const TextStyle(
fontSize: 18,
color: Colors.purple,
fontWeight: FontWeight.bold,
),
),
Container(
width: 265,
height: 265,
decoration: BoxDecoration(
border: Border.all(color: Colors.black, width: 2.0),
),
child: GridView.builder(
gridDelegate: SliverGridDelegateWithFixedCrossAxisCount(
crossAxisCount: _boardSize,
),
itemBuilder: _buildPlayerBoardItems,
itemCount: _boardSize * _boardSize,
physics: const NeverScrollableScrollPhysics(),
),
),
Padding(
padding: const EdgeInsets.only(
left: 0,
top: 20,
right: 0,
bottom: 0,
),
child: FilledButton(
onPressed: () {
_resetGame();
setState(() {});
},
child: const Text('Reset game'),
),
),
],
),
),
);
}
Widget _buildAgentBoardItems(BuildContext context, int index) {
int x, y = 0;
x = (index / _boardSize).floor();
y = (index % _boardSize);
return GestureDetector(
onTap: () {
_gridItemTapped(context, x, y);
},
child: GridTile(
child: Container(
decoration: BoxDecoration(
border: Border.all(color: Colors.black, width: 0.5),
),
child: Center(child: _buildGridItem(x, y, 'agent')),
),
),
);
}
Widget _buildPlayerBoardItems(BuildContext context, int index) {
int x, y = 0;
x = (index / _boardSize).floor();
y = (index % _boardSize);
return GridTile(
child: Container(
decoration: BoxDecoration(
border: Border.all(color: Colors.black, width: 0.5),
),
child: Center(child: _buildGridItem(x, y, 'player')),
),
);
}
Widget _buildGridItem(int x, int y, String agentOrPlayer) {
var boardState = _agentVisibleBoardState;
var hiddenBoardState = _agentHiddenBoardState;
if (agentOrPlayer == 'player') {
boardState = _playerVisibleBoardState;
hiddenBoardState = _playerHiddenBoardState;
}
Color gridItemColor;
switch ((boardState[x][y]).toInt()) {
// hit
case 1:
gridItemColor = Colors.red;
// miss
case -1:
gridItemColor = Colors.yellow;
default:
if (hiddenBoardState[x][y] == hiddenBoardCellOccupied &&
agentOrPlayer == 'player') {
gridItemColor = Colors.green;
} else {
gridItemColor = Colors.white;
}
}
return Container(color: gridItemColor);
}
Future<void> _gridItemTapped(BuildContext context, int x, int y) async {
if (_agentHiddenBoardState[x][y] == hiddenBoardCellOccupied) {
// Non-repeat move
if (_agentVisibleBoardState[x][y] == visibleBoardCellUntried) {
_playerHitCount++;
}
_agentVisibleBoardState[x][y] = visibleBoardCellHit;
} else {
_agentVisibleBoardState[x][y] = visibleBoardCellMiss;
}
// TODO: add code for the agent to take an action
String userPrompt = '';
if (_playerHitCount == _planePieceCount &&
_agentHitCount == _planePieceCount) {
userPrompt = 'Draw game!';
} else if (_agentHitCount == _planePieceCount) {
userPrompt = 'Agent wins!';
} else if (_playerHitCount == _planePieceCount) {
userPrompt = 'You win!';
}
if (userPrompt != '') {
Future.delayed(const Duration(seconds: 2), () => setState(_resetGame));
if (!context.mounted) return;
ScaffoldMessenger.of(context).showSnackBar(
SnackBar(
content: Text(userPrompt, textAlign: TextAlign.center),
duration: const Duration(seconds: 2),
),
);
}
}
}
- name: Add lib/game_agent.dart
path: frontend/lib/game_agent.dart
replace-contents: |
import 'dart:convert';
import 'dart:io' show Platform;
import 'package:flutter/foundation.dart' show kIsWeb;
import 'package:http/http.dart' as http;
// TODO: add class definition for inputs
class TFAgentsAgent {
TFAgentsAgent();
Future<int> predict(List<List<double>> boardState) async {
String server = '';
if (!kIsWeb && Platform.isAndroid) {
// For Android emulator
server = '10.0.2.2';
} else {
// For iOS emulator, desktop and web platforms
server = '127.0.0.1';
}
// TODO: add code to predict next strike position
return 0;
}
}
- name: Replace test/widget_test.dart
path: frontend/test/widget_test.dart
replace-contents: |
// This is a basic Flutter widget test.
//
// To perform an interaction with a widget in your test, use the WidgetTester
// utility that Flutter provides. For example, you can send tap and scroll
// gestures. You can also use WidgetTester to find child widgets in the widget
// tree, read text, and verify that the values of widget properties are correct.
import 'package:flutter_test/flutter_test.dart';
import 'package:frontend/main.dart';
void main() {
testWidgets('Smoke test', (tester) async {
// Build our app and trigger a frame.
await tester.pumpWidget(const PlaneStrike());
// Verify that the widgets are there
expect(find.text('Reset game'), findsOneWidget);
});
}
- name: Patch macos/Runner/DebugProfile.entitlements
path: frontend/macos/Runner/DebugProfile.entitlements
patch-u: |
--- b/frontend/finished/macos/Runner/DebugProfile.entitlements
+++ a/frontend/finished/macos/Runner/DebugProfile.entitlements
@@ -8,5 +8,7 @@
<true/>
<key>com.apple.security.network.server</key>
<true/>
+ <key>com.apple.security.network.client</key>
+ <true/>
</dict>
</plist>
- name: Patch macos/Runner/Release.entitlements
path: frontend/macos/Runner/Release.entitlements
patch-u: |
--- b/frontend/finished/macos/Runner/Release.entitlements
+++ a/frontend/finished/macos/Runner/Release.entitlements
@@ -4,5 +4,7 @@
<dict>
<key>com.apple.security.app-sandbox</key>
<true/>
+ <key>com.apple.security.network.client</key>
+ <true/>
</dict>
</plist>
- name: Copy step0
copydir:
from: frontend
to: step0/frontend
- name: Flutter clean
path: step0/frontend
flutter: clean
- name: step1
steps:
- name: Remove generated code
rmdir: step1/frontend
- name: Copy step1
copydir:
from: frontend
to: step1/frontend
- name: Flutter clean
path: step1/frontend
flutter: clean
- name: step2
steps:
- name: Remove generated code
rmdir: step2/frontend
- name: Copy step2
copydir:
from: frontend
to: step2/frontend
- name: Flutter clean
path: step2/frontend
flutter: clean
- name: step3
steps:
- name: Remove generated code
rmdir: step3/frontend
- name: Copy step3
copydir:
from: frontend
to: step3/frontend
- name: Flutter clean
path: step3/frontend
flutter: clean
- name: step4
steps:
- name: Remove generated code
rmdir: step4/frontend
- name: Copy step4
copydir:
from: frontend
to: step4/frontend
- name: Flutter clean
path: step4/frontend
flutter: clean
- name: step5
steps:
- name: Remove generated code
rmdir: step5/frontend
- name: Patch lib/main.dart
path: frontend/lib/main.dart
patch-u: |
--- b/tfagents-flutter/step5/frontend/lib/main.dart
+++ a/tfagents-flutter/step5/frontend/lib/main.dart
@@ -297,7 +297,26 @@ class _PlaneStrikeState extends State<PlaneStrike>
_agentVisibleBoardState[x][y] = visibleBoardCellMiss;
}
- // TODO: add code for the agent to take an action
+ // Agent takes action
+ int agentAction = await _policyGradientAgent.predict(
+ _playerVisibleBoardState,
+ );
+ _agentActionX = agentAction ~/ _boardSize;
+ _agentActionY = agentAction % _boardSize;
+ if (_playerHiddenBoardState[_agentActionX][_agentActionY] ==
+ hiddenBoardCellOccupied) {
+ // Non-repeat move
+ if (_playerVisibleBoardState[_agentActionX][_agentActionY] ==
+ visibleBoardCellUntried) {
+ _agentHitCount++;
+ }
+ _playerVisibleBoardState[_agentActionX][_agentActionY] =
+ visibleBoardCellHit;
+ } else {
+ _playerVisibleBoardState[_agentActionX][_agentActionY] =
+ visibleBoardCellMiss;
+ }
+ setState(() {});
String userPrompt = '';
if (_playerHitCount == _planePieceCount &&
- name: Patch lib/game_agent.dart
path: frontend/lib/game_agent.dart
patch-u: |
--- b/tfagents-flutter/step5/frontend/lib/game_agent.dart
+++ a/tfagents-flutter/step5/frontend/lib/game_agent.dart
@@ -4,7 +4,19 @@ import 'dart:io' show Platform;
import 'package:flutter/foundation.dart' show kIsWeb;
import 'package:http/http.dart' as http;
-// TODO: add class definition for inputs
+class Inputs {
+ final List<double> _boardState;
+ Inputs(this._boardState);
+
+ Map<String, dynamic> toJson() {
+ final Map<String, dynamic> data = <String, dynamic>{};
+ data['0/discount'] = [0.0];
+ data['0/observation'] = [_boardState];
+ data['0/reward'] = [0.0];
+ data['0/step_type'] = [0];
+ return data;
+ }
+}
class TFAgentsAgent {
TFAgentsAgent();
@@ -19,7 +31,22 @@ class TFAgentsAgent {
server = '127.0.0.1';
}
- // TODO: add code to predict next strike position
- return 0;
+ var flattenedBoardState = boardState.expand((i) => i).toList();
+ final response = await http.post(
+ Uri.parse('http://$server:8501/v1/models/policy_model:predict'),
+ body: jsonEncode(<String, dynamic>{
+ 'signature_name': 'action',
+ 'instances': [Inputs(flattenedBoardState)],
+ }),
+ );
+
+ if (response.statusCode == 200) {
+ var output = List<int>.from(
+ jsonDecode(response.body)['predictions'] as List<dynamic>,
+ );
+ return output[0];
+ } else {
+ throw Exception('Error response');
+ }
}
}
- name: Copy step5
copydir:
from: frontend
to: step5/frontend
- name: Flutter clean
path: step5/frontend
flutter: clean
- name: step6
steps:
- name: Remove generated code
rmdir: step6/frontend
- name: Copy step6
copydir:
from: frontend
to: step6/frontend
- name: Flutter clean
path: step6/frontend
flutter: clean
- name: finished
steps:
- name: Remove generated code
rmdir: finished/frontend
- name: Patch analysis_options.yaml
path: frontend/analysis_options.yaml
patch-u: |
--- b/tfagents-flutter/finished/frontend/analysis_options.yaml
+++ a/tfagents-flutter/finished/frontend/analysis_options.yaml
@@ -1,7 +1 @@
include: ../../../analysis_options.yaml
-
-analyzer:
- errors:
- unused_import: ignore
- unused_field: ignore
- unused_local_variable: ignore
- name: Copy finished
copydir:
from: frontend
to: finished/frontend
- name: Flutter clean
path: finished/frontend
flutter: clean
- name: Cleanup
rmdir: frontend