mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 09:07:51 +00:00
[format] applied code formatting on changed files in pull request 4926 (#5007)
Co-authored-by: github-actions <github-actions@github.com>
This commit is contained in:
committed by
GitHub
parent
1a3315e336
commit
c36e782d80
@@ -218,7 +218,7 @@ class TPInferEngine:
|
||||
), "Discrepancy between the tp size of TPInferEngine and the tp size of shard config"
|
||||
model_name = model.__class__.__name__
|
||||
assert model_name in self.supported_models, f"Unsupported model cls {model_name} for TP inference."
|
||||
|
||||
|
||||
model = model.model if self.shard_config.inference_gptq else model
|
||||
policy = get_autopolicy(model, shard_config=self.shard_config)
|
||||
|
||||
@@ -311,7 +311,7 @@ class TPInferEngine:
|
||||
seq_start_indexes[i] = start_index
|
||||
start_index += curr_seq_len
|
||||
max_len_in_batch = curr_seq_len if curr_seq_len > max_len_in_batch else max_len_in_batch
|
||||
|
||||
|
||||
block_loc = torch.empty((batch_size, self.max_input_len + self.max_output_len), dtype=torch.long, device="cuda")
|
||||
batch_infer_state = BatchInferState(batch_size, max_len_in_batch)
|
||||
batch_infer_state.seq_len = seq_lengths.to("cuda")
|
||||
|
Reference in New Issue
Block a user