This commit is contained in:
Harrison Chase
2023-03-01 16:44:45 -08:00
parent f635a31992
commit acaa2d3ee4
13 changed files with 532 additions and 32 deletions

View File

@@ -0,0 +1,127 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "319061e6",
"metadata": {},
"outputs": [],
"source": [
"from langchain.chat.conversation import ConversationChain\n",
"from langchain.chat_models.openai import OpenAI"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "4dbbb98f",
"metadata": {},
"outputs": [],
"source": [
"model = OpenAI()"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "164af02b",
"metadata": {},
"outputs": [],
"source": [
"chain = ConversationChain.from_model(model=model)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "11830ac7",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'\\n\\nHello there! How may I assist you today?'"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"chain.run(\"hi!\")"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "8182f6c8",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'As an AI language model, I do not have access to your physical location. However, since you are interacting with me through this platform, I believe you are connected to the internet and currently using a device such as a computer, phone or tablet.'"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"chain.run('where am i?')"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "ecc2fbbb",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'I said that as an AI language model, I do not have access to your physical location. However, since you are interacting with me through this platform, I believe you are connected to the internet and currently using a device such as a computer, phone, or tablet.'"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"chain.run('what did you say?')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "614977d8",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.1"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@@ -20,7 +20,7 @@ from langchain.llms.base import BaseLLM
from langchain.memory.chat_memory import ChatMemory
from langchain.memory.utils import get_buffer_string
from langchain.prompts.base import BasePromptTemplate
from langchain.schema import ChatGeneration
from langchain.schema import ChatMessage
def _get_prompt_input_key(inputs: Dict[str, Any], memory_variables: List[str]) -> str:
@@ -134,14 +134,14 @@ class ConversationBufferMemory(ChatMemoryMixin, BaseModel):
return {self.memory_key: self.buffer}
class ConversationBufferWindowMemory(Memory, ChatMemoryMixin, BaseModel):
class ConversationBufferWindowMemory(ChatMemoryMixin, BaseModel):
"""Buffer for storing conversation memory."""
memory_key: str = "history" #: :meta private:
k: int = 5
@property
def buffer(self) -> List[ChatGeneration]:
def buffer(self) -> List[ChatMessage]:
"""String buffer of memory."""
return self.chat_memory.messages
@@ -162,7 +162,7 @@ class ConversationBufferWindowMemory(Memory, ChatMemoryMixin, BaseModel):
ConversationalBufferWindowMemory = ConversationBufferWindowMemory
class ConversationSummaryMemory(Memory, ChatMemoryMixin, BaseModel):
class ConversationSummaryMemory(ChatMemoryMixin, BaseModel):
"""Conversation summarizer to memory."""
buffer: str = ""
@@ -208,7 +208,7 @@ class ConversationSummaryMemory(Memory, ChatMemoryMixin, BaseModel):
self.buffer = ""
class ConversationEntityMemory(Memory, ChatMemoryMixin, BaseModel):
class ConversationEntityMemory(ChatMemoryMixin, BaseModel):
"""Entity extractor & summarizer to memory."""
"""Prefix to use for AI generated responses."""
@@ -221,7 +221,7 @@ class ConversationEntityMemory(Memory, ChatMemoryMixin, BaseModel):
chat_history_key: str = "history"
@property
def buffer(self) -> List[ChatGeneration]:
def buffer(self) -> List[ChatMessage]:
return self.chat_memory.messages
@property
@@ -281,7 +281,7 @@ class ConversationEntityMemory(Memory, ChatMemoryMixin, BaseModel):
self.store = {}
class ConversationSummaryBufferMemory(Memory, ChatMemoryMixin, BaseModel):
class ConversationSummaryBufferMemory(ChatMemoryMixin, BaseModel):
"""Buffer with summarizer for storing conversation memory."""
max_token_limit: int = 2000
@@ -291,7 +291,7 @@ class ConversationSummaryBufferMemory(Memory, ChatMemoryMixin, BaseModel):
memory_key: str = "history"
@property
def buffer(self) -> List[ChatGeneration]:
def buffer(self) -> List[ChatMessage]:
return self.chat_memory.messages
@property
@@ -321,7 +321,7 @@ class ConversationSummaryBufferMemory(Memory, ChatMemoryMixin, BaseModel):
)
return values
def get_num_tokens_list(self, arr: List[ChatGeneration]) -> List[int]:
def get_num_tokens_list(self, arr: List[ChatMessage]) -> List[int]:
"""Get list of number of tokens in each string in the input array."""
return [self.llm.get_num_tokens(get_buffer_string([x])) for x in arr]
@@ -348,7 +348,7 @@ class ConversationSummaryBufferMemory(Memory, ChatMemoryMixin, BaseModel):
self.moving_summary_buffer = ""
class ConversationKGMemory(Memory, ChatMemoryMixin, BaseModel):
class ConversationKGMemory(ChatMemoryMixin, BaseModel):
"""Knowledge graph memory for storing conversation memory.
Integrates with external knowledge graph to store and retrieve

28
langchain/chat/base.py Normal file
View File

@@ -0,0 +1,28 @@
from langchain.chains.base import Chain
from abc import ABC
from langchain.memory.chat_memory import ChatMemory
from pydantic import root_validator
from typing import Dict
class BaseChatChain(Chain, ABC):
human_prefix: str = "user"
ai_prefix: str = "assistant"
@root_validator(pre=True)
def validate_memory_keys(cls, values: Dict) -> Dict:
"""Validate that the human and ai prefixes line up."""
memory = values["memory"]
if isinstance(memory, ChatMemory):
if memory.human_prefix != values["human_prefix"]:
raise ValueError(
f"Memory human_prefix ({memory.human_prefix}) must "
f"match chain human_prefix ({values['human_prefix']})"
)
if memory.ai_prefix != values["ai_prefix"]:
raise ValueError(
f"Memory ai_prefix ({memory.ai_prefix}) must "
f"match chain ai_prefix ({values['ai_prefix']})"
)
return values

View File

@@ -0,0 +1,66 @@
"""Chain that carries on a conversation and calls an LLM."""
from typing import Dict, List, Any
from pydantic import BaseModel, Extra, Field, root_validator
from langchain.chat.base import BaseChatChain
from langchain.chains.conversation.prompt import PROMPT
from langchain.chains.llm import LLMChain
from langchain.prompts.base import BasePromptTemplate
from langchain.chat.memory import SimpleChatMemory
from langchain.chat_models.base import BaseChat
from langchain.schema import ChatMessage
class ConversationChain(BaseChatChain, BaseModel):
"""Chain to have a conversation and load context from memory.
Example:
.. code-block:: python
from langchain import ConversationChain, OpenAI
conversation = ConversationChain(llm=OpenAI())
"""
model: BaseChat
memory: SimpleChatMemory = Field(default_factory=SimpleChatMemory)
"""Default memory store."""
prompt: BasePromptTemplate = PROMPT
"""Default conversation prompt to use."""
input_key: str = "input" #: :meta private:
output_key: str = "response" #: :meta private:
starter_messages: List[ChatMessage] = Field(default_factory=list)
@classmethod
def from_model(cls, model: BaseModel, **kwargs: Any):
"""From model. Future proofing."""
return cls(model=model, **kwargs)
@property
def _chain_type(self) -> str:
return "chat:conversation"
@property
def output_keys(self) -> List[str]:
return [self.output_key]
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
new_message = ChatMessage(text = inputs[self.input_key], role=self.human_prefix)
messages = self.starter_messages + self.memory.messages + [new_message]
output = self.model.run(messages)
return {self.output_key: output.text}
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
@property
def input_keys(self) -> List[str]:
"""Use this since so some prompt vars come from history."""
return [self.input_key]

28
langchain/chat/memory.py Normal file
View File

@@ -0,0 +1,28 @@
from typing import Dict, Any, List
from langchain.chains.base import Memory
from langchain.memory.chat_memory import ChatMemory
from abc import ABC, abstractmethod
def _get_prompt_input_key(inputs: Dict[str, Any]) -> str:
# "stop" is a special key that can be passed as input but is not used to
# format the prompt.
prompt_input_keys = list(set(inputs).difference(["stop"]))
if len(prompt_input_keys) != 1:
raise ValueError(f"One input key expected got {prompt_input_keys}")
return prompt_input_keys[0]
class SimpleChatMemory(Memory, ChatMemory):
def clear(self) -> None:
self.clear()
@property
def memory_variables(self) -> List[str]:
return []
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]:
return {}
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
self.add_user_message(inputs[_get_prompt_input_key(inputs)])
self.add_ai_message(outputs[_get_prompt_input_key(outputs)])

View File

@@ -0,0 +1,66 @@
"""Question Answering."""
from typing import Dict, List, Any
from pydantic import BaseModel, Extra, Field, root_validator
from langchain.chat.base import BaseChatChain
from langchain.chains.conversation.prompt import PROMPT
from langchain.chains.llm import LLMChain
from langchain.prompts.base import BasePromptTemplate
from langchain.chat.memory import SimpleChatMemory
from langchain.chat_models.base import BaseChat
from langchain.schema import ChatMessage
class QAChain(BaseChatChain, BaseModel):
"""Chain to have a conversation and load context from memory.
Example:
.. code-block:: python
from langchain import ConversationChain, OpenAI
conversation = ConversationChain(llm=OpenAI())
"""
model: BaseChat
"""Default memory store."""
prompt: BasePromptTemplate = PROMPT
"""Default conversation prompt to use."""
question_key: str = "question" #: :meta private:
documents_key: str = "input_documents" #: :meta private:
output_key: str = "response" #: :meta private:
starter_messages: List[ChatMessage] = Field(default_factory=list)
@classmethod
def from_model(cls, model: BaseModel, **kwargs: Any):
"""From model. Future proofing."""
return cls(model=model, **kwargs)
@property
def _chain_type(self) -> str:
return "chat:qa"
@property
def output_keys(self) -> List[str]:
return [self.output_key]
def _call(self, inputs: Dict[str, Any]) -> Dict[str, str]:
new_message = ChatMessage(text = inputs[self.question_key], role=self.human_prefix)
docs = inputs[self.documents_key]
doc_messages = [ChatMessage(text=doc.page_content, role=self.human_prefix) for doc in docs]
messages = self.starter_messages + doc_messages + [new_message]
output = self.model.run(messages)
return {self.output_key: output.text}
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
@property
def input_keys(self) -> List[str]:
"""Use this since so some prompt vars come from history."""
return [self.question_key, self.documents_key]

View File

@@ -0,0 +1,170 @@
"""Chain for question-answering against a vector database."""
from __future__ import annotations
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Extra, Field, root_validator
from langchain.chains.base import Chain
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
from langchain.chains.llm import LLMChain
from langchain.chains.question_answering import load_qa_chain
from langchain.chains.vector_db_qa.prompt import PROMPT
from langchain.llms.base import BaseLLM
from langchain.prompts import PromptTemplate
from langchain.vectorstores.base import VectorStore
class VectorDBQA(Chain, BaseModel):
"""Chain for question-answering against a vector database.
Example:
.. code-block:: python
from langchain import OpenAI, VectorDBQA
from langchain.faiss import FAISS
vectordb = FAISS(...)
vectordbQA = VectorDBQA(llm=OpenAI(), vectorstore=vectordb)
"""
vectorstore: VectorStore = Field(exclude=True)
"""Vector Database to connect to."""
k: int = 4
"""Number of documents to query for."""
combine_documents_chain: BaseCombineDocumentsChain
"""Chain to use to combine the documents."""
input_key: str = "query" #: :meta private:
output_key: str = "result" #: :meta private:
return_source_documents: bool = False
"""Return the source documents."""
search_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Extra search args."""
search_type: str = "similarity"
"""Search type to use over vectorstore. `similarity` or `mmr`."""
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
@property
def input_keys(self) -> List[str]:
"""Return the input keys.
:meta private:
"""
return [self.input_key]
@property
def output_keys(self) -> List[str]:
"""Return the output keys.
:meta private:
"""
_output_keys = [self.output_key]
if self.return_source_documents:
_output_keys = _output_keys + ["source_documents"]
return _output_keys
# TODO: deprecate this
@root_validator(pre=True)
def load_combine_documents_chain(cls, values: Dict) -> Dict:
"""Validate question chain."""
if "combine_documents_chain" not in values:
if "llm" not in values:
raise ValueError(
"If `combine_documents_chain` not provided, `llm` should be."
)
prompt = values.pop("prompt", PROMPT)
llm = values.pop("llm")
llm_chain = LLMChain(llm=llm, prompt=prompt)
document_prompt = PromptTemplate(
input_variables=["page_content"], template="Context:\n{page_content}"
)
combine_documents_chain = StuffDocumentsChain(
llm_chain=llm_chain,
document_variable_name="context",
document_prompt=document_prompt,
)
values["combine_documents_chain"] = combine_documents_chain
return values
@root_validator()
def validate_search_type(cls, values: Dict) -> Dict:
"""Validate search type."""
if "search_type" in values:
search_type = values["search_type"]
if search_type not in ("similarity", "mmr"):
raise ValueError(f"search_type of {search_type} not allowed.")
return values
@classmethod
def from_llm(
cls, llm: BaseLLM, prompt: PromptTemplate = PROMPT, **kwargs: Any
) -> VectorDBQA:
"""Initialize from LLM."""
llm_chain = LLMChain(llm=llm, prompt=prompt)
document_prompt = PromptTemplate(
input_variables=["page_content"], template="Context:\n{page_content}"
)
combine_documents_chain = StuffDocumentsChain(
llm_chain=llm_chain,
document_variable_name="context",
document_prompt=document_prompt,
)
return cls(combine_documents_chain=combine_documents_chain, **kwargs)
@classmethod
def from_chain_type(
cls,
llm: BaseLLM,
chain_type: str = "stuff",
chain_type_kwargs: Optional[dict] = None,
**kwargs: Any,
) -> VectorDBQA:
"""Load chain from chain type."""
_chain_type_kwargs = chain_type_kwargs or {}
combine_documents_chain = load_qa_chain(
llm, chain_type=chain_type, **_chain_type_kwargs
)
return cls(combine_documents_chain=combine_documents_chain, **kwargs)
def _call(self, inputs: Dict[str, str]) -> Dict[str, Any]:
"""Run similarity search and llm on input query.
If chain has 'return_source_documents' as 'True', returns
the retrieved documents as well under the key 'source_documents'.
Example:
.. code-block:: python
res = vectordbqa({'query': 'This is my query'})
answer, docs = res['result'], res['source_documents']
"""
question = inputs[self.input_key]
if self.search_type == "similarity":
docs = self.vectorstore.similarity_search(
question, k=self.k, **self.search_kwargs
)
elif self.search_type == "mmr":
docs = self.vectorstore.max_marginal_relevance_search(
question, k=self.k, **self.search_kwargs
)
else:
raise ValueError(f"search_type of {self.search_type} not allowed.")
answer, _ = self.combine_documents_chain.combine_docs(docs, question=question)
if self.return_source_documents:
return {self.output_key: answer, "source_documents": docs}
else:
return {self.output_key: answer}
@property
def _chain_type(self) -> str:
"""Return the chain type."""
return "chat:vector_db_qa"

View File

@@ -1,12 +1,12 @@
from abc import ABC, abstractmethod
from typing import Dict, List, Optional
from langchain.schema import ChatGeneration, ChatResult
from langchain.schema import ChatGeneration, ChatResult, ChatMessage
class BaseChat(ABC):
def generate(
self, messages: List[Dict], stop: Optional[List[str]] = None
self, messages: List[ChatMessage], stop: Optional[List[str]] = None
) -> ChatResult:
"""Top Level call"""
# Nothing here now, but future proofing.
@@ -14,21 +14,29 @@ class BaseChat(ABC):
@abstractmethod
def _generate(
self, messages: List[Dict], stop: Optional[List[str]] = None
self, messages: List[ChatMessage], stop: Optional[List[str]] = None
) -> ChatResult:
"""Top Level call"""
def run(
self, messages: List[ChatMessage], stop: Optional[List[str]] = None
) -> ChatMessage:
res = self.generate(messages, stop=stop)
return res.generations[0].message
class SimpleChat(BaseChat):
role: str = "assistant"
def _generate(
self, messages: List[Dict], stop: Optional[List[str]] = None
self, messages: List[ChatMessage], stop: Optional[List[str]] = None
) -> ChatResult:
output_str = self._call(messages, stop=stop)
generation = ChatGeneration(text=output_str, role=self.role)
message = ChatMessage(text=output_str, role=self.role)
generation = ChatGeneration(message=message)
return ChatResult(generations=[generation])
@abstractmethod
def _call(self, messages: List[Dict], stop: Optional[List[str]] = None) -> str:
def _call(self, messages: List[ChatMessage], stop: Optional[List[str]] = None) -> str:
"""Simpler interface."""

View File

@@ -12,13 +12,13 @@ from tenacity import (
)
from langchain.chat_models.base import BaseChat
from langchain.schema import ChatGeneration, ChatResult
from langchain.schema import ChatGeneration, ChatResult, ChatMessage
from langchain.utils import get_from_dict_or_env
logger = logging.getLogger(__file__)
class OpenAIChat(BaseChat, BaseModel):
class OpenAI(BaseChat, BaseModel):
"""Wrapper around OpenAI Chat large language models.
To use, you should have the ``openai`` python package installed, and the
@@ -124,19 +124,19 @@ class OpenAIChat(BaseChat, BaseModel):
return _completion_with_retry(**kwargs)
def _generate(
self, messages: List[Dict], stop: Optional[List[str]] = None
self, messages: List[ChatMessage], stop: Optional[List[str]] = None
) -> ChatResult:
params: Dict[str, Any] = {**{"model": self.model_name}, **self._default_params}
if stop is not None:
if "stop" in params:
raise ValueError("`stop` found in both the input and default params.")
params["stop"] = stop
response = self.completion_with_retry(messages=messages, **params)
message_dicts = [{"role": m.role, "content": m.text} for m in messages]
response = self.completion_with_retry(messages=message_dicts, **params)
generations = []
for res in response["choices"]:
gen = ChatGeneration(
text=res["message"]["content"], role=res["message"]["role"]
)
message = ChatMessage(text=res["message"]["content"], role=res["message"]["role"])
gen = ChatGeneration(message=message)
generations.append(gen)
return ChatResult(generations=generations)

View File

@@ -2,20 +2,20 @@ from typing import List
from pydantic import BaseModel, Field
from langchain.schema import ChatGeneration
from langchain.schema import ChatMessage
class ChatMemory(BaseModel):
human_prefix: str = "user"
ai_prefix: str = "assistant"
messages: List[ChatGeneration] = Field(default_factory=list)
messages: List[ChatMessage] = Field(default_factory=list)
def add_user_message(self, message: str) -> None:
gen = ChatGeneration(text=message, role=self.human_prefix)
gen = ChatMessage(text=message, role=self.human_prefix)
self.messages.append(gen)
def add_ai_message(self, message: str) -> None:
gen = ChatGeneration(text=message, role=self.ai_prefix)
gen = ChatMessage(text=message, role=self.ai_prefix)
self.messages.append(gen)
def clear(self):

View File

@@ -1,8 +1,8 @@
from typing import List
from langchain.schema import ChatGeneration
from langchain.schema import ChatMessage
def get_buffer_string(messages: List[ChatGeneration]):
def get_buffer_string(messages: List[ChatMessage]):
"""Get buffer string of messages."""
return "\n".join([f"{gen.role}: {gen.text}" for gen in messages])

View File

@@ -46,11 +46,10 @@ class LLMResult:
llm_output: Optional[dict] = None
"""For arbitrary LLM provider specific output."""
@dataclass_json
@dataclass
class ChatGeneration:
"""Output of a single generation."""
class ChatMessage:
"""Message object."""
text: str
"""Generated text output."""
@@ -58,6 +57,14 @@ class ChatGeneration:
role: str
"""Role of the chatter."""
@dataclass_json
@dataclass
class ChatGeneration:
"""Output of a single generation."""
message: ChatMessage
generation_info: Optional[Dict[str, Any]] = None
"""Raw generation info response from the provider"""
"""May include things like reason for finishing (e.g. in OpenAI)"""