2025-07-11 13:52:19 +00:00

117 lines
3.5 KiB
Python

from collections import defaultdict
from typing import Dict, Set, Type
from .basenode import BaseNode
from .baserelationship import BaseRelationship
from .graphconnection import GraphConnection
def get_node_types(base_type: Type[BaseNode] = BaseNode) -> Dict[str, Type[BaseNode]]:
node_types = {}
for subclass in base_type.__subclasses__():
# we can define 'abstract' nodes which don't have a label
# these are to provide common properties to be used by subclassed nodes
# but shouldn't be put in the graph
if (
hasattr(subclass, "__primarylabel__")
and subclass.__primarylabel__ is not None
):
node_types[subclass.__primarylabel__] = subclass
if subclass.__subclasses__():
subclass_node_types = get_node_types(subclass)
node_types.update(subclass_node_types)
return node_types
def get_rels_by_type(
base_type: Type[BaseRelationship] = BaseRelationship,
) -> Dict[str, dict]:
rel_types: dict = defaultdict(dict)
for rel_subclass in base_type.__subclasses__():
# we can define 'abstract' relationships which don't have a label
# these are to provide common properties to be used by subclassed relationships
# but shouldn't be put in the graph
if (
hasattr(rel_subclass, "__relationshiptype__")
and rel_subclass.__relationshiptype__ is not None
):
rel_types[rel_subclass.__relationshiptype__] = {
"rel_class": rel_subclass,
"source_class": rel_subclass.model_fields["source"].annotation,
"target_class": rel_subclass.model_fields["target"].annotation,
}
if rel_subclass.__subclasses__():
subclass_rel_types = get_rels_by_type(rel_subclass)
rel_types.update(subclass_rel_types)
return rel_types
def all_subclasses(cls: type) -> set:
return set(cls.__subclasses__()).union(
[s for c in cls.__subclasses__() for s in all_subclasses(c)]
)
def get_rels_by_node(
base_type: Type[BaseRelationship] = BaseRelationship, by_source: bool = True
) -> Dict[str, Set[str]]:
if by_source is True:
node_dir = "source_class"
else:
node_dir = "target_class"
all_rels = get_rels_by_type(base_type)
by_node: Dict[str, Set[str]] = defaultdict(set)
for rel_type, entry in all_rels.items():
try:
node_label = entry[node_dir].__primarylabel__
except AttributeError:
node_label = None
if node_label is not None:
by_node[node_label].add(rel_type)
for node_subclass in all_subclasses(entry[node_dir]):
subclass_label = node_subclass.__primarylabel__
if subclass_label is not None:
by_node[subclass_label].add(rel_type)
return by_node
def get_rels_by_source(
base_type: Type[BaseRelationship] = BaseRelationship,
) -> Dict[str, Set[str]]:
return get_rels_by_node(by_source=True)
def get_rels_by_target(
base_type: Type[BaseRelationship] = BaseRelationship,
) -> Dict[str, Set[str]]:
return get_rels_by_node(by_source=False)
def auto_constrain() -> None:
"""Automatically apply constraints
Get information about all the defined nodes in the current environment.
Apply constraints based on the primary label and primary property for each node.
"""
graph = GraphConnection()
for node_label, node_type in get_node_types().items():
graph.apply_constraint(node_label, node_type.__primaryproperty__)