mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-29 13:35:48 +00:00
[format] applied code formatting on changed files in pull request 4820 (#4886)
Co-authored-by: github-actions <github-actions@github.com>
This commit is contained in:
committed by
GitHub
parent
c7aa319ba0
commit
486d06a2d5
@@ -3,7 +3,7 @@ from typing import Dict, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
__all__ = 'MicroBatchManager'
|
||||
__all__ = "MicroBatchManager"
|
||||
|
||||
|
||||
class Status(Enum):
|
||||
@@ -13,7 +13,7 @@ class Status(Enum):
|
||||
COOLDOWN = 4
|
||||
|
||||
|
||||
class MicroBatchDescription():
|
||||
class MicroBatchDescription:
|
||||
"""
|
||||
This is the class to record the infomation of each microbatch, and also do some update operation.
|
||||
This clase is the base class of `HeadMicroBatchDescription` and `BodyMicroBatchDescription`, for more
|
||||
@@ -30,14 +30,14 @@ class MicroBatchDescription():
|
||||
output_dict: Dict[str, torch.Tensor],
|
||||
new_length: int,
|
||||
) -> None:
|
||||
assert output_dict.get('hidden_states') is not None
|
||||
self.mb_length = output_dict['hidden_states'].shape[-2]
|
||||
assert output_dict.get("hidden_states") is not None
|
||||
self.mb_length = output_dict["hidden_states"].shape[-2]
|
||||
self.target_length = self.mb_length + new_length
|
||||
self.kv_cache = ()
|
||||
|
||||
def update(self, output_dict: Dict[str, torch.Tensor] = None, new_token: torch.Tensor = None):
|
||||
if output_dict is not None:
|
||||
self._update_kvcache(output_dict['past_key_values'])
|
||||
self._update_kvcache(output_dict["past_key_values"])
|
||||
|
||||
def _update_kvcache(self, kv_cache: Tuple):
|
||||
assert type(kv_cache) == tuple
|
||||
@@ -64,7 +64,6 @@ class MicroBatchDescription():
|
||||
Return the current sequnence length of micro batch
|
||||
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class HeadMicroBatchDescription(MicroBatchDescription):
|
||||
@@ -80,13 +79,14 @@ class HeadMicroBatchDescription(MicroBatchDescription):
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, inputs_dict: Dict[str, torch.Tensor], output_dict: Dict[str, torch.Tensor],
|
||||
new_length: int) -> None:
|
||||
def __init__(
|
||||
self, inputs_dict: Dict[str, torch.Tensor], output_dict: Dict[str, torch.Tensor], new_length: int
|
||||
) -> None:
|
||||
super().__init__(inputs_dict, output_dict, new_length)
|
||||
assert inputs_dict is not None
|
||||
assert inputs_dict.get('input_ids') is not None and inputs_dict.get('attention_mask') is not None
|
||||
self.input_ids = inputs_dict['input_ids']
|
||||
self.attn_mask = inputs_dict['attention_mask']
|
||||
assert inputs_dict.get("input_ids") is not None and inputs_dict.get("attention_mask") is not None
|
||||
self.input_ids = inputs_dict["input_ids"]
|
||||
self.attn_mask = inputs_dict["attention_mask"]
|
||||
self.new_tokens = None
|
||||
|
||||
def update(self, output_dict: Dict[str, torch.Tensor] = None, new_token: torch.Tensor = None):
|
||||
@@ -104,7 +104,8 @@ class HeadMicroBatchDescription(MicroBatchDescription):
|
||||
|
||||
def _update_attnmask(self):
|
||||
self.attn_mask = torch.cat(
|
||||
(self.attn_mask, torch.ones((self.attn_mask.shape[0], 1), dtype=torch.int64, device='cuda')), dim=-1)
|
||||
(self.attn_mask, torch.ones((self.attn_mask.shape[0], 1), dtype=torch.int64, device="cuda")), dim=-1
|
||||
)
|
||||
|
||||
@property
|
||||
def cur_length(self):
|
||||
@@ -127,8 +128,9 @@ class BodyMicroBatchDescription(MicroBatchDescription):
|
||||
output_dict (Dict[str, torch.Tensor]): the outputs of previous stage. The key should have `hidden_states` and `past_key_values`.
|
||||
"""
|
||||
|
||||
def __init__(self, inputs_dict: Dict[str, torch.Tensor], output_dict: Dict[str, torch.Tensor],
|
||||
new_length: int) -> None:
|
||||
def __init__(
|
||||
self, inputs_dict: Dict[str, torch.Tensor], output_dict: Dict[str, torch.Tensor], new_length: int
|
||||
) -> None:
|
||||
super().__init__(inputs_dict, output_dict, new_length)
|
||||
|
||||
def update(self, output_dict: Dict[str, torch.Tensor] = None, new_token: torch.Tensor = None):
|
||||
@@ -146,8 +148,8 @@ class BodyMicroBatchDescription(MicroBatchDescription):
|
||||
return self.kv_cache[0][0].shape[-2] + 1
|
||||
|
||||
|
||||
class MicroBatchManager():
|
||||
'''
|
||||
class MicroBatchManager:
|
||||
"""
|
||||
MicroBatchManager is a class that manages the micro batch.
|
||||
|
||||
Args:
|
||||
@@ -156,7 +158,7 @@ class MicroBatchManager():
|
||||
micro_batch_size (int): the micro batch size.
|
||||
micro_batch_buffer_size (int): the buffer size for micro batch. Normally, it should be the same as the number of pipeline stages.
|
||||
|
||||
'''
|
||||
"""
|
||||
|
||||
def __init__(self, stage: int, new_length: int, micro_batch_size: int, micro_batch_buffer_size: int):
|
||||
self.stage = stage
|
||||
|
Reference in New Issue
Block a user