Source code for syconn.extraction.cs_processing_steps

# -*- coding: utf-8 -*-
# SyConn - Synaptic connectivity inference toolkit
#
# Copyright (c) 2016 - now
# Max Planck Institute of Neurobiology, Martinsried, Germany
# Authors: Philipp Schubert, Joergen Kornfeld
import datetime
import os
import shutil
import time
from collections import defaultdict
from logging import Logger
from typing import Optional, Dict, List, Tuple, TYPE_CHECKING
from itertools import chain

import joblib
import numpy as np
import tqdm
import pandas
from knossos_utils import skeleton_utils, skeleton
from scipy import spatial
from sklearn import ensemble
from sklearn.model_selection import cross_val_predict
from sklearn import metrics
try:
    import open3d as o3d
except ImportError:
    pass  # for sphinx build

from . import log_extraction
from .. import global_params
from ..backend.storage import AttributeDict, VoxelStorageDyn, MeshStorage, VoxelStorageLazyLoading
from ..handler.basics import chunkify
from ..handler.config import initialize_logging
from ..mp import batchjob_utils as qu
from ..mp import mp_utils as sm
from ..reps import segmentation_helper as seghelp
from ..reps.super_segmentation_dataset import filter_ssd_by_total_pathlength
from ..reps import super_segmentation, segmentation, connectivity_helper as ch
from ..reps.rep_helper import subfold_from_ix, ix_from_subfold, get_unique_subfold_ixs
from ..proc.meshes import gen_mesh_voxelmask, calc_contact_syn_mesh


[docs]def collect_properties_from_ssv_partners(wd, obj_version=None, ssd_version=None, debug=False): """ Collect axoness, cell types and spiness from synaptic partners and stores them in syn_ssv objects. Also maps syn_type_sym_ratio to the synaptic sign (-1 for asym., 1 for sym. synapses). The following keys will be available in the ``attr_dict`` of ``syn_ssv`` typed :class:`~syconn.reps.segmentation.SegmentationObject`: * 'partner_axoness': Cell compartment type (axon: 1, dendrite: 0, soma: 2, en-passant bouton: 3, terminal bouton: 4) of the partner neurons. * 'partner_spiness': Spine compartment predictions (0: dendritic shaft, 1: spine head, 2: spine neck, 3: other) of both neurons. * 'partner_spineheadvol': Spinehead volume in µm^3. * 'partner_celltypes': Celltype of the both neurons. * 'latent_morph': Local morphology embeddings of the pre- and post- synaptic partners. Args: wd (str): The working directory. obj_version (str, optional): The version of the object. Defaults to None. ssd_version (str, optional): The version of the super segmentation dataset. Defaults to None. debug (bool, optional): If True, the function will run in debug mode. Defaults to False. """ ssd = super_segmentation.SuperSegmentationDataset(working_dir=wd, version=ssd_version) multi_params = [] for ids_small_chunk in chunkify(ssd.ssv_ids[np.argsort(ssd.load_numpy_data('size'))[::-1]], global_params.config.ncore_total * 2): multi_params.append([wd, obj_version, ssd_version, ids_small_chunk]) if not qu.batchjob_enabled(): _ = sm.start_multiprocess_imap( _collect_properties_from_ssv_partners_thread, multi_params, debug=debug) else: _ = qu.batchjob_script( multi_params, "collect_properties_from_ssv_partners", remove_jobfolder=True) # iterate over paths with syn sd_syn_ssv = segmentation.SegmentationDataset("syn_ssv", working_dir=wd, version=obj_version) multi_params = [] for so_dir_paths in chunkify(sd_syn_ssv.so_dir_paths, global_params.config.ncore_total * 2): multi_params.append([so_dir_paths, wd, obj_version, ssd_version]) if not qu.batchjob_enabled(): _ = sm.start_multiprocess_imap( _from_cell_to_syn_dict, multi_params, debug=debug) else: _ = qu.batchjob_script( multi_params, "from_cell_to_syn_dict", remove_jobfolder=True) log_extraction.debug('Deleting cache dictionaries now.') # delete cache_dicts # TODO: start as thread! sm.start_multiprocess_imap(_delete_all_cache_dc, [(ssv_id, ssd.config) for ssv_id in ssd.ssv_ids], nb_cpus=None) log_extraction.debug('Deleted all cache dictionaries.')
def _collect_properties_from_ssv_partners_thread(args): """ Helper function of 'collect_properties_from_ssv_partners'. Notes: * SSV objects that do not have any mesh vertex will be assigned zero values for all properties. Args args(tuple) : A tuple containing the working directory, object version, super segmentation dataset version, and super segmentation object IDs. (see 'collect_properties_from_ssv_partners') """ wd, obj_version, ssd_version, ssv_ids = args semseg2coords_kwargs = global_params.config['spines']['semseg2coords_spines'] n_embedding = global_params.config['tcmn']['ndim_embedding'] sd_syn_ssv = segmentation.SegmentationDataset(obj_type="syn_ssv", working_dir=wd, version=obj_version) ssd = super_segmentation.SuperSegmentationDataset(working_dir=wd, version=ssd_version) syn_neuronpartners = sd_syn_ssv.load_numpy_data("neuron_partners") pred_key_ax = "{}_avg{}".format(global_params.config['compartments'][ 'view_properties_semsegax']['semseg_key'], global_params.config['compartments'][ 'dist_axoness_averaging']) for ssv_id in ssv_ids: # Iterate over cells ssv_o = ssd.get_super_segmentation_object(ssv_id) ssv_o.load_attr_dict() cache_dc = AttributeDict(ssv_o.ssv_dir + "/cache_syn.pkl", read_only=False, disable_locking=True) curr_ssv_mask = (syn_neuronpartners[:, 0] == ssv_id) | \ (syn_neuronpartners[:, 1] == ssv_id) ssv_synids = sd_syn_ssv.ids[curr_ssv_mask] if len(ssv_synids) == 0 or ssv_o.mesh[1].shape[0] == 0: cache_dc['partner_spineheadvol'] = np.zeros((len(ssv_synids),), dtype=np.float32) cache_dc['partner_axoness'] = np.zeros((len(ssv_synids),), dtype=np.int32) cache_dc['synssv_ids'] = ssv_synids cache_dc['partner_spiness'] = np.zeros((len(ssv_synids),), dtype=np.int32) cache_dc['partner_celltypes'] = np.zeros((len(ssv_synids),), dtype=np.int32) cache_dc['latent_morph'] = np.zeros([len(ssv_synids), n_embedding], dtype=np.float32) cache_dc.push() continue ssv_syncoords = sd_syn_ssv.rep_coords[curr_ssv_mask] try: ct = ssv_o.attr_dict['celltype_cnn_e3'] except KeyError: ct = -1 celltypes = [ct] * len(ssv_synids) curr_ax, latent_morph = ssv_o.attr_for_coords( ssv_syncoords, attr_keys=[pred_key_ax, 'latent_morph']) curr_sp = ssv_o.semseg_for_coords(ssv_syncoords, 'spiness', **semseg2coords_kwargs) sh_vol = np.array([ssv_o.attr_dict['spinehead_vol'][syn_id] if syn_id in ssv_o.attr_dict['spinehead_vol'] else -1 for syn_id in ssv_synids], dtype=np.float32) cache_dc['partner_spineheadvol'] = np.array(sh_vol) cache_dc['partner_axoness'] = curr_ax cache_dc['synssv_ids'] = ssv_synids cache_dc['partner_spiness'] = curr_sp cache_dc['partner_celltypes'] = np.array(celltypes) cache_dc['latent_morph'] = latent_morph cache_dc.push() def _from_cell_to_syn_dict(args): """ Maps the properties of synaptic partners from the cell level to the synapse level. Args: args (tuple): A tuple containing the segmentation object directory paths, working directory, object version, and super segmentation dataset version. (see 'collect_properties_from_ssv_partners') """ so_dir_paths, wd, obj_version, ssd_version = args ssd = super_segmentation.SuperSegmentationDataset(working_dir=wd, version=ssd_version) sd_syn_ssv = segmentation.SegmentationDataset(obj_type="syn_ssv", working_dir=wd, version=obj_version) cell_obj_conf = global_params.config['cell_objects'] for so_dir_path in so_dir_paths: this_attr_dc = AttributeDict(so_dir_path + "/attr_dict.pkl", read_only=False, disable_locking=True) for synssv_id in this_attr_dc.keys(): synssv_o = sd_syn_ssv.get_segmentation_object(synssv_id) synssv_o.load_attr_dict() sym_asym_ratio = synssv_o.attr_dict['syn_type_sym_ratio'] syn_sign = -1 if sym_asym_ratio > cell_obj_conf['sym_thresh'] else 1 axoness = [] latent_morph = [] spinehead_vol = [] spiness = [] celltypes = [] for ssv_partner_id in synssv_o.attr_dict["neuron_partners"]: ssv_o = ssd.get_super_segmentation_object(ssv_partner_id) cache_dc = AttributeDict(ssv_o.ssv_dir + "/cache_syn.pkl") index = np.transpose(np.nonzero(cache_dc['synssv_ids'] == synssv_id)) if len(index) != 1: msg = f"Could not find synssv with ID {synssv_id} in 'cache_syn.pkl' of {ssv_o}." log_extraction.error(msg) raise ValueError(msg) index = index[0][0] axoness.append(cache_dc['partner_axoness'][index]) spiness.append(cache_dc['partner_spiness'][index]) celltypes.append(cache_dc['partner_celltypes'][index]) latent_morph.append(cache_dc['latent_morph'][index]) spinehead_vol.append(cache_dc['partner_spineheadvol'][index]) synssv_o.attr_dict.update({'partner_axoness': axoness, 'partner_spiness': spiness, 'partner_celltypes': celltypes, 'partner_spineheadvol': spinehead_vol, 'syn_sign': syn_sign, 'latent_morph': latent_morph}) this_attr_dc[synssv_id] = synssv_o.attr_dict this_attr_dc.push() def _delete_all_cache_dc(args): """ Deletes all cache dictionaries in the super segmentation object directory. Args: args (tuple): A tuple containing the super segmentation object ID and configuration. """ ssv_id, config = args ssv_o = super_segmentation.SuperSegmentationObject(ssv_id, config=config) if os.path.exists(ssv_o.ssv_dir + "/cache_syn.pkl"): os.remove(ssv_o.ssv_dir + "/cache_syn.pkl")
[docs]def filter_relevant_syn(sd_syn: segmentation.SegmentationDataset, ssd: super_segmentation.SuperSegmentationDataset, log: Logger) -> Dict[int, list]: """ Filters the intra-ssv contact sites (inside of an ssv, not between ssvs) that do not need to be agglomerated. This function is also applicable to cs. Args: sd_syn (segmentation.SegmentationDataset): The segmentation dataset of synapses. ssd (super_segmentation.SuperSegmentationDataset): The super segmentation dataset. log (Logger): The logger for logging the progress and debugging information. Returns: Dict[int, list]: A dictionary where the keys are encoded SSV partner IDs and the values are lists of SV synapse object IDs. See :py:func:`~syconn.reps.connectivity_helper.sv_id_to_partner_ids_vec` for decoding into SSV IDs. """ if log is None: log = log_extraction # get all cs IDs belonging to syn objects and then retrieve corresponding SVs IDs via bit shift # syn objects are just a subset of contact site objects (which originally store the partner IDs) with the same IDs # -> not necessary to load the cs_ids. syn_ids = sd_syn.ids.copy() sv_ids = ch.cs_id_to_partner_ids_vec(syn_ids) log.debug(f'Generated supervoxel IDs for all {sd_syn.type} objects.') # this might mean that all syn between svs with IDs>max(np.uint32) are discarded sv_ids[sv_ids > ssd.mapping_lookup_reverse.id_array[-1]] = 0 # this creates a lookup dict for SV to SSV ID only for SV involved in synaptic contacts # -> This produces extra overhead when processing flattened SV graphs (SSVs only consist of 1 SV) mapping_dc = ssd.sv2ssv_ids(np.unique(sv_ids.flatten()), nb_cpus=sm.cpu_count()) log.debug('Generated sv-ssv mapping dict.') # TODO: apply with multiple processes and shared mapping_dc def mapper(x): return mapping_dc[x] if x in mapping_dc else 0 # np.vectorize is not concurrent/more efficient than "map", just a more convenient. mapped_ssv_ids = np.vectorize(mapper)(sv_ids.reshape(-1)).reshape(sv_ids.shape) log.debug(f'Mapped SV IDs to SSV IDs for all {sd_syn.type} objects.') del mapping_dc mask = np.all(mapped_ssv_ids > 0, axis=1) syn_ids = syn_ids[mask] filtered_mapped_ssv_ids = mapped_ssv_ids[mask] # this identifies all inter-ssv contact sites mask = filtered_mapped_ssv_ids[:, 0] != filtered_mapped_ssv_ids[:, 1] syn_ids = syn_ids[mask] inter_ssv_contacts = filtered_mapped_ssv_ids[mask] # TODO: generalize by adding it as a method parameter and pass config value to the method. Also add a config value # for syn_ssv! # filter small SSV if min path length was set in config min_path_length_partners = global_params.config['cell_contacts']['min_path_length_partners'] if (sd_syn.type == 'cs') and (min_path_length_partners is not None) and (min_path_length_partners > 0): filtered_ssv_ids = filter_ssd_by_total_pathlength(ssd, min_path_length_partners) log.info(f'Filtering contact sites formed with at least once small cell (min. path length of a ' f'cell {min_path_length_partners} µm). {len(filtered_ssv_ids)} Cells fulfill that ' f'criterion.') # check for every element of inter_ssv_contacts if it is inside filtered_ssv_ids res = np.isin(inter_ssv_contacts.reshape(-1), filtered_ssv_ids).reshape(-1, 2) mask = np.all(res, axis=1) inter_ssv_contacts = inter_ssv_contacts[mask] syn_ids = syn_ids[mask] assert len(inter_ssv_contacts) == len(syn_ids) if len(inter_ssv_contacts) == 0: log.warning(f'No contact site found after filtering small cells.') else: log.info(f'Found {len(inter_ssv_contacts)} supervoxel contact sites (merged, unsplit) between' f' {len(filtered_ssv_ids)} cells.') # get bit shifted combination of SSV partner IDs, used to collect all corresponding synapse IDs between the two # cells relevant_ssv_ids_enc = np.left_shift(np.max(inter_ssv_contacts, axis=1), 32) + np.min(inter_ssv_contacts, axis=1) log.debug(f'Filtered intra-cell {sd_syn.type} objects and created {sd_syn.type}_ssv IDs.') # create lookup from SSV-wide synapses to SV syn. objects ssv_to_syn_ids_dc = defaultdict(list) for i_entry in range(len(relevant_ssv_ids_enc)): ssv_to_syn_ids_dc[relevant_ssv_ids_enc[i_entry]].append(syn_ids[i_entry]) log.debug(f'Created cell-pair {sd_syn.type} object lists for subsequent split and agglomeration.') return ssv_to_syn_ids_dc
[docs]def combine_and_split_syn(wd, cs_gap_nm=300, ssd_version=None, syn_version=None, nb_cpus=None, n_folders_fs=10000, log=None, overwrite=False): """ Creates 'syn_ssv' objects from 'syn' objects. It computes connected syn-objects on SSV level and aggregates the respective 'syn' attributes ['cs_id', 'asym_prop', 'sym_prop', ]. All objects of the resulting 'syn_ssv' SegmentationDataset contain the following attributes: ['syn_sign', 'syn_type_sym_ratio', 'asym_prop', 'sym_prop', 'cs_ids', 'neuron_partners'] Notes: * 'rep_coord' property is calculated as the voxel (part of the object) closest to the center of mass of all object voxels. * 'cs_id'/'cs_ids' is the same as syn_id ('syn' are just a subset of 'cs', preserving the IDs). Args: wd (str): The working directory. cs_gap_nm (int, optional): The gap in nm. Defaults to 300. ssd_version (str, optional): The version of the super segmentation dataset. Defaults to None. syn_version (str, optional): The version of the synapse dataset. Defaults to None. nb_cpus (int, optional): The number of CPUs to use. Defaults to None. n_folders_fs (int, optional): The number of folders in the file system. Defaults to 10000. log (Logger, optional): The logger for logging the progress and debugging information. Defaults to None. overwrite (bool, optional): If True, overwrites existing files. Defaults to False. """ ssd = super_segmentation.SuperSegmentationDataset(wd, version=ssd_version) syn_sd = segmentation.SegmentationDataset("syn", working_dir=wd, version=syn_version) # TODO: this procedure creates folders with single and double digits, e.g. '0' and '00'. Single digit folders are # not used during write-outs, they are probably generated within this method's makedirs log_extraction.debug(f'Filtering relevant synapses.') rel_ssv_with_syn_ids = filter_relevant_syn(syn_sd, ssd, log=log) log_extraction.debug(f'Filtering relevant synapses done.') storage_location_ids = get_unique_subfold_ixs(n_folders_fs) n_used_paths = min(global_params.config.ncore_total * 8, len(storage_location_ids), len(rel_ssv_with_syn_ids)) voxel_rel_paths = chunkify([subfold_from_ix(ix, n_folders_fs) for ix in storage_location_ids], n_used_paths) # target SD for SSV syn objects sd_syn_ssv = segmentation.SegmentationDataset("syn_ssv", working_dir=wd, version="0", create=False, n_folders_fs=n_folders_fs) if os.path.exists(sd_syn_ssv.so_storage_path): if not overwrite: raise FileExistsError(f'"{sd_syn_ssv.so_storage_path}" already exists, but ' f'overwrite was set to False.') shutil.rmtree(sd_syn_ssv.so_storage_path) # prepare folder structure voxel_rel_paths_2stage = np.unique([subfold_from_ix(ix, n_folders_fs)[:-2] for ix in storage_location_ids]) for p in voxel_rel_paths_2stage: os.makedirs(sd_syn_ssv.so_storage_path + p) # TODO: apply weighting-scheme to balance worker load rel_ssv_with_syn_ids_items = list(rel_ssv_with_syn_ids.items()) rel_synssv_to_syn_ids_items_chunked = chunkify(rel_ssv_with_syn_ids_items, n_used_paths) multi_params = [(wd, rel_synssv_to_syn_ids_items_chunked[ii], voxel_rel_paths[ii], syn_sd.version, sd_syn_ssv.version, cs_gap_nm) for ii in range(n_used_paths)] if not qu.batchjob_enabled(): _ = sm.start_multiprocess_imap(_combine_and_split_syn_thread, multi_params, nb_cpus=nb_cpus, debug=False) else: _ = qu.batchjob_script( multi_params, "combine_and_split_syn", remove_jobfolder=True, log=log)
def _combine_and_split_syn_thread(args): """ This function is a helper function for combining and splitting synapses. It takes in a list of arguments and performs operations on synapse objects. It calculates the connected components of synapse objects and aggregates their attributes. It also handles the storage of these objects and their attributes. Args: args (list): A list containing the working directory, relative SSV with synapse IDs, voxel relative paths, synapse version, synapse SSV version, and CS gap in nm. """ wd = args[0] rel_ssv_with_syn_ids_items = args[1] voxel_rel_paths = args[2] syn_version = args[3] syn_ssv_version = args[4] cs_gap_nm = args[5] sd_syn_ssv = segmentation.SegmentationDataset("syn_ssv", working_dir=wd, version=syn_ssv_version) sd_syn = segmentation.SegmentationDataset("syn", working_dir=wd, version=syn_version) scaling = sd_syn.scaling syn_meshing_kws = global_params.config['meshes']['meshing_props_points']['syn_ssv'] mesh_min_obj_vx = global_params.config['meshes']['mesh_min_obj_vx'] cell_obj_cnf = global_params.config['cell_objects'] use_new_subfold = global_params.config.use_new_subfold # TODO: add to config, also used in 'ix_from_subfold' if 'global_params.config.use_new_subfold=True' div_base = 1e3 id_chunk_cnt = 0 n_per_voxel_path = np.ceil(float(len(rel_ssv_with_syn_ids_items)) / len(voxel_rel_paths)) n_items_for_path = 0 cur_path_id = 0 base_dir = sd_syn_ssv.so_storage_path + voxel_rel_paths[cur_path_id] os.makedirs(base_dir, exist_ok=True) # get ID/path to storage to save intermediate results base_id = ix_from_subfold(voxel_rel_paths[cur_path_id], sd_syn.n_folders_fs) syn_ssv_id = base_id voxel_dc = VoxelStorageLazyLoading(base_dir + "/voxel.npz", overwrite=True) attr_dc = AttributeDict(base_dir + "/attr_dict.pkl", read_only=False) mesh_dc = MeshStorage(base_dir + "/mesh.pkl", read_only=False) for ssvpartners_enc, syn_ids in rel_ssv_with_syn_ids_items: n_items_for_path += 1 ssv_ids = ch.cs_id_to_partner_ids_vec([ssvpartners_enc])[0] syn = sd_syn.get_segmentation_object(syn_ids[0]) # verify ssv_partner_ids syn.load_attr_dict() syn_attr_list = [syn.attr_dict] # used to collect syn properties voxel_list = [syn.voxel_list] # store index of syn. objects for attribute dict retrieval synix_list = [0] * len(voxel_list[0]) for syn_ix, syn_id in enumerate(syn_ids[1:]): syn = sd_syn.get_segmentation_object(syn_id) syn.load_attr_dict() syn_attr_list.append(syn.attr_dict) voxel_list.append(syn.voxel_list) synix_list += [syn_ix] * len(voxel_list[-1]) syn_attr_list = np.array(syn_attr_list) synix_list = np.array(synix_list) if len(synix_list) == 0: msg = 'Voxels not available for syn-objects {}.'.format(syn_ids) log_extraction.error(msg) raise ValueError(msg) ccs = connected_cluster_kdtree(voxel_list, dist_intra_object=cs_gap_nm, dist_inter_object=20000, scale=scaling) voxel_list = np.concatenate(voxel_list) for this_cc in ccs: # do not process synapse again if job has been restarted if syn_ssv_id not in attr_dc: this_cc_mask = np.array(list(this_cc)) # retrieve the index of the syn objects selected for this CC this_syn_ixs, this_syn_ids_cnt = np.unique(synix_list[this_cc_mask], return_counts=True) # the weight is important this_agg_syn_weights = this_syn_ids_cnt / np.sum(this_syn_ids_cnt) if np.sum(this_syn_ids_cnt) < cell_obj_cnf['min_obj_vx']['syn_ssv']: continue this_attr = syn_attr_list[this_syn_ixs] this_vx = voxel_list[this_cc_mask] syn_ssv = sd_syn_ssv.get_segmentation_object(syn_ssv_id) if (os.path.abspath(syn_ssv.attr_dict_path) != os.path.abspath(base_dir + "/attr_dict.pkl")): raise ValueError(f'Path mis-match!') synssv_attr_dc = dict(neuron_partners=ssv_ids) voxel_dc[syn_ssv_id] = this_vx synssv_attr_dc["rep_coord"] = (seghelp.calc_center_of_mass(this_vx * scaling) // scaling).astype(np.int32) synssv_attr_dc["bounding_box"] = np.array([np.min(this_vx, axis=0), np.max(this_vx, axis=0)]) synssv_attr_dc["size"] = len(this_vx) # calc_contact_syn_mesh returns a list with a single mesh (for syn_ssv) if mesh_min_obj_vx < synssv_attr_dc["size"]: syn_ssv._mesh = calc_contact_syn_mesh(syn_ssv, voxel_dc=voxel_dc, **syn_meshing_kws)[0] mesh_dc[syn_ssv.id] = syn_ssv.mesh synssv_attr_dc["mesh_bb"] = syn_ssv.mesh_bb synssv_attr_dc["mesh_area"] = syn_ssv.mesh_area else: zero_mesh = [np.zeros((0,), dtype=np.int32), np.zeros((0,), dtype=np.int32), np.zeros((0,), dtype=np.float32)] mesh_dc[syn_ssv.id] = zero_mesh synssv_attr_dc["mesh_bb"] = synssv_attr_dc["bounding_box"] * scaling synssv_attr_dc["mesh_area"] = 0 # aggregate syn properties syn_props_agg = {} # cs_id is the same as syn_id ('syn' are just a subset of 'cs') for dc in this_attr: for k in ['cs_id', 'sym_prop', 'asym_prop']: syn_props_agg.setdefault(k, []).append(dc[k]) # rename and delete old entry syn_props_agg['cs_ids'] = syn_props_agg['cs_id'] del syn_props_agg['cs_id'] # 'syn_ssv' synapse type as weighted sum of the 'syn' fragment types sym_prop = np.sum(this_agg_syn_weights * np.array(syn_props_agg['sym_prop'])) asym_prop = np.sum(this_agg_syn_weights * np.array(syn_props_agg['asym_prop'])) syn_props_agg['sym_prop'] = sym_prop syn_props_agg['asym_prop'] = asym_prop if sym_prop + asym_prop == 0: sym_ratio = -1 else: sym_ratio = sym_prop / float(asym_prop + sym_prop) syn_props_agg["syn_type_sym_ratio"] = sym_ratio syn_sign = -1 if sym_ratio > cell_obj_cnf['sym_thresh'] else 1 syn_props_agg["syn_sign"] = syn_sign # add syn_ssv dict to AttributeStorage synssv_attr_dc.update(syn_props_agg) attr_dc[syn_ssv_id] = synssv_attr_dc if use_new_subfold: syn_ssv_id += np.uint(1) if syn_ssv_id - base_id >= div_base: # next ID chunk mapped to this storage id_chunk_cnt += 1 old_base_id = base_id base_id += np.uint(sd_syn_ssv.n_folders_fs * div_base * id_chunk_cnt) assert subfold_from_ix(base_id, sd_syn_ssv.n_folders_fs, old_version=False) == \ subfold_from_ix(old_base_id, sd_syn_ssv.n_folders_fs, old_version=False) syn_ssv_id = base_id else: syn_ssv_id += np.uint(sd_syn.n_folders_fs) if n_items_for_path > n_per_voxel_path: voxel_dc.push() voxel_dc.close() mesh_dc.push() attr_dc.push() cur_path_id += 1 if len(voxel_rel_paths) == cur_path_id: raise ValueError(f'Worker ran out of possible storage paths for storing {sd_syn_ssv.type}.') n_items_for_path = 0 id_chunk_cnt = 0 base_id = ix_from_subfold(voxel_rel_paths[cur_path_id], sd_syn.n_folders_fs) syn_ssv_id = base_id base_dir = sd_syn_ssv.so_storage_path + voxel_rel_paths[cur_path_id] os.makedirs(base_dir, exist_ok=True) voxel_dc = VoxelStorageLazyLoading(base_dir + "/voxel.npz") attr_dc = AttributeDict(base_dir + "/attr_dict.pkl", read_only=False) mesh_dc = MeshStorage(base_dir + "/mesh.pkl", read_only=False) if n_items_for_path > 0: voxel_dc.push() voxel_dc.close() attr_dc.push() mesh_dc.push()
[docs]def connected_cluster_kdtree(voxel_coords: List[np.ndarray], dist_intra_object: float, dist_inter_object: float, scale: np.ndarray) -> List[set]: """ This function identifies connected components within N objects. It performs a two-stage process where it first adds edges between every object voxel which are at most 2 voxels apart. The edges are added to a global graph which is used to calculate connected components. In the second stage, connected components are considered close if they are within a maximum distance of `dist_inter_object` between a random voxel used as their representative coordinate. Close connected components will then be connected if the minimum distance between any of their voxels is smaller than `dist_intra_object`. Args: voxel_coords (List[np.ndarray]): List of numpy arrays in voxel coordinates. dist_intra_object (float): Maximum distance between two voxels of different synapse fragments to consider them the same object. In nm. dist_inter_object (float): Maximum distance between two objects to check for close voxels between them. In nm. scale (np.ndarray): Voxel sizes in nm (XYZ). Returns: List[set]: Connected components across all N input objects with at most `dist_intra_cluster` distance. """ import networkx as nx graph = nx.Graph() ixs_offset = np.cumsum([0] + [len(syn_vxs) for syn_vxs in voxel_coords[:-1]]) # add intra object edges for ii in range(len(voxel_coords)): off = ixs_offset[ii] graph.add_nodes_from(np.arange(len(voxel_coords[ii])) + off) kdtree = spatial.cKDTree(voxel_coords[ii]) pairs = np.array(list(kdtree.query_pairs(r=2)), dtype=np.int64) graph.add_edges_from(pairs + off) del kdtree, pairs voxel_coords_flat = np.concatenate(voxel_coords) * scale ccs = [np.array(list(cc)) for cc in nx.connected_components(graph)] rep_coords = np.array([voxel_coords_flat[cc[0]] for cc in ccs]) kdtree = spatial.cKDTree(rep_coords) pairs = kdtree.query_pairs(r=dist_inter_object) del kdtree # add minimal inter-object edges for c1, c2 in pairs: c1_ixs = ccs[c1] c2_ixs = ccs[c2] kd1 = spatial.cKDTree(voxel_coords_flat[c1_ixs]) dists, nn_ixs = kd1.query(voxel_coords_flat[c2_ixs], distance_upper_bound=dist_intra_object) if min(dists) > dist_intra_object: continue argmin = np.argmin(dists) ix_c1 = c1_ixs[nn_ixs[argmin]] ix_c2 = c2_ixs[argmin] graph.add_edge(ix_c1, ix_c2) return list(nx.connected_components(graph))
[docs]def combine_and_split_cs(wd, ssd_version=None, cs_version=None, nb_cpus=None, n_folders_fs=10000, log=None, overwrite=False): """ This function creates 'cs_ssv' objects from 'cs' objects. It computes connected cs-objects on SSV level and re-calculates their attributes (mesh_area, size, etc.). This method performs connected component analysis on the mesh of all cell-cell contacts instead of their voxels. Notes: * 'rep_coord' property is calculated as the mesh vertex closest to the center of mass of all mesh vertices. Args: wd (str): The working directory. ssd_version (str, optional): The version of the super segmentation dataset. Defaults to None. cs_version (str, optional): The version of the cell segmentation. Defaults to None. nb_cpus (int, optional): The number of CPUs to use. Defaults to None. n_folders_fs (int, optional): The number of folders in the file system. Defaults to 10000. log (Logger, optional): The logger to use. Defaults to None. overwrite (bool, optional): Whether to overwrite existing files. Defaults to False. """ ssd = super_segmentation.SuperSegmentationDataset(wd, version=ssd_version) cs_sd = segmentation.SegmentationDataset("cs", working_dir=wd, version=cs_version) cs_version = cs_sd.version rel_ssv_with_cs_ids = filter_relevant_syn(cs_sd, ssd, log=log) del ssd, cs_sd storage_location_ids = get_unique_subfold_ixs(n_folders_fs) n_used_paths = min(global_params.config.ncore_total * 30, len(storage_location_ids), len(rel_ssv_with_cs_ids)) voxel_rel_paths = chunkify([subfold_from_ix(ix, n_folders_fs) for ix in storage_location_ids], n_used_paths) # target SD for SSV cs objects sd_cs_ssv = segmentation.SegmentationDataset("cs_ssv", working_dir=wd, version="0", create=False, n_folders_fs=n_folders_fs) if os.path.exists(sd_cs_ssv.so_storage_path): if not overwrite: raise FileExistsError(f'"{sd_cs_ssv.so_storage_path}" already exists, but overwrite was set to False.') shutil.rmtree(sd_cs_ssv.so_storage_path) # prepare folder structure voxel_rel_paths_2stage = np.unique([subfold_from_ix(ix, n_folders_fs)[:-2] for ix in storage_location_ids]) for p in voxel_rel_paths_2stage: os.makedirs(sd_cs_ssv.so_storage_path + p) rel_ssv_with_cs_ids_items = list(rel_ssv_with_cs_ids.items()) rel_csssv_to_cs_ids_items_chunked = chunkify(rel_ssv_with_cs_ids_items, n_used_paths) multi_params = [(wd, rel_csssv_to_cs_ids_items_chunked[ii], voxel_rel_paths[ii], cs_version, sd_cs_ssv.version) for ii in range(n_used_paths)] del rel_ssv_with_cs_ids_items, rel_csssv_to_cs_ids_items_chunked, voxel_rel_paths_2stage, voxel_rel_paths if not qu.batchjob_enabled(): _ = sm.start_multiprocess_imap(_combine_and_split_cs_thread, multi_params, nb_cpus=nb_cpus, debug=False) else: _ = qu.batchjob_script(multi_params, "combine_and_split_cs", remove_jobfolder=True, log=log)
def _combine_and_split_cs_thread(args): """ This function is a helper function for the combine_and_split_cs function. It takes a tuple of arguments and performs the task of combining and splitting 'cs' objects into 'cs_ssv' objects. It does this by computing connected cs-objects on SSV level and re-calculating their attributes (mesh_area, size, ..). Args: args (tuple): A tuple containing the following elements: - wd: The working directory. - rel_ssv_with_cs_ids_items: A list of tuples, where each tuple contains an encoded SSV partner and a list of contact site IDs. - voxel_rel_paths: A list of relative paths to the voxel data. - cs_version: The version of the 'cs' dataset. - cs_ssv_version: The version of the 'cs_ssv' dataset. """ wd = args[0] rel_ssv_with_cs_ids_items = args[1] voxel_rel_paths = args[2] cs_version = args[3] cs_ssv_version = args[4] sd_cs_ssv = segmentation.SegmentationDataset("cs_ssv", working_dir=wd, version=cs_ssv_version) sd_cs = segmentation.SegmentationDataset("cs", working_dir=wd, version=cs_version) scaling = sd_cs.scaling meshing_kws = global_params.config['meshes']['meshing_props_points']['cs_ssv'] mesh_min_obj_vx = global_params.config['meshes']['mesh_min_obj_vx'] use_new_subfold = global_params.config.use_new_subfold # TODO: add to config, also used in 'ix_from_subfold' if 'global_params.config.use_new_subfold=True' div_base = 1e3 id_chunk_cnt = 0 n_per_voxel_path = np.ceil(float(len(rel_ssv_with_cs_ids_items)) / len(voxel_rel_paths)) n_items_for_path = 0 cur_path_id = 0 base_dir = sd_cs_ssv.so_storage_path + voxel_rel_paths[cur_path_id] os.makedirs(base_dir, exist_ok=True) # get ID/path to storage to save intermediate results base_id = ix_from_subfold(voxel_rel_paths[cur_path_id], sd_cs.n_folders_fs) cs_ssv_id = base_id attr_dc = AttributeDict(base_dir + "/attr_dict.pkl", read_only=False) mesh_dc = MeshStorage(base_dir + "/mesh.pkl", read_only=False, compress=True) # iterate over cell partners and their contact site IDs (each contact site is between two supervoxels # of the partner cells) for ssvpartners_enc, cs_ids in rel_ssv_with_cs_ids_items: n_items_for_path += 1 ssv_ids = ch.cs_id_to_partner_ids_vec([ssvpartners_enc])[0] # verify ssv_partner_ids cs_lst = sd_cs.get_segmentation_object(cs_ids) vxl_iter_lst = [] vx_cnt = 0 for cs in cs_lst: vx_store = VoxelStorageDyn(cs.voxel_path, read_only=True, disable_locking=True) vxl_iter_lst.append(vx_store.iter_voxelmask_offset(cs.id, overlap=1)) vx_cnt += vx_store.object_size(cs.id) if mesh_min_obj_vx > vx_cnt: ccs = [] else: # generate connected component meshes; vertices are in nm ccs = gen_mesh_voxelmask(chain(*vxl_iter_lst), scale=scaling, **meshing_kws) for mesh_cc in ccs: cs_ssv = sd_cs_ssv.get_segmentation_object(cs_ssv_id) if (os.path.abspath(cs_ssv.attr_dict_path) != os.path.abspath(base_dir + "/attr_dict.pkl")): raise ValueError(f'Path mis-match!') csssv_attr_dc = dict(neuron_partners=ssv_ids) # don't store normals cs_ssv._mesh = [mesh_cc[0], mesh_cc[1], np.zeros((0,), dtype=np.float32)] mesh_dc[cs_ssv.id] = cs_ssv.mesh csssv_attr_dc["mesh_bb"] = cs_ssv.mesh_bb csssv_attr_dc["mesh_area"] = cs_ssv.mesh_area csssv_attr_dc["bounding_box"] = (cs_ssv.mesh_bb // scaling).astype(np.int32) csssv_attr_dc["rep_coord"] = (seghelp.calc_center_of_mass(mesh_cc[1].reshape((-1, 3))) // scaling).astype(np.int64) csssv_attr_dc["cs_ids"] = list(cs_ids) # create open3d mesh instance to compute volume # # TODO: add this as soon open3d >= 0.11 is supported (glibc error on cluster prevents upgrade) # tm = o3d.geometry.TriangleMesh # tm.triangles = o3d.utility.Vector3iVector(mesh_cc[0].reshape((-1, 3))) # tm.vertices = o3d.utility.Vector3dVector(mesh_cc[1].reshape((-1, 3))) # tm.normals = o3d.utility.Vector3dVector(mesh_cc[2].reshape((-1, 3))) # assert tm.is_watertight() # csssv_attr_dc["size"] = tm.get_volume // np.prod(scaling) csssv_attr_dc["size"] = 0 # add cs_ssv dict to AttributeStorage attr_dc[cs_ssv_id] = csssv_attr_dc if use_new_subfold: cs_ssv_id += np.uint(1) if cs_ssv_id - base_id >= div_base: # next ID chunk mapped to this storage id_chunk_cnt += 1 old_base_id = base_id base_id += np.uint(sd_cs_ssv.n_folders_fs * div_base) * id_chunk_cnt assert subfold_from_ix(base_id, sd_cs_ssv.n_folders_fs, old_version=False) == \ subfold_from_ix(old_base_id, sd_cs_ssv.n_folders_fs, old_version=False) cs_ssv_id = base_id else: cs_ssv_id += np.uint(sd_cs.n_folders_fs) if n_items_for_path > n_per_voxel_path: attr_dc.push() mesh_dc.push() cur_path_id += 1 if len(voxel_rel_paths) == cur_path_id: raise ValueError(f'Worker ran out of possible storage paths for storing {sd_cs_ssv.type}.') n_items_for_path = 0 id_chunk_cnt = 0 base_id = ix_from_subfold(voxel_rel_paths[cur_path_id], sd_cs.n_folders_fs) cs_ssv_id = base_id base_dir = sd_cs_ssv.so_storage_path + voxel_rel_paths[cur_path_id] os.makedirs(base_dir, exist_ok=True) attr_dc = AttributeDict(base_dir + "/attr_dict.pkl", read_only=False) mesh_dc = MeshStorage(base_dir + "/mesh.pkl", read_only=False) if n_items_for_path > 0: attr_dc.push() mesh_dc.push()
[docs]def cc_large_voxel_lists(voxel_list, cs_gap_nm, max_concurrent_nodes=5000, verbose=False): """ This function identifies connected components within a list of voxels. It uses a k-d tree data structure to efficiently query the nearest neighbors of each voxel. It then groups voxels into connected components based on their proximity to each other. Args: voxel_list (list): A list of voxel coordinates. cs_gap_nm (float): The maximum distance between two voxels to consider them as part of the same connected component. In nanometers. max_concurrent_nodes (int, optional): The maximum number of nodes to process concurrently. Defaults to 5000. verbose (bool, optional): If True, print debug information. Defaults to False. Returns: list: A list of sets, where each set contains the indices of voxels that belong to the same connected component. """ kdtree = spatial.cKDTree(voxel_list) checked_ids = np.array([], dtype=np.int32) next_ids = np.array([0]) ccs = [set(next_ids)] current_ccs = 0 vx_ids = np.arange(len(voxel_list), dtype=np.int32) while True: if verbose: log_extraction.debug("NEXT - %d - %d" % (len(next_ids), len(checked_ids))) for cc in ccs: log_extraction.debug("N voxels in cc: %d" % (len(cc))) if len(next_ids) == 0: p_ids = vx_ids[~np.in1d(vx_ids, checked_ids)] if len(p_ids) == 0: break else: current_ccs += 1 ccs.append(set([p_ids[0]])) next_ids = p_ids[:1] q_ids = kdtree.query_ball_point(voxel_list[next_ids], r=cs_gap_nm, ) checked_ids = np.concatenate([checked_ids, next_ids]) for q_id in q_ids: ccs[current_ccs].update(q_id) cc_ids = np.array(list(ccs[current_ccs])) next_ids = vx_ids[cc_ids[~np.in1d(cc_ids, checked_ids)][:max_concurrent_nodes]] return ccs
[docs]def map_objects_from_synssv_partners(wd: str, obj_version: Optional[str] = None, ssd_version: Optional[str] = None, n_jobs: Optional[int] = None, debug: bool = False, log: Logger = None, max_rep_coord_dist_nm: Optional[float] = None): """ This function maps sub-cellular objects of the synaptic partners of 'syn_ssv' objects and stores them in their attribute dict. The following keys will be available in the ``attr_dict`` of ``syn_ssv``-typed :class:`~syconn.reps.segmentation.SegmentationObject`: * 'n_mi_objs_%d': * 'n_mi_vxs_%d': * 'min_dst_mi_nm_%d': * 'n_vc_objs_%d': * 'n_vc_vxs_%d': * 'min_dst_vc_nm_%d': Args: wd (str): The working directory. obj_version (str, optional): The version of the 'syn_ssv' dataset. Defaults to None. ssd_version (str, optional): The version of the 'ssv' dataset. Defaults to None. n_jobs (int, optional): The number of jobs to run in parallel. Defaults to None. debug (bool, optional): If True, print debug information. Defaults to False. log (Logger, optional): The logger to use for logging debug information. Defaults to None. max_rep_coord_dist_nm (float, optional): The maximum distance between the representative coordinate of a synapse and a sub-cellular object to consider them as connected. In nanometers. Defaults to None. """ if n_jobs is None: n_jobs = global_params.config.ncore_total * 4 if max_rep_coord_dist_nm is None: max_rep_coord_dist_nm = global_params.config['cell_objects']['max_rep_coord_dist_nm'] ssd = super_segmentation.SuperSegmentationDataset(working_dir=wd, version=ssd_version) multi_params = [] for ids_small_chunk in chunkify(ssd.ssv_ids, n_jobs): multi_params.append([wd, obj_version, ssd_version, ids_small_chunk, max_rep_coord_dist_nm]) if not qu.batchjob_enabled(): _ = sm.start_multiprocess_imap( _map_objects_from_synssv_partners_thread, multi_params, debug=debug) else: _ = qu.batchjob_script( multi_params, "map_objects_from_synssv_partners", log=log, remove_jobfolder=True) # iterate over paths with syn sd_syn_ssv = segmentation.SegmentationDataset("syn_ssv", working_dir=wd, version=obj_version) multi_params = [] for so_dir_paths in chunkify(sd_syn_ssv.so_dir_paths, n_jobs): multi_params.append([so_dir_paths, wd, obj_version, ssd_version]) if not qu.batchjob_enabled(): _ = sm.start_multiprocess_imap( _objects_from_cell_to_syn_dict, multi_params, debug=False) else: _ = qu.batchjob_script( multi_params, "objects_from_cell_to_syn_dict", log=log, remove_jobfolder=True) if log is None: log = log_extraction log.debug('Deleting cache dictionaries now.') # delete cache_dc sm.start_multiprocess_imap(_delete_all_cache_dc, [(ssv_id, ssd.config) for ssv_id in ssd.ssv_ids], nb_cpus=global_params.config['ncores_per_node']) log.debug('Deleted all cache dictionaries.')
def _map_objects_from_synssv_partners_thread(args: tuple): """ This function is a helper function for 'map_objects_from_synssv_partners'. It maps cellular organelles to syn_ssv objects which are needed for the RFC model executed in 'classify_synssv_objects'. It takes a tuple of arguments as input and returns nothing. Args: args(tuple): A tuple of arguments. See 'map_objects_from_synssv_partners' for more details. """ max_vert_dist_nm = global_params.config['cell_objects']['max_vert_dist_nm'] # TODO: add global overwrite kwarg overwrite = True wd, obj_version, ssd_version, ssv_ids, max_rep_coord_dist_nm = args use_new_subfold = global_params.config.use_new_subfold sd_syn_ssv = segmentation.SegmentationDataset(obj_type="syn_ssv", working_dir=wd, version=obj_version) ssd = super_segmentation.SuperSegmentationDataset(working_dir=wd, version=ssd_version) sd_vc = segmentation.SegmentationDataset(obj_type="vc", working_dir=wd) sd_mi = segmentation.SegmentationDataset(obj_type="mi", working_dir=wd) syn_neuronpartners = sd_syn_ssv.load_numpy_data("neuron_partners") # dts = dict(id_mask=0, kds=0, map_verts=0, directio=0, meshcache=0) for ssv_id in ssv_ids: # Iterate over cells ssv_o = ssd.get_super_segmentation_object(ssv_id) # start = time.time() if overwrite and os.path.isfile(ssv_o.ssv_dir + "/cache_syn.pkl"): os.remove(ssv_o.ssv_dir + "/cache_syn.pkl") cache_dc = AttributeDict(ssv_o.ssv_dir + "/cache_syn.pkl", read_only=False, disable_locking=True) if not overwrite and ('n_vc_vxs' in cache_dc): continue # dts['directio'] += time.time() - start curr_ssv_mask = (syn_neuronpartners[:, 0] == ssv_id) | \ (syn_neuronpartners[:, 1] == ssv_id) synssv_ids = sd_syn_ssv.ids[curr_ssv_mask] n_synssv = len(synssv_ids) n_mi_objs = np.zeros((n_synssv,), dtype=np.int32) n_mi_vxs = np.zeros((n_synssv,), dtype=np.int32) n_vc_objs = np.zeros((n_synssv,), dtype=np.int32) n_vc_vxs = np.zeros((n_synssv,), dtype=np.int32) min_dst_mi = np.zeros((n_synssv,), dtype=np.float32) min_dst_vc = np.zeros((n_synssv,), dtype=np.float32) cache_dc['synssv_ids'] = synssv_ids if n_synssv == 0: cache_dc['n_mi_objs'] = n_mi_objs cache_dc['n_mi_vxs'] = n_mi_vxs cache_dc['n_vc_objs'] = n_vc_objs cache_dc['n_vc_vxs'] = n_vc_vxs cache_dc['min_dst_mi_nm'] = min_dst_mi cache_dc['min_dst_vc_nm'] = min_dst_vc cache_dc.push() continue # start = time.time() vc_mask = np.in1d(sd_vc.ids, ssv_o.vc_ids) mi_mask = np.in1d(sd_mi.ids, ssv_o.mi_ids) # dts['id_mask'] += time.time() - start vc_ids = sd_vc.ids[vc_mask] mi_ids = sd_mi.ids[mi_mask] vc_sizes = sd_vc.sizes[vc_mask] mi_sizes = sd_mi.sizes[mi_mask] # start = time.time() kdtree_synssv = spatial.cKDTree(sd_syn_ssv.rep_coords[curr_ssv_mask] * sd_syn_ssv.scaling) # vesicle clouds kdtree_vc = spatial.cKDTree(sd_vc.rep_coords[vc_mask] * sd_vc.scaling) # mitos kdtree_mi = spatial.cKDTree(sd_mi.rep_coords[mi_mask] * sd_mi.scaling) # returns a list of neighboring objects for every synssv (note: ix is now the index within ssv_o.mi_ids close_mi_ixs = kdtree_synssv.query_ball_tree(kdtree_mi, r=max_rep_coord_dist_nm) close_vc_ixs = kdtree_synssv.query_ball_tree(kdtree_vc, r=max_rep_coord_dist_nm) # dts['kds'] += time.time() - start # start = time.time() close_mi_ids = mi_ids[np.unique(np.concatenate(close_mi_ixs)).astype(np.int32)] close_vc_ids = vc_ids[np.unique(np.concatenate(close_vc_ixs).astype(np.int32))] md_mi = seghelp.load_so_meshes_bulk(sd_mi.get_segmentation_object(close_mi_ids), use_new_subfold=use_new_subfold) md_vc = seghelp.load_so_meshes_bulk(sd_vc.get_segmentation_object(close_vc_ids), use_new_subfold=use_new_subfold) # md_synssv = seghelp.load_so_meshes_bulk(sd_syn_ssv.get_segmentation_object(synssv_ids), # use_new_subfold=use_new_subfold) # dts['meshcache'] += time.time() - start # start = time.time() for ii, synssv_id in enumerate(synssv_ids): synssv_obj = sd_syn_ssv.get_segmentation_object(synssv_id) # synssv_obj._mesh = md_synssv[synssv_id] mis = sd_mi.get_segmentation_object(mi_ids[close_mi_ixs[ii]]) # load cached meshes for jj, ix in enumerate(close_mi_ixs[ii]): mi = mis[jj] mi._size = mi_sizes[ix] mi._mesh = md_mi[mi.id] vcs = sd_vc.get_segmentation_object(vc_ids[close_vc_ixs[ii]]) for jj, ix in enumerate(close_vc_ixs[ii]): vc = vcs[jj] vc._size = vc_sizes[ix] vc._mesh = md_vc[vc.id] n_mi_objs[ii], n_mi_vxs[ii], min_dst_mi[ii] = _map_objects_from_synssv( synssv_obj, mis, max_vert_dist_nm['mi']) n_vc_objs[ii], n_vc_vxs[ii], min_dst_vc[ii] = _map_objects_from_synssv( synssv_obj, vcs, max_vert_dist_nm['vc']) # dts['map_verts'] += time.time() - start # start = time.time() cache_dc['n_mi_objs'] = n_mi_objs cache_dc['n_mi_vxs'] = n_mi_vxs cache_dc['min_dst_mi_nm'] = min_dst_mi cache_dc['n_vc_objs'] = n_vc_objs cache_dc['n_vc_vxs'] = n_vc_vxs cache_dc['min_dst_vc_nm'] = min_dst_vc cache_dc.push() # dts['directio'] += time.time() - start def _map_objects_from_synssv(synssv_o, seg_objs, max_vert_dist_nm, sample_fact=2): """ This function maps cellular organelles to syn_ssv objects. It is a helper function for `objects_to_single_synssv`. It takes a 'syn_ssv' synapse object, SegmentationObject of type 'vc' or 'mi', a query radius for SegmentationObject vertices and a sampling factor as input. It returns the number of SegmentationObjects with more than 0 vertices, approximated number of object voxels within `max_vert_dist_nm` and minimal distance (in nm; maximum value: 1e12 nm in case no object is present). Note: Loading meshes for approximating close-by object volume is slow - consider exchanging with summed object size? Args: synssv_o: 'syn_ssv' synapse object. seg_objs: SegmentationObject of type 'vc' or 'mi' max_vert_dist_nm: Query radius for SegmentationObject vertices. Used to estimate number of nearby object voxels. sample_fact: Only use every Xth vertex. Returns: n_objects: Number of SegmentationObjects with >0 vertices. n_vxs: Approximated number of object voxels within `max_vert_dist_nm`. min_dist: Minimal distance (in nm; maximum value: 1e12 nm in case no object is present). """ synssv_kdtree = spatial.cKDTree(synssv_o.voxel_list[::sample_fact] * synssv_o.scaling) min_dist = 1e12 # in nm n_obj_vxs = [] for obj in seg_objs: # use mesh vertices instead of voxels obj_vxs = obj.mesh[1].reshape(-1, 3)[::sample_fact] ds, _ = synssv_kdtree.query(obj_vxs, distance_upper_bound=max_vert_dist_nm) # surface fraction of subcellular object which is close to synapse close_frac = np.sum(ds < np.inf) / len(obj_vxs) if np.min(ds) < min_dist: min_dist = np.min(ds) # estimate number of voxels by close-by surface area fraction times total number of voxels n_obj_vxs.append(close_frac * obj.size) n_obj_vxs = np.array(n_obj_vxs) n_objects = np.sum(n_obj_vxs > 0) n_vxs = np.sum(n_obj_vxs) return n_objects, n_vxs, min_dist def _objects_from_cell_to_syn_dict(args): """ This function takes a tuple of arguments as input. See 'map_objects_from_synssv_partners' for more details. Args: args(tuple): A tuple of arguments. See 'map_objects_from_synssv_partners' for more details. """ so_dir_paths, wd, obj_version, ssd_version = args ssd = super_segmentation.SuperSegmentationDataset(working_dir=wd, version=ssd_version) sd_syn_ssv = segmentation.SegmentationDataset(obj_type="syn_ssv", working_dir=wd, version=obj_version) for so_dir_path in so_dir_paths: this_attr_dc = AttributeDict(so_dir_path + "/attr_dict.pkl", read_only=False, disable_locking=True) for synssv_id in this_attr_dc.keys(): synssv_o = sd_syn_ssv.get_segmentation_object(synssv_id) synssv_o.load_attr_dict() map_dc = dict() for ii, ssv_partner_id in enumerate(synssv_o.attr_dict["neuron_partners"]): ssv_o = ssd.get_super_segmentation_object(ssv_partner_id) cache_dc = AttributeDict(ssv_o.ssv_dir + "/cache_syn.pkl") index = np.transpose(np.nonzero(cache_dc['synssv_ids'] == synssv_id)) if len(index) != 1: msg = "Partner cell ID mismatch." log_extraction.error(msg) raise ValueError(msg) index = index[0][0] map_dc[f'n_mi_objs_{ii}'] = cache_dc['n_mi_objs'][index] map_dc[f'n_mi_vxs_{ii}'] = cache_dc['n_mi_vxs'][index] map_dc[f'n_vc_objs_{ii}'] = cache_dc['n_vc_objs'][index] map_dc[f'n_vc_vxs_{ii}'] = cache_dc['n_vc_vxs'][index] map_dc[f'min_dst_mi_nm_{ii}'] = cache_dc['min_dst_mi_nm'][index] map_dc[f'min_dst_vc_nm_{ii}'] = cache_dc['min_dst_vc_nm'][index] synssv_o.attr_dict.update(map_dc) this_attr_dc[synssv_id] = synssv_o.attr_dict this_attr_dc.push()
[docs]def classify_synssv_objects(wd, obj_version=None, log=None, nb_cpus=None): """ This function classifies SSV contact sites into synaptic or non-synaptic using an RFC model and stores the result in the attribute dict of the syn_ssv objects. For requirements see `synssv_o_features`. It takes the working directory, object version, logger and number of CPUs as input and returns nothing. Args: wd (str): Working directory. obj_version (str): Object version. log (Logger): Logger. nb_cpus (int): Number of CPUs. Returns: None """ sd_syn_ssv = segmentation.SegmentationDataset("syn_ssv", working_dir=wd, version=obj_version) multi_params = chunkify(sd_syn_ssv.so_dir_paths, global_params.config.ncore_total) multi_params = [(so_dir_paths, wd, obj_version) for so_dir_paths in multi_params] if not qu.batchjob_enabled(): _ = sm.start_multiprocess_imap(_classify_synssv_objects_thread, multi_params, nb_cpus=nb_cpus) else: _ = qu.batchjob_script( multi_params, "classify_synssv_objects", log=log, remove_jobfolder=True)
def _classify_synssv_objects_thread(args): """ This function is a helper for 'classify_synssv_objects'. It takes a tuple of arguments as input and returns nothing. Args: args(tuple): A tuple of arguments. See 'classify_synssv_objects' for more details. """ so_dir_paths, wd, obj_version = args sd_syn_ssv = segmentation.SegmentationDataset(obj_type="syn_ssv", working_dir=wd, version=obj_version) try: rfc = joblib.load(global_params.config.mpath_syn_rfc) except ImportError: rfc = joblib.load(global_params.config.mpath_syn_rfc_fallback) for so_dir_path in so_dir_paths: this_attr_dc = AttributeDict(so_dir_path + "/attr_dict.pkl", read_only=False) for synssv_id in this_attr_dc.keys(): synssv_o = sd_syn_ssv.get_segmentation_object(synssv_id) synssv_o.attr_dict = this_attr_dc[synssv_id] feats = synssv_o_features(synssv_o) syn_prob = rfc.predict_proba([feats])[0][1] synssv_o.attr_dict.update({"syn_prob": syn_prob}) this_attr_dc[synssv_id] = synssv_o.attr_dict this_attr_dc.push() # Code for property extraction of contact sites (syn_ssv)
[docs]def write_conn_gt_kzips(conn, n_objects, folder): """ This function writes .k.zip summary files of connectivity matrix. It takes a connectivity matrix, number of objects and a folder as input and returns nothing. Args: conn: Connectivity matrix. n_objects: Number of objects. folder: Folder to write .k.zip files. """ if not os.path.exists(folder): os.makedirs(folder) conn_ids = conn.ids[np.random.choice(len(conn.ids), n_objects, replace=False)] for conn_id in conn_ids: obj = conn.get_segmentation_object(conn_id) p = folder + "/obj_%d.k.zip" % conn_id obj.save_kzip(p) obj.mesh2kzip(p) a = skeleton.SkeletonAnnotation() a.scaling = obj.scaling a.comment = "rep coord - %d" % obj.size a.addNode(skeleton.SkeletonNode().from_scratch(a, obj.rep_coord[0], obj.rep_coord[1], obj.rep_coord[2], radius=1)) skeleton_utils.write_skeleton(folder + "/obj_%d.k.zip" % conn_id, [a])
[docs]def create_syn_rfc(sd_syn_ssv: 'segmentation.SegmentationDataset', path2file: str, overwrite: bool = False, rfc_path_out: str = None, max_dist_vx: int = 20) -> \ Tuple[ensemble.RandomForestClassifier, np.ndarray, np.ndarray]: """ Trains a random forest classifier (RFC) to distinguish between synaptic and non-synaptic objects. Features are generated from the objects in `sd_syn_ssv` associated with the annotated coordinates stored in `path2file`. The trained classifier is written to ``global_params.config.mpath_syn_rfc``. Args: sd_syn_ssv (segmentation.SegmentationDataset): SegmentationDataset object of type ``syn_ssv``. Used to identify synaptic object candidates annotated in the kzip/xls file at `path2file`. path2file (str): Path to kzip file with synapse labels as node comments ("non-synaptic", "synaptic"; labels used for classifier are 0 and 1 respectively). overwrite (bool): If True, existing files will be replaced. Defaults to False. rfc_path_out (str): Filename for dumped RFC. If None, the default path is used. max_dist_vx (int): Maximum voxel distance between sample and target. Defaults to 20. Returns: Tuple[ensemble.RandomForestClassifier, np.ndarray, np.ndarray]: The trained random forest classifier and the feature and label data. """ log = log_extraction if global_params.config.working_dir is not None or rfc_path_out is not None: if rfc_path_out is None: model_base_dir = os.path.split(global_params.config.mpath_syn_rfc)[0] log.info(f'Working directory is set to {global_params.config.working_dir} - ' f'trained RFC will be dumped at {model_base_dir}.') os.makedirs(model_base_dir, exist_ok=True) rfc_path_out = global_params.config.mpath_syn_rfc else: log = initialize_logging('create_syn_rfc', os.path.dirname(rfc_path_out)) if os.path.isfile(rfc_path_out) and not overwrite: msg = f'' log.error(msg) raise FileExistsError(msg) assert sd_syn_ssv.type == 'syn_ssv' log.info(f'Initiated RFC fitting procedure with GT file "{path2file}" and {sd_syn_ssv}.') mapped_synssv_objects_kzip = f'{os.path.split(rfc_path_out)[0]}/mapped_synssv.k.zip' if os.path.isfile(mapped_synssv_objects_kzip): if not overwrite: raise FileExistsError(f'File with mapped synssv objects already exists ' f'at "{mapped_synssv_objects_kzip}"') os.remove(mapped_synssv_objects_kzip) label_coords = [] labels = [] if path2file.endswith('k.zip'): anno = skeleton_utils.load_skeleton(path2file)['Synapse annotation'] for node in anno.getNodes(): c = node.getComment() if not ((c == 'synaptic') | (c == 'non-synaptic')): continue labels.append(c) label_coords.append(np.array(node.getCoordinate())) else: df = pandas.read_excel(path2file, header=0, names=[ 'ixs', 'coord', 'pre', 'post', 'syn', 'doublechecked', 'triplechecked', '?', 'comments']).values df = df[:, :7] for ix in range(df.shape[0]): c_orig = df[ix, 5] c = df[ix, 6] if type(c) != float and 'yes' in c: unified_comment = 'synaptic' elif type(c) != float and 'no' in c: unified_comment = 'non-synaptic' elif 'yes' in c_orig: unified_comment = 'synaptic' elif 'no' in c_orig: unified_comment = 'non-synaptic' else: log.warn(f'Did not understand GT comment "{c}". Skipping') continue labels.append(unified_comment) label_coords.append(np.array(df[ix, 1].split(','), dtype=np.float32)) labels = np.array(labels) label_coords = np.array(label_coords) # get deterministic order by sorting by coordinate first and then seeded shuffling ixs = [i[0] for i in sorted(enumerate(label_coords), key=lambda x: [x[1][0], x[1][1], x[1][2]])] ixs_random = np.arange(len(ixs)) np.random.seed(0) np.random.shuffle(ixs_random) ixs = np.array(ixs) label_coords = label_coords[ixs][ixs_random] labels = labels[ixs][ixs_random] log.info('Setting up kd-trees for coord-to-synapse mapping.') conn_kdtree = spatial.cKDTree(sd_syn_ssv.rep_coords * sd_syn_ssv.scaling) ds, list_ids = conn_kdtree.query(label_coords * sd_syn_ssv.scaling) synssv_ids = sd_syn_ssv.ids[list_ids] mask = np.ones(synssv_ids.shape, dtype=np.bool) log.info(f'Mapped {len(labels)} GT coordinates to {sd_syn_ssv.type}-objects.') for label_id in np.where(ds > 0)[0]: dists, close_ids = conn_kdtree.query(label_coords[label_id] * sd_syn_ssv.scaling, k=20) for close_id in close_ids[np.argsort(dists)]: conn_o = sd_syn_ssv.get_segmentation_object(sd_syn_ssv.ids[close_id]) vx_ds = np.sum(np.abs(conn_o.voxel_list - label_coords[label_id]), axis=-1) if np.min(vx_ds) < max_dist_vx: synssv_ids[label_id] = sd_syn_ssv.ids[close_id] break if np.min(vx_ds) > max_dist_vx: mask[label_id] = 0 if np.sum(mask) == 0: raise ValueError synssv_ids = synssv_ids[mask] labels = labels[mask] log.info(f'Found {np.sum(mask)}/{len(mask)} samples with a distance < {max_dist_vx} vx to the target.') log.info(f'Synapse features will now be generated.') features = [] pbar = tqdm.tqdm(total=len(synssv_ids), leave=False) for kk, synssv_id in enumerate(synssv_ids): synssv_o = sd_syn_ssv.get_segmentation_object(synssv_id) features.append(synssv_o_features(synssv_o)) pbar.update(1) pbar.close() features = np.array(features) log.info('Performing 10-fold cross validation.') rfc = ensemble.RandomForestClassifier(n_estimators=2000, max_features='sqrt', n_jobs=-1, random_state=0, oob_score=True) mask_annotated = (labels == "synaptic") | (labels == 'non-synaptic') v_features = features[mask_annotated] v_labels = labels[mask_annotated] v_labels = (v_labels == "synaptic").astype(np.int32) # score = cross_val_score(rfc, v_features, v_labels, cv=10) # log.info('RFC oob score: {:.4f}'.format(rfc.oob_score)) # log.info('RFC CV score +- std: {:.4f} +- {:.4f}'.format( # np.mean(score), np.std(score))) # if score < 0.95: # log.info(f'Individual CV scores: {score}') feature_names = np.array(synssv_o_featurenames()) probas = cross_val_predict(rfc, v_features, v_labels, cv=10, method='predict_proba') preds = np.argmax(probas, axis=1) log.info(metrics.classification_report(v_labels, preds, target_names=['non-synaptic', 'synaptic'])) if rfc_path_out is not None: import matplotlib.pyplot as plt import seaborn log.info(f'Wrote 10-fold cross validation probas, predictions, features and labels of trained RFC to ' f'"{os.path.split(rfc_path_out)[0]}".') np.save(os.path.split(rfc_path_out)[0] + '/rfc_probas.npy', probas) np.save(os.path.split(rfc_path_out)[0] + '/rfc_preds.npy', preds) np.save(os.path.split(rfc_path_out)[0] + '/rfc_labels.npy', v_features) np.save(os.path.split(rfc_path_out)[0] + '/rfc_features.npy', v_labels) np.save(os.path.split(rfc_path_out)[0] + '/rfc_feature_names.npy', feature_names) plt.figure() df = pandas.DataFrame(data=dict(mesh_area=v_features[:, feature_names == 'mesh_area_um2'].flatten(), size=v_features[:, feature_names == 'size_vx'].flatten(), correct=preds == v_labels, gt=['syn' if el == 1 else 'non-synaptic' for el in v_labels])) ax = seaborn.jointplot(data=df, x='size', y='mesh_area', hue='correct', xlim=(0, df['size'].max()*1.1), ylim=(0, df['mesh_area'].max()*1.1)) ax.set_axis_labels('size vx [1]', 'mesh area [um^2]') plt.savefig(os.path.split(rfc_path_out)[0] + '/feature_hist_size_vs_area_pred.png') plt.close() plt.figure() ax = seaborn.jointplot(data=df, x='size', y='mesh_area', hue='gt', xlim=(0, df['size'].max()*1.1), ylim=(0, df['mesh_area'].max()*1.1)) ax.set_axis_labels('size vx [1]', 'mesh area [um^2]') plt.savefig(os.path.split(rfc_path_out)[0] + '/feature_hist_size_vs_area_gt.png') plt.close() rfc.fit(v_features, v_labels) acc = rfc.score(v_features, v_labels) log.info(f'Training set accuracy: {acc:.4f}') feature_imp = rfc.feature_importances_ assert len(feature_imp) == len(feature_names) log.info('RFC importances:\n' + "\n".join( [f"{feature_names[ii]}: {feature_imp[ii]}" for ii in range(len(feature_imp))])) log.info(f'Synapses will be annotated and written to "{mapped_synssv_objects_kzip}" for manual revision.') skel = skeleton.Skeleton() anno = skeleton.SkeletonAnnotation() anno.scaling = sd_syn_ssv.scaling pbar = tqdm.tqdm(total=len(synssv_ids), leave=False) for kk, synssv_id in enumerate(synssv_ids): synssv_o = sd_syn_ssv.get_segmentation_object(synssv_id) rep_coord = synssv_o.rep_coord * sd_syn_ssv.scaling pred_correct = preds[kk] == v_labels[kk] n = skeleton.SkeletonNode().from_scratch(anno, rep_coord[0], rep_coord[1], rep_coord[2]) n.setComment(f'{preds[kk]} {pred_correct} {probas[kk][1]:.2f}') n.data.update({k: v for k, v in zip(feature_names, v_features[kk])}) anno.addNode(n) rep_coord = label_coords[kk] * sd_syn_ssv.scaling n_l = skeleton.SkeletonNode().from_scratch(anno, rep_coord[0], rep_coord[1], rep_coord[2]) n_l.setComment('gt node; {}'.format(labels[kk])) if not pred_correct: synssv_o.mesh2kzip(mapped_synssv_objects_kzip, ext_color=None, ply_name='{}.ply'.format(synssv_id)) anno.addNode(n_l) anno.addEdge(n, n_l) pbar.update(1) pbar.close() skel.add_annotation(anno) skel.to_kzip(mapped_synssv_objects_kzip) if rfc_path_out is not None: joblib.dump(rfc, rfc_path_out) log.info(f'Wrote parameters of trained RFC to "{rfc_path_out}".') else: log.info('Working directory and rfc_path_out not set - trained RFC was not dumped to file.') return rfc, v_features, v_labels
[docs]def synssv_o_features(synssv_o: segmentation.SegmentationObject) -> list: """ Collects syn_ssv feature for synapse prediction using an RFC. Args: synssv_o (segmentation.SegmentationObject): The SegmentationObject for which to collect features. Returns: list: A list of features for the given SegmentationObject. """ features = [synssv_o.size, synssv_o.mesh_area] partner_ids = synssv_o.attr_dict["neuron_partners"] for i_partner_id, partner_id in enumerate(partner_ids): features.append(synssv_o.attr_dict["n_mi_objs_%d" % i_partner_id]) features.append(synssv_o.attr_dict["n_mi_vxs_%d" % i_partner_id]) features.append(synssv_o.attr_dict["min_dst_mi_nm_%d" % i_partner_id]) features.append(synssv_o.attr_dict["n_vc_objs_%d" % i_partner_id]) features.append(synssv_o.attr_dict["n_vc_vxs_%d" % i_partner_id]) features.append(synssv_o.attr_dict["min_dst_vc_nm_%d" % i_partner_id]) return features
[docs]def synssv_o_featurenames() -> list: """ Returns a list of feature names used for synapse prediction. Returns: list: A list of feature names. """ return ['size_vx', 'mesh_area_um2', 'n_mi_objs_neuron1', 'n_mi_vxs_neuron1', 'min_dst_mi_nm_neuron1', 'n_vc_objs_neuron1', 'n_vc_vxs_neuron1', 'min_dst_vc_nm_neuron1', 'n_mi_objs_neuron2', 'n_mi_vxs_neuron2', 'min_dst_mi_nm_neuron2', 'n_vc_objs_neuron2', 'n_vc_vxs_neuron2', 'min_dst_vc_nm_neuron2']
[docs]def export_matrix(obj_version: Optional[str] = None, dest_folder: Optional[str] = None, threshold_syn: float = 0, export_kzip: bool = False, log: Optional[Logger] = None): """ Exports the connectivity matrix as a .csv file and optionally as a .kzip file. Args: obj_version (str, optional): Version of the object. Defaults to None. dest_folder (str, optional): Destination folder for the exported file. Defaults to None. threshold_syn (float, optional): Threshold for filtering synapses. Defaults to 0. export_kzip (bool, optional): If True, exports the connectivity matrix as a .kzip file. Note that this can result in large memory consumption. Defaults to False. log (Logger, optional): Logger for logging the process. Defaults to None. """ if threshold_syn is None: threshold_syn = global_params.config['cell_objects']['thresh_synssv_proba'] if dest_folder is None: dest_folder = global_params.config.working_dir + '/connectivity_matrix/' if log is None: log = log_extraction os.makedirs(os.path.split(dest_folder)[0], exist_ok=True) dest_name = dest_folder + '/conn_mat' log.info(f'Starting export of connectivity matrix with minimum probability {threshold_syn} as csv file to "{dest_name}".') sd_syn_ssv = segmentation.SegmentationDataset("syn_ssv", working_dir=global_params.config.working_dir, version=obj_version) syn_prob = sd_syn_ssv.load_numpy_data("syn_prob") m = syn_prob > threshold_syn m_axs = sd_syn_ssv.load_numpy_data("partner_axoness")[m] m_cts = sd_syn_ssv.load_numpy_data("partner_celltypes")[m] m_sp = sd_syn_ssv.load_numpy_data("partner_spiness")[m] m_coords = sd_syn_ssv.rep_coords[m] # m_sizes = sd_syn_ssv.sizes[m] m_sizes = sd_syn_ssv.load_numpy_data("mesh_area")[m] / 2 m_ssv_partners = sd_syn_ssv.load_numpy_data("neuron_partners")[m] m_syn_prob = syn_prob[m] m_syn_sign = sd_syn_ssv.load_numpy_data("syn_sign")[m] m_syn_asym_ratio = sd_syn_ssv.load_numpy_data("syn_type_sym_ratio")[m] m_spineheadvol = sd_syn_ssv.load_numpy_data("partner_spineheadvol")[m] m_latent_morph = sd_syn_ssv.load_numpy_data("latent_morph")[m] # N, 2, m m_latent_morph = m_latent_morph.reshape(len(m_latent_morph), -1) # N, 2*m # (loop of skeleton node generation) # make sure cache-arrays have ndim == 2, TODO: check when writing cached arrays m_sizes = np.multiply(m_sizes, m_syn_sign).squeeze()[:, None] # N, 1 m_axs = m_axs.squeeze() # N, 2 m_sp = m_sp.squeeze() # N, 2 m_syn_prob = m_syn_prob.squeeze()[:, None] # N, 1 table = np.concatenate([m_coords, m_ssv_partners, m_sizes, m_axs, m_cts, m_sp, m_syn_prob, m_spineheadvol, m_latent_morph], axis=1) # do not overwrite previous files if os.path.isfile(dest_name + '.csv'): st = datetime.datetime.fromtimestamp(time.time()).strftime('%Y-%m-%d %H:%M:%S') os.rename(dest_name + '.csv', '{}_{}.csv'.format(dest_name, st)) np.savetxt(dest_name + ".csv", table, delimiter="\t", header="x\ty\tz\tssv1\tssv2\tsize\tcomp1\tcomp2\tcelltype1\t" "celltype2\tspiness1\tspiness2\tsynprob\tspinehead_vol1" "\tspinehead_vol2" + "".join(["\tlatentmorph1_{}".format(ix) for ix in range( global_params.config['tcmn']['ndim_embedding'])]) + "".join(["\tlatentmorph2_{}".format(ix) for ix in range( global_params.config['tcmn']['ndim_embedding'])]) ) # # super high memory consumption for j0251 # wiring, borders = generate_wiring_array(log=log, thresh_syn_prob=threshold_syn, syn_version=obj_version) # plot_wiring(f'{dest_folder}', wiring, borders, borders, log=log) # plot_cumul_wiring(f'{dest_folder}', wiring, borders, min_cumul_synarea=0, log=log) if export_kzip: ax_labels = np.array(["N/A", "D", "A", "S"]) # TODO: this is already defined in handler.multiviews! ax_label_ids = np.array([-1, 0, 1, 2]) annotations = [] m_sizes = np.abs(m_sizes) ms_axs = np.sort(m_axs, axis=1) # transform labels 3 and 4 to 1 (bouton and terminal to axon to apply correct filter) ms_axs[ms_axs == 3] = 1 ms_axs[ms_axs == 4] = 1 # vigra currently requires numpy==1.11.1 try: u_axs = np.unique(ms_axs, axis=0) except TypeError: # in case numpy < 1.13 u_axs = np.vstack({tuple(row) for row in ms_axs}) for u_ax in u_axs: anno = skeleton.SkeletonAnnotation() anno.scaling = sd_syn_ssv.scaling cmt = "{} - {}".format(ax_labels[ax_label_ids == u_ax[0]][0], ax_labels[ax_label_ids == u_ax[1]][0]) anno.comment = cmt for i_syn in np.where(np.sum(np.abs(ms_axs - u_ax), axis=1) == 0)[0]: c = m_coords[i_syn] # somewhat approximated from sphere volume: r = np.power(m_sizes[i_syn] / 3., 1 / 3.) # r = m_sizes[i_syn] skel_node = skeleton.SkeletonNode().from_scratch(anno, c[0], c[1], c[2], radius=r) skel_node.data["ids"] = m_ssv_partners[i_syn] skel_node.data["size"] = m_sizes[i_syn] skel_node.data["syn_prob"] = m_syn_prob[i_syn] skel_node.data["sign"] = m_syn_sign[i_syn] skel_node.data["in_ex_frac"] = m_syn_asym_ratio[i_syn] skel_node.data['sp'] = m_sp[i_syn] skel_node.data['ct'] = m_cts[i_syn] skel_node.data['ax'] = m_axs[i_syn] skel_node.data['latent_morph'] = m_latent_morph[i_syn] anno.addNode(skel_node) annotations.append(anno) # do not overwrite previous files if os.path.isfile(dest_name + '.k.zip'): st = datetime.datetime.fromtimestamp(time.time()).strftime('%Y-%m-%d %H:%M:%S') os.rename(dest_name + '.k.zip', '{}_{}.k.zip'.format(dest_name, st)) skeleton_utils.write_skeleton(dest_name + ".k.zip", annotations)