mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-21 06:33:41 +00:00
Compare commits
8 Commits
langchain-
...
ankush/cal
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fe353c610f | ||
|
|
f897e7102b | ||
|
|
366de1bd58 | ||
|
|
4553f64b4b | ||
|
|
3c7a559e77 | ||
|
|
7d465cbc2f | ||
|
|
ecdfbfe1c7 | ||
|
|
4d366eeea3 |
@@ -11,11 +11,20 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 3,
|
"execution_count": 1,
|
||||||
"id": "5268c7fa",
|
"id": "5268c7fa",
|
||||||
"metadata": {},
|
"metadata": {
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
"import os\n",
|
||||||
|
"os.environ[\"LANGCHAIN_HANDLER\"] = \"langchain\"\n",
|
||||||
|
"\n",
|
||||||
|
"## Uncomment this if using hosted setup.\n",
|
||||||
|
"\n",
|
||||||
|
"os.environ[\"LANGCHAIN_ENDPOINT\"] = \"http://127.0.0.1:8000\" \n",
|
||||||
|
"\n",
|
||||||
"from langchain.agents import ZeroShotAgent, Tool, AgentExecutor\n",
|
"from langchain.agents import ZeroShotAgent, Tool, AgentExecutor\n",
|
||||||
"from langchain.chains import LLMChain\n",
|
"from langchain.chains import LLMChain\n",
|
||||||
"from langchain.utilities import SerpAPIWrapper"
|
"from langchain.utilities import SerpAPIWrapper"
|
||||||
@@ -23,9 +32,11 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 4,
|
"execution_count": 2,
|
||||||
"id": "fbaa4dbe",
|
"id": "fbaa4dbe",
|
||||||
"metadata": {},
|
"metadata": {
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"search = SerpAPIWrapper()\n",
|
"search = SerpAPIWrapper()\n",
|
||||||
@@ -40,9 +51,11 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 6,
|
"execution_count": 3,
|
||||||
"id": "f3ba6f08",
|
"id": "f3ba6f08",
|
||||||
"metadata": {},
|
"metadata": {
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"prefix = \"\"\"Answer the following questions as best you can, but speaking as a pirate might speak. You have access to the following tools:\"\"\"\n",
|
"prefix = \"\"\"Answer the following questions as best you can, but speaking as a pirate might speak. You have access to the following tools:\"\"\"\n",
|
||||||
@@ -58,9 +71,11 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 7,
|
"execution_count": 4,
|
||||||
"id": "3547a37d",
|
"id": "3547a37d",
|
||||||
"metadata": {},
|
"metadata": {
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"from langchain.chat_models import ChatOpenAI\n",
|
"from langchain.chat_models import ChatOpenAI\n",
|
||||||
@@ -79,9 +94,11 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 8,
|
"execution_count": 5,
|
||||||
"id": "a78f886f",
|
"id": "a78f886f",
|
||||||
"metadata": {},
|
"metadata": {
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"messages = [\n",
|
"messages = [\n",
|
||||||
@@ -94,9 +111,11 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 9,
|
"execution_count": 6,
|
||||||
"id": "dadadd70",
|
"id": "dadadd70",
|
||||||
"metadata": {},
|
"metadata": {
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"prompt = ChatPromptTemplate.from_messages(messages)"
|
"prompt = ChatPromptTemplate.from_messages(messages)"
|
||||||
@@ -104,9 +123,11 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 10,
|
"execution_count": 7,
|
||||||
"id": "b7180182",
|
"id": "b7180182",
|
||||||
"metadata": {},
|
"metadata": {
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"llm_chain = LLMChain(llm=ChatOpenAI(temperature=0), prompt=prompt)"
|
"llm_chain = LLMChain(llm=ChatOpenAI(temperature=0), prompt=prompt)"
|
||||||
@@ -114,9 +135,11 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 11,
|
"execution_count": 8,
|
||||||
"id": "ddddb07b",
|
"id": "ddddb07b",
|
||||||
"metadata": {},
|
"metadata": {
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"tool_names = [tool.name for tool in tools]\n",
|
"tool_names = [tool.name for tool in tools]\n",
|
||||||
@@ -125,9 +148,11 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 12,
|
"execution_count": 9,
|
||||||
"id": "36aef054",
|
"id": "36aef054",
|
||||||
"metadata": {},
|
"metadata": {
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"agent_executor = AgentExecutor.from_agent_and_tools(agent=agent, tools=tools, verbose=True)"
|
"agent_executor = AgentExecutor.from_agent_and_tools(agent=agent, tools=tools, verbose=True)"
|
||||||
@@ -135,9 +160,11 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 13,
|
"execution_count": 10,
|
||||||
"id": "33a4d6cc",
|
"id": "33a4d6cc",
|
||||||
"metadata": {},
|
"metadata": {
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
"name": "stdout",
|
"name": "stdout",
|
||||||
@@ -146,16 +173,16 @@
|
|||||||
"\n",
|
"\n",
|
||||||
"\n",
|
"\n",
|
||||||
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
|
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
|
||||||
"\u001b[32;1m\u001b[1;3mArrr, ye be in luck, matey! I'll find ye the answer to yer question.\n",
|
"\u001b[32;1m\u001b[1;3mArrr, ye be in luck, matey! I'll help ye find the answer ye seek.\n",
|
||||||
"\n",
|
"\n",
|
||||||
"Thought: I need to search for the current population of Canada.\n",
|
"Thought: Hmm, I don't have this information off the top of me head.\n",
|
||||||
"Action: Search\n",
|
"Action: Search\n",
|
||||||
"Action Input: \"current population of Canada 2023\"\n",
|
"Action Input: \"Canada population 2023\"\n",
|
||||||
"\u001b[0m\n",
|
"\u001b[0m\n",
|
||||||
"Observation: \u001b[36;1m\u001b[1;3mThe current population of Canada is 38,623,091 as of Saturday, March 4, 2023, based on Worldometer elaboration of the latest United Nations data.\u001b[0m\n",
|
"Observation: \u001b[36;1m\u001b[1;3mThe current population of Canada is 38,627,607 as of Thursday, March 9, 2023, based on Worldometer elaboration of the latest United Nations data.\u001b[0m\n",
|
||||||
"Thought:\u001b[32;1m\u001b[1;3mAhoy, me hearties! I've found the answer to yer question.\n",
|
"Thought:\u001b[32;1m\u001b[1;3mAhoy, me hearties! I have found the answer to yer question.\n",
|
||||||
"\n",
|
"\n",
|
||||||
"Final Answer: As of March 4, 2023, the population of Canada be 38,623,091. Arrr!\u001b[0m\n",
|
"Final Answer: As of March 9, 2023, the population of Canada is 38,627,607. Arrr!\u001b[0m\n",
|
||||||
"\n",
|
"\n",
|
||||||
"\u001b[1m> Finished chain.\u001b[0m\n"
|
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||||||
]
|
]
|
||||||
@@ -163,10 +190,10 @@
|
|||||||
{
|
{
|
||||||
"data": {
|
"data": {
|
||||||
"text/plain": [
|
"text/plain": [
|
||||||
"'As of March 4, 2023, the population of Canada be 38,623,091. Arrr!'"
|
"'As of March 9, 2023, the population of Canada is 38,627,607. Arrr!'"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"execution_count": 13,
|
"execution_count": 10,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"output_type": "execute_result"
|
"output_type": "execute_result"
|
||||||
}
|
}
|
||||||
@@ -200,7 +227,7 @@
|
|||||||
"name": "python",
|
"name": "python",
|
||||||
"nbconvert_exporter": "python",
|
"nbconvert_exporter": "python",
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.9.1"
|
"version": "3.10.9"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"nbformat": 4,
|
"nbformat": 4,
|
||||||
|
|||||||
@@ -372,7 +372,7 @@
|
|||||||
"name": "python",
|
"name": "python",
|
||||||
"nbconvert_exporter": "python",
|
"nbconvert_exporter": "python",
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.9.1"
|
"version": "3.10.9"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"nbformat": 4,
|
"nbformat": 4,
|
||||||
|
|||||||
@@ -151,7 +151,7 @@
|
|||||||
"name": "python",
|
"name": "python",
|
||||||
"nbconvert_exporter": "python",
|
"nbconvert_exporter": "python",
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.9.1"
|
"version": "3.10.9"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"nbformat": 4,
|
"nbformat": 4,
|
||||||
|
|||||||
@@ -22,7 +22,7 @@
|
|||||||
"\n",
|
"\n",
|
||||||
"## Uncomment this if using hosted setup.\n",
|
"## Uncomment this if using hosted setup.\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# os.environ[\"LANGCHAIN_ENDPOINT\"] = \"https://langchain-api-gateway-57eoxz8z.uc.gateway.dev\" \n",
|
"os.environ[\"LANGCHAIN_ENDPOINT\"] = \"http://127.0.0.1:8000\" \n",
|
||||||
"\n",
|
"\n",
|
||||||
"## Uncomment this if you want traces to be recorded to \"my_session\" instead of default.\n",
|
"## Uncomment this if you want traces to be recorded to \"my_session\" instead of default.\n",
|
||||||
"\n",
|
"\n",
|
||||||
@@ -89,9 +89,30 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": 4,
|
||||||
"id": "25addd7f",
|
"id": "25addd7f",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"'\\n\\nQ: What did the fish say when it hit the wall?\\nA: Dam!'"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 4,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"llm(\"tell me a joke\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "7afdd321-796f-4dcc-816c-c1160aa943af",
|
||||||
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": []
|
"source": []
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ from typing import Any, List, Optional, Sequence, Tuple
|
|||||||
|
|
||||||
from langchain.agents.agent import Agent
|
from langchain.agents.agent import Agent
|
||||||
from langchain.agents.chat.prompt import FORMAT_INSTRUCTIONS, PREFIX, SUFFIX
|
from langchain.agents.chat.prompt import FORMAT_INSTRUCTIONS, PREFIX, SUFFIX
|
||||||
|
from langchain.base_language_model import BaseLanguageModel
|
||||||
from langchain.callbacks.base import BaseCallbackManager
|
from langchain.callbacks.base import BaseCallbackManager
|
||||||
from langchain.chains.llm import LLMChain
|
from langchain.chains.llm import LLMChain
|
||||||
from langchain.prompts.base import BasePromptTemplate
|
from langchain.prompts.base import BasePromptTemplate
|
||||||
@@ -11,7 +12,7 @@ from langchain.prompts.chat import (
|
|||||||
HumanMessagePromptTemplate,
|
HumanMessagePromptTemplate,
|
||||||
SystemMessagePromptTemplate,
|
SystemMessagePromptTemplate,
|
||||||
)
|
)
|
||||||
from langchain.schema import AgentAction, BaseLanguageModel
|
from langchain.schema import AgentAction
|
||||||
from langchain.tools import BaseTool
|
from langchain.tools import BaseTool
|
||||||
|
|
||||||
FINAL_ANSWER_ACTION = "Final Answer:"
|
FINAL_ANSWER_ACTION = "Final Answer:"
|
||||||
|
|||||||
124
langchain/base_language_model.py
Normal file
124
langchain/base_language_model.py
Normal file
@@ -0,0 +1,124 @@
|
|||||||
|
"""Base class for language models."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Extra, Field, validator
|
||||||
|
|
||||||
|
import langchain
|
||||||
|
from langchain.callbacks import BaseCallbackManager, get_callback_manager
|
||||||
|
from langchain.schema import LLMResult, PromptValue
|
||||||
|
|
||||||
|
|
||||||
|
def _get_verbosity() -> bool:
|
||||||
|
return langchain.verbose
|
||||||
|
|
||||||
|
|
||||||
|
class BaseLanguageModel(BaseModel, ABC):
|
||||||
|
"""Base class for language models."""
|
||||||
|
|
||||||
|
verbose: bool = Field(default_factory=_get_verbosity)
|
||||||
|
"""Whether to print out response text."""
|
||||||
|
callback_manager: BaseCallbackManager = Field(default_factory=get_callback_manager)
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
"""Configuration for this pydantic object."""
|
||||||
|
|
||||||
|
extra = Extra.forbid
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
@validator("callback_manager", pre=True, always=True)
|
||||||
|
def set_callback_manager(
|
||||||
|
cls, callback_manager: Optional[BaseCallbackManager]
|
||||||
|
) -> BaseCallbackManager:
|
||||||
|
"""If callback manager is None, set it.
|
||||||
|
|
||||||
|
This allows users to pass in None as callback manager, which is a nice UX.
|
||||||
|
"""
|
||||||
|
return callback_manager or get_callback_manager()
|
||||||
|
|
||||||
|
@validator("verbose", pre=True, always=True)
|
||||||
|
def set_verbose(cls, verbose: Optional[bool]) -> bool:
|
||||||
|
"""If verbose is None, set it.
|
||||||
|
|
||||||
|
This allows users to pass in None as verbose to access the global setting.
|
||||||
|
"""
|
||||||
|
if verbose is None:
|
||||||
|
return _get_verbosity()
|
||||||
|
else:
|
||||||
|
return verbose
|
||||||
|
|
||||||
|
def generate_prompt(
|
||||||
|
self, prompts: List[PromptValue], stop: Optional[List[str]] = None
|
||||||
|
) -> LLMResult:
|
||||||
|
"""Take in a list of prompt values and return an LLMResult."""
|
||||||
|
self.callback_manager.on_llm_start_prompt_value(
|
||||||
|
{"name": self.__class__.__name__}, prompts, verbose=self.verbose
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
output = self._generate_prompt(prompts, stop=stop)
|
||||||
|
except (KeyboardInterrupt, Exception) as e:
|
||||||
|
self.callback_manager.on_llm_error(e, verbose=self.verbose)
|
||||||
|
raise e
|
||||||
|
self.callback_manager.on_llm_end(output, verbose=self.verbose)
|
||||||
|
return output
|
||||||
|
|
||||||
|
async def agenerate_prompt(
|
||||||
|
self, prompts: List[PromptValue], stop: Optional[List[str]] = None
|
||||||
|
) -> LLMResult:
|
||||||
|
"""Take in a list of prompt values and return an LLMResult."""
|
||||||
|
if self.callback_manager.is_async:
|
||||||
|
await self.callback_manager.on_llm_start_prompt_value(
|
||||||
|
{"name": self.__class__.__name__}, prompts, verbose=self.verbose
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.callback_manager.on_llm_start_prompt_value(
|
||||||
|
{"name": self.__class__.__name__}, prompts, verbose=self.verbose
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
output = await self._agenerate_prompt(prompts, stop=stop)
|
||||||
|
except (KeyboardInterrupt, Exception) as e:
|
||||||
|
if self.callback_manager.is_async:
|
||||||
|
await self.callback_manager.on_llm_error(e, verbose=self.verbose)
|
||||||
|
else:
|
||||||
|
self.callback_manager.on_llm_error(e, verbose=self.verbose)
|
||||||
|
raise e
|
||||||
|
if self.callback_manager.is_async:
|
||||||
|
await self.callback_manager.on_llm_end(output, verbose=self.verbose)
|
||||||
|
else:
|
||||||
|
self.callback_manager.on_llm_end(output, verbose=self.verbose)
|
||||||
|
return output
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _generate_prompt(
|
||||||
|
self, prompts: List[PromptValue], stop: Optional[List[str]] = None
|
||||||
|
) -> LLMResult:
|
||||||
|
"""Take in a list of prompt values and return an LLMResult."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def _agenerate_prompt(
|
||||||
|
self, prompts: List[PromptValue], stop: Optional[List[str]] = None
|
||||||
|
) -> LLMResult:
|
||||||
|
"""Take in a list of prompt values and return an LLMResult."""
|
||||||
|
|
||||||
|
def get_num_tokens(self, text: str) -> int:
|
||||||
|
"""Get the number of tokens present in the text."""
|
||||||
|
# TODO: this method may not be exact.
|
||||||
|
# TODO: this method may differ based on model (eg codex).
|
||||||
|
try:
|
||||||
|
from transformers import GPT2TokenizerFast
|
||||||
|
except ImportError:
|
||||||
|
raise ValueError(
|
||||||
|
"Could not import transformers python package. "
|
||||||
|
"This is needed in order to calculate get_num_tokens. "
|
||||||
|
"Please it install it with `pip install transformers`."
|
||||||
|
)
|
||||||
|
# create a GPT-3 tokenizer instance
|
||||||
|
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
|
||||||
|
|
||||||
|
# tokenize the text using the GPT-3 tokenizer
|
||||||
|
tokenized_text = tokenizer.tokenize(text)
|
||||||
|
|
||||||
|
# calculate the number of tokens in the tokenized text
|
||||||
|
return len(tokenized_text)
|
||||||
@@ -4,7 +4,7 @@ import functools
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any, Dict, List, Union
|
from typing import Any, Dict, List, Union
|
||||||
|
|
||||||
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
from langchain.schema import AgentAction, AgentFinish, LLMResult, PromptValue
|
||||||
|
|
||||||
|
|
||||||
class BaseCallbackHandler(ABC):
|
class BaseCallbackHandler(ABC):
|
||||||
@@ -30,6 +30,12 @@ class BaseCallbackHandler(ABC):
|
|||||||
"""Whether to ignore agent callbacks."""
|
"""Whether to ignore agent callbacks."""
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def on_llm_start_prompt_value(
|
||||||
|
self, serialized: Dict[str, Any], prompts: List[PromptValue], **kwargs: Any
|
||||||
|
) -> Any:
|
||||||
|
"""Run when LLM starts running."""
|
||||||
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def on_llm_start(
|
def on_llm_start(
|
||||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||||
@@ -127,6 +133,19 @@ class CallbackManager(BaseCallbackManager):
|
|||||||
"""Initialize callback manager."""
|
"""Initialize callback manager."""
|
||||||
self.handlers: List[BaseCallbackHandler] = handlers
|
self.handlers: List[BaseCallbackHandler] = handlers
|
||||||
|
|
||||||
|
def on_llm_start_prompt_value(
|
||||||
|
self,
|
||||||
|
serialized: Dict[str, Any],
|
||||||
|
prompts: List[PromptValue],
|
||||||
|
verbose: bool = False,
|
||||||
|
**kwargs: Any
|
||||||
|
) -> Any:
|
||||||
|
"""Run when LLM starts running."""
|
||||||
|
for handler in self.handlers:
|
||||||
|
if not handler.ignore_llm:
|
||||||
|
if verbose or handler.always_verbose:
|
||||||
|
handler.on_llm_start_prompt_value(serialized, prompts, **kwargs)
|
||||||
|
|
||||||
def on_llm_start(
|
def on_llm_start(
|
||||||
self,
|
self,
|
||||||
serialized: Dict[str, Any],
|
serialized: Dict[str, Any],
|
||||||
@@ -276,6 +295,11 @@ class CallbackManager(BaseCallbackManager):
|
|||||||
class AsyncCallbackHandler(BaseCallbackHandler):
|
class AsyncCallbackHandler(BaseCallbackHandler):
|
||||||
"""Async callback handler that can be used to handle callbacks from langchain."""
|
"""Async callback handler that can be used to handle callbacks from langchain."""
|
||||||
|
|
||||||
|
async def on_llm_start_prompt_value(
|
||||||
|
self, serialized: Dict[str, Any], prompts: List[PromptValue], **kwargs: Any
|
||||||
|
) -> Any:
|
||||||
|
"""Run when LLM starts running."""
|
||||||
|
|
||||||
async def on_llm_start(
|
async def on_llm_start(
|
||||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||||
) -> None:
|
) -> None:
|
||||||
@@ -340,6 +364,32 @@ class AsyncCallbackManager(BaseCallbackManager):
|
|||||||
"""Initialize callback manager."""
|
"""Initialize callback manager."""
|
||||||
self.handlers: List[BaseCallbackHandler] = handlers
|
self.handlers: List[BaseCallbackHandler] = handlers
|
||||||
|
|
||||||
|
async def on_llm_start_prompt_value(
|
||||||
|
self,
|
||||||
|
serialized: Dict[str, Any],
|
||||||
|
prompts: List[PromptValue],
|
||||||
|
verbose: bool = False,
|
||||||
|
**kwargs: Any
|
||||||
|
) -> Any:
|
||||||
|
"""Run when LLM starts running."""
|
||||||
|
for handler in self.handlers:
|
||||||
|
if not handler.ignore_llm:
|
||||||
|
if verbose or handler.always_verbose:
|
||||||
|
if asyncio.iscoroutinefunction(handler.on_llm_start_prompt_value):
|
||||||
|
return await handler.on_llm_start_prompt_value(
|
||||||
|
serialized, prompts, **kwargs
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return await asyncio.get_event_loop().run_in_executor(
|
||||||
|
None,
|
||||||
|
functools.partial(
|
||||||
|
handler.on_llm_start_prompt_value,
|
||||||
|
serialized,
|
||||||
|
prompts,
|
||||||
|
**kwargs
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
async def on_llm_start(
|
async def on_llm_start(
|
||||||
self,
|
self,
|
||||||
serialized: Dict[str, Any],
|
serialized: Dict[str, Any],
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ from langchain.callbacks.base import (
|
|||||||
BaseCallbackManager,
|
BaseCallbackManager,
|
||||||
CallbackManager,
|
CallbackManager,
|
||||||
)
|
)
|
||||||
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
from langchain.schema import AgentAction, AgentFinish, LLMResult, PromptValue
|
||||||
|
|
||||||
|
|
||||||
class Singleton:
|
class Singleton:
|
||||||
@@ -34,6 +34,15 @@ class SharedCallbackManager(Singleton, BaseCallbackManager):
|
|||||||
|
|
||||||
_callback_manager: CallbackManager = CallbackManager(handlers=[])
|
_callback_manager: CallbackManager = CallbackManager(handlers=[])
|
||||||
|
|
||||||
|
def on_llm_start_prompt_value(
|
||||||
|
self, serialized: Dict[str, Any], prompts: List[PromptValue], **kwargs: Any
|
||||||
|
) -> Any:
|
||||||
|
"""Run when LLM starts running."""
|
||||||
|
with self._lock:
|
||||||
|
self._callback_manager.on_llm_start_prompt_value(
|
||||||
|
serialized, prompts, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
def on_llm_start(
|
def on_llm_start(
|
||||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ from langchain.callbacks.tracers.schemas import (
|
|||||||
TracerSession,
|
TracerSession,
|
||||||
TracerSessionCreate,
|
TracerSessionCreate,
|
||||||
)
|
)
|
||||||
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
from langchain.schema import AgentAction, AgentFinish, LLMResult, PromptValue
|
||||||
|
|
||||||
|
|
||||||
class TracerException(Exception):
|
class TracerException(Exception):
|
||||||
@@ -109,9 +109,9 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
|||||||
self._execution_order = 1
|
self._execution_order = 1
|
||||||
self._persist_run(run)
|
self._persist_run(run)
|
||||||
|
|
||||||
def on_llm_start(
|
def on_llm_start_prompt_value(
|
||||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
self, serialized: Dict[str, Any], prompts: List[PromptValue], **kwargs: Any
|
||||||
) -> None:
|
) -> Any:
|
||||||
"""Start a trace for an LLM run."""
|
"""Start a trace for an LLM run."""
|
||||||
if self._session is None:
|
if self._session is None:
|
||||||
raise TracerException(
|
raise TracerException(
|
||||||
@@ -129,6 +129,12 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
|||||||
)
|
)
|
||||||
self._start_trace(llm_run)
|
self._start_trace(llm_run)
|
||||||
|
|
||||||
|
def on_llm_start(
|
||||||
|
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
"""Start a trace for an LLM run."""
|
||||||
|
pass
|
||||||
|
|
||||||
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||||
"""Handle a new token for an LLM run."""
|
"""Handle a new token for an LLM run."""
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ from typing import Any, Dict, List, Optional, Union
|
|||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from langchain.schema import LLMResult
|
from langchain.schema import LLMResult, PromptValue
|
||||||
|
|
||||||
|
|
||||||
class TracerSessionBase(BaseModel):
|
class TracerSessionBase(BaseModel):
|
||||||
@@ -45,7 +45,7 @@ class BaseRun(BaseModel):
|
|||||||
class LLMRun(BaseRun):
|
class LLMRun(BaseRun):
|
||||||
"""Class for LLMRun."""
|
"""Class for LLMRun."""
|
||||||
|
|
||||||
prompts: List[str]
|
prompts: List[PromptValue]
|
||||||
response: Optional[LLMResult] = None
|
response: Optional[LLMResult] = None
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -6,13 +6,13 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
|||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from langchain.base_language_model import BaseLanguageModel
|
||||||
from langchain.chains.base import Chain
|
from langchain.chains.base import Chain
|
||||||
from langchain.chains.chat_vector_db.prompts import CONDENSE_QUESTION_PROMPT
|
from langchain.chains.chat_vector_db.prompts import CONDENSE_QUESTION_PROMPT
|
||||||
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
||||||
from langchain.chains.llm import LLMChain
|
from langchain.chains.llm import LLMChain
|
||||||
from langchain.chains.question_answering import load_qa_chain
|
from langchain.chains.question_answering import load_qa_chain
|
||||||
from langchain.prompts.base import BasePromptTemplate
|
from langchain.prompts.base import BasePromptTemplate
|
||||||
from langchain.schema import BaseLanguageModel
|
|
||||||
from langchain.vectorstores.base import VectorStore
|
from langchain.vectorstores.base import VectorStore
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -5,11 +5,12 @@ from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
|
|||||||
|
|
||||||
from pydantic import BaseModel, Extra
|
from pydantic import BaseModel, Extra
|
||||||
|
|
||||||
|
from langchain.base_language_model import BaseLanguageModel
|
||||||
from langchain.chains.base import Chain
|
from langchain.chains.base import Chain
|
||||||
from langchain.input import get_colored_text
|
from langchain.input import get_colored_text
|
||||||
from langchain.prompts.base import BasePromptTemplate
|
from langchain.prompts.base import BasePromptTemplate
|
||||||
from langchain.prompts.prompt import PromptTemplate
|
from langchain.prompts.prompt import PromptTemplate
|
||||||
from langchain.schema import BaseLanguageModel, LLMResult, PromptValue
|
from langchain.schema import LLMResult, PromptValue
|
||||||
|
|
||||||
|
|
||||||
class LLMChain(Chain, BaseModel):
|
class LLMChain(Chain, BaseModel):
|
||||||
|
|||||||
@@ -3,10 +3,10 @@ from typing import Callable, List, Tuple
|
|||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from langchain.base_language_model import BaseLanguageModel
|
||||||
from langchain.chat_models.base import BaseChatModel
|
from langchain.chat_models.base import BaseChatModel
|
||||||
from langchain.llms.base import BaseLLM
|
from langchain.llms.base import BaseLLM
|
||||||
from langchain.prompts.base import BasePromptTemplate
|
from langchain.prompts.base import BasePromptTemplate
|
||||||
from langchain.schema import BaseLanguageModel
|
|
||||||
|
|
||||||
|
|
||||||
class BasePromptSelector(BaseModel, ABC):
|
class BasePromptSelector(BaseModel, ABC):
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""Load question answering chains."""
|
"""Load question answering chains."""
|
||||||
from typing import Any, Mapping, Optional, Protocol
|
from typing import Any, Mapping, Optional, Protocol
|
||||||
|
|
||||||
|
from langchain.base_language_model import BaseLanguageModel
|
||||||
from langchain.callbacks.base import BaseCallbackManager
|
from langchain.callbacks.base import BaseCallbackManager
|
||||||
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
||||||
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain
|
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain
|
||||||
@@ -15,7 +16,6 @@ from langchain.chains.question_answering import (
|
|||||||
stuff_prompt,
|
stuff_prompt,
|
||||||
)
|
)
|
||||||
from langchain.prompts.base import BasePromptTemplate
|
from langchain.prompts.base import BasePromptTemplate
|
||||||
from langchain.schema import BaseLanguageModel
|
|
||||||
|
|
||||||
|
|
||||||
class LoadingCallable(Protocol):
|
class LoadingCallable(Protocol):
|
||||||
|
|||||||
@@ -1,14 +1,12 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, Extra, Field, validator
|
from pydantic import BaseModel
|
||||||
|
|
||||||
import langchain
|
import langchain
|
||||||
from langchain.callbacks import get_callback_manager
|
from langchain.base_language_model import BaseLanguageModel
|
||||||
from langchain.callbacks.base import BaseCallbackManager
|
|
||||||
from langchain.schema import (
|
from langchain.schema import (
|
||||||
AIMessage,
|
AIMessage,
|
||||||
BaseLanguageModel,
|
|
||||||
BaseMessage,
|
BaseMessage,
|
||||||
ChatGeneration,
|
ChatGeneration,
|
||||||
ChatResult,
|
ChatResult,
|
||||||
@@ -22,25 +20,7 @@ def _get_verbosity() -> bool:
|
|||||||
|
|
||||||
|
|
||||||
class BaseChatModel(BaseLanguageModel, BaseModel, ABC):
|
class BaseChatModel(BaseLanguageModel, BaseModel, ABC):
|
||||||
verbose: bool = Field(default_factory=_get_verbosity)
|
"""Base class for chat models."""
|
||||||
"""Whether to print out response text."""
|
|
||||||
callback_manager: BaseCallbackManager = Field(default_factory=get_callback_manager)
|
|
||||||
|
|
||||||
class Config:
|
|
||||||
"""Configuration for this pydantic object."""
|
|
||||||
|
|
||||||
extra = Extra.forbid
|
|
||||||
arbitrary_types_allowed = True
|
|
||||||
|
|
||||||
@validator("callback_manager", pre=True, always=True)
|
|
||||||
def set_callback_manager(
|
|
||||||
cls, callback_manager: Optional[BaseCallbackManager]
|
|
||||||
) -> BaseCallbackManager:
|
|
||||||
"""If callback manager is None, set it.
|
|
||||||
|
|
||||||
This allows users to pass in None as callback manager, which is a nice UX.
|
|
||||||
"""
|
|
||||||
return callback_manager or get_callback_manager()
|
|
||||||
|
|
||||||
def generate(
|
def generate(
|
||||||
self, messages: List[List[BaseMessage]], stop: Optional[List[str]] = None
|
self, messages: List[List[BaseMessage]], stop: Optional[List[str]] = None
|
||||||
@@ -56,13 +36,13 @@ class BaseChatModel(BaseLanguageModel, BaseModel, ABC):
|
|||||||
results = [await self._agenerate(m, stop=stop) for m in messages]
|
results = [await self._agenerate(m, stop=stop) for m in messages]
|
||||||
return LLMResult(generations=[res.generations for res in results])
|
return LLMResult(generations=[res.generations for res in results])
|
||||||
|
|
||||||
def generate_prompt(
|
def _generate_prompt(
|
||||||
self, prompts: List[PromptValue], stop: Optional[List[str]] = None
|
self, prompts: List[PromptValue], stop: Optional[List[str]] = None
|
||||||
) -> LLMResult:
|
) -> LLMResult:
|
||||||
prompt_messages = [p.to_messages() for p in prompts]
|
prompt_messages = [p.to_messages() for p in prompts]
|
||||||
return self.generate(prompt_messages, stop=stop)
|
return self.generate(prompt_messages, stop=stop)
|
||||||
|
|
||||||
async def agenerate_prompt(
|
async def _agenerate_prompt(
|
||||||
self, prompts: List[PromptValue], stop: Optional[List[str]] = None
|
self, prompts: List[PromptValue], stop: Optional[List[str]] = None
|
||||||
) -> LLMResult:
|
) -> LLMResult:
|
||||||
prompt_messages = [p.to_messages() for p in prompts]
|
prompt_messages = [p.to_messages() for p in prompts]
|
||||||
|
|||||||
@@ -5,16 +5,11 @@ from pathlib import Path
|
|||||||
from typing import Any, Dict, List, Mapping, Optional, Tuple, Union
|
from typing import Any, Dict, List, Mapping, Optional, Tuple, Union
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
from pydantic import BaseModel, Extra, Field, validator
|
from pydantic import BaseModel, Extra
|
||||||
|
|
||||||
import langchain
|
import langchain
|
||||||
from langchain.callbacks import get_callback_manager
|
from langchain.base_language_model import BaseLanguageModel
|
||||||
from langchain.callbacks.base import BaseCallbackManager
|
from langchain.schema import Generation, LLMResult, PromptValue
|
||||||
from langchain.schema import BaseLanguageModel, Generation, LLMResult, PromptValue
|
|
||||||
|
|
||||||
|
|
||||||
def _get_verbosity() -> bool:
|
|
||||||
return langchain.verbose
|
|
||||||
|
|
||||||
|
|
||||||
def get_prompts(
|
def get_prompts(
|
||||||
@@ -57,9 +52,6 @@ class BaseLLM(BaseLanguageModel, BaseModel, ABC):
|
|||||||
"""LLM wrapper should take in a prompt and return a string."""
|
"""LLM wrapper should take in a prompt and return a string."""
|
||||||
|
|
||||||
cache: Optional[bool] = None
|
cache: Optional[bool] = None
|
||||||
verbose: bool = Field(default_factory=_get_verbosity)
|
|
||||||
"""Whether to print out response text."""
|
|
||||||
callback_manager: BaseCallbackManager = Field(default_factory=get_callback_manager)
|
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
"""Configuration for this pydantic object."""
|
"""Configuration for this pydantic object."""
|
||||||
@@ -67,27 +59,6 @@ class BaseLLM(BaseLanguageModel, BaseModel, ABC):
|
|||||||
extra = Extra.forbid
|
extra = Extra.forbid
|
||||||
arbitrary_types_allowed = True
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
@validator("callback_manager", pre=True, always=True)
|
|
||||||
def set_callback_manager(
|
|
||||||
cls, callback_manager: Optional[BaseCallbackManager]
|
|
||||||
) -> BaseCallbackManager:
|
|
||||||
"""If callback manager is None, set it.
|
|
||||||
|
|
||||||
This allows users to pass in None as callback manager, which is a nice UX.
|
|
||||||
"""
|
|
||||||
return callback_manager or get_callback_manager()
|
|
||||||
|
|
||||||
@validator("verbose", pre=True, always=True)
|
|
||||||
def set_verbose(cls, verbose: Optional[bool]) -> bool:
|
|
||||||
"""If verbose is None, set it.
|
|
||||||
|
|
||||||
This allows users to pass in None as verbose to access the global setting.
|
|
||||||
"""
|
|
||||||
if verbose is None:
|
|
||||||
return _get_verbosity()
|
|
||||||
else:
|
|
||||||
return verbose
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _generate(
|
def _generate(
|
||||||
self, prompts: List[str], stop: Optional[List[str]] = None
|
self, prompts: List[str], stop: Optional[List[str]] = None
|
||||||
@@ -100,13 +71,13 @@ class BaseLLM(BaseLanguageModel, BaseModel, ABC):
|
|||||||
) -> LLMResult:
|
) -> LLMResult:
|
||||||
"""Run the LLM on the given prompts."""
|
"""Run the LLM on the given prompts."""
|
||||||
|
|
||||||
def generate_prompt(
|
def _generate_prompt(
|
||||||
self, prompts: List[PromptValue], stop: Optional[List[str]] = None
|
self, prompts: List[PromptValue], stop: Optional[List[str]] = None
|
||||||
) -> LLMResult:
|
) -> LLMResult:
|
||||||
prompt_strings = [p.to_string() for p in prompts]
|
prompt_strings = [p.to_string() for p in prompts]
|
||||||
return self.generate(prompt_strings, stop=stop)
|
return self.generate(prompt_strings, stop=stop)
|
||||||
|
|
||||||
async def agenerate_prompt(
|
async def _agenerate_prompt(
|
||||||
self, prompts: List[PromptValue], stop: Optional[List[str]] = None
|
self, prompts: List[PromptValue], stop: Optional[List[str]] = None
|
||||||
) -> LLMResult:
|
) -> LLMResult:
|
||||||
prompt_strings = [p.to_string() for p in prompts]
|
prompt_strings = [p.to_string() for p in prompts]
|
||||||
@@ -138,7 +109,6 @@ class BaseLLM(BaseLanguageModel, BaseModel, ABC):
|
|||||||
except (KeyboardInterrupt, Exception) as e:
|
except (KeyboardInterrupt, Exception) as e:
|
||||||
self.callback_manager.on_llm_error(e, verbose=self.verbose)
|
self.callback_manager.on_llm_error(e, verbose=self.verbose)
|
||||||
raise e
|
raise e
|
||||||
self.callback_manager.on_llm_end(output, verbose=self.verbose)
|
|
||||||
return output
|
return output
|
||||||
params = self.dict()
|
params = self.dict()
|
||||||
params["stop"] = stop
|
params["stop"] = stop
|
||||||
@@ -157,7 +127,6 @@ class BaseLLM(BaseLanguageModel, BaseModel, ABC):
|
|||||||
except (KeyboardInterrupt, Exception) as e:
|
except (KeyboardInterrupt, Exception) as e:
|
||||||
self.callback_manager.on_llm_error(e, verbose=self.verbose)
|
self.callback_manager.on_llm_error(e, verbose=self.verbose)
|
||||||
raise e
|
raise e
|
||||||
self.callback_manager.on_llm_end(new_results, verbose=self.verbose)
|
|
||||||
llm_output = update_cache(
|
llm_output = update_cache(
|
||||||
existing_prompts, llm_string, missing_prompt_idxs, new_results, prompts
|
existing_prompts, llm_string, missing_prompt_idxs, new_results, prompts
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ from typing import Any, Dict, List, Optional
|
|||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from langchain.base_language_model import BaseLanguageModel
|
||||||
from langchain.chains.llm import LLMChain
|
from langchain.chains.llm import LLMChain
|
||||||
from langchain.memory.chat_memory import BaseChatMemory
|
from langchain.memory.chat_memory import BaseChatMemory
|
||||||
from langchain.memory.prompt import (
|
from langchain.memory.prompt import (
|
||||||
@@ -10,7 +11,7 @@ from langchain.memory.prompt import (
|
|||||||
)
|
)
|
||||||
from langchain.memory.utils import get_buffer_string, get_prompt_input_key
|
from langchain.memory.utils import get_buffer_string, get_prompt_input_key
|
||||||
from langchain.prompts.base import BasePromptTemplate
|
from langchain.prompts.base import BasePromptTemplate
|
||||||
from langchain.schema import BaseLanguageModel, BaseMessage
|
from langchain.schema import BaseMessage
|
||||||
|
|
||||||
|
|
||||||
class ConversationEntityMemory(BaseChatMemory, BaseModel):
|
class ConversationEntityMemory(BaseChatMemory, BaseModel):
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ from typing import Any, Dict, List
|
|||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from langchain.base_language_model import BaseLanguageModel
|
||||||
from langchain.chains.llm import LLMChain
|
from langchain.chains.llm import LLMChain
|
||||||
from langchain.graphs import NetworkxEntityGraph
|
from langchain.graphs import NetworkxEntityGraph
|
||||||
from langchain.graphs.networkx_graph import KnowledgeTriple, get_entities, parse_triples
|
from langchain.graphs.networkx_graph import KnowledgeTriple, get_entities, parse_triples
|
||||||
@@ -12,7 +13,7 @@ from langchain.memory.prompt import (
|
|||||||
)
|
)
|
||||||
from langchain.memory.utils import get_buffer_string, get_prompt_input_key
|
from langchain.memory.utils import get_buffer_string, get_prompt_input_key
|
||||||
from langchain.prompts.base import BasePromptTemplate
|
from langchain.prompts.base import BasePromptTemplate
|
||||||
from langchain.schema import BaseLanguageModel, SystemMessage
|
from langchain.schema import SystemMessage
|
||||||
|
|
||||||
|
|
||||||
class ConversationKGMemory(BaseChatMemory, BaseModel):
|
class ConversationKGMemory(BaseChatMemory, BaseModel):
|
||||||
|
|||||||
@@ -2,12 +2,13 @@ from typing import Any, Dict, List
|
|||||||
|
|
||||||
from pydantic import BaseModel, root_validator
|
from pydantic import BaseModel, root_validator
|
||||||
|
|
||||||
|
from langchain.base_language_model import BaseLanguageModel
|
||||||
from langchain.chains.llm import LLMChain
|
from langchain.chains.llm import LLMChain
|
||||||
from langchain.memory.chat_memory import BaseChatMemory
|
from langchain.memory.chat_memory import BaseChatMemory
|
||||||
from langchain.memory.prompt import SUMMARY_PROMPT
|
from langchain.memory.prompt import SUMMARY_PROMPT
|
||||||
from langchain.memory.utils import get_buffer_string
|
from langchain.memory.utils import get_buffer_string
|
||||||
from langchain.prompts.base import BasePromptTemplate
|
from langchain.prompts.base import BasePromptTemplate
|
||||||
from langchain.schema import BaseLanguageModel, BaseMessage, SystemMessage
|
from langchain.schema import BaseMessage, SystemMessage
|
||||||
|
|
||||||
|
|
||||||
class SummarizerMixin(BaseModel):
|
class SummarizerMixin(BaseModel):
|
||||||
|
|||||||
@@ -44,19 +44,26 @@ class BaseMessage(BaseModel):
|
|||||||
class HumanMessage(BaseMessage):
|
class HumanMessage(BaseMessage):
|
||||||
"""Type of message that is spoken by the human."""
|
"""Type of message that is spoken by the human."""
|
||||||
|
|
||||||
|
_type = "human"
|
||||||
|
|
||||||
|
|
||||||
class AIMessage(BaseMessage):
|
class AIMessage(BaseMessage):
|
||||||
"""Type of message that is spoken by the AI."""
|
"""Type of message that is spoken by the AI."""
|
||||||
|
|
||||||
|
_type = "ai"
|
||||||
|
|
||||||
|
|
||||||
class SystemMessage(BaseMessage):
|
class SystemMessage(BaseMessage):
|
||||||
"""Type of message that is a system message."""
|
"""Type of message that is a system message."""
|
||||||
|
|
||||||
|
_type = "system"
|
||||||
|
|
||||||
|
|
||||||
class ChatMessage(BaseMessage):
|
class ChatMessage(BaseMessage):
|
||||||
"""Type of message with arbitrary speaker."""
|
"""Type of message with arbitrary speaker."""
|
||||||
|
|
||||||
role: str
|
role: str
|
||||||
|
_type = "chat"
|
||||||
|
|
||||||
|
|
||||||
class ChatGeneration(Generation):
|
class ChatGeneration(Generation):
|
||||||
@@ -100,41 +107,6 @@ class PromptValue(BaseModel, ABC):
|
|||||||
"""Return prompt as messages."""
|
"""Return prompt as messages."""
|
||||||
|
|
||||||
|
|
||||||
class BaseLanguageModel(BaseModel, ABC):
|
|
||||||
@abstractmethod
|
|
||||||
def generate_prompt(
|
|
||||||
self, prompts: List[PromptValue], stop: Optional[List[str]] = None
|
|
||||||
) -> LLMResult:
|
|
||||||
"""Take in a list of prompt values and return an LLMResult."""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def agenerate_prompt(
|
|
||||||
self, prompts: List[PromptValue], stop: Optional[List[str]] = None
|
|
||||||
) -> LLMResult:
|
|
||||||
"""Take in a list of prompt values and return an LLMResult."""
|
|
||||||
|
|
||||||
def get_num_tokens(self, text: str) -> int:
|
|
||||||
"""Get the number of tokens present in the text."""
|
|
||||||
# TODO: this method may not be exact.
|
|
||||||
# TODO: this method may differ based on model (eg codex).
|
|
||||||
try:
|
|
||||||
from transformers import GPT2TokenizerFast
|
|
||||||
except ImportError:
|
|
||||||
raise ValueError(
|
|
||||||
"Could not import transformers python package. "
|
|
||||||
"This is needed in order to calculate get_num_tokens. "
|
|
||||||
"Please it install it with `pip install transformers`."
|
|
||||||
)
|
|
||||||
# create a GPT-3 tokenizer instance
|
|
||||||
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
|
|
||||||
|
|
||||||
# tokenize the text using the GPT-3 tokenizer
|
|
||||||
tokenized_text = tokenizer.tokenize(text)
|
|
||||||
|
|
||||||
# calculate the number of tokens in the tokenized text
|
|
||||||
return len(tokenized_text)
|
|
||||||
|
|
||||||
|
|
||||||
class BaseMemory(BaseModel, ABC):
|
class BaseMemory(BaseModel, ABC):
|
||||||
"""Base interface for memory in chains."""
|
"""Base interface for memory in chains."""
|
||||||
|
|
||||||
|
|||||||
@@ -85,10 +85,10 @@ def _perform_nested_run(tracer: BaseTracer) -> None:
|
|||||||
"""Perform a nested run."""
|
"""Perform a nested run."""
|
||||||
tracer.on_chain_start(serialized={}, inputs={})
|
tracer.on_chain_start(serialized={}, inputs={})
|
||||||
tracer.on_tool_start(serialized={}, input_str="test")
|
tracer.on_tool_start(serialized={}, input_str="test")
|
||||||
tracer.on_llm_start(serialized={}, prompts=[])
|
tracer.on_llm_start_prompt_value(serialized={}, prompts=[])
|
||||||
tracer.on_llm_end(response=LLMResult(generations=[[]]))
|
tracer.on_llm_end(response=LLMResult(generations=[[]]))
|
||||||
tracer.on_tool_end("test")
|
tracer.on_tool_end("test")
|
||||||
tracer.on_llm_start(serialized={}, prompts=[])
|
tracer.on_llm_start_prompt_value(serialized={}, prompts=[])
|
||||||
tracer.on_llm_end(response=LLMResult(generations=[[]]))
|
tracer.on_llm_end(response=LLMResult(generations=[[]]))
|
||||||
tracer.on_chain_end(outputs={})
|
tracer.on_chain_end(outputs={})
|
||||||
|
|
||||||
@@ -216,7 +216,7 @@ def test_tracer_llm_run() -> None:
|
|||||||
tracer = FakeTracer()
|
tracer = FakeTracer()
|
||||||
|
|
||||||
tracer.new_session()
|
tracer.new_session()
|
||||||
tracer.on_llm_start(serialized={}, prompts=[])
|
tracer.on_llm_start_prompt_value(serialized={}, prompts=[])
|
||||||
tracer.on_llm_end(response=LLMResult(generations=[[]]))
|
tracer.on_llm_end(response=LLMResult(generations=[[]]))
|
||||||
assert tracer.runs == [compare_run]
|
assert tracer.runs == [compare_run]
|
||||||
|
|
||||||
@@ -227,7 +227,7 @@ def test_tracer_llm_run_errors_no_session() -> None:
|
|||||||
tracer = FakeTracer()
|
tracer = FakeTracer()
|
||||||
|
|
||||||
with pytest.raises(TracerException):
|
with pytest.raises(TracerException):
|
||||||
tracer.on_llm_start(serialized={}, prompts=[])
|
tracer.on_llm_start_prompt_value(serialized={}, prompts=[])
|
||||||
|
|
||||||
|
|
||||||
@freeze_time("2023-01-01")
|
@freeze_time("2023-01-01")
|
||||||
@@ -260,7 +260,7 @@ def test_tracer_multiple_llm_runs() -> None:
|
|||||||
tracer.new_session()
|
tracer.new_session()
|
||||||
num_runs = 10
|
num_runs = 10
|
||||||
for _ in range(num_runs):
|
for _ in range(num_runs):
|
||||||
tracer.on_llm_start(serialized={}, prompts=[])
|
tracer.on_llm_start_prompt_value(serialized={}, prompts=[])
|
||||||
tracer.on_llm_end(response=LLMResult(generations=[[]]))
|
tracer.on_llm_end(response=LLMResult(generations=[[]]))
|
||||||
|
|
||||||
assert tracer.runs == [compare_run] * num_runs
|
assert tracer.runs == [compare_run] * num_runs
|
||||||
@@ -342,7 +342,7 @@ def test_tracer_llm_run_on_error() -> None:
|
|||||||
tracer = FakeTracer()
|
tracer = FakeTracer()
|
||||||
|
|
||||||
tracer.new_session()
|
tracer.new_session()
|
||||||
tracer.on_llm_start(serialized={}, prompts=[])
|
tracer.on_llm_start_prompt_value(serialized={}, prompts=[])
|
||||||
tracer.on_llm_error(exception)
|
tracer.on_llm_error(exception)
|
||||||
assert tracer.runs == [compare_run]
|
assert tracer.runs == [compare_run]
|
||||||
|
|
||||||
@@ -408,12 +408,12 @@ def test_tracer_nested_runs_on_error() -> None:
|
|||||||
|
|
||||||
for _ in range(3):
|
for _ in range(3):
|
||||||
tracer.on_chain_start(serialized={}, inputs={})
|
tracer.on_chain_start(serialized={}, inputs={})
|
||||||
tracer.on_llm_start(serialized={}, prompts=[])
|
tracer.on_llm_start_prompt_value(serialized={}, prompts=[])
|
||||||
tracer.on_llm_end(response=LLMResult(generations=[[]]))
|
tracer.on_llm_end(response=LLMResult(generations=[[]]))
|
||||||
tracer.on_llm_start(serialized={}, prompts=[])
|
tracer.on_llm_start_prompt_value(serialized={}, prompts=[])
|
||||||
tracer.on_llm_end(response=LLMResult(generations=[[]]))
|
tracer.on_llm_end(response=LLMResult(generations=[[]]))
|
||||||
tracer.on_tool_start(serialized={}, input_str="test")
|
tracer.on_tool_start(serialized={}, input_str="test")
|
||||||
tracer.on_llm_start(serialized={}, prompts=[])
|
tracer.on_llm_start_prompt_value(serialized={}, prompts=[])
|
||||||
tracer.on_llm_error(exception)
|
tracer.on_llm_error(exception)
|
||||||
tracer.on_tool_error(exception)
|
tracer.on_tool_error(exception)
|
||||||
tracer.on_chain_error(exception)
|
tracer.on_chain_error(exception)
|
||||||
|
|||||||
Reference in New Issue
Block a user