#! /usr/bin/env python

#
# Parse the inc-list.txt file and process the results.
#

from __future__ import print_function

import datetime
import fnmatch
import os
import os.path
import re
import sys

import pprint
pp = pprint.PrettyPrinter(indent = 2)

archs = [
    'arm',
    'avr',
    'bfin',
    'epiphany',
    'h8300',
    'i386',
    'lm32',
    'm32c',
    'm32r',
    'm68k',
    'mips',
    'moxie',
    'nios2',
    'no_cpu',
    'or1k',
    'powerpc',
    'riscv',
    'riscv32',
    'sh',
    'sparc',
    'sparc64',
    'v850'
]

bsps = [
    'TLL6527M',
    'altera-cyclone-v',
    'atsam',
    'av5282',
    'beagle',
    'beatnik',
    'bf537Stamp',
    'csb336',
    'csb337',
    'csb350',
    'csb360',
    'eZKit533',
    'edb7312',
    'epiphany_sim',
    'erc32',
    'gdbarmsim',
    'gdbv850sim',
    'gen5200',
    'gen68340',
    'gen68360',
    'gen83xx',
    'generic_or1k',
    'genmcf548x',
    'genmcf548x',
    'gensh1',
    'gensh2',
    'gensh4',
    'gumstix',
    'haleakala',
    'hurricane',
    'imx',
    'jmr3904',
    'leon2',
    'leon3',
    'lm32_evr',
    'lm3s69xx',
    'lpc176x',
    'lpc176x',
    'lpc176x',
    'lpc24xx',
    'lpc32xx',
    'm32cbsp',
    'malta',
    'mcf5206elite',
    'mcf52235',
    'mcf5225x',
    'mcf5235',
    'mcf5329',
    'milkymist',
    'motorola_powerpc',
    'moxiesim',
    'mpc55xxevb',
    'mpc8260ads',
    'mrm332',
    'mvme147s',
    'mvme147',
    'mvme162',
    'mvme167',
    'mvme3100',
    'mvme5500',
    'niagara',
    'nios2_iss',
    'no_bsp',
    'pc386',
    'psim',
    'qemuppc',
    'qoriq',
    'raspberrypi',
    'rbtx4925',
    'rbtx4938',
    'realview-pbx-a9',
    'riscv_generic',
    'rtl22xx',
    'shsim',
    'smdk2410',
    'ss555',
    'stm32f4',
    't32mppc',
    'tms570',
    'tqm8xx',
    'tqm8xx',
    'uC5282',
    'usiii',
    'virtex4',
    'virtex5',
    'virtex',
    'xilinx-zynq',
]

class error(Exception):
    def __init__(self, what):
        self.set_output('error: ' + what)
    def set_output(self, msg):
        self.msg = msg
    def __str__(self):
        return self.msg

class includes(object):

    def __init__(self, top, name):
        self.data = { 'headers'   : [],
                      'header-map': { },
                      'refs'      : {},
                      'macros'    : {},
                      'dirs'      : [],
                      'preinstall': {},
                      'tmpinstall': {},
                      'move'      : {},
                      'targets'   : {},
                      'remaining' : [],
                      'map'       : [('c/src/libchip',                            'bsps/include'),
                                     ('c/src/lib/libcpu/@RTEMS_CPU@',             'bsps/@RTEMS_CPU@/include'),
                                     ('c/src/lib/libbsp/@RTEMS_CPU@/@RTEMS_BSP@', 'bsps/@RTEMS_CPU@/@RTEMS_BSP@/include'),
                                     ('c/src/lib/libbsp/@RTEMS_CPU@',             'bsps/@RTEMS_CPU@/include'),
                                     ('c/src/lib/libbsp/shared',                  'bsps/include'),
                                     ('c/src/lib/libbsp/shared',                  ''),
                                     ('cpukit/score/cpu/@RTEMS_CPU@',             'cpukit/score/cpu/@RTEMS_CPU@/include'),
                                     ('cpukit/libdl/include/arch/@RTEMS_CPU@',    'cpukit/score/cpu/@RTEMS_CPU@/include'),
                                     ('cpukit',                                   'cpukit/include')],
                      'subst': { '@RTEMS_CPU@' : sorted(list(set(archs))),
                                 '@RTEMS_BSP@' : sorted(list(set(bsps))),
                                 '@exceptions@': ['new-exceptions'] } }
        self.trace_name = 'xx cpuopts.h'
        self.trace = False
        self.top = top
        self.name = None
        self.last_time = datetime.datetime.now()
        self.indicator = 0
        self.indicator_dots = 3
        self.macros_match = re.compile(r'(\$\([^\)]+\))')
        self.subst_match = re.compile(r'(@[^@]+@)')
        for m in self.data['map']:
            self.data['move'][m[1]] = []
        self.getheaders()
        self.load(name)

    def _subst(self, subst, line):
        for s in subst:
            line = line.replace(s[0], s[1])
        return line

    def _normal_join(self, base, part):
        return os.path.normpath(os.path.join(base, part))

    def _clean_rec(self, rec):
        for i in range(0, len(rec)):
            if rec[i][-1] == os.linesep:
                rec[i] = rec[i][:-1]
        rec = [os.path.normpath(r) for r in rec]
        return rec

    def _indicator_start(self, label):
        print('%s%s ' % (label, ' ' * self.indicator_dots), end = '')
        sys.stdout.flush()
        self.indicator = 0

    def _indicator_end(self, label):
        back = '\b' * (self.indicator_dots + 1)
        print(' %s%s' % (back, label))

    def _indicator(self):
        now = datetime.datetime.now()
        delta = now - self.last_time
        if delta.seconds > 0:
            self.last_time = now
            back = '\b' * (self.indicator_dots + 1)
            forward = '.' * self.indicator
            if self.indicator < self.indicator_dots:
                self.indicator += 1
            else:
                self.indicator = 0
            print('%s%-*s ' % (back, self.indicator_dots, forward), end = '')
            sys.stdout.flush()

    def _refs(self, key):
        if key not in self.data['refs']:
            self.data['refs'][key] = 0
        self.data['refs'][key] += 1

    def _subst_glob(self, line):
        subs = []
        for s in self.subst_match.split(line):
            if len(s) > 0 and s[0] == '@':
                if s not in self.data['subst']:
                    raise error('subst unknown: %s' % (s))
                subs += [s]
        if len(subs) == 0:
            return [(line, [])]
        substs = []
        pi = {}
        for s in subs:
            pi[s] = 0
        finished = False
        while not finished:
            pats = [(s, self.data['subst'][s][pi[s]]) for s in subs]
            l = line
            for p in pats:
                l = l.replace(p[0], p[1])
            substs += [(l, pats)]
            for s in reversed(subs):
                pi[s] += 1
                if pi[s] < len(self.data['subst'][s]):
                    break
                pi[s] = 0
                if s == subs[0]:
                    finished = True
        return substs

    def _header_macros(self, base, dependent):
        if self.trace:
            print('M] 1) b=%s d=%s' % (base, dependent))
        d = ''
        for m in self.macros_match.split(self._normal_join(base, dependent)):
            if self.trace:
                print('M] 2) m=%r d=%s' % (m, d))
            if len(m) > 0 and m[0] == '$':
                if m not in self.data['macros']:
                    self.data['macros'][m] = []
                self.data['macros'][m] += [(base, dependent)]
                if m in ['$(srcdir)', '$(top_srcdir)']:
                    m = ''
            d += m
        return base, d

    def _header_subst(self, base, dependent):
        dependents = []
        for s in self._subst_glob(self._normal_join(base, dependent)):
            if self.trace:
                print('S] 1) s=', s)
            if s[0] in self.data['headers']:
                dependents += [(base, os.path.normpath(s[0][len(base) + 1:]))]
        if self.trace:
            if len(dependents):
                print('S] 2) d=%s' % (', '.join([d[0] for d in dependents])))
            else:
                print('S] 3) d=empty')
        return dependents

    def _normalise_path(self, base, dependent):
        if self.trace:
            print('N] 1) b=%s d=%s' % (base, dependent))
        base, dependent = self._header_macros(base, dependent)
        if self.trace:
            print('N] 2) b=%s d=%s' % (base, dependent))
        if not dependent.startswith(base):
            base = os.path.normpath(os.path.commonprefix([base, dependent]))
            if not os.path.exists(base):
                b = base
                base = os.path.dirname(base)
                if not os.path.exists(base):
                    raise error('base path is invalid: %d' % (b))
            if self.trace:
                print('N] 3) b=%s d=%s' % (base, dependent))
        if self.trace:
            print('N] 4) b=%s d=%s db=%c' % (base, dependent, dependent[len(base)]))
        if dependent.startswith(base) and \
           len(dependent) > len(base) and \
           dependent[len(base)] == os.sep:
            dependent = os.path.normpath(dependent[len(base) + 1:])
        if self.trace:
            print('N] 5) b=%s d=%s' % (base, dependent))
        np = self._normal_join(base, dependent)
        if not np.startswith(base):
            base = os.path.normpath(os.path.commonprefix([base, np]))
        if self.trace:
            print('N] 6) b=%s d=%s np=%s' % (base, dependent, np))
        return base, os.path.normpath(np[len(base) + 1:])

    def _installable(self, key, base, target, dependent, line):
        if os.path.basename(target) == self.trace_name:
            self.trace = True
        if self.trace:
            print('I] %s: b=%s t=%s d=%s' % (key, base, target, dependent))
        self._refs(target)
        base, dependent = self._normalise_path(base, dependent)
        if self.trace:
            print('I] (np) b=%s d=%s' % (base, dependent))
        if target not in self.data[key]:
            self.data[key][target] = []
        self.data[key][target] += self._header_subst(base, dependent)
        if os.path.basename(target) == self.trace_name:
            self.trace = False

    def _dirs(self, rec, line):
        if len(rec) != 3:
            raise error('invalid dirs rec: %d: %s' % (line, ','.rec))
        self._refs(rec[1])
        self.data['dirs'] += [rec[1]]

    def _preinstall(self, rec, line):
        if len(rec) != 4:
            raise error('invalid preinstall rec: %d: %s' % (line, ','.rec))
        self._installable('preinstall', rec[1], rec[2], rec[3], line)

    def _tmpinstall(self, rec, line):
        if len(rec) != 4:
            raise error('invalid tmpinstall rec: %d: %s' % (line, ','.rec))
        self._installable('tmpinstall', rec[1], rec[2], rec[3], line)

    def _unique_installable(self, key):
        for h in self.data[key]:
            self.data[key][h] = list(set(self.data[key][h]))

    def _unique(self):
        self.data['dirs'] = list(set(self.data['dirs']))
        self._unique_installable('preinstall')
        self._unique_installable('tmpinstall')

    def getheaders(self):
        for root, dirs, files in os.walk(self.top, followlinks = True):
            for f in files:
                if fnmatch.fnmatch(f, '*.h'):
                    h = self._normal_join(root, f)[len(self.top) + 1:]
                    self.data['headers'] += [h]
                    hbase = os.path.basename(h)
                    if hbase not in self.data['header-map']:
                        self.data['header-map'][hbase] = []
                    self.data['header-map'][hbase] += [h]

    def load(self, name):
        with open(name, 'r') as d:
            raw = d.readlines()
        self.name = name
        lc = 0
        for l in raw:
            lc += 1
            ls = self._clean_rec(l.split(','))
            if ls[0] == 'dirs':
                self._dirs(ls, lc)
            elif ls[0] == 'preinstall':
                self._preinstall(ls, lc)
            elif ls[0] == 'tmpinstall':
                self._tmpinstall(ls, lc)
            else:
                raise error('invalid record type: %s' % (ls[0]))
        self._unique()

    def move(self):
        trace = False
        moved = {}
        self.data['remaining'] = sorted([p for p in self.data['preinstall'].keys()
                                         if p.startswith('$(PROJECT_INCLUDE)') and \
                                         not (p.endswith('.adb') or p.endswith('.ads'))])
        for m in self.data['map']:
            substs = self._subst_glob(m[0])
            for p in self.data['preinstall']:
                if p.startswith('$(PROJECT_INCLUDE)'):
                    for h in self.data['preinstall'][p]:
                        source_header = self._normal_join(h[0], h[1])
                        source_base = m[0] + '/'
                        dest_base = m[1]
                        if os.path.basename(source_header) == self.trace_name:
                            trace = True
                        if trace:
                            print()
                            print('  1| %s %s map: %s %s' % (p, source_header, source_base, dest_base))
                        for s in substs:
                            source_subst_base = self._subst(s[1], source_base)
                            if trace and False:
                                print('  2| %s' % (source_subst_base))
                            if source_header.startswith(source_subst_base):
                                source_base = source_subst_base
                                dest_base = self._subst(s[1], dest_base)
                                if trace:
                                    print('  3| %s %s' % (source_base, dest_base))
                                break
                        if trace:
                            print('  4| %s %s' % (source_base, source_header))
                        if source_header.startswith(source_base):
                            if source_header not in moved:
                                if trace:
                                    print('  X| %s %d %s' % (source_header,
                                                             len(source_base),
                                                             source_header[len(source_base):]))
                                dest_header = p.replace('$(PROJECT_INCLUDE)', dest_base)
                                if source_header != dest_header:
                                    self.data['move'][m[1]] += [(source_header, dest_header, p, source_base)]
                                    if dest_header not in self.data['targets']:
                                        self.data['targets'][dest_header] = []
                                    self.data['targets'][dest_header] += [source_header]
                                moved[source_header] = dest_header
                                if p in self.data['remaining']:
                                    self.data['remaining'].remove(p)
                                if trace:
                                    print('  5| %s (%s) %s (%s)' % (source_header,
                                                                    source_base,
                                                                    dest_header,
                                                                    dest_base))
                            elif trace:
                                print('  6| %s %s' % (source_base, dest_base))
                        trace = False

    def report(self):
        s  = 'RTEMS Includes Move' + os.linesep
        s += '===================' + os.linesep
        s += os.linesep
        s += 'Preinstall' + os.linesep
        s += '----------' + os.linesep
        s += os.linesep
        s += 'Remaining: %d%s' % (len(self.data['remaining']), os.linesep)
        for r in sorted(self.data['remaining']):
            s += ' %s%s' % (r, os.linesep)
        s += os.linesep
        d = 0
        for t in sorted(self.data['targets']):
            if len(self.data['targets'][t]) > 1:
                d += 1
        s += 'Duplicate Targets: %d%s' % (d, os.linesep)
        for t in sorted(self.data['targets']):
            if len(self.data['targets'][t]) > 1:
                s += ' %3d: %s%s' % (len(self.data['targets'][t]), t, os.linesep)
        s += os.linesep
        for t in sorted(self.data['targets']):
            if len(self.data['targets'][t]) > 1:
                s += ' %s:%s' % (t, os.linesep)
                for h in sorted(self.data['targets'][t]):
                    s += '  %s%s' % (h, os.linesep)
                s += os.linesep
        t = 0
        s += os.linesep
        for inc in self.data['move']:
            t += len(self.data['move'][inc])
        s += 'Total: %d%s' % (t, os.linesep)
        s += os.linesep
        for inc in sorted(self.data['move']):
            if len(self.data['move'][inc]):
                s += ' %s: %d%s' % (inc, len(self.data['move'][inc]), os.linesep)
                l = 0
                for m in self.data['move'][inc]:
                    if l < len(m[0]):
                        l = len(m[0])
                for m in sorted(self.data['move'][inc]):
                    s += '  %-*s => %s (%s) %s' % (l, m[0], m[1], m[2], os.linesep)
                s += os.linesep
        return s

    def get_paths(self):
        paths = []
        for inc in sorted(self.data['move']):
            for m in sorted(self.data['move'][inc]):
                p = os.path.dirname(m[1])
                if p not in paths:
                    paths += [p]
        return sorted(paths)

    def script(self):
        paths = self.get_paths()
        s  = '#! /bin/sh' + os.linesep
        s += '#' + os.linesep
        s += '# RTEMS Project 2017' + os.linesep
        s += '#' + os.linesep
        s += '#  Automatically generated, do not edit.' + os.linesep
        s += '#' + os.linesep
        s += '# Comment the DRY_RUN shell variable to commit the changes.' + os.linesep
        s += '#' + os.linesep
        s += 'set -x' + os.linesep
        s += 'set -e' + os.linesep
        s += 'DRY_RUN="--dry-run"' + os.linesep
        s += os.linesep
        s += '# make directories' + os.linesep
        for p in paths:
            s += 'mkdir -p %s%s' % (p, os.linesep)
        s += os.linesep
        for inc in sorted(self.data['move']):
            if len(self.data['move'][inc]):
                s += '# %s: %d%s' % (inc, len(self.data['move'][inc]), os.linesep)
                for m in sorted(self.data['move'][inc]):
                    if len(self.data['targets'][m[1]]) > 1:
                        continue
                    s += 'git mv  %s %s%s' % (m[0], m[1], os.linesep)
                s += 'git commit ${DRY_RUN} -m "preinstall: Moving to %s" ' % (inc)
                s += '-m Update #3254.' + os.linesep
                s += os.linesep
        return s

    def makefile_patch_data(self):
        paths = self.get_paths()
        s = ''
        for inc in sorted(self.data['move']):
            if len(self.data['move'][inc]):
                for m in sorted(self.data['move'][inc]):
                    if len(self.data['targets'][m[1]]) > 1:
                        continue
                    p = m[2][len('$(PROJECT_INCLUDE)/'):]
                    s += '%s,%s,%s,%s%s' % (m[0], m[1], m[3], p, os.linesep)
        return s

def run(top, name):
    inc = includes(top, name)
    inc.move()
    print(inc.report())
    with open('h-move', 'w') as f:
        f.write(inc.script())
    with open('mam-patch-it-data', 'w') as f:
        f.write(inc.makefile_patch_data())

if __name__ == "__main__":
    ec = 0
    try:
        if len(sys.argv) != 2:
            raise error('invalid arguments: usage: inc-mover file')
        run(os.getcwd(), sys.argv[1])

    except error as e:
        print(e, file = sys.stderr)
        ec = 1
    except KeyboardInterrupt:
        raise
        print('user interrupted', file = sys.stderr)
        ec = 1
    except:
        raise
        print('internal error', file = sys.stderr)
        ec = 1
    sys.exit(ec)
