Evaluating the Transfer of Information in Phase Retrieval STEM Techniques

Segmented Iterative Ptychography

# enable interactive matplotlib
%matplotlib widget 

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
import ctf # import custom plotting / utils
import cmasher as cmr

from tqdm.notebook import tqdm

import ipywidgets
from IPython.display import display

4D STEM Simulation

# parameters
n = 96
q_max = 2 # inverse Angstroms
q_probe = 1 # inverse Angstroms
wavelength = 0.019687 # 300kV
sampling = 1 / q_max / 2 # Angstroms
reciprocal_sampling = 2 * q_max / n # inverse Angstroms

scan_step_size = 1 # pixels
sx = sy = n//scan_step_size
phi0 = 1.0

cmap = cmr.eclipse
sample_cmap = 'gray'

segmented_icom_line_color = 'cornflowerblue'
segmented_ptycho_line_color = 'orchid'
pixelated_ptycho_line_color = 'darkmagenta'

White Noise Potential

def white_noise_object_2D(n, phi0):
    """ creates a 2D real-valued array, whose FFT has random phase and constant amplitude """

    evenQ = n%2 == 0
    
    # indices
    pos_ind = np.arange(1,(n if evenQ else n+1)//2)
    neg_ind = np.flip(np.arange(n//2+1,n))

    # random phase
    arr = np.random.randn(n,n)
    
    # top-left // bottom-right
    arr[pos_ind[:,None],pos_ind[None,:]] = -arr[neg_ind[:,None],neg_ind[None,:]]
    # bottom-left // top-right
    arr[pos_ind[:,None],neg_ind[None,:]] = -arr[neg_ind[:,None],pos_ind[None,:]]
    # kx=0
    arr[0,pos_ind] = -arr[0,neg_ind]
    # ky=0
    arr[pos_ind,0] = -arr[neg_ind,0]

    # zero-out components which don't have k-> -k mapping
    if evenQ:
        arr[n//2,:] = 0 # zero highest spatial freq
        arr[:,n//2] = 0 # zero highest spatial freq

    arr[0,0] = 0 # DC component

    # fourier-array
    arr = np.exp(2j*np.pi*arr)*phi0

    # inverse FFT and remove floating point errors
    arr = np.fft.ifft2(arr).real
    
    return arr

# potential
potential = white_noise_object_2D(n,phi0)
complex_obj = np.exp(1j*potential)

Import sample potentials

sto_potential = np.load("data/STO_projected-potential_192x192_4qprobe.npy")
sto_potential -= sto_potential.mean()
mof_potential = np.load("data/MOF_projected-potential_192x192_4qprobe.npy")
mof_potential -= mof_potential.mean()
apo_potential = np.load("data/apoF_projected-potential_192x192_4qprobe.npy")
apo_potential -= apo_potential.mean()

Probe

# we build probe in Fourier space, using a soft aperture

qx = qy = np.fft.fftfreq(n,sampling)
q2 = qx[:,None]**2 + qy[None,:]**2
q  = np.sqrt(q2)

aperture_fourier = np.sqrt(
    np.clip(
        (q_probe - q)/reciprocal_sampling + 0.5,
        0,
        1,
    ),
)

# # normalized s.t. np.sum(np.abs(probe_array_fourier)**2) = 1.0
# probe_array_fourier /= np.sqrt(np.sum(np.abs(probe_array_fourier)**2))

# # we then take the inverse FFT, and normalize s.t. np.sum(np.abs(probe_array)**2) = 1.0
# probe_array = np.fft.ifft2(probe_array_fourier) * n
def simulate_intensities(defocus, batch_size=n**2, pbar=None):

    m = n**2
    n_batch = int(m // batch_size)
    order = np.arange(m).reshape((n_batch,batch_size))
    amplitudes = np.zeros((m,n,n))

    if pbar is not None:
        pbar.reset(n_batch)
        pbar.colour = None
        pbar.refresh()

    probe_array_fourier = aperture_fourier * np.exp(-1j*np.pi*wavelength*defocus*q**2)
    probe_array_fourier /= np.sqrt(np.sum(np.abs(probe_array_fourier)**2))
    probe_array = np.fft.ifft2(probe_array_fourier) * n
    
    for batch_index in range(n_batch):
        batch_order = order[batch_index]
        amplitudes[batch_order] = ctf.simulate_data(
            complex_obj,
            probe_array,
            row[batch_order],
            col[batch_order],
        )
        if pbar is not None:
            pbar.update(1)

    if pbar is not None:
        pbar.colour = 'green'
        
    return [amplitudes, probe_array, probe_array_fourier]
x = y = np.arange(0.,n,scan_step_size)
xx, yy = np.meshgrid(x,y,indexing='ij')
positions = np.stack((xx.ravel(),yy.ravel()),axis=-1)
row, col = ctf.return_patch_indices(positions,(n,n),(n,n))

amplitudes_probe = simulate_intensities(defocus=0, batch_size=1024, pbar=None)

intensities = [amplitudes_probe[0].reshape((sx,sy,n,n))**2 / n**2,None]
intensities[1] = intensities[0].sum((-1,-2))

Virtual Detectors and CoM calculation

def annular_segmented_detectors(
    gpts,
    sampling,
    n_angular_bins,
    rotation_offset = 0,
    inner_radius = 0,
    outer_radius = np.inf,
):
    """ """
    nx,ny = gpts
    sx,sy = sampling

    k_x = np.fft.fftfreq(nx,sx)
    k_y = np.fft.fftfreq(ny,sy)

    k = np.sqrt(k_x[:,None]**2 + k_y[None,:]**2)
    radial_mask = ((inner_radius <= k) & (k < outer_radius))
    
    theta = (np.arctan2(k_y[None,:], k_x[:,None]) + rotation_offset) % (2 * np.pi)
    angular_bins = np.floor(n_angular_bins * (theta / (2 * np.pi))) + 1
    angular_bins *= radial_mask.astype("int")

    angular_bins = [np.fft.fftshift((angular_bins == i).astype("int")) for i in range(1,n_angular_bins+1)]
    
    return angular_bins

def compute_com_using_virtual_detectors(
    corner_centered_intensities,
    center_centered_masks,
    corner_centered_intensities_sum,
    sx,sy,
    kxa,kya,
):
    """ """

    masks = np.fft.ifftshift(np.asarray(center_centered_masks),axes=(-1,-2))
    
    com_x = np.zeros((sx,sy))
    com_y = np.zeros((sx,sy))

    kx = ky = np.fft.fftfreq(n,sampling).astype(np.float32)
    kxa, kya = np.meshgrid(kx, ky, indexing='ij')
    
    for mask in masks:
        kxa_i,kya_i=np.where(mask)
        patches= corner_centered_intensities[:,:,kxa_i,kya_i].sum(-1) / corner_centered_intensities_sum
        com_x += patches * np.mean(kxa[kxa_i,kya_i])
        com_y += patches * np.mean(kya[kxa_i,kya_i])
        
    return com_x, com_y

def integrate_com(
    com_x,
    com_y,
    kx_op,
    ky_op,
):
    """ """

    icom_fft = np.fft.fft2(com_x)*kx_op + np.fft.fft2(com_y)*ky_op
    return np.real(np.fft.ifft2(icom_fft))

def bin_amplitudes_using_virtual_detectors(
    corner_centered_amplitudes,
    center_centered_masks,
): 
    """ """
    
    masks = np.fft.ifftshift(np.asarray(center_centered_masks).astype(np.bool_),axes=(-1,-2))
    inverse_mask = (1-masks.sum(0)).astype(np.bool_)

    values = np.zeros((masks.shape[0],corner_centered_amplitudes.shape[0]))
    for index, mask in enumerate(masks):
        values[index] = np.sqrt(np.sum(corner_centered_amplitudes**2 * mask,axis=(-1,-2)))

    return values

def virtual_detector_ptycho_reconstruction(
    binned_amplitude_values,
    row,
    col,
    positions,
    center_centered_masks,
    recon,
    probe_array,
    pbars,
    batch_size = n**2,
    iterations=64,
    step_size=1.0,
):
    """ """
    m = binned_amplitude_values.shape[1]
    nx, ny = probe_array.shape
    n = int(m // batch_size)
    
    order = np.arange(m)
    np.random.shuffle(order)

    masks = np.fft.ifftshift(np.asarray(center_centered_masks,dtype=np.bool_),axes=(-1,-2))
    inverse_mask = (1-masks.sum(0)).astype(np.bool_)
    
    # normalization
    probe_normalization = np.mean(np.sum(binned_amplitude_values**2,0)) / nx / ny
    shifted_probes = probe_array * np.sqrt(probe_normalization)

    outer_pbar,inner_pbar = pbars
    
    outer_pbar.reset(iterations)
    outer_pbar.colour= None
    outer_pbar.refresh()
    
    for iter_index in range(iterations):

        inner_pbar.reset(n)
        inner_pbar.colour= None
        inner_pbar.refresh()
        
        for batch_index in range(n):

            batch_order = order.reshape((n,batch_size))[batch_index]
        
            batch_amplitudes = binned_amplitude_values[:,batch_order]
            batch_pos = positions[batch_order]
            batch_row = row[batch_order]
            batch_col = col[batch_order]
            
            # recon
            obj_patches = recon[batch_row,batch_col]

            overlap = shifted_probes * obj_patches
            fourier_overlap = np.fft.fft2(overlap)
            fourier_intensities = np.abs(fourier_overlap)**2
            
            # preprocess fourier overlap
            old_fourier_overlap_sum = np.sum(np.abs(fourier_overlap)**2)
            fourier_overlap[...,inverse_mask] = 0.0

            modified_fourier_overlap = fourier_overlap.copy()
            new_fourier_overlap_sum = 0.0
            for mask, amp_val in zip(masks,batch_amplitudes):
                squared_val = np.sum(fourier_intensities * mask,axis=(-1,-2))
                new_fourier_overlap_sum += np.sum(squared_val)
                modified_fourier_overlap[...,mask] *= (amp_val/np.sqrt(squared_val))[:,None]

            modified_fourier_overlap /= np.sqrt(old_fourier_overlap_sum/new_fourier_overlap_sum)

            grad = np.fft.ifft2(modified_fourier_overlap-fourier_overlap)

            update = ctf.sum_patches(
                grad*np.conj(shifted_probes),
                batch_pos,
                (nx,ny),
                (nx,ny),
            ) / probe_normalization
            
            recon += (step_size*update)

            amp = np.abs(recon).clip(0.0,1.0)
            recon = amp * np.exp(1j*np.angle(recon))
            inner_pbar.update(1)

        np.random.shuffle(order)
        update_ptycho_panel(recon)
        outer_pbar.update(1)

    inner_pbar.colour='green'
    outer_pbar.colour='green'
    return recon
# Spatial frequencies
kx = ky = np.fft.fftfreq(n,sampling).astype(np.float32)
kxa, kya = np.meshgrid(kx, ky, indexing='ij')

k2 = kxa**2 + kya**2
k = np.sqrt(k2)
k2[0, 0] = np.inf

# iCoM operators
kx_op = -1.0j * kxa / k2
ky_op = -1.0j * kya / k2

# Compute the inverse error
inverse_error = (k*np.pi/np.sqrt(2))
# Initial masks and recon
virtual_masks_annular = [ annular_segmented_detectors(
    gpts=(n,n),
    sampling=(sampling,sampling),
    n_angular_bins=4,
    inner_radius=q_probe/2,
    outer_radius=q_probe*1.05,
    rotation_offset=0,
)]
# com
com_x, com_y = compute_com_using_virtual_detectors(
    intensities[0],
    virtual_masks_annular[0],
    intensities[1],
    sx,sy,
    kxa,kya,
)

icom_annular = integrate_com(com_x,com_y,kx_op,ky_op)
ctf_annular = ctf.compute_ctf(icom_annular) 

q_bins_annular, I_bins_annular = ctf.radially_average_ctf(
    ctf_annular,
    (sampling,sampling)
)

# # Analytical CTF (probe autocorrelation)
# ctf_analytic = np.abs(
#     np.real(
#         np.fft.ifft2(
#             np.abs(
#                 np.fft.fft2(
#                     amplitudes_probe[2]
#                 )
#             )**2
#         )
#     )
# )

# # Radially-averaged CTF and SNR
# q_bins_analytic, I_bins_analytic = ctf.radially_average_ctf(ctf_analytic,(sampling,sampling))
amplitude_values = [None]
ptycho_recon = [np.ones((n,n),dtype=np.complex128)]
with plt.ioff():
    dpi=72
    fig, axs = plt.subplots(2,4,figsize=(640/dpi,400/dpi),dpi=dpi)

# detector
ax_detector = axs[0,0]
im_detector = ax_detector.imshow(ctf.combined_images_rgb(virtual_masks_annular[0]))

# annular CTF
ax_ctf_annular_dpc = axs[0,1]
im_ctf_dpc = ax_ctf_annular_dpc.imshow(ctf.histogram_scaling(np.fft.fftshift(ctf_annular),normalize=True),cmap=cmap)

# ptycho CTF
ax_ctf_annular_ptycho = axs[0,2]
im_ctf_ptycho = ax_ctf_annular_ptycho.imshow(np.zeros((n,n)),cmap=cmap,vmin=0,vmax=1)

# analytic CTF radially-averaged
ax_ctf_rad = axs[0,3]
# plot_ctf_dpc_analytical = ax_ctf_rad.plot(q_bins_analytic,I_bins_analytic,color='k',label='pixelated iCOM')[0]

plot_ctf_dpc = ax_ctf_rad.plot(q_bins_annular, I_bins_annular, color=segmented_icom_line_color,label='segmented iCOM')[0]
plot_ctf_ptycho = ax_ctf_rad.plot(q_bins_annular, np.zeros_like(I_bins_annular), color=segmented_ptycho_line_color,label='segmented ptycho')[0]
ax_ctf_rad.legend()

# remove ticks, add titles to 2D-plots
for ax, title in zip(
    axs.flatten(),
    [
        "detector geometry",
        "segmented iCOM CTF",
        "segmented ptycho CTF",
        "radially averaged CTF",
        "white noise object",
        "strontium titanate",
        "metal-organic framework",
        "apoferritin protein",
    ]
):
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_title(title)

for ax in axs[0,:3]:
    ctf.add_scalebar(ax,length=n//4,sampling=reciprocal_sampling,units=r'$q_{\mathrm{probe}}$')

ax_ctf_rad.set_ylim([0,1])
ax_ctf_rad.set_xlim([0,q_max])
ax_ctf_rad.vlines([q_probe/2,q_probe*1.05],0,2,colors='k',linestyles='--',linewidth=1,)
ax_ctf_rad.set_xticks([0,q_probe,q_max])
ax_ctf_rad.set_xticklabels([0,1,2])
ax_ctf_rad.set_aspect(2)
ax_ctf_rad.set_xlabel(r"spatial frequency, $q/q_{\mathrm{probe}}$")

ax_white_noise_obj = axs[1,0]
im_white_noise_obj = ax_white_noise_obj.imshow(
    np.zeros((n,n)),vmin=0,vmax=1,
    cmap=sample_cmap
)
ctf.add_scalebar(ax_white_noise_obj,length=n//5,sampling=sampling,units=r'Å')


ax_sto_obj = axs[1,1]
im_sto_obj = ax_sto_obj.imshow(
    np.zeros((n,n)),vmin=0,vmax=1,
    cmap=sample_cmap
)
sto_sampling = 23.67 / n  # Å
ctf.add_scalebar(ax_sto_obj,length=n//5,sampling=sto_sampling,units=r'Å')

ax_mof_obj = axs[1,2]
im_mof_obj = ax_mof_obj.imshow(
    np.zeros((n,n)),vmin=0,vmax=1,
    cmap=sample_cmap
)
mof_sampling = 4.48 / n  # nm
ctf.add_scalebar(ax_mof_obj,length=n//5,sampling=mof_sampling,units=r'nm')

ax_apo_obj = axs[1,3]
im_apo_obj = ax_apo_obj.imshow(
    np.zeros((n,n)),vmin=0,vmax=1,
    cmap=sample_cmap
)
apo_sampling = 19.2 / n  # nm
ctf.add_scalebar(ax_apo_obj,length=n//5,sampling=apo_sampling,units=r'nm')

im_ctf_ptycho.set_alpha(0.25)
im_white_noise_obj.set_alpha(0.25)
im_sto_obj.set_alpha(0.25)
im_mof_obj.set_alpha(0.25)
im_apo_obj.set_alpha(0.25)

# fix ipympl canvas from resizing
fig.canvas.resizable = False
fig.canvas.header_visible = False
fig.canvas.footer_visible = False
fig.canvas.toolbar_visible = False
# fig.canvas.toolbar_visible = True
# fig.canvas.toolbar_position = 'bottom'
fig.canvas.layout.width = '640px'
fig.canvas.layout.height = '420px'
fig.tight_layout()
# fig
def update_ptycho_panel(ptycho_recon):
    """ """
    ctf_ptycho = ctf.compute_ctf(ptycho_recon)
    im_ctf_ptycho.set_data(
        ctf.histogram_scaling(
            np.fft.fftshift(ctf_ptycho),
            normalize=True
        )
    )

    # real space samples
    im_white_noise_obj.set_data(
        ctf.histogram_scaling(
            np.fft.ifft2(
                np.fft.fft2(potential) * ctf_ptycho
            ).real
            ,normalize=True
        )
    )
    
    zero_pad_ctf_to_4qprobe = np.fft.ifftshift(np.pad(np.fft.fftshift(ctf_ptycho),48))
    resample_2qprobe_ctf_to_192  = np.fft.fft2(np.fft.ifftshift(np.pad(np.fft.fftshift(np.fft.ifft2(ctf_ptycho).real),48)))
    im_sto_obj.set_data(ctf.histogram_scaling(np.fft.ifft2(np.fft.fft2(sto_potential) * zero_pad_ctf_to_4qprobe).real,normalize=True))
    im_mof_obj.set_data(ctf.histogram_scaling(np.fft.ifft2(np.fft.fft2(mof_potential) * zero_pad_ctf_to_4qprobe).real,normalize=True))
    im_apo_obj.set_data(ctf.histogram_scaling(np.fft.ifft2(np.fft.fft2(apo_potential) * zero_pad_ctf_to_4qprobe).real,normalize=True))
    
    # radially average
    _, I_bins_ptycho = ctf.radially_average_ctf(
        ctf_ptycho,
        (sampling,sampling)
    )
    plot_ctf_ptycho.set_ydata(I_bins_ptycho)

    im_ctf_ptycho.set_alpha(1)
    im_white_noise_obj.set_alpha(1)
    im_sto_obj.set_alpha(1)
    im_mof_obj.set_alpha(1)
    im_apo_obj.set_alpha(1)
    plot_ctf_ptycho.set_alpha(1)
    
    fig.canvas.draw()
    return None
def compute_ptycho_updates(
    batch_size,
    iterations,
    pbars,
):
    """ """
    if amplitude_values[0] is None:
        amplitude_values[0] = bin_amplitudes_using_virtual_detectors(
            amplitudes_probe[0],
            virtual_masks_annular[0],
        )

    ptycho_recon[0] = virtual_detector_ptycho_reconstruction(
        amplitude_values[0],
        row,
        col,
        positions,
        virtual_masks_annular[0],
        ptycho_recon[0],
        amplitudes_probe[1],
        pbars,
        batch_size=batch_size,
        iterations = iterations,
    )
    
    return None
style = {'description_width': 'initial'}
layout = ipywidgets.Layout(width="320px",height="30px")
smaller_layout = ipywidgets.Layout(width="160px",height="30px")
kwargs = {'style':style,'layout':layout,'continuous_update':False}

inner_collection_angle_slider = ipywidgets.FloatSlider(
    value = q_probe/2,
    min = 0,
    max = q_probe, 
    step = q_probe/20,
    description = r"inner collection angle [$q_{\mathrm{probe}}$]",
    **kwargs
)

outer_collection_angle_slider = ipywidgets.FloatSlider(
    value = q_probe*1.05, 
    min = q_probe/20, 
    max = q_max, 
    step = q_probe/20,
    description = r"outer collection angle [$q_{\mathrm{probe}}$]",
    **kwargs
)

number_of_segments_slider = ipywidgets.IntSlider(
    value = 4, 
    min = 3, 
    max = 16, 
    step = 1,
    description = "number of annular segments",
    **kwargs
)

rotation_offset_slider = ipywidgets.IntSlider(
    value = 0, min = 0, max = 180/4, step = 1,
    description = "rotation offset [°]",
    **kwargs
)

number_of_rings_slider = ipywidgets.IntSlider(
    value = 1, 
    min = 1, 
    max = 8, 
    step = 1,
    description = "number of radial rings",
    **kwargs
)

rotate_half_the_rings = ipywidgets.ToggleButton(
    value = False,
    description = 'offset radial rings',
    disabled = False,
    layout=ipywidgets.Layout(width="155px",height="30px")
)

area_toggle = ipywidgets.ToggleButton(
    value = False,
    description = 'distribute by area',
    layout=ipywidgets.Layout(width="155px",height="30px")
)

def update_outer_collection_angle(change):
    value = change['new']
    outer_collection_angle_slider.min = value*1.05

def update_inner_collection_angle(change):
    value = change['new']
    inner_collection_angle_slider.max = value

inner_collection_angle_slider.observe(update_outer_collection_angle, names='value')
outer_collection_angle_slider.observe(update_inner_collection_angle, names='value')

# rotation offset is modulo 180/n
def update_rotation_offset_range(change):
    value = change['new']
    rotation_offset_slider.max = 180/value

number_of_segments_slider.observe(update_rotation_offset_range, names='value') 

batch_size_slider = ipywidgets.IntSlider(
    value = n**2, min = 1, max = n**2, step = 1,
    description = "batch size",
    **kwargs
)

m = n**2
batch_sizes = m/(np.arange(m)+1)
batch_sizes = np.where(np.mod(batch_sizes, 1, out=batch_sizes)==0)[0]+1
batch_size_slider = ipywidgets.SelectionSlider(
    options=batch_sizes,
    value=batch_sizes[-7],
    description= "batch size",
    **kwargs
)

iterations_slider = ipywidgets.IntSlider(
    value = 4, min = 1, max = 32, step = 1,
    description = "(outer loop) iterations",
    **kwargs
)

iterate_button = ipywidgets.Button(
    description="reconstruct (expensive)",
    layout=smaller_layout,
)

reset_button = ipywidgets.Button(
    description="reset object",
    layout=smaller_layout,
)

defocus_slider = ipywidgets.IntSlider(
    value = 0, min = -n, max = n, step = 1,
    description = "negative defocus, $C_{1,0}$ [Å]",
    **kwargs
)

simulate_button = ipywidgets.Button(
    description='simulate (expensive)',
    layout=ipywidgets.Layout(width="160px",height="30px")
)

simulation_pbar = tqdm(total=9,display=False)
simulation_pbar_wrapper = ipywidgets.HBox(
    simulation_pbar.container.children[:2],
    layout=ipywidgets.Layout(width="160px")
)

def defocus_wrapper(*args):
    """ """
    simulate_button.button_style = 'warning'
    reset_wrapper()
    im_ctf_dpc.set_alpha(0.25)
    im_ctf_ptycho.set_alpha(0.25)
    plot_ctf_dpc.set_alpha(0.25)
    plot_ctf_ptycho.set_alpha(0.25)
    # plot_ctf_dpc_analytical.set_alpha(0.25)
    im_white_noise_obj.set_alpha(0.25)
    im_sto_obj.set_alpha(0.25)
    im_mof_obj.set_alpha(0.25)
    im_apo_obj.set_alpha(0.25)
    simulation_pbar.reset()
defocus_slider.observe(defocus_wrapper,names='value')

def simulate_wrapper(*args):
    disable_all(True)
    amplitudes_probe[:] = simulate_intensities(
        defocus=defocus_slider.value,
        batch_size=1024,
        pbar=simulation_pbar
    )

    intensities[0] = amplitudes_probe[0].reshape((sx,sy,n,n))**2 / n**2
    intensities[1] = intensities[0].sum((-1,-2))

    # ctf_analytic = np.real(
    #     np.fft.ifft2(
    #         np.abs(
    #             np.fft.fft2(
    #                 amplitudes_probe[2]
    #             )
    #         )**2
    #     )
    # )
    
    # # Radially-averaged CTF and SNR
    # q_bins_analytic, I_bins_analytic = ctf.radially_average_ctf(ctf_analytic,(sampling,sampling))
    # plot_ctf_dpc_analytical.set_ydata(I_bins_analytic)
    # plot_ctf_dpc_analytical.set_alpha(1)


    amplitude_values[0] = bin_amplitudes_using_virtual_detectors(
        amplitudes_probe[0],
        virtual_masks_annular[0],
    )
    
    update_figure("dummy")
    disable_all(False)
    iterate_button.button_style = 'warning'
    outer_reconstruct_pbar.reset()
    inner_reconstruct_pbar.reset()
    simulate_button.button_style = ''
simulate_button.on_click(simulate_wrapper)

def reset_wrapper(*args):
    """ """
    ptycho_recon[0] = np.ones((n,n),dtype=np.complex128)
    update_ptycho_panel(ptycho_recon[0])
    if simulate_button.button_style != 'warning':
        iterate_button.button_style = 'warning'
    outer_reconstruct_pbar.reset()
    inner_reconstruct_pbar.reset()
    
    im_ctf_ptycho.set_alpha(0.25)
    im_white_noise_obj.set_alpha(0.25)
    im_sto_obj.set_alpha(0.25)
    im_mof_obj.set_alpha(0.25)
    im_apo_obj.set_alpha(0.25)

reset_button.on_click(reset_wrapper)

def disable_all(boolean):
    """ """
    inner_collection_angle_slider.disabled = boolean
    outer_collection_angle_slider.disabled = boolean
    number_of_segments_slider.disabled = boolean
    rotation_offset_slider.disabled = boolean
    number_of_rings_slider.disabled = boolean
    rotate_half_the_rings.disabled = boolean
    area_toggle.disabled = boolean
    batch_size_slider.disabled = boolean
    iterations_slider.disabled = boolean
    reset_button.disabled = boolean
    iterate_button.disabled = boolean
    defocus_slider.disabled = boolean
    simulate_button.disabled = boolean
    simulation_pbar_wrapper.disabled = boolean

def click_wrapper(*args):
    """ """
    disable_all(True)
    compute_ptycho_updates(
        batch_size=batch_size_slider.value,
        iterations=iterations_slider.value,
        pbars=(outer_reconstruct_pbar,inner_reconstruct_pbar),
    )
    disable_all(False)
    iterate_button.button_style = ''

iterate_button.on_click(click_wrapper)
outer_reconstruct_pbar = tqdm(total=4,display=False)
outer_reconstruct_pbar_wrapper = ipywidgets.HBox(
    outer_reconstruct_pbar.container.children[:2],
    layout=ipywidgets.Layout(width="160px")
)

inner_reconstruct_pbar = tqdm(total=9,display=False)
inner_reconstruct_pbar_wrapper = ipywidgets.HBox(
    inner_reconstruct_pbar.container.children[:2],
    layout=ipywidgets.Layout(width="160px")
)
def update_figure(
    *args,
):
    """ """

    # compute new datasets
    _virtual_masks_annular = []
    if area_toggle.value:
        ring_collection_angles = np.linspace(
            inner_collection_angle_slider.value**2,
            outer_collection_angle_slider.value**2,
            num=number_of_rings_slider.value + 1
        )**(1/2)
    else:
        ring_collection_angles = np.linspace(
            inner_collection_angle_slider.value,
            outer_collection_angle_slider.value,
            num=number_of_rings_slider.value + 1
        )
    if rotate_half_the_rings.value:
        ring_rotation = np.deg2rad((180/number_of_segments_slider.value))
    else:
        ring_rotation = 0
    for i in range(1,number_of_rings_slider.value+1):
        j = i-1
        _virtual_masks_annular.append(
            annular_segmented_detectors(
                gpts=(n,n),
                sampling=(sampling,sampling),
                n_angular_bins=number_of_segments_slider.value,
                inner_radius=ring_collection_angles[j],
                outer_radius=ring_collection_angles[i],
                rotation_offset=np.deg2rad(rotation_offset_slider.value) + ring_rotation*(j%2),
            )
        )

    virtual_masks_annular[0] = np.vstack(_virtual_masks_annular)
    
    com_x, com_y = compute_com_using_virtual_detectors(
        intensities[0],
        virtual_masks_annular[0],
        intensities[1],
        sx,sy,
        kxa,kya,
    )

    icom_annular = integrate_com(com_x,com_y,kx_op,ky_op)
    ctf_annular = ctf.compute_ctf(icom_annular) 

    q_bins_annular, I_bins_annular = ctf.radially_average_ctf(
        ctf_annular,
        (sampling,sampling)
    )

    q_bins_annular_snr, I_bins_annular_snr = ctf.radially_average_ctf(
        ctf_annular*inverse_error,
        (sampling,sampling)
    )
    
    # update data

    # 2D arrays
    im_detector.set_data(ctf.combined_images_rgb(virtual_masks_annular[0]))
    im_ctf_dpc.set_data(ctf.histogram_scaling(np.fft.fftshift(ctf_annular),normalize=True))

    # 1D lines
    plot_ctf_dpc.set_ydata(I_bins_annular)

    # vlines
    ax_ctf_rad.collections[0].remove()
    ax_ctf_rad.vlines(
        [
            inner_collection_angle_slider.value,
            outer_collection_angle_slider.value
        ],0,2,
        colors='k',linestyles='--',linewidth=1,
    )

    im_ctf_dpc.set_alpha(1)
    plot_ctf_dpc.set_alpha(1)

    iterate_button.button_style = 'warning'
    if amplitude_values[0] is not None:
        amplitude_values[0] = None
        reset_wrapper()
    im_ctf_ptycho.set_alpha(0.25)
    im_white_noise_obj.set_alpha(0.25)
    im_sto_obj.set_alpha(0.25)
    im_mof_obj.set_alpha(0.25)
    im_apo_obj.set_alpha(0.25)

    # re-draw figure
    fig.canvas.draw_idle()
    return None

inner_collection_angle_slider.observe(update_figure,names='value')
outer_collection_angle_slider.observe(update_figure,names='value')
number_of_segments_slider.observe(update_figure,names='value')
rotation_offset_slider.observe(update_figure,names='value')
number_of_rings_slider.observe(update_figure,names='value')
rotate_half_the_rings.observe(update_figure,names='value')
area_toggle.observe(update_figure,names='value')
iterate_button.button_style = 'warning'
def simulate(
    defocus,
):
    """ """
    intensities[0], probe_array_fourier[0] = simulate_intensities(
        defocus=defocus,
    )
    intensities[1] = intensities[0].sum((-1,-2))
    
    update_figure("dummy")
    
    return None
display(
    ipywidgets.VBox(
        [
            ipywidgets.VBox(
                [
                    ipywidgets.HBox([defocus_slider,simulate_button,simulation_pbar_wrapper]),
                    ipywidgets.HTML("<hr>",layout=ipywidgets.Layout(width="640px")),
                    ipywidgets.HBox([inner_collection_angle_slider,outer_collection_angle_slider]),
                    ipywidgets.HBox([number_of_segments_slider,rotation_offset_slider]),
                    ipywidgets.HBox([number_of_rings_slider,rotate_half_the_rings,area_toggle]),
                    ipywidgets.HTML("<hr>",layout=ipywidgets.Layout(width="640px")),
                    ipywidgets.HBox([batch_size_slider,iterations_slider]),
                    ipywidgets.HBox([reset_button,iterate_button,outer_reconstruct_pbar_wrapper,inner_reconstruct_pbar_wrapper]),
                ]
            ),
            fig.canvas
        ]
    )
)