From acaa2d3ee4af5c83ffa512dbdb22dfc08d1306bd Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Wed, 1 Mar 2023 16:44:45 -0800 Subject: [PATCH] cr --- docs/modules/chat/conversation.ipynb | 127 ++++++++++++++++++ langchain/chains/conversation/memory.py | 20 +-- langchain/{chains => }/chat/__init__.py | 0 langchain/chat/base.py | 28 ++++ langchain/chat/conversation.py | 66 +++++++++ langchain/chat/memory.py | 28 ++++ langchain/chat/question_answering.py | 66 +++++++++ langchain/chat/vectordbqa.py | 170 ++++++++++++++++++++++++ langchain/chat_models/base.py | 20 ++- langchain/chat_models/openai.py | 14 +- langchain/memory/chat_memory.py | 8 +- langchain/memory/utils.py | 4 +- langchain/schema.py | 13 +- 13 files changed, 532 insertions(+), 32 deletions(-) create mode 100644 docs/modules/chat/conversation.ipynb rename langchain/{chains => }/chat/__init__.py (100%) create mode 100644 langchain/chat/base.py create mode 100644 langchain/chat/conversation.py create mode 100644 langchain/chat/memory.py create mode 100644 langchain/chat/question_answering.py create mode 100644 langchain/chat/vectordbqa.py diff --git a/docs/modules/chat/conversation.ipynb b/docs/modules/chat/conversation.ipynb new file mode 100644 index 00000000000..f762f33025b --- /dev/null +++ b/docs/modules/chat/conversation.ipynb @@ -0,0 +1,127 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "319061e6", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.chat.conversation import ConversationChain\n", + "from langchain.chat_models.openai import OpenAI" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "4dbbb98f", + "metadata": {}, + "outputs": [], + "source": [ + "model = OpenAI()" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "164af02b", + "metadata": {}, + "outputs": [], + "source": [ + "chain = ConversationChain.from_model(model=model)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "11830ac7", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'\\n\\nHello there! How may I assist you today?'" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "chain.run(\"hi!\")" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "8182f6c8", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'As an AI language model, I do not have access to your physical location. However, since you are interacting with me through this platform, I believe you are connected to the internet and currently using a device such as a computer, phone or tablet.'" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "chain.run('where am i?')" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "ecc2fbbb", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'I said that as an AI language model, I do not have access to your physical location. However, since you are interacting with me through this platform, I believe you are connected to the internet and currently using a device such as a computer, phone, or tablet.'" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "chain.run('what did you say?')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "614977d8", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.1" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/langchain/chains/conversation/memory.py b/langchain/chains/conversation/memory.py index dc686439ea1..bacb06d7161 100644 --- a/langchain/chains/conversation/memory.py +++ b/langchain/chains/conversation/memory.py @@ -20,7 +20,7 @@ from langchain.llms.base import BaseLLM from langchain.memory.chat_memory import ChatMemory from langchain.memory.utils import get_buffer_string from langchain.prompts.base import BasePromptTemplate -from langchain.schema import ChatGeneration +from langchain.schema import ChatMessage def _get_prompt_input_key(inputs: Dict[str, Any], memory_variables: List[str]) -> str: @@ -134,14 +134,14 @@ class ConversationBufferMemory(ChatMemoryMixin, BaseModel): return {self.memory_key: self.buffer} -class ConversationBufferWindowMemory(Memory, ChatMemoryMixin, BaseModel): +class ConversationBufferWindowMemory(ChatMemoryMixin, BaseModel): """Buffer for storing conversation memory.""" memory_key: str = "history" #: :meta private: k: int = 5 @property - def buffer(self) -> List[ChatGeneration]: + def buffer(self) -> List[ChatMessage]: """String buffer of memory.""" return self.chat_memory.messages @@ -162,7 +162,7 @@ class ConversationBufferWindowMemory(Memory, ChatMemoryMixin, BaseModel): ConversationalBufferWindowMemory = ConversationBufferWindowMemory -class ConversationSummaryMemory(Memory, ChatMemoryMixin, BaseModel): +class ConversationSummaryMemory(ChatMemoryMixin, BaseModel): """Conversation summarizer to memory.""" buffer: str = "" @@ -208,7 +208,7 @@ class ConversationSummaryMemory(Memory, ChatMemoryMixin, BaseModel): self.buffer = "" -class ConversationEntityMemory(Memory, ChatMemoryMixin, BaseModel): +class ConversationEntityMemory(ChatMemoryMixin, BaseModel): """Entity extractor & summarizer to memory.""" """Prefix to use for AI generated responses.""" @@ -221,7 +221,7 @@ class ConversationEntityMemory(Memory, ChatMemoryMixin, BaseModel): chat_history_key: str = "history" @property - def buffer(self) -> List[ChatGeneration]: + def buffer(self) -> List[ChatMessage]: return self.chat_memory.messages @property @@ -281,7 +281,7 @@ class ConversationEntityMemory(Memory, ChatMemoryMixin, BaseModel): self.store = {} -class ConversationSummaryBufferMemory(Memory, ChatMemoryMixin, BaseModel): +class ConversationSummaryBufferMemory(ChatMemoryMixin, BaseModel): """Buffer with summarizer for storing conversation memory.""" max_token_limit: int = 2000 @@ -291,7 +291,7 @@ class ConversationSummaryBufferMemory(Memory, ChatMemoryMixin, BaseModel): memory_key: str = "history" @property - def buffer(self) -> List[ChatGeneration]: + def buffer(self) -> List[ChatMessage]: return self.chat_memory.messages @property @@ -321,7 +321,7 @@ class ConversationSummaryBufferMemory(Memory, ChatMemoryMixin, BaseModel): ) return values - def get_num_tokens_list(self, arr: List[ChatGeneration]) -> List[int]: + def get_num_tokens_list(self, arr: List[ChatMessage]) -> List[int]: """Get list of number of tokens in each string in the input array.""" return [self.llm.get_num_tokens(get_buffer_string([x])) for x in arr] @@ -348,7 +348,7 @@ class ConversationSummaryBufferMemory(Memory, ChatMemoryMixin, BaseModel): self.moving_summary_buffer = "" -class ConversationKGMemory(Memory, ChatMemoryMixin, BaseModel): +class ConversationKGMemory(ChatMemoryMixin, BaseModel): """Knowledge graph memory for storing conversation memory. Integrates with external knowledge graph to store and retrieve diff --git a/langchain/chains/chat/__init__.py b/langchain/chat/__init__.py similarity index 100% rename from langchain/chains/chat/__init__.py rename to langchain/chat/__init__.py diff --git a/langchain/chat/base.py b/langchain/chat/base.py new file mode 100644 index 00000000000..e61bed03101 --- /dev/null +++ b/langchain/chat/base.py @@ -0,0 +1,28 @@ +from langchain.chains.base import Chain +from abc import ABC +from langchain.memory.chat_memory import ChatMemory +from pydantic import root_validator +from typing import Dict + +class BaseChatChain(Chain, ABC): + + human_prefix: str = "user" + ai_prefix: str = "assistant" + + @root_validator(pre=True) + def validate_memory_keys(cls, values: Dict) -> Dict: + """Validate that the human and ai prefixes line up.""" + memory = values["memory"] + if isinstance(memory, ChatMemory): + if memory.human_prefix != values["human_prefix"]: + raise ValueError( + f"Memory human_prefix ({memory.human_prefix}) must " + f"match chain human_prefix ({values['human_prefix']})" + ) + if memory.ai_prefix != values["ai_prefix"]: + raise ValueError( + f"Memory ai_prefix ({memory.ai_prefix}) must " + f"match chain ai_prefix ({values['ai_prefix']})" + ) + return values + diff --git a/langchain/chat/conversation.py b/langchain/chat/conversation.py new file mode 100644 index 00000000000..4a2725e77f3 --- /dev/null +++ b/langchain/chat/conversation.py @@ -0,0 +1,66 @@ +"""Chain that carries on a conversation and calls an LLM.""" +from typing import Dict, List, Any + +from pydantic import BaseModel, Extra, Field, root_validator + +from langchain.chat.base import BaseChatChain +from langchain.chains.conversation.prompt import PROMPT +from langchain.chains.llm import LLMChain +from langchain.prompts.base import BasePromptTemplate +from langchain.chat.memory import SimpleChatMemory +from langchain.chat_models.base import BaseChat +from langchain.schema import ChatMessage + + +class ConversationChain(BaseChatChain, BaseModel): + """Chain to have a conversation and load context from memory. + + Example: + .. code-block:: python + + from langchain import ConversationChain, OpenAI + conversation = ConversationChain(llm=OpenAI()) + """ + + model: BaseChat + memory: SimpleChatMemory = Field(default_factory=SimpleChatMemory) + """Default memory store.""" + prompt: BasePromptTemplate = PROMPT + """Default conversation prompt to use.""" + input_key: str = "input" #: :meta private: + output_key: str = "response" #: :meta private: + starter_messages: List[ChatMessage] = Field(default_factory=list) + + @classmethod + def from_model(cls, model: BaseModel, **kwargs: Any): + """From model. Future proofing.""" + return cls(model=model, **kwargs) + + + @property + def _chain_type(self) -> str: + return "chat:conversation" + + @property + def output_keys(self) -> List[str]: + return [self.output_key] + + def _call(self, inputs: Dict[str, str]) -> Dict[str, str]: + new_message = ChatMessage(text = inputs[self.input_key], role=self.human_prefix) + messages = self.starter_messages + self.memory.messages + [new_message] + output = self.model.run(messages) + return {self.output_key: output.text} + + + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + arbitrary_types_allowed = True + + @property + def input_keys(self) -> List[str]: + """Use this since so some prompt vars come from history.""" + return [self.input_key] + diff --git a/langchain/chat/memory.py b/langchain/chat/memory.py new file mode 100644 index 00000000000..c7417d96f75 --- /dev/null +++ b/langchain/chat/memory.py @@ -0,0 +1,28 @@ +from typing import Dict, Any, List + +from langchain.chains.base import Memory +from langchain.memory.chat_memory import ChatMemory +from abc import ABC, abstractmethod + +def _get_prompt_input_key(inputs: Dict[str, Any]) -> str: + # "stop" is a special key that can be passed as input but is not used to + # format the prompt. + prompt_input_keys = list(set(inputs).difference(["stop"])) + if len(prompt_input_keys) != 1: + raise ValueError(f"One input key expected got {prompt_input_keys}") + return prompt_input_keys[0] + +class SimpleChatMemory(Memory, ChatMemory): + def clear(self) -> None: + self.clear() + + @property + def memory_variables(self) -> List[str]: + return [] + + def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]: + return {} + + def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None: + self.add_user_message(inputs[_get_prompt_input_key(inputs)]) + self.add_ai_message(outputs[_get_prompt_input_key(outputs)]) diff --git a/langchain/chat/question_answering.py b/langchain/chat/question_answering.py new file mode 100644 index 00000000000..b317d35ca18 --- /dev/null +++ b/langchain/chat/question_answering.py @@ -0,0 +1,66 @@ +"""Question Answering.""" +from typing import Dict, List, Any + +from pydantic import BaseModel, Extra, Field, root_validator + +from langchain.chat.base import BaseChatChain +from langchain.chains.conversation.prompt import PROMPT +from langchain.chains.llm import LLMChain +from langchain.prompts.base import BasePromptTemplate +from langchain.chat.memory import SimpleChatMemory +from langchain.chat_models.base import BaseChat +from langchain.schema import ChatMessage + + +class QAChain(BaseChatChain, BaseModel): + """Chain to have a conversation and load context from memory. + + Example: + .. code-block:: python + + from langchain import ConversationChain, OpenAI + conversation = ConversationChain(llm=OpenAI()) + """ + + model: BaseChat + """Default memory store.""" + prompt: BasePromptTemplate = PROMPT + """Default conversation prompt to use.""" + question_key: str = "question" #: :meta private: + documents_key: str = "input_documents" #: :meta private: + output_key: str = "response" #: :meta private: + starter_messages: List[ChatMessage] = Field(default_factory=list) + + @classmethod + def from_model(cls, model: BaseModel, **kwargs: Any): + """From model. Future proofing.""" + return cls(model=model, **kwargs) + + @property + def _chain_type(self) -> str: + return "chat:qa" + + @property + def output_keys(self) -> List[str]: + return [self.output_key] + + def _call(self, inputs: Dict[str, Any]) -> Dict[str, str]: + new_message = ChatMessage(text = inputs[self.question_key], role=self.human_prefix) + docs = inputs[self.documents_key] + doc_messages = [ChatMessage(text=doc.page_content, role=self.human_prefix) for doc in docs] + messages = self.starter_messages + doc_messages + [new_message] + output = self.model.run(messages) + return {self.output_key: output.text} + + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + arbitrary_types_allowed = True + + @property + def input_keys(self) -> List[str]: + """Use this since so some prompt vars come from history.""" + return [self.question_key, self.documents_key] + diff --git a/langchain/chat/vectordbqa.py b/langchain/chat/vectordbqa.py new file mode 100644 index 00000000000..12f86f5130e --- /dev/null +++ b/langchain/chat/vectordbqa.py @@ -0,0 +1,170 @@ +"""Chain for question-answering against a vector database.""" +from __future__ import annotations + +from typing import Any, Dict, List, Optional + +from pydantic import BaseModel, Extra, Field, root_validator + +from langchain.chains.base import Chain +from langchain.chains.combine_documents.base import BaseCombineDocumentsChain +from langchain.chains.combine_documents.stuff import StuffDocumentsChain +from langchain.chains.llm import LLMChain +from langchain.chains.question_answering import load_qa_chain +from langchain.chains.vector_db_qa.prompt import PROMPT +from langchain.llms.base import BaseLLM +from langchain.prompts import PromptTemplate +from langchain.vectorstores.base import VectorStore + + +class VectorDBQA(Chain, BaseModel): + """Chain for question-answering against a vector database. + + Example: + .. code-block:: python + + from langchain import OpenAI, VectorDBQA + from langchain.faiss import FAISS + vectordb = FAISS(...) + vectordbQA = VectorDBQA(llm=OpenAI(), vectorstore=vectordb) + + """ + + vectorstore: VectorStore = Field(exclude=True) + """Vector Database to connect to.""" + k: int = 4 + """Number of documents to query for.""" + combine_documents_chain: BaseCombineDocumentsChain + """Chain to use to combine the documents.""" + input_key: str = "query" #: :meta private: + output_key: str = "result" #: :meta private: + return_source_documents: bool = False + """Return the source documents.""" + search_kwargs: Dict[str, Any] = Field(default_factory=dict) + """Extra search args.""" + search_type: str = "similarity" + """Search type to use over vectorstore. `similarity` or `mmr`.""" + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + arbitrary_types_allowed = True + + @property + def input_keys(self) -> List[str]: + """Return the input keys. + + :meta private: + """ + return [self.input_key] + + @property + def output_keys(self) -> List[str]: + """Return the output keys. + + :meta private: + """ + _output_keys = [self.output_key] + if self.return_source_documents: + _output_keys = _output_keys + ["source_documents"] + return _output_keys + + # TODO: deprecate this + @root_validator(pre=True) + def load_combine_documents_chain(cls, values: Dict) -> Dict: + """Validate question chain.""" + if "combine_documents_chain" not in values: + if "llm" not in values: + raise ValueError( + "If `combine_documents_chain` not provided, `llm` should be." + ) + prompt = values.pop("prompt", PROMPT) + llm = values.pop("llm") + llm_chain = LLMChain(llm=llm, prompt=prompt) + document_prompt = PromptTemplate( + input_variables=["page_content"], template="Context:\n{page_content}" + ) + combine_documents_chain = StuffDocumentsChain( + llm_chain=llm_chain, + document_variable_name="context", + document_prompt=document_prompt, + ) + values["combine_documents_chain"] = combine_documents_chain + return values + + @root_validator() + def validate_search_type(cls, values: Dict) -> Dict: + """Validate search type.""" + if "search_type" in values: + search_type = values["search_type"] + if search_type not in ("similarity", "mmr"): + raise ValueError(f"search_type of {search_type} not allowed.") + return values + + @classmethod + def from_llm( + cls, llm: BaseLLM, prompt: PromptTemplate = PROMPT, **kwargs: Any + ) -> VectorDBQA: + """Initialize from LLM.""" + llm_chain = LLMChain(llm=llm, prompt=prompt) + document_prompt = PromptTemplate( + input_variables=["page_content"], template="Context:\n{page_content}" + ) + combine_documents_chain = StuffDocumentsChain( + llm_chain=llm_chain, + document_variable_name="context", + document_prompt=document_prompt, + ) + + return cls(combine_documents_chain=combine_documents_chain, **kwargs) + + @classmethod + def from_chain_type( + cls, + llm: BaseLLM, + chain_type: str = "stuff", + chain_type_kwargs: Optional[dict] = None, + **kwargs: Any, + ) -> VectorDBQA: + """Load chain from chain type.""" + _chain_type_kwargs = chain_type_kwargs or {} + combine_documents_chain = load_qa_chain( + llm, chain_type=chain_type, **_chain_type_kwargs + ) + return cls(combine_documents_chain=combine_documents_chain, **kwargs) + + def _call(self, inputs: Dict[str, str]) -> Dict[str, Any]: + """Run similarity search and llm on input query. + + If chain has 'return_source_documents' as 'True', returns + the retrieved documents as well under the key 'source_documents'. + + Example: + .. code-block:: python + + res = vectordbqa({'query': 'This is my query'}) + answer, docs = res['result'], res['source_documents'] + """ + question = inputs[self.input_key] + + if self.search_type == "similarity": + docs = self.vectorstore.similarity_search( + question, k=self.k, **self.search_kwargs + ) + elif self.search_type == "mmr": + docs = self.vectorstore.max_marginal_relevance_search( + question, k=self.k, **self.search_kwargs + ) + else: + raise ValueError(f"search_type of {self.search_type} not allowed.") + answer, _ = self.combine_documents_chain.combine_docs(docs, question=question) + + if self.return_source_documents: + return {self.output_key: answer, "source_documents": docs} + else: + return {self.output_key: answer} + + @property + def _chain_type(self) -> str: + """Return the chain type.""" + return "chat:vector_db_qa" diff --git a/langchain/chat_models/base.py b/langchain/chat_models/base.py index 94e66296a83..a66f213c5a3 100644 --- a/langchain/chat_models/base.py +++ b/langchain/chat_models/base.py @@ -1,12 +1,12 @@ from abc import ABC, abstractmethod from typing import Dict, List, Optional -from langchain.schema import ChatGeneration, ChatResult +from langchain.schema import ChatGeneration, ChatResult, ChatMessage class BaseChat(ABC): def generate( - self, messages: List[Dict], stop: Optional[List[str]] = None + self, messages: List[ChatMessage], stop: Optional[List[str]] = None ) -> ChatResult: """Top Level call""" # Nothing here now, but future proofing. @@ -14,21 +14,29 @@ class BaseChat(ABC): @abstractmethod def _generate( - self, messages: List[Dict], stop: Optional[List[str]] = None + self, messages: List[ChatMessage], stop: Optional[List[str]] = None ) -> ChatResult: """Top Level call""" + def run( + self, messages: List[ChatMessage], stop: Optional[List[str]] = None + ) -> ChatMessage: + res = self.generate(messages, stop=stop) + return res.generations[0].message + + class SimpleChat(BaseChat): role: str = "assistant" def _generate( - self, messages: List[Dict], stop: Optional[List[str]] = None + self, messages: List[ChatMessage], stop: Optional[List[str]] = None ) -> ChatResult: output_str = self._call(messages, stop=stop) - generation = ChatGeneration(text=output_str, role=self.role) + message = ChatMessage(text=output_str, role=self.role) + generation = ChatGeneration(message=message) return ChatResult(generations=[generation]) @abstractmethod - def _call(self, messages: List[Dict], stop: Optional[List[str]] = None) -> str: + def _call(self, messages: List[ChatMessage], stop: Optional[List[str]] = None) -> str: """Simpler interface.""" diff --git a/langchain/chat_models/openai.py b/langchain/chat_models/openai.py index 2587d9cbf8e..776a0c7fa33 100644 --- a/langchain/chat_models/openai.py +++ b/langchain/chat_models/openai.py @@ -12,13 +12,13 @@ from tenacity import ( ) from langchain.chat_models.base import BaseChat -from langchain.schema import ChatGeneration, ChatResult +from langchain.schema import ChatGeneration, ChatResult, ChatMessage from langchain.utils import get_from_dict_or_env logger = logging.getLogger(__file__) -class OpenAIChat(BaseChat, BaseModel): +class OpenAI(BaseChat, BaseModel): """Wrapper around OpenAI Chat large language models. To use, you should have the ``openai`` python package installed, and the @@ -124,19 +124,19 @@ class OpenAIChat(BaseChat, BaseModel): return _completion_with_retry(**kwargs) def _generate( - self, messages: List[Dict], stop: Optional[List[str]] = None + self, messages: List[ChatMessage], stop: Optional[List[str]] = None ) -> ChatResult: params: Dict[str, Any] = {**{"model": self.model_name}, **self._default_params} if stop is not None: if "stop" in params: raise ValueError("`stop` found in both the input and default params.") params["stop"] = stop - response = self.completion_with_retry(messages=messages, **params) + message_dicts = [{"role": m.role, "content": m.text} for m in messages] + response = self.completion_with_retry(messages=message_dicts, **params) generations = [] for res in response["choices"]: - gen = ChatGeneration( - text=res["message"]["content"], role=res["message"]["role"] - ) + message = ChatMessage(text=res["message"]["content"], role=res["message"]["role"]) + gen = ChatGeneration(message=message) generations.append(gen) return ChatResult(generations=generations) diff --git a/langchain/memory/chat_memory.py b/langchain/memory/chat_memory.py index 20654cb5bc4..ae5e745981d 100644 --- a/langchain/memory/chat_memory.py +++ b/langchain/memory/chat_memory.py @@ -2,20 +2,20 @@ from typing import List from pydantic import BaseModel, Field -from langchain.schema import ChatGeneration +from langchain.schema import ChatMessage class ChatMemory(BaseModel): human_prefix: str = "user" ai_prefix: str = "assistant" - messages: List[ChatGeneration] = Field(default_factory=list) + messages: List[ChatMessage] = Field(default_factory=list) def add_user_message(self, message: str) -> None: - gen = ChatGeneration(text=message, role=self.human_prefix) + gen = ChatMessage(text=message, role=self.human_prefix) self.messages.append(gen) def add_ai_message(self, message: str) -> None: - gen = ChatGeneration(text=message, role=self.ai_prefix) + gen = ChatMessage(text=message, role=self.ai_prefix) self.messages.append(gen) def clear(self): diff --git a/langchain/memory/utils.py b/langchain/memory/utils.py index 896fbbb8f48..3d48023557c 100644 --- a/langchain/memory/utils.py +++ b/langchain/memory/utils.py @@ -1,8 +1,8 @@ from typing import List -from langchain.schema import ChatGeneration +from langchain.schema import ChatMessage -def get_buffer_string(messages: List[ChatGeneration]): +def get_buffer_string(messages: List[ChatMessage]): """Get buffer string of messages.""" return "\n".join([f"{gen.role}: {gen.text}" for gen in messages]) diff --git a/langchain/schema.py b/langchain/schema.py index 2c6542d4ea2..efb4f7095cc 100644 --- a/langchain/schema.py +++ b/langchain/schema.py @@ -46,11 +46,10 @@ class LLMResult: llm_output: Optional[dict] = None """For arbitrary LLM provider specific output.""" - @dataclass_json @dataclass -class ChatGeneration: - """Output of a single generation.""" +class ChatMessage: + """Message object.""" text: str """Generated text output.""" @@ -58,6 +57,14 @@ class ChatGeneration: role: str """Role of the chatter.""" + +@dataclass_json +@dataclass +class ChatGeneration: + """Output of a single generation.""" + + message: ChatMessage + generation_info: Optional[Dict[str, Any]] = None """Raw generation info response from the provider""" """May include things like reason for finishing (e.g. in OpenAI)"""