From 76cecf8165b5c9d6c2b0b81a81361a9d550cb4c7 Mon Sep 17 00:00:00 2001 From: Delip Rao Date: Mon, 14 Nov 2022 11:34:01 -0500 Subject: [PATCH] A fix for Jupyter environment variable issue (#135) - fixes the Jupyter environment variable issues mentioned in issue #134 - fixes format/lint issues in some unrelated files (from make format/lint) ![image](https://user-images.githubusercontent.com/347398/201599322-090af858-362d-4d69-bf59-208aea65419a.png) --- langchain/chains/mapreduce.py | 4 +--- langchain/chains/natbot/crawler.py | 8 +------- langchain/chains/react/prompt.py | 6 +----- langchain/chains/self_ask_with_search/prompt.py | 5 +---- langchain/chains/sql_database/prompt.py | 3 +-- langchain/embeddings/cohere.py | 8 +++++--- langchain/embeddings/openai.py | 8 +++++--- langchain/llms/ai21.py | 13 ++++--------- langchain/llms/cohere.py | 9 +++++---- langchain/llms/huggingface_hub.py | 11 ++++++----- langchain/llms/nlpcloud.py | 8 +++++--- langchain/llms/openai.py | 8 +++++--- langchain/llms/utils.py | 10 +++++++++- langchain/vectorstores/elastic_vector_search.py | 5 +---- tests/integration_tests/llms/test_manifest.py | 4 +--- 15 files changed, 51 insertions(+), 59 deletions(-) diff --git a/langchain/chains/mapreduce.py b/langchain/chains/mapreduce.py index 0a88f945c61..4a2cc6417e0 100644 --- a/langchain/chains/mapreduce.py +++ b/langchain/chains/mapreduce.py @@ -59,9 +59,7 @@ class MapReduceChain(Chain, BaseModel): def _call(self, inputs: Dict[str, str]) -> Dict[str, str]: # Split the larger text into smaller chunks. - docs = self.text_splitter.split_text( - inputs[self.input_key], - ) + docs = self.text_splitter.split_text(inputs[self.input_key],) # Now that we have the chunks, we send them to the LLM and track results. # This is the "map" part. summaries = [] diff --git a/langchain/chains/natbot/crawler.py b/langchain/chains/natbot/crawler.py index 341b890b2fa..ec14d24973b 100644 --- a/langchain/chains/natbot/crawler.py +++ b/langchain/chains/natbot/crawler.py @@ -28,13 +28,7 @@ class Crawler: "Could not import playwright python package. " "Please it install it with `pip install playwright`." ) - self.browser = ( - sync_playwright() - .start() - .chromium.launch( - headless=False, - ) - ) + self.browser = sync_playwright().start().chromium.launch(headless=False,) self.page = self.browser.new_page() self.page.set_viewport_size({"width": 1280, "height": 1080}) diff --git a/langchain/chains/react/prompt.py b/langchain/chains/react/prompt.py index e0e16299f86..8f7f55d20f6 100644 --- a/langchain/chains/react/prompt.py +++ b/langchain/chains/react/prompt.py @@ -109,8 +109,4 @@ Action 3: Finish[yes]""", ] SUFFIX = """\n\nQuestion: {input}""" -PROMPT = Prompt.from_examples( - EXAMPLES, - SUFFIX, - ["input"], -) +PROMPT = Prompt.from_examples(EXAMPLES, SUFFIX, ["input"],) diff --git a/langchain/chains/self_ask_with_search/prompt.py b/langchain/chains/self_ask_with_search/prompt.py index 003e68dd7fd..4c7fff87b1a 100644 --- a/langchain/chains/self_ask_with_search/prompt.py +++ b/langchain/chains/self_ask_with_search/prompt.py @@ -38,7 +38,4 @@ Intermediate Answer: New Zealand. So the final answer is: No Question: {input}""" -PROMPT = Prompt( - input_variables=["input"], - template=_DEFAULT_TEMPLATE, -) +PROMPT = Prompt(input_variables=["input"], template=_DEFAULT_TEMPLATE,) diff --git a/langchain/chains/sql_database/prompt.py b/langchain/chains/sql_database/prompt.py index c35c92e4b57..4532ad24596 100644 --- a/langchain/chains/sql_database/prompt.py +++ b/langchain/chains/sql_database/prompt.py @@ -15,6 +15,5 @@ Only use the following tables: Question: {input}""" PROMPT = Prompt( - input_variables=["input", "table_info", "dialect"], - template=_DEFAULT_TEMPLATE, + input_variables=["input", "table_info", "dialect"], template=_DEFAULT_TEMPLATE, ) diff --git a/langchain/embeddings/cohere.py b/langchain/embeddings/cohere.py index 3ff641bfb02..f98622491ef 100644 --- a/langchain/embeddings/cohere.py +++ b/langchain/embeddings/cohere.py @@ -1,10 +1,10 @@ """Wrapper around Cohere embedding models.""" -import os from typing import Any, Dict, List, Optional from pydantic import BaseModel, Extra, root_validator from langchain.embeddings.base import Embeddings +from langchain.llms.utils import get_from_dict_or_env class CohereEmbeddings(BaseModel, Embeddings): @@ -25,7 +25,7 @@ class CohereEmbeddings(BaseModel, Embeddings): model: str = "medium" """Model name to use.""" - cohere_api_key: Optional[str] = os.environ.get("COHERE_API_KEY") + cohere_api_key: Optional[str] = None class Config: """Configuration for this pydantic object.""" @@ -35,7 +35,9 @@ class CohereEmbeddings(BaseModel, Embeddings): @root_validator() def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" - cohere_api_key = values.get("cohere_api_key") + cohere_api_key = get_from_dict_or_env( + values, "cohere_api_key", "COHERE_API_KEY" + ) if cohere_api_key is None or cohere_api_key == "": raise ValueError( diff --git a/langchain/embeddings/openai.py b/langchain/embeddings/openai.py index 0bb97e58ae0..a7366ab614b 100644 --- a/langchain/embeddings/openai.py +++ b/langchain/embeddings/openai.py @@ -1,10 +1,10 @@ """Wrapper around OpenAI embedding models.""" -import os from typing import Any, Dict, List, Optional from pydantic import BaseModel, Extra, root_validator from langchain.embeddings.base import Embeddings +from langchain.llms.utils import get_from_dict_or_env class OpenAIEmbeddings(BaseModel, Embeddings): @@ -25,7 +25,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings): model_name: str = "babbage" """Model name to use.""" - openai_api_key: Optional[str] = os.environ.get("OPENAI_API_KEY") + openai_api_key: Optional[str] = None class Config: """Configuration for this pydantic object.""" @@ -35,7 +35,9 @@ class OpenAIEmbeddings(BaseModel, Embeddings): @root_validator() def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" - openai_api_key = values.get("openai_api_key") + openai_api_key = get_from_dict_or_env( + values, "openai_api_key", "OPENAI_API_KEY" + ) if openai_api_key is None or openai_api_key == "": raise ValueError( diff --git a/langchain/llms/ai21.py b/langchain/llms/ai21.py index b1dac08b32f..8c4ead45974 100644 --- a/langchain/llms/ai21.py +++ b/langchain/llms/ai21.py @@ -1,11 +1,11 @@ """Wrapper around AI21 APIs.""" -import os from typing import Any, Dict, List, Mapping, Optional import requests from pydantic import BaseModel, Extra, root_validator from langchain.llms.base import LLM +from langchain.llms.utils import get_from_dict_or_env class AI21PenaltyData(BaseModel): @@ -62,7 +62,7 @@ class AI21(BaseModel, LLM): logitBias: Optional[Dict[str, float]] = None """Adjust the probability of specific tokens being generated.""" - ai21_api_key: Optional[str] = os.environ.get("AI21_API_KEY") + ai21_api_key: Optional[str] = None class Config: """Configuration for this pydantic object.""" @@ -72,8 +72,7 @@ class AI21(BaseModel, LLM): @root_validator() def validate_environment(cls, values: Dict) -> Dict: """Validate that api key exists in environment.""" - ai21_api_key = values.get("ai21_api_key") - + ai21_api_key = get_from_dict_or_env(values, "ai21_api_key", "AI21_API_KEY") if ai21_api_key is None or ai21_api_key == "": raise ValueError( "Did not find AI21 API key, please add an environment variable" @@ -122,11 +121,7 @@ class AI21(BaseModel, LLM): response = requests.post( url=f"https://api.ai21.com/studio/v1/{self.model}/complete", headers={"Authorization": f"Bearer {self.ai21_api_key}"}, - json={ - "prompt": prompt, - "stopSequences": stop, - **self._default_params, - }, + json={"prompt": prompt, "stopSequences": stop, **self._default_params,}, ) if response.status_code != 200: optional_detail = response.json().get("error") diff --git a/langchain/llms/cohere.py b/langchain/llms/cohere.py index 2a41b807ea3..fed3b56e8ed 100644 --- a/langchain/llms/cohere.py +++ b/langchain/llms/cohere.py @@ -1,11 +1,10 @@ """Wrapper around Cohere APIs.""" -import os from typing import Any, Dict, List, Mapping, Optional from pydantic import BaseModel, Extra, root_validator from langchain.llms.base import LLM -from langchain.llms.utils import enforce_stop_tokens +from langchain.llms.utils import enforce_stop_tokens, get_from_dict_or_env class Cohere(LLM, BaseModel): @@ -44,7 +43,7 @@ class Cohere(LLM, BaseModel): presence_penalty: int = 0 """Penalizes repeated tokens.""" - cohere_api_key: Optional[str] = os.environ.get("COHERE_API_KEY") + cohere_api_key: Optional[str] = None class Config: """Configuration for this pydantic object.""" @@ -54,7 +53,9 @@ class Cohere(LLM, BaseModel): @root_validator() def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" - cohere_api_key = values.get("cohere_api_key") + cohere_api_key = get_from_dict_or_env( + values, "cohere_api_key", "COHERE_API_KEY" + ) if cohere_api_key is None or cohere_api_key == "": raise ValueError( diff --git a/langchain/llms/huggingface_hub.py b/langchain/llms/huggingface_hub.py index 8d584558ae9..1cea3d555cb 100644 --- a/langchain/llms/huggingface_hub.py +++ b/langchain/llms/huggingface_hub.py @@ -1,11 +1,10 @@ """Wrapper around HuggingFace APIs.""" -import os from typing import Any, Dict, List, Mapping, Optional from pydantic import BaseModel, Extra, root_validator from langchain.llms.base import LLM -from langchain.llms.utils import enforce_stop_tokens +from langchain.llms.utils import enforce_stop_tokens, get_from_dict_or_env DEFAULT_REPO_ID = "gpt2" VALID_TASKS = ("text2text-generation", "text-generation") @@ -18,7 +17,7 @@ class HuggingFaceHub(LLM, BaseModel): environment variable ``HUGGINGFACEHUB_API_TOKEN`` set with your API token, or pass it as a named parameter to the constructor. - Only supports task `text-generation` for now. + Only supports `text-generation` and `text2text-generation` for now. Example: .. code-block:: python @@ -35,7 +34,7 @@ class HuggingFaceHub(LLM, BaseModel): model_kwargs: Optional[dict] = None """Key word arguments to pass to the model.""" - huggingfacehub_api_token: Optional[str] = os.environ.get("HUGGINGFACEHUB_API_TOKEN") + huggingfacehub_api_token: Optional[str] = None class Config: """Configuration for this pydantic object.""" @@ -45,7 +44,9 @@ class HuggingFaceHub(LLM, BaseModel): @root_validator() def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" - huggingfacehub_api_token = values.get("huggingfacehub_api_token") + huggingfacehub_api_token = get_from_dict_or_env( + values, "huggingfacehub_api_token", "HUGGINGFACEHUB_API_TOKEN" + ) if huggingfacehub_api_token is None or huggingfacehub_api_token == "": raise ValueError( "Did not find HuggingFace API token, please add an environment variable" diff --git a/langchain/llms/nlpcloud.py b/langchain/llms/nlpcloud.py index 771cc815a56..e4c37ff7032 100644 --- a/langchain/llms/nlpcloud.py +++ b/langchain/llms/nlpcloud.py @@ -1,10 +1,10 @@ """Wrapper around NLPCloud APIs.""" -import os from typing import Any, Dict, List, Mapping, Optional from pydantic import BaseModel, Extra, root_validator from langchain.llms.base import LLM +from langchain.llms.utils import get_from_dict_or_env class NLPCloud(LLM, BaseModel): @@ -54,7 +54,7 @@ class NLPCloud(LLM, BaseModel): num_return_sequences: int = 1 """How many completions to generate for each prompt.""" - nlpcloud_api_key: Optional[str] = os.environ.get("NLPCLOUD_API_KEY") + nlpcloud_api_key: Optional[str] = None class Config: """Configuration for this pydantic object.""" @@ -64,7 +64,9 @@ class NLPCloud(LLM, BaseModel): @root_validator() def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" - nlpcloud_api_key = values.get("nlpcloud_api_key") + nlpcloud_api_key = get_from_dict_or_env( + values, "nlpcloud_api_key", "NLPCLOUD_API_KEY" + ) if nlpcloud_api_key is None or nlpcloud_api_key == "": raise ValueError( diff --git a/langchain/llms/openai.py b/langchain/llms/openai.py index 2355015b64d..f0127efb518 100644 --- a/langchain/llms/openai.py +++ b/langchain/llms/openai.py @@ -1,10 +1,10 @@ """Wrapper around OpenAI APIs.""" -import os from typing import Any, Dict, List, Mapping, Optional from pydantic import BaseModel, Extra, root_validator from langchain.llms.base import LLM +from langchain.llms.utils import get_from_dict_or_env class OpenAI(LLM, BaseModel): @@ -38,7 +38,7 @@ class OpenAI(LLM, BaseModel): best_of: int = 1 """Generates best_of completions server-side and returns the "best".""" - openai_api_key: Optional[str] = os.environ.get("OPENAI_API_KEY") + openai_api_key: Optional[str] = None class Config: """Configuration for this pydantic object.""" @@ -48,7 +48,9 @@ class OpenAI(LLM, BaseModel): @root_validator() def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" - openai_api_key = values.get("openai_api_key") + openai_api_key = get_from_dict_or_env( + values, "openai_api_key", "OPENAI_API_KEY" + ) if openai_api_key is None or openai_api_key == "": raise ValueError( diff --git a/langchain/llms/utils.py b/langchain/llms/utils.py index a42fd130ee6..ef7407ab621 100644 --- a/langchain/llms/utils.py +++ b/langchain/llms/utils.py @@ -1,8 +1,16 @@ """Common utility functions for working with LLM APIs.""" +import os import re -from typing import List +from typing import Any, Dict, List def enforce_stop_tokens(text: str, stop: List[str]) -> str: """Cut off the text as soon as any stop words occur.""" return re.split("|".join(stop), text)[0] + + +def get_from_dict_or_env(data: Dict[str, Any], key: str, env_key: str) -> Any: + """Get a value from a dictionary or an environment variable.""" + if key in data: + return data[key] + return os.environ.get(env_key, None) diff --git a/langchain/vectorstores/elastic_vector_search.py b/langchain/vectorstores/elastic_vector_search.py index d8aed5f67d2..a0a623467a8 100644 --- a/langchain/vectorstores/elastic_vector_search.py +++ b/langchain/vectorstores/elastic_vector_search.py @@ -45,10 +45,7 @@ class ElasticVectorSearch(VectorStore): """ def __init__( - self, - elasticsearch_url: str, - index_name: str, - embedding_function: Callable, + self, elasticsearch_url: str, index_name: str, embedding_function: Callable, ): """Initialize with necessary components.""" try: diff --git a/tests/integration_tests/llms/test_manifest.py b/tests/integration_tests/llms/test_manifest.py index 3416a332862..41c54cc21b1 100644 --- a/tests/integration_tests/llms/test_manifest.py +++ b/tests/integration_tests/llms/test_manifest.py @@ -6,9 +6,7 @@ def test_manifest_wrapper() -> None: """Test manifest wrapper.""" from manifest import Manifest - manifest = Manifest( - client_name="openai", - ) + manifest = Manifest(client_name="openai",) llm = ManifestWrapper(client=manifest, llm_kwargs={"temperature": 0}) output = llm("The capital of New York is:") assert output == "Albany"