Compare commits

...

2 Commits

Author SHA1 Message Date
Harrison Chase
859446b75a cr 2023-03-27 14:43:06 -07:00
Anish Shah
a9c2e31dcb Update wandb_callback.py (#1961)
- Fix bugs mentioned in
https://github.com/hwchase17/langchain/issues/1884
- Better logic flow to prevent errors on flush
- Better viz for text
2023-03-27 14:41:00 -07:00

View File

@@ -3,9 +3,12 @@ import json
import tempfile
from copy import deepcopy
from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union
from typing import Any, Dict, Iterable, List, Optional, Sequence, Set, Tuple, Union
from langchain.callbacks.base import BaseCallbackHandler
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.llms.openai import BaseOpenAI, OpenAIChat
from langchain.schema import AgentAction, AgentFinish, LLMResult
@@ -203,14 +206,69 @@ def construct_html_from_prompt_and_generation(prompt: str, generation: str) -> A
formatted_prompt = prompt.replace("\n", "<br>")
formatted_generation = generation.replace("\n", "<br>")
style = """<style type="text/css">
body {
font-family: Arial, sans-serif;
background-color: #f4f4f4;
display: flex;
justify-content: center;
align-items: center;
height: 100vh;
margin: 0;
}
.chat-container {
background-color: #ffffff;
border-radius: 8px;
padding: 20px;
max-width: 600px;
box-shadow: 0 2px 6px rgba(0, 0, 0, 0.1);
display: flex;
flex-direction: column;
align-items: flex-start;
}
.chat-message {
display: inline-block;
padding: 10px 15px;
border-radius: 18px;
margin-bottom: 8px;
max-width: 80%;
word-wrap: break-word;
}
.user {
background-color: #007bff;
color: #ffffff;
align-self: flex-end;
border-bottom-right-radius: 4px;
}
.chatbot {
background-color: #e6e6e6;
color: #333333;
align-self: flex-start;
border-bottom-left-radius: 4px;
}
</style>
"""
body = f"""<body>
<div class="chat-container">
<div class="chat-message user">{formatted_prompt}</div>
<div class="chat-message chatbot">{formatted_generation}</div>
</div>
</body>
"""
return wandb.Html(
f"""
<p style="color:black;">{formatted_prompt}:</p>
<blockquote>
<p style="color:green;">
{formatted_generation}
</p>
</blockquote>
<html>
<head>
{style}
</head>
{body}
</html>
""",
inject=False,
)
@@ -442,6 +500,7 @@ class WandbCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
self.complexity_metrics = complexity_metrics
self.visualize = visualize
self.nlp = spacy.load("en_core_web_sm")
self.input_keys: Set[str] = set()
def _init_resp(self) -> Dict:
return {k: None for k in self.callback_columns}
@@ -528,27 +587,15 @@ class WandbCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
resp = self._init_resp()
resp.update({"action": "on_chain_start"})
resp.update(flatten_dict(serialized))
resp.update(flatten_dict(inputs))
resp.update(self.get_custom_callback_meta())
chain_input = inputs["input"]
self.on_chain_start_records.append(resp)
self.action_records.append(resp)
if self.stream_logs:
self.run.log(resp)
if isinstance(chain_input, str):
input_resp = deepcopy(resp)
input_resp["input"] = chain_input
self.on_chain_start_records.append(input_resp)
self.action_records.append(input_resp)
if self.stream_logs:
self.run.log(input_resp)
elif isinstance(chain_input, list):
for inp in chain_input:
input_resp = deepcopy(resp)
input_resp.update(inp)
self.on_chain_start_records.append(input_resp)
self.action_records.append(input_resp)
if self.stream_logs:
self.run.log(input_resp)
else:
raise ValueError("Unexpected data format provided!")
self.input_keys.update(inputs.keys())
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
"""Run when chain ends running."""
@@ -557,7 +604,8 @@ class WandbCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
self.ends += 1
resp = self._init_resp()
resp.update({"action": "on_chain_end", "outputs": outputs["output"]})
resp.update({"action": "on_chain_end"})
resp.update(flatten_dict(outputs))
resp.update(self.get_custom_callback_meta())
self.on_chain_end_records.append(resp)
@@ -670,7 +718,7 @@ class WandbCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
if self.stream_logs:
self.run.log(resp)
def _create_session_analysis_df(self) -> Any:
def _create_common_session_analysis_df(self) -> Any:
"""Create a dataframe with all the information from the session."""
pd = import_pandas()
on_llm_start_records_df = pd.DataFrame(self.on_llm_start_records)
@@ -712,9 +760,6 @@ class WandbCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
[
"step",
"text",
"token_usage_total_tokens",
"token_usage_prompt_tokens",
"token_usage_completion_tokens",
]
+ complexity_metrics_columns
+ visualizations_columns
@@ -769,33 +814,91 @@ class WandbCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
"""
pd = import_pandas()
wandb = import_wandb()
action_records_table = wandb.Table(dataframe=pd.DataFrame(self.action_records))
session_analysis_table = wandb.Table(
dataframe=self._create_session_analysis_df()
)
self.run.log(
{
"action_records": action_records_table,
"session_analysis": session_analysis_table,
}
)
table_records = {}
if len(self.action_records) > 0:
action_records_table = wandb.Table(
dataframe=pd.DataFrame(self.action_records)
)
table_records["action_records"] = action_records_table
else:
wandb.termlog("No actions were run.")
if (
len(self.on_llm_start_records) == len(self.on_llm_end_records)
and len(self.on_llm_start_records) > 0
):
session_analysis_df = self._create_common_session_analysis_df()
if isinstance(langchain_asset, Chain):
if len(self.on_chain_start_records) == len(self.on_llm_start_records):
chain_inputs = pd.DataFrame(self.on_chain_start_records)[
list(self.input_keys)
]
session_analysis_df = pd.concat(
[chain_inputs, session_analysis_df], axis=1
)
if isinstance(langchain_asset, LLMChain):
if isinstance(langchain_asset.llm, (BaseOpenAI, OpenAIChat)):
openai_token_info = pd.DataFrame(self.on_llm_end_records)[
[
"token_usage_total_tokens",
"token_usage_prompt_tokens",
"token_usage_completion_tokens",
]
]
session_analysis_df = pd.concat(
[session_analysis_df, openai_token_info], axis=1
)
# TODO: DRY
if isinstance(langchain_asset.llm, (BaseOpenAI, OpenAIChat)):
openai_token_info = pd.DataFrame(self.on_llm_end_records)[
[
"token_usage_total_tokens",
"token_usage_prompt_tokens",
"token_usage_completion_tokens",
]
]
session_analysis_df = pd.concat(
[session_analysis_df, openai_token_info], axis=1
)
session_analysis_table = wandb.Table(dataframe=session_analysis_df)
table_records["session_analysis"] = session_analysis_table
else:
wandb.termlog(
f"Record mismatch: (on_llm_start_records - "
f"{len(self.on_llm_start_records)}) != (on_llm_end_records "
f"- {len(self.on_llm_end_records)})"
)
wandb.termlog("or")
wandb.termlog("No records run.")
self.run.log(table_records)
if langchain_asset:
langchain_asset_path = Path(self.temp_dir.name, "model.json")
model_artifact = wandb.Artifact(name="model", type="model")
model_artifact.add(action_records_table, name="action_records")
model_artifact.add(session_analysis_table, name="session_analysis")
for tab_name, tab in table_records.items():
model_artifact.add(tab, name=tab_name)
try:
langchain_asset.save(langchain_asset_path)
model_artifact.add_file(str(langchain_asset_path))
model_artifact.metadata = load_json_to_dict(langchain_asset_path)
model_meta = load_json_to_dict(langchain_asset_path)
model_artifact.metadata = model_meta
self.run.config.update(model_meta)
except ValueError:
# TODO: Replace with check of agent as opposed to this try catch
langchain_asset.save_agent(langchain_asset_path)
model_artifact.add_file(str(langchain_asset_path))
model_artifact.metadata = load_json_to_dict(langchain_asset_path)
model_meta = load_json_to_dict(langchain_asset_path)
model_artifact.metadata = model_meta
self.run.config.update(model_meta)
except NotImplementedError as e:
print("Could not save model.")
print(repr(e))
wandb.termlog("Could not save model.")
wandb.termlog(repr(e))
pass
self.run.log_artifact(model_artifact)
@@ -803,7 +906,7 @@ class WandbCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
self.run.finish()
self.temp_dir.cleanup()
self.reset_callback_meta()
if reset:
if reset and not finish:
self.__init__( # type: ignore
job_type=job_type if job_type else self.job_type,
project=project if project else self.project,