mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-17 07:26:16 +00:00
openai[minor]: implement langchain-openai package (#15503)
Todo - [x] copy over integration tests - [x] update docs with new instructions in #15513 - [x] add linear ticket to bump core -> community, community->langchain, and core->openai deps - [ ] (optional): add `pip install langchain-openai` command to each notebook using it - [x] Update docstrings to not need `openai` install - [x] Add serialization - [x] deprecate old models Contributor steps: - [x] Add secret names to manual integrations workflow in .github/workflows/_integration_test.yml - [x] Add secrets to release workflow (for pre-release testing) in .github/workflows/_release.yml Maintainer steps (Contributors should not do these): - [x] set up pypi and test pypi projects - [x] add credential secrets to Github Actions - [ ] add package to conda-forge Functional changes to existing classes: - now relies on openai client v1 (1.6.1) via concrete dep in langchain-openai package Codebase organization - some function calling stuff moved to `langchain_core.utils.function_calling` in order to be used in both community and langchain-openai
This commit is contained in:
0
libs/partners/openai/tests/__init__.py
Normal file
0
libs/partners/openai/tests/__init__.py
Normal file
@@ -0,0 +1,221 @@
|
||||
"""Test AzureChatOpenAI wrapper."""
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from langchain_core.callbacks import CallbackManager
|
||||
from langchain_core.messages import BaseMessage, HumanMessage
|
||||
from langchain_core.outputs import ChatGeneration, ChatResult, LLMResult
|
||||
|
||||
from langchain_openai import AzureChatOpenAI
|
||||
from tests.unit_tests.fake.callbacks import FakeCallbackHandler
|
||||
|
||||
OPENAI_API_VERSION = os.environ.get("AZURE_OPENAI_API_VERSION", "")
|
||||
OPENAI_API_BASE = os.environ.get("AZURE_OPENAI_API_BASE", "")
|
||||
OPENAI_API_KEY = os.environ.get("AZURE_OPENAI_API_KEY", "")
|
||||
DEPLOYMENT_NAME = os.environ.get(
|
||||
"AZURE_OPENAI_DEPLOYMENT_NAME",
|
||||
os.environ.get("AZURE_OPENAI_CHAT_DEPLOYMENT_NAME", ""),
|
||||
)
|
||||
|
||||
|
||||
def _get_llm(**kwargs: Any) -> AzureChatOpenAI:
|
||||
return AzureChatOpenAI(
|
||||
deployment_name=DEPLOYMENT_NAME,
|
||||
openai_api_version=OPENAI_API_VERSION,
|
||||
azure_endpoint=OPENAI_API_BASE,
|
||||
openai_api_key=OPENAI_API_KEY,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.scheduled
|
||||
@pytest.fixture
|
||||
def llm() -> AzureChatOpenAI:
|
||||
return _get_llm(
|
||||
max_tokens=10,
|
||||
)
|
||||
|
||||
|
||||
def test_chat_openai(llm: AzureChatOpenAI) -> None:
|
||||
"""Test AzureChatOpenAI wrapper."""
|
||||
message = HumanMessage(content="Hello")
|
||||
response = llm([message])
|
||||
assert isinstance(response, BaseMessage)
|
||||
assert isinstance(response.content, str)
|
||||
|
||||
|
||||
@pytest.mark.scheduled
|
||||
def test_chat_openai_generate() -> None:
|
||||
"""Test AzureChatOpenAI wrapper with generate."""
|
||||
chat = _get_llm(max_tokens=10, n=2)
|
||||
message = HumanMessage(content="Hello")
|
||||
response = chat.generate([[message], [message]])
|
||||
assert isinstance(response, LLMResult)
|
||||
assert len(response.generations) == 2
|
||||
for generations in response.generations:
|
||||
assert len(generations) == 2
|
||||
for generation in generations:
|
||||
assert isinstance(generation, ChatGeneration)
|
||||
assert isinstance(generation.text, str)
|
||||
assert generation.text == generation.message.content
|
||||
|
||||
|
||||
@pytest.mark.scheduled
|
||||
def test_chat_openai_multiple_completions() -> None:
|
||||
"""Test AzureChatOpenAI wrapper with multiple completions."""
|
||||
chat = _get_llm(max_tokens=10, n=5)
|
||||
message = HumanMessage(content="Hello")
|
||||
response = chat._generate([message])
|
||||
assert isinstance(response, ChatResult)
|
||||
assert len(response.generations) == 5
|
||||
for generation in response.generations:
|
||||
assert isinstance(generation.message, BaseMessage)
|
||||
assert isinstance(generation.message.content, str)
|
||||
|
||||
|
||||
@pytest.mark.scheduled
|
||||
def test_chat_openai_streaming() -> None:
|
||||
"""Test that streaming correctly invokes on_llm_new_token callback."""
|
||||
callback_handler = FakeCallbackHandler()
|
||||
callback_manager = CallbackManager([callback_handler])
|
||||
chat = _get_llm(
|
||||
max_tokens=10,
|
||||
streaming=True,
|
||||
temperature=0,
|
||||
callback_manager=callback_manager,
|
||||
verbose=True,
|
||||
)
|
||||
message = HumanMessage(content="Hello")
|
||||
response = chat([message])
|
||||
assert callback_handler.llm_streams > 0
|
||||
assert isinstance(response, BaseMessage)
|
||||
|
||||
|
||||
@pytest.mark.scheduled
|
||||
def test_chat_openai_streaming_generation_info() -> None:
|
||||
"""Test that generation info is preserved when streaming."""
|
||||
|
||||
class _FakeCallback(FakeCallbackHandler):
|
||||
saved_things: dict = {}
|
||||
|
||||
def on_llm_end(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
# Save the generation
|
||||
self.saved_things["generation"] = args[0]
|
||||
|
||||
callback = _FakeCallback()
|
||||
callback_manager = CallbackManager([callback])
|
||||
chat = _get_llm(
|
||||
max_tokens=2,
|
||||
temperature=0,
|
||||
callback_manager=callback_manager,
|
||||
)
|
||||
list(chat.stream("hi"))
|
||||
generation = callback.saved_things["generation"]
|
||||
# `Hello!` is two tokens, assert that that is what is returned
|
||||
assert generation.generations[0][0].text == "Hello!"
|
||||
|
||||
|
||||
@pytest.mark.scheduled
|
||||
async def test_async_chat_openai() -> None:
|
||||
"""Test async generation."""
|
||||
chat = _get_llm(max_tokens=10, n=2)
|
||||
message = HumanMessage(content="Hello")
|
||||
response = await chat.agenerate([[message], [message]])
|
||||
assert isinstance(response, LLMResult)
|
||||
assert len(response.generations) == 2
|
||||
for generations in response.generations:
|
||||
assert len(generations) == 2
|
||||
for generation in generations:
|
||||
assert isinstance(generation, ChatGeneration)
|
||||
assert isinstance(generation.text, str)
|
||||
assert generation.text == generation.message.content
|
||||
|
||||
|
||||
@pytest.mark.scheduled
|
||||
async def test_async_chat_openai_streaming() -> None:
|
||||
"""Test that streaming correctly invokes on_llm_new_token callback."""
|
||||
callback_handler = FakeCallbackHandler()
|
||||
callback_manager = CallbackManager([callback_handler])
|
||||
chat = _get_llm(
|
||||
max_tokens=10,
|
||||
streaming=True,
|
||||
temperature=0,
|
||||
callback_manager=callback_manager,
|
||||
verbose=True,
|
||||
)
|
||||
message = HumanMessage(content="Hello")
|
||||
response = await chat.agenerate([[message], [message]])
|
||||
assert callback_handler.llm_streams > 0
|
||||
assert isinstance(response, LLMResult)
|
||||
assert len(response.generations) == 2
|
||||
for generations in response.generations:
|
||||
assert len(generations) == 1
|
||||
for generation in generations:
|
||||
assert isinstance(generation, ChatGeneration)
|
||||
assert isinstance(generation.text, str)
|
||||
assert generation.text == generation.message.content
|
||||
|
||||
|
||||
@pytest.mark.scheduled
|
||||
def test_openai_streaming(llm: AzureChatOpenAI) -> None:
|
||||
"""Test streaming tokens from OpenAI."""
|
||||
|
||||
for token in llm.stream("I'm Pickle Rick"):
|
||||
assert isinstance(token.content, str)
|
||||
|
||||
|
||||
@pytest.mark.scheduled
|
||||
async def test_openai_astream(llm: AzureChatOpenAI) -> None:
|
||||
"""Test streaming tokens from OpenAI."""
|
||||
async for token in llm.astream("I'm Pickle Rick"):
|
||||
assert isinstance(token.content, str)
|
||||
|
||||
|
||||
@pytest.mark.scheduled
|
||||
async def test_openai_abatch(llm: AzureChatOpenAI) -> None:
|
||||
"""Test streaming tokens from AzureChatOpenAI."""
|
||||
|
||||
result = await llm.abatch(["I'm Pickle Rick", "I'm not Pickle Rick"])
|
||||
for token in result:
|
||||
assert isinstance(token.content, str)
|
||||
|
||||
|
||||
@pytest.mark.scheduled
|
||||
async def test_openai_abatch_tags(llm: AzureChatOpenAI) -> None:
|
||||
"""Test batch tokens from AzureChatOpenAI."""
|
||||
|
||||
result = await llm.abatch(
|
||||
["I'm Pickle Rick", "I'm not Pickle Rick"], config={"tags": ["foo"]}
|
||||
)
|
||||
for token in result:
|
||||
assert isinstance(token.content, str)
|
||||
|
||||
|
||||
@pytest.mark.scheduled
|
||||
def test_openai_batch(llm: AzureChatOpenAI) -> None:
|
||||
"""Test batch tokens from AzureChatOpenAI."""
|
||||
|
||||
result = llm.batch(["I'm Pickle Rick", "I'm not Pickle Rick"])
|
||||
for token in result:
|
||||
assert isinstance(token.content, str)
|
||||
|
||||
|
||||
@pytest.mark.scheduled
|
||||
async def test_openai_ainvoke(llm: AzureChatOpenAI) -> None:
|
||||
"""Test invoke tokens from AzureChatOpenAI."""
|
||||
|
||||
result = await llm.ainvoke("I'm Pickle Rick", config={"tags": ["foo"]})
|
||||
assert isinstance(result.content, str)
|
||||
|
||||
|
||||
@pytest.mark.scheduled
|
||||
def test_openai_invoke(llm: AzureChatOpenAI) -> None:
|
||||
"""Test invoke tokens from AzureChatOpenAI."""
|
||||
|
||||
result = llm.invoke("I'm Pickle Rick", config=dict(tags=["foo"]))
|
||||
assert isinstance(result.content, str)
|
@@ -0,0 +1,393 @@
|
||||
"""Test ChatOpenAI chat model."""
|
||||
from typing import Any, Optional
|
||||
|
||||
import pytest
|
||||
from langchain_core.callbacks import CallbackManager
|
||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
|
||||
from langchain_core.outputs import (
|
||||
ChatGeneration,
|
||||
ChatResult,
|
||||
LLMResult,
|
||||
)
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field
|
||||
|
||||
from langchain_openai import ChatOpenAI
|
||||
from tests.unit_tests.fake.callbacks import FakeCallbackHandler
|
||||
|
||||
|
||||
@pytest.mark.scheduled
|
||||
def test_chat_openai() -> None:
|
||||
"""Test ChatOpenAI wrapper."""
|
||||
chat = ChatOpenAI(
|
||||
temperature=0.7,
|
||||
base_url=None,
|
||||
organization=None,
|
||||
openai_proxy=None,
|
||||
timeout=10.0,
|
||||
max_retries=3,
|
||||
http_client=None,
|
||||
n=1,
|
||||
max_tokens=10,
|
||||
default_headers=None,
|
||||
default_query=None,
|
||||
)
|
||||
message = HumanMessage(content="Hello")
|
||||
response = chat([message])
|
||||
assert isinstance(response, BaseMessage)
|
||||
assert isinstance(response.content, str)
|
||||
|
||||
|
||||
def test_chat_openai_model() -> None:
|
||||
"""Test ChatOpenAI wrapper handles model_name."""
|
||||
chat = ChatOpenAI(model="foo")
|
||||
assert chat.model_name == "foo"
|
||||
chat = ChatOpenAI(model_name="bar")
|
||||
assert chat.model_name == "bar"
|
||||
|
||||
|
||||
def test_chat_openai_system_message() -> None:
|
||||
"""Test ChatOpenAI wrapper with system message."""
|
||||
chat = ChatOpenAI(max_tokens=10)
|
||||
system_message = SystemMessage(content="You are to chat with the user.")
|
||||
human_message = HumanMessage(content="Hello")
|
||||
response = chat([system_message, human_message])
|
||||
assert isinstance(response, BaseMessage)
|
||||
assert isinstance(response.content, str)
|
||||
|
||||
|
||||
@pytest.mark.scheduled
|
||||
def test_chat_openai_generate() -> None:
|
||||
"""Test ChatOpenAI wrapper with generate."""
|
||||
chat = ChatOpenAI(max_tokens=10, n=2)
|
||||
message = HumanMessage(content="Hello")
|
||||
response = chat.generate([[message], [message]])
|
||||
assert isinstance(response, LLMResult)
|
||||
assert len(response.generations) == 2
|
||||
assert response.llm_output
|
||||
for generations in response.generations:
|
||||
assert len(generations) == 2
|
||||
for generation in generations:
|
||||
assert isinstance(generation, ChatGeneration)
|
||||
assert isinstance(generation.text, str)
|
||||
assert generation.text == generation.message.content
|
||||
|
||||
|
||||
@pytest.mark.scheduled
|
||||
def test_chat_openai_multiple_completions() -> None:
|
||||
"""Test ChatOpenAI wrapper with multiple completions."""
|
||||
chat = ChatOpenAI(max_tokens=10, n=5)
|
||||
message = HumanMessage(content="Hello")
|
||||
response = chat._generate([message])
|
||||
assert isinstance(response, ChatResult)
|
||||
assert len(response.generations) == 5
|
||||
for generation in response.generations:
|
||||
assert isinstance(generation.message, BaseMessage)
|
||||
assert isinstance(generation.message.content, str)
|
||||
|
||||
|
||||
@pytest.mark.scheduled
|
||||
def test_chat_openai_streaming() -> None:
|
||||
"""Test that streaming correctly invokes on_llm_new_token callback."""
|
||||
callback_handler = FakeCallbackHandler()
|
||||
callback_manager = CallbackManager([callback_handler])
|
||||
chat = ChatOpenAI(
|
||||
max_tokens=10,
|
||||
streaming=True,
|
||||
temperature=0,
|
||||
callback_manager=callback_manager,
|
||||
verbose=True,
|
||||
)
|
||||
message = HumanMessage(content="Hello")
|
||||
response = chat([message])
|
||||
assert callback_handler.llm_streams > 0
|
||||
assert isinstance(response, BaseMessage)
|
||||
|
||||
|
||||
@pytest.mark.scheduled
|
||||
def test_chat_openai_streaming_generation_info() -> None:
|
||||
"""Test that generation info is preserved when streaming."""
|
||||
|
||||
class _FakeCallback(FakeCallbackHandler):
|
||||
saved_things: dict = {}
|
||||
|
||||
def on_llm_end(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
# Save the generation
|
||||
self.saved_things["generation"] = args[0]
|
||||
|
||||
callback = _FakeCallback()
|
||||
callback_manager = CallbackManager([callback])
|
||||
chat = ChatOpenAI(
|
||||
max_tokens=2,
|
||||
temperature=0,
|
||||
callback_manager=callback_manager,
|
||||
)
|
||||
list(chat.stream("hi"))
|
||||
generation = callback.saved_things["generation"]
|
||||
# `Hello!` is two tokens, assert that that is what is returned
|
||||
assert generation.generations[0][0].text == "Hello!"
|
||||
|
||||
|
||||
def test_chat_openai_llm_output_contains_model_name() -> None:
|
||||
"""Test llm_output contains model_name."""
|
||||
chat = ChatOpenAI(max_tokens=10)
|
||||
message = HumanMessage(content="Hello")
|
||||
llm_result = chat.generate([[message]])
|
||||
assert llm_result.llm_output is not None
|
||||
assert llm_result.llm_output["model_name"] == chat.model_name
|
||||
|
||||
|
||||
def test_chat_openai_streaming_llm_output_contains_model_name() -> None:
|
||||
"""Test llm_output contains model_name."""
|
||||
chat = ChatOpenAI(max_tokens=10, streaming=True)
|
||||
message = HumanMessage(content="Hello")
|
||||
llm_result = chat.generate([[message]])
|
||||
assert llm_result.llm_output is not None
|
||||
assert llm_result.llm_output["model_name"] == chat.model_name
|
||||
|
||||
|
||||
def test_chat_openai_invalid_streaming_params() -> None:
|
||||
"""Test that streaming correctly invokes on_llm_new_token callback."""
|
||||
with pytest.raises(ValueError):
|
||||
ChatOpenAI(
|
||||
max_tokens=10,
|
||||
streaming=True,
|
||||
temperature=0,
|
||||
n=5,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.scheduled
|
||||
async def test_async_chat_openai() -> None:
|
||||
"""Test async generation."""
|
||||
chat = ChatOpenAI(max_tokens=10, n=2)
|
||||
message = HumanMessage(content="Hello")
|
||||
response = await chat.agenerate([[message], [message]])
|
||||
assert isinstance(response, LLMResult)
|
||||
assert len(response.generations) == 2
|
||||
assert response.llm_output
|
||||
for generations in response.generations:
|
||||
assert len(generations) == 2
|
||||
for generation in generations:
|
||||
assert isinstance(generation, ChatGeneration)
|
||||
assert isinstance(generation.text, str)
|
||||
assert generation.text == generation.message.content
|
||||
|
||||
|
||||
@pytest.mark.scheduled
|
||||
async def test_async_chat_openai_streaming() -> None:
|
||||
"""Test that streaming correctly invokes on_llm_new_token callback."""
|
||||
callback_handler = FakeCallbackHandler()
|
||||
callback_manager = CallbackManager([callback_handler])
|
||||
chat = ChatOpenAI(
|
||||
max_tokens=10,
|
||||
streaming=True,
|
||||
temperature=0,
|
||||
callback_manager=callback_manager,
|
||||
verbose=True,
|
||||
)
|
||||
message = HumanMessage(content="Hello")
|
||||
response = await chat.agenerate([[message], [message]])
|
||||
assert callback_handler.llm_streams > 0
|
||||
assert isinstance(response, LLMResult)
|
||||
assert len(response.generations) == 2
|
||||
for generations in response.generations:
|
||||
assert len(generations) == 1
|
||||
for generation in generations:
|
||||
assert isinstance(generation, ChatGeneration)
|
||||
assert isinstance(generation.text, str)
|
||||
assert generation.text == generation.message.content
|
||||
|
||||
|
||||
@pytest.mark.scheduled
|
||||
async def test_async_chat_openai_bind_functions() -> None:
|
||||
"""Test ChatOpenAI wrapper with multiple completions."""
|
||||
|
||||
class Person(BaseModel):
|
||||
"""Identifying information about a person."""
|
||||
|
||||
name: str = Field(..., title="Name", description="The person's name")
|
||||
age: int = Field(..., title="Age", description="The person's age")
|
||||
fav_food: Optional[str] = Field(
|
||||
default=None, title="Fav Food", description="The person's favorite food"
|
||||
)
|
||||
|
||||
chat = ChatOpenAI(
|
||||
max_tokens=30,
|
||||
n=1,
|
||||
streaming=True,
|
||||
).bind_functions(functions=[Person], function_call="Person")
|
||||
|
||||
prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
("system", "Use the provided Person function"),
|
||||
("user", "{input}"),
|
||||
]
|
||||
)
|
||||
|
||||
chain = prompt | chat
|
||||
|
||||
message = HumanMessage(content="Sally is 13 years old")
|
||||
response = await chain.abatch([{"input": message}])
|
||||
|
||||
assert isinstance(response, list)
|
||||
assert len(response) == 1
|
||||
for generation in response:
|
||||
assert isinstance(generation, AIMessage)
|
||||
|
||||
|
||||
def test_chat_openai_extra_kwargs() -> None:
|
||||
"""Test extra kwargs to chat openai."""
|
||||
# Check that foo is saved in extra_kwargs.
|
||||
llm = ChatOpenAI(foo=3, max_tokens=10)
|
||||
assert llm.max_tokens == 10
|
||||
assert llm.model_kwargs == {"foo": 3}
|
||||
|
||||
# Test that if extra_kwargs are provided, they are added to it.
|
||||
llm = ChatOpenAI(foo=3, model_kwargs={"bar": 2})
|
||||
assert llm.model_kwargs == {"foo": 3, "bar": 2}
|
||||
|
||||
# Test that if provided twice it errors
|
||||
with pytest.raises(ValueError):
|
||||
ChatOpenAI(foo=3, model_kwargs={"foo": 2})
|
||||
|
||||
# Test that if explicit param is specified in kwargs it errors
|
||||
with pytest.raises(ValueError):
|
||||
ChatOpenAI(model_kwargs={"temperature": 0.2})
|
||||
|
||||
# Test that "model" cannot be specified in kwargs
|
||||
with pytest.raises(ValueError):
|
||||
ChatOpenAI(model_kwargs={"model": "gpt-3.5-turbo-instruct"})
|
||||
|
||||
|
||||
@pytest.mark.scheduled
|
||||
def test_openai_streaming() -> None:
|
||||
"""Test streaming tokens from OpenAI."""
|
||||
llm = ChatOpenAI(max_tokens=10)
|
||||
|
||||
for token in llm.stream("I'm Pickle Rick"):
|
||||
assert isinstance(token.content, str)
|
||||
|
||||
|
||||
@pytest.mark.scheduled
|
||||
async def test_openai_astream() -> None:
|
||||
"""Test streaming tokens from OpenAI."""
|
||||
llm = ChatOpenAI(max_tokens=10)
|
||||
|
||||
async for token in llm.astream("I'm Pickle Rick"):
|
||||
assert isinstance(token.content, str)
|
||||
|
||||
|
||||
@pytest.mark.scheduled
|
||||
async def test_openai_abatch() -> None:
|
||||
"""Test streaming tokens from ChatOpenAI."""
|
||||
llm = ChatOpenAI(max_tokens=10)
|
||||
|
||||
result = await llm.abatch(["I'm Pickle Rick", "I'm not Pickle Rick"])
|
||||
for token in result:
|
||||
assert isinstance(token.content, str)
|
||||
|
||||
|
||||
@pytest.mark.scheduled
|
||||
async def test_openai_abatch_tags() -> None:
|
||||
"""Test batch tokens from ChatOpenAI."""
|
||||
llm = ChatOpenAI(max_tokens=10)
|
||||
|
||||
result = await llm.abatch(
|
||||
["I'm Pickle Rick", "I'm not Pickle Rick"], config={"tags": ["foo"]}
|
||||
)
|
||||
for token in result:
|
||||
assert isinstance(token.content, str)
|
||||
|
||||
|
||||
@pytest.mark.scheduled
|
||||
def test_openai_batch() -> None:
|
||||
"""Test batch tokens from ChatOpenAI."""
|
||||
llm = ChatOpenAI(max_tokens=10)
|
||||
|
||||
result = llm.batch(["I'm Pickle Rick", "I'm not Pickle Rick"])
|
||||
for token in result:
|
||||
assert isinstance(token.content, str)
|
||||
|
||||
|
||||
@pytest.mark.scheduled
|
||||
async def test_openai_ainvoke() -> None:
|
||||
"""Test invoke tokens from ChatOpenAI."""
|
||||
llm = ChatOpenAI(max_tokens=10)
|
||||
|
||||
result = await llm.ainvoke("I'm Pickle Rick", config={"tags": ["foo"]})
|
||||
assert isinstance(result.content, str)
|
||||
|
||||
|
||||
@pytest.mark.scheduled
|
||||
def test_openai_invoke() -> None:
|
||||
"""Test invoke tokens from ChatOpenAI."""
|
||||
llm = ChatOpenAI(max_tokens=10)
|
||||
|
||||
result = llm.invoke("I'm Pickle Rick", config=dict(tags=["foo"]))
|
||||
assert isinstance(result.content, str)
|
||||
|
||||
|
||||
def test_stream() -> None:
|
||||
"""Test streaming tokens from OpenAI."""
|
||||
llm = ChatOpenAI()
|
||||
|
||||
for token in llm.stream("I'm Pickle Rick"):
|
||||
assert isinstance(token.content, str)
|
||||
|
||||
|
||||
async def test_astream() -> None:
|
||||
"""Test streaming tokens from OpenAI."""
|
||||
llm = ChatOpenAI()
|
||||
|
||||
async for token in llm.astream("I'm Pickle Rick"):
|
||||
assert isinstance(token.content, str)
|
||||
|
||||
|
||||
async def test_abatch() -> None:
|
||||
"""Test streaming tokens from ChatOpenAI."""
|
||||
llm = ChatOpenAI()
|
||||
|
||||
result = await llm.abatch(["I'm Pickle Rick", "I'm not Pickle Rick"])
|
||||
for token in result:
|
||||
assert isinstance(token.content, str)
|
||||
|
||||
|
||||
async def test_abatch_tags() -> None:
|
||||
"""Test batch tokens from ChatOpenAI."""
|
||||
llm = ChatOpenAI()
|
||||
|
||||
result = await llm.abatch(
|
||||
["I'm Pickle Rick", "I'm not Pickle Rick"], config={"tags": ["foo"]}
|
||||
)
|
||||
for token in result:
|
||||
assert isinstance(token.content, str)
|
||||
|
||||
|
||||
def test_batch() -> None:
|
||||
"""Test batch tokens from ChatOpenAI."""
|
||||
llm = ChatOpenAI()
|
||||
|
||||
result = llm.batch(["I'm Pickle Rick", "I'm not Pickle Rick"])
|
||||
for token in result:
|
||||
assert isinstance(token.content, str)
|
||||
|
||||
|
||||
async def test_ainvoke() -> None:
|
||||
"""Test invoke tokens from ChatOpenAI."""
|
||||
llm = ChatOpenAI()
|
||||
|
||||
result = await llm.ainvoke("I'm Pickle Rick", config={"tags": ["foo"]})
|
||||
assert isinstance(result.content, str)
|
||||
|
||||
|
||||
def test_invoke() -> None:
|
||||
"""Test invoke tokens from ChatOpenAI."""
|
||||
llm = ChatOpenAI()
|
||||
|
||||
result = llm.invoke("I'm Pickle Rick", config=dict(tags=["foo"]))
|
||||
assert isinstance(result.content, str)
|
@@ -0,0 +1,132 @@
|
||||
"""Test azure openai embeddings."""
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import openai
|
||||
import pytest
|
||||
|
||||
from langchain_openai import AzureOpenAIEmbeddings
|
||||
|
||||
OPENAI_API_VERSION = os.environ.get("AZURE_OPENAI_API_VERSION", "")
|
||||
OPENAI_API_BASE = os.environ.get("AZURE_OPENAI_API_BASE", "")
|
||||
OPENAI_API_KEY = os.environ.get("AZURE_OPENAI_API_KEY", "")
|
||||
DEPLOYMENT_NAME = os.environ.get(
|
||||
"AZURE_OPENAI_DEPLOYMENT_NAME",
|
||||
os.environ.get("AZURE_OPENAI_EMBEDDINGS_DEPLOYMENT_NAME", ""),
|
||||
)
|
||||
print
|
||||
|
||||
|
||||
def _get_embeddings(**kwargs: Any) -> AzureOpenAIEmbeddings:
|
||||
return AzureOpenAIEmbeddings(
|
||||
azure_deployment=DEPLOYMENT_NAME,
|
||||
api_version=OPENAI_API_VERSION,
|
||||
openai_api_base=OPENAI_API_BASE,
|
||||
openai_api_key=OPENAI_API_KEY,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.scheduled
|
||||
def test_azure_openai_embedding_documents() -> None:
|
||||
"""Test openai embeddings."""
|
||||
documents = ["foo bar"]
|
||||
embedding = _get_embeddings()
|
||||
output = embedding.embed_documents(documents)
|
||||
assert len(output) == 1
|
||||
assert len(output[0]) == 1536
|
||||
|
||||
|
||||
@pytest.mark.scheduled
|
||||
def test_azure_openai_embedding_documents_multiple() -> None:
|
||||
"""Test openai embeddings."""
|
||||
documents = ["foo bar", "bar foo", "foo"]
|
||||
embedding = _get_embeddings(chunk_size=2)
|
||||
embedding.embedding_ctx_length = 8191
|
||||
output = embedding.embed_documents(documents)
|
||||
assert embedding.chunk_size == 2
|
||||
assert len(output) == 3
|
||||
assert len(output[0]) == 1536
|
||||
assert len(output[1]) == 1536
|
||||
assert len(output[2]) == 1536
|
||||
|
||||
|
||||
@pytest.mark.scheduled
|
||||
def test_azure_openai_embedding_documents_chunk_size() -> None:
|
||||
"""Test openai embeddings."""
|
||||
documents = ["foo bar"] * 20
|
||||
embedding = _get_embeddings()
|
||||
embedding.embedding_ctx_length = 8191
|
||||
output = embedding.embed_documents(documents)
|
||||
# Max 16 chunks per batch on Azure OpenAI embeddings
|
||||
assert embedding.chunk_size == 16
|
||||
assert len(output) == 20
|
||||
assert all([len(out) == 1536 for out in output])
|
||||
|
||||
|
||||
@pytest.mark.scheduled
|
||||
async def test_azure_openai_embedding_documents_async_multiple() -> None:
|
||||
"""Test openai embeddings."""
|
||||
documents = ["foo bar", "bar foo", "foo"]
|
||||
embedding = _get_embeddings(chunk_size=2)
|
||||
embedding.embedding_ctx_length = 8191
|
||||
output = await embedding.aembed_documents(documents)
|
||||
assert len(output) == 3
|
||||
assert len(output[0]) == 1536
|
||||
assert len(output[1]) == 1536
|
||||
assert len(output[2]) == 1536
|
||||
|
||||
|
||||
@pytest.mark.scheduled
|
||||
def test_azure_openai_embedding_query() -> None:
|
||||
"""Test openai embeddings."""
|
||||
document = "foo bar"
|
||||
embedding = _get_embeddings()
|
||||
output = embedding.embed_query(document)
|
||||
assert len(output) == 1536
|
||||
|
||||
|
||||
@pytest.mark.scheduled
|
||||
async def test_azure_openai_embedding_async_query() -> None:
|
||||
"""Test openai embeddings."""
|
||||
document = "foo bar"
|
||||
embedding = _get_embeddings()
|
||||
output = await embedding.aembed_query(document)
|
||||
assert len(output) == 1536
|
||||
|
||||
|
||||
@pytest.mark.scheduled
|
||||
def test_azure_openai_embedding_with_empty_string() -> None:
|
||||
"""Test openai embeddings with empty string."""
|
||||
|
||||
document = ["", "abc"]
|
||||
embedding = _get_embeddings()
|
||||
output = embedding.embed_documents(document)
|
||||
assert len(output) == 2
|
||||
assert len(output[0]) == 1536
|
||||
expected_output = (
|
||||
openai.AzureOpenAI(
|
||||
api_version=OPENAI_API_VERSION,
|
||||
api_key=OPENAI_API_KEY,
|
||||
base_url=embedding.openai_api_base,
|
||||
azure_deployment=DEPLOYMENT_NAME,
|
||||
) # type: ignore
|
||||
.embeddings.create(input="", model="text-embedding-ada-002")
|
||||
.data[0]
|
||||
.embedding
|
||||
)
|
||||
assert np.allclose(output[0], expected_output)
|
||||
assert len(output[1]) == 1536
|
||||
|
||||
|
||||
@pytest.mark.scheduled
|
||||
def test_embed_documents_normalized() -> None:
|
||||
output = _get_embeddings().embed_documents(["foo walked to the market"])
|
||||
assert np.isclose(np.linalg.norm(output[0]), 1.0)
|
||||
|
||||
|
||||
@pytest.mark.scheduled
|
||||
def test_embed_query_normalized() -> None:
|
||||
output = _get_embeddings().embed_query("foo walked to the market")
|
||||
assert np.isclose(np.linalg.norm(output), 1.0)
|
@@ -0,0 +1,19 @@
|
||||
"""Test OpenAI embeddings."""
|
||||
from langchain_openai.embeddings.base import OpenAIEmbeddings
|
||||
|
||||
|
||||
def test_langchain_openai_embedding_documents() -> None:
|
||||
"""Test cohere embeddings."""
|
||||
documents = ["foo bar"]
|
||||
embedding = OpenAIEmbeddings()
|
||||
output = embedding.embed_documents(documents)
|
||||
assert len(output) == 1
|
||||
assert len(output[0]) > 0
|
||||
|
||||
|
||||
def test_langchain_openai_embedding_query() -> None:
|
||||
"""Test cohere embeddings."""
|
||||
document = "foo bar"
|
||||
embedding = OpenAIEmbeddings()
|
||||
output = embedding.embed_query(document)
|
||||
assert len(output) > 0
|
176
libs/partners/openai/tests/integration_tests/llms/test_azure.py
Normal file
176
libs/partners/openai/tests/integration_tests/llms/test_azure.py
Normal file
@@ -0,0 +1,176 @@
|
||||
"""Test AzureOpenAI wrapper."""
|
||||
import os
|
||||
from typing import Any, Generator
|
||||
|
||||
import pytest
|
||||
from langchain_core.callbacks import CallbackManager
|
||||
from langchain_core.outputs import LLMResult
|
||||
|
||||
from langchain_openai import AzureOpenAI
|
||||
from tests.unit_tests.fake.callbacks import FakeCallbackHandler
|
||||
|
||||
OPENAI_API_VERSION = os.environ.get("AZURE_OPENAI_API_VERSION", "")
|
||||
OPENAI_API_BASE = os.environ.get("AZURE_OPENAI_API_BASE", "")
|
||||
OPENAI_API_KEY = os.environ.get("AZURE_OPENAI_API_KEY", "")
|
||||
DEPLOYMENT_NAME = os.environ.get(
|
||||
"AZURE_OPENAI_DEPLOYMENT_NAME",
|
||||
os.environ.get("AZURE_OPENAI_LLM_DEPLOYMENT_NAME", ""),
|
||||
)
|
||||
|
||||
|
||||
def _get_llm(**kwargs: Any) -> AzureOpenAI:
|
||||
return AzureOpenAI(
|
||||
deployment_name=DEPLOYMENT_NAME,
|
||||
openai_api_version=OPENAI_API_VERSION,
|
||||
openai_api_base=OPENAI_API_BASE,
|
||||
openai_api_key=OPENAI_API_KEY,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def llm() -> AzureOpenAI:
|
||||
return _get_llm(
|
||||
max_tokens=10,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.scheduled
|
||||
def test_openai_call(llm: AzureOpenAI) -> None:
|
||||
"""Test valid call to openai."""
|
||||
output = llm("Say something nice:")
|
||||
assert isinstance(output, str)
|
||||
|
||||
|
||||
@pytest.mark.scheduled
|
||||
def test_openai_streaming(llm: AzureOpenAI) -> None:
|
||||
"""Test streaming tokens from AzureOpenAI."""
|
||||
generator = llm.stream("I'm Pickle Rick")
|
||||
|
||||
assert isinstance(generator, Generator)
|
||||
|
||||
full_response = ""
|
||||
for token in generator:
|
||||
assert isinstance(token, str)
|
||||
full_response += token
|
||||
assert full_response
|
||||
|
||||
|
||||
@pytest.mark.scheduled
|
||||
async def test_openai_astream(llm: AzureOpenAI) -> None:
|
||||
"""Test streaming tokens from AzureOpenAI."""
|
||||
async for token in llm.astream("I'm Pickle Rick"):
|
||||
assert isinstance(token, str)
|
||||
|
||||
|
||||
@pytest.mark.scheduled
|
||||
async def test_openai_abatch(llm: AzureOpenAI) -> None:
|
||||
"""Test streaming tokens from AzureOpenAI."""
|
||||
result = await llm.abatch(["I'm Pickle Rick", "I'm not Pickle Rick"])
|
||||
for token in result:
|
||||
assert isinstance(token, str)
|
||||
|
||||
|
||||
async def test_openai_abatch_tags(llm: AzureOpenAI) -> None:
|
||||
"""Test streaming tokens from AzureOpenAI."""
|
||||
result = await llm.abatch(
|
||||
["I'm Pickle Rick", "I'm not Pickle Rick"], config={"tags": ["foo"]}
|
||||
)
|
||||
for token in result:
|
||||
assert isinstance(token, str)
|
||||
|
||||
|
||||
@pytest.mark.scheduled
|
||||
def test_openai_batch(llm: AzureOpenAI) -> None:
|
||||
"""Test streaming tokens from AzureOpenAI."""
|
||||
result = llm.batch(["I'm Pickle Rick", "I'm not Pickle Rick"])
|
||||
for token in result:
|
||||
assert isinstance(token, str)
|
||||
|
||||
|
||||
@pytest.mark.scheduled
|
||||
async def test_openai_ainvoke(llm: AzureOpenAI) -> None:
|
||||
"""Test streaming tokens from AzureOpenAI."""
|
||||
result = await llm.ainvoke("I'm Pickle Rick", config={"tags": ["foo"]})
|
||||
assert isinstance(result, str)
|
||||
|
||||
|
||||
@pytest.mark.scheduled
|
||||
def test_openai_invoke(llm: AzureOpenAI) -> None:
|
||||
"""Test streaming tokens from AzureOpenAI."""
|
||||
result = llm.invoke("I'm Pickle Rick", config=dict(tags=["foo"]))
|
||||
assert isinstance(result, str)
|
||||
|
||||
|
||||
@pytest.mark.scheduled
|
||||
def test_openai_multiple_prompts(llm: AzureOpenAI) -> None:
|
||||
"""Test completion with multiple prompts."""
|
||||
output = llm.generate(["I'm Pickle Rick", "I'm Pickle Rick"])
|
||||
assert isinstance(output, LLMResult)
|
||||
assert isinstance(output.generations, list)
|
||||
assert len(output.generations) == 2
|
||||
|
||||
|
||||
def test_openai_streaming_best_of_error() -> None:
|
||||
"""Test validation for streaming fails if best_of is not 1."""
|
||||
with pytest.raises(ValueError):
|
||||
_get_llm(best_of=2, streaming=True)
|
||||
|
||||
|
||||
def test_openai_streaming_n_error() -> None:
|
||||
"""Test validation for streaming fails if n is not 1."""
|
||||
with pytest.raises(ValueError):
|
||||
_get_llm(n=2, streaming=True)
|
||||
|
||||
|
||||
def test_openai_streaming_multiple_prompts_error() -> None:
|
||||
"""Test validation for streaming fails if multiple prompts are given."""
|
||||
with pytest.raises(ValueError):
|
||||
_get_llm(streaming=True).generate(["I'm Pickle Rick", "I'm Pickle Rick"])
|
||||
|
||||
|
||||
@pytest.mark.scheduled
|
||||
def test_openai_streaming_call() -> None:
|
||||
"""Test valid call to openai."""
|
||||
llm = _get_llm(max_tokens=10, streaming=True)
|
||||
output = llm("Say foo:")
|
||||
assert isinstance(output, str)
|
||||
|
||||
|
||||
def test_openai_streaming_callback() -> None:
|
||||
"""Test that streaming correctly invokes on_llm_new_token callback."""
|
||||
callback_handler = FakeCallbackHandler()
|
||||
callback_manager = CallbackManager([callback_handler])
|
||||
llm = _get_llm(
|
||||
max_tokens=10,
|
||||
streaming=True,
|
||||
temperature=0,
|
||||
callback_manager=callback_manager,
|
||||
verbose=True,
|
||||
)
|
||||
llm("Write me a sentence with 100 words.")
|
||||
assert callback_handler.llm_streams == 11
|
||||
|
||||
|
||||
@pytest.mark.scheduled
|
||||
async def test_openai_async_generate() -> None:
|
||||
"""Test async generation."""
|
||||
llm = _get_llm(max_tokens=10)
|
||||
output = await llm.agenerate(["Hello, how are you?"])
|
||||
assert isinstance(output, LLMResult)
|
||||
|
||||
|
||||
async def test_openai_async_streaming_callback() -> None:
|
||||
"""Test that streaming correctly invokes on_llm_new_token callback."""
|
||||
callback_handler = FakeCallbackHandler()
|
||||
callback_manager = CallbackManager([callback_handler])
|
||||
llm = _get_llm(
|
||||
max_tokens=10,
|
||||
streaming=True,
|
||||
temperature=0,
|
||||
callback_manager=callback_manager,
|
||||
verbose=True,
|
||||
)
|
||||
result = await llm.agenerate(["Write me a sentence with 100 words."])
|
||||
assert callback_handler.llm_streams == 11
|
||||
assert isinstance(result, LLMResult)
|
280
libs/partners/openai/tests/integration_tests/llms/test_base.py
Normal file
280
libs/partners/openai/tests/integration_tests/llms/test_base.py
Normal file
@@ -0,0 +1,280 @@
|
||||
"""Test OpenAI llm."""
|
||||
from typing import Generator
|
||||
|
||||
import pytest
|
||||
from langchain_core.callbacks import CallbackManager
|
||||
from langchain_core.outputs import LLMResult
|
||||
|
||||
from langchain_openai import OpenAI
|
||||
from tests.unit_tests.fake.callbacks import (
|
||||
FakeCallbackHandler,
|
||||
)
|
||||
|
||||
|
||||
def test_stream() -> None:
|
||||
"""Test streaming tokens from OpenAI."""
|
||||
llm = OpenAI()
|
||||
|
||||
for token in llm.stream("I'm Pickle Rick"):
|
||||
assert isinstance(token, str)
|
||||
|
||||
|
||||
async def test_astream() -> None:
|
||||
"""Test streaming tokens from OpenAI."""
|
||||
llm = OpenAI()
|
||||
|
||||
async for token in llm.astream("I'm Pickle Rick"):
|
||||
assert isinstance(token, str)
|
||||
|
||||
|
||||
async def test_abatch() -> None:
|
||||
"""Test streaming tokens from OpenAI."""
|
||||
llm = OpenAI()
|
||||
|
||||
result = await llm.abatch(["I'm Pickle Rick", "I'm not Pickle Rick"])
|
||||
for token in result:
|
||||
assert isinstance(token, str)
|
||||
|
||||
|
||||
async def test_abatch_tags() -> None:
|
||||
"""Test batch tokens from OpenAI."""
|
||||
llm = OpenAI()
|
||||
|
||||
result = await llm.abatch(
|
||||
["I'm Pickle Rick", "I'm not Pickle Rick"], config={"tags": ["foo"]}
|
||||
)
|
||||
for token in result:
|
||||
assert isinstance(token, str)
|
||||
|
||||
|
||||
def test_batch() -> None:
|
||||
"""Test batch tokens from OpenAI."""
|
||||
llm = OpenAI()
|
||||
|
||||
result = llm.batch(["I'm Pickle Rick", "I'm not Pickle Rick"])
|
||||
for token in result:
|
||||
assert isinstance(token, str)
|
||||
|
||||
|
||||
async def test_ainvoke() -> None:
|
||||
"""Test invoke tokens from OpenAI."""
|
||||
llm = OpenAI()
|
||||
|
||||
result = await llm.ainvoke("I'm Pickle Rick", config={"tags": ["foo"]})
|
||||
assert isinstance(result, str)
|
||||
|
||||
|
||||
def test_invoke() -> None:
|
||||
"""Test invoke tokens from OpenAI."""
|
||||
llm = OpenAI()
|
||||
|
||||
result = llm.invoke("I'm Pickle Rick", config=dict(tags=["foo"]))
|
||||
assert isinstance(result, str)
|
||||
|
||||
|
||||
@pytest.mark.scheduled
|
||||
def test_openai_call() -> None:
|
||||
"""Test valid call to openai."""
|
||||
llm = OpenAI()
|
||||
output = llm("Say something nice:")
|
||||
assert isinstance(output, str)
|
||||
|
||||
|
||||
def test_openai_llm_output_contains_model_name() -> None:
|
||||
"""Test llm_output contains model_name."""
|
||||
llm = OpenAI(max_tokens=10)
|
||||
llm_result = llm.generate(["Hello, how are you?"])
|
||||
assert llm_result.llm_output is not None
|
||||
assert llm_result.llm_output["model_name"] == llm.model_name
|
||||
|
||||
|
||||
def test_openai_stop_valid() -> None:
|
||||
"""Test openai stop logic on valid configuration."""
|
||||
query = "write an ordered list of five items"
|
||||
first_llm = OpenAI(stop="3", temperature=0)
|
||||
first_output = first_llm(query)
|
||||
second_llm = OpenAI(temperature=0)
|
||||
second_output = second_llm(query, stop=["3"])
|
||||
# Because it stops on new lines, shouldn't return anything
|
||||
assert first_output == second_output
|
||||
|
||||
|
||||
def test_openai_stop_error() -> None:
|
||||
"""Test openai stop logic on bad configuration."""
|
||||
llm = OpenAI(stop="3", temperature=0)
|
||||
with pytest.raises(ValueError):
|
||||
llm("write an ordered list of five items", stop=["\n"])
|
||||
|
||||
|
||||
@pytest.mark.scheduled
|
||||
def test_openai_streaming() -> None:
|
||||
"""Test streaming tokens from OpenAI."""
|
||||
llm = OpenAI(max_tokens=10)
|
||||
generator = llm.stream("I'm Pickle Rick")
|
||||
|
||||
assert isinstance(generator, Generator)
|
||||
|
||||
for token in generator:
|
||||
assert isinstance(token, str)
|
||||
|
||||
|
||||
@pytest.mark.scheduled
|
||||
async def test_openai_astream() -> None:
|
||||
"""Test streaming tokens from OpenAI."""
|
||||
llm = OpenAI(max_tokens=10)
|
||||
|
||||
async for token in llm.astream("I'm Pickle Rick"):
|
||||
assert isinstance(token, str)
|
||||
|
||||
|
||||
@pytest.mark.scheduled
|
||||
async def test_openai_abatch() -> None:
|
||||
"""Test streaming tokens from OpenAI."""
|
||||
llm = OpenAI(max_tokens=10)
|
||||
|
||||
result = await llm.abatch(["I'm Pickle Rick", "I'm not Pickle Rick"])
|
||||
for token in result:
|
||||
assert isinstance(token, str)
|
||||
|
||||
|
||||
async def test_openai_abatch_tags() -> None:
|
||||
"""Test streaming tokens from OpenAI."""
|
||||
llm = OpenAI(max_tokens=10)
|
||||
|
||||
result = await llm.abatch(
|
||||
["I'm Pickle Rick", "I'm not Pickle Rick"], config={"tags": ["foo"]}
|
||||
)
|
||||
for token in result:
|
||||
assert isinstance(token, str)
|
||||
|
||||
|
||||
@pytest.mark.scheduled
|
||||
def test_openai_batch() -> None:
|
||||
"""Test streaming tokens from OpenAI."""
|
||||
llm = OpenAI(max_tokens=10)
|
||||
|
||||
result = llm.batch(["I'm Pickle Rick", "I'm not Pickle Rick"])
|
||||
for token in result:
|
||||
assert isinstance(token, str)
|
||||
|
||||
|
||||
@pytest.mark.scheduled
|
||||
async def test_openai_ainvoke() -> None:
|
||||
"""Test streaming tokens from OpenAI."""
|
||||
llm = OpenAI(max_tokens=10)
|
||||
|
||||
result = await llm.ainvoke("I'm Pickle Rick", config={"tags": ["foo"]})
|
||||
assert isinstance(result, str)
|
||||
|
||||
|
||||
@pytest.mark.scheduled
|
||||
def test_openai_invoke() -> None:
|
||||
"""Test streaming tokens from OpenAI."""
|
||||
llm = OpenAI(max_tokens=10)
|
||||
|
||||
result = llm.invoke("I'm Pickle Rick", config=dict(tags=["foo"]))
|
||||
assert isinstance(result, str)
|
||||
|
||||
|
||||
@pytest.mark.scheduled
|
||||
def test_openai_multiple_prompts() -> None:
|
||||
"""Test completion with multiple prompts."""
|
||||
llm = OpenAI(max_tokens=10)
|
||||
output = llm.generate(["I'm Pickle Rick", "I'm Pickle Rick"])
|
||||
assert isinstance(output, LLMResult)
|
||||
assert isinstance(output.generations, list)
|
||||
assert len(output.generations) == 2
|
||||
|
||||
|
||||
def test_openai_streaming_best_of_error() -> None:
|
||||
"""Test validation for streaming fails if best_of is not 1."""
|
||||
with pytest.raises(ValueError):
|
||||
OpenAI(best_of=2, streaming=True)
|
||||
|
||||
|
||||
def test_openai_streaming_n_error() -> None:
|
||||
"""Test validation for streaming fails if n is not 1."""
|
||||
with pytest.raises(ValueError):
|
||||
OpenAI(n=2, streaming=True)
|
||||
|
||||
|
||||
def test_openai_streaming_multiple_prompts_error() -> None:
|
||||
"""Test validation for streaming fails if multiple prompts are given."""
|
||||
with pytest.raises(ValueError):
|
||||
OpenAI(streaming=True).generate(["I'm Pickle Rick", "I'm Pickle Rick"])
|
||||
|
||||
|
||||
@pytest.mark.scheduled
|
||||
def test_openai_streaming_call() -> None:
|
||||
"""Test valid call to openai."""
|
||||
llm = OpenAI(max_tokens=10, streaming=True)
|
||||
output = llm("Say foo:")
|
||||
assert isinstance(output, str)
|
||||
|
||||
|
||||
def test_openai_streaming_callback() -> None:
|
||||
"""Test that streaming correctly invokes on_llm_new_token callback."""
|
||||
callback_handler = FakeCallbackHandler()
|
||||
callback_manager = CallbackManager([callback_handler])
|
||||
llm = OpenAI(
|
||||
max_tokens=10,
|
||||
streaming=True,
|
||||
temperature=0,
|
||||
callback_manager=callback_manager,
|
||||
verbose=True,
|
||||
)
|
||||
llm("Write me a sentence with 100 words.")
|
||||
|
||||
# new client sometimes passes 2 tokens at once
|
||||
assert callback_handler.llm_streams >= 5
|
||||
|
||||
|
||||
@pytest.mark.scheduled
|
||||
async def test_openai_async_generate() -> None:
|
||||
"""Test async generation."""
|
||||
llm = OpenAI(max_tokens=10)
|
||||
output = await llm.agenerate(["Hello, how are you?"])
|
||||
assert isinstance(output, LLMResult)
|
||||
|
||||
|
||||
async def test_openai_async_streaming_callback() -> None:
|
||||
"""Test that streaming correctly invokes on_llm_new_token callback."""
|
||||
callback_handler = FakeCallbackHandler()
|
||||
callback_manager = CallbackManager([callback_handler])
|
||||
llm = OpenAI(
|
||||
max_tokens=10,
|
||||
streaming=True,
|
||||
temperature=0,
|
||||
callback_manager=callback_manager,
|
||||
verbose=True,
|
||||
)
|
||||
result = await llm.agenerate(["Write me a sentence with 100 words."])
|
||||
|
||||
# new client sometimes passes 2 tokens at once
|
||||
assert callback_handler.llm_streams >= 5
|
||||
assert isinstance(result, LLMResult)
|
||||
|
||||
|
||||
def test_openai_modelname_to_contextsize_valid() -> None:
|
||||
"""Test model name to context size on a valid model."""
|
||||
assert OpenAI().modelname_to_contextsize("davinci") == 2049
|
||||
|
||||
|
||||
def test_openai_modelname_to_contextsize_invalid() -> None:
|
||||
"""Test model name to context size on an invalid model."""
|
||||
with pytest.raises(ValueError):
|
||||
OpenAI().modelname_to_contextsize("foobar")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_completion() -> dict:
|
||||
return {
|
||||
"id": "cmpl-3evkmQda5Hu7fcZavknQda3SQ",
|
||||
"object": "text_completion",
|
||||
"created": 1689989000,
|
||||
"model": "gpt-3.5-turbo-instruct",
|
||||
"choices": [
|
||||
{"text": "Bar Baz", "index": 0, "logprobs": None, "finish_reason": "length"}
|
||||
],
|
||||
"usage": {"prompt_tokens": 1, "completion_tokens": 2, "total_tokens": 3},
|
||||
}
|
@@ -0,0 +1,7 @@
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.compile
|
||||
def test_placeholder() -> None:
|
||||
"""Used for compiling integration tests without running any real tests."""
|
||||
pass
|
0
libs/partners/openai/tests/unit_tests/__init__.py
Normal file
0
libs/partners/openai/tests/unit_tests/__init__.py
Normal file
120
libs/partners/openai/tests/unit_tests/chat_models/test_base.py
Normal file
120
libs/partners/openai/tests/unit_tests/chat_models/test_base.py
Normal file
@@ -0,0 +1,120 @@
|
||||
"""Test OpenAI Chat API wrapper."""
|
||||
import json
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
FunctionMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
)
|
||||
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain_openai.chat_models.base import _convert_dict_to_message
|
||||
|
||||
|
||||
def test_openai_model_param() -> None:
|
||||
llm = ChatOpenAI(model="foo")
|
||||
assert llm.model_name == "foo"
|
||||
llm = ChatOpenAI(model_name="foo")
|
||||
assert llm.model_name == "foo"
|
||||
|
||||
|
||||
def test_function_message_dict_to_function_message() -> None:
|
||||
content = json.dumps({"result": "Example #1"})
|
||||
name = "test_function"
|
||||
result = _convert_dict_to_message(
|
||||
{
|
||||
"role": "function",
|
||||
"name": name,
|
||||
"content": content,
|
||||
}
|
||||
)
|
||||
assert isinstance(result, FunctionMessage)
|
||||
assert result.name == name
|
||||
assert result.content == content
|
||||
|
||||
|
||||
def test__convert_dict_to_message_human() -> None:
|
||||
message = {"role": "user", "content": "foo"}
|
||||
result = _convert_dict_to_message(message)
|
||||
expected_output = HumanMessage(content="foo")
|
||||
assert result == expected_output
|
||||
|
||||
|
||||
def test__convert_dict_to_message_ai() -> None:
|
||||
message = {"role": "assistant", "content": "foo"}
|
||||
result = _convert_dict_to_message(message)
|
||||
expected_output = AIMessage(content="foo")
|
||||
assert result == expected_output
|
||||
|
||||
|
||||
def test__convert_dict_to_message_system() -> None:
|
||||
message = {"role": "system", "content": "foo"}
|
||||
result = _convert_dict_to_message(message)
|
||||
expected_output = SystemMessage(content="foo")
|
||||
assert result == expected_output
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_completion() -> dict:
|
||||
return {
|
||||
"id": "chatcmpl-7fcZavknQda3SQ",
|
||||
"object": "chat.completion",
|
||||
"created": 1689989000,
|
||||
"model": "gpt-3.5-turbo-0613",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "Bar Baz",
|
||||
},
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def test_openai_predict(mock_completion: dict) -> None:
|
||||
llm = ChatOpenAI()
|
||||
mock_client = MagicMock()
|
||||
completed = False
|
||||
|
||||
def mock_create(*args: Any, **kwargs: Any) -> Any:
|
||||
nonlocal completed
|
||||
completed = True
|
||||
return mock_completion
|
||||
|
||||
mock_client.create = mock_create
|
||||
with patch.object(
|
||||
llm,
|
||||
"client",
|
||||
mock_client,
|
||||
):
|
||||
res = llm.predict("bar")
|
||||
assert res == "Bar Baz"
|
||||
assert completed
|
||||
|
||||
|
||||
async def test_openai_apredict(mock_completion: dict) -> None:
|
||||
llm = ChatOpenAI()
|
||||
mock_client = MagicMock()
|
||||
completed = False
|
||||
|
||||
def mock_create(*args: Any, **kwargs: Any) -> Any:
|
||||
nonlocal completed
|
||||
completed = True
|
||||
return mock_completion
|
||||
|
||||
mock_client.create = mock_create
|
||||
with patch.object(
|
||||
llm,
|
||||
"client",
|
||||
mock_client,
|
||||
):
|
||||
res = llm.predict("bar")
|
||||
assert res == "Bar Baz"
|
||||
assert completed
|
@@ -0,0 +1,7 @@
|
||||
from langchain_openai.chat_models import __all__
|
||||
|
||||
EXPECTED_ALL = ["ChatOpenAI", "AzureChatOpenAI"]
|
||||
|
||||
|
||||
def test_all_imports() -> None:
|
||||
assert sorted(EXPECTED_ALL) == sorted(__all__)
|
@@ -0,0 +1,18 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain_openai import OpenAIEmbeddings
|
||||
|
||||
os.environ["OPENAI_API_KEY"] = "foo"
|
||||
|
||||
|
||||
def test_openai_invalid_model_kwargs() -> None:
|
||||
with pytest.raises(ValueError):
|
||||
OpenAIEmbeddings(model_kwargs={"model": "foo"})
|
||||
|
||||
|
||||
def test_openai_incorrect_field() -> None:
|
||||
with pytest.warns(match="not default parameter"):
|
||||
llm = OpenAIEmbeddings(foo="bar")
|
||||
assert llm.model_kwargs == {"foo": "bar"}
|
@@ -0,0 +1,7 @@
|
||||
from langchain_openai.embeddings import __all__
|
||||
|
||||
EXPECTED_ALL = ["OpenAIEmbeddings", "AzureOpenAIEmbeddings"]
|
||||
|
||||
|
||||
def test_all_imports() -> None:
|
||||
assert sorted(EXPECTED_ALL) == sorted(__all__)
|
393
libs/partners/openai/tests/unit_tests/fake/callbacks.py
Normal file
393
libs/partners/openai/tests/unit_tests/fake/callbacks.py
Normal file
@@ -0,0 +1,393 @@
|
||||
"""A fake callback handler for testing purposes."""
|
||||
from itertools import chain
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from uuid import UUID
|
||||
|
||||
from langchain_core.callbacks.base import AsyncCallbackHandler, BaseCallbackHandler
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
|
||||
|
||||
class BaseFakeCallbackHandler(BaseModel):
|
||||
"""Base fake callback handler for testing."""
|
||||
|
||||
starts: int = 0
|
||||
ends: int = 0
|
||||
errors: int = 0
|
||||
errors_args: List[Any] = []
|
||||
text: int = 0
|
||||
ignore_llm_: bool = False
|
||||
ignore_chain_: bool = False
|
||||
ignore_agent_: bool = False
|
||||
ignore_retriever_: bool = False
|
||||
ignore_chat_model_: bool = False
|
||||
|
||||
# to allow for similar callback handlers that are not technicall equal
|
||||
fake_id: Union[str, None] = None
|
||||
|
||||
# add finer-grained counters for easier debugging of failing tests
|
||||
chain_starts: int = 0
|
||||
chain_ends: int = 0
|
||||
llm_starts: int = 0
|
||||
llm_ends: int = 0
|
||||
llm_streams: int = 0
|
||||
tool_starts: int = 0
|
||||
tool_ends: int = 0
|
||||
agent_actions: int = 0
|
||||
agent_ends: int = 0
|
||||
chat_model_starts: int = 0
|
||||
retriever_starts: int = 0
|
||||
retriever_ends: int = 0
|
||||
retriever_errors: int = 0
|
||||
retries: int = 0
|
||||
|
||||
|
||||
class BaseFakeCallbackHandlerMixin(BaseFakeCallbackHandler):
|
||||
"""Base fake callback handler mixin for testing."""
|
||||
|
||||
def on_llm_start_common(self) -> None:
|
||||
self.llm_starts += 1
|
||||
self.starts += 1
|
||||
|
||||
def on_llm_end_common(self) -> None:
|
||||
self.llm_ends += 1
|
||||
self.ends += 1
|
||||
|
||||
def on_llm_error_common(self, *args: Any, **kwargs: Any) -> None:
|
||||
self.errors += 1
|
||||
self.errors_args.append({"args": args, "kwargs": kwargs})
|
||||
|
||||
def on_llm_new_token_common(self) -> None:
|
||||
self.llm_streams += 1
|
||||
|
||||
def on_retry_common(self) -> None:
|
||||
self.retries += 1
|
||||
|
||||
def on_chain_start_common(self) -> None:
|
||||
self.chain_starts += 1
|
||||
self.starts += 1
|
||||
|
||||
def on_chain_end_common(self) -> None:
|
||||
self.chain_ends += 1
|
||||
self.ends += 1
|
||||
|
||||
def on_chain_error_common(self) -> None:
|
||||
self.errors += 1
|
||||
|
||||
def on_tool_start_common(self) -> None:
|
||||
self.tool_starts += 1
|
||||
self.starts += 1
|
||||
|
||||
def on_tool_end_common(self) -> None:
|
||||
self.tool_ends += 1
|
||||
self.ends += 1
|
||||
|
||||
def on_tool_error_common(self) -> None:
|
||||
self.errors += 1
|
||||
|
||||
def on_agent_action_common(self) -> None:
|
||||
self.agent_actions += 1
|
||||
self.starts += 1
|
||||
|
||||
def on_agent_finish_common(self) -> None:
|
||||
self.agent_ends += 1
|
||||
self.ends += 1
|
||||
|
||||
def on_chat_model_start_common(self) -> None:
|
||||
self.chat_model_starts += 1
|
||||
self.starts += 1
|
||||
|
||||
def on_text_common(self) -> None:
|
||||
self.text += 1
|
||||
|
||||
def on_retriever_start_common(self) -> None:
|
||||
self.starts += 1
|
||||
self.retriever_starts += 1
|
||||
|
||||
def on_retriever_end_common(self) -> None:
|
||||
self.ends += 1
|
||||
self.retriever_ends += 1
|
||||
|
||||
def on_retriever_error_common(self) -> None:
|
||||
self.errors += 1
|
||||
self.retriever_errors += 1
|
||||
|
||||
|
||||
class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
|
||||
"""Fake callback handler for testing."""
|
||||
|
||||
@property
|
||||
def ignore_llm(self) -> bool:
|
||||
"""Whether to ignore LLM callbacks."""
|
||||
return self.ignore_llm_
|
||||
|
||||
@property
|
||||
def ignore_chain(self) -> bool:
|
||||
"""Whether to ignore chain callbacks."""
|
||||
return self.ignore_chain_
|
||||
|
||||
@property
|
||||
def ignore_agent(self) -> bool:
|
||||
"""Whether to ignore agent callbacks."""
|
||||
return self.ignore_agent_
|
||||
|
||||
@property
|
||||
def ignore_retriever(self) -> bool:
|
||||
"""Whether to ignore retriever callbacks."""
|
||||
return self.ignore_retriever_
|
||||
|
||||
def on_llm_start(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
self.on_llm_start_common()
|
||||
|
||||
def on_llm_new_token(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
self.on_llm_new_token_common()
|
||||
|
||||
def on_llm_end(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
self.on_llm_end_common()
|
||||
|
||||
def on_llm_error(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
self.on_llm_error_common(*args, **kwargs)
|
||||
|
||||
def on_retry(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
self.on_retry_common()
|
||||
|
||||
def on_chain_start(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
self.on_chain_start_common()
|
||||
|
||||
def on_chain_end(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
self.on_chain_end_common()
|
||||
|
||||
def on_chain_error(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
self.on_chain_error_common()
|
||||
|
||||
def on_tool_start(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
self.on_tool_start_common()
|
||||
|
||||
def on_tool_end(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
self.on_tool_end_common()
|
||||
|
||||
def on_tool_error(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
self.on_tool_error_common()
|
||||
|
||||
def on_agent_action(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
self.on_agent_action_common()
|
||||
|
||||
def on_agent_finish(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
self.on_agent_finish_common()
|
||||
|
||||
def on_text(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
self.on_text_common()
|
||||
|
||||
def on_retriever_start(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
self.on_retriever_start_common()
|
||||
|
||||
def on_retriever_end(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
self.on_retriever_end_common()
|
||||
|
||||
def on_retriever_error(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
self.on_retriever_error_common()
|
||||
|
||||
def __deepcopy__(self, memo: dict) -> "FakeCallbackHandler":
|
||||
return self
|
||||
|
||||
|
||||
class FakeCallbackHandlerWithChatStart(FakeCallbackHandler):
|
||||
def on_chat_model_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
messages: List[List[BaseMessage]],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
assert all(isinstance(m, BaseMessage) for m in chain(*messages))
|
||||
self.on_chat_model_start_common()
|
||||
|
||||
|
||||
class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixin):
|
||||
"""Fake async callback handler for testing."""
|
||||
|
||||
@property
|
||||
def ignore_llm(self) -> bool:
|
||||
"""Whether to ignore LLM callbacks."""
|
||||
return self.ignore_llm_
|
||||
|
||||
@property
|
||||
def ignore_chain(self) -> bool:
|
||||
"""Whether to ignore chain callbacks."""
|
||||
return self.ignore_chain_
|
||||
|
||||
@property
|
||||
def ignore_agent(self) -> bool:
|
||||
"""Whether to ignore agent callbacks."""
|
||||
return self.ignore_agent_
|
||||
|
||||
async def on_retry(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
self.on_retry_common()
|
||||
|
||||
async def on_llm_start(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
self.on_llm_start_common()
|
||||
|
||||
async def on_llm_new_token(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
self.on_llm_new_token_common()
|
||||
|
||||
async def on_llm_end(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
self.on_llm_end_common()
|
||||
|
||||
async def on_llm_error(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
self.on_llm_error_common(*args, **kwargs)
|
||||
|
||||
async def on_chain_start(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
self.on_chain_start_common()
|
||||
|
||||
async def on_chain_end(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
self.on_chain_end_common()
|
||||
|
||||
async def on_chain_error(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
self.on_chain_error_common()
|
||||
|
||||
async def on_tool_start(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
self.on_tool_start_common()
|
||||
|
||||
async def on_tool_end(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
self.on_tool_end_common()
|
||||
|
||||
async def on_tool_error(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
self.on_tool_error_common()
|
||||
|
||||
async def on_agent_action(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
self.on_agent_action_common()
|
||||
|
||||
async def on_agent_finish(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
self.on_agent_finish_common()
|
||||
|
||||
async def on_text(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
self.on_text_common()
|
||||
|
||||
def __deepcopy__(self, memo: dict) -> "FakeAsyncCallbackHandler":
|
||||
return self
|
48
libs/partners/openai/tests/unit_tests/llms/test_base.py
Normal file
48
libs/partners/openai/tests/unit_tests/llms/test_base.py
Normal file
@@ -0,0 +1,48 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain_openai import OpenAI
|
||||
|
||||
os.environ["OPENAI_API_KEY"] = "foo"
|
||||
|
||||
|
||||
@pytest.mark.requires("openai")
|
||||
def test_openai_model_param() -> None:
|
||||
llm = OpenAI(model="foo")
|
||||
assert llm.model_name == "foo"
|
||||
llm = OpenAI(model_name="foo")
|
||||
assert llm.model_name == "foo"
|
||||
|
||||
|
||||
@pytest.mark.requires("openai")
|
||||
def test_openai_model_kwargs() -> None:
|
||||
llm = OpenAI(model_kwargs={"foo": "bar"})
|
||||
assert llm.model_kwargs == {"foo": "bar"}
|
||||
|
||||
|
||||
@pytest.mark.requires("openai")
|
||||
def test_openai_invalid_model_kwargs() -> None:
|
||||
with pytest.raises(ValueError):
|
||||
OpenAI(model_kwargs={"model_name": "foo"})
|
||||
|
||||
|
||||
@pytest.mark.requires("openai")
|
||||
def test_openai_incorrect_field() -> None:
|
||||
with pytest.warns(match="not default parameter"):
|
||||
llm = OpenAI(foo="bar")
|
||||
assert llm.model_kwargs == {"foo": "bar"}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_completion() -> dict:
|
||||
return {
|
||||
"id": "cmpl-3evkmQda5Hu7fcZavknQda3SQ",
|
||||
"object": "text_completion",
|
||||
"created": 1689989000,
|
||||
"model": "text-davinci-003",
|
||||
"choices": [
|
||||
{"text": "Bar Baz", "index": 0, "logprobs": None, "finish_reason": "length"}
|
||||
],
|
||||
"usage": {"prompt_tokens": 1, "completion_tokens": 2, "total_tokens": 3},
|
||||
}
|
@@ -0,0 +1,7 @@
|
||||
from langchain_openai.llms import __all__
|
||||
|
||||
EXPECTED_ALL = ["OpenAI", "AzureOpenAI"]
|
||||
|
||||
|
||||
def test_all_imports() -> None:
|
||||
assert sorted(EXPECTED_ALL) == sorted(__all__)
|
14
libs/partners/openai/tests/unit_tests/test_imports.py
Normal file
14
libs/partners/openai/tests/unit_tests/test_imports.py
Normal file
@@ -0,0 +1,14 @@
|
||||
from langchain_openai import __all__
|
||||
|
||||
EXPECTED_ALL = [
|
||||
"OpenAI",
|
||||
"ChatOpenAI",
|
||||
"OpenAIEmbeddings",
|
||||
"AzureOpenAI",
|
||||
"AzureChatOpenAI",
|
||||
"AzureOpenAIEmbeddings",
|
||||
]
|
||||
|
||||
|
||||
def test_all_imports() -> None:
|
||||
assert sorted(EXPECTED_ALL) == sorted(__all__)
|
39
libs/partners/openai/tests/unit_tests/test_token_counts.py
Normal file
39
libs/partners/openai/tests/unit_tests/test_token_counts.py
Normal file
@@ -0,0 +1,39 @@
|
||||
import pytest
|
||||
|
||||
from langchain_openai import ChatOpenAI, OpenAI
|
||||
|
||||
_EXPECTED_NUM_TOKENS = {
|
||||
"ada": 17,
|
||||
"babbage": 17,
|
||||
"curie": 17,
|
||||
"davinci": 17,
|
||||
"gpt-4": 12,
|
||||
"gpt-4-32k": 12,
|
||||
"gpt-3.5-turbo": 12,
|
||||
}
|
||||
|
||||
_MODELS = models = [
|
||||
"ada",
|
||||
"babbage",
|
||||
"curie",
|
||||
"davinci",
|
||||
]
|
||||
_CHAT_MODELS = [
|
||||
"gpt-4",
|
||||
"gpt-4-32k",
|
||||
"gpt-3.5-turbo",
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", _MODELS)
|
||||
def test_openai_get_num_tokens(model: str) -> None:
|
||||
"""Test get_tokens."""
|
||||
llm = OpenAI(model=model)
|
||||
assert llm.get_num_tokens("表情符号是\n🦜🔗") == _EXPECTED_NUM_TOKENS[model]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", _CHAT_MODELS)
|
||||
def test_chat_openai_get_num_tokens(model: str) -> None:
|
||||
"""Test get_tokens."""
|
||||
llm = ChatOpenAI(model=model)
|
||||
assert llm.get_num_tokens("表情符号是\n🦜🔗") == _EXPECTED_NUM_TOKENS[model]
|
Reference in New Issue
Block a user