Skip to content

Commit 9afaadf

Browse files
committed
fixed validation fun in bern/poiss
1 parent 223d3c0 commit 9afaadf

File tree

2 files changed

+11
-5
lines changed

2 files changed

+11
-5
lines changed

ngclearn/components/input_encoders/bernoulliCell.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,9 +101,12 @@ def __init__(self, name, n_units, target_freq=0., batch_size=1, **kwargs):
101101
self.outputs = Compartment(restVals, display_name="Spikes") # output compartment
102102
self.tols = Compartment(restVals, display_name="Time-of-Last-Spike", units="ms") # time of last spike
103103

104-
def validate(self, dt, **validation_kwargs):
105-
## check for unstable combinations of dt and target-frequency meta-params
104+
def validate(self, dt=None, **validation_kwargs):
106105
valid = super().validate(**validation_kwargs)
106+
if dt is None:
107+
warn(f"{self.name} requires a validation kwarg of `dt`")
108+
return False
109+
## check for unstable combinations of dt and target-frequency meta-params
107110
events_per_timestep = (dt/1000.) * self.target_freq ## compute scaled probability
108111
if events_per_timestep > 1.:
109112
valid = False

ngclearn/components/input_encoders/poissonCell.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,10 +56,13 @@ def __init__(self, name, n_units, target_freq=63.75, batch_size=1,
5656
random.uniform(subkey, (self.batch_size, self.n_units), minval=0.,
5757
maxval=1.))
5858

59-
def validate(self, dt, **validation_kwargs):
60-
## check for unstable combinations of dt and target-frequency meta-params
59+
def validate(self, dt=None, **validation_kwargs):
6160
valid = super().validate(**validation_kwargs)
62-
events_per_timestep = (dt/1000.) * self.target_freq ## compute scaled probability
61+
if dt is None:
62+
warn(f"{self.name} requires a validation kwarg of `dt`")
63+
return False
64+
## check for unstable combinations of dt and target-frequency meta-params
65+
events_per_timestep = (dt / 1000.) * self.target_freq ## compute scaled probability
6366
if events_per_timestep > 1.:
6467
valid = False
6568
warn(

0 commit comments

Comments
 (0)