mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-26 08:33:49 +00:00
Merge branch 'master' into harrison/prompts_take_2
This commit is contained in:
commit
bf3a9973f0
@ -1 +1 @@
|
||||
0.0.13
|
||||
0.0.14
|
||||
|
@ -9,6 +9,7 @@ from typing import Any, Dict, List, Optional
|
||||
from pydantic import BaseModel, Extra, root_validator
|
||||
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
|
||||
class HiddenPrints:
|
||||
@ -43,7 +44,7 @@ class SerpAPIChain(Chain, BaseModel):
|
||||
input_key: str = "search_query" #: :meta private:
|
||||
output_key: str = "search_result" #: :meta private:
|
||||
|
||||
serpapi_api_key: Optional[str] = os.environ.get("SERPAPI_API_KEY")
|
||||
serpapi_api_key: Optional[str] = None
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
@ -69,14 +70,10 @@ class SerpAPIChain(Chain, BaseModel):
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
serpapi_api_key = values.get("serpapi_api_key")
|
||||
|
||||
if serpapi_api_key is None or serpapi_api_key == "":
|
||||
raise ValueError(
|
||||
"Did not find SerpAPI API key, please add an environment variable"
|
||||
" `SERPAPI_API_KEY` which contains it, or pass `serpapi_api_key` "
|
||||
"as a named parameter to the constructor."
|
||||
)
|
||||
serpapi_api_key = get_from_dict_or_env(
|
||||
values, "serpapi_api_key", "SERPAPI_API_KEY"
|
||||
)
|
||||
values["serpapi_api_key"] = serpapi_api_key
|
||||
try:
|
||||
from serpapi import GoogleSearch
|
||||
|
||||
|
@ -4,7 +4,7 @@ from typing import Any, Dict, List, Optional
|
||||
from pydantic import BaseModel, Extra, root_validator
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.llms.utils import get_from_dict_or_env
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
|
||||
class CohereEmbeddings(BaseModel, Embeddings):
|
||||
@ -38,13 +38,6 @@ class CohereEmbeddings(BaseModel, Embeddings):
|
||||
cohere_api_key = get_from_dict_or_env(
|
||||
values, "cohere_api_key", "COHERE_API_KEY"
|
||||
)
|
||||
|
||||
if cohere_api_key is None or cohere_api_key == "":
|
||||
raise ValueError(
|
||||
"Did not find Cohere API key, please add an environment variable"
|
||||
" `COHERE_API_KEY` which contains it, or pass `cohere_api_key` as a"
|
||||
" named parameter."
|
||||
)
|
||||
try:
|
||||
import cohere
|
||||
|
||||
|
@ -4,7 +4,7 @@ from typing import Any, Dict, List, Optional
|
||||
from pydantic import BaseModel, Extra, root_validator
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.llms.utils import get_from_dict_or_env
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
|
||||
class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
@ -38,13 +38,6 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
openai_api_key = get_from_dict_or_env(
|
||||
values, "openai_api_key", "OPENAI_API_KEY"
|
||||
)
|
||||
|
||||
if openai_api_key is None or openai_api_key == "":
|
||||
raise ValueError(
|
||||
"Did not find OpenAI API key, please add an environment variable"
|
||||
" `OPENAI_API_KEY` which contains it, or pass `openai_api_key` as a"
|
||||
" named parameter."
|
||||
)
|
||||
try:
|
||||
import openai
|
||||
|
||||
|
@ -5,7 +5,7 @@ import requests
|
||||
from pydantic import BaseModel, Extra, root_validator
|
||||
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.llms.utils import get_from_dict_or_env
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
|
||||
class AI21PenaltyData(BaseModel):
|
||||
@ -73,12 +73,7 @@ class AI21(BaseModel, LLM):
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key exists in environment."""
|
||||
ai21_api_key = get_from_dict_or_env(values, "ai21_api_key", "AI21_API_KEY")
|
||||
if ai21_api_key is None or ai21_api_key == "":
|
||||
raise ValueError(
|
||||
"Did not find AI21 API key, please add an environment variable"
|
||||
" `AI21_API_KEY` which contains it, or pass `ai21_api_key`"
|
||||
" as a named parameter."
|
||||
)
|
||||
values["ai21_api_key"] = ai21_api_key
|
||||
return values
|
||||
|
||||
@property
|
||||
|
@ -4,7 +4,8 @@ from typing import Any, Dict, List, Mapping, Optional
|
||||
from pydantic import BaseModel, Extra, root_validator
|
||||
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.llms.utils import enforce_stop_tokens, get_from_dict_or_env
|
||||
from langchain.llms.utils import enforce_stop_tokens
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
|
||||
class Cohere(LLM, BaseModel):
|
||||
@ -56,13 +57,6 @@ class Cohere(LLM, BaseModel):
|
||||
cohere_api_key = get_from_dict_or_env(
|
||||
values, "cohere_api_key", "COHERE_API_KEY"
|
||||
)
|
||||
|
||||
if cohere_api_key is None or cohere_api_key == "":
|
||||
raise ValueError(
|
||||
"Did not find Cohere API key, please add an environment variable"
|
||||
" `COHERE_API_KEY` which contains it, or pass `cohere_api_key`"
|
||||
" as a named parameter."
|
||||
)
|
||||
try:
|
||||
import cohere
|
||||
|
||||
|
@ -4,7 +4,8 @@ from typing import Any, Dict, List, Mapping, Optional
|
||||
from pydantic import BaseModel, Extra, root_validator
|
||||
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.llms.utils import enforce_stop_tokens, get_from_dict_or_env
|
||||
from langchain.llms.utils import enforce_stop_tokens
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
DEFAULT_REPO_ID = "gpt2"
|
||||
VALID_TASKS = ("text2text-generation", "text-generation")
|
||||
@ -47,12 +48,6 @@ class HuggingFaceHub(LLM, BaseModel):
|
||||
huggingfacehub_api_token = get_from_dict_or_env(
|
||||
values, "huggingfacehub_api_token", "HUGGINGFACEHUB_API_TOKEN"
|
||||
)
|
||||
if huggingfacehub_api_token is None or huggingfacehub_api_token == "":
|
||||
raise ValueError(
|
||||
"Did not find HuggingFace API token, please add an environment variable"
|
||||
" `HUGGINGFACEHUB_API_TOKEN` which contains it, or pass"
|
||||
" `huggingfacehub_api_token` as a named parameter."
|
||||
)
|
||||
try:
|
||||
from huggingface_hub.inference_api import InferenceApi
|
||||
|
||||
|
@ -4,7 +4,7 @@ from typing import Any, Dict, List, Mapping, Optional
|
||||
from pydantic import BaseModel, Extra, root_validator
|
||||
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.llms.utils import get_from_dict_or_env
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
|
||||
class NLPCloud(LLM, BaseModel):
|
||||
@ -67,13 +67,6 @@ class NLPCloud(LLM, BaseModel):
|
||||
nlpcloud_api_key = get_from_dict_or_env(
|
||||
values, "nlpcloud_api_key", "NLPCLOUD_API_KEY"
|
||||
)
|
||||
|
||||
if nlpcloud_api_key is None or nlpcloud_api_key == "":
|
||||
raise ValueError(
|
||||
"Did not find NLPCloud API key, please add an environment variable"
|
||||
" `NLPCLOUD_API_KEY` which contains it, or pass `nlpcloud_api_key`"
|
||||
" as a named parameter."
|
||||
)
|
||||
try:
|
||||
import nlpcloud
|
||||
|
||||
|
@ -4,7 +4,7 @@ from typing import Any, Dict, List, Mapping, Optional
|
||||
from pydantic import BaseModel, Extra, root_validator
|
||||
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.llms.utils import get_from_dict_or_env
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
|
||||
class OpenAI(LLM, BaseModel):
|
||||
@ -51,13 +51,6 @@ class OpenAI(LLM, BaseModel):
|
||||
openai_api_key = get_from_dict_or_env(
|
||||
values, "openai_api_key", "OPENAI_API_KEY"
|
||||
)
|
||||
|
||||
if openai_api_key is None or openai_api_key == "":
|
||||
raise ValueError(
|
||||
"Did not find OpenAI API key, please add an environment variable"
|
||||
" `OPENAI_API_KEY` which contains it, or pass `openai_api_key`"
|
||||
" as a named parameter."
|
||||
)
|
||||
try:
|
||||
import openai
|
||||
|
||||
|
@ -1,16 +1,8 @@
|
||||
"""Common utility functions for working with LLM APIs."""
|
||||
import os
|
||||
import re
|
||||
from typing import Any, Dict, List
|
||||
from typing import List
|
||||
|
||||
|
||||
def enforce_stop_tokens(text: str, stop: List[str]) -> str:
|
||||
"""Cut off the text as soon as any stop words occur."""
|
||||
return re.split("|".join(stop), text)[0]
|
||||
|
||||
|
||||
def get_from_dict_or_env(data: Dict[str, Any], key: str, env_key: str) -> Any:
|
||||
"""Get a value from a dictionary or an environment variable."""
|
||||
if key in data:
|
||||
return data[key]
|
||||
return os.environ.get(env_key, None)
|
||||
|
@ -94,8 +94,7 @@ class Prompt(BaseModel, BasePrompt):
|
||||
Returns:
|
||||
The final prompt generated.
|
||||
"""
|
||||
example_str = example_separator.join(examples)
|
||||
template = prefix + example_str + suffix
|
||||
template = example_separator.join([prefix, *examples, suffix])
|
||||
return cls(input_variables=input_variables, template=template)
|
||||
|
||||
@classmethod
|
||||
|
17
langchain/utils.py
Normal file
17
langchain/utils.py
Normal file
@ -0,0 +1,17 @@
|
||||
"""Generic utility functions."""
|
||||
import os
|
||||
from typing import Any, Dict
|
||||
|
||||
|
||||
def get_from_dict_or_env(data: Dict[str, Any], key: str, env_key: str) -> str:
|
||||
"""Get a value from a dictionary or an environment variable."""
|
||||
if key in data and data[key]:
|
||||
return data[key]
|
||||
elif env_key in os.environ and os.environ[env_key]:
|
||||
return os.environ[env_key]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Did not find {key}, please add an environment variable"
|
||||
f" `{env_key}` which contains it, or pass"
|
||||
f" `{key}` as a named parameter."
|
||||
)
|
@ -1,10 +1,10 @@
|
||||
"""Wrapper around Elasticsearch vector database."""
|
||||
import os
|
||||
import uuid
|
||||
from typing import Any, Callable, Dict, List
|
||||
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
from langchain.vectorstores.base import VectorStore
|
||||
|
||||
|
||||
@ -107,16 +107,9 @@ class ElasticVectorSearch(VectorStore):
|
||||
elasticsearch_url="http://localhost:9200"
|
||||
)
|
||||
"""
|
||||
elasticsearch_url = kwargs.get("elasticsearch_url")
|
||||
if not elasticsearch_url:
|
||||
elasticsearch_url = os.environ.get("ELASTICSEARCH_URL")
|
||||
|
||||
if elasticsearch_url is None or elasticsearch_url == "":
|
||||
raise ValueError(
|
||||
"Did not find Elasticsearch URL, please add an environment variable"
|
||||
" `ELASTICSEARCH_URL` which contains it, or pass"
|
||||
" `elasticsearch_url` as a named parameter."
|
||||
)
|
||||
elasticsearch_url = get_from_dict_or_env(
|
||||
kwargs, "elasticsearch_url", "ELASTICSEARCH_URL"
|
||||
)
|
||||
try:
|
||||
import elasticsearch
|
||||
from elasticsearch.helpers import bulk
|
||||
|
@ -51,8 +51,8 @@ Question: {question}
|
||||
Answer:"""
|
||||
input_variables = ["question"]
|
||||
example_separator = "\n\n"
|
||||
prefix = """Test Prompt:\n\n"""
|
||||
suffix = """\n\nQuestion: {question}\nAnswer:"""
|
||||
prefix = """Test Prompt:"""
|
||||
suffix = """Question: {question}\nAnswer:"""
|
||||
examples = [
|
||||
"""Question: who are you?\nAnswer: foo""",
|
||||
"""Question: what are you?\nAnswer: bar""",
|
||||
|
Loading…
Reference in New Issue
Block a user