diff --git a/golem/core/dag/graph.py b/golem/core/dag/graph.py index 2bcbe062..e6d2ca5e 100644 --- a/golem/core/dag/graph.py +++ b/golem/core/dag/graph.py @@ -1,4 +1,5 @@ from abc import ABC, abstractmethod +from enum import Enum from os import PathLike from typing import Dict, List, Optional, Sequence, Union, Tuple, TypeVar @@ -8,6 +9,13 @@ NodeType = TypeVar('NodeType', bound=GraphNode, covariant=False, contravariant=False) +class ReconnectType(Enum): + """Defines allowed kinds of removals in Graph. Used by mutations.""" + none = 'none' # do not reconnect predecessors + single = 'single' # reconnect a predecessor only if it's single + all = 'all' # reconnect all predecessors to all successors + + class Graph(ABC): """Defines abstract graph interface that's required by graph optimisation process. """ @@ -41,12 +49,13 @@ def update_subtree(self, old_subtree: GraphNode, new_subtree: GraphNode): raise NotImplementedError() @abstractmethod - def delete_node(self, node: GraphNode): + def delete_node(self, node: GraphNode, reconnect: ReconnectType = ReconnectType.single): """Removes ``node`` from the graph. If ``node`` has only one child, then connects all of the ``node`` parents to it. Args: node: node of the graph to be deleted + reconnect: defines how to treat left edges between parents and children """ raise NotImplementedError() @@ -84,7 +93,7 @@ def connect_nodes(self, node_parent: GraphNode, node_child: GraphNode): @abstractmethod def disconnect_nodes(self, node_parent: GraphNode, node_child: GraphNode, - clean_up_leftovers: bool = True): + clean_up_leftovers: bool = False): """Removes an edge between two nodes Args: diff --git a/golem/core/dag/graph_delegate.py b/golem/core/dag/graph_delegate.py index a0905434..765ab19e 100644 --- a/golem/core/dag/graph_delegate.py +++ b/golem/core/dag/graph_delegate.py @@ -1,6 +1,6 @@ from typing import Union, Sequence, List, Optional, Tuple, Type -from golem.core.dag.graph import Graph +from golem.core.dag.graph import Graph, ReconnectType from golem.core.dag.graph_node import GraphNode from golem.core.dag.linked_graph import LinkedGraph @@ -26,8 +26,8 @@ def update_node(self, old_node: GraphNode, new_node: GraphNode): def update_subtree(self, old_subtree: GraphNode, new_subtree: GraphNode): self.operator.update_subtree(old_subtree, new_subtree) - def delete_node(self, node: GraphNode): - self.operator.delete_node(node) + def delete_node(self, node: GraphNode, reconnect: ReconnectType = ReconnectType.single): + self.operator.delete_node(node, reconnect) def delete_subtree(self, subtree: GraphNode): self.operator.delete_subtree(subtree) @@ -39,7 +39,7 @@ def connect_nodes(self, node_parent: GraphNode, node_child: GraphNode): self.operator.connect_nodes(node_parent, node_child) def disconnect_nodes(self, node_parent: GraphNode, node_child: GraphNode, - clean_up_leftovers: bool = True): + clean_up_leftovers: bool = False): self.operator.disconnect_nodes(node_parent, node_child, clean_up_leftovers) def get_edges(self) -> Sequence[Tuple[GraphNode, GraphNode]]: diff --git a/golem/core/dag/linked_graph.py b/golem/core/dag/linked_graph.py index fc0b6d50..ee1af54d 100644 --- a/golem/core/dag/linked_graph.py +++ b/golem/core/dag/linked_graph.py @@ -3,7 +3,7 @@ from networkx import graph_edit_distance, set_node_attributes -from golem.core.dag.graph import Graph +from golem.core.dag.graph import Graph, ReconnectType from golem.core.dag.graph_node import GraphNode from golem.core.dag.graph_utils import ordered_subnodes_hierarchy, node_depth from golem.core.dag.convert import graph_structure_as_nx_graph @@ -34,19 +34,25 @@ def _empty_postprocess(*args): pass @copy_doc(Graph.delete_node) - def delete_node(self, node: GraphNode): + def delete_node(self, node: GraphNode, reconnect: ReconnectType = ReconnectType.single) -> object: node_children_cached = self.node_children(node) self._nodes.remove(node) for node_child in node_children_cached: node_child.nodes_from.remove(node) - # if removed node had a single child - # then reconnect it to preceding parent nodes. - if node.nodes_from and len(node_children_cached) == 1: - child = node_children_cached[0] - for node_from in node.nodes_from: - child.nodes_from.append(node_from) + if reconnect == ReconnectType.single: + # if removed node had a single child + # then reconnect it to preceding parent nodes. + if node.nodes_from and len(node_children_cached) == 1: + child = node_children_cached[0] + child.nodes_from.extend(node.nodes_from) + elif reconnect == ReconnectType.all: + if node.nodes_from: + for child in node_children_cached: + child.nodes_from.extend(node.nodes_from) + elif reconnect == ReconnectType.none: + pass self._postprocess_nodes(self, self._nodes) @@ -126,7 +132,7 @@ def _clean_up_leftovers(self, node: GraphNode): @copy_doc(Graph.disconnect_nodes) def disconnect_nodes(self, node_parent: GraphNode, node_child: GraphNode, - clean_up_leftovers: bool = True): + clean_up_leftovers: bool = False): if node_parent not in node_child.nodes_from: return if node_parent not in self._nodes or node_child not in self._nodes: diff --git a/golem/core/optimisers/advisor.py b/golem/core/optimisers/advisor.py index dbc2905f..0d276265 100644 --- a/golem/core/optimisers/advisor.py +++ b/golem/core/optimisers/advisor.py @@ -6,10 +6,11 @@ class RemoveType(Enum): """Defines allowed kinds of removals in Graph. Used by mutations.""" + forbidden = 'forbidden' node_only = 'node_only' + node_rewire = 'node_rewire' with_direct_children = 'with_direct_children' with_parents = 'with_parents' - forbidden = 'forbidden' class DefaultChangeAdvisor: @@ -24,7 +25,7 @@ def propose_change(self, node: OptNode, possible_operations: List[Any]) -> List[ return possible_operations def can_be_removed(self, node: OptNode) -> RemoveType: - return RemoveType.node_only + return RemoveType.node_rewire def propose_parent(self, node: OptNode, possible_operations: List[Any]) -> List[Any]: return possible_operations diff --git a/golem/core/optimisers/genetic/operators/base_mutations.py b/golem/core/optimisers/genetic/operators/base_mutations.py index c680dc26..9030bb26 100644 --- a/golem/core/optimisers/genetic/operators/base_mutations.py +++ b/golem/core/optimisers/genetic/operators/base_mutations.py @@ -4,6 +4,7 @@ from typing import TYPE_CHECKING from golem.core.adapter import register_native +from golem.core.dag.graph import ReconnectType from golem.core.dag.graph_node import GraphNode from golem.core.dag.graph_utils import distance_to_root_level, ordered_subnodes_hierarchy, distance_to_primary_level from golem.core.optimisers.advisor import RemoveType @@ -171,7 +172,8 @@ def add_as_child(graph: OptGraph, graph.connect_nodes(node_parent=node_to_mutate, node_child=new_node) if new_node_child: graph.connect_nodes(node_parent=new_node, node_child=new_node_child) - graph.disconnect_nodes(node_parent=node_to_mutate, node_child=new_node_child) + graph.disconnect_nodes(node_parent=node_to_mutate, node_child=new_node_child, + clean_up_leftovers=True) return graph @@ -246,18 +248,17 @@ def single_drop_mutation(graph: OptGraph, if n.descriptive_id.count('data_source') == 1 and node_name in n.descriptive_id] for child_node in nodes_to_delete: - graph.delete_node(child_node) + graph.delete_node(child_node, reconnect=ReconnectType.all) elif removal_type == RemoveType.with_parents: graph.delete_subtree(node_to_del) - elif removal_type != RemoveType.forbidden: - graph.delete_node(node_to_del) - if node_to_del.nodes_from: - children = graph.node_children(node_to_del) - for child in children: - if child.nodes_from: - child.nodes_from.extend(node_to_del.nodes_from) - else: - child.nodes_from = node_to_del.nodes_from + elif removal_type == RemoveType.node_rewire: + graph.delete_node(node_to_del, reconnect=ReconnectType.all) + elif removal_type == RemoveType.node_only: + graph.delete_node(node_to_del, reconnect=ReconnectType.none) + elif removal_type == RemoveType.forbidden: + pass + else: + raise ValueError("Unknown advice (RemoveType) returned by Advisor ") return graph diff --git a/test/unit/dag/test_graph.py b/test/unit/dag/test_graph.py index fccbee97..34412cad 100644 --- a/test/unit/dag/test_graph.py +++ b/test/unit/dag/test_graph.py @@ -4,7 +4,7 @@ import numpy as np import pytest -from golem.core.dag.graph import Graph +from golem.core.dag.graph import Graph, ReconnectType from golem.core.dag.graph_delegate import GraphDelegate from golem.core.dag.linked_graph import LinkedGraph from golem.core.dag.linked_graph_node import LinkedGraphNode @@ -118,6 +118,40 @@ def test_delete_intermediate_node(): assert graph.depth == 2 +def test_delete_leave_cycle(): + first = GraphNode(content='n1') + second = GraphNode(content='n2', nodes_from=[first]) + third = GraphNode(content='n3', nodes_from=[second]) + final = GraphNode(content='n4', nodes_from=[third]) + graph = GraphImpl(final) + graph.connect_nodes(final, first) + + assert len(graph.get_edges()) == 4 + + graph.delete_node(third, reconnect=ReconnectType.single) + + assert third not in graph.nodes + assert len(graph.get_edges()) == 3 + assert (second, final) in graph.get_edges() + + +def test_delete_break_cycle(): + first = GraphNode(content='n1') + second = GraphNode(content='n2', nodes_from=[first]) + third = GraphNode(content='n3', nodes_from=[second]) + final = GraphNode(content='n4', nodes_from=[third]) + graph = GraphImpl(final) + graph.connect_nodes(final, first) + + assert len(graph.get_edges()) == 4 + + graph.delete_node(third, reconnect=ReconnectType.none) + + assert third not in graph.nodes + assert len(graph.get_edges()) == 2 + assert not final.nodes_from + + def test_delete_node_with_duplicated_edges(): ok_primary_node = GraphNode('n1') bad_primary_node = GraphNode('n2') diff --git a/test/unit/dag/test_graph_operator.py b/test/unit/dag/test_graph_operator.py index 9b1d6be0..d855b59b 100644 --- a/test/unit/dag/test_graph_operator.py +++ b/test/unit/dag/test_graph_operator.py @@ -184,7 +184,7 @@ def test_disconnect_nodes_method_first(): node_e = graph.nodes[4] node_e_root = graph.nodes[0] - graph.disconnect_nodes(node_e, node_e_root) + graph.disconnect_nodes(node_e, node_e_root, clean_up_leftovers=True) assert res_graph == graph @@ -197,7 +197,7 @@ def test_disconnect_nodes_method_second(): node_b = graph.nodes[5] node_e = graph.nodes[4] - graph.disconnect_nodes(node_b, node_e) + graph.disconnect_nodes(node_b, node_e, clean_up_leftovers=True) assert res_graph == graph @@ -210,7 +210,7 @@ def test_disconnect_nodes_method_third(): node_d = graph.nodes[1] root_node_e = graph.nodes[0] - graph.disconnect_nodes(node_d, root_node_e) + graph.disconnect_nodes(node_d, root_node_e, clean_up_leftovers=True) assert res_graph == graph @@ -224,7 +224,7 @@ def test_disconnect_nodes_method_fourth(): node_c = res_graph.nodes[2] root_node_e = res_graph.nodes[0] - res_graph.disconnect_nodes(node_c, root_node_e) + res_graph.disconnect_nodes(node_c, root_node_e, clean_up_leftovers=True) assert res_graph == graph @@ -237,7 +237,7 @@ def test_disconnect_nodes_method_fifth(): node_k = LinkedGraphNode('k') node_m = LinkedGraphNode('m', nodes_from=[node_k]) - res_graph.disconnect_nodes(node_k, node_m) + res_graph.disconnect_nodes(node_k, node_m, clean_up_leftovers=True) assert res_graph == graph