{#
  ******************************************************************************
  * @file    legacy_network.j2.c
  * @author  AST Embedded Analytics Research Platform
  * @brief   AI Tool Automatic Code Generator for Embedded NN computing
  ******************************************************************************
  * @attention
  *
  * Copyright (c) 2023 STMicroelectronics.
  * All rights reserved.
  *
  * This software is licensed under terms that can be found in the LICENSE file
  * in the root directory of this software component.
  * If no LICENSE file comes with this software, it is provided AS-IS.
  ******************************************************************************
#}

{%- import 'network_layers.j2.c' as layers -%}
{%- import 'op_lite_common.j2.c' as lite -%}
{% set lite_graphs = config[lite.LITE_GRAPHS] -%}
{% set net_name = config['net_name'].lower() -%}
{% set net_signature = config['net_signature'][:10] -%}
{% set NET_NAME = config['net_name'].upper() -%}
{% set data_module = config['data_module'].lower() -%}
{% set DATA_MODULE = config['data_module'].upper() -%}
{% set custom = config[layers.CUSTOM] -%}
{% set _activations = config[layers.ACTIVATIONS] -%}
{% set _weights = config[layers.WEIGHTS] -%}

/**
  ******************************************************************************
  * @file    {{ net_name }}.c
  * @author  AST Embedded Analytics Research Platform
  * @date    {{ config['date_time'] }}
  * @brief   AI Tool Automatic Code Generator for Embedded NN computing
  ******************************************************************************
  * @attention
  *
  * {{ config['copyright'] }}
  * All rights reserved.
  *
  * This software is licensed under terms that can be found in the LICENSE file
  * in the root directory of this software component.
  * If no LICENSE file comes with this software, it is provided AS-IS.
  ******************************************************************************
  */

{{layers.include_lambda(net_name, config[layers.FUNCTIONS])}}
#include "legacy_{{ net_name }}.h"
#include "legacy_{{ data_module }}.h"

#include "ai_platform.h"
#include "ai_platform_interface.h"
#include "{{ net_name }}.h"

{% if config['reloc']  -%}
#if defined (AI_NETWORK_RELOC)
#include "ai_reloc_network.h"
#endif
{% endif -%}


#undef AI_NET_OBJ_INSTANCE
#define AI_NET_OBJ_INSTANCE g_legacy_{{ net_name }}
 
#undef AI_{{ NET_NAME }}_MODEL_SIGNATURE
#define AI_{{ NET_NAME }}_MODEL_SIGNATURE     "{{config['model_signature']}}"

#ifndef AI_TOOLS_REVISION_ID
#define AI_TOOLS_REVISION_ID     "{{ config['git_commit'] }}"
#endif

#undef AI_TOOLS_DATE_TIME
#define AI_TOOLS_DATE_TIME   "{{ config['date_time'] }}"

#undef AI_TOOLS_COMPILE_TIME
#define AI_TOOLS_COMPILE_TIME    __DATE__ " " __TIME__

#undef AI_{{ NET_NAME }}_N_BATCHES
#define AI_{{ NET_NAME }}_N_BATCHES         (1)

{% if _activations['map'] -%}
static ai_ptr {{ _activations['map'] }}[{{ _activations['buffers']|length }}] = AI_C_ARRAY_INIT;
{% endif -%}

{% if _weights['map'] -%}
static ai_ptr {{ _weights['map'] }}[{{ _weights['buffers']|length }}] = AI_C_ARRAY_INIT;
{% endif -%}


{%  for init_func in custom['init']: -%}
void {{ init_func }}(ai_layer* layer);
{% endfor %}
{%  for forward_func in custom['forward']: -%}
void {{ forward_func }}(ai_layer* layer);
{% endfor %}


#define _AI_{{ NET_NAME }}_CONTEXT_DECLARE( \
  name_, attr_, \
  signature_, _flags, _ctx_size) \
  AI_ALIGNED(4) \
  attr_ ai_network_context name_ = { \
    .magic = AI_MAGIC_LEGACY_TOKEN, \
    .signature = (signature_), \
    .tool_api_version = 0x0, \
    .error = AI_ERROR_INIT(NONE, NONE), \
    ._inputs = (ai_buffer[]) { \
{%- for (_shape, _buffer), c_net_in in zip(config['in_shapes'], config['c_net_in']): %}
      AI_BUFFER_INIT(AI_FLAG_NONE, {{ _buffer['fmt'].get_c_buffer_format() }},\
        AI_BUFFER_SHAPE_INIT(0, {{ c_net_in['shape'].to_bcwh()|length }}, {{ c_net_in['shape'].to_bcwh()|join(", ") }}), {{ _shape.get_size() }}, NULL, NULL),\
{%- endfor %}
    },\
    ._outputs = (ai_buffer[]) { \
{%- for (_shape, _buffer), c_net_out in zip(config['out_shapes'], config['c_net_out']): %}
      AI_BUFFER_INIT(AI_FLAG_NONE, {{ _buffer['fmt'].get_c_buffer_format() }}, \
        AI_BUFFER_SHAPE_INIT(0, {{ c_net_out['shape'].to_bcwh()|length}}, {{ c_net_out['shape'].to_bcwh()|join(", ") }}), {{ _shape.get_size() }}, NULL, NULL), \
{%- endfor %}
    }, \
    ._map_weights = \
{%- if _weights['size'] > 0 %}
      AI_BUFFER_ARRAY_OBJ_INIT_STATIC( \
        AI_FLAG_NONE, {{ _weights['buffers']|length }}, \
  {% for buf in _weights['buffers']: -%}
  AI_BUFFER_INIT(AI_FLAG_NONE,  AI_BUFFER_FORMAT_U8,\
    AI_BUFFER_SHAPE_INIT(AI_SHAPE_BCWH, 4, 1, {{ buf['pool_size'] }}, 1, 1),\
    {{ buf['pool_size'] }}, NULL, {{ buf['buffer_c_name_addr'] }}),   /* {{ buf['name'] }} */\
  {% endfor -%}
      ), \
{%- else %}
      AI_BUFFER_ARRAY_OBJ_INIT( \
        AI_FLAG_NONE, 0, NULL \
      ), \
{%- endif %}
    ._map_activations = \
{%- if _activations['size'] > 0 %}
      AI_BUFFER_ARRAY_OBJ_INIT_STATIC( \
        AI_FLAG_NONE, {{ _activations['buffers']|length }}, \
{%- for _buf in _activations['buffers']: %}
        AI_BUFFER_INIT(AI_FLAG_NONE,  AI_BUFFER_FORMAT_U8, \
          AI_BUFFER_SHAPE_INIT(0, 4, 1, {{ _buf['pool_size'] }}, 1, 1), \
          {{ _buf['pool_size'] }}, NULL, NULL){%-if not loop.last -%},{%-endif-%} \
{%- endfor %}
      ), \
{%- else %}
       AI_BUFFER_ARRAY_OBJ_INIT( \
         AI_FLAG_NONE, 0, NULL \
       ), \
{%- endif %}
    ._ctx = (ai_u64[((_ctx_size) + 7) / 8]){0}, \
  };

_AI_{{ NET_NAME }}_CONTEXT_DECLARE(
  g_legacy_{{ net_name }}, AI_STATIC,
  0x0, AI_FLAG_NONE, STAI_{{ NET_NAME }}_CONTEXT_SIZE)

{% if _activations['map'] %}
/******************************************************************************/
AI_DECLARE_STATIC
ai_bool {{ net_name }}_configure_activations(
  ai_network_context* net_ctx, const ai_network_params* params)
{
  AI_ASSERT(net_ctx)

  if (ai_platform_get_activations_map({{ _activations['map'] }}, {{ _activations['buffers']|length }}, params)) {
    /* Updating activations (byte) offsets */
    if (stai_{{ net_name }}_set_activations((stai_network*)net_ctx->_ctx, (stai_ptr*){{ _activations['map'] }}, {{ _activations['buffers']|length }}) == STAI_SUCCESS) {
      return true;
    }
  }
  AI_ERROR_TRAP(net_ctx, INIT_FAILED, NETWORK_ACTIVATIONS);
  return false;
}
{% endif %}

{% if _weights['map'] %}
/******************************************************************************/
AI_DECLARE_STATIC
ai_bool {{ net_name }}_configure_weights(
  ai_network_context* net_ctx, const ai_network_params* params)
{
  AI_ASSERT(net_ctx)

  if (ai_platform_get_weights_map({{ _weights['map'] }}, {{ _weights['buffers']|length }}, params)) {
    /* Updating weights (byte) offsets */
    if (stai_{{ net_name }}_set_weights((stai_network*)net_ctx->_ctx, (stai_ptr*){{ _weights['map'] }}, {{ _weights['buffers']|length }}) == STAI_SUCCESS) {
      return true;
    }
  }
  AI_ERROR_TRAP(net_ctx, INIT_FAILED, NETWORK_WEIGHTS);
  return false;
}
{% endif %}

/**  PUBLIC APIs SECTION  *****************************************************/

AI_API_ENTRY
ai_bool ai_{{ net_name }}_get_report(
  ai_handle network, ai_network_report* report)
{
  ai_network_context* net_ctx = (ai_network_context*)ai_platform_context_acquire(network);

  if (report && net_ctx)
  {
    stai_network_info stai_info;

    ai_network_report r = {
      .model_name        = AI_{{ NET_NAME }}_MODEL_NAME,
      .model_signature   = AI_{{ NET_NAME }}_MODEL_SIGNATURE,
      .model_datetime    = AI_TOOLS_DATE_TIME,
      
      .compile_datetime  = AI_TOOLS_COMPILE_TIME,
      
      .runtime_revision  = ai_platform_runtime_get_revision(),
      .runtime_version   = ai_platform_runtime_get_version(),

      .tool_revision     = AI_TOOLS_REVISION_ID,
      .tool_version      = {AI_TOOLS_VERSION_MAJOR, AI_TOOLS_VERSION_MINOR,
                            AI_TOOLS_VERSION_MICRO, 0x0},
      .tool_api_version  = AI_STRUCT_INIT,

      .api_version            = ai_platform_api_get_version(),
      .interface_api_version  = ai_platform_interface_api_get_version(),
      
      .n_macc            = {{config['Macc']}},
      .n_inputs          = STAI_{{ NET_NAME }}_IN_NUM,
      .inputs            = net_ctx->_inputs,
      .n_outputs         = STAI_{{ NET_NAME }}_OUT_NUM,
      .outputs           = net_ctx->_outputs,
      .map_signature     = AI_MAGIC_SIGNATURE,
      .map_weights       = net_ctx->_map_weights,
      .map_activations   = net_ctx->_map_activations,
      .n_nodes           = {{config['layers']|length}},
      .signature         = {{net_signature}},
    };

    if (stai_{{ net_name }}_get_info((stai_network*)net_ctx->_ctx, &stai_info) != STAI_SUCCESS) {
      AI_ERROR_TRAP(net_ctx, INVALID_HANDLE, NETWORK);
      return false;
    }

    r.n_macc      = stai_info.n_macc;
    r.n_nodes     = stai_info.n_nodes;
    r.signature   = stai_info.c_model_signature;

    stai_size n_inputs = STAI_{{ NET_NAME }}_IN_NUM;
    stai_ptr inputs[STAI_{{ NET_NAME }}_IN_NUM];

    if (stai_{{ net_name }}_get_inputs((stai_network*)net_ctx->_ctx, inputs, &n_inputs)) {
       AI_ERROR_TRAP(net_ctx, INVALID_HANDLE, NETWORK);
       return false;
    }

    for (stai_size idx=0; idx<r.n_inputs; idx++) {
      r.inputs[idx].data = AI_HANDLE_PTR(inputs[idx]);
    }

    stai_size n_outputs = STAI_{{ NET_NAME }}_OUT_NUM;
    stai_ptr outputs[STAI_{{ NET_NAME }}_OUT_NUM];

    if (stai_{{ net_name }}_get_outputs((stai_network*)net_ctx->_ctx, outputs, &n_outputs)) {
       AI_ERROR_TRAP(net_ctx, INVALID_HANDLE, NETWORK);
       return false;
    }

    for (stai_size idx=0; idx<r.n_outputs; idx++) {
      r.outputs[idx].data = AI_HANDLE_PTR(outputs[idx]);
    }

    stai_ptr activations[STAI_{{ NET_NAME }}_ACTIVATIONS_NUM];
    stai_size n_activations;
    stai_{{ net_name }}_get_activations((stai_network*)net_ctx->_ctx, activations, &n_activations);
    for (stai_size idx=0; idx<r.map_activations.size; idx++) {
       r.map_activations.buffer[idx].data = activations[idx];
    }
    stai_ptr weights[STAI_{{ NET_NAME }}_WEIGHTS_NUM];
    stai_size n_weights;
    stai_{{ net_name }}_get_weights((stai_network*)net_ctx->_ctx, weights, &n_weights);
    for (stai_size idx=0; idx<r.map_weights.size; idx++) {
       r.map_weights.buffer[idx].data = weights[idx];
    }

    *report = r;

    return true;
  }
  return false;
}

AI_API_ENTRY
ai_error ai_{{ net_name }}_get_error(ai_handle network)
{
  return ai_platform_network_get_error(network);
}


AI_API_ENTRY
ai_error ai_{{ net_name }}_create(
  ai_handle* network, const ai_buffer* network_config)
{
  return ai_platform_network_create(
    network, network_config, 
    AI_CONTEXT_OBJ(&g_legacy_{{ net_name }}),
    AI_TOOLS_API_VERSION_MAJOR, AI_TOOLS_API_VERSION_MINOR, AI_TOOLS_API_VERSION_MICRO);
}


AI_API_ENTRY
ai_error ai_{{ net_name }}_create_and_init(
  ai_handle* network, const ai_handle activations[], const ai_handle weights[])
{
    ai_error err;
    ai_network_params params;
    err = ai_network_create(network, AI_NETWORK_DATA_CONFIG);
    if (err.type != AI_ERROR_NONE) {
        return err;
    }
    if (ai_network_data_params_get(&params) != true) {
        return ai_network_get_error(*network);
    }

#if defined(AI_NETWORK_DATA_ACTIVATIONS_COUNT)
    if (activations) {
        /* set the addresses of the activations buffers */
        for (int idx=0; idx<params.map_activations.size; idx++) {
            AI_BUFFER_ARRAY_ITEM_SET_ADDRESS(&params.map_activations, idx, activations[idx]);
        }
    }
#endif
#if defined(AI_NETWORK_DATA_WEIGHTS_COUNT)
    if (weights) {
        /* set the addresses of the weight buffers */
        for (int idx=0; idx<params.map_weights.size; idx++) {
            AI_BUFFER_ARRAY_ITEM_SET_ADDRESS(&params.map_weights, idx, weights[idx]);
        }
    }
#endif
    if (ai_network_init(*network, &params) != true) {
        err = ai_network_get_error(*network);
    }
    return err;
}


AI_API_ENTRY
ai_buffer* ai_{{ net_name }}_inputs_get(ai_handle network, ai_u16 *n_buffer)
{
  if (network == AI_HANDLE_NULL) {
    network = (ai_handle)&g_legacy_{{ net_name }};
    ((ai_network_context*)network)->magic = AI_MAGIC_LEGACY_TOKEN;
  }
  ai_network_context* net_ctx = (ai_network_context*)ai_platform_context_acquire(network);
  stai_ptr _inputs[STAI_{{ NET_NAME }}_IN_NUM];
  stai_size n_inputs = 0;
  stai_return_code ret = stai_{{ net_name }}_get_inputs((stai_network*)net_ctx->_ctx, _inputs, &n_inputs);
  if (ret != STAI_SUCCESS) {
    ai_platform_stai_handle_error(&net_ctx->error, ret);
    return NULL;
  }
  
  *n_buffer = STAI_{{ NET_NAME }}_IN_NUM;
  return ai_platform_stai_bind_io(n_buffer, net_ctx->_inputs, _inputs, n_inputs);
}


AI_API_ENTRY
ai_buffer* ai_{{ net_name }}_outputs_get(ai_handle network, ai_u16 *n_buffer)
{
  if (network == AI_HANDLE_NULL) {
    network = (ai_handle)&g_legacy_{{ net_name }};
    ((ai_network_context*)network)->magic = AI_MAGIC_LEGACY_TOKEN;
  }
  ai_network_context* net_ctx = (ai_network_context*)ai_platform_context_acquire(network);
  stai_ptr _outputs[STAI_{{ NET_NAME }}_OUT_NUM];
  stai_size n_outputs = 0;
  stai_return_code ret = stai_{{ net_name }}_get_outputs((stai_network*)net_ctx->_ctx, _outputs, &n_outputs);
  if (ret != STAI_SUCCESS) { 
    ai_platform_stai_handle_error(&net_ctx->error, ret);
    return NULL;
  }

  *n_buffer = STAI_{{ NET_NAME }}_OUT_NUM;
  return ai_platform_stai_bind_io(n_buffer, net_ctx->_outputs, _outputs, n_outputs);
}


AI_API_ENTRY
ai_handle ai_{{ net_name }}_destroy(ai_handle network)
{
  return ai_platform_network_destroy(network);
}


AI_API_ENTRY
ai_bool ai_{{ net_name }}_init(
  ai_handle network, const ai_network_params* params)
{
  ai_network_context* net_ctx = (ai_network_context*)ai_platform_network_init(network, params);
  if (!net_ctx) return false;

  stai_return_code ret = stai_{{ net_name }}_init((stai_network*)net_ctx->_ctx);
  if (ret == STAI_SUCCESS) {
    ai_bool ok = true;
    {%- if _weights['map'] %}
    ok &= {{ net_name }}_configure_weights(net_ctx, params);
    {%- endif %}
    {%- if _activations['map'] %}
    ok &= {{ net_name }}_configure_activations(net_ctx, params);
    {%- endif %}
    ok &= ai_platform_network_post_init(network);
    return ok;
  } else {
    ai_platform_stai_handle_error(&net_ctx->error, ret);
    return false;
  }
}

AI_API_ENTRY
ai_i32 ai_{{ net_name }}_run(
  ai_handle network, const ai_buffer* input, ai_buffer* output)
{
  ai_network_context* net_ctx = (ai_network_context*)ai_platform_context_acquire(network);
  stai_return_code ret;
{% if config['allocate_inputs'] == 0 and config['allocate_outputs'] == 0 -%}
  stai_ptr _inputs[STAI_{{ NET_NAME }}_IN_NUM];
  stai_ptr _outputs[STAI_{{ NET_NAME }}_OUT_NUM];
  ai_i32 idx = 0;
  while (ai_platform_stai_update_io(&idx, _inputs, _outputs,
                                    input, output,
                                    STAI_{{ NET_NAME }}_IN_NUM,
                                    STAI_{{ NET_NAME }}_OUT_NUM))
  {
    ret = stai_{{ net_name }}_set_inputs((stai_network*)net_ctx->_ctx, _inputs, STAI_{{ NET_NAME }}_IN_NUM);
    if (ret != STAI_SUCCESS) {
      ai_platform_stai_handle_error(&net_ctx->error, ret);
      break;
    }
    ret = stai_{{ net_name }}_set_outputs((stai_network*)net_ctx->_ctx, _outputs, STAI_{{ NET_NAME }}_OUT_NUM);
    if (ret != STAI_SUCCESS) {
      ai_platform_stai_handle_error(&net_ctx->error, ret);
      break;
    }
    ret = stai_{{ net_name }}_run((stai_network*)net_ctx->_ctx, STAI_MODE_SYNC);

    if (ret != STAI_SUCCESS) {
      ai_platform_stai_handle_error(&net_ctx->error, ret);
      break;
    }
  }
  idx--;

  return idx;
{% else %}
{% if config['allocate_inputs'] == 0 %}
  stai_ptr _inputs[STAI_{{ NET_NAME }}_IN_NUM];
  for(int input_index = 0; input_index < STAI_{{ NET_NAME }}_IN_NUM; input_index++) {
     _inputs[input_index] = input[input_index].data;
  }
  ret = stai_{{ net_name }}_set_inputs((stai_network*)net_ctx->_ctx, _inputs, STAI_{{ NET_NAME }}_IN_NUM);
  if (ret != STAI_SUCCESS) {
    ai_platform_stai_handle_error(&net_ctx->error, ret);
    return 0;
  }
{% else %}
  STAI_UNUSED(input)
{%- endif %}
{% if config['allocate_outputs'] == 0 %}
  stai_ptr _outputs[STAI_{{ NET_NAME }}_OUT_NUM];
  for(int output_index = 0; output_index < STAI_{{ NET_NAME }}_OUT_NUM; output_index++) {
     _outputs[output_index] = output[output_index].data;
  }
  ret = stai_{{ net_name }}_set_outputs((stai_network*)net_ctx->_ctx, _outputs, STAI_{{ NET_NAME }}_OUT_NUM);
  if (ret != STAI_SUCCESS) {
    ai_platform_stai_handle_error(&net_ctx->error, ret);
    return 0;
  }
{% else %}
  STAI_UNUSED(output)
{%- endif %}
{%- endif %}

  ret = stai_{{ net_name }}_run((stai_network*)net_ctx->_ctx, STAI_MODE_SYNC);
  if (ret != STAI_SUCCESS) {
    ai_platform_stai_handle_error(&net_ctx->error, ret);
    return 0;
  }
  return 1;
}


AI_API_ENTRY
ai_i32 ai_{{ net_name }}_forward(ai_handle network, const ai_buffer* input)
{
  STAI_UNUSED(input)
  ai_network_context* net_ctx = (ai_network_context*)ai_platform_context_acquire(network);
  (void)net_ctx;
  return 0;
}

{% if config['reloc']  %}
#if defined (AI_NETWORK_RELOC)
AI_RELOC_NETWORK();
#endif
{% endif %}


#undef AI_{{ NET_NAME }}_MODEL_SIGNATURE
#undef AI_TOOLS_DATE_TIME
#undef AI_TOOLS_COMPILE_TIME

{{ "" }}
