mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 21:09:18 +00:00
fix logprob, add filtering, temperature annealing, lr descent
This commit is contained in:
@@ -57,6 +57,7 @@ class BaseConsumer:
|
||||
assert self.plugin_config.get("pp_size", 1) == 1, "pp_size > 1 is not supported now"
|
||||
|
||||
self.device = get_current_device()
|
||||
self.lr_scheduler = None
|
||||
|
||||
def setup(self) -> None:
|
||||
for i in range(self.num_producers):
|
||||
@@ -121,6 +122,8 @@ class BaseConsumer:
|
||||
pbar.set_postfix({"loss": loss})
|
||||
i += 1
|
||||
assert len(self.buffer) == 0
|
||||
if self.lr_scheduler is not None:
|
||||
self.lr_scheduler.step()
|
||||
if (step + 1) % self.save_interval == 0:
|
||||
if self.rank == 0:
|
||||
print(f"Start saving policy model at step {step + 1}.")
|
||||
|
Reference in New Issue
Block a user