Skip to content

fx_importer NotImplementedError: MultiheadAttention layer with NeedWeight = false #4158

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
alaa-ali opened this issue Apr 24, 2025 · 0 comments

Comments

@alaa-ali
Copy link

This issue explains a bug in torch-mlir/blob/main/python/torch_mlir/extras/fx_importer.py
We found this while importing an exported model into MLIR. This occurs for an exported MultiheadAttention layer with "NeedWeight = false" which means weights are not going to be returned by the layer. So, the second output attn_output_weights will be None in this case.

The following error is raised:
Python Error: NotImplementedError: OutputKind.USER_OUTPUT for <class
'torch.export.graph_signature.ConstantArgument'>: ConstantArgument(name='',
value=None)

[Additionally, I couldn't visualize the exported model as .pt2 using a tool like https://netron.app/,
However, I am able to import the exported model and visualize it when "NeedWeight = true", i.e. attn_output_weights will not be None in this case]

doc: https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html
parameters:
need_weights: [bool] If specified, returns attn_output_weights
outputs:
attn_output_weights: Only returned when need_weights=True.

Source code to reproduce the exported model with attn_output_weights = None

import torch
import torch.nn as nn
import torch.nn.functional as F
class CustomModel(nn.Module):
    def __init__(self, kwargs):
        super(CustomModel, self).__init__()
        self.kwargs = kwargs
        self.attn = nn.MultiheadAttention(embed_dim=kwargs['embedding_dim'], num_heads=kwargs['num_heads'], dropout=kwargs['dropout'], add_bias_kv=kwargs['add_bias_kv'], add_zero_attn=kwargs['add_zero_attn'], kdim=kwargs['kdim'], vdim=kwargs['vdim'], batch_first=kwargs['batch_first'])
    def forward(self, *args):
        query, key, value, attn_mask, kp_mask = args[0], args[1], args[2], args[3], args[4]
        return self.attn(query, key, value, attn_mask=attn_mask, key_padding_mask=kp_mask, need_weights=self.kwargs['need_weights'], average_attn_weights=self.kwargs['average_attn_weights'], is_causal=self.kwargs['is_causal'])

# Create model instance
model = CustomModel(kwargs = {
    'embedding_dim': 64,
    'num_heads': 1,
	'dropout': 0.1,
	'add_bias_kv': True,
    'add_zero_attn': False,
    'kdim': 16,
    'vdim': None, #used None inseatd of string(missing)
    'batch_first': True,
    'need_weights': False,
    'average_attn_weights': True,
    'is_causal': False
})

# Dummy input tensors
query = torch.rand(1, 50, 64)         # (batch, seq_len, embedding_dim)
key = torch.rand(1, 10, 16)
value = torch.rand(1, 10, 64)
attn_mask = torch.zeros(50, 10)       # (seq_len, seq_len)
key_padding_mask = torch.zeros(1, 10)  # (batch, seq_len)

# Export the model
exported_model = torch.export.export(
    model, args=(query, key, value, attn_mask, key_padding_mask))

# use exported_model.graph to inspect the TorchScript graph
print(exported_model)

The error occurs due to a missing case in lines # 661, 662 in the source code below (torch.export.graph_signature.ConstantArgument is not handled)
torch-mlir/blob/main/python/torch_mlir/extras/fx_importer.py
Image

Before, proposing code changes to solve this issue, we wanted to check the expected behavior and confirm whether the OutputSpec is intentionally handled this way in the source code or if it's an actual bug that needs to be fixed.

This is a snippet from the exported program

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, p_attn_q_proj_weight: "f32[64, 64]", p_attn_k_proj_weight: "f32[64, 16]", p_attn_v_proj_weight: "f32[64, 64]", p_attn_in_proj_bias: "f32[192]", p_attn_bias_k: "f32[1, 1, 64]", p_attn_bias_v: "f32[1, 1, 64]", p_attn_out_proj_weight: "f32[64, 64]", p_attn_out_proj_bias: "f32[64]", args_0: "f32[1, 50, 64]", args_1: "f32[1, 10, 16]", args_2: "f32[1, 10, 64]", args_3: "f32[50, 10]", args_4: "f32[1, 10]"):
             # 
            transpose: "f32[50, 1, 64]" = torch.ops.aten.transpose.int(args_0, 1, 0);  args_0 = None
            ....
            view_8: "f32[50, 1, 64]" = torch.ops.aten.view.default(linear_3, [50, 1, 64]);  linear_3 = None
            transpose_6: "f32[1, 50, 64]" = torch.ops.aten.transpose.int(view_8, 1, 0);  view_8 = None
            return (transpose_6, **None**)
            
Graph signature: ExportGraphSignature(
input_specs=[InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_attn_q_proj_weight'), target='attn.q_proj_weight', persistent=None), ...], 
output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='transpose_6'), target=None), 
							  OutputSpec(kind=<**OutputKind.USER_OUTPUT: 1>, arg=ConstantArgument(name='', value=None**), target=None)])

We noticed that OutputSpec has enum below while the source code handles only two types of the enum below (TensorArgument, and SymIntArgument)
https://pytorch.org/docs/stable/export.html#torch.export.graph_signature.OutputSpec
Image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant