[application] Update README (#6196)

* remove unused ray

* remove unused readme

* update readme

* update readme

* update

* update

* add link

* update readme

* update readme

* fix link

* update code

* update cititaion

* update

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update readme

* update project

* add images

* update link

* update note

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Tong Li
2025-02-18 20:17:56 +08:00
committed by GitHub
parent d54642a263
commit f8b9e88484
10 changed files with 87 additions and 1132 deletions

View File

@@ -1,26 +0,0 @@
import openai
from openai.types.chat.chat_completion import ChatCompletion
from openai.types.chat.chat_completion_message_param import ChatCompletionMessageParam
API_KEY = "Dummy API Key"
def get_client(base_url: str | None = None) -> openai.Client:
return openai.Client(api_key=API_KEY, base_url=base_url)
def chat_completion(
messages: list[ChatCompletionMessageParam],
model: str,
base_url: str | None = None,
temperature: float = 0.8,
**kwargs,
) -> ChatCompletion:
client = get_client(base_url)
response = client.chat.completions.create(
model=model,
messages=messages,
temperature=temperature,
**kwargs,
)
return response

View File

@@ -1,250 +0,0 @@
"""
Implementation of MCTS + Self-refine algorithm.
Reference:
1. "Accessing GPT-4 level Mathematical Olympiad Solutions via Monte
Carlo Tree Self-refine with LLaMa-3 8B: A Technical Report"
2. https://github.com/BrendanGraham14/mcts-llm/
3. https://github.com/trotsky1997/MathBlackBox/
4. https://github.com/openreasoner/openr/blob/main/reason/guided_search/tree.py
"""
from __future__ import annotations
import math
from collections import deque
import numpy as np
import tqdm
from coati.reasoner.guided_search.llm import chat_completion
from coati.reasoner.guided_search.prompt_store.base import PromptCFG
from pydantic import BaseModel
class MCTSNode(BaseModel):
"""
Node for MCTS.
"""
answer: str
parent: MCTSNode = None
children: list[MCTSNode] = []
num_visits: int = 0
Q: int = 0
rewards: list[int] = []
def expand_node(self, node) -> None:
self.children.append(node)
def add_reward(self, reward: int) -> None:
self.rewards.append(reward)
self.Q = (np.min(self.rewards) + np.mean(self.rewards)) / 2
class MCTS(BaseModel):
"""
Simulation of MCTS process.
"""
problem: str
max_simulations: int
cfg: PromptCFG
C: float = 1.4
max_children: int = 2
epsilon: float = 1e-5
root: MCTSNode = None
def initialization(self):
"""
Root Initiation.
"""
# Simple answer as root. You can also use negative response such as "I do not know" as a response.
base_answer = self.sample_base_answer()
self.root = MCTSNode(answer=base_answer)
self.self_evaluate(self.root)
def is_fully_expanded(self, node: MCTSNode):
return len(node.children) >= self.max_children or any(child.Q > node.Q for child in node.children)
def select_node(self) -> MCTSNode:
"""
Select next node to explore.
"""
candidates: list[MCTSNode] = []
to_explore = deque([self.root])
while to_explore:
current_node = to_explore.popleft()
if not self.is_fully_expanded(current_node):
candidates.append(current_node)
to_explore.extend(current_node.children)
if not candidates:
return self.root
return max(candidates, key=self.compute_uct)
def self_evaluate(self, node: MCTSNode):
"""
Sample reward of the answer.
"""
reward = self.sample_reward(node)
node.add_reward(reward)
def back_propagation(self, node: MCTSNode):
"""
Back propagate the value of the refined answer.
"""
parent = node.parent
while parent:
best_child_Q = max(child.Q for child in parent.children)
parent.Q = (parent.Q + best_child_Q) / 2
parent.num_visits += 1
parent = parent.parent
def compute_uct(self, node: MCTSNode):
"""
Compute UCT.
"""
if node.parent is None:
return -100
return node.Q + self.C * math.sqrt(math.log(node.parent.num_visits + 1) / (node.num_visits + self.epsilon))
def simulate(self):
self.initialization()
for _ in tqdm.tqdm(range(self.max_simulations)):
node = self.select_node()
child = self.self_refine(node)
node.expand_node(child)
self.self_evaluate(child)
self.back_propagation(child)
return self.get_best_answer()
def get_best_answer(self):
to_visit = deque([self.root])
best_node = self.root
while to_visit:
current_node = to_visit.popleft()
if current_node.Q > best_node.Q:
best_node = current_node
to_visit.extend(current_node.children)
return best_node.answer
def self_refine(self, node: MCTSNode):
"""
Refine node.
"""
critique_response = chat_completion(
messages=[
{
"role": "system",
"content": self.cfg.critic_system_prompt,
},
{
"role": "user",
"content": "\n\n".join(
[
f"<problem>\n{self.problem}\n</problem>",
f"<current_answer>\n{node.answer}\n</current_answer>",
]
),
},
],
model=self.cfg.model,
base_url=self.cfg.base_url,
max_tokens=self.cfg.max_tokens,
)
critique = critique_response.choices[0].message.content
assert critique is not None
refined_answer_response = chat_completion(
messages=[
{
"role": "system",
"content": self.cfg.refine_system_prompt,
},
{
"role": "user",
"content": "\n\n".join(
[
f"<problem>\n{self.problem}\n</problem>",
f"<current_answer>\n{node.answer}\n</current_answer>",
f"<critique>\n{critique}\n</critique>",
]
),
},
],
model=self.cfg.model,
base_url=self.cfg.base_url,
max_tokens=self.cfg.max_tokens,
)
refined_answer = refined_answer_response.choices[0].message.content
assert refined_answer is not None
return MCTSNode(answer=refined_answer, parent=node)
def sample_base_answer(self):
response = chat_completion(
messages=[
{
"role": "system",
"content": self.cfg.base_system_prompt,
},
{
"role": "user",
"content": f"<problem>\n {self.problem} \n</problem> \nLet's think step by step",
},
],
model=self.cfg.model,
base_url=self.cfg.base_url,
max_tokens=self.cfg.max_tokens,
)
assert response.choices[0].message.content is not None
return response.choices[0].message.content
def sample_reward(self, node: MCTSNode):
"""
Calculate reward.
"""
messages = [
{
"role": "system",
"content": self.cfg.evaluate_system_prompt,
},
{
"role": "user",
"content": "\n\n".join(
[
f"<problem>\n{self.problem}\n</problem>",
f"<answer>\n{node.answer}\n</answer>",
]
),
},
]
for attempt in range(3):
try:
response = chat_completion(
messages=messages,
model=self.cfg.model,
base_url=self.cfg.base_url,
max_tokens=self.cfg.max_tokens,
)
assert response.choices[0].message.content is not None
return int(response.choices[0].message.content)
except ValueError:
messages.extend(
[
{
"role": "assistant",
"content": response.choices[0].message.content,
},
{
"role": "user",
"content": "Failed to parse reward as an integer.",
},
]
)
if attempt == 2:
raise

View File

@@ -1,11 +0,0 @@
from pydantic import BaseModel
class PromptCFG(BaseModel):
model: str
base_url: str
max_tokens: int = 4096
base_system_prompt: str
critic_system_prompt: str
refine_system_prompt: str
evaluate_system_prompt: str

View File

@@ -1,22 +0,0 @@
"""
Prompts for Qwen Series.
"""
from coati.reasoner.guided_search.prompt_store.base import PromptCFG
Qwen32B_prompt_CFG = PromptCFG(
base_url="http://0.0.0.0:8008/v1",
model="Qwen2.5-32B-Instruct",
base_system_prompt="The user will present a problem. Analyze and solve the problem in the following structure:\n"
"Begin with [Reasoning Process] to explain the approach. \n Proceed with [Verification] to confirm the solution. \n Conclude with [Final Answer] in the format: 'Answer: [answer]'",
critic_system_prompt="Provide a detailed and constructive critique of the answer, focusing on ways to improve its clarity, accuracy, and relevance."
"Highlight specific areas that need refinement or correction, and offer concrete suggestions for enhancing the overall quality and effectiveness of the response.",
refine_system_prompt="""# Instruction
Refine the answer based on the critique. The response should begin with [reasoning process]...[Verification]... and end with [Final Answer].
""",
evaluate_system_prompt=(
"Critically analyze this answer and provide a reward score between -100 and 100 based on strict standards."
"The score should clearly reflect the quality of the answer."
"Make sure the reward score is an integer. You should only return the score. If the score is greater than 95, return 95."
),
)