Graeme Winter

Hybrid Programming / Spot Finding on MCU

Bigger ask this time: find spots on Eiger 9M images (~9 megapixel, 16 bit depth) on an MCU which has ~ 200kB of RAM. It can be done.

Key points:

The spot finding uses typically a 7x7 kernel around the pixel in question, computes a local mean and variance then does some threshold calculation to determine whether the signal is or is not part of the background. If the pixel in question is significantly different from the neighbours then it is signal and therefore interesting. The work here is to do that calculation in a SAMD51 MCU as part of an Adafruit Grand Central development board.

Input

Prepare data by reading from native HDF5 files, re-arranging into a module stack and saving as uint16_t binary files to a µSD card (supported on the Grand Central). This is justified as it is one way the data could have been saved. Code:

import sys

import h5py
import hdf5plugin
import bitshuffle
import numpy

from matplotlib import pyplot

ny, nx = 512, 1028
gy, gx = 38, 12

with h5py.File(sys.argv[1], "r") as f:
    total = None

    data = f["/entry/data/data"]

    assert data.shape[1] == 6 * ny + 5 * gy
    assert data.shape[2] == 3 * nx + 2 * gx

    for i in range(data.shape[0]):
        stack = numpy.zeros(dtype=numpy.uint16, shape=(18, ny, nx))
        frame = data[i][()]
        for j in range(18):
            x = j // 6
            y = j % 6

            module = frame[
                y * (ny + gy) : y * (ny + gy) + ny, x * (nx + gx) : x * (nx + gx) + nx
            ]
            stack[j] = module

        stack = stack.reshape((18 * ny * nx,))

        with open(f"frame_{i:05d}.raw", "wb") as o:
            nbytes = o.write(stack.tobytes())
            print(f"Writing frame_{i:05d}.raw: {nbytes} bytes")

Approach

Write as a µPython extension module with internal buffers, using µPython for SD card access etc. as this is simply a proof of concept. Builds on previous work on writing extension modules, though -Ofast was needed to be able to use the Cortex M4 intrinsic sqrtf function.

Implementation

#include <math.h>
#include "py/dynruntime.h"

// global variables - strictly some of these are not needed

uint32_t *im = NULL;
uint32_t *m_sat = NULL;
uint32_t *i_sat = NULL;
uint32_t *i2_sat = NULL;

uint32_t height = 0;
uint32_t width = 0;
uint32_t knl = 0;
uint32_t row = 0;

int signal_filter_init(uint32_t _height, uint32_t _width, uint32_t _knl) {
  if (im != NULL) {
    return 1;
  }
  height = _height;
  width = _width;
  knl = _knl;

  // buffer size is 2 * KNL + 1 rows, + 1 more row for the pre-subtraction row
  uint32_t nn = (2 * _knl + 2) * _width;

  im = (uint32_t *)m_malloc(sizeof(uint32_t) * nn);
  m_sat = (uint32_t *)m_malloc(sizeof(uint32_t) * nn);
  i_sat = (uint32_t *)m_malloc(sizeof(uint32_t) * nn);
  i2_sat = (uint32_t *)m_malloc(sizeof(uint32_t) * nn);

  return 0;
}

int signal_filter_deinit(void) {
  if (im == NULL) {
    return 1;
  }

  m_free(im);
  m_free(m_sat);
  m_free(i_sat);
  m_free(i2_sat);

  im = m_sat = i_sat = i2_sat = NULL;

  return 0;
}

int signal_filter_row(uint16_t *io_row) {
  float sigma_b = 6.0f, sigma_s = 3.0f;
  uint32_t knl2 = 2 * knl + 1;

  int nsignal = 0;

  if (row < height) {
    // move rows up if we are past the start-up region
    if (row > knl2) {
      for (uint32_t i = 0; i < knl2; i++) {
        uint32_t off = i * width;
        for (uint32_t j = 0; j < width; j++) {
          im[off + j] = im[off + j + width];
          m_sat[off + j] = m_sat[off + j + width];
          i_sat[off + j] = i_sat[off + j + width];
          i2_sat[off + j] = i2_sat[off + j + width];
        }
      }
    }

    // populate row, update SAT
    uint32_t _m = 0, _i = 0, _i2 = 0;
    for (uint32_t j = 0; j < width; j++) {
      uint32_t m = io_row[j] > 0xfffd ? 0 : 1;
      uint32_t p = m * io_row[j];
      uint32_t i = row < knl2 ? row * width : knl2 * width;
      _m += m;
      _i += p;
      _i2 += p * p;
      im[i + j] = p;
      m_sat[i + j] = row > 0 ? _m + m_sat[i + j - width] : _m;
      i_sat[i + j] = row > 0 ? _i + i_sat[i + j - width] : _i;
      i2_sat[i + j] = row > 0 ? _i2 + i2_sat[i + j - width] : _i2;
    }
  }

  // at the very start we cannot do anything useful
  if (row < knl) {
    row++;
    return 0;
  }

  // after that we can start to compute signal tables for earlier rows
  for (uint32_t j = 0; j < width; j++) {

    int32_t j0 = j - knl - 1;
    int32_t j1 = j < (width - knl) ? j + knl : width - 1;

    int32_t i1 = row < knl2 ? row : knl2;
    int32_t i = row < height ? i1 - knl : i1 - (knl - (row - height)) + 1;
    int32_t i0 = i - knl - 1;

    int32_t a = i1 * width + j1;
    int32_t b = i0 * width + j1;
    int32_t c = i1 * width + j0;
    int32_t d = i0 * width + j0;

    uint32_t m_sum = m_sat[a], i_sum = i_sat[a], i2_sum = i2_sat[a];

    if (j0 >= 0 && i0 >= 0) {
      m_sum += m_sat[d] - m_sat[b] - m_sat[c];
      i_sum += i_sat[d] - i_sat[b] - i_sat[c];
      i2_sum += i2_sat[d] - i2_sat[b] - i2_sat[c];
    } else if (j0 >= 0) {
      m_sum -= m_sat[c];
      i_sum -= i_sat[c];
      i2_sum -= i2_sat[c];
    } else if (i0 >= 0) {
      m_sum -= m_sat[b];
      i_sum -= i_sat[b];
      i2_sum -= i2_sat[b];
    }

    // N.B. for photon counting detectors the threshold is usually zero,
    // masked pixels have value zero so cannot be signal - if the pixel is
    // > 0 then the sum must be

    uint16_t signal = 0;
    uint32_t p = im[i * width + j];

    if (p > 0 && m_sum >= 2) {
      float bg_lhs = (float)m_sum * i2_sum - (float)i_sum * i_sum -
                     (float)i_sum * (m_sum - 1);
      float bg_rhs = i_sum * sigma_b * (float) sqrtf((float)2 * (m_sum - 1));
      uint16_t background = bg_lhs > bg_rhs;
      float fg_lhs = (float) m_sum * p - (float)i_sum;
      float fg_rhs = sigma_s * (float) sqrtf((float)i_sum * m_sum);
      uint16_t foreground = fg_lhs > fg_rhs;
      signal = background && foreground;
    }
    io_row[j] = signal;
    nsignal += signal;
  }

  row++;

  if (row == height + knl)
    row = 0;

  return nsignal;
}

Wrapper

#include <stdint.h>
#include "py/dynruntime.h"

int signal_filter_init(uint32_t height, uint32_t width, uint32_t knl);
int signal_filter_row(uint16_t *row);
int signal_filter_deinit(void);

STATIC mp_obj_t spot_finder_init(mp_obj_t ny_obj, mp_obj_t nx_obj) {
  mp_int_t ny = mp_obj_get_int(ny_obj);
  mp_int_t nx = mp_obj_get_int(nx_obj);
  uint32_t knl = 3;
  signal_filter_init((uint32_t) ny, (uint32_t) nx, knl);
  return mp_obj_new_int(0);
}

STATIC mp_obj_t spot_finder_deinit(void) {
  signal_filter_deinit();
  return mp_obj_new_int(0);
}

STATIC mp_obj_t spot_finder_row(mp_obj_t row_obj) {
  mp_buffer_info_t bufinfo;
  mp_get_buffer_raise(row_obj, &bufinfo, MP_BUFFER_RW);
  mp_int_t nsignal = signal_filter_row((uint16_t *) bufinfo.buf);
  return mp_obj_new_int(nsignal);
}

STATIC MP_DEFINE_CONST_FUN_OBJ_2(spot_finder_init_obj, spot_finder_init);
STATIC MP_DEFINE_CONST_FUN_OBJ_0(spot_finder_deinit_obj, spot_finder_deinit);
STATIC MP_DEFINE_CONST_FUN_OBJ_1(spot_finder_row_obj, spot_finder_row);

mp_obj_t mpy_init(mp_obj_fun_bc_t *self, size_t n_args, size_t n_kw,
                  mp_obj_t *args) {
  MP_DYNRUNTIME_INIT_ENTRY

  mp_store_global(MP_QSTR_init, MP_OBJ_FROM_PTR(&spot_finder_init_obj));
  mp_store_global(MP_QSTR_deinit, MP_OBJ_FROM_PTR(&spot_finder_deinit_obj));
  mp_store_global(MP_QSTR_row, MP_OBJ_FROM_PTR(&spot_finder_row_obj));

  MP_DYNRUNTIME_INIT_EXIT
}

API

from machine import Pin, SPI
import sdcard
import os
import time
import gc

import spot_filter


def main():

    spot_filter.init(512, 1028)

    sd = sdcard.SDCard(
        SPI(
            2,
            baudrate=8000000,
            mosi=Pin("SD_MOSI"),
            miso=Pin("SD_MISO"),
            sck=Pin("SD_SCK"),
        ),
        Pin("SD_CS"),
        baudrate=8000000,
    )

    led = Pin("D13", Pin.OUT)
    os.mount(sd, "/sd")

    # file is 18 * 512 * 1028 * 2 bytes

    buffer = bytearray(1028 * 2)

    for j in range(10):
        filename = f"/sd/frame_{j:05d}.raw"
        with open(filename, "rb") as fin:
            t0 = time.time()
            signal = 0
            for i in range(18):
                for j in range(512):
                    fin.readinto(buffer, 1028 * 2)
                    signal += spot_filter.row(buffer)
                # last 3 rows
                for j in range(3):
                    signal += spot_filter.row(buffer)
                led.toggle()

            t1 = time.time()

            print(f"{filename} => {signal} signal pixels in {t1 - t0:d}s")

    os.umount("/sd")

    spot_filter.deinit()


main()

Results

Reference results from DIALS for 1st 10 images in data set:

Found 275 strong pixels on image 1
Found 693 strong pixels on image 2
Found 734 strong pixels on image 3
Found 853 strong pixels on image 4
Found 754 strong pixels on image 5
Found 757 strong pixels on image 6
Found 731 strong pixels on image 7
Found 766 strong pixels on image 8
Found 645 strong pixels on image 9
Found 704 strong pixels on image 10

From the µPython console:

>>> %Run -c $EDITOR_CONTENT
/sd/frame_00000.raw => 275 signal pixels in 127s
/sd/frame_00001.raw => 693 signal pixels in 124s
/sd/frame_00002.raw => 734 signal pixels in 124s
/sd/frame_00003.raw => 853 signal pixels in 131s
/sd/frame_00004.raw => 754 signal pixels in 125s
/sd/frame_00005.raw => 757 signal pixels in 124s
/sd/frame_00006.raw => 731 signal pixels in 124s
/sd/frame_00007.raw => 766 signal pixels in 125s
/sd/frame_00008.raw => 645 signal pixels in 124s
/sd/frame_00009.raw => 704 signal pixels in 131s

Timing commentary: simply reading the data in the program above, with no analysis, takes 80s / frame => have some work to do on optimising it but it does work. Need to write an elegant way of extracting the data perhaps to a second sequence of files, and more importantly figure out how to fit the connected component analysis in as well, so that the results are actually useful for the next step in the analysis.

Code

Code in repository here - needs custom µPython build for SAMD51 but should be fairly portable, modulo hacking the SD card access.

Reference

To compare with the same analysis in DIALS, generate “gold” reference data with:

import sys

import numpy
import bitshuffle

import iotbx.phil
from scitbx.array_family import flex

from dials.algorithms.spot_finding.factory import SpotFinderFactory
from dials.algorithms.spot_finding.factory import phil_scope as spot_phil
from dials.util.options import ArgumentParser, flatten_experiments


ny, nx = 512, 1028
gy, gx = 38, 12


def signal(imageset, images):
    panels = imageset.get_detector()
    assert len(panels) == 1

    detector = panels[0]
    trusted = detector.get_trusted_range()

    for image in images:
        data = imageset.get_raw_data(image)[0]

        negative = data < int(round(trusted[0]))
        hot = data > int(round(trusted[1]))
        bad = negative | hot

        data = data.as_double()

        spot_params = spot_phil.fetch(
            source=iotbx.phil.parse("spotfinder.threshold.algorithm=dispersion")
        ).extract()
        threshold_function = SpotFinderFactory.configure_threshold(spot_params)
        peak_pixels = (
            threshold_function.compute_threshold(data, ~bad)
            .as_numpy_array()
            .astype(numpy.uint16)
        )

        stack = numpy.zeros(dtype=numpy.uint16, shape=(18, ny, nx))

        for j in range(18):
            x = j // 6
            y = j % 6

            module = peak_pixels[
                y * (ny + gy) : y * (ny + gy) + ny, x * (nx + gx) : x * (nx + gx) + nx
            ]
            stack[j] = module

        stack = stack.reshape((18 * ny * nx,))

        with open(f"gold_{image:05d}.bslz4", "wb") as f:
            nbytes = f.write(bitshuffle.compress_lz4(stack).tobytes())
            print(
                f"{image} -> {numpy.count_nonzero(stack)} in gold_{image:05d}.bslz4 ({nbytes} bytes)"
            )


phil_scope = iotbx.phil.parse(
    """
images = None
  .type = ints
  .help = "Images on which to perform the analysis (otherwise use all images)"
image_range = None
  .type = ints(value_min=0, size=2)
  .help = "Image range for analysis e.g. 1,1800"
nproc = 1
  .type = int(value_min=1)
    .help = "The number of processes to use."

output {
    mask = pixels.mask
        .type = path
        .help = "Output mask file name"
    print_values = False
        .type = bool
        .help = "Print bad pixel values"
}
"""
)


def main(args=None):
    usage = "program [options] (data_master.h5|data_*.cbf)"

    parser = ArgumentParser(
        usage=usage,
        phil=phil_scope,
        read_experiments=True,
        read_experiments_from_images=True,
    )

    params, options = parser.parse_args(args, show_diff_phil=True)

    experiments = flatten_experiments(params.input.experiments)

    assert len(experiments) == 1, len(experiments)

    experiment = experiments[0]

    imageset = experiment.imageset

    first, last = experiment.scan.get_array_range()

    images = range(first, last)

    if params.images is None and params.image_range is not None:
        start, end = params.image_range
        params.images = list(range(start, end + 1))

    if params.images:
        if min(params.images) < first or max(params.images) > last:
            sys.exit("image(s) outside of scan range")
        images = params.images

    signal(imageset, images)


main(sys.argv[1:])

Comparing the data from the µPython implementation, saved to the SD card, confirms that the correct pixels are identified.