42
42
import torch
43
43
from common_utils import TestCase , run_tests
44
44
45
+ def tensor_N (shape , dtype = float ):
46
+ numel = np .prod (shape )
47
+ x = (np .arange (numel , dtype = dtype )).reshape (shape )
48
+ return x
45
49
46
50
class BaseTestCase (TestCase ):
47
51
""" Base class used for all TensorBoard tests """
@@ -315,31 +319,31 @@ def test_empty_input(self):
315
319
316
320
def test_image_with_boxes (self ):
317
321
self .assertTrue (compare_proto (summary .image_boxes ('dummy' ,
318
- np . random . rand ( 3 , 32 , 32 ). astype ( np . float32 ),
322
+ tensor_N ( shape = ( 3 , 32 , 32 )),
319
323
np .array ([[10 , 10 , 40 , 40 ]])),
320
324
self ))
321
325
322
326
def test_image_with_one_channel (self ):
323
327
self .assertTrue (compare_proto (summary .image ('dummy' ,
324
- np . random . rand ( 1 , 8 , 8 ). astype ( np . float32 ),
328
+ tensor_N ( shape = ( 1 , 8 , 8 )),
325
329
dataformats = 'CHW' ),
326
330
self )) # noqa E127
327
331
328
332
def test_image_with_one_channel_batched (self ):
329
333
self .assertTrue (compare_proto (summary .image ('dummy' ,
330
- np . random . rand ( 2 , 1 , 8 , 8 ). astype ( np . float32 ),
334
+ tensor_N ( shape = ( 2 , 1 , 8 , 8 )),
331
335
dataformats = 'NCHW' ),
332
336
self )) # noqa E127
333
337
334
338
def test_image_with_3_channel_batched (self ):
335
339
self .assertTrue (compare_proto (summary .image ('dummy' ,
336
- np . random . rand ( 2 , 3 , 8 , 8 ). astype ( np . float32 ),
340
+ tensor_N ( shape = ( 2 , 3 , 8 , 8 )),
337
341
dataformats = 'NCHW' ),
338
342
self )) # noqa E127
339
343
340
344
def test_image_without_channel (self ):
341
345
self .assertTrue (compare_proto (summary .image ('dummy' ,
342
- np . random . rand ( 8 , 8 ). astype ( np . float32 ),
346
+ tensor_N ( shape = ( 8 , 8 )),
343
347
dataformats = 'HW' ),
344
348
self )) # noqa E127
345
349
@@ -348,56 +352,57 @@ def test_video(self):
348
352
import moviepy # noqa F401
349
353
except ImportError :
350
354
return
351
- self .assertTrue (compare_proto (summary .video ('dummy' , np . random . rand ( 4 , 3 , 1 , 8 , 8 ). astype ( np . float32 )), self ))
352
- summary .video ('dummy' , np .random .rand (16 , 48 , 1 , 28 , 28 ). astype ( np . float32 ) )
353
- summary .video ('dummy' , np .random .rand (20 , 7 , 1 , 8 , 8 ). astype ( np . float32 ) )
355
+ self .assertTrue (compare_proto (summary .video ('dummy' , tensor_N ( shape = ( 4 , 3 , 1 , 8 , 8 ))), self ))
356
+ summary .video ('dummy' , np .random .rand (16 , 48 , 1 , 28 , 28 ))
357
+ summary .video ('dummy' , np .random .rand (20 , 7 , 1 , 8 , 8 ))
354
358
355
359
def test_audio (self ):
356
- self .assertTrue (compare_proto (summary .audio ('dummy' , np . random . rand ( 42 )), self ))
360
+ self .assertTrue (compare_proto (summary .audio ('dummy' , tensor_N ( shape = ( 42 ,) )), self ))
357
361
358
362
def test_text (self ):
359
363
self .assertTrue (compare_proto (summary .text ('dummy' , 'text 123' ), self ))
360
364
361
365
def test_histogram_auto (self ):
362
- self .assertTrue (compare_proto (summary .histogram ('dummy' , np . random . rand ( 1024 ), bins = 'auto' , max_bins = 5 ), self ))
366
+ self .assertTrue (compare_proto (summary .histogram ('dummy' , tensor_N ( shape = ( 1024 ,) ), bins = 'auto' , max_bins = 5 ), self ))
363
367
364
368
def test_histogram_fd (self ):
365
- self .assertTrue (compare_proto (summary .histogram ('dummy' , np . random . rand ( 1024 ), bins = 'fd' , max_bins = 5 ), self ))
369
+ self .assertTrue (compare_proto (summary .histogram ('dummy' , tensor_N ( shape = ( 1024 ,) ), bins = 'fd' , max_bins = 5 ), self ))
366
370
367
371
def test_histogram_doane (self ):
368
- self .assertTrue (compare_proto (summary .histogram ('dummy' , np .random .rand (1024 ), bins = 'doane' , max_bins = 5 ), self ))
372
+ self .assertTrue (compare_proto (summary .histogram ('dummy' , tensor_N (shape = (1024 ,)), bins = 'doane' , max_bins = 5 ), self ))
373
+
374
+ def test_custom_scalars (self ):
375
+ layout = {'Taiwan' : {'twse' : ['Multiline' , ['twse/0050' , 'twse/2330' ]]},
376
+ 'USA' : {'dow' : ['Margin' , ['dow/aaa' , 'dow/bbb' , 'dow/ccc' ]],
377
+ 'nasdaq' : ['Margin' , ['nasdaq/aaa' , 'nasdaq/bbb' , 'nasdaq/ccc' ]]}}
378
+ summary .custom_scalars (layout ) # only smoke test. Because protobuf in python2/3 serialize dictionary differently.
369
379
370
380
def remove_whitespace (string ):
371
381
return string .replace (' ' , '' ).replace ('\t ' , '' ).replace ('\n ' , '' )
372
382
373
383
def compare_proto (str_to_compare , function_ptr ):
374
- # TODO: enable test after tensorboard is ready.
375
- return True
376
- if 'histogram' in function_ptr .id ():
377
- return # numpy.histogram has slight difference between versions
378
-
379
- if 'pr_curve' in function_ptr .id ():
380
- return # pr_curve depends on numpy.histogram
381
384
382
385
module_id = function_ptr .__class__ .__module__
386
+ test_dir = os .path .dirname (sys .modules [module_id ].__file__ )
383
387
functionName = function_ptr .id ().split ('.' )[- 1 ]
384
- test_file = os .path .realpath (sys .modules [module_id ].__file__ )
385
- expected_file = os .path .join (os .path .dirname (test_file ),
388
+ expected_file = os .path .join (test_dir ,
386
389
"expect" ,
387
- module_id .split ('.' )[- 1 ] + '.' + functionName + ".expect" )
390
+ 'TestTensorBoard.' + functionName + ".expect" )
391
+
388
392
assert os .path .exists (expected_file )
389
393
with open (expected_file ) as f :
390
394
expected = f .read ()
391
395
str_to_compare = str (str_to_compare )
396
+ # if not remove_whitespace(str_to_compare) == remove_whitespace(expected):
392
397
return remove_whitespace (str_to_compare ) == remove_whitespace (expected )
393
398
394
399
def write_proto (str_to_compare , function_ptr ):
395
400
module_id = function_ptr .__class__ .__module__
401
+ test_dir = os .path .dirname (sys .modules [module_id ].__file__ )
396
402
functionName = function_ptr .id ().split ('.' )[- 1 ]
397
- test_file = os .path .realpath (sys .modules [module_id ].__file__ )
398
- expected_file = os .path .join (os .path .dirname (test_file ),
403
+ expected_file = os .path .join (test_dir ,
399
404
"expect" ,
400
- module_id . split ( '.' )[ - 1 ] + ' .' + functionName + ".expect" )
405
+ 'TestTensorBoard .' + functionName + ".expect" )
401
406
with open (expected_file , 'w' ) as f :
402
407
f .write (str (str_to_compare ))
403
408
@@ -414,7 +419,7 @@ def forward(self, x):
414
419
return self .l (x )
415
420
416
421
with SummaryWriter (comment = 'LinearModel' ) as w :
417
- w .add_graph (myLinear (), dummy_input , True )
422
+ w .add_graph (myLinear (), dummy_input )
418
423
419
424
def test_mlp_graph (self ):
420
425
dummy_input = (torch .zeros (2 , 1 , 28 , 28 ),)
@@ -442,7 +447,7 @@ def forward(self, x, update_batch_stats=True):
442
447
return h
443
448
444
449
with SummaryWriter (comment = 'MLPModel' ) as w :
445
- w .add_graph (myMLP (), dummy_input , True )
450
+ w .add_graph (myMLP (), dummy_input )
446
451
447
452
def test_wrong_input_size (self ):
448
453
with self .assertRaises (RuntimeError ) as e_info :
@@ -527,7 +532,7 @@ def test_scalar(self):
527
532
528
533
@skipIfNoCaffe2
529
534
def test_caffe2_np (self ):
530
- workspace .FeedBlob ("testBlob" , np . random . randn ( 1 , 3 , 64 , 64 ). astype ( np . float32 ))
535
+ workspace .FeedBlob ("testBlob" , tensor_N ( shape = ( 1 , 3 , 64 , 64 )))
531
536
self .assertIsInstance (make_np ('testBlob' ), np .ndarray )
532
537
533
538
@skipIfNoCaffe2
0 commit comments