Compare commits

...

8 Commits

Author SHA1 Message Date
Ankush Gola
fe353c610f more lint 2023-03-09 21:41:36 -08:00
Ankush Gola
f897e7102b lint 2023-03-09 21:40:44 -08:00
Ankush Gola
366de1bd58 add type 2023-03-09 21:39:52 -08:00
Ankush Gola
4553f64b4b get everything working 2023-03-09 21:29:29 -08:00
Ankush Gola
3c7a559e77 fix tests 2023-03-09 20:44:50 -08:00
Ankush Gola
7d465cbc2f refactor 2023-03-09 20:41:50 -08:00
Ankush Gola
ecdfbfe1c7 cr 2023-03-09 20:18:05 -08:00
Ankush Gola
4d366eeea3 add on_llm_new_token_prompt_value 2023-03-09 19:27:42 -08:00
21 changed files with 317 additions and 154 deletions

View File

@@ -11,11 +11,20 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 1,
"id": "5268c7fa",
"metadata": {},
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"import os\n",
"os.environ[\"LANGCHAIN_HANDLER\"] = \"langchain\"\n",
"\n",
"## Uncomment this if using hosted setup.\n",
"\n",
"os.environ[\"LANGCHAIN_ENDPOINT\"] = \"http://127.0.0.1:8000\" \n",
"\n",
"from langchain.agents import ZeroShotAgent, Tool, AgentExecutor\n",
"from langchain.chains import LLMChain\n",
"from langchain.utilities import SerpAPIWrapper"
@@ -23,9 +32,11 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 2,
"id": "fbaa4dbe",
"metadata": {},
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"search = SerpAPIWrapper()\n",
@@ -40,9 +51,11 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 3,
"id": "f3ba6f08",
"metadata": {},
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"prefix = \"\"\"Answer the following questions as best you can, but speaking as a pirate might speak. You have access to the following tools:\"\"\"\n",
@@ -58,9 +71,11 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 4,
"id": "3547a37d",
"metadata": {},
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"from langchain.chat_models import ChatOpenAI\n",
@@ -79,9 +94,11 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 5,
"id": "a78f886f",
"metadata": {},
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"messages = [\n",
@@ -94,9 +111,11 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 6,
"id": "dadadd70",
"metadata": {},
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"prompt = ChatPromptTemplate.from_messages(messages)"
@@ -104,9 +123,11 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 7,
"id": "b7180182",
"metadata": {},
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"llm_chain = LLMChain(llm=ChatOpenAI(temperature=0), prompt=prompt)"
@@ -114,9 +135,11 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 8,
"id": "ddddb07b",
"metadata": {},
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"tool_names = [tool.name for tool in tools]\n",
@@ -125,9 +148,11 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 9,
"id": "36aef054",
"metadata": {},
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"agent_executor = AgentExecutor.from_agent_and_tools(agent=agent, tools=tools, verbose=True)"
@@ -135,9 +160,11 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 10,
"id": "33a4d6cc",
"metadata": {},
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
@@ -146,16 +173,16 @@
"\n",
"\n",
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
"\u001b[32;1m\u001b[1;3mArrr, ye be in luck, matey! I'll find ye the answer to yer question.\n",
"\u001b[32;1m\u001b[1;3mArrr, ye be in luck, matey! I'll help ye find the answer ye seek.\n",
"\n",
"Thought: I need to search for the current population of Canada.\n",
"Thought: Hmm, I don't have this information off the top of me head.\n",
"Action: Search\n",
"Action Input: \"current population of Canada 2023\"\n",
"Action Input: \"Canada population 2023\"\n",
"\u001b[0m\n",
"Observation: \u001b[36;1m\u001b[1;3mThe current population of Canada is 38,623,091 as of Saturday, March 4, 2023, based on Worldometer elaboration of the latest United Nations data.\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3mAhoy, me hearties! I've found the answer to yer question.\n",
"Observation: \u001b[36;1m\u001b[1;3mThe current population of Canada is 38,627,607 as of Thursday, March 9, 2023, based on Worldometer elaboration of the latest United Nations data.\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3mAhoy, me hearties! I have found the answer to yer question.\n",
"\n",
"Final Answer: As of March 4, 2023, the population of Canada be 38,623,091. Arrr!\u001b[0m\n",
"Final Answer: As of March 9, 2023, the population of Canada is 38,627,607. Arrr!\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
@@ -163,10 +190,10 @@
{
"data": {
"text/plain": [
"'As of March 4, 2023, the population of Canada be 38,623,091. Arrr!'"
"'As of March 9, 2023, the population of Canada is 38,627,607. Arrr!'"
]
},
"execution_count": 13,
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
@@ -200,7 +227,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.1"
"version": "3.10.9"
}
},
"nbformat": 4,

View File

@@ -372,7 +372,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.1"
"version": "3.10.9"
}
},
"nbformat": 4,

View File

@@ -151,7 +151,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.1"
"version": "3.10.9"
}
},
"nbformat": 4,

View File

@@ -22,7 +22,7 @@
"\n",
"## Uncomment this if using hosted setup.\n",
"\n",
"# os.environ[\"LANGCHAIN_ENDPOINT\"] = \"https://langchain-api-gateway-57eoxz8z.uc.gateway.dev\" \n",
"os.environ[\"LANGCHAIN_ENDPOINT\"] = \"http://127.0.0.1:8000\" \n",
"\n",
"## Uncomment this if you want traces to be recorded to \"my_session\" instead of default.\n",
"\n",
@@ -89,9 +89,30 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 4,
"id": "25addd7f",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'\\n\\nQ: What did the fish say when it hit the wall?\\nA: Dam!'"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"llm(\"tell me a joke\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7afdd321-796f-4dcc-816c-c1160aa943af",
"metadata": {},
"outputs": [],
"source": []
}

View File

@@ -3,6 +3,7 @@ from typing import Any, List, Optional, Sequence, Tuple
from langchain.agents.agent import Agent
from langchain.agents.chat.prompt import FORMAT_INSTRUCTIONS, PREFIX, SUFFIX
from langchain.base_language_model import BaseLanguageModel
from langchain.callbacks.base import BaseCallbackManager
from langchain.chains.llm import LLMChain
from langchain.prompts.base import BasePromptTemplate
@@ -11,7 +12,7 @@ from langchain.prompts.chat import (
HumanMessagePromptTemplate,
SystemMessagePromptTemplate,
)
from langchain.schema import AgentAction, BaseLanguageModel
from langchain.schema import AgentAction
from langchain.tools import BaseTool
FINAL_ANSWER_ACTION = "Final Answer:"

View File

@@ -0,0 +1,124 @@
"""Base class for language models."""
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import List, Optional
from pydantic import BaseModel, Extra, Field, validator
import langchain
from langchain.callbacks import BaseCallbackManager, get_callback_manager
from langchain.schema import LLMResult, PromptValue
def _get_verbosity() -> bool:
return langchain.verbose
class BaseLanguageModel(BaseModel, ABC):
"""Base class for language models."""
verbose: bool = Field(default_factory=_get_verbosity)
"""Whether to print out response text."""
callback_manager: BaseCallbackManager = Field(default_factory=get_callback_manager)
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
@validator("callback_manager", pre=True, always=True)
def set_callback_manager(
cls, callback_manager: Optional[BaseCallbackManager]
) -> BaseCallbackManager:
"""If callback manager is None, set it.
This allows users to pass in None as callback manager, which is a nice UX.
"""
return callback_manager or get_callback_manager()
@validator("verbose", pre=True, always=True)
def set_verbose(cls, verbose: Optional[bool]) -> bool:
"""If verbose is None, set it.
This allows users to pass in None as verbose to access the global setting.
"""
if verbose is None:
return _get_verbosity()
else:
return verbose
def generate_prompt(
self, prompts: List[PromptValue], stop: Optional[List[str]] = None
) -> LLMResult:
"""Take in a list of prompt values and return an LLMResult."""
self.callback_manager.on_llm_start_prompt_value(
{"name": self.__class__.__name__}, prompts, verbose=self.verbose
)
try:
output = self._generate_prompt(prompts, stop=stop)
except (KeyboardInterrupt, Exception) as e:
self.callback_manager.on_llm_error(e, verbose=self.verbose)
raise e
self.callback_manager.on_llm_end(output, verbose=self.verbose)
return output
async def agenerate_prompt(
self, prompts: List[PromptValue], stop: Optional[List[str]] = None
) -> LLMResult:
"""Take in a list of prompt values and return an LLMResult."""
if self.callback_manager.is_async:
await self.callback_manager.on_llm_start_prompt_value(
{"name": self.__class__.__name__}, prompts, verbose=self.verbose
)
else:
self.callback_manager.on_llm_start_prompt_value(
{"name": self.__class__.__name__}, prompts, verbose=self.verbose
)
try:
output = await self._agenerate_prompt(prompts, stop=stop)
except (KeyboardInterrupt, Exception) as e:
if self.callback_manager.is_async:
await self.callback_manager.on_llm_error(e, verbose=self.verbose)
else:
self.callback_manager.on_llm_error(e, verbose=self.verbose)
raise e
if self.callback_manager.is_async:
await self.callback_manager.on_llm_end(output, verbose=self.verbose)
else:
self.callback_manager.on_llm_end(output, verbose=self.verbose)
return output
@abstractmethod
def _generate_prompt(
self, prompts: List[PromptValue], stop: Optional[List[str]] = None
) -> LLMResult:
"""Take in a list of prompt values and return an LLMResult."""
@abstractmethod
async def _agenerate_prompt(
self, prompts: List[PromptValue], stop: Optional[List[str]] = None
) -> LLMResult:
"""Take in a list of prompt values and return an LLMResult."""
def get_num_tokens(self, text: str) -> int:
"""Get the number of tokens present in the text."""
# TODO: this method may not be exact.
# TODO: this method may differ based on model (eg codex).
try:
from transformers import GPT2TokenizerFast
except ImportError:
raise ValueError(
"Could not import transformers python package. "
"This is needed in order to calculate get_num_tokens. "
"Please it install it with `pip install transformers`."
)
# create a GPT-3 tokenizer instance
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
# tokenize the text using the GPT-3 tokenizer
tokenized_text = tokenizer.tokenize(text)
# calculate the number of tokens in the tokenized text
return len(tokenized_text)

View File

@@ -4,7 +4,7 @@ import functools
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Union
from langchain.schema import AgentAction, AgentFinish, LLMResult
from langchain.schema import AgentAction, AgentFinish, LLMResult, PromptValue
class BaseCallbackHandler(ABC):
@@ -30,6 +30,12 @@ class BaseCallbackHandler(ABC):
"""Whether to ignore agent callbacks."""
return False
def on_llm_start_prompt_value(
self, serialized: Dict[str, Any], prompts: List[PromptValue], **kwargs: Any
) -> Any:
"""Run when LLM starts running."""
pass
@abstractmethod
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
@@ -127,6 +133,19 @@ class CallbackManager(BaseCallbackManager):
"""Initialize callback manager."""
self.handlers: List[BaseCallbackHandler] = handlers
def on_llm_start_prompt_value(
self,
serialized: Dict[str, Any],
prompts: List[PromptValue],
verbose: bool = False,
**kwargs: Any
) -> Any:
"""Run when LLM starts running."""
for handler in self.handlers:
if not handler.ignore_llm:
if verbose or handler.always_verbose:
handler.on_llm_start_prompt_value(serialized, prompts, **kwargs)
def on_llm_start(
self,
serialized: Dict[str, Any],
@@ -276,6 +295,11 @@ class CallbackManager(BaseCallbackManager):
class AsyncCallbackHandler(BaseCallbackHandler):
"""Async callback handler that can be used to handle callbacks from langchain."""
async def on_llm_start_prompt_value(
self, serialized: Dict[str, Any], prompts: List[PromptValue], **kwargs: Any
) -> Any:
"""Run when LLM starts running."""
async def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
@@ -340,6 +364,32 @@ class AsyncCallbackManager(BaseCallbackManager):
"""Initialize callback manager."""
self.handlers: List[BaseCallbackHandler] = handlers
async def on_llm_start_prompt_value(
self,
serialized: Dict[str, Any],
prompts: List[PromptValue],
verbose: bool = False,
**kwargs: Any
) -> Any:
"""Run when LLM starts running."""
for handler in self.handlers:
if not handler.ignore_llm:
if verbose or handler.always_verbose:
if asyncio.iscoroutinefunction(handler.on_llm_start_prompt_value):
return await handler.on_llm_start_prompt_value(
serialized, prompts, **kwargs
)
else:
return await asyncio.get_event_loop().run_in_executor(
None,
functools.partial(
handler.on_llm_start_prompt_value,
serialized,
prompts,
**kwargs
),
)
async def on_llm_start(
self,
serialized: Dict[str, Any],

View File

@@ -8,7 +8,7 @@ from langchain.callbacks.base import (
BaseCallbackManager,
CallbackManager,
)
from langchain.schema import AgentAction, AgentFinish, LLMResult
from langchain.schema import AgentAction, AgentFinish, LLMResult, PromptValue
class Singleton:
@@ -34,6 +34,15 @@ class SharedCallbackManager(Singleton, BaseCallbackManager):
_callback_manager: CallbackManager = CallbackManager(handlers=[])
def on_llm_start_prompt_value(
self, serialized: Dict[str, Any], prompts: List[PromptValue], **kwargs: Any
) -> Any:
"""Run when LLM starts running."""
with self._lock:
self._callback_manager.on_llm_start_prompt_value(
serialized, prompts, **kwargs
)
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:

View File

@@ -16,7 +16,7 @@ from langchain.callbacks.tracers.schemas import (
TracerSession,
TracerSessionCreate,
)
from langchain.schema import AgentAction, AgentFinish, LLMResult
from langchain.schema import AgentAction, AgentFinish, LLMResult, PromptValue
class TracerException(Exception):
@@ -109,9 +109,9 @@ class BaseTracer(BaseCallbackHandler, ABC):
self._execution_order = 1
self._persist_run(run)
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
def on_llm_start_prompt_value(
self, serialized: Dict[str, Any], prompts: List[PromptValue], **kwargs: Any
) -> Any:
"""Start a trace for an LLM run."""
if self._session is None:
raise TracerException(
@@ -129,6 +129,12 @@ class BaseTracer(BaseCallbackHandler, ABC):
)
self._start_trace(llm_run)
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
"""Start a trace for an LLM run."""
pass
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Handle a new token for an LLM run."""
pass

View File

@@ -6,7 +6,7 @@ from typing import Any, Dict, List, Optional, Union
from pydantic import BaseModel, Field
from langchain.schema import LLMResult
from langchain.schema import LLMResult, PromptValue
class TracerSessionBase(BaseModel):
@@ -45,7 +45,7 @@ class BaseRun(BaseModel):
class LLMRun(BaseRun):
"""Class for LLMRun."""
prompts: List[str]
prompts: List[PromptValue]
response: Optional[LLMResult] = None

View File

@@ -6,13 +6,13 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from pydantic import BaseModel
from langchain.base_language_model import BaseLanguageModel
from langchain.chains.base import Chain
from langchain.chains.chat_vector_db.prompts import CONDENSE_QUESTION_PROMPT
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain.chains.llm import LLMChain
from langchain.chains.question_answering import load_qa_chain
from langchain.prompts.base import BasePromptTemplate
from langchain.schema import BaseLanguageModel
from langchain.vectorstores.base import VectorStore

View File

@@ -5,11 +5,12 @@ from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
from pydantic import BaseModel, Extra
from langchain.base_language_model import BaseLanguageModel
from langchain.chains.base import Chain
from langchain.input import get_colored_text
from langchain.prompts.base import BasePromptTemplate
from langchain.prompts.prompt import PromptTemplate
from langchain.schema import BaseLanguageModel, LLMResult, PromptValue
from langchain.schema import LLMResult, PromptValue
class LLMChain(Chain, BaseModel):

View File

@@ -3,10 +3,10 @@ from typing import Callable, List, Tuple
from pydantic import BaseModel, Field
from langchain.base_language_model import BaseLanguageModel
from langchain.chat_models.base import BaseChatModel
from langchain.llms.base import BaseLLM
from langchain.prompts.base import BasePromptTemplate
from langchain.schema import BaseLanguageModel
class BasePromptSelector(BaseModel, ABC):

View File

@@ -1,6 +1,7 @@
"""Load question answering chains."""
from typing import Any, Mapping, Optional, Protocol
from langchain.base_language_model import BaseLanguageModel
from langchain.callbacks.base import BaseCallbackManager
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain
@@ -15,7 +16,6 @@ from langchain.chains.question_answering import (
stuff_prompt,
)
from langchain.prompts.base import BasePromptTemplate
from langchain.schema import BaseLanguageModel
class LoadingCallable(Protocol):

View File

@@ -1,14 +1,12 @@
from abc import ABC, abstractmethod
from typing import List, Optional
from pydantic import BaseModel, Extra, Field, validator
from pydantic import BaseModel
import langchain
from langchain.callbacks import get_callback_manager
from langchain.callbacks.base import BaseCallbackManager
from langchain.base_language_model import BaseLanguageModel
from langchain.schema import (
AIMessage,
BaseLanguageModel,
BaseMessage,
ChatGeneration,
ChatResult,
@@ -22,25 +20,7 @@ def _get_verbosity() -> bool:
class BaseChatModel(BaseLanguageModel, BaseModel, ABC):
verbose: bool = Field(default_factory=_get_verbosity)
"""Whether to print out response text."""
callback_manager: BaseCallbackManager = Field(default_factory=get_callback_manager)
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
@validator("callback_manager", pre=True, always=True)
def set_callback_manager(
cls, callback_manager: Optional[BaseCallbackManager]
) -> BaseCallbackManager:
"""If callback manager is None, set it.
This allows users to pass in None as callback manager, which is a nice UX.
"""
return callback_manager or get_callback_manager()
"""Base class for chat models."""
def generate(
self, messages: List[List[BaseMessage]], stop: Optional[List[str]] = None
@@ -56,13 +36,13 @@ class BaseChatModel(BaseLanguageModel, BaseModel, ABC):
results = [await self._agenerate(m, stop=stop) for m in messages]
return LLMResult(generations=[res.generations for res in results])
def generate_prompt(
def _generate_prompt(
self, prompts: List[PromptValue], stop: Optional[List[str]] = None
) -> LLMResult:
prompt_messages = [p.to_messages() for p in prompts]
return self.generate(prompt_messages, stop=stop)
async def agenerate_prompt(
async def _agenerate_prompt(
self, prompts: List[PromptValue], stop: Optional[List[str]] = None
) -> LLMResult:
prompt_messages = [p.to_messages() for p in prompts]

View File

@@ -5,16 +5,11 @@ from pathlib import Path
from typing import Any, Dict, List, Mapping, Optional, Tuple, Union
import yaml
from pydantic import BaseModel, Extra, Field, validator
from pydantic import BaseModel, Extra
import langchain
from langchain.callbacks import get_callback_manager
from langchain.callbacks.base import BaseCallbackManager
from langchain.schema import BaseLanguageModel, Generation, LLMResult, PromptValue
def _get_verbosity() -> bool:
return langchain.verbose
from langchain.base_language_model import BaseLanguageModel
from langchain.schema import Generation, LLMResult, PromptValue
def get_prompts(
@@ -57,9 +52,6 @@ class BaseLLM(BaseLanguageModel, BaseModel, ABC):
"""LLM wrapper should take in a prompt and return a string."""
cache: Optional[bool] = None
verbose: bool = Field(default_factory=_get_verbosity)
"""Whether to print out response text."""
callback_manager: BaseCallbackManager = Field(default_factory=get_callback_manager)
class Config:
"""Configuration for this pydantic object."""
@@ -67,27 +59,6 @@ class BaseLLM(BaseLanguageModel, BaseModel, ABC):
extra = Extra.forbid
arbitrary_types_allowed = True
@validator("callback_manager", pre=True, always=True)
def set_callback_manager(
cls, callback_manager: Optional[BaseCallbackManager]
) -> BaseCallbackManager:
"""If callback manager is None, set it.
This allows users to pass in None as callback manager, which is a nice UX.
"""
return callback_manager or get_callback_manager()
@validator("verbose", pre=True, always=True)
def set_verbose(cls, verbose: Optional[bool]) -> bool:
"""If verbose is None, set it.
This allows users to pass in None as verbose to access the global setting.
"""
if verbose is None:
return _get_verbosity()
else:
return verbose
@abstractmethod
def _generate(
self, prompts: List[str], stop: Optional[List[str]] = None
@@ -100,13 +71,13 @@ class BaseLLM(BaseLanguageModel, BaseModel, ABC):
) -> LLMResult:
"""Run the LLM on the given prompts."""
def generate_prompt(
def _generate_prompt(
self, prompts: List[PromptValue], stop: Optional[List[str]] = None
) -> LLMResult:
prompt_strings = [p.to_string() for p in prompts]
return self.generate(prompt_strings, stop=stop)
async def agenerate_prompt(
async def _agenerate_prompt(
self, prompts: List[PromptValue], stop: Optional[List[str]] = None
) -> LLMResult:
prompt_strings = [p.to_string() for p in prompts]
@@ -138,7 +109,6 @@ class BaseLLM(BaseLanguageModel, BaseModel, ABC):
except (KeyboardInterrupt, Exception) as e:
self.callback_manager.on_llm_error(e, verbose=self.verbose)
raise e
self.callback_manager.on_llm_end(output, verbose=self.verbose)
return output
params = self.dict()
params["stop"] = stop
@@ -157,7 +127,6 @@ class BaseLLM(BaseLanguageModel, BaseModel, ABC):
except (KeyboardInterrupt, Exception) as e:
self.callback_manager.on_llm_error(e, verbose=self.verbose)
raise e
self.callback_manager.on_llm_end(new_results, verbose=self.verbose)
llm_output = update_cache(
existing_prompts, llm_string, missing_prompt_idxs, new_results, prompts
)

View File

@@ -2,6 +2,7 @@ from typing import Any, Dict, List, Optional
from pydantic import BaseModel
from langchain.base_language_model import BaseLanguageModel
from langchain.chains.llm import LLMChain
from langchain.memory.chat_memory import BaseChatMemory
from langchain.memory.prompt import (
@@ -10,7 +11,7 @@ from langchain.memory.prompt import (
)
from langchain.memory.utils import get_buffer_string, get_prompt_input_key
from langchain.prompts.base import BasePromptTemplate
from langchain.schema import BaseLanguageModel, BaseMessage
from langchain.schema import BaseMessage
class ConversationEntityMemory(BaseChatMemory, BaseModel):

View File

@@ -2,6 +2,7 @@ from typing import Any, Dict, List
from pydantic import BaseModel, Field
from langchain.base_language_model import BaseLanguageModel
from langchain.chains.llm import LLMChain
from langchain.graphs import NetworkxEntityGraph
from langchain.graphs.networkx_graph import KnowledgeTriple, get_entities, parse_triples
@@ -12,7 +13,7 @@ from langchain.memory.prompt import (
)
from langchain.memory.utils import get_buffer_string, get_prompt_input_key
from langchain.prompts.base import BasePromptTemplate
from langchain.schema import BaseLanguageModel, SystemMessage
from langchain.schema import SystemMessage
class ConversationKGMemory(BaseChatMemory, BaseModel):

View File

@@ -2,12 +2,13 @@ from typing import Any, Dict, List
from pydantic import BaseModel, root_validator
from langchain.base_language_model import BaseLanguageModel
from langchain.chains.llm import LLMChain
from langchain.memory.chat_memory import BaseChatMemory
from langchain.memory.prompt import SUMMARY_PROMPT
from langchain.memory.utils import get_buffer_string
from langchain.prompts.base import BasePromptTemplate
from langchain.schema import BaseLanguageModel, BaseMessage, SystemMessage
from langchain.schema import BaseMessage, SystemMessage
class SummarizerMixin(BaseModel):

View File

@@ -44,19 +44,26 @@ class BaseMessage(BaseModel):
class HumanMessage(BaseMessage):
"""Type of message that is spoken by the human."""
_type = "human"
class AIMessage(BaseMessage):
"""Type of message that is spoken by the AI."""
_type = "ai"
class SystemMessage(BaseMessage):
"""Type of message that is a system message."""
_type = "system"
class ChatMessage(BaseMessage):
"""Type of message with arbitrary speaker."""
role: str
_type = "chat"
class ChatGeneration(Generation):
@@ -100,41 +107,6 @@ class PromptValue(BaseModel, ABC):
"""Return prompt as messages."""
class BaseLanguageModel(BaseModel, ABC):
@abstractmethod
def generate_prompt(
self, prompts: List[PromptValue], stop: Optional[List[str]] = None
) -> LLMResult:
"""Take in a list of prompt values and return an LLMResult."""
@abstractmethod
async def agenerate_prompt(
self, prompts: List[PromptValue], stop: Optional[List[str]] = None
) -> LLMResult:
"""Take in a list of prompt values and return an LLMResult."""
def get_num_tokens(self, text: str) -> int:
"""Get the number of tokens present in the text."""
# TODO: this method may not be exact.
# TODO: this method may differ based on model (eg codex).
try:
from transformers import GPT2TokenizerFast
except ImportError:
raise ValueError(
"Could not import transformers python package. "
"This is needed in order to calculate get_num_tokens. "
"Please it install it with `pip install transformers`."
)
# create a GPT-3 tokenizer instance
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
# tokenize the text using the GPT-3 tokenizer
tokenized_text = tokenizer.tokenize(text)
# calculate the number of tokens in the tokenized text
return len(tokenized_text)
class BaseMemory(BaseModel, ABC):
"""Base interface for memory in chains."""

View File

@@ -85,10 +85,10 @@ def _perform_nested_run(tracer: BaseTracer) -> None:
"""Perform a nested run."""
tracer.on_chain_start(serialized={}, inputs={})
tracer.on_tool_start(serialized={}, input_str="test")
tracer.on_llm_start(serialized={}, prompts=[])
tracer.on_llm_start_prompt_value(serialized={}, prompts=[])
tracer.on_llm_end(response=LLMResult(generations=[[]]))
tracer.on_tool_end("test")
tracer.on_llm_start(serialized={}, prompts=[])
tracer.on_llm_start_prompt_value(serialized={}, prompts=[])
tracer.on_llm_end(response=LLMResult(generations=[[]]))
tracer.on_chain_end(outputs={})
@@ -216,7 +216,7 @@ def test_tracer_llm_run() -> None:
tracer = FakeTracer()
tracer.new_session()
tracer.on_llm_start(serialized={}, prompts=[])
tracer.on_llm_start_prompt_value(serialized={}, prompts=[])
tracer.on_llm_end(response=LLMResult(generations=[[]]))
assert tracer.runs == [compare_run]
@@ -227,7 +227,7 @@ def test_tracer_llm_run_errors_no_session() -> None:
tracer = FakeTracer()
with pytest.raises(TracerException):
tracer.on_llm_start(serialized={}, prompts=[])
tracer.on_llm_start_prompt_value(serialized={}, prompts=[])
@freeze_time("2023-01-01")
@@ -260,7 +260,7 @@ def test_tracer_multiple_llm_runs() -> None:
tracer.new_session()
num_runs = 10
for _ in range(num_runs):
tracer.on_llm_start(serialized={}, prompts=[])
tracer.on_llm_start_prompt_value(serialized={}, prompts=[])
tracer.on_llm_end(response=LLMResult(generations=[[]]))
assert tracer.runs == [compare_run] * num_runs
@@ -342,7 +342,7 @@ def test_tracer_llm_run_on_error() -> None:
tracer = FakeTracer()
tracer.new_session()
tracer.on_llm_start(serialized={}, prompts=[])
tracer.on_llm_start_prompt_value(serialized={}, prompts=[])
tracer.on_llm_error(exception)
assert tracer.runs == [compare_run]
@@ -408,12 +408,12 @@ def test_tracer_nested_runs_on_error() -> None:
for _ in range(3):
tracer.on_chain_start(serialized={}, inputs={})
tracer.on_llm_start(serialized={}, prompts=[])
tracer.on_llm_start_prompt_value(serialized={}, prompts=[])
tracer.on_llm_end(response=LLMResult(generations=[[]]))
tracer.on_llm_start(serialized={}, prompts=[])
tracer.on_llm_start_prompt_value(serialized={}, prompts=[])
tracer.on_llm_end(response=LLMResult(generations=[[]]))
tracer.on_tool_start(serialized={}, input_str="test")
tracer.on_llm_start(serialized={}, prompts=[])
tracer.on_llm_start_prompt_value(serialized={}, prompts=[])
tracer.on_llm_error(exception)
tracer.on_tool_error(exception)
tracer.on_chain_error(exception)