# -*- coding: utf-8 -*-
# SyConn - Synaptic connectivity inference toolkit
#
# Copyright (c) 2016 - now
# Max Planck Institute of Neurobiology, Martinsried, Germany
# Authors: Philipp Schubert, Joergen Kornfeld
import datetime
import os
import shutil
import time
from collections import defaultdict
from logging import Logger
from typing import Optional, Dict, List, Tuple, TYPE_CHECKING
from itertools import chain
import joblib
import numpy as np
import tqdm
import pandas
from knossos_utils import skeleton_utils, skeleton
from scipy import spatial
from sklearn import ensemble
from sklearn.model_selection import cross_val_predict
from sklearn import metrics
try:
import open3d as o3d
except ImportError:
pass # for sphinx build
from . import log_extraction
from .. import global_params
from ..backend.storage import AttributeDict, VoxelStorageDyn, MeshStorage, VoxelStorageLazyLoading
from ..handler.basics import chunkify
from ..handler.config import initialize_logging
from ..mp import batchjob_utils as qu
from ..mp import mp_utils as sm
from ..reps import segmentation_helper as seghelp
from ..reps.super_segmentation_dataset import filter_ssd_by_total_pathlength
from ..reps import super_segmentation, segmentation, connectivity_helper as ch
from ..reps.rep_helper import subfold_from_ix, ix_from_subfold, get_unique_subfold_ixs
from ..proc.meshes import gen_mesh_voxelmask, calc_contact_syn_mesh
[docs]def collect_properties_from_ssv_partners(wd, obj_version=None, ssd_version=None, debug=False):
"""
Collect axoness, cell types and spiness from synaptic partners and stores
them in syn_ssv objects. Also maps syn_type_sym_ratio to the synaptic sign
(-1 for asym., 1 for sym. synapses).
The following keys will be available in the ``attr_dict`` of ``syn_ssv``
typed :class:`~syconn.reps.segmentation.SegmentationObject`:
* 'partner_axoness': Cell compartment type (axon: 1, dendrite: 0, soma: 2,
en-passant bouton: 3, terminal bouton: 4) of the partner neurons.
* 'partner_spiness': Spine compartment predictions (0: dendritic shaft,
1: spine head, 2: spine neck, 3: other) of both neurons.
* 'partner_spineheadvol': Spinehead volume in µm^3.
* 'partner_celltypes': Celltype of the both neurons.
* 'latent_morph': Local morphology embeddings of the pre- and post-
synaptic partners.
Args:
wd(str):
obj_version (str):
ssd_version (str) : Number of parallel jobs
debug : bool
"""
ssd = super_segmentation.SuperSegmentationDataset(working_dir=wd,
version=ssd_version)
multi_params = []
for ids_small_chunk in chunkify(ssd.ssv_ids[np.argsort(ssd.load_numpy_data('size'))[::-1]],
global_params.config.ncore_total * 2):
multi_params.append([wd, obj_version, ssd_version, ids_small_chunk])
if not qu.batchjob_enabled():
_ = sm.start_multiprocess_imap(
_collect_properties_from_ssv_partners_thread, multi_params,
debug=debug)
else:
_ = qu.batchjob_script(
multi_params, "collect_properties_from_ssv_partners",
remove_jobfolder=True)
# iterate over paths with syn
sd_syn_ssv = segmentation.SegmentationDataset("syn_ssv", working_dir=wd,
version=obj_version)
multi_params = []
for so_dir_paths in chunkify(sd_syn_ssv.so_dir_paths, global_params.config.ncore_total * 2):
multi_params.append([so_dir_paths, wd, obj_version,
ssd_version])
if not qu.batchjob_enabled():
_ = sm.start_multiprocess_imap(
_from_cell_to_syn_dict, multi_params,
debug=debug)
else:
_ = qu.batchjob_script(
multi_params, "from_cell_to_syn_dict", remove_jobfolder=True)
log_extraction.debug('Deleting cache dictionaries now.')
# delete cache_dicts
# TODO: start as thread!
sm.start_multiprocess_imap(_delete_all_cache_dc, [(ssv_id, ssd.config) for ssv_id in ssd.ssv_ids],
nb_cpus=None)
log_extraction.debug('Deleted all cache dictionaries.')
def _collect_properties_from_ssv_partners_thread(args):
"""
Helper function of 'collect_properties_from_ssv_partners'.
Notes:
* SSV objects that do not have any mesh vertex will be assigned zero values for all properties.
Args
args(tuple) :
see 'collect_properties_from_ssv_partners'
"""
wd, obj_version, ssd_version, ssv_ids = args
semseg2coords_kwargs = global_params.config['spines']['semseg2coords_spines']
n_embedding = global_params.config['tcmn']['ndim_embedding']
sd_syn_ssv = segmentation.SegmentationDataset(obj_type="syn_ssv",
working_dir=wd,
version=obj_version)
ssd = super_segmentation.SuperSegmentationDataset(working_dir=wd, version=ssd_version)
syn_neuronpartners = sd_syn_ssv.load_numpy_data("neuron_partners")
pred_key_ax = "{}_avg{}".format(global_params.config['compartments'][
'view_properties_semsegax']['semseg_key'],
global_params.config['compartments'][
'dist_axoness_averaging'])
for ssv_id in ssv_ids: # Iterate over cells
ssv_o = ssd.get_super_segmentation_object(ssv_id)
ssv_o.load_attr_dict()
cache_dc = AttributeDict(ssv_o.ssv_dir + "/cache_syn.pkl",
read_only=False, disable_locking=True)
curr_ssv_mask = (syn_neuronpartners[:, 0] == ssv_id) | \
(syn_neuronpartners[:, 1] == ssv_id)
ssv_synids = sd_syn_ssv.ids[curr_ssv_mask]
if len(ssv_synids) == 0 or ssv_o.mesh[1].shape[0] == 0:
cache_dc['partner_spineheadvol'] = np.zeros((len(ssv_synids),), dtype=np.float32)
cache_dc['partner_axoness'] = np.zeros((len(ssv_synids),), dtype=np.int32)
cache_dc['synssv_ids'] = ssv_synids
cache_dc['partner_spiness'] = np.zeros((len(ssv_synids),), dtype=np.int32)
cache_dc['partner_celltypes'] = np.zeros((len(ssv_synids),), dtype=np.int32)
cache_dc['latent_morph'] = np.zeros([len(ssv_synids), n_embedding], dtype=np.float32)
cache_dc.push()
continue
ssv_syncoords = sd_syn_ssv.rep_coords[curr_ssv_mask]
try:
ct = ssv_o.attr_dict['celltype_cnn_e3']
except KeyError:
ct = -1
celltypes = [ct] * len(ssv_synids)
curr_ax, latent_morph = ssv_o.attr_for_coords(
ssv_syncoords, attr_keys=[pred_key_ax, 'latent_morph'])
curr_sp = ssv_o.semseg_for_coords(ssv_syncoords, 'spiness', **semseg2coords_kwargs)
sh_vol = np.array([ssv_o.attr_dict['spinehead_vol'][syn_id] if syn_id in ssv_o.attr_dict['spinehead_vol']
else -1 for syn_id in ssv_synids], dtype=np.float32)
cache_dc['partner_spineheadvol'] = np.array(sh_vol)
cache_dc['partner_axoness'] = curr_ax
cache_dc['synssv_ids'] = ssv_synids
cache_dc['partner_spiness'] = curr_sp
cache_dc['partner_celltypes'] = np.array(celltypes)
cache_dc['latent_morph'] = latent_morph
cache_dc.push()
def _from_cell_to_syn_dict(args):
"""
args : Tuple
see 'collect_properties_from_ssv_partners'
"""
so_dir_paths, wd, obj_version, ssd_version = args
ssd = super_segmentation.SuperSegmentationDataset(working_dir=wd,
version=ssd_version)
sd_syn_ssv = segmentation.SegmentationDataset(obj_type="syn_ssv",
working_dir=wd,
version=obj_version)
cell_obj_conf = global_params.config['cell_objects']
for so_dir_path in so_dir_paths:
this_attr_dc = AttributeDict(so_dir_path + "/attr_dict.pkl",
read_only=False, disable_locking=True)
for synssv_id in this_attr_dc.keys():
synssv_o = sd_syn_ssv.get_segmentation_object(synssv_id)
synssv_o.load_attr_dict()
sym_asym_ratio = synssv_o.attr_dict['syn_type_sym_ratio']
syn_sign = -1 if sym_asym_ratio > cell_obj_conf['sym_thresh'] else 1
axoness = []
latent_morph = []
spinehead_vol = []
spiness = []
celltypes = []
for ssv_partner_id in synssv_o.attr_dict["neuron_partners"]:
ssv_o = ssd.get_super_segmentation_object(ssv_partner_id)
cache_dc = AttributeDict(ssv_o.ssv_dir + "/cache_syn.pkl")
index = np.transpose(np.nonzero(cache_dc['synssv_ids'] == synssv_id))
if len(index) != 1:
msg = f"Could not find synssv with ID {synssv_id} in 'cache_syn.pkl' of {ssv_o}."
log_extraction.error(msg)
raise ValueError(msg)
index = index[0][0]
axoness.append(cache_dc['partner_axoness'][index])
spiness.append(cache_dc['partner_spiness'][index])
celltypes.append(cache_dc['partner_celltypes'][index])
latent_morph.append(cache_dc['latent_morph'][index])
spinehead_vol.append(cache_dc['partner_spineheadvol'][index])
synssv_o.attr_dict.update({'partner_axoness': axoness,
'partner_spiness': spiness,
'partner_celltypes': celltypes,
'partner_spineheadvol': spinehead_vol,
'syn_sign': syn_sign,
'latent_morph': latent_morph})
this_attr_dc[synssv_id] = synssv_o.attr_dict
this_attr_dc.push()
def _delete_all_cache_dc(args):
ssv_id, config = args
ssv_o = super_segmentation.SuperSegmentationObject(ssv_id, config=config)
if os.path.exists(ssv_o.ssv_dir + "/cache_syn.pkl"):
os.remove(ssv_o.ssv_dir + "/cache_syn.pkl")
[docs]def filter_relevant_syn(sd_syn: segmentation.SegmentationDataset,
ssd: super_segmentation.SuperSegmentationDataset, log: Logger) -> Dict[int, list]:
"""
This function filters (likely ;-) ) the intra-ssv contact sites (inside of an ssv, not between ssvs)
that do not need to be agglomerated.
Notes:
* Also applicable to cs.
Args:
sd_syn:
ssd:
log:
Returns:
Lookup from encoded SSV partner IDs (see :py:func:`~syconn.reps.connectivity_helper.sv_id_to_partner_ids_vec`
for decoding into SSV IDs) to SV syn. object IDs, keys: encoded SSV syn IDs; values: List of SV syn IDs.
"""
if log is None:
log = log_extraction
# get all cs IDs belonging to syn objects and then retrieve corresponding SVs IDs via bit shift
# syn objects are just a subset of contact site objects (which originally store the partner IDs) with the same IDs
# -> not necessary to load the cs_ids.
syn_ids = sd_syn.ids.copy()
sv_ids = ch.cs_id_to_partner_ids_vec(syn_ids)
log.debug(f'Generated supervoxel IDs for all {sd_syn.type} objects.')
# this might mean that all syn between svs with IDs>max(np.uint32) are discarded
sv_ids[sv_ids > ssd.mapping_lookup_reverse.id_array[-1]] = 0
# this creates a lookup dict for SV to SSV ID only for SV involved in synaptic contacts
# -> This produces extra overhead when processing flattened SV graphs (SSVs only consist of 1 SV)
mapping_dc = ssd.sv2ssv_ids(np.unique(sv_ids.flatten()), nb_cpus=sm.cpu_count())
log.debug('Generated sv-ssv mapping dict.')
# TODO: apply with multiple processes and shared mapping_dc
def mapper(x): return mapping_dc[x] if x in mapping_dc else 0
# np.vectorize is not concurrent/more efficient than "map", just a more convenient.
mapped_ssv_ids = np.vectorize(mapper)(sv_ids.reshape(-1)).reshape(sv_ids.shape)
log.debug(f'Mapped SV IDs to SSV IDs for all {sd_syn.type} objects.')
del mapping_dc
mask = np.all(mapped_ssv_ids > 0, axis=1)
syn_ids = syn_ids[mask]
filtered_mapped_ssv_ids = mapped_ssv_ids[mask]
# this identifies all inter-ssv contact sites
mask = filtered_mapped_ssv_ids[:, 0] != filtered_mapped_ssv_ids[:, 1]
syn_ids = syn_ids[mask]
inter_ssv_contacts = filtered_mapped_ssv_ids[mask]
# TODO: generalize by adding it as a method parameter and pass config value to the method. Also add a config value
# for syn_ssv!
# filter small SSV if min path length was set in config
min_path_length_partners = global_params.config['cell_contacts']['min_path_length_partners']
if (sd_syn.type == 'cs') and (min_path_length_partners is not None) and (min_path_length_partners > 0):
filtered_ssv_ids = filter_ssd_by_total_pathlength(ssd, min_path_length_partners)
log.info(f'Filtering contact sites formed with at least once small cell (min. path length of a '
f'cell {min_path_length_partners} µm). {len(filtered_ssv_ids)} Cells fulfill that '
f'criterion.')
# check for every element of inter_ssv_contacts if it is inside filtered_ssv_ids
res = np.isin(inter_ssv_contacts.reshape(-1), filtered_ssv_ids).reshape(-1, 2)
mask = np.all(res, axis=1)
inter_ssv_contacts = inter_ssv_contacts[mask]
syn_ids = syn_ids[mask]
assert len(inter_ssv_contacts) == len(syn_ids)
if len(inter_ssv_contacts) == 0:
log.warning(f'No contact site found after filtering small cells.')
else:
log.info(f'Found {len(inter_ssv_contacts)} supervoxel contact sites (merged, unsplit) between'
f' {len(filtered_ssv_ids)} cells.')
# get bit shifted combination of SSV partner IDs, used to collect all corresponding synapse IDs between the two
# cells
relevant_ssv_ids_enc = np.left_shift(np.max(inter_ssv_contacts, axis=1), 32) + np.min(inter_ssv_contacts, axis=1)
log.debug(f'Filtered intra-cell {sd_syn.type} objects and created {sd_syn.type}_ssv IDs.')
# create lookup from SSV-wide synapses to SV syn. objects
ssv_to_syn_ids_dc = defaultdict(list)
for i_entry in range(len(relevant_ssv_ids_enc)):
ssv_to_syn_ids_dc[relevant_ssv_ids_enc[i_entry]].append(syn_ids[i_entry])
log.debug(f'Created cell-pair {sd_syn.type} object lists for subsequent split and agglomeration.')
return ssv_to_syn_ids_dc
[docs]def combine_and_split_syn(wd, cs_gap_nm=300, ssd_version=None, syn_version=None,
nb_cpus=None, n_folders_fs=10000, log=None, overwrite=False):
"""
Creates 'syn_ssv' objects from 'syn' objects. Therefore, computes connected
syn-objects on SSV level and aggregates the respective 'syn' attributes ['cs_id', 'asym_prop', 'sym_prop', ].
All objects of the resulting 'syn_ssv' SegmentationDataset contain the following attributes:
['syn_sign', 'syn_type_sym_ratio', 'asym_prop', 'sym_prop', 'cs_ids', 'neuron_partners']
Notes:
* 'rep_coord' property is calculated as the voxel (part of the object) closest to the center of mass of
all object voxels.
* 'cs_id'/'cs_ids' is the same as syn_id ('syn' are just a subset of 'cs', preserving the IDs).
Args:
wd :
cs_gap_nm :
ssd_version :
syn_version :
nb_cpus :
log:
n_folders_fs:
overwrite:
"""
ssd = super_segmentation.SuperSegmentationDataset(wd, version=ssd_version)
syn_sd = segmentation.SegmentationDataset("syn", working_dir=wd, version=syn_version)
# TODO: this procedure creates folders with single and double digits, e.g. '0' and '00'. Single digit folders are
# not used during write-outs, they are probably generated within this method's makedirs
log_extraction.debug(f'Filtering relevant synapses.')
rel_ssv_with_syn_ids = filter_relevant_syn(syn_sd, ssd, log=log)
log_extraction.debug(f'Filtering relevant synapses done.')
storage_location_ids = get_unique_subfold_ixs(n_folders_fs)
n_used_paths = min(global_params.config.ncore_total * 8, len(storage_location_ids),
len(rel_ssv_with_syn_ids))
voxel_rel_paths = chunkify([subfold_from_ix(ix, n_folders_fs) for ix in storage_location_ids],
n_used_paths)
# target SD for SSV syn objects
sd_syn_ssv = segmentation.SegmentationDataset("syn_ssv", working_dir=wd,
version="0", create=False,
n_folders_fs=n_folders_fs)
if os.path.exists(sd_syn_ssv.so_storage_path):
if not overwrite:
raise FileExistsError(f'"{sd_syn_ssv.so_storage_path}" already exists, but '
f'overwrite was set to False.')
shutil.rmtree(sd_syn_ssv.so_storage_path)
# prepare folder structure
voxel_rel_paths_2stage = np.unique([subfold_from_ix(ix, n_folders_fs)[:-2]
for ix in storage_location_ids])
for p in voxel_rel_paths_2stage:
os.makedirs(sd_syn_ssv.so_storage_path + p)
# TODO: apply weighting-scheme to balance worker load
rel_ssv_with_syn_ids_items = list(rel_ssv_with_syn_ids.items())
rel_synssv_to_syn_ids_items_chunked = chunkify(rel_ssv_with_syn_ids_items, n_used_paths)
multi_params = [(wd, rel_synssv_to_syn_ids_items_chunked[ii], voxel_rel_paths[ii],
syn_sd.version, sd_syn_ssv.version, cs_gap_nm) for
ii in range(n_used_paths)]
if not qu.batchjob_enabled():
_ = sm.start_multiprocess_imap(_combine_and_split_syn_thread,
multi_params, nb_cpus=nb_cpus, debug=False)
else:
_ = qu.batchjob_script(
multi_params, "combine_and_split_syn", remove_jobfolder=True, log=log)
def _combine_and_split_syn_thread(args):
wd = args[0]
rel_ssv_with_syn_ids_items = args[1]
voxel_rel_paths = args[2]
syn_version = args[3]
syn_ssv_version = args[4]
cs_gap_nm = args[5]
sd_syn_ssv = segmentation.SegmentationDataset("syn_ssv", working_dir=wd,
version=syn_ssv_version)
sd_syn = segmentation.SegmentationDataset("syn", working_dir=wd, version=syn_version)
scaling = sd_syn.scaling
syn_meshing_kws = global_params.config['meshes']['meshing_props_points']['syn_ssv']
mesh_min_obj_vx = global_params.config['meshes']['mesh_min_obj_vx']
cell_obj_cnf = global_params.config['cell_objects']
use_new_subfold = global_params.config.use_new_subfold
# TODO: add to config, also used in 'ix_from_subfold' if 'global_params.config.use_new_subfold=True'
div_base = 1e3
id_chunk_cnt = 0
n_per_voxel_path = np.ceil(float(len(rel_ssv_with_syn_ids_items)) / len(voxel_rel_paths))
n_items_for_path = 0
cur_path_id = 0
base_dir = sd_syn_ssv.so_storage_path + voxel_rel_paths[cur_path_id]
os.makedirs(base_dir, exist_ok=True)
# get ID/path to storage to save intermediate results
base_id = ix_from_subfold(voxel_rel_paths[cur_path_id], sd_syn.n_folders_fs)
syn_ssv_id = base_id
voxel_dc = VoxelStorageLazyLoading(base_dir + "/voxel.npz", overwrite=True)
attr_dc = AttributeDict(base_dir + "/attr_dict.pkl", read_only=False)
mesh_dc = MeshStorage(base_dir + "/mesh.pkl", read_only=False)
for ssvpartners_enc, syn_ids in rel_ssv_with_syn_ids_items:
n_items_for_path += 1
ssv_ids = ch.cs_id_to_partner_ids_vec([ssvpartners_enc])[0]
syn = sd_syn.get_segmentation_object(syn_ids[0])
# verify ssv_partner_ids
syn.load_attr_dict()
syn_attr_list = [syn.attr_dict] # used to collect syn properties
voxel_list = [syn.voxel_list]
# store index of syn. objects for attribute dict retrieval
synix_list = [0] * len(voxel_list[0])
for syn_ix, syn_id in enumerate(syn_ids[1:]):
syn = sd_syn.get_segmentation_object(syn_id)
syn.load_attr_dict()
syn_attr_list.append(syn.attr_dict)
voxel_list.append(syn.voxel_list)
synix_list += [syn_ix] * len(voxel_list[-1])
syn_attr_list = np.array(syn_attr_list)
synix_list = np.array(synix_list)
if len(synix_list) == 0:
msg = 'Voxels not available for syn-objects {}.'.format(syn_ids)
log_extraction.error(msg)
raise ValueError(msg)
ccs = connected_cluster_kdtree(voxel_list, dist_intra_object=cs_gap_nm,
dist_inter_object=20000, scale=scaling)
voxel_list = np.concatenate(voxel_list)
for this_cc in ccs:
# do not process synapse again if job has been restarted
if syn_ssv_id not in attr_dc:
this_cc_mask = np.array(list(this_cc))
# retrieve the index of the syn objects selected for this CC
this_syn_ixs, this_syn_ids_cnt = np.unique(synix_list[this_cc_mask],
return_counts=True)
# the weight is important
this_agg_syn_weights = this_syn_ids_cnt / np.sum(this_syn_ids_cnt)
if np.sum(this_syn_ids_cnt) < cell_obj_cnf['min_obj_vx']['syn_ssv']:
continue
this_attr = syn_attr_list[this_syn_ixs]
this_vx = voxel_list[this_cc_mask]
syn_ssv = sd_syn_ssv.get_segmentation_object(syn_ssv_id)
if (os.path.abspath(syn_ssv.attr_dict_path)
!= os.path.abspath(base_dir + "/attr_dict.pkl")):
raise ValueError(f'Path mis-match!')
synssv_attr_dc = dict(neuron_partners=ssv_ids)
voxel_dc[syn_ssv_id] = this_vx
synssv_attr_dc["rep_coord"] = (seghelp.calc_center_of_mass(this_vx * scaling) // scaling).astype(np.int32)
synssv_attr_dc["bounding_box"] = np.array([np.min(this_vx, axis=0), np.max(this_vx, axis=0)])
synssv_attr_dc["size"] = len(this_vx)
# calc_contact_syn_mesh returns a list with a single mesh (for syn_ssv)
if mesh_min_obj_vx < synssv_attr_dc["size"]:
syn_ssv._mesh = calc_contact_syn_mesh(syn_ssv, voxel_dc=voxel_dc, **syn_meshing_kws)[0]
mesh_dc[syn_ssv.id] = syn_ssv.mesh
synssv_attr_dc["mesh_bb"] = syn_ssv.mesh_bb
synssv_attr_dc["mesh_area"] = syn_ssv.mesh_area
else:
zero_mesh = [np.zeros((0,), dtype=np.int32), np.zeros((0,), dtype=np.int32),
np.zeros((0,), dtype=np.float32)]
mesh_dc[syn_ssv.id] = zero_mesh
synssv_attr_dc["mesh_bb"] = synssv_attr_dc["bounding_box"] * scaling
synssv_attr_dc["mesh_area"] = 0
# aggregate syn properties
syn_props_agg = {}
# cs_id is the same as syn_id ('syn' are just a subset of 'cs')
for dc in this_attr:
for k in ['cs_id', 'sym_prop', 'asym_prop']:
syn_props_agg.setdefault(k, []).append(dc[k])
# rename and delete old entry
syn_props_agg['cs_ids'] = syn_props_agg['cs_id']
del syn_props_agg['cs_id']
# 'syn_ssv' synapse type as weighted sum of the 'syn' fragment types
sym_prop = np.sum(this_agg_syn_weights * np.array(syn_props_agg['sym_prop']))
asym_prop = np.sum(this_agg_syn_weights * np.array(syn_props_agg['asym_prop']))
syn_props_agg['sym_prop'] = sym_prop
syn_props_agg['asym_prop'] = asym_prop
if sym_prop + asym_prop == 0:
sym_ratio = -1
else:
sym_ratio = sym_prop / float(asym_prop + sym_prop)
syn_props_agg["syn_type_sym_ratio"] = sym_ratio
syn_sign = -1 if sym_ratio > cell_obj_cnf['sym_thresh'] else 1
syn_props_agg["syn_sign"] = syn_sign
# add syn_ssv dict to AttributeStorage
synssv_attr_dc.update(syn_props_agg)
attr_dc[syn_ssv_id] = synssv_attr_dc
if use_new_subfold:
syn_ssv_id += np.uint(1)
if syn_ssv_id - base_id >= div_base:
# next ID chunk mapped to this storage
id_chunk_cnt += 1
old_base_id = base_id
base_id += np.uint(sd_syn_ssv.n_folders_fs * div_base * id_chunk_cnt)
assert subfold_from_ix(base_id, sd_syn_ssv.n_folders_fs, old_version=False) == \
subfold_from_ix(old_base_id, sd_syn_ssv.n_folders_fs, old_version=False)
syn_ssv_id = base_id
else:
syn_ssv_id += np.uint(sd_syn.n_folders_fs)
if n_items_for_path > n_per_voxel_path:
voxel_dc.push()
voxel_dc.close()
mesh_dc.push()
attr_dc.push()
cur_path_id += 1
if len(voxel_rel_paths) == cur_path_id:
raise ValueError(f'Worker ran out of possible storage paths for storing {sd_syn_ssv.type}.')
n_items_for_path = 0
id_chunk_cnt = 0
base_id = ix_from_subfold(voxel_rel_paths[cur_path_id], sd_syn.n_folders_fs)
syn_ssv_id = base_id
base_dir = sd_syn_ssv.so_storage_path + voxel_rel_paths[cur_path_id]
os.makedirs(base_dir, exist_ok=True)
voxel_dc = VoxelStorageLazyLoading(base_dir + "/voxel.npz")
attr_dc = AttributeDict(base_dir + "/attr_dict.pkl", read_only=False)
mesh_dc = MeshStorage(base_dir + "/mesh.pkl", read_only=False)
if n_items_for_path > 0:
voxel_dc.push()
voxel_dc.close()
attr_dc.push()
mesh_dc.push()
[docs]def connected_cluster_kdtree(voxel_coords: List[np.ndarray], dist_intra_object: float,
dist_inter_object: float, scale: np.ndarray) -> List[set]:
"""
Identify connected components within N objects. Two stage process: 1st stage adds edges between every
object voxel which are at most 2 voxels apart. The edges are added to a global graph which is used to
calculate connected components. In the 2nd stage, connected components are considered close if they are within a
maximum distance of `dist_inter_object` between a random voxel used as their representative coordinate.
Close connected components will then be connected if the minimum distance between any of their voxels is
smaller than `dist_intra_object`.
Args:
voxel_coords: List of numpy arrays in voxel coordinates.
dist_intra_object: Maximum distance between two voxels of different synapse fragments to
consider them the same object. In nm.
dist_inter_object: Maximum distance between two objects to check for close voxels
between them. In nm.
scale: Voxel sizes in nm (XYZ).
Returns:
Connected components across all N input objects with at most `dist_intra_cluster` distance.
"""
import networkx as nx
graph = nx.Graph()
ixs_offset = np.cumsum([0] + [len(syn_vxs) for syn_vxs in voxel_coords[:-1]])
# add intra object edges
for ii in range(len(voxel_coords)):
off = ixs_offset[ii]
graph.add_nodes_from(np.arange(len(voxel_coords[ii])) + off)
kdtree = spatial.cKDTree(voxel_coords[ii])
pairs = np.array(list(kdtree.query_pairs(r=2)), dtype=np.int64)
graph.add_edges_from(pairs + off)
del kdtree, pairs
voxel_coords_flat = np.concatenate(voxel_coords) * scale
ccs = [np.array(list(cc)) for cc in nx.connected_components(graph)]
rep_coords = np.array([voxel_coords_flat[cc[0]] for cc in ccs])
kdtree = spatial.cKDTree(rep_coords)
pairs = kdtree.query_pairs(r=dist_inter_object)
del kdtree
# add minimal inter-object edges
for c1, c2 in pairs:
c1_ixs = ccs[c1]
c2_ixs = ccs[c2]
kd1 = spatial.cKDTree(voxel_coords_flat[c1_ixs])
dists, nn_ixs = kd1.query(voxel_coords_flat[c2_ixs], distance_upper_bound=dist_intra_object)
if min(dists) > dist_intra_object:
continue
argmin = np.argmin(dists)
ix_c1 = c1_ixs[nn_ixs[argmin]]
ix_c2 = c2_ixs[argmin]
graph.add_edge(ix_c1, ix_c2)
return list(nx.connected_components(graph))
[docs]def combine_and_split_cs(wd, ssd_version=None, cs_version=None, nb_cpus=None, n_folders_fs=10000,
log=None, overwrite=False):
"""
Creates 'cs_ssv' objects from 'cs' objects. Computes connected
cs-objects on SSV level and re-calculates their attributes (mesh_area, size, ..).
In contrast to :func:`~combine_and_split_syn` this method performs connected component analysis on
the mesh of all cell-cell contacts instead of their voxels.
Notes:
* 'rep_coord' property is calculated as the mesh vertex closest to the center of mass of all mesh vertices.
Args:
wd :
ssd_version :
cs_version :
nb_cpus :
log:
n_folders_fs:
overwrite:
"""
ssd = super_segmentation.SuperSegmentationDataset(wd, version=ssd_version)
cs_sd = segmentation.SegmentationDataset("cs", working_dir=wd, version=cs_version)
cs_version = cs_sd.version
rel_ssv_with_cs_ids = filter_relevant_syn(cs_sd, ssd, log=log)
del ssd, cs_sd
storage_location_ids = get_unique_subfold_ixs(n_folders_fs)
n_used_paths = min(global_params.config.ncore_total * 30, len(storage_location_ids),
len(rel_ssv_with_cs_ids))
voxel_rel_paths = chunkify([subfold_from_ix(ix, n_folders_fs) for ix in storage_location_ids],
n_used_paths)
# target SD for SSV cs objects
sd_cs_ssv = segmentation.SegmentationDataset("cs_ssv", working_dir=wd, version="0", create=False,
n_folders_fs=n_folders_fs)
if os.path.exists(sd_cs_ssv.so_storage_path):
if not overwrite:
raise FileExistsError(f'"{sd_cs_ssv.so_storage_path}" already exists, but overwrite was set to False.')
shutil.rmtree(sd_cs_ssv.so_storage_path)
# prepare folder structure
voxel_rel_paths_2stage = np.unique([subfold_from_ix(ix, n_folders_fs)[:-2]
for ix in storage_location_ids])
for p in voxel_rel_paths_2stage:
os.makedirs(sd_cs_ssv.so_storage_path + p)
rel_ssv_with_cs_ids_items = list(rel_ssv_with_cs_ids.items())
rel_csssv_to_cs_ids_items_chunked = chunkify(rel_ssv_with_cs_ids_items, n_used_paths)
multi_params = [(wd, rel_csssv_to_cs_ids_items_chunked[ii], voxel_rel_paths[ii],
cs_version, sd_cs_ssv.version) for
ii in range(n_used_paths)]
del rel_ssv_with_cs_ids_items, rel_csssv_to_cs_ids_items_chunked, voxel_rel_paths_2stage, voxel_rel_paths
if not qu.batchjob_enabled():
_ = sm.start_multiprocess_imap(_combine_and_split_cs_thread, multi_params, nb_cpus=nb_cpus, debug=False)
else:
_ = qu.batchjob_script(multi_params, "combine_and_split_cs", remove_jobfolder=True, log=log)
def _combine_and_split_cs_thread(args):
wd = args[0]
rel_ssv_with_cs_ids_items = args[1]
voxel_rel_paths = args[2]
cs_version = args[3]
cs_ssv_version = args[4]
sd_cs_ssv = segmentation.SegmentationDataset("cs_ssv", working_dir=wd, version=cs_ssv_version)
sd_cs = segmentation.SegmentationDataset("cs", working_dir=wd, version=cs_version)
scaling = sd_cs.scaling
meshing_kws = global_params.config['meshes']['meshing_props_points']['cs_ssv']
mesh_min_obj_vx = global_params.config['meshes']['mesh_min_obj_vx']
use_new_subfold = global_params.config.use_new_subfold
# TODO: add to config, also used in 'ix_from_subfold' if 'global_params.config.use_new_subfold=True'
div_base = 1e3
id_chunk_cnt = 0
n_per_voxel_path = np.ceil(float(len(rel_ssv_with_cs_ids_items)) / len(voxel_rel_paths))
n_items_for_path = 0
cur_path_id = 0
base_dir = sd_cs_ssv.so_storage_path + voxel_rel_paths[cur_path_id]
os.makedirs(base_dir, exist_ok=True)
# get ID/path to storage to save intermediate results
base_id = ix_from_subfold(voxel_rel_paths[cur_path_id], sd_cs.n_folders_fs)
cs_ssv_id = base_id
attr_dc = AttributeDict(base_dir + "/attr_dict.pkl", read_only=False)
mesh_dc = MeshStorage(base_dir + "/mesh.pkl", read_only=False, compress=True)
# iterate over cell partners and their contact site IDs (each contact site is between two supervoxels
# of the partner cells)
for ssvpartners_enc, cs_ids in rel_ssv_with_cs_ids_items:
n_items_for_path += 1
ssv_ids = ch.cs_id_to_partner_ids_vec([ssvpartners_enc])[0]
# verify ssv_partner_ids
cs_lst = sd_cs.get_segmentation_object(cs_ids)
vxl_iter_lst = []
vx_cnt = 0
for cs in cs_lst:
vx_store = VoxelStorageDyn(cs.voxel_path, read_only=True,
disable_locking=True)
vxl_iter_lst.append(vx_store.iter_voxelmask_offset(cs.id, overlap=1))
vx_cnt += vx_store.object_size(cs.id)
if mesh_min_obj_vx > vx_cnt:
ccs = []
else:
# generate connected component meshes; vertices are in nm
ccs = gen_mesh_voxelmask(chain(*vxl_iter_lst), scale=scaling, **meshing_kws)
for mesh_cc in ccs:
cs_ssv = sd_cs_ssv.get_segmentation_object(cs_ssv_id)
if (os.path.abspath(cs_ssv.attr_dict_path)
!= os.path.abspath(base_dir + "/attr_dict.pkl")):
raise ValueError(f'Path mis-match!')
csssv_attr_dc = dict(neuron_partners=ssv_ids)
# don't store normals
cs_ssv._mesh = [mesh_cc[0], mesh_cc[1], np.zeros((0,), dtype=np.float32)]
mesh_dc[cs_ssv.id] = cs_ssv.mesh
csssv_attr_dc["mesh_bb"] = cs_ssv.mesh_bb
csssv_attr_dc["mesh_area"] = cs_ssv.mesh_area
csssv_attr_dc["bounding_box"] = (cs_ssv.mesh_bb // scaling).astype(np.int32)
csssv_attr_dc["rep_coord"] = (seghelp.calc_center_of_mass(mesh_cc[1].reshape((-1, 3))) // scaling).astype(np.int64)
csssv_attr_dc["cs_ids"] = list(cs_ids)
# create open3d mesh instance to compute volume
# # TODO: add this as soon open3d >= 0.11 is supported (glibc error on cluster prevents upgrade)
# tm = o3d.geometry.TriangleMesh
# tm.triangles = o3d.utility.Vector3iVector(mesh_cc[0].reshape((-1, 3)))
# tm.vertices = o3d.utility.Vector3dVector(mesh_cc[1].reshape((-1, 3)))
# tm.normals = o3d.utility.Vector3dVector(mesh_cc[2].reshape((-1, 3)))
# assert tm.is_watertight()
# csssv_attr_dc["size"] = tm.get_volume // np.prod(scaling)
csssv_attr_dc["size"] = 0
# add cs_ssv dict to AttributeStorage
attr_dc[cs_ssv_id] = csssv_attr_dc
if use_new_subfold:
cs_ssv_id += np.uint(1)
if cs_ssv_id - base_id >= div_base:
# next ID chunk mapped to this storage
id_chunk_cnt += 1
old_base_id = base_id
base_id += np.uint(sd_cs_ssv.n_folders_fs * div_base) * id_chunk_cnt
assert subfold_from_ix(base_id, sd_cs_ssv.n_folders_fs, old_version=False) == \
subfold_from_ix(old_base_id, sd_cs_ssv.n_folders_fs, old_version=False)
cs_ssv_id = base_id
else:
cs_ssv_id += np.uint(sd_cs.n_folders_fs)
if n_items_for_path > n_per_voxel_path:
attr_dc.push()
mesh_dc.push()
cur_path_id += 1
if len(voxel_rel_paths) == cur_path_id:
raise ValueError(f'Worker ran out of possible storage paths for storing {sd_cs_ssv.type}.')
n_items_for_path = 0
id_chunk_cnt = 0
base_id = ix_from_subfold(voxel_rel_paths[cur_path_id], sd_cs.n_folders_fs)
cs_ssv_id = base_id
base_dir = sd_cs_ssv.so_storage_path + voxel_rel_paths[cur_path_id]
os.makedirs(base_dir, exist_ok=True)
attr_dc = AttributeDict(base_dir + "/attr_dict.pkl", read_only=False)
mesh_dc = MeshStorage(base_dir + "/mesh.pkl", read_only=False)
if n_items_for_path > 0:
attr_dc.push()
mesh_dc.push()
[docs]def cc_large_voxel_lists(voxel_list, cs_gap_nm, max_concurrent_nodes=5000, verbose=False):
kdtree = spatial.cKDTree(voxel_list)
checked_ids = np.array([], dtype=np.int32)
next_ids = np.array([0])
ccs = [set(next_ids)]
current_ccs = 0
vx_ids = np.arange(len(voxel_list), dtype=np.int32)
while True:
if verbose:
log_extraction.debug("NEXT - %d - %d" % (len(next_ids),
len(checked_ids)))
for cc in ccs:
log_extraction.debug("N voxels in cc: %d" % (len(cc)))
if len(next_ids) == 0:
p_ids = vx_ids[~np.in1d(vx_ids, checked_ids)]
if len(p_ids) == 0:
break
else:
current_ccs += 1
ccs.append(set([p_ids[0]]))
next_ids = p_ids[:1]
q_ids = kdtree.query_ball_point(voxel_list[next_ids], r=cs_gap_nm, )
checked_ids = np.concatenate([checked_ids, next_ids])
for q_id in q_ids:
ccs[current_ccs].update(q_id)
cc_ids = np.array(list(ccs[current_ccs]))
next_ids = vx_ids[cc_ids[~np.in1d(cc_ids, checked_ids)][:max_concurrent_nodes]]
return ccs
[docs]def map_objects_from_synssv_partners(wd: str, obj_version: Optional[str] = None,
ssd_version: Optional[str] = None,
n_jobs: Optional[int] = None,
debug: bool = False, log: Logger = None,
max_rep_coord_dist_nm: Optional[float] = None):
"""
Map sub-cellular objects of the synaptic partners of 'syn_ssv' objects and stores
them in their attribute dict.
The following keys will be available in the ``attr_dict`` of ``syn_ssv``-typed
:class:`~syconn.reps.segmentation.SegmentationObject`:
* 'n_mi_objs_%d':
* 'n_mi_vxs_%d':
* 'min_dst_mi_nm_%d':
* 'n_vc_objs_%d':
* 'n_vc_vxs_%d':
* 'min_dst_vc_nm_%d':
Args:
wd:
obj_version:
ssd_version:
n_jobs:
debug:
log:
max_rep_coord_dist_nm:
Returns:
"""
if n_jobs is None:
n_jobs = global_params.config.ncore_total * 4
if max_rep_coord_dist_nm is None:
max_rep_coord_dist_nm = global_params.config['cell_objects']['max_rep_coord_dist_nm']
ssd = super_segmentation.SuperSegmentationDataset(working_dir=wd,
version=ssd_version)
multi_params = []
for ids_small_chunk in chunkify(ssd.ssv_ids, n_jobs):
multi_params.append([wd, obj_version, ssd_version, ids_small_chunk,
max_rep_coord_dist_nm])
if not qu.batchjob_enabled():
_ = sm.start_multiprocess_imap(
_map_objects_from_synssv_partners_thread, multi_params,
debug=debug)
else:
_ = qu.batchjob_script(
multi_params, "map_objects_from_synssv_partners", log=log,
remove_jobfolder=True)
# iterate over paths with syn
sd_syn_ssv = segmentation.SegmentationDataset("syn_ssv", working_dir=wd,
version=obj_version)
multi_params = []
for so_dir_paths in chunkify(sd_syn_ssv.so_dir_paths, n_jobs):
multi_params.append([so_dir_paths, wd, obj_version,
ssd_version])
if not qu.batchjob_enabled():
_ = sm.start_multiprocess_imap(
_objects_from_cell_to_syn_dict, multi_params,
debug=False)
else:
_ = qu.batchjob_script(
multi_params, "objects_from_cell_to_syn_dict", log=log,
remove_jobfolder=True)
if log is None:
log = log_extraction
log.debug('Deleting cache dictionaries now.')
# delete cache_dc
sm.start_multiprocess_imap(_delete_all_cache_dc, [(ssv_id, ssd.config) for ssv_id in ssd.ssv_ids],
nb_cpus=global_params.config['ncores_per_node'])
log.debug('Deleted all cache dictionaries.')
def _map_objects_from_synssv_partners_thread(args: tuple):
"""
Helper function of 'map_objects_from_synssv_partners'.
Args:
args: see 'map_objects_from_synssv_partners'
Returns:
"""
max_vert_dist_nm = global_params.config['cell_objects']['max_vert_dist_nm']
# TODO: add global overwrite kwarg
overwrite = True
wd, obj_version, ssd_version, ssv_ids, max_rep_coord_dist_nm = args
use_new_subfold = global_params.config.use_new_subfold
sd_syn_ssv = segmentation.SegmentationDataset(obj_type="syn_ssv",
working_dir=wd,
version=obj_version)
ssd = super_segmentation.SuperSegmentationDataset(working_dir=wd,
version=ssd_version)
sd_vc = segmentation.SegmentationDataset(obj_type="vc", working_dir=wd)
sd_mi = segmentation.SegmentationDataset(obj_type="mi", working_dir=wd)
syn_neuronpartners = sd_syn_ssv.load_numpy_data("neuron_partners")
# dts = dict(id_mask=0, kds=0, map_verts=0, directio=0, meshcache=0)
for ssv_id in ssv_ids: # Iterate over cells
ssv_o = ssd.get_super_segmentation_object(ssv_id)
# start = time.time()
if overwrite and os.path.isfile(ssv_o.ssv_dir + "/cache_syn.pkl"):
os.remove(ssv_o.ssv_dir + "/cache_syn.pkl")
cache_dc = AttributeDict(ssv_o.ssv_dir + "/cache_syn.pkl",
read_only=False, disable_locking=True)
if not overwrite and ('n_vc_vxs' in cache_dc):
continue
# dts['directio'] += time.time() - start
curr_ssv_mask = (syn_neuronpartners[:, 0] == ssv_id) | \
(syn_neuronpartners[:, 1] == ssv_id)
synssv_ids = sd_syn_ssv.ids[curr_ssv_mask]
n_synssv = len(synssv_ids)
n_mi_objs = np.zeros((n_synssv,), dtype=np.int32)
n_mi_vxs = np.zeros((n_synssv,), dtype=np.int32)
n_vc_objs = np.zeros((n_synssv,), dtype=np.int32)
n_vc_vxs = np.zeros((n_synssv,), dtype=np.int32)
min_dst_mi = np.zeros((n_synssv,), dtype=np.float32)
min_dst_vc = np.zeros((n_synssv,), dtype=np.float32)
cache_dc['synssv_ids'] = synssv_ids
if n_synssv == 0:
cache_dc['n_mi_objs'] = n_mi_objs
cache_dc['n_mi_vxs'] = n_mi_vxs
cache_dc['n_vc_objs'] = n_vc_objs
cache_dc['n_vc_vxs'] = n_vc_vxs
cache_dc['min_dst_mi_nm'] = min_dst_mi
cache_dc['min_dst_vc_nm'] = min_dst_vc
cache_dc.push()
continue
# start = time.time()
vc_mask = np.in1d(sd_vc.ids, ssv_o.vc_ids)
mi_mask = np.in1d(sd_mi.ids, ssv_o.mi_ids)
# dts['id_mask'] += time.time() - start
vc_ids = sd_vc.ids[vc_mask]
mi_ids = sd_mi.ids[mi_mask]
vc_sizes = sd_vc.sizes[vc_mask]
mi_sizes = sd_mi.sizes[mi_mask]
# start = time.time()
kdtree_synssv = spatial.cKDTree(sd_syn_ssv.rep_coords[curr_ssv_mask] * sd_syn_ssv.scaling)
# vesicle clouds
kdtree_vc = spatial.cKDTree(sd_vc.rep_coords[vc_mask] * sd_vc.scaling)
# mitos
kdtree_mi = spatial.cKDTree(sd_mi.rep_coords[mi_mask] * sd_mi.scaling)
# returns a list of neighboring objects for every synssv (note: ix is now the index within ssv_o.mi_ids
close_mi_ixs = kdtree_synssv.query_ball_tree(kdtree_mi, r=max_rep_coord_dist_nm)
close_vc_ixs = kdtree_synssv.query_ball_tree(kdtree_vc, r=max_rep_coord_dist_nm)
# dts['kds'] += time.time() - start
# start = time.time()
close_mi_ids = mi_ids[np.unique(np.concatenate(close_mi_ixs)).astype(np.int32)]
close_vc_ids = vc_ids[np.unique(np.concatenate(close_vc_ixs).astype(np.int32))]
md_mi = seghelp.load_so_meshes_bulk(sd_mi.get_segmentation_object(close_mi_ids),
use_new_subfold=use_new_subfold)
md_vc = seghelp.load_so_meshes_bulk(sd_vc.get_segmentation_object(close_vc_ids),
use_new_subfold=use_new_subfold)
# md_synssv = seghelp.load_so_meshes_bulk(sd_syn_ssv.get_segmentation_object(synssv_ids),
# use_new_subfold=use_new_subfold)
# dts['meshcache'] += time.time() - start
# start = time.time()
for ii, synssv_id in enumerate(synssv_ids):
synssv_obj = sd_syn_ssv.get_segmentation_object(synssv_id)
# synssv_obj._mesh = md_synssv[synssv_id]
mis = sd_mi.get_segmentation_object(mi_ids[close_mi_ixs[ii]])
# load cached meshes
for jj, ix in enumerate(close_mi_ixs[ii]):
mi = mis[jj]
mi._size = mi_sizes[ix]
mi._mesh = md_mi[mi.id]
vcs = sd_vc.get_segmentation_object(vc_ids[close_vc_ixs[ii]])
for jj, ix in enumerate(close_vc_ixs[ii]):
vc = vcs[jj]
vc._size = vc_sizes[ix]
vc._mesh = md_vc[vc.id]
n_mi_objs[ii], n_mi_vxs[ii], min_dst_mi[ii] = _map_objects_from_synssv(
synssv_obj, mis, max_vert_dist_nm['mi'])
n_vc_objs[ii], n_vc_vxs[ii], min_dst_vc[ii] = _map_objects_from_synssv(
synssv_obj, vcs, max_vert_dist_nm['vc'])
# dts['map_verts'] += time.time() - start
# start = time.time()
cache_dc['n_mi_objs'] = n_mi_objs
cache_dc['n_mi_vxs'] = n_mi_vxs
cache_dc['min_dst_mi_nm'] = min_dst_mi
cache_dc['n_vc_objs'] = n_vc_objs
cache_dc['n_vc_vxs'] = n_vc_vxs
cache_dc['min_dst_vc_nm'] = min_dst_vc
cache_dc.push()
# dts['directio'] += time.time() - start
def _map_objects_from_synssv(synssv_o, seg_objs, max_vert_dist_nm, sample_fact=2):
"""
TODO: Loading meshes for approximating close-by object volume is slow - exchange with summed object size?
Maps cellular organelles to syn_ssv objects. Needed for the RFC model which
is executed in 'classify_synssv_objects'.
Helper function of `objects_to_single_synssv`.
Args:
synssv_o: 'syn_ssv' synapse object.
seg_objs: SegmentationObject of type 'vc' or 'mi'
max_vert_dist_nm: Query radius for SegmentationObject vertices. Used to estimate
number of nearby object voxels.
sample_fact: only use every Xth vertex.
Returns:
Number of SegmentationObjects with >0 vertices, approximated number of
object voxels within `max_vert_dist_nm` and minimal distance
(in nm; maximum value: 1e12 nm in case no object is present).
"""
synssv_kdtree = spatial.cKDTree(synssv_o.voxel_list[::sample_fact] * synssv_o.scaling)
min_dist = 1e12 # in nm
n_obj_vxs = []
for obj in seg_objs:
# use mesh vertices instead of voxels
obj_vxs = obj.mesh[1].reshape(-1, 3)[::sample_fact]
ds, _ = synssv_kdtree.query(obj_vxs, distance_upper_bound=max_vert_dist_nm)
# surface fraction of subcellular object which is close to synapse
close_frac = np.sum(ds < np.inf) / len(obj_vxs)
if np.min(ds) < min_dist:
min_dist = np.min(ds)
# estimate number of voxels by close-by surface area fraction times total number of voxels
n_obj_vxs.append(close_frac * obj.size)
n_obj_vxs = np.array(n_obj_vxs)
n_objects = np.sum(n_obj_vxs > 0)
n_vxs = np.sum(n_obj_vxs)
return n_objects, n_vxs, min_dist
def _objects_from_cell_to_syn_dict(args):
"""
args : Tuple
see 'map_objects_from_synssv_partners'
"""
so_dir_paths, wd, obj_version, ssd_version = args
ssd = super_segmentation.SuperSegmentationDataset(working_dir=wd,
version=ssd_version)
sd_syn_ssv = segmentation.SegmentationDataset(obj_type="syn_ssv",
working_dir=wd,
version=obj_version)
for so_dir_path in so_dir_paths:
this_attr_dc = AttributeDict(so_dir_path + "/attr_dict.pkl",
read_only=False, disable_locking=True)
for synssv_id in this_attr_dc.keys():
synssv_o = sd_syn_ssv.get_segmentation_object(synssv_id)
synssv_o.load_attr_dict()
map_dc = dict()
for ii, ssv_partner_id in enumerate(synssv_o.attr_dict["neuron_partners"]):
ssv_o = ssd.get_super_segmentation_object(ssv_partner_id)
cache_dc = AttributeDict(ssv_o.ssv_dir + "/cache_syn.pkl")
index = np.transpose(np.nonzero(cache_dc['synssv_ids'] == synssv_id))
if len(index) != 1:
msg = "Partner cell ID mismatch."
log_extraction.error(msg)
raise ValueError(msg)
index = index[0][0]
map_dc[f'n_mi_objs_{ii}'] = cache_dc['n_mi_objs'][index]
map_dc[f'n_mi_vxs_{ii}'] = cache_dc['n_mi_vxs'][index]
map_dc[f'n_vc_objs_{ii}'] = cache_dc['n_vc_objs'][index]
map_dc[f'n_vc_vxs_{ii}'] = cache_dc['n_vc_vxs'][index]
map_dc[f'min_dst_mi_nm_{ii}'] = cache_dc['min_dst_mi_nm'][index]
map_dc[f'min_dst_vc_nm_{ii}'] = cache_dc['min_dst_vc_nm'][index]
synssv_o.attr_dict.update(map_dc)
this_attr_dc[synssv_id] = synssv_o.attr_dict
this_attr_dc.push()
[docs]def classify_synssv_objects(wd, obj_version=None, log=None, nb_cpus=None):
"""
TODO: Replace by new synapse detection.
Classify SSV contact sites into synaptic or non-synaptic using an RFC model
and store the result in the attribute dict of the syn_ssv objects.
For requirements see `synssv_o_features`.
Args:
wd:
obj_version:
log:
nb_cpus:
Returns:
"""
sd_syn_ssv = segmentation.SegmentationDataset("syn_ssv", working_dir=wd,
version=obj_version)
multi_params = chunkify(sd_syn_ssv.so_dir_paths, global_params.config.ncore_total)
multi_params = [(so_dir_paths, wd, obj_version) for so_dir_paths in
multi_params]
if not qu.batchjob_enabled():
_ = sm.start_multiprocess_imap(_classify_synssv_objects_thread,
multi_params, nb_cpus=nb_cpus)
else:
_ = qu.batchjob_script(
multi_params, "classify_synssv_objects", log=log,
remove_jobfolder=True)
def _classify_synssv_objects_thread(args):
"""
Helper function of 'classify_synssv_objects'.
Args:
args : Tuple
see 'classify_synssv_objects'
"""
so_dir_paths, wd, obj_version = args
sd_syn_ssv = segmentation.SegmentationDataset(obj_type="syn_ssv",
working_dir=wd,
version=obj_version)
try:
rfc = joblib.load(global_params.config.mpath_syn_rfc)
except ImportError:
rfc = joblib.load(global_params.config.mpath_syn_rfc_fallback)
for so_dir_path in so_dir_paths:
this_attr_dc = AttributeDict(so_dir_path + "/attr_dict.pkl",
read_only=False)
for synssv_id in this_attr_dc.keys():
synssv_o = sd_syn_ssv.get_segmentation_object(synssv_id)
synssv_o.attr_dict = this_attr_dc[synssv_id]
feats = synssv_o_features(synssv_o)
syn_prob = rfc.predict_proba([feats])[0][1]
synssv_o.attr_dict.update({"syn_prob": syn_prob})
this_attr_dc[synssv_id] = synssv_o.attr_dict
this_attr_dc.push()
# Code for property extraction of contact sites (syn_ssv)
[docs]def write_conn_gt_kzips(conn, n_objects, folder):
if not os.path.exists(folder):
os.makedirs(folder)
conn_ids = conn.ids[np.random.choice(len(conn.ids), n_objects, replace=False)]
for conn_id in conn_ids:
obj = conn.get_segmentation_object(conn_id)
p = folder + "/obj_%d.k.zip" % conn_id
obj.save_kzip(p)
obj.mesh2kzip(p)
a = skeleton.SkeletonAnnotation()
a.scaling = obj.scaling
a.comment = "rep coord - %d" % obj.size
a.addNode(skeleton.SkeletonNode().from_scratch(a, obj.rep_coord[0],
obj.rep_coord[1],
obj.rep_coord[2],
radius=1))
skeleton_utils.write_skeleton(folder + "/obj_%d.k.zip" % conn_id, [a])
[docs]def create_syn_rfc(sd_syn_ssv: 'segmentation.SegmentationDataset', path2file: str, overwrite: bool = False,
rfc_path_out: str = None, max_dist_vx: int = 20) -> \
Tuple[ensemble.RandomForestClassifier, np.ndarray, np.ndarray]:
"""
Trains a random forest classifier (RFC) to distinguish between synaptic and non-synaptic
objects. Features are generated from the objects in `sd_syn_ssv` associated with the annotated
coordinates stored in `path2file`.
Will write the trained classifier to ``global_params.config.mpath_syn_rfc``.
Args:
sd_syn_ssv: :class:`~syconn.reps.segmentation.SegmentationDataset` object of
type ``syn_ssv``. Used to identify synaptic object candidates annotated
in the kzip/xls file at `path2file`.
path2file: Path to kzip file with synapse labels as node comments
("non-synaptic", "synaptic"; labels used for classifier are 0 and 1
respectively).
overwrite: Replace existing files.
rfc_path_out: Filename for dumped RFC.
max_dist_vx: Maximum voxel distance between sample and target.
Returns:
The trained random forest classifier and the feature and label data.
"""
log = log_extraction
if global_params.config.working_dir is not None or rfc_path_out is not None:
if rfc_path_out is None:
model_base_dir = os.path.split(global_params.config.mpath_syn_rfc)[0]
log.info(f'Working directory is set to {global_params.config.working_dir} - '
f'trained RFC will be dumped at {model_base_dir}.')
os.makedirs(model_base_dir, exist_ok=True)
rfc_path_out = global_params.config.mpath_syn_rfc
else:
log = initialize_logging('create_syn_rfc', os.path.dirname(rfc_path_out))
if os.path.isfile(rfc_path_out) and not overwrite:
msg = f''
log.error(msg)
raise FileExistsError(msg)
assert sd_syn_ssv.type == 'syn_ssv'
log.info(f'Initiated RFC fitting procedure with GT file "{path2file}" and {sd_syn_ssv}.')
mapped_synssv_objects_kzip = f'{os.path.split(rfc_path_out)[0]}/mapped_synssv.k.zip'
if os.path.isfile(mapped_synssv_objects_kzip):
if not overwrite:
raise FileExistsError(f'File with mapped synssv objects already exists '
f'at "{mapped_synssv_objects_kzip}"')
os.remove(mapped_synssv_objects_kzip)
label_coords = []
labels = []
if path2file.endswith('k.zip'):
anno = skeleton_utils.load_skeleton(path2file)['Synapse annotation']
for node in anno.getNodes():
c = node.getComment()
if not ((c == 'synaptic') | (c == 'non-synaptic')):
continue
labels.append(c)
label_coords.append(np.array(node.getCoordinate()))
else:
df = pandas.read_excel(path2file, header=0, names=[
'ixs', 'coord', 'pre', 'post', 'syn', 'doublechecked', 'triplechecked', '?', 'comments']).values
df = df[:, :7]
for ix in range(df.shape[0]):
c_orig = df[ix, 5]
c = df[ix, 6]
if type(c) != float and 'yes' in c:
unified_comment = 'synaptic'
elif type(c) != float and 'no' in c:
unified_comment = 'non-synaptic'
elif 'yes' in c_orig:
unified_comment = 'synaptic'
elif 'no' in c_orig:
unified_comment = 'non-synaptic'
else:
log.warn(f'Did not understand GT comment "{c}". Skipping')
continue
labels.append(unified_comment)
label_coords.append(np.array(df[ix, 1].split(','), dtype=np.float32))
labels = np.array(labels)
label_coords = np.array(label_coords)
# get deterministic order by sorting by coordinate first and then seeded shuffling
ixs = [i[0] for i in sorted(enumerate(label_coords),
key=lambda x: [x[1][0], x[1][1], x[1][2]])]
ixs_random = np.arange(len(ixs))
np.random.seed(0)
np.random.shuffle(ixs_random)
ixs = np.array(ixs)
label_coords = label_coords[ixs][ixs_random]
labels = labels[ixs][ixs_random]
log.info('Setting up kd-trees for coord-to-synapse mapping.')
conn_kdtree = spatial.cKDTree(sd_syn_ssv.rep_coords * sd_syn_ssv.scaling)
ds, list_ids = conn_kdtree.query(label_coords * sd_syn_ssv.scaling)
synssv_ids = sd_syn_ssv.ids[list_ids]
mask = np.ones(synssv_ids.shape, dtype=np.bool)
log.info(f'Mapped {len(labels)} GT coordinates to {sd_syn_ssv.type}-objects.')
for label_id in np.where(ds > 0)[0]:
dists, close_ids = conn_kdtree.query(label_coords[label_id] * sd_syn_ssv.scaling,
k=20)
for close_id in close_ids[np.argsort(dists)]:
conn_o = sd_syn_ssv.get_segmentation_object(sd_syn_ssv.ids[close_id])
vx_ds = np.sum(np.abs(conn_o.voxel_list - label_coords[label_id]),
axis=-1)
if np.min(vx_ds) < max_dist_vx:
synssv_ids[label_id] = sd_syn_ssv.ids[close_id]
break
if np.min(vx_ds) > max_dist_vx:
mask[label_id] = 0
if np.sum(mask) == 0:
raise ValueError
synssv_ids = synssv_ids[mask]
labels = labels[mask]
log.info(f'Found {np.sum(mask)}/{len(mask)} samples with a distance < {max_dist_vx} vx to the target.')
log.info(f'Synapse features will now be generated.')
features = []
pbar = tqdm.tqdm(total=len(synssv_ids), leave=False)
for kk, synssv_id in enumerate(synssv_ids):
synssv_o = sd_syn_ssv.get_segmentation_object(synssv_id)
features.append(synssv_o_features(synssv_o))
pbar.update(1)
pbar.close()
features = np.array(features)
log.info('Performing 10-fold cross validation.')
rfc = ensemble.RandomForestClassifier(n_estimators=2000, max_features='sqrt',
n_jobs=-1, random_state=0,
oob_score=True)
mask_annotated = (labels == "synaptic") | (labels == 'non-synaptic')
v_features = features[mask_annotated]
v_labels = labels[mask_annotated]
v_labels = (v_labels == "synaptic").astype(np.int32)
# score = cross_val_score(rfc, v_features, v_labels, cv=10)
# log.info('RFC oob score: {:.4f}'.format(rfc.oob_score))
# log.info('RFC CV score +- std: {:.4f} +- {:.4f}'.format(
# np.mean(score), np.std(score)))
# if score < 0.95:
# log.info(f'Individual CV scores: {score}')
feature_names = np.array(synssv_o_featurenames())
probas = cross_val_predict(rfc, v_features, v_labels, cv=10, method='predict_proba')
preds = np.argmax(probas, axis=1)
log.info(metrics.classification_report(v_labels, preds, target_names=['non-synaptic', 'synaptic']))
if rfc_path_out is not None:
import matplotlib.pyplot as plt
import seaborn
log.info(f'Wrote 10-fold cross validation probas, predictions, features and labels of trained RFC to '
f'"{os.path.split(rfc_path_out)[0]}".')
np.save(os.path.split(rfc_path_out)[0] + '/rfc_probas.npy', probas)
np.save(os.path.split(rfc_path_out)[0] + '/rfc_preds.npy', preds)
np.save(os.path.split(rfc_path_out)[0] + '/rfc_labels.npy', v_features)
np.save(os.path.split(rfc_path_out)[0] + '/rfc_features.npy', v_labels)
np.save(os.path.split(rfc_path_out)[0] + '/rfc_feature_names.npy', feature_names)
plt.figure()
df = pandas.DataFrame(data=dict(mesh_area=v_features[:, feature_names == 'mesh_area_um2'].flatten(),
size=v_features[:, feature_names == 'size_vx'].flatten(),
correct=preds == v_labels,
gt=['syn' if el == 1 else 'non-synaptic' for el in v_labels]))
ax = seaborn.jointplot(data=df, x='size', y='mesh_area', hue='correct',
xlim=(0, df['size'].max()*1.1), ylim=(0, df['mesh_area'].max()*1.1))
ax.set_axis_labels('size vx [1]', 'mesh area [um^2]')
plt.savefig(os.path.split(rfc_path_out)[0] + '/feature_hist_size_vs_area_pred.png')
plt.close()
plt.figure()
ax = seaborn.jointplot(data=df, x='size', y='mesh_area', hue='gt',
xlim=(0, df['size'].max()*1.1), ylim=(0, df['mesh_area'].max()*1.1))
ax.set_axis_labels('size vx [1]', 'mesh area [um^2]')
plt.savefig(os.path.split(rfc_path_out)[0] + '/feature_hist_size_vs_area_gt.png')
plt.close()
rfc.fit(v_features, v_labels)
acc = rfc.score(v_features, v_labels)
log.info(f'Training set accuracy: {acc:.4f}')
feature_imp = rfc.feature_importances_
assert len(feature_imp) == len(feature_names)
log.info('RFC importances:\n' + "\n".join(
[f"{feature_names[ii]}: {feature_imp[ii]}" for ii in range(len(feature_imp))]))
log.info(f'Synapses will be annotated and written to "{mapped_synssv_objects_kzip}" for manual revision.')
skel = skeleton.Skeleton()
anno = skeleton.SkeletonAnnotation()
anno.scaling = sd_syn_ssv.scaling
pbar = tqdm.tqdm(total=len(synssv_ids), leave=False)
for kk, synssv_id in enumerate(synssv_ids):
synssv_o = sd_syn_ssv.get_segmentation_object(synssv_id)
rep_coord = synssv_o.rep_coord * sd_syn_ssv.scaling
pred_correct = preds[kk] == v_labels[kk]
n = skeleton.SkeletonNode().from_scratch(anno, rep_coord[0], rep_coord[1], rep_coord[2])
n.setComment(f'{preds[kk]} {pred_correct} {probas[kk][1]:.2f}')
n.data.update({k: v for k, v in zip(feature_names, v_features[kk])})
anno.addNode(n)
rep_coord = label_coords[kk] * sd_syn_ssv.scaling
n_l = skeleton.SkeletonNode().from_scratch(anno, rep_coord[0], rep_coord[1], rep_coord[2])
n_l.setComment('gt node; {}'.format(labels[kk]))
if not pred_correct:
synssv_o.mesh2kzip(mapped_synssv_objects_kzip, ext_color=None, ply_name='{}.ply'.format(synssv_id))
anno.addNode(n_l)
anno.addEdge(n, n_l)
pbar.update(1)
pbar.close()
skel.add_annotation(anno)
skel.to_kzip(mapped_synssv_objects_kzip)
if rfc_path_out is not None:
joblib.dump(rfc, rfc_path_out)
log.info(f'Wrote parameters of trained RFC to "{rfc_path_out}".')
else:
log.info('Working directory and rfc_path_out not set - trained RFC was not dumped to file.')
return rfc, v_features, v_labels
[docs]def synssv_o_features(synssv_o: segmentation.SegmentationObject) -> list:
"""
Collects syn_ssv feature for synapse prediction using an RFC.
Args:
synssv_o : SegmentationObject
Returns:
list
"""
features = [synssv_o.size, synssv_o.mesh_area]
partner_ids = synssv_o.attr_dict["neuron_partners"]
for i_partner_id, partner_id in enumerate(partner_ids):
features.append(synssv_o.attr_dict["n_mi_objs_%d" % i_partner_id])
features.append(synssv_o.attr_dict["n_mi_vxs_%d" % i_partner_id])
features.append(synssv_o.attr_dict["min_dst_mi_nm_%d" % i_partner_id])
features.append(synssv_o.attr_dict["n_vc_objs_%d" % i_partner_id])
features.append(synssv_o.attr_dict["n_vc_vxs_%d" % i_partner_id])
features.append(synssv_o.attr_dict["min_dst_vc_nm_%d" % i_partner_id])
return features
[docs]def synssv_o_featurenames() -> list:
return ['size_vx', 'mesh_area_um2', 'n_mi_objs_neuron1',
'n_mi_vxs_neuron1', 'min_dst_mi_nm_neuron1', 'n_vc_objs_neuron1', 'n_vc_vxs_neuron1',
'min_dst_vc_nm_neuron1', 'n_mi_objs_neuron2', 'n_mi_vxs_neuron2',
'min_dst_mi_nm_neuron2', 'n_vc_objs_neuron2', 'n_vc_vxs_neuron2', 'min_dst_vc_nm_neuron2']
[docs]def export_matrix(obj_version: Optional[str] = None, dest_folder: Optional[str] = None,
threshold_syn: float = 0, export_kzip: bool = False, log: Optional[Logger] = None):
"""
Writes .csv and optionally .kzip (large memory consumption) summary file of connectivity matrix.
Args:
obj_version (str):
dest_folder : Path to csv file.
threshold_syn :
Threshold applied to filter synapses. Defaults to 0, i.e. exporting all synapses.
export_kzip:
Export connectivity matrix as kzip - high memory consumption.
log: Logger.
"""
if threshold_syn is None:
threshold_syn = global_params.config['cell_objects']['thresh_synssv_proba']
if dest_folder is None:
dest_folder = global_params.config.working_dir + '/connectivity_matrix/'
if log is None:
log = log_extraction
os.makedirs(os.path.split(dest_folder)[0], exist_ok=True)
dest_name = dest_folder + '/conn_mat'
log.info(f'Starting export of connectivity matrix with minimum probability {threshold_syn} as csv file to "{dest_name}".')
sd_syn_ssv = segmentation.SegmentationDataset("syn_ssv", working_dir=global_params.config.working_dir,
version=obj_version)
syn_prob = sd_syn_ssv.load_numpy_data("syn_prob")
m = syn_prob > threshold_syn
m_axs = sd_syn_ssv.load_numpy_data("partner_axoness")[m]
m_cts = sd_syn_ssv.load_numpy_data("partner_celltypes")[m]
m_sp = sd_syn_ssv.load_numpy_data("partner_spiness")[m]
m_coords = sd_syn_ssv.rep_coords[m]
# m_sizes = sd_syn_ssv.sizes[m]
m_sizes = sd_syn_ssv.load_numpy_data("mesh_area")[m] / 2
m_ssv_partners = sd_syn_ssv.load_numpy_data("neuron_partners")[m]
m_syn_prob = syn_prob[m]
m_syn_sign = sd_syn_ssv.load_numpy_data("syn_sign")[m]
m_syn_asym_ratio = sd_syn_ssv.load_numpy_data("syn_type_sym_ratio")[m]
m_spineheadvol = sd_syn_ssv.load_numpy_data("partner_spineheadvol")[m]
m_latent_morph = sd_syn_ssv.load_numpy_data("latent_morph")[m] # N, 2, m
m_latent_morph = m_latent_morph.reshape(len(m_latent_morph), -1) # N, 2*m
# (loop of skeleton node generation)
# make sure cache-arrays have ndim == 2, TODO: check when writing cached arrays
m_sizes = np.multiply(m_sizes, m_syn_sign).squeeze()[:, None] # N, 1
m_axs = m_axs.squeeze() # N, 2
m_sp = m_sp.squeeze() # N, 2
m_syn_prob = m_syn_prob.squeeze()[:, None] # N, 1
table = np.concatenate([m_coords, m_ssv_partners, m_sizes, m_axs, m_cts,
m_sp, m_syn_prob, m_spineheadvol, m_latent_morph], axis=1)
# do not overwrite previous files
if os.path.isfile(dest_name + '.csv'):
st = datetime.datetime.fromtimestamp(time.time()).strftime('%Y-%m-%d %H:%M:%S')
os.rename(dest_name + '.csv', '{}_{}.csv'.format(dest_name, st))
np.savetxt(dest_name + ".csv", table, delimiter="\t",
header="x\ty\tz\tssv1\tssv2\tsize\tcomp1\tcomp2\tcelltype1\t"
"celltype2\tspiness1\tspiness2\tsynprob\tspinehead_vol1"
"\tspinehead_vol2" +
"".join(["\tlatentmorph1_{}".format(ix) for ix in range(
global_params.config['tcmn']['ndim_embedding'])]) +
"".join(["\tlatentmorph2_{}".format(ix) for ix in range(
global_params.config['tcmn']['ndim_embedding'])])
)
# # super high memory consumption for j0251
# wiring, borders = generate_wiring_array(log=log, thresh_syn_prob=threshold_syn, syn_version=obj_version)
# plot_wiring(f'{dest_folder}', wiring, borders, borders, log=log)
# plot_cumul_wiring(f'{dest_folder}', wiring, borders, min_cumul_synarea=0, log=log)
if export_kzip:
ax_labels = np.array(["N/A", "D", "A", "S"]) # TODO: this is already defined in handler.multiviews!
ax_label_ids = np.array([-1, 0, 1, 2])
annotations = []
m_sizes = np.abs(m_sizes)
ms_axs = np.sort(m_axs, axis=1)
# transform labels 3 and 4 to 1 (bouton and terminal to axon to apply correct filter)
ms_axs[ms_axs == 3] = 1
ms_axs[ms_axs == 4] = 1
# vigra currently requires numpy==1.11.1
try:
u_axs = np.unique(ms_axs, axis=0)
except TypeError: # in case numpy < 1.13
u_axs = np.vstack({tuple(row) for row in ms_axs})
for u_ax in u_axs:
anno = skeleton.SkeletonAnnotation()
anno.scaling = sd_syn_ssv.scaling
cmt = "{} - {}".format(ax_labels[ax_label_ids == u_ax[0]][0],
ax_labels[ax_label_ids == u_ax[1]][0])
anno.comment = cmt
for i_syn in np.where(np.sum(np.abs(ms_axs - u_ax), axis=1) == 0)[0]:
c = m_coords[i_syn]
# somewhat approximated from sphere volume:
r = np.power(m_sizes[i_syn] / 3., 1 / 3.)
# r = m_sizes[i_syn]
skel_node = skeleton.SkeletonNode().from_scratch(anno, c[0], c[1], c[2],
radius=r)
skel_node.data["ids"] = m_ssv_partners[i_syn]
skel_node.data["size"] = m_sizes[i_syn]
skel_node.data["syn_prob"] = m_syn_prob[i_syn]
skel_node.data["sign"] = m_syn_sign[i_syn]
skel_node.data["in_ex_frac"] = m_syn_asym_ratio[i_syn]
skel_node.data['sp'] = m_sp[i_syn]
skel_node.data['ct'] = m_cts[i_syn]
skel_node.data['ax'] = m_axs[i_syn]
skel_node.data['latent_morph'] = m_latent_morph[i_syn]
anno.addNode(skel_node)
annotations.append(anno)
# do not overwrite previous files
if os.path.isfile(dest_name + '.k.zip'):
st = datetime.datetime.fromtimestamp(time.time()).strftime('%Y-%m-%d %H:%M:%S')
os.rename(dest_name + '.k.zip', '{}_{}.k.zip'.format(dest_name, st))
skeleton_utils.write_skeleton(dest_name + ".k.zip", annotations)