@@ -61,16 +61,10 @@ def _get_pytorch_state_dict_weight_kwargs(architecture, model_kwargs, root):
61
61
tmp_archtecture = None
62
62
weight_kwargs = {"kwargs" : model_kwargs } if model_kwargs else {}
63
63
if ":" in architecture :
64
- arch_file , callable_name = architecture .replace ("::" , ":" ).split (":" )
65
-
66
- # this goes haywire if we pass an absolute path, so need to copt to a tmp relative path
67
- if os .path .isabs (arch_file ):
68
- tmp_archtecture = Path ("this_model_architecture.py" )
69
- copyfile (arch_file , root / tmp_archtecture )
70
- arch = ImportableSourceFile (callable_name , tmp_archtecture )
71
- else :
72
- arch = ImportableSourceFile (callable_name , Path (arch_file ))
73
-
64
+ # note: path itself might include : for absolute paths in windows
65
+ * arch_file_parts , callable_name = architecture .replace ("::" , ":" ).split (":" )
66
+ arch_file = _ensure_local (":" .join (arch_file_parts ), root )
67
+ arch = ImportableSourceFile (callable_name , arch_file )
74
68
arch_hash = _get_hash (root / arch .source_file )
75
69
weight_kwargs ["architecture_sha256" ] = arch_hash
76
70
else :
@@ -123,30 +117,21 @@ def _get_weights(
123
117
if tensorflow_version is None :
124
118
raise ValueError ("tensorflow_version needs to be passed for building a keras model" )
125
119
weights = model_spec .raw_nodes .KerasHdf5WeightsEntry (
126
- source = weight_source ,
127
- sha256 = weight_hash ,
128
- tensorflow_version = tensorflow_version ,
129
- ** attachments ,
120
+ source = weight_source , sha256 = weight_hash , tensorflow_version = tensorflow_version , ** attachments
130
121
)
131
122
132
123
elif weight_type == "tensorflow_saved_model_bundle" :
133
124
if tensorflow_version is None :
134
125
raise ValueError ("tensorflow_version needs to be passed for building a tensorflow model" )
135
126
weights = model_spec .raw_nodes .TensorflowSavedModelBundleWeightsEntry (
136
- source = weight_source ,
137
- sha256 = weight_hash ,
138
- tensorflow_version = tensorflow_version ,
139
- ** attachments ,
127
+ source = weight_source , sha256 = weight_hash , tensorflow_version = tensorflow_version , ** attachments
140
128
)
141
129
142
130
elif weight_type == "tensorflow_js" :
143
131
if tensorflow_version is None :
144
132
raise ValueError ("tensorflow_version needs to be passed for building a tensorflow_js model" )
145
133
weights = model_spec .raw_nodes .TensorflowJsWeightsEntry (
146
- source = weight_source ,
147
- sha256 = weight_hash ,
148
- tensorflow_version = tensorflow_version ,
149
- ** attachments ,
134
+ source = weight_source , sha256 = weight_hash , tensorflow_version = tensorflow_version , ** attachments
150
135
)
151
136
152
137
elif weight_type in weight_types :
@@ -519,9 +504,8 @@ def _ensure_local_or_url(source: Union[Path, URI, str, list], root: Path) -> Uni
519
504
return [_ensure_local_or_url (s , root ) for s in source ]
520
505
521
506
local_source = resolve_local_source (source , root )
522
- local_source = resolve_local_source (
523
- local_source , root , None if isinstance (local_source , URI ) else root / local_source .name
524
- )
507
+ if not isinstance (local_source , URI ):
508
+ local_source = resolve_local_source (local_source , root , root / local_source .name )
525
509
return local_source .relative_to (root )
526
510
527
511
@@ -654,6 +638,7 @@ def build_model(
654
638
Only requred for models with onnx weight format.
655
639
weight_kwargs: additional keyword arguments for this weight type.
656
640
"""
641
+ assert architecture is None or isinstance (architecture , str )
657
642
if root is None :
658
643
root = "."
659
644
root = Path (root )
0 commit comments