This commit is contained in:
Tong Li
2025-03-06 16:26:14 +08:00
parent 0f566cc2d4
commit 0cc0c843ed
3 changed files with 18 additions and 4 deletions

View File

@@ -1,6 +1,6 @@
from contextlib import nullcontext
from typing import Any, Dict, Optional
import os
import ray
import ray.util.collective as cc
import torch
@@ -33,6 +33,8 @@ class BaseConsumer:
model_config: Dict[str, Any],
plugin_config: Dict[str, Any],
microbatch_size: int = 1,
save_interval: int = 100,
save_dir: str = "./model"
):
self.num_producers = num_producers
self.num_episodes = num_episodes
@@ -44,6 +46,8 @@ class BaseConsumer:
self.num_recv_per_update = num_recv_per_update
self.batch_size = batch_size
self.microbatch_size = microbatch_size
self.save_interval = save_interval
self.save_dir = save_dir
assert batch_size % microbatch_size == 0, "batch_size should be divisible by microbatch_size"
self.num_microbatches = batch_size // microbatch_size
@@ -116,6 +120,14 @@ class BaseConsumer:
pbar.set_postfix({"loss": loss})
i += 1
assert len(self.buffer) == 0
if (step + 1) % self.save_interval == 0:
if self.rank == 0:
print(f"Start saving policy model at step {step + 1}.")
save_path = os.path.join(self.save_dir, f"modeling-step-{step + 1}")
self.booster.save_model(self.policy_model, save_path, shard=True)
if self.rank == 0:
print(f"Saved model checkpoint at step {step + 1} in folder {save_path}")
if episode != self.num_episodes - 1 or step != self.num_update_per_episode - 1:
print(f"[T{dist.get_rank()}] Sync model episode {episode} step {step}")
state_dict = self.state_dict()