Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 20 additions & 27 deletions burr/core/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -2065,7 +2065,6 @@ def __init__(self):
self.loaded_from_fork: bool = False
self.tracker = None
self.graph_builder = None
self.prebuilt_graph = None
self.typing_system = None
self.parallel_executor_factory = None
self.state_persister = None
Expand Down Expand Up @@ -2143,9 +2142,21 @@ def with_state(
self.state = State(kwargs)
return self

def with_graphs(self, *graphs) -> "ApplicationBuilder[StateType]":
"""Adds multiple prebuilt graphs -- this just calls :py:meth:`with_graph <burr.core.application.ApplicationBuilder.with_graph>`
in a loop! See caveats in :py:meth:`with_graph <burr.core.application.ApplicationBuilder.with_graph>`.

:param graphs: Graphs to add to the application
:return: The application builder for future chaining.
"""
for graph in graphs:
self.with_graph(graph)
return self

def with_graph(self, graph: Graph) -> "ApplicationBuilder[StateType]":
"""Adds a prebuilt graph -- this is an alternative to using the with_actions and with_transitions methods.
While you will likely use with_actions and with_transitions, you may want this in a few cases:
"""Adds a prebuilt graph -- this can work in addition to using with_actions and with_transitions methods.
This will add all nodes + edges from a prebuilt graph to the current graph. Note that if you add two
graphs (or a combination of graphs/nodes/edges), you will need to ensure that there are no node name conflicts.

1. You want to reuse the same graph object for different applications
2. You want the logic that constructs the graph to be separate from that which constructs the application
Expand All @@ -2154,13 +2165,8 @@ def with_graph(self, graph: Graph) -> "ApplicationBuilder[StateType]":
:param graph: Graph object built with the :py:class:`GraphBuilder <burr.core.graph.GraphBuilder>`
:return: The application builder for future chaining.
"""
if self.graph_builder is not None:
raise ValueError(
BASE_ERROR_MESSAGE
+ "You have already called `with_actions`, or `with_transitions` -- you currently "
"cannot use the with_graph method along with that. Use `with_graph` or the other methods, not both"
)
self.prebuilt_graph = graph
self._initialize_graph_builder()
self.graph_builder = self.graph_builder.with_graph(graph)
return self

def with_parallel_executor(self, executor_factory: lambda: Executor):
Expand Down Expand Up @@ -2190,15 +2196,6 @@ def with_parallel_executor(self, executor_factory: lambda: Executor):
self.parallel_executor_factory = executor_factory
return self

def _ensure_no_prebuilt_graph(self):
if self.prebuilt_graph is not None:
raise ValueError(
BASE_ERROR_MESSAGE + "You have already called `with_graph` -- you currently "
"cannot use the with_actions, or with_transitions method along with that. "
"Use `with_graph` or the other methods, not both."
)
return self

def _initialize_graph_builder(self):
if self.graph_builder is None:
self.graph_builder = GraphBuilder()
Expand Down Expand Up @@ -2233,7 +2230,6 @@ def with_actions(
:param action_dict: Actions to add, keyed by name
:return: The application builder for future chaining.
"""
self._ensure_no_prebuilt_graph()
self._initialize_graph_builder()
self.graph_builder = self.graph_builder.with_actions(*action_list, **action_dict)
return self
Expand All @@ -2256,7 +2252,6 @@ def with_transitions(
:param transitions: Transitions to add
:return: The application builder for future chaining.
"""
self._ensure_no_prebuilt_graph()
self._initialize_graph_builder()
self.graph_builder = self.graph_builder.with_transitions(*transitions)
return self
Expand Down Expand Up @@ -2583,15 +2578,13 @@ def reset_to_entrypoint(self):
self.state = self.state.wipe(delete=[PRIOR_STEP])

def _get_built_graph(self) -> Graph:
if self.graph_builder is None and self.prebuilt_graph is None:
if self.graph_builder is None:
raise ValueError(
BASE_ERROR_MESSAGE
+ "You must set the graph using with_graph, or use with_entrypoint, with_actions, and with_transitions"
" to build the graph."
+ "No graph constructs exist. You must call some combination of with_graph, with_entrypoint, "
"with_actions, and with_transitions"
)
if self.graph_builder is not None:
return self.graph_builder.build()
return self.prebuilt_graph
return self.graph_builder.build()

def _build_common(self) -> Application:
graph = self._get_built_graph()
Expand Down
29 changes: 29 additions & 0 deletions burr/core/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,16 @@ def _validate_actions(actions: Optional[List[Action]]):
assert_set(actions, "_actions", "with_actions")
if len(actions) == 0:
raise ValueError("Must have at least one action in the application!")
seen_action_names = set()
for action in actions:
if action.name in seen_action_names:
raise ValueError(
f"Action: {action.name} is duplicated in the actions list. "
"actions have unique names. This could happen"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider improving the formatting of the error message here. The concatenated strings lack proper spacing between sentences, e.g. add a space or punctuation after happen to separate it from the subsequent clause.

Suggested change
"actions have unique names. This could happen"
"actions have unique names. This could happen. "

"if you add two actions with the same name or add a graph that"
"has actions with the same name as any that already exist."
)
seen_action_names.add(action.name)


def _validate_transitions(
Expand Down Expand Up @@ -321,6 +331,25 @@ def with_transitions(
self.transitions.append((action, to_, condition))
return self

def with_graph(self, graph: Graph) -> "GraphBuilder":
"""Adds an existing graph to the builder. Note that if you have any name clashes
this will error out. This would happen if you add actions with the same name as actions
that already exist.

:param graph: The graph to add
:return: The application builder for future chaining.
"""
if self.actions is None:
self.actions = []
if self.transitions is None:
self.transitions = []
self.actions.extend(graph.actions)
self.transitions.extend(
(transition.from_.name, transition.to.name, transition.condition)
for transition in graph.transitions
)
return self

def build(self) -> Graph:
"""Builds/finalizes the graph.

Expand Down
32 changes: 32 additions & 0 deletions tests/core/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,11 @@ def test__validate_actions_empty():
_validate_actions([])


def test__validate_actions_duplicated():
with pytest.raises(ValueError, match="duplicated"):
_validate_actions([Result("test"), Result("test")])


base_counter_action = PassedInAction(
reads=["count"],
writes=["count"],
Expand All @@ -110,6 +115,33 @@ def test_graph_builder_builds():
assert len(graph.transitions) == 2


def test_graph_builder_with_graph():
graph1 = (
GraphBuilder()
.with_actions(counter=base_counter_action)
.with_transitions(("counter", "counter", Condition.expr("count < 10")))
.build()
)
graph2 = (
GraphBuilder()
.with_actions(counter2=base_counter_action)
.with_transitions(("counter2", "counter2", Condition.expr("count < 20")))
.build()
)
graph = (
GraphBuilder()
.with_graph(graph1)
.with_graph(graph2)
.with_actions(result=Result("count"))
.with_transitions(
("counter", "counter2"),
("counter2", "result"),
)
)
assert len(graph.actions) == 3
assert len(graph.transitions) == 4


def test_graph_builder_get_next_node():
graph = (
GraphBuilder()
Expand Down