diff --git a/graphrag/config/create_graphrag_config.py b/graphrag/config/create_graphrag_config.py index a883658c..5a63dbc7 100644 --- a/graphrag/config/create_graphrag_config.py +++ b/graphrag/config/create_graphrag_config.py @@ -397,7 +397,7 @@ def hydrate_parallelization_params( reader.use(values.get("snapshots")), ): snapshots_model = SnapshotsConfig( - graphml=reader.bool("graphml") or defs.SNAPSHOTS_GRAPHML, + graphml = defs.SNAPSHOTS_GRAPHML, raw_entities=reader.bool("raw_entities") or defs.SNAPSHOTS_RAW_ENTITIES, top_level_nodes=reader.bool("top_level_nodes") or defs.SNAPSHOTS_TOP_LEVEL_NODES, diff --git a/graphrag/config/defaults.py b/graphrag/config/defaults.py index 4d648914..b07579c9 100644 --- a/graphrag/config/defaults.py +++ b/graphrag/config/defaults.py @@ -75,7 +75,7 @@ NODE2VEC_RANDOM_SEED = 597832 REPORTING_TYPE = ReportingType.file REPORTING_BASE_DIR = "output/${timestamp}/reports" -SNAPSHOTS_GRAPHML = False +SNAPSHOTS_GRAPHML = True SNAPSHOTS_RAW_ENTITIES = False SNAPSHOTS_TOP_LEVEL_NODES = False STORAGE_BASE_DIR = "output/${timestamp}/artifacts" diff --git a/graphrag/index/init_content.py b/graphrag/index/init_content.py index c63a6578..7c72f639 100644 --- a/graphrag/index/init_content.py +++ b/graphrag/index/init_content.py @@ -137,7 +137,6 @@ enabled: false # if true, will generate UMAP embeddings for nodes snapshots: - graphml: false raw_entities: false top_level_nodes: false diff --git a/graphrag/index/utils/load_graph.py b/graphrag/index/utils/load_graph.py index 57992a04..37f05ce1 100644 --- a/graphrag/index/utils/load_graph.py +++ b/graphrag/index/utils/load_graph.py @@ -3,9 +3,11 @@ """Networkx load_graph utility definition.""" +import io + import networkx as nx def load_graph(graphml: str | nx.Graph) -> nx.Graph: """Load a graph from a graphml file or a networkx graph.""" - return nx.parse_graphml(graphml) if isinstance(graphml, str) else graphml + return nx.read_graphml(io.StringIO(graphml)) if isinstance(graphml, str) else graphml diff --git a/graphrag/index/verbs/__init__.py b/graphrag/index/verbs/__init__.py index 379c2a37..4e2123a9 100644 --- a/graphrag/index/verbs/__init__.py +++ b/graphrag/index/verbs/__init__.py @@ -16,6 +16,7 @@ unpack_graph, ) from .overrides import aggregate, concat, merge +from .restore_snapshot_rows import restore_snapshot_rows from .snapshot import snapshot from .snapshot_rows import snapshot_rows from .spread_json import spread_json @@ -37,6 +38,7 @@ "layout_graph", "merge", "merge_graphs", + "restore_snapshot_rows", "snapshot", "snapshot_rows", "spread_json", diff --git a/graphrag/index/verbs/entities/summarize/description_summarize.py b/graphrag/index/verbs/entities/summarize/description_summarize.py index 5b7feb41..944883ee 100644 --- a/graphrag/index/verbs/entities/summarize/description_summarize.py +++ b/graphrag/index/verbs/entities/summarize/description_summarize.py @@ -66,7 +66,7 @@ async def summarize_descriptions( { "verb": "", "args": { - "column": "the_document_text_column_to_extract_descriptions_from", /* Required: This will be a graphml graph in string form which represents the entities and their relationships */ + "column": "the_document_text_column_to_extract_descriptions_from", /* Required: This will be a graphml graph filepath in string form which represents the entities and their relationships */ "to": "the_column_to_output_the_summarized_descriptions_to", /* Required: This will be a graphml graph in string form which represents the entities and their relationships after being summarized */ "strategy": {...} , see strategies section below } diff --git a/graphrag/index/verbs/graph/clustering/cluster_graph.py b/graphrag/index/verbs/graph/clustering/cluster_graph.py index e8be50e8..885aaf34 100644 --- a/graphrag/index/verbs/graph/clustering/cluster_graph.py +++ b/graphrag/index/verbs/graph/clustering/cluster_graph.py @@ -3,6 +3,7 @@ """A module containing cluster_graph, apply_clustering and run_layout methods definition.""" +import io import logging from enum import Enum from random import Random @@ -115,7 +116,7 @@ def apply_clustering( ) -> nx.Graph: """Apply clustering to a graphml string.""" random = Random(seed) # noqa S311 - graph = nx.parse_graphml(graphml) + graph = nx.read_graphml(io.StringIO(graphml)) for community_level, community_id, nodes in communities: if level == community_level: for node in nodes: diff --git a/graphrag/index/verbs/restore_snapshot_rows.py b/graphrag/index/verbs/restore_snapshot_rows.py new file mode 100644 index 00000000..4ada0a59 --- /dev/null +++ b/graphrag/index/verbs/restore_snapshot_rows.py @@ -0,0 +1,58 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing 'FormatSpecifier' model.""" + +import json +from pathlib import Path +from typing import Any, cast + +from datashaper import TableContainer, VerbInput, verb +from pandas.core.frame import DataFrame +from pandas.core.groupby import DataFrameGroupBy + +from graphrag.index.storage import PipelineStorage +from graphrag.index.verbs.snapshot_rows import _parse_formats + + +@verb(name="restore_snapshot_rows") +async def restore_snapshot_rows( + input: VerbInput, + column: str, + to: str, + storage: PipelineStorage, + formats: list[str | dict[str, Any]], + **_kwargs: dict, +) -> TableContainer: + """Take a by-row snapshot of the tabular data.""" + if isinstance(input.get_input(), DataFrameGroupBy): + msg = "Cannot snapshot rows of a grouped DataFrame" + raise TypeError(msg) + + data = cast(DataFrame, input.get_input()) + parsed_formats = _parse_formats(formats) + + # do not modify the original data + new_data = data.copy() + new_data[to] = None + for row_idx, row in data.iterrows(): + # for each row, load only the data in the specified formats + for fmt in parsed_formats: + row_name = Path(row[column]) + if row_name.suffix[1:] != fmt.extension: + continue + + filename = row_name.name + data_bytes:bytes|bytearray = await storage.get(filename, as_bytes=True) + if not isinstance(row_idx, int): + continue + + if fmt.format == "json": + new_data.loc[row_idx, to] = json.loads(data_bytes) + elif fmt.format == "text": + new_data.loc[row_idx, to] = data_bytes.decode("utf-8") + else: + msg = f"Unsupported format: {fmt.format}" + raise ValueError(msg) + + return TableContainer(table=new_data.dropna(subset=[to])) diff --git a/graphrag/index/verbs/snapshot_rows.py b/graphrag/index/verbs/snapshot_rows.py index 6c6c1665..de439f22 100644 --- a/graphrag/index/verbs/snapshot_rows.py +++ b/graphrag/index/verbs/snapshot_rows.py @@ -5,9 +5,11 @@ import json from dataclasses import dataclass -from typing import Any +from typing import Any, cast from datashaper import TableContainer, VerbInput, verb +from pandas.core.frame import DataFrame +from pandas.core.groupby import DataFrameGroupBy from graphrag.index.storage import PipelineStorage @@ -24,6 +26,7 @@ class FormatSpecifier: async def snapshot_rows( input: VerbInput, column: str | None, + to: str | None, base_name: str, storage: PipelineStorage, formats: list[str | dict[str, Any]], @@ -31,7 +34,11 @@ async def snapshot_rows( **_kwargs: dict, ) -> TableContainer: """Take a by-row snapshot of the tabular data.""" - data = input.get_input() + if isinstance(input.get_input(), DataFrameGroupBy): + msg = "Cannot snapshot rows of a grouped DataFrame" + raise TypeError(msg) + + data = cast(DataFrame, input.get_input()) parsed_formats = _parse_formats(formats) num_rows = len(data) @@ -42,13 +49,19 @@ def get_row_name(row: Any, row_idx: Any): return f"{base_name}.{row_idx}" return f"{base_name}.{row[row_name_column]}" + if to is not None: + # init the table column where the filenames will be stored + data[to] = None + for row_idx, row in data.iterrows(): + # for each row, save the data in the specified formats for fmt in parsed_formats: row_name = get_row_name(row, row_idx) extension = fmt.extension + filename = f"{row_name}.{extension}" if fmt.format == "json": await storage.set( - f"{row_name}.{extension}", + filename, ( json.dumps(row[column], ensure_ascii=False) if column is not None @@ -59,8 +72,11 @@ def get_row_name(row: Any, row_idx: Any): if column is None: msg = "column must be specified for text format" raise ValueError(msg) - await storage.set(f"{row_name}.{extension}", str(row[column])) - + await storage.set(filename, str(row[column])) + + if to is not None and isinstance(row_idx, int): + data.loc[row_idx, to] = filename + return TableContainer(table=data) diff --git a/graphrag/index/workflows/v1/create_base_entity_graph.py b/graphrag/index/workflows/v1/create_base_entity_graph.py index b001aad2..71904135 100644 --- a/graphrag/index/workflows/v1/create_base_entity_graph.py +++ b/graphrag/index/workflows/v1/create_base_entity_graph.py @@ -7,7 +7,6 @@ workflow_name = "create_base_entity_graph" - def build_steps( config: PipelineWorkflowConfig, ) -> list[PipelineWorkflowStep]: @@ -35,10 +34,20 @@ def build_steps( }, ) - graphml_snapshot_enabled = config.get("graphml_snapshot", False) or False + graphml_snapshot_enabled = True embed_graph_enabled = config.get("embed_graph_enabled", False) or False return [ + { + "verb": "restore_snapshot_rows", + "enabled": graphml_snapshot_enabled, + "args": { + "column": "filepath", + "to": "entity_graph", + "formats": [{"format": "text", "extension": "graphml"}], + }, + "input": ({"source": "workflow:create_summarized_entities"}), + }, { "verb": "cluster_graph", "args": { @@ -46,8 +55,7 @@ def build_steps( "column": "entity_graph", "to": "clustered_graph", "level_to": "level", - }, - "input": ({"source": "workflow:create_summarized_entities"}), + } }, { "verb": "snapshot_rows", @@ -55,6 +63,7 @@ def build_steps( "args": { "base_name": "clustered_graph", "column": "clustered_graph", + "to": "clustered_graph_filepath", "formats": [{"format": "text", "extension": "graphml"}], }, }, @@ -73,6 +82,7 @@ def build_steps( "args": { "base_name": "embedded_graph", "column": "entity_graph", + "to": "embedded_graph_filepath", "formats": [{"format": "text", "extension": "graphml"}], }, }, @@ -82,9 +92,9 @@ def build_steps( # only selecting for documentation sake, so we know what is contained in # this workflow "columns": ( - ["level", "clustered_graph", "embeddings"] + ["level", "clustered_graph_filepath", "embeddings"] if embed_graph_enabled - else ["level", "clustered_graph"] + else ["level", "clustered_graph_filepath"] ), }, }, diff --git a/graphrag/index/workflows/v1/create_base_extracted_entities.py b/graphrag/index/workflows/v1/create_base_extracted_entities.py index 30d608e9..79c28c1a 100644 --- a/graphrag/index/workflows/v1/create_base_extracted_entities.py +++ b/graphrag/index/workflows/v1/create_base_extracted_entities.py @@ -20,7 +20,7 @@ def build_steps( * `workflow:create_base_text_units` """ entity_extraction_config = config.get("entity_extract", {}) - graphml_snapshot_enabled = config.get("graphml_snapshot", False) or False + graphml_snapshot_enabled = True raw_entity_snapshot_enabled = config.get("raw_entity_snapshot", False) or False return [ @@ -84,12 +84,21 @@ def build_steps( }, }, { + # To-Do: update the snapshot_rows verb to include a "filename" column that + # we can use to located the stored graphml file in the future. "verb": "snapshot_rows", "enabled": graphml_snapshot_enabled, "args": { "base_name": "merged_graph", "column": "entity_graph", + "to": "filepath", "formats": [{"format": "text", "extension": "graphml"}], }, }, + { + "verb": "select", + "args": { + "columns": (["filepath"]), + }, + }, ] diff --git a/graphrag/index/workflows/v1/create_final_communities.py b/graphrag/index/workflows/v1/create_final_communities.py index f8949dfc..1c213bc0 100644 --- a/graphrag/index/workflows/v1/create_final_communities.py +++ b/graphrag/index/workflows/v1/create_final_communities.py @@ -18,6 +18,17 @@ def build_steps( * `workflow:create_base_entity_graph` """ return [ + { + "id": "local:create_base_entity_graph", + "verb": "restore_snapshot_rows", + "enabled": True, + "args": { + "column": "clustered_graph_filepath", + "to": "clustered_graph", + "formats": [{"format": "text", "extension": "graphml"}], + }, + "input": {"source": "workflow:create_base_entity_graph"}, + }, { "id": "graph_nodes", "verb": "unpack_graph", @@ -25,7 +36,7 @@ def build_steps( "column": "clustered_graph", "type": "nodes", }, - "input": {"source": "workflow:create_base_entity_graph"}, + "input": {"source": "local:create_base_entity_graph"}, }, { "id": "graph_edges", @@ -34,7 +45,7 @@ def build_steps( "column": "clustered_graph", "type": "edges", }, - "input": {"source": "workflow:create_base_entity_graph"}, + "input": {"source": "local:create_base_entity_graph"}, }, { "id": "source_clusters", diff --git a/graphrag/index/workflows/v1/create_final_entities.py b/graphrag/index/workflows/v1/create_final_entities.py index 9d8b962b..d36fdf18 100644 --- a/graphrag/index/workflows/v1/create_final_entities.py +++ b/graphrag/index/workflows/v1/create_final_entities.py @@ -30,13 +30,23 @@ def build_steps( ) return [ + { + "verb": "restore_snapshot_rows", + "enabled": True, + "args": { + "column": "clustered_graph_filepath", + "to": "clustered_graph", + "formats": [{"format": "text", "extension": "graphml"}], + }, + "input": {"source": "workflow:create_base_entity_graph"}, + }, { "verb": "unpack_graph", "args": { "column": "clustered_graph", "type": "nodes", }, - "input": {"source": "workflow:create_base_entity_graph"}, + }, {"verb": "rename", "args": {"columns": {"label": "title"}}}, { diff --git a/graphrag/index/workflows/v1/create_final_nodes.py b/graphrag/index/workflows/v1/create_final_nodes.py index 31277e7b..0c940414 100644 --- a/graphrag/index/workflows/v1/create_final_nodes.py +++ b/graphrag/index/workflows/v1/create_final_nodes.py @@ -77,6 +77,16 @@ def build_steps( }, ) return [ + { + "verb": "restore_snapshot_rows", + "enabled": True, + "args": { + "column": "clustered_graph_filepath", + "to": "clustered_graph", + "formats": [{"format": "text", "extension": "graphml"}], + }, + "input": {"source": "workflow:create_base_entity_graph"}, + }, { "id": "laid_out_entity_graph", "verb": "layout_graph", @@ -87,7 +97,7 @@ def build_steps( "graph_to": "positioned_graph", **layout_graph_config, }, - "input": {"source": "workflow:create_base_entity_graph"}, + # "input": {"source": "workflow:create_base_entity_graph"}, }, { "verb": "unpack_graph", diff --git a/graphrag/index/workflows/v1/create_final_relationships.py b/graphrag/index/workflows/v1/create_final_relationships.py index a58c2a45..9228df86 100644 --- a/graphrag/index/workflows/v1/create_final_relationships.py +++ b/graphrag/index/workflows/v1/create_final_relationships.py @@ -24,13 +24,23 @@ def build_steps( skip_description_embedding = config.get("skip_description_embedding", False) return [ + { + "verb": "restore_snapshot_rows", + "enabled": True, + "args": { + "column": "clustered_graph_filepath", + "to": "clustered_graph", + "formats": [{"format": "text", "extension": "graphml"}], + }, + "input": {"source": "workflow:create_base_entity_graph"}, + }, { "verb": "unpack_graph", "args": { "column": "clustered_graph", "type": "edges", }, - "input": {"source": "workflow:create_base_entity_graph"}, + # "input": {"source": "workflow:create_base_entity_graph"}, }, { "verb": "rename", diff --git a/graphrag/index/workflows/v1/create_summarized_entities.py b/graphrag/index/workflows/v1/create_summarized_entities.py index 8f8d7f00..19a2b92a 100644 --- a/graphrag/index/workflows/v1/create_summarized_entities.py +++ b/graphrag/index/workflows/v1/create_summarized_entities.py @@ -20,9 +20,19 @@ def build_steps( * `workflow:create_base_text_units` """ summarize_descriptions_config = config.get("summarize_descriptions", {}) - graphml_snapshot_enabled = config.get("graphml_snapshot", False) or False + graphml_snapshot_enabled = True return [ + { + "verb": "restore_snapshot_rows", + "enabled": graphml_snapshot_enabled, + "args": { + "column": "filepath", + "to": "entity_graph", + "formats": [{"format": "text", "extension": "graphml"}], + }, + "input": {"source": "workflow:create_base_extracted_entities"}, + }, { "verb": "summarize_descriptions", "args": { @@ -33,7 +43,6 @@ def build_steps( "async_mode", AsyncType.AsyncIO ), }, - "input": {"source": "workflow:create_base_extracted_entities"}, }, { "verb": "snapshot_rows", @@ -41,7 +50,14 @@ def build_steps( "args": { "base_name": "summarized_graph", "column": "entity_graph", + "to": "filepath", "formats": [{"format": "text", "extension": "graphml"}], }, }, + { + "verb": "select", + "args": { + "columns": (["filepath"]), + }, + }, ]