mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 12:01:39 +00:00
[fix] revert reward update and evaluation (#6295)
* Revert "rewrite reward fn" This reverts commitd06042b434
. * Revert "upgrade reward math verification" This reverts commita6085ff676
. * Revert "fix bug" This reverts commit01640ebd65
. * Revert "reuse comm-group" This reverts commitbd61918dcf
. * Revert "Support evaluation during training" This reverts commit57a88395fe
.
This commit is contained in:
@@ -36,7 +36,6 @@ class BaseConsumer:
|
||||
minibatch_size: int = 1,
|
||||
save_interval: int = 100,
|
||||
save_dir: str = "./model",
|
||||
eval_interval: int = -1,
|
||||
):
|
||||
self.num_producers = num_producers
|
||||
self.num_episodes = num_episodes
|
||||
@@ -52,7 +51,6 @@ class BaseConsumer:
|
||||
self.save_dir = save_dir
|
||||
assert batch_size % minibatch_size == 0, "batch_size should be divisible by microbatch_size"
|
||||
self.num_microbatches = batch_size // minibatch_size
|
||||
self.eval_interval = eval_interval
|
||||
|
||||
self.model_config = model_config
|
||||
self.plugin_config = plugin_config
|
||||
@@ -95,6 +93,7 @@ class BaseConsumer:
|
||||
cc.init_collective_group(self.num_producers + 1, self.num_producers, group_name="sync_model")
|
||||
|
||||
self.buffer = []
|
||||
|
||||
self.recv_cnt = 0
|
||||
|
||||
def state_dict(self) -> Dict[str, torch.Tensor]:
|
||||
@@ -111,27 +110,6 @@ class BaseConsumer:
|
||||
with tqdm(range(self.num_update_per_episode), desc=f"Episode {episode}", disable=self.rank != 0) as pbar:
|
||||
for step in pbar:
|
||||
i = 0
|
||||
if self.eval_interval > 0 and step % self.eval_interval == 0:
|
||||
eval_statistics = None
|
||||
eval_global_step = None
|
||||
for r in range(self.num_producers):
|
||||
print(f"[T{dist.get_rank()}] Recv eval result episode {episode} step {step} from {r}")
|
||||
local_eval_result = ray_broadcast_tensor_dict(
|
||||
None, src=0, device=self.device, group_name=f"sync_data_{r}"
|
||||
)
|
||||
assert "consumer_global_step" in local_eval_result
|
||||
eval_global_step = local_eval_result.pop("consumer_global_step").item()
|
||||
if eval_statistics is None:
|
||||
eval_statistics = local_eval_result
|
||||
else:
|
||||
eval_statistics = {
|
||||
k: eval_statistics[k] + local_eval_result[k] for k in eval_statistics
|
||||
}
|
||||
eval_statistics = {"eval/" + k: (v[0] / v[1]).item() for k, v in eval_statistics.items()}
|
||||
if dist.get_rank() == 0:
|
||||
if hasattr(self, "wandb_run"):
|
||||
self.wandb_run.log(eval_statistics, step=eval_global_step)
|
||||
print(f"Eval statistics: {eval_statistics}")
|
||||
for _ in range(self.num_recv_per_update):
|
||||
# receive data from producers
|
||||
for r in range(self.num_producers):
|
||||
@@ -217,7 +195,6 @@ class SimpleConsumer(BaseConsumer):
|
||||
minibatch_size=1,
|
||||
save_interval: int = 100,
|
||||
save_dir="./model",
|
||||
eval_interval: int = -1,
|
||||
):
|
||||
super().__init__(
|
||||
num_producers,
|
||||
@@ -232,9 +209,6 @@ class SimpleConsumer(BaseConsumer):
|
||||
model_config,
|
||||
plugin_config,
|
||||
minibatch_size,
|
||||
save_interval,
|
||||
save_dir,
|
||||
eval_interval,
|
||||
)
|
||||
path = model_config.pop("path")
|
||||
self.model = AutoModelForCausalLM.from_pretrained(path, **model_config)
|
||||
|
Reference in New Issue
Block a user