Merge branch 'grpo-latest-rebase-main' of https://github.com/hpcaitech/ColossalAI into grpo-latest-rebase-main

This commit is contained in:
YeAnbang 2025-08-14 19:03:04 +08:00
commit 99ba48fc40
4 changed files with 12 additions and 4 deletions

View File

@ -21,7 +21,7 @@ jobs:
container: container:
image: image-cloud.luchentech.com/hpcaitech/pytorch-cuda:2.2.2-12.1.0 image: image-cloud.luchentech.com/hpcaitech/pytorch-cuda:2.2.2-12.1.0
options: --gpus all --rm -v /data/scratch/examples-data:/data/scratch/examples-data --shm-size=10.24gb options: --gpus all --rm -v /data/scratch/examples-data:/data/scratch/examples-data --shm-size=10.24gb
timeout-minutes: 60 timeout-minutes: 180
defaults: defaults:
run: run:
shell: bash shell: bash
@ -34,6 +34,10 @@ jobs:
pip install --no-cache-dir -v -e . pip install --no-cache-dir -v -e .
- name: Install ChatGPT - name: Install ChatGPT
env:
CFLAGS: "-O1"
CXXFLAGS: "-O1"
MAX_JOBS: 4
run: | run: |
pip install flash-attn --no-build-isolation pip install flash-attn --no-build-isolation
cd applications/ColossalChat cd applications/ColossalChat

View File

@ -21,7 +21,7 @@ jobs:
container: container:
image: image-cloud.luchentech.com/hpcaitech/pytorch-cuda:2.2.2-12.1.0 image: image-cloud.luchentech.com/hpcaitech/pytorch-cuda:2.2.2-12.1.0
options: --gpus all --rm -v /data/scratch/examples-data:/data/scratch/examples-data options: --gpus all --rm -v /data/scratch/examples-data:/data/scratch/examples-data
timeout-minutes: 30 timeout-minutes: 180
defaults: defaults:
run: run:
shell: bash shell: bash
@ -30,6 +30,10 @@ jobs:
uses: actions/checkout@v2 uses: actions/checkout@v2
- name: Install ChatGPT - name: Install ChatGPT
env:
CFLAGS: "-O1"
CXXFLAGS: "-O1"
MAX_JOBS: 4
run: | run: |
pip install flash-attn --no-build-isolation pip install flash-attn --no-build-isolation
cd applications/ColossalChat cd applications/ColossalChat

View File

@ -530,4 +530,4 @@ class GRPOConsumer(BaseConsumer):
model = self.policy_model.unwrap() model = self.policy_model.unwrap()
state_dict = model.state_dict() state_dict = model.state_dict()
state_dict["consumer_global_step"] = torch.tensor([self.global_step], device=self.device) state_dict["consumer_global_step"] = torch.tensor([self.global_step], device=self.device)
return state_dict return state_dict

View File

@ -273,7 +273,7 @@ class Qwen3PipelineForwards:
hidden_states: Optional[torch.FloatTensor] = None, hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None, stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None, shard_config: ShardConfig = None,
**kwargs **kwargs,
): ):
r""" r"""
Args: Args: