Evaluating the Transfer of Information in Phase Retrieval STEM Techniques
Contents
Pixelated Iterative Ptychography
# enable interactive matplotlib
%matplotlib widget
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
import ctf # import custom plotting / utils
import cmasher as cmr
from tqdm.notebook import tqdm
import ipywidgets
4D STEM Simulation¶
# parameters
n = 96
q_max = 2 # inverse Angstroms
q_probe = 1 # inverse Angstroms
wavelength = 0.019687 # 300kV
sampling = 1 / q_max / 2 # Angstroms
reciprocal_sampling = 2 * q_max / n # inverse Angstroms
scan_step_size = 1 # pixels
sx = sy = n//scan_step_size
phi0 = 1.0
C10 = 0
C30 = 0
cmap = cmr.viola
sample_cmap = 'gray'
iter_ptycho_line_color = 'darkmagenta'
sto_potential = np.load("data/STO_projected-potential_192x192_4qprobe.npy")
sto_potential -= sto_potential.mean()
mof_potential = np.load("data/MOF_projected-potential_192x192_4qprobe.npy")
mof_potential -= mof_potential.mean()
apo_potential = np.load("data/apoF_projected-potential_192x192_4qprobe.npy")
apo_potential -= apo_potential.mean()
potentials = [sto_potential,mof_potential,apo_potential]
sto_sampling = 23.67 / sto_potential.shape[0] # Å
mof_sampling = 4.48 / mof_potential.shape[0] # nm
apo_sampling = 19.2 / apo_potential.shape[0] # nm
White Noise Potential¶
def white_noise_object_2D(n, phi0):
""" creates a 2D real-valued array, whose FFT has random phase and constant amplitude """
evenQ = n%2 == 0
# indices
pos_ind = np.arange(1,(n if evenQ else n+1)//2)
neg_ind = np.flip(np.arange(n//2+1,n))
# random phase
arr = np.random.randn(n,n)
# top-left // bottom-right
arr[pos_ind[:,None],pos_ind[None,:]] = -arr[neg_ind[:,None],neg_ind[None,:]]
# bottom-left // top-right
arr[pos_ind[:,None],neg_ind[None,:]] = -arr[neg_ind[:,None],pos_ind[None,:]]
# kx=0
arr[0,pos_ind] = -arr[0,neg_ind]
# ky=0
arr[pos_ind,0] = -arr[neg_ind,0]
# zero-out components which don't have k-> -k mapping
if evenQ:
arr[n//2,:] = 0 # zero highest spatial freq
arr[:,n//2] = 0 # zero highest spatial freq
arr[0,0] = 0 # DC component
# fourier-array
arr = np.exp(2j*np.pi*arr)*phi0
# inverse FFT and remove floating point errors
arr = np.fft.ifft2(arr).real
return arr
# potential
potential = white_noise_object_2D(n,phi0)
complex_obj = np.exp(1j*potential)
Probe¶
# we build probe in Fourier space, using a soft aperture
qx = qy = np.fft.fftfreq(n,sampling)
q2 = qx[:,None]**2 + qy[None,:]**2
q = np.sqrt(q2)
aperture_fourier = np.sqrt(
np.clip(
(q_probe - q)/reciprocal_sampling + 0.5,
0,
1,
),
)
def return_chi(
q,
wavelength,
C10,
C30,
):
""" """
prefactor = 2*np.pi / wavelength
alpha = q*wavelength
order_2 = alpha**2 / 2 * C10
order_4 = alpha**4 / 4 * C30
return (order_2 + order_4) * prefactor
def simulate_intensities(q,wavelength,C10,C30, batch_size=n**2, pbar=None):
""" """
m = n**2
n_batch = int(m // batch_size)
order = np.arange(m).reshape((n_batch,batch_size))
amplitudes = np.zeros((m,n,n))
if pbar is not None:
pbar.reset(n_batch)
pbar.colour = None
pbar.refresh()
exp_chi = np.exp(-1j*return_chi(q,wavelength,C10,C30))
probe_array_fourier = aperture_fourier * exp_chi
probe_array_fourier /= np.sqrt(np.sum(np.abs(probe_array_fourier)**2))
probe_array = np.fft.ifft2(probe_array_fourier) * n
for batch_index in range(n_batch):
batch_order = order[batch_index]
amplitudes[batch_order] = ctf.simulate_data(
complex_obj,
probe_array,
row[batch_order],
col[batch_order],
)
if pbar is not None:
pbar.update(1)
if pbar is not None:
pbar.colour = 'green'
return [amplitudes, probe_array]
x = y = np.arange(0.,n,scan_step_size)
xx, yy = np.meshgrid(x,y,indexing='ij')
positions = np.stack((xx.ravel(),yy.ravel()),axis=-1)
row, col = ctf.return_patch_indices(positions,(n,n),(n,n))
amplitudes_probe = simulate_intensities(
q,
wavelength,
C10,
C30,
batch_size=1024,
pbar=None
)
intensities = [amplitudes_probe[0].reshape((sx,sy,n,n))**2 / n**2,None]
intensities[1] = intensities[0].sum((-1,-2))
Ptycho calculation¶
def ptycho_reconstruction(
amplitudes,
row,
col,
positions,
recon_array,
probe_array,
pbars,
batch_size = n**2,
iterations=64,
step_size=1.0,
):
""" """
m = amplitudes.shape[0]
nx, ny = probe_array.shape
n = int(m // batch_size)
order = np.arange(m)
np.random.shuffle(order)
# normalization
probe_normalization = np.mean(np.sum(amplitudes**2,0)) / nx / ny
shifted_probes = probe_array * np.sqrt(probe_normalization)
outer_pbar,inner_pbar = pbars
outer_pbar.reset(iterations)
outer_pbar.colour= None
outer_pbar.refresh()
for iter_index in range(iterations):
inner_pbar.reset(n)
inner_pbar.colour= None
inner_pbar.refresh()
for batch_index in range(n):
batch_order = order.reshape((n,batch_size))[batch_index]
batch_amplitudes = amplitudes[batch_order]
batch_pos = positions[batch_order]
batch_row = row[batch_order]
batch_col = col[batch_order]
# recon
obj_patches = recon_array[batch_row,batch_col]
overlap = shifted_probes * obj_patches
fourier_overlap = np.fft.fft2(overlap)
fourier_intensities = np.abs(fourier_overlap)**2
# preprocess fourier overlap
modified_fourier_overlap = batch_amplitudes*np.exp(1j*np.angle(fourier_overlap))
grad = np.fft.ifft2(modified_fourier_overlap-fourier_overlap)
update = ctf.sum_patches(
grad*np.conj(shifted_probes),
batch_pos,
(nx,ny),
(nx,ny),
) / probe_normalization
recon_array += (step_size*update)
amp = np.abs(recon_array).clip(0.0,1.0)
recon_array = amp * np.exp(1j*np.angle(recon_array))
inner_pbar.update(1)
np.random.shuffle(order)
update_ptycho_panel(recon_array)
outer_pbar.update(1)
inner_pbar.colour='green'
outer_pbar.colour='green'
return recon_array
ptycho_recon = [np.ones((n,n),dtype=np.complex128)]
q_bins = np.arange(
0,
q_max + reciprocal_sampling,
reciprocal_sampling
)
with plt.ioff():
dpi=72
fig, axs = plt.subplots(1,3,figsize=(640/dpi,270/dpi),dpi=dpi)
# ptycho CTF
ax_ctf_pixelated_ptycho = axs[0]
im_ctf_ptycho = ax_ctf_pixelated_ptycho.imshow(
np.zeros((n,n)),
cmap=cmap,
vmin=-1,
vmax=1,
)
ctf.add_scalebar(
ax_ctf_pixelated_ptycho,
length=n//4,
sampling=reciprocal_sampling,
units=r'$q_{\mathrm{probe}}$',
color="black"
)
ax_ctf_pixelated_ptycho.set(xticks=[],yticks=[],title="contrast transfer function (CTF)")
# CTF radially-averaged
ax_ctf_rad = axs[1]
plot_ctf_ptycho = ax_ctf_rad.plot(q_bins, np.zeros_like(q_bins), color=iter_ptycho_line_color)[0]
ax_ctf_rad.set(
xlim=[0,2],
ylim=[0,1.025],
aspect= 2 / 1.025,
xticks=[0,1,2],
yticks=[],
xlabel=r"spatial frequency, $q/q_{\mathrm{probe}}$",
title="radially averaged CTF"
)
ax_obj = axs[2]
im_obj = ax_obj.imshow(
np.zeros((n,n)),
cmap=sample_cmap,
vmin=0,
vmax=1
)
ctf.add_scalebar(
ax_obj,
length=40,
sampling=sampling,
units=r'Å',
size_vertical=2,
)
ctf.add_scalebar(
ax_obj,
length=40,
sampling=sto_sampling,
units=r'Å',
size_vertical=2,
)
ctf.add_scalebar(
ax_obj,
length=40,
sampling=mof_sampling,
units=r'nm',
size_vertical=2
)
ctf.add_scalebar(
ax_obj,
length=40,
sampling=apo_sampling,
units=r'nm',
size_vertical=2
)
noise_scalebar, sto_scalebar, mof_scalebar, apo_scalebar = ax_obj.artists
sto_scalebar.set_visible(False)
mof_scalebar.set_visible(False)
apo_scalebar.set_visible(False)
ax_obj.set(xticks=[],yticks=[],title="CTF-convolved weak phase object")
im_obj.set_alpha(0.25)
fig.tight_layout()
fig.canvas.resizable = False
fig.canvas.header_visible = False
fig.canvas.footer_visible = False
fig.canvas.toolbar_visible = False
fig.canvas.layout.height = "280px"
fig.canvas.layout.width = '640px'
None
def update_ptycho_panel(ptycho_recon):
""" """
ctf_ptycho = ctf.compute_ctf(np.angle(ptycho_recon))
im_ctf_ptycho.set_data(
np.fft.fftshift(ctf_ptycho)
)
_, I_bins_ptycho = ctf.radially_average_ctf(
ctf_ptycho,
(sampling,sampling)
)
plot_ctf_ptycho.set_ydata(I_bins_ptycho)
zero_pad_ctf_to_4qprobe = np.fft.ifftshift(
np.pad(np.fft.fftshift(ctf_ptycho),48)
)
object_index = object_dropdown.value
if object_index==0:
convolved_object = np.angle(ptycho_recon)
else:
chosen_potential = potentials[object_index-1]
convolved_object = np.fft.ifft2(
np.fft.fft2(chosen_potential) * zero_pad_ctf_to_4qprobe
).real
noise_scalebar.set_visible(object_index==0)
sto_scalebar.set_visible(object_index==1)
mof_scalebar.set_visible(object_index==2)
apo_scalebar.set_visible(object_index==3)
convolved_object = ctf.histogram_scaling(convolved_object,normalize=True)
im_obj.set_data(convolved_object)
im_ctf_ptycho.set_alpha(1)
plot_ctf_ptycho.set_alpha(1)
im_obj.set_alpha(1)
fig.canvas.draw()
return None
def compute_ptycho_updates(
batch_size,
iterations,
pbars,
):
""" """
ptycho_recon[0] = ptycho_reconstruction(
amplitudes_probe[0],
row,
col,
positions,
ptycho_recon[0],
amplitudes_probe[1],
pbars,
batch_size=batch_size,
iterations = iterations,
)
return None
style = {'description_width': 'initial'}
layout = ipywidgets.Layout(width="320px",height="30px")
smaller_layout = ipywidgets.Layout(width="160px",height="30px")
kwargs = {'style':style,'layout':layout,'continuous_update':False}
m = n**2
batch_sizes = m/(np.arange(m)+1)
batch_sizes = np.where(np.mod(batch_sizes, 1, out=batch_sizes)==0)[0]+1
batch_size_slider = ipywidgets.SelectionSlider(
options=batch_sizes,
value=batch_sizes[-7],
description= "batch size",
**kwargs
)
iterations_slider = ipywidgets.IntSlider(
value = 4, min = 1, max = 32, step = 1,
description = "(outer loop) iterations",
**kwargs
)
iterate_button = ipywidgets.Button(
description="reconstruct (expensive)",
layout=smaller_layout,
)
reset_button = ipywidgets.Button(
description="reset object",
layout=smaller_layout,
)
C10_slider = ipywidgets.FloatSlider(
value = 0,
min = -500,
max = 500,
step = 1,
description = r"negative defocus, $C_{1,0}$ [Å]",
**kwargs
)
C30_slider = ipywidgets.FloatSlider(
value = 0,
min = -100,
max = 100,
step = 0.1,
description = r"spherical aberration, $C_{3,0}$ [µm]",
**kwargs
)
simulate_button = ipywidgets.Button(
description='simulate (expensive)',
layout=ipywidgets.Layout(width="160px",height="30px")
)
object_dropdown = ipywidgets.Dropdown(
options=[("white noise object",0),("strontium titanate",1),("metal-organic framework",2),("apoferritin protein",3)],
**kwargs
)
simulation_pbar = tqdm(total=9,display=False)
simulation_pbar_wrapper = ipywidgets.HBox(simulation_pbar.container.children[:2],layout=ipywidgets.Layout(width="160px"))
iterate_button.button_style = 'warning'
def defocus_wrapper(*args):
simulate_button.button_style = 'warning'
reset_wrapper()
im_ctf_ptycho.set_alpha(0.25)
plot_ctf_ptycho.set_alpha(0.25)
im_obj.set_alpha(0.25)
simulation_pbar.reset()
C10_slider.observe(defocus_wrapper,names='value')
C30_slider.observe(defocus_wrapper,names='value')
def simulate_wrapper(*args):
disable_all(True)
amplitudes_probe[:] = simulate_intensities(
q,
wavelength,
C10=C10_slider.value,
C30=C30_slider.value,
batch_size=1024,
pbar=simulation_pbar
)
intensities[0] = amplitudes_probe[0].reshape((sx,sy,n,n))**2 / n**2
intensities[1] = intensities[0].sum((-1,-2))
disable_all(False)
iterate_button.button_style = 'warning'
outer_reconstruct_pbar.reset()
inner_reconstruct_pbar.reset()
simulate_button.button_style = ''
simulate_button.on_click(simulate_wrapper)
def reset_wrapper(*args):
""" """
ptycho_recon[0] = np.ones((n,n),dtype=np.complex128)
update_ptycho_panel(ptycho_recon[0])
if simulate_button.button_style != 'warning':
iterate_button.button_style = 'warning'
im_obj.set_alpha(0.25)
outer_reconstruct_pbar.reset()
inner_reconstruct_pbar.reset()
reset_button.on_click(reset_wrapper)
def disable_all(boolean):
batch_size_slider.disabled = boolean
iterations_slider.disabled = boolean
reset_button.disabled = boolean
iterate_button.disabled = boolean
C10_slider.disabled = boolean
C30_slider.disabled = boolean
object_dropdown.disabled = boolean
simulate_button.disabled = boolean
simulation_pbar_wrapper.disabled = boolean
def click_wrapper(*args):
""" """
disable_all(True)
compute_ptycho_updates(
batch_size=batch_size_slider.value,
iterations=iterations_slider.value,
pbars=(outer_reconstruct_pbar,inner_reconstruct_pbar),
)
disable_all(False)
iterate_button.button_style = ''
iterate_button.on_click(click_wrapper)
outer_reconstruct_pbar = tqdm(total=4,display=False)
outer_reconstruct_pbar_wrapper = ipywidgets.HBox(outer_reconstruct_pbar.container.children[:2],layout=ipywidgets.Layout(width="160px"))
inner_reconstruct_pbar = tqdm(total=9,display=False)
inner_reconstruct_pbar_wrapper = ipywidgets.HBox(inner_reconstruct_pbar.container.children[:2],layout=ipywidgets.Layout(width="160px"))
def update_current_ptycho_wrapper(*args):
""" """
update_ptycho_panel(ptycho_recon[0])
return None
object_dropdown.observe(update_current_ptycho_wrapper,"value")
ipywidgets.VBox(
[
ipywidgets.VBox(
[
ipywidgets.HBox([C10_slider,C30_slider]),
ipywidgets.HBox([simulate_button,simulation_pbar_wrapper,object_dropdown]),
ipywidgets.HTML("<hr>",layout=ipywidgets.Layout(width="640px")),
ipywidgets.HBox([batch_size_slider,iterations_slider]),
ipywidgets.HBox([reset_button,iterate_button,outer_reconstruct_pbar_wrapper,inner_reconstruct_pbar_wrapper]),
]
),
fig.canvas
]
)
VBox(children=(VBox(children=(HBox(children=(FloatSlider(value=0.0, continuous_update=False, description='nega…