#!/usr/bin/env python3
# -*- Mode: Python -*-
#------------------------------------------------------------------------------#
#  DFTB+: general package for performing fast atomistic simulations            #
#  Copyright (C) 2006 - 2023  DFTB+ developers group                           #
#                                                                              #
#  See the LICENSE file for terms of usage and distribution.                   #
#------------------------------------------------------------------------------#
#
############################################################################
#
#  tagdiff -- small utility to compare numerical results of calculations
#
############################################################################
#
#  (The tagged data is assumed to be represented in the form as provided by
#  the taggedout module of the DFTB project. See the appropriate source code
#  for the details. For the format of the config file see the provided
#  commented sample config file.)
#
############################################################################
from __future__ import print_function
import sys
import os.path

REQUIRED_VERSION = 0x2060000
if sys.hexversion < REQUIRED_VERSION:
    sys.stderr.write("This script requires Python 2.6 or higher\n")
    sys.exit(1)

import re
import gzip
from argparse import ArgumentParser
from tagreader import InvalidEntry, ConversionError, IntConverter
from tagreader import InvalidEntry, TaggedCollection, ResultParser
from uncommlines import UncommLines


VERSION = 0.2
DESCRIPTION ="""Compares two tagged output files according the tolerances
in config file(s).
"""

RES_OK = 0
RES_SKIPPED = 1
RES_FAILED = 2
RES_ERROR = 3

EXIT_OK = 0
EXIT_ERROR = 1

DEFAULT_CONFIG = "tagdiff.conf"

############################################################################
# Exceptions
############################################################################

class DiffError(Exception):
    """Raised, if building difference fails for some reson"""
    pass


############################################################################
# Difference builder functors
############################################################################

class Diff(object):

    def __call__(self, orig, new):
        return None

    def __str__(self):
        return "general"



class DiffElement(Diff):
    """Calculates the difference as the maximal difference of the elements"""

    def __call__(self, orig, new):
        try:
            diff = max(map(lambda x,y: abs(x-y), orig.value, new.value))
        except Exception as ee:
            raise DiffError("Exception (%s) while building difference"
                            % str(ee))
        return diff

    def __str__(self):
        return "element"



class DiffVector(Diff):
    """Calculates the difference as the max. difference in vector norm."""

    def __init__(self, nElement):
        self.__nElement = nElement


    def __call__(self, orig, new):

        if self.__nElement == -1:
            nElement = orig.shape[0]
        else:
            nElement = self.__nElement

        if (len(orig.value) != len(new.value)
            or len(orig.value) % nElement != 0):
            raise DiffError("Invalid nr. of elements")
        diff2 = []
        try:
            for ii in range(0, len(orig.value), nElement):
                origvals = orig.value[ii : ii + nElement]
                newvals = new.value[ii : ii + nElement]
                diffs = [abs(x - y)**2 for x, y in zip(origvals, newvals)]
                totaldiff = 0
                for diff in diffs:
                    totaldiff += diff
                diff2.append(totaldiff)
            maxdiff = max(diff2)**0.5
        except Exception as ee:
            raise DiffError("Exception (%s) while building difference"
                            % str(ee))

        return maxdiff


    def __str__(self):
        return "vector:%d" % self.__nElement



############################################################################
# Tolerance related objects
############################################################################

class ToleranceEntry(object):
    """Contains tolerance related data"""

    # Valid comparison functors with converters to convert strings to the
    # type of their initialization arguments
    __compFuncs = { "element": (DiffElement, ()),
                    "vector":    (DiffVector, (IntConverter(nolist=True),))
    }

    def __init__(self, pattern, value, compFuncName, compFuncArgs, keep):
        """pattern        -- regular expression
             value        -- tolerance for quantities matching pattern
             compFuncName -- comparison function's name
             compFuncArgs -- arguments to the comparison function
             keep         -- should processed entry keept after processing
        """
        # Store representation
        field = ":".join([ compFuncName, ] + list(compFuncArgs))
        self.__str = " @ ".join([ pattern, value, field, keep ])

        # Convert regular expression
        try:
            self.__pattern = re.compile(pattern)
        except re.error:
            raise InvalidEntry(msg="Invalid regular expression")

        # Convert value
        self.__value = None
        try:
            self.__value = float(value)
            self.__value = int(value)
        except ValueError:
            if self.__value == None:
                raise InvalidEntry(msg="Invalid tolerance value")

        # Convert conversion method
        failed = True
        msg = ""
        if compFuncName in self.__compFuncs:
            (compFunc, argConverters) = self.__compFuncs[compFuncName]
            if len(argConverters) == len(compFuncArgs):
                try:
                    args = [ argConverters[ii](compFuncArgs[ii])
                                     for ii in range(len(compFuncArgs)) ]
                    self.__compFunc = compFunc(*args)
                    failed = False
                except ConversionError as msg:
                    pass
        else:
            msg = "Invalid function name"
        if failed:
            raise InvalidEntry(msg="Invalid comparison function '%s' (%s)"
                                                 % (compFuncName, msg))

        # Convert keep-flag
        if keep == "keep":
            self.__keep = True
        else:
            self.__keep = False


    def getPattern(self):
        return self.__pattern
    pattern = property(getPattern, None, None, "pattern")


    def getValue(self):
        return self.__value
    value = property(getValue, None, None, "value")


    def getCompFunc(self):
        return self.__compFunc
    compFunc = property(getCompFunc, None, None, "comparison function")


    def getKeep(self):
        return self.__keep
    keep = property(getKeep, None, None, "keep flag")

    def __str__(self):
        return self.__str



############################################################################
# Parsing
############################################################################

class ConfigParser(object):
    """Parses the tagdiff configuration file and returns the contained
    information as a list of ToleranceEntry-s.
    """

    # Valid comparison functors with argument type lists for initialization
    __comparison = { "element": (DiffElement, ()),
                     "vector":    (DiffVector, (IntConverter(nolist=True),))
    }
    # default comparison functor name
    __defCompFunc = "element"
    __defCompFuncArgs = ()

    # default flag
    __defFlag = "nokeep"


    def __init__(self, file):
        """file -- file like (opened) object containing the configuration file
                   (file should remain open until parser doesn't give back all
                   entries)
        """
        self.__file = file


    def iterateEntries(self):
        """Generator for iterating over the entries in the config file"""

        for (line, iLine) in UncommLines(self.__file, returnLineNr=True):
            words = line.split("@")
            if len(words) < 2:
                raise InvalidEntry(iLine+1, iLine+2, "Not enough fields")

            pattern = words[0].strip()
            value = words[1].strip()

            if len(words) < 3 or words[2].strip() == "":
                compFunc = self.__defCompFunc
                compFuncArgs = self.__defCompFuncArgs
            else:
                tokens = [ s.strip() for s in words[2].split(":") ]
                compFunc = tokens[0]
                compFuncArgs = tuple(tokens[1:])

            if len(words) < 4 or words[3].strip() == "":
                flag = self.__defFlag
            else:
                flag = words[3].strip()

            try:
                te = ToleranceEntry(pattern, value, compFunc, compFuncArgs,
                                    flag)
            except InvalidEntry as ee:
                raise InvalidEntry(iLine+1, iLine+2, ee.msg)
            yield te

    entries = property(iterateEntries, None, None,
                                         "Sequence of extracted entries.")


############################################################################
# Input/Output
############################################################################
resultStr = { RES_OK:            "OK",
                            RES_SKIPPED: "Skipped",
                            RES_FAILED:    "Failed",
                            RES_ERROR:     "Error",
                            }

def printResult(name, method, msg, result):
    """Prints a result of a comparison
    name     -- name of the data
    method -- comarison method
    msg        -- msg to print
    result -- result of the comparison
    """
    res = resultStr[result]
    tmp = [ "%-20s %-20s %-27s" % (name, method, msg) ]
    sys.stdout.write("%s %-10s\n" % ("\n".join(tmp), res))



def printError(message):
    sys.stderr.write("ERROR::%s\n" % message)


def zOpen(filename, mode):
    """Opens a file according to its extension with different methods"""
    if len(filename) > 3 and filename[-3:] == ".gz":
        return gzip.open(filename, mode)
    else:
        return open(filename, mode)



############################################################################
# Option processing
############################################################################

def parseOptions():
    """Parses script's options"""

    parser = ArgumentParser(usage="%(prog)s [ options ] orig new",
                            description=DESCRIPTION)
    parser.add_argument('--version', action='version',
                        version=("%%(prog)s %s" % VERSION))
    parser.add_argument("-c", "--config", dest="configfile", action="append",
                        help="config file to use (multiple config files can be "
                        "specified by using this option multiple times)")
    parser.add_argument("-v", "--verbose", dest="verbose", action="store_true",
                        default=False, help="verbose mode")
    options, args = parser.parse_known_args()
    if not args or len(args) < 2:
        parser.print_help()
        sys.exit(EXIT_ERROR)
    return (options, args)



############################################################################
# Main program
############################################################################

def main():

    #
    # Parse options
    #
    options, args = parseOptions()
    oldFile, newFile = args[:2]
    if options.configfile:
        confFiles = options.configfile[:]
    else:
        confFiles = [os.path.join(os.path.dirname(sys.argv[0]), DEFAULT_CONFIG)]

    configEntries = []
    for confFile in confFiles:
        if options.verbose:
            print("# Reading config file `%s'" % confFile)
        try:
            f = zOpen(confFile, "r")
            configEntries += [ ce for ce in ConfigParser(f).entries ]
        except InvalidEntry as ee:
            printError("Invalid entry (%s) in file '%s' between lines %d and %d"
                                 % (ee.msg, confFile, ee.start, ee.end))
            f.close()
            return EXIT_ERROR
        except IOError:
            printError("Input/output error for file '%s'" % (confFile,))
            return EXIT_ERROR
        f.close()

    #
    # Read and parse files
    #
    if options.verbose:
        print("# Reading old tagged file `%s'" % oldFile)
    try:
        f = zOpen(oldFile, "r")
        old = TaggedCollection(ResultParser(f).entries)
    except InvalidEntry as ee:
        printError("Invalid entry (%s) in file '%s' between lines %d and %d"
                             % (ee.msg, oldFile, ee.start, ee.end))
        f.close()
        return EXIT_ERROR
    except IOError:
        printError("Input/output error for file '%s'" % (oldFile,))
        return EXIT_ERROR
    f.close()

    if options.verbose:
        print("# Reading new tagged file `%s'" % newFile)
    try:
        f = zOpen(newFile, "r")
        new = TaggedCollection(ResultParser(f).entries)
    except InvalidEntry as ee:
        printError("Invalid entry (%s) in file '%s' between lines %d and %d"
                             % (ee.msg, newFile, ee.start, ee.end))
        f.close()
        return EXIT_ERROR
    except IOError:
        printError("Input/output error for file '%s'" % (newFile,))
        return EXIT_ERROR
    f.close()


    #
    # Compare entries
    #
    for configEntry in configEntries:
        if options.verbose:
            print(" # Processing rule:     `%s'" % configEntry)

        compFunc = configEntry.compFunc
        oldEntries = old.getMatchingEntries(configEntry.pattern)
        for oldEntry in oldEntries:
            name = oldEntry.name
            newEntry = new.getEntry(name)

            if not configEntry.keep:
                old.delEntry(name)
                new.delEntry(name)

            if not newEntry:
                printResult(name, compFunc, "Not found in new", RES_SKIPPED)
                continue
            if not oldEntry.isComparable(newEntry):
                printResult(name, compFunc, "Mismatching data", RES_ERROR)
                continue

            try:
                result = compFunc(oldEntry, newEntry)
            except DiffError as msg:
                printResult(name, compFunc,
                            "Difference building error (%s)" % msg, RES_ERROR)
                continue

            passed = (result <= configEntry.value)
            if passed:
                msg = "%-20s" % (str(result))
                res = RES_OK
            else:
                msg = "%-20s" % (str(result))
                res = RES_FAILED
            printResult(name, compFunc, msg, res)

    return EXIT_OK



############################################################################
# Start
############################################################################

if __name__ == "__main__":

    status = main()
    sys.exit(status)

## Test
#
#confFile = "tagdiff.conf"
#oldFile = "orig.tgo"
#newFile = "new.tgo"
#sys.argv = ["alma", confFile, oldFile, newFile ]
#main()
