Skip to content

Commit 83815c1

Browse files
committed
Convert xor_gen example into gym application
1 parent 25cf0b0 commit 83815c1

File tree

8 files changed

+119
-31
lines changed

8 files changed

+119
-31
lines changed

Diff for: .gitignore

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
xor_gen
1+
xor
22
adder_gen
33
gym
44
img2nn

Diff for: build.sh

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,6 @@ CFLAGS="-O3 -Wall -Wextra -I./thirdparty/"
66
LIBS="-lm"
77

88
clang $CFLAGS -o adder_gen adder_gen.c $LIBS
9-
clang $CFLAGS -o xor_gen xor_gen.c $LIBS
9+
clang $CFLAGS `pkg-config --cflags raylib` -o xor xor.c $LIBS `pkg-config --libs raylib` -lglfw -ldl -lpthread
1010
clang $CFLAGS `pkg-config --cflags raylib` -o gym gym.c $LIBS `pkg-config --libs raylib` -lglfw -ldl -lpthread
1111
clang $CFLAGS `pkg-config --cflags raylib` -o img2nn img2nn.c $LIBS `pkg-config --libs raylib` -lglfw -ldl -lpthread

Diff for: fonts/iosevka-regular.ttf

4.55 MB
Binary file not shown.

Diff for: nn.h

+1
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ void nn_backprop(NN nn, NN g, Mat ti, Mat to);
6666
void nn_learn(NN nn, NN g, float rate);
6767

6868
#ifdef NN_ENABLE_GYM
69+
#include <float.h>
6970
#include "raylib.h"
7071

7172
typedef struct {

Diff for: xor.arch

-1
This file was deleted.

Diff for: xor.c

+116
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
#define NN_IMPLEMENTATION
2+
#define NN_ENABLE_GYM
3+
#include "nn.h"
4+
5+
size_t arch[] = {2, 2, 1};
6+
size_t max_epoch = 100*1000;
7+
size_t epochs_per_frame = 103;
8+
float rate = 1.0f;
9+
bool paused = true;
10+
11+
void verify_nn_gate(Font font, NN nn, float rx, float ry, float rw, float rh)
12+
{
13+
(void) rw;
14+
char buffer[256];
15+
float s = rh*0.06;
16+
float pad = rh*0.03;
17+
for (size_t i = 0; i < 2; ++i) {
18+
for (size_t j = 0; j < 2; ++j) {
19+
MAT_AT(NN_INPUT(nn), 0, 0) = i;
20+
MAT_AT(NN_INPUT(nn), 0, 1) = j;
21+
nn_forward(nn);
22+
snprintf(buffer, sizeof(buffer), "%zu @ %zu == %f", i, j, MAT_AT(NN_OUTPUT(nn), 0, 0));
23+
DrawTextEx(font, buffer, CLITERAL(Vector2){rx, ry + (i*2 + j)*(s + pad)}, s, 0, WHITE);
24+
}
25+
}
26+
}
27+
28+
int main(void)
29+
{
30+
Mat t = mat_alloc(4, 3);
31+
for (size_t i = 0; i < 2; ++i) {
32+
for (size_t j = 0; j < 2; ++j) {
33+
size_t row = i*2 + j;
34+
MAT_AT(t, row, 0) = i;
35+
MAT_AT(t, row, 1) = j;
36+
MAT_AT(t, row, 2) = i^j;
37+
}
38+
}
39+
40+
Mat ti = {
41+
.rows = t.rows,
42+
.cols = 2,
43+
.stride = t.stride,
44+
.es = &MAT_AT(t, 0, 0),
45+
};
46+
47+
Mat to = {
48+
.rows = t.rows,
49+
.cols = 1,
50+
.stride = t.stride,
51+
.es = &MAT_AT(t, 0, ti.cols),
52+
};
53+
54+
55+
NN nn = nn_alloc(arch, ARRAY_LEN(arch));
56+
NN g = nn_alloc(arch, ARRAY_LEN(arch));
57+
nn_rand(nn, -1, 1);
58+
59+
size_t WINDOW_FACTOR = 80;
60+
size_t WINDOW_WIDTH = (16*WINDOW_FACTOR);
61+
size_t WINDOW_HEIGHT = (9*WINDOW_FACTOR);
62+
63+
SetConfigFlags(FLAG_WINDOW_RESIZABLE);
64+
InitWindow(WINDOW_WIDTH, WINDOW_HEIGHT, "xor");
65+
SetTargetFPS(60);
66+
67+
Font font = LoadFontEx("./fonts/iosevka-regular.ttf", 72, NULL, 0);
68+
SetTextureFilter(font.texture, TEXTURE_FILTER_BILINEAR);
69+
70+
Plot plot = {0};
71+
72+
size_t epoch = 0;
73+
while (!WindowShouldClose()) {
74+
if (IsKeyPressed(KEY_SPACE)) {
75+
paused = !paused;
76+
}
77+
if (IsKeyPressed(KEY_R)) {
78+
epoch = 0;
79+
nn_rand(nn, -1, 1);
80+
plot.count = 0;
81+
}
82+
83+
for (size_t i = 0; i < epochs_per_frame && !paused && epoch < max_epoch; ++i) {
84+
nn_backprop(nn, g, ti, to);
85+
nn_learn(nn, g, rate);
86+
epoch += 1;
87+
da_append(&plot, nn_cost(nn, ti, to));
88+
}
89+
90+
BeginDrawing();
91+
Color background_color = {0x18, 0x18, 0x18, 0xFF};
92+
ClearBackground(background_color);
93+
{
94+
int w = GetRenderWidth();
95+
int h = GetRenderHeight();
96+
97+
int rw = w/3;
98+
int rh = h*2/3;
99+
int rx = 0;
100+
int ry = h/2 - rh/2;
101+
102+
gym_plot(plot, rx, ry, rw, rh);
103+
rx += rw;
104+
gym_render_nn(nn, rx, ry, rw, rh);
105+
rx += rw;
106+
verify_nn_gate(font, nn, rx, ry, rw, rh);
107+
108+
char buffer[256];
109+
snprintf(buffer, sizeof(buffer), "Epoch: %zu/%zu, Rate: %f, Cost: %f", epoch, max_epoch, rate, nn_cost(nn, ti, to));
110+
DrawTextEx(font, buffer, CLITERAL(Vector2){}, h*0.04, 0, WHITE);
111+
}
112+
EndDrawing();
113+
}
114+
115+
return 0;
116+
}

Diff for: xor.mat

-72 Bytes
Binary file not shown.

Diff for: xor_gen.c

-28
This file was deleted.

0 commit comments

Comments
 (0)