[shardformer] support SAM (#4231)

* 1.support sam 2.add fused qkv for nn.Linear

* update utils support set element in list

* overtwrite SamVisionAttention foward to use DropoutForParallelInput

* remove unused code
This commit is contained in:
FoolPlayer
2023-07-14 15:56:59 +08:00
committed by Hongxin Liu
parent c59d7aca09
commit dd2bf02679
10 changed files with 733 additions and 10 deletions

View File

@@ -1,25 +1,57 @@
import re
def get_obj_list_element(obj, a):
def get_obj_list_element(obj, attr: str):
r"""
Get the element of the list in the object
If the attr is a normal attribute, return the attribute of the object.
If the attr is a index type, return the element of the index in the list, like `layers[0]`.
Args:
obj (Object): The object to get
attr (str): The suffix of the attribute to get
"""
re_pattern = r'\[\d+\]'
prog = re.compile(re_pattern)
result = prog.search(a)
result = prog.search(attr)
if result:
matched_brackets = result.group()
matched_index = matched_brackets.replace('[', '')
matched_index = matched_index.replace(']', '')
a_ = a.replace(matched_brackets, '')
container_obj = getattr(obj, a_)
attr_ = attr.replace(matched_brackets, '')
container_obj = getattr(obj, attr_)
obj = container_obj[int(matched_index)]
else:
obj = getattr(obj, a)
obj = getattr(obj, attr)
return obj
def set_obj_list_element(obj, attr: str, value):
r"""
Set the element to value of a list object
It used like set_obj_list_element(obj, 'lyaers[0]', new_layer), it will set obj.layers[0] to value
Args:
obj (object): The object to set
attr (str): the string including a list index like `layers[0]`
"""
re_pattern = r'\[\d+\]'
prog = re.compile(re_pattern)
result = prog.search(attr)
if result:
matched_brackets = result.group()
matched_index = matched_brackets.replace('[', '')
matched_index = matched_index.replace(']', '')
attr_ = attr.replace(matched_brackets, '')
container_obj = getattr(obj, attr_)
container_obj[int(matched_index)] = value
else:
setattr(obj, attr, value)
def hasattr_(obj, attr: str):
r"""
Check whether the object has the multi sublevel attr
@@ -56,7 +88,7 @@ def setattr_(obj, attr: str, value, ignore: bool = False):
if ignore:
return
raise AttributeError(f"Object {obj.__class__.__name__} has no attribute {attr}")
setattr(obj, attrs[-1], value)
set_obj_list_element(obj, attrs[-1], value)
def getattr_(obj, attr: str, ignore: bool = False):