# -*- coding: utf-8 -*-
# SyConn - Synaptic connectivity inference toolkit
#
# Copyright (c) 2016 - now
# Max Planck Institute of Neurobiology, Martinsried, Germany
# Authors: Philipp Schubert, Joergen Kornfeld
import collections
import contextlib
import gc
import glob
import os
import pickle as pkl
import re
import shutil
import signal
import tempfile
import zipfile
from collections import defaultdict
from typing import List, Union
import networkx as nx
import numpy as np
import tqdm
from knossos_utils import KnossosDataset
from knossos_utils.skeleton import SkeletonAnnotation, SkeletonNode
from plyfile import PlyData
from . import log_handler
from .. import global_params
[docs]def kd_factory(kd_path: str, channel: str = 'jpg'):
"""
Initializes a KnossosDataset at the given `kd_path`.
Notes:
* Prioritizes pyk.conf files.
Todo:
* Requires additional adjustment of the data type,
i.e. setting the channel explicitly currently leads to uint32 <->
uint64 issues in the CS segmentation.
Args:
kd_path: Path to the KnossosDataset.
channel: Channel which to use. Currently not used.
Returns:
"""
kd = KnossosDataset()
# TODO: set appropriate channel
# # kd.set_channel(channel)
if os.path.isfile(kd_path):
kd.initialize_from_conf(kd_path)
elif len(glob.glob(f'{kd_path}/*.pyk.conf')) == 1:
pyk_confs = glob.glob(f'{kd_path}/*.pyk.conf')
kd.initialize_from_pyknossos_path(pyk_confs[0])
elif os.path.isfile(kd_path + "/mag1/knossos.conf"):
# Initializes the dataset by parsing the knossos.conf in path + "mag1"
kd_path += "/mag1/knossos.conf"
kd.initialize_from_knossos_path(kd_path)
else:
raise ValueError(f'Could not find KnossosDataset config at {kd_path}.')
return kd
[docs]def switch_array_entries(this_array, entries):
entry_0 = this_array[entries[0]]
this_array[entries[0]] = this_array[entries[1]]
this_array[entries[1]] = entry_0
return this_array
[docs]def crop_bool_array(arr):
"""
Crops a bool array to its True region
Args:
arr: 3d bool array
array to crop
Returns: d bool array, list
cropped array, offset
"""
in_mask_indices = [np.flatnonzero(arr.sum(axis=(1, 2))),
np.flatnonzero(arr.sum(axis=(0, 2))),
np.flatnonzero(arr.sum(axis=(0, 1)))]
return arr[in_mask_indices[0].min(): in_mask_indices[0].max() + 1,
in_mask_indices[1].min(): in_mask_indices[1].max() + 1,
in_mask_indices[2].min(): in_mask_indices[2].max() + 1],\
[in_mask_indices[0].min(),
in_mask_indices[1].min(),
in_mask_indices[2].min()]
[docs]def group_ids_to_so_storage(ids, params, significant_digits=5):
id_dict = defaultdict(list)
param_dicts = [defaultdict(list) for _ in range(len(params))]
for i_id in range(len(ids)):
this_id = ids[i_id]
this_id_str = "%.5d" % this_id
id_dict[this_id_str[-significant_digits:]].append(this_id)
for i_param in range(len(params)):
param_dicts[i_param][this_id_str[-significant_digits:]].\
append(params[i_param][i_id])
return [id_dict] + param_dicts
[docs]def majority_element_1d(arr):
"""
Returns most frequent element in 'arr'.
Args:
arr: np.array
Returns: scalar
"""
uni_el, cnts = np.unique(arr, return_counts=True)
return uni_el[np.argmax(cnts)]
[docs]def get_paths_of_skelID(id_list, traced_skel_dir):
"""
Gather paths to kzip of skeletons with ID in id_list
Args:
id_list: list of str
skeleton ID's
traced_skel_dir: str
directory of mapped skeletons
Returns: list of str
paths of skeletons in id_list
"""
mapped_skel_paths = get_filepaths_from_dir(traced_skel_dir)
mapped_skel_ids = re.findall(r'iter_\d+_(\d+)', ''.join(mapped_skel_paths))
wanted_paths = []
for skelID in id_list:
try:
path = mapped_skel_paths[mapped_skel_ids.index(str(skelID))]
wanted_paths.append(path)
except:
wanted_paths.append(None)
return wanted_paths
[docs]def coordpath2anno(coords, scaling=None, add_edges=True):
"""
Creates skeleton from scaled coordinates, assume coords are in order for
edge creation.
Args:
coords: np.array
scaled cooridnates
scaling: tuple
add_edges: bool
Returns: SkeletonAnnotation
"""
if scaling is None:
scaling = global_params.config['scaling']
anno = SkeletonAnnotation()
anno.scaling = scaling
scaling = np.array(scaling, dtype=np.int32)
rep_nodes = []
coords = np.array(coords, dtype=np.int32)
for c in coords:
unscaled_c = c / scaling
n = SkeletonNode().from_scratch(anno, unscaled_c[0], unscaled_c[1],
unscaled_c[2])
anno.addNode(n)
rep_nodes.append(n)
if add_edges:
for i in range(1, len(rep_nodes)):
anno.addEdge(rep_nodes[i-1], rep_nodes[i])
return anno
[docs]def get_filepaths_from_dir(directory, ending=('k.zip',), recursively=False,
exclude_endings=False, fname_includes=()):
"""
Collect all files with certain ending from directory.
Args:
directory: str
path to lookup directory
ending: tuple/list/str
ending(s) of files
recursively: boolean
add files from subdirectories
exclude_endings: bool
filenames with endings defined in endings will not be added
fname_includes: str or list
file names with this substring(s) will be added
Returns: list of str
paths to files
"""
# make it backwards compatible
if type(ending) is str:
ending = [ending]
if type(fname_includes) is str:
fname_includes = [fname_includes]
files = []
corr_incl = True
corr_end = True
if recursively:
for r, s, fs in os.walk(directory):
for f in fs:
if len(ending) > 0:
corr_end = np.any(
[f[-len(end):] == end for end in ending])
if exclude_endings:
corr_end = not corr_end
if len(fname_includes) > 0:
corr_incl = np.any([substr in f for substr in fname_includes])
if corr_end and corr_incl:
files.append(os.path.join(r, f))
else:
for f in next(os.walk(directory))[2]:
if len(ending) > 0:
corr_end = np.any(
[f[-len(end):] == end for end in ending])
if exclude_endings:
corr_end = not corr_end
if len(fname_includes) > 0:
corr_incl = np.any([substr in f for substr in fname_includes])
if corr_end and corr_incl:
files.append(os.path.join(directory, f))
return files
[docs]def read_txt_from_zip(zip_fname, fname_in_zip):
"""
Read text file from zip.
Args:
zip_fname: str
fname_in_zip: str
Returns: bytes
"""
with zipfile.ZipFile(zip_fname, allowZip64=True) as z:
txt = z.read(fname_in_zip)
return txt
[docs]def read_mesh_from_zip(zip_fname, fname_in_zip):
"""
Read ply file from zip. Currently does not support normals!
Args:
zip_fname: str
fname_in_zip: str
Returns: np.array, np.array, np.array
"""
with zipfile.ZipFile(zip_fname, allowZip64=True) as z:
txt = z.open(fname_in_zip)
plydata = PlyData.read(txt)
vert = plydata['vertex'].data
vert = vert.view((np.float32, len(vert.dtype.names))).flatten()
ind = np.array(plydata['face'].data['vertex_indices'].tolist()).flatten()
# TODO: support normals
# norm = plydata['normals'].data
# norm = vert.view((np.float32, len(vert.dtype.names))).flatten()
return [ind, vert, None]
[docs]def read_meshes_from_zip(zip_fname, fnames_in_zip):
"""
Read ply files from zip. Currently does not support normals!
Args:
zip_fname: str
fnames_in_zip: str
Returns: np.array, np.array, np.array
"""
meshes = []
with zipfile.ZipFile(zip_fname, allowZip64=True) as z:
for fname_in_zip in fnames_in_zip:
txt = z.open(fname_in_zip)
plydata = PlyData.read(txt)
vert = plydata['vertex'].data
vert = vert.view((np.float32, len(vert.dtype.names))).flatten()
ind = np.array(plydata['face'].data['vertex_indices'].tolist()).flatten()
# TODO: support normals
# norm = plydata['normals'].data
# norm = vert.view((np.float32, len(vert.dtype.names))).flatten()
meshes.append((ind, vert, None))
return meshes
[docs]def write_txt2kzip(kzip_path, text, fname_in_zip, force_overwrite=False):
"""
Write string to file in k.zip.
Args:
kzip_path: str
text: str or bytes
fname_in_zip: str
name of file when added to zip
force_overwrite: bool
Returns:
"""
texts2kzip(kzip_path, [text], [fname_in_zip],
force_overwrite=force_overwrite)
[docs]def texts2kzip(kzip_path, texts, fnames_in_zip, force_overwrite=False):
"""
Write strings to files in k.zip.
Args:
kzip_path: str
texts: List[str]
fnames_in_zip: List[str]
name of file when added to zip
force_overwrite: bool
Returns:
"""
if not kzip_path.endswith('.k.zip'):
kzip_path += '.k.zip'
if os.path.isfile(kzip_path):
try:
if force_overwrite:
with zipfile.ZipFile(kzip_path, "w", zipfile.ZIP_DEFLATED) as zf:
for i in range(len(texts)):
zf.writestr(fnames_in_zip[i], texts[i])
else:
for i in range(len(texts)):
remove_from_zip(kzip_path, fnames_in_zip[i])
with zipfile.ZipFile(kzip_path, "a", zipfile.ZIP_DEFLATED) as zf:
for i in range(len(texts)):
zf.writestr(fnames_in_zip[i], texts[i])
except Exception as e:
log_handler.error("Couldn't open file {} for reading and overwri"
"ting. {}".format(kzip_path, e))
else:
try:
with zipfile.ZipFile(kzip_path, "w", zipfile.ZIP_DEFLATED) as zf:
for i in range(len(texts)):
zf.writestr(fnames_in_zip[i], texts[i])
except Exception as e:
log_handler.error("Couldn't open file {} for writing. {}"
"".format(kzip_path, e))
[docs]def write_data2kzip(kzip_path, fpath, fname_in_zip=None, force_overwrite=False):
"""
Write file to k.zip.
Args:
kzip_path: str
fpath: str
fname_in_zip: str
name of file when added to zip
force_overwrite: bool
Returns:
"""
data2kzip(kzip_path, [fpath], [fname_in_zip], force_overwrite)
[docs]def data2kzip(kzip_path: str, fpaths, fnames_in_zip=None, force_overwrite=True,
verbose=False):
"""
Write files to k.zip. Finally removes files at `fpaths`.
Args:
kzip_path: str
fpaths: List[str]
fnames_in_zip: List[str]
name of file when added to zip
force_overwrite: bool
verbose: bool
Returns:
"""
if not kzip_path.endswith('.k.zip'):
kzip_path += '.k.zip'
nb_files = len(fpaths)
if verbose:
log_handler.info('Writing {} files to .zip.'.format(nb_files))
pbar = tqdm.tqdm(total=nb_files, leave=False)
if os.path.isfile(kzip_path):
try:
if force_overwrite:
with zipfile.ZipFile(kzip_path, "w", zipfile.ZIP_DEFLATED,
allowZip64=True) as zf:
for ii in range(nb_files):
file_name = os.path.split(fpaths[ii])[1]
if fnames_in_zip[ii] is not None:
file_name = fnames_in_zip[ii]
zf.write(fpaths[ii], file_name)
if verbose:
pbar.update()
else:
for ii in range(nb_files):
file_name = os.path.split(fpaths[ii])[1]
if fnames_in_zip[ii] is not None:
file_name = fnames_in_zip[ii]
remove_from_zip(kzip_path, file_name)
with zipfile.ZipFile(kzip_path, "a", zipfile.ZIP_DEFLATED,
allowZip64=True) as zf:
for ii in range(nb_files):
file_name = os.path.split(fpaths[ii])[1]
if fnames_in_zip[ii] is not None:
file_name = fnames_in_zip[ii]
zf.write(fpaths[ii], file_name)
if verbose:
pbar.update()
except Exception as e:
log_handler.error("Couldn't open file {} for reading and"
" overwriting. Error: {}".format(kzip_path, e))
else:
try:
with zipfile.ZipFile(kzip_path, "w", zipfile.ZIP_DEFLATED,
allowZip64=True) as zf:
for ii in range(nb_files):
file_name = os.path.split(fpaths[ii])[1]
if fnames_in_zip[ii] is not None:
file_name = fnames_in_zip[ii]
zf.write(fpaths[ii], file_name)
if verbose:
pbar.update()
except Exception as e:
log_handler.error("Couldn't open file {} for writing. Error: "
"{}".format(kzip_path, e))
for ii in range(nb_files):
os.remove(fpaths[ii])
if verbose:
pbar.close()
log_handler.info('Done writing files to .zip.')
[docs]def remove_from_zip(zipfname, *filenames):
"""
Removes filenames from zipfile
Args:
zipfname: str
Path to zipfile
*filenames: list of str
files to delete
Returns:
"""
tempdir = tempfile.mkdtemp()
try:
tempname = os.path.join(tempdir, 'new.zip')
with zipfile.ZipFile(zipfname, 'r', allowZip64=True) as zipread:
with zipfile.ZipFile(tempname, 'w', allowZip64=True) as zipwrite:
for item in zipread.infolist():
if item.filename not in filenames:
data = zipread.read(item.filename)
zipwrite.writestr(item, data)
shutil.move(tempname, zipfname)
finally:
shutil.rmtree(tempdir)
[docs]def write_obj2pkl(path, objects):
"""
Writes object to pickle file
Args:
path: str
Destination.
objects: object
Returns:
"""
gc.disable()
if isinstance(path, str):
with open(path + ".tmp", 'wb') as output:
pkl.dump(objects, output, protocol=pkl.HIGHEST_PROTOCOL)
shutil.move(path + ".tmp", path)
else:
log_handler.warn("Write_obj2pkl takes arguments 'path' (str) and "
"'objects' (python object).")
with open(objects + ".tmp", 'wb') as output:
pkl.dump(path, output, protocol=pkl.HIGHEST_PROTOCOL)
shutil.move(objects + ".tmp", objects)
gc.enable()
[docs]def load_pkl2obj(path):
"""
Loads pickle file of object
Args:
path: str
path of source file
Returns:
"""
gc.disable()
try:
with open(path, 'rb') as inp:
objects = pkl.load(inp)
except UnicodeDecodeError: # python3 compatibility
with open(path, 'rb') as inp:
objects = pkl.loads(inp.read(), encoding='bytes')
objects = convert_keys_byte2str(objects)
gc.enable()
return objects
[docs]def convert_keys_byte2str(dc):
if type(dc) is not dict:
return dc
for k in list(dc.keys()):
v = convert_keys_byte2str(dc[k])
if type(k) is bytes:
dc[k.decode('utf-8')] = v
del dc[k]
return dc
[docs]def chunkify(lst: Union[list, np.ndarray], n: int) -> List[list]:
"""
Splits list into ``np.min([n, len(lst)])`` sub-lists.
Args:
lst:
n:
Examples:
>>> chunkify(np.arange(10), 2)
>>> chunkify(np.arange(10), 100)
Returns:
List of chunks. Length is ``np.min([n, len(lst)])``.
"""
if len(lst) < n:
n = len(lst)
return [lst[i::n] for i in range(n)]
[docs]def chunkify_weighted(lst, n, weights):
"""
splits list into n sub-lists according to weights.
Args:
lst: list
n: int
weights: array
Returns:
"""
if len(lst) < n:
n = len(lst)
return [lst[i::n] for i in range(n)] # no weighting needed
ordered = np.argsort(weights)
lst = lst[ordered[::-1]]
return [lst[i::n] for i in range(n)]
[docs]def chunkify_successive(l, n):
"""Yield successive n-sized chunks from l."""
for i in range(0, len(l), n):
yield l[i:i + n]
[docs]def flatten_list(lst):
"""
Flattens list of lists. Same ordering as np.concatenate
Args:
lst: list of lists
Returns: list
"""
res = np.array([el for sub in lst for el in sub])
return res
[docs]def flatten(x):
"""
Replacement for compiler.ast.flatten - this performs
recursive flattening in comparison to the function above.
Public domain code:
https://stackoverflow.com/questions/16176742/
python-3-replacement-for-deprecated-compiler-ast-flatten-function
Args:
x:
Returns: flattend x
"""
def iselement(e):
return not(isinstance(e, collections.Iterable) and not isinstance(e, str))
for el in x:
if iselement(el):
yield el
else:
# py2 compat
# yield from flatten(el)
for subel in flatten(el):
yield subel
[docs]def get_skelID_from_path(skel_path):
"""
Parse skeleton ID from filename.
Args:
skel_path: str
path to skeleton
Returns: int
skeleton ID
"""
return int(re.findall(r'iter_0_(\d+)', skel_path)[0])
[docs]def safe_copy(src, dest, safe=True):
"""
Copies file and throws exception if destination exists. Taken from
Misandrist on Stackoverflow (03/31/17).
Args:
src: str
path to source file
dest: str
path to destination file
safe: bool
If False, copies file with replacement
Returns:
"""
if safe:
fd = os.open(dest, os.O_CREAT | os.O_EXCL | os.O_WRONLY)
# Copy the file and automatically close files at the end
with os.fdopen(fd, 'wb') as f:
with open(src, 'rb') as sf:
shutil.copyfileobj(sf, f)
else:
shutil.copy(src, dest)
# https://gist.github.com/tcwalther/ae058c64d5d9078a9f333913718bba95
# class based on: http://stackoverflow.com/a/21919644/487556
[docs]class DelayedInterrupt(object):
def __init__(self, signals):
if not isinstance(signals, list) and not isinstance(signals, tuple):
signals = [signals]
self.sigs = signals
def __enter__(self):
self.signal_received = {}
self.old_handlers = {}
for sig in self.sigs:
self.signal_received[sig] = False
self.old_handlers[sig] = signal.getsignal(sig)
def handler(s, frame):
self.signal_received[sig] = (s, frame)
# Note: in Python 3.5, you can use signal.Signals(sig).name
log_handler.info('Signal %s received. Delaying KeyboardInterrupt.' % sig)
self.old_handlers[sig] = signal.getsignal(sig)
signal.signal(sig, handler)
def __exit__(self, type, value, traceback):
for sig in self.sigs:
signal.signal(sig, self.old_handlers[sig])
if self.signal_received[sig] and self.old_handlers[sig]:
self.old_handlers[sig](*self.signal_received[sig])
[docs]def prase_cc_dict_from_txt(txt):
"""
Parse connected components from knossos mergelist text file
Args:
txt: str or bytes
Returns: dict
"""
cc_dict = {}
for line in txt.splitlines()[::4]:
if type(line) is bytes:
curr_line = line.decode()
else:
curr_line = line
line_nb = np.array(re.findall(r"(\d+)", curr_line), dtype=np.uint64)
curr_ixs = line_nb[3:]
cc_ix = line_nb[0]
curr_ixs = curr_ixs[curr_ixs != 0]
cc_dict[cc_ix] = curr_ixs
return cc_dict
[docs]def parse_cc_dict_from_kml(kml_path):
"""
Parse connected components from knossos mergelist text file
Args:
kml_path: str
Returns: dict
"""
txt = open(kml_path, "rb").read().decode()
return prase_cc_dict_from_txt(txt)
[docs]def parse_cc_dict_from_g(g):
cc_dict = {}
# use minimum ID in CC as SSV ID
for cc in sorted(nx.connected_components(g), key=len, reverse=True):
cc_dict[cc[0]] = cc
return cc_dict
[docs]def parse_cc_dict_from_kzip(k_path):
"""
Args:
k_path: str
Returns: dict
"""
txt = read_txt_from_zip(k_path, "mergelist.txt").decode()
return prase_cc_dict_from_txt(txt)
[docs]@contextlib.contextmanager
def temp_seed(seed):
"""
From https://stackoverflow.com/questions/49555991/can-i-create-a-local-numpy-random-seed
Args:
seed:
Returns:
"""
state = np.random.get_state()
np.random.seed(seed)
try:
yield
finally:
np.random.set_state(state)
[docs]def str_delta_sec(seconds: int) -> str:
"""
String time formatting - omits time units which are zero.
Examples:
>>> sec = 2 * 24 * 3600 + 12 * 3600 + 5 * 60 + 1
>>> str_rep = str_delta_sec(sec)
>>> assert str_rep == '2d:12h:05min:01s'
>>> assert str_delta_sec(4 * 3600 + 20 * 60 + 10) == '4h:20min:10s'
Args:
seconds: Number of seconds, e.g. result of a time delta.
Returns:
String representation, e.g. ``'2d:12h:05min:01s'`` for
``sec = 1 + 5 * 60 + 12 * 3600 + 2 * 24 * 3600``.
"""
m, s = divmod(int(seconds), 60)
h, m = divmod(m, 60)
d, h = divmod(h, 24)
str_rep = ''
if d > 0:
str_rep += f'{d:d}d:'
if h > 0:
str_rep += f'{h:d}h:'
if m > 0:
str_rep += f'{m:02d}min:'
str_rep += f'{s:02d}s'
return str_rep