diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index f7217a8f1..bee6387a7 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -15,21 +15,21 @@ repos:
args: ["--profile", "black"] # avoid conflict with black
- repo: https://github.com/psf/black-pre-commit-mirror
- rev: 24.8.0
+ rev: 24.10.0
hooks:
- id: black
name: black formatter
args: ['--line-length=120', '--target-version=py37', '--target-version=py38', '--target-version=py39','--target-version=py310']
- repo: https://github.com/pre-commit/mirrors-clang-format
- rev: v18.1.8
+ rev: v19.1.2
hooks:
- id: clang-format
name: clang formatter
types_or: [c++, c]
- repo: https://github.com/pre-commit/pre-commit-hooks
- rev: v4.6.0
+ rev: v5.0.0
hooks:
- id: check-yaml
- id: check-merge-conflict
diff --git a/README.md b/README.md
index 22c565b50..7b2004dfc 100644
--- a/README.md
+++ b/README.md
@@ -25,16 +25,36 @@
+## GPU Cloud HPC-AI.COM Coming!!
+
+For a limited time, you can access an H100 Server for just $1! This is your chance to leverage premium GPU power at an unbeatable price.
+Plus, when you refer a friend, you’ll receive 20% cashback or compute credits equal to 100% of their top-up!
+
+Our platform offers on-demand premium compute, ensuring safe, permanent data storage even after stopping your instance.
+Don’t miss this incredible opportunity to accelerate your AI projects!
+
+Unlock premium GPUs and register now at [HPC-AI.COM](https://hpc-ai.com) to receive $10!
+
+Special Bonuses:
+
+* Top up $1,000 and receive 300 credits
+* Top up $500 and receive 100 credits
+
+
+
+
## Latest News
+* [2024/10] [How to build a low-cost Sora-like app? Solutions for you](https://company.hpc-ai.com/blog/how-to-build-a-low-cost-sora-like-app-solutions-for-you)
+* [2024/09] [Singapore Startup HPC-AI Tech Secures 50 Million USD in Series A Funding to Build the Video Generation AI Model and GPU Platform](https://company.hpc-ai.com/blog/singapore-startup-hpc-ai-tech-secures-50-million-usd-in-series-a-funding-to-build-the-video-generation-ai-model-and-gpu-platform)
+* [2024/09] [Reducing AI Large Model Training Costs by 30% Requires Just a Single Line of Code From FP8 Mixed Precision Training Upgrades](https://company.hpc-ai.com/blog/reducing-ai-large-model-training-costs-by-30-requires-just-a-single-line-of-code-from-fp8-mixed-precision-training-upgrades)
* [2024/06] [Open-Sora Continues Open Source: Generate Any 16-Second 720p HD Video with One Click, Model Weights Ready to Use](https://hpc-ai.com/blog/open-sora-from-hpc-ai-tech-team-continues-open-source-generate-any-16-second-720p-hd-video-with-one-click-model-weights-ready-to-use)
* [2024/05] [Large AI Models Inference Speed Doubled, Colossal-Inference Open Source Release](https://hpc-ai.com/blog/colossal-inference)
* [2024/04] [Open-Sora Unveils Major Upgrade: Embracing Open Source with Single-Shot 16-Second Video Generation and 720p Resolution](https://hpc-ai.com/blog/open-soras-comprehensive-upgrade-unveiled-embracing-16-second-video-generation-and-720p-resolution-in-open-source)
* [2024/04] [Most cost-effective solutions for inference, fine-tuning and pretraining, tailored to LLaMA3 series](https://hpc-ai.com/blog/most-cost-effective-solutions-for-inference-fine-tuning-and-pretraining-tailored-to-llama3-series)
-* [2024/03] [314 Billion Parameter Grok-1 Inference Accelerated by 3.8x, Efficient and Easy-to-Use PyTorch+HuggingFace version is Here](https://hpc-ai.com/blog/314-billion-parameter-grok-1-inference-accelerated-by-3.8x-efficient-and-easy-to-use-pytorchhuggingface-version-is-here)
-* [2024/03] [Open-Sora: Revealing Complete Model Parameters, Training Details, and Everything for Sora-like Video Generation Models](https://hpc-ai.com/blog/open-sora-v1.0)
-* [2024/03] [Open-Sora:Sora Replication Solution with 46% Cost Reduction, Sequence Expansion to Nearly a Million](https://hpc-ai.com/blog/open-sora)
-* [2024/01] [Inference Performance Improved by 46%, Open Source Solution Breaks the Length Limit of LLM for Multi-Round Conversations](https://hpc-ai.com/blog/Colossal-AI-SwiftInfer)
-* [2023/07] [HPC-AI Tech Raises 22 Million USD in Series A Funding](https://www.hpc-ai.tech/blog/hpc-ai-tech-raises-22-million-usd-in-series-a-funding-to-fuel-team-expansion-and-business-growth)
## Table of Contents
diff --git a/applications/ColossalChat/README.md b/applications/ColossalChat/README.md
index 100cc5ece..ef904b864 100755
--- a/applications/ColossalChat/README.md
+++ b/applications/ColossalChat/README.md
@@ -27,11 +27,11 @@
- [Alternative Option For RLHF: SimPO](#alternative-option-for-rlhf-simple-preference-optimization-simpo)
- [Alternative Option For RLHF: ORPO](#alternative-option-for-rlhf-odds-ratio-preference-optimization-orpo)
- [Alternative Option For RLHF: KTO](#alternative-option-for-rlhf-kahneman-tversky-optimization-kto)
+- [O1 Journey](#o1-journey)
+ - [Inference with Self-refined MCTS](#inference-with-self-refined-mcts)
- [FAQ](#faq)
- [How to save/load checkpoint](#faq)
- [How to train with limited resources](#faq)
-- [The Plan](#the-plan)
- - [Real-time progress](#real-time-progress)
- [Invitation to open-source contribution](#invitation-to-open-source-contribution)
- [Quick Preview](#quick-preview)
- [Authors](#authors)
@@ -272,7 +272,7 @@ Odds Ratio Preference Optimization (ORPO) from this [paper](https://arxiv.org/pd
## Alternative Option For RLHF: Kahneman-Tversky Optimization (KTO)
We support the method introduced in the paper [KTO:Model Alignment as Prospect Theoretic Optimization](https://arxiv.org/pdf/2402.01306) (KTO). Which is a aligment method that directly maximize "human utility" of generation results. Read this [README](./examples/README.md) for more information.
-### Inference Quantization and Serving - After Training
+## Inference Quantization and Serving - After Training
We provide an online inference server and a benchmark. We aim to run inference on single GPU, so quantization is essential when using large models.
@@ -281,6 +281,21 @@ We support 8-bit quantization (RTN), 4-bit quantization (GPTQ), and FP16 inferen
Online inference server scripts can help you deploy your own services.
For more details, see [`inference/`](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Chat/inference).
+## O1 Journey
+### Inference with Self-refined MCTS
+We provide the implementation of MCT Self-Refine (MCTSr) algorithm, an innovative integration of Large Language Models with Monte Carlo Tree Search.
+To run inference with MCTS, simply use the following script.
+```python
+from coati.reasoner.guided_search.mcts import MCTS
+from coati.reasoner.guided_search.prompt_store.qwen import Qwen32B_prompt_CFG
+
+problem = "How Many R in 'Strawberry'"
+
+search_tree = MCTS(problem=problem, max_simulations=8, cfg=Qwen32B_prompt_CFG)
+answer = search_tree.simulate()
+print(answer)
+```
+
## Coati7B examples
### Generation
diff --git a/applications/ColossalChat/coati/models/loss.py b/applications/ColossalChat/coati/models/loss.py
index bd0bbd36b..927dfd5a8 100755
--- a/applications/ColossalChat/coati/models/loss.py
+++ b/applications/ColossalChat/coati/models/loss.py
@@ -153,10 +153,11 @@ class DpoLoss(nn.Module):
else:
# If no reference model is provided
ref_logratios = 0.0
+
pi_logratios = logprob_actor_chosen.sum(-1) - logprob_actor_reject.sum(-1)
logits = pi_logratios - ref_logratios - self.gamma / self.beta
losses = -torch.nn.functional.logsigmoid(self.beta * logits)
-
+ loss = losses.mean()
# Calculate rewards for logging
if logprob_ref_chosen is not None:
chosen_rewards = self.beta * (logprob_actor_chosen.sum(-1) - logprob_ref_chosen.sum(-1)).detach()
@@ -167,7 +168,7 @@ class DpoLoss(nn.Module):
else:
rejected_rewards = self.beta * logprob_actor_reject.sum(-1).detach()
- return losses, chosen_rewards, rejected_rewards
+ return loss, chosen_rewards, rejected_rewards
class LogSigLoss(nn.Module):
diff --git a/applications/ColossalChat/coati/models/utils.py b/applications/ColossalChat/coati/models/utils.py
index c583f057a..fe7ab2098 100755
--- a/applications/ColossalChat/coati/models/utils.py
+++ b/applications/ColossalChat/coati/models/utils.py
@@ -50,8 +50,8 @@ def _log_probs_from_logits(logits: torch.Tensor, labels: torch.Tensor) -> torch.
torch.Tensor: The log probabilities corresponding to the labels.
"""
log_probs = F.log_softmax(logits, dim=-1)
- log_probs_labels = log_probs.gather(dim=-1, index=labels.unsqueeze(-1))
- return log_probs_labels.squeeze(-1)
+ per_label_logps = log_probs.gather(dim=-1, index=labels.unsqueeze(-1))
+ return per_label_logps.squeeze(-1)
def calc_action_log_probs(logits: torch.Tensor, sequences: torch.LongTensor, num_actions: int) -> torch.Tensor:
diff --git a/applications/ColossalChat/coati/reasoner/guided_search/llm.py b/applications/ColossalChat/coati/reasoner/guided_search/llm.py
new file mode 100644
index 000000000..5025a98ea
--- /dev/null
+++ b/applications/ColossalChat/coati/reasoner/guided_search/llm.py
@@ -0,0 +1,26 @@
+import openai
+from openai.types.chat.chat_completion import ChatCompletion
+from openai.types.chat.chat_completion_message_param import ChatCompletionMessageParam
+
+API_KEY = "Dummy API Key"
+
+
+def get_client(base_url: str | None = None) -> openai.Client:
+ return openai.Client(api_key=API_KEY, base_url=base_url)
+
+
+def chat_completion(
+ messages: list[ChatCompletionMessageParam],
+ model: str,
+ base_url: str | None = None,
+ temperature: float = 0.8,
+ **kwargs,
+) -> ChatCompletion:
+ client = get_client(base_url)
+ response = client.chat.completions.create(
+ model=model,
+ messages=messages,
+ temperature=temperature,
+ **kwargs,
+ )
+ return response
diff --git a/applications/ColossalChat/coati/reasoner/guided_search/mcts.py b/applications/ColossalChat/coati/reasoner/guided_search/mcts.py
new file mode 100644
index 000000000..693e2b750
--- /dev/null
+++ b/applications/ColossalChat/coati/reasoner/guided_search/mcts.py
@@ -0,0 +1,250 @@
+"""
+Implementation of MCTS + Self-refine algorithm.
+
+Reference:
+1. "Accessing GPT-4 level Mathematical Olympiad Solutions via Monte
+Carlo Tree Self-refine with LLaMa-3 8B: A Technical Report"
+2. https://github.com/BrendanGraham14/mcts-llm/
+3. https://github.com/trotsky1997/MathBlackBox/
+4. https://github.com/openreasoner/openr/blob/main/reason/guided_search/tree.py
+"""
+
+from __future__ import annotations
+
+import math
+from collections import deque
+
+import numpy as np
+import tqdm
+from coati.reasoner.guided_search.llm import chat_completion
+from coati.reasoner.guided_search.prompt_store.base import PromptCFG
+from pydantic import BaseModel
+
+
+class MCTSNode(BaseModel):
+ """
+ Node for MCTS.
+ """
+
+ answer: str
+ parent: MCTSNode = None
+ children: list[MCTSNode] = []
+ num_visits: int = 0
+ Q: int = 0
+ rewards: list[int] = []
+
+ def expand_node(self, node) -> None:
+ self.children.append(node)
+
+ def add_reward(self, reward: int) -> None:
+ self.rewards.append(reward)
+ self.Q = (np.min(self.rewards) + np.mean(self.rewards)) / 2
+
+
+class MCTS(BaseModel):
+ """
+ Simulation of MCTS process.
+ """
+
+ problem: str
+ max_simulations: int
+ cfg: PromptCFG
+ C: float = 1.4
+ max_children: int = 2
+ epsilon: float = 1e-5
+ root: MCTSNode = None
+
+ def initialization(self):
+ """
+ Root Initiation.
+ """
+ # Dummy answer as root.
+ base_answer = self.sample_base_answer()
+ self.root = MCTSNode(answer=base_answer)
+ self.self_evaluate(self.root)
+
+ def is_fully_expanded(self, node: MCTSNode):
+ return len(node.children) >= self.max_children or any(child.Q > node.Q for child in node.children)
+
+ def select_node(self) -> MCTSNode:
+ """
+ Select next node to explore.
+ """
+ candidates: list[MCTSNode] = []
+ to_explore = deque([self.root])
+
+ while to_explore:
+ current_node = to_explore.popleft()
+ if not self.is_fully_expanded(current_node):
+ candidates.append(current_node)
+ to_explore.extend(current_node.children)
+
+ if not candidates:
+ return self.root
+
+ return max(candidates, key=self.compute_uct)
+
+ def self_evaluate(self, node: MCTSNode):
+ """
+ Sample reward of the answer.
+ """
+ reward = self.sample_reward(node)
+ node.add_reward(reward)
+
+ def back_propagation(self, node: MCTSNode):
+ """
+ Back propagate the value of the refined answer.
+ """
+ parent = node.parent
+ while parent:
+ best_child_Q = max(child.Q for child in parent.children)
+ parent.Q = (parent.Q + best_child_Q) / 2
+ parent.num_visits += 1
+ parent = parent.parent
+
+ def compute_uct(self, node: MCTSNode):
+ """
+ Compute UCT.
+ """
+ if node.parent is None:
+ return -100
+ return node.Q + self.C * math.sqrt(math.log(node.parent.num_visits + 1) / (node.num_visits + self.epsilon))
+
+ def simulate(self):
+ self.initialization()
+ for _ in tqdm.tqdm(range(self.max_simulations)):
+ node = self.select_node()
+ child = self.self_refine(node)
+ node.expand_node(child)
+ self.self_evaluate(child)
+ self.back_propagation(child)
+
+ return self.get_best_answer()
+
+ def get_best_answer(self):
+ to_visit = deque([self.root])
+ best_node = self.root
+
+ while to_visit:
+ current_node = to_visit.popleft()
+ if current_node.Q > best_node.Q:
+ best_node = current_node
+ to_visit.extend(current_node.children)
+
+ return best_node.answer
+
+ def self_refine(self, node: MCTSNode):
+ """
+ Refine node.
+ """
+ critique_response = chat_completion(
+ messages=[
+ {
+ "role": "system",
+ "content": self.cfg.critic_system_prompt,
+ },
+ {
+ "role": "user",
+ "content": "\n\n".join(
+ [
+ f"\n{self.problem}\n",
+ f"\n{node.answer}\n",
+ ]
+ ),
+ },
+ ],
+ model=self.cfg.model,
+ base_url=self.cfg.base_url,
+ max_tokens=self.cfg.max_tokens,
+ )
+ critique = critique_response.choices[0].message.content
+ assert critique is not None
+ refined_answer_response = chat_completion(
+ messages=[
+ {
+ "role": "system",
+ "content": self.cfg.refine_system_prompt,
+ },
+ {
+ "role": "user",
+ "content": "\n\n".join(
+ [
+ f"\n{self.problem}\n",
+ f"\n{node.answer}\n",
+ f"\n{critique}\n",
+ ]
+ ),
+ },
+ ],
+ model=self.cfg.model,
+ base_url=self.cfg.base_url,
+ max_tokens=self.cfg.max_tokens,
+ )
+ refined_answer = refined_answer_response.choices[0].message.content
+ assert refined_answer is not None
+
+ return MCTSNode(answer=refined_answer, parent=node)
+
+ def sample_base_answer(self):
+ response = chat_completion(
+ messages=[
+ {
+ "role": "system",
+ "content": "The user will provide a problem. Solve the problem. The response should begin with [reasoning process]...[Verification]... and end with [Final Answer]. \nThe answer is [answer] \n#### [answer].",
+ },
+ {
+ "role": "user",
+ "content": f"\n {self.problem} \n \nLet's think step by step",
+ },
+ ],
+ model=self.cfg.model,
+ base_url=self.cfg.base_url,
+ max_tokens=self.cfg.max_tokens,
+ )
+ assert response.choices[0].message.content is not None
+ return response.choices[0].message.content
+
+ def sample_reward(self, node: MCTSNode):
+ """
+ Calculate reward.
+ """
+ messages = [
+ {
+ "role": "system",
+ "content": self.cfg.evaluate_system_prompt,
+ },
+ {
+ "role": "user",
+ "content": "\n\n".join(
+ [
+ f"\n{self.problem}\n",
+ f"\n{node.answer}\n",
+ ]
+ ),
+ },
+ ]
+ for attempt in range(3):
+ try:
+ response = chat_completion(
+ messages=messages,
+ model=self.cfg.model,
+ base_url=self.cfg.base_url,
+ max_tokens=self.cfg.max_tokens,
+ )
+ assert response.choices[0].message.content is not None
+ return int(response.choices[0].message.content)
+ except ValueError:
+ messages.extend(
+ [
+ {
+ "role": "assistant",
+ "content": response.choices[0].message.content,
+ },
+ {
+ "role": "user",
+ "content": "Failed to parse reward as an integer.",
+ },
+ ]
+ )
+ if attempt == 2:
+ raise
diff --git a/applications/ColossalChat/coati/reasoner/guided_search/prompt_store/base.py b/applications/ColossalChat/coati/reasoner/guided_search/prompt_store/base.py
new file mode 100644
index 000000000..b325b8fa2
--- /dev/null
+++ b/applications/ColossalChat/coati/reasoner/guided_search/prompt_store/base.py
@@ -0,0 +1,10 @@
+from pydantic import BaseModel
+
+
+class PromptCFG(BaseModel):
+ model: str
+ base_url: str
+ max_tokens: int = 4096
+ critic_system_prompt: str
+ refine_system_prompt: str
+ evaluate_system_prompt: str
diff --git a/applications/ColossalChat/coati/reasoner/guided_search/prompt_store/qwen.py b/applications/ColossalChat/coati/reasoner/guided_search/prompt_store/qwen.py
new file mode 100644
index 000000000..8bf0fa959
--- /dev/null
+++ b/applications/ColossalChat/coati/reasoner/guided_search/prompt_store/qwen.py
@@ -0,0 +1,20 @@
+"""
+Prompts for Qwen Series.
+"""
+
+from coati.reasoner.guided_search.prompt_store.base import PromptCFG
+
+Qwen32B_prompt_CFG = PromptCFG(
+ base_url="http://0.0.0.0:8008/v1",
+ model="Qwen2.5-32B-Instruct",
+ critic_system_prompt="Provide a detailed and constructive critique to improve the answer. "
+ "Highlight specific areas that need refinement or correction.",
+ refine_system_prompt="""# Instruction
+ Refine the answer based on the critique. The response should begin with [reasoning process]...[Verification]... and end with [Final Answer].
+ """,
+ evaluate_system_prompt=(
+ "Analyze this answer strictly and critic, provide a reward score between -100 and 100 for the answer quality, using very strict standards. "
+ "Do not give a full score above 95. Make sure the reward score is an integer. "
+ "Return *ONLY* the score."
+ ),
+)
diff --git a/applications/ColossalChat/coati/trainer/dpo.py b/applications/ColossalChat/coati/trainer/dpo.py
index faa7a90d9..499113e96 100755
--- a/applications/ColossalChat/coati/trainer/dpo.py
+++ b/applications/ColossalChat/coati/trainer/dpo.py
@@ -6,6 +6,7 @@ import os
from typing import Any, Optional
import torch
+import torch.distributed as dist
from coati.models.loss import DpoLoss
from coati.models.utils import calc_masked_log_probs
from coati.trainer.utils import all_reduce_mean
@@ -13,10 +14,11 @@ from coati.utils import AccumulativeMeanMeter, save_checkpoint
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import DataLoader
-from tqdm import trange
+from tqdm import tqdm, trange
from transformers import PreTrainedTokenizerBase
from colossalai.booster import Booster, Plugin
+from colossalai.booster.plugin import HybridParallelPlugin
from colossalai.cluster import DistCoordinator
from colossalai.utils import get_current_device
@@ -96,18 +98,25 @@ class DPOTrainer(SLTrainer):
self.train_dataloader = train_preference_dataloader
self.eval_dataloader = eval_preference_dataloader
self.writer = None
- if use_wandb and is_rank_0():
+
+ init_criterion = (
+ dist.get_rank() == dist.get_world_size() - 1
+ if isinstance(self.plugin, HybridParallelPlugin) and self.plugin.pp_size > 1
+ else is_rank_0()
+ )
+
+ if use_wandb and init_criterion:
assert log_dir is not None, "log_dir must be provided when use_wandb is True"
import wandb
self.wandb_run = wandb.init(project="Coati-dpo", sync_tensorboard=True)
- if log_dir is not None and is_rank_0():
+ if log_dir is not None and init_criterion:
import os
import time
from torch.utils.tensorboard import SummaryWriter
- log_dir = os.path.join(log_dir, "dpo")
+ log_dir = os.path.join(log_dir, "DPO")
log_dir = os.path.join(log_dir, time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime()))
self.writer = SummaryWriter(log_dir=log_dir)
@@ -117,166 +126,147 @@ class DPOTrainer(SLTrainer):
epoch int: the number of current epoch
"""
self.model.train()
- self.accumulative_meter.reset()
- step_bar = trange(
- len(self.train_dataloader) // self.accumulation_steps,
- desc=f"Epoch {epoch + 1}/{self.max_epochs}",
- disable=not is_rank_0(),
- )
- for i, batch in enumerate(self.train_dataloader):
- batch = to_device(batch, self.device)
- (
- chosen_input_ids,
- chosen_attention_mask,
- chosen_loss_mask,
- reject_input_ids,
- reject_attention_mask,
- reject_loss_mask,
- ) = (
- batch["chosen_input_ids"],
- batch["chosen_attention_mask"],
- batch["chosen_loss_mask"],
- batch["reject_input_ids"],
- batch["reject_attention_mask"],
- batch["reject_loss_mask"],
+ if isinstance(self.plugin, HybridParallelPlugin) and self.plugin.pp_size > 1:
+ step_bar = tqdm(
+ range(len(self.train_dataloader)),
+ desc="Step",
+ disable=not (dist.get_rank() == dist.get_world_size() - 1),
)
- if not self.apply_loss_mask:
- chosen_loss_mask = chosen_loss_mask.fill_(1.0)
- reject_loss_mask = reject_loss_mask.fill_(1.0)
+ for i, batch in enumerate(self.train_dataloader):
+ batch = to_device(batch, self.device)
+ (
+ chosen_input_ids,
+ chosen_attention_mask,
+ chosen_loss_mask,
+ reject_input_ids,
+ reject_attention_mask,
+ reject_loss_mask,
+ ) = (
+ batch["chosen_input_ids"],
+ batch["chosen_attention_mask"],
+ batch["chosen_loss_mask"],
+ batch["reject_input_ids"],
+ batch["reject_attention_mask"],
+ batch["reject_loss_mask"],
+ )
+ batch_size = chosen_input_ids.size()[0]
+ # Calculate logits from reference model.
+ if self.ref_model is not None:
+ self.ref_model.eval()
+ with torch.no_grad():
+ ref_all_logits = self.ref_model(
+ input_ids=torch.cat([chosen_input_ids, reject_input_ids]),
+ attention_mask=torch.cat([chosen_attention_mask, reject_attention_mask]),
+ )["logits"]
+ ref_chosen_logits = ref_all_logits[:batch_size]
+ ref_reject_logits = ref_all_logits[batch_size:]
+ logprob_ref_chosen = calc_masked_log_probs(
+ ref_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:], self.length_normalization
+ )
+ logprob_ref_reject = calc_masked_log_probs(
+ ref_reject_logits, reject_input_ids, reject_loss_mask[:, 1:], self.length_normalization
+ )
+ else:
+ logprob_ref_chosen = None
+ logprob_ref_reject = None
- batch_size = chosen_input_ids.size()[0]
+ # Merge chosen and reject
+ inputs_ids = torch.stack([item for tup in zip(chosen_input_ids, reject_input_ids) for item in tup])
+ attention_mask = torch.stack(
+ [item for tup in zip(chosen_attention_mask, reject_attention_mask) for item in tup]
+ )
+ loss_mask = torch.stack([item for tup in zip(chosen_loss_mask, reject_loss_mask) for item in tup])
+ logprob_ref = torch.stack([item for tup in zip(logprob_ref_chosen, logprob_ref_reject) for item in tup])
- actor_all_logits = self.model(
- input_ids=torch.cat([chosen_input_ids, reject_input_ids]),
- attention_mask=torch.cat([chosen_attention_mask, reject_attention_mask]),
- )["logits"]
- actor_chosen_logits = actor_all_logits[:batch_size]
- actor_reject_logits = actor_all_logits[batch_size:]
- logprob_actor_chosen = calc_masked_log_probs(
- actor_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:], self.length_normalization
- )
+ data_iter = iter(
+ [
+ {
+ "input_ids": inputs_ids,
+ "attention_mask": attention_mask,
+ "loss_mask": loss_mask,
+ "logprob_ref": logprob_ref,
+ }
+ ]
+ )
+ rewards = []
- logprob_actor_reject = calc_masked_log_probs(
- actor_reject_logits, reject_input_ids, reject_loss_mask[:, 1:], self.length_normalization
- )
-
- if self.ref_model is not None:
- self.ref_model.eval()
- with torch.no_grad():
- ref_all_logits = self.ref_model(
- input_ids=torch.cat([chosen_input_ids, reject_input_ids]),
- attention_mask=torch.cat([chosen_attention_mask, reject_attention_mask]),
- )["logits"]
- ref_chosen_logits = ref_all_logits[:batch_size]
- ref_reject_logits = ref_all_logits[batch_size:]
- logprob_ref_chosen = calc_masked_log_probs(
- ref_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:], self.length_normalization
+ def _criterion(outputs, inputs):
+ loss, chosen_rewards, rejected_rewards = self.actor_loss_fn(
+ calc_masked_log_probs(
+ outputs["logits"][0::2],
+ inputs["input_ids"][0::2],
+ inputs["loss_mask"][0::2][:, 1:],
+ self.length_normalization,
+ ),
+ calc_masked_log_probs(
+ outputs["logits"][1::2],
+ inputs["input_ids"][1::2],
+ inputs["loss_mask"][1::2][:, 1:],
+ self.length_normalization,
+ ),
+ inputs["logprob_ref"][0::2] if inputs["logprob_ref"] is not None else None,
+ inputs["logprob_ref"][1::2] if inputs["logprob_ref"] is not None else None,
+ inputs["loss_mask"][0::2][:, 1:],
+ inputs["loss_mask"][1::2][:, 1:],
)
- logprob_ref_reject = calc_masked_log_probs(
- ref_reject_logits, reject_input_ids, reject_loss_mask[:, 1:], self.length_normalization
- )
- else:
- logprob_ref_chosen = None
- logprob_ref_reject = None
+ rewards.append(chosen_rewards)
+ rewards.append(rejected_rewards)
+ return loss
- losses, chosen_rewards, rejected_rewards = self.actor_loss_fn(
- logprob_actor_chosen,
- logprob_actor_reject,
- logprob_ref_chosen if logprob_ref_chosen is not None else None,
- logprob_ref_reject if logprob_ref_reject is not None else None,
- chosen_loss_mask[:, 1:],
- reject_loss_mask[:, 1:],
- )
- reward_accuracies = (chosen_rewards > rejected_rewards).float().mean()
+ outputs = self.booster.execute_pipeline(
+ data_iter,
+ self.model,
+ criterion=_criterion,
+ optimizer=self.optimizer,
+ return_loss=True,
+ )
+ loss = outputs["loss"]
+ if self.booster.plugin.stage_manager.is_last_stage():
+ chosen_rewards, rejected_rewards = rewards[0], rewards[1]
+ global_loss = all_reduce_mean(loss, self.plugin)
+ if dist.get_rank() == dist.get_world_size() - 1:
+ step_bar.set_postfix(
+ {
+ "train/loss": global_loss.item(),
+ "train/lr": self.actor_scheduler.get_last_lr()[0],
+ "train/chosen_rewards": chosen_rewards.to(torch.float16).mean().item(),
+ "train/rejected_rewards": rejected_rewards.to(torch.float16).mean().item(),
+ }
+ )
+ step_bar.update()
+ self.accumulative_meter.add("loss", global_loss.item())
+ self.accumulative_meter.add("chosen_rewards", chosen_rewards.to(torch.float16).mean().item())
+ self.accumulative_meter.add(
+ "rejected_rewards", rejected_rewards.to(torch.float16).mean().item()
+ )
+ if self.writer is not None:
+ self.writer.add_scalar("train/loss", self.accumulative_meter.get("loss"), i)
+ self.writer.add_scalar(
+ "train/chosen_rewards", self.accumulative_meter.get("chosen_rewards"), i
+ )
+ self.writer.add_scalar(
+ "train/rejected_rewards",
+ self.accumulative_meter.get("rejected_rewards"),
+ i,
+ )
+ self.writer.add_scalar(
+ "train/margin",
+ self.accumulative_meter.get("chosen_rewards")
+ - self.accumulative_meter.get("rejected_rewards"),
+ i,
+ )
- # DPO Loss
- loss = losses.mean()
-
- self.booster.backward(loss=loss, optimizer=self.optimizer)
- if self.num_train_step % self.accumulation_steps == self.accumulation_steps - 1:
self.optimizer.step()
self.optimizer.zero_grad()
self.actor_scheduler.step()
-
- # sync
- loss_mean = all_reduce_mean(tensor=loss)
- chosen_rewards_mean = all_reduce_mean(tensor=chosen_rewards)
- rejected_rewards_mean = all_reduce_mean(tensor=rejected_rewards)
- reward_accuracies_mean = all_reduce_mean(tensor=reward_accuracies)
- self.accumulative_meter.add("chosen_rewards", chosen_rewards_mean.to(torch.float16).mean().item())
- self.accumulative_meter.add("rejected_rewards", rejected_rewards_mean.to(torch.float16).mean().item())
- self.accumulative_meter.add("loss", loss_mean.to(torch.float16).item())
- self.accumulative_meter.add("accuracy", reward_accuracies_mean.to(torch.float16).item())
-
- if i % self.accumulation_steps == self.accumulation_steps - 1:
- self.num_train_step += 1
- step_bar.update()
- # logging
- if self.writer and is_rank_0():
- self.writer.add_scalar("train/loss", self.accumulative_meter.get("loss"), self.num_train_step)
- self.writer.add_scalar("train/lr", self.optimizer.param_groups[0]["lr"], self.num_train_step)
- self.writer.add_scalar(
- "train/chosen_rewards", self.accumulative_meter.get("chosen_rewards"), self.num_train_step
- )
- self.writer.add_scalar(
- "train/rejected_rewards",
- self.accumulative_meter.get("rejected_rewards"),
- self.num_train_step,
- )
- self.writer.add_scalar(
- "train/margin",
- self.accumulative_meter.get("chosen_rewards") - self.accumulative_meter.get("rejected_rewards"),
- self.num_train_step,
- )
- self.writer.add_scalar(
- "train/accuracy",
- self.accumulative_meter.get("accuracy"),
- self.num_train_step,
- )
- self.accumulative_meter.reset()
-
- if self.save_dir is not None and (self.num_train_step + 1) % self.save_interval == 0:
- # save checkpoint
- self.coordinator.print_on_master("\nStart saving model checkpoint with running states")
- save_checkpoint(
- save_dir=self.save_dir,
- booster=self.booster,
- model=self.model,
- optimizer=self.optimizer,
- lr_scheduler=self.actor_scheduler,
- epoch=epoch,
- step=i + 1,
- batch_size=batch_size,
- coordinator=self.coordinator,
- )
- self.coordinator.print_on_master(
- f"Saved checkpoint at epoch {epoch} step {self.save_interval} at folder {self.save_dir}"
- )
-
- step_bar.close()
-
- def _eval(self, epoch: int):
- """
- Args:
- epoch int: the number of current epoch
- """
- if self.eval_dataloader is None:
- self.coordinator.print_on_master("No eval dataloader is provided, skip evaluation")
- return
- self.model.eval()
- self.ref_model.eval()
- self.coordinator.print_on_master("\nStart evaluation...")
-
- step_bar = trange(
- len(self.eval_dataloader),
- desc=f"Epoch {epoch + 1}/{self.max_epochs}",
- disable=not is_rank_0(),
- )
-
- self.accumulative_meter.reset()
-
- with torch.no_grad():
- for i, batch in enumerate(self.eval_dataloader):
+ else:
+ self.accumulative_meter.reset()
+ step_bar = trange(
+ len(self.train_dataloader) // self.accumulation_steps,
+ desc=f"Epoch {epoch + 1}/{self.max_epochs}",
+ disable=not is_rank_0(),
+ )
+ for i, batch in enumerate(self.train_dataloader):
batch = to_device(batch, self.device)
(
chosen_input_ids,
@@ -300,12 +290,11 @@ class DPOTrainer(SLTrainer):
batch_size = chosen_input_ids.size()[0]
actor_all_logits = self.model(
- torch.cat([chosen_input_ids, reject_input_ids]),
- torch.cat([chosen_attention_mask, reject_attention_mask]),
+ input_ids=torch.cat([chosen_input_ids, reject_input_ids]),
+ attention_mask=torch.cat([chosen_attention_mask, reject_attention_mask]),
)["logits"]
actor_chosen_logits = actor_all_logits[:batch_size]
actor_reject_logits = actor_all_logits[batch_size:]
-
logprob_actor_chosen = calc_masked_log_probs(
actor_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:], self.length_normalization
)
@@ -314,22 +303,26 @@ class DPOTrainer(SLTrainer):
actor_reject_logits, reject_input_ids, reject_loss_mask[:, 1:], self.length_normalization
)
- self.ref_model.eval()
+ if self.ref_model is not None:
+ self.ref_model.eval()
+ with torch.no_grad():
+ ref_all_logits = self.ref_model(
+ input_ids=torch.cat([chosen_input_ids, reject_input_ids]),
+ attention_mask=torch.cat([chosen_attention_mask, reject_attention_mask]),
+ )["logits"]
+ ref_chosen_logits = ref_all_logits[:batch_size]
+ ref_reject_logits = ref_all_logits[batch_size:]
+ logprob_ref_chosen = calc_masked_log_probs(
+ ref_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:], self.length_normalization
+ )
+ logprob_ref_reject = calc_masked_log_probs(
+ ref_reject_logits, reject_input_ids, reject_loss_mask[:, 1:], self.length_normalization
+ )
+ else:
+ logprob_ref_chosen = None
+ logprob_ref_reject = None
- ref_all_logits = self.ref_model(
- torch.cat([chosen_input_ids, reject_input_ids]),
- torch.cat([chosen_attention_mask, reject_attention_mask]),
- )["logits"]
- ref_chosen_logits = ref_all_logits[:batch_size]
- ref_reject_logits = ref_all_logits[batch_size:]
- logprob_ref_chosen = calc_masked_log_probs(
- ref_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:], self.length_normalization
- )
- logprob_ref_reject = calc_masked_log_probs(
- ref_reject_logits, reject_input_ids, reject_loss_mask[:, 1:], self.length_normalization
- )
-
- losses, chosen_rewards, rejected_rewards = self.actor_loss_fn(
+ loss, chosen_rewards, rejected_rewards = self.actor_loss_fn(
logprob_actor_chosen,
logprob_actor_reject,
logprob_ref_chosen if logprob_ref_chosen is not None else None,
@@ -338,7 +331,9 @@ class DPOTrainer(SLTrainer):
reject_loss_mask[:, 1:],
)
reward_accuracies = (chosen_rewards > rejected_rewards).float().mean()
- loss = losses.mean()
+
+ self.booster.backward(loss=loss, optimizer=self.optimizer)
+ # sync
loss_mean = all_reduce_mean(tensor=loss)
chosen_rewards_mean = all_reduce_mean(tensor=chosen_rewards)
rejected_rewards_mean = all_reduce_mean(tensor=rejected_rewards)
@@ -347,16 +342,301 @@ class DPOTrainer(SLTrainer):
self.accumulative_meter.add("rejected_rewards", rejected_rewards_mean.to(torch.float16).mean().item())
self.accumulative_meter.add("loss", loss_mean.to(torch.float16).item())
self.accumulative_meter.add("accuracy", reward_accuracies_mean.to(torch.float16).item())
- self.accumulative_meter.add(
- "margin", (chosen_rewards_mean - rejected_rewards_mean).to(torch.float16).mean().item()
- )
- step_bar.update()
- msg = "Evaluation Result:\n"
- for tag in ["loss", "chosen_rewards", "rejected_rewards", "accuracy", "margin"]:
- msg = msg + f"{tag}: {self.accumulative_meter.get(tag)}\n"
- self.coordinator.print_on_master(msg)
- os.makedirs(self.save_dir, exist_ok=True)
- with open(os.path.join(self.save_dir, f"eval_result_epoch{epoch}.txt"), "w") as f:
- f.write(msg)
+ if (i + 1) % self.accumulation_steps == 0:
+ self.optimizer.step()
+ self.optimizer.zero_grad()
+ self.actor_scheduler.step()
+
+ step_bar.set_postfix(
+ {
+ "train/loss": self.accumulative_meter.get("loss"),
+ "train/chosen_rewards": self.accumulative_meter.get("chosen_rewards"),
+ "train/rejected_rewards": self.accumulative_meter.get("rejected_rewards"),
+ "train/accuracy": self.accumulative_meter.get("accuracy"),
+ }
+ )
+ step_bar.update()
+ if self.writer and is_rank_0():
+ self.writer.add_scalar("train/loss", self.accumulative_meter.get("loss"), self.num_train_step)
+ self.writer.add_scalar("train/lr", self.optimizer.param_groups[0]["lr"], self.num_train_step)
+ self.writer.add_scalar(
+ "train/chosen_rewards", self.accumulative_meter.get("chosen_rewards"), self.num_train_step
+ )
+ self.writer.add_scalar(
+ "train/rejected_rewards",
+ self.accumulative_meter.get("rejected_rewards"),
+ self.num_train_step,
+ )
+ self.writer.add_scalar(
+ "train/margin",
+ self.accumulative_meter.get("chosen_rewards")
+ - self.accumulative_meter.get("rejected_rewards"),
+ self.num_train_step,
+ )
+ self.writer.add_scalar(
+ "train/accuracy",
+ self.accumulative_meter.get("accuracy"),
+ self.num_train_step,
+ )
+ self.num_train_step += 1
+ self.accumulative_meter.reset()
+
+ if self.save_dir is not None and self.num_train_step > 0 and self.num_train_step % self.save_interval == 0:
+ # save checkpoint
+ self.coordinator.print_on_master("\nStart saving model checkpoint with running states")
+ save_checkpoint(
+ save_dir=self.save_dir,
+ booster=self.booster,
+ model=self.model,
+ optimizer=self.optimizer,
+ lr_scheduler=self.actor_scheduler,
+ epoch=epoch,
+ step=self.num_train_step,
+ batch_size=batch_size,
+ coordinator=self.coordinator,
+ )
+ self.coordinator.print_on_master(
+ f"Saved checkpoint at epoch {epoch} step {self.save_interval} at folder {self.save_dir}"
+ )
+
+ step_bar.close()
+
+ def _eval(self, epoch: int):
+ """
+ Args:
+ epoch int: the number of current epoch
+ """
+ if self.eval_dataloader is None:
+ self.coordinator.print_on_master("No eval dataloader is provided, skip evaluation")
+ return
+ self.model.eval()
+ self.ref_model.eval()
+ self.accumulative_meter.reset()
+ self.coordinator.print_on_master("\nStart evaluation...")
+
+ if isinstance(self.plugin, HybridParallelPlugin) and self.plugin.pp_size > 1:
+ step_bar = tqdm(
+ range(len(self.eval_dataloader)),
+ desc="Step",
+ disable=not (dist.get_rank() == dist.get_world_size() - 1),
+ )
+ with torch.no_grad():
+ for _, batch in enumerate(self.eval_dataloader):
+ batch = to_device(batch, self.device)
+ (
+ chosen_input_ids,
+ chosen_attention_mask,
+ chosen_loss_mask,
+ reject_input_ids,
+ reject_attention_mask,
+ reject_loss_mask,
+ ) = (
+ batch["chosen_input_ids"],
+ batch["chosen_attention_mask"],
+ batch["chosen_loss_mask"],
+ batch["reject_input_ids"],
+ batch["reject_attention_mask"],
+ batch["reject_loss_mask"],
+ )
+ batch_size = chosen_input_ids.size()[0]
+ # Calculate logits from reference model.
+ if self.ref_model is not None:
+ self.ref_model.eval()
+ with torch.no_grad():
+ ref_all_logits = self.ref_model(
+ input_ids=torch.cat([chosen_input_ids, reject_input_ids]),
+ attention_mask=torch.cat([chosen_attention_mask, reject_attention_mask]),
+ )["logits"]
+ ref_chosen_logits = ref_all_logits[:batch_size]
+ ref_reject_logits = ref_all_logits[batch_size:]
+ logprob_ref_chosen = calc_masked_log_probs(
+ ref_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:], self.length_normalization
+ )
+ logprob_ref_reject = calc_masked_log_probs(
+ ref_reject_logits, reject_input_ids, reject_loss_mask[:, 1:], self.length_normalization
+ )
+ else:
+ logprob_ref_chosen = None
+ logprob_ref_reject = None
+
+ # Merge chosen and reject
+ inputs_ids = torch.stack([item for tup in zip(chosen_input_ids, reject_input_ids) for item in tup])
+ attention_mask = torch.stack(
+ [item for tup in zip(chosen_attention_mask, reject_attention_mask) for item in tup]
+ )
+ loss_mask = torch.stack([item for tup in zip(chosen_loss_mask, reject_loss_mask) for item in tup])
+ logprob_ref = torch.stack(
+ [item for tup in zip(logprob_ref_chosen, logprob_ref_reject) for item in tup]
+ )
+
+ data_iter = iter(
+ [
+ {
+ "input_ids": inputs_ids,
+ "attention_mask": attention_mask,
+ "loss_mask": loss_mask,
+ "logprob_ref": logprob_ref,
+ }
+ ]
+ )
+ rewards = []
+
+ def _criterion(outputs, inputs):
+ loss, chosen_rewards, rejected_rewards = self.actor_loss_fn(
+ calc_masked_log_probs(
+ outputs["logits"][0::2],
+ inputs["input_ids"][0::2],
+ inputs["loss_mask"][0::2][:, 1:],
+ self.length_normalization,
+ ),
+ calc_masked_log_probs(
+ outputs["logits"][1::2],
+ inputs["input_ids"][1::2],
+ inputs["loss_mask"][1::2][:, 1:],
+ self.length_normalization,
+ ),
+ inputs["logprob_ref"][0::2] if inputs["logprob_ref"] is not None else None,
+ inputs["logprob_ref"][1::2] if inputs["logprob_ref"] is not None else None,
+ inputs["loss_mask"][0::2][:, 1:],
+ inputs["loss_mask"][1::2][:, 1:],
+ )
+ rewards.append(chosen_rewards)
+ rewards.append(rejected_rewards)
+ return loss
+
+ outputs = self.booster.execute_pipeline(
+ data_iter,
+ self.model,
+ criterion=_criterion,
+ optimizer=self.optimizer,
+ return_loss=True,
+ )
+ loss = outputs["loss"]
+ if self.booster.plugin.stage_manager.is_last_stage():
+ chosen_rewards, rejected_rewards = rewards[0], rewards[1]
+ global_loss = all_reduce_mean(loss, self.plugin)
+ chosen_rewards_mean = all_reduce_mean(chosen_rewards, self.plugin)
+ rejected_rewards_mean = all_reduce_mean(rejected_rewards, self.plugin)
+ if dist.get_rank() == dist.get_world_size() - 1:
+ step_bar.set_postfix(
+ {
+ "eval/loss": global_loss.item(),
+ "eval/lr": self.actor_scheduler.get_last_lr()[0],
+ "eval/chosen_rewards": chosen_rewards.to(torch.float16).mean().item(),
+ "eval/rejected_rewards": rejected_rewards.to(torch.float16).mean().item(),
+ }
+ )
+ self.accumulative_meter.add(
+ "chosen_rewards", chosen_rewards_mean.to(torch.float16).mean().item()
+ )
+ self.accumulative_meter.add(
+ "rejected_rewards", rejected_rewards_mean.to(torch.float16).mean().item()
+ )
+ self.accumulative_meter.add("loss", global_loss.to(torch.float16).item())
+ step_bar.update()
+ if self.booster.plugin.stage_manager.is_last_stage():
+ msg = "\nEvaluation Result:\n"
+ for tag in ["loss", "chosen_rewards", "rejected_rewards"]:
+ msg = msg + f"{tag}: {self.accumulative_meter.get(tag)}\n"
+ if dist.get_rank() == dist.get_world_size() - 1:
+ print(msg)
+ else:
+ step_bar = trange(
+ len(self.eval_dataloader),
+ desc=f"Epoch {epoch + 1}/{self.max_epochs}",
+ disable=not is_rank_0(),
+ )
+ with torch.no_grad():
+ for i, batch in enumerate(self.eval_dataloader):
+ batch = to_device(batch, self.device)
+ (
+ chosen_input_ids,
+ chosen_attention_mask,
+ chosen_loss_mask,
+ reject_input_ids,
+ reject_attention_mask,
+ reject_loss_mask,
+ ) = (
+ batch["chosen_input_ids"],
+ batch["chosen_attention_mask"],
+ batch["chosen_loss_mask"],
+ batch["reject_input_ids"],
+ batch["reject_attention_mask"],
+ batch["reject_loss_mask"],
+ )
+ if not self.apply_loss_mask:
+ chosen_loss_mask = chosen_loss_mask.fill_(1.0)
+ reject_loss_mask = reject_loss_mask.fill_(1.0)
+
+ batch_size = chosen_input_ids.size()[0]
+
+ actor_all_logits = self.model(
+ torch.cat([chosen_input_ids, reject_input_ids]),
+ torch.cat([chosen_attention_mask, reject_attention_mask]),
+ )["logits"]
+ actor_chosen_logits = actor_all_logits[:batch_size]
+ actor_reject_logits = actor_all_logits[batch_size:]
+
+ logprob_actor_chosen = calc_masked_log_probs(
+ actor_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:], self.length_normalization
+ )
+
+ logprob_actor_reject = calc_masked_log_probs(
+ actor_reject_logits, reject_input_ids, reject_loss_mask[:, 1:], self.length_normalization
+ )
+ ref_all_logits = self.ref_model(
+ torch.cat([chosen_input_ids, reject_input_ids]),
+ torch.cat([chosen_attention_mask, reject_attention_mask]),
+ )["logits"]
+ ref_chosen_logits = ref_all_logits[:batch_size]
+ ref_reject_logits = ref_all_logits[batch_size:]
+ logprob_ref_chosen = calc_masked_log_probs(
+ ref_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:], self.length_normalization
+ )
+ logprob_ref_reject = calc_masked_log_probs(
+ ref_reject_logits, reject_input_ids, reject_loss_mask[:, 1:], self.length_normalization
+ )
+
+ losses, chosen_rewards, rejected_rewards = self.actor_loss_fn(
+ logprob_actor_chosen,
+ logprob_actor_reject,
+ logprob_ref_chosen if logprob_ref_chosen is not None else None,
+ logprob_ref_reject if logprob_ref_reject is not None else None,
+ chosen_loss_mask[:, 1:],
+ reject_loss_mask[:, 1:],
+ )
+ reward_accuracies = (chosen_rewards > rejected_rewards).float().mean()
+ loss = losses.mean()
+ loss_mean = all_reduce_mean(tensor=loss)
+ chosen_rewards_mean = all_reduce_mean(tensor=chosen_rewards)
+ rejected_rewards_mean = all_reduce_mean(tensor=rejected_rewards)
+ reward_accuracies_mean = all_reduce_mean(tensor=reward_accuracies)
+ self.accumulative_meter.add("chosen_rewards", chosen_rewards_mean.to(torch.float16).mean().item())
+ self.accumulative_meter.add(
+ "rejected_rewards", rejected_rewards_mean.to(torch.float16).mean().item()
+ )
+ self.accumulative_meter.add("loss", loss_mean.to(torch.float16).item())
+ self.accumulative_meter.add("accuracy", reward_accuracies_mean.to(torch.float16).item())
+ self.accumulative_meter.add(
+ "margin", (chosen_rewards_mean - rejected_rewards_mean).to(torch.float16).mean().item()
+ )
+ step_bar.set_postfix(
+ {
+ "eval/loss": self.accumulative_meter.get("loss"),
+ "eval/chosen_rewards": self.accumulative_meter.get("chosen_rewards"),
+ "eval/rejected_rewards": self.accumulative_meter.get("rejected_rewards"),
+ "eval/accuracy": self.accumulative_meter.get("accuracy"),
+ }
+ )
+ step_bar.update()
+
+ msg = "\nEvaluation Result:\n"
+ for tag in ["loss", "chosen_rewards", "rejected_rewards", "accuracy", "margin"]:
+ msg = msg + f"{tag}: {self.accumulative_meter.get(tag)}\n"
+ self.coordinator.print_on_master(msg)
+ if self.save_dir is not None:
+ os.makedirs(self.save_dir, exist_ok=True)
+ with open(os.path.join(self.save_dir, f"eval_result_epoch{epoch}.txt"), "w") as f:
+ f.write(msg)
step_bar.close()
diff --git a/applications/ColossalChat/examples/data_preparation_scripts/prepare_dataset.py b/applications/ColossalChat/examples/data_preparation_scripts/prepare_dataset.py
index a35f2bf52..b551497b9 100644
--- a/applications/ColossalChat/examples/data_preparation_scripts/prepare_dataset.py
+++ b/applications/ColossalChat/examples/data_preparation_scripts/prepare_dataset.py
@@ -73,8 +73,7 @@ def main():
"--conversation_template_config",
type=str,
default="conversation_template_config",
- help="Path \
- to save conversation template config files.",
+ help="Path to save conversation template config files.",
)
parser.add_argument("--data_cache_dir", type=str, default="cache", help="Data cache directory")
parser.add_argument(
diff --git a/applications/ColossalChat/examples/training_scripts/train_dpo.py b/applications/ColossalChat/examples/training_scripts/train_dpo.py
index 3b324ee78..ad81db73a 100755
--- a/applications/ColossalChat/examples/training_scripts/train_dpo.py
+++ b/applications/ColossalChat/examples/training_scripts/train_dpo.py
@@ -13,7 +13,7 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
import colossalai
from colossalai.booster import Booster
-from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin
+from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.cluster import DistCoordinator
from colossalai.logging import get_dist_logger
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
@@ -29,8 +29,6 @@ def train(args):
# check lora compatibility
if "gemini" in args.plugin and lora_config is not None and lora_config.r > 0:
raise ValueError("LoRA is not supported in GeminiPlugin. Please use other plugin")
- if args.plugin == "gemini_auto" and args.accumulation_steps > 1:
- raise ValueError("Gradient accumulation is not supported in GeminiPlugin. Please use other plugin")
# ==============================
# Initialize Distributed Training
@@ -46,7 +44,7 @@ def train(args):
Default torch ddp plugin without any acceleration, for
debugging purpose acceleration, for debugging purpose
"""
- plugin = TorchDDPPlugin(find_unused_parameters=True)
+ plugin = TorchDDPPlugin(find_unused_parameters=not args.grad_checkpoint)
elif args.plugin == "gemini":
plugin = GeminiPlugin(
precision=args.mixed_precision,
@@ -56,14 +54,6 @@ def train(args):
enable_gradient_accumulation=True,
enable_flash_attention=args.use_flash_attn,
)
- elif args.plugin == "gemini_auto":
- plugin = GeminiPlugin(
- precision=args.mixed_precision,
- placement_policy="auto",
- initial_scale=2**16,
- max_norm=args.grad_clip,
- enable_flash_attention=args.use_flash_attn,
- )
elif args.plugin == "zero2":
plugin = LowLevelZeroPlugin(
stage=2,
@@ -92,20 +82,24 @@ def train(args):
parallel_output=False,
max_norm=args.grad_clip,
precision=args.mixed_precision,
+ microbatch_size=args.microbatch_size,
)
else:
raise ValueError(f"Unknown plugin {args.plugin}")
booster = Booster(plugin=plugin)
- ref_booster = Booster(plugin=plugin)
- # ======================================================
- # Initialize Model, Objective, Optimizer and LR Scheduler
- # ======================================================
- # Temp Fix: Disable lazy init due to version conflict
- # init_ctx = (
- # LazyInitContext(default_device=get_current_device()) if isinstance(plugin, (GeminiPlugin,)) else nullcontext()
- # )
+ ref_plugin = HybridParallelPlugin(
+ tp_size=args.ref_tp,
+ pp_size=1,
+ zero_stage=args.zero_stage,
+ enable_flash_attention=args.use_flash_attn,
+ cpu_offload=True if args.zero_stage >= 1 and args.zero_cpu_offload else False,
+ parallel_output=False,
+ max_norm=args.grad_clip,
+ precision=args.mixed_precision,
+ )
+ ref_booster = Booster(plugin=ref_plugin)
init_ctx = nullcontext()
with init_ctx:
@@ -130,6 +124,7 @@ def train(args):
ref_model = AutoModelForCausalLM.from_pretrained(args.pretrain)
else:
ref_model = None
+
if args.lora_config is not None:
model = convert_to_lora_module(model, lora_config=lora_config)
for name, module in model.named_modules():
@@ -139,7 +134,9 @@ def train(args):
disable_dropout(ref_model)
if args.grad_checkpoint:
- # Note, for some models, lora may not be compatible with gradient checkpointing
+ # Make sure gradient checkpointing can be activated.
+ model.train()
+ # Note, for some models, lora may not be compatible with gradient checkpointing.
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")
@@ -169,7 +166,7 @@ def train(args):
adamw_mode=True,
)
- # configure dataset
+ # Configure dataset
coordinator.print_on_master(f"Load dataset: {args.dataset}")
mode_map = {"train": "train", "valid": "validation", "test": "test"}
train_dataset = load_tokenized_dataset(dataset_paths=args.dataset, mode="train", mode_map=mode_map)
@@ -213,14 +210,15 @@ def train(args):
default_dtype = torch.float16 if args.mixed_precision == "fp16" else torch.bfloat16
torch.set_default_dtype(default_dtype)
+
model, optim, _, train_dataloader, lr_scheduler = booster.boost(
model=model,
optimizer=optim,
lr_scheduler=lr_scheduler,
dataloader=train_dataloader,
)
- if ref_model is not None:
- ref_model, _, _, _, _ = ref_booster.boost(model=ref_model, dataloader=train_dataloader)
+ ref_model, _, _, _, _ = ref_booster.boost(model=ref_model)
+
torch.set_default_dtype(torch.float)
coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB")
@@ -312,7 +310,7 @@ if __name__ == "__main__":
"--plugin",
type=str,
default="gemini",
- choices=["gemini", "gemini_auto", "zero2", "zero2_cpu", "3d"],
+ choices=["gemini", "zero2", "zero2_cpu", "3d", "ddp"],
help="Choose which plugin to use",
)
parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping value")
@@ -342,22 +340,35 @@ if __name__ == "__main__":
parser.add_argument("--max_length", type=int, default=2048, help="Model max length")
parser.add_argument("--max_epochs", type=int, default=3)
parser.add_argument("--batch_size", type=int, default=4)
+ parser.add_argument("--disable_loss_mask", default=False, action="store_true")
+ parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["fp16", "bf16"], help="Mixed precision")
+ parser.add_argument("--lora_config", type=str, default=None, help="low-rank adaptation config file path")
+ parser.add_argument("--save_interval", type=int, default=1000, help="number of step between two checkpoints")
+ parser.add_argument("--lr", type=float, default=5e-6)
+ parser.add_argument("--accumulation_steps", type=int, default=1)
+ parser.add_argument("--log_dir", default=None, type=str)
+ parser.add_argument("--use_wandb", default=False, action="store_true")
+ parser.add_argument("--grad_checkpoint", default=False, action="store_true")
+ parser.add_argument("--use_flash_attn", default=False, action="store_true")
+ parser.add_argument(
+ "--microbatch_size",
+ type=int,
+ default=2,
+ help="Micro batch size for PP training. To activate PP training for DPO-like algorithm, you must keep size even and the size should be equal or greater than 2.",
+ )
+ # Parameter for reference model
parser.add_argument(
"--disable_reference_model",
action="store_true",
default=False,
help="Disable the reference model (enabled by default)",
)
- parser.add_argument("--disable_loss_mask", default=False, action="store_true")
- parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["fp16", "bf16"], help="Mixed precision")
- parser.add_argument("--lora_config", type=str, default=None, help="low-rank adaptation config file path")
- parser.add_argument("--save_interval", type=int, default=1000, help="number of step between two checkpoints")
- parser.add_argument("--lr", type=float, default=5e-6)
- parser.add_argument("--accumulation_steps", type=int, default=8)
- parser.add_argument("--log_dir", default=None, type=str)
- parser.add_argument("--use_wandb", default=False, action="store_true")
- parser.add_argument("--grad_checkpoint", default=False, action="store_true")
- parser.add_argument("--use_flash_attn", default=False, action="store_true")
+ parser.add_argument(
+ "--ref_tp",
+ type=int,
+ default=1,
+ help="TP size for reference model; used only when reference model is too large.",
+ )
args = parser.parse_args()
# fool proof hyperparameter setup
diff --git a/applications/ColossalChat/examples/training_scripts/train_sft.py b/applications/ColossalChat/examples/training_scripts/train_sft.py
index 62acad32f..e319340c3 100755
--- a/applications/ColossalChat/examples/training_scripts/train_sft.py
+++ b/applications/ColossalChat/examples/training_scripts/train_sft.py
@@ -68,7 +68,7 @@ def train(args):
Default torch ddp plugin without any acceleration, for
debugging purpose acceleration, for debugging purpose
"""
- plugin = TorchDDPPlugin(find_unused_parameters=True if args.grad_checkpoint is False else False)
+ plugin = TorchDDPPlugin(find_unused_parameters=not args.grad_checkpoint)
elif args.plugin == "gemini":
plugin = GeminiPlugin(
precision=args.mixed_precision,
diff --git a/applications/ColossalChat/tests/test_templating.sh b/applications/ColossalChat/tests/test_templating.sh
index 6ee10e8be..defe6f71b 100755
--- a/applications/ColossalChat/tests/test_templating.sh
+++ b/applications/ColossalChat/tests/test_templating.sh
@@ -4,7 +4,7 @@ BASE_TEMP_DIR=$BASE_DIR/temp
EXAMPLES_DIR=$BASE_DIR/examples
TEST_DATA_DIR=$BASE_DIR/tests/test_data
DATA_SAVE_PATH=$BASE_TEMP_DIR/tests
-CONFIG_DIR=$BASE_DIR/config
+CONFIG_DIR=$BASE_DIR/conversation_template
# MODELS=("colossal-llama2" "llama2" "mistral" "chatGLM2" "chatGLM3" "deepseek" "Yi" "baichuan") # for local test
MODELS=("colossal-llama2" "llama2" "chatGLM2" "chatGLM3" "deepseek" "Yi")
@@ -39,23 +39,23 @@ get_pretrain() {
get_conversation_template_config() {
local model=$1
if [[ $model == "colossal-llama2" ]]; then
- echo "$CONFIG_DIR/conversation_template/colossal-llama2.json"
+ echo "$CONFIG_DIR/colossal-llama2.json"
elif [[ $model == "llama2" ]]; then
- echo "$CONFIG_DIR/conversation_template/llama2.json"
+ echo "$CONFIG_DIR/llama2.json"
elif [[ $model == "deepseek" ]]; then
- echo "$CONFIG_DIR/conversation_template/deepseek-ai_DeepSeek-V2-Lite.json"
+ echo "$CONFIG_DIR/deepseek-ai_DeepSeek-V2-Lite.json"
elif [[ $model == "mistral" ]]; then
- echo "$CONFIG_DIR/conversation_template/mistralai_Mixtral-8x7B-Instruct-v0.1.json"
+ echo "$CONFIG_DIR/mistralai_Mixtral-8x7B-Instruct-v0.1.json"
elif [[ $model == "chatGLM2" ]]; then
- echo "$CONFIG_DIR/conversation_template/THUDM_chatglm2-6b.json"
+ echo "$CONFIG_DIR/THUDM_chatglm2-6b.json"
elif [[ $model == "chatGLM3" ]]; then
- echo "$CONFIG_DIR/conversation_template/THUDM_chatglm3-6b.json"
+ echo "$CONFIG_DIR/THUDM_chatglm3-6b.json"
elif [[ $model == "phi" ]]; then
- echo "$CONFIG_DIR/conversation_template/microsoft_phi-2.json"
+ echo "$CONFIG_DIR/microsoft_phi-2.json"
elif [[ $model == "Yi" ]]; then
- echo "$CONFIG_DIR/conversation_template/01-ai_Yi-1.5-9B-Chat.json"
+ echo "$CONFIG_DIR/01-ai_Yi-1.5-9B-Chat.json"
elif [[ $model == "baichuan" ]]; then
- echo "$CONFIG_DIR/conversation_template/baichuan-inc_Baichuan2-13B-Chat.json"
+ echo "$CONFIG_DIR/baichuan-inc_Baichuan2-13B-Chat.json"
else
echo "Unknown model $model"
exit 1
@@ -71,6 +71,7 @@ for model in ${MODELS[@]}; do
rm -rf $SAVE_DIR/arrow
pretrain=$(get_pretrain $model)
conversation_template_config=$(get_conversation_template_config $model)
+ echo $conversation_template_config
python $EXAMPLES_DIR/data_preparation_scripts/prepare_dataset.py --type sft --data_input_dirs $TEST_DATA_DIR/sft \
--tokenizer_dir $pretrain \
--conversation_template_config $conversation_template_config \
diff --git a/colossalai/accelerator/cuda_accelerator.py b/colossalai/accelerator/cuda_accelerator.py
index f1ab487d4..32e62b33f 100644
--- a/colossalai/accelerator/cuda_accelerator.py
+++ b/colossalai/accelerator/cuda_accelerator.py
@@ -279,4 +279,4 @@ class CudaAccelerator(BaseAccelerator):
"""
Return autocast function
"""
- return torch.cuda.amp.autocast(enabled=enabled, dtype=dtype, cache_enabled=cache_enabled)
+ return torch.amp.autocast(device_type="cuda", enabled=enabled, dtype=dtype, cache_enabled=cache_enabled)
diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py
index ae49aa8b1..4c8258113 100644
--- a/colossalai/booster/plugin/gemini_plugin.py
+++ b/colossalai/booster/plugin/gemini_plugin.py
@@ -322,7 +322,6 @@ class GeminiPlugin(DPPluginBase):
enable_flash_attention (bool, optional): Whether to switch on flash attention in Shardformer. Defaults to False.
enable_jit_fused (bool, optional): Whether to switch on JIT in Shardformer. Default to False.
enable_sequence_parallelism (bool): Whether to turn on sequence parallelism in Shardformer. Defaults to False.
- enable_sequence_overlap (bool): Whether to turn on sequence overlap in Shardformer. Defaults to False.
use_fp8 (bool, optional): Whether to enable fp8 mixed precision training. Defaults to False.
verbose (bool, optional): verbose mode. Debug info including chunk search result will be printed. Defaults to False.
fp8_communication (bool, optional): Whether to enable fp8 communication. Defaults to False.
@@ -366,7 +365,6 @@ class GeminiPlugin(DPPluginBase):
enable_flash_attention: bool = False,
enable_sequence_parallelism: bool = False,
enable_jit_fused: bool = False,
- enable_sequence_overlap: bool = False,
enable_async_reduce: bool = True,
use_fp8: bool = False,
verbose: bool = False,
@@ -428,7 +426,6 @@ class GeminiPlugin(DPPluginBase):
self.enable_flash_attention = enable_flash_attention
self.enable_sequence_parallelism = enable_sequence_parallelism if self.enable_tensor_parallelism else False
self.enable_jit_fused = enable_jit_fused
- self.enable_sequence_overlap = enable_sequence_overlap
self.verbose = verbose
self.tp_size = tp_size
@@ -455,7 +452,6 @@ class GeminiPlugin(DPPluginBase):
enable_flash_attention=self.enable_flash_attention,
enable_jit_fused=self.enable_jit_fused,
enable_sequence_parallelism=self.enable_sequence_parallelism,
- enable_sequence_overlap=self.enable_sequence_overlap,
)
def __del__(self):
diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py
index caeed5457..58d055bb0 100644
--- a/colossalai/booster/plugin/hybrid_parallel_plugin.py
+++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py
@@ -116,10 +116,15 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin):
super().__init__(module)
self.op_hooks = []
+ if use_fp8:
+ self.op_hooks.append(FP8Hook())
+ self.op_hooks = []
if use_fp8:
self.op_hooks.append(FP8Hook())
if overlap_allgather:
self.op_hooks.append(ZeroOpHook())
+ if use_fp8 or overlap_allgather:
+ self.op_hooks.append(ZeroOpHook())
if use_fp8 or overlap_allgather:
for p in module.parameters():
if p.requires_grad and type(p) is not ColoParameter:
@@ -232,6 +237,9 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin):
def _hook_context(self):
return ColoParamOpHookManager.use_hooks(*self.op_hooks) if len(self.op_hooks) > 0 else nullcontext()
+ def _hook_context(self):
+ return ColoParamOpHookManager.use_hooks(*self.op_hooks) if len(self.op_hooks) > 0 else nullcontext()
+
def get_param_info(optim: Optimizer):
# Get a backup of necessary information of parameters for future use, which includes:
@@ -951,7 +959,6 @@ class HybridParallelPlugin(PipelinePluginBase):
enable_jit_fused (bool, optional): Whether to switch on JIT in Shardformer. Default to False.
enable_sequence_parallelism (bool): Whether to turn on sequence parallelism in Shardformer. Defaults to False.
sequence_parallelism_mode (str): The Sequence parallelism mode. Can only be choosed from ["split_gather", "ring", "all_to_all"]. Defaults to "split_gather".
- enable_sequence_overlap (bool): Whether to turn on sequence overlap in Shardformer. Defaults to False.
parallel_output (bool): Whether to keep the output parallel when enabling tensor parallelism. Default to True.
num_microbatches (int, optional): Number of microbatches when using pipeline parallelism. Defaults to None.
microbatch_size (int, optional): Microbatch size when using pipeline parallelism.
@@ -983,6 +990,8 @@ class HybridParallelPlugin(PipelinePluginBase):
make_vocab_size_divisible_by (int, optional): it's used when padding the vocabulary size, to make it choose an faster kenel. Default to 64.
fp8_communication (bool, optional): Whether to enable fp8 communication. Defaults to False.
use_fp8 (bool, optional): Whether to enable fp8 mixed precision training. Defaults to False.
+ fp8_communication (bool, optional): Whether to enable fp8 communication. Defaults to False.
+ use_fp8 (bool, optional): Whether to enable fp8 mixed precision training. Defaults to False.
overlap_p2p (bool, optional): Whether to overlap the p2p communication in pipeline parallelism
inner_ring_size (int, optional): The inner ring size of 2D Ring Attention when sp mode is "ring_attn".
It's advisable to not tune this (especially in single-node settings) and let it be heuristically set based on topology by default.
@@ -1002,7 +1011,6 @@ class HybridParallelPlugin(PipelinePluginBase):
enable_jit_fused: bool = False,
enable_sequence_parallelism: bool = False,
sequence_parallelism_mode: str = None,
- enable_sequence_overlap: bool = False,
parallel_output: bool = True,
num_microbatches: Optional[int] = None,
microbatch_size: Optional[int] = None,
@@ -1092,6 +1100,7 @@ class HybridParallelPlugin(PipelinePluginBase):
self.use_fp8 = use_fp8
if dp_outside:
self.dp_axis, self.pp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3
+ self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size, self.sp_size)
if sequence_parallelism_mode == "ring_attn":
# Swap tp and sp since 2D Ring has better inter-node latency
self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.sp_size, self.tp_size)
@@ -1195,13 +1204,15 @@ class HybridParallelPlugin(PipelinePluginBase):
enable_jit_fused=self.enable_jit_fused,
enable_sequence_parallelism=enable_sequence_parallelism,
sequence_parallelism_mode=sequence_parallelism_mode,
- enable_sequence_overlap=enable_sequence_overlap,
parallel_output=parallel_output,
make_vocab_size_divisible_by=make_vocab_size_divisible_by,
gradient_checkpoint_config=gradient_checkpoint_config,
fp8_communication=fp8_communication,
inner_ring_size=inner_ring_size,
+ pg_mesh=self.pg_mesh,
+ sp_axis=self.sp_axis,
)
+
self.amp_config = dict(
initial_scale=initial_scale,
growth_factor=growth_factor,
@@ -1293,6 +1304,7 @@ class HybridParallelPlugin(PipelinePluginBase):
self.dp_size == 1 and self.pp_size == 1
)
# sync gradients across DP * SP ranks
+ # sync gradients across DP * SP ranks
# Apply Hybrid ZeRO across DP * SP ranks
if self.enable_sequence_parallelism and not is_share_sp_tp(self.sequence_parallelism_mode):
dp_group = self.pg_mesh.create_group_along_axis([self.dp_axis, self.sp_axis])
diff --git a/colossalai/booster/plugin/low_level_zero_plugin.py b/colossalai/booster/plugin/low_level_zero_plugin.py
index b167b5c7a..f3a6901ad 100644
--- a/colossalai/booster/plugin/low_level_zero_plugin.py
+++ b/colossalai/booster/plugin/low_level_zero_plugin.py
@@ -290,7 +290,11 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
assert isinstance(
peft_model, PeftModel
), "The model doesn't have lora adapters, please enable lora before saving."
- return peft_model.save_pretrained(checkpoint, safe_serialization=use_safetensors)
+ return peft_model.save_pretrained(
+ checkpoint,
+ safe_serialization=use_safetensors,
+ state_dict=tree_map(lambda x: x.data if torch.is_tensor(x) else x, peft_model.state_dict()),
+ )
class LowLevelZeroPlugin(DPPluginBase):
diff --git a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
index 8b62a1e2b..96531a04f 100644
--- a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
+++ b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
@@ -141,7 +141,6 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
enable_jit_fused (bool, optional): Whether to switch on JIT in Shardformer. Default to False.
enable_sequence_parallelism (bool): Whether to turn on sequence parallelism in Shardformer. Defaults to False.
sequence_parallelism_mode (str): The Sequence parallelism mode. Can only be choosed from ["split_gather", "ring", "all_to_all"]. Defaults to "split_gather".
- enable_sequence_overlap (bool): Whether to turn on sequence overlap in Shardformer. Defaults to False.
parallel_output (bool): Whether to keep the output parallel when enabling tensor parallelism. Default to True.
num_microbatches (int, optional): Number of microbatches when using pipeline parallelism. Defaults to None.
microbatch_size (int, optional): Microbatch size when using pipeline parallelism.
@@ -190,7 +189,6 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
enable_jit_fused: bool = False,
enable_sequence_parallelism: bool = False,
sequence_parallelism_mode: str = None,
- enable_sequence_overlap: bool = False,
parallel_output: bool = True,
num_microbatches: Optional[int] = None,
microbatch_size: Optional[int] = None,
@@ -368,7 +366,6 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
enable_jit_fused=self.enable_jit_fused,
enable_sequence_parallelism=enable_sequence_parallelism,
sequence_parallelism_mode=sequence_parallelism_mode,
- enable_sequence_overlap=enable_sequence_overlap,
parallel_output=parallel_output,
make_vocab_size_divisible_by=make_vocab_size_divisible_by,
gradient_checkpoint_config=gradient_checkpoint_config,
diff --git a/colossalai/booster/plugin/torch_ddp_plugin.py b/colossalai/booster/plugin/torch_ddp_plugin.py
index ec7ce7f9a..156a4acf9 100644
--- a/colossalai/booster/plugin/torch_ddp_plugin.py
+++ b/colossalai/booster/plugin/torch_ddp_plugin.py
@@ -1,9 +1,11 @@
from typing import Callable, Dict, Iterator, List, Optional, Tuple, Union
+import torch
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
+from torch.utils._pytree import tree_map
from torch.utils.data import DataLoader
from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO
@@ -134,7 +136,11 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
assert isinstance(
peft_model, PeftModel
), "The model doesn't have lora adapters, please enable lora before saving."
- peft_model.save_pretrained(save_directory=checkpoint, safe_serialization=use_safetensors)
+ return peft_model.save_pretrained(
+ checkpoint,
+ safe_serialization=use_safetensors,
+ state_dict=tree_map(lambda x: x.data if torch.is_tensor(x) else x, peft_model.state_dict()),
+ )
class TorchDDPModel(ModelWrapper):
diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py
index 3b6917d32..79bb33dca 100644
--- a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py
+++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py
@@ -11,6 +11,7 @@ import torch.distributed as dist
import torch.nn as nn
from torch.distributed import ProcessGroup
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
+from torch.utils._pytree import tree_map
from colossalai.cluster import DistCoordinator
from colossalai.interface import ModelWrapper, OptimizerWrapper
@@ -20,7 +21,7 @@ from colossalai.tensor.padded_tensor import (
to_padded_tensor,
to_unpadded_tensor,
)
-from colossalai.utils import get_current_device
+from colossalai.utils import get_current_device, get_non_persistent_buffers_set
from .general_checkpoint_io import GeneralCheckpointIO
from .index_file import CheckpointIndexFile
@@ -104,8 +105,9 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
yield block, block_size
# Save buffers.
+ non_persist_buffers_set = get_non_persistent_buffers_set(model)
for name, buf in model.named_buffers():
- if buf is not None and name not in model._non_persistent_buffers_set:
+ if buf is not None and name not in non_persist_buffers_set:
buffer = buf if keep_vars else buf.detach()
block, block_size = state_dict_sharder.append_param(prefix + name, buffer)
if block is not None:
@@ -351,9 +353,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
_load(name)
# Load buffers.
- non_persistent_buffers = set()
- for n, m in model.named_modules():
- non_persistent_buffers |= set(".".join((n, b)) for b in m._non_persistent_buffers_set)
+ non_persistent_buffers = get_non_persistent_buffers_set(model)
for name, buf in model.named_buffers():
if buf is not None and name not in non_persistent_buffers:
_load(name)
@@ -956,4 +956,8 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
assert isinstance(
peft_model, PeftModel
), "The model doesn't have lora adapters, please enable lora before saving."
- return peft_model.save_pretrained(checkpoint, safe_serialization=use_safetensors)
+ return peft_model.save_pretrained(
+ checkpoint,
+ safe_serialization=use_safetensors,
+ state_dict=tree_map(lambda x: x.data if torch.is_tensor(x) else x, peft_model.state_dict()),
+ )
diff --git a/colossalai/inference/modeling/policy/nopadding_baichuan.py b/colossalai/inference/modeling/policy/nopadding_baichuan.py
index 37b5062e8..8528de75c 100644
--- a/colossalai/inference/modeling/policy/nopadding_baichuan.py
+++ b/colossalai/inference/modeling/policy/nopadding_baichuan.py
@@ -57,7 +57,9 @@ class NoPaddingBaichuanModelInferPolicy(LlamaForCausalLMPolicy, RPC_PARAM):
target_module=NopadBaichuanMLP,
),
SubModuleReplacementDescription(
- suffix="self_attn.W_pack", target_module=FusedLinear1D_Col, kwargs={"n_fused": 3}
+ suffix="self_attn.W_pack",
+ target_module=FusedLinear1D_Col,
+ kwargs={"split_sizes": [self.model.config.hidden_size] * 3},
),
SubModuleReplacementDescription(
suffix="self_attn.o_proj",
diff --git a/colossalai/kernel/jit/option.py b/colossalai/kernel/jit/option.py
index d392649a6..1ee93e4e0 100644
--- a/colossalai/kernel/jit/option.py
+++ b/colossalai/kernel/jit/option.py
@@ -1,7 +1,6 @@
import torch
from colossalai.accelerator import get_accelerator
-from colossalai.legacy.nn.layer.colossalai_layer import Embedding, Linear
from .bias_dropout_add import bias_dropout_add_fused_train
from .bias_gelu import bias_gelu_impl
@@ -45,6 +44,7 @@ def warmup_jit_fusion(
dtype: torch.dtype = torch.float32,
):
"""Compile JIT functions before the main training steps"""
+ from colossalai.legacy.nn.layer.colossalai_layer import Embedding, Linear
embed = Embedding(vocab_size, hidden_size).to(get_accelerator().get_current_device())
linear_1 = Linear(hidden_size, hidden_size * 4, skip_bias_add=True).to(get_accelerator().get_current_device())
diff --git a/colossalai/pipeline/schedule/_utils.py b/colossalai/pipeline/schedule/_utils.py
index b641eb364..ded75d968 100644
--- a/colossalai/pipeline/schedule/_utils.py
+++ b/colossalai/pipeline/schedule/_utils.py
@@ -3,8 +3,9 @@ from typing import Any, List, Optional, Tuple
import torch
import torch.cuda
+from packaging.version import Version
from torch.nn import Module
-from torch.utils._pytree import SUPPORTED_NODES, TreeSpec, _register_pytree_node, tree_flatten, tree_map, tree_unflatten
+from torch.utils._pytree import SUPPORTED_NODES, TreeSpec, tree_flatten, tree_map, tree_unflatten
# this register are for torch under version 1.13.1, maybe removed in the future
@@ -16,7 +17,12 @@ def _odict_unflatten(values: List[Any], context: Any) -> "OrderedDict[Any, Any]"
return OrderedDict((key, value) for key, value in zip(context, values))
-_register_pytree_node(OrderedDict, _odict_flatten, _odict_unflatten)
+if Version(torch.__version__) <= Version("1.13.1"):
+ try:
+ from torch.utils._pytree import register_pytree_node as _register_pytree_node
+ except ImportError:
+ from torch.utils._pytree import _register_pytree_node
+ _register_pytree_node(OrderedDict, _odict_flatten, _odict_unflatten)
def tree_map_hf(fn: Any, pytree: Any):
diff --git a/colossalai/pipeline/schedule/interleaved_pp.py b/colossalai/pipeline/schedule/interleaved_pp.py
index c538ee071..5da98364d 100644
--- a/colossalai/pipeline/schedule/interleaved_pp.py
+++ b/colossalai/pipeline/schedule/interleaved_pp.py
@@ -351,15 +351,16 @@ class InterleavedSchedule(PipelineSchedule):
if output_obj_grad is None:
optimizer.backward(output_obj)
else:
- if "backward_tensor_keys" not in output_obj:
- for k, grad in output_obj_grad.items():
- optimizer.backward_by_grad(output_obj[k], grad)
+ keys = output_obj.get("backward_tensor_keys", output_obj_grad.keys())
+ tensors_to_backward = []
+ grads_to_backward = []
+ for k in keys:
+ tensors_to_backward.append(output_obj[k])
+ grads_to_backward.append(output_obj_grad[k])
+ if len(tensors_to_backward) == 1:
+ optimizer.backward_by_grad(tensors_to_backward[0], grads_to_backward[0])
else:
- for k, grad in output_obj_grad.items():
- output_obj[k].grad = grad
- for k in output_obj["backward_tensor_keys"]:
- tensor_to_backward = output_obj[k]
- optimizer.backward_by_grad(tensor_to_backward, tensor_to_backward.grad)
+ optimizer.backward_by_grad(tensors_to_backward, grads_to_backward)
# Collect the grad of the input_obj.
input_obj_grad = None
diff --git a/colossalai/pipeline/schedule/one_f_one_b.py b/colossalai/pipeline/schedule/one_f_one_b.py
index 0fc90995a..224d63688 100644
--- a/colossalai/pipeline/schedule/one_f_one_b.py
+++ b/colossalai/pipeline/schedule/one_f_one_b.py
@@ -305,15 +305,16 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
if output_obj_grad is None:
optimizer.backward(output_obj)
else:
- if "backward_tensor_keys" not in output_obj:
- for k, grad in output_obj_grad.items():
- optimizer.backward_by_grad(output_obj[k], grad)
+ keys = output_obj.get("backward_tensor_keys", output_obj_grad.keys())
+ tensors_to_backward = []
+ grads_to_backward = []
+ for k in keys:
+ tensors_to_backward.append(output_obj[k])
+ grads_to_backward.append(output_obj_grad[k])
+ if len(tensors_to_backward) == 1:
+ optimizer.backward_by_grad(tensors_to_backward[0], grads_to_backward[0])
else:
- for k, grad in output_obj_grad.items():
- output_obj[k].grad = grad
- for k in output_obj["backward_tensor_keys"]:
- tensor_to_backward = output_obj[k]
- optimizer.backward_by_grad(tensor_to_backward, tensor_to_backward.grad)
+ optimizer.backward_by_grad(tensors_to_backward, grads_to_backward)
# Collect the grad of the input_obj.
input_obj_grad = None
diff --git a/colossalai/quantization/fp8.py b/colossalai/quantization/fp8.py
index 8243a29ac..e23da5ccc 100644
--- a/colossalai/quantization/fp8.py
+++ b/colossalai/quantization/fp8.py
@@ -8,6 +8,8 @@ import torch.nn.functional as F
from packaging.version import Version
from torch.distributed import ReduceOp
+from .fp8_config import dynamic_kernel
+
SUPPORT_TORCH_COMPILE = Version(torch.__version__) >= Version("2.4.0")
SCALE_BYTES = 4
try:
@@ -832,11 +834,13 @@ class _LinearFp8(torch.autograd.Function):
return x_grad.reshape(ctx.x_shape), w_grad, bias_grad
-@torch.compile(mode="max-autotune-no-cudagraphs", disable=not SUPPORT_TORCH_COMPILE, dynamic=False)
+@torch.compile(mode="max-autotune-no-cudagraphs", disable=not SUPPORT_TORCH_COMPILE, dynamic=dynamic_kernel)
def _linear_fp8(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor:
return _LinearFp8.apply(input, weight, bias)
def linear_fp8(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor:
+ if input.shape[-1] % 16 != 0 or np.prod(input.shape[:-1]) % 16 != 0:
+ return F.linear(input, weight, bias)
out = _linear_fp8(input, weight, bias)
return out
diff --git a/colossalai/quantization/fp8_config.py b/colossalai/quantization/fp8_config.py
new file mode 100644
index 000000000..efa625185
--- /dev/null
+++ b/colossalai/quantization/fp8_config.py
@@ -0,0 +1 @@
+dynamic_kernel: bool = False
diff --git a/colossalai/shardformer/layer/__init__.py b/colossalai/shardformer/layer/__init__.py
index 4fc714e57..da5363840 100644
--- a/colossalai/shardformer/layer/__init__.py
+++ b/colossalai/shardformer/layer/__init__.py
@@ -6,7 +6,7 @@ from .linear import Linear1D_Col, Linear1D_Row, LinearWithGradAccum, PaddingLMHe
from .loss import cross_entropy_1d, dist_cross_entropy
from .normalization import FusedLayerNorm, FusedRMSNorm, LayerNorm, RMSNorm
from .parallel_module import ParallelModule
-from .qkv_fused_linear import FusedLinear1D_Col, GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row
+from .qkv_fused_linear import FusedLinear1D_Col, FusedLinear1D_Row, GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row
__all__ = [
"Embedding1D",
@@ -35,4 +35,5 @@ __all__ = [
"RingAttention",
"get_pad_info",
"all_to_all_comm",
+ "FusedLinear1D_Row",
]
diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py
index 8a068b78c..8c2e6e7c5 100644
--- a/colossalai/shardformer/layer/_operation.py
+++ b/colossalai/shardformer/layer/_operation.py
@@ -106,7 +106,7 @@ class MatmulWithAsyncCommunication(torch.autograd.Function):
grad_output = grad_output.view(-1, grad_output.shape[-1])
total_input = total_input.view(-1, total_input.shape[-1])
- if ctx.async_grad_allreduce and fp8_communication:
+ if fp8_communication or not ctx.async_grad_allreduce:
_reduce(grad_input, group=ctx.process_group, fp8_communication=fp8_communication, fp8_format="e5m2")
elif ctx.async_grad_allreduce:
# Asynchronous all-reduce
@@ -364,10 +364,12 @@ def _ring_as_gather(func, input_to_gather=None, input_local=None, process_group=
for k in recv_tensors:
send_tensors[k], recv_tensors[k] = recv_tensors[k], send_tensors[k]
+ input_tensors = []
output_tensors = []
handles = communicate_step()
# first round: special case, retrive from local tensor
+ input_tensors.append(input_to_gather)
output_tensors.append(func(**input_to_gather, **input_local))
for i in range(group_size - 2):
for handle in handles:
@@ -378,14 +380,25 @@ def _ring_as_gather(func, input_to_gather=None, input_local=None, process_group=
handles = communicate_step()
# actual computation
+ input_tensors.append(send_tensors)
output_tensors.append(func(**send_tensors, **input_local))
# final round: special case, no need to send/recv again
for handle in handles:
handle.wait()
+ input_tensors.append(send_tensors)
output_tensors.append(func(**recv_tensors, **input_local))
- return torch.cat(output_tensors[group_size - cur_rank :] + output_tensors[: group_size - cur_rank], dim=gather_dim)
+ gathered_input = {}
+ for k in input_to_gather:
+ input_shards = [d[k] for d in input_tensors[group_size - cur_rank :] + input_tensors[: group_size - cur_rank]]
+ gathered_input[k] = torch.cat(input_shards, dim=gather_dim)
+
+ gathered_output = torch.cat(
+ output_tensors[group_size - cur_rank :] + output_tensors[: group_size - cur_rank], dim=gather_dim
+ )
+
+ return gathered_output, gathered_input
class _GatherForwardReduceScatterBackward(torch.autograd.Function):
@@ -441,29 +454,30 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function):
"""
@staticmethod
- def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap=True, ring=False):
+ def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring=False):
ctx.save_for_backward(input_, weight, bias)
ctx.use_bias = bias is not None
ctx.process_group = process_group
ctx.async_grad_reduce_scatter = async_grad_reduce_scatter
ctx.dim = dim
- ctx.overlap = overlap
if ring is True:
input_to_gather = {"input": input_}
input_local = {"weight": weight}
- output = _ring_as_gather(
+ output, input_dict = _ring_as_gather(
F.linear,
input_to_gather=input_to_gather,
input_local=input_local,
process_group=process_group,
)
+ ctx.gathered_input = input_dict["input"]
if bias is not None:
output += bias
else:
input_parallel = _gather(input_, dim, process_group)
+ ctx.gathered_input = input_parallel
if bias is not None:
output = F.linear(input_parallel, weight, bias)
else:
@@ -477,100 +491,50 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function):
use_bias = ctx.use_bias
dim = ctx.dim
process_group = ctx.process_group
- overlap = ctx.overlap
# In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias. Used in FusedLayerNorm
if use_bias:
bias = bias.view(bias.shape)
- if not overlap:
- input_parallel = _gather(input_, dim, process_group)
+ input_parallel = ctx.gathered_input
- total_input = input_parallel
- grad_input = grad_output.matmul(weight)
- grad_output = grad_output.contiguous()
- # Convert the tensor shapes to 2D for execution compatibility
- if len(grad_output.shape) > 2:
- grad_output = grad_output.view(-1, grad_output.shape[-1])
- total_input = total_input.view(-1, total_input.shape[-1])
+ total_input = input_parallel
+ grad_input = grad_output.matmul(weight)
+ grad_output = grad_output.contiguous()
+ # Convert the tensor shapes to 2D for execution compatibility
+ if len(grad_output.shape) > 2:
+ grad_output = grad_output.view(-1, grad_output.shape[-1])
+ total_input = total_input.view(-1, total_input.shape[-1])
- if ctx.async_grad_reduce_scatter:
- # Asynchronous reduce-scatter
- input_list = [
- item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim)
- ]
- output = torch.empty(
- input_.shape, dtype=input_parallel.dtype, device=input_parallel.device
- ).contiguous()
- handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True)
- # Rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to have
- # all-reduce scheduled first and have GPU resources allocated, CUDA_DEVICE_MAX_CONNECTIONS=1 is set in shardformer.py
-
- if _grad_accum_fusion_available and weight.grad is not None:
- grad = weight.grad
- if grad.dtype == torch.float32:
- fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, grad)
- grad_weight = None
- elif grad.dtype == torch.float16:
- fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(total_input, grad_output, grad)
- grad_weight = None
- else:
- grad_weight = grad_output.t().matmul(total_input)
- else:
- grad_weight = grad_output.t().matmul(total_input)
-
- grad_bias = grad_output.sum(dim=0) if use_bias else None
-
- if ctx.async_grad_reduce_scatter:
- handle.wait()
-
- else:
- input_ = input_.contiguous()
- world_size = dist.get_world_size(process_group)
- tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
-
- # do all gather in is async way
- gather_handle = dist.all_gather(tensor_list, input_, group=process_group, async_op=True)
- # calculate gradient and prepare data asynchronously with all-gather
- # calculate
- grad_input = grad_output.matmul(weight)
- grad_output = grad_output.contiguous()
- # Convert the tensor shapes to 2D for execution compatibility
- if len(grad_output.shape) > 2:
- grad_output = grad_output.view(-1, grad_output.shape[-1])
- grad_bias = grad_output.sum(dim=0) if use_bias else None
- # prepare data
+ if ctx.async_grad_reduce_scatter:
+ # Asynchronous reduce-scatter
input_list = [
item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim)
]
- output = torch.empty(input_.shape, dtype=input_.dtype, device=input_.device).contiguous()
- # wait until all-gather finished
- gather_handle.wait()
+ output = torch.empty(input_.shape, dtype=input_parallel.dtype, device=input_parallel.device).contiguous()
+ handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True)
+ # Rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to have
+ # all-reduce scheduled first and have GPU resources allocated, CUDA_DEVICE_MAX_CONNECTIONS=1 is set in shardformer.py
- # do reduce-scatter in async way
- reducescatter_handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True)
- input_parallel = torch.cat(tensor_list, dim=dim).contiguous()
- # calculate gradient
- if len(input_parallel.shape) > 2:
- input_parallel = input_parallel.view(-1, input_parallel.shape[-1])
-
- if _grad_accum_fusion_available and weight.grad is not None:
- grad = weight.grad
- if grad.dtype == torch.float32:
- fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(input_parallel, grad_output, grad)
- grad_weight = None
- elif grad.dtype == torch.float16:
- fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(input_parallel, grad_output, grad)
- grad_weight = None
- else:
- grad_weight = grad_output.t().matmul(input_parallel)
+ if _grad_accum_fusion_available and weight.grad is not None:
+ grad = weight.grad
+ if grad.dtype == torch.float32:
+ fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, grad)
+ grad_weight = None
+ elif grad.dtype == torch.float16:
+ fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(total_input, grad_output, grad)
+ grad_weight = None
else:
- grad_weight = grad_output.t().matmul(input_parallel)
- # grad_weight = grad_output.t().matmul(input_parallel)
- # wait until reduce-scatter finished
- reducescatter_handle.wait()
+ grad_weight = grad_output.t().matmul(total_input)
+ else:
+ grad_weight = grad_output.t().matmul(total_input)
- return output, grad_weight, grad_bias, None, None, None, None, None
+ grad_bias = grad_output.sum(dim=0) if use_bias else None
+
+ if ctx.async_grad_reduce_scatter:
+ handle.wait()
+
+ return output, grad_weight, grad_bias, None, None, None, None
def _ring_as_reducescatter(
@@ -701,7 +665,7 @@ class _LinearWithReduceScatterForwardGatherBackward(torch.autograd.Function):
# Convert the tensor shapes to 2D for execution compatibility
if len(grad_output.shape) > 2:
grad_output = grad_output.view(-1, grad_output.shape[-1])
- total_input = total_input.view(-1, total_input.shape[-1])
+ total_input = total_input.reshape(-1, total_input.shape[-1])
grad_weight = grad_output.t().matmul(total_input)
grad_bias = grad_output.sum(dim=0) if use_bias else None
@@ -759,34 +723,30 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function):
"""
@staticmethod
- def forward(
- ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring, fp8_communication
- ):
+ def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring, fp8_communication):
ctx.save_for_backward(input_, weight, bias)
ctx.use_bias = bias is not None
ctx.process_group = process_group
ctx.async_grad_reduce_scatter = async_grad_reduce_scatter
ctx.dim = dim
- ctx.overlap = overlap
ctx.fp8_communication = fp8_communication
if ring is True:
- input_to_gather = {}
- input_local = {}
- input_to_gather["input"] = input_
- input_local["other"] = weight
+ input_to_gather = {"input": input_}
+ input_local = {"other": weight}
- output = _ring_as_gather(
+ output, input_dict = _ring_as_gather(
torch.matmul,
input_to_gather=input_to_gather,
input_local=input_local,
process_group=process_group,
gather_dim=dim,
)
+ ctx.gathered_input = input_dict["input"]
else:
input_parallel = _gather(input_, dim, process_group, fp8_communication, fp8_format="e4m3")
-
+ ctx.gathered_input = input_parallel
output = torch.matmul(input_parallel, weight)
if bias is not None:
@@ -799,76 +759,39 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function):
use_bias = ctx.use_bias
dim = ctx.dim
process_group = ctx.process_group
- overlap = ctx.overlap
- fp8_communication = ctx.fp8_communication
# In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias. Used in FusedLayerNorm
weight = weight.view(weight.shape)
if use_bias:
bias = bias.view(bias.shape)
- if not overlap:
- input_parallel = _gather(input_, dim, process_group, fp8_communication, fp8_format="e5m2")
+ input_parallel = ctx.gathered_input
- total_input = input_parallel
- grad_input = grad_output.matmul(weight.T)
- grad_output = grad_output.contiguous()
- # Convert the tensor shapes to 2D for execution compatibility
- if len(grad_output.shape) > 2:
- grad_output = grad_output.view(-1, grad_output.shape[-1])
- total_input = total_input.view(-1, total_input.shape[-1])
+ total_input = input_parallel
+ grad_input = grad_output.matmul(weight.T)
+ grad_output = grad_output.contiguous()
+ # Convert the tensor shapes to 2D for execution compatibility
+ if len(grad_output.shape) > 2:
+ grad_output = grad_output.view(-1, grad_output.shape[-1])
+ total_input = total_input.view(-1, total_input.shape[-1])
- if ctx.async_grad_reduce_scatter:
- # Asynchronous reduce-scatter
- input_list = [
- item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim)
- ]
- output = torch.empty(
- input_.shape, dtype=input_parallel.dtype, device=input_parallel.device
- ).contiguous()
- handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True)
- # Rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to have
- # all-reduce scheduled first and have GPU resources allocated
-
- grad_weight = total_input.t().matmul(grad_output)
- grad_bias = grad_output.sum(dim=0) if use_bias else None
-
- if ctx.async_grad_reduce_scatter:
- handle.wait()
-
- else:
- world_size = dist.get_world_size(process_group)
- tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
-
- # do all gather in is async way
- gather_handle = dist.all_gather(tensor_list, input_, group=process_group, async_op=True)
- # calculate gradient and prepare data asynchronously with all-gather
- # calculate
- grad_input = grad_output.matmul(weight.T)
- grad_output = grad_output.contiguous()
- # Convert the tensor shapes to 2D for execution compatibility
- if len(grad_output.shape) > 2:
- grad_output = grad_output.view(-1, grad_output.shape[-1])
- grad_bias = grad_output.sum(dim=0) if use_bias else None
- # prepare data
+ if ctx.async_grad_reduce_scatter:
+ # Asynchronous reduce-scatter
input_list = [
item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim)
]
- output = torch.empty(input_.shape, dtype=input_.dtype, device=input_.device).contiguous()
- # wait until all-gather finished
- gather_handle.wait()
+ output = torch.empty(input_.shape, dtype=input_parallel.dtype, device=input_parallel.device).contiguous()
+ handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True)
+ # Rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to have
+ # all-reduce scheduled first and have GPU resources allocated
- # do reduce-scatter in async way
- reducescatter_handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True)
- input_parallel = torch.cat(tensor_list, dim=dim).contiguous()
- # calculate gradient
- if len(input_parallel.shape) > 2:
- input_parallel = input_parallel.view(-1, input_parallel.shape[-1])
- grad_weight = input_parallel.t().matmul(grad_output)
- # wait until reduce-scatter finished
- reducescatter_handle.wait()
+ grad_weight = total_input.t().matmul(grad_output)
+ grad_bias = grad_output.sum(dim=0) if use_bias else None
- return output, grad_weight, grad_bias, None, None, None, None, None, None
+ if ctx.async_grad_reduce_scatter:
+ handle.wait()
+
+ return output, grad_weight, grad_bias, None, None, None, None, None
class _SplitForwardGatherBackward(torch.autograd.Function):
@@ -988,7 +911,7 @@ class _AllToAll(torch.autograd.Function):
ctx.gather_dim = gather_dim
ctx.fp8_communication = fp8_communication
world_size = dist.get_world_size(process_group)
- bsz, _, _ = input_.shape
+ bsz = input_.shape[0]
# using all_to_all_single when batch size is 1
if bsz == 1:
@@ -1019,7 +942,7 @@ class _AllToAll(torch.autograd.Function):
gather_dim = ctx.scatter_dim
fp8_communication = ctx.fp8_communication
world_size = dist.get_world_size(process_group)
- bsz, _, _ = grad_output.shape
+ bsz = grad_output.shape[0]
if bsz == 1:
return_grad = _all_to_all_single(
@@ -1204,10 +1127,10 @@ def linear_with_grad_accum(input_, weight, bias, async_grad_allreduce, use_zbv=F
def linear_gather_forward_reducescatter_backward(
- input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring=False
+ input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring=False
):
return _LinearWithGatherForwardReduceScatterBackward.apply(
- input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring
+ input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring
)
@@ -1224,10 +1147,10 @@ def linear_reducescatter_forward_gather_backward(input_, weight, bias=None, proc
def matmul_gather_forward_reducescatter_backward(
- input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring=False, fp8_communication=False
+ input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring=False, fp8_communication=False
):
return _MatmulWithGatherForwardReduceScatterBackward.apply(
- input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring, fp8_communication
+ input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring, fp8_communication
)
diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py
index 5f0e9261c..3202ebf25 100644
--- a/colossalai/shardformer/layer/attn.py
+++ b/colossalai/shardformer/layer/attn.py
@@ -422,16 +422,21 @@ class RingAttention(torch.autograd.Function):
ATTN_DONE: torch.cuda.Event = None
SP_STREAM: torch.cuda.Stream = None
SP_GROUP: dist.ProcessGroup = None
- # duplicate process group for concurrent NCCL streams
- # while both PyTorch and NCCL warns(https://github.com/pytorch/pytorch/commit/2dbe5cb979f674f0052a8eea1f7b6c3c0ba441d7)
- # against this, in practice it seems to work fine.
+
+ # NOTE: Duplicating PGs for concurrent NCCL streams is a risky hack -- while it may increase throughput,
+ # both PyTorch and NCCL warn against this. (https://github.com/pytorch/pytorch/commit/2dbe5cb979f674f0052a8eea1f7b6c3c0ba441d7)
+ # LoongTrain's original double ring impl. uses concurrent PGs
+ # (https://github.com/InternLM/InternEvo/blob/e52f2ffc9acf818e8f2b1f97dfc69ceb2f06e154/internlm/model/ops/ring_flash_attn/zigzag_ring_flash_attn_with_sliding_window.py#L192)
+ # but I confirmed with Pytorch developers this can cause obscure "Software caused connection abort" errors.
+ # (https://github.com/pytorch/pytorch/issues/132852)
+ # NOTE: In general, a smarter idea is put as many P2P calls as possible into one `batch_isend_irecv`.
INNER_RING_GROUP: dist.ProcessGroup = None
- INNER_RING_GROUP_COPY: dist.ProcessGroup = None
+ # INNER_RING_GROUP_COPY: dist.ProcessGroup = None
INTER_RING_GROUP: dist.ProcessGroup = None
- INTER_RING_GROUP_COPY: dist.ProcessGroup = None
+ # INTER_RING_GROUP_COPY: dist.ProcessGroup = None
@staticmethod
- def get_double_ring_groups(sp_group, inner_ring_size=None):
+ def get_double_ring_groups(sp_axis, pg_mesh, inner_ring_size=None):
"""
Get 2D ring groups for the given process group. Generally, to avoid congestion, the inner ring size
shouldn't be larger than the number of NICs on each node.
@@ -441,21 +446,17 @@ class RingAttention(torch.autograd.Function):
Returns:
Tuple[dist.ProcessGroup, dist.ProcessGroup]: Inner-ring process group and inter-ring process group.
"""
+ assert pg_mesh is not None, f"Error: The pg mesh is None! please check the process group initialization."
+
+ sp_group = pg_mesh.get_group_along_axis(sp_axis)
sp_size = dist.get_world_size(sp_group)
sp_rank = dist.get_rank(sp_group)
- if inner_ring_size is None:
- if torch.cuda.device_count() >= dist.get_world_size():
- # single node, no need to consider NICs
- return sp_group, sp_group
- if sp_size <= 4:
- inner_ring_size = min(2, sp_size)
- else:
- inner_ring_size = min(4, sp_size)
- else:
- assert (
- inner_ring_size <= sp_size and sp_size % inner_ring_size == 0
- ), f"Error: sp_size {sp_size} should be divisible by inner_ring_size {inner_ring_size}"
+ assert inner_ring_size is not None
+
+ assert (
+ inner_ring_size <= sp_size and sp_size % inner_ring_size == 0
+ ), f"Error: sp_size {sp_size} should be divisible by inner_ring_size {inner_ring_size}"
if inner_ring_size == sp_size:
return sp_group, sp_group
@@ -474,14 +475,14 @@ class RingAttention(torch.autograd.Function):
# Create inner ring groups
for i in range(inner_ring_size):
ranks = list(range(i * inner_ring_size, (i + 1) * inner_ring_size))
- group = dist.new_group(ranks)
+ group = pg_mesh.get_group_along_axis(sp_axis, ranks)
if sp_rank in ranks:
inner_ring_group = group
# Create inter ring groups
for i in range(num_rings):
ranks = list(range(i, sp_size, num_rings))
- group = dist.new_group(ranks)
+ group = pg_mesh.get_group_along_axis(sp_axis, ranks)
if sp_rank in ranks:
inter_ring_group = group
@@ -492,7 +493,7 @@ class RingAttention(torch.autograd.Function):
q, # (B, H, Sq, D)
k,
v,
- sp_group,
+ sp_axis,
attention_mask_type,
cu_seqlens=None,
max_seqlen=None,
@@ -502,6 +503,7 @@ class RingAttention(torch.autograd.Function):
deterministic=False,
return_softmax=False,
inner_ring_size=None,
+ pg_mesh=None,
**kwargs,
):
"""
@@ -512,7 +514,7 @@ class RingAttention(torch.autograd.Function):
q (torch.Tensor): Query tensor. Shape should be [B, nHeads, Sq, D]
k (torch.Tensor): Key tensor. Shape should be [B, nHeads, Sq, Sq, D]
v (torch.Tensor): Value tensor. Shape should be [B, nHeads, Sq, Sq, D]
- sp_group (Optional[dist.ProcessGroup]): Process group for sequence parallelism
+ sp_axis (Optional[int]): Sp axis for the global pg mesh.
sp_tream (torch.cuda.Stream): An different stream for output correction.
cu_seqlens (Optional[torch.Tensor], optional): The cumulative sequence lengths
of the sequences in the batch, used to index into q.
@@ -537,7 +539,6 @@ class RingAttention(torch.autograd.Function):
RingAttention.ATTN_DONE = torch.cuda.Event()
if RingAttention.SP_STREAM is None:
RingAttention.SP_STREAM = torch.cuda.Stream()
-
assert (
q.shape[2] == k.shape[2]
), "Q, K and V having different sequence lengths (inference or cross-attn)\
@@ -546,11 +547,13 @@ class RingAttention(torch.autograd.Function):
attention_mask_type in RingAttention.SUPPORTED_MASK_TYPES
), f"Mask type {attention_mask_type} is not supported yet."
- clone_pg = lambda pg: dist.new_group(dist.get_process_group_ranks(pg))
+ assert pg_mesh is not None, f"Error: The pg mesh is None! please check the process group initialization."
- if RingAttention.SP_GROUP is not sp_group:
+ clone_pg = lambda pg: dist.new_group(dist.get_process_group_ranks(pg))
+ sp_group = pg_mesh.get_group_along_axis(sp_axis)
+ if inner_ring_size != None:
RingAttention.SP_GROUP = sp_group
- inner_ring_group, inter_ring_group = RingAttention.get_double_ring_groups(sp_group, inner_ring_size)
+ inner_ring_group, inter_ring_group = RingAttention.get_double_ring_groups(sp_axis, pg_mesh, inner_ring_size)
RingAttention.INNER_RING_GROUP = inner_ring_group
RingAttention.INTER_RING_GROUP = inter_ring_group
else:
@@ -628,7 +631,13 @@ class RingAttention(torch.autograd.Function):
inner_ring_group: Optional[dist.ProcessGroup] = None,
inter_ring_group: Optional[dist.ProcessGroup] = None,
):
-
+ """
+ Forward supporting both packed (varlen) and batched(fixed length, no padding) sequences.
+ No separate version for batched seq (hard to maintain), which incurs
+ some overhead in sequence splitting due to python for loops.
+ Uses two CUDA streams to overlap softmax denominator correction with next flash attn
+ (see comments below).
+ """
cu_seqlens_q = cu_seqlens_kv = cu_seqlens
max_seqlen_q = max_seqlen_kv = max_seqlen
cu_seqlens_half = cu_seqlens // 2
@@ -670,7 +679,8 @@ class RingAttention(torch.autograd.Function):
sp_size = dist.get_world_size(sp_group)
sp_rank = dist.get_rank(sp_group)
- # Attempt to achieve concurrent comm in the two-stream forward
+
+ # Create communicators corresponding to two CUDA streams
local_kv_comms = [RingComm(inner_ring_group) for _ in range(2)]
inter_ring_comm = RingComm(inter_ring_group)
local_sp_size = dist.get_world_size(inner_ring_group)
@@ -678,7 +688,7 @@ class RingAttention(torch.autograd.Function):
inter_ring_rank = dist.get_rank(inter_ring_group) if inter_ring_group is not sp_group else 0
num_rings = dist.get_world_size(inter_ring_group) if inter_ring_group is not sp_group else 1
- # Non-contiguous indexing copies to a new contiguous tensor,
+ # Any type of indexing(but not slicing) copies to a new contiguous tensor,
# so only do it once
if sp_rank != sp_size - 1:
q1 = q[half_idx_back]
@@ -695,6 +705,7 @@ class RingAttention(torch.autograd.Function):
rng_states = [None for _ in range(sp_size)]
sp_streams = [torch.cuda.current_stream(), sp_stream]
+ # Helper to pass args to FA
def _forward(q, k, v, causal):
(
_,
@@ -725,6 +736,7 @@ class RingAttention(torch.autograd.Function):
if i < local_sp_size - 1:
local_kv_comms[i % 2].send_recv(kv_buffers[i % 2], kv_buffers[(i + 1) % 2])
+ # Forward within a node
def _local_ring_forward():
# (Hopefully) overlap output correction with next flash attn
for i in range(local_sp_size):
@@ -733,6 +745,8 @@ class RingAttention(torch.autograd.Function):
# NOTE: waiting outside the current stream will NOT correctly synchronize.
if i > 0:
local_kv_comms[(i + 1) % 2].wait()
+
+ # Prefetch
if i == 0:
_kv_comm(i)
@@ -766,15 +780,22 @@ class RingAttention(torch.autograd.Function):
) = _forward(q_block, kv_block[0], kv_block[1], causal=False)
RingAttention.ATTN_DONE.record()
# Pipeline the next KV comm with output correction instead of the next flash attn
- # to minimize idle time when comm takes longer than attn.
+ # kernel, to minimize bubble when comm takes longer than attn.
_kv_comm(i + 1)
block_softmax_lse[i % 2] = (
block_softmax_lse[i % 2].transpose(0, 1).unsqueeze(-1).contiguous().float()
) # (H, T) -> (T, H, 1)
assert block_out[i % 2].shape[:-1] == block_softmax_lse[i % 2].shape[:-1]
- # Output and log sum exp correction. Ideally overlap this with the next flash attn kernel.
- # In reality this always finishes before next flash attn; no need for extra sync.
+
+ # Output and log sum exp correction.
+ # Ideally overlap this with the next flash attn kernel,
+ # since attn uses Tensor Core and rescale is element-wise, memory-bound and uses CUDA cores.
+ # (NOTE that this is the same as ping-pong scheduling idea in FA3)
+ # TODO However sometimes while the GPU has scheduled the next kernel,
+ # it's reluctant to launch it in overlap. Some potential causes:
+ # 1. need lower-level CUDA scheduling 2. further benchmark against Megatron-LM
+ # 3. register spilling by FA kernel.
if i == 0:
out = block_out[0]
softmax_lse = block_softmax_lse[0]
@@ -790,15 +811,17 @@ class RingAttention(torch.autograd.Function):
torch.cuda.current_stream().wait_stream(sp_stream)
return out, softmax_lse
+ # Forward for inter-node (the outer ring in 2D ring)
def _other_ring_forward(ring_num_idx, out, softmax_lse):
# Loop through the inner ring after receiving
- # all new KVs from the previous inner ring
+ # all new KVs from another ring
for i in range(local_sp_size):
with torch.cuda.stream(sp_streams[i % 2]):
# Send & recv KV
if i > 0:
local_kv_comms[(i + 1) % 2].wait()
+ # Prefetch
if i == 0:
_kv_comm(i)
@@ -895,7 +918,8 @@ class RingAttention(torch.autograd.Function):
def backward(ctx, dout, _):
"""
During backward, we accumulate q grads on each rank locally, but iterate kv and their grads
- over all ranks for accumulation.
+ over all ranks for accumulation. We avoid using two streams due to backward using doubled
+ buffers and more comm cost.
"""
(q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_kv, half_idx_front, half_idx_back) = ctx.saved_tensors[:9]
rng_states = ctx.saved_tensors[9:]
@@ -927,7 +951,7 @@ class RingAttention(torch.autograd.Function):
local_sp_rank = dist.get_rank(sp_group)
sp_size = dist.get_world_size(sp_group)
- # Using separate streams (pg) for concurrent kv and dkv comm may
+ # NOTE: Using separate streams (PG) for concurrent kv and dkv comm may
# cause NCCL "software caused connection abort" here...
local_kv_comm = RingComm(local_kv_group)
local_dkv_comm = RingComm(local_kv_group)
@@ -959,6 +983,7 @@ class RingAttention(torch.autograd.Function):
dkv_buffers = [torch.empty_like(kv, dtype=torch.float32) for kv in kv_buffers] # (T, H, D)
del k, v
+ # Helper to pass args to FA
def _backward(dout, q, k, v, out, softmax_lse, dq, dk, dv, rng_state, causal):
_flash_attn_backward(
dout,
@@ -979,8 +1004,7 @@ class RingAttention(torch.autograd.Function):
**misc_kwargs,
)
- # NOTE: We avoid using two streams due to doubled buffers
- # and that backward is more communication intensive.
+ # Backward within a node
def _local_ring_backward():
for i in range(local_sp_size):
if i > 0:
@@ -1043,6 +1067,7 @@ class RingAttention(torch.autograd.Function):
dkv_send = dkv_buffers[(local_sp_size - 1) % 2]
return dq, dkv_recv, dkv_send
+ # Backward for inter-node (the outer ring in 2D ring)
def _other_ring_backward(ring_num_idx, dq):
if ring_num_idx > inter_ring_rank:
# Indexing is expensive
@@ -1127,34 +1152,34 @@ class RingAttention(torch.autograd.Function):
@staticmethod
def prepare_varlen_batch(
- attention_mask: torch.Tensor,
+ padding_mask: torch.Tensor,
sp_group: dist.ProcessGroup,
inputs_embeds: torch.Tensor = None,
position_ids: Optional[torch.Tensor] = None,
is_label: bool = False,
- is_2d: bool = True,
+ is_batched_seq: bool = True,
):
+ # TODO: support setting a batch dim (fix packing length) for packed mode, so that
+ # DP can be used (needs to modify dataloader too)
"""
Preprocess a batch of padded sequence by splitting input sequence by sp_size
- sequence-wise and packing them into one sequence. Updates the mask info accordingly.
+ seq-wise and packing them into one sequence. Updates the mask info accordingly.
Args:
- attention_mask (torch.Tensor): Contains the mask [B, Sq], where True means the token is NOT masked.
+ padding_mask (torch.Tensor): Contains the mask [B, Sq], where True means the token is NOT masked.
sp_group (dist.ProcessGroup): Process group for sequence parallelism
inputs_embeds (torch.Tensor): Input embeddings. Shape should be [B, Sq, ...]
position_ids (Optional[torch.Tensor], optional): Position ids of shape [Sq] or [1, Sq]. Defaults to None.
is_label (bool, optional): Whether inputs_embeds is instead a label tensor. If True, mask out the first
token of each sequence.
- is_2d (bool, optional): Whether to return 2D outputs padded to max_seqlen // sp_size or flatten
- the batch dim to a packed 1d sequence. Contingent on model forward shape definitions.
+ is_batched_seq (bool, optional): If True, then the input is a batch of (potentially padded) sequences
+ of shape [B, Sq, ...]; else a packed sequence of shape [T, ...].
Returns:
- torch.Tensor:
- Packed input embeddings of shape [B, Sq // sp_size, ...].
-
- Dict[str, Any]:
+ inputs_embeds (torch.Tensor):
+ Packed input embeddings of shape [B, Sq // sp_size, ...] if is_batched_seq, else [T, ...].
+ mask_info (Dict[str, Any]):
A dictionary containing mask info.
-
- torch.Tensor:
+ position_ids (torch.Tensor):
Packed position ids of shape [..., Sq // sp_size].
"""
@@ -1162,12 +1187,11 @@ class RingAttention(torch.autograd.Function):
sp_size = dist.get_world_size(group=sp_group)
sp_rank = dist.get_rank(group=sp_group)
mask_info = {}
- mask_info["max_seqlen"], mask_info["cu_seqlens"] = get_pad_info(attention_mask, return_indices=False)
+ mask_info["max_seqlen"], mask_info["cu_seqlens"] = get_pad_info(padding_mask, return_indices=False)
- # Unpad, split seq-wise, then pad back to (B, max_seqlen // sp_size)
- # Split mask to compute local nonzero position indices
+ # Unpad, split seq-wise, then pad to (B, max_seqlen // sp_size)
# (B, Sq) -> (B, max_seqlen // sp_size)
- attention_mask = attention_mask[:, : mask_info["max_seqlen"]]
+ padding_mask = padding_mask[:, : mask_info["max_seqlen"]]
if inputs_embeds is not None:
inputs_embeds = inputs_embeds[:, : mask_info["max_seqlen"]]
inputs_embeds = split_varlen_zigzag(
@@ -1175,11 +1199,12 @@ class RingAttention(torch.autograd.Function):
mask_info["cu_seqlens"],
sp_group,
mask_info["max_seqlen"],
- is_2d=is_2d,
+ is_batched_seq=is_batched_seq,
is_label=is_label,
)
- attention_mask = split_varlen_zigzag(
- attention_mask, mask_info["cu_seqlens"], sp_group, mask_info["max_seqlen"], is_2d=is_2d
+ # Split mask to get local nonzero seq positions
+ padding_mask = split_varlen_zigzag(
+ padding_mask, mask_info["cu_seqlens"], sp_group, mask_info["max_seqlen"], is_batched_seq=is_batched_seq
)
if position_ids is not None:
@@ -1192,7 +1217,7 @@ class RingAttention(torch.autograd.Function):
)
mask_info["max_seqlen"] //= sp_size
- mask_info["valid_indices"] = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
+ mask_info["valid_indices"] = torch.nonzero(padding_mask.flatten(), as_tuple=False).flatten()
mask_info["cu_seqlens"] //= sp_size
mask_info["attention_mask_type"] = AttnMaskType.PADDED_CAUSAL
return inputs_embeds, mask_info, position_ids
diff --git a/colossalai/shardformer/layer/linear.py b/colossalai/shardformer/layer/linear.py
index 040a93e5a..d39d6e997 100644
--- a/colossalai/shardformer/layer/linear.py
+++ b/colossalai/shardformer/layer/linear.py
@@ -23,18 +23,16 @@ from colossalai.tensor.d_tensor.api import (
)
from ._operation import (
- gather_forward_reducescatter_backward,
gather_forward_split_backward,
linear_gather_forward_reducescatter_backward,
linear_reducescatter_forward_gather_backward,
linear_with_async_comm,
linear_with_grad_accum,
reduce_forward,
- reducescatter_forward_gather_backward,
split_forward_gather_backward,
)
from .parallel_module import PaddingParallelModule, ParallelModule
-from .utils import create_randomizer_with_offset
+from .utils import create_randomizer_with_offset, is_share_sp_tp
__all__ = ["LinearWithGradAccum", "Linear1D_Col", "Linear1D_Row"]
@@ -197,7 +195,6 @@ class Linear1D_Col(ParallelModule):
to all GPUs, otherwise, every GPU will have its output
which is :math:`Y_i = XA_i`, defaults to False
seq_parallel (`bool`): If set to ``True``, it will use sequence parallel, defaults to False.
- overlap (`bool`): If set to ``True``, it will overlap input all-gather with gradient computation during backward, defaults to False.
skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer,
which is preserved for kernel fusion, defaults to False
weight_initializer (`typing.Callable`):
@@ -220,7 +217,6 @@ class Linear1D_Col(ParallelModule):
gather_output: bool = False,
seq_parallel_mode: str = None,
seq_parallel_dim: int = 1,
- overlap: torch.cuda.Stream = None,
skip_bias_add: bool = False,
weight: Optional[Parameter] = None,
bias_: Optional[Parameter] = None,
@@ -238,7 +234,6 @@ class Linear1D_Col(ParallelModule):
self.gather_output = gather_output
self.seq_parallel_mode = seq_parallel_mode
self.seq_parallel_dim = seq_parallel_dim
- self.overlap = overlap
self.skip_bias_add = skip_bias_add
self.device = device
self.process_group = process_group
@@ -345,22 +340,16 @@ class Linear1D_Col(ParallelModule):
# Matrix multiply.
bias = self.bias if not self.skip_bias_add else None
- if self.seq_parallel_mode == "split_gather":
- input_parallel = gather_forward_reducescatter_backward(
- input_parallel, self.process_group, self.seq_parallel_dim, fp8_communication=self.fp8_communication
- )
- output_parallel = linear_with_async_comm(
+
+ if is_share_sp_tp(self.seq_parallel_mode):
+ output_parallel = linear_gather_forward_reducescatter_backward(
input_parallel,
self.weight,
bias,
self.process_group,
- False,
- fp8_communication=self.fp8_communication,
- use_zbv=self.use_zbv,
- )
- elif self.seq_parallel_mode == "ring":
- output_parallel = linear_gather_forward_reducescatter_backward(
- input_parallel, self.weight, bias, self.process_group, True, self.seq_parallel_dim, self.overlap, True
+ True,
+ self.seq_parallel_dim,
+ ring=self.seq_parallel_mode == "ring",
)
else:
output_parallel = linear_with_async_comm(
@@ -584,31 +573,17 @@ class Linear1D_Row(ParallelModule):
handle.wait()
output = torch.cat(output_parallel_list, dim=-1)
else:
- if self.seq_parallel_mode is None:
- output_parallel = linear_with_async_comm(
- input_, self.weight, None, self.process_group, False, use_zbv=self.use_zbv
- )
- output = reduce_forward(output_parallel, self.process_group, fp8_communication=self.fp8_communication)
- elif self.seq_parallel_mode == "split_gather":
- output_parallel = linear_with_async_comm(
- input_, self.weight, None, self.process_group, False, use_zbv=self.use_zbv
- )
- output = reducescatter_forward_gather_backward(
- output_parallel, self.process_group, self.seq_parallel_dim, fp8_communication=self.fp8_communication
- )
- elif self.seq_parallel_mode == "ring":
+ if is_share_sp_tp(self.seq_parallel_mode):
output = linear_reducescatter_forward_gather_backward(
input_,
self.weight,
process_group=self.process_group,
dim=self.seq_parallel_dim,
- ring=True,
+ ring=self.seq_parallel_mode == "ring",
)
else:
- output_parallel = linear_with_async_comm(
- input_, self.weight, None, self.process_group, False, use_zbv=self.use_zbv
- )
- output = reduce_forward(output_parallel, self.process_group)
+ output_parallel = F.linear(input_, self.weight)
+ output = reduce_forward(output_parallel, self.process_group, fp8_communication=self.fp8_communication)
if not self.skip_bias_add:
if self.bias is not None:
@@ -716,7 +691,6 @@ class VocabParallelLMHead1D(Linear1D_Col, PaddingParallelModule):
to all GPUs, otherwise, every GPU will have its output
which is :math:`Y_i = XA_i`, defaults to False
seq_parallel (`bool`): If set to ``True``, it will use sequence parallel, defaults to False.
- overlap (`bool`): If set to ``True``, it will overlap input all-gather with gradient computation during backward, defaults to False.
skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer,
which is preserved for kernel fusion, defaults to False
weight_initializer (`typing.Callable`):
diff --git a/colossalai/shardformer/layer/qkv_fused_linear.py b/colossalai/shardformer/layer/qkv_fused_linear.py
index 6fd689908..6e469686b 100644
--- a/colossalai/shardformer/layer/qkv_fused_linear.py
+++ b/colossalai/shardformer/layer/qkv_fused_linear.py
@@ -7,6 +7,7 @@ from typing import Callable, List, Optional, Tuple, Union
import torch
import torch.distributed as dist
import torch.nn as nn
+import torch.nn.functional as F
from torch import Tensor
from torch.distributed import ProcessGroup
from torch.nn.parameter import Parameter
@@ -24,17 +25,17 @@ from colossalai.tensor.d_tensor.api import (
)
from ._operation import (
- gather_forward_split_backward,
+ linear_gather_forward_reducescatter_backward,
+ linear_reducescatter_forward_gather_backward,
linear_with_async_comm,
matmul_gather_forward_reducescatter_backward,
matmul_with_async_comm,
- reduce_backward,
reduce_forward,
reducescatter_forward_gather_backward,
split_forward_gather_backward,
)
from .parallel_module import ParallelModule
-from .utils import create_randomizer_with_offset
+from .utils import create_randomizer_with_offset, is_share_sp_tp
__all__ = ["FusedLinear1D_Col", "FusedLinear1D_Row", "GPT2FusedLinearConv1D_Col", "GPT2FusedLinearConv1D_Row"]
@@ -44,21 +45,25 @@ __all__ = ["FusedLinear1D_Col", "FusedLinear1D_Row", "GPT2FusedLinearConv1D_Col"
def split_fused_qkv_in_gpt2_style(
- qkv: torch.Tensor, n_fused: int, process_group: ProcessGroup, is_transposed: bool = False
+ qkv: torch.Tensor, split_sizes: List[int], process_group: ProcessGroup, is_transposed: bool = False
):
"""
The fused qkv tensor looks like [Q1, Q2, K1, K2, V1, V2], this function will split them into [Q1, K1, V1] and [Q2, K2, V2].
Args:
qkv (torch.Tensor): The fused qkv tensor.
- n_fused (int): The number items fused together, defaults to 3 (query, key and value).
+ split_sizes (List[int]): The sizes of the split tensor.
process_group (ProcessGroup): The process group for distributed communication.
is_transposed (bool): generally the tensor is the shape of (out_features, in_features). Set this to True if the tensor is in the shape (in_features, out_features).
"""
# get the number of slice for the fused qkv
rank = dist.get_rank(group=process_group)
world_size = dist.get_world_size(group=process_group)
- order = torch.arange(world_size * n_fused)
+ order = torch.arange(world_size * len(split_sizes))
+ new_split_sizes = []
+ for sz in split_sizes:
+ assert sz % world_size == 0, f"size {sz} is not divisible by world_size {world_size}"
+ new_split_sizes.extend([sz // world_size] * world_size)
# split the fused qkv
# from
@@ -66,9 +71,9 @@ def split_fused_qkv_in_gpt2_style(
# to
# [Q1, Q2, K1, K2, V1, V2]
if is_transposed:
- weight_chunks = torch.chunk(qkv, world_size * n_fused, dim=-1)
+ weight_chunks = torch.split(qkv, new_split_sizes, dim=-1)
else:
- weight_chunks = torch.chunk(qkv, world_size * n_fused, dim=0)
+ weight_chunks = torch.split(qkv, new_split_sizes, dim=0)
# rearrange the slice into the final order
# from
@@ -85,18 +90,23 @@ def split_fused_qkv_in_gpt2_style(
def gather_fused_qkv_in_gpt2_style(
- qkv: torch.Tensor, n_fused: int, process_group: ProcessGroup, is_transposed: bool = False
+ qkv: torch.Tensor, split_sizes: List[int], process_group: ProcessGroup, is_transposed: bool = False
):
"""
The splitted qkv tensor looks like [Q1, K1, V1] and [Q2, K2, V2], this function will gather them into [Q1, Q2, K1, K2, V1, V2].
Args:
qkv (torch.Tensor): The fused qkv tensor.
- n_fused (int): The number items fused together, defaults to 3 (query, key and value).
+ split_sizes (List[int]): The sizes of the split tensor.
process_group (ProcessGroup): The process group for distributed communication.
is_transposed (bool): generally the tensor is the shape of (out_features, in_features). Set this to True if the tensor is in the shape (in_features, out_features).
"""
world_size = dist.get_world_size(group=process_group)
+ new_split_sizes = []
+ for sz in split_sizes:
+ assert sz % world_size == 0, f"size {sz} is not divisible by world_size {world_size}"
+ new_split_sizes.append(sz // world_size)
+ new_split_sizes = new_split_sizes * world_size
# gather the tensors
# from
@@ -121,13 +131,13 @@ def gather_fused_qkv_in_gpt2_style(
# to
# [Q1, Q2, K1, K2, V1, V2]
if is_transposed:
- weight_chunks = torch.chunk(gather_weight, world_size * n_fused, dim=-1)
+ weight_chunks = torch.split(gather_weight, new_split_sizes, dim=-1)
else:
- weight_chunks = torch.chunk(gather_weight, world_size * n_fused, dim=0)
+ weight_chunks = torch.split(gather_weight, new_split_sizes, dim=0)
reordered_chunk_list = []
- for i in range(n_fused):
- reordered_chunk_list.extend(weight_chunks[i::n_fused])
+ for i in range(len(split_sizes)):
+ reordered_chunk_list.extend(weight_chunks[i :: len(split_sizes)])
if is_transposed:
reordered_gather_weight = torch.cat(reordered_chunk_list, dim=-1)
@@ -136,6 +146,42 @@ def gather_fused_qkv_in_gpt2_style(
return reordered_gather_weight
+class _SplitForwardGatherBackwardFusedQKV(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, qkv: torch.Tensor, split_sizes: List[int], process_group: ProcessGroup):
+ ctx.split_sizes = split_sizes
+ ctx.process_group = process_group
+ return split_fused_qkv_in_gpt2_style(qkv, split_sizes, process_group, is_transposed=True)
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ grad_output = gather_fused_qkv_in_gpt2_style(
+ grad_output, ctx.split_sizes, ctx.process_group, is_transposed=True
+ )
+ return grad_output, None, None
+
+
+def split_forward_gather_backward_fused_qkv(qkv: torch.Tensor, split_sizes: List[int], process_group: ProcessGroup):
+ return _SplitForwardGatherBackwardFusedQKV.apply(qkv, split_sizes, process_group)
+
+
+class _GatherForwardSplitBackwardFusedQKV(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, qkv: torch.Tensor, split_sizes: List[int], process_group: ProcessGroup):
+ ctx.split_sizes = split_sizes
+ ctx.process_group = process_group
+ return gather_fused_qkv_in_gpt2_style(qkv, split_sizes, process_group, is_transposed=True)
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ grad_output = split_fused_qkv_in_gpt2_style(grad_output, ctx.split_sizes, ctx.process_group, is_transposed=True)
+ return grad_output, None, None
+
+
+def gather_forward_split_backward_fused_qkv(qkv: torch.Tensor, split_sizes: List[int], process_group: ProcessGroup):
+ return _GatherForwardSplitBackwardFusedQKV.apply(qkv, split_sizes, process_group)
+
+
class GPT2FusedLinearConv1D_Col(ParallelModule):
r"""Linear layer with column parallelism.
@@ -145,10 +191,10 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
Args:
in_features (int): size of each input sample.
out_features (int): size of each output sample.
+ split_sizes (List[int]): The sizes of the split tensor.
bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``.
dtype (`torch.dtype`): The dtype of parameters, defaults to None.
device (`torch.device`): The device of parameters, defaults to None.
- n_fused (int): The number items fused, defaults to 3 (QKV).
process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None.
seq_parallel_mode (str): If set to ``None``, it will not use sequence parallel, otherwise will use corresponding mode of sequence parallel, defaults to None.
gather_output (bool, optional): If true, call all-gather on output and make Y available
@@ -169,16 +215,14 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
self,
in_features: int,
out_features: int,
+ split_sizes: List[int],
bias: bool = True,
dtype: torch.dtype = None,
device: torch.device = None,
process_group: ProcessGroup = None,
- async_communication: bool = False,
gather_output: bool = False,
seq_parallel_mode: str = None,
- overlap: bool = False,
skip_bias_add: bool = False,
- n_fused: int = 3,
weight: Optional[Parameter] = None,
bias_: Optional[Parameter] = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
@@ -192,14 +236,16 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
self.out_features = out_features
self.gather_output = gather_output
self.seq_parallel_mode = seq_parallel_mode
- self.overlap = overlap
self.skip_bias_add = skip_bias_add
self.device = device
- self.n_fused = n_fused
+ self.split_sizes = split_sizes
self.process_group = process_group
- self.async_communication = async_communication
self.fp8_communication = fp8_communication
+ assert (
+ sum(split_sizes) == out_features
+ ), f"The sum of split_sizes({sum(split_sizes)}) should be equal to out_features({out_features})."
+
if skip_bias_add and not bias:
raise ValueError("cannot skip bias addition if bias is None")
@@ -223,10 +269,10 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
self.weight = weight
def shard_fn(tensor):
- return split_fused_qkv_in_gpt2_style(tensor, self.n_fused, self.process_group, True)
+ return split_fused_qkv_in_gpt2_style(tensor, self.split_sizes, self.process_group, True)
def gather_fn(tensor):
- return gather_fused_qkv_in_gpt2_style(tensor, self.n_fused, self.process_group, True)
+ return gather_fused_qkv_in_gpt2_style(tensor, self.split_sizes, self.process_group, True)
if not is_customized_distributed_tensor(self.weight):
with torch.no_grad():
@@ -252,7 +298,11 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
@staticmethod
def from_native_module(
- module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
+ module: nn.Module,
+ process_group: Union[ProcessGroup, List[ProcessGroup]],
+ split_sizes: List[int],
+ *args,
+ **kwargs,
) -> ParallelModule:
r"""
Convert a huggingface layer `Conv1D` in gpt2 to a parallelized linear layer.
@@ -260,7 +310,7 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
Args:
module (`nn.Linear`): The module to be converted.
process_group (`Union[ProcessGroup, List[ProcessGroup]]`): The process group to be used for weight sharding and communication.
- n_fused (int): The number of layers to be fused. In GPT2, Q,K,V are fused in one weight.
+ split_sizes (List[int]): The sizes of the split tensor. In GPT2, Q,K,V are fused in one weight.
"""
LazyInitContext.materialize(module)
# get the attributes
@@ -291,6 +341,7 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
process_group=process_group,
weight=module.weight,
bias_=module.bias,
+ split_sizes=split_sizes,
*args,
**kwargs,
)
@@ -313,7 +364,7 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
# Matrix multiply.
bias = self.bias if not self.skip_bias_add else None
- if self.seq_parallel_mode == "split_gather":
+ if is_share_sp_tp(self.seq_parallel_mode):
input_parallel = input_
output_parallel = matmul_gather_forward_reducescatter_backward(
input_parallel,
@@ -322,31 +373,18 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
self.process_group,
True,
1,
- self.overlap,
- fp8_communication=self.fp8_communication,
- )
- elif self.seq_parallel_mode == "ring":
- input_parallel = input_
- output_parallel = matmul_gather_forward_reducescatter_backward(
- input_parallel,
- self.weight,
- bias,
- self.process_group,
- True,
- 1,
- self.overlap,
- True,
+ ring=self.seq_parallel_mode == "ring",
fp8_communication=self.fp8_communication,
)
elif self.seq_parallel_mode is None or self.seq_parallel_mode == "ring_attn":
# Set up backprop all-reduce.
- input_parallel = reduce_backward(input_, self.process_group)
+ input_parallel = input_
output_parallel = matmul_with_async_comm(
input_parallel,
self.weight,
bias,
self.process_group,
- self.async_communication,
+ True,
fp8_communication=self.fp8_communication,
)
else:
@@ -354,9 +392,7 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
if self.gather_output:
# All-gather across the partitions.
- output = gather_forward_split_backward(
- output_parallel, dim=-1, process_group=self.process_group, fp8_communication=self.fp8_communication
- )
+ output = gather_forward_split_backward_fused_qkv(output_parallel, self.split_sizes, self.process_group)
else:
output = output_parallel
@@ -565,7 +601,7 @@ class GPT2FusedLinearConv1D_Row(ParallelModule):
if self.seq_parallel_mode is None or self.seq_parallel_mode == "ring_attn":
output_parallel = torch.matmul(input_, self.weight)
output = reduce_forward(output_parallel, self.process_group, fp8_communication=self.fp8_communication)
- elif self.seq_parallel_mode == "split_gather":
+ elif is_share_sp_tp(self.seq_parallel_mode):
output_parallel = torch.matmul(input_, self.weight)
output = reducescatter_forward_gather_backward(
output_parallel,
@@ -573,13 +609,6 @@ class GPT2FusedLinearConv1D_Row(ParallelModule):
1,
self.fp8_communication,
)
- elif self.seq_parallel_mode == "ring":
- output_parallel = torch.matmul(input_, self.weight)
- output = reducescatter_forward_gather_backward(
- output_parallel,
- self.process_group,
- 1,
- )
else:
raise NotImplementedError(f"seq_parallel_mode={self.seq_parallel_mode} is not supported!")
@@ -605,10 +634,10 @@ class FusedLinear1D_Col(ParallelModule):
Args:
in_features (int): size of each input sample.
out_features (int): size of each output sample.
+ split_sizes (List[int]): The sizes of the split tensor.
bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``.
dtype (`torch.dtype`): The dtype of parameters, defaults to None.
device (`torch.device`): The device of parameters, defaults to None.
- n_fused (int): The number items fused, defaults to 3 (QKV).
process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None.
gather_output (bool, optional): If true, call all-gather on output and make Y available
to all GPUs, otherwise, every GPU will have its output
@@ -628,14 +657,15 @@ class FusedLinear1D_Col(ParallelModule):
self,
in_features: int,
out_features: int,
+ split_sizes: List[int],
bias: bool = True,
dtype: torch.dtype = None,
device: torch.device = None,
process_group: ProcessGroup = None,
- async_communication: bool = False,
gather_output: bool = False,
+ seq_parallel_mode: str = None,
+ seq_parallel_dim: int = 1,
skip_bias_add: bool = False,
- n_fused: int = 3,
weight: Optional[Parameter] = None,
bias_: Optional[Parameter] = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
@@ -647,13 +677,18 @@ class FusedLinear1D_Col(ParallelModule):
self.in_features = in_features
self.out_features = out_features
self.gather_output = gather_output
+ self.seq_parallel_mode = seq_parallel_mode
+ self.seq_parallel_dim = seq_parallel_dim
self.skip_bias_add = skip_bias_add
self.device = device
- self.n_fused = n_fused
+ self.split_sizes = split_sizes
self.process_group = process_group
- self.async_communication = async_communication
self.fp8_communication = fp8_communication
+ assert (
+ sum(split_sizes) == out_features
+ ), f"The sum of split_sizes({sum(split_sizes)}) should be equal to out_features({out_features})."
+
if skip_bias_add and not bias:
raise ValueError("cannot skip bias addition if bias is None")
@@ -677,10 +712,10 @@ class FusedLinear1D_Col(ParallelModule):
self.weight = weight
def shard_fn(tensor):
- return split_fused_qkv_in_gpt2_style(tensor, self.n_fused, self.process_group, False)
+ return split_fused_qkv_in_gpt2_style(tensor, self.split_sizes, self.process_group, False)
def gather_fn(tensor):
- return gather_fused_qkv_in_gpt2_style(tensor, self.n_fused, self.process_group, False)
+ return gather_fused_qkv_in_gpt2_style(tensor, self.split_sizes, self.process_group, False)
if not is_customized_distributed_tensor(self.weight):
with torch.no_grad():
@@ -706,7 +741,11 @@ class FusedLinear1D_Col(ParallelModule):
@staticmethod
def from_native_module(
- module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], n_fused: int, *args, **kwargs
+ module: nn.Module,
+ process_group: Union[ProcessGroup, List[ProcessGroup]],
+ split_sizes: List[int],
+ *args,
+ **kwargs,
) -> ParallelModule:
r"""
Convert a fused `torch.nn.linear` layer to a parallelized linear layer.
@@ -714,7 +753,7 @@ class FusedLinear1D_Col(ParallelModule):
Args:
module (`nn.Linear`): The module to be converted.
process_group (`Union[ProcessGroup, List[ProcessGroup]]`): The process group to be used for weight sharding and communication.
- n_fused (int): The number of layers to be fused. In common, Q,K,V are fused in one weight.
+ split_sizes (List[int]): The sizes of the split tensor. In common, Q,K,V are fused in one weight.
"""
LazyInitContext.materialize(module)
@@ -737,25 +776,11 @@ class FusedLinear1D_Col(ParallelModule):
process_group=process_group,
weight=module.weight,
bias_=module.bias,
- n_fused=n_fused,
+ split_sizes=split_sizes,
*args,
**kwargs,
)
- # # TODO: copy the sharded weights
- # with torch.no_grad():
- # sharded_weight = split_fused_qkv_in_gpt2_style(module.weight.data,
- # n_fused=n_fused,
- # process_group=process_group,
- # is_transposed=False)
- # linear_1d.weight.data.copy_(sharded_weight.data)
-
- # if bias:
- # sharded_bias = split_fused_qkv_in_gpt2_style(module.bias.data,
- # n_fused=n_fused,
- # process_group=process_group,
- # is_transposed=False)
- # linear_1d.bias.data.copy_(sharded_bias.data)
return linear_1d
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
@@ -772,19 +797,29 @@ class FusedLinear1D_Col(ParallelModule):
input_.shape, self.weight.shape, self.weight.shape[-1]
)
# Set up backprop all-reduce.
- # input_parallel = reduce_backward(input_, self.process_group)
input_parallel = input_
# Matrix multiply.
bias = self.bias if not self.skip_bias_add else None
- output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True)
+ if is_share_sp_tp(self.seq_parallel_mode):
+ output_parallel = linear_gather_forward_reducescatter_backward(
+ input_parallel,
+ self.weight,
+ bias,
+ self.process_group,
+ True,
+ self.seq_parallel_dim,
+ ring=self.seq_parallel_mode == "ring",
+ )
+ else:
+ output_parallel = linear_with_async_comm(
+ input_parallel, self.weight, bias, self.process_group, True, fp8_communication=self.fp8_communication
+ )
if self.gather_output:
# All-gather across the partitions.
- output = gather_forward_split_backward(
- output_parallel, dim=-1, process_group=self.process_group, fp8_communication=self.fp8_communication
- )
+ output = gather_forward_split_backward_fused_qkv(output_parallel, self.split_sizes, self.process_group)
else:
output = output_parallel
@@ -792,3 +827,196 @@ class FusedLinear1D_Col(ParallelModule):
return output, self.bias
else:
return output
+
+
+class FusedLinear1D_Row(ParallelModule):
+ r"""Linear layer with row parallelism
+
+ Args:
+ in_features (int): size of each input sample.
+ out_features (int): size of each output sample.
+ bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``.
+ dtype (`torch.dtype`): The dtype of parameters, defaults to None.
+ parallel_input (bool): If set to ``True``, it's assumed that the input is split, defaults to False.
+ process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None.
+ seq_parallel_mode (`str`): The type of sp mode, it will use sequence parallel when `seq_parallel_mode` is not None. Defaults to None.
+ seq_parallel_dim (`int`): Which dim will sequence parallelism split and gather the sequence.
+ skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer,
+ which is preserved for kernel fusion, defaults to False
+ weight_initializer (:class:`typing.Callable`, optional):
+ The initializer of weight, defaults to kaiming uniform initializer.
+ bias_initializer (:class:`typing.Callable`, optional):
+ The initializer of bias, defaults to xavier uniform initializer.
+
+ More details about ``initializer`` please refer to
+ `init `_.
+ """
+
+ def __init__(
+ self,
+ in_features: int,
+ out_features: int,
+ split_sizes: List[int],
+ bias: bool = True,
+ dtype: torch.dtype = None,
+ device: torch.device = None,
+ process_group: ProcessGroup = None,
+ seq_parallel_mode: str = None,
+ seq_parallel_dim: int = 1,
+ parallel_input: bool = True,
+ skip_bias_add: bool = False,
+ weight: Optional[Parameter] = None,
+ bias_: Optional[Parameter] = None,
+ weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
+ bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
+ fp8_communication: bool = False,
+ ):
+ super().__init__()
+ # Keep input parameters
+ self.in_features = in_features
+ self.out_features = out_features
+ self.split_sizes = split_sizes
+ self.parallel_input = parallel_input
+ self.skip_bias_add = skip_bias_add
+ self.process_group = process_group
+ self.seq_parallel_mode = seq_parallel_mode
+ self.seq_parallel_dim = seq_parallel_dim
+ self.num_partitions = dist.get_world_size(self.process_group)
+ self.fp8_communication = fp8_communication
+
+ assert (
+ sum(split_sizes) == in_features
+ ), f"The sum of split_sizes({sum(split_sizes)}) should be equal to in_features({in_features})."
+
+ if skip_bias_add and not bias:
+ raise ValueError("cannot skip bias addition if bias is None")
+
+ # offset the seed with randomizer index and rank
+ seed = torch.random.initial_seed()
+ self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group)
+
+ # sanity check
+ if weight is not None:
+ assert not bias or bias_ is not None, "bias_ must be provided if bias is True when weight is not None"
+ else:
+ assert bias_ is None, "bias_ must be None if weight is None"
+
+ # Parameters.
+ if weight is None:
+ # Initialize weight.
+ factory_kwargs = {"device": device, "dtype": dtype}
+ self.weight = Parameter(torch.empty(self.out_features, self.in_features, **factory_kwargs))
+ else:
+ weight.data = weight.data.to(device=device, dtype=dtype)
+ self.weight = weight
+
+ def shard_fn(tensor):
+ return split_fused_qkv_in_gpt2_style(tensor, self.split_sizes, self.process_group, True)
+
+ def gather_fn(tensor):
+ return gather_fused_qkv_in_gpt2_style(tensor, self.split_sizes, self.process_group, True)
+
+ if not is_customized_distributed_tensor(self.weight):
+ with torch.no_grad():
+ sharded_weight = distribute_tensor_with_customization(self.weight.data, shard_fn, gather_fn)
+ customized_distributed_tensor_to_existing_param(sharded_weight, self.weight)
+
+ if bias:
+ if bias_ is None:
+ self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs))
+ else:
+ bias_.data = bias_.data.to(device=device, dtype=dtype)
+ self.bias = bias_
+ else:
+ self.bias = None
+
+ if weight is None:
+ with self.randomizer.fork_rng(enable_cpu=True):
+ self.reset_parameters(weight_initializer, bias_initializer)
+
+ @staticmethod
+ def from_native_module(
+ module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], split_sizes: List[int], **kwargs
+ ) -> ParallelModule:
+ r"""
+ Convert a native PyTorch linear layer to a parallelized linear layer.
+ """
+ LazyInitContext.materialize(module)
+ # get the attributes
+ in_features = module.in_features
+ out_features = module.out_features
+ bias = module.bias is not None
+ device = module.weight.device
+
+ # ensure only one process group is passed
+ if isinstance(process_group, (list, tuple)):
+ assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}."
+ process_group = process_group[0]
+
+ linear_1d = FusedLinear1D_Row(
+ in_features=in_features,
+ out_features=out_features,
+ bias=bias,
+ device=device,
+ process_group=process_group,
+ weight=module.weight,
+ bias_=module.bias,
+ split_sizes=split_sizes,
+ **kwargs,
+ )
+
+ return linear_1d
+
+ @torch.no_grad()
+ def reset_parameters(self, weight_initializer, bias_initializer) -> None:
+ fan_in, fan_out = self.in_features, self.out_features
+ weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
+
+ if self.bias is not None:
+ bias_initializer(self.bias, fan_in=fan_in)
+ if self.process_group is None:
+ src_rank = 0
+ else:
+ src_rank = dist.distributed_c10d._get_global_rank(self.process_group, 0)
+
+ origin_device = self.bias.device
+ bias = self.bias.cuda()
+ dist.broadcast(bias, src=src_rank, group=self.process_group)
+ bias = bias.to(origin_device)
+ self.bias.copy_(bias)
+
+ def forward(self, input_: Tensor) -> Tensor:
+ # Set up backprop all-reduce.
+ if self.parallel_input:
+ assert (
+ input_.shape[-1] == self.weight.shape[-1]
+ ), "Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.".format(
+ input_.shape, self.weight.shape, self.weight.shape[-1]
+ )
+ input_ = input_
+ else:
+ assert (
+ divide(input_.shape[-1], self.num_partitions) == self.weight.shape[-1]
+ ), "Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.".format(
+ input_.shape, self.weight.shape, self.weight.shape[-1] * self.num_partitions
+ )
+ input_ = split_forward_gather_backward_fused_qkv(input_, self.split_sizes, self.process_group)
+
+ if is_share_sp_tp(self.seq_parallel_mode):
+ output = linear_reducescatter_forward_gather_backward(
+ input_,
+ self.weight,
+ process_group=self.process_group,
+ dim=self.seq_parallel_dim,
+ ring=self.seq_parallel_mode == "ring",
+ )
+ else:
+ output_parallel = F.linear(input_, self.weight)
+ output = reduce_forward(output_parallel, self.process_group, fp8_communication=self.fp8_communication)
+
+ if not self.skip_bias_add:
+ if self.bias is not None:
+ output = output + self.bias
+ return output
+ else:
+ return output, self.bias
diff --git a/colossalai/shardformer/layer/utils.py b/colossalai/shardformer/layer/utils.py
index 4512e0c68..2df68e18c 100644
--- a/colossalai/shardformer/layer/utils.py
+++ b/colossalai/shardformer/layer/utils.py
@@ -295,8 +295,8 @@ def split_batch_zigzag(
batch: Union[torch.Tensor, List[torch.Tensor]], sp_group: ProcessGroup, seq_dim: int = 1, is_label: bool = False
) -> Union[torch.Tensor, List[torch.Tensor]]:
"""
- Split the input along the sequence dimension for Ring Attention. Naively spliting the attention mask
- in the causal setting will result in the preceding ranks having much less workload.
+ Split the input sequence batch . Naively spliting the attention mask in the causal setting
+ will result in the preceding ranks having much less workload.
We split after "folding" the 2D attention mask in half (https://github.com/zhuzilin/ring-flash-attention/issues/2).
For example, for sp_size = 4 and seq_len = 8, we get | s0, s7 | s1, s6 | s2, s5 | s3, s4 |.
@@ -346,40 +346,42 @@ def split_varlen_zigzag(
cu_seqlens: torch.Tensor,
sp_group: ProcessGroup,
max_seqlen: int = 0,
- is_2d: bool = False,
+ is_batched_seq: bool = False,
is_label: bool = False,
) -> Union[List[torch.Tensor], torch.Tensor]:
- """Split each sequence in a batch of packed sequences in a zigzag fashion.
- For each tensor in batch, return packed sequences if is_2d is False;
- else return a padded batch of sequences.
-
+ """Split a packed seq/batch of padded sequences in a Zigzag fashion.
+ Different from split_batch_zigzag, inputs here have variable sequence lengths.
Args:
- batch (List[torch.Tensor]): Packed sequences of shape (B * Sq, ...), or (B, Sq, ...) if is_2d.
+ batch (List[torch.Tensor]): Packed sequences of shape (T, ...), or (B, Sq, ...) if is_batched_seq,
+ where T is the total number of tokens.
cu_seqlens (torch.Tensor): Cumulative sequence lengths of shape (B + 1) before splitting.
sp_group (ProcessGroup): The process group for sequence parallelism.
max_seqlen (int): The maximum sequence length in the batch before splitting.
- is_2d (bool): If True, then input has batch size and sequence length split into two dimensions.
+ is_batched_seq (bool): If True, then the input is a batch of sequences padded to the same len.
is_label (bool): If True, mask out the first token in each sequence ().
Returns:
- batch (List[torch.Tensor]): Packed sequences of shape (B * max_seqlen // sp_size)
- or (B, max_seqlen // sp_size, ...) if is_2d
+ batch (List[torch.Tensor]): Packed sequences of shape (T, ..)
+ or (B, max_seqlen // sp_size, ...) if is_batched_seq
"""
sp_size = dist.get_world_size(sp_group)
sp_rank = dist.get_rank(sp_group)
if sp_size == 1:
return batch
- if is_2d:
+ if is_batched_seq:
assert max_seqlen > 0, "max_seqlen must be provided for 2D input"
if isinstance(batch, torch.Tensor):
batch = [batch]
+ # seq: (B, Sq, h, n)
+ # seq = seq[:, :rank * (seqlen // sp_size), ...]
+
for i, packed_seq in enumerate(batch):
device = packed_seq.device
dtype = packed_seq.dtype
- if is_2d:
+ if is_batched_seq:
assert max_seqlen % (sp_size * 2) == 0
# Recreate a padded tensor with the new max seqlen
shape = (packed_seq.shape[0], max_seqlen // sp_size, *packed_seq.shape[2:])
@@ -398,7 +400,7 @@ def split_varlen_zigzag(
seqlen % (2 * sp_size) == 0
), f"batch {i} seq {j}'s length ({seqlen}) must be divisible by 2 * sp_size = {2 * sp_size} for splitting"
- if is_2d:
+ if is_batched_seq:
seq = packed_seq[j][:seqlen]
if is_label:
# Shift one position to the right for next token prediction
@@ -415,7 +417,7 @@ def split_varlen_zigzag(
seq = seq.chunk(sp_size * 2)
local_seq.extend([seq[sp_rank], seq[2 * sp_size - 1 - sp_rank]])
- if is_2d:
+ if is_batched_seq:
batch[i] = local_seq.contiguous()
else:
batch[i] = torch.cat(local_seq, dim=0)
diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py
index 798fca88f..d550484da 100644
--- a/colossalai/shardformer/modeling/gpt2.py
+++ b/colossalai/shardformer/modeling/gpt2.py
@@ -857,17 +857,17 @@ def get_gpt2_flash_attention_forward(shard_config: Optional[ShardConfig] = None)
dropout_p = self.attn_dropout.p if self.training else 0.0
sp_mode = shard_config.sequence_parallelism_mode
- sp_group = shard_config.sequence_parallel_process_group
if sp_mode == "ring_attn":
attn_output = RingAttention.attention(
query,
key,
value,
- sp_group,
+ sp_axis=shard_config.sp_axis,
**attention_mask,
dropout_p=dropout_p,
scale=scale,
inner_ring_size=shard_config.inner_ring_size,
+ pg_mesh=shard_config.pg_mesh,
)
else:
attn_output = ColoAttention.attention(query, key, value, **attention_mask, dropout_p=dropout_p, scale=scale)
diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py
index 7a04c5451..a51a1df9f 100644
--- a/colossalai/shardformer/modeling/llama.py
+++ b/colossalai/shardformer/modeling/llama.py
@@ -271,6 +271,7 @@ class LlamaPipelineForwards:
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None,
+ **kwargs,
):
r"""
Args:
@@ -568,9 +569,10 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s
query_states,
key_states,
value_states,
- sp_group,
+ sp_axis=shard_config.sp_axis,
**attention_mask,
inner_ring_size=shard_config.inner_ring_size,
+ pg_mesh=shard_config.pg_mesh,
)
elif shard_config.enable_flash_attention:
diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py
index 4c33e14bc..09673d396 100644
--- a/colossalai/shardformer/policies/bert.py
+++ b/colossalai/shardformer/policies/bert.py
@@ -73,7 +73,6 @@ class BertPolicy(Policy):
)
sp_mode = "split_gather"
- overlap = self.shard_config.enable_sequence_overlap
sp_partial_derived = sp_mode == "split_gather"
if self.shard_config.enable_tensor_parallelism:
@@ -97,7 +96,6 @@ class BertPolicy(Policy):
target_module=col_nn.Linear1D_Col,
kwargs={
"seq_parallel_mode": sp_mode,
- "overlap": overlap,
"fp8_communication": self.shard_config.fp8_communication,
},
),
@@ -106,7 +104,6 @@ class BertPolicy(Policy):
target_module=col_nn.Linear1D_Col,
kwargs={
"seq_parallel_mode": sp_mode,
- "overlap": overlap,
"fp8_communication": self.shard_config.fp8_communication,
},
),
@@ -115,7 +112,6 @@ class BertPolicy(Policy):
target_module=col_nn.Linear1D_Col,
kwargs={
"seq_parallel_mode": sp_mode,
- "overlap": overlap,
"fp8_communication": self.shard_config.fp8_communication,
},
),
@@ -140,7 +136,6 @@ class BertPolicy(Policy):
target_module=col_nn.Linear1D_Col,
kwargs={
"seq_parallel_mode": sp_mode,
- "overlap": overlap,
"skip_bias_add": self.enable_bias_gelu_fused,
"fp8_communication": self.shard_config.fp8_communication,
},
diff --git a/colossalai/shardformer/policies/blip2.py b/colossalai/shardformer/policies/blip2.py
index da798f6a0..2e73d5c2a 100644
--- a/colossalai/shardformer/policies/blip2.py
+++ b/colossalai/shardformer/policies/blip2.py
@@ -71,7 +71,7 @@ class BlipPolicy(Policy):
suffix="self_attn.qkv",
target_module=col_nn.FusedLinear1D_Col,
kwargs={
- "n_fused": 3,
+ "split_sizes": [self.model.config.vision_config.hidden_size] * 3,
"fp8_communication": self.shard_config.fp8_communication,
},
),
diff --git a/colossalai/shardformer/policies/bloom.py b/colossalai/shardformer/policies/bloom.py
index a43ac02d0..7c6259e85 100644
--- a/colossalai/shardformer/policies/bloom.py
+++ b/colossalai/shardformer/policies/bloom.py
@@ -57,7 +57,6 @@ class BloomPolicy(Policy):
)
sp_mode = "split_gather"
- overlap = self.shard_config.enable_sequence_overlap
sp_partial_derived = sp_mode == "split_gather"
if self.shard_config.enable_tensor_parallelism:
@@ -78,7 +77,6 @@ class BloomPolicy(Policy):
target_module=col_nn.Linear1D_Col,
kwargs={
"seq_parallel_mode": sp_mode,
- "overlap": overlap,
"fp8_communication": self.shard_config.fp8_communication,
},
),
@@ -99,7 +97,6 @@ class BloomPolicy(Policy):
target_module=col_nn.Linear1D_Col,
kwargs={
"seq_parallel_mode": sp_mode,
- "overlap": overlap,
"fp8_communication": self.shard_config.fp8_communication,
},
),
diff --git a/colossalai/shardformer/policies/chatglm2.py b/colossalai/shardformer/policies/chatglm2.py
index 1b7d2db85..c003570a0 100644
--- a/colossalai/shardformer/policies/chatglm2.py
+++ b/colossalai/shardformer/policies/chatglm2.py
@@ -67,7 +67,6 @@ class ChatGLMPolicy(Policy):
f"For ChatGLM2, sequence parallelism doesn't support mode {sp_mode} yet, will set to be split_gather"
)
sp_mode = "split_gather"
- overlap = self.shard_config.enable_sequence_overlap
sp_partial_derived = sp_mode in ["split_gather"]
if sp_mode == "all_to_all":
@@ -127,7 +126,6 @@ class ChatGLMPolicy(Policy):
kwargs={
"seq_parallel_mode": sp_mode,
"seq_parallel_dim": 0,
- "overlap": overlap,
"fp8_communication": self.shard_config.fp8_communication,
},
),
diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py
index d9233be9a..08accaaea 100644
--- a/colossalai/shardformer/policies/gpt2.py
+++ b/colossalai/shardformer/policies/gpt2.py
@@ -65,7 +65,6 @@ class GPT2Policy(Policy):
f"For GPT2, sequence parallelism is currently not support mode {sp_mode}, will set to be split_gather"
)
self.shard_config.sequence_parallelism_mode = sp_mode = "split_gather"
- overlap = self.shard_config.enable_sequence_overlap
sp_partial_derived = sp_mode in ["split_gather", "ring"]
use_flash_attention = self.shard_config.enable_flash_attention
if self.shard_config.enable_tensor_parallelism:
@@ -92,9 +91,8 @@ class GPT2Policy(Policy):
suffix="attn.c_attn",
target_module=col_nn.GPT2FusedLinearConv1D_Col,
kwargs={
- "n_fused": 3,
+ "split_sizes": [self.model.config.hidden_size] * 3,
"seq_parallel_mode": sp_mode,
- "overlap": overlap,
"fp8_communication": self.shard_config.fp8_communication,
},
),
@@ -107,9 +105,8 @@ class GPT2Policy(Policy):
suffix="mlp.c_fc",
target_module=col_nn.GPT2FusedLinearConv1D_Col,
kwargs={
- "n_fused": 1,
+ "split_sizes": [self.model.config.n_inner or 4 * self.model.config.hidden_size],
"seq_parallel_mode": sp_mode,
- "overlap": overlap,
"skip_bias_add": self.enable_bias_gelu_fused,
"fp8_communication": self.shard_config.fp8_communication,
},
diff --git a/colossalai/shardformer/policies/gptj.py b/colossalai/shardformer/policies/gptj.py
index 6f0c8803c..9fcca1385 100644
--- a/colossalai/shardformer/policies/gptj.py
+++ b/colossalai/shardformer/policies/gptj.py
@@ -51,7 +51,6 @@ class GPTJPolicy(Policy):
self.shard_config.enable_sequence_parallelism = False
warnings.warn("GPTJ doesn't support sequence parallelism now, will ignore the sequence parallelism flag.")
- overlap = self.shard_config.enable_sequence_overlap
if self.shard_config.enable_tensor_parallelism:
assert (
self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
@@ -76,7 +75,6 @@ class GPTJPolicy(Policy):
suffix="attn.k_proj",
target_module=col_nn.Linear1D_Col,
kwargs={
- "overlap": overlap,
"fp8_communication": self.shard_config.fp8_communication,
},
),
@@ -84,7 +82,6 @@ class GPTJPolicy(Policy):
suffix="attn.q_proj",
target_module=col_nn.Linear1D_Col,
kwargs={
- "overlap": overlap,
"fp8_communication": self.shard_config.fp8_communication,
},
),
@@ -92,7 +89,6 @@ class GPTJPolicy(Policy):
suffix="attn.v_proj",
target_module=col_nn.Linear1D_Col,
kwargs={
- "overlap": overlap,
"fp8_communication": self.shard_config.fp8_communication,
},
),
diff --git a/colossalai/shardformer/policies/sam.py b/colossalai/shardformer/policies/sam.py
index 674fe5e58..a94cc9119 100644
--- a/colossalai/shardformer/policies/sam.py
+++ b/colossalai/shardformer/policies/sam.py
@@ -42,7 +42,7 @@ class SamPolicy(Policy):
suffix="attn.qkv",
target_module=col_nn.FusedLinear1D_Col,
kwargs={
- "n_fused": 3,
+ "split_sizes": [self.model.config.vision_config.hidden_size] * 3,
"fp8_communication": self.shard_config.fp8_communication,
},
),
diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py
index 1219119bb..4d4a1803b 100644
--- a/colossalai/shardformer/shard/shard_config.py
+++ b/colossalai/shardformer/shard/shard_config.py
@@ -26,7 +26,6 @@ class ShardConfig:
enable_flash_attention (bool, optional): Whether to switch on flash attention. Defaults to False.
enable_jit_fused (bool, optional): Whether to switch on JIT fused operators. Defaults to False.
enable_sequence_parallelism (bool): Whether to turn on sequence parallelism, which partitions non-tensor-parallel regions along the sequence dimension. Defaults to False.
- enable_sequence_overlap (bool): Whether to turn on sequence overlap, which overlap the computation and communication in sequence parallelism. It can only be used when enable_sequence_parallelism is True. Defaults to False.
gradient_checkpoint_config (Optional[GradientCheckpointConfig]): The gradient checkpoint config. Defaults to None.
enable_all_optimization (bool): Whether to turn on all optimization tools including 'fused normalization', 'flash attention', 'JIT fused operators', 'sequence parallelism' and 'sequence overlap'. Defaults to False.
fp8_communication (bool, optional): Whether to enable fp8 communication in model parallelism. Defaults to False.
@@ -44,13 +43,14 @@ class ShardConfig:
enable_jit_fused: bool = False
enable_sequence_parallelism: bool = False
sequence_parallelism_mode: str = None
- enable_sequence_overlap: bool = False
parallel_output: bool = True
make_vocab_size_divisible_by: int = 64
gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None
extra_kwargs: Dict[str, Any] = field(default_factory=dict)
# For ring attention
+ sp_axis: Optional[int] = None
+ pg_mesh: Optional[int] = None
inner_ring_size: Optional[int] = None
# for moe related
moe_dp_group: Optional[ProcessGroup] = None
@@ -84,24 +84,12 @@ class ShardConfig:
assert (
self.enable_tensor_parallelism
), f"sequence parallelism mode {self.sequence_parallelism_mode} can only be used when enable_tensor_parallelism is True"
- elif self.sequence_parallelism_mode in ["all_to_all"]:
- # assert (
- # not self.enable_tensor_parallelism
- # ), f"sequence parallelism mode {self.sequence_parallelism_mode} can only be used when enable_tensor_parallelism is False"
- if self.enable_sequence_overlap:
- self.enable_sequence_overlap = False
- warnings.warn(
- f"The enable_sequence_overlap flag will be ignored in sequence parallelism mode {self.sequence_parallelism_mode}"
- )
else:
if self.sequence_parallelism_mode:
self.sequence_parallelism_mode = None
warnings.warn(
f"The sequence_parallelism_mode will be ignored when enable_sequence_parallelism is False"
)
- assert (
- not self.enable_sequence_overlap
- ), f"enable_sequence_overlap can only be set to True when enable_sequence_parallelism is True"
# get the tensor parallel size
if not self.enable_tensor_parallelism:
@@ -134,4 +122,3 @@ class ShardConfig:
# This can cause non-in-place param sharding when used without ZeRO.
# It may also slow down training when seq len is small. Plz enable manually.
# self.enable_sequence_parallelism = True
- # self.enable_sequence_overlap = True
diff --git a/colossalai/utils/__init__.py b/colossalai/utils/__init__.py
index cdba46709..1605a5f4e 100644
--- a/colossalai/utils/__init__.py
+++ b/colossalai/utils/__init__.py
@@ -5,6 +5,7 @@ from .common import (
ensure_path_exists,
free_storage,
get_current_device,
+ get_non_persistent_buffers_set,
is_ddp_ignored,
set_seed,
)
@@ -25,4 +26,5 @@ __all__ = [
"set_seed",
"get_current_device",
"is_ddp_ignored",
+ "get_non_persistent_buffers_set",
]
diff --git a/colossalai/utils/common.py b/colossalai/utils/common.py
index 4a1889eb5..0863a812b 100644
--- a/colossalai/utils/common.py
+++ b/colossalai/utils/common.py
@@ -5,10 +5,11 @@ import os
import random
from contextlib import contextmanager
from pathlib import Path
-from typing import Callable
+from typing import Callable, Optional, Set
import numpy as np
import torch
+import torch.nn as nn
from colossalai.accelerator import get_accelerator
@@ -76,3 +77,34 @@ def set_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
+
+
+def get_non_persistent_buffers_set(
+ module, memo: Optional[Set[nn.Module]] = None, prefix: str = "", remove_duplicate: bool = True
+):
+ r"""
+ Args:
+ memo: a memo to store the set of modules already added to the result
+ prefix: a prefix that will be added to the name of the module
+ remove_duplicate: whether to remove the duplicated module instances in the result
+ or not
+ """
+
+ if memo is None:
+ memo = set()
+ self_non_persistent_set = set()
+ if module not in memo:
+ if remove_duplicate:
+ memo.add(module)
+ self_non_persistent_set = set(
+ map(lambda key: prefix + ("." if prefix else "") + key, module._non_persistent_buffers_set)
+ )
+ for name, sub_module in module._modules.items():
+ if sub_module is None:
+ continue
+ submodule_prefix = prefix + ("." if prefix else "") + name
+ child_non_persistent_set = get_non_persistent_buffers_set(
+ sub_module, memo, submodule_prefix, remove_duplicate
+ )
+ self_non_persistent_set = set.union(self_non_persistent_set, child_non_persistent_set)
+ return self_non_persistent_set
diff --git a/colossalai/utils/safetensors.py b/colossalai/utils/safetensors.py
new file mode 100644
index 000000000..9aa3558d9
--- /dev/null
+++ b/colossalai/utils/safetensors.py
@@ -0,0 +1,64 @@
+# a python safetensors serializer modified from https://github.com/huggingface/safetensors/blob/41bd1acf38ad28ac559522d40596c6c802f79453/safetensors/src/tensor.rs#L214
+import json
+from dataclasses import asdict, dataclass
+from typing import Dict, List, Tuple
+
+import torch
+from safetensors.torch import _TYPES
+
+try:
+ from tensornvme.async_file_io import AsyncFileWriter
+except ModuleNotFoundError:
+ raise ModuleNotFoundError("Please install tensornvme to use NVMeOptimizer")
+_TYPES_INV = {v: k for k, v in _TYPES.items()}
+
+
+@dataclass
+class TensorInfo:
+ dtype: str
+ shape: List[int]
+ data_offsets: Tuple[int, int]
+
+
+@dataclass
+class PreparedData:
+ n: int
+ header_bytes: bytes
+ offset: int
+
+
+def prepare(data: Dict[str, torch.Tensor]) -> Tuple[PreparedData, List[torch.Tensor]]:
+ sorted_data = sorted(data.items(), key=lambda x: (x[1].dtype, x[0]))
+
+ tensors = []
+ metadata = {}
+ offset = 0
+
+ for name, tensor in sorted_data:
+ n = tensor.numel() * tensor.element_size()
+ tensor_info = TensorInfo(
+ dtype=_TYPES_INV[tensor.dtype], shape=list(tensor.shape), data_offsets=(offset, offset + n)
+ )
+ offset += n
+ metadata[name] = asdict(tensor_info)
+ tensors.append(tensor)
+
+ metadata_buf = json.dumps(metadata).encode("utf-8")
+
+ extra = (8 - len(metadata_buf) % 8) % 8
+ metadata_buf += b" " * extra
+
+ n = len(metadata_buf)
+
+ return PreparedData(n=n, header_bytes=metadata_buf, offset=offset), tensors
+
+
+def save(f_writer: AsyncFileWriter, state_dict: Dict[str, torch.Tensor]) -> None:
+ prepared_data, tensors = prepare(state_dict)
+ n, header_bytes, _ = prepared_data.n, prepared_data.header_bytes, prepared_data.offset
+
+ f_writer.write(n.to_bytes(8, byteorder="little"))
+ f_writer.write(header_bytes)
+
+ for tensor in tensors:
+ f_writer.write_raw(tensor, tensor.data_ptr(), tensor.numel() * tensor.element_size(), f_writer.offset)
diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py
index 9111c3b5d..fe5fb82ca 100644
--- a/colossalai/zero/gemini/gemini_ddp.py
+++ b/colossalai/zero/gemini/gemini_ddp.py
@@ -35,7 +35,7 @@ from colossalai.tensor.padded_tensor import (
to_unpadded_tensor,
)
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
-from colossalai.utils import _cast_float, free_storage, is_ddp_ignored
+from colossalai.utils import _cast_float, free_storage, get_non_persistent_buffers_set, is_ddp_ignored
from .chunk import Chunk, ChunkManager, TensorState, init_chunk_manager
from .gemini_hook import GeminiZeROHook
@@ -187,7 +187,7 @@ class GeminiDDP(ModelWrapper):
pin_memory=pin_memory,
)
super().__init__(module)
- self._non_persistent_buffers_set = self._get_non_persistent_buffers_set(module)
+ self._non_persistent_buffers_set = get_non_persistent_buffers_set(module)
self._cast_buffers()
# register grad hook
@@ -257,36 +257,6 @@ class GeminiDDP(ModelWrapper):
for p in params_to_ignore:
p._ddp_to_ignore = True
- def _get_non_persistent_buffers_set(
- self, module, memo: Optional[Set[nn.Module]] = None, prefix: str = "", remove_duplicate: bool = True
- ):
- r"""
- Args:
- memo: a memo to store the set of modules already added to the result
- prefix: a prefix that will be added to the name of the module
- remove_duplicate: whether to remove the duplicated module instances in the result
- or not
- """
-
- if memo is None:
- memo = set()
- self_non_persistent_set = set()
- if module not in memo:
- if remove_duplicate:
- memo.add(module)
- self_non_persistent_set = set(
- map(lambda key: prefix + ("." if prefix else "") + key, module._non_persistent_buffers_set)
- )
- for name, sub_module in module._modules.items():
- if sub_module is None:
- continue
- submodule_prefix = prefix + ("." if prefix else "") + name
- child_non_persistent_set = self._get_non_persistent_buffers_set(
- sub_module, memo, submodule_prefix, remove_duplicate
- )
- self_non_persistent_set = set.union(self_non_persistent_set, child_non_persistent_set)
- return self_non_persistent_set
-
def _post_forward(self):
"""This function is only triggered for inference."""
access_list = list(self.chunk_manager.accessed_chunks)
diff --git a/colossalai/zero/gemini/memory_tracer/runtime_mem_tracer.py b/colossalai/zero/gemini/memory_tracer/runtime_mem_tracer.py
index b0d258824..81520326f 100644
--- a/colossalai/zero/gemini/memory_tracer/runtime_mem_tracer.py
+++ b/colossalai/zero/gemini/memory_tracer/runtime_mem_tracer.py
@@ -1,10 +1,5 @@
import torch.nn
-from colossalai.legacy.zero.gemini.ophooks.runtime_mem_tracer_hook import (
- GradMemStats,
- GradMemTracerHook,
- ParamMemTracerHook,
-)
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
from colossalai.utils import _cast_float
@@ -27,6 +22,12 @@ class RuntimeMemTracer:
def __init__(self, module: torch.nn.Module, dtype: torch.dtype = torch.half):
super().__init__()
+ from colossalai.legacy.zero.gemini.ophooks.runtime_mem_tracer_hook import (
+ GradMemStats,
+ GradMemTracerHook,
+ ParamMemTracerHook,
+ )
+
self.module = module
self.dtype = dtype
self._gradstat = GradMemStats()
diff --git a/colossalai/zero/gemini/placement_policy.py b/colossalai/zero/gemini/placement_policy.py
index 178755d03..2aa8dc3f6 100644
--- a/colossalai/zero/gemini/placement_policy.py
+++ b/colossalai/zero/gemini/placement_policy.py
@@ -8,7 +8,6 @@ import torch
import torch.distributed as dist
from colossalai.accelerator import get_accelerator
-from colossalai.legacy.utils.memory import colo_device_memory_capacity
from colossalai.zero.gemini.chunk import Chunk
from .chunk import Chunk, ChunkManager
@@ -172,6 +171,8 @@ class AutoPlacementPolicy(PlacementPolicy):
Returns:
int: the volume of memory that is evicted
"""
+ from colossalai.legacy.utils.memory import colo_device_memory_capacity
+
start = time()
cuda_capacity = colo_device_memory_capacity(get_accelerator().get_current_device())
used_cuda_model_data = self.chunk_manager.total_mem["cuda"]
diff --git a/docs/README-zh-Hans.md b/docs/README-zh-Hans.md
index ef121d348..c32e5f13c 100644
--- a/docs/README-zh-Hans.md
+++ b/docs/README-zh-Hans.md
@@ -25,15 +25,13 @@
## 新闻
+* [2024/10] [How to build a low-cost Sora-like app? Solutions for you](https://company.hpc-ai.com/blog/how-to-build-a-low-cost-sora-like-app-solutions-for-you)
+* [2024/09] [Singapore Startup HPC-AI Tech Secures 50 Million USD in Series A Funding to Build the Video Generation AI Model and GPU Platform](https://company.hpc-ai.com/blog/singapore-startup-hpc-ai-tech-secures-50-million-usd-in-series-a-funding-to-build-the-video-generation-ai-model-and-gpu-platform)
+* [2024/09] [Reducing AI Large Model Training Costs by 30% Requires Just a Single Line of Code From FP8 Mixed Precision Training Upgrades](https://company.hpc-ai.com/blog/reducing-ai-large-model-training-costs-by-30-requires-just-a-single-line-of-code-from-fp8-mixed-precision-training-upgrades)
* [2024/06] [Open-Sora Continues Open Source: Generate Any 16-Second 720p HD Video with One Click, Model Weights Ready to Use](https://hpc-ai.com/blog/open-sora-from-hpc-ai-tech-team-continues-open-source-generate-any-16-second-720p-hd-video-with-one-click-model-weights-ready-to-use)
* [2024/05] [Large AI Models Inference Speed Doubled, Colossal-Inference Open Source Release](https://hpc-ai.com/blog/colossal-inference)
* [2024/04] [Open-Sora Unveils Major Upgrade: Embracing Open Source with Single-Shot 16-Second Video Generation and 720p Resolution](https://hpc-ai.com/blog/open-soras-comprehensive-upgrade-unveiled-embracing-16-second-video-generation-and-720p-resolution-in-open-source)
* [2024/04] [Most cost-effective solutions for inference, fine-tuning and pretraining, tailored to LLaMA3 series](https://hpc-ai.com/blog/most-cost-effective-solutions-for-inference-fine-tuning-and-pretraining-tailored-to-llama3-series)
-* [2024/03] [314 Billion Parameter Grok-1 Inference Accelerated by 3.8x, Efficient and Easy-to-Use PyTorch+HuggingFace version is Here](https://hpc-ai.com/blog/314-billion-parameter-grok-1-inference-accelerated-by-3.8x-efficient-and-easy-to-use-pytorchhuggingface-version-is-here)
-* [2024/03] [Open-Sora: Revealing Complete Model Parameters, Training Details, and Everything for Sora-like Video Generation Models](https://hpc-ai.com/blog/open-sora-v1.0)
-* [2024/03] [Open-Sora:Sora Replication Solution with 46% Cost Reduction, Sequence Expansion to Nearly a Million](https://hpc-ai.com/blog/open-sora)
-* [2024/01] [Inference Performance Improved by 46%, Open Source Solution Breaks the Length Limit of LLM for Multi-Round Conversations](https://hpc-ai.com/blog/Colossal-AI-SwiftInfer)
-* [2023/07] [HPC-AI Tech Raises 22 Million USD in Series A Funding](https://www.hpc-ai.tech/blog/hpc-ai-tech-raises-22-million-usd-in-series-a-funding-to-fuel-team-expansion-and-business-growth)
## 目录
diff --git a/docs/source/en/features/mixed_precision_training_with_booster.md b/docs/source/en/features/mixed_precision_training_with_booster.md
index 65304b1f4..1e17c2bb5 100644
--- a/docs/source/en/features/mixed_precision_training_with_booster.md
+++ b/docs/source/en/features/mixed_precision_training_with_booster.md
@@ -16,7 +16,7 @@ Author: [Mingyan Jiang](https://github.com/jiangmingyan)
AMP stands for automatic mixed precision training.
In Colossal-AI, we have incorporated different implementations of mixed precision training:
-1. torch.cuda.amp
+1. torch.amp
2. apex.amp
3. naive amp
diff --git a/docs/source/zh-Hans/features/mixed_precision_training_with_booster.md b/docs/source/zh-Hans/features/mixed_precision_training_with_booster.md
index da377ceb2..93a69830c 100644
--- a/docs/source/zh-Hans/features/mixed_precision_training_with_booster.md
+++ b/docs/source/zh-Hans/features/mixed_precision_training_with_booster.md
@@ -16,7 +16,7 @@
AMP 代表自动混合精度训练。
在 Colossal-AI 中, 我们结合了混合精度训练的不同实现:
-1. torch.cuda.amp
+1. torch.amp
2. apex.amp
3. naive amp
diff --git a/examples/language/llama/benchmark.py b/examples/language/llama/benchmark.py
index 4976f0c37..ad5d35161 100644
--- a/examples/language/llama/benchmark.py
+++ b/examples/language/llama/benchmark.py
@@ -163,6 +163,8 @@ def main():
enable_async_reduce=not args.disable_async_reduce,
use_fp8=args.use_fp8,
fp8_communication=args.use_fp8_comm,
+ use_fp8=args.use_fp8,
+ fp8_communication=args.use_fp8_comm,
)
elif args.plugin == "gemini_auto":
plugin = GeminiPlugin(
@@ -177,6 +179,8 @@ def main():
enable_flash_attention=args.xformers,
use_fp8=args.use_fp8,
fp8_communication=args.use_fp8_comm,
+ use_fp8=args.use_fp8,
+ fp8_communication=args.use_fp8_comm,
)
elif args.plugin == "fsdp":
if use_empty_init:
@@ -188,6 +192,7 @@ def main():
),
param_init_fn=empty_init(),
fp8_communication=args.use_fp8_comm,
+ fp8_communication=args.use_fp8_comm,
)
else:
plugin = TorchFSDPPlugin(
@@ -209,6 +214,7 @@ def main():
cpu_offload=CPUOffload(offload_params=True),
param_init_fn=empty_init(),
fp8_communication=args.use_fp8_comm,
+ fp8_communication=args.use_fp8_comm,
)
else:
plugin = TorchFSDPPlugin(
@@ -219,6 +225,7 @@ def main():
),
cpu_offload=CPUOffload(offload_params=True),
fp8_communication=args.use_fp8_comm,
+ fp8_communication=args.use_fp8_comm,
)
elif args.plugin == "3d":
if args.pp_style == "zbv":
diff --git a/extensions/cpp_extension.py b/extensions/cpp_extension.py
index aaa43f964..a92195dda 100644
--- a/extensions/cpp_extension.py
+++ b/extensions/cpp_extension.py
@@ -79,7 +79,7 @@ class _CppExtension(_Extension):
# check if the kernel has been built
compiled_before = False
- kernel_file_path = build_directory.joinpath(f"{self.name}.o")
+ kernel_file_path = build_directory.joinpath(f"{self.name}.so")
if kernel_file_path.exists():
compiled_before = True
diff --git a/extensions/cuda_extension.py b/extensions/cuda_extension.py
index da15bcd57..214b83ec8 100644
--- a/extensions/cuda_extension.py
+++ b/extensions/cuda_extension.py
@@ -74,7 +74,7 @@ class _CudaExtension(_CppExtension):
# check if the kernel has been built
compiled_before = False
- kernel_file_path = build_directory.joinpath(f"{self.name}.o")
+ kernel_file_path = build_directory.joinpath(f"{self.name}.so")
if kernel_file_path.exists():
compiled_before = True
diff --git a/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py b/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py
index 5aa8584a0..a45beb771 100644
--- a/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py
+++ b/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py
@@ -41,22 +41,7 @@ class Conv1D(nn.Module):
return x
-def rearrange(tensor: torch.Tensor, dim: int):
- tensor = tensor.clone()
- world_size = 2
- order = torch.arange(world_size * 3)
- new_order = []
- for i in range(world_size):
- new_order.append(order[i::world_size])
- new_order = torch.cat(new_order)
-
- tensor_chunks = torch.chunk(tensor, world_size * 3, dim=dim)
- rearanged_tensor_chunks = [tensor_chunks[i] for i in new_order]
- rearanged_tensor = torch.cat(rearanged_tensor_chunks, dim=dim)
- return rearanged_tensor
-
-
-def check_linear_conv_1d_col(lazy_init: bool, seq_parallel_mode: str, overlap: bool):
+def check_linear_conv_1d_col(lazy_init: bool, seq_parallel_mode: str):
ctx = LazyInitContext() if lazy_init else nullcontext()
linear = Conv1D(192, 48).cuda()
with ctx:
@@ -66,8 +51,7 @@ def check_linear_conv_1d_col(lazy_init: bool, seq_parallel_mode: str, overlap: b
process_group=None,
gather_output=True,
seq_parallel_mode=seq_parallel_mode,
- n_fused=3,
- overlap=overlap,
+ split_sizes=[64] * 3,
)
assert linear.weight.shape == torch.Size([48, 192])
@@ -88,13 +72,13 @@ def check_linear_conv_1d_col(lazy_init: bool, seq_parallel_mode: str, overlap: b
x.expand_as(x.clone()) if seq_parallel_mode is None else torch.chunk(x.clone(), 2, dim=1)[dist.get_rank()]
)
gather_out = linear_conv_col(x_for_shard)
- assert_close(rearrange(out, -1), gather_out)
+ assert_close(out, gather_out)
# check backward correctness
out.sum().backward()
gather_out.sum().backward()
- target_grad = split_fused_qkv_in_gpt2_style(linear.weight.grad, 3, None, True)
+ target_grad = split_fused_qkv_in_gpt2_style(linear.weight.grad, [64] * 3, None, True)
assert_close(target_grad, linear_conv_col.weight.grad)
@@ -136,9 +120,8 @@ def check_linear_conv_1d_row(lazy_init: bool, seq_parallel_mode: bool):
@parameterize("lazy_init", [False, True])
@parameterize("seq_parallel_mode", ["split_gather", None])
-@parameterize("overlap", [True])
-def check_gpt2_qkv_fused_linear_1d(lazy_init: bool, seq_parallel_mode: bool, overlap: bool):
- check_linear_conv_1d_col(lazy_init, seq_parallel_mode, overlap)
+def check_gpt2_qkv_fused_linear_1d(lazy_init: bool, seq_parallel_mode: bool):
+ check_linear_conv_1d_col(lazy_init, seq_parallel_mode)
check_linear_conv_1d_row(lazy_init, seq_parallel_mode)
diff --git a/tests/test_shardformer/test_layer/test_qkv_fused_linear_1d.py b/tests/test_shardformer/test_layer/test_qkv_fused_linear_1d.py
index dc14fd591..fccba564f 100644
--- a/tests/test_shardformer/test_layer/test_qkv_fused_linear_1d.py
+++ b/tests/test_shardformer/test_layer/test_qkv_fused_linear_1d.py
@@ -2,13 +2,12 @@ import os
from contextlib import nullcontext
import torch
-import torch.distributed as dist
import torch.nn as nn
from torch.testing import assert_close
import colossalai
from colossalai.lazy import LazyInitContext
-from colossalai.shardformer.layer import GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row
+from colossalai.shardformer.layer import FusedLinear1D_Col, FusedLinear1D_Row
from colossalai.shardformer.layer.qkv_fused_linear import split_fused_qkv_in_gpt2_style
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
@@ -16,93 +15,55 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"
-class Conv1D(nn.Module):
- """
- 1D-convolutional layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2).
-
- Basically works like a linear layer but the weights are transposed.
-
- Args:
- nf (`int`): The number of output features.
- nx (`int`): The number of input features.
- """
-
- def __init__(self, nf, nx):
- super().__init__()
- self.nf = nf
- self.weight = nn.Parameter(torch.empty(nx, nf))
- self.bias = nn.Parameter(torch.zeros(nf))
- nn.init.normal_(self.weight, std=0.02)
-
- def forward(self, x):
- size_out = x.size()[:-1] + (self.nf,)
- x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
- x = x.view(size_out)
- return x
-
-
-def rearrange(tensor: torch.Tensor, dim: int):
- tensor = tensor.clone()
- world_size = 2
- order = torch.arange(world_size * 3)
- new_order = []
- for i in range(world_size):
- new_order.append(order[i::world_size])
- new_order = torch.cat(new_order)
-
- tensor_chunks = torch.chunk(tensor, world_size * 3, dim=dim)
- rearanged_tensor_chunks = [tensor_chunks[i] for i in new_order]
- rearanged_tensor = torch.cat(rearanged_tensor_chunks, dim=dim)
- return rearanged_tensor
-
-
@parameterize("lazy_init", [False, True])
-def check_linear_conv_1d_col(lazy_init: bool):
+def check_linear_1d_col(lazy_init: bool):
ctx = LazyInitContext() if lazy_init else nullcontext()
- linear = Conv1D(192, 48).cuda()
+ linear = nn.Linear(8, 80).cuda()
with ctx:
- linear_copy = Conv1D(192, 48).cuda()
- linear_conv_col = GPT2FusedLinearConv1D_Col.from_native_module(
- linear_copy, process_group=None, gather_output=True, n_fused=3
+ linear_copy = nn.Linear(8, 80).cuda()
+ linear_col = FusedLinear1D_Col.from_native_module(
+ linear_copy, process_group=None, gather_output=True, split_sizes=[32, 32, 16]
)
- assert linear.weight.shape == torch.Size([48, 192])
- assert linear.bias.shape == torch.Size([192])
- assert linear_conv_col.weight.shape == torch.Size([48, 96])
- assert linear_conv_col.bias.shape == torch.Size([96])
- assert linear_copy.weight is linear_conv_col.weight
- assert linear_copy.bias is linear_conv_col.bias
+ assert linear.weight.shape == torch.Size([80, 8])
+ assert linear.bias.shape == torch.Size([80])
+ assert linear_col.weight.shape == torch.Size([40, 8])
+ assert linear_col.bias.shape == torch.Size([40])
+ assert linear_copy.weight is linear_col.weight
+ assert linear_copy.bias is linear_col.bias
# ensure weights are reversibly loadable
- linear_conv_col.load_state_dict(linear.state_dict())
- linear.load_state_dict(linear_conv_col.state_dict())
+ linear_col.load_state_dict(linear.state_dict())
+ linear.load_state_dict(linear_col.state_dict())
# check computation correctness
- x = torch.rand(4, 48).cuda()
+ x = torch.rand(4, 8).cuda()
out = linear(x)
- gather_out = linear_conv_col(x)
- assert_close(rearrange(out, 1), gather_out)
+ gather_out = linear_col(x)
+ assert_close(out, gather_out)
# check backward correctness
out.sum().backward()
gather_out.sum().backward()
- target_grad = split_fused_qkv_in_gpt2_style(linear.weight.grad, 3, None, True)
- assert_close(target_grad, linear_conv_col.weight.grad)
+ target_grad = split_fused_qkv_in_gpt2_style(linear.weight.grad, [32, 32, 16], None, False)
+ assert_close(target_grad, linear_col.weight.grad)
@parameterize("lazy_init", [False, True])
-def check_linear_conv_1d_row(lazy_init: bool):
+def check_linear_1d_row(lazy_init: bool):
ctx = LazyInitContext() if lazy_init else nullcontext()
- linear = Conv1D(192, 48).cuda()
+ linear = nn.Linear(80, 8).cuda()
with ctx:
- linear_copy = Conv1D(192, 48).cuda()
- linear_row = GPT2FusedLinearConv1D_Row.from_native_module(linear_copy, process_group=None, parallel_input=False)
+ linear_copy = nn.Linear(80, 8).cuda()
+ linear_row = FusedLinear1D_Row.from_native_module(
+ linear_copy, process_group=None, split_sizes=[32, 32, 16], parallel_input=False
+ )
- assert linear.weight.shape == torch.Size([48, 192])
- assert linear_row.weight.shape == torch.Size([24, 192])
- assert linear_row.bias.shape == torch.Size([192])
+ assert linear.weight.shape == torch.Size([8, 80])
+ assert linear_row.weight.shape == torch.Size([8, 40])
+ assert linear_row.bias.shape == torch.Size([8])
assert linear_copy.weight is linear_row.weight
assert linear_copy.bias is linear_row.bias
@@ -111,7 +72,7 @@ def check_linear_conv_1d_row(lazy_init: bool):
linear.load_state_dict(linear_row.state_dict())
# check computation correctness
- x = torch.rand(4, 48).cuda()
+ x = torch.rand(4, 80).cuda()
out = linear(x)
gather_out = linear_row(x)
assert_close(out, gather_out)
@@ -120,17 +81,51 @@ def check_linear_conv_1d_row(lazy_init: bool):
out.sum().backward()
gather_out.sum().backward()
- rank = dist.get_rank()
- target_grad = torch.chunk(linear.weight.grad, 2, dim=0)[rank]
+ target_grad = split_fused_qkv_in_gpt2_style(linear.weight.grad, [32, 32, 16], None, True)
assert_close(target_grad, linear_row.weight.grad)
+@parameterize("lazy_init", [False, True])
+def check_linear_1d_col_row(lazy_init: bool):
+ ctx = LazyInitContext() if lazy_init else nullcontext()
+
+ linear1 = nn.Linear(8, 80).cuda()
+ linear2 = nn.Linear(80, 8).cuda()
+ with ctx:
+ linear1_copy = nn.Linear(8, 80).cuda()
+ linear2_copy = nn.Linear(80, 8).cuda()
+ linear_col = FusedLinear1D_Col.from_native_module(linear1_copy, process_group=None, split_sizes=[32, 32, 16])
+ linear_row = FusedLinear1D_Row.from_native_module(
+ linear2_copy,
+ process_group=None,
+ split_sizes=[32, 32, 16],
+ )
+ # ensure weights are reversibly loadable
+ linear_col.load_state_dict(linear1.state_dict())
+ linear_row.load_state_dict(linear2.state_dict())
+
+ # check computation correctness
+ x = torch.rand(4, 8).cuda()
+ target_out = linear2(linear1(x))
+ out = linear_row(linear_col(x))
+ assert_close(out, target_out)
+
+ # check backward correctness
+ target_out.sum().backward()
+ out.sum().backward()
+
+ target_grad1 = split_fused_qkv_in_gpt2_style(linear1.weight.grad, [32, 32, 16], None, False)
+ assert_close(target_grad1, linear_col.weight.grad)
+ target_grad2 = split_fused_qkv_in_gpt2_style(linear2.weight.grad, [32, 32, 16], None, True)
+ assert_close(target_grad2, linear_row.weight.grad)
+
+
def run_dist(rank, world_size, port):
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
- # test for linear conv
- check_linear_conv_1d_col()
- check_linear_conv_1d_row()
+ check_linear_1d_col()
+ check_linear_1d_row()
+ check_linear_1d_col_row()
@rerun_if_address_is_in_use()
diff --git a/tests/test_shardformer/test_layer/test_ring_attn.py b/tests/test_shardformer/test_layer/test_ring_attn.py
index 1c7647a7d..6ebd8da73 100644
--- a/tests/test_shardformer/test_layer/test_ring_attn.py
+++ b/tests/test_shardformer/test_layer/test_ring_attn.py
@@ -5,6 +5,7 @@ from flash_attn import flash_attn_qkvpacked_func, flash_attn_varlen_qkvpacked_fu
from torch.testing import assert_close
import colossalai
+from colossalai.cluster import ProcessGroupMesh
from colossalai.shardformer.layer import AttnMaskType
from colossalai.shardformer.layer.attn import AttnMaskType, RingAttention
from colossalai.shardformer.layer.utils import split_batch_zigzag, split_varlen_zigzag
@@ -17,11 +18,14 @@ from colossalai.utils import get_current_device
@parameterize("nheads", [5])
@parameterize("d", [128])
@parameterize("dtype", [torch.bfloat16, torch.float16])
-def check_ring_attn(seq_len, bs, nheads, d, dtype):
+def check_ring_attn(seq_len, bs, nheads, d, dtype, inner_ring_size):
torch.cuda.manual_seed(2)
device = get_current_device()
sp_group = dist.group.WORLD
+ dp_size, pp_size, tp_size = 1, 1, 1
sp_size = dist.get_world_size()
+ sp_axis = 2
+ pg_mesh = ProcessGroupMesh(dp_size, pp_size, sp_size, tp_size)
# Some outliers may seem large, but our errors are still lower than
# than Megatron-LM context parallel's
# (https://github.com/NVIDIA/TransformerEngine/blob/33a3d02f81c56e6f7b542c09bfa86657078d57fb/tests/pytorch/fused_attn/run_fused_attn_with_cp.py#L215)
@@ -40,11 +44,11 @@ def check_ring_attn(seq_len, bs, nheads, d, dtype):
q,
k,
v,
- sp_group,
+ sp_axis,
AttnMaskType.CAUSAL,
return_softmax=True,
- inner_ring_size=max(2, sp_size // 2),
- # inner_ring_size=4
+ inner_ring_size=inner_ring_size,
+ pg_mesh=pg_mesh,
)
ring_out = ring_out.transpose(1, 2)
out, lse, _ = flash_attn_qkvpacked_func(
@@ -83,6 +87,7 @@ def check_packed_seq(seqlen, bs, nheads, d, dtype):
device = get_current_device()
sp_group = dist.group.WORLD
sp_size = dist.get_world_size()
+ sp_axis = 2
atol = rtol = 7e-3
torch.cuda.manual_seed(2)
# Prepare varlen attention mask
@@ -123,10 +128,11 @@ def check_packed_seq(seqlen, bs, nheads, d, dtype):
q_ring,
k_ring,
v_ring,
- sp_group,
+ sp_axis,
**mask_info,
pad_output=False,
return_softmax=True,
+ pg_mesh=ProcessGroupMesh(1, 1, sp_size, 1),
# deterministic=True
)
ring_out = ring_out.transpose(1, 2).reshape(-1, nheads, d)
@@ -161,12 +167,12 @@ def check_packed_seq(seqlen, bs, nheads, d, dtype):
def launch_single_ring(rank, world_size, port):
colossalai.launch(rank, world_size, "localhost", port)
check_packed_seq()
- check_ring_attn()
+ check_ring_attn(inner_ring_size=None)
def launch_double_ring(rank, world_size, port):
colossalai.launch(rank, world_size, "localhost", port)
- check_ring_attn()
+ check_ring_attn(inner_ring_size=2)
@rerun_if_address_is_in_use()
diff --git a/version.txt b/version.txt
index 6f2743d65..0bfccb080 100644
--- a/version.txt
+++ b/version.txt
@@ -1 +1 @@
-0.4.4
+0.4.5