Skip to content

Commit

Permalink
Add default values for metric/optimizer in Fitters (brian-team#5) (br…
Browse files Browse the repository at this point in the history
  • Loading branch information
Eslam Khaled committed Apr 7, 2021
1 parent cd05601 commit a109617
Showing 1 changed file with 12 additions and 3 deletions.
15 changes: 12 additions & 3 deletions brian2modelfitting/fitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,15 @@

from brian2 import (NeuronGroup, defaultclock, get_device, Network,
StateMonitor, SpikeMonitor, second, get_local_namespace,
Quantity, get_logger)
Quantity, get_logger, ms)
from brian2.input import TimedArray
from brian2.equations.equations import Equations, SUBEXPRESSION
from brian2.devices import set_device, reset_device, device
from brian2.devices.cpp_standalone.device import CPPStandaloneDevice
from brian2.core.functions import Function
from .simulator import RuntimeSimulator, CPPStandaloneSimulator
from .metric import Metric, SpikeMetric, TraceMetric, MSEMetric, normalize_weights
from .optimizer import Optimizer
from .metric import Metric, SpikeMetric, TraceMetric, MSEMetric, GammaFactor, normalize_weights
from .optimizer import Optimizer, NevergradOptimizer
from .utils import callback_setup, make_dic


Expand Down Expand Up @@ -577,6 +577,9 @@ def fit(self, optimizer, metric=None, n_rounds=1, callback='text',
if penalty is None:
penalty = self.penalty

if optimizer is None:
optimizer = NevergradOptimizer()

if self.optimizer is None or restart:
if start_iteration is None:
self.iteration = 0
Expand Down Expand Up @@ -862,6 +865,9 @@ def calc_errors(self, metric):
def fit(self, optimizer, metric=None, n_rounds=1, callback='text',
restart=False, start_iteration=None, penalty=None,
level=0, **params):
if metric is None:
metric = MSEMetric()

if not isinstance(metric, TraceMetric):
raise TypeError("You can only use TraceMetric child metric with "
"TraceFitter")
Expand Down Expand Up @@ -1149,6 +1155,9 @@ def calc_errors(self, metric):
def fit(self, optimizer, metric=None, n_rounds=1, callback='text',
restart=False, start_iteration=None, penalty=None,
level=0, **params):
if metric is None:
metric = GammaFactor(delta=1*ms, time=60*ms)

if not isinstance(metric, SpikeMetric):
raise TypeError("You can only use SpikeMetric child metric with "
"SpikeFitter")
Expand Down

0 comments on commit a109617

Please sign in to comment.