src.utils.use_npu module#

src.utils.use_npu.npu_attention_fallback(model)[source]#

Force every MultiheadAttention module into train mode and disable dropout to avoid the unsupported _native_multi_head_attention fused operator on NPU.

Parameters:

model – PyTorch model.

src.utils.use_npu.recover_npu_attention(model, original_states)[source]#

Restore the original dropout and training states of MultiheadAttention.

Parameters:
  • model – PyTorch model.

  • original_states – State dictionary captured by npu_attention_fallback.

src.utils.use_npu.npu_attention_fallback_context(model, enable=True)[source]#

Context manager that temporarily switches MultiheadAttention to train mode and disables dropout to avoid the unsupported NPU fused operator.

Parameters:
  • model – PyTorch model.

  • enable – Whether to enable this workaround, usually controlled by args.