-
Notifications
You must be signed in to change notification settings - Fork 346
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MRG] add TimeSeriesScaleMeanMaxVariance() #333
Open
tonylee2016
wants to merge
5
commits into
tslearn-team:main
Choose a base branch
from
tonylee2016:TimeSeriesScaleMeanMaxVariance
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
72d67e3
add TimeSeriesScaleMeanMaxVariance
tonylee2016 e7eaea5
some update and fix
tonylee2016 58ac80a
Add compile.py
galaxie500 2b98526
add .whl for docker-airflow build
galaxie500 2819398
Merge pull request #1 from galaxie500/TimeSeriesScaleMeanMaxVariance
tonylee2016 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,7 +16,7 @@ __pycache__/ | |
env/ | ||
build/ | ||
develop-eggs/ | ||
dist/ | ||
#dist/ | ||
downloads/ | ||
eggs/ | ||
.eggs/ | ||
|
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
# coding: utf-8 | ||
import os | ||
import fnmatch | ||
import sysconfig | ||
from setuptools import setup, find_packages | ||
from setuptools.command.build_py import build_py as _build_py | ||
from Cython.Build import cythonize | ||
|
||
EXCLUDE_FILES = [ | ||
'tslearn/main.py' | ||
] | ||
|
||
|
||
def get_ext_paths(root_dir, exclude_files): | ||
"""get filepaths for compilation""" | ||
paths = [] | ||
|
||
for root, dirs, files in os.walk(root_dir): | ||
for filename in files: | ||
if os.path.splitext(filename)[1] != '.py': | ||
continue | ||
|
||
file_path = os.path.join(root, filename) | ||
if file_path in exclude_files: | ||
continue | ||
|
||
paths.append(file_path) | ||
return paths | ||
|
||
|
||
class build_py(_build_py): | ||
|
||
def find_package_modules(self, package, package_dir): | ||
ext_suffix = sysconfig.get_config_var('EXT_SUFFIX') | ||
modules = super().find_package_modules(package, package_dir) | ||
filtered_modules = [] | ||
for (pkg, mod, filepath) in modules: | ||
if os.path.exists(filepath.replace('.py', ext_suffix)): | ||
continue | ||
filtered_modules.append((pkg, mod, filepath, )) | ||
return filtered_modules | ||
|
||
|
||
setup( | ||
name='tslearn', | ||
version='0.0.0', | ||
packages=find_packages(), | ||
ext_modules=cythonize( | ||
get_ext_paths('tslearn', EXCLUDE_FILES), | ||
compiler_directives={'language_level': 3} | ||
), | ||
cmdclass={ | ||
'build_py':build_py | ||
} | ||
) | ||
|
||
# to compile working directory to wheel | ||
#$ python setup.py bdist_wheel | ||
#$ unzip tslearn-0.0.0-cp38-cp38-linux_x86_64.whl |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,16 +1,18 @@ | ||
""" | ||
The :mod:`tslearn.preprocessing` module gathers time series scalers and | ||
The :mod:`tslearn.preprocessing` module gathers time series scalers and | ||
resamplers. | ||
""" | ||
|
||
from .preprocessing import ( | ||
TimeSeriesScalerMeanVariance, | ||
TimeSeriesScalerMinMax, | ||
TimeSeriesResampler | ||
TimeSeriesResampler, | ||
TimeSeriesScaleMeanMaxVariance | ||
) | ||
|
||
__all__ = [ | ||
"TimeSeriesResampler", | ||
"TimeSeriesScalerMinMax", | ||
"TimeSeriesScalerMeanVariance" | ||
"TimeSeriesScalerMeanVariance", | ||
"TimeSeriesScaleMeanMaxVariance" | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -296,3 +296,57 @@ def transform(self, X, y=None, **kwargs): | |
|
||
def _more_tags(self): | ||
return {'allow_nan': True} | ||
|
||
|
||
class TimeSeriesScaleMeanMaxVariance(TimeSeriesScalerMeanVariance): | ||
"""Scaler for time series. Scales time series so that their mean (resp. | ||
standard deviation) in the signal with the max amplitue is | ||
mu (resp. std). The scaling relationships between each signal are preserved | ||
This is supplement to the TimeSeriesScalerMeanVariance method | ||
|
||
Parameters | ||
---------- | ||
mu : float (default: 0.) | ||
Mean of the output time series. | ||
std : float (default: 1.) | ||
Standard deviation of the output time series. | ||
|
||
Notes | ||
----- | ||
This method requires a dataset of equal-sized time series. | ||
|
||
NaNs within a time series are ignored when calculating mu and std. | ||
""" | ||
|
||
def transform(self, X, y=None, **kwargs): | ||
"""Fit to data, then transform it. | ||
|
||
Parameters | ||
---------- | ||
X : array-like of shape (n_ts, sz, d) | ||
Time series dataset to be rescaled | ||
|
||
Returns | ||
------- | ||
numpy.ndarray | ||
Rescaled time series dataset | ||
""" | ||
check_is_fitted(self, '_X_fit_dims') | ||
X = check_array(X, allow_nd=True, force_all_finite=False) | ||
X_ = to_time_series_dataset(X) | ||
X_ = check_dims(X_, X_fit_dims=self._X_fit_dims, extend=False) | ||
mean_t = numpy.nanmean(X_, axis=1, keepdims=True) | ||
std_t = numpy.nanstd(X_, axis=1, keepdims=True) | ||
# retain the scaling relation cross the signals, | ||
# the max std_t is set to self.std | ||
max_std = max(std_t) | ||
if max_std ==0.: max_std = 1 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Split up over multiple lines. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. And add space to right and left of |
||
X_ = (X_ - mean_t) * self.std / max_std + self.mu | ||
|
||
return X_ | ||
|
||
def _more_tags(self): | ||
return {'allow_nan': True, '_skip_test': True} | ||
|
||
|
||
|
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are these comments still relevant? Does it need to be stateful? I.e. should there be a fit() and transform() where in the fit() the self.std is stored and used in the transform() method?