diff --git a/tensorrt_llm/layers/attention.py b/tensorrt_llm/layers/attention.py index 871ccbb11..0264c9b45 100755 --- a/tensorrt_llm/layers/attention.py +++ b/tensorrt_llm/layers/attention.py @@ -872,8 +872,13 @@ def forward( assert kv_cache_params is None or kv_cache_params.is_valid( default_net().plugin_config.gpt_attention_plugin) - past_key_value = None if kv_cache_params is None else kv_cache_params.get_first_past_key_value( - ) + if kv_cache_params is None: + past_key_value = None + else: + if paged_kv_cache: + past_key_value = kv_cache_params.get_first_past_key_value() + else: + past_key_value = kv_cache_params.past_key_value[self.local_layer_idx] # if cross attention, cross QKV only needs to be calculated once in the # 1st decoding step --> write to cross KV cache --> remains constant