Skip to content

Commit

Permalink
condition strategy
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidtronix committed Dec 16, 2024
1 parent b146aa6 commit b7d5afb
Showing 1 changed file with 16 additions and 9 deletions.
25 changes: 16 additions & 9 deletions ml4h/models/diffusion_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

from tensorflow import keras
from keras import layers
from torch.nn.quantized.functional import upsample

from ml4h.defines import IMAGE_EXT
from ml4h.models.Block import Block
Expand Down Expand Up @@ -107,8 +106,12 @@ def condition_layer_film(input_tensor, control_vector, filters):
beta = layers.Dense(filters, activation="linear")(control_vector)

# Reshape gamma and beta to match the spatial dimensions
#gamma = tf.reshape(gamma, (-1,) + input_tensor.shape[1:-1] + (filters,))
#beta = tf.reshape(beta, (-1,) + input_tensor.shape[1:-1] + (filters,))
if 4 == len(input_tensor.shape):
gamma = tf.reshape(gamma, (-1, 1, 1, filters))
beta = tf.reshape(beta, (-1, 1, 1, filters))
elif 3 == len(input_tensor.shape):
gamma = tf.reshape(gamma, (-1, 1, filters))
beta = tf.reshape(beta, (-1, 1, filters))
# Apply FiLM (Feature-wise Linear Modulation)
return input_tensor * gamma + beta

Expand Down Expand Up @@ -157,12 +160,10 @@ def up_block_control(width, block_depth, conv, upsample, kernel_size, attention_
def apply(x):
x, skips, control = x
# x = upsample(size=2, interpolation="bilinear")(x)

#control = upsample(size=2)(control)
x = upsample(size=2)(x)
for _ in range(block_depth):
x = layers.Concatenate()([x, skips.pop()])
x = residual_block_control(width, conv, kernel_size, attention_heads, condition_strategy)([x, control])
x = upsample(size=2)(x)
return x

return apply
Expand Down Expand Up @@ -193,7 +194,9 @@ def get_control_network(input_shape, widths, block_depth, kernel_size, control_s
skips = []
for i, width in enumerate(widths[:-1]):
if attention_modulo > 1 and (i + 1) % attention_modulo == 0:
if len(input_shape) > 2:
if condition_strategy == 'film':
c2 = control
elif len(input_shape) > 2:
c2 = upsample(size=x.shape[1:-1])(control[control_idxs])
else:
c2 = upsample(size=x.shape[-2])(control[control_idxs])
Expand All @@ -202,7 +205,9 @@ def get_control_network(input_shape, widths, block_depth, kernel_size, control_s
else:
x = down_block(width, block_depth, conv, pool, kernel_size)([x, skips])

if len(input_shape) > 2:
if condition_strategy == 'film':
c2 = control
elif len(input_shape) > 2:
c2 = upsample(size=x.shape[1:-1])(control[control_idxs])
else:
c2 = upsample(size=x.shape[-2])(control[control_idxs])
Expand All @@ -212,7 +217,9 @@ def get_control_network(input_shape, widths, block_depth, kernel_size, control_s

for i, width in enumerate(reversed(widths[:-1])):
if attention_modulo > 1 and i % attention_modulo == 0:
if len(input_shape) > 2:
if condition_strategy == 'film':
c2 = control
elif len(input_shape) > 2:
c2 = upsample(size=x.shape[1:-1])(control[control_idxs])
else:
c2 = upsample(size=x.shape[-2])(control[control_idxs])
Expand Down

0 comments on commit b7d5afb

Please sign in to comment.