From 65f6d450fd26f3f5b600c593f27ee0eeccdbc557 Mon Sep 17 00:00:00 2001 From: Hesam Salehipour Date: Thu, 2 Jan 2025 21:56:55 -0500 Subject: [PATCH] Moved omega from an attribute of the collision and stepper operations to an input of the methods. --- examples/cfd/flow_past_sphere_3d.py | 8 ++++-- examples/cfd/lid_driven_cavity_2d.py | 7 +++-- .../cfd/lid_driven_cavity_2d_distributed.py | 1 - examples/cfd/turbulent_channel_3d.py | 7 +++-- examples/cfd/windtunnel_3d.py | 7 +++-- examples/performance/mlups_3d.py | 7 +++-- .../collision/test_bgk_collision_jax.py | 4 +-- .../collision/test_bgk_collision_warp.py | 4 +-- xlb/operator/collision/bgk.py | 17 +++++------ xlb/operator/collision/collision.py | 8 ------ xlb/operator/collision/forced_collision.py | 18 ++++++------ xlb/operator/collision/kbc.py | 28 +++++++++++-------- .../equilibrium/quadratic_equilibrium.py | 2 +- xlb/operator/stepper/nse_stepper.py | 18 ++++++------ 14 files changed, 76 insertions(+), 60 deletions(-) diff --git a/examples/cfd/flow_past_sphere_3d.py b/examples/cfd/flow_past_sphere_3d.py index 755088ef..a955176f 100644 --- a/examples/cfd/flow_past_sphere_3d.py +++ b/examples/cfd/flow_past_sphere_3d.py @@ -31,6 +31,11 @@ def __init__(self, omega, grid_shape, velocity_set, backend, precision_policy): self.backend = backend self.precision_policy = precision_policy self.omega = omega + + # Ensure warp type is set correctly for Omega + if self.backend == ComputeBackend.WARP: + self.omega = wp.static(self.precision_policy.compute_precision.wp_dtype(self.omega)) + self.boundary_conditions = [] self.u_max = 0.04 @@ -75,7 +80,6 @@ def setup_boundary_conditions(self): def setup_stepper(self): self.stepper = IncompressibleNavierStokesStepper( - omega=self.omega, grid=self.grid, boundary_conditions=self.boundary_conditions, collision_type="BGK", @@ -127,7 +131,7 @@ def bc_profile_jax(): def run(self, num_steps, post_process_interval=100): start_time = time.time() for i in range(num_steps): - self.f_0, self.f_1 = self.stepper(self.f_0, self.f_1, self.bc_mask, self.missing_mask, i) + self.f_0, self.f_1 = self.stepper(self.f_0, self.f_1, self.bc_mask, self.missing_mask, self.omega, i) self.f_0, self.f_1 = self.f_1, self.f_0 if i % post_process_interval == 0 or i == num_steps - 1: diff --git a/examples/cfd/lid_driven_cavity_2d.py b/examples/cfd/lid_driven_cavity_2d.py index 17b59159..c167939a 100644 --- a/examples/cfd/lid_driven_cavity_2d.py +++ b/examples/cfd/lid_driven_cavity_2d.py @@ -29,6 +29,10 @@ def __init__(self, omega, prescribed_vel, grid_shape, velocity_set, backend, pre self.boundary_conditions = [] self.prescribed_vel = prescribed_vel + # Ensure warp type is set correctly for Omega + if self.backend == ComputeBackend.WARP: + self.omega = wp.static(self.precision_policy.compute_precision.wp_dtype(self.omega)) + # Create grid using factory self.grid = grid_factory(grid_shape, compute_backend=backend) @@ -57,7 +61,6 @@ def setup_boundary_conditions(self): def setup_stepper(self): self.stepper = IncompressibleNavierStokesStepper( - omega=self.omega, grid=self.grid, boundary_conditions=self.boundary_conditions, collision_type="BGK", @@ -65,7 +68,7 @@ def setup_stepper(self): def run(self, num_steps, post_process_interval=100): for i in range(num_steps): - self.f_0, self.f_1 = self.stepper(self.f_0, self.f_1, self.bc_mask, self.missing_mask, i) + self.f_0, self.f_1 = self.stepper(self.f_0, self.f_1, self.bc_mask, self.missing_mask, self.omega, i) self.f_0, self.f_1 = self.f_1, self.f_0 if i % post_process_interval == 0 or i == num_steps - 1: diff --git a/examples/cfd/lid_driven_cavity_2d_distributed.py b/examples/cfd/lid_driven_cavity_2d_distributed.py index 1018ec58..d06d314a 100644 --- a/examples/cfd/lid_driven_cavity_2d_distributed.py +++ b/examples/cfd/lid_driven_cavity_2d_distributed.py @@ -13,7 +13,6 @@ def __init__(self, omega, prescribed_vel, grid_shape, velocity_set, backend, pre def setup_stepper(self): # Create the base stepper stepper = IncompressibleNavierStokesStepper( - omega=self.omega, grid=self.grid, boundary_conditions=self.boundary_conditions, collision_type="BGK", diff --git a/examples/cfd/turbulent_channel_3d.py b/examples/cfd/turbulent_channel_3d.py index 5a6484f3..204f4012 100644 --- a/examples/cfd/turbulent_channel_3d.py +++ b/examples/cfd/turbulent_channel_3d.py @@ -54,6 +54,10 @@ def __init__(self, channel_half_width, Re_tau, u_tau, grid_shape, velocity_set, self.precision_policy = precision_policy self.boundary_conditions = [] + # Ensure warp type is set correctly for Omega + if self.backend == ComputeBackend.WARP: + self.omega = wp.static(self.precision_policy.compute_precision.wp_dtype(self.omega)) + # Create grid using factory self.grid = grid_factory(grid_shape, compute_backend=backend) @@ -98,7 +102,6 @@ def initialize_fields(self): def setup_stepper(self): self.stepper = IncompressibleNavierStokesStepper( - omega=self.omega, grid=self.grid, boundary_conditions=self.boundary_conditions, collision_type="KBC", @@ -108,7 +111,7 @@ def setup_stepper(self): def run(self, num_steps, print_interval, post_process_interval=100): start_time = time.time() for i in range(num_steps): - self.f_0, self.f_1 = self.stepper(self.f_0, self.f_1, self.bc_mask, self.missing_mask, i) + self.f_0, self.f_1 = self.stepper(self.f_0, self.f_1, self.bc_mask, self.missing_mask, self.omega, i) self.f_0, self.f_1 = self.f_1, self.f_0 if (i + 1) % print_interval == 0: diff --git a/examples/cfd/windtunnel_3d.py b/examples/cfd/windtunnel_3d.py index 133c4f3e..819b5dad 100644 --- a/examples/cfd/windtunnel_3d.py +++ b/examples/cfd/windtunnel_3d.py @@ -37,6 +37,10 @@ def __init__(self, omega, wind_speed, grid_shape, velocity_set, backend, precisi self.boundary_conditions = [] self.wind_speed = wind_speed + # Ensure warp type is set correctly for Omega + if self.backend == ComputeBackend.WARP: + self.omega = wp.static(self.precision_policy.compute_precision.wp_dtype(self.omega)) + # Create grid using factory self.grid = grid_factory(grid_shape, compute_backend=backend) @@ -98,7 +102,6 @@ def setup_boundary_conditions(self): def setup_stepper(self): self.stepper = IncompressibleNavierStokesStepper( - omega=self.omega, grid=self.grid, boundary_conditions=self.boundary_conditions, collision_type="KBC", @@ -111,7 +114,7 @@ def run(self, num_steps, print_interval, post_process_interval=100): start_time = time.time() for i in range(num_steps): - self.f_0, self.f_1 = self.stepper(self.f_0, self.f_1, self.bc_mask, self.missing_mask, i) + self.f_0, self.f_1 = self.stepper(self.f_0, self.f_1, self.bc_mask, self.missing_mask, self.omega, i) self.f_0, self.f_1 = self.f_1, self.f_0 if (i + 1) % print_interval == 0: diff --git a/examples/performance/mlups_3d.py b/examples/performance/mlups_3d.py index f22dd94d..8d864ae2 100644 --- a/examples/performance/mlups_3d.py +++ b/examples/performance/mlups_3d.py @@ -53,7 +53,7 @@ def run(backend, precision_policy, grid_shape, num_steps): boundary_conditions = [EquilibriumBC(rho=1.0, u=(0.02, 0.0, 0.0), indices=lid), FullwayBounceBackBC(indices=walls)] # Create stepper - stepper = IncompressibleNavierStokesStepper(omega=1.0, grid=grid, boundary_conditions=boundary_conditions, collision_type="BGK") + stepper = IncompressibleNavierStokesStepper(grid=grid, boundary_conditions=boundary_conditions, collision_type="BGK") # Distribute if using JAX backend if backend == ComputeBackend.JAX: @@ -64,11 +64,14 @@ def run(backend, precision_policy, grid_shape, num_steps): ) # Initialize fields and run simulation + omega = 1.0 + if backend == ComputeBackend.WARP: + omega = wp.static(precision_policy.compute_precision.wp_dtype(omega)) f_0, f_1, bc_mask, missing_mask = stepper.prepare_fields() start_time = time.time() for i in range(num_steps): - f_0, f_1 = stepper(f_0, f_1, bc_mask, missing_mask, i) + f_0, f_1 = stepper(f_0, f_1, bc_mask, missing_mask, omega, i) f_0, f_1 = f_1, f_0 wp.synchronize() diff --git a/tests/kernels/collision/test_bgk_collision_jax.py b/tests/kernels/collision/test_bgk_collision_jax.py index 1672cd5c..72c2ec99 100644 --- a/tests/kernels/collision/test_bgk_collision_jax.py +++ b/tests/kernels/collision/test_bgk_collision_jax.py @@ -41,11 +41,11 @@ def test_bgk_ollision(dim, velocity_set, grid_shape, omega): # Compute collision - compute_collision = BGK(omega=omega) + compute_collision = BGK() f_orig = my_grid.create_field(cardinality=DefaultConfig.velocity_set.q) - f_out = compute_collision(f_orig, f_eq, rho, u) + f_out = compute_collision(f_orig, f_eq, rho, u, omega) assert jnp.allclose(f_out, f_orig - omega * (f_orig - f_eq)) diff --git a/tests/kernels/collision/test_bgk_collision_warp.py b/tests/kernels/collision/test_bgk_collision_warp.py index 382e3684..3c8436c6 100644 --- a/tests/kernels/collision/test_bgk_collision_warp.py +++ b/tests/kernels/collision/test_bgk_collision_warp.py @@ -40,11 +40,11 @@ def test_bgk_collision_warp(dim, velocity_set, grid_shape, omega): f_eq = my_grid.create_field(cardinality=DefaultConfig.velocity_set.q) f_eq = compute_macro(rho, u, f_eq) - compute_collision = BGK(omega=omega) + compute_collision = BGK() f_orig = my_grid.create_field(cardinality=DefaultConfig.velocity_set.q) f_out = my_grid.create_field(cardinality=DefaultConfig.velocity_set.q) - f_out = compute_collision(f_orig, f_eq, f_out, rho, u) + f_out = compute_collision(f_orig, f_eq, f_out, rho, u, omega) f_eq = f_eq.numpy() f_out = f_out.numpy() diff --git a/xlb/operator/collision/bgk.py b/xlb/operator/collision/bgk.py index 115ed9a5..479fac25 100644 --- a/xlb/operator/collision/bgk.py +++ b/xlb/operator/collision/bgk.py @@ -15,23 +15,22 @@ class BGK(Collision): """ @Operator.register_backend(ComputeBackend.JAX) - @partial(jit, static_argnums=(0,)) - def jax_implementation(self, f: jnp.ndarray, feq: jnp.ndarray, rho, u): + @partial(jit, static_argnums=(0, 5)) + def jax_implementation(self, f: jnp.ndarray, feq: jnp.ndarray, rho, u, omega): fneq = f - feq - fout = f - self.compute_dtype(self.omega) * fneq + fout = f - self.compute_dtype(omega) * fneq return fout def _construct_warp(self): # Set local constants TODO: This is a hack and should be fixed with warp update _w = self.velocity_set.w - _omega = wp.constant(self.compute_dtype(self.omega)) _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype) # Construct the functional @wp.func - def functional(f: Any, feq: Any, rho: Any, u: Any): + def functional(f: Any, feq: Any, rho: Any, u: Any, omega: Any): fneq = f - feq - fout = f - _omega * fneq + fout = f - omega * fneq return fout # Construct the warp kernel @@ -42,6 +41,7 @@ def kernel( fout: wp.array4d(dtype=Any), rho: wp.array4d(dtype=Any), u: wp.array4d(dtype=Any), + omega: Any, ): # Get the global index i, j, k = wp.tid() @@ -55,7 +55,7 @@ def kernel( _feq[l] = feq[l, index[0], index[1], index[2]] # Compute the collision - _fout = functional(_f, _feq, rho, u) + _fout = functional(_f, _feq, rho, u, omega) # Write the result for l in range(self.velocity_set.q): @@ -64,7 +64,7 @@ def kernel( return functional, kernel @Operator.register_backend(ComputeBackend.WARP) - def warp_implementation(self, f, feq, fout, rho, u): + def warp_implementation(self, f, feq, fout, rho, u, omega): # Launch the warp kernel wp.launch( self.warp_kernel, @@ -74,6 +74,7 @@ def warp_implementation(self, f, feq, fout, rho, u): fout, rho, u, + omega, ], dim=f.shape[1:], ) diff --git a/xlb/operator/collision/collision.py b/xlb/operator/collision/collision.py index 00a8dfd4..a9ded59e 100644 --- a/xlb/operator/collision/collision.py +++ b/xlb/operator/collision/collision.py @@ -12,20 +12,12 @@ class Collision(Operator): This class defines the collision step for the Lattice Boltzmann Method. - Parameters - ---------- - omega : float - Relaxation parameter for collision step. Default value is 0.6. - shear : bool - Flag to indicate whether the collision step requires the shear stress. """ def __init__( self, - omega: float, velocity_set: VelocitySet = None, precision_policy=None, compute_backend=None, ): - self.omega = omega super().__init__(velocity_set, precision_policy, compute_backend) diff --git a/xlb/operator/collision/forced_collision.py b/xlb/operator/collision/forced_collision.py index 2036bab3..c209283f 100644 --- a/xlb/operator/collision/forced_collision.py +++ b/xlb/operator/collision/forced_collision.py @@ -23,7 +23,7 @@ def __init__( ): assert collision_operator is not None self.collision_operator = collision_operator - super().__init__(self.collision_operator.omega) + super().__init__() assert forcing_scheme == "exact_difference", NotImplementedError(f"Force model {forcing_scheme} not implemented!") assert force_vector.shape[0] == self.velocity_set.d, "Check the dimensions of the input force!" @@ -32,9 +32,9 @@ def __init__( self.forcing_operator = ExactDifference(force_vector) @Operator.register_backend(ComputeBackend.JAX) - @partial(jit, static_argnums=(0,)) - def jax_implementation(self, f: jnp.ndarray, feq: jnp.ndarray, rho, u): - fout = self.collision_operator(f, feq, rho, u) + @partial(jit, static_argnums=(0, 5)) + def jax_implementation(self, f: jnp.ndarray, feq: jnp.ndarray, rho, u, omega): + fout = self.collision_operator(f, feq, rho, u, omega) fout = self.forcing_operator(fout, feq, rho, u) return fout @@ -45,8 +45,8 @@ def _construct_warp(self): # Construct the functional @wp.func - def functional(f: Any, feq: Any, rho: Any, u: Any): - fout = self.collision_operator.warp_functional(f, feq, rho, u) + def functional(f: Any, feq: Any, rho: Any, u: Any, omega: Any): + fout = self.collision_operator.warp_functional(f, feq, rho, u, omega) fout = self.forcing_operator.warp_functional(fout, feq, rho, u) return fout @@ -58,6 +58,7 @@ def kernel( fout: wp.array4d(dtype=Any), rho: wp.array4d(dtype=Any), u: wp.array4d(dtype=Any), + omega: Any, ): # Get the global index i, j, k = wp.tid() @@ -76,7 +77,7 @@ def kernel( _rho = rho[0, index[0], index[1], index[2]] # Compute the collision - _fout = functional(_f, _feq, _rho, _u) + _fout = functional(_f, _feq, _rho, _u, omega) # Write the result for l in range(self.velocity_set.q): @@ -85,7 +86,7 @@ def kernel( return functional, kernel @Operator.register_backend(ComputeBackend.WARP) - def warp_implementation(self, f, feq, fout, rho, u): + def warp_implementation(self, f, feq, fout, rho, u, omega): # Launch the warp kernel wp.launch( self.warp_kernel, @@ -95,6 +96,7 @@ def warp_implementation(self, f, feq, fout, rho, u): fout, rho, u, + omega, ], dim=f.shape[1:], ) diff --git a/xlb/operator/collision/kbc.py b/xlb/operator/collision/kbc.py index b841baa5..e84c935e 100644 --- a/xlb/operator/collision/kbc.py +++ b/xlb/operator/collision/kbc.py @@ -24,31 +24,28 @@ class KBC(Collision): def __init__( self, - omega: float, velocity_set: VelocitySet = None, precision_policy=None, compute_backend=None, ): self.momentum_flux = MomentumFlux() self.epsilon = 1e-32 - self.beta = omega * 0.5 - self.inv_beta = 1.0 / self.beta super().__init__( - omega=omega, velocity_set=velocity_set, precision_policy=precision_policy, compute_backend=compute_backend, ) @Operator.register_backend(ComputeBackend.JAX) - @partial(jit, static_argnums=(0,), donate_argnums=(1, 2, 3)) + @partial(jit, static_argnums=(0, 5), donate_argnums=(1, 2, 3)) def jax_implementation( self, f: jnp.ndarray, feq: jnp.ndarray, rho: jnp.ndarray, u: jnp.ndarray, + omega, ): """ KBC collision step for lattice. @@ -74,13 +71,17 @@ def jax_implementation( else: raise NotImplementedError("Velocity set not supported: {}".format(type(self.velocity_set))) + # Compute required constants based on the input omega (omega is the inverse relaxation time) + beta = omega * 0.5 + inv_beta = 1.0 / beta + # Perform collision delta_h = fneq - delta_s - gamma = self.inv_beta - (2.0 - self.inv_beta) * self.entropic_scalar_product(delta_s, delta_h, feq) / ( + gamma = inv_beta - (2.0 - inv_beta) * self.entropic_scalar_product(delta_s, delta_h, feq) / ( self.epsilon + self.entropic_scalar_product(delta_h, delta_h, feq) ) - fout = f - self.beta * (2.0 * delta_s + gamma[None, ...] * delta_h) + fout = f - beta * (2.0 * delta_s + gamma[None, ...] * delta_h) return fout @@ -185,8 +186,6 @@ def _construct_warp(self): _u_vec = wp.vec(self.velocity_set.d, dtype=self.compute_dtype) _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype) _epsilon = wp.constant(self.compute_dtype(self.epsilon)) - _beta = wp.constant(self.compute_dtype(self.beta)) - _inv_beta = wp.constant(self.compute_dtype(1.0 / self.beta)) @wp.func def decompose_shear_d2q9(fneq: Any): @@ -268,6 +267,7 @@ def functional( feq: Any, rho: Any, u: Any, + omega: Any, ): # Compute shear and delta_s fneq = f - feq @@ -278,6 +278,10 @@ def functional( shear = decompose_shear_d2q9(fneq) delta_s = shear * rho / self.compute_dtype(4.0) + # Compute required constants based on the input omega (omega is the inverse relaxation time) + _beta = omega * self.compute_dtype(0.5) + _inv_beta = self.compute_dtype(1.0) / _beta + # Perform collision delta_h = fneq - delta_s two = self.compute_dtype(2.0) @@ -296,6 +300,7 @@ def kernel( fout: wp.array4d(dtype=Any), rho: wp.array4d(dtype=Any), u: wp.array4d(dtype=Any), + omega: Any, ): # Get the global index i, j, k = wp.tid() @@ -314,7 +319,7 @@ def kernel( _rho = rho[0, index[0], index[1], index[2]] # Compute the collision - _fout = functional(_f, _feq, _rho, _u) + _fout = functional(_f, _feq, _rho, _u, omega) # Write the result for l in range(self.velocity_set.q): @@ -323,7 +328,7 @@ def kernel( return functional, kernel @Operator.register_backend(ComputeBackend.WARP) - def warp_implementation(self, f, feq, fout, rho, u): + def warp_implementation(self, f, feq, fout, rho, u, omega): # Launch the warp kernel wp.launch( self.warp_kernel, @@ -333,6 +338,7 @@ def warp_implementation(self, f, feq, fout, rho, u): fout, rho, u, + omega, ], dim=f.shape[1:], ) diff --git a/xlb/operator/equilibrium/quadratic_equilibrium.py b/xlb/operator/equilibrium/quadratic_equilibrium.py index 133d30c6..62cc0414 100644 --- a/xlb/operator/equilibrium/quadratic_equilibrium.py +++ b/xlb/operator/equilibrium/quadratic_equilibrium.py @@ -20,7 +20,7 @@ class QuadraticEquilibrium(Equilibrium): def jax_implementation(self, rho, u): cu = 3.0 * jnp.tensordot(self.velocity_set.c, u, axes=(0, 0)) usqr = 1.5 * jnp.sum(jnp.square(u), axis=0, keepdims=True) - w = self.velocity_set.w.reshape((-1,) + (1,) * self.velocity_set.d) + w = self.velocity_set.w.reshape((-1,) + (1,) * (len(rho.shape) - 1)) feq = rho * w * (1.0 + cu * (1.0 + 0.5 * cu) - usqr) return feq diff --git a/xlb/operator/stepper/nse_stepper.py b/xlb/operator/stepper/nse_stepper.py index 3db20cc8..8cd80fe4 100644 --- a/xlb/operator/stepper/nse_stepper.py +++ b/xlb/operator/stepper/nse_stepper.py @@ -25,7 +25,6 @@ class IncompressibleNavierStokesStepper(Stepper): def __init__( self, - omega, grid, boundary_conditions=[], collision_type="BGK", @@ -36,9 +35,9 @@ def __init__( # Construct the collision operator if collision_type == "BGK": - self.collision = BGK(omega, self.velocity_set, self.precision_policy, self.compute_backend) + self.collision = BGK(self.velocity_set, self.precision_policy, self.compute_backend) elif collision_type == "KBC": - self.collision = KBC(omega, self.velocity_set, self.precision_policy, self.compute_backend) + self.collision = KBC(self.velocity_set, self.precision_policy, self.compute_backend) if force_vector is not None: self.collision = ForcedCollision(collision_operator=self.collision, forcing_scheme=forcing_scheme, force_vector=force_vector) @@ -128,8 +127,8 @@ def _initialize_auxiliary_data(boundary_conditions, f_0, f_1, bc_mask, missing_m return f_0, f_1 @Operator.register_backend(ComputeBackend.JAX) - @partial(jit, static_argnums=(0)) - def jax_implementation(self, f_0, f_1, bc_mask, missing_mask, timestep): + @partial(jit, static_argnums=(0, 5)) + def jax_implementation(self, f_0, f_1, bc_mask, missing_mask, omega, timestep): """ Perform a single step of the lattice boltzmann method """ @@ -157,7 +156,7 @@ def jax_implementation(self, f_0, f_1, bc_mask, missing_mask, timestep): feq = self.equilibrium(rho, u) # Apply collision - f_post_collision = self.collision(f_post_stream, feq, rho, u) + f_post_collision = self.collision(f_post_stream, feq, rho, u, omega) # Apply collision type boundary conditions for bc in self.boundary_conditions: @@ -280,6 +279,7 @@ def kernel( f_1: wp.array4d(dtype=Any), bc_mask: wp.array4d(dtype=Any), missing_mask: wp.array4d(dtype=Any), + omega: Any, timestep: int, ): i, j, k = wp.tid() @@ -300,7 +300,7 @@ def kernel( _rho, _u = self.macroscopic.warp_functional(_f_post_stream) _feq = self.equilibrium.warp_functional(_rho, _u) - _f_post_collision = self.collision.warp_functional(_f_post_stream, _feq, _rho, _u) + _f_post_collision = self.collision.warp_functional(_f_post_stream, _feq, _rho, _u, omega) # Apply post-collision boundary conditions _f_post_collision = apply_bc(index, timestep, _boundary_id, _missing_mask, f_0, f_1, _f_post_stream, _f_post_collision, False) @@ -315,10 +315,10 @@ def kernel( return None, kernel @Operator.register_backend(ComputeBackend.WARP) - def warp_implementation(self, f_0, f_1, bc_mask, missing_mask, timestep): + def warp_implementation(self, f_0, f_1, bc_mask, missing_mask, omega, timestep): wp.launch( self.warp_kernel, - inputs=[f_0, f_1, bc_mask, missing_mask, timestep], + inputs=[f_0, f_1, bc_mask, missing_mask, omega, timestep], dim=f_0.shape[1:], ) return f_0, f_1