RFC: more complete return (#313)

Co-authored-by: Andrew Williamson <awilliamson10@indstate.edu>
Co-authored-by: awilliamson10 <aw.williamson10@gmail.com>
This commit is contained in:
Harrison Chase 2022-12-13 05:50:03 -08:00 committed by GitHub
parent 482611f426
commit 595cc1ae1a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 73 additions and 13 deletions

View File

@ -1,11 +1,39 @@
"""Base interface for large language models to expose.""" """Base interface for large language models to expose."""
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, List, Mapping, Optional from typing import Any, List, Mapping, NamedTuple, Optional
class Generation(NamedTuple):
"""Output of a single generation."""
text: str
"""Generated text output."""
# TODO: add log probs
class LLMResult(NamedTuple):
"""Class that contains all relevant information for an LLM Result."""
generations: List[List[Generation]]
"""List of the things generated. This is List[List[]] because
each input could have multiple generations."""
llm_output: Optional[dict] = None
"""For arbitrary LLM provider specific output."""
class LLM(ABC): class LLM(ABC):
"""LLM wrapper should take in a prompt and return a string.""" """LLM wrapper should take in a prompt and return a string."""
def generate(
self, prompts: List[str], stop: Optional[List[str]] = None
) -> LLMResult:
"""Run the LLM on the given prompt and input."""
generations = []
for prompt in prompts:
text = self(prompt, stop=stop)
generations.append([Generation(text=text)])
return LLMResult(generations=generations)
def get_num_tokens(self, text: str) -> int: def get_num_tokens(self, text: str) -> int:
"""Get the number of tokens present in the text.""" """Get the number of tokens present in the text."""
# TODO: this method may not be exact. # TODO: this method may not be exact.

View File

@ -3,7 +3,7 @@ from typing import Any, Dict, List, Mapping, Optional
from pydantic import BaseModel, Extra, Field, root_validator from pydantic import BaseModel, Extra, Field, root_validator
from langchain.llms.base import LLM from langchain.llms.base import LLM, Generation, LLMResult
from langchain.utils import get_from_dict_or_env from langchain.utils import get_from_dict_or_env
@ -97,6 +97,48 @@ class OpenAI(LLM, BaseModel):
} }
return {**normal_params, **self.model_kwargs} return {**normal_params, **self.model_kwargs}
def generate(
self, prompts: List[str], stop: Optional[List[str]] = None
) -> LLMResult:
"""Call out to OpenAI's endpoint with k unique prompts.
Args:
prompts: The prompts to pass into the model.
stop: Optional list of stop words to use when generating.
Returns:
The full LLM output.
Example:
.. code-block:: python
response = openai.generate(["Tell me a joke."])
"""
params = self._default_params
if stop is not None:
if "stop" in params:
raise ValueError("`stop` found in both the input and default params.")
params["stop"] = stop
if params["max_tokens"] == -1:
if len(prompts) != 1:
raise ValueError(
"max_tokens set to -1 not supported for multiple inputs."
)
params["max_tokens"] = self.max_tokens_for_prompt(prompts[0])
response = self.client.create(model=self.model_name, prompt=prompts, **params)
generations = []
for i, prompt in enumerate(prompts):
choices = response["choices"][i * self.n : (i + 1) * self.n]
generations.append([Generation(text=choice["text"]) for choice in choices])
# Get the token usage from the response.
# Includes prompt, completion, and total tokens used.
token_usage = response["usage"]
return LLMResult(
generations=generations, llm_output={"token_usage": token_usage}
)
@property @property
def _identifying_params(self) -> Mapping[str, Any]: def _identifying_params(self) -> Mapping[str, Any]:
"""Get the identifying parameters.""" """Get the identifying parameters."""
@ -117,17 +159,7 @@ class OpenAI(LLM, BaseModel):
response = openai("Tell me a joke.") response = openai("Tell me a joke.")
""" """
params = self._default_params return self.generate([prompt], stop=stop).generations[0][0].text
if params["max_tokens"] == -1:
params["max_tokens"] = self.max_tokens_for_prompt(prompt)
if stop is not None:
if "stop" in params:
raise ValueError("`stop` found in both the input and default params.")
params["stop"] = stop
response = self.client.create(model=self.model_name, prompt=prompt, **params)
return response["choices"][0]["text"]
def modelname_to_contextsize(self, modelname: str) -> int: def modelname_to_contextsize(self, modelname: str) -> int:
"""Calculate the maximum number of tokens possible to generate for a model. """Calculate the maximum number of tokens possible to generate for a model.