-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtransformer.h
More file actions
86 lines (70 loc) · 3.43 KB
/
transformer.h
File metadata and controls
86 lines (70 loc) · 3.43 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
#ifndef TRANSFORMER_H
#define TRANSFORMER_H
#include "ternary.h"
// ---------------------------------------------------------------------------
// Model configuration
// ---------------------------------------------------------------------------
typedef struct {
int dim; // model / embedding dimension
int hidden_dim; // FFN intermediate dimension (typically 4 * dim or 8/3 * dim)
int n_heads; // number of attention heads
int head_dim; // dim / n_heads
int n_layers; // transformer layers
int vocab_size; // vocabulary size
int max_seq_len; // maximum sequence length
} TransformerConfig;
// ---------------------------------------------------------------------------
// Model weights — all linear layers are ternary
// ---------------------------------------------------------------------------
typedef struct {
// Per-layer ternary weight matrices
TernaryMatrix *wq; // [n_layers] dim → dim
TernaryMatrix *wk; // [n_layers] dim → dim
TernaryMatrix *wv; // [n_layers] dim → dim
TernaryMatrix *wo; // [n_layers] dim → dim
TernaryMatrix *w_up; // [n_layers] dim → hidden_dim
TernaryMatrix *w_down; // [n_layers] hidden_dim → dim
// Full-precision parameters
float *token_emb; // [vocab_size × dim]
float *rms_att_w; // [n_layers × dim] attention RMSNorm weights
float *rms_ffn_w; // [n_layers × dim] FFN RMSNorm weights
float *rms_final_w; // [dim] final RMSNorm weights
float *output_w; // [vocab_size × dim] output projection (FP32)
} TransformerWeights;
// ---------------------------------------------------------------------------
// Runtime state (activations + KV cache)
// ---------------------------------------------------------------------------
typedef struct {
float *x; // [dim] current activation
float *xb; // [dim] after RMSNorm
float *xb2; // [dim] second buffer
float *q; // [dim] query
float *k; // [dim] key
float *v; // [dim] value
float *att; // [n_heads × seq] attention scores
float *hid; // [hidden_dim] FFN intermediate
float *logits; // [vocab_size] output logits
int8_t *xq_buf; // [padded_dim] pre-quantised activation buffer
// KV cache: [n_layers × max_seq_len × dim]
float *key_cache;
float *value_cache;
} RunState;
// ---------------------------------------------------------------------------
// Operations
// ---------------------------------------------------------------------------
void rmsnorm(float *out, const float *x, const float *weight, int dim);
void apply_rope(float *q, float *k, int head_dim, int pos, int n_heads);
void softmax(float *x, int n);
void relu_squared(float *x, int n);
// Full forward pass — returns pointer to logits
float *forward(TransformerConfig *cfg, TransformerWeights *w,
RunState *s, int token, int pos);
// Sampling
int argmax(const float *logits, int n);
// Allocation / free
RunState alloc_run_state(TransformerConfig *cfg);
void free_run_state(RunState *s);
// Initialise random ternary weights for benchmarking
TransformerWeights alloc_random_weights(TransformerConfig *cfg, unsigned seed);
void free_weights(TransformerWeights *w, TransformerConfig *cfg);
#endif // TRANSFORMER_H