mirror of
https://github.com/hwchase17/langchain.git
synced 2026-01-23 21:31:02 +00:00
Compare commits
8 Commits
vwp/charac
...
ankush/cal
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fe353c610f | ||
|
|
f897e7102b | ||
|
|
366de1bd58 | ||
|
|
4553f64b4b | ||
|
|
3c7a559e77 | ||
|
|
7d465cbc2f | ||
|
|
ecdfbfe1c7 | ||
|
|
4d366eeea3 |
@@ -11,11 +11,20 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"execution_count": 1,
|
||||
"id": "5268c7fa",
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"os.environ[\"LANGCHAIN_HANDLER\"] = \"langchain\"\n",
|
||||
"\n",
|
||||
"## Uncomment this if using hosted setup.\n",
|
||||
"\n",
|
||||
"os.environ[\"LANGCHAIN_ENDPOINT\"] = \"http://127.0.0.1:8000\" \n",
|
||||
"\n",
|
||||
"from langchain.agents import ZeroShotAgent, Tool, AgentExecutor\n",
|
||||
"from langchain.chains import LLMChain\n",
|
||||
"from langchain.utilities import SerpAPIWrapper"
|
||||
@@ -23,9 +32,11 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"execution_count": 2,
|
||||
"id": "fbaa4dbe",
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"search = SerpAPIWrapper()\n",
|
||||
@@ -40,9 +51,11 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"execution_count": 3,
|
||||
"id": "f3ba6f08",
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"prefix = \"\"\"Answer the following questions as best you can, but speaking as a pirate might speak. You have access to the following tools:\"\"\"\n",
|
||||
@@ -58,9 +71,11 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"execution_count": 4,
|
||||
"id": "3547a37d",
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.chat_models import ChatOpenAI\n",
|
||||
@@ -79,9 +94,11 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"execution_count": 5,
|
||||
"id": "a78f886f",
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"messages = [\n",
|
||||
@@ -94,9 +111,11 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"execution_count": 6,
|
||||
"id": "dadadd70",
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"prompt = ChatPromptTemplate.from_messages(messages)"
|
||||
@@ -104,9 +123,11 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"execution_count": 7,
|
||||
"id": "b7180182",
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"llm_chain = LLMChain(llm=ChatOpenAI(temperature=0), prompt=prompt)"
|
||||
@@ -114,9 +135,11 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"execution_count": 8,
|
||||
"id": "ddddb07b",
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"tool_names = [tool.name for tool in tools]\n",
|
||||
@@ -125,9 +148,11 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"execution_count": 9,
|
||||
"id": "36aef054",
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"agent_executor = AgentExecutor.from_agent_and_tools(agent=agent, tools=tools, verbose=True)"
|
||||
@@ -135,9 +160,11 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"execution_count": 10,
|
||||
"id": "33a4d6cc",
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
@@ -146,16 +173,16 @@
|
||||
"\n",
|
||||
"\n",
|
||||
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
|
||||
"\u001b[32;1m\u001b[1;3mArrr, ye be in luck, matey! I'll find ye the answer to yer question.\n",
|
||||
"\u001b[32;1m\u001b[1;3mArrr, ye be in luck, matey! I'll help ye find the answer ye seek.\n",
|
||||
"\n",
|
||||
"Thought: I need to search for the current population of Canada.\n",
|
||||
"Thought: Hmm, I don't have this information off the top of me head.\n",
|
||||
"Action: Search\n",
|
||||
"Action Input: \"current population of Canada 2023\"\n",
|
||||
"Action Input: \"Canada population 2023\"\n",
|
||||
"\u001b[0m\n",
|
||||
"Observation: \u001b[36;1m\u001b[1;3mThe current population of Canada is 38,623,091 as of Saturday, March 4, 2023, based on Worldometer elaboration of the latest United Nations data.\u001b[0m\n",
|
||||
"Thought:\u001b[32;1m\u001b[1;3mAhoy, me hearties! I've found the answer to yer question.\n",
|
||||
"Observation: \u001b[36;1m\u001b[1;3mThe current population of Canada is 38,627,607 as of Thursday, March 9, 2023, based on Worldometer elaboration of the latest United Nations data.\u001b[0m\n",
|
||||
"Thought:\u001b[32;1m\u001b[1;3mAhoy, me hearties! I have found the answer to yer question.\n",
|
||||
"\n",
|
||||
"Final Answer: As of March 4, 2023, the population of Canada be 38,623,091. Arrr!\u001b[0m\n",
|
||||
"Final Answer: As of March 9, 2023, the population of Canada is 38,627,607. Arrr!\u001b[0m\n",
|
||||
"\n",
|
||||
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||||
]
|
||||
@@ -163,10 +190,10 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"'As of March 4, 2023, the population of Canada be 38,623,091. Arrr!'"
|
||||
"'As of March 9, 2023, the population of Canada is 38,627,607. Arrr!'"
|
||||
]
|
||||
},
|
||||
"execution_count": 13,
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@@ -200,7 +227,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.1"
|
||||
"version": "3.10.9"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@@ -372,7 +372,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.1"
|
||||
"version": "3.10.9"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@@ -151,7 +151,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.1"
|
||||
"version": "3.10.9"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@@ -22,7 +22,7 @@
|
||||
"\n",
|
||||
"## Uncomment this if using hosted setup.\n",
|
||||
"\n",
|
||||
"# os.environ[\"LANGCHAIN_ENDPOINT\"] = \"https://langchain-api-gateway-57eoxz8z.uc.gateway.dev\" \n",
|
||||
"os.environ[\"LANGCHAIN_ENDPOINT\"] = \"http://127.0.0.1:8000\" \n",
|
||||
"\n",
|
||||
"## Uncomment this if you want traces to be recorded to \"my_session\" instead of default.\n",
|
||||
"\n",
|
||||
@@ -89,9 +89,30 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 4,
|
||||
"id": "25addd7f",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"'\\n\\nQ: What did the fish say when it hit the wall?\\nA: Dam!'"
|
||||
]
|
||||
},
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"llm(\"tell me a joke\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "7afdd321-796f-4dcc-816c-c1160aa943af",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ from typing import Any, List, Optional, Sequence, Tuple
|
||||
|
||||
from langchain.agents.agent import Agent
|
||||
from langchain.agents.chat.prompt import FORMAT_INSTRUCTIONS, PREFIX, SUFFIX
|
||||
from langchain.base_language_model import BaseLanguageModel
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
@@ -11,7 +12,7 @@ from langchain.prompts.chat import (
|
||||
HumanMessagePromptTemplate,
|
||||
SystemMessagePromptTemplate,
|
||||
)
|
||||
from langchain.schema import AgentAction, BaseLanguageModel
|
||||
from langchain.schema import AgentAction
|
||||
from langchain.tools import BaseTool
|
||||
|
||||
FINAL_ANSWER_ACTION = "Final Answer:"
|
||||
|
||||
124
langchain/base_language_model.py
Normal file
124
langchain/base_language_model.py
Normal file
@@ -0,0 +1,124 @@
|
||||
"""Base class for language models."""
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel, Extra, Field, validator
|
||||
|
||||
import langchain
|
||||
from langchain.callbacks import BaseCallbackManager, get_callback_manager
|
||||
from langchain.schema import LLMResult, PromptValue
|
||||
|
||||
|
||||
def _get_verbosity() -> bool:
|
||||
return langchain.verbose
|
||||
|
||||
|
||||
class BaseLanguageModel(BaseModel, ABC):
|
||||
"""Base class for language models."""
|
||||
|
||||
verbose: bool = Field(default_factory=_get_verbosity)
|
||||
"""Whether to print out response text."""
|
||||
callback_manager: BaseCallbackManager = Field(default_factory=get_callback_manager)
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@validator("callback_manager", pre=True, always=True)
|
||||
def set_callback_manager(
|
||||
cls, callback_manager: Optional[BaseCallbackManager]
|
||||
) -> BaseCallbackManager:
|
||||
"""If callback manager is None, set it.
|
||||
|
||||
This allows users to pass in None as callback manager, which is a nice UX.
|
||||
"""
|
||||
return callback_manager or get_callback_manager()
|
||||
|
||||
@validator("verbose", pre=True, always=True)
|
||||
def set_verbose(cls, verbose: Optional[bool]) -> bool:
|
||||
"""If verbose is None, set it.
|
||||
|
||||
This allows users to pass in None as verbose to access the global setting.
|
||||
"""
|
||||
if verbose is None:
|
||||
return _get_verbosity()
|
||||
else:
|
||||
return verbose
|
||||
|
||||
def generate_prompt(
|
||||
self, prompts: List[PromptValue], stop: Optional[List[str]] = None
|
||||
) -> LLMResult:
|
||||
"""Take in a list of prompt values and return an LLMResult."""
|
||||
self.callback_manager.on_llm_start_prompt_value(
|
||||
{"name": self.__class__.__name__}, prompts, verbose=self.verbose
|
||||
)
|
||||
try:
|
||||
output = self._generate_prompt(prompts, stop=stop)
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
self.callback_manager.on_llm_error(e, verbose=self.verbose)
|
||||
raise e
|
||||
self.callback_manager.on_llm_end(output, verbose=self.verbose)
|
||||
return output
|
||||
|
||||
async def agenerate_prompt(
|
||||
self, prompts: List[PromptValue], stop: Optional[List[str]] = None
|
||||
) -> LLMResult:
|
||||
"""Take in a list of prompt values and return an LLMResult."""
|
||||
if self.callback_manager.is_async:
|
||||
await self.callback_manager.on_llm_start_prompt_value(
|
||||
{"name": self.__class__.__name__}, prompts, verbose=self.verbose
|
||||
)
|
||||
else:
|
||||
self.callback_manager.on_llm_start_prompt_value(
|
||||
{"name": self.__class__.__name__}, prompts, verbose=self.verbose
|
||||
)
|
||||
try:
|
||||
output = await self._agenerate_prompt(prompts, stop=stop)
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
if self.callback_manager.is_async:
|
||||
await self.callback_manager.on_llm_error(e, verbose=self.verbose)
|
||||
else:
|
||||
self.callback_manager.on_llm_error(e, verbose=self.verbose)
|
||||
raise e
|
||||
if self.callback_manager.is_async:
|
||||
await self.callback_manager.on_llm_end(output, verbose=self.verbose)
|
||||
else:
|
||||
self.callback_manager.on_llm_end(output, verbose=self.verbose)
|
||||
return output
|
||||
|
||||
@abstractmethod
|
||||
def _generate_prompt(
|
||||
self, prompts: List[PromptValue], stop: Optional[List[str]] = None
|
||||
) -> LLMResult:
|
||||
"""Take in a list of prompt values and return an LLMResult."""
|
||||
|
||||
@abstractmethod
|
||||
async def _agenerate_prompt(
|
||||
self, prompts: List[PromptValue], stop: Optional[List[str]] = None
|
||||
) -> LLMResult:
|
||||
"""Take in a list of prompt values and return an LLMResult."""
|
||||
|
||||
def get_num_tokens(self, text: str) -> int:
|
||||
"""Get the number of tokens present in the text."""
|
||||
# TODO: this method may not be exact.
|
||||
# TODO: this method may differ based on model (eg codex).
|
||||
try:
|
||||
from transformers import GPT2TokenizerFast
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import transformers python package. "
|
||||
"This is needed in order to calculate get_num_tokens. "
|
||||
"Please it install it with `pip install transformers`."
|
||||
)
|
||||
# create a GPT-3 tokenizer instance
|
||||
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
|
||||
|
||||
# tokenize the text using the GPT-3 tokenizer
|
||||
tokenized_text = tokenizer.tokenize(text)
|
||||
|
||||
# calculate the number of tokens in the tokenized text
|
||||
return len(tokenized_text)
|
||||
@@ -4,7 +4,7 @@ import functools
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||
from langchain.schema import AgentAction, AgentFinish, LLMResult, PromptValue
|
||||
|
||||
|
||||
class BaseCallbackHandler(ABC):
|
||||
@@ -30,6 +30,12 @@ class BaseCallbackHandler(ABC):
|
||||
"""Whether to ignore agent callbacks."""
|
||||
return False
|
||||
|
||||
def on_llm_start_prompt_value(
|
||||
self, serialized: Dict[str, Any], prompts: List[PromptValue], **kwargs: Any
|
||||
) -> Any:
|
||||
"""Run when LLM starts running."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
@@ -127,6 +133,19 @@ class CallbackManager(BaseCallbackManager):
|
||||
"""Initialize callback manager."""
|
||||
self.handlers: List[BaseCallbackHandler] = handlers
|
||||
|
||||
def on_llm_start_prompt_value(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
prompts: List[PromptValue],
|
||||
verbose: bool = False,
|
||||
**kwargs: Any
|
||||
) -> Any:
|
||||
"""Run when LLM starts running."""
|
||||
for handler in self.handlers:
|
||||
if not handler.ignore_llm:
|
||||
if verbose or handler.always_verbose:
|
||||
handler.on_llm_start_prompt_value(serialized, prompts, **kwargs)
|
||||
|
||||
def on_llm_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
@@ -276,6 +295,11 @@ class CallbackManager(BaseCallbackManager):
|
||||
class AsyncCallbackHandler(BaseCallbackHandler):
|
||||
"""Async callback handler that can be used to handle callbacks from langchain."""
|
||||
|
||||
async def on_llm_start_prompt_value(
|
||||
self, serialized: Dict[str, Any], prompts: List[PromptValue], **kwargs: Any
|
||||
) -> Any:
|
||||
"""Run when LLM starts running."""
|
||||
|
||||
async def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
) -> None:
|
||||
@@ -340,6 +364,32 @@ class AsyncCallbackManager(BaseCallbackManager):
|
||||
"""Initialize callback manager."""
|
||||
self.handlers: List[BaseCallbackHandler] = handlers
|
||||
|
||||
async def on_llm_start_prompt_value(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
prompts: List[PromptValue],
|
||||
verbose: bool = False,
|
||||
**kwargs: Any
|
||||
) -> Any:
|
||||
"""Run when LLM starts running."""
|
||||
for handler in self.handlers:
|
||||
if not handler.ignore_llm:
|
||||
if verbose or handler.always_verbose:
|
||||
if asyncio.iscoroutinefunction(handler.on_llm_start_prompt_value):
|
||||
return await handler.on_llm_start_prompt_value(
|
||||
serialized, prompts, **kwargs
|
||||
)
|
||||
else:
|
||||
return await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
functools.partial(
|
||||
handler.on_llm_start_prompt_value,
|
||||
serialized,
|
||||
prompts,
|
||||
**kwargs
|
||||
),
|
||||
)
|
||||
|
||||
async def on_llm_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
|
||||
@@ -8,7 +8,7 @@ from langchain.callbacks.base import (
|
||||
BaseCallbackManager,
|
||||
CallbackManager,
|
||||
)
|
||||
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||
from langchain.schema import AgentAction, AgentFinish, LLMResult, PromptValue
|
||||
|
||||
|
||||
class Singleton:
|
||||
@@ -34,6 +34,15 @@ class SharedCallbackManager(Singleton, BaseCallbackManager):
|
||||
|
||||
_callback_manager: CallbackManager = CallbackManager(handlers=[])
|
||||
|
||||
def on_llm_start_prompt_value(
|
||||
self, serialized: Dict[str, Any], prompts: List[PromptValue], **kwargs: Any
|
||||
) -> Any:
|
||||
"""Run when LLM starts running."""
|
||||
with self._lock:
|
||||
self._callback_manager.on_llm_start_prompt_value(
|
||||
serialized, prompts, **kwargs
|
||||
)
|
||||
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
) -> None:
|
||||
|
||||
@@ -16,7 +16,7 @@ from langchain.callbacks.tracers.schemas import (
|
||||
TracerSession,
|
||||
TracerSessionCreate,
|
||||
)
|
||||
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||
from langchain.schema import AgentAction, AgentFinish, LLMResult, PromptValue
|
||||
|
||||
|
||||
class TracerException(Exception):
|
||||
@@ -109,9 +109,9 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
self._execution_order = 1
|
||||
self._persist_run(run)
|
||||
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
) -> None:
|
||||
def on_llm_start_prompt_value(
|
||||
self, serialized: Dict[str, Any], prompts: List[PromptValue], **kwargs: Any
|
||||
) -> Any:
|
||||
"""Start a trace for an LLM run."""
|
||||
if self._session is None:
|
||||
raise TracerException(
|
||||
@@ -129,6 +129,12 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
)
|
||||
self._start_trace(llm_run)
|
||||
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
) -> None:
|
||||
"""Start a trace for an LLM run."""
|
||||
pass
|
||||
|
||||
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
"""Handle a new token for an LLM run."""
|
||||
pass
|
||||
|
||||
@@ -6,7 +6,7 @@ from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from langchain.schema import LLMResult
|
||||
from langchain.schema import LLMResult, PromptValue
|
||||
|
||||
|
||||
class TracerSessionBase(BaseModel):
|
||||
@@ -45,7 +45,7 @@ class BaseRun(BaseModel):
|
||||
class LLMRun(BaseRun):
|
||||
"""Class for LLMRun."""
|
||||
|
||||
prompts: List[str]
|
||||
prompts: List[PromptValue]
|
||||
response: Optional[LLMResult] = None
|
||||
|
||||
|
||||
|
||||
@@ -6,13 +6,13 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain.base_language_model import BaseLanguageModel
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.chat_vector_db.prompts import CONDENSE_QUESTION_PROMPT
|
||||
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.question_answering import load_qa_chain
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.schema import BaseLanguageModel
|
||||
from langchain.vectorstores.base import VectorStore
|
||||
|
||||
|
||||
|
||||
@@ -5,11 +5,12 @@ from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
from pydantic import BaseModel, Extra
|
||||
|
||||
from langchain.base_language_model import BaseLanguageModel
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.input import get_colored_text
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
from langchain.schema import BaseLanguageModel, LLMResult, PromptValue
|
||||
from langchain.schema import LLMResult, PromptValue
|
||||
|
||||
|
||||
class LLMChain(Chain, BaseModel):
|
||||
|
||||
@@ -3,10 +3,10 @@ from typing import Callable, List, Tuple
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from langchain.base_language_model import BaseLanguageModel
|
||||
from langchain.chat_models.base import BaseChatModel
|
||||
from langchain.llms.base import BaseLLM
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.schema import BaseLanguageModel
|
||||
|
||||
|
||||
class BasePromptSelector(BaseModel, ABC):
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""Load question answering chains."""
|
||||
from typing import Any, Mapping, Optional, Protocol
|
||||
|
||||
from langchain.base_language_model import BaseLanguageModel
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
||||
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain
|
||||
@@ -15,7 +16,6 @@ from langchain.chains.question_answering import (
|
||||
stuff_prompt,
|
||||
)
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.schema import BaseLanguageModel
|
||||
|
||||
|
||||
class LoadingCallable(Protocol):
|
||||
|
||||
@@ -1,14 +1,12 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel, Extra, Field, validator
|
||||
from pydantic import BaseModel
|
||||
|
||||
import langchain
|
||||
from langchain.callbacks import get_callback_manager
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.base_language_model import BaseLanguageModel
|
||||
from langchain.schema import (
|
||||
AIMessage,
|
||||
BaseLanguageModel,
|
||||
BaseMessage,
|
||||
ChatGeneration,
|
||||
ChatResult,
|
||||
@@ -22,25 +20,7 @@ def _get_verbosity() -> bool:
|
||||
|
||||
|
||||
class BaseChatModel(BaseLanguageModel, BaseModel, ABC):
|
||||
verbose: bool = Field(default_factory=_get_verbosity)
|
||||
"""Whether to print out response text."""
|
||||
callback_manager: BaseCallbackManager = Field(default_factory=get_callback_manager)
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@validator("callback_manager", pre=True, always=True)
|
||||
def set_callback_manager(
|
||||
cls, callback_manager: Optional[BaseCallbackManager]
|
||||
) -> BaseCallbackManager:
|
||||
"""If callback manager is None, set it.
|
||||
|
||||
This allows users to pass in None as callback manager, which is a nice UX.
|
||||
"""
|
||||
return callback_manager or get_callback_manager()
|
||||
"""Base class for chat models."""
|
||||
|
||||
def generate(
|
||||
self, messages: List[List[BaseMessage]], stop: Optional[List[str]] = None
|
||||
@@ -56,13 +36,13 @@ class BaseChatModel(BaseLanguageModel, BaseModel, ABC):
|
||||
results = [await self._agenerate(m, stop=stop) for m in messages]
|
||||
return LLMResult(generations=[res.generations for res in results])
|
||||
|
||||
def generate_prompt(
|
||||
def _generate_prompt(
|
||||
self, prompts: List[PromptValue], stop: Optional[List[str]] = None
|
||||
) -> LLMResult:
|
||||
prompt_messages = [p.to_messages() for p in prompts]
|
||||
return self.generate(prompt_messages, stop=stop)
|
||||
|
||||
async def agenerate_prompt(
|
||||
async def _agenerate_prompt(
|
||||
self, prompts: List[PromptValue], stop: Optional[List[str]] = None
|
||||
) -> LLMResult:
|
||||
prompt_messages = [p.to_messages() for p in prompts]
|
||||
|
||||
@@ -5,16 +5,11 @@ from pathlib import Path
|
||||
from typing import Any, Dict, List, Mapping, Optional, Tuple, Union
|
||||
|
||||
import yaml
|
||||
from pydantic import BaseModel, Extra, Field, validator
|
||||
from pydantic import BaseModel, Extra
|
||||
|
||||
import langchain
|
||||
from langchain.callbacks import get_callback_manager
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.schema import BaseLanguageModel, Generation, LLMResult, PromptValue
|
||||
|
||||
|
||||
def _get_verbosity() -> bool:
|
||||
return langchain.verbose
|
||||
from langchain.base_language_model import BaseLanguageModel
|
||||
from langchain.schema import Generation, LLMResult, PromptValue
|
||||
|
||||
|
||||
def get_prompts(
|
||||
@@ -57,9 +52,6 @@ class BaseLLM(BaseLanguageModel, BaseModel, ABC):
|
||||
"""LLM wrapper should take in a prompt and return a string."""
|
||||
|
||||
cache: Optional[bool] = None
|
||||
verbose: bool = Field(default_factory=_get_verbosity)
|
||||
"""Whether to print out response text."""
|
||||
callback_manager: BaseCallbackManager = Field(default_factory=get_callback_manager)
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
@@ -67,27 +59,6 @@ class BaseLLM(BaseLanguageModel, BaseModel, ABC):
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@validator("callback_manager", pre=True, always=True)
|
||||
def set_callback_manager(
|
||||
cls, callback_manager: Optional[BaseCallbackManager]
|
||||
) -> BaseCallbackManager:
|
||||
"""If callback manager is None, set it.
|
||||
|
||||
This allows users to pass in None as callback manager, which is a nice UX.
|
||||
"""
|
||||
return callback_manager or get_callback_manager()
|
||||
|
||||
@validator("verbose", pre=True, always=True)
|
||||
def set_verbose(cls, verbose: Optional[bool]) -> bool:
|
||||
"""If verbose is None, set it.
|
||||
|
||||
This allows users to pass in None as verbose to access the global setting.
|
||||
"""
|
||||
if verbose is None:
|
||||
return _get_verbosity()
|
||||
else:
|
||||
return verbose
|
||||
|
||||
@abstractmethod
|
||||
def _generate(
|
||||
self, prompts: List[str], stop: Optional[List[str]] = None
|
||||
@@ -100,13 +71,13 @@ class BaseLLM(BaseLanguageModel, BaseModel, ABC):
|
||||
) -> LLMResult:
|
||||
"""Run the LLM on the given prompts."""
|
||||
|
||||
def generate_prompt(
|
||||
def _generate_prompt(
|
||||
self, prompts: List[PromptValue], stop: Optional[List[str]] = None
|
||||
) -> LLMResult:
|
||||
prompt_strings = [p.to_string() for p in prompts]
|
||||
return self.generate(prompt_strings, stop=stop)
|
||||
|
||||
async def agenerate_prompt(
|
||||
async def _agenerate_prompt(
|
||||
self, prompts: List[PromptValue], stop: Optional[List[str]] = None
|
||||
) -> LLMResult:
|
||||
prompt_strings = [p.to_string() for p in prompts]
|
||||
@@ -138,7 +109,6 @@ class BaseLLM(BaseLanguageModel, BaseModel, ABC):
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
self.callback_manager.on_llm_error(e, verbose=self.verbose)
|
||||
raise e
|
||||
self.callback_manager.on_llm_end(output, verbose=self.verbose)
|
||||
return output
|
||||
params = self.dict()
|
||||
params["stop"] = stop
|
||||
@@ -157,7 +127,6 @@ class BaseLLM(BaseLanguageModel, BaseModel, ABC):
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
self.callback_manager.on_llm_error(e, verbose=self.verbose)
|
||||
raise e
|
||||
self.callback_manager.on_llm_end(new_results, verbose=self.verbose)
|
||||
llm_output = update_cache(
|
||||
existing_prompts, llm_string, missing_prompt_idxs, new_results, prompts
|
||||
)
|
||||
|
||||
@@ -2,6 +2,7 @@ from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain.base_language_model import BaseLanguageModel
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.memory.chat_memory import BaseChatMemory
|
||||
from langchain.memory.prompt import (
|
||||
@@ -10,7 +11,7 @@ from langchain.memory.prompt import (
|
||||
)
|
||||
from langchain.memory.utils import get_buffer_string, get_prompt_input_key
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.schema import BaseLanguageModel, BaseMessage
|
||||
from langchain.schema import BaseMessage
|
||||
|
||||
|
||||
class ConversationEntityMemory(BaseChatMemory, BaseModel):
|
||||
|
||||
@@ -2,6 +2,7 @@ from typing import Any, Dict, List
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from langchain.base_language_model import BaseLanguageModel
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.graphs import NetworkxEntityGraph
|
||||
from langchain.graphs.networkx_graph import KnowledgeTriple, get_entities, parse_triples
|
||||
@@ -12,7 +13,7 @@ from langchain.memory.prompt import (
|
||||
)
|
||||
from langchain.memory.utils import get_buffer_string, get_prompt_input_key
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.schema import BaseLanguageModel, SystemMessage
|
||||
from langchain.schema import SystemMessage
|
||||
|
||||
|
||||
class ConversationKGMemory(BaseChatMemory, BaseModel):
|
||||
|
||||
@@ -2,12 +2,13 @@ from typing import Any, Dict, List
|
||||
|
||||
from pydantic import BaseModel, root_validator
|
||||
|
||||
from langchain.base_language_model import BaseLanguageModel
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.memory.chat_memory import BaseChatMemory
|
||||
from langchain.memory.prompt import SUMMARY_PROMPT
|
||||
from langchain.memory.utils import get_buffer_string
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.schema import BaseLanguageModel, BaseMessage, SystemMessage
|
||||
from langchain.schema import BaseMessage, SystemMessage
|
||||
|
||||
|
||||
class SummarizerMixin(BaseModel):
|
||||
|
||||
@@ -44,19 +44,26 @@ class BaseMessage(BaseModel):
|
||||
class HumanMessage(BaseMessage):
|
||||
"""Type of message that is spoken by the human."""
|
||||
|
||||
_type = "human"
|
||||
|
||||
|
||||
class AIMessage(BaseMessage):
|
||||
"""Type of message that is spoken by the AI."""
|
||||
|
||||
_type = "ai"
|
||||
|
||||
|
||||
class SystemMessage(BaseMessage):
|
||||
"""Type of message that is a system message."""
|
||||
|
||||
_type = "system"
|
||||
|
||||
|
||||
class ChatMessage(BaseMessage):
|
||||
"""Type of message with arbitrary speaker."""
|
||||
|
||||
role: str
|
||||
_type = "chat"
|
||||
|
||||
|
||||
class ChatGeneration(Generation):
|
||||
@@ -100,41 +107,6 @@ class PromptValue(BaseModel, ABC):
|
||||
"""Return prompt as messages."""
|
||||
|
||||
|
||||
class BaseLanguageModel(BaseModel, ABC):
|
||||
@abstractmethod
|
||||
def generate_prompt(
|
||||
self, prompts: List[PromptValue], stop: Optional[List[str]] = None
|
||||
) -> LLMResult:
|
||||
"""Take in a list of prompt values and return an LLMResult."""
|
||||
|
||||
@abstractmethod
|
||||
async def agenerate_prompt(
|
||||
self, prompts: List[PromptValue], stop: Optional[List[str]] = None
|
||||
) -> LLMResult:
|
||||
"""Take in a list of prompt values and return an LLMResult."""
|
||||
|
||||
def get_num_tokens(self, text: str) -> int:
|
||||
"""Get the number of tokens present in the text."""
|
||||
# TODO: this method may not be exact.
|
||||
# TODO: this method may differ based on model (eg codex).
|
||||
try:
|
||||
from transformers import GPT2TokenizerFast
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import transformers python package. "
|
||||
"This is needed in order to calculate get_num_tokens. "
|
||||
"Please it install it with `pip install transformers`."
|
||||
)
|
||||
# create a GPT-3 tokenizer instance
|
||||
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
|
||||
|
||||
# tokenize the text using the GPT-3 tokenizer
|
||||
tokenized_text = tokenizer.tokenize(text)
|
||||
|
||||
# calculate the number of tokens in the tokenized text
|
||||
return len(tokenized_text)
|
||||
|
||||
|
||||
class BaseMemory(BaseModel, ABC):
|
||||
"""Base interface for memory in chains."""
|
||||
|
||||
|
||||
@@ -85,10 +85,10 @@ def _perform_nested_run(tracer: BaseTracer) -> None:
|
||||
"""Perform a nested run."""
|
||||
tracer.on_chain_start(serialized={}, inputs={})
|
||||
tracer.on_tool_start(serialized={}, input_str="test")
|
||||
tracer.on_llm_start(serialized={}, prompts=[])
|
||||
tracer.on_llm_start_prompt_value(serialized={}, prompts=[])
|
||||
tracer.on_llm_end(response=LLMResult(generations=[[]]))
|
||||
tracer.on_tool_end("test")
|
||||
tracer.on_llm_start(serialized={}, prompts=[])
|
||||
tracer.on_llm_start_prompt_value(serialized={}, prompts=[])
|
||||
tracer.on_llm_end(response=LLMResult(generations=[[]]))
|
||||
tracer.on_chain_end(outputs={})
|
||||
|
||||
@@ -216,7 +216,7 @@ def test_tracer_llm_run() -> None:
|
||||
tracer = FakeTracer()
|
||||
|
||||
tracer.new_session()
|
||||
tracer.on_llm_start(serialized={}, prompts=[])
|
||||
tracer.on_llm_start_prompt_value(serialized={}, prompts=[])
|
||||
tracer.on_llm_end(response=LLMResult(generations=[[]]))
|
||||
assert tracer.runs == [compare_run]
|
||||
|
||||
@@ -227,7 +227,7 @@ def test_tracer_llm_run_errors_no_session() -> None:
|
||||
tracer = FakeTracer()
|
||||
|
||||
with pytest.raises(TracerException):
|
||||
tracer.on_llm_start(serialized={}, prompts=[])
|
||||
tracer.on_llm_start_prompt_value(serialized={}, prompts=[])
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
@@ -260,7 +260,7 @@ def test_tracer_multiple_llm_runs() -> None:
|
||||
tracer.new_session()
|
||||
num_runs = 10
|
||||
for _ in range(num_runs):
|
||||
tracer.on_llm_start(serialized={}, prompts=[])
|
||||
tracer.on_llm_start_prompt_value(serialized={}, prompts=[])
|
||||
tracer.on_llm_end(response=LLMResult(generations=[[]]))
|
||||
|
||||
assert tracer.runs == [compare_run] * num_runs
|
||||
@@ -342,7 +342,7 @@ def test_tracer_llm_run_on_error() -> None:
|
||||
tracer = FakeTracer()
|
||||
|
||||
tracer.new_session()
|
||||
tracer.on_llm_start(serialized={}, prompts=[])
|
||||
tracer.on_llm_start_prompt_value(serialized={}, prompts=[])
|
||||
tracer.on_llm_error(exception)
|
||||
assert tracer.runs == [compare_run]
|
||||
|
||||
@@ -408,12 +408,12 @@ def test_tracer_nested_runs_on_error() -> None:
|
||||
|
||||
for _ in range(3):
|
||||
tracer.on_chain_start(serialized={}, inputs={})
|
||||
tracer.on_llm_start(serialized={}, prompts=[])
|
||||
tracer.on_llm_start_prompt_value(serialized={}, prompts=[])
|
||||
tracer.on_llm_end(response=LLMResult(generations=[[]]))
|
||||
tracer.on_llm_start(serialized={}, prompts=[])
|
||||
tracer.on_llm_start_prompt_value(serialized={}, prompts=[])
|
||||
tracer.on_llm_end(response=LLMResult(generations=[[]]))
|
||||
tracer.on_tool_start(serialized={}, input_str="test")
|
||||
tracer.on_llm_start(serialized={}, prompts=[])
|
||||
tracer.on_llm_start_prompt_value(serialized={}, prompts=[])
|
||||
tracer.on_llm_error(exception)
|
||||
tracer.on_tool_error(exception)
|
||||
tracer.on_chain_error(exception)
|
||||
|
||||
Reference in New Issue
Block a user