{#
  ******************************************************************************
  * @file    main_app.j2.c
  * @author  AST Embedded Analytics Research Platform
  * @brief   AI Tool Automatic Code Generator for Embedded NN computing
  ******************************************************************************
  * @attention
  *
  * Copyright (c) 2023 STMicroelectronics.
  * All rights reserved.
  *
  * This software is licensed under terms that can be found in the LICENSE file
  * in the root directory of this software component.
  * If no LICENSE file comes with this software, it is provided AS-IS.
  ******************************************************************************
#}
{%- import 'network_layers.j2.c' as layers -%}
{% set _activation_buffers = config[layers.ACTIVATIONS] -%}
{% set _state_buffers = config[layers.STATES] -%}
{% set _weight_buffers = config[layers.WEIGHTS] -%}
{% set net_name = config['net_name'].lower() -%}
{% set NET_NAME = config['net_name'].upper() -%}


/**
  ******************************************************************************
  * @file    stai_main_app.c
  * @author  AST Embedded Analytics Research Platform
  * @date    {{ config['date_time'] }}
  * @brief   AI Tool Automatic Code Generator for Embedded NN computing
  ******************************************************************************
  * {{ config['copyright'] }}
  * All rights reserved.
  *
  * This software is licensed under terms that can be found in the LICENSE file
  * in the root directory of this software component.
  * If no LICENSE file comes with this software, it is provided AS-IS.
  ******************************************************************************
  */

#define EXPORT_RUNTIME_LITE_APIS

#ifndef MINIMAL_COMPILE
#include <inttypes.h>
#include <stdio.h>
#include <string.h>
#endif

#include "stai.h"
#include "{{ net_name }}_inputs.h"
#include "{{ net_name }}.h"

#define my_lite_inspect_cb      (NULL)

#ifdef MINIMAL_COMPILE
#define LOG_PRINT(fmt, ...) \
  {} while (0);

#else
#define LOG_PRINT(fmt, ...) \
  { printf(fmt, ##__VA_ARGS__); fflush(stdout);}
#endif

int32_t simple_atoi(const char* str)
{
   int32_t ret = 0;
   for(int i = 0; str[i] != '\0'; i++) {
      ret = ret * 10 + str[i] - '0';
   }
   return ret;
}

__attribute__((noinline,noclone,optimize("O3"))) void th_signal_start(){  __asm__("nop");}
__attribute__((noinline,noclone,optimize("O3"))) void th_signal_stop() {  __asm__("nop");__asm__("nop");}


int main(int argc, char *argv[])
{
#ifndef MINIMAL_COMPILE
   int32_t sample = simple_atoi(argv[1]);
   if (sample == 0)
   {
      LOG_PRINT("__START_SELF_INSPECTION__\n")
      LOG_PRINT("name:{{ net_name }}\n");
      LOG_PRINT("n_inputs:{{ config['c_net_in']|length }}\n");
      {% for input in config['c_net_in'] -%}
         {# using jinja filter to retrieve shape length from get_value() input tuple -#}
         {% set _values = input['shape'].get_values() -%}
         {% if _values|length == 6 -%}
      LOG_PRINT("inputtensor_{{ loop.index }}:({{ _values[5] }},{{ _values[1] }},{{ _values[2] }},{{ _values[3] }},{{ _values[4] }},{{ _values[0] }})#{{ input['fmt'].get_c_format() }}#{{ input['array_byte_size']}}#{{ input['fmt'].get_scale()}}#{{ input['fmt'].get_zero()}}\n");
         {% elif _values|length == 5 -%}
      LOG_PRINT("inputtensor_{{ loop.index }}:({{ _values[4] }},{{ _values[1] }},{{ _values[2] }},{{ _values[3] }},{{ _values[0] }})#{{ input['fmt'].get_c_format() }}#{{ input['array_byte_size']}}#{{ input['fmt'].get_scale()}}#{{ input['fmt'].get_zero()}}\n");
         {% else -%}
      LOG_PRINT("inputtensor_{{ loop.index }}:({{ _values[3] }},{{ _values[1] }},{{ _values[2] }},{{ _values[0] }})#{{ input['fmt'].get_c_format() }}#{{ input['array_byte_size']}}#{{ input['fmt'].get_scale()}}#{{ input['fmt'].get_zero()}}\n");
         {% endif %}
      {%- endfor %}
      LOG_PRINT("n_outputs:{{ config['c_net_out']|length }}\n");
      {% for output in config['c_net_out'] -%}
         {# using jinja filter to retrieve shape length from get_value() output tuple -#}
         {% set _values = output['shape'].get_values() -%}
         {% if _values|length == 6 -%}
      LOG_PRINT("outputtensor_{{ loop.index }}:({{ _values[5] }},{{ _values[1] }},{{ _values[2] }},{{ _values[3] }},{{ _values[4] }},{{ _values[0] }})#{{ output['fmt'].get_c_format() }}#{{ output['array_byte_size']}}#{{ output['fmt'].get_scale()}}#{{ output['fmt'].get_zero()}}\n");
         {% elif _values|length == 5 -%}
      LOG_PRINT("outputtensor_{{ loop.index }}:({{ _values[4] }},{{ _values[1] }},{{ _values[2] }},{{ _values[3] }},{{ _values[0] }})#{{ output['fmt'].get_c_format() }}#{{ output['array_byte_size']}}#{{ output['fmt'].get_scale()}}#{{ output['fmt'].get_zero()}}\n");
         {% else -%}
      LOG_PRINT("outputtensor_{{ loop.index }}:({{ _values[3] }},{{ _values[1] }},{{ _values[2] }},{{ _values[0] }})#{{ output['fmt'].get_c_format() }}#{{ output['array_byte_size']}}#{{ output['fmt'].get_scale()}}#{{ output['fmt'].get_zero()}}\n");
         {% endif -%}
      {% endfor %}
      LOG_PRINT("n_nodes:{{ config[layers.LAYERS]|length }}\n");
      LOG_PRINT("activations:{{ _activation_buffers['size'] }}\n");
      LOG_PRINT("weights:{{ _weight_buffers['size'] }}\n");
      LOG_PRINT("runtime_name:STM.AI\n");
      LOG_PRINT("macc:{{ config['Macc'] }}\n");
      LOG_PRINT("runtime_version:%d.%d.%d\n", STAI_TOOLS_VERSION_MAJOR, STAI_TOOLS_VERSION_MINOR, STAI_TOOLS_VERSION_MICRO);
      LOG_PRINT("runtime_tools_version:%d.%d.%d\n", STAI_TOOLS_VERSION_MAJOR, STAI_TOOLS_VERSION_MINOR, STAI_TOOLS_VERSION_MICRO);
      LOG_PRINT("__STOP_SELF_INSPECTION__\n")

      return 0;
   }
   sample -= 1;
#endif

   /*  Declare and allocate memory for private network context  */
   STAI_NETWORK_CONTEXT_DECLARE(network, STAI_{{ NET_NAME }}_CONTEXT_SIZE)
   stai_return_code return_code = STAI_SUCCESS;

   stai_runtime_init();

   {%- if not config['allocate_activations'] %}
   {% if _activation_buffers['size'] > 0 %}
   /*  Declare activations buffer pointers array  */
   stai_ptr activation_buffers[STAI_{{ NET_NAME }}_ACTIVATIONS_NUM] = {0};

   {%- for _activation in _activation_buffers['buffers'] -%}
   /*  Allocate and set activation buffer #{{loop.index}}  */
   STAI_ALIGNED(STAI_{{ NET_NAME }}_ACTIVATION_{{loop.index}}_ALIGNMENT)
   uint8_t activation{{loop.index}}[STAI_{{ NET_NAME }}_ACTIVATION_{{loop.index}}_SIZE] = {0};
   activation_buffers[{{loop.index0}}] = (stai_ptr)(activation{{loop.index}});
   {% endfor %}
   {% endif %}
   {% endif %}

   {%- if not config['allocate_states'] %}
   {% if _state_buffers['size'] > 0 %}
   /*  Declare states buffer pointers array  */
   stai_ptr state_buffers[STAI_{{ NET_NAME }}_STATES_NUM] = {0};

   {%- for _state in _state_buffers['buffers'] -%}
   /*  Allocate and set state buffer #{{loop.index}}  */
   STAI_ALIGNED(STAI_{{ NET_NAME }}_STATE_{{loop.index}}_ALIGNMENT)
   uint8_t state{{loop.index}}[STAI_{{ NET_NAME }}_STATE_{{loop.index}}_SIZE] = {0};
   state_buffers[{{loop.index0}}] = (stai_ptr)(state{{loop.index}});
   {% endfor %}
   {% endif %}
   {% endif %}

   {% if config['allocate_outputs'] -%}
#ifndef MINIMAL_COMPILE
   {% endif -%}
   {% for _output in config['c_net_out'] %}
   /*  Allocate and declare output buffer #{{loop.index}}  */
   {{ _output['fmt'].get_c_type() }} output{{loop.index}}[{{ _output['padded_elements'] }}] __attribute__((aligned(4)));
   {% endfor %}
   {% if config['allocate_outputs'] -%}
#endif
   {%- endif -%}

   /*  Declare inputs buffer pointers array  */
   {% if config['allocate_inputs'] -%}
#ifndef MINIMAL_COMPILE
   {% endif -%}
   stai_ptr input_buffers[STAI_{{ NET_NAME }}_IN_NUM] = {
   {%- for _input in config['c_net_in'] %}
#ifndef MINIMAL_COMPILE
      (stai_ptr)input{{loop.index}}[sample]
#else
      (stai_ptr)&input{{loop.index}}
#endif
      {% if not loop.last -%}
         ,
      {%- endif -%}
   {%- endfor -%}
   };
   {% if config['allocate_inputs'] -%}
#endif
   {%- endif -%}

   /*  Declare outputs buffer pointers array  */
   {% if config['allocate_outputs'] -%}
#ifndef MINIMAL_COMPILE
   {% endif -%}
   stai_ptr output_buffers[STAI_{{ NET_NAME }}_OUT_NUM] = {
   {%- for _output in config['c_net_out'] %}
      (stai_ptr)&output{{loop.index}}
      {% if not loop.last -%}
          ,
      {%- endif -%}
   {%- endfor -%}
   };
   {% if config['allocate_outputs'] -%}
#endif
   {%- endif -%}

   /*  Initialize network context  */
   return_code = stai_{{ net_name }}_init(network);
   if (return_code != STAI_SUCCESS) {
      LOG_PRINT("  ## Test Failed executing stai init: 0x%x.\n\n", return_code)
      return -1;
   }
   {%- if not config['allocate_activations'] %}
   {% if _activation_buffers['size'] > 0 %}
   /*  Set network activations buffers  */
   return_code = stai_{{ net_name }}_set_activations(network, activation_buffers, STAI_{{ NET_NAME }}_ACTIVATIONS_NUM);
   if (return_code != STAI_SUCCESS) {
      LOG_PRINT("  ## Test Failed executing stai set activations: 0x%x.\n\n", return_code)
      return -1;
   }
   {%- endif -%}
   {%- endif -%}

   {%- if not config['allocate_states'] %}
   {% if _state_buffers['size'] > 0 %}
   /*  Set network states buffers  */
   return_code = stai_{{ net_name }}_set_states(network, state_buffers, STAI_{{ NET_NAME }}_STATES_NUM);
   if (return_code != STAI_SUCCESS) {
      LOG_PRINT("  ## Test Failed executing stai set states: 0x%x.\n\n", return_code)
      return -1;
   }
   {%- endif -%}
   {%- endif -%}

   {% if not config['allocate_inputs'] -%}
   /*  Set network inputs buffers  */
   return_code = stai_{{ net_name }}_set_inputs(network, input_buffers, STAI_{{ NET_NAME }}_IN_NUM);
   if (return_code != STAI_SUCCESS) {
      LOG_PRINT("  ## Test Failed executing stai set inputs: 0x%x.\n\n", return_code)
      return -1;
   }
   {%- endif %}
   {% if not config['allocate_outputs'] -%}
   /*  Set network outputs buffers  */
   return_code = stai_{{ net_name }}_set_outputs(network, output_buffers, STAI_{{ NET_NAME }}_OUT_NUM);
   if (return_code != STAI_SUCCESS) {
      LOG_PRINT("  ## Test Failed executing stai set outputs: 0x%x.\n\n", return_code)
      return -1;
   }
   {%- endif %}

#ifndef MINIMAL_COMPILE
   {# copy inputs into activation buffers if allocate_inputs is used #}
   {% if config['allocate_inputs'] -%}
   LOG_PRINT("Copying inputs in activations\n");
   stai_ptr inputs[STAI_{{ NET_NAME }}_IN_NUM];
   stai_size n_inputs = 0;
   /*  Get network inputs buffers  */
   return_code = stai_{{ net_name }}_get_inputs(network, inputs, &n_inputs);
   if (return_code != STAI_SUCCESS) {
      LOG_PRINT("  ## Test Failed executing stai get inputs: 0x%x.\n\n", return_code)
      return -1;
   }
   {% for _ in config['c_net_in'] -%}
   LOG_PRINT("Copy %d from %p to %p\n", STAI_{{ NET_NAME }}_IN_{{loop.index}}_SIZE_BYTES, input_buffers[{{loop.index0}}], inputs[{{loop.index0}}]);
   memcpy(inputs[{{loop.index0}}], input_buffers[{{loop.index0}}], STAI_{{ NET_NAME }}_IN_{{loop.index}}_SIZE_BYTES);
   {%- endfor %}
   LOG_PRINT("Copied inputs in activations memory\n");
   {% endif %}
#endif

   /*  Execute network model inference on sample test (synchronous mode)  */
   LOG_PRINT("Starting inference\n");
   th_signal_start();
   return_code = stai_{{ net_name }}_run(network, STAI_MODE_SYNC);
   th_signal_stop();
   LOG_PRINT("Completed inference\n");
   if (return_code != STAI_SUCCESS) {
      LOG_PRINT("  ## Test Failed executing stai network: 0x%x.\n\n", return_code)
      return -1;
   }

#ifndef MINIMAL_COMPILE
   {% if config['allocate_outputs'] -%}
   stai_ptr outputs[STAI_{{ NET_NAME }}_OUT_NUM];
   stai_size n_outputs = 0;
   /*  Get network outputs buffers  */
   return_code = stai_{{ net_name }}_get_outputs(network, outputs, &n_outputs);
   if (return_code != STAI_SUCCESS) {
      LOG_PRINT("  ## Test Failed executing stai get inputs: 0x%x.\n\n", return_code)
      return -1;
   }
   {% for _ in config['c_net_out'] -%}
   memcpy(output_buffers[{{loop.index0}}], outputs[{{loop.index0}}],  STAI_{{ NET_NAME }}_OUT_{{loop.index}}_SIZE_BYTES);
   {%- endfor %}
   {% endif %}
#endif

#ifndef MINIMAL_COMPILE
   typedef union alias {
      int8_t _int8_t;
      uint8_t _uint8_t;
      int16_t _int16_t;
      uint16_t _uint16_t;
      int32_t _int32_t;
      uint32_t _uint32_t;
      float _float_t;
      bool _bool_t;
   } alias;
   {% for _output in config['c_net_out'] -%}
   LOG_PRINT("__START_OUTPUT{{loop.index}} __\n");
   for(int32_t o = 0; o < {{ _output['padded_elements'] / (32 if _output['fmt'].is_minus_plus_one() else 1) }}; o++) {
      {% if _output['fmt'].is_quantized_integer(size=8) and _output['fmt'].is_signed() -%}
      const alias temp = {._int8_t=*(const int8_t*)&output{{loop.index}}[o]};
      {% elif _output['fmt'].is_quantized_integer(size=8) and _output['fmt'].is_unsigned() -%}
      const alias temp = {._uint8_t=*(const uint8_t*)&output{{loop.index}}[o]};
      {% elif _output['fmt'].is_quantized_integer(size=16) and _output['fmt'].is_signed() -%}
      const alias temp = {._int16_t=*(const int16_t*)&output{{loop.index}}[o]};
      {% elif _output['fmt'].is_quantized_integer(size=16) and _output['fmt'].is_unsigned() -%}
      const alias temp = {._uint16_t=*(const uint16_t*)&output{{loop.index}}[o]};
      {% elif _output['fmt'].get_bit_size() == 32 and _output['fmt'].is_signed() -%}
      const alias temp = {._int32_t=*(const int32_t*)&output{{loop.index}}[o]};
      {% elif _output['fmt'].get_bit_size() == 32 and _output['fmt'].is_unsigned() -%}
      const alias temp = {._uint32_t=*(const uint32_t*)&output{{loop.index}}[o]};
      {% elif _output['fmt'].is_bool() -%}
      const alias temp = {._bool_t=*(const bool*)&output{{loop.index}}[o]};
      {% elif _output['fmt'].is_float() -%}
      const alias temp = {._float_t=*(const float*)&output{{loop.index}}[o]};
      {% else %}
      const alias temp = {._uint32_t=*(const uint32_t*)&output{{loop.index}}[o]};
      {% endif -%}
      if(o != 0 && o % 10 == 0) {
         LOG_PRINT("\n");
      }
      LOG_PRINT("%08" PRIx32 " ", temp._uint32_t);
   }
   LOG_PRINT("\n__END_OUTPUT{{loop.index}} __\n");
   {% endfor %}
#endif

   /*  Network de-initialization  */
   return_code = stai_{{ net_name }}_deinit(network);

   stai_runtime_deinit();

   return (return_code == STAI_SUCCESS) ? 0 : -1;
}
