mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-24 23:54:14 +00:00
parent
75b6fa1134
commit
af5ae24af2
@ -416,15 +416,25 @@ class AimCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
||||
self._run.close()
|
||||
self.reset_callback_meta()
|
||||
if reset:
|
||||
self.__init__( # type: ignore
|
||||
repo=repo if repo else self.repo,
|
||||
experiment_name=experiment_name
|
||||
if experiment_name
|
||||
else self.experiment_name,
|
||||
system_tracking_interval=system_tracking_interval
|
||||
if system_tracking_interval
|
||||
else self.system_tracking_interval,
|
||||
log_system_params=log_system_params
|
||||
if log_system_params
|
||||
else self.log_system_params,
|
||||
aim = import_aim()
|
||||
self.repo = repo if repo else self.repo
|
||||
self.experiment_name = (
|
||||
experiment_name if experiment_name else self.experiment_name
|
||||
)
|
||||
self.system_tracking_interval = (
|
||||
system_tracking_interval
|
||||
if system_tracking_interval
|
||||
else self.system_tracking_interval
|
||||
)
|
||||
self.log_system_params = (
|
||||
log_system_params if log_system_params else self.log_system_params
|
||||
)
|
||||
|
||||
self._run = aim.Run(
|
||||
repo=self.repo,
|
||||
experiment=self.experiment_name,
|
||||
system_tracking_interval=self.system_tracking_interval,
|
||||
log_system_params=self.log_system_params,
|
||||
)
|
||||
self._run_hash = self._run.hash
|
||||
self.action_records = []
|
||||
|
@ -1,6 +1,6 @@
|
||||
import os
|
||||
import warnings
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, List, Optional, cast
|
||||
|
||||
from langchain_core.agents import AgentAction, AgentFinish
|
||||
from langchain_core.callbacks import BaseCallbackHandler
|
||||
@ -269,8 +269,8 @@ class ArgillaCallbackHandler(BaseCallbackHandler):
|
||||
for key in [str(kwargs["parent_run_id"]), str(kwargs["run_id"])]
|
||||
):
|
||||
return
|
||||
prompts = self.prompts.get(str(kwargs["parent_run_id"])) or self.prompts.get(
|
||||
str(kwargs["run_id"])
|
||||
prompts: List = self.prompts.get(str(kwargs["parent_run_id"])) or cast(
|
||||
List, self.prompts.get(str(kwargs["run_id"]), [])
|
||||
)
|
||||
for chain_output_key, chain_output_val in outputs.items():
|
||||
if isinstance(chain_output_val, list):
|
||||
@ -283,10 +283,7 @@ class ArgillaCallbackHandler(BaseCallbackHandler):
|
||||
"response": output["text"].strip(),
|
||||
},
|
||||
}
|
||||
for prompt, output in zip(
|
||||
prompts, # type: ignore
|
||||
chain_output_val,
|
||||
)
|
||||
for prompt, output in zip(prompts, chain_output_val)
|
||||
]
|
||||
)
|
||||
else:
|
||||
@ -295,7 +292,7 @@ class ArgillaCallbackHandler(BaseCallbackHandler):
|
||||
records=[
|
||||
{
|
||||
"fields": {
|
||||
"prompt": " ".join(prompts), # type: ignore
|
||||
"prompt": " ".join(prompts),
|
||||
"response": chain_output_val.strip(),
|
||||
},
|
||||
}
|
||||
|
@ -162,7 +162,7 @@ class ArthurCallbackHandler(BaseCallbackHandler):
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
"""On LLM end, send data to Arthur."""
|
||||
try:
|
||||
import pytz # type: ignore[import]
|
||||
import pytz
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Could not import pytz. Please install it with 'pip install pytz'."
|
||||
|
@ -83,7 +83,7 @@ class ClearMLCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
||||
if clearml.Task.current_task():
|
||||
self.task = clearml.Task.current_task()
|
||||
else:
|
||||
self.task = clearml.Task.init( # type: ignore
|
||||
self.task = clearml.Task.init(
|
||||
task_type=self.task_type,
|
||||
project_name=self.project_name,
|
||||
tags=self.tags,
|
||||
@ -361,17 +361,13 @@ class ClearMLCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
||||
if self.visualize and self.nlp and self.temp_dir.name is not None:
|
||||
doc = self.nlp(text)
|
||||
|
||||
dep_out = spacy.displacy.render( # type: ignore
|
||||
doc, style="dep", jupyter=False, page=True
|
||||
)
|
||||
dep_out = spacy.displacy.render(doc, style="dep", jupyter=False, page=True)
|
||||
dep_output_path = Path(
|
||||
self.temp_dir.name, hash_string(f"dep-{text}") + ".html"
|
||||
)
|
||||
dep_output_path.open("w", encoding="utf-8").write(dep_out)
|
||||
|
||||
ent_out = spacy.displacy.render( # type: ignore
|
||||
doc, style="ent", jupyter=False, page=True
|
||||
)
|
||||
ent_out = spacy.displacy.render(doc, style="ent", jupyter=False, page=True)
|
||||
ent_output_path = Path(
|
||||
self.temp_dir.name, hash_string(f"ent-{text}") + ".html"
|
||||
)
|
||||
|
@ -37,7 +37,7 @@ def _get_experiment(
|
||||
) -> Any:
|
||||
comet_ml = import_comet_ml()
|
||||
|
||||
experiment = comet_ml.Experiment( # type: ignore
|
||||
experiment = comet_ml.Experiment(
|
||||
workspace=workspace,
|
||||
project_name=project_name,
|
||||
)
|
||||
|
@ -79,12 +79,8 @@ def analyze_text(
|
||||
if nlp is not None:
|
||||
spacy = import_spacy()
|
||||
doc = nlp(text)
|
||||
dep_out = spacy.displacy.render( # type: ignore
|
||||
doc, style="dep", jupyter=False, page=True
|
||||
)
|
||||
ent_out = spacy.displacy.render( # type: ignore
|
||||
doc, style="ent", jupyter=False, page=True
|
||||
)
|
||||
dep_out = spacy.displacy.render(doc, style="dep", jupyter=False, page=True)
|
||||
ent_out = spacy.displacy.render(doc, style="ent", jupyter=False, page=True)
|
||||
text_visualizations = {
|
||||
"dependency_tree": dep_out,
|
||||
"entities": ent_out,
|
||||
@ -199,7 +195,7 @@ class FlyteCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
||||
|
||||
complexity_metrics: Dict[str, float] = generation_resp.pop(
|
||||
"text_complexity_metrics"
|
||||
) # type: ignore # noqa: E501
|
||||
)
|
||||
self.deck.append(
|
||||
self.markdown_renderer().to_html("#### Text Complexity Metrics")
|
||||
)
|
||||
|
@ -4,7 +4,7 @@ from typing import Any, Dict, List, Optional, cast
|
||||
from langchain_core.agents import AgentAction, AgentFinish
|
||||
from langchain_core.callbacks import BaseCallbackHandler
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.outputs import LLMResult
|
||||
from langchain_core.outputs import ChatGeneration, LLMResult
|
||||
|
||||
|
||||
def import_infino() -> Any:
|
||||
@ -146,7 +146,7 @@ class InfinoCallbackHandler(BaseCallbackHandler):
|
||||
# Track completion token usage (for openai chat models).
|
||||
if self.is_chat_openai_model:
|
||||
messages = " ".join(
|
||||
generation.message.content # type: ignore[attr-defined]
|
||||
cast(str, cast(ChatGeneration, generation).message.content)
|
||||
for generation in generations
|
||||
)
|
||||
completion_tokens = get_num_tokens(
|
||||
|
@ -109,13 +109,9 @@ def analyze_text(
|
||||
spacy = import_spacy()
|
||||
doc = nlp(text)
|
||||
|
||||
dep_out = spacy.displacy.render( # type: ignore
|
||||
doc, style="dep", jupyter=False, page=True
|
||||
)
|
||||
dep_out = spacy.displacy.render(doc, style="dep", jupyter=False, page=True)
|
||||
|
||||
ent_out = spacy.displacy.render( # type: ignore
|
||||
doc, style="ent", jupyter=False, page=True
|
||||
)
|
||||
ent_out = spacy.displacy.render(doc, style="ent", jupyter=False, page=True)
|
||||
|
||||
text_visualizations = {
|
||||
"dependency_tree": dep_out,
|
||||
@ -233,7 +229,7 @@ class MlflowLogger:
|
||||
data, os.path.join(self.dir, f"{filename}.json"), run_id=self.run_id
|
||||
)
|
||||
|
||||
def table(self, name: str, dataframe) -> None: # type: ignore
|
||||
def table(self, name: str, dataframe: Any) -> None:
|
||||
"""To log the input pandas dataframe as a html table"""
|
||||
self.html(dataframe.to_html(), f"table_{name}")
|
||||
|
||||
@ -411,7 +407,7 @@ class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
||||
)
|
||||
complexity_metrics: Dict[str, float] = generation_resp.pop(
|
||||
"text_complexity_metrics"
|
||||
) # type: ignore # noqa: E501
|
||||
)
|
||||
self.mlflg.metrics(
|
||||
complexity_metrics,
|
||||
step=self.metrics["step"],
|
||||
@ -723,7 +719,7 @@ class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
||||
)
|
||||
return session_analysis_df
|
||||
|
||||
def _contain_llm_records(self): # type: ignore[no-untyped-def]
|
||||
def _contain_llm_records(self) -> bool:
|
||||
return bool(self.records["on_llm_start_records"])
|
||||
|
||||
def flush_tracker(self, langchain_asset: Any = None, finish: bool = False) -> None:
|
||||
|
@ -62,7 +62,7 @@ def StreamlitCallbackHandler(
|
||||
# guaranteed to support the same set of kwargs.
|
||||
try:
|
||||
from streamlit.external.langchain import (
|
||||
StreamlitCallbackHandler as OfficialStreamlitCallbackHandler, # type: ignore # noqa: 501
|
||||
StreamlitCallbackHandler as OfficialStreamlitCallbackHandler,
|
||||
)
|
||||
|
||||
return OfficialStreamlitCallbackHandler(
|
||||
|
@ -108,7 +108,7 @@ class MutableExpander:
|
||||
) -> int:
|
||||
"""Add a Markdown element to the container and return its index."""
|
||||
kwargs = {"body": body, "unsafe_allow_html": unsafe_allow_html, "help": help}
|
||||
new_dg = self._get_dg(index).markdown(**kwargs) # type: ignore[arg-type]
|
||||
new_dg = self._get_dg(index).markdown(**kwargs)
|
||||
record = ChildRecord(ChildType.MARKDOWN, kwargs, new_dg)
|
||||
return self._add_record(record, index)
|
||||
|
||||
|
@ -489,11 +489,10 @@ class WandbTracer(BaseTracer):
|
||||
If not, will start a new run with the provided run_args.
|
||||
"""
|
||||
if self._wandb.run is None:
|
||||
run_args = self._run_args or {} # type: ignore
|
||||
run_args: dict = {**run_args} # type: ignore
|
||||
run_args: Dict = {**(self._run_args or {})}
|
||||
|
||||
if "settings" not in run_args: # type: ignore
|
||||
run_args["settings"] = {"silent": True} # type: ignore
|
||||
if "settings" not in run_args:
|
||||
run_args["settings"] = {"silent": True}
|
||||
|
||||
self._wandb.init(**run_args)
|
||||
if self._wandb.run is not None:
|
||||
|
@ -92,15 +92,11 @@ def analyze_text(
|
||||
if visualize and nlp and output_dir is not None:
|
||||
doc = nlp(text)
|
||||
|
||||
dep_out = spacy.displacy.render( # type: ignore
|
||||
doc, style="dep", jupyter=False, page=True
|
||||
)
|
||||
dep_out = spacy.displacy.render(doc, style="dep", jupyter=False, page=True)
|
||||
dep_output_path = Path(output_dir, hash_string(f"dep-{text}") + ".html")
|
||||
dep_output_path.open("w", encoding="utf-8").write(dep_out)
|
||||
|
||||
ent_out = spacy.displacy.render( # type: ignore
|
||||
doc, style="ent", jupyter=False, page=True
|
||||
)
|
||||
ent_out = spacy.displacy.render(doc, style="ent", jupyter=False, page=True)
|
||||
ent_output_path = Path(output_dir, hash_string(f"ent-{text}") + ".html")
|
||||
ent_output_path.open("w", encoding="utf-8").write(ent_out)
|
||||
|
||||
@ -193,7 +189,7 @@ class WandbCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
||||
self.stream_logs = stream_logs
|
||||
|
||||
self.temp_dir = tempfile.TemporaryDirectory()
|
||||
self.run: wandb.sdk.wandb_run.Run = wandb.init( # type: ignore
|
||||
self.run = wandb.init(
|
||||
job_type=self.job_type,
|
||||
project=self.project,
|
||||
entity=self.entity,
|
||||
|
Loading…
Reference in New Issue
Block a user