|
| 1 | +from __future__ import annotations |
| 2 | +import typing as ty |
| 3 | +import json |
| 4 | +import tempfile |
| 5 | +from pathlib import Path |
| 6 | +import subprocess as sp |
| 7 | +from collections import defaultdict |
| 8 | +import black |
| 9 | +from nipype.interfaces.base import isdefined |
| 10 | +from .utils import load_class_or_func |
| 11 | + |
| 12 | + |
| 13 | +class WorkflowConverter: |
| 14 | + # creating the wf |
| 15 | + def __init__(self, spec): |
| 16 | + self.spec = spec |
| 17 | + |
| 18 | + self.wf = load_class_or_func(self.spec["function"])( |
| 19 | + **self._parse_workflow_args(self.spec["args"]) |
| 20 | + ) |
| 21 | + # loads the 'function' in smriprep.yaml, and implement the args (creates a |
| 22 | + # dictionary) |
| 23 | + |
| 24 | + def node_connections( |
| 25 | + self, |
| 26 | + workflow, |
| 27 | + functions: dict[str, dict], |
| 28 | + # wf_inputs: dict[str, str], |
| 29 | + # wf_outputs: dict[str, str], |
| 30 | + ): |
| 31 | + connections: defaultdict = defaultdict(dict) |
| 32 | + |
| 33 | + # iterates over wf graph, Get connections from workflow graph, store connections |
| 34 | + # in a dictionary |
| 35 | + for edge, props in workflow._graph.edges.items(): |
| 36 | + src_node = edge[0].name |
| 37 | + dest_node = edge[1].name |
| 38 | + dest_node_fullname = workflow.get_node(dest_node).fullname |
| 39 | + for node_conn in props["connect"]: |
| 40 | + src_field = node_conn[0] |
| 41 | + dest_field = node_conn[1] |
| 42 | + if src_field[1].startswith("def"): |
| 43 | + functions[dest_node_fullname][dest_field] = src_field[1] |
| 44 | + else: |
| 45 | + connections[dest_node_fullname][ |
| 46 | + dest_field |
| 47 | + ] = f"{src_node}.lzout.{src_field}" |
| 48 | + |
| 49 | + for nested_wf in workflow._nested_workflows_cache: |
| 50 | + connections.update(self.node_connections(nested_wf, functions=functions)) |
| 51 | + return connections |
| 52 | + |
| 53 | + def generate(self, format_with_black: bool = False): |
| 54 | + |
| 55 | + functions = defaultdict(dict) |
| 56 | + connections = self.node_connections(self.wf, functions=functions) |
| 57 | + out_text = "" |
| 58 | + for node_name in self.wf.list_node_names(): |
| 59 | + node = self.wf.get_node(node_name) |
| 60 | + |
| 61 | + interface_type = type(node.interface) |
| 62 | + |
| 63 | + task_type = interface_type.__module__ + "." + interface_type.__name__ |
| 64 | + node_args = "" |
| 65 | + for arg in node.inputs.visible_traits(): |
| 66 | + val = getattr(node.inputs, arg) # Enclose strings in quotes |
| 67 | + if isdefined(val): |
| 68 | + try: |
| 69 | + val = json.dumps(val) |
| 70 | + except TypeError: |
| 71 | + pass |
| 72 | + if isinstance(val, str) and "\n" in val: |
| 73 | + val = '"""' + val + '""""' |
| 74 | + node_args += f",\n {arg}={val}" |
| 75 | + |
| 76 | + for arg, val in connections[node.fullname].items(): |
| 77 | + node_args += f",\n {arg}=wf.{val}" |
| 78 | + |
| 79 | + out_text += f""" |
| 80 | + wf.add( |
| 81 | + {task_type}( |
| 82 | + name="{node.name}"{node_args} |
| 83 | + ) |
| 84 | + )""" |
| 85 | + |
| 86 | + if format_with_black: |
| 87 | + out_text = black.format_file_contents( |
| 88 | + out_text, fast=False, mode=black.FileMode() |
| 89 | + ) |
| 90 | + return out_text |
| 91 | + |
| 92 | + @classmethod |
| 93 | + def _parse_workflow_args(cls, args): |
| 94 | + dct = {} |
| 95 | + for name, val in args.items(): |
| 96 | + if isinstance(val, dict) and sorted(val.keys()) == ["args", "type"]: |
| 97 | + val = load_class_or_func(val["type"])( |
| 98 | + **cls._parse_workflow_args(val["args"]) |
| 99 | + ) |
| 100 | + dct[name] = val |
| 101 | + return dct |
| 102 | + |
| 103 | + def save_graph( |
| 104 | + self, out_path: Path, format: str = "svg", work_dir: ty.Optional[Path] = None |
| 105 | + ): |
| 106 | + if work_dir is None: |
| 107 | + work_dir = Path(tempfile.mkdtemp()) |
| 108 | + work_dir = Path(work_dir) |
| 109 | + graph_dot_path = work_dir / "wf-graph.dot" |
| 110 | + self.wf.write_hierarchical_dotfile(graph_dot_path) |
| 111 | + dot_path = sp.check_output("which dot", shell=True).decode("utf-8").strip() |
| 112 | + sp.check_call( |
| 113 | + f"{dot_path} -T{format} {str(graph_dot_path)} > {str(out_path)}", shell=True |
| 114 | + ) |
0 commit comments