mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-24 15:43:54 +00:00
parent
75b6fa1134
commit
af5ae24af2
@ -416,15 +416,25 @@ class AimCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
|||||||
self._run.close()
|
self._run.close()
|
||||||
self.reset_callback_meta()
|
self.reset_callback_meta()
|
||||||
if reset:
|
if reset:
|
||||||
self.__init__( # type: ignore
|
aim = import_aim()
|
||||||
repo=repo if repo else self.repo,
|
self.repo = repo if repo else self.repo
|
||||||
experiment_name=experiment_name
|
self.experiment_name = (
|
||||||
if experiment_name
|
experiment_name if experiment_name else self.experiment_name
|
||||||
else self.experiment_name,
|
|
||||||
system_tracking_interval=system_tracking_interval
|
|
||||||
if system_tracking_interval
|
|
||||||
else self.system_tracking_interval,
|
|
||||||
log_system_params=log_system_params
|
|
||||||
if log_system_params
|
|
||||||
else self.log_system_params,
|
|
||||||
)
|
)
|
||||||
|
self.system_tracking_interval = (
|
||||||
|
system_tracking_interval
|
||||||
|
if system_tracking_interval
|
||||||
|
else self.system_tracking_interval
|
||||||
|
)
|
||||||
|
self.log_system_params = (
|
||||||
|
log_system_params if log_system_params else self.log_system_params
|
||||||
|
)
|
||||||
|
|
||||||
|
self._run = aim.Run(
|
||||||
|
repo=self.repo,
|
||||||
|
experiment=self.experiment_name,
|
||||||
|
system_tracking_interval=self.system_tracking_interval,
|
||||||
|
log_system_params=self.log_system_params,
|
||||||
|
)
|
||||||
|
self._run_hash = self._run.hash
|
||||||
|
self.action_records = []
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional, cast
|
||||||
|
|
||||||
from langchain_core.agents import AgentAction, AgentFinish
|
from langchain_core.agents import AgentAction, AgentFinish
|
||||||
from langchain_core.callbacks import BaseCallbackHandler
|
from langchain_core.callbacks import BaseCallbackHandler
|
||||||
@ -269,8 +269,8 @@ class ArgillaCallbackHandler(BaseCallbackHandler):
|
|||||||
for key in [str(kwargs["parent_run_id"]), str(kwargs["run_id"])]
|
for key in [str(kwargs["parent_run_id"]), str(kwargs["run_id"])]
|
||||||
):
|
):
|
||||||
return
|
return
|
||||||
prompts = self.prompts.get(str(kwargs["parent_run_id"])) or self.prompts.get(
|
prompts: List = self.prompts.get(str(kwargs["parent_run_id"])) or cast(
|
||||||
str(kwargs["run_id"])
|
List, self.prompts.get(str(kwargs["run_id"]), [])
|
||||||
)
|
)
|
||||||
for chain_output_key, chain_output_val in outputs.items():
|
for chain_output_key, chain_output_val in outputs.items():
|
||||||
if isinstance(chain_output_val, list):
|
if isinstance(chain_output_val, list):
|
||||||
@ -283,10 +283,7 @@ class ArgillaCallbackHandler(BaseCallbackHandler):
|
|||||||
"response": output["text"].strip(),
|
"response": output["text"].strip(),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
for prompt, output in zip(
|
for prompt, output in zip(prompts, chain_output_val)
|
||||||
prompts, # type: ignore
|
|
||||||
chain_output_val,
|
|
||||||
)
|
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@ -295,7 +292,7 @@ class ArgillaCallbackHandler(BaseCallbackHandler):
|
|||||||
records=[
|
records=[
|
||||||
{
|
{
|
||||||
"fields": {
|
"fields": {
|
||||||
"prompt": " ".join(prompts), # type: ignore
|
"prompt": " ".join(prompts),
|
||||||
"response": chain_output_val.strip(),
|
"response": chain_output_val.strip(),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
@ -162,7 +162,7 @@ class ArthurCallbackHandler(BaseCallbackHandler):
|
|||||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||||
"""On LLM end, send data to Arthur."""
|
"""On LLM end, send data to Arthur."""
|
||||||
try:
|
try:
|
||||||
import pytz # type: ignore[import]
|
import pytz
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"Could not import pytz. Please install it with 'pip install pytz'."
|
"Could not import pytz. Please install it with 'pip install pytz'."
|
||||||
|
@ -83,7 +83,7 @@ class ClearMLCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
|||||||
if clearml.Task.current_task():
|
if clearml.Task.current_task():
|
||||||
self.task = clearml.Task.current_task()
|
self.task = clearml.Task.current_task()
|
||||||
else:
|
else:
|
||||||
self.task = clearml.Task.init( # type: ignore
|
self.task = clearml.Task.init(
|
||||||
task_type=self.task_type,
|
task_type=self.task_type,
|
||||||
project_name=self.project_name,
|
project_name=self.project_name,
|
||||||
tags=self.tags,
|
tags=self.tags,
|
||||||
@ -361,17 +361,13 @@ class ClearMLCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
|||||||
if self.visualize and self.nlp and self.temp_dir.name is not None:
|
if self.visualize and self.nlp and self.temp_dir.name is not None:
|
||||||
doc = self.nlp(text)
|
doc = self.nlp(text)
|
||||||
|
|
||||||
dep_out = spacy.displacy.render( # type: ignore
|
dep_out = spacy.displacy.render(doc, style="dep", jupyter=False, page=True)
|
||||||
doc, style="dep", jupyter=False, page=True
|
|
||||||
)
|
|
||||||
dep_output_path = Path(
|
dep_output_path = Path(
|
||||||
self.temp_dir.name, hash_string(f"dep-{text}") + ".html"
|
self.temp_dir.name, hash_string(f"dep-{text}") + ".html"
|
||||||
)
|
)
|
||||||
dep_output_path.open("w", encoding="utf-8").write(dep_out)
|
dep_output_path.open("w", encoding="utf-8").write(dep_out)
|
||||||
|
|
||||||
ent_out = spacy.displacy.render( # type: ignore
|
ent_out = spacy.displacy.render(doc, style="ent", jupyter=False, page=True)
|
||||||
doc, style="ent", jupyter=False, page=True
|
|
||||||
)
|
|
||||||
ent_output_path = Path(
|
ent_output_path = Path(
|
||||||
self.temp_dir.name, hash_string(f"ent-{text}") + ".html"
|
self.temp_dir.name, hash_string(f"ent-{text}") + ".html"
|
||||||
)
|
)
|
||||||
|
@ -37,7 +37,7 @@ def _get_experiment(
|
|||||||
) -> Any:
|
) -> Any:
|
||||||
comet_ml = import_comet_ml()
|
comet_ml = import_comet_ml()
|
||||||
|
|
||||||
experiment = comet_ml.Experiment( # type: ignore
|
experiment = comet_ml.Experiment(
|
||||||
workspace=workspace,
|
workspace=workspace,
|
||||||
project_name=project_name,
|
project_name=project_name,
|
||||||
)
|
)
|
||||||
|
@ -79,12 +79,8 @@ def analyze_text(
|
|||||||
if nlp is not None:
|
if nlp is not None:
|
||||||
spacy = import_spacy()
|
spacy = import_spacy()
|
||||||
doc = nlp(text)
|
doc = nlp(text)
|
||||||
dep_out = spacy.displacy.render( # type: ignore
|
dep_out = spacy.displacy.render(doc, style="dep", jupyter=False, page=True)
|
||||||
doc, style="dep", jupyter=False, page=True
|
ent_out = spacy.displacy.render(doc, style="ent", jupyter=False, page=True)
|
||||||
)
|
|
||||||
ent_out = spacy.displacy.render( # type: ignore
|
|
||||||
doc, style="ent", jupyter=False, page=True
|
|
||||||
)
|
|
||||||
text_visualizations = {
|
text_visualizations = {
|
||||||
"dependency_tree": dep_out,
|
"dependency_tree": dep_out,
|
||||||
"entities": ent_out,
|
"entities": ent_out,
|
||||||
@ -199,7 +195,7 @@ class FlyteCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
|||||||
|
|
||||||
complexity_metrics: Dict[str, float] = generation_resp.pop(
|
complexity_metrics: Dict[str, float] = generation_resp.pop(
|
||||||
"text_complexity_metrics"
|
"text_complexity_metrics"
|
||||||
) # type: ignore # noqa: E501
|
)
|
||||||
self.deck.append(
|
self.deck.append(
|
||||||
self.markdown_renderer().to_html("#### Text Complexity Metrics")
|
self.markdown_renderer().to_html("#### Text Complexity Metrics")
|
||||||
)
|
)
|
||||||
|
@ -4,7 +4,7 @@ from typing import Any, Dict, List, Optional, cast
|
|||||||
from langchain_core.agents import AgentAction, AgentFinish
|
from langchain_core.agents import AgentAction, AgentFinish
|
||||||
from langchain_core.callbacks import BaseCallbackHandler
|
from langchain_core.callbacks import BaseCallbackHandler
|
||||||
from langchain_core.messages import BaseMessage
|
from langchain_core.messages import BaseMessage
|
||||||
from langchain_core.outputs import LLMResult
|
from langchain_core.outputs import ChatGeneration, LLMResult
|
||||||
|
|
||||||
|
|
||||||
def import_infino() -> Any:
|
def import_infino() -> Any:
|
||||||
@ -146,7 +146,7 @@ class InfinoCallbackHandler(BaseCallbackHandler):
|
|||||||
# Track completion token usage (for openai chat models).
|
# Track completion token usage (for openai chat models).
|
||||||
if self.is_chat_openai_model:
|
if self.is_chat_openai_model:
|
||||||
messages = " ".join(
|
messages = " ".join(
|
||||||
generation.message.content # type: ignore[attr-defined]
|
cast(str, cast(ChatGeneration, generation).message.content)
|
||||||
for generation in generations
|
for generation in generations
|
||||||
)
|
)
|
||||||
completion_tokens = get_num_tokens(
|
completion_tokens = get_num_tokens(
|
||||||
|
@ -109,13 +109,9 @@ def analyze_text(
|
|||||||
spacy = import_spacy()
|
spacy = import_spacy()
|
||||||
doc = nlp(text)
|
doc = nlp(text)
|
||||||
|
|
||||||
dep_out = spacy.displacy.render( # type: ignore
|
dep_out = spacy.displacy.render(doc, style="dep", jupyter=False, page=True)
|
||||||
doc, style="dep", jupyter=False, page=True
|
|
||||||
)
|
|
||||||
|
|
||||||
ent_out = spacy.displacy.render( # type: ignore
|
ent_out = spacy.displacy.render(doc, style="ent", jupyter=False, page=True)
|
||||||
doc, style="ent", jupyter=False, page=True
|
|
||||||
)
|
|
||||||
|
|
||||||
text_visualizations = {
|
text_visualizations = {
|
||||||
"dependency_tree": dep_out,
|
"dependency_tree": dep_out,
|
||||||
@ -233,7 +229,7 @@ class MlflowLogger:
|
|||||||
data, os.path.join(self.dir, f"{filename}.json"), run_id=self.run_id
|
data, os.path.join(self.dir, f"{filename}.json"), run_id=self.run_id
|
||||||
)
|
)
|
||||||
|
|
||||||
def table(self, name: str, dataframe) -> None: # type: ignore
|
def table(self, name: str, dataframe: Any) -> None:
|
||||||
"""To log the input pandas dataframe as a html table"""
|
"""To log the input pandas dataframe as a html table"""
|
||||||
self.html(dataframe.to_html(), f"table_{name}")
|
self.html(dataframe.to_html(), f"table_{name}")
|
||||||
|
|
||||||
@ -411,7 +407,7 @@ class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
|||||||
)
|
)
|
||||||
complexity_metrics: Dict[str, float] = generation_resp.pop(
|
complexity_metrics: Dict[str, float] = generation_resp.pop(
|
||||||
"text_complexity_metrics"
|
"text_complexity_metrics"
|
||||||
) # type: ignore # noqa: E501
|
)
|
||||||
self.mlflg.metrics(
|
self.mlflg.metrics(
|
||||||
complexity_metrics,
|
complexity_metrics,
|
||||||
step=self.metrics["step"],
|
step=self.metrics["step"],
|
||||||
@ -723,7 +719,7 @@ class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
|||||||
)
|
)
|
||||||
return session_analysis_df
|
return session_analysis_df
|
||||||
|
|
||||||
def _contain_llm_records(self): # type: ignore[no-untyped-def]
|
def _contain_llm_records(self) -> bool:
|
||||||
return bool(self.records["on_llm_start_records"])
|
return bool(self.records["on_llm_start_records"])
|
||||||
|
|
||||||
def flush_tracker(self, langchain_asset: Any = None, finish: bool = False) -> None:
|
def flush_tracker(self, langchain_asset: Any = None, finish: bool = False) -> None:
|
||||||
|
@ -62,7 +62,7 @@ def StreamlitCallbackHandler(
|
|||||||
# guaranteed to support the same set of kwargs.
|
# guaranteed to support the same set of kwargs.
|
||||||
try:
|
try:
|
||||||
from streamlit.external.langchain import (
|
from streamlit.external.langchain import (
|
||||||
StreamlitCallbackHandler as OfficialStreamlitCallbackHandler, # type: ignore # noqa: 501
|
StreamlitCallbackHandler as OfficialStreamlitCallbackHandler,
|
||||||
)
|
)
|
||||||
|
|
||||||
return OfficialStreamlitCallbackHandler(
|
return OfficialStreamlitCallbackHandler(
|
||||||
|
@ -108,7 +108,7 @@ class MutableExpander:
|
|||||||
) -> int:
|
) -> int:
|
||||||
"""Add a Markdown element to the container and return its index."""
|
"""Add a Markdown element to the container and return its index."""
|
||||||
kwargs = {"body": body, "unsafe_allow_html": unsafe_allow_html, "help": help}
|
kwargs = {"body": body, "unsafe_allow_html": unsafe_allow_html, "help": help}
|
||||||
new_dg = self._get_dg(index).markdown(**kwargs) # type: ignore[arg-type]
|
new_dg = self._get_dg(index).markdown(**kwargs)
|
||||||
record = ChildRecord(ChildType.MARKDOWN, kwargs, new_dg)
|
record = ChildRecord(ChildType.MARKDOWN, kwargs, new_dg)
|
||||||
return self._add_record(record, index)
|
return self._add_record(record, index)
|
||||||
|
|
||||||
|
@ -489,11 +489,10 @@ class WandbTracer(BaseTracer):
|
|||||||
If not, will start a new run with the provided run_args.
|
If not, will start a new run with the provided run_args.
|
||||||
"""
|
"""
|
||||||
if self._wandb.run is None:
|
if self._wandb.run is None:
|
||||||
run_args = self._run_args or {} # type: ignore
|
run_args: Dict = {**(self._run_args or {})}
|
||||||
run_args: dict = {**run_args} # type: ignore
|
|
||||||
|
|
||||||
if "settings" not in run_args: # type: ignore
|
if "settings" not in run_args:
|
||||||
run_args["settings"] = {"silent": True} # type: ignore
|
run_args["settings"] = {"silent": True}
|
||||||
|
|
||||||
self._wandb.init(**run_args)
|
self._wandb.init(**run_args)
|
||||||
if self._wandb.run is not None:
|
if self._wandb.run is not None:
|
||||||
|
@ -92,15 +92,11 @@ def analyze_text(
|
|||||||
if visualize and nlp and output_dir is not None:
|
if visualize and nlp and output_dir is not None:
|
||||||
doc = nlp(text)
|
doc = nlp(text)
|
||||||
|
|
||||||
dep_out = spacy.displacy.render( # type: ignore
|
dep_out = spacy.displacy.render(doc, style="dep", jupyter=False, page=True)
|
||||||
doc, style="dep", jupyter=False, page=True
|
|
||||||
)
|
|
||||||
dep_output_path = Path(output_dir, hash_string(f"dep-{text}") + ".html")
|
dep_output_path = Path(output_dir, hash_string(f"dep-{text}") + ".html")
|
||||||
dep_output_path.open("w", encoding="utf-8").write(dep_out)
|
dep_output_path.open("w", encoding="utf-8").write(dep_out)
|
||||||
|
|
||||||
ent_out = spacy.displacy.render( # type: ignore
|
ent_out = spacy.displacy.render(doc, style="ent", jupyter=False, page=True)
|
||||||
doc, style="ent", jupyter=False, page=True
|
|
||||||
)
|
|
||||||
ent_output_path = Path(output_dir, hash_string(f"ent-{text}") + ".html")
|
ent_output_path = Path(output_dir, hash_string(f"ent-{text}") + ".html")
|
||||||
ent_output_path.open("w", encoding="utf-8").write(ent_out)
|
ent_output_path.open("w", encoding="utf-8").write(ent_out)
|
||||||
|
|
||||||
@ -193,7 +189,7 @@ class WandbCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
|||||||
self.stream_logs = stream_logs
|
self.stream_logs = stream_logs
|
||||||
|
|
||||||
self.temp_dir = tempfile.TemporaryDirectory()
|
self.temp_dir = tempfile.TemporaryDirectory()
|
||||||
self.run: wandb.sdk.wandb_run.Run = wandb.init( # type: ignore
|
self.run = wandb.init(
|
||||||
job_type=self.job_type,
|
job_type=self.job_type,
|
||||||
project=self.project,
|
project=self.project,
|
||||||
entity=self.entity,
|
entity=self.entity,
|
||||||
|
Loading…
Reference in New Issue
Block a user