""" Base class providing common functionality for analyzing Leed patterns. """

from __future__ import division

__version__ = 0.3

from UserList import UserList
import re
import logging
logger = logging.getLogger("leedbase")

import numpy as np
np.seterr(all="raise")

import PyQt4.QtGui as _qt

# only for guess from fit
from scipy import optimize, ndimage

# load packages for available file types
formats_available = []
try:
    import pyfits
    formats_available.append("FITS")    
except:
    logger.warning("The pyfits package is not properly installed.")
try:
    import Image
    formats_available.append("IMG")    
except:
    logger.warning("The pil package (Image) is not properly installed.")

#### Image loading ####
class ImageLoader(object):
    """ Abstract base class for a class loading LEED images.

    Subclasses need to provide
        - get_energy(image_path)
        - get_image(image_path)
    """
    def __init__(self, image_paths):
        # build a dictionary with energy as key and imagePath as value
        self.files = {}
        for image_path in image_paths:
            energy = self.get_energy(image_path)
            self.files[energy] = image_path
        self.energies = sorted(self.files.keys())
        self.restart()
   
    def current_energy(self):
        """ Get current energy. """
        return self.energies[self.index]

    def __iter__(self):
        return self

    def restart(self):
        """ Start at lowest energy again. """
        self.index = -1

    def previous(self):
        """ Get image at next lower beam energy. """
        if self.index == 0:
            raise StopIteration("there is no previous image")
        else:
           self.index -= 1
           energy = self.energies[self.index]
           return self.get_image(self.files[energy]), energy

    def next(self):
        """ Get image at next higher beam energy. """
        if self.index < len(self.energies)-1:
            self.index += 1
            energy = self.energies[self.index]
            return self.get_image(self.files[energy]), energy
        else:
            raise StopIteration()

    # FIXME: untested
    def custom_iter(self, energies):
        """ Returns an iterator to iter over the given energies."""
        non_elements = set(energies) - set(self.energies)
        if non_elements:
            raise Exception("ImageLoader doesn't have the following elements: %s" % (list(non_elements)))
        for energy in energies:
            yield self.get_image(energy), energy

class ImgImageLoader(ImageLoader):
    def __init__(self, files, **kwargs):
        ImageLoader.__init__(self, files)

    def get_energy(self, image_path):
        with open(image_path, "rb") as f:
            return self.load_header(f)["Beam Voltage (eV)"]
    
    def load_header(self, f):
        # find header length
        line = f.readline()
        while not "Header length:" in line:
            line = f.readline()
        header_length = int(line.split(": ")[1].strip())
        # jump back to beginning
        f.seek(0)
        # read in header
        header_raw = f.read(header_length)
        ## process header ##
        # dict containing names of all interesting entrys
        header = {"Beam Voltage (eV)": 0, "Date": "", "Comment": "", "x1": 0, "y1": 0, "x2": 0, "y2": 0, "Number of frames": 0}
        headerlines = header_raw.split("\n")
        for line in headerlines:
            parts = line.split(": ")
            if parts[0] in header.keys():
                # convert int entrys
                if type(header[parts[0]]) == type(1):
                    header[parts[0]] = int(parts[1])
                # convert string entrys
                elif type(header[parts[0]]) == type(""):
                    header[parts[0]] = parts[1].strip()
        return header

    def get_image(self, image_path):
        with open(image_path, "rb") as f:
            header = self.load_header(f) 
            # read in content
            content = f.read()
            # calculate size of image from header information
            size = (header["x2"] - header["x1"] + 1, header["y2"] - header["y1"] +1)
            # load image with PIL
            pilimage = Image.fromstring("F", size, content, "raw", "F;16", 0, 1)
            return np.asarray(pilimage)

class FitsImageLoader(ImageLoader):
    def __init__(self, files, regex="\d{1,3}(?=.fit)"):
        self.regex = regex
        ImageLoader.__init__(self, files)

    def get_energy(self, image_path):
        m = re.search(self.regex, image_path)
        return int(m.group())
    
    def get_image(self, image_path):
        try:
            hdulist = pyfits.open(image_path)
            data = hdulist[0].data
            hdulist.close()
            return data
        except IOError:
            print "IOError while processing %s" % image_path
            raise IOError()

class ImageFormat:
    """Class describing an image format."""
    def __init__(self, abbrev, extensions, loader):
        """
        abbrev: abbreviation (e.g. FITS)
        extensions: list of corresponding file extensions (e.g. .fit, .fits)
        loader: ImageLoader subclass for this format
        """
        self.abbrev = abbrev
        self.extensions = extensions
        self.loader = loader
    def __str__(self):
        return "{0}-Files ({1})".format(self.abbrev, " ".join(self.extensions))

""" Dictionary of available ImageFormats. """
IMAGE_FORMATS = dict([str(format_), format_] for format_ in \
        [ImageFormat("FITS", ["*.fit", "*.fits"], FitsImageLoader),
        ImageFormat("IMG", ["*.img"], ImgImageLoader)] \
             if format_.abbrev in formats_available)

def _normalize255(array):
    """ Returns a normalized array of uint8."""
    nmin, nmax = array.min(), array.max()
    if nmin:
        array = array - nmin
    scale = 255.0 / (nmax - nmin)
    if scale != 1.0:
        array = array * scale
    return array.astype("uint8")

def npimage2qimage(npimage):
    """ Converts numpy grayscale image to qimage."""
    h, w = npimage.shape
    npimage = _normalize255(npimage)
    qimage = _qt.QImage(npimage.data, w, h, _qt.QImage.Format_Indexed8)
    for i in range(256):
        qimage.setColor(i, _qt.qRgb(i, i, i))
    return qimage
########################

#### generate points ####
def points_in_square(size):
    """ returns int points lying in a square centered on 0,0 with given size.

        >>> [point for point in points_in_square(1)]
        [(-1, -1), (-1, 0), (-1, 1), (0, -1), (0, 0), (0, 1), (1, -1), (1, 0), (1, 1)]
    """
    coords = xrange(int(-size), int(size) + 1, 1) ## DOES THE FLOAT->INT MODIFY THE RESULT MUCH? Possible solution http://code.activestate.com/recipes/66472-frange-a-range-function-with-float-increments/
    return ((x, y) for x in coords for y in coords)

def points_in_circle(radius):
    """ returns int points lying in a circle centered on 0, 0 with given radius."""
    radius_squared = radius**2
    for x, y in points_in_square(radius):
        if x**2 + y**2 <= radius_squared:
            yield x, y

def points_at_circle(x0, y0, radius):
    """ Yields int points lying at the edge of an circle centered on x0, y0.
        Uses the Midpoint circle algorithm.
    """
    f = 1 - radius
    ddf_x = 1
    ddf_y = -2 * radius
    x = 0
    y = radius
    yield x0, y0 + radius
    yield x0, y0 - radius
    yield x0 + radius, y0
    yield x0 - radius, y0
 
    while x < y:
        if f >= 0: 
            y -= 1
            ddf_y += 2
            f += ddf_y
        x += 1
        ddf_x += 2
        f += ddf_x    
        yield x0 + x, y0 + y
        yield x0 - x, y0 + y
        yield x0 + x, y0 - y
        yield x0 - x, y0 - y
        yield x0 + y, y0 + x
        yield x0 - y, y0 + x
        yield x0 + y, y0 - x
        yield x0 - y, y0 - x

def points_in_annulus(big_radius, small_radius):
    """ Yields integer points lying between two circles.
        
        >>> [point for point in points_in_annulus(2, 1)]
        [(-2, 0), (-1, -1), (-1, 1), (0, -2), (0, 2), (1, -1), (1, 1), (2, 0)]
    """
    big_radius_squared = big_radius**2
    small_radius_squared = small_radius**2
    for x, y in points_in_square(big_radius):
        distance_squared = x**2 + y**2
        if small_radius_squared < distance_squared <= big_radius_squared:
            yield x, y
##########################

class SpotModel:
    """ Data model for a Spot that stores all the information in various lists.
    """

    def __init__(self):
        self.x = []
        self.y = []
        self.intensity = []
        self.energy = []
        self.radius = []

    def update(self, x, y, intensity, energy, radius):
        self.x.append(x)
        self.y.append(y)
        self.intensity.append(intensity)
        self.energy.append(energy)
        self.radius.append(radius)

class ImageRegion(UserList):
    """ Region of a image defined by contained spots. """
    def middle(self):
        ylist, xlist = zip(*self.data)
        return np.mean(xlist), np.mean(ylist)

    def radius(spot):
        nitems = len(self.data)
        return (nitems / np.pi)**0.5

#### Kalman Filter ####
class AbstractKalmanFilter(object):
    """ Abstract implementation of a Kalman filter.
    
    Matrices and Vectors can be given in any input format np.matrix() understands.
    Vectors are internally transposed and should therefore be given as column vectors.
    """

    def __init__(self, x, P, H):
        """ Initialize Kalman filter.

        x: start state vector
        P: start state covariance matrix
        H: measurement matrix
        """
        
        super(AbstractKalmanFilter, self).__init__()
        self.x = np.asmatrix(x).T
        self.P = np.asmatrix(P)
        self.H = np.asmatrix(H)
        # identity matrix of state vector size
        self._1 = np.asmatrix(np.identity(max(self.x.shape)))

    def predict(self, F, Q = np.zeros((4, 4))):
        """ Predict next state.

        F: state transition matrix
        Q: process covariance matrix
        """
        
        F = np.asmatrix(F)
        Q = np.asmatrix(Q)
        self.x = F * self.x
        self.P = F * self.P * F.T  + Q

    def _calcS(self, R=None):
        """ Returns the covariance matrix of a predicted measurement. """
        if not R is None:
            return self.H * self.P * self.H.T + R
        else:
            return self.H * self.P * self.H.T

    def _calcz(self):
        """ Predicted measurement result. """
        return self.H * self.x

    def update(self, z, R):
        """ Update state estimate.

        z: measurement vector
        R: measurement covariance matrix
        """

        z = np.asmatrix(z).T
        R = np.asmatrix(R)
        S = self._calcS(R)
        z_predicted = self._calcz()
        K = self.P * self.H.T * S.I
        self.x = self.x  +  K * (z - z_predicted)
        self.P = (self._1 - K * self.H) * self.P

    def measurement_distance(self, z, R=None):
        """ Returns the normalized distance squared of the given measurement.
        
        z: measurement vector
        """
        z = np.asmatrix(z).T
        S = self._calcS(R)
        z_predicted = self._calcz()
        diff = z - z_predicted
        return diff.T * S.I * diff 

class AbstractPVKalmanFilter(AbstractKalmanFilter):
    """ Kalman filter for 2d-tracking using position and velocity as state variables."""
    def __init__(self, x_in, y_in, P, time):
        self.old_time = time
        x = [x_in, y_in, 0, 0]
        H = [[1, 0, 0, 0], [0, 1, 0, 0]]
        super(AbstractPVKalmanFilter, self).__init__(x, P, H)

    def get_position(self):
        return float(self.x[0]), float(self.x[1])

    def get_position_err(self):
        return float(self.P[0,0])**0.5, float(self.P[1,1])**0.5

class PVKalmanFilter0(AbstractPVKalmanFilter):
    def predict(self, time, *args, **kwargs):
        dt = time - self.old_time
        F = [[1, 0, dt, 0], [0, 1, 0, dt], [0, 0, 1, 0], [0, 0, 0, 1]]
        super(PVKalmanFilter0, self).predict(F, *args, **kwargs)
        self.old_time = time

class PVKalmanFilter1(AbstractPVKalmanFilter):
    def predict(self, time, *args, **kwargs):
        dt = time - self.old_time
        a = - 1.5 / self.old_time
        v_up = 1 + a * dt
        pos_up = dt + 0.5 * a * dt**2
        F = [[1, 0, pos_up, 0], [0, 1, 0, pos_up], [0, 0, v_up, 0], [0, 0, 0, v_up]]
        super(PVKalmanFilter1, self).predict(F, *args, **kwargs)
        self.old_time = time

class PVKalmanFilter2(AbstractPVKalmanFilter):
    def predict(self, time, *args, **kwargs):
        dt = time - self.old_time
        a = - 1.5 / self.old_time
        a_dot = 1.875 / self.old_time**2
        v_up = 1 + a * dt + a_dot * dt**2
        pos_up = dt + 0.5 * a * dt**2 + (1.0 / 3.0) * a_dot * dt**3
        F = [[1, 0, pos_up, 0], [0, 1, 0, pos_up], [0, 0, v_up, 0], [0, 0, 0, v_up]]
        super(PVKalmanFilter2, self).predict(F, *args, **kwargs)
        self.old_time = time

class PVKalmanFilter3(AbstractPVKalmanFilter):
    def predict(self, time, *args, **kwargs):
        dt = time - self.old_time
        a = - 1.5 / self.old_time
        a_dot = 1.875 / self.old_time**2
        a_ddot = - 2.1875 / self.old_time**3
        v_up = 1 + a * dt + a_dot * dt**2 + a_ddot * dt**3
        pos_up = dt + 0.5 * a * dt**2 + (1.0 / 3.0) * a_dot * dt**3 + (1.0 / 4.0) * a_ddot * dt**4
        F = [[1, 0, pos_up, 0], [0, 1, 0, pos_up], [0, 0, v_up, 0], [0, 0, 0, v_up]]
        super(PVKalmanFilter3, self).predict(F, *args, **kwargs)
        self.old_time = time

class PVKalmanFilterExactV(AbstractPVKalmanFilter):
    def predict(self, time, *args, **kwargs):
        dt = time - self.old_time
        a = - 1.5 / self.old_time
        a_dot = 1.875 / self.old_time**2
        v_up = (self.old_time / time)**1.5
        pos_up = dt + 0.5 * a * dt**2 + (1.0 / 3.0) * a_dot * dt**3
        F = [[1, 0, pos_up, 0], [0, 1, 0, pos_up], [0, 0, v_up, 0], [0, 0, 0, v_up]]
        super(PVKalmanFilterExactV, self).predict(F, *args, **kwargs)
        self.old_time = time
#######################

class Tracker:
    """ Tracks spots through intensity information and velocity prediction. """
    def __init__(self, x_in, y_in, radius, energy,
            input_precision=1, window_scaling=False, min_window_size = 0):
        """ x_in, y_in: start position of spot """
        self.radius = radius
        cov_input = np.diag([input_precision, input_precision, 1000, 1000])
        self.kalman = PVKalmanFilter3(x_in, y_in, cov_input, energy)
        self.window_scaling = window_scaling
        if self.window_scaling:
            self.min_window_size = min_window_size
            self.c_size = energy**0.5 * (self.radius - self.min_window_size)

    def feed_image(self, image):
        npimage, energy = image
        if self.window_scaling:
            self.radius = self.c_size / energy**0.5 + self.min_window_size
        self.kalman.predict(energy, np.diag([10**(-2),10**(-2),10**(-6),10**(-6)]))
        x_p, y_p = self.kalman.get_position()
        x_th, y_th, guess_cov = guesser(npimage, x_p, y_p, self.radius, kalman = self.kalman)
        self.kalman.update([x_th, y_th], guess_cov)
        x, y = self.kalman.get_position()
        intensity = calc_intensity(npimage, x, y, self.radius)
        return x, y, intensity, energy, self.radius

def guess_from_com(image, *args, **kwargs):
    bin_image = (image > otsu(image))
    return com(image, bin_image)

def guess_from_brightest(image, *args, **kwargs):
    return reversed(np.unravel_index(np.argmax(image), image.shape))

def guess_from_fit(image, x_mid=0, y_mid=0, size = 5):
    max_ = image.flatten().max()
    back = np.mean(image.flatten())
    params = [max_ - back, x_mid, y_mid, size, size, back]
#    back = np.median(image)
#    params = moments(image)
#    params.append(back)
#    params[0] = params[0] - back
    errfunc = lambda p: np.ravel(gaussian2d(*p)(*np.indices(image.shape)) - image)
    errfuncSquared = lambda p: ((np.ravel(gaussian2d(*p)(*np.indices(image.shape)) - image))**2).sum()
    offset = 10
    bounds = [(0, None), (params[1]-offset, params[1]+offset), (params[2]-offset, params[2]+offset), (0.5, 10), (0.5, 10), (0, None)]
    try:
        value = optimize.fmin_l_bfgs_b(errfuncSquared, np.asarray(params), bounds=bounds, approx_grad=True)
        params_opt = value[0]
        x_res = params_opt[2]
        y_res = params_opt[1]
#        output = optimize.leastsq(errfunc, params, full_output=True, maxfev=200)
    except:
        return None
#    p_opt = output[0]
#    p_cov = output[1]
#    infodict = output[2]
#    if infodict["nfev"] >= 200 or p_cov is None:
#        return [], []
#    s_sq = (errfunc(p_opt)**2).sum()/(len(image.flatten())-len(params))
#    p_cov *= s_sq
#    p_cov = p_cov[1:3, 1:3]
#    res_cov = np.matrix([[p_cov.item(1), p_cov.item(0)], [p_cov.item(3), p_cov.item(2)]])
#    x_res = p_opt[1]
#    y_res = p_opt[2]
#    return [(x_res, y_res)], [res_cov + np.diag([0.5, 0.5])]
    return x_res, y_res

def guess_from_spots(image, use_com=False, **kwargs):
    threshold = otsu(image)
    bin_image = (image.copy() > threshold)
    spots = find_spots(bin_image)
    if use_com:
        spots_pos = [com_circle(image, spot.middle(), spot.radius()) for spot in spots]
    else:
        spots_pos = [spot.middle() for spot in spots]
    return {"spots" : spots_pos}

def guess_from_am(image, x_mid, y_mid, expansion_factor=4, **kwargs):
    otsu_thresh = otsu(image)
    max_image = max(image.flatten())
    increment = (max_image - otsu_thresh) * 0.1
    thresh = max_image - 5 * increment
    bin_image_old = image > thresh
    spots = find_spots(bin_image_old)
    spots = [[spot, spot.middle(), len(spot)] for spot in spots]
    while thresh > otsu_thresh:
        old_thresh = thresh
        thresh -= increment
        #thresh = bin_mids[-i]
        new_pixels = [index for index, value in np.ndenumerate(image) if thresh <= value < old_thresh]
        for y, x in new_pixels:
            distances = []
            for i, spot in enumerate(spots):
                # dist = dist_to_mid - r
                dist = ((y - spot[1][0])**2 + (x - spot[1][1])**2)**0.5 - \
                        ((spot[2] / np.pi))**0.5
                distances.append((dist, i))
            distances.sort()
            if distances[0][0] < expansion_factor:
                index = distances[0][1]
                spots[index][0].append((y, x))
                spots[index][1] = spots[index][0].middle()
                spots[index][2] += 1
            else:
                spots.append([ImageRegion([(y, x)]), (y, x), 1])
    spots_pos = [spot[1] for spot in spots if spot[2] > 2]
    return spots_pos

def guesser(npimage, x_in, y_in, radius, func = guess_from_spots, max_radius = 20, kalman = None, gamma = 6, smooth = True, default_cov=np.diag([2, 2])):
    def failure(reason):
        logger.info("no guess, because " + reason)
        return x_in, y_in, np.asmatrix(np.diag([1000, 1000]))
    try:
        x_min, x_max, y_min, y_max = adjust_slice(npimage, x_in - radius, x_in + radius + 1,
                                     y_in - radius, y_in + radius + 1)
    except IndexError:
       return failure("position outside image")
   
    image = npimage[y_min : y_max, x_min : x_max]
    if smooth:
        image = ndimage.filters.gaussian_filter(image, 1, mode="nearest")
    result = func(image, x_mid = x_in - x_min, y_mid = y_in - y_min, size = radius * 0.8)
#    if not spots:
#        if radius * 2**0.5 < max_radius:
#            return guesser(npimage, x_in, y_in, radius * 2**0.5, max_radius = max_radius, kalman = kalman, gamma = gamma)
#        else:

    if result is None:
        return failure("no bright spot found")
    spots = result["spots"]
    if result.has_key("covs"):
        covs = result["covs"]
    else:
        covs = [default_cov for i in range(len(spots))]
    spots = [(x_spot + x_min, y_spot + y_min) for x_spot, y_spot in spots]
    spots_decorated = [(kalman.measurement_distance(spot, np.diag([0.1, 0.1])), spot, i) for i, spot in enumerate(spots)]
    distance, min_spot, min_index = min(spots_decorated)
    x_res, y_res = min_spot
    cov = covs[min_index]
    if distance > gamma:
        return failure("no spot in validation gate")
    return x_res, y_res, cov
    
def adjust_slice(image, x_sl_min, x_sl_max, y_sl_min, y_sl_max):
    """
    Adjusts slice if it is trying to get pieces outside the image.

    >>> image = np.ones((2, 2))
    >>> adjust_slice(image, 0, 1.5, 0, 2)
    (0, 1, 0, 2)
    >>> adjust_slice(image, -5.5, 2, -0.5, 10)
    (0, 2, 0, 2)
    >>> adjust_slice(image, -5, -4, -0.5, 10)
    Traceback (most recent call last):
    ...
    IndexError
    """

    ymax, xmax = image.shape
    adjusted = False
    indices = [int(x_sl_min), int(x_sl_max), int(y_sl_min), int(y_sl_max)]
    for i, value in enumerate(indices):
        if value < 0:
            indices[i] = 0
            adjusted = True
    for i, value in enumerate(indices):
        if i < 2:
            if value > xmax:
                indices[i] = xmax
                adjusted = True
        else:
            if value > ymax:
                indices[i] = ymax
                adjusted = True
    if adjusted:
        logger.warning("slice had to be adjusted to fit image.")
    if not int(indices[0] - indices[1]) or not int(indices[2] - indices[3]):
        raise IndexError()
    return tuple(indices)
   
def calc_intensity(npimage, x, y, radius, background_substraction=True):
    """ Calculates the intensity of the point.
        Uses background substraction.
    """
    intensity, cnt = _integral_intensity(npimage, points_in_circle(radius), x, y)
    if background_substraction:
        intensity -= background_intensity(npimage, x, y, radius) * cnt
    return intensity

def background_intensity(npimage, x, y, radius):
    """ Returns trimmed mean of points at circle. """
    pixels = points_at_circle(x, y, radius)
    intensities = [npimage[y][x] for x, y in pixels]
    intensities.sort()
    off = int(0.05 * len(intensities))
    intensities = intensities[off:-off]
    return np.mean(intensities)

def _integral_intensity(npimage, pixels, x, y):
    """ Sums the values of the image at the given pixel increments.
        returns: summed intensity, number of pixels
    """
    intensity = 0
    for cnt, incs in enumerate(pixels):
        intensity += npimage[y + incs[1]][x + incs[0]]
    return intensity, cnt
##################

#### routines ####

def gaussian2d(height, center_x, center_y, width_x, width_y,
                offset=0, slope_x=0, slope_y = 0):
    """Returns a two dimensional gaussian function with the given parameters"""
    return lambda x, y: np.asarray(height * np.exp(-(((center_x - x) / width_x)**2 + \
                        ((center_y - y) / width_y)**2) / 2)) + \
                        offset# + np.asarray(x) * slope_x + np.asarray(y) * slope_y

def moments(data):
    """Returns [height, x, y, width_x, width_y]
    the gaussian parameters of a 2D distribution by calculating its
    moments. """
    total = data.sum()
    X, Y = np.indices(data.shape)
    x = (X*data).sum()/total
    y = (Y*data).sum()/total
    col = data[:, int(y)]
    width_x = np.sqrt(abs((np.arange(col.size)-y)**2*col).sum()/col.sum())
    row = data[int(x), :]
    width_y = np.sqrt(abs((np.arange(row.size)-x)**2*row).sum()/row.sum())
    height = data.max()
    return [height, x, y, width_x, width_y]

def com_circle(npimage, mid, radius):
    y, x = mid
    y = round(y)
    x = round(x)
    weights = np.zeros_like(npimage)
    for x_inc, y_inc in points_in_circle(radius + 1):
        try:
            weights[y_inc + y, x_inc + x] = 1
        except IndexError:
            pass
    result = com(npimage, weights)
    return result

def com(npimage, weights=None): # slower than scipy
    if not weights is None:
        npimage *= weights
    x_res, y_res = 0, 0
    normalizer = 0
    ymax, xmax = npimage.shape
    for x in range(0, xmax):
        for y in range(0, ymax):
            value = npimage[y][x]
            x_res += x * value
            y_res += y * value
            normalizer += value
    try:
        x_res /= normalizer
        y_res /= normalizer
        return x_res, y_res
    except ZeroDivisionError, FloatingPointError:
        logger.warning("com couldn't be calculated")
        return None

def gaussian1d(height, center_x, width_x, offset=0, slope=0):
    """Returns a one dimensional gaussian function with the given parameters."""
    return lambda x: height * np.exp(-((center_x - x) / width_x)**2/2) + \
                     offset + np.asarray(x) * slope

def find_spots(bin_image):
    """ Returns a list of spots that are 8-connected.
    
    Performs flood filling.
    """
    spots = []
    # ndenumerate might cause crash (?!)
    white_pixels = set(index for index, value in np.ndenumerate(bin_image) if value)
#    ymax, xmax = bin_image.shape
#    list_ = []
#    for y in xrange(0, ymax):
#        for x in xrange(0, xmax):
#            if bin_image[y][x]:
#                list_.append((y, x))
#    white_pixels = set(list_)
    while white_pixels:
        startpixel = white_pixels.pop()
        point = _expand(startpixel, white_pixels)
        point.add(startpixel)
        spots.append(ImageRegion(point))
    return spots

def _expand(pixel, white_pixels):
    y, x = pixel
    next_pixels = set((y + i, x + j) for i in [-1, 0, 1] for j in [-1, 0, 1])
    next_pixels = next_pixels.intersection(white_pixels)
    white_pixels -= next_pixels
    if not white_pixels:
        return next_pixels
    if not next_pixels:
        return set([pixel])
    point = next_pixels.copy()
    for next_pixel in next_pixels:
        point.update(_expand(next_pixel, white_pixels))
    return point

def otsu(image):
    """ Finds threshold using Otsu's method.

    image: image for which threshold is to be found.
    
    C++ code by Jordan Bevik <Jordan.Bevic@qtiworld.com>
    ported to ImageJ plugin by G.Landini
    ported to python by A.Mayer
    """
    nbins = 300
    hist, bin_edges = np.histogram(image, bins = nbins)
    bin_mid = bin_edges[:-1] + 0.5 * np.diff(bin_edges)
    N = image.size # total number of points
    S = sum(bin_mid*hist) # The total intensity of the image
    # Initialize values:
    Sk = bin_mid[0] * hist[0] # The total intensity for all histogram points <=k
    N1 = hist[0] # N1 = # points with intensity <=k
    BCV = 0 # The current Between Class Variance
    BCVmax = 0 # maximum Between Class Variance
    kStar = 0 # kStar = optimal threshold

    # Look at each possible threshold value,
    # calculate the between-class variance, and decide if it's a max
    # k = the current threshold
    for k in range(1, nbins-1): # No need to check endpoints k = 0 or k = L-1
        Sk += bin_mid[k] * hist[k];
        N1 += hist[k];
        denom =  N1 * (N - N1)
        if denom:
            num = (N1 / N) * S - Sk
            BCV = (num * num) / denom
        else:
            BCV = 0
        if (BCV >= BCVmax): # Assign the best threshold found so far
            BCVmax = BCV
            kStar = k
    return bin_mid[kStar]
##################

if __name__ == "__main__":
    import doctest
    doctest.testmod()
