Source code for syconn.reps.segmentation_helper

# -*- coding: utf-8 -*-
# SyConn - Synaptic connectivity inference toolkit
#
# Copyright (c) 2016 - now
# Max-Planck-Institute of Neurobiology, Munich, Germany
# Authors: Philipp Schubert, Joergen Kornfeld
import glob
import os
from collections import defaultdict
from typing import TYPE_CHECKING, Dict, Optional, Tuple, List, Union, Iterable, Any

from scipy import spatial
import numpy as np

from . import log_reps
from . import rep_helper as rh
from .rep_helper import surface_samples
from .. import global_params
from ..backend.storage import AttributeDict, CompressedStorage, MeshStorage, \
    VoxelStorage, SkeletonStorage, VoxelStorageDyn, VoxelStorageLazyLoading
from ..handler.basics import chunkify, temp_seed
from ..handler.multiviews import generate_rendering_locs
from ..mp.mp_utils import start_multiprocess_imap
from ..proc.graphs import create_graph_from_coords

if TYPE_CHECKING:
    from ..reps.segmentation import SegmentationObject, SegmentationDataset
MeshType = Union[Tuple[np.ndarray, np.ndarray, np.ndarray], List[np.ndarray],
                 Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]]


[docs]def glia_pred_so(so: 'SegmentationObject', thresh: float, pred_key_appendix: str) -> int: """ Perform the glia classification of a cell supervoxel (0: neuron, 1: glia). Args: so: The cell supervoxel object. thresh: Threshold used for the classification. pred_key_appendix: Additional prediction key. Returns: """ assert so.type == "sv" pred_key = "glia_probas" + pred_key_appendix if pred_key not in so.attr_dict: so.load_attr_dict() try: preds = np.array(so.attr_dict[pred_key][:, 1] > thresh, dtype=np.int32) pred = np.mean(so.attr_dict[pred_key][:, 1]) > thresh except KeyError: raise KeyError('Could not find glia proba key `{}` in so,attr_dict (keys: {})'.format( pred_key, so.attr_dict.keys())) if pred == 0: return 0 glia_votes = np.sum(preds) if glia_votes > int(len(preds) * 0.7): return 1 return 0
[docs]def glia_proba_so(so: 'SegmentationObject', pred_key_appendix: str) -> float: """ Get mean glia probability of a cell supervoxel (0: neuron, 1: glia). Args: so: The cell supervoxel object. pred_key_appendix: Additional prediction key. Returns: """ assert so.type == "sv" pred_key = "glia_probas" + pred_key_appendix if pred_key not in so.attr_dict: so.load_attr_dict() return np.mean(so.attr_dict[pred_key][:, 1])
[docs]def acquire_obj_ids(sd: 'SegmentationDataset'): """ Acquires all obj ids present in the dataset. Loads id array if available. Assembles id list by iterating over all voxel / attribute dicts, otherwise (very slow). """ sd._ids = sd.load_numpy_data('id') if sd._ids is None: paths = glob.glob(sd.so_storage_path + "/*/*/*/") + \ glob.glob(sd.so_storage_path + "/*/*/") + \ glob.glob(sd.so_storage_path + "/*/") sd._ids = [] for path in paths: if os.path.exists(path + "attr_dict.pkl"): this_ids = list(AttributeDict(path + "attr_dict.pkl", read_only=True).keys()) else: this_ids = [] sd._ids += this_ids sd._ids = np.array(sd._ids) np.save(sd.path_ids, sd._ids)
[docs]def save_voxels(so: 'SegmentationObject', bin_arr: np.ndarray, offset: np.ndarray, overwrite: bool = False): """ Helper function to save SegmentationObject voxels. Args: so: SegmentationObject bin_arr: np.array Binary mask array, 0: background, 1: supervoxel locations. offset: np.array overwrite: bool Returns: """ assert bin_arr.dtype == bool voxel_dc = VoxelStorage(so.voxel_path, read_only=False, disable_locking=True) if so.id in voxel_dc and not overwrite: voxel_dc.append(so.id, bin_arr, offset) else: voxel_dc[so.id] = [bin_arr], [offset] voxel_dc.push(so.voxel_path)
[docs]def load_voxels_depr(so: 'SegmentationObject', voxel_dc: Optional[VoxelStorage] = None) -> np.ndarray: """ Helper function to load voxels of a SegmentationObject as 3D array. Also calculates size and bounding box and assigns it to `so._size` and `so._bounding_box` respectively. Args: so: SegmentationObject voxel_dc: VoxelStorage Returns: np.array 3D binary mask array, 0: background, 1: supervoxel locations. """ if voxel_dc is None: voxel_dc = VoxelStorage(so.voxel_path, read_only=True, disable_locking=True) so._size = 0 if so.id not in voxel_dc: msg = f"Voxels of {so} do not exist!" log_reps.error(msg) raise KeyError(msg) bin_arrs, block_offsets = voxel_dc[so.id] block_extents = [] for i_bin_arr in range(len(bin_arrs)): block_extents.append(np.array(bin_arrs[i_bin_arr].shape) + block_offsets[i_bin_arr]) block_offsets = np.array(block_offsets, dtype=np.int32) block_extents = np.array(block_extents, dtype=np.int32) so._bounding_box = np.array([block_offsets.min(axis=0), block_extents.max(axis=0)]) voxels = np.zeros(so.bounding_box[1] - so.bounding_box[0], dtype=np.bool) for i_bin_arr in range(len(bin_arrs)): box = [block_offsets[i_bin_arr] - so.bounding_box[0], block_extents[i_bin_arr] - so.bounding_box[0]] so._size += np.sum(bin_arrs[i_bin_arr]) voxels[box[0][0]: box[1][0], box[0][1]: box[1][1], box[0][2]: box[1][2]][bin_arrs[i_bin_arr]] = True return voxels
[docs]def load_voxels_downsampled(so: 'SegmentationObject', ds: Tuple[int, int, int] = (2, 2, 1)) -> Union[np.ndarray, List]: if isinstance(so.voxels, int): return [] return so.voxels[::ds[0], ::ds[1], ::ds[2]]
[docs]def load_voxel_list(so: 'SegmentationObject') -> np.ndarray: """ Helper function to load voxels of a SegmentationObject. Args: so: SegmentationObject. Returns: np.array 2D array of coordinates to all voxels in SegmentationObject. """ if so._voxels is not None: voxel_list = np.transpose(np.nonzero(so.voxels)) + so.bounding_box[0] elif so.type in ['syn', 'syn_ssv']: voxel_dc = VoxelStorageLazyLoading(so.voxel_path) voxel_list = voxel_dc[so.id] voxel_dc.close() else: voxel_dc = VoxelStorageDyn(so.voxel_path, read_only=True, disable_locking=True) bin_arrs, block_offsets = voxel_dc[so.id] voxel_list = [] for i_bin_arr in range(len(bin_arrs)): block_voxels = np.transpose(np.nonzero(bin_arrs[i_bin_arr])) block_voxels += block_offsets[i_bin_arr] voxel_list.append(block_voxels) voxel_list = np.concatenate(voxel_list) return voxel_list
[docs]def load_voxel_list_downsampled(so, downsampling=(2, 2, 1)): """ TODO: refactor, probably more efficient implementation possible. Args: so: SegmentationObject downsampling: Tuple[int] Returns: """ downsampling = np.array(downsampling) dvoxels = so.load_voxels_downsampled(downsampling) voxel_list = np.array(np.transpose(np.nonzero(dvoxels)), dtype=np.int32) voxel_list = voxel_list * downsampling + np.array(so.bounding_box[0]) return voxel_list
[docs]def load_voxel_list_downsampled_adapt(so, downsampling=(2, 2, 1)): downsampling = np.array(downsampling, dtype=np.int32) dvoxels = so.load_voxels_downsampled(downsampling) if len(dvoxels) == 0: return [] while True: if True in dvoxels: break downsampling = downsampling // 2 downsampling[downsampling < 1] = 1 dvoxels = so.load_voxels_downsampled(downsampling) voxel_list = np.array(np.transpose(np.nonzero(dvoxels)), dtype=np.int32) voxel_list = voxel_list * downsampling + np.array(so.bounding_box[0]) return voxel_list
[docs]def load_mesh(so: 'SegmentationObject', recompute: bool = False) -> MeshType: """ Load mesh of SegmentationObject. TODO: Currently ignores potential color/label array. Args: so: SegmentationObject recompute: bool Returns: indices, vertices, normals; all flattened """ if not recompute and so.mesh_exists: try: mesh = MeshStorage(so.mesh_path, disable_locking=True)[so.id] if len(mesh) == 2: indices, vertices = mesh normals = np.zeros((0,), dtype=np.float32) elif len(mesh) == 3: indices, vertices, normals = mesh col = np.zeros(0, dtype=np.uint8) elif len(mesh) == 4: indices, vertices, normals, col = mesh except Exception as e: msg = "\n%s\nException occured when loading mesh.pkl of SO (%s)" \ "with id %d.".format(e, so.type, so.id) log_reps.error(msg) return [np.zeros((0,)).astype(np.int32), np.zeros((0,)), np.zeros((0,))] else: if so.type == "sv" and not global_params.config.allow_mesh_gen_cells: log_reps.error("Mesh of SV %d not found.\n" % so.id) return [np.zeros((0,)).astype(np.int), np.zeros((0,)), np.zeros((0,))] indices, vertices, normals = so.mesh_from_scratch() col = np.zeros(0, dtype=np.uint8) try: so._save_mesh(indices, vertices, normals) except Exception as e: log_reps.error("Mesh of %s %d could not be saved:\n%s\n".format( so.type, so.id, e)) vertices = np.array(vertices, dtype=np.int32) indices = np.array(indices, dtype=np.int64) normals = np.array(normals, dtype=np.float32) col = np.array(col, dtype=np.uint8) return [indices, vertices, normals]
[docs]def load_skeleton(so: 'SegmentationObject', recompute: bool = False) -> dict: """ Args: so: SegmentationObject recompute: Compute skeleton, will not store it in ``SkeletonStorage``. Returns: Dictionary with "nodes", "diameters" and "edges". """ empty_skel = dict(nodes=np.zeros((0, 3)).astype(np.int64), edges=np.zeros((0, 2)), diameters=np.zeros((0,)).astype(np.int32)) if not recompute and so.skeleton_exists: try: skeleton_dc = SkeletonStorage(so.skeleton_path, disable_locking=not so.enable_locking) skel = skeleton_dc[so.id] if np.ndim(skel['nodes']) == 1: skel['nodes'] = skel['nodes'].reshape((-1, 3)) if np.ndim(skel['edges']) == 1: skel['edges'] = skel['edges'].reshape((-1, 2)) except Exception as e: log_reps.error("\n{}\nException occured when loading skeletons.pkl" " of SO ({}) with id {}.".format(e, so.type, so.id)) return empty_skel elif recompute: skel = generate_skeleton_sv(so) else: msg = f"Skeleton of {so} (size: {so.size}) not found.\n" if so.type == "sv": if so.size == 1: # small SVs don't have a skeleton log_reps.debug(msg) else: log_reps.error(msg) raise ValueError(msg) return empty_skel return skel
[docs]def save_skeleton(so: 'SegmentationObject', overwrite: bool = False): """ Args: so: overwrite: Returns: """ skeleton_dc = SkeletonStorage(so.skeleton_path, read_only=False, disable_locking=not so.enable_locking) if not overwrite and so.id in skeleton_dc: raise ValueError(f"Skeleton of {so} already exists.") skeleton_dc[so.id] = so.skeleton skeleton_dc.push()
[docs]def sv_view_exists(args): ps, woglia = args missing_ids = [] for p in ps: ad = AttributeDict(p + "/attr_dict.pkl", disable_locking=True) obj_ixs = ad.keys() view_dc_p = p + "/views_woglia.pkl" if woglia else p + "/views.pkl" view_dc = CompressedStorage(view_dc_p, disable_locking=True) missing_ids = np.setdiff1d(list(obj_ixs), list(view_dc.keys())) return missing_ids
[docs]def find_missing_sv_views(sd, woglia, n_cores=20): folders = sd.so_dir_paths np.random.shuffle(folders) multi_params = chunkify(folders, 1000) params = [(ps, woglia) for ps in multi_params] res = start_multiprocess_imap(sv_view_exists, params, nb_cpus=n_cores, debug=False) return np.concatenate(res)
[docs]def sv_skeleton_missing(sv): if sv.skeleton is None: sv.load_skeleton() return (sv.skeleton is None) or (len(sv.skeleton['nodes']) == 0)
[docs]def find_missing_sv_skeletons(svs, n_cores=20): res = start_multiprocess_imap(sv_skeleton_missing, svs, nb_cpus=n_cores, debug=False) return [svs[kk].id for kk in range(len(svs)) if res[kk]]
[docs]def sv_attr_exists(args): ps, attr_key = args missing_ids = [] for p in ps: ad = AttributeDict(p + "/attr_dict.pkl", disable_locking=True) for k, v in ad.items(): if attr_key not in v: missing_ids.append(k) return missing_ids
[docs]def find_missing_sv_attributes(sd: 'SegmentationDataset', attr_key: str, n_cores: int = 20): """ Args: sd: attr_key: str n_cores: int Returns: """ multi_params = chunkify(sd.so_dir_paths, 100) params = [(ps, attr_key) for ps in multi_params] res = start_multiprocess_imap(sv_attr_exists, params, nb_cpus=n_cores, debug=False) return np.concatenate(res)
[docs]def load_so_meshes_bulk(sos: Union[List['SegmentationObject'], Iterable['SegmentationObject']], use_new_subfold: bool = True, cache_decomp=True) -> MeshStorage: """ Bulk loader for SegmentationObject (SO) meshes. Minimizes IO by loading IDs from the same storage at the same time. This will not assign the ``_mesh`` attribute! Args: sos: SegmentationObjects use_new_subfold: Use new sub-folder structure cache_decomp: Cache decompressed meshes. Returns: Dictionary, key: ID, value: mesh """ md_out = MeshStorage(None) # in-memory dict with compression if len(sos) == 0: return md_out base_path = sos[0].so_storage_path nf = sos[0].n_folders_fs subf_from_ix = rh.subfold_from_ix_new if use_new_subfold else \ rh.subfold_from_ix_OLD sub2ids = defaultdict(list) for so in sos: if so._mesh is None: subf = subf_from_ix(so.id, nf) sub2ids[subf].append(so.id) else: md_out[so.id] = so._mesh for subfold, ids in sub2ids.items(): mesh_path = f'{base_path}/{subfold}/mesh.pkl' md = MeshStorage(mesh_path, disable_locking=True, cache_decomp=cache_decomp) for so_id in ids: md_out._dc_intern[so_id] = md._dc_intern[so_id] assert len(md_out) == len(sos) return md_out
[docs]def load_so_attr_bulk(sos: List['SegmentationObject'], attr_keys: Union[str, Iterable[str]], use_new_subfold: bool = True, allow_missing: bool = False) -> Union[Dict[str, Dict[int, Any]], Dict[int, Any]]: """ Bulk loader for SegmentationObject (SO) meshes. Minimizes IO by loading IDs from the same storage at the same time. Returns a single dict if only one attr_key is provided or a dict of dicts if many. This method will also check if the requested attribute(s) already exist in the object's ``attr_dict``. This means using ``cache_properties`` when initializing ``SegmentationDataset`` might be beneficial to avoid exhaustive file reads in case `sos` is large. Args: sos: SegmentationObjects attr_keys: Attribute key(s). use_new_subfold: Use new sub-folder structure allow_missing: If True, sets attribute value to None if missing. If False and missing, raise KeyError. Returns: (Dict. with key: attr_key of) dict. with key: ID, value: attribute value """ if type(attr_keys) is str: attr_keys = [attr_keys] out = {attr_k: dict() for attr_k in attr_keys} if len(sos) == 0: if len(attr_keys) == 1: out = out[attr_keys[0]] return out base_path = sos[0].so_storage_path nf = sos[0].n_folders_fs subf_from_ix = rh.subfold_from_ix_new if use_new_subfold else rh.subfold_from_ix_OLD sub2ids = defaultdict(list) for so in sos: keys_missing = len(attr_keys) # use cached/loaded attributes for k in attr_keys: if k in so.attr_dict: out[k][so.id] = so.attr_dict[k] keys_missing -= 1 if keys_missing == 0: continue subf = subf_from_ix(so.id, nf) sub2ids[subf].append(so.id) for subfold, ids in sub2ids.items(): attr_p = f'{base_path}/{subfold}/attr_dict.pkl' ad = AttributeDict(attr_p, disable_locking=True) for so_id in ids: so_dict = ad[so_id] for attr_key in attr_keys: try: out[attr_key][so_id] = so_dict[attr_key] except KeyError as e: if allow_missing: out[attr_key][so_id] = None else: raise KeyError(e) if len(attr_keys) == 1: out = out[attr_keys[0]] return out
[docs]def prepare_so_attr_cache(sd: 'SegmentationDataset', so_ids: np.ndarray, attr_keys: List[str]) -> Dict[str, dict]: """ Args: sd: SegmentationDataset. so_ids: SegmentationObject IDs for which to collect the attributes. attr_keys: Attribute keys to collect. Corresponding numyp arrays must exist. Returns: Dictionary with `attr_keys` as keys and an attribute dictionary as values for the IDs `so_ids`, e.g. ``attr_cache[attr_keys[0]][so_ids[0]]`` will return the attribute value of type ``attr_keys[0]`` for the first SegmentatonObect in `so_ids`. """ attr_cache = {k: dict() for k in attr_keys} # TODO: Use BinarySearchStore soid2ix = {so_id: sd.soid2ix[so_id] for so_id in so_ids} sd._soid2ix = None # free memory for attr in attr_keys: np_cache = sd.load_numpy_data(attr, allow_nonexisting=False) for so_id in so_ids: attr_cache[attr][so_id] = np_cache[soid2ix[so_id]] del np_cache return attr_cache
[docs]def load_so_voxels_bulk(sos: List['SegmentationObject'], use_new_subfold: bool = True, cache_decomp=True): """ Args: sos: use_new_subfold: cache_decomp: Returns: """ raise NotImplementedError('WIP') vd_out = VoxelStorage(None, cache_decomp=cache_decomp) # in-memory dict with compression if len(sos) == 0: return vd_out base_path = sos[0].so_storage_path nf = sos[0].n_folders_fs subf_from_ix = rh.subfold_from_ix_new if use_new_subfold else \ rh.subfold_from_ix_OLD sub2ids = defaultdict(list) for so in sos: subf = subf_from_ix(so.id, nf) sub2ids[subf].append(so.id) cnt = 0 for subfold, ids in sub2ids.items(): voxel_path = f'{base_path}/{subfold}/voxel.pkl' vd = VoxelStorage(voxel_path, disable_locking=True) for so_id in ids: cnt += 1 vd_out._dc_intern[so_id] = vd._dc_intern[so_id] assert cnt == len(sos) return vd_out
def _helper_func(args): ps, use_vxsize = args out = [] for p in ps: if not use_vxsize: w = len(AttributeDict(p + '/attr_dict.pkl', disable_locking=True)) else: w = np.sum([v['size'] for v in AttributeDict(p + '/attr_dict.pkl', disable_locking=True).values()]) out.append(w) return out
[docs]def get_sd_load_distribution(sd: 'SegmentationDataset', use_vxsize: bool = True) -> np.ndarray: """ Get the load distribution (number of objects per storage) of the SegmentationDataset's AttributeDicts. Args: sd: SegmentationDataset use_vxsize: Returns: Load array. """ n_objects = start_multiprocess_imap(_helper_func, [(ch, use_vxsize) for ch in chunkify(sd.so_dir_paths, 1000)], nb_cpus=None) return np.concatenate(n_objects).astype(np.int64)
[docs]def generate_skeleton_sv(so: 'SegmentationObject') -> Dict[str, np.ndarray]: """ Poor man's solution to generate a SV "skeleton". Used for glia predictions. Args: so: Returns: Dictionary with keys "nodes", "edges" and "diameters" (all 1). """ verts = so.mesh[1].reshape(-1, 3) # choose random subset of surface vertices np.random.seed(0) ixs = np.arange(len(verts)) np.random.shuffle(ixs) ixs = ixs[:int(0.5 * len(ixs))] if global_params.config.use_new_renderings_locs: locs = generate_rendering_locs(verts[ixs], 1000) else: locs = surface_samples(verts[ixs], bin_sizes=(1000, 1000, 1000), max_nb_samples=10000, r=500) g = create_graph_from_coords(locs, mst=True) if g.number_of_edges() == 1: edge_list = np.array(list(g.edges())) else: edge_list = np.array(g.edges()) del g if edge_list.ndim != 2: raise ValueError("Edge list ist not a 2D array: {}\n{}".format( edge_list.shape, edge_list)) skeleton = dict() skeleton["nodes"] = (locs / np.array(so.scaling)).astype(np.int32) skeleton["edges"] = edge_list skeleton["diameters"] = np.ones(len(locs)) return skeleton
[docs]def calc_center_of_mass(point_arr: np.ndarray) -> np.ndarray: """ Args: point_arr: Array of points (in nm or at least isotropic). Returns: Closest point in `point_arr` to its center of mass. """ # downsampling to ensure fast processing - this is deterministic! if len(point_arr) > 1e5: with temp_seed(0): idx = np.random.randint(0, len(point_arr), int(1e5)) point_arr = point_arr[idx] # calculate mean center_of_mass = np.mean(point_arr, axis=0) # ensure that the point is contained inside of the object, # i.e. use closest existing point to center of mass kdtree = spatial.cKDTree(point_arr) dd, ii = kdtree.query(center_of_mass, k=1) center_point = point_arr[ii] return center_point