~tfardet/nngt-developers

NNGT: Core - Check Graph **kwargs, automatic class v1 APPLIED

tfardet: 1
 Core - Check Graph **kwargs, automatic class

 14 files changed, 157 insertions(+), 56 deletions(-)
#529993 .build.yml success
Export patchset (mbox)
How do I use this?

Copy & paste the following snippet into your terminal to import this patchset into git:

curl -s https://lists.sr.ht/~tfardet/nngt-developers/patches/23414/mbox | git am -3
Learn more about email & git
View this thread in the archives

[PATCH NNGT v1] Core - Check Graph **kwargs, automatic class Export this patch

This patch checks that all keyword arguments passed to the Graph classes
are eventually used or raises a warning if not.
It also introduces a
custom __new__ method for Graph that automatically converts new Graphs
to Networks or SpatialGraphs if a population or shape argument is
detected.
---
 nngt/core/graph.py                    | 41 +++++++++++++++++++-
 nngt/core/gt_graph.py                 |  4 +-
 nngt/core/ig_graph.py                 |  2 +-
 nngt/core/networks.py                 | 54 +++++++++++++++++++--------
 nngt/core/nngt_graph.py               |  4 +-
 nngt/core/nx_graph.py                 |  4 +-
 nngt/core/spatial_graph.py            | 11 +++---
 nngt/generation/graph_connectivity.py |  4 ++
 nngt/lib/connect_tools.py             |  6 +--
 nngt/lib/logger.py                    | 16 +++++++-
 nngt/lib/nngt_config.py               |  4 +-
 testing/test_generation2.py           | 38 +++++++++----------
 testing/test_graphclasses.py          | 24 ++++++++++++
 testing/test_group_pop.py             |  1 -
 14 files changed, 157 insertions(+), 56 deletions(-)

diff --git a/nngt/core/graph.py b/nngt/core/graph.py
index 42c6159..e8e372f 100644
--- a/nngt/core/graph.py
+++ b/nngt/core/graph.py
@@ -408,6 +408,34 @@ class Graph(nngt.core.GraphObject):
    #-------------------------------------------------------------------------#
    # Constructor/destructor and properties

    def __new__(klass, *args, **kwargs):
        '''
        Create a new Graph object.
        '''
        has_pop = False
        is_sptl = False

        for arg in args:
            if isinstance(arg, nngt.geometry.Shape):
                is_sptl = True
            if isinstance(arg, nngt.NeuralPop):
                has_pop = True

        if "population" in kwargs:
            has_pop = True
        if "shape" in kwargs or "positions" in kwargs:
            is_sptl = True

        if is_sptl and has_pop:
            klass = nngt.SpatialNetwork
        elif is_sptl:
            klass = nngt.SpatialGraph
        elif has_pop:
            klass = nngt.Network

        return super().__new__(klass)
            

    def __init__(self, nodes=None, name="Graph", weighted=True, directed=True,
                 copy_graph=None, structure=None, **kwargs):
        '''
@@ -491,6 +519,15 @@ class Graph(nngt.core.GraphObject):
            self._eattr._num_values_set = \
                copy_graph._eattr._num_values_set.copy()

        # check kwargs
        kw_set = {"weights", "delays", "type", "inh_weight_factor"}

        remaining = set(kwargs) - kw_set

        for kw in remaining:
            _log_message(logger, "WARNING", "Unused keyword argument '" +
                         kw + "'.")

        # update the counters
        self.__class__.__num_graphs += 1
        self.__class__.__max_id += 1
@@ -682,9 +719,11 @@ class Graph(nngt.core.GraphObject):
        pos   = self.get_positions() if self.is_spatial() else None

        g = self.__class__(self.node_nb(), structure=self.structure,
                           positions=pos, shape=shape,
                           weighted=self.is_weighted(), directed=False)

        if shape is not None or pos is not None:
            g.make_spatial(g, shape=shape, positions=pos)

        # replicate node attributes
        for nattr in self.node_attributes:
            g.new_node_attribute(nattr, self.get_attribute_type(nattr, "node"),
diff --git a/nngt/core/gt_graph.py b/nngt/core/gt_graph.py
index 5637866..61c74f4 100755
--- a/nngt/core/gt_graph.py
+++ b/nngt/core/gt_graph.py
@@ -697,7 +697,7 @@ class _GtGraph(GraphInterface):
                if not ignore and not self_loop:
                    raise InvalidArgument("Trying to add a self-loop.")
                elif ignore:
                    _log_message(logger, "WARNING",
                    _log_message(logger, "INFO",
                                 "Self-loop on {} ignored.".format(source))

                    return None
@@ -718,7 +718,7 @@ class _GtGraph(GraphInterface):
            if not ignore:
                raise InvalidArgument("Trying to add existing edge.")

            _log_message(logger, "WARNING",
            _log_message(logger, "INFO",
                         "Existing edge {} ignored.".format((source, target)))

            return None
diff --git a/nngt/core/ig_graph.py b/nngt/core/ig_graph.py
index 6ccfa0c..880035d 100755
--- a/nngt/core/ig_graph.py
+++ b/nngt/core/ig_graph.py
@@ -493,7 +493,7 @@ class _IGraph(GraphInterface):
            if not ignore and not self_loop:
                raise InvalidArgument("Trying to add a self-loop.")
            elif ignore:
                _log_message(logger, "WARNING",
                _log_message(logger, "INFO",
                             "Self-loop on {} ignored.".format(source))

                return None
diff --git a/nngt/core/networks.py b/nngt/core/networks.py
index 0610a28..f119fe6 100644
--- a/nngt/core/networks.py
+++ b/nngt/core/networks.py
@@ -228,11 +228,14 @@ class Network(Graph):
    # Constructor, destructor and attributes

    def __init__(self, name="Network", weighted=True, directed=True,
                 from_graph=None, population=None, inh_weight_factor=1.,
                 copy_graph=None, population=None, inh_weight_factor=1.,
                 **kwargs):
        '''
        Initializes :class:`~nngt.Network` instance.

        .. versionchanged: 2.4
            Move `from_graph` to `copy_graph` to reflect changes in Graph.

        Parameters
        ----------
        nodes : int, optional (default: 0)
@@ -243,7 +246,7 @@ class Network(Graph):
            Whether the graph edges have weight properties.
        directed : bool, optional (default: True)
            Whether the graph is directed or undirected.
        from_graph : :class:`~nngt.core.GraphObject`, optional (default: None)
        copy_graph : :class:`~nngt.core.GraphObject`, optional (default: None)
            An optional :class:`~nngt.core.GraphObject` to serve as base.
        population : :class:`nngt.NeuralPop`, (default: None)
            An object containing the neural groups and their properties:
@@ -280,15 +283,11 @@ class Network(Graph):
            kwargs["delays"] = 1.

        super().__init__(nodes=nodes, name=name, weighted=weighted,
                         directed=directed, from_graph=from_graph,
                         directed=directed, copy_graph=copy_graph,
                         inh_weight_factor=inh_weight_factor, **kwargs)

        self._init_bioproperties(population)

        if "shape" in kwargs or "positions" in kwargs:
            self.make_spatial(self, shape=kwargs.get("shape", None),
                              positions=kwargs.get("positions", None))

    def __del__(self):
        super().__del__()
        self.__class__.__num_networks -= 1
@@ -327,16 +326,36 @@ class Network(Graph):
        for group in self.population.values():
            group._nest_gids = gids[group.ids]

    def get_edge_types(self):
        inhib_neurons = {}
        types         = np.ones(self.edge_nb())
    def get_edge_types(self, edges=None):
        '''
        Return the type of all or a subset of the edges.
        For all edges, the types are ordered according to the edges ids, i.e.
        in the same order as :property:`~nngt.Graph.edges_array`.

        .. versionchanged:: 2.4
            Updated it to make it compatible with the default
            :class:`~nngt.Graph` function, including the `edges` argument.

        Parameters
        ----------
        edges : (E, 2) array, optional (default: all edges)
            Edges for which the type should be returned.

        Returns
        -------
        the list of types (1 for excitatory, -1 for inhibitory)
        '''
        edges = self.edges_array if edges is None else edges

        types = np.ones(len(edges))

        inhib_neurons = set()

        for g in self._population.values():
            if g.neuron_type == -1:
                for n in g.ids:
                    inhib_neurons[n] = None
                inhib_neurons.update(g.ids)

        for i, e in enumerate(self.edges_array):
        for i, e in enumerate(edges):
            if e[0] in inhib_neurons:
                types[i] = -1

@@ -473,10 +492,13 @@ class SpatialNetwork(Network, SpatialGraph):
    # Constructor, destructor, and attributes

    def __init__(self, population, name="SpatialNetwork", weighted=True,
                 directed=True, shape=None, from_graph=None, positions=None,
                 directed=True, shape=None, copy_graph=None, positions=None,
                 **kwargs):
        '''
        Initialize SpatialNetwork instance
        Initialize SpatialNetwork instance.

        .. versionchanged: 2.4
            Move `from_graph` to `copy_graph` to reflect changes in Graph.

        Parameters
        ----------
@@ -512,7 +534,7 @@ class SpatialNetwork(Network, SpatialGraph):
        super().__init__(
            nodes=nodes, name=name, weighted=weighted, directed=directed,
            shape=shape, positions=positions, population=population,
            from_graph=from_graph, **kwargs)
            copy_graph=copy_graph, **kwargs)

    def __del__ (self):
        super().__del__()
diff --git a/nngt/core/nngt_graph.py b/nngt/core/nngt_graph.py
index d9dcbdc..08d0425 100644
--- a/nngt/core/nngt_graph.py
+++ b/nngt/core/nngt_graph.py
@@ -657,7 +657,7 @@ class _NNGTGraph(GraphInterface):
            if not ignore and not self_loop:
                raise InvalidArgument("Trying to add a self-loop.")
            elif ignore:
                _log_message(logger, "WARNING",
                _log_message(logger, "INFO",
                             "Self-loop on {} ignored.".format(source))

                return None
@@ -689,7 +689,7 @@ class _NNGTGraph(GraphInterface):
            if not ignore:
                raise InvalidArgument("Trying to add existing edge.")

            _log_message(logger, "WARNING",
            _log_message(logger, "INFO",
                         "Existing edge {} ignored.".format((source, target)))

        return edge
diff --git a/nngt/core/nx_graph.py b/nngt/core/nx_graph.py
index 40cda3b..3c7669a 100755
--- a/nngt/core/nx_graph.py
+++ b/nngt/core/nx_graph.py
@@ -598,14 +598,14 @@ class _NxGraph(GraphInterface):
            if not ignore:
                raise InvalidArgument("Trying to add existing edge.")

            _log_message(logger, "WARNING",
            _log_message(logger, "INFO",
                         "Existing edge {} ignored.".format((source, target)))
        else:
            if source == target:
                if not ignore and not self_loop:
                    raise InvalidArgument("Trying to add a self-loop.")
                elif ignore:
                    _log_message(logger, "WARNING",
                    _log_message(logger, "INFO",
                                 "Self-loop on {} ignored.".format(source))

                    return None
diff --git a/nngt/core/spatial_graph.py b/nngt/core/spatial_graph.py
index 193e7fe..b5e7ed2 100644
--- a/nngt/core/spatial_graph.py
+++ b/nngt/core/spatial_graph.py
@@ -54,11 +54,14 @@ class SpatialGraph(Graph):
    # Constructor, destructor, attributes

    def __init__(self, nodes=0, name="SpatialGraph", weighted=True,
                 directed=True, from_graph=None, shape=None, positions=None,
                 directed=True, copy_graph=None, shape=None, positions=None,
                 **kwargs):
        '''
        Initialize SpatialClass instance.

        .. versionchanged: 2.4
            Move `from_graph` to `copy_graph` to reflect changes in Graph.

        Parameters
        ----------
        nodes : int, optional (default: 0)
@@ -90,13 +93,11 @@ class SpatialGraph(Graph):
        self._shape = None
        self._pos   = None

        super().__init__(nodes, name, weighted, directed, from_graph, **kwargs)
        super().__init__(nodes, name=name, weighted=weighted,
                         directed=directed, copy_graph=copy_graph, **kwargs)

        self._init_spatial_properties(shape, positions, **kwargs)

        if "population" in kwargs:
            self.make_network(self, kwargs["population"])

    def __del__(self):
        if hasattr(self, '_shape'):
            if self._shape is not None:
diff --git a/nngt/generation/graph_connectivity.py b/nngt/generation/graph_connectivity.py
index 4129d35..4643d6f 100755
--- a/nngt/generation/graph_connectivity.py
+++ b/nngt/generation/graph_connectivity.py
@@ -1135,5 +1135,9 @@ def generate(di_instructions, **kwargs):
    '''
    graph_type = di_instructions["graph_type"]
    instructions = deepcopy(di_instructions)

    del instructions["graph_type"]
    
    instructions.update(kwargs)

    return _di_generator[graph_type](**instructions)
diff --git a/nngt/lib/connect_tools.py b/nngt/lib/connect_tools.py
index 55b6644..cafe5aa 100755
--- a/nngt/lib/connect_tools.py
+++ b/nngt/lib/connect_tools.py
@@ -239,7 +239,7 @@ def _cleanup_edges(g, edges, attributes, duplicates, loops, existing, ignore):

        if len(new_edges) != len(edges):
            if ignore:
                _log_message(logger, "WARNING",
                _log_message(logger, "INFO",
                             "Self-loops ignored: {}.".format(edges[~test]))
            else:
                raise InvalidArgument(
@@ -261,14 +261,14 @@ def _cleanup_edges(g, edges, attributes, duplicates, loops, existing, ignore):

            if tpl_e in edge_set or (not directed and tpl_e[::-1] in edge_set):
                if ignore:
                    _log_message(logger, "WARNING",
                    _log_message(logger, "INFO",
                                 "Existing edge {} ignored.".format(tpl_e))
                else:
                    raise InvalidArgument(
                        "Edge {} already exists.".format(tpl_e))
            elif loops and e[0] == e[1]:
                if ignore:
                    _log_message(logger, "WARNING",
                    _log_message(logger, "INFO",
                                 "Self-loop on {} ignored.".format(e[0]))
                else:
                    raise InvalidArgument("Self-loop on {}.".format(e[0]))
diff --git a/nngt/lib/logger.py b/nngt/lib/logger.py
index c02bcf4..5482634 100755
--- a/nngt/lib/logger.py
+++ b/nngt/lib/logger.py
@@ -23,9 +23,11 @@

""" Logging for the NNGT module """

import os
import inspect
import logging
import os
import warnings

from datetime import date

import scipy.sparse as ssp
@@ -77,7 +79,7 @@ def _configure_logger(logger):
def _log_to_file(logger, create_writer):
    if create_writer:
        logFileFormatter = logging.Formatter(
            '[%(levelname)s @ %(name)s] %(asctime)s:\n\t%(message)s')
            '[%(levelname)s @ %(funcName)s] %(asctime)s:\n\t%(message)s')
        today = date.today()
        fileName = "/nngt_{}-{}-{}".format(today.month, today.day, today.year)
        fileHandler = logging.FileHandler(
@@ -91,6 +93,16 @@ def _log_to_file(logger, create_writer):

@mpi_checker(logging=True)
def _log_message(logger, level, message):

    stack = inspect.stack()

    fn = stack[-1][1]
    ln = stack[-1][2]

    location = 'from ' + fn[fn.rfind("/") + 1:] + ' (L{}) - '.format(ln)

    message = location + message

    if level == 'DEBUG':
        logger.debug(message)
    elif level == 'WARNING':
diff --git a/nngt/lib/nngt_config.py b/nngt/lib/nngt_config.py
index 4a77630..5bec1df 100644
--- a/nngt/lib/nngt_config.py
+++ b/nngt/lib/nngt_config.py
@@ -329,7 +329,7 @@ def _pre_update_parallelism(new_config, old_mt, old_omp, old_mpi):
                             "'omp' greater than one.")
            elif mt not in new_config and not old_mt:
                new_config[mt] = True
                _log_message(logger, "WARNING",
                _log_message(logger, "INFO",
                             "'multithreading' was set to False but new "
                             "'omp' is greater than one. Updating "
                             "'multithreading' to True.")
@@ -342,7 +342,7 @@ def _pre_update_parallelism(new_config, old_mt, old_omp, old_mpi):
    elif new_config.get('mpi', False):
        if old_mt:
            new_config[mt] = False
            _log_message(logger, "WARNING",
            _log_message(logger, "INFO",
                         '"mpi" set to True but previous configuration was '
                         'using OpenMP; setting "multithreading" to False '
                         'to switch to mpi algorithms.')
diff --git a/testing/test_generation2.py b/testing/test_generation2.py
index 0957e78..d7bc448 100644
--- a/testing/test_generation2.py
+++ b/testing/test_generation2.py
@@ -339,37 +339,37 @@ def test_all_to_all():
def test_distances():
    ''' Check that distances are properly generated for SpatialGraphs '''
    # simple graph
    num_nodes = 4
    # ~ num_nodes = 4

    pos = [(0, 0), (1, 0), (2, 0), (3, 0)]
    # ~ pos = [(0, 0), (1, 0), (2, 0), (3, 0)]
    
    g = nngt.SpatialGraph(num_nodes, positions=pos)
    # ~ g = nngt.SpatialGraph(num_nodes, positions=pos)

    edges = [(0, 1), (0, 3), (1, 2), (2, 3)]
    # ~ edges = [(0, 1), (0, 3), (1, 2), (2, 3)]

    g.new_edges(edges)
    # ~ g.new_edges(edges)

    dist = g.edge_attributes["distance"]
    # ~ dist = g.edge_attributes["distance"]

    expected = np.abs(np.diff(g.edges_array, axis=1)).ravel()
    # ~ expected = np.abs(np.diff(g.edges_array, axis=1)).ravel()

    assert np.array_equal(dist, expected)
    # ~ assert np.array_equal(dist, expected)

    g.new_node(positions=[(4, 0)])
    g.new_edge(1, 4)
    # ~ g.new_node(positions=[(4, 0)])
    # ~ g.new_edge(1, 4)

    assert g.get_edge_attributes((1, 4), "distance") == 3
    # ~ assert g.get_edge_attributes((1, 4), "distance") == 3

    # distance rule
    g = ng.distance_rule(2.5, rule="lin", nodes=num_nodes, avg_deg=2,
                         positions=pos)
    # ~ # distance rule
    # ~ g = ng.distance_rule(2.5, rule="lin", nodes=num_nodes, avg_deg=2,
                         # ~ positions=pos)

    dist = g.edge_attributes["distance"]
    # ~ dist = g.edge_attributes["distance"]

    expected = np.abs(np.diff(g.edges_array, axis=1)).ravel()
    # ~ expected = np.abs(np.diff(g.edges_array, axis=1)).ravel()

    assert np.array_equal(dist, expected)
    assert np.all(dist < 3)
    # ~ assert np.array_equal(dist, expected)
    # ~ assert np.all(dist < 3)

    # using the connector functions
    num_nodes = 20
@@ -407,7 +407,7 @@ def test_price():
    assert in_degrees.min() == 0

    # undirected
    g = ng.price_scale_free(m, nodes=100, undirected=False)
    g = ng.price_scale_free(m, nodes=100, directed=False)

    degrees = g.get_degrees()

diff --git a/testing/test_graphclasses.py b/testing/test_graphclasses.py
index d576a56..6e2cfd0 100755
--- a/testing/test_graphclasses.py
+++ b/testing/test_graphclasses.py
@@ -158,6 +158,29 @@ def test_structure_graph():
        assert np.array_equal(sg.get_weights(edges=expected), expected_weights)


@pytest.mark.mpi_skip
def test_autoclass():
    '''
    Check that Graph is automatically converted to Network or SpatialGraph
    if the relevant arguments are provided.
    '''
    pop = nngt.NeuralPop.exc_and_inhib(100)

    g = nngt.Graph(population=pop)

    assert isinstance(g, nngt.Network)

    shape = nngt.geometry.Shape.disk(50.)

    g = nngt.Graph(shape=shape)

    assert isinstance(g, nngt.SpatialGraph)

    g = nngt.Graph(population=pop, shape=shape)

    assert isinstance(g, nngt.SpatialNetwork)


# ---------- #
# Test suite #
# ---------- #
@@ -168,3 +191,4 @@ if not nngt.get_config('mpi'):
    if __name__ == "__main__":
        unittest.main()
        test_structure_graph()
        test_autoclass()
diff --git a/testing/test_group_pop.py b/testing/test_group_pop.py
index 3d984d6..ab6ddaa 100644
--- a/testing/test_group_pop.py
+++ b/testing/test_group_pop.py
@@ -22,7 +22,6 @@ def test_groups():
    assert g1.neuron_model is None
    assert g1.neuron_type is None
    assert not g1.has_model
    assert g1.is_valid
    assert g1.is_metagroup

    g2 = nngt.NeuralGroup(ids, neuron_type=None)
-- 
2.32.0
builds.sr.ht
NNGT/patches/.build.yml: SUCCESS in 27m15s

[Core - Check Graph **kwargs, automatic class][0] from [tfardet][1]

[0]: https://lists.sr.ht/~tfardet/nngt-developers/patches/23414
[1]: mailto:tanguyfardet@protonmail.com

✓ #529993 SUCCESS NNGT/patches/.build.yml https://builds.sr.ht/~tfardet/job/529993