Support evaluation during training

This commit is contained in:
YeAnbang
2025-04-30 18:13:40 +08:00
parent 5fd4bcb9d8
commit 57a88395fe
9 changed files with 234 additions and 65 deletions

View File

@@ -36,6 +36,7 @@ class BaseConsumer:
minibatch_size: int = 1,
save_interval: int = 100,
save_dir: str = "./model",
eval_interval: int = -1,
):
self.num_producers = num_producers
self.num_episodes = num_episodes
@@ -51,6 +52,7 @@ class BaseConsumer:
self.save_dir = save_dir
assert batch_size % minibatch_size == 0, "batch_size should be divisible by microbatch_size"
self.num_microbatches = batch_size // minibatch_size
self.eval_interval = eval_interval
self.model_config = model_config
self.plugin_config = plugin_config
@@ -92,8 +94,10 @@ class BaseConsumer:
if self.rank == 0:
cc.init_collective_group(self.num_producers + 1, self.num_producers, group_name="sync_model")
self.buffer = []
for i in range(self.num_producers):
cc.init_collective_group(self.world_size + 1, self.rank + 1, group_name=f"sync_eval_statistics_{i}")
self.buffer = []
self.recv_cnt = 0
def state_dict(self) -> Dict[str, torch.Tensor]:
@@ -110,6 +114,24 @@ class BaseConsumer:
with tqdm(range(self.num_update_per_episode), desc=f"Episode {episode}", disable=self.rank != 0) as pbar:
for step in pbar:
i = 0
if self.eval_interval > 0 and step % self.eval_interval == 0:
eval_statistics = None
for r in range(self.num_producers):
print(f"[T{dist.get_rank()}] Recv eval result episode {episode} step {step} from {r}")
local_eval_result = ray_broadcast_tensor_dict(
None, src=0, device=self.device, group_name=f"sync_eval_statistics_{r}"
)
if eval_statistics is None:
eval_statistics = local_eval_result
else:
eval_statistics = {
k: eval_statistics[k] + local_eval_result[k] for k in eval_statistics
}
eval_statistics = {k: (v[0] / v[1]).item() for k, v in eval_statistics.items()}
if dist.get_rank() == 0:
if hasattr(self, "wandb_run") and hasattr(self, "global_step"):
self.wandb_run.log(eval_statistics, step=self.global_step)
print(f"Eval statistics: {eval_statistics}")
for _ in range(self.num_recv_per_update):
# receive data from producers
for r in range(self.num_producers):
@@ -195,6 +217,7 @@ class SimpleConsumer(BaseConsumer):
minibatch_size=1,
save_interval: int = 100,
save_dir="./model",
eval_interval: int = -1,
):
super().__init__(
num_producers,
@@ -209,6 +232,9 @@ class SimpleConsumer(BaseConsumer):
model_config,
plugin_config,
minibatch_size,
save_interval,
save_dir,
eval_interval,
)
path = model_config.pop("path")
self.model = AutoModelForCausalLM.from_pretrained(path, **model_config)