Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion graphrag/config/create_graphrag_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,7 @@ def hydrate_parallelization_params(
reader.use(values.get("snapshots")),
):
snapshots_model = SnapshotsConfig(
graphml=reader.bool("graphml") or defs.SNAPSHOTS_GRAPHML,
graphml = defs.SNAPSHOTS_GRAPHML,
raw_entities=reader.bool("raw_entities") or defs.SNAPSHOTS_RAW_ENTITIES,
top_level_nodes=reader.bool("top_level_nodes")
or defs.SNAPSHOTS_TOP_LEVEL_NODES,
Expand Down
2 changes: 1 addition & 1 deletion graphrag/config/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@
NODE2VEC_RANDOM_SEED = 597832
REPORTING_TYPE = ReportingType.file
REPORTING_BASE_DIR = "output/${timestamp}/reports"
SNAPSHOTS_GRAPHML = False
SNAPSHOTS_GRAPHML = True
SNAPSHOTS_RAW_ENTITIES = False
SNAPSHOTS_TOP_LEVEL_NODES = False
STORAGE_BASE_DIR = "output/${timestamp}/artifacts"
Expand Down
1 change: 0 additions & 1 deletion graphrag/index/init_content.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,6 @@
enabled: false # if true, will generate UMAP embeddings for nodes

snapshots:
graphml: false
raw_entities: false
top_level_nodes: false

Expand Down
4 changes: 3 additions & 1 deletion graphrag/index/utils/load_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@

"""Networkx load_graph utility definition."""

import io

import networkx as nx


def load_graph(graphml: str | nx.Graph) -> nx.Graph:
"""Load a graph from a graphml file or a networkx graph."""
return nx.parse_graphml(graphml) if isinstance(graphml, str) else graphml
return nx.read_graphml(io.StringIO(graphml)) if isinstance(graphml, str) else graphml
2 changes: 2 additions & 0 deletions graphrag/index/verbs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
unpack_graph,
)
from .overrides import aggregate, concat, merge
from .restore_snapshot_rows import restore_snapshot_rows
from .snapshot import snapshot
from .snapshot_rows import snapshot_rows
from .spread_json import spread_json
Expand All @@ -37,6 +38,7 @@
"layout_graph",
"merge",
"merge_graphs",
"restore_snapshot_rows",
"snapshot",
"snapshot_rows",
"spread_json",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ async def summarize_descriptions(
{
"verb": "",
"args": {
"column": "the_document_text_column_to_extract_descriptions_from", /* Required: This will be a graphml graph in string form which represents the entities and their relationships */
"column": "the_document_text_column_to_extract_descriptions_from", /* Required: This will be a graphml graph filepath in string form which represents the entities and their relationships */
"to": "the_column_to_output_the_summarized_descriptions_to", /* Required: This will be a graphml graph in string form which represents the entities and their relationships after being summarized */
"strategy": {...} <strategy_config>, see strategies section below
}
Expand Down
3 changes: 2 additions & 1 deletion graphrag/index/verbs/graph/clustering/cluster_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

"""A module containing cluster_graph, apply_clustering and run_layout methods definition."""

import io
import logging
from enum import Enum
from random import Random
Expand Down Expand Up @@ -115,7 +116,7 @@ def apply_clustering(
) -> nx.Graph:
"""Apply clustering to a graphml string."""
random = Random(seed) # noqa S311
graph = nx.parse_graphml(graphml)
graph = nx.read_graphml(io.StringIO(graphml))
for community_level, community_id, nodes in communities:
if level == community_level:
for node in nodes:
Expand Down
58 changes: 58 additions & 0 deletions graphrag/index/verbs/restore_snapshot_rows.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License

"""A module containing 'FormatSpecifier' model."""

import json
from pathlib import Path
from typing import Any, cast

from datashaper import TableContainer, VerbInput, verb
from pandas.core.frame import DataFrame
from pandas.core.groupby import DataFrameGroupBy

from graphrag.index.storage import PipelineStorage
from graphrag.index.verbs.snapshot_rows import _parse_formats


@verb(name="restore_snapshot_rows")
async def restore_snapshot_rows(
input: VerbInput,
column: str,
to: str,
storage: PipelineStorage,
formats: list[str | dict[str, Any]],
**_kwargs: dict,
) -> TableContainer:
"""Take a by-row snapshot of the tabular data."""
if isinstance(input.get_input(), DataFrameGroupBy):
msg = "Cannot snapshot rows of a grouped DataFrame"
raise TypeError(msg)

data = cast(DataFrame, input.get_input())
parsed_formats = _parse_formats(formats)

# do not modify the original data
new_data = data.copy()
new_data[to] = None
for row_idx, row in data.iterrows():
# for each row, load only the data in the specified formats
for fmt in parsed_formats:
row_name = Path(row[column])
if row_name.suffix[1:] != fmt.extension:
continue

filename = row_name.name
data_bytes:bytes|bytearray = await storage.get(filename, as_bytes=True)
if not isinstance(row_idx, int):
continue

if fmt.format == "json":
new_data.loc[row_idx, to] = json.loads(data_bytes)
elif fmt.format == "text":
new_data.loc[row_idx, to] = data_bytes.decode("utf-8")
else:
msg = f"Unsupported format: {fmt.format}"
raise ValueError(msg)

return TableContainer(table=new_data.dropna(subset=[to]))
26 changes: 21 additions & 5 deletions graphrag/index/verbs/snapshot_rows.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@

import json
from dataclasses import dataclass
from typing import Any
from typing import Any, cast

from datashaper import TableContainer, VerbInput, verb
from pandas.core.frame import DataFrame
from pandas.core.groupby import DataFrameGroupBy

from graphrag.index.storage import PipelineStorage

Expand All @@ -24,14 +26,19 @@ class FormatSpecifier:
async def snapshot_rows(
input: VerbInput,
column: str | None,
to: str | None,
base_name: str,
storage: PipelineStorage,
formats: list[str | dict[str, Any]],
row_name_column: str | None = None,
**_kwargs: dict,
) -> TableContainer:
"""Take a by-row snapshot of the tabular data."""
data = input.get_input()
if isinstance(input.get_input(), DataFrameGroupBy):
msg = "Cannot snapshot rows of a grouped DataFrame"
raise TypeError(msg)

data = cast(DataFrame, input.get_input())
parsed_formats = _parse_formats(formats)
num_rows = len(data)

Expand All @@ -42,13 +49,19 @@ def get_row_name(row: Any, row_idx: Any):
return f"{base_name}.{row_idx}"
return f"{base_name}.{row[row_name_column]}"

if to is not None:
# init the table column where the filenames will be stored
data[to] = None

for row_idx, row in data.iterrows():
# for each row, save the data in the specified formats
for fmt in parsed_formats:
row_name = get_row_name(row, row_idx)
extension = fmt.extension
filename = f"{row_name}.{extension}"
if fmt.format == "json":
await storage.set(
f"{row_name}.{extension}",
filename,
(
json.dumps(row[column], ensure_ascii=False)
if column is not None
Expand All @@ -59,8 +72,11 @@ def get_row_name(row: Any, row_idx: Any):
if column is None:
msg = "column must be specified for text format"
raise ValueError(msg)
await storage.set(f"{row_name}.{extension}", str(row[column]))

await storage.set(filename, str(row[column]))

if to is not None and isinstance(row_idx, int):
data.loc[row_idx, to] = filename

return TableContainer(table=data)


Expand Down
22 changes: 16 additions & 6 deletions graphrag/index/workflows/v1/create_base_entity_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

workflow_name = "create_base_entity_graph"


def build_steps(
config: PipelineWorkflowConfig,
) -> list[PipelineWorkflowStep]:
Expand Down Expand Up @@ -35,26 +34,36 @@ def build_steps(
},
)

graphml_snapshot_enabled = config.get("graphml_snapshot", False) or False
graphml_snapshot_enabled = True
embed_graph_enabled = config.get("embed_graph_enabled", False) or False

return [
{
"verb": "restore_snapshot_rows",
"enabled": graphml_snapshot_enabled,
"args": {
"column": "filepath",
"to": "entity_graph",
"formats": [{"format": "text", "extension": "graphml"}],
},
"input": ({"source": "workflow:create_summarized_entities"}),
},
{
"verb": "cluster_graph",
"args": {
**clustering_config,
"column": "entity_graph",
"to": "clustered_graph",
"level_to": "level",
},
"input": ({"source": "workflow:create_summarized_entities"}),
}
},
{
"verb": "snapshot_rows",
"enabled": graphml_snapshot_enabled,
"args": {
"base_name": "clustered_graph",
"column": "clustered_graph",
"to": "clustered_graph_filepath",
"formats": [{"format": "text", "extension": "graphml"}],
},
},
Expand All @@ -73,6 +82,7 @@ def build_steps(
"args": {
"base_name": "embedded_graph",
"column": "entity_graph",
"to": "embedded_graph_filepath",
"formats": [{"format": "text", "extension": "graphml"}],
},
},
Expand All @@ -82,9 +92,9 @@ def build_steps(
# only selecting for documentation sake, so we know what is contained in
# this workflow
"columns": (
["level", "clustered_graph", "embeddings"]
["level", "clustered_graph_filepath", "embeddings"]
if embed_graph_enabled
else ["level", "clustered_graph"]
else ["level", "clustered_graph_filepath"]
),
},
},
Expand Down
11 changes: 10 additions & 1 deletion graphrag/index/workflows/v1/create_base_extracted_entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def build_steps(
* `workflow:create_base_text_units`
"""
entity_extraction_config = config.get("entity_extract", {})
graphml_snapshot_enabled = config.get("graphml_snapshot", False) or False
graphml_snapshot_enabled = True
raw_entity_snapshot_enabled = config.get("raw_entity_snapshot", False) or False

return [
Expand Down Expand Up @@ -84,12 +84,21 @@ def build_steps(
},
},
{
# To-Do: update the snapshot_rows verb to include a "filename" column that
# we can use to located the stored graphml file in the future.
"verb": "snapshot_rows",
"enabled": graphml_snapshot_enabled,
"args": {
"base_name": "merged_graph",
"column": "entity_graph",
"to": "filepath",
"formats": [{"format": "text", "extension": "graphml"}],
},
},
{
"verb": "select",
"args": {
"columns": (["filepath"]),
},
},
]
15 changes: 13 additions & 2 deletions graphrag/index/workflows/v1/create_final_communities.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,25 @@ def build_steps(
* `workflow:create_base_entity_graph`
"""
return [
{
"id": "local:create_base_entity_graph",
"verb": "restore_snapshot_rows",
"enabled": True,
"args": {
"column": "clustered_graph_filepath",
"to": "clustered_graph",
"formats": [{"format": "text", "extension": "graphml"}],
},
"input": {"source": "workflow:create_base_entity_graph"},
},
{
"id": "graph_nodes",
"verb": "unpack_graph",
"args": {
"column": "clustered_graph",
"type": "nodes",
},
"input": {"source": "workflow:create_base_entity_graph"},
"input": {"source": "local:create_base_entity_graph"},
},
{
"id": "graph_edges",
Expand All @@ -34,7 +45,7 @@ def build_steps(
"column": "clustered_graph",
"type": "edges",
},
"input": {"source": "workflow:create_base_entity_graph"},
"input": {"source": "local:create_base_entity_graph"},
},
{
"id": "source_clusters",
Expand Down
12 changes: 11 additions & 1 deletion graphrag/index/workflows/v1/create_final_entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,23 @@ def build_steps(
)

return [
{
"verb": "restore_snapshot_rows",
"enabled": True,
"args": {
"column": "clustered_graph_filepath",
"to": "clustered_graph",
"formats": [{"format": "text", "extension": "graphml"}],
},
"input": {"source": "workflow:create_base_entity_graph"},
},
{
"verb": "unpack_graph",
"args": {
"column": "clustered_graph",
"type": "nodes",
},
"input": {"source": "workflow:create_base_entity_graph"},

},
{"verb": "rename", "args": {"columns": {"label": "title"}}},
{
Expand Down
12 changes: 11 additions & 1 deletion graphrag/index/workflows/v1/create_final_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,16 @@ def build_steps(
},
)
return [
{
"verb": "restore_snapshot_rows",
"enabled": True,
"args": {
"column": "clustered_graph_filepath",
"to": "clustered_graph",
"formats": [{"format": "text", "extension": "graphml"}],
},
"input": {"source": "workflow:create_base_entity_graph"},
},
{
"id": "laid_out_entity_graph",
"verb": "layout_graph",
Expand All @@ -87,7 +97,7 @@ def build_steps(
"graph_to": "positioned_graph",
**layout_graph_config,
},
"input": {"source": "workflow:create_base_entity_graph"},
# "input": {"source": "workflow:create_base_entity_graph"},
},
{
"verb": "unpack_graph",
Expand Down
Loading