diff --git a/applications/ColossalChat/coati/distributed/consumer.py b/applications/ColossalChat/coati/distributed/consumer.py
index 1e85cccb3..de2897383 100644
--- a/applications/ColossalChat/coati/distributed/consumer.py
+++ b/applications/ColossalChat/coati/distributed/consumer.py
@@ -57,6 +57,7 @@ class BaseConsumer:
assert self.plugin_config.get("pp_size", 1) == 1, "pp_size > 1 is not supported now"
self.device = get_current_device()
+ self.lr_scheduler = None
def setup(self) -> None:
for i in range(self.num_producers):
@@ -121,6 +122,8 @@ class BaseConsumer:
pbar.set_postfix({"loss": loss})
i += 1
assert len(self.buffer) == 0
+ if self.lr_scheduler is not None:
+ self.lr_scheduler.step()
if (step + 1) % self.save_interval == 0:
if self.rank == 0:
print(f"Start saving policy model at step {step + 1}.")
diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py
index b1edb89bb..1c0773f4e 100644
--- a/applications/ColossalChat/coati/distributed/grpo_consumer.py
+++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py
@@ -1,8 +1,11 @@
+import json
+import os
from contextlib import nullcontext
from typing import Optional
import ray
import torch
+import torch.distributed as dist
import wandb
from coati.distributed.consumer import BaseConsumer
from coati.distributed.loss import PolicyLoss
@@ -12,6 +15,7 @@ from coati.distributed.utils import calc_action_log_probs
from coati.trainer.utils import all_reduce_mean
from transformers import AutoModelForCausalLM, AutoTokenizer
+from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
from colossalai.nn.optimizer import HybridAdam
@@ -31,8 +35,10 @@ class GRPOConsumer(BaseConsumer):
model_config,
plugin_config,
microbatch_size=1,
- num_generations=4,
+ num_generations=8,
use_wandb=True,
+ generate_config=None,
+ training_config={},
):
super().__init__(
num_producers,
@@ -52,7 +58,7 @@ class GRPOConsumer(BaseConsumer):
self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
self.policy_model.train()
self.policy_model.gradient_checkpointing_enable()
- self.optimizer = HybridAdam(self.policy_model.parameters(), lr=1e-6)
+ self.optimizer = HybridAdam(self.policy_model.parameters(), lr=training_config.get("lr", 1e-6))
self.accum_loss = torch.zeros(1, device=self.device)
self.accum_reward = torch.zeros(1, device=self.device)
self.accum_kl = torch.zeros(1, device=self.device)
@@ -61,6 +67,7 @@ class GRPOConsumer(BaseConsumer):
self.accum_advantages = torch.zeros(1, device=self.device)
self.accum_response_length = torch.zeros(1, device=self.device)
self.accum_count = 0
+ self.generate_config = generate_config
# Reference model is initialized from policy model.
self.reference_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
@@ -69,6 +76,9 @@ class GRPOConsumer(BaseConsumer):
self.tokenizer = AutoTokenizer.from_pretrained(path)
self.pad_token_id = self.tokenizer.pad_token_id
self.num_generations = num_generations
+ self.filter_range = training_config.get("filter_range", None)
+ if self.filter_range is not None:
+ assert len(self.filter_range) == 2, "Filter range should have 2 values."
# Initialize verifiable reward.
response_format_tags = {
@@ -84,11 +94,21 @@ class GRPOConsumer(BaseConsumer):
self.policy_loss_fn = PolicyLoss()
self.global_step = 0
if use_wandb and self.rank == 0:
- self.wandb_run = wandb.init(project="GRPO-V1", sync_tensorboard=True)
+ name = f"{generate_config['backend']}_bs_{self.batch_size*self.world_size}_temp_{generate_config['temperature']:.01f}_top_p_{generate_config['top_p']:.02f}"
+ self.wandb_run = wandb.init(project="GRPO-V1", sync_tensorboard=True, dir="./wandb", name=name)
+
+ self.lr_scheduler = CosineAnnealingWarmupLR(
+ optimizer=self.optimizer,
+ total_steps=min(self.num_episodes, 4) * self.num_update_per_episode,
+ warmup_steps=0,
+ eta_min=0.1 * training_config.get("lr", 1e-6),
+ )
def setup(self):
super().setup()
- self.policy_model, self.optimizer, *_ = self.booster.boost(self.policy_model, self.optimizer)
+ self.policy_model, self.optimizer, _, _, self.lr_scheduler = self.booster.boost(
+ self.policy_model, self.optimizer, lr_scheduler=self.lr_scheduler
+ )
self.reference_model, *_ = self.booster.boost(self.reference_model)
def step(self, step_idx: int, **kwargs) -> Optional[float]:
@@ -113,7 +133,6 @@ class GRPOConsumer(BaseConsumer):
response_length = torch.sum(action_mask, dim=1).to(torch.float32)
need_update = (step_idx + 1) % self.num_microbatches == 0
-
ctx = nullcontext() if need_update else self.booster.no_sync(self.policy_model, self.optimizer)
with ctx:
policy_model_logits = self.policy_model(
@@ -121,7 +140,10 @@ class GRPOConsumer(BaseConsumer):
attention_mask=data["attention_mask"],
)["logits"]
action_log_probs = calc_action_log_probs(
- policy_model_logits, data["input_ids"], num_action, self.plugin.shard_config
+ policy_model_logits / self.generate_config["temperature"],
+ data["input_ids"],
+ num_action,
+ self.plugin.shard_config,
)
with torch.no_grad():
@@ -130,7 +152,10 @@ class GRPOConsumer(BaseConsumer):
attention_mask=data["attention_mask"],
)["logits"]
reference_action_log_probs = calc_action_log_probs(
- reference_model_logits, data["input_ids"], num_action, self.plugin.shard_config
+ reference_model_logits / self.generate_config["temperature"],
+ data["input_ids"],
+ num_action,
+ self.plugin.shard_config,
)
per_token_kl = (
@@ -149,21 +174,31 @@ class GRPOConsumer(BaseConsumer):
acc_reward = torch.tensor([value[2] for value in reward_group]).to(data["input_ids"].device)
# [batch_size, num_generations]
+
group_reward = reward.view(-1, self.num_generations)
+ reward_mean = group_reward.mean(dim=1)
+ # filter out the reward that is too high (all sample gets full score) or too low (all sample gets 0 score),
+ loss_mask = (
+ None
+ if self.filter_range is None
+ else torch.logical_and(
+ reward_mean > self.filter_range[0], reward_mean < self.filter_range[1]
+ ).repeat_interleave(self.num_generations, dim=0)
+ )
# [batch_size x num_generations]
- reward_mean = group_reward.mean(dim=1).repeat_interleave(self.num_generations, dim=0)
+ reward_mean = reward_mean.repeat_interleave(self.num_generations, dim=0)
reward_std = group_reward.std(dim=1).repeat_interleave(self.num_generations, dim=0)
# [batch_size x num_generations]
advantages = (reward - reward_mean) / (reward_std + 1e-4)
- # Calculate Loss
loss, skip_update, _ = self.policy_loss_fn(
action_log_probs,
old_action_log_probs,
advantages.unsqueeze(dim=-1).repeat_interleave(action_log_probs.size(-1), dim=-1),
per_token_kl,
action_mask,
+ loss_mask=loss_mask,
)
if not skip_update:
@@ -207,13 +242,15 @@ class GRPOConsumer(BaseConsumer):
)
self.wandb_run.log(
{
+ "metrics/reward": self.accum_reward.item() / self.accum_count,
+ "metrics/format_reward": self.accum_format_reward.item() / self.accum_count,
+ "metrics/acc_reward": self.accum_acc_reward.item() / self.accum_count,
+ "metrics/response_length": self.accum_response_length.item() / self.accum_count,
"train/loss": self.accum_loss.item() / self.accum_count,
- "train/reward": self.accum_reward.item() / self.accum_count,
- "train/format_reward": self.accum_format_reward.item() / self.accum_count,
- "train/acc_reward": self.accum_acc_reward.item() / self.accum_count,
"train/kl": self.accum_kl.item() / self.accum_count,
"train/advantages": self.accum_advantages.item() / self.accum_count,
- "train/response_length": self.accum_response_length.item() / self.accum_count,
+ "train/learning_rate": self.lr_scheduler.get_last_lr()[0],
+ "rollout/temperature": data["temperature"].cpu().numpy()[0][0],
}
)
self.accum_loss.zero_()
@@ -232,3 +269,125 @@ class GRPOConsumer(BaseConsumer):
model = self.policy_model.unwrap()
state_dict = model.state_dict()
return state_dict
+
+
+@ray.remote
+class GRPOEvalConsumer(BaseConsumer):
+ def __init__(
+ self,
+ num_producers,
+ num_episodes,
+ rank,
+ world_size,
+ master_addr,
+ master_port,
+ num_update_per_episode,
+ num_recv_per_update,
+ batch_size,
+ model_config,
+ plugin_config,
+ microbatch_size=1,
+ num_generations=4,
+ use_wandb=True,
+ log_dir="./results",
+ ):
+ super().__init__(
+ num_producers,
+ num_episodes,
+ rank,
+ world_size,
+ master_addr,
+ master_port,
+ num_update_per_episode,
+ num_recv_per_update,
+ batch_size,
+ model_config,
+ plugin_config,
+ microbatch_size,
+ )
+ path = model_config.pop("path")
+ self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
+ self.policy_model.train()
+ self.accum_reward = torch.zeros(1, device=self.device)
+ self.accum_format_reward = torch.zeros(1, device=self.device)
+ self.accum_acc_reward = torch.zeros(1, device=self.device)
+ self.accum_response_length = torch.zeros(1, device=self.device)
+ self.accum_count = torch.zeros(1, device=self.device)
+
+ self.tokenizer = AutoTokenizer.from_pretrained(path)
+ self.pad_token_id = self.tokenizer.pad_token_id
+ self.num_generations = num_generations
+
+ # Initialize verifiable reward.
+ response_format_tags = {
+ "think_start": {"text": "", "num_occur": 1},
+ "think_end": {"text": "", "num_occur": 1},
+ "answer_start": {"text": "", "num_occur": 1},
+ "answer_end": {"text": "", "num_occur": 1},
+ }
+ self.reward_model = VerifiableReward(
+ reward_fns=[math_reward_fn], tokenizer=self.tokenizer, tags=response_format_tags
+ )
+
+ self.log_dir = log_dir
+ if not os.path.exists(self.log_dir):
+ os.makedirs(self.log_dir)
+ else:
+ os.system(f"rm -rf {self.log_dir}/*")
+
+ def setup(self):
+ super().setup()
+ self.policy_model, _, *_ = self.booster.boost(self.policy_model)
+
+ def step(self, step_idx: int, **kwargs) -> Optional[float]:
+ rank = dist.get_rank()
+ data = {k: v.view(-1, v.size(-1)).cpu() for k, v in kwargs.items()}
+ kwargs["input_ids"].size(0)
+ reward_group = self.reward_model(
+ data["input_ids"], gt_answer=data["gt_answer"], response_idx=data["response_idx"]
+ )
+ reward = [value[0].item() for value in reward_group]
+ format_reward = [value[1].item() for value in reward_group]
+ acc_reward = [value[2].item() for value in reward_group]
+ response_length = [(data["response_idx"][i][1] - data["response_idx"][i][0]).item() for i in range(len(reward))]
+
+ response = self.tokenizer.batch_decode(data["input_ids"], skip_special_tokens=True)
+ with open(f"{self.log_dir}/eval_results_rank_{rank}.jsonl", "a", encoding="utf8") as f:
+ for i in range(len(response)):
+ f.write(
+ json.dumps(
+ {
+ "response": response[i],
+ "reward": reward[i],
+ "format_reward": format_reward[i],
+ "acc_reward": acc_reward[i],
+ "response_length": response_length[i],
+ },
+ ensure_ascii=False,
+ )
+ + "\n"
+ )
+
+ self.accum_reward += sum(reward)
+ self.accum_format_reward += sum(format_reward)
+ self.accum_acc_reward += sum(acc_reward)
+ self.accum_response_length += sum(response_length)
+ self.accum_count += len(reward)
+
+ # print results
+ total_count = all_reduce_mean(self.accum_count, self.plugin)
+ mean_reward = all_reduce_mean(self.accum_reward, self.plugin) / total_count
+ mean_format_reward = all_reduce_mean(self.accum_format_reward, self.plugin) / total_count
+ mean_acc_reward = all_reduce_mean(self.accum_acc_reward, self.plugin) / total_count
+ mean_response_length = all_reduce_mean(self.accum_response_length, self.plugin) / total_count
+ if rank == 0:
+ print(
+ f"Step {step_idx}: Mean Reward: {mean_reward}, Mean Format Reward: {mean_format_reward}, Mean Acc Reward: {mean_acc_reward}, Mean Response Length: {mean_response_length}"
+ )
+ return None
+
+ def state_dict(self):
+ self.policy_model._force_wait_all_gather()
+ model = self.policy_model.unwrap()
+ state_dict = model.state_dict()
+ return state_dict
diff --git a/applications/ColossalChat/coati/distributed/inference_backend.py b/applications/ColossalChat/coati/distributed/inference_backend.py
index 58414b29f..17c71c8a8 100644
--- a/applications/ColossalChat/coati/distributed/inference_backend.py
+++ b/applications/ColossalChat/coati/distributed/inference_backend.py
@@ -53,7 +53,13 @@ class TransformersInferenceBackend(BaseInferenceBackend):
)
FORCE_GENERATE_CONFIG = dict(output_logits=True, return_dict_in_generate=True)
- def __init__(self, model_config: Dict[str, Any], generate_config: Dict[str, Any], tokenizer: PreTrainedTokenizer):
+ def __init__(
+ self,
+ model_config: Dict[str, Any],
+ generate_config: Dict[str, Any],
+ tokenizer: PreTrainedTokenizer,
+ num_generations: int = 8,
+ ):
model_config = update_by_default(model_config, self.DEFAULT_MODEL_CONFIG)
model_config.update(self.FORCE_MODEL_CONFIG)
path = model_config.pop("path")
@@ -61,7 +67,7 @@ class TransformersInferenceBackend(BaseInferenceBackend):
self.generate_config = generate_config.copy()
self.generate_config.update(self.FORCE_GENERATE_CONFIG)
self.tokenizer = tokenizer
- self.num_generations = 8
+ self.num_generations = num_generations
@torch.no_grad()
def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
@@ -120,7 +126,13 @@ class TransformersInferenceBackend(BaseInferenceBackend):
class SGLangInferenceBackend(BaseInferenceBackend):
- def __init__(self, model_config: Dict[str, Any], generate_config: Dict[str, Any], tokenizer: PreTrainedTokenizer):
+ def __init__(
+ self,
+ model_config: Dict[str, Any],
+ generate_config: Dict[str, Any],
+ tokenizer: PreTrainedTokenizer,
+ num_generations: int = 8,
+ ):
if sgl is None:
raise ImportError("sglang is not installed")
path = model_config.pop("path")
@@ -175,27 +187,38 @@ class VLLMInferenceBackend(BaseInferenceBackend):
)
FORCE_GENERATE_CONFIG = dict(
logprobs=0,
- n=8,
)
- def __init__(self, model_config: Dict[str, Any], generate_config: Dict[str, Any], tokenizer: PreTrainedTokenizer):
+ def __init__(
+ self,
+ model_config: Dict[str, Any],
+ generate_config: Dict[str, Any],
+ tokenizer: PreTrainedTokenizer,
+ num_generations: int = 8,
+ ):
if LLM is None:
raise ImportError("vllm is not installed")
model_config = update_by_default(model_config, self.DEFAULT_MODEL_CONFIG)
path = model_config.pop("path")
- self.llm = LLM(path, **model_config)
+ self.llm = LLM(model=path, **model_config)
generate_config = generate_config.copy()
generate_config.update(self.FORCE_GENERATE_CONFIG)
+ generate_config.update({"n": num_generations})
self.generate_config = SamplingParams(**generate_config)
self.tokenizer = tokenizer
- self.num_generations = self.FORCE_GENERATE_CONFIG["n"]
+ self.num_generations = num_generations
@torch.no_grad()
def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
micro_batch_size = input_ids.size(0)
response_start_idx = input_ids.size(1)
+ first_non_padding_token_idx = (input_ids != self.tokenizer.pad_token_id).int().argmax(dim=1)
+ micro_batch_input_ids = input_ids.tolist()
+ micro_batch_input_ids_no_padding = [
+ micro_batch_input_ids[i][first_non_padding_token_idx[i] :] for i in range(micro_batch_size)
+ ]
outputs = self.llm.generate(
- prompt_token_ids=input_ids.tolist(), sampling_params=self.generate_config, use_tqdm=False
+ prompt_token_ids=micro_batch_input_ids_no_padding, sampling_params=self.generate_config, use_tqdm=False
)
out_tokens = []
out_len = []
diff --git a/applications/ColossalChat/coati/distributed/launch.py b/applications/ColossalChat/coati/distributed/launch.py
index 8581ff586..c50db1378 100644
--- a/applications/ColossalChat/coati/distributed/launch.py
+++ b/applications/ColossalChat/coati/distributed/launch.py
@@ -1,15 +1,13 @@
+import copy
from typing import Any, Dict, Optional
import ray
from .consumer import SimpleConsumer
-from .grpo_consumer import GRPOConsumer
+from .grpo_consumer import GRPOConsumer, GRPOEvalConsumer
from .producer import SimpleProducer
-ALGO_MAP = {
- "Simple": SimpleConsumer,
- "GRPO": GRPOConsumer,
-}
+ALGO_MAP = {"Simple": SimpleConsumer, "GRPO": GRPOConsumer, "EvalGRPO": GRPOEvalConsumer}
def get_jsonl_size_fast(path: str) -> int:
@@ -44,6 +42,7 @@ def launch_distributed(
plugin_config: Dict[str, Any],
tokenizer_config: Optional[Dict[str, Any]] = None,
inference_backend: str = "transformers",
+ num_generations: int = 8,
master_addr: str = "localhost",
master_port: int = 29500,
core_algo: str = "GRPO",
@@ -78,8 +77,15 @@ def launch_distributed(
tokenizer_config=tokenizer_config,
microbatch_size=inference_microbatch_size,
backend=inference_backend,
+ num_generations=num_generations,
)
procs.append(producer)
+ generate_config_consumer = copy.deepcopy(generate_config)
+ generate_config_consumer.update(
+ dict(
+ backend=inference_backend,
+ )
+ )
for i in range(num_consumer_procs):
consumer = core_consumer.options(num_gpus=1).remote(
num_producers=num_producers,
@@ -94,6 +100,9 @@ def launch_distributed(
model_config=train_model_config,
plugin_config=plugin_config,
microbatch_size=train_microbatch_size,
+ generate_config=generate_config_consumer,
+ training_config={"filter_range": [0.05, 9.0], "lr": 1e-6},
+ num_generations=num_generations,
)
procs.append(consumer)
ray.get([p.setup.remote() for p in procs])
diff --git a/applications/ColossalChat/coati/distributed/loss.py b/applications/ColossalChat/coati/distributed/loss.py
index af5776731..90ad09736 100644
--- a/applications/ColossalChat/coati/distributed/loss.py
+++ b/applications/ColossalChat/coati/distributed/loss.py
@@ -23,6 +23,7 @@ class PolicyLoss(nn.Module):
advantages: torch.Tensor,
per_token_kl: torch.Tensor,
action_mask: Optional[torch.Tensor] = None,
+ loss_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
skip = False
if action_mask is None:
@@ -38,5 +39,7 @@ class PolicyLoss(nn.Module):
loss = masked_mean(loss, action_mask)
else:
loss = loss.mean(dim=1)
+ if loss_mask is not None:
+ loss = loss * loss_mask
loss = loss.mean()
return loss, skip, ratio.max()
diff --git a/applications/ColossalChat/coati/distributed/producer.py b/applications/ColossalChat/coati/distributed/producer.py
index a3ae22a79..51a1af332 100644
--- a/applications/ColossalChat/coati/distributed/producer.py
+++ b/applications/ColossalChat/coati/distributed/producer.py
@@ -101,6 +101,9 @@ class BaseProducer:
break
outputs = self.rollout(**batch)
print(f"[P{self.producer_idx}] Send data {[(k, v.shape) for k, v in outputs.items()]}")
+ outputs["temperature"] = torch.tensor(
+ [self.model.generate_config.temperature] * outputs["input_ids"].size(0)
+ ).to(outputs["input_ids"].device)
outputs = pre_send(outputs)
ray_broadcast_tensor_dict(
outputs, src=0, device=self.device, group_name=f"sync_data_{self.producer_idx}"
@@ -117,6 +120,12 @@ class BaseProducer:
None, self.num_producers, device=self.device, group_name="sync_model"
)
self.load_state_dict(state_dict)
+ # linear annealing for 1 episode, temperature from initial to 0.7
+ if episode <= 0:
+ ratio = 1 - (len(self.dataloader) - i) / len(self.dataloader)
+ self.model.generate_config.temperature = (
+ ratio * self.generate_config["temperature"] + (1 - ratio) * 0.7
+ )
@ray.remote
@@ -135,6 +144,7 @@ class SimpleProducer(BaseProducer):
tokenizer_config=None,
microbatch_size=1,
backend="transformers",
+ num_generations: int = 8,
):
super().__init__(
producer_idx,
@@ -150,7 +160,7 @@ class SimpleProducer(BaseProducer):
microbatch_size,
backend,
)
- self.model = self.backend_cls(model_config, generate_config, self.tokenizer)
+ self.model = self.backend_cls(model_config, generate_config, self.tokenizer, num_generations)
@torch.no_grad()
def rollout(self, input_ids, attention_mask, **kwargs):
diff --git a/applications/ColossalChat/rl_example.py b/applications/ColossalChat/rl_example.py
index 1de8b649d..a67a10bc5 100644
--- a/applications/ColossalChat/rl_example.py
+++ b/applications/ColossalChat/rl_example.py
@@ -15,18 +15,14 @@ if __name__ == "__main__":
parser.add_argument("-tbs", "--train-batch-size", type=int, default=32)
parser.add_argument("-tmbs", "--train-microbatch-size", type=int, default=1)
parser.add_argument("-b", "--backend", type=str, default="transformers")
- parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["Simple, GRPO"])
+ parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["Simple", "GRPO", "EvalGRPO"])
args = parser.parse_args()
ray.init(address="local", namespace="ray-example")
inference_model_config = dict(path=args.model)
train_model_config = dict(path=args.model)
- generate_config = dict(
- top_k=50,
- top_p=0.9,
- temperature=1.0,
- )
+ generate_config = dict(top_k=50, top_p=0.75, temperature=0.9)
if args.backend == "transformers":
inference_model_config.update(
@@ -52,19 +48,13 @@ if __name__ == "__main__":
)
)
elif args.backend == "vllm":
- inference_model_config.update(
- dict(
- gpu_memory_utilization=0.7,
- )
- )
+ inference_model_config.update(dict(gpu_memory_utilization=0.7, enforce_eager=True, enable_chunked_prefill=True))
generate_config.update(
dict(
max_tokens=2048,
ignore_eos=True,
include_stop_str_in_output=True,
stop=[""],
- temperature=0.7,
- top_p=0.95,
)
)
else:
@@ -97,6 +87,6 @@ if __name__ == "__main__":
plugin_config={},
inference_backend=args.backend,
master_addr="localhost",
- master_port=29504,
+ master_port=29503,
core_algo=args.algo,
)
diff --git a/colossalai/shardformer/layer/loss.py b/colossalai/shardformer/layer/loss.py
index 51419a38a..a9bb76fc7 100644
--- a/colossalai/shardformer/layer/loss.py
+++ b/colossalai/shardformer/layer/loss.py
@@ -387,7 +387,7 @@ def dist_log_prob(
dtype=dtype,
)
else:
- log_prob = log_softmax(logits)
+ log_prob = log_softmax(logits, dim=-1)
log_prob = log_prob.gather(dim=-1, index=labels.unsqueeze(-1))
return log_prob