# -*- coding: utf-8 -*-
# SyConn - Synaptic connectivity inference toolkit
#
# Copyright (c) 2016 - now
# Max Planck Institute of Neurobiology, Martinsried, Germany
# Authors: Philipp Schubert, Joergen Kornfeld
# import here, otherwise it might fail if it is imported after importing torch
# see https://github.com/pytorch/pytorch/issues/19739
try:
import open3d as o3d
except ImportError:
pass # for sphinx build
import os
import re
import shutil
from collections import Counter
from logging import Logger
from typing import Iterable, Union, Optional, Any, Tuple, List
import numpy as np
from knossos_utils import knossosdataset
from knossos_utils.chunky import ChunkDataset, save_dataset
from knossos_utils.knossosdataset import KnossosDataset
from scipy.special import softmax
from scipy.stats import entropy
from sklearn.decomposition import PCA
from sklearn.neighbors import KNeighborsClassifier
from .basics import read_txt_from_zip, get_filepaths_from_dir, \
parse_cc_dict_from_kzip
from .compression import load_from_h5py, save_to_h5py
from .. import global_params
from ..handler import log_handler, log_main, basics
from ..handler.basics import chunkify
from ..handler.config import initialize_logging
from ..mp import batchjob_utils as qu
from ..proc.image import apply_morphological_operations
from ..reps import log_reps
# for readthedocs build
try:
import torch
except ImportError:
pass
[docs]def load_gt_from_kzip(zip_fname, kd_p, raw_data_offset=75, verbose=False,
mag=1):
"""
Loads ground truth from zip file, generated with Knossos. Corresponding
dataset config file is located at `kd_p`.
Args:
zip_fname (str): Path to the zip file containing ground truth data.
kd_p (str or List[str]): Path or list of paths to the Knossos dataset
configuration file(s).
raw_data_offset (int or np.array): Number of voxels for additional raw
offset, i.e., the offset for the raw data will be label_offset -
raw_data_offset, while the raw data volume will be label_volume +
2*raw_data_offset. It will use 'kd.scaling' to account for dataset
anisotropy if scalar or a list of length 3 has to be provided for
a custom x, y, z offset. Defaults to 75 if not provided.
verbose (bool, optional): Enables verbose output. Defaults to False.
mag (int): Magnification level of the data. Defaults to 1 if not
provided.
Returns:
tuple: A tuple containing two numpy arrays:
- raw data (float32) normalized by 1/255.
- label data (uint16)
"""
if type(kd_p) is str or type(kd_p) is bytes:
kd_p = [kd_p]
raw_data = []
label_data = []
for curr_p in kd_p:
kd = basics.kd_factory(curr_p)
bb = kd.get_movement_area(zip_fname)
offset, size = bb[0], bb[1] - bb[0]
scaling = np.array(kd.scale, dtype=np.int32)
if np.isscalar(raw_data_offset):
raw_data_offset = np.array(scaling[0] * raw_data_offset / scaling,
dtype=np.int32)
if verbose:
log_handler.debug(f'Using scale adapted raw offset: {raw_data_offset}')
elif len(raw_data_offset) != 3:
raise ValueError("Offset for raw cubes has to have length 3.")
else:
raw_data_offset = np.array(raw_data_offset)
raw = kd.load_raw(size=(size // mag + 2 * raw_data_offset) * mag,
offset=(offset // mag - raw_data_offset) * mag,
mag=mag).swapaxes(0, 2)
raw_data.append(raw[None,])
label = kd.load_kzip_seg(zip_fname, mag=mag).swapaxes(0, 2)
label = label
label_data.append(label[None,])
raw = np.concatenate(raw_data, axis=0).astype(np.float32)
label = np.concatenate(label_data, axis=0)
try:
_ = parse_cc_dict_from_kzip(zip_fname)
except: # mergelist.txt does not exist
label = np.zeros(size)
return raw.astype(np.float32) / 255., label
return raw.astype(np.float32) / 255., label
[docs]def predict_kzip(kzip_p, m_path, kd_path, clf_thresh=0.5, mfp_active=False,
dest_path=None, overwrite=False, gpu_ix=0,
imposed_patch_size=None):
"""
Predicts data contained in a k.zip file using a specified predictive model.
Args:
kzip_p (str): Path to the k.zip file containing raw data cube information.
m_path (str): Path to the predictive model.
kd_path (str): Path to the Knossos dataset configuration file.
clf_thresh (float, optional): Classification threshold. If not specified,
the function uses a default value.
mfp_active (bool, optional): Flag to activate max-fragment pooling. Defaults
to False. When set to False, the prediction uses standard processing.
dest_path (str, optional): Destination folder path. If None, the folder of
k.zip is used. Defaults to None.
overwrite (bool, optional): Flag to enable overwriting of existing files.
Defaults to False.
gpu_ix (int, optional): GPU index for model prediction. Defaults to 0.
imposed_patch_size (tuple, optional): Imposed patch size for the model. If
not specified, the model determines the patch size.
Returns:
The function returns predictions for the data in the specified k.zip file. If
a destination path is provided, it will save the predictions there, otherwise,
it saves in the folder of the k.zip file. The function supports GPU prediction,
overwriting existing files, and custom patch sizes when specified.
"""
cube_name = os.path.splitext(os.path.basename(kzip_p))[0]
if dest_path is None:
dest_path = os.path.dirname(kzip_p)
from elektronn2.utils.gpu import initgpu
if not os.path.isfile(dest_path + "/%s_data.h5" % cube_name) or overwrite:
raw, labels = load_gt_from_kzip(kzip_p, kd_p=kd_path,
raw_data_offset=0)
raw = xyz2zxy(raw)
initgpu(gpu_ix)
from elektronn2.neuromancer.model import modelload
m = modelload(m_path, imposed_patch_size=list(imposed_patch_size)
if isinstance(imposed_patch_size, tuple) else imposed_patch_size,
override_mfp_to_active=mfp_active, imposed_batch_size=1)
original_do_rates = m.dropout_rates
m.dropout_rates = ([0.0, ] * len(original_do_rates))
pred = m.predict_dense(raw[None,], pad_raw=True)[1]
# remove area without sufficient FOV
pred = zxy2xyz(pred)
raw = zxy2xyz(raw)
save_to_h5py([pred, raw], dest_path + "/%s_data.h5" % cube_name,
["pred", "raw"])
else:
pred, raw = load_from_h5py(dest_path + "/%s_data.h5" % cube_name,
hdf5_names=["pred", "raw"])
offset = parse_movement_area_from_zip(kzip_p)[0]
overlaycubes2kzip(dest_path + "/%s_pred.k.zip" % cube_name,
(pred >= clf_thresh).astype(np.uint32),
offset, kd_path)
[docs]def predict_h5(h5_path, m_path, clf_thresh=None, mfp_active=False,
gpu_ix=0, imposed_patch_size=None, hdf5_data_key=None,
data_is_zxy=True, dest_p=None, dest_hdf5_data_key="pred",
as_uint8=True):
"""
Predicts data from an h5 file using a specified predictive model.
Args:
h5_path (str): Path to the h5 file containing raw data.
m_path (str): Path to the predictive model.
clf_thresh (float, optional): Classification threshold. If None, no thresholding
is applied.
mfp_active (bool): Flag to activate max-fragment pooling. Defaults to False.
gpu_ix (int, optional): GPU index for model prediction. Defaults to 0.
imposed_patch_size (tuple, optional): Imposed patch size for the model.
hdf5_data_key (str, optional): Key for raw data in the h5 file. If None, the first
entry is used.
data_is_zxy (bool): Flag to indicate data order. If False, data is assumed to be in
[X, Y, Z] format.
dest_p (str): Destination path for the prediction output.
dest_hdf5_data_key (str, optional): Key for predicted data in the output h5 file.
Defaults to 'pred'.
as_uint8 (bool, optional): Flag to store prediction as uint8. Defaults to True.
"""
if hdf5_data_key:
raw = load_from_h5py(h5_path, hdf5_names=[hdf5_data_key])[0]
else:
raw = load_from_h5py(h5_path, hdf5_names=None)
assert len(raw) == 1, "'hdf5_data_key' not given but multiple hdf5 " \
"elements found. Please define raw data key."
raw = raw[0]
if not data_is_zxy:
raw = xyz2zxy(raw)
from elektronn2.utils.gpu import initgpu
initgpu(gpu_ix)
if raw.dtype.kind in ('u', 'i'):
raw = raw.astype(np.float32) / 255.
from elektronn2.neuromancer.model import modelload
m = modelload(m_path, imposed_patch_size=list(imposed_patch_size)
if isinstance(imposed_patch_size, tuple) else imposed_patch_size,
override_mfp_to_active=mfp_active, imposed_batch_size=1)
original_do_rates = m.dropout_rates
m.dropout_rates = ([0.0, ] * len(original_do_rates))
pred = m.predict_dense(raw[None,], pad_raw=True)[1]
pred = zxy2xyz(pred)
raw = zxy2xyz(raw)
if as_uint8:
pred = (pred * 255).astype(np.uint8)
raw = (raw * 255).astype(np.uint8)
if clf_thresh:
pred = (pred >= clf_thresh).astype(np.float32)
if dest_p is None:
dest_p = h5_path[:-3] + "_pred.h5"
if hdf5_data_key is None:
hdf5_data_key = "raw"
save_to_h5py([raw, pred], dest_p, [hdf5_data_key, dest_hdf5_data_key])
[docs]def overlaycubes2kzip(dest_p: str, vol: np.ndarray, offset: np.ndarray,
kd_path: str):
"""
Writes segmentation volume to kzip.
Args:
dest_p (str): Path to the destination k.zip file.
vol (np.ndarray): Segmentation or prediction volume as an unsigned integer
array in XYZ order.
offset (np.ndarray): Offset of the volume in the dataset.
kd_path (str): Path to the Knossos dataset configuration file.
Returns:
np.ndarray: A numpy array with shape [Z, X, Y] following the segmentation
or prediction volume.
"""
kd = basics.kd_factory(kd_path)
kd.from_matrix_to_cubes(offset=offset, kzip_path=dest_p,
mags=[1], data=vol)
[docs]def xyz2zxy(vol: np.ndarray) -> np.ndarray:
"""
Swaps axes to ELEKTRONN convention ([M, .., X, Y, Z] -> [M, .., Z, X, Y]).
Args:
vol (np.ndarray): Input volume in the shape [M, .., X, Y, Z].
Returns:
np.ndarray: Reordered volume in the shape [M, .., Z, X, Y].
"""
# assert vol.ndim == 3 # removed for multi-channel support
# adapt data to ELEKTRONN conventions (speed-up)
vol = vol.swapaxes(-2, -3) # y x z
vol = vol.swapaxes(-3, -1) # z x y
return vol
[docs]def zxy2xyz(vol: np.ndarray) -> np.ndarray:
"""
Swaps axes to ELEKTRONN convention, converting a volume from ZXY to XYZ order.
Args:
vol (np.ndarray): Input volume with shape [M, .., Z, X, Y].
Returns:
np.ndarray: The volume rearranged to shape [M, .., X, Y, Z].
"""
# assert vol.ndim == 3 # removed for multi-channel support
vol = vol.swapaxes(-2, -3) # x z y
vol = vol.swapaxes(-2, -1) # x y z
return vol
[docs]def xyz2zyx(vol: np.ndarray) -> np.ndarray:
"""
Swaps axes to ELEKTRONN convention ([M, .., X, Y, Z] -> [M, .., Z, X, Y]).
Args:
vol: np.array [M, .., X, Y, Z]
Returns: np.array [M, .., Z, X, Y]
"""
# assert vol.ndim == 3 # removed for multi-channel support
# adapt data to ELEKTRONN conventions (speed-up)
vol = vol.swapaxes(-1, -3) # [..., z, y, x]
return vol
[docs]def zyx2xyz(vol: np.ndarray) -> np.ndarray:
"""
Swaps axes to ELEKTRONN convention ([M, .., Z, X, Y] -> [M, .., X, Y, Z]).
Args:
vol: np.array [M, .., Z, X, Y]
Returns: np.array [M, .., X, Y, Z]
"""
# assert vol.ndim == 3 # removed for multi-channel support
vol = vol.swapaxes(-1, -3) # [..., x, y, z]
return vol
[docs]def create_h5_from_kzip(zip_fname: str, kd_p: str,
foreground_ids: Optional[Iterable[int]] = None,
overwrite: bool = True, raw_data_offset: int = 75,
debug: bool = False, mag: int = 1,
squeeze_data: int = True,
target_labels: Optional[Iterable[int]] = None,
apply_mops_seg: Optional[List[str]] = None):
"""
Create .h5 files for elektronn3 (zyx) input. Only supports binary labels
(0=background, 1=foreground).
Examples:
Suppose your k.zip file contains the segmentation GT with two
segmentation IDs 1, 2 and is stored at ``zip_fname``. The
corresponding ``KnossosDataset`` is located at ``kd_p``.
The following code snippet will create an ``.h5`` file in the
folder of ``zip_fname`` with the raw data (additional offset
controlled by ``raw_data_offset``) and the label data (either
binary or defined by ``target_labels``) with the keys ``raw`` and
``label`` respectively:
create_h5_from_kzip(d_p=kd_p, raw_data_offset=75,
zip_fname=zip_fname, mag=1, foreground_ids=[1, 2],
target_labels=[1, 2])
Args:
zip_fname (str): Path to the annotated kzip file.
kd_p (str): Path to the underlying raw data stored as KnossosDataset.
foreground_ids (Optional[Iterable[int]], optional): IDs to be
converted to foreground (1). Defaults to None, meaning all
non-zero are foreground.
overwrite (bool, optional): If True, overwrites existing h5 files.
Defaults to False.
raw_data_offset (int, optional): Number of voxels for additional
raw offset. Defaults to 75. Set to 0 if `debug` is True.
debug (bool, optional): If True, adds a 'debug' suffix and adapts
bit depths. Defaults to False.
mag (int, optional): Magnification level of the data. Defaults to 1.
squeeze_data (bool, optional): If True, squeezes label and raw data.
Defaults to False.
target_labels (Optional[Iterable[int]], optional): Target labels
for mapping foreground IDs. Must be set if `foreground_ids`
is specified. Defaults to None.
apply_mops_seg (Optional[List[str]], optional): Morphological
operations to apply to segmentation. Defaults to None.
"""
if not squeeze_data and apply_mops_seg is not None:
raise ValueError('Data might have axis with length one if squeeze_data=False.')
if target_labels is not None and foreground_ids is None:
raise ValueError('`target_labels` is set, but `foreground_ids` is None.')
fname, ext = os.path.splitext(zip_fname)
if fname[-2:] == ".k":
fname = fname[:-2]
if debug:
file_appendix = '_debug'
raw_data_offset = 0
else:
file_appendix = ''
fname_dest = fname + file_appendix + ".h5"
if os.path.isfile(fname_dest) and not overwrite:
print("File at {} already exists. Skipping.".format(fname_dest))
return
raw, label = load_gt_from_kzip(zip_fname, kd_p, mag=mag,
raw_data_offset=raw_data_offset)
if squeeze_data:
raw = raw.squeeze()
label = label.squeeze()
if foreground_ids is None:
try:
cc_dc = parse_cc_dict_from_kzip(zip_fname)
foreground_ids = np.concatenate(list(cc_dc.values()))
except: # mergelist.txt does not exist
foreground_ids = []
print("Foreground IDs not assigned. Inferring from "
"'mergelist.txt' in k.zip.:", foreground_ids)
create_h5_gt_file(fname_dest, raw, label, foreground_ids, debug=debug,
target_labels=target_labels, apply_mops_seg=apply_mops_seg)
[docs]def create_h5_gt_file(fname: str, raw: np.ndarray, label: np.ndarray,
foreground_ids: Optional[Iterable[int]] = None,
target_labels: Optional[Iterable[int]] = None,
debug: bool = False,
apply_mops_seg: Optional[List[str]] = None):
"""
Creates an h5 file for ELEKTRONN input from raw and label arrays, supporting only
binary labels (0=background, 1=foreground). For creating true negative cubes, set
`foreground_ids=[]` to indicate no foreground. If `foreground_ids` is None, all
non-zero values are treated as foreground.
Args:
fname (str): Path where the h5 file will be saved.
raw (np.ndarray): Raw data array.
label (np.ndarray): Label data array for binary classification.
foreground_ids (Optional[Iterable[int]], optional): IDs to be treated as
foreground. If None, all non-zero values are considered foreground. Defaults to None.
target_labels (Optional[Iterable[int]], optional): If set, `foreground_ids` must also
be set. Each ID in `foreground_ids` will be mapped to the corresponding label in
`target_labels`. Defaults to None.
debug (bool, optional): If True, stores labels and raw as uint8 ranging from 0 to 255
for debugging purposes. Defaults to False.
apply_mops_seg (Optional[List[str]], optional): List of string identifiers for
morphological operations to be applied to segmentation. Defaults to None.
"""
if target_labels is not None and foreground_ids is None:
raise ValueError('`target_labels` is set, but `foreground_ids` is None.')
print(os.path.split(fname)[1])
print("Label (before):", label.shape, label.dtype, label.min(), label.max())
label = binarize_labels(label, foreground_ids, target_labels=target_labels)
label = xyz2zxy(label)
raw = xyz2zxy(raw)
if apply_mops_seg is not None:
label = apply_morphological_operations(label, morph_ops=apply_mops_seg)
label = label.astype(np.uint16)
print("Raw:", raw.shape, raw.dtype, raw.min(), raw.max())
print("Label (after mapping):", label.shape, label.dtype, label.min(), label.max())
print("-----------------\nGT Summary:\n%s\n" % str(Counter(label.flatten()).items()))
if not fname[-2:] == "h5":
fname = fname + ".h5"
if debug:
raw = (raw * 255).astype(np.uint8, copy=False)
label = label.astype(np.uint8) * 255
save_to_h5py([raw, label], fname, hdf5_names=["raw", "label"])
[docs]def binarize_labels(labels: np.ndarray, foreground_ids: Iterable[int],
target_labels: Optional[Iterable[int]] = None):
"""
Transforms label array to binary label array (0=background, 1=foreground) or
to the labels provided in `target_labels` by mapping the foreground IDs
accordingly.
Args:
labels (np.ndarray): Input label array.
foreground_ids (Iterable[int]): IDs to be considered as foreground.
target_labels (Optional[Iterable[int]], optional): Target labels for
mapping foreground IDs. Defaults to None if not provided.
Returns:
np.ndarray: Transformed label array.
"""
new_labels = np.zeros_like(labels)
if foreground_ids is None:
target_labels = [1]
if len(np.unique(labels)) > 2:
print("------------ WARNING -------------\n"
"Found more than two different labels during label "
"conversion\n"
"----------------------------------")
new_labels[labels != 0] = 1
else:
try:
_ = iter(foreground_ids)
except TypeError:
foreground_ids = [foreground_ids]
if target_labels is None:
target_labels = [1 for _ in foreground_ids]
for ii, ix in enumerate(foreground_ids):
new_labels[labels == ix] = target_labels[ii]
labels = new_labels
assert len(np.unique(labels)) <= len(target_labels) + 1
assert 0 <= np.max(labels) <= np.max(target_labels)
assert 0 <= np.min(labels) <= np.max(target_labels)
return labels.astype(np.uint16)
[docs]def parse_movement_area_from_zip(zip_fname: str) -> np.ndarray:
"""
Parses the MovementArea (bounding box of labeled volume) from an annotation.xml
file within a (k.)zip file.
Args:
zip_fname (str): Path to the zip file containing the annotation.xml file.
Returns:
np.ndarray: Array representing the Movement Area with shape [2, 3].
"""
anno_str = read_txt_from_zip(zip_fname, "annotation.xml").decode()
line = re.findall("MovementArea (.*)/>", anno_str)
assert len(line) == 1
line = line[0]
bb_min = np.array([re.findall(r'min.\w="(\d+)"', line)], dtype=np.uint64)
bb_max = np.array([re.findall(r'max.\w="(\d+)"', line)], dtype=np.uint64)
# Movement area is stored with 0-indexing! No adjustment needed
return np.concatenate([bb_min, bb_max])
[docs]def pred_dataset(*args, **kwargs):
log_handler.warning("'pred_dataset' will be replaced by 'predict_dense_to_kd' in"
" the near future.")
return _pred_dataset(*args, **kwargs)
def _pred_dataset(kd_p, kd_pred_p, cd_p, model_p, imposed_patch_size=None,
mfp_active=False, gpu_id=0, overwrite=False, i=None, n=None):
"""
Predicts a dataset or a subset of it using a specified model and saves the results.
Args:
kd_p (str): Path to the Knossos dataset configuration file.
kd_pred_p (str): Path to the Knossos dataset head folder for the prediction.
cd_p (str): Destination folder for the chunk dataset containing the prediction.
model_p (str): Path to the ELEKTRONN2 model.
imposed_patch_size (tuple, optional): Model patch size in Z, X, Y order. Defaults to None.
mfp_active (bool, optional): Flag to activate max-fragment pooling. Defaults to False.
gpu_id (int, optional): GPU index used for prediction. Defaults to 0.
overwrite (bool, optional): If True, overwrites existing predictions. Defaults to False.
i (optional): Index of the current processing unit. Defaults to None.
n (optional): Total number of processing units. Defaults to None.
Returns:
This function does not return any values; it saves the prediction results to the specified
destination folder.
"""
from elektronn2.utils.gpu import initgpu
initgpu(gpu_id)
from elektronn2.neuromancer.model import modelload
kd = KnossosDataset()
kd.initialize_from_knossos_path(kd_p, fixed_mag=1)
m = modelload(model_p, imposed_patch_size=list(imposed_patch_size)
if isinstance(imposed_patch_size, tuple) else imposed_patch_size,
override_mfp_to_active=mfp_active, imposed_batch_size=1)
original_do_rates = m.dropout_rates
m.dropout_rates = ([0.0, ] * len(original_do_rates))
offset = m.target_node.shape.offsets
offset = np.array([offset[1], offset[2], offset[0]], dtype=np.int32)
cd = ChunkDataset()
cd.initialize(kd, kd.boundary, [512, 512, 256], cd_p, overlap=offset,
box_coords=np.zeros(3), fit_box_size=True)
ch_dc = cd.chunk_dict
print('Total number of chunks for GPU/GPUs:', len(ch_dc.keys()))
if i is not None and n is not None:
chunks = ch_dc.values()[i::n]
else:
chunks = ch_dc.values()
print("Starting prediction of %d chunks in gpu %d\n" % (len(chunks), gpu_id))
if not overwrite:
for chunk in chunks:
try:
_ = chunk.load_chunk("pred")[0]
except Exception as e:
chunk_pred(chunk, m)
else:
for chunk in chunks:
try:
chunk_pred(chunk, m)
except KeyboardInterrupt as e:
print("Exiting out from chunk prediction: ", str(e))
return
save_dataset(cd)
# single gpu processing also exports the cset to kd
"""TODO: Use pyknossos conf like here to support bigger cube size:
target_kd = knossosdataset.KnossosDataset()
target_kd._cube_shape = cube_shape
scale = np.array(global_params.config['scaling'])
target_kd.scales = [scale, ]
target_kd.initialize_without_conf(path, kd.boundary, scale, kd.experiment_name,
mags=[1, ], create_pyk_conf=True, create_knossos_conf=False)
target_kd = basics.kd_factory(path) # test if init is possible
"""
if n is None:
kd_pred = KnossosDataset()
kd_pred.initialize_without_conf(kd_pred_p, kd.boundary, kd.scale,
kd.experiment_name, mags=[1, 2, 4, 8])
cd.export_cset_to_kd(kd_pred, "pred", ["pred"], [4, 4], as_raw=True,
stride=[256, 256, 256])
[docs]def predict_dense_to_kd(kd_path: str, target_path: str, model_path: str,
n_channel: int, target_names: Optional[Iterable[str]] = None,
target_channels: Optional[Iterable[Iterable[int]]] = None,
channel_thresholds: Optional[Iterable[Union[float, Any]]] = None,
log: Optional[Logger] = None, mag: int = 1,
overlap_shape_tiles: Tuple[int, int, int] = (40, 40, 20),
cube_of_interest: Optional[Tuple[np.ndarray]] = None,
overwrite: bool = False,
cube_shape_kd: Optional[Tuple[int]] = None):
"""
Performs dense prediction on a Knossos dataset and writes the results to new KnossosDatasets.
This function runs predictions on the entire Knossos dataset specified by `kd_path`. The
predictions are written to new KnossosDatasets with names specified by `target_names` in the
directory `target_path`. If no threshold is provided and there is only one channel per target
name, the KnossosDataset will contain a probability map in the raw channel as uint8 (0..255).
Otherwise, classification results will be written to the overlay channel.
Notes:
* Requires a significant amount of GPU memory (at least 12GB).
* The resulting KnossosDatasets do not use pyknossos configurations.
* The GPU memory usage should be adjustable from the configuration or determined
automatically.
Args:
kd_path: Path to the KnossosDataset configuration file of the raw data.
target_path: Destination directory for the output KnossosDatasets containing predictions.
model_path: Path to the elektronn3 model used for predictions. Loaded via the
:class:`~elektronn3.inference.inference.Predictor`.
n_channel: Number of channels predicted by the model.
target_names: Names of the target KnossosDatasets, e.g. `target_names=['synapse_fb',
'synapse_type']`. Defaults to `['pred']`. Length must match with `target_channels`.
target_channels: Channel IDs in the prediction for each target KnossosDataset, e.g.
`target_channels=[(1, 2)]` if the prediction has two foreground labels. Defaults to
`[[ix for ix in range(n_channel)]]`. Length must match with `target_names`.
channel_thresholds: Thresholds for channels. If None and the number of channels for the
target KnossosDataset is 1, probabilities are stored. Otherwise, defaults to 0.5.
log: Logger for output messages.
mag: Magnification level of the data.
overlap_shape_tiles: Overlap in voxels [XYZ] used for each tile predicted during inference.
Properties such as chunk size and tile count may additionally influence the overlap::
chunk_size = np.array([1024, 1024, 256], dtype=np.int32) # XYZ
n_tiles = np.array([4, 4, 16])
tile_shape = (chunk_size / n_tiles).astype(np.int32)
# final input shape must be multiple of tile_shape
overlap_shape = tile_shape // 2
cube_of_interest: Bounding box of the volume of interest (min and max coordinate in voxels)
in the respective magnification level (see kwarg `mag`).
overwrite: Whether to overwrite existing KnossosDatasets.
cube_shape_kd: Cube shape for storing sub-volumes in KnossosDataset on the file system.
"""
if log is None:
log = initialize_logging('dense_predictions', global_params.config.working_dir + '/logs/', overwrite=False)
if target_names is None:
target_names = ['pred']
if target_channels is None:
target_channels = [[ix for ix in range(n_channel)]]
if not len(target_names) == len(target_channels):
msg = 'For every target name the target channels have to be specified.'
log_reps.error(msg)
raise ValueError(msg)
if channel_thresholds is None:
channel_thresholds = [None for _ in range(n_channel)]
kd = basics.kd_factory(kd_path)
if cube_of_interest is None:
cube_of_interest = (np.zeros(3, ), kd.boundary // mag)
if cube_shape_kd is None:
cube_shape_kd = (256, 256, 256)
# TODO: these should be config parameters
overlap_shape_tiles = np.array([30, 31, 20])
overlap_shape = overlap_shape_tiles
chunk_size = np.array([482, 481, 236])
# if qu.batchjob_enabled():
# chunk_size *= 2
tile_shape = [271, 181, 138]
cd = ChunkDataset()
cd.initialize(kd, cube_of_interest[1], chunk_size, target_path + '/cd_tmp/',
box_coords=cube_of_interest[0], list_of_coords=[],
fit_box_size=True, overlap=overlap_shape)
chunk_ids = list(cd.chunk_dict.keys())
# init target KnossosDatasets
target_kd_path_list = [target_path + '/{}/'.format(tn) for tn in target_names]
for path in target_kd_path_list:
if os.path.isdir(path):
if not overwrite:
msg = f'Found existing KD at "{path}" but overwrite is set to False.'
log.error(msg)
raise ValueError(msg)
log.debug('Found existing KD at {}. Removing it now.'.format(path))
shutil.rmtree(path)
for path in target_kd_path_list:
target_kd = knossosdataset.KnossosDataset()
target_kd._cube_shape = cube_shape_kd
scale = np.array(global_params.config['scaling'])
target_kd.scales = [scale, ]
# TODO: use pyk conf!
target_kd.initialize_without_conf(path, kd.boundary, kd.scale,
kd.experiment_name, [2 ** x for x in range(6)],
create_pyk_conf=False, create_knossos_conf=True)
try: # make sure init works
basics.kd_factory(path)
except ValueError as e:
log.error(f'Could not initialize KnossosDataset at "{path}". {e}')
# init batchjob parameters
multi_params = chunk_ids
multi_params = chunkify(multi_params, global_params.config.ngpu_total)
multi_params = [(ch_ids, kd_path, target_path, model_path, overlap_shape,
overlap_shape_tiles, tile_shape, chunk_size, n_channel, target_channels,
target_kd_path_list, channel_thresholds, mag, cube_of_interest)
for ch_ids in multi_params]
log.info('Started dense prediction of {} in {:d} chunk(s).'.format(", ".join(target_names), len(chunk_ids)))
n_cores_per_job = global_params.config['ncores_per_node'] // global_params.config['ngpus_per_node'] if \
qu.batchjob_enabled() else global_params.config['ncores_per_node']
qu.batchjob_script(multi_params, "predict_dense", n_cores=n_cores_per_job, suffix='_' + '_'.join(target_names),
remove_jobfolder=True, log=log, additional_flags="--gres=gpu:1")
log.info('Finished dense prediction of {}'.format(", ".join(target_names)))
[docs]def dense_predictor(args):
"""
Transforms volumes and performs model predictions.
This function is responsible for transforming volumes by switching XYZ and ZYX order before
passing them to the model for prediction. The function may also transform the predicted
volumes back if necessary.
Args:
args: Tuple(
chunk_ids: list
List of chunk IDs in the chunk dataset.
kd_p: str
Path to the Knossos dataset configuration file.
cd_p: str
Destination folder for the chunk dataset containing prediction.
model_p: str
Path to the model used for predictions.
offset: Variable type and description from the old docstring.
chunk_size: Variable type and description from the old docstring.
... Additional parameters related to the prediction process.
Returns:
None. The function is used for its side effects of writing predictions to disk.
"""
# TODO: remove chunk necessity
# TODO: clean up (e.g. redundant chunk sizes, ...)
#
chunk_ids, kd_p, target_p, model_p, overlap_shape, overlap_shape_tiles, tile_shape, chunk_size, n_channel, \
target_channels, target_kd_path_list, channel_thresholds, mag, cube_of_interest = args
# init KnossosDataset:
kd = KnossosDataset()
kd.initialize_from_knossos_path(kd_p)
# init ChunkDataset:
cd = ChunkDataset()
cd.initialize(kd, cube_of_interest[1], chunk_size, target_p + '/cd_tmp/',
box_coords=cube_of_interest[0], list_of_coords=[],
fit_box_size=True, overlap=overlap_shape)
# init Target KnossosDataset
target_kd_dict = {}
for path in target_kd_path_list:
target_kd = knossosdataset.KnossosDataset()
target_kd = basics.kd_factory(path)
target_kd_dict[path] = target_kd
# init Predictor
from elektronn3.inference import Predictor
ix = 0
tile_shape = np.array(tile_shape)
while True:
try:
out_shape = (chunk_size + 2 * np.array(overlap_shape)).astype(np.int32)[::-1] # ZYX
out_shape = np.insert(out_shape, 0, n_channel) # output must equal chunk size
predictor = Predictor(model_p, strict_shapes=True, tile_shape=tile_shape[::-1],
out_shape=out_shape, overlap_shape=overlap_shape_tiles[::-1],
apply_softmax=True)
predictor.model.ae = False
_ = predictor.predict(np.zeros(out_shape[1:])[None, None])
break
except RuntimeError: # cuda MemoryError
if np.all(tile_shape % 2):
raise ValueError('Cannot reduce tile shape anymore. Please adapt '
'the tile/overlap/chunk shape in the function '
'that is calling `dense_predictor`.')
while tile_shape[ix] % 2:
ix += 1
tile_sh_orig = np.array(tile_shape)
tile_shape[ix] = tile_shape[ix] // 2
log_main.warn(f'Changed tile shape from {tile_sh_orig} to '
f'{tile_shape} to reduce memory requirements.')
ix = (ix + 1) % 3 # permute spatial dimension which is reduced
# predict Chunks
for ch_id in chunk_ids:
ch = cd.chunk_dict[ch_id]
ol = ch.overlap
size = np.array(np.array(ch.size) + 2 * np.array(ol),
dtype=np.int32)
coords = np.array(np.array(ch.coordinates) - np.array(ol),
dtype=np.int32)
raw = kd.load_raw(size=size * mag, offset=coords * mag, mag=mag)
pred = dense_predicton_helper(raw.astype(np.float32) / 255., predictor,
is_zyx=True, return_zyx=True)
# slice out the original input volume along ZYX, i.e. the last three axes
pred = pred[..., ol[2]:-ol[2], ol[1]:-ol[1], ol[0]:-ol[0]]
for j in range(len(target_channels)):
ids = target_channels[j]
path = target_kd_path_list[j]
data = np.zeros_like(pred[0]).astype(np.uint64)
save_as_raw = not (len(ids) > 1)
for label in ids:
t = channel_thresholds[label]
# if threshold is given or multiple target labels per dataset
# store classification results
# TODO: argmax might be more reasonable
if not save_as_raw:
if t is None:
t = 255 / 2
if t < 1.:
t = 255 * t
pred_mask = pred[label] > t
data[pred_mask] = label
else:
# no thresholding and only one label in the target KnossosDataset
# -> store probability map.
data = pred[label]
if save_as_raw:
target_kd_dict[path].save_raw(
offset=ch.coordinates * mag, data=data.astype(np.uint8),
data_mag=mag, mags=[mag, mag * 2, mag * 4],
fast_resampling=True, upsample=False)
else:
target_kd_dict[path].save_seg(
offset=ch.coordinates * mag, data=data, data_mag=mag,
mags=[mag, mag * 2, mag * 4],
fast_resampling=True, upsample=False)
[docs]def dense_predicton_helper(raw: np.ndarray, predictor: 'Predictor', is_zyx=False,
return_zyx=False) -> np.ndarray:
"""
Assists with the prediction of dense volumes.
This function helps with the prediction of dense volumes by handling the transformation
of data formats and calling the model's prediction method.
Args:
raw: The input data array in CXYZ format.
predictor: The model which performs the inference. Requires `predictor.predict`.
is_zyx: A flag indicating if the input data is already in ZYX format.
return_zyx: A flag indicating if the output data should be in ZYX format.
Returns:
The inference result in CXYZ format as uint8 ranging from 0 to 255.
"""
# transform raw data
if not is_zyx:
raw = xyz2zyx(raw)
# predict: pred of the form (N, C, [D,], H, W)
pred = predictor.predict(raw[None, None]).numpy()
pred = np.array(pred[0]) * 255 # remove N-axis
pred = pred.astype(np.uint8)
if not return_zyx:
pred = zyx2xyz(pred)
return pred
[docs]def to_knossos_dataset(kd_p, kd_pred_p, cd_p, model_p,
imposed_patch_size, mfp_active=False):
"""
Converts a chunk dataset to a Knossos dataset using a model for prediction.
This function is deprecated and will be replaced by `predict_dense_to_kd`. It predicts the
entire or partial Knossos dataset and writes the results to a Knossos dataset.
Args:
kd_p: The path to the Knossos dataset configuration file.
kd_pred_p: The path to the Knossos dataset directory which will contain the prediction.
cd_p: The destination folder for the chunk dataset containing the prediction.
model_p: The path to the ELEKTRONN2 model.
imposed_patch_size: The patch size of the model.
mfp_active: A flag indicating whether max-fragment pooling is active.
Returns:
None. The function is used for its side effects of writing predictions to disk.
"""
from elektronn2.neuromancer.model import modelload
log_reps.warning('Depracation Warning; "to_knossos_dataset" is deprecated and will be '
'replaced by "predict_dense_to_kd" which immediately .')
kd = KnossosDataset()
kd.initialize_from_knossos_path(kd_p, fixed_mag=1)
kd_pred = KnossosDataset()
m = modelload(model_p, imposed_patch_size=list(imposed_patch_size)
if isinstance(imposed_patch_size, tuple) else imposed_patch_size,
override_mfp_to_active=mfp_active, imposed_batch_size=1)
original_do_rates = m.dropout_rates
m.dropout_rates = ([0.0, ] * len(original_do_rates))
offset = m.target_node.shape.offsets
offset = np.array([offset[1], offset[2], offset[0]], dtype=np.int32)
cd = ChunkDataset()
cd.initialize(kd, kd.boundary, [512, 512, 256], cd_p, overlap=offset,
box_coords=np.zeros(3), fit_box_size=True)
"""TODO: Use pyknossos conf like here to support bigger cube size:
target_kd = knossosdataset.KnossosDataset()
target_kd._cube_shape = cube_shape
scale = np.array(global_params.config['scaling'])
target_kd.scales = [scale, ]
target_kd.initialize_without_conf(path, kd.boundary, scale, kd.experiment_name,
mags=[1, ], create_pyk_conf=True, create_knossos_conf=False)
target_kd = basics.kd_factory(path) # test if init is possible
"""
kd_pred.initialize_without_conf(kd_pred_p, kd.boundary, kd.scale,
kd.experiment_name, mags=[1, 2, 4, 8])
cd.export_cset_to_kd(kd_pred, "pred", ["pred"], [4, 4], as_raw=True,
stride=[256, 256, 256])
[docs]def prediction_helper(raw, model, override_mfp=True,
imposed_patch_size=None):
"""
Helper function for predicting raw volumes (range: 0 to 255; uint8).
This function predicts raw volumes and converts the coordinate format from
XYZ to the ELEKTRONN-specific ZXY before returning the prediction in the
original XYZ format. Ensure the imposed patch size is specified in ZXY.
Args:
raw: np.array
The raw volume data in XYZ format.
model: str or model object
The path to the model (.mdl) or the model object itself.
override_mfp: bool
A flag indicating whether to override max fragment pooling.
imposed_patch_size: tuple
The patch size imposed on the model, in ZXY format.
Returns: np.array
The prediction data in XYZ format.
"""
if type(model) == str:
from elektronn2.neuromancer.model import modelload
m = modelload(model, imposed_patch_size=list(imposed_patch_size)
if isinstance(imposed_patch_size, tuple) else imposed_patch_size,
override_mfp_to_active=override_mfp, imposed_batch_size=1)
original_do_rates = m.dropout_rates
m.dropout_rates = ([0.0, ] * len(original_do_rates))
else:
m = model
raw = xyz2zxy(raw)
if raw.dtype.kind in ('u', 'i'):
# convert to float 32 and scale it
raw = raw.astype(np.float32) / 255.
if not raw.dtype == np.float32:
# assume already normalized between 0 and 1
raw = raw.astype(np.float32)
assert 0 <= np.max(raw) <= 1.0 and 0 <= np.min(raw) <= 1.0
pred = m.predict_dense(raw[None,], pad_raw=True)[1]
return zxy2xyz(pred)
[docs]def chunk_pred(ch: 'chunky.Chunk', model: 'torch.nn.Module', debug: bool = False):
"""
Helper function to write chunks to disk.
This function writes chunks of predicted data to disk during the prediction process. It handles
individual chunks, storing the output of the prediction made by the given model.
Args:
ch: The chunk of data to be predicted (Chunk).
model: The model or the path to the model used for prediction (str or model object).
debug: A flag indicating whether to run in debug mode (bool).
Returns:
None. The function writes data to disk as a side effect.
"""
raw = ch.raw_data()
pred = prediction_helper(raw, model) * 255
pred = pred.astype(np.uint8)
ch.save_chunk(pred, "pred", "pred", overwrite=True)
if debug:
ch.save_chunk(raw, "pred", "raw", overwrite=False)
[docs]def get_glia_model_e3():
"""
Retrieves the elektronn3 model for glia prediction, typically trained with
`naive_view_normalization_new`.
Returns:
The trained InferenceModel for glia prediction.
"""
from elektronn3.models.base import InferenceModel
m = InferenceModel(global_params.config.mpath_glia_e3, normalize_func=naive_view_normalization_new)
return m
[docs]def get_celltype_model_e3():
"""
Retrieves the elektronn3 model trained for cell type prediction. Unlike standard
e3 InferenceModel instances which employ `naive_view_normalization_new`, this model
applies view normalization in a distinct downstream inference method
(`predict_sso_celltype`) due to the additional scalar inputs which should not undergo
normalization.
Returns:
The trained InferenceModel for cell type prediction, which handles view
normalization differently to accommodate additional scalar inputs.
"""
try:
from elektronn3.models.base import InferenceModel
except ImportError as e:
msg = "elektronn3 could not be imported ({}). Please see 'https://github." \
"com/ELEKTRONN/elektronn3' for more information.".format(e)
log_main.error(msg)
raise ImportError(msg)
m = torch.jit.load(global_params.config.mpath_celltype_e3)
m = InferenceModel(m, bs=40, multi_gpu=True)
return m
[docs]def get_semseg_spiness_model():
"""
Retrieves the elektronn3 model trained for semantic segmentation of spines.
Returns:
The trained InferenceModel for semantic segmentation of spines.
"""
try:
from elektronn3.models.base import InferenceModel
except ImportError as e:
msg = "elektronn3 could not be imported ({}). Please see 'https://github." \
"com/ELEKTRONN/elektronn3' for more information.".format(e)
log_main.error(msg)
raise ImportError(msg)
path = global_params.config.mpath_spiness
m = torch.jit.load(path)
m = InferenceModel(m)
m._path = path
return m
[docs]def get_semseg_axon_model():
"""
Retrieves the elektronn3 model trained for semantic segmentation of axons.
Returns:
The trained InferenceModel for semantic segmentation of axons.
"""
try:
from elektronn3.models.base import InferenceModel
except ImportError as e:
msg = "elektronn3 could not be imported ({}). Please see 'https://github." \
"com/ELEKTRONN/elektronn3' for more information.".format(e)
log_main.error(msg)
raise ImportError(msg)
path = global_params.config.mpath_axonsem
m = torch.jit.load(path)
m = InferenceModel(m)
m._path = path
return m
[docs]def get_tripletnet_model_e3():
"""
Retrieves the elektronn3 model trained with `naive_view_normalization_new`.
Returns:
The trained InferenceModel used for comparison tasks in a triplet
network configuration.
"""
try:
from elektronn3.models.base import InferenceModel
except ImportError as e:
msg = "elektronn3 could not be imported ({}). Please see 'https://github." \
"com/ELEKTRONN/elektronn3' for more information.".format(e)
log_main.error(msg)
raise ImportError(msg)
m = torch.jit.load(global_params.config.mpath_tnet)
m = InferenceModel(m)
return m
[docs]def get_myelin_cnn():
"""
Retrieves the elektronn3 model trained to predict binary myelin-in class.
Returns:
The trained Inference model for myelin prediction.
"""
try:
from elektronn3.inference.inference import Predictor
except ImportError as e:
msg = "elektronn3 could not be imported ({}). Please see 'https://github." \
"com/ELEKTRONN/elektronn3' for more information.".format(e)
log_main.error(msg)
raise ImportError(msg)
m = torch.jit.load(global_params.config.mpath_myelin)
m = Predictor(m)
return m
[docs]def get_knn_tnet_embedding_e3():
"""
Retrieves the K-Nearest Neighbors classifier trained on triplet network embeddings.
Returns:
The trained KNeighborsClassifier instance.
"""
tnet_eval_dir = "{}/pred/".format(global_params.config.mpath_tnet)
return knn_clf_tnet_embedding(tnet_eval_dir)
[docs]def get_pca_tnet_embedding_e3():
"""
Retrieves the PCA transformation trained on triplet network embeddings.
Returns:
The trained PCA instance. This instance is fitted to the specific
embeddings generated by the triplet network, allowing for dimensionality
reduction and visualization of high-dimensional data.
"""
tnet_eval_dir = "{}/pred/".format(global_params.config.mpath_tnet)
return pca_tnet_embedding(tnet_eval_dir)
[docs]def naive_view_normalization(d):
"""
Performs a simple normalization of the input data for backward compatibility.
Args:
d: The input data to be normalized.
Returns:
The normalized data.
"""
# TODO: Remove with new dataset, only necessary for backwards compat.
d = d.astype(np.float32)
# perform pseudo-normalization
# (proper normalization: how to store mean and std for inference?)
if not (np.min(d) >= 0 and np.max(d) <= 1.0):
for ii in range(len(d)):
curr_view = d[ii]
if 0 <= np.max(curr_view) <= 1.0:
curr_view = curr_view - 0.5
else:
curr_view = curr_view / 255. - 0.5
d[ii] = curr_view
else:
d = d - 0.5
return d
[docs]def naive_view_normalization_new(d):
"""
Performs a simple normalization on the input data.
Args:
d: The input data to be normalized.
Returns:
The normalized data.
"""
return d.astype(np.float32) / 255. - 0.5
[docs]def knn_clf_tnet_embedding(fold, fit_all=False):
"""
Trains a K-Nearest Neighbors classifier on triplet network embeddings.
This method assumes embedding for GT views has been created already in 'fold'
and stored in l_train_%d.npy / l_valid_%d.npy files.
Args:
fold (str): The directory containing the embeddings and labels.
fit_all (bool): A flag indicating whether to fit the classifier on all data.
Returns:
The trained KNeighborsClassifier instance.
"""
train_fnames = get_filepaths_from_dir(
fold, fname_includes=["l_axoness_train"], ending=".npy")
valid_fnames = get_filepaths_from_dir(
fold, fname_includes=["l_axoness_valid"], ending=".npy")
train_d = []
train_l = []
valid_d = []
valid_l = []
for tf in train_fnames:
train_l.append(np.load(tf))
tf = tf.replace("l_axoness_train", "ls_axoness_train")
train_d.append(np.load(tf))
for tf in valid_fnames:
valid_l.append(np.load(tf))
tf = tf.replace("l_axoness_valid", "ls_axoness_valid")
valid_d.append(np.load(tf))
train_d = np.concatenate(train_d).astype(dtype=np.float32)
train_l = np.concatenate(train_l).astype(dtype=np.uint16)
valid_d = np.concatenate(valid_d).astype(dtype=np.float32)
valid_l = np.concatenate(valid_l).astype(dtype=np.uint16)
nbrs = KNeighborsClassifier(n_neighbors=5, algorithm='auto',
n_jobs=16, weights='uniform')
if fit_all:
nbrs.fit(np.concatenate([train_d, valid_d]),
np.concatenate([train_l, valid_l]).ravel())
else:
nbrs.fit(train_d, train_l.ravel())
return nbrs
[docs]def pca_tnet_embedding(fold, n_components=3, fit_all=False):
"""
Performs PCA on triplet network embeddings.
Args:
fold: The directory containing the embeddings and labels.
n_components: The number of principal components to keep.
fit_all: A flag indicating whether to fit the PCA on all data.
Returns:
The trained PCA instance. Assumes embedding for GT views has been
created already in 'fold' and put into l_train_%d.npy / l_valid_%d.npy
files.
"""
train_fnames = get_filepaths_from_dir(
fold, fname_includes=["l_axoness_train"], ending=".npy")
valid_fnames = get_filepaths_from_dir(
fold, fname_includes=["l_axoness_valid"], ending=".npy")
train_d = []
train_l = []
valid_d = []
valid_l = []
for tf in train_fnames:
train_l.append(np.load(tf))
tf = tf.replace("l_axoness_train", "ls_axoness_train")
train_d.append(np.load(tf))
for tf in valid_fnames:
valid_l.append(np.load(tf))
tf = tf.replace("l_axoness_valid", "ls_axoness_valid")
valid_d.append(np.load(tf))
train_d = np.concatenate(train_d).astype(dtype=np.float32)
train_l = np.concatenate(train_l).astype(dtype=np.uint16)
valid_d = np.concatenate(valid_d).astype(dtype=np.float32)
valid_l = np.concatenate(valid_l).astype(dtype=np.uint16)
pca = PCA(n_components, whiten=True, random_state=0)
if fit_all:
pca.fit(np.concatenate([train_d, valid_d]))
else:
pca.fit(train_d)
return pca
[docs]def certainty_estimate(inp: np.ndarray, is_logit: bool = False) -> float:
"""
Estimates the certainty of (independent) predictions of the same sample:
1. If `is_logit` is True, apply softmax on the input.
2. Sum the evidence per class and (re-)normalize.
3. Compute the entropy, scale it with the maximum entropy (equal
probabilities) and subtract it from 1 to get the certainty.
Args:
inp: 2D array of prediction results (N: number of samples,
C: Number of classes).
is_logit: If True, applies ``softmax(inp, axis=1)``.
Returns:
Certainty measure based on the entropy of a set of (independent)
predictions.
"""
if not inp.ndim == 2:
raise ValueError('Input is not two dimensional.')
if is_logit:
proba = softmax(inp, axis=1)
else:
proba = inp
# sum probabilities across samples and normalize
proba = np.mean(proba, axis=0)
# maximum entropy at equal probabilities: -sum(1/N*ln(1/N)) = ln(N)
entr_max = np.log(len(proba))
entr_norm = entropy(proba) / entr_max
# convert to certainty estimate
return 1 - entr_norm
[docs]def str2int_converter(comment: str, gt_type: str) -> int:
"""
Converts semantic string labels into integer labels based on the ground truth type.
Args:
comment: The semantic string label to be converted.
gt_type: The type of ground truth, which determines the conversion logic.
Returns:
The integer label corresponding to the semantic string label.
"""
if gt_type == "axgt":
if comment == "gt_axon":
return 1
elif comment == "gt_dendrite":
return 0
elif comment == "gt_soma":
return 2
elif comment == "gt_bouton":
return 3
elif comment == "gt_terminal":
return 4
else:
return -1
elif gt_type == "spgt":
if "head" in comment:
return 1
elif "neck" in comment:
return 0
elif "shaft" in comment:
return 2
elif "other" in comment:
return 3
else:
return -1
elif gt_type == 'ctgt_j0251':
str2int_label = dict(STN=0, DA=1, MSN=2, LMAN=3, HVC=4, TAN=5, GPe=6, GPi=7,
FS=8, LTS=9)
return str2int_label[comment]
elif gt_type == 'ctgt_j0251_v2':
str2int_label = dict(STN=0, DA=1, MSN=2, LMAN=3, HVC=4, TAN=5, GPe=6, GPi=7,
FS=8, LTS=9, NGF=10)
return str2int_label[comment]
else:
raise ValueError("Given groundtruth type is not valid.")
# create function that converts information in string type to the
# information in integer type
[docs]def int2str_converter(label: int, gt_type: str) -> str:
"""
Converts an integer label into a semantic string based on the specified ground truth
type. The conversion is specific to the domain of connectomics, where labels
represent various cellular structures or cell types.
Args:
label: An integer representing the label to be converted.
gt_type: A string specifying the ground truth type. It determines the mapping
from integer labels to semantic strings. Examples include 'spgt' for
spines, 'axgt' for cell compartments, 'ctgt' for cell types, 'ctgt_v2',
and 'ctgt_j0251' for different versions of cell type ground truths.
Returns:
A string representing the semantic label corresponding to the input integer.
If the label does not match any known category, a value indicating an
unknown label is returned (commonly -1 or "N/A").
Raises:
ValueError: If the provided ground truth type is not recognized by the function.
TODO:
- Remove redundant definitions.
- Check if the return value for unknown labels (-1) is handled appropriately
elsewhere in the code, otherwise consider returning "N/A".
"""
if type(label) == str:
label = int(label)
if gt_type == "axgt":
if label == 1:
return "gt_axon"
elif label == 0:
return "gt_dendrite"
elif label == 2:
return "gt_soma"
elif label == 3:
return "gt_bouton"
elif label == 4:
return "gt_terminal"
else:
return -1 # TODO: Check if somewhere -1 is handled, otherwise return "N/A"
elif gt_type == "spgt":
if label == 1:
return "head"
elif label == 0:
return "neck"
elif label == 2:
return "shaft"
elif label == 3:
return "other"
else:
return -1 # TODO: Check if somewhere -1 is already used, otherwise return "N/A"
elif gt_type == 'ctgt':
if label == 1:
return "MSN"
elif label == 0:
return "EA"
elif label == 2:
return "GP"
elif label == 3:
return "INT"
else:
return -1 # TODO: Check if somewhere -1 is already used, otherwise return "N/A"
elif gt_type == 'ctgt_v2':
# DA and TAN are type modulatory, if this is changes, also change `certainty_celltype`
l_dc_inv = dict(STN=0, modulatory=1, MSN=2, LMAN=3, HVC=4, GP=5, INT=6)
l_dc = {v: k for k, v in l_dc_inv.items()}
try:
return l_dc[label]
except KeyError:
print('Unknown label "{}"'.format(label))
return -1
elif gt_type == 'ctgt_j0251':
str2int_label = dict(STN=0, DA=1, MSN=2, LMAN=3, HVC=4, TAN=5, GPe=6, GPi=7,
FS=8, LTS=9)
int2str_label = {v: k for k, v in str2int_label.items()}
return int2str_label[label]
elif gt_type == 'ctgt_j0251_v2':
str2int_label = dict(STN=0, DA=1, MSN=2, LMAN=3, HVC=4, TAN=5, GPe=6, GPi=7,
FS=8, LTS=9, NGF=10)
int2str_label = {v: k for k, v in str2int_label.items()}
return int2str_label[label]
else:
raise ValueError("Given ground truth type is not valid.")