#ifndef HAVE_CHECKPOINT_HPP
#define HAVE_CHECKPOINT_HPP
// Autogenerated - do not edit by hand !
#include <memory>

#include "global.hpp"
#include "vectorize.hpp"

namespace TMBad {

/** \brief Fixed derivative table used by `AtomOp` */
template <class ADFun, bool packed_ = false>
struct standard_derivative_table : std::vector<ADFun> {
  static const bool packed = packed_;
  /** \brief Add derivatives up to this order. */
  void requireOrder(size_t n) {
    while ((*this).size() <= n) {
      Position root;
      std::swap((*this).back().tail_start, root);
      ADFun deriv = (*this).back().WgtJacFun();
      std::swap((*this).back().tail_start, root);

      (*this).push_back(deriv);
    }
  }
  /** \brief Retaping this derivative table has no effect. */
  template <class ARGS>
  void retape(ARGS &args) {}
  /** \brief Set zero order function of this derivative table. */
  standard_derivative_table(const ADFun &F) : std::vector<ADFun>(1, F) {}
};

/** \brief Default tester for `retaping_derivative_table`.
    \param x Previous parameter vector
    \param y Current parameter vector
    \return `true` if retaping is required. `false` otherwise.
*/
struct ParametersChanged {
  static const bool packed = false;
  std::vector<Scalar> x_prev;
  bool operator()(const std::vector<Scalar> &x);
};

/** \brief Adaptive derivative table used by `AtomOp` */
template <class Functor, class ADFun, class Test = ParametersChanged,
          bool packed_ = false>
struct retaping_derivative_table : standard_derivative_table<ADFun, packed_> {
  Functor F;
  Test test;
  /** \brief Retape the zero derivative and remove all higher orders
      from the table. */
  template <class ARGS>
  void retape(ARGS &args) {
    size_t n = (*this)[0].Domain();
    std::vector<Scalar> x = args.x_segment(0, n);
    bool change = test(x);
    if (change) {
      (*this).resize(1);
      (*this)[0] = ADFun(F, x);
    }
  }
  /** \brief Set zero order *functor* used to retape this derivative
      table. */
  template <class V>
  retaping_derivative_table(const Functor &F, const V &x, Test test = Test())
      : standard_derivative_table<ADFun, packed_>(ADFun(F, x)),
        F(F),
        test(test) {}
};

/** \brief Manage shared operator data across multiple threads.
    \details An operator that shares a resource using the normal
    `std::shared_ptr` can be made thread safe by simply replacing
    'shared_ptr' by 'omp_shared_ptr'. Thread safety is guarantied in
    the following two situations:

    1. Parallel taping: The thread that *constructs* an operator is the
    same thread that *uses* the operator. Thread safety (and memory
    locality) in this case is already guarantied by the normal
    std::shared.

    2. Automatic parallelization: The tape is constructed using a single
    thread, then split and **copied** to the individual worker
    threads. Thread safety (and memory locality) in this case requires
    `omp_shared_ptr`.

    Each instance has an extra pointer to a 'book-keeping' structure,
    essentially doubling its size compared to the normal shared_ptr.
    Operator copy is still a fairly quick operation involving
    essentially two reference counter increments.

    The resource sharing has the following properties:

    - *Space efficiency* Threads that do *not* require the shared
       resource will *not* have an allocated copy (no waste across
       threads).
    - *Memory locality* First touch principle: When new thread takes
       ownership a deep copy is invoked.
    - *Access performance* Once the object is constructed there is no
       additional overhead compared to normal std::shared.
*/
template <class T>
struct omp_shared_ptr {
  typedef std::shared_ptr<T> Base;
  Base sp;
  std::shared_ptr<std::vector<std::weak_ptr<T> > > weak_refs;

  omp_shared_ptr(const Base &x)
      : sp(x), weak_refs(std::make_shared<std::vector<std::weak_ptr<T> > >()) {
    (*weak_refs).resize(TMBAD_MAX_NUM_THREADS);
    (*weak_refs)[TMBAD_THREAD_NUM] = x;
  }
  omp_shared_ptr(const omp_shared_ptr &other) : weak_refs(other.weak_refs) {
    if ((*weak_refs)[TMBAD_THREAD_NUM].expired()) {
      sp = std::make_shared<T>(*other);

      (*weak_refs)[TMBAD_THREAD_NUM] = sp;
    } else {
      sp = (*weak_refs)[TMBAD_THREAD_NUM].lock();
    }
  }
  omp_shared_ptr() {}
  T &operator*() const { return *sp; }
  T *operator->() const { return sp.get(); }
  explicit operator bool() const { return (bool)sp; }
};

/** \brief Generic checkpoint operator
    \details This class implements checkpointing.

    In short, a checkpoint operator is an *atomic operator* for which
    the derivatives are generated automatically up to any order using
    AD. The derivatives are stored in a *derivatives table* that can
    be shared among different instances of the operator.

    There are two main use cases of interest:

    1. The **fixed graph case** where the computational graph doesn't
    change. The purpose of this case is to reduce memory when the same
    operation sequence is repeated many times. This case uses a
    `standard_derivative_table`.
    2. The **adaptive case** where the computational graph can change
    dynamically. The purpose of this case is to allow algorithms that
    use parameter dependendent branching. This case uses a
    `retaping_derivative_table`.

    Memory management
    -----------------

    - Last operator 'alive' cleans up the shared derivatives table.

    - A derivative table can *always* be shared among operators within
      the same tape. However, shared derivative tables cause an
      overhead because the forward pass written by one operator may be
      invalidated by another operator. The reverse pass must account
      for this by re-calulating the forward values if necessary.

    - A derivative table can *never* be shared between different
      threads. A thread must have its own local (deep) copy of the
      table. If parallel tapes a constructed independently this is
      automatically the case.

    Retaping
    --------

    - Changed parameter inputs triggers retaping of the derivative
      table in the adaptive case. This happens correctly even in cases
      where the zero order operator is no longer present on the tape
      (it may have been removed by the tape optimizer).

    \note The fixed graph case can in principle allow for sparsity,
    but this is not yet implemented.
*/
template <class DerivativeTable>
struct AtomOp : global::DynamicOperator<-1, -1> {
  static const bool have_input_size_output_size = true;
  static const bool add_forward_replay_copy = true;

  TMBAD_SHARED_PTR<DerivativeTable> dtab;

  int order;

  struct control {
    bool clear_all;
    bool deriv_all;
    control() : clear_all(true), deriv_all(true) {}

    void swap() { std::swap(clear_all, deriv_all); }
  } ctrl;

  AtomOp() {}

  template <class T1>
  AtomOp(const T1 &F) : dtab(std::make_shared<DerivativeTable>(F)), order(0) {}
  template <class T1, class T2>
  AtomOp(const T1 &F, const T2 &x)
      : dtab(std::make_shared<DerivativeTable>(F, x)), order(0) {}
  template <class T1, class T2, class T3>
  AtomOp(const T1 &F, const T2 &x, const T3 &t)
      : dtab(std::make_shared<DerivativeTable>(F, x, t)), order(0) {}

  Index input_size() const { return (*dtab)[order].Domain(); }
  Index output_size() const { return (*dtab)[order].Range(); }

  void forward(ForwardArgs<Scalar> &args) {
    (*dtab).retape(args);

    (*dtab).requireOrder(order);

    size_t n = input_size();
    size_t m = output_size();

    auto x = args.x_segment(0, n);

    args.y_segment(0, m) = (*dtab)[order](x);
  }

  void reverse(ReverseArgs<Scalar> &args) {
    (*dtab).retape(args);

    (*dtab).requireOrder(order);

    size_t n = input_size();
    size_t m = output_size();

    auto x = args.x_segment(0, n);
    auto w = args.dy_segment(0, m);

    args.dx_segment(0, n) +=
        (*dtab)[order].Jacobian(x, w, ctrl.clear_all, ctrl.deriv_all);
  }

  void reverse(ReverseArgs<global::Replay> &args) {
    size_t n = input_size();
    size_t m = output_size();

    std::vector<global::Replay> x = args.x_segment(0, n);
    if (DerivativeTable::packed) x = repack(x);
    std::vector<global::Replay> w = args.dy_segment(0, m);
    std::vector<global::Replay> xw;
    xw.insert(xw.end(), x.begin(), x.end());
    xw.insert(xw.end(), w.begin(), w.end());

    (*dtab).requireOrder(order + 1);
    AtomOp cpy(*this);
    cpy.order++;
    cpy.ctrl.swap();
    args.dx_segment(0, n) += global::Complete<AtomOp>(cpy)(xw);
  }

  template <class T>
  void forward(ForwardArgs<T> &args) {
    TMBAD_ASSERT(false);
  }
  void reverse(ReverseArgs<Writer> &args) { TMBAD_ASSERT(false); }

  static const bool have_custom_identifier = true;
  void *custom_identifier() { return &(*dtab); }

  const char *op_name() { return "AtomOp"; }

  void print(global::print_config cfg) {
    Rcout << cfg.prefix;
    Rcout << "order=" << order << " ";
    Rcout << "(*dtab).size()=" << (*dtab).size() << " ";
    Rcout << "dtab=" << &(*dtab) << "\n";
    (*dtab)[order].print(cfg);
  }
};

/** \brief Transform a functor to have packed input/output

    \details A functor operating on `ad_segment`s must be transformed to
    'packed' format in order to be used by `AtomOp` with `packed=true`.

    - Evaluation functor is constructed using `PackWrap<Functor>`.
    - The tester object is constructed using `PackWrap<Tester>`.
*/
template <class Functor>
struct PackWrap {
  Functor F;
  PackWrap(const Functor &F) : F(F) {}
  /** \brief Transformed functor assuming original maps
   * `std::vector<ad_segment>` to `ad_segment` */
  template <class T>
  std::vector<T> operator()(const std::vector<T> &xp) {
    Index K = ScalarPack<SegmentRef>::size;
    size_t n = xp.size() / K;
    TMBAD_ASSERT2(n * K == xp.size(), "Invalid packed arguments");
    std::vector<ad_segment> x(n);
    for (size_t i = 0; i < n; i++) x[i] = unpack(xp, i);
    ad_segment y = F(x);
    ad_segment yp = pack(y);
    std::vector<T> ans = concat(std::vector<ad_segment>(1, yp));
    return ans;
  }
  /** \brief Transformed 'tester' assuming original maps `std::vector<Scalar*>`
   * to `bool` */
  bool operator()(const std::vector<Scalar> &xp) {
    Index K = ScalarPack<SegmentRef>::size;
    size_t n = xp.size() / K;
    TMBAD_ASSERT2(n * K == xp.size(), "Invalid packed arguments");
    std::vector<Scalar *> x(n);
    for (size_t i = 0; i < n; i++) x[i] = unpack(xp, i);
    return F(x);
  }
};

}  // namespace TMBad
#endif  // HAVE_CHECKPOINT_HPP
