Skip to content

Commit 31fc433

Browse files
authored
Add intent classifier example (#268)
1 parent 515b5ac commit 31fc433

26 files changed

+140068
-0
lines changed

intent-classifier/.babelrc

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
{
2+
"presets": [
3+
[
4+
"env",
5+
{
6+
"esmodules": false,
7+
"targets": {
8+
"browsers": [
9+
"> 3%"
10+
]
11+
}
12+
}
13+
]
14+
],
15+
}

intent-classifier/.eslintrc.json

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
{
2+
"env": {
3+
"browser": true,
4+
"es6": true,
5+
"node": true
6+
},
7+
"extends": "google",
8+
"globals": {
9+
"Atomics": "readonly",
10+
"SharedArrayBuffer": "readonly"
11+
},
12+
"parserOptions": {
13+
"ecmaVersion": 2018,
14+
"sourceType": "module"
15+
},
16+
"rules": {
17+
"arrow-parens": [
18+
2,
19+
"as-needed"
20+
],
21+
"max-len": [
22+
2,
23+
{
24+
"code": 80,
25+
"tabWidth": 2,
26+
"ignoreUrls": true,
27+
"ignorePattern": "^import |^export"
28+
}
29+
],
30+
"new-parens": 2,
31+
"no-debugger": 2,
32+
"no-throw-literal": 2,
33+
"no-unused-expression": true,
34+
"radix": 2,
35+
"switch-default": true,
36+
"use-isnan": 2,
37+
"require-jsdoc": 0
38+
}
39+
}

intent-classifier/.gitignore

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
node_modules
2+
dist*
3+
.cache
4+
yarn-error.log
5+
*.tgz
6+
*-ubyte
7+
*.pyc
8+
.yalc
9+
.DS_STORE
10+
yalc.lock
11+
model.json
12+
weights.bin
13+
training/models
14+
training/data

intent-classifier/README.md

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
# Text Classifer Example
2+
3+
This example uses the universal sentence encoder to train two text
4+
classification models.
5+
6+
1. An 'intent' classifier that classifies sentences into categories representing
7+
user intent for a query.
8+
2. A token tagger, that classifies tokens within a weather releated query to
9+
identify location related tokens.
10+
11+
## Setup and Installation
12+
13+
Note: These instructions use `yarn`, but you can use `npm run` instead if you
14+
do not have `yarn` installed.
15+
16+
Install dependencies
17+
18+
```
19+
yarn
20+
```
21+
22+
## Preparing training data
23+
24+
There are four npm/yarn scripts listed in package.json for preparing the training data. Each writes out one of more new files.
25+
26+
The two scripts needed to train the intent classifier are:
27+
28+
1. `yarn convert-raw-to-csv`: Converts the raw data into a csv format
29+
2. `yarn convert-csv-to-tensors`: Converts the strings in the CSV created in step 1 into tensors.
30+
31+
The two scripts needed to train the token tagger are:
32+
33+
1. `yarn convert-raw-to-tagged-tokens`: Extracts tokens from sentences in the original data and tags each token with a category
34+
2. `yarn convert-tokens-to-embeddings`: embeds the tokens from the queries using the universal sentence encoder and writes out a look-up-table.
35+
36+
You can run all four of these commands with
37+
38+
```
39+
yarn prep-data
40+
```
41+
42+
You only need to do this once. This process can take 15-25 mins. The output of these scripts will be written to the `training/data` folder.
43+
44+
## Train the models
45+
46+
To train the intent classifier model run:
47+
48+
```
49+
yarn train-intent
50+
```
51+
52+
To train the token tagging model run:
53+
54+
```
55+
yarn train-tagger
56+
```
57+
58+
Each of these scripts take multiple options, look at `training/train-intent.js` and `training/train-tagger.js` for details.
59+
60+
These scripts will output model artifacts in the `training/models` folder.
61+
62+
## Run the apps
63+
64+
Once the models are trained you can use the following commands
65+
to run the demo apps for each model.
66+
67+
68+
```
69+
yarn intent-app
70+
```
71+
72+
73+
```
74+
yarn tagger-app
75+
```

intent-classifier/app/index.html

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
<!-- Copyright 2019 Google LLC. All Rights Reserved.
2+
Licensed under the Apache License, Version 2.0 (the "License");
3+
you may not use this file except in compliance with the License.
4+
You may obtain a copy of the License at
5+
http://www.apache.org/licenses/LICENSE-2.0
6+
Unless required by applicable law or agreed to in writing, software
7+
distributed under the License is distributed on an "AS IS" BASIS,
8+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
See the License for the specific language governing permissions and
10+
limitations under the License.
11+
============================================================================== -->
12+
13+
14+
<html>
15+
16+
<head>
17+
<meta charset="UTF-8">
18+
<meta name="viewport" content="width=device-width, initial-scale=1">
19+
20+
<link rel="stylesheet" href="./index.scss">
21+
<style>
22+
</style>
23+
</head>
24+
25+
<body>
26+
<div class="container">
27+
<div class="app">
28+
<div class="title">
29+
<p>Chatty McChatterson</p>
30+
</div>
31+
<div id='message-area' class="messages">
32+
<div class="message bot">
33+
👋! Hello, I can:
34+
<b>get the weather (⛅)</b>,
35+
<!-- <br /> -->
36+
&nbsp;<b>play Music (🎵🎺🎵)</b>,
37+
<!-- <br /> -->
38+
&nbsp;and <b>put things on your playlist (💿➡️📇)</b>
39+
</div>
40+
</div>
41+
<div class="controls">
42+
<form id="textentry">
43+
<input id="textbox" type="text" />
44+
<input id="submit" type="submit" value="💬">
45+
</form>
46+
</div>
47+
</div>
48+
</div>
49+
50+
51+
<script src="./index.js"></script>
52+
</body>
53+
54+
</html>

intent-classifier/app/index.js

Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
/**
2+
* @license
3+
* Copyright 2019 Google LLC. All Rights Reserved.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
18+
import * as useLoader from '@tensorflow-models/universal-sentence-encoder';
19+
import * as tf from '@tensorflow/tfjs';
20+
21+
tf.ENV.set('WEBGL_PACK', false);
22+
23+
const DENSE_MODEL_URL = './models/intent/model.json';
24+
const METADATA_URL = './models/intent/intent_metadata.json';
25+
26+
let use;
27+
async function loadUSE() {
28+
if (use == null) {
29+
use = await useLoader.load();
30+
}
31+
return use;
32+
}
33+
34+
let intent;
35+
async function loadIntentClassifer(url) {
36+
if (intent == null) {
37+
intent = await tf.loadLayersModel(url);
38+
}
39+
return intent;
40+
}
41+
42+
let metadata;
43+
async function loadMetadata() {
44+
if (metadata == null) {
45+
const resp = await fetch(METADATA_URL);
46+
metadata = resp.json();
47+
}
48+
return metadata;
49+
}
50+
51+
async function classify(sentences) {
52+
const [use, intent, metadata] = await Promise.all(
53+
loadUSE(), loadIntentClassifer(DENSE_MODEL_URL), loadMetadata());
54+
55+
const {labels} = metadata;
56+
console.log('classifying', sentences);
57+
console.time(`Embedding ${sentences.length} sentences`);
58+
const activations = await use.embed(sentences);
59+
console.timeEnd(`Embedding ${sentences.length} sentences`);
60+
61+
const prediction = intent.predict(activations);
62+
63+
const predsArr = await prediction.array();
64+
const preview = [predsArr[0].slice()];
65+
preview.unshift(labels);
66+
console.table(preview);
67+
68+
tf.dispose([activations, prediction]);
69+
70+
return predsArr[0];
71+
}
72+
73+
const THRESHOLD = 0.90;
74+
async function getClassificationMessage(softmaxArr) {
75+
const {labels} = await loadMetadata();
76+
const max = Math.max(...softmaxArr);
77+
const maxIndex = softmaxArr.indexOf(max);
78+
const intentLabel = labels[maxIndex];
79+
80+
if (max < THRESHOLD) {
81+
return '¯\\_(ツ)_/¯';
82+
} else {
83+
let response;
84+
switch (intentLabel) {
85+
case 'AddToPlaylist':
86+
response = '💿➡️📇';
87+
break;
88+
case 'GetWeather':
89+
response = '⛅';
90+
break;
91+
case 'PlayMusic':
92+
response = '🎵🎺🎵';
93+
break;
94+
default:
95+
response = '?';
96+
break;
97+
}
98+
return response;
99+
}
100+
}
101+
102+
async function onSendMessage(inputText) {
103+
if (inputText != null && inputText.length > 0) {
104+
// Add the input text to the chat window
105+
const msgId = appendMessage(inputText, 'input');
106+
// Classify the text
107+
const classification = await classify([inputText]);
108+
// Add the response to the chat window
109+
const response = await getClassificationMessage(classification);
110+
appendMessage(response, 'bot', msgId);
111+
}
112+
}
113+
114+
let messageId = 0;
115+
function appendMessage(message, sender, appendAfter) {
116+
const messageDiv = document.createElement('div');
117+
messageDiv.classList = `message ${sender}`;
118+
messageDiv.innerHTML = message;
119+
messageDiv.dataset.messageId = messageId++;
120+
121+
const messageArea = document.getElementById('message-area');
122+
if (appendAfter == null) {
123+
messageArea.appendChild(messageDiv);
124+
} else {
125+
const inputMsg =
126+
document.querySelector(`.message[data-message-id="${appendAfter}"]`);
127+
inputMsg.parentNode.insertBefore(messageDiv, inputMsg.nextElementSibling);
128+
}
129+
130+
// Scroll the message area to the bottom.
131+
messageArea.scroll({top: messageArea.scrollHeight, behavior: 'smooth'});
132+
133+
// Return this message id so that a reply can be posted to it later
134+
return messageDiv.dataset.messageId;
135+
}
136+
137+
function setupListeners() {
138+
const form = document.getElementById('textentry');
139+
const textbox = document.getElementById('textbox');
140+
form.addEventListener('submit', event => {
141+
event.preventDefault();
142+
event.stopPropagation();
143+
144+
const inputText = textbox.value;
145+
onSendMessage(inputText);
146+
textbox.value = '';
147+
}, false);
148+
}
149+
150+
function warmup() {
151+
classify('hello there');
152+
}
153+
154+
window.addEventListener('load', function() {
155+
setupListeners();
156+
warmup();
157+
});

0 commit comments

Comments
 (0)