Merge branch 'master' into harrison/chain_pipeline

This commit is contained in:
Harrison Chase 2022-11-19 09:34:05 -08:00
commit 3fcc803880
31 changed files with 300 additions and 191 deletions

View File

@ -1,6 +1,6 @@
name: lint name: lint
on: [push, pull_request_target] on: [push, pull_request]
jobs: jobs:
build: build:

View File

@ -1,6 +1,6 @@
name: test name: test
on: [push, pull_request_target] on: [push, pull_request]
jobs: jobs:
build: build:

View File

@ -37,7 +37,7 @@ This project was largely inspired by a few projects seen on Twitter for which we
**[Self-ask-with-search](https://ofir.io/self-ask.pdf)** **[Self-ask-with-search](https://ofir.io/self-ask.pdf)**
To recreate this paper, use the following code snippet or checkout the [example notebook](https://github.com/hwchase17/langchain/blob/master/examples/self_ask_with_search.ipynb). To recreate this paper, use the following code snippet or checkout the [example notebook](https://github.com/hwchase17/langchain/blob/master/docs/examples/demos/self_ask_with_search.ipynb).
```python ```python
from langchain import SelfAskWithSearchChain, OpenAI, SerpAPIChain from langchain import SelfAskWithSearchChain, OpenAI, SerpAPIChain
@ -52,7 +52,7 @@ self_ask_with_search.run("What is the hometown of the reigning men's U.S. Open c
**[LLM Math](https://twitter.com/amasad/status/1568824744367259648?s=20&t=-7wxpXBJinPgDuyHLouP1w)** **[LLM Math](https://twitter.com/amasad/status/1568824744367259648?s=20&t=-7wxpXBJinPgDuyHLouP1w)**
To recreate this example, use the following code snippet or check out the [example notebook](https://github.com/hwchase17/langchain/blob/master/examples/llm_math.ipynb). To recreate this example, use the following code snippet or check out the [example notebook](https://github.com/hwchase17/langchain/blob/master/docs/examples/demos/llm_math.ipynb).
```python ```python
from langchain import OpenAI, LLMMathChain from langchain import OpenAI, LLMMathChain
@ -65,7 +65,7 @@ llm_math.run("How many of the integers between 0 and 99 inclusive are divisible
**Generic Prompting** **Generic Prompting**
You can also use this for simple prompting pipelines, as in the below example and this [example notebook](https://github.com/hwchase17/langchain/blob/master/examples/simple_prompts.ipynb). You can also use this for simple prompting pipelines, as in the below example and this [example notebook](https://github.com/hwchase17/langchain/blob/master/docs/examples/demos/simple_prompts.ipynb).
```python ```python
from langchain import Prompt, OpenAI, LLMChain from langchain import Prompt, OpenAI, LLMChain
@ -84,7 +84,7 @@ llm_chain.predict(question=question)
**Embed & Search Documents** **Embed & Search Documents**
We support two vector databases to store and search embeddings -- FAISS and Elasticsearch. Here's a code snippet showing how to use FAISS to store embeddings and search for text similar to a query. Both database backends are featured in this [example notebook](https://github.com/hwchase17/langchain/blob/master/examples/embeddings.ipynb). We support two vector databases to store and search embeddings -- FAISS and Elasticsearch. Here's a code snippet showing how to use FAISS to store embeddings and search for text similar to a query. Both database backends are featured in this [example notebook](https://github.com/hwchase17/langchain/blob/master/docs/examples/integrations/embeddings.ipynb).
```python ```python
from langchain.embeddings.openai import OpenAIEmbeddings from langchain.embeddings.openai import OpenAIEmbeddings

View File

@ -42,7 +42,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"model_lab = ModelLaboratory(llms)" "model_lab = ModelLaboratory.from_llms(llms)"
] ]
}, },
{ {
@ -60,19 +60,19 @@
"\n", "\n",
"\u001b[1mOpenAI\u001b[0m\n", "\u001b[1mOpenAI\u001b[0m\n",
"Params: {'model': 'text-davinci-002', 'temperature': 0.0, 'max_tokens': 256, 'top_p': 1, 'frequency_penalty': 0, 'presence_penalty': 0, 'n': 1, 'best_of': 1}\n", "Params: {'model': 'text-davinci-002', 'temperature': 0.0, 'max_tokens': 256, 'top_p': 1, 'frequency_penalty': 0, 'presence_penalty': 0, 'n': 1, 'best_of': 1}\n",
"\u001b[104m\n", "\u001b[36;1m\u001b[1;3m\n",
"\n", "\n",
"Flamingos are pink.\u001b[0m\n", "Flamingos are pink.\u001b[0m\n",
"\n", "\n",
"\u001b[1mCohere\u001b[0m\n", "\u001b[1mCohere\u001b[0m\n",
"Params: {'model': 'command-xlarge-20221108', 'max_tokens': 20, 'temperature': 0.0, 'k': 0, 'p': 1, 'frequency_penalty': 0, 'presence_penalty': 0}\n", "Params: {'model': 'command-xlarge-20221108', 'max_tokens': 20, 'temperature': 0.0, 'k': 0, 'p': 1, 'frequency_penalty': 0, 'presence_penalty': 0}\n",
"\u001b[103m\n", "\u001b[33;1m\u001b[1;3m\n",
"\n", "\n",
"Pink\u001b[0m\n", "Pink\u001b[0m\n",
"\n", "\n",
"\u001b[1mHuggingFaceHub\u001b[0m\n", "\u001b[1mHuggingFaceHub\u001b[0m\n",
"Params: {'repo_id': 'google/flan-t5-xl', 'temperature': 1}\n", "Params: {'repo_id': 'google/flan-t5-xl', 'temperature': 1}\n",
"\u001b[101mpink\u001b[0m\n", "\u001b[38;5;200m\u001b[1;3mpink\u001b[0m\n",
"\n" "\n"
] ]
} }
@ -89,7 +89,7 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"prompt = Prompt(template=\"What is the capital of {state}?\", input_variables=[\"state\"])\n", "prompt = Prompt(template=\"What is the capital of {state}?\", input_variables=[\"state\"])\n",
"model_lab_with_prompt = ModelLaboratory(llms, prompt=prompt)" "model_lab_with_prompt = ModelLaboratory.from_llms(llms, prompt=prompt)"
] ]
}, },
{ {
@ -107,19 +107,19 @@
"\n", "\n",
"\u001b[1mOpenAI\u001b[0m\n", "\u001b[1mOpenAI\u001b[0m\n",
"Params: {'model': 'text-davinci-002', 'temperature': 0.0, 'max_tokens': 256, 'top_p': 1, 'frequency_penalty': 0, 'presence_penalty': 0, 'n': 1, 'best_of': 1}\n", "Params: {'model': 'text-davinci-002', 'temperature': 0.0, 'max_tokens': 256, 'top_p': 1, 'frequency_penalty': 0, 'presence_penalty': 0, 'n': 1, 'best_of': 1}\n",
"\u001b[104m\n", "\u001b[36;1m\u001b[1;3m\n",
"\n", "\n",
"The capital of New York is Albany.\u001b[0m\n", "The capital of New York is Albany.\u001b[0m\n",
"\n", "\n",
"\u001b[1mCohere\u001b[0m\n", "\u001b[1mCohere\u001b[0m\n",
"Params: {'model': 'command-xlarge-20221108', 'max_tokens': 20, 'temperature': 0.0, 'k': 0, 'p': 1, 'frequency_penalty': 0, 'presence_penalty': 0}\n", "Params: {'model': 'command-xlarge-20221108', 'max_tokens': 20, 'temperature': 0.0, 'k': 0, 'p': 1, 'frequency_penalty': 0, 'presence_penalty': 0}\n",
"\u001b[103m\n", "\u001b[33;1m\u001b[1;3m\n",
"\n", "\n",
"The capital of New York is Albany.\u001b[0m\n", "The capital of New York is Albany.\u001b[0m\n",
"\n", "\n",
"\u001b[1mHuggingFaceHub\u001b[0m\n", "\u001b[1mHuggingFaceHub\u001b[0m\n",
"Params: {'repo_id': 'google/flan-t5-xl', 'temperature': 1}\n", "Params: {'repo_id': 'google/flan-t5-xl', 'temperature': 1}\n",
"\u001b[101mst john s\u001b[0m\n", "\u001b[38;5;200m\u001b[1;3mst john s\u001b[0m\n",
"\n" "\n"
] ]
} }
@ -130,10 +130,103 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 7,
"id": "54336dbf", "id": "54336dbf",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [
"from langchain import SelfAskWithSearchChain, SerpAPIChain\n",
"\n",
"open_ai_llm = OpenAI(temperature=0)\n",
"search = SerpAPIChain()\n",
"self_ask_with_search_openai = SelfAskWithSearchChain(llm=open_ai_llm, search_chain=search, verbose=True)\n",
"\n",
"cohere_llm = Cohere(temperature=0, model=\"command-xlarge-20221108\")\n",
"search = SerpAPIChain()\n",
"self_ask_with_search_cohere = SelfAskWithSearchChain(llm=cohere_llm, search_chain=search, verbose=True)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "6a50a9f1",
"metadata": {},
"outputs": [],
"source": [
"chains = [self_ask_with_search_openai, self_ask_with_search_cohere]\n",
"names = [str(open_ai_llm), str(cohere_llm)]"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "d3549e99",
"metadata": {},
"outputs": [],
"source": [
"model_lab = ModelLaboratory(chains, names=names)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "362f7f57",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[1mInput:\u001b[0m\n",
"What is the hometown of the reigning men's U.S. Open champion?\n",
"\n",
"\u001b[1mOpenAI\u001b[0m\n",
"Params: {'model': 'text-davinci-002', 'temperature': 0.0, 'max_tokens': 256, 'top_p': 1, 'frequency_penalty': 0, 'presence_penalty': 0, 'n': 1, 'best_of': 1}\n",
"\n",
"\n",
"\u001b[1m> Entering new chain...\u001b[0m\n",
"What is the hometown of the reigning men's U.S. Open champion?\n",
"Are follow up questions needed here:\u001b[32;1m\u001b[1;3m Yes.\n",
"Follow up: Who is the reigning men's U.S. Open champion?\u001b[0m\n",
"Intermediate answer: \u001b[33;1m\u001b[1;3mCarlos Alcaraz.\u001b[0m\u001b[32;1m\u001b[1;3m\n",
"Follow up: Where is Carlos Alcaraz from?\u001b[0m\n",
"Intermediate answer: \u001b[33;1m\u001b[1;3mEl Palmar, Spain.\u001b[0m\u001b[32;1m\u001b[1;3m\n",
"So the final answer is: El Palmar, Spain\u001b[0m\n",
"\u001b[1m> Finished chain.\u001b[0m\n",
"\u001b[36;1m\u001b[1;3m\n",
"So the final answer is: El Palmar, Spain\u001b[0m\n",
"\n",
"\u001b[1mCohere\u001b[0m\n",
"Params: {'model': 'command-xlarge-20221108', 'max_tokens': 256, 'temperature': 0.0, 'k': 0, 'p': 1, 'frequency_penalty': 0, 'presence_penalty': 0}\n",
"\n",
"\n",
"\u001b[1m> Entering new chain...\u001b[0m\n",
"What is the hometown of the reigning men's U.S. Open champion?\n",
"Are follow up questions needed here:\u001b[32;1m\u001b[1;3m Yes.\n",
"Follow up: Who is the reigning men's U.S. Open champion?\u001b[0m\n",
"Intermediate answer: \u001b[33;1m\u001b[1;3mCarlos Alcaraz.\u001b[0m\u001b[32;1m\u001b[1;3m\n",
"So the final answer is:\n",
"\n",
"Carlos Alcaraz\u001b[0m\n",
"\u001b[1m> Finished chain.\u001b[0m\n",
"\u001b[33;1m\u001b[1;3m\n",
"So the final answer is:\n",
"\n",
"Carlos Alcaraz\u001b[0m\n",
"\n"
]
}
],
"source": [
"model_lab.compare(\"What is the hometown of the reigning men's U.S. Open champion?\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "94159131",
"metadata": {},
"outputs": [],
"source": [] "source": []
} }
], ],

View File

@ -1 +1 @@
0.0.13 0.0.16

View File

@ -9,7 +9,7 @@ class Chain(BaseModel, ABC):
"""Base interface that all chains should implement.""" """Base interface that all chains should implement."""
verbose: bool = False verbose: bool = False
"""Whether to print out the code that was executed.""" """Whether to print out response text."""
@property @property
@abstractmethod @abstractmethod
@ -49,6 +49,10 @@ class Chain(BaseModel, ABC):
self._validate_outputs(outputs) self._validate_outputs(outputs)
return {**inputs, **outputs} return {**inputs, **outputs}
def apply(self, input_list: List[Dict[str, Any]]) -> List[Dict[str, str]]:
"""Call the chain on all inputs in the list."""
return [self(inputs) for inputs in input_list]
def run(self, text: str) -> str: def run(self, text: str) -> str:
"""Run text in, text out (if applicable).""" """Run text in, text out (if applicable)."""
if len(self.input_keys) != 1: if len(self.input_keys) != 1:

View File

@ -59,16 +59,13 @@ class MapReduceChain(Chain, BaseModel):
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]: def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
# Split the larger text into smaller chunks. # Split the larger text into smaller chunks.
docs = self.text_splitter.split_text( docs = self.text_splitter.split_text(inputs[self.input_key])
inputs[self.input_key],
)
# Now that we have the chunks, we send them to the LLM and track results. # Now that we have the chunks, we send them to the LLM and track results.
# This is the "map" part. # This is the "map" part.
summaries = [] input_list = [{self.map_llm.prompt.input_variables[0]: d} for d in docs]
for d in docs: summary_results = self.map_llm.apply(input_list)
inputs = {self.map_llm.prompt.input_variables[0]: d} summaries = [res[self.map_llm.output_key] for res in summary_results]
res = self.map_llm.predict(**inputs)
summaries.append(res)
# We then need to combine these individual parts into one. # We then need to combine these individual parts into one.
# This is the reduce part. # This is the reduce part.

View File

@ -28,14 +28,7 @@ class Crawler:
"Could not import playwright python package. " "Could not import playwright python package. "
"Please it install it with `pip install playwright`." "Please it install it with `pip install playwright`."
) )
self.browser = ( self.browser = sync_playwright().start().chromium.launch(headless=False)
sync_playwright()
.start()
.chromium.launch(
headless=False,
)
)
self.page = self.browser.new_page() self.page = self.browser.new_page()
self.page.set_viewport_size({"width": 1280, "height": 1080}) self.page.set_viewport_size({"width": 1280, "height": 1080})

View File

@ -109,8 +109,4 @@ Action 3: Finish[yes]""",
] ]
SUFFIX = """\n\nQuestion: {input}""" SUFFIX = """\n\nQuestion: {input}"""
PROMPT = Prompt.from_examples( PROMPT = Prompt.from_examples(EXAMPLES, SUFFIX, ["input"])
EXAMPLES,
SUFFIX,
["input"],
)

View File

@ -38,7 +38,4 @@ Intermediate Answer: New Zealand.
So the final answer is: No So the final answer is: No
Question: {input}""" Question: {input}"""
PROMPT = Prompt( PROMPT = Prompt(input_variables=["input"], template=_DEFAULT_TEMPLATE)
input_variables=["input"],
template=_DEFAULT_TEMPLATE,
)

View File

@ -9,6 +9,7 @@ from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Extra, root_validator from pydantic import BaseModel, Extra, root_validator
from langchain.chains.base import Chain from langchain.chains.base import Chain
from langchain.utils import get_from_dict_or_env
class HiddenPrints: class HiddenPrints:
@ -43,7 +44,7 @@ class SerpAPIChain(Chain, BaseModel):
input_key: str = "search_query" #: :meta private: input_key: str = "search_query" #: :meta private:
output_key: str = "search_result" #: :meta private: output_key: str = "search_result" #: :meta private:
serpapi_api_key: Optional[str] = os.environ.get("SERPAPI_API_KEY") serpapi_api_key: Optional[str] = None
class Config: class Config:
"""Configuration for this pydantic object.""" """Configuration for this pydantic object."""
@ -69,14 +70,10 @@ class SerpAPIChain(Chain, BaseModel):
@root_validator() @root_validator()
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment.""" """Validate that api key and python package exists in environment."""
serpapi_api_key = values.get("serpapi_api_key") serpapi_api_key = get_from_dict_or_env(
values, "serpapi_api_key", "SERPAPI_API_KEY"
if serpapi_api_key is None or serpapi_api_key == "": )
raise ValueError( values["serpapi_api_key"] = serpapi_api_key
"Did not find SerpAPI API key, please add an environment variable"
" `SERPAPI_API_KEY` which contains it, or pass `serpapi_api_key` "
"as a named parameter to the constructor."
)
try: try:
from serpapi import GoogleSearch from serpapi import GoogleSearch

View File

@ -15,6 +15,5 @@ Only use the following tables:
Question: {input}""" Question: {input}"""
PROMPT = Prompt( PROMPT = Prompt(
input_variables=["input", "table_info", "dialect"], input_variables=["input", "table_info", "dialect"], template=_DEFAULT_TEMPLATE
template=_DEFAULT_TEMPLATE,
) )

View File

@ -27,6 +27,8 @@ class VectorDBQA(Chain, BaseModel):
"""LLM wrapper to use.""" """LLM wrapper to use."""
vectorstore: VectorStore vectorstore: VectorStore
"""Vector Database to connect to.""" """Vector Database to connect to."""
k: int = 4
"""Number of documents to query for."""
input_key: str = "query" #: :meta private: input_key: str = "query" #: :meta private:
output_key: str = "result" #: :meta private: output_key: str = "result" #: :meta private:
@ -55,7 +57,7 @@ class VectorDBQA(Chain, BaseModel):
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]: def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
question = inputs[self.input_key] question = inputs[self.input_key]
llm_chain = LLMChain(llm=self.llm, prompt=prompt) llm_chain = LLMChain(llm=self.llm, prompt=prompt)
docs = self.vectorstore.similarity_search(question) docs = self.vectorstore.similarity_search(question, k=self.k)
contexts = [] contexts = []
for j, doc in enumerate(docs): for j, doc in enumerate(docs):
contexts.append(f"Context {j}:\n{doc.page_content}") contexts.append(f"Context {j}:\n{doc.page_content}")

View File

@ -1,7 +1,7 @@
"""Interface for interacting with a document.""" """Interface for interacting with a document."""
from typing import List from typing import List
from pydantic import BaseModel from pydantic import BaseModel, Field
class Document(BaseModel): class Document(BaseModel):
@ -10,6 +10,7 @@ class Document(BaseModel):
page_content: str page_content: str
lookup_str: str = "" lookup_str: str = ""
lookup_index = 0 lookup_index = 0
metadata: dict = Field(default_factory=dict)
@property @property
def paragraphs(self) -> List[str]: def paragraphs(self) -> List[str]:

View File

@ -1,10 +1,10 @@
"""Wrapper around Cohere embedding models.""" """Wrapper around Cohere embedding models."""
import os
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Extra, root_validator from pydantic import BaseModel, Extra, root_validator
from langchain.embeddings.base import Embeddings from langchain.embeddings.base import Embeddings
from langchain.utils import get_from_dict_or_env
class CohereEmbeddings(BaseModel, Embeddings): class CohereEmbeddings(BaseModel, Embeddings):
@ -25,7 +25,7 @@ class CohereEmbeddings(BaseModel, Embeddings):
model: str = "medium" model: str = "medium"
"""Model name to use.""" """Model name to use."""
cohere_api_key: Optional[str] = os.environ.get("COHERE_API_KEY") cohere_api_key: Optional[str] = None
class Config: class Config:
"""Configuration for this pydantic object.""" """Configuration for this pydantic object."""
@ -35,14 +35,9 @@ class CohereEmbeddings(BaseModel, Embeddings):
@root_validator() @root_validator()
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment.""" """Validate that api key and python package exists in environment."""
cohere_api_key = values.get("cohere_api_key") cohere_api_key = get_from_dict_or_env(
values, "cohere_api_key", "COHERE_API_KEY"
if cohere_api_key is None or cohere_api_key == "": )
raise ValueError(
"Did not find Cohere API key, please add an environment variable"
" `COHERE_API_KEY` which contains it, or pass `cohere_api_key` as a"
" named parameter."
)
try: try:
import cohere import cohere

View File

@ -1,10 +1,10 @@
"""Wrapper around OpenAI embedding models.""" """Wrapper around OpenAI embedding models."""
import os
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Extra, root_validator from pydantic import BaseModel, Extra, root_validator
from langchain.embeddings.base import Embeddings from langchain.embeddings.base import Embeddings
from langchain.utils import get_from_dict_or_env
class OpenAIEmbeddings(BaseModel, Embeddings): class OpenAIEmbeddings(BaseModel, Embeddings):
@ -25,7 +25,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
model_name: str = "babbage" model_name: str = "babbage"
"""Model name to use.""" """Model name to use."""
openai_api_key: Optional[str] = os.environ.get("OPENAI_API_KEY") openai_api_key: Optional[str] = None
class Config: class Config:
"""Configuration for this pydantic object.""" """Configuration for this pydantic object."""
@ -35,14 +35,9 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
@root_validator() @root_validator()
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment.""" """Validate that api key and python package exists in environment."""
openai_api_key = values.get("openai_api_key") openai_api_key = get_from_dict_or_env(
values, "openai_api_key", "OPENAI_API_KEY"
if openai_api_key is None or openai_api_key == "": )
raise ValueError(
"Did not find OpenAI API key, please add an environment variable"
" `OPENAI_API_KEY` which contains it, or pass `openai_api_key` as a"
" named parameter."
)
try: try:
import openai import openai

View File

@ -1,14 +1,19 @@
"""Handle chained inputs.""" """Handle chained inputs."""
from typing import Dict, List, Optional from typing import Dict, List, Optional
_COLOR_MAPPING = {"blue": 104, "yellow": 103, "red": 101, "green": 102} _TEXT_COLOR_MAPPING = {
"blue": "36;1",
"yellow": "33;1",
"pink": "38;5;200",
"green": "32;1",
}
def get_color_mapping( def get_color_mapping(
items: List[str], excluded_colors: Optional[List] = None items: List[str], excluded_colors: Optional[List] = None
) -> Dict[str, str]: ) -> Dict[str, str]:
"""Get mapping for items to a support color.""" """Get mapping for items to a support color."""
colors = list(_COLOR_MAPPING.keys()) colors = list(_TEXT_COLOR_MAPPING.keys())
if excluded_colors is not None: if excluded_colors is not None:
colors = [c for c in colors if c not in excluded_colors] colors = [c for c in colors if c not in excluded_colors]
color_mapping = {item: colors[i % len(colors)] for i, item in enumerate(items)} color_mapping = {item: colors[i % len(colors)] for i, item in enumerate(items)}
@ -20,8 +25,8 @@ def print_text(text: str, color: Optional[str] = None, end: str = "") -> None:
if color is None: if color is None:
print(text, end=end) print(text, end=end)
else: else:
color_str = _COLOR_MAPPING[color] color_str = _TEXT_COLOR_MAPPING[color]
print(f"\x1b[{color_str}m{text}\x1b[0m", end=end) print(f"\u001b[{color_str}m\033[1;3m{text}\u001b[0m", end=end)
class ChainedInput: class ChainedInput:
@ -29,14 +34,14 @@ class ChainedInput:
def __init__(self, text: str, verbose: bool = False): def __init__(self, text: str, verbose: bool = False):
"""Initialize with verbose flag and initial text.""" """Initialize with verbose flag and initial text."""
self.verbose = verbose self._verbose = verbose
if self.verbose: if self._verbose:
print_text(text, None) print_text(text, None)
self._input = text self._input = text
def add(self, text: str, color: Optional[str] = None) -> None: def add(self, text: str, color: Optional[str] = None) -> None:
"""Add text to input, print if in verbose mode.""" """Add text to input, print if in verbose mode."""
if self.verbose: if self._verbose:
print_text(text, color) print_text(text, color)
self._input += text self._input += text

View File

@ -1,11 +1,11 @@
"""Wrapper around AI21 APIs.""" """Wrapper around AI21 APIs."""
import os
from typing import Any, Dict, List, Mapping, Optional from typing import Any, Dict, List, Mapping, Optional
import requests import requests
from pydantic import BaseModel, Extra, root_validator from pydantic import BaseModel, Extra, root_validator
from langchain.llms.base import LLM from langchain.llms.base import LLM
from langchain.utils import get_from_dict_or_env
class AI21PenaltyData(BaseModel): class AI21PenaltyData(BaseModel):
@ -62,7 +62,7 @@ class AI21(BaseModel, LLM):
logitBias: Optional[Dict[str, float]] = None logitBias: Optional[Dict[str, float]] = None
"""Adjust the probability of specific tokens being generated.""" """Adjust the probability of specific tokens being generated."""
ai21_api_key: Optional[str] = os.environ.get("AI21_API_KEY") ai21_api_key: Optional[str] = None
class Config: class Config:
"""Configuration for this pydantic object.""" """Configuration for this pydantic object."""
@ -72,14 +72,8 @@ class AI21(BaseModel, LLM):
@root_validator() @root_validator()
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key exists in environment.""" """Validate that api key exists in environment."""
ai21_api_key = values.get("ai21_api_key") ai21_api_key = get_from_dict_or_env(values, "ai21_api_key", "AI21_API_KEY")
values["ai21_api_key"] = ai21_api_key
if ai21_api_key is None or ai21_api_key == "":
raise ValueError(
"Did not find AI21 API key, please add an environment variable"
" `AI21_API_KEY` which contains it, or pass `ai21_api_key`"
" as a named parameter."
)
return values return values
@property @property
@ -122,11 +116,7 @@ class AI21(BaseModel, LLM):
response = requests.post( response = requests.post(
url=f"https://api.ai21.com/studio/v1/{self.model}/complete", url=f"https://api.ai21.com/studio/v1/{self.model}/complete",
headers={"Authorization": f"Bearer {self.ai21_api_key}"}, headers={"Authorization": f"Bearer {self.ai21_api_key}"},
json={ json={"prompt": prompt, "stopSequences": stop, **self._default_params},
"prompt": prompt,
"stopSequences": stop,
**self._default_params,
},
) )
if response.status_code != 200: if response.status_code != 200:
optional_detail = response.json().get("error") optional_detail = response.json().get("error")

View File

@ -1,11 +1,11 @@
"""Wrapper around Cohere APIs.""" """Wrapper around Cohere APIs."""
import os
from typing import Any, Dict, List, Mapping, Optional from typing import Any, Dict, List, Mapping, Optional
from pydantic import BaseModel, Extra, root_validator from pydantic import BaseModel, Extra, root_validator
from langchain.llms.base import LLM from langchain.llms.base import LLM
from langchain.llms.utils import enforce_stop_tokens from langchain.llms.utils import enforce_stop_tokens
from langchain.utils import get_from_dict_or_env
class Cohere(LLM, BaseModel): class Cohere(LLM, BaseModel):
@ -44,7 +44,7 @@ class Cohere(LLM, BaseModel):
presence_penalty: int = 0 presence_penalty: int = 0
"""Penalizes repeated tokens.""" """Penalizes repeated tokens."""
cohere_api_key: Optional[str] = os.environ.get("COHERE_API_KEY") cohere_api_key: Optional[str] = None
class Config: class Config:
"""Configuration for this pydantic object.""" """Configuration for this pydantic object."""
@ -54,14 +54,9 @@ class Cohere(LLM, BaseModel):
@root_validator() @root_validator()
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment.""" """Validate that api key and python package exists in environment."""
cohere_api_key = values.get("cohere_api_key") cohere_api_key = get_from_dict_or_env(
values, "cohere_api_key", "COHERE_API_KEY"
if cohere_api_key is None or cohere_api_key == "": )
raise ValueError(
"Did not find Cohere API key, please add an environment variable"
" `COHERE_API_KEY` which contains it, or pass `cohere_api_key`"
" as a named parameter."
)
try: try:
import cohere import cohere

View File

@ -1,11 +1,11 @@
"""Wrapper around HuggingFace APIs.""" """Wrapper around HuggingFace APIs."""
import os
from typing import Any, Dict, List, Mapping, Optional from typing import Any, Dict, List, Mapping, Optional
from pydantic import BaseModel, Extra, root_validator from pydantic import BaseModel, Extra, root_validator
from langchain.llms.base import LLM from langchain.llms.base import LLM
from langchain.llms.utils import enforce_stop_tokens from langchain.llms.utils import enforce_stop_tokens
from langchain.utils import get_from_dict_or_env
DEFAULT_REPO_ID = "gpt2" DEFAULT_REPO_ID = "gpt2"
VALID_TASKS = ("text2text-generation", "text-generation") VALID_TASKS = ("text2text-generation", "text-generation")
@ -18,7 +18,7 @@ class HuggingFaceHub(LLM, BaseModel):
environment variable ``HUGGINGFACEHUB_API_TOKEN`` set with your API token, or pass environment variable ``HUGGINGFACEHUB_API_TOKEN`` set with your API token, or pass
it as a named parameter to the constructor. it as a named parameter to the constructor.
Only supports task `text-generation` for now. Only supports `text-generation` and `text2text-generation` for now.
Example: Example:
.. code-block:: python .. code-block:: python
@ -35,7 +35,7 @@ class HuggingFaceHub(LLM, BaseModel):
model_kwargs: Optional[dict] = None model_kwargs: Optional[dict] = None
"""Key word arguments to pass to the model.""" """Key word arguments to pass to the model."""
huggingfacehub_api_token: Optional[str] = os.environ.get("HUGGINGFACEHUB_API_TOKEN") huggingfacehub_api_token: Optional[str] = None
class Config: class Config:
"""Configuration for this pydantic object.""" """Configuration for this pydantic object."""
@ -45,13 +45,9 @@ class HuggingFaceHub(LLM, BaseModel):
@root_validator() @root_validator()
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment.""" """Validate that api key and python package exists in environment."""
huggingfacehub_api_token = values.get("huggingfacehub_api_token") huggingfacehub_api_token = get_from_dict_or_env(
if huggingfacehub_api_token is None or huggingfacehub_api_token == "": values, "huggingfacehub_api_token", "HUGGINGFACEHUB_API_TOKEN"
raise ValueError( )
"Did not find HuggingFace API token, please add an environment variable"
" `HUGGINGFACEHUB_API_TOKEN` which contains it, or pass"
" `huggingfacehub_api_token` as a named parameter."
)
try: try:
from huggingface_hub.inference_api import InferenceApi from huggingface_hub.inference_api import InferenceApi

View File

@ -1,10 +1,10 @@
"""Wrapper around NLPCloud APIs.""" """Wrapper around NLPCloud APIs."""
import os
from typing import Any, Dict, List, Mapping, Optional from typing import Any, Dict, List, Mapping, Optional
from pydantic import BaseModel, Extra, root_validator from pydantic import BaseModel, Extra, root_validator
from langchain.llms.base import LLM from langchain.llms.base import LLM
from langchain.utils import get_from_dict_or_env
class NLPCloud(LLM, BaseModel): class NLPCloud(LLM, BaseModel):
@ -54,7 +54,7 @@ class NLPCloud(LLM, BaseModel):
num_return_sequences: int = 1 num_return_sequences: int = 1
"""How many completions to generate for each prompt.""" """How many completions to generate for each prompt."""
nlpcloud_api_key: Optional[str] = os.environ.get("NLPCLOUD_API_KEY") nlpcloud_api_key: Optional[str] = None
class Config: class Config:
"""Configuration for this pydantic object.""" """Configuration for this pydantic object."""
@ -64,14 +64,9 @@ class NLPCloud(LLM, BaseModel):
@root_validator() @root_validator()
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment.""" """Validate that api key and python package exists in environment."""
nlpcloud_api_key = values.get("nlpcloud_api_key") nlpcloud_api_key = get_from_dict_or_env(
values, "nlpcloud_api_key", "NLPCLOUD_API_KEY"
if nlpcloud_api_key is None or nlpcloud_api_key == "": )
raise ValueError(
"Did not find NLPCloud API key, please add an environment variable"
" `NLPCLOUD_API_KEY` which contains it, or pass `nlpcloud_api_key`"
" as a named parameter."
)
try: try:
import nlpcloud import nlpcloud

View File

@ -1,10 +1,10 @@
"""Wrapper around OpenAI APIs.""" """Wrapper around OpenAI APIs."""
import os
from typing import Any, Dict, List, Mapping, Optional from typing import Any, Dict, List, Mapping, Optional
from pydantic import BaseModel, Extra, root_validator from pydantic import BaseModel, Extra, root_validator
from langchain.llms.base import LLM from langchain.llms.base import LLM
from langchain.utils import get_from_dict_or_env
class OpenAI(LLM, BaseModel): class OpenAI(LLM, BaseModel):
@ -38,7 +38,7 @@ class OpenAI(LLM, BaseModel):
best_of: int = 1 best_of: int = 1
"""Generates best_of completions server-side and returns the "best".""" """Generates best_of completions server-side and returns the "best"."""
openai_api_key: Optional[str] = os.environ.get("OPENAI_API_KEY") openai_api_key: Optional[str] = None
class Config: class Config:
"""Configuration for this pydantic object.""" """Configuration for this pydantic object."""
@ -48,14 +48,9 @@ class OpenAI(LLM, BaseModel):
@root_validator() @root_validator()
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment.""" """Validate that api key and python package exists in environment."""
openai_api_key = values.get("openai_api_key") openai_api_key = get_from_dict_or_env(
values, "openai_api_key", "OPENAI_API_KEY"
if openai_api_key is None or openai_api_key == "": )
raise ValueError(
"Did not find OpenAI API key, please add an environment variable"
" `OPENAI_API_KEY` which contains it, or pass `openai_api_key`"
" as a named parameter."
)
try: try:
import openai import openai

View File

@ -1,6 +1,7 @@
"""Experiment with different models.""" """Experiment with different models."""
from typing import List, Optional from typing import List, Optional, Sequence
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
from langchain.input import get_color_mapping, print_text from langchain.input import get_color_mapping, print_text
from langchain.llms.base import LLM from langchain.llms.base import LLM
@ -10,7 +11,41 @@ from langchain.prompts.prompt import Prompt
class ModelLaboratory: class ModelLaboratory:
"""Experiment with different models.""" """Experiment with different models."""
def __init__(self, llms: List[LLM], prompt: Optional[Prompt] = None): def __init__(self, chains: Sequence[Chain], names: Optional[List[str]] = None):
"""Initialize with chains to experiment with.
Args:
chains: list of chains to experiment with.
"""
if not isinstance(chains[0], Chain):
raise ValueError(
"ModelLaboratory should now be initialized with Chains. "
"If you want to initialize with LLMs, use the `from_llms` method "
"instead (`ModelLaboratory.from_llms(...)`)"
)
for chain in chains:
if len(chain.input_keys) != 1:
raise ValueError(
"Currently only support chains with one input variable, "
f"got {chain.input_keys}"
)
if len(chain.output_keys) != 1:
raise ValueError(
"Currently only support chains with one output variable, "
f"got {chain.output_keys}"
)
if names is not None:
if len(names) != len(chains):
raise ValueError("Length of chains does not match length of names.")
self.chains = chains
chain_range = [str(i) for i in range(len(self.chains))]
self.chain_colors = get_color_mapping(chain_range)
self.names = names
@classmethod
def from_llms(
cls, llms: List[LLM], prompt: Optional[Prompt] = None
) -> "ModelLaboratory":
"""Initialize with LLMs to experiment with and optional prompt. """Initialize with LLMs to experiment with and optional prompt.
Args: Args:
@ -18,18 +53,11 @@ class ModelLaboratory:
prompt: Optional prompt to use to prompt the LLMs. Defaults to None. prompt: Optional prompt to use to prompt the LLMs. Defaults to None.
If a prompt was provided, it should only have one input variable. If a prompt was provided, it should only have one input variable.
""" """
self.llms = llms
llm_range = [str(i) for i in range(len(self.llms))]
self.llm_colors = get_color_mapping(llm_range)
if prompt is None: if prompt is None:
self.prompt = Prompt(input_variables=["_input"], template="{_input}") prompt = Prompt(input_variables=["_input"], template="{_input}")
else: chains = [LLMChain(llm=llm, prompt=prompt) for llm in llms]
if len(prompt.input_variables) != 1: names = [str(llm) for llm in llms]
raise ValueError( return cls(chains, names=names)
"Currently only support prompts with one input variable, "
f"got {prompt}"
)
self.prompt = prompt
def compare(self, text: str) -> None: def compare(self, text: str) -> None:
"""Compare model outputs on an input text. """Compare model outputs on an input text.
@ -42,9 +70,11 @@ class ModelLaboratory:
text: input text to run all models on. text: input text to run all models on.
""" """
print(f"\033[1mInput:\033[0m\n{text}\n") print(f"\033[1mInput:\033[0m\n{text}\n")
for i, llm in enumerate(self.llms): for i, chain in enumerate(self.chains):
print_text(str(llm), end="\n") if self.names is not None:
chain = LLMChain(llm=llm, prompt=self.prompt) name = self.names[i]
llm_inputs = {self.prompt.input_variables[0]: text} else:
output = chain.predict(**llm_inputs) name = str(chain)
print_text(output, color=self.llm_colors[str(i)], end="\n\n") print_text(name, end="\n")
output = chain.run(text)
print_text(output, color=self.chain_colors[str(i)], end="\n\n")

View File

@ -94,8 +94,7 @@ class Prompt(BaseModel, BasePrompt):
Returns: Returns:
The final prompt generated. The final prompt generated.
""" """
example_str = example_separator.join(examples) template = example_separator.join([prefix, *examples, suffix])
template = prefix + example_str + suffix
return cls(input_variables=input_variables, template=template) return cls(input_variables=input_variables, template=template)
@classmethod @classmethod

View File

@ -1,4 +1,6 @@
"""SQLAlchemy wrapper around a database.""" """SQLAlchemy wrapper around a database."""
from typing import Any, Iterable, List, Optional
from sqlalchemy import create_engine, inspect from sqlalchemy import create_engine, inspect
from sqlalchemy.engine import Engine from sqlalchemy.engine import Engine
@ -6,29 +8,57 @@ from sqlalchemy.engine import Engine
class SQLDatabase: class SQLDatabase:
"""SQLAlchemy wrapper around a database.""" """SQLAlchemy wrapper around a database."""
def __init__(self, engine: Engine): def __init__(
self,
engine: Engine,
ignore_tables: Optional[List[str]] = None,
include_tables: Optional[List[str]] = None,
):
"""Create engine from database URI.""" """Create engine from database URI."""
self._engine = engine self._engine = engine
if include_tables and ignore_tables:
raise ValueError("Cannot specify both include_tables and ignore_tables")
self._inspector = inspect(self._engine)
self._all_tables = self._inspector.get_table_names()
self._include_tables = include_tables or []
if self._include_tables:
missing_tables = set(self._include_tables).difference(self._all_tables)
if missing_tables:
raise ValueError(
f"include_tables {missing_tables} not found in database"
)
self._ignore_tables = ignore_tables or []
if self._ignore_tables:
missing_tables = set(self._ignore_tables).difference(self._all_tables)
if missing_tables:
raise ValueError(
f"ignore_tables {missing_tables} not found in database"
)
@classmethod @classmethod
def from_uri(cls, database_uri: str) -> "SQLDatabase": def from_uri(cls, database_uri: str, **kwargs: Any) -> "SQLDatabase":
"""Construct a SQLAlchemy engine from URI.""" """Construct a SQLAlchemy engine from URI."""
return cls(create_engine(database_uri)) return cls(create_engine(database_uri), **kwargs)
@property @property
def dialect(self) -> str: def dialect(self) -> str:
"""Return string representation of dialect to use.""" """Return string representation of dialect to use."""
return self._engine.dialect.name return self._engine.dialect.name
def _get_table_names(self) -> Iterable[str]:
if self._include_tables:
return self._include_tables
return set(self._all_tables) - set(self._ignore_tables)
@property @property
def table_info(self) -> str: def table_info(self) -> str:
"""Information about all tables in the database.""" """Information about all tables in the database."""
template = "The '{table_name}' table has columns: {columns}." template = "Table '{table_name}' has columns: {columns}."
tables = [] tables = []
inspector = inspect(self._engine) for table_name in self._get_table_names():
for table_name in inspector.get_table_names():
columns = [] columns = []
for column in inspector.get_columns(table_name): for column in self._inspector.get_columns(table_name):
columns.append(f"{column['name']} ({str(column['type'])})") columns.append(f"{column['name']} ({str(column['type'])})")
column_str = ", ".join(columns) column_str = ", ".join(columns)
table_str = template.format(table_name=table_name, columns=column_str) table_str = template.format(table_name=table_name, columns=column_str)

17
langchain/utils.py Normal file
View File

@ -0,0 +1,17 @@
"""Generic utility functions."""
import os
from typing import Any, Dict
def get_from_dict_or_env(data: Dict[str, Any], key: str, env_key: str) -> str:
"""Get a value from a dictionary or an environment variable."""
if key in data and data[key]:
return data[key]
elif env_key in os.environ and os.environ[env_key]:
return os.environ[env_key]
else:
raise ValueError(
f"Did not find {key}, please add an environment variable"
f" `{env_key}` which contains it, or pass"
f" `{key}` as a named parameter."
)

View File

@ -1,10 +1,10 @@
"""Wrapper around Elasticsearch vector database.""" """Wrapper around Elasticsearch vector database."""
import os
import uuid import uuid
from typing import Any, Callable, Dict, List from typing import Any, Callable, Dict, List
from langchain.docstore.document import Document from langchain.docstore.document import Document
from langchain.embeddings.base import Embeddings from langchain.embeddings.base import Embeddings
from langchain.utils import get_from_dict_or_env
from langchain.vectorstores.base import VectorStore from langchain.vectorstores.base import VectorStore
@ -45,10 +45,7 @@ class ElasticVectorSearch(VectorStore):
""" """
def __init__( def __init__(
self, self, elasticsearch_url: str, index_name: str, embedding_function: Callable
elasticsearch_url: str,
index_name: str,
embedding_function: Callable,
): ):
"""Initialize with necessary components.""" """Initialize with necessary components."""
try: try:
@ -110,16 +107,9 @@ class ElasticVectorSearch(VectorStore):
elasticsearch_url="http://localhost:9200" elasticsearch_url="http://localhost:9200"
) )
""" """
elasticsearch_url = kwargs.get("elasticsearch_url") elasticsearch_url = get_from_dict_or_env(
if not elasticsearch_url: kwargs, "elasticsearch_url", "ELASTICSEARCH_URL"
elasticsearch_url = os.environ.get("ELASTICSEARCH_URL") )
if elasticsearch_url is None or elasticsearch_url == "":
raise ValueError(
"Did not find Elasticsearch URL, please add an environment variable"
" `ELASTICSEARCH_URL` which contains it, or pass"
" `elasticsearch_url` as a named parameter."
)
try: try:
import elasticsearch import elasticsearch
from elasticsearch.helpers import bulk from elasticsearch.helpers import bulk

View File

@ -6,9 +6,7 @@ def test_manifest_wrapper() -> None:
"""Test manifest wrapper.""" """Test manifest wrapper."""
from manifest import Manifest from manifest import Manifest
manifest = Manifest( manifest = Manifest(client_name="openai")
client_name="openai",
)
llm = ManifestWrapper(client=manifest, llm_kwargs={"temperature": 0}) llm = ManifestWrapper(client=manifest, llm_kwargs={"temperature": 0})
output = llm("The capital of New York is:") output = llm("The capital of New York is:")
assert output == "Albany" assert output == "Albany"

View File

@ -48,7 +48,7 @@ def test_chained_input_verbose() -> None:
chained_input.add("baz", color="blue") chained_input.add("baz", color="blue")
sys.stdout = old_stdout sys.stdout = old_stdout
output = mystdout.getvalue() output = mystdout.getvalue()
assert output == "\x1b[104mbaz\x1b[0m" assert output == "\x1b[36;1m\x1b[1;3mbaz\x1b[0m"
assert chained_input.input == "foobarbaz" assert chained_input.input == "foobarbaz"
@ -70,5 +70,5 @@ def test_get_color_mapping_excluded_colors() -> None:
"""Test getting of color mapping with excluded colors.""" """Test getting of color mapping with excluded colors."""
items = ["foo", "bar"] items = ["foo", "bar"]
output = get_color_mapping(items, excluded_colors=["blue"]) output = get_color_mapping(items, excluded_colors=["blue"])
expected_output = {"foo": "yellow", "bar": "red"} expected_output = {"foo": "yellow", "bar": "pink"}
assert output == expected_output assert output == expected_output

View File

@ -51,8 +51,8 @@ Question: {question}
Answer:""" Answer:"""
input_variables = ["question"] input_variables = ["question"]
example_separator = "\n\n" example_separator = "\n\n"
prefix = """Test Prompt:\n\n""" prefix = """Test Prompt:"""
suffix = """\n\nQuestion: {question}\nAnswer:""" suffix = """Question: {question}\nAnswer:"""
examples = [ examples = [
"""Question: who are you?\nAnswer: foo""", """Question: who are you?\nAnswer: foo""",
"""Question: what are you?\nAnswer: bar""", """Question: what are you?\nAnswer: bar""",

View File

@ -28,11 +28,11 @@ def test_table_info() -> None:
db = SQLDatabase(engine) db = SQLDatabase(engine)
output = db.table_info output = db.table_info
expected_output = ( expected_output = (
"The 'company' table has columns: company_id (INTEGER), " "Table 'company' has columns: company_id (INTEGER), "
"company_location (VARCHAR).\n" "company_location (VARCHAR).",
"The 'user' table has columns: user_id (INTEGER), user_name (VARCHAR(16))." "Table 'user' has columns: user_id (INTEGER), user_name (VARCHAR(16)).",
) )
assert output == expected_output assert sorted(output.split("\n")) == sorted(expected_output)
def test_sql_database_run() -> None: def test_sql_database_run() -> None: