[format] applied code formatting on changed files in pull request 4820 (#4886)

Co-authored-by: github-actions <github-actions@github.com>
This commit is contained in:
github-actions[bot]
2023-10-18 11:46:37 +08:00
committed by GitHub
parent c7aa319ba0
commit 486d06a2d5
13 changed files with 297 additions and 258 deletions

View File

@@ -3,7 +3,7 @@ from typing import Dict, Tuple
import torch
__all__ = 'MicroBatchManager'
__all__ = "MicroBatchManager"
class Status(Enum):
@@ -13,7 +13,7 @@ class Status(Enum):
COOLDOWN = 4
class MicroBatchDescription():
class MicroBatchDescription:
"""
This is the class to record the infomation of each microbatch, and also do some update operation.
This clase is the base class of `HeadMicroBatchDescription` and `BodyMicroBatchDescription`, for more
@@ -30,14 +30,14 @@ class MicroBatchDescription():
output_dict: Dict[str, torch.Tensor],
new_length: int,
) -> None:
assert output_dict.get('hidden_states') is not None
self.mb_length = output_dict['hidden_states'].shape[-2]
assert output_dict.get("hidden_states") is not None
self.mb_length = output_dict["hidden_states"].shape[-2]
self.target_length = self.mb_length + new_length
self.kv_cache = ()
def update(self, output_dict: Dict[str, torch.Tensor] = None, new_token: torch.Tensor = None):
if output_dict is not None:
self._update_kvcache(output_dict['past_key_values'])
self._update_kvcache(output_dict["past_key_values"])
def _update_kvcache(self, kv_cache: Tuple):
assert type(kv_cache) == tuple
@@ -64,7 +64,6 @@ class MicroBatchDescription():
Return the current sequnence length of micro batch
"""
pass
class HeadMicroBatchDescription(MicroBatchDescription):
@@ -80,13 +79,14 @@ class HeadMicroBatchDescription(MicroBatchDescription):
"""
def __init__(self, inputs_dict: Dict[str, torch.Tensor], output_dict: Dict[str, torch.Tensor],
new_length: int) -> None:
def __init__(
self, inputs_dict: Dict[str, torch.Tensor], output_dict: Dict[str, torch.Tensor], new_length: int
) -> None:
super().__init__(inputs_dict, output_dict, new_length)
assert inputs_dict is not None
assert inputs_dict.get('input_ids') is not None and inputs_dict.get('attention_mask') is not None
self.input_ids = inputs_dict['input_ids']
self.attn_mask = inputs_dict['attention_mask']
assert inputs_dict.get("input_ids") is not None and inputs_dict.get("attention_mask") is not None
self.input_ids = inputs_dict["input_ids"]
self.attn_mask = inputs_dict["attention_mask"]
self.new_tokens = None
def update(self, output_dict: Dict[str, torch.Tensor] = None, new_token: torch.Tensor = None):
@@ -104,7 +104,8 @@ class HeadMicroBatchDescription(MicroBatchDescription):
def _update_attnmask(self):
self.attn_mask = torch.cat(
(self.attn_mask, torch.ones((self.attn_mask.shape[0], 1), dtype=torch.int64, device='cuda')), dim=-1)
(self.attn_mask, torch.ones((self.attn_mask.shape[0], 1), dtype=torch.int64, device="cuda")), dim=-1
)
@property
def cur_length(self):
@@ -127,8 +128,9 @@ class BodyMicroBatchDescription(MicroBatchDescription):
output_dict (Dict[str, torch.Tensor]): the outputs of previous stage. The key should have `hidden_states` and `past_key_values`.
"""
def __init__(self, inputs_dict: Dict[str, torch.Tensor], output_dict: Dict[str, torch.Tensor],
new_length: int) -> None:
def __init__(
self, inputs_dict: Dict[str, torch.Tensor], output_dict: Dict[str, torch.Tensor], new_length: int
) -> None:
super().__init__(inputs_dict, output_dict, new_length)
def update(self, output_dict: Dict[str, torch.Tensor] = None, new_token: torch.Tensor = None):
@@ -146,8 +148,8 @@ class BodyMicroBatchDescription(MicroBatchDescription):
return self.kv_cache[0][0].shape[-2] + 1
class MicroBatchManager():
'''
class MicroBatchManager:
"""
MicroBatchManager is a class that manages the micro batch.
Args:
@@ -156,7 +158,7 @@ class MicroBatchManager():
micro_batch_size (int): the micro batch size.
micro_batch_buffer_size (int): the buffer size for micro batch. Normally, it should be the same as the number of pipeline stages.
'''
"""
def __init__(self, stage: int, new_length: int, micro_batch_size: int, micro_batch_buffer_size: int):
self.stage = stage