community[patch]: callbacks mypy fixes (#17058)

Related to #17048
This commit is contained in:
Bagatur 2024-02-05 12:37:27 -08:00 committed by GitHub
parent 75b6fa1134
commit af5ae24af2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 49 additions and 59 deletions

View File

@ -416,15 +416,25 @@ class AimCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
self._run.close() self._run.close()
self.reset_callback_meta() self.reset_callback_meta()
if reset: if reset:
self.__init__( # type: ignore aim = import_aim()
repo=repo if repo else self.repo, self.repo = repo if repo else self.repo
experiment_name=experiment_name self.experiment_name = (
if experiment_name experiment_name if experiment_name else self.experiment_name
else self.experiment_name,
system_tracking_interval=system_tracking_interval
if system_tracking_interval
else self.system_tracking_interval,
log_system_params=log_system_params
if log_system_params
else self.log_system_params,
) )
self.system_tracking_interval = (
system_tracking_interval
if system_tracking_interval
else self.system_tracking_interval
)
self.log_system_params = (
log_system_params if log_system_params else self.log_system_params
)
self._run = aim.Run(
repo=self.repo,
experiment=self.experiment_name,
system_tracking_interval=self.system_tracking_interval,
log_system_params=self.log_system_params,
)
self._run_hash = self._run.hash
self.action_records = []

View File

@ -1,6 +1,6 @@
import os import os
import warnings import warnings
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional, cast
from langchain_core.agents import AgentAction, AgentFinish from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.callbacks import BaseCallbackHandler from langchain_core.callbacks import BaseCallbackHandler
@ -269,8 +269,8 @@ class ArgillaCallbackHandler(BaseCallbackHandler):
for key in [str(kwargs["parent_run_id"]), str(kwargs["run_id"])] for key in [str(kwargs["parent_run_id"]), str(kwargs["run_id"])]
): ):
return return
prompts = self.prompts.get(str(kwargs["parent_run_id"])) or self.prompts.get( prompts: List = self.prompts.get(str(kwargs["parent_run_id"])) or cast(
str(kwargs["run_id"]) List, self.prompts.get(str(kwargs["run_id"]), [])
) )
for chain_output_key, chain_output_val in outputs.items(): for chain_output_key, chain_output_val in outputs.items():
if isinstance(chain_output_val, list): if isinstance(chain_output_val, list):
@ -283,10 +283,7 @@ class ArgillaCallbackHandler(BaseCallbackHandler):
"response": output["text"].strip(), "response": output["text"].strip(),
}, },
} }
for prompt, output in zip( for prompt, output in zip(prompts, chain_output_val)
prompts, # type: ignore
chain_output_val,
)
] ]
) )
else: else:
@ -295,7 +292,7 @@ class ArgillaCallbackHandler(BaseCallbackHandler):
records=[ records=[
{ {
"fields": { "fields": {
"prompt": " ".join(prompts), # type: ignore "prompt": " ".join(prompts),
"response": chain_output_val.strip(), "response": chain_output_val.strip(),
}, },
} }

View File

@ -162,7 +162,7 @@ class ArthurCallbackHandler(BaseCallbackHandler):
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""On LLM end, send data to Arthur.""" """On LLM end, send data to Arthur."""
try: try:
import pytz # type: ignore[import] import pytz
except ImportError as e: except ImportError as e:
raise ImportError( raise ImportError(
"Could not import pytz. Please install it with 'pip install pytz'." "Could not import pytz. Please install it with 'pip install pytz'."

View File

@ -83,7 +83,7 @@ class ClearMLCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
if clearml.Task.current_task(): if clearml.Task.current_task():
self.task = clearml.Task.current_task() self.task = clearml.Task.current_task()
else: else:
self.task = clearml.Task.init( # type: ignore self.task = clearml.Task.init(
task_type=self.task_type, task_type=self.task_type,
project_name=self.project_name, project_name=self.project_name,
tags=self.tags, tags=self.tags,
@ -361,17 +361,13 @@ class ClearMLCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
if self.visualize and self.nlp and self.temp_dir.name is not None: if self.visualize and self.nlp and self.temp_dir.name is not None:
doc = self.nlp(text) doc = self.nlp(text)
dep_out = spacy.displacy.render( # type: ignore dep_out = spacy.displacy.render(doc, style="dep", jupyter=False, page=True)
doc, style="dep", jupyter=False, page=True
)
dep_output_path = Path( dep_output_path = Path(
self.temp_dir.name, hash_string(f"dep-{text}") + ".html" self.temp_dir.name, hash_string(f"dep-{text}") + ".html"
) )
dep_output_path.open("w", encoding="utf-8").write(dep_out) dep_output_path.open("w", encoding="utf-8").write(dep_out)
ent_out = spacy.displacy.render( # type: ignore ent_out = spacy.displacy.render(doc, style="ent", jupyter=False, page=True)
doc, style="ent", jupyter=False, page=True
)
ent_output_path = Path( ent_output_path = Path(
self.temp_dir.name, hash_string(f"ent-{text}") + ".html" self.temp_dir.name, hash_string(f"ent-{text}") + ".html"
) )

View File

@ -37,7 +37,7 @@ def _get_experiment(
) -> Any: ) -> Any:
comet_ml = import_comet_ml() comet_ml = import_comet_ml()
experiment = comet_ml.Experiment( # type: ignore experiment = comet_ml.Experiment(
workspace=workspace, workspace=workspace,
project_name=project_name, project_name=project_name,
) )

View File

@ -79,12 +79,8 @@ def analyze_text(
if nlp is not None: if nlp is not None:
spacy = import_spacy() spacy = import_spacy()
doc = nlp(text) doc = nlp(text)
dep_out = spacy.displacy.render( # type: ignore dep_out = spacy.displacy.render(doc, style="dep", jupyter=False, page=True)
doc, style="dep", jupyter=False, page=True ent_out = spacy.displacy.render(doc, style="ent", jupyter=False, page=True)
)
ent_out = spacy.displacy.render( # type: ignore
doc, style="ent", jupyter=False, page=True
)
text_visualizations = { text_visualizations = {
"dependency_tree": dep_out, "dependency_tree": dep_out,
"entities": ent_out, "entities": ent_out,
@ -199,7 +195,7 @@ class FlyteCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
complexity_metrics: Dict[str, float] = generation_resp.pop( complexity_metrics: Dict[str, float] = generation_resp.pop(
"text_complexity_metrics" "text_complexity_metrics"
) # type: ignore # noqa: E501 )
self.deck.append( self.deck.append(
self.markdown_renderer().to_html("#### Text Complexity Metrics") self.markdown_renderer().to_html("#### Text Complexity Metrics")
) )

View File

@ -4,7 +4,7 @@ from typing import Any, Dict, List, Optional, cast
from langchain_core.agents import AgentAction, AgentFinish from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.callbacks import BaseCallbackHandler from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.messages import BaseMessage from langchain_core.messages import BaseMessage
from langchain_core.outputs import LLMResult from langchain_core.outputs import ChatGeneration, LLMResult
def import_infino() -> Any: def import_infino() -> Any:
@ -146,7 +146,7 @@ class InfinoCallbackHandler(BaseCallbackHandler):
# Track completion token usage (for openai chat models). # Track completion token usage (for openai chat models).
if self.is_chat_openai_model: if self.is_chat_openai_model:
messages = " ".join( messages = " ".join(
generation.message.content # type: ignore[attr-defined] cast(str, cast(ChatGeneration, generation).message.content)
for generation in generations for generation in generations
) )
completion_tokens = get_num_tokens( completion_tokens = get_num_tokens(

View File

@ -109,13 +109,9 @@ def analyze_text(
spacy = import_spacy() spacy = import_spacy()
doc = nlp(text) doc = nlp(text)
dep_out = spacy.displacy.render( # type: ignore dep_out = spacy.displacy.render(doc, style="dep", jupyter=False, page=True)
doc, style="dep", jupyter=False, page=True
)
ent_out = spacy.displacy.render( # type: ignore ent_out = spacy.displacy.render(doc, style="ent", jupyter=False, page=True)
doc, style="ent", jupyter=False, page=True
)
text_visualizations = { text_visualizations = {
"dependency_tree": dep_out, "dependency_tree": dep_out,
@ -233,7 +229,7 @@ class MlflowLogger:
data, os.path.join(self.dir, f"{filename}.json"), run_id=self.run_id data, os.path.join(self.dir, f"{filename}.json"), run_id=self.run_id
) )
def table(self, name: str, dataframe) -> None: # type: ignore def table(self, name: str, dataframe: Any) -> None:
"""To log the input pandas dataframe as a html table""" """To log the input pandas dataframe as a html table"""
self.html(dataframe.to_html(), f"table_{name}") self.html(dataframe.to_html(), f"table_{name}")
@ -411,7 +407,7 @@ class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
) )
complexity_metrics: Dict[str, float] = generation_resp.pop( complexity_metrics: Dict[str, float] = generation_resp.pop(
"text_complexity_metrics" "text_complexity_metrics"
) # type: ignore # noqa: E501 )
self.mlflg.metrics( self.mlflg.metrics(
complexity_metrics, complexity_metrics,
step=self.metrics["step"], step=self.metrics["step"],
@ -723,7 +719,7 @@ class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
) )
return session_analysis_df return session_analysis_df
def _contain_llm_records(self): # type: ignore[no-untyped-def] def _contain_llm_records(self) -> bool:
return bool(self.records["on_llm_start_records"]) return bool(self.records["on_llm_start_records"])
def flush_tracker(self, langchain_asset: Any = None, finish: bool = False) -> None: def flush_tracker(self, langchain_asset: Any = None, finish: bool = False) -> None:

View File

@ -62,7 +62,7 @@ def StreamlitCallbackHandler(
# guaranteed to support the same set of kwargs. # guaranteed to support the same set of kwargs.
try: try:
from streamlit.external.langchain import ( from streamlit.external.langchain import (
StreamlitCallbackHandler as OfficialStreamlitCallbackHandler, # type: ignore # noqa: 501 StreamlitCallbackHandler as OfficialStreamlitCallbackHandler,
) )
return OfficialStreamlitCallbackHandler( return OfficialStreamlitCallbackHandler(

View File

@ -108,7 +108,7 @@ class MutableExpander:
) -> int: ) -> int:
"""Add a Markdown element to the container and return its index.""" """Add a Markdown element to the container and return its index."""
kwargs = {"body": body, "unsafe_allow_html": unsafe_allow_html, "help": help} kwargs = {"body": body, "unsafe_allow_html": unsafe_allow_html, "help": help}
new_dg = self._get_dg(index).markdown(**kwargs) # type: ignore[arg-type] new_dg = self._get_dg(index).markdown(**kwargs)
record = ChildRecord(ChildType.MARKDOWN, kwargs, new_dg) record = ChildRecord(ChildType.MARKDOWN, kwargs, new_dg)
return self._add_record(record, index) return self._add_record(record, index)

View File

@ -489,11 +489,10 @@ class WandbTracer(BaseTracer):
If not, will start a new run with the provided run_args. If not, will start a new run with the provided run_args.
""" """
if self._wandb.run is None: if self._wandb.run is None:
run_args = self._run_args or {} # type: ignore run_args: Dict = {**(self._run_args or {})}
run_args: dict = {**run_args} # type: ignore
if "settings" not in run_args: # type: ignore if "settings" not in run_args:
run_args["settings"] = {"silent": True} # type: ignore run_args["settings"] = {"silent": True}
self._wandb.init(**run_args) self._wandb.init(**run_args)
if self._wandb.run is not None: if self._wandb.run is not None:

View File

@ -92,15 +92,11 @@ def analyze_text(
if visualize and nlp and output_dir is not None: if visualize and nlp and output_dir is not None:
doc = nlp(text) doc = nlp(text)
dep_out = spacy.displacy.render( # type: ignore dep_out = spacy.displacy.render(doc, style="dep", jupyter=False, page=True)
doc, style="dep", jupyter=False, page=True
)
dep_output_path = Path(output_dir, hash_string(f"dep-{text}") + ".html") dep_output_path = Path(output_dir, hash_string(f"dep-{text}") + ".html")
dep_output_path.open("w", encoding="utf-8").write(dep_out) dep_output_path.open("w", encoding="utf-8").write(dep_out)
ent_out = spacy.displacy.render( # type: ignore ent_out = spacy.displacy.render(doc, style="ent", jupyter=False, page=True)
doc, style="ent", jupyter=False, page=True
)
ent_output_path = Path(output_dir, hash_string(f"ent-{text}") + ".html") ent_output_path = Path(output_dir, hash_string(f"ent-{text}") + ".html")
ent_output_path.open("w", encoding="utf-8").write(ent_out) ent_output_path.open("w", encoding="utf-8").write(ent_out)
@ -193,7 +189,7 @@ class WandbCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
self.stream_logs = stream_logs self.stream_logs = stream_logs
self.temp_dir = tempfile.TemporaryDirectory() self.temp_dir = tempfile.TemporaryDirectory()
self.run: wandb.sdk.wandb_run.Run = wandb.init( # type: ignore self.run = wandb.init(
job_type=self.job_type, job_type=self.job_type,
project=self.project, project=self.project,
entity=self.entity, entity=self.entity,