1
0
mirror of https://github.com/hpcaitech/ColossalAI.git synced 2025-05-03 05:58:09 +00:00
ColossalAI/applications/Chat/coati/kernels/wrapper.py
Hongxin Liu 7bd0bee8ea
[chat] add opt attn kernel ()
* [chat] add opt attn kernel

* [chat] disable xformer during fwd
2023-05-04 16:03:33 +08:00

19 lines
533 B
Python

import torch.nn as nn
from transformers.models.opt.modeling_opt import OPTAttention
from .opt_attn import XOPTAttention
def convert_to_xformer_model(model: nn.Module) -> nn.Module:
for module in model.modules():
if isinstance(module, OPTAttention):
module.__class__ = XOPTAttention
return model
def recover_from_xformer_model(model: nn.Module) -> nn.Module:
for module in model.modules():
if isinstance(module, XOPTAttention):
module.__class__ = OPTAttention
return model