Source code for syconn.proc.graphs

# -*- coding: utf-8 -*-
# SyConn - Synaptic connectivity inference toolkit
#
# Copyright (c) 2016 - now
# Max Planck Institute of Neurobiology, Martinsried, Germany
# Authors: Philipp Schubert, Joergen Kornfeld
import itertools
from typing import List, Any, Optional, TYPE_CHECKING

import networkx as nx
import numpy as np
import tqdm
from knossos_utils.skeleton import Skeleton, SkeletonAnnotation, SkeletonNode
from scipy import spatial

if TYPE_CHECKING:
    from ..reps.super_segmentation import SuperSegmentationObject
from .. import global_params
from ..mp.mp_utils import start_multiprocess_imap as start_multiprocess


[docs]def bfs_smoothing(vertices, vertex_labels, max_edge_length=120, n_voting=40): """ Smooth vertex labels by applying a majority vote on a BFS subset of nodes for every node in the graph Parameters Args: vertices: np.array N, 3 vertex_labels: np.array N, 1 max_edge_length: float maximum distance between vertices to consider them connected in the graph n_voting: int Number of collected nodes during BFS used for majority vote Returns: np.array smoothed vertex labels """ G = create_graph_from_coords(vertices, max_dist=max_edge_length, mst=False, force_single_cc=False) # create BFS subset bfs_nn = split_subcc(G, max_nb=n_voting, verbose=False) new_vertex_labels = np.zeros_like(vertex_labels) for ii in range(len(vertex_labels)): curr_labels = vertex_labels[bfs_nn[ii]] labels, counts = np.unique(curr_labels, return_counts=True) majority_label = labels[np.argmax(counts)] new_vertex_labels[ii] = majority_label return new_vertex_labels
[docs]def split_subcc(g, max_nb, verbose=False, start_nodes=None): """ Creates subgraph for each node consisting of nodes until maximum number of nodes is reached. Args: g: Graph max_nb: int verbose: bool start_nodes: iterable node ID's Returns: dict """ subnodes = {} if verbose: nb_nodes = g.number_of_nodes() pbar = tqdm.tqdm(total=nb_nodes, leave=False) if start_nodes is None: iter_ixs = g.nodes() else: iter_ixs = start_nodes for n in iter_ixs: n_subgraph = [n] nb_edges = 0 for e in nx.bfs_edges(g, n): n_subgraph.append(e[1]) nb_edges += 1 if nb_edges == max_nb: break subnodes[n] = n_subgraph if verbose: pbar.update(1) if verbose: pbar.close() return subnodes
[docs]def chunkify_contiguous(l, n): """Yield successive n-sized chunks from l. https://stackoverflow.com/questions/312443/how-do-you-split-a-list-into-evenly-sized-chunks""" for i in range(0, len(l), n): yield l[i:i + n]
[docs]def split_subcc_join(g: nx.Graph, subgraph_size: int, lo_first_n: int = 1) -> List[List[Any]]: """ Creates a subgraph for each node consisting of nodes until maximum number of nodes is reached. Args: g: Supervoxel graph subgraph_size: Size of subgraphs. The difference between `subgraph_size` and `lo_first_n` defines the supervoxel overlap. lo_first_n: Leave out first n nodes: will collect `subgraph_size` nodes starting from center node and then omit the first lo_first_n nodes, i.e. not use them as new starting nodes. Returns: """ start_node = list(g.nodes())[0] for n, d in dict(g.degree).items(): if d == 1: start_node = n break dfs_nodes = list(nx.dfs_preorder_nodes(g, start_node)) # get subgraphs via splicing of traversed node list into equally sized fragments. they might # be unconnected if branch sizes mod subgraph_size != 0, then a chunk will contain multiple connected components. chunks = list(chunkify_contiguous(dfs_nodes, lo_first_n)) sub_graphs = [] for ch in chunks: # collect all connected component subgraphs sg = g.subgraph(ch).copy() sub_graphs += list((sg.subgraph(c) for c in nx.connected_components(sg))) # add more context to subgraphs subgraphs_withcontext = [] for sg in sub_graphs: # add context but omit artificial start node context_nodes = [] for n in list(sg.nodes()): subgraph_nodes_with_context = [] nb_edges = sg.number_of_nodes() for e in nx.bfs_edges(g, n): subgraph_nodes_with_context += list(e) nb_edges += 1 if nb_edges == subgraph_size: break context_nodes += subgraph_nodes_with_context # add original nodes context_nodes = list(set(context_nodes)) for n in list(sg.nodes()): if n in context_nodes: context_nodes.remove(n) subgraph_nodes_with_context = list(sg.nodes()) + context_nodes subgraphs_withcontext.append(subgraph_nodes_with_context) return subgraphs_withcontext
[docs]def merge_nodes(G, nodes, new_node): """ FOR UNWEIGHTED, UNDIRECTED GRAPHS ONLY """ if G.is_directed(): raise ValueError('Method "merge_nodes" is only valid for undirected graphs.') G.add_node(new_node) for n in nodes: for e in G.edges(n): # add edge between new node and original partner node edge = list(e) edge.remove(n) paired_node = edge[0] G.add_edge(new_node, paired_node) for n in nodes: # remove the merged nodes G.remove_node(n)
[docs]def split_glia_graph(nx_g, thresh, clahe=False, nb_cpus=1, pred_key_appendix=""): """ Split graph into glia and non-glua CC's. Args: nx_g: nx.Graph thresh: float clahe: bool nb_cpus: int pred_key_appendix: str verbose: bool Returns: list, list Neuron, glia connected components. """ glia_key = "glia_probas" if clahe: glia_key += "_clahe" glia_key += pred_key_appendix glianess, size = get_glianess_dict(list(nx_g.nodes()), thresh, glia_key, nb_cpus=nb_cpus) return remove_glia_nodes(nx_g, size, glianess, return_removed_nodes=True)
[docs]def split_glia(sso, thresh, clahe=False, pred_key_appendix=""): """ Split SuperSegmentationObject into glia and non glia SegmentationObjects. Args: sso: SuperSegmentationObject thresh: float clahe: bool pred_key_appendix: str Defines type of glia predictions Returns: list, list (of SegmentationObject) Neuron, glia nodes """ nx_G = sso.rag nonglia_ccs, glia_ccs = split_glia_graph(nx_G, thresh=thresh, clahe=clahe, nb_cpus=sso.nb_cpus, pred_key_appendix=pred_key_appendix) return nonglia_ccs, glia_ccs
[docs]def create_ccsize_dict(g: nx.Graph, bbs: dict, is_connected_components: bool = False) -> dict: """ Calculate bounding box size of connected components. Args: g: Supervoxel graph. bbs: Bounding boxes (physical units). is_connected_components: If graph `g` already is connected components. If False, ``nx.connected_components`` is applied. Returns: Look-up which stores the connected component bounding box for every single node in the input Graph `g`. """ if not is_connected_components: ccs = nx.connected_components(g) else: ccs = g node2cssize_dict = {} for cc in ccs: # if ID is not in bbs, it was skipped due to low voxel count curr_bbs = [bbs[n] for n in cc if n in bbs] if len(curr_bbs) == 0: raise ValueError(f'Could not find a single bounding box for connected component with IDs: {cc}.') else: curr_bbs = np.concatenate(curr_bbs) cc_size = np.linalg.norm(np.max(curr_bbs, axis=0) - np.min(curr_bbs, axis=0), ord=2) for n in cc: node2cssize_dict[n] = cc_size return node2cssize_dict
[docs]def get_glianess_dict(seg_objs, thresh, glia_key, nb_cpus=1, use_sv_volume=False, verbose=False): glianess = {} sizes = {} params = [[so, glia_key, thresh, use_sv_volume] for so in seg_objs] res = start_multiprocess(glia_loader_helper, params, nb_cpus=nb_cpus, verbose=verbose, show_progress=verbose) for ii, el in enumerate(res): so = seg_objs[ii] glianess[so] = el[0] sizes[so] = el[1] return glianess, sizes
[docs]def glia_loader_helper(args): so, glia_key, thresh, use_sv_volume = args if glia_key not in so.attr_dict.keys(): so.load_attr_dict() curr_glianess = so.glia_pred(thresh) if not use_sv_volume: curr_size = so.mesh_bb else: curr_size = so.size return curr_glianess, curr_size
[docs]def remove_glia_nodes(g, size_dict, glia_dict, return_removed_nodes=False): """ Calculate distance weights for shortest path analysis or similar, based on glia and size vertex properties and removes unsupporting glia nodes. Args: g: Graph size_dict: glia_dict: return_removed_nodes: bool Returns: list of list of nodes Remaining connected components of type neuron """ # set up node weights based on glia prediction and size # weights = {} # e_weights = {} # for n in g.nodes(): # weights[n] = np.linalg.norm(size_dict[n][1]-size_dict[n][0], ord=2)\ # * glia_dict[n] # # set up edge weights based on sum of node weights # for e in g.edges(): # e_weights[e] = weights[list(e)[0]] + weights[list(e)[1]] # nx.set_node_attributes(g, weights, 'weight') # nx.set_edge_attributes(g, e_weights, 'weights') # get neuron type connected component sizes g_neuron = g.copy() for n in g.nodes(): if glia_dict[n] != 0: g_neuron.remove_node(n) neuron2ccsize_dict = create_ccsize_dict(g_neuron, size_dict) if np.all(np.array(list(neuron2ccsize_dict.values())) <= global_params.config['min_cc_size_ssv']): # no significant neuron SV if return_removed_nodes: return [], [list(g.nodes())] return [] # get glia type connected component sizes g_glia = g.copy() for n in g.nodes(): if glia_dict[n] == 0: g_glia.remove_node(n) glia2ccsize_dict = create_ccsize_dict(g_glia, size_dict) if np.all(np.array(list(glia2ccsize_dict.values())) <= global_params.config['min_cc_size_ssv']): # no significant glia SV if return_removed_nodes: return [list(g.nodes())], [] return [list(g.nodes())] tiny_glia_fragments = [] for n in g_glia.nodes(): if glia2ccsize_dict[n] < global_params.config['min_cc_size_ssv']: tiny_glia_fragments += [n] # create new neuron graph without sufficiently big glia connected components g_neuron = g.copy() for n in g.nodes(): if glia_dict[n] != 0 and n not in tiny_glia_fragments: g_neuron.remove_node(n) # find orphaned neuron SV's and add them to glia graph neuron2ccsize_dict = create_ccsize_dict(g_neuron, size_dict) g_tmp = g_neuron.copy() for n in g_tmp.nodes(): if neuron2ccsize_dict[n] < global_params.config['min_cc_size_ssv']: g_neuron.remove_node(n) # create new glia graph with remaining nodes # (as the complementary set of sufficiently big neuron connected components) g_glia = g.copy() for n in g_neuron.nodes(): g_glia.remove_node(n) neuron_ccs = list(nx.connected_components(g_neuron)) if return_removed_nodes: glia_ccs = list(nx.connected_components(g_glia)) assert len(g_glia) + len(g_neuron) == len(g) return neuron_ccs, glia_ccs return neuron_ccs
[docs]def glia_path_length(glia_path, glia_dict, write_paths=None): """ Get the path length of glia SV within glia_path. Assumes single connected glia component within this path. Uses the mesh property of each SegmentationObject to build a graph from all vertices to find shortest path through (or more precise: along the surface of) glia. Edges between non-glia vertices have negligible distance (0.0001) to ensure shortest path along non-glia surfaces. Args: glia_path: list of SegmentationObjects glia_dict: dict Dictionary which keys the SegmentationObjects in glia_path and returns their glia prediction write_paths: bool Returns: float Shortest path between neuron type nodes in nm """ g = nx.Graph() col = {} curr_ind = 0 if write_paths is not None: all_vert = np.zeros((0, 3)) for so in glia_path: is_glia_sv = int(glia_dict[so] > 0) ind, vert = so.mesh # connect meshes of different SV, starts after first SV if curr_ind > 0: # build kd tree from vertices of SV before kd_tree = spatial.cKDTree(vert_resh) # get indices of vertives of SV before (= indices of graph nodes) ind_offset_before = curr_ind - len(vert_resh) # query vertices of current mesh to find close connects next_vert_resh = vert.reshape((-1, 3)) dists, ixs = kd_tree.query(next_vert_resh, distance_upper_bound=500) for kk, ix in enumerate(ixs): if dists[kk] > 500: continue if is_glia_sv: edge_weight = eucl_dist(next_vert_resh[kk], vert_resh[ix]) else: edge_weight = 0.0001 g.add_edge(curr_ind + kk, ind_offset_before + ix, weights=edge_weight) vert_resh = vert.reshape((-1, 3)) # save all vertices for writing shortest path skeleton if write_paths is not None: all_vert = np.concatenate([all_vert, vert_resh]) # connect fragments of SV mesh kd_tree = spatial.cKDTree(vert_resh) dists, ixs = kd_tree.query(vert_resh, k=20, distance_upper_bound=500) for kk in range(len(ixs)): nn_ixs = ixs[kk] nn_dists = dists[kk] col[curr_ind + kk] = glia_dict[so] for curr_ix, curr_dist in zip(nn_ixs, nn_dists): col[curr_ind + curr_ix] = glia_dict[so] if is_glia_sv: dist = curr_dist else: # only take path through glia into account dist = 0 g.add_edge(kk + curr_ind, curr_ix + curr_ind, weights=dist) curr_ind += len(vert_resh) start_ix = 0 # choose any index of the first mesh end_ix = curr_ind - 1 # choose any index of the last mesh shortest_path_length = nx.dijkstra_path_length(g, start_ix, end_ix, weight="weights") if write_paths is not None: shortest_path = nx.dijkstra_path(g, start_ix, end_ix, weight="weights") anno = coordpath2anno([all_vert[ix] for ix in shortest_path]) anno.setComment("{0:.4}".format(shortest_path_length)) skel = Skeleton() skel.add_annotation(anno) skel.to_kzip("{{}/{0:.4}_vertpath.k.zip".format(write_paths, shortest_path_length)) return shortest_path_length
[docs]def eucl_dist(a, b): return np.linalg.norm(a - b)
[docs]def get_glia_paths(g, glia_dict, node2ccsize_dict, min_cc_size_neuron, node2ccsize_dict_glia, min_cc_size_glia): """ Currently not in use, Refactoring needed Find paths between neuron type SV grpah nodes which contain glia nodes. Args: g: nx.Graph glia_dict: node2ccsize_dict: min_cc_size_neuron: node2ccsize_dict_glia: min_cc_size_glia: Returns: """ end_nodes = [] paths = nx.all_pairs_dijkstra_path(g, weight="weights") for n, d in g.degree().items(): if d == 1 and glia_dict[n] == 0 and node2ccsize_dict[n] > min_cc_size_neuron: end_nodes.append(n) # find all nodes along these ways and store them as mandatory nodes glia_paths = [] glia_svixs_in_paths = [] for a, b in itertools.combinations(end_nodes, 2): glia_nodes = [n for n in paths[a][b] if glia_dict[n] != 0] if len(glia_nodes) == 0: continue sv_ccsizes = [node2ccsize_dict_glia[n] for n in glia_nodes] if np.max(sv_ccsizes) <= min_cc_size_glia: # check minimum glia size continue sv_ixs = np.array([n.id for n in glia_nodes]) glia_nodes_already_exist = False for el_ixs in glia_svixs_in_paths: if np.all(sv_ixs == el_ixs): glia_nodes_already_exist = True break if glia_nodes_already_exist: # check if same glia path exists already continue glia_paths.append(paths[a][b]) glia_svixs_in_paths.append(np.array([so.id for so in glia_nodes])) return glia_paths
[docs]def write_sopath2skeleton(so_path, dest_path, scaling=None, comment=None): """ Writes very simple skeleton, each node represents the center of mass of a SV, and edges are created in list order. Args: so_path: list of SegmentationObject dest_path: str scaling: np.ndarray or tuple comment: str Returns: """ if scaling is None: scaling = np.array(global_params.config['scaling']) skel = Skeleton() anno = SkeletonAnnotation() anno.scaling = scaling rep_nodes = [] for so in so_path: vert = so.mesh[1].reshape((-1, 3)) com = np.mean(vert, axis=0) kd_tree = spatial.cKDTree(vert) dist, nn_ix = kd_tree.query([com]) nn = vert[nn_ix[0]] / scaling n = SkeletonNode().from_scratch(anno, nn[0], nn[1], nn[2]) anno.addNode(n) rep_nodes.append(n) for i in range(1, len(rep_nodes)): anno.addEdge(rep_nodes[i - 1], rep_nodes[i]) if comment is not None: anno.setComment(comment) skel.add_annotation(anno) skel.to_kzip(dest_path)
[docs]def coordpath2anno(coords: np.ndarray, scaling: Optional[np.ndarray] = None) -> SkeletonAnnotation: """ Creates skeleton from scaled coordinates, assume coords are in order for edge creation. Args: coords: np.array scaling: np.ndarray Returns: SkeletonAnnotation """ if scaling is None: scaling = global_params.config['scaling'] anno = SkeletonAnnotation() anno.scaling = scaling rep_nodes = [] for c in coords: n = SkeletonNode().from_scratch(anno, c[0] / scaling[0], c[1] / scaling[1], c[2] / scaling[2]) anno.addNode(n) rep_nodes.append(n) for i in range(1, len(rep_nodes)): anno.addEdge(rep_nodes[i - 1], rep_nodes[i]) return anno
[docs]def create_graph_from_coords(coords: np.ndarray, max_dist: float = 6000, force_single_cc: bool = True, mst: bool = False) -> nx.Graph: """ Generate skeleton from sample locations by adding edges between points with a maximum distance and then pruning the skeleton using MST. Nodes will have a 'position' attribute. Args: coords: Coordinates. max_dist: Add edges between two nodes that are within this distance. force_single_cc: Force that the tree generated from coords is a single connected component. mst: Compute the minimum spanning tree. Returns: Networkx graph. Edge between nodes (coord indices) using the ordering of coords, i.e. the edge (1, 2) connects coordinate coord[1] and coord[2]. """ g = nx.Graph() if len(coords) == 1: g.add_node(0) g.add_weighted_edges_from([[0, 0, 0]]) return g kd_t = spatial.cKDTree(coords) pairs = kd_t.query_pairs(r=max_dist, output_type="ndarray") g.add_nodes_from([(ix, dict(position=coord)) for ix, coord in enumerate(coords)]) weights = np.linalg.norm(coords[pairs[:, 0]] - coords[pairs[:, 1]], axis=1) g.add_weighted_edges_from([[pairs[i][0], pairs[i][1], weights[i]] for i in range(len(pairs))]) if force_single_cc: # make sure its a connected component g = stitch_skel_nx(g) if mst: g = nx.minimum_spanning_tree(g) return g
[docs]def draw_glia_graph(G, dest_path, min_sv_size=0, ext_glia=None, iterations=150, seed=0, glia_key="glia_probas", node_size_cap=np.inf, mcmp=None, pos=None): """ Draw graph with nodes colored in red (glia) and blue) depending on their class. Writes drawing to dest_path. Args: G: nx.Graph dest_path: str min_sv_size: int ext_glia: dict keys: node in G, values: number indicating class iterations: seed: int Default: 0; random seed for layout generation glia_key: str node_size_cap: int mcmp: color palette pos: Returns: """ import matplotlib.pyplot as plt import seaborn as sns if mcmp is None: mcmp = sns.diverging_palette(250, 15, s=99, l=60, center="dark", as_cmap=True) np.random.seed(0) seg_objs = list(G.nodes()) glianess, size = get_glianess_dict(seg_objs, glia_thresh, glia_key, 5, use_sv_volume=True) if ext_glia is not None: for n in G.nodes(): glianess[n] = ext_glia[n.id] plt.figure() n_size = np.array([size[n] ** (1. / 3) for n in G.nodes()]).astype( np.float32) # reduce cubic relation to a linear one # n_size = np.array([np.linalg.norm(size[n][1]-size[n][0]) for n in G.nodes()]) if node_size_cap == "max": node_size_cap = np.max(n_size) n_size[n_size > node_size_cap] = node_size_cap col = np.array([glianess[n] for n in G.nodes()]) col = col[n_size >= min_sv_size] nodelist = list(np.array(list(G.nodes()))[n_size > min_sv_size]) n_size = n_size[n_size >= min_sv_size] n_size = n_size / np.max(n_size) * 25. if pos is None: pos = nx.spring_layout(G, weight="weight", iterations=iterations, random_state=seed) nx.draw(G, nodelist=nodelist, node_color=col, node_size=n_size, cmap=mcmp, width=0.15, pos=pos, linewidths=0) plt.savefig(dest_path) plt.close() return pos
[docs]def nxGraph2kzip(g, coords, kzip_path): import tqdm scaling = global_params.config['scaling'] coords = coords / scaling skel = Skeleton() anno = SkeletonAnnotation() anno.scaling = scaling node_mapping = {} pbar = tqdm.tqdm(total=len(coords) + len(g.edges()), leave=False) for v in g.nodes(): c = coords[v] n = SkeletonNode().from_scratch(anno, c[0], c[1], c[2]) node_mapping[v] = n anno.addNode(n) pbar.update(1) for e in g.edges(): anno.addEdge(node_mapping[e[0]], node_mapping[e[1]]) pbar.update(1) skel.add_annotation(anno) skel.to_kzip(kzip_path) pbar.close()
[docs]def svgraph2kzip(ssv: 'SuperSegmentationObject', kzip_path: str): """ Writes the SV graph stored in `ssv.edgelist_path` to a kzip file. The representative coordinate of a SV is used as the corresponding node location. Args: ssv: Cell reconstruction object. kzip_path: Path to the output kzip file. """ sv_graph = nx.read_edgelist(ssv.edgelist_path, nodetype=int) coords = {ix: ssv.get_seg_obj('sv', ix).rep_coord for ix in sv_graph.nodes} import tqdm skel = Skeleton() anno = SkeletonAnnotation() anno.scaling = ssv.scaling node_mapping = {} pbar = tqdm.tqdm(total=len(coords) + len(sv_graph.edges()), leave=False) for v in sv_graph.nodes: c = coords[v] n = SkeletonNode().from_scratch(anno, c[0], c[1], c[2]) n.setComment(f'{v}') node_mapping[v] = n anno.addNode(n) pbar.update(1) for e in sv_graph.edges(): anno.addEdge(node_mapping[e[0]], node_mapping[e[1]]) pbar.update(1) skel.add_annotation(anno) skel.to_kzip(kzip_path) pbar.close()
[docs]def stitch_skel_nx(skel_nx: nx.Graph, n_jobs: int = 1) -> nx.Graph: """ Stitch connected components within a graph by recursively adding edges between the closest components. Args: skel_nx: Networkx graph. Nodes require 'position' attribute. n_jobs: Number of jobs used for query of cKDTree. Returns: Single connected component graph. """ if skel_nx.number_of_nodes() == 0: return skel_nx no_of_seg = nx.number_connected_components(skel_nx) if no_of_seg == 1: return skel_nx skel_nx_nodes = np.array([skel_nx.nodes[ix]['position'] for ix in skel_nx.nodes()], dtype=np.int64) while no_of_seg != 1: rest_nodes = [] rest_nodes_ixs = [] list_of_comp = np.array([c for c in sorted(nx.connected_components(skel_nx), key=len, reverse=True)]) for single_rest_graph in list_of_comp[1:]: rest_nodes += [skel_nx_nodes[int(ix)] for ix in single_rest_graph] rest_nodes_ixs += list(single_rest_graph) current_set_of_nodes = [skel_nx_nodes[int(ix)] for ix in list_of_comp[0]] current_set_of_nodes_ixs = list(list_of_comp[0]) tree = spatial.cKDTree(rest_nodes, 1) thread_lengths, indices = tree.query(current_set_of_nodes, n_jobs=n_jobs) start_thread_index = np.argmin(thread_lengths) stop_thread_index = indices[start_thread_index] e1 = current_set_of_nodes_ixs[start_thread_index] e2 = rest_nodes_ixs[stop_thread_index] skel_nx.add_edge(e1, e2) no_of_seg -= 1 return skel_nx