mirror of
https://github.com/hwchase17/langchain.git
synced 2025-05-21 06:58:02 +00:00
OpenAI LLM: update modelname_to_contextsize
with new models (#2843)
Token counts pulled from https://openai.com/pricing
This commit is contained in:
parent
82d1d5f24e
commit
be4fb24b32
@ -476,13 +476,6 @@ class BaseOpenAI(BaseLLM):
|
||||
def modelname_to_contextsize(self, modelname: str) -> int:
|
||||
"""Calculate the maximum number of tokens possible to generate for a model.
|
||||
|
||||
text-davinci-003: 4,097 tokens
|
||||
text-curie-001: 2,048 tokens
|
||||
text-babbage-001: 2,048 tokens
|
||||
text-ada-001: 2,048 tokens
|
||||
code-davinci-002: 8,000 tokens
|
||||
code-cushman-001: 2,048 tokens
|
||||
|
||||
Args:
|
||||
modelname: The modelname we want to know the context size for.
|
||||
|
||||
@ -494,20 +487,37 @@ class BaseOpenAI(BaseLLM):
|
||||
|
||||
max_tokens = openai.modelname_to_contextsize("text-davinci-003")
|
||||
"""
|
||||
if modelname == "text-davinci-003":
|
||||
return 4097
|
||||
elif modelname == "text-curie-001":
|
||||
return 2048
|
||||
elif modelname == "text-babbage-001":
|
||||
return 2048
|
||||
elif modelname == "text-ada-001":
|
||||
return 2048
|
||||
elif modelname == "code-davinci-002":
|
||||
return 8000
|
||||
elif modelname == "code-cushman-001":
|
||||
return 2048
|
||||
else:
|
||||
return 4097
|
||||
model_token_mapping = {
|
||||
"gpt-4": 8192,
|
||||
"gpt-4-0314": 8192,
|
||||
"gpt-4-32k": 32768,
|
||||
"gpt-4-32k-0314": 32768,
|
||||
"gpt-3.5-turbo": 4096,
|
||||
"gpt-3.5-turbo-0301": 4096,
|
||||
"text-ada-001": 2049,
|
||||
"ada": 2049,
|
||||
"text-babbage-001": 2040,
|
||||
"babbage": 2049,
|
||||
"text-curie-001": 2049,
|
||||
"curie": 2049,
|
||||
"davinci": 2049,
|
||||
"text-davinci-003": 4097,
|
||||
"text-davinci-002": 4097,
|
||||
"code-davinci-002": 8001,
|
||||
"code-davinci-001": 8001,
|
||||
"code-cushman-002": 2048,
|
||||
"code-cushman-001": 2048,
|
||||
}
|
||||
|
||||
context_size = model_token_mapping.get(modelname, None)
|
||||
|
||||
if context_size is None:
|
||||
raise ValueError(
|
||||
f"Unknown model: {modelname}. Please provide a valid OpenAI model name."
|
||||
"Known models are: " + ", ".join(model_token_mapping.keys())
|
||||
)
|
||||
|
||||
return context_size
|
||||
|
||||
def max_tokens_for_prompt(self, prompt: str) -> int:
|
||||
"""Calculate the maximum number of tokens possible to generate for a prompt.
|
||||
|
@ -211,3 +211,14 @@ async def test_openai_chat_async_streaming_callback() -> None:
|
||||
result = await llm.agenerate(["Write me a sentence with 100 words."])
|
||||
assert callback_handler.llm_streams != 0
|
||||
assert isinstance(result, LLMResult)
|
||||
|
||||
|
||||
def test_openai_modelname_to_contextsize_valid() -> None:
|
||||
"""Test model name to context size on a valid model."""
|
||||
assert OpenAI().modelname_to_contextsize("davinci") == 2049
|
||||
|
||||
|
||||
def test_openai_modelname_to_contextsize_invalid() -> None:
|
||||
"""Test model name to context size on an invalid model."""
|
||||
with pytest.raises(ValueError):
|
||||
OpenAI().modelname_to_contextsize("foobar")
|
||||
|
Loading…
Reference in New Issue
Block a user