Source code for syconn.reps.super_segmentation_helper

# -*- coding: utf-8 -*-
# SyConn - Synaptic connectivity inference toolkit
#
# Copyright (c) 2016 - now
# Max Planck Institute of Neurobiology, Martinsried, Germany
# Authors: Sven Dorkenwald, Philipp Schubert, Joergen Kornfeld
import copy
import os
import time

from . import log_reps
from . import segmentation
from .rep_helper import assign_rep_values, colorcode_vertices, surface_samples
from .segmentation_helper import load_skeleton, find_missing_sv_views, \
    find_missing_sv_attributes, find_missing_sv_skeletons, load_so_attr_bulk
from .. import global_params
from ..handler.basics import kd_factory, flatten_list
from ..handler.multiviews import generate_rendering_locs
from ..mp.mp_utils import start_multiprocess_obj, start_multiprocess_imap
from ..proc.graphs import create_graph_from_coords, stitch_skel_nx
from ..proc.meshes import write_mesh2kzip
from ..proc.rendering import render_sso_coords
from ..proc.sd_proc import predict_views
from ..extraction.block_processing_C import relabel_vol_nonexist2zero
from ..extraction.in_bounding_boxC import in_bounding_box

from typing import Dict, List, Union, Optional, Tuple, TYPE_CHECKING, Any
if TYPE_CHECKING:
    from . import super_segmentation
    from ..reps.super_segmentation import SuperSegmentationObject, SuperSegmentationDataset
    from ..reps.segmentation import SegmentationObject

from collections.abc import Iterable
from collections import Counter
from multiprocessing.pool import ThreadPool
import networkx as nx
from numba import jit
import numpy as np
import scipy
import scipy.ndimage
from scipy import spatial
from skimage.segmentation import watershed
from skimage.feature import peak_local_max
from scipy import ndimage
from knossos_utils.skeleton_utils import annotation_to_nx_graph, \
    load_skeleton as load_skeleton_kzip, Skeleton, SkeletonAnnotation, SkeletonNode
try:
    from knossos_utils import mergelist_tools
except ImportError:
    from knossos_utils import mergelist_tools_fallback as mergelist_tool


[docs]def majority_vote(anno, prop, max_dist): """ Performs smoothing of property prediction in an annotation using a sliding window and majority voting, leaving somata untouched for axoness property. Args: anno: SkeletonAnnotation The annotation object containing the skeleton. prop: str The property to average, e.g., 'axoness'. max_dist: int The maximum distance (in nm) for the sliding window used in majority voting. Returns: None """ old_anno = copy.deepcopy(anno) nearest_nodes_list = nodes_in_pathlength(old_anno, max_dist) for nodes in nearest_nodes_list: curr_node_id = nodes[0].getID() new_node = anno.getNodeByID(curr_node_id) if prop == "axoness": if int(new_node.data["axoness_pred"]) == 2: new_node.data["axoness_pred"] = 2 continue property_val = [int(n.data[prop + '_pred']) for n in nodes if int(n.data[prop + '_pred']) != 2] counter = Counter(property_val) new_ax = counter.most_common()[0][0] new_node.setDataElem(prop + '_pred', new_ax)
[docs]def nodes_in_pathlength(anno, max_path_len): """ Identifies nodes reachable within a specified path length from each source node in an annotation. Args: anno: AnnotationObject The annotation object containing the nodes. max_path_len: float The maximum distance from the source node. Returns: List[List[SkeletonNode]]: A list of lists, each containing nodes reachable within max_path_len. The outer list has a length equal to the number of nodes in anno, obtained by len(anno.getNodes()). """ skel_graph = annotation_to_nx_graph(anno) list_reachable_nodes = [] for source_node in anno.getNodes(): reachable_nodes = [source_node] curr_path_length = 0.0 for edge in nx.bfs_edges(skel_graph, source_node): next_node = edge[1] next_node_coord = np.array(next_node.getCoordinate_scaled()) curr_vec = next_node_coord - np.array(edge[0].getCoordinate_scaled()) curr_path_length += np.linalg.norm(curr_vec) if curr_path_length > max_path_len: break reachable_nodes.append(next_node) list_reachable_nodes.append(reachable_nodes) return list_reachable_nodes
[docs]def predict_sso_celltype(sso: 'super_segmentation.SuperSegmentationObject', model: Any, nb_views_model: int = 20, use_syntype=True, overwrite: bool = False, pred_key_appendix="", da_equals_tan: bool = True, n_classes: int = 7, save_to_attr_dict: bool = True): """ Predicts the cell type of a SuperSegmentationObject based on local views and synapse type ratio feature, using cached views. This method uses precomputed views, also used for axon and spine prediction. To generate predictions without using cached views, refer to `celltype_of_sso_nocache`. The random view subsets used for prediction are prepared by the function :func:`~sso_views_to_modelinput`. The final cell type prediction is determined by the majority vote across all subset predictions. Args: sso: SuperSegmentationObject The SuperSegmentationObject to predict the cell type for. model: nn.Module The prediction model, typically a neural network module. nb_views_model: int The number of views to use in the model for prediction. use_syntype: bool Whether to consider the synapse type in the prediction. `n_classes` must be 7 if `da_equals_tan` is True. overwrite: bool If True, overwrite existing predictions. Use this option with caution to ensure that previous predictions are not unintentionally lost. pred_key_appendix: str An appendix to the prediction key used for result storage. This allows differentiation between different prediction iterations or parameters. da_equals_tan: bool If True, merge DA and TAN classes into one category. This parameter should be set to True only when `n_classes` is 7, to maintain consistency in class representation across predictions. n_classes: int The number of output classes in the model. It has to be set to 7 if `da_equals_tan` is True in order to align with the merged class definitions. save_to_attr_dict: bool If True, persist the prediction result in the SuperSegmentationObject's attribute dictionary for subsequent retrieval and analysis. Returns: None """ sso.load_attr_dict() pred_key = "celltype_cnn_e3" + pred_key_appendix if not overwrite and pred_key in sso.attr_dict: return from ..handler.prediction import naive_view_normalization_new inp_d = sso_views_to_modelinput(sso, nb_views_model) inp_d = naive_view_normalization_new(inp_d) if global_params.config.syntype_available and use_syntype: synsign_ratio = np.array([[syn_sign_ratio_celltype(sso, comp_types=[1, ]), syn_sign_ratio_celltype(sso, comp_types=[0, ])]] * len(inp_d)) res = model.predict_proba((inp_d, synsign_ratio)) else: res = model.predict_proba(inp_d) # DA and TAN are type modulatory, if this is changes, also change `certainty_celltype`, `celltype_of_sso_nocache` if da_equals_tan: assert n_classes == 7, 'Incompatible number of classes for cell type prediction with "da_equals_tan=True".' # accumulate evidence for DA and TAN res[:, 1] += res[:, 6] # remove TAN in proba array res = np.delete(res, [6], axis=1) # INT is now at index 6 -> label 6 is INT clf = np.argmax(res, axis=1) if np.max(clf) >= n_classes: raise ValueError('Unknown cell type predicted.') major_dec = np.zeros(n_classes) for ii in range(len(major_dec)): major_dec[ii] = np.sum(clf == ii) major_dec /= np.sum(major_dec) pred = np.argmax(major_dec) sso.attr_dict[pred_key] = pred sso.attr_dict[f"{pred_key}_probas"] = res cert = sso.certainty_celltype(pred_key) sso.attr_dict[f"{pred_key}_certainty"] = cert if save_to_attr_dict: sso.save_attributes([pred_key, f"{pred_key}_probas", f"{pred_key}_certainty"], [pred, res, cert])
[docs]def sso_views_to_modelinput(sso: 'super_segmentation.SuperSegmentationObject', nb_views: int, view_key: Optional[str] = None) -> np.ndarray: """ Converts the 2D projection views of a SuperSegmentationObject into random subsets of views for model input. Used for cell type inference. todo: * shuffle after reshaping from (#multi-view locations, 4 channels, #nb_views, 128, 256) to (#multi-view locations * #nb_views, 4 channels, 128, 256)? Args: sso (SuperSegmentationObject): The cell reconstruction object. nb_views (int): The number of views in each subset. view_key (Optional[str]): The key of the stored views. Defaults to None. Returns: np.ndarray: An array of random view subsets of all 2D projections contained in the cell reconstruction. Shape: (#subsets, 4 channels, nb_views, 128, 256) """ np.random.seed(0) assert len(sso.sv_ids) > 0 views = sso.load_views(view_key=view_key) np.random.shuffle(views) # view shape: (#multi-view locations, 4 channels, #nb_views, 128, 256) views = views.swapaxes(1, 0).reshape((4, -1, 128, 256)) assert views.shape[1] > 0 if views.shape[1] < nb_views: rand_ixs = np.random.choice(views.shape[1], nb_views - views.shape[1]) views = np.append(views, views[:, rand_ixs], axis=1) nb_samples = np.floor(views.shape[1] / nb_views) assert nb_samples > 0 out_d = views[:, :int(nb_samples * nb_views)] out_d = out_d.reshape((4, -1, nb_views, 128, 256)).swapaxes(1, 0) return out_d
[docs]def radius_correction_found_vertices(sso: 'super_segmentation.SuperSegmentationObject', plump_factor: int = 1, num_found_vertices: int = 10): """ Estimates the diameters of skeleton nodes by finding the median distance to the nearest mesh vertices. Args: sso: SuperSegmentationObject The SuperSegmentationObject containing the skeleton and mesh. plump_factor: int A factor to adjust the estimated radius. This is a multiplication factor for the radius. num_found_vertices: int The number of closest vertices to query for each node. Returns: The updated skeleton with diameters estimated. """ skel_node = sso.skeleton['nodes'] diameters = sso.skeleton['diameters'] vert_sparse = sso.mesh[1].reshape((-1, 3)) tree = spatial.cKDTree(vert_sparse) dists, all_found_vertices_ixs = tree.query(skel_node * sso.scaling, num_found_vertices) for ii, el in enumerate(skel_node): diameters[ii] = np.median(dists[ii]) * 2 / 10 sso.skeleton['diameters'] = diameters * plump_factor return sso.skeleton
[docs]def get_sso_axoness_from_coord(sso, coord, k=5): """ Determines the majority axoness class of the k nearest neighbor nodes within an SSO skeleton. Args: sso: SuperSegmentationObject The SuperSegmentationObject containing the skeleton. coord: np.array The unscaled coordinate to query. k: int The number of nearest neighbors to consider for the majority vote. Returns: int: The majority class of the nodes (0 for dendrite, 1 for axon, or 2 for soma). """ coord = np.array(coord) * np.array(sso.scaling) sso.load_skeleton() kdt = spatial.cKDTree(sso.skeleton["nodes"] * np.array(sso.scaling)) dists, ixs = kdt.query(coord, k=k) ixs = ixs[dists != np.inf] axs = sso.skeleton["axoness"][ixs] cnt = Counter(axs) return cnt.most_common(n=1)[0][0]
[docs]def load_voxels_downsampled(sso, downsampling=(2, 2, 1), nb_threads=10): """ Loads the voxels of a SuperSegmentationObject with downsampling. Args: sso: SuperSegmentationObject The SuperSegmentationObject to load voxels for. downsampling: tuple The downsampling factors for each dimension (z, y, x). nb_threads: int The number of threads to use for parallel loading. Returns: np.ndarray: A downsampled boolean array representing the voxels of the SSO. """ def _load_sv_voxels_thread(args): sv_id = args[0] sv = segmentation.SegmentationObject(sv_id, obj_type="sv", version=sso.version_dict[ "sv"], working_dir=sso.working_dir, config=sso.config, voxel_caching=False) if sv.voxels_exist: box = [np.array(sv.bounding_box[0] - sso.bounding_box[0], dtype=np.int32)] box[0] /= downsampling size = np.array(sv.bounding_box[1] - sv.bounding_box[0], dtype=np.float32) size = np.ceil(size.astype(np.float32) / downsampling).astype(np.int32) box.append(box[0] + size) sv_voxels = sv.voxels if not isinstance(sv_voxels, int): sv_voxels = sv_voxels[::downsampling[0], ::downsampling[1], ::downsampling[2]] voxels[box[0][0]: box[1][0], box[0][1]: box[1][1], box[0][2]: box[1][2]][sv_voxels] = True downsampling = np.array(downsampling, dtype=np.int32) if len(sso.sv_ids) == 0: return None voxel_box_size = sso.bounding_box[1] - sso.bounding_box[0] voxel_box_size = voxel_box_size.astype(np.float32) voxel_box_size = np.ceil(voxel_box_size / downsampling).astype(np.int32) voxels = np.zeros(voxel_box_size, dtype=np.bool) multi_params = [] for sv_id in sso.sv_ids: multi_params.append([sv_id]) if nb_threads > 1: pool = ThreadPool(nb_threads) pool.map(_load_sv_voxels_thread, multi_params) pool.close() pool.join() else: map(_load_sv_voxels_thread, multi_params) return voxels
[docs]def create_new_skeleton(sv_id, sso): """ Creates a new skeleton for a supervoxel within a SuperSegmentationObject. Args: sv_id (int): The ID of the supervoxel. sso (SuperSegmentationObject): The SuperSegmentationObject containing the supervoxel. Returns: Tuple[np.ndarray, np.ndarray, np.ndarray]: Arrays representing the nodes, diameters, and edges of the new skeleton. """ so = SegmentationObject(sv_id, obj_type="sv", version=sso.version_dict["sv"], working_dir=sso.working_dir, config=sso.config) so.enable_locking = False so.load_attr_dict() skel = load_skeleton(so) return skel['nodes'], skel['diameters'], skel['edges']
[docs]def convert_coord(coord_list, scal): """ Converts a list of coordinates using a scaling factor. Args: coord_list: list or np.ndarray The list of coordinates to convert. scal: numeric or np.ndarray The scaling factor to be applied to each coordinate in `coord_list`. Returns: np.ndarray: The scaled coordinates, with the same shape as `coord_list`. """ return np.array([coord_list[1] + 1, coord_list[0] + 1, coord_list[2] + 1]) * np.array(scal)
[docs]def prune_stub_branches(sso=None, nx_g=None, scal=None, len_thres=1000, preserve_annotations=True): """ Removes short stub branches from a skeleton graph, maintaining the true morphology. Args: sso: Optional[SuperSegmentationObject] The SuperSegmentationObject containing the skeleton. If None, no SuperSegmentationObject will be returned. nx_g: networkx.Graph The graph representing the skeleton. scal: Optional[np.array of size 3] The scaling factor for the coordinates. Defaults to the original scale. len_thres: int The threshold for the length below which branches will be pruned. preserve_annotations: bool If True, annotations are preserved during pruning. Returns: Tuple[Optional[SuperSegmentationObject], networkx.Graph]: The pruned SuperSegmentationObject (if provided) and the pruned skeleton graph. If sso is None, only the pruned graph is returned. """ if scal is None: scal = global_params.config['scaling'] pruning_complete = False if preserve_annotations: new_nx_g = nx_g.copy() else: new_nx_g = nx_g # find all tip nodes in an anno, ie degree 1 nodes while not pruning_complete: nx_g = new_nx_g.copy() end_nodes = list({k for k, v in dict(nx_g.degree()).items() if v == 1}) # DFS to first branch node for end_node in end_nodes: prune_nodes = [] for curr_node in nx.traversal.dfs_preorder_nodes(nx_g, end_node): if nx_g.degree(curr_node) > 2: loc_end = convert_coord(nx_g.nodes[end_node]['position'], scal) loc_curr = convert_coord(nx_g.nodes[curr_node]['position'], scal) b_len = np.linalg.norm(loc_end - loc_curr) if b_len < len_thres: # remove this stub, i.e. prune the nodes that were # collected on our way to the branch point for prune_node in prune_nodes: new_nx_g.remove_node(prune_node) break else: break prune_nodes.append(curr_node) if len(new_nx_g.nodes) == len(nx_g.nodes): pruning_complete = True if nx.number_connected_components(new_nx_g) != 1: msg = 'Pruning of SV skeletons failed during "prune_stub_branches' \ '" with {} connected components. Please check the underlying' \ ' SSV {}. Performing stitching method to add missing edg' \ 'es recursively.'.format(nx.number_connected_components(new_nx_g), sso.id) new_nx_g = stitch_skel_nx(new_nx_g) log_reps.critical(msg) raise ValueError(msg) for e in new_nx_g.edges: w = np.linalg.norm((new_nx_g.nodes[e[0]]['position'] - new_nx_g.nodes[e[1]]['position']) * scal) new_nx_g[e[0]][e[1]]['weight'] = w new_nx_g = nx.minimum_spanning_tree(new_nx_g) if sso is not None: sso = from_netkx_to_sso(sso, new_nx_g) return sso, new_nx_g
[docs]def from_netkx_to_sso(sso, skel_nx): """ Converts a networkx graph representation of a skeleton into a SuperSegmentationObject's skeleton. Args: sso: SuperSegmentationObject The SuperSegmentationObject to update with the new skeleton. skel_nx: networkx.Graph The networkx graph representing the skeleton. Returns: SuperSegmentationObject: The updated SSO with the new skeleton. """ sso.skeleton = dict() sso.skeleton['nodes'] = np.array([skel_nx.nodes[ix]['position'] for ix in skel_nx.nodes()], dtype=np.uint32) sso.skeleton['diameters'] = np.zeros(len(sso.skeleton['nodes']), dtype=np.float32) assert nx.number_connected_components(skel_nx) == 1 # Important bit, please don't remove (needed after pruning) temp_edges = np.array(list(skel_nx.edges())).reshape(-1) temp_edges_sorted = np.unique(np.sort(temp_edges)) temp_edges_dict = {} for ii, ix in enumerate(temp_edges_sorted): temp_edges_dict[ix] = ii temp_edges = [temp_edges_dict[ix] for ix in temp_edges] temp_edges = np.array(temp_edges, dtype=np.uint64).reshape([-1, 2]) sso.skeleton['edges'] = temp_edges return sso
[docs]def create_sso_skeletons_wrapper(ssvs: List['super_segmentation.SuperSegmentationObject'], dest_paths: Optional[str] = None, nb_cpus: Optional[int] = None, map_myelin: bool = False, save: bool = True): """ Generates skeleton representations for a list of SuperSegmentationObjects. If `global_params.config.allow_ssv_skel_gen = True`, skeletons are created via surface sampling which may result in skeletons partially outside cell segmentation, but close to the cell surface. Conversely, if `global_params.config.allow_ssv_skel_gen = False`, existing supervoxel skeletons are pruned, stitched, and diameter estimates are performed. Skeletons are saved using `ssv.save_skeleton` and accessible through `ssv.skeleton`. Args: ssvs: List[SuperSegmentationObject] | An iterable of cell reconstruction objects. dest_paths: Optional[str] | Paths to kzips for each object in `ssvs`. nb_cpus: Optional[int] | Number of CPUs used for every `ssv` in `ssvs`. map_myelin: bool | If True, uses `map_myelin2coords` to map myelin predictions to `ssv.skeleton["nodes"]`, storing the result in `ssv.skeleton["myelin"]`. Predictions are smoothed via majority vote with 10 micrometers traversal. save: bool | If True, writes the generated skeleton to disk. Returns: None Todo: * Add sliding window majority vote for smoothing myelin prediction to `global_params`. """ if nb_cpus is None: nb_cpus = global_params.config['ncores_per_node'] if dest_paths is not None: if not isinstance(dest_paths, Iterable): raise ValueError('Destination paths given but are not iterable.') else: dest_paths = [None for _ in ssvs] use_new_renderings_locs = global_params.config.use_new_renderings_locs for ssv, dest_path in zip(ssvs, dest_paths): ssv.nb_cpus = nb_cpus if not global_params.config.allow_ssv_skel_gen: # This merges existing SV skeletons - SV skeletons must exist ssv = create_sso_skeleton_fast(ssv) else: verts = ssv.mesh[1].reshape(-1, 3) # choose random subset of surface vertices np.random.seed(0) ixs = np.arange(len(verts)) np.random.shuffle(ixs) ixs = ixs[:int(0.5 * len(ixs))] # TODO: add parameter to config if use_new_renderings_locs: locs = generate_rendering_locs(verts[ixs], 1000) else: locs = surface_samples(verts[ixs], bin_sizes=(1000, 1000, 1000), max_nb_samples=10000, r=500) g = create_graph_from_coords(locs, mst=True, force_single_cc=True) if g.number_of_edges() == 1: edge_list = np.array(list(g.edges())) else: edge_list = np.array(g.edges()) del g if edge_list.ndim != 2: raise ValueError("Edge list ist not a 2D array: {}\n{}".format( edge_list.shape, edge_list)) ssv.skeleton = dict() ssv.skeleton["nodes"] = (locs / np.array(ssv.scaling)).astype(np.int32) ssv.skeleton["edges"] = edge_list ssv.skeleton["diameters"] = np.ones(len(locs)) if map_myelin: try: ssv.skeleton["myelin"] = map_myelin2coords(ssv.skeleton["nodes"], mag=4) majorityvote_skeleton_property(ssv, prop_key='myelin') except Exception as e: raise Exception(f'Myelin mapping in {ssv} failed with: {e}') if save: ssv.save_skeleton() if dest_path is not None: ssv.save_skeleton_to_kzip(dest_path=dest_path)
[docs]def map_myelin2coords(coords: np.ndarray, cube_edge_avg: np.ndarray = np.array([11, 11, 5]), thresh_proba: float = 255 // 2, thresh_majority: float = 0.5, mag: int = 4) -> np.ndarray: """ Retrieves a myelin prediction at every location in `coords`. The classification is the majority label within a cube of size `cube_edge_avg` around the respective location. Voxels are classified as myelinated by thresholding the probability using `thresh_proba`. A ratio `thresh_majority` decides the label. Examples: The entire myelin prediction for a single cell reconstruction including a smoothing via :func:`~majorityvote_skeleton_property` is implemented as: from syconn import global_params from syconn.reps.super_segmentation import * from syconn.reps.super_segmentation_helper import \ map_myelin2coords, majorityvote_skeleton_property # init. example data set global_params.wd = '~/SyConn/example_cube1/' # initialize example cell reconstruction ssd = SuperSegmentationDataset() ssv = list(ssd.ssvs)[0] ssv.load_skeleton() # get myelin predictions myelinated = map_myelin2coords(ssv.skeleton["nodes"], mag=4) ssv.skeleton["myelin"] = myelinated # this will generate a smoothed version at `ssv.skeleton["myelin_avg10000"]` majorityvote_skeleton_property(ssv, "myelin") # store results as a KNOSSOS readable k.zip file ssv.save_skeleton_to_kzip(dest_path='~/{}_myelin.k.zip'.format(ssv.id), additional_keys=['myelin', 'myelin_avg10000']) Args: coords: Coordinates used to retrieve myelin predictions. In voxel coordinates (mag=1). cube_edge_avg: Cube size used for averaging myelin predictions for each location. The loaded data cube will always have the extent given by `cube_edge_avg`, regardless of the value of `mag`. thresh_proba: Classification threshold in uint8 values (0..255). thresh_majority: Majority ratio for myelin (between 0..1), i.e. `thresh_majority=0.1` means that 10% myelin voxels within `cube_edge_avg` will flag the corresponding locations as myelinated. mag: Data magnification level used to retrieve the prediction results. Returns: An array of myelin predictions (0: no myelin, 1: myelinated neuron) for each coordinate. """ myelin_kd_p = global_params.config.working_dir + "/knossosdatasets/myelin/" if not os.path.isdir(myelin_kd_p): raise ValueError(f'Could not find myelin KnossosDataset at {myelin_kd_p}.') kd = kd_factory(myelin_kd_p) myelin_preds = np.zeros((len(coords)), dtype=np.uint8) n_cube_vx = np.prod(cube_edge_avg) # convert to mag 1, TODO: requires adaption if anisotropic downsampling was used in KD! cube_edge_avg = cube_edge_avg * mag for ix, c in enumerate(coords): offset = c - cube_edge_avg // 2 myelin_proba = kd.load_raw(size=cube_edge_avg, offset=offset, mag=mag).swapaxes(0, 2) myelin_ratio = np.sum(myelin_proba > thresh_proba) / n_cube_vx myelin_preds[ix] = myelin_ratio > thresh_majority return myelin_preds
# New Implementation of skeleton generation which makes use of ssv.rag
[docs]def from_netkx_to_arr(skel_nx: nx.Graph) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: """ Converts a networkx graph representing a skeleton into arrays of nodes, diameters, and edges. Args: skel_nx: The networkx graph representing the skeleton. Returns: Tuple[np.ndarray, np.ndarray, np.ndarray]: Arrays representing the nodes, diameters, and edges of the skeleton. """ skeleton = {} skeleton['nodes'] = np.array( [skel_nx.nodes[ix]['position'] for ix in skel_nx.nodes()], dtype=np.uint32) skeleton['diameters'] = np.zeros(len(skeleton['nodes']), dtype=np.float32) # Important bit, please don't remove (needed after pruning) # This transforms the edge values to contiguous node indices temp_edges = np.array(list(skel_nx.edges())).reshape(-1) temp_edges_sorted = np.unique(np.sort(temp_edges)) temp_edges_dict = {} for ii, ix in enumerate(temp_edges_sorted): temp_edges_dict[ix] = ii temp_edges = [temp_edges_dict[ix] for ix in temp_edges] temp_edges = np.array(temp_edges, dtype=np.uint64).reshape([-1, 2]) skeleton['edges'] = temp_edges return skeleton['nodes'], skeleton['diameters'], skeleton['edges']
[docs]def sparsify_skeleton_fast(g: nx.Graph, scal: Optional[np.ndarray] = None, dot_prod_thresh: float = 0.8, max_dist_thresh: Union[int, float] = 500, min_dist_thresh: Union[int, float] = 50, verbose: bool = False) -> nx.Graph: """ Reduces the number of nodes in a skeleton graph based on geometric criteria. Args: g: The networkx graph representing the skeleton. scal: Scale factor corresponding to the physical voxel size in nm. dot_prod_thresh: Threshold for the dot product indicating 'straightness' of edges. max_dist_thresh: Maximum distance allowed between adjacent nodes. min_dist_thresh: Minimum distance below which nodes will be merged. verbose: If True, additional output will be logged. Returns: nx.Graph: The sparsified skeleton graph. """ start = time.time() skel_nx = nx.Graph(g) n_nodes_start = skel_nx.number_of_nodes() if scal is None: scal = global_params.config['scaling'] change = 1 while change > 0: change = 0 visiting_nodes = list({k for k, v in dict(skel_nx.degree()).items() if v == 2}) for visiting_node in visiting_nodes: neighbours = [n for n in skel_nx.neighbors(visiting_node)] if skel_nx.degree(visiting_node) == 2: left_node = neighbours[0] right_node = neighbours[1] vector_left_node = np.array( [int(skel_nx.nodes[left_node]['position'][ix]) - int(skel_nx.nodes[visiting_node]['position'][ix]) for ix in range(3)]) * scal vector_right_node = np.array([int(skel_nx.nodes[right_node]['position'][ix]) - int(skel_nx.nodes[visiting_node]['position'][ix]) for ix in range(3)]) * scal dot_prod = np.dot(vector_left_node / np.linalg.norm(vector_left_node), vector_right_node / np.linalg.norm(vector_right_node)) dist = np.linalg.norm([int(skel_nx.nodes[right_node]['position'][ix] * scal[ix]) - int( skel_nx.nodes[left_node]['position'][ix] * scal[ix]) for ix in range(3)]) if (abs(dot_prod) > dot_prod_thresh and dist < max_dist_thresh) or dist <= min_dist_thresh: skel_nx.remove_node(visiting_node) skel_nx.add_edge(left_node, right_node) change += 1 if verbose: log_reps.debug(f'sparsening took {time.time() - start}. Reduced {n_nodes_start} to ' f'{skel_nx.number_of_nodes()} nodes') return skel_nx
[docs]def create_new_skeleton_sv_fast(args): """ Generates a sparse skeleton for a super-segmentation object (SSO). This method creates edges between supervoxels based on connectivity, suitable for creating sparse representations of object skeletons. Note that performance varies based on connectivity rules. Args: sso: The Super Segmentation Object to generate the skeleton for. pruning_thresh (int): Pruning threshold value for branch length, with shorter branches being removed. sparsify (bool): Controls whether the generated skeleton should be sparsified. max_dist_thresh (float): Maximum permissible distance between edges for node pruning during initial sparsification. dot_prod_thresh (float): Dot product threshold for edge adjacencies influencing further pruning post-sparsification. max_dist_thresh_iter2 (float): Secondary distance threshold for additional pruning steps after the initial round of sparsifying and pruning. Returns: A tuple consisting of arrays that represent the nodes, diameters, and edges of the sparse skeleton, alongside cell reconstruction featuring a minimum spanning tree (MST) with estimated radii measures. Notes: - If the use of multi-processing is desired, set `ssv.nb_cpus` to a value greater than 1. - The initial sparsening algorithm can be found at :func:`~skeleton_optimization`. - Further sparsification and pruning steps are detailed at :func:`~sparsify_skeleton_fast`. """ so_id, sparsify = args so = SegmentationObject(obj_type="sv", obj_id=so_id) so.enable_locking = False so.load_attr_dict() # ignore diameters, will be populated at the and of create_sso_skeleton_fast skel = load_skeleton(so) nodes, diameters, edges = skel['nodes'].astype(np.uint32), skel['diameters'], skel['edges'] # create nx graph skel_nx = nx.Graph() skel_nx.add_nodes_from([(ix, dict(position=coord)) for ix, coord in enumerate(nodes)]) new_edges = [tuple(ix) for ix in edges] skel_nx.add_edges_from(new_edges) if sparsify: skel_nx = sparsify_skeleton_fast(skel_nx) n_cc = nx.number_connected_components(skel_nx) if n_cc > 1: log_reps.warning('SV {} contained {} connected components in its skel' 'eton representation. Stitching now.' ''.format(so.id, n_cc)) # make edge values and node IDs contiguous prior to stitching temp_edges = np.array(skel_nx.edges()).reshape(-1) temp_nodes_sorted = np.array(np.sort(skel_nx.nodes()).reshape(-1)) temp_edges_dict = {} for ii, ix in enumerate(temp_nodes_sorted): temp_edges_dict[ix] = ii temp_edges = [temp_edges_dict[ix] for ix in temp_edges] temp_edges = np.array(temp_edges, dtype=np.uint64).reshape([-1, 2]) skel_nx_tmp = nx.Graph() skel_nx_tmp.add_nodes_from([(temp_edges_dict[ix], skel_nx.nodes[ix]) for ix in skel_nx.nodes()]) skel_nx_tmp.add_edges_from(temp_edges) skel_nx = stitch_skel_nx(skel_nx_tmp) nodes, diameters, edges = from_netkx_to_arr(skel_nx) # just get nodes, diameters and edges return nodes, diameters, edges
[docs]def from_sso_to_netkx_fast(sso, sparsify=True, max_edge_length=1.5e3): """ Create a sparse supervoxel skeleton from an existing initial skeleton. This method acts as a multi-process helper similar to `create_new_skeleton`, often used in conjunction with functions like `from_sso_to_netkx_fast`. Args: sso: The SuperSegmentationObject to process (equivalent to Supervoxel ID for skeleton creation). sparsify: Flag indicating whether to sparsify the skeleton (equates to the sparse flag). max_edge_length: Optional; defines the maximum length of edges in the resulting skeleton. If not provided, defaults to the procedure used in the initial skeletonization. Returns: A tuple of three elements: 1. Node coordinates, expressed in voxels, 2. An estimation of the diameter for each node, 3. The list of edges connecting the nodes to represent the skeleton as a graph. This encompasses both the original and additional information, ensuring consistency and completeness while adhering to the maximum character limit. """ skel_nx = nx.Graph() sso.load_attr_dict() ssv_skel = {'nodes': [], 'edges': [], 'diameters': []} res = start_multiprocess_imap(create_new_skeleton_sv_fast, [(sv_id, sparsify) for sv_id in sso.sv_ids], nb_cpus=sso.nb_cpus, show_progress=False, debug=False) nodes, diameters, edges, sv_id_arr = [], [], [], [] # first offset is 0, last length is not needed n_nodes_per_sv = [0] + list( np.cumsum([len(el[0]) for el in res])[:-1]) for ii in range(len(res)): if len(res[ii][0]) == 0: # skip missing / empty skeletons, e.g. happens for small SVs continue nodes.append(res[ii][0]) diameters.append(res[ii][1]) edges.append(res[ii][2] + int(n_nodes_per_sv[ii])) # store mapping from node to SV ID sv_id_arr.append([sso.sv_ids[ii]] * len(res[ii][0])) ssv_skel['nodes'] = np.concatenate(nodes) ssv_skel['diameters'] = np.concatenate(diameters, axis=0) sv_id_arr = np.concatenate(sv_id_arr) node_ix_arr = np.arange(len(sv_id_arr)) added_edges = set() # stitching if len(sso.sv_ids) > 1: # iterates over SV object edges g = sso.load_sv_graph() # # TODO: activate as soon as SV graphs only connect adjacent SVs. # # bridge SVs without skeleton > might speed up the SSV # # skeleton generation if they lead to splits in the SV-graph -> fallback # # to `stitch_skel_nx` is slow. # # copy edges, as those might be modified inside loop # for e1, e2 in list(g.edges()): # # get closest node-pair between SV nodes in question and add new edge # nodes1 = ssv_skel['nodes'][sv_id_arr == e1.id] * sso.scaling # nodes2 = ssv_skel['nodes'][sv_id_arr == e2.id] * sso.scaling # nodes1 = nodes1.astype(np.float32) # nodes2 = nodes2.astype(np.float32) # if len(nodes1) == 0: # # bridge SV without skeleton # neighbors = g.neighbors(e1) # for p in itertools.combinations(neighbors, 2): # g.add_edge(p[0], p[1]) # if len(nodes2) == 0: # # bridge SV without skeleton # neighbors = g.neighbors(e2) # for p in itertools.combinations(neighbors, 2): # g.add_edge(p[0], p[1]) for e1, e2 in g.edges(): # get closest node-pair between SV nodes in question and add new edge nodes1 = ssv_skel['nodes'][sv_id_arr == e1.id] * sso.scaling nodes2 = ssv_skel['nodes'][sv_id_arr == e2.id] * sso.scaling nodes1 = nodes1.astype(np.float32) nodes2 = nodes2.astype(np.float32) if len(nodes1) == 0 or len(nodes2) == 0: continue # SV without skeleton nodes1_ix = node_ix_arr[sv_id_arr == e1.id] nodes2_ix = node_ix_arr[sv_id_arr == e2.id] tree = spatial.cKDTree(nodes1) dists, node_ixs1 = tree.query(nodes2) # # get global index of nodes ix2 = nodes2_ix[np.argmin(dists)] ix1 = nodes1_ix[node_ixs1[np.argmin(dists)]] added_edges.add((sv_id_arr[ix1], sv_id_arr[ix2])) node_dist_check = np.linalg.norm(ssv_skel['nodes'][ix1].astype(np.float32) * sso.scaling - ssv_skel['nodes'][ix2].astype( np.float32) * sso.scaling) if np.min(dists) < node_dist_check or node_dist_check > max_edge_length: log_reps.debug(f'Found long edge with length ' f'{node_dist_check / 1e3:.0f} um between SVs ' f'{e1.id} and {e2.id} although they were ' f'connected within the SV graph. Skipping.') # TODO: remove as soon as SV graphs only connect adjacent SVs. continue edges.append(np.array([[ix1, ix2]], dtype=np.uint32)) ssv_skel['edges'] = np.concatenate(edges) if len(ssv_skel['nodes']) == 0: sso.skeleton = ssv_skel return skel_nx.add_nodes_from([(ix, dict(position=coord)) for ix, coord in enumerate(ssv_skel['nodes'])]) edges = [tuple(ix) for ix in ssv_skel['edges']] skel_nx.add_edges_from(edges) if nx.number_connected_components(skel_nx) != 1: msg = 'Stitching of SV skeletons failed during "from_sso_to_netkx_' \ 'fast" with {} connected components using the underlying SSV ' \ 'agglomeration. Please check the underlying RAG of SSV {}. ' \ 'Now performing a slower stitching method to add missing edg' \ 'es recursively between the closest connected components. ' \ 'This warning might also occur if two supervoxels are connected ' \ 'over supervoxel(s) without skeleton!' \ ''.format(nx.number_connected_components(skel_nx), sso.id) skel_nx = stitch_skel_nx(skel_nx) log_reps.warning(msg) assert nx.number_connected_components(skel_nx) == 1 ssv_skel['edges'] = np.array(skel_nx.edges(), dtype=np.uint64) sso.skeleton = ssv_skel return skel_nx
[docs]def create_sso_skeleton_fast(sso, pruning_thresh=800, sparsify=True, max_dist_thresh=600, dot_prod_thresh=0.0, max_dist_thresh_iter2=600): """ Creates a sparse skeleton for a super-segmentation object (SSO). This method connects supervoxels in the supervoxel graph, potentially using multi-processing (set `ssv.nb_cpus` > 1 for this). It offers optional pruning and sparsification steps. Args: sso: The SuperSegmentationObject to process. pruning_thresh: Threshold for pruning short branches. Short branches are removed below this path length, expressed in NM, as per `prune_stub_branches`. sparsify: If True, the skeleton is sparsified; otherwise, it is not. max_dist_thresh: Initial threshold for node pruning based on the maximum distance in NM between two adjacent edges. dot_prod_thresh: Threshold for pruning nodes based on the dot product value of two adjacent edges. Nodes are pruned when their connecting edges' dot product is above this value. max_dist_thresh_iter2: A secondary distance threshold in NM for additional node pruning between two adjacent edges after initial sparsening and pruning. Returns: The SuperSegmentationObject with an updated sparse skeleton, including minimal spanning tree (MST) and radius estimates. """ # Creating network kx graph from sso skel # log_reps.debug('Creating skeleton of SSO {}'.format(sso.id)) skel_nx = from_sso_to_netkx_fast(sso) # log_reps.debug('Number CC after stitching and sparsifying SSO {}: {}'.format(sso.id, # nx.number_connected_components(skel_nx))) # Sparse again after stitching. Inexpensive. if sparsify: skel_nx = sparsify_skeleton_fast(skel_nx, max_dist_thresh=max_dist_thresh, min_dist_thresh=max_dist_thresh) # log_reps.debug( # 'Number CC after 2nd sparsification SSO {}: {}'.format(sso.id, # nx.number_connected_components(skel_nx))) # Pruning the stitched sso skeletons _, skel_nx = prune_stub_branches(nx_g=skel_nx, len_thres=pruning_thresh) if sparsify: # dot_prod_thresh=0.95: this allows to remove nodes which neighboring # edges have an angle below ~36° skel_nx = sparsify_skeleton_fast( skel_nx, dot_prod_thresh=dot_prod_thresh, max_dist_thresh=max_dist_thresh_iter2) start = time.time() for e in skel_nx.edges: w = np.linalg.norm( (skel_nx.nodes[e[0]]['position'] - skel_nx.nodes[e[1]]['position']) * global_params.config['scaling']) skel_nx[e[0]][e[1]]['weight'] = w skel_nx = nx.minimum_spanning_tree(skel_nx) log_reps.debug(f'mst took {time.time() - start:.0f} s') sso = from_netkx_to_sso(sso, skel_nx) # reset weighted graph sso._weighted_graph = None # log_reps.debug('Number CC after pruning SSO {}: {}'.format(sso.id, # nx.number_connected_components(skel_nx))) # Estimating the radii start = time.time() sso.skeleton = radius_correction_found_vertices(sso) log_reps.debug(f'radius estimation took {time.time() - start:.0f} s') return sso
[docs]def glia_pred_exists(so): so.load_attr_dict() return "glia_probas" in so.attr_dict
[docs]def views2tripletinput(views): views = views[:, :, :1] # use first view only out_d = np.concatenate([views, np.ones_like(views), np.ones_like(views)], axis=2) return out_d.astype(np.float32)
[docs]def get_pca_view_hists(sso, t_net, pca): views = sso.load_views() latent = t_net.predict_proba(views2tripletinput(views)) latent = pca.transform(latent) hist0 = np.histogram(latent[:, 0], bins=50, range=[-2, 2], normed=True) hist1 = np.histogram(latent[:, 1], bins=50, range=[-3.2, 3], normed=True) hist2 = np.histogram(latent[:, 2], bins=50, range=[-3.5, 3.5], normed=True) return np.array([hist0, hist1, hist2])
[docs]def save_view_pca_proj(sso, t_net, pca, dest_dir, ls=20, s=6.0, special_points=(), special_markers=(), special_kwargs=()): import matplotlib.pyplot as plt import matplotlib.ticker as ticker views = sso.load_views() latent = t_net.predict_proba(views2tripletinput(views)) latent = pca.transform(latent) col = (np.array(latent) - latent.min(axis=0)) / (latent.max(axis=0) - latent.min(axis=0)) col = np.concatenate([col, np.ones_like(col)[:, :1]], axis=1) for ii, (a, b) in enumerate([[0, 1], [0, 2], [1, 2]]): fig, ax = plt.subplots() plt.scatter(latent[:, a], latent[:, b], c=col, s=s, lw=0.5, marker="o", edgecolors=col) if len(special_points) >= 0: for kk, sp in enumerate(special_points): if len(special_markers) == 0: sm = "x" else: sm = special_markers[kk] if len(special_kwargs) == 0: plt.scatter(sp[None, a], sp[None, b], s=75.0, lw=2.3, marker=sm, edgecolor="0.3", facecolor="none") else: plt.scatter(sp[None, a], sp[None, b], **special_kwargs) fig.patch.set_facecolor('white') ax.tick_params(axis='x', which='major', labelsize=ls, direction='out', length=4, width=3, right="off", top="off", pad=10) ax.tick_params(axis='y', which='major', labelsize=ls, direction='out', length=4, width=3, right="off", top="off", pad=10) ax.tick_params(axis='x', which='minor', labelsize=ls, direction='out', length=4, width=3, right="off", top="off", pad=10) ax.tick_params(axis='y', which='minor', labelsize=ls, direction='out', length=4, width=3, right="off", top="off", pad=10) plt.xlabel(r"$Z_%d$" % (a + 1), fontsize=ls) plt.ylabel(r"$Z_%d$" % (b + 1), fontsize=ls) ax.xaxis.set_major_locator(ticker.MultipleLocator(2)) ax.yaxis.set_major_locator(ticker.MultipleLocator(2)) plt.tight_layout() plt.savefig(dest_dir + "/%d_pca_%d%d.png" % (sso.id, a + 1, b + 1), dpi=400) plt.close()
[docs]def skelnode_comment_dict(sso): comment_dict = {} skel = load_skeleton_kzip(sso.skeleton_kzip_path)["skeleton"] for n in skel.getNodes(): c = frozenset(n.getCoordinate()) comment_dict[c] = n.getComment() return comment_dict
[docs]def label_array_for_sso_skel(sso, comment_converter): """ Converts skeleton node comments to a label array matching the node order. Args: sso: SuperSegmentationObject, contains the skeleton extracted from sso.skeleton_kzip_path (see SkeletonAnnotation from knossos utils). comment_converter: dict, maps node comments to integer labels. Unspecified comments receive label -1. Returns: np.array, an array of labels corresponding to the order of nodes in sso.skeleton["nodes"]. """ if sso.skeleton is None: sso.load_skeleton() cd = skelnode_comment_dict(sso) label_array = np.ones(len(sso.skeleton["nodes"]), dtype=np.int32) * -1 for ii, n in enumerate(sso.skeleton["nodes"]): comment = cd[frozenset(n.astype(np.int32))].lower() try: label_array[ii] = comment_converter[comment] except KeyError: pass return label_array
[docs]def write_axpred_cnn(ssv, pred_key_appendix, dest_path=None, k=1): if dest_path is None: dest_path = ssv.skeleton_kzip_path_views pred_key = "axoness_preds%s" % pred_key_appendix if not ssv.attr_exists(pred_key): log_reps.info("Couldn't find specified axoness prediction. Falling back " "to default.") preds = np.array(start_multiprocess_obj("axoness_preds", [[sv, { "pred_key_appendix": pred_key_appendix}] for sv in ssv.svs], nb_cpus=ssv.nb_cpus)) preds = np.concatenate(preds) else: preds = ssv.lookup_in_attribute_dict(pred_key) log_reps.debug("Collected axoness: {}".format(Counter(preds).most_common())) locs = ssv.sample_locations() log_reps.debug("Collected locations.") pred_coords = np.concatenate(locs) assert pred_coords.ndim == 2 assert pred_coords.shape[1] == 3 colors = np.array(np.array([[0.6, 0.6, 0.6, 1], [0.841, 0.138, 0.133, 1.], [0.32, 0.32, 0.32, 1.]]) * 255, dtype=np.uint32) ssv._pred2mesh(pred_coords, preds, "axoness.ply", dest_path=dest_path, k=k, colors=colors)
[docs]def cnn_axoness2skel(sso: 'super_segmentation.SuperSegmentationObject', pred_key_appendix: str = "", k: int = 1, force_reload: bool = False, save_skel: bool = True, use_cache: bool = False): """ Generates axoness predictions and probabilities for a given SuperSegmentationObject (SSO) and saves them to the 'axoness_preds_cnn' attribute in the SSV attribute dict as skeleton attributes. It maps the predictions from the supervoxel views to the skeleton nodes using nearest neighbor assignment. Args: sso: The SuperSegmentationObject for which axoness predictions and probabilities are generated. pred_key_appendix: A string appended to the prediction keys to differentiate between different prediction sets. k: Deprecated. Previously used to define the number of nearest neighbors for prediction assignment. force_reload: If True, forces the reloading of predictions even if they already exist. Reload SV predictions. save_skel: If True, saves the skeleton with the new prediction attributes, saving the SSV skeleton with prediction attributes "axoness" and "axoness_probas". use_cache: If True, caches the intermediate supervoxel predictions in the SSO attribute dictionary on disk. Write intermediate SV predictions in SSV attribute dict to disk. Returns: None """ if k != 1: log_reps.warn("Parameter 'k' is deprecated but was set to {}. " "It is not longer used in this method.".format(k)) if sso.skeleton is None: sso.load_skeleton() proba_key = "axoness_probas_cnn%s" % pred_key_appendix pred_key = "axoness_preds_cnn%s" % pred_key_appendix if not sso.attr_exists(pred_key) or not sso.attr_exists(proba_key) or \ force_reload: preds = np.array(start_multiprocess_obj( "axoness_preds", [[sv, {"pred_key_appendix": pred_key_appendix}] for sv in sso.svs], nb_cpus=sso.nb_cpus)) probas = np.array(start_multiprocess_obj( "axoness_probas", [[sv, {"pred_key_appendix": pred_key_appendix}] for sv in sso.svs], nb_cpus=sso.nb_cpus)) preds = np.concatenate(preds) probas = np.concatenate(probas) sso.attr_dict[proba_key] = probas sso.attr_dict[pred_key] = preds if use_cache: sso.save_attributes([proba_key, pred_key], [probas, preds]) else: preds = sso.lookup_in_attribute_dict(pred_key) probas = sso.lookup_in_attribute_dict(proba_key) loc_coords = np.concatenate(sso.sample_locations()) assert len(loc_coords) == len(preds), "Number of view coordinates is" \ "different from number of view" \ "predictions. SSO %d" % sso.id # find NN in loc_coords for every skeleton node and use their majority # prediction node_preds = colorcode_vertices(sso.skeleton["nodes"] * sso.scaling, loc_coords, preds, colors=[0, 1, 2], k=1) node_probas, ixs = assign_rep_values(sso.skeleton["nodes"] * sso.scaling, loc_coords, probas, return_ixs=True) assert np.max(ixs) <= len(loc_coords), "Maximum index for sample " \ "coordinates is bigger than " \ "length of sample coordinates." sso.skeleton["axoness%s" % pred_key_appendix] = node_preds sso.skeleton["axoness_probas%s" % pred_key_appendix] = node_probas sso.skeleton["view_ixs"] = ixs if save_skel: sso.save_skeleton()
[docs]def average_node_axoness_views(sso: 'super_segmentation.SuperSegmentationObject', pred_key: Optional[str] = None, pred_key_appendix: str = "", max_dist: int = 10000, return_res: bool = False, use_cache: bool = False): """ Averages axoness prediction along skeleton with maximum path length of 'max_dist'. The majority prediction of neighboring nodes within this distance is assigned to each node in the skeleton of the SSO. Args: sso: The SuperSegmentationObject whose skeleton nodes are evaluated. pred_key: The key for retrieving stored supervoxel predictions. If None, a default key is used: ``"axoness_preds_cnn%s" % pred_key_appendix``. pred_key_appendix: Appended to default prediction key if pred_key is None, formatting into a complete key. max_dist: The path length over which to average the predictions. return_res: When True, a list of average predictions for each node is returned rather than modifying the SSO. use_cache: If enabled, caches intermediate supervoxel predictions in the SSO attribute dictionary and can save to disk. Returns: If return_res is True, it returns a dictionary with the averaged predictions for each skeleton node. Otherwise, the results are integrated into the SSO and accessible via the generated key ``"%s_views_avg%d" % (pred_key, max_dist)`` and no value is returned. The method does not call ``sso.save_skeleton()``. """ if sso.skeleton is None: sso.load_skeleton() if len(sso.skeleton["edges"]) == 0: log_reps.error("Zero edges in skeleton of SSV %d. " "Skipping averaging." % sso.id) return if pred_key is None: pred_key = "axoness_preds_cnn%s" % pred_key_appendix elif len(pred_key_appendix) > 0: raise ValueError("Only one of the two may be given: 'pred_key' or" "'pred_key_appendix', but not both.") if type(pred_key) != str: raise ValueError("'pred_key' has to be of type str.") if not sso.attr_exists(pred_key) and ("axoness_preds_cnn" not in pred_key): if len(pred_key_appendix) > 0: log_reps.error("Couldn't find specified axoness prediction. Falling" " back to default (-> per SV stored multi-view " "prediction including SSV context") preds = np.array(start_multiprocess_obj( "axoness_preds", [[sv, {"pred_key_appendix": pred_key_appendix}] for sv in sso.svs], nb_cpus=sso.nb_cpus)) preds = np.concatenate(preds) sso.attr_dict[pred_key] = preds if use_cache: sso.save_attributes([pred_key], [preds]) else: preds = sso.lookup_in_attribute_dict(pred_key) loc_coords = np.concatenate(sso.sample_locations()) assert len(loc_coords) == len(preds), "Number of view coordinates is " \ "different from number of view " \ "predictions. SSO %d" % sso.id if "view_ixs" not in sso.skeleton.keys(): log_reps.info("View indices were not yet assigned to skeleton nodes. " "Running now '_cnn_axonness2skel(sso, " "pred_key_appendix=pred_key_appendix, k=1)'") cnn_axoness2skel(sso, pred_key_appendix=pred_key_appendix, k=1, save_skel=not return_res, use_cache=use_cache) view_ixs = np.array(sso.skeleton["view_ixs"]) avg_pred = [] g = sso.weighted_graph() for n in range(g.number_of_nodes()): paths = nx.single_source_dijkstra_path(g, n, max_dist) neighs = np.array(list(paths.keys()), dtype=np.int64) unique_view_ixs = np.unique(view_ixs[neighs], return_counts=False) cls, cnts = np.unique(preds[unique_view_ixs], return_counts=True) c = cls[np.argmax(cnts)] avg_pred.append(c) if return_res: return avg_pred sso.skeleton["axoness%s_avg%d" % (pred_key_appendix, max_dist)] = avg_pred
[docs]def majority_vote_compartments(sso: 'SuperSegmentationObject', ax_pred_key: str = 'axoness'): """ Determines the majority compartment prediction for each connected component of the skeleton graph of an SSO, excluding soma nodes. The majority prediction is used to relabel the nodes of each connected component. By default, will save new skeleton attribute with key ax_pred_key + "_comp_maj". Will not call `sso.save_skeleton()`. Args: sso: The SuperSegmentationObject whose skeleton compartments are being processed. ax_pred_key: The key for accessing axoness predictions stored in the SSO's skeleton. Returns: None """ g = sso.weighted_graph(add_node_attr=(ax_pred_key,)) soma_free_g = g.copy() for n, d in g.nodes(data=True): if d[ax_pred_key] == 2: soma_free_g.remove_node(n) ccs = list((soma_free_g.subgraph(c) for c in nx.connected_components(soma_free_g))) new_axoness_dc = nx.get_node_attributes(g, ax_pred_key) for cc in ccs: preds = [d[ax_pred_key] for n, d in cc.nodes(data=True)] cls, cnts = np.unique(preds, return_counts=True) majority = cls[np.argmax(cnts)] probas = np.array(cnts, dtype=np.float32) / np.sum(cnts) # positively bias dendrite assignment if (majority == 1) and (probas[cls == 1] < 0.66): majority = 0 for n in cc.nodes(): new_axoness_dc[n] = majority nx.set_node_attributes(g, new_axoness_dc, ax_pred_key) new_axoness_arr = np.zeros((len(sso.skeleton["nodes"]))) for n, d in g.nodes(data=True): new_axoness_arr[n] = d[ax_pred_key] sso.skeleton[ax_pred_key + "_comp_maj"] = new_axoness_arr sso.save_skeleton()
[docs]def majorityvote_skeleton_property(sso: 'super_segmentation.SuperSegmentationObject', prop_key: str, max_dist: int = 10000, return_res: bool = False) -> np.ndarray: """ Applies a sliding window majority vote along the skeleton of a SuperSegmentationObject (SSO), utilizing the specified property key. This vote determines the prevailing value over a path length for each skeletal node. Args: sso: The SuperSegmentationObject (SSO) whose skeleton property is being processed. This is the cell reconstruction object. prop_key: The key identifier of the property to be processed. max_dist: The maximum path length permitted for the sliding window along the L2-distance weighted skeleton graph. return_res: If True, the function returns the resulting majority vote values for the skeleton nodes. Returns: If return_res is True, a numpy array with the majority vote result for each skeleton node is returned. When return_res is False, the function does not return a value; instead, it directly modifies the SSO without invoking `sso.save_skeleton()`. Note: The function will not call `sso.save_skeleton()` post-processing. To persist modifications, this call must be executed manually after function execution if necessary. """ if not prop_key in sso.skeleton: raise ValueError(f'Given property "{prop_key}" does not exist in ' f'skeleton of SSV {sso.id}.') g = sso.weighted_graph() avg_prop = [] for n in range(g.number_of_nodes()): paths = nx.single_source_dijkstra_path(g, n, max_dist) neighs = np.array(list(paths.keys()), dtype=np.int64) prop_vals, cnts = np.unique(sso.skeleton[prop_key][neighs], return_counts=True) c = prop_vals[np.argmax(cnts)] avg_prop.append(c) avg_prop = np.array(avg_prop) if return_res: return avg_prop sso.skeleton["%s_avg%d" % (prop_key, max_dist)] = avg_prop
[docs]def find_incomplete_ssv_views(ssd: 'SuperSegmentationDataset', woglia: bool, n_cores: Optional[int] = None): """ Identifies SuperSegmentationObjects within a SuperSegmentationDataset that have incomplete views. Args: ssd: The SuperSegmentationDataset to search within. woglia: A boolean indicating whether to consider views with glia removal. n_cores: The number of cores to use for processing. If None, the default number of cores specified in the global configuration is used. Returns: A list of IDs of SuperSegmentationObjects with incomplete views. """ if n_cores is None: n_cores = global_params.config['ncores_per_node'] sd = ssd.get_segmentationdataset("sv") incomplete_sv_ids = find_missing_sv_views(sd, woglia, n_cores) missing_ssv_ids = set() incomplete_ssv_ids = ssd.sv2ssv_ids(incomplete_sv_ids) for sv_id in incomplete_sv_ids: try: ssv_id = incomplete_ssv_ids[sv_id] missing_ssv_ids.add(ssv_id) except KeyError: pass # sv does not exist in this SSD return list(missing_ssv_ids)
[docs]def find_incomplete_ssv_skeletons(ssd, n_cores: Optional[int] = None): """ Identifies SuperSegmentationObjects within a SuperSegmentationDataset that have incomplete skeletons. Args: ssd: The SuperSegmentationDataset to search within. n_cores: The number of cores to use for processing. If None, the default number of cores specified in the global configuration is used. Returns: A list of IDs of SuperSegmentationObjects with incomplete skeletons. """ if n_cores is None: n_cores = global_params.config['ncores_per_node'] svs = np.concatenate([list(ssv.svs) for ssv in ssd.ssvs]) incomplete_sv_ids = find_missing_sv_skeletons(svs, n_cores) missing_ssv_ids = set() incomplete_ssv_ids = ssd.sv2ssv_ids(incomplete_sv_ids) for sv_id in incomplete_sv_ids: try: ssv_id = incomplete_ssv_ids[sv_id] missing_ssv_ids.add(ssv_id) except KeyError: pass # sv does not exist in this SSD return list(missing_ssv_ids)
[docs]def find_missing_sv_attributes_in_ssv(ssd, attr_key, n_cores: Optional[int] = None): """ Identifies SuperSegmentationObjects within a SuperSegmentationDataset that are missing specified attributes. Args: ssd: The SuperSegmentationDataset to search within. attr_key: The key of the attribute to check for. n_cores: The number of cores to use for processing. If None, the default number of cores specified in the global configuration is used. Returns: A list of IDs of SuperSegmentationObjects missing the specified attribute. """ if n_cores is None: n_cores = global_params.config['ncores_per_node'] sd = ssd.get_segmentationdataset("sv") incomplete_sv_ids = find_missing_sv_attributes(sd, attr_key, n_cores) missing_ssv_ids = set() incomplete_ssv_ids = ssd.sv2ssv_ids(incomplete_sv_ids) for sv_id in incomplete_sv_ids: try: ssv_id = incomplete_ssv_ids[sv_id] missing_ssv_ids.add(ssv_id) except KeyError: pass # sv does not exist in this SSD return list(missing_ssv_ids)
[docs]def predict_views_semseg(views, model, batch_size=10, verbose=False): """ Predicts semantic segmentation for a given array of views using a specified model. It processes the array in batches and can provide verbose output. Args: views: A numpy array of shape [N_LOCS, N_CH, N_VIEWS, X, Y], where N_LOCS is the number of locations, N_CH is the number of channels (e.g., shape of cell, mitochondria, synaptic junctions, and vesicle clouds), N_VIEWS is the number of views per location, and X, Y are the spatial dimensions of each view. The array should be uint8 scaled from 0 to 255. model: A PyTorch model used for prediction. batch_size: The batch size to use during prediction. verbose: If True, additional output is printed during the process. Returns: A numpy array of predicted views with the same shape as the input 'views'. """ # if verbose: # log_reps.debug('Reshaping view array with shape {}.' # ''.format(views.shape)) views = views.astype(np.float32) / 255. views = views.swapaxes(1, 2) # swap channel and view axis # N, 2, 4, 128, 256 orig_shape = views.shape # reshape to predict single projections, N*2, 4, 128, 256 views = views.reshape([-1] + list(orig_shape[2:])) # if verbose: # log_reps.debug('Predicting view array with shape {}.' # ''.format(views.shape)) # predict and reset to original shape: N, 2, 4, 128, 256 labeled_views = model.predict_proba(views, bs=batch_size, verbose=verbose) labeled_views = np.argmax(labeled_views, axis=1)[:, None] # if verbose: # log_reps.debug('Finished prediction of view array with shape {}.' # ''.format(views.shape)) labeled_views = labeled_views.reshape(list(orig_shape[:2]) + list(labeled_views.shape[1:])) # swap axes to get source shape labeled_views = labeled_views.swapaxes(2, 1) return labeled_views
[docs]def pred_svs_semseg(model, views, pred_key=None, svs=None, return_pred=False, nb_cpus=1, verbose=False, bs: int = 10): """ Predicts semantic segmentation for views of a list of Supervoxels (SVs) and optionally saves them using SV.save_views. This is an efficient helper function designed for chunked predictions and requires pre-loaded views. Args: model: The model used for semantic segmentation prediction. views: List[np.array] N_SV each with np.array of shape [N_LOCS, N_CH, N_VIEWS, X, Y] as uint8 scaled from 0 to 255, representing views of each supervoxel. pred_key: The key under which predictions will be saved in SV view storages. svs: Optional[list[SegmentationObject]] List of SegmentationObject instances corresponding to SVs. If not provided, SVs must be pre-loaded. return_pred: Optional[bool] If True, returns the predicted label views instead of saving them. nb_cpus: int The number of CPUs to use for saving SV views. verbose: bool If True, prints additional output during the prediction process. bs: int Batch size used during inference. Returns: list[np.array] If 'return_pred=True', returns the label views of the supervoxels as numpy arrays. Otherwise, the predictions are saved without returning any value. """ if not return_pred and (svs is None or pred_key is None): raise ValueError('SV objects and "pred_key" have to be given if' ' predictions should be saved at SV view storages.') part_views = np.cumsum([0] + [len(v) for v in views]) assert len(part_views) == len(views) + 1 views = np.concatenate(views) # merge axis 0, i.e. N_SV and N_LOCS to N_SV*N_LOCS # views have shape: M, 4, 2, 128, 256 label_views = predict_views_semseg(views, model, verbose=verbose, batch_size=bs) svs_labelviews = [] for ii in range(len(part_views[:-1])): sv_label_views = label_views[part_views[ii]:part_views[ii + 1]] svs_labelviews.append(sv_label_views) assert len(part_views) == len(svs_labelviews) + 1 if return_pred: return svs_labelviews params = [[sv, dict(views=views, index_views=False, woglia=True, view_key=pred_key)] for sv, views in zip(svs, svs_labelviews)] start_multiprocess_obj('save_views', params, nb_cpus=nb_cpus)
[docs]def pred_sv_chunk_semseg(args): """ Helper method for predicting the 2D projections of supervoxels in chunks. Args: args: A tuple containing paths to the supervoxel storages to be processed, the model, and supervoxel and prediction parameters. Returns: None. The predictions are stored in the supervoxel storages. """ from syconn.proc.sd_proc import sos_dict_fact, init_sos from syconn.backend.storage import CompressedStorage from syconn.handler.prediction import get_semseg_spiness_model so_chunk_paths = args[0] so_kwargs = args[1] pred_kwargs = args[2] # By default use views after glia removal if 'woglia' in pred_kwargs: woglia = pred_kwargs["woglia"] del pred_kwargs["woglia"] else: woglia = True pred_key = pred_kwargs["pred_key"] if 'raw_only' in pred_kwargs: raw_only = pred_kwargs['raw_only'] del pred_kwargs['raw_only'] else: raw_only = False model = get_semseg_spiness_model() for p in so_chunk_paths: # get raw views view_dc_p = p + "/views_woglia.pkl" if woglia else p + "/views.pkl" view_dc = CompressedStorage(view_dc_p, disable_locking=True) svixs = list(view_dc.keys()) if len(svixs) == 0: continue views = list(view_dc.values()) if raw_only: views = views[:, :1] sd = sos_dict_fact(svixs, **so_kwargs) svs = init_sos(sd) label_views = pred_svs_semseg(model, views, svs, return_pred=True, verbose=False) # choose any SV to get a path constructor for the v # iew storage (is the same for all SVs of this chunk) lview_dc_p = svs[0].view_path(woglia, view_key=pred_key) label_vd = CompressedStorage(lview_dc_p, disable_locking=True) for ii in range(len(svs)): label_vd[svs[ii].id] = label_views[ii] label_vd.push()
[docs]def gliapred_sso_nocache(sso: 'SuperSegmentationObject', model, verbose: bool = True): """ Performs a multi-view based astrocyte inference on a SuperSegmentationObject without using cached views. The result is stored as 'glia_probas' in the attribute dictionaries of every supervoxel within the SuperSegmentationObject. Access the probabilities of a supervoxel via `sso.svs[idx].attr_dict['glia_probas']`. Args: sso: The SuperSegmentationObject to process. model: A PyTorch model used for astrocyte inference. verbose: If True, additional output is printed during the process. Returns: None. The probabilities are stored in the attribute dictionaries of the supervoxels within 'sso'. """ pred_key = "glia_probas" assert sso.version == 'tmp', 'Only use this method with ssv.version="tmp".' coords = sso.sample_locations(cache=False) # len(part_views) == N + 1 part_views = np.cumsum([0] + [len(c) for c in coords]) flat_coords = np.array(flatten_list(coords)) # views are flat views = render_sso_coords(sso, flat_coords, verbose=verbose, add_cellobjects=False, return_rot_mat=False) sv_views = [] for ii in range(len(sso.svs)): sv_views.append(views[part_views[ii]:part_views[ii+1]]) del views probas = predict_views(model, sv_views, None, return_proba=True, pred_key=pred_key, nb_cpus=sso.nb_cpus, verbose=verbose) for ii, prob in enumerate(probas): sso.svs[ii].attr_dict[pred_key] = prob
[docs]@jit(nopython=True) def semseg2mesh_counter(index_arr: np.ndarray, label_arr: np.ndarray, bg_label: int, count_arr: np.ndarray) -> np.ndarray: """ Counts the occurrence of labels in 'label_arr' for each vertex ID in 'index_arr' and accumulates the counts in 'count_arr'. Args: index_arr: A flat array of contiguous vertex IDs, corresponding to the order in 'label_arr'. label_arr: A flat array of semantic segmentation prediction results, corresponding to the order in 'index_arr'. The maximum value must be below 'bg_label'. bg_label: The label used to represent the background, which will not be counted. count_arr: A zero-initialized array to store the per-vertex counted labels from 'label_arr'. It must have the shape (M, bg_label), where M is the number of vertices of the underlying mesh. Returns: An array filled with the per-vertex label counts. """ for ii in range(len(index_arr)): vertex_ix = index_arr[ii] if vertex_ix == bg_label: continue l = label_arr[ii] # vertex label count_arr[vertex_ix][l] += 1 return count_arr
[docs]def semseg2mesh(sso, semseg_key, nb_views=None, dest_path=None, k=1, colors=None, force_recompute=False, index_view_key=None): """ Maps semantic segmentation predictions to the mesh of a SuperSegmentationObject (SSO) and optionally saves the colored mesh to a file. Args: sso: The SuperSegmentationObject whose mesh will be colored based on semantic segmentation predictions. semseg_key: The key identifying the views containing the semantic segmentation results. index_view_key: The key identifying the views containing the vertex indices. If set, `nb_views` is ignored. nb_views: The number of views used for the prediction, required if `index_view_key` is not set. dest_path: If provided, the colored mesh will be written to a k.zip file at this path. k: The number of nearest vertices to average over when mapping predictions to the mesh. If k=0, unpredicted vertices will be treated as 'unpredicted' class. colors: An array mapping labels to colors. If None, the majority label is returned instead. Note to add a color for unpredicted vertices if k==0; here illustrated with by the spine prediction example: if k=0: [neck, head, shaft, other, background, unpredicted] else: [neck, head, shaft, other, background]. force_recompute: If True, forces re-mapping of the predicted labels to the mesh vertices. Notes: * ``k>0`` should only be used if a prediction for all vertices is absolutely required. Filtering of background and unpredicted vertices should be favored if time complexity is critical. Returns: If `dest_path` is None, returns a tuple containing the mesh indices, vertices, normals, and colors. Otherwise, the function has no return value and the colored mesh is saved to the specified path. """ ld = sso.label_dict('vertex') if force_recompute or semseg_key not in ld: ts0 = time.time() # view loading if nb_views is None and index_view_key is None: # load default i_views = sso.load_views(index_views=True).flatten() else: if index_view_key is None: index_view_key = "index{}".format(nb_views) # load special views i_views = sso.load_views(index_view_key).flatten() semseg_views = sso.load_views(semseg_key).flatten() ts1 = time.time() # log_reps.debug('Time to load index and shape views: ' # '{:.2f}s.'.format(ts1 - ts0)) background_id = np.max(i_views) # TODO: this will fail if no single pixel in all views is background background_l = np.max(semseg_views) unpredicted_l = background_l + 1 pp = len(sso.mesh[1]) // 3 count_arr = np.zeros((pp, background_l + 1), dtype=np.uint8) count_arr = semseg2mesh_counter(i_views, semseg_views, background_id, count_arr) # np.argmax returns int64 array.. `colorcode_vertices` complexity is # sensitive to the datatype of vertex_labels! vertex_labels = np.argmax(count_arr, axis=1).astype(np.uint8) mask = np.sum(count_arr, axis=1) == 0 vertex_labels[mask] = unpredicted_l # background label is highest label in prediction (see 'generate_palette' or # 'remap_rgb_labelviews' in multiviews.py) if unpredicted_l > 255: raise ValueError('Overflow in label view array.') if k == 0: # map actual prediction situation / coverage # keep unpredicted vertices and vertices with background labels predicted_vertices = sso.mesh[1].reshape(-1, 3) predictions = vertex_labels else: # remove unpredicted vertices predicted_vertices = sso.mesh[1].reshape(-1, 3)[vertex_labels != unpredicted_l] predictions = vertex_labels[vertex_labels != unpredicted_l] # remove background class predicted_vertices = predicted_vertices[predictions != background_l] predictions = predictions[predictions != background_l] ts2 = time.time() # log_reps.debug('Time to map predictions on vertices: ' # '{:.2f}s.'.format(ts2 - ts1)) # High time complexity! if k > 0: # map predictions of predicted vertices to all vertices maj_vote = colorcode_vertices( sso.mesh[1].reshape((-1, 3)), predicted_vertices, predictions, k=k, return_color=False, nb_cpus=sso.nb_cpus) ts3 = time.time() # log_reps.debug('Time to map predictions on unpredicted vertices' # 'with k={}: {:.2f}s.'.format(k, ts3 - ts2)) else: # no vertex mask was applied in this case maj_vote = predictions # add prediction to mesh storage ld[semseg_key] = maj_vote ld.push() else: maj_vote = ld[semseg_key].astype(np.int32) if colors is not None: col = colors[maj_vote].astype(np.uint8) if np.sum(col) == 0: log_reps.warn('All colors-zero warning during "semseg2mesh"' ' of SSO {}. Make sure color values have uint8 range ' '0...255'.format(sso.id)) else: col = maj_vote if dest_path is not None: if colors is None: col = None # set to None, because write_mesh2kzip only supports # RGBA colors and no labels write_mesh2kzip(dest_path, sso.mesh[0], sso.mesh[1], sso.mesh[2], col, ply_fname=semseg_key + ".ply") return return sso.mesh[0], sso.mesh[1], sso.mesh[2], col
[docs]def celltype_of_sso_nocache(sso, model, ws, nb_views, comp_window, nb_views_model: int = 20, pred_key_appendix: str = "", verbose: bool = False, overwrite: bool = True, use_syntype: bool = True, da_equals_tan: bool = True, n_classes: int = 7, save_to_attr_dict: bool = True): """ Predicts the cell type of a SuperSegmentationObject without using file system caching. This function renders raw views at rendering locations determined by `comp_window` and following the given view properties. These views are then predicted with the provided `model`. By default, the resulting predictions and probabilities are stored as 'celltype_cnn_e3' and 'celltype_cnn_e3_probas' in the attribute dictionary. Args: sso: SuperSegmentationObject to be processed. model: A machine learning model used for predictions. ws: Tuple[int, int], window size in pixels [y, x], determines the size of each view. nb_views: int, number of views rendered at each location. Views are not stored on disk. comp_window: float, physical extent in nm of the view-window along y. nb_views_model: int, bootstrap sample size of view locations for model prediction. pred_key_appendix: str, appendix for the prediction key in the attribute dictionary. verbose: bool, if True, adds progress bars for view generation. overwrite: bool, if True, overwrites existing views in temporary view dictionary. use_syntype: bool, if True, uses the type of presynaptic synapses for prediction. da_equals_tan: bool, if True, merges DA and TAN classes, requiring `n_classes` to be 7. n_classes: int, number of output classes of the model, must be 7 if `da_equals_tan` is True. save_to_attr_dict: bool, if True, saves the prediction in the attribute dictionary. Returns: None """ sso.load_attr_dict() pred_key = "celltype_cnn_e3" + pred_key_appendix if not overwrite and pred_key in sso.attr_dict: return view_kwargs = dict(ws=ws, comp_window=comp_window, nb_views=nb_views, verbose=verbose, add_cellobjects=True, return_rot_mat=False) verts = sso.mesh[1].reshape(-1, 3) rendering_locs = generate_rendering_locs(verts, comp_window / 3) # three views per comp window # overwrite default rendering locations (used later on for the view generation) sso._sample_locations = rendering_locs # this cache is only in-memory, and not file system cache assert sso.view_caching, "'view_caching' of {} has to be True in order to" \ " run 'celltype_of_sso_nocache'.".format(sso) tmp_view_key = 'tmp_views' + pred_key_appendix if tmp_view_key not in sso.view_dict or overwrite: views = render_sso_coords(sso, rendering_locs, **view_kwargs) # shape: N, 4, nb_views, y, x sso.view_dict[tmp_view_key] = views # required for `sso_views_to_modelinput` from ..handler.prediction import naive_view_normalization_new if verbose: log_reps.debug('Finished rendering. Starting cell type prediction.') inp_d = sso_views_to_modelinput(sso, nb_views_model, view_key=tmp_view_key) inp_d = naive_view_normalization_new(inp_d) if use_syntype: synsign_ratio = np.array([[syn_sign_ratio_celltype(sso, comp_types=[1, ]), syn_sign_ratio_celltype(sso, comp_types=[0, ])]] * len(inp_d)) res = model.predict_proba((inp_d, synsign_ratio), bs=40) else: res = model.predict_proba(inp_d, bs=40) if verbose: log_reps.debug('Finished prediction.') # DA and TAN are type modulatory, if this is changes, also change `certainty_celltype`, `predict_sso_celltype` if da_equals_tan: assert n_classes == 7 # accumulate evidence for DA and TAN res[:, 1] += res[:, 6] # remove TAN in proba array res = np.delete(res, [6], axis=1) # INT is now at index 6 -> label 6 is INT clf = np.argmax(res, axis=1) if np.max(clf) >= n_classes: raise ValueError('Unknown cell type predicted.') major_dec = np.zeros(n_classes) for ii in range(len(major_dec)): major_dec[ii] = np.sum(clf == ii) major_dec /= np.sum(major_dec) pred = np.argmax(major_dec) sso.attr_dict[pred_key] = pred sso.attr_dict[f"{pred_key}_probas"] = res cert = sso.certainty_celltype(pred_key) sso.attr_dict[f"{pred_key}_certainty"] = cert if save_to_attr_dict: sso.save_attributes([pred_key, f"{pred_key}_probas", f"{pred_key}_certainty"], [pred, res, cert])
[docs]def view_embedding_of_sso_nocache(sso: 'SuperSegmentationObject', model: 'torch.nn.Module', ws: Tuple[int, int], nb_views: int, comp_window: Union[int, float], pred_key_appendix: str = "", verbose: bool = False, overwrite: bool = True, add_cellobjects: Union[bool, Iterable] = True): """ Renders views and predicts the view embedding of a SuperSegmentationObject without caching. This function renders raw views at rendering locations determined by `comp_window` and according to given view properties. These views are predicted with the provided `model`, which does not require storing the views on the file system. The `predict_views_embedding` method in `super_segmentation_object` can be used as an alternative that employs file- system caching. By default, predictions are stored as `latent_morph`. Args: sso: SuperSegmentationObject to process. No file-system caching is used for views. model: A torch neural network model used for prediction. ws: Tuple[int, int], window size in pixels [y, x]. nb_views: int, number of views rendered at each rendering location. comp_window: Union[int, float], physical extent in nm of the view-window along y. pred_key_appendix: str, appendix for prediction key in attribute dictionary. verbose: bool, if True, adds progress bars for view generation. overwrite: bool, if True, overwrites existing views in temp view dictionary. add_cellobjects: Union[bool, Iterable], specifies whether to add cell objects during rendering. Accepts a boolean value or a list of structures used for rendering. This is applicable only when `raw_view_key` or `nb_views` is None, leading to on-the- fly rendering. Returns: None """ pred_key = "latent_morph" pred_key += pred_key_appendix view_kwargs = dict(ws=ws, comp_window=comp_window, nb_views=nb_views, verbose=verbose, add_cellobjects=add_cellobjects, return_rot_mat=False) verts = sso.mesh[1].reshape(-1, 3) # this cache is only in-memory, and not file system cache assert sso.view_caching, "'view_caching' of {} has to be True in order to" \ " run 'view_embedding_of_sso_nocache'.".format(sso) tmp_view_key = 'tmp_views' + pred_key_appendix if tmp_view_key not in sso.view_dict or overwrite: rendering_locs = generate_rendering_locs(verts, comp_window / 3) # ~3 views per comp window # overwrite default rendering locations (used later on for the view generation) sso._sample_locations = rendering_locs[None, ] # requires auxiliary axis # views shape: N, 4, nb_views, y, x views = render_sso_coords(sso, rendering_locs, **view_kwargs) sso.view_dict[tmp_view_key] = views # required for `sso_views_to_modelinput` else: views = sso.view_dict[tmp_view_key] from ..handler.prediction import naive_view_normalization_new views = naive_view_normalization_new(views) # The inference with TNets can be optimzed, via splititng the views into three equally sized parts. inp = (views[:, :, 0], np.zeros_like(views[:, :, 0]), np.zeros_like(views[:, :, 0])) # return dist1, dist2, inp1, inp2, inp3 latent _, _, latent, _, _ = model.predict_proba(inp) # only use first view for now # map latent vecs at rendering locs to skeleton node locations via nearest neighbor sso.load_skeleton() # view location ordering same as views / latent hull_tree = spatial.cKDTree(np.concatenate(sso.sample_locations())) dists, ixs = hull_tree.query(sso.skeleton["nodes"] * sso.scaling, n_jobs=sso.nb_cpus, k=1) sso.skeleton[pred_key] = latent[ixs] sso.save_skeleton()
[docs]def semseg_of_sso_nocache(sso, model, semseg_key: str, ws: Tuple[int, int], nb_views: int, comp_window: float, k: int = 1, dest_path: Optional[str] = None, verbose: bool = False, add_cellobjects: Union[bool, Iterable] = True, bs: int = 10): """ Renders raw and index views at rendering locations determined by `comp_window` and according to given view properties without storing them on the file system. Views will be predicted with the given `model` and maps prediction results onto mesh. Vertex labels are stored on file system and can be accessed via `sso.label_dict('vertex')[semseg_key]`. If sso._sample_locations is None, `generate_rendering_locs(verts, comp_window / 3)` will be called to generate rendering locations. Examples: Given a cell reconstruction exported as kzip (see ) at ``cell_kzip_fn`` the compartment prediction (axon boutons, dendrite, soma) can be started via the following script:: # set working directory to obtain models global_params.wd = '~/SyConn/example_cube1/' # get model for compartment detection m = get_semseg_axon_model() view_props = global_params.config['compartments']['view_properties_semsegax'] view_props["verbose"] = True # load SSO instance from k.zip file sso = init_sso_from_kzip(cell_kzip_fn, sso_id=1) # run prediction and store result in new kzip cell_kzip_fn_axon = cell_kzip_fn[:-6] + '_axon.k.zip' semseg_of_sso_nocache(sso, dest_path=cell_kzip_fn_axon, model=m, **view_props) See also the example scripts at:: $ python SyConn/examples/semseg_axon.py $ python SyConn/examples/semseg_spine.py Args: sso: Cell reconstruction object to be processed. model: The machine learning model used for prediction. semseg_key: The key used to store the resulting prediction. ws: Tuple representing the window size in pixels (y, x). nb_views: The number of views rendered at each rendering location. comp_window: The physical extent in nm of the view-window along the y-axis. k: The number of nearest vertices to average over for mesh mapping. If k=0, unpredicted vertices will be treated as 'unpredicted' class. dest_path: The file path to store the colored mesh k.zip file. verbose: If True, adds progress bars for view generation. add_cellobjects: If True, adds cell objects to the rendering. Can be a list of structures to render. Only used when `raw_view_key` or `nb_views` is None - then views are rendered on-the-fly. bs: The batch size during inference. Returns: None """ view_kwargs = dict(ws=ws, comp_window=comp_window, nb_views=nb_views, verbose=verbose, save=False) raw_view_key = 'raw{}_{}_{}'.format(ws[0], ws[1], nb_views) index_view_key = 'index{}_{}_{}'.format(ws[0], ws[1], nb_views) verts = sso.mesh[1].reshape(-1, 3) # use default rendering locations (used later on for the view generation) if sso._sample_locations is None: # ~three views per comp window rendering_locs = generate_rendering_locs(verts, comp_window / 3) sso._sample_locations = [rendering_locs] assert sso.view_caching, "'view_caching' of {} has to be True in order to" \ " run 'semseg_of_sso_nocache'.".format(sso) # this generates the raw views and their prediction sso.predict_semseg(model, semseg_key, raw_view_key=raw_view_key, add_cellobjects=add_cellobjects, bs=bs, **view_kwargs) if verbose: log_reps.debug('Finished shape-view rendering and sem. seg. prediction.') # this generates the index views sso.render_indexviews(view_key=index_view_key, force_recompute=True, **view_kwargs) if verbose: log_reps.debug('Finished index-view rendering.') # map prediction onto mesh and saves it to sso._label_dict['vertex'][semseg_key] (also pushed to file system!) sso.semseg2mesh(semseg_key, index_view_key=index_view_key, dest_path=dest_path, force_recompute=True, k=k) if verbose: log_reps.debug('Finished mapping of vertex predictions to mesh.')
[docs]def assemble_from_mergelist(ssd: 'SuperSegmentationDataset', mergelist: Union[Dict[int, int], str]): """ Creates a mapping dictionary and saves the dataset shallowly based on a mergelist. This function will overwrite existing mapping dict, id changer, and version files. Args: ssd (SuperSegmentationDataset): The dataset to be updated with the new mapping. mergelist: Supervoxel agglomeration provided either as a dictionary or as a file path to a previously generated mergelist. Returns: None """ if mergelist is not None: assert "sv" in ssd.version_dict if isinstance(mergelist, dict): pass elif isinstance(mergelist, str): with open(mergelist, "r") as f: mergelist = mergelist_tools. \ subobject_map_from_mergelist(f.read()) else: raise Exception("sv_mapping has unknown type") mapping_dict = dict() for sv_id in mergelist.values(): mapping_dict[sv_id] = [] for sv_id in mergelist.keys(): mapping_dict[mergelist[sv_id]].append(sv_id) ssd._mapping_dict = mapping_dict ssd.create_mapping_lookup_reverse() ssd.save_dataset_shallow(overwrite=True)
[docs]def compartments_graph(ssv: 'super_segmentation.SuperSegmentationObject', axoness_key: str) -> Tuple[nx.Graph, nx.Graph, nx.Graph]: """ Creates graphs for axon, dendrite, and soma compartments based on skeleton node predictions. Args: ssv: Cell reconstruction object. Its skeleton must exist and must contain keys ``'edges'``, ``'nodes'`` and `axoness_key`. axoness_key: str, key for axon predictions in `ssv.skeleton` (0: dendrite, 1: axon, 2: soma). Convert labels 3 (en-passant bouton) and 4 (terminal bouton) to 1 (axon). Returns: Tuple[nx.Graph, nx.Graph, nx.Graph]: Graphs for dendrite, axon, and soma compartments, respectively. """ axon_prediction = np.array(ssv.skeleton[axoness_key]) axon_prediction[axon_prediction == 3] = 1 axon_prediction[axon_prediction == 4] = 1 axon_ixs = np.nonzero(axon_prediction == 1) dendrite_ixs = np.nonzero(axon_prediction == 0) soma_ixs = np.nonzero(axon_prediction == 2) so_graph = ssv.weighted_graph(add_node_attr=[axoness_key]) ax_graph = so_graph.copy() den_graph = so_graph.copy() for ix in axon_ixs: so_graph.remove_node(ix) den_graph.remove_node(ix) for ix in dendrite_ixs: so_graph.remove_node(ix) ax_graph.remove_node(ix) for ix in soma_ixs: ax_graph.remove_node(ix) den_graph.remove_node(ix) return den_graph, ax_graph, so_graph
[docs]def syn_sign_ratio_celltype(ssv: 'super_segmentation.SuperSegmentationObject', weighted: bool = True, recompute: bool = False, comp_types: Optional[List[int]] = None, save: bool = False) -> float: """ Computes the ratio of symmetric synapses on specified compartments of a cell reconstruction. The ratio is based on the synapse objects associated with the SuperSegmentationObject. Excludes partner cell compartment information. Refer to `~syconn.reps.super_segmentation_object.SuperSegmentationObject.syn_sign_ratio` for partner inclusion. Todo: * Check default of synapse type if synapse type predictions are not available -> propagate to this method and return -1. Notes: * Bouton predictions are converted into axon label, i.e., 3 -> 1 (en-passant) and 4 -> 1 (terminal). * Compartment predictions are collected after first attribute access during celltype prediction. The key 'partner_axoness' is not available in `self.syn_ssv` until the relevant processing function is called (see :func:`~syconn.exec.exec_syns.run_matrix_export`). * The compartment type of the other cell cannot be inferred at this point. Think about adding the property collection before celltype prediction -> would allow more detailed filtering of the synapses, but adds an additional round of property collection. Args: ssv (SuperSegmentationObject): The cell reconstruction. weighted (bool): If True, compute synapse-area weighted ratio. recompute (bool): If True, ignores existing values and recomputes. comp_types (list, optional): Specifies the functional compartment types for computing the ratio. Default is [1, ] for axons only. save (bool): If True, saves the computed ratio using a key that includes 'syn_sign_ratio_celltype' or 'syn_sign_ratio_celltype_weighted' with `comp_types`. Returns: float: The (area-weighted) ratio of symmetric synapses or -1 if no synapses are present. """ if comp_types is None: comp_types = [1, ] ratio_key = 'syn_sign_ratio_celltype' if weighted: ratio_key += '_weighted' ratio_key += '_' + "_".join([str(el) for el in comp_types]) ratio = ssv.lookup_in_attribute_dict(ratio_key) if not recompute and ratio is not None: return ratio pred_key_ax = "{}_avg{}".format(global_params.config['compartments']['view_properties_semsegax']['semseg_key'], global_params.config['compartments']['dist_axoness_averaging']) if len(ssv.syn_ssv) == 0: return -1 props = load_so_attr_bulk(ssv.syn_ssv, ('syn_sign', 'mesh_area', 'rep_coord'), allow_missing=True, use_new_subfold=global_params.config.use_new_subfold) syn_axs = ssv.attr_for_coords([props['rep_coord'][syn.id] for syn in ssv.syn_ssv], attr_keys=[pred_key_ax, ])[0] # convert boutons to axon class syn_axs[syn_axs == 3] = 1 syn_axs[syn_axs == 4] = 1 syn_signs = [] syn_sizes = [] for syn_ix, syn in enumerate(ssv.syn_ssv): if syn_axs[syn_ix] not in comp_types: continue syn_sign = props['syn_sign'][syn.id] syn_size = props['mesh_area'][syn.id] / 2 if syn_sign is None or syn_size is None: raise ValueError(f'Got at least one None value for syn_sign and/or syn_size of {ssv.syn_ssv[syn_ix]}.') syn_signs.append(syn_sign) syn_sizes.append(syn_size) if len(syn_signs) == 0 or np.sum(syn_sizes) == 0: if save: ssv.save_attributes([ratio_key], [-1]) return -1 syn_signs = np.array(syn_signs) syn_sizes = np.array(syn_sizes) if weighted: ratio = np.sum(syn_sizes[syn_signs == -1]) / float(np.sum(syn_sizes)) else: ratio = np.sum(syn_signs == -1) / float(len(syn_signs)) if save: ssv.save_attributes([ratio_key], [ratio]) return ratio
[docs]def extract_spinehead_volume_mesh(sso: 'super_segmentation.SuperSegmentationObject', ctx_vol=(200, 200, 100)): """ Calculate the volume of spine heads using a watershed approach on cell segmentation. This method applies a watershed procedure to the cell segmentation to determine the volume of spine heads. The process begins with predictions on the cell mesh, then refines these predictions by mapping them to voxels within a specified bounding box around synapses. The watershed seeds are derived from local maxima of the cell mask's distance transform and are labeled based on the majority vote among their k-nearest vertices. Results are stored in the `SuperSegmentationObject.attr_dict` with the key `'spinehead_vol'`. Notes: - The calculated 'spine_headvol' is in micrometers cubed (µm^3). - The segmentation mask is downsampled to match the z voxel size. - The predicted cell mesh must have 'spiness' in `label_dict('vertex')['spiness']`. - To store results, invoke `sso.save_attr_dict()`. Args: sso: The SuperSegmentationObject to be processed. It requires a predicted cell mesh, i.e. 'spiness' must be present in `label_dict('vertex')['spiness']`. ctx_vol: A tuple representing the additional volume around the spine head synapse representative coordinate used for volume estimation. The inspected volume is `2*ctx_vol + synapse_bounding_box`. Returns: None """ if len(sso.attr_dict) == 0: sso.load_attr_dict() sso.attr_dict['spinehead_vol'] = {} ctx_vol = np.array(ctx_vol) scaling = sso.scaling if 'spiness' not in sso.label_dict('vertex'): msg = f'"spiness" not available in skeleton of SSO {sso.id}.' log_reps.error(msg) raise ValueError(msg) ssv_svids = sso.sv_ids ssv_syncoords = np.array([syn.rep_coord for syn in sso.syn_ssv]) if len(ssv_syncoords) == 0: return ssv_synids = np.array([syn.id for syn in sso.syn_ssv]) verts = sso.mesh[1].reshape(-1, 3) / scaling sp_semseg = sso.label_dict('vertex')['spiness'] if np.ndim(sp_semseg) == 2: sp_semseg = sp_semseg.squeeze(1) ignore_labels = sso.config['spines']['semseg2coords_spines']['ignore_labels'] for l in ignore_labels: verts = verts[sp_semseg != l] sp_semseg = sp_semseg[sp_semseg != l] curr_sp = sso.semseg_for_coords(ssv_syncoords, 'spiness', **sso.config['spines']['semseg2coords_spines']) pred_key_ax = "{}_avg{}".format(sso.config['compartments'][ 'view_properties_semsegax']['semseg_key'], sso.config['compartments'][ 'dist_axoness_averaging']) curr_ax = sso.attr_for_coords(ssv_syncoords, attr_keys=[pred_key_ax])[0] ssv_syncoords = ssv_syncoords[(curr_sp == 1) & (curr_ax == 0)] ssv_synids = ssv_synids[(curr_sp == 1) & (curr_ax == 0)] if len(ssv_syncoords) == 0: # no spine head synapses return ds = sso.scaling[2] // np.array(sso.scaling) assert np.all(ds > 0) kd = kd_factory(sso.config.kd_seg_path) k_nn = sso.config['spines']['semseg2coords_spines']['k'] # iterate over spine head synapses for c, ssv_syn_id in zip(ssv_syncoords, ssv_synids): offset = c - ctx_vol offset[offset < 0] = 0 size = (2 * ctx_vol).astype(np.int32) # get cell segmentation mask seg = kd.load_seg(offset=offset, size=size, mag=1).swapaxes(2, 0) seg = ndimage.zoom(seg, 1 / ds, order=0) if len(ssv_svids) > 1: relabel_vol_nonexist2zero(seg, {k: 1 for k in ssv_svids}) else: seg = (seg == ssv_svids[0]).astype(np.int32) seg = ndimage.binary_fill_holes(seg) if np.sum(seg) == 0: msg = (f'Could not find segmentation at {offset} and size {size} for SSVs ' f'{ssv_svids}. syn_ssv ID: {ssv_syn_id}.') log_reps.error(msg) raise ValueError(msg) # set watershed seeds using vertices vert_ixs_bb = in_bounding_box(verts, np.array([offset + size / 2, size])) vert_ixs_bb = np.array(vert_ixs_bb, dtype=np.bool) verts_bb = verts[vert_ixs_bb] semseg_bb = sp_semseg[vert_ixs_bb] # pathological case, such as re-entering cells within a smaller test cube lead to missing # meshes and skeletons if the process is very small. Synapse objects get correctly identified, # but context is insufficient for mesh generation/prediction. Does not occur in real data. if len(semseg_bb) == 0: continue # relabelled spine neck as 9, actually not needed here semseg_bb[semseg_bb == 0] = 9 distance = ndimage.distance_transform_edt(seg) maxima = peak_local_max(distance, footprint=np.ones((3, 3, 3)), labels=seg).astype(np.uint64) # assign labels from nearby vertices; convert maxima coordinates back to mag 1 via 'ds' maxima_sp = colorcode_vertices(maxima * ds, verts_bb - offset, semseg_bb, k=k_nn, return_color=False, nb_cpus=sso.nb_cpus) local_maxi = np.zeros_like(distance) local_maxi[maxima[:, 0], maxima[:, 1], maxima[:, 2]] = maxima_sp labels = watershed(-distance, local_maxi, mask=seg).astype(np.uint64) labels[labels != 1] = 0 # only keep spine head locations labels, nb_obj = ndimage.label(labels) c = c - offset max_id = 1 # if more than one spine head object get the one with the majority voxels in vicinity if nb_obj > 1: # query many voxels or use NN approach? ls = labels[(c[0] - 10):(c[0] + 11), (c[1] - 10):(c[1] + 11), (c[2] - 10):(c[2] + 11)] ids, cnts = np.unique(ls, return_counts=True) cnts = cnts[ids != 0] ids = ids[ids != 0] if len(ids) == 0: coords = [] ids = [] for ii in range(1, nb_obj + 1): curr_coords = np.transpose(np.nonzero(labels == ii)) coords.append(curr_coords) ids.extend(len(curr_coords) * [ii]) coords = np.concatenate(coords) + offset nn_kdt = spatial.cKDTree(coords * sso.scaling) _, nn_id = nn_kdt.query([(c + offset) * sso.scaling]) max_id = ids[nn_id[0]] else: max_id = ids[np.argmax(cnts)] n_voxels_spinehead = np.sum(labels == max_id) vol_sh = n_voxels_spinehead * np.prod(scaling * ds) / 1e9 # in um^3 sso.attr_dict['spinehead_vol'][ssv_syn_id] = vol_sh
[docs]def sso_svgraph2kzip(dest_path: str, sso: 'SuperSegmentationObject'): """ Stores the supervoxel graph of a SuperSegmentationObject in a KNOSSOS compatible kzip file. Args: dest_path (str): The file path where the k.zip will be stored. sso (SuperSegmentationObject): The SuperSegmentationObject whose supervoxel graph is to be stored. Returns: None """ sv_edges = sso.load_sv_edgelist() anno = SkeletonAnnotation() anno.scaling = sso.scaling sd = sso.get_seg_dataset('sv') coord_dc = dict() for ii in range(len(sd.ids)): coord_dc[sd.ids[ii]] = sd.rep_coords[ii] for sv1, sv2 in sv_edges: sv1_coord = coord_dc[sv1.id] * sso.scaling sv2_coord = coord_dc[sv2.id] * sso.scaling n1 = SkeletonNode().from_scratch(anno, sv1_coord[0], sv1_coord[1], sv1_coord[2]) n1.data['svid'] = sv1.id n2 = SkeletonNode().from_scratch(anno, sv2_coord[0], sv2_coord[1], sv2_coord[2]) n2.data['svid'] = sv2.id anno.addNode(n1) anno.addNode(n2) anno.addEdge(n1, n2) dummy_skel = Skeleton() dummy_skel.add_annotation(anno) dummy_skel.to_kzip(dest_path)