Direct Ptychography

%matplotlib widget
import abtem
import ase
import py4DSTEM

import numpy as np
import matplotlib.pyplot as plt

import ipywidgets
abtem.config.set({"dask.lazy":False});
file_path = 'data/'
file_data_in_focus = file_path + 'dpc_STO_simulation_in-focus_1e5.h5'
file_data_defocus = file_path + 'dpc_STO_simulation_defocus_1e5.h5'
file_data_aberrated = file_path + 'dpc_STO_simulation_aberrated_1e5.h5'

dataset_in_focus = py4DSTEM.read(file_data_in_focus)
dataset_defocus = py4DSTEM.read(file_data_defocus)
dataset_aberrated = py4DSTEM.read(file_data_aberrated)
dataset_real_space_FFT_in_focus = np.fft.fft2(np.fft.ifftshift(dataset_in_focus.data,axes=(-1,-2)),axes=(0,1))
dataset_real_space_FFT_defocus = np.fft.fft2(np.fft.ifftshift(dataset_defocus.data,axes=(-1,-2)),axes=(0,1))
dataset_real_space_FFT_aberrated = np.fft.fft2(np.fft.ifftshift(dataset_aberrated.data,axes=(-1,-2)),axes=(0,1))
energy = 200e3
semiangle = 20
wavelength = abtem.core.energy.energy2wavelength(energy)
scan_gpts = dataset_in_focus.Rshape
gpts = dataset_in_focus.Qshape

scan_sampling = (dataset_in_focus.calibration.R_pixel_size,)*2
angular_sampling = (dataset_in_focus.calibration.Q_pixel_size,)*2
sampling = (wavelength*1e3/angular_sampling[0]/gpts[0],)*2
def angular_spatial_frequencies(
    kx,
    ky,
    wavelength,
    dtype=None
):
    alpha = np.sqrt(kx[:,None]**2 + ky[None,:]**2) * wavelength
    phi = np.arctan2(ky[None,:],kx[:,None])

    if dtype is not None:
        alpha = alpha.astype(dtype)
        phi = phi.astype(dtype)

    return alpha, phi
kx, ky = abtem.core.grid.spatial_frequencies(gpts, sampling)
alpha, phi = angular_spatial_frequencies(kx,ky,wavelength,dtype=np.float32)

ctf_in_focus = abtem.CTF(
    semiangle_cutoff=semiangle,
    sampling=sampling,
    gpts=gpts,
    energy=energy,
)

ctf_defocus = abtem.CTF(
    semiangle_cutoff=semiangle,
    sampling=sampling,
    gpts=gpts,
    energy=energy,
    C10=-100,
)


ctf_aberrated = abtem.CTF(
    semiangle_cutoff=semiangle,
    sampling=sampling,
    gpts=gpts,
    energy=energy,
    C10=-100,C23=10000,phi23=np.deg2rad(27.5+15),
)

bf_disk_in_focus = ctf_in_focus._evaluate_from_angular_grid(alpha,phi)
bf_disk_defocus = ctf_defocus._evaluate_from_angular_grid(alpha,phi)
bf_disk_aberrated = ctf_aberrated._evaluate_from_angular_grid(alpha,phi)
rotation_angle = np.deg2rad(-15)
ct = np.cos(-rotation_angle)
st = np.sin(-rotation_angle)

qx, qy = abtem.core.grid.spatial_frequencies(scan_gpts, scan_sampling)
qx, qy = np.meshgrid(qx,qy, indexing='ij')
qx, qy = qx * ct - qy * st, qy * ct + qx * st
def return_sorted_indices(dataset_real_space_FFT):
    """ """
    trotter_intensities = (dataset_real_space_FFT * dataset_real_space_FFT.conj()).sum((-1,-2)).real
    trotter_intensities[0,0] = 0.0 # dc component
    
    df_indices = np.sqrt(qx**2 + qy**2) * wavelength * 1e3 > (semiangle * 2)
    trotter_intensities[df_indices] = 0.0 # df trotters
    
    sorted_i, sorted_j = np.unravel_index(
        np.argsort(-trotter_intensities,axis=None),
        trotter_intensities.shape
    )

    return sorted_i, sorted_j
def return_shifted_disks(
    ctf,
    index_i, 
    index_j,
):
    """ """
    KX = kx + qx[index_i,index_j]
    KY = ky + qy[index_i,index_j]
    alpha_plus, phi_plus = angular_spatial_frequencies(KX,KY,wavelength,dtype=np.float32)
    bf_disk_plus = ctf._evaluate_from_angular_grid(alpha_plus,phi_plus)
    
    KX = kx - qx[index_i,index_j]
    KY = ky - qy[index_i,index_j]
    alpha_minus, phi_minus = angular_spatial_frequencies(KX,KY,wavelength,dtype=np.float32)
    bf_disk_minus = ctf._evaluate_from_angular_grid(alpha_minus,phi_minus)

    return bf_disk_plus, bf_disk_minus

def return_phase_compensated_trotter(
    bf_disk,
    bf_disk_plus,
    bf_disk_minus,
    G_array,
):
    """ """
    gamma = bf_disk.conj() * bf_disk_minus - bf_disk * bf_disk_plus.conj()
    gamma_conj = gamma.conj()
    gamma_abs = np.abs(gamma)
    gamma_ind = gamma_abs > 0
    gamma_conj[gamma_ind] /= gamma_abs[gamma_ind]
    return G_array * gamma_conj

def return_masked_trotter(
    bf_disk,
    bf_disk_plus,
    bf_disk_minus,
    G_array,
):
    """ """
    mask = np.abs(bf_disk) * (np.abs(bf_disk_minus) - np.abs(bf_disk_plus))
    return G_array * (mask > 0)
sorted_i, sorted_j = return_sorted_indices(dataset_real_space_FFT_in_focus)
index_i, index_j = sorted_i[0], sorted_j[0]
phase_compensate = False
bf_disk_plus, bf_disk_minus = return_shifted_disks(
    ctf_in_focus,
    index_i,
    index_j
)
phase_compensated_trotter = return_phase_compensated_trotter(
    bf_disk_in_focus,
    bf_disk_plus,
    bf_disk_minus,
    dataset_real_space_FFT_in_focus[index_i,index_j]
)
masked_trotter = return_masked_trotter(
    bf_disk_in_focus,
    bf_disk_plus,
    bf_disk_minus,
    dataset_real_space_FFT_in_focus[index_i,index_j]
)
reconstructed_object = np.zeros(scan_gpts,dtype=np.complex64)
if phase_compensate:
    reconstructed_object[index_i,index_j] = phase_compensated_trotter.sum()
else:
    reconstructed_object[index_i,index_j] = masked_trotter.sum()*2

normalization_in_focus = dataset_in_focus.data.sum((-1,-2)).mean()
normalization_defocus = dataset_defocus.data.sum((-1,-2)).mean()
normalization_aberrated = dataset_aberrated.data.sum((-1,-2)).mean()
with plt.ioff():
    dpi = 72
    fig, axs = plt.subplots(2,3, figsize=(675/dpi, 450/dpi), dpi=dpi)

# raw data
cmplx_data = py4DSTEM.visualize.Complex2RGB(np.fft.fftshift(dataset_real_space_FFT_in_focus[index_i,index_j]))
im_data = axs[0,0].imshow(cmplx_data)

if phase_compensate:
    # gamma
    gamma_data = bf_disk_in_focus.conj() * bf_disk_minus - bf_disk_in_focus * bf_disk_plus.conj()
    gamma_data = py4DSTEM.visualize.Complex2RGB(np.fft.fftshift(gamma_data))
    im_gamma = axs[0,1].imshow(gamma_data)
    
    # compensated trotter
    compensated_trotter_data = py4DSTEM.visualize.Complex2RGB(np.fft.fftshift(phase_compensated_trotter))
    im_compensated_trotter = axs[0,2].imshow(compensated_trotter_data)
else:
    # mask
    mask_data = np.abs(bf_disk_in_focus) * (np.abs(bf_disk_minus) - np.abs(bf_disk_plus))
    im_gamma = axs[0,1].imshow(np.fft.fftshift(mask_data),cmap='RdBu',vmin=-1.5,vmax=1.5)
    
    # compensated trotter
    masked_trotter_data = py4DSTEM.visualize.Complex2RGB(np.fft.fftshift(masked_trotter))
    im_compensated_trotter = axs[0,2].imshow(masked_trotter_data)

# reconstructed psi
reconstructed_data = py4DSTEM.visualize.Complex2RGB(np.fft.fftshift(reconstructed_object),vmin=0,vmax=1,power=0.5)
im_reconstructed = axs[1,0].imshow(reconstructed_data)

# reconstructed obj
psi = reconstructed_object.copy()
psi[0,0] = np.abs(dataset_real_space_FFT_in_focus[0,0]).sum()
obj = np.fft.ifft2(psi) / normalization_in_focus

angle_data, _, _ = py4DSTEM.visualize.return_scaled_histogram_ordering(np.angle(obj),normalize=True)
ampl_data, _, _ = py4DSTEM.visualize.return_scaled_histogram_ordering(2-np.abs(obj),normalize=True)
im_angle = axs[1,1].imshow(np.tile(np.fft.fftshift(angle_data),(2,2)),cmap='magma')
im_ampl = axs[1,2].imshow(np.tile(np.fft.fftshift(ampl_data),(2,2)),cmap='gray')

titles = [
    "measured trotters","modeled trotters masks", "masked measured trotter",
    "reconstructed structure factor", "reconstructed phase", "reconstructed amplitude"
]

phase_compensated_titles = [
    "measured trotters","modeled complex trotter overlap", "phase-compensated trotter",
    "reconstructed structure factor", "reconstructed phase", "reconstructed amplitude"
]


scalebar_real = {'pixelsize':scan_sampling[1],'pixelunits':r'$\AA$',"Nx":scan_gpts[0]*2,"Ny":scan_gpts[1]*2,"labelsize":10}
scalebar_inverse = {'pixelsize':1/scan_gpts[1]/scan_sampling[1],'pixelunits':r'$\AA^{-1}$',"Nx":scan_gpts[0],"Ny":scan_gpts[1],"labelsize":10}
scalebar_fourier = {'pixelsize':angular_sampling[1],'pixelunits':'mrad',"Nx":gpts[0],"Ny":gpts[1],"labelsize":10}

bars = [
    scalebar_fourier,{},{},
    scalebar_inverse,scalebar_real,{}
       ]

for ax, title, bar in zip(axs.flatten(),titles, bars):
    ax.set(xticks=[],yticks=[],title=title)
    if bar:
        py4DSTEM.visualize.add_scalebar(ax,bar)

fig.tight_layout()
fig.canvas.resizable = False
fig.canvas.header_visible = False
fig.canvas.footer_visible = False
fig.canvas.toolbar_visible = True
fig.canvas.layout.width = '680px'
fig.canvas.layout.height = "500px"
fig.canvas.toolbar_position = 'bottom'
None
arrays_to_mutate = [
    ctf_in_focus,
    bf_disk_in_focus,
    dataset_real_space_FFT_in_focus,
    normalization_in_focus
]
layout = ipywidgets.Layout(width='225px',height='30px')
layout_wide = ipywidgets.Layout(width='500px',height='30px')
style = {
    'description_width': 'initial',
}

def reset_potentials(*args):
    """ """
    reconstructed_object[:,:] = 0j
    slider.value=0
    update_plots({'new':slider.value})
    return None

def toggle_phase_compensation(
    change
):
    """ """
    phase_compensate = change['new']
    if phase_compensate:
        axs[0,1].set_title(phase_compensated_titles[1])
        axs[0,2].set_title(phase_compensated_titles[2])
    else:
        axs[0,1].set_title(titles[1])
        axs[0,2].set_title(titles[2])
        
    update_plots({'new':slider.value})
    return None

def choose_dataset(
    change
):
    """ """
    dataset = change['new']
    if dataset == "in focus":
        arrays_to_mutate[0] = ctf_in_focus
        arrays_to_mutate[1] = bf_disk_in_focus
        arrays_to_mutate[2] = dataset_real_space_FFT_in_focus
        arrays_to_mutate[3] = normalization_in_focus
    elif dataset == "defocused":
        arrays_to_mutate[0] = ctf_defocus
        arrays_to_mutate[1] = bf_disk_defocus
        arrays_to_mutate[2] = dataset_real_space_FFT_defocus
        arrays_to_mutate[3] = normalization_defocus
    elif dataset == "aberrated":
        arrays_to_mutate[0] = ctf_aberrated
        arrays_to_mutate[1] = bf_disk_aberrated
        arrays_to_mutate[2] = dataset_real_space_FFT_aberrated
        arrays_to_mutate[3] = normalization_aberrated
    else:
        raise ValueError()
        
    reset_potentials()
    return None

def update_plots(
    change
):
    """ """
    sorted_index = change["new"]
    phase_compensate = phase_compensation_toggle.value
    ctf, bf_disk, dataset_real_space_FFT, normalization = arrays_to_mutate
    sorted_i, sorted_j = return_sorted_indices(dataset_real_space_FFT_in_focus)
    
    # compute
    index_i = sorted_i[sorted_index]
    index_j = sorted_j[sorted_index]

    bf_disk_plus, bf_disk_minus = return_shifted_disks(
        ctf,
        index_i,
        index_j
    )

    phase_compensated_trotter = return_phase_compensated_trotter(
        bf_disk,
        bf_disk_plus,
        bf_disk_minus,
        dataset_real_space_FFT[index_i,index_j]
    )

    masked_trotter = return_masked_trotter(
        bf_disk,
        bf_disk_plus,
        bf_disk_minus,
        dataset_real_space_FFT[index_i,index_j]
    )

    reconstructed_object[index_i,index_j] = phase_compensated_trotter.sum() if phase_compensate else masked_trotter.sum()*2

    # visualize
    
    cmplx_data = py4DSTEM.visualize.Complex2RGB(np.fft.fftshift(dataset_real_space_FFT[index_i,index_j]))
    im_data.set_data(cmplx_data)

    if phase_compensate:
        gamma_data = bf_disk.conj() * bf_disk_minus - bf_disk * bf_disk_plus.conj()
        gamma_data = py4DSTEM.visualize.Complex2RGB(np.fft.fftshift(gamma_data))
        im_gamma.set_data(gamma_data)
        
        compensated_trotter_data = py4DSTEM.visualize.Complex2RGB(np.fft.fftshift(phase_compensated_trotter))
        im_compensated_trotter.set_data(compensated_trotter_data)
    else:
        mask_data = np.abs(bf_disk) * (np.abs(bf_disk_minus) - np.abs(bf_disk_plus))
        im_gamma.set_data(np.fft.fftshift(mask_data))
        
        masked_trotter_data = py4DSTEM.visualize.Complex2RGB(np.fft.fftshift(masked_trotter))
        im_compensated_trotter.set_data(masked_trotter_data)

    reconstructed_data = py4DSTEM.visualize.Complex2RGB(np.fft.fftshift(reconstructed_object),vmin=0,vmax=1,power=0.5)
    im_reconstructed.set_data(reconstructed_data)

    psi = reconstructed_object.copy()
    psi[0,0] = np.abs(dataset_real_space_FFT[0,0]).sum()
    obj = np.fft.ifft2(psi) / normalization
    
    angle_data, _, _ = py4DSTEM.visualize.return_scaled_histogram_ordering(np.angle(obj),normalize=True)
    ampl_data, _, _ = py4DSTEM.visualize.return_scaled_histogram_ordering(2-np.abs(obj),normalize=True)
    
    im_angle.set_data(np.tile(np.fft.fftshift(angle_data),(2,2)))
    im_ampl.set_data(np.tile(np.fft.fftshift(ampl_data),(2,2)))

    fig.canvas.draw_idle()
    return None

dataset_radio = ipywidgets.ToggleButtons(
    options = ["in focus", "defocused", "aberrated"],
    value="in focus",
    style=style,
    layout=layout_wide,
)

reset_button = ipywidgets.Button(
    description="reset reconstructions",
    style=style,
    layout=layout,
)

phase_compensation_toggle = ipywidgets.ToggleButton(
    value=False,
    description="phase compensate",
    style=style,
    layout=layout,
)

play = ipywidgets.Play(
    value=0,
    min=0,
    max=31,
    step=1,
    interval=500,
    style=style,
    layout=layout,
)

slider = ipywidgets.IntSlider(
    min=0,
    max=31,
    step=1,
    layout=layout_wide,
    style=style,
    description="spatial frequency index"
)

ipywidgets.jslink((play, 'value'), (slider, 'value'))
slider.observe(update_plots,"value")
phase_compensation_toggle.observe(toggle_phase_compensation,"value")
reset_button.on_click(reset_potentials)
dataset_radio.observe(choose_dataset,"value")
ipywidgets.VBox(
    [
        ipywidgets.HBox([ipywidgets.Label("simulated dataset:"),dataset_radio]),
        ipywidgets.HBox([reset_button,phase_compensation_toggle,play]),
        slider,
        fig.canvas
    ],
    layout=ipywidgets.Layout(
        align_items="center"
    )
)
Colin Ophus Lab | StanfordColin Ophus Lab | Stanford
Understanding materials, atom by atom — Colin Ophus Lab
Lab Group Website by Curvenote