Source code for src.utils.use_npu

import os
import torch
from contextlib import contextmanager
from .tag2ansi import tag2ansi


## When the USE_NPU environment variable is present, run the code on a Huawei Ascend NPU.
USE_NPU = os.environ.get('USE_NPU', None) not in [None, '', 'no', 'No', 'false', 'False']
if USE_NPU:
    print(tag2ansi("[yellow bold]Using Huawei Ascend NPU for training.[reset]"))
    import torch_npu
    from torch_npu.contrib import transfer_to_npu  # Transparent migration to NPU.
    torch.backends.cuda.enable_flash_sdp(False)
    torch.backends.cuda.enable_mem_efficient_sdp(False)
    torch.backends.cuda.enable_math_sdp(True)


[docs] def npu_attention_fallback(model): """ Force every `MultiheadAttention` module into train mode and disable dropout to avoid the unsupported `_native_multi_head_attention` fused operator on NPU. Args: model: PyTorch model. """ original_states = {} # Preserve the original dropout and training states. for name, module in model.named_modules(): if isinstance(module, torch.nn.MultiheadAttention): original_states[name] = { 'dropout': module.dropout, 'training': module.training } module.train() # Force train mode to fall back to the MatMul path. module.dropout = 0.0 # Force dropout off to keep inference deterministic. # No need to consider BatchNorm-like modules here; MultiheadAttention has none. return original_states
[docs] def recover_npu_attention(model, original_states): """ Restore the original dropout and training states of `MultiheadAttention`. Args: model: PyTorch model. original_states: State dictionary captured by `npu_attention_fallback`. """ for name, module in model.named_modules(): if name in original_states: state = original_states[name] module.dropout = state['dropout'] # Restore dropout. module.train(state['training']) # Restore the original mode (train or eval).
[docs] @contextmanager def npu_attention_fallback_context(model, enable=True): """ Context manager that temporarily switches `MultiheadAttention` to train mode and disables dropout to avoid the unsupported NPU fused operator. Args: model: PyTorch model. enable: Whether to enable this workaround, usually controlled by args. """ if not enable: yield return else: original_states = npu_attention_fallback(model) try: yield finally: recover_npu_attention(model, original_states)