Compare commits

...

5 Commits

Author SHA1 Message Date
Eugene Yurtsev
914e7b55bf qxqx 2024-08-14 13:08:57 -04:00
Fernando de Oliveira
6899f5896e Merge branch 'master' into feature/graph-mermaid-playwright-runnable 2024-07-24 19:59:05 -03:00
Fernando de Oliveira
84b51e41a2 chore: unused method removed 2024-07-23 17:02:52 -03:00
Fernando de Oliveira
ea81891946 fix: playwright imports inside run_mermaid_using_playwright method 2024-07-23 16:57:02 -03:00
Fernando de Oliveira
f76364fd0f feature: graph image rendering with playwright 2024-07-23 16:26:05 -03:00
3 changed files with 2356 additions and 10 deletions

File diff suppressed because one or more lines are too long

View File

@@ -173,8 +173,21 @@ class NodeStyles:
class MermaidDrawMethod(Enum):
"""Enum for different draw methods supported by Mermaid"""
PYPPETEER = "pyppeteer" # Uses Pyppeteer to render the graph
API = "api" # Uses Mermaid.INK API to render the graph
API = "api"
"""Use Mermaid.INK API to render the graph. This is the default method."""
PLAYWRIGHT = "playwright"
"""Use playwright to render the graph.
Playwright is a library that allows using different browser engines to render
the graph.
See: https://playwright.dev/python/docs/intro
"""
PYPPETEER = "pyppeteer"
"""Use Pyppeteer to render the graph.
This is DEPRECATED as the pyppetter library is no longer maintained.
"""
def node_data_str(id: str, data: Union[Type[BaseModel], RunnableType]) -> str:
@@ -631,3 +644,46 @@ class Graph:
background_color=background_color,
padding=padding,
)
async def adraw_mermaid_png(
self,
*,
curve_style: CurveStyle = CurveStyle.LINEAR,
node_colors: NodeStyles = NodeStyles(),
wrap_label_n_words: int = 9,
output_file_path: Optional[str] = None,
draw_method: MermaidDrawMethod = MermaidDrawMethod.API,
background_color: str = "white",
padding: int = 10,
) -> bytes:
"""Draw the graph as a PNG image using Mermaid (async variant).
Args:
curve_style: The style of the edges. Defaults to CurveStyle.LINEAR.
node_colors: The colors of the nodes. Defaults to NodeStyles().
wrap_label_n_words: The number of words to wrap the node labels at.
Defaults to 9.
output_file_path: The path to save the image to. If None, the image
is not saved. Defaults to None.
draw_method: The method to use to draw the graph.
Defaults to MermaidDrawMethod.API.
background_color: The color of the background. Defaults to "white".
padding: The padding around the graph. Defaults to 10.
Returns:
The PNG image as bytes.
"""
from langchain_core.runnables.graph_mermaid import adraw_mermaid_png
mermaid_syntax = self.draw_mermaid(
curve_style=curve_style,
node_colors=node_colors,
wrap_label_n_words=wrap_label_n_words,
)
return await adraw_mermaid_png(
mermaid_syntax=mermaid_syntax,
output_file_path=output_file_path,
draw_method=draw_method,
background_color=background_color,
padding=padding,
)

View File

@@ -1,7 +1,9 @@
import base64
import re
from dataclasses import asdict
from typing import Dict, List, Optional
from functools import lru_cache
from pathlib import Path
from typing import Dict, List, Literal, Optional, Union
from langchain_core.runnables.graph import (
CurveStyle,
@@ -12,6 +14,15 @@ from langchain_core.runnables.graph import (
)
@lru_cache()
def _get_vendored_mermaid_js() -> str:
"""Get vendored Mermaid JS file contents."""
HERE = Path(__file__).parent
path = HERE / "_vendored/mermaid.min.js"
with open(path, "r") as file:
return file.read()
def draw_mermaid(
nodes: Dict[str, Node],
edges: List[Edge],
@@ -142,7 +153,9 @@ def _generate_mermaid_graph_styles(node_colors: NodeStyles) -> str:
def draw_mermaid_png(
mermaid_syntax: str,
output_file_path: Optional[str] = None,
draw_method: MermaidDrawMethod = MermaidDrawMethod.API,
draw_method: Union[
MermaidDrawMethod, Literal["api", "playwright", "pyppeteer"]
] = MermaidDrawMethod.API,
background_color: Optional[str] = "white",
padding: int = 10,
) -> bytes:
@@ -152,7 +165,7 @@ def draw_mermaid_png(
mermaid_syntax (str): Mermaid graph syntax.
output_file_path (str, optional): Path to save the PNG image.
Defaults to None.
draw_method (MermaidDrawMethod, optional): Method to draw the graph.
draw_method: Method to draw the graph.
Defaults to MermaidDrawMethod.API.
background_color (str, optional): Background color of the image.
Defaults to "white".
@@ -164,7 +177,20 @@ def draw_mermaid_png(
Raises:
ValueError: If an invalid draw method is provided.
"""
if draw_method == MermaidDrawMethod.PYPPETEER:
if draw_method == "api":
draw_method_ = MermaidDrawMethod.API
elif draw_method == "playwright":
draw_method_ = MermaidDrawMethod.PLAYWRIGHT
elif draw_method == "pyppeteer":
draw_method_ = MermaidDrawMethod.PYPPETEER
else:
draw_method_ = draw_method
if draw_method_ == MermaidDrawMethod.PLAYWRIGHT:
img_bytes = _render_mermaid_using_playwright_sync(
mermaid_syntax, output_file_path, background_color, padding
)
elif draw_method_ == MermaidDrawMethod.PYPPETEER:
import asyncio
img_bytes = asyncio.run(
@@ -172,20 +198,255 @@ def draw_mermaid_png(
mermaid_syntax, output_file_path, background_color, padding
)
)
elif draw_method == MermaidDrawMethod.API:
elif draw_method_ == MermaidDrawMethod.API:
img_bytes = _render_mermaid_using_api(
mermaid_syntax, output_file_path, background_color
)
else:
supported_methods = ", ".join([m.value for m in MermaidDrawMethod])
raise ValueError(
f"Invalid draw method: {draw_method}. "
f"Invalid draw method: {draw_method_}. "
f"Supported draw methods are: {supported_methods}"
)
return img_bytes
async def adraw_mermaid_png(
mermaid_syntax: str,
output_file_path: Optional[str] = None,
draw_method: Union[
MermaidDrawMethod, Literal["api", "playwright", "pyppeteer"]
] = MermaidDrawMethod.API,
background_color: Optional[str] = "white",
padding: int = 10,
) -> bytes:
"""Draws a Mermaid graph as PNG using provided syntax.
Args:
mermaid_syntax (str): Mermaid graph syntax.
output_file_path (str, optional): Path to save the PNG image.
Defaults to None.
draw_method: Method to draw the graph.
Defaults to MermaidDrawMethod.API.
background_color (str, optional): Background color of the image.
Defaults to "white".
padding (int, optional): Padding around the image. Defaults to 10.
Returns:
bytes: PNG image bytes.
Raises:
ValueError: If an invalid draw method is provided.
"""
if draw_method == "api":
draw_method_ = MermaidDrawMethod.API
elif draw_method == "playwright":
draw_method_ = MermaidDrawMethod.PLAYWRIGHT
elif draw_method == "pyppeteer":
draw_method_ = MermaidDrawMethod.PYPPETEER
else:
draw_method_ = draw_method
if draw_method_ == MermaidDrawMethod.PLAYWRIGHT:
img_bytes = await _render_mermaid_using_playwright(
mermaid_syntax, output_file_path, background_color, padding
)
elif draw_method_ == MermaidDrawMethod.PYPPETEER:
img_bytes = await _render_mermaid_using_pyppeteer(
mermaid_syntax, output_file_path, background_color, padding
)
elif draw_method_ == MermaidDrawMethod.API:
img_bytes = _render_mermaid_using_api(
mermaid_syntax, output_file_path, background_color
)
else:
supported_methods = ", ".join([m.value for m in MermaidDrawMethod])
raise ValueError(
f"Invalid draw method: {draw_method_}. "
f"Supported draw methods are: {supported_methods}"
)
return img_bytes
def _render_mermaid_using_playwright_sync(
mermaid_syntax: str,
output_file_path: Optional[str] = None,
background_color: Optional[str] = "white",
padding: int = 10,
device_scale_factor: int = 3,
) -> bytes:
try:
from playwright.sync_api import ViewportSize, sync_playwright
except ImportError as e:
raise ImportError(
"Install Playwright to use the Playwright method: `pip install playwright`."
) from e
with sync_playwright() as p:
img_bytes: bytes = b""
for browser_type in [p.chromium, p.firefox, p.webkit]:
try:
browser = browser_type.launch()
except Exception:
continue
page = browser.new_page()
# Setup Mermaid JS
page.goto("about:blank")
page.add_script_tag(content=_get_vendored_mermaid_js())
page.evaluate(
"""() => {
mermaid.initialize({startOnLoad:true});
}"""
)
# Render SVG
svg_code = page.evaluate(
"""(mermaidGraph) => {
return mermaid.mermaidAPI.render('mermaid', mermaidGraph);
}""",
mermaid_syntax,
)
# Set the page background to white
page.evaluate(
"""([svg, background_color]) => {
document.body.innerHTML = svg;
document.body.style.background = background_color;
}""",
[svg_code["svg"], background_color],
)
# Take a screenshot
dimensions = page.evaluate(
"""() => {
const svgElement = document.querySelector('svg');
const rect = svgElement.getBoundingClientRect();
return { width: rect.width, height: rect.height };
}"""
)
viewport_size = ViewportSize(
width=int(dimensions["width"] + padding),
height=int(dimensions["height"] + padding),
)
browser.new_context(
viewport=viewport_size,
device_scale_factor=device_scale_factor,
)
img_bytes = page.screenshot(full_page=False)
browser.close()
break
if len(img_bytes) == 0:
raise Exception(
"Install a Playwright supported browser with `playwright install`."
)
if output_file_path is not None:
with open(output_file_path, "wb") as file:
file.write(img_bytes)
return img_bytes
async def _render_mermaid_using_playwright(
mermaid_syntax: str,
output_file_path: Optional[str] = None,
background_color: Optional[str] = "white",
padding: int = 10,
device_scale_factor: int = 3,
) -> bytes:
"""Renders Mermaid graph using Playwright."""
try:
from playwright.async_api import ViewportSize, async_playwright
except ImportError as e:
raise ImportError(
"Install Playwright to use the Playwright method: `pip install playwright`."
) from e
async with async_playwright() as p:
img_bytes: bytes = b""
for browser_type in [p.chromium, p.firefox, p.webkit]:
try:
browser = await browser_type.launch()
except Exception:
continue
page = await browser.new_page()
# Setup Mermaid JS
await page.goto("about:blank")
await page.add_script_tag(content=_get_vendored_mermaid_js())
await page.evaluate(
"""() => {
mermaid.initialize({startOnLoad:true});
}"""
)
# Render SVG
svg_code = await page.evaluate(
"""(mermaidGraph) => {
return mermaid.mermaidAPI.render('mermaid', mermaidGraph);
}""",
mermaid_syntax,
)
# Set the page background to white
await page.evaluate(
"""([svg, background_color]) => {
document.body.innerHTML = svg;
document.body.style.background = background_color;
}""",
[svg_code["svg"], background_color],
)
# Take a screenshot
dimensions = await page.evaluate(
"""() => {
const svgElement = document.querySelector('svg');
const rect = svgElement.getBoundingClientRect();
return { width: rect.width, height: rect.height };
}"""
)
viewport_size = ViewportSize(
width=int(dimensions["width"] + padding),
height=int(dimensions["height"] + padding),
)
await browser.new_context(
viewport=viewport_size,
device_scale_factor=device_scale_factor,
)
img_bytes = await page.screenshot(full_page=False)
await browser.close()
break
if len(img_bytes) == 0:
raise Exception(
"Install a Playwright supported browser with `playwright install`."
)
if output_file_path is not None:
with open(output_file_path, "wb") as file:
file.write(img_bytes)
return img_bytes
async def _render_mermaid_using_pyppeteer(
mermaid_syntax: str,
output_file_path: Optional[str] = None,
@@ -206,8 +467,8 @@ async def _render_mermaid_using_pyppeteer(
# Setup Mermaid JS
await page.goto("about:blank")
await page.addScriptTag(
{"url": "https://cdn.jsdelivr.net/npm/mermaid/dist/mermaid.min.js"}
await page.add_script_tag(
content=_get_vendored_mermaid_js(),
)
await page.evaluate(
"""() => {