Files
swissknife/compiler.py
2026-03-29 12:01:11 +02:00

222 lines
7.0 KiB
Python

from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Literal
import yaml
NodeClass = Literal["downstream", "upstream"]
NodeType = Literal[
"sharepoint",
"confluence",
"azure_ai_search",
"azure_vector_store",
]
@dataclass(frozen=True)
class NormalizedNode:
id: int
node_class: NodeClass
node_type: NodeType
name: str
fields: dict[str, str]
@dataclass(frozen=True)
class NormalizedRelation:
from_node: NormalizedNode
to_node: NormalizedNode
class CompileError(ValueError):
pass
def _require_str(value: Any, *, field: str) -> str:
if not isinstance(value, str):
raise CompileError(f"{field} must be a string")
s = value.strip()
if not s:
raise CompileError(f"{field} must be non-empty")
return s
def _node_type_from_litegraph_type(lg_type: Any) -> tuple[NodeClass, NodeType]:
if not isinstance(lg_type, str):
raise CompileError("node.type must be a string")
# Frontend will set node.type like: "downstream.sharepoint"
parts = lg_type.split(".")
if len(parts) != 2:
raise CompileError(
"node.type must be in the form '<class>.<type>' (e.g. downstream.sharepoint)"
)
node_class, node_type = parts[0], parts[1]
if node_class not in ("downstream", "upstream"):
raise CompileError("node.type class must be 'downstream' or 'upstream'")
allowed_types: set[str]
if node_class == "downstream":
allowed_types = {"sharepoint", "confluence"}
else:
allowed_types = {"azure_ai_search", "azure_vector_store"}
if node_type not in allowed_types:
raise CompileError(f"invalid node.type '{lg_type}'")
return node_class, node_type # type: ignore[return-value]
def normalize_graph(graph: dict[str, Any]) -> tuple[list[NormalizedNode], list[NormalizedRelation]]:
if not isinstance(graph, dict):
raise CompileError("graph must be an object")
nodes_raw = graph.get("nodes")
links_raw = graph.get("links")
if not isinstance(nodes_raw, list):
raise CompileError("graph.nodes must be an array")
if not isinstance(links_raw, list):
raise CompileError("graph.links must be an array")
nodes_by_id: dict[int, NormalizedNode] = {}
for n in nodes_raw:
if not isinstance(n, dict):
raise CompileError("each node must be an object")
node_id = n.get("id")
if not isinstance(node_id, int):
raise CompileError("node.id must be an integer")
node_class, node_type = _node_type_from_litegraph_type(n.get("type"))
props = n.get("properties")
if not isinstance(props, dict):
props = {}
name = _require_str(props.get("name"), field=f"node[{node_id}].properties.name")
fields: dict[str, str] = {}
for k, v in props.items():
if k == "name":
continue
if v is None:
continue
if not isinstance(v, str):
raise CompileError(f"node[{node_id}].properties.{k} must be a string")
fields[k] = v
if node_id in nodes_by_id:
raise CompileError(f"duplicate node id {node_id}")
nodes_by_id[node_id] = NormalizedNode(
id=node_id,
node_class=node_class,
node_type=node_type,
name=name,
fields=fields,
)
relations: list[NormalizedRelation] = []
for link in links_raw:
# LiteGraph can export links as arrays or objects depending on version.
# We support both:
# - array: [id, origin_id, origin_slot, target_id, target_slot, type]
# - object: { origin_id, target_id, ... }
origin_id: Any = None
target_id: Any = None
if isinstance(link, list) and len(link) >= 5:
origin_id = link[1]
target_id = link[3]
elif isinstance(link, dict):
origin_id = link.get("origin_id")
target_id = link.get("target_id")
else:
raise CompileError("each link must be an array or object")
if not isinstance(origin_id, int) or not isinstance(target_id, int):
raise CompileError("link origin_id/target_id must be integers")
from_node = nodes_by_id.get(origin_id)
to_node = nodes_by_id.get(target_id)
if from_node is None or to_node is None:
raise CompileError("link references unknown node id")
relations.append(NormalizedRelation(from_node=from_node, to_node=to_node))
return list(nodes_by_id.values()), relations
def validate_graph(nodes: list[NormalizedNode], relations: list[NormalizedRelation]) -> None:
downstream = [n for n in nodes if n.node_class == "downstream"]
upstream = [n for n in nodes if n.node_class == "upstream"]
if len(downstream) < 1:
raise CompileError("graph must contain at least one Downstream entity")
if len(upstream) < 1:
raise CompileError("graph must contain at least one Upstream entity")
if len(relations) < 1:
raise CompileError("graph must contain at least one relation")
for r in relations:
if r.from_node.node_class != "downstream" or r.to_node.node_class != "upstream":
raise CompileError("relations must be Downstream -> Upstream only")
# Name uniqueness within each (class,type) group
seen: set[tuple[NodeClass, NodeType, str]] = set()
for n in nodes:
key = (n.node_class, n.node_type, n.name)
if key in seen:
raise CompileError(
f"duplicate name '{n.name}' within {n.node_class}.{n.node_type}"
)
seen.add(key)
def to_yaml(nodes: list[NormalizedNode], relations: list[NormalizedRelation]) -> str:
entities: dict[str, Any] = {
"downstream": {"sharepoint": [], "confluence": []},
"upstream": {"azure_ai_search": [], "azure_vector_store": []},
}
# Stable ordering: by class, type, then name
for n in sorted(nodes, key=lambda x: (x.node_class, x.node_type, x.name)):
item: dict[str, Any] = {"name": n.name}
for k in sorted(n.fields.keys()):
item[k] = n.fields[k]
entities[n.node_class][n.node_type].append(item)
rel_items: list[dict[str, str]] = []
for r in sorted(
relations,
key=lambda x: (
x.from_node.node_class,
x.from_node.node_type,
x.from_node.name,
x.to_node.node_class,
x.to_node.node_type,
x.to_node.name,
),
):
rel_items.append(
{
"from": f"{r.from_node.node_class}.{r.from_node.node_type}.{r.from_node.name}",
"to": f"{r.to_node.node_class}.{r.to_node.node_type}.{r.to_node.name}",
}
)
doc = {"entities": entities, "relations": rel_items}
return yaml.safe_dump(doc, sort_keys=False)
def compile_graph(graph: dict[str, Any]) -> str:
nodes, relations = normalize_graph(graph)
validate_graph(nodes, relations)
return to_yaml(nodes, relations)