|
| 1 | +/** |
| 2 | + * @license |
| 3 | + * Copyright 2019 Google LLC. All Rights Reserved. |
| 4 | + * Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | + * you may not use this file except in compliance with the License. |
| 6 | + * You may obtain a copy of the License at |
| 7 | + * |
| 8 | + * http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | + * |
| 10 | + * Unless required by applicable law or agreed to in writing, software |
| 11 | + * distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | + * See the License for the specific language governing permissions and |
| 14 | + * limitations under the License. |
| 15 | + * ============================================================================= |
| 16 | + */ |
| 17 | + |
| 18 | +import 'babel-polyfill'; |
| 19 | +import * as tf from '@tensorflow/tfjs'; |
| 20 | +import {IMAGENET_CLASSES} from './imagenet_classes'; |
| 21 | + |
| 22 | +// Where to load the model from. |
| 23 | +const MOBILENET_MODEL_TFHUB_URL = |
| 24 | + 'https://tfhub.dev/google/imagenet/mobilenet_v2_100_224/classification/2' |
| 25 | +// Size of the image expected by mobilenet. |
| 26 | +const IMAGE_SIZE = 224; |
| 27 | +// The minimum image size to consider classifying. Below this limit the |
| 28 | +// extension will refuse to classify the image. |
| 29 | +const MIN_IMG_SIZE = 128; |
| 30 | + |
| 31 | +// How many predictions to take. |
| 32 | +const TOPK_PREDICTIONS = 2; |
| 33 | +const FIVE_SECONDS_IN_MS = 5000; |
| 34 | +/** |
| 35 | + * What action to take when someone clicks the right-click menu option. |
| 36 | + * Here it takes the url of the right-clicked image and the current tabId |
| 37 | + * and forwards it to the imageClassifier's analyzeImage method. |
| 38 | + */ |
| 39 | +function clickMenuCallback(info, tab) { |
| 40 | + imageClassifier.analyzeImage(info.srcUrl, tab.id); |
| 41 | +} |
| 42 | + |
| 43 | +/** |
| 44 | + * Adds a right-click menu option to trigger classifying the image. |
| 45 | + * The menu option should only appear when right-clicking an image. |
| 46 | + */ |
| 47 | +chrome.contextMenus.create({ |
| 48 | + title: 'Classify image with TensorFlow.js ', |
| 49 | + contexts: ['image'], |
| 50 | + onclick: clickMenuCallback |
| 51 | +}); |
| 52 | + |
| 53 | +/** |
| 54 | + * Async loads a mobilenet on construction. Subsequently handles |
| 55 | + * requests to classify images through the .analyzeImage API. |
| 56 | + * Successful requests will post a chrome message with |
| 57 | + * 'IMAGE_CLICK_PROCESSED' action, which the content.js can |
| 58 | + * hear and use to manipulate the DOM. |
| 59 | + */ |
| 60 | +class ImageClassifier { |
| 61 | + constructor() { |
| 62 | + this.loadModel(); |
| 63 | + } |
| 64 | + |
| 65 | + /** |
| 66 | + * Loads mobilenet from URL and keeps a reference to it in the object. |
| 67 | + */ |
| 68 | + async loadModel() { |
| 69 | + console.log('Loading model...'); |
| 70 | + const startTime = performance.now(); |
| 71 | + try { |
| 72 | + this.model = |
| 73 | + await tf.loadGraphModel(MOBILENET_MODEL_TFHUB_URL, {fromTFHub: true}); |
| 74 | + // Warms up the model by causing intermediate tensor values |
| 75 | + // to be built and pushed to GPU. |
| 76 | + tf.tidy(() => { |
| 77 | + this.model.predict(tf.zeros([1, IMAGE_SIZE, IMAGE_SIZE, 3])); |
| 78 | + }); |
| 79 | + const totalTime = Math.floor(performance.now() - startTime); |
| 80 | + console.log(`Model loaded and initialized in ${totalTime} ms...`); |
| 81 | + } catch { |
| 82 | + console.error( |
| 83 | + `Unable to load model from URL: ${MOBILENET_MODEL_TFHUB_URL}`); |
| 84 | + } |
| 85 | + } |
| 86 | + |
| 87 | + /** |
| 88 | + * Triggers the model to make a prediction on the image referenced by url. |
| 89 | + * After a successful prediction a IMAGE_CLICK_PROCESSED message when |
| 90 | + * complete, for the content.js script to hear and update the DOM with the |
| 91 | + * results of the prediction. |
| 92 | + * |
| 93 | + * @param {string} url url of image to analyze. |
| 94 | + * @param {number} tabId which tab the request comes from. |
| 95 | + */ |
| 96 | + async analyzeImage(url, tabId) { |
| 97 | + if (!tabId) { |
| 98 | + console.error('No tab. No prediction.'); |
| 99 | + return; |
| 100 | + } |
| 101 | + if (!this.model) { |
| 102 | + console.log('Waiting for model to load...'); |
| 103 | + setTimeout(() => {this.analyzeImage(url)}, FIVE_SECONDS_IN_MS); |
| 104 | + return; |
| 105 | + } |
| 106 | + let message; |
| 107 | + this.loadImage(url).then( |
| 108 | + async (img) => { |
| 109 | + if (!img) { |
| 110 | + console.error( |
| 111 | + 'Could not load image. Either too small or unavailable.'); |
| 112 | + return; |
| 113 | + } |
| 114 | + const predictions = await this.predict(img); |
| 115 | + message = {action: 'IMAGE_CLICK_PROCESSED', url, predictions}; |
| 116 | + chrome.tabs.sendMessage(tabId, message); |
| 117 | + }, |
| 118 | + (reason) => { |
| 119 | + console.error(`Failed to analyze: ${reason}`); |
| 120 | + }); |
| 121 | + } |
| 122 | + |
| 123 | + /** |
| 124 | + * Creates a dom element and loads the image pointed to by the provided src. |
| 125 | + * @param {string} src URL of the image to load. |
| 126 | + */ |
| 127 | + async loadImage(src) { |
| 128 | + return new Promise((resolve, reject) => { |
| 129 | + const img = document.createElement('img'); |
| 130 | + img.crossOrigin = 'anonymous'; |
| 131 | + img.onerror = function(e) { |
| 132 | + reject(`Could not load image from external source ${src}.`); |
| 133 | + }; |
| 134 | + img.onload = function(e) { |
| 135 | + if ((img.height && img.height > MIN_IMG_SIZE) || |
| 136 | + (img.width && img.width > MIN_IMG_SIZE)) { |
| 137 | + img.width = IMAGE_SIZE; |
| 138 | + img.height = IMAGE_SIZE; |
| 139 | + resolve(img); |
| 140 | + } |
| 141 | + // Fail out if either dimension is less than MIN_IMG_SIZE. |
| 142 | + reject(`Image size too small. [${img.height} x ${ |
| 143 | + img.width}] vs. minimum [${MIN_IMG_SIZE} x ${MIN_IMG_SIZE}]`); |
| 144 | + }; |
| 145 | + img.src = src; |
| 146 | + }); |
| 147 | + } |
| 148 | + |
| 149 | + /** |
| 150 | + * Sorts predictions by score and keeps only topK |
| 151 | + * @param {Tensor} logits A tensor with one element per predicatable class |
| 152 | + * type of mobilenet. Return of executing model.predict on an Image. |
| 153 | + * @param {number} topK how many to keep. |
| 154 | + */ |
| 155 | + async getTopKClasses(logits, topK) { |
| 156 | + const {values, indices} = tf.topk(logits, topK, true); |
| 157 | + const valuesArr = await values.data(); |
| 158 | + const indicesArr = await indices.data(); |
| 159 | + console.log(`indicesArr ${indicesArr}`); |
| 160 | + const topClassesAndProbs = []; |
| 161 | + for (let i = 0; i < topK; i++) { |
| 162 | + topClassesAndProbs.push({ |
| 163 | + className: IMAGENET_CLASSES[indicesArr[i]], |
| 164 | + probability: valuesArr[i] |
| 165 | + }) |
| 166 | + } |
| 167 | + return topClassesAndProbs; |
| 168 | + } |
| 169 | + |
| 170 | + /** |
| 171 | + * Executes the model on the input image, and returns the top predicted |
| 172 | + * classes. |
| 173 | + * @param {HTMLElement} imgElement HTML element holding the image to predict |
| 174 | + * from. Should have the correct size ofr mobilenet. |
| 175 | + */ |
| 176 | + async predict(imgElement) { |
| 177 | + console.log('Predicting...'); |
| 178 | + // The first start time includes the time it takes to extract the image |
| 179 | + // from the HTML and preprocess it, in additon to the predict() call. |
| 180 | + const startTime1 = performance.now(); |
| 181 | + // The second start time excludes the extraction and preprocessing and |
| 182 | + // includes only the predict() call. |
| 183 | + let startTime2; |
| 184 | + const logits = tf.tidy(() => { |
| 185 | + // Mobilenet expects images to be normalized between -1 and 1. |
| 186 | + const img = tf.browser.fromPixels(imgElement).toFloat(); |
| 187 | + // const offset = tf.scalar(127.5); |
| 188 | + // const normalized = img.sub(offset).div(offset); |
| 189 | + const normalized = img.div(tf.scalar(256.0)); |
| 190 | + const batched = normalized.reshape([1, IMAGE_SIZE, IMAGE_SIZE, 3]); |
| 191 | + startTime2 = performance.now(); |
| 192 | + const output = this.model.predict(batched); |
| 193 | + if (output.shape[output.shape.length - 1] === 1001) { |
| 194 | + // Remove the very first logit (background noise). |
| 195 | + return output.slice([0, 1], [-1, 1000]); |
| 196 | + } else if (output.shape[output.shape.length - 1] === 1000) { |
| 197 | + return output; |
| 198 | + } else { |
| 199 | + throw new Error('Unexpected shape...'); |
| 200 | + } |
| 201 | + }); |
| 202 | + |
| 203 | + // Convert logits to probabilities and class names. |
| 204 | + const classes = await this.getTopKClasses(logits, TOPK_PREDICTIONS); |
| 205 | + const totalTime1 = performance.now() - startTime1; |
| 206 | + const totalTime2 = performance.now() - startTime2; |
| 207 | + console.log( |
| 208 | + `Done in ${totalTime1.toFixed(1)} ms ` + |
| 209 | + `(not including preprocessing: ${Math.floor(totalTime2)} ms)`); |
| 210 | + return classes; |
| 211 | + } |
| 212 | +} |
| 213 | + |
| 214 | +const imageClassifier = new ImageClassifier(); |
0 commit comments