Contents
Direct Ptychography
%matplotlib widget
import abtem
import ase
import py4DSTEM
import numpy as np
import matplotlib.pyplot as plt
import ipywidgets
abtem.config.set({"dask.lazy":False});
file_path = 'data/'
file_data_in_focus = file_path + 'dpc_STO_simulation_in-focus_1e5.h5'
file_data_defocus = file_path + 'dpc_STO_simulation_defocus_1e5.h5'
file_data_aberrated = file_path + 'dpc_STO_simulation_aberrated_1e5.h5'
dataset_in_focus = py4DSTEM.read(file_data_in_focus)
dataset_defocus = py4DSTEM.read(file_data_defocus)
dataset_aberrated = py4DSTEM.read(file_data_aberrated)
dataset_real_space_FFT_in_focus = np.fft.fft2(np.fft.ifftshift(dataset_in_focus.data,axes=(-1,-2)),axes=(0,1))
dataset_real_space_FFT_defocus = np.fft.fft2(np.fft.ifftshift(dataset_defocus.data,axes=(-1,-2)),axes=(0,1))
dataset_real_space_FFT_aberrated = np.fft.fft2(np.fft.ifftshift(dataset_aberrated.data,axes=(-1,-2)),axes=(0,1))
energy = 200e3
semiangle = 20
wavelength = abtem.core.energy.energy2wavelength(energy)
scan_gpts = dataset_in_focus.Rshape
gpts = dataset_in_focus.Qshape
scan_sampling = (dataset_in_focus.calibration.R_pixel_size,)*2
angular_sampling = (dataset_in_focus.calibration.Q_pixel_size,)*2
sampling = (wavelength*1e3/angular_sampling[0]/gpts[0],)*2
def angular_spatial_frequencies(
kx,
ky,
wavelength,
dtype=None
):
alpha = np.sqrt(kx[:,None]**2 + ky[None,:]**2) * wavelength
phi = np.arctan2(ky[None,:],kx[:,None])
if dtype is not None:
alpha = alpha.astype(dtype)
phi = phi.astype(dtype)
return alpha, phi
kx, ky = abtem.core.grid.spatial_frequencies(gpts, sampling)
alpha, phi = angular_spatial_frequencies(kx,ky,wavelength,dtype=np.float32)
ctf_in_focus = abtem.CTF(
semiangle_cutoff=semiangle,
sampling=sampling,
gpts=gpts,
energy=energy,
)
ctf_defocus = abtem.CTF(
semiangle_cutoff=semiangle,
sampling=sampling,
gpts=gpts,
energy=energy,
C10=-100,
)
ctf_aberrated = abtem.CTF(
semiangle_cutoff=semiangle,
sampling=sampling,
gpts=gpts,
energy=energy,
C10=-100,C23=10000,phi23=np.deg2rad(27.5+15),
)
bf_disk_in_focus = ctf_in_focus._evaluate_from_angular_grid(alpha,phi)
bf_disk_defocus = ctf_defocus._evaluate_from_angular_grid(alpha,phi)
bf_disk_aberrated = ctf_aberrated._evaluate_from_angular_grid(alpha,phi)
rotation_angle = np.deg2rad(-15)
ct = np.cos(-rotation_angle)
st = np.sin(-rotation_angle)
qx, qy = abtem.core.grid.spatial_frequencies(scan_gpts, scan_sampling)
qx, qy = np.meshgrid(qx,qy, indexing='ij')
qx, qy = qx * ct - qy * st, qy * ct + qx * st
def return_sorted_indices(dataset_real_space_FFT):
""" """
trotter_intensities = (dataset_real_space_FFT * dataset_real_space_FFT.conj()).sum((-1,-2)).real
trotter_intensities[0,0] = 0.0 # dc component
df_indices = np.sqrt(qx**2 + qy**2) * wavelength * 1e3 > (semiangle * 2)
trotter_intensities[df_indices] = 0.0 # df trotters
sorted_i, sorted_j = np.unravel_index(
np.argsort(-trotter_intensities,axis=None),
trotter_intensities.shape
)
return sorted_i, sorted_j
def return_shifted_disks(
ctf,
index_i,
index_j,
):
""" """
KX = kx + qx[index_i,index_j]
KY = ky + qy[index_i,index_j]
alpha_plus, phi_plus = angular_spatial_frequencies(KX,KY,wavelength,dtype=np.float32)
bf_disk_plus = ctf._evaluate_from_angular_grid(alpha_plus,phi_plus)
KX = kx - qx[index_i,index_j]
KY = ky - qy[index_i,index_j]
alpha_minus, phi_minus = angular_spatial_frequencies(KX,KY,wavelength,dtype=np.float32)
bf_disk_minus = ctf._evaluate_from_angular_grid(alpha_minus,phi_minus)
return bf_disk_plus, bf_disk_minus
def return_phase_compensated_trotter(
bf_disk,
bf_disk_plus,
bf_disk_minus,
G_array,
):
""" """
gamma = bf_disk.conj() * bf_disk_minus - bf_disk * bf_disk_plus.conj()
gamma_conj = gamma.conj()
gamma_abs = np.abs(gamma)
gamma_ind = gamma_abs > 0
gamma_conj[gamma_ind] /= gamma_abs[gamma_ind]
return G_array * gamma_conj
def return_masked_trotter(
bf_disk,
bf_disk_plus,
bf_disk_minus,
G_array,
):
""" """
mask = np.abs(bf_disk) * (np.abs(bf_disk_minus) - np.abs(bf_disk_plus))
return G_array * (mask > 0)
sorted_i, sorted_j = return_sorted_indices(dataset_real_space_FFT_in_focus)
index_i, index_j = sorted_i[0], sorted_j[0]
phase_compensate = False
bf_disk_plus, bf_disk_minus = return_shifted_disks(
ctf_in_focus,
index_i,
index_j
)
phase_compensated_trotter = return_phase_compensated_trotter(
bf_disk_in_focus,
bf_disk_plus,
bf_disk_minus,
dataset_real_space_FFT_in_focus[index_i,index_j]
)
masked_trotter = return_masked_trotter(
bf_disk_in_focus,
bf_disk_plus,
bf_disk_minus,
dataset_real_space_FFT_in_focus[index_i,index_j]
)
reconstructed_object = np.zeros(scan_gpts,dtype=np.complex64)
if phase_compensate:
reconstructed_object[index_i,index_j] = phase_compensated_trotter.sum()
else:
reconstructed_object[index_i,index_j] = masked_trotter.sum()*2
normalization_in_focus = dataset_in_focus.data.sum((-1,-2)).mean()
normalization_defocus = dataset_defocus.data.sum((-1,-2)).mean()
normalization_aberrated = dataset_aberrated.data.sum((-1,-2)).mean()
with plt.ioff():
dpi = 72
fig, axs = plt.subplots(2,3, figsize=(675/dpi, 450/dpi), dpi=dpi)
# raw data
cmplx_data = py4DSTEM.visualize.Complex2RGB(np.fft.fftshift(dataset_real_space_FFT_in_focus[index_i,index_j]))
im_data = axs[0,0].imshow(cmplx_data)
if phase_compensate:
# gamma
gamma_data = bf_disk_in_focus.conj() * bf_disk_minus - bf_disk_in_focus * bf_disk_plus.conj()
gamma_data = py4DSTEM.visualize.Complex2RGB(np.fft.fftshift(gamma_data))
im_gamma = axs[0,1].imshow(gamma_data)
# compensated trotter
compensated_trotter_data = py4DSTEM.visualize.Complex2RGB(np.fft.fftshift(phase_compensated_trotter))
im_compensated_trotter = axs[0,2].imshow(compensated_trotter_data)
else:
# mask
mask_data = np.abs(bf_disk_in_focus) * (np.abs(bf_disk_minus) - np.abs(bf_disk_plus))
im_gamma = axs[0,1].imshow(np.fft.fftshift(mask_data),cmap='RdBu',vmin=-1.5,vmax=1.5)
# compensated trotter
masked_trotter_data = py4DSTEM.visualize.Complex2RGB(np.fft.fftshift(masked_trotter))
im_compensated_trotter = axs[0,2].imshow(masked_trotter_data)
# reconstructed psi
reconstructed_data = py4DSTEM.visualize.Complex2RGB(np.fft.fftshift(reconstructed_object),vmin=0,vmax=1,power=0.5)
im_reconstructed = axs[1,0].imshow(reconstructed_data)
# reconstructed obj
psi = reconstructed_object.copy()
psi[0,0] = np.abs(dataset_real_space_FFT_in_focus[0,0]).sum()
obj = np.fft.ifft2(psi) / normalization_in_focus
angle_data, _, _ = py4DSTEM.visualize.return_scaled_histogram_ordering(np.angle(obj),normalize=True)
ampl_data, _, _ = py4DSTEM.visualize.return_scaled_histogram_ordering(2-np.abs(obj),normalize=True)
im_angle = axs[1,1].imshow(np.tile(np.fft.fftshift(angle_data),(2,2)),cmap='magma')
im_ampl = axs[1,2].imshow(np.tile(np.fft.fftshift(ampl_data),(2,2)),cmap='gray')
titles = [
"measured trotters","modeled trotters masks", "masked measured trotter",
"reconstructed structure factor", "reconstructed phase", "reconstructed amplitude"
]
phase_compensated_titles = [
"measured trotters","modeled complex trotter overlap", "phase-compensated trotter",
"reconstructed structure factor", "reconstructed phase", "reconstructed amplitude"
]
scalebar_real = {'pixelsize':scan_sampling[1],'pixelunits':r'$\AA$',"Nx":scan_gpts[0]*2,"Ny":scan_gpts[1]*2,"labelsize":10}
scalebar_inverse = {'pixelsize':1/scan_gpts[1]/scan_sampling[1],'pixelunits':r'$\AA^{-1}$',"Nx":scan_gpts[0],"Ny":scan_gpts[1],"labelsize":10}
scalebar_fourier = {'pixelsize':angular_sampling[1],'pixelunits':'mrad',"Nx":gpts[0],"Ny":gpts[1],"labelsize":10}
bars = [
scalebar_fourier,{},{},
scalebar_inverse,scalebar_real,{}
]
for ax, title, bar in zip(axs.flatten(),titles, bars):
ax.set(xticks=[],yticks=[],title=title)
if bar:
py4DSTEM.visualize.add_scalebar(ax,bar)
fig.tight_layout()
fig.canvas.resizable = False
fig.canvas.header_visible = False
fig.canvas.footer_visible = False
fig.canvas.toolbar_visible = True
fig.canvas.layout.width = '680px'
fig.canvas.layout.height = "500px"
fig.canvas.toolbar_position = 'bottom'
None
arrays_to_mutate = [
ctf_in_focus,
bf_disk_in_focus,
dataset_real_space_FFT_in_focus,
normalization_in_focus
]
layout = ipywidgets.Layout(width='225px',height='30px')
layout_wide = ipywidgets.Layout(width='500px',height='30px')
style = {
'description_width': 'initial',
}
def reset_potentials(*args):
""" """
reconstructed_object[:,:] = 0j
slider.value=0
update_plots({'new':slider.value})
return None
def toggle_phase_compensation(
change
):
""" """
phase_compensate = change['new']
if phase_compensate:
axs[0,1].set_title(phase_compensated_titles[1])
axs[0,2].set_title(phase_compensated_titles[2])
else:
axs[0,1].set_title(titles[1])
axs[0,2].set_title(titles[2])
update_plots({'new':slider.value})
return None
def choose_dataset(
change
):
""" """
dataset = change['new']
if dataset == "in focus":
arrays_to_mutate[0] = ctf_in_focus
arrays_to_mutate[1] = bf_disk_in_focus
arrays_to_mutate[2] = dataset_real_space_FFT_in_focus
arrays_to_mutate[3] = normalization_in_focus
elif dataset == "defocused":
arrays_to_mutate[0] = ctf_defocus
arrays_to_mutate[1] = bf_disk_defocus
arrays_to_mutate[2] = dataset_real_space_FFT_defocus
arrays_to_mutate[3] = normalization_defocus
elif dataset == "aberrated":
arrays_to_mutate[0] = ctf_aberrated
arrays_to_mutate[1] = bf_disk_aberrated
arrays_to_mutate[2] = dataset_real_space_FFT_aberrated
arrays_to_mutate[3] = normalization_aberrated
else:
raise ValueError()
reset_potentials()
return None
def update_plots(
change
):
""" """
sorted_index = change["new"]
phase_compensate = phase_compensation_toggle.value
ctf, bf_disk, dataset_real_space_FFT, normalization = arrays_to_mutate
sorted_i, sorted_j = return_sorted_indices(dataset_real_space_FFT_in_focus)
# compute
index_i = sorted_i[sorted_index]
index_j = sorted_j[sorted_index]
bf_disk_plus, bf_disk_minus = return_shifted_disks(
ctf,
index_i,
index_j
)
phase_compensated_trotter = return_phase_compensated_trotter(
bf_disk,
bf_disk_plus,
bf_disk_minus,
dataset_real_space_FFT[index_i,index_j]
)
masked_trotter = return_masked_trotter(
bf_disk,
bf_disk_plus,
bf_disk_minus,
dataset_real_space_FFT[index_i,index_j]
)
reconstructed_object[index_i,index_j] = phase_compensated_trotter.sum() if phase_compensate else masked_trotter.sum()*2
# visualize
cmplx_data = py4DSTEM.visualize.Complex2RGB(np.fft.fftshift(dataset_real_space_FFT[index_i,index_j]))
im_data.set_data(cmplx_data)
if phase_compensate:
gamma_data = bf_disk.conj() * bf_disk_minus - bf_disk * bf_disk_plus.conj()
gamma_data = py4DSTEM.visualize.Complex2RGB(np.fft.fftshift(gamma_data))
im_gamma.set_data(gamma_data)
compensated_trotter_data = py4DSTEM.visualize.Complex2RGB(np.fft.fftshift(phase_compensated_trotter))
im_compensated_trotter.set_data(compensated_trotter_data)
else:
mask_data = np.abs(bf_disk) * (np.abs(bf_disk_minus) - np.abs(bf_disk_plus))
im_gamma.set_data(np.fft.fftshift(mask_data))
masked_trotter_data = py4DSTEM.visualize.Complex2RGB(np.fft.fftshift(masked_trotter))
im_compensated_trotter.set_data(masked_trotter_data)
reconstructed_data = py4DSTEM.visualize.Complex2RGB(np.fft.fftshift(reconstructed_object),vmin=0,vmax=1,power=0.5)
im_reconstructed.set_data(reconstructed_data)
psi = reconstructed_object.copy()
psi[0,0] = np.abs(dataset_real_space_FFT[0,0]).sum()
obj = np.fft.ifft2(psi) / normalization
angle_data, _, _ = py4DSTEM.visualize.return_scaled_histogram_ordering(np.angle(obj),normalize=True)
ampl_data, _, _ = py4DSTEM.visualize.return_scaled_histogram_ordering(2-np.abs(obj),normalize=True)
im_angle.set_data(np.tile(np.fft.fftshift(angle_data),(2,2)))
im_ampl.set_data(np.tile(np.fft.fftshift(ampl_data),(2,2)))
fig.canvas.draw_idle()
return None
dataset_radio = ipywidgets.ToggleButtons(
options = ["in focus", "defocused", "aberrated"],
value="in focus",
style=style,
layout=layout_wide,
)
reset_button = ipywidgets.Button(
description="reset reconstructions",
style=style,
layout=layout,
)
phase_compensation_toggle = ipywidgets.ToggleButton(
value=False,
description="phase compensate",
style=style,
layout=layout,
)
play = ipywidgets.Play(
value=0,
min=0,
max=31,
step=1,
interval=500,
style=style,
layout=layout,
)
slider = ipywidgets.IntSlider(
min=0,
max=31,
step=1,
layout=layout_wide,
style=style,
description="spatial frequency index"
)
ipywidgets.jslink((play, 'value'), (slider, 'value'))
slider.observe(update_plots,"value")
phase_compensation_toggle.observe(toggle_phase_compensation,"value")
reset_button.on_click(reset_potentials)
dataset_radio.observe(choose_dataset,"value")
ipywidgets.VBox(
[
ipywidgets.HBox([ipywidgets.Label("simulated dataset:"),dataset_radio]),
ipywidgets.HBox([reset_button,phase_compensation_toggle,play]),
slider,
fig.canvas
],
layout=ipywidgets.Layout(
align_items="center"
)
)
VBox(children=(HBox(children=(Label(value='simulated dataset:'), ToggleButtons(layout=Layout(height='30px', wi…