Skip to content

Commit acebcc9

Browse files
Full support for ShardTensor and Stormscope utils in physicsnemo.diffusion (#1547)
* Support FSDP/ShardTensor in diffusion * Full ShardTensor, sigma_data, sigma bin support * per-channel sigma in preconditioners, domain parallel unit tests * Minor cleanup * update diffusion import rules * dict casting * revert dict handling * condition check * fix typo * Promote domain parallel wrapper to NoiseScheduler * Address feedback * loguniform doctest fix * Fix minor bugs * Fix formatting and clarity in README.md Corrected capitalization in section headers and improved clarity in several sentences throughout the README. * Small fixes --------- Co-authored-by: megnvidia <mmiranda@nvidia.com>
1 parent ab8bc24 commit acebcc9

29 files changed

+2479
-1063
lines changed

docs/api/diffusion/noise_schedulers.rst

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,11 +99,13 @@ available at three levels:
9999
- the inverse mapping :math:`\sigma^{-1}(\sigma) = t` from noise level back to time
100100
- the discretization of the diffusion time grid
101101

102-
- **Ready-to-use schedules**: Five concrete implementations that work out of
102+
- **Ready-to-use schedules**: Multiple concrete implementations that work out of
103103
the box:
104104

105105
- :class:`EDMNoiseScheduler` --- :math:`\alpha(t)=1`,
106106
:math:`\sigma(t)=t`. The recommended default for most applications.
107+
- :class:`EDMLogUniformNoiseScheduler` --- EDM variant that samples
108+
training times uniformly in log-space instead of from a log-normal.
107109
- :class:`VENoiseScheduler` --- Variance Exploding schedule.
108110
- :class:`VPNoiseScheduler` --- Variance Preserving schedule.
109111
- :class:`IDDPMNoiseScheduler` --- Improved DDPM schedule.
@@ -137,6 +139,14 @@ API Reference
137139
:members:
138140
:exclude-members: __init__
139141

142+
:code:`EDMLogUniformNoiseScheduler`
143+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
144+
145+
.. autoclass:: physicsnemo.diffusion.noise_schedulers.EDMLogUniformNoiseScheduler
146+
:show-inheritance:
147+
:members:
148+
:exclude-members: __init__
149+
140150
:code:`VENoiseScheduler`
141151
~~~~~~~~~~~~~~~~~~~~~~~~
142152

@@ -168,3 +178,10 @@ API Reference
168178
:show-inheritance:
169179
:members:
170180
:exclude-members: __init__
181+
182+
:code:`DomainParallelNoiseScheduler`
183+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
184+
185+
.. autoclass:: physicsnemo.diffusion.noise_schedulers.DomainParallelNoiseScheduler
186+
:members:
187+
:exclude-members: __init__

examples/weather/stormcast/README.md

Lines changed: 85 additions & 57 deletions
Large diffs are not rendered by default.

examples/weather/stormcast/config/inference/stormcast.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ experiment_name: 'stormcast-inference' # Name for the inference experiment being
2020
run_id: 0 # Unique identifier for the inference run
2121
rundir: ./${inference.outdir}/${inference.experiment_name}/${inference.run_id} # Path where experiement outputs will be saved
2222
regression_checkpoint: stormcast_checkpoints/regression/StormCastUNet.0.0.mdlus
23-
diffusion_checkpoint: stormcast_checkpoints/diffusion/EDMPrecond.0.0.mdlus
23+
diffusion_checkpoint: stormcast_checkpoints/diffusion/EDMPreconditioner.0.0.mdlus
2424

2525
# Initial and lead times
2626
initial_time: "2022-11-04T21:00:00" # datetime to intialize forecast with (YYYY-MM-DDTHH:MM:SS)

examples/weather/stormcast/config/sampler/edm_deterministic.yaml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616

17-
# Sampler args
18-
# Below are passed as kwargs to physicsnemo.utils.diffusion.determinisitic_sampler
19-
# Also supports stochastic sampling via S_churn and related args.
17+
# Sampler args for physicsnemo.diffusion.samplers.sample
18+
# Uses EDMNoiseScheduler for timestep generation and Heun solver for ODE integration.
19+
# Set S_churn > 0 for stochastic sampling (EDMStochasticHeunSolver).
2020
# See EDM paper for details (https://arxiv.org/abs/2206.00364)
2121

2222
name: 'EDM Deterministic'
@@ -25,6 +25,7 @@ args:
2525
sigma_min: 0.002
2626
sigma_max: 800
2727
rho: 7
28+
solver: heun # "heun" (2nd order) or "euler" (1st order, faster)
2829
S_churn: 0.
2930
S_min: 0.
3031
S_max: .inf

examples/weather/stormcast/config/training/default.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ force_sharding: False
4747
# Performance and optimization (like corrdiff perf section)
4848
perf:
4949
fp_optimizations: fp32 # Floating point mode: "fp32", "amp-fp16", "amp-bf16"
50-
torch_compile: False # Use torch.compile to compile model
50+
torch_compile: False # torch.compile the training loss forward (skipped with domain parallelism)
5151
use_apex_gn: False # Use Apex GroupNorm (enables channels_last memory format)
5252
allow_tf32: False # Allow TF32 for matmul and cuDNN (faster but less precise)
5353
allow_fp16_reduced_precision: False # Allow reduced precision reductions in fp16

examples/weather/stormcast/inference.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
from omegaconf import DictConfig
2424
from physicsnemo.core import Module
2525

26+
from physicsnemo.diffusion.noise_schedulers import EDMNoiseScheduler
27+
2628
from datasets import dataset_classes
2729
from utils.io import (
2830
init_inference_results_zarr,
@@ -81,6 +83,13 @@ def main(cfg: DictConfig):
8183
net = Module.from_checkpoint(cfg.inference.diffusion_checkpoint)
8284
diffusion_model = net.to(device)
8385

86+
sa = dict(cfg.sampler.args)
87+
sampling_scheduler = EDMNoiseScheduler(
88+
sigma_min=sa.get("sigma_min", 0.002),
89+
sigma_max=sa.get("sigma_max", 80.0),
90+
rho=sa.get("rho", 7.0),
91+
)
92+
8493
# initialize zarr
8594
(
8695
group,
@@ -150,7 +159,8 @@ def main(cfg: DictConfig):
150159
diffusion_model,
151160
condition,
152161
state_pred.shape,
153-
sampler_args=dict(cfg.sampler.args),
162+
scheduler=sampling_scheduler,
163+
sampler_args=sa,
154164
lead_time_label=lead_time_label,
155165
)
156166

examples/weather/stormcast/test_training.py

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -162,8 +162,9 @@ def test_training(
162162
if dist.world_size > 1:
163163
torch.distributed.barrier()
164164

165-
net_cls = "EDMPrecond" if net_architecture == "unet" else "EDMPreconditioner"
166-
ckpt_path = os.path.join(rundir, "checkpoints_diffusion", f"{net_cls}.0.10.mdlus")
165+
ckpt_path = os.path.join(
166+
rundir, "checkpoints_diffusion", "EDMPreconditioner.0.10.mdlus"
167+
)
167168
assert os.path.isfile(ckpt_path), "Diffusion checkpoint not found"
168169

169170

@@ -232,8 +233,9 @@ def test_checkpointing(
232233
if num_procs > 1:
233234
torch.distributed.barrier()
234235

235-
net_cls = "EDMPrecond" if net_architecture == "unet" else "EDMPreconditioner"
236-
ckpt_path = os.path.join(rundir, "checkpoints_diffusion", f"{net_cls}.0.20.mdlus")
236+
ckpt_path = os.path.join(
237+
rundir, "checkpoints_diffusion", "EDMPreconditioner.0.20.mdlus"
238+
)
237239
assert os.path.isfile(ckpt_path), (
238240
f"Diffusion checkpoint not found on rank {dist.rank}"
239241
)
@@ -359,7 +361,7 @@ def test_seeding(
359361
360362
- Domain (model-parallel) groups are {0, 1} and {2, 3}.
361363
Ranks within the same domain group must see **identical** sigma
362-
(enforced by ``replicate_in_mesh`` broadcast).
364+
(enforced by ``DomainParallelNoiseScheduler`` broadcast).
363365
- DDP (data-parallel) groups are {0, 2} and {1, 3}.
364366
Ranks in different DDP groups must see **different** sigma
365367
(they process different data and have distinct RNG seeds).
@@ -391,17 +393,26 @@ def test_seeding(
391393

392394
t = trainer.Trainer(cfg)
393395

394-
# -- instrument the loss to capture post-broadcast sigma values ----------
396+
# -- instrument the loss to capture sigma values -------------------------
397+
from physicsnemo.diffusion.noise_schedulers import DomainParallelNoiseScheduler
398+
399+
scheduler = t.train_noise_scheduler
400+
if domain_parallel_size > 1 and not isinstance(
401+
scheduler, DomainParallelNoiseScheduler
402+
):
403+
raise ValueError(
404+
"test_seeding requires a DomainParallelNoiseScheduler on the "
405+
"loss when domain_parallel_size > 1"
406+
)
395407
captured_sigmas: list[torch.Tensor] = []
396-
_orig_replicate = t.loss_fn.replicate_in_mesh
408+
_orig_sample_time = scheduler.sample_time
397409

398-
def _capturing_replicate(x, y):
399-
result = _orig_replicate(x, y)
400-
local = result.to_local() if hasattr(result, "to_local") else result
401-
captured_sigmas.append(local.detach().cpu())
410+
def _capturing_sample_time(*args, **kwargs):
411+
result = _orig_sample_time(*args, **kwargs)
412+
captured_sigmas.append(result.detach().cpu())
402413
return result
403414

404-
t.loss_fn.replicate_in_mesh = _capturing_replicate
415+
scheduler.sample_time = _capturing_sample_time
405416

406417
# -- helper: gather sigmas and assert the expected pattern ---------------
407418
def _check_sigma_pattern(label: str) -> None:
@@ -520,8 +531,7 @@ def test_model_types(
520531
if dist.world_size > 1:
521532
torch.distributed.barrier()
522533

523-
net_cls = "EDMPrecond" if net_architecture == "unet" else "EDMPreconditioner"
524534
ckpt_path = os.path.join(
525-
rundir, "checkpoints_diffusion", f"{net_cls}.0.10.mdlus"
535+
rundir, "checkpoints_diffusion", "EDMPreconditioner.0.10.mdlus"
526536
)
527537
assert os.path.isfile(ckpt_path), "Diffusion checkpoint not found"

examples/weather/stormcast/utils/config.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,9 @@ class PerfConfig:
5454
fp_optimizations: Literal["fp32", "amp-fp16", "amp-bf16"] = (
5555
"fp32" # Floating point mode: "fp32", "amp-fp16", "amp-bf16"
5656
)
57-
torch_compile: bool = False # Use torch.compile to compile model
57+
torch_compile: bool = (
58+
False # torch.compile training loss forward (skipped with domain parallelism)
59+
)
5860
use_apex_gn: bool = (
5961
False # Use Apex GroupNorm (enables channels_last memory format)
6062
)

0 commit comments

Comments
 (0)