fix typo applications/Chat/coati/ (#3947)

This commit is contained in:
digger yu
2023-06-15 10:43:11 +08:00
committed by GitHub
parent e8ad3c88f5
commit d4fb7bfda7
3 changed files with 12 additions and 12 deletions

View File

@@ -205,15 +205,15 @@ class ExperienceMakerHolder:
self.experience_maker.actor.model.load_state_dict(new_actor_state_dict, strict=False)
else:
new_actor_state_dict = state_dict_to(new_actor_state_dict, device=torch.cuda.current_device())
state_dict_increasae = self.actor_lora_constructor.reconstruct_increase(new_actor_state_dict, new_actor_lora_config_dict)
self.actor_lora_constructor.load_state_dict_increase(self.experience_maker.actor.model, state_dict_increasae)
state_dict_increase = self.actor_lora_constructor.reconstruct_increase(new_actor_state_dict, new_actor_lora_config_dict)
self.actor_lora_constructor.load_state_dict_increase(self.experience_maker.actor.model, state_dict_increase)
if new_critic_state_dict is not None:
if not self._update_lora_weights or fully_update:
self.experience_maker.critic.load_state_dict(new_critic_state_dict, strict=False)
else:
new_critic_state_dict = state_dict_to(new_critic_state_dict, device=torch.cuda.current_device())
state_dict_increasae = self.critic_lora_constructor.reconstruct_increase(new_critic_state_dict, new_critic_lora_config_dict)
self.critic_lora_constructor.load_state_dict_increase(self.experience_maker.critic, state_dict_increasae)
state_dict_increase = self.critic_lora_constructor.reconstruct_increase(new_critic_state_dict, new_critic_lora_config_dict)
self.critic_lora_constructor.load_state_dict_increase(self.experience_maker.critic, state_dict_increase)
# the lock must be released after both actor and critic being updated
if chunk_end:

View File

@@ -19,7 +19,7 @@ class LoRAConfig:
class LoRAConstructor:
'''
Tools for reconstructing a model from a remote LoRA model.
(Transfering only LoRA data costs much less!)
(Transferring only LoRA data costs much less!)
Usage:
Step 1 (Sender):
filter_state_dict_lora()
@@ -52,7 +52,7 @@ class LoRAConstructor:
if lora_config_dict is not None:
self.register_lora_config(lora_config_dict)
state_dict_increasae = OrderedDict()
state_dict_increase = OrderedDict()
config_iter = iter(self.lora_config_dict.items())
lora_A, lora_B, layer_prefix = None, None, None
for k, v in state_dict_lora.items():
@@ -65,11 +65,11 @@ class LoRAConstructor:
assert layer_prefix_2 == layer_prefix, "unmatched (state_dict, config_dict) pair"
lora_B = v
weight_data_increase = self._compute(lora_A, lora_B, config)
state_dict_increasae[layer_prefix + '.weight'] = weight_data_increase
state_dict_increase[layer_prefix + '.weight'] = weight_data_increase
lora_A, lora_B, layer_prefix = None, None, None
else:
raise ValueError('unexpected key')
return state_dict_increasae
return state_dict_increase
def _compute(self, lora_A, lora_B, config=LoRAConfig()):
def T(w):
@@ -80,12 +80,12 @@ class LoRAConstructor:
return weight_data_increase
return 0
def load_state_dict_increase(self, model: nn.Module, state_dict_increasae: Dict[str, Any]):
def load_state_dict_increase(self, model: nn.Module, state_dict_increase: Dict[str, Any]):
'''
The final reconstruction step
'''
# naive approach
model.load_state_dict({k: v + model.state_dict()[k] for k, v in state_dict_increasae.items()}, strict=False)
model.load_state_dict({k: v + model.state_dict()[k] for k, v in state_dict_increase.items()}, strict=False)
@staticmethod
def filter_state_dict_lora(state_dict: Dict[str, Any], keep_non_lora=False):