# -*- coding: utf-8 -*-
# SyConn - Synaptic connectivity inference toolkit
#
# Copyright (c) 2016 - now
# Max Planck Institute of Neurobiology, Martinsried, Germany
# Authors: Philipp Schubert, Joergen Kornfeld
import glob
import os
import pickle as pkl
import shutil
import networkx as nx
import numpy as np
import scipy.ndimage
import skimage.segmentation
from knossos_utils import chunky
from .block_processing_C import relabel_vol
from .. import global_params
from ..handler import basics, log_handler, compression
from ..handler.basics import kd_factory
from ..mp import batchjob_utils as qu, mp_utils as sm
from ..proc.general import cut_array_in_one_dim
from ..proc.image import apply_morphological_operations, get_aniso_struct
try:
import vigra
from vigra.filters import gaussianSmoothing, distanceTransform
from vigra.analysis import watershedsNew
except ImportError as e:
gaussianSmoothing = None
log_handler.error('ImportError. Could not import VIGRA. '
'`object_segmentation` will not be possible. {}'.format(e))
[docs]def gauss_threshold_connected_components(*args, **kwargs):
# alias
return object_segmentation(*args, **kwargs)
[docs]def object_segmentation(cset, filename, hdf5names, overlap="auto", sigmas=None,
thresholds=None, chunk_list=None, debug=False,
swapdata=False, prob_kd_path_dict=None,
membrane_filename=None, membrane_kd_path=None,
hdf5_name_membrane=None, fast_load=False, suffix="",
nb_cpus=None, transform_func=None,
transform_func_kwargs=None, transf_func_kd_overlay=None,
load_from_kd_overlaycubes=False, n_chunk_jobs=None):
"""
Extracts connected component from probability maps.
By default the following procedure is used:
1. Gaussian filter (defined by sigma)
2. Thresholding (defined by threshold)
3. Connected components analysis
If `transform_func` is set, the specified method will be applied by every
worker on the chunk's probability map to generate the segmentation instead.
Add `transform_func_kwargs` in case `transform_func` specific arguments.
In case of vesicle clouds (hdf5_name in ["p4", "vc"]) the membrane
segmentation is used to cut connected vesicle clouds across cells
apart (only if membrane segmentation is provided).
Args:
cset : chunkdataset instance
filename (str) : Filename of the prediction in the ChunkDataset.
hdf5names (list): list of strings
List of names/ labels to be extracted and processed from the prediction
file.
overlap (str): str or np.array
Defines the overlap with neighbouring chunks that is left for later
processing steps; if 'auto' the overlap is calculated from the sigma and
the stitch_overlap (here: [1., 1., 1.]) and the number of binary erosion
in global_params.config['cell_objects']['extract_morph_op'].
sigmas (list): list of lists or None
Defines the sigmas of the gaussian filters applied to the probability
maps. Has to be the same length as hdf5names. If None no gaussian filter
is applied.
thresholds(list of float or np.ndarray):
Threshold for cutting the probability map. Has to be the same length as
hdf5names. If None zeros are used instead (not recommended!)
chunk_list(list):
Selective list of chunks for which this function should work on. If None
all chunks are used.
debug(bool):
If true multiprocessed steps only operate on one core using 'map' which
allows for better error messages.
swapdata(bool):
If true an x-z swap is applied to the data prior to processing.
prob_kd_path_dict:
membrane_filename(str):
One way to allow access to a membrane segmentation when processing
vesicle clouds. Filename of the prediction in the chunkdataset. The
threshold is currently set at 0.4.
membrane_kd_path(str):
One way to allow access to a membrane segmentation when processing
vesicle clouds. Path to the knossosdataset containing a membrane
segmentation. The threshold is currently set at 0.4.
hdf5_name_membrane(str):
When using the membrane_filename this key has to be given to access the
data in the saved chunk.
fast_load(bool):
If true the data of chunk is blindly loaded without checking for enough
offset to compute the overlap area. Faster, because no neighbouring
chunk has to be accessed since the default case loads th overlap area
from them.
suffix(str):
Suffix for the intermediate results.
nb_cpus:
transform_func(callable):
Segmentation method which is applied.
transform_func_kwargs(dict) :
key word arguments for transform_func
load_from_kd_overlaycubes(bool) :
Load prob/seg data from overlaycubes instead of raw cubes.
transf_func_kd_overlay :
Method which is to applied to cube data if `load_from_kd_overlaycubes`
is True.
n_chunk_jobs:
Returns:
results_as_list(list):
list containing information about the number of connected components
in each chunk
overlap(np.array):
stitch overlap(np.array):
"""
if transform_func is None:
transform_func = _object_segmentation_thread
if thresholds is None:
thresholds = np.zeros(len(hdf5names))
if sigmas is None:
sigmas = np.zeros(len(hdf5names))
if not len(sigmas) == len(thresholds) == len(hdf5names):
raise Exception("Number of thresholds, sigmas and HDF5 names does not "
"match!")
if n_chunk_jobs is None:
n_chunk_jobs = global_params.config.ncore_total * 4
if chunk_list is None:
chunk_list = [ii for ii in range(len(cset.chunk_dict))]
rand_ixs = np.arange(len(chunk_list))
np.random.seed(0)
np.random.shuffle(rand_ixs)
chunk_list = np.array(chunk_list)[rand_ixs]
chunk_blocks = basics.chunkify(chunk_list, n_chunk_jobs)
if overlap is "auto":
if sigmas is None:
max_sigma = np.zeros(3)
else:
max_sigma = np.array([np.max(sigmas)] * 3)
overlap = np.ceil(max_sigma * 4)
morph_ops = global_params.config['cell_objects']['extract_morph_op']
scaling = global_params.config['scaling']
aniso = scaling[2] // scaling[0]
n_erosions = 0
for k, v in morph_ops.items():
v = np.array(v)
# factor 2: erodes both sides; aniso: morphological operation kernel is laterally increased by this factor
n_erosions = max(n_erosions, 2 * aniso * np.sum(v == 'binary_erosion'))
overlap = np.max([overlap, [n_erosions, n_erosions, n_erosions // aniso]], axis=0).astype(np.int32)
stitch_overlap = np.max([overlap.copy(), [1, 1, 1]], axis=0)
multi_params = []
for chunk_sub in chunk_blocks:
multi_params.append(
[[cset.chunk_dict[nb_chunk] for nb_chunk in chunk_sub],
cset.path_head_folder, filename, hdf5names, overlap,
sigmas, thresholds, swapdata, prob_kd_path_dict,
membrane_filename, membrane_kd_path,
hdf5_name_membrane, fast_load, suffix, transform_func_kwargs,
load_from_kd_overlaycubes, transf_func_kd_overlay])
if not qu.batchjob_enabled():
results = sm.start_multiprocess_imap(transform_func, multi_params, nb_cpus=nb_cpus, debug=False,
use_dill=True)
results_as_list = []
for result in results:
for entry in result:
results_as_list.append(entry)
else:
assert transform_func == _object_segmentation_thread, "batch jobs currently only supported for " \
"`_object_segmentation_thread`."
path_to_out = qu.batchjob_script(
multi_params, "object_segmentation", n_cores=nb_cpus, use_dill=True, suffix=filename)
out_files = glob.glob(path_to_out + "/*")
results_as_list = []
for out_file in out_files:
with open(out_file, 'rb') as f:
for entry in pkl.load(f):
results_as_list.append(entry)
shutil.rmtree(os.path.abspath(path_to_out + "/../"), ignore_errors=True)
return results_as_list, [overlap, stitch_overlap]
def _object_segmentation_thread(args):
"""
Default worker of object_segmentation. Performs a gaussian blur with
subsequent thresholding to extract connected components of a probability
map. Result summaries are returned and connected components are stored as
.h5 files.
TODO: Add generic '_segmentation_thread' to enable a clean support of
custom-made segmentation functions passed to 'object_segmentation' via
'transform_func'-kwargs
Args:
args(list) :
Returns:
list of lists: Results of connected component analysis
"""
chunks = args[0]
path_head_folder = args[1]
filename = args[2]
hdf5names = args[3]
overlap = args[4]
sigmas = args[5]
thresholds = args[6]
swapdata = args[7]
prob_kd_path_dict = args[8]
membrane_filename = args[9]
membrane_kd_path = args[10]
hdf5_name_membrane = args[11]
fast_load = args[12]
suffix = args[13]
transform_func_kwargs = args[14]
load_from_kd_overlaycubes = args[15]
transf_func_kd_overlay = args[16]
# e.g. {'sj': ['binary_closing', 'binary_opening'], 'mi': [], 'cell': []}
morph_ops = global_params.config['cell_objects']['extract_morph_op']
min_seed_vx = global_params.config['cell_objects']['min_seed_vx']
scaling = np.array(global_params.config['scaling'])
struct = get_aniso_struct(scaling)
nb_cc_list = []
for chunk in chunks:
box_offset = np.array(chunk.coordinates) - np.array(overlap)
size = np.array(chunk.size) + 2 * np.array(overlap)
if swapdata:
size = basics.switch_array_entries(size, [0, 2])
if prob_kd_path_dict is not None:
bin_data_dict = {}
if load_from_kd_overlaycubes: # enable possibility to load from overlay cubes as well
data_k = None
exp_value = next(iter(prob_kd_path_dict.values()))
all_equal = all(v == exp_value for v in prob_kd_path_dict.values())
if all_equal:
kd = kd_factory(prob_kd_path_dict[hdf5names[0]])
data_k = kd.load_seg(size=size, offset=box_offset, mag=1).swapaxes(0, 2)
for kd_key in hdf5names:
if not all_equal:
kd = kd_factory(prob_kd_path_dict[kd_key])
data_k = kd.load_seg(size=size, offset=box_offset, mag=1).swapaxes(0, 2)
if transf_func_kd_overlay is not None:
bin_data_dict[kd_key] = transf_func_kd_overlay[kd_key](data_k)
else:
bin_data_dict[kd_key] = data_k
else: # load raw
for kd_key in prob_kd_path_dict.keys():
kd = kd_factory(prob_kd_path_dict[kd_key])
bin_data_dict[kd_key] = kd.load_raw(size=size, offset=box_offset,
mag=1).swapaxes(0, 2)
else:
if not fast_load:
cset = chunky.load_dataset(path_head_folder)
bin_data_dict = cset.from_chunky_to_matrix(size, box_offset,
filename, hdf5names)
else:
bin_data_dict = compression.load_from_h5py(chunk.folder + filename + ".h5",
hdf5names, as_dict=True)
labels_data = []
for nb_hdf5_name in range(len(hdf5names)):
hdf5_name = hdf5names[nb_hdf5_name]
tmp_data = bin_data_dict[hdf5_name]
tmp_data_shape = tmp_data.shape
offset = (np.array(tmp_data_shape) - np.array(chunk.size) -
2 * np.array(overlap)) / 2
offset = offset.astype(np.int32)
if np.any(offset < 0):
offset = np.array([0, 0, 0])
tmp_data = tmp_data[offset[0]: tmp_data_shape[0] - offset[0],
offset[1]: tmp_data_shape[1] - offset[1],
offset[2]: tmp_data_shape[2] - offset[2]]
if np.sum(sigmas[nb_hdf5_name]) != 0:
tmp_data = gaussianSmoothing(tmp_data, sigmas[nb_hdf5_name])
if hdf5_name in ["p4", "vc"] and membrane_filename is not None and hdf5_name_membrane is not None:
membrane_data = compression.load_from_h5py(chunk.folder + membrane_filename + ".h5",
hdf5_name_membrane)[0]
membrane_data_shape = membrane_data.shape
offset = (np.array(membrane_data_shape) - np.array(tmp_data.shape)) / 2
membrane_data = membrane_data[offset[0]: membrane_data_shape[0] - offset[0],
offset[1]: membrane_data_shape[1] - offset[1],
offset[2]: membrane_data_shape[2] - offset[2]]
tmp_data[membrane_data > 255 * .4] = 0
del membrane_data
elif hdf5_name in ["p4", "vc"] and membrane_kd_path is not None:
kd_bar = kd_factory(membrane_kd_path)
membrane_data = kd_bar.load_raw(size=size, offset=box_offset,
mag=1).swapaxes(0, 2)
tmp_data[membrane_data > 255 * .4] = 0
del membrane_data
if thresholds[nb_hdf5_name] != 0 and not load_from_kd_overlaycubes:
tmp_data = np.array(tmp_data > thresholds[nb_hdf5_name], dtype=np.uint8)
if hdf5_name in morph_ops:
if 'binary_erosion' in morph_ops[hdf5_name]:
first_erosion_ix = morph_ops[hdf5_name].index('binary_erosion')
tmp_data = apply_morphological_operations(tmp_data.copy(), morph_ops[hdf5_name][:first_erosion_ix],
mop_kwargs=dict(structure=struct))
# apply erosion operations to generate watershed seeds
markers = apply_morphological_operations(tmp_data.copy(),
morph_ops[hdf5_name][first_erosion_ix:],
mop_kwargs=dict(structure=struct))
markers = scipy.ndimage.label(markers)[0].astype(np.uint32)
# remove small fragments and 0; this might also delete objects bigger than min_size as
# this threshold is applied after N binary erosion!
if hdf5_name in min_seed_vx and min_seed_vx[hdf5_name] > 1:
min_size = min_seed_vx[hdf5_name]
ixs, cnt = np.unique(markers, return_counts=True)
m = (ixs != 0) & (cnt < min_size)
ixs_del = np.sort(ixs[m])
ixs_keep = np.sort(ixs[~m])
# set small objects to 0
label_m = {ix_del: 0 for ix_del in ixs_del}
# fill "holes" in ID space with to-be-kept object IDs
ii = len(ixs_keep) - 1
for ix_del in ixs_del:
if (ix_del > ixs_keep[ii]) or (ixs_keep[ii] == 0) or (ii < 0):
break
label_m[ixs_keep[ii]] = ix_del
ii -= 1
# in-place modification of markers array
relabel_vol(markers, label_m)
distance = distanceTransform(tmp_data.astype(np.uint32, copy=False), background=False,
pixel_pitch=scaling.astype(np.uint32))
this_labels_data = skimage.segmentation.watershed(-distance, markers, mask=tmp_data)
max_label = np.max(this_labels_data)
else:
mop_data = apply_morphological_operations(tmp_data.copy(), morph_ops[hdf5_name],
mop_kwargs=dict(structure=struct))
this_labels_data, max_label = scipy.ndimage.label(mop_data)
else:
this_labels_data, max_label = scipy.ndimage.label(tmp_data)
nb_cc_list.append([chunk.number, hdf5_name, max_label])
labels_data.append(this_labels_data)
h5_fname = chunk.folder + filename + "_connected_components%s.h5" % suffix
os.makedirs(os.path.split(h5_fname)[0], exist_ok=True)
compression.save_to_h5py(labels_data, h5_fname, hdf5names)
del labels_data
return nb_cc_list
[docs]def make_unique_labels(cset, filename, hdf5names, chunk_list, max_nb_dict,
chunk_translator, debug, suffix="",
n_chunk_jobs=None, nb_cpus=1):
"""
Makes labels unique across chunks
Args:
cset : chunkdataset instance
filename(str) :
Filename of the prediction in the chunkdataset
hdf5names(list): list of str
List of names/ labels to be extracted and processed from the prediction
file
chunk_list(list): list of int
Selective list of chunks for which this function should work on. If None
all chunks are used.
max_nb_dict(dict):
Maps each chunk id to a integer describing which needs to be added to
all its entries
chunk_translator(dict):
Remapping from chunk ids to position in chunk_list
debug(bool):
If true multiprocessed steps only operate on one core using 'map' which
allows for better error messages
suffix: str
Suffix for the intermediate results
n_chunk_jobs: int
Number of total jobs.
nb_cpus: int
"""
if n_chunk_jobs is None:
n_chunk_jobs = global_params.config.ncore_total
chunk_blocks = basics.chunkify(chunk_list, n_chunk_jobs)
multi_params_glob = []
for chunk_sub in chunk_blocks:
multi_params = []
for nb_chunk in chunk_sub:
this_max_nb_dict = {}
for hdf5_name in hdf5names:
this_max_nb_dict[hdf5_name] = max_nb_dict[hdf5_name][
chunk_translator[nb_chunk]]
multi_params.append([cset.chunk_dict[nb_chunk], filename, hdf5names,
this_max_nb_dict, suffix])
multi_params_glob.append(multi_params)
if not qu.batchjob_enabled():
_ = sm.start_multiprocess_imap(_make_unique_labels_thread,
multi_params_glob, debug=debug)
else:
_ = qu.batchjob_script(
multi_params_glob, "make_unique_labels", suffix=filename,
remove_jobfolder=True, n_cores=nb_cpus)
def _make_unique_labels_thread(func_args):
for args in func_args:
chunk = args[0]
filename = args[1]
hdf5names = args[2]
this_max_nb_dict = args[3]
suffix = args[4]
cc_data_list = compression.load_from_h5py(
chunk.folder + filename + "_connected_components%s.h5" % suffix, hdf5names)
for nb_hdf5_name in range(len(hdf5names)):
hdf5_name = hdf5names[nb_hdf5_name]
cc_data_list[nb_hdf5_name] = cc_data_list[nb_hdf5_name].astype(np.uint64)
matrix = cc_data_list[nb_hdf5_name]
matrix[matrix > 0] += this_max_nb_dict[hdf5_name]
compression.save_to_h5py(cc_data_list, chunk.folder + filename + "_unique_components%s.h5" % suffix, hdf5names)
[docs]def make_stitch_list(cset, filename, hdf5names, chunk_list, stitch_overlap,
overlap, debug, suffix="", nb_cpus=None,
overlap_thresh=0, n_chunk_jobs=None):
"""
Creates a stitch list for the overlap region between chunks
Args:
cset : chunkdataset instance
filename(str):
Filename of the prediction in the chunkdataset
hdf5names(list): list of str
List of names/ labels to be extracted and processed from the prediction
file
chunk_list(list): list of int
Selective list of chunks for which this function should work on. If None
all chunks are used.
overlap(np.array): np.array
Defines the overlap with neighbouring chunks that is left for later
processing steps
stitch_overlap: np.array
Defines the overlap with neighbouring chunks that is left for stitching
debug: boolean
If true multiprocessed steps only operate on one core using 'map' which
allows for better error messages
suffix: str
Suffix for the intermediate results
nb_cpus: int
Number of cores used per worker.
n_chunk_jobs: int
Number of total jobs.
overlap_thresh : float
Overlap fraction of object in different chunks to be considered stitched.
If zero this behavior is disabled.
Returns:
stitch_list(dict):
Dictionary of overlapping component ids
"""
if n_chunk_jobs is None:
n_chunk_jobs = global_params.config.ncore_total
chunk_blocks = basics.chunkify(chunk_list, n_chunk_jobs)
multi_params = []
for i_job in range(len(chunk_blocks)):
multi_params.append([cset.path_head_folder, chunk_blocks[i_job], filename, hdf5names,
stitch_overlap, overlap,
suffix, chunk_list, overlap_thresh])
if not qu.batchjob_enabled():
results = sm.start_multiprocess_imap(_make_stitch_list_thread,
multi_params, debug=debug)
stitch_list = {}
for hdf5_name in hdf5names:
stitch_list[hdf5_name] = []
for result in results:
for hdf5_name in hdf5names:
elems = result[hdf5_name]
for elem in elems:
stitch_list[hdf5_name].append(elem)
else:
path_to_out = qu.batchjob_script(multi_params, "make_stitch_list",
suffix=filename, n_cores=nb_cpus)
out_files = glob.glob(path_to_out + "/*")
stitch_list = {}
for hdf5_name in hdf5names:
stitch_list[hdf5_name] = []
for out_file in out_files:
with open(out_file, 'rb') as f:
result = pkl.load(f)
for hdf5_name in hdf5names:
elems = result[hdf5_name]
for elem in elems:
stitch_list[hdf5_name].append(elem)
shutil.rmtree(os.path.abspath(path_to_out + "/../"), ignore_errors=True)
return stitch_list
def _make_stitch_list_thread(args):
cpath_head_folder = args[0]
nb_chunks = args[1]
filename = args[2]
hdf5names = args[3]
stitch_overlap = args[4]
overlap = args[5]
suffix = args[6]
chunk_list = args[7]
overlap_thresh = args[8]
map_dict = {}
for nb_hdf5_name in range(len(hdf5names)):
map_dict[hdf5names[nb_hdf5_name]] = set()
cset = chunky.load_dataset(cpath_head_folder)
for nb_chunk in nb_chunks:
chunk = cset.chunk_dict[nb_chunk]
cc_data_list = compression.load_from_h5py(chunk.folder + filename +
"_unique_components%s.h5" % suffix, hdf5names)
# TODO: optimize get_neighbouring_chunks
neighbours, pos = cset.get_neighbouring_chunks(chunk, chunklist=chunk_list,
con_mode=7)
# Compare only upper half of 6-neighborhood for every chunk
neighbours = neighbours[np.any(pos > 0, axis=1)]
pos = pos[np.any(pos > 0, axis=1)]
# Compare only half of 6-neighborhood for every chunk which suffices to cover all overlap areas. Checking all
# neighbors for every chunk would lead to twice and redundant computational load
for ii in range(3):
if neighbours[ii] != -1:
compare_chunk = cset.chunk_dict[neighbours[ii]]
cc_data_list_to_compare = \
compression.load_from_h5py(
compare_chunk.folder + filename + "_unique_components%s.h5" % suffix, hdf5names)
cc_area = {}
cc_area_to_compare = {}
id = np.argmax(pos[ii]) # get contact dimension (perpendicular to contact plane)
for nb_hdf5_name in range(len(hdf5names)):
this_cc_data = cc_data_list[nb_hdf5_name]
this_cc_data_to_compare = \
cc_data_list_to_compare[nb_hdf5_name]
cc_area[nb_hdf5_name] = \
cut_array_in_one_dim(
this_cc_data,
-overlap[id] - stitch_overlap[id],
-overlap[id] + stitch_overlap[id], id)
cc_area_to_compare[nb_hdf5_name] = \
cut_array_in_one_dim(
this_cc_data_to_compare,
overlap[id] - stitch_overlap[id],
overlap[id] + stitch_overlap[id], id)
for nb_hdf5_name in range(len(hdf5names)):
hdf5_name = hdf5names[nb_hdf5_name]
stitch_ixs = np.transpose(np.nonzero((cc_area[nb_hdf5_name] != 0) &
(cc_area_to_compare[nb_hdf5_name] != 0)))
ignore_ids = set() # if already inspected and overlap is insufficient
for stitch_pos in stitch_ixs:
stitch_pos = tuple(stitch_pos)
this_id = cc_area[nb_hdf5_name][stitch_pos]
compare_id = cc_area_to_compare[nb_hdf5_name][stitch_pos]
pair = tuple(sorted([this_id, compare_id]))
if (pair not in map_dict[hdf5_name]) and (pair not in ignore_ids):
if overlap_thresh > 0:
obj_coord_intern = np.transpose(np.nonzero(cc_data_list[nb_hdf5_name] == this_id))
obj_coord_intern_compare = np.transpose(
np.nonzero(cc_data_list_to_compare[nb_hdf5_name] == compare_id))
c1 = chunk.coordinates - chunk.overlap + obj_coord_intern + np.array([1, 1, 1])
c2 = compare_chunk.coordinates - compare_chunk.overlap + obj_coord_intern_compare + np.array(
[1, 1, 1])
from scipy import spatial
kdt = spatial.cKDTree(c1)
dists, ixs = kdt.query(c2)
match_vx = np.sum(dists == 0)
match_vx_rel = 2 * float(match_vx) / (len(c1) + len(c2))
if match_vx_rel > 0.1:
map_dict[hdf5_name].add(pair)
else:
ignore_ids.add(pair)
else:
map_dict[hdf5_name].add(pair)
for k, v in map_dict.items():
map_dict[k] = list(v)
return map_dict
[docs]def make_merge_list(hdf5names, stitch_list, max_labels):
"""
Creates a merge list from a stitch list by mapping all connected ids to
one id
Args:
hdf5names (list): list of str
List of names/ labels to be extracted and processed from the prediction
file
stitch_list (dict):
Contains pairs of overlapping component ids for each hdf5name
max_labels (dict): dictionary
Contains the number of different component ids for each hdf5name
Returns:
merge_dict (dict):
mergelist for each hdf5name
merge_list_dict (dict):
mergedict for each hdf5name
"""
merge_dict = {}
merge_list_dict = {}
for hdf5_name in hdf5names:
this_stitch_list = stitch_list[hdf5_name]
max_label = max_labels[hdf5_name]
graph = nx.from_edgelist(this_stitch_list)
cc = nx.connected_components(graph)
merge_dict[hdf5_name] = {}
merge_list_dict[hdf5_name] = np.arange(max_label + 1)
for this_cc in cc:
this_cc = list(this_cc)
for id in this_cc:
merge_dict[hdf5_name][id] = this_cc[0]
merge_list_dict[hdf5_name][id] = this_cc[0]
return merge_dict, merge_list_dict
[docs]def apply_merge_list(cset, chunk_list, filename, hdf5names, merge_list_dict,
debug, suffix="", n_chunk_jobs=None, nb_cpus=1):
"""
Applies merge list to all chunks
Args:
cset : chunkdataset instance
chunk_list (list): list of int
Selective list of chunks for which this function should work on. If None
all chunks are used.
filename (str):
Filename of the prediction in the chunkdataset
hdf5names (list): list of str
List of names/ labels to be extracted and processed from the prediction
file
merge_list_dict (dict):
mergedict for each hdf5name
debug (bool):
If true multiprocessed steps only operate on one core using 'map' which
allows for better error messages
suffix (str):
Suffix for the intermediate results
n_chunk_jobs (int):
Number of total jobs.
nb_cpus:
"""
multi_params = []
merge_list_dict_path = cset.path_head_folder + "merge_list_dict.pkl"
f = open(merge_list_dict_path, "wb")
pkl.dump(merge_list_dict, f, protocol=4)
f.close()
if n_chunk_jobs is None:
n_chunk_jobs = global_params.config.ncore_total * 2
chunk_blocks = basics.chunkify(chunk_list, n_chunk_jobs)
for i_job in range(len(chunk_blocks)):
multi_params.append([[cset.chunk_dict[nb_chunk] for nb_chunk in chunk_blocks[i_job]],
filename, hdf5names, merge_list_dict_path, suffix])
if not qu.batchjob_enabled():
sm.start_multiprocess_imap(_apply_merge_list_thread, multi_params)
else:
qu.batchjob_script(
multi_params, "apply_merge_list", suffix=filename, n_cores=nb_cpus,
remove_jobfolder=True)
def _apply_merge_list_thread(args):
chunks = args[0]
filename = args[1]
hdf5names = args[2]
merge_list_dict_path = args[3]
postfix = args[4]
merge_list_dict = pkl.load(open(merge_list_dict_path, 'rb'))
for chunk in chunks:
cc_data_list = compression.load_from_h5py(
chunk.folder + filename + "_unique_components%s.h5" % postfix, hdf5names)
for nb_hdf5_name in range(len(hdf5names)):
hdf5_name = hdf5names[nb_hdf5_name]
this_cc = cc_data_list[nb_hdf5_name]
id_changer = merge_list_dict[hdf5_name]
this_shape = this_cc.shape
offset = (np.array(this_shape) - chunk.size) // 2 # offset needs to be integer
this_cc = this_cc[offset[0]: this_shape[0] - offset[0],
offset[1]: this_shape[1] - offset[1],
offset[2]: this_shape[2] - offset[2]]
this_cc = id_changer[this_cc]
cc_data_list[nb_hdf5_name] = this_cc
compression.save_to_h5py(cc_data_list,
chunk.folder + filename +
"_stitched_components%s.h5" % postfix,
hdf5names)
[docs]def export_cset_to_kd_batchjob(target_kd_paths, cset, name, hdf5names, n_cores=1,
offset=None, size=None, stride=(4 * 128, 4 * 128, 4 * 128),
overwrite=False, as_raw=False, fast_downsampling=False,
n_max_job=None, unified_labels=False, orig_dtype=np.uint8, log=None,
compresslevel=None):
"""
Batchjob version of :class:`knossos_utils.chunky.ChunkDataset.export_cset_to_kd`
method, see ``knossos_utils.chunky`` for details.
Notes:
* KnossosDataset needs to be initialized beforehand (see
:func:`~KnossosDataset.initialize_without_conf`).
* Only works if data mag = 1.
Args:
target_kd_paths: Target KnossosDatasets.
cset: Source ChunkDataset.
name:
hdf5names:
n_cores:
offset:
size:
stride:
overwrite:
as_raw:
fast_downsampling:
n_max_job:
unified_labels:
orig_dtype:
log:
compresslevel: Compression level in case segmentation data is written for (seg.sz.zip files).
Returns:
"""
if n_max_job is None:
n_max_job = global_params.config.ncore_total
target_kds = {}
for hdf5name in hdf5names:
path = target_kd_paths[hdf5name]
target_kd = kd_factory(path)
target_kds[hdf5name] = target_kd
for hdf5name in hdf5names[1:]:
assert np.all(target_kds[hdf5names[0]].boundary ==
target_kds[hdf5name].boundary), \
"KnossosDataset boundaries differ."
if offset is None or size is None:
offset = np.zeros(3, dtype=np.int32)
# use any KD to infere the boundary
size = np.copy(target_kds[hdf5names[0]].boundary)
multi_params = []
for coordx in range(offset[0], offset[0] + size[0],
stride[0]):
for coordy in range(offset[1], offset[1] + size[1],
stride[1]):
for coordz in range(offset[2], offset[2] + size[2],
stride[2]):
coords = np.array([coordx, coordy, coordz])
multi_params.append(coords)
multi_params = basics.chunkify(multi_params, n_max_job)
multi_params = [[coords, stride, cset.path_head_folder, target_kd_paths, name,
hdf5names, as_raw, unified_labels, 1, orig_dtype,
fast_downsampling, overwrite,
compresslevel] for coords in multi_params]
job_suffix = "_" + "_".join(hdf5names)
qu.batchjob_script(
multi_params, "export_cset_to_kds", n_cores=n_cores, remove_jobfolder=True,
suffix=job_suffix, log=log)
def _export_cset_as_kds_thread(args):
"""Helper function.
TODO: refactor.
"""
coords = args[0]
size = np.array(args[1])
cset_path = args[2]
target_kd_paths = args[3]
name = args[4]
hdf5names = args[5]
as_raw = args[6]
unified_labels = args[7]
nb_threads = args[8]
orig_dtype = args[9]
fast_downsampling = args[10]
overwrite = args[11]
compresslevel = args[12]
cset = chunky.load_dataset(cset_path, update_paths=True)
# Backwards compatibility
if type(target_kd_paths) is str and len(hdf5names) == 1:
kd = kd_factory(target_kd_paths)
target_kds = {hdf5names[0]: kd}
else: # Add proper case handling for incorrect arguments
target_kds = {}
for hdf5name in hdf5names:
path = target_kd_paths[hdf5name]
target_kd = kd_factory(path)
target_kds[hdf5name] = target_kd
for dim in range(3):
if coords[dim] + size[dim] > cset.box_size[dim]:
size[dim] = cset.box_size[dim] - coords[dim]
data_dict = cset.from_chunky_to_matrix(size, coords, name, hdf5names,
dtype=orig_dtype)
for hdf5name in hdf5names:
curr_d = data_dict[hdf5name]
if (curr_d.dtype.kind not in ("u", "i")) and (0 < np.max(curr_d) <= 1.0):
curr_d = (curr_d * 255).astype(np.uint8)
data_dict[hdf5name] = []
data_list = curr_d
# make it ZYX
data_list = np.swapaxes(data_list, 0, 2)
kd = target_kds[hdf5name]
if as_raw:
kd.save_raw(offset=coords, mags=kd.available_mags, data=data_list, data_mag=1,
fast_resampling=fast_downsampling)
else:
kd.save_seg(offset=coords, mags=kd.available_mags, data=data_list, data_mag=1,
fast_resampling=fast_downsampling, compresslevel=compresslevel)