~tfardet/nngt-developers

NNGT: Plot: node annotations and improved plots v3 APPLIED

~tfardet: 1
 Plot: node annotations and improved plots

 7 files changed, 804 insertions(+), 389 deletions(-)
#601471 .build.yml success
Export patchset (mbox)
How do I use this?

Copy & paste the following snippet into your terminal to import this patchset into git:

curl -s https://lists.sr.ht/~tfardet/nngt-developers/patches/25530/mbox | git am -3
Learn more about email & git

[PATCH NNGT v3] Plot: node annotations and improved plots Export this patch

From: Tanguy Fardet <tanguyfardet@protonmail.com>

Add node annotations.
Fix curved edges.
Improved draw_network options support.
Improved self-loops.
Prepare for igraph new matplotlib support.
Fix NaturalEarth resources download.
---
 .gitignore                                   |    3 +
 doc/examples/graph_structure/plot_layouts.py |   11 +-
 doc/examples/graph_structure/plot_map.py     |    2 +-
 nngt/geospatial/_cartopy_ne.py               |   26 +-
 nngt/plot/plt_networks.py                    | 1074 ++++++++++++------
 nngt/plot/plt_properties.py                  |    9 +-
 testing/test_plots.py                        |   68 +-
 7 files changed, 804 insertions(+), 389 deletions(-)

diff --git a/.gitignore b/.gitignore
index 8b07cc4..e7a6512 100755
--- a/.gitignore
+++ b/.gitignore
@@ -38,3 +38,6 @@ nngt.egg-info/

# visualcode
*.vscode

# KDevelop
*.kdev4
diff --git a/doc/examples/graph_structure/plot_layouts.py b/doc/examples/graph_structure/plot_layouts.py
index d6cb37d..0d84d02 100644
--- a/doc/examples/graph_structure/plot_layouts.py
+++ b/doc/examples/graph_structure/plot_layouts.py
@@ -118,14 +118,17 @@ c2 = nngt.geometry.Shape.disk(100, centroid=(50, 0))

shape = nngt.geometry.Shape.from_polygon(c1.union(c2))

npos  = shape.seed_neurons(num_nodes)
max_nsize = 15
npos  = shape.seed_neurons(num_nodes, soma_radius=0.5*max_nsize)

g = nngt.generation.distance_rule(10, shape=shape, nodes=num_nodes, avg_deg=5)
g = nngt.generation.distance_rule(10, shape=shape, nodes=num_nodes, avg_deg=5,
                                  positions=npos)

cc = nngt.analysis.local_clustering(g)

nngt.plot.draw_network(g, ncolor=cc, axis=axes[1, 1], ecolor="lightgrey",
                       tight=False, show=False)
nngt.plot.draw_network(g, ncolor=cc, axis=axes[1, 1], ecolor="grey", show=False,
                       eborder_width=0.5, eborder_color="w", esize=10,
                       max_nsize=max_nsize, tight=False)

axes[1, 1].set_title("Spatial layout")

diff --git a/doc/examples/graph_structure/plot_map.py b/doc/examples/graph_structure/plot_map.py
index c0eb053..f5a2ac2 100644
--- a/doc/examples/graph_structure/plot_map.py
+++ b/doc/examples/graph_structure/plot_map.py
@@ -60,7 +60,7 @@ g.set_weights(nngt._rng.exponential(2, g.edge_nb()))

# plot using draw_map and the A3 codes stored in "code"
ng.draw_map(g, "code", ncolor="in-degree", esize="weight", threshold=0,
            ecolor="grey", proj=ccrs.EqualEarth(), show=False)
            ecolor="grey", proj=ccrs.EqualEarth(), max_nsize=20, show=False)

plt.tight_layout()
plt.show()
diff --git a/nngt/geospatial/_cartopy_ne.py b/nngt/geospatial/_cartopy_ne.py
index d917910..983db03 100644
--- a/nngt/geospatial/_cartopy_ne.py
+++ b/nngt/geospatial/_cartopy_ne.py
@@ -53,6 +53,9 @@ class NEShpDownloader(Downloader):
    _NE_URL_TEMPLATE = ('https://naciscdn.org/naturalearth/{resolution}'
                        '/{category}/ne_{resolution}_{name}.zip')

    _NE_URL_BACKUP = ('https://naturalearth.s3.amazonaws.com/{resolution}'
                      '_{category}/ne_{resolution}_{name}.zip')

    def __init__(self, url_template=_NE_URL_TEMPLATE,
                 target_path_template=None, pre_downloaded_path_template=''):
        ''' adds some NE defaults to the __init__ of a Downloader'''
@@ -99,8 +102,8 @@ class NEShpDownloader(Downloader):

        return target_path

    @staticmethod
    def default_downloader():
    @classmethod
    def default_downloader(cls, backup=False):
        '''
        Return a generic, standard, NEShpDownloader instance.

@@ -115,10 +118,16 @@ ne_{resolution}_{name}.shp
        '''
        default_spec = ('shapefiles', 'natural_earth', '{category}',
                        'ne_{resolution}_{name}.shp')

        ne_path_template = os.path.join('{config[data_dir]}', *default_spec)

        pre_path_template = os.path.join('{config[pre_existing_data_dir]}',
                                         *default_spec)
        return NEShpDownloader(target_path_template=ne_path_template,

        url_template = cls._NE_URL_BACKUP if backup else cls._NE_URL_TEMPLATE

        return NEShpDownloader(url_template=url_template,
                               target_path_template=ne_path_template,
                               pre_downloaded_path_template=pre_path_template)

    
@@ -152,8 +161,15 @@ def natural_earth(resolution='110m', category='physical', name='coastline'):
    # get hold of the Downloader (typically a NEShpDownloader instance)
    # which we can then simply call its path method to get the appropriate
    # shapefile (it will download if necessary)
    ne_downloader = NEShpDownloader.default_downloader()
    format_dict = {'config': cartopy.config, 'category': category,
                   'name': name, 'resolution': resolution}
                    'name': name, 'resolution': resolution}

    try:
        ne_downloader = NEShpDownloader.default_downloader()
        return ne_downloader.path(format_dict)
    except:
        ne_downloader = NEShpDownloader.default_downloader(backup=True)

    return ne_downloader.path(format_dict)

                                         
diff --git a/nngt/plot/plt_networks.py b/nngt/plot/plt_networks.py
index e6b7d41..4adefed 100755
--- a/nngt/plot/plt_networks.py
+++ b/nngt/plot/plt_networks.py
@@ -28,8 +28,9 @@ import numpy as np

import matplotlib as mpl
from matplotlib.artist import Artist
from matplotlib.path import Path
from matplotlib.patches import FancyArrowPatch, ArrowStyle, FancyArrow, Circle
from matplotlib.patches import Arc, RegularPolygon, PathPatch
from matplotlib.patches import Arc, RegularPolygon, PathPatch, Patch
from matplotlib.cm import get_cmap
from matplotlib.collections import PatchCollection, PathCollection
from matplotlib.colors import ListedColormap, Normalize, ColorConverter
@@ -72,8 +73,7 @@ __all__ = ["chord_diagram", "draw_network", "hive_plot", "library_draw"]
# ------- #

def draw_network(network, nsize="total-degree", ncolor=None, nshape="o",
                 nborder_color="k", nborder_width=0.5, esize=None, ecolor="k",
                 ealpha=0.5, curved_edges=False, threshold=0.5,
                 esize=None, ecolor="k", curved_edges=False, threshold=0.5,
                 decimate_connections=None, spatial=True,
                 restrict_sources=None, restrict_targets=None,
                 restrict_nodes=None, restrict_edges=None,
@@ -115,6 +115,8 @@ def draw_network(network, nsize="total-degree", ncolor=None, nshape="o",
        Edge color. If ecolor="groups", edges color will depend on the source
        and target groups, i.e. only edges from and toward same groups will
        have the same color.
    curved_edges : bool, optional (default: False)
        Whether the edges should be curved or straight.
    threshold : float, optional (default: 0.5)
        Size under which edges are not plotted.
    decimate_connections : int, optional (default: keep all connections)
@@ -171,9 +173,17 @@ def draw_network(network, nsize="total-degree", ncolor=None, nshape="o",
        min_*             float               Minimum value for `nsize` or
                                              `esize`
        ----------------  ------------------  ---------------------------------
        nalpha            float               Node opacity in [0, 1]`
        nalpha            float               Node opacity in [0, 1]`, default 1
        ----------------  ------------------  ---------------------------------
        ealpha            float               Edge opacity, default 0.5
        ----------------  ------------------  ---------------------------------
                                              Color of the border for nodes (n)
        *border_color     color               or edges (e).
                                              Default to black.
        ----------------  ------------------  ---------------------------------
        ealpha            float               Edge opacity in [0, 1]`
                                              Border size for nodes (n) or edges
        *border_width     float               (e). Default to .5 for nodes and
                                              .3 for edges (if `fast` is False).
        ----------------  ------------------  ---------------------------------
                                              Whether to use simple nodes (that
        simple_nodes      bool                are always the same size) or
@@ -185,10 +195,14 @@ def draw_network(network, nsize="total-degree", ncolor=None, nshape="o",
    # figure and axes
    size_inches = (size[0]/float(dpi), size[1]/float(dpi))

    fig = None

    if axis is None:
        fig = plt.figure(facecolor='white', figsize=size_inches,
                         dpi=dpi)
        axis = fig.add_subplot(111, frameon=0, aspect=1)
    else:
        fig = axis.get_figure()

    # projections for geographic plots

@@ -247,31 +261,83 @@ def draw_network(network, nsize="total-degree", ncolor=None, nshape="o",
        adj_mat[:, remove] = 0

    edges = (np.array(adj_mat.nonzero()).T if restrict_edges is None else
             restrict_edges)
             np.asarray(restrict_edges))

    e = len(edges)

    # compute properties
    decimate_connections = 1 if decimate_connections is None\
                           else decimate_connections

    # get node and edge shape/size properties
    simple_nodes = kwargs.get("simple_nodes", False)
    # get positions (all cases except circular layout which is done below the
    # node sizes
    pos = None

    if esize is None:
        esize = 1 if fast else 5
    spatial *= network.is_spatial()

    max_nsize = kwargs.get("max_nsize", None)
    min_nsize = kwargs.get("min_nsize", None)
    if nonstring_container(layout):
        assert np.shape(layout) == (n, 2), "One position per node is required."
        pos = np.asarray(layout).astype(float)
        spatial = False
    elif spatial:
        if show_environment:
            nngt.geometry.plot.plot_shape(network.shape, axis=axis, show=False)

    max_esize = kwargs.get("max_esize", 2)
        nodes = None if restrict_nodes is None else list(restrict_nodes)

        pos = network.get_positions(nodes=nodes).astype(float)
    elif layout in (None, "random"):
        pos = np.random.uniform(size=(n, 2)) - 0.5

        pos[:, 0] *= size[0]
        pos[:, 1] *= size[1]
    elif layout not in ("circular", "random", None):
        raise ValueError("Unknown `layout`: {}".format(layout))

    # get node and edge size extrema and drawing properties
    simple_nodes = kwargs.get("simple_nodes", fast)

    dist = min(size)

    if pos is not None:
        dist = min(pos[:, 0].max() - pos[:, 0].min(),
                   pos[:, 1].max() - pos[:, 1].min())

    max_nsize = kwargs.get("max_nsize", 100 if simple_nodes else 0.05*dist)
    min_nsize = kwargs.get("min_nsize", 0.2*max_nsize)

    max_esize = kwargs.get("max_esize", 5 if fast else 0.05*dist)
    min_esize = kwargs.get("min_esize", 0)

    if fast:
        simple_nodes = True
        max_nsize *= 0.01*min(size)
        min_nsize *= 0.01*min(size)
        max_esize *= 0.005*min(size)
        min_esize *= 0.005*min(size)
        threshold *= 0.005*min(size)

    if esize is None:
        esize = 0.5*max_esize

    max_nsize = (20 if simple_nodes else 5) if max_nsize is None else max_nsize
    # circular layout
    if isinstance(layout, str) and layout == "circular":
        pos = _circular_layout(network, max_nsize)

    # check axis extent
    xmax = pos[:, 0].max()
    xmin = pos[:, 0].min()
    ymax = pos[:, 1].max()
    ymin = pos[:, 1].min()

    height = ymax - ymin
    width = xmax - xmin

    if not show_environment or not spatial or proj is not None:
        # axis.get_data()
        _set_ax_lim(axis, xmax, xmin, ymax, ymin, height, width, xlims, ylims,
                    max_nsize, fast)

    # get node and edge shape/size properties
    markers, nsize, esize = _node_edge_shape_size(
        network, nshape, nsize, max_nsize, min_nsize, esize, max_esize,
        min_esize, restrict_nodes, edges, size, threshold,
@@ -284,16 +350,23 @@ def draw_network(network, nsize="total-degree", ncolor=None, nshape="o",
        else:
            ncolor = "r"

    nborder_color = kwargs.get("nborder_color", "k")
    nborder_width = kwargs.get("nborder_width", 0.5)

    eborder_color = kwargs.get("eborder_color", "k")
    eborder_width = kwargs.get("eborder_width", 0.3)

    discrete_colors, default_ncmap = _get_ncmap(network, ncolor)

    nalpha = kwargs.get("nalpha", 1)
    ealpha = kwargs.get("ealpha", 0.5)

    ncmap = get_cmap(kwargs.get("node_cmap", default_ncmap))

    node_color, nticks, ntickslabels, nlabel = _node_color(
        network, restrict_nodes, ncolor, discrete_colors=discrete_colors)

    if nonstring_container(ncolor):
    if nonstring_container(ncolor) and not len(ncolor) in (3, 4):
        assert len(ncolor) == n, "For color arrays, one " +\
            "color per node is required."
        ncolor = "custom"
@@ -303,75 +376,20 @@ def draw_network(network, nsize="total-degree", ncolor=None, nshape="o",
    if not nonstring_container(nborder_color):
        nborder_color = np.repeat(nborder_color, n)

    # check edge color
    group_based = False

    default_ecmap = (palette_discrete() if not nonstring_container(ncolor) and
                     ecolor == "group" else palette_continuous())

    if ecolor == "groups" or ecolor == "group":
        if not network.is_network():
            raise TypeError(
                "The graph must be a Network to use `ecolor='groups'`.")

        group_based = True
        ecolor      = {}

        for i, src in enumerate(network.population):
            if network.population[src].ids:
                idx1 = network.population[src].ids[0]
                for j, tgt in enumerate(network.population):
                    if network.population[tgt].ids:
                        idx2 = network.population[tgt].ids[0]
                        if src == tgt:
                            ecolor[(src, tgt)] = node_color[idx1]
                        else:
                            ecolor[(src, tgt)] = \
                                np.abs(0.8*node_color[idx1]
                                       - 0.2*node_color[idx2])
    elif not nonstring_container(ecolor):
        ecolor = np.repeat(ecolor, e)

    # draw
    pos = kwargs.get("positions", None)

    spatial *= (network.is_spatial() or pos is not None)

    if nonstring_container(layout):
        assert np.shape(layout) == (n, 2), "One position per node is required."
        pos = np.asarray(layout)
    elif layout == "circular":
        pos = _circular_layout(network, nsize)
    elif pos is None and spatial:
        if show_environment:
            nngt.geometry.plot.plot_shape(network.shape, axis=axis, show=False)

        nodes = None if restrict_nodes is None else list(restrict_nodes)

        pos = network.get_positions(nodes=nodes)
    else:
        pos = np.full((n, 2), np.NaN)

        pos[:, 0] = size[0]*(np.random.uniform(size=n)-0.5)
        pos[:, 1] = size[1]*(np.random.uniform(size=n)-0.5)

    pos = np.array(pos).astype(float)

    # prepare node colors
    if nonstring_container(c) and not isinstance(c[0], str):
    if nonstring_container(c) and not isinstance(c[0], (str, np.ndarray)):
        # make the colorbar for the nodes
        cmap = ncmap
        cnorm = None

        if colorbar:
            cnorm = None

            if discrete_colors:
                cmap = _discrete_cmap(len(nticks), ncmap, discrete_colors)
                cnorm = Normalize(nticks[0]-0.5, nticks[-1] + 0.5)
            else:
                cnorm = Normalize(np.min(c), np.max(c))
                c = cnorm(c)
        if discrete_colors:
            cmap = _discrete_cmap(len(nticks), ncmap, discrete_colors)
            cnorm = Normalize(nticks[0]-0.5, nticks[-1] + 0.5)
        else:
            cnorm = Normalize(np.min(c), np.max(c))
            c = cnorm(c)

        if colorbar:
            sm = plt.cm.ScalarMappable(cmap=cmap, norm=cnorm)

            if discrete_colors:
@@ -407,7 +425,37 @@ def draw_network(network, nsize="total-degree", ncolor=None, nshape="o",
            c = np.array(
                [ncmap((node_color - minc)/(np.max(node_color) - minc))]*n)

    # check edge color
    group_based = False

    default_ecmap = (palette_discrete() if not nonstring_container(ncolor) and
                     ecolor == "group" else palette_continuous())

    if ecolor == "groups" or ecolor == "group":
        if network.structure is None:
            raise TypeError(
                "The graph must have a Structure/NeuralPop to use "
                "`ecolor='groups'`.")

        group_based = True
        ecolor = {}

        for i, src in enumerate(network.structure):
            if network.structure[src].ids:
                idx1 = network.structure[src].ids[0]
                for j, tgt in enumerate(network.structure):
                    if network.structure[tgt].ids:
                        idx2 = network.structure[tgt].ids[0]
                        if src == tgt:
                            ecolor[(src, tgt)] = c[idx1]
                        else:
                            ecolor[(src, tgt)] = 0.7*c[idx1] + 0.3*c[idx2]
    elif not nonstring_container(ecolor):
        ecolor = np.repeat(ecolor, e)

    # plot nodes
    scatter = []

    if simple_nodes:
        if nonstring_container(nshape):
            # matplotlib scatter does not support marker arrays
@@ -416,66 +464,79 @@ def draw_network(network, nsize="total-degree", ncolor=None, nshape="o",
                    ids = g.ids if restrict_nodes is None \
                          else list(set(g.ids).intersection(restrict_nodes))

                    axis.scatter(pos[ids, 0], pos[ids, 1], color=c[ids],
                                 s=0.5*np.array(nsize)[ids],
                                 marker=markers[ids[0]], zorder=2,
                                 edgecolors=nborder_color,
                                 linewidths=nborder_width, alpha=nalpha)
                    scatter.append(
                        axis.scatter(pos[ids, 0], pos[ids, 1], color=c[ids],
                                     s=0.5*np.array(nsize)[ids],
                                     marker=markers[ids[0]], zorder=2,
                                     edgecolors=nborder_color,
                                     linewidths=nborder_width, alpha=nalpha))
            else:
                ids = range(network.node_nb()) if restrict_nodes is None \
                      else restrict_nodes

                for i in ids:
                    axis.plot(
                    scatter.append(axis.scatter(
                        pos[i, 0], pos[i, 1], color=c[i], ms=0.5*nsize[i],
                        marker=nshape[i], ls="", zorder=2,
                        mec=nborder_color[i], mew=nborder_width, alpha=nalpha)
                        marker=nshape[i], zorder=2, mec=nborder_color[i],
                        mew=nborder_width, alpha=nalpha))
        else:
            axis.scatter(pos[:, 0], pos[:, 1], color=c, s=0.5*np.array(nsize),
                         marker=nshape, zorder=2, edgecolor=nborder_color,
                         linewidths=nborder_width, alpha=nalpha)
            scatter.append(axis.scatter(
                pos[:, 0], pos[:, 1], color=c, s=0.5*np.array(nsize),
                marker=nshape, zorder=2, edgecolor=nborder_color,
                linewidths=nborder_width, alpha=nalpha))
    else:
        nodes = []

        axis.set_aspect(1.)

        if network.is_network():
            for group in network.population.values():
                idx = group.ids if restrict_nodes is None \
                      else list(set(restrict_nodes).intersection(group.ids))
        if network.structure is not None:
            converter = None

            if restrict_nodes is not None:
                converter = {n: i for i, n in enumerate(restrict_nodes)}

            for group in network.structure.values():
                idx = group.ids

                if restrict_nodes is not None:
                    idx = [converter[n]
                           for n in set(restrict_nodes).intersection(idx)]

                for i, fc in zip(idx, c[idx]):
                    m = MarkerStyle(markers[i]).get_path()
                    center = np.average(m.vertices, axis=0)
                    m = Path(m.vertices - center, m.codes)
                    transform = Affine2D().scale(
                        0.5*nsize[i]).translate(pos[i][0], pos[i][1])
                    patch = PathPatch(m.transformed(transform), facecolor=fc,
                                      edgecolor=nborder_color[i], alpha=nalpha)
                    patch = PathPatch(
                        m.transformed(transform), facecolor=fc,
                        lw=nborder_width, edgecolor=nborder_color[i],
                        alpha=nalpha)
                    nodes.append(patch)
        else:
            for i, ci in enumerate(c):
                m = MarkerStyle(markers[i]).get_path()
                center = np.average(m.vertices, axis=0)
                m = Path(m.vertices - center, m.codes)
                transform = Affine2D().scale(0.5*nsize[i]).translate(
                    pos[i][0], pos[i][1])
                patch = PathPatch(m.transformed(transform), facecolor=ci,
                                  edgecolor=nborder_color[i], alpha=nalpha)
                    pos[i, 0], pos[i, 1])
                patch = PathPatch(
                    m.transformed(transform), facecolor=ci,
                    lw=nborder_width, edgecolor=nborder_color[i], alpha=nalpha)
                nodes.append(patch)

        nodes = PatchCollection(nodes, match_original=True, alpha=nalpha)
        nodes.set_zorder(2)
        axis.add_collection(nodes)
        scatter = PatchCollection(nodes, match_original=True, alpha=nalpha)
        scatter.set_zorder(2)
        axis.add_collection(scatter)

    if not show_environment or not spatial or proj is not None:
        # axis.get_data()
        _set_ax_lim(axis, pos[:, 0], pos[:, 1], xlims, ylims)
    # draw the edges
    arrows = []

    # use quiver to draw the edges
    if e and decimate_connections != -1:
        avg_size = np.average(nsize)

        arrows = []

        if group_based:
            for src_name, src_group in network.population.items():
                for tgt_name, tgt_group in network.population.items():
            for src_name, src_group in network.structure.items():
                for tgt_name, tgt_group in network.structure.items():
                    s_ids = src_group.ids

                    if restrict_sources is not None:
@@ -492,222 +553,188 @@ def draw_network(network, nsize="total-degree", ncolor=None, nshape="o",

                        edges = np.array(
                            adj_mat[s_min:s_max, t_min:t_max].nonzero(),
                            dtype=int)
                            dtype=int).T

                        edges[0, :] += s_min
                        edges[1, :] += t_min
                        edges[:, 0] += s_min
                        edges[:, 1] += t_min

                        strght_edges, self_loops, strght_sizes, loop_sizes = \
                            _split_edges_sizes(edges, esize,
                                               decimate_connections)

                        # plot
                        ec = default_ecmap(ecolor[(src_name, tgt_name)])
                        ec = ecolor[(src_name, tgt_name)]

                        if fast:
                            dl       = 0.5*np.max(nsize)
                            arrow_x  = pos[strght_edges[1], 0] - \
                                       pos[strght_edges[0], 0]
                        if len(strght_edges) and fast:
                            dl       = 0 if simple_nodes else 0.5*np.max(nsize)
                            arrow_x  = pos[strght_edges[:, 1], 0] - \
                                       pos[strght_edges[:, 0], 0]
                            arrow_x -= np.sign(arrow_x) * dl
                            arrow_y  = pos[strght_edges[1], 1] - \
                                       pos[strght_edges[0], 1]
                            arrow_y  = pos[strght_edges[:, 1], 1] - \
                                       pos[strght_edges[:, 0], 1]
                            arrow_x -= np.sign(arrow_y) * dl

                            axis.quiver(
                                pos[strght_edges[0], 0],
                                pos[strght_edges[0], 1], arrow_x,
                                pos[strght_edges[:, 0], 0],
                                pos[strght_edges[:, 0], 1], arrow_x,
                                arrow_y, scale_units='xy', angles='xy',
                                scale=1, alpha=ealpha, width=0.005*size[0],
                                linewidths=0.5*strght_sizes, edgecolors=ec,
                                zorder=1, **kw)

                            for i, s in enumerate(self_loops):
                                es = loop_sizes[i]
                                ec = loop_colors[i]
                                cs = _connectionstyle(axis, nsize[s],
                                                      loop_sizes[i])

                                arrow = FancyArrowPatch(
                                    pos[s], pos[s], arrowstyle=arrowstyle,
                                    shrinkA=nsize[s], color=ec,
                                    linewidth=es, connectionstyle=cs, zorder=1)

                                axis.add_patch(arrow)
                        else:
                            for s, t in zip(strght_edges[0], strght_edges[1]):
                                scale=1, alpha=ealpha,
                                width=3e-3, linewidths=0.5*strght_sizes,
                                edgecolors=ec, color=ec, zorder=1, **kw)
                        elif len(strght_edges):
                            for i, (s, t) in enumerate(strght_edges):
                                xs, ys = pos[s, 0], pos[s, 1]
                                xt, yt = pos[t, 0], pos[t, 1]
                                dl     = 0.5*nsize[t]
                                dx     = xt-xs
                                dx -= np.sign(dx) * dl
                                dy     = yt-ys
                                dy -= np.sign(dy) * dl

                                if curved_edges:
                                    arrow = FancyArrowPatch(
                                        posA=(xs, ys), posB=(xt, yt),
                                        arrowstyle=arrowstyle,
                                        connectionstyle='arc3,rad=0.1',
                                        alpha=ealpha, fc=ec, lw=0.5)
                                    axis.add_patch(arrow)
                                elif s == t:
                                    # self-loop
                                    arrow = FancyArrowPatch(
                                        posA=(xs, ys), posB=(xt, yt),
                                        arrowstyle=arrowstyle,
                                        connectionstyle='arc3,rad=1',
                                        alpha=ealpha, fc=ec, lw=0.5)

                                    axis.add_patch(arrow)
                                else:
                                    arrows.append(FancyArrow(
                                        xs, ys, dx, dy, width=0.3*avg_size,
                                        head_length=0.7*avg_size,
                                        head_width=0.7*avg_size,
                                        length_includes_head=True,
                                        alpha=ealpha, fc=ec, lw=0.5))
        else:
            if e and decimate_connections >= 1:
                strght_colors, loop_colors = [], []

                strght_edges, self_loops, strght_sizes, loop_sizes = \
                    _split_edges_sizes(edges, esize, decimate_connections,
                    ecolor, strght_colors, loop_colors)
                                sA = 0 if simple_nodes else 0.5*nsize[s]
                                sB = 0 if simple_nodes else 0.5*nsize[t]

                # keep only desired edges
                if None not in (restrict_sources, restrict_targets):
                    new_edges = []
                    new_colors = []
                                cs = 'arc3,rad=0.2' if curved_edges else None

                                astyle = ArrowStyle.Simple(
                                    head_length=0.7*strght_sizes[i],
                                    head_width=0.7*strght_sizes[i],
                                    tail_width=0.3*strght_sizes[i])

                    for edge, ec in zip(strght_edges, strght_colors):
                        s, t = edge
                                arrows.append(FancyArrowPatch(
                                    posA=(xs, ys), posB=(xt, yt),
                                    arrowstyle=astyle, connectionstyle=cs,
                                    alpha=ealpha, fc=ec, zorder=1,
                                    shrinkA=0.5*nsize[s], shrinkB=0.5*nsize[t],
                                    lw=eborder_width, ec=eborder_color))

                        if s in restrict_sources and t in restrict_targets:
                            new_edges.append(edge)
                            new_colors.append(ec)
                        for i, s in enumerate(self_loops):
                            loop = _plot_loop(
                                i, s, pos, loop_sizes, nsize, max_nsize, xmax,
                                xmin, ymax, ymin, height, width, ec, ealpha,
                                eborder_width, eborder_color, fast, network,
                                restrict_nodes)

                            axis.add_artist(loop)
        else:
            strght_colors, loop_colors = [], []

                    strght_edges = np.array(new_edges, dtype=int)
                    strght_colors = new_colors
            strght_edges, self_loops, strght_sizes, loop_sizes = \
                _split_edges_sizes(edges, esize, decimate_connections,
                ecolor, strght_colors, loop_colors)

                    if restrict_nodes is not None:
                        nodes = list(self_loops)
                        nodes.sort()
            # keep only desired edges
            if None not in (restrict_sources, restrict_targets):
                new_edges = []
                new_colors = []

                        new_loops = set()
                        new_colors = []
                for edge, ec in zip(strght_edges, strght_colors):
                    s, t = edge

                        for i, node in enumerate(restrict_nodes):
                            strght_edges[strght_edges == node] = i
                    if s in restrict_sources and t in restrict_targets:
                        new_edges.append(edge)
                        new_colors.append(ec)

                            if node in self_loops:
                                idx = nodes.index(node)
                                new_loops.add(i)
                                new_colors.append(loop_colors[idx])
                strght_edges = np.array(new_edges, dtype=int)
                strght_colors = new_colors

                        self_loops = new_loops
                        loop_colors = new_colors
                elif restrict_sources is not None:
                    new_edges = []
                if restrict_nodes is not None:
                    nodes = list(self_loops)
                    nodes.sort()

                    new_loops = set()
                    new_colors = []

                    for edge, ec in zip(strght_edges, strght_colors):
                        s, _ = edge
                    for i, node in enumerate(restrict_nodes):
                        strght_edges[strght_edges == node] = i

                        if s in restrict_sources:
                            new_edges.append(edge)
                            new_colors.append(ec)
                        if node in self_loops:
                            idx = nodes.index(node)
                            new_loops.add(i)
                            new_colors.append(loop_colors[idx])

                    strght_edges = np.array(new_edges, dtype=int)
                    self_loops = new_loops
                    loop_colors = new_colors
            elif restrict_sources is not None:
                new_edges = []
                new_colors = []

                    loop_colors = [ec for ec, n in zip(loop_colors, self_loops)
                                   if n in restrict_sources]
                    self_loops  = self_loops.intersection(restrict_sources)
                elif restrict_targets is not None:
                    new_edges = []
                    new_colors = []
                for edge, ec in zip(strght_edges, strght_colors):
                    s, _ = edge

                    if s in restrict_sources:
                        new_edges.append(edge)
                        new_colors.append(ec)

                strght_edges = np.array(new_edges, dtype=int)

                    for edge, ec in zip(strght_edges, strght_colors):
                        _, t = edge
                loop_colors = [ec for ec, n in zip(loop_colors, self_loops)
                                if n in restrict_sources]
                self_loops  = self_loops.intersection(restrict_sources)
            elif restrict_targets is not None:
                new_edges = []
                new_colors = []

                        if t in restrict_targets:
                            new_edges.append(edge)
                            new_colors.append(ec)
                for edge, ec in zip(strght_edges, strght_colors):
                    _, t = edge

                    strght_edges = np.array(new_edges, dtype=int)
                    if t in restrict_targets:
                        new_edges.append(edge)
                        new_colors.append(ec)

                    loop_colors = [ec for ec, n in zip(loop_colors, self_loops)
                                   if n in restrict_targets]
                    self_loops  = self_loops.intersection(restrict_targets)
                strght_edges = np.array(new_edges, dtype=int)

                loop_colors = [ec for ec, n in zip(loop_colors, self_loops)
                                if n in restrict_targets]
                self_loops  = self_loops.intersection(restrict_targets)

            if fast:
                if len(strght_edges):
                    dl = 0.5*np.max(nsize) if not simple_nodes else 0.

                    arrow_x  = pos[strght_edges[:, 1], 0] - \
                               pos[strght_edges[:, 0], 0]
                                pos[strght_edges[:, 0], 0]
                    arrow_x -= np.sign(arrow_x) * dl
                    arrow_y  = pos[strght_edges[:, 1], 1] - \
                               pos[strght_edges[:, 0], 1]
                                pos[strght_edges[:, 0], 1]
                    arrow_x -= np.sign(arrow_y) * dl

                    axis.quiver(
                        pos[strght_edges[:, 0], 0], pos[strght_edges[:, 0], 1],
                        arrow_x, arrow_y, scale_units='xy', angles='xy',
                        scale=1, alpha=ealpha, width=1.5e-3, headlength=0.02*size[0], headwidth=0.01*size[0],
                        linewidths=0.5*esize, ec=ecolor, fc=ecolor, zorder=1)

                for i, s in enumerate(self_loops):
                    es = loop_sizes[i]
                    ec = loop_colors[i]
                    cs = _connectionstyle(axis, nsize[s], loop_sizes[i])

                    arrow = FancyArrowPatch(
                        pos[s], pos[s], arrowstyle=arrowstyle,
                        shrinkA=nsize[s], color=ec, linewidth=es,
                        connectionstyle=cs, zorder=1)

                    axis.add_patch(arrow)
                        scale=1, alpha=ealpha, width=3e-3,
                        linewidths=0.5*strght_sizes, ec=ecolor, fc=ecolor,
                        zorder=1)
            else:
                if len(strght_edges):
                    for i, (s, t) in enumerate(strght_edges):
                        xs, ys = pos[s, 0], pos[s, 1]
                        xt, yt = pos[t, 0], pos[t, 1]

                        if curved_edges:
                            arrow = FancyArrowPatch(
                                posA=(xs, ys), posB=(xt, yt),
                                arrowstyle=arrowstyle,
                                connectionstyle='arc3,rad=0.1',
                                alpha=ealpha, fc=ecolor[i], lw=0)
                            axis.add_patch(arrow)
                        else:
                            dl     = 0.5*nsize[t]
                            dx     = xt-xs
                            dx -= np.sign(dx) * dl
                            dy     = yt-ys
                            dy -= np.sign(dy) * dl
                            arrows.append(FancyArrow(
                                xs, ys, dx, dy, width=0.3*strght_sizes[i],
                                head_length=0.7*strght_sizes[i],
                                head_width=0.7*strght_sizes[i],
                                length_includes_head=True, alpha=ealpha,
                                fc=strght_colors[i], lw=0))

                for i, s in enumerate(self_loops):
                    es = loop_sizes[i]
                    ec = loop_colors[i]

                    arrow = FancyArrowPatch(
                        pos[s], pos[s], arrowstyle=arrowstyle,
                        shrinkA=nsize[s], color=ec, linewidth=es,
                        connectionstyle='arc3,rad=1', zorder=1)

                    arrows.append(arrow)

        if not fast:
            arrows = PatchCollection(arrows, match_original=True, alpha=ealpha)
            arrows.set_zorder(1)
            axis.add_collection(arrows)
                        astyle = ArrowStyle.Simple(
                            head_length=0.7*strght_sizes[i],
                            head_width=0.7*strght_sizes[i],
                            tail_width=0.3*strght_sizes[i])

                        sA = 0 if simple_nodes else 0.5*nsize[s]
                        sB = 0 if simple_nodes else 0.5*nsize[t]

                        cs = 'arc3,rad=0.2' if curved_edges else None

                        arrows.append(FancyArrowPatch(
                            posA=(xs, ys), posB=(xt, yt), arrowstyle=astyle,
                            connectionstyle=cs, alpha=ealpha, fc=ecolor[i],
                            zorder=1, shrinkA=sA, shrinkB=sB, lw=eborder_width,
                            ec=eborder_color))

            for i, s in enumerate(self_loops):
                ec = loop_colors[i]
                loop = _plot_loop(
                    i, s, pos, loop_sizes, nsize, max_nsize, xmax, xmin,
                    ymax, ymin, height, width, ec, ealpha, eborder_width,
                    eborder_color, fast, network, restrict_nodes)

                axis.add_artist(loop)

    # add patch arrows
    arrows = PatchCollection(arrows, match_original=True, alpha=ealpha)
    arrows.set_zorder(1)
    axis.add_collection(arrows)

    if kwargs.get('tight', True):
        plt.tight_layout()
@@ -715,6 +742,82 @@ def draw_network(network, nsize="total-degree", ncolor=None, nshape="o",
            hspace=0., wspace=0., left=0., right=0.95 if colorbar else 1.,
            top=1., bottom=0.)

    # annotations
    annotations = kwargs.get("annotations",
        [str(i) for i in range(network.node_nb())] if restrict_nodes is None
        else [str(i) for i in restrict_nodes])

    if isinstance(annotations, str):
        assert annotations in network.node_attributes, \
            "String values for `annotations` must be a node attribute."

        if restrict_nodes is None:
            annotations = network.node_attributes[annotations]
        else:
            annotations = network.get_node_attributes(
                nodes=list(restrict_nodes), name=annotations)
    elif len(annotations) == network.node_nb() and restrict_nodes is not None:
        annotations = [annotations[i] for i in restrict_nodes]
    else:
        assert len(annotations) == n, "One annotation per node is required."

    annotate = kwargs.get("annotate", True)

    if annotate:
        annot = axis.annotate(
            "", xy=(0,0), xytext=(10,10), textcoords="offset points",
            bbox=dict(boxstyle="round", fc="w"),
            arrowprops=dict(arrowstyle="->"))

        annot.set_visible(False)

        def update_annot(ind):
            annot.xy = pos[ind["ind"][0]]
            text = "\n".join([annotations[n] for n in ind["ind"]])
            annot.set_text(text)
            annot.get_bbox_patch().set_facecolor("w")

        def hover(event):
            if hover.bg is None:
                # first run, save the current plot
                hover.bg = fig.canvas.copy_from_bbox(fig.bbox)

            vis = annot.get_visible()
            if event.inaxes == axis:
                if fast or simple_nodes:
                    for sc in scatter:
                        cont, ind = sc.contains(event)
                        if cont:
                            update_annot(ind)
                            fig.canvas.restore_region(hover.bg)
                            annot.set_visible(True)
                            axis.draw_artist(annot)
                            fig.canvas.blit(fig.bbox)
                        else:
                            if vis:
                                annot.set_visible(False)
                                fig.canvas.restore_region(hover.bg)
                                fig.canvas.blit(fig.bbox)
                else:
                    cont, ind = scatter.contains(event)
                    if cont:
                        update_annot(ind)
                        fig.canvas.restore_region(hover.bg)
                        annot.set_visible(True)
                        axis.draw_artist(annot)
                        fig.canvas.blit(fig.bbox)
                    else:
                        if vis:
                            annot.set_visible(False)
                            fig.canvas.restore_region(hover.bg)
                            fig.canvas.blit(fig.bbox)

                fig.canvas.flush_events()

        hover.bg = None

        fig.canvas.mpl_connect("motion_notify_event", hover)

    if show:
        plt.show()

@@ -1117,12 +1220,27 @@ def library_draw(network, nsize="total-degree", ncolor=None, nshape="o",
        ----------------  ------------------  ---------------------------------
        min_*             float               Minimum value for `nsize` or
                                              `esize`
        ----------------  ------------------  ---------------------------------
        annotate          bool                Use annotations to show node
                                              information (default: True)
        ----------------  ------------------  ---------------------------------
                                              Information that will be displayed
        annotations       str or list         such as a node attribute or a list
                                              of values. (default: node id)
        ================  ==================  =================================
    '''
    import matplotlib.pyplot as plt

    # backend and axis
    if nngt.get_config("backend") in ("graph-tool", "igraph"):
    try:
        import igraph
        igv = igraph.__version__
    except:
        igv = '1.0'

    ig_test = nngt.get_config("backend") == "igraph" and igv <= '0.9.6'

    if nngt.get_config("backend") == "graph-tool" or ig_test:
        mpl_backend = mpl.get_backend()

        if mpl_backend.startswith("Qt4"):
@@ -1319,6 +1437,7 @@ def library_draw(network, nsize="total-degree", ncolor=None, nshape="o",
            edge_cmap=palette_continuous(), cmap=ncmap,
            with_labels=show_labels, width=esize, edgecolors=nborder_color)
    elif nngt.get_config("backend") == "igraph":
        import igraph
        from igraph import Layout, PrecalculatedPalette

        pos = None
@@ -1327,6 +1446,8 @@ def library_draw(network, nsize="total-degree", ncolor=None, nshape="o",
            if isinstance(network, nngt.SpatialGraph) and spatial:
                xy  = network.get_positions()
                pos = Layout(xy)
            else:
                pos = network.graph.layout_fruchterman_reingold()
        elif layout == "circular":
            pos = network.graph.layout_circle()
        elif layout == "random":
@@ -1379,9 +1500,12 @@ def library_draw(network, nsize="total-degree", ncolor=None, nshape="o",
            eids  = [network.edge_id(e) for e in restrict_edges]
            graph = network.graph.subgraph_edges(eids, delete_vertices=False)

        graph_artist = GraphArtist(graph, axis, **visual_style)
        if igv > '0.9.6':
            igraph.plot(graph, target=axis, **visual_style)
        else:
            graph_artist = GraphArtist(graph, axis, **visual_style)

        axis.artists.append(graph_artist)
            axis.artists.append(graph_artist)

    if "title" in kwargs:
        axis.set_title(kwargs["title"])
@@ -1505,18 +1629,27 @@ def _node_edge_shape_size(network, nshape, nsize, max_nsize, min_nsize, esize,

            mm = cycle(MarkerStyle.filled_markers)

            shapes  = np.full(network.node_nb(), "", dtype=object)
            shapes  = np.full(n, "", dtype=object)

            for g, m in zip(nshape, mm):
                shapes[g.ids] = m
            if restrict_nodes is None:
                for g, m in zip(nshape, mm):
                    shapes[g.ids] = m
            else:
                converter = {n: i for i, n in enumerate(restrict_nodes)}
                for g, m in zip(nshape, mm):
                    ids = [converter[n]
                           for n in restrict_nodes.intersection(g.ids)]
                    shapes[ids] = m

            markers = list(shapes)
        elif len(nshape) != network.node_nb():
        if len(nshape) == network.node_nb() and restrict_nodes is not None:
            nshape = nshape[list(restrict_nodes)]
        elif len(nshape) != n:
            raise ValueError("When passing an array of markers to "
                             "`nshape`, one entry per node in the "
                             "network must be provided.")
    else:
        markers = [nshape for _ in range(network.node_nb())]
        markers = [nshape for _ in range(n)]

    # size
    if isinstance(nsize, str):
@@ -1537,34 +1670,31 @@ def _node_edge_shape_size(network, nshape, nsize, max_nsize, min_nsize, esize,
            raise ValueError("`nsize` must contain either one entry per node "
                             "or be the same length as `restrict_nodes`.")

    nsize *= 0.01 * size[0]

    if e:
        if isinstance(esize, str):
            esize  = _edge_size(network, edges, esize)
            esize = _edge_size(network, edges, esize)
            esize = _norm_size(esize, max_esize, min_esize)
            esize[esize < threshold] = 0.

        esize *= 0.005 * size[0]  # border on each side (so 0.5 %)
        else:
            esize = _norm_size(esize, max_esize, min_esize)
    else:
        esize = np.array([])

    return markers, nsize, esize


def _set_ax_lim(ax, xdata, ydata, xlims, ylims):
def _set_ax_lim(ax, xmax, xmin, ymax, ymin, height, width, xlims, ylims,
                max_nsize, fast):
    if xlims is not None:
        ax.set_xlim(*xlims)
    else:
        x_min, x_max = np.min(xdata), np.max(xdata)
        width = x_max - x_min
        ax.set_xlim(x_min - 0.05*width, x_max + 0.05*width)
        dx = 0.05*width if fast else 1.5*max_nsize
        ax.set_xlim(xmin - dx, xmax + dx)
    if ylims is not None:
        ax.set_ylim(*ylims)
    else:
        y_min, y_max = np.min(ydata), np.max(ydata)
        height = y_max - y_min
        ax.set_ylim(y_min - 0.05*height, y_max + 0.05*height)
        dy = 0.05*height if fast else 1.5*max_nsize
        ax.set_ylim(ymin - dy, ymax + dy)


def _node_size(network, restrict_nodes, nsize):
@@ -1663,7 +1793,7 @@ def _node_color(network, restrict_nodes, ncolor, discrete_colors=False):
    n = network.node_nb() if restrict_nodes is None else len(restrict_nodes)

    if restrict_nodes is not None:
        restrict_nodes = set(restrict_nodes)
        restrict_nodes = list(set(restrict_nodes))

    if isinstance(ncolor, float):
        color = np.repeat(ncolor, n)
@@ -1711,7 +1841,7 @@ def _node_color(network, restrict_nodes, ncolor, discrete_colors=False):
                    values = network.get_betweenness("node")
                else:
                    values = network.get_betweenness(
                        "node")[list(restrict_nodes)]
                        "node")[restrict_nodes]
            elif ncolor in network.node_attributes:
                values = network.get_node_attributes(
                    name=ncolor, nodes=restrict_nodes)
@@ -1723,7 +1853,7 @@ def _node_color(network, restrict_nodes, ncolor, discrete_colors=False):
                    values = nngt.analyze_graph[ncolor](network)
                else:
                    values = nngt.analyze_graph[ncolor](
                        network)[list(restrict_nodes)]
                        network)[restrict_nodes]
            else:
                raise RuntimeError("Invalid `ncolor`: {}.".format(ncolor))

@@ -1929,59 +2059,12 @@ def _to_gt_prop(graph, value, cmap, ptype='node', color=False):
            return pmap("vector<double>", vals=[cmap(v) for v in normalized])

        return pmap("double", vals=value)
        

    return value


class GraphArtist(Artist):
    """
    Matplotlib artist class that draws igraph graphs.

    Only Cairo-based backends are supported.

    Adapted from: https://stackoverflow.com/a/36154077/5962321
    """

    def __init__(self, graph, axis, palette=None, *args, **kwds):
        """Constructs a graph artist that draws the given graph within
        the given bounding box.

        `graph` must be an instance of `igraph.Graph`.
        `bbox` must either be an instance of `igraph.drawing.BoundingBox`
        or a 4-tuple (`left`, `top`, `width`, `height`). The tuple
        will be passed on to the constructor of `BoundingBox`.
        `palette` is an igraph palette that is used to transform
        numeric color IDs to RGB values. If `None`, a default grayscale
        palette is used from igraph.

        All the remaining positional and keyword arguments are passed
        on intact to `igraph.Graph.__plot__`.
        """
        from igraph import BoundingBox, palettes

        super().__init__()

        self.graph = graph
        self.palette = palette or palettes["gray"]
        self.bbox = BoundingBox(axis.bbox.bounds)
        self.args = args
        self.kwds = kwds

    def draw(self, renderer):
        from matplotlib.backends.backend_cairo import RendererCairo

        if not isinstance(renderer, RendererCairo):
            raise TypeError(
                "graph plotting is supported only on Cairo backends")

        self.graph.__plot__(renderer.gc.ctx, self.bbox, self.palette,
                            *self.args, **self.kwds)

    return value

def _circular_layout(graph, node_size):
    max_nsize = np.max(node_size)

def _circular_layout(graph, max_nsize):
    # chose radius such that r*dtheta > max_nsize
    dtheta = 2*np.pi / graph.node_nb()

@@ -1994,32 +2077,24 @@ def _circular_layout(graph, node_size):
    return np.array((x, y)).T



def _connectionstyle(axis, nsize, esize):
    def cs(posA, posB, *args, **kwargs):
        # check if we need to do a self-loop
        if np.all(posA == posB):
            # Self-loops are scaled by node size
            vshift = 0.1*max(nsize, 2*esize)
            hshift = 0.7*vshift
            # this is called with _screen space_ values so covert back
            # to data space
        # Self-loops are scaled by node size
        vshift = 0.1*max(nsize, 2*esize)
        hshift = 0.7*vshift
        # this is called with _screen space_ values so covert back
        # to data space

            s1 = np.asarray([-hshift, vshift])
            s2 = np.asarray([hshift, vshift])
        s1 = np.asarray([-hshift, vshift])
        s2 = np.asarray([hshift, vshift])

            p1 = axis.transData.inverted().transform(posA)
            p2 = axis.transData.inverted().transform(posA + s1)
            p3 = axis.transData.inverted().transform(posA + s2)
        p1 = axis.transData.inverted().transform(posA)
        p2 = axis.transData.inverted().transform(posA + s1)
        p3 = axis.transData.inverted().transform(posA + s2)

            path = [p1, p2, p3, p1]
        path = [p1, p2, p3, p1]

            ret = mpl.path.Path(axis.transData.transform(path), [1, 2, 2, 2])
        # if not, fall back to the user specified behavior
        else:
            return None

        return ret
        return mpl.path.Path(axis.transData.transform(path), [1, 2, 2, 2])

    return cs

@@ -2051,8 +2126,8 @@ def _split_edges_sizes(edges, esize, decimate_connections, ecolor=None,
        strght_sizes = esize[strght]
        loop_sizes = esize[loops]
    else:
        strght_sizes = [esize]*len(strght_edges)
        loop_sizes = [esize]*len(self_loops)
        strght_sizes = np.full(len(strght_edges), esize)
        loop_sizes = np.full(len(self_loops), esize)

    if decimate_connections > 1:
        strght_edges = \
@@ -2092,3 +2167,264 @@ def _get_ncmap(network, ncolor):
                    else palette_continuous()

    return discrete_colors, default_ncmap


def _plot_loop(i, s, pos, loop_sizes, nsize, max_nsize, xmax, xmin, ymax, ymin,
               height, width, ec, ealpha, eborder_width, eborder_color, fast,
               network, restrict_nodes):
    '''
    Draw self loops
    '''
    es = loop_sizes[i]
    dl = 0.03*max(height, width)
    ns = nsize[s]*dl/max_nsize if fast else nsize[s]

    # get the neighbours
    nn = network.neighbours(s)

    if restrict_nodes is not None:
        nn = nn.intersection(restrict_nodes)

        convert = {n: i for i, n in enumerate(restrict_nodes)}

        nn = {convert[n] for n in nn}

    nn = list(nn - {s})

    vec = pos[nn] - pos[s]
    norm = np.sqrt((vec*vec).sum(axis=1))
    vec = np.asarray([vec[i] / n for i, n in enumerate(norm)])

    dir = np.average(vec, axis=0)
    dir /= np.linalg.norm(dir)

    if fast:
        xy = pos[s] - ns*dir
        return Circle(xy, ns, fc="none", alpha=ealpha, linewidth=0.5*es, ec=ec)

    es = min(0.5*ns, es)
    xy = pos[s] - 0.75*ns*dir

    return Annulus(xy, 0.75*ns, 0.5*es, fc=ec, alpha=ealpha, lw=eborder_width,
                   ec=eborder_color)


class Annulus(Patch):
    """
    An elliptical annulus.
    """

    def __init__(self, xy, r, width, angle=0.0, **kwargs):
        """
        Parameters
        ----------
        xy : (float, float)
            xy coordinates of annulus centre.
        r : float or (float, float)
            The radius, or semi-axes:
            - If float: radius of the outer circle.
            - If two floats: semi-major and -minor axes of outer ellipse.
        width : float
            Width (thickness) of the annular ring. The width is measured inward
            from the outer ellipse so that for the inner ellipse the semi-axes
            are given by ``r - width``. *width* must be less than or equal to
            the semi-minor axis.
        angle : float, default: 0
            Rotation angle in degrees (anti-clockwise from the positive
            x-axis). Ignored for circular annuli (i.e., if *r* is a scalar).
        **kwargs
            Keyword arguments control the `Patch` properties:
            %(Patch:kwdoc)s
        """
        super().__init__(**kwargs)

        self.set_radii(r)
        self.center = xy
        self.width = width
        self.angle = angle
        self._path = None

    def __str__(self):
        if self.a == self.b:
            r = self.a
        else:
            r = (self.a, self.b)

        return "Annulus(xy=(%s, %s), r=%s, width=%s, angle=%s)" % \
                (*self.center, r, self.width, self.angle)

    def set_center(self, xy):
        """
        Set the center of the annulus.
        Parameters
        ----------
        xy : (float, float)
        """
        self._center = xy
        self._path = None
        self.stale = True

    def get_center(self):
        """Return the center of the annulus."""
        return self._center

    center = property(get_center, set_center)

    def set_width(self, width):
        """
        Set the width (thickness) of the annulus ring.
        The width is measured inwards from the outer ellipse.
        Parameters
        ----------
        width : float
        """
        if min(self.a, self.b) <= width:
            raise ValueError(
                'Width of annulus must be less than or equal semi-minor axis')

        self._width = width
        self._path = None
        self.stale = True

    def get_width(self):
        """Return the width (thickness) of the annulus ring."""
        return self._width

    width = property(get_width, set_width)

    def set_angle(self, angle):
        """
        Set the tilt angle of the annulus.
        Parameters
        ----------
        angle : float
        """
        self._angle = angle
        self._path = None
        self.stale = True

    def get_angle(self):
        """Return the angle of the annulus."""
        return self._angle

    angle = property(get_angle, set_angle)

    def set_semimajor(self, a):
        """
        Set the semi-major axis *a* of the annulus.
        Parameters
        ----------
        a : float
        """
        self.a = float(a)
        self._path = None
        self.stale = True

    def set_semiminor(self, b):
        """
        Set the semi-minor axis *b* of the annulus.
        Parameters
        ----------
        b : float
        """
        self.b = float(b)
        self._path = None
        self.stale = True

    def set_radii(self, r):
        """
        Set the semi-major (*a*) and semi-minor radii (*b*) of the annulus.
        Parameters
        ----------
        r : float or (float, float)
            The radius, or semi-axes:
            - If float: radius of the outer circle.
            - If two floats: semi-major and -minor axes of outer ellipse.
        """
        if np.shape(r) == (2,):
            self.a, self.b = r
        elif np.shape(r) == ():
            self.a = self.b = float(r)
        else:
            raise ValueError("Parameter 'r' must be one or two floats.")

        self._path = None
        self.stale = True

    def get_radii(self):
        """Return the semi-major and semi-minor radii of the annulus."""
        return self.a, self.b

    radii = property(get_radii, set_radii)

    def _transform_verts(self, verts, a, b):
        return Affine2D() \
            .scale(*self._convert_xy_units((a, b))) \
            .rotate_deg(self.angle) \
            .translate(*self._convert_xy_units(self.center)) \
            .transform(verts)

    def _recompute_path(self):
        # circular arc
        arc = Path.arc(0, 360)

        # annulus needs to draw an outer ring
        # followed by a reversed and scaled inner ring
        a, b, w = self.a, self.b, self.width
        v1 = self._transform_verts(arc.vertices, a, b)
        v2 = self._transform_verts(arc.vertices[::-1], a - w, b - w)
        v = np.vstack([v1, v2, v1[0, :], (0, 0)])
        c = np.hstack([arc.codes, Path.MOVETO,
                       arc.codes[1:], Path.MOVETO,
                       Path.CLOSEPOLY])
        self._path = Path(v, c)

    def get_path(self):
        if self._path is None:
            self._recompute_path()
        return self._path


class GraphArtist(Artist):
    """
    Matplotlib artist class that draws igraph graphs.

    Only Cairo-based backends are supported.

    Adapted from: https://stackoverflow.com/a/36154077/5962321
    """

    def __init__(self, graph, axis, palette=None, *args, **kwds):
        """Constructs a graph artist that draws the given graph within
        the given bounding box.

        `graph` must be an instance of `igraph.Graph`.
        `bbox` must either be an instance of `igraph.drawing.BoundingBox`
        or a 4-tuple (`left`, `top`, `width`, `height`). The tuple
        will be passed on to the constructor of `BoundingBox`.
        `palette` is an igraph palette that is used to transform
        numeric color IDs to RGB values. If `None`, a default grayscale
        palette is used from igraph.

        All the remaining positional and keyword arguments are passed
        on intact to `igraph.Graph.__plot__`.
        """
        from igraph import BoundingBox, palettes

        super().__init__()

        self.graph = graph
        self.palette = palette or palettes["gray"]
        self.bbox = BoundingBox(axis.bbox.bounds)
        self.args = args
        self.kwds = kwds

    def draw(self, renderer):
        from matplotlib.backends.backend_cairo import RendererCairo

        if not isinstance(renderer, RendererCairo):
            raise TypeError(
                "graph plotting is supported only on Cairo backends")

        self.graph.__plot__(renderer.gc.ctx, self.bbox, self.palette,
                            *self.args, **self.kwds)
diff --git a/nngt/plot/plt_properties.py b/nngt/plot/plt_properties.py
index b678990..775658f 100755
--- a/nngt/plot/plt_properties.py
+++ b/nngt/plot/plt_properties.py
@@ -31,6 +31,8 @@ from nngt.analysis import (degree_distrib, betweenness_distrib,
                           node_attributes, binning)
from .custom_plt import palette_continuous, palette_discrete, format_exponent

from matplotlib.gridspec import SubplotSpec


__all__ = [
    'degree_distribution',
@@ -38,7 +40,7 @@ __all__ = [
    'edge_attributes_distribution',
    'node_attributes_distribution',
    'compare_population_attributes',
    "correlation_to_attribute",
    'correlation_to_attribute',
]


@@ -1015,8 +1017,11 @@ def _set_new_plot(fignum=None, num_new_plots=1, names=None, sharex=None):
    if int(ratio) != int(np.ceil(ratio)):
        num_rows += 1
    # change the geometry
    gs = fig.add_gridspec(num_rows, num_cols)
    for i in range(num_axes - num_new_plots):
        fig.axes[i].change_geometry(num_rows, num_cols, i+1)
        y = i // num_cols
        x = i - num_cols*y
        fig.axes[i].set_subplotspec(gs[y, x])
    lst_new_axes = []
    n_old = num_axes-num_new_plots+1
    for i in range(num_new_plots):
diff --git a/testing/test_plots.py b/testing/test_plots.py
index 4603d97..080b08d 100644
--- a/testing/test_plots.py
+++ b/testing/test_plots.py
@@ -57,7 +57,7 @@ def test_plot_prop():

        nplt.node_attributes_distribution(
            net, ["betweenness", "attr", "out-degree"], colors=["r", "g", "b"],
            show=True)
            show=False)


@pytest.mark.mpi_skip
@@ -80,21 +80,22 @@ def test_draw_network_options():
    # restrict nodes

    nplt.draw_network(net, ncolor="g", nshape='s', ecolor="b",
                      restrict_targets=[1, 2, 3], show=False)
                      restrict_targets=[1, 2, 3], curved_edges=True, show=False)

    nplt.draw_network(net, restrict_nodes=list(range(10)), fast=True,
                      show=False)

    nplt.draw_network(net, restrict_targets=[4, 5, 6, 7, 8], show=False)

    nplt.draw_network(net, restrict_sources=[4, 5, 6, 7, 8], show=False)
    nplt.draw_network(net, restrict_sources=[4, 5, 6, 7, 8], simple_nodes=True,
                      show=False)

    # colors and sizes
    for fast in (True, False):
        maxns = 50 if fast else 10
        minns = 5 if fast else 1
        maxes = 2 if fast else 10
        mines = 0.2 if fast else 1
        maxns = 100 if fast else 20
        minns = 10 if fast else 2
        maxes = 2 if fast else 20
        mines = 0.2 if fast else 2

        nplt.draw_network(net, ncolor="r", nalpha=0.5, ecolor="#999999",
                          ealpha=0.5, nsize="in-degree", max_nsize=maxns,
@@ -113,11 +114,45 @@ def test_draw_network_options():
    nplt.draw_network(net, simple_nodes=True, ncolor="k",
                      decimate_connections=-1, axis=ax, show=False)

    nplt.draw_network(net, simple_nodes=True, ncolor="r", nsize=2,
    nplt.draw_network(net, simple_nodes=True, ncolor="r", nsize=20,
                      restrict_nodes=list(range(10)), esize='weight',
                      ecolor="b", fast=True, axis=ax, show=False)


@pytest.mark.mpi_skip
def test_group_plot():
    ''' Test plotting with a Network and group colors '''
    gsize = 5

    g1 = nngt.Group(gsize)
    g2 = nngt.Group(gsize)

    s = nngt.Structure.from_groups({"1": g1, "2": g2})

    positions = np.concatenate((
        nngt._rng.uniform(-5, -2, size=(gsize, 2)),
        nngt._rng.uniform(2, 5, size=(gsize, 2))))

    g = nngt.SpatialGraph(2*gsize, structure=s, positions=positions)

    nngt.generation.connect_groups(g, g1, g1, "erdos_renyi", edges=5)
    nngt.generation.connect_groups(g, g1, g2, "erdos_renyi", edges=5)
    nngt.generation.connect_groups(g, g2, g2, "erdos_renyi", edges=5)
    nngt.generation.connect_groups(g, g2, g1, "erdos_renyi", edges=5)

    g.new_edge(6, 6, self_loop=True)

    nplt.draw_network(g, ncolor="group", ecolor="group", show_environment=False,
                      fast=True, show=False)

    nplt.draw_network(g, ncolor="group", ecolor="group", max_nsize=0.4,
                      esize=0.3, show_environment=False, show=False)

    nplt.draw_network(g, ncolor="group", ecolor="group", max_nsize=0.4,
                      esize=0.3, show_environment=False, curved_edges=True,
                      show=False)


@pytest.mark.mpi_skip
def test_library_plot():
    ''' Check that plotting with the underlying backend library works '''
@@ -215,6 +250,21 @@ def test_plot_spatial_alpha():
                          nalpha=0.5, esize=0.1 + 3*fast, fast=fast)


@pytest.mark.mpi_skip
def test_annotations():
    num_nodes = 5
    positions = nngt._rng.uniform(-10, 10, (num_nodes, 2))

    g = nngt.generation.erdos_renyi(edges=10, nodes=num_nodes,
                                    positions=positions)

    g.new_node_attribute("name", "string", values=["a", "b", "c", "d", "e"])

    nplt.draw_network(g, annotate=False, show=False)
    nplt.draw_network(g, show=False)
    nplt.draw_network(g, annotations="name", show=False)


if __name__ == "__main__":
    test_plot_config()
    test_plot_prop()
@@ -222,3 +272,5 @@ if __name__ == "__main__":
    test_library_plot()
    test_hive_plot()
    test_plot_spatial_alpha()
    test_group_plot()
    test_annotations()
-- 
2.32.0
NNGT/patches/.build.yml: SUCCESS in 31m3s

[Plot: node annotations and improved plots][0] v3 from [~tfardet][1]

[0]: https://lists.sr.ht/~tfardet/nngt-developers/patches/25530
[1]: mailto:tanguyfardet@protonmail.com

✓ #601471 SUCCESS NNGT/patches/.build.yml https://builds.sr.ht/~tfardet/job/601471