Source code for syconn.extraction.object_extraction_wrapper

# -*- coding: utf-8 -*-
# SyConn - Synaptic connectivity inference toolkit
#
# Copyright (c) 2016 - now
# Max Planck Institute of Neurobiology, Martinsried, Germany
# Authors: Philipp Schubert, Joergen Kornfeld

import os
import shutil
import time
from logging import Logger
from typing import Optional, Dict, List, Tuple, Union, Callable

import numpy as np
from knossos_utils import chunky, knossosdataset

from . import object_extraction_steps as oes
from .. import global_params
from ..extraction import log_extraction
from ..handler import basics


[docs]def calculate_chunk_numbers_for_box(cset, offset, size): """ This function calculates the chunk ids that are (partly) contained in the defined volume. It takes in a ChunkDataset, an offset of the volume to the origin, and the size of the volume. It returns a list of chunk ids and a dictionary with reverse mapping. Args: cset (ChunkDataset): The ChunkDataset to calculate chunk ids for. offset (np.array): The offset of the volume to the origin. size (np.array): The size of the volume. Returns: chunk_list (list): The list of chunk ids. dictionary (dict): A dictionary with reverse mapping. """ for dim in range(3): offset_overlap = offset[dim] % cset.chunk_size[dim] offset[dim] -= offset_overlap size[dim] += offset_overlap size[dim] += (cset.chunk_size[dim] - size[dim]) % cset.chunk_size[dim] chunk_list = [] translator = {} for x in range(offset[0], offset[0] + size[0], cset.chunk_size[0]): for y in range(offset[1], offset[1] + size[1], cset.chunk_size[1]): for z in range(offset[2], offset[2] + size[2], cset.chunk_size[2]): chunk_list.append(cset.coord_dict[tuple([x, y, z])]) translator[chunk_list[-1]] = len(chunk_list) - 1 return chunk_list, translator
[docs]def generate_subcell_kd_from_proba( subcell_names: List[str], chunk_size: Optional[Union[list, tuple]] = None, transf_func_kd_overlay: Optional[Dict[str, Callable]] = None, load_cellorganelles_from_kd_overlaycubes: bool = False, cube_of_interest_bb: Optional[Tuple[np.ndarray]] = None, cube_shape: Optional[Tuple[int]] = None, log: Logger = None, overwrite=False, **kwargs): """ This function generates a connected components segmentation for the given sub-cellular structures (e.g. ['mi', 'sj', 'vc]) as KnossosDatasets. The data format of the source data is KnossosDataset which path(s) is defined in ``global_params.config['paths']`` (e.g. key ``kd_mi_path`` for mitochondria). The resulting KDs will be stored at (for each ``co in subcell_names``) ``"{}/knossosdatasets/{}_seg/".format(global_params.config.working_dir, co)``. See :func:`~syconn.extraction.object_extraction_wrapper.from_probabilities_to_kd` for details of the conversion process from the initial probability map to the SV segmentation. Default: thresholding and connected components, thresholds are set via the `config.yml` file, check ``syconn.global_params.config['cell_objects']["probathresholds"]`` of an initialized :class:`~syconn.handler.config.DynConfig` object. Args: subcell_names (List[str]): List of subcellular structures to generate segmentation for. chunk_size (Optional[Union[list, tuple]]): Size of the chunks to be processed. transf_func_kd_overlay (Optional[Dict[str, Callable]]): Transformation function for overlay. load_cellorganelles_from_kd_overlaycubes (bool): Flag to load cell organelles from overlay cubes. cube_of_interest_bb (Optional[Tuple[np.ndarray]]): Bounding box of the cube of interest. cube_shape (Optional[Tuple[int]]): Shape of the cube. log (Logger): Logger for logging the process. overwrite (bool): Flag to overwrite existing data. **kwargs: Additional keyword arguments. Returns: None """ if chunk_size is None: chunk_size = [512, 512, 512] if log is None: log = log_extraction if cube_shape is None: cube_shape = (256, 256, 256) kd = basics.kd_factory(global_params.config.kd_seg_path) if cube_of_interest_bb is None: cube_of_interest_bb = [np.zeros(3, dtype=np.int32), kd.boundary] size = cube_of_interest_bb[1] - cube_of_interest_bb[0] + 1 offset = cube_of_interest_bb[0] cd_dir = "{}/chunkdatasets/{}/".format(global_params.config.working_dir, "_".join(subcell_names)) if os.path.isdir(cd_dir): if not overwrite: msg = f'Could not start generation of sub-cellular objects ' \ f'"{subcell_names}" ChunkDataset because it already exists at "{cd_dir}" ' \ f'and overwrite was not set to True.' log_extraction.error(msg) raise FileExistsError(msg) log.debug('Found existing ChunkDataset at {}. Removing it now.'.format(cd_dir)) shutil.rmtree(cd_dir) cd = chunky.ChunkDataset() # TODO: possible to restrict ChunkDataset here already to report correct number of processed chunks? Check # coordinate framework compatibility downstream in `from_probabilities_to_kd` cd.initialize(kd, kd.boundary, chunk_size, cd_dir, box_coords=[0, 0, 0], fit_box_size=True, list_of_coords=[]) log.info('Started object extraction of cellular organelles "{}" from ' '{} chunks.'.format(", ".join(subcell_names), len(cd.chunk_dict))) prob_kd_path_dict = {co: getattr(global_params.config, 'kd_{}_path'.format(co)) for co in subcell_names} prob_threshs = [] # get probability threshold for co in subcell_names: prob_threshs.append(global_params.config['cell_objects']["probathresholds"][co]) path = global_params.config.kd_organelle_seg_paths[co] if os.path.isdir(path): if not overwrite: msg = f'Could not start generation of sub-cellular object ' \ f'"{co}" KnossosDataset because it already exists at "{path}" and overwrite ' \ f'was not set to True.' log_extraction.error(msg) raise FileExistsError(msg) log.debug('Found existing KD at {}. Removing it now.'.format(path)) shutil.rmtree(path) target_kd = knossosdataset.KnossosDataset() scale = np.array(global_params.config['scaling'], dtype=np.float32) target_kd._cube_shape = cube_shape target_kd.scales = [scale, ] target_kd.initialize_without_conf(path, kd.boundary, scale, kd.experiment_name, mags=[1, ], create_pyk_conf=True, create_knossos_conf=False) if load_cellorganelles_from_kd_overlaycubes: # no thresholds needed prob_threshs = None from_probabilities_to_kd(global_params.config.kd_organelle_seg_paths, cd, "_".join(subcell_names), # membrane_kd_path=global_params.config.kd_barrier_path, # TODO: currently does not exist prob_kd_path_dict=prob_kd_path_dict, thresholds=prob_threshs, hdf5names=subcell_names, size=size, offset=offset, load_from_kd_overlaycubes=load_cellorganelles_from_kd_overlaycubes, transf_func_kd_overlay=transf_func_kd_overlay, log=log, **kwargs) shutil.rmtree(cd_dir, ignore_errors=True)
[docs]def from_probabilities_to_kd( target_kd_paths: Optional[Dict[str, str]], cset: 'chunky.ChunkDataset', filename: str, hdf5names: List[str], prob_kd_path_dict: Optional[Dict[str, str]] = None, load_from_kd_overlaycubes: bool = False, transf_func_kd_overlay: Optional[Dict[str, Callable]] = None, log: Optional[Logger] = None, overlap: str = "auto", sigmas: Optional[list] = None, thresholds: Optional[list] = None, debug: bool = False, swapdata: bool = False, offset: Optional[np.ndarray] = None, size: Optional[np.ndarray] = None, suffix: str = "", transform_func: Optional[Callable] = None, func_kwargs: Optional[dict] = None, n_cores: Optional[int] = None, overlap_thresh: Optional[int] = 0, stitch_overlap: Optional[int] = None, membrane_filename: str = None, membrane_kd_path: str = None, hdf5_name_membrane: str = None, n_chunk_jobs: int = None): """ Converts classified or predicted data into a ChunkDataset or KnossosDataset(s). The ChunkDataset is used to store intermediate extraction results such as per-cube segmentation, stitched results, and globally unique segmentation. The function requires pre-initialized KnossosDatasets given by `target_kd_paths`. Args: target_kd_paths (Optional[Dict[str, str]]): Paths to pre-initialized output KnossosDatasets. cset (chunky.ChunkDataset): ChunkDataset used for object extraction and may contain source data. filename (str): Base name used to store the extracted in `cset`. hdf5names (List[str]): Keys used to store intermediate extraction results. prob_kd_path_dict (Optional[Dict[str, str]]): Paths to source KnossosDatasets. load_from_kd_overlaycubes (bool): If True, load prob/seg data from overlaycubes instead of raw cubes. transf_func_kd_overlay (Optional[Dict[str, Callable]]): Method applied to cube data if `load_from_kd_overlaycubes` is True. log (Optional[Logger]): Logger for logging events. overlap (str): Defines overlap with neighbouring chunks left for later processing steps. sigmas (Optional[list]): Defines sigmas of Gaussian filters applied to probability maps. thresholds (Optional[list]): Threshold for cutting probability map. debug (bool): If True, multiprocessing steps only operate on one core using 'map'. swapdata (bool): If True, an x-z swap is applied to data prior to processing. offset (Optional[np.ndarray]): Offset of processed volume. size (Optional[np.ndarray]): Size of processed volume of dataset starting at `offset`. suffix (str): Suffix used for intermediate processing steps. transform_func (Optional[Callable]): Segmentation method applied. func_kwargs (Optional[dict]): Keyword arguments for `transform_func`. n_cores (Optional[int]): Number of cores used for each job. overlap_thresh (Optional[int]): Overlap fraction of object in different chunks to be considered stitched. stitch_overlap (Optional[int]): Volume evaluated during stitching procedure. membrane_filename (str): Filename of prediction in chunkdataset for accessing membrane segmentation. membrane_kd_path (str): Path to knossosdataset containing a membrane segmentation. hdf5_name_membrane (str): Key to access data in saved chunk when `membrane_filename` is set. n_chunk_jobs (int): Number of jobs. Returns: None """ if log is None: log = log_extraction all_times = [] step_names = [] if prob_kd_path_dict is not None: kd_keys = list(prob_kd_path_dict.keys()) assert len(kd_keys) == len(hdf5names) for kd_key in kd_keys: assert kd_key in hdf5names if size is not None and offset is not None: chunk_list, chunk_translator = \ calculate_chunk_numbers_for_box(cset, offset, size) else: chunk_translator = {} chunk_list = [ii for ii in range(len(cset.chunk_dict))] for ii in range(len(cset.chunk_dict)): chunk_translator[ii] = ii if thresholds is not None and thresholds[0] <= 1.: thresholds = np.array(thresholds) thresholds *= 255 if sigmas is not None and swapdata == 1: for nb_sigma in range(len(sigmas)): if len(sigmas[nb_sigma]) == 3: sigmas[nb_sigma] = \ basics.switch_array_entries(sigmas[nb_sigma], [0, 2]) # # -------------------------------------------------------------------------- time_start = time.time() cc_info_list, overlap_info = oes.object_segmentation( cset, filename, hdf5names, overlap=overlap, sigmas=sigmas, thresholds=thresholds, chunk_list=chunk_list, debug=debug, swapdata=swapdata, prob_kd_path_dict=prob_kd_path_dict, membrane_filename=membrane_filename, membrane_kd_path=membrane_kd_path, hdf5_name_membrane=hdf5_name_membrane, fast_load=True, suffix=suffix, transform_func=transform_func, transform_func_kwargs=func_kwargs, nb_cpus=n_cores, load_from_kd_overlaycubes=load_from_kd_overlaycubes, transf_func_kd_overlay=transf_func_kd_overlay, n_chunk_jobs=n_chunk_jobs) if stitch_overlap is None: stitch_overlap = overlap_info[1] else: overlap_info[1] = stitch_overlap if not np.all(stitch_overlap <= overlap_info[0]): msg = "Stitch overlap ({}) has to be <= than chunk overlap ({})." \ "".format(overlap_info[1], overlap_info[0]) log.error(msg) raise ValueError(msg) overlap = overlap_info[0] all_times.append(time.time() - time_start) step_names.append("conneceted components") basics.write_obj2pkl(cset.path_head_folder.rstrip("/") + "/connected_components.pkl", [cc_info_list, overlap_info]) # # # ------------------------------------------------------------------------ time_start = time.time() nb_cc_dict = {} max_nb_dict = {} max_labels = {} for hdf5_name in hdf5names: nb_cc_dict[hdf5_name] = np.zeros(len(chunk_list), dtype=np.int32) max_nb_dict[hdf5_name] = np.zeros(len(chunk_list), dtype=np.int32) for cc_info in cc_info_list: nb_cc_dict[cc_info[1]][chunk_translator[cc_info[0]]] = cc_info[2] for hdf5_name in hdf5names: max_nb_dict[hdf5_name][0] = 0 for nb_chunk in range(1, len(chunk_list)): max_nb_dict[hdf5_name][nb_chunk] = \ max_nb_dict[hdf5_name][nb_chunk - 1] + nb_cc_dict[hdf5_name][nb_chunk - 1] max_labels[hdf5_name] = int(max_nb_dict[hdf5_name][-1] + nb_cc_dict[hdf5_name][-1]) all_times.append(time.time() - time_start) step_names.append("max labels") basics.write_obj2pkl(cset.path_head_folder.rstrip("/") + "/max_labels.pkl", max_labels) # # ------------------------------------------------------------------------ time_start = time.time() oes.make_unique_labels(cset, filename, hdf5names, chunk_list, max_nb_dict, chunk_translator, debug, suffix=suffix, n_chunk_jobs=n_chunk_jobs, nb_cpus=n_cores) all_times.append(time.time() - time_start) step_names.append("unique labels") # # ------------------------------------------------------------------------ chunky.save_dataset(cset) # save dataset to be able to load it during make_stitch_list (this # allows to load the ChunkDataset inside the worker instead of pickling it for each, which # slows down the submission process. time_start = time.time() stitch_list = oes.make_stitch_list(cset, filename, hdf5names, chunk_list, stitch_overlap, overlap, debug, suffix=suffix, overlap_thresh=overlap_thresh, n_chunk_jobs=n_chunk_jobs, nb_cpus=n_cores) all_times.append(time.time() - time_start) step_names.append("stitch list") basics.write_obj2pkl(cset.path_head_folder.rstrip("/") + "/stitch_list.pkl", stitch_list) # # # ------------------------------------------------------------------------ # time_start = time.time() merge_dict, merge_list_dict = oes.make_merge_list(hdf5names, stitch_list, max_labels) all_times.append(time.time() - time_start) step_names.append("merge list") basics.write_obj2pkl(cset.path_head_folder.rstrip("/") + "/merge_list.pkl", [merge_dict, merge_list_dict]) # -------------------------------------------------------------------------- time_start = time.time() oes.apply_merge_list(cset, chunk_list, filename, hdf5names, merge_list_dict, debug, suffix=suffix, n_chunk_jobs=n_chunk_jobs, nb_cpus=n_cores) all_times.append(time.time() - time_start) step_names.append("apply merge list") time_start = time.time() chunky.save_dataset(cset) oes.export_cset_to_kd_batchjob( target_kd_paths, cset, '{}_stitched_components'.format(filename), hdf5names, offset=offset, size=size, stride=cset.chunk_size, as_raw=False, orig_dtype=np.uint64, unified_labels=False, log=log, n_max_job=n_chunk_jobs, n_cores=n_cores) all_times.append(time.time() - time_start) step_names.append("export KD") # -------------------------------------------------------------------------- log.debug("Time overview [from_probabilities_to_kd]:") for ii in range(len(all_times)): log.debug("%s: %.3fs" % (step_names[ii], all_times[ii])) log.debug("--------------------------") log.debug("Total Time: %.1f min" % (np.sum(all_times) / 60)) log.debug("--------------------------")