mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 20:40:34 +00:00
[inference]fix import bug and delete down useless init (#4830)
* fix import bug and release useless init * fix * fix * fix
This commit is contained in:
@@ -20,30 +20,6 @@ MAX_OUTPUT_LEN = 100
|
||||
CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5")
|
||||
|
||||
|
||||
def init_to_get_rotary(self, base=10000):
|
||||
self.config.head_dim_ = self.config.hidden_size // self.config.num_attention_heads
|
||||
if not hasattr(self.config, "rope_scaling"):
|
||||
rope_scaling_factor = 1.0
|
||||
else:
|
||||
rope_scaling_factor = self.config.rope_scaling.factor if self.config.rope_scaling is not None else 1.0
|
||||
if hasattr(self.config, "max_sequence_length"):
|
||||
max_seq_len = self.config.max_sequence_length
|
||||
elif hasattr(self.config, "max_position_embeddings"):
|
||||
max_seq_len = self.config.max_position_embeddings * rope_scaling_factor
|
||||
else:
|
||||
max_seq_len = 2048 * rope_scaling_factor
|
||||
base = float(base)
|
||||
inv_freq = 1.0 / (
|
||||
base ** (torch.arange(0, self.config.head_dim_, 2, device="cpu", dtype=torch.float32) / self.config.head_dim_)
|
||||
)
|
||||
t = torch.arange(max_seq_len + 1024 * 64, device="cpu", dtype=torch.float32) / rope_scaling_factor
|
||||
freqs = torch.outer(t, inv_freq)
|
||||
|
||||
self._cos_cached = torch.cos(freqs).to(torch.float16).cuda()
|
||||
self._sin_cached = torch.sin(freqs).to(torch.float16).cuda()
|
||||
return
|
||||
|
||||
|
||||
@parameterize(
|
||||
"test_config",
|
||||
[
|
||||
@@ -56,7 +32,6 @@ def run_llama_test(test_config):
|
||||
sub_model_zoo = model_zoo.get_sub_registry("transformers_llama_for_casual_lm")
|
||||
for name, (model_fn, data_gen_fn, _, _, _) in sub_model_zoo.items():
|
||||
orig_model = model_fn()
|
||||
init_to_get_rotary(orig_model.model, base=10000)
|
||||
orig_model = orig_model.half()
|
||||
data = data_gen_fn()
|
||||
|
||||
|
Reference in New Issue
Block a user