Source code for syconn.proc.sd_proc

# -*- coding: utf-8 -*-
# SyConn - Synaptic connectivity inference toolkit
#
# Copyright (c) 2016 - now
# Max Planck Institute of Neurobiology, Martinsried, Germany
# Authors: Philipp Schubert, Joergen Kornfeld

import gc
import glob
import os
import sys
import time

import tqdm

from . import log_proc
from .image import single_conn_comp_img
from .meshes import mesh_area_calc, merge_meshes_incl_norm
from .. import global_params
from ..backend.storage import AttributeDict, VoxelStorage, VoxelStorageDyn, MeshStorage, CompressedStorage
from ..extraction import object_extraction_wrapper as oew
from ..handler import basics
from ..mp import batchjob_utils as qu
from ..mp import mp_utils as sm
from ..proc.meshes import mesh_chunk, find_meshes
from ..reps import rep_helper
from ..reps import segmentation
from ..extraction.find_object_properties import map_subcell_extract_props as map_subcell_extract_props_func

from multiprocessing import Process
import pickle as pkl
import numpy as np
from logging import Logger
import shutil
from collections import defaultdict
from knossos_utils import chunky
from typing import Optional, List, Union


[docs]def dataset_analysis(sd, recompute=True, n_jobs=None, compute_meshprops=False): """ Analyzes a SegmentationDataset and extracts and caches SegmentationObjects attributes as numpy arrays. This function only recognizes dict/storage entries of type int for object attribute collection. Args: sd (SegmentationDataset): The SegmentationDataset to analyze. This is typically a set of cell supervoxels ('sv'). recompute (bool, optional): A flag indicating whether to recompute key information of each object (representative coordinate, bounding box, size). Defaults to True. n_jobs (int, optional): The number of jobs to run in parallel. Defaults to None. compute_meshprops (bool, optional): A flag indicating whether to compute mesh properties. If set to True, it will also calculate meshes (sparsely) if not available. Defaults to False. """ if n_jobs is None: n_jobs = global_params.config.ncore_total # individual tasks are very fast if recompute or compute_meshprops: n_jobs *= 4 paths = sd.so_dir_paths if compute_meshprops: if sd.type not in global_params.config['meshes']['downsampling']: msg = 'SegmentationDataset of type "{}" has no configured mesh parameters. ' \ 'Please add them to global_params.py accordingly.' log_proc.error(msg) raise ValueError(msg) # Partitioning the work multi_params = basics.chunkify(paths, n_jobs) multi_params = [(mps, sd.type, sd.version, sd.working_dir, recompute, compute_meshprops) for mps in multi_params] # Running workers if not qu.batchjob_enabled(): results = sm.start_multiprocess_imap(_dataset_analysis_thread, multi_params, debug=False) # Creating summaries attr_dict = {} for this_attr_dict in results: for attribute in this_attr_dict: if len(this_attr_dict['id']) == 0: continue value = this_attr_dict[attribute] if attribute == 'id': value = np.array(value, np.uint64) if attribute not in attr_dict: if type(value) is not list: sh = list(value.shape) sh[0] = 0 attr_dict[attribute] = np.empty(sh, dtype=value.dtype) else: attr_dict[attribute] = [] if type(value) is not list: # assume numpy array attr_dict[attribute] = np.concatenate([attr_dict[attribute], value]) else: attr_dict[attribute] += value for attribute in attr_dict: if attribute in ['cs_ids', 'mapping_mi_ids', 'mapping_mi_ratios', 'mapping_sj_ids', 'mapping_vc_ids', 'mapping_vc_ratios', 'mapping_sj_ratios']: np.save(sd.path + "/%ss.npy" % attribute, np.array(attr_dict[attribute], dtype=object)) else: np.save(sd.path + "/%ss.npy" % attribute, attr_dict[attribute]) else: path_to_out = qu.batchjob_script(multi_params, "dataset_analysis", suffix=sd.type) out_files = np.array(glob.glob(path_to_out + "/*")) res_keys = [] file_mask = np.zeros(len(out_files), dtype=np.int64) res = sm.start_multiprocess_imap(_dataset_analysis_check, out_files, sm.cpu_count()) for ix_cnt, r in enumerate(res): rk, n_el = r if n_el > 0: file_mask[ix_cnt] = n_el if len(res_keys) == 0: res_keys = rk if len(res_keys) == 0: raise ValueError(f'No objects found during dataset_analysis of {sd}.') n_ids = np.sum(file_mask) log_proc.info(f'Caching {len(res_keys)} attributes of {n_ids} objects in {sd} during ' f'dataset_analysis:\n{res_keys}') out_files = out_files[file_mask > 0] params = [(attr, out_files, n_ids, sd.path) for attr in res_keys] qu.batchjob_script(params, 'dataset_analysis_collect', n_cores=global_params.config['ncores_per_node'], remove_jobfolder=True) shutil.rmtree(os.path.abspath(path_to_out + "/../"), ignore_errors=True)
def _dataset_analysis_check(out_file): """ Checks the output file of the dataset analysis process. This function is typically used in a multiprocessing context to verify the results of the analysis. Args: out_file (str): The path to the output file to check. """ res_keys = [] with open(out_file, 'rb') as f: res_dc = pkl.load(f) n_el = len(res_dc['id']) if n_el > 0: res_keys = list(res_dc.keys()) return res_keys, n_el def _dataset_analysis_collect(args): """ Collects the results of the dataset analysis process. This function is typically used in a multiprocessing context to gather the results of the analysis. Args: args (tuple): A tuple containing the attribute to collect, the output files to collect from, the number of IDs, and the path to the SegmentationDataset. """ attribute, out_files, n_ids, sd_path = args # start_multiprocess_imap obeys parameter order and therefore the # collected attributes will share the same ordering. n_jobs = min(len(out_files), global_params.config['ncores_per_node'] * 4) params = list(basics.chunkify([(p, attribute) for p in out_files], n_jobs)) tmp_res = sm.start_multiprocess_imap( _load_attr_helper, params, nb_cpus=global_params.config['ncores_per_node'] // 2, debug=False) if attribute in ['cs_ids', 'mapping_mi_ids', 'mapping_mi_ratios', 'mapping_sj_ids', 'mapping_vc_ids', 'mapping_vc_ratios', 'mapping_sj_ratios']: tmp_res = [el for lst in tmp_res for el in lst] # flatten lists tmp_res = np.array(tmp_res, dtype=object) else: tmp_res = np.concatenate(tmp_res) assert tmp_res.shape[0] == n_ids, f'Shape mismatch during dataset_analysis of property {attribute}.' np.save(f"{sd_path}/{attribute}s.npy", tmp_res) def _load_attr_helper(args): """ Helper function to load attributes during the dataset analysis process. This function is typically used in a multiprocessing context to load attributes in parallel. Args: args (tuple): A tuple containing the file name and the attribute to load. """ res = [] attr = args[0][1] for arg in args: fname, attr_ex = arg assert attr == attr_ex with open(fname, 'rb') as f: dc = pkl.load(f) if len(dc['id']) == 0: continue value = dc[attr] if attr == 'id': value = np.array(value, np.uint64) if type(value) is not list: # assume numpy array if len(res) == 0: sh = list(value.shape) sh[0] = 0 res = np.empty(sh, dtype=value.dtype) res = np.concatenate([res, value]) else: res += value return res def _dataset_analysis_thread(args): """ Worker function for the dataset analysis process. This function is typically used in a multiprocessing context to perform the analysis in parallel. Args: args (tuple): A tuple containing the paths to process, the object type, the version, the working directory, a flag indicating whether to recompute, and a flag indicating whether to compute mesh properties. """ # TODO: use arrays to store properties already during collection paths = args[0] obj_type = args[1] version = args[2] working_dir = args[3] recompute = args[4] compute_meshprops = args[5] global_attr_dict = dict(id=[], size=[], bounding_box=[], rep_coord=[]) for p in paths: if not len(os.listdir(p)) > 0: os.rmdir(p) else: new_mesh_generated = False this_attr_dc = AttributeDict(p + "/attr_dict.pkl", read_only=not recompute) if recompute: this_vx_dc = VoxelStorage(p + "/voxel.pkl", read_only=True, disable_locking=True) so_ids = list(this_vx_dc.keys()) else: so_ids = list(this_attr_dc.keys()) if compute_meshprops: this_mesh_dc = MeshStorage(p + "/mesh.pkl", read_only=True, disable_locking=True) for so_id in so_ids: global_attr_dict["id"].append(so_id) so = segmentation.SegmentationObject(so_id, obj_type, version, working_dir) so.attr_dict = this_attr_dc[so_id] if recompute: # prevent loading voxels in case we use VoxelStorageDyn if not isinstance(this_vx_dc, VoxelStorageDyn): # use fall-back, so._voxels will be used for mesh computation but also # triggers the calculation of size and bounding box. so.load_voxels(voxel_dc=this_vx_dc) else: # VoxelStorageDyn stores pre-computed bounding box, size and rep coord values. so.calculate_bounding_box(this_vx_dc) so.calculate_size(this_vx_dc) so.calculate_rep_coord(this_vx_dc) so.attr_dict["rep_coord"] = so.rep_coord so.attr_dict["bounding_box"] = so.bounding_box so.attr_dict["size"] = so.size if compute_meshprops: # make sure so._mesh is available prior to mesh_bb and mesh_area call (otherwise every unavailable # mesh will be generated from scratch and saved to so.mesh_path for every single object. if so.id in this_mesh_dc: so._mesh = this_mesh_dc[so.id] else: new_mesh_generated = True so._mesh = so.mesh_from_scratch() this_mesh_dc[so.id] = so._mesh # if mesh does not exist beforehand, it will be generated so.attr_dict["mesh_bb"] = so.mesh_bb so.attr_dict["mesh_area"] = so.mesh_area for attribute in so.attr_dict.keys(): if attribute not in global_attr_dict: global_attr_dict[attribute] = [] global_attr_dict[attribute].append(so.attr_dict[attribute]) this_attr_dc[so_id] = so.attr_dict if recompute or compute_meshprops: this_attr_dc.push() if new_mesh_generated: this_mesh_dc.push() if 'bounding_box' in global_attr_dict: global_attr_dict['bounding_box'] = np.array(global_attr_dict['bounding_box'], dtype=np.int32) if 'rep_coord' in global_attr_dict: global_attr_dict['rep_coord'] = np.array(global_attr_dict['rep_coord'], dtype=np.int32) if 'size' in global_attr_dict: global_attr_dict['size'] = np.array(global_attr_dict['size'], dtype=np.int64) if 'mesh_area' in global_attr_dict: global_attr_dict['mesh_area'] = np.array(global_attr_dict['mesh_area'], dtype=np.float32) return global_attr_dict def _cache_storage_paths(args): """ Caches the storage paths for a given set of object IDs. This function is used to organize the storage of segmentation objects in a hierarchical folder structure for efficient access and retrieval. The function supports both old and new subfolder structures as defined in the global parameters. Args: args (tuple): A tuple containing the following: - target_p (str): The target path where the storage paths will be cached. - all_ids (list): A list of all object IDs for which the storage paths are to be cached. - n_folders_fs (int): The number of folders per filesystem. This parameter is used to determine the depth of the folder hierarchy. Note: The dtype of the object IDs is currently hardcoded as np.uint64 and may need to be made configurable in future versions. """ target_p, all_ids, n_folders_fs = args # outputs target folder hierarchy for object storage if global_params.config.use_new_subfold: target_dir_func = rep_helper.subfold_from_ix_new else: target_dir_func = rep_helper.subfold_from_ix_OLD dest_dc_tmp = defaultdict(list) for obj_id in all_ids: dest_dc_tmp[target_dir_func( obj_id, n_folders_fs)].append(obj_id) del all_ids cd = CompressedStorage(target_p, disable_locking=True) for k, v in dest_dc_tmp.items(): cd[k] = np.array(v, dtype=np.uint64) # TODO: dtype needs to be configurable cd.push()
[docs]def map_subcell_extract_props(kd_seg_path: str, kd_organelle_paths: dict, n_folders_fs: int = 1000, n_folders_fs_sc: int = 1000, n_chunk_jobs: Optional[int] = None, n_cores: int = 1, cube_of_interest_bb: Optional[tuple] = None, chunk_size: Optional[Union[tuple, np.ndarray]] = None, log: Logger = None, overwrite=False): """ Extracts and caches segmentation properties for each Supervoxel (SV) in cell and subcellular segmentation. This function requires KNOSSOSDATASETs (KDs) at `kd_seg_path` and `kd_organelle_paths`. The process involves two main steps: 1. Extract properties (representative coordinate, bounding box, size) and overlap voxels locally (organelles <-> cell segmentation). 2. Write out combined results for each SV object. Args: kd_seg_path (str): Path to the KNOSSOSDATASET (KD) of the cell segmentation. kd_organelle_paths (dict): Dictionary of paths to the KDs of the subcellular segmentations. n_folders_fs (int, optional): Number of folders in the file system for the cell segmentation. Defaults to 1000. n_folders_fs_sc (int, optional): Number of folders in the file system for the subcellular segmentation. Defaults to 1000. n_chunk_jobs (Optional[int], optional): Number of chunks to be processed in parallel. Defaults to None. n_cores (int, optional): Number of cores to be used for processing. Defaults to 1. cube_of_interest_bb (Optional[tuple], optional): Bounding box of the cube of interest. Defaults to None. chunk_size (Optional[Union[tuple, np.ndarray]], optional): Size of the chunks to be processed. Defaults to None. log (Logger, optional): Logger for logging the process. Defaults to None. overwrite (bool, optional): If True, overwrite existing results. Defaults to False. """ kd = basics.kd_factory(kd_seg_path) assert sys.version_info >= (3, 6) # below, dictionaries are unordered! for k, v in kd_organelle_paths.items(): if not np.array_equal(basics.kd_factory(v).boundary, kd.boundary): msg = "Data shape of subcellular structures '{}' differs from " \ "cell segmentation data. {} vs. {}".format( k, basics.kd_factory(v).boundary, kd.boundary) log_proc.error(msg) raise ValueError(msg) # outputs target folder hierarchy for object storages if global_params.config.use_new_subfold: target_dir_func = rep_helper.subfold_from_ix_new else: target_dir_func = rep_helper.subfold_from_ix_OLD # get chunks if log is None: log = log_proc if n_chunk_jobs is None: n_chunk_jobs = global_params.config.ncore_total * 2 if chunk_size is None: chunk_size = [512, 512, 512] if cube_of_interest_bb is None: cube_of_interest_bb = [np.zeros(3, dtype=np.int32), kd.boundary] size = cube_of_interest_bb[1] - cube_of_interest_bb[0] offset = cube_of_interest_bb[0] cd_dir = "{}/chunkdatasets/tmp/".format(global_params.config.temp_path) cd = chunky.ChunkDataset() cd.initialize(kd, size, chunk_size, cd_dir, box_coords=offset, fit_box_size=True) sv_sd = segmentation.SegmentationDataset( working_dir=global_params.config.working_dir, obj_type="sv", version=0, n_folders_fs=n_folders_fs) dir_props = f"{global_params.config.temp_path}/tmp_props/" dir_meshes = f"{global_params.config.temp_path}/tmp_meshes/" # remove previous temporary results. if os.path.isdir(dir_props): if not overwrite: msg = f'Could not start extraction of supervoxel objects ' \ f'because temporary files already exist at "{dir_props}" ' \ f'and overwrite was set to False.' log_proc.error(msg) raise FileExistsError(msg) log.debug(f'Found existing cache folder at {dir_props}. Removing it now.') shutil.rmtree(dir_props) if os.path.isdir(dir_meshes): shutil.rmtree(dir_meshes) os.makedirs(dir_props, exist_ok=True) os.makedirs(dir_meshes, exist_ok=True) all_times = [] step_names = [] dict_paths_tmp = [] # extract mapping start = time.time() # create chunk list represented by offset and unique ID ch_list = [(off, ch_id) for ch_id, off in enumerate(list(cd.coord_dict.keys()))] # create chunked chunk list ch_list = [ch for ch in basics.chunkify_successive(ch_list, max(1, len(cd.coord_dict) // n_chunk_jobs))] multi_params = [(ch, chunk_size, kd_seg_path, kd_organelle_paths, worker_nr, global_params.config.allow_mesh_gen_cells) for worker_nr, ch in enumerate(ch_list)] # results contain meshing information cell_mesh_workers = dict() cell_prop_worker = dict() subcell_mesh_workers = [dict() for _ in range(len(kd_organelle_paths))] subcell_prop_workers = [dict() for _ in range(len(kd_organelle_paths))] # needed for caching target storage folder for all objects all_ids = {k: [] for k in list(kd_organelle_paths.keys()) + ['sv']} if qu.batchjob_enabled(): path_to_out = qu.batchjob_script( multi_params, "map_subcell_extract_props", n_cores=n_cores) out_files = glob.glob(path_to_out + "/*.pkl") for out_file in tqdm.tqdm(out_files, leave=False): with open(out_file, 'rb') as f: worker_nr, ref_mesh_dc = pkl.load(f) for chunk_id, cell_ids in ref_mesh_dc['sv'].items(): cell_mesh_workers[chunk_id] = (worker_nr, cell_ids) # memory consumption of list is about 0.25 cell_prop_worker[worker_nr] = list(set().union(*ref_mesh_dc['sv'].values())) all_ids['sv'].extend(cell_prop_worker[worker_nr]) c_mesh_worker_dc = "{}/c_mesh_worker_dict.pkl".format(global_params.config.temp_path) with open(c_mesh_worker_dc, 'wb') as f: pkl.dump(cell_mesh_workers, f, protocol=4) del cell_mesh_workers c_prop_worker_dc = "{}/c_prop_worker_dict.pkl".format(global_params.config.temp_path) with open(c_prop_worker_dc, 'wb') as f: pkl.dump(cell_prop_worker, f, protocol=4) all_ids['sv'] = np.unique(all_ids['sv']) del cell_prop_worker # Collect organelle worker info # memory consumption of list is about 0.25 for out_file in tqdm.tqdm(out_files, leave=False): with open(out_file, 'rb') as f: worker_nr, ref_mesh_dc = pkl.load(f) # iterate over each subcellular structure for ii, organelle in enumerate(kd_organelle_paths): organelle = global_params.config['process_cell_organelles'][ii] for chunk_id, subcell_ids in ref_mesh_dc[organelle].items(): subcell_mesh_workers[ii][chunk_id] = (worker_nr, subcell_ids) subcell_prop_workers[ii][worker_nr] = list(set().union(*ref_mesh_dc[ organelle].values())) all_ids[organelle].extend(subcell_prop_workers[ii][worker_nr]) for ii, organelle in enumerate(kd_organelle_paths): all_ids[organelle] = np.unique(all_ids[organelle]) sc_mesh_worker_dc = "{}/sc_{}_mesh_worker_dict.pkl".format( global_params.config.temp_path, organelle) with open(sc_mesh_worker_dc, 'wb') as f: pkl.dump(subcell_mesh_workers[ii], f, protocol=4) dict_paths_tmp += [sc_mesh_worker_dc] sc_prop_worker_dc = "{}/sc_{}_prop_worker_dict.pkl".format( global_params.config.temp_path, organelle) with open(sc_prop_worker_dc, 'wb') as f: pkl.dump(subcell_prop_workers[ii], f, protocol=4) dict_paths_tmp += [sc_prop_worker_dc] else: results = sm.start_multiprocess_imap( _map_subcell_extract_props_thread, multi_params, verbose=False, debug=False) for worker_nr, ref_mesh_dc in tqdm.tqdm(results, leave=False): for chunk_id, cell_ids in ref_mesh_dc['sv'].items(): cell_mesh_workers[chunk_id] = (worker_nr, cell_ids) # memory consumption of list is about 0.25 cell_prop_worker[worker_nr] = list(set().union(*ref_mesh_dc['sv'].values())) all_ids['sv'].extend(cell_prop_worker[worker_nr]) # iterate over each subcellular structure for ii, organelle in enumerate(kd_organelle_paths): for chunk_id, subcell_ids in ref_mesh_dc[organelle].items(): subcell_mesh_workers[ii][chunk_id] = (worker_nr, subcell_ids) subcell_prop_workers[ii][worker_nr] = list(set().union(*ref_mesh_dc[ organelle].values())) all_ids[organelle].extend(subcell_prop_workers[ii][worker_nr]) del results c_mesh_worker_dc = "{}/c_mesh_worker_dict.pkl".format(global_params.config.temp_path) with open(c_mesh_worker_dc, 'wb') as f: pkl.dump(cell_mesh_workers, f, protocol=4) del cell_mesh_workers c_prop_worker_dc = "{}/c_prop_worker_dict.pkl".format(global_params.config.temp_path) with open(c_prop_worker_dc, 'wb') as f: pkl.dump(cell_prop_worker, f, protocol=4) del cell_prop_worker all_ids['sv'] = np.unique(all_ids['sv']) for ii, organelle in enumerate(kd_organelle_paths): all_ids[organelle] = np.unique(all_ids[organelle]) sc_mesh_worker_dc = "{}/sc_{}_mesh_worker_dict.pkl".format( global_params.config.temp_path, organelle) with open(sc_mesh_worker_dc, 'wb') as f: pkl.dump(subcell_mesh_workers[ii], f, protocol=4) dict_paths_tmp += [sc_mesh_worker_dc] sc_prop_worker_dc = "{}/sc_{}_prop_worker_dict.pkl".format( global_params.config.temp_path, organelle) with open(sc_prop_worker_dc, 'wb') as f: pkl.dump(subcell_prop_workers[ii], f, protocol=4) dict_paths_tmp += [sc_prop_worker_dc] del subcell_mesh_workers, subcell_prop_workers params_cache = [] for k, v in all_ids.items(): dest_p = f'{global_params.config.temp_path}/storage_targets_{k}.pkl' nf = n_folders_fs if k == 'sv' else n_folders_fs_sc params_cache.append((dest_p, v, nf)) dict_paths_tmp.append(dest_p) _ = sm.start_multiprocess_imap(_cache_storage_paths, params_cache, nb_cpus=global_params.config['ncores_per_node']) del all_ids, params_cache dict_paths_tmp += [c_mesh_worker_dc, c_prop_worker_dc] step_names.append("extract and map segmentation objects") all_times.append(time.time() - start) # reduce step start = time.time() # create folders for existing (sub-)cell supervoxels to prevent concurrent makedirs ids = rep_helper.get_unique_subfold_ixs(n_folders_fs) for k in tqdm.tqdm(ids, leave=False): curr_dir = sv_sd.so_storage_path + target_dir_func(k, n_folders_fs) os.makedirs(curr_dir, exist_ok=True) for ii, organelle in enumerate(kd_organelle_paths): sc_sd = segmentation.SegmentationDataset( working_dir=global_params.config.working_dir, obj_type=organelle, version=0, n_folders_fs=n_folders_fs_sc) ids = rep_helper.get_unique_subfold_ixs(n_folders_fs_sc) for ix in tqdm.tqdm(ids, leave=False): curr_dir = sc_sd.so_storage_path + target_dir_func( ix, n_folders_fs_sc) os.makedirs(curr_dir, exist_ok=True) all_times.append(time.time() - start) step_names.append("conversion of results") # write to subcell. SV attribute dicts # must be executed before '_write_props_to_sv_thread' start = time.time() # create "dummy" IDs which represent each a unique storage path storage_location_ids = rep_helper.get_unique_subfold_ixs(n_folders_fs_sc) n_jobs = int(min(2 * global_params.config.ncore_total, len(storage_location_ids) / 2)) multi_params = [(sv_id_block, n_folders_fs_sc, kd_organelle_paths) for sv_id_block in basics.chunkify(storage_location_ids, n_jobs)] if not qu.batchjob_enabled(): sm.start_multiprocess_imap(_write_props_to_sc_thread, multi_params, debug=False) else: # hacky, but memory load gets high at that size, prevent oom events of slurm and other system relevant parts n_cores = 1 if np.prod(cube_of_interest_bb) < 2e12 else 2 qu.batchjob_script(multi_params, "write_props_to_sc", script_folder=None, remove_jobfolder=True, n_cores=n_cores) # perform dataset analysis to cache all attributes in numpy arrays procs = [] da_kwargs = dict(recompute=False, compute_meshprops=False) for k in kd_organelle_paths: sc_sd = segmentation.SegmentationDataset( working_dir=global_params.config.working_dir, obj_type=k, version=0, n_folders_fs=n_folders_fs_sc) p = Process(target=dataset_analysis, args=(sc_sd,), kwargs=da_kwargs) procs.append(p) for p in procs: p.start() for p in procs: p.join() if p.exitcode != 0: raise Exception(f'Worker {p.name} stopped unexpectedly with exit ' f'code {p.exitcode}.') p.close() all_times.append(time.time() - start) step_names.append("write subcellular SD") # writing cell SV properties to SD start = time.time() # create "dummy" IDs which represent each a unique storage path storage_location_ids = rep_helper.get_unique_subfold_ixs(n_folders_fs) n_jobs = int(max(2 * global_params.config.ncore_total, len(storage_location_ids) / 15)) multi_params = [(sv_id_block, n_folders_fs, global_params.config.allow_mesh_gen_cells, list(kd_organelle_paths.keys())) for sv_id_block in basics.chunkify(storage_location_ids, n_jobs)] if not qu.batchjob_enabled(): sm.start_multiprocess_imap(_write_props_to_sv_thread, multi_params, debug=False) else: # hacky, but memory load gets high at that size, prevent oom events of slurm and other system relevant parts n_cores = 1 if np.prod(size) < 2e12 else 2 qu.batchjob_script(multi_params, "write_props_to_sv", remove_jobfolder=True, n_cores=n_cores) dataset_analysis(sv_sd, recompute=False, compute_meshprops=False) all_times.append(time.time() - start) step_names.append("write cell SD") # clear temporary files if global_params.config.use_new_meshing: shutil.rmtree(dir_meshes) for p in dict_paths_tmp: os.remove(p) shutil.rmtree(cd_dir, ignore_errors=True) shutil.rmtree(dir_props, ignore_errors=True) if qu.batchjob_enabled(): # remove job directory of `map_subcell_extract_props` shutil.rmtree(os.path.abspath(path_to_out + "/../"), ignore_errors=True) log.debug("Time overview [map_subcell_extract_props]:") for ii in range(len(all_times)): log.debug("%s: %.3fs" % (step_names[ii], all_times[ii])) log.debug("--------------------------") log.debug("Total Time: %.1f min" % (np.sum(all_times) / 60)) log.debug("--------------------------")
def _map_subcell_extract_props_thread(args): """ This function is a worker thread for the map_subcell_extract_props function. It extracts properties and mapping information for each chunk of the dataset and stores them in property dictionaries for cellular and subcellular structures. It also generates meshes for each chunk if specified. Args: args (list): A list containing the following elements: - chunks (list): List of tuples, where each tuple contains the offset and chunk id. - chunk_size (tuple): Size of the chunk. - kd_cell_p (str): Path to the KnossosDataset for the cell. - kd_subcell_ps (dict): Dictionary with organelle names as keys and paths to their respective KnossosDatasets as values. - worker_nr (int): Worker number. - generate_sv_mesh (bool): Flag to indicate whether to generate meshes for the chunk. Returns: tuple: A tuple containing the worker number and a dictionary with references to partial results of each object. """ chunks = args[0] chunk_size = args[1] kd_cell_p = args[2] kd_subcell_ps = args[3] # Dict worker_nr = args[4] generate_sv_mesh = args[5] worker_dir_meshes = f"{global_params.config.temp_path}/tmp_meshes/meshes_{worker_nr}/" os.makedirs(worker_dir_meshes, exist_ok=True) worker_dir_props = f"{global_params.config.temp_path}/tmp_props/props_{worker_nr}/" os.makedirs(worker_dir_props, exist_ok=True) kd_cell = basics.kd_factory(kd_cell_p) kd_subcells = {k: basics.kd_factory(kd_subcell_p) for k, kd_subcell_p in kd_subcell_ps.items()} n_subcell = len(kd_subcells) min_obj_vx = global_params.config['cell_objects']['min_obj_vx'] downsampling_dc = global_params.config['meshes']['downsampling'] # cell property dicts cpd_lst = [{}, defaultdict(list), {}] # subcell. property dicts scpd_lst = [[{}, defaultdict(list), {}] for _ in range(n_subcell)] # subcell. mapping dicts scmd_lst = [{} for _ in range(n_subcell)] # existing_oragnelles has the same ordering as kd_subcells.keys() and kd_subcell_p existing_oragnelles = kd_subcells.keys() # objects that are not purely inside this chunk ref_mesh_dict = dict() ref_mesh_dict['sv'] = dict() for organelle in existing_oragnelles: ref_mesh_dict[organelle] = dict() dt_times_dc = {'find_mesh': 0, 'mesh_io': 0, 'data_io': 0, 'overall': 0, 'prop_dicts_extract': 0} # iterate over chunks and store information in property dicts for # subcellular and cellular structures start_all = time.time() for offset, ch_id in chunks: # get all segmentation arrays concatenates as 4D array: [C, X, Y, Z] subcell_d = [] obj_ids_bdry = dict() small_obj_ids_inside = defaultdict(list) for organelle in existing_oragnelles: obj_ids_bdry[organelle] = [] for organelle in kd_subcell_ps: start = time.time() kd_sc = kd_subcells[organelle] subc_d = kd_sc.load_seg(size=chunk_size, offset=offset, mag=1).swapaxes(0, 2) # get objects that are not purely inside this chunk obj_bdry = np.concatenate( [subc_d[0].flat, subc_d[:, 0].flat, subc_d[:, :, 0].flat, subc_d[-1].flat, subc_d[:, -1].flat, subc_d[:, :, -1].flat]) obj_bdry = np.unique(obj_bdry) obj_ids_bdry[organelle] = obj_bdry dt_times_dc['data_io'] += time.time() - start # add auxiliary axis subcell_d.append(subc_d[None,]) subcell_d = np.concatenate(subcell_d) start = time.time() cell_d = kd_cell.load_seg(size=chunk_size, offset=offset, mag=1).swapaxes(0, 2) dt_times_dc['data_io'] += time.time() - start start = time.time() # extract properties and mapping information cell_prop_dicts, subcell_prop_dicts, subcell_mapping_dicts = \ map_subcell_extract_props_func(cell_d, subcell_d) dt_times_dc['prop_dicts_extract'] += time.time() - start # remove objects that are purely inside this chunk and smaller than the size threshold if min_obj_vx['sv'] > 1: obj_bdry = np.concatenate( [cell_d[0].flat, cell_d[:, 0].flat, cell_d[:, :, 0].flat, cell_d[-1].flat, cell_d[:, -1].flat, cell_d[:, :, -1].flat]) obj_bdry = set(np.unique(obj_bdry)) obj_inside = set(list(cell_prop_dicts[0].keys())).difference(obj_bdry) # cell_prop_dicts: [rc, bb, size] for ix in obj_inside: if cell_prop_dicts[2][ix] < min_obj_vx['sv']: small_obj_ids_inside['sv'].append(ix) del cell_prop_dicts[0][ix], cell_prop_dicts[1][ix], cell_prop_dicts[2][ix] # merge cell properties: list list of dicts merge_prop_dicts([cpd_lst, cell_prop_dicts], offset) del cell_prop_dicts # reorder to match [[rc, bb, size], [rc, bb, size]] for e.g. [mi, vc] subcell_prop_dicts = [[subcell_prop_dicts[0][ii], subcell_prop_dicts[1][ii], subcell_prop_dicts[2][ii]] for ii in range(n_subcell)] # remove objects that are purely inside this chunk and smaller than the size threshold for ii, organelle in enumerate(existing_oragnelles): if min_obj_vx[organelle] > 1: # subcell_prop_dicts: [[rc, bb, size], [rc, bb, size], ..] obj_bdry = obj_ids_bdry[organelle] obj_inside = set(list(subcell_prop_dicts[ii][0].keys())).difference(obj_bdry) for ix in obj_inside: if subcell_prop_dicts[ii][2][ix] < min_obj_vx[organelle]: small_obj_ids_inside[organelle].append(ix) del subcell_prop_dicts[ii][0][ix], subcell_prop_dicts[ii][1][ix] del subcell_prop_dicts[ii][2][ix] if ix in subcell_mapping_dicts[ii]: # could not be mapped to cell sv del subcell_mapping_dicts[ii][ix] merge_map_dicts([scmd_lst[ii], subcell_mapping_dicts[ii]]) merge_prop_dicts([scpd_lst[ii], subcell_prop_dicts[ii]], offset) del subcell_mapping_dicts, subcell_prop_dicts if global_params.config.use_new_meshing: for ii, organelle in enumerate(kd_subcell_ps): ch_cache_exists = False # do not redo done work in case this worker is restarted due to memory issues. p = f"{worker_dir_meshes}/{organelle}_{worker_nr}_ch{ch_id}.pkl" if os.path.isfile(p): try: start = time.time() tmp_subcell_meshes = basics.load_pkl2obj(p) dt_times_dc['mesh_io'] += time.time() - start except Exception as e: log_proc.error(f'Exception raised when loading ' f'mesh cache {p}:\n{e}') else: if min_obj_vx[organelle] > 1: for ix in small_obj_ids_inside[organelle]: # the cache was pruned in an early version # of the code before it got dumped -> check if it exists if ix in tmp_subcell_meshes: del tmp_subcell_meshes[ix] ref_mesh_dict[organelle][ch_id] = list(tmp_subcell_meshes.keys()) del tmp_subcell_meshes ch_cache_exists = True if not ch_cache_exists: start = time.time() tmp_subcell_meshes = find_meshes(subcell_d[ii], offset, pad=1, ds=downsampling_dc[organelle]) dt_times_dc['find_mesh'] += time.time() - start start = time.time() output_worker = open(p, 'wb') pkl.dump(tmp_subcell_meshes, output_worker, protocol=4) output_worker.close() dt_times_dc['mesh_io'] += time.time() - start if min_obj_vx[organelle] > 1: for ix in small_obj_ids_inside[organelle]: if ix in tmp_subcell_meshes: del tmp_subcell_meshes[ix] # store reference to partial results of each object ref_mesh_dict[organelle][ch_id] = list(tmp_subcell_meshes.keys()) del tmp_subcell_meshes del subcell_d # collect subcell properties: list of list of dicts # collect subcell mappings to cell SVs: list of list of # dicts and list of dict of dict of int if generate_sv_mesh and global_params.config.use_new_meshing: # do not redo done work in case this worker is restarted due to memory issues. ch_cache_exists = False p = f"{worker_dir_meshes}/sv_{worker_nr}_ch{ch_id}.pkl" if os.path.isfile(p): try: start = time.time() tmp_cell_mesh = basics.load_pkl2obj(p) dt_times_dc['mesh_io'] += time.time() - start except Exception as e: log_proc.error(f'Exception raised when loading mesh cache {p}:' f'\n{e}') else: if min_obj_vx['sv'] > 1: for ix in small_obj_ids_inside['sv']: # the cache was pruned in an early version # of the code before it got dumped -> check if ID exists if ix in tmp_cell_mesh: del tmp_cell_mesh[ix] ref_mesh_dict['sv'][ch_id] = list(tmp_cell_mesh.keys()) del tmp_cell_mesh ch_cache_exists = True if not ch_cache_exists: start = time.time() tmp_cell_mesh = find_meshes(cell_d, offset, pad=1, ds=downsampling_dc['sv']) dt_times_dc['find_mesh'] += time.time() - start start = time.time() output_worker = open(p, 'wb') pkl.dump(tmp_cell_mesh, output_worker, protocol=4) output_worker.close() dt_times_dc['mesh_io'] += time.time() - start if min_obj_vx['sv'] > 1: for ix in small_obj_ids_inside['sv']: if ix in tmp_cell_mesh: del tmp_cell_mesh[ix] # store reference to partial results of each object ref_mesh_dict['sv'][ch_id] = list(tmp_cell_mesh.keys()) del tmp_cell_mesh del cell_d gc.collect() # write worker results basics.write_obj2pkl(f'{worker_dir_props}/cp_{worker_nr}.pkl', cpd_lst) del cpd_lst for ii, organelle in enumerate(existing_oragnelles): basics.write_obj2pkl(f'{worker_dir_props}/scp_{organelle}_{worker_nr}.pkl', scpd_lst[ii]) basics.write_obj2pkl(f'{worker_dir_props}/scm_{organelle}_{worker_nr}.pkl', scmd_lst[ii]) del scmd_lst del scpd_lst if global_params.config.use_new_meshing: dt_times_dc['overall'] = time.time() - start_all dt_str = ["{:<20}".format(f"{k}: {v:.2f}s") for k, v in dt_times_dc.items()] # log_proc.debug('{}'.format("".join(dt_str))) return worker_nr, ref_mesh_dict def _write_props_to_sc_thread(args): """ This function writes the properties of subcellular structures to the corresponding thread. It iterates over the subcellular structures, gets cached mapping and property dictionaries of the current subcellular structure, loads target storage folders for all objects in this chunk, and loads properties and mapping info. It also trims mesh info to objects of interest and fetches all required mesh data. Finally, it writes the properties to the attribute dictionary for this batch of object IDs. Args: args (list): A list containing the object ID chunks, the number of folders in the file system, and a dictionary of kd paths for the subcellular structures. Returns: None """ obj_id_chs = args[0] n_folders_fs = args[1] kd_subcell_ps = args[2] # Dict of kd paths if global_params.config.use_new_subfold: target_dir_func = rep_helper.subfold_from_ix_new else: target_dir_func = rep_helper.subfold_from_ix_OLD mesh_min_obj_vx = global_params.config['meshes']['mesh_min_obj_vx'] global_tmp_path = global_params.config.temp_path # iterate over the subcell structures for organelle in kd_subcell_ps: min_obj_vx = global_params.config['cell_objects']['min_obj_vx'][organelle] # get cached mapping and property dicts of current subcellular structure sc_prop_worker_dc = f"{global_tmp_path}/sc_{organelle}_prop_worker_dict.pkl" with open(sc_prop_worker_dc, "rb") as f: subcell_prop_workers_tmp = pkl.load(f) # load target storage folders for all objects in this chunk dest_dc = dict() dest_dc_tmp = CompressedStorage(f'{global_tmp_path}/storage_targets_' f'{organelle}.pkl', disable_locking=True) all_obj_keys = set() for obj_id_mod in obj_id_chs: k = target_dir_func(obj_id_mod, n_folders_fs) if k not in dest_dc_tmp: value = np.array([], dtype=np.uint64) # TODO: dtype needs to be configurable else: value = dest_dc_tmp[k] all_obj_keys.update(set(value)) dest_dc[k] = value del dest_dc_tmp if len(all_obj_keys) == 0: continue # Now given to IDs of interest, load properties and mapping info prop_dict = [{}, defaultdict(list), {}] mapping_dict = dict() for worker_id, obj_ids in subcell_prop_workers_tmp.items(): intersec = set(obj_ids).intersection(all_obj_keys) if len(intersec) > 0: worker_dir_props = f"{global_tmp_path}/tmp_props/props_{worker_id}/" fname = f'{worker_dir_props}/scp_{organelle}_{worker_id}.pkl' dc = basics.load_pkl2obj(fname) tmp_dcs = [dict(), defaultdict(list), dict()] for k in intersec: tmp_dcs[0][k] = dc[0][k] tmp_dcs[1][k] = dc[1][k] tmp_dcs[2][k] = dc[2][k] del dc merge_prop_dicts([prop_dict, tmp_dcs]) del tmp_dcs # TODO: optimize as above - by creating a temporary dictionary with the intersecting IDs only fname = f'{worker_dir_props}/scm_{organelle}_{worker_id}.pkl' dc = basics.load_pkl2obj(fname) for k in list(dc.keys()): if k not in all_obj_keys: del dc[k] # store number of overlap voxels merge_map_dicts([mapping_dict, dc]) del dc del subcell_prop_workers_tmp # Trim mesh info to objects of interest # keys: chunk IDs, values: (worker_nr, object IDs) sc_mesh_worker_dc_p = f"{global_tmp_path}/sc_{organelle}_mesh_worker_dict.pkl" with open(sc_mesh_worker_dc_p, "rb") as f: subcell_mesh_workers_tmp = pkl.load(f) # convert into object ID -> worker_ids -> chunk_ids sc_mesh_worker_dc = {k: defaultdict(list) for k in all_obj_keys} for ch_id, (worker_id, obj_ids) in subcell_mesh_workers_tmp.items(): for k in set(obj_ids).intersection(all_obj_keys): sc_mesh_worker_dc[k][worker_id].append(ch_id) del subcell_mesh_workers_tmp # get SegmentationDataset of current subcell. sc_sd = segmentation.SegmentationDataset( n_folders_fs=n_folders_fs, obj_type=organelle, working_dir=global_params.config.working_dir, version=0) # iterate over the subcellular SV ID chunks for obj_id_mod in obj_id_chs: obj_keys = dest_dc[target_dir_func(obj_id_mod, n_folders_fs)] obj_keys = set(obj_keys) # fetch all required mesh data if global_params.config.use_new_meshing: # get cached mesh dicts for segmentation object 'organelle' cached_mesh_dc = defaultdict(list) worker_ids = defaultdict(set) for k in obj_keys: # prop_dict contains [rc, bb, size] of the objects s = prop_dict[2][k] if (s < mesh_min_obj_vx) or (s < min_obj_vx): # do not load mesh-cache of small objects continue # # Decided to not add the exclusion of too big SJs # # leaving the code here in case this opinion changes. # bb = prop_dict[1][k] * scaling # check bounding box diagonal # bbd = np.linalg.norm(bb[1] - bb[0], ord=2) # if organelle == 'sj' and bbd > max_bb_sj: # continue for worker_id, ch_ids in sc_mesh_worker_dc[k].items(): worker_ids[worker_id].update(ch_ids) for worker_nr, chunk_ids in worker_ids.items(): for ch_id in chunk_ids: p = f"{global_tmp_path}/tmp_meshes/meshes_{worker_nr}/" \ f"{organelle}_{worker_nr}_ch{ch_id}.pkl" with open(p, "rb") as pkl_file: partial_mesh_dc = pkl.load(pkl_file) # only load keys which are part of the worker's chunk for el in obj_keys.intersection(set(list(partial_mesh_dc.keys()))): cached_mesh_dc[el].append(partial_mesh_dc[el]) # get dummy segmentation object to fetch attribute # dictionary for this batch of object IDs dummy_so = sc_sd.get_segmentation_object(obj_id_mod) attr_p = dummy_so.attr_dict_path vx_p = dummy_so.voxel_path this_attr_dc = AttributeDict(attr_p, read_only=False, disable_locking=True) voxel_dc = VoxelStorageDyn( vx_p, voxel_mode=False, read_only=False, disable_locking=True, voxeldata_path=global_params.config.kd_organelle_seg_paths[organelle]) if global_params.config.use_new_meshing: obj_mesh_dc = MeshStorage(dummy_so.mesh_path, disable_locking=True, read_only=False) for sc_id in obj_keys: size = prop_dict[2][sc_id] if size < min_obj_vx: continue if sc_id in mapping_dict: # TODO: remove the properties mapping_ratios and mapping_ids as # they not need to be stored with the sub-cellular objects anymore (make sure to delete # `correct_for_background` in _apply_mapping_decisions_thread this_attr_dc[sc_id]["mapping_ids"] = \ list(mapping_dict[sc_id].keys()) # normalize to the objects total number of voxels this_attr_dc[sc_id]["mapping_ratios"] = \ [v / size for v in mapping_dict[sc_id].values()] else: this_attr_dc[sc_id]["mapping_ids"] = [] this_attr_dc[sc_id]["mapping_ratios"] = [] rp = np.array(prop_dict[0][sc_id], dtype=np.int32) bbs = np.concatenate(prop_dict[1][sc_id]) size = prop_dict[2][sc_id] this_attr_dc[sc_id]["rep_coord"] = rp this_attr_dc[sc_id]["bounding_box"] = np.array( [bbs[:, 0].min(axis=0), bbs[:, 1].max(axis=0)]) this_attr_dc[sc_id]["size"] = size voxel_dc[sc_id] = bbs # TODO: make use of these stored properties # downstream during the reduce step (requires clarification) voxel_dc.increase_object_size(sc_id, size) voxel_dc.set_object_repcoord(sc_id, rp) if global_params.config.use_new_meshing: try: partial_meshes = cached_mesh_dc[sc_id] except KeyError: # object has size < 10 partial_meshes = [] del cached_mesh_dc[sc_id] list_of_ind = [] list_of_ver = [] list_of_norm = [] for single_mesh in partial_meshes: list_of_ind.append(single_mesh[0]) list_of_ver.append(single_mesh[1]) list_of_norm.append(single_mesh[2]) mesh = merge_meshes_incl_norm(list_of_ind, list_of_ver, list_of_norm) obj_mesh_dc[sc_id] = mesh verts = mesh[1].reshape(-1, 3) if len(verts) > 0: mesh_bb = [np.min(verts, axis=0), np.max(verts, axis=0)] del verts this_attr_dc[sc_id]["mesh_bb"] = mesh_bb this_attr_dc[sc_id]["mesh_area"] = mesh_area_calc(mesh) else: this_attr_dc[sc_id]["mesh_bb"] = this_attr_dc[sc_id]["bounding_box"] * \ dummy_so.scaling this_attr_dc[sc_id]["mesh_area"] = 0 voxel_dc.push() this_attr_dc.push() if global_params.config.use_new_meshing: obj_mesh_dc.push() del obj_mesh_dc, this_attr_dc, voxel_dc gc.collect() def _write_props_to_sv_thread(args): """ This function is a worker thread that writes properties to the supervoxel (SV) thread. It loads the cached mapping and property dictionaries of the current subcellular structure, loads the target storage folders for all objects in the chunk, and loads properties and mapping info. It also fetches all required mesh data and writes the properties and mapping info to the attribute dictionary for the batch of object IDs. Args: args (list): A list containing the object ID chunks, the number of folders in the file system, a boolean indicating whether to generate SV mesh, and the processed organelles. Returns: None """ obj_id_chs = args[0] n_folders_fs = args[1] generate_sv_mesh = args[2] processsed_organelles = args[3] dt_loading_cache = time.time() if global_params.config.use_new_subfold: target_dir_func = rep_helper.subfold_from_ix_new else: target_dir_func = rep_helper.subfold_from_ix_OLD mesh_min_obj_vx = global_params.config['meshes']['mesh_min_obj_vx'] min_obj_vx = global_params.config['cell_objects']['min_obj_vx']['sv'] global_tmp_path = global_params.config.temp_path wd = global_params.config.working_dir # get cached mapping and property dicts of current subcellular structure c_prop_worker_dc = f"{global_tmp_path}/c_prop_worker_dict.pkl" with open(c_prop_worker_dc, "rb") as f: cell_prop_workers_tmp = pkl.load(f) # load target storage folders for all objects in this chunk dest_dc = dict() dest_dc_tmp = CompressedStorage(f'{global_tmp_path}/storage_targets_sv.pkl', disable_locking=True) all_obj_keys = set() for obj_id_mod in obj_id_chs: k = target_dir_func(obj_id_mod, n_folders_fs) if k not in dest_dc_tmp: value = np.array([], dtype=np.uint64) # TODO: dtype needs to be configurable else: value = dest_dc_tmp[k] all_obj_keys.update(set(value)) dest_dc[k] = value if len(all_obj_keys) == 0: return del dest_dc_tmp # Now given to IDs of interest, load properties and mapping info prop_dict = [{}, defaultdict(list), {}] mapping_dicts = {k: {} for k in processsed_organelles} # No size threshold applied in mapping dict as it would require loading the property # dictionaries -> when mapping decision is made on cell level non-existing organelles are # assumed to be below the size threshold. for worker_id, obj_ids in cell_prop_workers_tmp.items(): intersec = set(obj_ids).intersection(all_obj_keys) if len(intersec) > 0: worker_dir_props = f"{global_tmp_path}/tmp_props/props_{worker_id}/" fname = f'{worker_dir_props}/cp_{worker_id}.pkl' dc = basics.load_pkl2obj(fname) tmp_dcs = [dict(), defaultdict(list), dict()] for k in intersec: tmp_dcs[0][k] = dc[0][k] tmp_dcs[1][k] = dc[1][k] tmp_dcs[2][k] = dc[2][k] del dc merge_prop_dicts([prop_dict, tmp_dcs]) del tmp_dcs # TODO: optimize as above - by creating a temporary dictionary with the intersecting IDs only for organelle in processsed_organelles: fname = f'{worker_dir_props}/scm_{organelle}_{worker_id}.pkl' dc = basics.load_pkl2obj(fname) dc = invert_mdc(dc) # invert to have cell IDs in top layer for k in list(dc.keys()): if k not in all_obj_keys: del dc[k] # should behave also for inverted dicts merge_map_dicts([mapping_dicts[organelle], dc]) del dc del cell_prop_workers_tmp for md_k in mapping_dicts.keys(): md = mapping_dicts[md_k] md = invert_mdc(md) # invert to have organelle IDs in highest layer sd_sc = segmentation.SegmentationDataset( obj_type=md_k, working_dir=wd, version=0) size_dc = {k: v for k, v in zip(sd_sc.ids, sd_sc.sizes) if k in md} del sd_sc # normalize overlap with respect to the objects total size for subcell_id in list(md.keys()): # size threshold for objects at the chunk boundary is not applied # when mapping dictionaries are written, therefore objects that # are not part of the SD have been removed. if subcell_id not in size_dc: del md[subcell_id] continue subcell_dc = md[subcell_id] for k, v in subcell_dc.items(): # normalize with respect to the number of voxels of the object subcell_dc[k] = v / size_dc[subcell_id] del size_dc md = invert_mdc(md) # invert to have cell SV IDs in highest layer mapping_dicts[md_k] = md if global_params.config.use_new_meshing and generate_sv_mesh: c_mesh_worker_dc = f"{global_tmp_path}/c_mesh_worker_dict.pkl" with open(c_mesh_worker_dc, "rb") as f: cell_mesh_workers_tmp = pkl.load(f) # convert into object ID -> worker_ids -> chunk_ids mesh_worker_dc = {k: defaultdict(list) for k in all_obj_keys} for ch_id, (worker_id, obj_ids) in cell_mesh_workers_tmp.items(): for k in set(obj_ids).intersection(all_obj_keys): mesh_worker_dc[k][worker_id].append(ch_id) del cell_mesh_workers_tmp dt_loading_cache = time.time() - dt_loading_cache # log_proc.debug('[SV] loaded cache dicts after {:.2f} min.'.format( # dt_loading_cache / 60)) # fetch all required mesh data dt_mesh_merge_io = 0 # get SegmentationDataset of cell SV sv_sd = segmentation.SegmentationDataset( n_folders_fs=n_folders_fs, obj_type="sv", working_dir=wd, version=0) # iterate over the subcellular SV ID chunks dt_mesh_area = 0 dt_mesh_merge = 0 # without io for obj_id_mod in obj_id_chs: obj_keys = dest_dc[target_dir_func(obj_id_mod, n_folders_fs)] obj_keys = set(obj_keys) # load meshes of current batch if global_params.config.use_new_meshing and generate_sv_mesh: # get cached mesh dicts for segmentation object k cached_mesh_dc = defaultdict(list) start = time.time() worker_ids = defaultdict(set) for k in obj_keys: # ignore mesh of small objects s = prop_dict[2][k] if (s < mesh_min_obj_vx) or (s < min_obj_vx): continue # mesh_worker_dc contains a dict with keys: worker_nr (int) and chunk_ids (set) for worker_id, ch_ids in mesh_worker_dc[k].items(): worker_ids[worker_id].update(ch_ids) # log_proc.debug('Loading meshes of {} SVs from {} worker ' # 'caches.'.format(len(obj_keys), len(worker_ids))) for worker_nr, chunk_ids in worker_ids.items(): for ch_id in chunk_ids: p = f"{global_tmp_path}/tmp_meshes/meshes_{worker_nr}/" \ f"sv_{worker_nr}_ch{ch_id}.pkl" pkl_file = open(p, "rb") partial_mesh_dc = pkl.load(pkl_file) pkl_file.close() # only loaded keys which are part of the worker's chunk for el in obj_keys.intersection(set(list(partial_mesh_dc.keys()))): cached_mesh_dc[el].append(partial_mesh_dc[el]) dt_mesh_merge_io += time.time() - start # get dummy segmentation object to fetch attribute dictionary for this batch of object IDs dummy_so = sv_sd.get_segmentation_object(obj_id_mod) attr_p = dummy_so.attr_dict_path vx_p = dummy_so.voxel_path this_attr_dc = AttributeDict(attr_p, read_only=False, disable_locking=True) voxel_dc = VoxelStorageDyn(vx_p, voxel_mode=False, voxeldata_path=global_params.config.kd_seg_path, read_only=False, disable_locking=True) obj_mesh_dc = MeshStorage(dummy_so.mesh_path, disable_locking=True, read_only=False) for sv_id in obj_keys: size = prop_dict[2][sv_id] if size < min_obj_vx: continue for k in processsed_organelles: if sv_id not in mapping_dicts[k]: # no object of this type mapped to current cell SV this_attr_dc[sv_id][f"mapping_{k}_ids"] = [] this_attr_dc[sv_id][f"mapping_{k}_ratios"] = [] continue this_attr_dc[sv_id][f"mapping_{k}_ids"] = \ list(mapping_dicts[k][sv_id].keys()) this_attr_dc[sv_id][f"mapping_{k}_ratios"] = \ list(mapping_dicts[k][sv_id].values()) rp = np.array(prop_dict[0][sv_id], dtype=np.int32) bbs = np.concatenate(prop_dict[1][sv_id]) this_attr_dc[sv_id]["rep_coord"] = rp this_attr_dc[sv_id]["bounding_box"] = np.array( [bbs[:, 0].min(axis=0), bbs[:, 1].max(axis=0)]) this_attr_dc[sv_id]["size"] = size voxel_dc[sv_id] = bbs # TODO: make use of these stored properties downstream during the reduce step voxel_dc.increase_object_size(sv_id, size) voxel_dc.set_object_repcoord(sv_id, rp) if generate_sv_mesh and global_params.config.use_new_meshing: try: partial_meshes = cached_mesh_dc[sv_id] except KeyError: # object has small number of voxels partial_meshes = [] del cached_mesh_dc[sv_id] list_of_ind = [] list_of_ver = [] list_of_norm = [] for single_mesh in partial_meshes: list_of_ind.append(single_mesh[0]) list_of_ver.append(single_mesh[1]) list_of_norm.append(single_mesh[2]) start2 = time.time() mesh = merge_meshes_incl_norm(list_of_ind, list_of_ver, list_of_norm) dt_mesh_merge += time.time() - start2 obj_mesh_dc[sv_id] = mesh start = time.time() verts = mesh[1].reshape(-1, 3) if len(verts) > 0: mesh_bb = np.array([np.min(verts, axis=0), np.max(verts, axis=0)], dtype=np.float32) del verts this_attr_dc[sv_id]["mesh_bb"] = mesh_bb this_attr_dc[sv_id]["mesh_area"] = mesh_area_calc(mesh) else: this_attr_dc[sv_id]["mesh_bb"] = this_attr_dc[sv_id]["bounding_box"] * \ dummy_so.scaling this_attr_dc[sv_id]["mesh_area"] = 0 dt_mesh_area += time.time() - start voxel_dc.push() this_attr_dc.push() if global_params.config.use_new_meshing: obj_mesh_dc.push() # if global_params.config.use_new_meshing: # log_proc.debug(f'[SV] dt mesh area {dt_mesh_area:.2f}s\tdt mesh merge ' # f'{dt_mesh_merge:.2f}s\tdt merge IO ' # f'{dt_mesh_merge_io:.2f}s')
[docs]def merge_meshes_dict(m_storage, tmp_dict): """ This function merges mesh dictionaries. It iterates over the object IDs in the temporary dictionary and merges the meshes for each object ID. Args: m_storage (dict): A dictionary where the key is the object ID and the value is a list of faces, vertices, and normals of the mesh. tmp_dict (dict): A temporary dictionary where the key is the object ID and the value is a list of faces, vertices, and normals of the mesh. Returns: None """ for obj_id in tmp_dict: merge_meshes_single(m_storage, obj_id, tmp_dict[obj_id])
[docs]def merge_meshes_single(m_storage, obj_id, tmp_dict): """ Merges mesh dictionaries for a single object. This function takes in a MeshStorage object, an object id, and a temporary dictionary, and merges the temporary dictionary into the MeshStorage object. Args: m_storage (MeshStorage): A MeshStorage object where the merged data will be stored. obj_id (int): The id of the object whose mesh data is being merged. tmp_dict (dict): A temporary dictionary containing mesh data to be merged into the MeshStorage object. """ if obj_id not in m_storage: m_storage[obj_id] = [tmp_dict[0], tmp_dict[1], tmp_dict[2]] else: # TODO: this needs to be a parameter -> add global parameter for face type n_el = int((len(m_storage[obj_id][1])) / 3) m_storage[obj_id][0] = np.concatenate((m_storage[obj_id][0], tmp_dict[0] + n_el)) m_storage[obj_id][1] = np.concatenate((m_storage[obj_id][1], tmp_dict[1])) m_storage[obj_id][2] = np.concatenate((m_storage[obj_id][2], tmp_dict[2]))
[docs]def merge_prop_dicts(prop_dicts: List[List[dict]], offset: Optional[np.ndarray] = None): """ Merges property dictionaries in-place. All values will be stored in the first dictionary. If an offset is provided, it is added to the representative coordinates and bounding boxes of the objects. Args: prop_dicts (List[List[dict]]): A list of property dictionaries to be merged. offset (Optional[np.ndarray]): An optional offset to be added to the representative coordinates and bounding boxes of the objects. """ tot_rc = prop_dicts[0][0] tot_bb = prop_dicts[0][1] tot_size = prop_dicts[0][2] for el in prop_dicts[1:]: if len(el[0]) == 0: continue if offset is not None: # update chunk offset # TODO: could be done at the end of the map_extract cython code for k in el[0]: el[0][k] = [el[0][k][ii] + offset[ii] for ii in range(3)] tot_rc.update(el[0]) # just overwrite existing elements for k, v in el[1].items(): if offset is None: bb = v else: bb = [[v[0][ii] + offset[ii] for ii in range(3)], [v[1][ii] + offset[ii] for ii in range(3)]] # collect all bounding boxes to enable efficient data loading tot_bb[k].append(bb) for k, v in el[2].items(): if k in tot_size: tot_size[k] += v else: tot_size[k] = v
[docs]def convert_nvox2ratio_mapdict(map_dc): """ Converts the number of overlapping voxels of each subcellular structure object inside the mapping dictionaries to each cell SV (subcell ID -> cell ID -> number overlap vxs) to fraction. Args: map_dc (dict): A dictionary mapping subcellular structure objects to cell SVs and the number of overlapping voxels. """ # TODO consider to implement in cython for subcell_id, subcell_dc in map_dc.items(): s = np.sum(list(subcell_dc.values())) # total number of overlap voxels for k, v in subcell_dc.items(): map_dc[subcell_id][k] = subcell_dc[k] / s
[docs]def invert_mdc(mapping_dict): """ Inverts a mapping dictionary to: cell ID -> subcell ID -> value (ratio or voxel count). Args: mapping_dict (dict): A dictionary mapping subcellular structure objects to cell SVs and the number of overlapping voxels or the ratio of overlapping voxels. Returns: dict: The inverted mapping dictionary. """ mdc_inv = {} for subcell_id, subcell_dc in mapping_dict.items(): for cell_id, v in subcell_dc.items(): if cell_id not in mdc_inv: mdc_inv[cell_id] = {subcell_id: v} else: mdc_inv[cell_id][subcell_id] = v return mdc_inv
[docs]def merge_map_dicts(map_dicts): """ Merges map dictionaries in-place. Values will be stored in the first dictionary. Args: map_dicts (List[dict]): A list of map dictionaries to be merged. """ tot_map = map_dicts[0] for el in map_dicts[1:]: # iterate over subcell. ids with dictionaries as values which store # the number of overlap voxels to cell SVs for sc_id, sc_dc in el.items(): if sc_id in tot_map: for cellsv_id, ol_vx_cnt in sc_dc.items(): if cellsv_id in tot_map[sc_id]: tot_map[sc_id][cellsv_id] += ol_vx_cnt else: tot_map[sc_id][cellsv_id] = ol_vx_cnt else: tot_map[sc_id] = sc_dc
[docs]def init_sos(sos_dict): """ Initializes a list of SegmentationObjects from a dictionary. Args: sos_dict (dict): A dictionary containing parameters for initializing SegmentationObjects. Returns: list: A list of initialized SegmentationObjects. """ loc_dict = sos_dict.copy() svixs = loc_dict["svixs"] del loc_dict["svixs"] sos = [segmentation.SegmentationObject(ix, **loc_dict) for ix in svixs] return sos
[docs]def sos_dict_fact(svixs, version=None, scaling=None, obj_type="sv", working_dir=None, create=False): """ Creates a dictionary with parameters for initializing SegmentationObjects. Args: svixs (list): A list of segmentation object indices. version (str, optional): The version of the SegmentationObjects. Defaults to None. scaling (list, optional): The scaling factors for the SegmentationObjects. Defaults to None. obj_type (str, optional): The type of the SegmentationObjects. Defaults to "sv". working_dir (str, optional): The working directory for the SegmentationObjects. Defaults to None. create (bool, optional): Whether to create the SegmentationObjects if they do not exist. Defaults to False. Returns: dict: A dictionary with parameters for initializing SegmentationObjects. """ if working_dir is None: working_dir = global_params.config.working_dir if scaling is None: scaling = global_params.config['scaling'] sos_dict = {"svixs": svixs, "version": version, "working_dir": working_dir, "scaling": scaling, "create": create, "obj_type": obj_type} return sos_dict
[docs]def predict_sos_views(model, sos, pred_key, nb_cpus=1, woglia=True, verbose=False, raw_only=False, single_cc_only=False, return_proba=False): """ Predicts the views of a list of SegmentationObjects using a given model. Args: model (nn.Model): The model to use for prediction. sos (list): A list of SegmentationObjects whose views are to be predicted. pred_key (str): The key to use for storing the predictions. nb_cpus (int, optional): The number of CPUs to use for prediction. Defaults to 1. woglia (bool, optional): Whether to exclude glia from the prediction. Defaults to True. verbose (bool, optional): Whether to print verbose output. Defaults to False. raw_only (bool, optional): Whether to use only raw data for prediction. Defaults to False. single_cc_only (bool, optional): Whether to use only single connected components for prediction. Defaults to False. return_proba (bool, optional): Whether to return the probabilities of the predictions. Defaults to False. Returns: np.ndarray: The predicted views if return_proba is True, otherwise None. """ nb_chunks = np.max([1, len(sos) // 200]) so_chs = basics.chunkify(sos, nb_chunks) all_probas = [] if verbose: pbar = tqdm.tqdm(total=len(sos), leave=False) for ch in so_chs: views = sm.start_multiprocess_obj("load_views", [[sv, {"woglia": woglia, "raw_only": raw_only}] for sv in ch], nb_cpus=nb_cpus) proba = predict_views(model, views, ch, pred_key, verbose=False, single_cc_only=single_cc_only, return_proba=return_proba, nb_cpus=nb_cpus) if verbose: pbar.update(len(ch)) if return_proba: all_probas.append(np.concatenate(proba)) if verbose: pbar.close() if return_proba: return np.concatenate(all_probas)
[docs]def predict_views(model, views, ch, pred_key, single_cc_only=False, verbose=False, return_proba=False, nb_cpus=1) -> Optional[List[np.ndarray]]: """ Predicts the views of a list of SegmentationObjects using a given model. The predictions are not written to disk if return_proba is True. Args: model (nn.Model): The model to use for prediction. views (np.array): An array of views to be predicted. ch (List[SegmentationObject]): A list of SegmentationObjects whose views are to be predicted. pred_key (str): The key to use for storing the predictions. single_cc_only (bool, optional): Whether to use only single connected components for prediction. Defaults to False. verbose (bool, optional): Whether to print verbose output. Defaults to False. return_proba (bool, optional): Whether to return the probabilities of the predictions. Defaults to False. nb_cpus (int, optional): The number of CPUs to use for prediction. Defaults to 1. Returns: Optional[List[np.ndarray]]: The predicted views if return_proba is True, otherwise None. """ if single_cc_only: for kk in range(len(views)): data = views[kk] for i in range(len(data)): sing_cc = np.concatenate([single_conn_comp_img(data[i, 0, :1]), single_conn_comp_img(data[i, 0, 1:])]) data[i, 0] = sing_cc views[kk] = data part_views = np.cumsum([0] + [len(v) for v in views]) assert len(part_views) == len(views) + 1 views = np.concatenate(views) probas = model.predict_proba(views, verbose=verbose) so_probas = [] for ii, _ in enumerate(part_views[:-1]): sv_probas = probas[part_views[ii]:part_views[ii + 1]] so_probas.append(sv_probas) assert len(part_views) == len(so_probas) + 1 if return_proba: return so_probas if nb_cpus > 1: # make sure locking is enabled if multiprocessed for so in ch: so.enable_locking = True params = [[so, prob, pred_key] for so, prob in zip(ch, so_probas)] sm.start_multiprocess(multi_probas_saver, params, nb_cpus=nb_cpus)
[docs]def multi_probas_saver(args): """ Saves the probabilities of predictions for a list of SegmentationObjects. Args: args (tuple): A tuple containing a SegmentationObject, the probabilities of its predictions, and the key to use for storing the predictions. """ so, probas, key = args so.save_attributes([key], [probas])
[docs]def mesh_proc_chunked(working_dir, obj_type, nb_cpus=None): """ Caches the meshes for all SegmentationObjects within the SegmentationDataset with a given object type. Args: working_dir (str): The working directory for the SegmentationDataset. obj_type (str): The type of the SegmentationObjects whose meshes are to be cached. Object type identifier, like 'sj', 'vc' or 'mi'. nb_cpus (int, optional): The number of CPUs to use for caching. Defaults to the number of cores per node in the global parameters. Default is 20. """ if nb_cpus is None: nb_cpus = global_params.config['ncores_per_node'] sd = segmentation.SegmentationDataset(obj_type, working_dir=working_dir) multi_params = sd.so_dir_paths sm.start_multiprocess_imap(mesh_chunk, multi_params, nb_cpus=nb_cpus, debug=False)