-
Notifications
You must be signed in to change notification settings - Fork 307
/
Copy pathbase.py
151 lines (120 loc) · 5.15 KB
/
base.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
"""BaseSynthesizer module."""
import contextlib
import numpy as np
import torch
@contextlib.contextmanager
def set_random_states(random_state, set_model_random_state):
"""Context manager for managing the random state.
Args:
random_state (int or tuple):
The random seed or a tuple of (numpy.random.RandomState, torch.Generator).
set_model_random_state (function):
Function to set the random state on the model.
"""
original_np_state = np.random.get_state()
original_torch_state = torch.get_rng_state()
random_np_state, random_torch_state = random_state
np.random.set_state(random_np_state.get_state())
torch.set_rng_state(random_torch_state.get_state())
try:
yield
finally:
current_np_state = np.random.RandomState()
current_np_state.set_state(np.random.get_state())
current_torch_state = torch.Generator()
current_torch_state.set_state(torch.get_rng_state())
set_model_random_state((current_np_state, current_torch_state))
np.random.set_state(original_np_state)
torch.set_rng_state(original_torch_state)
def random_state(function):
"""Set the random state before calling the function.
Args:
function (Callable):
The function to wrap around.
"""
def wrapper(self, *args, **kwargs):
if self.random_states is None:
return function(self, *args, **kwargs)
else:
with set_random_states(self.random_states, self.set_random_state):
return function(self, *args, **kwargs)
return wrapper
class BaseSynthesizer:
"""Base class for all default synthesizers of ``CTGAN``."""
random_states = None
def __getstate__(self):
"""Improve pickling state for ``BaseSynthesizer``.
Convert to ``cpu`` device before starting the pickling process in order to be able to
load the model even when used from an external tool such as ``SDV``. Also, if
``random_states`` are set, store their states as dictionaries rather than generators.
Returns:
dict:
Python dict representing the object.
"""
device_backup = self._device
self.set_device(torch.device('cpu'))
state = self.__dict__.copy()
self.set_device(device_backup)
if (
isinstance(self.random_states, tuple)
and isinstance(self.random_states[0], np.random.RandomState)
and isinstance(self.random_states[1], torch.Generator)
):
state['_numpy_random_state'] = self.random_states[0].get_state()
state['_torch_random_state'] = self.random_states[1].get_state()
state.pop('random_states')
return state
def __setstate__(self, state):
"""Restore the state of a ``BaseSynthesizer``.
Restore the ``random_states`` from the state dict if those are present and then
set the device according to the current hardware.
"""
if '_numpy_random_state' in state and '_torch_random_state' in state:
np_state = state.pop('_numpy_random_state')
torch_state = state.pop('_torch_random_state')
current_torch_state = torch.Generator()
current_torch_state.set_state(torch_state)
current_numpy_state = np.random.RandomState()
current_numpy_state.set_state(np_state)
state['random_states'] = (current_numpy_state, current_torch_state)
self.__dict__ = state
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
self.set_device(device)
def save(self, path):
"""Save the model in the passed `path`."""
device_backup = self._device
self.set_device(torch.device('cpu'))
torch.save(self, path)
self.set_device(device_backup)
@classmethod
def load(cls, path):
"""Load the model stored in the passed `path`."""
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model = torch.load(path, weights_only=False)
model.set_device(device)
return model
def set_random_state(self, random_state):
"""Set the random state.
Args:
random_state (int, tuple, or None):
Either a tuple containing the (numpy.random.RandomState, torch.Generator)
or an int representing the random seed to use for both random states.
"""
if random_state is None:
self.random_states = random_state
elif isinstance(random_state, int):
self.random_states = (
np.random.RandomState(seed=random_state),
torch.Generator().manual_seed(random_state),
)
elif (
isinstance(random_state, tuple)
and isinstance(random_state[0], np.random.RandomState)
and isinstance(random_state[1], torch.Generator)
):
self.random_states = random_state
else:
raise TypeError(
f'`random_state` {random_state} expected to be an int or a tuple of '
'(`np.random.RandomState`, `torch.Generator`)'
)