13
13
# limitations under the License.
14
14
15
15
import collections
16
- import functools
17
16
import itertools
18
17
import threading
19
18
import warnings
@@ -1903,14 +1902,22 @@ def _walk_up_rv(rv, formatting='plain'):
1903
1902
return all_rvs
1904
1903
1905
1904
1906
- def _repr_deterministic_rv (rv , formatting = 'plain' ):
1907
- """Make latex string for a Deterministic variable"""
1908
- if formatting == 'latex' :
1909
- return r"$\text{{{name}}} \sim \text{{Deterministic}}({args})$" .format (
1910
- name = rv .name , args = r",~" .join (_walk_up_rv (rv , formatting = formatting )))
1911
- else :
1912
- return "{name} ~ Deterministic({args})" .format (
1913
- name = rv .name , args = ", " .join (_walk_up_rv (rv , formatting = formatting )))
1905
+ class DeterministicWrapper (tt .TensorVariable ):
1906
+ def _str_repr (self , formatting = 'plain' ):
1907
+ if formatting == 'latex' :
1908
+ return r"$\text{{{name}}} \sim \text{{Deterministic}}({args})$" .format (
1909
+ name = self .name , args = r",~" .join (_walk_up_rv (self , formatting = formatting )))
1910
+ else :
1911
+ return "{name} ~ Deterministic({args})" .format (
1912
+ name = self .name , args = ", " .join (_walk_up_rv (self , formatting = formatting )))
1913
+
1914
+ def _repr_latex_ (self ):
1915
+ return self ._str_repr (formatting = 'latex' )
1916
+
1917
+ __latex__ = _repr_latex_
1918
+
1919
+ def __str__ (self ):
1920
+ return self ._str_repr (formatting = 'plain' )
1914
1921
1915
1922
1916
1923
def Deterministic (name , var , model = None , dims = None ):
@@ -1929,15 +1936,7 @@ def Deterministic(name, var, model=None, dims=None):
1929
1936
var = var .copy (model .name_for (name ))
1930
1937
model .deterministics .append (var )
1931
1938
model .add_random_variable (var , dims )
1932
- var ._repr_latex_ = functools .partial (_repr_deterministic_rv , var , formatting = 'latex' )
1933
- var .__latex__ = var ._repr_latex_
1934
-
1935
- # simply assigning var.__str__ is not enough, since str() will default to the class-
1936
- # defined __str__ anyway; see https://stackoverflow.com/a/5918210/1692028
1937
- old_type = type (var )
1938
- new_type = type (old_type .__name__ + '_pymc3_Deterministic' , (old_type ,),
1939
- {'__str__' : functools .partial (_repr_deterministic_rv , var , formatting = 'plain' )})
1940
- var .__class__ = new_type
1939
+ var .__class__ = DeterministicWrapper # adds str and latex functionality
1941
1940
1942
1941
return var
1943
1942
0 commit comments