Skip to content

Commit 25cf0b0

Browse files
committed
Make gym part of nn.h
1 parent 89aa3e7 commit 25cf0b0

File tree

2 files changed

+102
-90
lines changed

2 files changed

+102
-90
lines changed

img2nn.c

+5-90
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include "stb_image_write.h"
99

1010
#define NN_IMPLEMENTATION
11+
#define NN_ENABLE_GYM
1112
#include "nn.h"
1213

1314
char *args_shift(int *argc, char ***argv)
@@ -19,92 +20,6 @@ char *args_shift(int *argc, char ***argv)
1920
return result;
2021
}
2122

22-
typedef struct {
23-
float *items;
24-
size_t count;
25-
size_t capacity;
26-
} Cost_Plot;
27-
28-
#define DA_INIT_CAP 256
29-
#define da_append(da, item) \
30-
do { \
31-
if ((da)->count >= (da)->capacity) { \
32-
(da)->capacity = (da)->capacity == 0 ? DA_INIT_CAP : (da)->capacity*2; \
33-
(da)->items = realloc((da)->items, (da)->capacity*sizeof(*(da)->items)); \
34-
assert((da)->items != NULL && "Buy more RAM lol"); \
35-
} \
36-
\
37-
(da)->items[(da)->count++] = (item); \
38-
} while (0)
39-
40-
void nn_render_raylib(NN nn, float rx, float ry, float rw, float rh)
41-
{
42-
Color low_color = {0xFF, 0x00, 0xFF, 0xFF};
43-
Color high_color = {0x00, 0xFF, 0x00, 0xFF};
44-
45-
float neuron_radius = rh*0.03;
46-
float layer_border_vpad = rh*0.08;
47-
float layer_border_hpad = rw*0.06;
48-
float nn_width = rw - 2*layer_border_hpad;
49-
float nn_height = rh - 2*layer_border_vpad;
50-
float nn_x = rx + rw/2 - nn_width/2;
51-
float nn_y = ry + rh/2 - nn_height/2;
52-
size_t arch_count = nn.count + 1;
53-
float layer_hpad = nn_width / arch_count;
54-
for (size_t l = 0; l < arch_count; ++l) {
55-
float layer_vpad1 = nn_height / nn.as[l].cols;
56-
for (size_t i = 0; i < nn.as[l].cols; ++i) {
57-
float cx1 = nn_x + l*layer_hpad + layer_hpad/2;
58-
float cy1 = nn_y + i*layer_vpad1 + layer_vpad1/2;
59-
if (l+1 < arch_count) {
60-
float layer_vpad2 = nn_height / nn.as[l+1].cols;
61-
for (size_t j = 0; j < nn.as[l+1].cols; ++j) {
62-
// i - rows of ws
63-
// j - cols of ws
64-
float cx2 = nn_x + (l+1)*layer_hpad + layer_hpad/2;
65-
float cy2 = nn_y + j*layer_vpad2 + layer_vpad2/2;
66-
float value = sigmoidf(MAT_AT(nn.ws[l], i, j));
67-
high_color.a = floorf(255.f*value);
68-
float thick = rh*0.004f;
69-
Vector2 start = {cx1, cy1};
70-
Vector2 end = {cx2, cy2};
71-
DrawLineEx(start, end, thick, ColorAlphaBlend(low_color, high_color, WHITE));
72-
}
73-
}
74-
if (l > 0) {
75-
high_color.a = floorf(255.f*sigmoidf(MAT_AT(nn.bs[l-1], 0, i)));
76-
DrawCircle(cx1, cy1, neuron_radius, ColorAlphaBlend(low_color, high_color, WHITE));
77-
} else {
78-
DrawCircle(cx1, cy1, neuron_radius, GRAY);
79-
}
80-
}
81-
}
82-
}
83-
84-
void plot_cost(Cost_Plot plot, int rx, int ry, int rw, int rh)
85-
{
86-
float min = FLT_MAX, max = FLT_MIN;
87-
for (size_t i = 0; i < plot.count; ++i) {
88-
if (max < plot.items[i]) max = plot.items[i];
89-
if (min > plot.items[i]) min = plot.items[i];
90-
}
91-
92-
if (min > 0) min = 0;
93-
size_t n = plot.count;
94-
if (n < 1000) n = 1000;
95-
for (size_t i = 0; i+1 < plot.count; ++i) {
96-
float x1 = rx + (float)rw/n*i;
97-
float y1 = ry + (1 - (plot.items[i] - min)/(max - min))*rh;
98-
float x2 = rx + (float)rw/n*(i+1);
99-
float y2 = ry + (1 - (plot.items[i+1] - min)/(max - min))*rh;
100-
DrawLineEx((Vector2){x1, y1}, (Vector2){x2, y2}, rh*0.005, RED);
101-
}
102-
103-
float y0 = ry + (1 - (0 - min)/(max - min))*rh;
104-
DrawLineEx((Vector2){rx + 0, y0}, (Vector2){rx + rw - 1, y0}, rh*0.005, WHITE);
105-
DrawText("0", rx + 0, y0 - rh*0.04, rh*0.04, WHITE);
106-
}
107-
10823
int main(int argc, char **argv)
10924
{
11025
const char *program = args_shift(&argc, &argv);
@@ -156,7 +71,7 @@ int main(int argc, char **argv)
15671
// MAT_PRINT(ti);
15772
// MAT_PRINT(to);
15873

159-
size_t arch[] = {2, 7, 5, 1};
74+
size_t arch[] = {2, 7, 7, 1};
16075
NN nn = nn_alloc(arch, ARRAY_LEN(arch));
16176
NN g = nn_alloc(arch, ARRAY_LEN(arch));
16277
nn_rand(nn, -1, 1);
@@ -169,7 +84,7 @@ int main(int argc, char **argv)
16984
InitWindow(WINDOW_WIDTH, WINDOW_HEIGHT, "gym");
17085
SetTargetFPS(60);
17186

172-
Cost_Plot plot = {0};
87+
Plot plot = {0};
17388

17489
Image preview_image = GenImageColor(img_width, img_height, BLACK);
17590
Texture2D preview_texture = LoadTextureFromImage(preview_image);
@@ -217,9 +132,9 @@ int main(int argc, char **argv)
217132
int rx = 0;
218133
int ry = h/2 - rh/2;
219134

220-
plot_cost(plot, rx, ry, rw, rh);
135+
gym_plot(plot, rx, ry, rw, rh);
221136
rx += rw;
222-
nn_render_raylib(nn, rx, ry, rw, rh);
137+
gym_render_nn(nn, rx, ry, rw, rh);
223138
rx += rw;
224139

225140
float scale = 10;

nn.h

+97
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,32 @@ void nn_finite_diff(NN nn, NN g, float eps, Mat ti, Mat to);
6565
void nn_backprop(NN nn, NN g, Mat ti, Mat to);
6666
void nn_learn(NN nn, NN g, float rate);
6767

68+
#ifdef NN_ENABLE_GYM
69+
#include "raylib.h"
70+
71+
typedef struct {
72+
float *items;
73+
size_t count;
74+
size_t capacity;
75+
} Plot;
76+
77+
#define DA_INIT_CAP 256
78+
#define da_append(da, item) \
79+
do { \
80+
if ((da)->count >= (da)->capacity) { \
81+
(da)->capacity = (da)->capacity == 0 ? DA_INIT_CAP : (da)->capacity*2; \
82+
(da)->items = realloc((da)->items, (da)->capacity*sizeof(*(da)->items)); \
83+
assert((da)->items != NULL && "Buy more RAM lol"); \
84+
} \
85+
\
86+
(da)->items[(da)->count++] = (item); \
87+
} while (0)
88+
89+
void gym_render_nn(NN nn, float rx, float ry, float rw, float rh);
90+
void gym_plot(Plot plot, int rx, int ry, int rw, int rh);
91+
92+
#endif // NN_ENABLE_GYM
93+
6894
#endif // NN_H_
6995

7096
#ifdef NN_IMPLEMENTATION
@@ -402,4 +428,75 @@ void nn_learn(NN nn, NN g, float rate)
402428
}
403429
}
404430

431+
#ifdef NN_ENABLE_GYM
432+
433+
void gym_render_nn(NN nn, float rx, float ry, float rw, float rh)
434+
{
435+
Color low_color = {0xFF, 0x00, 0xFF, 0xFF};
436+
Color high_color = {0x00, 0xFF, 0x00, 0xFF};
437+
438+
float neuron_radius = rh*0.03;
439+
float layer_border_vpad = rh*0.08;
440+
float layer_border_hpad = rw*0.06;
441+
float nn_width = rw - 2*layer_border_hpad;
442+
float nn_height = rh - 2*layer_border_vpad;
443+
float nn_x = rx + rw/2 - nn_width/2;
444+
float nn_y = ry + rh/2 - nn_height/2;
445+
size_t arch_count = nn.count + 1;
446+
float layer_hpad = nn_width / arch_count;
447+
for (size_t l = 0; l < arch_count; ++l) {
448+
float layer_vpad1 = nn_height / nn.as[l].cols;
449+
for (size_t i = 0; i < nn.as[l].cols; ++i) {
450+
float cx1 = nn_x + l*layer_hpad + layer_hpad/2;
451+
float cy1 = nn_y + i*layer_vpad1 + layer_vpad1/2;
452+
if (l+1 < arch_count) {
453+
float layer_vpad2 = nn_height / nn.as[l+1].cols;
454+
for (size_t j = 0; j < nn.as[l+1].cols; ++j) {
455+
// i - rows of ws
456+
// j - cols of ws
457+
float cx2 = nn_x + (l+1)*layer_hpad + layer_hpad/2;
458+
float cy2 = nn_y + j*layer_vpad2 + layer_vpad2/2;
459+
float value = sigmoidf(MAT_AT(nn.ws[l], i, j));
460+
high_color.a = floorf(255.f*value);
461+
float thick = rh*0.004f;
462+
Vector2 start = {cx1, cy1};
463+
Vector2 end = {cx2, cy2};
464+
DrawLineEx(start, end, thick, ColorAlphaBlend(low_color, high_color, WHITE));
465+
}
466+
}
467+
if (l > 0) {
468+
high_color.a = floorf(255.f*sigmoidf(MAT_AT(nn.bs[l-1], 0, i)));
469+
DrawCircle(cx1, cy1, neuron_radius, ColorAlphaBlend(low_color, high_color, WHITE));
470+
} else {
471+
DrawCircle(cx1, cy1, neuron_radius, GRAY);
472+
}
473+
}
474+
}
475+
}
476+
477+
void gym_plot(Plot plot, int rx, int ry, int rw, int rh)
478+
{
479+
float min = FLT_MAX, max = FLT_MIN;
480+
for (size_t i = 0; i < plot.count; ++i) {
481+
if (max < plot.items[i]) max = plot.items[i];
482+
if (min > plot.items[i]) min = plot.items[i];
483+
}
484+
485+
if (min > 0) min = 0;
486+
size_t n = plot.count;
487+
if (n < 1000) n = 1000;
488+
for (size_t i = 0; i+1 < plot.count; ++i) {
489+
float x1 = rx + (float)rw/n*i;
490+
float y1 = ry + (1 - (plot.items[i] - min)/(max - min))*rh;
491+
float x2 = rx + (float)rw/n*(i+1);
492+
float y2 = ry + (1 - (plot.items[i+1] - min)/(max - min))*rh;
493+
DrawLineEx((Vector2){x1, y1}, (Vector2){x2, y2}, rh*0.005, RED);
494+
}
495+
496+
float y0 = ry + (1 - (0 - min)/(max - min))*rh;
497+
DrawLineEx((Vector2){rx + 0, y0}, (Vector2){rx + rw - 1, y0}, rh*0.005, WHITE);
498+
DrawText("0", rx + 0, y0 - rh*0.04, rh*0.04, WHITE);
499+
}
500+
#endif // NN_ENABLE_GYM
501+
405502
#endif // NN_IMPLEMENTATION

0 commit comments

Comments
 (0)