src.utils.use_npu module#
- src.utils.use_npu.npu_attention_fallback(model)[source]#
Force every MultiheadAttention module into train mode and disable dropout to avoid the unsupported _native_multi_head_attention fused operator on NPU.
- Parameters:
model – PyTorch model.
- src.utils.use_npu.recover_npu_attention(model, original_states)[source]#
Restore the original dropout and training states of MultiheadAttention.
- Parameters:
model – PyTorch model.
original_states – State dictionary captured by npu_attention_fallback.
- src.utils.use_npu.npu_attention_fallback_context(model, enable=True)[source]#
Context manager that temporarily switches MultiheadAttention to train mode and disables dropout to avoid the unsupported NPU fused operator.
- Parameters:
model – PyTorch model.
enable – Whether to enable this workaround, usually controlled by args.