mirror of
				https://github.com/hwchase17/langchain.git
				synced 2025-11-04 02:03:32 +00:00 
			
		
		
		
	Harrison/prediction guard update (#5404)
Co-authored-by: Daniel Whitenack <whitenack.daniel@gmail.com>
This commit is contained in:
		@@ -14,41 +14,85 @@ There exists a Prediction Guard LLM wrapper, which you can access with
 | 
			
		||||
from langchain.llms import PredictionGuard
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
You can provide the name of your Prediction Guard "proxy" as an argument when initializing the LLM:
 | 
			
		||||
You can provide the name of the Prediction Guard model as an argument when initializing the LLM:
 | 
			
		||||
```python
 | 
			
		||||
pgllm = PredictionGuard(name="your-text-gen-proxy")
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
Alternatively, you can use Prediction Guard's default proxy for SOTA LLMs:
 | 
			
		||||
```python
 | 
			
		||||
pgllm = PredictionGuard(name="default-text-gen")
 | 
			
		||||
pgllm = PredictionGuard(model="MPT-7B-Instruct")
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
You can also provide your access token directly as an argument:
 | 
			
		||||
```python
 | 
			
		||||
pgllm = PredictionGuard(name="default-text-gen", token="<your access token>")
 | 
			
		||||
pgllm = PredictionGuard(model="MPT-7B-Instruct", token="<your access token>")
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
Finally, you can provide an "output" argument that is used to structure/ control the output of the LLM:
 | 
			
		||||
```python
 | 
			
		||||
pgllm = PredictionGuard(model="MPT-7B-Instruct", output={"type": "boolean"})
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
## Example usage
 | 
			
		||||
 | 
			
		||||
Basic usage of the LLM wrapper:
 | 
			
		||||
Basic usage of the controlled or guarded LLM wrapper:
 | 
			
		||||
```python
 | 
			
		||||
from langchain.llms import PredictionGuard
 | 
			
		||||
import os
 | 
			
		||||
 | 
			
		||||
pgllm = PredictionGuard(name="default-text-gen")
 | 
			
		||||
pgllm("Tell me a joke")
 | 
			
		||||
import predictionguard as pg
 | 
			
		||||
from langchain.llms import PredictionGuard
 | 
			
		||||
from langchain import PromptTemplate, LLMChain
 | 
			
		||||
 | 
			
		||||
# Your Prediction Guard API key. Get one at predictionguard.com
 | 
			
		||||
os.environ["PREDICTIONGUARD_TOKEN"] = "<your Prediction Guard access token>"
 | 
			
		||||
 | 
			
		||||
# Define a prompt template
 | 
			
		||||
template = """Respond to the following query based on the context.
 | 
			
		||||
 | 
			
		||||
Context: EVERY comment, DM + email suggestion has led us to this EXCITING announcement! 🎉 We have officially added TWO new candle subscription box options! 📦
 | 
			
		||||
Exclusive Candle Box - $80 
 | 
			
		||||
Monthly Candle Box - $45 (NEW!)
 | 
			
		||||
Scent of The Month Box - $28 (NEW!)
 | 
			
		||||
Head to stories to get ALLL the deets on each box! 👆 BONUS: Save 50% on your first box with code 50OFF! 🎉
 | 
			
		||||
 | 
			
		||||
Query: {query}
 | 
			
		||||
 | 
			
		||||
Result: """
 | 
			
		||||
prompt = PromptTemplate(template=template, input_variables=["query"])
 | 
			
		||||
 | 
			
		||||
# With "guarding" or controlling the output of the LLM. See the 
 | 
			
		||||
# Prediction Guard docs (https://docs.predictionguard.com) to learn how to 
 | 
			
		||||
# control the output with integer, float, boolean, JSON, and other types and
 | 
			
		||||
# structures.
 | 
			
		||||
pgllm = PredictionGuard(model="MPT-7B-Instruct", 
 | 
			
		||||
                        output={
 | 
			
		||||
                                "type": "categorical",
 | 
			
		||||
                                "categories": [
 | 
			
		||||
                                    "product announcement", 
 | 
			
		||||
                                    "apology", 
 | 
			
		||||
                                    "relational"
 | 
			
		||||
                                    ]
 | 
			
		||||
                                })
 | 
			
		||||
pgllm(prompt.format(query="What kind of post is this?"))
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
Basic LLM Chaining with the Prediction Guard wrapper:
 | 
			
		||||
```python
 | 
			
		||||
import os
 | 
			
		||||
 | 
			
		||||
from langchain import PromptTemplate, LLMChain
 | 
			
		||||
from langchain.llms import PredictionGuard
 | 
			
		||||
 | 
			
		||||
# Optional, add your OpenAI API Key. This is optional, as Prediction Guard allows
 | 
			
		||||
# you to access all the latest open access models (see https://docs.predictionguard.com)
 | 
			
		||||
os.environ["OPENAI_API_KEY"] = "<your OpenAI api key>"
 | 
			
		||||
 | 
			
		||||
# Your Prediction Guard API key. Get one at predictionguard.com
 | 
			
		||||
os.environ["PREDICTIONGUARD_TOKEN"] = "<your Prediction Guard access token>"
 | 
			
		||||
 | 
			
		||||
pgllm = PredictionGuard(model="OpenAI-text-davinci-003")
 | 
			
		||||
 | 
			
		||||
template = """Question: {question}
 | 
			
		||||
 | 
			
		||||
Answer: Let's think step by step."""
 | 
			
		||||
prompt = PromptTemplate(template=template, input_variables=["question"])
 | 
			
		||||
llm_chain = LLMChain(prompt=prompt, llm=PredictionGuard(name="default-text-gen"), verbose=True)
 | 
			
		||||
llm_chain = LLMChain(prompt=prompt, llm=pgllm, verbose=True)
 | 
			
		||||
 | 
			
		||||
question = "What NFL team won the Super Bowl in the year Justin Beiber was born?"
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -1,14 +1,19 @@
 | 
			
		||||
{
 | 
			
		||||
 "cells": [
 | 
			
		||||
  {
 | 
			
		||||
   "cell_type": "markdown",
 | 
			
		||||
   "metadata": {},
 | 
			
		||||
   "source": [
 | 
			
		||||
    "# PredictionGuard\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "How to use PredictionGuard wrapper"
 | 
			
		||||
   ]
 | 
			
		||||
  "nbformat": 4,
 | 
			
		||||
  "nbformat_minor": 0,
 | 
			
		||||
  "metadata": {
 | 
			
		||||
    "colab": {
 | 
			
		||||
      "provenance": []
 | 
			
		||||
    },
 | 
			
		||||
    "kernelspec": {
 | 
			
		||||
      "name": "python3",
 | 
			
		||||
      "display_name": "Python 3"
 | 
			
		||||
    },
 | 
			
		||||
    "language_info": {
 | 
			
		||||
      "name": "python"
 | 
			
		||||
    }
 | 
			
		||||
  },
 | 
			
		||||
  "cells": [
 | 
			
		||||
    {
 | 
			
		||||
      "cell_type": "code",
 | 
			
		||||
      "execution_count": null,
 | 
			
		||||
@@ -22,75 +27,156 @@
 | 
			
		||||
    },
 | 
			
		||||
    {
 | 
			
		||||
      "cell_type": "code",
 | 
			
		||||
   "execution_count": 1,
 | 
			
		||||
      "source": [
 | 
			
		||||
        "import os\n",
 | 
			
		||||
        "\n",
 | 
			
		||||
        "import predictionguard as pg\n",
 | 
			
		||||
        "from langchain.llms import PredictionGuard\n",
 | 
			
		||||
        "from langchain import PromptTemplate, LLMChain"
 | 
			
		||||
      ],
 | 
			
		||||
      "metadata": {
 | 
			
		||||
        "id": "2xe8JEUwA7_y"
 | 
			
		||||
      },
 | 
			
		||||
   "outputs": [],
 | 
			
		||||
   "source": [
 | 
			
		||||
    "import predictionguard as pg\n",
 | 
			
		||||
    "from langchain.llms import PredictionGuard"
 | 
			
		||||
   ]
 | 
			
		||||
      "execution_count": null,
 | 
			
		||||
      "outputs": []
 | 
			
		||||
    },
 | 
			
		||||
    {
 | 
			
		||||
      "cell_type": "markdown",
 | 
			
		||||
      "source": [
 | 
			
		||||
        "# Basic LLM usage\n",
 | 
			
		||||
        "\n"
 | 
			
		||||
      ],
 | 
			
		||||
      "metadata": {
 | 
			
		||||
        "id": "mesCTyhnJkNS"
 | 
			
		||||
   },
 | 
			
		||||
   "source": [
 | 
			
		||||
    "## Basic LLM usage\n",
 | 
			
		||||
    "\n"
 | 
			
		||||
   ]
 | 
			
		||||
      }
 | 
			
		||||
    },
 | 
			
		||||
    {
 | 
			
		||||
      "cell_type": "code",
 | 
			
		||||
      "source": [
 | 
			
		||||
        "# Optional, add your OpenAI API Key. This is optional, as Prediction Guard allows\n",
 | 
			
		||||
        "# you to access all the latest open access models (see https://docs.predictionguard.com)\n",
 | 
			
		||||
        "os.environ[\"OPENAI_API_KEY\"] = \"<your OpenAI api key>\"\n",
 | 
			
		||||
        "\n",
 | 
			
		||||
        "# Your Prediction Guard API key. Get one at predictionguard.com\n",
 | 
			
		||||
        "os.environ[\"PREDICTIONGUARD_TOKEN\"] = \"<your Prediction Guard access token>\""
 | 
			
		||||
      ],
 | 
			
		||||
      "metadata": {
 | 
			
		||||
        "id": "kp_Ymnx1SnDG"
 | 
			
		||||
      },
 | 
			
		||||
      "execution_count": null,
 | 
			
		||||
      "outputs": []
 | 
			
		||||
    },
 | 
			
		||||
    {
 | 
			
		||||
      "cell_type": "code",
 | 
			
		||||
      "source": [
 | 
			
		||||
        "pgllm = PredictionGuard(model=\"OpenAI-text-davinci-003\")"
 | 
			
		||||
      ],
 | 
			
		||||
      "metadata": {
 | 
			
		||||
        "id": "Ua7Mw1N4HcER"
 | 
			
		||||
      },
 | 
			
		||||
   "outputs": [],
 | 
			
		||||
   "source": [
 | 
			
		||||
    "pgllm = PredictionGuard(name=\"default-text-gen\", token=\"<your access token>\")"
 | 
			
		||||
   ]
 | 
			
		||||
      "execution_count": null,
 | 
			
		||||
      "outputs": []
 | 
			
		||||
    },
 | 
			
		||||
    {
 | 
			
		||||
      "cell_type": "code",
 | 
			
		||||
   "execution_count": null,
 | 
			
		||||
      "source": [
 | 
			
		||||
        "pgllm(\"Tell me a joke\")"
 | 
			
		||||
      ],
 | 
			
		||||
      "metadata": {
 | 
			
		||||
        "id": "Qo2p5flLHxrB"
 | 
			
		||||
      },
 | 
			
		||||
   "outputs": [],
 | 
			
		||||
   "source": [
 | 
			
		||||
    "pgllm(\"Tell me a joke\")"
 | 
			
		||||
   ]
 | 
			
		||||
      "execution_count": null,
 | 
			
		||||
      "outputs": []
 | 
			
		||||
    },
 | 
			
		||||
    {
 | 
			
		||||
      "cell_type": "markdown",
 | 
			
		||||
   "metadata": {
 | 
			
		||||
    "id": "v3MzIUItJ8kV"
 | 
			
		||||
   },
 | 
			
		||||
      "source": [
 | 
			
		||||
    "## Chaining"
 | 
			
		||||
   ]
 | 
			
		||||
        "# Control the output structure/ type of LLMs"
 | 
			
		||||
      ],
 | 
			
		||||
      "metadata": {
 | 
			
		||||
        "id": "EyBYaP_xTMXH"
 | 
			
		||||
      }
 | 
			
		||||
    },
 | 
			
		||||
    {
 | 
			
		||||
      "cell_type": "code",
 | 
			
		||||
      "source": [
 | 
			
		||||
        "template = \"\"\"Respond to the following query based on the context.\n",
 | 
			
		||||
        "\n",
 | 
			
		||||
        "Context: EVERY comment, DM + email suggestion has led us to this EXCITING announcement! 🎉 We have officially added TWO new candle subscription box options! 📦\n",
 | 
			
		||||
        "Exclusive Candle Box - $80 \n",
 | 
			
		||||
        "Monthly Candle Box - $45 (NEW!)\n",
 | 
			
		||||
        "Scent of The Month Box - $28 (NEW!)\n",
 | 
			
		||||
        "Head to stories to get ALLL the deets on each box! 👆 BONUS: Save 50% on your first box with code 50OFF! 🎉\n",
 | 
			
		||||
        "\n",
 | 
			
		||||
        "Query: {query}\n",
 | 
			
		||||
        "\n",
 | 
			
		||||
        "Result: \"\"\"\n",
 | 
			
		||||
        "prompt = PromptTemplate(template=template, input_variables=[\"query\"])"
 | 
			
		||||
      ],
 | 
			
		||||
      "metadata": {
 | 
			
		||||
        "id": "55uxzhQSTPqF"
 | 
			
		||||
      },
 | 
			
		||||
      "execution_count": null,
 | 
			
		||||
      "outputs": []
 | 
			
		||||
    },
 | 
			
		||||
    {
 | 
			
		||||
      "cell_type": "code",
 | 
			
		||||
      "source": [
 | 
			
		||||
        "# Without \"guarding\" or controlling the output of the LLM.\n",
 | 
			
		||||
        "pgllm(prompt.format(query=\"What kind of post is this?\"))"
 | 
			
		||||
      ],
 | 
			
		||||
      "metadata": {
 | 
			
		||||
        "id": "yersskWbTaxU"
 | 
			
		||||
      },
 | 
			
		||||
      "execution_count": null,
 | 
			
		||||
      "outputs": []
 | 
			
		||||
    },
 | 
			
		||||
    {
 | 
			
		||||
      "cell_type": "code",
 | 
			
		||||
      "source": [
 | 
			
		||||
        "# With \"guarding\" or controlling the output of the LLM. See the \n",
 | 
			
		||||
        "# Prediction Guard docs (https://docs.predictionguard.com) to learn how to \n",
 | 
			
		||||
        "# control the output with integer, float, boolean, JSON, and other types and\n",
 | 
			
		||||
        "# structures.\n",
 | 
			
		||||
        "pgllm = PredictionGuard(model=\"OpenAI-text-davinci-003\", \n",
 | 
			
		||||
        "                        output={\n",
 | 
			
		||||
        "                                \"type\": \"categorical\",\n",
 | 
			
		||||
        "                                \"categories\": [\n",
 | 
			
		||||
        "                                    \"product announcement\", \n",
 | 
			
		||||
        "                                    \"apology\", \n",
 | 
			
		||||
        "                                    \"relational\"\n",
 | 
			
		||||
        "                                    ]\n",
 | 
			
		||||
        "                                })\n",
 | 
			
		||||
        "pgllm(prompt.format(query=\"What kind of post is this?\"))"
 | 
			
		||||
      ],
 | 
			
		||||
      "metadata": {
 | 
			
		||||
        "id": "PzxSbYwqTm2w"
 | 
			
		||||
      },
 | 
			
		||||
      "execution_count": null,
 | 
			
		||||
      "outputs": []
 | 
			
		||||
    },
 | 
			
		||||
    {
 | 
			
		||||
      "cell_type": "markdown",
 | 
			
		||||
      "source": [
 | 
			
		||||
        "# Chaining"
 | 
			
		||||
      ],
 | 
			
		||||
      "metadata": {
 | 
			
		||||
        "id": "v3MzIUItJ8kV"
 | 
			
		||||
      }
 | 
			
		||||
    },
 | 
			
		||||
    {
 | 
			
		||||
      "cell_type": "code",
 | 
			
		||||
      "source": [
 | 
			
		||||
        "pgllm = PredictionGuard(model=\"OpenAI-text-davinci-003\")"
 | 
			
		||||
      ],
 | 
			
		||||
      "metadata": {
 | 
			
		||||
        "id": "pPegEZExILrT"
 | 
			
		||||
      },
 | 
			
		||||
   "outputs": [],
 | 
			
		||||
   "source": [
 | 
			
		||||
    "from langchain import PromptTemplate, LLMChain"
 | 
			
		||||
   ]
 | 
			
		||||
      "execution_count": null,
 | 
			
		||||
      "outputs": []
 | 
			
		||||
    },
 | 
			
		||||
    {
 | 
			
		||||
      "cell_type": "code",
 | 
			
		||||
   "execution_count": null,
 | 
			
		||||
   "metadata": {
 | 
			
		||||
    "id": "suxw62y-J-bg"
 | 
			
		||||
   },
 | 
			
		||||
   "outputs": [],
 | 
			
		||||
      "source": [
 | 
			
		||||
        "template = \"\"\"Question: {question}\n",
 | 
			
		||||
        "\n",
 | 
			
		||||
@@ -101,55 +187,36 @@
 | 
			
		||||
        "question = \"What NFL team won the Super Bowl in the year Justin Beiber was born?\"\n",
 | 
			
		||||
        "\n",
 | 
			
		||||
        "llm_chain.predict(question=question)"
 | 
			
		||||
   ]
 | 
			
		||||
      ],
 | 
			
		||||
      "metadata": {
 | 
			
		||||
        "id": "suxw62y-J-bg"
 | 
			
		||||
      },
 | 
			
		||||
      "execution_count": null,
 | 
			
		||||
      "outputs": []
 | 
			
		||||
    },
 | 
			
		||||
    {
 | 
			
		||||
      "cell_type": "code",
 | 
			
		||||
   "execution_count": null,
 | 
			
		||||
   "metadata": {
 | 
			
		||||
    "id": "l2bc26KHKr7n"
 | 
			
		||||
   },
 | 
			
		||||
   "outputs": [],
 | 
			
		||||
      "source": [
 | 
			
		||||
        "template = \"\"\"Write a {adjective} poem about {subject}.\"\"\"\n",
 | 
			
		||||
        "prompt = PromptTemplate(template=template, input_variables=[\"adjective\", \"subject\"])\n",
 | 
			
		||||
        "llm_chain = LLMChain(prompt=prompt, llm=pgllm, verbose=True)\n",
 | 
			
		||||
        "\n",
 | 
			
		||||
        "llm_chain.predict(adjective=\"sad\", subject=\"ducks\")"
 | 
			
		||||
   ]
 | 
			
		||||
      ],
 | 
			
		||||
      "metadata": {
 | 
			
		||||
        "id": "l2bc26KHKr7n"
 | 
			
		||||
      },
 | 
			
		||||
      "execution_count": null,
 | 
			
		||||
      "outputs": []
 | 
			
		||||
    },
 | 
			
		||||
    {
 | 
			
		||||
      "cell_type": "code",
 | 
			
		||||
   "execution_count": null,
 | 
			
		||||
      "source": [],
 | 
			
		||||
      "metadata": {
 | 
			
		||||
        "id": "I--eSa2PLGqq"
 | 
			
		||||
      },
 | 
			
		||||
   "outputs": [],
 | 
			
		||||
   "source": []
 | 
			
		||||
      "execution_count": null,
 | 
			
		||||
      "outputs": []
 | 
			
		||||
    }
 | 
			
		||||
 ],
 | 
			
		||||
 "metadata": {
 | 
			
		||||
  "colab": {
 | 
			
		||||
   "provenance": []
 | 
			
		||||
  },
 | 
			
		||||
  "kernelspec": {
 | 
			
		||||
   "display_name": "Python 3 (ipykernel)",
 | 
			
		||||
   "language": "python",
 | 
			
		||||
   "name": "python3"
 | 
			
		||||
  },
 | 
			
		||||
  "language_info": {
 | 
			
		||||
   "codemirror_mode": {
 | 
			
		||||
    "name": "ipython",
 | 
			
		||||
    "version": 3
 | 
			
		||||
   },
 | 
			
		||||
   "file_extension": ".py",
 | 
			
		||||
   "mimetype": "text/x-python",
 | 
			
		||||
   "name": "python",
 | 
			
		||||
   "nbconvert_exporter": "python",
 | 
			
		||||
   "pygments_lexer": "ipython3",
 | 
			
		||||
   "version": "3.9.1"
 | 
			
		||||
  }
 | 
			
		||||
 },
 | 
			
		||||
 "nbformat": 4,
 | 
			
		||||
 "nbformat_minor": 1
 | 
			
		||||
  ]
 | 
			
		||||
}
 | 
			
		||||
@@ -16,15 +16,24 @@ class PredictionGuard(LLM):
 | 
			
		||||
    """Wrapper around Prediction Guard large language models.
 | 
			
		||||
    To use, you should have the ``predictionguard`` python package installed, and the
 | 
			
		||||
    environment variable ``PREDICTIONGUARD_TOKEN`` set with your access token, or pass
 | 
			
		||||
    it as a named parameter to the constructor.
 | 
			
		||||
    it as a named parameter to the constructor. To use Prediction Guard's API along
 | 
			
		||||
    with OpenAI models, set the environment variable ``OPENAI_API_KEY`` with your
 | 
			
		||||
    OpenAI API key as well.
 | 
			
		||||
    Example:
 | 
			
		||||
        .. code-block:: python
 | 
			
		||||
            pgllm = PredictionGuard(name="text-gen-proxy-name", token="my-access-token")
 | 
			
		||||
            pgllm = PredictionGuard(model="MPT-7B-Instruct",
 | 
			
		||||
                                    token="my-access-token",
 | 
			
		||||
                                    output={
 | 
			
		||||
                                        "type": "boolean"
 | 
			
		||||
                                    })
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    client: Any  #: :meta private:
 | 
			
		||||
    name: Optional[str] = "default-text-gen"
 | 
			
		||||
    """Proxy name to use."""
 | 
			
		||||
    model: Optional[str] = "MPT-7B-Instruct"
 | 
			
		||||
    """Model name to use."""
 | 
			
		||||
 | 
			
		||||
    output: Optional[Dict[str, Any]] = None
 | 
			
		||||
    """The output type or structure for controlling the LLM output."""
 | 
			
		||||
 | 
			
		||||
    max_tokens: int = 256
 | 
			
		||||
    """Denotes the number of tokens to predict per generation."""
 | 
			
		||||
@@ -33,6 +42,7 @@ class PredictionGuard(LLM):
 | 
			
		||||
    """A non-negative float that tunes the degree of randomness in generation."""
 | 
			
		||||
 | 
			
		||||
    token: Optional[str] = None
 | 
			
		||||
    """Your Prediction Guard access token."""
 | 
			
		||||
 | 
			
		||||
    stop: Optional[List[str]] = None
 | 
			
		||||
 | 
			
		||||
@@ -58,7 +68,7 @@ class PredictionGuard(LLM):
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def _default_params(self) -> Dict[str, Any]:
 | 
			
		||||
        """Get the default parameters for calling Cohere API."""
 | 
			
		||||
        """Get the default parameters for calling the Prediction Guard API."""
 | 
			
		||||
        return {
 | 
			
		||||
            "max_tokens": self.max_tokens,
 | 
			
		||||
            "temperature": self.temperature,
 | 
			
		||||
@@ -67,7 +77,7 @@ class PredictionGuard(LLM):
 | 
			
		||||
    @property
 | 
			
		||||
    def _identifying_params(self) -> Dict[str, Any]:
 | 
			
		||||
        """Get the identifying parameters."""
 | 
			
		||||
        return {**{"name": self.name}, **self._default_params}
 | 
			
		||||
        return {**{"model": self.model}, **self._default_params}
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def _llm_type(self) -> str:
 | 
			
		||||
@@ -80,7 +90,7 @@ class PredictionGuard(LLM):
 | 
			
		||||
        stop: Optional[List[str]] = None,
 | 
			
		||||
        run_manager: Optional[CallbackManagerForLLMRun] = None,
 | 
			
		||||
    ) -> str:
 | 
			
		||||
        """Call out to Prediction Guard's model proxy.
 | 
			
		||||
        """Call out to Prediction Guard's model API.
 | 
			
		||||
        Args:
 | 
			
		||||
            prompt: The prompt to pass into the model.
 | 
			
		||||
        Returns:
 | 
			
		||||
@@ -89,6 +99,8 @@ class PredictionGuard(LLM):
 | 
			
		||||
            .. code-block:: python
 | 
			
		||||
                response = pgllm("Tell me a joke.")
 | 
			
		||||
        """
 | 
			
		||||
        import predictionguard as pg
 | 
			
		||||
 | 
			
		||||
        params = self._default_params
 | 
			
		||||
        if self.stop is not None and stop is not None:
 | 
			
		||||
            raise ValueError("`stop` found in both the input and default params.")
 | 
			
		||||
@@ -97,15 +109,14 @@ class PredictionGuard(LLM):
 | 
			
		||||
        else:
 | 
			
		||||
            params["stop_sequences"] = stop
 | 
			
		||||
 | 
			
		||||
        response = self.client.predict(
 | 
			
		||||
            name=self.name,
 | 
			
		||||
            data={
 | 
			
		||||
                "prompt": prompt,
 | 
			
		||||
                "max_tokens": params["max_tokens"],
 | 
			
		||||
                "temperature": params["temperature"],
 | 
			
		||||
            },
 | 
			
		||||
        response = pg.Completion.create(
 | 
			
		||||
            model=self.model,
 | 
			
		||||
            prompt=prompt,
 | 
			
		||||
            output=self.output,
 | 
			
		||||
            temperature=params["temperature"],
 | 
			
		||||
            max_tokens=params["max_tokens"],
 | 
			
		||||
        )
 | 
			
		||||
        text = response["text"]
 | 
			
		||||
        text = response["choices"][0]["text"]
 | 
			
		||||
 | 
			
		||||
        # If stop tokens are provided, Prediction Guard's endpoint returns them.
 | 
			
		||||
        # In order to make this consistent with other endpoints, we strip them.
 | 
			
		||||
 
 | 
			
		||||
@@ -5,6 +5,6 @@ from langchain.llms.predictionguard import PredictionGuard
 | 
			
		||||
 | 
			
		||||
def test_predictionguard_call() -> None:
 | 
			
		||||
    """Test valid call to prediction guard."""
 | 
			
		||||
    llm = PredictionGuard(name="default-text-gen")
 | 
			
		||||
    llm = PredictionGuard(model="OpenAI-text-davinci-003")
 | 
			
		||||
    output = llm("Say foo:")
 | 
			
		||||
    assert isinstance(output, str)
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user