mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-04-28 19:55:29 +00:00
* add infer_struct and infer_config * update codes * change InferConfig * Add hf_model_config to the engine * rm _get_hf_model_config * update codes * made adjustments according to the feedback from the reviewer. * update codes * add ci test for config and struct * Add the logic of the inference engine * update engine and test * Recover cache_manager.py * add logger * fix conflict * update codes * update codes * update model and tokenizer * fix add the logic about shardformer * change kvcache_manager docstring * add policy * fix ci bug in test_kvcache_manager.py * remove codes related o tokenizer and move model_policy * fix code style * add ordered_set to requirements-infer.txt * Delete extra empty lines * add ordered_set to requirements-test.txt
41 lines
1.2 KiB
Python
Executable File
41 lines
1.2 KiB
Python
Executable File
import copy
|
|
|
|
from colossalai.shardformer import ShardConfig, ShardFormer
|
|
|
|
|
|
def build_model(
|
|
model_fn,
|
|
enable_fused_normalization=False,
|
|
enable_tensor_parallelism=False,
|
|
enable_flash_attention=False,
|
|
enable_jit_fused=False,
|
|
):
|
|
# create new model
|
|
org_model = model_fn()
|
|
|
|
# shard model
|
|
shard_config = ShardConfig(
|
|
enable_fused_normalization=enable_fused_normalization,
|
|
enable_tensor_parallelism=enable_tensor_parallelism,
|
|
enable_flash_attention=enable_flash_attention,
|
|
enable_jit_fused=enable_jit_fused,
|
|
)
|
|
model_copy = copy.deepcopy(org_model)
|
|
shard_former = ShardFormer(shard_config=shard_config)
|
|
sharded_model, shared_params = shard_former.optimize(model_copy)
|
|
return org_model.cuda(), sharded_model.cuda()
|
|
|
|
|
|
def run_infer(original_model, sharded_model, data_gen_fn, output_transform_fn):
|
|
# prepare input
|
|
data = data_gen_fn()
|
|
data = {k: v.cuda() for k, v in data.items()}
|
|
# run forward
|
|
org_output = original_model(**data)
|
|
org_output = output_transform_fn(org_output)
|
|
|
|
shard_output = sharded_model(**data)
|
|
shard_output = output_transform_fn(shard_output)
|
|
|
|
return org_output, shard_output
|