mirror of
https://github.com/hwchase17/langchain.git
synced 2025-05-22 07:27:45 +00:00
OpenAI LLM: update modelname_to_contextsize
with new models (#2843)
Token counts pulled from https://openai.com/pricing
This commit is contained in:
parent
82d1d5f24e
commit
be4fb24b32
@ -476,13 +476,6 @@ class BaseOpenAI(BaseLLM):
|
|||||||
def modelname_to_contextsize(self, modelname: str) -> int:
|
def modelname_to_contextsize(self, modelname: str) -> int:
|
||||||
"""Calculate the maximum number of tokens possible to generate for a model.
|
"""Calculate the maximum number of tokens possible to generate for a model.
|
||||||
|
|
||||||
text-davinci-003: 4,097 tokens
|
|
||||||
text-curie-001: 2,048 tokens
|
|
||||||
text-babbage-001: 2,048 tokens
|
|
||||||
text-ada-001: 2,048 tokens
|
|
||||||
code-davinci-002: 8,000 tokens
|
|
||||||
code-cushman-001: 2,048 tokens
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
modelname: The modelname we want to know the context size for.
|
modelname: The modelname we want to know the context size for.
|
||||||
|
|
||||||
@ -494,20 +487,37 @@ class BaseOpenAI(BaseLLM):
|
|||||||
|
|
||||||
max_tokens = openai.modelname_to_contextsize("text-davinci-003")
|
max_tokens = openai.modelname_to_contextsize("text-davinci-003")
|
||||||
"""
|
"""
|
||||||
if modelname == "text-davinci-003":
|
model_token_mapping = {
|
||||||
return 4097
|
"gpt-4": 8192,
|
||||||
elif modelname == "text-curie-001":
|
"gpt-4-0314": 8192,
|
||||||
return 2048
|
"gpt-4-32k": 32768,
|
||||||
elif modelname == "text-babbage-001":
|
"gpt-4-32k-0314": 32768,
|
||||||
return 2048
|
"gpt-3.5-turbo": 4096,
|
||||||
elif modelname == "text-ada-001":
|
"gpt-3.5-turbo-0301": 4096,
|
||||||
return 2048
|
"text-ada-001": 2049,
|
||||||
elif modelname == "code-davinci-002":
|
"ada": 2049,
|
||||||
return 8000
|
"text-babbage-001": 2040,
|
||||||
elif modelname == "code-cushman-001":
|
"babbage": 2049,
|
||||||
return 2048
|
"text-curie-001": 2049,
|
||||||
else:
|
"curie": 2049,
|
||||||
return 4097
|
"davinci": 2049,
|
||||||
|
"text-davinci-003": 4097,
|
||||||
|
"text-davinci-002": 4097,
|
||||||
|
"code-davinci-002": 8001,
|
||||||
|
"code-davinci-001": 8001,
|
||||||
|
"code-cushman-002": 2048,
|
||||||
|
"code-cushman-001": 2048,
|
||||||
|
}
|
||||||
|
|
||||||
|
context_size = model_token_mapping.get(modelname, None)
|
||||||
|
|
||||||
|
if context_size is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown model: {modelname}. Please provide a valid OpenAI model name."
|
||||||
|
"Known models are: " + ", ".join(model_token_mapping.keys())
|
||||||
|
)
|
||||||
|
|
||||||
|
return context_size
|
||||||
|
|
||||||
def max_tokens_for_prompt(self, prompt: str) -> int:
|
def max_tokens_for_prompt(self, prompt: str) -> int:
|
||||||
"""Calculate the maximum number of tokens possible to generate for a prompt.
|
"""Calculate the maximum number of tokens possible to generate for a prompt.
|
||||||
|
@ -211,3 +211,14 @@ async def test_openai_chat_async_streaming_callback() -> None:
|
|||||||
result = await llm.agenerate(["Write me a sentence with 100 words."])
|
result = await llm.agenerate(["Write me a sentence with 100 words."])
|
||||||
assert callback_handler.llm_streams != 0
|
assert callback_handler.llm_streams != 0
|
||||||
assert isinstance(result, LLMResult)
|
assert isinstance(result, LLMResult)
|
||||||
|
|
||||||
|
|
||||||
|
def test_openai_modelname_to_contextsize_valid() -> None:
|
||||||
|
"""Test model name to context size on a valid model."""
|
||||||
|
assert OpenAI().modelname_to_contextsize("davinci") == 2049
|
||||||
|
|
||||||
|
|
||||||
|
def test_openai_modelname_to_contextsize_invalid() -> None:
|
||||||
|
"""Test model name to context size on an invalid model."""
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
OpenAI().modelname_to_contextsize("foobar")
|
||||||
|
Loading…
Reference in New Issue
Block a user