Skip to content

Commit e02c54e

Browse files
committed
refactor some io to spec
1 parent 4fdbf4c commit e02c54e

File tree

3 files changed

+13
-242
lines changed

3 files changed

+13
-242
lines changed

bioimageio/core/resource_io/io_.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,6 @@ def load_raw_resource_description(
112112
return source
113113

114114
raw_rd = spec.load_raw_resource_description(source, update_to_format=update_to_format)
115-
raw_rd = _replace_relative_paths_for_remote_source(raw_rd, raw_rd.root_path)
116115
return raw_rd
117116

118117

bioimageio/core/resource_io/nodes.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -42,20 +42,6 @@ class Dependencies(Node, raw_nodes.Dependencies):
4242
file: pathlib.Path = missing
4343

4444

45-
@dataclass
46-
class LocalImportableModule(Node, raw_nodes.ImportableModule):
47-
"""intermediate between raw_nodes.ImportableModule and nodes.ImportedSource. Used by SourceNodeTransformer"""
48-
49-
root_path: pathlib.Path = missing
50-
51-
52-
@dataclass
53-
class ResolvedImportableSourceFile(Node, raw_nodes.ImportableSourceFile):
54-
"""intermediate between raw_nodes.ImportableSourceFile and nodes.ImportedSource. Used by SourceNodeTransformer"""
55-
56-
source_file: pathlib.Path = missing
57-
58-
5945
@dataclass
6046
class CiteEntry(Node, rdf_raw_nodes.CiteEntry):
6147
pass

bioimageio/core/resource_io/utils.py

Lines changed: 13 additions & 227 deletions
Original file line numberDiff line numberDiff line change
@@ -2,24 +2,23 @@
22
import importlib.util
33
import os
44
import pathlib
5-
import shutil
65
import sys
76
import typing
8-
import warnings
9-
from functools import singledispatch
107
from types import ModuleType
11-
from urllib.request import url2pathname
128

13-
import requests
14-
from marshmallow import ValidationError
15-
from tqdm import tqdm
16-
17-
from bioimageio.spec.shared import fields, raw_nodes
18-
from bioimageio.spec.shared.common import BIOIMAGEIO_CACHE_PATH
19-
from bioimageio.spec.shared.utils import GenericRawNode, GenericRawRD, NodeTransformer, NodeVisitor
9+
from bioimageio.spec.shared import raw_nodes
10+
from bioimageio.spec.shared.utils import (
11+
GenericRawNode,
12+
GenericRawRD,
13+
GenericResolvedNode,
14+
NodeTransformer,
15+
NodeVisitor,
16+
UriNodeTransformer,
17+
resolve_source,
18+
source_available,
19+
)
2020
from . import nodes
2121

22-
GenericResolvedNode = typing.TypeVar("GenericResolvedNode", bound=nodes.Node)
2322
GenericNode = typing.Union[GenericRawNode, GenericResolvedNode]
2423

2524

@@ -48,34 +47,6 @@ def visit_WindowsPath(self, leaf: pathlib.WindowsPath):
4847
self._visit_source(leaf)
4948

5049

51-
class UriNodeTransformer(NodeTransformer):
52-
def __init__(self, *, root_path: os.PathLike):
53-
self.root_path = pathlib.Path(root_path).resolve()
54-
55-
def transform_URI(self, node: raw_nodes.URI) -> pathlib.Path:
56-
local_path = resolve_source(node, root_path=self.root_path)
57-
return local_path
58-
59-
def transform_ImportableSourceFile(
60-
self, node: raw_nodes.ImportableSourceFile
61-
) -> nodes.ResolvedImportableSourceFile:
62-
return nodes.ResolvedImportableSourceFile(
63-
source_file=resolve_source(node.source_file, self.root_path), callable_name=node.callable_name
64-
)
65-
66-
def transform_ImportableModule(self, node: raw_nodes.ImportableModule) -> nodes.LocalImportableModule:
67-
return nodes.LocalImportableModule(**dataclasses.asdict(node), root_path=self.root_path)
68-
69-
def _transform_Path(self, leaf: pathlib.Path):
70-
return self.root_path / leaf
71-
72-
def transform_PosixPath(self, leaf: pathlib.PosixPath) -> pathlib.Path:
73-
return self._transform_Path(leaf)
74-
75-
def transform_WindowsPath(self, leaf: pathlib.WindowsPath) -> pathlib.Path:
76-
return self._transform_Path(leaf)
77-
78-
7950
class SourceNodeTransformer(NodeTransformer):
8051
"""
8152
Imports all source callables
@@ -92,7 +63,7 @@ def __enter__(self):
9263
def __exit__(self, exc_type, exc_value, traceback):
9364
sys.path.remove(self.path)
9465

95-
def transform_LocalImportableModule(self, node: nodes.LocalImportableModule) -> nodes.ImportedSource:
66+
def transform_LocalImportableModule(self, node: raw_nodes.LocalImportableModule) -> nodes.ImportedSource:
9667
with self.TemporaryInsertionIntoPythonPath(str(node.root_path)):
9768
module = importlib.import_module(node.module_name)
9869

@@ -105,7 +76,7 @@ def transform_ImportableModule(node):
10576
)
10677

10778
@staticmethod
108-
def transform_ResolvedImportableSourceFile(node: nodes.ResolvedImportableSourceFile) -> nodes.ImportedSource:
79+
def transform_ResolvedImportableSourceFile(node: raw_nodes.ResolvedImportableSourceFile) -> nodes.ImportedSource:
10980
module_path = resolve_source(node.source_file)
11081
module_name = f"module_from_source.{module_path.stem}"
11182
importlib_spec = importlib.util.spec_from_file_location(module_name, module_path)
@@ -137,156 +108,6 @@ def generic_transformer(self, node: GenericRawNode) -> GenericResolvedNode:
137108
return super().generic_transformer(node)
138109

139110

140-
@singledispatch # todo: fix type annotations
141-
def resolve_source(source, root_path: os.PathLike = pathlib.Path(), output=None):
142-
raise TypeError(type(source))
143-
144-
145-
@resolve_source.register
146-
def _resolve_source_uri_node(
147-
source: raw_nodes.URI, root_path: os.PathLike = pathlib.Path(), output: typing.Optional[os.PathLike] = None
148-
) -> pathlib.Path:
149-
path_or_remote_uri = resolve_local_source(source, root_path, output)
150-
if isinstance(path_or_remote_uri, raw_nodes.URI):
151-
local_path = _download_url(path_or_remote_uri, output)
152-
elif isinstance(path_or_remote_uri, pathlib.Path):
153-
local_path = path_or_remote_uri
154-
else:
155-
raise TypeError(path_or_remote_uri)
156-
157-
return local_path
158-
159-
160-
@resolve_source.register
161-
def _resolve_source_str(
162-
source: str, root_path: os.PathLike = pathlib.Path(), output: typing.Optional[os.PathLike] = None
163-
) -> pathlib.Path:
164-
return resolve_source(fields.Union([fields.URI(), fields.Path()]).deserialize(source), root_path, output)
165-
166-
167-
@resolve_source.register
168-
def _resolve_source_path(
169-
source: pathlib.Path, root_path: os.PathLike = pathlib.Path(), output: typing.Optional[os.PathLike] = None
170-
) -> pathlib.Path:
171-
if not source.is_absolute():
172-
source = pathlib.Path(root_path).absolute() / source
173-
174-
if output is None:
175-
return source
176-
else:
177-
try:
178-
shutil.copyfile(source, output)
179-
except shutil.SameFileError: # source and output are identical
180-
pass
181-
return pathlib.Path(output)
182-
183-
184-
@resolve_source.register
185-
def _resolve_source_resolved_importable_path(
186-
source: nodes.ResolvedImportableSourceFile,
187-
root_path: os.PathLike = pathlib.Path(),
188-
output: typing.Optional[os.PathLike] = None,
189-
) -> nodes.ResolvedImportableSourceFile:
190-
return nodes.ResolvedImportableSourceFile(
191-
callable_name=source.callable_name, source_file=resolve_source(source.source_file, root_path, output)
192-
)
193-
194-
195-
@resolve_source.register
196-
def _resolve_source_importable_path(
197-
source: raw_nodes.ImportableSourceFile,
198-
root_path: os.PathLike = pathlib.Path(),
199-
output: typing.Optional[os.PathLike] = None,
200-
) -> nodes.ResolvedImportableSourceFile:
201-
return nodes.ResolvedImportableSourceFile(
202-
callable_name=source.callable_name, source_file=resolve_source(source.source_file, root_path, output)
203-
)
204-
205-
206-
@resolve_source.register
207-
def _resolve_source_list(
208-
source: list,
209-
root_path: os.PathLike = pathlib.Path(),
210-
output: typing.Optional[typing.Sequence[typing.Optional[os.PathLike]]] = None,
211-
) -> typing.List[pathlib.Path]:
212-
assert output is None or len(output) == len(source)
213-
return [resolve_source(el, root_path, out) for el, out in zip(source, output or [None] * len(source))]
214-
215-
216-
def resolve_local_sources(
217-
sources: typing.Sequence[typing.Union[str, os.PathLike, raw_nodes.URI]],
218-
root_path: os.PathLike,
219-
outputs: typing.Optional[typing.Sequence[os.PathLike]] = None,
220-
) -> typing.List[typing.Union[pathlib.Path, raw_nodes.URI]]:
221-
assert outputs is None or len(outputs) == len(sources)
222-
return [resolve_local_source(src, root_path, out) for src, out in zip(sources, outputs)]
223-
224-
225-
def resolve_local_source(
226-
source: typing.Union[str, os.PathLike, raw_nodes.URI],
227-
root_path: os.PathLike,
228-
output: typing.Optional[os.PathLike] = None,
229-
) -> typing.Union[pathlib.Path, raw_nodes.URI]:
230-
if isinstance(source, (tuple, list)):
231-
return type(source)([resolve_local_source(s, root_path, output) for s in source])
232-
elif isinstance(source, os.PathLike) or isinstance(source, str):
233-
try: # source as path from cwd
234-
is_path_cwd = pathlib.Path(source).exists()
235-
except OSError:
236-
is_path_cwd = False
237-
238-
try: # source as relative path from root_path
239-
path_from_root = pathlib.Path(root_path) / source
240-
is_path_rp = (path_from_root).exists()
241-
except OSError:
242-
is_path_rp = False
243-
else:
244-
if not is_path_cwd and is_path_rp:
245-
source = path_from_root
246-
247-
if is_path_cwd or is_path_rp:
248-
source = pathlib.Path(source)
249-
if output is None:
250-
return source
251-
else:
252-
try:
253-
shutil.copyfile(source, output)
254-
except shutil.SameFileError:
255-
pass
256-
return pathlib.Path(output)
257-
258-
elif isinstance(source, os.PathLike):
259-
raise FileNotFoundError(f"Could neither find {source} nor {pathlib.Path(root_path) / source}")
260-
261-
if isinstance(source, str):
262-
uri = fields.URI().deserialize(source)
263-
else:
264-
uri = source
265-
266-
assert isinstance(uri, raw_nodes.URI), uri
267-
if uri.scheme == "file":
268-
local_path_or_remote_uri = pathlib.Path(url2pathname(uri.path))
269-
elif uri.scheme in ("https", "https"):
270-
local_path_or_remote_uri = uri
271-
else:
272-
raise ValueError(f"Unknown uri scheme {uri.scheme}")
273-
274-
return local_path_or_remote_uri
275-
276-
277-
def source_available(source: typing.Union[pathlib.Path, raw_nodes.URI], root_path: pathlib.Path) -> bool:
278-
local_path_or_remote_uri = resolve_local_source(source, root_path)
279-
if isinstance(local_path_or_remote_uri, raw_nodes.URI):
280-
response = requests.head(str(local_path_or_remote_uri))
281-
available = response.status_code == 200
282-
elif isinstance(local_path_or_remote_uri, pathlib.Path):
283-
available = local_path_or_remote_uri.exists()
284-
else:
285-
raise TypeError(local_path_or_remote_uri)
286-
287-
return available
288-
289-
290111
def all_sources_available(
291112
node: typing.Union[GenericNode, list, tuple, dict], root_path: os.PathLike = pathlib.Path()
292113
) -> bool:
@@ -298,41 +119,6 @@ def all_sources_available(
298119
return True
299120

300121

301-
def _download_url(uri: raw_nodes.URI, output: typing.Optional[os.PathLike] = None) -> pathlib.Path:
302-
if output is not None:
303-
local_path = pathlib.Path(output)
304-
else:
305-
# todo: proper caching
306-
local_path = BIOIMAGEIO_CACHE_PATH / uri.scheme / uri.authority / uri.path.strip("/") / uri.query
307-
308-
if local_path.exists():
309-
warnings.warn(f"found cached {local_path}. Skipping download of {uri}.")
310-
else:
311-
local_path.parent.mkdir(parents=True, exist_ok=True)
312-
313-
try:
314-
# download with tqdm adapted from:
315-
# https://github.com/shaypal5/tqdl/blob/189f7fd07f265d29af796bee28e0893e1396d237/tqdl/core.py
316-
# Streaming, so we can iterate over the response.
317-
r = requests.get(str(uri), stream=True)
318-
# Total size in bytes.
319-
total_size = int(r.headers.get("content-length", 0))
320-
block_size = 1024 # 1 Kibibyte
321-
t = tqdm(total=total_size, unit="iB", unit_scale=True, desc=local_path.name)
322-
with local_path.open("wb") as f:
323-
for data in r.iter_content(block_size):
324-
t.update(len(data))
325-
f.write(data)
326-
t.close()
327-
if total_size != 0 and t.n != total_size:
328-
# todo: check more carefully and raise on real issue
329-
warnings.warn("Download does not have expected size.")
330-
except Exception as e:
331-
raise RuntimeError(f"Failed to download {uri} ({e})")
332-
333-
return local_path
334-
335-
336122
def resolve_raw_resource_description(raw_rd: GenericRawRD, nodes_module: typing.Any) -> GenericResolvedNode:
337123
"""resolve all uris and sources"""
338124
rd = UriNodeTransformer(root_path=raw_rd.root_path).transform(raw_rd)

0 commit comments

Comments
 (0)