chore: minor clean up / formatting (#233)

to get familiarize with the project
This commit is contained in:
Xupeng (Tony) Tong 2022-12-02 02:50:36 +08:00 committed by GitHub
parent 473943643e
commit bb4bf9d6d0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 51 additions and 32 deletions

View File

@ -1,4 +1,6 @@
"""Chain that takes in an input and produces an action and action input.""" """Chain that takes in an input and produces an action and action input."""
from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, ClassVar, Dict, List, NamedTuple, Optional, Tuple from typing import Any, ClassVar, Dict, List, NamedTuple, Optional, Tuple
@ -91,7 +93,7 @@ class Agent(Chain, BaseModel, ABC):
pass pass
@classmethod @classmethod
def from_llm_and_tools(cls, llm: LLM, tools: List[Tool], **kwargs: Any) -> "Agent": def from_llm_and_tools(cls, llm: LLM, tools: List[Tool], **kwargs: Any) -> Agent:
"""Construct an agent from an LLM and tools.""" """Construct an agent from an LLM and tools."""
cls._validate_tools(tools) cls._validate_tools(tools)
llm_chain = LLMChain(llm=llm, prompt=cls.create_prompt(tools)) llm_chain = LLMChain(llm=llm, prompt=cls.create_prompt(tools))

View File

@ -1,4 +1,6 @@
"""Attempt to implement MRKL systems as described in arxiv.org/pdf/2205.00445.pdf.""" """Attempt to implement MRKL systems as described in arxiv.org/pdf/2205.00445.pdf."""
from __future__ import annotations
from typing import Any, Callable, List, NamedTuple, Optional, Tuple from typing import Any, Callable, List, NamedTuple, Optional, Tuple
from langchain.agents.agent import Agent from langchain.agents.agent import Agent
@ -114,7 +116,7 @@ class MRKLChain(ZeroShotAgent):
""" """
@classmethod @classmethod
def from_chains(cls, llm: LLM, chains: List[ChainConfig], **kwargs: Any) -> "Agent": def from_chains(cls, llm: LLM, chains: List[ChainConfig], **kwargs: Any) -> Agent:
"""User friendly way to initialize the MRKL chain. """User friendly way to initialize the MRKL chain.
This is intended to be an easy way to get up and running with the This is intended to be an easy way to get up and running with the

View File

@ -27,10 +27,7 @@ class SelfAskWithSearchAgent(Agent):
def _extract_tool_and_input(self, text: str) -> Optional[Tuple[str, str]]: def _extract_tool_and_input(self, text: str) -> Optional[Tuple[str, str]]:
followup = "Follow up:" followup = "Follow up:"
if "\n" not in text: last_line = text.split("\n")[-1]
last_line = text
else:
last_line = text.split("\n")[-1]
if followup not in last_line: if followup not in last_line:
finish_string = "So the final answer is: " finish_string = "So the final answer is: "
@ -38,10 +35,7 @@ class SelfAskWithSearchAgent(Agent):
return None return None
return "Final Answer", last_line[len(finish_string) :] return "Final Answer", last_line[len(finish_string) :]
if ":" not in last_line: after_colon = text.split(":")[-1]
after_colon = last_line
else:
after_colon = text.split(":")[-1]
if " " == after_colon[0]: if " " == after_colon[0]:
after_colon = after_colon[1:] after_colon = after_colon[1:]
@ -49,7 +43,7 @@ class SelfAskWithSearchAgent(Agent):
return "Intermediate Answer", after_colon return "Intermediate Answer", after_colon
def _fix_text(self, text: str) -> str: def _fix_text(self, text: str) -> str:
return text + "\nSo the final answer is:" return f"{text}\nSo the final answer is:"
@property @property
def observation_prefix(self) -> str: def observation_prefix(self) -> str:

View File

@ -84,8 +84,8 @@ class ConversationSummaryMemory(Memory, BaseModel):
prompt_input_key = _get_prompt_input_key(inputs, self.memory_variables) prompt_input_key = _get_prompt_input_key(inputs, self.memory_variables)
if len(outputs) != 1: if len(outputs) != 1:
raise ValueError(f"One output key expected, got {outputs.keys()}") raise ValueError(f"One output key expected, got {outputs.keys()}")
human = "Human: " + inputs[prompt_input_key] human = f"Human: {inputs[prompt_input_key]}"
ai = "AI: " + list(outputs.values())[0] ai = f"AI: {list(outputs.values())[0]}"
new_lines = "\n".join([human, ai]) new_lines = "\n".join([human, ai])
chain = LLMChain(llm=self.llm, prompt=self.prompt) chain = LLMChain(llm=self.llm, prompt=self.prompt)
self.buffer = chain.predict(summary=self.buffer, new_lines=new_lines) self.buffer = chain.predict(summary=self.buffer, new_lines=new_lines)

View File

@ -3,6 +3,7 @@
Splits up a document, sends the smaller parts to the LLM with one prompt, Splits up a document, sends the smaller parts to the LLM with one prompt,
then combines the results with another one. then combines the results with another one.
""" """
from __future__ import annotations
from typing import Dict, List from typing import Dict, List
@ -32,7 +33,7 @@ class MapReduceChain(Chain, BaseModel):
@classmethod @classmethod
def from_params( def from_params(
cls, llm: LLM, prompt: BasePromptTemplate, text_splitter: TextSplitter cls, llm: LLM, prompt: BasePromptTemplate, text_splitter: TextSplitter
) -> "MapReduceChain": ) -> MapReduceChain:
"""Construct a map-reduce chain that uses the chain for map and reduce.""" """Construct a map-reduce chain that uses the chain for map and reduce."""
llm_chain = LLMChain(llm=llm, prompt=prompt) llm_chain = LLMChain(llm=llm, prompt=prompt)
return cls(map_llm=llm_chain, reduce_llm=llm_chain, text_splitter=text_splitter) return cls(map_llm=llm_chain, reduce_llm=llm_chain, text_splitter=text_splitter)

View File

@ -1,4 +1,6 @@
"""Implement an LLM driven browser.""" """Implement an LLM driven browser."""
from __future__ import annotations
from typing import Dict, List from typing import Dict, List
from pydantic import BaseModel, Extra from pydantic import BaseModel, Extra
@ -36,7 +38,7 @@ class NatBotChain(Chain, BaseModel):
arbitrary_types_allowed = True arbitrary_types_allowed = True
@classmethod @classmethod
def from_default(cls, objective: str) -> "NatBotChain": def from_default(cls, objective: str) -> NatBotChain:
"""Load with default LLM.""" """Load with default LLM."""
llm = OpenAI(temperature=0.5, best_of=10, n=3, max_tokens=50) llm = OpenAI(temperature=0.5, best_of=10, n=3, max_tokens=50)
return cls(llm=llm, objective=objective) return cls(llm=llm, objective=objective)

View File

@ -2,6 +2,8 @@
As in https://arxiv.org/pdf/2211.10435.pdf. As in https://arxiv.org/pdf/2211.10435.pdf.
""" """
from __future__ import annotations
from typing import Any, Dict, List from typing import Any, Dict, List
from pydantic import BaseModel, Extra from pydantic import BaseModel, Extra
@ -57,7 +59,7 @@ class PALChain(Chain, BaseModel):
return {self.output_key: res.strip()} return {self.output_key: res.strip()}
@classmethod @classmethod
def from_math_prompt(cls, llm: LLM, **kwargs: Any) -> "PALChain": def from_math_prompt(cls, llm: LLM, **kwargs: Any) -> PALChain:
"""Load PAL from math prompt.""" """Load PAL from math prompt."""
return cls( return cls(
llm=llm, llm=llm,
@ -68,7 +70,7 @@ class PALChain(Chain, BaseModel):
) )
@classmethod @classmethod
def from_colored_object_prompt(cls, llm: LLM, **kwargs: Any) -> "PALChain": def from_colored_object_prompt(cls, llm: LLM, **kwargs: Any) -> PALChain:
"""Load PAL from colored object prompt.""" """Load PAL from colored object prompt."""
return cls( return cls(
llm=llm, llm=llm,

View File

@ -1,5 +1,7 @@
"""Question answering with sources over documents.""" """Question answering with sources over documents."""
from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Dict, List from typing import Any, Dict, List
@ -40,7 +42,7 @@ class BaseQAWithSourcesChain(Chain, BaseModel, ABC):
question_prompt: BasePromptTemplate = QUESTION_PROMPT, question_prompt: BasePromptTemplate = QUESTION_PROMPT,
combine_prompt: BasePromptTemplate = COMBINE_PROMPT, combine_prompt: BasePromptTemplate = COMBINE_PROMPT,
**kwargs: Any, **kwargs: Any,
) -> "BaseQAWithSourcesChain": ) -> BaseQAWithSourcesChain:
"""Construct the chain from an LLM.""" """Construct the chain from an LLM."""
llm_question_chain = LLMChain(llm=llm, prompt=question_prompt) llm_question_chain = LLMChain(llm=llm, prompt=question_prompt)
llm_combine_chain = LLMChain(llm=llm, prompt=combine_prompt) llm_combine_chain = LLMChain(llm=llm, prompt=combine_prompt)

View File

@ -54,7 +54,7 @@ class SQLDatabaseChain(Chain, BaseModel):
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]: def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
llm_chain = LLMChain(llm=self.llm, prompt=PROMPT) llm_chain = LLMChain(llm=self.llm, prompt=PROMPT)
chained_input = ChainedInput( chained_input = ChainedInput(
inputs[self.input_key] + "\nSQLQuery:", verbose=self.verbose f"{inputs[self.input_key]} \nSQLQuery:", verbose=self.verbose
) )
llm_inputs = { llm_inputs = {
"input": chained_input.input, "input": chained_input.input,

View File

@ -1,4 +1,6 @@
"""Experiment with different models.""" """Experiment with different models."""
from __future__ import annotations
from typing import List, Optional, Sequence, Union from typing import List, Optional, Sequence, Union
from langchain.agents.agent import Agent from langchain.agents.agent import Agent
@ -49,7 +51,7 @@ class ModelLaboratory:
@classmethod @classmethod
def from_llms( def from_llms(
cls, llms: List[LLM], prompt: Optional[PromptTemplate] = None cls, llms: List[LLM], prompt: Optional[PromptTemplate] = None
) -> "ModelLaboratory": ) -> ModelLaboratory:
"""Initialize with LLMs to experiment with and optional prompt. """Initialize with LLMs to experiment with and optional prompt.
Args: Args:

View File

@ -1,4 +1,6 @@
"""Example selector that selects examples based on SemanticSimilarity.""" """Example selector that selects examples based on SemanticSimilarity."""
from __future__ import annotations
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Extra from pydantic import BaseModel, Extra
@ -55,7 +57,7 @@ class SemanticSimilarityExampleSelector(BaseExampleSelector, BaseModel):
vectorstore_cls: VectorStore, vectorstore_cls: VectorStore,
k: int = 4, k: int = 4,
**vectorstore_cls_kwargs: Any, **vectorstore_cls_kwargs: Any,
) -> "SemanticSimilarityExampleSelector": ) -> SemanticSimilarityExampleSelector:
"""Create k-shot example selector using example list and embeddings. """Create k-shot example selector using example list and embeddings.
Reshuffles examples dynamically based on query similarity. Reshuffles examples dynamically based on query similarity.

View File

@ -1,4 +1,6 @@
"""Prompt schema definition.""" """Prompt schema definition."""
from __future__ import annotations
from typing import Any, Dict, List from typing import Any, Dict, List
from pydantic import BaseModel, Extra, root_validator from pydantic import BaseModel, Extra, root_validator
@ -67,7 +69,7 @@ class PromptTemplate(BasePromptTemplate, BaseModel):
input_variables: List[str], input_variables: List[str],
example_separator: str = "\n\n", example_separator: str = "\n\n",
prefix: str = "", prefix: str = "",
) -> "PromptTemplate": ) -> PromptTemplate:
"""Take examples in list format with prefix and suffix to create a prompt. """Take examples in list format with prefix and suffix to create a prompt.
Intended be used as a way to dynamically create a prompt from examples. Intended be used as a way to dynamically create a prompt from examples.
@ -92,7 +94,7 @@ class PromptTemplate(BasePromptTemplate, BaseModel):
@classmethod @classmethod
def from_file( def from_file(
cls, template_file: str, input_variables: List[str] cls, template_file: str, input_variables: List[str]
) -> "PromptTemplate": ) -> PromptTemplate:
"""Load a prompt from a file. """Load a prompt from a file.
Args: Args:

View File

@ -1,4 +1,6 @@
"""SQLAlchemy wrapper around a database.""" """SQLAlchemy wrapper around a database."""
from __future__ import annotations
from typing import Any, Iterable, List, Optional from typing import Any, Iterable, List, Optional
from sqlalchemy import create_engine, inspect from sqlalchemy import create_engine, inspect
@ -37,7 +39,7 @@ class SQLDatabase:
) )
@classmethod @classmethod
def from_uri(cls, database_uri: str, **kwargs: Any) -> "SQLDatabase": def from_uri(cls, database_uri: str, **kwargs: Any) -> SQLDatabase:
"""Construct a SQLAlchemy engine from URI.""" """Construct a SQLAlchemy engine from URI."""
return cls(create_engine(database_uri), **kwargs) return cls(create_engine(database_uri), **kwargs)

View File

@ -1,4 +1,6 @@
"""Functionality for splitting text.""" """Functionality for splitting text."""
from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Callable, Iterable, List from typing import Any, Callable, Iterable, List
@ -46,9 +48,7 @@ class TextSplitter(ABC):
return docs return docs
@classmethod @classmethod
def from_huggingface_tokenizer( def from_huggingface_tokenizer(cls, tokenizer: Any, **kwargs: Any) -> TextSplitter:
cls, tokenizer: Any, **kwargs: Any
) -> "TextSplitter":
"""Text splitter than uses HuggingFace tokenizer to count length.""" """Text splitter than uses HuggingFace tokenizer to count length."""
try: try:
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase

View File

@ -1,4 +1,6 @@
"""Interface for vector stores.""" """Interface for vector stores."""
from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Iterable, List, Optional from typing import Any, Iterable, List, Optional
@ -26,6 +28,6 @@ class VectorStore(ABC):
texts: List[str], texts: List[str],
embedding: Embeddings, embedding: Embeddings,
metadatas: Optional[List[dict]] = None, metadatas: Optional[List[dict]] = None,
**kwargs: Any **kwargs: Any,
) -> "VectorStore": ) -> VectorStore:
"""Return VectorStore initialized from texts and embeddings.""" """Return VectorStore initialized from texts and embeddings."""

View File

@ -1,4 +1,6 @@
"""Wrapper around Elasticsearch vector database.""" """Wrapper around Elasticsearch vector database."""
from __future__ import annotations
import uuid import uuid
from typing import Any, Callable, Dict, Iterable, List, Optional from typing import Any, Callable, Dict, Iterable, List, Optional
@ -117,7 +119,7 @@ class ElasticVectorSearch(VectorStore):
embedding: Embeddings, embedding: Embeddings,
metadatas: Optional[List[dict]] = None, metadatas: Optional[List[dict]] = None,
**kwargs: Any, **kwargs: Any,
) -> "ElasticVectorSearch": ) -> ElasticVectorSearch:
"""Construct ElasticVectorSearch wrapper from raw documents. """Construct ElasticVectorSearch wrapper from raw documents.
This is a user-friendly interface that: This is a user-friendly interface that:

View File

@ -1,4 +1,6 @@
"""Wrapper around FAISS vector database.""" """Wrapper around FAISS vector database."""
from __future__ import annotations
import uuid import uuid
from typing import Any, Callable, Dict, Iterable, List, Optional from typing import Any, Callable, Dict, Iterable, List, Optional
@ -96,7 +98,7 @@ class FAISS(VectorStore):
embedding: Embeddings, embedding: Embeddings,
metadatas: Optional[List[dict]] = None, metadatas: Optional[List[dict]] = None,
**kwargs: Any, **kwargs: Any,
) -> "FAISS": ) -> FAISS:
"""Construct FAISS wrapper from raw documents. """Construct FAISS wrapper from raw documents.
This is a user friendly interface that: This is a user friendly interface that:

View File

@ -28,7 +28,7 @@ class FakeChain(Chain, BaseModel):
outputs = {} outputs = {}
for var in self.output_variables: for var in self.output_variables:
variables = [inputs[k] for k in self.input_variables] variables = [inputs[k] for k in self.input_variables]
outputs[var] = " ".join(variables) + "foo" outputs[var] = f"{' '.join(variables)}foo"
return outputs return outputs