#! /usr/bin/env python

from __future__ import print_function

import copy
import fnmatch
import os
import os.path
import re

class error(Exception):
    pass

class converter(object):

    LF = os.linesep

    def __init__(self, source, verbose = False, write = True):
        self.verbose = verbose
        self.write = write
        self.source = source
        self.changes = { }
        self.greps = self.searches()
        for g in self.greps:
            self.changes[g] = []
        with open(source, 'r') as f:
            self.src = f.readlines()
        self.orig = copy.deepcopy(self.src)
        self.orig_changes = { }
        self.edits = ''

    def __del__(self):
        if self.write and self.src is not None:
            with open(self.source, 'w') as f:
                f.writelines(self.src)
            self.src = None

    def __str__(self):
        s = '-' * 20 + self.LF
        s += '%s: IN: %-5d' % (self.source, len(self.src))
        s += ' OUT: %-5d' % (len(self.src))
        s += ' DIFF:%-5d' % (len(self.orig) - len(self.src))
        s += ' EDITS: %s %s' % (self.has_edits(), self.edits)
        s += self.LF
        s += 'CHANGES:' + self.LF
        for c in self.orig_changes:
            if len(self.orig_changes[c]) > 0:
                s += ' %s: %d%s' % (c, len(self.orig_changes[c]),  self.LF)
                for l in self.orig_changes[c]:
                    s += ' %5d: %s%s' % (l, self.orig[l - 1][:-1], self.LF)
        s += 'DIFF:' + self.LF
        s += ''.join(self.diff())
        return s

    def _get(self, pos, length = 1):
        if pos <= len(self.src):
            return [s[:-1] for s in self.src[pos - 1:pos - 1 + length]]
        return []

    def diff(self):
        import difflib
        return difflib.unified_diff(self.orig, self.src)

    def has_changes(self):
        return True

    def has_edits(self):
        return len(self.edits) > 0

    def searches(self):
        return { }

    def get(self, l):
        line = ''
        while l <= self.source_len():
            s = self._get(l)[0]
            if len(s) > 0 and s[-1] != '\\':
                line += s
                break
            sp = ''
            if len(line) > 0:
                sp = ' '
            line += sp + s[:-1]
            l += 1
        return line

    def source_len(self):
        return len(self.src)

    def source_empty(self, pos):
         return len(self.src[pos - 1][:-1]) == 0

    def insert(self, pos, lines):
        if type(lines) is str:
            lines = [lines]
        i = len(self.src)
        self.src = self.src[:pos - 1] + \
                   [l + self.LF for l in lines] + \
                   self.src[pos - 1:]
        for t in self.changes:
            for p in range(0, len(self.changes[t])):
                if self.changes[t][p] >= pos:
                    self.changes[t][p] += len(lines)
        self.edits += ' add:%d/%d ' % (i, len(self.src))

    def remove(self, pos, count = 1):
        i = len(self.src)
        for l in range(0, count):
            del self.src[pos - 1]
        for t in self.changes:
            for p in range(0, len(self.changes[t])):
                if self.changes[t][p] >= pos and self.changes[t][p] < pos + count:
                    self.changes[t][p] = 0
                elif self.changes[t][p] >= pos:
                    self.changes[t][p] -= count
        self.edits += ' del:%d/%d' % (i, len(self.src))

    def replace(self, pos, lines):
        if type(lines) is str:
            lines = [lines]
        self.remove(pos, len(lines))
        self.insert(pos, lines)

    def scan(self):
        lc = 0
        for l in self.src:
            lc += 1
            if lc > len(self.src):
                print(']]', self.source, l)
            for g in self.greps:
                if self.greps[g].search(l):
                    if g not in self.changes:
                        self.changes[g] = [lc]
                    else:
                        self.changes[g] += [lc]
        self.orig_changes = copy.deepcopy(self.changes)

    def update(self):
        pass

    def convert(self):
        self.scan()
        if self.has_changes():
            self.update()

class jobs(object):

    def __init__(self, path, globs = None, excludes = [], verbose = False):
        self.verbose = verbose
        self.path = path
        self.globs = globs
        self.excludes = excludes
        self.files = []
        if globs is not None:
            if type(globs) is not list:
                globls = [globs]
            for g in globs:
                self.find_files(path, g)
        else:
            self.files = path

    def __str__(self):
        return 'Files: %d' % (len(self.files))

    def find_files(self, path, glob, reset = False):
        if reset:
            self.files = []
        for root, dirs, files in os.walk(path, followlinks = True):
            for f in files:
                ff = os.path.join(root, f)
                if fnmatch.fnmatch(f, glob) and \
                   (f not in self.excludes and ff not in self.excludes):
                    self.files += [ff]
        self.files = sorted(self.files)

    def run(self, changer, include_map, write = True):
        changes = []
        for f in self.files:
            c = changer(self.path, f, include_map, verbose = self.verbose, write = write)
            c.convert()
            if c.has_edits():
                if self.verbose:
                    print('%70s: updated' % (f))
            if self.verbose:
                print(c)
            changes += [c]
        return changes

class parse_am(converter):

    def __init__(self, base, source, include_map, verbose = False, write = True):
        super(parse_am, self).__init__(source, verbose, write)
        self.base = base
        self.testdir = os.path.dirname(source[len(base) + 1:])
        self.data = { 'programs': [], 'libraries': [], 'am': { } }
        self.include_map = include_map
        self.include_maps_used = []

    def __str__(self):
        s = super(parse_am, self).__str__()
        import pprint
        s += 'file: ' + self.source + os.linesep
        s += 'dir: ' + self.testdir + os.linesep
        s += pprint.pformat(self.data)
        return s

    def _get_am_line(self, l):
        s = self.get(l)
        if s[0] == '#':
            return None, None
        sep = '='
        if '+=' in s:
            sep = '+='
        tag, data = s.split(sep, 1)
        return tag.strip(), data.split()

    def _parse_am_type(self, l, label = None):
        am = self.data['am']
        tag, data = self._get_am_line(l)
        if tag is not None:
            if tag not in am:
                am[tag] = []
            am[tag] += data

    def _process_am_tag(self, am, prog, tag, label, err = True):
        if tag not in am:
            if err:
                raise error('no AM %s found: %s' % (label, tag))
        else:
            prog[label] = am[tag]

    def _map_include(self, include):
        for m in self.include_map:
            for i in self.include_map[m]:
                if i == include:
                    self.include_maps_used = sorted(set(self.include_maps_used + [m]))
                    return '$(%s)' % (m)
        return include

    def _get_filename(self, name, filename):
        fullpath = os.path.normpath(os.path.join(self.base, self.testdir, filename))
        return os.path.relpath(fullpath, self.base)

    def searches(self):
        return { 'programs':  re.compile('^.*_PROGRAMS.*='),
                 'libraries': re.compile('^.*_LIBRARIES.*='),
                 'sources':   re.compile('^.*_SOURCES.*='),
                 'ldadd':     re.compile('^.*_LDADD.*='),
                 'cppflags':  re.compile('AM_CPPFLAGS.=*'),
                 'data':      re.compile('dist_rtems_tests_DATA.*=') }

    def parse_programs(self):
        change = 'programs'
        for c in range(0, len(self.changes[change])):
            l = self.changes[change][c]
            if l == 0:
                continue
            self._parse_am_type(l)

    def parse_libraries(self):
        change = 'libraries'
        for c in range(0, len(self.changes[change])):
            l = self.changes[change][c]
            if l == 0:
                continue
            self._parse_am_type(l)

    def parse_sources(self):
        change = 'sources'
        for c in range(0, len(self.changes[change])):
            l = self.changes[change][c]
            if l == 0:
                continue
            self._parse_am_type(l)

    def parse_ldadd(self):
        change = 'ldadd'
        for c in range(0, len(self.changes[change])):
            l = self.changes[change][c]
            if l == 0:
                continue
            self._parse_am_type(l)

    def parse_cppflags(self):
        change = 'cppflags'
        for c in range(0, len(self.changes[change])):
            l = self.changes[change][c]
            if l == 0:
                continue
            self._parse_am_type(l)

    def parse_data(self):
        change = 'data'
        for c in range(0, len(self.changes[change])):
            l = self.changes[change][c]
            if l == 0:
                continue
            self._parse_am_type(l)

    def update(self):
        self.parse_programs()
        self.parse_libraries()
        self.parse_sources()
        self.parse_ldadd()
        self.parse_cppflags()
        self.parse_data()

    def process_data(self):
        am = self.data['am']
        progs = []
        libs = []
        for k in am:
            if k.endswith('_PROGRAMS'):
                progs += am[k]
        for k in am:
            if k.endswith('_LIBRARIES'):
                libs += am[k]
        if len(progs) == 0 and len(libs) == 0:
            print(self)
            raise error('no AM programs or libraries found')
        for name in progs:
            self.data['programs'] += [name]
            self.data[name] = { }
            self.process_name(am, self.data[name], name)
        for name in libs:
            lib = name
            name = name.replace('.', '_')
            self.data['libraries'] += [name]
            self.data[name] = { 'LIB': lib }
            self.process_name(am, self.data[name], name)

    def process_name(self, am, prog, name):
            self._process_am_tag(am, prog, '%s_SOURCES' % (name), 'SOURCES')
            self._process_am_tag(am, prog, 'AM_CPPFLAGS', 'CPPFLAGS', err = False)
            self._process_am_tag(am, prog, '%s_LDADD' % (name), 'LDADD', err = False)
            self._process_am_tag(am, prog, 'dist_rtems_tests_DATA', 'DATA', err = False)
            self.process_sources(prog, name)
            self.process_cppflags(prog, name)
            self.process_dist_data(prog, name)

    def process_sources(self, prog, name):
        sources = []
        for s in prog['SOURCES']:
            sources += [self._get_filename(name, s)]
        prog['SOURCES'] = sources

    def process_cppflags(self, prog, name):
        if 'CPPFLAGS' in prog:
            includes = []
            defines = []
            others = []
            for f in prog['CPPFLAGS']:
                if f.startswith('-I'):
                    includes += [self._map_include(f)]
                elif f.startswith('-D'):
                    defines += [f]
                else:
                    others += [f]
            prog['CPPFLAGS'] = ['$(AM_CPPFLAGS)', '$(TEST_FLAGS_%s)' % (name)] + \
                               sorted(set(includes)) + sorted(set(defines)) + sorted(set(others))

    def process_dist_data(self, prog, name):
        if 'DATA' in prog:
            screens = []
            docs = []
            for d in prog['DATA']:
                if d.endswith('.scn'):
                    screens += [d]
                else:
                    docs += [d]
            if len(screens) > 0:
                prog['SCREENS'] = screens
            if len(docs) > 0:
                prog['DOCS'] = docs

    def generate_program(self, tests, screens, docs, name):
        if name not in self.data['programs']:
            return None
        prog = self.data[name]
        out = ['if TEST_%s' % (name)]
        out += ['%s += %s' % (tests, name)]
        if 'SCREENS' in prog:
            out += ['%s += %s' % (screens, ' '.join(prog['SCREENS']))]
        if 'DOCS' in prog:
            out += ['%s += %s' % (docs, ' '.join(prog['DOCS']))]
        out += ['%s_SOURCES = %s' % (name, ' '.join(prog['SOURCES']))]
        if 'CPPFLAGS' in prog:
            out += ['%s_CPPFLAGS = %s' % (name, ' '.join(prog['CPPFLAGS']))]
        if 'LDADD' in prog:
            out += ['%s_LDADD = %s' % (name, ' '.join(prog['LDADD']))]
        out += ['endif']
        return out

    def generate_library(self, libraries, name):
        if name not in self.data['libraries']:
            return None
        lib = self.data[name]
        out = ['if TEST_%s' % (name)]
        out += ['%s += %s' % (libraries, lib['LIB'])]
        out += ['%s_SOURCES = %s' % (name, ' '.join(lib['SOURCES']))]
        out += ['endif']
        return out

    def get_used_include_maps(self):
        return self.include_maps_used

    def program_names(self):
        return self.data['programs']

    def library_names(self):
        return self.data['libraries']

def format_block(text):
    import textwrap
    wrapped = textwrap.wrap(text, subsequent_indent = '\t')
    if len(wrapped) > 0:
        wrapped = [l + ' \\' for l in wrapped[:-1]] + [wrapped[-1]]
    return wrapped

def run():
    import argparse
    parser = argparse.ArgumentParser(description = 'Patch a build.')
    parser.add_argument('--verbose', dest = 'verbose', action = 'store_true',
                        default = False, help='Verbose output')
    parser.add_argument('--dry-run', dest = 'write', action = 'store_false',
                        default = True, help='Dry run, do not write the changes out')
    args = parser.parse_args()

    testsuite = {
        'samples': ('Samples',
                       'samples', 'sample_libs', 'sample_screens', 'sample_docs'),
        'benchmarks': ('Benchmarks',
                       'benchmarks', 'benchmark_libs', 'benchmark_screens', 'benchmark_docs'),
        'rhealstone': ('Real-time Benchmarking',
                       'rhealstones', 'rhealstone_libs', 'rhealstone_screens', 'rhealstone_docs'),
        'fstests': ('File System Testsuite',
                    'fs_tests', 'fs_libs', 'fs_screens', 'fs_docs'),
        'sptests': ('Single Processor Testsuite',
                    'sp_tests', 'sp_libs', 'sp_screens', 'sp_docs'),
        'tmtests': ('Timing Testsuite',
                    'tm_tests', 'tm_libs', 'tm_screens', 'tm_docs'),
        'libtests': ('Library Testsuite',
                     'lib_tests', 'lib_lib', 'lib_screens', 'lib_docs'),
        'psxtests': ('POSIX Testsuite',
                     'psx_tests', 'psx_lib', 'psx_screens', 'psx_docs'),
        'psxtmtests': ('POSIX Timing Testsuite',
                       'psxtm_tests', 'psxtm_lib', 'psxtm_screens', 'psxtm_docs'),
        'smptests': ('SMP Testsuite',
                     'smp_tests', 'smp_lib', 'smp_screens', 'smp_docs'),
    }
    include_map = {
        'test_includes' : ['-I$(top_srcdir)/support',
                           '-I$(top_srcdir)/../psxtests/include'],
        'support_includes': [ '-I$(top_srcdir)/../support/include' ]
    }

    for tests in testsuite:
        test_makefile_am = 'testsuites/%s/Makefile.am' % (tests)
        test_configure_ac = 'testsuites/%s/%s.ac' % (tests, tests)
        j = jobs(os.path.join('testsuites', tests),
                 ['*.am'],
                 excludes = [test_makefile_am,
                             'testsuites/samples/base_mp/Makefile.am',
                             'testsuites/samples/base_mp/node1/Makefile.am',
                             'testsuites/samples/base_mp/node2/Makefile.am'],
                 verbose = False)
        try:
            config = testsuite[tests]
            changers = j.run(parse_am, include_map, write = args.write)
            include_maps_used = []
            for c in changers:
                c.process_data()
                if args.verbose:
                    print(c)
                include_maps_used = sorted(set(include_maps_used + c.get_used_include_maps()))
            programs = []
            libraries = []
            for c in changers:
                programs += c.program_names()
                libraries += c.library_names()
            programs = sorted(programs)
            libraries = sorted(libraries)
            configure_ac = [ '# BSP Test configuration' ]
            makefile_am = [
                '#',
                '# %s' % (config[0]),
                '#',
                '',
                'ACLOCAL_AMFLAGS = -I ../aclocal',
                '',
                'include $(RTEMS_ROOT)/make/custom/@RTEMS_BSP@.cfg',
                'include $(top_srcdir)/../automake/compile.am',
                '',
                '%s = ' % (config[1]),
                '%s = ' % (config[3]),
                '%s = ' % (config[4]),
            ]

            if len(libraries) > 0:
                makefile_am += ['%s = ' % (config[2])]

            makefile_am += ['']

            if len(include_maps_used) > 0:
                for i in include_maps_used:
                    makefile_am += format_block('%s = %s' % (i, ' '.join(include_map[i])))
                makefile_am += ['']
            for name in programs:
                for c in changers:
                    out = c.generate_program(config[1], config[3], config[4], name)
                    if out is not None:
                        for l in out:
                            makefile_am += format_block(l)
                        makefile_am += ['']
                        configure_ac += ['RTEMS_TEST_CHECK([%s])' % (name)]
                        break
            for name in libraries:
                for c in changers:
                    out = c.generate_library(config[2], name)
                    if out is not None:
                        for l in out:
                            makefile_am += format_block(l)
                        makefile_am += ['']
                        break
            makefile_am += ['rtems_tests_PROGRAMS = $(%s)' % (config[1])]
            if len(libraries) > 0:
                makefile_am += ['noinst_LIBRARIES = $(%s)' % (config[2])]
            makefile_am += ['dist_rtems_tests_DATA = $(%s) $(%s)' % (config[3], config[4])]
            makefile_am += ['']
            makefile_am += ['include $(top_srcdir)/../automake/local.am']
            configure_ac = [l + os.linesep for l in configure_ac]
            makefile_am = [l + os.linesep for l in makefile_am]
            if args.verbose:
                print('=' * 80)
                for l in makefile_am:
                    print(l[:-1])
            if args.write:
                with open(test_configure_ac, 'w') as f:
                    f.writelines(configure_ac)
                with open(test_makefile_am, 'r') as f:
                    original = f.readlines()
                with open(test_makefile_am, 'w') as f:
                    f.writelines(makefile_am)
                    f.writelines(['>' * 80])
                    f.writelines(original)
        except error as e:
            import sys
            print('error: %s' % (e))
            sys.exit(1)

if __name__ == '__main__':
    run()
