8
8
#include "stb_image_write.h"
9
9
10
10
#define NN_IMPLEMENTATION
11
+ #define NN_ENABLE_GYM
11
12
#include "nn.h"
12
13
13
14
char * args_shift (int * argc , char * * * argv )
@@ -19,92 +20,6 @@ char *args_shift(int *argc, char ***argv)
19
20
return result ;
20
21
}
21
22
22
- typedef struct {
23
- float * items ;
24
- size_t count ;
25
- size_t capacity ;
26
- } Cost_Plot ;
27
-
28
- #define DA_INIT_CAP 256
29
- #define da_append (da , item ) \
30
- do { \
31
- if ((da)->count >= (da)->capacity) { \
32
- (da)->capacity = (da)->capacity == 0 ? DA_INIT_CAP : (da)->capacity*2; \
33
- (da)->items = realloc((da)->items, (da)->capacity*sizeof(*(da)->items)); \
34
- assert((da)->items != NULL && "Buy more RAM lol"); \
35
- } \
36
- \
37
- (da)->items[(da)->count++] = (item); \
38
- } while (0)
39
-
40
- void nn_render_raylib (NN nn , float rx , float ry , float rw , float rh )
41
- {
42
- Color low_color = {0xFF , 0x00 , 0xFF , 0xFF };
43
- Color high_color = {0x00 , 0xFF , 0x00 , 0xFF };
44
-
45
- float neuron_radius = rh * 0.03 ;
46
- float layer_border_vpad = rh * 0.08 ;
47
- float layer_border_hpad = rw * 0.06 ;
48
- float nn_width = rw - 2 * layer_border_hpad ;
49
- float nn_height = rh - 2 * layer_border_vpad ;
50
- float nn_x = rx + rw /2 - nn_width /2 ;
51
- float nn_y = ry + rh /2 - nn_height /2 ;
52
- size_t arch_count = nn .count + 1 ;
53
- float layer_hpad = nn_width / arch_count ;
54
- for (size_t l = 0 ; l < arch_count ; ++ l ) {
55
- float layer_vpad1 = nn_height / nn .as [l ].cols ;
56
- for (size_t i = 0 ; i < nn .as [l ].cols ; ++ i ) {
57
- float cx1 = nn_x + l * layer_hpad + layer_hpad /2 ;
58
- float cy1 = nn_y + i * layer_vpad1 + layer_vpad1 /2 ;
59
- if (l + 1 < arch_count ) {
60
- float layer_vpad2 = nn_height / nn .as [l + 1 ].cols ;
61
- for (size_t j = 0 ; j < nn .as [l + 1 ].cols ; ++ j ) {
62
- // i - rows of ws
63
- // j - cols of ws
64
- float cx2 = nn_x + (l + 1 )* layer_hpad + layer_hpad /2 ;
65
- float cy2 = nn_y + j * layer_vpad2 + layer_vpad2 /2 ;
66
- float value = sigmoidf (MAT_AT (nn .ws [l ], i , j ));
67
- high_color .a = floorf (255.f * value );
68
- float thick = rh * 0.004f ;
69
- Vector2 start = {cx1 , cy1 };
70
- Vector2 end = {cx2 , cy2 };
71
- DrawLineEx (start , end , thick , ColorAlphaBlend (low_color , high_color , WHITE ));
72
- }
73
- }
74
- if (l > 0 ) {
75
- high_color .a = floorf (255.f * sigmoidf (MAT_AT (nn .bs [l - 1 ], 0 , i )));
76
- DrawCircle (cx1 , cy1 , neuron_radius , ColorAlphaBlend (low_color , high_color , WHITE ));
77
- } else {
78
- DrawCircle (cx1 , cy1 , neuron_radius , GRAY );
79
- }
80
- }
81
- }
82
- }
83
-
84
- void plot_cost (Cost_Plot plot , int rx , int ry , int rw , int rh )
85
- {
86
- float min = FLT_MAX , max = FLT_MIN ;
87
- for (size_t i = 0 ; i < plot .count ; ++ i ) {
88
- if (max < plot .items [i ]) max = plot .items [i ];
89
- if (min > plot .items [i ]) min = plot .items [i ];
90
- }
91
-
92
- if (min > 0 ) min = 0 ;
93
- size_t n = plot .count ;
94
- if (n < 1000 ) n = 1000 ;
95
- for (size_t i = 0 ; i + 1 < plot .count ; ++ i ) {
96
- float x1 = rx + (float )rw /n * i ;
97
- float y1 = ry + (1 - (plot .items [i ] - min )/(max - min ))* rh ;
98
- float x2 = rx + (float )rw /n * (i + 1 );
99
- float y2 = ry + (1 - (plot .items [i + 1 ] - min )/(max - min ))* rh ;
100
- DrawLineEx ((Vector2 ){x1 , y1 }, (Vector2 ){x2 , y2 }, rh * 0.005 , RED );
101
- }
102
-
103
- float y0 = ry + (1 - (0 - min )/(max - min ))* rh ;
104
- DrawLineEx ((Vector2 ){rx + 0 , y0 }, (Vector2 ){rx + rw - 1 , y0 }, rh * 0.005 , WHITE );
105
- DrawText ("0" , rx + 0 , y0 - rh * 0.04 , rh * 0.04 , WHITE );
106
- }
107
-
108
23
int main (int argc , char * * argv )
109
24
{
110
25
const char * program = args_shift (& argc , & argv );
@@ -156,7 +71,7 @@ int main(int argc, char **argv)
156
71
// MAT_PRINT(ti);
157
72
// MAT_PRINT(to);
158
73
159
- size_t arch [] = {2 , 7 , 5 , 1 };
74
+ size_t arch [] = {2 , 7 , 7 , 1 };
160
75
NN nn = nn_alloc (arch , ARRAY_LEN (arch ));
161
76
NN g = nn_alloc (arch , ARRAY_LEN (arch ));
162
77
nn_rand (nn , -1 , 1 );
@@ -169,7 +84,7 @@ int main(int argc, char **argv)
169
84
InitWindow (WINDOW_WIDTH , WINDOW_HEIGHT , "gym" );
170
85
SetTargetFPS (60 );
171
86
172
- Cost_Plot plot = {0 };
87
+ Plot plot = {0 };
173
88
174
89
Image preview_image = GenImageColor (img_width , img_height , BLACK );
175
90
Texture2D preview_texture = LoadTextureFromImage (preview_image );
@@ -217,9 +132,9 @@ int main(int argc, char **argv)
217
132
int rx = 0 ;
218
133
int ry = h /2 - rh /2 ;
219
134
220
- plot_cost (plot , rx , ry , rw , rh );
135
+ gym_plot (plot , rx , ry , rw , rh );
221
136
rx += rw ;
222
- nn_render_raylib (nn , rx , ry , rw , rh );
137
+ gym_render_nn (nn , rx , ry , rw , rh );
223
138
rx += rw ;
224
139
225
140
float scale = 10 ;
0 commit comments