mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-28 13:05:26 +00:00
[fx] refactored the file structure of patched function and module (#1238)
* [fx] refactored the file structure of patched function and module * polish code
This commit is contained in:
@@ -0,0 +1,24 @@
|
||||
import operator
|
||||
import torch
|
||||
from ..registry import meta_patched_function
|
||||
|
||||
|
||||
@meta_patched_function.register(operator.getitem)
|
||||
def operator_getitem(a, b):
|
||||
# copied from huggingface.utils.fx
|
||||
def to_concrete(t):
|
||||
if isinstance(t, torch.Tensor):
|
||||
concrete = torch.ones_like(t, device="cpu")
|
||||
if concrete.dtype in [torch.float16, torch.float32, torch.float64, torch.int32]:
|
||||
concrete = concrete.to(torch.int64)
|
||||
return concrete
|
||||
return t
|
||||
|
||||
if isinstance(a, torch.Tensor):
|
||||
# TODO: infer shape without performing the computation.
|
||||
if isinstance(b, tuple):
|
||||
b = tuple(map(to_concrete, b))
|
||||
else:
|
||||
b = to_concrete(b)
|
||||
return operator.getitem(torch.empty_like(a, device="cpu"), b).to("meta")
|
||||
return operator.getitem(a, b)
|
Reference in New Issue
Block a user