#! /usr/bin/env python

# SPDX-License-Identifier: BSD-2-Clause
"""RTEMS LibBBSD Kernel Symbols

Generate the symbols for the kernel headers and merge in any new ones
"""

#
# Copyright (C) 2021 Chris Johns <chrisj@rtems.org>, All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# 1. Redistributions of source code must retain the above copyright
#    notice, this list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright
#    notice, this list of conditions and the following disclaimer in the
#    documentation and/or other materials provided with the distribution.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.
#

from __future__ import print_function

import argparse
import os
import re
import sys

version = "1.0"

kern_objects = [
    ( 'freebsd/sys', '.*\.o' ),
    ( 'rtemsbsd/rtems', 'rtems-kernel-.*\.o' )
] # yapf: disable

kern_excludes = [
    '^rtems_',
    '^accept$',
    '^arc4random$',
    '^bind$',
    '^blackhole$',
    '^bootverbose$',
    '^bpf_filter$',
    '^bpf_jitter$',
    '^bpf_jitter_enable$',
    '^bpf_validate$',
    '^cache_enter$',
    '^connect$',
    '^drop_redirect$',
    '^drop_synfin$',
    '^free$',
    '^getentropy$',
    '^getpeername$',
    '^getsockname$',
    '^getsockopt$',
    '^global_epoch$',
    '^global_epoch_preempt$',
    '^ifqmaxlen$',
    '^in6addr_any$',
    '^in6addr_linklocal_allnodes$',
    '^in6addr_loopback$',
    '^in6addr_nodelocal_allnodes$',
    '^in_epoch$',
    '^kevent$',
    '^kqueue$',
    '^listen$',
    '^malloc$',
    '^max_datalen$',
    '^max_hdr$',
    '^max_linkhdr$',
    '^max_protohdr$',
    '^maxsockets$',
    '^nd6_debug$',
    '^nd6_delay$',
    '^nd6_gctimer$',
    '^nd6_maxnudhint$',
    '^nd6_mmaxtries$',
    '^nd6_onlink_ns_rfc4861$',
    '^nd6_prune$',
    '^nd6_umaxtries$',
    '^nd6_useloopback$',
    '^net_epoch$',
    '^net_epoch_preempt$',
    '^nmbclusters$',
    '^nmbjumbo16$',
    '^nmbjumbo9$',
    '^nmbjumbop$',
    '^nmbufs$',
    '^nolocaltimewait$',
    '^path_mtu_discovery$',
    '^pause$',
    '^pf_osfp_entry_pl$',
    '^pf_osfp_pl$',
    '^pipe$',
    '^poll$',
    '^pselect$',
    '^random$',
    '^realloc$',
    '^reallocf$',
    '^recvfrom$',
    '^recvmsg$',
    '^rtems',
    '^select$',
    '^sendmsg$',
    '^sendto$',
    '^setfib$',
    '^setsockopt$',
    '^shutdown$',
    '^socket$',
    '^socketpair$',
    '^soreceive_stream$',
    '^srandom$',
    '^strdup$',
    '^sysctlbyname$',
    '^sysctl$',
    '^sysctlnametomib$',
    'sys_init',
    '^taskqueue_',
    '^tcp_offload_listen_start$',
    '^tcp_offload_listen_stop$',
    '^ticks$',
    '^useloopback$',
    '^_Watchdog_Ticks_since_boot$'
] # yapf: disable

kern_header = 'rtemsbsd/include/machine/rtems-bsd-kernel-namespace.h'


class exit(Exception):
    """Base class for exceptions."""
    def __init__(self, code):
        self.code = code


class error(Exception):
    """Base class for exceptions."""
    def set_output(self, msg):
        self.msg = msg

    def __str__(self):
        return self.msg


class general_error(error):
    """Raise for a general error."""
    def __init__(self, what):
        self.set_output('error: ' + str(what))


class command:
    def __init__(self, cmd, cwd='.'):
        self.exit_code = 0
        self.output = None
        self.cmd = cmd
        self.cwd = cwd
        self.result = None

    def run(self):

        import subprocess

        #
        # Support Python 2.6
        #
        if "check_output" not in dir(subprocess):

            def f(*popenargs, **kwargs):
                if 'stdout' in kwargs:
                    raise ValueError(
                        'stdout argument not allowed, it will be overridden.')
                process = subprocess.Popen(stdout=subprocess.PIPE,
                                           *popenargs,
                                           **kwargs)
                output, unused_err = process.communicate()
                retcode = process.poll()
                if retcode:
                    cmd = kwargs.get("args")
                    if cmd is None:
                        cmd = popenargs[0]
                    raise subprocess.CalledProcessError(retcode, cmd)
                return output

            subprocess.check_output = f

        self.exit_code = 0
        try:
            if os.name == 'nt':
                cmd = ['sh', '-c'] + self.cmd
            else:
                cmd = self.cmd
                output = subprocess.check_output(cmd,
                                                 cwd=self.cwd).decode("utf-8")
                self.output = output.split(os.linesep)
        except subprocess.CalledProcessError as cpe:
            self.exit_code = cpe.returncode
            output = cpe.output.decode("utf-8")
            self.output = output.split(os.linesep)
        except OSError as ose:
            cs = ' '.join(cmd)
            if len(cs) > 80:
                cs = cs[:80] + '...'
            raise general_error('bootstrap failed: %s in %s: %s' % \
                                (cs, self.cwd, (str(ose))))
        except KeyboardInterrupt:
            pass
        except:
            raise


class kernel_symbols:
    def __init__(self, excludes):
        self.bsd_tag = '_bsd_'
        self.excludes = [re.compile(exc) for exc in excludes]
        self.bsps = {}
        self.header = {'source': [], 'symbols': []}
        self.output = {'source': [], 'symbols': []}
        self.analysis = {'mapped': [], 'unmapped': [], 'new': []}

    @staticmethod
    def _find(base, spec):
        found = []
        filter = re.compile(spec[1])
        for root, dirs, files in os.walk(os.path.join(base, spec[0]),
                                         topdown=True):
            for f in files:
                if filter.match(f):
                    found += [os.path.join(root, f)]
        return found

    @staticmethod
    def _find_bsps(build):
        bsps = []
        filter = re.compile('^.*-rtems[0-9].*-.*')
        for name in os.listdir(build):
            if os.path.isdir(os.path.join(build, name)) and \
               filter.match(name) != None:
                bsps += [name]
        return bsps

    @staticmethod
    def bsp_arch(bsp):
        bs = bsp.split('-')
        return bs[0] + '-' + bs[1]

    def _clean(self, symbols):
        syms = []
        for sym in symbols:
            add = True
            for exclude in self.excludes:
                if exclude.search(sym) is not None:
                    add = False
                    break
            if add:
                syms += [sym]
        return sorted(list(set(syms)))

    def load_header(self, header):
        with open(header, 'r') as h:
            self.header['source'] = h.read().splitlines()
        filter = re.compile('^#define\s')
        for line in self.header['source']:
            if filter.match(line) != None:
                ls = line.split()
                if len(ls) == 3:
                    self.header['symbols'] += [ls[1]]
        self.header['symbols'] = self._clean(self.header['symbols'])

    def load_symbols(self, specs, excludes, build, tools):
        bsps = self._find_bsps(build)
        for bsp in bsps:
            self.bsps[bsp] = {'output': [], 'objects': [], 'symbols': []}
            for spec in specs:
                self.bsps[bsp]['objects'] += self._find(
                    os.path.join(build, bsp), spec)
            arch = self.bsp_arch(bsp)
            if tools is not None:
                cmd = os.path.join(tools, 'bin', arch + '-nm')
            else:
                cmd = arch + '-nm'
            nm = command([cmd] + self.bsps[bsp]['objects'])
            nm.run()
            self.bsps[bsp]['output'] = nm.output
            object = '-'
            syms = []
            for line in nm.output:
                if len(line) == 0:
                    continue
                if line[-1] == ':':
                    object = os.path.basename(line[:-1])
                    continue
                ls = line.split()
                if len(ls) == 3:
                    ls = ls[1:]
                if ls[0] in ['A', 'B', 'C', 'D', 'R', 'T', 'W']:
                    sym = ls[1]
                    if sym.startswith(self.bsd_tag):
                        sym = sym[len(self.bsd_tag):]
                    if sym in syms:
                        print('warning: duplicate symbol: %s:%s: %s (%s)' %
                              (bsp, object, sym, ls[1]))
                    syms += [sym]
            self.bsps[bsp]['symbols'] += syms

    def generate_header(self):
        self.output['source'] = [
            '/*', '* RTEMS Libbsd, this file is generated. Do not edit.', '*/',
            '#ifndef _RTEMS_BSD_MACHINE_RTEMS_BSD_KERNEL_SPACE_H_',
            '#error "the header file <machine/rtems-bsd-kernel-space.h> must be included first"',
            '#endif', ''
        ] + ['#define\t%s %s%s' % (sym, self.bsd_tag, sym) for sym in self.output['symbols']]

    def write_header(self, header):
        with open(header, 'wb') as o:
            o.write(
                os.linesep.join(self.output['source'] + ['']).encode("utf-8"))

    def write_sym_data(self):
        for bsp in self.bsps:
            arch = self.bsp_arch(bsp)
            with open('sym-data-' + arch + '.txt', 'w') as o:
                o.writelines(os.linesep.join(self.bsps[bsp]['output']))

    def merge(self, symbols):
        self.output['symbols'] = \
            self._clean(self.output['symbols'] + symbols)

    def merge_bsp(self):
        for bsp in self.bsps:
            self.merge(self.bsps[bsp]['symbols'])

    def analyse(self):
        for bsp in self.bsps:
            for sym in self.bsps[bsp]['symbols']:
                if sym in self.header['symbols']:
                    key = 'mapped'
                else:
                    key = 'new'
                self.analysis[key] += [sym]
        for key in self.analysis:
            self.analysis[key] = self._clean(self.analysis[key])
        self.analysis['unmapped'] = [sym for sym in self.header['symbols'] if sym not in self.analysis['mapped']]

    def diff(self):
        import difflib
        return list(
            difflib.unified_diff(self.header['source'], self.output['source']))

    def report(self):
        out = [
            'Symbols:',
            ' header   : %d' % (len(self.header['symbols'])),
            ' mapped   : %d' % (len(self.analysis['mapped'])),
            ' unmapped : %d' % (len(self.analysis['unmapped'])),
            ' new      : %d' % (len(self.analysis['new']))
        ]
        max_len = 0
        for bsp in self.bsps:
            if max_len < len(bsp):
                max_len = len(bsp)
        out += ['BSPs: %*s Unmapped Total' % (max_len - 4, ' ')]
        for bsp in self.bsps:
            unmapped = len(self._clean(self.bsps[bsp]['symbols']))
            total = len(self.bsps[bsp]['symbols'])
            out += [' %-*s: %-8d %d' % (max_len, bsp, unmapped, total)]
        out += ['New:'] + [' ' + sym for sym in self.analysis['new']]
        out += ['Unmapped:'] + [' ' + sym for sym in self.analysis['unmapped']]
        return out


def run(args):
    try:
        argsp = argparse.ArgumentParser(
            prog='rtems-kern-symbols',
            description="RTEMS LibBSD Kernel Symbols")
        argsp.add_argument('-t',
                           '--rtems-tools',
                           help='RTEMS Tools (default: %(default)s).',
                           type=str,
                           default=None)
        argsp.add_argument(
            '-w',
            '--write',
            action='store_true',
            help=
            'Write the header to the output file name (default: %(default)s).')
        argsp.add_argument(
            '-d',
            '--diff',
            action='store_true',
            help='Show a diff if the header has changed (default: %(default)s).'
        )
        argsp.add_argument(
            '-o',
            '--output',
            type=str,
            default=kern_header,
            help='output path to the write the header (default: %(default)s).')
        argsp.add_argument(
            '-b',
            '--build',
            type=str,
            default='build',
            help='path to the rtems libbsd build output (default: %(default)s).'
        )
        argsp.add_argument(
            '-K',
            '--kern-header',
            type=str,
            default=kern_header,
            help=
            'kernel header to load existing symbols from(default: %(default)s).'
        )
        argsp.add_argument(
            '-S',
            '--sym-data',
            action="store_true",
            help=
            'Write the BSP symbol data that is parsed (default: %(default)s).')
        argsp.add_argument(
            '-r',
            '--regenerate',
            action="store_true",
            help=
            'Regenerate the header file from the symbols in it, write option ignored (default: %(default)s).'
        )
        argsp.add_argument('-R',
                           '--report',
                           action="store_true",
                           help='Generate a report (default: %(default)s).')
        argopts = argsp.parse_args(args[1:])

        print('RTEMS LibBSD Kernel Symbols, %s' % (version))

        if not os.path.exists(argopts.build):
            raise general_error('path does not exist: %s' % (argopts.build))

        ks = kernel_symbols(kern_excludes)

        ks.load_header(argopts.kern_header)

        if argopts.regenerate:
            ks.merge(ks.header['symbols'])
            print('Regenerating: symbols: %d: %s' %
                  (len(ks.output['symbols']), argopts.output))
            ks.generate_header()
            diff = ks.diff()
            if len(diff) == 0:
                print('info: no changes; header not updated')
            else:
                print('info: writing: %s' % (argopts.output))
                ks.write_header(argopts.output)
            raise exit(0)

        ks.load_symbols(kern_objects, kern_excludes, argopts.build,
                        argopts.rtems_tools)

        if argopts.sym_data:
            ks.write_sym_data()

        ks.analyse()
        ks.merge(ks.header['symbols'])
        ks.merge_bsp()
        ks.generate_header()

        diff = ks.diff()
        if argopts.write:
            if len(diff) == 0:
                print('info: no changes; header not updated')
            else:
                print('info: writing: %s' % (argopts.output))
                ks.write_header(argopts.output)

        if argopts.report:
            print(os.linesep.join(ks.report()))

        if argopts.diff:
            print('Diff: %d' % (len(diff)))
            print(os.linesep.join(diff))

    except general_error as gerr:
        print(gerr)
        print('RTEMS Kernel Symbols FAILED', file=sys.stderr)
        sys.exit(1)
    except KeyboardInterrupt:
        log.notice('abort: user terminated')
        sys.exit(1)
    except exit as ec:
        sys.exit(ec.code)
    sys.exit(0)


if __name__ == "__main__":
    run(sys.argv)
