# Copyright (c) "Neo4j"
# Neo4j Sweden AB [https://neo4j.com]
# #
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# #
# https://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Union
try:
from neo4j_viz import VisualizationGraph, Node, Relationship
except ImportError:
VisualizationGraph = Node = Relationship = None # type: ignore
from neo4j_graphrag.experimental.components.schema import (
GraphSchema,
NodeType,
PropertyType,
)
[docs]
def schema_visualization(
schema: Union[dict[str, Any], GraphSchema],
) -> VisualizationGraph:
"""Helper function to visualize a GraphSchema using the neo4j-viz library.
Usage:
.. code:: python
VG = schema_visualization(schema)
html = VG.render()
# in Jupyter:
display(html)
# to save the generated HTML
with open("my_schema.html", "w") as f:
f.write(html.data)
"""
if VisualizationGraph is None:
raise ImportError(
"Please install neo4j-viz to use the graph schema visualization feature: pip install neo4j-viz"
)
schema_object = GraphSchema.model_validate(schema)
def _format_property_name(p: PropertyType) -> str:
"""
Args:
p (PropertyType): the property to be formatted
Returns:
str: the property name, suffixed with '*' if the property is required
"""
return p.name + ("*" if p.required else "")
def _relationship_properties(rel_type: str) -> dict[str, str]:
"""Returns a dict {prop_name: prop_type} for all relationship properties.
Args:
rel_type (str): the relationship type
Returns:
dict[str, str]: the relationship properties {name: type} mapping for display
"""
for relationship_type in schema_object.relationship_types:
if relationship_type.label != rel_type:
continue
return {
_format_property_name(p): p.type for p in relationship_type.properties
}
return {}
def _node_properties(node_type: NodeType) -> dict[str, str]:
"""Returns a dict {prop_name: prop_type} for all node properties.
Args:
node_type (NodeType): the node type object
Returns:
dict[str, str]: the node properties {name: type} mapping for display
"""
return {_format_property_name(p): p.type for p in node_type.properties}
nodes = [
Node( # type: ignore
id=node_type.label,
caption=node_type.label,
properties=_node_properties(node_type),
)
for node_type in schema_object.node_types
]
relationships = [
Relationship( # type: ignore
source=pattern[0],
target=pattern[2],
caption=pattern[1],
properties=_relationship_properties(pattern[1]),
)
for pattern in schema_object.patterns
]
VG = VisualizationGraph(nodes=nodes, relationships=relationships)
VG.color_nodes(field="caption")
return VG