Source code for plot

# -*- coding: utf-8 -*-
"""Plotting functions."""

__author__ = "Matteo Levantino"
__contact__ = "matteo.levantino@esrf.fr"
__licence__ = "MIT"
__copyright__ = "ESRF - The European Synchrotron, Grenoble, France"
__date__ = "01/09/2021"


import numpy as np
import matplotlib.pyplot as plt
from math import ceil, floor
from matplotlib.colors import SymLogNorm  # not available in matplotlib 3.1
from pathlib import Path
from txs.utils import t2str

plt.ion()

qlabel = r"q ($\AA^{-1}$)"


def plot_raw_diffs():
    pass


def plot_raw_abs():
    pass


def plot_diffs_and_abs():
    pass


[docs] def plot_diffs(data, sel=None, every=None, hide_lines=False, cmap=None, yscale='diff', lbls=None, error_bars=False, title='short', plot_ref=True, plot_abs_mean=False, diff_plus_ref=False, tpause=None, tdigits=2, legend_ncols='auto', legend_nlines=None, xlabel=qlabel, xlim=(0, None), ylim=None, figsize=None, fig=None, ax=None, return_fig=False, return_ax=False, return_lines=False): """ Plot time-resolved difference patterns (averaged over different delays). Difference patterns can be hidden or shown upon a click on the corresponding legend entry. Parameters ------------ data : dict or tuple Data reduction result obtained from the txs.datared.datared(). sel : list or slice or None, optional List of time-delays (strings) for the plot. Only difference patterns corresponding to the time-delays list will be displayed in the plot. If None (default), all difference patterns will be displayed. every : int or None, optional Spacing at which time-delays for the plot will be selected. If 'sel' is not None, 'every' will be applied to 'sel' rather than to the whole list of time-delays. If None (default), all difference patterns will be displayed. hide_lines : bool, optional If True, all selected curves are hidden. Default is False. cmap : str or None, optional Name of matplotlib colormap. yscale : str, optional If 'diff' (default), scattering intensity differences are plotted. If 'qdiff', scattering intensity difference times q-values are plotted. lbls : list or None, optional List of labels to be used for the plot legend. Default is None. error_bars : bool, optional If True, patterns are plotted with error bars. Default is False. title : str or None, optional Figure title. If None, the full data folder is used as title. If 'short' (default), the data folder is used as title, but only the last 2 subfolders are kept and the other parent subfolders are removed. remove_folder_parent plot_ref : bool, optional If True (default), the difference pattern corresponding to the reference time-delay is plotted. plot_abs_mean : bool, optional If True, the average absolute pattern (average over all data, irrespective of the time-delay) is plotted on a separate (top) panel. Default is False. diff_plus_ref : bool, optional If True, will plot the absolute patterns for each time delay calculated as average of ref patterns + differential patterns tpause : float or None, optional Interval time after which the figure is updated. tdigits : int, optional Number of digits used for time-delay labels. Default is 2. legend_ncols : int or 'auto' or None If 'auto' (default), the number of columns in the legend is modified automatically and optimized on the basis of the total number of curves. (if nlines is Nones, the legend will use default matplotlib settings) If int, the legend will have 'ncols' columns. If None, the legend will use default matplotlib settings. (same as ncols=1). legend_nlines : int or None If 'auto' (default), the number of lines in the legend is modified automatically and optimized to fit figure height. If None (default), the legend will use default matplotlib settings (no limitations on the number of lines is applied). If int, the legend will have 'nlines' lines. """ if isinstance(data, dict): x = data['q'] t = data['t'] y = data['diff_av'] e = data.get('diff_err', None) if 'azav' in data.keys(): folder = data['azav'].get('folder', None) if 'filt_res' in data.keys(): filt_res = data['filt_res'] else: filt_res = None elif isinstance(data, (list, tuple)): folder = None filt_res = None if len(data) == 3: x, y, t = data elif len(data) == 4: x, y, t, e = data else: raise ValueError("'data' must have length 3 or 4 if array-like.") else: raise TypeError("'data' must be dict or array-like.") if yscale in ['qdiff', 'qdeltai', 'qdi', 'q*deltai']: for k in range(y.shape[1]): y[:, k] = x*y[:, k] ylabel = r'q*$\Delta$I' else: ylabel = r'$\Delta$I' if title is None: title = folder elif title == 'short': title = Path(*Path(folder).parts[-2:]) else: if not isinstance(title, 'str'): raise TypeError("'title' must be None or str.") if lbls is None: lbls = [t2str(tt, digits=tdigits) for tt in t] if filt_res is not None: filt_lbls = ["(%d/%d)" % res for res in filt_res] lbls = ["%s %s" % (l1, l2) for (l1, l2) in zip(lbls, filt_lbls)] if ax is None: if fig is None: fig = plt.figure() else: fig.clear() if plot_abs_mean: ax0, ax = fig.subplots(2, 1, sharex=True) ax0.plot(x, np.median(data['i'], axis=1)) ax0.set_ylabel("Scattering intensity (a.u.)") ax0.set_ylim(0, ) ax0.grid(alpha=0.7) ax0.legend(['median'], loc='upper right') else: ax = fig.subplots(1, 1) elif fig is None: fig = ax.figure() else: ax = fig.gca() if diff_plus_ref: y = y.T + _ref_average(data) y = y.T ret = _plot(x, y, t, e=e, sel=sel, every=every, hide_lines=hide_lines, cmap=cmap, lbls=lbls, error_bars=error_bars, title=title, xlabel=xlabel, ylabel=ylabel, xlim=xlim, ylim=ylim, tpause=tpause, figsize=figsize, fig=fig, ax=ax, return_fig=return_fig, return_ax=return_ax, return_lines=return_lines, legend_ncols=legend_ncols, legend_nlines=legend_nlines) return ret
[docs] def plot_abs(data, sel=None, every=None, hide_lines=False, cmap=None, lbls=None, error_bars=False, title='short', tpause=None, folder_parent=None, xlabel=qlabel, ylabel=None, xlim=(0, None), ylim=(0, None), fig=None, ax=None, return_fig=False, return_ax=False, legend_ncols='auto', legend_nlines=None, figsize=None, map2D=False, map_clim=None, return_lines=False): """ Plot absolute patterns. TO DO: should plot averages absolute patterns diff_av(q, t) + <abs(q, tref)> TO DO: map2D should go to plot_raw_abs and plot_raw_diffs Patterns can be hidden or shown upon a click on the corresponding legend entry. Parameters ---------- data : dict or tuple Data reduction result obtained from the txs.datared.datared(). sel : list or slice or None, optional List of time-delays (strings) for the plot. Only patterns corresponding to the time-delays list will be displayed in the plot. If None (default), all patterns will be displayed. every : int or None, optional Spacing at which time-delays for the plot will be selected. If 'sel' is not None, 'every' will be applied to 'sel' rather than to the whole list of time-delays. If None (default), all patterns will be displayed. hide_lines : bool, optional If True, all selected curves are hidden. Default is False. cmap : str or None, optional Name of matplotlib colormap. lbls : list or None, optional List of labels to be used for the plot legend. Default is None. error_bars : bool, optional If True, patterns are plotted with error bars. Default is False. title : str or None, optional Figure title. If None, the full data folder is used as title. If 'short' (default), the data folder is used as title, but only the last 2 subfolders are kept and the other parent subfolders are removed. tpause : float or None, optional Interval time after which the figure is updated. legend_ncols : int or 'auto' or None If 'auto' (default), the number of columns in the legend is modified automatically and optimized on the basis of the total number of curves. (if nlines is Nones, the legend will use default matplotlib settings) If int, the legend will have 'ncols' columns. If None, the legend will use default matplotlib settings. (same as ncols=1). legend_nlines : int or None If 'auto' (default), the number of lines in the legend is modified automatically and optimized to fit figure height. If None (default), the legend will use default matplotlib settings (no limitations on the number of lines is applied). If int, the legend will have 'nlines' lines. map2D : bool, optionnal If True, will display the patterns as a 2D map (image_number,q) map_clim : tupple or None Ignored if map2D is False """ if isinstance(data, dict): x = data['q'] t = data['delays'] y = data['i'] e = data.get('e', None) folder = data['azav'].get('folder', None) elif isinstance(data, (list, tuple)): folder = None if len(data) == 3: x, y, t = data elif len(data) == 4: x, y, t, e = data else: raise ValueError("'data' must have length 3 or 4 if array-like.") ylabel = 'Scattered intensity (a.u.)' if title is None: title = folder elif title == 'short': title = Path(*Path(folder).parts[-2:]) else: if not isinstance(title, 'str'): raise TypeError("'title' must be None or str.") if lbls is None: lbls = range(y.shape[1]) if ax is None: if fig is None: fig = plt.figure() ax = fig.subplots(1, 1) elif fig is None: fig = ax.figure() ret = _plot(x, y, t, e=e, sel=sel, every=every, hide_lines=hide_lines, cmap=cmap, lbls=lbls, error_bars=error_bars, title=title, xlabel=xlabel, ylabel=ylabel, xlim=xlim, ylim=ylim, tpause=tpause, fig=fig, ax=ax, return_fig=return_fig, return_ax=return_ax, return_lines=return_lines, figsize=figsize, legend_ncols=legend_ncols, legend_nlines=legend_nlines, map2D=map2D, map_clim=map_clim) return ret
[docs] def plot_filt_hist(res, filt='red_chi2', bins=None, fig=None, ax=None): """Plot histogram of filtering parameter ('red_chi2' or 'pts_perc').""" if filt not in ['red_chi2', 'pts_perc']: raise ValueError("'filt' must be 'red_chi2' or 'pts_perc'.") if fig is None: fig = plt.figure() if ax is None: fig.add_subplot(111) ax = plt.gca() y = np.concatenate(res[filt]) ny = len(y) if bins is None: if int(ny/5) > 10: bins = int(ny/5) else: bins = 10 # default for plt.hist() ax.hist(y, bins=bins) ax.set_xlabel(filt) ax.set_ylabel("number of occurrences") ax.set_title("histogram over %d curves (%d bins)" % (ny, bins)) ax.grid(alpha=0.7) if filt+"_max" in res.keys(): if res[filt+"_max"] is not None: ax.legend(["%s = %g" % (filt+"_max", res[filt+"_max"])]) else: ax.legend([filt+"_max"])
[docs] def plot_motor_scan(res, qrange, td, abs_value=True, min_max=False, xscale=1, yscale=1, xlabel='scan motor position', title='auto', fig=None, ax=None): """ Plot average difference signal (averaged over a given q-range) as a function of a scanned motor position. The motor position corresponding to the max value of the averaged difference signal is automatically calculated. Parameters ---------- res : dict Result of txs.datared.datared() with 'scan_motor' != None. qrange : array-like Range of q-values (min, max) to use for signal averaging. td : str, optional Time-delay label corresponding to the data to use for the plot. abs_value : bool, optional If True, the absolute value of difference patterns is used to calculate the best motor position. Default is True. min_max : bool, optional If True, the difference between the max difference signal and the min difference signal is used to calculate the position (NO AVERAGED IS DONE!!!). Default is False. xscale : float, optional Scaling factor for motor position values. Default is 1. yscale : float, optional Scaling factor for signal. Default is 1. xlabel : str, optional Label for signal plot x-axis. Default is "scan motor position".BaseException title : str or None, optional If None, the full data folder is used. If 'auto', the dataset name is used. """ pos = list(res.keys()) res0 = res[pos[0]] tds = np.array(res0['t']) if td not in tds: raise ValueError("'td' must be in: ", tds) td_idx = (tds == td) t = tds[td_idx][0] q = res0['q'] if abs_value and min_max: raise ValueError("'abs_value' and 'min_max' cannot be both True.") sig = [] for p in pos: diff_av = res[p]['diff_av'][:, td_idx][:, 0] qidx = (q >= qrange[0]) & (q <= qrange[1]) if abs_value: sk = np.mean(np.abs(diff_av[qidx]), axis=0) elif min_max: diff_max = np.max(diff_av[qidx], axis=0) diff_min = np.min(diff_av[qidx], axis=0) sk = abs(diff_max - diff_min) else: sk = np.mean(diff_av[qidx], axis=0) sig.append(sk) pos_max = np.array(pos)[sig == max(sig)][0] diff_av_max = res[pos_max]['diff_av'][:, td_idx][:, 0] pos = xscale*np.array(pos) sig = yscale*np.array(sig) pos_max *= xscale diff_av_max *= yscale fig, ax = plt.subplots(2, 1) folder = res0['azav']['folder'] if title is None: title = folder elif title == 'auto': title = Path(*Path(folder).parts[-2:]) else: if not isinstance(title, str): raise TypeError("'title' must be None or str.") fig.suptitle(title) ax[1].plot(pos, sig, '.-') ax[1].set_xlabel(xlabel) ax[1].set_ylabel(r"Average over q=(%.2f, %.2f) $\AA^{-1}$" % (qrange[0], qrange[1])) ax[1].grid(alpha=0.7) ax[1].legend([t]) ax[0].plot(q, diff_av_max) ymin, ymax = ax[0].get_ylim() ax[0].vlines(x=qrange[0], ymin=ymin, ymax=ymax, ls='--', color='k') ax[0].vlines(x=qrange[1], ymin=ymin, ymax=ymax, ls='--', color='k') ax[0].set_xlabel(qlabel) ax[0].set_ylabel(r'$\Delta$I') ax[0].grid(alpha=0.7) ax[0].legend(["pos=%g" % pos_max]) plt.tight_layout()
[docs] def plot_azim_regroup(img, ai, N=600, M=360, center=None, vline=None, label=None, ax=None, return_ax=False, cmap="inferno", clim=None): """ Perform and plot azimuthal 2d regrouping ("caking") of an image Azimuthal regrouping is performed over N radian bins and M angular steps. Parameters ---------- img : array_like Image. ai : pyFAI ... pyFAI azimuthal integrator obj. N : int, optional Number of radial bins. Default is 600. M : int, optional Number of angular steps. Default is 360. cmap : str, optional Figure colormap. Default is 'inferno'. clim : tuple or None, optional Figure color limits. Default is None. """ if center is not None: ai.setFit2D(centerX=center[0], centerY=center[1], directDist=ai.dist*1e3, tilt=ai.tilt[0], tiltPlanRotation=ai.tilt[1]) res = ai.integrate2d(img, N, M, unit="2th_deg") rad = res.radial # 2th in deg azm = res.azimuthal # phi in ?? if ax is None: fig, ax = plt.subplots() ax.clear() colornorm = SymLogNorm( 1, base=10, vmin=np.nanmin(img), vmax=np.nanmax(img)) ax.imshow( res.intensity, origin='lower', extent=[0, rad.max(), azm.min(), azm.max()], aspect='auto', cmap=cmap, clim=clim, norm=colornorm) if label: ax.set_title("2D regrouping") else: ax.set_title(label) ax.set_xlabel(r"Scattering angle $2\theta$ ($^{o}$)") ax.set_ylabel(r"Azimuthal angle $\varphi$ ($^{o}$)") # if center is not None: # xc = center[0]*ai.pixel1 - ai.poni1 # m # yc = center[1]*ai.pixel2 - ai.poni2 # m # r = np.sqrt(xc**2 + yc**2) # m # phi = np.rad2deg(np.arccos(xc/r)) # tth = np.rad2deg(np.arctan(r/ai.dist)) # ax.vlines(x=phi, ymin=azm.min(), ymax=azm.max(), ls='--') # ax.hlines(y=tth, xmin=rad.min(), xmax=rad.max(), ls='--') ax.vlines(x=0, ymin=azm.min(), ymax=azm.max(), ls='--') ax.hlines(y=0, xmin=0, xmax=rad.max(), ls='--') if vline: ax.vlines(x=vline, ymin=azm.min(), ymax=azm.max()) if return_ax: return res, ax return res
def _track_init(qmon=None): """ Prepare new figure to monitor scattering patterns and track specific the signal in specific qrange(s). Parameters ---------- data : output of txs.azav.integrate1d_dataset() Input data qmon : tuple or list or None, optional Monitor q-range(s). If tuple, monitoring q-range is (qmon[0], qmon[1]). If list, each element of the list is a monitoring q-range. If None (default), only absolute patterns are tracked. Returns ------- ... """ if qmon is None: fig, ax0 = plt.subplots(1, 1) ax1 = None ax2 = None elif isinstance(qmon, tuple): qmon1 = qmon qmon2 = None fig, (ax0, ax1) = plt.subplots(1, 2) ax2 = None elif isinstance(qmon, list): qmon1 = qmon[0] qmon2 = qmon[1] fig = plt.figure() ax0 = fig.add_subplot(121) ax1 = fig.add_subplot(222) ax2 = fig.add_subplot(224) else: raise TypeError("'qmon' must be tuple, list or None.") ax = (ax0, ax1, ax2) qmon = (qmon1, qmon2) return fig, ax, qmon
[docs] def track_abs_init(data, qmon=None, qnorm=None, title='short'): """ Prepare new figure to monitor absolute scattering patterns and specific qrange(s). Parameters ---------- data : output of txs.azav.integrate1d_dataset() Input data qmon : tuple or list or None, optional Monitor q-range(s). If tuple, monitoring q-range is (qmon[0], qmon[1]). If list, each element of the list is a monitoring q-range. If None (default), only absolute patterns are tracked. Returns ------- ... """ fig, ax, qmon = _track_init(qmon) ln, sig = _track_abs_plot(data, fig, ax, qmon, init=True, qnorm=qnorm) folder = data['folder'] if title is None: title = folder elif title == 'short': title = Path(*Path(folder).parts[-2:]) else: if not isinstance(title, 'str'): raise TypeError("'title' must be None or str.") fig.suptitle(title) plt.tight_layout() return fig, ax, ln, sig
[docs] def track_diff_init(data, qmon=None, qnorm=None, track_t=None, title='short'): """ Prepare new figure to monitor difference scattering patterns and specific qrange(s). Parameters ---------- data : output of txs.azav.integrate1d_dataset() Input data qmon : tuple or list or None, optional Monitor q-range(s). If tuple, monitoring q-range is (qmon[0], qmon[1]). If list, each element of the list is a monitoring q-range. If None (default), only absolute patterns are tracked. Returns ------- ... """ fig, ax, qmon = _track_init(qmon) fig.set_size_inches(12, 6) ln, sig = _track_diff_plot(data, fig, ax, qmon, init=True, qnorm=qnorm, track_t=track_t) folder = data['azav']['folder'] if title is None: title = folder elif title == 'short': title = Path(*Path(folder).parts[-2:]) else: if not isinstance(title, 'str'): raise TypeError("'title' must be None or str.") fig.suptitle(title) plt.tight_layout() plt.pause(0.1) return fig, ax, ln, sig
def _track_update(fig, qmon=None): """ Update figure for absolute patterns tracking. """ plt.figure(fig.number) if qmon is None: qmon1, qmon2 = None, None elif isinstance(qmon, tuple): qmon1, qmon2 = qmon, None elif isinstance(qmon, list): qmon1, qmon2 = qmon[0], qmon[1] else: raise TypeError("'qmon' must be tuple, list or None.") qmon = (qmon1, qmon2) return qmon
[docs] def track_abs_update(data, fig, ax, ln, sig=None, qmon=None, qnorm=None): """ Update figure for absolute patterns tracking. """ qmon = _track_update(fig, qmon) ln, sig = _track_abs_plot( data, fig, ax, qmon, init=False, ln=ln, qnorm=qnorm) return ln, sig
[docs] def track_diff_update(data, fig, ax, ln, sig=None, qmon=None, qnorm=None, track_t=None): """ Update figure for difference patterns tracking. """ qmon = _track_update(fig, qmon) ln, sig = _track_diff_plot( data, fig, ax, qmon, init=False, ln=ln, qnorm=qnorm, track_t=track_t) return ln, sig
def _track_abs_plot(data, fig, ax, qmon=None, sig=None, qnorm=None, init=False, ln=None, tpause=0.25): """ Returns ------- ln : list Lines to be updated. sig : list Signals calculated so far. """ q = data['q'] i = data['i'] if qnorm is not None: idx = (q >= qnorm[0]) & (q <= qnorm[1]) i = i/np.mean(i[idx, :], axis=0) if init: ln = [None, None, None] elif ln is None: raise ValueError("'l' can be None only if 'init' is True.") if init: ln[0], = ax[0].plot(q, i[:, -1]) ax[0].set_xlabel(qlabel) ax[0].set_ylabel("Scattered intensity (a.u.)") ax[0].grid(alpha=0.7) ax[0].set_title("image_no=%d" % data['i'].shape[1]) ax[0].set_xlim(0, ) ax[0].set_ylim(0, ) else: ln[0].set_xdata(q) ln[0].set_ydata(i[:, -1]) ax[0].set_title("image_no=%d" % data['i'].shape[1]) ax[0].set_xlim(0, ) ax[0].set_ylim(0, ) if sig is None: sig = (None, None) if qmon[0] is not None: sig1 = _get_qrange_sum(q, i, qmon[0], sig[0]) if init: ln[1], = ax[1].plot(sig1, '.-') ax[1].set_xlabel('image #') ax[1].set_ylabel(r'sum over (%g, %g) [$\AA^{-1}$]' % (qmon[0][0], qmon[0][1])) ax[1].grid(alpha=0.7) else: ln[1].set_xdata(range(len(sig1))) ln[1].set_ydata(sig1) ax[1].relim() # to recompute ax.dataLim ax[1].autoscale_view() # to update ax.viewLim else: sig1 = None if qmon[1] is not None: sig2 = _get_qrange_sum(q, i, qmon[1], sig[1]) if init: ln[2], = ax[2].plot(sig2, '.-') ax[2].set_ylabel(r'sum over (%g, %g) [$\AA^{-1}$]' % (qmon[1][0], qmon[1][1])) ax[2].grid(alpha=0.7) else: ln[2].set_xdata(range(len(sig2))) ln[2].set_ydata(sig2) ax[2].relim() # to recompute ax.dataLim ax[2].autoscale_view() # to update ax.viewLim else: sig2 = None sig = [sig1, sig2] plt.pause(tpause) return ln, sig def _track_diff_plot(data, fig, ax, qmon=None, sig=None, qnorm=None, track_t=None, init=False, ln=None, tpause=0.25): """ Returns ------- ln : list Lines to be updated. sig : list Signals calculated so far. """ q = data['q'] if track_t is None or track_t == 'last': t_idx = -1 elif track_t == 'first': t_idx = 0 else: if not isinstance(track_t, str): raise TypeError("'track_t' must be str.") if track_t not in data['t']: raise ValueError("'track_t' is not in available time-delays: " + data['t']) t_idx = np.where(np.array(data['t']) == track_t)[0][0] diffs_t = data['diffs'][t_idx] if qnorm is not None: idx = (q >= qnorm[0]) & (q <= qnorm[1]) diffs_t = diffs_t/np.mean(diffs_t[idx, :], axis=0) if init: ln = [None, None, None] elif ln is None: raise ValueError("'l' can be None only if 'init' is True.") if init: ln[0], = ax[0].plot(q, diffs_t[:, -1]) ax[0].set_xlabel(qlabel) ax[0].set_ylabel("Scattering difference (a.u.)") ax[0].grid(alpha=0.7) rep_t = diffs_t.shape[1] ax[0].set_title("%s, rep=%d" % (data['t'][t_idx], rep_t)) # ax[0].set_xlim(0, ) # ax[0].set_ylim(0, ) else: ln[0].set_xdata(q) ln[0].set_ydata(diffs_t[:, -1]) rep_t = diffs_t.shape[1] ax[0].set_title("%s, rep=%d" % (data['t'][t_idx], rep_t)) # ax[0].set_xlim(0, ) # ax[0].set_ylim(0, ) if sig is None: sig = (None, None) if qmon[0] is not None: # sig1 = _get_qrange_sum(q, diffs_t, qmon[0], sig[0]) sig1 = _get_qrange_max(q, diffs_t, qmon[0], sig[0], abs_value=True) if init: ln[1], = ax[1].plot(sig1, '.-') ax[1].set_xlabel('rep #') ax[1].set_ylabel(r'abs(signal) (%.2g, %.2g) $\AA^{-1}$' % (qmon[0][0], qmon[0][1])) ax[1].grid(alpha=0.7) else: ln[1].set_xdata(range(1, len(sig1)+1)) ln[1].set_ydata(sig1) ax[1].relim() ax[1].autoscale_view() else: sig1 = None if qmon[1] is not None: # sig2 = _get_qrange_sum(q, diffs_t, qmon[1], sig[1]) sig2 = _get_qrange_max(q, diffs_t, qmon[1], sig[1], abs_value=True) if init: ln[2], = ax[2].plot(sig2, '.-') ax[2].set_ylabel(r'abs(signal) (%.2g, %.2g) $\AA^{-1}$' % (qmon[1][0], qmon[1][1])) ax[2].grid(alpha=0.7) else: ln[2].set_xdata(range(1, len(sig2)+1)) ln[2].set_ydata(sig2) ax[2].relim() ax[2].autoscale_view() else: sig2 = None sig = [sig1, sig2] plt.pause(tpause) return ln, sig def _get_qrange_sum(q, i, qrange, sig=None, sum_abs=False): idx = (q >= qrange[0]) & (q <= qrange[1]) i_arr = np.array(i) i_arr_idx = i_arr[idx, :] if sum_abs: i_arr_idx = np.abs(i_arr_idx) if sig is None: sig = np.sum(i_arr_idx, axis=0) else: sig = np.concatenate((sig, np.sum(i_arr_idx, axis=1))) return sig def _get_qrange_max(q, i, qrange, sig=None, abs_value=False): idx = (q >= qrange[0]) & (q <= qrange[1]) i_arr = np.array(i) i_arr_idx = i_arr[idx, :] if abs_value: i_arr_idx = np.abs(i_arr_idx) if sig is None: sig = np.max(i_arr_idx, axis=0) else: sig = np.concatenate((sig, np.max(i_arr_idx, axis=1))) return sig def get_qrange_mean(): pass def _onpick(event, fig, lines_dict): """Toggle visibility of legend line and corresponding plotted line.""" legend_line = event.artist original_line = lines_dict[legend_line] visible = not original_line.get_visible() original_line.set_visible(visible) if visible: legend_line.set_alpha(1.0) else: legend_line.set_alpha(0.2) fig.canvas.draw() return True def _plot(x, y, t, e=None, sel=None, every=None, hide_lines=False, cmap=None, lbls=None, error_bars=False, title=None, xlabel=None, ylabel=None, xlim=None, ylim=None, tpause=None, figsize=None, fig=None, ax=None, legend_ncols='auto', legend_nlines=None, map2D=False, map_clim=None, return_fig=False, return_ax=True, return_lines=False): """ Plot time-resolved patterns. Patterns can be hidden or shown upon a click on the corresponding legend entry. Parameters ---------- ... legend_ncols : int or 'auto' or None If 'auto' (default), the number of columns in the legend is modified automatically and optimized on the basis of the total number of curves. (if nlines is Nones, the legend will use default matplotlib settings) If int, the legend will have 'ncols' columns. If None, the legend will use default matplotlib settings. (same as ncols=1). legend_nlines : int or None If 'auto' (default), the number of lines in the legend is modified automatically and optimized to fit figure height. If None (default), the legend will use default matplotlib settings (no limitations on the number of lines is applied). If int, the legend will have 'nlines' lines. map2D : bool, optionnal If True, will display the patterns as a 2D map (image_number,q) map_clim : tupple or None Ignored if map2D is False """ # for plot_diffs: len(t) = number of unique time-delays # for plot_abs: len(t) = number of time-delays (including repetitions) if len(t) != y.shape[1]: raise ValueError("'t' length must be equal to the number of " + " columns of 'y'.") if sel is None: idx = range(len(t)) if every is not None: idx = idx[::every] elif isinstance(sel, list): # 'sel' is a list of time-delays labels idx = [] sub = [np.where(np.array(t) == s)[0] for s in sel] # sub is a list of arrays # each array contains all indeces corresponding to an element of 'sel' if every is not None: sub = [s[::every] for s in sub] for s in sub: idx.extend(s) add_title = ', '.join(sel) if title is not None: title += ': ' + add_title else: title = add_title elif isinstance(sel, slice): if every is not None and sel.step is None: sel = slice(sel.start, sel.stop, every) idx = range(len(t))[sel] if map2D: plt.imshow(y[:, idx], aspect="auto", clim=map_clim, extent=[0, len(idx), x[-1], x[0]]) ax.set_ylabel(xlabel) ax.set_xlabel("image_number") return lines = [] if isinstance(cmap, str): try: cmap = getattr(plt.cm, cmap) except KeyError: cmap = plt.cm.prism print("WARNING : colormap could not be found, using default " + "colormap ('prism') instead") for linenum, k in enumerate(idx): if cmap is not None: color = cmap(k/(len(idx)-1)) kw = dict(color=color, label=lbls[k]) else: kw = dict(label=lbls[k]) if e is not None and error_bars: line = ax.errorbar(x, y[:, k], e[:, k], **kw)[0] else: line = ax.plot(x, y[:, k], **kw)[0] lines.append(line) _adjust_legend(ax, ncols=legend_ncols, nlines=legend_nlines, figsize=figsize) if hide_lines: ax = plt.gca() for ln in ax.get_lines(): ln.set_visible(False) for ln in ax.legend_.get_lines(): ln.set_alpha(0.7) ax.grid(alpha=0.7) ax.set_xlabel(xlabel) ax.set_ylabel(ylabel) ax.set_xlim(xlim) ax.set_ylim(ylim) ax.set_title(title) _clickable_legend(ax) # fig.suptitle(title) if tpause is not None: plt.pause(tpause) else: plt.tight_layout() if return_fig: if return_ax: if return_lines: return fig, ax, lines else: return fig, ax else: if return_lines: return fig, lines else: return fig elif return_ax: if return_lines: return ax, lines else: return ax elif return_lines: return lines else: return None def _clickable_legend(ax): """ Makes the legend clickable. One click on the legend line toogles the line plot from visible to hidden. """ fig = ax.figure lines_dict = dict() for legend_line, orig_line in zip(ax.legend_.get_lines(), ax.get_lines()): legend_line.set_picker(True) legend_line.set_pickradius(5) lines_dict[legend_line] = orig_line fig.canvas.mpl_connect( 'pick_event', lambda event: _onpick(event, fig, lines_dict)) def _adjust_legend(ax, ncols=None, nlines=None, figsize=None): """ Modify the horizontal size of a plot to add the legend sideways. Parameters ---------- ax : ... ... ncols : int or 'auto' or None If 'auto' (default), the number of columns in the legend is modified automatically and optimized on the basis of the total number of curves. (if nlines is Nones, the legend will use default matplotlib settings) If int, the legend will have 'ncols' columns. If None, the legend will use default matplotlib settings. nlines : int or 'auto' or None If 'auto' (default), the number of lines in the legend is modified automatically and optimized to fit figure height. If None (default), the legend will use default matplotlib settings (no limitations on the number of lines is applied). If int, the legend will have 'nlines' lines. figsize : tuple (Hor, Ver) size of the figure in inches. """ if nlines is not None and not isinstance(nlines, (int, str)): raise TypeError("'nlines' must be int or None.") if ncols is not None and not isinstance(ncols, (int, str)): raise TypeError("'ncols' must be int, str or None.") def get_plot_labels(ax): lab = [] for h in ax.get_lines(): lab.append(h.get_label()) return lab ncurves = len(ax.get_lines()) labels = get_plot_labels(ax) fig = ax.figure if figsize is not None: fig.set_size_inches(figsize) def fig_height(): return fig.get_size_inches()[1] def get_legend_height(nlines): return get_line_height() * nlines def get_max_lines_from_height(height): return floor(height / get_line_height()) def get_line_height(): fontsize = get_legend_fontsize() spacing = plt.rcParams["legend.labelspacing"] points_to_inches = 1 / 72 return fontsize * points_to_inches * (1 + spacing) def get_legend_fontsize(): if isinstance(plt.rcParams["legend.fontsize"], int): return plt.rcParams["legend.fontsize"] sizes = np.array(['xx-small', 'x-small', 'small', 'medium', 'large', 'x-large', 'xx-large']) scales = np.array([1/1.2**3, 1/1.2**2, 1/1.2, 1, 1.2, 1.2**2, 1.2**3]) if isinstance(plt.rcParams["legend.fontsize"], str): default_size = plt.rcParams["font.size"] idx = np.where(sizes == plt.rcParams["legend.fontsize"]) scale = scales[idx][0] return default_size * scale hshrink = 0.15 # factor by which the plot is shrink per legend col def _do_adjust(labels, ncols): x0, y0, width, height = ax.get_position(original=False).bounds fig.subplots_adjust(left=x0, right=x0+width*(1-ncols*hshrink)) ax.legend(labels, loc=6, bbox_to_anchor=(1.01, 0.5), ncol=ncols) if nlines is None and (ncols is None or ncols == 'auto'): # default location should be 'upper right' # to avoid covering TR-WAXS water heating in LS experiments ax.legend(labels, loc='upper right') return if isinstance(nlines, int): if ncols == 'auto': ncols = int(ceil(ncurves/nlines)) elif ncols is None: ncols = 1 labels = labels[:ncols*nlines] elif isinstance(ncols, int): labels = labels[:ncols*nlines] if nlines == 'auto': if ncols == 'auto': nlines = get_max_lines_from_height(fig_height() * 0.9) ncols = ceil(ncurves/nlines) elif isinstance(ncols, int): nlines = ncurves / ncols if get_legend_height(nlines) > fig.get_size_inches()[1] * 0.9: print("WARNING : Figure height is too small, number of " + "entry in legend will be adjusted") nlines = get_max_lines_from_height(fig_height() * 0.9) labels = labels[:ncols*nlines] elif ncols is None: ncols = 1 if ncols * hshrink > 0.8: print("""WARNING : Two many columns, custom layout not applied""") ax.legend(labels, loc=4) return _do_adjust(labels, ncols) def _ref_average(diffs): """ Returns average of reference patterns """ ref_i = diffs["i"][:, np.where(diffs["delays"] == diffs["ref_delay"])] return np.squeeze(ref_i).mean(axis=1)