This commit is contained in:
Erick Friis 2024-10-08 15:49:45 -07:00
parent 078f73f3f8
commit 008e6b3e41
7 changed files with 10 additions and 10 deletions

View File

@ -375,7 +375,7 @@ class BaseLanguageModel(
Returns: Returns:
The sum of the number of tokens across the messages. The sum of the number of tokens across the messages.
""" """
return sum([self.get_num_tokens(get_buffer_string([m])) for m in messages]) return sum(self.get_num_tokens(get_buffer_string([m])) for m in messages)
@classmethod @classmethod
def _all_required_field_names(cls) -> set: def _all_required_field_names(cls) -> set:

View File

@ -1387,10 +1387,10 @@ class BaseLLM(BaseLanguageModel[str], ABC):
prompt_dict = self.dict() prompt_dict = self.dict()
if save_path.suffix == ".json": if save_path.suffix == ".json":
with open(file_path, "w") as f: with open(file_path, "w", encoding="utf-8") as f:
json.dump(prompt_dict, f, indent=4) json.dump(prompt_dict, f, indent=4)
elif save_path.suffix.endswith((".yaml", ".yml")): elif save_path.suffix.endswith((".yaml", ".yml")):
with open(file_path, "w") as f: with open(file_path, "w", encoding="utf-8") as f:
yaml.dump(prompt_dict, f, default_flow_style=False) yaml.dump(prompt_dict, f, default_flow_style=False)
else: else:
msg = f"{save_path} must be json or yaml" msg = f"{save_path} must be json or yaml"

View File

@ -359,10 +359,10 @@ class BasePromptTemplate(
directory_path.mkdir(parents=True, exist_ok=True) directory_path.mkdir(parents=True, exist_ok=True)
if save_path.suffix == ".json": if save_path.suffix == ".json":
with open(file_path, "w") as f: with open(file_path, "w", encoding="utf-8") as f:
json.dump(prompt_dict, f, indent=4) json.dump(prompt_dict, f, indent=4)
elif save_path.suffix.endswith((".yaml", ".yml")): elif save_path.suffix.endswith((".yaml", ".yml")):
with open(file_path, "w") as f: with open(file_path, "w", encoding="utf-8") as f:
yaml.dump(prompt_dict, f, default_flow_style=False) yaml.dump(prompt_dict, f, default_flow_style=False)
else: else:
msg = f"{save_path} must be json or yaml" msg = f"{save_path} must be json or yaml"

View File

@ -588,7 +588,7 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
Returns: Returns:
A new instance of this class. A new instance of this class.
""" """
with open(str(template_file)) as f: with open(str(template_file), encoding="utf-8") as f:
template = f.read() template = f.read()
return cls.from_template(template, input_variables=input_variables, **kwargs) return cls.from_template(template, input_variables=input_variables, **kwargs)

View File

@ -53,7 +53,7 @@ def _load_template(var_name: str, config: dict) -> dict:
template_path = Path(config.pop(f"{var_name}_path")) template_path = Path(config.pop(f"{var_name}_path"))
# Load the template. # Load the template.
if template_path.suffix == ".txt": if template_path.suffix == ".txt":
with open(template_path) as f: with open(template_path, encoding="utf-8") as f:
template = f.read() template = f.read()
else: else:
raise ValueError raise ValueError
@ -67,7 +67,7 @@ def _load_examples(config: dict) -> dict:
if isinstance(config["examples"], list): if isinstance(config["examples"], list):
pass pass
elif isinstance(config["examples"], str): elif isinstance(config["examples"], str):
with open(config["examples"]) as f: with open(config["examples"], encoding="utf-8") as f:
if config["examples"].endswith(".json"): if config["examples"].endswith(".json"):
examples = json.load(f) examples = json.load(f)
elif config["examples"].endswith((".yaml", ".yml")): elif config["examples"].endswith((".yaml", ".yml")):

View File

@ -248,7 +248,7 @@ class InMemoryRateLimiter(BaseRateLimiter):
if not blocking: if not blocking:
return self._consume() return self._consume()
while not self._consume(): while not self._consume(): # noqa: ASYNC110
await asyncio.sleep(self.check_every_n_seconds) await asyncio.sleep(self.check_every_n_seconds)
return True return True

View File

@ -314,7 +314,7 @@ async def test_runnable_sequence_parallel_trace_nesting(method: str) -> None:
"other_thing": "RunnableParallel<chain_result,other_thing>", "other_thing": "RunnableParallel<chain_result,other_thing>",
"after": "RunnableSequence", "after": "RunnableSequence",
} }
assert len(posts) == sum([1 if isinstance(n, str) else len(n) for n in name_order]) assert len(posts) == sum(1 if isinstance(n, str) else len(n) for n in name_order)
prev_dotted_order = None prev_dotted_order = None
dotted_order_map = {} dotted_order_map = {}
id_map = {} id_map = {}