From 62fdd15d682ea4316a3332235fb6e9118b0aa918 Mon Sep 17 00:00:00 2001 From: Daniel Ching Date: Tue, 19 Oct 2021 13:20:11 -0500 Subject: [PATCH] NEW: Wrap lamino operator in Pytorch module --- src/tike/operators/torch/__init__.py | 8 ++ src/tike/operators/torch/lamino.py | 76 +++++++++++++++ tests/operators/torch/test_lamino.py | 136 +++++++++++++++++++++++++++ tests/test_lamino.py | 5 + 4 files changed, 225 insertions(+) create mode 100644 src/tike/operators/torch/__init__.py create mode 100644 src/tike/operators/torch/lamino.py create mode 100644 tests/operators/torch/test_lamino.py diff --git a/src/tike/operators/torch/__init__.py b/src/tike/operators/torch/__init__.py new file mode 100644 index 00000000..7ac5e14d --- /dev/null +++ b/src/tike/operators/torch/__init__.py @@ -0,0 +1,8 @@ +"""Wraps cupy operators in torch.autograd.Function""" + +from .lamino import * + +__all__ = [ + LaminoFunction, + LaminoModule, +] diff --git a/src/tike/operators/torch/lamino.py b/src/tike/operators/torch/lamino.py new file mode 100644 index 00000000..fc3c9bb1 --- /dev/null +++ b/src/tike/operators/torch/lamino.py @@ -0,0 +1,76 @@ +import cupy as cp +import torch + +import tike.operators.cupy + + +class LaminoFunction(torch.autograd.Function): + """The forward/adjoint laminography operations. + + Parameters + ---------- + u : (N, N, N, 2) tensor float32 + A (3 + 1)D tensor where the first dimensions are spatial dimensions and + the last dimension of len 2 is the real/imaginary components. Pytorch + doesn't presently have good complex-value support. + theta : (M, ) tensor float32 + The rotation angles of the projections. + tilt : float + The laminography angle + output : (M, N, N, 2) float32 + Projections through the volume at each rotation angle. + + """ + + @staticmethod + def forward(ctx, u, theta, tilt=cp.pi / 2): + ctx.n = u.shape[0] + ctx.tilt = tilt + ctx.save_for_backward(theta) + with tike.operators.cupy.Lamino( + n=ctx.n, + tilt=ctx.tilt, + eps=1e-6, + upsample=2, + ) as operator: + output = operator.fwd( + u=cp.asarray(torch.view_as_complex(u).detach(), + dtype='complex64'), + theta=cp.asarray(theta, dtype='float32'), + ) + output = torch.view_as_real(torch.as_tensor(output, device=u.device)) + return output + + @staticmethod + def backward(ctx, grad_output): + theta, = ctx.saved_tensors + with tike.operators.cupy.Lamino( + n=ctx.n, + tilt=ctx.tilt, + eps=1e-6, + upsample=2, + ) as operator: + grad_u = operator.adj( + data=cp.asarray(torch.view_as_complex(grad_output), + dtype='complex64'), + theta=cp.asarray(theta, dtype='float32'), + ) / grad_output.shape[0] + grad_u = torch.view_as_real( + torch.as_tensor(grad_u, device=grad_output.device)) + grad_theta = grad_tilt = None + return grad_u, grad_theta, grad_tilt + + +class LaminoModule(torch.nn.Module): + + def __init__(self, width): + super(LaminoModule, self).__init__() + self.width = width + self.weight = torch.nn.Parameter( + torch.zeros(width, width, width, 2, dtype=torch.float32)) + + def forward(self, theta, tilt=cp.pi / 2): + return LaminoFunction.apply(self.weight, theta, tilt) + + def extra_repr(self): + return f'width={self.width}' diff --git a/tests/operators/torch/test_lamino.py b/tests/operators/torch/test_lamino.py new file mode 100644 index 00000000..a47dca9b --- /dev/null +++ b/tests/operators/torch/test_lamino.py @@ -0,0 +1,136 @@ +import lzma +import os +import pickle +import unittest + +import cupy as cp +import numpy as np +import torch +from torch.nn.modules.loss import GaussianNLLLoss + +from tike.operators.torch import LaminoFunction, LaminoModule + + +@unittest.skip('single precision is not enough to pass gradcheck') +def test_lamino_gradcheck(n=16, ntheta=8): + + lamino = LaminoFunction.apply + + # gradcheck takes a tuple of tensors as input, check if your gradient + # evaluated with these tensors are close enough to numerical + # approximations and returns True if they all verify this condition. + input = ( + torch.randn( + n, + n, + n, + 2, + dtype=torch.float32, + requires_grad=True, + device='cpu', + ), + cp.pi * torch.randn( + ntheta, + dtype=torch.float32, + requires_grad=False, + device='cpu', + ), + ) + test = torch.autograd.gradcheck( + lamino, + input, + eps=1e-6, + atol=1e-4, + nondet_tol=1e-6, + ) + print(test) + + +testdir = os.path.dirname(os.path.dirname(os.path.dirname(__file__))) + + +class L2Loss(torch.nn.Module): + + def forward(self, input, target): + return torch.mean(torch.square(torch.abs(input - target))) + + +class TestLaminoModel(unittest.TestCase): + + def setUp(self): + """Load a dataset for reconstruction.""" + dataset_file = os.path.join(testdir, 'data/lamino_setup.pickle.lzma') + if not os.path.isfile(dataset_file): + self.create_dataset(dataset_file) + with lzma.open(dataset_file, 'rb') as file: + [ + self.data, + self.original, + self.theta, + self.tilt, + ] = pickle.load(file) + + def test_lamino_model(self, num_epoch=32, device=0): + + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + theta = torch.from_numpy(self.theta).type(torch.float32).to(device) + data = torch.view_as_real( + torch.from_numpy(self.data).type(torch.complex64)).to(device) + var = torch.ones(data.shape, dtype=torch.float32, + requires_grad=True).to(device) + + model = LaminoModule(data.shape[1]).to(device) + lossf = GaussianNLLLoss().to(device) + optimizer = torch.optim.Adam(model.parameters()) + + loss_log = [] + for epoch in range(num_epoch): + pred = model(theta, self.tilt) + loss = lossf(pred, data, var) + optimizer.zero_grad() + loss.backward() + optimizer.step() + loss_log.append(loss.item()) + print(f"loss: {loss_log[-1]:.3e} [{epoch:>5d}/{num_epoch:>5d}]") + + obj = torch.view_as_complex(model.weight.cpu().detach()).numpy() + + _save_lamino_result({'obj': obj, 'costs': loss_log}, 'torch') + + +def _save_lamino_result(result, algorithm): + try: + import matplotlib.pyplot as plt + fname = os.path.join(testdir, 'result', 'lamino', f'{algorithm}') + os.makedirs(fname, exist_ok=True) + plt.figure() + plt.title(algorithm) + plt.plot(result['costs']) + plt.semilogy() + plt.savefig(os.path.join(fname, 'convergence.svg')) + slice_id = int(35 / 128 * result['obj'].shape[0]) + plt.imsave( + f'{fname}/{slice_id}-phase.png', + np.angle(result['obj'][slice_id]).astype('float32'), + # The output of np.angle is locked to (-pi, pi] + cmap=plt.cm.twilight, + vmin=-np.pi, + vmax=np.pi, + ) + plt.imsave( + f'{fname}/{slice_id}-ampli.png', + np.abs(result['obj'][slice_id]).astype('float32'), + ) + import skimage.io + skimage.io.imsave( + f'{fname}/phase.tiff', + np.angle(result['obj']).astype('float32'), + ) + skimage.io.imsave( + f'{fname}/ampli.tiff', + np.abs(result['obj']).astype('float32'), + ) + + except ImportError: + pass diff --git a/tests/test_lamino.py b/tests/test_lamino.py index 6f422a8e..e7b5709d 100644 --- a/tests/test_lamino.py +++ b/tests/test_lamino.py @@ -283,6 +283,11 @@ def _save_lamino_result(result, algorithm): import matplotlib.pyplot as plt fname = os.path.join(testdir, 'result', 'lamino', f'{algorithm}') os.makedirs(fname, exist_ok=True) + plt.figure() + plt.title(algorithm) + plt.plot(result['cost']) + plt.semilogy() + plt.savefig(os.path.join(fname, 'convergence.svg')) slice_id = int(35 / 128 * result['obj'].shape[0]) plt.imsave( f'{fname}/{slice_id}-phase.png',