Skip to content

Commit 9faaeec

Browse files
committed
fix run on cpu bug
1 parent bb88ce4 commit 9faaeec

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

pvnet/app.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
from pvnet.data.datamodule import batch_to_tensor
4141
from pvnet.models.base_model import BaseModel
4242

43+
4344
# ---------------------------------------------------------------------------
4445
# GLOBAL SETTINGS
4546

@@ -49,8 +50,7 @@
4950
data_config_filename = f"{this_dir}/../configs/datamodule/configuration/app_configuration.yaml"
5051

5152
# Model will use GPU if available
52-
if torch.cuda.is_available():
53-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
53+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
5454

5555
# Use multiple workers for data loading
5656
num_workers = min(os.cpu_count() - 1, 16)
@@ -161,7 +161,7 @@ def convert_df_to_forecasts(
161161

162162
def app(t0=None, apply_adjuster=False, gsp_ids=gsp_ids):
163163
"""Inference function for production
164-
164+
165165
This app expects these evironmental variables to be available:
166166
- DB_URL
167167
- NWP_ZARR_PATH
@@ -202,6 +202,7 @@ def app(t0=None, apply_adjuster=False, gsp_ids=gsp_ids):
202202
logger.info("Downloading zipped satellite data")
203203
fs = fsspec.open(os.environ["SATELLITE_ZARR_PATH"]).fs
204204
fs.get(os.environ["SATELLITE_ZARR_PATH"], "latest.zarr.zip")
205+
205206

206207
# ---------------------------------------------------------------------------
207208
# 2. Set up data loader

0 commit comments

Comments
 (0)