Evaluating the Transfer of Information in Phase Retrieval STEM Techniques

Pixelated Iterative Ptychography

# enable interactive matplotlib
%matplotlib widget 

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
import ctf # import custom plotting / utils
import cmasher as cmr

from tqdm.notebook import tqdm
import ipywidgets

4D STEM Simulation

# parameters
n = 96
q_max = 2 # inverse Angstroms
q_probe = 1 # inverse Angstroms
wavelength = 0.019687 # 300kV
sampling = 1 / q_max / 2 # Angstroms
reciprocal_sampling = 2 * q_max / n # inverse Angstroms

scan_step_size = 1 # pixels
sx = sy = n//scan_step_size
phi0 = 1.0

C10 = 0
C30 = 0

cmap = cmr.viola
sample_cmap = 'gray'
iter_ptycho_line_color = 'darkmagenta'
sto_potential = np.load("data/STO_projected-potential_192x192_4qprobe.npy")
sto_potential -= sto_potential.mean()
mof_potential = np.load("data/MOF_projected-potential_192x192_4qprobe.npy")
mof_potential -= mof_potential.mean()
apo_potential = np.load("data/apoF_projected-potential_192x192_4qprobe.npy")
apo_potential -= apo_potential.mean()

potentials = [sto_potential,mof_potential,apo_potential]

sto_sampling = 23.67 / sto_potential.shape[0]  # Å
mof_sampling = 4.48 / mof_potential.shape[0]  # nm
apo_sampling = 19.2 / apo_potential.shape[0]  # nm

White Noise Potential

def white_noise_object_2D(n, phi0):
    """ creates a 2D real-valued array, whose FFT has random phase and constant amplitude """

    evenQ = n%2 == 0
    
    # indices
    pos_ind = np.arange(1,(n if evenQ else n+1)//2)
    neg_ind = np.flip(np.arange(n//2+1,n))

    # random phase
    arr = np.random.randn(n,n)
    
    # top-left // bottom-right
    arr[pos_ind[:,None],pos_ind[None,:]] = -arr[neg_ind[:,None],neg_ind[None,:]]
    # bottom-left // top-right
    arr[pos_ind[:,None],neg_ind[None,:]] = -arr[neg_ind[:,None],pos_ind[None,:]]
    # kx=0
    arr[0,pos_ind] = -arr[0,neg_ind]
    # ky=0
    arr[pos_ind,0] = -arr[neg_ind,0]

    # zero-out components which don't have k-> -k mapping
    if evenQ:
        arr[n//2,:] = 0 # zero highest spatial freq
        arr[:,n//2] = 0 # zero highest spatial freq

    arr[0,0] = 0 # DC component

    # fourier-array
    arr = np.exp(2j*np.pi*arr)*phi0

    # inverse FFT and remove floating point errors
    arr = np.fft.ifft2(arr).real
    
    return arr

# potential
potential = white_noise_object_2D(n,phi0)
complex_obj = np.exp(1j*potential)

Probe

# we build probe in Fourier space, using a soft aperture

qx = qy = np.fft.fftfreq(n,sampling)
q2 = qx[:,None]**2 + qy[None,:]**2
q  = np.sqrt(q2)

aperture_fourier = np.sqrt(
    np.clip(
        (q_probe - q)/reciprocal_sampling + 0.5,
        0,
        1,
    ),
)
def return_chi(
    q,
    wavelength,
    C10,
    C30,
):
    """ """
    prefactor = 2*np.pi / wavelength
    alpha = q*wavelength
    order_2 = alpha**2 / 2 * C10 
    order_4 = alpha**4 / 4 * C30
    
    return (order_2 + order_4) * prefactor

def simulate_intensities(q,wavelength,C10,C30, batch_size=n**2, pbar=None):
    """ """
    m = n**2
    n_batch = int(m // batch_size)
    order = np.arange(m).reshape((n_batch,batch_size))
    amplitudes = np.zeros((m,n,n))

    if pbar is not None:
        pbar.reset(n_batch)
        pbar.colour = None
        pbar.refresh()

    exp_chi = np.exp(-1j*return_chi(q,wavelength,C10,C30))
    probe_array_fourier = aperture_fourier * exp_chi
    probe_array_fourier /= np.sqrt(np.sum(np.abs(probe_array_fourier)**2))
    probe_array = np.fft.ifft2(probe_array_fourier) * n
    
    for batch_index in range(n_batch):
        batch_order = order[batch_index]
        amplitudes[batch_order] = ctf.simulate_data(
            complex_obj,
            probe_array,
            row[batch_order],
            col[batch_order],
        )
        if pbar is not None:
            pbar.update(1)

    if pbar is not None:
        pbar.colour = 'green'
        
    return [amplitudes, probe_array]
x = y = np.arange(0.,n,scan_step_size)
xx, yy = np.meshgrid(x,y,indexing='ij')
positions = np.stack((xx.ravel(),yy.ravel()),axis=-1)
row, col = ctf.return_patch_indices(positions,(n,n),(n,n))

amplitudes_probe = simulate_intensities(
    q,
    wavelength,
    C10,
    C30,
    batch_size=1024,
    pbar=None
)

intensities = [amplitudes_probe[0].reshape((sx,sy,n,n))**2 / n**2,None]
intensities[1] = intensities[0].sum((-1,-2))

Ptycho calculation

def ptycho_reconstruction(
    amplitudes,
    row,
    col,
    positions,
    recon_array,
    probe_array,
    pbars,
    batch_size = n**2,
    iterations=64,
    step_size=1.0,
):
    """ """
    m = amplitudes.shape[0]
    nx, ny = probe_array.shape
    n = int(m // batch_size)
    
    order = np.arange(m)
    np.random.shuffle(order)
    
    # normalization
    probe_normalization = np.mean(np.sum(amplitudes**2,0)) / nx / ny
    shifted_probes = probe_array * np.sqrt(probe_normalization)

    outer_pbar,inner_pbar = pbars
    
    outer_pbar.reset(iterations)
    outer_pbar.colour= None
    outer_pbar.refresh()

    
    for iter_index in range(iterations):

        inner_pbar.reset(n)
        inner_pbar.colour= None
        inner_pbar.refresh()
        
        for batch_index in range(n):

            batch_order = order.reshape((n,batch_size))[batch_index]
        
            batch_amplitudes = amplitudes[batch_order]
            batch_pos = positions[batch_order]
            batch_row = row[batch_order]
            batch_col = col[batch_order]
            
            # recon
            obj_patches = recon_array[batch_row,batch_col]

            overlap = shifted_probes * obj_patches
            fourier_overlap = np.fft.fft2(overlap)
            fourier_intensities = np.abs(fourier_overlap)**2
            
            # preprocess fourier overlap
            modified_fourier_overlap = batch_amplitudes*np.exp(1j*np.angle(fourier_overlap))
            grad = np.fft.ifft2(modified_fourier_overlap-fourier_overlap)

            update = ctf.sum_patches(
                grad*np.conj(shifted_probes),
                batch_pos,
                (nx,ny),
                (nx,ny),
            ) / probe_normalization
            
            recon_array += (step_size*update)

            amp = np.abs(recon_array).clip(0.0,1.0)
            recon_array = amp * np.exp(1j*np.angle(recon_array))
            inner_pbar.update(1)

        np.random.shuffle(order)
        update_ptycho_panel(recon_array)
        outer_pbar.update(1)

    inner_pbar.colour='green'
    outer_pbar.colour='green'
    return recon_array
ptycho_recon = [np.ones((n,n),dtype=np.complex128)]
q_bins = np.arange(
    0,
    q_max + reciprocal_sampling,
    reciprocal_sampling
)
with plt.ioff():
    dpi=72
    fig, axs = plt.subplots(1,3,figsize=(640/dpi,270/dpi),dpi=dpi)

# ptycho CTF
ax_ctf_pixelated_ptycho = axs[0]
im_ctf_ptycho = ax_ctf_pixelated_ptycho.imshow(
    np.zeros((n,n)),
    cmap=cmap,
    vmin=-1,
    vmax=1,
)
ctf.add_scalebar(
    ax_ctf_pixelated_ptycho,
    length=n//4,
    sampling=reciprocal_sampling,
    units=r'$q_{\mathrm{probe}}$',
    color="black"
)
ax_ctf_pixelated_ptycho.set(xticks=[],yticks=[],title="contrast transfer function (CTF)")

# CTF radially-averaged
ax_ctf_rad = axs[1]
plot_ctf_ptycho = ax_ctf_rad.plot(q_bins, np.zeros_like(q_bins), color=iter_ptycho_line_color)[0]

ax_ctf_rad.set(
    xlim=[0,2],
    ylim=[0,1.025],
    aspect= 2 / 1.025,
    xticks=[0,1,2],
    yticks=[],
    xlabel=r"spatial frequency, $q/q_{\mathrm{probe}}$",
    title="radially averaged CTF"
)

ax_obj = axs[2]
im_obj = ax_obj.imshow(
    np.zeros((n,n)),
    cmap=sample_cmap,
    vmin=0,
    vmax=1
)

ctf.add_scalebar(
    ax_obj,
    length=40,
    sampling=sampling,
    units=r'Å',
    size_vertical=2,
)

ctf.add_scalebar(
    ax_obj,
    length=40,
    sampling=sto_sampling,
    units=r'Å',
    size_vertical=2,
)

ctf.add_scalebar(
    ax_obj,
    length=40,
    sampling=mof_sampling,
    units=r'nm',
    size_vertical=2
)

ctf.add_scalebar(
    ax_obj,
    length=40,
    sampling=apo_sampling,
    units=r'nm',
    size_vertical=2
)

noise_scalebar, sto_scalebar, mof_scalebar, apo_scalebar = ax_obj.artists
sto_scalebar.set_visible(False)
mof_scalebar.set_visible(False)
apo_scalebar.set_visible(False)

ax_obj.set(xticks=[],yticks=[],title="CTF-convolved weak phase object")

im_obj.set_alpha(0.25)
fig.tight_layout()

fig.canvas.resizable = False
fig.canvas.header_visible = False
fig.canvas.footer_visible = False
fig.canvas.toolbar_visible = False
fig.canvas.layout.height = "280px"
fig.canvas.layout.width = '640px'
None
def update_ptycho_panel(ptycho_recon):
    """ """
    ctf_ptycho = ctf.compute_ctf(np.angle(ptycho_recon))
    im_ctf_ptycho.set_data(
        np.fft.fftshift(ctf_ptycho)
    )

    _, I_bins_ptycho = ctf.radially_average_ctf(
        ctf_ptycho,
        (sampling,sampling)
    )
    plot_ctf_ptycho.set_ydata(I_bins_ptycho)

    
    zero_pad_ctf_to_4qprobe = np.fft.ifftshift(
        np.pad(np.fft.fftshift(ctf_ptycho),48)
    )
    object_index = object_dropdown.value
    if object_index==0:
        convolved_object = np.angle(ptycho_recon)
    else:
        chosen_potential = potentials[object_index-1]
        convolved_object = np.fft.ifft2(
            np.fft.fft2(chosen_potential) * zero_pad_ctf_to_4qprobe
        ).real

    noise_scalebar.set_visible(object_index==0)
    sto_scalebar.set_visible(object_index==1)
    mof_scalebar.set_visible(object_index==2)
    apo_scalebar.set_visible(object_index==3)
        
    convolved_object = ctf.histogram_scaling(convolved_object,normalize=True)
    im_obj.set_data(convolved_object)
    
    im_ctf_ptycho.set_alpha(1)
    plot_ctf_ptycho.set_alpha(1)
    im_obj.set_alpha(1)
    
    fig.canvas.draw()
    return None
def compute_ptycho_updates(
    batch_size,
    iterations,
    pbars,
):
    """ """

    ptycho_recon[0] = ptycho_reconstruction(
        amplitudes_probe[0],
        row,
        col,
        positions,
        ptycho_recon[0],
        amplitudes_probe[1],
        pbars,
        batch_size=batch_size,
        iterations = iterations,
    )
    
    return None
style = {'description_width': 'initial'}
layout = ipywidgets.Layout(width="320px",height="30px")
smaller_layout = ipywidgets.Layout(width="160px",height="30px")
kwargs = {'style':style,'layout':layout,'continuous_update':False}

m = n**2
batch_sizes = m/(np.arange(m)+1)
batch_sizes = np.where(np.mod(batch_sizes, 1, out=batch_sizes)==0)[0]+1
batch_size_slider = ipywidgets.SelectionSlider(
    options=batch_sizes,
    value=batch_sizes[-7],
    description= "batch size",
    **kwargs
)

iterations_slider = ipywidgets.IntSlider(
    value = 4, min = 1, max = 32, step = 1,
    description = "(outer loop) iterations",
    **kwargs
)

iterate_button = ipywidgets.Button(
    description="reconstruct (expensive)",
    layout=smaller_layout,
)

reset_button = ipywidgets.Button(
    description="reset object",
    layout=smaller_layout,
)

C10_slider = ipywidgets.FloatSlider(
    value = 0,
    min = -500,
    max = 500, 
    step = 1,
    description = r"negative defocus, $C_{1,0}$ [Å]",
    **kwargs
)

C30_slider = ipywidgets.FloatSlider(
    value = 0,
    min = -100,
    max = 100, 
    step = 0.1,
    description = r"spherical aberration, $C_{3,0}$ [µm]",
    **kwargs
)

simulate_button = ipywidgets.Button(
    description='simulate (expensive)',
    layout=ipywidgets.Layout(width="160px",height="30px")
)

object_dropdown = ipywidgets.Dropdown(
    options=[("white noise object",0),("strontium titanate",1),("metal-organic framework",2),("apoferritin protein",3)],
    **kwargs
)

simulation_pbar = tqdm(total=9,display=False)
simulation_pbar_wrapper = ipywidgets.HBox(simulation_pbar.container.children[:2],layout=ipywidgets.Layout(width="160px"))

iterate_button.button_style = 'warning'

def defocus_wrapper(*args):
    simulate_button.button_style = 'warning'
    reset_wrapper()
    im_ctf_ptycho.set_alpha(0.25)
    plot_ctf_ptycho.set_alpha(0.25)
    im_obj.set_alpha(0.25)
    simulation_pbar.reset()
C10_slider.observe(defocus_wrapper,names='value')
C30_slider.observe(defocus_wrapper,names='value')

def simulate_wrapper(*args):
    disable_all(True)
    amplitudes_probe[:] = simulate_intensities(
        q,
        wavelength,
        C10=C10_slider.value,
        C30=C30_slider.value,
        batch_size=1024,
        pbar=simulation_pbar
    )

    intensities[0] = amplitudes_probe[0].reshape((sx,sy,n,n))**2 / n**2
    intensities[1] = intensities[0].sum((-1,-2))
    
    disable_all(False)
    iterate_button.button_style = 'warning'
    outer_reconstruct_pbar.reset()
    inner_reconstruct_pbar.reset()
    simulate_button.button_style = ''
simulate_button.on_click(simulate_wrapper)

def reset_wrapper(*args):
    """ """
    ptycho_recon[0] = np.ones((n,n),dtype=np.complex128)
    update_ptycho_panel(ptycho_recon[0])
    if simulate_button.button_style != 'warning':
        iterate_button.button_style = 'warning'
    im_obj.set_alpha(0.25)
    outer_reconstruct_pbar.reset()
    inner_reconstruct_pbar.reset()

reset_button.on_click(reset_wrapper)

def disable_all(boolean):
    batch_size_slider.disabled = boolean
    iterations_slider.disabled = boolean
    reset_button.disabled = boolean
    iterate_button.disabled = boolean
    C10_slider.disabled = boolean
    C30_slider.disabled = boolean
    object_dropdown.disabled = boolean
    simulate_button.disabled = boolean
    simulation_pbar_wrapper.disabled = boolean

def click_wrapper(*args):
    """ """
    disable_all(True)
    compute_ptycho_updates(
        batch_size=batch_size_slider.value,
        iterations=iterations_slider.value,
        pbars=(outer_reconstruct_pbar,inner_reconstruct_pbar),
    )
    disable_all(False)
    iterate_button.button_style = ''

iterate_button.on_click(click_wrapper)
outer_reconstruct_pbar = tqdm(total=4,display=False)
outer_reconstruct_pbar_wrapper = ipywidgets.HBox(outer_reconstruct_pbar.container.children[:2],layout=ipywidgets.Layout(width="160px"))

inner_reconstruct_pbar = tqdm(total=9,display=False)
inner_reconstruct_pbar_wrapper = ipywidgets.HBox(inner_reconstruct_pbar.container.children[:2],layout=ipywidgets.Layout(width="160px"))

def update_current_ptycho_wrapper(*args):
    """ """
    update_ptycho_panel(ptycho_recon[0])
    return None

object_dropdown.observe(update_current_ptycho_wrapper,"value")
ipywidgets.VBox(
    [
        ipywidgets.VBox(
            [
                ipywidgets.HBox([C10_slider,C30_slider]),
                ipywidgets.HBox([simulate_button,simulation_pbar_wrapper,object_dropdown]),
                ipywidgets.HTML("<hr>",layout=ipywidgets.Layout(width="640px")),
                ipywidgets.HBox([batch_size_slider,iterations_slider]),
                ipywidgets.HBox([reset_button,iterate_button,outer_reconstruct_pbar_wrapper,inner_reconstruct_pbar_wrapper]),
            ]
        ),
        fig.canvas
    ]
)