mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-04-28 03:43:01 +00:00
Compare commits
10 Commits
Author | SHA1 | Date | |
---|---|---|---|
|
46ed5d856b | ||
|
7ecdf9a211 | ||
|
44d4053fec | ||
|
6d676ee0e9 | ||
|
56fe130b15 | ||
|
f32861ccc5 | ||
|
b9e60559b8 | ||
|
7595c453a5 | ||
|
53834b74b9 | ||
|
0171884664 |
@ -1,3 +1,3 @@
|
||||
2.2.2-12.1.0
|
||||
2.3.0-12.1.0
|
||||
2.4.0-12.4.1
|
||||
2.5.1-12.4.1
|
||||
|
@ -1,11 +1,11 @@
|
||||
{
|
||||
"build": [
|
||||
{
|
||||
"torch_command": "pip install torch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 --index-url https://download.pytorch.org/whl/cu121",
|
||||
"torch_command": "pip install torch==2.3.0 torchvision==0.18.0 torchaudio==2.3.0 --index-url https://download.pytorch.org/whl/cu121",
|
||||
"cuda_image": "hpcaitech/cuda-conda:12.1"
|
||||
},
|
||||
{
|
||||
"torch_command": "pip install torch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0 --index-url https://download.pytorch.org/whl/cu124",
|
||||
"torch_command": "pip install torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 --index-url https://download.pytorch.org/whl/cu124",
|
||||
"cuda_image": "hpcaitech/cuda-conda:12.4"
|
||||
}
|
||||
]
|
||||
|
6
.github/workflows/build_on_pr.yml
vendored
6
.github/workflows/build_on_pr.yml
vendored
@ -87,10 +87,10 @@ jobs:
|
||||
name: Build and Test Colossal-AI
|
||||
needs: detect
|
||||
if: needs.detect.outputs.anyLibraryFileChanged == 'true'
|
||||
runs-on: [self-hosted, gpu]
|
||||
runs-on: ubuntu-latest
|
||||
container:
|
||||
image: hpcaitech/pytorch-cuda:2.2.2-12.1.0
|
||||
options: --gpus all --rm -v /dev/shm -v /data/scratch:/data/scratch
|
||||
image: image-cloud.luchentech.com/hpcaitech/pytorch-cuda:2.2.2-12.1.0
|
||||
options: --gpus all --shm-size=2g --rm -v /dev/shm -v /data/scratch:/data/scratch
|
||||
timeout-minutes: 90
|
||||
defaults:
|
||||
run:
|
||||
|
@ -38,7 +38,7 @@ Limited Academic Bonuses:
|
||||
|
||||
<div align="center">
|
||||
<a href="https://hpc-ai.com/?utm_source=github&utm_medium=social&utm_campaign=promotion-colossalai">
|
||||
<img src="https://github.com/hpcaitech/public_assets/blob/main/colossalai/img/2.gif" width="850" />
|
||||
<img src="https://github.com/hpcaitech/public_assets/blob/main/colossalai/img/2-2.gif" width="850" />
|
||||
</a>
|
||||
</div>
|
||||
|
||||
|
@ -140,7 +140,7 @@ class NaiveExperienceMaker(ExperienceMaker):
|
||||
num_actions = 0
|
||||
|
||||
for inference_mini_batch_id in range(0, input_ids.size(0), self.inference_batch_size):
|
||||
s, e = inference_mini_batch_id, (inference_mini_batch_id + 1) * self.inference_batch_size
|
||||
s, e = inference_mini_batch_id, inference_mini_batch_id + self.inference_batch_size
|
||||
if input_ids[s:e].size(0) == 0:
|
||||
break
|
||||
sequences = generate(self.actor, input_ids[s:e], self.tokenizer, **generate_kwargs)
|
||||
|
@ -380,8 +380,8 @@ class DPOTrainer(SLTrainer):
|
||||
self.accumulative_meter.get("accuracy"),
|
||||
global_step,
|
||||
)
|
||||
self.num_train_step += 1
|
||||
self.accumulative_meter.reset()
|
||||
self.num_train_step += 1
|
||||
|
||||
if self.save_dir is not None and self.num_train_step > 0 and self.num_train_step % self.save_interval == 0:
|
||||
# save checkpoint
|
||||
|
@ -231,7 +231,6 @@ class GRPOTrainer(OLTrainer):
|
||||
experience:
|
||||
sequences: [batch_size, prompt_length + response_length] --- <PAD>...<PAD><PROMPT>...<PROMPT><RESPONSE>...<RESPONSE><PAD>...<PAD>
|
||||
"""
|
||||
self.num_train_step += 1
|
||||
self.actor.train()
|
||||
num_actions = experience.action_log_probs.size(1)
|
||||
# policy loss
|
||||
@ -294,7 +293,7 @@ class GRPOTrainer(OLTrainer):
|
||||
self.temperature_annealing_scheduler.step_forward()
|
||||
|
||||
# preparing logging model output and corresponding rewards.
|
||||
if self.num_train_step % 10 == 1:
|
||||
if self.num_train_step % 10 == 0:
|
||||
response_text = self.experience_maker.tokenizer.batch_decode(
|
||||
experience.sequences, skip_special_tokens=True
|
||||
)
|
||||
@ -327,6 +326,7 @@ class GRPOTrainer(OLTrainer):
|
||||
self.writer.add_scalar("approx_kl", self.accumulative_meter.get("kl"), global_step)
|
||||
self.writer.add_scalar("advantages", self.accumulative_meter.get("advantages"), global_step)
|
||||
self.accumulative_meter.reset()
|
||||
self.num_train_step += 1
|
||||
|
||||
def _learn(self, update_step: int):
|
||||
"""
|
||||
|
@ -256,7 +256,7 @@ class KTOTrainer(SLTrainer):
|
||||
self.coordinator.print_on_master(
|
||||
f"Saved checkpoint at epoch {epoch} step {self.save_interval} at folder {self.save_dir}"
|
||||
)
|
||||
self.num_train_step += 1
|
||||
self.num_train_step += 1
|
||||
|
||||
step_bar.close()
|
||||
|
||||
|
@ -233,7 +233,7 @@ class ORPOTrainer(SLTrainer):
|
||||
self.coordinator.print_on_master(
|
||||
f"Saved checkpoint at epoch {epoch} step {self.save_interval} at folder {self.save_dir}"
|
||||
)
|
||||
self.num_train_step += 1
|
||||
self.num_train_step += 1
|
||||
|
||||
step_bar.close()
|
||||
|
||||
|
@ -220,7 +220,6 @@ class PPOTrainer(OLTrainer):
|
||||
experience:
|
||||
sequences: [batch_size, prompt_length + response_length] --- <PAD>...<PAD><PROMPT>...<PROMPT><RESPONSE>...<RESPONSE><PAD>...<PAD>
|
||||
"""
|
||||
self.num_train_step += 1
|
||||
self.actor.train()
|
||||
self.critic.train()
|
||||
num_actions = experience.action_log_probs.size(1)
|
||||
@ -294,7 +293,7 @@ class PPOTrainer(OLTrainer):
|
||||
self.critic_scheduler.step()
|
||||
|
||||
# preparing logging model output and corresponding rewards.
|
||||
if self.num_train_step % 10 == 1:
|
||||
if self.num_train_step % 10 == 0:
|
||||
response_text = self.experience_maker.tokenizer.batch_decode(
|
||||
experience.sequences, skip_special_tokens=True
|
||||
)
|
||||
@ -336,6 +335,7 @@ class PPOTrainer(OLTrainer):
|
||||
self.writer.add_scalar("value", self.accumulative_meter.get("value"), self.num_train_step)
|
||||
self.writer.add_scalar("advantages", self.accumulative_meter.get("advantages"), self.num_train_step)
|
||||
self.accumulative_meter.reset()
|
||||
self.num_train_step += 1
|
||||
|
||||
def _learn(self, update_step: int):
|
||||
"""
|
||||
|
@ -193,7 +193,7 @@ class RewardModelTrainer(SLTrainer):
|
||||
self.coordinator.print_on_master(
|
||||
f"Saved checkpoint at epoch {epoch} step {(i + 1)/self.accumulation_steps} at folder {self.save_dir}"
|
||||
)
|
||||
self.num_train_step += 1
|
||||
self.num_train_step += 1
|
||||
step_bar.close()
|
||||
|
||||
def _eval(self, epoch):
|
||||
|
@ -152,9 +152,9 @@ class SFTTrainer(SLTrainer):
|
||||
if self.writer:
|
||||
self.writer.add_scalar("train/loss", self.accumulative_meter.get("loss"), global_step)
|
||||
self.writer.add_scalar("train/lr", self.scheduler.get_last_lr()[0], global_step)
|
||||
self.num_train_step += 1
|
||||
self.accumulative_meter.reset()
|
||||
step_bar.update()
|
||||
self.num_train_step += 1
|
||||
|
||||
# Save checkpoint
|
||||
if (
|
||||
|
@ -892,6 +892,63 @@ The dialogues can by multiple turns and it can contain system prompt. For more d
|
||||
|
||||
We use bf16 weights for finetuning. If you downloaded fp8 DeepSeek V3/R1 weights, you can use the [script](https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/fp8_cast_bf16.py) to convert the weights to bf16 via GPU. For Ascend NPU, you can use this [script](https://gitee.com/ascend/ModelZoo-PyTorch/blob/master/MindIE/LLM/DeepSeek/DeepSeek-V2/NPU_inference/fp8_cast_bf16.py).
|
||||
|
||||
We have also added details on how to load and reason with lora models.
|
||||
```python
|
||||
from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
AutoTokenizer,
|
||||
)
|
||||
from peft import (
|
||||
PeftModel
|
||||
)
|
||||
import torch
|
||||
|
||||
# Set model path
|
||||
model_name = "Qwen/Qwen2.5-3B"
|
||||
lora_adapter = "Qwen2.5-3B_lora" # Your lora model Path
|
||||
merged_model_path = "Qwen2.5-3B_merged"
|
||||
|
||||
######
|
||||
# How to Load lora Model
|
||||
######
|
||||
# 1.Load base model
|
||||
base_model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="auto",
|
||||
trust_remote_code=True
|
||||
)
|
||||
|
||||
# 2.Load lora model
|
||||
peft_model = PeftModel.from_pretrained(
|
||||
base_model,
|
||||
lora_adapter,
|
||||
torch_dtype=torch.bfloat16
|
||||
)
|
||||
|
||||
# 3.Merge lora model
|
||||
merged_model = peft_model.merge_and_unload()
|
||||
|
||||
# 4.Load tokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_name,
|
||||
trust_remote_code=True,
|
||||
pad_token="<|endoftext|>"
|
||||
)
|
||||
|
||||
# 5.Save merged lora model
|
||||
merged_model.save_pretrained(
|
||||
merged_model_path,
|
||||
safe_serialization=True
|
||||
)
|
||||
tokenizer.save_pretrained(merged_model_path)
|
||||
|
||||
# 6.Run Inference
|
||||
test_input = tokenizer("Instruction: Finding prime numbers up to 100\nAnswer:", return_tensors="pt").to("cuda")
|
||||
output = merged_model.generate(**test_input, max_new_tokens=100)
|
||||
print(tokenizer.decode(output[0], skip_special_tokens=True))
|
||||
```
|
||||
|
||||
#### Usage
|
||||
|
||||
After preparing the dataset and model weights, you can run the script with the following command:
|
||||
|
@ -257,7 +257,7 @@ def train(args) -> None:
|
||||
)
|
||||
|
||||
torch.set_default_dtype(torch.float)
|
||||
booster.load_model(model, args.pretrained)
|
||||
booster.load_model(model, args.pretrained, low_cpu_mem_mode=False, num_threads=8)
|
||||
|
||||
coordinator.print_on_master(
|
||||
f"Booster init max device memory: {accelerator.max_memory_allocated() / 1024 ** 2:.2f} MB"
|
||||
|
@ -85,11 +85,11 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
|
||||
if use_async:
|
||||
from colossalai.utils.safetensors import save
|
||||
|
||||
if id(model) not in self.pinned_state_dicts:
|
||||
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict)
|
||||
if hash(model) not in self.pinned_state_dicts:
|
||||
self.pinned_state_dicts[hash(model)] = create_pinned_state_dict(state_dict)
|
||||
for k, v in state_dict.items():
|
||||
self.pinned_state_dicts[id(model)][k].copy_(v)
|
||||
state_dict[k] = self.pinned_state_dicts[id(model)][k]
|
||||
self.pinned_state_dicts[hash(model)][k].copy_(v)
|
||||
state_dict[k] = self.pinned_state_dicts[hash(model)][k]
|
||||
writer = save(checkpoint, state_dict)
|
||||
self.async_writers.append(writer)
|
||||
else:
|
||||
@ -172,9 +172,9 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
|
||||
Path(checkpoint_path).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if use_async and self.coordinator.is_master():
|
||||
if id(model) not in self.pinned_state_dicts:
|
||||
self.pinned_state_dicts[id(model)] = {}
|
||||
pinned_state_dicts = self.pinned_state_dicts[id(model)]
|
||||
if hash(model) not in self.pinned_state_dicts:
|
||||
self.pinned_state_dicts[hash(model)] = {}
|
||||
pinned_state_dicts = self.pinned_state_dicts[hash(model)]
|
||||
else:
|
||||
pinned_state_dicts = None
|
||||
state_dict_shard = model.state_dict_shard(
|
||||
|
@ -26,6 +26,7 @@ from colossalai.amp.naive_amp.mixed_precision_optimizer import MixedPrecisionOpt
|
||||
from colossalai.checkpoint_io import CheckpointIO, HybridParallelCheckpointIO
|
||||
from colossalai.cluster import ProcessGroupMesh
|
||||
from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
|
||||
from colossalai.interface.model import PeftUnwrapMixin
|
||||
from colossalai.interface.optimizer import DistributedOptim
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed
|
||||
@ -225,7 +226,7 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin):
|
||||
if isinstance(model, DDP):
|
||||
model = model.module
|
||||
if unwrap_peft and isinstance(model, PeftModel):
|
||||
model = model.get_base_model()
|
||||
model = PeftUnwrapMixin(model)
|
||||
return model
|
||||
|
||||
def _force_wait_all_gather(self):
|
||||
|
@ -12,6 +12,7 @@ from torch.utils.data import DataLoader
|
||||
from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||
from colossalai.interface.model import PeftUnwrapMixin
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.quantization import BnbQuantizationConfig, quantize_model
|
||||
from colossalai.utils import get_current_device
|
||||
@ -201,7 +202,7 @@ class TorchDDPModel(ModelWrapper):
|
||||
def unwrap(self, unwrap_peft: bool = True) -> nn.Module:
|
||||
model = self.module.module
|
||||
if unwrap_peft and isinstance(model, PeftModel):
|
||||
model = model.get_base_model()
|
||||
model = PeftUnwrapMixin(model)
|
||||
return model
|
||||
|
||||
|
||||
|
@ -103,11 +103,11 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
|
||||
if use_async:
|
||||
from colossalai.utils.safetensors import save
|
||||
|
||||
if id(model) not in self.pinned_state_dicts:
|
||||
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(full_model_state)
|
||||
if hash(model) not in self.pinned_state_dicts:
|
||||
self.pinned_state_dicts[hash(model)] = create_pinned_state_dict(full_model_state)
|
||||
for k, v in full_model_state.items():
|
||||
self.pinned_state_dicts[id(model)][k].copy_(v)
|
||||
full_model_state[k] = self.pinned_state_dicts[id(model)][k]
|
||||
self.pinned_state_dicts[hash(model)][k].copy_(v)
|
||||
full_model_state[k] = self.pinned_state_dicts[hash(model)][k]
|
||||
writer = save(checkpoint, full_model_state)
|
||||
self.async_writers.append(writer)
|
||||
else:
|
||||
@ -186,9 +186,9 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
|
||||
state_dict = model.unwrap().state_dict()
|
||||
|
||||
if use_async and self.coordinator.is_master():
|
||||
if id(model) not in self.pinned_state_dicts:
|
||||
self.pinned_state_dicts[id(model)] = {}
|
||||
pinned_state_dicts = self.pinned_state_dicts[id(model)]
|
||||
if hash(model) not in self.pinned_state_dicts:
|
||||
self.pinned_state_dicts[hash(model)] = {}
|
||||
pinned_state_dicts = self.pinned_state_dicts[hash(model)]
|
||||
else:
|
||||
pinned_state_dicts = None
|
||||
state_dict_shard = utils.shard_model_checkpoint(
|
||||
|
@ -60,9 +60,9 @@ class GeneralCheckpointIO(CheckpointIO):
|
||||
if use_async:
|
||||
from colossalai.utils.safetensors import move_and_save
|
||||
|
||||
if id(model) not in self.pinned_state_dicts:
|
||||
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict)
|
||||
writer = move_and_save(checkpoint, state_dict, self.pinned_state_dicts[id(model)])
|
||||
if hash(model) not in self.pinned_state_dicts:
|
||||
self.pinned_state_dicts[hash(model)] = create_pinned_state_dict(state_dict)
|
||||
writer = move_and_save(checkpoint, state_dict, self.pinned_state_dicts[hash(model)])
|
||||
self.async_writers.append(writer)
|
||||
else:
|
||||
# save the checkpoint
|
||||
@ -234,7 +234,7 @@ class GeneralCheckpointIO(CheckpointIO):
|
||||
index_file = CheckpointIndexFile(checkpoint_path)
|
||||
|
||||
if use_async:
|
||||
pinned_state_dict = self.pinned_state_dicts.get(id(model), None)
|
||||
pinned_state_dict = self.pinned_state_dicts.get(hash(model), None)
|
||||
total_size, new_pinned_state_dict, writers = async_move_save_state_dict_shards(
|
||||
sharded_state_dict=state_dict_shard,
|
||||
checkpoint=checkpoint_path,
|
||||
@ -243,7 +243,7 @@ class GeneralCheckpointIO(CheckpointIO):
|
||||
is_master=True,
|
||||
pinned_state_dict=pinned_state_dict,
|
||||
)
|
||||
self.pinned_state_dicts[id(model)] = new_pinned_state_dict
|
||||
self.pinned_state_dicts[hash(model)] = new_pinned_state_dict
|
||||
self.async_writers.extend(writers)
|
||||
else:
|
||||
# Save shards of optimizer states.
|
||||
|
@ -249,9 +249,9 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
# Only devices with tp_rank == 0 are responsible for model saving.
|
||||
control_saving = self.tp_rank == 0 and self.sp_rank == 0
|
||||
if control_saving and use_async:
|
||||
if id(model) not in self.pinned_state_dicts:
|
||||
self.pinned_state_dicts[id(model)] = {}
|
||||
pinned_state_dicts = self.pinned_state_dicts[id(model)]
|
||||
if hash(model) not in self.pinned_state_dicts:
|
||||
self.pinned_state_dicts[hash(model)] = {}
|
||||
pinned_state_dicts = self.pinned_state_dicts[hash(model)]
|
||||
else:
|
||||
pinned_state_dicts = None
|
||||
state_dict_shard = HybridParallelCheckpointIO._model_sharder(
|
||||
@ -789,11 +789,11 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
if use_async:
|
||||
from colossalai.utils.safetensors import save
|
||||
|
||||
if id(model) not in self.pinned_state_dicts:
|
||||
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict)
|
||||
if hash(model) not in self.pinned_state_dicts:
|
||||
self.pinned_state_dicts[hash(model)] = create_pinned_state_dict(state_dict)
|
||||
for name, param in state_dict.items():
|
||||
self.pinned_state_dicts[id(model)][name].copy_(param)
|
||||
state_dict[name] = self.pinned_state_dicts[id(model)][name]
|
||||
self.pinned_state_dicts[hash(model)][name].copy_(param)
|
||||
state_dict[name] = self.pinned_state_dicts[hash(model)][name]
|
||||
writer = save(path=checkpoint, state_dict=state_dict)
|
||||
self.async_writers.append(writer)
|
||||
else:
|
||||
@ -811,11 +811,11 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
if use_async:
|
||||
from colossalai.utils.safetensors import save
|
||||
|
||||
if id(model) not in self.pinned_state_dicts:
|
||||
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(complete_state_dict)
|
||||
if hash(model) not in self.pinned_state_dicts:
|
||||
self.pinned_state_dicts[hash(model)] = create_pinned_state_dict(complete_state_dict)
|
||||
for name, param in complete_state_dict.items():
|
||||
self.pinned_state_dicts[id(model)][name].copy_(param)
|
||||
complete_state_dict[name] = self.pinned_state_dicts[id(model)][name]
|
||||
self.pinned_state_dicts[hash(model)][name].copy_(param)
|
||||
complete_state_dict[name] = self.pinned_state_dicts[hash(model)][name]
|
||||
writer = save(path=checkpoint, state_dict=complete_state_dict)
|
||||
self.async_writers.append(writer)
|
||||
else:
|
||||
|
@ -701,15 +701,18 @@ class MoECheckpointIO(HybridParallelCheckpointIO):
|
||||
all_param = None
|
||||
# gather param from every ep rank
|
||||
# dist.all_gather(all_param, param, group=ep_group)
|
||||
dist.gather(param, all_param, group=ep_group)
|
||||
dist.gather(param, all_param, dst=dist.get_global_rank(ep_group, 0), group=ep_group)
|
||||
if ep_rank == 0:
|
||||
all_param = torch.cat(all_param, dim=0)
|
||||
state_dict[name] = all_param.cpu()
|
||||
|
||||
if self.pp_size > 1:
|
||||
if self.dp_rank == 0:
|
||||
out = [None for _ in range(self.pp_size)]
|
||||
dist.gather_object(state_dict, out, group=self.pp_group)
|
||||
if self.pp_rank == 0:
|
||||
out = [None for _ in range(self.pp_size)]
|
||||
else:
|
||||
out = None
|
||||
dist.gather_object(state_dict, out, dst=dist.get_global_rank(self.pp_group, 0), group=self.pp_group)
|
||||
if self.pp_rank == 0:
|
||||
new_state_dict = {}
|
||||
for o in out:
|
||||
|
@ -20,6 +20,7 @@ from torch.optim import Optimizer
|
||||
from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.interface.model import PeftUnwrapMixin
|
||||
from colossalai.tensor.d_tensor import (
|
||||
is_customized_distributed_tensor,
|
||||
is_distributed_tensor,
|
||||
@ -554,6 +555,8 @@ def save_config_file(model: nn.Module, checkpoint_path: str, is_master: bool = T
|
||||
from transformers.modeling_utils import unwrap_model as unwrap_huggingface_model
|
||||
except ImportError:
|
||||
return
|
||||
if isinstance(model, PeftUnwrapMixin):
|
||||
model = model.base_model
|
||||
if not isinstance(model, PreTrainedModel):
|
||||
return
|
||||
|
||||
@ -692,6 +695,9 @@ def load_state_dict_into_model(
|
||||
state_dict (dict): a dict containing parameters and
|
||||
persistent buffers.
|
||||
"""
|
||||
if isinstance(model, PeftUnwrapMixin):
|
||||
state_dict = model.patch_state_dict(state_dict)
|
||||
model = model.base_model
|
||||
if not isinstance(state_dict, Mapping):
|
||||
raise TypeError("Expected state_dict to be dict-like, got {}.".format(type(state_dict)))
|
||||
|
||||
|
@ -1,5 +1,102 @@
|
||||
import re
|
||||
from typing import Dict, Set
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from peft import PeftModel
|
||||
from peft import PeftModel, PeftType
|
||||
|
||||
|
||||
def extract_lora_layers(model: PeftModel, names: Set[str], adapter_name: str = "default"):
|
||||
config = model.peft_config[adapter_name]
|
||||
if config.peft_type != PeftType.LORA:
|
||||
raise ValueError(f"Adapter {adapter_name} is not a LORA adapter.")
|
||||
# to_return = lora_state_dict(model, bias=model.peft_config.bias)
|
||||
# adapted from `https://github.com/microsoft/LoRA/blob/main/loralib/utils.py`
|
||||
# to be used directly with the state dict which is necessary when using DeepSpeed or FSDP
|
||||
bias = config.bias
|
||||
if bias == "none":
|
||||
to_return = {k for k in names if "lora_" in k}
|
||||
elif bias == "all":
|
||||
to_return = {k for k in names if "lora_" in k or "bias" in k}
|
||||
elif bias == "lora_only":
|
||||
to_return = set()
|
||||
for k in names:
|
||||
if "lora_" in k:
|
||||
to_return.add(k)
|
||||
bias_name = k.split("lora_")[0] + "bias"
|
||||
if bias_name in names:
|
||||
to_return.add(bias_name)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
to_return = {k for k in to_return if (("lora_" in k and adapter_name in k) or ("bias" in k))}
|
||||
if config.use_dora:
|
||||
# Here we take care of a refactor of DoRA which changed lora_magnitude_vector from a ParameterDict to a
|
||||
# ModuleDict with a DoraLayer instance. The old parameter is now the "weight" attribute of that layer. Since
|
||||
# we want the state_dict format not to change, we remove the "weight" part.
|
||||
new_dora_suffix = f"lora_magnitude_vector.{adapter_name}.weight"
|
||||
|
||||
def renamed_dora_weights(k):
|
||||
if k.endswith(new_dora_suffix):
|
||||
k = k[:-7] # remove ".weight"
|
||||
return k
|
||||
|
||||
to_return = {renamed_dora_weights(k) for k in to_return}
|
||||
|
||||
to_return = {re.sub(f"lora_\S\.{adapter_name}\.(weight|bias)", "base_layer", k) for k in to_return}
|
||||
return to_return
|
||||
|
||||
|
||||
class PeftUnwrapMixin:
|
||||
def __init__(self, peft_model: PeftModel):
|
||||
self.base_model = peft_model.get_base_model()
|
||||
# peft does not affect buffers
|
||||
self.lora_layers = extract_lora_layers(peft_model, set(n for n, p in self.base_model.named_parameters()))
|
||||
potential_lora_weights = set()
|
||||
for n in self.lora_layers:
|
||||
potential_lora_weights.add(f"{n}.weight")
|
||||
potential_lora_weights.add(f"{n}.bias")
|
||||
self.lora_param_to_origin_param = {n: n.replace("base_layer.", "") for n in potential_lora_weights}
|
||||
self.origin_param_to_lora_param = {v: k for k, v in self.lora_param_to_origin_param.items()}
|
||||
|
||||
def named_parameters(self):
|
||||
for n, p in self.base_model.named_parameters():
|
||||
if n in self.lora_param_to_origin_param:
|
||||
n = self.lora_param_to_origin_param[n]
|
||||
yield n, p
|
||||
|
||||
def named_buffers(self):
|
||||
return self.base_model.named_buffers()
|
||||
|
||||
@property
|
||||
def _modules(self):
|
||||
return self.base_model._modules
|
||||
|
||||
@property
|
||||
def _non_persistent_buffers_set(self):
|
||||
return self.base_model._non_persistent_buffers_set
|
||||
|
||||
def patch_state_dict(self, state_dict: Dict[str, torch.Tensor]):
|
||||
new_state_dict = {}
|
||||
for k, v in state_dict.items():
|
||||
if k in self.origin_param_to_lora_param:
|
||||
k = self.origin_param_to_lora_param[k]
|
||||
new_state_dict[k] = v
|
||||
return new_state_dict
|
||||
|
||||
def state_dict(self):
|
||||
state_dict = {}
|
||||
for k, v in self.base_model.state_dict().items():
|
||||
if k in self.lora_param_to_origin_param:
|
||||
k = self.lora_param_to_origin_param[k]
|
||||
state_dict[k] = v
|
||||
return state_dict
|
||||
|
||||
def load_state_dict(self, state_dict, strict: bool = True, assign: bool = False):
|
||||
state_dict = self.patch_state_dict(state_dict)
|
||||
self.base_model.load_state_dict(state_dict, strict=strict, assign=assign)
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.base_model)
|
||||
|
||||
|
||||
class ModelWrapper(nn.Module):
|
||||
@ -23,7 +120,7 @@ class ModelWrapper(nn.Module):
|
||||
else:
|
||||
model = self.module
|
||||
if unwrap_peft and isinstance(model, PeftModel):
|
||||
model = model.get_base_model()
|
||||
model = PeftUnwrapMixin(model)
|
||||
return model
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
|
@ -8,7 +8,7 @@ click
|
||||
fabric
|
||||
contexttimer
|
||||
ninja
|
||||
torch>=2.2.0,<=2.4.1
|
||||
torch>=2.2.0,<=2.5.1
|
||||
safetensors
|
||||
einops
|
||||
pydantic
|
||||
|
@ -1,7 +1,7 @@
|
||||
from colossalai.cluster.device_mesh_manager import DeviceMeshInfo, DeviceMeshManager
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing import spawn
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
|
||||
|
||||
def check_device_mesh_manager(rank, world_size, port):
|
||||
@ -24,6 +24,7 @@ def check_device_mesh_manager(rank, world_size, port):
|
||||
assert device_mesh_with_shape._logical_mesh_id.tolist() == [[0, 1], [2, 3]]
|
||||
|
||||
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_device_mesh_manager():
|
||||
spawn(check_device_mesh_manager, 4)
|
||||
|
||||
|
@ -6,9 +6,10 @@ from torch.testing import assert_close
|
||||
from colossalai import launch
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.quantization.fp8 import all_to_all_single_fp8
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
|
||||
|
||||
|
||||
@clear_cache_before_run()
|
||||
@parameterize("shape", [(4,), (1, 8, 16), (4, 8, 16)])
|
||||
@parameterize("dtype", [torch.bfloat16, torch.float16])
|
||||
@parameterize("async_op", [True, False])
|
||||
@ -24,6 +25,7 @@ def check_all2all(shape, dtype, async_op):
|
||||
assert_close(output, output_fp8, rtol=0.1, atol=0.1)
|
||||
|
||||
|
||||
@clear_cache_before_run()
|
||||
@parameterize("shape", [(8, 8, 16)])
|
||||
@parameterize("dtype", [torch.bfloat16, torch.float16])
|
||||
@parameterize("async_op", [True, False])
|
||||
|
@ -6,9 +6,10 @@ from torch.testing import assert_close
|
||||
from colossalai import launch
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.quantization.fp8 import _all_to_all_fp8
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
|
||||
|
||||
|
||||
@clear_cache_before_run()
|
||||
@parameterize("shape", [(16, 8, 4)])
|
||||
@parameterize("scatter_dim", [0, 1, 2])
|
||||
@parameterize("dtype", [torch.bfloat16, torch.float16])
|
||||
|
@ -6,11 +6,12 @@ from torch.testing import assert_close
|
||||
from colossalai import launch
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.quantization.fp8 import all_to_all_single_fp8
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
|
||||
|
||||
dist.all_to_all_single
|
||||
|
||||
|
||||
@clear_cache_before_run()
|
||||
@parameterize("shape", [(4), (8, 7), (4, 8, 16)])
|
||||
@parameterize("dtype", [torch.bfloat16, torch.float16])
|
||||
@parameterize("fp8_format", ["e4m3", "e5m2"])
|
||||
|
@ -6,9 +6,10 @@ from torch.testing import assert_close
|
||||
from colossalai import launch
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.quantization.fp8 import _all_gather_fp8
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
|
||||
|
||||
|
||||
@clear_cache_before_run()
|
||||
@parameterize(
|
||||
"shape",
|
||||
[(3, 7, 16)],
|
||||
|
@ -5,7 +5,7 @@ from torch.testing import assert_close
|
||||
from colossalai import launch
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.quantization.fp8 import all_reduce_fp8
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
|
||||
|
||||
|
||||
@parameterize(
|
||||
@ -20,6 +20,7 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
(8,),
|
||||
],
|
||||
)
|
||||
@clear_cache_before_run()
|
||||
@parameterize("dtype", [torch.float16, torch.bfloat16])
|
||||
@parameterize("fp8_format", ["e4m3", "e5m2"])
|
||||
@parameterize("async_op", [True, False])
|
||||
|
@ -3,9 +3,10 @@ from torch.testing import assert_close
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.quantization.fp8 import cast_from_fp8, cast_from_fp8_pipeline, cast_to_fp8, cast_to_fp8_pipeline
|
||||
from colossalai.testing import parameterize
|
||||
from colossalai.testing import clear_cache_before_run, parameterize
|
||||
|
||||
|
||||
@clear_cache_before_run()
|
||||
@parameterize("shape", [(100, 10), (10, 100), (3, 7), (2, 1), (1, 2), (2, 2), (4, 2), (5,), (4,), (2,)])
|
||||
@parameterize("dtype", [torch.bfloat16, torch.float16, torch.float32])
|
||||
@parameterize("fp8_format", ["e4m3", "e5m2"])
|
||||
|
@ -8,7 +8,7 @@ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
from torch.testing import assert_close
|
||||
|
||||
from colossalai import launch
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
|
||||
|
||||
# example modified from https://pytorch.org/tutorials/intermediate/ddp_tutorial.html
|
||||
|
||||
@ -28,6 +28,7 @@ class ToyModel(nn.Module):
|
||||
return self.net2(self.relu(self.net1(x)))
|
||||
|
||||
|
||||
@clear_cache_before_run()
|
||||
@parameterize("mode", ["grad", "params"])
|
||||
def run_model(mode):
|
||||
rank = dist.get_rank()
|
||||
|
@ -6,9 +6,10 @@ from torch.testing import assert_close
|
||||
from colossalai import launch
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.quantization.fp8 import reduce_scatter_fp8
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
|
||||
|
||||
|
||||
@clear_cache_before_run()
|
||||
@parameterize("shape", [(16, 8, 4)])
|
||||
@parameterize("scatter_dim", [0, 1, 2])
|
||||
@parameterize("dtype", [torch.bfloat16, torch.float16])
|
||||
|
@ -51,7 +51,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||
if test_config["precision"] == "fp32":
|
||||
atol, rtol = 1e-5, 1e-3
|
||||
else:
|
||||
atol, rtol = 5e-2, 5e-2
|
||||
atol, rtol = 9e-2, 0
|
||||
if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0:
|
||||
row_layer_grads = get_grad_tensors_for_check(
|
||||
t5, sharded_t5, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0
|
||||
|
@ -1 +1 @@
|
||||
0.4.8
|
||||
0.4.9
|
||||
|
Loading…
Reference in New Issue
Block a user