mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-18 18:53:10 +00:00
Update chat model output type (#11833)
--------- Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
ed62984cb2
commit
7db6aabf65
@ -11,7 +11,6 @@ from typing import (
|
|||||||
List,
|
List,
|
||||||
Optional,
|
Optional,
|
||||||
Sequence,
|
Sequence,
|
||||||
Union,
|
|
||||||
cast,
|
cast,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -38,12 +37,10 @@ from langchain.schema import (
|
|||||||
from langchain.schema.language_model import BaseLanguageModel, LanguageModelInput
|
from langchain.schema.language_model import BaseLanguageModel, LanguageModelInput
|
||||||
from langchain.schema.messages import (
|
from langchain.schema.messages import (
|
||||||
AIMessage,
|
AIMessage,
|
||||||
|
AnyMessage,
|
||||||
BaseMessage,
|
BaseMessage,
|
||||||
BaseMessageChunk,
|
BaseMessageChunk,
|
||||||
ChatMessage,
|
|
||||||
FunctionMessage,
|
|
||||||
HumanMessage,
|
HumanMessage,
|
||||||
SystemMessage,
|
|
||||||
)
|
)
|
||||||
from langchain.schema.output import ChatGenerationChunk
|
from langchain.schema.output import ChatGenerationChunk
|
||||||
from langchain.schema.runnable import RunnableConfig
|
from langchain.schema.runnable import RunnableConfig
|
||||||
@ -79,7 +76,7 @@ async def _agenerate_from_stream(
|
|||||||
return ChatResult(generations=[generation])
|
return ChatResult(generations=[generation])
|
||||||
|
|
||||||
|
|
||||||
class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
|
class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||||
"""Base class for Chat models."""
|
"""Base class for Chat models."""
|
||||||
|
|
||||||
cache: Optional[bool] = None
|
cache: Optional[bool] = None
|
||||||
@ -116,9 +113,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
|
|||||||
@property
|
@property
|
||||||
def OutputType(self) -> Any:
|
def OutputType(self) -> Any:
|
||||||
"""Get the output type for this runnable."""
|
"""Get the output type for this runnable."""
|
||||||
return Union[
|
return AnyMessage
|
||||||
HumanMessage, AIMessage, ChatMessage, FunctionMessage, SystemMessage
|
|
||||||
]
|
|
||||||
|
|
||||||
def _convert_input(self, input: LanguageModelInput) -> PromptValue:
|
def _convert_input(self, input: LanguageModelInput) -> PromptValue:
|
||||||
if isinstance(input, PromptValue):
|
if isinstance(input, PromptValue):
|
||||||
@ -140,11 +135,9 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
|
|||||||
*,
|
*,
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> BaseMessageChunk:
|
) -> BaseMessage:
|
||||||
config = config or {}
|
config = config or {}
|
||||||
return cast(
|
return cast(
|
||||||
BaseMessageChunk,
|
|
||||||
cast(
|
|
||||||
ChatGeneration,
|
ChatGeneration,
|
||||||
self.generate_prompt(
|
self.generate_prompt(
|
||||||
[self._convert_input(input)],
|
[self._convert_input(input)],
|
||||||
@ -155,8 +148,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
|
|||||||
run_name=config.get("run_name"),
|
run_name=config.get("run_name"),
|
||||||
**kwargs,
|
**kwargs,
|
||||||
).generations[0][0],
|
).generations[0][0],
|
||||||
).message,
|
).message
|
||||||
)
|
|
||||||
|
|
||||||
async def ainvoke(
|
async def ainvoke(
|
||||||
self,
|
self,
|
||||||
@ -165,7 +157,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
|
|||||||
*,
|
*,
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> BaseMessageChunk:
|
) -> BaseMessage:
|
||||||
config = config or {}
|
config = config or {}
|
||||||
llm_result = await self.agenerate_prompt(
|
llm_result = await self.agenerate_prompt(
|
||||||
[self._convert_input(input)],
|
[self._convert_input(input)],
|
||||||
@ -176,9 +168,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
|
|||||||
run_name=config.get("run_name"),
|
run_name=config.get("run_name"),
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
return cast(
|
return cast(ChatGeneration, llm_result.generations[0][0]).message
|
||||||
BaseMessageChunk, cast(ChatGeneration, llm_result.generations[0][0]).message
|
|
||||||
)
|
|
||||||
|
|
||||||
def stream(
|
def stream(
|
||||||
self,
|
self,
|
||||||
@ -190,7 +180,9 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
|
|||||||
) -> Iterator[BaseMessageChunk]:
|
) -> Iterator[BaseMessageChunk]:
|
||||||
if type(self)._stream == BaseChatModel._stream:
|
if type(self)._stream == BaseChatModel._stream:
|
||||||
# model doesn't implement streaming, so use default implementation
|
# model doesn't implement streaming, so use default implementation
|
||||||
yield self.invoke(input, config=config, stop=stop, **kwargs)
|
yield cast(
|
||||||
|
BaseMessageChunk, self.invoke(input, config=config, stop=stop, **kwargs)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
config = config or {}
|
config = config or {}
|
||||||
messages = self._convert_input(input).to_messages()
|
messages = self._convert_input(input).to_messages()
|
||||||
@ -241,7 +233,9 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
|
|||||||
) -> AsyncIterator[BaseMessageChunk]:
|
) -> AsyncIterator[BaseMessageChunk]:
|
||||||
if type(self)._astream == BaseChatModel._astream:
|
if type(self)._astream == BaseChatModel._astream:
|
||||||
# model doesn't implement streaming, so use default implementation
|
# model doesn't implement streaming, so use default implementation
|
||||||
yield self.invoke(input, config=config, stop=stop, **kwargs)
|
yield cast(
|
||||||
|
BaseMessageChunk, self.invoke(input, config=config, stop=stop, **kwargs)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
config = config or {}
|
config = config or {}
|
||||||
messages = self._convert_input(input).to_messages()
|
messages = self._convert_input(input).to_messages()
|
||||||
|
@ -2163,19 +2163,19 @@
|
|||||||
dict({
|
dict({
|
||||||
'anyOf': list([
|
'anyOf': list([
|
||||||
dict({
|
dict({
|
||||||
'$ref': '#/definitions/HumanMessage',
|
'$ref': '#/definitions/AIMessage',
|
||||||
}),
|
}),
|
||||||
dict({
|
dict({
|
||||||
'$ref': '#/definitions/AIMessage',
|
'$ref': '#/definitions/HumanMessage',
|
||||||
}),
|
}),
|
||||||
dict({
|
dict({
|
||||||
'$ref': '#/definitions/ChatMessage',
|
'$ref': '#/definitions/ChatMessage',
|
||||||
}),
|
}),
|
||||||
dict({
|
dict({
|
||||||
'$ref': '#/definitions/FunctionMessage',
|
'$ref': '#/definitions/SystemMessage',
|
||||||
}),
|
}),
|
||||||
dict({
|
dict({
|
||||||
'$ref': '#/definitions/SystemMessage',
|
'$ref': '#/definitions/FunctionMessage',
|
||||||
}),
|
}),
|
||||||
]),
|
]),
|
||||||
'definitions': dict({
|
'definitions': dict({
|
||||||
|
Loading…
Reference in New Issue
Block a user