Support overall loss, update KTO logging

This commit is contained in:
YeAnbang
2024-08-02 06:51:38 +00:00
parent 75c963686f
commit 0b2d55c4ab
15 changed files with 119 additions and 119 deletions

View File

@@ -53,8 +53,8 @@ def load_model_and_tokenizer(model_path, tokenizer_path, device="cuda", **kwargs
tuple: A tuple containing the loaded model and tokenizer.
"""
model = AutoModelForCausalLM.from_pretrained(model_path, **kwargs)
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
model = AutoModelForCausalLM.from_pretrained(model_path, **kwargs, trust_remote_code=True).to(torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
model.to(device)