From af5ae24af2b32e962adf23d78e59ed505d17fff7 Mon Sep 17 00:00:00 2001 From: Bagatur <22008038+baskaryan@users.noreply.github.com> Date: Mon, 5 Feb 2024 12:37:27 -0800 Subject: [PATCH] community[patch]: callbacks mypy fixes (#17058) Related to #17048 --- .../callbacks/aim_callback.py | 32 ++++++++++++------- .../callbacks/argilla_callback.py | 13 +++----- .../callbacks/arthur_callback.py | 2 +- .../callbacks/clearml_callback.py | 10 ++---- .../callbacks/comet_ml_callback.py | 2 +- .../callbacks/flyte_callback.py | 10 ++---- .../callbacks/infino_callback.py | 4 +-- .../callbacks/mlflow_callback.py | 14 +++----- .../callbacks/streamlit/__init__.py | 2 +- .../callbacks/streamlit/mutable_expander.py | 2 +- .../callbacks/tracers/wandb.py | 7 ++-- .../callbacks/wandb_callback.py | 10 ++---- 12 files changed, 49 insertions(+), 59 deletions(-) diff --git a/libs/community/langchain_community/callbacks/aim_callback.py b/libs/community/langchain_community/callbacks/aim_callback.py index da194b34b3d..e36cfcceafe 100644 --- a/libs/community/langchain_community/callbacks/aim_callback.py +++ b/libs/community/langchain_community/callbacks/aim_callback.py @@ -416,15 +416,25 @@ class AimCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler): self._run.close() self.reset_callback_meta() if reset: - self.__init__( # type: ignore - repo=repo if repo else self.repo, - experiment_name=experiment_name - if experiment_name - else self.experiment_name, - system_tracking_interval=system_tracking_interval - if system_tracking_interval - else self.system_tracking_interval, - log_system_params=log_system_params - if log_system_params - else self.log_system_params, + aim = import_aim() + self.repo = repo if repo else self.repo + self.experiment_name = ( + experiment_name if experiment_name else self.experiment_name ) + self.system_tracking_interval = ( + system_tracking_interval + if system_tracking_interval + else self.system_tracking_interval + ) + self.log_system_params = ( + log_system_params if log_system_params else self.log_system_params + ) + + self._run = aim.Run( + repo=self.repo, + experiment=self.experiment_name, + system_tracking_interval=self.system_tracking_interval, + log_system_params=self.log_system_params, + ) + self._run_hash = self._run.hash + self.action_records = [] diff --git a/libs/community/langchain_community/callbacks/argilla_callback.py b/libs/community/langchain_community/callbacks/argilla_callback.py index 942f9858b51..157075a2832 100644 --- a/libs/community/langchain_community/callbacks/argilla_callback.py +++ b/libs/community/langchain_community/callbacks/argilla_callback.py @@ -1,6 +1,6 @@ import os import warnings -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, cast from langchain_core.agents import AgentAction, AgentFinish from langchain_core.callbacks import BaseCallbackHandler @@ -269,8 +269,8 @@ class ArgillaCallbackHandler(BaseCallbackHandler): for key in [str(kwargs["parent_run_id"]), str(kwargs["run_id"])] ): return - prompts = self.prompts.get(str(kwargs["parent_run_id"])) or self.prompts.get( - str(kwargs["run_id"]) + prompts: List = self.prompts.get(str(kwargs["parent_run_id"])) or cast( + List, self.prompts.get(str(kwargs["run_id"]), []) ) for chain_output_key, chain_output_val in outputs.items(): if isinstance(chain_output_val, list): @@ -283,10 +283,7 @@ class ArgillaCallbackHandler(BaseCallbackHandler): "response": output["text"].strip(), }, } - for prompt, output in zip( - prompts, # type: ignore - chain_output_val, - ) + for prompt, output in zip(prompts, chain_output_val) ] ) else: @@ -295,7 +292,7 @@ class ArgillaCallbackHandler(BaseCallbackHandler): records=[ { "fields": { - "prompt": " ".join(prompts), # type: ignore + "prompt": " ".join(prompts), "response": chain_output_val.strip(), }, } diff --git a/libs/community/langchain_community/callbacks/arthur_callback.py b/libs/community/langchain_community/callbacks/arthur_callback.py index a5fce582ed1..b0b28328ab2 100644 --- a/libs/community/langchain_community/callbacks/arthur_callback.py +++ b/libs/community/langchain_community/callbacks/arthur_callback.py @@ -162,7 +162,7 @@ class ArthurCallbackHandler(BaseCallbackHandler): def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: """On LLM end, send data to Arthur.""" try: - import pytz # type: ignore[import] + import pytz except ImportError as e: raise ImportError( "Could not import pytz. Please install it with 'pip install pytz'." diff --git a/libs/community/langchain_community/callbacks/clearml_callback.py b/libs/community/langchain_community/callbacks/clearml_callback.py index 71f30fccdd8..7c4b66e6146 100644 --- a/libs/community/langchain_community/callbacks/clearml_callback.py +++ b/libs/community/langchain_community/callbacks/clearml_callback.py @@ -83,7 +83,7 @@ class ClearMLCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler): if clearml.Task.current_task(): self.task = clearml.Task.current_task() else: - self.task = clearml.Task.init( # type: ignore + self.task = clearml.Task.init( task_type=self.task_type, project_name=self.project_name, tags=self.tags, @@ -361,17 +361,13 @@ class ClearMLCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler): if self.visualize and self.nlp and self.temp_dir.name is not None: doc = self.nlp(text) - dep_out = spacy.displacy.render( # type: ignore - doc, style="dep", jupyter=False, page=True - ) + dep_out = spacy.displacy.render(doc, style="dep", jupyter=False, page=True) dep_output_path = Path( self.temp_dir.name, hash_string(f"dep-{text}") + ".html" ) dep_output_path.open("w", encoding="utf-8").write(dep_out) - ent_out = spacy.displacy.render( # type: ignore - doc, style="ent", jupyter=False, page=True - ) + ent_out = spacy.displacy.render(doc, style="ent", jupyter=False, page=True) ent_output_path = Path( self.temp_dir.name, hash_string(f"ent-{text}") + ".html" ) diff --git a/libs/community/langchain_community/callbacks/comet_ml_callback.py b/libs/community/langchain_community/callbacks/comet_ml_callback.py index 5493c947ae8..499e93a07b0 100644 --- a/libs/community/langchain_community/callbacks/comet_ml_callback.py +++ b/libs/community/langchain_community/callbacks/comet_ml_callback.py @@ -37,7 +37,7 @@ def _get_experiment( ) -> Any: comet_ml = import_comet_ml() - experiment = comet_ml.Experiment( # type: ignore + experiment = comet_ml.Experiment( workspace=workspace, project_name=project_name, ) diff --git a/libs/community/langchain_community/callbacks/flyte_callback.py b/libs/community/langchain_community/callbacks/flyte_callback.py index 23a8f473430..9fd953443f2 100644 --- a/libs/community/langchain_community/callbacks/flyte_callback.py +++ b/libs/community/langchain_community/callbacks/flyte_callback.py @@ -79,12 +79,8 @@ def analyze_text( if nlp is not None: spacy = import_spacy() doc = nlp(text) - dep_out = spacy.displacy.render( # type: ignore - doc, style="dep", jupyter=False, page=True - ) - ent_out = spacy.displacy.render( # type: ignore - doc, style="ent", jupyter=False, page=True - ) + dep_out = spacy.displacy.render(doc, style="dep", jupyter=False, page=True) + ent_out = spacy.displacy.render(doc, style="ent", jupyter=False, page=True) text_visualizations = { "dependency_tree": dep_out, "entities": ent_out, @@ -199,7 +195,7 @@ class FlyteCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler): complexity_metrics: Dict[str, float] = generation_resp.pop( "text_complexity_metrics" - ) # type: ignore # noqa: E501 + ) self.deck.append( self.markdown_renderer().to_html("#### Text Complexity Metrics") ) diff --git a/libs/community/langchain_community/callbacks/infino_callback.py b/libs/community/langchain_community/callbacks/infino_callback.py index 57d756948ef..106a9a5a083 100644 --- a/libs/community/langchain_community/callbacks/infino_callback.py +++ b/libs/community/langchain_community/callbacks/infino_callback.py @@ -4,7 +4,7 @@ from typing import Any, Dict, List, Optional, cast from langchain_core.agents import AgentAction, AgentFinish from langchain_core.callbacks import BaseCallbackHandler from langchain_core.messages import BaseMessage -from langchain_core.outputs import LLMResult +from langchain_core.outputs import ChatGeneration, LLMResult def import_infino() -> Any: @@ -146,7 +146,7 @@ class InfinoCallbackHandler(BaseCallbackHandler): # Track completion token usage (for openai chat models). if self.is_chat_openai_model: messages = " ".join( - generation.message.content # type: ignore[attr-defined] + cast(str, cast(ChatGeneration, generation).message.content) for generation in generations ) completion_tokens = get_num_tokens( diff --git a/libs/community/langchain_community/callbacks/mlflow_callback.py b/libs/community/langchain_community/callbacks/mlflow_callback.py index ea862c23fd2..6bcfb3aa330 100644 --- a/libs/community/langchain_community/callbacks/mlflow_callback.py +++ b/libs/community/langchain_community/callbacks/mlflow_callback.py @@ -109,13 +109,9 @@ def analyze_text( spacy = import_spacy() doc = nlp(text) - dep_out = spacy.displacy.render( # type: ignore - doc, style="dep", jupyter=False, page=True - ) + dep_out = spacy.displacy.render(doc, style="dep", jupyter=False, page=True) - ent_out = spacy.displacy.render( # type: ignore - doc, style="ent", jupyter=False, page=True - ) + ent_out = spacy.displacy.render(doc, style="ent", jupyter=False, page=True) text_visualizations = { "dependency_tree": dep_out, @@ -233,7 +229,7 @@ class MlflowLogger: data, os.path.join(self.dir, f"{filename}.json"), run_id=self.run_id ) - def table(self, name: str, dataframe) -> None: # type: ignore + def table(self, name: str, dataframe: Any) -> None: """To log the input pandas dataframe as a html table""" self.html(dataframe.to_html(), f"table_{name}") @@ -411,7 +407,7 @@ class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler): ) complexity_metrics: Dict[str, float] = generation_resp.pop( "text_complexity_metrics" - ) # type: ignore # noqa: E501 + ) self.mlflg.metrics( complexity_metrics, step=self.metrics["step"], @@ -723,7 +719,7 @@ class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler): ) return session_analysis_df - def _contain_llm_records(self): # type: ignore[no-untyped-def] + def _contain_llm_records(self) -> bool: return bool(self.records["on_llm_start_records"]) def flush_tracker(self, langchain_asset: Any = None, finish: bool = False) -> None: diff --git a/libs/community/langchain_community/callbacks/streamlit/__init__.py b/libs/community/langchain_community/callbacks/streamlit/__init__.py index 7a0fadb059d..4ee2ea5a9fb 100644 --- a/libs/community/langchain_community/callbacks/streamlit/__init__.py +++ b/libs/community/langchain_community/callbacks/streamlit/__init__.py @@ -62,7 +62,7 @@ def StreamlitCallbackHandler( # guaranteed to support the same set of kwargs. try: from streamlit.external.langchain import ( - StreamlitCallbackHandler as OfficialStreamlitCallbackHandler, # type: ignore # noqa: 501 + StreamlitCallbackHandler as OfficialStreamlitCallbackHandler, ) return OfficialStreamlitCallbackHandler( diff --git a/libs/community/langchain_community/callbacks/streamlit/mutable_expander.py b/libs/community/langchain_community/callbacks/streamlit/mutable_expander.py index 7de1e9873fe..0bb73f571ad 100644 --- a/libs/community/langchain_community/callbacks/streamlit/mutable_expander.py +++ b/libs/community/langchain_community/callbacks/streamlit/mutable_expander.py @@ -108,7 +108,7 @@ class MutableExpander: ) -> int: """Add a Markdown element to the container and return its index.""" kwargs = {"body": body, "unsafe_allow_html": unsafe_allow_html, "help": help} - new_dg = self._get_dg(index).markdown(**kwargs) # type: ignore[arg-type] + new_dg = self._get_dg(index).markdown(**kwargs) record = ChildRecord(ChildType.MARKDOWN, kwargs, new_dg) return self._add_record(record, index) diff --git a/libs/community/langchain_community/callbacks/tracers/wandb.py b/libs/community/langchain_community/callbacks/tracers/wandb.py index 31df3352734..00753a5af1a 100644 --- a/libs/community/langchain_community/callbacks/tracers/wandb.py +++ b/libs/community/langchain_community/callbacks/tracers/wandb.py @@ -489,11 +489,10 @@ class WandbTracer(BaseTracer): If not, will start a new run with the provided run_args. """ if self._wandb.run is None: - run_args = self._run_args or {} # type: ignore - run_args: dict = {**run_args} # type: ignore + run_args: Dict = {**(self._run_args or {})} - if "settings" not in run_args: # type: ignore - run_args["settings"] = {"silent": True} # type: ignore + if "settings" not in run_args: + run_args["settings"] = {"silent": True} self._wandb.init(**run_args) if self._wandb.run is not None: diff --git a/libs/community/langchain_community/callbacks/wandb_callback.py b/libs/community/langchain_community/callbacks/wandb_callback.py index 35559ea539c..9f968d3c671 100644 --- a/libs/community/langchain_community/callbacks/wandb_callback.py +++ b/libs/community/langchain_community/callbacks/wandb_callback.py @@ -92,15 +92,11 @@ def analyze_text( if visualize and nlp and output_dir is not None: doc = nlp(text) - dep_out = spacy.displacy.render( # type: ignore - doc, style="dep", jupyter=False, page=True - ) + dep_out = spacy.displacy.render(doc, style="dep", jupyter=False, page=True) dep_output_path = Path(output_dir, hash_string(f"dep-{text}") + ".html") dep_output_path.open("w", encoding="utf-8").write(dep_out) - ent_out = spacy.displacy.render( # type: ignore - doc, style="ent", jupyter=False, page=True - ) + ent_out = spacy.displacy.render(doc, style="ent", jupyter=False, page=True) ent_output_path = Path(output_dir, hash_string(f"ent-{text}") + ".html") ent_output_path.open("w", encoding="utf-8").write(ent_out) @@ -193,7 +189,7 @@ class WandbCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler): self.stream_logs = stream_logs self.temp_dir = tempfile.TemporaryDirectory() - self.run: wandb.sdk.wandb_run.Run = wandb.init( # type: ignore + self.run = wandb.init( job_type=self.job_type, project=self.project, entity=self.entity,