[llama] fix training and inference scripts (#5384)

* [llama] refactor inference example to fit sft

* [llama] fix training script to fit gemini

* [llama] fix inference script
This commit is contained in:
Hongxin Liu
2024-02-19 16:41:04 +08:00
committed by GitHub
parent adae123df3
commit 7303801854
3 changed files with 52 additions and 30 deletions

View File

@@ -726,11 +726,13 @@ class GeminiDDP(ModelWrapper):
chunk.cpu_shard.copy_(temp_chunk[chunk.shard_begin : chunk.shard_end])
del temp_chunk
if self.reuse_fp16_chunk:
for chunk_32 in chunk_list:
chunk_16 = chunk_32.paired_chunk
assert chunk_16 is not None
chunk_16.payload.copy_(chunk_32.payload)
# sync running weights and master weights
if self.master_weights:
for loaded_chunk in chunk_list:
paired_chunk = loaded_chunk.paired_chunk
assert paired_chunk is not None
paired_chunk.payload.copy_(loaded_chunk.payload)
for name, buf in persistent_buffers.items():
if buf is not None: