1313* ==========================================================================
1414*/
1515
16- import { convertToNumericTensor1D , convertToNumericTensor2D } from '../utils'
16+ import {
17+ convertToNumericTensor1D_2D ,
18+ convertToNumericTensor2D
19+ } from '../utils'
1720import {
1821 Scikit2D ,
1922 Scikit1D ,
@@ -23,8 +26,7 @@ import {
2326 Tensor2D ,
2427 Tensor ,
2528 ModelCompileArgs ,
26- ModelFitArgs ,
27- RecursiveArray
29+ ModelFitArgs
2830} from '../types'
2931import { OneHotEncoder } from '../preprocessing/OneHotEncoder'
3032import { assert } from '../typesUtils'
@@ -103,6 +105,7 @@ export class SGDClassifier extends ClassifierMixin {
103105 lossType : LossTypes
104106 oneHot : OneHotEncoder
105107 tf : any
108+ isMultiOutput : boolean
106109
107110 constructor ( {
108111 modelFitArgs,
@@ -119,6 +122,7 @@ export class SGDClassifier extends ClassifierMixin {
119122 this . denseLayerArgs = denseLayerArgs
120123 this . optimizerType = optimizerType
121124 this . lossType = lossType
125+ this . isMultiOutput = false
122126 // Next steps: Implement "drop" mechanics for OneHotEncoder
123127 // There is a possibility to do a drop => if_binary which would
124128 // squash down on the number of variables that we'd have to learn
@@ -200,12 +204,17 @@ export class SGDClassifier extends ClassifierMixin {
200204 * // lr model weights have been updated
201205 */
202206
203- public async fit ( X : Scikit2D , y : Scikit1D ) : Promise < SGDClassifier > {
207+ public async fit (
208+ X : Scikit2D ,
209+ y : Scikit1D | Scikit2D
210+ ) : Promise < SGDClassifier > {
204211 let XTwoD = convertToNumericTensor2D ( X )
205- let yOneD = convertToNumericTensor1D ( y )
212+ let yOneD = convertToNumericTensor1D_2D ( y )
206213
207214 const yTwoD = this . initializeModelForClassification ( yOneD )
208-
215+ if ( yOneD . shape . length > 1 ) {
216+ this . isMultiOutput = true
217+ }
209218 if ( this . model . layers . length === 0 ) {
210219 this . initializeModel ( XTwoD , yTwoD )
211220 }
@@ -344,6 +353,9 @@ export class SGDClassifier extends ClassifierMixin {
344353 public predict ( X : Scikit2D ) : Tensor1D {
345354 assert ( this . model . layers . length > 0 , 'Need to call "fit" before "predict"' )
346355 const y2D = this . predictProba ( X )
356+ if ( this . isMultiOutput ) {
357+ return this . tf . oneHot ( y2D . argMax ( 1 ) , y2D . shape [ 1 ] )
358+ }
347359 return this . tf . tensor1d ( this . oneHot . inverseTransform ( y2D ) )
348360 }
349361
@@ -418,10 +430,4 @@ export class SGDClassifier extends ClassifierMixin {
418430
419431 return intercept
420432 }
421-
422- private getModelWeight ( ) : Promise < RecursiveArray < number > > {
423- return Promise . all (
424- this . model . getWeights ( ) . map ( ( weight : any ) => weight . array ( ) )
425- )
426- }
427433}
0 commit comments