diff --git a/pipeline/src/collection.py b/pipeline/src/collection.py index 123d51f8..11dfa7df 100644 --- a/pipeline/src/collection.py +++ b/pipeline/src/collection.py @@ -220,3 +220,29 @@ def validate(self, ignore=None): def is_valid(self): failures = self.validate() return len(failures) == 0 + + def sort_nodes_for_upload(self): + """ + Return a list of nodes, sorted so that they can be uploaded to a graph database safely, + i.e., child nodes will be saved before their parents. + + The upload code is assumed to generate @ids and update the Python instances accordingly. + """ + unsorted = set(self.nodes.keys()) + sorted = [] + # initial step: move nodes with no children (downstream links) directly to `sorted` + for node_id in unsorted: + if len(self.nodes[node_id].links) == 0: + sorted.append(node_id) + unsorted -= set(sorted) + # now iteratively add nodes to `sorted` if all their children are already in `sorted` + while len(unsorted) > 0: + newly_sorted = [] + for node_id in unsorted: + child_ids = set(child.id for child in self.nodes[node_id].links) + if not child_ids.difference(sorted): + sorted.append(node_id) + newly_sorted.append(node_id) + unsorted -= set(newly_sorted) + return [self.nodes[node_id] for node_id in sorted] + \ No newline at end of file