Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

hotfix pipefusion using flash_attn #411

Merged
merged 6 commits into from
Dec 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def get_cuda_version():
],
extras_require={
"diffusers": [
"diffusers>=0.32.0", # NOTE: diffusers>=0.32.0.dev is necessary for CogVideoX and Flux
"diffusers>=0.31.0", # NOTE: diffusers>=0.32.0.dev is necessary for CogVideoX and Flux
"flash_attn>=2.6.3",
]
},
Expand Down
50 changes: 30 additions & 20 deletions xfuser/core/long_ctx_attention/ring/ring_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
except ImportError:
flash_attn = None
_flash_attn_forward = None
from yunchang.kernels.attention import pytorch_attn_forward

def xdit_ring_flash_attn_forward(
process_group,
Expand Down Expand Up @@ -85,34 +86,43 @@ def xdit_ring_flash_attn_forward(
key, value = k, v

if not causal or step <= comm.rank:
assert flash_attn is not None, f"FlashAttention is not available, please install flash_attn"
if flash_attn.__version__ <= "2.6.3":
block_out, _, _, _, _, block_lse, _, _ = _flash_attn_forward(
if flash_attn is None:
block_out, block_lse = pytorch_attn_forward(
q,
key,
value,
dropout_p,
softmax_scale,
causal=causal and step == 0,
window_size=window_size,
softcap=0.0,
alibi_slopes=alibi_slopes,
return_softmax=True and dropout_p > 0,
)
else:
block_out, block_lse, _, _ = _flash_attn_forward(
q,
key,
value,
dropout_p,
softmax_scale,
causal=causal and step == 0,
window_size_left=window_size[0],
window_size_right=window_size[1],
softcap=0.0,
alibi_slopes=alibi_slopes,
return_softmax=True and dropout_p > 0,
)
if flash_attn.__version__ <= "2.6.3":
block_out, _, _, _, _, block_lse, _, _ = _flash_attn_forward(
q,
key,
value,
dropout_p,
softmax_scale,
causal=causal and step == 0,
window_size=window_size,
softcap=0.0,
alibi_slopes=alibi_slopes,
return_softmax=True and dropout_p > 0,
)
else:
block_out, block_lse, _, _ = _flash_attn_forward(
q,
key,
value,
dropout_p,
softmax_scale,
causal=causal and step == 0,
window_size_left=window_size[0],
window_size_right=window_size[1],
softcap=0.0,
alibi_slopes=alibi_slopes,
return_softmax=True and dropout_p > 0,
)
out, lse = update_out_and_lse(out, lse, block_out, block_lse)

if step + 1 != comm.world_size:
Expand Down
Loading