Mrbean/support timeout (#398)

Add support for passing in a request timeout to the API
This commit is contained in:
mrbean 2022-12-21 23:39:07 -05:00 committed by GitHub
parent 6b60c509ac
commit 136f759492
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 5 additions and 2 deletions

View File

@ -37,7 +37,7 @@ class HuggingFacePipeline(LLM, BaseModel):
pipe = pipeline( pipe = pipeline(
"text-generation", model=model, tokenizer=tokenizer, max_new_tokens=10 "text-generation", model=model, tokenizer=tokenizer, max_new_tokens=10
) )
hf = HuggingFacePipeline(pipeline=pipe hf = HuggingFacePipeline(pipeline=pipe)
""" """
pipeline: Any #: :meta private: pipeline: Any #: :meta private:

View File

@ -1,6 +1,6 @@
"""Wrapper around OpenAI APIs.""" """Wrapper around OpenAI APIs."""
import sys import sys
from typing import Any, Dict, Generator, List, Mapping, Optional from typing import Any, Dict, Generator, List, Mapping, Optional, Tuple, Union
from pydantic import BaseModel, Extra, Field, root_validator from pydantic import BaseModel, Extra, Field, root_validator
@ -49,6 +49,8 @@ class BaseOpenAI(BaseLLM, BaseModel):
openai_api_key: Optional[str] = None openai_api_key: Optional[str] = None
batch_size: int = 20 batch_size: int = 20
"""Batch size to use when passing multiple documents to generate.""" """Batch size to use when passing multiple documents to generate."""
request_timeout: Optional[Union[float, Tuple[float, float]]] = None
"""Timeout for requests to OpenAI completion API. Default is 600 seconds."""
class Config: class Config:
"""Configuration for this pydantic object.""" """Configuration for this pydantic object."""
@ -98,6 +100,7 @@ class BaseOpenAI(BaseLLM, BaseModel):
"presence_penalty": self.presence_penalty, "presence_penalty": self.presence_penalty,
"n": self.n, "n": self.n,
"best_of": self.best_of, "best_of": self.best_of,
"request_timeout": self.request_timeout,
} }
return {**normal_params, **self.model_kwargs} return {**normal_params, **self.model_kwargs}