@@ -312,58 +312,50 @@ class SymmetricalLogScale(ScaleBase):
312
312
name = 'symlog'
313
313
314
314
class SymmetricalLogTransform (Transform ):
315
+ input_dims = 1
316
+ output_dims = 1
317
+ is_separable = True
318
+
319
+ def __init__ (self , base , linthresh ):
320
+ Transform .__init__ (self )
321
+ self .base = base
322
+ self .linthresh = linthresh
323
+ self ._log_base = np .log (base )
324
+ self ._linadjust = (np .log (linthresh ) / self ._log_base ) / linthresh
325
+
326
+ def transform (self , a ):
327
+ sign = np .sign (a )
328
+ masked = ma .masked_inside (a , - self .linthresh , self .linthresh , copy = False )
329
+ log = sign * self .linthresh * (1 + ma .log (np .abs (masked ) / self .linthresh ))
330
+ if masked .mask .any ():
331
+ return ma .where (masked .mask , a , log )
332
+ else :
333
+ return log
334
+
335
+ def inverted (self ):
336
+ return SymmetricalLogScale .InvertedSymmetricalLogTransform (self .base , self .linthresh )
337
+
338
+ class InvertedSymmetricalLogTransform (Transform ):
315
339
input_dims = 1
316
340
output_dims = 1
317
341
is_separable = True
318
342
319
343
def __init__ (self , base , linthresh ):
320
344
Transform .__init__ (self )
321
345
self .base = base
322
- self .linthresh = abs ( linthresh )
346
+ self .linthresh = linthresh
323
347
self ._log_base = np .log (base )
324
- logb_linthresh = np .log (linthresh ) / self ._log_base
325
- self ._linadjust = 1.0 - logb_linthresh
326
- self ._linscale = 1.0 / linthresh
348
+ self ._log_linthresh = np .log (linthresh ) / self ._log_base
349
+ self ._linadjust = linthresh / (np .log (linthresh ) / self ._log_base )
327
350
328
351
def transform (self , a ):
329
- a = np .asarray (a )
330
352
sign = np .sign (a )
331
353
masked = ma .masked_inside (a , - self .linthresh , self .linthresh , copy = False )
354
+ exp = sign * self .linthresh * ma .exp (sign * masked / self .linthresh - 1 )
332
355
if masked .mask .any ():
333
- log = sign * (ma .log (np .abs (masked )) / self ._log_base + self ._linadjust )
334
- return np .asarray (ma .where (masked .mask , a * self ._linscale , log ))
356
+ return ma .where (masked .mask , a , exp )
335
357
else :
336
- return sign * (np .log (np .abs (a )) / self ._log_base + self ._linadjust )
337
-
338
- def inverted (self ):
339
- return SymmetricalLogScale .InvertedSymmetricalLogTransform (
340
- self .base , self .linthresh )
341
-
342
- class InvertedSymmetricalLogTransform (Transform ):
343
- input_dims = 1
344
- output_dims = 1
345
- is_separable = True
346
-
347
- def __init__ (self , base , linthresh ):
348
- Transform .__init__ (self )
349
- self .base = base
350
- self .linthresh = linthresh
351
- log_base = np .log (base )
352
- logb_linthresh = np .log (linthresh ) / log_base
353
- self ._linadjust = 1.0 - logb_linthresh
354
-
355
- def transform (self , a ):
356
- a = np .asarray (a )
357
- sign = np .sign (a )
358
- masked = ma .masked_inside (a , - 1.0 , 1.0 , copy = False )
359
- result = np .where ((a >= - 1.0 ) & (a <= 1.0 ),
360
- a * self .linthresh ,
361
- sign * np .power (self .base , np .abs (a - sign * self ._linadjust )))
362
- return result
363
-
364
- def inverted (self ):
365
- return SymmetricalLogScale .SymmetricalLogTransform (
366
- self .base , self .linthresh )
358
+ return exp
367
359
368
360
def __init__ (self , axis , ** kwargs ):
369
361
"""
@@ -395,7 +387,7 @@ def __init__(self, axis, **kwargs):
395
387
396
388
assert base > 0.0
397
389
assert linthresh > 0.0
398
-
390
+
399
391
self .base = base
400
392
self .linthresh = linthresh
401
393
self .subs = subs
0 commit comments