mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-21 14:43:07 +00:00
cr
This commit is contained in:
127
docs/modules/chat/conversation.ipynb
Normal file
127
docs/modules/chat/conversation.ipynb
Normal file
@@ -0,0 +1,127 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "319061e6",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.chat.conversation import ConversationChain\n",
|
||||
"from langchain.chat_models.openai import OpenAI"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "4dbbb98f",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model = OpenAI()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "164af02b",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"chain = ConversationChain.from_model(model=model)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "11830ac7",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"'\\n\\nHello there! How may I assist you today?'"
|
||||
]
|
||||
},
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"chain.run(\"hi!\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "8182f6c8",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"'As an AI language model, I do not have access to your physical location. However, since you are interacting with me through this platform, I believe you are connected to the internet and currently using a device such as a computer, phone or tablet.'"
|
||||
]
|
||||
},
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"chain.run('where am i?')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"id": "ecc2fbbb",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"'I said that as an AI language model, I do not have access to your physical location. However, since you are interacting with me through this platform, I believe you are connected to the internet and currently using a device such as a computer, phone, or tablet.'"
|
||||
]
|
||||
},
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"chain.run('what did you say?')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "614977d8",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.1"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
@@ -20,7 +20,7 @@ from langchain.llms.base import BaseLLM
|
||||
from langchain.memory.chat_memory import ChatMemory
|
||||
from langchain.memory.utils import get_buffer_string
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.schema import ChatGeneration
|
||||
from langchain.schema import ChatMessage
|
||||
|
||||
|
||||
def _get_prompt_input_key(inputs: Dict[str, Any], memory_variables: List[str]) -> str:
|
||||
@@ -134,14 +134,14 @@ class ConversationBufferMemory(ChatMemoryMixin, BaseModel):
|
||||
return {self.memory_key: self.buffer}
|
||||
|
||||
|
||||
class ConversationBufferWindowMemory(Memory, ChatMemoryMixin, BaseModel):
|
||||
class ConversationBufferWindowMemory(ChatMemoryMixin, BaseModel):
|
||||
"""Buffer for storing conversation memory."""
|
||||
|
||||
memory_key: str = "history" #: :meta private:
|
||||
k: int = 5
|
||||
|
||||
@property
|
||||
def buffer(self) -> List[ChatGeneration]:
|
||||
def buffer(self) -> List[ChatMessage]:
|
||||
"""String buffer of memory."""
|
||||
return self.chat_memory.messages
|
||||
|
||||
@@ -162,7 +162,7 @@ class ConversationBufferWindowMemory(Memory, ChatMemoryMixin, BaseModel):
|
||||
ConversationalBufferWindowMemory = ConversationBufferWindowMemory
|
||||
|
||||
|
||||
class ConversationSummaryMemory(Memory, ChatMemoryMixin, BaseModel):
|
||||
class ConversationSummaryMemory(ChatMemoryMixin, BaseModel):
|
||||
"""Conversation summarizer to memory."""
|
||||
|
||||
buffer: str = ""
|
||||
@@ -208,7 +208,7 @@ class ConversationSummaryMemory(Memory, ChatMemoryMixin, BaseModel):
|
||||
self.buffer = ""
|
||||
|
||||
|
||||
class ConversationEntityMemory(Memory, ChatMemoryMixin, BaseModel):
|
||||
class ConversationEntityMemory(ChatMemoryMixin, BaseModel):
|
||||
"""Entity extractor & summarizer to memory."""
|
||||
|
||||
"""Prefix to use for AI generated responses."""
|
||||
@@ -221,7 +221,7 @@ class ConversationEntityMemory(Memory, ChatMemoryMixin, BaseModel):
|
||||
chat_history_key: str = "history"
|
||||
|
||||
@property
|
||||
def buffer(self) -> List[ChatGeneration]:
|
||||
def buffer(self) -> List[ChatMessage]:
|
||||
return self.chat_memory.messages
|
||||
|
||||
@property
|
||||
@@ -281,7 +281,7 @@ class ConversationEntityMemory(Memory, ChatMemoryMixin, BaseModel):
|
||||
self.store = {}
|
||||
|
||||
|
||||
class ConversationSummaryBufferMemory(Memory, ChatMemoryMixin, BaseModel):
|
||||
class ConversationSummaryBufferMemory(ChatMemoryMixin, BaseModel):
|
||||
"""Buffer with summarizer for storing conversation memory."""
|
||||
|
||||
max_token_limit: int = 2000
|
||||
@@ -291,7 +291,7 @@ class ConversationSummaryBufferMemory(Memory, ChatMemoryMixin, BaseModel):
|
||||
memory_key: str = "history"
|
||||
|
||||
@property
|
||||
def buffer(self) -> List[ChatGeneration]:
|
||||
def buffer(self) -> List[ChatMessage]:
|
||||
return self.chat_memory.messages
|
||||
|
||||
@property
|
||||
@@ -321,7 +321,7 @@ class ConversationSummaryBufferMemory(Memory, ChatMemoryMixin, BaseModel):
|
||||
)
|
||||
return values
|
||||
|
||||
def get_num_tokens_list(self, arr: List[ChatGeneration]) -> List[int]:
|
||||
def get_num_tokens_list(self, arr: List[ChatMessage]) -> List[int]:
|
||||
"""Get list of number of tokens in each string in the input array."""
|
||||
return [self.llm.get_num_tokens(get_buffer_string([x])) for x in arr]
|
||||
|
||||
@@ -348,7 +348,7 @@ class ConversationSummaryBufferMemory(Memory, ChatMemoryMixin, BaseModel):
|
||||
self.moving_summary_buffer = ""
|
||||
|
||||
|
||||
class ConversationKGMemory(Memory, ChatMemoryMixin, BaseModel):
|
||||
class ConversationKGMemory(ChatMemoryMixin, BaseModel):
|
||||
"""Knowledge graph memory for storing conversation memory.
|
||||
|
||||
Integrates with external knowledge graph to store and retrieve
|
||||
|
||||
28
langchain/chat/base.py
Normal file
28
langchain/chat/base.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from langchain.chains.base import Chain
|
||||
from abc import ABC
|
||||
from langchain.memory.chat_memory import ChatMemory
|
||||
from pydantic import root_validator
|
||||
from typing import Dict
|
||||
|
||||
class BaseChatChain(Chain, ABC):
|
||||
|
||||
human_prefix: str = "user"
|
||||
ai_prefix: str = "assistant"
|
||||
|
||||
@root_validator(pre=True)
|
||||
def validate_memory_keys(cls, values: Dict) -> Dict:
|
||||
"""Validate that the human and ai prefixes line up."""
|
||||
memory = values["memory"]
|
||||
if isinstance(memory, ChatMemory):
|
||||
if memory.human_prefix != values["human_prefix"]:
|
||||
raise ValueError(
|
||||
f"Memory human_prefix ({memory.human_prefix}) must "
|
||||
f"match chain human_prefix ({values['human_prefix']})"
|
||||
)
|
||||
if memory.ai_prefix != values["ai_prefix"]:
|
||||
raise ValueError(
|
||||
f"Memory ai_prefix ({memory.ai_prefix}) must "
|
||||
f"match chain ai_prefix ({values['ai_prefix']})"
|
||||
)
|
||||
return values
|
||||
|
||||
66
langchain/chat/conversation.py
Normal file
66
langchain/chat/conversation.py
Normal file
@@ -0,0 +1,66 @@
|
||||
"""Chain that carries on a conversation and calls an LLM."""
|
||||
from typing import Dict, List, Any
|
||||
|
||||
from pydantic import BaseModel, Extra, Field, root_validator
|
||||
|
||||
from langchain.chat.base import BaseChatChain
|
||||
from langchain.chains.conversation.prompt import PROMPT
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.chat.memory import SimpleChatMemory
|
||||
from langchain.chat_models.base import BaseChat
|
||||
from langchain.schema import ChatMessage
|
||||
|
||||
|
||||
class ConversationChain(BaseChatChain, BaseModel):
|
||||
"""Chain to have a conversation and load context from memory.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain import ConversationChain, OpenAI
|
||||
conversation = ConversationChain(llm=OpenAI())
|
||||
"""
|
||||
|
||||
model: BaseChat
|
||||
memory: SimpleChatMemory = Field(default_factory=SimpleChatMemory)
|
||||
"""Default memory store."""
|
||||
prompt: BasePromptTemplate = PROMPT
|
||||
"""Default conversation prompt to use."""
|
||||
input_key: str = "input" #: :meta private:
|
||||
output_key: str = "response" #: :meta private:
|
||||
starter_messages: List[ChatMessage] = Field(default_factory=list)
|
||||
|
||||
@classmethod
|
||||
def from_model(cls, model: BaseModel, **kwargs: Any):
|
||||
"""From model. Future proofing."""
|
||||
return cls(model=model, **kwargs)
|
||||
|
||||
|
||||
@property
|
||||
def _chain_type(self) -> str:
|
||||
return "chat:conversation"
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
return [self.output_key]
|
||||
|
||||
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||
new_message = ChatMessage(text = inputs[self.input_key], role=self.human_prefix)
|
||||
messages = self.starter_messages + self.memory.messages + [new_message]
|
||||
output = self.model.run(messages)
|
||||
return {self.output_key: output.text}
|
||||
|
||||
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Use this since so some prompt vars come from history."""
|
||||
return [self.input_key]
|
||||
|
||||
28
langchain/chat/memory.py
Normal file
28
langchain/chat/memory.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from typing import Dict, Any, List
|
||||
|
||||
from langchain.chains.base import Memory
|
||||
from langchain.memory.chat_memory import ChatMemory
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
def _get_prompt_input_key(inputs: Dict[str, Any]) -> str:
|
||||
# "stop" is a special key that can be passed as input but is not used to
|
||||
# format the prompt.
|
||||
prompt_input_keys = list(set(inputs).difference(["stop"]))
|
||||
if len(prompt_input_keys) != 1:
|
||||
raise ValueError(f"One input key expected got {prompt_input_keys}")
|
||||
return prompt_input_keys[0]
|
||||
|
||||
class SimpleChatMemory(Memory, ChatMemory):
|
||||
def clear(self) -> None:
|
||||
self.clear()
|
||||
|
||||
@property
|
||||
def memory_variables(self) -> List[str]:
|
||||
return []
|
||||
|
||||
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
||||
return {}
|
||||
|
||||
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
|
||||
self.add_user_message(inputs[_get_prompt_input_key(inputs)])
|
||||
self.add_ai_message(outputs[_get_prompt_input_key(outputs)])
|
||||
66
langchain/chat/question_answering.py
Normal file
66
langchain/chat/question_answering.py
Normal file
@@ -0,0 +1,66 @@
|
||||
"""Question Answering."""
|
||||
from typing import Dict, List, Any
|
||||
|
||||
from pydantic import BaseModel, Extra, Field, root_validator
|
||||
|
||||
from langchain.chat.base import BaseChatChain
|
||||
from langchain.chains.conversation.prompt import PROMPT
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.chat.memory import SimpleChatMemory
|
||||
from langchain.chat_models.base import BaseChat
|
||||
from langchain.schema import ChatMessage
|
||||
|
||||
|
||||
class QAChain(BaseChatChain, BaseModel):
|
||||
"""Chain to have a conversation and load context from memory.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain import ConversationChain, OpenAI
|
||||
conversation = ConversationChain(llm=OpenAI())
|
||||
"""
|
||||
|
||||
model: BaseChat
|
||||
"""Default memory store."""
|
||||
prompt: BasePromptTemplate = PROMPT
|
||||
"""Default conversation prompt to use."""
|
||||
question_key: str = "question" #: :meta private:
|
||||
documents_key: str = "input_documents" #: :meta private:
|
||||
output_key: str = "response" #: :meta private:
|
||||
starter_messages: List[ChatMessage] = Field(default_factory=list)
|
||||
|
||||
@classmethod
|
||||
def from_model(cls, model: BaseModel, **kwargs: Any):
|
||||
"""From model. Future proofing."""
|
||||
return cls(model=model, **kwargs)
|
||||
|
||||
@property
|
||||
def _chain_type(self) -> str:
|
||||
return "chat:qa"
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
return [self.output_key]
|
||||
|
||||
def _call(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
||||
new_message = ChatMessage(text = inputs[self.question_key], role=self.human_prefix)
|
||||
docs = inputs[self.documents_key]
|
||||
doc_messages = [ChatMessage(text=doc.page_content, role=self.human_prefix) for doc in docs]
|
||||
messages = self.starter_messages + doc_messages + [new_message]
|
||||
output = self.model.run(messages)
|
||||
return {self.output_key: output.text}
|
||||
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Use this since so some prompt vars come from history."""
|
||||
return [self.question_key, self.documents_key]
|
||||
|
||||
170
langchain/chat/vectordbqa.py
Normal file
170
langchain/chat/vectordbqa.py
Normal file
@@ -0,0 +1,170 @@
|
||||
"""Chain for question-answering against a vector database."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Extra, Field, root_validator
|
||||
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
||||
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.question_answering import load_qa_chain
|
||||
from langchain.chains.vector_db_qa.prompt import PROMPT
|
||||
from langchain.llms.base import BaseLLM
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain.vectorstores.base import VectorStore
|
||||
|
||||
|
||||
class VectorDBQA(Chain, BaseModel):
|
||||
"""Chain for question-answering against a vector database.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain import OpenAI, VectorDBQA
|
||||
from langchain.faiss import FAISS
|
||||
vectordb = FAISS(...)
|
||||
vectordbQA = VectorDBQA(llm=OpenAI(), vectorstore=vectordb)
|
||||
|
||||
"""
|
||||
|
||||
vectorstore: VectorStore = Field(exclude=True)
|
||||
"""Vector Database to connect to."""
|
||||
k: int = 4
|
||||
"""Number of documents to query for."""
|
||||
combine_documents_chain: BaseCombineDocumentsChain
|
||||
"""Chain to use to combine the documents."""
|
||||
input_key: str = "query" #: :meta private:
|
||||
output_key: str = "result" #: :meta private:
|
||||
return_source_documents: bool = False
|
||||
"""Return the source documents."""
|
||||
search_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
"""Extra search args."""
|
||||
search_type: str = "similarity"
|
||||
"""Search type to use over vectorstore. `similarity` or `mmr`."""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Return the input keys.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.input_key]
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Return the output keys.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
_output_keys = [self.output_key]
|
||||
if self.return_source_documents:
|
||||
_output_keys = _output_keys + ["source_documents"]
|
||||
return _output_keys
|
||||
|
||||
# TODO: deprecate this
|
||||
@root_validator(pre=True)
|
||||
def load_combine_documents_chain(cls, values: Dict) -> Dict:
|
||||
"""Validate question chain."""
|
||||
if "combine_documents_chain" not in values:
|
||||
if "llm" not in values:
|
||||
raise ValueError(
|
||||
"If `combine_documents_chain` not provided, `llm` should be."
|
||||
)
|
||||
prompt = values.pop("prompt", PROMPT)
|
||||
llm = values.pop("llm")
|
||||
llm_chain = LLMChain(llm=llm, prompt=prompt)
|
||||
document_prompt = PromptTemplate(
|
||||
input_variables=["page_content"], template="Context:\n{page_content}"
|
||||
)
|
||||
combine_documents_chain = StuffDocumentsChain(
|
||||
llm_chain=llm_chain,
|
||||
document_variable_name="context",
|
||||
document_prompt=document_prompt,
|
||||
)
|
||||
values["combine_documents_chain"] = combine_documents_chain
|
||||
return values
|
||||
|
||||
@root_validator()
|
||||
def validate_search_type(cls, values: Dict) -> Dict:
|
||||
"""Validate search type."""
|
||||
if "search_type" in values:
|
||||
search_type = values["search_type"]
|
||||
if search_type not in ("similarity", "mmr"):
|
||||
raise ValueError(f"search_type of {search_type} not allowed.")
|
||||
return values
|
||||
|
||||
@classmethod
|
||||
def from_llm(
|
||||
cls, llm: BaseLLM, prompt: PromptTemplate = PROMPT, **kwargs: Any
|
||||
) -> VectorDBQA:
|
||||
"""Initialize from LLM."""
|
||||
llm_chain = LLMChain(llm=llm, prompt=prompt)
|
||||
document_prompt = PromptTemplate(
|
||||
input_variables=["page_content"], template="Context:\n{page_content}"
|
||||
)
|
||||
combine_documents_chain = StuffDocumentsChain(
|
||||
llm_chain=llm_chain,
|
||||
document_variable_name="context",
|
||||
document_prompt=document_prompt,
|
||||
)
|
||||
|
||||
return cls(combine_documents_chain=combine_documents_chain, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def from_chain_type(
|
||||
cls,
|
||||
llm: BaseLLM,
|
||||
chain_type: str = "stuff",
|
||||
chain_type_kwargs: Optional[dict] = None,
|
||||
**kwargs: Any,
|
||||
) -> VectorDBQA:
|
||||
"""Load chain from chain type."""
|
||||
_chain_type_kwargs = chain_type_kwargs or {}
|
||||
combine_documents_chain = load_qa_chain(
|
||||
llm, chain_type=chain_type, **_chain_type_kwargs
|
||||
)
|
||||
return cls(combine_documents_chain=combine_documents_chain, **kwargs)
|
||||
|
||||
def _call(self, inputs: Dict[str, str]) -> Dict[str, Any]:
|
||||
"""Run similarity search and llm on input query.
|
||||
|
||||
If chain has 'return_source_documents' as 'True', returns
|
||||
the retrieved documents as well under the key 'source_documents'.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
res = vectordbqa({'query': 'This is my query'})
|
||||
answer, docs = res['result'], res['source_documents']
|
||||
"""
|
||||
question = inputs[self.input_key]
|
||||
|
||||
if self.search_type == "similarity":
|
||||
docs = self.vectorstore.similarity_search(
|
||||
question, k=self.k, **self.search_kwargs
|
||||
)
|
||||
elif self.search_type == "mmr":
|
||||
docs = self.vectorstore.max_marginal_relevance_search(
|
||||
question, k=self.k, **self.search_kwargs
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"search_type of {self.search_type} not allowed.")
|
||||
answer, _ = self.combine_documents_chain.combine_docs(docs, question=question)
|
||||
|
||||
if self.return_source_documents:
|
||||
return {self.output_key: answer, "source_documents": docs}
|
||||
else:
|
||||
return {self.output_key: answer}
|
||||
|
||||
@property
|
||||
def _chain_type(self) -> str:
|
||||
"""Return the chain type."""
|
||||
return "chat:vector_db_qa"
|
||||
@@ -1,12 +1,12 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from langchain.schema import ChatGeneration, ChatResult
|
||||
from langchain.schema import ChatGeneration, ChatResult, ChatMessage
|
||||
|
||||
|
||||
class BaseChat(ABC):
|
||||
def generate(
|
||||
self, messages: List[Dict], stop: Optional[List[str]] = None
|
||||
self, messages: List[ChatMessage], stop: Optional[List[str]] = None
|
||||
) -> ChatResult:
|
||||
"""Top Level call"""
|
||||
# Nothing here now, but future proofing.
|
||||
@@ -14,21 +14,29 @@ class BaseChat(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def _generate(
|
||||
self, messages: List[Dict], stop: Optional[List[str]] = None
|
||||
self, messages: List[ChatMessage], stop: Optional[List[str]] = None
|
||||
) -> ChatResult:
|
||||
"""Top Level call"""
|
||||
|
||||
def run(
|
||||
self, messages: List[ChatMessage], stop: Optional[List[str]] = None
|
||||
) -> ChatMessage:
|
||||
res = self.generate(messages, stop=stop)
|
||||
return res.generations[0].message
|
||||
|
||||
|
||||
|
||||
class SimpleChat(BaseChat):
|
||||
role: str = "assistant"
|
||||
|
||||
def _generate(
|
||||
self, messages: List[Dict], stop: Optional[List[str]] = None
|
||||
self, messages: List[ChatMessage], stop: Optional[List[str]] = None
|
||||
) -> ChatResult:
|
||||
output_str = self._call(messages, stop=stop)
|
||||
generation = ChatGeneration(text=output_str, role=self.role)
|
||||
message = ChatMessage(text=output_str, role=self.role)
|
||||
generation = ChatGeneration(message=message)
|
||||
return ChatResult(generations=[generation])
|
||||
|
||||
@abstractmethod
|
||||
def _call(self, messages: List[Dict], stop: Optional[List[str]] = None) -> str:
|
||||
def _call(self, messages: List[ChatMessage], stop: Optional[List[str]] = None) -> str:
|
||||
"""Simpler interface."""
|
||||
|
||||
@@ -12,13 +12,13 @@ from tenacity import (
|
||||
)
|
||||
|
||||
from langchain.chat_models.base import BaseChat
|
||||
from langchain.schema import ChatGeneration, ChatResult
|
||||
from langchain.schema import ChatGeneration, ChatResult, ChatMessage
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
logger = logging.getLogger(__file__)
|
||||
|
||||
|
||||
class OpenAIChat(BaseChat, BaseModel):
|
||||
class OpenAI(BaseChat, BaseModel):
|
||||
"""Wrapper around OpenAI Chat large language models.
|
||||
|
||||
To use, you should have the ``openai`` python package installed, and the
|
||||
@@ -124,19 +124,19 @@ class OpenAIChat(BaseChat, BaseModel):
|
||||
return _completion_with_retry(**kwargs)
|
||||
|
||||
def _generate(
|
||||
self, messages: List[Dict], stop: Optional[List[str]] = None
|
||||
self, messages: List[ChatMessage], stop: Optional[List[str]] = None
|
||||
) -> ChatResult:
|
||||
params: Dict[str, Any] = {**{"model": self.model_name}, **self._default_params}
|
||||
if stop is not None:
|
||||
if "stop" in params:
|
||||
raise ValueError("`stop` found in both the input and default params.")
|
||||
params["stop"] = stop
|
||||
response = self.completion_with_retry(messages=messages, **params)
|
||||
message_dicts = [{"role": m.role, "content": m.text} for m in messages]
|
||||
response = self.completion_with_retry(messages=message_dicts, **params)
|
||||
generations = []
|
||||
for res in response["choices"]:
|
||||
gen = ChatGeneration(
|
||||
text=res["message"]["content"], role=res["message"]["role"]
|
||||
)
|
||||
message = ChatMessage(text=res["message"]["content"], role=res["message"]["role"])
|
||||
gen = ChatGeneration(message=message)
|
||||
generations.append(gen)
|
||||
return ChatResult(generations=generations)
|
||||
|
||||
|
||||
@@ -2,20 +2,20 @@ from typing import List
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from langchain.schema import ChatGeneration
|
||||
from langchain.schema import ChatMessage
|
||||
|
||||
|
||||
class ChatMemory(BaseModel):
|
||||
human_prefix: str = "user"
|
||||
ai_prefix: str = "assistant"
|
||||
messages: List[ChatGeneration] = Field(default_factory=list)
|
||||
messages: List[ChatMessage] = Field(default_factory=list)
|
||||
|
||||
def add_user_message(self, message: str) -> None:
|
||||
gen = ChatGeneration(text=message, role=self.human_prefix)
|
||||
gen = ChatMessage(text=message, role=self.human_prefix)
|
||||
self.messages.append(gen)
|
||||
|
||||
def add_ai_message(self, message: str) -> None:
|
||||
gen = ChatGeneration(text=message, role=self.ai_prefix)
|
||||
gen = ChatMessage(text=message, role=self.ai_prefix)
|
||||
self.messages.append(gen)
|
||||
|
||||
def clear(self):
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from typing import List
|
||||
|
||||
from langchain.schema import ChatGeneration
|
||||
from langchain.schema import ChatMessage
|
||||
|
||||
|
||||
def get_buffer_string(messages: List[ChatGeneration]):
|
||||
def get_buffer_string(messages: List[ChatMessage]):
|
||||
"""Get buffer string of messages."""
|
||||
return "\n".join([f"{gen.role}: {gen.text}" for gen in messages])
|
||||
|
||||
@@ -46,11 +46,10 @@ class LLMResult:
|
||||
llm_output: Optional[dict] = None
|
||||
"""For arbitrary LLM provider specific output."""
|
||||
|
||||
|
||||
@dataclass_json
|
||||
@dataclass
|
||||
class ChatGeneration:
|
||||
"""Output of a single generation."""
|
||||
class ChatMessage:
|
||||
"""Message object."""
|
||||
|
||||
text: str
|
||||
"""Generated text output."""
|
||||
@@ -58,6 +57,14 @@ class ChatGeneration:
|
||||
role: str
|
||||
"""Role of the chatter."""
|
||||
|
||||
|
||||
@dataclass_json
|
||||
@dataclass
|
||||
class ChatGeneration:
|
||||
"""Output of a single generation."""
|
||||
|
||||
message: ChatMessage
|
||||
|
||||
generation_info: Optional[Dict[str, Any]] = None
|
||||
"""Raw generation info response from the provider"""
|
||||
"""May include things like reason for finishing (e.g. in OpenAI)"""
|
||||
|
||||
Reference in New Issue
Block a user