[inference]fix import bug and delete down useless init (#4830)

* fix import bug and release useless init

* fix

* fix

* fix
This commit is contained in:
Jianghai
2023-10-04 09:18:45 +08:00
committed by GitHub
parent 573f270537
commit 013a4bedf0
9 changed files with 121 additions and 154 deletions

View File

@@ -20,30 +20,6 @@ MAX_OUTPUT_LEN = 100
CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5")
def init_to_get_rotary(self, base=10000):
self.config.head_dim_ = self.config.hidden_size // self.config.num_attention_heads
if not hasattr(self.config, "rope_scaling"):
rope_scaling_factor = 1.0
else:
rope_scaling_factor = self.config.rope_scaling.factor if self.config.rope_scaling is not None else 1.0
if hasattr(self.config, "max_sequence_length"):
max_seq_len = self.config.max_sequence_length
elif hasattr(self.config, "max_position_embeddings"):
max_seq_len = self.config.max_position_embeddings * rope_scaling_factor
else:
max_seq_len = 2048 * rope_scaling_factor
base = float(base)
inv_freq = 1.0 / (
base ** (torch.arange(0, self.config.head_dim_, 2, device="cpu", dtype=torch.float32) / self.config.head_dim_)
)
t = torch.arange(max_seq_len + 1024 * 64, device="cpu", dtype=torch.float32) / rope_scaling_factor
freqs = torch.outer(t, inv_freq)
self._cos_cached = torch.cos(freqs).to(torch.float16).cuda()
self._sin_cached = torch.sin(freqs).to(torch.float16).cuda()
return
@parameterize(
"test_config",
[
@@ -56,7 +32,6 @@ def run_llama_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry("transformers_llama_for_casual_lm")
for name, (model_fn, data_gen_fn, _, _, _) in sub_model_zoo.items():
orig_model = model_fn()
init_to_get_rotary(orig_model.model, base=10000)
orig_model = orig_model.half()
data = data_gen_fn()