Source code for src.utils.plot

import os
import re
import random
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from matplotlib import rcParams
from matplotlib.font_manager import FontProperties
from typing import List, Dict, Literal, Tuple

__all__ = [
    "get_fig",
    "plot_resilience",
    "EqualizeNormalize",
    "plotOD",
    "clear_svg",
    "load_font",
    "use_chinese_font",
    "merge_axes",
]


[docs] def get_fig( RN, CN, FW=None, FH=None, AW=None, AH=None, A_ratio=1.0, LM=3, RM=3, TM=3, BM=3, HS=None, VS=None, dpi=300, fontsize=7, lw=0.5, gridspec=False, font_family="DejaVu Sans", **kwargs, ): plt.rcParams["font.family"] = font_family plt.rcParams["font.size"] = fontsize plt.rcParams["axes.labelsize"] = fontsize plt.rcParams["axes.titlesize"] = fontsize plt.rcParams["xtick.labelsize"] = fontsize plt.rcParams["ytick.labelsize"] = fontsize plt.rcParams["legend.fontsize"] = fontsize plt.rcParams["figure.titlesize"] = fontsize plt.rcParams["lines.linewidth"] = lw plt.rcParams["axes.linewidth"] = lw plt.rcParams["xtick.major.width"] = lw plt.rcParams["ytick.major.width"] = lw plt.rcParams["xtick.minor.width"] = lw plt.rcParams["ytick.minor.width"] = lw plt.rcParams["grid.linewidth"] = lw plt.rcParams["pdf.fonttype"] = 42 plt.rcParams["svg.fonttype"] = "none" LM = LM * fontsize / 72 RM = RM * fontsize / 72 TM = TM * fontsize / 72 BM = BM * fontsize / 72 HS = HS * fontsize / 72 if HS is not None else LM VS = VS * fontsize / 72 if VS is not None else TM if AW is not None: AW /= 2.54 FW = LM + CN * AW + (CN - 1) * HS + RM if AH is not None: AH /= 2.54 FH = TM + RN * AH + (RN - 1) * VS + BM A_ratio = AW / AH else: AH = AW / A_ratio FH = TM + RN * AH + (RN - 1) * VS + BM elif AH is not None: AH /= 2.54 AW = AH * A_ratio FH = TM + RN * AH + (RN - 1) * VS + BM FW = LM + CN * AW + (CN - 1) * HS + RM elif FW is not None: FW /= 2.54 AW = (FW - LM - RM - (CN - 1) * HS) / CN if FH is not None: FH /= 2.54 AH = (FH - TM - BM - (RN - 1) * VS) / RN A_ratio = AW / AH else: AH = AW / A_ratio FH = TM + RN * AH + (RN - 1) * VS + BM elif FH is not None: FH /= 2.54 AH = (FH - TM - BM - (RN - 1) * VS) / RN AW = AH * A_ratio FW = LM + CN * AW + (CN - 1) * HS + RM figinfo = dict( RN=RN, CN=CN, FW=FW, FH=FH, AW=AW, AH=AH, LM=LM, RM=RM, TM=TM, BM=BM, HS=HS, VS=VS, r_AW=AW / FW, r_AH=AH / FH, r_HS=HS / FW, r_VS=VS / FH, r_LM=LM / FW, r_RM=RM / FW, r_TM=TM / FH, r_BM=BM / FH, fontsize=fontsize, lw=lw, top_box=(LM / FW, 1 - TM / FH, 1 - (LM + RM) / FW, TM / FH), bottom_box=(LM / FW, 0, 1 - (LM + RM) / FW, BM / FH), right_box=(1 - RM / FW, BM / FH, RM / FW, 1 - (TM + BM) / FH), left_box=(0, BM / FH, LM / FW, 1 - (TM + BM) / FH), ) if gridspec: fig = plt.figure(figsize=(FW, FH), dpi=dpi, **kwargs) gs = fig.add_gridspec( RN, CN, wspace=HS / AW, hspace=VS / AH, left=LM / FW, right=1 - RM / FW, top=1 - TM / FH, bottom=BM / FH, ) return figinfo, fig, gs else: fig, axes = plt.subplots(RN, CN, figsize=(FW, FH), dpi=dpi, **kwargs) fig.subplots_adjust( left=LM / FW, right=1 - RM / FW, top=1 - TM / FH, bottom=BM / FH, wspace=HS / AW, hspace=VS / AH, ) axes = axes.ravel() if RN * CN > 1 else [axes] return figinfo, fig, axes
[docs] def plot_resilience( f: callable, extent=(0, 1, 0, 1), gridnum=(1000, 1000), cmap=None, norm=None, ax=None, reset_xylim=True, lw=1, ): """ Plot the resilience of a 2D function. - f: 2D function, such as lambda x, y: y - 3*y**2 - y**3 + x*y**3 - extent: Function domain (xmin, xmax, ymin, ymax) - gridnum: Grid resolution (xnum, ynum) --- Example: >>> f = lambda x, y: y - 3*y**2 - y**3 + x*y**3 >>> plot_resilience(f, ax=axes[0], extent=(0, 5, 0, 5)) >>> f = lambda x, y: np.sin(np.sqrt(x**2 + y ** 2)) >>> plot_resilience(f, ax=axes[1], extent=(-4*np.pi, 4*np.pi, -4*np.pi, 4*np.pi)) >>> f = lambda x, y: np.sin(np.sqrt(x ** 2 + y ** 2)) * x - y * np.cos(np.sqrt(x ** 2 + y ** 2)) >>> plot_resilience(f, ax=axes[2], extent=(-4*np.pi, 4*np.pi, -4*np.pi, 4*np.pi)) """ if cmap is None: cmap = mcolors.LinearSegmentedColormap.from_list( "my_bwr", ["#74b9ff", "white", "#ff7675"] ) if norm is None: norm = mcolors.SymLogNorm(linthresh=0.01, linscale=0.01, vmin=-1, vmax=1) if ax is None: ax = plt.gca() def plot1(x, y): ax.plot(x, y, "black", linewidth=lw, linestyle="-", solid_capstyle="round") def plot2(x, y): ax.plot( x, y, "gray", linewidth=lw, linestyle=(3, (3, 1)), solid_capstyle="round" ) xmin, xmax, ymin, ymax = extent xnum, ynum = gridnum x = np.linspace(xmin, xmax, xnum) y = np.linspace(ymin, ymax, ynum) x, y = np.meshgrid(x, y) z = f(x, y) ax.imshow( z, origin="lower", aspect="auto", cmap=cmap, norm=norm, zorder=0, extent=(xmin, xmax, ymin, ymax), ) contours = ax.contour(x, y, z, levels=[0], linewidths=0) for coll in contours.collections: coll.remove() dzdy = np.gradient(z)[0] for coll in contours.collections: for path in coll.get_paths(): saved = [] sign = None for x0, y0 in path.vertices: saved.append((x0, y0)) x_idx = int(np.floor((x0 - xmin) / (xmax - xmin) * (xnum - 1))) y_idx = int(np.floor((y0 - ymin) / (ymax - ymin) * (ynum - 1))) if sign is None: sign = np.sign(dzdy[y_idx, x_idx]) if np.sign(dzdy[y_idx, x_idx] * sign) < 0: if sign < 0: plot1(*np.array(saved).T) else: plot2(*np.array(saved).T) sign *= -1 saved = [] else: pass if saved: if sign < 0: plot1(*np.array(saved).T) else: plot2(*np.array(saved).T) if reset_xylim: scale = lambda min, max, ratio=0.1: ( min - (max - min) * ratio / 2, max + (max - min) * ratio / 2, ) ax.set_xlim(*scale(xmin, xmax)) ax.set_ylim(*scale(ymin, ymax))
[docs] class EqualizeNormalize(mcolors.Normalize): """Normalize by distribution rather than raw value."""
[docs] def __init__(self, samples, clip=False): super().__init__(vmin=samples.min(), vmax=samples.max(), clip=clip) hist, bin_edges = np.histogram( samples.flatten(), bins=256, range=(self.vmin, self.vmax), density=True ) cdf = hist.cumsum() cdf = (cdf - cdf.min()) / (cdf.max() - cdf.min()) self.bin_edges = bin_edges self.cdf = cdf
def __call__(self, value, clip=False): value = np.array(value) return np.ma.masked_array( np.interp(value.flatten(), self.bin_edges[:-1], self.cdf).reshape( value.shape ) )
[docs] def inverse(self, value): value = np.array(value) return np.ma.masked_array( np.interp(value.flatten(), self.cdf, self.bin_edges[:-1]).reshape( value.shape ) )
myhsv = mcolors.LinearSegmentedColormap.from_list( "myhsv", ["#d63031", "#e17055", "#fdcb6e", "#00b894", "#00cec9", "#0984e3", "#6c5ce7"], N=256, ) mybwr = mcolors.LinearSegmentedColormap.from_list( "mybwr", ["#0984e3", "#ffffff", "#d63031"], N=256 ) myhot = mcolors.LinearSegmentedColormap.from_list( "myhot", ["#0308F8", "#FD0B1B", "#ffff00"], gamma=5.0 )
[docs] def plotOD( ax, source: List[str], destination: List[str], flow: List[float], location: Dict[str, Tuple[float, float]], linetype: Literal[ "straight", "parabola", "rotated_parabola", "projected_parabola" ] = "straight", N=100, zorder=10, **kwargs, ): """Plot OD flow.""" cmap = mcolors.LinearSegmentedColormap.from_list( "cmap", ["#0308F8", "#FD0B1B", "#ffff00"], gamma=5.0 ) norm = EqualizeNormalize(flow.values) t = np.linspace(0, 1, N) ignored = set(list(source) + list(destination)) - set(location.keys()) if ignored: print(f"\033[33mWarning: {ignored} are not in location\033[0m") for s, d, f in zip(source, destination, flow): if s not in location or d not in location: continue p1 = np.array(location[s])[:, np.newaxis] # (2, 1) p2 = np.array(location[d])[:, np.newaxis] # (2, 1) if linetype == "straight": ax.plot( *(p1 * (1 - t) + p2 * t), lw=0.05 + 0.1 * norm(f), alpha=0.4 + 0.4 * norm(f), color=cmap(norm(f)), zorder=zorder + norm(f), ) elif linetype == "parabola": y_scale = locals().get("y_scale", np.diff(ax.get_ylim())[0]) xy = ( p1 * (1 - t) + p2 * t + np.array([[0], [1]]) * 4 * t * (1 - t) * norm(f) * y_scale * kwargs.get("scale", 0.1) ) ax.plot( *xy, lw=0.05 + 0.1 * norm(f), alpha=0.4 + 0.4 * norm(f), color=cmap(norm(f)), zorder=zorder + norm(f), ) elif linetype == "rotated_parabola": height = 0.5 * norm(f) C, S = (p2 - p1)[:, 0] A = np.array([[C, -S], [S, C]]) / 2 if kwargs.get("adjust_up", None) and (C < 0): A = -A if kwargs.get("adjust_down", None) and (C > 0): A = -A xy = A @ np.array([2 * t - 1, height * 4 * t * (1 - t)]) + 0.5 * (p1 + p2) ax.plot( *xy, lw=0.05 + 0.1 * norm(f), alpha=0.4 + 0.4 * norm(f), color=cmap(norm(f)), zorder=zorder + norm(f), ) elif linetype == "projected_parabola": p0 = locals().get( "p0", np.array( kwargs.get("p0", [np.mean(ax.get_xlim()), np.mean(ax.get_ylim())]) )[:, np.newaxis], ) D = kwargs.get("D", 10) xy = p0 + D * (p1 * (1 - t) + p2 * t - p0) / ( D - 4 * t * (1 - t) * (norm(f) + 1) ) ax.plot( *xy, lw=0.05 + 0.1 * norm(f), alpha=0.4 + 0.4 * norm(f), color=cmap(norm(f)), zorder=zorder + norm(f), ) else: raise ValueError(f"Invalid linetype: {linetype}") return ax
[docs] def clear_svg(path, debug=False): """ SVG files generated by matplotlib may use syntax like <text style="font: 9.8px 'Arial'; text-anchor: middle" x="80.307802" y="193.900483">2020-02-02</text>, while PowerPoint cannot recognize the shorthand form `font: 9.8px 'Arial';` and only supports forms such as `font-family: 'Arial'; font-size: 9.8px;`. Therefore, conversion is required. The handled properties include: - font-size - font-family - font-weight - font-style """ from lxml import etree tree = etree.parse(path) root = tree.getroot() for text in root.findall(".//{http://www.w3.org/2000/svg}tspan") + root.findall( ".//{http://www.w3.org/2000/svg}text" ): style = text.attrib.pop("style", "") font_size = re.search(r"font:[^;]*\s+(\d+\.?\d*+px)(?:\s|$)", style) font_family = re.search(r"font:[^;]*\s+\'([^\']*)\'(?:\s|$)", style) font_style = re.search(r"font:[^;]*\s+(italic|oblique)(?:\s|$)", style) font_weight = re.search( r"font:[^;]*\s+(bold|normal|bolder|lighter)(?:\s|$)", style ) line_height = re.search(r"line-height:\s+([\d\.]+(px|em|%)?)", style) font_variant = re.search(r"font-variant:\s+([\w-]+)", style) text_anchor = re.search(r"text-anchor:\s+([\w-]+)", style) if font_size: text.set("font-size", font_size.group(1)) if font_family: text.set("font-family", font_family.group(1)) if font_style: text.set("font-style", font_style.group(1)) if font_weight: text.set("font-weight", font_weight.group(1)) if line_height: text.set("line-height", line_height.group(1)) if font_variant: text.set("font-variant", font_variant.group(1)) if text_anchor: text.set("text-anchor", text_anchor.group(1)) if debug: print(f'"{style}" -> "{text.attrib}"') tree.write(path)
[docs] def load_font(): """ - plt.title("Example Chart: Square of Numbers", fontproperties=font, size=15) - plt.xlabel("Number", fontproperties=font, size=12) - plt.ylabel("Square", fontproperties=font, size=12) """ import requests from matplotlib.font_manager import FontProperties url = "https://ghp.ci/https://github.com/notofonts/noto-cjk/blob/main/Sans/SubsetOTF/SC/NotoSansSC-Regular.otf" font_path = "/tmp/SimHei.otf" if not os.path.exists(font_path): response = requests.get(url) if response.status_code == 200: with open(font_path, "wb") as f: f.write(response.content) else: raise Exception("Failed to download the font.") font = FontProperties(fname=font_path) plt.rcParams["font.family"] = "sans-serif" plt.rcParams["font.sans-serif"] = [font_path] # Specify the .otf file path here. plt.rcParams["axes.unicode_minus"] = False # Display minus signs correctly. return font
# Compute the overlapping area of two boxes. def _overlapping_area(bbox1, bbox2): x0 = max(bbox1.x0, bbox2.x0) x1 = min(bbox1.x1, bbox2.x1) y0 = max(bbox1.y0, bbox2.y0) y1 = min(bbox1.y1, bbox2.y1) return max(0, x1 - x0) * max(0, y1 - y0) # Compute the nearest distance from a point to a box. def _distance_to_box(point, box): x = max(box.x0, min(point[0], box.x1)) y = max(box.y0, min(point[1], box.y1)) return np.sqrt((point[0] - x) ** 2 + (point[1] - y) ** 2) # Function for adjusting text positions. def adjust_text(texts, ax, step=0.01, max_iterations=100, mode="xy"): raw_pos = np.array([text.get_position() for text in texts], dtype=float) cur_pos = raw_pos.copy() overlap = np.zeros((len(texts), len(texts))) def update(i, utri_only=False): bbox1 = texts[i].get_window_extent(renderer=ax.figure.canvas.get_renderer()) for j, text2 in enumerate(texts): if j == i: continue if utri_only and j <= i: continue bbox2 = text2.get_window_extent(renderer=ax.figure.canvas.get_renderer()) overlap[i, j] = overlap[j, i] = _overlapping_area(bbox1, bbox2) def normalize(v): return v / np.linalg.norm(v).clip(1e-6) # Initialize. for i in range(len(texts)): update(i, utri_only=True) for iteration in range(max_iterations): i, j = random.choice(list(np.stack(np.nonzero(overlap), axis=-1))) # Repulsion from i to j. F1 = cur_pos[j] - cur_pos[i] # Attraction from raw_pos[j] to j. bbox = texts[j].get_window_extent(renderer=ax.figure.canvas.get_renderer()) F2 = ( normalize(raw_pos[j] - cur_pos[j]) * _distance_to_box(raw_pos[j], bbox) * 0.001 ) # F2 = 0.0 # Random perturbation. F3 = np.random.randn(2) # Resultant force. F = F1 + F2 + F3 if mode == "x": F[1] = 0 if mode == "y": F[0] = 0 # Update position. cur_pos[j] += F * step texts[j].set_position(cur_pos[j]) update(j) if (overlap == 0).all(): break
[docs] def use_chinese_font(fontpath="/usr/share/fonts/opentype/noto/NotoSansCJK-Regular.ttc"): font_prop = FontProperties(fname=fontpath) rcParams["font.family"] = font_prop.get_name() rcParams["axes.unicode_minus"] = False
[docs] def merge_axes(axes): xmin = min(ax.get_position().x0 for ax in axes) xmax = max(ax.get_position().x1 for ax in axes) ymin = min(ax.get_position().y0 for ax in axes) ymax = max(ax.get_position().y1 for ax in axes) merged_ax = plt.gca().figure.add_axes([xmin, ymin, xmax - xmin, ymax - ymin]) for ax in axes: ax.remove() # Remove the old axes to avoid overlap return merged_ax