mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-25 15:01:43 +00:00
[hotfix] add deconstructor for stateful tensor (#848)
* add deconstructor for stateful tensor * fix colo init context
This commit is contained in:
parent
0f7ed8c192
commit
0dea140760
@ -6,7 +6,7 @@ class GeminiMemoryManager(object):
|
|||||||
def __init__(self, states_cls: EnumMeta):
|
def __init__(self, states_cls: EnumMeta):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.states_cls = states_cls
|
self.states_cls = states_cls
|
||||||
self._cnter = 0 # the counter of instances
|
self._cnter = 0 # the counter of instances
|
||||||
|
|
||||||
self.total_mem = dict()
|
self.total_mem = dict()
|
||||||
self.state_mem = dict()
|
self.state_mem = dict()
|
||||||
@ -20,10 +20,10 @@ class GeminiMemoryManager(object):
|
|||||||
return self._cnter
|
return self._cnter
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
self._cnter = 0 # the counter of instances
|
self._cnter = 0 # the counter of instances
|
||||||
|
|
||||||
self.total_mem['cpu'] = 0 # memory occupation of instances in cpu
|
self.total_mem['cpu'] = 0 # memory occupation of instances in cpu
|
||||||
self.total_mem['cuda'] = 0 # memory of occupation of instances in cuda
|
self.total_mem['cuda'] = 0 # memory of occupation of instances in cuda
|
||||||
|
|
||||||
# memory conditions for all states
|
# memory conditions for all states
|
||||||
for state in self.states_cls:
|
for state in self.states_cls:
|
||||||
@ -33,13 +33,16 @@ class GeminiMemoryManager(object):
|
|||||||
def register_new_instance(self):
|
def register_new_instance(self):
|
||||||
self._cnter += 1
|
self._cnter += 1
|
||||||
|
|
||||||
|
def delete_instance(self):
|
||||||
|
self._cnter -= 1
|
||||||
|
|
||||||
def print_info(self):
|
def print_info(self):
|
||||||
print(
|
print(f"Total number: {self.total_number}",
|
||||||
f"Total number: {self.total_number}",
|
f"Total CPU memory occupation: {self.total_mem['cpu']}",
|
||||||
f"Total CPU memory occupation: {self.total_mem['cpu']}",
|
f"Total CUDA memory occupation: {self.total_mem['cuda']}\n",
|
||||||
f"Total CUDA memory occupation: {self.total_mem['cuda']}\n", sep='\n')
|
sep='\n')
|
||||||
|
|
||||||
for state in self.states_cls:
|
for state in self.states_cls:
|
||||||
print(
|
print(f"{state}: CPU memory occupation: {self.state_mem['cpu'][state]}",
|
||||||
f"{state}: CPU memory occupation: {self.state_mem['cpu'][state]}",
|
f"{state}: CUDA memory occupation: {self.state_mem['cuda'][state]}\n",
|
||||||
f"{state}: CUDA memory occupation: {self.state_mem['cuda'][state]}\n", sep='\n')
|
sep='\n')
|
||||||
|
@ -202,3 +202,8 @@ class StatefulTensor(object):
|
|||||||
# update the information of each state
|
# update the information of each state
|
||||||
manager.state_mem[from_type][state] -= size
|
manager.state_mem[from_type][state] -= size
|
||||||
manager.state_mem[to_type][state] += size
|
manager.state_mem[to_type][state] += size
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
self.set_null()
|
||||||
|
StatefulTensor.GST_MGR.delete_instance()
|
||||||
|
del self
|
||||||
|
@ -12,7 +12,7 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self._lazy_memory_allocate = lazy_memory_allocate
|
self._lazy_memory_allocate = lazy_memory_allocate
|
||||||
|
|
||||||
def _post_init_method(self, module: torch.nn.Module):
|
def _post_init_method(self, module: torch.nn.Module, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
The function to call at the end of the constructor of each module.
|
The function to call at the end of the constructor of each module.
|
||||||
FIXME(fjr) The module may be passed to this function multiple times?
|
FIXME(fjr) The module may be passed to this function multiple times?
|
||||||
|
Loading…
Reference in New Issue
Block a user