diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index d9920a877112..408a0e575337 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -3042,7 +3042,7 @@ def _flash_attention_3_varlen_hub( value_packed = torch.cat(value_valid, dim=0) func = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_VARLEN_HUB].kernel_fn - out, lse, *_ = func( + out = func( q=query_packed, k=key_packed, v=value_packed, @@ -3053,6 +3053,9 @@ def _flash_attention_3_varlen_hub( softmax_scale=scale, causal=is_causal, ) + lse = None + if isinstance(out, tuple): + out, lse, *_ = out out = out.unflatten(0, (batch_size, -1)) return (out, lse) if return_lse else out diff --git a/tests/models/test_attention_dispatch.py b/tests/models/test_attention_dispatch.py new file mode 100644 index 000000000000..b41cb6eaa114 --- /dev/null +++ b/tests/models/test_attention_dispatch.py @@ -0,0 +1,43 @@ +import torch + +from diffusers.models.attention_dispatch import ( + AttentionBackendName, + _HUB_KERNELS_REGISTRY, + _flash_attention_3_varlen_hub, +) + + +def test_flash_attention_3_varlen_hub_handles_tensor_return(monkeypatch): + def flash_attention_3_varlen_func( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + softmax_scale, + causal, + ): + return q + 1000 + + monkeypatch.setattr( + _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_VARLEN_HUB], + "kernel_fn", + flash_attention_3_varlen_func, + ) + + batch_size = 2 + seq_len = 4 + heads = 2 + dim = 5 + query = torch.arange(batch_size * seq_len * heads * dim, dtype=torch.float32).reshape( + batch_size, seq_len, heads, dim + ) + key = query.clone() + value = query.clone() + + out = _flash_attention_3_varlen_hub(query, key, value) + + assert out.shape == query.shape + assert torch.equal(out, query + 1000)