mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 17:17:05 +00:00
[shardformer] support SAM (#4231)
* 1.support sam 2.add fused qkv for nn.Linear * update utils support set element in list * overtwrite SamVisionAttention foward to use DropoutForParallelInput * remove unused code
This commit is contained in:
@@ -1,25 +1,57 @@
|
||||
import re
|
||||
|
||||
|
||||
def get_obj_list_element(obj, a):
|
||||
def get_obj_list_element(obj, attr: str):
|
||||
r"""
|
||||
Get the element of the list in the object
|
||||
|
||||
If the attr is a normal attribute, return the attribute of the object.
|
||||
If the attr is a index type, return the element of the index in the list, like `layers[0]`.
|
||||
|
||||
Args:
|
||||
obj (Object): The object to get
|
||||
attr (str): The suffix of the attribute to get
|
||||
|
||||
"""
|
||||
re_pattern = r'\[\d+\]'
|
||||
prog = re.compile(re_pattern)
|
||||
result = prog.search(a)
|
||||
result = prog.search(attr)
|
||||
if result:
|
||||
matched_brackets = result.group()
|
||||
matched_index = matched_brackets.replace('[', '')
|
||||
matched_index = matched_index.replace(']', '')
|
||||
a_ = a.replace(matched_brackets, '')
|
||||
container_obj = getattr(obj, a_)
|
||||
attr_ = attr.replace(matched_brackets, '')
|
||||
container_obj = getattr(obj, attr_)
|
||||
obj = container_obj[int(matched_index)]
|
||||
else:
|
||||
obj = getattr(obj, a)
|
||||
obj = getattr(obj, attr)
|
||||
return obj
|
||||
|
||||
|
||||
def set_obj_list_element(obj, attr: str, value):
|
||||
r"""
|
||||
Set the element to value of a list object
|
||||
|
||||
It used like set_obj_list_element(obj, 'lyaers[0]', new_layer), it will set obj.layers[0] to value
|
||||
|
||||
Args:
|
||||
obj (object): The object to set
|
||||
attr (str): the string including a list index like `layers[0]`
|
||||
"""
|
||||
re_pattern = r'\[\d+\]'
|
||||
prog = re.compile(re_pattern)
|
||||
result = prog.search(attr)
|
||||
if result:
|
||||
matched_brackets = result.group()
|
||||
matched_index = matched_brackets.replace('[', '')
|
||||
matched_index = matched_index.replace(']', '')
|
||||
attr_ = attr.replace(matched_brackets, '')
|
||||
container_obj = getattr(obj, attr_)
|
||||
container_obj[int(matched_index)] = value
|
||||
else:
|
||||
setattr(obj, attr, value)
|
||||
|
||||
|
||||
def hasattr_(obj, attr: str):
|
||||
r"""
|
||||
Check whether the object has the multi sublevel attr
|
||||
@@ -56,7 +88,7 @@ def setattr_(obj, attr: str, value, ignore: bool = False):
|
||||
if ignore:
|
||||
return
|
||||
raise AttributeError(f"Object {obj.__class__.__name__} has no attribute {attr}")
|
||||
setattr(obj, attrs[-1], value)
|
||||
set_obj_list_element(obj, attrs[-1], value)
|
||||
|
||||
|
||||
def getattr_(obj, attr: str, ignore: bool = False):
|
||||
|
Reference in New Issue
Block a user