Compare commits

...

3 Commits

Author SHA1 Message Date
Erick Friis
01907e5418 x 2024-10-15 11:23:33 -04:00
Erick Friis
008e6b3e41 x 2024-10-08 15:49:45 -07:00
Erick Friis
078f73f3f8 core: ruff encoding rule preview 2024-10-08 15:45:21 -07:00
7 changed files with 13 additions and 10 deletions

View File

@@ -375,7 +375,7 @@ class BaseLanguageModel(
Returns:
The sum of the number of tokens across the messages.
"""
return sum([self.get_num_tokens(get_buffer_string([m])) for m in messages])
return sum(self.get_num_tokens(get_buffer_string([m])) for m in messages)
@classmethod
def _all_required_field_names(cls) -> set:

View File

@@ -1387,10 +1387,10 @@ class BaseLLM(BaseLanguageModel[str], ABC):
prompt_dict = self.dict()
if save_path.suffix == ".json":
with open(file_path, "w") as f:
with open(file_path, "w", encoding="utf-8") as f:
json.dump(prompt_dict, f, indent=4)
elif save_path.suffix.endswith((".yaml", ".yml")):
with open(file_path, "w") as f:
with open(file_path, "w", encoding="utf-8") as f:
yaml.dump(prompt_dict, f, default_flow_style=False)
else:
msg = f"{save_path} must be json or yaml"

View File

@@ -359,10 +359,10 @@ class BasePromptTemplate(
directory_path.mkdir(parents=True, exist_ok=True)
if save_path.suffix == ".json":
with open(file_path, "w") as f:
with open(file_path, "w", encoding="utf-8") as f:
json.dump(prompt_dict, f, indent=4)
elif save_path.suffix.endswith((".yaml", ".yml")):
with open(file_path, "w") as f:
with open(file_path, "w", encoding="utf-8") as f:
yaml.dump(prompt_dict, f, default_flow_style=False)
else:
msg = f"{save_path} must be json or yaml"

View File

@@ -588,7 +588,7 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
Returns:
A new instance of this class.
"""
with open(str(template_file)) as f:
with open(str(template_file), encoding="utf-8") as f:
template = f.read()
return cls.from_template(template, input_variables=input_variables, **kwargs)

View File

@@ -53,7 +53,7 @@ def _load_template(var_name: str, config: dict) -> dict:
template_path = Path(config.pop(f"{var_name}_path"))
# Load the template.
if template_path.suffix == ".txt":
with open(template_path) as f:
with open(template_path, encoding="utf-8") as f:
template = f.read()
else:
raise ValueError
@@ -67,7 +67,7 @@ def _load_examples(config: dict) -> dict:
if isinstance(config["examples"], list):
pass
elif isinstance(config["examples"], str):
with open(config["examples"]) as f:
with open(config["examples"], encoding="utf-8") as f:
if config["examples"].endswith(".json"):
examples = json.load(f)
elif config["examples"].endswith((".yaml", ".yml")):

View File

@@ -73,13 +73,16 @@ select = [
"TID",
"UP",
"W",
"YTT"
"YTT",
"PLW1514",
]
ignore = [
"COM812", # Messes with the formatter
"UP007", # Incompatible with pydantic + Python 3.9
"W293", #
]
preview = true
explicit-preview-rules = true
[tool.coverage.run]
omit = [ "tests/*",]

View File

@@ -314,7 +314,7 @@ async def test_runnable_sequence_parallel_trace_nesting(method: str) -> None:
"other_thing": "RunnableParallel<chain_result,other_thing>",
"after": "RunnableSequence",
}
assert len(posts) == sum([1 if isinstance(n, str) else len(n) for n in name_order])
assert len(posts) == sum(1 if isinstance(n, str) else len(n) for n in name_order)
prev_dotted_order = None
dotted_order_map = {}
id_map = {}