Compare commits

...

8 Commits

Author SHA1 Message Date
Ankush Gola
fe353c610f more lint 2023-03-09 21:41:36 -08:00
Ankush Gola
f897e7102b lint 2023-03-09 21:40:44 -08:00
Ankush Gola
366de1bd58 add type 2023-03-09 21:39:52 -08:00
Ankush Gola
4553f64b4b get everything working 2023-03-09 21:29:29 -08:00
Ankush Gola
3c7a559e77 fix tests 2023-03-09 20:44:50 -08:00
Ankush Gola
7d465cbc2f refactor 2023-03-09 20:41:50 -08:00
Ankush Gola
ecdfbfe1c7 cr 2023-03-09 20:18:05 -08:00
Ankush Gola
4d366eeea3 add on_llm_new_token_prompt_value 2023-03-09 19:27:42 -08:00
21 changed files with 317 additions and 154 deletions

View File

@@ -11,11 +11,20 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 3, "execution_count": 1,
"id": "5268c7fa", "id": "5268c7fa",
"metadata": {}, "metadata": {
"tags": []
},
"outputs": [], "outputs": [],
"source": [ "source": [
"import os\n",
"os.environ[\"LANGCHAIN_HANDLER\"] = \"langchain\"\n",
"\n",
"## Uncomment this if using hosted setup.\n",
"\n",
"os.environ[\"LANGCHAIN_ENDPOINT\"] = \"http://127.0.0.1:8000\" \n",
"\n",
"from langchain.agents import ZeroShotAgent, Tool, AgentExecutor\n", "from langchain.agents import ZeroShotAgent, Tool, AgentExecutor\n",
"from langchain.chains import LLMChain\n", "from langchain.chains import LLMChain\n",
"from langchain.utilities import SerpAPIWrapper" "from langchain.utilities import SerpAPIWrapper"
@@ -23,9 +32,11 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 4, "execution_count": 2,
"id": "fbaa4dbe", "id": "fbaa4dbe",
"metadata": {}, "metadata": {
"tags": []
},
"outputs": [], "outputs": [],
"source": [ "source": [
"search = SerpAPIWrapper()\n", "search = SerpAPIWrapper()\n",
@@ -40,9 +51,11 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 6, "execution_count": 3,
"id": "f3ba6f08", "id": "f3ba6f08",
"metadata": {}, "metadata": {
"tags": []
},
"outputs": [], "outputs": [],
"source": [ "source": [
"prefix = \"\"\"Answer the following questions as best you can, but speaking as a pirate might speak. You have access to the following tools:\"\"\"\n", "prefix = \"\"\"Answer the following questions as best you can, but speaking as a pirate might speak. You have access to the following tools:\"\"\"\n",
@@ -58,9 +71,11 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 7, "execution_count": 4,
"id": "3547a37d", "id": "3547a37d",
"metadata": {}, "metadata": {
"tags": []
},
"outputs": [], "outputs": [],
"source": [ "source": [
"from langchain.chat_models import ChatOpenAI\n", "from langchain.chat_models import ChatOpenAI\n",
@@ -79,9 +94,11 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 8, "execution_count": 5,
"id": "a78f886f", "id": "a78f886f",
"metadata": {}, "metadata": {
"tags": []
},
"outputs": [], "outputs": [],
"source": [ "source": [
"messages = [\n", "messages = [\n",
@@ -94,9 +111,11 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 9, "execution_count": 6,
"id": "dadadd70", "id": "dadadd70",
"metadata": {}, "metadata": {
"tags": []
},
"outputs": [], "outputs": [],
"source": [ "source": [
"prompt = ChatPromptTemplate.from_messages(messages)" "prompt = ChatPromptTemplate.from_messages(messages)"
@@ -104,9 +123,11 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 10, "execution_count": 7,
"id": "b7180182", "id": "b7180182",
"metadata": {}, "metadata": {
"tags": []
},
"outputs": [], "outputs": [],
"source": [ "source": [
"llm_chain = LLMChain(llm=ChatOpenAI(temperature=0), prompt=prompt)" "llm_chain = LLMChain(llm=ChatOpenAI(temperature=0), prompt=prompt)"
@@ -114,9 +135,11 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 11, "execution_count": 8,
"id": "ddddb07b", "id": "ddddb07b",
"metadata": {}, "metadata": {
"tags": []
},
"outputs": [], "outputs": [],
"source": [ "source": [
"tool_names = [tool.name for tool in tools]\n", "tool_names = [tool.name for tool in tools]\n",
@@ -125,9 +148,11 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 12, "execution_count": 9,
"id": "36aef054", "id": "36aef054",
"metadata": {}, "metadata": {
"tags": []
},
"outputs": [], "outputs": [],
"source": [ "source": [
"agent_executor = AgentExecutor.from_agent_and_tools(agent=agent, tools=tools, verbose=True)" "agent_executor = AgentExecutor.from_agent_and_tools(agent=agent, tools=tools, verbose=True)"
@@ -135,9 +160,11 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 13, "execution_count": 10,
"id": "33a4d6cc", "id": "33a4d6cc",
"metadata": {}, "metadata": {
"tags": []
},
"outputs": [ "outputs": [
{ {
"name": "stdout", "name": "stdout",
@@ -146,16 +173,16 @@
"\n", "\n",
"\n", "\n",
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
"\u001b[32;1m\u001b[1;3mArrr, ye be in luck, matey! I'll find ye the answer to yer question.\n", "\u001b[32;1m\u001b[1;3mArrr, ye be in luck, matey! I'll help ye find the answer ye seek.\n",
"\n", "\n",
"Thought: I need to search for the current population of Canada.\n", "Thought: Hmm, I don't have this information off the top of me head.\n",
"Action: Search\n", "Action: Search\n",
"Action Input: \"current population of Canada 2023\"\n", "Action Input: \"Canada population 2023\"\n",
"\u001b[0m\n", "\u001b[0m\n",
"Observation: \u001b[36;1m\u001b[1;3mThe current population of Canada is 38,623,091 as of Saturday, March 4, 2023, based on Worldometer elaboration of the latest United Nations data.\u001b[0m\n", "Observation: \u001b[36;1m\u001b[1;3mThe current population of Canada is 38,627,607 as of Thursday, March 9, 2023, based on Worldometer elaboration of the latest United Nations data.\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3mAhoy, me hearties! I've found the answer to yer question.\n", "Thought:\u001b[32;1m\u001b[1;3mAhoy, me hearties! I have found the answer to yer question.\n",
"\n", "\n",
"Final Answer: As of March 4, 2023, the population of Canada be 38,623,091. Arrr!\u001b[0m\n", "Final Answer: As of March 9, 2023, the population of Canada is 38,627,607. Arrr!\u001b[0m\n",
"\n", "\n",
"\u001b[1m> Finished chain.\u001b[0m\n" "\u001b[1m> Finished chain.\u001b[0m\n"
] ]
@@ -163,10 +190,10 @@
{ {
"data": { "data": {
"text/plain": [ "text/plain": [
"'As of March 4, 2023, the population of Canada be 38,623,091. Arrr!'" "'As of March 9, 2023, the population of Canada is 38,627,607. Arrr!'"
] ]
}, },
"execution_count": 13, "execution_count": 10,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
@@ -200,7 +227,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.9.1" "version": "3.10.9"
} }
}, },
"nbformat": 4, "nbformat": 4,

View File

@@ -372,7 +372,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.9.1" "version": "3.10.9"
} }
}, },
"nbformat": 4, "nbformat": 4,

View File

@@ -151,7 +151,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.9.1" "version": "3.10.9"
} }
}, },
"nbformat": 4, "nbformat": 4,

View File

@@ -22,7 +22,7 @@
"\n", "\n",
"## Uncomment this if using hosted setup.\n", "## Uncomment this if using hosted setup.\n",
"\n", "\n",
"# os.environ[\"LANGCHAIN_ENDPOINT\"] = \"https://langchain-api-gateway-57eoxz8z.uc.gateway.dev\" \n", "os.environ[\"LANGCHAIN_ENDPOINT\"] = \"http://127.0.0.1:8000\" \n",
"\n", "\n",
"## Uncomment this if you want traces to be recorded to \"my_session\" instead of default.\n", "## Uncomment this if you want traces to be recorded to \"my_session\" instead of default.\n",
"\n", "\n",
@@ -89,9 +89,30 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 4,
"id": "25addd7f", "id": "25addd7f",
"metadata": {}, "metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'\\n\\nQ: What did the fish say when it hit the wall?\\nA: Dam!'"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"llm(\"tell me a joke\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7afdd321-796f-4dcc-816c-c1160aa943af",
"metadata": {},
"outputs": [], "outputs": [],
"source": [] "source": []
} }

View File

@@ -3,6 +3,7 @@ from typing import Any, List, Optional, Sequence, Tuple
from langchain.agents.agent import Agent from langchain.agents.agent import Agent
from langchain.agents.chat.prompt import FORMAT_INSTRUCTIONS, PREFIX, SUFFIX from langchain.agents.chat.prompt import FORMAT_INSTRUCTIONS, PREFIX, SUFFIX
from langchain.base_language_model import BaseLanguageModel
from langchain.callbacks.base import BaseCallbackManager from langchain.callbacks.base import BaseCallbackManager
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
from langchain.prompts.base import BasePromptTemplate from langchain.prompts.base import BasePromptTemplate
@@ -11,7 +12,7 @@ from langchain.prompts.chat import (
HumanMessagePromptTemplate, HumanMessagePromptTemplate,
SystemMessagePromptTemplate, SystemMessagePromptTemplate,
) )
from langchain.schema import AgentAction, BaseLanguageModel from langchain.schema import AgentAction
from langchain.tools import BaseTool from langchain.tools import BaseTool
FINAL_ANSWER_ACTION = "Final Answer:" FINAL_ANSWER_ACTION = "Final Answer:"

View File

@@ -0,0 +1,124 @@
"""Base class for language models."""
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import List, Optional
from pydantic import BaseModel, Extra, Field, validator
import langchain
from langchain.callbacks import BaseCallbackManager, get_callback_manager
from langchain.schema import LLMResult, PromptValue
def _get_verbosity() -> bool:
return langchain.verbose
class BaseLanguageModel(BaseModel, ABC):
"""Base class for language models."""
verbose: bool = Field(default_factory=_get_verbosity)
"""Whether to print out response text."""
callback_manager: BaseCallbackManager = Field(default_factory=get_callback_manager)
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
@validator("callback_manager", pre=True, always=True)
def set_callback_manager(
cls, callback_manager: Optional[BaseCallbackManager]
) -> BaseCallbackManager:
"""If callback manager is None, set it.
This allows users to pass in None as callback manager, which is a nice UX.
"""
return callback_manager or get_callback_manager()
@validator("verbose", pre=True, always=True)
def set_verbose(cls, verbose: Optional[bool]) -> bool:
"""If verbose is None, set it.
This allows users to pass in None as verbose to access the global setting.
"""
if verbose is None:
return _get_verbosity()
else:
return verbose
def generate_prompt(
self, prompts: List[PromptValue], stop: Optional[List[str]] = None
) -> LLMResult:
"""Take in a list of prompt values and return an LLMResult."""
self.callback_manager.on_llm_start_prompt_value(
{"name": self.__class__.__name__}, prompts, verbose=self.verbose
)
try:
output = self._generate_prompt(prompts, stop=stop)
except (KeyboardInterrupt, Exception) as e:
self.callback_manager.on_llm_error(e, verbose=self.verbose)
raise e
self.callback_manager.on_llm_end(output, verbose=self.verbose)
return output
async def agenerate_prompt(
self, prompts: List[PromptValue], stop: Optional[List[str]] = None
) -> LLMResult:
"""Take in a list of prompt values and return an LLMResult."""
if self.callback_manager.is_async:
await self.callback_manager.on_llm_start_prompt_value(
{"name": self.__class__.__name__}, prompts, verbose=self.verbose
)
else:
self.callback_manager.on_llm_start_prompt_value(
{"name": self.__class__.__name__}, prompts, verbose=self.verbose
)
try:
output = await self._agenerate_prompt(prompts, stop=stop)
except (KeyboardInterrupt, Exception) as e:
if self.callback_manager.is_async:
await self.callback_manager.on_llm_error(e, verbose=self.verbose)
else:
self.callback_manager.on_llm_error(e, verbose=self.verbose)
raise e
if self.callback_manager.is_async:
await self.callback_manager.on_llm_end(output, verbose=self.verbose)
else:
self.callback_manager.on_llm_end(output, verbose=self.verbose)
return output
@abstractmethod
def _generate_prompt(
self, prompts: List[PromptValue], stop: Optional[List[str]] = None
) -> LLMResult:
"""Take in a list of prompt values and return an LLMResult."""
@abstractmethod
async def _agenerate_prompt(
self, prompts: List[PromptValue], stop: Optional[List[str]] = None
) -> LLMResult:
"""Take in a list of prompt values and return an LLMResult."""
def get_num_tokens(self, text: str) -> int:
"""Get the number of tokens present in the text."""
# TODO: this method may not be exact.
# TODO: this method may differ based on model (eg codex).
try:
from transformers import GPT2TokenizerFast
except ImportError:
raise ValueError(
"Could not import transformers python package. "
"This is needed in order to calculate get_num_tokens. "
"Please it install it with `pip install transformers`."
)
# create a GPT-3 tokenizer instance
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
# tokenize the text using the GPT-3 tokenizer
tokenized_text = tokenizer.tokenize(text)
# calculate the number of tokens in the tokenized text
return len(tokenized_text)

View File

@@ -4,7 +4,7 @@ import functools
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Dict, List, Union from typing import Any, Dict, List, Union
from langchain.schema import AgentAction, AgentFinish, LLMResult from langchain.schema import AgentAction, AgentFinish, LLMResult, PromptValue
class BaseCallbackHandler(ABC): class BaseCallbackHandler(ABC):
@@ -30,6 +30,12 @@ class BaseCallbackHandler(ABC):
"""Whether to ignore agent callbacks.""" """Whether to ignore agent callbacks."""
return False return False
def on_llm_start_prompt_value(
self, serialized: Dict[str, Any], prompts: List[PromptValue], **kwargs: Any
) -> Any:
"""Run when LLM starts running."""
pass
@abstractmethod @abstractmethod
def on_llm_start( def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
@@ -127,6 +133,19 @@ class CallbackManager(BaseCallbackManager):
"""Initialize callback manager.""" """Initialize callback manager."""
self.handlers: List[BaseCallbackHandler] = handlers self.handlers: List[BaseCallbackHandler] = handlers
def on_llm_start_prompt_value(
self,
serialized: Dict[str, Any],
prompts: List[PromptValue],
verbose: bool = False,
**kwargs: Any
) -> Any:
"""Run when LLM starts running."""
for handler in self.handlers:
if not handler.ignore_llm:
if verbose or handler.always_verbose:
handler.on_llm_start_prompt_value(serialized, prompts, **kwargs)
def on_llm_start( def on_llm_start(
self, self,
serialized: Dict[str, Any], serialized: Dict[str, Any],
@@ -276,6 +295,11 @@ class CallbackManager(BaseCallbackManager):
class AsyncCallbackHandler(BaseCallbackHandler): class AsyncCallbackHandler(BaseCallbackHandler):
"""Async callback handler that can be used to handle callbacks from langchain.""" """Async callback handler that can be used to handle callbacks from langchain."""
async def on_llm_start_prompt_value(
self, serialized: Dict[str, Any], prompts: List[PromptValue], **kwargs: Any
) -> Any:
"""Run when LLM starts running."""
async def on_llm_start( async def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None: ) -> None:
@@ -340,6 +364,32 @@ class AsyncCallbackManager(BaseCallbackManager):
"""Initialize callback manager.""" """Initialize callback manager."""
self.handlers: List[BaseCallbackHandler] = handlers self.handlers: List[BaseCallbackHandler] = handlers
async def on_llm_start_prompt_value(
self,
serialized: Dict[str, Any],
prompts: List[PromptValue],
verbose: bool = False,
**kwargs: Any
) -> Any:
"""Run when LLM starts running."""
for handler in self.handlers:
if not handler.ignore_llm:
if verbose or handler.always_verbose:
if asyncio.iscoroutinefunction(handler.on_llm_start_prompt_value):
return await handler.on_llm_start_prompt_value(
serialized, prompts, **kwargs
)
else:
return await asyncio.get_event_loop().run_in_executor(
None,
functools.partial(
handler.on_llm_start_prompt_value,
serialized,
prompts,
**kwargs
),
)
async def on_llm_start( async def on_llm_start(
self, self,
serialized: Dict[str, Any], serialized: Dict[str, Any],

View File

@@ -8,7 +8,7 @@ from langchain.callbacks.base import (
BaseCallbackManager, BaseCallbackManager,
CallbackManager, CallbackManager,
) )
from langchain.schema import AgentAction, AgentFinish, LLMResult from langchain.schema import AgentAction, AgentFinish, LLMResult, PromptValue
class Singleton: class Singleton:
@@ -34,6 +34,15 @@ class SharedCallbackManager(Singleton, BaseCallbackManager):
_callback_manager: CallbackManager = CallbackManager(handlers=[]) _callback_manager: CallbackManager = CallbackManager(handlers=[])
def on_llm_start_prompt_value(
self, serialized: Dict[str, Any], prompts: List[PromptValue], **kwargs: Any
) -> Any:
"""Run when LLM starts running."""
with self._lock:
self._callback_manager.on_llm_start_prompt_value(
serialized, prompts, **kwargs
)
def on_llm_start( def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None: ) -> None:

View File

@@ -16,7 +16,7 @@ from langchain.callbacks.tracers.schemas import (
TracerSession, TracerSession,
TracerSessionCreate, TracerSessionCreate,
) )
from langchain.schema import AgentAction, AgentFinish, LLMResult from langchain.schema import AgentAction, AgentFinish, LLMResult, PromptValue
class TracerException(Exception): class TracerException(Exception):
@@ -109,9 +109,9 @@ class BaseTracer(BaseCallbackHandler, ABC):
self._execution_order = 1 self._execution_order = 1
self._persist_run(run) self._persist_run(run)
def on_llm_start( def on_llm_start_prompt_value(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any self, serialized: Dict[str, Any], prompts: List[PromptValue], **kwargs: Any
) -> None: ) -> Any:
"""Start a trace for an LLM run.""" """Start a trace for an LLM run."""
if self._session is None: if self._session is None:
raise TracerException( raise TracerException(
@@ -129,6 +129,12 @@ class BaseTracer(BaseCallbackHandler, ABC):
) )
self._start_trace(llm_run) self._start_trace(llm_run)
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
"""Start a trace for an LLM run."""
pass
def on_llm_new_token(self, token: str, **kwargs: Any) -> None: def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Handle a new token for an LLM run.""" """Handle a new token for an LLM run."""
pass pass

View File

@@ -6,7 +6,7 @@ from typing import Any, Dict, List, Optional, Union
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from langchain.schema import LLMResult from langchain.schema import LLMResult, PromptValue
class TracerSessionBase(BaseModel): class TracerSessionBase(BaseModel):
@@ -45,7 +45,7 @@ class BaseRun(BaseModel):
class LLMRun(BaseRun): class LLMRun(BaseRun):
"""Class for LLMRun.""" """Class for LLMRun."""
prompts: List[str] prompts: List[PromptValue]
response: Optional[LLMResult] = None response: Optional[LLMResult] = None

View File

@@ -6,13 +6,13 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from pydantic import BaseModel from pydantic import BaseModel
from langchain.base_language_model import BaseLanguageModel
from langchain.chains.base import Chain from langchain.chains.base import Chain
from langchain.chains.chat_vector_db.prompts import CONDENSE_QUESTION_PROMPT from langchain.chains.chat_vector_db.prompts import CONDENSE_QUESTION_PROMPT
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
from langchain.chains.question_answering import load_qa_chain from langchain.chains.question_answering import load_qa_chain
from langchain.prompts.base import BasePromptTemplate from langchain.prompts.base import BasePromptTemplate
from langchain.schema import BaseLanguageModel
from langchain.vectorstores.base import VectorStore from langchain.vectorstores.base import VectorStore

View File

@@ -5,11 +5,12 @@ from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
from pydantic import BaseModel, Extra from pydantic import BaseModel, Extra
from langchain.base_language_model import BaseLanguageModel
from langchain.chains.base import Chain from langchain.chains.base import Chain
from langchain.input import get_colored_text from langchain.input import get_colored_text
from langchain.prompts.base import BasePromptTemplate from langchain.prompts.base import BasePromptTemplate
from langchain.prompts.prompt import PromptTemplate from langchain.prompts.prompt import PromptTemplate
from langchain.schema import BaseLanguageModel, LLMResult, PromptValue from langchain.schema import LLMResult, PromptValue
class LLMChain(Chain, BaseModel): class LLMChain(Chain, BaseModel):

View File

@@ -3,10 +3,10 @@ from typing import Callable, List, Tuple
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from langchain.base_language_model import BaseLanguageModel
from langchain.chat_models.base import BaseChatModel from langchain.chat_models.base import BaseChatModel
from langchain.llms.base import BaseLLM from langchain.llms.base import BaseLLM
from langchain.prompts.base import BasePromptTemplate from langchain.prompts.base import BasePromptTemplate
from langchain.schema import BaseLanguageModel
class BasePromptSelector(BaseModel, ABC): class BasePromptSelector(BaseModel, ABC):

View File

@@ -1,6 +1,7 @@
"""Load question answering chains.""" """Load question answering chains."""
from typing import Any, Mapping, Optional, Protocol from typing import Any, Mapping, Optional, Protocol
from langchain.base_language_model import BaseLanguageModel
from langchain.callbacks.base import BaseCallbackManager from langchain.callbacks.base import BaseCallbackManager
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain
@@ -15,7 +16,6 @@ from langchain.chains.question_answering import (
stuff_prompt, stuff_prompt,
) )
from langchain.prompts.base import BasePromptTemplate from langchain.prompts.base import BasePromptTemplate
from langchain.schema import BaseLanguageModel
class LoadingCallable(Protocol): class LoadingCallable(Protocol):

View File

@@ -1,14 +1,12 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import List, Optional from typing import List, Optional
from pydantic import BaseModel, Extra, Field, validator from pydantic import BaseModel
import langchain import langchain
from langchain.callbacks import get_callback_manager from langchain.base_language_model import BaseLanguageModel
from langchain.callbacks.base import BaseCallbackManager
from langchain.schema import ( from langchain.schema import (
AIMessage, AIMessage,
BaseLanguageModel,
BaseMessage, BaseMessage,
ChatGeneration, ChatGeneration,
ChatResult, ChatResult,
@@ -22,25 +20,7 @@ def _get_verbosity() -> bool:
class BaseChatModel(BaseLanguageModel, BaseModel, ABC): class BaseChatModel(BaseLanguageModel, BaseModel, ABC):
verbose: bool = Field(default_factory=_get_verbosity) """Base class for chat models."""
"""Whether to print out response text."""
callback_manager: BaseCallbackManager = Field(default_factory=get_callback_manager)
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
@validator("callback_manager", pre=True, always=True)
def set_callback_manager(
cls, callback_manager: Optional[BaseCallbackManager]
) -> BaseCallbackManager:
"""If callback manager is None, set it.
This allows users to pass in None as callback manager, which is a nice UX.
"""
return callback_manager or get_callback_manager()
def generate( def generate(
self, messages: List[List[BaseMessage]], stop: Optional[List[str]] = None self, messages: List[List[BaseMessage]], stop: Optional[List[str]] = None
@@ -56,13 +36,13 @@ class BaseChatModel(BaseLanguageModel, BaseModel, ABC):
results = [await self._agenerate(m, stop=stop) for m in messages] results = [await self._agenerate(m, stop=stop) for m in messages]
return LLMResult(generations=[res.generations for res in results]) return LLMResult(generations=[res.generations for res in results])
def generate_prompt( def _generate_prompt(
self, prompts: List[PromptValue], stop: Optional[List[str]] = None self, prompts: List[PromptValue], stop: Optional[List[str]] = None
) -> LLMResult: ) -> LLMResult:
prompt_messages = [p.to_messages() for p in prompts] prompt_messages = [p.to_messages() for p in prompts]
return self.generate(prompt_messages, stop=stop) return self.generate(prompt_messages, stop=stop)
async def agenerate_prompt( async def _agenerate_prompt(
self, prompts: List[PromptValue], stop: Optional[List[str]] = None self, prompts: List[PromptValue], stop: Optional[List[str]] = None
) -> LLMResult: ) -> LLMResult:
prompt_messages = [p.to_messages() for p in prompts] prompt_messages = [p.to_messages() for p in prompts]

View File

@@ -5,16 +5,11 @@ from pathlib import Path
from typing import Any, Dict, List, Mapping, Optional, Tuple, Union from typing import Any, Dict, List, Mapping, Optional, Tuple, Union
import yaml import yaml
from pydantic import BaseModel, Extra, Field, validator from pydantic import BaseModel, Extra
import langchain import langchain
from langchain.callbacks import get_callback_manager from langchain.base_language_model import BaseLanguageModel
from langchain.callbacks.base import BaseCallbackManager from langchain.schema import Generation, LLMResult, PromptValue
from langchain.schema import BaseLanguageModel, Generation, LLMResult, PromptValue
def _get_verbosity() -> bool:
return langchain.verbose
def get_prompts( def get_prompts(
@@ -57,9 +52,6 @@ class BaseLLM(BaseLanguageModel, BaseModel, ABC):
"""LLM wrapper should take in a prompt and return a string.""" """LLM wrapper should take in a prompt and return a string."""
cache: Optional[bool] = None cache: Optional[bool] = None
verbose: bool = Field(default_factory=_get_verbosity)
"""Whether to print out response text."""
callback_manager: BaseCallbackManager = Field(default_factory=get_callback_manager)
class Config: class Config:
"""Configuration for this pydantic object.""" """Configuration for this pydantic object."""
@@ -67,27 +59,6 @@ class BaseLLM(BaseLanguageModel, BaseModel, ABC):
extra = Extra.forbid extra = Extra.forbid
arbitrary_types_allowed = True arbitrary_types_allowed = True
@validator("callback_manager", pre=True, always=True)
def set_callback_manager(
cls, callback_manager: Optional[BaseCallbackManager]
) -> BaseCallbackManager:
"""If callback manager is None, set it.
This allows users to pass in None as callback manager, which is a nice UX.
"""
return callback_manager or get_callback_manager()
@validator("verbose", pre=True, always=True)
def set_verbose(cls, verbose: Optional[bool]) -> bool:
"""If verbose is None, set it.
This allows users to pass in None as verbose to access the global setting.
"""
if verbose is None:
return _get_verbosity()
else:
return verbose
@abstractmethod @abstractmethod
def _generate( def _generate(
self, prompts: List[str], stop: Optional[List[str]] = None self, prompts: List[str], stop: Optional[List[str]] = None
@@ -100,13 +71,13 @@ class BaseLLM(BaseLanguageModel, BaseModel, ABC):
) -> LLMResult: ) -> LLMResult:
"""Run the LLM on the given prompts.""" """Run the LLM on the given prompts."""
def generate_prompt( def _generate_prompt(
self, prompts: List[PromptValue], stop: Optional[List[str]] = None self, prompts: List[PromptValue], stop: Optional[List[str]] = None
) -> LLMResult: ) -> LLMResult:
prompt_strings = [p.to_string() for p in prompts] prompt_strings = [p.to_string() for p in prompts]
return self.generate(prompt_strings, stop=stop) return self.generate(prompt_strings, stop=stop)
async def agenerate_prompt( async def _agenerate_prompt(
self, prompts: List[PromptValue], stop: Optional[List[str]] = None self, prompts: List[PromptValue], stop: Optional[List[str]] = None
) -> LLMResult: ) -> LLMResult:
prompt_strings = [p.to_string() for p in prompts] prompt_strings = [p.to_string() for p in prompts]
@@ -138,7 +109,6 @@ class BaseLLM(BaseLanguageModel, BaseModel, ABC):
except (KeyboardInterrupt, Exception) as e: except (KeyboardInterrupt, Exception) as e:
self.callback_manager.on_llm_error(e, verbose=self.verbose) self.callback_manager.on_llm_error(e, verbose=self.verbose)
raise e raise e
self.callback_manager.on_llm_end(output, verbose=self.verbose)
return output return output
params = self.dict() params = self.dict()
params["stop"] = stop params["stop"] = stop
@@ -157,7 +127,6 @@ class BaseLLM(BaseLanguageModel, BaseModel, ABC):
except (KeyboardInterrupt, Exception) as e: except (KeyboardInterrupt, Exception) as e:
self.callback_manager.on_llm_error(e, verbose=self.verbose) self.callback_manager.on_llm_error(e, verbose=self.verbose)
raise e raise e
self.callback_manager.on_llm_end(new_results, verbose=self.verbose)
llm_output = update_cache( llm_output = update_cache(
existing_prompts, llm_string, missing_prompt_idxs, new_results, prompts existing_prompts, llm_string, missing_prompt_idxs, new_results, prompts
) )

View File

@@ -2,6 +2,7 @@ from typing import Any, Dict, List, Optional
from pydantic import BaseModel from pydantic import BaseModel
from langchain.base_language_model import BaseLanguageModel
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
from langchain.memory.chat_memory import BaseChatMemory from langchain.memory.chat_memory import BaseChatMemory
from langchain.memory.prompt import ( from langchain.memory.prompt import (
@@ -10,7 +11,7 @@ from langchain.memory.prompt import (
) )
from langchain.memory.utils import get_buffer_string, get_prompt_input_key from langchain.memory.utils import get_buffer_string, get_prompt_input_key
from langchain.prompts.base import BasePromptTemplate from langchain.prompts.base import BasePromptTemplate
from langchain.schema import BaseLanguageModel, BaseMessage from langchain.schema import BaseMessage
class ConversationEntityMemory(BaseChatMemory, BaseModel): class ConversationEntityMemory(BaseChatMemory, BaseModel):

View File

@@ -2,6 +2,7 @@ from typing import Any, Dict, List
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from langchain.base_language_model import BaseLanguageModel
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
from langchain.graphs import NetworkxEntityGraph from langchain.graphs import NetworkxEntityGraph
from langchain.graphs.networkx_graph import KnowledgeTriple, get_entities, parse_triples from langchain.graphs.networkx_graph import KnowledgeTriple, get_entities, parse_triples
@@ -12,7 +13,7 @@ from langchain.memory.prompt import (
) )
from langchain.memory.utils import get_buffer_string, get_prompt_input_key from langchain.memory.utils import get_buffer_string, get_prompt_input_key
from langchain.prompts.base import BasePromptTemplate from langchain.prompts.base import BasePromptTemplate
from langchain.schema import BaseLanguageModel, SystemMessage from langchain.schema import SystemMessage
class ConversationKGMemory(BaseChatMemory, BaseModel): class ConversationKGMemory(BaseChatMemory, BaseModel):

View File

@@ -2,12 +2,13 @@ from typing import Any, Dict, List
from pydantic import BaseModel, root_validator from pydantic import BaseModel, root_validator
from langchain.base_language_model import BaseLanguageModel
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
from langchain.memory.chat_memory import BaseChatMemory from langchain.memory.chat_memory import BaseChatMemory
from langchain.memory.prompt import SUMMARY_PROMPT from langchain.memory.prompt import SUMMARY_PROMPT
from langchain.memory.utils import get_buffer_string from langchain.memory.utils import get_buffer_string
from langchain.prompts.base import BasePromptTemplate from langchain.prompts.base import BasePromptTemplate
from langchain.schema import BaseLanguageModel, BaseMessage, SystemMessage from langchain.schema import BaseMessage, SystemMessage
class SummarizerMixin(BaseModel): class SummarizerMixin(BaseModel):

View File

@@ -44,19 +44,26 @@ class BaseMessage(BaseModel):
class HumanMessage(BaseMessage): class HumanMessage(BaseMessage):
"""Type of message that is spoken by the human.""" """Type of message that is spoken by the human."""
_type = "human"
class AIMessage(BaseMessage): class AIMessage(BaseMessage):
"""Type of message that is spoken by the AI.""" """Type of message that is spoken by the AI."""
_type = "ai"
class SystemMessage(BaseMessage): class SystemMessage(BaseMessage):
"""Type of message that is a system message.""" """Type of message that is a system message."""
_type = "system"
class ChatMessage(BaseMessage): class ChatMessage(BaseMessage):
"""Type of message with arbitrary speaker.""" """Type of message with arbitrary speaker."""
role: str role: str
_type = "chat"
class ChatGeneration(Generation): class ChatGeneration(Generation):
@@ -100,41 +107,6 @@ class PromptValue(BaseModel, ABC):
"""Return prompt as messages.""" """Return prompt as messages."""
class BaseLanguageModel(BaseModel, ABC):
@abstractmethod
def generate_prompt(
self, prompts: List[PromptValue], stop: Optional[List[str]] = None
) -> LLMResult:
"""Take in a list of prompt values and return an LLMResult."""
@abstractmethod
async def agenerate_prompt(
self, prompts: List[PromptValue], stop: Optional[List[str]] = None
) -> LLMResult:
"""Take in a list of prompt values and return an LLMResult."""
def get_num_tokens(self, text: str) -> int:
"""Get the number of tokens present in the text."""
# TODO: this method may not be exact.
# TODO: this method may differ based on model (eg codex).
try:
from transformers import GPT2TokenizerFast
except ImportError:
raise ValueError(
"Could not import transformers python package. "
"This is needed in order to calculate get_num_tokens. "
"Please it install it with `pip install transformers`."
)
# create a GPT-3 tokenizer instance
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
# tokenize the text using the GPT-3 tokenizer
tokenized_text = tokenizer.tokenize(text)
# calculate the number of tokens in the tokenized text
return len(tokenized_text)
class BaseMemory(BaseModel, ABC): class BaseMemory(BaseModel, ABC):
"""Base interface for memory in chains.""" """Base interface for memory in chains."""

View File

@@ -85,10 +85,10 @@ def _perform_nested_run(tracer: BaseTracer) -> None:
"""Perform a nested run.""" """Perform a nested run."""
tracer.on_chain_start(serialized={}, inputs={}) tracer.on_chain_start(serialized={}, inputs={})
tracer.on_tool_start(serialized={}, input_str="test") tracer.on_tool_start(serialized={}, input_str="test")
tracer.on_llm_start(serialized={}, prompts=[]) tracer.on_llm_start_prompt_value(serialized={}, prompts=[])
tracer.on_llm_end(response=LLMResult(generations=[[]])) tracer.on_llm_end(response=LLMResult(generations=[[]]))
tracer.on_tool_end("test") tracer.on_tool_end("test")
tracer.on_llm_start(serialized={}, prompts=[]) tracer.on_llm_start_prompt_value(serialized={}, prompts=[])
tracer.on_llm_end(response=LLMResult(generations=[[]])) tracer.on_llm_end(response=LLMResult(generations=[[]]))
tracer.on_chain_end(outputs={}) tracer.on_chain_end(outputs={})
@@ -216,7 +216,7 @@ def test_tracer_llm_run() -> None:
tracer = FakeTracer() tracer = FakeTracer()
tracer.new_session() tracer.new_session()
tracer.on_llm_start(serialized={}, prompts=[]) tracer.on_llm_start_prompt_value(serialized={}, prompts=[])
tracer.on_llm_end(response=LLMResult(generations=[[]])) tracer.on_llm_end(response=LLMResult(generations=[[]]))
assert tracer.runs == [compare_run] assert tracer.runs == [compare_run]
@@ -227,7 +227,7 @@ def test_tracer_llm_run_errors_no_session() -> None:
tracer = FakeTracer() tracer = FakeTracer()
with pytest.raises(TracerException): with pytest.raises(TracerException):
tracer.on_llm_start(serialized={}, prompts=[]) tracer.on_llm_start_prompt_value(serialized={}, prompts=[])
@freeze_time("2023-01-01") @freeze_time("2023-01-01")
@@ -260,7 +260,7 @@ def test_tracer_multiple_llm_runs() -> None:
tracer.new_session() tracer.new_session()
num_runs = 10 num_runs = 10
for _ in range(num_runs): for _ in range(num_runs):
tracer.on_llm_start(serialized={}, prompts=[]) tracer.on_llm_start_prompt_value(serialized={}, prompts=[])
tracer.on_llm_end(response=LLMResult(generations=[[]])) tracer.on_llm_end(response=LLMResult(generations=[[]]))
assert tracer.runs == [compare_run] * num_runs assert tracer.runs == [compare_run] * num_runs
@@ -342,7 +342,7 @@ def test_tracer_llm_run_on_error() -> None:
tracer = FakeTracer() tracer = FakeTracer()
tracer.new_session() tracer.new_session()
tracer.on_llm_start(serialized={}, prompts=[]) tracer.on_llm_start_prompt_value(serialized={}, prompts=[])
tracer.on_llm_error(exception) tracer.on_llm_error(exception)
assert tracer.runs == [compare_run] assert tracer.runs == [compare_run]
@@ -408,12 +408,12 @@ def test_tracer_nested_runs_on_error() -> None:
for _ in range(3): for _ in range(3):
tracer.on_chain_start(serialized={}, inputs={}) tracer.on_chain_start(serialized={}, inputs={})
tracer.on_llm_start(serialized={}, prompts=[]) tracer.on_llm_start_prompt_value(serialized={}, prompts=[])
tracer.on_llm_end(response=LLMResult(generations=[[]])) tracer.on_llm_end(response=LLMResult(generations=[[]]))
tracer.on_llm_start(serialized={}, prompts=[]) tracer.on_llm_start_prompt_value(serialized={}, prompts=[])
tracer.on_llm_end(response=LLMResult(generations=[[]])) tracer.on_llm_end(response=LLMResult(generations=[[]]))
tracer.on_tool_start(serialized={}, input_str="test") tracer.on_tool_start(serialized={}, input_str="test")
tracer.on_llm_start(serialized={}, prompts=[]) tracer.on_llm_start_prompt_value(serialized={}, prompts=[])
tracer.on_llm_error(exception) tracer.on_llm_error(exception)
tracer.on_tool_error(exception) tracer.on_tool_error(exception)
tracer.on_chain_error(exception) tracer.on_chain_error(exception)