Compare commits

...

3 Commits

Author SHA1 Message Date
Dev 2049
6e38bc7f2c simple 2023-05-23 11:29:52 -07:00
Dev 2049
38b3fc80fe nit 2023-05-18 16:11:45 -07:00
Dev 2049
4e9f6de31a rfc 2023-05-18 16:06:03 -07:00
38 changed files with 130 additions and 110 deletions

View File

@@ -37,7 +37,7 @@ class JsonFormer(HuggingFacePipeline):
import_jsonformer()
return values
def _call(
def _generate_single(
self,
prompt: str,
stop: Optional[List[str]] = None,

View File

@@ -42,7 +42,7 @@ class RELLM(HuggingFacePipeline):
import_rellm()
return values
def _call(
def _generate_single(
self,
prompt: str,
stop: Optional[List[str]] = None,

View File

@@ -5,7 +5,7 @@ import requests
from pydantic import BaseModel, Extra, root_validator
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
from langchain.llms.base import SimpleLLM
from langchain.utils import get_from_dict_or_env
@@ -20,7 +20,7 @@ class AI21PenaltyData(BaseModel):
applyToEmojis: bool = True
class AI21(LLM):
class AI21(SimpleLLM):
"""Wrapper around AI21 large language models.
To use, you should have the environment variable ``AI21_API_KEY``
@@ -107,7 +107,7 @@ class AI21(LLM):
"""Return type of llm."""
return "ai21"
def _call(
def _generate_single(
self,
prompt: str,
stop: Optional[List[str]] = None,

View File

@@ -4,12 +4,12 @@ from typing import Any, Dict, List, Optional, Sequence
from pydantic import Extra, root_validator
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
from langchain.llms.base import SimpleLLM
from langchain.llms.utils import enforce_stop_tokens
from langchain.utils import get_from_dict_or_env
class AlephAlpha(LLM):
class AlephAlpha(SimpleLLM):
"""Wrapper around Aleph Alpha large language models.
To use, you should have the ``aleph_alpha_client`` python package installed, and the
@@ -201,7 +201,7 @@ class AlephAlpha(LLM):
"""Return type of llm."""
return "alpeh_alpha"
def _call(
def _generate_single(
self,
prompt: str,
stop: Optional[List[str]] = None,

View File

@@ -9,7 +9,7 @@ from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain.llms.base import LLM
from langchain.llms.base import SimpleLLM
from langchain.utils import get_from_dict_or_env
@@ -104,7 +104,7 @@ class _AnthropicCommon(BaseModel):
return self.count_tokens(text)
class Anthropic(LLM, _AnthropicCommon):
class Anthropic(SimpleLLM, _AnthropicCommon):
r"""Wrapper around Anthropic's large language models.
To use, you should have the ``anthropic`` python package installed, and the
@@ -162,7 +162,7 @@ class Anthropic(LLM, _AnthropicCommon):
# As a last resort, wrap the prompt ourselves to emulate instruct-style.
return f"{self.HUMAN_PROMPT} {prompt}{self.AI_PROMPT} Sure, here you go:\n"
def _call(
def _generate_single(
self,
prompt: str,
stop: Optional[List[str]] = None,

View File

@@ -5,12 +5,12 @@ import requests
from pydantic import Extra, root_validator
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
from langchain.llms.base import SimpleLLM
from langchain.llms.utils import enforce_stop_tokens
from langchain.utils import get_from_dict_or_env
class Anyscale(LLM):
class Anyscale(SimpleLLM):
"""Wrapper around Anyscale Services.
To use, you should have the environment variable ``ANYSCALE_SERVICE_URL``,
``ANYSCALE_SERVICE_ROUTE`` and ``ANYSCALE_SERVICE_TOKEN`` set with your Anyscale
@@ -82,7 +82,7 @@ class Anyscale(LLM):
"""Return type of llm."""
return "anyscale"
def _call(
def _generate_single(
self,
prompt: str,
stop: Optional[List[str]] = None,

View File

@@ -5,14 +5,14 @@ from typing import Any, Dict, List, Mapping, Optional
from pydantic import Extra, Field, root_validator
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
from langchain.llms.base import SimpleLLM
from langchain.llms.utils import enforce_stop_tokens
from langchain.utils import get_from_dict_or_env
logger = logging.getLogger(__name__)
class Banana(LLM):
class Banana(SimpleLLM):
"""Wrapper around Banana large language models.
To use, you should have the ``banana-dev`` python package installed,
@@ -81,7 +81,7 @@ class Banana(LLM):
"""Return type of llm."""
return "banana"
def _call(
def _generate_single(
self,
prompt: str,
stop: Optional[List[str]] = None,

View File

@@ -365,14 +365,32 @@ class BaseLLM(BaseLanguageModel, ABC):
raise ValueError(f"{save_path} must be json or yaml")
class LLM(BaseLLM):
class SimpleLLM(BaseLLM):
"""LLM class that expect subclasses to implement a simpler call method.
The purpose of this class is to expose a simpler interface for working
with LLMs, rather than expect the user to implement the full _generate method.
"""
@abstractmethod
def _generate_single(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
) -> str:
"""Run the LLM on a single input string and return the output as a string."""
raise NotImplementedError
async def _agenerate_single(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
) -> str:
"""Run the LLM on a single input string and return the output as a string."""
raise NotImplementedError
# Kept for backwards compatibility
def _call(
self,
prompt: str,
@@ -380,7 +398,9 @@ class LLM(BaseLLM):
run_manager: Optional[CallbackManagerForLLMRun] = None,
) -> str:
"""Run the LLM on the given prompt and input."""
return self._generate_single(prompt, stop=stop, run_manager=run_manager)
# Kept for backwards compatibility
async def _acall(
self,
prompt: str,
@@ -388,7 +408,7 @@ class LLM(BaseLLM):
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
) -> str:
"""Run the LLM on the given prompt and input."""
raise NotImplementedError("Async generation not implemented for this LLM.")
return await self._agenerate_single(prompt, stop=stop, run_manager=run_manager)
def _generate(
self,

View File

@@ -5,14 +5,14 @@ from typing import Any, Dict, List, Mapping, Optional
from pydantic import Extra, Field, root_validator
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
from langchain.llms.base import SimpleLLM
from langchain.llms.utils import enforce_stop_tokens
from langchain.utils import get_from_dict_or_env
logger = logging.getLogger(__name__)
class CerebriumAI(LLM):
class CerebriumAI(SimpleLLM):
"""Wrapper around CerebriumAI large language models.
To use, you should have the ``cerebrium`` python package installed, and the
@@ -82,7 +82,7 @@ class CerebriumAI(LLM):
"""Return type of llm."""
return "cerebriumai"
def _call(
def _generate_single(
self,
prompt: str,
stop: Optional[List[str]] = None,

View File

@@ -5,14 +5,14 @@ from typing import Any, Dict, List, Optional
from pydantic import Extra, root_validator
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
from langchain.llms.base import SimpleLLM
from langchain.llms.utils import enforce_stop_tokens
from langchain.utils import get_from_dict_or_env
logger = logging.getLogger(__name__)
class Cohere(LLM):
class Cohere(SimpleLLM):
"""Wrapper around Cohere large language models.
To use, you should have the ``cohere`` python package installed, and the
@@ -101,7 +101,7 @@ class Cohere(LLM):
"""Return type of llm."""
return "cohere"
def _call(
def _generate_single(
self,
prompt: str,
stop: Optional[List[str]] = None,

View File

@@ -5,14 +5,14 @@ import requests
from pydantic import Extra, root_validator
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
from langchain.llms.base import SimpleLLM
from langchain.llms.utils import enforce_stop_tokens
from langchain.utils import get_from_dict_or_env
DEFAULT_MODEL_ID = "google/flan-t5-xl"
class DeepInfra(LLM):
class DeepInfra(SimpleLLM):
"""Wrapper around DeepInfra deployed models.
To use, you should have the ``requests`` python package installed, and the
@@ -61,7 +61,7 @@ class DeepInfra(LLM):
"""Return type of llm."""
return "deepinfra"
def _call(
def _generate_single(
self,
prompt: str,
stop: Optional[List[str]] = None,

View File

@@ -2,10 +2,10 @@
from typing import Any, List, Mapping, Optional
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
from langchain.llms.base import SimpleLLM
class FakeListLLM(LLM):
class FakeListLLM(SimpleLLM):
"""Fake LLM wrapper for testing purposes."""
responses: List
@@ -16,7 +16,7 @@ class FakeListLLM(LLM):
"""Return type of llm."""
return "fake-list"
def _call(
def _generate_single(
self,
prompt: str,
stop: Optional[List[str]] = None,

View File

@@ -5,12 +5,12 @@ import requests
from pydantic import Extra, root_validator
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
from langchain.llms.base import SimpleLLM
from langchain.llms.utils import enforce_stop_tokens
from langchain.utils import get_from_dict_or_env
class ForefrontAI(LLM):
class ForefrontAI(SimpleLLM):
"""Wrapper around ForefrontAI large language models.
To use, you should have the environment variable ``FOREFRONTAI_API_KEY``
@@ -82,7 +82,7 @@ class ForefrontAI(LLM):
"""Return type of llm."""
return "forefrontai"
def _call(
def _generate_single(
self,
prompt: str,
stop: Optional[List[str]] = None,

View File

@@ -5,13 +5,13 @@ from typing import Any, Dict, List, Mapping, Optional
from pydantic import Extra, Field, root_validator
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
from langchain.llms.base import SimpleLLM
from langchain.utils import get_from_dict_or_env
logger = logging.getLogger(__name__)
class GooseAI(LLM):
class GooseAI(SimpleLLM):
"""Wrapper around OpenAI large language models.
To use, you should have the ``openai`` python package installed, and the
@@ -131,7 +131,7 @@ class GooseAI(LLM):
"""Return type of llm."""
return "gooseai"
def _call(
def _generate_single(
self,
prompt: str,
stop: Optional[List[str]] = None,

View File

@@ -5,11 +5,11 @@ from typing import Any, Dict, List, Mapping, Optional, Set
from pydantic import Extra, Field, root_validator
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
from langchain.llms.base import SimpleLLM
from langchain.llms.utils import enforce_stop_tokens
class GPT4All(LLM):
class GPT4All(SimpleLLM):
r"""Wrapper around GPT4All language models.
To use, you should have the ``gpt4all`` python package installed, the
@@ -167,7 +167,7 @@ class GPT4All(LLM):
"""Return the type of llm."""
return "gpt4all"
def _call(
def _generate_single(
self,
prompt: str,
stop: Optional[List[str]] = None,

View File

@@ -5,14 +5,14 @@ import requests
from pydantic import Extra, root_validator
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
from langchain.llms.base import SimpleLLM
from langchain.llms.utils import enforce_stop_tokens
from langchain.utils import get_from_dict_or_env
VALID_TASKS = ("text2text-generation", "text-generation", "summarization")
class HuggingFaceEndpoint(LLM):
class HuggingFaceEndpoint(SimpleLLM):
"""Wrapper around HuggingFaceHub Inference Endpoints.
To use, you should have the ``huggingface_hub`` python package installed, and the
@@ -91,7 +91,7 @@ class HuggingFaceEndpoint(LLM):
"""Return type of llm."""
return "huggingface_endpoint"
def _call(
def _generate_single(
self,
prompt: str,
stop: Optional[List[str]] = None,

View File

@@ -4,7 +4,7 @@ from typing import Any, Dict, List, Mapping, Optional
from pydantic import Extra, root_validator
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
from langchain.llms.base import SimpleLLM
from langchain.llms.utils import enforce_stop_tokens
from langchain.utils import get_from_dict_or_env
@@ -12,7 +12,7 @@ DEFAULT_REPO_ID = "gpt2"
VALID_TASKS = ("text2text-generation", "text-generation", "summarization")
class HuggingFaceHub(LLM):
class HuggingFaceHub(SimpleLLM):
"""Wrapper around HuggingFaceHub models.
To use, you should have the ``huggingface_hub`` python package installed, and the
@@ -86,7 +86,7 @@ class HuggingFaceHub(LLM):
"""Return type of llm."""
return "huggingface_hub"
def _call(
def _generate_single(
self,
prompt: str,
stop: Optional[List[str]] = None,

View File

@@ -6,7 +6,7 @@ from typing import Any, List, Mapping, Optional
from pydantic import Extra
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
from langchain.llms.base import SimpleLLM
from langchain.llms.utils import enforce_stop_tokens
DEFAULT_MODEL_ID = "gpt2"
@@ -16,7 +16,7 @@ VALID_TASKS = ("text2text-generation", "text-generation", "summarization")
logger = logging.getLogger(__name__)
class HuggingFacePipeline(LLM):
class HuggingFacePipeline(SimpleLLM):
"""Wrapper around HuggingFace Pipeline API.
To use, you should have the ``transformers`` python package installed.
@@ -64,7 +64,7 @@ class HuggingFacePipeline(LLM):
device: int = -1,
model_kwargs: Optional[dict] = None,
**kwargs: Any,
) -> LLM:
) -> SimpleLLM:
"""Construct the pipeline object from model_id and task."""
try:
from transformers import (
@@ -150,7 +150,7 @@ class HuggingFacePipeline(LLM):
def _llm_type(self) -> str:
return "huggingface_pipeline"
def _call(
def _generate_single(
self,
prompt: str,
stop: Optional[List[str]] = None,

View File

@@ -5,10 +5,10 @@ from typing import Any, Dict, List, Optional
from pydantic import Extra, Field, root_validator
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
from langchain.llms.base import SimpleLLM
class HuggingFaceTextGenInference(LLM):
class HuggingFaceTextGenInference(SimpleLLM):
"""
HuggingFace text generation inference API.
@@ -108,7 +108,7 @@ class HuggingFaceTextGenInference(LLM):
"""Return type of llm."""
return "hf_textgen_inference"
def _call(
def _generate_single(
self,
prompt: str,
stop: Optional[List[str]] = None,

View File

@@ -3,7 +3,7 @@ from typing import Any, Callable, List, Mapping, Optional
from pydantic import Field
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
from langchain.llms.base import SimpleLLM
from langchain.llms.utils import enforce_stop_tokens
@@ -32,7 +32,7 @@ def _collect_user_input(
return multi_line_input
class HumanInputLLM(LLM):
class HumanInputLLM(SimpleLLM):
"""
A LLM wrapper which returns user input as the response.
"""
@@ -55,7 +55,7 @@ class HumanInputLLM(LLM):
"""Returns the type of LLM."""
return "human-input"
def _call(
def _generate_single(
self,
prompt: str,
stop: Optional[List[str]] = None,

View File

@@ -5,12 +5,12 @@ from typing import Any, Dict, Generator, List, Optional
from pydantic import Field, root_validator
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
from langchain.llms.base import SimpleLLM
logger = logging.getLogger(__name__)
class LlamaCpp(LLM):
class LlamaCpp(SimpleLLM):
"""Wrapper around the llama.cpp model.
To use, you should have the llama-cpp-python library installed, and provide the
@@ -195,7 +195,7 @@ class LlamaCpp(LLM):
return params
def _call(
def _generate_single(
self,
prompt: str,
stop: Optional[List[str]] = None,

View File

@@ -4,10 +4,10 @@ from typing import Any, Dict, List, Mapping, Optional
from pydantic import Extra, root_validator
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
from langchain.llms.base import SimpleLLM
class ManifestWrapper(LLM):
class ManifestWrapper(SimpleLLM):
"""Wrapper around HazyResearch's Manifest library."""
client: Any #: :meta private:
@@ -43,7 +43,7 @@ class ManifestWrapper(LLM):
"""Return type of llm."""
return "manifest"
def _call(
def _generate_single(
self,
prompt: str,
stop: Optional[List[str]] = None,

View File

@@ -6,13 +6,13 @@ import requests
from pydantic import Extra, Field, root_validator
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
from langchain.llms.base import SimpleLLM
from langchain.llms.utils import enforce_stop_tokens
logger = logging.getLogger(__name__)
class Modal(LLM):
class Modal(SimpleLLM):
"""Wrapper around Modal large language models.
To use, you should have the ``modal-client`` python package installed.
@@ -70,7 +70,7 @@ class Modal(LLM):
"""Return type of llm."""
return "modal"
def _call(
def _generate_single(
self,
prompt: str,
stop: Optional[List[str]] = None,

View File

@@ -4,11 +4,11 @@ from typing import Any, Dict, List, Mapping, Optional
from pydantic import Extra, root_validator
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
from langchain.llms.base import SimpleLLM
from langchain.utils import get_from_dict_or_env
class NLPCloud(LLM):
class NLPCloud(SimpleLLM):
"""Wrapper around NLPCloud large language models.
To use, you should have the ``nlpcloud`` python package installed, and the
@@ -112,7 +112,7 @@ class NLPCloud(LLM):
"""Return type of llm."""
return "nlpcloud"
def _call(
def _generate_single(
self,
prompt: str,
stop: Optional[List[str]] = None,

View File

@@ -5,14 +5,14 @@ from typing import Any, Dict, List, Mapping, Optional
from pydantic import Extra, Field, root_validator
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
from langchain.llms.base import SimpleLLM
from langchain.llms.utils import enforce_stop_tokens
from langchain.utils import get_from_dict_or_env
logger = logging.getLogger(__name__)
class Petals(LLM):
class Petals(SimpleLLM):
"""Wrapper around Petals Bloom models.
To use, you should have the ``petals`` python package installed, and the
@@ -131,7 +131,7 @@ class Petals(LLM):
"""Return type of llm."""
return "petals"
def _call(
def _generate_single(
self,
prompt: str,
stop: Optional[List[str]] = None,

View File

@@ -5,14 +5,14 @@ from typing import Any, Dict, List, Mapping, Optional
from pydantic import BaseModel, Extra, Field, root_validator
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
from langchain.llms.base import SimpleLLM
from langchain.llms.utils import enforce_stop_tokens
from langchain.utils import get_from_dict_or_env
logger = logging.getLogger(__name__)
class PipelineAI(LLM, BaseModel):
class PipelineAI(SimpleLLM, BaseModel):
"""Wrapper around PipelineAI large language models.
To use, you should have the ``pipeline-ai`` python package installed,
@@ -81,7 +81,7 @@ class PipelineAI(LLM, BaseModel):
"""Return type of llm."""
return "pipeline_ai"
def _call(
def _generate_single(
self,
prompt: str,
stop: Optional[List[str]] = None,

View File

@@ -5,14 +5,14 @@ from typing import Any, Dict, List, Optional
from pydantic import Extra, root_validator
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
from langchain.llms.base import SimpleLLM
from langchain.llms.utils import enforce_stop_tokens
from langchain.utils import get_from_dict_or_env
logger = logging.getLogger(__name__)
class PredictionGuard(LLM):
class PredictionGuard(SimpleLLM):
"""Wrapper around Prediction Guard large language models.
To use, you should have the ``predictionguard`` python package installed, and the
environment variable ``PREDICTIONGUARD_TOKEN`` set with your access token, or pass
@@ -74,7 +74,7 @@ class PredictionGuard(LLM):
"""Return type of llm."""
return "predictionguard"
def _call(
def _generate_single(
self,
prompt: str,
stop: Optional[List[str]] = None,

View File

@@ -5,13 +5,13 @@ from typing import Any, Dict, List, Mapping, Optional
from pydantic import Extra, Field, root_validator
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
from langchain.llms.base import SimpleLLM
from langchain.utils import get_from_dict_or_env
logger = logging.getLogger(__name__)
class Replicate(LLM):
class Replicate(SimpleLLM):
"""Wrapper around Replicate models.
To use, you should have the ``replicate`` python package installed,
@@ -79,7 +79,7 @@ class Replicate(LLM):
"""Return type of model."""
return "replicate"
def _call(
def _generate_single(
self,
prompt: str,
stop: Optional[List[str]] = None,

View File

@@ -8,11 +8,11 @@ from typing import Any, Dict, List, Mapping, Optional, Set
from pydantic import BaseModel, Extra, root_validator
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
from langchain.llms.base import SimpleLLM
from langchain.llms.utils import enforce_stop_tokens
class RWKV(LLM, BaseModel):
class RWKV(SimpleLLM, BaseModel):
r"""Wrapper around RWKV language models.
To use, you should have the ``rwkv`` python package installed, the
@@ -205,7 +205,7 @@ class RWKV(LLM, BaseModel):
return decoded
def _call(
def _generate_single(
self,
prompt: str,
stop: Optional[List[str]] = None,

View File

@@ -5,7 +5,7 @@ from typing import Any, Dict, Generic, List, Mapping, Optional, TypeVar, Union
from pydantic import Extra, root_validator
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
from langchain.llms.base import SimpleLLM
from langchain.llms.utils import enforce_stop_tokens
INPUT_TYPE = TypeVar("INPUT_TYPE", bound=Union[str, List[str]])
@@ -61,7 +61,7 @@ class LLMContentHandler(ContentHandlerBase[str, str]):
"""Content handler for LLM class."""
class SagemakerEndpoint(LLM):
class SagemakerEndpoint(SimpleLLM):
"""Wrapper around custom Sagemaker Inference Endpoints.
To use, you must supply the endpoint name from your deployed
@@ -202,7 +202,7 @@ class SagemakerEndpoint(LLM):
"""Return type of llm."""
return "sagemaker_endpoint"
def _call(
def _generate_single(
self,
prompt: str,
stop: Optional[List[str]] = None,

View File

@@ -7,7 +7,7 @@ from typing import Any, Callable, List, Mapping, Optional
from pydantic import Extra
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
from langchain.llms.base import SimpleLLM
from langchain.llms.utils import enforce_stop_tokens
logger = logging.getLogger(__name__)
@@ -62,7 +62,7 @@ def _send_pipeline_to_device(pipeline: Any, device: int) -> Any:
return pipeline
class SelfHostedPipeline(LLM):
class SelfHostedPipeline(SimpleLLM):
"""Run model inference on self-hosted remote hardware.
Supported hardware includes auto-launched instances on AWS, GCP, Azure,
@@ -178,7 +178,7 @@ class SelfHostedPipeline(LLM):
model_reqs: Optional[List[str]] = None,
device: int = 0,
**kwargs: Any,
) -> LLM:
) -> SimpleLLM:
"""Init the SelfHostedPipeline from a pipeline object or string."""
if not isinstance(pipeline, str):
logger.warning(
@@ -209,7 +209,7 @@ class SelfHostedPipeline(LLM):
def _llm_type(self) -> str:
return "self_hosted_llm"
def _call(
def _generate_single(
self,
prompt: str,
stop: Optional[List[str]] = None,

View File

@@ -202,7 +202,7 @@ class SelfHostedHuggingFaceLLM(SelfHostedPipeline):
def _llm_type(self) -> str:
return "selfhosted_huggingface_pipeline"
def _call(
def _generate_single(
self,
prompt: str,
stop: Optional[List[str]] = None,

View File

@@ -7,14 +7,14 @@ import requests
from pydantic import Extra, Field, root_validator
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
from langchain.llms.base import SimpleLLM
from langchain.llms.utils import enforce_stop_tokens
from langchain.utils import get_from_dict_or_env
logger = logging.getLogger(__name__)
class StochasticAI(LLM):
class StochasticAI(SimpleLLM):
"""Wrapper around StochasticAI large language models.
To use, you should have the environment variable ``STOCHASTICAI_API_KEY``
@@ -81,7 +81,7 @@ class StochasticAI(LLM):
"""Return type of llm."""
return "stochasticai"
def _call(
def _generate_single(
self,
prompt: str,
stop: Optional[List[str]] = None,

View File

@@ -5,12 +5,12 @@ import requests
from pydantic import Extra, root_validator
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
from langchain.llms.base import SimpleLLM
from langchain.llms.utils import enforce_stop_tokens
from langchain.utils import get_from_dict_or_env
class Writer(LLM):
class Writer(SimpleLLM):
"""Wrapper around Writer large language models.
To use, you should have the environment variable ``WRITER_API_KEY`` and
@@ -113,7 +113,7 @@ class Writer(LLM):
"""Return type of llm."""
return "writer"
def _call(
def _generate_single(
self,
prompt: str,
stop: Optional[List[str]] = None,

View File

@@ -5,17 +5,17 @@ from typing import Any, List, Mapping, Optional
from langchain.agents import AgentExecutor, AgentType, initialize_agent
from langchain.agents.tools import Tool
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
from langchain.llms.base import SimpleLLM
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
class FakeListLLM(LLM):
class FakeListLLM(SimpleLLM):
"""Fake LLM for testing that outputs elements of a list."""
responses: List[str]
i: int = -1
def _call(
def _generate_single(
self,
prompt: str,
stop: Optional[List[str]] = None,

View File

@@ -7,7 +7,7 @@ from langchain.agents.tools import Tool
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.docstore.base import Docstore
from langchain.docstore.document import Document
from langchain.llms.base import LLM
from langchain.llms.base import SimpleLLM
from langchain.prompts.prompt import PromptTemplate
from langchain.schema import AgentAction
@@ -22,7 +22,7 @@ Made in 2022."""
_FAKE_PROMPT = PromptTemplate(input_variables=["input"], template="{input}")
class FakeListLLM(LLM):
class FakeListLLM(SimpleLLM):
"""Fake LLM for testing that outputs elements of a list."""
responses: List[str]
@@ -33,7 +33,7 @@ class FakeListLLM(LLM):
"""Return type of llm."""
return "fake_list"
def _call(
def _generate_single(
self,
prompt: str,
stop: Optional[List[str]] = None,

View File

@@ -4,13 +4,13 @@ from typing import Any, List, Mapping, Optional
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.chains.natbot.base import NatBotChain
from langchain.llms.base import LLM
from langchain.llms.base import SimpleLLM
class FakeLLM(LLM):
class FakeLLM(SimpleLLM):
"""Fake LLM wrapper for testing purposes."""
def _call(
def _generate_single(
self,
prompt: str,
stop: Optional[List[str]] = None,

View File

@@ -4,10 +4,10 @@ from typing import Any, List, Mapping, Optional, cast
from pydantic import validator
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
from langchain.llms.base import SimpleLLM
class FakeLLM(LLM):
class FakeLLM(SimpleLLM):
"""Fake LLM wrapper for testing purposes."""
queries: Optional[Mapping] = None
@@ -29,7 +29,7 @@ class FakeLLM(LLM):
"""Return type of llm."""
return "fake"
def _call(
def _generate_single(
self,
prompt: str,
stop: Optional[List[str]] = None,