feat: first functioning code
This commit is contained in:
221
compiler.py
Normal file
221
compiler.py
Normal file
@@ -0,0 +1,221 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Literal
|
||||
|
||||
import yaml
|
||||
|
||||
|
||||
NodeClass = Literal["downstream", "upstream"]
|
||||
NodeType = Literal[
|
||||
"sharepoint",
|
||||
"confluence",
|
||||
"azure_ai_search",
|
||||
"azure_vector_store",
|
||||
]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class NormalizedNode:
|
||||
id: int
|
||||
node_class: NodeClass
|
||||
node_type: NodeType
|
||||
name: str
|
||||
fields: dict[str, str]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class NormalizedRelation:
|
||||
from_node: NormalizedNode
|
||||
to_node: NormalizedNode
|
||||
|
||||
|
||||
class CompileError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
def _require_str(value: Any, *, field: str) -> str:
|
||||
if not isinstance(value, str):
|
||||
raise CompileError(f"{field} must be a string")
|
||||
s = value.strip()
|
||||
if not s:
|
||||
raise CompileError(f"{field} must be non-empty")
|
||||
return s
|
||||
|
||||
|
||||
def _node_type_from_litegraph_type(lg_type: Any) -> tuple[NodeClass, NodeType]:
|
||||
if not isinstance(lg_type, str):
|
||||
raise CompileError("node.type must be a string")
|
||||
|
||||
# Frontend will set node.type like: "downstream.sharepoint"
|
||||
parts = lg_type.split(".")
|
||||
if len(parts) != 2:
|
||||
raise CompileError(
|
||||
"node.type must be in the form '<class>.<type>' (e.g. downstream.sharepoint)"
|
||||
)
|
||||
|
||||
node_class, node_type = parts[0], parts[1]
|
||||
if node_class not in ("downstream", "upstream"):
|
||||
raise CompileError("node.type class must be 'downstream' or 'upstream'")
|
||||
|
||||
allowed_types: set[str]
|
||||
if node_class == "downstream":
|
||||
allowed_types = {"sharepoint", "confluence"}
|
||||
else:
|
||||
allowed_types = {"azure_ai_search", "azure_vector_store"}
|
||||
|
||||
if node_type not in allowed_types:
|
||||
raise CompileError(f"invalid node.type '{lg_type}'")
|
||||
|
||||
return node_class, node_type # type: ignore[return-value]
|
||||
|
||||
|
||||
def normalize_graph(graph: dict[str, Any]) -> tuple[list[NormalizedNode], list[NormalizedRelation]]:
|
||||
if not isinstance(graph, dict):
|
||||
raise CompileError("graph must be an object")
|
||||
|
||||
nodes_raw = graph.get("nodes")
|
||||
links_raw = graph.get("links")
|
||||
|
||||
if not isinstance(nodes_raw, list):
|
||||
raise CompileError("graph.nodes must be an array")
|
||||
if not isinstance(links_raw, list):
|
||||
raise CompileError("graph.links must be an array")
|
||||
|
||||
nodes_by_id: dict[int, NormalizedNode] = {}
|
||||
|
||||
for n in nodes_raw:
|
||||
if not isinstance(n, dict):
|
||||
raise CompileError("each node must be an object")
|
||||
|
||||
node_id = n.get("id")
|
||||
if not isinstance(node_id, int):
|
||||
raise CompileError("node.id must be an integer")
|
||||
|
||||
node_class, node_type = _node_type_from_litegraph_type(n.get("type"))
|
||||
|
||||
props = n.get("properties")
|
||||
if not isinstance(props, dict):
|
||||
props = {}
|
||||
|
||||
name = _require_str(props.get("name"), field=f"node[{node_id}].properties.name")
|
||||
|
||||
fields: dict[str, str] = {}
|
||||
for k, v in props.items():
|
||||
if k == "name":
|
||||
continue
|
||||
if v is None:
|
||||
continue
|
||||
if not isinstance(v, str):
|
||||
raise CompileError(f"node[{node_id}].properties.{k} must be a string")
|
||||
fields[k] = v
|
||||
|
||||
if node_id in nodes_by_id:
|
||||
raise CompileError(f"duplicate node id {node_id}")
|
||||
|
||||
nodes_by_id[node_id] = NormalizedNode(
|
||||
id=node_id,
|
||||
node_class=node_class,
|
||||
node_type=node_type,
|
||||
name=name,
|
||||
fields=fields,
|
||||
)
|
||||
|
||||
relations: list[NormalizedRelation] = []
|
||||
|
||||
for link in links_raw:
|
||||
# LiteGraph can export links as arrays or objects depending on version.
|
||||
# We support both:
|
||||
# - array: [id, origin_id, origin_slot, target_id, target_slot, type]
|
||||
# - object: { origin_id, target_id, ... }
|
||||
origin_id: Any = None
|
||||
target_id: Any = None
|
||||
|
||||
if isinstance(link, list) and len(link) >= 5:
|
||||
origin_id = link[1]
|
||||
target_id = link[3]
|
||||
elif isinstance(link, dict):
|
||||
origin_id = link.get("origin_id")
|
||||
target_id = link.get("target_id")
|
||||
else:
|
||||
raise CompileError("each link must be an array or object")
|
||||
|
||||
if not isinstance(origin_id, int) or not isinstance(target_id, int):
|
||||
raise CompileError("link origin_id/target_id must be integers")
|
||||
|
||||
from_node = nodes_by_id.get(origin_id)
|
||||
to_node = nodes_by_id.get(target_id)
|
||||
if from_node is None or to_node is None:
|
||||
raise CompileError("link references unknown node id")
|
||||
|
||||
relations.append(NormalizedRelation(from_node=from_node, to_node=to_node))
|
||||
|
||||
return list(nodes_by_id.values()), relations
|
||||
|
||||
|
||||
def validate_graph(nodes: list[NormalizedNode], relations: list[NormalizedRelation]) -> None:
|
||||
downstream = [n for n in nodes if n.node_class == "downstream"]
|
||||
upstream = [n for n in nodes if n.node_class == "upstream"]
|
||||
|
||||
if len(downstream) < 1:
|
||||
raise CompileError("graph must contain at least one Downstream entity")
|
||||
if len(upstream) < 1:
|
||||
raise CompileError("graph must contain at least one Upstream entity")
|
||||
if len(relations) < 1:
|
||||
raise CompileError("graph must contain at least one relation")
|
||||
|
||||
for r in relations:
|
||||
if r.from_node.node_class != "downstream" or r.to_node.node_class != "upstream":
|
||||
raise CompileError("relations must be Downstream -> Upstream only")
|
||||
|
||||
# Name uniqueness within each (class,type) group
|
||||
seen: set[tuple[NodeClass, NodeType, str]] = set()
|
||||
for n in nodes:
|
||||
key = (n.node_class, n.node_type, n.name)
|
||||
if key in seen:
|
||||
raise CompileError(
|
||||
f"duplicate name '{n.name}' within {n.node_class}.{n.node_type}"
|
||||
)
|
||||
seen.add(key)
|
||||
|
||||
|
||||
def to_yaml(nodes: list[NormalizedNode], relations: list[NormalizedRelation]) -> str:
|
||||
entities: dict[str, Any] = {
|
||||
"downstream": {"sharepoint": [], "confluence": []},
|
||||
"upstream": {"azure_ai_search": [], "azure_vector_store": []},
|
||||
}
|
||||
|
||||
# Stable ordering: by class, type, then name
|
||||
for n in sorted(nodes, key=lambda x: (x.node_class, x.node_type, x.name)):
|
||||
item: dict[str, Any] = {"name": n.name}
|
||||
for k in sorted(n.fields.keys()):
|
||||
item[k] = n.fields[k]
|
||||
entities[n.node_class][n.node_type].append(item)
|
||||
|
||||
rel_items: list[dict[str, str]] = []
|
||||
for r in sorted(
|
||||
relations,
|
||||
key=lambda x: (
|
||||
x.from_node.node_class,
|
||||
x.from_node.node_type,
|
||||
x.from_node.name,
|
||||
x.to_node.node_class,
|
||||
x.to_node.node_type,
|
||||
x.to_node.name,
|
||||
),
|
||||
):
|
||||
rel_items.append(
|
||||
{
|
||||
"from": f"{r.from_node.node_class}.{r.from_node.node_type}.{r.from_node.name}",
|
||||
"to": f"{r.to_node.node_class}.{r.to_node.node_type}.{r.to_node.name}",
|
||||
}
|
||||
)
|
||||
|
||||
doc = {"entities": entities, "relations": rel_items}
|
||||
return yaml.safe_dump(doc, sort_keys=False)
|
||||
|
||||
|
||||
def compile_graph(graph: dict[str, Any]) -> str:
|
||||
nodes, relations = normalize_graph(graph)
|
||||
validate_graph(nodes, relations)
|
||||
return to_yaml(nodes, relations)
|
||||
Reference in New Issue
Block a user