@@ -29,7 +29,7 @@ class TraceSTDPConvSynapse(ConvSynapse): ## trace-based STDP convolutional cable
2929 Args:
3030 name: the string name of this cell
3131
32- x_size: dimension of input signal (assuming a square input)
32+ x_shape: 2d shape of input map signal (component currently assumess a square input maps )
3333
3434 shape: tuple specifying shape of this synaptic cable (usually a 4-tuple
3535 with number `filter height x filter width x input channels x number output channels`);
@@ -55,7 +55,7 @@ class TraceSTDPConvSynapse(ConvSynapse): ## trace-based STDP convolutional cable
5555
5656 padding: pre-operator padding to use -- "VALID" (none), "SAME"
5757
58- resist_scale: aa fixed (resistance) scaling factor to apply to synaptic
58+ resist_scale: a fixed (resistance) scaling factor to apply to synaptic
5959 transform (Default: 1.), i.e., yields: out = ((K @ in) * resist_scale) + b
6060 where `@` denotes convolution
6161
@@ -69,10 +69,10 @@ class TraceSTDPConvSynapse(ConvSynapse): ## trace-based STDP convolutional cable
6969 """
7070
7171 # Define Functions
72- def __init__ (self , name , shape , x_size , A_plus , A_minus , eta = 0. ,
72+ def __init__ (self , name , shape , x_shape , A_plus , A_minus , eta = 0. ,
7373 pretrace_target = 0. , filter_init = None , stride = 1 , padding = None ,
7474 resist_scale = 1. , w_bound = 0. , w_decay = 0. , batch_size = 1 , ** kwargs ):
75- super ().__init__ (name , shape , x_size = x_size , filter_init = filter_init ,
75+ super ().__init__ (name , shape , x_shape = x_shape , filter_init = filter_init ,
7676 bias_init = None , resist_scale = resist_scale , stride = stride ,
7777 padding = padding , batch_size = batch_size , ** kwargs )
7878
@@ -97,6 +97,15 @@ def __init__(self, name, shape, x_size, A_plus, A_minus, eta=0.,
9797 ## Shape error correction -- do shape correction inference for local updates
9898 self ._init (self .batch_size , self .x_size , self .shape , self .stride ,
9999 self .padding , self .pad_args , self .weights )
100+ k_size , k_size , n_in_chan , n_out_chan = self .shape
101+ if padding == "SAME" :
102+ self .antiPad = _conv_same_transpose_padding (
103+ self .postSpike .value .shape [1 ],
104+ self .x_size , k_size , stride )
105+ elif padding == "VALID" :
106+ self .antiPad = _conv_valid_transpose_padding (
107+ self .postSpike .value .shape [1 ],
108+ self .x_size , k_size , stride )
100109 ########################################################################
101110
102111 def _init (self , batch_size , x_size , shape , stride , padding , pad_args ,
@@ -147,17 +156,17 @@ def evolve(self, weights, dWeights):
147156 self .dWeights .set (dWeights )
148157
149158 @staticmethod
150- def _backtransmit (x_size , shape , stride , padding , x_delta_shape ,
151- preSpike , postSpike , weights ): ## action-backpropagating routine
159+ def _backtransmit (x_size , shape , stride , padding , x_delta_shape , antiPad ,
160+ postSpike , weights ): ## action-backpropagating routine
152161 ## calc dInputs - adjustment w.r.t. input signal
153162 k_size , k_size , n_in_chan , n_out_chan = shape
154- antiPad = None
155- if padding == "SAME" :
156- antiPad = _conv_same_transpose_padding (postSpike .shape [1 ], x_size ,
157- k_size , stride )
158- elif padding == "VALID" :
159- antiPad = _conv_valid_transpose_padding (postSpike .shape [1 ], x_size ,
160- k_size , stride )
163+ # antiPad = None
164+ # if padding == "SAME":
165+ # antiPad = _conv_same_transpose_padding(postSpike.shape[1], x_size,
166+ # k_size, stride)
167+ # elif padding == "VALID":
168+ # antiPad = _conv_valid_transpose_padding(postSpike.shape[1], x_size,
169+ # k_size, stride)
161170 dInputs = calc_dX_conv (weights , postSpike , delta_shape = x_delta_shape ,
162171 stride_size = stride , anti_padding = antiPad )
163172 return dInputs
@@ -213,9 +222,12 @@ def help(self): ## component help function
213222 hyperparams = {
214223 "shape" : "Shape of synaptic filter value matrix; `kernel width` x `kernel height` "
215224 "x `number input channels` x `number output channels`" ,
225+ "x_shape" : "Shape of any single incoming/input feature map" ,
216226 "weight_init" : "Initialization conditions for synaptic filter (K) values" ,
217227 "bias_init" : "Initialization conditions for bias/base-rate (b) values" ,
218228 "resist_scale" : "Resistance level output scaling factor (R)" ,
229+ "stride" : "length / size of stride" ,
230+ "padding" : "pre-operator padding to use, i.e., `VALID` `SAME`" ,
219231 "A_plus" : "Strength of long-term potentiation (LTP)" ,
220232 "A_minus" : "Strength of long-term depression (LTD)" ,
221233 "eta" : "Global learning rate (multiplier beyond A_plus and A_minus)" ,
0 commit comments