Update chat model output type (#11833)

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
Nuno Campos 2023-10-19 08:55:15 +01:00 committed by GitHub
parent ed62984cb2
commit 7db6aabf65
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 27 additions and 33 deletions

View File

@ -11,7 +11,6 @@ from typing import (
List, List,
Optional, Optional,
Sequence, Sequence,
Union,
cast, cast,
) )
@ -38,12 +37,10 @@ from langchain.schema import (
from langchain.schema.language_model import BaseLanguageModel, LanguageModelInput from langchain.schema.language_model import BaseLanguageModel, LanguageModelInput
from langchain.schema.messages import ( from langchain.schema.messages import (
AIMessage, AIMessage,
AnyMessage,
BaseMessage, BaseMessage,
BaseMessageChunk, BaseMessageChunk,
ChatMessage,
FunctionMessage,
HumanMessage, HumanMessage,
SystemMessage,
) )
from langchain.schema.output import ChatGenerationChunk from langchain.schema.output import ChatGenerationChunk
from langchain.schema.runnable import RunnableConfig from langchain.schema.runnable import RunnableConfig
@ -79,7 +76,7 @@ async def _agenerate_from_stream(
return ChatResult(generations=[generation]) return ChatResult(generations=[generation])
class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC): class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
"""Base class for Chat models.""" """Base class for Chat models."""
cache: Optional[bool] = None cache: Optional[bool] = None
@ -116,9 +113,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
@property @property
def OutputType(self) -> Any: def OutputType(self) -> Any:
"""Get the output type for this runnable.""" """Get the output type for this runnable."""
return Union[ return AnyMessage
HumanMessage, AIMessage, ChatMessage, FunctionMessage, SystemMessage
]
def _convert_input(self, input: LanguageModelInput) -> PromptValue: def _convert_input(self, input: LanguageModelInput) -> PromptValue:
if isinstance(input, PromptValue): if isinstance(input, PromptValue):
@ -140,11 +135,9 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
*, *,
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
**kwargs: Any, **kwargs: Any,
) -> BaseMessageChunk: ) -> BaseMessage:
config = config or {} config = config or {}
return cast( return cast(
BaseMessageChunk,
cast(
ChatGeneration, ChatGeneration,
self.generate_prompt( self.generate_prompt(
[self._convert_input(input)], [self._convert_input(input)],
@ -155,8 +148,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
run_name=config.get("run_name"), run_name=config.get("run_name"),
**kwargs, **kwargs,
).generations[0][0], ).generations[0][0],
).message, ).message
)
async def ainvoke( async def ainvoke(
self, self,
@ -165,7 +157,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
*, *,
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
**kwargs: Any, **kwargs: Any,
) -> BaseMessageChunk: ) -> BaseMessage:
config = config or {} config = config or {}
llm_result = await self.agenerate_prompt( llm_result = await self.agenerate_prompt(
[self._convert_input(input)], [self._convert_input(input)],
@ -176,9 +168,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
run_name=config.get("run_name"), run_name=config.get("run_name"),
**kwargs, **kwargs,
) )
return cast( return cast(ChatGeneration, llm_result.generations[0][0]).message
BaseMessageChunk, cast(ChatGeneration, llm_result.generations[0][0]).message
)
def stream( def stream(
self, self,
@ -190,7 +180,9 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
) -> Iterator[BaseMessageChunk]: ) -> Iterator[BaseMessageChunk]:
if type(self)._stream == BaseChatModel._stream: if type(self)._stream == BaseChatModel._stream:
# model doesn't implement streaming, so use default implementation # model doesn't implement streaming, so use default implementation
yield self.invoke(input, config=config, stop=stop, **kwargs) yield cast(
BaseMessageChunk, self.invoke(input, config=config, stop=stop, **kwargs)
)
else: else:
config = config or {} config = config or {}
messages = self._convert_input(input).to_messages() messages = self._convert_input(input).to_messages()
@ -241,7 +233,9 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
) -> AsyncIterator[BaseMessageChunk]: ) -> AsyncIterator[BaseMessageChunk]:
if type(self)._astream == BaseChatModel._astream: if type(self)._astream == BaseChatModel._astream:
# model doesn't implement streaming, so use default implementation # model doesn't implement streaming, so use default implementation
yield self.invoke(input, config=config, stop=stop, **kwargs) yield cast(
BaseMessageChunk, self.invoke(input, config=config, stop=stop, **kwargs)
)
else: else:
config = config or {} config = config or {}
messages = self._convert_input(input).to_messages() messages = self._convert_input(input).to_messages()

View File

@ -2163,19 +2163,19 @@
dict({ dict({
'anyOf': list([ 'anyOf': list([
dict({ dict({
'$ref': '#/definitions/HumanMessage', '$ref': '#/definitions/AIMessage',
}), }),
dict({ dict({
'$ref': '#/definitions/AIMessage', '$ref': '#/definitions/HumanMessage',
}), }),
dict({ dict({
'$ref': '#/definitions/ChatMessage', '$ref': '#/definitions/ChatMessage',
}), }),
dict({ dict({
'$ref': '#/definitions/FunctionMessage', '$ref': '#/definitions/SystemMessage',
}), }),
dict({ dict({
'$ref': '#/definitions/SystemMessage', '$ref': '#/definitions/FunctionMessage',
}), }),
]), ]),
'definitions': dict({ 'definitions': dict({