Source code for syconn.extraction.object_extraction_wrapper

# -*- coding: utf-8 -*-
# SyConn - Synaptic connectivity inference toolkit
#
# Copyright (c) 2016 - now
# Max Planck Institute of Neurobiology, Martinsried, Germany
# Authors: Philipp Schubert, Joergen Kornfeld

import os
import shutil
import time
from logging import Logger
from typing import Optional, Dict, List, Tuple, Union, Callable

import numpy as np
from knossos_utils import chunky, knossosdataset

from . import object_extraction_steps as oes
from .. import global_params
from ..extraction import log_extraction
from ..handler import basics


[docs]def calculate_chunk_numbers_for_box(cset, offset, size): """ Calculates the chunk ids that are (partly) contained it the defined volume Args: cset : ChunkDataset offset(np.array): offset of the volume to the origin size(np.array): size of the volume Returns: chunk_list(list): chunk ids dictionary(dict): with reverse mapping """ for dim in range(3): offset_overlap = offset[dim] % cset.chunk_size[dim] offset[dim] -= offset_overlap size[dim] += offset_overlap size[dim] += (cset.chunk_size[dim] - size[dim]) % cset.chunk_size[dim] chunk_list = [] translator = {} for x in range(offset[0], offset[0] + size[0], cset.chunk_size[0]): for y in range(offset[1], offset[1] + size[1], cset.chunk_size[1]): for z in range(offset[2], offset[2] + size[2], cset.chunk_size[2]): chunk_list.append(cset.coord_dict[tuple([x, y, z])]) translator[chunk_list[-1]] = len(chunk_list) - 1 return chunk_list, translator
[docs]def generate_subcell_kd_from_proba( subcell_names: List[str], chunk_size: Optional[Union[list, tuple]] = None, transf_func_kd_overlay: Optional[Dict[str, Callable]] = None, load_cellorganelles_from_kd_overlaycubes: bool = False, cube_of_interest_bb: Optional[Tuple[np.ndarray]] = None, cube_shape: Optional[Tuple[int]] = None, log: Logger = None, overwrite=False, **kwargs): """ Generates a connected components segmentation for the given the sub-cellular structures (e.g. ['mi', 'sj', 'vc]) as KnossosDatasets. The data format of the source data is KnossosDataset which path(s) is defined in ``global_params.config['paths']`` (e.g. key ``kd_mi_path`` for mitochondria). The resulting KDs will be stored at (for each ``co in subcell_names``) ``"{}/knossosdatasets/{}_seg/".format(global_params.config.working_dir, co)``. See :func:`~syconn.extraction.object_extraction_wrapper.from_probabilities_to_kd` for details of the conversion process from the initial probability map to the SV segmentation. Default: thresholding and connected components, thresholds are set via the `config.yml` file, check ``syconn.global_params.config['cell_objects']["probathresholds"]`` of an initialized :class:`~syconn.handler.config.DynConfig` object. Args: subcell_names: chunk_size: transf_func_kd_overlay: load_cellorganelles_from_kd_overlaycubes: cube_of_interest_bb: cube_shape: log: overwrite: **kwargs: Returns: """ if chunk_size is None: chunk_size = [512, 512, 512] if log is None: log = log_extraction if cube_shape is None: cube_shape = (256, 256, 256) kd = basics.kd_factory(global_params.config.kd_seg_path) if cube_of_interest_bb is None: cube_of_interest_bb = [np.zeros(3, dtype=np.int32), kd.boundary] size = cube_of_interest_bb[1] - cube_of_interest_bb[0] + 1 offset = cube_of_interest_bb[0] cd_dir = "{}/chunkdatasets/{}/".format(global_params.config.working_dir, "_".join(subcell_names)) if os.path.isdir(cd_dir): if not overwrite: msg = f'Could not start generation of sub-cellular objects ' \ f'"{subcell_names}" ChunkDataset because it already exists at "{cd_dir}" ' \ f'and overwrite was not set to True.' log_extraction.error(msg) raise FileExistsError(msg) log.debug('Found existing ChunkDataset at {}. Removing it now.'.format(cd_dir)) shutil.rmtree(cd_dir) cd = chunky.ChunkDataset() # TODO: possible to restrict ChunkDataset here already to report correct number of processed chunks? Check # coordinate framework compatibility downstream in `from_probabilities_to_kd` cd.initialize(kd, kd.boundary, chunk_size, cd_dir, box_coords=[0, 0, 0], fit_box_size=True, list_of_coords=[]) log.info('Started object extraction of cellular organelles "{}" from ' '{} chunks.'.format(", ".join(subcell_names), len(cd.chunk_dict))) prob_kd_path_dict = {co: getattr(global_params.config, 'kd_{}_path'.format(co)) for co in subcell_names} prob_threshs = [] # get probability threshold for co in subcell_names: prob_threshs.append(global_params.config['cell_objects']["probathresholds"][co]) path = global_params.config.kd_organelle_seg_paths[co] if os.path.isdir(path): if not overwrite: msg = f'Could not start generation of sub-cellular object ' \ f'"{co}" KnossosDataset because it already exists at "{path}" and overwrite ' \ f'was not set to True.' log_extraction.error(msg) raise FileExistsError(msg) log.debug('Found existing KD at {}. Removing it now.'.format(path)) shutil.rmtree(path) target_kd = knossosdataset.KnossosDataset() scale = np.array(global_params.config['scaling'], dtype=np.float32) target_kd._cube_shape = cube_shape target_kd.scales = [scale, ] target_kd.initialize_without_conf(path, kd.boundary, scale, kd.experiment_name, mags=[1, ], create_pyk_conf=True, create_knossos_conf=False) if load_cellorganelles_from_kd_overlaycubes: # no thresholds needed prob_threshs = None from_probabilities_to_kd(global_params.config.kd_organelle_seg_paths, cd, "_".join(subcell_names), # membrane_kd_path=global_params.config.kd_barrier_path, # TODO: currently does not exist prob_kd_path_dict=prob_kd_path_dict, thresholds=prob_threshs, hdf5names=subcell_names, size=size, offset=offset, load_from_kd_overlaycubes=load_cellorganelles_from_kd_overlaycubes, transf_func_kd_overlay=transf_func_kd_overlay, log=log, **kwargs) shutil.rmtree(cd_dir, ignore_errors=True)
[docs]def from_probabilities_to_kd( target_kd_paths: Optional[Dict[str, str]], cset: 'chunky.ChunkDataset', filename: str, hdf5names: List[str], prob_kd_path_dict: Optional[Dict[str, str]] = None, load_from_kd_overlaycubes: bool = False, transf_func_kd_overlay: Optional[Dict[str, Callable]] = None, log: Optional[Logger] = None, overlap: str = "auto", sigmas: Optional[list] = None, thresholds: Optional[list] = None, debug: bool = False, swapdata: bool = False, offset: Optional[np.ndarray] = None, size: Optional[np.ndarray] = None, suffix: str = "", transform_func: Optional[Callable] = None, func_kwargs: Optional[dict] = None, n_cores: Optional[int] = None, overlap_thresh: Optional[int] = 0, stitch_overlap: Optional[int] = None, membrane_filename: str = None, membrane_kd_path: str = None, hdf5_name_membrane: str = None, n_chunk_jobs: int = None): """ Method for the conversion of classified (hard labels, e.g. 0, 1, 2; see `load_from_kd_overlaycubes` and `transf_func_kd_overlay` parameters) or predicted (probability maps, e.g. 0 .. 1 or 0 .. 255, see `thresholds` parameter). Original data can be provided as ChunkDataset `cset` or via KnossosDataset(s) `prob_kd_path_dict`. The ChunkDataset will be used in any case for storing the intermediate extraction results (per-cube segmentation, stitched results, globally unique segmentation). Notes: * KnossosDatasets given by `target_kd_paths` need to be initialized prior to this function call. Args: target_kd_paths: Paths to (already initialized) output KnossosDatasets. See ``KnossosDataset.initialize_without_conf``. cset: ChunkDataset which is used for the object extraction process and which may additionally contain the source data. The latter can be provided as KnossosDataset(s) (see `prob_kd_path_dict`). filename: The base name used to store the extracted in `cset`. hdf5names: Keys used to store the intermediate extraction results. prob_kd_path_dict: Paths to source KnossosDatasets load_from_kd_overlaycubes: Load prob/seg data from overlaycubes instead of raw cubes. transf_func_kd_overlay: Method which is to applied to cube data if `load_from_kd_overlaycubes` is True. log: TODO: pass log to all methods called overlap: Defines the overlap with neighbouring chunks that is left for later processing steps; if 'auto' the overlap is calculated from the sigma and the stitch_overlap (here: [1., 1., 1.]). sigmas: Defines the sigmas of the Gaussian filters applied to the probability maps. Has to be the same length as hdf5names. If None, no Gaussian filter is applied. thresholds: Threshold for cutting the probability map. Has to be the same length as hdf5names. If None, zeros are used instead (not recommended!) debug: If True, multiprocessing steps only operate on one core using 'map' which allows for better error messages. swapdata: If true an x-z swap is applied to the data prior to processing. offset: Offset of the processed volume. size: Size of the processed volume of the dataset starting at `offset`. suffix: Suffix used for the intermediate processing steps. transform_func: [WIP] Segmentation method which is applied, currently only func:`~syconn.extraction.object_extraction_steps. _object_segmentation_thread` is supported for batch jobs. func_kwargs: keyword arguments for `transform_func`. n_chunk_jobs: Number of jobs. n_cores: Number of cores used for each job in :func:`syconn.extraction.object_extraction_steps.object_segmentation` if batch jobs is enabled. overlap_thresh: Overlap fraction of object in different chunks to be considered stitched. If zero this behavior is disabled. stitch_overlap: Volume evaluated during stitching procedure. membrane_filename: Experimental. One way to allow access to a membrane segmentation when processing vesicle clouds. Filename of the prediction in the chunkdataset. The threshold is currently set at 0.4. membrane_kd_path: Experimental. One way to allow access to a membrane segmentation when processing vesicle clouds. Path to the knossosdataset containing a membrane segmentation. The threshold is currently set at 0.4. hdf5_name_membrane: Experimental. When `membrane_filename` is set this key has to be given to access the data in the saved chunk. Returns: """ if log is None: log = log_extraction all_times = [] step_names = [] if prob_kd_path_dict is not None: kd_keys = list(prob_kd_path_dict.keys()) assert len(kd_keys) == len(hdf5names) for kd_key in kd_keys: assert kd_key in hdf5names if size is not None and offset is not None: chunk_list, chunk_translator = \ calculate_chunk_numbers_for_box(cset, offset, size) else: chunk_translator = {} chunk_list = [ii for ii in range(len(cset.chunk_dict))] for ii in range(len(cset.chunk_dict)): chunk_translator[ii] = ii if thresholds is not None and thresholds[0] <= 1.: thresholds = np.array(thresholds) thresholds *= 255 if sigmas is not None and swapdata == 1: for nb_sigma in range(len(sigmas)): if len(sigmas[nb_sigma]) == 3: sigmas[nb_sigma] = \ basics.switch_array_entries(sigmas[nb_sigma], [0, 2]) # # -------------------------------------------------------------------------- time_start = time.time() cc_info_list, overlap_info = oes.object_segmentation( cset, filename, hdf5names, overlap=overlap, sigmas=sigmas, thresholds=thresholds, chunk_list=chunk_list, debug=debug, swapdata=swapdata, prob_kd_path_dict=prob_kd_path_dict, membrane_filename=membrane_filename, membrane_kd_path=membrane_kd_path, hdf5_name_membrane=hdf5_name_membrane, fast_load=True, suffix=suffix, transform_func=transform_func, transform_func_kwargs=func_kwargs, nb_cpus=n_cores, load_from_kd_overlaycubes=load_from_kd_overlaycubes, transf_func_kd_overlay=transf_func_kd_overlay, n_chunk_jobs=n_chunk_jobs) if stitch_overlap is None: stitch_overlap = overlap_info[1] else: overlap_info[1] = stitch_overlap if not np.all(stitch_overlap <= overlap_info[0]): msg = "Stitch overlap ({}) has to be <= than chunk overlap ({})." \ "".format(overlap_info[1], overlap_info[0]) log.error(msg) raise ValueError(msg) overlap = overlap_info[0] all_times.append(time.time() - time_start) step_names.append("conneceted components") basics.write_obj2pkl(cset.path_head_folder.rstrip("/") + "/connected_components.pkl", [cc_info_list, overlap_info]) # # # ------------------------------------------------------------------------ time_start = time.time() nb_cc_dict = {} max_nb_dict = {} max_labels = {} for hdf5_name in hdf5names: nb_cc_dict[hdf5_name] = np.zeros(len(chunk_list), dtype=np.int32) max_nb_dict[hdf5_name] = np.zeros(len(chunk_list), dtype=np.int32) for cc_info in cc_info_list: nb_cc_dict[cc_info[1]][chunk_translator[cc_info[0]]] = cc_info[2] for hdf5_name in hdf5names: max_nb_dict[hdf5_name][0] = 0 for nb_chunk in range(1, len(chunk_list)): max_nb_dict[hdf5_name][nb_chunk] = \ max_nb_dict[hdf5_name][nb_chunk - 1] + nb_cc_dict[hdf5_name][nb_chunk - 1] max_labels[hdf5_name] = int(max_nb_dict[hdf5_name][-1] + nb_cc_dict[hdf5_name][-1]) all_times.append(time.time() - time_start) step_names.append("max labels") basics.write_obj2pkl(cset.path_head_folder.rstrip("/") + "/max_labels.pkl", max_labels) # # ------------------------------------------------------------------------ time_start = time.time() oes.make_unique_labels(cset, filename, hdf5names, chunk_list, max_nb_dict, chunk_translator, debug, suffix=suffix, n_chunk_jobs=n_chunk_jobs, nb_cpus=n_cores) all_times.append(time.time() - time_start) step_names.append("unique labels") # # ------------------------------------------------------------------------ chunky.save_dataset(cset) # save dataset to be able to load it during make_stitch_list (this # allows to load the ChunkDataset inside the worker instead of pickling it for each, which # slows down the submission process. time_start = time.time() stitch_list = oes.make_stitch_list(cset, filename, hdf5names, chunk_list, stitch_overlap, overlap, debug, suffix=suffix, overlap_thresh=overlap_thresh, n_chunk_jobs=n_chunk_jobs, nb_cpus=n_cores) all_times.append(time.time() - time_start) step_names.append("stitch list") basics.write_obj2pkl(cset.path_head_folder.rstrip("/") + "/stitch_list.pkl", stitch_list) # # # ------------------------------------------------------------------------ # time_start = time.time() merge_dict, merge_list_dict = oes.make_merge_list(hdf5names, stitch_list, max_labels) all_times.append(time.time() - time_start) step_names.append("merge list") basics.write_obj2pkl(cset.path_head_folder.rstrip("/") + "/merge_list.pkl", [merge_dict, merge_list_dict]) # -------------------------------------------------------------------------- time_start = time.time() oes.apply_merge_list(cset, chunk_list, filename, hdf5names, merge_list_dict, debug, suffix=suffix, n_chunk_jobs=n_chunk_jobs, nb_cpus=n_cores) all_times.append(time.time() - time_start) step_names.append("apply merge list") time_start = time.time() chunky.save_dataset(cset) oes.export_cset_to_kd_batchjob( target_kd_paths, cset, '{}_stitched_components'.format(filename), hdf5names, offset=offset, size=size, stride=cset.chunk_size, as_raw=False, orig_dtype=np.uint64, unified_labels=False, log=log, n_max_job=n_chunk_jobs, n_cores=n_cores) all_times.append(time.time() - time_start) step_names.append("export KD") # -------------------------------------------------------------------------- log.debug("Time overview [from_probabilities_to_kd]:") for ii in range(len(all_times)): log.debug("%s: %.3fs" % (step_names[ii], all_times[ii])) log.debug("--------------------------") log.debug("Total Time: %.1f min" % (np.sum(all_times) / 60)) log.debug("--------------------------")