# -*- coding: utf-8 -*-
# SyConn - Synaptic connectivity inference toolkit
#
# Copyright (c) 2016 - now
# Max-Planck-Institute of Neurobiology, Munich, Germany
# Authors: Philipp Schubert, Joergen Kornfeld
from . import log_proc
from .. import global_params
from ..handler import basics
from ..mp import batchjob_utils as qu
from ..mp import mp_utils as sm
from ..proc.meshes import mesh_creator_sso
from ..reps import segmentation, super_segmentation
from ..reps.segmentation_helper import prepare_so_attr_cache
from ..reps.super_segmentation import SuperSegmentationObject, SuperSegmentationDataset
from typing import Iterable, Tuple
import numpy as np
import tqdm
from collections import Counter
from typing import Optional, List
from logging import Logger
[docs]def aggregate_segmentation_object_mappings(ssd: SuperSegmentationDataset, obj_types: List[str],
n_jobs: Optional[int] = None, nb_cpus: Optional[int] = None):
"""
Populates the attributes of SuperSegmentationDataset (SSD) objects with mapping information for each object type
specified in `obj_types`. Each object type must exist in the version dictionary of the SSD. This function is
essential for mapping cellular structures to cells in the SyConn toolkit.
Args:
ssd (SuperSegmentationDataset): The SSD whose attributes are to be populated.
obj_types (List[str]): List of object identifiers to be mapped to the cells. For example, ['mi', 'sj', 'vc'].
n_jobs (Optional[int]): Number of jobs to be run in parallel. If not provided, it defaults to twice the total
number of cores available.
nb_cpus (Optional[int]): Number of cores to be used per job when using BatchJob. If single node multiprocessing
is used, this represents the total number of jobs.
"""
for obj_type in obj_types:
assert obj_type in ssd.version_dict
assert "sv" in ssd.version_dict
if n_jobs is None:
n_jobs = global_params.config.ncore_total * 2
multi_params = basics.chunkify(ssd.ssv_ids[np.argsort(ssd.load_numpy_data('size'))[::-1]], n_jobs)
multi_params = [(ssv_id_block, ssd.version, ssd.version_dict, ssd.working_dir,
obj_types, ssd.type) for ssv_id_block in multi_params]
if not qu.batchjob_enabled():
_ = sm.start_multiprocess_imap(_aggregate_segmentation_object_mappings_thread, multi_params,
debug=False, nb_cpus=nb_cpus)
else:
_ = qu.batchjob_script(multi_params, "aggregate_segmentation_object_mappings", n_cores=nb_cpus,
remove_jobfolder=True)
def _aggregate_segmentation_object_mappings_thread(args):
"""
Thread function for the aggregate_segmentation_object_mappings function. It performs the actual task of populating
the attributes of SuperSegmentationDataset (SSD) objects with mapping information for each object type. This
function is not intended to be called directly, but is used by the multiprocessing module to parallelize the
attribute population task.
Args:
args: A tuple containing the parameters required for the attribute population task. This includes a block of
SuperSegmentationObject IDs, the version of the SSD, the version dictionary of the SSD, the working
directory of the SSD, the object types to be mapped, and the type of the SSD.
"""
ssv_obj_ids = args[0]
version = args[1]
version_dict = args[2]
working_dir = args[3]
obj_types = args[4]
ssd_type = args[5]
ssd = super_segmentation.SuperSegmentationDataset(working_dir, version, ssd_type=ssd_type,
version_dict=version_dict)
svids = np.concatenate([ssd.mapping_dict[ssvid] for ssvid in ssv_obj_ids])
ssd._mapping_dict = None
so_attr_of_interest = []
# create cache for object attributes
for obj_type in obj_types:
so_attr_of_interest.extend([f"mapping_{obj_type}_ids", f"mapping_{obj_type}_ratios"])
attr_cache = prepare_so_attr_cache(segmentation.SegmentationDataset('sv', config=ssd.config), svids,
so_attr_of_interest)
for ssv_id in ssv_obj_ids:
ssv = ssd.get_super_segmentation_object(ssv_id)
ssv.load_attr_dict()
mappings = dict((obj_type, Counter()) for obj_type in obj_types)
for svid in ssv.sv_ids:
for obj_type in obj_types:
try:
keys = attr_cache[f"mapping_{obj_type}_ids"][svid]
values = attr_cache[f"mapping_{obj_type}_ratios"][svid]
mappings[obj_type] += Counter(dict(zip(keys, values)))
except KeyError:
raise KeyError(f'Could not find attribute "{f"mapping_{obj_type}_ids"}" for '
f'cell supervoxel {svid} during "_aggregate_segmentation_object_mappings_thread".')
for obj_type in obj_types:
if obj_type in mappings:
ssv.attr_dict[f"mapping_{obj_type}_ids"] = list(mappings[obj_type].keys())
ssv.attr_dict[f"mapping_{obj_type}_ratios"] = list(mappings[obj_type].values())
ssv.save_attr_dict()
[docs]def apply_mapping_decisions(ssd: SuperSegmentationDataset,
obj_types: List[str], n_jobs: Optional[int] = None,
nb_cpus: Optional[int] = None):
"""
Populates the attributes of SuperSegmentationDataset (SSD) objects with the specified object types. Each object
type must exist in the version dictionary of the SSD and in the cell objects configuration of the SSD. This
function is essential for applying mapping decisions to cells in the SyConn toolkit. It requires the prior
execution of the aggregate_segmentation_object_mappings function.
Args:
ssd (SuperSegmentationDataset): The SSD whose attributes are to be populated.
obj_types (List[str]): List of object identifiers to be mapped to the cells. For example, ['mi', 'sj', 'vc'].
n_jobs (Optional[int]): Number of jobs to be run in parallel. If not provided, it defaults to twice the total
number of cores available.
nb_cpus (Optional[int]): Number of cores to be used per job when using BatchJob. If single node multiprocessing
is used, this represents the total number of jobs.
"""
for obj_type in obj_types:
assert obj_type in ssd.version_dict
if n_jobs is None:
n_jobs = global_params.config.ncore_total * 2
multi_params = basics.chunkify(ssd.ssv_ids[np.argsort(ssd.load_numpy_data('size'))[::-1]], n_jobs)
multi_params = [(ssv_id_block, ssd.version, ssd.version_dict, ssd.working_dir,
obj_types, ssd.type) for ssv_id_block in multi_params]
if not qu.batchjob_enabled():
_ = sm.start_multiprocess_imap(_apply_mapping_decisions_thread, multi_params)
else:
_ = qu.batchjob_script(
multi_params, "apply_mapping_decisions", n_cores=nb_cpus, remove_jobfolder=True)
def _apply_mapping_decisions_thread(args):
"""
This function applies mapping decisions to the SuperSegmentationObjects (SSOs) in a SuperSegmentationDataset (SSD).
It populates the SSOs with attributes for each object type in the SSD. The attributes include the object's size,
representative coordinate, bounding box, and mapping ratios and ids for each object type. If any of these attributes
are missing, the function attempts to generate them. If the mapping ratios or ids are missing, the function calls
'_aggregate_segmentation_object_mappings_thread' to generate them. This function is not intended to be called directly,
but is used by the multiprocessing module to parallelize the attribute population task.
Args:
args (list): A list containing the following elements:
- ssv_obj_ids (list): List of SSO ids.
- version (str): Version of the SSD.
- version_dict (dict): Dictionary containing versions of the SSD.
- working_dir (str): Working directory of the SSD.
- obj_types (list): List of object types in the SSD.
- ssd_type (str): Type of the SSD.
"""
ssv_obj_ids = args[0]
version = args[1]
version_dict = args[2]
working_dir = args[3]
obj_types = args[4]
ssd_type = args[5]
ssd = super_segmentation.SuperSegmentationDataset(working_dir, version, ssd_type=ssd_type,
version_dict=version_dict)
cell_objects_dc = ssd.config['cell_objects']
lower_ratio = None
upper_ratio = None
sizethreshold = None
# cache size property of objects
upper_ratios = {}
sizethresholds = {}
lower_ratios = {}
sd_dc = {}
for obj_t in obj_types:
assert obj_t in ssd.version_dict
sd_dc[obj_t] = segmentation.SegmentationDataset(obj_t, config=ssd.config, version=ssd.version_dict[obj_t],
cache_properties=['size'])
if lower_ratio is None:
try:
lower_ratio = cell_objects_dc["lower_mapping_ratios"][obj_t]
except KeyError:
msg = "Lower ratio undefined."
log_proc.critical(msg)
raise ValueError(msg)
if upper_ratio is None:
try:
upper_ratio = cell_objects_dc["upper_mapping_ratios"][obj_t]
except KeyError:
log_proc.error(f"Upper ratio undefined - 1. assumed.")
upper_ratio = 1.
if sizethreshold is None:
try:
sizethreshold = cell_objects_dc["sizethresholds"][obj_t]
except KeyError:
msg = "Size threshold undefined."
log_proc.critical(msg)
raise ValueError(msg)
upper_ratios[obj_t] = upper_ratio
lower_ratios[obj_t] = lower_ratio
sizethresholds[obj_t] = sizethreshold
missing_mapping_ids = set()
for ssv_id in tqdm.tqdm(ssv_obj_ids, disable=True):
missing_mapping_info = False
ssv = ssd.get_super_segmentation_object(ssv_id)
ssv.load_attr_dict()
if 'sv' not in ssv.attr_dict:
ssv.attr_dict["sv"] = ssd.mapping_dict[ssv.id]
log_proc.warning(f"No supervoxel IDs found in SSV {ssv}, but it was possible to generate them.")
if "rep_coord" not in ssv.attr_dict:
ssv.attr_dict["rep_coord"] = ssv.rep_coord
log_proc.warning(f"No rep. coord. found in SSV {ssv}, but it was possible to generate it.")
if "bounding_box" not in ssv.attr_dict:
ssv.attr_dict["bounding_box"] = ssv.bounding_box
log_proc.warning(f"No bounding box found in SSV {ssv}, but it was possible to generate it.")
if "size" not in ssv.attr_dict:
ssv.attr_dict["size"] = ssv.size
log_proc.warning(f"No size found in SSV {ssv}, but it was possible to generate it.")
# save here already because sub-sequent call of '_aggregate_segmentation_object_mappings_thread'
# does not have access to it otherwise
ssv.save_attr_dict()
for obj_type in obj_types:
upper_ratio = upper_ratios[obj_type]
lower_ratio = lower_ratios[obj_type]
sizethreshold = sizethresholds[obj_type]
if not "mapping_%s_ratios" % obj_type in ssv.attr_dict:
# ssv.load_attr_dict()
# if not "mapping_%s_ratios" % obj_type in ssv.attr_dict:
# msg = f"No mapping ratios found in SSV {ssv}."
# log_proc.error(msg)
# raise ValueError(msg)
# else:
missing_mapping_info = True
log_proc.warning(f"No mapping ratios found in SSV {ssv}, but it was possible "
f"to perform the mapping.")
break
if not "mapping_%s_ids" % obj_type in ssv.attr_dict:
msg = f"No mapping ids found in SSV {ssv}."
log_proc.error(msg)
raise ValueError(msg)
obj_ratios = np.array(ssv.attr_dict[f"mapping_{obj_type}_ratios"])
id_mask = obj_ratios > lower_ratio
if upper_ratio < 1.:
id_mask[obj_ratios > upper_ratio] = False
candidate_ids = np.array(ssv.attr_dict[f"mapping_{obj_type}_ids"])[id_mask]
ssv.attr_dict[obj_type] = []
for candidate_id in candidate_ids:
obj = sd_dc[obj_type].get_segmentation_object(candidate_id)
if obj.size > sizethreshold:
ssv.attr_dict[obj_type].append(candidate_id)
if missing_mapping_info:
missing_mapping_ids.add(ssv.id)
continue
else:
ssv.save_attr_dict()
# TODO: this is a safety-precaution which should not be necessary at this point.
if len(missing_mapping_ids) == 0:
return
# second round after finding missing mapping ratios
func_args = (list(missing_mapping_ids), ssd.version, ssd.version_dict, ssd.working_dir, obj_types, ssd.type)
log_proc.warning(f"Found {len(missing_mapping_ids)} SSOs without mapping info. Mapping now again.")
_aggregate_segmentation_object_mappings_thread(func_args)
for ssv_id in tqdm.tqdm(missing_mapping_ids, disable=True, desc='SSO (mapping)'):
ssv = ssd.get_super_segmentation_object(ssv_id)
ssv.load_attr_dict()
for obj_type in obj_types:
upper_ratio = upper_ratios[obj_type]
lower_ratio = lower_ratios[obj_type]
sizethreshold = sizethresholds[obj_type]
if not "mapping_%s_ratios" % obj_type in ssv.attr_dict:
msg = f"No mapping ratios found in SSV {ssv}."
log_proc.error(msg)
raise ValueError(msg)
if not "mapping_%s_ids" % obj_type in ssv.attr_dict:
msg = f"No mapping ids found in SSV {ssv}."
log_proc.error(msg)
raise ValueError(msg)
obj_ratios = np.array(ssv.attr_dict[f"mapping_{obj_type}_ratios"])
id_mask = obj_ratios > lower_ratio
if upper_ratio < 1.:
id_mask[obj_ratios > upper_ratio] = False
candidate_ids = np.array(ssv.attr_dict[f"mapping_{obj_type}_ids"])[id_mask]
ssv.attr_dict[obj_type] = []
for candidate_id in candidate_ids:
obj = sd_dc[obj_type].get_segmentation_object(candidate_id)
if obj.size > sizethreshold:
ssv.attr_dict[obj_type].append(candidate_id)
ssv.save_attr_dict()
[docs]def map_synssv_objects(synssv_version: Optional[str] = None, log: Optional[Logger] = None,
nb_cpus=None, n_jobs=None, syn_threshold=None):
"""
This function maps syn_ssv objects and merges their meshes for all
SuperSegmentationObjects (SSOs) in a SuperSegmentationDataset (SSD).
It stores the meshes with keys: 'syn_ssv', 'syn_ssv_sym', and
'syn_ssv_asym' (if synapse type is available). This operation may take
a while.
Args:
synssv_version (str, optional): String identifier for the syn_ssv
objects. Defaults to None.
n_jobs (int, optional): Number of jobs. Defaults to None.
log (Logger, optional): Logger for logging information. Defaults to None.
nb_cpus (int, optional): Number of CPUs for local multi-processing.
Defaults to None.
syn_threshold (float, optional): Probability threshold applied
during the mapping of syn_ssv objects. Defaults to None.
"""
if n_jobs is None:
n_jobs = 4 * global_params.config.ncore_total
if syn_threshold is None:
syn_threshold = global_params.config['cell_objects']['thresh_synssv_proba']
ssd = SuperSegmentationDataset(global_params.config.working_dir)
multi_params = []
for ssv_id_block in basics.chunkify(ssd.ssv_ids, n_jobs):
multi_params.append([ssv_id_block, ssd.version, ssd.version_dict, ssd.working_dir, ssd.type, synssv_version,
syn_threshold])
if not qu.batchjob_enabled():
_ = sm.start_multiprocess_imap(map_synssv_objects_thread, multi_params, nb_cpus=nb_cpus)
else:
_ = qu.batchjob_script(multi_params, "map_synssv_objects",
remove_jobfolder=True, log=log)
[docs]def map_synssv_objects_thread(args):
"""
This function is a multi-threaded version of the 'map_synssv_objects' function. It maps syn_ssv objects and merges
their meshes for a subset of SuperSegmentationObjects (SSOs) in a SuperSegmentationDataset (SSD).
Args:
args (list): A list containing the following elements:
- ssv_obj_ids (list): List of SSO ids.
- version (str): Version of the SSD.
- version_dict (dict): Dictionary containing versions of the SSD.
- working_dir (str): Working directory of the SSD.
- ssd_type (str): Type of the SSD.
- synssv_version (str): Version of the syn_ssv objects.
- syn_threshold (float): Probability threshold applied during the mapping of syn_ssv objects.
"""
ssv_obj_ids, version, version_dict, working_dir, \
ssd_type, synssv_version, syn_threshold = args
ssd = super_segmentation.SuperSegmentationDataset(working_dir, version, ssd_type=ssd_type,
version_dict=version_dict)
syn_ssv_sd = segmentation.SegmentationDataset(obj_type="syn_ssv", working_dir=working_dir,
version=synssv_version)
ssv_partners = syn_ssv_sd.load_numpy_data("neuron_partners")
syn_prob = syn_ssv_sd.load_numpy_data("syn_prob")
synssv_ids = syn_ssv_sd.load_numpy_data("id")
synssv_ids = synssv_ids[syn_prob > syn_threshold]
ssv_partners = ssv_partners[syn_prob > syn_threshold]
for ssv_id in ssv_obj_ids:
# enable cache of syn_ssv SegmentationObjects, including their meshes
# -> reused in typedsyns2mesh call
ssv = ssd.get_super_segmentation_object(ssv_id, caching=True)
ssv.load_attr_dict()
curr_synssv_ids = synssv_ids[np.in1d(ssv_partners[:, 0], ssv.id)]
curr_synssv_ids = np.concatenate([curr_synssv_ids,
synssv_ids[np.in1d(ssv_partners[:, 1], ssv.id)]])
ssv.attr_dict["syn_ssv"] = curr_synssv_ids
ssv.save_attr_dict()
# cache syn_ssv mesh and typed meshes if available
ssv.load_mesh('syn_ssv')
if global_params.config.syntype_available:
ssv.typedsyns2mesh()
[docs]def mesh_proc_ssv(working_dir: str, version: Optional[str] = None,
ssd_type: str = 'ssv', nb_cpus: Optional[int] = None):
"""
This function caches the meshes of SuperSegmentationObjects (SSOs) in a
SuperSegmentationDataset (SSD) locally with a specified number of CPUs in
parallel.
Args:
working_dir (str): Path to the working directory.
version (str, optional): Version identifier, like 'spgt' for spine
ground truth SSD. Defaults to the SSD of the
cellular SSVs. Defaults to None.
ssd_type (str): Type of the SSD. Default is 'ssv'.
nb_cpus (int, optional): Number of CPUs for local multi-processing.
Default is the total number of CPUs.
"""
ssds = super_segmentation.SuperSegmentationDataset(working_dir=working_dir,
version=version,
ssd_type=ssd_type)
sm.start_multiprocess_imap(mesh_creator_sso, list(ssds.ssvs),
nb_cpus=nb_cpus, debug=False)
[docs]def split_ssv(ssv: SuperSegmentationObject, splitted_sv_ids: Iterable[int]) \
-> Tuple[SuperSegmentationObject, SuperSegmentationObject]:
"""
Splits a SuperSegmentationObject into two separate SuperSegmentationObjects. The split is
based on the provided supervoxel IDs. The function checks if the SuperSegmentationObject's
dataset is defined and if the provided supervoxel IDs are part of the original
SuperSegmentationObject. If these conditions are met, the function generates two new
SuperSegmentationObjects with new IDs and returns them.
Args:
ssv (SuperSegmentationObject): The SuperSegmentationObject to be split.
splitted_sv_ids (Iterable[int]): The supervoxel IDs used to split the
SuperSegmentationObject.
Returns:
Tuple[SuperSegmentationObject, SuperSegmentationObject]: The two new
SuperSegmentationObjects resulting from the split.
"""
if ssv._ssd is None:
raise ValueError('SSV dataset has to be defined. Use "get_superseg'
'mentation_object" method to instantiate SSO objects,'
' or assign "_dataset".')
ssd = ssv._ssd
orig_ids = set(ssv.sv_ids)
# TODO: Support ssv.rag splitting
splitted_sv_ids = set(splitted_sv_ids)
if splitted_sv_ids.issubset(orig_ids):
raise ValueError('All splitted SV IDs have to be part of the SSV.')
set1 = orig_ids.difference(set(splitted_sv_ids))
set2 = splitted_sv_ids
# TODO: run SSD modification methods, e.g. cached numpy arrays holding SSV attributes
# TODO: run contactsite modification methods, e.g. change all contactsites which SSV partners contain ssv.id etc.
# TODO: run all classification models
new_id1, new_id2 = list(get_available_ssv_ids(ssd, n=2))
ssv1 = init_ssv(new_id1, list(set1), ssd=ssd)
ssv2 = init_ssv(new_id2, list(set2), ssd=ssd)
# TODO: add ignore flag or destroy original SSV in its SSD.
return ssv1, ssv2
[docs]def init_ssv(ssv_id: int, sv_ids: List[int], ssd: SuperSegmentationDataset) \
-> SuperSegmentationObject:
"""
Initializes a SuperSegmentationObject and caches all relevant data. This function
requires that the cell organelles and supervoxel SegmentationDatasets are already
initialized. The function creates a new SuperSegmentationObject with the provided
ID and supervoxel IDs, preprocesses it, and returns it.
Args:
ssv_id (int): The ID for the new SuperSegmentationObject.
sv_ids (List[int]): The supervoxel IDs for the new SuperSegmentationObject.
ssd (SuperSegmentationDataset): The SuperSegmentationDataset that the new
SuperSegmentationObject belongs to.
Returns:
SuperSegmentationObject: The newly initialized SuperSegmentationObject.
"""
ssv = SuperSegmentationObject(ssv_id, sv_ids=sv_ids, version=ssd.version,
create=True, working_dir=ssd.working_dir)
ssv.preprocess()
return ssv
[docs]def get_available_ssv_ids(ssd, n=2):
"""
Generates available SuperSegmentationObject IDs. The function iterates through the range of the maximum
SuperSegmentationObject ID plus the provided number 'n', and yields the IDs that are not already in use. The
function stops when 'n' IDs have been generated.
Args:
ssd: The SuperSegmentationDataset to check for available IDs.
n (int, optional): The number of IDs to generate. Defaults to 2.
Yields:
int: The next available SuperSegmentationObject ID.
"""
cnt = 0
for ii in range(np.max(ssd.ssv_ids) + n):
if cnt == n:
break
if not ii in ssd.ssv_ids:
cnt += 1
yield ii