mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-12 07:50:39 +00:00
langchain: Add ruff rules C4 (#31879)
All auto-fixes See https://docs.astral.sh/ruff/rules/#flake8-comprehensions-c4 --------- Co-authored-by: Mason Daugherty <mason@langchain.dev>
This commit is contained in:
parent
4134b36db8
commit
fceebbb387
@ -1126,9 +1126,9 @@ class AgentExecutor(Chain):
|
|||||||
agent = self.agent
|
agent = self.agent
|
||||||
tools = self.tools
|
tools = self.tools
|
||||||
allowed_tools = agent.get_allowed_tools() # type: ignore[union-attr]
|
allowed_tools = agent.get_allowed_tools() # type: ignore[union-attr]
|
||||||
if allowed_tools is not None and set(allowed_tools) != set(
|
if allowed_tools is not None and set(allowed_tools) != {
|
||||||
[tool.name for tool in tools]
|
tool.name for tool in tools
|
||||||
):
|
}:
|
||||||
msg = (
|
msg = (
|
||||||
f"Allowed tools ({allowed_tools}) different than "
|
f"Allowed tools ({allowed_tools}) different than "
|
||||||
f"provided tools ({[tool.name for tool in tools]})"
|
f"provided tools ({[tool.name for tool in tools]})"
|
||||||
@ -1318,16 +1318,15 @@ class AgentExecutor(Chain):
|
|||||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||||
) -> Union[AgentFinish, list[tuple[AgentAction, str]]]:
|
) -> Union[AgentFinish, list[tuple[AgentAction, str]]]:
|
||||||
return self._consume_next_step(
|
return self._consume_next_step(
|
||||||
[
|
list(
|
||||||
a
|
self._iter_next_step(
|
||||||
for a in self._iter_next_step(
|
|
||||||
name_to_tool_map,
|
name_to_tool_map,
|
||||||
color_mapping,
|
color_mapping,
|
||||||
inputs,
|
inputs,
|
||||||
intermediate_steps,
|
intermediate_steps,
|
||||||
run_manager,
|
run_manager,
|
||||||
)
|
)
|
||||||
]
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
def _iter_next_step(
|
def _iter_next_step(
|
||||||
|
@ -37,7 +37,7 @@ class SelfAskOutputParser(AgentOutputParser):
|
|||||||
|
|
||||||
def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
|
def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
|
||||||
last_line = text.split("\n")[-1]
|
last_line = text.split("\n")[-1]
|
||||||
if not any([follow in last_line for follow in self.followups]):
|
if not any(follow in last_line for follow in self.followups):
|
||||||
if self.finish_string not in last_line:
|
if self.finish_string not in last_line:
|
||||||
msg = f"Could not parse output: {text}"
|
msg = f"Could not parse output: {text}"
|
||||||
raise OutputParserException(msg)
|
raise OutputParserException(msg)
|
||||||
|
@ -24,7 +24,7 @@ class InvalidTool(BaseTool):
|
|||||||
run_manager: Optional[CallbackManagerForToolRun] = None,
|
run_manager: Optional[CallbackManagerForToolRun] = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Use the tool."""
|
"""Use the tool."""
|
||||||
available_tool_names_str = ", ".join([tool for tool in available_tool_names])
|
available_tool_names_str = ", ".join(list(available_tool_names))
|
||||||
return (
|
return (
|
||||||
f"{requested_tool_name} is not a valid tool, "
|
f"{requested_tool_name} is not a valid tool, "
|
||||||
f"try one of [{available_tool_names_str}]."
|
f"try one of [{available_tool_names_str}]."
|
||||||
@ -37,7 +37,7 @@ class InvalidTool(BaseTool):
|
|||||||
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
|
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Use the tool asynchronously."""
|
"""Use the tool asynchronously."""
|
||||||
available_tool_names_str = ", ".join([tool for tool in available_tool_names])
|
available_tool_names_str = ", ".join(list(available_tool_names))
|
||||||
return (
|
return (
|
||||||
f"{requested_tool_name} is not a valid tool, "
|
f"{requested_tool_name} is not a valid tool, "
|
||||||
f"try one of [{available_tool_names_str}]."
|
f"try one of [{available_tool_names_str}]."
|
||||||
|
@ -352,7 +352,7 @@ try:
|
|||||||
headers: Optional[dict] = None,
|
headers: Optional[dict] = None,
|
||||||
api_url_prompt: BasePromptTemplate = API_URL_PROMPT,
|
api_url_prompt: BasePromptTemplate = API_URL_PROMPT,
|
||||||
api_response_prompt: BasePromptTemplate = API_RESPONSE_PROMPT,
|
api_response_prompt: BasePromptTemplate = API_RESPONSE_PROMPT,
|
||||||
limit_to_domains: Optional[Sequence[str]] = tuple(),
|
limit_to_domains: Optional[Sequence[str]] = (),
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> APIChain:
|
) -> APIChain:
|
||||||
"""Load chain from just an LLM and the api docs."""
|
"""Load chain from just an LLM and the api docs."""
|
||||||
|
@ -112,13 +112,15 @@ class Chain(RunnableSerializable[dict[str, Any], dict[str, Any]], ABC):
|
|||||||
self, config: Optional[RunnableConfig] = None
|
self, config: Optional[RunnableConfig] = None
|
||||||
) -> type[BaseModel]:
|
) -> type[BaseModel]:
|
||||||
# This is correct, but pydantic typings/mypy don't think so.
|
# This is correct, but pydantic typings/mypy don't think so.
|
||||||
return create_model("ChainInput", **{k: (Any, None) for k in self.input_keys})
|
return create_model("ChainInput", **dict.fromkeys(self.input_keys, (Any, None)))
|
||||||
|
|
||||||
def get_output_schema(
|
def get_output_schema(
|
||||||
self, config: Optional[RunnableConfig] = None
|
self, config: Optional[RunnableConfig] = None
|
||||||
) -> type[BaseModel]:
|
) -> type[BaseModel]:
|
||||||
# This is correct, but pydantic typings/mypy don't think so.
|
# This is correct, but pydantic typings/mypy don't think so.
|
||||||
return create_model("ChainOutput", **{k: (Any, None) for k in self.output_keys})
|
return create_model(
|
||||||
|
"ChainOutput", **dict.fromkeys(self.output_keys, (Any, None))
|
||||||
|
)
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def invoke(
|
def invoke(
|
||||||
|
@ -100,7 +100,7 @@ class MapRerankDocumentsChain(BaseCombineDocumentsChain):
|
|||||||
if self.return_intermediate_steps:
|
if self.return_intermediate_steps:
|
||||||
schema["intermediate_steps"] = (list[str], None)
|
schema["intermediate_steps"] = (list[str], None)
|
||||||
if self.metadata_keys:
|
if self.metadata_keys:
|
||||||
schema.update({key: (Any, None) for key in self.metadata_keys})
|
schema.update(dict.fromkeys(self.metadata_keys, (Any, None)))
|
||||||
|
|
||||||
return create_model("MapRerankOutput", **schema)
|
return create_model("MapRerankOutput", **schema)
|
||||||
|
|
||||||
|
@ -48,7 +48,7 @@ class SequentialChain(Chain):
|
|||||||
"""Validate that the correct inputs exist for all chains."""
|
"""Validate that the correct inputs exist for all chains."""
|
||||||
chains = values["chains"]
|
chains = values["chains"]
|
||||||
input_variables = values["input_variables"]
|
input_variables = values["input_variables"]
|
||||||
memory_keys = list()
|
memory_keys = []
|
||||||
if "memory" in values and values["memory"] is not None:
|
if "memory" in values and values["memory"] is not None:
|
||||||
"""Validate that prompt input variables are consistent."""
|
"""Validate that prompt input variables are consistent."""
|
||||||
memory_keys = values["memory"].memory_variables
|
memory_keys = values["memory"].memory_variables
|
||||||
|
@ -69,7 +69,7 @@ def load_dataset(uri: str) -> list[dict]:
|
|||||||
raise ImportError(msg)
|
raise ImportError(msg)
|
||||||
|
|
||||||
dataset = load_dataset(f"LangChainDatasets/{uri}")
|
dataset = load_dataset(f"LangChainDatasets/{uri}")
|
||||||
return [d for d in dataset["train"]]
|
return list(dataset["train"])
|
||||||
|
|
||||||
|
|
||||||
_EVALUATOR_MAP: dict[
|
_EVALUATOR_MAP: dict[
|
||||||
|
@ -311,10 +311,10 @@ class SQLRecordManager(RecordManager):
|
|||||||
).values(records_to_upsert)
|
).values(records_to_upsert)
|
||||||
stmt = sqlite_insert_stmt.on_conflict_do_update(
|
stmt = sqlite_insert_stmt.on_conflict_do_update(
|
||||||
[UpsertionRecord.key, UpsertionRecord.namespace],
|
[UpsertionRecord.key, UpsertionRecord.namespace],
|
||||||
set_=dict(
|
set_={
|
||||||
updated_at=sqlite_insert_stmt.excluded.updated_at,
|
"updated_at": sqlite_insert_stmt.excluded.updated_at,
|
||||||
group_id=sqlite_insert_stmt.excluded.group_id,
|
"group_id": sqlite_insert_stmt.excluded.group_id,
|
||||||
),
|
},
|
||||||
)
|
)
|
||||||
elif self.dialect == "postgresql":
|
elif self.dialect == "postgresql":
|
||||||
from sqlalchemy.dialects.postgresql import Insert as PgInsertType
|
from sqlalchemy.dialects.postgresql import Insert as PgInsertType
|
||||||
@ -327,10 +327,10 @@ class SQLRecordManager(RecordManager):
|
|||||||
)
|
)
|
||||||
stmt = pg_insert_stmt.on_conflict_do_update( # type: ignore[assignment]
|
stmt = pg_insert_stmt.on_conflict_do_update( # type: ignore[assignment]
|
||||||
"uix_key_namespace", # Name of constraint
|
"uix_key_namespace", # Name of constraint
|
||||||
set_=dict(
|
set_={
|
||||||
updated_at=pg_insert_stmt.excluded.updated_at,
|
"updated_at": pg_insert_stmt.excluded.updated_at,
|
||||||
group_id=pg_insert_stmt.excluded.group_id,
|
"group_id": pg_insert_stmt.excluded.group_id,
|
||||||
),
|
},
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
msg = f"Unsupported dialect {self.dialect}"
|
msg = f"Unsupported dialect {self.dialect}"
|
||||||
@ -393,10 +393,10 @@ class SQLRecordManager(RecordManager):
|
|||||||
).values(records_to_upsert)
|
).values(records_to_upsert)
|
||||||
stmt = sqlite_insert_stmt.on_conflict_do_update(
|
stmt = sqlite_insert_stmt.on_conflict_do_update(
|
||||||
[UpsertionRecord.key, UpsertionRecord.namespace],
|
[UpsertionRecord.key, UpsertionRecord.namespace],
|
||||||
set_=dict(
|
set_={
|
||||||
updated_at=sqlite_insert_stmt.excluded.updated_at,
|
"updated_at": sqlite_insert_stmt.excluded.updated_at,
|
||||||
group_id=sqlite_insert_stmt.excluded.group_id,
|
"group_id": sqlite_insert_stmt.excluded.group_id,
|
||||||
),
|
},
|
||||||
)
|
)
|
||||||
elif self.dialect == "postgresql":
|
elif self.dialect == "postgresql":
|
||||||
from sqlalchemy.dialects.postgresql import Insert as PgInsertType
|
from sqlalchemy.dialects.postgresql import Insert as PgInsertType
|
||||||
@ -409,10 +409,10 @@ class SQLRecordManager(RecordManager):
|
|||||||
)
|
)
|
||||||
stmt = pg_insert_stmt.on_conflict_do_update( # type: ignore[assignment]
|
stmt = pg_insert_stmt.on_conflict_do_update( # type: ignore[assignment]
|
||||||
"uix_key_namespace", # Name of constraint
|
"uix_key_namespace", # Name of constraint
|
||||||
set_=dict(
|
set_={
|
||||||
updated_at=pg_insert_stmt.excluded.updated_at,
|
"updated_at": pg_insert_stmt.excluded.updated_at,
|
||||||
group_id=pg_insert_stmt.excluded.group_id,
|
"group_id": pg_insert_stmt.excluded.group_id,
|
||||||
),
|
},
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
msg = f"Unsupported dialect {self.dialect}"
|
msg = f"Unsupported dialect {self.dialect}"
|
||||||
@ -432,7 +432,7 @@ class SQLRecordManager(RecordManager):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
records = filtered_query.all()
|
records = filtered_query.all()
|
||||||
found_keys = set(r.key for r in records)
|
found_keys = {r.key for r in records}
|
||||||
return [k in found_keys for k in keys]
|
return [k in found_keys for k in keys]
|
||||||
|
|
||||||
async def aexists(self, keys: Sequence[str]) -> list[bool]:
|
async def aexists(self, keys: Sequence[str]) -> list[bool]:
|
||||||
|
@ -8,7 +8,7 @@ class SimpleMemory(BaseMemory):
|
|||||||
ever change between prompts.
|
ever change between prompts.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
memories: dict[str, Any] = dict()
|
memories: dict[str, Any] = {}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def memory_variables(self) -> list[str]:
|
def memory_variables(self) -> list[str]:
|
||||||
|
@ -49,7 +49,7 @@ class CombiningOutputParser(BaseOutputParser[dict[str, Any]]):
|
|||||||
def parse(self, text: str) -> dict[str, Any]:
|
def parse(self, text: str) -> dict[str, Any]:
|
||||||
"""Parse the output of an LLM call."""
|
"""Parse the output of an LLM call."""
|
||||||
texts = text.split("\n\n")
|
texts = text.split("\n\n")
|
||||||
output = dict()
|
output = {}
|
||||||
for txt, parser in zip(texts, self.parsers):
|
for txt, parser in zip(texts, self.parsers):
|
||||||
output.update(parser.parse(txt.strip()))
|
output.update(parser.parse(txt.strip()))
|
||||||
return output
|
return output
|
||||||
|
@ -82,19 +82,19 @@ class OutputFixingParser(BaseOutputParser[T]):
|
|||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
completion = self.retry_chain.invoke(
|
completion = self.retry_chain.invoke(
|
||||||
dict(
|
{
|
||||||
instructions=self.parser.get_format_instructions(),
|
"instructions": self.parser.get_format_instructions(), # noqa: E501
|
||||||
completion=completion,
|
"completion": completion,
|
||||||
error=repr(e),
|
"error": repr(e),
|
||||||
)
|
}
|
||||||
)
|
)
|
||||||
except (NotImplementedError, AttributeError):
|
except (NotImplementedError, AttributeError):
|
||||||
# Case: self.parser does not have get_format_instructions
|
# Case: self.parser does not have get_format_instructions
|
||||||
completion = self.retry_chain.invoke(
|
completion = self.retry_chain.invoke(
|
||||||
dict(
|
{
|
||||||
completion=completion,
|
"completion": completion,
|
||||||
error=repr(e),
|
"error": repr(e),
|
||||||
)
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
msg = "Failed to parse"
|
msg = "Failed to parse"
|
||||||
@ -120,19 +120,19 @@ class OutputFixingParser(BaseOutputParser[T]):
|
|||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
completion = await self.retry_chain.ainvoke(
|
completion = await self.retry_chain.ainvoke(
|
||||||
dict(
|
{
|
||||||
instructions=self.parser.get_format_instructions(),
|
"instructions": self.parser.get_format_instructions(), # noqa: E501
|
||||||
completion=completion,
|
"completion": completion,
|
||||||
error=repr(e),
|
"error": repr(e),
|
||||||
)
|
}
|
||||||
)
|
)
|
||||||
except (NotImplementedError, AttributeError):
|
except (NotImplementedError, AttributeError):
|
||||||
# Case: self.parser does not have get_format_instructions
|
# Case: self.parser does not have get_format_instructions
|
||||||
completion = await self.retry_chain.ainvoke(
|
completion = await self.retry_chain.ainvoke(
|
||||||
dict(
|
{
|
||||||
completion=completion,
|
"completion": completion,
|
||||||
error=repr(e),
|
"error": repr(e),
|
||||||
)
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
msg = "Failed to parse"
|
msg = "Failed to parse"
|
||||||
|
@ -116,10 +116,10 @@ class RetryOutputParser(BaseOutputParser[T]):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
completion = self.retry_chain.invoke(
|
completion = self.retry_chain.invoke(
|
||||||
dict(
|
{
|
||||||
prompt=prompt_value.to_string(),
|
"prompt": prompt_value.to_string(),
|
||||||
completion=completion,
|
"completion": completion,
|
||||||
)
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
msg = "Failed to parse"
|
msg = "Failed to parse"
|
||||||
@ -153,10 +153,10 @@ class RetryOutputParser(BaseOutputParser[T]):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
completion = await self.retry_chain.ainvoke(
|
completion = await self.retry_chain.ainvoke(
|
||||||
dict(
|
{
|
||||||
prompt=prompt_value.to_string(),
|
"prompt": prompt_value.to_string(),
|
||||||
completion=completion,
|
"completion": completion,
|
||||||
)
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
msg = "Failed to parse"
|
msg = "Failed to parse"
|
||||||
@ -244,11 +244,11 @@ class RetryWithErrorOutputParser(BaseOutputParser[T]):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
completion = self.retry_chain.invoke(
|
completion = self.retry_chain.invoke(
|
||||||
dict(
|
{
|
||||||
completion=completion,
|
"completion": completion,
|
||||||
prompt=prompt_value.to_string(),
|
"prompt": prompt_value.to_string(),
|
||||||
error=repr(e),
|
"error": repr(e),
|
||||||
)
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
msg = "Failed to parse"
|
msg = "Failed to parse"
|
||||||
@ -273,11 +273,11 @@ class RetryWithErrorOutputParser(BaseOutputParser[T]):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
completion = await self.retry_chain.ainvoke(
|
completion = await self.retry_chain.ainvoke(
|
||||||
dict(
|
{
|
||||||
prompt=prompt_value.to_string(),
|
"prompt": prompt_value.to_string(),
|
||||||
completion=completion,
|
"completion": completion,
|
||||||
error=repr(e),
|
"error": repr(e),
|
||||||
)
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
msg = "Failed to parse"
|
msg = "Failed to parse"
|
||||||
|
@ -43,7 +43,7 @@ class YamlOutputParser(BaseOutputParser[T]):
|
|||||||
|
|
||||||
def get_format_instructions(self) -> str:
|
def get_format_instructions(self) -> str:
|
||||||
# Copy schema to avoid altering original Pydantic schema.
|
# Copy schema to avoid altering original Pydantic schema.
|
||||||
schema = {k: v for k, v in self.pydantic_object.schema().items()}
|
schema = dict(self.pydantic_object.schema().items())
|
||||||
|
|
||||||
# Remove extraneous fields.
|
# Remove extraneous fields.
|
||||||
reduced_schema = schema
|
reduced_schema = schema
|
||||||
|
@ -24,7 +24,7 @@ class TimeWeightedVectorStoreRetriever(BaseRetriever):
|
|||||||
vectorstore: VectorStore
|
vectorstore: VectorStore
|
||||||
"""The vectorstore to store documents and determine salience."""
|
"""The vectorstore to store documents and determine salience."""
|
||||||
|
|
||||||
search_kwargs: dict = Field(default_factory=lambda: dict(k=100))
|
search_kwargs: dict = Field(default_factory=lambda: {"k": 100})
|
||||||
"""Keyword arguments to pass to the vectorstore similarity search."""
|
"""Keyword arguments to pass to the vectorstore similarity search."""
|
||||||
|
|
||||||
# TODO: abstract as a queue
|
# TODO: abstract as a queue
|
||||||
|
@ -623,7 +623,7 @@ def _load_run_evaluators(
|
|||||||
input_key, prediction_key, reference_key = None, None, None
|
input_key, prediction_key, reference_key = None, None, None
|
||||||
if config.evaluators or (
|
if config.evaluators or (
|
||||||
config.custom_evaluators
|
config.custom_evaluators
|
||||||
and any([isinstance(e, StringEvaluator) for e in config.custom_evaluators])
|
and any(isinstance(e, StringEvaluator) for e in config.custom_evaluators)
|
||||||
):
|
):
|
||||||
input_key, prediction_key, reference_key = _get_keys(
|
input_key, prediction_key, reference_key = _get_keys(
|
||||||
config, run_inputs, run_outputs, example_outputs
|
config, run_inputs, run_outputs, example_outputs
|
||||||
|
@ -143,7 +143,7 @@ ignore-regex = ".*(Stati Uniti|Tense=Pres).*"
|
|||||||
ignore-words-list = "momento,collison,ned,foor,reworkd,parth,whats,aapply,mysogyny,unsecure,damon,crate,aadd,symbl,precesses,accademia,nin"
|
ignore-words-list = "momento,collison,ned,foor,reworkd,parth,whats,aapply,mysogyny,unsecure,damon,crate,aadd,symbl,precesses,accademia,nin"
|
||||||
|
|
||||||
[tool.ruff.lint]
|
[tool.ruff.lint]
|
||||||
select = ["A", "E", "F", "I", "EM", "PGH003", "PIE", "SIM", "T201", "D", "UP", "S", "W"]
|
select = ["A", "C4", "E", "F", "I", "EM", "PGH003", "PIE", "SIM", "T201", "D", "UP", "S", "W"]
|
||||||
pydocstyle.convention = "google"
|
pydocstyle.convention = "google"
|
||||||
pyupgrade.keep-runtime-typing = true
|
pyupgrade.keep-runtime-typing = true
|
||||||
|
|
||||||
|
@ -205,7 +205,7 @@ def test_agent_stream() -> None:
|
|||||||
agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
|
agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
|
||||||
)
|
)
|
||||||
|
|
||||||
output = [a for a in agent.stream("when was langchain made")]
|
output = list(agent.stream("when was langchain made"))
|
||||||
assert output == [
|
assert output == [
|
||||||
{
|
{
|
||||||
"actions": [
|
"actions": [
|
||||||
|
@ -24,7 +24,7 @@ def test_resolve_criteria_enum(criterion: Criteria) -> None:
|
|||||||
def test_resolve_criteria_list_enum() -> None:
|
def test_resolve_criteria_list_enum() -> None:
|
||||||
val = resolve_pairwise_criteria(list(Criteria))
|
val = resolve_pairwise_criteria(list(Criteria))
|
||||||
assert isinstance(val, dict)
|
assert isinstance(val, dict)
|
||||||
assert set(val.keys()) == set(c.value for c in list(Criteria))
|
assert set(val.keys()) == {c.value for c in list(Criteria)}
|
||||||
|
|
||||||
|
|
||||||
def test_PairwiseStringResultOutputParser_parse() -> None:
|
def test_PairwiseStringResultOutputParser_parse() -> None:
|
||||||
|
@ -288,11 +288,11 @@ def test_index_simple_delete_full(
|
|||||||
"num_updated": 0,
|
"num_updated": 0,
|
||||||
}
|
}
|
||||||
|
|
||||||
doc_texts = set(
|
doc_texts = {
|
||||||
# Ignoring type since doc should be in the store and not a None
|
# Ignoring type since doc should be in the store and not a None
|
||||||
vector_store.store.get(uid).page_content # type: ignore[union-attr]
|
vector_store.store.get(uid).page_content # type: ignore[union-attr]
|
||||||
for uid in vector_store.store
|
for uid in vector_store.store
|
||||||
)
|
}
|
||||||
assert doc_texts == {"mutated document 1", "This is another document."}
|
assert doc_texts == {"mutated document 1", "This is another document."}
|
||||||
|
|
||||||
# Attempt to index again verify that nothing changes
|
# Attempt to index again verify that nothing changes
|
||||||
@ -364,11 +364,11 @@ async def test_aindex_simple_delete_full(
|
|||||||
"num_updated": 0,
|
"num_updated": 0,
|
||||||
}
|
}
|
||||||
|
|
||||||
doc_texts = set(
|
doc_texts = {
|
||||||
# Ignoring type since doc should be in the store and not a None
|
# Ignoring type since doc should be in the store and not a None
|
||||||
vector_store.store.get(uid).page_content # type: ignore[union-attr]
|
vector_store.store.get(uid).page_content # type: ignore[union-attr]
|
||||||
for uid in vector_store.store
|
for uid in vector_store.store
|
||||||
)
|
}
|
||||||
assert doc_texts == {"mutated document 1", "This is another document."}
|
assert doc_texts == {"mutated document 1", "This is another document."}
|
||||||
|
|
||||||
# Attempt to index again verify that nothing changes
|
# Attempt to index again verify that nothing changes
|
||||||
@ -657,11 +657,11 @@ def test_incremental_delete(
|
|||||||
"num_updated": 0,
|
"num_updated": 0,
|
||||||
}
|
}
|
||||||
|
|
||||||
doc_texts = set(
|
doc_texts = {
|
||||||
# Ignoring type since doc should be in the store and not a None
|
# Ignoring type since doc should be in the store and not a None
|
||||||
vector_store.store.get(uid).page_content # type: ignore[union-attr]
|
vector_store.store.get(uid).page_content # type: ignore[union-attr]
|
||||||
for uid in vector_store.store
|
for uid in vector_store.store
|
||||||
)
|
}
|
||||||
assert doc_texts == {"This is another document.", "This is a test document."}
|
assert doc_texts == {"This is another document.", "This is a test document."}
|
||||||
|
|
||||||
# Attempt to index again verify that nothing changes
|
# Attempt to index again verify that nothing changes
|
||||||
@ -716,11 +716,11 @@ def test_incremental_delete(
|
|||||||
"num_updated": 0,
|
"num_updated": 0,
|
||||||
}
|
}
|
||||||
|
|
||||||
doc_texts = set(
|
doc_texts = {
|
||||||
# Ignoring type since doc should be in the store and not a None
|
# Ignoring type since doc should be in the store and not a None
|
||||||
vector_store.store.get(uid).page_content # type: ignore[union-attr]
|
vector_store.store.get(uid).page_content # type: ignore[union-attr]
|
||||||
for uid in vector_store.store
|
for uid in vector_store.store
|
||||||
)
|
}
|
||||||
assert doc_texts == {
|
assert doc_texts == {
|
||||||
"mutated document 1",
|
"mutated document 1",
|
||||||
"mutated document 2",
|
"mutated document 2",
|
||||||
@ -784,11 +784,11 @@ def test_incremental_indexing_with_batch_size(
|
|||||||
"num_updated": 0,
|
"num_updated": 0,
|
||||||
}
|
}
|
||||||
|
|
||||||
doc_texts = set(
|
doc_texts = {
|
||||||
# Ignoring type since doc should be in the store and not a None
|
# Ignoring type since doc should be in the store and not a None
|
||||||
vector_store.store.get(uid).page_content # type: ignore[union-attr]
|
vector_store.store.get(uid).page_content # type: ignore[union-attr]
|
||||||
for uid in vector_store.store
|
for uid in vector_store.store
|
||||||
)
|
}
|
||||||
assert doc_texts == {"1", "2", "3", "4"}
|
assert doc_texts == {"1", "2", "3", "4"}
|
||||||
|
|
||||||
|
|
||||||
@ -834,11 +834,11 @@ def test_incremental_delete_with_batch_size(
|
|||||||
"num_updated": 0,
|
"num_updated": 0,
|
||||||
}
|
}
|
||||||
|
|
||||||
doc_texts = set(
|
doc_texts = {
|
||||||
# Ignoring type since doc should be in the store and not a None
|
# Ignoring type since doc should be in the store and not a None
|
||||||
vector_store.store.get(uid).page_content # type: ignore[union-attr]
|
vector_store.store.get(uid).page_content # type: ignore[union-attr]
|
||||||
for uid in vector_store.store
|
for uid in vector_store.store
|
||||||
)
|
}
|
||||||
assert doc_texts == {"1", "2", "3", "4"}
|
assert doc_texts == {"1", "2", "3", "4"}
|
||||||
|
|
||||||
# Attempt to index again verify that nothing changes
|
# Attempt to index again verify that nothing changes
|
||||||
@ -980,11 +980,11 @@ async def test_aincremental_delete(
|
|||||||
"num_updated": 0,
|
"num_updated": 0,
|
||||||
}
|
}
|
||||||
|
|
||||||
doc_texts = set(
|
doc_texts = {
|
||||||
# Ignoring type since doc should be in the store and not a None
|
# Ignoring type since doc should be in the store and not a None
|
||||||
vector_store.store.get(uid).page_content # type: ignore[union-attr]
|
vector_store.store.get(uid).page_content # type: ignore[union-attr]
|
||||||
for uid in vector_store.store
|
for uid in vector_store.store
|
||||||
)
|
}
|
||||||
assert doc_texts == {"This is another document.", "This is a test document."}
|
assert doc_texts == {"This is another document.", "This is a test document."}
|
||||||
|
|
||||||
# Attempt to index again verify that nothing changes
|
# Attempt to index again verify that nothing changes
|
||||||
@ -1039,11 +1039,11 @@ async def test_aincremental_delete(
|
|||||||
"num_updated": 0,
|
"num_updated": 0,
|
||||||
}
|
}
|
||||||
|
|
||||||
doc_texts = set(
|
doc_texts = {
|
||||||
# Ignoring type since doc should be in the store and not a None
|
# Ignoring type since doc should be in the store and not a None
|
||||||
vector_store.store.get(uid).page_content # type: ignore[union-attr]
|
vector_store.store.get(uid).page_content # type: ignore[union-attr]
|
||||||
for uid in vector_store.store
|
for uid in vector_store.store
|
||||||
)
|
}
|
||||||
assert doc_texts == {
|
assert doc_texts == {
|
||||||
"mutated document 1",
|
"mutated document 1",
|
||||||
"mutated document 2",
|
"mutated document 2",
|
||||||
|
@ -51,7 +51,7 @@ async def test_generic_fake_chat_model_stream() -> None:
|
|||||||
_AnyIdAIMessageChunk(content="goodbye"),
|
_AnyIdAIMessageChunk(content="goodbye"),
|
||||||
]
|
]
|
||||||
|
|
||||||
chunks = [chunk for chunk in model.stream("meow")]
|
chunks = list(model.stream("meow"))
|
||||||
assert chunks == [
|
assert chunks == [
|
||||||
_AnyIdAIMessageChunk(content="hello"),
|
_AnyIdAIMessageChunk(content="hello"),
|
||||||
_AnyIdAIMessageChunk(content=" "),
|
_AnyIdAIMessageChunk(content=" "),
|
||||||
|
@ -28,7 +28,7 @@ def test_required_dependencies(uv_conf: Mapping[str, Any]) -> None:
|
|||||||
"""
|
"""
|
||||||
# Get the dependencies from the [tool.poetry.dependencies] section
|
# Get the dependencies from the [tool.poetry.dependencies] section
|
||||||
dependencies = uv_conf["project"]["dependencies"]
|
dependencies = uv_conf["project"]["dependencies"]
|
||||||
required_dependencies = set(Requirement(dep).name for dep in dependencies)
|
required_dependencies = {Requirement(dep).name for dep in dependencies}
|
||||||
|
|
||||||
assert sorted(required_dependencies) == sorted(
|
assert sorted(required_dependencies) == sorted(
|
||||||
[
|
[
|
||||||
@ -54,7 +54,7 @@ def test_test_group_dependencies(uv_conf: Mapping[str, Any]) -> None:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
dependencies = uv_conf["dependency-groups"]["test"]
|
dependencies = uv_conf["dependency-groups"]["test"]
|
||||||
test_group_deps = set(Requirement(dep).name for dep in dependencies)
|
test_group_deps = {Requirement(dep).name for dep in dependencies}
|
||||||
|
|
||||||
assert sorted(test_group_deps) == sorted(
|
assert sorted(test_group_deps) == sorted(
|
||||||
[
|
[
|
||||||
|
Loading…
Reference in New Issue
Block a user