"""
Copyright (c) 2021 STMicroelectronics.
All rights reserved.

This software is licensed under terms that can be found in the LICENSE file
in the root directory of this software component.
If no LICENSE file comes with this software, it is provided AS-IS.

Script to decompile object code stored in lambda layers
"""

import larq as lq
import tensorflow as tf


@lq.utils.register_keras_custom_object
class STCustomDoReFa(lq.quantizers.DoReFa):
    """ STCustomDoReFa Quantizer """

    def __init__(self, k_bit: int = 8, mode: str = "activations", **kwargs):
        if mode not in ('activations', 'weights'):
            raise ValueError(
                f"Invalid Custom DoReFa quantizer mode {mode}. "
                "Valid values are 'activations' and 'weights'."
            )
        # self.__name__ = "CustomDoReFa"
        self.mode = mode
        self.precision = k_bit

        super().__init__(k_bit=k_bit, mode=mode, **kwargs)

    def activation_preprocess(self, inputs):
        """ Activation preprocessing routine """     # noqa : DAR101,DAR201
        # activation data preprocess routine
        r_eps = 1.0 / (2**(self.precision - 1.0))

        # our change: return clipped values around upper limit 1.0
        return tf.clip_by_value(inputs, -1.0, 1.0 - r_eps)

    def weight_preprocess(self, inputs):
        """ Weight preprocessing routine """     # noqa : DAR101,DAR201
        # Limit inputs to [-1, 1] range
        limited = tf.math.tanh(inputs)

        # Divider for max-value norm.
        dividend = tf.math.reduce_max(tf.math.abs(limited))

        # Need to stop the gradient here. Otherwise, for the maximum element,
        # which gives the dividend, normed is limited/limited (for this one
        # maximum digit). The derivative of y = x/x, dy/dx is just zero, when
        # one does the simplification y = x/x = 1. But TF does NOT do this
        # simplification when computing the gradient for the
        # normed = limited/dividend operation. As a result, this gradient
        # becomes complicated, because during the computation, "dividend" is
        # not just a constant, but depends on "limited" instead. Here,
        # tf.stop_gradient is used to mark "dividend" as a constant explicitly.
        dividend = tf.stop_gradient(dividend)

        # Norm and then scale from value range [-1,1] to [0,1] (the range
        # expected by the core quantization operation).
        # If the dividend used for the norm operation is 0, all elements of
        # the weight tensor are 0 and divide_no_nan returns 0 for all weights.
        # So if all elements of the weight tensor are zero, nothing is normed.
        return tf.math.divide_no_nan(limited, 2.0 * dividend) + 0.5

    def call(self, inputs):
        """     # noqa: DAR101,DAR201,DAR401
        Depending on quantizer mode (activation or weight) just clip inputs
        on [0, 1] range or use weight preprocessing method.
        """
        if self.mode == 'activations':
            inputs = self.activation_preprocess(inputs)
        elif self.mode == 'weights':
            inputs = self.weight_preprocess(inputs)
        else:
            raise ValueError(
                f"Invalid DoReFa quantizer mode {self.mode}. "
                "Valid values are 'activations' and 'weights'."
            )

        if self.mode == 'activations':
            @tf.custom_gradient
            def _k_bit_with_identity_grad_act(values):
                scale = 2 ** (self.precision - 1)  # activation scale 1/128, zp=0
                return tf.round(values * scale) / scale, lambda dy: dy
            outputs = _k_bit_with_identity_grad_act(inputs)
        elif self.mode == 'weights':
            @tf.custom_gradient
            def _k_bit_with_identity_grad_weights(values):
                scale = (2 ** self.precision) - 2  # Laurent request for symetric quantization 1/127, [-127...+127]
                return tf.round(values * scale) / scale, lambda dy: dy
            outputs = _k_bit_with_identity_grad_weights(inputs)
            # [0, 1] -> [-1, 1], quantized means [0, 254] -> [-127, 127]
            outputs = 2.0 * outputs - 1.0

        return outputs

    def get_config(self):
        return {**super().get_config(), 'k_bit': self.precision, 'mode': self.mode}
