-
Notifications
You must be signed in to change notification settings - Fork 235
Expand file tree
/
Copy pathrandom_shift_pipeline.py
More file actions
77 lines (56 loc) · 2.3 KB
/
random_shift_pipeline.py
File metadata and controls
77 lines (56 loc) · 2.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
# Copyright 2020-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
"""RandomShiftPipeline module."""
import numpy as np
from openfl.pipelines.pipeline import NumpyArrayToBytes, TransformationPipeline, Transformer
class RandomShiftTransformer(Transformer):
"""Random Shift Transformer class."""
def __init__(self):
"""Initialize RandomShiftTransformer."""
self.lossy = False
return
def forward(self, data, **kwargs):
"""Forward pass - compress data.
Implement the data transformation.
Args:
data: The data to be transformed.
Returns:
transformed_data: The data after the random shift.
metadata: The metadata for the transformation.
"""
shape = data.shape
random_shift = np.random.uniform(low=-20, high=20, size=shape).astype(np.float32)
transformed_data = data + random_shift
# construct metadata
metadata = {"int_to_float": {}, "int_list": list(shape)}
for idx, val in enumerate(random_shift.flatten(order="C")):
metadata["int_to_float"][idx] = val
return transformed_data, metadata
def backward(self, data, metadata, **kwargs):
"""Backward pass - Decompress data.
Implement the data transformation needed when going the oppposite
direction to the forward method.
Args:
data: The transformed data.
metadata: The metadata for the transformation.
Returns:
The original data before the random shift.
"""
shape = tuple(metadata["int_list"])
# this is an awkward use of the metadata into to float dict, usually
# it will truly be treated as a dict. Here (and in 'forward' above)
# we use it essentially as an array.
shift = np.reshape(
np.array(
[metadata["int_to_float"][idx] for idx in range(len(metadata["int_to_float"]))]
),
newshape=shape,
order="C",
)
return data - shift
class RandomShiftPipeline(TransformationPipeline):
"""Random Shift Pipeline."""
def __init__(self, **kwargs):
"""Initialize."""
transformers = [RandomShiftTransformer(), NumpyArrayToBytes()]
super().__init__(transformers=transformers)