From: Tanguy Fardet <tanguyfardet@protonmail.com>
---
.build.yml | 2 +-
doc/examples/basic_nest_network.py | 2 +-
doc/examples/random_balanced.py | 6 +++--
nngt/plot/chord_diag | 2 +-
nngt/plot/custom_plt.py | 40 +++++++++++++++++++++++++++---
nngt/plot/hive_helpers.py | 6 ++---
nngt/plot/plt_networks.py | 11 +++-----
testing/test_plots.py | 3 ---
8 files changed, 50 insertions(+), 22 deletions(-)
diff --git a/.build.yml b/.build.yml
index d89ecb3..0a2c4fb 100644
--- a/.build.yml
+++ b/.build.yml
@@ -32,7 +32,7 @@ tasks:
GL=nx coverage run -p -m pytest testing
GL=ig coverage run -p -m pytest testing
GL=nngt coverage run -p -m pytest testing
- GL=ig OMP=2 coverage run -p -m pytest -s testing
+ GL=ig OMP=2 PYNEST_QUIET=1 coverage run -p -m pytest testing
GL=gt OMP=0 MPI=1 mpirun -n 2 coverage run -p -m pytest --with-mpi testing
coverage combine
GIT_BRANCH=$(git show -s --pretty=%D HEAD | tr -s ', /' '\n' | grep -v HEAD | sed -n 2p)
diff --git a/doc/examples/basic_nest_network.py b/doc/examples/basic_nest_network.py
index 5acdcf4..2c69cd4 100644
--- a/doc/examples/basic_nest_network.py
+++ b/doc/examples/basic_nest_network.py
@@ -74,8 +74,8 @@ ng.connect_neural_types(net, -1, -1, "erdos_renyi", density=0.04)
# ------------------ #
if nngt.get_config('with_nest'):
- import nest
import nngt.simulation as ns
+ import nest
'''
Prepare the network and devices.
diff --git a/doc/examples/random_balanced.py b/doc/examples/random_balanced.py
index bf164cd..44255cf 100644
--- a/doc/examples/random_balanced.py
+++ b/doc/examples/random_balanced.py
@@ -155,13 +155,15 @@ Send the network to NEST, set noise and randomize parameters
'''
if nngt.get_config('with_nest'):
- import nest
import nngt.simulation as ns
from nngt.analysis import get_spikes
+ import nest
+
+ print_time = bool(os.environ.get("PYNEST_QUIET", False))
nest.ResetKernel()
- nest.SetKernelStatus({"resolution": dt, "print_time": True,
+ nest.SetKernelStatus({"resolution": dt, "print_time": print_time,
"overwrite_files": True, 'local_num_threads': 4})
gids = net.to_nest()
diff --git a/nngt/plot/chord_diag b/nngt/plot/chord_diag
index 370fe8e..f40c3bf 160000
--- a/nngt/plot/chord_diag
+++ b/nngt/plot/chord_diag
@@ -1 +1 @@
-Subproject commit 370fe8e80234950baaa1c655ac4ff18af284d16d
+Subproject commit f40c3bf4d7e3bbaf99d5dc4a11e4f1a1fd3a569b
diff --git a/nngt/plot/custom_plt.py b/nngt/plot/custom_plt.py
index 4f00a94..567c89f 100755
--- a/nngt/plot/custom_plt.py
+++ b/nngt/plot/custom_plt.py
@@ -26,9 +26,10 @@
import itertools
import logging
+from pkg_resources import parse_version
+
import matplotlib as mpl
-import matplotlib.cm as cm
-import matplotlib.colors as clrs
+from matplotlib.colors import Colormap
from matplotlib.markers import MarkerStyle as MS
import nngt
@@ -43,20 +44,51 @@ logger = logging.getLogger(__name__)
with_seaborn = False
+
+def get_cmap(colormap, n=None):
+ '''
+ Get a colormap.
+
+ Parameters
+ ----------
+ colormap : str or colormap
+ Colormap to return.
+ n : int, optional
+ Take `n` samples from the colormap.
+ '''
+ if not isinstance(colormap, Colormap):
+ colormap = mpl.colormaps[colormap]
+
+ if n is None:
+ return colormap
+
+ # check version for call to resampled
+ # @TODO require matplotlib > 3.6.0 in 2024 or something
+ mpl_version = parse_version(mpl.__version__)
+ min_version = parse_version("3.6.0")
+
+ if mpl_version < min_version:
+ return colormap._resample(n)
+
+ return colormap.resampled(n)
+
+
def palette_continuous(numbers=None):
- pal = cm.get_cmap(nngt._config["palette_continuous"])
+ pal = get_cmap(nngt._config["palette_continuous"])
if numbers is None:
return pal
else:
return pal(numbers)
+
def palette_discrete(numbers=None):
- pal = cm.get_cmap(nngt._config["palette_discrete"])
+ pal = get_cmap(nngt._config["palette_discrete"])
if numbers is None:
return pal
else:
return pal(numbers)
+
# markers list
markers = [m for m in MS.filled_markers if m != '.']
diff --git a/nngt/plot/hive_helpers.py b/nngt/plot/hive_helpers.py
index 4302d5b..d4cdbfd 100755
--- a/nngt/plot/hive_helpers.py
+++ b/nngt/plot/hive_helpers.py
@@ -21,14 +21,14 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
-from matplotlib import cm
from matplotlib.colors import ColorConverter
-from matplotlib.path import Path
from matplotlib.patches import PathPatch
+from matplotlib.path import Path
import numpy as np
from ..lib.test_functions import nonstring_container
+from .custom_plt import get_cmap
from .chord_diag.gradient import linear_gradient
@@ -265,7 +265,7 @@ def _get_colors(axes_colors, edge_colors, angles, num_axes, intra_connections,
if axes_colors is None or isinstance(axes_colors, str):
named_cmap = "Set1" if axes_colors is None else axes_colors
- cmap = cm.get_cmap(named_cmap)
+ cmap = get_cmap(named_cmap)
values = list(range(num_axes))
diff --git a/nngt/plot/plt_networks.py b/nngt/plot/plt_networks.py
index b754766..7b4dd47 100755
--- a/nngt/plot/plt_networks.py
+++ b/nngt/plot/plt_networks.py
@@ -32,7 +32,6 @@ from matplotlib.artist import Artist
from matplotlib.path import Path
from matplotlib.patches import FancyArrowPatch, ArrowStyle, FancyArrow, Circle
from matplotlib.patches import Arc, RegularPolygon, PathPatch, Patch
-from matplotlib.cm import get_cmap
from matplotlib.collections import PatchCollection, PathCollection
from matplotlib.colors import ListedColormap, Normalize, ColorConverter
from matplotlib.markers import MarkerStyle
@@ -41,7 +40,8 @@ from mpl_toolkits.axes_grid1 import make_axes_locatable
import nngt
from nngt.lib import POS, nonstring_container, is_integer
-from .custom_plt import palette_continuous, palette_discrete, format_exponent
+from .custom_plt import (get_cmap, palette_continuous, palette_discrete,
+ format_exponent)
from .chord_diag import chord_diagram as _chord_diag
from .hive_helpers import *
@@ -1925,14 +1925,11 @@ def _discrete_cmap(N, base_cmap=None, discrete=False):
# Modified from Jake VanderPlas
# License: BSD-style
'''
- import matplotlib.pyplot as plt
- # Note that if base_cmap is a string or None, you can simply do
- # return plt.cm.get_cmap(base_cmap, N)
- # The following works for string, None, or a colormap instance:
- base = plt.cm.get_cmap(base_cmap, N)
+ base = get_cmap(base_cmap, N)
color_list = base(np.arange(N))
cmap_name = base.name + str(N)
+
try:
return base.from_list(cmap_name, color_list, N)
except:
diff --git a/testing/test_plots.py b/testing/test_plots.py
index caed66d..a5cdd42 100644
--- a/testing/test_plots.py
+++ b/testing/test_plots.py
@@ -17,9 +17,6 @@ import nngt.generation as ng
import nngt.plot as nplt
-nngt.use_backend("igraph")
-
-
# absolute directory path
dirpath = os.path.abspath(os.path.dirname(__file__))
--
2.34.4