{#
  ******************************************************************************
  * @file    network.j2.c
  * @author  AST Embedded Analytics Research Platform
  * @brief   AI Tool Automatic Code Generator for Embedded NN computing
  ******************************************************************************
  * @attention
  *
  * Copyright (c) 2017 STMicroelectronics.
  * All rights reserved.
  *
  * This software is licensed under terms that can be found in the LICENSE file
  * in the root directory of this software component.
  * If no LICENSE file comes with this software, it is provided AS-IS.
  ******************************************************************************
#}

{%- import 'network_layers.j2.c' as layers -%}
{%- import 'op_lite_common.j2.c' as lite -%}
{% set lite_graphs = config[lite.LITE_GRAPHS] -%}
{% set net_name = config['net_name'].lower() -%}
{% set net_signature = config['net_signature'][:10] -%}
{% set NET_NAME = config['net_name'].upper() -%}
{% set custom = config[layers.CUSTOM] -%}
{% set _activations = config[layers.ACTIVATIONS] -%}
{% set _weights = config[layers.WEIGHTS] -%}

/**
  ******************************************************************************
  * @file    {{ net_name }}.c
  * @author  AST Embedded Analytics Research Platform
  * @date    {{ config['date_time'] }}
  * @brief   AI Tool Automatic Code Generator for Embedded NN computing
  ******************************************************************************
  * @attention
  *
  * {{ config['copyright'] }}
  * All rights reserved.
  *
  * This software is licensed under terms that can be found in the LICENSE file
  * in the root directory of this software component.
  * If no LICENSE file comes with this software, it is provided AS-IS.
  ******************************************************************************
  */

{{layers.include_lambda(net_name, config[layers.FUNCTIONS])}}
#include "{{ net_name }}.h"
#include "{{ net_name }}_data.h"

#include "ai_platform.h"
#include "ai_platform_interface.h"
#include "ai_math_helpers.h"

#include "core_common.h"
#include "core_convert.h"

#include "layers.h"

{% if config['reloc']  -%}
#if defined (AI_NETWORK_RELOC)
#include "ai_reloc_network.h"
#endif
{% endif -%}

{{ lite.include_lite(net_name, lite_graphs) }}

#undef AI_NET_OBJ_INSTANCE
#define AI_NET_OBJ_INSTANCE g_{{ net_name }}
 
#undef AI_{{ NET_NAME }}_MODEL_SIGNATURE
#define AI_{{ NET_NAME }}_MODEL_SIGNATURE     "{{ config['model_signature'] }}"

#ifndef AI_TOOLS_REVISION_ID
#define AI_TOOLS_REVISION_ID     "{{ config['git_commit'] }}"
#endif

#undef AI_TOOLS_DATE_TIME
#define AI_TOOLS_DATE_TIME   "{{ config['date_time'] }}"

#undef AI_TOOLS_COMPILE_TIME
#define AI_TOOLS_COMPILE_TIME    __DATE__ " " __TIME__

#undef AI_{{ NET_NAME }}_N_BATCHES
#define AI_{{ NET_NAME }}_N_BATCHES         (1)

{% if _activations['map'] -%}
static ai_ptr {{ _activations['map'] }}[{{ _activations['buffers']|length }}] = AI_C_ARRAY_INIT;
{% endif -%}

{% if _weights['map'] -%}
static ai_ptr {{ _weights['map'] }}[{{ _weights['buffers']|length }}] = AI_C_ARRAY_INIT;
{% endif -%}

{%  for init_func in custom['init']: -%}
void {{ init_func }}(ai_layer* layer);
{% endfor %}
{%  for forward_func in custom['forward']: -%}
void {{ forward_func }}(ai_layer* layer);
{% endfor %}

/**  Array declarations section  **********************************************/
{% for array in config[layers.ARRAYS]: -%}
{{ layers.declare_array(array, loop.index0) }}
{% endfor %}

{%- if config[layers.FORMATS]|length>0 -%}
/**  Array metadata declarations section  *************************************/
{% for format in config[layers.FORMATS]: -%}
{{ layers.declare_intq(format, loop.index0, is_const=True) }}
{% endfor %}
{%- endif -%}

/**  Tensor declarations section  *********************************************/
{% for tensor in config[layers.TENSORS]: -%}
{{ layers.declare_tensor(tensor, loop.index0) }}
{% endfor %}

/**  Layer declarations section  **********************************************/

{% for layer in config[layers.LAYERS][::-1]: -%}
{% if layer['c_code_emit'] -%}
{{"\n"-}}
  {{ layers.declare_layer(layer, 'NULL') }}
{% endif -%}
{% endfor %}

#if (AI_TOOLS_API_VERSION < AI_TOOLS_API_VERSION_1_5)

AI_NETWORK_OBJ_DECLARE(
  AI_NET_OBJ_INSTANCE, AI_STATIC,
  AI_BUFFER_INIT(AI_FLAG_NONE,  AI_BUFFER_FORMAT_U8,
    AI_BUFFER_SHAPE_INIT(AI_SHAPE_BCWH, 4, 1, {{ _weights['size'] }}, 1, 1),
    {{ _weights['size'] }}, NULL, NULL),
  AI_BUFFER_INIT(AI_FLAG_NONE,  AI_BUFFER_FORMAT_U8,
    AI_BUFFER_SHAPE_INIT(AI_SHAPE_BCWH, 4, 1, {{ _activations['size'] }}, 1, 1),
    {{ _activations['size'] }}, NULL, NULL),
  AI_TENSOR_LIST_IO_OBJ_INIT(AI_FLAG_NONE, AI_{{ NET_NAME }}_IN_NUM, {{ config['net_in_names'] }}),
  AI_TENSOR_LIST_IO_OBJ_INIT(AI_FLAG_NONE, AI_{{ NET_NAME }}_OUT_NUM, {{ config['net_out_names'] }}),
  &{{ config['first_layer'] }}, {{net_signature}}, NULL)

#else

AI_NETWORK_OBJ_DECLARE(
  AI_NET_OBJ_INSTANCE, AI_STATIC,
{%- if _weights['size'] > 0 %}
  AI_BUFFER_ARRAY_OBJ_INIT_STATIC(
  	AI_FLAG_NONE, {{ _weights['buffers']|length }},
    AI_BUFFER_INIT(AI_FLAG_NONE,  AI_BUFFER_FORMAT_U8,
      AI_BUFFER_SHAPE_INIT(AI_SHAPE_BCWH, 4, 1, {{ _weights['size'] }}, 1, 1),
      {{ _weights['size'] }}, NULL, NULL)
{%- else %}
  AI_BUFFER_ARRAY_OBJ_INIT(
	  AI_FLAG_NONE, 0, NULL
{%- endif %}
  ),
{%- if _activations['size'] > 0 %}
  AI_BUFFER_ARRAY_OBJ_INIT_STATIC(
  	AI_FLAG_NONE, {{ _activations['buffers']|length }},
    AI_BUFFER_INIT(AI_FLAG_NONE,  AI_BUFFER_FORMAT_U8,
      AI_BUFFER_SHAPE_INIT(AI_SHAPE_BCWH, 4, 1, {{ _activations['size'] }}, 1, 1),
      {{ _activations['size'] }}, NULL, NULL)
{%- else %}
  AI_BUFFER_ARRAY_OBJ_INIT(
	  AI_FLAG_NONE, 0, NULL
{%- endif %}
  ),
  AI_TENSOR_LIST_IO_OBJ_INIT(AI_FLAG_NONE, AI_{{ NET_NAME }}_IN_NUM, {{ config['net_in_names'] }}),
  AI_TENSOR_LIST_IO_OBJ_INIT(AI_FLAG_NONE, AI_{{ NET_NAME }}_OUT_NUM, {{ config['net_out_names'] }}),
  &{{ config['first_layer'] }}, {{net_signature}}, NULL)

#endif	/*(AI_TOOLS_API_VERSION < AI_TOOLS_API_VERSION_1_5)*/


{% if _activations['map'] %}
/******************************************************************************/
AI_DECLARE_STATIC
ai_bool {{ net_name }}_configure_activations(
  ai_network* net_ctx, const ai_network_params* params)
{
  AI_ASSERT(net_ctx)

  if (ai_platform_get_activations_map({{ _activations['map'] }}, {{ _activations['buffers']|length }}, params)) {
    /* Updating activations (byte) offsets */
    {{ layers.activations_init_offsets(_activations['buffers'], _activations['map']) }}
    return true;
  }
  AI_ERROR_TRAP(net_ctx, INIT_FAILED, NETWORK_ACTIVATIONS);
  return false;
}
{% endif %}


{% if _weights['map'] %}
/******************************************************************************/
AI_DECLARE_STATIC
ai_bool {{ net_name }}_configure_weights(
  ai_network* net_ctx, const ai_network_params* params)
{
  AI_ASSERT(net_ctx)

  if (ai_platform_get_weights_map({{ _weights['map'] }}, {{ _weights['buffers']|length }}, params)) {
    /* Updating weights (byte) offsets */
    {{ layers.weights_init_offsets(_weights['buffers'], _weights['map']) }}
    return true;
  }
  AI_ERROR_TRAP(net_ctx, INIT_FAILED, NETWORK_WEIGHTS);
  return false;
}
{% endif %}

/**  PUBLIC APIs SECTION  *****************************************************/
{% if _activations['buffers']|length <= 1 %}


AI_DEPRECATED
AI_API_ENTRY
ai_bool ai_{{ net_name }}_get_info(
  ai_handle network, ai_network_report* report)
{
  ai_network* net_ctx = AI_NETWORK_ACQUIRE_CTX(network);

  if (report && net_ctx)
  {
    ai_network_report r = {
      .model_name        = AI_{{ NET_NAME }}_MODEL_NAME,
      .model_signature   = AI_{{ NET_NAME }}_MODEL_SIGNATURE,
      .model_datetime    = AI_TOOLS_DATE_TIME,
      
      .compile_datetime  = AI_TOOLS_COMPILE_TIME,
      
      .runtime_revision  = ai_platform_runtime_get_revision(),
      .runtime_version   = ai_platform_runtime_get_version(),

      .tool_revision     = AI_TOOLS_REVISION_ID,
      .tool_version      = {AI_TOOLS_VERSION_MAJOR, AI_TOOLS_VERSION_MINOR,
                            AI_TOOLS_VERSION_MICRO, 0x0},
      .tool_api_version  = AI_STRUCT_INIT,

      .api_version            = ai_platform_api_get_version(),
      .interface_api_version  = ai_platform_interface_api_get_version(),
      
      .n_macc            = {{config['Macc']}},
      .n_inputs          = 0,
      .inputs            = NULL,
      .n_outputs         = 0,
      .outputs           = NULL,
      .params            = AI_STRUCT_INIT,
      .activations       = AI_STRUCT_INIT,
      .n_nodes           = 0,
      .signature         = {{net_signature}},
    };

    if (!ai_platform_api_get_network_report(network, &r)) return false;

    *report = r;
    return true;
  }
  return false;
}
{% endif %}


AI_API_ENTRY
ai_bool ai_{{ net_name }}_get_report(
  ai_handle network, ai_network_report* report)
{
  ai_network* net_ctx = AI_NETWORK_ACQUIRE_CTX(network);

  if (report && net_ctx)
  {
    ai_network_report r = {
      .model_name        = AI_{{ NET_NAME }}_MODEL_NAME,
      .model_signature   = AI_{{ NET_NAME }}_MODEL_SIGNATURE,
      .model_datetime    = AI_TOOLS_DATE_TIME,
      
      .compile_datetime  = AI_TOOLS_COMPILE_TIME,
      
      .runtime_revision  = ai_platform_runtime_get_revision(),
      .runtime_version   = ai_platform_runtime_get_version(),

      .tool_revision     = AI_TOOLS_REVISION_ID,
      .tool_version      = {AI_TOOLS_VERSION_MAJOR, AI_TOOLS_VERSION_MINOR,
                            AI_TOOLS_VERSION_MICRO, 0x0},
      .tool_api_version  = AI_STRUCT_INIT,

      .api_version            = ai_platform_api_get_version(),
      .interface_api_version  = ai_platform_interface_api_get_version(),
      
      .n_macc            = {{config['Macc']}},
      .n_inputs          = 0,
      .inputs            = NULL,
      .n_outputs         = 0,
      .outputs           = NULL,
      .map_signature     = AI_MAGIC_SIGNATURE,
      .map_weights       = AI_STRUCT_INIT,
      .map_activations   = AI_STRUCT_INIT,
      .n_nodes           = 0,
      .signature         = {{net_signature}},
    };

    if (!ai_platform_api_get_network_report(network, &r)) return false;

    *report = r;
    return true;
  }
  return false;
}


AI_API_ENTRY
ai_error ai_{{ net_name }}_get_error(ai_handle network)
{
  return ai_platform_network_get_error(network);
}


AI_API_ENTRY
ai_error ai_{{ net_name }}_create(
  ai_handle* network, const ai_buffer* network_config)
{
  return ai_platform_network_create(
    network, network_config, 
    AI_CONTEXT_OBJ(&AI_NET_OBJ_INSTANCE),
    AI_TOOLS_API_VERSION_MAJOR, AI_TOOLS_API_VERSION_MINOR, AI_TOOLS_API_VERSION_MICRO);
}


AI_API_ENTRY
ai_error ai_{{ net_name }}_create_and_init(
  ai_handle* network, const ai_handle activations[], const ai_handle weights[])
{
  ai_error err;
  ai_network_params params;

  err = ai_{{ net_name }}_create(network, AI_{{ NET_NAME }}_DATA_CONFIG);
  if (err.type != AI_ERROR_NONE) {
    return err;
  }
  
  if (ai_{{ net_name }}_data_params_get(&params) != true) {
    err = ai_{{ net_name }}_get_error(*network);
    return err;
  }
#if defined(AI_{{ NET_NAME }}_DATA_ACTIVATIONS_COUNT)
  /* set the addresses of the activations buffers */
  for (ai_u16 idx=0; activations && idx<params.map_activations.size; idx++) {
    AI_BUFFER_ARRAY_ITEM_SET_ADDRESS(&params.map_activations, idx, activations[idx]);
  }
#endif
#if defined(AI_{{ NET_NAME }}_DATA_WEIGHTS_COUNT)
  /* set the addresses of the weight buffers */
  for (ai_u16 idx=0; weights && idx<params.map_weights.size; idx++) {
    AI_BUFFER_ARRAY_ITEM_SET_ADDRESS(&params.map_weights, idx, weights[idx]);
  }
#endif
  if (ai_{{ net_name }}_init(*network, &params) != true) {
    err = ai_{{ net_name }}_get_error(*network);
  }
  return err;
}


AI_API_ENTRY
ai_buffer* ai_{{ net_name }}_inputs_get(ai_handle network, ai_u16 *n_buffer)
{
  if (network == AI_HANDLE_NULL) {
    network = (ai_handle)&AI_NET_OBJ_INSTANCE;
    AI_NETWORK_OBJ(network)->magic = AI_MAGIC_CONTEXT_TOKEN;
  }
  return ai_platform_inputs_get(network, n_buffer);
}


AI_API_ENTRY
ai_buffer* ai_{{ net_name }}_outputs_get(ai_handle network, ai_u16 *n_buffer)
{
  if (network == AI_HANDLE_NULL) {
    network = (ai_handle)&AI_NET_OBJ_INSTANCE;
    AI_NETWORK_OBJ(network)->magic = AI_MAGIC_CONTEXT_TOKEN;
  }
  return ai_platform_outputs_get(network, n_buffer);
}


AI_API_ENTRY
ai_handle ai_{{ net_name }}_destroy(ai_handle network)
{
  return ai_platform_network_destroy(network);
}


AI_API_ENTRY
ai_bool ai_{{ net_name }}_init(
  ai_handle network, const ai_network_params* params)
{
  ai_network* net_ctx = AI_NETWORK_OBJ(ai_platform_network_init(network, params));
  ai_bool ok = true;

  if (!net_ctx) return false;

  {%- if _weights['map'] %}
  ok &= {{ net_name }}_configure_weights(net_ctx, params);
  {%- endif %}
  {%- if _activations['map'] %}
  ok &= {{net_name}}_configure_activations(net_ctx, params);
  {%- endif %}

  ok &= ai_platform_network_post_init(network);

  return ok;
}


AI_API_ENTRY
ai_i32 ai_{{ net_name }}_run(
  ai_handle network, const ai_buffer* input, ai_buffer* output)
{
  return ai_platform_network_process(network, input, output);
}


AI_API_ENTRY
ai_i32 ai_{{ net_name }}_forward(ai_handle network, const ai_buffer* input)
{
  return ai_platform_network_process(network, input, NULL);
}

{% if config['reloc']  %}
#if defined (AI_NETWORK_RELOC)
AI_RELOC_NETWORK();
#endif
{% endif %}

#undef AI_{{ NET_NAME }}_MODEL_SIGNATURE
#undef AI_NET_OBJ_INSTANCE
#undef AI_TOOLS_DATE_TIME
#undef AI_TOOLS_COMPILE_TIME

{{ "" }}
