Skip to content

Commit 1dbd6a2

Browse files
authored
Fixes for broken examples (#202)
* Try to fix gaas-map * Added retrying logic for hamiltonian example * Fix linters * Try to fix pet-mad-uq
1 parent 886be4a commit 1dbd6a2

File tree

3 files changed

+54
-17
lines changed

3 files changed

+54
-17
lines changed

examples/gaas-map/gaas-map.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from sklearn.linear_model import RidgeCV
3636
from skmatter.decomposition import PCovR
3737
from skmatter.preprocessing import StandardFlexibleScaler
38+
from urllib3.util.retry import Retry
3839

3940

4041
# %%
@@ -45,13 +46,31 @@
4546
# train a ML potential to describe the full composition range.
4647
#
4748

48-
filename = "gaas_training.xyz"
49-
if not os.path.exists(filename):
50-
url = f"https://zenodo.org/records/10566825/files/{filename}"
51-
response = requests.get(url)
49+
50+
def fetch_dataset(filename, base_url, local_path=""):
51+
"""Helper function to load the pre-computed examples"""
52+
53+
local_file = local_path + filename
54+
if os.path.isfile(local_file):
55+
return
56+
57+
# Retry strategy: wait 1s, 2s, 4s, 8s, 16s on 429/5xx errors
58+
retry_strategy = Retry(
59+
total=5, backoff_factor=1, status_forcelist=[429, 500, 502, 503, 504]
60+
)
61+
session = requests.Session()
62+
session.mount("https://", requests.adapters.HTTPAdapter(max_retries=retry_strategy))
63+
64+
# Fetch with automatic retry and error raising
65+
response = session.get(base_url + filename)
5266
response.raise_for_status()
53-
with open(filename, "wb") as f:
54-
f.write(response.content)
67+
68+
with open(local_file, "wb") as file:
69+
file.write(response.content)
70+
71+
72+
filename = "gaas_training.xyz"
73+
fetch_dataset(filename, "https://zenodo.org/records/10566825/files/")
5574

5675
structures = ase.io.read(filename, ":")
5776
energy = np.array([f.info["energy"] for f in structures])

examples/hamiltonian-qm7/hamiltonian-qm7.py

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@
8787
from mlelec.features.acdc import compute_features_for_target
8888
from mlelec.targets import drop_zero_blocks # noqa: F401
8989
from mlelec.utils.plot_utils import plot_losses
90+
from urllib3.util.retry import Retry
9091

9192

9293
os.environ["PYSCFAD_BACKEND"] = "torch"
@@ -188,15 +189,33 @@ def save_parameters(file_path, **params):
188189
# We first download the data for the two examples from Zenodo
189190
# and unzip the downloaded datafile.
190191

191-
if not os.path.exists("hamiltonian-qm7-data"):
192-
url = r"https://zenodo.org/records/15524259/files/hamiltonian-qm7-data.zip"
193-
response = requests.get(url)
192+
193+
def fetch_dataset(filename, base_url, local_path=""):
194+
"""Helper function to load the pre-computed examples"""
195+
196+
local_file = local_path + filename
197+
if os.path.isfile(local_file):
198+
return
199+
200+
# Retry strategy: wait 1s, 2s, 4s, 8s, 16s on 429/5xx errors
201+
retry_strategy = Retry(
202+
total=5, backoff_factor=1, status_forcelist=[429, 500, 502, 503, 504]
203+
)
204+
session = requests.Session()
205+
session.mount("https://", requests.adapters.HTTPAdapter(max_retries=retry_strategy))
206+
207+
# Fetch with automatic retry and error raising
208+
response = session.get(base_url + filename)
194209
response.raise_for_status()
195-
with open("hamiltonian-qm7-data.zip", "wb") as f:
196-
f.write(response.content)
197210

198-
with ZipFile("hamiltonian-qm7-data.zip", "r") as zObject:
199-
zObject.extractall(path=".")
211+
with open(local_file, "wb") as file:
212+
file.write(response.content)
213+
214+
215+
fetch_dataset("hamiltonian-qm7-data.zip", "https://zenodo.org/records/15524259/files/")
216+
217+
with ZipFile("hamiltonian-qm7-data.zip", "r") as zObject:
218+
zObject.extractall(path=".")
200219

201220
# %%
202221
# Prepare the Dataset for ML Training

examples/pet-mad-uq/pet-mad-uq.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,7 @@
8787
import subprocess
8888
from urllib.request import urlretrieve
8989

90-
import ase.cell
91-
import ase.ga.utilities
90+
import ase.geometry.rdf
9291
import matplotlib.pyplot as plt
9392
import numpy as np
9493
import pandas as pd
@@ -363,7 +362,7 @@
363362
for atoms in frames:
364363

365364
# Compute H-H distances
366-
bins, xs = ase.ga.utilities.get_rdf( # type: ignore
365+
bins, xs = ase.geometry.rdf.get_rdf( # type: ignore
367366
atoms, 4.5, num_bins, elements=[1, 1]
368367
)
369368

@@ -378,7 +377,7 @@
378377
rdfs_hh.append(bins * 3.0 / 2.0) # correct ASE normalization
379378

380379
# Compute O-O distances
381-
bins, xs = ase.ga.utilities.get_rdf( # type: ignore
380+
bins, xs = ase.geometry.rdf.get_rdf( # type: ignore
382381
atoms, 4.5, num_bins, elements=[8, 8]
383382
)
384383
bins[2:-2] = (

0 commit comments

Comments
 (0)