Skip to content

用torch原生的flash attention性能更好 #80

@xphh

Description

@xphh

transformer_decoder.py里面可以替换torch原生的scaled_dot_product_attention函数

第247行:output = self.attention(q, k, v, mask=mask)

改成:output = F.scaled_dot_product_attention(q, k, v, attn_mask=mask.bool())

整体性能大概可以提升10%

conformer_encoder.py里面应该也可以,但逻辑稍微有点不一样,我还不知道怎么改,麻烦作者可以看看

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions