From da62789aa07f5cda20c874671b92945d80c24ff1 Mon Sep 17 00:00:00 2001 From: esavary Date: Fri, 5 Jan 2018 17:00:01 +0100 Subject: [PATCH] Add files via upload --- WGAN-GP/src/utils/layers.py | 29 +++++++++++++++-------------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/WGAN-GP/src/utils/layers.py b/WGAN-GP/src/utils/layers.py index 9fef396..68f102a 100644 --- a/WGAN-GP/src/utils/layers.py +++ b/WGAN-GP/src/utils/layers.py @@ -16,26 +16,27 @@ def linear(x, n_out, bias=True, name="linear"): with tf.variable_scope(name): - n_in = x.shape[-1] + #n_in = x.shape[-1] + n_in = x.get_shape()[-1] - # Initialize w - w_init_std = np.sqrt(1.0 / n_out) - w_init = tf.truncated_normal_initializer(0.0, w_init_std) - w = tf.get_variable('w', shape=[n_in,n_out], initializer=w_init) + # Initialize w + w_init_std = np.sqrt(1.0 / n_out) + w_init = tf.truncated_normal_initializer(0.0, w_init_std) + w = tf.get_variable('w', shape=[n_in,n_out], initializer=w_init) - # Dense mutliplication - x = tf.matmul(x, w) + # Dense mutliplication + x = tf.matmul(x, w) - if bias: + if bias: - # Initialize b - b_init = tf.constant_initializer(0.0) - b = tf.get_variable('b', shape=(n_out,), initializer=b_init) + # Initialize b + b_init = tf.constant_initializer(0.0) + b = tf.get_variable('b', shape=(n_out,), initializer=b_init) - # Add b - x = x + b + # Add b + x = x + b - return x + return x def phase_shift(x, upsampling_factor=2, data_format="NCHW", name="PhaseShift"):