[NFC] polish applications/Chat/coati/trainer/strategies/base.py code style (#4278)

This commit is contained in:
Zirui Zhu 2023-07-19 22:18:08 +08:00 committed by binmakeswell
parent c972d65311
commit 9e512938f6

View File

@ -79,8 +79,7 @@ class Strategy(ABC):
model, optimizer = arg model, optimizer = arg
except ValueError: except ValueError:
raise RuntimeError(f'Expect (model, optimizer) pair, got a tuple with size "{len(arg)}"') raise RuntimeError(f'Expect (model, optimizer) pair, got a tuple with size "{len(arg)}"')
model, optimizer, *_ = self.booster.boost(model=model, model, optimizer, *_ = self.booster.boost(model=model, optimizer=optimizer)
optimizer=optimizer)
rets.append((model, optimizer)) rets.append((model, optimizer))
elif isinstance(arg, Dict): elif isinstance(arg, Dict):
model, optimizer, criterion, dataloader, lr_scheduler = self.booster.boost(**arg) model, optimizer, criterion, dataloader, lr_scheduler = self.booster.boost(**arg)
@ -90,10 +89,7 @@ class Strategy(ABC):
dataloader=dataloader, dataloader=dataloader,
lr_scheduler=lr_scheduler) lr_scheduler=lr_scheduler)
# remove None values # remove None values
boost_result = { boost_result = {key: value for key, value in boost_result.items() if value is not None}
key: value
for key, value in boost_result.items() if value is not None
}
rets.append(boost_result) rets.append(boost_result)
else: else:
raise RuntimeError(f'Type {type(arg)} is not supported') raise RuntimeError(f'Type {type(arg)} is not supported')
@ -112,23 +108,13 @@ class Strategy(ABC):
""" """
return model return model
def save_model(self, def save_model(self, model: nn.Module, path: str, only_rank0: bool = True, **kwargs) -> None:
model: nn.Module,
path: str,
only_rank0: bool = True,
**kwargs
) -> None:
self.booster.save_model(model, path, shard=not only_rank0, **kwargs) self.booster.save_model(model, path, shard=not only_rank0, **kwargs)
def load_model(self, model: nn.Module, path: str, strict: bool = True) -> None: def load_model(self, model: nn.Module, path: str, strict: bool = True) -> None:
self.booster.load_model(model, path, strict) self.booster.load_model(model, path, strict)
def save_optimizer(self, def save_optimizer(self, optimizer: Optimizer, path: str, only_rank0: bool = False, **kwargs) -> None:
optimizer: Optimizer,
path: str,
only_rank0: bool = False,
**kwargs
) -> None:
self.booster.save_optimizer(optimizer, path, shard=not only_rank0, **kwargs) self.booster.save_optimizer(optimizer, path, shard=not only_rank0, **kwargs)
def load_optimizer(self, optimizer: Optimizer, path: str) -> None: def load_optimizer(self, optimizer: Optimizer, path: str) -> None: