@@ -48,30 +48,34 @@ def infer_module_output_dtypes(
4848
4949
5050def insert_engine_to_cache (
51- hash_val : str ,
51+ hash_val : Optional [ str ] ,
5252 interpreter_result : TRTInterpreterResult ,
5353 engine_cache : BaseEngineCache ,
5454 settings : CompilationSettings ,
5555 inputs : Sequence [Input ],
5656) -> bool :
57+ if hash_val is None :
58+ logger .warning ("Hash value is not provided, so the engine will not be cached" )
59+ return False
60+
5761 if not ENABLED_FEATURES .refit :
58- logger .info ("Refit feature is not available, so the engine is not cached" )
62+ logger .warning (
63+ "Refit feature is not available, so the engine cache will not be used"
64+ )
5965 return False
6066
6167 # Cache the weight-stripped engine regardless of the `strip_engine_weights` setting
6268 if engine_cache .check (hash_val ) is not None :
63- logger .info (f"The engine already exists in cache for hash: { hash_val } " )
64- return False
65-
66- if not settings .strip_engine_weights :
67- # set EXCLUDE_WEIGHTS flag to strip weights
68- serialization_config = interpreter_result .engine .create_serialization_config ()
69- serialization_config .set_flag (trt .SerializationFlag .EXCLUDE_WEIGHTS )
70- weight_stripped_serialized_engine = (
71- interpreter_result .engine .serialize_with_config (serialization_config )
69+ logger .info (
70+ f"Detected that the engine with hash: { hash_val } exists in cache. It will be refreshed"
7271 )
73- else :
74- weight_stripped_serialized_engine = interpreter_result .engine .serialize ()
72+
73+ # set EXCLUDE_WEIGHTS flag to strip weights
74+ serialization_config = interpreter_result .engine .create_serialization_config ()
75+ serialization_config .set_flag (trt .SerializationFlag .EXCLUDE_WEIGHTS )
76+ weight_stripped_serialized_engine = interpreter_result .engine .serialize_with_config (
77+ serialization_config
78+ )
7579
7680 # Insert weight-stripped engine to cache
7781 engine_cache .insert (
@@ -86,20 +90,26 @@ def insert_engine_to_cache(
8690 interpreter_result .requires_output_allocator ,
8791 ),
8892 )
89- logger .info (f"Engine was successfully inserted into cache for hash: { hash_val } " )
93+ logger .info (f"Engine with hash: { hash_val } was successfully inserted into cache" )
9094 return True
9195
9296
9397def pull_cached_engine (
94- hash_val : str ,
98+ hash_val : Optional [ str ] ,
9599 module : torch .fx .GraphModule ,
96100 engine_cache : BaseEngineCache ,
97101 settings : CompilationSettings ,
98102 inputs : Sequence [Input ],
99103) -> Optional [SerializedInterpreterResult ]:
104+ if hash_val is None :
105+ logger .warning (
106+ "Hash value is not provided, so the engine cache will not be used"
107+ )
108+ return None
109+
100110 if not ENABLED_FEATURES .refit :
101- logger .info (
102- "Refit feature is not available, so the engine is not loaded from cache "
111+ logger .warning (
112+ "Refit feature is not available, so the engine cache will not be used "
103113 )
104114 return None
105115
@@ -131,7 +141,7 @@ def pull_cached_engine(
131141 ), f"Attempted to refit a cached engine built for a different input size (input: { i } , cached size: { cached_engine_inputs [i ]} , new size: { inputs [i ]} "
132142
133143 logger .info (
134- "Found the cached engine that corresponds to this graph. It is directly loaded."
144+ f "Found the cached engine with hash { hash_val } that corresponds to this graph. It is directly loaded."
135145 )
136146
137147 # refit the cached engine with the new graph module
@@ -194,20 +204,39 @@ def interpret_module_to_result(
194204 # engine_cache could be None if:
195205 # 1) engine_cache is not passed in when calling this function like convert_exported_program_to_serialized_trt_engine etc., or
196206 # 2) both cache_built_engines and reuse_cached_engines are False
197- if (
198- ENABLED_FEATURES .refit
199- and engine_cache is not None
207+
208+ is_engine_caching_supported = (
209+ engine_cache is not None
210+ and ENABLED_FEATURES .refit
200211 and not settings .immutable_weights
201- ):
202- if settings .cache_built_engines or settings .reuse_cached_engines :
203- hash_val = engine_cache .get_hash (module , inputs , settings )
212+ )
213+ # calculate the hash only once. It will be used in pulling and inserting the engine.
214+ hash_val = (
215+ engine_cache .get_hash (module , inputs , settings ) # type: ignore
216+ if is_engine_caching_supported
217+ and (settings .cache_built_engines or settings .reuse_cached_engines )
218+ else None
219+ )
204220
205- if settings .reuse_cached_engines :
206- serialized_interpreter_result = pull_cached_engine (
207- hash_val , module , engine_cache , settings , inputs
208- )
209- if serialized_interpreter_result is not None : # hit the cache
210- return serialized_interpreter_result
221+ if settings .reuse_cached_engines :
222+ if engine_cache is None :
223+ logger .warning (
224+ "Engine cache is not provided, so the engine will not be reused from cache"
225+ )
226+ elif not ENABLED_FEATURES .refit :
227+ logger .warning (
228+ "Refit feature is not available, so the engine will not be reused from cache"
229+ )
230+ elif settings .immutable_weights :
231+ logger .warning (
232+ "The engine weights are immutable, so the engine will not be reused from cache"
233+ )
234+ else :
235+ serialized_interpreter_result = pull_cached_engine (
236+ hash_val , module , engine_cache , settings , inputs
237+ )
238+ if serialized_interpreter_result is not None : # hit the cache
239+ return serialized_interpreter_result
211240
212241 output_dtypes = infer_module_output_dtypes (
213242 module , truncate_double = settings .truncate_double
@@ -232,16 +261,23 @@ def interpret_module_to_result(
232261 f"CPU memory usage after clearing frozen parameters and building memory in conversion: { get_cpu_memory_usage ()} MB"
233262 )
234263
235- # Engine caching only for refittable engines
236- if (
237- ENABLED_FEATURES .refit
238- and not settings .immutable_weights
239- and settings .cache_built_engines
240- and engine_cache is not None
241- ):
242- _ = insert_engine_to_cache (
243- hash_val , interpreter_result , engine_cache , settings , inputs
244- )
264+ if settings .cache_built_engines :
265+ if engine_cache is None :
266+ logger .warning (
267+ "Engine cache is not provided, so the engine will not be cached"
268+ )
269+ elif not ENABLED_FEATURES .refit :
270+ logger .warning (
271+ "Refit feature is not available, so the engine will not be cached"
272+ )
273+ elif settings .immutable_weights :
274+ logger .warning (
275+ "The engine weights are immutable, so the engine will not be cached"
276+ )
277+ else :
278+ _ = insert_engine_to_cache (
279+ hash_val , interpreter_result , engine_cache , settings , inputs
280+ )
245281
246282 serialized_engine = interpreter_result .engine .serialize ()
247283 with io .BytesIO () as engine_bytes :
0 commit comments