Skip to content

Commit da04a98

Browse files
authored
Add deserialize module (#489)
1 parent 31eaf56 commit da04a98

File tree

5 files changed

+686
-0
lines changed

5 files changed

+686
-0
lines changed

docs/api_reference.rst

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,17 @@ Prior
6161
Censored
6262
Scaled
6363

64+
Deserialize
65+
===========
66+
67+
.. currentmodule:: pymc_extras.deserialize
68+
.. autosummary::
69+
:toctree: generated/
70+
71+
deserialize
72+
register_deserialization
73+
Deserializer
74+
6475

6576
Transforms
6677
==========

pymc_extras/deserialize.py

Lines changed: 224 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,224 @@
1+
"""Deserialize dictionaries into Python objects.
2+
3+
This is a two step process:
4+
5+
1. Determine if the data is of the correct type.
6+
2. Deserialize the data into a python object.
7+
8+
Examples
9+
--------
10+
Make use of the already registered deserializers:
11+
12+
.. code-block:: python
13+
14+
from pymc_extras.deserialize import deserialize
15+
16+
prior_class_data = {
17+
"dist": "Normal",
18+
"kwargs": {"mu": 0, "sigma": 1}
19+
}
20+
prior = deserialize(prior_class_data)
21+
# Prior("Normal", mu=0, sigma=1)
22+
23+
Register custom class deserialization:
24+
25+
.. code-block:: python
26+
27+
from pymc_extras.deserialize import register_deserialization
28+
29+
class MyClass:
30+
def __init__(self, value: int):
31+
self.value = value
32+
33+
def to_dict(self) -> dict:
34+
# Example of what the to_dict method might look like.
35+
return {"value": self.value}
36+
37+
register_deserialization(
38+
is_type=lambda data: data.keys() == {"value"} and isinstance(data["value"], int),
39+
deserialize=lambda data: MyClass(value=data["value"]),
40+
)
41+
42+
Deserialize data into that custom class:
43+
44+
.. code-block:: python
45+
46+
from pymc_extras.deserialize import deserialize
47+
48+
data = {"value": 42}
49+
obj = deserialize(data)
50+
assert isinstance(obj, MyClass)
51+
52+
53+
"""
54+
55+
from collections.abc import Callable
56+
from dataclasses import dataclass
57+
from typing import Any
58+
59+
IsType = Callable[[Any], bool]
60+
Deserialize = Callable[[Any], Any]
61+
62+
63+
@dataclass
64+
class Deserializer:
65+
"""Object to store information required for deserialization.
66+
67+
All deserializers should be stored via the :func:`register_deserialization` function
68+
instead of creating this object directly.
69+
70+
Attributes
71+
----------
72+
is_type : IsType
73+
Function to determine if the data is of the correct type.
74+
deserialize : Deserialize
75+
Function to deserialize the data.
76+
77+
Examples
78+
--------
79+
.. code-block:: python
80+
81+
from typing import Any
82+
83+
class MyClass:
84+
def __init__(self, value: int):
85+
self.value = value
86+
87+
from pymc_extras.deserialize import Deserializer
88+
89+
def is_type(data: Any) -> bool:
90+
return data.keys() == {"value"} and isinstance(data["value"], int)
91+
92+
def deserialize(data: dict) -> MyClass:
93+
return MyClass(value=data["value"])
94+
95+
deserialize_logic = Deserializer(is_type=is_type, deserialize=deserialize)
96+
97+
"""
98+
99+
is_type: IsType
100+
deserialize: Deserialize
101+
102+
103+
DESERIALIZERS: list[Deserializer] = []
104+
105+
106+
class DeserializableError(Exception):
107+
"""Error raised when data cannot be deserialized."""
108+
109+
def __init__(self, data: Any):
110+
self.data = data
111+
super().__init__(
112+
f"Couldn't deserialize {data}. Use register_deserialization to add a deserialization mapping."
113+
)
114+
115+
116+
def deserialize(data: Any) -> Any:
117+
"""Deserialize a dictionary into a Python object.
118+
119+
Use the :func:`register_deserialization` function to add custom deserializations.
120+
121+
Deserialization is a two step process due to the dynamic nature of the data:
122+
123+
1. Determine if the data is of the correct type.
124+
2. Deserialize the data into a Python object.
125+
126+
Each registered deserialization is checked in order until one is found that can
127+
deserialize the data. If no deserialization is found, a :class:`DeserializableError` is raised.
128+
129+
A :class:`DeserializableError` is raised when the data fails to be deserialized
130+
by any of the registered deserializers.
131+
132+
Parameters
133+
----------
134+
data : Any
135+
The data to deserialize.
136+
137+
Returns
138+
-------
139+
Any
140+
The deserialized object.
141+
142+
Raises
143+
------
144+
DeserializableError
145+
Raised when the data doesn't match any registered deserializations
146+
or fails to be deserialized.
147+
148+
Examples
149+
--------
150+
Deserialize a :class:`pymc_extras.prior.Prior` object:
151+
152+
.. code-block:: python
153+
154+
from pymc_extras.deserialize import deserialize
155+
156+
data = {"dist": "Normal", "kwargs": {"mu": 0, "sigma": 1}}
157+
prior = deserialize(data)
158+
# Prior("Normal", mu=0, sigma=1)
159+
160+
"""
161+
for mapping in DESERIALIZERS:
162+
try:
163+
is_type = mapping.is_type(data)
164+
except Exception:
165+
is_type = False
166+
167+
if not is_type:
168+
continue
169+
170+
try:
171+
return mapping.deserialize(data)
172+
except Exception as e:
173+
raise DeserializableError(data) from e
174+
else:
175+
raise DeserializableError(data)
176+
177+
178+
def register_deserialization(is_type: IsType, deserialize: Deserialize) -> None:
179+
"""Register an arbitrary deserialization.
180+
181+
Use the :func:`deserialize` function to then deserialize data using all registered
182+
deserialize functions.
183+
184+
Parameters
185+
----------
186+
is_type : Callable[[Any], bool]
187+
Function to determine if the data is of the correct type.
188+
deserialize : Callable[[dict], Any]
189+
Function to deserialize the data of that type.
190+
191+
Examples
192+
--------
193+
Register a custom class deserialization:
194+
195+
.. code-block:: python
196+
197+
from pymc_extras.deserialize import register_deserialization
198+
199+
class MyClass:
200+
def __init__(self, value: int):
201+
self.value = value
202+
203+
def to_dict(self) -> dict:
204+
# Example of what the to_dict method might look like.
205+
return {"value": self.value}
206+
207+
register_deserialization(
208+
is_type=lambda data: data.keys() == {"value"} and isinstance(data["value"], int),
209+
deserialize=lambda data: MyClass(value=data["value"]),
210+
)
211+
212+
Use that custom class deserialization:
213+
214+
.. code-block:: python
215+
216+
from pymc_extras.deserialize import deserialize
217+
218+
data = {"value": 42}
219+
obj = deserialize(data)
220+
assert isinstance(obj, MyClass)
221+
222+
"""
223+
mapping = Deserializer(is_type=is_type, deserialize=deserialize)
224+
DESERIALIZERS.append(mapping)

0 commit comments

Comments
 (0)