mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-13 16:36:06 +00:00
fix linting errors
This commit is contained in:
parent
631289a38d
commit
248db75cd6
@ -19,19 +19,20 @@ from typing import (
|
|||||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||||
from langchain.chains.base import Chain
|
from langchain.chains.base import Chain
|
||||||
from langchain.chains.llm import LLMChain
|
from langchain.chains.llm import LLMChain
|
||||||
from langchain_experimental.rl_chain.metrics import (
|
|
||||||
MetricsTrackerAverage,
|
|
||||||
MetricsTrackerRollingWindow,
|
|
||||||
)
|
|
||||||
from langchain_experimental.rl_chain.model_repository import ModelRepository
|
|
||||||
from langchain_experimental.rl_chain.vw_logger import VwLogger
|
|
||||||
from langchain.prompts import (
|
from langchain.prompts import (
|
||||||
BasePromptTemplate,
|
BasePromptTemplate,
|
||||||
ChatPromptTemplate,
|
ChatPromptTemplate,
|
||||||
HumanMessagePromptTemplate,
|
HumanMessagePromptTemplate,
|
||||||
SystemMessagePromptTemplate,
|
SystemMessagePromptTemplate,
|
||||||
)
|
)
|
||||||
|
|
||||||
from langchain_experimental.pydantic_v1 import BaseModel, Extra, root_validator
|
from langchain_experimental.pydantic_v1 import BaseModel, Extra, root_validator
|
||||||
|
from langchain_experimental.rl_chain.metrics import (
|
||||||
|
MetricsTrackerAverage,
|
||||||
|
MetricsTrackerRollingWindow,
|
||||||
|
)
|
||||||
|
from langchain_experimental.rl_chain.model_repository import ModelRepository
|
||||||
|
from langchain_experimental.rl_chain.vw_logger import VwLogger
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
import vowpal_wabbit_next as vw
|
import vowpal_wabbit_next as vw
|
||||||
|
@ -3,12 +3,13 @@ from __future__ import annotations
|
|||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
import langchain_experimental.rl_chain.base as base
|
|
||||||
from langchain.base_language import BaseLanguageModel
|
from langchain.base_language import BaseLanguageModel
|
||||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||||
from langchain.chains.llm import LLMChain
|
from langchain.chains.llm import LLMChain
|
||||||
from langchain.prompts import BasePromptTemplate
|
from langchain.prompts import BasePromptTemplate
|
||||||
|
|
||||||
|
import langchain_experimental.rl_chain.base as base
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# sentinel object used to distinguish between
|
# sentinel object used to distinguish between
|
||||||
|
@ -1,12 +1,12 @@
|
|||||||
from typing import Any, Dict
|
from typing import Any, Dict
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from langchain.chat_models import FakeListChatModel
|
||||||
|
from langchain.prompts.prompt import PromptTemplate
|
||||||
from test_utils import MockEncoder, MockEncoderReturnsList
|
from test_utils import MockEncoder, MockEncoderReturnsList
|
||||||
|
|
||||||
import langchain_experimental.rl_chain.base as rl_chain
|
import langchain_experimental.rl_chain.base as rl_chain
|
||||||
import langchain_experimental.rl_chain.pick_best_chain as pick_best_chain
|
import langchain_experimental.rl_chain.pick_best_chain as pick_best_chain
|
||||||
from langchain.chat_models import FakeListChatModel
|
|
||||||
from langchain.prompts.prompt import PromptTemplate
|
|
||||||
|
|
||||||
encoded_keyword = "[encoded]"
|
encoded_keyword = "[encoded]"
|
||||||
|
|
||||||
@ -90,11 +90,13 @@ def test_update_with_delayed_score_with_auto_validator_throws() -> None:
|
|||||||
User=rl_chain.BasedOn("Context"),
|
User=rl_chain.BasedOn("Context"),
|
||||||
action=rl_chain.ToSelectFrom(actions),
|
action=rl_chain.ToSelectFrom(actions),
|
||||||
)
|
)
|
||||||
assert response["response"] == "hey" # type: ignore
|
assert response["response"] == "hey" # type: ignore
|
||||||
selection_metadata = response["selection_metadata"] # type: ignore
|
selection_metadata = response["selection_metadata"] # type: ignore
|
||||||
assert selection_metadata.selected.score == 3.0 # type: ignore
|
assert selection_metadata.selected.score == 3.0 # type: ignore
|
||||||
with pytest.raises(RuntimeError):
|
with pytest.raises(RuntimeError):
|
||||||
chain.update_with_delayed_score(chain_response=response, score=100) # type: ignore
|
chain.update_with_delayed_score(
|
||||||
|
chain_response=response, score=100 # type: ignore
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
||||||
@ -209,7 +211,7 @@ def test_everything_embedded() -> None:
|
|||||||
action=rl_chain.EmbedAndKeep(rl_chain.ToSelectFrom(actions)),
|
action=rl_chain.EmbedAndKeep(rl_chain.ToSelectFrom(actions)),
|
||||||
)
|
)
|
||||||
selection_metadata = response["selection_metadata"] # type: ignore
|
selection_metadata = response["selection_metadata"] # type: ignore
|
||||||
vw_str = feature_embedder.format(selection_metadata) # type: ignore
|
vw_str = feature_embedder.format(selection_metadata) # type: ignore
|
||||||
assert vw_str == expected
|
assert vw_str == expected
|
||||||
|
|
||||||
|
|
||||||
@ -237,7 +239,7 @@ def test_default_auto_embedder_is_off() -> None:
|
|||||||
action=pick_best_chain.base.ToSelectFrom(actions),
|
action=pick_best_chain.base.ToSelectFrom(actions),
|
||||||
)
|
)
|
||||||
selection_metadata = response["selection_metadata"] # type: ignore
|
selection_metadata = response["selection_metadata"] # type: ignore
|
||||||
vw_str = feature_embedder.format(selection_metadata) # type: ignore
|
vw_str = feature_embedder.format(selection_metadata) # type: ignore
|
||||||
assert vw_str == expected
|
assert vw_str == expected
|
||||||
|
|
||||||
|
|
||||||
@ -264,8 +266,8 @@ def test_default_w_embeddings_off() -> None:
|
|||||||
User=rl_chain.BasedOn(ctx_str_1),
|
User=rl_chain.BasedOn(ctx_str_1),
|
||||||
action=rl_chain.ToSelectFrom(actions),
|
action=rl_chain.ToSelectFrom(actions),
|
||||||
)
|
)
|
||||||
selection_metadata = response["selection_metadata"] # type: ignore
|
selection_metadata = response["selection_metadata"] # type: ignore
|
||||||
vw_str = feature_embedder.format(selection_metadata) # type: ignore
|
vw_str = feature_embedder.format(selection_metadata) # type: ignore
|
||||||
assert vw_str == expected
|
assert vw_str == expected
|
||||||
|
|
||||||
|
|
||||||
@ -292,8 +294,8 @@ def test_default_w_embeddings_on() -> None:
|
|||||||
User=rl_chain.BasedOn(ctx_str_1),
|
User=rl_chain.BasedOn(ctx_str_1),
|
||||||
action=rl_chain.ToSelectFrom(actions),
|
action=rl_chain.ToSelectFrom(actions),
|
||||||
)
|
)
|
||||||
selection_metadata = response["selection_metadata"] # type: ignore
|
selection_metadata = response["selection_metadata"] # type: ignore
|
||||||
vw_str = feature_embedder.format(selection_metadata) # type: ignore
|
vw_str = feature_embedder.format(selection_metadata) # type: ignore
|
||||||
assert vw_str == expected
|
assert vw_str == expected
|
||||||
|
|
||||||
|
|
||||||
@ -324,8 +326,8 @@ def test_default_embeddings_mixed_w_explicit_user_embeddings() -> None:
|
|||||||
User2=rl_chain.BasedOn(ctx_str_2),
|
User2=rl_chain.BasedOn(ctx_str_2),
|
||||||
action=rl_chain.ToSelectFrom(actions),
|
action=rl_chain.ToSelectFrom(actions),
|
||||||
)
|
)
|
||||||
selection_metadata = response["selection_metadata"] # type: ignore
|
selection_metadata = response["selection_metadata"] # type: ignore
|
||||||
vw_str = feature_embedder.format(selection_metadata) # type: ignore
|
vw_str = feature_embedder.format(selection_metadata) # type: ignore
|
||||||
assert vw_str == expected
|
assert vw_str == expected
|
||||||
|
|
||||||
|
|
||||||
@ -345,9 +347,9 @@ def test_default_no_scorer_specified() -> None:
|
|||||||
action=rl_chain.ToSelectFrom(["0", "1", "2"]),
|
action=rl_chain.ToSelectFrom(["0", "1", "2"]),
|
||||||
)
|
)
|
||||||
# chain llm used for both basic prompt and for scoring
|
# chain llm used for both basic prompt and for scoring
|
||||||
assert response["response"] == "hey" # type: ignore
|
assert response["response"] == "hey" # type: ignore
|
||||||
selection_metadata = response["selection_metadata"] # type: ignore
|
selection_metadata = response["selection_metadata"] # type: ignore
|
||||||
assert selection_metadata.selected.score == 100.0 # type: ignore
|
assert selection_metadata.selected.score == 100.0 # type: ignore
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
||||||
@ -366,9 +368,9 @@ def test_explicitly_no_scorer() -> None:
|
|||||||
action=rl_chain.ToSelectFrom(["0", "1", "2"]),
|
action=rl_chain.ToSelectFrom(["0", "1", "2"]),
|
||||||
)
|
)
|
||||||
# chain llm used for both basic prompt and for scoring
|
# chain llm used for both basic prompt and for scoring
|
||||||
assert response["response"] == "hey" # type: ignore
|
assert response["response"] == "hey" # type: ignore
|
||||||
selection_metadata = response["selection_metadata"] # type: ignore
|
selection_metadata = response["selection_metadata"] # type: ignore
|
||||||
assert selection_metadata.selected.score is None # type: ignore
|
assert selection_metadata.selected.score is None # type: ignore
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
||||||
@ -388,9 +390,9 @@ def test_auto_scorer_with_user_defined_llm() -> None:
|
|||||||
action=rl_chain.ToSelectFrom(["0", "1", "2"]),
|
action=rl_chain.ToSelectFrom(["0", "1", "2"]),
|
||||||
)
|
)
|
||||||
# chain llm used for both basic prompt and for scoring
|
# chain llm used for both basic prompt and for scoring
|
||||||
assert response["response"] == "hey" # type: ignore
|
assert response["response"] == "hey" # type: ignore
|
||||||
selection_metadata = response["selection_metadata"] # type: ignore
|
selection_metadata = response["selection_metadata"] # type: ignore
|
||||||
assert selection_metadata.selected.score == 300.0 # type: ignore
|
assert selection_metadata.selected.score == 300.0 # type: ignore
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
@pytest.mark.requires("vowpal_wabbit_next", "sentence_transformers")
|
||||||
@ -434,24 +436,24 @@ def test_activate_and_deactivate_scorer() -> None:
|
|||||||
action=pick_best_chain.base.ToSelectFrom(["0", "1", "2"]),
|
action=pick_best_chain.base.ToSelectFrom(["0", "1", "2"]),
|
||||||
)
|
)
|
||||||
# chain llm used for both basic prompt and for scoring
|
# chain llm used for both basic prompt and for scoring
|
||||||
assert response["response"] == "hey1" # type: ignore
|
assert response["response"] == "hey1" # type: ignore
|
||||||
selection_metadata = response["selection_metadata"] # type: ignore
|
selection_metadata = response["selection_metadata"] # type: ignore
|
||||||
assert selection_metadata.selected.score == 300.0 # type: ignore
|
assert selection_metadata.selected.score == 300.0 # type: ignore
|
||||||
|
|
||||||
chain.deactivate_selection_scorer()
|
chain.deactivate_selection_scorer()
|
||||||
response = chain.run(
|
response = chain.run(
|
||||||
User=pick_best_chain.base.BasedOn("Context"),
|
User=pick_best_chain.base.BasedOn("Context"),
|
||||||
action=pick_best_chain.base.ToSelectFrom(["0", "1", "2"]),
|
action=pick_best_chain.base.ToSelectFrom(["0", "1", "2"]),
|
||||||
)
|
)
|
||||||
assert response["response"] == "hey2" # type: ignore
|
assert response["response"] == "hey2" # type: ignore
|
||||||
selection_metadata = response["selection_metadata"] # type: ignore
|
selection_metadata = response["selection_metadata"] # type: ignore
|
||||||
assert selection_metadata.selected.score is None # type: ignore
|
assert selection_metadata.selected.score is None # type: ignore
|
||||||
|
|
||||||
chain.activate_selection_scorer()
|
chain.activate_selection_scorer()
|
||||||
response = chain.run(
|
response = chain.run(
|
||||||
User=pick_best_chain.base.BasedOn("Context"),
|
User=pick_best_chain.base.BasedOn("Context"),
|
||||||
action=pick_best_chain.base.ToSelectFrom(["0", "1", "2"]),
|
action=pick_best_chain.base.ToSelectFrom(["0", "1", "2"]),
|
||||||
)
|
)
|
||||||
assert response["response"] == "hey3" # type: ignore
|
assert response["response"] == "hey3" # type: ignore
|
||||||
selection_metadata = response["selection_metadata"] # type: ignore
|
selection_metadata = response["selection_metadata"] # type: ignore
|
||||||
assert selection_metadata.selected.score == 400.0 # type: ignore
|
assert selection_metadata.selected.score == 400.0 # type: ignore
|
||||||
|
Loading…
Reference in New Issue
Block a user