# -*- coding: utf-8 -*-
# SyConn - Synaptic connectivity inference toolkit
#
# Copyright (c) 2016 - now
# Max Planck Institute of Neurobiology, Martinsried, Germany
# Authors: Sven Dorkenwald, Philipp Schubert, Joergen Kornfeld
import copy
import os
import time
from . import log_reps
from . import segmentation
from .rep_helper import assign_rep_values, colorcode_vertices, surface_samples
from .segmentation_helper import load_skeleton, find_missing_sv_views, \
find_missing_sv_attributes, find_missing_sv_skeletons, load_so_attr_bulk
from .. import global_params
from ..handler.basics import kd_factory, flatten_list
from ..handler.multiviews import generate_rendering_locs
from ..mp.mp_utils import start_multiprocess_obj, start_multiprocess_imap
from ..proc.graphs import create_graph_from_coords, stitch_skel_nx
from ..proc.meshes import write_mesh2kzip
from ..proc.rendering import render_sso_coords
from ..proc.sd_proc import predict_views
from ..extraction.block_processing_C import relabel_vol_nonexist2zero
from ..extraction.in_bounding_boxC import in_bounding_box
from typing import Dict, List, Union, Optional, Tuple, TYPE_CHECKING, Any
if TYPE_CHECKING:
from . import super_segmentation
from ..reps.super_segmentation import SuperSegmentationObject, SuperSegmentationDataset
from ..reps.segmentation import SegmentationObject
from collections.abc import Iterable
from collections import Counter
from multiprocessing.pool import ThreadPool
import networkx as nx
from numba import jit
import numpy as np
import scipy
import scipy.ndimage
from scipy import spatial
from skimage.segmentation import watershed
from skimage.feature import peak_local_max
from scipy import ndimage
from knossos_utils.skeleton_utils import annotation_to_nx_graph, \
load_skeleton as load_skeleton_kzip, Skeleton, SkeletonAnnotation, SkeletonNode
try:
from knossos_utils import mergelist_tools
except ImportError:
from knossos_utils import mergelist_tools_fallback as mergelist_tool
[docs]def majority_vote(anno, prop, max_dist):
"""
Performs smoothing of property prediction in an annotation using a sliding
window and majority voting, leaving somata untouched for axoness property.
Args:
anno: SkeletonAnnotation
The annotation object containing the skeleton.
prop: str
The property to average, e.g., 'axoness'.
max_dist: int
The maximum distance (in nm) for the sliding window used in majority
voting.
Returns:
None
"""
old_anno = copy.deepcopy(anno)
nearest_nodes_list = nodes_in_pathlength(old_anno, max_dist)
for nodes in nearest_nodes_list:
curr_node_id = nodes[0].getID()
new_node = anno.getNodeByID(curr_node_id)
if prop == "axoness":
if int(new_node.data["axoness_pred"]) == 2:
new_node.data["axoness_pred"] = 2
continue
property_val = [int(n.data[prop + '_pred']) for n in nodes if
int(n.data[prop + '_pred']) != 2]
counter = Counter(property_val)
new_ax = counter.most_common()[0][0]
new_node.setDataElem(prop + '_pred', new_ax)
[docs]def nodes_in_pathlength(anno, max_path_len):
"""
Identifies nodes reachable within a specified path length from each source node
in an annotation.
Args:
anno: AnnotationObject
The annotation object containing the nodes.
max_path_len: float
The maximum distance from the source node.
Returns:
List[List[SkeletonNode]]: A list of lists, each containing nodes reachable
within max_path_len. The outer list has a length equal to the number of nodes
in anno, obtained by len(anno.getNodes()).
"""
skel_graph = annotation_to_nx_graph(anno)
list_reachable_nodes = []
for source_node in anno.getNodes():
reachable_nodes = [source_node]
curr_path_length = 0.0
for edge in nx.bfs_edges(skel_graph, source_node):
next_node = edge[1]
next_node_coord = np.array(next_node.getCoordinate_scaled())
curr_vec = next_node_coord - np.array(edge[0].getCoordinate_scaled())
curr_path_length += np.linalg.norm(curr_vec)
if curr_path_length > max_path_len:
break
reachable_nodes.append(next_node)
list_reachable_nodes.append(reachable_nodes)
return list_reachable_nodes
[docs]def predict_sso_celltype(sso: 'super_segmentation.SuperSegmentationObject', model: Any, nb_views_model: int = 20,
use_syntype=True, overwrite: bool = False, pred_key_appendix="", da_equals_tan: bool = True,
n_classes: int = 7, save_to_attr_dict: bool = True):
"""
Predicts the cell type of a SuperSegmentationObject based on local views and
synapse type ratio feature, using cached views. This method uses precomputed
views, also used for axon and spine prediction. To generate predictions without
using cached views, refer to `celltype_of_sso_nocache`. The random view subsets
used for prediction are prepared by the function :func:`~sso_views_to_modelinput`.
The final cell type prediction is determined by the majority vote across all
subset predictions.
Args:
sso: SuperSegmentationObject
The SuperSegmentationObject to predict the cell type for.
model: nn.Module
The prediction model, typically a neural network module.
nb_views_model: int
The number of views to use in the model for prediction.
use_syntype: bool
Whether to consider the synapse type in the prediction. `n_classes`
must be 7 if `da_equals_tan` is True.
overwrite: bool
If True, overwrite existing predictions. Use this option with
caution to ensure that previous predictions are not unintentionally lost.
pred_key_appendix: str
An appendix to the prediction key used for result storage. This allows
differentiation between different prediction iterations or parameters.
da_equals_tan: bool
If True, merge DA and TAN classes into one category. This parameter
should be set to True only when `n_classes` is 7, to maintain consistency
in class representation across predictions.
n_classes: int
The number of output classes in the model. It has to be set to 7 if
`da_equals_tan` is True in order to align with the merged class
definitions.
save_to_attr_dict: bool
If True, persist the prediction result in the SuperSegmentationObject's
attribute dictionary for subsequent retrieval and analysis.
Returns:
None
"""
sso.load_attr_dict()
pred_key = "celltype_cnn_e3" + pred_key_appendix
if not overwrite and pred_key in sso.attr_dict:
return
from ..handler.prediction import naive_view_normalization_new
inp_d = sso_views_to_modelinput(sso, nb_views_model)
inp_d = naive_view_normalization_new(inp_d)
if global_params.config.syntype_available and use_syntype:
synsign_ratio = np.array([[syn_sign_ratio_celltype(sso, comp_types=[1, ]),
syn_sign_ratio_celltype(sso, comp_types=[0, ])]]
* len(inp_d))
res = model.predict_proba((inp_d, synsign_ratio))
else:
res = model.predict_proba(inp_d)
# DA and TAN are type modulatory, if this is changes, also change `certainty_celltype`, `celltype_of_sso_nocache`
if da_equals_tan:
assert n_classes == 7, 'Incompatible number of classes for cell type prediction with "da_equals_tan=True".'
# accumulate evidence for DA and TAN
res[:, 1] += res[:, 6]
# remove TAN in proba array
res = np.delete(res, [6], axis=1)
# INT is now at index 6 -> label 6 is INT
clf = np.argmax(res, axis=1)
if np.max(clf) >= n_classes:
raise ValueError('Unknown cell type predicted.')
major_dec = np.zeros(n_classes)
for ii in range(len(major_dec)):
major_dec[ii] = np.sum(clf == ii)
major_dec /= np.sum(major_dec)
pred = np.argmax(major_dec)
sso.attr_dict[pred_key] = pred
sso.attr_dict[f"{pred_key}_probas"] = res
cert = sso.certainty_celltype(pred_key)
sso.attr_dict[f"{pred_key}_certainty"] = cert
if save_to_attr_dict:
sso.save_attributes([pred_key, f"{pred_key}_probas", f"{pred_key}_certainty"], [pred, res, cert])
[docs]def radius_correction_found_vertices(sso: 'super_segmentation.SuperSegmentationObject',
plump_factor: int = 1, num_found_vertices: int = 10):
"""
Estimates the diameters of skeleton nodes by finding the median distance to the
nearest mesh vertices.
Args:
sso: SuperSegmentationObject
The SuperSegmentationObject containing the skeleton and mesh.
plump_factor: int
A factor to adjust the estimated radius. This is a multiplication
factor for the radius.
num_found_vertices: int
The number of closest vertices to query for each node.
Returns:
The updated skeleton with diameters estimated.
"""
skel_node = sso.skeleton['nodes']
diameters = sso.skeleton['diameters']
vert_sparse = sso.mesh[1].reshape((-1, 3))
tree = spatial.cKDTree(vert_sparse)
dists, all_found_vertices_ixs = tree.query(skel_node * sso.scaling,
num_found_vertices)
for ii, el in enumerate(skel_node):
diameters[ii] = np.median(dists[ii]) * 2 / 10
sso.skeleton['diameters'] = diameters * plump_factor
return sso.skeleton
[docs]def get_sso_axoness_from_coord(sso, coord, k=5):
"""
Determines the majority axoness class of the k nearest neighbor nodes within an SSO skeleton.
Args:
sso: SuperSegmentationObject
The SuperSegmentationObject containing the skeleton.
coord: np.array
The unscaled coordinate to query.
k: int
The number of nearest neighbors to consider for the majority vote.
Returns:
int: The majority class of the nodes (0 for dendrite, 1 for axon, or 2 for soma).
"""
coord = np.array(coord) * np.array(sso.scaling)
sso.load_skeleton()
kdt = spatial.cKDTree(sso.skeleton["nodes"] * np.array(sso.scaling))
dists, ixs = kdt.query(coord, k=k)
ixs = ixs[dists != np.inf]
axs = sso.skeleton["axoness"][ixs]
cnt = Counter(axs)
return cnt.most_common(n=1)[0][0]
[docs]def load_voxels_downsampled(sso, downsampling=(2, 2, 1), nb_threads=10):
"""
Loads the voxels of a SuperSegmentationObject with downsampling.
Args:
sso: SuperSegmentationObject
The SuperSegmentationObject to load voxels for.
downsampling: tuple
The downsampling factors for each dimension (z, y, x).
nb_threads: int
The number of threads to use for parallel loading.
Returns:
np.ndarray: A downsampled boolean array representing the voxels of the SSO.
"""
def _load_sv_voxels_thread(args):
sv_id = args[0]
sv = segmentation.SegmentationObject(sv_id,
obj_type="sv",
version=sso.version_dict[
"sv"],
working_dir=sso.working_dir,
config=sso.config,
voxel_caching=False)
if sv.voxels_exist:
box = [np.array(sv.bounding_box[0] - sso.bounding_box[0],
dtype=np.int32)]
box[0] /= downsampling
size = np.array(sv.bounding_box[1] -
sv.bounding_box[0], dtype=np.float32)
size = np.ceil(size.astype(np.float32) /
downsampling).astype(np.int32)
box.append(box[0] + size)
sv_voxels = sv.voxels
if not isinstance(sv_voxels, int):
sv_voxels = sv_voxels[::downsampling[0],
::downsampling[1],
::downsampling[2]]
voxels[box[0][0]: box[1][0],
box[0][1]: box[1][1],
box[0][2]: box[1][2]][sv_voxels] = True
downsampling = np.array(downsampling, dtype=np.int32)
if len(sso.sv_ids) == 0:
return None
voxel_box_size = sso.bounding_box[1] - sso.bounding_box[0]
voxel_box_size = voxel_box_size.astype(np.float32)
voxel_box_size = np.ceil(voxel_box_size / downsampling).astype(np.int32)
voxels = np.zeros(voxel_box_size, dtype=np.bool)
multi_params = []
for sv_id in sso.sv_ids:
multi_params.append([sv_id])
if nb_threads > 1:
pool = ThreadPool(nb_threads)
pool.map(_load_sv_voxels_thread, multi_params)
pool.close()
pool.join()
else:
map(_load_sv_voxels_thread, multi_params)
return voxels
[docs]def create_new_skeleton(sv_id, sso):
"""
Creates a new skeleton for a supervoxel within a SuperSegmentationObject.
Args:
sv_id (int): The ID of the supervoxel.
sso (SuperSegmentationObject): The SuperSegmentationObject containing the
supervoxel.
Returns:
Tuple[np.ndarray, np.ndarray, np.ndarray]: Arrays representing the nodes,
diameters, and edges of the new skeleton.
"""
so = SegmentationObject(sv_id, obj_type="sv", version=sso.version_dict["sv"], working_dir=sso.working_dir,
config=sso.config)
so.enable_locking = False
so.load_attr_dict()
skel = load_skeleton(so)
return skel['nodes'], skel['diameters'], skel['edges']
[docs]def convert_coord(coord_list, scal):
"""
Converts a list of coordinates using a scaling factor.
Args:
coord_list: list or np.ndarray
The list of coordinates to convert.
scal: numeric or np.ndarray
The scaling factor to be applied to each coordinate in `coord_list`.
Returns:
np.ndarray: The scaled coordinates, with the same shape as `coord_list`.
"""
return np.array([coord_list[1] + 1, coord_list[0] + 1,
coord_list[2] + 1]) * np.array(scal)
[docs]def prune_stub_branches(sso=None, nx_g=None, scal=None, len_thres=1000,
preserve_annotations=True):
"""
Removes short stub branches from a skeleton graph, maintaining the true morphology.
Args:
sso: Optional[SuperSegmentationObject]
The SuperSegmentationObject containing the skeleton. If None, no SuperSegmentationObject
will be returned.
nx_g: networkx.Graph
The graph representing the skeleton.
scal: Optional[np.array of size 3]
The scaling factor for the coordinates. Defaults to the original scale.
len_thres: int
The threshold for the length below which branches will be pruned.
preserve_annotations: bool
If True, annotations are preserved during pruning.
Returns:
Tuple[Optional[SuperSegmentationObject], networkx.Graph]: The pruned SuperSegmentationObject
(if provided) and the pruned skeleton graph. If sso is None, only the pruned graph is returned.
"""
if scal is None:
scal = global_params.config['scaling']
pruning_complete = False
if preserve_annotations:
new_nx_g = nx_g.copy()
else:
new_nx_g = nx_g
# find all tip nodes in an anno, ie degree 1 nodes
while not pruning_complete:
nx_g = new_nx_g.copy()
end_nodes = list({k for k, v in dict(nx_g.degree()).items() if v == 1})
# DFS to first branch node
for end_node in end_nodes:
prune_nodes = []
for curr_node in nx.traversal.dfs_preorder_nodes(nx_g, end_node):
if nx_g.degree(curr_node) > 2:
loc_end = convert_coord(nx_g.nodes[end_node]['position'], scal)
loc_curr = convert_coord(nx_g.nodes[curr_node]['position'], scal)
b_len = np.linalg.norm(loc_end - loc_curr)
if b_len < len_thres:
# remove this stub, i.e. prune the nodes that were
# collected on our way to the branch point
for prune_node in prune_nodes:
new_nx_g.remove_node(prune_node)
break
else:
break
prune_nodes.append(curr_node)
if len(new_nx_g.nodes) == len(nx_g.nodes):
pruning_complete = True
if nx.number_connected_components(new_nx_g) != 1:
msg = 'Pruning of SV skeletons failed during "prune_stub_branches' \
'" with {} connected components. Please check the underlying' \
' SSV {}. Performing stitching method to add missing edg' \
'es recursively.'.format(nx.number_connected_components(new_nx_g),
sso.id)
new_nx_g = stitch_skel_nx(new_nx_g)
log_reps.critical(msg)
raise ValueError(msg)
for e in new_nx_g.edges:
w = np.linalg.norm((new_nx_g.nodes[e[0]]['position'] -
new_nx_g.nodes[e[1]]['position']) * scal)
new_nx_g[e[0]][e[1]]['weight'] = w
new_nx_g = nx.minimum_spanning_tree(new_nx_g)
if sso is not None:
sso = from_netkx_to_sso(sso, new_nx_g)
return sso, new_nx_g
[docs]def from_netkx_to_sso(sso, skel_nx):
"""
Converts a networkx graph representation of a skeleton into a
SuperSegmentationObject's skeleton.
Args:
sso: SuperSegmentationObject
The SuperSegmentationObject to update with the new skeleton.
skel_nx: networkx.Graph
The networkx graph representing the skeleton.
Returns:
SuperSegmentationObject: The updated SSO with the new skeleton.
"""
sso.skeleton = dict()
sso.skeleton['nodes'] = np.array([skel_nx.nodes[ix]['position'] for ix in
skel_nx.nodes()], dtype=np.uint32)
sso.skeleton['diameters'] = np.zeros(len(sso.skeleton['nodes']),
dtype=np.float32)
assert nx.number_connected_components(skel_nx) == 1
# Important bit, please don't remove (needed after pruning)
temp_edges = np.array(list(skel_nx.edges())).reshape(-1)
temp_edges_sorted = np.unique(np.sort(temp_edges))
temp_edges_dict = {}
for ii, ix in enumerate(temp_edges_sorted):
temp_edges_dict[ix] = ii
temp_edges = [temp_edges_dict[ix] for ix in temp_edges]
temp_edges = np.array(temp_edges, dtype=np.uint64).reshape([-1, 2])
sso.skeleton['edges'] = temp_edges
return sso
[docs]def create_sso_skeletons_wrapper(ssvs: List['super_segmentation.SuperSegmentationObject'],
dest_paths: Optional[str] = None, nb_cpus: Optional[int] = None,
map_myelin: bool = False, save: bool = True):
"""
Generates skeleton representations for a list of SuperSegmentationObjects. If
`global_params.config.allow_ssv_skel_gen = True`, skeletons are created via surface
sampling which may result in skeletons partially outside cell segmentation, but close
to the cell surface. Conversely, if `global_params.config.allow_ssv_skel_gen = False`,
existing supervoxel skeletons are pruned, stitched, and diameter estimates are performed.
Skeletons are saved using `ssv.save_skeleton` and accessible through `ssv.skeleton`.
Args:
ssvs: List[SuperSegmentationObject] | An iterable of cell reconstruction objects.
dest_paths: Optional[str] | Paths to kzips for each object in `ssvs`.
nb_cpus: Optional[int] | Number of CPUs used for every `ssv` in `ssvs`.
map_myelin: bool | If True, uses `map_myelin2coords` to map myelin predictions
to `ssv.skeleton["nodes"]`, storing the result in `ssv.skeleton["myelin"]`.
Predictions are smoothed via majority vote with 10 micrometers traversal.
save: bool | If True, writes the generated skeleton to disk.
Returns:
None
Todo:
* Add sliding window majority vote for smoothing myelin prediction to `global_params`.
"""
if nb_cpus is None:
nb_cpus = global_params.config['ncores_per_node']
if dest_paths is not None:
if not isinstance(dest_paths, Iterable):
raise ValueError('Destination paths given but are not iterable.')
else:
dest_paths = [None for _ in ssvs]
use_new_renderings_locs = global_params.config.use_new_renderings_locs
for ssv, dest_path in zip(ssvs, dest_paths):
ssv.nb_cpus = nb_cpus
if not global_params.config.allow_ssv_skel_gen:
# This merges existing SV skeletons - SV skeletons must exist
ssv = create_sso_skeleton_fast(ssv)
else:
verts = ssv.mesh[1].reshape(-1, 3)
# choose random subset of surface vertices
np.random.seed(0)
ixs = np.arange(len(verts))
np.random.shuffle(ixs)
ixs = ixs[:int(0.5 * len(ixs))]
# TODO: add parameter to config
if use_new_renderings_locs:
locs = generate_rendering_locs(verts[ixs], 1000)
else:
locs = surface_samples(verts[ixs], bin_sizes=(1000, 1000, 1000), max_nb_samples=10000, r=500)
g = create_graph_from_coords(locs, mst=True, force_single_cc=True)
if g.number_of_edges() == 1:
edge_list = np.array(list(g.edges()))
else:
edge_list = np.array(g.edges())
del g
if edge_list.ndim != 2:
raise ValueError("Edge list ist not a 2D array: {}\n{}".format(
edge_list.shape, edge_list))
ssv.skeleton = dict()
ssv.skeleton["nodes"] = (locs / np.array(ssv.scaling)).astype(np.int32)
ssv.skeleton["edges"] = edge_list
ssv.skeleton["diameters"] = np.ones(len(locs))
if map_myelin:
try:
ssv.skeleton["myelin"] = map_myelin2coords(ssv.skeleton["nodes"], mag=4)
majorityvote_skeleton_property(ssv, prop_key='myelin')
except Exception as e:
raise Exception(f'Myelin mapping in {ssv} failed with: {e}')
if save:
ssv.save_skeleton()
if dest_path is not None:
ssv.save_skeleton_to_kzip(dest_path=dest_path)
[docs]def map_myelin2coords(coords: np.ndarray,
cube_edge_avg: np.ndarray = np.array([11, 11, 5]),
thresh_proba: float = 255 // 2, thresh_majority: float = 0.5,
mag: int = 4) -> np.ndarray:
"""
Retrieves a myelin prediction at every location in `coords`. The classification
is the majority label within a cube of size `cube_edge_avg` around the
respective location. Voxels are classified as myelinated by thresholding the
probability using `thresh_proba`. A ratio `thresh_majority` decides the label.
Examples:
The entire myelin prediction for a single cell reconstruction including a
smoothing via :func:`~majorityvote_skeleton_property` is implemented as:
from syconn import global_params
from syconn.reps.super_segmentation import *
from syconn.reps.super_segmentation_helper import \
map_myelin2coords, majorityvote_skeleton_property
# init. example data set
global_params.wd = '~/SyConn/example_cube1/'
# initialize example cell reconstruction
ssd = SuperSegmentationDataset()
ssv = list(ssd.ssvs)[0]
ssv.load_skeleton()
# get myelin predictions
myelinated = map_myelin2coords(ssv.skeleton["nodes"], mag=4)
ssv.skeleton["myelin"] = myelinated
# this will generate a smoothed version at `ssv.skeleton["myelin_avg10000"]`
majorityvote_skeleton_property(ssv, "myelin")
# store results as a KNOSSOS readable k.zip file
ssv.save_skeleton_to_kzip(dest_path='~/{}_myelin.k.zip'.format(ssv.id),
additional_keys=['myelin', 'myelin_avg10000'])
Args:
coords: Coordinates used to retrieve myelin predictions. In voxel coordinates (mag=1).
cube_edge_avg: Cube size used for averaging myelin predictions for each location.
The loaded data cube will always have the extent given by `cube_edge_avg`, regardless
of the value of `mag`.
thresh_proba: Classification threshold in uint8 values (0..255).
thresh_majority: Majority ratio for myelin (between 0..1), i.e.
`thresh_majority=0.1` means that 10% myelin voxels within `cube_edge_avg`
will flag the corresponding locations as myelinated.
mag: Data magnification level used to retrieve the prediction results.
Returns:
An array of myelin predictions (0: no myelin, 1: myelinated neuron) for each coordinate.
"""
myelin_kd_p = global_params.config.working_dir + "/knossosdatasets/myelin/"
if not os.path.isdir(myelin_kd_p):
raise ValueError(f'Could not find myelin KnossosDataset at {myelin_kd_p}.')
kd = kd_factory(myelin_kd_p)
myelin_preds = np.zeros((len(coords)), dtype=np.uint8)
n_cube_vx = np.prod(cube_edge_avg)
# convert to mag 1, TODO: requires adaption if anisotropic downsampling was used in KD!
cube_edge_avg = cube_edge_avg * mag
for ix, c in enumerate(coords):
offset = c - cube_edge_avg // 2
myelin_proba = kd.load_raw(size=cube_edge_avg, offset=offset, mag=mag).swapaxes(0, 2)
myelin_ratio = np.sum(myelin_proba > thresh_proba) / n_cube_vx
myelin_preds[ix] = myelin_ratio > thresh_majority
return myelin_preds
# New Implementation of skeleton generation which makes use of ssv.rag
[docs]def from_netkx_to_arr(skel_nx: nx.Graph) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Converts a networkx graph representing a skeleton into arrays of nodes, diameters, and edges.
Args:
skel_nx: The networkx graph representing the skeleton.
Returns:
Tuple[np.ndarray, np.ndarray, np.ndarray]: Arrays representing the nodes, diameters,
and edges of the skeleton.
"""
skeleton = {}
skeleton['nodes'] = np.array(
[skel_nx.nodes[ix]['position'] for ix in skel_nx.nodes()],
dtype=np.uint32)
skeleton['diameters'] = np.zeros(len(skeleton['nodes']), dtype=np.float32)
# Important bit, please don't remove (needed after pruning)
# This transforms the edge values to contiguous node indices
temp_edges = np.array(list(skel_nx.edges())).reshape(-1)
temp_edges_sorted = np.unique(np.sort(temp_edges))
temp_edges_dict = {}
for ii, ix in enumerate(temp_edges_sorted):
temp_edges_dict[ix] = ii
temp_edges = [temp_edges_dict[ix] for ix in temp_edges]
temp_edges = np.array(temp_edges, dtype=np.uint64).reshape([-1, 2])
skeleton['edges'] = temp_edges
return skeleton['nodes'], skeleton['diameters'], skeleton['edges']
[docs]def sparsify_skeleton_fast(g: nx.Graph, scal: Optional[np.ndarray] = None,
dot_prod_thresh: float = 0.8,
max_dist_thresh: Union[int, float] = 500,
min_dist_thresh: Union[int, float] = 50,
verbose: bool = False) -> nx.Graph:
"""
Reduces the number of nodes in a skeleton graph based on geometric criteria.
Args:
g: The networkx graph representing the skeleton.
scal: Scale factor corresponding to the physical voxel size in nm.
dot_prod_thresh: Threshold for the dot product indicating 'straightness' of edges.
max_dist_thresh: Maximum distance allowed between adjacent nodes.
min_dist_thresh: Minimum distance below which nodes will be merged.
verbose: If True, additional output will be logged.
Returns:
nx.Graph: The sparsified skeleton graph.
"""
start = time.time()
skel_nx = nx.Graph(g)
n_nodes_start = skel_nx.number_of_nodes()
if scal is None:
scal = global_params.config['scaling']
change = 1
while change > 0:
change = 0
visiting_nodes = list({k for k, v in dict(skel_nx.degree()).items() if v == 2})
for visiting_node in visiting_nodes:
neighbours = [n for n in skel_nx.neighbors(visiting_node)]
if skel_nx.degree(visiting_node) == 2:
left_node = neighbours[0]
right_node = neighbours[1]
vector_left_node = np.array(
[int(skel_nx.nodes[left_node]['position'][ix]) - int(skel_nx.nodes[visiting_node]['position'][ix])
for
ix in range(3)]) * scal
vector_right_node = np.array([int(skel_nx.nodes[right_node]['position'][ix]) -
int(skel_nx.nodes[visiting_node]['position'][ix]) for ix in
range(3)]) * scal
dot_prod = np.dot(vector_left_node / np.linalg.norm(vector_left_node),
vector_right_node / np.linalg.norm(vector_right_node))
dist = np.linalg.norm([int(skel_nx.nodes[right_node]['position'][ix] * scal[ix]) - int(
skel_nx.nodes[left_node]['position'][ix] * scal[ix]) for ix in range(3)])
if (abs(dot_prod) > dot_prod_thresh and dist < max_dist_thresh) or dist <= min_dist_thresh:
skel_nx.remove_node(visiting_node)
skel_nx.add_edge(left_node, right_node)
change += 1
if verbose:
log_reps.debug(f'sparsening took {time.time() - start}. Reduced {n_nodes_start} to '
f'{skel_nx.number_of_nodes()} nodes')
return skel_nx
[docs]def create_new_skeleton_sv_fast(args):
"""
Generates a sparse skeleton for a super-segmentation object (SSO). This method creates edges
between supervoxels based on connectivity, suitable for creating sparse representations of
object skeletons. Note that performance varies based on connectivity rules.
Args:
sso: The Super Segmentation Object to generate the skeleton for.
pruning_thresh (int): Pruning threshold value for branch length, with shorter branches
being removed.
sparsify (bool): Controls whether the generated skeleton should be sparsified.
max_dist_thresh (float): Maximum permissible distance between edges for node pruning
during initial sparsification.
dot_prod_thresh (float): Dot product threshold for edge adjacencies influencing further
pruning post-sparsification.
max_dist_thresh_iter2 (float): Secondary distance threshold for additional pruning steps
after the initial round of sparsifying and pruning.
Returns:
A tuple consisting of arrays that represent the nodes, diameters, and edges of the
sparse skeleton, alongside cell reconstruction featuring a minimum spanning tree (MST)
with estimated radii measures.
Notes:
- If the use of multi-processing is desired, set `ssv.nb_cpus` to a value greater than 1.
- The initial sparsening algorithm can be found at :func:`~skeleton_optimization`.
- Further sparsification and pruning steps are detailed at
:func:`~sparsify_skeleton_fast`.
"""
so_id, sparsify = args
so = SegmentationObject(obj_type="sv", obj_id=so_id)
so.enable_locking = False
so.load_attr_dict()
# ignore diameters, will be populated at the and of create_sso_skeleton_fast
skel = load_skeleton(so)
nodes, diameters, edges = skel['nodes'].astype(np.uint32), skel['diameters'], skel['edges']
# create nx graph
skel_nx = nx.Graph()
skel_nx.add_nodes_from([(ix, dict(position=coord)) for ix, coord
in enumerate(nodes)])
new_edges = [tuple(ix) for ix in edges]
skel_nx.add_edges_from(new_edges)
if sparsify:
skel_nx = sparsify_skeleton_fast(skel_nx)
n_cc = nx.number_connected_components(skel_nx)
if n_cc > 1:
log_reps.warning('SV {} contained {} connected components in its skel'
'eton representation. Stitching now.'
''.format(so.id, n_cc))
# make edge values and node IDs contiguous prior to stitching
temp_edges = np.array(skel_nx.edges()).reshape(-1)
temp_nodes_sorted = np.array(np.sort(skel_nx.nodes()).reshape(-1))
temp_edges_dict = {}
for ii, ix in enumerate(temp_nodes_sorted):
temp_edges_dict[ix] = ii
temp_edges = [temp_edges_dict[ix] for ix in temp_edges]
temp_edges = np.array(temp_edges, dtype=np.uint64).reshape([-1, 2])
skel_nx_tmp = nx.Graph()
skel_nx_tmp.add_nodes_from([(temp_edges_dict[ix], skel_nx.nodes[ix]) for ix in
skel_nx.nodes()])
skel_nx_tmp.add_edges_from(temp_edges)
skel_nx = stitch_skel_nx(skel_nx_tmp)
nodes, diameters, edges = from_netkx_to_arr(skel_nx)
# just get nodes, diameters and edges
return nodes, diameters, edges
[docs]def from_sso_to_netkx_fast(sso, sparsify=True, max_edge_length=1.5e3):
"""
Create a sparse supervoxel skeleton from an existing initial skeleton.
This method acts as a multi-process helper similar to
`create_new_skeleton`, often used in conjunction with functions like
`from_sso_to_netkx_fast`.
Args:
sso: The SuperSegmentationObject to process (equivalent to Supervoxel
ID for skeleton creation).
sparsify: Flag indicating whether to sparsify the skeleton (equates
to the sparse flag).
max_edge_length: Optional; defines the maximum length of edges in the
resulting skeleton. If not provided, defaults to the
procedure used in the initial skeletonization.
Returns:
A tuple of three elements:
1. Node coordinates, expressed in voxels,
2. An estimation of the diameter for each node,
3. The list of edges connecting the nodes to represent the
skeleton as a graph.
This encompasses both the original and additional information, ensuring
consistency and completeness while adhering to the maximum character limit.
"""
skel_nx = nx.Graph()
sso.load_attr_dict()
ssv_skel = {'nodes': [], 'edges': [], 'diameters': []}
res = start_multiprocess_imap(create_new_skeleton_sv_fast,
[(sv_id, sparsify) for sv_id in sso.sv_ids],
nb_cpus=sso.nb_cpus, show_progress=False,
debug=False)
nodes, diameters, edges, sv_id_arr = [], [], [], []
# first offset is 0, last length is not needed
n_nodes_per_sv = [0] + list(
np.cumsum([len(el[0]) for el in res])[:-1])
for ii in range(len(res)):
if len(res[ii][0]) == 0: # skip missing / empty skeletons, e.g. happens for small SVs
continue
nodes.append(res[ii][0])
diameters.append(res[ii][1])
edges.append(res[ii][2] + int(n_nodes_per_sv[ii]))
# store mapping from node to SV ID
sv_id_arr.append([sso.sv_ids[ii]] * len(res[ii][0]))
ssv_skel['nodes'] = np.concatenate(nodes)
ssv_skel['diameters'] = np.concatenate(diameters, axis=0)
sv_id_arr = np.concatenate(sv_id_arr)
node_ix_arr = np.arange(len(sv_id_arr))
added_edges = set()
# stitching
if len(sso.sv_ids) > 1:
# iterates over SV object edges
g = sso.load_sv_graph()
# # TODO: activate as soon as SV graphs only connect adjacent SVs.
# # bridge SVs without skeleton > might speed up the SSV
# # skeleton generation if they lead to splits in the SV-graph -> fallback
# # to `stitch_skel_nx` is slow.
# # copy edges, as those might be modified inside loop
# for e1, e2 in list(g.edges()):
# # get closest node-pair between SV nodes in question and add new edge
# nodes1 = ssv_skel['nodes'][sv_id_arr == e1.id] * sso.scaling
# nodes2 = ssv_skel['nodes'][sv_id_arr == e2.id] * sso.scaling
# nodes1 = nodes1.astype(np.float32)
# nodes2 = nodes2.astype(np.float32)
# if len(nodes1) == 0:
# # bridge SV without skeleton
# neighbors = g.neighbors(e1)
# for p in itertools.combinations(neighbors, 2):
# g.add_edge(p[0], p[1])
# if len(nodes2) == 0:
# # bridge SV without skeleton
# neighbors = g.neighbors(e2)
# for p in itertools.combinations(neighbors, 2):
# g.add_edge(p[0], p[1])
for e1, e2 in g.edges():
# get closest node-pair between SV nodes in question and add new edge
nodes1 = ssv_skel['nodes'][sv_id_arr == e1.id] * sso.scaling
nodes2 = ssv_skel['nodes'][sv_id_arr == e2.id] * sso.scaling
nodes1 = nodes1.astype(np.float32)
nodes2 = nodes2.astype(np.float32)
if len(nodes1) == 0 or len(nodes2) == 0:
continue # SV without skeleton
nodes1_ix = node_ix_arr[sv_id_arr == e1.id]
nodes2_ix = node_ix_arr[sv_id_arr == e2.id]
tree = spatial.cKDTree(nodes1)
dists, node_ixs1 = tree.query(nodes2)
# # get global index of nodes
ix2 = nodes2_ix[np.argmin(dists)]
ix1 = nodes1_ix[node_ixs1[np.argmin(dists)]]
added_edges.add((sv_id_arr[ix1], sv_id_arr[ix2]))
node_dist_check = np.linalg.norm(ssv_skel['nodes'][ix1].astype(np.float32) *
sso.scaling - ssv_skel['nodes'][ix2].astype(
np.float32) * sso.scaling)
if np.min(dists) < node_dist_check or node_dist_check > max_edge_length:
log_reps.debug(f'Found long edge with length '
f'{node_dist_check / 1e3:.0f} um between SVs '
f'{e1.id} and {e2.id} although they were '
f'connected within the SV graph. Skipping.')
# TODO: remove as soon as SV graphs only connect adjacent SVs.
continue
edges.append(np.array([[ix1, ix2]], dtype=np.uint32))
ssv_skel['edges'] = np.concatenate(edges)
if len(ssv_skel['nodes']) == 0:
sso.skeleton = ssv_skel
return
skel_nx.add_nodes_from([(ix, dict(position=coord)) for ix, coord
in enumerate(ssv_skel['nodes'])])
edges = [tuple(ix) for ix in ssv_skel['edges']]
skel_nx.add_edges_from(edges)
if nx.number_connected_components(skel_nx) != 1:
msg = 'Stitching of SV skeletons failed during "from_sso_to_netkx_' \
'fast" with {} connected components using the underlying SSV ' \
'agglomeration. Please check the underlying RAG of SSV {}. ' \
'Now performing a slower stitching method to add missing edg' \
'es recursively between the closest connected components. ' \
'This warning might also occur if two supervoxels are connected ' \
'over supervoxel(s) without skeleton!' \
''.format(nx.number_connected_components(skel_nx), sso.id)
skel_nx = stitch_skel_nx(skel_nx)
log_reps.warning(msg)
assert nx.number_connected_components(skel_nx) == 1
ssv_skel['edges'] = np.array(skel_nx.edges(), dtype=np.uint64)
sso.skeleton = ssv_skel
return skel_nx
[docs]def create_sso_skeleton_fast(sso, pruning_thresh=800, sparsify=True, max_dist_thresh=600, dot_prod_thresh=0.0,
max_dist_thresh_iter2=600):
"""
Creates a sparse skeleton for a super-segmentation object (SSO). This method connects
supervoxels in the supervoxel graph, potentially using multi-processing (set
`ssv.nb_cpus` > 1 for this). It offers optional pruning and sparsification
steps.
Args:
sso: The SuperSegmentationObject to process.
pruning_thresh: Threshold for pruning short branches. Short branches are removed
below this path length, expressed in NM, as per `prune_stub_branches`.
sparsify: If True, the skeleton is sparsified; otherwise, it is not.
max_dist_thresh: Initial threshold for node pruning based on the maximum distance
in NM between two adjacent edges.
dot_prod_thresh: Threshold for pruning nodes based on the dot product value of two
adjacent edges. Nodes are pruned when their connecting edges' dot product is
above this value.
max_dist_thresh_iter2: A secondary distance threshold in NM for additional node
pruning between two adjacent edges after initial sparsening and pruning.
Returns:
The SuperSegmentationObject with an updated sparse skeleton, including minimal
spanning tree (MST) and radius estimates.
"""
# Creating network kx graph from sso skel
# log_reps.debug('Creating skeleton of SSO {}'.format(sso.id))
skel_nx = from_sso_to_netkx_fast(sso)
# log_reps.debug('Number CC after stitching and sparsifying SSO {}: {}'.format(sso.id,
# nx.number_connected_components(skel_nx)))
# Sparse again after stitching. Inexpensive.
if sparsify:
skel_nx = sparsify_skeleton_fast(skel_nx, max_dist_thresh=max_dist_thresh,
min_dist_thresh=max_dist_thresh)
# log_reps.debug(
# 'Number CC after 2nd sparsification SSO {}: {}'.format(sso.id,
# nx.number_connected_components(skel_nx)))
# Pruning the stitched sso skeletons
_, skel_nx = prune_stub_branches(nx_g=skel_nx, len_thres=pruning_thresh)
if sparsify:
# dot_prod_thresh=0.95: this allows to remove nodes which neighboring
# edges have an angle below ~36°
skel_nx = sparsify_skeleton_fast(
skel_nx, dot_prod_thresh=dot_prod_thresh,
max_dist_thresh=max_dist_thresh_iter2)
start = time.time()
for e in skel_nx.edges:
w = np.linalg.norm(
(skel_nx.nodes[e[0]]['position'] - skel_nx.nodes[e[1]]['position']) * global_params.config['scaling'])
skel_nx[e[0]][e[1]]['weight'] = w
skel_nx = nx.minimum_spanning_tree(skel_nx)
log_reps.debug(f'mst took {time.time() - start:.0f} s')
sso = from_netkx_to_sso(sso, skel_nx)
# reset weighted graph
sso._weighted_graph = None
# log_reps.debug('Number CC after pruning SSO {}: {}'.format(sso.id,
# nx.number_connected_components(skel_nx)))
# Estimating the radii
start = time.time()
sso.skeleton = radius_correction_found_vertices(sso)
log_reps.debug(f'radius estimation took {time.time() - start:.0f} s')
return sso
[docs]def glia_pred_exists(so):
so.load_attr_dict()
return "glia_probas" in so.attr_dict
[docs]def get_pca_view_hists(sso, t_net, pca):
views = sso.load_views()
latent = t_net.predict_proba(views2tripletinput(views))
latent = pca.transform(latent)
hist0 = np.histogram(latent[:, 0], bins=50, range=[-2, 2], normed=True)
hist1 = np.histogram(latent[:, 1], bins=50, range=[-3.2, 3], normed=True)
hist2 = np.histogram(latent[:, 2], bins=50, range=[-3.5, 3.5], normed=True)
return np.array([hist0, hist1, hist2])
[docs]def save_view_pca_proj(sso, t_net, pca, dest_dir, ls=20, s=6.0, special_points=(),
special_markers=(), special_kwargs=()):
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
views = sso.load_views()
latent = t_net.predict_proba(views2tripletinput(views))
latent = pca.transform(latent)
col = (np.array(latent) - latent.min(axis=0)) / (latent.max(axis=0) - latent.min(axis=0))
col = np.concatenate([col, np.ones_like(col)[:, :1]], axis=1)
for ii, (a, b) in enumerate([[0, 1], [0, 2], [1, 2]]):
fig, ax = plt.subplots()
plt.scatter(latent[:, a], latent[:, b], c=col, s=s, lw=0.5, marker="o",
edgecolors=col)
if len(special_points) >= 0:
for kk, sp in enumerate(special_points):
if len(special_markers) == 0:
sm = "x"
else:
sm = special_markers[kk]
if len(special_kwargs) == 0:
plt.scatter(sp[None, a], sp[None, b], s=75.0, lw=2.3,
marker=sm, edgecolor="0.3", facecolor="none")
else:
plt.scatter(sp[None, a], sp[None, b], **special_kwargs)
fig.patch.set_facecolor('white')
ax.tick_params(axis='x', which='major', labelsize=ls, direction='out',
length=4, width=3, right="off", top="off", pad=10)
ax.tick_params(axis='y', which='major', labelsize=ls, direction='out',
length=4, width=3, right="off", top="off", pad=10)
ax.tick_params(axis='x', which='minor', labelsize=ls, direction='out',
length=4, width=3, right="off", top="off", pad=10)
ax.tick_params(axis='y', which='minor', labelsize=ls, direction='out',
length=4, width=3, right="off", top="off", pad=10)
plt.xlabel(r"$Z_%d$" % (a + 1), fontsize=ls)
plt.ylabel(r"$Z_%d$" % (b + 1), fontsize=ls)
ax.xaxis.set_major_locator(ticker.MultipleLocator(2))
ax.yaxis.set_major_locator(ticker.MultipleLocator(2))
plt.tight_layout()
plt.savefig(dest_dir + "/%d_pca_%d%d.png" % (sso.id, a + 1, b + 1), dpi=400)
plt.close()
[docs]def label_array_for_sso_skel(sso, comment_converter):
"""
Converts skeleton node comments to a label array matching the node order.
Args:
sso: SuperSegmentationObject, contains the skeleton extracted from
sso.skeleton_kzip_path (see SkeletonAnnotation from knossos utils).
comment_converter: dict, maps node comments to integer labels. Unspecified
comments receive label -1.
Returns:
np.array, an array of labels corresponding to the order of nodes in
sso.skeleton["nodes"].
"""
if sso.skeleton is None:
sso.load_skeleton()
cd = skelnode_comment_dict(sso)
label_array = np.ones(len(sso.skeleton["nodes"]), dtype=np.int32) * -1
for ii, n in enumerate(sso.skeleton["nodes"]):
comment = cd[frozenset(n.astype(np.int32))].lower()
try:
label_array[ii] = comment_converter[comment]
except KeyError:
pass
return label_array
[docs]def write_axpred_cnn(ssv, pred_key_appendix, dest_path=None, k=1):
if dest_path is None:
dest_path = ssv.skeleton_kzip_path_views
pred_key = "axoness_preds%s" % pred_key_appendix
if not ssv.attr_exists(pred_key):
log_reps.info("Couldn't find specified axoness prediction. Falling back "
"to default.")
preds = np.array(start_multiprocess_obj("axoness_preds",
[[sv, {
"pred_key_appendix": pred_key_appendix}]
for sv in ssv.svs],
nb_cpus=ssv.nb_cpus))
preds = np.concatenate(preds)
else:
preds = ssv.lookup_in_attribute_dict(pred_key)
log_reps.debug("Collected axoness: {}".format(Counter(preds).most_common()))
locs = ssv.sample_locations()
log_reps.debug("Collected locations.")
pred_coords = np.concatenate(locs)
assert pred_coords.ndim == 2
assert pred_coords.shape[1] == 3
colors = np.array(np.array([[0.6, 0.6, 0.6, 1], [0.841, 0.138, 0.133, 1.],
[0.32, 0.32, 0.32, 1.]]) * 255, dtype=np.uint32)
ssv._pred2mesh(pred_coords, preds, "axoness.ply", dest_path=dest_path, k=k,
colors=colors)
[docs]def cnn_axoness2skel(sso: 'super_segmentation.SuperSegmentationObject',
pred_key_appendix: str = "", k: int = 1,
force_reload: bool = False,
save_skel: bool = True, use_cache: bool = False):
"""
Generates axoness predictions and probabilities for a given SuperSegmentationObject
(SSO) and saves them to the 'axoness_preds_cnn' attribute in the SSV attribute dict
as skeleton attributes. It maps the predictions from the supervoxel views to the
skeleton nodes using nearest neighbor assignment.
Args:
sso: The SuperSegmentationObject for which axoness predictions and probabilities
are generated.
pred_key_appendix: A string appended to the prediction keys to differentiate
between different prediction sets.
k: Deprecated. Previously used to define the number of nearest neighbors for
prediction assignment.
force_reload: If True, forces the reloading of predictions even if they already
exist. Reload SV predictions.
save_skel: If True, saves the skeleton with the new prediction attributes, saving
the SSV skeleton with prediction attributes "axoness" and
"axoness_probas".
use_cache: If True, caches the intermediate supervoxel predictions in the SSO
attribute dictionary on disk. Write intermediate SV predictions in SSV
attribute dict to disk.
Returns:
None
"""
if k != 1:
log_reps.warn("Parameter 'k' is deprecated but was set to {}. "
"It is not longer used in this method.".format(k))
if sso.skeleton is None:
sso.load_skeleton()
proba_key = "axoness_probas_cnn%s" % pred_key_appendix
pred_key = "axoness_preds_cnn%s" % pred_key_appendix
if not sso.attr_exists(pred_key) or not sso.attr_exists(proba_key) or \
force_reload:
preds = np.array(start_multiprocess_obj(
"axoness_preds", [[sv, {"pred_key_appendix": pred_key_appendix}]
for sv in sso.svs],
nb_cpus=sso.nb_cpus))
probas = np.array(start_multiprocess_obj(
"axoness_probas", [[sv, {"pred_key_appendix": pred_key_appendix}]
for sv in sso.svs], nb_cpus=sso.nb_cpus))
preds = np.concatenate(preds)
probas = np.concatenate(probas)
sso.attr_dict[proba_key] = probas
sso.attr_dict[pred_key] = preds
if use_cache:
sso.save_attributes([proba_key, pred_key], [probas, preds])
else:
preds = sso.lookup_in_attribute_dict(pred_key)
probas = sso.lookup_in_attribute_dict(proba_key)
loc_coords = np.concatenate(sso.sample_locations())
assert len(loc_coords) == len(preds), "Number of view coordinates is" \
"different from number of view" \
"predictions. SSO %d" % sso.id
# find NN in loc_coords for every skeleton node and use their majority
# prediction
node_preds = colorcode_vertices(sso.skeleton["nodes"] * sso.scaling,
loc_coords, preds, colors=[0, 1, 2], k=1)
node_probas, ixs = assign_rep_values(sso.skeleton["nodes"] * sso.scaling,
loc_coords, probas, return_ixs=True)
assert np.max(ixs) <= len(loc_coords), "Maximum index for sample " \
"coordinates is bigger than " \
"length of sample coordinates."
sso.skeleton["axoness%s" % pred_key_appendix] = node_preds
sso.skeleton["axoness_probas%s" % pred_key_appendix] = node_probas
sso.skeleton["view_ixs"] = ixs
if save_skel:
sso.save_skeleton()
[docs]def average_node_axoness_views(sso: 'super_segmentation.SuperSegmentationObject',
pred_key: Optional[str] = None,
pred_key_appendix: str = "",
max_dist: int = 10000, return_res: bool = False,
use_cache: bool = False):
"""
Averages axoness prediction along skeleton with maximum path length
of 'max_dist'. The majority prediction of neighboring nodes within
this distance is assigned to each node in the skeleton of the SSO.
Args:
sso: The SuperSegmentationObject whose skeleton nodes are evaluated.
pred_key: The key for retrieving stored supervoxel predictions. If
None, a default key is used: ``"axoness_preds_cnn%s" %
pred_key_appendix``.
pred_key_appendix: Appended to default prediction key if pred_key
is None, formatting into a complete key.
max_dist: The path length over which to average the predictions.
return_res: When True, a list of average predictions for each node
is returned rather than modifying the SSO.
use_cache: If enabled, caches intermediate supervoxel predictions
in the SSO attribute dictionary and can save to disk.
Returns:
If return_res is True, it returns a dictionary with the averaged
predictions for each skeleton node. Otherwise, the results are
integrated into the SSO and accessible via the generated key
``"%s_views_avg%d" % (pred_key, max_dist)`` and no value is
returned. The method does not call ``sso.save_skeleton()``.
"""
if sso.skeleton is None:
sso.load_skeleton()
if len(sso.skeleton["edges"]) == 0:
log_reps.error("Zero edges in skeleton of SSV %d. "
"Skipping averaging." % sso.id)
return
if pred_key is None:
pred_key = "axoness_preds_cnn%s" % pred_key_appendix
elif len(pred_key_appendix) > 0:
raise ValueError("Only one of the two may be given: 'pred_key' or"
"'pred_key_appendix', but not both.")
if type(pred_key) != str:
raise ValueError("'pred_key' has to be of type str.")
if not sso.attr_exists(pred_key) and ("axoness_preds_cnn" not in pred_key):
if len(pred_key_appendix) > 0:
log_reps.error("Couldn't find specified axoness prediction. Falling"
" back to default (-> per SV stored multi-view "
"prediction including SSV context")
preds = np.array(start_multiprocess_obj(
"axoness_preds", [[sv, {"pred_key_appendix": pred_key_appendix}]
for sv in sso.svs], nb_cpus=sso.nb_cpus))
preds = np.concatenate(preds)
sso.attr_dict[pred_key] = preds
if use_cache:
sso.save_attributes([pred_key], [preds])
else:
preds = sso.lookup_in_attribute_dict(pred_key)
loc_coords = np.concatenate(sso.sample_locations())
assert len(loc_coords) == len(preds), "Number of view coordinates is " \
"different from number of view " \
"predictions. SSO %d" % sso.id
if "view_ixs" not in sso.skeleton.keys():
log_reps.info("View indices were not yet assigned to skeleton nodes. "
"Running now '_cnn_axonness2skel(sso, "
"pred_key_appendix=pred_key_appendix, k=1)'")
cnn_axoness2skel(sso, pred_key_appendix=pred_key_appendix, k=1,
save_skel=not return_res, use_cache=use_cache)
view_ixs = np.array(sso.skeleton["view_ixs"])
avg_pred = []
g = sso.weighted_graph()
for n in range(g.number_of_nodes()):
paths = nx.single_source_dijkstra_path(g, n, max_dist)
neighs = np.array(list(paths.keys()), dtype=np.int64)
unique_view_ixs = np.unique(view_ixs[neighs], return_counts=False)
cls, cnts = np.unique(preds[unique_view_ixs], return_counts=True)
c = cls[np.argmax(cnts)]
avg_pred.append(c)
if return_res:
return avg_pred
sso.skeleton["axoness%s_avg%d" % (pred_key_appendix, max_dist)] = avg_pred
[docs]def majority_vote_compartments(sso: 'SuperSegmentationObject', ax_pred_key: str = 'axoness'):
"""
Determines the majority compartment prediction for each connected component of the skeleton graph of an SSO,
excluding soma nodes. The majority prediction is used to relabel the nodes of each connected component. By
default, will save new skeleton attribute with key ax_pred_key + "_comp_maj". Will not call
`sso.save_skeleton()`.
Args:
sso: The SuperSegmentationObject whose skeleton compartments are being processed.
ax_pred_key: The key for accessing axoness predictions stored in the SSO's skeleton.
Returns:
None
"""
g = sso.weighted_graph(add_node_attr=(ax_pred_key,))
soma_free_g = g.copy()
for n, d in g.nodes(data=True):
if d[ax_pred_key] == 2:
soma_free_g.remove_node(n)
ccs = list((soma_free_g.subgraph(c) for c in nx.connected_components(soma_free_g)))
new_axoness_dc = nx.get_node_attributes(g, ax_pred_key)
for cc in ccs:
preds = [d[ax_pred_key] for n, d in cc.nodes(data=True)]
cls, cnts = np.unique(preds, return_counts=True)
majority = cls[np.argmax(cnts)]
probas = np.array(cnts, dtype=np.float32) / np.sum(cnts)
# positively bias dendrite assignment
if (majority == 1) and (probas[cls == 1] < 0.66):
majority = 0
for n in cc.nodes():
new_axoness_dc[n] = majority
nx.set_node_attributes(g, new_axoness_dc, ax_pred_key)
new_axoness_arr = np.zeros((len(sso.skeleton["nodes"])))
for n, d in g.nodes(data=True):
new_axoness_arr[n] = d[ax_pred_key]
sso.skeleton[ax_pred_key + "_comp_maj"] = new_axoness_arr
sso.save_skeleton()
[docs]def majorityvote_skeleton_property(sso: 'super_segmentation.SuperSegmentationObject',
prop_key: str, max_dist: int = 10000,
return_res: bool = False) -> np.ndarray:
"""
Applies a sliding window majority vote along the skeleton of a SuperSegmentationObject (SSO), utilizing the specified property key. This vote determines the prevailing value over a path length for each skeletal node.
Args:
sso: The SuperSegmentationObject (SSO) whose skeleton property is being processed. This is the cell reconstruction object.
prop_key: The key identifier of the property to be processed.
max_dist: The maximum path length permitted for the sliding window along the L2-distance weighted skeleton graph.
return_res: If True, the function returns the resulting majority vote values for the skeleton nodes.
Returns:
If return_res is True, a numpy array with the majority vote result for each skeleton node is returned. When return_res is False, the function does not return a value; instead, it directly modifies the SSO without invoking `sso.save_skeleton()`.
Note:
The function will not call `sso.save_skeleton()` post-processing. To persist modifications, this call must be executed manually after function execution if necessary.
"""
if not prop_key in sso.skeleton:
raise ValueError(f'Given property "{prop_key}" does not exist in '
f'skeleton of SSV {sso.id}.')
g = sso.weighted_graph()
avg_prop = []
for n in range(g.number_of_nodes()):
paths = nx.single_source_dijkstra_path(g, n, max_dist)
neighs = np.array(list(paths.keys()), dtype=np.int64)
prop_vals, cnts = np.unique(sso.skeleton[prop_key][neighs],
return_counts=True)
c = prop_vals[np.argmax(cnts)]
avg_prop.append(c)
avg_prop = np.array(avg_prop)
if return_res:
return avg_prop
sso.skeleton["%s_avg%d" % (prop_key, max_dist)] = avg_prop
[docs]def find_incomplete_ssv_views(ssd: 'SuperSegmentationDataset', woglia: bool, n_cores: Optional[int] = None):
"""
Identifies SuperSegmentationObjects within a SuperSegmentationDataset that
have incomplete views.
Args:
ssd: The SuperSegmentationDataset to search within.
woglia: A boolean indicating whether to consider views with glia removal.
n_cores: The number of cores to use for processing. If None, the default
number of cores specified in the global configuration is used.
Returns:
A list of IDs of SuperSegmentationObjects with incomplete views.
"""
if n_cores is None:
n_cores = global_params.config['ncores_per_node']
sd = ssd.get_segmentationdataset("sv")
incomplete_sv_ids = find_missing_sv_views(sd, woglia, n_cores)
missing_ssv_ids = set()
incomplete_ssv_ids = ssd.sv2ssv_ids(incomplete_sv_ids)
for sv_id in incomplete_sv_ids:
try:
ssv_id = incomplete_ssv_ids[sv_id]
missing_ssv_ids.add(ssv_id)
except KeyError:
pass # sv does not exist in this SSD
return list(missing_ssv_ids)
[docs]def find_incomplete_ssv_skeletons(ssd, n_cores: Optional[int] = None):
"""
Identifies SuperSegmentationObjects within a SuperSegmentationDataset that
have incomplete skeletons.
Args:
ssd: The SuperSegmentationDataset to search within.
n_cores: The number of cores to use for processing. If None, the default
number of cores specified in the global configuration is used.
Returns:
A list of IDs of SuperSegmentationObjects with incomplete skeletons.
"""
if n_cores is None:
n_cores = global_params.config['ncores_per_node']
svs = np.concatenate([list(ssv.svs) for ssv in ssd.ssvs])
incomplete_sv_ids = find_missing_sv_skeletons(svs, n_cores)
missing_ssv_ids = set()
incomplete_ssv_ids = ssd.sv2ssv_ids(incomplete_sv_ids)
for sv_id in incomplete_sv_ids:
try:
ssv_id = incomplete_ssv_ids[sv_id]
missing_ssv_ids.add(ssv_id)
except KeyError:
pass # sv does not exist in this SSD
return list(missing_ssv_ids)
[docs]def find_missing_sv_attributes_in_ssv(ssd, attr_key, n_cores: Optional[int] = None):
"""
Identifies SuperSegmentationObjects within a SuperSegmentationDataset that
are missing specified attributes.
Args:
ssd: The SuperSegmentationDataset to search within.
attr_key: The key of the attribute to check for.
n_cores: The number of cores to use for processing. If None, the default
number of cores specified in the global configuration is used.
Returns:
A list of IDs of SuperSegmentationObjects missing the specified attribute.
"""
if n_cores is None:
n_cores = global_params.config['ncores_per_node']
sd = ssd.get_segmentationdataset("sv")
incomplete_sv_ids = find_missing_sv_attributes(sd, attr_key, n_cores)
missing_ssv_ids = set()
incomplete_ssv_ids = ssd.sv2ssv_ids(incomplete_sv_ids)
for sv_id in incomplete_sv_ids:
try:
ssv_id = incomplete_ssv_ids[sv_id]
missing_ssv_ids.add(ssv_id)
except KeyError:
pass # sv does not exist in this SSD
return list(missing_ssv_ids)
[docs]def predict_views_semseg(views, model, batch_size=10, verbose=False):
"""
Predicts semantic segmentation for a given array of views using a specified
model. It processes the array in batches and can provide verbose output.
Args:
views: A numpy array of shape [N_LOCS, N_CH, N_VIEWS, X, Y], where
N_LOCS is the number of locations, N_CH is the number of channels
(e.g., shape of cell, mitochondria, synaptic junctions, and vesicle
clouds), N_VIEWS is the number of views per location, and X, Y are
the spatial dimensions of each view. The array should be uint8 scaled
from 0 to 255.
model: A PyTorch model used for prediction.
batch_size: The batch size to use during prediction.
verbose: If True, additional output is printed during the process.
Returns:
A numpy array of predicted views with the same shape as the input 'views'.
"""
# if verbose:
# log_reps.debug('Reshaping view array with shape {}.'
# ''.format(views.shape))
views = views.astype(np.float32) / 255.
views = views.swapaxes(1, 2) # swap channel and view axis
# N, 2, 4, 128, 256
orig_shape = views.shape
# reshape to predict single projections, N*2, 4, 128, 256
views = views.reshape([-1] + list(orig_shape[2:]))
# if verbose:
# log_reps.debug('Predicting view array with shape {}.'
# ''.format(views.shape))
# predict and reset to original shape: N, 2, 4, 128, 256
labeled_views = model.predict_proba(views, bs=batch_size, verbose=verbose)
labeled_views = np.argmax(labeled_views, axis=1)[:, None]
# if verbose:
# log_reps.debug('Finished prediction of view array with shape {}.'
# ''.format(views.shape))
labeled_views = labeled_views.reshape(list(orig_shape[:2])
+ list(labeled_views.shape[1:]))
# swap axes to get source shape
labeled_views = labeled_views.swapaxes(2, 1)
return labeled_views
[docs]def pred_svs_semseg(model, views, pred_key=None, svs=None, return_pred=False,
nb_cpus=1, verbose=False, bs: int = 10):
"""
Predicts semantic segmentation for views of a list of Supervoxels (SVs) and
optionally saves them using SV.save_views. This is an efficient helper function
designed for chunked predictions and requires pre-loaded views.
Args:
model: The model used for semantic segmentation prediction.
views: List[np.array]
N_SV each with np.array of shape [N_LOCS, N_CH, N_VIEWS, X, Y] as uint8
scaled from 0 to 255, representing views of each supervoxel.
pred_key: The key under which predictions will be saved in SV view storages.
svs: Optional[list[SegmentationObject]]
List of SegmentationObject instances corresponding to SVs. If not provided,
SVs must be pre-loaded.
return_pred: Optional[bool]
If True, returns the predicted label views instead of saving them.
nb_cpus: int
The number of CPUs to use for saving SV views.
verbose: bool
If True, prints additional output during the prediction process.
bs: int
Batch size used during inference.
Returns: list[np.array]
If 'return_pred=True', returns the label views of the supervoxels as numpy arrays.
Otherwise, the predictions are saved without returning any value.
"""
if not return_pred and (svs is None or pred_key is None):
raise ValueError('SV objects and "pred_key" have to be given if'
' predictions should be saved at SV view storages.')
part_views = np.cumsum([0] + [len(v) for v in views])
assert len(part_views) == len(views) + 1
views = np.concatenate(views) # merge axis 0, i.e. N_SV and N_LOCS to N_SV*N_LOCS
# views have shape: M, 4, 2, 128, 256
label_views = predict_views_semseg(views, model, verbose=verbose, batch_size=bs)
svs_labelviews = []
for ii in range(len(part_views[:-1])):
sv_label_views = label_views[part_views[ii]:part_views[ii + 1]]
svs_labelviews.append(sv_label_views)
assert len(part_views) == len(svs_labelviews) + 1
if return_pred:
return svs_labelviews
params = [[sv, dict(views=views, index_views=False, woglia=True,
view_key=pred_key)]
for sv, views in zip(svs, svs_labelviews)]
start_multiprocess_obj('save_views', params, nb_cpus=nb_cpus)
[docs]def pred_sv_chunk_semseg(args):
"""
Helper method for predicting the 2D projections of supervoxels in chunks.
Args:
args: A tuple containing paths to the supervoxel storages to be processed,
the model, and supervoxel and prediction parameters.
Returns:
None. The predictions are stored in the supervoxel storages.
"""
from syconn.proc.sd_proc import sos_dict_fact, init_sos
from syconn.backend.storage import CompressedStorage
from syconn.handler.prediction import get_semseg_spiness_model
so_chunk_paths = args[0]
so_kwargs = args[1]
pred_kwargs = args[2]
# By default use views after glia removal
if 'woglia' in pred_kwargs:
woglia = pred_kwargs["woglia"]
del pred_kwargs["woglia"]
else:
woglia = True
pred_key = pred_kwargs["pred_key"]
if 'raw_only' in pred_kwargs:
raw_only = pred_kwargs['raw_only']
del pred_kwargs['raw_only']
else:
raw_only = False
model = get_semseg_spiness_model()
for p in so_chunk_paths:
# get raw views
view_dc_p = p + "/views_woglia.pkl" if woglia else p + "/views.pkl"
view_dc = CompressedStorage(view_dc_p, disable_locking=True)
svixs = list(view_dc.keys())
if len(svixs) == 0:
continue
views = list(view_dc.values())
if raw_only:
views = views[:, :1]
sd = sos_dict_fact(svixs, **so_kwargs)
svs = init_sos(sd)
label_views = pred_svs_semseg(model, views, svs, return_pred=True,
verbose=False)
# choose any SV to get a path constructor for the v
# iew storage (is the same for all SVs of this chunk)
lview_dc_p = svs[0].view_path(woglia, view_key=pred_key)
label_vd = CompressedStorage(lview_dc_p, disable_locking=True)
for ii in range(len(svs)):
label_vd[svs[ii].id] = label_views[ii]
label_vd.push()
[docs]def gliapred_sso_nocache(sso: 'SuperSegmentationObject', model, verbose: bool = True):
"""
Performs a multi-view based astrocyte inference on a SuperSegmentationObject
without using cached views. The result is stored as 'glia_probas' in the
attribute dictionaries of every supervoxel within the SuperSegmentationObject.
Access the probabilities of a supervoxel via `sso.svs[idx].attr_dict['glia_probas']`.
Args:
sso: The SuperSegmentationObject to process.
model: A PyTorch model used for astrocyte inference.
verbose: If True, additional output is printed during the process.
Returns:
None. The probabilities are stored in the attribute dictionaries of the
supervoxels within 'sso'.
"""
pred_key = "glia_probas"
assert sso.version == 'tmp', 'Only use this method with ssv.version="tmp".'
coords = sso.sample_locations(cache=False)
# len(part_views) == N + 1
part_views = np.cumsum([0] + [len(c) for c in coords])
flat_coords = np.array(flatten_list(coords))
# views are flat
views = render_sso_coords(sso, flat_coords, verbose=verbose, add_cellobjects=False, return_rot_mat=False)
sv_views = []
for ii in range(len(sso.svs)):
sv_views.append(views[part_views[ii]:part_views[ii+1]])
del views
probas = predict_views(model, sv_views, None, return_proba=True, pred_key=pred_key, nb_cpus=sso.nb_cpus,
verbose=verbose)
for ii, prob in enumerate(probas):
sso.svs[ii].attr_dict[pred_key] = prob
[docs]@jit(nopython=True)
def semseg2mesh_counter(index_arr: np.ndarray, label_arr: np.ndarray,
bg_label: int, count_arr: np.ndarray) -> np.ndarray:
"""
Counts the occurrence of labels in 'label_arr' for each vertex ID in 'index_arr'
and accumulates the counts in 'count_arr'.
Args:
index_arr: A flat array of contiguous vertex IDs, corresponding to the
order in 'label_arr'.
label_arr: A flat array of semantic segmentation prediction results,
corresponding to the order in 'index_arr'. The maximum value
must be below 'bg_label'.
bg_label: The label used to represent the background, which will not be
counted.
count_arr: A zero-initialized array to store the per-vertex counted labels
from 'label_arr'. It must have the shape (M, bg_label), where
M is the number of vertices of the underlying mesh.
Returns:
An array filled with the per-vertex label counts.
"""
for ii in range(len(index_arr)):
vertex_ix = index_arr[ii]
if vertex_ix == bg_label:
continue
l = label_arr[ii] # vertex label
count_arr[vertex_ix][l] += 1
return count_arr
[docs]def semseg2mesh(sso, semseg_key, nb_views=None, dest_path=None, k=1,
colors=None, force_recompute=False, index_view_key=None):
"""
Maps semantic segmentation predictions to the mesh of a SuperSegmentationObject (SSO) and optionally saves the
colored mesh to a file.
Args:
sso: The SuperSegmentationObject whose mesh will be colored based on semantic segmentation
predictions.
semseg_key: The key identifying the views containing the semantic segmentation results.
index_view_key: The key identifying the views containing the vertex indices. If set,
`nb_views` is ignored.
nb_views: The number of views used for the prediction, required if `index_view_key` is not set.
dest_path: If provided, the colored mesh will be written to a k.zip file at this path.
k: The number of nearest vertices to average over when mapping predictions to the mesh. If k=0,
unpredicted vertices will be treated as 'unpredicted' class.
colors: An array mapping labels to colors. If None, the majority label is returned instead.
Note to add a color for unpredicted vertices if k==0; here illustrated with by the spine
prediction example: if k=0: [neck, head, shaft, other, background, unpredicted] else:
[neck, head, shaft, other, background].
force_recompute: If True, forces re-mapping of the predicted labels to the mesh vertices.
Notes:
* ``k>0`` should only be used if a prediction for all vertices is absolutely required.
Filtering of background and unpredicted vertices should be favored if time complexity
is critical.
Returns:
If `dest_path` is None, returns a tuple containing the mesh indices, vertices, normals,
and colors. Otherwise, the function has no return value and the colored mesh is saved to
the specified path.
"""
ld = sso.label_dict('vertex')
if force_recompute or semseg_key not in ld:
ts0 = time.time() # view loading
if nb_views is None and index_view_key is None:
# load default
i_views = sso.load_views(index_views=True).flatten()
else:
if index_view_key is None:
index_view_key = "index{}".format(nb_views)
# load special views
i_views = sso.load_views(index_view_key).flatten()
semseg_views = sso.load_views(semseg_key).flatten()
ts1 = time.time()
# log_reps.debug('Time to load index and shape views: '
# '{:.2f}s.'.format(ts1 - ts0))
background_id = np.max(i_views)
# TODO: this will fail if no single pixel in all views is background
background_l = np.max(semseg_views)
unpredicted_l = background_l + 1
pp = len(sso.mesh[1]) // 3
count_arr = np.zeros((pp, background_l + 1), dtype=np.uint8)
count_arr = semseg2mesh_counter(i_views, semseg_views, background_id,
count_arr)
# np.argmax returns int64 array.. `colorcode_vertices` complexity is
# sensitive to the datatype of vertex_labels!
vertex_labels = np.argmax(count_arr, axis=1).astype(np.uint8)
mask = np.sum(count_arr, axis=1) == 0
vertex_labels[mask] = unpredicted_l
# background label is highest label in prediction (see 'generate_palette' or
# 'remap_rgb_labelviews' in multiviews.py)
if unpredicted_l > 255:
raise ValueError('Overflow in label view array.')
if k == 0: # map actual prediction situation / coverage
# keep unpredicted vertices and vertices with background labels
predicted_vertices = sso.mesh[1].reshape(-1, 3)
predictions = vertex_labels
else:
# remove unpredicted vertices
predicted_vertices = sso.mesh[1].reshape(-1, 3)[vertex_labels != unpredicted_l]
predictions = vertex_labels[vertex_labels != unpredicted_l]
# remove background class
predicted_vertices = predicted_vertices[predictions != background_l]
predictions = predictions[predictions != background_l]
ts2 = time.time()
# log_reps.debug('Time to map predictions on vertices: '
# '{:.2f}s.'.format(ts2 - ts1))
# High time complexity!
if k > 0: # map predictions of predicted vertices to all vertices
maj_vote = colorcode_vertices(
sso.mesh[1].reshape((-1, 3)), predicted_vertices, predictions,
k=k,
return_color=False, nb_cpus=sso.nb_cpus)
ts3 = time.time()
# log_reps.debug('Time to map predictions on unpredicted vertices'
# 'with k={}: {:.2f}s.'.format(k, ts3 - ts2))
else: # no vertex mask was applied in this case
maj_vote = predictions
# add prediction to mesh storage
ld[semseg_key] = maj_vote
ld.push()
else:
maj_vote = ld[semseg_key].astype(np.int32)
if colors is not None:
col = colors[maj_vote].astype(np.uint8)
if np.sum(col) == 0:
log_reps.warn('All colors-zero warning during "semseg2mesh"'
' of SSO {}. Make sure color values have uint8 range '
'0...255'.format(sso.id))
else:
col = maj_vote
if dest_path is not None:
if colors is None:
col = None # set to None, because write_mesh2kzip only supports
# RGBA colors and no labels
write_mesh2kzip(dest_path, sso.mesh[0], sso.mesh[1], sso.mesh[2],
col, ply_fname=semseg_key + ".ply")
return
return sso.mesh[0], sso.mesh[1], sso.mesh[2], col
[docs]def celltype_of_sso_nocache(sso, model, ws, nb_views, comp_window, nb_views_model: int = 20,
pred_key_appendix: str = "", verbose: bool = False,
overwrite: bool = True, use_syntype: bool = True,
da_equals_tan: bool = True, n_classes: int = 7, save_to_attr_dict: bool = True):
"""
Predicts the cell type of a SuperSegmentationObject without using file system caching.
This function renders raw views at rendering locations determined by `comp_window` and following the given view
properties. These views are then predicted with the provided `model`. By default, the resulting predictions and
probabilities are stored as 'celltype_cnn_e3' and 'celltype_cnn_e3_probas' in the attribute dictionary.
Args:
sso: SuperSegmentationObject to be processed.
model: A machine learning model used for predictions.
ws: Tuple[int, int], window size in pixels [y, x], determines the size of each view.
nb_views: int, number of views rendered at each location. Views are not stored on disk.
comp_window: float, physical extent in nm of the view-window along y.
nb_views_model: int, bootstrap sample size of view locations for model prediction.
pred_key_appendix: str, appendix for the prediction key in the attribute dictionary.
verbose: bool, if True, adds progress bars for view generation.
overwrite: bool, if True, overwrites existing views in temporary view dictionary.
use_syntype: bool, if True, uses the type of presynaptic synapses for prediction.
da_equals_tan: bool, if True, merges DA and TAN classes, requiring `n_classes` to be 7.
n_classes: int, number of output classes of the model, must be 7 if `da_equals_tan` is True.
save_to_attr_dict: bool, if True, saves the prediction in the attribute dictionary.
Returns:
None
"""
sso.load_attr_dict()
pred_key = "celltype_cnn_e3" + pred_key_appendix
if not overwrite and pred_key in sso.attr_dict:
return
view_kwargs = dict(ws=ws, comp_window=comp_window, nb_views=nb_views, verbose=verbose, add_cellobjects=True,
return_rot_mat=False)
verts = sso.mesh[1].reshape(-1, 3)
rendering_locs = generate_rendering_locs(verts, comp_window / 3) # three views per comp window
# overwrite default rendering locations (used later on for the view generation)
sso._sample_locations = rendering_locs
# this cache is only in-memory, and not file system cache
assert sso.view_caching, "'view_caching' of {} has to be True in order to" \
" run 'celltype_of_sso_nocache'.".format(sso)
tmp_view_key = 'tmp_views' + pred_key_appendix
if tmp_view_key not in sso.view_dict or overwrite:
views = render_sso_coords(sso, rendering_locs, **view_kwargs) # shape: N, 4, nb_views, y, x
sso.view_dict[tmp_view_key] = views # required for `sso_views_to_modelinput`
from ..handler.prediction import naive_view_normalization_new
if verbose:
log_reps.debug('Finished rendering. Starting cell type prediction.')
inp_d = sso_views_to_modelinput(sso, nb_views_model, view_key=tmp_view_key)
inp_d = naive_view_normalization_new(inp_d)
if use_syntype:
synsign_ratio = np.array([[syn_sign_ratio_celltype(sso, comp_types=[1, ]),
syn_sign_ratio_celltype(sso, comp_types=[0, ])]] * len(inp_d))
res = model.predict_proba((inp_d, synsign_ratio), bs=40)
else:
res = model.predict_proba(inp_d, bs=40)
if verbose:
log_reps.debug('Finished prediction.')
# DA and TAN are type modulatory, if this is changes, also change `certainty_celltype`, `predict_sso_celltype`
if da_equals_tan:
assert n_classes == 7
# accumulate evidence for DA and TAN
res[:, 1] += res[:, 6]
# remove TAN in proba array
res = np.delete(res, [6], axis=1)
# INT is now at index 6 -> label 6 is INT
clf = np.argmax(res, axis=1)
if np.max(clf) >= n_classes:
raise ValueError('Unknown cell type predicted.')
major_dec = np.zeros(n_classes)
for ii in range(len(major_dec)):
major_dec[ii] = np.sum(clf == ii)
major_dec /= np.sum(major_dec)
pred = np.argmax(major_dec)
sso.attr_dict[pred_key] = pred
sso.attr_dict[f"{pred_key}_probas"] = res
cert = sso.certainty_celltype(pred_key)
sso.attr_dict[f"{pred_key}_certainty"] = cert
if save_to_attr_dict:
sso.save_attributes([pred_key, f"{pred_key}_probas", f"{pred_key}_certainty"], [pred, res, cert])
[docs]def view_embedding_of_sso_nocache(sso: 'SuperSegmentationObject', model: 'torch.nn.Module', ws: Tuple[int, int],
nb_views: int, comp_window: Union[int, float], pred_key_appendix: str = "",
verbose: bool = False, overwrite: bool = True,
add_cellobjects: Union[bool, Iterable] = True):
"""
Renders views and predicts the view embedding of a SuperSegmentationObject without caching.
This function renders raw views at rendering locations determined by `comp_window` and
according to given view properties. These views are predicted with the provided `model`,
which does not require storing the views on the file system. The `predict_views_embedding`
method in `super_segmentation_object` can be used as an alternative that employs file-
system caching. By default, predictions are stored as `latent_morph`.
Args:
sso: SuperSegmentationObject to process. No file-system caching is used for views.
model: A torch neural network model used for prediction.
ws: Tuple[int, int], window size in pixels [y, x].
nb_views: int, number of views rendered at each rendering location.
comp_window: Union[int, float], physical extent in nm of the view-window along y.
pred_key_appendix: str, appendix for prediction key in attribute dictionary.
verbose: bool, if True, adds progress bars for view generation.
overwrite: bool, if True, overwrites existing views in temp view dictionary.
add_cellobjects: Union[bool, Iterable], specifies whether to add cell objects during
rendering. Accepts a boolean value or a list of structures used for rendering. This
is applicable only when `raw_view_key` or `nb_views` is None, leading to on-the-
fly rendering.
Returns:
None
"""
pred_key = "latent_morph"
pred_key += pred_key_appendix
view_kwargs = dict(ws=ws, comp_window=comp_window, nb_views=nb_views,
verbose=verbose, add_cellobjects=add_cellobjects,
return_rot_mat=False)
verts = sso.mesh[1].reshape(-1, 3)
# this cache is only in-memory, and not file system cache
assert sso.view_caching, "'view_caching' of {} has to be True in order to" \
" run 'view_embedding_of_sso_nocache'.".format(sso)
tmp_view_key = 'tmp_views' + pred_key_appendix
if tmp_view_key not in sso.view_dict or overwrite:
rendering_locs = generate_rendering_locs(verts, comp_window / 3) # ~3 views per comp window
# overwrite default rendering locations (used later on for the view generation)
sso._sample_locations = rendering_locs[None, ] # requires auxiliary axis
# views shape: N, 4, nb_views, y, x
views = render_sso_coords(sso, rendering_locs, **view_kwargs)
sso.view_dict[tmp_view_key] = views # required for `sso_views_to_modelinput`
else:
views = sso.view_dict[tmp_view_key]
from ..handler.prediction import naive_view_normalization_new
views = naive_view_normalization_new(views)
# The inference with TNets can be optimzed, via splititng the views into three equally sized parts.
inp = (views[:, :, 0], np.zeros_like(views[:, :, 0]), np.zeros_like(views[:, :, 0]))
# return dist1, dist2, inp1, inp2, inp3 latent
_, _, latent, _, _ = model.predict_proba(inp) # only use first view for now
# map latent vecs at rendering locs to skeleton node locations via nearest neighbor
sso.load_skeleton()
# view location ordering same as views / latent
hull_tree = spatial.cKDTree(np.concatenate(sso.sample_locations()))
dists, ixs = hull_tree.query(sso.skeleton["nodes"] * sso.scaling, n_jobs=sso.nb_cpus, k=1)
sso.skeleton[pred_key] = latent[ixs]
sso.save_skeleton()
[docs]def semseg_of_sso_nocache(sso, model, semseg_key: str, ws: Tuple[int, int],
nb_views: int, comp_window: float, k: int = 1,
dest_path: Optional[str] = None, verbose: bool = False,
add_cellobjects: Union[bool, Iterable] = True, bs: int = 10):
"""
Renders raw and index views at rendering locations determined by `comp_window`
and according to given view properties without storing them on the file system. Views will
be predicted with the given `model` and maps prediction results onto mesh.
Vertex labels are stored on file system and can be accessed via
`sso.label_dict('vertex')[semseg_key]`.
If sso._sample_locations is None, `generate_rendering_locs(verts, comp_window / 3)`
will be called to generate rendering locations.
Examples:
Given a cell reconstruction exported as kzip (see ) at ``cell_kzip_fn``
the compartment prediction (axon boutons, dendrite, soma) can be started
via the following script::
# set working directory to obtain models
global_params.wd = '~/SyConn/example_cube1/'
# get model for compartment detection
m = get_semseg_axon_model()
view_props = global_params.config['compartments']['view_properties_semsegax']
view_props["verbose"] = True
# load SSO instance from k.zip file
sso = init_sso_from_kzip(cell_kzip_fn, sso_id=1)
# run prediction and store result in new kzip
cell_kzip_fn_axon = cell_kzip_fn[:-6] + '_axon.k.zip'
semseg_of_sso_nocache(sso, dest_path=cell_kzip_fn_axon, model=m,
**view_props)
See also the example scripts at::
$ python SyConn/examples/semseg_axon.py
$ python SyConn/examples/semseg_spine.py
Args:
sso: Cell reconstruction object to be processed.
model: The machine learning model used for prediction.
semseg_key: The key used to store the resulting prediction.
ws: Tuple representing the window size in pixels (y, x).
nb_views: The number of views rendered at each rendering location.
comp_window: The physical extent in nm of the view-window along the y-axis.
k: The number of nearest vertices to average over for mesh mapping. If k=0,
unpredicted vertices will be treated as 'unpredicted' class.
dest_path: The file path to store the colored mesh k.zip file.
verbose: If True, adds progress bars for view generation.
add_cellobjects: If True, adds cell objects to the rendering. Can be a list of
structures to render. Only used when `raw_view_key` or `nb_views` is None - then
views are rendered on-the-fly.
bs: The batch size during inference.
Returns:
None
"""
view_kwargs = dict(ws=ws, comp_window=comp_window, nb_views=nb_views,
verbose=verbose, save=False)
raw_view_key = 'raw{}_{}_{}'.format(ws[0], ws[1], nb_views)
index_view_key = 'index{}_{}_{}'.format(ws[0], ws[1], nb_views)
verts = sso.mesh[1].reshape(-1, 3)
# use default rendering locations (used later on for the view generation)
if sso._sample_locations is None:
# ~three views per comp window
rendering_locs = generate_rendering_locs(verts, comp_window / 3)
sso._sample_locations = [rendering_locs]
assert sso.view_caching, "'view_caching' of {} has to be True in order to" \
" run 'semseg_of_sso_nocache'.".format(sso)
# this generates the raw views and their prediction
sso.predict_semseg(model, semseg_key, raw_view_key=raw_view_key,
add_cellobjects=add_cellobjects, bs=bs, **view_kwargs)
if verbose:
log_reps.debug('Finished shape-view rendering and sem. seg. prediction.')
# this generates the index views
sso.render_indexviews(view_key=index_view_key, force_recompute=True,
**view_kwargs)
if verbose:
log_reps.debug('Finished index-view rendering.')
# map prediction onto mesh and saves it to sso._label_dict['vertex'][semseg_key] (also pushed to file system!)
sso.semseg2mesh(semseg_key, index_view_key=index_view_key, dest_path=dest_path,
force_recompute=True, k=k)
if verbose:
log_reps.debug('Finished mapping of vertex predictions to mesh.')
[docs]def assemble_from_mergelist(ssd: 'SuperSegmentationDataset', mergelist: Union[Dict[int, int], str]):
"""
Creates a mapping dictionary and saves the dataset shallowly based on a mergelist.
This function will overwrite existing mapping dict, id changer, and version files.
Args:
ssd (SuperSegmentationDataset): The dataset to be updated with the new mapping.
mergelist: Supervoxel agglomeration provided either as a dictionary or as a file path
to a previously generated mergelist.
Returns:
None
"""
if mergelist is not None:
assert "sv" in ssd.version_dict
if isinstance(mergelist, dict):
pass
elif isinstance(mergelist, str):
with open(mergelist, "r") as f:
mergelist = mergelist_tools. \
subobject_map_from_mergelist(f.read())
else:
raise Exception("sv_mapping has unknown type")
mapping_dict = dict()
for sv_id in mergelist.values():
mapping_dict[sv_id] = []
for sv_id in mergelist.keys():
mapping_dict[mergelist[sv_id]].append(sv_id)
ssd._mapping_dict = mapping_dict
ssd.create_mapping_lookup_reverse()
ssd.save_dataset_shallow(overwrite=True)
[docs]def compartments_graph(ssv: 'super_segmentation.SuperSegmentationObject',
axoness_key: str) -> Tuple[nx.Graph, nx.Graph, nx.Graph]:
"""
Creates graphs for axon, dendrite, and soma compartments based on skeleton node
predictions.
Args:
ssv: Cell reconstruction object. Its skeleton must exist and must contain
keys ``'edges'``, ``'nodes'`` and `axoness_key`.
axoness_key: str, key for axon predictions in `ssv.skeleton` (0: dendrite,
1: axon, 2: soma). Convert labels 3 (en-passant bouton) and 4
(terminal bouton) to 1 (axon).
Returns:
Tuple[nx.Graph, nx.Graph, nx.Graph]: Graphs for dendrite, axon, and soma
compartments, respectively.
"""
axon_prediction = np.array(ssv.skeleton[axoness_key])
axon_prediction[axon_prediction == 3] = 1
axon_prediction[axon_prediction == 4] = 1
axon_ixs = np.nonzero(axon_prediction == 1)
dendrite_ixs = np.nonzero(axon_prediction == 0)
soma_ixs = np.nonzero(axon_prediction == 2)
so_graph = ssv.weighted_graph(add_node_attr=[axoness_key])
ax_graph = so_graph.copy()
den_graph = so_graph.copy()
for ix in axon_ixs:
so_graph.remove_node(ix)
den_graph.remove_node(ix)
for ix in dendrite_ixs:
so_graph.remove_node(ix)
ax_graph.remove_node(ix)
for ix in soma_ixs:
ax_graph.remove_node(ix)
den_graph.remove_node(ix)
return den_graph, ax_graph, so_graph
[docs]def syn_sign_ratio_celltype(ssv: 'super_segmentation.SuperSegmentationObject', weighted: bool = True,
recompute: bool = False, comp_types: Optional[List[int]] = None,
save: bool = False) -> float:
"""
Computes the ratio of symmetric synapses on specified compartments of a cell
reconstruction. The ratio is based on the synapse objects associated with the
SuperSegmentationObject.
Excludes partner cell compartment information. Refer to
`~syconn.reps.super_segmentation_object.SuperSegmentationObject.syn_sign_ratio`
for partner inclusion.
Todo:
* Check default of synapse type if synapse type predictions are not
available -> propagate to this method and return -1.
Notes:
* Bouton predictions are converted into axon label, i.e., 3 -> 1 (en-passant) and 4 -> 1 (terminal).
* Compartment predictions are collected after first attribute access during celltype prediction.
The key 'partner_axoness' is not available in `self.syn_ssv` until the relevant processing
function is called (see :func:`~syconn.exec.exec_syns.run_matrix_export`).
* The compartment type of the other cell cannot be inferred at this
point. Think about adding the property collection before celltype
prediction -> would allow more detailed filtering of the synapses,
but adds an additional round of property collection.
Args:
ssv (SuperSegmentationObject): The cell reconstruction.
weighted (bool): If True, compute synapse-area weighted ratio.
recompute (bool): If True, ignores existing values and recomputes.
comp_types (list, optional): Specifies the functional compartment types for
computing the ratio. Default is [1, ] for axons only.
save (bool): If True, saves the computed ratio using a key that includes
'syn_sign_ratio_celltype' or 'syn_sign_ratio_celltype_weighted' with
`comp_types`.
Returns:
float: The (area-weighted) ratio of symmetric synapses or -1 if no synapses
are present.
"""
if comp_types is None:
comp_types = [1, ]
ratio_key = 'syn_sign_ratio_celltype'
if weighted:
ratio_key += '_weighted'
ratio_key += '_' + "_".join([str(el) for el in comp_types])
ratio = ssv.lookup_in_attribute_dict(ratio_key)
if not recompute and ratio is not None:
return ratio
pred_key_ax = "{}_avg{}".format(global_params.config['compartments']['view_properties_semsegax']['semseg_key'],
global_params.config['compartments']['dist_axoness_averaging'])
if len(ssv.syn_ssv) == 0:
return -1
props = load_so_attr_bulk(ssv.syn_ssv, ('syn_sign', 'mesh_area', 'rep_coord'), allow_missing=True,
use_new_subfold=global_params.config.use_new_subfold)
syn_axs = ssv.attr_for_coords([props['rep_coord'][syn.id] for syn in ssv.syn_ssv], attr_keys=[pred_key_ax, ])[0]
# convert boutons to axon class
syn_axs[syn_axs == 3] = 1
syn_axs[syn_axs == 4] = 1
syn_signs = []
syn_sizes = []
for syn_ix, syn in enumerate(ssv.syn_ssv):
if syn_axs[syn_ix] not in comp_types:
continue
syn_sign = props['syn_sign'][syn.id]
syn_size = props['mesh_area'][syn.id] / 2
if syn_sign is None or syn_size is None:
raise ValueError(f'Got at least one None value for syn_sign and/or syn_size of {ssv.syn_ssv[syn_ix]}.')
syn_signs.append(syn_sign)
syn_sizes.append(syn_size)
if len(syn_signs) == 0 or np.sum(syn_sizes) == 0:
if save:
ssv.save_attributes([ratio_key], [-1])
return -1
syn_signs = np.array(syn_signs)
syn_sizes = np.array(syn_sizes)
if weighted:
ratio = np.sum(syn_sizes[syn_signs == -1]) / float(np.sum(syn_sizes))
else:
ratio = np.sum(syn_signs == -1) / float(len(syn_signs))
if save:
ssv.save_attributes([ratio_key], [ratio])
return ratio
[docs]def sso_svgraph2kzip(dest_path: str, sso: 'SuperSegmentationObject'):
"""
Stores the supervoxel graph of a SuperSegmentationObject in a KNOSSOS compatible
kzip file.
Args:
dest_path (str): The file path where the k.zip will be stored.
sso (SuperSegmentationObject): The SuperSegmentationObject whose supervoxel graph
is to be stored.
Returns:
None
"""
sv_edges = sso.load_sv_edgelist()
anno = SkeletonAnnotation()
anno.scaling = sso.scaling
sd = sso.get_seg_dataset('sv')
coord_dc = dict()
for ii in range(len(sd.ids)):
coord_dc[sd.ids[ii]] = sd.rep_coords[ii]
for sv1, sv2 in sv_edges:
sv1_coord = coord_dc[sv1.id] * sso.scaling
sv2_coord = coord_dc[sv2.id] * sso.scaling
n1 = SkeletonNode().from_scratch(anno, sv1_coord[0], sv1_coord[1], sv1_coord[2])
n1.data['svid'] = sv1.id
n2 = SkeletonNode().from_scratch(anno, sv2_coord[0], sv2_coord[1], sv2_coord[2])
n2.data['svid'] = sv2.id
anno.addNode(n1)
anno.addNode(n2)
anno.addEdge(n1, n2)
dummy_skel = Skeleton()
dummy_skel.add_annotation(anno)
dummy_skel.to_kzip(dest_path)