"""
Copyright (c) 2021 STMicroelectronics.
All rights reserved.

This software is licensed under terms that can be found in the LICENSE file
in the root directory of this software component.
If no LICENSE file comes with this software, it is provided AS-IS.

Script to decompile object code stored in lambda layers
"""


import argparse
import codecs
import io
import logging
import marshal
import os
import pickle
import re
import sys
import types
import uncompyle6
import xdis


# The following imports are necessary to replicated the globals seen by body of lambda layer
# pylint: disable=unused-import
# pylint: disable=no-name-in-module
from keras import activations  # noqa: F401
from keras import backend as K  # noqa: F401
from keras import constraints  # noqa: F401
from keras import initializers  # noqa: F401
from keras import regularizers  # noqa: F401
from tensorflow.python.eager import backprop  # noqa: F401
from tensorflow.python.eager import context  # noqa: F401
from tensorflow.python.framework import constant_op  # noqa: F401
from tensorflow.python.framework import dtypes  # noqa: F401
from tensorflow.python.framework import ops  # noqa: F401
from tensorflow.python.framework import tensor_shape  # noqa: F401
from tensorflow.python.keras.engine import keras_tensor  # noqa: F401
from tensorflow.python.keras.engine.base_layer import Layer  # noqa: F401
from tensorflow.python.keras.engine.input_spec import InputSpec  # noqa: F401
from tensorflow.python.keras.utils import conv_utils  # noqa: F401
from tensorflow.python.keras.utils import generic_utils  # noqa: F401
from tensorflow.python.keras.utils import tf_utils  # noqa: F401
from tensorflow.python.ops import array_ops  # noqa: F401
from tensorflow.python.ops import gen_array_ops  # noqa: F401
from tensorflow.python.ops import math_ops  # noqa: F401
from tensorflow.python.ops import nn  # noqa: F401
from tensorflow.python.ops import variable_scope  # noqa: F401
from tensorflow.python.platform import tf_logging  # noqa: F401
from tensorflow.python.trackable import base as trackable  # noqa: F401
from tensorflow.python.util import dispatch  # noqa: F401
from tensorflow.python.util import nest  # noqa: F401
from tensorflow.python.util import tf_decorator  # noqa: F401
from tensorflow.python.util import tf_inspect  # noqa: F401
from tensorflow.python.util.tf_export import get_canonical_name_for_symbol  # noqa: F401
from tensorflow.python.util.tf_export import get_symbol_from_name  # noqa: F401
from tensorflow.python.util.tf_export import keras_export  # noqa: F401


def _parse_code(code, string_stream, closure_values):
    """."""         # noqa: DAR101, DAR201
    logging.debug('Parsing code %s', str(code))
    string_stream.seek(0)
    source_code = re.sub(r'^', '    ', string_stream.read(), flags=re.MULTILINE)
    arguments = list(code.co_varnames[:code.co_argcount])
    function_name = 'custom_lambda' if code.co_name == '<lambda>' else code.co_name
    function_signature = 'def ' + function_name + '(' + ','.join(arguments) + '):\n'
    function_closure = ''
    if closure_values:
        for variable, value in zip(code.co_freevars, closure_values):
            function_closure += '    ' + variable + ' = ' + str(value) + '\n'
    source_code = function_signature + function_closure + source_code
    return source_code


def main():
    """."""
    # pylint: disable=no-member
    # Initialize logger

    parser = argparse.ArgumentParser(description="Decompile python object code")
    parser.add_argument('-o', '--output', help='The file where the decompiled code will be stored', required=True)
    parser.add_argument('-i', '--input', help='An input file storing the code to be decompiled', required=True)
    parser.add_argument('-d', '--debug', help='Enable the debug messages (default=False)', default=False,
                        action='store_true')

    args = parser.parse_args()
    if args.debug:
        logging.basicConfig(level=logging.DEBUG)
    else:
        logging.basicConfig(level=logging.INFO)
    logging.info('Loading %s', str(args.input))
    if not os.path.exists(args.input):
        logging.error('%s does not exists', args.input)
        sys.exit(1)
    logging.info('Size is %s', str(os.path.getsize(args.input)))
    input_file = open(args.input, 'rb')
    raw_code, defaults, closure_values = pickle.load(input_file)
    input_file.close()
    if closure_values is not None:
        def ensure_value_to_cell(value):
            """
            Ensures that a value is converted to a python cell object.

            Parameters
            ----------
            value
                Any value that needs to be casted to the cell type

            Returns
            -------
                A value wrapped as a cell object (see function "func_load")
            """
            def dummy_fn():
                # just access it so it gets captured in .__closure__
                value  # pylint: disable=pointless-statement
            # logging.debug('ensure_value_to_cell: %s', str(value))
            cell_value = dummy_fn.__closure__[0]
            return value if isinstance(value, type(cell_value)) else cell_value
        try:
            closure = tuple(ensure_value_to_cell(c) for c in closure_values)
        except (TypeError, AttributeError):
            closure = None
        closure = None
    else:
        closure = None
    raw_code = codecs.decode(raw_code.encode('ascii'), 'base64')

    try:
        code = marshal.loads(raw_code)
    except ValueError:
        for magic in (62211, 62071, 62041, 3280, 3250, 3000):
            try:
                fake_file = io.BytesIO(raw_code)
                code = xdis.unmarshal.load_code(fake_file, magic)
                logging.info('Loaded code with code: %s', str(magic))
                break
            except Exception:  # pylint: disable=broad-except
                logging.debug('Failed to  load code with code: %s', str(magic))
                code = None

    if code is None:
        logging.error('Impossible to load code inside lambda layer')
        sys.exit(1)

    # Check if it can be compiled with the current version of python
    python_version = tuple(sys.version_info[:3])
    string_stream = io.StringIO()
    globs = globals().copy()
    try:
        uncompyle6.main.decompile(python_version, code, out=string_stream)
        source_code = _parse_code(code, string_stream, closure_values)
    except (uncompyle6.parser.ParserError, IndexError, RuntimeError, AssertionError):
        # Fall back in using source code
        # logging.debug('Failed to uncompile with version %s', str(python_version))
        for version in ('2.7', '3.1', '3.2', '3.3', '3.4', '3.5',
                        '3.6', '3.7', '3.8', '3.9'):
            version += '.0'
            version_info = tuple([int(v) for v in version.split('.')])
            try:
                logging.debug('Trying to decompile with %s', str(version_info))
                string_stream = io.StringIO()
                uncompyle6.main.decompile(version_info, code, string_stream)
            except (uncompyle6.parser.ParserError, IndexError, RuntimeError, AssertionError):
                logging.debug('Failure decompiling lambda layer with %s', str(version_info))
                continue

            logging.debug('Success decompiling lambda layer with %s', str(version_info))
            source_code = _parse_code(code, string_stream, closure_values)
            filename = code.co_filename if code.co_filename else 'unknonwn'
            code = compile(source_code, filename, 'exec')
            function = types.FunctionType(code.co_consts[0], globs, name=code.co_name, argdefs=defaults)

            output_file = open(args.output, 'wb')
            code = marshal.dumps(function.__code__)
            pickle.dump([function.__name__, code], output_file)
            output_file.close()
            sys.exit(0)
        logging.error('Impossible to read bytecode inside lambda layer')
        sys.exit(1)

    logging.debug('Success decompiling lambda layer with %s (system)', str(python_version))
    function = types.FunctionType(code, globs, argdefs=defaults, closure=closure)
    output_file = open(args.output, 'wb')
    code = marshal.dumps(function.__code__)
    pickle.dump([function.__name__, code], output_file)
    output_file.close()
    sys.exit(0)


if __name__ == '__main__':
    main()
