mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-09 05:13:29 +00:00
'fix/format' (#573)
This commit is contained in:
parent
b0f708dfc1
commit
cfb41297ff
@ -106,7 +106,7 @@ class MemTracerOpHook(BaseOpHook):
|
|||||||
# output file info
|
# output file info
|
||||||
self._logger.info(f"dump a memory statistics as pickle to {self._data_prefix}-{self._rank}.pkl")
|
self._logger.info(f"dump a memory statistics as pickle to {self._data_prefix}-{self._rank}.pkl")
|
||||||
home_dir = Path.home()
|
home_dir = Path.home()
|
||||||
with open (home_dir.joinpath(f".cache/colossal/mem-{self._rank}.pkl"), "wb") as f:
|
with open(home_dir.joinpath(f".cache/colossal/mem-{self._rank}.pkl"), "wb") as f:
|
||||||
pickle.dump(self.async_mem_monitor.state_dict, f)
|
pickle.dump(self.async_mem_monitor.state_dict, f)
|
||||||
self._count += 1
|
self._count += 1
|
||||||
self._logger.debug(f"data file has been refreshed {self._count} times")
|
self._logger.debug(f"data file has been refreshed {self._count} times")
|
||||||
@ -115,4 +115,4 @@ class MemTracerOpHook(BaseOpHook):
|
|||||||
|
|
||||||
def save_results(self, data_file: Union[str, Path]):
|
def save_results(self, data_file: Union[str, Path]):
|
||||||
with open(data_file, "w") as f:
|
with open(data_file, "w") as f:
|
||||||
f.write(json.dumps(self.async_mem_monitor.state_dict))
|
f.write(json.dumps(self.async_mem_monitor.state_dict))
|
||||||
|
@ -85,8 +85,7 @@ class BaseSchedule(ABC):
|
|||||||
data_iter: Iterable,
|
data_iter: Iterable,
|
||||||
forward_only: bool,
|
forward_only: bool,
|
||||||
return_loss: bool = True,
|
return_loss: bool = True,
|
||||||
return_output_label: bool = True
|
return_output_label: bool = True):
|
||||||
):
|
|
||||||
"""The process function over a batch of dataset for training or evaluation.
|
"""The process function over a batch of dataset for training or evaluation.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -107,8 +106,9 @@ class BaseSchedule(ABC):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _call_engine_criterion(engine, outputs, labels):
|
def _call_engine_criterion(engine, outputs, labels):
|
||||||
assert isinstance(outputs, (torch.Tensor, list, tuple)
|
assert isinstance(
|
||||||
), f'Expect output of model is (torch.Tensor, list, tuple), got {type(outputs)}'
|
outputs,
|
||||||
|
(torch.Tensor, list, tuple)), f'Expect output of model is (torch.Tensor, list, tuple), got {type(outputs)}'
|
||||||
if isinstance(outputs, torch.Tensor):
|
if isinstance(outputs, torch.Tensor):
|
||||||
outputs = (outputs,)
|
outputs = (outputs,)
|
||||||
if isinstance(labels, torch.Tensor):
|
if isinstance(labels, torch.Tensor):
|
||||||
|
Loading…
Reference in New Issue
Block a user