From 595cc1ae1ac786e7a9ae940e4ffc1a65871a948f Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Tue, 13 Dec 2022 05:50:03 -0800 Subject: [PATCH] RFC: more complete return (#313) Co-authored-by: Andrew Williamson Co-authored-by: awilliamson10 --- langchain/llms/base.py | 30 ++++++++++++++++++++- langchain/llms/openai.py | 56 +++++++++++++++++++++++++++++++--------- 2 files changed, 73 insertions(+), 13 deletions(-) diff --git a/langchain/llms/base.py b/langchain/llms/base.py index 60728841bed..367aa503c43 100644 --- a/langchain/llms/base.py +++ b/langchain/llms/base.py @@ -1,11 +1,39 @@ """Base interface for large language models to expose.""" from abc import ABC, abstractmethod -from typing import Any, List, Mapping, Optional +from typing import Any, List, Mapping, NamedTuple, Optional + + +class Generation(NamedTuple): + """Output of a single generation.""" + + text: str + """Generated text output.""" + # TODO: add log probs + + +class LLMResult(NamedTuple): + """Class that contains all relevant information for an LLM Result.""" + + generations: List[List[Generation]] + """List of the things generated. This is List[List[]] because + each input could have multiple generations.""" + llm_output: Optional[dict] = None + """For arbitrary LLM provider specific output.""" class LLM(ABC): """LLM wrapper should take in a prompt and return a string.""" + def generate( + self, prompts: List[str], stop: Optional[List[str]] = None + ) -> LLMResult: + """Run the LLM on the given prompt and input.""" + generations = [] + for prompt in prompts: + text = self(prompt, stop=stop) + generations.append([Generation(text=text)]) + return LLMResult(generations=generations) + def get_num_tokens(self, text: str) -> int: """Get the number of tokens present in the text.""" # TODO: this method may not be exact. diff --git a/langchain/llms/openai.py b/langchain/llms/openai.py index 9de56f0ce8c..fd88d873fee 100644 --- a/langchain/llms/openai.py +++ b/langchain/llms/openai.py @@ -3,7 +3,7 @@ from typing import Any, Dict, List, Mapping, Optional from pydantic import BaseModel, Extra, Field, root_validator -from langchain.llms.base import LLM +from langchain.llms.base import LLM, Generation, LLMResult from langchain.utils import get_from_dict_or_env @@ -97,6 +97,48 @@ class OpenAI(LLM, BaseModel): } return {**normal_params, **self.model_kwargs} + def generate( + self, prompts: List[str], stop: Optional[List[str]] = None + ) -> LLMResult: + """Call out to OpenAI's endpoint with k unique prompts. + + Args: + prompts: The prompts to pass into the model. + stop: Optional list of stop words to use when generating. + + Returns: + The full LLM output. + + Example: + .. code-block:: python + + response = openai.generate(["Tell me a joke."]) + """ + params = self._default_params + if stop is not None: + if "stop" in params: + raise ValueError("`stop` found in both the input and default params.") + params["stop"] = stop + + if params["max_tokens"] == -1: + if len(prompts) != 1: + raise ValueError( + "max_tokens set to -1 not supported for multiple inputs." + ) + params["max_tokens"] = self.max_tokens_for_prompt(prompts[0]) + + response = self.client.create(model=self.model_name, prompt=prompts, **params) + generations = [] + for i, prompt in enumerate(prompts): + choices = response["choices"][i * self.n : (i + 1) * self.n] + generations.append([Generation(text=choice["text"]) for choice in choices]) + # Get the token usage from the response. + # Includes prompt, completion, and total tokens used. + token_usage = response["usage"] + return LLMResult( + generations=generations, llm_output={"token_usage": token_usage} + ) + @property def _identifying_params(self) -> Mapping[str, Any]: """Get the identifying parameters.""" @@ -117,17 +159,7 @@ class OpenAI(LLM, BaseModel): response = openai("Tell me a joke.") """ - params = self._default_params - - if params["max_tokens"] == -1: - params["max_tokens"] = self.max_tokens_for_prompt(prompt) - - if stop is not None: - if "stop" in params: - raise ValueError("`stop` found in both the input and default params.") - params["stop"] = stop - response = self.client.create(model=self.model_name, prompt=prompt, **params) - return response["choices"][0]["text"] + return self.generate([prompt], stop=stop).generations[0][0].text def modelname_to_contextsize(self, modelname: str) -> int: """Calculate the maximum number of tokens possible to generate for a model.