{#
  ******************************************************************************
  * @file    stai_network.j2.c
  * @author  AST Embedded Analytics Research Platform
  * @brief   AI Tool Automatic Code Generator for Embedded NN computing
  ******************************************************************************
  * @attention
  *
  * Copyright (c) 2023 STMicroelectronics.
  * All rights reserved.
  *
  * This software is licensed under terms that can be found in the LICENSE file
  * in the root directory of this software component.
  * If no LICENSE file comes with this software, it is provided AS-IS.
  ******************************************************************************
#}
{%- import 'network_layers.j2.c' as layers -%}
{%- import 'op_lite_common.j2.c' as lite -%}
{%- import 'stai_common.j2.c' as stai -%}
{%- set lite_graphs = config['lite_graphs'] -%}
{%- set net_name = config['net_name'].lower() -%}
{%- set NET_NAME = config['net_name'].upper() -%}
{%- set hybrid_lite = config['hybrid_lite'] -%}
{%- set _activations = config[layers.ACTIVATIONS] -%}
{%- set _functions = config[layers.FUNCTIONS] -%}
{%- set _states = config[layers.STATES] -%}
{%- set _weights = config[layers.WEIGHTS] -%}

/**
  ******************************************************************************
  * @file    {{ net_name }}.c
  * @author  AST Embedded Analytics Research Platform
  * @date    {{ config['date_time'] }}
  * @brief   AI Tool Automatic Code Generator for Embedded NN computing
  ******************************************************************************
  * @attention
  *
  * {{ config['copyright'] }}
  * All rights reserved.
  *
  * This software is licensed under terms that can be found in the LICENSE file
  * in the root directory of this software component.
  * If no LICENSE file comes with this software, it is provided AS-IS.
  ******************************************************************************
  */

#include "ai_lite_inspect.h"
#include "ai_platform_interface.h"
#include "layers.h"
#include "core_convert.h"
#include "{{ net_name }}.h"
#include "{{ net_name }}_details.h"
{%- if _weights['hexify'] %}
#include "{{ net_name }}_data.h"
{%- endif %}
#include "stai_events.h"
{{layers.include_lambda(net_name, _functions)}}

{%- for file in config['includes']: %}
#include "{{ file }}"
{% endfor -%}

/*****************************************************************************/
#define STAI_INTERNAL_API_MAJOR               (1)
#define STAI_INTERNAL_API_MINOR               (0)
#define STAI_INTERNAL_API_MICRO               (0)

#define STAI_MAGIC                            (0xB1C00100)

/*****************************************************************************/
#define _STAI_CONCAT_ARG(a, b)     a ## b
#define STAI_CONCAT(a, b)         _STAI_CONCAT_ARG(a, b)

/*!  STAI_CAST SECTION                       *********************************/
#define STAI_CAST(type, expr) \
  ((type)(expr))


/*****************************************************************************/
#define STAI_SIZE(_size) \
  ((stai_size)(_size))

/*****************************************************************************/
#define STAI_INIT_BUFFER(_flags, _size, _address) \
  { \
    .size = (_size), \
    .address = (uintptr_t)(_address), \
    .flags = (_flags), \
  }

#define STAI_INIT_TENSOR(_name, _flags, _fmt, _size_bytes, _shape, _scale, _zeropoint) \
  { \
    .size_bytes = (_size_bytes), \
    .flags = (_flags), \
    .format = (stai_format)(_fmt), \
    .shape = STAI_PACK(_shape), \
    .scale = STAI_PACK(_scale), \
    .zeropoint = STAI_PACK(_zeropoint), \
    .name = (_name) \
  }

#define STAI_INIT_ARRAY(_size, _ptr) \
  { .size = STAI_SIZE(_size), .data = STAI_PACK(_ptr) }


#define STAI_CAST_ARRAY(_type, _size, _ptr) \
  { .size = STAI_SIZE(_size), .data = (_type)STAI_PACK(_ptr) }


#define STAI_DECLARE_ARRAY(_type, _size, ...) \
  { .size = STAI_SIZE(_size), .data = (_type[_size]) { STAI_PACK(__VA_ARGS__) } }


#define STAI_EMPTY_ARRAY() \
  { .size = 0, .data = NULL }


#define STAI_INIT_VERSION(_major, _minor, _micro) \
  { .major = (_major), .minor = (_minor), .micro = (_micro), .reserved = 0x0 }

/*****************************************************************************/
/**  Getters and setters  **/

#define STAI_GET_ARRAY_SIZE(nd_array) \
  (nd_array.size)


#define STAI_GET_ARRAY_ELEM(nd_array, pos) \
  (nd_array.data[(pos)])

#define _STAI_SET_ERROR(net_ctx, cond, value, exit) { \
  if (!(net_ctx)) { return STAI_ERROR_NETWORK_INVALID_CONTEXT_HANDLE; } \
  if (((uintptr_t)net_ctx) & (_STAI_CONTEXT_ALIGNMENT-1)) { return STAI_ERROR_NETWORK_INVALID_CONTEXT_ALIGNMENT; } \
  if (((value) >= STAI_ERROR_GENERIC) && (cond)) { \
    if ((net_ctx)->_return_code == STAI_SUCCESS) { \
      (net_ctx)->_return_code = (value); \
    } \
    return (exit); \
  } \
}

/*****************************************************************************/
/* TODO REMOVE THESE TWO MACROS */
#define STAI_EVENT_NODE_START_CB
#define STAI_EVENT_NODE_STOP_CB

#ifdef STAI_EVENT_NODE_START_CB
#ifndef _STAI_{{ NET_NAME }}_EVENT_NODE_START_CB
  #define _STAI_{{ NET_NAME }}_EVENT_NODE_START_CB(_node_id, _buffers_size, ...) \
  if (net_ctx->_callback) { \
    const stai_event_node_start_stop _start_event = { \
      .node_id=(_node_id), \
      .buffers={ \
        .size=(_buffers_size), \
        .data=(stai_ptr const*)(const stai_ptr[_buffers_size])STAI_PACK(__VA_ARGS__) \
      } \
    }; \
    net_ctx->_callback(net_ctx->_callback_cookie, STAI_EVENT_NODE_START, (const void*)&_start_event); \
  }
#endif
#else
  #define _STAI_{{ NET_NAME }}_EVENT_NODE_START_CB(_node_id, _buffers_size, ...) \
    do { /* _STAI_{{ NET_NAME }}_EVENT_NODE_START_CB() */ } while(0);
#endif      /* STAI_EVENT_NODE_START_CB */

#ifdef STAI_EVENT_NODE_STOP_CB
#ifndef _STAI_{{ NET_NAME }}_EVENT_NODE_STOP_CB
  #define _STAI_{{ NET_NAME }}_EVENT_NODE_STOP_CB(_node_id, _buffers_size, ...) \
  if (net_ctx->_callback) { \
    const stai_event_node_start_stop _stop_event = { \
      .node_id=(_node_id), \
      .buffers={ \
        .size=(_buffers_size), \
        .data=(stai_ptr const*)(stai_ptr[_buffers_size])STAI_PACK(__VA_ARGS__) \
      } \
    }; \
    net_ctx->_callback(net_ctx->_callback_cookie, STAI_EVENT_NODE_STOP, (const void*)&_stop_event); \
  }
#endif
#else
  #define _STAI_{{ NET_NAME }}_EVENT_NODE_STOP_CB(_node_id, _buffers_size, ...) \
    do { /* _STAI_{{ NET_NAME }}_EVENT_NODE_STOP_CB() */ } while(0);
#endif      /* STAI_EVENT_NODE_STOP_CB */


/*****************************************************************************/
#define _STAI_{{ NET_NAME }}_MODEL_SIGNATURE     "{{ config['model_signature'] }}"
#define _STAI_{{ NET_NAME }}_DATETIME            "{{ config['date_time'] }}"
#define _STAI_{{ NET_NAME }}_COMPILE_DATETIME    __DATE__ " " __TIME__

#define _STAI_CONTEXT_ALIGNMENT        (STAI_{{NET_NAME}}_CONTEXT_ALIGNMENT)

/*****************************************************************************/

{%- if config['allocate_activations'] %}
{% for buf in _activations['buffers']: -%}
/* Declare and allocate activation buffer #{{loop.index}} */
STAI_ALIGNED(STAI_{{NET_NAME}}_ACTIVATION_{{loop.index}}_ALIGNMENT)
static uint8_t g_{{net_name}}_activations_{{loop.index}}[STAI_{{NET_NAME}}_ACTIVATION_{{loop.index}}_SIZE];
{% endfor %}
{%- else %}
{% for buf in _activations['buffers']: -%}
#define g_{{net_name}}_activations_{{loop.index}}     (NULL)
{% endfor %}
{% endif %}

{%- if config['allocate_states'] %}
{%- if _states['size'] %}
{% for buf in _states['buffers']: -%}
/*  Declare and allocate state buffer #{{loop.index}}  */
STAI_ALIGNED(STAI_{{NET_NAME}}_STATE_{{loop.index}}_ALIGNMENT)
static uint8_t g_{{ net_name }}_states_{{loop.index}}[STAI_{{ NET_NAME }}_STATE_{{loop.index}}_SIZE];
{% endfor %}
{% endif %}
{%- else %}
{%- if _states['size'] %}
{% for buf in _states['buffers']: -%}
#define g_{{ net_name }}_states_{{loop.index}}     (NULL)
{% endfor %}
{% endif %}
{% endif %}

#if defined(HAVE_{{ NET_NAME }}_INFO)
/*****************************************************************************/
static const stai_network_info g_{{ net_name }}_info = {
  .model_signature = _STAI_{{ NET_NAME }}_MODEL_SIGNATURE,
  .c_compile_datetime = _STAI_{{ NET_NAME }}_COMPILE_DATETIME,
  .c_model_name = STAI_{{ NET_NAME }}_MODEL_NAME,
  .c_model_datetime = _STAI_{{ NET_NAME }}_DATETIME,
  .c_model_signature = 0x0,
  .runtime_version = STAI_INIT_VERSION(10, 2, 0),
  .tool_version = STAI_INIT_VERSION({{config['st_ai_version']['major']}}, {{config['st_ai_version']['minor']}}, {{config['st_ai_version']['micro']}}),
  .api_version = STAI_INIT_VERSION(1, 0, 0),
  .n_macc = STAI_{{ NET_NAME }}_MACC_NUM,
  .n_nodes = STAI_{{ NET_NAME }}_NODES_NUM,
  .flags = STAI_{{ NET_NAME }}_FLAGS,
  .n_inputs = STAI_{{ NET_NAME }}_IN_NUM,
  .n_outputs = STAI_{{ NET_NAME }}_OUT_NUM,
  .n_activations = STAI_{{ NET_NAME }}_ACTIVATIONS_NUM,
  .n_weights = STAI_{{ NET_NAME }}_WEIGHTS_NUM,
  .n_states = STAI_{{ NET_NAME }}_STATES_NUM,
  .inputs = (stai_tensor[STAI_{{ NET_NAME }}_IN_NUM]) {
{%- for _shape, _buffer in config['in_shapes']: %}
    STAI_INIT_TENSOR(
      STAI_{{ NET_NAME }}_IN_{{loop.index}}_NAME,
      STAI_{{ NET_NAME }}_IN_{{loop.index}}_FLAGS,
      STAI_{{ NET_NAME }}_IN_{{loop.index}}_FORMAT,
      STAI_{{ NET_NAME }}_IN_{{loop.index}}_SIZE_BYTES,
      {{ stai.render_array('int32_t', _buffer['dl_shape'].to_stai()[1:]) }},
      {{ stai.render_array('float', _buffer['fmt'].get_scale(), postfix='f') }},
      {{ stai.render_array('int16_t', _buffer['fmt'].get_zero()) }}),
{%- endfor %}
    },
    .outputs = (stai_tensor[STAI_{{ NET_NAME }}_OUT_NUM]) {
{%- for _shape, _buffer in config['out_shapes']: %}
    STAI_INIT_TENSOR(
      STAI_{{ NET_NAME }}_OUT_{{loop.index}}_NAME,
      STAI_{{ NET_NAME }}_OUT_{{loop.index}}_FLAGS,
      STAI_{{ NET_NAME }}_OUT_{{loop.index}}_FORMAT,
      STAI_{{ NET_NAME }}_OUT_{{loop.index}}_SIZE_BYTES,
      {{ stai.render_array('int32_t', _buffer['dl_shape'].to_stai()[1:]) }},
      {{ stai.render_array('float', _buffer['fmt'].get_scale(), postfix='f') }},
      {{ stai.render_array('int16_t', _buffer['fmt'].get_zero()) }}),
{%- endfor %}
    },
{%- if _activations['size'] > 0 %}
  .activations = (stai_tensor[STAI_{{ NET_NAME }}_ACTIVATIONS_NUM]) {
{%- for _activation in _activations['buffers'] %}
    STAI_INIT_TENSOR(
      (NULL),
      STAI_{{ NET_NAME }}_ACTIVATION_{{loop.index}}_FLAGS,
      STAI_FORMAT_U8,
      STAI_{{ NET_NAME }}_ACTIVATION_{{loop.index}}_SIZE_BYTES,
      {{ stai.render_array('int32_t', [_activation['pool_size']]) }},
      STAI_EMPTY_ARRAY(),
      STAI_EMPTY_ARRAY()),
{%- endfor %}
    },
{%- else %}
  .activations = NULL,
{%- endif %}
{%- if _weights['size'] > 0 %}
  .weights = (stai_tensor[STAI_{{ NET_NAME }}_WEIGHTS_NUM]) {
{%- for _weight in _weights['buffers'] %}
    STAI_INIT_TENSOR(
      (NULL),
      STAI_{{ NET_NAME }}_WEIGHT_{{loop.index}}_FLAGS,
      STAI_FORMAT_U8,
      STAI_{{ NET_NAME }}_WEIGHT_{{loop.index}}_SIZE_BYTES,
      {{ stai.render_array('int32_t', [_weight['pool_size']]) }},
      STAI_EMPTY_ARRAY(),
      STAI_EMPTY_ARRAY()),
{%- endfor %}
    },
{% else %}
    .weights = NULL,
{% endif -%}
{%- if _states['size'] > 0 %}
  .states = (stai_tensor[STAI_{{ NET_NAME }}_STATES_NUM]) {
{%- for _state in _states['buffers'] %}
    STAI_INIT_TENSOR(
      (NULL),
      STAI_{{ NET_NAME }}_STATE_{{loop.index}}_FLAGS,
      STAI_FORMAT_U8,
      STAI_{{ NET_NAME }}_STATE_{{loop.index}}_SIZE_BYTES,
      {{ stai.render_array('int32_t', [_state['pool_size']]) }},
      STAI_EMPTY_ARRAY(),
      STAI_EMPTY_ARRAY())
{%- endfor %}
    }
{%- else %}
  .states = NULL
{%- endif %}
};
#endif

#define _STAI_CONTEXT_ACQUIRE(_net_ctx, _net_handle) \
  _stai_{{ net_name }}_context* _net_ctx = (_stai_{{ net_name }}_context*)(_net_handle); \
  STAI_ASSERT(_net_ctx != NULL) \
  _STAI_SET_ERROR(_net_ctx, _net_ctx->_magic != STAI_MAGIC, \
                  STAI_ERROR_NETWORK_INVALID_CONTEXT_HANDLE, _net_ctx->_return_code)


/*****************************************************************************/
static
void _stai_{{ net_name }}_check(_stai_{{ net_name }}_context* net_ctx)
{
  stai_size idx;

{% if _activations['size'] > 0 -%}
  // Check activations status
  for (idx=0; idx<STAI_{{ NET_NAME }}_ACTIVATIONS_NUM; idx++) {
    if (net_ctx->_activations[idx] == NULL) break;
  }
  net_ctx->_flags |= (idx == STAI_{{ NET_NAME }}_ACTIVATIONS_NUM) ? STAI_FLAG_ACTIVATIONS : STAI_FLAG_NONE;
{% endif -%}

{% if _states['size'] > 0 -%}
  // Check states status
  for (idx=0; idx<STAI_{{ NET_NAME }}_STATES_NUM; idx++) {
    if (net_ctx->_states[idx] == NULL) break;
  }
  net_ctx->_flags |= (idx == STAI_{{ NET_NAME }}_STATES_NUM) ? STAI_FLAG_STATES : STAI_FLAG_NONE;
{% endif -%}

  // Check inputs status
  for (idx=0; idx<STAI_{{ NET_NAME }}_IN_NUM; idx++) {
    if (net_ctx->_inputs[idx] == NULL) break;
  }
  net_ctx->_flags |= (idx == STAI_{{ NET_NAME }}_IN_NUM) ? STAI_FLAG_INPUTS : STAI_FLAG_NONE;

  // Check outputs status
  for (idx=0; idx<STAI_{{ NET_NAME }}_OUT_NUM; idx++) {
    if (net_ctx->_outputs[idx] == NULL) break;
  }
  net_ctx->_flags |= (idx == STAI_{{ NET_NAME }}_OUT_NUM) ? STAI_FLAG_OUTPUTS : STAI_FLAG_NONE;

{% if _weights['size'] > 0 -%}
  // Check weights status
  for (idx=0; idx<STAI_{{ NET_NAME }}_WEIGHTS_NUM; idx++) {
    if (net_ctx->_weights[idx] == NULL) break;
  }
  net_ctx->_flags |= (idx == STAI_{{ NET_NAME }}_WEIGHTS_NUM) ? STAI_FLAG_WEIGHTS : STAI_FLAG_NONE;
{% endif -%}

  STAI_PRINT("  [_stai_network_check] flags: 0x%08x\n", net_ctx->_flags)
}


/*****************************************************************************/
STAI_API_ENTRY
stai_return_code stai_{{net_name}}_init(
  stai_network* network)
{
  /* Memory where to store internal context is provided by applications as a raw byte buffer */
  _stai_{{ net_name }}_context* net_ctx = (_stai_{{ net_name }}_context*)(network);
  net_ctx->_return_code = STAI_SUCCESS;
  STAI_PRINT("[Entering Network Init] network(%p) context_size(%d)\n", net_ctx, (int32_t)sizeof(_stai_{{ net_name }}_context))

  _STAI_SET_ERROR(net_ctx, STAI_{{ NET_NAME }}_CONTEXT_SIZE != sizeof(_stai_{{ net_name }}_context),
                 STAI_ERROR_NETWORK_INVALID_CONTEXT_SIZE, net_ctx->_return_code)

  {
    const _stai_{{ net_name }}_context _{{ net_name }}_context = {
      ._magic = STAI_MAGIC,
      ._signature = STAI_{{ NET_NAME }}_MODEL_SIGNATURE,
      ._flags = STAI_{{ NET_NAME }}_FLAGS,
      ._return_code = STAI_SUCCESS,
      ._callback = NULL,
      ._callback_cookie = NULL,
      {% if _activations['size'] > 0 -%}
      ._activations = {
      {% for buf in _activations['buffers']: -%}
        (stai_ptr)g_{{net_name}}_activations_{{loop.index}}{%- if not loop.last -%}, {%- endif -%}
      {% endfor %}
      },
      {% endif -%}
      {% if _states['size'] > 0 -%}
      ._states = {
      {% for buf in _states['buffers']: -%}
        (stai_ptr)g_{{ net_name }}_states_{{loop.index}}{%- if not loop.last -%}, {%- endif -%}
      {% endfor %}
      },
      {% endif -%}
      {% if _weights['size'] > 0 -%}
      ._weights = {
      {% for buf in _weights['buffers']: -%}
      {% if _weights['hexify'] -%}
        (stai_ptr){{buf['buffer_c_name']}}{%- if not loop.last -%}, {%- endif -%}
      {% else -%}
        NULL{%- if not loop.last -%}, {%- endif -%}
      {% endif -%}
      {% endfor %}
      },
      {% endif -%}
      ._inputs = {
    {% for buf in config['c_net_in']: -%}
      {% if config['allocate_activations'] and buf['pool_id'] is not none -%}
        (stai_ptr)g_{{net_name}}_activations_{{buf['pool_id'] + 1}} + {{buf['offset']}}{%- if not loop.last -%},{%- endif -%}
      {% else -%}
        NULL{%- if not loop.last -%},{%- endif -%}
      {% endif -%}
    {% endfor -%}
      },
      ._outputs = {
    {% for buf in config['c_net_out']: -%}
      {% if config['allocate_activations'] and buf['pool_id'] is not none -%}
        (stai_ptr)g_{{net_name}}_activations_{{buf['pool_id'] + 1}} + {{buf['offset']}}{%- if not loop.last -%},{%- endif -%}
      {% else -%}
        NULL{%- if not loop.last -%},{%- endif -%}
      {% endif -%}
    {% endfor -%}
      },
    };

    // Deep copy of internal context to opaque buffer provided by app
    *net_ctx = _{{net_name}}_context;

    _stai_{{net_name}}_check(net_ctx);
  }

  return net_ctx->_return_code;
}


STAI_API_ENTRY
stai_return_code stai_{{ net_name }}_deinit(
  stai_network* network)
{
  _STAI_CONTEXT_ACQUIRE(net_ctx, network)

  /*  Reset flags to initial state  */
  net_ctx->_flags = STAI_{{ NET_NAME }}_FLAGS;
  return net_ctx->_return_code;
}

/*****************************************************************************/
{%  for init_func in config['custom']['init']: -%}
void {{ init_func }}(ai_layer* layer);
{% endfor %}
{%  for forward_func in config['custom']['forward']: -%}
void {{ forward_func }}(ai_layer* layer);
{% endfor %}

{% for _, intq in hybrid_lite[layers.INTQS].items(): -%}
{{ layers.declare_intq(intq, loop.index0) }}
{% endfor %}

{% for _, array in hybrid_lite[layers.ARRAYS].items(): -%}
{{ layers.declare_array(array, loop.index0) }}
{% endfor %}

{% for _, tensor in hybrid_lite[layers.TENSORS].items(): -%}
{{ layers.declare_tensor(tensor, loop.index0) }}
{% endfor %}

{%- for layer in hybrid_lite[layers.LAYERS] %}
{{ layers.declare_layer(layer['layer'], 'NULL') }}
{% endfor -%}

{%- set stai_net_ctx = "_stai_" + net_name + "_context* net_ctx" -%}
{{ lite.hybrid_lite_render_functions(hybrid_lite['layers'], stai_net_ctx, False, NET_NAME) }}

/*****************************************************************************/


STAI_API_ENTRY
stai_return_code stai_{{ net_name }}_run(
  stai_network* network,
  const stai_run_mode mode)
{
   STAI_UNUSED(mode)
  _STAI_CONTEXT_ACQUIRE(net_ctx, network)
{% if _activations['size'] > 0 %}
  _STAI_SET_ERROR(net_ctx, (net_ctx->_flags & STAI_FLAG_ACTIVATIONS) != STAI_FLAG_ACTIVATIONS,
        STAI_ERROR_NETWORK_INVALID_ACTIVATIONS_PTR, net_ctx->_return_code)
{% endif %}
  _STAI_SET_ERROR(net_ctx, (net_ctx->_flags & STAI_FLAG_INPUTS) != STAI_FLAG_INPUTS,
                  STAI_ERROR_NETWORK_INVALID_IN_PTR, net_ctx->_return_code)
  _STAI_SET_ERROR(net_ctx, (net_ctx->_flags & STAI_FLAG_OUTPUTS) != STAI_FLAG_OUTPUTS,
                  STAI_ERROR_NETWORK_INVALID_OUT_PTR, net_ctx->_return_code)
{% if _weights['size'] > 0 %}
  _STAI_SET_ERROR(net_ctx, (net_ctx->_flags & STAI_FLAG_WEIGHTS) != STAI_FLAG_WEIGHTS,
                  STAI_ERROR_NETWORK_INVALID_WEIGHTS_PTR, net_ctx->_return_code)
{% endif %}

  {% for layer in config['layers']: -%}
  {{ lite.lite_render_section(layer) }}
  {% endfor -%}

  return net_ctx->_return_code;
}

/*****************************************************************************/
/*  Getters APIs Section  */
STAI_API_ENTRY
stai_size stai_{{ net_name }}_get_context_size()
{
  return (stai_size)STAI_{{ NET_NAME}}_CONTEXT_SIZE;
}

#if defined(HAVE_{{ NET_NAME }}_INFO)
STAI_API_ENTRY
stai_return_code stai_{{ net_name }}_get_info(
  stai_network* network,
  stai_network_info* info)
{
  _STAI_CONTEXT_ACQUIRE(net_ctx, network)
  _STAI_SET_ERROR(net_ctx, info==NULL, STAI_ERROR_NETWORK_INVALID_INFO, net_ctx->_return_code)

  // Copy of network info struct
  *info = g_{{ net_name }}_info;

  return STAI_SUCCESS;
}
#endif


STAI_API_ENTRY
stai_return_code stai_{{ net_name }}_get_activations(
  stai_network* network, stai_ptr* activations, stai_size* n_activations)
{
  _STAI_CONTEXT_ACQUIRE(net_ctx, network)

  _STAI_SET_ERROR(net_ctx, !n_activations, STAI_ERROR_NETWORK_INVALID_API_ARGUMENTS, net_ctx->_return_code)
  *n_activations = STAI_{{ NET_NAME }}_ACTIVATIONS_NUM;
{% if _activations['size'] > 0 -%}
  for (stai_size idx=0; activations && (idx<STAI_{{ NET_NAME }}_ACTIVATIONS_NUM); idx++) {
    // get address of the activations buffers
    activations[idx] = net_ctx->_activations[idx];
  }
{%- else %}
  AI_UNUSED(activations);
{% endif -%}
  return net_ctx->_return_code;
}


STAI_API_ENTRY
stai_return_code stai_{{ net_name }}_get_weights(
  stai_network* network, stai_ptr* weights, stai_size* n_weights)
{
  _STAI_CONTEXT_ACQUIRE(net_ctx, network)
  _STAI_SET_ERROR(net_ctx, !n_weights, STAI_ERROR_NETWORK_INVALID_API_ARGUMENTS, net_ctx->_return_code)
  *n_weights = STAI_{{ NET_NAME }}_WEIGHTS_NUM;
{% if _weights['size'] > 0 -%}
  for (stai_size idx=0; weights && (idx<STAI_{{ NET_NAME }}_WEIGHTS_NUM); idx++) {
    // get address of the weights buffers
    weights[idx] = net_ctx->_weights[idx];
  }
{%- else %}
  AI_UNUSED(weights);
{% endif -%}

  return net_ctx->_return_code;
}


STAI_API_ENTRY
stai_return_code stai_{{ net_name }}_get_inputs(
  stai_network* network, stai_ptr* inputs, stai_size* n_inputs)
{
  _STAI_CONTEXT_ACQUIRE(net_ctx, network)
  _STAI_SET_ERROR(net_ctx, !n_inputs, STAI_ERROR_NETWORK_INVALID_API_ARGUMENTS, net_ctx->_return_code)
  *n_inputs = STAI_{{ NET_NAME }}_IN_NUM;
  for (stai_size idx=0; inputs && (idx<STAI_{{ NET_NAME }}_IN_NUM); idx++) {
    inputs[idx] = net_ctx->_inputs[idx];
  }
  return net_ctx->_return_code;
}


STAI_API_ENTRY
stai_return_code stai_{{ net_name }}_get_outputs(
  stai_network* network, stai_ptr* outputs, stai_size* n_outputs)
{
  _STAI_CONTEXT_ACQUIRE(net_ctx, network)
  _STAI_SET_ERROR(net_ctx, !n_outputs, STAI_ERROR_NETWORK_INVALID_API_ARGUMENTS, net_ctx->_return_code)
  *n_outputs = STAI_{{ NET_NAME }}_OUT_NUM;
  for (stai_size idx=0; outputs && (idx<STAI_{{ NET_NAME }}_OUT_NUM); idx++) {
    outputs[idx] = net_ctx->_outputs[idx];
  }
  return net_ctx->_return_code;
}


STAI_API_ENTRY
stai_return_code stai_{{ net_name }}_get_error(
  stai_network* network)
{
  _STAI_CONTEXT_ACQUIRE(net_ctx, network)

  /* return 1st generated error or STAI_SUCCESS if no errors so far */
  return net_ctx->_return_code;
}


STAI_API_ENTRY
stai_return_code stai_{{ net_name }}_get_states(
  stai_network* network, stai_ptr* states, stai_size* n_states)
{
  _STAI_CONTEXT_ACQUIRE(net_ctx, network)
  _STAI_SET_ERROR(net_ctx, !n_states, STAI_ERROR_NETWORK_INVALID_API_ARGUMENTS, net_ctx->_return_code)
  /* get the number of internals states (supporting multi-heap also for internal states) */
  *n_states = STAI_{{ NET_NAME }}_STATES_NUM;
{% if _states['size'] > 0 -%}
  for (stai_size idx=0; states && idx<STAI_{{ NET_NAME }}_STATES_NUM; idx++) {
    states[idx] = net_ctx->_states[idx];
  }
{%- else %}
  STAI_UNUSED(states)
{% endif -%}
  return net_ctx->_return_code;
}


/*****************************************************************************/
/*  Setters APIs Section  */

STAI_API_ENTRY
stai_return_code stai_{{ net_name }}_set_activations(
  stai_network* network,
  const stai_ptr* activations,
  const stai_size n_activations)
{
  _STAI_CONTEXT_ACQUIRE(net_ctx, network)
{% if _activations['size'] > 0 -%}
  const uintptr_t _activations_alignment[] = STAI_{{ NET_NAME }}_ACTIVATIONS_ALIGNMENTS;
  STAI_PRINT("  [stai_{{ net_name }}_set_activations] network(%p) activations[%d]: %p\n\n", net_ctx, n_activations, activations)
  _STAI_SET_ERROR(net_ctx, !activations,
                  STAI_ERROR_NETWORK_INVALID_API_ARGUMENTS, net_ctx->_return_code)
  _STAI_SET_ERROR(net_ctx, n_activations!=STAI_{{ NET_NAME }}_ACTIVATIONS_NUM,
                  STAI_ERROR_NETWORK_INVALID_ACTIVATIONS_NUM, net_ctx->_return_code)

  for (stai_size idx=0; activations && idx<STAI_{{ NET_NAME }}_ACTIVATIONS_NUM; idx++) {
    STAI_PRINT("  activation[%d]: %p\n", idx, activations[idx])
    _STAI_SET_ERROR(net_ctx, activations[idx]==NULL,
                    STAI_ERROR_NETWORK_INVALID_ACTIVATIONS_PTR, net_ctx->_return_code)
    _STAI_SET_ERROR(net_ctx, ((uintptr_t)activations[idx]) & (_activations_alignment[idx]-1),
                    STAI_ERROR_INVALID_BUFFER_ALIGNMENT, net_ctx->_return_code)
    net_ctx->_activations[idx] = activations[idx];
  }

{%- for buf in config['c_net_in']: %}
{%- if buf['pool_id'] is not none %}
  net_ctx->_inputs[{{loop.index0}}] = activations[{{buf['pool_id']}}] + {{buf['offset']}};
{% endif -%}
{% endfor -%}
{%- for buf in config['c_net_out']: %}
{%- if buf['pool_id'] is not none %}
  net_ctx->_outputs[{{loop.index0}}] = activations[{{buf['pool_id']}}] + {{buf['offset']}};
{% endif -%}
{% endfor -%}
{%- else %}
  AI_UNUSED(activations);
  AI_UNUSED(n_activations);

{% endif -%}
  _stai_{{ net_name }}_check(net_ctx);
  return net_ctx->_return_code;
}


STAI_API_ENTRY
stai_return_code stai_{{ net_name }}_set_weights(
  stai_network* network,
  const stai_ptr* weights,
  const stai_size n_weights)
{
  _STAI_CONTEXT_ACQUIRE(net_ctx, network)
{% if _weights['size'] > 0 -%}
  const uintptr_t _weights_alignment[] = STAI_{{ NET_NAME }}_WEIGHTS_ALIGNMENTS;
  _STAI_SET_ERROR(net_ctx, !weights,
                  STAI_ERROR_NETWORK_INVALID_API_ARGUMENTS, net_ctx->_return_code)
  _STAI_SET_ERROR(net_ctx, n_weights!=STAI_{{ NET_NAME }}_WEIGHTS_NUM,
                  STAI_ERROR_NETWORK_INVALID_WEIGHTS_NUM, net_ctx->_return_code)
  for (stai_size idx=0; weights && idx<STAI_{{ NET_NAME }}_WEIGHTS_NUM; idx++) {
    STAI_PRINT("  weight[%d]: %p\n", idx, weights[idx])
    _STAI_SET_ERROR(net_ctx, weights[idx]==NULL,
                    STAI_ERROR_NETWORK_INVALID_WEIGHTS_PTR, net_ctx->_return_code)
    _STAI_SET_ERROR(net_ctx, ((uintptr_t)weights[idx]) & (_weights_alignment[idx]-1),
                    STAI_ERROR_INVALID_BUFFER_ALIGNMENT, net_ctx->_return_code)
    net_ctx->_weights[idx] = weights[idx];
  }
{%- else %}
  AI_UNUSED(weights);
  AI_UNUSED(n_weights);
{% endif -%}

  _stai_{{ net_name }}_check(net_ctx);
  return net_ctx->_return_code;
}


STAI_API_ENTRY
stai_return_code stai_{{ net_name }}_set_inputs(
  stai_network* network,
  const stai_ptr* inputs,
  const stai_size n_inputs)
{
  const uintptr_t _inputs_alignment[] = STAI_{{ NET_NAME }}_IN_ALIGNMENTS;
  _STAI_CONTEXT_ACQUIRE(net_ctx, network)
  _STAI_SET_ERROR(net_ctx, !inputs,
                  STAI_ERROR_NETWORK_INVALID_API_ARGUMENTS, net_ctx->_return_code)
  _STAI_SET_ERROR(net_ctx, n_inputs!=STAI_{{ NET_NAME }}_IN_NUM,
                  STAI_ERROR_NETWORK_INVALID_IN_NUM, net_ctx->_return_code)

  for (stai_size idx=0; inputs && idx<STAI_{{ NET_NAME }}_IN_NUM; idx++) {
    STAI_PRINT("  input[%d]: %p\n", idx, inputs[idx])
    _STAI_SET_ERROR(net_ctx, inputs[idx]==NULL,
                    STAI_ERROR_NETWORK_INVALID_IN_PTR, net_ctx->_return_code)
    _STAI_SET_ERROR(net_ctx, ((uintptr_t)inputs[idx]) & (_inputs_alignment[idx]-1),
                    STAI_ERROR_INVALID_BUFFER_ALIGNMENT, net_ctx->_return_code)
    net_ctx->_inputs[idx] = inputs[idx];
  }

  _stai_{{ net_name }}_check(net_ctx);
  return net_ctx->_return_code;
}


STAI_API_ENTRY
stai_return_code stai_{{ net_name }}_set_outputs(
  stai_network* network,
  const stai_ptr* outputs,
  const stai_size n_outputs)
{
  const uintptr_t _outputs_alignment[] = STAI_{{ NET_NAME }}_OUT_ALIGNMENTS;
  _STAI_CONTEXT_ACQUIRE(net_ctx, network)
  _STAI_SET_ERROR(net_ctx, !outputs,
                  STAI_ERROR_NETWORK_INVALID_API_ARGUMENTS, net_ctx->_return_code)
  _STAI_SET_ERROR(net_ctx, n_outputs!=STAI_{{ NET_NAME }}_OUT_NUM,
                  STAI_ERROR_NETWORK_INVALID_OUT_NUM, net_ctx->_return_code)

  for (stai_size idx=0; outputs && idx<n_outputs; idx++) {
    STAI_PRINT("  output[%d]: %p\n", idx, outputs[idx])
    _STAI_SET_ERROR(net_ctx, outputs[idx]==NULL,
                    STAI_ERROR_NETWORK_INVALID_OUT_PTR, net_ctx->_return_code)
    _STAI_SET_ERROR(net_ctx, ((uintptr_t)outputs[idx]) & (_outputs_alignment[idx]-1),
                    STAI_ERROR_INVALID_BUFFER_ALIGNMENT, net_ctx->_return_code)
    net_ctx->_outputs[idx] = outputs[idx];
  }

  _stai_{{ net_name }}_check(net_ctx);
  return net_ctx->_return_code;
}


STAI_API_ENTRY
stai_return_code stai_{{ net_name }}_set_states(
  stai_network* network,
  const stai_ptr* states,
  const stai_size n_states)
{
  _STAI_CONTEXT_ACQUIRE(net_ctx, network)
{% if _states['size'] > 0 -%}
  const uintptr_t _states_alignment[] = STAI_{{ NET_NAME }}_STATES_ALIGNMENTS;
  STAI_PRINT("  [stai_{{ net_name }}_set_states] network(%p) states[%d]: %p\n\n", net_ctx, n_states, states)
  _STAI_SET_ERROR(net_ctx, !states,
                  STAI_ERROR_NETWORK_INVALID_API_ARGUMENTS, net_ctx->_return_code)
  _STAI_SET_ERROR(net_ctx, n_states!=STAI_{{ NET_NAME }}_STATES_NUM,
                  STAI_ERROR_NETWORK_INVALID_STATES_NUM, net_ctx->_return_code)

  for (stai_size idx=0; states && idx<STAI_{{ NET_NAME }}_STATES_NUM; idx++) {
    STAI_PRINT("  state[%d]: %p\n", idx, states[idx])
    _STAI_SET_ERROR(net_ctx, states[idx]==NULL,
                    STAI_ERROR_NETWORK_INVALID_STATES_PTR, net_ctx->_return_code)
    _STAI_SET_ERROR(net_ctx, ((uintptr_t)states[idx]) & (_states_alignment[idx]-1),
                    STAI_ERROR_INVALID_BUFFER_ALIGNMENT, net_ctx->_return_code)
    net_ctx->_states[idx] = states[idx];
  }
{%- else %}
  STAI_UNUSED(states)
  STAI_UNUSED(n_states)
{% endif -%}

  _stai_{{ net_name }}_check(net_ctx);
  return net_ctx->_return_code;
}

STAI_API_ENTRY
stai_return_code stai_{{ net_name }}_set_callback(
  stai_network* network, const stai_event_cb cb, void* cb_cookie)
{
  _STAI_CONTEXT_ACQUIRE(net_ctx, network)
  STAI_PRINT("  set_callback %p cb %p cookie %p\n", net_ctx, cb, cb_cookie)
  // _STAI_SET_ERROR(net_ctx, cb==NULL, STAI_ERROR_NETWORK_INVALID_CALLBACK, net_ctx->_return_code)
  net_ctx->_callback = cb;
  net_ctx->_callback_cookie = cb_cookie;
  return net_ctx->_return_code;
}

#undef _STAI_SET_ERROR
#undef _STAI_CONTEXT_ALIGNMENT
#undef _STAI_CONTEXT_ACQUIRE
#undef _STAI_{{ NET_NAME }}_EVENT_NODE_START_CB
#undef _STAI_{{ NET_NAME }}_EVENT_NODE_STOP_CB
#undef _STAI_{{ NET_NAME }}_MODEL_SIGNATURE
#undef _STAI_{{ NET_NAME }}_DATETIME
#undef _STAI_{{ NET_NAME }}_COMPILE_DATETIME

{{ "" }}
