mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 01:28:31 +00:00
[fx] added testing for all bert variants (#1207)
* [fx] added testing for all bert variants * polish code
This commit is contained in:
@@ -1,6 +1,8 @@
|
||||
import operator
|
||||
import torch
|
||||
from torch.fx.proxy import Proxy, Attribute
|
||||
from typing import List, Union
|
||||
from torch.utils._pytree import PyTree
|
||||
|
||||
__all__ = ['ColoProxy']
|
||||
|
||||
@@ -26,8 +28,12 @@ class ColoProxy(Proxy):
|
||||
return self._meta_tensor
|
||||
|
||||
@meta_tensor.setter
|
||||
def meta_tensor(self, tensor: torch.Tensor):
|
||||
assert tensor is None or tensor.is_meta, 'Expected to receive a meta tensor, but got a non-meta tensor'
|
||||
def meta_tensor(self, tensor: Union[List[torch.Tensor], torch.Tensor]):
|
||||
|
||||
def _is_meta(item):
|
||||
assert torch.is_tensor(item) and item.is_meta
|
||||
|
||||
torch.fx.node.map_aggregate(tensor, _is_meta)
|
||||
self._meta_tensor = tensor
|
||||
|
||||
@property
|
||||
@@ -83,6 +89,14 @@ class ColoProxy(Proxy):
|
||||
def __setitem__(self, indices, values):
|
||||
return self.tracer.create_proxy("call_function", operator.setitem, (self, indices, values), {})
|
||||
|
||||
def __contains__(self, key):
|
||||
if self.node.op == "placeholder":
|
||||
# this is used to handle like
|
||||
# if x in kwargs
|
||||
# we don't handle this case for now
|
||||
return False
|
||||
return super().__contains__(key)
|
||||
|
||||
|
||||
class ColoAttribute(ColoProxy):
|
||||
|
||||
|
Reference in New Issue
Block a user