Contents
PRISM Algorithm
%matplotlib widget
import py4DSTEM
import abtem
import ase
import ipywidgets
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
abtem.config.set({"dask.lazy":False});
Converged Probe Expansion¶
# parameters
gpts = np.array([128,128])
sampling = np.array([0.125,0.125])
extent = gpts*sampling
energy = 300e3
semiangle_cutoff = 20
defocus = 100
arrays_mutated = [
None, # s_matrix
None, # spiral_ordering
None, # indices_i
None, # indices_j
None, # pos_px
None, # position_coefs
None, # defocus
None, # ctf_coefs
]
# inputs a
interpolation_factor = 1
dummy_s_matrix = abtem.SMatrix(
potential=None,
sampling=sampling,
gpts=gpts,
energy=energy,
semiangle_cutoff=semiangle_cutoff,
downsample=None,
interpolation=interpolation_factor,
).build(
)
arrays_mutated[0] = dummy_s_matrix
def return_ordering_indices(s_matrix):
wave_vectors = s_matrix.wave_vectors
spiral_ordering = np.argsort(np.sum(wave_vectors**2,1))
indices_i, indices_j = np.mod(
(wave_vectors * s_matrix.extent).astype("int"),
s_matrix.gpts
).T
return spiral_ordering, indices_i, indices_j
arrays_mutated[1:4] = return_ordering_indices(arrays_mutated[0])
# inputs b
defocus = 100
pos_px = gpts/2
def return_position_coefs(s_matrix, pos_px):
pos = (pos_px-gpts/2) * np.array(s_matrix.sampling)
position_coefs = s_matrix._calculate_positions_coefficients(
abtem.scan.CustomScan(pos),
)[0]
return position_coefs
def return_ctf_coefs(s_matrix, defocus):
ctf_coefs = s_matrix._calculate_ctf_coefficients(
ctf=abtem.CTF(
defocus=defocus,
semiangle_cutoff=np.inf,
energy=energy
)
)
return ctf_coefs
arrays_mutated[4] = pos_px
arrays_mutated[5] = return_position_coefs(arrays_mutated[0],arrays_mutated[4])
arrays_mutated[6] = defocus
arrays_mutated[7] = return_ctf_coefs(arrays_mutated[0],arrays_mutated[6])
# inputs c
num_planewaves = 5
def return_selected_inds(
num_planewaves,
spiral_ordering,
indices_i,
indices_j,
):
inds = spiral_ordering[:num_planewaves]
inds_i = indices_i[inds]
inds_j = indices_j[inds]
return inds, inds_i, inds_j
def return_arrays(
s_matrix,
inds,
inds_i,
inds_j,
position_coefs,
ctf_coefs,
):
coefs = ctf_coefs[inds] * position_coefs[inds]
beams = np.zeros(s_matrix.gpts,dtype=np.complex64)
beams[inds_i,inds_j]= coefs
scaled_beams = s_matrix.array[inds] * coefs[:,None,None]
probe = scaled_beams.sum(0)
return beams, scaled_beams, probe
inds, inds_i, inds_j = return_selected_inds(
num_planewaves,
arrays_mutated[1],
arrays_mutated[2],
arrays_mutated[3]
)
beams, scaled_beams, probe = return_arrays(
arrays_mutated[0],
inds,
inds_i,
inds_j,
arrays_mutated[5],
arrays_mutated[7]
)
dpi=72
with plt.ioff():
fig = plt.figure(figsize=(675/dpi,235/dpi),dpi=dpi)
spec = GridSpec(2,6,figure=fig)
ax1 = fig.add_subplot(spec[:,:2])
ax2a = fig.add_subplot(spec[0,2])
ax2b = fig.add_subplot(spec[0,3])
ax2c = fig.add_subplot(spec[1,2])
ax2d = fig.add_subplot(spec[1,3])
ax3 = fig.add_subplot(spec[:,4:])
beams_rgb = py4DSTEM.visualize.Complex2RGB(
np.fft.fftshift(beams),
vmin=0,vmax=1
)
im_beams = ax1.imshow(beams_rgb)
ax1.set_title("Fourier-space beam coefficients")
plane_tgb_tl = py4DSTEM.visualize.Complex2RGB(scaled_beams[np.maximum(-num_planewaves,-4)])
im_plane_tl = ax2a.imshow(plane_tgb_tl)
if num_planewaves > 1:
plane_tgb_tr = py4DSTEM.visualize.Complex2RGB(scaled_beams[np.maximum(-num_planewaves+1,-3)])
im_plane_tr = ax2b.imshow(plane_tgb_tr)
else:
im_plane_tr = ax2b.imshow(np.ones_like(plane_tgb_tl))
ax2b.set_visible(False)
if num_planewaves > 2:
plane_tgb_bl = py4DSTEM.visualize.Complex2RGB(scaled_beams[np.maximum(-num_planewaves+2,-2)])
im_plane_bl = ax2c.imshow(plane_tgb_bl)
else:
im_plane_bl = ax2c.imshow(np.ones_like(plane_tgb_tl))
ax2c.set_visible(False)
if num_planewaves > 3:
plane_tgb_br = py4DSTEM.visualize.Complex2RGB(scaled_beams[np.maximum(-num_planewaves+3,-1)])
im_plane_br = ax2d.imshow(plane_tgb_br)
else:
im_plane_br = ax2d.imshow(np.ones_like(plane_tgb_tl))
ax2d.set_visible(False)
fig.text(
0.5,
0.95,
"largest-4 frequency beams",
horizontalalignment="center",
fontsize=12
)
probe_rgb = py4DSTEM.visualize.Complex2RGB(np.fft.fftshift(probe))
im_probe = ax3.imshow(probe_rgb)
ax3.set_title("Real-space converged probe")
for ax in fig.axes:
ax.set(xticks=[],yticks=[])
fig.canvas.resizable = False
fig.canvas.header_visible = False
fig.canvas.footer_visible = False
fig.canvas.toolbar_visible = True
fig.canvas.layout.width = '675px'
fig.canvas.layout.height = '265px'
fig.canvas.toolbar_position = 'bottom'
spec.tight_layout(fig)
None
layout = ipywidgets.Layout(width='330px',height='30px')
style = {
'description_width': 'initial',
}
def update_num_planewaves(change):
num_planewaves = change["new"]
inds, inds_i, inds_j = return_selected_inds(
num_planewaves,
arrays_mutated[1],
arrays_mutated[2],
arrays_mutated[3]
)
beams, scaled_beams, probe = return_arrays(
arrays_mutated[0],
inds,inds_i,inds_j,
arrays_mutated[5],
arrays_mutated[7]
)
beams_rgb = py4DSTEM.visualize.Complex2RGB(
np.fft.fftshift(beams),
vmin=0,vmax=1
)
im_beams.set_data(beams_rgb)
plane_tgb_tl = py4DSTEM.visualize.Complex2RGB(scaled_beams[np.maximum(-num_planewaves,-4)])
im_plane_tl.set_data(plane_tgb_tl)
if num_planewaves > 1:
plane_tgb_tr = py4DSTEM.visualize.Complex2RGB(scaled_beams[np.maximum(-num_planewaves+1,-3)])
im_plane_tr.set_data(plane_tgb_tr)
ax2b.set_visible(True)
else:
ax2b.set_visible(False)
if num_planewaves > 2:
plane_tgb_bl = py4DSTEM.visualize.Complex2RGB(scaled_beams[np.maximum(-num_planewaves+2,-2)])
im_plane_bl.set_data(plane_tgb_bl)
ax2c.set_visible(True)
else:
ax2c.set_visible(False)
if num_planewaves > 3:
plane_tgb_br = py4DSTEM.visualize.Complex2RGB(scaled_beams[np.maximum(-num_planewaves+3,-1)])
im_plane_br.set_data(plane_tgb_br)
ax2d.set_visible(True)
else:
ax2d.set_visible(False)
probe_rgb = py4DSTEM.visualize.Complex2RGB(np.fft.fftshift(probe))
im_probe.set_data(probe_rgb)
fig.canvas.draw_idle()
return None
play = ipywidgets.Play(
value=5,
min=1,
max=len(arrays_mutated[0]),
step=1,
interval=25,
)
slider = ipywidgets.IntSlider(
min=1,
max=len(arrays_mutated[0]),
step=1,
layout=layout,
style=style,
description="# of beams"
)
ipywidgets.jslink((play, 'value'), (slider, 'value'))
slider.observe(update_num_planewaves,"value")
def update_interpolation_factor(change):
interpolation_factor = change["new"]
s_matrix = abtem.SMatrix(
potential=None,
sampling=sampling,
gpts=gpts,
energy=energy,
semiangle_cutoff=semiangle_cutoff,
downsample=None,
interpolation=interpolation_factor,
).build(
)
arrays_mutated[0] = s_matrix
arrays_mutated[1:4] = return_ordering_indices(arrays_mutated[0])
arrays_mutated[5] = return_position_coefs(arrays_mutated[0],arrays_mutated[4])
arrays_mutated[7] = return_ctf_coefs(arrays_mutated[0],arrays_mutated[6])
slider.max = len(arrays_mutated[0])
if slider.value < slider.max:
update_num_planewaves({"new":slider.value})
return None
def update_defocus(change):
arrays_mutated[6] = change["new"]
arrays_mutated[7] = return_ctf_coefs(arrays_mutated[0],arrays_mutated[6])
update_num_planewaves({"new":slider.value})
return None
interpolation_slider = ipywidgets.SelectionSlider(
options=[1,2,4,8,16],
value=1,
description="interpolation factor",
layout=layout,
style=style,
)
interpolation_slider.observe(update_interpolation_factor,"value")
defocus_slider = ipywidgets.FloatSlider(
min=-150,
max=150,
value=100,
description="defocus [Å]",
layout=layout,
style=style,
)
defocus_slider.observe(update_defocus,"value")
def onclick(event):
""" """
positions_px = np.array([event.ydata,event.xdata])
if positions_px[0] is not None:
arrays_mutated[4] = positions_px
arrays_mutated[5] = return_position_coefs(arrays_mutated[0],arrays_mutated[4])
update_num_planewaves({"new":slider.value})
cid = fig.canvas.mpl_connect('button_press_event',onclick)
ipywidgets.VBox(
[
ipywidgets.HBox([slider,play]),
ipywidgets.HBox([interpolation_slider,defocus_slider]),
fig.canvas
]
)
VBox(children=(HBox(children=(IntSlider(value=5, description='# of beams', layout=Layout(height='30px', width=…