mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-10-02 07:39:24 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
from typing import Any, Deque, Hashable, List, Tuple
|
||||
|
||||
import torch
|
||||
from typing import List, Deque, Tuple, Hashable, Any
|
||||
from energonai import BatchManager, SubmitEntry, TaskEntry
|
||||
|
||||
|
||||
@@ -10,15 +11,15 @@ class BatchManagerForGeneration(BatchManager):
|
||||
self.pad_token_id = pad_token_id
|
||||
|
||||
def _left_padding(self, batch_inputs):
|
||||
max_len = max(len(inputs['input_ids']) for inputs in batch_inputs)
|
||||
outputs = {'input_ids': [], 'attention_mask': []}
|
||||
max_len = max(len(inputs["input_ids"]) for inputs in batch_inputs)
|
||||
outputs = {"input_ids": [], "attention_mask": []}
|
||||
for inputs in batch_inputs:
|
||||
input_ids, attention_mask = inputs['input_ids'], inputs['attention_mask']
|
||||
input_ids, attention_mask = inputs["input_ids"], inputs["attention_mask"]
|
||||
padding_len = max_len - len(input_ids)
|
||||
input_ids = [self.pad_token_id] * padding_len + input_ids
|
||||
attention_mask = [0] * padding_len + attention_mask
|
||||
outputs['input_ids'].append(input_ids)
|
||||
outputs['attention_mask'].append(attention_mask)
|
||||
outputs["input_ids"].append(input_ids)
|
||||
outputs["attention_mask"].append(attention_mask)
|
||||
for k in outputs:
|
||||
outputs[k] = torch.tensor(outputs[k])
|
||||
return outputs, max_len
|
||||
@@ -26,7 +27,7 @@ class BatchManagerForGeneration(BatchManager):
|
||||
@staticmethod
|
||||
def _make_batch_key(entry: SubmitEntry) -> tuple:
|
||||
data = entry.data
|
||||
return (data['top_k'], data['top_p'], data['temperature'])
|
||||
return (data["top_k"], data["top_p"], data["temperature"])
|
||||
|
||||
def make_batch(self, q: Deque[SubmitEntry]) -> Tuple[TaskEntry, dict]:
|
||||
entry = q.popleft()
|
||||
@@ -37,7 +38,7 @@ class BatchManagerForGeneration(BatchManager):
|
||||
break
|
||||
if self._make_batch_key(entry) != self._make_batch_key(q[0]):
|
||||
break
|
||||
if q[0].data['max_tokens'] > entry.data['max_tokens']:
|
||||
if q[0].data["max_tokens"] > entry.data["max_tokens"]:
|
||||
break
|
||||
e = q.popleft()
|
||||
batch.append(e.data)
|
||||
@@ -45,12 +46,12 @@ class BatchManagerForGeneration(BatchManager):
|
||||
inputs, max_len = self._left_padding(batch)
|
||||
trunc_lens = []
|
||||
for data in batch:
|
||||
trunc_lens.append(max_len + data['max_tokens'])
|
||||
inputs['top_k'] = entry.data['top_k']
|
||||
inputs['top_p'] = entry.data['top_p']
|
||||
inputs['temperature'] = entry.data['temperature']
|
||||
inputs['max_tokens'] = max_len + entry.data['max_tokens']
|
||||
return TaskEntry(tuple(uids), inputs), {'trunc_lens': trunc_lens}
|
||||
trunc_lens.append(max_len + data["max_tokens"])
|
||||
inputs["top_k"] = entry.data["top_k"]
|
||||
inputs["top_p"] = entry.data["top_p"]
|
||||
inputs["temperature"] = entry.data["temperature"]
|
||||
inputs["max_tokens"] = max_len + entry.data["max_tokens"]
|
||||
return TaskEntry(tuple(uids), inputs), {"trunc_lens": trunc_lens}
|
||||
|
||||
def split_batch(self, task_entry: TaskEntry, trunc_lens: List[int] = []) -> List[Tuple[Hashable, Any]]:
|
||||
retval = []
|
||||
|
Reference in New Issue
Block a user