Skip to content

Commit bfce4aa

Browse files
committed
Add training script and classifier
1 parent f63545b commit bfce4aa

File tree

8 files changed

+519
-6
lines changed

8 files changed

+519
-6
lines changed

integrations/slack/package.json

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,20 +3,25 @@
33
"version": "2.5.3",
44
"private": true,
55
"dependencies": {
6-
"@gitbook/runtime": "*",
76
"@gitbook/api": "*",
7+
"@gitbook/runtime": "*",
88
"itty-router": "^4.0.26",
99
"js-sha256": "^0.9.0",
10-
"remove-markdown": "^0.5.0"
10+
"remove-markdown": "^0.5.0",
11+
"toygrad": "^2.6.0"
1112
},
1213
"devDependencies": {
1314
"@gitbook/cli": "workspace:*",
14-
"@gitbook/tsconfig": "workspace:*"
15+
"@gitbook/tsconfig": "workspace:*",
16+
"@vanillaes/csv": "^3.0.4",
17+
"commander": "^14.0.2"
1518
},
1619
"scripts": {
1720
"typecheck": "tsc --noEmit",
1821
"check": "gitbook check",
22+
"test": "bun test",
1923
"publish-integrations": "gitbook publish .",
20-
"publish-integrations-staging": "gitbook publish ."
24+
"publish-integrations-staging": "gitbook publish .",
25+
"train-classifier": "bun run scripts/train-classifier.ts --"
2126
}
2227
}
Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,190 @@
1+
import fs from 'fs';
2+
import path from 'path';
3+
import { Command } from 'commander';
4+
import { NeuralNetwork, Trainers, Tensor } from 'toygrad';
5+
import { Logger } from '@gitbook/runtime';
6+
import { parse as csvParse } from '@vanillaes/csv';
7+
8+
const logger = Logger('slack:scripts:train-classifier');
9+
10+
const __dirname = path.dirname(new URL(import.meta.url).pathname);
11+
12+
/**
13+
* Clean text: remove mentions, punctuation, lowercase.
14+
*/
15+
function cleanText(text: string): string {
16+
return text
17+
.replace(/@\w+/g, '') // remove mentions
18+
.replace(/[^\w\s]/g, '') // remove punctuation
19+
.toLowerCase()
20+
.trim();
21+
}
22+
23+
/**
24+
* Load CSV training data.
25+
*/
26+
function loadTrainingData(filePath: string): { text: string; intent: string }[] {
27+
if (!fs.existsSync(filePath)) {
28+
throw new Error(`CSV file not found: ${filePath}`);
29+
}
30+
31+
const csvContent = fs.readFileSync(filePath, 'utf-8');
32+
const rows = csvParse(csvContent) as string[][];
33+
const [, ...data] = rows;
34+
35+
return data.map(([text, intent]) => ({
36+
text: text.trim(),
37+
intent: intent.trim(),
38+
}));
39+
}
40+
41+
/**
42+
* Build vocabulary from training data.
43+
*/
44+
function buildVocabulary(records: { text: string; intent: string }[]): string[] {
45+
const vocabSet = new Set<string>();
46+
for (const r of records) {
47+
const words = cleanText(r.text).match(/\b\w+\b/g) || [];
48+
for (const word of words) {
49+
vocabSet.add(word);
50+
}
51+
}
52+
return Array.from(vocabSet);
53+
}
54+
55+
/**
56+
* Convert text to weighted bag-of-words vector.
57+
*/
58+
function textToWordVector(text: string, vocabulary: string[]): Float32Array {
59+
const vector = new Float32Array(vocabulary.length);
60+
const words = cleanText(text).match(/\b\w+\b/g) || [];
61+
for (const word of words) {
62+
const idx = vocabulary.indexOf(word);
63+
if (idx !== -1) {
64+
vector[idx] += 1;
65+
}
66+
}
67+
return vector;
68+
}
69+
70+
/**
71+
* Build or load the neural network model.
72+
*/
73+
function buildModel(inputSize: number, outputSize: number): NeuralNetwork {
74+
const options: NeuralNetwork['options'] = {
75+
layers: [
76+
{ type: 'input', sx: 1, sy: 1, sz: inputSize },
77+
{ type: 'dense', filters: 32 },
78+
{ type: 'relu' },
79+
{ type: 'dense', filters: 16 },
80+
{ type: 'relu' },
81+
{ type: 'dense', filters: outputSize },
82+
{ type: 'softmax' },
83+
],
84+
};
85+
86+
const nn = new NeuralNetwork(options);
87+
return nn;
88+
}
89+
90+
async function trainModel(
91+
nn: NeuralNetwork,
92+
records: { text: string; intent: string }[],
93+
vocabulary: string[],
94+
outputLabels: string[],
95+
epochs = 50,
96+
batchSize = 4,
97+
) {
98+
const trainingInputs: Tensor[] = [];
99+
const trainingTargets: number[] = []; // target label indices
100+
101+
for (const r of records) {
102+
const vec = textToWordVector(r.text, vocabulary);
103+
const inputTensor = new Tensor(1, 1, vec.length, vec);
104+
trainingInputs.push(inputTensor);
105+
106+
const targetIdx = outputLabels.indexOf(r.intent);
107+
if (targetIdx === -1) {
108+
throw new Error(`Unknown intent label: ${r.intent}`);
109+
}
110+
trainingTargets.push(targetIdx);
111+
}
112+
113+
const trainer = new Trainers.Adadelta(nn, {
114+
batchSize: batchSize,
115+
});
116+
117+
logger.info(`🚀 Training model on ${records.length} examples for ${epochs} epochs...`);
118+
119+
for (let epoch = 0; epoch < epochs; epoch++) {
120+
for (let i = 0; i < trainingInputs.length; i++) {
121+
trainer.train(trainingInputs[i], trainingTargets[i]);
122+
}
123+
if ((epoch + 1) % 10 === 0) {
124+
logger.info(`Epoch ${epoch + 1}/${epochs} done`);
125+
}
126+
}
127+
128+
logger.info('✅ Training complete');
129+
}
130+
131+
/**
132+
* Save model, vocabulary, and output labels into JSON file for classifier to use.
133+
*/
134+
function saveModel(
135+
nn: NeuralNetwork,
136+
vocabulary: string[],
137+
outputLabels: string[],
138+
filePath: string,
139+
) {
140+
const options = nn.getAsOptions('f32');
141+
const serialized = {
142+
model: options,
143+
vocabulary,
144+
outputLabels,
145+
};
146+
fs.writeFileSync(filePath, JSON.stringify(serialized, null, 2));
147+
logger.info(`💾 Saved updated classifier to ${filePath}`);
148+
}
149+
150+
async function main() {
151+
const program = new Command();
152+
153+
program
154+
.name('train-classifier')
155+
.description('Train or update the action intent classifier from a CSV file')
156+
.requiredOption('-c, --csv <path>', 'Path to the training CSV file')
157+
.option(
158+
'-m, --model <path>',
159+
'Path to serialized model JSON',
160+
'../src/actions/intent/classifier-model.json',
161+
)
162+
.parse(process.argv);
163+
164+
const opts = program.opts();
165+
const csvPath = path.resolve(opts.csv);
166+
const modelPath = path.resolve(__dirname, opts.model);
167+
168+
try {
169+
const records = loadTrainingData(csvPath);
170+
const vocabulary = buildVocabulary(records);
171+
const outputLabels = Array.from(new Set(records.map((r) => r.intent)));
172+
173+
const inputSize = vocabulary.length;
174+
const outputSize = outputLabels.length;
175+
176+
logger.info(`Vocabulary size: ${inputSize}`);
177+
logger.info(`Output labels: ${outputLabels.join(', ')}`);
178+
179+
const nn = buildModel(inputSize, outputSize);
180+
181+
await trainModel(nn, records, vocabulary, outputLabels, 50, 4);
182+
183+
saveModel(nn, vocabulary, outputLabels, modelPath);
184+
} catch (err) {
185+
logger.error('❌ Error:', (err as Error).message);
186+
process.exit(1);
187+
}
188+
}
189+
190+
main();

0 commit comments

Comments
 (0)