Skip to content

Commit 5032eb0

Browse files
committed
fix comments and add warnings
1 parent bb93381 commit 5032eb0

File tree

1 file changed

+76
-40
lines changed

1 file changed

+76
-40
lines changed

py/torch_tensorrt/dynamo/conversion/_conversion.py

Lines changed: 76 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -48,30 +48,34 @@ def infer_module_output_dtypes(
4848

4949

5050
def insert_engine_to_cache(
51-
hash_val: str,
51+
hash_val: Optional[str],
5252
interpreter_result: TRTInterpreterResult,
5353
engine_cache: BaseEngineCache,
5454
settings: CompilationSettings,
5555
inputs: Sequence[Input],
5656
) -> bool:
57+
if hash_val is None:
58+
logger.warning("Hash value is not provided, so the engine will not be cached")
59+
return False
60+
5761
if not ENABLED_FEATURES.refit:
58-
logger.info("Refit feature is not available, so the engine is not cached")
62+
logger.warning(
63+
"Refit feature is not available, so the engine cache will not be used"
64+
)
5965
return False
6066

6167
# Cache the weight-stripped engine regardless of the `strip_engine_weights` setting
6268
if engine_cache.check(hash_val) is not None:
63-
logger.info(f"The engine already exists in cache for hash: {hash_val}")
64-
return False
65-
66-
if not settings.strip_engine_weights:
67-
# set EXCLUDE_WEIGHTS flag to strip weights
68-
serialization_config = interpreter_result.engine.create_serialization_config()
69-
serialization_config.set_flag(trt.SerializationFlag.EXCLUDE_WEIGHTS)
70-
weight_stripped_serialized_engine = (
71-
interpreter_result.engine.serialize_with_config(serialization_config)
69+
logger.info(
70+
f"Detected that the engine with hash: {hash_val} exists in cache. It will be refreshed"
7271
)
73-
else:
74-
weight_stripped_serialized_engine = interpreter_result.engine.serialize()
72+
73+
# set EXCLUDE_WEIGHTS flag to strip weights
74+
serialization_config = interpreter_result.engine.create_serialization_config()
75+
serialization_config.set_flag(trt.SerializationFlag.EXCLUDE_WEIGHTS)
76+
weight_stripped_serialized_engine = interpreter_result.engine.serialize_with_config(
77+
serialization_config
78+
)
7579

7680
# Insert weight-stripped engine to cache
7781
engine_cache.insert(
@@ -86,20 +90,26 @@ def insert_engine_to_cache(
8690
interpreter_result.requires_output_allocator,
8791
),
8892
)
89-
logger.info(f"Engine was successfully inserted into cache for hash: {hash_val}")
93+
logger.info(f"Engine with hash: {hash_val} was successfully inserted into cache")
9094
return True
9195

9296

9397
def pull_cached_engine(
94-
hash_val: str,
98+
hash_val: Optional[str],
9599
module: torch.fx.GraphModule,
96100
engine_cache: BaseEngineCache,
97101
settings: CompilationSettings,
98102
inputs: Sequence[Input],
99103
) -> Optional[SerializedInterpreterResult]:
104+
if hash_val is None:
105+
logger.warning(
106+
"Hash value is not provided, so the engine cache will not be used"
107+
)
108+
return None
109+
100110
if not ENABLED_FEATURES.refit:
101-
logger.info(
102-
"Refit feature is not available, so the engine is not loaded from cache"
111+
logger.warning(
112+
"Refit feature is not available, so the engine cache will not be used"
103113
)
104114
return None
105115

@@ -131,7 +141,7 @@ def pull_cached_engine(
131141
), f"Attempted to refit a cached engine built for a different input size (input: {i}, cached size: {cached_engine_inputs[i]}, new size: {inputs[i]}"
132142

133143
logger.info(
134-
"Found the cached engine that corresponds to this graph. It is directly loaded."
144+
f"Found the cached engine with hash {hash_val} that corresponds to this graph. It is directly loaded."
135145
)
136146

137147
# refit the cached engine with the new graph module
@@ -194,20 +204,39 @@ def interpret_module_to_result(
194204
# engine_cache could be None if:
195205
# 1) engine_cache is not passed in when calling this function like convert_exported_program_to_serialized_trt_engine etc., or
196206
# 2) both cache_built_engines and reuse_cached_engines are False
197-
if (
198-
ENABLED_FEATURES.refit
199-
and engine_cache is not None
207+
208+
is_engine_caching_supported = (
209+
engine_cache is not None
210+
and ENABLED_FEATURES.refit
200211
and not settings.immutable_weights
201-
):
202-
if settings.cache_built_engines or settings.reuse_cached_engines:
203-
hash_val = engine_cache.get_hash(module, inputs, settings)
212+
)
213+
# calculate the hash only once. It will be used in pulling and inserting the engine.
214+
hash_val = (
215+
engine_cache.get_hash(module, inputs, settings) # type: ignore
216+
if is_engine_caching_supported
217+
and (settings.cache_built_engines or settings.reuse_cached_engines)
218+
else None
219+
)
204220

205-
if settings.reuse_cached_engines:
206-
serialized_interpreter_result = pull_cached_engine(
207-
hash_val, module, engine_cache, settings, inputs
208-
)
209-
if serialized_interpreter_result is not None: # hit the cache
210-
return serialized_interpreter_result
221+
if settings.reuse_cached_engines:
222+
if engine_cache is None:
223+
logger.warning(
224+
"Engine cache is not provided, so the engine will not be reused from cache"
225+
)
226+
elif not ENABLED_FEATURES.refit:
227+
logger.warning(
228+
"Refit feature is not available, so the engine will not be reused from cache"
229+
)
230+
elif settings.immutable_weights:
231+
logger.warning(
232+
"The engine weights are immutable, so the engine will not be reused from cache"
233+
)
234+
else:
235+
serialized_interpreter_result = pull_cached_engine(
236+
hash_val, module, engine_cache, settings, inputs
237+
)
238+
if serialized_interpreter_result is not None: # hit the cache
239+
return serialized_interpreter_result
211240

212241
output_dtypes = infer_module_output_dtypes(
213242
module, truncate_double=settings.truncate_double
@@ -232,16 +261,23 @@ def interpret_module_to_result(
232261
f"CPU memory usage after clearing frozen parameters and building memory in conversion: {get_cpu_memory_usage()} MB"
233262
)
234263

235-
# Engine caching only for refittable engines
236-
if (
237-
ENABLED_FEATURES.refit
238-
and not settings.immutable_weights
239-
and settings.cache_built_engines
240-
and engine_cache is not None
241-
):
242-
_ = insert_engine_to_cache(
243-
hash_val, interpreter_result, engine_cache, settings, inputs
244-
)
264+
if settings.cache_built_engines:
265+
if engine_cache is None:
266+
logger.warning(
267+
"Engine cache is not provided, so the engine will not be cached"
268+
)
269+
elif not ENABLED_FEATURES.refit:
270+
logger.warning(
271+
"Refit feature is not available, so the engine will not be cached"
272+
)
273+
elif settings.immutable_weights:
274+
logger.warning(
275+
"The engine weights are immutable, so the engine will not be cached"
276+
)
277+
else:
278+
_ = insert_engine_to_cache(
279+
hash_val, interpreter_result, engine_cache, settings, inputs
280+
)
245281

246282
serialized_engine = interpreter_result.engine.serialize()
247283
with io.BytesIO() as engine_bytes:

0 commit comments

Comments
 (0)