mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-23 23:29:21 +00:00
Refactor some loops into list comprehensions (#1185)
This commit is contained in:
parent
926c121b98
commit
159c560c95
@ -75,9 +75,7 @@ class SQLAlchemyCache(BaseCache):
|
|||||||
.order_by(self.cache_schema.idx)
|
.order_by(self.cache_schema.idx)
|
||||||
)
|
)
|
||||||
with Session(self.engine) as session:
|
with Session(self.engine) as session:
|
||||||
generations = []
|
generations = [Generation(text=row[0]) for row in session.execute(stmt)]
|
||||||
for row in session.execute(stmt):
|
|
||||||
generations.append(Generation(text=row[0]))
|
|
||||||
if len(generations) > 0:
|
if len(generations) > 0:
|
||||||
return generations
|
return generations
|
||||||
return None
|
return None
|
||||||
|
@ -124,12 +124,11 @@ class LLMChain(Chain, BaseModel):
|
|||||||
|
|
||||||
def create_outputs(self, response: LLMResult) -> List[Dict[str, str]]:
|
def create_outputs(self, response: LLMResult) -> List[Dict[str, str]]:
|
||||||
"""Create outputs from response."""
|
"""Create outputs from response."""
|
||||||
outputs = []
|
return [
|
||||||
for generation in response.generations:
|
|
||||||
# Get the text of the top generated string.
|
# Get the text of the top generated string.
|
||||||
response_str = generation[0].text
|
{self.output_key: generation[0].text}
|
||||||
outputs.append({self.output_key: response_str})
|
for generation in response.generations
|
||||||
return outputs
|
]
|
||||||
|
|
||||||
def _call(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
def _call(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
||||||
return self.apply([inputs])[0]
|
return self.apply([inputs])[0]
|
||||||
@ -188,11 +187,9 @@ class LLMChain(Chain, BaseModel):
|
|||||||
self, result: List[Dict[str, str]]
|
self, result: List[Dict[str, str]]
|
||||||
) -> Sequence[Union[str, List[str], Dict[str, str]]]:
|
) -> Sequence[Union[str, List[str], Dict[str, str]]]:
|
||||||
if self.prompt.output_parser is not None:
|
if self.prompt.output_parser is not None:
|
||||||
new_result = []
|
return [
|
||||||
for res in result:
|
self.prompt.output_parser.parse(res[self.output_key]) for res in result
|
||||||
text = res[self.output_key]
|
]
|
||||||
new_result.append(self.prompt.output_parser.parse(text))
|
|
||||||
return new_result
|
|
||||||
else:
|
else:
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
@ -116,22 +116,19 @@ class GoogleDriveLoader(BaseLoader, BaseModel):
|
|||||||
)
|
)
|
||||||
items = results.get("files", [])
|
items = results.get("files", [])
|
||||||
|
|
||||||
docs = []
|
return [
|
||||||
for item in items:
|
self._load_document_from_id(item["id"])
|
||||||
|
for item in items
|
||||||
# Only support Google Docs for now
|
# Only support Google Docs for now
|
||||||
if item["mimeType"] == "application/vnd.google-apps.document":
|
if item["mimeType"] == "application/vnd.google-apps.document"
|
||||||
docs.append(self._load_document_from_id(item["id"]))
|
]
|
||||||
return docs
|
|
||||||
|
|
||||||
def _load_documents_from_ids(self) -> List[Document]:
|
def _load_documents_from_ids(self) -> List[Document]:
|
||||||
"""Load documents from a list of IDs."""
|
"""Load documents from a list of IDs."""
|
||||||
if not self.document_ids:
|
if not self.document_ids:
|
||||||
raise ValueError("document_ids must be set")
|
raise ValueError("document_ids must be set")
|
||||||
|
|
||||||
docs = []
|
return [self._load_document_from_id(doc_id) for doc_id in self.document_ids]
|
||||||
for doc_id in self.document_ids:
|
|
||||||
docs.append(self._load_document_from_id(doc_id))
|
|
||||||
return docs
|
|
||||||
|
|
||||||
def load(self) -> List[Document]:
|
def load(self) -> List[Document]:
|
||||||
"""Load documents."""
|
"""Load documents."""
|
||||||
|
@ -30,12 +30,13 @@ class HNLoader(WebBaseLoader):
|
|||||||
"""Load comments from a HN post."""
|
"""Load comments from a HN post."""
|
||||||
comments = soup_info.select("tr[class='athing comtr']")
|
comments = soup_info.select("tr[class='athing comtr']")
|
||||||
title = soup_info.select_one("tr[id='pagespace']").get("title")
|
title = soup_info.select_one("tr[id='pagespace']").get("title")
|
||||||
documents = []
|
return [
|
||||||
for comment in comments:
|
Document(
|
||||||
text = comment.text.strip()
|
page_content=comment.text.strip(),
|
||||||
metadata = {"source": self.web_path, "title": title}
|
metadata={"source": self.web_path, "title": title},
|
||||||
documents.append(Document(page_content=text, metadata=metadata))
|
)
|
||||||
return documents
|
for comment in comments
|
||||||
|
]
|
||||||
|
|
||||||
def load_results(self, soup: Any) -> List[Document]:
|
def load_results(self, soup: Any) -> List[Document]:
|
||||||
"""Load items from an HN page."""
|
"""Load items from an HN page."""
|
||||||
|
@ -25,12 +25,12 @@ class PagedPDFSplitter(BaseLoader):
|
|||||||
"""Load given path as pages."""
|
"""Load given path as pages."""
|
||||||
import pypdf
|
import pypdf
|
||||||
|
|
||||||
pdf_file_obj = open(self._file_path, "rb")
|
with open(self._file_path, "rb") as pdf_file_obj:
|
||||||
pdf_reader = pypdf.PdfReader(pdf_file_obj)
|
pdf_reader = pypdf.PdfReader(pdf_file_obj)
|
||||||
docs = []
|
return [
|
||||||
for i, page in enumerate(pdf_reader.pages):
|
Document(
|
||||||
text = page.extract_text()
|
page_content=page.extract_text(),
|
||||||
metadata = {"source": self._file_path, "page": i}
|
metadata={"source": self._file_path, "page": i},
|
||||||
docs.append(Document(page_content=text, metadata=metadata))
|
)
|
||||||
pdf_file_obj.close()
|
for i, page in enumerate(pdf_reader.pages)
|
||||||
return docs
|
]
|
||||||
|
@ -121,9 +121,7 @@ class HuggingFaceInstructEmbeddings(BaseModel, Embeddings):
|
|||||||
Returns:
|
Returns:
|
||||||
List of embeddings, one for each text.
|
List of embeddings, one for each text.
|
||||||
"""
|
"""
|
||||||
instruction_pairs = []
|
instruction_pairs = [[self.embed_instruction, text] for text in texts]
|
||||||
for text in texts:
|
|
||||||
instruction_pairs.append([self.embed_instruction, text])
|
|
||||||
embeddings = self.client.encode(instruction_pairs)
|
embeddings = self.client.encode(instruction_pairs)
|
||||||
return embeddings.tolist()
|
return embeddings.tolist()
|
||||||
|
|
||||||
|
@ -48,13 +48,13 @@ class QAEvalChain(LLMChain):
|
|||||||
prediction_key: str = "result",
|
prediction_key: str = "result",
|
||||||
) -> List[dict]:
|
) -> List[dict]:
|
||||||
"""Evaluate question answering examples and predictions."""
|
"""Evaluate question answering examples and predictions."""
|
||||||
inputs = []
|
inputs = [
|
||||||
for i, example in enumerate(examples):
|
{
|
||||||
_input = {
|
|
||||||
"query": example[question_key],
|
"query": example[question_key],
|
||||||
"answer": example[answer_key],
|
"answer": example[answer_key],
|
||||||
"result": predictions[i][prediction_key],
|
"result": predictions[i][prediction_key],
|
||||||
}
|
}
|
||||||
inputs.append(_input)
|
for i, example in enumerate(examples)
|
||||||
|
]
|
||||||
|
|
||||||
return self.apply(inputs)
|
return self.apply(inputs)
|
||||||
|
@ -329,7 +329,7 @@ class BaseOpenAI(BaseLLM, BaseModel):
|
|||||||
) -> LLMResult:
|
) -> LLMResult:
|
||||||
"""Create the LLMResult from the choices and prompts."""
|
"""Create the LLMResult from the choices and prompts."""
|
||||||
generations = []
|
generations = []
|
||||||
for i, prompt in enumerate(prompts):
|
for i, _ in enumerate(prompts):
|
||||||
sub_choices = choices[i * self.n : (i + 1) * self.n]
|
sub_choices = choices[i * self.n : (i + 1) * self.n]
|
||||||
generations.append(
|
generations.append(
|
||||||
[
|
[
|
||||||
|
@ -304,7 +304,6 @@ class SearxSearchWrapper(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
metadata_results = []
|
|
||||||
_params = {
|
_params = {
|
||||||
"q": query,
|
"q": query,
|
||||||
}
|
}
|
||||||
@ -314,14 +313,14 @@ class SearxSearchWrapper(BaseModel):
|
|||||||
results = self._searx_api_query(params).results[:num_results]
|
results = self._searx_api_query(params).results[:num_results]
|
||||||
if len(results) == 0:
|
if len(results) == 0:
|
||||||
return [{"Result": "No good Search Result was found"}]
|
return [{"Result": "No good Search Result was found"}]
|
||||||
for result in results:
|
|
||||||
metadata_result = {
|
return [
|
||||||
|
{
|
||||||
"snippet": result.get("content", ""),
|
"snippet": result.get("content", ""),
|
||||||
"title": result["title"],
|
"title": result["title"],
|
||||||
"link": result["url"],
|
"link": result["url"],
|
||||||
"engines": result["engines"],
|
"engines": result["engines"],
|
||||||
"category": result["category"],
|
"category": result["category"],
|
||||||
}
|
}
|
||||||
metadata_results.append(metadata_result)
|
for result in results
|
||||||
|
]
|
||||||
return metadata_results
|
|
||||||
|
Loading…
Reference in New Issue
Block a user