Compare commits

...

7 Commits

Author SHA1 Message Date
Erick Friis
ed1bb3d8d0 proposal (#14729) 2023-12-18 13:50:39 -08:00
Erick Friis
251c81b6a9 Merge branch 'master' into bagatur/core_update_ruff_mypy 2023-12-14 11:37:51 -08:00
Bagatur
e37f5508be more 2023-12-14 11:21:27 -08:00
Bagatur
6081a99e2b more 2023-12-14 11:05:01 -08:00
Bagatur
54e00ea36e fix 2023-12-14 10:28:37 -08:00
Bagatur
c4c79faabb fix 2023-12-14 10:20:46 -08:00
Bagatur
98e090a02f infra: use latest ruff and mypy in core 2023-12-13 19:03:36 -08:00
33 changed files with 226 additions and 180 deletions

View File

@@ -41,7 +41,7 @@ lint lint_diff lint_package lint_tests:
poetry run ruff .
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff format $(PYTHON_FILES) --diff
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff --select I $(PYTHON_FILES)
[ "$(PYTHON_FILES)" = "" ] || mkdir -p $(MYPY_CACHE) || poetry run mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE)
[ "$(PYTHON_FILES)" = "" ] || (mkdir -p $(MYPY_CACHE); poetry run mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE))
format format_diff:
poetry run ruff format $(PYTHON_FILES)

View File

@@ -2,7 +2,7 @@ import os
import warnings
from datetime import datetime
from enum import Enum
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
from uuid import UUID
from langchain_core.agents import AgentAction, AgentFinish
@@ -224,7 +224,7 @@ class LabelStudioCallbackHandler(BaseCallbackHandler):
raise ValueError(error_message)
def add_prompts_generations(
self, run_id: str, generations: List[List[Generation]]
self, run_id: str, generations: Sequence[Sequence[Generation]]
) -> None:
# Create tasks in Label Studio
tasks = []

View File

@@ -34,7 +34,6 @@ if TYPE_CHECKING:
# This is for backwards compatibility
# We can remove after `langchain` stops importing it
_response_to_generation = None
completion_with_retry = None
stream_completion_with_retry = None

View File

@@ -42,7 +42,7 @@ lint lint_diff lint_package lint_tests:
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff format $(PYTHON_FILES) --diff
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff --select I $(PYTHON_FILES)
[ "$(PYTHON_FILES)" = "" ] || poetry run mypy $(PYTHON_FILES)
[ "$(PYTHON_FILES)" = "" ] || mkdir $(MYPY_CACHE) || poetry run mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE)
[ "$(PYTHON_FILES)" = "" ] || (mkdir -p $(MYPY_CACHE); poetry run mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE))
format format_diff:
poetry run ruff format $(PYTHON_FILES)

View File

@@ -377,7 +377,10 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
run_managers[i].on_llm_error(e, response=LLMResult(generations=[]))
raise e
flattened_outputs = [
LLMResult(generations=[res.generations], llm_output=res.llm_output)
LLMResult(
generations=[res.generations],
llm_output=res.llm_output,
)
for res in results
]
llm_output = self._combine_llm_outputs([res.llm_output for res in results])
@@ -425,7 +428,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
batch_size=len(messages),
)
results = await asyncio.gather(
results_and_exceptions = await asyncio.gather(
*[
self._agenerate_with_cache(
m,
@@ -437,42 +440,29 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
],
return_exceptions=True,
)
exceptions = []
for i, res in enumerate(results):
if isinstance(res, BaseException):
if run_managers:
await run_managers[i].on_llm_error(
res, response=LLMResult(generations=[])
)
exceptions.append(res)
if exceptions:
if run_managers:
await asyncio.gather(
*[
run_manager.on_llm_end(
LLMResult(
generations=[res.generations], llm_output=res.llm_output
)
)
for run_manager, res in zip(run_managers, results)
if not isinstance(res, Exception)
]
)
raise exceptions[0]
flattened_outputs = [
LLMResult(generations=[res.generations], llm_output=res.llm_output)
for res in results
]
llm_output = self._combine_llm_outputs([res.llm_output for res in results])
generations = [res.generations for res in results]
output = LLMResult(generations=generations, llm_output=llm_output)
await asyncio.gather(
*[
run_manager.on_llm_end(flattened_output)
for run_manager, flattened_output in zip(
run_managers, flattened_outputs
# report results and errors
if run_managers:
jobs = [
run_manager.on_llm_error(res, response=LLMResult(generations=[]))
if isinstance(res, BaseException)
else run_manager.on_llm_end(
LLMResult(generations=[res.generations], llm_output=res.llm_output)
)
for run_manager, res in zip(run_managers, results_and_exceptions)
]
await asyncio.gather(*jobs)
# raise first exception, if any
for res in results_and_exceptions:
if isinstance(res, BaseException):
raise res
# compute return value
results = cast(List[ChatResult], results_and_exceptions)
output = LLMResult(
generations=[res.generations for res in results],
llm_output=self._combine_llm_outputs([res.llm_output for res in results]),
)
if run_managers:
output.run = [

View File

@@ -143,7 +143,7 @@ def update_cache(
"""Update the cache and get the LLM output."""
llm_cache = get_llm_cache()
for i, result in enumerate(new_results.generations):
existing_prompts[missing_prompt_idxs[i]] = result
existing_prompts[missing_prompt_idxs[i]] = list(result)
prompt = prompts[missing_prompt_idxs[i]]
if llm_cache is not None:
llm_cache.update(prompt, llm_string, result)
@@ -819,7 +819,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
raise ValueError(
"Asked to cache, but no cache found at `langchain.cache`."
)
run_managers = await asyncio.gather(
run_managers_gather = await asyncio.gather(
*[
callback_manager.on_llm_start(
dumpd(self),
@@ -834,13 +834,13 @@ class BaseLLM(BaseLanguageModel[str], ABC):
)
]
)
run_managers = [r[0] for r in run_managers]
run_managers = [r[0] for r in run_managers_gather]
output = await self._agenerate_helper(
prompts, stop, run_managers, bool(new_arg_supported), **kwargs
)
return output
if len(missing_prompts) > 0:
run_managers = await asyncio.gather(
run_managers_gather = await asyncio.gather(
*[
callback_managers[idx].on_llm_start(
dumpd(self),
@@ -853,7 +853,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
for idx in missing_prompt_idxs
]
)
run_managers = [r[0] for r in run_managers]
run_managers = [r[0] for r in run_managers_gather]
new_results = await self._agenerate_helper(
missing_prompts, stop, run_managers, bool(new_arg_supported), **kwargs
)

View File

@@ -39,7 +39,7 @@ class BaseMessage(Serializable):
def __add__(self, other: Any) -> ChatPromptTemplate:
from langchain_core.prompts.chat import ChatPromptTemplate
prompt = ChatPromptTemplate(messages=[self])
prompt = ChatPromptTemplate(messages=[self], input_variables=[])
return prompt + other
@@ -136,6 +136,7 @@ class BaseMessageChunk(BaseMessage):
additional_kwargs=self._merge_kwargs_dict(
self.additional_kwargs, other.additional_kwargs
),
type=self.type,
)
else:
raise TypeError(

View File

@@ -8,8 +8,8 @@ from typing import (
Any,
Dict,
Generic,
List,
Optional,
Sequence,
Type,
TypeVar,
Union,
@@ -31,7 +31,7 @@ class BaseLLMOutputParser(Generic[T], ABC):
"""Abstract base class for parsing the outputs of a model."""
@abstractmethod
def parse_result(self, result: List[Generation], *, partial: bool = False) -> T:
def parse_result(self, result: Sequence[Generation], *, partial: bool = False) -> T:
"""Parse a list of candidate model Generations into a specific format.
Args:
@@ -43,7 +43,7 @@ class BaseLLMOutputParser(Generic[T], ABC):
"""
async def aparse_result(
self, result: List[Generation], *, partial: bool = False
self, result: Sequence[Generation], *, partial: bool = False
) -> T:
"""Parse a list of candidate model Generations into a specific format.
@@ -206,7 +206,7 @@ class BaseOutputParser(
run_type="parser",
)
def parse_result(self, result: List[Generation], *, partial: bool = False) -> T:
def parse_result(self, result: Sequence[Generation], *, partial: bool = False) -> T:
"""Parse a list of candidate model Generations into a specific format.
The return value is parsed from only the first Generation in the result, which
@@ -233,7 +233,7 @@ class BaseOutputParser(
"""
async def aparse_result(
self, result: List[Generation], *, partial: bool = False
self, result: Sequence[Generation], *, partial: bool = False
) -> T:
"""Parse a list of candidate model Generations into a specific format.

View File

@@ -18,7 +18,7 @@ class ChatGeneration(Generation):
type: Literal["ChatGeneration"] = "ChatGeneration" # type: ignore[assignment]
"""Type is used exclusively for serialization purposes."""
@root_validator
@root_validator()
def set_text(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Set the text attribute to be the contents of the message."""
try:

View File

@@ -1,15 +1,20 @@
from typing import List, Optional
from typing import Optional, Sequence
from langchain_core.outputs.chat_generation import ChatGeneration
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.pydantic_v1 import BaseModel, root_validator
class ChatResult(BaseModel):
"""Class that contains all results for a single chat model call."""
generations: List[ChatGeneration]
generations: Sequence[ChatGeneration]
"""List of the chat generations. This is a List because an input can have multiple
candidate generations.
"""
llm_output: Optional[dict] = None
"""For arbitrary LLM provider specific output."""
@root_validator(pre=True)
def validate_environment(cls, values: dict) -> dict:
values["generations"] = list(values.get("generations", ()))
return values

View File

@@ -1,17 +1,17 @@
from __future__ import annotations
from copy import deepcopy
from typing import List, Optional
from typing import List, Optional, Sequence
from langchain_core.outputs.generation import Generation
from langchain_core.outputs.run_info import RunInfo
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.pydantic_v1 import BaseModel, root_validator
class LLMResult(BaseModel):
"""Class that contains all results for a batched LLM call."""
generations: List[List[Generation]]
generations: Sequence[Sequence[Generation]]
"""List of generated outputs. This is a List[List[]] because
each input could have multiple candidate generations."""
llm_output: Optional[dict] = None
@@ -19,6 +19,11 @@ class LLMResult(BaseModel):
run: Optional[List[RunInfo]] = None
"""List of metadata info for model call for each input."""
@root_validator(pre=True)
def validate_environment(cls, values: dict) -> dict:
values["generations"] = [list(g) for g in values.get("generations", ())]
return values
def flatten(self) -> List[LLMResult]:
"""Flatten generations into a single list.

View File

@@ -40,7 +40,7 @@ class BasePromptTemplate(RunnableSerializable[Dict, PromptValue], ABC):
If not provided, all variables are assumed to be strings."""
output_parser: Optional[BaseOutputParser] = None
"""How to parse the output of calling an LLM on this formatted prompt."""
partial_variables: Mapping[str, Union[str, Callable[[], str]]] = Field(
partial_variables: Mapping[str, Union[Any, Callable[[], Any]]] = Field(
default_factory=dict
)
@@ -137,8 +137,7 @@ class BasePromptTemplate(RunnableSerializable[Dict, PromptValue], ABC):
def _merge_partial_and_user_variables(self, **kwargs: Any) -> Dict[str, Any]:
# Get partial params:
partial_kwargs = {
k: v if isinstance(v, str) else v()
for k, v in self.partial_variables.items()
k: v() if callable(v) else v for k, v in self.partial_variables.items()
}
return {**partial_kwargs, **kwargs}

View File

@@ -8,6 +8,7 @@ from typing import (
Callable,
Dict,
List,
Literal,
Optional,
Sequence,
Set,
@@ -151,7 +152,7 @@ class BaseStringMessagePromptTemplate(BaseMessagePromptTemplate, ABC):
def from_template(
cls: Type[MessagePromptTemplateT],
template: str,
template_format: str = "f-string",
template_format: Literal["f-string", "jinja2"] = "f-string",
partial_variables: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> MessagePromptTemplateT:
@@ -396,9 +397,9 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
)
"""
input_variables: List[str]
input_variables: List[str] = Field(default_factory=list)
"""List of input variables in template messages. Used for validation."""
messages: List[MessageLike]
messages: Sequence[MessageLike]
"""List of messages consisting of either message prompt templates or messages."""
validate_template: bool = False
"""Whether or not to try validating the template."""
@@ -418,18 +419,19 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
Combined prompt template.
"""
# Allow for easy combining
messages = list(self.messages)
if isinstance(other, ChatPromptTemplate):
return ChatPromptTemplate(messages=self.messages + other.messages)
return ChatPromptTemplate(messages=messages + list(other.messages))
elif isinstance(
other, (BaseMessagePromptTemplate, BaseMessage, BaseChatPromptTemplate)
):
return ChatPromptTemplate(messages=self.messages + [other])
return ChatPromptTemplate(messages=messages + [other])
elif isinstance(other, (list, tuple)):
_other = ChatPromptTemplate.from_messages(other)
return ChatPromptTemplate(messages=self.messages + _other.messages)
return ChatPromptTemplate(messages=messages + list(_other.messages))
elif isinstance(other, str):
prompt = HumanMessagePromptTemplate.from_template(other)
return ChatPromptTemplate(messages=self.messages + [prompt])
return ChatPromptTemplate(messages=messages + [prompt])
else:
raise NotImplementedError(f"Unsupported operand type for +: {type(other)}")
@@ -446,7 +448,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
Returns:
Validated values.
"""
messages = values["messages"]
messages = list(values["messages"])
input_vars = set()
input_types: Dict[str, Any] = values.get("input_types", {})
for message in messages:
@@ -656,11 +658,13 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
Args:
message: representation of a message to append.
"""
self.messages.append(_convert_to_message(message))
self.messages = list(self.messages) + [_convert_to_message(message)]
def extend(self, messages: Sequence[MessageLikeRepresentation]) -> None:
"""Extend the chat template with a sequence of messages."""
self.messages.extend([_convert_to_message(message) for message in messages])
self.messages = list(self.messages) + [
_convert_to_message(message) for message in messages
]
@overload
def __getitem__(self, index: int) -> MessageLike:

View File

@@ -3,7 +3,7 @@ from typing import Any, Dict, List, Tuple
from langchain_core.prompt_values import PromptValue
from langchain_core.prompts.base import BasePromptTemplate
from langchain_core.prompts.chat import BaseChatPromptTemplate
from langchain_core.pydantic_v1 import root_validator
from langchain_core.pydantic_v1 import Field, root_validator
def _get_inputs(inputs: dict, input_variables: List[str]) -> dict:
@@ -27,6 +27,8 @@ class PipelinePromptTemplate(BasePromptTemplate):
"""The final prompt that is returned."""
pipeline_prompts: List[Tuple[str, BasePromptTemplate]]
"""A list of tuples, consisting of a string (`name`) and a Prompt Template."""
input_variables: List[str] = Field(default_factory=list)
"""A list of the names of the variables the prompt template expects."""
@classmethod
def get_lc_namespace(cls) -> List[str]:

View File

@@ -212,7 +212,7 @@ class PromptTemplate(StringPromptTemplate):
cls,
template: str,
*,
template_format: str = "f-string",
template_format: Literal["f-string", "jinja2"] = "f-string",
partial_variables: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> PromptTemplate:

View File

@@ -2245,12 +2245,12 @@ class RunnableGenerator(Runnable[Input, Output]):
] = None,
) -> None:
if atransform is not None:
self._atransform = atransform
self._atransform: Callable = atransform
if inspect.isasyncgenfunction(transform):
self._atransform = transform
elif inspect.isgeneratorfunction(transform):
self._transform = transform
self._transform: Callable = transform
else:
raise TypeError(
"Expected a generator function type for `transform`."

View File

@@ -382,6 +382,7 @@ class RunnableConfigurableAlternatives(DynamicRunnable[Input, Output]):
which=self.which,
default=self.default.configurable_fields(**kwargs),
alternatives=self.alternatives,
prefix_keys=self.prefix_keys,
)
def _prepare(

View File

@@ -7,7 +7,18 @@ import warnings
from abc import abstractmethod
from functools import partial
from inspect import signature
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple, Type, Union
from typing import (
Any,
Awaitable,
Callable,
Dict,
List,
Optional,
Tuple,
Type,
Union,
cast,
)
from langchain_core.callbacks import (
AsyncCallbackManager,
@@ -547,7 +558,8 @@ class Tool(BaseTool):
self, name: str, func: Optional[Callable], description: str, **kwargs: Any
) -> None:
"""Initialize tool."""
super(Tool, self).__init__(
# TODO: fix typing issue
super(Tool, self).__init__( # type: ignore[call-arg]
name=name, func=func, description=description, **kwargs
)
@@ -722,7 +734,7 @@ class StructuredTool(BaseTool):
name=name,
func=func,
coroutine=coroutine,
args_schema=_args_schema,
args_schema=cast(Type[BaseModel], _args_schema),
description=description,
return_direct=return_direct,
**kwargs,

View File

@@ -133,7 +133,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
child_execution_order=execution_order,
run_type="llm",
tags=tags or [],
name=name,
name=name, # type: ignore[arg-type] # TODO: Fix typing
)
self._start_trace(llm_run)
self._on_llm_start(llm_run)
@@ -258,7 +258,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
child_execution_order=execution_order,
child_runs=[],
run_type=run_type or "chain",
name=name,
name=name, # type: ignore[arg-type] # TODO: fix typing
tags=tags or [],
)
self._start_trace(chain_run)
@@ -336,7 +336,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
child_runs=[],
run_type="tool",
tags=tags or [],
name=name,
name=name, # type: ignore[arg-type] # TODO: fix typing
)
self._start_trace(tool_run)
self._on_tool_start(tool_run)

View File

@@ -131,7 +131,7 @@ class LangChainTracer(BaseTracer):
child_execution_order=execution_order,
run_type="llm",
tags=tags,
name=name,
name=name, # type: ignore[arg-type] # TODO: fix typing
)
self._start_trace(chat_model_run)
self._on_chat_model_start(chat_model_run)

View File

@@ -1,8 +1,9 @@
from __future__ import annotations
import datetime
import logging
import os
from typing import Any, Dict, Optional, Union
from typing import Any, Dict, Optional, Union, cast
import requests
@@ -60,29 +61,30 @@ class LangChainTracerV1(BaseTracer):
else:
raise ValueError("No prompts found in LLM run inputs")
return LLMRun(
uuid=str(run.id) if run.id else None,
uuid=str(run.id) if run.id else "",
parent_uuid=str(run.parent_run_id) if run.parent_run_id else None,
start_time=run.start_time,
end_time=run.end_time,
end_time=cast(datetime.datetime, run.end_time),
extra=run.extra,
execution_order=run.execution_order,
child_execution_order=run.child_execution_order,
serialized=run.serialized,
serialized=run.serialized or {},
session_id=session.id,
error=run.error,
prompts=prompts,
response=run.outputs if run.outputs else None,
# TODO: Fix type error.
response=run.outputs if run.outputs else None, # type: ignore[arg-type]
)
if run.run_type == "chain":
child_runs = [self._convert_to_v1_run(run) for run in run.child_runs]
return ChainRun(
uuid=str(run.id) if run.id else None,
uuid=str(run.id) if run.id else "",
parent_uuid=str(run.parent_run_id) if run.parent_run_id else None,
start_time=run.start_time,
end_time=run.end_time,
end_time=cast(datetime.datetime, run.end_time),
execution_order=run.execution_order,
child_execution_order=run.child_execution_order,
serialized=run.serialized,
serialized=run.serialized or {},
session_id=session.id,
inputs=run.inputs,
outputs=run.outputs,
@@ -97,13 +99,13 @@ class LangChainTracerV1(BaseTracer):
if run.run_type == "tool":
child_runs = [self._convert_to_v1_run(run) for run in run.child_runs]
return ToolRun(
uuid=str(run.id) if run.id else None,
uuid=str(run.id) if run.id else "",
parent_uuid=str(run.parent_run_id) if run.parent_run_id else None,
start_time=run.start_time,
end_time=run.end_time,
end_time=cast(datetime.datetime, run.end_time),
execution_order=run.execution_order,
child_execution_order=run.child_execution_order,
serialized=run.serialized,
serialized=run.serialized or {},
session_id=session.id,
action=str(run.serialized),
tool_input=run.inputs.get("input", ""),

View File

@@ -8,7 +8,7 @@ class StrictFormatter(Formatter):
def check_unused_args(
self,
used_args: Sequence[Union[int, str]],
used_args: Sequence[Union[int, str]], # type: ignore[override] # TODO: fix
args: Sequence,
kwargs: Mapping[str, Any],
) -> None:

105
libs/core/poetry.lock generated
View File

@@ -1216,52 +1216,49 @@ files = [
[[package]]
name = "mypy"
version = "0.991"
version = "1.7.1"
description = "Optional static typing for Python"
optional = false
python-versions = ">=3.7"
python-versions = ">=3.8"
files = [
{file = "mypy-0.991-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:7d17e0a9707d0772f4a7b878f04b4fd11f6f5bcb9b3813975a9b13c9332153ab"},
{file = "mypy-0.991-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:0714258640194d75677e86c786e80ccf294972cc76885d3ebbb560f11db0003d"},
{file = "mypy-0.991-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:0c8f3be99e8a8bd403caa8c03be619544bc2c77a7093685dcf308c6b109426c6"},
{file = "mypy-0.991-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc9ec663ed6c8f15f4ae9d3c04c989b744436c16d26580eaa760ae9dd5d662eb"},
{file = "mypy-0.991-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:4307270436fd7694b41f913eb09210faff27ea4979ecbcd849e57d2da2f65305"},
{file = "mypy-0.991-cp310-cp310-win_amd64.whl", hash = "sha256:901c2c269c616e6cb0998b33d4adbb4a6af0ac4ce5cd078afd7bc95830e62c1c"},
{file = "mypy-0.991-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:d13674f3fb73805ba0c45eb6c0c3053d218aa1f7abead6e446d474529aafc372"},
{file = "mypy-0.991-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:1c8cd4fb70e8584ca1ed5805cbc7c017a3d1a29fb450621089ffed3e99d1857f"},
{file = "mypy-0.991-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:209ee89fbb0deed518605edddd234af80506aec932ad28d73c08f1400ef80a33"},
{file = "mypy-0.991-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:37bd02ebf9d10e05b00d71302d2c2e6ca333e6c2a8584a98c00e038db8121f05"},
{file = "mypy-0.991-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:26efb2fcc6b67e4d5a55561f39176821d2adf88f2745ddc72751b7890f3194ad"},
{file = "mypy-0.991-cp311-cp311-win_amd64.whl", hash = "sha256:3a700330b567114b673cf8ee7388e949f843b356a73b5ab22dd7cff4742a5297"},
{file = "mypy-0.991-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:1f7d1a520373e2272b10796c3ff721ea1a0712288cafaa95931e66aa15798813"},
{file = "mypy-0.991-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:641411733b127c3e0dab94c45af15fea99e4468f99ac88b39efb1ad677da5711"},
{file = "mypy-0.991-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:3d80e36b7d7a9259b740be6d8d906221789b0d836201af4234093cae89ced0cd"},
{file = "mypy-0.991-cp37-cp37m-win_amd64.whl", hash = "sha256:e62ebaad93be3ad1a828a11e90f0e76f15449371ffeecca4a0a0b9adc99abcef"},
{file = "mypy-0.991-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:b86ce2c1866a748c0f6faca5232059f881cda6dda2a893b9a8373353cfe3715a"},
{file = "mypy-0.991-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:ac6e503823143464538efda0e8e356d871557ef60ccd38f8824a4257acc18d93"},
{file = "mypy-0.991-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:0cca5adf694af539aeaa6ac633a7afe9bbd760df9d31be55ab780b77ab5ae8bf"},
{file = "mypy-0.991-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a12c56bf73cdab116df96e4ff39610b92a348cc99a1307e1da3c3768bbb5b135"},
{file = "mypy-0.991-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:652b651d42f155033a1967739788c436491b577b6a44e4c39fb340d0ee7f0d70"},
{file = "mypy-0.991-cp38-cp38-win_amd64.whl", hash = "sha256:4175593dc25d9da12f7de8de873a33f9b2b8bdb4e827a7cae952e5b1a342e243"},
{file = "mypy-0.991-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:98e781cd35c0acf33eb0295e8b9c55cdbef64fcb35f6d3aa2186f289bed6e80d"},
{file = "mypy-0.991-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:6d7464bac72a85cb3491c7e92b5b62f3dcccb8af26826257760a552a5e244aa5"},
{file = "mypy-0.991-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:c9166b3f81a10cdf9b49f2d594b21b31adadb3d5e9db9b834866c3258b695be3"},
{file = "mypy-0.991-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b8472f736a5bfb159a5e36740847808f6f5b659960115ff29c7cecec1741c648"},
{file = "mypy-0.991-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:5e80e758243b97b618cdf22004beb09e8a2de1af481382e4d84bc52152d1c476"},
{file = "mypy-0.991-cp39-cp39-win_amd64.whl", hash = "sha256:74e259b5c19f70d35fcc1ad3d56499065c601dfe94ff67ae48b85596b9ec1461"},
{file = "mypy-0.991-py3-none-any.whl", hash = "sha256:de32edc9b0a7e67c2775e574cb061a537660e51210fbf6006b0b36ea695ae9bb"},
{file = "mypy-0.991.tar.gz", hash = "sha256:3c0165ba8f354a6d9881809ef29f1a9318a236a6d81c690094c5df32107bde06"},
{file = "mypy-1.7.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:12cce78e329838d70a204293e7b29af9faa3ab14899aec397798a4b41be7f340"},
{file = "mypy-1.7.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:1484b8fa2c10adf4474f016e09d7a159602f3239075c7bf9f1627f5acf40ad49"},
{file = "mypy-1.7.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:31902408f4bf54108bbfb2e35369877c01c95adc6192958684473658c322c8a5"},
{file = "mypy-1.7.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:f2c2521a8e4d6d769e3234350ba7b65ff5d527137cdcde13ff4d99114b0c8e7d"},
{file = "mypy-1.7.1-cp310-cp310-win_amd64.whl", hash = "sha256:fcd2572dd4519e8a6642b733cd3a8cfc1ef94bafd0c1ceed9c94fe736cb65b6a"},
{file = "mypy-1.7.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:4b901927f16224d0d143b925ce9a4e6b3a758010673eeded9b748f250cf4e8f7"},
{file = "mypy-1.7.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:2f7f6985d05a4e3ce8255396df363046c28bea790e40617654e91ed580ca7c51"},
{file = "mypy-1.7.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:944bdc21ebd620eafefc090cdf83158393ec2b1391578359776c00de00e8907a"},
{file = "mypy-1.7.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:9c7ac372232c928fff0645d85f273a726970c014749b924ce5710d7d89763a28"},
{file = "mypy-1.7.1-cp311-cp311-win_amd64.whl", hash = "sha256:f6efc9bd72258f89a3816e3a98c09d36f079c223aa345c659622f056b760ab42"},
{file = "mypy-1.7.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:6dbdec441c60699288adf051f51a5d512b0d818526d1dcfff5a41f8cd8b4aaf1"},
{file = "mypy-1.7.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:4fc3d14ee80cd22367caaaf6e014494415bf440980a3045bf5045b525680ac33"},
{file = "mypy-1.7.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2c6e4464ed5f01dc44dc9821caf67b60a4e5c3b04278286a85c067010653a0eb"},
{file = "mypy-1.7.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:d9b338c19fa2412f76e17525c1b4f2c687a55b156320acb588df79f2e6fa9fea"},
{file = "mypy-1.7.1-cp312-cp312-win_amd64.whl", hash = "sha256:204e0d6de5fd2317394a4eff62065614c4892d5a4d1a7ee55b765d7a3d9e3f82"},
{file = "mypy-1.7.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:84860e06ba363d9c0eeabd45ac0fde4b903ad7aa4f93cd8b648385a888e23200"},
{file = "mypy-1.7.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:8c5091ebd294f7628eb25ea554852a52058ac81472c921150e3a61cdd68f75a7"},
{file = "mypy-1.7.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:40716d1f821b89838589e5b3106ebbc23636ffdef5abc31f7cd0266db936067e"},
{file = "mypy-1.7.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:5cf3f0c5ac72139797953bd50bc6c95ac13075e62dbfcc923571180bebb662e9"},
{file = "mypy-1.7.1-cp38-cp38-win_amd64.whl", hash = "sha256:78e25b2fd6cbb55ddfb8058417df193f0129cad5f4ee75d1502248e588d9e0d7"},
{file = "mypy-1.7.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:75c4d2a6effd015786c87774e04331b6da863fc3fc4e8adfc3b40aa55ab516fe"},
{file = "mypy-1.7.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:2643d145af5292ee956aa0a83c2ce1038a3bdb26e033dadeb2f7066fb0c9abce"},
{file = "mypy-1.7.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:75aa828610b67462ffe3057d4d8a4112105ed211596b750b53cbfe182f44777a"},
{file = "mypy-1.7.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:ee5d62d28b854eb61889cde4e1dbc10fbaa5560cb39780c3995f6737f7e82120"},
{file = "mypy-1.7.1-cp39-cp39-win_amd64.whl", hash = "sha256:72cf32ce7dd3562373f78bd751f73c96cfb441de147cc2448a92c1a308bd0ca6"},
{file = "mypy-1.7.1-py3-none-any.whl", hash = "sha256:f7c5d642db47376a0cc130f0de6d055056e010debdaf0707cd2b0fc7e7ef30ea"},
{file = "mypy-1.7.1.tar.gz", hash = "sha256:fcb6d9afb1b6208b4c712af0dafdc650f518836065df0d4fb1d800f5d6773db2"},
]
[package.dependencies]
mypy-extensions = ">=0.4.3"
mypy-extensions = ">=1.0.0"
tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""}
typing-extensions = ">=3.10"
typing-extensions = ">=4.1.0"
[package.extras]
dmypy = ["psutil (>=4.0)"]
install-types = ["pip"]
python2 = ["typed-ast (>=1.4.0,<2)"]
mypyc = ["setuptools (>=50)"]
reports = ["lxml"]
[[package]]
@@ -2287,28 +2284,28 @@ files = [
[[package]]
name = "ruff"
version = "0.1.6"
version = "0.1.8"
description = "An extremely fast Python linter and code formatter, written in Rust."
optional = false
python-versions = ">=3.7"
files = [
{file = "ruff-0.1.6-py3-none-macosx_10_12_x86_64.macosx_11_0_arm64.macosx_10_12_universal2.whl", hash = "sha256:88b8cdf6abf98130991cbc9f6438f35f6e8d41a02622cc5ee130a02a0ed28703"},
{file = "ruff-0.1.6-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:5c549ed437680b6105a1299d2cd30e4964211606eeb48a0ff7a93ef70b902248"},
{file = "ruff-0.1.6-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1cf5f701062e294f2167e66d11b092bba7af6a057668ed618a9253e1e90cfd76"},
{file = "ruff-0.1.6-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:05991ee20d4ac4bb78385360c684e4b417edd971030ab12a4fbd075ff535050e"},
{file = "ruff-0.1.6-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:87455a0c1f739b3c069e2f4c43b66479a54dea0276dd5d4d67b091265f6fd1dc"},
{file = "ruff-0.1.6-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:683aa5bdda5a48cb8266fcde8eea2a6af4e5700a392c56ea5fb5f0d4bfdc0240"},
{file = "ruff-0.1.6-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:137852105586dcbf80c1717facb6781555c4e99f520c9c827bd414fac67ddfb6"},
{file = "ruff-0.1.6-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:bd98138a98d48a1c36c394fd6b84cd943ac92a08278aa8ac8c0fdefcf7138f35"},
{file = "ruff-0.1.6-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3a0cd909d25f227ac5c36d4e7e681577275fb74ba3b11d288aff7ec47e3ae745"},
{file = "ruff-0.1.6-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:e8fd1c62a47aa88a02707b5dd20c5ff20d035d634aa74826b42a1da77861b5ff"},
{file = "ruff-0.1.6-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:fd89b45d374935829134a082617954120d7a1470a9f0ec0e7f3ead983edc48cc"},
{file = "ruff-0.1.6-py3-none-musllinux_1_2_i686.whl", hash = "sha256:491262006e92f825b145cd1e52948073c56560243b55fb3b4ecb142f6f0e9543"},
{file = "ruff-0.1.6-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:ea284789861b8b5ca9d5443591a92a397ac183d4351882ab52f6296b4fdd5462"},
{file = "ruff-0.1.6-py3-none-win32.whl", hash = "sha256:1610e14750826dfc207ccbcdd7331b6bd285607d4181df9c1c6ae26646d6848a"},
{file = "ruff-0.1.6-py3-none-win_amd64.whl", hash = "sha256:4558b3e178145491e9bc3b2ee3c4b42f19d19384eaa5c59d10acf6e8f8b57e33"},
{file = "ruff-0.1.6-py3-none-win_arm64.whl", hash = "sha256:03910e81df0d8db0e30050725a5802441c2022ea3ae4fe0609b76081731accbc"},
{file = "ruff-0.1.6.tar.gz", hash = "sha256:1b09f29b16c6ead5ea6b097ef2764b42372aebe363722f1605ecbcd2b9207184"},
{file = "ruff-0.1.8-py3-none-macosx_10_12_x86_64.macosx_11_0_arm64.macosx_10_12_universal2.whl", hash = "sha256:7de792582f6e490ae6aef36a58d85df9f7a0cfd1b0d4fe6b4fb51803a3ac96fa"},
{file = "ruff-0.1.8-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:c8e3255afd186c142eef4ec400d7826134f028a85da2146102a1172ecc7c3696"},
{file = "ruff-0.1.8-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ff78a7583020da124dd0deb835ece1d87bb91762d40c514ee9b67a087940528b"},
{file = "ruff-0.1.8-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:bd8ee69b02e7bdefe1e5da2d5b6eaaddcf4f90859f00281b2333c0e3a0cc9cd6"},
{file = "ruff-0.1.8-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a05b0ddd7ea25495e4115a43125e8a7ebed0aa043c3d432de7e7d6e8e8cd6448"},
{file = "ruff-0.1.8-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:e6f08ca730f4dc1b76b473bdf30b1b37d42da379202a059eae54ec7fc1fbcfed"},
{file = "ruff-0.1.8-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f35960b02df6b827c1b903091bb14f4b003f6cf102705efc4ce78132a0aa5af3"},
{file = "ruff-0.1.8-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7d076717c67b34c162da7c1a5bda16ffc205e0e0072c03745275e7eab888719f"},
{file = "ruff-0.1.8-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b6a21ab023124eafb7cef6d038f835cb1155cd5ea798edd8d9eb2f8b84be07d9"},
{file = "ruff-0.1.8-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:ce697c463458555027dfb194cb96d26608abab920fa85213deb5edf26e026664"},
{file = "ruff-0.1.8-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:db6cedd9ffed55548ab313ad718bc34582d394e27a7875b4b952c2d29c001b26"},
{file = "ruff-0.1.8-py3-none-musllinux_1_2_i686.whl", hash = "sha256:05ffe9dbd278965271252704eddb97b4384bf58b971054d517decfbf8c523f05"},
{file = "ruff-0.1.8-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:5daaeaf00ae3c1efec9742ff294b06c3a2a9db8d3db51ee4851c12ad385cda30"},
{file = "ruff-0.1.8-py3-none-win32.whl", hash = "sha256:e49fbdfe257fa41e5c9e13c79b9e79a23a79bd0e40b9314bc53840f520c2c0b3"},
{file = "ruff-0.1.8-py3-none-win_amd64.whl", hash = "sha256:f41f692f1691ad87f51708b823af4bb2c5c87c9248ddd3191c8f088e66ce590a"},
{file = "ruff-0.1.8-py3-none-win_arm64.whl", hash = "sha256:aa8ee4f8440023b0a6c3707f76cadce8657553655dcbb5fc9b2f9bb9bee389f6"},
{file = "ruff-0.1.8.tar.gz", hash = "sha256:f7ee467677467526cfe135eab86a40a0e8db43117936ac4f9b469ce9cdb3fb62"},
]
[[package]]
@@ -2734,4 +2731,4 @@ extended-testing = ["jinja2"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.8.1,<4.0"
content-hash = "64fa7ef31713835d12d5213f04b52adf7423299d023f9558b8b4e65ce1e5262f"
content-hash = "ef8b5b1a303d9cd9b5740991c9ae5d37727663548373324565312a1d33cfda5b"

View File

@@ -24,15 +24,15 @@ jinja2 = {version = "^3", optional = true}
optional = true
[tool.poetry.group.lint.dependencies]
ruff = "^0.1.5"
ruff = "^0.1.8"
[tool.poetry.group.typing]
optional = true
[tool.poetry.group.typing.dependencies]
mypy = "^0.991"
types-pyyaml = "^6.0.12.2"
types-requests = "^2.28.11.5"
mypy = "^1.7.1"
types-pyyaml = ">=5.3"
types-requests = "^2"
types-jinja2 = "^2.11.9"
[tool.poetry.group.dev]

View File

@@ -118,8 +118,11 @@ def test_prompt_invalid_template_format() -> None:
template = "This is a {foo} test."
input_variables = ["foo"]
with pytest.raises(ValueError):
# Intentional bad argument.
PromptTemplate(
input_variables=input_variables, template=template, template_format="bar"
input_variables=input_variables,
template=template,
template_format="bar", # type: ignore[arg-type]
)

View File

@@ -62,7 +62,7 @@ lint lint_diff lint_package lint_tests:
poetry run ruff .
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff format $(PYTHON_FILES) --diff
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff --select I $(PYTHON_FILES)
[ "$(PYTHON_FILES)" = "" ] || mkdir $(MYPY_CACHE) || poetry run mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE)
[ "$(PYTHON_FILES)" = "" ] || (mkdir -p $(MYPY_CACHE); poetry run mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE))
format format_diff:
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff format $(PYTHON_FILES)

View File

@@ -1,7 +1,7 @@
import asyncio
import json
from json import JSONDecodeError
from typing import List, Union
from typing import Sequence, Union
from langchain_core.agents import AgentAction, AgentActionMessageLog, AgentFinish
from langchain_core.exceptions import OutputParserException
@@ -78,7 +78,7 @@ class OpenAIFunctionsAgentOutputParser(AgentOutputParser):
)
def parse_result(
self, result: List[Generation], *, partial: bool = False
self, result: Sequence[Generation], *, partial: bool = False
) -> Union[AgentAction, AgentFinish]:
if not isinstance(result[0], ChatGeneration):
raise ValueError("This output parser only works on ChatGeneration output")
@@ -86,7 +86,7 @@ class OpenAIFunctionsAgentOutputParser(AgentOutputParser):
return self._parse_ai_message(message)
async def aparse_result(
self, result: List[Generation], *, partial: bool = False
self, result: Sequence[Generation], *, partial: bool = False
) -> Union[AgentAction, AgentFinish]:
return await asyncio.get_running_loop().run_in_executor(
None, self.parse_result, result

View File

@@ -1,7 +1,7 @@
import asyncio
import json
from json import JSONDecodeError
from typing import List, Union
from typing import List, Sequence, Union
from langchain_core.agents import AgentAction, AgentActionMessageLog, AgentFinish
from langchain_core.exceptions import OutputParserException
@@ -85,7 +85,7 @@ class OpenAIToolsAgentOutputParser(MultiActionAgentOutputParser):
return "openai-tools-agent-output-parser"
def parse_result(
self, result: List[Generation], *, partial: bool = False
self, result: Sequence[Generation], *, partial: bool = False
) -> Union[List[AgentAction], AgentFinish]:
if not isinstance(result[0], ChatGeneration):
raise ValueError("This output parser only works on ChatGeneration output")
@@ -93,7 +93,7 @@ class OpenAIToolsAgentOutputParser(MultiActionAgentOutputParser):
return parse_ai_message_to_openai_tool_action(message)
async def aparse_result(
self, result: List[Generation], *, partial: bool = False
self, result: Sequence[Generation], *, partial: bool = False
) -> Union[List[AgentAction], AgentFinish]:
return await asyncio.get_running_loop().run_in_executor(
None, self.parse_result, result

View File

@@ -48,7 +48,7 @@ class _ResponseChain(LLMChain):
@abstractmethod
def _extract_tokens_and_log_probs(
self, generations: List[Generation]
self, generations: Sequence[Generation]
) -> Tuple[Sequence[str], Sequence[float]]:
"""Extract tokens and log probs from response."""
@@ -63,7 +63,7 @@ class _OpenAIResponseChain(_ResponseChain):
)
def _extract_tokens_and_log_probs(
self, generations: List[Generation]
self, generations: Sequence[Generation]
) -> Tuple[Sequence[str], Sequence[float]]:
tokens = []
log_probs = []

View File

@@ -1,6 +1,6 @@
import copy
import json
from typing import Any, Dict, List, Optional, Type, Union
from typing import Any, Dict, Optional, Sequence, Type, Union
import jsonpatch
@@ -23,7 +23,9 @@ class OutputFunctionsParser(BaseGenerationOutputParser[Any]):
args_only: bool = True
"""Whether to only return the arguments to the function call."""
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
def parse_result(
self, result: Sequence[Generation], *, partial: bool = False
) -> Any:
generation = result[0]
if not isinstance(generation, ChatGeneration):
raise OutputParserException(
@@ -61,7 +63,9 @@ class JsonOutputFunctionsParser(BaseCumulativeTransformOutputParser[Any]):
def _diff(self, prev: Optional[Any], next: Any) -> Any:
return jsonpatch.make_patch(prev, next).patch
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
def parse_result(
self, result: Sequence[Generation], *, partial: bool = False
) -> Any:
if len(result) != 1:
raise OutputParserException(
f"Expected exactly one result, but got {len(result)}"
@@ -131,7 +135,9 @@ class JsonKeyOutputFunctionsParser(JsonOutputFunctionsParser):
key_name: str
"""The name of the key to return."""
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
def parse_result(
self, result: Sequence[Generation], *, partial: bool = False
) -> Any:
res = super().parse_result(result, partial=partial)
if partial and res is None:
return None
@@ -158,7 +164,9 @@ class PydanticOutputFunctionsParser(OutputFunctionsParser):
)
return values
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
def parse_result(
self, result: Sequence[Generation], *, partial: bool = False
) -> Any:
_result = super().parse_result(result)
if self.args_only:
pydantic_args = self.pydantic_schema.parse_raw(_result) # type: ignore
@@ -175,6 +183,8 @@ class PydanticAttrOutputFunctionsParser(PydanticOutputFunctionsParser):
attr_name: str
"""The name of the attribute to return."""
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
def parse_result(
self, result: Sequence[Generation], *, partial: bool = False
) -> Any:
result = super().parse_result(result)
return getattr(result, self.attr_name)

View File

@@ -1,6 +1,6 @@
import copy
import json
from typing import Any, Dict, List, Optional, Type, Union
from typing import Any, Dict, Optional, Sequence, Type, Union
import jsonpatch
from langchain_core.exceptions import OutputParserException
@@ -20,7 +20,9 @@ class OutputFunctionsParser(BaseGenerationOutputParser[Any]):
args_only: bool = True
"""Whether to only return the arguments to the function call."""
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
def parse_result(
self, result: Sequence[Generation], *, partial: bool = False
) -> Any:
generation = result[0]
if not isinstance(generation, ChatGeneration):
raise OutputParserException(
@@ -58,7 +60,9 @@ class JsonOutputFunctionsParser(BaseCumulativeTransformOutputParser[Any]):
def _diff(self, prev: Optional[Any], next: Any) -> Any:
return jsonpatch.make_patch(prev, next).patch
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
def parse_result(
self, result: Sequence[Generation], *, partial: bool = False
) -> Any:
if len(result) != 1:
raise OutputParserException(
f"Expected exactly one result, but got {len(result)}"
@@ -126,7 +130,9 @@ class JsonKeyOutputFunctionsParser(JsonOutputFunctionsParser):
key_name: str
"""The name of the key to return."""
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
def parse_result(
self, result: Sequence[Generation], *, partial: bool = False
) -> Any:
res = super().parse_result(result, partial=partial)
if partial and res is None:
return None
@@ -153,7 +159,9 @@ class PydanticOutputFunctionsParser(OutputFunctionsParser):
)
return values
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
def parse_result(
self, result: Sequence[Generation], *, partial: bool = False
) -> Any:
_result = super().parse_result(result)
if self.args_only:
pydantic_args = self.pydantic_schema.parse_raw(_result) # type: ignore
@@ -170,6 +178,8 @@ class PydanticAttrOutputFunctionsParser(PydanticOutputFunctionsParser):
attr_name: str
"""The name of the attribute to return."""
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
def parse_result(
self, result: Sequence[Generation], *, partial: bool = False
) -> Any:
result = super().parse_result(result)
return getattr(result, self.attr_name)

View File

@@ -1,6 +1,6 @@
import copy
import json
from typing import Any, List, Type
from typing import Any, List, Sequence, Type
from langchain_core.exceptions import OutputParserException
from langchain_core.output_parsers import (
@@ -13,7 +13,9 @@ from langchain_core.pydantic_v1 import BaseModel
class JsonOutputToolsParser(BaseGenerationOutputParser[Any]):
"""Parse tools from OpenAI response."""
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
def parse_result(
self, result: Sequence[Generation], *, partial: bool = False
) -> Any:
generation = result[0]
if not isinstance(generation, ChatGeneration):
raise OutputParserException(
@@ -45,7 +47,9 @@ class JsonOutputKeyToolsParser(JsonOutputToolsParser):
key_name: str
"""The type of tools to return."""
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
def parse_result(
self, result: Sequence[Generation], *, partial: bool = False
) -> Any:
results = super().parse_result(result)
return [res["args"] for res in results if results["type"] == self.key_name]
@@ -55,7 +59,9 @@ class PydanticToolsParser(JsonOutputToolsParser):
tools: List[Type[BaseModel]]
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
def parse_result(
self, result: Sequence[Generation], *, partial: bool = False
) -> Any:
results = super().parse_result(result)
name_dict = {tool.__name__: tool for tool in self.tools}
return [name_dict[res["type"]](**res["args"]) for res in results]

View File

@@ -36,7 +36,7 @@ lint lint_diff lint_package lint_tests:
./scripts/lint_imports.sh
poetry run ruff .
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff format $(PYTHON_FILES) --diff
[ "$(PYTHON_FILES)" = "" ] || poetry run mypy $(PYTHON_FILES)
[ "$(PYTHON_FILES)" = "" ] || (mkdir -p $(MYPY_CACHE); poetry run mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE))
format format_diff:
poetry run ruff format $(PYTHON_FILES)