fix:update awel embedding examples and delete unuseful code. (#1073)

This commit is contained in:
Aries-ckt
2024-01-15 23:22:52 +08:00
committed by GitHub
parent cb9c34abb9
commit 3a54d1ef9a
11 changed files with 45 additions and 1437 deletions

View File

@@ -1,40 +1,32 @@
import asyncio
import os
from typing import Dict, List
from pydantic import BaseModel, Field
from dbgpt.configs.model_config import MODEL_PATH, PILOT_PATH
from dbgpt.core.awel import DAG, InputOperator, MapOperator, SimpleCallDataInputSource
from dbgpt.rag.chunk import Chunk
from dbgpt.core.awel import DAG, HttpTrigger, MapOperator
from dbgpt.rag.embedding.embedding_factory import DefaultEmbeddingFactory
from dbgpt.rag.knowledge.base import KnowledgeType
from dbgpt.rag.operator.knowledge import KnowledgeOperator
from dbgpt.serve.rag.operators.embedding import EmbeddingAssemblerOperator
from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig
from dbgpt.storage.vector_store.connector import VectorStoreConnector
"""AWEL: Simple rag embedding operator example
pre-requirements:
set your file path in your example code.
Examples:
pre-requirements:
python examples/awel/simple_rag_embedding_example.py
..code-block:: shell
python examples/awel/simple_rag_embedding_example.py
curl --location --request POST 'http://127.0.0.1:5555/api/v1/awel/trigger/examples/rag/embedding' \
--header 'Content-Type: application/json' \
--data-raw '{
"url": "https://docs.dbgpt.site/docs/awel"
}'
"""
def _context_join_fn(context_dict: Dict, chunks: List[Chunk]) -> Dict:
"""context Join function for JoinOperator.
Args:
context_dict (Dict): context dict
chunks (List[Chunk]): chunks
Returns:
Dict: context dict
"""
context_dict["context"] = "\n".join([chunk.content for chunk in chunks])
return context_dict
def _create_vector_connector():
def _create_vector_connector() -> VectorStoreConnector:
"""Create vector connector."""
return VectorStoreConnector.from_default(
"Chroma",
@@ -48,6 +40,22 @@ def _create_vector_connector():
)
class TriggerReqBody(BaseModel):
url: str = Field(..., description="url")
class RequestHandleOperator(MapOperator[TriggerReqBody, Dict]):
def __init__(self, **kwargs):
super().__init__(**kwargs)
async def map(self, input_value: TriggerReqBody) -> Dict:
params = {
"url": input_value.url,
}
print(f"Receive input value: {input_value}")
return params
class ResultOperator(MapOperator):
"""The Result Operator."""
@@ -61,26 +69,31 @@ class ResultOperator(MapOperator):
with DAG("simple_sdk_rag_embedding_example") as dag:
knowledge_operator = KnowledgeOperator()
trigger = HttpTrigger(
"/examples/rag/embedding", methods="POST", request_body=TriggerReqBody
)
request_handle_task = RequestHandleOperator()
knowledge_operator = KnowledgeOperator(knowledge_type=KnowledgeType.URL)
vector_connector = _create_vector_connector()
input_task = InputOperator(input_source=SimpleCallDataInputSource())
file_path_parser = MapOperator(map_function=lambda x: x["file_path"])
url_parser_operator = MapOperator(map_function=lambda x: x["url"])
embedding_operator = EmbeddingAssemblerOperator(
vector_store_connector=vector_connector,
)
output_task = ResultOperator()
(
input_task
>> file_path_parser
trigger
>> request_handle_task
>> url_parser_operator
>> knowledge_operator
>> embedding_operator
>> output_task
)
if __name__ == "__main__":
input_data = {
"data": {
"file_path": "docs/docs/awel.md",
}
}
output = asyncio.run(output_task.call(call_data=input_data))
if dag.leaf_nodes[0].dev_mode:
# Development mode, you can run the dag locally for debugging.
from dbgpt.core.awel import setup_dev_environment
setup_dev_environment([dag], port=5555)
else:
pass