"""Plotting functions to ease the analysis process.
.. warning::
These functions make use of ipywidgets and are meant to be used within a
Jupyter notebook with the 'magic' `%matplotlib widget`.
"""
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.cm import get_cmap
from ipywidgets import widgets
from txs.heating import remove_heating
from txs.utils import str2t
from txs.analysis.utils import sort_delays
from txs.analysis.svd import SVD
[docs]
def plot_dataset(
data,
sample_names,
heating=None,
heating_qlim=(1.6, 2),
q_scaled=False
):
"""Plot the processed data for each time delays."""
fig, ax = plt.subplots(figsize=(10, 6))
incr = 0.001
ylabel = 'q' if q_scaled else ''
delays = {}
for s_idx, sample in enumerate(data):
for _, delay in enumerate(sample['t']):
if delay not in delays:
delays[delay] = []
delays[delay].append(s_idx)
delay_slider = widgets.IntSlider(
min=0, max=len(delays.keys()) - 1, value=0, step=1, description="delay"
)
cmap = get_cmap('tab10')
@widgets.interact(delay=delay_slider)
def update_plot(delay):
ax.clear()
t_key = list(delays.keys())[delay]
for idx, s_idx in enumerate(delays[t_key]):
q = data[s_idx]['q']
q_scale = q if q_scaled else np.ones_like(q)
t_idx = list(data[s_idx]['t']).index(t_key)
ax.errorbar(
q,
q_scale * data[s_idx]['diff_av'][:, t_idx] + idx * incr,
label=sample_names[s_idx],
color=cmap(idx),
alpha=0.4 if heating is not None else 1,
)
if heating is not None:
red2 = remove_heating(
data[s_idx], heating, qlim=heating_qlim, verbose=False
)
ax.errorbar(
q,
q_scale * red2['diff_av'][:, t_idx] + idx * incr,
color=cmap(idx),
)
ax.axhline(idx * incr, ls=":", alpha=0.2)
ax.set_xlabel("q [$\\rm \AA^{-1}$]")
ax.set_ylabel(f"{ylabel}S(q)")
ax.legend(loc=2, bbox_to_anchor=(1, 1))
ax.set_title(
f"$\\rm \Delta t$ = {t_key}\n" +
"semi-transparent -> " if heating is not None else "" +
"no heating removal" if heating is not None else ""
)
ax.relim()
ax.autoscale_view()
fig.tight_layout()
update_plot(0)
[docs]
def plot_svd_analysis(
data,
heating=None,
heating_qlim=(1.8, 2),
q_scaled=False,
include_ref=False,
figsize=(10, 10),
plot_grid=None,
ref_delays=None,
):
"""Performs and plot a SVD analysis on the provided data."""
if heating is not None:
data = remove_heating(data, heating, qlim=heating_qlim, verbose=False)
q = data['q']
q_scale = q if q_scaled else np.ones_like(q)
ylabel = "q" if q_scaled else ""
data = sort_delays(data, ref_delays=ref_delays, include_ref=include_ref)
t = str2t(data['t'])
svd = SVD(data['diff_av']).run()
fig = plt.figure(figsize=figsize)
ax = []
if plot_grid is None:
gs = plt.GridSpec(3, 2, height_ratios=(1, 1, 1))
else:
gs = plot_grid
ax.append(fig.add_subplot(gs[0, :]))
ax.append(fig.add_subplot(gs[1, :], sharex=ax[0]))
ax.append(fig.add_subplot(gs[2, 0]))
ax.append(fig.add_subplot(gs[2, 1]))
cursor = Cursor(ax[:2])
cmap = get_cmap('gnuplot')
svd_rank = widgets.IntSlider(
min=0,
max=10,
step=1,
value=0,
description="svd rank",
)
incr_slider = widgets.FloatSlider(
min=0, max=0.01, value=0.0005, step=0.0002, description="increment"
)
leg_lines = {}
def on_pick(event):
legline = event.artist
for line in leg_lines[legline]:
visible = not line.get_visible()
line.set_visible(visible)
legline.set_alpha(1.0 if visible else 0.2)
@widgets.interact(rank=svd_rank, incr=incr_slider)
def update_plot(rank, incr, keep_lims=True):
ax0_lims = (ax[0].get_xlim(), ax[0].get_ylim())
ax1_lims = (ax[1].get_xlim(), ax[1].get_ylim())
ax[0].clear()
correlations = svd.autocorr()
if rank == 0:
rank = np.arange(correlations[0].size)[correlations[0] > 0.5]
lines = []
leg_lines.clear()
for idx, val in reversed(list(enumerate(data['diff_av'].T))):
line_group = []
line = ax[0].plot(
data['q'],
q_scale * val + incr * idx,
# savgol_filter(val, 15, 3) + incr * idx,
color=cmap(idx / data['diff_av'].shape[1]),
alpha=0.4
)
line_group.append(line[0])
line = ax[0].plot(
data['q'],
q_scale * svd.recompose(rank)[:, idx] + incr * idx,
# savgol_filter(svd.recompose(rank)[:, idx], 15, 3) + incr * idx,
color=cmap(idx / data['diff_av'].shape[1]),
label=data['t'][idx],
ls='--'
)
line_group.append(line[0])
line = ax[0].axhline(incr * idx, color='black', ls=':', alpha=0.2)
line_group.append(line)
lines.append(line_group)
ax[0].set_xlabel("q [$\\rm \AA^{-1}$]")
ax[0].set_ylabel(f'$\\rm {ylabel}\Delta S(q)$ ')
ax[0].grid(alpha=0.2)
if keep_lims:
ax[0].set_xlim(ax0_lims[0])
ax[0].set_ylim(ax0_lims[1])
leg = ax[0].legend(loc=2, bbox_to_anchor=(1, 1), fontsize=10)
for legline, origlines in zip(leg.get_lines(), lines):
legline.set_picker(True)
legline.set_pickradius(10)
leg_lines[legline] = origlines
patterns = svd.patterns(rank)
ax[1].clear()
for idx, val in enumerate(patterns[0].T):
ax[1].plot(
data['q'],
q_scale * val + incr * idx,
# savgol_filter(val, 15, 3) + incr * idx
)
ax[1].axhline(incr * idx, color='black', ls=':', alpha=0.2)
ax[1].set_xlabel("q [$\\rm \AA^{-1}$]")
ax[1].set_ylabel(f'{ylabel}U vectors\n[arb. units]')
ax[1].grid(alpha=0.2)
if keep_lims:
ax[1].set_xlim(ax1_lims[0])
ax[1].set_ylim(ax1_lims[1])
ax[1].autoscale_view()
ax[2].clear()
# patterns = patterns[1].T / np.max(np.abs(patterns[1])) + 0.5
patterns = patterns[1].T
ax[2].plot(
t,
patterns,
)
ax[2].set_xlabel('time')
ax[2].set_ylabel('V vectors\n[arb. units]')
# ax[2].set_xticklabels(delays, rotation=-45, ha='left')
ax[2].grid(alpha=0.2)
ax[2].axhline(0, color='black', ls=':')
ax[2].relim()
ax[2].autoscale_view()
ax[3].clear()
ax[3].plot(correlations[0], marker='o', label='U vectors')
ax[3].plot(correlations[1], marker='^', label='V vectors')
ax[3].axhline(0.5, color='black', ls=':')
ax[3].set_xlabel('rank')
ax[3].set_ylabel('autocorrelations\n[arb. units]')
ax[3].grid(alpha=0.2)
ax[3].relim()
ax[3].autoscale_view()
ax[3].legend(loc=2, bbox_to_anchor=(1, 1))
cursor.drawline()
update_plot(0, 0.0005, False)
fig.tight_layout()
fig.canvas.mpl_connect('motion_notify_event', cursor.on_mouse_move)
fig.canvas.mpl_connect('button_press_event', cursor.on_mouse_press)
fig.canvas.mpl_connect('pick_event', on_pick)
return data, svd
[docs]
class Cursor:
"""A vertical line cursor."""
def __init__(self, ax):
if not isinstance(ax, list):
ax = [ax]
self.ax = ax
self.vertical_line = []
self.persistent_lines = []
self.drawline()
def drawline(self):
self.vertical_line = [
val.axvline(color='k', lw=0.8, ls='--', alpha=0.6)
for val in self.ax
]
def on_mouse_move(self, event):
if not event.inaxes:
for line in self.vertical_line:
line.set_visible(False)
pass
else:
for line in self.vertical_line:
line.set_visible(True)
x = event.xdata
# update the line positions
for idx, val in enumerate(self.ax):
self.vertical_line[idx].set_xdata([x])
# val.figure.canvas.draw()
def on_mouse_press(self, event):
if not event.inaxes:
pass
else:
x = event.xdata
side = event.button
if side == 1:
for val in self.ax:
line = val.axvline(x, color='k', lw=0.8, ls='--', alpha=0.6)
self.persistent_lines.append(line)
if side == 3:
for line in self.persistent_lines:
line.remove()