mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-04-28 19:55:29 +00:00
fix CI bugs
This commit is contained in:
parent
2a73e828eb
commit
fab294c7f4
@ -191,7 +191,14 @@ class InferenceEngine:
|
|||||||
prompt = None
|
prompt = None
|
||||||
else:
|
else:
|
||||||
prompt = prompts[i]
|
prompt = prompts[i]
|
||||||
block_table = torch.full([self.inference_config.max_seq_len], -1, device=self.device)
|
|
||||||
|
max_blocks_per_sequence = (
|
||||||
|
self.inference_config.max_input_len
|
||||||
|
+ self.inference_config.max_output_len
|
||||||
|
+ self.inference_config.block_size
|
||||||
|
- 1
|
||||||
|
) // self.inference_config.block_size
|
||||||
|
block_table = torch.full([max_blocks_per_sequence], -1, device=self.device)
|
||||||
sequence = Sequence(
|
sequence = Sequence(
|
||||||
request_id,
|
request_id,
|
||||||
prompt,
|
prompt,
|
||||||
|
@ -7,7 +7,7 @@ from colossalai.inference.config import InferenceConfig
|
|||||||
from colossalai.inference.kv_cache import KVCacheManager
|
from colossalai.inference.kv_cache import KVCacheManager
|
||||||
from colossalai.inference.logit_processors import logit_processor
|
from colossalai.inference.logit_processors import logit_processor
|
||||||
from colossalai.inference.sampler import *
|
from colossalai.inference.sampler import *
|
||||||
from colossalai.inference.struct import BatchInfo, Sequence
|
from colossalai.inference.struct import BatchInfo, RequestStatus, Sequence
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
|
|
||||||
logger = get_dist_logger(__name__)
|
logger = get_dist_logger(__name__)
|
||||||
@ -104,7 +104,7 @@ class RequestHandler:
|
|||||||
f"the prompt(Request id = {seq.request_id}) length is longer than max_input_len, abort this sequence."
|
f"the prompt(Request id = {seq.request_id}) length is longer than max_input_len, abort this sequence."
|
||||||
)
|
)
|
||||||
self.abort_sequence(seq.request_id)
|
self.abort_sequence(seq.request_id)
|
||||||
remove_list.append(seq)
|
break
|
||||||
# Try to allocate cache blocks for the sequence.
|
# Try to allocate cache blocks for the sequence.
|
||||||
if self.cache_manager.check_allocation(seq):
|
if self.cache_manager.check_allocation(seq):
|
||||||
# If succeed, add the sequence to running list.
|
# If succeed, add the sequence to running list.
|
||||||
@ -139,9 +139,10 @@ class RequestHandler:
|
|||||||
"""
|
"""
|
||||||
Abort the request.
|
Abort the request.
|
||||||
"""
|
"""
|
||||||
seq, _ = self._find_sequence(request_id)
|
seq, priority = self._find_sequence(request_id)
|
||||||
if seq.status.is_waiting:
|
if seq.status == RequestStatus.WAITING:
|
||||||
seq.mark_aborted()
|
seq.mark_aborted()
|
||||||
|
self.waiting_list[priority].remove(seq)
|
||||||
elif seq.status.is_running():
|
elif seq.status.is_running():
|
||||||
self.cache_manager.free_block_table(seq.block_table)
|
self.cache_manager.free_block_table(seq.block_table)
|
||||||
self.running_list.remove(seq)
|
self.running_list.remove(seq)
|
||||||
|
@ -217,6 +217,8 @@ class PagedAttention:
|
|||||||
|
|
||||||
attn_weights = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(head_size)
|
attn_weights = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(head_size)
|
||||||
|
|
||||||
|
padding_mask = None
|
||||||
|
|
||||||
if attn_mask is not None:
|
if attn_mask is not None:
|
||||||
padding_mask = AttentionMaskConverter._expand_mask(attn_mask, q.dtype, seq_len)
|
padding_mask = AttentionMaskConverter._expand_mask(attn_mask, q.dtype, seq_len)
|
||||||
|
|
||||||
@ -279,11 +281,12 @@ class PagedAttention:
|
|||||||
if attn_weights.size() != (bsz, num_heads, 1, seq_len):
|
if attn_weights.size() != (bsz, num_heads, 1, seq_len):
|
||||||
raise ValueError(f"Got wrong attn_weights, should be in shape {(bsz,num_heads,1,seq_len)}.")
|
raise ValueError(f"Got wrong attn_weights, should be in shape {(bsz,num_heads,1,seq_len)}.")
|
||||||
|
|
||||||
|
padding_mask = None
|
||||||
if attn_mask is not None:
|
if attn_mask is not None:
|
||||||
padding_mask = AttentionMaskConverter._expand_mask(attn_mask, q.dtype, query_length)
|
padding_mask = AttentionMaskConverter._expand_mask(attn_mask, q.dtype, q_length)
|
||||||
|
|
||||||
attn_mask = AttentionMaskConverter._make_causal_mask(
|
attn_mask = AttentionMaskConverter._make_causal_mask(
|
||||||
(bsz, q_length), q.dtype, q.device, past_key_values_length=seq_len - query_length
|
(bsz, q_length), q.dtype, q.device, past_key_values_length=seq_len - q_length
|
||||||
)
|
)
|
||||||
|
|
||||||
if padding_mask is not None:
|
if padding_mask is not None:
|
||||||
|
@ -11,6 +11,7 @@ from colossalai.inference.config import InferenceConfig
|
|||||||
from colossalai.inference.core.engine import InferenceEngine
|
from colossalai.inference.core.engine import InferenceEngine
|
||||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||||
|
|
||||||
|
|
||||||
def setup_seed(seed):
|
def setup_seed(seed):
|
||||||
torch.manual_seed(seed)
|
torch.manual_seed(seed)
|
||||||
torch.cuda.manual_seed_all(seed)
|
torch.cuda.manual_seed_all(seed)
|
||||||
@ -34,7 +35,7 @@ def check_inference_engine(test_cai=False):
|
|||||||
"介绍一下武汉,",
|
"介绍一下武汉,",
|
||||||
]
|
]
|
||||||
|
|
||||||
output_len = 128
|
output_len = 38
|
||||||
do_sample = True
|
do_sample = True
|
||||||
top_p = 0.5
|
top_p = 0.5
|
||||||
top_k = 50
|
top_k = 50
|
||||||
|
@ -57,7 +57,7 @@ def check_request_handler():
|
|||||||
block_size=16,
|
block_size=16,
|
||||||
eos_token_id=0,
|
eos_token_id=0,
|
||||||
sample_params=None,
|
sample_params=None,
|
||||||
block_table=torch.tensor([0, 0]),
|
block_table=torch.tensor([-1, -1]),
|
||||||
)
|
)
|
||||||
request_handler.add_sequence(seq1)
|
request_handler.add_sequence(seq1)
|
||||||
# the priority should be 1
|
# the priority should be 1
|
||||||
|
Loading…
Reference in New Issue
Block a user