# -*- coding: utf-8 -*-
# SyConn - Synaptic connectivity inference toolkit
#
# Copyright (c) 2016 - now
# Max Planck Institute of Neurobiology, Martinsried, Germany
# Authors: Philipp Schubert, Sven Dorkenwald, Joergen Kornfeld
import os
import shutil
import time
from pickle import UnpicklingError
from .. import global_params
from ..extraction import log_extraction
from ..handler.basics import write_obj2pkl, load_pkl2obj
try:
from lz4.block import compress, decompress
except ImportError:
from lz4 import compress, decompress
try:
import fasteners
LOCKING = True
except ImportError:
print("fasteners could not be imported. Locking will be disabled by default."
"Please install fasteners to enable locking (pip install fasteners).")
LOCKING = False
__all__ = ['FSBase', 'BTBase']
class StorageBase(dict):
"""
Interface class for data Input/Output operations in SyConn. This class is a dictionary-like
object that provides an interface for data storage and retrieval. It is designed to work with
compressed data and provides methods for data caching and decompression.
Attributes:
_cache_decomp (bool): Flag indicating whether to cache decompressed arrays.
_cache_dc (dict): Cache for decompressed arrays.
_dc_intern (dict): Internal dictionary for data storage.
"""
def __init__(self, cache_decomp):
"""
Initializes the StorageBase object.
Args:
cache_decomp (bool): Flag indicating whether to cache decompressed arrays.
"""
super(StorageBase, self).__init__()
self._cache_decomp = cache_decomp
self._cache_dc = {}
self._dc_intern = {}
def __getitem__(self, key):
"""
Retrieves an item from the storage. This method needs to be implemented in subclasses.
Args:
key (str): The key of the item to retrieve.
"""
raise NotImplementedError
def __setitem__(self, key, value):
"""
Sets an item in the storage. This method needs to be implemented in subclasses.
Args:
key (str): The key of the item to set.
value (any): The value of the item to set.
"""
raise NotImplementedError
def __delitem__(self, key):
"""
Deletes an item from the storage. This method needs to be implemented in subclasses.
Args:
key (str): The key of the item to delete.
"""
raise NotImplementedError
def __del__(self):
"""
Deletes the storage object. This method needs to be implemented in subclasses.
"""
raise NotImplementedError
def __len__(self):
"""
Returns the number of items in the storage.
Returns:
int: The number of items in the storage.
"""
return self._dc_intern.__len__()
def __eq__(self, other):
"""
Checks if the storage is equal to another storage.
Args:
other (StorageBase): The other storage to compare with.
Returns:
bool: True if the storages are equal, False otherwise.
"""
if not isinstance(other, StorageBase):
return False
return self._dc_intern.__eq__(other._dc_intern)
def __ne__(self, other):
"""
Checks if the storage is not equal to another storage.
Args:
other (StorageBase): The other storage to compare with.
Returns:
bool: True if the storages are not equal, False otherwise.
"""
return not self.__eq__(other)
def __contains__(self, item):
"""
Checks if an item is in the storage.
Args:
item (str): The key of the item to check.
Returns:
bool: True if the item is in the storage, False otherwise.
"""
return self._dc_intern.__contains__(item)
def __iter__(self):
"""
Returns an iterator over the keys in the storage.
Returns:
iterator: An iterator over the keys in the storage.
"""
return iter(self._dc_intern)
def __repr__(self):
"""
Returns a string representation of the storage.
Returns:
str: A string representation of the storage.
"""
return self._dc_intern.__repr__()
def update(self, other, **kwargs):
"""
Updates the storage with the items from another storage. This method needs to be implemented
in subclasses.
Args:
other (StorageBase): The other storage to update from.
**kwargs: Additional keyword arguments.
"""
raise NotImplementedError
def copy(self):
"""
Returns a copy of the storage. This method needs to be implemented in subclasses.
Returns:
StorageBase: A copy of the storage.
"""
raise NotImplementedError
def items(self):
"""
Returns an iterator over the items (key-value pairs) in the storage.
Returns:
iterator: An iterator over the items in the storage.
"""
for k in self._dc_intern.keys():
yield k, self[k]
def values(self):
"""
Returns an iterator over the values in the storage.
Returns:
iterator: An iterator over the values in the storage.
"""
for k in self._dc_intern.keys():
yield self[k]
def keys(self):
"""
Returns an iterator over the keys in the storage.
Returns:
iterator: An iterator over the keys in the storage.
"""
return self._dc_intern.keys()
def push(self, dest=None):
"""
Pushes the data to a destination. This method needs to be implemented in subclasses.
Args:
dest (str, optional): The destination to push the data to.
"""
raise NotImplementedError
def pull(self, source=None):
"""
Pulls the data from a source. This method needs to be implemented in subclasses.
Args:
source (str, optional): The source to pull the data from.
"""
raise NotImplementedError
# ---------------------------- BT
# ------------------------------------------------------------------------------
[docs]class BTBase(StorageBase):
"""
BTBase is a subclass of StorageBase that provides an interface for data storage and retrieval
in a BTree-like structure. It is designed to work with compressed data and provides methods
for data caching and decompression.
Attributes:
identifier (str): Identifier for the BTBase instance.
cache_decomp (bool): Flag indicating whether to cache decompressed arrays.
read_only (bool): Flag indicating whether the BTBase instance is read-only.
disable_locking (bool): Flag indicating whether to disable locking.
"""
def __init__(self, identifier, cache_decomp=False, read_only=True,
disable_locking=False):
"""
Initializes the BTBase object.
Args:
identifier (str): Identifier for the BTBase instance.
cache_decomp (bool, optional): Flag indicating whether to cache decompressed arrays.
read_only (bool, optional): Flag indicating whether the BTBase instance is read-only.
disable_locking (bool, optional): Flag indicating whether to disable locking.
"""
# likely 'cache_decomp' not necessary, but needed to match interface of LZ4Dicts
super(BTBase, self).__init__(cache_decomp=False)
pass
def __eq__(self, other):
"""
Checks if the BTBase instance is equal to another BTBase instance.
Args:
other (BTBase): The other BTBase instance to compare with.
Returns:
bool: True if the BTBase instances are equal, False otherwise.
"""
if not isinstance(other, BTBase):
return False
return self._dc_intern.__eq__(other._dc_intern)
# ---------------------------- lz4
# ------------------------------------------------------------------------------
[docs]class FSBase(StorageBase):
"""
This class is a customized dictionary that stores compressed numpy arrays. The compression process
happens in the background, providing an intuitive user interface. The 'cache_decomp' kwarg can be
enabled to cache decompressed arrays, saving decompression time when accessing items frequently.
"""
def __init__(self, inp_p: str, cache_decomp: bool = False,
read_only: bool = True, max_delay: int = 100,
timeout: int = 1000, disable_locking: bool = True,
max_nb_attempts: int = 100):
"""
Initializes the FSBase object.
Args:
inp_p (str): Path to the file.
cache_decomp (bool, optional): If True, caches deserialized arrays.
Defaults to False.
read_only (bool, optional): If True and locking is enabled, no semaphore
will be placed. Defaults to True.
max_delay (int, optional): Delay between attempts. Defaults to 100.
timeout (int, optional): Throws `RuntimeError` after `timeout` seconds.
Defaults to 1000.
disable_locking (bool, optional): If True, disables file locking.
Defaults to True.
max_nb_attempts (int, optional): Maximum number of total attempts.
Defaults to 100.
"""
super(FSBase, self).__init__(cache_decomp)
if not LOCKING and not disable_locking:
log_extraction.warning('Locking could not be enabled due to missing "fasteners" package.')
disable_locking = True
self.read_only = read_only
self.a_lock = None
self.max_delay = max_delay
self.timeout = timeout
self.disable_locking = disable_locking
self._cache_decomp = cache_decomp
self._max_nb_attempts = max_nb_attempts
self._cache_dc = {}
self._dc_intern = {}
self._path = inp_p
if inp_p is not None:
if type(inp_p) is str:
self.pull(inp_p)
else:
msg = "Unsupported initialization type {} for 'FSBase'.".format(type(inp_p))
log_extraction.error(msg)
raise NotImplementedError(msg)
def __delitem__(self, key):
"""
Deletes the item with the specified key from the dictionary.
"""
try:
del self[key]
except KeyError:
msg = "No such attribute {} in dict at {}. Existing keys:" \
" {}.".format(key, self._path, list(self.keys()))
log_extraction.error(msg)
raise AttributeError(msg)
def __del__(self):
"""
Deletes the object and releases any acquired locks.
"""
if self.a_lock is not None and self.a_lock.acquired:
self.a_lock.release()
del self._dc_intern, self._cache_dc
def __len__(self):
"""
Returns the number of items in the dictionary.
"""
return self._dc_intern.__len__()
def __eq__(self, other):
"""
Checks if the current object is equal to the other object.
"""
if not isinstance(other, FSBase):
return False
return self._dc_intern.__eq__(other._dc_intern)
def __ne__(self, other):
"""
Checks if the current object is not equal to the other object.
"""
return not self.__eq__(other)
def __contains__(self, item):
"""
Checks if the dictionary contains the specified item.
"""
return self._dc_intern.__contains__(item)
def __iter__(self):
"""
Returns an iterator for the dictionary.
"""
return iter(self._dc_intern)
def __repr__(self):
"""
Returns a string representation of the dictionary.
"""
return self._dc_intern.__repr__()
[docs] def update(self, other, **kwargs):
"""
Updates the dictionary with the key-value pairs from other. This method is not implemented.
"""
raise NotImplementedError
[docs] def copy(self):
"""
Returns a copy of the dictionary. This method is not implemented.
"""
raise NotImplementedError
[docs] def items(self):
"""
Returns a generator that yields the key-value pairs in the dictionary.
"""
for k in self._dc_intern.keys():
yield k, self[k]
[docs] def values(self):
"""
Returns a generator that yields the values in the dictionary.
"""
for k in self._dc_intern.keys():
yield self[k]
[docs] def keys(self):
"""
Returns the keys in the dictionary.
"""
return self._dc_intern.keys()
[docs] def push(self, dest: str = None):
"""
Pushes data to the specified destination.
Args:
dest (str, optional): The storage destination. Defaults to None.
"""
if dest is None:
dest = self._path
if dest is None: # support virtual / temporary SSO objects
log_extraction.warning('"push" called but Storage object was initialized '
'with "None". Content will not be written.')
return
write_obj2pkl(dest, self._dc_intern)
if not self.read_only and not self.disable_locking:
self.a_lock.release()
[docs] def pull(self, source: str = None):
"""
Fetches data from the specified source.
Args:
source (str, optional): The source location. Defaults to None.
"""
if source is None:
source = self._path
fold, fname = os.path.split(source)
lock_path = fold + "/." + fname + ".lk"
# only create directory if read_only is false. -> support virtual SSO
if not os.path.isdir(fold) and not self.read_only:
try:
os.makedirs(fold)
except OSError as e: # if to jobs create the folder at the same time
log_extraction.warning("Tried to create folder of dict {}, but it already existed. "
"Multiple jobs might work on same chunk almost"
" simultaneously. Error {}.".format(self._path, e))
pass
# acquires lock until released when saving or after loading if self.read_only
if not self.disable_locking:
self.a_lock = fasteners.InterProcessLock(lock_path)
nb_attempts = 1
start = time.time()
while True:
try:
gotten = self.a_lock.acquire(blocking=True, delay=0.1,
max_delay=self.max_delay,
timeout=self.timeout / self._max_nb_attempts)
except ValueError:
gotten = False
# if not gotten and maximum attempts not reached yet keep trying
if not gotten and (nb_attempts < self._max_nb_attempts):
nb_attempts += 1
else:
break
if not gotten:
msg = "Unable to acquire file lock for {} after {:.0f}s.".format(source, time.time() - start)
log_extraction.warning(msg)
raise RuntimeError(msg)
if os.path.isfile(source):
try:
self._dc_intern = load_pkl2obj(source)
except (UnpicklingError, EOFError) as e:
log_extraction.warning("Could not load LZ4Dict ({}). 'push' will"
" overwrite broken .pkl file: {}.".format(self._path, e))
self._dc_intern = {}
else:
self._dc_intern = {}
if self.read_only and not self.disable_locking:
self.a_lock.release()