"""
Base classes for standard models
:copyright: Copyright 2006-2023 by the PyNN team, see AUTHORS.
:license: CeCILL, see LICENSE for details.
"""
import warnings
from copy import deepcopy
import numpy as np
import neo
import quantities as pq
from .. import errors, models
from ..parameters import ParameterSpace
from ..morphology import IonChannelDistribution, SynapseDistribution
excitatory_receptor_types = ["excitatory", "AMPA", "NMDA"]
inhibitory_receptor_types = ["inhibitory", "GABA", "GABAA", "GABAB"]
# ==============================================================================
#   Standard cells
# ==============================================================================
def build_scaling_functions(pynn_name, sim_name, scale_factor):
    def f(**p):
        return p[pynn_name] * scale_factor
    def g(**p):
        return p[sim_name] / scale_factor
    return f, g
def build_translations(*translation_list):
    """
    Build a translation dictionary from a list of translations/transformations.
    """
    translations = {}
    for item in translation_list:
        err_msg = f"Translation tuples must have between 2 and 4 items. Actual content: {item}"
        assert 2 <= len(item) <= 4, err_msg
        pynn_name = item[0]
        sim_name = item[1]
        if len(item) == 2:  # no transformation
            f = pynn_name
            g = sim_name
            type_ = "simple"
        elif len(item) == 3:  # simple multiplicative factor
            scale_factor = item[2]
            f, g = build_scaling_functions(pynn_name, sim_name, scale_factor)
            type_ = "scaled"
        elif len(item) == 4:  # more complex transformation
            f = item[2]
            g = item[3]
            type_ = "computed"
        translations[pynn_name] = {'translated_name': sim_name,
                                   'forward_transform': f,
                                   'reverse_transform': g,
                                   'type': type_}
    return translations
class StandardModelType(models.BaseModelType):
    """Base class for standardized cell model and synapse model classes."""
    translations = {}
    extra_parameters = {}
    @property
    def native_parameters(self):
        """
        A :class:`ParameterSpace` containing parameter names and values
        translated from the standard PyNN names and units to simulator-specific
        ("native") names and units.
        """
        return self.translate(self.parameter_space)
    def translate(self, parameters, copy=True):
        """Translate standardized model parameters to simulator-specific parameters."""
        if copy:
            _parameters = deepcopy(parameters)
        else:
            _parameters = parameters
        cls = self.__class__
        if parameters.schema != self.get_schema():
            # should replace this with a PyNN-specific exception type
            raise Exception(f"Schemas do not match: {parameters.schema} != {self.get_schema()}")
        native_parameters = {}
        for name in parameters.keys():
            D = self.translations[name]
            pname = D['translated_name']
            if callable(D['forward_transform']):
                pval = D['forward_transform'](**_parameters)
            else:
                try:
                    pval = eval(D['forward_transform'], globals(), _parameters)
                except NameError as err:
                    raise NameError(
                        f"Problem translating '{pname}' in {cls.__name__}. "
                        f"Transform: '{D['forward_transform']}'. Parameters: {parameters}. {err}"
                    )
                except ZeroDivisionError:
                    raise
            native_parameters[pname] = pval
        return ParameterSpace(native_parameters, schema=None, shape=parameters.shape)
    def reverse_translate(self, native_parameters):
        """Translate simulator-specific model parameters to standardized parameters."""
        cls = self.__class__
        standard_parameters = {}
        for name, D in self.translations.items():
            tname = D['translated_name']
            if tname in native_parameters.keys():
                if callable(D['reverse_transform']):
                    standard_parameters[name] = D['reverse_transform'](**native_parameters)
                else:
                    try:
                        standard_parameters[name] = eval(
                            D['reverse_transform'], {}, native_parameters)
                    except NameError as err:
                        raise NameError(
                            f"Problem translating '{name}' in {cls.__name__}. "
                            f"Transform: '{D['reverse_transform']}'. "
                            f"Parameters: {native_parameters}. {err}"
                        )
        return ParameterSpace(standard_parameters,
                              schema=self.get_schema(),
                              shape=native_parameters.shape)
    def simple_parameters(self):
        """Return a list of parameters for which there is a one-to-one
        correspondance between standard and native parameter values."""
        return [name for name in self.translations
                if self.translations[name]['type'] == "simple"]
    def scaled_parameters(self):
        """Return a list of parameters for which there is a unit change between
        standard and native parameter values."""
        return [name for name in self.translations
                if self.translations[name]['type'] == "scaled"]
    def computed_parameters(self):
        """Return a list of parameters whose values must be computed from
        more than one other parameter."""
        return [name for name in self.translations
                if self.translations[name]['type'] == "computed"]
    def computed_parameters_include(self, parameter_names):
        return any(name in self.computed_parameters() for name in parameter_names)
    def get_native_names(self, *names):
        """
        Return a list of native parameter names for a given model.
        """
        if names:
            translations = (self.translations[name] for name in names)
        else:  # return all names
            translations = self.translations.values()
        return [D['translated_name'] for D in translations]
[docs]class StandardCellType(StandardModelType, models.BaseCellType):
    """Base class for standardized cell model classes."""
    recordable = ['spikes', 'v', 'gsyn']
    receptor_types = ('excitatory', 'inhibitory')
    always_local = False  # override for NEST spike sources 
class StandardCellTypeComponent(StandardModelType, models.BaseCellTypeComponent):
    """docstring needed"""
    pass
class StandardIonChannelModel(StandardModelType, models.BaseIonChannelModel):
    """Base class for standardized ion channel models."""
    def get_schema(self):
        return {
            "conductance_density": IonChannelDistribution,
            "e_rev": float
        }
class StandardPostSynapticResponse(StandardModelType, models.BasePostSynapticResponse):
    """Base class for standardized post-synaptic receptor models."""
    def get_schema(self):
        return {
            "locations": SynapseDistribution,
            "e_syn": float,
            "tau_syn": float  # should be a tuple, if multiple time constants
        }
    def set_parent(self, parent):
        """
        """
        self.parent = parent
class StandardCurrentSource(StandardModelType, models.BaseCurrentSource):
    """Base class for standardized current source model classes."""
    def inject_into(self, cells):
        """
        Inject the current from this source into the supplied group of cells.
        `cells` may be a :class:`Population`, :class:`PopulationView`,
        :class:`Assembly` or a list of :class:`ID` objects.
        """
        raise NotImplementedError("Should be redefined in the local simulator electrodes")
    def __getattr__(self, name):
        if name == "set":
            err_msg = "For current sources, set values using the parameter name directly, " \
                      "e.g. source.amplitude = 0.5, or use 'set_parameters()' " \
                      "e.g. source.set_parameters(amplitude=0.5)"
            raise AttributeError(err_msg)
        try:
            val = self.get_parameters()[name]
        except KeyError:
            try:
                val = self.__getattribute__(name)
            except AttributeError:
                raise errors.NonExistentParameterError(name,
                                                       self.__class__.__name__,
                                                       self.get_parameter_names())
        return val
    def __setattr__(self, name, value):
        if self.has_parameter(name):
            self.set_parameters(**{name: value})
        else:
            object.__setattr__(self, name, value)
    def set_parameters(self, copy=True, **parameters):
        """
        Set current source parameters, given as a sequence of parameter=value arguments.
        """
        # if some of the parameters are computed from the values of other
        # parameters, need to get and translate all parameters
        if self.computed_parameters_include(parameters):
            all_parameters = self.get_parameters()
            all_parameters.update(parameters)
            parameters = all_parameters
        else:
            parameters = ParameterSpace(parameters, self.get_schema(), (1,))
        parameters = self.translate(parameters, copy=copy)
        self.set_native_parameters(parameters)
    def get_parameters(self):
        """Return a dict of all current source parameters."""
        parameters = self.get_native_parameters()
        parameters = self.reverse_translate(parameters)
        return parameters
    def set_native_parameters(self, parameters):
        raise NotImplementedError
    def get_native_parameters(self):
        raise NotImplementedError
    def _round_timestamp(self, value, resolution):
        # todo: consider using decimals module,
        # since rounding of floating point numbers is so horrible
        return np.rint(value / resolution) * resolution
    def get_data(self):
        """Return the recorded current as a Neo signal object"""
        t_arr, i_arr = self._get_data()
        intervals = np.diff(t_arr)
        if intervals.size > 0 and intervals.max() - intervals.min() < 1e-9:
            signal = neo.AnalogSignal(i_arr, units="nA", t_start=t_arr[0] * pq.ms,
                                      sampling_period=intervals[0] * pq.ms)
        else:
            signal = neo.IrregularlySampledSignal(t_arr, i_arr, units="nA", time_units="ms")
        return signal
class ModelNotAvailable(object):
    """Not available for this simulator."""
    def __init__(self, *args, **kwargs):
        raise NotImplementedError(
            f"The {self.__class__.__name__} model is not available for this simulator.")
# ==============================================================================
#   Synapse Dynamics classes
# ==============================================================================
def check_weights(weights, projection):
    # if projection.post is an Assembly, some components might have cond-synapses, others curr,
    # so need a more sophisticated check here. For now, skipping check and emitting a warning
    if (
        hasattr(projection.post, "_homogeneous_synapses")
        and not projection.post._homogeneous_synapses     # noqa: W503
    ):
        warnings.warn("Not checking weights due to due mixture of synapse types")
    if isinstance(weights, np.ndarray):
        all_negative = (weights <= 0).all()
        all_positive = (weights >= 0).all()
        if not (all_negative or all_positive):
            raise errors.ConnectionError("Weights must be either all positive or all negative")
    elif np.isreal(weights):
        all_positive = weights >= 0
        all_negative = weights <= 0
    else:
        raise errors.ConnectionError("Weights must be a number or an array of numbers.")
    if projection.post.conductance_based or projection.receptor_type in excitatory_receptor_types:
        if not all_positive:
            raise errors.ConnectionError(
                "Weights must be positive for conductance-based and/or excitatory synapses"
            )
    elif (
        projection.post.conductance_based is False
        and projection.receptor_type in inhibitory_receptor_types  # noqa: W503
    ):
        if not all_negative:
            raise errors.ConnectionError(
                "Weights must be negative for current-based, inhibitory synapses"
            )
    else:
        # This can happen for multi-synapse models
        # if the receptor_type is not one of the commonly used ones
        warnings.warn("Can't check weight, conductance status unknown.")
def check_delays(delays, projection):
    min_delay = projection._simulator.state.min_delay
    max_delay = projection._simulator.state.max_delay
    if isinstance(delays, np.ndarray):
        below_max = (delays <= max_delay).all()
        above_min = (delays >= min_delay).all()
        in_range = below_max and above_min
    elif np.isreal(delays):
        in_range = min_delay <= delays <= max_delay
    else:
        raise errors.ConnectionError("Delays must be a number or an array of numbers.")
    if not in_range:
        raise errors.ConnectionError(
            f"Delay ({delays}) is out of range [{min_delay}, {max_delay}]")
[docs]class StandardSynapseType(StandardModelType, models.BaseSynapseType):
    parameter_checks = {
        'weight': check_weights,
        # 'delay': check_delays  # this needs to be revisited in the context of min_delay = "auto"
    }
[docs]    def get_schema(self):
        """
        Returns the model schema: i.e. a mapping of parameter names to allowed
        parameter types.
        """
        base_schema = dict((name, type(value))
                           for name, value in self.default_parameters.items())
        base_schema['delay'] = float
        # delay has default value None, meaning "use the minimum delay",
        # so we have to correct the auto-generated schema
        return base_schema  
[docs]class STDPWeightDependence(StandardModelType):
    """Base class for models of STDP weight dependence."""
    def __init__(self, **parameters):
        StandardModelType.__init__(self, **parameters) 
[docs]class STDPTimingDependence(StandardModelType):
    """Base class for models of STDP timing dependence (triplets, etc)"""
    def __init__(self, **parameters):
        StandardModelType.__init__(self, **parameters)