|
| 1 | +#define NN_IMPLEMENTATION |
| 2 | +#define NN_ENABLE_GYM |
| 3 | +#include "nn.h" |
| 4 | + |
| 5 | +size_t arch[] = {2, 2, 1}; |
| 6 | +size_t max_epoch = 100*1000; |
| 7 | +size_t epochs_per_frame = 103; |
| 8 | +float rate = 1.0f; |
| 9 | +bool paused = true; |
| 10 | + |
| 11 | +void verify_nn_gate(Font font, NN nn, float rx, float ry, float rw, float rh) |
| 12 | +{ |
| 13 | + (void) rw; |
| 14 | + char buffer[256]; |
| 15 | + float s = rh*0.06; |
| 16 | + float pad = rh*0.03; |
| 17 | + for (size_t i = 0; i < 2; ++i) { |
| 18 | + for (size_t j = 0; j < 2; ++j) { |
| 19 | + MAT_AT(NN_INPUT(nn), 0, 0) = i; |
| 20 | + MAT_AT(NN_INPUT(nn), 0, 1) = j; |
| 21 | + nn_forward(nn); |
| 22 | + snprintf(buffer, sizeof(buffer), "%zu @ %zu == %f", i, j, MAT_AT(NN_OUTPUT(nn), 0, 0)); |
| 23 | + DrawTextEx(font, buffer, CLITERAL(Vector2){rx, ry + (i*2 + j)*(s + pad)}, s, 0, WHITE); |
| 24 | + } |
| 25 | + } |
| 26 | +} |
| 27 | + |
| 28 | +int main(void) |
| 29 | +{ |
| 30 | + Mat t = mat_alloc(4, 3); |
| 31 | + for (size_t i = 0; i < 2; ++i) { |
| 32 | + for (size_t j = 0; j < 2; ++j) { |
| 33 | + size_t row = i*2 + j; |
| 34 | + MAT_AT(t, row, 0) = i; |
| 35 | + MAT_AT(t, row, 1) = j; |
| 36 | + MAT_AT(t, row, 2) = i^j; |
| 37 | + } |
| 38 | + } |
| 39 | + |
| 40 | + Mat ti = { |
| 41 | + .rows = t.rows, |
| 42 | + .cols = 2, |
| 43 | + .stride = t.stride, |
| 44 | + .es = &MAT_AT(t, 0, 0), |
| 45 | + }; |
| 46 | + |
| 47 | + Mat to = { |
| 48 | + .rows = t.rows, |
| 49 | + .cols = 1, |
| 50 | + .stride = t.stride, |
| 51 | + .es = &MAT_AT(t, 0, ti.cols), |
| 52 | + }; |
| 53 | + |
| 54 | + |
| 55 | + NN nn = nn_alloc(arch, ARRAY_LEN(arch)); |
| 56 | + NN g = nn_alloc(arch, ARRAY_LEN(arch)); |
| 57 | + nn_rand(nn, -1, 1); |
| 58 | + |
| 59 | + size_t WINDOW_FACTOR = 80; |
| 60 | + size_t WINDOW_WIDTH = (16*WINDOW_FACTOR); |
| 61 | + size_t WINDOW_HEIGHT = (9*WINDOW_FACTOR); |
| 62 | + |
| 63 | + SetConfigFlags(FLAG_WINDOW_RESIZABLE); |
| 64 | + InitWindow(WINDOW_WIDTH, WINDOW_HEIGHT, "xor"); |
| 65 | + SetTargetFPS(60); |
| 66 | + |
| 67 | + Font font = LoadFontEx("./fonts/iosevka-regular.ttf", 72, NULL, 0); |
| 68 | + SetTextureFilter(font.texture, TEXTURE_FILTER_BILINEAR); |
| 69 | + |
| 70 | + Plot plot = {0}; |
| 71 | + |
| 72 | + size_t epoch = 0; |
| 73 | + while (!WindowShouldClose()) { |
| 74 | + if (IsKeyPressed(KEY_SPACE)) { |
| 75 | + paused = !paused; |
| 76 | + } |
| 77 | + if (IsKeyPressed(KEY_R)) { |
| 78 | + epoch = 0; |
| 79 | + nn_rand(nn, -1, 1); |
| 80 | + plot.count = 0; |
| 81 | + } |
| 82 | + |
| 83 | + for (size_t i = 0; i < epochs_per_frame && !paused && epoch < max_epoch; ++i) { |
| 84 | + nn_backprop(nn, g, ti, to); |
| 85 | + nn_learn(nn, g, rate); |
| 86 | + epoch += 1; |
| 87 | + da_append(&plot, nn_cost(nn, ti, to)); |
| 88 | + } |
| 89 | + |
| 90 | + BeginDrawing(); |
| 91 | + Color background_color = {0x18, 0x18, 0x18, 0xFF}; |
| 92 | + ClearBackground(background_color); |
| 93 | + { |
| 94 | + int w = GetRenderWidth(); |
| 95 | + int h = GetRenderHeight(); |
| 96 | + |
| 97 | + int rw = w/3; |
| 98 | + int rh = h*2/3; |
| 99 | + int rx = 0; |
| 100 | + int ry = h/2 - rh/2; |
| 101 | + |
| 102 | + gym_plot(plot, rx, ry, rw, rh); |
| 103 | + rx += rw; |
| 104 | + gym_render_nn(nn, rx, ry, rw, rh); |
| 105 | + rx += rw; |
| 106 | + verify_nn_gate(font, nn, rx, ry, rw, rh); |
| 107 | + |
| 108 | + char buffer[256]; |
| 109 | + snprintf(buffer, sizeof(buffer), "Epoch: %zu/%zu, Rate: %f, Cost: %f", epoch, max_epoch, rate, nn_cost(nn, ti, to)); |
| 110 | + DrawTextEx(font, buffer, CLITERAL(Vector2){}, h*0.04, 0, WHITE); |
| 111 | + } |
| 112 | + EndDrawing(); |
| 113 | + } |
| 114 | + |
| 115 | + return 0; |
| 116 | +} |
0 commit comments