mirror of
https://github.com/hwchase17/langchain.git
synced 2026-01-22 11:11:48 +00:00
Compare commits
3 Commits
langchain-
...
dev2049/ch
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6e38bc7f2c | ||
|
|
38b3fc80fe | ||
|
|
4e9f6de31a |
@@ -37,7 +37,7 @@ class JsonFormer(HuggingFacePipeline):
|
||||
import_jsonformer()
|
||||
return values
|
||||
|
||||
def _call(
|
||||
def _generate_single(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
|
||||
@@ -42,7 +42,7 @@ class RELLM(HuggingFacePipeline):
|
||||
import_rellm()
|
||||
return values
|
||||
|
||||
def _call(
|
||||
def _generate_single(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
|
||||
@@ -5,7 +5,7 @@ import requests
|
||||
from pydantic import BaseModel, Extra, root_validator
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.llms.base import SimpleLLM
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ class AI21PenaltyData(BaseModel):
|
||||
applyToEmojis: bool = True
|
||||
|
||||
|
||||
class AI21(LLM):
|
||||
class AI21(SimpleLLM):
|
||||
"""Wrapper around AI21 large language models.
|
||||
|
||||
To use, you should have the environment variable ``AI21_API_KEY``
|
||||
@@ -107,7 +107,7 @@ class AI21(LLM):
|
||||
"""Return type of llm."""
|
||||
return "ai21"
|
||||
|
||||
def _call(
|
||||
def _generate_single(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
|
||||
@@ -4,12 +4,12 @@ from typing import Any, Dict, List, Optional, Sequence
|
||||
from pydantic import Extra, root_validator
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.llms.base import SimpleLLM
|
||||
from langchain.llms.utils import enforce_stop_tokens
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
|
||||
class AlephAlpha(LLM):
|
||||
class AlephAlpha(SimpleLLM):
|
||||
"""Wrapper around Aleph Alpha large language models.
|
||||
|
||||
To use, you should have the ``aleph_alpha_client`` python package installed, and the
|
||||
@@ -201,7 +201,7 @@ class AlephAlpha(LLM):
|
||||
"""Return type of llm."""
|
||||
return "alpeh_alpha"
|
||||
|
||||
def _call(
|
||||
def _generate_single(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
|
||||
@@ -9,7 +9,7 @@ from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.llms.base import SimpleLLM
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
|
||||
@@ -104,7 +104,7 @@ class _AnthropicCommon(BaseModel):
|
||||
return self.count_tokens(text)
|
||||
|
||||
|
||||
class Anthropic(LLM, _AnthropicCommon):
|
||||
class Anthropic(SimpleLLM, _AnthropicCommon):
|
||||
r"""Wrapper around Anthropic's large language models.
|
||||
|
||||
To use, you should have the ``anthropic`` python package installed, and the
|
||||
@@ -162,7 +162,7 @@ class Anthropic(LLM, _AnthropicCommon):
|
||||
# As a last resort, wrap the prompt ourselves to emulate instruct-style.
|
||||
return f"{self.HUMAN_PROMPT} {prompt}{self.AI_PROMPT} Sure, here you go:\n"
|
||||
|
||||
def _call(
|
||||
def _generate_single(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
|
||||
@@ -5,12 +5,12 @@ import requests
|
||||
from pydantic import Extra, root_validator
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.llms.base import SimpleLLM
|
||||
from langchain.llms.utils import enforce_stop_tokens
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
|
||||
class Anyscale(LLM):
|
||||
class Anyscale(SimpleLLM):
|
||||
"""Wrapper around Anyscale Services.
|
||||
To use, you should have the environment variable ``ANYSCALE_SERVICE_URL``,
|
||||
``ANYSCALE_SERVICE_ROUTE`` and ``ANYSCALE_SERVICE_TOKEN`` set with your Anyscale
|
||||
@@ -82,7 +82,7 @@ class Anyscale(LLM):
|
||||
"""Return type of llm."""
|
||||
return "anyscale"
|
||||
|
||||
def _call(
|
||||
def _generate_single(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
|
||||
@@ -5,14 +5,14 @@ from typing import Any, Dict, List, Mapping, Optional
|
||||
from pydantic import Extra, Field, root_validator
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.llms.base import SimpleLLM
|
||||
from langchain.llms.utils import enforce_stop_tokens
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Banana(LLM):
|
||||
class Banana(SimpleLLM):
|
||||
"""Wrapper around Banana large language models.
|
||||
|
||||
To use, you should have the ``banana-dev`` python package installed,
|
||||
@@ -81,7 +81,7 @@ class Banana(LLM):
|
||||
"""Return type of llm."""
|
||||
return "banana"
|
||||
|
||||
def _call(
|
||||
def _generate_single(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
|
||||
@@ -365,14 +365,32 @@ class BaseLLM(BaseLanguageModel, ABC):
|
||||
raise ValueError(f"{save_path} must be json or yaml")
|
||||
|
||||
|
||||
class LLM(BaseLLM):
|
||||
class SimpleLLM(BaseLLM):
|
||||
"""LLM class that expect subclasses to implement a simpler call method.
|
||||
|
||||
The purpose of this class is to expose a simpler interface for working
|
||||
with LLMs, rather than expect the user to implement the full _generate method.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def _generate_single(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
) -> str:
|
||||
"""Run the LLM on a single input string and return the output as a string."""
|
||||
raise NotImplementedError
|
||||
|
||||
async def _agenerate_single(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
) -> str:
|
||||
"""Run the LLM on a single input string and return the output as a string."""
|
||||
raise NotImplementedError
|
||||
|
||||
# Kept for backwards compatibility
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
@@ -380,7 +398,9 @@ class LLM(BaseLLM):
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
) -> str:
|
||||
"""Run the LLM on the given prompt and input."""
|
||||
return self._generate_single(prompt, stop=stop, run_manager=run_manager)
|
||||
|
||||
# Kept for backwards compatibility
|
||||
async def _acall(
|
||||
self,
|
||||
prompt: str,
|
||||
@@ -388,7 +408,7 @@ class LLM(BaseLLM):
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
) -> str:
|
||||
"""Run the LLM on the given prompt and input."""
|
||||
raise NotImplementedError("Async generation not implemented for this LLM.")
|
||||
return await self._agenerate_single(prompt, stop=stop, run_manager=run_manager)
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
|
||||
@@ -5,14 +5,14 @@ from typing import Any, Dict, List, Mapping, Optional
|
||||
from pydantic import Extra, Field, root_validator
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.llms.base import SimpleLLM
|
||||
from langchain.llms.utils import enforce_stop_tokens
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CerebriumAI(LLM):
|
||||
class CerebriumAI(SimpleLLM):
|
||||
"""Wrapper around CerebriumAI large language models.
|
||||
|
||||
To use, you should have the ``cerebrium`` python package installed, and the
|
||||
@@ -82,7 +82,7 @@ class CerebriumAI(LLM):
|
||||
"""Return type of llm."""
|
||||
return "cerebriumai"
|
||||
|
||||
def _call(
|
||||
def _generate_single(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
|
||||
@@ -5,14 +5,14 @@ from typing import Any, Dict, List, Optional
|
||||
from pydantic import Extra, root_validator
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.llms.base import SimpleLLM
|
||||
from langchain.llms.utils import enforce_stop_tokens
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Cohere(LLM):
|
||||
class Cohere(SimpleLLM):
|
||||
"""Wrapper around Cohere large language models.
|
||||
|
||||
To use, you should have the ``cohere`` python package installed, and the
|
||||
@@ -101,7 +101,7 @@ class Cohere(LLM):
|
||||
"""Return type of llm."""
|
||||
return "cohere"
|
||||
|
||||
def _call(
|
||||
def _generate_single(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
|
||||
@@ -5,14 +5,14 @@ import requests
|
||||
from pydantic import Extra, root_validator
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.llms.base import SimpleLLM
|
||||
from langchain.llms.utils import enforce_stop_tokens
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
DEFAULT_MODEL_ID = "google/flan-t5-xl"
|
||||
|
||||
|
||||
class DeepInfra(LLM):
|
||||
class DeepInfra(SimpleLLM):
|
||||
"""Wrapper around DeepInfra deployed models.
|
||||
|
||||
To use, you should have the ``requests`` python package installed, and the
|
||||
@@ -61,7 +61,7 @@ class DeepInfra(LLM):
|
||||
"""Return type of llm."""
|
||||
return "deepinfra"
|
||||
|
||||
def _call(
|
||||
def _generate_single(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
|
||||
@@ -2,10 +2,10 @@
|
||||
from typing import Any, List, Mapping, Optional
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.llms.base import SimpleLLM
|
||||
|
||||
|
||||
class FakeListLLM(LLM):
|
||||
class FakeListLLM(SimpleLLM):
|
||||
"""Fake LLM wrapper for testing purposes."""
|
||||
|
||||
responses: List
|
||||
@@ -16,7 +16,7 @@ class FakeListLLM(LLM):
|
||||
"""Return type of llm."""
|
||||
return "fake-list"
|
||||
|
||||
def _call(
|
||||
def _generate_single(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
|
||||
@@ -5,12 +5,12 @@ import requests
|
||||
from pydantic import Extra, root_validator
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.llms.base import SimpleLLM
|
||||
from langchain.llms.utils import enforce_stop_tokens
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
|
||||
class ForefrontAI(LLM):
|
||||
class ForefrontAI(SimpleLLM):
|
||||
"""Wrapper around ForefrontAI large language models.
|
||||
|
||||
To use, you should have the environment variable ``FOREFRONTAI_API_KEY``
|
||||
@@ -82,7 +82,7 @@ class ForefrontAI(LLM):
|
||||
"""Return type of llm."""
|
||||
return "forefrontai"
|
||||
|
||||
def _call(
|
||||
def _generate_single(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
|
||||
@@ -5,13 +5,13 @@ from typing import Any, Dict, List, Mapping, Optional
|
||||
from pydantic import Extra, Field, root_validator
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.llms.base import SimpleLLM
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GooseAI(LLM):
|
||||
class GooseAI(SimpleLLM):
|
||||
"""Wrapper around OpenAI large language models.
|
||||
|
||||
To use, you should have the ``openai`` python package installed, and the
|
||||
@@ -131,7 +131,7 @@ class GooseAI(LLM):
|
||||
"""Return type of llm."""
|
||||
return "gooseai"
|
||||
|
||||
def _call(
|
||||
def _generate_single(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
|
||||
@@ -5,11 +5,11 @@ from typing import Any, Dict, List, Mapping, Optional, Set
|
||||
from pydantic import Extra, Field, root_validator
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.llms.base import SimpleLLM
|
||||
from langchain.llms.utils import enforce_stop_tokens
|
||||
|
||||
|
||||
class GPT4All(LLM):
|
||||
class GPT4All(SimpleLLM):
|
||||
r"""Wrapper around GPT4All language models.
|
||||
|
||||
To use, you should have the ``gpt4all`` python package installed, the
|
||||
@@ -167,7 +167,7 @@ class GPT4All(LLM):
|
||||
"""Return the type of llm."""
|
||||
return "gpt4all"
|
||||
|
||||
def _call(
|
||||
def _generate_single(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
|
||||
@@ -5,14 +5,14 @@ import requests
|
||||
from pydantic import Extra, root_validator
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.llms.base import SimpleLLM
|
||||
from langchain.llms.utils import enforce_stop_tokens
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
VALID_TASKS = ("text2text-generation", "text-generation", "summarization")
|
||||
|
||||
|
||||
class HuggingFaceEndpoint(LLM):
|
||||
class HuggingFaceEndpoint(SimpleLLM):
|
||||
"""Wrapper around HuggingFaceHub Inference Endpoints.
|
||||
|
||||
To use, you should have the ``huggingface_hub`` python package installed, and the
|
||||
@@ -91,7 +91,7 @@ class HuggingFaceEndpoint(LLM):
|
||||
"""Return type of llm."""
|
||||
return "huggingface_endpoint"
|
||||
|
||||
def _call(
|
||||
def _generate_single(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
|
||||
@@ -4,7 +4,7 @@ from typing import Any, Dict, List, Mapping, Optional
|
||||
from pydantic import Extra, root_validator
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.llms.base import SimpleLLM
|
||||
from langchain.llms.utils import enforce_stop_tokens
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
@@ -12,7 +12,7 @@ DEFAULT_REPO_ID = "gpt2"
|
||||
VALID_TASKS = ("text2text-generation", "text-generation", "summarization")
|
||||
|
||||
|
||||
class HuggingFaceHub(LLM):
|
||||
class HuggingFaceHub(SimpleLLM):
|
||||
"""Wrapper around HuggingFaceHub models.
|
||||
|
||||
To use, you should have the ``huggingface_hub`` python package installed, and the
|
||||
@@ -86,7 +86,7 @@ class HuggingFaceHub(LLM):
|
||||
"""Return type of llm."""
|
||||
return "huggingface_hub"
|
||||
|
||||
def _call(
|
||||
def _generate_single(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
|
||||
@@ -6,7 +6,7 @@ from typing import Any, List, Mapping, Optional
|
||||
from pydantic import Extra
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.llms.base import SimpleLLM
|
||||
from langchain.llms.utils import enforce_stop_tokens
|
||||
|
||||
DEFAULT_MODEL_ID = "gpt2"
|
||||
@@ -16,7 +16,7 @@ VALID_TASKS = ("text2text-generation", "text-generation", "summarization")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HuggingFacePipeline(LLM):
|
||||
class HuggingFacePipeline(SimpleLLM):
|
||||
"""Wrapper around HuggingFace Pipeline API.
|
||||
|
||||
To use, you should have the ``transformers`` python package installed.
|
||||
@@ -64,7 +64,7 @@ class HuggingFacePipeline(LLM):
|
||||
device: int = -1,
|
||||
model_kwargs: Optional[dict] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLM:
|
||||
) -> SimpleLLM:
|
||||
"""Construct the pipeline object from model_id and task."""
|
||||
try:
|
||||
from transformers import (
|
||||
@@ -150,7 +150,7 @@ class HuggingFacePipeline(LLM):
|
||||
def _llm_type(self) -> str:
|
||||
return "huggingface_pipeline"
|
||||
|
||||
def _call(
|
||||
def _generate_single(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
|
||||
@@ -5,10 +5,10 @@ from typing import Any, Dict, List, Optional
|
||||
from pydantic import Extra, Field, root_validator
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.llms.base import SimpleLLM
|
||||
|
||||
|
||||
class HuggingFaceTextGenInference(LLM):
|
||||
class HuggingFaceTextGenInference(SimpleLLM):
|
||||
"""
|
||||
HuggingFace text generation inference API.
|
||||
|
||||
@@ -108,7 +108,7 @@ class HuggingFaceTextGenInference(LLM):
|
||||
"""Return type of llm."""
|
||||
return "hf_textgen_inference"
|
||||
|
||||
def _call(
|
||||
def _generate_single(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import Any, Callable, List, Mapping, Optional
|
||||
from pydantic import Field
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.llms.base import SimpleLLM
|
||||
from langchain.llms.utils import enforce_stop_tokens
|
||||
|
||||
|
||||
@@ -32,7 +32,7 @@ def _collect_user_input(
|
||||
return multi_line_input
|
||||
|
||||
|
||||
class HumanInputLLM(LLM):
|
||||
class HumanInputLLM(SimpleLLM):
|
||||
"""
|
||||
A LLM wrapper which returns user input as the response.
|
||||
"""
|
||||
@@ -55,7 +55,7 @@ class HumanInputLLM(LLM):
|
||||
"""Returns the type of LLM."""
|
||||
return "human-input"
|
||||
|
||||
def _call(
|
||||
def _generate_single(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
|
||||
@@ -5,12 +5,12 @@ from typing import Any, Dict, Generator, List, Optional
|
||||
from pydantic import Field, root_validator
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.llms.base import SimpleLLM
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LlamaCpp(LLM):
|
||||
class LlamaCpp(SimpleLLM):
|
||||
"""Wrapper around the llama.cpp model.
|
||||
|
||||
To use, you should have the llama-cpp-python library installed, and provide the
|
||||
@@ -195,7 +195,7 @@ class LlamaCpp(LLM):
|
||||
|
||||
return params
|
||||
|
||||
def _call(
|
||||
def _generate_single(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
|
||||
@@ -4,10 +4,10 @@ from typing import Any, Dict, List, Mapping, Optional
|
||||
from pydantic import Extra, root_validator
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.llms.base import SimpleLLM
|
||||
|
||||
|
||||
class ManifestWrapper(LLM):
|
||||
class ManifestWrapper(SimpleLLM):
|
||||
"""Wrapper around HazyResearch's Manifest library."""
|
||||
|
||||
client: Any #: :meta private:
|
||||
@@ -43,7 +43,7 @@ class ManifestWrapper(LLM):
|
||||
"""Return type of llm."""
|
||||
return "manifest"
|
||||
|
||||
def _call(
|
||||
def _generate_single(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
|
||||
@@ -6,13 +6,13 @@ import requests
|
||||
from pydantic import Extra, Field, root_validator
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.llms.base import SimpleLLM
|
||||
from langchain.llms.utils import enforce_stop_tokens
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Modal(LLM):
|
||||
class Modal(SimpleLLM):
|
||||
"""Wrapper around Modal large language models.
|
||||
|
||||
To use, you should have the ``modal-client`` python package installed.
|
||||
@@ -70,7 +70,7 @@ class Modal(LLM):
|
||||
"""Return type of llm."""
|
||||
return "modal"
|
||||
|
||||
def _call(
|
||||
def _generate_single(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
|
||||
@@ -4,11 +4,11 @@ from typing import Any, Dict, List, Mapping, Optional
|
||||
from pydantic import Extra, root_validator
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.llms.base import SimpleLLM
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
|
||||
class NLPCloud(LLM):
|
||||
class NLPCloud(SimpleLLM):
|
||||
"""Wrapper around NLPCloud large language models.
|
||||
|
||||
To use, you should have the ``nlpcloud`` python package installed, and the
|
||||
@@ -112,7 +112,7 @@ class NLPCloud(LLM):
|
||||
"""Return type of llm."""
|
||||
return "nlpcloud"
|
||||
|
||||
def _call(
|
||||
def _generate_single(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
|
||||
@@ -5,14 +5,14 @@ from typing import Any, Dict, List, Mapping, Optional
|
||||
from pydantic import Extra, Field, root_validator
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.llms.base import SimpleLLM
|
||||
from langchain.llms.utils import enforce_stop_tokens
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Petals(LLM):
|
||||
class Petals(SimpleLLM):
|
||||
"""Wrapper around Petals Bloom models.
|
||||
|
||||
To use, you should have the ``petals`` python package installed, and the
|
||||
@@ -131,7 +131,7 @@ class Petals(LLM):
|
||||
"""Return type of llm."""
|
||||
return "petals"
|
||||
|
||||
def _call(
|
||||
def _generate_single(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
|
||||
@@ -5,14 +5,14 @@ from typing import Any, Dict, List, Mapping, Optional
|
||||
from pydantic import BaseModel, Extra, Field, root_validator
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.llms.base import SimpleLLM
|
||||
from langchain.llms.utils import enforce_stop_tokens
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PipelineAI(LLM, BaseModel):
|
||||
class PipelineAI(SimpleLLM, BaseModel):
|
||||
"""Wrapper around PipelineAI large language models.
|
||||
|
||||
To use, you should have the ``pipeline-ai`` python package installed,
|
||||
@@ -81,7 +81,7 @@ class PipelineAI(LLM, BaseModel):
|
||||
"""Return type of llm."""
|
||||
return "pipeline_ai"
|
||||
|
||||
def _call(
|
||||
def _generate_single(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
|
||||
@@ -5,14 +5,14 @@ from typing import Any, Dict, List, Optional
|
||||
from pydantic import Extra, root_validator
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.llms.base import SimpleLLM
|
||||
from langchain.llms.utils import enforce_stop_tokens
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PredictionGuard(LLM):
|
||||
class PredictionGuard(SimpleLLM):
|
||||
"""Wrapper around Prediction Guard large language models.
|
||||
To use, you should have the ``predictionguard`` python package installed, and the
|
||||
environment variable ``PREDICTIONGUARD_TOKEN`` set with your access token, or pass
|
||||
@@ -74,7 +74,7 @@ class PredictionGuard(LLM):
|
||||
"""Return type of llm."""
|
||||
return "predictionguard"
|
||||
|
||||
def _call(
|
||||
def _generate_single(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
|
||||
@@ -5,13 +5,13 @@ from typing import Any, Dict, List, Mapping, Optional
|
||||
from pydantic import Extra, Field, root_validator
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.llms.base import SimpleLLM
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Replicate(LLM):
|
||||
class Replicate(SimpleLLM):
|
||||
"""Wrapper around Replicate models.
|
||||
|
||||
To use, you should have the ``replicate`` python package installed,
|
||||
@@ -79,7 +79,7 @@ class Replicate(LLM):
|
||||
"""Return type of model."""
|
||||
return "replicate"
|
||||
|
||||
def _call(
|
||||
def _generate_single(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
|
||||
@@ -8,11 +8,11 @@ from typing import Any, Dict, List, Mapping, Optional, Set
|
||||
from pydantic import BaseModel, Extra, root_validator
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.llms.base import SimpleLLM
|
||||
from langchain.llms.utils import enforce_stop_tokens
|
||||
|
||||
|
||||
class RWKV(LLM, BaseModel):
|
||||
class RWKV(SimpleLLM, BaseModel):
|
||||
r"""Wrapper around RWKV language models.
|
||||
|
||||
To use, you should have the ``rwkv`` python package installed, the
|
||||
@@ -205,7 +205,7 @@ class RWKV(LLM, BaseModel):
|
||||
|
||||
return decoded
|
||||
|
||||
def _call(
|
||||
def _generate_single(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import Any, Dict, Generic, List, Mapping, Optional, TypeVar, Union
|
||||
from pydantic import Extra, root_validator
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.llms.base import SimpleLLM
|
||||
from langchain.llms.utils import enforce_stop_tokens
|
||||
|
||||
INPUT_TYPE = TypeVar("INPUT_TYPE", bound=Union[str, List[str]])
|
||||
@@ -61,7 +61,7 @@ class LLMContentHandler(ContentHandlerBase[str, str]):
|
||||
"""Content handler for LLM class."""
|
||||
|
||||
|
||||
class SagemakerEndpoint(LLM):
|
||||
class SagemakerEndpoint(SimpleLLM):
|
||||
"""Wrapper around custom Sagemaker Inference Endpoints.
|
||||
|
||||
To use, you must supply the endpoint name from your deployed
|
||||
@@ -202,7 +202,7 @@ class SagemakerEndpoint(LLM):
|
||||
"""Return type of llm."""
|
||||
return "sagemaker_endpoint"
|
||||
|
||||
def _call(
|
||||
def _generate_single(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
|
||||
@@ -7,7 +7,7 @@ from typing import Any, Callable, List, Mapping, Optional
|
||||
from pydantic import Extra
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.llms.base import SimpleLLM
|
||||
from langchain.llms.utils import enforce_stop_tokens
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -62,7 +62,7 @@ def _send_pipeline_to_device(pipeline: Any, device: int) -> Any:
|
||||
return pipeline
|
||||
|
||||
|
||||
class SelfHostedPipeline(LLM):
|
||||
class SelfHostedPipeline(SimpleLLM):
|
||||
"""Run model inference on self-hosted remote hardware.
|
||||
|
||||
Supported hardware includes auto-launched instances on AWS, GCP, Azure,
|
||||
@@ -178,7 +178,7 @@ class SelfHostedPipeline(LLM):
|
||||
model_reqs: Optional[List[str]] = None,
|
||||
device: int = 0,
|
||||
**kwargs: Any,
|
||||
) -> LLM:
|
||||
) -> SimpleLLM:
|
||||
"""Init the SelfHostedPipeline from a pipeline object or string."""
|
||||
if not isinstance(pipeline, str):
|
||||
logger.warning(
|
||||
@@ -209,7 +209,7 @@ class SelfHostedPipeline(LLM):
|
||||
def _llm_type(self) -> str:
|
||||
return "self_hosted_llm"
|
||||
|
||||
def _call(
|
||||
def _generate_single(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
|
||||
@@ -202,7 +202,7 @@ class SelfHostedHuggingFaceLLM(SelfHostedPipeline):
|
||||
def _llm_type(self) -> str:
|
||||
return "selfhosted_huggingface_pipeline"
|
||||
|
||||
def _call(
|
||||
def _generate_single(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
|
||||
@@ -7,14 +7,14 @@ import requests
|
||||
from pydantic import Extra, Field, root_validator
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.llms.base import SimpleLLM
|
||||
from langchain.llms.utils import enforce_stop_tokens
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class StochasticAI(LLM):
|
||||
class StochasticAI(SimpleLLM):
|
||||
"""Wrapper around StochasticAI large language models.
|
||||
|
||||
To use, you should have the environment variable ``STOCHASTICAI_API_KEY``
|
||||
@@ -81,7 +81,7 @@ class StochasticAI(LLM):
|
||||
"""Return type of llm."""
|
||||
return "stochasticai"
|
||||
|
||||
def _call(
|
||||
def _generate_single(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
|
||||
@@ -5,12 +5,12 @@ import requests
|
||||
from pydantic import Extra, root_validator
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.llms.base import SimpleLLM
|
||||
from langchain.llms.utils import enforce_stop_tokens
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
|
||||
class Writer(LLM):
|
||||
class Writer(SimpleLLM):
|
||||
"""Wrapper around Writer large language models.
|
||||
|
||||
To use, you should have the environment variable ``WRITER_API_KEY`` and
|
||||
@@ -113,7 +113,7 @@ class Writer(LLM):
|
||||
"""Return type of llm."""
|
||||
return "writer"
|
||||
|
||||
def _call(
|
||||
def _generate_single(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
|
||||
@@ -5,17 +5,17 @@ from typing import Any, List, Mapping, Optional
|
||||
from langchain.agents import AgentExecutor, AgentType, initialize_agent
|
||||
from langchain.agents.tools import Tool
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.llms.base import SimpleLLM
|
||||
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
|
||||
|
||||
|
||||
class FakeListLLM(LLM):
|
||||
class FakeListLLM(SimpleLLM):
|
||||
"""Fake LLM for testing that outputs elements of a list."""
|
||||
|
||||
responses: List[str]
|
||||
i: int = -1
|
||||
|
||||
def _call(
|
||||
def _generate_single(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
|
||||
@@ -7,7 +7,7 @@ from langchain.agents.tools import Tool
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.docstore.base import Docstore
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.llms.base import SimpleLLM
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
from langchain.schema import AgentAction
|
||||
|
||||
@@ -22,7 +22,7 @@ Made in 2022."""
|
||||
_FAKE_PROMPT = PromptTemplate(input_variables=["input"], template="{input}")
|
||||
|
||||
|
||||
class FakeListLLM(LLM):
|
||||
class FakeListLLM(SimpleLLM):
|
||||
"""Fake LLM for testing that outputs elements of a list."""
|
||||
|
||||
responses: List[str]
|
||||
@@ -33,7 +33,7 @@ class FakeListLLM(LLM):
|
||||
"""Return type of llm."""
|
||||
return "fake_list"
|
||||
|
||||
def _call(
|
||||
def _generate_single(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
|
||||
@@ -4,13 +4,13 @@ from typing import Any, List, Mapping, Optional
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.chains.natbot.base import NatBotChain
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.llms.base import SimpleLLM
|
||||
|
||||
|
||||
class FakeLLM(LLM):
|
||||
class FakeLLM(SimpleLLM):
|
||||
"""Fake LLM wrapper for testing purposes."""
|
||||
|
||||
def _call(
|
||||
def _generate_single(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
|
||||
@@ -4,10 +4,10 @@ from typing import Any, List, Mapping, Optional, cast
|
||||
from pydantic import validator
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.llms.base import SimpleLLM
|
||||
|
||||
|
||||
class FakeLLM(LLM):
|
||||
class FakeLLM(SimpleLLM):
|
||||
"""Fake LLM wrapper for testing purposes."""
|
||||
|
||||
queries: Optional[Mapping] = None
|
||||
@@ -29,7 +29,7 @@ class FakeLLM(LLM):
|
||||
"""Return type of llm."""
|
||||
return "fake"
|
||||
|
||||
def _call(
|
||||
def _generate_single(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
|
||||
Reference in New Issue
Block a user