~tfardet: 1 Plot: node annotations and improved plots 7 files changed, 804 insertions(+), 389 deletions(-)
Copy & paste the following snippet into your terminal to import this patchset into git:
curl -s https://lists.sr.ht/~tfardet/nngt-developers/patches/25530/mbox | git am -3Learn more about email & git
From: Tanguy Fardet <tanguyfardet@protonmail.com> Add node annotations. Fix curved edges. Improved draw_network options support. Improved self-loops. Prepare for igraph new matplotlib support. Fix NaturalEarth resources download. --- .gitignore | 3 + doc/examples/graph_structure/plot_layouts.py | 11 +- doc/examples/graph_structure/plot_map.py | 2 +- nngt/geospatial/_cartopy_ne.py | 26 +- nngt/plot/plt_networks.py | 1074 ++++++++++++------ nngt/plot/plt_properties.py | 9 +- testing/test_plots.py | 68 +- 7 files changed, 804 insertions(+), 389 deletions(-) diff --git a/.gitignore b/.gitignore index 8b07cc4..e7a6512 100755 --- a/.gitignore +++ b/.gitignore @@ -38,3 +38,6 @@ nngt.egg-info/ # visualcode *.vscode + +# KDevelop +*.kdev4 diff --git a/doc/examples/graph_structure/plot_layouts.py b/doc/examples/graph_structure/plot_layouts.py index d6cb37d..0d84d02 100644 --- a/doc/examples/graph_structure/plot_layouts.py +++ b/doc/examples/graph_structure/plot_layouts.py @@ -118,14 +118,17 @@ c2 = nngt.geometry.Shape.disk(100, centroid=(50, 0)) shape = nngt.geometry.Shape.from_polygon(c1.union(c2)) -npos = shape.seed_neurons(num_nodes) +max_nsize = 15 +npos = shape.seed_neurons(num_nodes, soma_radius=0.5*max_nsize) -g = nngt.generation.distance_rule(10, shape=shape, nodes=num_nodes, avg_deg=5) +g = nngt.generation.distance_rule(10, shape=shape, nodes=num_nodes, avg_deg=5, + positions=npos) cc = nngt.analysis.local_clustering(g) -nngt.plot.draw_network(g, ncolor=cc, axis=axes[1, 1], ecolor="lightgrey", - tight=False, show=False) +nngt.plot.draw_network(g, ncolor=cc, axis=axes[1, 1], ecolor="grey", show=False, + eborder_width=0.5, eborder_color="w", esize=10, + max_nsize=max_nsize, tight=False) axes[1, 1].set_title("Spatial layout") diff --git a/doc/examples/graph_structure/plot_map.py b/doc/examples/graph_structure/plot_map.py index c0eb053..f5a2ac2 100644 --- a/doc/examples/graph_structure/plot_map.py +++ b/doc/examples/graph_structure/plot_map.py @@ -60,7 +60,7 @@ g.set_weights(nngt._rng.exponential(2, g.edge_nb())) # plot using draw_map and the A3 codes stored in "code" ng.draw_map(g, "code", ncolor="in-degree", esize="weight", threshold=0, - ecolor="grey", proj=ccrs.EqualEarth(), show=False) + ecolor="grey", proj=ccrs.EqualEarth(), max_nsize=20, show=False) plt.tight_layout() plt.show() diff --git a/nngt/geospatial/_cartopy_ne.py b/nngt/geospatial/_cartopy_ne.py index d917910..983db03 100644 --- a/nngt/geospatial/_cartopy_ne.py +++ b/nngt/geospatial/_cartopy_ne.py @@ -53,6 +53,9 @@ class NEShpDownloader(Downloader): _NE_URL_TEMPLATE = ('https://naciscdn.org/naturalearth/{resolution}' '/{category}/ne_{resolution}_{name}.zip') + _NE_URL_BACKUP = ('https://naturalearth.s3.amazonaws.com/{resolution}' + '_{category}/ne_{resolution}_{name}.zip') + def __init__(self, url_template=_NE_URL_TEMPLATE, target_path_template=None, pre_downloaded_path_template=''): ''' adds some NE defaults to the __init__ of a Downloader''' @@ -99,8 +102,8 @@ class NEShpDownloader(Downloader): return target_path - @staticmethod - def default_downloader(): + @classmethod + def default_downloader(cls, backup=False): ''' Return a generic, standard, NEShpDownloader instance. @@ -115,10 +118,16 @@ ne_{resolution}_{name}.shp ''' default_spec = ('shapefiles', 'natural_earth', '{category}', 'ne_{resolution}_{name}.shp') + ne_path_template = os.path.join('{config[data_dir]}', *default_spec) + pre_path_template = os.path.join('{config[pre_existing_data_dir]}', *default_spec) - return NEShpDownloader(target_path_template=ne_path_template, + + url_template = cls._NE_URL_BACKUP if backup else cls._NE_URL_TEMPLATE + + return NEShpDownloader(url_template=url_template, + target_path_template=ne_path_template, pre_downloaded_path_template=pre_path_template) @@ -152,8 +161,15 @@ def natural_earth(resolution='110m', category='physical', name='coastline'): # get hold of the Downloader (typically a NEShpDownloader instance) # which we can then simply call its path method to get the appropriate # shapefile (it will download if necessary) - ne_downloader = NEShpDownloader.default_downloader() format_dict = {'config': cartopy.config, 'category': category, - 'name': name, 'resolution': resolution} + 'name': name, 'resolution': resolution} + + try: + ne_downloader = NEShpDownloader.default_downloader() + return ne_downloader.path(format_dict) + except: + ne_downloader = NEShpDownloader.default_downloader(backup=True) + return ne_downloader.path(format_dict) + diff --git a/nngt/plot/plt_networks.py b/nngt/plot/plt_networks.py index e6b7d41..4adefed 100755 --- a/nngt/plot/plt_networks.py +++ b/nngt/plot/plt_networks.py @@ -28,8 +28,9 @@ import numpy as np import matplotlib as mpl from matplotlib.artist import Artist +from matplotlib.path import Path from matplotlib.patches import FancyArrowPatch, ArrowStyle, FancyArrow, Circle -from matplotlib.patches import Arc, RegularPolygon, PathPatch +from matplotlib.patches import Arc, RegularPolygon, PathPatch, Patch from matplotlib.cm import get_cmap from matplotlib.collections import PatchCollection, PathCollection from matplotlib.colors import ListedColormap, Normalize, ColorConverter @@ -72,8 +73,7 @@ __all__ = ["chord_diagram", "draw_network", "hive_plot", "library_draw"] # ------- # def draw_network(network, nsize="total-degree", ncolor=None, nshape="o", - nborder_color="k", nborder_width=0.5, esize=None, ecolor="k", - ealpha=0.5, curved_edges=False, threshold=0.5, + esize=None, ecolor="k", curved_edges=False, threshold=0.5, decimate_connections=None, spatial=True, restrict_sources=None, restrict_targets=None, restrict_nodes=None, restrict_edges=None, @@ -115,6 +115,8 @@ def draw_network(network, nsize="total-degree", ncolor=None, nshape="o", Edge color. If ecolor="groups", edges color will depend on the source and target groups, i.e. only edges from and toward same groups will have the same color. + curved_edges : bool, optional (default: False) + Whether the edges should be curved or straight. threshold : float, optional (default: 0.5) Size under which edges are not plotted. decimate_connections : int, optional (default: keep all connections) @@ -171,9 +173,17 @@ def draw_network(network, nsize="total-degree", ncolor=None, nshape="o", min_* float Minimum value for `nsize` or `esize` ---------------- ------------------ --------------------------------- - nalpha float Node opacity in [0, 1]` + nalpha float Node opacity in [0, 1]`, default 1 + ---------------- ------------------ --------------------------------- + ealpha float Edge opacity, default 0.5 + ---------------- ------------------ --------------------------------- + Color of the border for nodes (n) + *border_color color or edges (e). + Default to black. ---------------- ------------------ --------------------------------- - ealpha float Edge opacity in [0, 1]` + Border size for nodes (n) or edges + *border_width float (e). Default to .5 for nodes and + .3 for edges (if `fast` is False). ---------------- ------------------ --------------------------------- Whether to use simple nodes (that simple_nodes bool are always the same size) or @@ -185,10 +195,14 @@ def draw_network(network, nsize="total-degree", ncolor=None, nshape="o", # figure and axes size_inches = (size[0]/float(dpi), size[1]/float(dpi)) + fig = None + if axis is None: fig = plt.figure(facecolor='white', figsize=size_inches, dpi=dpi) axis = fig.add_subplot(111, frameon=0, aspect=1) + else: + fig = axis.get_figure() # projections for geographic plots @@ -247,31 +261,83 @@ def draw_network(network, nsize="total-degree", ncolor=None, nshape="o", adj_mat[:, remove] = 0 edges = (np.array(adj_mat.nonzero()).T if restrict_edges is None else - restrict_edges) + np.asarray(restrict_edges)) e = len(edges) - # compute properties decimate_connections = 1 if decimate_connections is None\ else decimate_connections - # get node and edge shape/size properties - simple_nodes = kwargs.get("simple_nodes", False) + # get positions (all cases except circular layout which is done below the + # node sizes + pos = None - if esize is None: - esize = 1 if fast else 5 + spatial *= network.is_spatial() - max_nsize = kwargs.get("max_nsize", None) - min_nsize = kwargs.get("min_nsize", None) + if nonstring_container(layout): + assert np.shape(layout) == (n, 2), "One position per node is required." + pos = np.asarray(layout).astype(float) + spatial = False + elif spatial: + if show_environment: + nngt.geometry.plot.plot_shape(network.shape, axis=axis, show=False) - max_esize = kwargs.get("max_esize", 2) + nodes = None if restrict_nodes is None else list(restrict_nodes) + + pos = network.get_positions(nodes=nodes).astype(float) + elif layout in (None, "random"): + pos = np.random.uniform(size=(n, 2)) - 0.5 + + pos[:, 0] *= size[0] + pos[:, 1] *= size[1] + elif layout not in ("circular", "random", None): + raise ValueError("Unknown `layout`: {}".format(layout)) + + # get node and edge size extrema and drawing properties + simple_nodes = kwargs.get("simple_nodes", fast) + + dist = min(size) + + if pos is not None: + dist = min(pos[:, 0].max() - pos[:, 0].min(), + pos[:, 1].max() - pos[:, 1].min()) + + max_nsize = kwargs.get("max_nsize", 100 if simple_nodes else 0.05*dist) + min_nsize = kwargs.get("min_nsize", 0.2*max_nsize) + + max_esize = kwargs.get("max_esize", 5 if fast else 0.05*dist) min_esize = kwargs.get("min_esize", 0) if fast: simple_nodes = True + max_nsize *= 0.01*min(size) + min_nsize *= 0.01*min(size) + max_esize *= 0.005*min(size) + min_esize *= 0.005*min(size) + threshold *= 0.005*min(size) + + if esize is None: + esize = 0.5*max_esize - max_nsize = (20 if simple_nodes else 5) if max_nsize is None else max_nsize + # circular layout + if isinstance(layout, str) and layout == "circular": + pos = _circular_layout(network, max_nsize) + + # check axis extent + xmax = pos[:, 0].max() + xmin = pos[:, 0].min() + ymax = pos[:, 1].max() + ymin = pos[:, 1].min() + + height = ymax - ymin + width = xmax - xmin + + if not show_environment or not spatial or proj is not None: + # axis.get_data() + _set_ax_lim(axis, xmax, xmin, ymax, ymin, height, width, xlims, ylims, + max_nsize, fast) + # get node and edge shape/size properties markers, nsize, esize = _node_edge_shape_size( network, nshape, nsize, max_nsize, min_nsize, esize, max_esize, min_esize, restrict_nodes, edges, size, threshold, @@ -284,16 +350,23 @@ def draw_network(network, nsize="total-degree", ncolor=None, nshape="o", else: ncolor = "r" + nborder_color = kwargs.get("nborder_color", "k") + nborder_width = kwargs.get("nborder_width", 0.5) + + eborder_color = kwargs.get("eborder_color", "k") + eborder_width = kwargs.get("eborder_width", 0.3) + discrete_colors, default_ncmap = _get_ncmap(network, ncolor) nalpha = kwargs.get("nalpha", 1) + ealpha = kwargs.get("ealpha", 0.5) ncmap = get_cmap(kwargs.get("node_cmap", default_ncmap)) node_color, nticks, ntickslabels, nlabel = _node_color( network, restrict_nodes, ncolor, discrete_colors=discrete_colors) - if nonstring_container(ncolor): + if nonstring_container(ncolor) and not len(ncolor) in (3, 4): assert len(ncolor) == n, "For color arrays, one " +\ "color per node is required." ncolor = "custom" @@ -303,75 +376,20 @@ def draw_network(network, nsize="total-degree", ncolor=None, nshape="o", if not nonstring_container(nborder_color): nborder_color = np.repeat(nborder_color, n) - # check edge color - group_based = False - - default_ecmap = (palette_discrete() if not nonstring_container(ncolor) and - ecolor == "group" else palette_continuous()) - - if ecolor == "groups" or ecolor == "group": - if not network.is_network(): - raise TypeError( - "The graph must be a Network to use `ecolor='groups'`.") - - group_based = True - ecolor = {} - - for i, src in enumerate(network.population): - if network.population[src].ids: - idx1 = network.population[src].ids[0] - for j, tgt in enumerate(network.population): - if network.population[tgt].ids: - idx2 = network.population[tgt].ids[0] - if src == tgt: - ecolor[(src, tgt)] = node_color[idx1] - else: - ecolor[(src, tgt)] = \ - np.abs(0.8*node_color[idx1] - - 0.2*node_color[idx2]) - elif not nonstring_container(ecolor): - ecolor = np.repeat(ecolor, e) - - # draw - pos = kwargs.get("positions", None) - - spatial *= (network.is_spatial() or pos is not None) - - if nonstring_container(layout): - assert np.shape(layout) == (n, 2), "One position per node is required." - pos = np.asarray(layout) - elif layout == "circular": - pos = _circular_layout(network, nsize) - elif pos is None and spatial: - if show_environment: - nngt.geometry.plot.plot_shape(network.shape, axis=axis, show=False) - - nodes = None if restrict_nodes is None else list(restrict_nodes) - - pos = network.get_positions(nodes=nodes) - else: - pos = np.full((n, 2), np.NaN) - - pos[:, 0] = size[0]*(np.random.uniform(size=n)-0.5) - pos[:, 1] = size[1]*(np.random.uniform(size=n)-0.5) - - pos = np.array(pos).astype(float) - # prepare node colors - if nonstring_container(c) and not isinstance(c[0], str): + if nonstring_container(c) and not isinstance(c[0], (str, np.ndarray)): # make the colorbar for the nodes cmap = ncmap + cnorm = None - if colorbar: - cnorm = None - - if discrete_colors: - cmap = _discrete_cmap(len(nticks), ncmap, discrete_colors) - cnorm = Normalize(nticks[0]-0.5, nticks[-1] + 0.5) - else: - cnorm = Normalize(np.min(c), np.max(c)) - c = cnorm(c) + if discrete_colors: + cmap = _discrete_cmap(len(nticks), ncmap, discrete_colors) + cnorm = Normalize(nticks[0]-0.5, nticks[-1] + 0.5) + else: + cnorm = Normalize(np.min(c), np.max(c)) + c = cnorm(c) + if colorbar: sm = plt.cm.ScalarMappable(cmap=cmap, norm=cnorm) if discrete_colors: @@ -407,7 +425,37 @@ def draw_network(network, nsize="total-degree", ncolor=None, nshape="o", c = np.array( [ncmap((node_color - minc)/(np.max(node_color) - minc))]*n) + # check edge color + group_based = False + + default_ecmap = (palette_discrete() if not nonstring_container(ncolor) and + ecolor == "group" else palette_continuous()) + + if ecolor == "groups" or ecolor == "group": + if network.structure is None: + raise TypeError( + "The graph must have a Structure/NeuralPop to use " + "`ecolor='groups'`.") + + group_based = True + ecolor = {} + + for i, src in enumerate(network.structure): + if network.structure[src].ids: + idx1 = network.structure[src].ids[0] + for j, tgt in enumerate(network.structure): + if network.structure[tgt].ids: + idx2 = network.structure[tgt].ids[0] + if src == tgt: + ecolor[(src, tgt)] = c[idx1] + else: + ecolor[(src, tgt)] = 0.7*c[idx1] + 0.3*c[idx2] + elif not nonstring_container(ecolor): + ecolor = np.repeat(ecolor, e) + # plot nodes + scatter = [] + if simple_nodes: if nonstring_container(nshape): # matplotlib scatter does not support marker arrays @@ -416,66 +464,79 @@ def draw_network(network, nsize="total-degree", ncolor=None, nshape="o", ids = g.ids if restrict_nodes is None \ else list(set(g.ids).intersection(restrict_nodes)) - axis.scatter(pos[ids, 0], pos[ids, 1], color=c[ids], - s=0.5*np.array(nsize)[ids], - marker=markers[ids[0]], zorder=2, - edgecolors=nborder_color, - linewidths=nborder_width, alpha=nalpha) + scatter.append( + axis.scatter(pos[ids, 0], pos[ids, 1], color=c[ids], + s=0.5*np.array(nsize)[ids], + marker=markers[ids[0]], zorder=2, + edgecolors=nborder_color, + linewidths=nborder_width, alpha=nalpha)) else: ids = range(network.node_nb()) if restrict_nodes is None \ else restrict_nodes for i in ids: - axis.plot( + scatter.append(axis.scatter( pos[i, 0], pos[i, 1], color=c[i], ms=0.5*nsize[i], - marker=nshape[i], ls="", zorder=2, - mec=nborder_color[i], mew=nborder_width, alpha=nalpha) + marker=nshape[i], zorder=2, mec=nborder_color[i], + mew=nborder_width, alpha=nalpha)) else: - axis.scatter(pos[:, 0], pos[:, 1], color=c, s=0.5*np.array(nsize), - marker=nshape, zorder=2, edgecolor=nborder_color, - linewidths=nborder_width, alpha=nalpha) + scatter.append(axis.scatter( + pos[:, 0], pos[:, 1], color=c, s=0.5*np.array(nsize), + marker=nshape, zorder=2, edgecolor=nborder_color, + linewidths=nborder_width, alpha=nalpha)) else: nodes = [] - axis.set_aspect(1.) - if network.is_network(): - for group in network.population.values(): - idx = group.ids if restrict_nodes is None \ - else list(set(restrict_nodes).intersection(group.ids)) + if network.structure is not None: + converter = None + + if restrict_nodes is not None: + converter = {n: i for i, n in enumerate(restrict_nodes)} + + for group in network.structure.values(): + idx = group.ids + + if restrict_nodes is not None: + idx = [converter[n] + for n in set(restrict_nodes).intersection(idx)] + for i, fc in zip(idx, c[idx]): m = MarkerStyle(markers[i]).get_path() + center = np.average(m.vertices, axis=0) + m = Path(m.vertices - center, m.codes) transform = Affine2D().scale( 0.5*nsize[i]).translate(pos[i][0], pos[i][1]) - patch = PathPatch(m.transformed(transform), facecolor=fc, - edgecolor=nborder_color[i], alpha=nalpha) + patch = PathPatch( + m.transformed(transform), facecolor=fc, + lw=nborder_width, edgecolor=nborder_color[i], + alpha=nalpha) nodes.append(patch) else: for i, ci in enumerate(c): m = MarkerStyle(markers[i]).get_path() + center = np.average(m.vertices, axis=0) + m = Path(m.vertices - center, m.codes) transform = Affine2D().scale(0.5*nsize[i]).translate( - pos[i][0], pos[i][1]) - patch = PathPatch(m.transformed(transform), facecolor=ci, - edgecolor=nborder_color[i], alpha=nalpha) + pos[i, 0], pos[i, 1]) + patch = PathPatch( + m.transformed(transform), facecolor=ci, + lw=nborder_width, edgecolor=nborder_color[i], alpha=nalpha) nodes.append(patch) - nodes = PatchCollection(nodes, match_original=True, alpha=nalpha) - nodes.set_zorder(2) - axis.add_collection(nodes) + scatter = PatchCollection(nodes, match_original=True, alpha=nalpha) + scatter.set_zorder(2) + axis.add_collection(scatter) - if not show_environment or not spatial or proj is not None: - # axis.get_data() - _set_ax_lim(axis, pos[:, 0], pos[:, 1], xlims, ylims) + # draw the edges + arrows = [] - # use quiver to draw the edges if e and decimate_connections != -1: avg_size = np.average(nsize) - arrows = [] - if group_based: - for src_name, src_group in network.population.items(): - for tgt_name, tgt_group in network.population.items(): + for src_name, src_group in network.structure.items(): + for tgt_name, tgt_group in network.structure.items(): s_ids = src_group.ids if restrict_sources is not None: @@ -492,222 +553,188 @@ def draw_network(network, nsize="total-degree", ncolor=None, nshape="o", edges = np.array( adj_mat[s_min:s_max, t_min:t_max].nonzero(), - dtype=int) + dtype=int).T - edges[0, :] += s_min - edges[1, :] += t_min + edges[:, 0] += s_min + edges[:, 1] += t_min strght_edges, self_loops, strght_sizes, loop_sizes = \ _split_edges_sizes(edges, esize, decimate_connections) # plot - ec = default_ecmap(ecolor[(src_name, tgt_name)]) + ec = ecolor[(src_name, tgt_name)] - if fast: - dl = 0.5*np.max(nsize) - arrow_x = pos[strght_edges[1], 0] - \ - pos[strght_edges[0], 0] + if len(strght_edges) and fast: + dl = 0 if simple_nodes else 0.5*np.max(nsize) + arrow_x = pos[strght_edges[:, 1], 0] - \ + pos[strght_edges[:, 0], 0] arrow_x -= np.sign(arrow_x) * dl - arrow_y = pos[strght_edges[1], 1] - \ - pos[strght_edges[0], 1] + arrow_y = pos[strght_edges[:, 1], 1] - \ + pos[strght_edges[:, 0], 1] arrow_x -= np.sign(arrow_y) * dl axis.quiver( - pos[strght_edges[0], 0], - pos[strght_edges[0], 1], arrow_x, + pos[strght_edges[:, 0], 0], + pos[strght_edges[:, 0], 1], arrow_x, arrow_y, scale_units='xy', angles='xy', - scale=1, alpha=ealpha, width=0.005*size[0], - linewidths=0.5*strght_sizes, edgecolors=ec, - zorder=1, **kw) - - for i, s in enumerate(self_loops): - es = loop_sizes[i] - ec = loop_colors[i] - cs = _connectionstyle(axis, nsize[s], - loop_sizes[i]) - - arrow = FancyArrowPatch( - pos[s], pos[s], arrowstyle=arrowstyle, - shrinkA=nsize[s], color=ec, - linewidth=es, connectionstyle=cs, zorder=1) - - axis.add_patch(arrow) - else: - for s, t in zip(strght_edges[0], strght_edges[1]): + scale=1, alpha=ealpha, + width=3e-3, linewidths=0.5*strght_sizes, + edgecolors=ec, color=ec, zorder=1, **kw) + elif len(strght_edges): + for i, (s, t) in enumerate(strght_edges): xs, ys = pos[s, 0], pos[s, 1] xt, yt = pos[t, 0], pos[t, 1] - dl = 0.5*nsize[t] - dx = xt-xs - dx -= np.sign(dx) * dl - dy = yt-ys - dy -= np.sign(dy) * dl - - if curved_edges: - arrow = FancyArrowPatch( - posA=(xs, ys), posB=(xt, yt), - arrowstyle=arrowstyle, - connectionstyle='arc3,rad=0.1', - alpha=ealpha, fc=ec, lw=0.5) - axis.add_patch(arrow) - elif s == t: - # self-loop - arrow = FancyArrowPatch( - posA=(xs, ys), posB=(xt, yt), - arrowstyle=arrowstyle, - connectionstyle='arc3,rad=1', - alpha=ealpha, fc=ec, lw=0.5) - - axis.add_patch(arrow) - else: - arrows.append(FancyArrow( - xs, ys, dx, dy, width=0.3*avg_size, - head_length=0.7*avg_size, - head_width=0.7*avg_size, - length_includes_head=True, - alpha=ealpha, fc=ec, lw=0.5)) - else: - if e and decimate_connections >= 1: - strght_colors, loop_colors = [], [] - strght_edges, self_loops, strght_sizes, loop_sizes = \ - _split_edges_sizes(edges, esize, decimate_connections, - ecolor, strght_colors, loop_colors) + sA = 0 if simple_nodes else 0.5*nsize[s] + sB = 0 if simple_nodes else 0.5*nsize[t] - # keep only desired edges - if None not in (restrict_sources, restrict_targets): - new_edges = [] - new_colors = [] + cs = 'arc3,rad=0.2' if curved_edges else None + + astyle = ArrowStyle.Simple( + head_length=0.7*strght_sizes[i], + head_width=0.7*strght_sizes[i], + tail_width=0.3*strght_sizes[i]) - for edge, ec in zip(strght_edges, strght_colors): - s, t = edge + arrows.append(FancyArrowPatch( + posA=(xs, ys), posB=(xt, yt), + arrowstyle=astyle, connectionstyle=cs, + alpha=ealpha, fc=ec, zorder=1, + shrinkA=0.5*nsize[s], shrinkB=0.5*nsize[t], + lw=eborder_width, ec=eborder_color)) - if s in restrict_sources and t in restrict_targets: - new_edges.append(edge) - new_colors.append(ec) + for i, s in enumerate(self_loops): + loop = _plot_loop( + i, s, pos, loop_sizes, nsize, max_nsize, xmax, + xmin, ymax, ymin, height, width, ec, ealpha, + eborder_width, eborder_color, fast, network, + restrict_nodes) + + axis.add_artist(loop) + else: + strght_colors, loop_colors = [], [] - strght_edges = np.array(new_edges, dtype=int) - strght_colors = new_colors + strght_edges, self_loops, strght_sizes, loop_sizes = \ + _split_edges_sizes(edges, esize, decimate_connections, + ecolor, strght_colors, loop_colors) - if restrict_nodes is not None: - nodes = list(self_loops) - nodes.sort() + # keep only desired edges + if None not in (restrict_sources, restrict_targets): + new_edges = [] + new_colors = [] - new_loops = set() - new_colors = [] + for edge, ec in zip(strght_edges, strght_colors): + s, t = edge - for i, node in enumerate(restrict_nodes): - strght_edges[strght_edges == node] = i + if s in restrict_sources and t in restrict_targets: + new_edges.append(edge) + new_colors.append(ec) - if node in self_loops: - idx = nodes.index(node) - new_loops.add(i) - new_colors.append(loop_colors[idx]) + strght_edges = np.array(new_edges, dtype=int) + strght_colors = new_colors - self_loops = new_loops - loop_colors = new_colors - elif restrict_sources is not None: - new_edges = [] + if restrict_nodes is not None: + nodes = list(self_loops) + nodes.sort() + + new_loops = set() new_colors = [] - for edge, ec in zip(strght_edges, strght_colors): - s, _ = edge + for i, node in enumerate(restrict_nodes): + strght_edges[strght_edges == node] = i - if s in restrict_sources: - new_edges.append(edge) - new_colors.append(ec) + if node in self_loops: + idx = nodes.index(node) + new_loops.add(i) + new_colors.append(loop_colors[idx]) - strght_edges = np.array(new_edges, dtype=int) + self_loops = new_loops + loop_colors = new_colors + elif restrict_sources is not None: + new_edges = [] + new_colors = [] - loop_colors = [ec for ec, n in zip(loop_colors, self_loops) - if n in restrict_sources] - self_loops = self_loops.intersection(restrict_sources) - elif restrict_targets is not None: - new_edges = [] - new_colors = [] + for edge, ec in zip(strght_edges, strght_colors): + s, _ = edge + + if s in restrict_sources: + new_edges.append(edge) + new_colors.append(ec) + + strght_edges = np.array(new_edges, dtype=int) - for edge, ec in zip(strght_edges, strght_colors): - _, t = edge + loop_colors = [ec for ec, n in zip(loop_colors, self_loops) + if n in restrict_sources] + self_loops = self_loops.intersection(restrict_sources) + elif restrict_targets is not None: + new_edges = [] + new_colors = [] - if t in restrict_targets: - new_edges.append(edge) - new_colors.append(ec) + for edge, ec in zip(strght_edges, strght_colors): + _, t = edge - strght_edges = np.array(new_edges, dtype=int) + if t in restrict_targets: + new_edges.append(edge) + new_colors.append(ec) - loop_colors = [ec for ec, n in zip(loop_colors, self_loops) - if n in restrict_targets] - self_loops = self_loops.intersection(restrict_targets) + strght_edges = np.array(new_edges, dtype=int) + + loop_colors = [ec for ec, n in zip(loop_colors, self_loops) + if n in restrict_targets] + self_loops = self_loops.intersection(restrict_targets) if fast: if len(strght_edges): dl = 0.5*np.max(nsize) if not simple_nodes else 0. arrow_x = pos[strght_edges[:, 1], 0] - \ - pos[strght_edges[:, 0], 0] + pos[strght_edges[:, 0], 0] arrow_x -= np.sign(arrow_x) * dl arrow_y = pos[strght_edges[:, 1], 1] - \ - pos[strght_edges[:, 0], 1] + pos[strght_edges[:, 0], 1] arrow_x -= np.sign(arrow_y) * dl axis.quiver( pos[strght_edges[:, 0], 0], pos[strght_edges[:, 0], 1], arrow_x, arrow_y, scale_units='xy', angles='xy', - scale=1, alpha=ealpha, width=1.5e-3, headlength=0.02*size[0], headwidth=0.01*size[0], - linewidths=0.5*esize, ec=ecolor, fc=ecolor, zorder=1) - - for i, s in enumerate(self_loops): - es = loop_sizes[i] - ec = loop_colors[i] - cs = _connectionstyle(axis, nsize[s], loop_sizes[i]) - - arrow = FancyArrowPatch( - pos[s], pos[s], arrowstyle=arrowstyle, - shrinkA=nsize[s], color=ec, linewidth=es, - connectionstyle=cs, zorder=1) - - axis.add_patch(arrow) + scale=1, alpha=ealpha, width=3e-3, + linewidths=0.5*strght_sizes, ec=ecolor, fc=ecolor, + zorder=1) else: if len(strght_edges): for i, (s, t) in enumerate(strght_edges): xs, ys = pos[s, 0], pos[s, 1] xt, yt = pos[t, 0], pos[t, 1] - if curved_edges: - arrow = FancyArrowPatch( - posA=(xs, ys), posB=(xt, yt), - arrowstyle=arrowstyle, - connectionstyle='arc3,rad=0.1', - alpha=ealpha, fc=ecolor[i], lw=0) - axis.add_patch(arrow) - else: - dl = 0.5*nsize[t] - dx = xt-xs - dx -= np.sign(dx) * dl - dy = yt-ys - dy -= np.sign(dy) * dl - arrows.append(FancyArrow( - xs, ys, dx, dy, width=0.3*strght_sizes[i], - head_length=0.7*strght_sizes[i], - head_width=0.7*strght_sizes[i], - length_includes_head=True, alpha=ealpha, - fc=strght_colors[i], lw=0)) - - for i, s in enumerate(self_loops): - es = loop_sizes[i] - ec = loop_colors[i] - - arrow = FancyArrowPatch( - pos[s], pos[s], arrowstyle=arrowstyle, - shrinkA=nsize[s], color=ec, linewidth=es, - connectionstyle='arc3,rad=1', zorder=1) - - arrows.append(arrow) - - if not fast: - arrows = PatchCollection(arrows, match_original=True, alpha=ealpha) - arrows.set_zorder(1) - axis.add_collection(arrows) + astyle = ArrowStyle.Simple( + head_length=0.7*strght_sizes[i], + head_width=0.7*strght_sizes[i], + tail_width=0.3*strght_sizes[i]) + + sA = 0 if simple_nodes else 0.5*nsize[s] + sB = 0 if simple_nodes else 0.5*nsize[t] + + cs = 'arc3,rad=0.2' if curved_edges else None + + arrows.append(FancyArrowPatch( + posA=(xs, ys), posB=(xt, yt), arrowstyle=astyle, + connectionstyle=cs, alpha=ealpha, fc=ecolor[i], + zorder=1, shrinkA=sA, shrinkB=sB, lw=eborder_width, + ec=eborder_color)) + + for i, s in enumerate(self_loops): + ec = loop_colors[i] + loop = _plot_loop( + i, s, pos, loop_sizes, nsize, max_nsize, xmax, xmin, + ymax, ymin, height, width, ec, ealpha, eborder_width, + eborder_color, fast, network, restrict_nodes) + + axis.add_artist(loop) + + # add patch arrows + arrows = PatchCollection(arrows, match_original=True, alpha=ealpha) + arrows.set_zorder(1) + axis.add_collection(arrows) if kwargs.get('tight', True): plt.tight_layout() @@ -715,6 +742,82 @@ def draw_network(network, nsize="total-degree", ncolor=None, nshape="o", hspace=0., wspace=0., left=0., right=0.95 if colorbar else 1., top=1., bottom=0.) + # annotations + annotations = kwargs.get("annotations", + [str(i) for i in range(network.node_nb())] if restrict_nodes is None + else [str(i) for i in restrict_nodes]) + + if isinstance(annotations, str): + assert annotations in network.node_attributes, \ + "String values for `annotations` must be a node attribute." + + if restrict_nodes is None: + annotations = network.node_attributes[annotations] + else: + annotations = network.get_node_attributes( + nodes=list(restrict_nodes), name=annotations) + elif len(annotations) == network.node_nb() and restrict_nodes is not None: + annotations = [annotations[i] for i in restrict_nodes] + else: + assert len(annotations) == n, "One annotation per node is required." + + annotate = kwargs.get("annotate", True) + + if annotate: + annot = axis.annotate( + "", xy=(0,0), xytext=(10,10), textcoords="offset points", + bbox=dict(boxstyle="round", fc="w"), + arrowprops=dict(arrowstyle="->")) + + annot.set_visible(False) + + def update_annot(ind): + annot.xy = pos[ind["ind"][0]] + text = "\n".join([annotations[n] for n in ind["ind"]]) + annot.set_text(text) + annot.get_bbox_patch().set_facecolor("w") + + def hover(event): + if hover.bg is None: + # first run, save the current plot + hover.bg = fig.canvas.copy_from_bbox(fig.bbox) + + vis = annot.get_visible() + if event.inaxes == axis: + if fast or simple_nodes: + for sc in scatter: + cont, ind = sc.contains(event) + if cont: + update_annot(ind) + fig.canvas.restore_region(hover.bg) + annot.set_visible(True) + axis.draw_artist(annot) + fig.canvas.blit(fig.bbox) + else: + if vis: + annot.set_visible(False) + fig.canvas.restore_region(hover.bg) + fig.canvas.blit(fig.bbox) + else: + cont, ind = scatter.contains(event) + if cont: + update_annot(ind) + fig.canvas.restore_region(hover.bg) + annot.set_visible(True) + axis.draw_artist(annot) + fig.canvas.blit(fig.bbox) + else: + if vis: + annot.set_visible(False) + fig.canvas.restore_region(hover.bg) + fig.canvas.blit(fig.bbox) + + fig.canvas.flush_events() + + hover.bg = None + + fig.canvas.mpl_connect("motion_notify_event", hover) + if show: plt.show() @@ -1117,12 +1220,27 @@ def library_draw(network, nsize="total-degree", ncolor=None, nshape="o", ---------------- ------------------ --------------------------------- min_* float Minimum value for `nsize` or `esize` + ---------------- ------------------ --------------------------------- + annotate bool Use annotations to show node + information (default: True) + ---------------- ------------------ --------------------------------- + Information that will be displayed + annotations str or list such as a node attribute or a list + of values. (default: node id) ================ ================== ================================= ''' import matplotlib.pyplot as plt # backend and axis - if nngt.get_config("backend") in ("graph-tool", "igraph"): + try: + import igraph + igv = igraph.__version__ + except: + igv = '1.0' + + ig_test = nngt.get_config("backend") == "igraph" and igv <= '0.9.6' + + if nngt.get_config("backend") == "graph-tool" or ig_test: mpl_backend = mpl.get_backend() if mpl_backend.startswith("Qt4"): @@ -1319,6 +1437,7 @@ def library_draw(network, nsize="total-degree", ncolor=None, nshape="o", edge_cmap=palette_continuous(), cmap=ncmap, with_labels=show_labels, width=esize, edgecolors=nborder_color) elif nngt.get_config("backend") == "igraph": + import igraph from igraph import Layout, PrecalculatedPalette pos = None @@ -1327,6 +1446,8 @@ def library_draw(network, nsize="total-degree", ncolor=None, nshape="o", if isinstance(network, nngt.SpatialGraph) and spatial: xy = network.get_positions() pos = Layout(xy) + else: + pos = network.graph.layout_fruchterman_reingold() elif layout == "circular": pos = network.graph.layout_circle() elif layout == "random": @@ -1379,9 +1500,12 @@ def library_draw(network, nsize="total-degree", ncolor=None, nshape="o", eids = [network.edge_id(e) for e in restrict_edges] graph = network.graph.subgraph_edges(eids, delete_vertices=False) - graph_artist = GraphArtist(graph, axis, **visual_style) + if igv > '0.9.6': + igraph.plot(graph, target=axis, **visual_style) + else: + graph_artist = GraphArtist(graph, axis, **visual_style) - axis.artists.append(graph_artist) + axis.artists.append(graph_artist) if "title" in kwargs: axis.set_title(kwargs["title"]) @@ -1505,18 +1629,27 @@ def _node_edge_shape_size(network, nshape, nsize, max_nsize, min_nsize, esize, mm = cycle(MarkerStyle.filled_markers) - shapes = np.full(network.node_nb(), "", dtype=object) + shapes = np.full(n, "", dtype=object) - for g, m in zip(nshape, mm): - shapes[g.ids] = m + if restrict_nodes is None: + for g, m in zip(nshape, mm): + shapes[g.ids] = m + else: + converter = {n: i for i, n in enumerate(restrict_nodes)} + for g, m in zip(nshape, mm): + ids = [converter[n] + for n in restrict_nodes.intersection(g.ids)] + shapes[ids] = m markers = list(shapes) - elif len(nshape) != network.node_nb(): + if len(nshape) == network.node_nb() and restrict_nodes is not None: + nshape = nshape[list(restrict_nodes)] + elif len(nshape) != n: raise ValueError("When passing an array of markers to " "`nshape`, one entry per node in the " "network must be provided.") else: - markers = [nshape for _ in range(network.node_nb())] + markers = [nshape for _ in range(n)] # size if isinstance(nsize, str): @@ -1537,34 +1670,31 @@ def _node_edge_shape_size(network, nshape, nsize, max_nsize, min_nsize, esize, raise ValueError("`nsize` must contain either one entry per node " "or be the same length as `restrict_nodes`.") - nsize *= 0.01 * size[0] - if e: if isinstance(esize, str): - esize = _edge_size(network, edges, esize) + esize = _edge_size(network, edges, esize) esize = _norm_size(esize, max_esize, min_esize) esize[esize < threshold] = 0. - - esize *= 0.005 * size[0] # border on each side (so 0.5 %) + else: + esize = _norm_size(esize, max_esize, min_esize) else: esize = np.array([]) return markers, nsize, esize -def _set_ax_lim(ax, xdata, ydata, xlims, ylims): +def _set_ax_lim(ax, xmax, xmin, ymax, ymin, height, width, xlims, ylims, + max_nsize, fast): if xlims is not None: ax.set_xlim(*xlims) else: - x_min, x_max = np.min(xdata), np.max(xdata) - width = x_max - x_min - ax.set_xlim(x_min - 0.05*width, x_max + 0.05*width) + dx = 0.05*width if fast else 1.5*max_nsize + ax.set_xlim(xmin - dx, xmax + dx) if ylims is not None: ax.set_ylim(*ylims) else: - y_min, y_max = np.min(ydata), np.max(ydata) - height = y_max - y_min - ax.set_ylim(y_min - 0.05*height, y_max + 0.05*height) + dy = 0.05*height if fast else 1.5*max_nsize + ax.set_ylim(ymin - dy, ymax + dy) def _node_size(network, restrict_nodes, nsize): @@ -1663,7 +1793,7 @@ def _node_color(network, restrict_nodes, ncolor, discrete_colors=False): n = network.node_nb() if restrict_nodes is None else len(restrict_nodes) if restrict_nodes is not None: - restrict_nodes = set(restrict_nodes) + restrict_nodes = list(set(restrict_nodes)) if isinstance(ncolor, float): color = np.repeat(ncolor, n) @@ -1711,7 +1841,7 @@ def _node_color(network, restrict_nodes, ncolor, discrete_colors=False): values = network.get_betweenness("node") else: values = network.get_betweenness( - "node")[list(restrict_nodes)] + "node")[restrict_nodes] elif ncolor in network.node_attributes: values = network.get_node_attributes( name=ncolor, nodes=restrict_nodes) @@ -1723,7 +1853,7 @@ def _node_color(network, restrict_nodes, ncolor, discrete_colors=False): values = nngt.analyze_graph[ncolor](network) else: values = nngt.analyze_graph[ncolor]( - network)[list(restrict_nodes)] + network)[restrict_nodes] else: raise RuntimeError("Invalid `ncolor`: {}.".format(ncolor)) @@ -1929,59 +2059,12 @@ def _to_gt_prop(graph, value, cmap, ptype='node', color=False): return pmap("vector<double>", vals=[cmap(v) for v in normalized]) return pmap("double", vals=value) - - - return value - - -class GraphArtist(Artist): - """ - Matplotlib artist class that draws igraph graphs. - - Only Cairo-based backends are supported. - Adapted from: https://stackoverflow.com/a/36154077/5962321 - """ - - def __init__(self, graph, axis, palette=None, *args, **kwds): - """Constructs a graph artist that draws the given graph within - the given bounding box. - - `graph` must be an instance of `igraph.Graph`. - `bbox` must either be an instance of `igraph.drawing.BoundingBox` - or a 4-tuple (`left`, `top`, `width`, `height`). The tuple - will be passed on to the constructor of `BoundingBox`. - `palette` is an igraph palette that is used to transform - numeric color IDs to RGB values. If `None`, a default grayscale - palette is used from igraph. - - All the remaining positional and keyword arguments are passed - on intact to `igraph.Graph.__plot__`. - """ - from igraph import BoundingBox, palettes - - super().__init__() - - self.graph = graph - self.palette = palette or palettes["gray"] - self.bbox = BoundingBox(axis.bbox.bounds) - self.args = args - self.kwds = kwds - - def draw(self, renderer): - from matplotlib.backends.backend_cairo import RendererCairo - - if not isinstance(renderer, RendererCairo): - raise TypeError( - "graph plotting is supported only on Cairo backends") - - self.graph.__plot__(renderer.gc.ctx, self.bbox, self.palette, - *self.args, **self.kwds) + return value -def _circular_layout(graph, node_size): - max_nsize = np.max(node_size) +def _circular_layout(graph, max_nsize): # chose radius such that r*dtheta > max_nsize dtheta = 2*np.pi / graph.node_nb() @@ -1994,32 +2077,24 @@ def _circular_layout(graph, node_size): return np.array((x, y)).T - def _connectionstyle(axis, nsize, esize): def cs(posA, posB, *args, **kwargs): - # check if we need to do a self-loop - if np.all(posA == posB): - # Self-loops are scaled by node size - vshift = 0.1*max(nsize, 2*esize) - hshift = 0.7*vshift - # this is called with _screen space_ values so covert back - # to data space + # Self-loops are scaled by node size + vshift = 0.1*max(nsize, 2*esize) + hshift = 0.7*vshift + # this is called with _screen space_ values so covert back + # to data space - s1 = np.asarray([-hshift, vshift]) - s2 = np.asarray([hshift, vshift]) + s1 = np.asarray([-hshift, vshift]) + s2 = np.asarray([hshift, vshift]) - p1 = axis.transData.inverted().transform(posA) - p2 = axis.transData.inverted().transform(posA + s1) - p3 = axis.transData.inverted().transform(posA + s2) + p1 = axis.transData.inverted().transform(posA) + p2 = axis.transData.inverted().transform(posA + s1) + p3 = axis.transData.inverted().transform(posA + s2) - path = [p1, p2, p3, p1] + path = [p1, p2, p3, p1] - ret = mpl.path.Path(axis.transData.transform(path), [1, 2, 2, 2]) - # if not, fall back to the user specified behavior - else: - return None - - return ret + return mpl.path.Path(axis.transData.transform(path), [1, 2, 2, 2]) return cs @@ -2051,8 +2126,8 @@ def _split_edges_sizes(edges, esize, decimate_connections, ecolor=None, strght_sizes = esize[strght] loop_sizes = esize[loops] else: - strght_sizes = [esize]*len(strght_edges) - loop_sizes = [esize]*len(self_loops) + strght_sizes = np.full(len(strght_edges), esize) + loop_sizes = np.full(len(self_loops), esize) if decimate_connections > 1: strght_edges = \ @@ -2092,3 +2167,264 @@ def _get_ncmap(network, ncolor): else palette_continuous() return discrete_colors, default_ncmap + + +def _plot_loop(i, s, pos, loop_sizes, nsize, max_nsize, xmax, xmin, ymax, ymin, + height, width, ec, ealpha, eborder_width, eborder_color, fast, + network, restrict_nodes): + ''' + Draw self loops + ''' + es = loop_sizes[i] + dl = 0.03*max(height, width) + ns = nsize[s]*dl/max_nsize if fast else nsize[s] + + # get the neighbours + nn = network.neighbours(s) + + if restrict_nodes is not None: + nn = nn.intersection(restrict_nodes) + + convert = {n: i for i, n in enumerate(restrict_nodes)} + + nn = {convert[n] for n in nn} + + nn = list(nn - {s}) + + vec = pos[nn] - pos[s] + norm = np.sqrt((vec*vec).sum(axis=1)) + vec = np.asarray([vec[i] / n for i, n in enumerate(norm)]) + + dir = np.average(vec, axis=0) + dir /= np.linalg.norm(dir) + + if fast: + xy = pos[s] - ns*dir + return Circle(xy, ns, fc="none", alpha=ealpha, linewidth=0.5*es, ec=ec) + + es = min(0.5*ns, es) + xy = pos[s] - 0.75*ns*dir + + return Annulus(xy, 0.75*ns, 0.5*es, fc=ec, alpha=ealpha, lw=eborder_width, + ec=eborder_color) + + +class Annulus(Patch): + """ + An elliptical annulus. + """ + + def __init__(self, xy, r, width, angle=0.0, **kwargs): + """ + Parameters + ---------- + xy : (float, float) + xy coordinates of annulus centre. + r : float or (float, float) + The radius, or semi-axes: + - If float: radius of the outer circle. + - If two floats: semi-major and -minor axes of outer ellipse. + width : float + Width (thickness) of the annular ring. The width is measured inward + from the outer ellipse so that for the inner ellipse the semi-axes + are given by ``r - width``. *width* must be less than or equal to + the semi-minor axis. + angle : float, default: 0 + Rotation angle in degrees (anti-clockwise from the positive + x-axis). Ignored for circular annuli (i.e., if *r* is a scalar). + **kwargs + Keyword arguments control the `Patch` properties: + %(Patch:kwdoc)s + """ + super().__init__(**kwargs) + + self.set_radii(r) + self.center = xy + self.width = width + self.angle = angle + self._path = None + + def __str__(self): + if self.a == self.b: + r = self.a + else: + r = (self.a, self.b) + + return "Annulus(xy=(%s, %s), r=%s, width=%s, angle=%s)" % \ + (*self.center, r, self.width, self.angle) + + def set_center(self, xy): + """ + Set the center of the annulus. + Parameters + ---------- + xy : (float, float) + """ + self._center = xy + self._path = None + self.stale = True + + def get_center(self): + """Return the center of the annulus.""" + return self._center + + center = property(get_center, set_center) + + def set_width(self, width): + """ + Set the width (thickness) of the annulus ring. + The width is measured inwards from the outer ellipse. + Parameters + ---------- + width : float + """ + if min(self.a, self.b) <= width: + raise ValueError( + 'Width of annulus must be less than or equal semi-minor axis') + + self._width = width + self._path = None + self.stale = True + + def get_width(self): + """Return the width (thickness) of the annulus ring.""" + return self._width + + width = property(get_width, set_width) + + def set_angle(self, angle): + """ + Set the tilt angle of the annulus. + Parameters + ---------- + angle : float + """ + self._angle = angle + self._path = None + self.stale = True + + def get_angle(self): + """Return the angle of the annulus.""" + return self._angle + + angle = property(get_angle, set_angle) + + def set_semimajor(self, a): + """ + Set the semi-major axis *a* of the annulus. + Parameters + ---------- + a : float + """ + self.a = float(a) + self._path = None + self.stale = True + + def set_semiminor(self, b): + """ + Set the semi-minor axis *b* of the annulus. + Parameters + ---------- + b : float + """ + self.b = float(b) + self._path = None + self.stale = True + + def set_radii(self, r): + """ + Set the semi-major (*a*) and semi-minor radii (*b*) of the annulus. + Parameters + ---------- + r : float or (float, float) + The radius, or semi-axes: + - If float: radius of the outer circle. + - If two floats: semi-major and -minor axes of outer ellipse. + """ + if np.shape(r) == (2,): + self.a, self.b = r + elif np.shape(r) == (): + self.a = self.b = float(r) + else: + raise ValueError("Parameter 'r' must be one or two floats.") + + self._path = None + self.stale = True + + def get_radii(self): + """Return the semi-major and semi-minor radii of the annulus.""" + return self.a, self.b + + radii = property(get_radii, set_radii) + + def _transform_verts(self, verts, a, b): + return Affine2D() \ + .scale(*self._convert_xy_units((a, b))) \ + .rotate_deg(self.angle) \ + .translate(*self._convert_xy_units(self.center)) \ + .transform(verts) + + def _recompute_path(self): + # circular arc + arc = Path.arc(0, 360) + + # annulus needs to draw an outer ring + # followed by a reversed and scaled inner ring + a, b, w = self.a, self.b, self.width + v1 = self._transform_verts(arc.vertices, a, b) + v2 = self._transform_verts(arc.vertices[::-1], a - w, b - w) + v = np.vstack([v1, v2, v1[0, :], (0, 0)]) + c = np.hstack([arc.codes, Path.MOVETO, + arc.codes[1:], Path.MOVETO, + Path.CLOSEPOLY]) + self._path = Path(v, c) + + def get_path(self): + if self._path is None: + self._recompute_path() + return self._path + + +class GraphArtist(Artist): + """ + Matplotlib artist class that draws igraph graphs. + + Only Cairo-based backends are supported. + + Adapted from: https://stackoverflow.com/a/36154077/5962321 + """ + + def __init__(self, graph, axis, palette=None, *args, **kwds): + """Constructs a graph artist that draws the given graph within + the given bounding box. + + `graph` must be an instance of `igraph.Graph`. + `bbox` must either be an instance of `igraph.drawing.BoundingBox` + or a 4-tuple (`left`, `top`, `width`, `height`). The tuple + will be passed on to the constructor of `BoundingBox`. + `palette` is an igraph palette that is used to transform + numeric color IDs to RGB values. If `None`, a default grayscale + palette is used from igraph. + + All the remaining positional and keyword arguments are passed + on intact to `igraph.Graph.__plot__`. + """ + from igraph import BoundingBox, palettes + + super().__init__() + + self.graph = graph + self.palette = palette or palettes["gray"] + self.bbox = BoundingBox(axis.bbox.bounds) + self.args = args + self.kwds = kwds + + def draw(self, renderer): + from matplotlib.backends.backend_cairo import RendererCairo + + if not isinstance(renderer, RendererCairo): + raise TypeError( + "graph plotting is supported only on Cairo backends") + + self.graph.__plot__(renderer.gc.ctx, self.bbox, self.palette, + *self.args, **self.kwds) diff --git a/nngt/plot/plt_properties.py b/nngt/plot/plt_properties.py index b678990..775658f 100755 --- a/nngt/plot/plt_properties.py +++ b/nngt/plot/plt_properties.py @@ -31,6 +31,8 @@ from nngt.analysis import (degree_distrib, betweenness_distrib, node_attributes, binning) from .custom_plt import palette_continuous, palette_discrete, format_exponent +from matplotlib.gridspec import SubplotSpec + __all__ = [ 'degree_distribution', @@ -38,7 +40,7 @@ __all__ = [ 'edge_attributes_distribution', 'node_attributes_distribution', 'compare_population_attributes', - "correlation_to_attribute", + 'correlation_to_attribute', ] @@ -1015,8 +1017,11 @@ def _set_new_plot(fignum=None, num_new_plots=1, names=None, sharex=None): if int(ratio) != int(np.ceil(ratio)): num_rows += 1 # change the geometry + gs = fig.add_gridspec(num_rows, num_cols) for i in range(num_axes - num_new_plots): - fig.axes[i].change_geometry(num_rows, num_cols, i+1) + y = i // num_cols + x = i - num_cols*y + fig.axes[i].set_subplotspec(gs[y, x]) lst_new_axes = [] n_old = num_axes-num_new_plots+1 for i in range(num_new_plots): diff --git a/testing/test_plots.py b/testing/test_plots.py index 4603d97..080b08d 100644 --- a/testing/test_plots.py +++ b/testing/test_plots.py @@ -57,7 +57,7 @@ def test_plot_prop(): nplt.node_attributes_distribution( net, ["betweenness", "attr", "out-degree"], colors=["r", "g", "b"], - show=True) + show=False) @pytest.mark.mpi_skip @@ -80,21 +80,22 @@ def test_draw_network_options(): # restrict nodes nplt.draw_network(net, ncolor="g", nshape='s', ecolor="b", - restrict_targets=[1, 2, 3], show=False) + restrict_targets=[1, 2, 3], curved_edges=True, show=False) nplt.draw_network(net, restrict_nodes=list(range(10)), fast=True, show=False) nplt.draw_network(net, restrict_targets=[4, 5, 6, 7, 8], show=False) - nplt.draw_network(net, restrict_sources=[4, 5, 6, 7, 8], show=False) + nplt.draw_network(net, restrict_sources=[4, 5, 6, 7, 8], simple_nodes=True, + show=False) # colors and sizes for fast in (True, False): - maxns = 50 if fast else 10 - minns = 5 if fast else 1 - maxes = 2 if fast else 10 - mines = 0.2 if fast else 1 + maxns = 100 if fast else 20 + minns = 10 if fast else 2 + maxes = 2 if fast else 20 + mines = 0.2 if fast else 2 nplt.draw_network(net, ncolor="r", nalpha=0.5, ecolor="#999999", ealpha=0.5, nsize="in-degree", max_nsize=maxns, @@ -113,11 +114,45 @@ def test_draw_network_options(): nplt.draw_network(net, simple_nodes=True, ncolor="k", decimate_connections=-1, axis=ax, show=False) - nplt.draw_network(net, simple_nodes=True, ncolor="r", nsize=2, + nplt.draw_network(net, simple_nodes=True, ncolor="r", nsize=20, restrict_nodes=list(range(10)), esize='weight', ecolor="b", fast=True, axis=ax, show=False) +@pytest.mark.mpi_skip +def test_group_plot(): + ''' Test plotting with a Network and group colors ''' + gsize = 5 + + g1 = nngt.Group(gsize) + g2 = nngt.Group(gsize) + + s = nngt.Structure.from_groups({"1": g1, "2": g2}) + + positions = np.concatenate(( + nngt._rng.uniform(-5, -2, size=(gsize, 2)), + nngt._rng.uniform(2, 5, size=(gsize, 2)))) + + g = nngt.SpatialGraph(2*gsize, structure=s, positions=positions) + + nngt.generation.connect_groups(g, g1, g1, "erdos_renyi", edges=5) + nngt.generation.connect_groups(g, g1, g2, "erdos_renyi", edges=5) + nngt.generation.connect_groups(g, g2, g2, "erdos_renyi", edges=5) + nngt.generation.connect_groups(g, g2, g1, "erdos_renyi", edges=5) + + g.new_edge(6, 6, self_loop=True) + + nplt.draw_network(g, ncolor="group", ecolor="group", show_environment=False, + fast=True, show=False) + + nplt.draw_network(g, ncolor="group", ecolor="group", max_nsize=0.4, + esize=0.3, show_environment=False, show=False) + + nplt.draw_network(g, ncolor="group", ecolor="group", max_nsize=0.4, + esize=0.3, show_environment=False, curved_edges=True, + show=False) + + @pytest.mark.mpi_skip def test_library_plot(): ''' Check that plotting with the underlying backend library works ''' @@ -215,6 +250,21 @@ def test_plot_spatial_alpha(): nalpha=0.5, esize=0.1 + 3*fast, fast=fast) +@pytest.mark.mpi_skip +def test_annotations(): + num_nodes = 5 + positions = nngt._rng.uniform(-10, 10, (num_nodes, 2)) + + g = nngt.generation.erdos_renyi(edges=10, nodes=num_nodes, + positions=positions) + + g.new_node_attribute("name", "string", values=["a", "b", "c", "d", "e"]) + + nplt.draw_network(g, annotate=False, show=False) + nplt.draw_network(g, show=False) + nplt.draw_network(g, annotations="name", show=False) + + if __name__ == "__main__": test_plot_config() test_plot_prop() @@ -222,3 +272,5 @@ if __name__ == "__main__": test_library_plot() test_hive_plot() test_plot_spatial_alpha() + test_group_plot() + test_annotations() -- 2.32.0
builds.sr.ht <builds@sr.ht>NNGT/patches/.build.yml: SUCCESS in 31m3s [Plot: node annotations and improved plots][0] v3 from [~tfardet][1] [0]: https://lists.sr.ht/~tfardet/nngt-developers/patches/25530 [1]: mailto:tanguyfardet@protonmail.com ✓ #601471 SUCCESS NNGT/patches/.build.yml https://builds.sr.ht/~tfardet/job/601471