import sys
import logging

from PyQt4.QtCore import (QPoint, QRectF, QPointF, Qt, SIGNAL, QTimer, QObject)
from PyQt4.QtGui import (QApplication, QMainWindow, QGraphicsView,
    QGraphicsScene, QImage, QWidget, QHBoxLayout, QPen,
    QVBoxLayout, QPushButton, QGraphicsEllipseItem, QGraphicsItem,
    QPainter, QKeySequence, QAction, QIcon, QFileDialog, QProgressBar,
    QBrush, QFrame, QLabel, QRadioButton, QGridLayout)
import numpy as np

from . import config
from . import __version__
from base import *
from io import *

##H #
import matplotlib.pyplot as plt
import pickle
from . import flatten as flt
#####



logging.basicConfig(filename = config.loggingFilename, level=config.loggingLevel)

class QGraphicsSpotView(QGraphicsEllipseItem):
    """ Provides an QGraphicsItem to display a Spot on a QGraphicsScene.
    
        Circle class providing a circle that can be moved by mouse and keys.
        
    """

    def __init__(self, point, radius, parent=None):
        offset = QPointF(radius, radius)
        super(QGraphicsSpotView, self).__init__(QRectF(-offset, offset),  parent)
        self.setPen(QPen(Qt.blue))
        self.setPos(point)
        self.setFlags(QGraphicsItem.ItemIsSelectable|
                      QGraphicsItem.ItemIsMovable|
                      QGraphicsItem.ItemIsFocusable)

    def keyPressEvent(self, event):
        """ Handles keyPressEvents.

            The circle can be moved using the arrow keys. Applying Shift
            at the same time allows fine adjustments.

            The circles radius can be changed using the plus and minus keys.
        """

        if event.key() == Qt.Key_Plus:
            self.changeSize(config.QGraphicsSpotView_spotSizeChange)
        elif event.key() == Qt.Key_Minus:
            self.changeSize(-config.QGraphicsSpotView_spotSizeChange)
        elif event.key() == Qt.Key_Right:
            if event.modifiers() & Qt.ShiftModifier:
                self.moveRight(config.QGraphicsSpotView_smallMove)
            else:
                self.moveRight(config.QGraphicsSpotView_bigMove)
        elif event.key() == Qt.Key_Left:
            if event.modifiers() & Qt.ShiftModifier:
                self.moveLeft(config.QGraphicsSpotView_smallMove)
            else:
                self.moveLeft(config.QGraphicsSpotView_bigMove)
        elif event.key() == Qt.Key_Up:
            if event.modifiers() & Qt.ShiftModifier:
                self.moveUp(config.QGraphicsSpotView_smallMove)
            else:
                self.moveUp(config.QGraphicsSpotView_bigMove)
        elif event.key() == Qt.Key_Down:
            if event.modifiers() & Qt.ShiftModifier:
                self.moveDown(config.QGraphicsSpotView_smallMove)
            else:
                self.moveDown(config.QGraphicsSpotView_bigMove)

    def onPositionChange(self, point):
        """ Handles incoming position change request."""
        self.setPos(point)

    def radius(self):
        return self.rect().width() / 2.0

    def onRadiusChange(self, radius):
        """ Handles incoming radius change request."""
        self.changeSize(radius - self.radius())

    def moveRight(self, distance):
        """ Moves the circle distance to the right."""
        self.setPos(self.pos() + QPointF(distance, 0.0))

    def moveLeft(self, distance):
        """ Moves the circle distance to the left."""
        self.setPos(self.pos() + QPointF(-distance, 0.0))

    def moveUp(self, distance):
        """ Moves the circle distance up."""
        self.setPos(self.pos() + QPointF(0.0, -distance))

    def moveDown(self, distance):
        """ Moves the circle distance down."""
        self.setPos(self.pos() + QPointF(0.0, distance))

    def changeSize(self, inc):
        """ Change radius by inc.
        
            inc > 0: increase
            inc < 0: decrease
        """

        inc /= 2**0.5 
        self.setRect(self.rect().adjusted(-inc, -inc, +inc, +inc))

class QSpotModel(QObject):
    """
    Wraps a SpotModel to offer signals.

    Provides the following signals:
    - intensityChanged
    - positionChanged
    - radiusChanged
    """

    def __init__(self, parent = None):
        super(QSpotModel, self).__init__(parent)
        self.m = SpotModel()
    
    def update(self, x, y, intensity, energy, radius):
        self.m.update(x, y, intensity, energy, radius)
        QObject.emit(self, SIGNAL("positionChanged"), QPointF(x, y))
        QObject.emit(self, SIGNAL("radiusChanged"), radius)
        QObject.emit(self, SIGNAL("intensityChanged"), intensity)

class GraphicsScene(QGraphicsScene):
    """ Custom GraphicScene having all the main content."""

    def __init__(self, parent=None):
        super(GraphicsScene, self).__init__(parent)
    
    def mousePressEvent(self, event):
        """ Processes mouse events through either            
              - propagating the event
            or 
              - instantiating a new Circle (on left-click)
        """
       
        if self.itemAt(event.scenePos()):
            super(GraphicsScene, self).mousePressEvent(event)
        elif event.button() == Qt.LeftButton:
            item = QGraphicsSpotView(event.scenePos(),
                         config.GraphicsScene_defaultRadius)
            self.clearSelection()
            self.addItem(item)
            item.setSelected(True)
            self.setFocusItem(item)

    def keyPressEvent(self, event):
        """ Processes key events through either            
              - deleting the focus item
            or   
              - propagating the event

        """

        item = self.focusItem()
        if item:
            if event.key() == Qt.Key_Delete:
                self.removeItem(item)
                del item
            else:
                super(GraphicsScene, self).keyPressEvent(event)

    def drawBackground(self, painter, rect):
        """ Draws image in background if it exists. """
        if hasattr(self, "image"):
            painter.drawImage(QPoint(0, 0), self.image)

    def setBackground(self, image):
        """ Sets the background image. """
        self.image = image
        self.update()
    
    def removeAll(self):
        """ Remove all items from the scene (leaves background unchanged). """
        for item in self.items():
            self.removeItem(item)

class GraphicsView(QGraphicsView):
    """ Custom GraphicsView to display the scene. """
    def __init__(self, parent=None):
        super(GraphicsView, self).__init__(parent)
        self.setRenderHints(QPainter.Antialiasing)
    
#    def wheelEvent(self, event):
#        factor = 1.41 ** (-event.delta() / 240.0)
#        self.scale(factor, factor)
    
    def resizeEvent(self, event):
        self.fitInView(self.sceneRect(), Qt.KeepAspectRatio)
    
    def drawBackground(self, painter, rect):
        painter.fillRect(rect, QBrush(Qt.lightGray))
        self.scene().drawBackground(painter, rect)

class FileDialog(QFileDialog):
    def __init__(self, **kwargs):
        super(FileDialog, self).__init__(**kwargs)
        self.setFileMode(QFileDialog.ExistingFiles)


############## H ###########
class PlotWidget(QWidget):
    
    def __init__(self):
        super(PlotWidget, self).__init__()
        
        self.initUI()
        
    def initUI(self):
        
        self.setGeometry(300, 300, 300, 150)
        self.setWindowTitle('Select plotting method')
        self.gridLayout = QGridLayout()
        self.setLayout(self.gridLayout)
        self.rbutton1 = QRadioButton('Plot &intensities', self)
        self.rbutton2 = QRadioButton('Plot a&verage', self)
        self.rbutton3 = QRadioButton('&Plot intensities && average', self)
        self.verticalLayout = QVBoxLayout()
        self.verticalLayout.addWidget(self.rbutton1)
        self.verticalLayout.addWidget(self.rbutton2)
        self.verticalLayout.addWidget(self.rbutton3)
        self.gridLayout.addLayout(self.verticalLayout, 0, 0, 1, 1)
        self.label = QLabel("", self)	
        self.verticalLayout.addWidget(self.label)
        self.hLayout = QHBoxLayout()
        self.pbutton1 = QPushButton('&Accept', self)
        self.pbutton2 = QPushButton('&Cancel', self)
        self.hLayout.addWidget(self.pbutton1)
        self.hLayout.addWidget(self.pbutton2)
        self.gridLayout.addLayout(self.hLayout, 1, 0, 1, 1)

##############

class MainWindow(QMainWindow):
    """ easyLeed's main window. """
    def __init__(self, parent=None):
        super(MainWindow, self).__init__(parent)
        self.setWindowTitle("easyLeed %s" % __version__)

        #### setup central widget ####
        self.plotwid = PlotWidget()
        self.scene = GraphicsScene(self)
        self.view = GraphicsView()
        self.view.setScene(self.scene)
        self.view.setMinimumSize(640, 480)
        self.setCentralWidget(self.view)
        
        #### define actions ####
        processRunAction = self.createAction("&Run", self.run,
                QKeySequence("Ctrl+r"), None,
                "Run the analysis of the images.")
        processRestartAction = self.createAction("&Restart", self.restart,
                QKeySequence("Ctrl+z"), None,
                "Reset chosen points and jump to first image.")
        processNextAction = self.createAction("&Next Image", self.next_,
                QKeySequence("Ctrl+n"), None,
                "Open next image.")
        processPreviousAction = self.createAction("&Previous Image", self.previous,
                QKeySequence("Ctrl+p"), None,
                "Open previous image.")
                
## H #
        # actions to "Process" menu
        processPlotAction = self.createAction("&Plot", self.plotting, QKeySequence("Ctrl+d"), None, "Plot the energy/intensity.")
        processPlotAverageAction = self.createAction("&Plot average", self.plottingAverage, QKeySequence("Ctrl+g"), None, "Plot the energy/intensity average.")
        #needs still work
        #processSpotsAction = self.createAction("&Process Spots", self.processSpots, QKeySequence("Ctrl+"), None, "Process spots.")
        processPlotOptions = self.createAction("&Plot...", self.plottingOptions, None, None, "Choose plotting method.")

######

        self.processActions = [processNextAction, processPreviousAction, None, processRunAction, processRestartAction, None, processPlotAction, None, processPlotAverageAction]
        fileOpenAction = self.createAction("&Open...", self.fileOpen,
                QKeySequence.Open, None,
                "Open a directory containing the image files.")
        self.fileSaveAction = self.createAction("&Save intensities...", self.saveIntensity,
                QKeySequence.Save, None,
                "Save the calculated intensities to a text file.")
                
## H #
        # actions to "File" menu
        self.fileSavePlotAction = self.createAction("&Save plot...", self.savePlot, QKeySequence("Ctrl+a"), None, "Save the plot to a pdf file.")
        # Will only enable plot saving after there is a plot to be saved
        self.fileSavePlotAction.setEnabled(False)
        self.fileQuitAction = self.createAction("&Quit", self.fileQuit, QKeySequence("Ctrl+q"), None, "Close the application.")
        self.fileSaveSpotsAction = self.createAction("&Save spot locations...", self.saveSpots, QKeySequence("Ctrl+t"), None, "Save the spots to a file.")
        # Enables when data to be saved
        self.fileSaveSpotsAction.setEnabled(False)
        self.fileLoadSpotsAction = self.createAction("&Load spot locations...", self.loadSpots, QKeySequence("Ctrl+l"), None, "Load spots from a file.")
######
        
                
        
        self.fileActions = [fileOpenAction, self.fileSaveAction, self.fileSavePlotAction, self.fileSaveSpotsAction, self.fileLoadSpotsAction, None, self.fileQuitAction]

        #### Create menu bar ####
        fileMenu = self.menuBar().addMenu("&File")
        self.fileSaveAction.setEnabled(False)
        self.addActions(fileMenu, self.fileActions)
        processMenu = self.menuBar().addMenu("&Process")
        self.addActions(processMenu, self.processActions)
        self.enableProcessActions(False)

        #### Create tool bar ####
        toolBar = self.addToolBar("&Toolbar")
        # adding actions to the toolbar, addActions-function creates a separator with "None"
        self.toolBarActions = [self.fileQuitAction, None, fileOpenAction, None, processRunAction, None, processPreviousAction, None, processNextAction, None, processPlotOptions, None, None, processRestartAction]
        self.addActions(toolBar, self.toolBarActions)

        #### Create status bar ####
        self.statusBar().showMessage("Ready", 5000)
        self.energyLabel = QLabel()
        self.energyLabel.setFrameStyle(QFrame.StyledPanel | QFrame.Sunken)
        self.statusBar().addPermanentWidget(self.energyLabel)
  
    def addActions(self, target, actions):
        """
        Convenience function that adds the actions to the target.
        If an action is None a separator will be added.
        
        """
        for action in actions:
            if action is None:
                target.addSeparator()
            else:
                target.addAction(action)
    
    def createAction(self, text, slot=None, shortcut=None, icon=None,
                     tip=None, checkable=False, signal="triggered()"):
        """ Convenience function that creates an action with the specified attributes. """
        action = QAction(text, self)
        if icon is not None:
            action.setIcon(QIcon(":/{0}.png".format(icon)))
        if shortcut is not None:
            action.setShortcut(shortcut)
        if tip is not None:
            action.setToolTip(tip)
            action.setStatusTip(tip)
        if slot is not None:
            self.connect(action, SIGNAL(signal), slot)
        if checkable:
            action.setCheckable(True)
        return action

    def enableProcessActions(self, enable):
        for action in self.processActions:
            if action:
                action.setEnabled(enable)

    def next_(self):
        try:
            image = self.loader.next()
        except StopIteration:
            self.statusBar().showMessage("Reached last picture", 5000)
        else:
            self.setImage(image)

    def previous(self):
        try:
            image = self.loader.previous()
        except StopIteration:
            self.statusBar().showMessage("Reached first picture", 5000)
        else:
            self.setImage(image)

    def restart(self):
        self.scene.removeAll()
        self.loader.restart()
        self.setImage(self.loader.next())
        plt.close()

    def setImage(self, image):
        npimage, energy = image
        qimage = npimage2qimage(npimage)
        self.view.setSceneRect(QRectF(qimage.rect()))
        self.scene.setBackground(qimage)
        self.current_energy = energy
        self.energyLabel.setText("Energy %s eV" % self.current_energy)

    def saveIntensity(self):
        filename = str(QFileDialog.getSaveFileName(self, "Save intensities to a file"))
        if filename:
            self.worker.save(filename)

    def fileOpen(self):
        """ Prompts the user to select input image files."""
        self.scene.removeAll()
        dialog = FileDialog(parent = self,
                caption = "Choose image files", filter= ";;".join(IMAGE_FORMATS))
        if dialog.exec_():
            files = dialog.selectedFiles();
            filetype = IMAGE_FORMATS[str(dialog.selectedNameFilter())]
            files = [str(file_) for file_ in files]
            try:
                self.loader = filetype.loader(files, regex=config.IO_energyRegex)
                self.setImage(self.loader.next())
                self.enableProcessActions(True)
            except:
                self.statusBar().showMessage("Invalid filename. Check naming policy.", 5000)
            
    def stopProcessing(self):
        self.stopped = True

    def run(self):
        import time
        time_before = time.time()
        
        self.stopped = False
        progress = QProgressBar()
        stop = QPushButton("Stop", self)
        self.connect(stop, SIGNAL("clicked()"), self.stopProcessing)
        progress.setMinimum(int(self.loader.current_energy()))
        progress.setMaximum(int(self.loader.energies[-1]))
        statusLayout = QHBoxLayout()
        statusLayout.addWidget(progress)
        statusLayout.addWidget(stop)
        statusWidget = QWidget(self)
        statusWidget.setLayout(statusLayout)
        self.statusBar().addWidget(statusWidget)
        self.view.setInteractive(False)
        self.scene.clearSelection()
        self.worker = Worker(self.scene.items(), self.current_energy, parent=self)
        self.fileSaveAction.setEnabled(True)
        self.fileSaveSpotsAction.setEnabled(True)
        for image in self.loader:
            if self.stopped:
                break
            progress.setValue(int(image[1]))
            QApplication.processEvents()
            self.setImage(image)
            self.worker.process(image)
            QApplication.processEvents()
        self.view.setInteractive(True)
        self.statusBar().removeWidget(statusWidget)

        print time.time() - time_before

    def disableInput(self):
        for item in self.scene.items():
            item.setFlag(QGraphicsItem.ItemIsSelectable, False)
            item.setFlag(QGraphicsItem.ItemIsFocusable, False)
            item.setFlag(QGraphicsItem.ItemIsMovable, False)

##H #
	## Plotting with matplotlib ##

    def plotting(self):
        # do only if there's some data to draw the plot from, otherwise show an error message in the statusbar
        try:
            # getting intensities and energy from the worker class
            intensities = [model.m.intensity for model, tracker \
                                in self.worker.spots_map.itervalues()]
            energy = [model.m.energy for model, tracker in self.worker.spots_map.itervalues()]
            plt.ioff()
            # setting the axe labels
            plt.xlabel("Energy [eV]")
            plt.ylabel("Intensity")
            plt.title("I(E)-curve")
            # removes the ticks from y-axis
            plt.gca().get_yaxis().set_ticks([])
            plt.ion()
            # do the plot
            for x in energy:
                for y in intensities:
                    plt.plot(x, y)
            # and show it
            plt.draw()
            self.fileSavePlotAction.setEnabled(True)
        except AttributeError:
            self.statusBar().showMessage("No plottable data.", 5000)

    def plottingAverage(self):
        try:
            i=0
            j=0
            sum_intensity=0
            list_of_average_intensities = []
            intensities = [model.m.intensity for model, tracker \
                                in self.worker.spots_map.itervalues()]
            number_of_pictures = len(intensities[0])
            number_of_points = len(intensities)
            energy = [model.m.energy for model, tracker in self.worker.spots_map.itervalues()]
            intensities = flt.flatten(intensities)
            for i in range(number_of_pictures):
                for j in range(i,len(intensities), number_of_pictures):
                    sum_intensity = sum_intensity + intensities[j]
                average_intensity = sum_intensity/number_of_points
                list_of_average_intensities.append(average_intensity)
                sum_intensity = 0
            plt.ioff()
            # setting the axe labels
            plt.xlabel("Energy [eV]")
            plt.ylabel("Intensity")
            plt.title("I(E)-curve")
            # removes the ticks from y-axis
            plt.gca().get_yaxis().set_ticks([])
            plt.ion()
            plt.plot(energy[0], list_of_average_intensities,'k-', linewidth=3, label = 'Average')
            plt.draw()
        except AttributeError:
            self.statusBar().showMessage("No plottable data.", 5000)

    def plottingOptions(self):

        QObject.connect(self.plotwid.pbutton1, SIGNAL("clicked()"), self.PlotWidgetAccept)
        QObject.connect(self.plotwid.pbutton2, SIGNAL("clicked()"), self.plotwid.close)

        self.plotwid.show()

    def PlotWidgetAccept(self):        
        if self.plotwid.rbutton1.isChecked():
            plt.clf()
            self.plotting()
            self.plotwid.close()
        elif self.plotwid.rbutton2.isChecked():
            plt.clf()
            self.plottingAverage()
            self.plotwid.close()
        elif self.plotwid.rbutton3.isChecked():
            plt.clf()
            self.plotting()
            self.plottingAverage()
            self.plotwid.close()
        else:
            self.plotwid.label.setText('Check a mode')


#           
#
#
	
	## Saving the plot

    def savePlot(self):
        # savefile prompt
	    filename = str(QFileDialog.getSaveFileName(self, "Save the plot to a file"))
	    if filename:
            # matplotlib.plyplot save-function, saves as pdf
		    plt.savefig(filename, format="pdf")

    ## Quitting the application
    # special quit-function as the normal window closing might leave something on the background
    def fileQuit(self):
        QApplication.closeAllWindows()
        plt.close()

    ## Some spot controlling
    # saves the spot locations to a file, uses workers saveloc-function
    def saveSpots(self):
        filename = str(QFileDialog.getSaveFileName(self, "Save the spot locations to a file"))
        if filename:
            self.worker.saveloc(filename)

    # loads the spots from a file
    def loadSpots(self):
        # This can probably be done in a better way
        filename = QFileDialog.getOpenFileName(self, 'Open spot location file')
        if filename:
            # pickle doesn't recognise the file opened by PyQt's openfile dialog as a file so 'normal' file processing
            pkl_file = open(filename, 'rb')
            # loading the zipped info to "location"
            location = pickle.load(pkl_file)
            pkl_file.close()
            # unzipping the "location"
            energy, locationx, locationy, radius = zip(*location)
            # NEED TO FIGURE OUT HOW TO GET ALL THE SPOTS TO RESPECTIVE ENERGIES, now only loads the first energy's spots
            # improving might involve modifying the algorithm for calculating intensity
            for i in range(len(energy)):
                #for j in range(len(energy[i])):
                # only taking the first energy location, [0] -> [j] for all, but now puts every spot to the first energy
                point = QPointF(locationx[i][0], locationy[i][0])
                item = QGraphicsSpotView(point, radius[i][0])
                # adding the item to the gui
                self.scene.clearSelection()
                self.scene.addItem(item)
                item.setSelected(True)
                self.scene.setFocusItem(item)
            
            
            

    ## Controlling the spots
    # doesn't do much, only lists how many scene items (circles) there is in the gui, a test for trying to name the spots
    def processSpots(self):
        self.control = SpotControl(self.scene.items(), 0)
      

##### H #
# useless at the moment
class SpotControl(QObject):
    """Class that manages the spots. """

    def __init__(self, spots, i):
        self.names = []
        for spot in spots:
            self.names.append(i)
            #text = "%s" % (self.names[i])
            #textItem = QGraphicsSimpleTextItem(self, text, parent = None, scene = scene.GraphicsView)
            i += 1
        print self.names

#######


class Worker(QObject):
    """ Worker that manages the spots."""

    def __init__(self, spots, energy, parent=None):
        super(Worker, self).__init__(parent)
        self.spots_map = {}
        for spot in spots:
            pos = spot.scenePos()
            tracker = Tracker(pos.x(), pos.y(), spot.radius(), energy,
                        input_precision = config.Tracking_inputPrecision,
                        window_scaling = config.Tracking_windowScalingOn)
#            tracker = TrackerPhysics(pos.x(), pos.y(), 217, 218, spot.radius(), energy)
            self.spots_map[spot] = (QSpotModel(self), tracker)

        for view, tup in self.spots_map.iteritems():
            # view = QGraphicsSpotView, tup = (QSpotModel, tracker) -> tup[0] = QSpotModel
            self.connect(tup[0], SIGNAL("positionChanged"), view.onPositionChange)
            self.connect(tup[0], SIGNAL("radiusChanged"), view.onRadiusChange)

    def process(self, image):
        for model, tracker in self.spots_map.itervalues():
            tracker_result = tracker.feed_image(image)
            # feed_image returns x, y, intensity, energy and radius
            model.update(*tracker_result)

    def save(self, filename):
        intensities = [model.m.intensity for model, tracker \
                                in self.spots_map.itervalues()]
        energy = [model.m.energy for model, tracker in self.spots_map.itervalues()]
        zipped = zip(energy[0], *intensities)
        np.savetxt(filename + ".int", zipped)
        x = [model.m.x for model, tracker \
                in self.spots_map.itervalues()]
        y = [model.m.y for model, tracker \
                in self.spots_map.itervalues()]
        x.extend(y)
        zipped = zip(energy[0], *x)
        np.savetxt(filename + ".pos", zipped)
        
##### H #####        
    def saveloc(self, filename):
        # model = QSpotModel object tracker = tracker
        # dict function .itervalues() = return an iterator over the mapping's values
        energy = [model.m.energy for model, tracker in self.spots_map.itervalues()]
        locationx = [model.m.x for model, tracker in self.spots_map.itervalues()]
        locationy = [model.m.y for model, tracker in self.spots_map.itervalues()]
        radius = [model.m.radius for model, tracker in self.spots_map.itervalues()]
        locations = [locationx, locationy, radius]
        zipped = zip(energy, *locations)
        output = open(filename, 'wb')
        pickle.dump(zipped, output)
        output.close()  
    
