mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-09 21:08:59 +00:00
chore(langchain): add ruff rules TC (#31921)
See https://docs.astral.sh/ruff/rules/#flake8-type-checking-tc
This commit is contained in:
parent
5ecbb5f277
commit
a2ad5aca41
@ -1185,7 +1185,7 @@ class AgentExecutor(Chain):
|
|||||||
to reflect the changes made in the root_validator.
|
to reflect the changes made in the root_validator.
|
||||||
"""
|
"""
|
||||||
if isinstance(self.agent, Runnable):
|
if isinstance(self.agent, Runnable):
|
||||||
return cast(RunnableAgentType, self.agent)
|
return cast("RunnableAgentType", self.agent)
|
||||||
return self.agent
|
return self.agent
|
||||||
|
|
||||||
def save(self, file_path: Union[Path, str]) -> None:
|
def save(self, file_path: Union[Path, str]) -> None:
|
||||||
|
@ -73,7 +73,7 @@ class AsyncIteratorCallbackHandler(AsyncCallbackHandler):
|
|||||||
other.pop().cancel()
|
other.pop().cancel()
|
||||||
|
|
||||||
# Extract the value of the first completed task
|
# Extract the value of the first completed task
|
||||||
token_or_done = cast(Union[str, Literal[True]], done.pop().result())
|
token_or_done = cast("Union[str, Literal[True]]", done.pop().result())
|
||||||
|
|
||||||
# If the extracted value is the boolean True, the done event was set
|
# If the extracted value is the boolean True, the done event was set
|
||||||
if token_or_done is True:
|
if token_or_done is True:
|
||||||
|
@ -411,7 +411,7 @@ class Chain(RunnableSerializable[dict[str, Any], dict[str, Any]], ABC):
|
|||||||
|
|
||||||
return self.invoke(
|
return self.invoke(
|
||||||
inputs,
|
inputs,
|
||||||
cast(RunnableConfig, {k: v for k, v in config.items() if v is not None}),
|
cast("RunnableConfig", {k: v for k, v in config.items() if v is not None}),
|
||||||
return_only_outputs=return_only_outputs,
|
return_only_outputs=return_only_outputs,
|
||||||
include_run_info=include_run_info,
|
include_run_info=include_run_info,
|
||||||
)
|
)
|
||||||
@ -461,7 +461,7 @@ class Chain(RunnableSerializable[dict[str, Any], dict[str, Any]], ABC):
|
|||||||
}
|
}
|
||||||
return await self.ainvoke(
|
return await self.ainvoke(
|
||||||
inputs,
|
inputs,
|
||||||
cast(RunnableConfig, {k: v for k, v in config.items() if k is not None}),
|
cast("RunnableConfig", {k: v for k, v in config.items() if k is not None}),
|
||||||
return_only_outputs=return_only_outputs,
|
return_only_outputs=return_only_outputs,
|
||||||
include_run_info=include_run_info,
|
include_run_info=include_run_info,
|
||||||
)
|
)
|
||||||
|
@ -229,7 +229,7 @@ class MapRerankDocumentsChain(BaseCombineDocumentsChain):
|
|||||||
docs: list[Document],
|
docs: list[Document],
|
||||||
results: Sequence[Union[str, list[str], dict[str, str]]],
|
results: Sequence[Union[str, list[str], dict[str, str]]],
|
||||||
) -> tuple[str, dict]:
|
) -> tuple[str, dict]:
|
||||||
typed_results = cast(list[dict], results)
|
typed_results = cast("list[dict]", results)
|
||||||
sorted_res = sorted(
|
sorted_res = sorted(
|
||||||
zip(typed_results, docs),
|
zip(typed_results, docs),
|
||||||
key=lambda x: -int(x[0][self.rank_key]),
|
key=lambda x: -int(x[0][self.rank_key]),
|
||||||
|
@ -145,7 +145,7 @@ class LLMChain(Chain):
|
|||||||
**self.llm_kwargs,
|
**self.llm_kwargs,
|
||||||
)
|
)
|
||||||
results = self.llm.bind(stop=stop, **self.llm_kwargs).batch(
|
results = self.llm.bind(stop=stop, **self.llm_kwargs).batch(
|
||||||
cast(list, prompts),
|
cast("list", prompts),
|
||||||
{"callbacks": callbacks},
|
{"callbacks": callbacks},
|
||||||
)
|
)
|
||||||
generations: list[list[Generation]] = []
|
generations: list[list[Generation]] = []
|
||||||
@ -172,7 +172,7 @@ class LLMChain(Chain):
|
|||||||
**self.llm_kwargs,
|
**self.llm_kwargs,
|
||||||
)
|
)
|
||||||
results = await self.llm.bind(stop=stop, **self.llm_kwargs).abatch(
|
results = await self.llm.bind(stop=stop, **self.llm_kwargs).abatch(
|
||||||
cast(list, prompts),
|
cast("list", prompts),
|
||||||
{"callbacks": callbacks},
|
{"callbacks": callbacks},
|
||||||
)
|
)
|
||||||
generations: list[list[Generation]] = []
|
generations: list[list[Generation]] = []
|
||||||
|
@ -76,11 +76,11 @@ def create_qa_with_structure_chain(
|
|||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
if isinstance(schema, type) and is_basemodel_subclass(schema):
|
if isinstance(schema, type) and is_basemodel_subclass(schema):
|
||||||
if hasattr(schema, "model_json_schema"):
|
if hasattr(schema, "model_json_schema"):
|
||||||
schema_dict = cast(dict, schema.model_json_schema())
|
schema_dict = cast("dict", schema.model_json_schema())
|
||||||
else:
|
else:
|
||||||
schema_dict = cast(dict, schema.schema())
|
schema_dict = cast("dict", schema.schema())
|
||||||
else:
|
else:
|
||||||
schema_dict = cast(dict, schema)
|
schema_dict = cast("dict", schema)
|
||||||
function = {
|
function = {
|
||||||
"name": schema_dict["title"],
|
"name": schema_dict["title"],
|
||||||
"description": schema_dict["description"],
|
"description": schema_dict["description"],
|
||||||
|
@ -91,7 +91,7 @@ class StructuredQueryOutputParser(BaseOutputParser[StructuredQuery]):
|
|||||||
|
|
||||||
def ast_parse(raw_filter: str) -> Optional[FilterDirective]:
|
def ast_parse(raw_filter: str) -> Optional[FilterDirective]:
|
||||||
filter_directive = cast(
|
filter_directive = cast(
|
||||||
Optional[FilterDirective],
|
"Optional[FilterDirective]",
|
||||||
get_parser().parse(raw_filter),
|
get_parser().parse(raw_filter),
|
||||||
)
|
)
|
||||||
return fix_filter_directive(
|
return fix_filter_directive(
|
||||||
@ -144,7 +144,7 @@ def fix_filter_directive(
|
|||||||
return None
|
return None
|
||||||
args = [
|
args = [
|
||||||
cast(
|
cast(
|
||||||
FilterDirective,
|
"FilterDirective",
|
||||||
fix_filter_directive(
|
fix_filter_directive(
|
||||||
arg,
|
arg,
|
||||||
allowed_comparators=allowed_comparators,
|
allowed_comparators=allowed_comparators,
|
||||||
|
@ -137,7 +137,7 @@ class LLMRouterChain(RouterChain):
|
|||||||
|
|
||||||
prediction = self.llm_chain.predict(callbacks=callbacks, **inputs)
|
prediction = self.llm_chain.predict(callbacks=callbacks, **inputs)
|
||||||
return cast(
|
return cast(
|
||||||
dict[str, Any],
|
"dict[str, Any]",
|
||||||
self.llm_chain.prompt.output_parser.parse(prediction),
|
self.llm_chain.prompt.output_parser.parse(prediction),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -149,7 +149,7 @@ class LLMRouterChain(RouterChain):
|
|||||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||||
callbacks = _run_manager.get_child()
|
callbacks = _run_manager.get_child()
|
||||||
return cast(
|
return cast(
|
||||||
dict[str, Any],
|
"dict[str, Any]",
|
||||||
await self.llm_chain.apredict_and_parse(callbacks=callbacks, **inputs),
|
await self.llm_chain.apredict_and_parse(callbacks=callbacks, **inputs),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -322,7 +322,7 @@ def init_chat_model(
|
|||||||
|
|
||||||
if not configurable_fields:
|
if not configurable_fields:
|
||||||
return _init_chat_model_helper(
|
return _init_chat_model_helper(
|
||||||
cast(str, model),
|
cast("str", model),
|
||||||
model_provider=model_provider,
|
model_provider=model_provider,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
@ -632,7 +632,7 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
|||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> _ConfigurableModel:
|
) -> _ConfigurableModel:
|
||||||
"""Bind config to a Runnable, returning a new Runnable."""
|
"""Bind config to a Runnable, returning a new Runnable."""
|
||||||
config = RunnableConfig(**(config or {}), **cast(RunnableConfig, kwargs))
|
config = RunnableConfig(**(config or {}), **cast("RunnableConfig", kwargs))
|
||||||
model_params = self._model_params(config)
|
model_params = self._model_params(config)
|
||||||
remaining_config = {k: v for k, v in config.items() if k != "configurable"}
|
remaining_config = {k: v for k, v in config.items() if k != "configurable"}
|
||||||
remaining_config["configurable"] = {
|
remaining_config["configurable"] = {
|
||||||
@ -781,7 +781,7 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
|||||||
if config is None or isinstance(config, dict) or len(config) <= 1:
|
if config is None or isinstance(config, dict) or len(config) <= 1:
|
||||||
if isinstance(config, list):
|
if isinstance(config, list):
|
||||||
config = config[0]
|
config = config[0]
|
||||||
yield from self._model(cast(RunnableConfig, config)).batch_as_completed( # type: ignore[call-overload]
|
yield from self._model(cast("RunnableConfig", config)).batch_as_completed( # type: ignore[call-overload]
|
||||||
inputs,
|
inputs,
|
||||||
config=config,
|
config=config,
|
||||||
return_exceptions=return_exceptions,
|
return_exceptions=return_exceptions,
|
||||||
@ -811,7 +811,7 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
|||||||
if isinstance(config, list):
|
if isinstance(config, list):
|
||||||
config = config[0]
|
config = config[0]
|
||||||
async for x in self._model(
|
async for x in self._model(
|
||||||
cast(RunnableConfig, config),
|
cast("RunnableConfig", config),
|
||||||
).abatch_as_completed( # type: ignore[call-overload]
|
).abatch_as_completed( # type: ignore[call-overload]
|
||||||
inputs,
|
inputs,
|
||||||
config=config,
|
config=config,
|
||||||
|
@ -1,12 +1,9 @@
|
|||||||
from typing import TYPE_CHECKING, Any
|
from typing import Any
|
||||||
|
|
||||||
from langchain_core.document_loaders import Blob, BlobLoader
|
from langchain_core.document_loaders import Blob, BlobLoader
|
||||||
|
|
||||||
from langchain._api import create_importer
|
from langchain._api import create_importer
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# Create a way to dynamically look up deprecated imports.
|
# Create a way to dynamically look up deprecated imports.
|
||||||
# Used to consolidate logic for raising deprecation warnings and
|
# Used to consolidate logic for raising deprecation warnings and
|
||||||
# handling optional imports.
|
# handling optional imports.
|
||||||
|
@ -80,7 +80,7 @@ def _value_serializer(value: Sequence[float]) -> bytes:
|
|||||||
|
|
||||||
def _value_deserializer(serialized_value: bytes) -> list[float]:
|
def _value_deserializer(serialized_value: bytes) -> list[float]:
|
||||||
"""Deserialize a value."""
|
"""Deserialize a value."""
|
||||||
return cast(list[float], json.loads(serialized_value.decode()))
|
return cast("list[float]", json.loads(serialized_value.decode()))
|
||||||
|
|
||||||
|
|
||||||
# The warning is global; track emission, so it appears only once.
|
# The warning is global; track emission, so it appears only once.
|
||||||
@ -192,7 +192,7 @@ class CacheBackedEmbeddings(Embeddings):
|
|||||||
vectors[index] = updated_vector
|
vectors[index] = updated_vector
|
||||||
|
|
||||||
return cast(
|
return cast(
|
||||||
list[list[float]],
|
"list[list[float]]",
|
||||||
vectors,
|
vectors,
|
||||||
) # Nones should have been resolved by now
|
) # Nones should have been resolved by now
|
||||||
|
|
||||||
@ -230,7 +230,7 @@ class CacheBackedEmbeddings(Embeddings):
|
|||||||
vectors[index] = updated_vector
|
vectors[index] = updated_vector
|
||||||
|
|
||||||
return cast(
|
return cast(
|
||||||
list[list[float]],
|
"list[list[float]]",
|
||||||
vectors,
|
vectors,
|
||||||
) # Nones should have been resolved by now
|
) # Nones should have been resolved by now
|
||||||
|
|
||||||
|
@ -301,7 +301,7 @@ The following is the expected answer. Use this to measure correctness:
|
|||||||
chain_input,
|
chain_input,
|
||||||
callbacks=_run_manager.get_child(),
|
callbacks=_run_manager.get_child(),
|
||||||
)
|
)
|
||||||
return cast(dict, self.output_parser.parse(raw_output))
|
return cast("dict", self.output_parser.parse(raw_output))
|
||||||
|
|
||||||
async def _acall(
|
async def _acall(
|
||||||
self,
|
self,
|
||||||
@ -326,7 +326,7 @@ The following is the expected answer. Use this to measure correctness:
|
|||||||
chain_input,
|
chain_input,
|
||||||
callbacks=_run_manager.get_child(),
|
callbacks=_run_manager.get_child(),
|
||||||
)
|
)
|
||||||
return cast(dict, self.output_parser.parse(raw_output))
|
return cast("dict", self.output_parser.parse(raw_output))
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def _evaluate_agent_trajectory(
|
def _evaluate_agent_trajectory(
|
||||||
|
@ -166,7 +166,7 @@ class JsonEqualityEvaluator(StringEvaluator):
|
|||||||
dict: A dictionary containing the evaluation score.
|
dict: A dictionary containing the evaluation score.
|
||||||
"""
|
"""
|
||||||
parsed = self._parse_json(prediction)
|
parsed = self._parse_json(prediction)
|
||||||
label = self._parse_json(cast(str, reference))
|
label = self._parse_json(cast("str", reference))
|
||||||
if isinstance(label, list):
|
if isinstance(label, list):
|
||||||
if not isinstance(parsed, list):
|
if not isinstance(parsed, list):
|
||||||
return {"score": 0}
|
return {"score": 0}
|
||||||
|
@ -82,7 +82,7 @@ class LLMChainExtractor(BaseDocumentCompressor):
|
|||||||
if len(output) == 0:
|
if len(output) == 0:
|
||||||
continue
|
continue
|
||||||
compressed_docs.append(
|
compressed_docs.append(
|
||||||
Document(page_content=cast(str, output), metadata=doc.metadata),
|
Document(page_content=cast("str", output), metadata=doc.metadata),
|
||||||
)
|
)
|
||||||
return compressed_docs
|
return compressed_docs
|
||||||
|
|
||||||
|
@ -236,7 +236,7 @@ class EnsembleRetriever(BaseRetriever):
|
|||||||
# Enforce that retrieved docs are Documents for each list in retriever_docs
|
# Enforce that retrieved docs are Documents for each list in retriever_docs
|
||||||
for i in range(len(retriever_docs)):
|
for i in range(len(retriever_docs)):
|
||||||
retriever_docs[i] = [
|
retriever_docs[i] = [
|
||||||
Document(page_content=cast(str, doc)) if isinstance(doc, str) else doc
|
Document(page_content=cast("str", doc)) if isinstance(doc, str) else doc
|
||||||
for doc in retriever_docs[i]
|
for doc in retriever_docs[i]
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -214,24 +214,24 @@ def _wrap_in_chain_factory(
|
|||||||
return lambda: lcf
|
return lambda: lcf
|
||||||
if callable(llm_or_chain_factory):
|
if callable(llm_or_chain_factory):
|
||||||
if is_traceable_function(llm_or_chain_factory):
|
if is_traceable_function(llm_or_chain_factory):
|
||||||
runnable_ = as_runnable(cast(Callable, llm_or_chain_factory))
|
runnable_ = as_runnable(cast("Callable", llm_or_chain_factory))
|
||||||
return lambda: runnable_
|
return lambda: runnable_
|
||||||
try:
|
try:
|
||||||
_model = llm_or_chain_factory() # type: ignore[call-arg]
|
_model = llm_or_chain_factory() # type: ignore[call-arg]
|
||||||
except TypeError:
|
except TypeError:
|
||||||
# It's an arbitrary function, wrap it in a RunnableLambda
|
# It's an arbitrary function, wrap it in a RunnableLambda
|
||||||
user_func = cast(Callable, llm_or_chain_factory)
|
user_func = cast("Callable", llm_or_chain_factory)
|
||||||
sig = inspect.signature(user_func)
|
sig = inspect.signature(user_func)
|
||||||
logger.info("Wrapping function %s as RunnableLambda.", sig)
|
logger.info("Wrapping function %s as RunnableLambda.", sig)
|
||||||
wrapped = RunnableLambda(user_func)
|
wrapped = RunnableLambda(user_func)
|
||||||
return lambda: wrapped
|
return lambda: wrapped
|
||||||
constructor = cast(Callable, llm_or_chain_factory)
|
constructor = cast("Callable", llm_or_chain_factory)
|
||||||
if isinstance(_model, BaseLanguageModel):
|
if isinstance(_model, BaseLanguageModel):
|
||||||
# It's not uncommon to do an LLM constructor instead of raw LLM,
|
# It's not uncommon to do an LLM constructor instead of raw LLM,
|
||||||
# so we'll unpack it for the user.
|
# so we'll unpack it for the user.
|
||||||
return _model
|
return _model
|
||||||
if is_traceable_function(cast(Callable, _model)):
|
if is_traceable_function(cast("Callable", _model)):
|
||||||
runnable_ = as_runnable(cast(Callable, _model))
|
runnable_ = as_runnable(cast("Callable", _model))
|
||||||
return lambda: runnable_
|
return lambda: runnable_
|
||||||
if not isinstance(_model, Runnable):
|
if not isinstance(_model, Runnable):
|
||||||
# This is unlikely to happen - a constructor for a model function
|
# This is unlikely to happen - a constructor for a model function
|
||||||
@ -1104,7 +1104,7 @@ class _DatasetRunContainer:
|
|||||||
) -> dict:
|
) -> dict:
|
||||||
results: dict = {}
|
results: dict = {}
|
||||||
for example, output in zip(self.examples, batch_results):
|
for example, output in zip(self.examples, batch_results):
|
||||||
row_result = cast(_RowResult, all_eval_results.get(str(example.id), {}))
|
row_result = cast("_RowResult", all_eval_results.get(str(example.id), {}))
|
||||||
results[str(example.id)] = {
|
results[str(example.id)] = {
|
||||||
"input": example.inputs,
|
"input": example.inputs,
|
||||||
"feedback": row_result.get("feedback", []),
|
"feedback": row_result.get("feedback", []),
|
||||||
@ -1131,7 +1131,7 @@ class _DatasetRunContainer:
|
|||||||
result = evaluator(runs_list, self.examples)
|
result = evaluator(runs_list, self.examples)
|
||||||
if isinstance(result, EvaluationResult):
|
if isinstance(result, EvaluationResult):
|
||||||
result = result.dict()
|
result = result.dict()
|
||||||
aggregate_feedback.append(cast(dict, result))
|
aggregate_feedback.append(cast("dict", result))
|
||||||
executor.submit(
|
executor.submit(
|
||||||
self.client.create_feedback,
|
self.client.create_feedback,
|
||||||
**result,
|
**result,
|
||||||
@ -1148,7 +1148,7 @@ class _DatasetRunContainer:
|
|||||||
all_eval_results: dict = {}
|
all_eval_results: dict = {}
|
||||||
all_runs: dict = {}
|
all_runs: dict = {}
|
||||||
for c in self.configs:
|
for c in self.configs:
|
||||||
for callback in cast(list, c["callbacks"]):
|
for callback in cast("list", c["callbacks"]):
|
||||||
if isinstance(callback, EvaluatorCallbackHandler):
|
if isinstance(callback, EvaluatorCallbackHandler):
|
||||||
eval_results = callback.logged_eval_results
|
eval_results = callback.logged_eval_results
|
||||||
for (_, example_id), v in eval_results.items():
|
for (_, example_id), v in eval_results.items():
|
||||||
@ -1171,7 +1171,7 @@ class _DatasetRunContainer:
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
all_runs[str(callback.example_id)] = run
|
all_runs[str(callback.example_id)] = run
|
||||||
return cast(dict[str, _RowResult], all_eval_results), all_runs
|
return cast("dict[str, _RowResult]", all_eval_results), all_runs
|
||||||
|
|
||||||
def _collect_test_results(
|
def _collect_test_results(
|
||||||
self,
|
self,
|
||||||
|
@ -145,8 +145,8 @@ ignore-words-list = "momento,collison,ned,foor,reworkd,parth,whats,aapply,mysogy
|
|||||||
[tool.ruff.lint]
|
[tool.ruff.lint]
|
||||||
select = [
|
select = [
|
||||||
"A", # flake8-builtins
|
"A", # flake8-builtins
|
||||||
"B", # flake8-bugbear
|
|
||||||
"ASYNC", # flake8-async
|
"ASYNC", # flake8-async
|
||||||
|
"B", # flake8-bugbear
|
||||||
"C4", # flake8-comprehensions
|
"C4", # flake8-comprehensions
|
||||||
"COM", # flake8-commas
|
"COM", # flake8-commas
|
||||||
"D1", # pydocstyle: missing docstring
|
"D1", # pydocstyle: missing docstring
|
||||||
@ -174,11 +174,12 @@ select = [
|
|||||||
"RSE", # flake8-rst-docstrings
|
"RSE", # flake8-rst-docstrings
|
||||||
"RUF", # ruff
|
"RUF", # ruff
|
||||||
"S", # flake8-bandit
|
"S", # flake8-bandit
|
||||||
"SLOT", # flake8-slots
|
|
||||||
"SIM", # flake8-simplify
|
"SIM", # flake8-simplify
|
||||||
"SLF", # flake8-self
|
"SLF", # flake8-self
|
||||||
|
"SLOT", # flake8-slots
|
||||||
"T10", # flake8-debugger
|
"T10", # flake8-debugger
|
||||||
"T20", # flake8-print
|
"T20", # flake8-print
|
||||||
|
"TC", # flake8-type-checking
|
||||||
"TID", # flake8-tidy-imports
|
"TID", # flake8-tidy-imports
|
||||||
"TRY", # tryceratops
|
"TRY", # tryceratops
|
||||||
"UP", # pyupgrade
|
"UP", # pyupgrade
|
||||||
@ -192,6 +193,9 @@ ignore = [
|
|||||||
"PLR09", # Too many something (args, statements, etc)
|
"PLR09", # Too many something (args, statements, etc)
|
||||||
"S112", # Rarely useful
|
"S112", # Rarely useful
|
||||||
"RUF012", # Doesn't play well with Pydantic
|
"RUF012", # Doesn't play well with Pydantic
|
||||||
|
"TC001", # Doesn't play well with Pydantic
|
||||||
|
"TC002", # Doesn't play well with Pydantic
|
||||||
|
"TC003", # Doesn't play well with Pydantic
|
||||||
"UP007", # pyupgrade: non-pep604-annotation-union
|
"UP007", # pyupgrade: non-pep604-annotation-union
|
||||||
|
|
||||||
# TODO rules
|
# TODO rules
|
||||||
|
@ -40,7 +40,7 @@ async def test_init_chat_model_chain() -> None:
|
|||||||
class TestStandard(ChatModelIntegrationTests):
|
class TestStandard(ChatModelIntegrationTests):
|
||||||
@property
|
@property
|
||||||
def chat_model_class(self) -> type[BaseChatModel]:
|
def chat_model_class(self) -> type[BaseChatModel]:
|
||||||
return cast(type[BaseChatModel], init_chat_model)
|
return cast("type[BaseChatModel]", init_chat_model)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def chat_model_params(self) -> dict:
|
def chat_model_params(self) -> dict:
|
||||||
|
@ -571,7 +571,7 @@ async def test_runnable_agent_with_function_calls() -> None:
|
|||||||
|
|
||||||
def fake_parse(inputs: dict) -> Union[AgentFinish, AgentAction]:
|
def fake_parse(inputs: dict) -> Union[AgentFinish, AgentAction]:
|
||||||
"""A parser."""
|
"""A parser."""
|
||||||
return cast(Union[AgentFinish, AgentAction], next(parser_responses))
|
return cast("Union[AgentFinish, AgentAction]", next(parser_responses))
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
def find_pet(pet: str) -> str:
|
def find_pet(pet: str) -> str:
|
||||||
@ -683,7 +683,7 @@ async def test_runnable_with_multi_action_per_step() -> None:
|
|||||||
|
|
||||||
def fake_parse(inputs: dict) -> Union[AgentFinish, AgentAction]:
|
def fake_parse(inputs: dict) -> Union[AgentFinish, AgentAction]:
|
||||||
"""A parser."""
|
"""A parser."""
|
||||||
return cast(Union[AgentFinish, AgentAction], next(parser_responses))
|
return cast("Union[AgentFinish, AgentAction]", next(parser_responses))
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
def find_pet(pet: str) -> str:
|
def find_pet(pet: str) -> str:
|
||||||
|
@ -82,7 +82,7 @@ def test_parse_disallowed_operator() -> None:
|
|||||||
|
|
||||||
|
|
||||||
def _test_parse_value(x: Any) -> None:
|
def _test_parse_value(x: Any) -> None:
|
||||||
parsed = cast(Comparison, (DEFAULT_PARSER.parse(f'eq("x", {x})')))
|
parsed = cast("Comparison", (DEFAULT_PARSER.parse(f'eq("x", {x})')))
|
||||||
actual = parsed.value
|
actual = parsed.value
|
||||||
assert actual == x
|
assert actual == x
|
||||||
|
|
||||||
@ -104,14 +104,14 @@ def test_parse_list_value(x: list) -> None:
|
|||||||
|
|
||||||
@pytest.mark.parametrize("x", ['""', '" "', '"foo"', "'foo'"])
|
@pytest.mark.parametrize("x", ['""', '" "', '"foo"', "'foo'"])
|
||||||
def test_parse_string_value(x: str) -> None:
|
def test_parse_string_value(x: str) -> None:
|
||||||
parsed = cast(Comparison, DEFAULT_PARSER.parse(f'eq("x", {x})'))
|
parsed = cast("Comparison", DEFAULT_PARSER.parse(f'eq("x", {x})'))
|
||||||
actual = parsed.value
|
actual = parsed.value
|
||||||
assert actual == x[1:-1]
|
assert actual == x[1:-1]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("x", ["true", "True", "TRUE", "false", "False", "FALSE"])
|
@pytest.mark.parametrize("x", ["true", "True", "TRUE", "false", "False", "FALSE"])
|
||||||
def test_parse_bool_value(x: str) -> None:
|
def test_parse_bool_value(x: str) -> None:
|
||||||
parsed = cast(Comparison, DEFAULT_PARSER.parse(f'eq("x", {x})'))
|
parsed = cast("Comparison", DEFAULT_PARSER.parse(f'eq("x", {x})'))
|
||||||
actual = parsed.value
|
actual = parsed.value
|
||||||
expected = x.lower() == "true"
|
expected = x.lower() == "true"
|
||||||
assert actual == expected
|
assert actual == expected
|
||||||
@ -127,7 +127,7 @@ def test_parser_unpack_single_arg_operation(op: str, arg: str) -> None:
|
|||||||
|
|
||||||
@pytest.mark.parametrize("x", ['"2022-10-20"', "'2022-10-20'", "2022-10-20"])
|
@pytest.mark.parametrize("x", ['"2022-10-20"', "'2022-10-20'", "2022-10-20"])
|
||||||
def test_parse_date_value(x: str) -> None:
|
def test_parse_date_value(x: str) -> None:
|
||||||
parsed = cast(Comparison, DEFAULT_PARSER.parse(f'eq("x", {x})'))
|
parsed = cast("Comparison", DEFAULT_PARSER.parse(f'eq("x", {x})'))
|
||||||
actual = parsed.value["date"]
|
actual = parsed.value["date"]
|
||||||
assert actual == x.strip("'\"")
|
assert actual == x.strip("'\"")
|
||||||
|
|
||||||
@ -152,7 +152,7 @@ def test_parse_date_value(x: str) -> None:
|
|||||||
def test_parse_datetime_value(x: str, expected: dict) -> None:
|
def test_parse_datetime_value(x: str, expected: dict) -> None:
|
||||||
"""Test parsing of datetime values with ISO 8601 format."""
|
"""Test parsing of datetime values with ISO 8601 format."""
|
||||||
try:
|
try:
|
||||||
parsed = cast(Comparison, DEFAULT_PARSER.parse(f'eq("publishedAt", {x})'))
|
parsed = cast("Comparison", DEFAULT_PARSER.parse(f'eq("publishedAt", {x})'))
|
||||||
actual = parsed.value
|
actual = parsed.value
|
||||||
assert actual == expected, f"Expected {expected}, got {actual}"
|
assert actual == expected, f"Expected {expected}, got {actual}"
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
|
@ -122,7 +122,7 @@ class GenericFakeChatModel(BaseChatModel):
|
|||||||
# Use a regular expression to split on whitespace with a capture group
|
# Use a regular expression to split on whitespace with a capture group
|
||||||
# so that we can preserve the whitespace in the output.
|
# so that we can preserve the whitespace in the output.
|
||||||
assert isinstance(content, str)
|
assert isinstance(content, str)
|
||||||
content_chunks = cast(list[str], re.split(r"(\s)", content))
|
content_chunks = cast("list[str]", re.split(r"(\s)", content))
|
||||||
|
|
||||||
for token in content_chunks:
|
for token in content_chunks:
|
||||||
chunk = ChatGenerationChunk(
|
chunk = ChatGenerationChunk(
|
||||||
@ -140,7 +140,7 @@ class GenericFakeChatModel(BaseChatModel):
|
|||||||
for fkey, fvalue in value.items():
|
for fkey, fvalue in value.items():
|
||||||
if isinstance(fvalue, str):
|
if isinstance(fvalue, str):
|
||||||
# Break function call by `,`
|
# Break function call by `,`
|
||||||
fvalue_chunks = cast(list[str], re.split(r"(,)", fvalue))
|
fvalue_chunks = cast("list[str]", re.split(r"(,)", fvalue))
|
||||||
for fvalue_chunk in fvalue_chunks:
|
for fvalue_chunk in fvalue_chunks:
|
||||||
chunk = ChatGenerationChunk(
|
chunk = ChatGenerationChunk(
|
||||||
message=AIMessageChunk(
|
message=AIMessageChunk(
|
||||||
|
@ -53,7 +53,7 @@ class FakeLLM(LLM):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def _get_next_response_in_sequence(self) -> str:
|
def _get_next_response_in_sequence(self) -> str:
|
||||||
queries = cast(Mapping, self.queries)
|
queries = cast("Mapping", self.queries)
|
||||||
response = queries[list(queries.keys())[self.response_index]]
|
response = queries[list(queries.keys())[self.response_index]]
|
||||||
self.response_index = self.response_index + 1
|
self.response_index = self.response_index + 1
|
||||||
return response
|
return response
|
||||||
|
@ -22,7 +22,7 @@ def test_create_lc_store(file_store: LocalFileStore) -> None:
|
|||||||
"""Test that a docstore is created from a base store."""
|
"""Test that a docstore is created from a base store."""
|
||||||
docstore = create_lc_store(file_store)
|
docstore = create_lc_store(file_store)
|
||||||
docstore.mset([("key1", Document(page_content="hello", metadata={"key": "value"}))])
|
docstore.mset([("key1", Document(page_content="hello", metadata={"key": "value"}))])
|
||||||
fetched_doc = cast(Document, docstore.mget(["key1"])[0])
|
fetched_doc = cast("Document", docstore.mget(["key1"])[0])
|
||||||
assert fetched_doc.page_content == "hello"
|
assert fetched_doc.page_content == "hello"
|
||||||
assert fetched_doc.metadata == {"key": "value"}
|
assert fetched_doc.metadata == {"key": "value"}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user