Fixed missing optional tags. Added default key value for Ollama (#12599)

Added missing Optional typings. Added default values for Ollama optional
keys.

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
Ackermann Yuriy 2023-10-31 12:30:10 +13:00 committed by GitHub
parent f6f3ca12e7
commit 99b69fe607
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 55 additions and 41 deletions

View File

@ -1481,8 +1481,8 @@ class CallbackManagerForChainGroup(CallbackManager):
def __init__(
self,
handlers: List[BaseCallbackHandler],
inheritable_handlers: List[BaseCallbackHandler] | None = None,
parent_run_id: UUID | None = None,
inheritable_handlers: Optional[List[BaseCallbackHandler]] = None,
parent_run_id: Optional[UUID] = None,
*,
parent_run_manager: CallbackManagerForChainRun,
**kwargs: Any,
@ -1817,8 +1817,8 @@ class AsyncCallbackManagerForChainGroup(AsyncCallbackManager):
def __init__(
self,
handlers: List[BaseCallbackHandler],
inheritable_handlers: List[BaseCallbackHandler] | None = None,
parent_run_id: UUID | None = None,
inheritable_handlers: Optional[List[BaseCallbackHandler]] = None,
parent_run_id: Optional[UUID] = None,
*,
parent_run_manager: AsyncCallbackManagerForChainRun,
**kwargs: Any,

View File

@ -40,62 +40,62 @@ class OllamaEmbeddings(BaseModel, Embeddings):
query_instruction: str = "query: "
"""Instruction used to embed the query."""
mirostat: Optional[int]
mirostat: Optional[int] = None
"""Enable Mirostat sampling for controlling perplexity.
(default: 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)"""
mirostat_eta: Optional[float]
mirostat_eta: Optional[float] = None
"""Influences how quickly the algorithm responds to feedback
from the generated text. A lower learning rate will result in
slower adjustments, while a higher learning rate will make
the algorithm more responsive. (Default: 0.1)"""
mirostat_tau: Optional[float]
mirostat_tau: Optional[float] = None
"""Controls the balance between coherence and diversity
of the output. A lower value will result in more focused and
coherent text. (Default: 5.0)"""
num_ctx: Optional[int]
num_ctx: Optional[int] = None
"""Sets the size of the context window used to generate the
next token. (Default: 2048) """
num_gpu: Optional[int]
num_gpu: Optional[int] = None
"""The number of GPUs to use. On macOS it defaults to 1 to
enable metal support, 0 to disable."""
num_thread: Optional[int]
num_thread: Optional[int] = None
"""Sets the number of threads to use during computation.
By default, Ollama will detect this for optimal performance.
It is recommended to set this value to the number of physical
CPU cores your system has (as opposed to the logical number of cores)."""
repeat_last_n: Optional[int]
repeat_last_n: Optional[int] = None
"""Sets how far back for the model to look back to prevent
repetition. (Default: 64, 0 = disabled, -1 = num_ctx)"""
repeat_penalty: Optional[float]
repeat_penalty: Optional[float] = None
"""Sets how strongly to penalize repetitions. A higher value (e.g., 1.5)
will penalize repetitions more strongly, while a lower value (e.g., 0.9)
will be more lenient. (Default: 1.1)"""
temperature: Optional[float]
temperature: Optional[float] = None
"""The temperature of the model. Increasing the temperature will
make the model answer more creatively. (Default: 0.8)"""
stop: Optional[List[str]]
stop: Optional[List[str]] = None
"""Sets the stop tokens to use."""
tfs_z: Optional[float]
tfs_z: Optional[float] = None
"""Tail free sampling is used to reduce the impact of less probable
tokens from the output. A higher value (e.g., 2.0) will reduce the
impact more, while a value of 1.0 disables this setting. (default: 1)"""
top_k: Optional[int]
top_k: Optional[int] = None
"""Reduces the probability of generating nonsense. A higher value (e.g. 100)
will give more diverse answers, while a lower value (e.g. 10)
will be more conservative. (Default: 40)"""
top_p: Optional[int]
top_p: Optional[int] = None
"""Works together with top-k. A higher value (e.g., 0.95) will lead
to more diverse text, while a lower value (e.g., 0.5) will
generate more focused and conservative text. (Default: 0.9)"""

View File

@ -29,62 +29,62 @@ class _OllamaCommon(BaseLanguageModel):
model: str = "llama2"
"""Model name to use."""
mirostat: Optional[int]
mirostat: Optional[int] = None
"""Enable Mirostat sampling for controlling perplexity.
(default: 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)"""
mirostat_eta: Optional[float]
mirostat_eta: Optional[float] = None
"""Influences how quickly the algorithm responds to feedback
from the generated text. A lower learning rate will result in
slower adjustments, while a higher learning rate will make
the algorithm more responsive. (Default: 0.1)"""
mirostat_tau: Optional[float]
mirostat_tau: Optional[float] = None
"""Controls the balance between coherence and diversity
of the output. A lower value will result in more focused and
coherent text. (Default: 5.0)"""
num_ctx: Optional[int]
num_ctx: Optional[int] = None
"""Sets the size of the context window used to generate the
next token. (Default: 2048) """
num_gpu: Optional[int]
num_gpu: Optional[int] = None
"""The number of GPUs to use. On macOS it defaults to 1 to
enable metal support, 0 to disable."""
num_thread: Optional[int]
num_thread: Optional[int] = None
"""Sets the number of threads to use during computation.
By default, Ollama will detect this for optimal performance.
It is recommended to set this value to the number of physical
CPU cores your system has (as opposed to the logical number of cores)."""
repeat_last_n: Optional[int]
repeat_last_n: Optional[int] = None
"""Sets how far back for the model to look back to prevent
repetition. (Default: 64, 0 = disabled, -1 = num_ctx)"""
repeat_penalty: Optional[float]
repeat_penalty: Optional[float] = None
"""Sets how strongly to penalize repetitions. A higher value (e.g., 1.5)
will penalize repetitions more strongly, while a lower value (e.g., 0.9)
will be more lenient. (Default: 1.1)"""
temperature: Optional[float]
temperature: Optional[float] = None
"""The temperature of the model. Increasing the temperature will
make the model answer more creatively. (Default: 0.8)"""
stop: Optional[List[str]]
stop: Optional[List[str]] = None
"""Sets the stop tokens to use."""
tfs_z: Optional[float]
tfs_z: Optional[float] = None
"""Tail free sampling is used to reduce the impact of less probable
tokens from the output. A higher value (e.g., 2.0) will reduce the
impact more, while a value of 1.0 disables this setting. (default: 1)"""
top_k: Optional[int]
top_k: Optional[int] = None
"""Reduces the probability of generating nonsense. A higher value (e.g. 100)
will give more diverse answers, while a lower value (e.g. 10)
will be more conservative. (Default: 40)"""
top_p: Optional[int]
top_p: Optional[int] = None
"""Works together with top-k. A higher value (e.g., 0.95) will lead
to more diverse text, while a lower value (e.g., 0.5) will
generate more focused and conservative text. (Default: 0.9)"""

View File

@ -248,7 +248,7 @@ class OpenLLM(LLM):
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: CallbackManagerForLLMRun | None = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
try:

View File

@ -36,7 +36,7 @@ class CassandraChatMessageHistory(BaseChatMessageHistory):
session: Session,
keyspace: str,
table_name: str = DEFAULT_TABLE_NAME,
ttl_seconds: int | None = DEFAULT_TTL_SECONDS,
ttl_seconds: typing.Optional[int] = DEFAULT_TTL_SECONDS,
) -> None:
try:
from cassio.history import StoredBlobHistory

View File

@ -101,7 +101,7 @@ class BaseGenerationOutputParser(
async def ainvoke(
self,
input: str | BaseMessage,
config: RunnableConfig | None = None,
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> T:
if isinstance(input, BaseMessage):
@ -190,7 +190,7 @@ class BaseOutputParser(
async def ainvoke(
self,
input: str | BaseMessage,
config: RunnableConfig | None = None,
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> T:
if isinstance(input, BaseMessage):

View File

@ -54,7 +54,9 @@ class BasePromptTemplate(RunnableSerializable[Dict, PromptValue], ABC):
**{k: (self.input_types.get(k, str), None) for k in self.input_variables},
)
def invoke(self, input: Dict, config: RunnableConfig | None = None) -> PromptValue:
def invoke(
self, input: Dict, config: Optional[RunnableConfig] = None
) -> PromptValue:
return self._call_with_config(
lambda inner_input: self.format_prompt(
**{key: inner_input[key] for key in self.input_variables}

View File

@ -304,7 +304,7 @@ class StringRunEvaluatorChain(Chain, RunEvaluator):
async def _acall(
self,
inputs: Dict[str, str],
run_manager: AsyncCallbackManagerForChainRun | None = None,
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
"""Call the evaluation chain."""
evaluate_strings_inputs = self._prepare_input(inputs)

View File

@ -923,8 +923,8 @@ class FAISS(VectorStore):
cls,
texts: list[str],
embedding: Embeddings,
metadatas: List[dict] | None = None,
ids: List[str] | None = None,
metadatas: Optional[List[dict]] = None,
ids: Optional[List[str]] = None,
**kwargs: Any,
) -> FAISS:
"""Construct FAISS wrapper from raw documents asynchronously.

View File

@ -145,7 +145,10 @@ class FakeRetrieverV2(BaseRetriever):
throw_error: bool = False
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun | None
self,
query: str,
*,
run_manager: Optional[CallbackManagerForRetrieverRun] = None,
) -> List[Document]:
assert isinstance(self, FakeRetrieverV2)
assert run_manager is not None
@ -157,7 +160,10 @@ class FakeRetrieverV2(BaseRetriever):
]
async def _aget_relevant_documents(
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun | None
self,
query: str,
*,
run_manager: Optional[AsyncCallbackManagerForRetrieverRun] = None,
) -> List[Document]:
assert isinstance(self, FakeRetrieverV2)
assert run_manager is not None

View File

@ -1,3 +1,5 @@
from typing import Optional
FULL_PROMPT = """# Context
- Plate-based data is rectangular and could be situated anywhere within the dataset.
- The first item in every row is the row index
@ -54,7 +56,11 @@ AI_REPONSE_DICT = {
}
def create_prompt(num_plates: int | None, num_rows: int | None, num_cols: int | None):
def create_prompt(
num_plates: Optional[int] = None,
num_rows: Optional[int] = None,
num_cols: Optional[int] = None,
) -> str:
additional_prompts = []
if num_plates:
num_plates_str = f"are {num_plates} plates" if num_plates > 1 else "is 1 plate"