Source code for neo4j_graphrag.experimental.utils.schema

#  Copyright (c) "Neo4j"
#  Neo4j Sweden AB [https://neo4j.com]
#  #
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#  #
#      https://www.apache.org/licenses/LICENSE-2.0
#  #
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
from typing import Any, Union

try:
    from neo4j_viz import VisualizationGraph, Node, Relationship
except ImportError:
    VisualizationGraph = Node = Relationship = None  # type: ignore

from neo4j_graphrag.experimental.components.schema import (
    GraphSchema,
    NodeType,
    PropertyType,
)


[docs] def schema_visualization( schema: Union[dict[str, Any], GraphSchema], ) -> VisualizationGraph: """Helper function to visualize a GraphSchema using the neo4j-viz library. Usage: .. code:: python VG = schema_visualization(schema) html = VG.render() # in Jupyter: display(html) # to save the generated HTML with open("my_schema.html", "w") as f: f.write(html.data) """ if VisualizationGraph is None: raise ImportError( "Please install neo4j-viz to use the graph schema visualization feature: pip install neo4j-viz" ) schema_object = GraphSchema.model_validate(schema) def _format_property_name(p: PropertyType) -> str: """ Args: p (PropertyType): the property to be formatted Returns: str: the property name, suffixed with '*' if the property is required """ return p.name + ("*" if p.required else "") def _relationship_properties(rel_type: str) -> dict[str, str]: """Returns a dict {prop_name: prop_type} for all relationship properties. Args: rel_type (str): the relationship type Returns: dict[str, str]: the relationship properties {name: type} mapping for display """ for relationship_type in schema_object.relationship_types: if relationship_type.label != rel_type: continue return { _format_property_name(p): p.type for p in relationship_type.properties } return {} def _node_properties(node_type: NodeType) -> dict[str, str]: """Returns a dict {prop_name: prop_type} for all node properties. Args: node_type (NodeType): the node type object Returns: dict[str, str]: the node properties {name: type} mapping for display """ return {_format_property_name(p): p.type for p in node_type.properties} nodes = [ Node( # type: ignore id=node_type.label, caption=node_type.label, properties=_node_properties(node_type), ) for node_type in schema_object.node_types ] relationships = [ Relationship( # type: ignore source=pattern[0], target=pattern[2], caption=pattern[1], properties=_relationship_properties(pattern[1]), ) for pattern in schema_object.patterns ] VG = VisualizationGraph(nodes=nodes, relationships=relationships) VG.color_nodes(field="caption") return VG