From 76cecf8165b5c9d6c2b0b81a81361a9d550cb4c7 Mon Sep 17 00:00:00 2001 From: Delip Rao Date: Mon, 14 Nov 2022 11:34:01 -0500 Subject: [PATCH 01/16] A fix for Jupyter environment variable issue (#135) - fixes the Jupyter environment variable issues mentioned in issue #134 - fixes format/lint issues in some unrelated files (from make format/lint) ![image](https://user-images.githubusercontent.com/347398/201599322-090af858-362d-4d69-bf59-208aea65419a.png) --- langchain/chains/mapreduce.py | 4 +--- langchain/chains/natbot/crawler.py | 8 +------- langchain/chains/react/prompt.py | 6 +----- langchain/chains/self_ask_with_search/prompt.py | 5 +---- langchain/chains/sql_database/prompt.py | 3 +-- langchain/embeddings/cohere.py | 8 +++++--- langchain/embeddings/openai.py | 8 +++++--- langchain/llms/ai21.py | 13 ++++--------- langchain/llms/cohere.py | 9 +++++---- langchain/llms/huggingface_hub.py | 11 ++++++----- langchain/llms/nlpcloud.py | 8 +++++--- langchain/llms/openai.py | 8 +++++--- langchain/llms/utils.py | 10 +++++++++- langchain/vectorstores/elastic_vector_search.py | 5 +---- tests/integration_tests/llms/test_manifest.py | 4 +--- 15 files changed, 51 insertions(+), 59 deletions(-) diff --git a/langchain/chains/mapreduce.py b/langchain/chains/mapreduce.py index 0a88f945c61..4a2cc6417e0 100644 --- a/langchain/chains/mapreduce.py +++ b/langchain/chains/mapreduce.py @@ -59,9 +59,7 @@ class MapReduceChain(Chain, BaseModel): def _call(self, inputs: Dict[str, str]) -> Dict[str, str]: # Split the larger text into smaller chunks. - docs = self.text_splitter.split_text( - inputs[self.input_key], - ) + docs = self.text_splitter.split_text(inputs[self.input_key],) # Now that we have the chunks, we send them to the LLM and track results. # This is the "map" part. summaries = [] diff --git a/langchain/chains/natbot/crawler.py b/langchain/chains/natbot/crawler.py index 341b890b2fa..ec14d24973b 100644 --- a/langchain/chains/natbot/crawler.py +++ b/langchain/chains/natbot/crawler.py @@ -28,13 +28,7 @@ class Crawler: "Could not import playwright python package. " "Please it install it with `pip install playwright`." ) - self.browser = ( - sync_playwright() - .start() - .chromium.launch( - headless=False, - ) - ) + self.browser = sync_playwright().start().chromium.launch(headless=False,) self.page = self.browser.new_page() self.page.set_viewport_size({"width": 1280, "height": 1080}) diff --git a/langchain/chains/react/prompt.py b/langchain/chains/react/prompt.py index e0e16299f86..8f7f55d20f6 100644 --- a/langchain/chains/react/prompt.py +++ b/langchain/chains/react/prompt.py @@ -109,8 +109,4 @@ Action 3: Finish[yes]""", ] SUFFIX = """\n\nQuestion: {input}""" -PROMPT = Prompt.from_examples( - EXAMPLES, - SUFFIX, - ["input"], -) +PROMPT = Prompt.from_examples(EXAMPLES, SUFFIX, ["input"],) diff --git a/langchain/chains/self_ask_with_search/prompt.py b/langchain/chains/self_ask_with_search/prompt.py index 003e68dd7fd..4c7fff87b1a 100644 --- a/langchain/chains/self_ask_with_search/prompt.py +++ b/langchain/chains/self_ask_with_search/prompt.py @@ -38,7 +38,4 @@ Intermediate Answer: New Zealand. So the final answer is: No Question: {input}""" -PROMPT = Prompt( - input_variables=["input"], - template=_DEFAULT_TEMPLATE, -) +PROMPT = Prompt(input_variables=["input"], template=_DEFAULT_TEMPLATE,) diff --git a/langchain/chains/sql_database/prompt.py b/langchain/chains/sql_database/prompt.py index c35c92e4b57..4532ad24596 100644 --- a/langchain/chains/sql_database/prompt.py +++ b/langchain/chains/sql_database/prompt.py @@ -15,6 +15,5 @@ Only use the following tables: Question: {input}""" PROMPT = Prompt( - input_variables=["input", "table_info", "dialect"], - template=_DEFAULT_TEMPLATE, + input_variables=["input", "table_info", "dialect"], template=_DEFAULT_TEMPLATE, ) diff --git a/langchain/embeddings/cohere.py b/langchain/embeddings/cohere.py index 3ff641bfb02..f98622491ef 100644 --- a/langchain/embeddings/cohere.py +++ b/langchain/embeddings/cohere.py @@ -1,10 +1,10 @@ """Wrapper around Cohere embedding models.""" -import os from typing import Any, Dict, List, Optional from pydantic import BaseModel, Extra, root_validator from langchain.embeddings.base import Embeddings +from langchain.llms.utils import get_from_dict_or_env class CohereEmbeddings(BaseModel, Embeddings): @@ -25,7 +25,7 @@ class CohereEmbeddings(BaseModel, Embeddings): model: str = "medium" """Model name to use.""" - cohere_api_key: Optional[str] = os.environ.get("COHERE_API_KEY") + cohere_api_key: Optional[str] = None class Config: """Configuration for this pydantic object.""" @@ -35,7 +35,9 @@ class CohereEmbeddings(BaseModel, Embeddings): @root_validator() def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" - cohere_api_key = values.get("cohere_api_key") + cohere_api_key = get_from_dict_or_env( + values, "cohere_api_key", "COHERE_API_KEY" + ) if cohere_api_key is None or cohere_api_key == "": raise ValueError( diff --git a/langchain/embeddings/openai.py b/langchain/embeddings/openai.py index 0bb97e58ae0..a7366ab614b 100644 --- a/langchain/embeddings/openai.py +++ b/langchain/embeddings/openai.py @@ -1,10 +1,10 @@ """Wrapper around OpenAI embedding models.""" -import os from typing import Any, Dict, List, Optional from pydantic import BaseModel, Extra, root_validator from langchain.embeddings.base import Embeddings +from langchain.llms.utils import get_from_dict_or_env class OpenAIEmbeddings(BaseModel, Embeddings): @@ -25,7 +25,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings): model_name: str = "babbage" """Model name to use.""" - openai_api_key: Optional[str] = os.environ.get("OPENAI_API_KEY") + openai_api_key: Optional[str] = None class Config: """Configuration for this pydantic object.""" @@ -35,7 +35,9 @@ class OpenAIEmbeddings(BaseModel, Embeddings): @root_validator() def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" - openai_api_key = values.get("openai_api_key") + openai_api_key = get_from_dict_or_env( + values, "openai_api_key", "OPENAI_API_KEY" + ) if openai_api_key is None or openai_api_key == "": raise ValueError( diff --git a/langchain/llms/ai21.py b/langchain/llms/ai21.py index b1dac08b32f..8c4ead45974 100644 --- a/langchain/llms/ai21.py +++ b/langchain/llms/ai21.py @@ -1,11 +1,11 @@ """Wrapper around AI21 APIs.""" -import os from typing import Any, Dict, List, Mapping, Optional import requests from pydantic import BaseModel, Extra, root_validator from langchain.llms.base import LLM +from langchain.llms.utils import get_from_dict_or_env class AI21PenaltyData(BaseModel): @@ -62,7 +62,7 @@ class AI21(BaseModel, LLM): logitBias: Optional[Dict[str, float]] = None """Adjust the probability of specific tokens being generated.""" - ai21_api_key: Optional[str] = os.environ.get("AI21_API_KEY") + ai21_api_key: Optional[str] = None class Config: """Configuration for this pydantic object.""" @@ -72,8 +72,7 @@ class AI21(BaseModel, LLM): @root_validator() def validate_environment(cls, values: Dict) -> Dict: """Validate that api key exists in environment.""" - ai21_api_key = values.get("ai21_api_key") - + ai21_api_key = get_from_dict_or_env(values, "ai21_api_key", "AI21_API_KEY") if ai21_api_key is None or ai21_api_key == "": raise ValueError( "Did not find AI21 API key, please add an environment variable" @@ -122,11 +121,7 @@ class AI21(BaseModel, LLM): response = requests.post( url=f"https://api.ai21.com/studio/v1/{self.model}/complete", headers={"Authorization": f"Bearer {self.ai21_api_key}"}, - json={ - "prompt": prompt, - "stopSequences": stop, - **self._default_params, - }, + json={"prompt": prompt, "stopSequences": stop, **self._default_params,}, ) if response.status_code != 200: optional_detail = response.json().get("error") diff --git a/langchain/llms/cohere.py b/langchain/llms/cohere.py index 2a41b807ea3..fed3b56e8ed 100644 --- a/langchain/llms/cohere.py +++ b/langchain/llms/cohere.py @@ -1,11 +1,10 @@ """Wrapper around Cohere APIs.""" -import os from typing import Any, Dict, List, Mapping, Optional from pydantic import BaseModel, Extra, root_validator from langchain.llms.base import LLM -from langchain.llms.utils import enforce_stop_tokens +from langchain.llms.utils import enforce_stop_tokens, get_from_dict_or_env class Cohere(LLM, BaseModel): @@ -44,7 +43,7 @@ class Cohere(LLM, BaseModel): presence_penalty: int = 0 """Penalizes repeated tokens.""" - cohere_api_key: Optional[str] = os.environ.get("COHERE_API_KEY") + cohere_api_key: Optional[str] = None class Config: """Configuration for this pydantic object.""" @@ -54,7 +53,9 @@ class Cohere(LLM, BaseModel): @root_validator() def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" - cohere_api_key = values.get("cohere_api_key") + cohere_api_key = get_from_dict_or_env( + values, "cohere_api_key", "COHERE_API_KEY" + ) if cohere_api_key is None or cohere_api_key == "": raise ValueError( diff --git a/langchain/llms/huggingface_hub.py b/langchain/llms/huggingface_hub.py index 8d584558ae9..1cea3d555cb 100644 --- a/langchain/llms/huggingface_hub.py +++ b/langchain/llms/huggingface_hub.py @@ -1,11 +1,10 @@ """Wrapper around HuggingFace APIs.""" -import os from typing import Any, Dict, List, Mapping, Optional from pydantic import BaseModel, Extra, root_validator from langchain.llms.base import LLM -from langchain.llms.utils import enforce_stop_tokens +from langchain.llms.utils import enforce_stop_tokens, get_from_dict_or_env DEFAULT_REPO_ID = "gpt2" VALID_TASKS = ("text2text-generation", "text-generation") @@ -18,7 +17,7 @@ class HuggingFaceHub(LLM, BaseModel): environment variable ``HUGGINGFACEHUB_API_TOKEN`` set with your API token, or pass it as a named parameter to the constructor. - Only supports task `text-generation` for now. + Only supports `text-generation` and `text2text-generation` for now. Example: .. code-block:: python @@ -35,7 +34,7 @@ class HuggingFaceHub(LLM, BaseModel): model_kwargs: Optional[dict] = None """Key word arguments to pass to the model.""" - huggingfacehub_api_token: Optional[str] = os.environ.get("HUGGINGFACEHUB_API_TOKEN") + huggingfacehub_api_token: Optional[str] = None class Config: """Configuration for this pydantic object.""" @@ -45,7 +44,9 @@ class HuggingFaceHub(LLM, BaseModel): @root_validator() def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" - huggingfacehub_api_token = values.get("huggingfacehub_api_token") + huggingfacehub_api_token = get_from_dict_or_env( + values, "huggingfacehub_api_token", "HUGGINGFACEHUB_API_TOKEN" + ) if huggingfacehub_api_token is None or huggingfacehub_api_token == "": raise ValueError( "Did not find HuggingFace API token, please add an environment variable" diff --git a/langchain/llms/nlpcloud.py b/langchain/llms/nlpcloud.py index 771cc815a56..e4c37ff7032 100644 --- a/langchain/llms/nlpcloud.py +++ b/langchain/llms/nlpcloud.py @@ -1,10 +1,10 @@ """Wrapper around NLPCloud APIs.""" -import os from typing import Any, Dict, List, Mapping, Optional from pydantic import BaseModel, Extra, root_validator from langchain.llms.base import LLM +from langchain.llms.utils import get_from_dict_or_env class NLPCloud(LLM, BaseModel): @@ -54,7 +54,7 @@ class NLPCloud(LLM, BaseModel): num_return_sequences: int = 1 """How many completions to generate for each prompt.""" - nlpcloud_api_key: Optional[str] = os.environ.get("NLPCLOUD_API_KEY") + nlpcloud_api_key: Optional[str] = None class Config: """Configuration for this pydantic object.""" @@ -64,7 +64,9 @@ class NLPCloud(LLM, BaseModel): @root_validator() def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" - nlpcloud_api_key = values.get("nlpcloud_api_key") + nlpcloud_api_key = get_from_dict_or_env( + values, "nlpcloud_api_key", "NLPCLOUD_API_KEY" + ) if nlpcloud_api_key is None or nlpcloud_api_key == "": raise ValueError( diff --git a/langchain/llms/openai.py b/langchain/llms/openai.py index 2355015b64d..f0127efb518 100644 --- a/langchain/llms/openai.py +++ b/langchain/llms/openai.py @@ -1,10 +1,10 @@ """Wrapper around OpenAI APIs.""" -import os from typing import Any, Dict, List, Mapping, Optional from pydantic import BaseModel, Extra, root_validator from langchain.llms.base import LLM +from langchain.llms.utils import get_from_dict_or_env class OpenAI(LLM, BaseModel): @@ -38,7 +38,7 @@ class OpenAI(LLM, BaseModel): best_of: int = 1 """Generates best_of completions server-side and returns the "best".""" - openai_api_key: Optional[str] = os.environ.get("OPENAI_API_KEY") + openai_api_key: Optional[str] = None class Config: """Configuration for this pydantic object.""" @@ -48,7 +48,9 @@ class OpenAI(LLM, BaseModel): @root_validator() def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" - openai_api_key = values.get("openai_api_key") + openai_api_key = get_from_dict_or_env( + values, "openai_api_key", "OPENAI_API_KEY" + ) if openai_api_key is None or openai_api_key == "": raise ValueError( diff --git a/langchain/llms/utils.py b/langchain/llms/utils.py index a42fd130ee6..ef7407ab621 100644 --- a/langchain/llms/utils.py +++ b/langchain/llms/utils.py @@ -1,8 +1,16 @@ """Common utility functions for working with LLM APIs.""" +import os import re -from typing import List +from typing import Any, Dict, List def enforce_stop_tokens(text: str, stop: List[str]) -> str: """Cut off the text as soon as any stop words occur.""" return re.split("|".join(stop), text)[0] + + +def get_from_dict_or_env(data: Dict[str, Any], key: str, env_key: str) -> Any: + """Get a value from a dictionary or an environment variable.""" + if key in data: + return data[key] + return os.environ.get(env_key, None) diff --git a/langchain/vectorstores/elastic_vector_search.py b/langchain/vectorstores/elastic_vector_search.py index d8aed5f67d2..a0a623467a8 100644 --- a/langchain/vectorstores/elastic_vector_search.py +++ b/langchain/vectorstores/elastic_vector_search.py @@ -45,10 +45,7 @@ class ElasticVectorSearch(VectorStore): """ def __init__( - self, - elasticsearch_url: str, - index_name: str, - embedding_function: Callable, + self, elasticsearch_url: str, index_name: str, embedding_function: Callable, ): """Initialize with necessary components.""" try: diff --git a/tests/integration_tests/llms/test_manifest.py b/tests/integration_tests/llms/test_manifest.py index 3416a332862..41c54cc21b1 100644 --- a/tests/integration_tests/llms/test_manifest.py +++ b/tests/integration_tests/llms/test_manifest.py @@ -6,9 +6,7 @@ def test_manifest_wrapper() -> None: """Test manifest wrapper.""" from manifest import Manifest - manifest = Manifest( - client_name="openai", - ) + manifest = Manifest(client_name="openai",) llm = ManifestWrapper(client=manifest, llm_kwargs={"temperature": 0}) output = llm("The capital of New York is:") assert output == "Albany" From 9f223e6ccce8c1911efb8319d5c844c9f0f8ec9f Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Mon, 14 Nov 2022 08:55:59 -0800 Subject: [PATCH 02/16] Harrison/fix lint (#138) --- langchain/chains/mapreduce.py | 2 +- langchain/chains/natbot/crawler.py | 3 +-- langchain/chains/react/prompt.py | 2 +- langchain/chains/self_ask_with_search/prompt.py | 2 +- langchain/chains/sql_database/prompt.py | 2 +- langchain/llms/ai21.py | 2 +- langchain/vectorstores/elastic_vector_search.py | 2 +- tests/integration_tests/llms/test_manifest.py | 2 +- 8 files changed, 8 insertions(+), 9 deletions(-) diff --git a/langchain/chains/mapreduce.py b/langchain/chains/mapreduce.py index 4a2cc6417e0..8286e49cca1 100644 --- a/langchain/chains/mapreduce.py +++ b/langchain/chains/mapreduce.py @@ -59,7 +59,7 @@ class MapReduceChain(Chain, BaseModel): def _call(self, inputs: Dict[str, str]) -> Dict[str, str]: # Split the larger text into smaller chunks. - docs = self.text_splitter.split_text(inputs[self.input_key],) + docs = self.text_splitter.split_text(inputs[self.input_key]) # Now that we have the chunks, we send them to the LLM and track results. # This is the "map" part. summaries = [] diff --git a/langchain/chains/natbot/crawler.py b/langchain/chains/natbot/crawler.py index ec14d24973b..b15e0eace97 100644 --- a/langchain/chains/natbot/crawler.py +++ b/langchain/chains/natbot/crawler.py @@ -28,8 +28,7 @@ class Crawler: "Could not import playwright python package. " "Please it install it with `pip install playwright`." ) - self.browser = sync_playwright().start().chromium.launch(headless=False,) - + self.browser = sync_playwright().start().chromium.launch(headless=False) self.page = self.browser.new_page() self.page.set_viewport_size({"width": 1280, "height": 1080}) diff --git a/langchain/chains/react/prompt.py b/langchain/chains/react/prompt.py index 8f7f55d20f6..8a3b2cfe811 100644 --- a/langchain/chains/react/prompt.py +++ b/langchain/chains/react/prompt.py @@ -109,4 +109,4 @@ Action 3: Finish[yes]""", ] SUFFIX = """\n\nQuestion: {input}""" -PROMPT = Prompt.from_examples(EXAMPLES, SUFFIX, ["input"],) +PROMPT = Prompt.from_examples(EXAMPLES, SUFFIX, ["input"]) diff --git a/langchain/chains/self_ask_with_search/prompt.py b/langchain/chains/self_ask_with_search/prompt.py index 4c7fff87b1a..02f7ab3f51f 100644 --- a/langchain/chains/self_ask_with_search/prompt.py +++ b/langchain/chains/self_ask_with_search/prompt.py @@ -38,4 +38,4 @@ Intermediate Answer: New Zealand. So the final answer is: No Question: {input}""" -PROMPT = Prompt(input_variables=["input"], template=_DEFAULT_TEMPLATE,) +PROMPT = Prompt(input_variables=["input"], template=_DEFAULT_TEMPLATE) diff --git a/langchain/chains/sql_database/prompt.py b/langchain/chains/sql_database/prompt.py index 4532ad24596..43bb3fcfb67 100644 --- a/langchain/chains/sql_database/prompt.py +++ b/langchain/chains/sql_database/prompt.py @@ -15,5 +15,5 @@ Only use the following tables: Question: {input}""" PROMPT = Prompt( - input_variables=["input", "table_info", "dialect"], template=_DEFAULT_TEMPLATE, + input_variables=["input", "table_info", "dialect"], template=_DEFAULT_TEMPLATE ) diff --git a/langchain/llms/ai21.py b/langchain/llms/ai21.py index 8c4ead45974..3967ca65332 100644 --- a/langchain/llms/ai21.py +++ b/langchain/llms/ai21.py @@ -121,7 +121,7 @@ class AI21(BaseModel, LLM): response = requests.post( url=f"https://api.ai21.com/studio/v1/{self.model}/complete", headers={"Authorization": f"Bearer {self.ai21_api_key}"}, - json={"prompt": prompt, "stopSequences": stop, **self._default_params,}, + json={"prompt": prompt, "stopSequences": stop, **self._default_params}, ) if response.status_code != 200: optional_detail = response.json().get("error") diff --git a/langchain/vectorstores/elastic_vector_search.py b/langchain/vectorstores/elastic_vector_search.py index a0a623467a8..32dd4843b13 100644 --- a/langchain/vectorstores/elastic_vector_search.py +++ b/langchain/vectorstores/elastic_vector_search.py @@ -45,7 +45,7 @@ class ElasticVectorSearch(VectorStore): """ def __init__( - self, elasticsearch_url: str, index_name: str, embedding_function: Callable, + self, elasticsearch_url: str, index_name: str, embedding_function: Callable ): """Initialize with necessary components.""" try: diff --git a/tests/integration_tests/llms/test_manifest.py b/tests/integration_tests/llms/test_manifest.py index 41c54cc21b1..eca4a94b0fa 100644 --- a/tests/integration_tests/llms/test_manifest.py +++ b/tests/integration_tests/llms/test_manifest.py @@ -6,7 +6,7 @@ def test_manifest_wrapper() -> None: """Test manifest wrapper.""" from manifest import Manifest - manifest = Manifest(client_name="openai",) + manifest = Manifest(client_name="openai") llm = ManifestWrapper(client=manifest, llm_kwargs={"temperature": 0}) output = llm("The capital of New York is:") assert output == "Albany" From 1a95252f000e7289c6e2d27d2eb6ee409f867bf5 Mon Sep 17 00:00:00 2001 From: Predrag Gruevski <2348618+obi1kenobi@users.noreply.github.com> Date: Mon, 14 Nov 2022 14:34:08 -0500 Subject: [PATCH 03/16] Use `pull_request` not `pull_request_target` in GitHub Actions. (#139) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit `pull_request` runs on the merge commit between the opened PR and the target branch where the PR is to be merged — `master` in this case. This is desirable because that way the new changes get linted and tested. The existing `pull_request_target` specifier causes lint and test to run _on the target branch itself_ (i.e. `master` in this case). That way the new code in the PR doesn't get linted and tested at all. This can also lead to security vulnerabilities, as described in the GitHub docs: ![image](https://user-images.githubusercontent.com/2348618/201735153-c5dd0c03-2490-45e9-b7f9-f0d47eb0109f.png) Screenshot from here: https://docs.github.com/en/actions/using-workflows/events-that-trigger-workflows#pull_request_target Link from the screenshot: https://securitylab.github.com/research/github-actions-preventing-pwn-requests/ --- .github/workflows/lint.yml | 2 +- .github/workflows/test.yml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 8d56deda34a..68e023356b7 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -1,6 +1,6 @@ name: lint -on: [push, pull_request_target] +on: [push, pull_request] jobs: build: diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index ad0def9e912..54b61276fb8 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -1,6 +1,6 @@ name: test -on: [push, pull_request_target] +on: [push, pull_request] jobs: build: From bbb405a492542f6c987dc138b53319f855a04d22 Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Mon, 14 Nov 2022 20:27:36 -0800 Subject: [PATCH 04/16] update colors (#140) --- langchain/input.py | 4 ++-- tests/unit_tests/test_input.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/langchain/input.py b/langchain/input.py index 94fad908280..21c7333e56e 100644 --- a/langchain/input.py +++ b/langchain/input.py @@ -1,7 +1,7 @@ """Handle chained inputs.""" from typing import Dict, List, Optional -_COLOR_MAPPING = {"blue": 104, "yellow": 103, "red": 101, "green": 102} +_COLOR_MAPPING = {"blue": 51, "yellow": 229, "pink": 219, "green": 85} def get_color_mapping( @@ -21,7 +21,7 @@ def print_text(text: str, color: Optional[str] = None, end: str = "") -> None: print(text, end=end) else: color_str = _COLOR_MAPPING[color] - print(f"\x1b[{color_str}m{text}\x1b[0m", end=end) + print(f"\u001b[48;5;{color_str}m{text}\x1b[0m", end=end) class ChainedInput: diff --git a/tests/unit_tests/test_input.py b/tests/unit_tests/test_input.py index dd17bfc5deb..cc837cbabc8 100644 --- a/tests/unit_tests/test_input.py +++ b/tests/unit_tests/test_input.py @@ -48,7 +48,7 @@ def test_chained_input_verbose() -> None: chained_input.add("baz", color="blue") sys.stdout = old_stdout output = mystdout.getvalue() - assert output == "\x1b[104mbaz\x1b[0m" + assert output == "\x1b[48;5;51mbaz\x1b[0m" assert chained_input.input == "foobarbaz" @@ -70,5 +70,5 @@ def test_get_color_mapping_excluded_colors() -> None: """Test getting of color mapping with excluded colors.""" items = ["foo", "bar"] output = get_color_mapping(items, excluded_colors=["blue"]) - expected_output = {"foo": "yellow", "bar": "red"} + expected_output = {"foo": "yellow", "bar": "pink"} assert output == expected_output From 1835e8a6812e52aacdc50cfe60e2192a7f172b3a Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Mon, 14 Nov 2022 21:30:33 -0800 Subject: [PATCH 05/16] prompt nit (#141) doing some cleanup, and i think this just simplifies things... --- langchain/prompts/prompt.py | 3 +-- tests/unit_tests/test_prompt.py | 4 ++-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/langchain/prompts/prompt.py b/langchain/prompts/prompt.py index b84b7718149..f27c678f04d 100644 --- a/langchain/prompts/prompt.py +++ b/langchain/prompts/prompt.py @@ -94,8 +94,7 @@ class Prompt(BaseModel, BasePrompt): Returns: The final prompt generated. """ - example_str = example_separator.join(examples) - template = prefix + example_str + suffix + template = example_separator.join([prefix, *examples, suffix]) return cls(input_variables=input_variables, template=template) @classmethod diff --git a/tests/unit_tests/test_prompt.py b/tests/unit_tests/test_prompt.py index 80a7fca4a75..7265cae3470 100644 --- a/tests/unit_tests/test_prompt.py +++ b/tests/unit_tests/test_prompt.py @@ -51,8 +51,8 @@ Question: {question} Answer:""" input_variables = ["question"] example_separator = "\n\n" - prefix = """Test Prompt:\n\n""" - suffix = """\n\nQuestion: {question}\nAnswer:""" + prefix = """Test Prompt:""" + suffix = """Question: {question}\nAnswer:""" examples = [ """Question: who are you?\nAnswer: foo""", """Question: what are you?\nAnswer: bar""", From a4b502d92ffeee34f1638ae08e1887e33bc82be3 Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Mon, 14 Nov 2022 21:42:43 -0800 Subject: [PATCH 06/16] fix env var loader (#143) --- langchain/llms/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/langchain/llms/utils.py b/langchain/llms/utils.py index ef7407ab621..29c69d925e7 100644 --- a/langchain/llms/utils.py +++ b/langchain/llms/utils.py @@ -11,6 +11,6 @@ def enforce_stop_tokens(text: str, stop: List[str]) -> str: def get_from_dict_or_env(data: Dict[str, Any], key: str, env_key: str) -> Any: """Get a value from a dictionary or an environment variable.""" - if key in data: + if key in data and data[key]: return data[key] return os.environ.get(env_key, None) From b504cd739fc04f1169316203c972f553c58826f0 Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Mon, 14 Nov 2022 22:05:41 -0800 Subject: [PATCH 07/16] Harrison/cleanup env check (#144) --- langchain/chains/serpapi.py | 15 ++++++--------- langchain/embeddings/cohere.py | 9 +-------- langchain/embeddings/openai.py | 9 +-------- langchain/llms/ai21.py | 9 ++------- langchain/llms/cohere.py | 10 ++-------- langchain/llms/huggingface_hub.py | 9 ++------- langchain/llms/nlpcloud.py | 9 +-------- langchain/llms/openai.py | 9 +-------- langchain/llms/utils.py | 10 +--------- langchain/utils.py | 17 +++++++++++++++++ langchain/vectorstores/elastic_vector_search.py | 15 ++++----------- 11 files changed, 38 insertions(+), 83 deletions(-) create mode 100644 langchain/utils.py diff --git a/langchain/chains/serpapi.py b/langchain/chains/serpapi.py index dbac148b898..50e086e13be 100644 --- a/langchain/chains/serpapi.py +++ b/langchain/chains/serpapi.py @@ -9,6 +9,7 @@ from typing import Any, Dict, List, Optional from pydantic import BaseModel, Extra, root_validator from langchain.chains.base import Chain +from langchain.utils import get_from_dict_or_env class HiddenPrints: @@ -43,7 +44,7 @@ class SerpAPIChain(Chain, BaseModel): input_key: str = "search_query" #: :meta private: output_key: str = "search_result" #: :meta private: - serpapi_api_key: Optional[str] = os.environ.get("SERPAPI_API_KEY") + serpapi_api_key: Optional[str] = None class Config: """Configuration for this pydantic object.""" @@ -69,14 +70,10 @@ class SerpAPIChain(Chain, BaseModel): @root_validator() def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" - serpapi_api_key = values.get("serpapi_api_key") - - if serpapi_api_key is None or serpapi_api_key == "": - raise ValueError( - "Did not find SerpAPI API key, please add an environment variable" - " `SERPAPI_API_KEY` which contains it, or pass `serpapi_api_key` " - "as a named parameter to the constructor." - ) + serpapi_api_key = get_from_dict_or_env( + values, "serpapi_api_key", "SERPAPI_API_KEY" + ) + values["serpapi_api_key"] = serpapi_api_key try: from serpapi import GoogleSearch diff --git a/langchain/embeddings/cohere.py b/langchain/embeddings/cohere.py index f98622491ef..9a4f2ffe6e1 100644 --- a/langchain/embeddings/cohere.py +++ b/langchain/embeddings/cohere.py @@ -4,7 +4,7 @@ from typing import Any, Dict, List, Optional from pydantic import BaseModel, Extra, root_validator from langchain.embeddings.base import Embeddings -from langchain.llms.utils import get_from_dict_or_env +from langchain.utils import get_from_dict_or_env class CohereEmbeddings(BaseModel, Embeddings): @@ -38,13 +38,6 @@ class CohereEmbeddings(BaseModel, Embeddings): cohere_api_key = get_from_dict_or_env( values, "cohere_api_key", "COHERE_API_KEY" ) - - if cohere_api_key is None or cohere_api_key == "": - raise ValueError( - "Did not find Cohere API key, please add an environment variable" - " `COHERE_API_KEY` which contains it, or pass `cohere_api_key` as a" - " named parameter." - ) try: import cohere diff --git a/langchain/embeddings/openai.py b/langchain/embeddings/openai.py index a7366ab614b..864e7758f37 100644 --- a/langchain/embeddings/openai.py +++ b/langchain/embeddings/openai.py @@ -4,7 +4,7 @@ from typing import Any, Dict, List, Optional from pydantic import BaseModel, Extra, root_validator from langchain.embeddings.base import Embeddings -from langchain.llms.utils import get_from_dict_or_env +from langchain.utils import get_from_dict_or_env class OpenAIEmbeddings(BaseModel, Embeddings): @@ -38,13 +38,6 @@ class OpenAIEmbeddings(BaseModel, Embeddings): openai_api_key = get_from_dict_or_env( values, "openai_api_key", "OPENAI_API_KEY" ) - - if openai_api_key is None or openai_api_key == "": - raise ValueError( - "Did not find OpenAI API key, please add an environment variable" - " `OPENAI_API_KEY` which contains it, or pass `openai_api_key` as a" - " named parameter." - ) try: import openai diff --git a/langchain/llms/ai21.py b/langchain/llms/ai21.py index 3967ca65332..a870d9e4bfd 100644 --- a/langchain/llms/ai21.py +++ b/langchain/llms/ai21.py @@ -5,7 +5,7 @@ import requests from pydantic import BaseModel, Extra, root_validator from langchain.llms.base import LLM -from langchain.llms.utils import get_from_dict_or_env +from langchain.utils import get_from_dict_or_env class AI21PenaltyData(BaseModel): @@ -73,12 +73,7 @@ class AI21(BaseModel, LLM): def validate_environment(cls, values: Dict) -> Dict: """Validate that api key exists in environment.""" ai21_api_key = get_from_dict_or_env(values, "ai21_api_key", "AI21_API_KEY") - if ai21_api_key is None or ai21_api_key == "": - raise ValueError( - "Did not find AI21 API key, please add an environment variable" - " `AI21_API_KEY` which contains it, or pass `ai21_api_key`" - " as a named parameter." - ) + values["ai21_api_key"] = ai21_api_key return values @property diff --git a/langchain/llms/cohere.py b/langchain/llms/cohere.py index fed3b56e8ed..e051ba47ff7 100644 --- a/langchain/llms/cohere.py +++ b/langchain/llms/cohere.py @@ -4,7 +4,8 @@ from typing import Any, Dict, List, Mapping, Optional from pydantic import BaseModel, Extra, root_validator from langchain.llms.base import LLM -from langchain.llms.utils import enforce_stop_tokens, get_from_dict_or_env +from langchain.llms.utils import enforce_stop_tokens +from langchain.utils import get_from_dict_or_env class Cohere(LLM, BaseModel): @@ -56,13 +57,6 @@ class Cohere(LLM, BaseModel): cohere_api_key = get_from_dict_or_env( values, "cohere_api_key", "COHERE_API_KEY" ) - - if cohere_api_key is None or cohere_api_key == "": - raise ValueError( - "Did not find Cohere API key, please add an environment variable" - " `COHERE_API_KEY` which contains it, or pass `cohere_api_key`" - " as a named parameter." - ) try: import cohere diff --git a/langchain/llms/huggingface_hub.py b/langchain/llms/huggingface_hub.py index 1cea3d555cb..c67c9720a4e 100644 --- a/langchain/llms/huggingface_hub.py +++ b/langchain/llms/huggingface_hub.py @@ -4,7 +4,8 @@ from typing import Any, Dict, List, Mapping, Optional from pydantic import BaseModel, Extra, root_validator from langchain.llms.base import LLM -from langchain.llms.utils import enforce_stop_tokens, get_from_dict_or_env +from langchain.llms.utils import enforce_stop_tokens +from langchain.utils import get_from_dict_or_env DEFAULT_REPO_ID = "gpt2" VALID_TASKS = ("text2text-generation", "text-generation") @@ -47,12 +48,6 @@ class HuggingFaceHub(LLM, BaseModel): huggingfacehub_api_token = get_from_dict_or_env( values, "huggingfacehub_api_token", "HUGGINGFACEHUB_API_TOKEN" ) - if huggingfacehub_api_token is None or huggingfacehub_api_token == "": - raise ValueError( - "Did not find HuggingFace API token, please add an environment variable" - " `HUGGINGFACEHUB_API_TOKEN` which contains it, or pass" - " `huggingfacehub_api_token` as a named parameter." - ) try: from huggingface_hub.inference_api import InferenceApi diff --git a/langchain/llms/nlpcloud.py b/langchain/llms/nlpcloud.py index e4c37ff7032..d9e4c54e206 100644 --- a/langchain/llms/nlpcloud.py +++ b/langchain/llms/nlpcloud.py @@ -4,7 +4,7 @@ from typing import Any, Dict, List, Mapping, Optional from pydantic import BaseModel, Extra, root_validator from langchain.llms.base import LLM -from langchain.llms.utils import get_from_dict_or_env +from langchain.utils import get_from_dict_or_env class NLPCloud(LLM, BaseModel): @@ -67,13 +67,6 @@ class NLPCloud(LLM, BaseModel): nlpcloud_api_key = get_from_dict_or_env( values, "nlpcloud_api_key", "NLPCLOUD_API_KEY" ) - - if nlpcloud_api_key is None or nlpcloud_api_key == "": - raise ValueError( - "Did not find NLPCloud API key, please add an environment variable" - " `NLPCLOUD_API_KEY` which contains it, or pass `nlpcloud_api_key`" - " as a named parameter." - ) try: import nlpcloud diff --git a/langchain/llms/openai.py b/langchain/llms/openai.py index f0127efb518..2affb86de5e 100644 --- a/langchain/llms/openai.py +++ b/langchain/llms/openai.py @@ -4,7 +4,7 @@ from typing import Any, Dict, List, Mapping, Optional from pydantic import BaseModel, Extra, root_validator from langchain.llms.base import LLM -from langchain.llms.utils import get_from_dict_or_env +from langchain.utils import get_from_dict_or_env class OpenAI(LLM, BaseModel): @@ -51,13 +51,6 @@ class OpenAI(LLM, BaseModel): openai_api_key = get_from_dict_or_env( values, "openai_api_key", "OPENAI_API_KEY" ) - - if openai_api_key is None or openai_api_key == "": - raise ValueError( - "Did not find OpenAI API key, please add an environment variable" - " `OPENAI_API_KEY` which contains it, or pass `openai_api_key`" - " as a named parameter." - ) try: import openai diff --git a/langchain/llms/utils.py b/langchain/llms/utils.py index 29c69d925e7..a42fd130ee6 100644 --- a/langchain/llms/utils.py +++ b/langchain/llms/utils.py @@ -1,16 +1,8 @@ """Common utility functions for working with LLM APIs.""" -import os import re -from typing import Any, Dict, List +from typing import List def enforce_stop_tokens(text: str, stop: List[str]) -> str: """Cut off the text as soon as any stop words occur.""" return re.split("|".join(stop), text)[0] - - -def get_from_dict_or_env(data: Dict[str, Any], key: str, env_key: str) -> Any: - """Get a value from a dictionary or an environment variable.""" - if key in data and data[key]: - return data[key] - return os.environ.get(env_key, None) diff --git a/langchain/utils.py b/langchain/utils.py new file mode 100644 index 00000000000..8588f4e940e --- /dev/null +++ b/langchain/utils.py @@ -0,0 +1,17 @@ +"""Generic utility functions.""" +import os +from typing import Any, Dict + + +def get_from_dict_or_env(data: Dict[str, Any], key: str, env_key: str) -> str: + """Get a value from a dictionary or an environment variable.""" + if key in data and data[key]: + return data[key] + elif env_key in os.environ and os.environ[env_key]: + return os.environ[env_key] + else: + raise ValueError( + f"Did not find {key}, please add an environment variable" + f" `{env_key}` which contains it, or pass" + f" `{key}` as a named parameter." + ) diff --git a/langchain/vectorstores/elastic_vector_search.py b/langchain/vectorstores/elastic_vector_search.py index 32dd4843b13..549277b3f94 100644 --- a/langchain/vectorstores/elastic_vector_search.py +++ b/langchain/vectorstores/elastic_vector_search.py @@ -1,10 +1,10 @@ """Wrapper around Elasticsearch vector database.""" -import os import uuid from typing import Any, Callable, Dict, List from langchain.docstore.document import Document from langchain.embeddings.base import Embeddings +from langchain.utils import get_from_dict_or_env from langchain.vectorstores.base import VectorStore @@ -107,16 +107,9 @@ class ElasticVectorSearch(VectorStore): elasticsearch_url="http://localhost:9200" ) """ - elasticsearch_url = kwargs.get("elasticsearch_url") - if not elasticsearch_url: - elasticsearch_url = os.environ.get("ELASTICSEARCH_URL") - - if elasticsearch_url is None or elasticsearch_url == "": - raise ValueError( - "Did not find Elasticsearch URL, please add an environment variable" - " `ELASTICSEARCH_URL` which contains it, or pass" - " `elasticsearch_url` as a named parameter." - ) + elasticsearch_url = get_from_dict_or_env( + kwargs, "elasticsearch_url", "ELASTICSEARCH_URL" + ) try: import elasticsearch from elasticsearch.helpers import bulk From 4f1bf159f4c6bfc9044e42b5e000981d419d1c26 Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Mon, 14 Nov 2022 22:07:54 -0800 Subject: [PATCH 08/16] bump version to 0.0.14 (#145) --- langchain/VERSION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/langchain/VERSION b/langchain/VERSION index 43b29618309..9789c4ccb0c 100644 --- a/langchain/VERSION +++ b/langchain/VERSION @@ -1 +1 @@ -0.0.13 +0.0.14 From 47e35d7d0ed872c0046f5917921a799f03143a22 Mon Sep 17 00:00:00 2001 From: thesved <2893181+thesved@users.noreply.github.com> Date: Thu, 17 Nov 2022 00:13:12 +0100 Subject: [PATCH 09/16] Fix notebook links (#149) Example notebook links were broken. --- README.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index b195840a220..a7bded5508f 100644 --- a/README.md +++ b/README.md @@ -37,7 +37,7 @@ This project was largely inspired by a few projects seen on Twitter for which we **[Self-ask-with-search](https://ofir.io/self-ask.pdf)** -To recreate this paper, use the following code snippet or checkout the [example notebook](https://github.com/hwchase17/langchain/blob/master/examples/self_ask_with_search.ipynb). +To recreate this paper, use the following code snippet or checkout the [example notebook](https://github.com/hwchase17/langchain/blob/master/docs/examples/demos/self_ask_with_search.ipynb). ```python from langchain import SelfAskWithSearchChain, OpenAI, SerpAPIChain @@ -52,7 +52,7 @@ self_ask_with_search.run("What is the hometown of the reigning men's U.S. Open c **[LLM Math](https://twitter.com/amasad/status/1568824744367259648?s=20&t=-7wxpXBJinPgDuyHLouP1w)** -To recreate this example, use the following code snippet or check out the [example notebook](https://github.com/hwchase17/langchain/blob/master/examples/llm_math.ipynb). +To recreate this example, use the following code snippet or check out the [example notebook](https://github.com/hwchase17/langchain/blob/master/docs/examples/demos/llm_math.ipynb). ```python from langchain import OpenAI, LLMMathChain @@ -65,7 +65,7 @@ llm_math.run("How many of the integers between 0 and 99 inclusive are divisible **Generic Prompting** -You can also use this for simple prompting pipelines, as in the below example and this [example notebook](https://github.com/hwchase17/langchain/blob/master/examples/simple_prompts.ipynb). +You can also use this for simple prompting pipelines, as in the below example and this [example notebook](https://github.com/hwchase17/langchain/blob/master/docs/examples/demos/simple_prompts.ipynb). ```python from langchain import Prompt, OpenAI, LLMChain @@ -84,7 +84,7 @@ llm_chain.predict(question=question) **Embed & Search Documents** -We support two vector databases to store and search embeddings -- FAISS and Elasticsearch. Here's a code snippet showing how to use FAISS to store embeddings and search for text similar to a query. Both database backends are featured in this [example notebook](https://github.com/hwchase17/langchain/blob/master/examples/embeddings.ipynb). +We support two vector databases to store and search embeddings -- FAISS and Elasticsearch. Here's a code snippet showing how to use FAISS to store embeddings and search for text similar to a query. Both database backends are featured in this [example notebook](https://github.com/hwchase17/langchain/blob/master/docs/examples/integrations/embeddings.ipynb). ```python from langchain.embeddings.openai import OpenAIEmbeddings From d775ddd749a545f53490f75c3601b23965b4ee55 Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Wed, 16 Nov 2022 21:39:02 -0800 Subject: [PATCH 10/16] add apply functionality (#150) --- langchain/chains/base.py | 4 ++++ langchain/chains/mapreduce.py | 9 ++++----- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/langchain/chains/base.py b/langchain/chains/base.py index 15827f6577c..53befcff134 100644 --- a/langchain/chains/base.py +++ b/langchain/chains/base.py @@ -49,6 +49,10 @@ class Chain(BaseModel, ABC): self._validate_outputs(outputs) return {**inputs, **outputs} + def apply(self, input_list: List[Dict[str, Any]]) -> List[Dict[str, str]]: + """Call the chain on all inputs in the list.""" + return [self(inputs) for inputs in input_list] + def run(self, text: str) -> str: """Run text in, text out (if applicable).""" if len(self.input_keys) != 1: diff --git a/langchain/chains/mapreduce.py b/langchain/chains/mapreduce.py index 8286e49cca1..623b95cdfcd 100644 --- a/langchain/chains/mapreduce.py +++ b/langchain/chains/mapreduce.py @@ -60,13 +60,12 @@ class MapReduceChain(Chain, BaseModel): def _call(self, inputs: Dict[str, str]) -> Dict[str, str]: # Split the larger text into smaller chunks. docs = self.text_splitter.split_text(inputs[self.input_key]) + # Now that we have the chunks, we send them to the LLM and track results. # This is the "map" part. - summaries = [] - for d in docs: - inputs = {self.map_llm.prompt.input_variables[0]: d} - res = self.map_llm.predict(**inputs) - summaries.append(res) + input_list = [{self.map_llm.prompt.input_variables[0]: d} for d in docs] + summary_results = self.map_llm.apply(input_list) + summaries = [res[self.map_llm.output_key] for res in summary_results] # We then need to combine these individual parts into one. # This is the reduce part. From d2f9288be6c61b305ac3f7f8c52d53a7db3b1ded Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Wed, 16 Nov 2022 21:58:05 -0800 Subject: [PATCH 11/16] add metadata to documents (#153) add concept of metadata to document --- langchain/docstore/document.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/langchain/docstore/document.py b/langchain/docstore/document.py index 2c6e04bb0ba..cd6349d5312 100644 --- a/langchain/docstore/document.py +++ b/langchain/docstore/document.py @@ -1,7 +1,7 @@ """Interface for interacting with a document.""" from typing import List -from pydantic import BaseModel +from pydantic import BaseModel, Field class Document(BaseModel): @@ -10,6 +10,7 @@ class Document(BaseModel): page_content: str lookup_str: str = "" lookup_index = 0 + metadata: dict = Field(default_factory=dict) @property def paragraphs(self) -> List[str]: From ca4b10bb74400c0489777e59ef0ca244830a53a1 Mon Sep 17 00:00:00 2001 From: Nicholas Larus-Stone Date: Wed, 16 Nov 2022 22:04:50 -0800 Subject: [PATCH 12/16] feat: add option to ignore or restrict to SQL tables (#151) `SQLDatabase` now accepts two `init` arguments: 1. `ignore_tables` to pass in a list of tables to not search over 2. `include_tables` to restrict to a list of tables to consider --- langchain/sql_database.py | 44 ++++++++++++++++++++++----- tests/unit_tests/test_sql_database.py | 8 ++--- 2 files changed, 41 insertions(+), 11 deletions(-) diff --git a/langchain/sql_database.py b/langchain/sql_database.py index 138839bb3dc..a04ab15aafa 100644 --- a/langchain/sql_database.py +++ b/langchain/sql_database.py @@ -1,4 +1,6 @@ """SQLAlchemy wrapper around a database.""" +from typing import Any, Iterable, List, Optional + from sqlalchemy import create_engine, inspect from sqlalchemy.engine import Engine @@ -6,29 +8,57 @@ from sqlalchemy.engine import Engine class SQLDatabase: """SQLAlchemy wrapper around a database.""" - def __init__(self, engine: Engine): + def __init__( + self, + engine: Engine, + ignore_tables: Optional[List[str]] = None, + include_tables: Optional[List[str]] = None, + ): """Create engine from database URI.""" self._engine = engine + if include_tables and ignore_tables: + raise ValueError("Cannot specify both include_tables and ignore_tables") + + self._inspector = inspect(self._engine) + self._all_tables = self._inspector.get_table_names() + self._include_tables = include_tables or [] + if self._include_tables: + missing_tables = set(self._include_tables).difference(self._all_tables) + if missing_tables: + raise ValueError( + f"include_tables {missing_tables} not found in database" + ) + self._ignore_tables = ignore_tables or [] + if self._ignore_tables: + missing_tables = set(self._ignore_tables).difference(self._all_tables) + if missing_tables: + raise ValueError( + f"ignore_tables {missing_tables} not found in database" + ) @classmethod - def from_uri(cls, database_uri: str) -> "SQLDatabase": + def from_uri(cls, database_uri: str, **kwargs: Any) -> "SQLDatabase": """Construct a SQLAlchemy engine from URI.""" - return cls(create_engine(database_uri)) + return cls(create_engine(database_uri), **kwargs) @property def dialect(self) -> str: """Return string representation of dialect to use.""" return self._engine.dialect.name + def _get_table_names(self) -> Iterable[str]: + if self._include_tables: + return self._include_tables + return set(self._all_tables) - set(self._ignore_tables) + @property def table_info(self) -> str: """Information about all tables in the database.""" - template = "The '{table_name}' table has columns: {columns}." + template = "Table '{table_name}' has columns: {columns}." tables = [] - inspector = inspect(self._engine) - for table_name in inspector.get_table_names(): + for table_name in self._get_table_names(): columns = [] - for column in inspector.get_columns(table_name): + for column in self._inspector.get_columns(table_name): columns.append(f"{column['name']} ({str(column['type'])})") column_str = ", ".join(columns) table_str = template.format(table_name=table_name, columns=column_str) diff --git a/tests/unit_tests/test_sql_database.py b/tests/unit_tests/test_sql_database.py index c18d5deb63a..1a536fe5d1d 100644 --- a/tests/unit_tests/test_sql_database.py +++ b/tests/unit_tests/test_sql_database.py @@ -28,11 +28,11 @@ def test_table_info() -> None: db = SQLDatabase(engine) output = db.table_info expected_output = ( - "The 'company' table has columns: company_id (INTEGER), " - "company_location (VARCHAR).\n" - "The 'user' table has columns: user_id (INTEGER), user_name (VARCHAR(16))." + "Table 'company' has columns: company_id (INTEGER), " + "company_location (VARCHAR).", + "Table 'user' has columns: user_id (INTEGER), user_name (VARCHAR(16)).", ) - assert output == expected_output + assert sorted(output.split("\n")) == sorted(expected_output) def test_sql_database_run() -> None: From 0c3ae78ec18303857bb6007699d8141f6acde3d3 Mon Sep 17 00:00:00 2001 From: Nicholas Larus-Stone Date: Wed, 16 Nov 2022 22:05:28 -0800 Subject: [PATCH 13/16] chore: update ascii colors to work with dark mode (#152) --- langchain/chains/base.py | 2 +- langchain/input.py | 19 ++++++++++++------- tests/unit_tests/test_input.py | 2 +- 3 files changed, 14 insertions(+), 9 deletions(-) diff --git a/langchain/chains/base.py b/langchain/chains/base.py index 53befcff134..0f9edecb057 100644 --- a/langchain/chains/base.py +++ b/langchain/chains/base.py @@ -9,7 +9,7 @@ class Chain(BaseModel, ABC): """Base interface that all chains should implement.""" verbose: bool = False - """Whether to print out the code that was executed.""" + """Whether to print out response text.""" @property @abstractmethod diff --git a/langchain/input.py b/langchain/input.py index 21c7333e56e..ef7053ad315 100644 --- a/langchain/input.py +++ b/langchain/input.py @@ -1,14 +1,19 @@ """Handle chained inputs.""" from typing import Dict, List, Optional -_COLOR_MAPPING = {"blue": 51, "yellow": 229, "pink": 219, "green": 85} +_TEXT_COLOR_MAPPING = { + "blue": "36;1", + "yellow": "33;1", + "pink": "38;5;200", + "green": "32;1", +} def get_color_mapping( items: List[str], excluded_colors: Optional[List] = None ) -> Dict[str, str]: """Get mapping for items to a support color.""" - colors = list(_COLOR_MAPPING.keys()) + colors = list(_TEXT_COLOR_MAPPING.keys()) if excluded_colors is not None: colors = [c for c in colors if c not in excluded_colors] color_mapping = {item: colors[i % len(colors)] for i, item in enumerate(items)} @@ -20,8 +25,8 @@ def print_text(text: str, color: Optional[str] = None, end: str = "") -> None: if color is None: print(text, end=end) else: - color_str = _COLOR_MAPPING[color] - print(f"\u001b[48;5;{color_str}m{text}\x1b[0m", end=end) + color_str = _TEXT_COLOR_MAPPING[color] + print(f"\u001b[{color_str}m\033[1;3m{text}\u001b[0m", end=end) class ChainedInput: @@ -29,14 +34,14 @@ class ChainedInput: def __init__(self, text: str, verbose: bool = False): """Initialize with verbose flag and initial text.""" - self.verbose = verbose - if self.verbose: + self._verbose = verbose + if self._verbose: print_text(text, None) self._input = text def add(self, text: str, color: Optional[str] = None) -> None: """Add text to input, print if in verbose mode.""" - if self.verbose: + if self._verbose: print_text(text, color) self._input += text diff --git a/tests/unit_tests/test_input.py b/tests/unit_tests/test_input.py index cc837cbabc8..43dd3b080b8 100644 --- a/tests/unit_tests/test_input.py +++ b/tests/unit_tests/test_input.py @@ -48,7 +48,7 @@ def test_chained_input_verbose() -> None: chained_input.add("baz", color="blue") sys.stdout = old_stdout output = mystdout.getvalue() - assert output == "\x1b[48;5;51mbaz\x1b[0m" + assert output == "\x1b[36;1m\x1b[1;3mbaz\x1b[0m" assert chained_input.input == "foobarbaz" From 0ac08bbca6a74fb5ecaddd4ebe65ddcf57e112e5 Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Wed, 16 Nov 2022 23:22:05 -0800 Subject: [PATCH 14/16] bump version to 0.0.15 (#154) --- langchain/VERSION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/langchain/VERSION b/langchain/VERSION index 9789c4ccb0c..ceddfb28f4f 100644 --- a/langchain/VERSION +++ b/langchain/VERSION @@ -1 +1 @@ -0.0.14 +0.0.15 From b15c84e19d31722693f6388b1b59418e88d32d3d Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Fri, 18 Nov 2022 05:50:02 -0800 Subject: [PATCH 15/16] Harrison/chain lab (#156) --- docs/examples/model_laboratory.ipynb | 111 +++++++++++++++++++++++--- langchain/chains/vector_db_qa/base.py | 4 +- langchain/model_laboratory.py | 68 +++++++++++----- 3 files changed, 154 insertions(+), 29 deletions(-) diff --git a/docs/examples/model_laboratory.ipynb b/docs/examples/model_laboratory.ipynb index 0646386e56b..8c5af92f172 100644 --- a/docs/examples/model_laboratory.ipynb +++ b/docs/examples/model_laboratory.ipynb @@ -42,7 +42,7 @@ "metadata": {}, "outputs": [], "source": [ - "model_lab = ModelLaboratory(llms)" + "model_lab = ModelLaboratory.from_llms(llms)" ] }, { @@ -60,19 +60,19 @@ "\n", "\u001b[1mOpenAI\u001b[0m\n", "Params: {'model': 'text-davinci-002', 'temperature': 0.0, 'max_tokens': 256, 'top_p': 1, 'frequency_penalty': 0, 'presence_penalty': 0, 'n': 1, 'best_of': 1}\n", - "\u001b[104m\n", + "\u001b[36;1m\u001b[1;3m\n", "\n", "Flamingos are pink.\u001b[0m\n", "\n", "\u001b[1mCohere\u001b[0m\n", "Params: {'model': 'command-xlarge-20221108', 'max_tokens': 20, 'temperature': 0.0, 'k': 0, 'p': 1, 'frequency_penalty': 0, 'presence_penalty': 0}\n", - "\u001b[103m\n", + "\u001b[33;1m\u001b[1;3m\n", "\n", "Pink\u001b[0m\n", "\n", "\u001b[1mHuggingFaceHub\u001b[0m\n", "Params: {'repo_id': 'google/flan-t5-xl', 'temperature': 1}\n", - "\u001b[101mpink\u001b[0m\n", + "\u001b[38;5;200m\u001b[1;3mpink\u001b[0m\n", "\n" ] } @@ -89,7 +89,7 @@ "outputs": [], "source": [ "prompt = Prompt(template=\"What is the capital of {state}?\", input_variables=[\"state\"])\n", - "model_lab_with_prompt = ModelLaboratory(llms, prompt=prompt)" + "model_lab_with_prompt = ModelLaboratory.from_llms(llms, prompt=prompt)" ] }, { @@ -107,19 +107,19 @@ "\n", "\u001b[1mOpenAI\u001b[0m\n", "Params: {'model': 'text-davinci-002', 'temperature': 0.0, 'max_tokens': 256, 'top_p': 1, 'frequency_penalty': 0, 'presence_penalty': 0, 'n': 1, 'best_of': 1}\n", - "\u001b[104m\n", + "\u001b[36;1m\u001b[1;3m\n", "\n", "The capital of New York is Albany.\u001b[0m\n", "\n", "\u001b[1mCohere\u001b[0m\n", "Params: {'model': 'command-xlarge-20221108', 'max_tokens': 20, 'temperature': 0.0, 'k': 0, 'p': 1, 'frequency_penalty': 0, 'presence_penalty': 0}\n", - "\u001b[103m\n", + "\u001b[33;1m\u001b[1;3m\n", "\n", "The capital of New York is Albany.\u001b[0m\n", "\n", "\u001b[1mHuggingFaceHub\u001b[0m\n", "Params: {'repo_id': 'google/flan-t5-xl', 'temperature': 1}\n", - "\u001b[101mst john s\u001b[0m\n", + "\u001b[38;5;200m\u001b[1;3mst john s\u001b[0m\n", "\n" ] } @@ -130,10 +130,103 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 7, "id": "54336dbf", "metadata": {}, "outputs": [], + "source": [ + "from langchain import SelfAskWithSearchChain, SerpAPIChain\n", + "\n", + "open_ai_llm = OpenAI(temperature=0)\n", + "search = SerpAPIChain()\n", + "self_ask_with_search_openai = SelfAskWithSearchChain(llm=open_ai_llm, search_chain=search, verbose=True)\n", + "\n", + "cohere_llm = Cohere(temperature=0, model=\"command-xlarge-20221108\")\n", + "search = SerpAPIChain()\n", + "self_ask_with_search_cohere = SelfAskWithSearchChain(llm=cohere_llm, search_chain=search, verbose=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "6a50a9f1", + "metadata": {}, + "outputs": [], + "source": [ + "chains = [self_ask_with_search_openai, self_ask_with_search_cohere]\n", + "names = [str(open_ai_llm), str(cohere_llm)]" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "d3549e99", + "metadata": {}, + "outputs": [], + "source": [ + "model_lab = ModelLaboratory(chains, names=names)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "362f7f57", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[1mInput:\u001b[0m\n", + "What is the hometown of the reigning men's U.S. Open champion?\n", + "\n", + "\u001b[1mOpenAI\u001b[0m\n", + "Params: {'model': 'text-davinci-002', 'temperature': 0.0, 'max_tokens': 256, 'top_p': 1, 'frequency_penalty': 0, 'presence_penalty': 0, 'n': 1, 'best_of': 1}\n", + "\n", + "\n", + "\u001b[1m> Entering new chain...\u001b[0m\n", + "What is the hometown of the reigning men's U.S. Open champion?\n", + "Are follow up questions needed here:\u001b[32;1m\u001b[1;3m Yes.\n", + "Follow up: Who is the reigning men's U.S. Open champion?\u001b[0m\n", + "Intermediate answer: \u001b[33;1m\u001b[1;3mCarlos Alcaraz.\u001b[0m\u001b[32;1m\u001b[1;3m\n", + "Follow up: Where is Carlos Alcaraz from?\u001b[0m\n", + "Intermediate answer: \u001b[33;1m\u001b[1;3mEl Palmar, Spain.\u001b[0m\u001b[32;1m\u001b[1;3m\n", + "So the final answer is: El Palmar, Spain\u001b[0m\n", + "\u001b[1m> Finished chain.\u001b[0m\n", + "\u001b[36;1m\u001b[1;3m\n", + "So the final answer is: El Palmar, Spain\u001b[0m\n", + "\n", + "\u001b[1mCohere\u001b[0m\n", + "Params: {'model': 'command-xlarge-20221108', 'max_tokens': 256, 'temperature': 0.0, 'k': 0, 'p': 1, 'frequency_penalty': 0, 'presence_penalty': 0}\n", + "\n", + "\n", + "\u001b[1m> Entering new chain...\u001b[0m\n", + "What is the hometown of the reigning men's U.S. Open champion?\n", + "Are follow up questions needed here:\u001b[32;1m\u001b[1;3m Yes.\n", + "Follow up: Who is the reigning men's U.S. Open champion?\u001b[0m\n", + "Intermediate answer: \u001b[33;1m\u001b[1;3mCarlos Alcaraz.\u001b[0m\u001b[32;1m\u001b[1;3m\n", + "So the final answer is:\n", + "\n", + "Carlos Alcaraz\u001b[0m\n", + "\u001b[1m> Finished chain.\u001b[0m\n", + "\u001b[33;1m\u001b[1;3m\n", + "So the final answer is:\n", + "\n", + "Carlos Alcaraz\u001b[0m\n", + "\n" + ] + } + ], + "source": [ + "model_lab.compare(\"What is the hometown of the reigning men's U.S. Open champion?\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "94159131", + "metadata": {}, + "outputs": [], "source": [] } ], diff --git a/langchain/chains/vector_db_qa/base.py b/langchain/chains/vector_db_qa/base.py index 3e010710576..d54de11ca22 100644 --- a/langchain/chains/vector_db_qa/base.py +++ b/langchain/chains/vector_db_qa/base.py @@ -27,6 +27,8 @@ class VectorDBQA(Chain, BaseModel): """LLM wrapper to use.""" vectorstore: VectorStore """Vector Database to connect to.""" + k: int = 4 + """Number of documents to query for.""" input_key: str = "query" #: :meta private: output_key: str = "result" #: :meta private: @@ -55,7 +57,7 @@ class VectorDBQA(Chain, BaseModel): def _call(self, inputs: Dict[str, str]) -> Dict[str, str]: question = inputs[self.input_key] llm_chain = LLMChain(llm=self.llm, prompt=prompt) - docs = self.vectorstore.similarity_search(question) + docs = self.vectorstore.similarity_search(question, k=self.k) contexts = [] for j, doc in enumerate(docs): contexts.append(f"Context {j}:\n{doc.page_content}") diff --git a/langchain/model_laboratory.py b/langchain/model_laboratory.py index 0243f70e889..d61265c01df 100644 --- a/langchain/model_laboratory.py +++ b/langchain/model_laboratory.py @@ -1,6 +1,7 @@ """Experiment with different models.""" -from typing import List, Optional +from typing import List, Optional, Sequence +from langchain.chains.base import Chain from langchain.chains.llm import LLMChain from langchain.input import get_color_mapping, print_text from langchain.llms.base import LLM @@ -10,7 +11,41 @@ from langchain.prompts.prompt import Prompt class ModelLaboratory: """Experiment with different models.""" - def __init__(self, llms: List[LLM], prompt: Optional[Prompt] = None): + def __init__(self, chains: Sequence[Chain], names: Optional[List[str]] = None): + """Initialize with chains to experiment with. + + Args: + chains: list of chains to experiment with. + """ + if not isinstance(chains[0], Chain): + raise ValueError( + "ModelLaboratory should now be initialized with Chains. " + "If you want to initialize with LLMs, use the `from_llms` method " + "instead (`ModelLaboratory.from_llms(...)`)" + ) + for chain in chains: + if len(chain.input_keys) != 1: + raise ValueError( + "Currently only support chains with one input variable, " + f"got {chain.input_keys}" + ) + if len(chain.output_keys) != 1: + raise ValueError( + "Currently only support chains with one output variable, " + f"got {chain.output_keys}" + ) + if names is not None: + if len(names) != len(chains): + raise ValueError("Length of chains does not match length of names.") + self.chains = chains + chain_range = [str(i) for i in range(len(self.chains))] + self.chain_colors = get_color_mapping(chain_range) + self.names = names + + @classmethod + def from_llms( + cls, llms: List[LLM], prompt: Optional[Prompt] = None + ) -> "ModelLaboratory": """Initialize with LLMs to experiment with and optional prompt. Args: @@ -18,18 +53,11 @@ class ModelLaboratory: prompt: Optional prompt to use to prompt the LLMs. Defaults to None. If a prompt was provided, it should only have one input variable. """ - self.llms = llms - llm_range = [str(i) for i in range(len(self.llms))] - self.llm_colors = get_color_mapping(llm_range) if prompt is None: - self.prompt = Prompt(input_variables=["_input"], template="{_input}") - else: - if len(prompt.input_variables) != 1: - raise ValueError( - "Currently only support prompts with one input variable, " - f"got {prompt}" - ) - self.prompt = prompt + prompt = Prompt(input_variables=["_input"], template="{_input}") + chains = [LLMChain(llm=llm, prompt=prompt) for llm in llms] + names = [str(llm) for llm in llms] + return cls(chains, names=names) def compare(self, text: str) -> None: """Compare model outputs on an input text. @@ -42,9 +70,11 @@ class ModelLaboratory: text: input text to run all models on. """ print(f"\033[1mInput:\033[0m\n{text}\n") - for i, llm in enumerate(self.llms): - print_text(str(llm), end="\n") - chain = LLMChain(llm=llm, prompt=self.prompt) - llm_inputs = {self.prompt.input_variables[0]: text} - output = chain.predict(**llm_inputs) - print_text(output, color=self.llm_colors[str(i)], end="\n\n") + for i, chain in enumerate(self.chains): + if self.names is not None: + name = self.names[i] + else: + name = str(chain) + print_text(name, end="\n") + output = chain.run(text) + print_text(output, color=self.chain_colors[str(i)], end="\n\n") From 8869b0ab0e94af5e950b2a935c53507d5ddd359b Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Fri, 18 Nov 2022 06:09:03 -0800 Subject: [PATCH 16/16] bump version to 0.0.16 (#157) --- langchain/VERSION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/langchain/VERSION b/langchain/VERSION index ceddfb28f4f..e3b86dd9cc1 100644 --- a/langchain/VERSION +++ b/langchain/VERSION @@ -1 +1 @@ -0.0.15 +0.0.16