diff --git a/burr/core/application.py b/burr/core/application.py index 66dc465c8..a4c15a5a0 100644 --- a/burr/core/application.py +++ b/burr/core/application.py @@ -2065,7 +2065,6 @@ def __init__(self): self.loaded_from_fork: bool = False self.tracker = None self.graph_builder = None - self.prebuilt_graph = None self.typing_system = None self.parallel_executor_factory = None self.state_persister = None @@ -2143,9 +2142,21 @@ def with_state( self.state = State(kwargs) return self + def with_graphs(self, *graphs) -> "ApplicationBuilder[StateType]": + """Adds multiple prebuilt graphs -- this just calls :py:meth:`with_graph ` + in a loop! See caveats in :py:meth:`with_graph `. + + :param graphs: Graphs to add to the application + :return: The application builder for future chaining. + """ + for graph in graphs: + self.with_graph(graph) + return self + def with_graph(self, graph: Graph) -> "ApplicationBuilder[StateType]": - """Adds a prebuilt graph -- this is an alternative to using the with_actions and with_transitions methods. - While you will likely use with_actions and with_transitions, you may want this in a few cases: + """Adds a prebuilt graph -- this can work in addition to using with_actions and with_transitions methods. + This will add all nodes + edges from a prebuilt graph to the current graph. Note that if you add two + graphs (or a combination of graphs/nodes/edges), you will need to ensure that there are no node name conflicts. 1. You want to reuse the same graph object for different applications 2. You want the logic that constructs the graph to be separate from that which constructs the application @@ -2154,13 +2165,8 @@ def with_graph(self, graph: Graph) -> "ApplicationBuilder[StateType]": :param graph: Graph object built with the :py:class:`GraphBuilder ` :return: The application builder for future chaining. """ - if self.graph_builder is not None: - raise ValueError( - BASE_ERROR_MESSAGE - + "You have already called `with_actions`, or `with_transitions` -- you currently " - "cannot use the with_graph method along with that. Use `with_graph` or the other methods, not both" - ) - self.prebuilt_graph = graph + self._initialize_graph_builder() + self.graph_builder = self.graph_builder.with_graph(graph) return self def with_parallel_executor(self, executor_factory: lambda: Executor): @@ -2190,15 +2196,6 @@ def with_parallel_executor(self, executor_factory: lambda: Executor): self.parallel_executor_factory = executor_factory return self - def _ensure_no_prebuilt_graph(self): - if self.prebuilt_graph is not None: - raise ValueError( - BASE_ERROR_MESSAGE + "You have already called `with_graph` -- you currently " - "cannot use the with_actions, or with_transitions method along with that. " - "Use `with_graph` or the other methods, not both." - ) - return self - def _initialize_graph_builder(self): if self.graph_builder is None: self.graph_builder = GraphBuilder() @@ -2233,7 +2230,6 @@ def with_actions( :param action_dict: Actions to add, keyed by name :return: The application builder for future chaining. """ - self._ensure_no_prebuilt_graph() self._initialize_graph_builder() self.graph_builder = self.graph_builder.with_actions(*action_list, **action_dict) return self @@ -2256,7 +2252,6 @@ def with_transitions( :param transitions: Transitions to add :return: The application builder for future chaining. """ - self._ensure_no_prebuilt_graph() self._initialize_graph_builder() self.graph_builder = self.graph_builder.with_transitions(*transitions) return self @@ -2583,15 +2578,13 @@ def reset_to_entrypoint(self): self.state = self.state.wipe(delete=[PRIOR_STEP]) def _get_built_graph(self) -> Graph: - if self.graph_builder is None and self.prebuilt_graph is None: + if self.graph_builder is None: raise ValueError( BASE_ERROR_MESSAGE - + "You must set the graph using with_graph, or use with_entrypoint, with_actions, and with_transitions" - " to build the graph." + + "No graph constructs exist. You must call some combination of with_graph, with_entrypoint, " + "with_actions, and with_transitions" ) - if self.graph_builder is not None: - return self.graph_builder.build() - return self.prebuilt_graph + return self.graph_builder.build() def _build_common(self) -> Application: graph = self._get_built_graph() diff --git a/burr/core/graph.py b/burr/core/graph.py index 6b1ebba6b..a3a227dba 100644 --- a/burr/core/graph.py +++ b/burr/core/graph.py @@ -26,6 +26,16 @@ def _validate_actions(actions: Optional[List[Action]]): assert_set(actions, "_actions", "with_actions") if len(actions) == 0: raise ValueError("Must have at least one action in the application!") + seen_action_names = set() + for action in actions: + if action.name in seen_action_names: + raise ValueError( + f"Action: {action.name} is duplicated in the actions list. " + "actions have unique names. This could happen" + "if you add two actions with the same name or add a graph that" + "has actions with the same name as any that already exist." + ) + seen_action_names.add(action.name) def _validate_transitions( @@ -321,6 +331,25 @@ def with_transitions( self.transitions.append((action, to_, condition)) return self + def with_graph(self, graph: Graph) -> "GraphBuilder": + """Adds an existing graph to the builder. Note that if you have any name clashes + this will error out. This would happen if you add actions with the same name as actions + that already exist. + + :param graph: The graph to add + :return: The application builder for future chaining. + """ + if self.actions is None: + self.actions = [] + if self.transitions is None: + self.transitions = [] + self.actions.extend(graph.actions) + self.transitions.extend( + (transition.from_.name, transition.to.name, transition.condition) + for transition in graph.transitions + ) + return self + def build(self) -> Graph: """Builds/finalizes the graph. diff --git a/tests/core/test_graph.py b/tests/core/test_graph.py index c2013ee6f..1dbba42e6 100644 --- a/tests/core/test_graph.py +++ b/tests/core/test_graph.py @@ -87,6 +87,11 @@ def test__validate_actions_empty(): _validate_actions([]) +def test__validate_actions_duplicated(): + with pytest.raises(ValueError, match="duplicated"): + _validate_actions([Result("test"), Result("test")]) + + base_counter_action = PassedInAction( reads=["count"], writes=["count"], @@ -110,6 +115,33 @@ def test_graph_builder_builds(): assert len(graph.transitions) == 2 +def test_graph_builder_with_graph(): + graph1 = ( + GraphBuilder() + .with_actions(counter=base_counter_action) + .with_transitions(("counter", "counter", Condition.expr("count < 10"))) + .build() + ) + graph2 = ( + GraphBuilder() + .with_actions(counter2=base_counter_action) + .with_transitions(("counter2", "counter2", Condition.expr("count < 20"))) + .build() + ) + graph = ( + GraphBuilder() + .with_graph(graph1) + .with_graph(graph2) + .with_actions(result=Result("count")) + .with_transitions( + ("counter", "counter2"), + ("counter2", "result"), + ) + ) + assert len(graph.actions) == 3 + assert len(graph.transitions) == 4 + + def test_graph_builder_get_next_node(): graph = ( GraphBuilder()