222 lines
7.0 KiB
Python
222 lines
7.0 KiB
Python
from __future__ import annotations
|
|
|
|
from dataclasses import dataclass
|
|
from typing import Any, Literal
|
|
|
|
import yaml
|
|
|
|
|
|
NodeClass = Literal["downstream", "upstream"]
|
|
NodeType = Literal[
|
|
"sharepoint",
|
|
"confluence",
|
|
"azure_ai_search",
|
|
"azure_vector_store",
|
|
]
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class NormalizedNode:
|
|
id: int
|
|
node_class: NodeClass
|
|
node_type: NodeType
|
|
name: str
|
|
fields: dict[str, str]
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class NormalizedRelation:
|
|
from_node: NormalizedNode
|
|
to_node: NormalizedNode
|
|
|
|
|
|
class CompileError(ValueError):
|
|
pass
|
|
|
|
|
|
def _require_str(value: Any, *, field: str) -> str:
|
|
if not isinstance(value, str):
|
|
raise CompileError(f"{field} must be a string")
|
|
s = value.strip()
|
|
if not s:
|
|
raise CompileError(f"{field} must be non-empty")
|
|
return s
|
|
|
|
|
|
def _node_type_from_litegraph_type(lg_type: Any) -> tuple[NodeClass, NodeType]:
|
|
if not isinstance(lg_type, str):
|
|
raise CompileError("node.type must be a string")
|
|
|
|
# Frontend will set node.type like: "downstream.sharepoint"
|
|
parts = lg_type.split(".")
|
|
if len(parts) != 2:
|
|
raise CompileError(
|
|
"node.type must be in the form '<class>.<type>' (e.g. downstream.sharepoint)"
|
|
)
|
|
|
|
node_class, node_type = parts[0], parts[1]
|
|
if node_class not in ("downstream", "upstream"):
|
|
raise CompileError("node.type class must be 'downstream' or 'upstream'")
|
|
|
|
allowed_types: set[str]
|
|
if node_class == "downstream":
|
|
allowed_types = {"sharepoint", "confluence"}
|
|
else:
|
|
allowed_types = {"azure_ai_search", "azure_vector_store"}
|
|
|
|
if node_type not in allowed_types:
|
|
raise CompileError(f"invalid node.type '{lg_type}'")
|
|
|
|
return node_class, node_type # type: ignore[return-value]
|
|
|
|
|
|
def normalize_graph(graph: dict[str, Any]) -> tuple[list[NormalizedNode], list[NormalizedRelation]]:
|
|
if not isinstance(graph, dict):
|
|
raise CompileError("graph must be an object")
|
|
|
|
nodes_raw = graph.get("nodes")
|
|
links_raw = graph.get("links")
|
|
|
|
if not isinstance(nodes_raw, list):
|
|
raise CompileError("graph.nodes must be an array")
|
|
if not isinstance(links_raw, list):
|
|
raise CompileError("graph.links must be an array")
|
|
|
|
nodes_by_id: dict[int, NormalizedNode] = {}
|
|
|
|
for n in nodes_raw:
|
|
if not isinstance(n, dict):
|
|
raise CompileError("each node must be an object")
|
|
|
|
node_id = n.get("id")
|
|
if not isinstance(node_id, int):
|
|
raise CompileError("node.id must be an integer")
|
|
|
|
node_class, node_type = _node_type_from_litegraph_type(n.get("type"))
|
|
|
|
props = n.get("properties")
|
|
if not isinstance(props, dict):
|
|
props = {}
|
|
|
|
name = _require_str(props.get("name"), field=f"node[{node_id}].properties.name")
|
|
|
|
fields: dict[str, str] = {}
|
|
for k, v in props.items():
|
|
if k == "name":
|
|
continue
|
|
if v is None:
|
|
continue
|
|
if not isinstance(v, str):
|
|
raise CompileError(f"node[{node_id}].properties.{k} must be a string")
|
|
fields[k] = v
|
|
|
|
if node_id in nodes_by_id:
|
|
raise CompileError(f"duplicate node id {node_id}")
|
|
|
|
nodes_by_id[node_id] = NormalizedNode(
|
|
id=node_id,
|
|
node_class=node_class,
|
|
node_type=node_type,
|
|
name=name,
|
|
fields=fields,
|
|
)
|
|
|
|
relations: list[NormalizedRelation] = []
|
|
|
|
for link in links_raw:
|
|
# LiteGraph can export links as arrays or objects depending on version.
|
|
# We support both:
|
|
# - array: [id, origin_id, origin_slot, target_id, target_slot, type]
|
|
# - object: { origin_id, target_id, ... }
|
|
origin_id: Any = None
|
|
target_id: Any = None
|
|
|
|
if isinstance(link, list) and len(link) >= 5:
|
|
origin_id = link[1]
|
|
target_id = link[3]
|
|
elif isinstance(link, dict):
|
|
origin_id = link.get("origin_id")
|
|
target_id = link.get("target_id")
|
|
else:
|
|
raise CompileError("each link must be an array or object")
|
|
|
|
if not isinstance(origin_id, int) or not isinstance(target_id, int):
|
|
raise CompileError("link origin_id/target_id must be integers")
|
|
|
|
from_node = nodes_by_id.get(origin_id)
|
|
to_node = nodes_by_id.get(target_id)
|
|
if from_node is None or to_node is None:
|
|
raise CompileError("link references unknown node id")
|
|
|
|
relations.append(NormalizedRelation(from_node=from_node, to_node=to_node))
|
|
|
|
return list(nodes_by_id.values()), relations
|
|
|
|
|
|
def validate_graph(nodes: list[NormalizedNode], relations: list[NormalizedRelation]) -> None:
|
|
downstream = [n for n in nodes if n.node_class == "downstream"]
|
|
upstream = [n for n in nodes if n.node_class == "upstream"]
|
|
|
|
if len(downstream) < 1:
|
|
raise CompileError("graph must contain at least one Downstream entity")
|
|
if len(upstream) < 1:
|
|
raise CompileError("graph must contain at least one Upstream entity")
|
|
if len(relations) < 1:
|
|
raise CompileError("graph must contain at least one relation")
|
|
|
|
for r in relations:
|
|
if r.from_node.node_class != "downstream" or r.to_node.node_class != "upstream":
|
|
raise CompileError("relations must be Downstream -> Upstream only")
|
|
|
|
# Name uniqueness within each (class,type) group
|
|
seen: set[tuple[NodeClass, NodeType, str]] = set()
|
|
for n in nodes:
|
|
key = (n.node_class, n.node_type, n.name)
|
|
if key in seen:
|
|
raise CompileError(
|
|
f"duplicate name '{n.name}' within {n.node_class}.{n.node_type}"
|
|
)
|
|
seen.add(key)
|
|
|
|
|
|
def to_yaml(nodes: list[NormalizedNode], relations: list[NormalizedRelation]) -> str:
|
|
entities: dict[str, Any] = {
|
|
"downstream": {"sharepoint": [], "confluence": []},
|
|
"upstream": {"azure_ai_search": [], "azure_vector_store": []},
|
|
}
|
|
|
|
# Stable ordering: by class, type, then name
|
|
for n in sorted(nodes, key=lambda x: (x.node_class, x.node_type, x.name)):
|
|
item: dict[str, Any] = {"name": n.name}
|
|
for k in sorted(n.fields.keys()):
|
|
item[k] = n.fields[k]
|
|
entities[n.node_class][n.node_type].append(item)
|
|
|
|
rel_items: list[dict[str, str]] = []
|
|
for r in sorted(
|
|
relations,
|
|
key=lambda x: (
|
|
x.from_node.node_class,
|
|
x.from_node.node_type,
|
|
x.from_node.name,
|
|
x.to_node.node_class,
|
|
x.to_node.node_type,
|
|
x.to_node.name,
|
|
),
|
|
):
|
|
rel_items.append(
|
|
{
|
|
"from": f"{r.from_node.node_class}.{r.from_node.node_type}.{r.from_node.name}",
|
|
"to": f"{r.to_node.node_class}.{r.to_node.node_type}.{r.to_node.name}",
|
|
}
|
|
)
|
|
|
|
doc = {"entities": entities, "relations": rel_items}
|
|
return yaml.safe_dump(doc, sort_keys=False)
|
|
|
|
|
|
def compile_graph(graph: dict[str, Any]) -> str:
|
|
nodes, relations = normalize_graph(graph)
|
|
validate_graph(nodes, relations)
|
|
return to_yaml(nodes, relations)
|