// Autogenerated - do not edit by hand !
#include "TMBad.hpp"
namespace TMBad {

SpJacFun_config::SpJacFun_config() : compress(false), index_remap(true) {}
}  // namespace TMBad
// Autogenerated - do not edit by hand !
#include "ad_blas.hpp"
namespace TMBad {

vmatrix matmul(const vmatrix &x, const vmatrix &y) {
  vmatrix z(x.rows(), y.cols());
  Map<vmatrix> zm(&z(0), z.rows(), z.cols());
  matmul<false, false, false, false>(x, y, zm);
  return z;
}

dmatrix matmul(const dmatrix &x, const dmatrix &y) { return x * y; }
}  // namespace TMBad
// Autogenerated - do not edit by hand !
#include "checkpoint.hpp"
namespace TMBad {

bool ParametersChanged::operator()(const std::vector<Scalar> &x) {
  bool change = (x != x_prev);
  if (change) {
    x_prev = x;
  }
  return change;
}
}  // namespace TMBad
// Autogenerated - do not edit by hand !
#include "code_generator.hpp"
namespace TMBad {

void searchReplace(std::string &str, const std::string &oldStr,
                   const std::string &newStr) {
  std::string::size_type pos = 0u;
  while ((pos = str.find(oldStr, pos)) != std::string::npos) {
    str.replace(pos, oldStr.length(), newStr);
    pos += newStr.length();
  }
}

std::string code_config::float_ptr() { return float_str + (gpu ? "**" : "*"); }

std::string code_config::void_str() {
  return (gpu ? "__device__ void" : "extern \"C\" void");
}

void code_config::init_code() {
  if (gpu) {
    *cout << indent << "int idx = threadIdx.x;" << std::endl;
  }
}

void code_config::write_header_comment() {
  if (header_comment.length() > 0) *cout << header_comment << std::endl;
}

code_config::code_config()
    : asm_comments(true),
      gpu(true),
      indent("  "),
      header_comment("// Autogenerated - do not edit by hand !"),
      float_str(xstringify(TMBAD_SCALAR_TYPE)),
      cout(&Rcout) {}

void write_common(std::ostringstream &buffer, code_config cfg, size_t node) {
  std::ostream &cout = *cfg.cout;
  using std::endl;
  using std::left;
  using std::setw;
  std::string indent = cfg.indent;
  if (cfg.asm_comments)
    cout << indent << "asm(\"// Node: " << node << "\");" << endl;
  bool empty_buffer = (buffer.tellp() == 0);
  if (!empty_buffer) {
    std::string str = buffer.str();
    if (cfg.gpu) {
      std::string pattern = "]";
      std::string replace = "][idx]";
      searchReplace(str, pattern, replace);
    }
    searchReplace(str, ";v", "; v");
    searchReplace(str, ";d", "; d");
    cout << indent << str << endl;
  }
}

void write_forward(global &glob, code_config cfg) {
  using std::endl;
  using std::left;
  using std::setw;
  std::ostream &cout = *cfg.cout;
  cfg.write_header_comment();
  cout << cfg.void_str() << " forward(" << cfg.float_ptr() << " v) {" << endl;
  cfg.init_code();
  ForwardArgs<Writer> args(glob.inputs, glob.values);
  for (size_t i = 0; i < glob.opstack.size(); i++) {
    std::ostringstream buffer;
    Writer::cout = &buffer;
    glob.opstack[i]->forward(args);
    write_common(buffer, cfg, i);
    glob.opstack[i]->increment(args.ptr);
  }
  cout << "}" << endl;
}

void write_reverse(global &glob, code_config cfg) {
  using std::endl;
  using std::left;
  using std::setw;
  std::ostream &cout = *cfg.cout;
  cfg.write_header_comment();
  cout << cfg.void_str() << " reverse(" << cfg.float_ptr() << " v, "
       << cfg.float_ptr() << " d) {" << endl;
  cfg.init_code();
  ReverseArgs<Writer> args(glob.inputs, glob.values);
  for (size_t i = glob.opstack.size(); i > 0;) {
    i--;
    glob.opstack[i]->decrement(args.ptr);
    std::ostringstream buffer;
    Writer::cout = &buffer;
    glob.opstack[i]->reverse(args);
    write_common(buffer, cfg, i);
  }
  cout << "}" << endl;
}

void write_all(global glob, code_config cfg) {
  using std::endl;
  using std::left;
  using std::setw;
  std::ostream &cout = *cfg.cout;
  cout << "#include \"global.hpp\"" << endl;
  cout << "#include \"ad_blas.hpp\"" << endl;
  write_forward(glob, cfg);
  write_reverse(glob, cfg);
  cout << "int main() {}" << endl;
}
}  // namespace TMBad
#ifndef _WIN32
// Autogenerated - do not edit by hand !
#include "compile.hpp"
namespace TMBad {

void compile(global &glob, code_config cfg) {
  cfg.gpu = false;
  cfg.asm_comments = false;
  std::ofstream file;
  file.open("tmp.cpp");
  cfg.cout = &file;

  *cfg.cout << "#include <cmath>" << std::endl;
  *cfg.cout
      << "template<class T>T sign(const T &x) { return (x > 0) - (x < 0); }"
      << std::endl;

  write_forward(glob, cfg);

  write_reverse(glob, cfg);

  int out = system("g++ -O3 -g tmp.cpp -o tmp.so -shared -fPIC");
  if (out != 0) {
  }

  void *handle = dlopen("./tmp.so", RTLD_NOW);
  if (handle != NULL) {
    Rcout << "Loading compiled code!" << std::endl;
    glob.forward_compiled =
        reinterpret_cast<void (*)(Scalar *)>(dlsym(handle, "forward"));
    glob.reverse_compiled = reinterpret_cast<void (*)(Scalar *, Scalar *)>(
        dlsym(handle, "reverse"));
  }
}
}  // namespace TMBad
#endif
// Autogenerated - do not edit by hand !
#include "compression.hpp"
namespace TMBad {

std::ostream &operator<<(std::ostream &os, const period &x) {
  os << "begin: " << x.begin;
  os << " size: " << x.size;
  os << " rep: " << x.rep;
  return os;
}

std::vector<period> split_period(global *glob, period p,
                                 size_t max_period_size) {
  typedef std::ptrdiff_t ptrdiff_t;
  glob->subgraph_cache_ptr();

  size_t offset = glob->subgraph_ptr[p.begin].first;

  size_t nrow = 0;
  for (size_t i = 0; i < p.size; i++) {
    nrow += glob->opstack[p.begin + i]->input_size();
  }

  size_t ncol = p.rep;

  matrix_view<Index> x(&(glob->inputs[offset]), nrow, ncol);

  std::vector<bool> marks(ncol - 1, false);

  for (size_t i = 0; i < nrow; i++) {
    std::vector<period> pd =
        periodic<ptrdiff_t>(x.row_diff<ptrdiff_t>(i), max_period_size)
            .find_all();

    for (size_t j = 0; j < pd.size(); j++) {
      if (pd[j].begin > 0) {
        marks[pd[j].begin - 1] = true;
      }
      size_t end = pd[j].begin + pd[j].size * pd[j].rep;
      if (end < marks.size()) marks[end] = true;
    }
  }

  std::vector<period> ans;
  p.rep = 1;
  ans.push_back(p);
  for (size_t j = 0; j < marks.size(); j++) {
    if (marks[j]) {
      period pnew = p;
      pnew.begin = p.begin + (j + 1) * p.size;
      pnew.rep = 1;
      ans.push_back(pnew);
    } else {
      ans.back().rep++;
    }
  }

  return ans;
}

size_t compressed_input::input_size() const { return n; }

void compressed_input::update_increment_pattern() const {
  for (size_t i = 0; i < (size_t)np; i++)
    increment_pattern[which_periodic[i]] =
        period_data[period_offsets[i] + counter % period_sizes[i]];
}

void compressed_input::increment(Args<> &args) const {
  if (np) {
    update_increment_pattern();
    counter++;
  }
  for (size_t i = 0; i < n; i++) inputs[i] += increment_pattern[i];
  args.ptr.first = 0;
}

void compressed_input::decrement(Args<> &args) const {
  args.ptr.first = input_size();
  for (size_t i = 0; i < n; i++) inputs[i] -= increment_pattern[i];
  if (np) {
    counter--;
    update_increment_pattern();
  }
}

void compressed_input::forward_init(Args<> &args) const {
  counter = 0;
  inputs.resize(input_size());
  for (size_t i = 0; i < inputs.size(); i++) inputs[i] = args.input(i);
  args.inputs = inputs.data();
  args.ptr.first = 0;
}

void compressed_input::reverse_init(Args<> &args) {
  inputs.resize(input_size());
  for (size_t i = 0; i < inputs.size(); i++)
    inputs[i] = args.input(i) + input_diff[i];

  args.inputs = inputs.data();
  args.ptr.first = 0;
  args.ptr.second += m * nrep;
  counter = nrep - 1;
  update_increment_pattern();
  args.ptr.first = input_size();
}

void compressed_input::dependencies_intervals(Args<> &args,
                                              std::vector<Index> &lower,
                                              std::vector<Index> &upper) const {
  forward_init(args);
  lower = inputs;
  upper = inputs;
  for (size_t i = 0; i < nrep; i++) {
    for (size_t j = 0; j < inputs.size(); j++) {
      if (inputs[j] < lower[j]) lower[j] = inputs[j];
      if (inputs[j] > upper[j]) upper[j] = inputs[j];
    }
    increment(args);
  }
}

bool compressed_input::test_period(std::vector<ptrdiff_t> &x, size_t p) {
  for (size_t j = 0; j < x.size(); j++) {
    if (x[j] != x[j % p]) return false;
  }
  return true;
}

size_t compressed_input::find_shortest(std::vector<ptrdiff_t> &x) {
  for (size_t p = 1; p < max_period_size; p++) {
    if (test_period(x, p)) return p;
  }
  return x.size();
}

compressed_input::compressed_input() {}

compressed_input::compressed_input(std::vector<Index> &x, size_t offset,
                                   size_t nrow, size_t m, size_t ncol,
                                   size_t max_period_size)
    : n(nrow), m(m), nrep(ncol), counter(0), max_period_size(max_period_size) {
  matrix_view<Index> xm(&x[offset], nrow, ncol);

  for (size_t i = 0; i < nrow; i++) {
    std::vector<ptrdiff_t> rd = xm.row_diff<ptrdiff_t>(i);

    size_t p = find_shortest(rd);

    increment_pattern.push_back(rd[0]);
    if (p != 1) {
      which_periodic.push_back(i);
      period_sizes.push_back(p);

      size_t pos = std::search(period_data.begin(), period_data.end(),
                               rd.begin(), rd.begin() + p) -
                   period_data.begin();
      if (pos < period_data.size()) {
        period_offsets.push_back(pos);
      } else {
        period_offsets.push_back(period_data.size());
        period_data.insert(period_data.end(), rd.begin(), rd.begin() + p);
      }
    }
  }

  np = which_periodic.size();

  input_diff.resize(n, 0);
  Args<> args(input_diff);
  forward_init(args);
  for (size_t i = 0; i < nrep; i++) {
    increment(args);
  }
  input_diff = inputs;
}

StackOp::StackOp(global *glob, period p, IndexPair ptr, size_t max_period_size)
    : shared(std::make_shared<shared_data>()) {
  global::operation_stack &opstack(shared->opstack);
  compressed_input &ci(shared->ci);
  opstack.resize(p.size);
  size_t n = 0, m = 0;
  for (size_t i = 0; i < p.size; i++) {
    opstack[i] = glob->opstack[p.begin + i]->copy();
    n += opstack[i]->input_size();
    m += opstack[i]->output_size();
  }
  ci = compressed_input(glob->inputs, ptr.first, n, m, p.rep, max_period_size);
}

void StackOp::print(global::print_config cfg) {
  global::operation_stack &opstack(shared->opstack);
  compressed_input &ci(shared->ci);
  std::vector<const char *> tmp(opstack.size());
  for (size_t i = 0; i < opstack.size(); i++) tmp[i] = opstack[i]->op_name();
  Rcout << cfg.prefix << " opstack = " << tmp << "\n";

  Rcout << cfg.prefix << " " << "nrep" << " = " << ci.nrep << "\n";
  ;
  Rcout << cfg.prefix << " " << "increment_pattern" << " = "
        << ci.increment_pattern << "\n";
  ;
  if (ci.which_periodic.size() > 0) {
    Rcout << cfg.prefix << " " << "which_periodic" << " = " << ci.which_periodic
          << "\n";
    ;
    Rcout << cfg.prefix << " " << "period_sizes" << " = " << ci.period_sizes
          << "\n";
    ;
    Rcout << cfg.prefix << " " << "period_offsets" << " = " << ci.period_offsets
          << "\n";
    ;
    Rcout << cfg.prefix << " " << "period_data" << " = " << ci.period_data
          << "\n";
    ;
  }

  Rcout << "\n";
}

Index StackOp::input_size() const { return shared->ci.n; }

Index StackOp::output_size() const { return shared->ci.m * shared->ci.nrep; }

void StackOp::forward(ForwardArgs<Writer> &args) {
  global::operation_stack &opstack(shared->opstack);
  compressed_input &ci(shared->ci);
  size_t n = ci.n, m = ci.m, nrep = ci.nrep;
  std::vector<Index> inputs(n);
  for (size_t i = 0; i < (size_t)n; i++) inputs[i] = args.input(i);
  std::vector<Index> outputs(m);
  for (size_t i = 0; i < (size_t)m; i++) outputs[i] = args.output(i);
  Writer w;
  size_t np = ci.which_periodic.size();
  size_t sp = ci.period_data.size();
  w << "for (int count = 0, ";
  if (n > 0) {
    w << "i[" << n << "]=" << inputs << ", " << "ip[" << n
      << "]=" << ci.increment_pattern << ", ";
  }
  if (np > 0) {
    w << "wp[" << np << "]=" << ci.which_periodic << ", " << "ps[" << np
      << "]=" << ci.period_sizes << ", " << "po[" << np
      << "]=" << ci.period_offsets << ", " << "pd[" << sp
      << "]=" << ci.period_data << ", ";
  }
  w << "o[" << m << "]=" << outputs << "; " << "count < " << nrep
    << "; count++) {\n";

  w << "    ";
  ForwardArgs<Writer> args_cpy = args;
  args_cpy.set_indirect();
  for (size_t k = 0; k < opstack.size(); k++) {
    opstack[k]->forward_incr(args_cpy);
  }
  w << "\n";

  if (np > 0) {
    w << "    ";
    for (size_t k = 0; k < np; k++)
      w << "ip[wp[" << k << "]] = pd[po[" << k << "] + count % ps[" << k
        << "]]; ";
    w << "\n";
  }
  if (n > 0) {
    w << "    ";
    for (size_t k = 0; k < n; k++) w << "i[" << k << "] += ip[" << k << "]; ";
    w << "\n";
  }
  w << "    ";
  for (size_t k = 0; k < m; k++) w << "o[" << k << "] += " << m << "; ";
  w << "\n";

  w << "  ";
  w << "}";
}

void StackOp::reverse(ReverseArgs<Writer> &args) {
  global::operation_stack &opstack(shared->opstack);
  compressed_input &ci(shared->ci);
  size_t n = ci.n, m = ci.m, nrep = ci.nrep;
  std::vector<ptrdiff_t> inputs(input_size());
  for (size_t i = 0; i < inputs.size(); i++) {
    ptrdiff_t tmp;
    if (-ci.input_diff[i] < ci.input_diff[i]) {
      tmp = -((ptrdiff_t)-ci.input_diff[i]);
    } else {
      tmp = ci.input_diff[i];
    }
    inputs[i] = args.input(i) + tmp;
  }
  std::vector<Index> outputs(ci.m);
  for (size_t i = 0; i < (size_t)ci.m; i++)
    outputs[i] = args.output(i) + ci.m * ci.nrep;
  Writer w;
  size_t np = ci.which_periodic.size();
  size_t sp = ci.period_data.size();
  w << "for (int count = " << nrep << ", ";
  if (n > 0) {
    w << "i[" << n << "]=" << inputs << ", " << "ip[" << n
      << "]=" << ci.increment_pattern << ", ";
  }
  if (np > 0) {
    w << "wp[" << np << "]=" << ci.which_periodic << ", " << "ps[" << np
      << "]=" << ci.period_sizes << ", " << "po[" << np
      << "]=" << ci.period_offsets << ", " << "pd[" << sp
      << "]=" << ci.period_data << ", ";
  }
  w << "o[" << m << "]=" << outputs << "; " << "count > 0 ; ) {\n";

  w << "    ";
  w << "count--;\n";
  if (np > 0) {
    w << "    ";
    for (size_t k = 0; k < np; k++)
      w << "ip[wp[" << k << "]] = pd[po[" << k << "] + count % ps[" << k
        << "]]; ";
    w << "\n";
  }
  if (n > 0) {
    w << "    ";
    for (size_t k = 0; k < n; k++) w << "i[" << k << "] -= ip[" << k << "]; ";
    w << "\n";
  }
  w << "    ";
  for (size_t k = 0; k < m; k++) w << "o[" << k << "] -= " << m << "; ";
  w << "\n";

  w << "    ";

  ReverseArgs<Writer> args_cpy = args;
  args_cpy.set_indirect();
  args_cpy.ptr.first = ci.n;
  args_cpy.ptr.second = ci.m;
  for (size_t k = opstack.size(); k > 0;) {
    k--;
    opstack[k]->reverse_decr(args_cpy);
  }
  w << "\n";

  w << "  ";
  w << "}";
}

void StackOp::dependencies(Args<> args, Dependencies &dep) const {
  compressed_input &ci(shared->ci);
  std::vector<Index> lower;
  std::vector<Index> upper;
  ci.dependencies_intervals(args, lower, upper);
  for (size_t i = 0; i < lower.size(); i++) {
    dep.add_interval(lower[i], upper[i]);
  }
}

const char *StackOp::op_name() { return "StackOp"; }

void reorder_sub_expressions(global &glob) {
  global::hash_config cfg;
  cfg.strong_inv = false;
  cfg.strong_const = false;
  cfg.strong_output = false;
  cfg.reduce = false;
  cfg.deterministic = TMBAD_DETERMINISTIC_HASH;
  std::vector<hash_t> h = glob.hash_sweep(cfg);
  std::vector<Index> remap = radix::first_occurance<Index>(h);

  TMBAD_ASSERT(all_allow_remap(glob));

  Args<> args(glob.inputs);
  for (size_t i = 0; i < glob.opstack.size(); i++) {
    Dependencies dep;
    glob.opstack[i]->dependencies(args, dep);

    Index var = args.ptr.second;
    toposort_remap<Index> fb(remap, var);
    dep.apply(fb);
    glob.opstack[i]->increment(args.ptr);
  }

  std::vector<Index> ord = radix::order<Index>(remap);
  std::vector<Index> v2o = glob.var2op();
  glob.subgraph_seq = subset(v2o, ord);

  glob = glob.extract_sub();
}

void reorder_temporaries(global &glob) {
  std::vector<Index> remap(glob.values.size(), Index(-1));
  Args<> args(glob.inputs);
  for (size_t i = 0; i < glob.opstack.size(); i++) {
    Dependencies dep;
    glob.opstack[i]->dependencies(args, dep);
    sort_unique_inplace(dep);
    Index var = args.ptr.second;
    temporaries_remap<Index> fb(remap, var);
    dep.apply(fb);
    glob.opstack[i]->increment(args.ptr);
  }

  for (size_t i = remap.size(); i > 0;) {
    i--;
    if (remap[i] == Index(-1))
      remap[i] = i;
    else
      remap[i] = remap[remap[i]];
  }

  std::vector<Index> ord = radix::order<Index>(remap);
  std::vector<Index> v2o = glob.var2op();
  glob.subgraph_seq = subset(v2o, ord);

  glob = glob.extract_sub();
}

void reorder_depth_first(global &glob) {
  std::vector<bool> done(glob.opstack.size(), false);
  std::vector<Index> v2o = glob.var2op();
  std::vector<Index> stack;
  std::vector<Index> result;
  Args<> args(glob.inputs);
  glob.subgraph_cache_ptr();
  for (size_t k = 0; k < glob.dep_index.size(); k++) {
    Index dep_var = glob.dep_index[k];
    Index i = v2o[dep_var];

    stack.push_back(i);
    while (stack.size() > 0) {
      Index i = stack.back();
      args.ptr = glob.subgraph_ptr[i];
      Dependencies dep;
      glob.opstack[i]->dependencies(args, dep);
      dfs_add_to_stack<Index> add_to_stack(stack, done, v2o);
      size_t before = stack.size();
      dep.apply(add_to_stack);
      size_t after = stack.size();

      if (before == after) {
        if (!done[i]) {
          result.push_back(i);
          done[i] = true;
        }
        stack.pop_back();
      }
    }
  }

  glob.subgraph_seq = result;
  glob = glob.extract_sub();

  glob.shrink_to_fit();
}

void compress(global &glob, size_t max_period_size) {
  size_t min_period_rep = TMBAD_MIN_PERIOD_REP;
  periodic<global::OperatorPure *> p(glob.opstack, max_period_size,
                                     min_period_rep);
  std::vector<period> periods = p.find_all();

  std::vector<period> periods_expand;
  for (size_t i = 0; i < periods.size(); i++) {
    std::vector<period> tmp = split_period(&glob, periods[i], max_period_size);

    if (tmp.size() > 10) {
      tmp.resize(0);
      tmp.push_back(periods[i]);
    }

    for (size_t j = 0; j < tmp.size(); j++) {
      if (tmp[j].rep > 1) periods_expand.push_back(tmp[j]);
    }
  }

  std::swap(periods, periods_expand);
  OperatorPure *null_op = get_glob()->getOperator<global::NullOp>();
  IndexPair ptr(0, 0);
  Index k = 0;
  for (size_t i = 0; i < periods.size(); i++) {
    period p = periods[i];
    TMBAD_ASSERT(p.rep >= 1);
    while (k < p.begin) {
      glob.opstack[k]->increment(ptr);
      k++;
    }

    OperatorPure *pOp =
        get_glob()->getOperator<StackOp>(&glob, p, ptr, max_period_size);
    Index ninp = 0;
    for (size_t j = 0; j < p.size * p.rep; j++) {
      ninp += glob.opstack[p.begin + j]->input_size();
      glob.opstack[p.begin + j]->deallocate();
      glob.opstack[p.begin + j] = null_op;
    }
    glob.opstack[p.begin] = pOp;
    ninp -= pOp->input_size();
    glob.opstack[p.begin + 1] =
        get_glob()->getOperator<global::NullOp2>(ninp, 0);
  }

  std::vector<bool> marks(glob.values.size(), true);
  glob.extract_sub_inplace(marks);
  glob.shrink_to_fit();
}
}  // namespace TMBad
// Autogenerated - do not edit by hand !
#include "global.hpp"
namespace TMBad {

global *global_ptr_data[TMBAD_MAX_NUM_THREADS] = {NULL};
global **global_ptr = global_ptr_data;
std::ostream *Writer::cout = 0;
bool global::fuse = 0;

global *get_glob() { return global_ptr[TMBAD_THREAD_NUM]; }

Dependencies::Dependencies() {}

void Dependencies::clear() {
  this->resize(0);
  I.resize(0);
}

void Dependencies::add_interval(Index a, Index b) {
  I.push_back(std::pair<Index, Index>(a, b));
}

void Dependencies::add_segment(Index start, Index size) {
  if (size > 0) add_interval(start, start + size - 1);
}

void Dependencies::monotone_transform_inplace(const std::vector<Index> &x) {
  for (size_t i = 0; i < this->size(); i++) (*this)[i] = x[(*this)[i]];
  for (size_t i = 0; i < I.size(); i++) {
    I[i].first = x[I[i].first];
    I[i].second = x[I[i].second];
  }
}

bool Dependencies::any(const std::vector<bool> &x) const {
  for (size_t i = 0; i < this->size(); i++)
    if (x[(*this)[i]]) return true;
  for (size_t i = 0; i < I.size(); i++) {
    for (Index j = I[i].first; j <= I[i].second; j++) {
      if (x[j]) return true;
    }
  }
  return false;
}

std::string tostr(const Index &x) {
  std::ostringstream strs;
  strs << x;
  return strs.str();
}

std::string tostr(const Scalar &x) {
  std::ostringstream strs;
  strs << x;
  return strs.str();
}

Writer::Writer(std::string str) : std::string(str) {}

Writer::Writer(Scalar x) : std::string(tostr(x)) {}

Writer::Writer() {}

std::string Writer::p(std::string x) { return "(" + x + ")"; }

Writer Writer::operator+(const Writer &other) {
  return p(*this + " + " + other);
}

Writer Writer::operator-(const Writer &other) {
  return p(*this + " - " + other);
}

Writer Writer::operator-() { return " - " + *this; }

Writer Writer::operator*(const Writer &other) { return *this + " * " + other; }

Writer Writer::operator/(const Writer &other) { return *this + " / " + other; }

Writer Writer::operator*(const Scalar &other) {
  return *this + "*" + tostr(other);
}

Writer Writer::operator+(const Scalar &other) {
  return p(*this + "+" + tostr(other));
}

void Writer::operator=(const Writer &other) {
  *cout << *this + " = " + other << ";";
}

void Writer::operator+=(const Writer &other) {
  *cout << *this + " += " + other << ";";
}

void Writer::operator-=(const Writer &other) {
  *cout << *this + " -= " + other << ";";
}

void Writer::operator*=(const Writer &other) {
  *cout << *this + " *= " + other << ";";
}

void Writer::operator/=(const Writer &other) {
  *cout << *this + " /= " + other << ";";
}

Position::Position(Index node, Index first, Index second)
    : node(node), ptr(first, second) {}

Position::Position() : node(0), ptr(0, 0) {}

bool Position::operator<(const Position &other) const {
  return this->node < other.node;
}

graph::graph() {}

size_t graph::num_neighbors(Index node) { return p[node + 1] - p[node]; }

Index *graph::neighbors(Index node) { return &(j[p[node]]); }

bool graph::empty() { return p.size() == 0; }

size_t graph::num_nodes() { return (empty() ? 0 : p.size() - 1); }

void graph::print() {
  for (size_t node = 0; node < num_nodes(); node++) {
    Rcout << node << ": ";
    for (size_t i = 0; i < num_neighbors(node); i++) {
      Rcout << " " << neighbors(node)[i];
    }
    Rcout << "\n";
  }
}

std::vector<Index> graph::rowcounts() {
  std::vector<Index> ans(num_nodes());
  for (size_t i = 0; i < ans.size(); i++) ans[i] = num_neighbors(i);
  return ans;
}

std::vector<Index> graph::colcounts() {
  std::vector<Index> ans(num_nodes());
  for (size_t i = 0; i < j.size(); i++) ans[j[i]]++;
  return ans;
}

void graph::bfs(const std::vector<Index> &start, std::vector<bool> &visited,
                std::vector<Index> &result) {
  for (size_t i = 0; i < start.size(); i++) {
    Index node = start[i];
    for (size_t j_ = 0; j_ < num_neighbors(node); j_++) {
      Index k = neighbors(node)[j_];
      if (!visited[k]) {
        result.push_back(k);
        visited[k] = true;
      }
    }
  }
}

void graph::search(std::vector<Index> &start, bool sort_input,
                   bool sort_output) {
  if (mark.size() == 0) mark.resize(num_nodes(), false);

  search(start, mark, sort_input, sort_output);

  for (size_t i = 0; i < start.size(); i++) mark[start[i]] = false;
}

void graph::search(std::vector<Index> &start, std::vector<bool> &visited,
                   bool sort_input, bool sort_output) {
  if (sort_input) sort_unique_inplace(start);

  for (size_t i = 0; i < start.size(); i++) visited[start[i]] = true;

  bfs(start, visited, start);

  if (sort_output) sort_inplace(start);
}

std::vector<Index> graph::boundary(const std::vector<Index> &subgraph) {
  if (mark.size() == 0) mark.resize(num_nodes(), false);

  std::vector<Index> boundary;

  for (size_t i = 0; i < subgraph.size(); i++) mark[subgraph[i]] = true;

  bfs(subgraph, mark, boundary);

  for (size_t i = 0; i < subgraph.size(); i++) mark[subgraph[i]] = false;
  for (size_t i = 0; i < boundary.size(); i++) mark[boundary[i]] = false;

  return boundary;
}

graph::graph(size_t num_nodes, const std::vector<IndexPair> &edges) {
  std::vector<IndexPair>::const_iterator it;
  std::vector<Index> row_counts(num_nodes, 0);
  for (it = edges.begin(); it != edges.end(); it++) {
    row_counts[it->first]++;
  }

  p.resize(num_nodes + 1);
  p[0] = 0;
  for (size_t i = 0; i < num_nodes; i++) {
    p[i + 1] = p[i] + row_counts[i];
  }

  std::vector<Index> k(p);
  j.resize(edges.size());
  for (it = edges.begin(); it != edges.end(); it++) {
    j[k[it->first]++] = it->second;
  }
}

op_info::op_info() : code(0) {
  static_assert(sizeof(IntRep) * 8 >= op_flag_count,
                "'IntRep' not wide enough!");
}

op_info::op_info(op_flag f) : code(1 << f) {}

bool op_info::test(op_flag f) const { return code & 1 << f; }

op_info &op_info::operator|=(const op_info &other) {
  code |= other.code;
  return *this;
}

op_info &op_info::operator&=(const op_info &other) {
  code &= other.code;
  return *this;
}

global::operation_stack::operation_stack() {}

global::operation_stack::operation_stack(const operation_stack &other) {
  (*this).copy_from(other);
}

void global::operation_stack::push_back(OperatorPure *x) {
  Base::push_back(x);

  any |= x->info();
}

operation_stack &global::operation_stack::operator=(
    const operation_stack &other) {
  if (this != &other) {
    (*this).clear();
    (*this).copy_from(other);
  }
  return *this;
}

global::operation_stack::~operation_stack() { (*this).clear(); }

void global::operation_stack::clear() {
  if (any.test(op_info::dynamic)) {
    for (size_t i = 0; i < (*this).size(); i++) (*this)[i]->deallocate();
  }
  (*this).resize(0);
}

void global::operation_stack::copy_from(const operation_stack &other) {
  if (other.any.test(op_info::dynamic)) {
    for (size_t i = 0; i < other.size(); i++) Base::push_back(other[i]->copy());
  } else {
    Base::operator=(other);
  }
  this->any = other.any;
}

global::global()
    : forward_compiled(NULL),
      reverse_compiled(NULL),
      parent_glob(NULL),
      in_use(false) {}

void global::copy_from(const global &other) {
  opstack = other.opstack;
  values = other.values;
  derivs = other.derivs;
  inputs = other.inputs;
  inv_index = other.inv_index;
  dep_index = other.dep_index;
  subgraph_ptr = other.subgraph_ptr;
  subgraph_seq = other.subgraph_seq;
  forward_compiled = other.forward_compiled;
  reverse_compiled = other.reverse_compiled;
  parent_glob = other.parent_glob;
  in_use = other.in_use;
  if (opstack.any.test(op_info::synchronize_on_copy)) {
    forward_synchronize();
  }
}

global::global(const global &other) { copy_from(other); }

global &global::operator=(const global &other) {
  if (this != &other) {
    copy_from(other);
  }
  return *this;
}

void global::clear() {
  values.resize(0);
  derivs.resize(0);
  inputs.resize(0);
  inv_index.resize(0);
  dep_index.resize(0);
  subgraph_ptr.resize(0);
  subgraph_seq.resize(0);
  opstack.clear();
}

void global::shrink_to_fit(double tol) {
  std::vector<Scalar>().swap(derivs);
  std::vector<IndexPair>().swap(subgraph_ptr);
  if (values.size() < tol * values.capacity())
    std::vector<Scalar>(values).swap(values);
  if (inputs.size() < tol * inputs.capacity())
    std::vector<Index>(inputs).swap(inputs);
  if (opstack.size() < tol * opstack.capacity())
    std::vector<OperatorPure *>(opstack).swap(opstack);
}

void global::clear_deriv(Position start) {
  derivs.resize(values.size());
  std::fill(derivs.begin() + start.ptr.second, derivs.end(), 0);
}

Scalar &global::value_inv(Index i) { return values[inv_index[i]]; }

Scalar &global::deriv_inv(Index i) { return derivs[inv_index[i]]; }

Scalar &global::value_dep(Index i) { return values[dep_index[i]]; }

Scalar &global::deriv_dep(Index i) { return derivs[dep_index[i]]; }

Position global::begin() { return Position(0, 0, 0); }

Position global::end() {
  return Position(opstack.size(), inputs.size(), values.size());
}

CONSTEXPR bool global::no_filter::operator[](size_t i) const { return true; }

void global::forward(Position start) {
  if (forward_compiled != NULL) {
    forward_compiled(values.data());
    return;
  }
  ForwardArgs<Scalar> args(inputs, values, this);
  args.ptr = start.ptr;
  forward_loop(args, start.node);
}

void global::reverse(Position start) {
  if (reverse_compiled != NULL) {
    reverse_compiled(values.data(), derivs.data());
    return;
  }
  ReverseArgs<Scalar> args(inputs, values, derivs, this);
  reverse_loop(args, start.node);
}

void global::forward_sub() {
  ForwardArgs<Scalar> args(inputs, values, this);
  forward_loop_subgraph(args);
}

void global::reverse_sub() {
  ReverseArgs<Scalar> args(inputs, values, derivs, this);
  reverse_loop_subgraph(args);
}

void global::forward_synchronize() {
  ForwardArgs<Scalar> args(inputs, values, this);
  for (size_t i = 0; i < opstack.size(); i++) {
    opstack[i]->synchronize(args);
    opstack[i]->increment(args.ptr);
  }
}

void global::forward(std::vector<bool> &marks) {
  intervals<Index> marked_intervals;
  ForwardArgs<bool> args(inputs, marks, marked_intervals);
  forward_loop(args);
}

void global::reverse(std::vector<bool> &marks) {
  intervals<Index> marked_intervals;
  ReverseArgs<bool> args(inputs, marks, marked_intervals);
  reverse_loop(args);
}

void global::forward_sub(std::vector<bool> &marks,
                         const std::vector<bool> &node_filter) {
  intervals<Index> marked_intervals;
  ForwardArgs<bool> args(inputs, marks, marked_intervals);
  if (node_filter.size() == 0)
    forward_loop_subgraph(args);
  else
    forward_loop(args, 0, node_filter);
}

void global::reverse_sub(std::vector<bool> &marks,
                         const std::vector<bool> &node_filter) {
  intervals<Index> marked_intervals;
  ReverseArgs<bool> args(inputs, marks, marked_intervals);
  if (node_filter.size() == 0)
    reverse_loop_subgraph(args);
  else
    reverse_loop(args, 0, node_filter);
}

void global::forward_dense(std::vector<bool> &marks) {
  intervals<Index> marked_intervals;
  ForwardArgs<bool> args(inputs, marks, marked_intervals);
  for (size_t i = 0; i < opstack.size(); i++) {
    opstack[i]->forward_incr_mark_dense(args);
  }
}

intervals<Index> global::get_intervals(op_info::op_flag flag, bool reverse,
                                       bool forward) const {
  Dependencies dep;
  intervals<Index> marked_intervals;
  Args<> args(inputs);
  for (size_t i = 0; i < opstack.size(); i++) {
    if (opstack[i]->info().test(flag)) {
      dep.clear();
      if (reverse) opstack[i]->dependencies(args, dep);
      if (forward) opstack[i]->dependencies_updating(args, dep);

      for (size_t i = 0; i < dep.I.size(); i++) {
        Index a = dep.I[i].first;
        Index b = dep.I[i].second;
        marked_intervals.insert(a, b);
      }
    }
    opstack[i]->increment(args.ptr);
  }
  return marked_intervals;
}

intervals<Index> global::updating_intervals() const {
  return get_intervals(op_info::reverse_updating);
}

intervals<Index> global::get_intervals_sub(op_info::op_flag flag, bool reverse,
                                           bool forward) const {
  Dependencies dep;
  intervals<Index> marked_intervals;
  Args<> args(inputs);
  subgraph_cache_ptr();
  for (size_t j = 0; j < subgraph_seq.size(); j++) {
    Index i = subgraph_seq[j];
    args.ptr = subgraph_ptr[i];
    if (opstack[i]->info().test(flag)) {
      dep.clear();
      if (reverse) opstack[i]->dependencies(args, dep);
      if (forward) opstack[i]->dependencies_updating(args, dep);

      for (size_t i = 0; i < dep.I.size(); i++) {
        Index a = dep.I[i].first;
        Index b = dep.I[i].second;
        marked_intervals.insert(a, b);
      }
    }
  }
  return marked_intervals;
}

intervals<Index> global::updating_intervals_sub() const {
  return get_intervals_sub(op_info::reverse_updating);
}

Replay &global::replay::value_inv(Index i) { return values[orig.inv_index[i]]; }

Replay &global::replay::deriv_inv(Index i) { return derivs[orig.inv_index[i]]; }

Replay &global::replay::value_dep(Index i) { return values[orig.dep_index[i]]; }

Replay &global::replay::deriv_dep(Index i) { return derivs[orig.dep_index[i]]; }

global::replay::replay(const global &orig, global &target)
    : orig(orig), target(target) {
  TMBAD_ASSERT(&orig != &target);
}

void global::replay::start() {
  parent_glob = get_glob();
  if (&target != parent_glob) target.ad_start();
  values = std::vector<Replay>(orig.values.begin(), orig.values.end());
}

void global::replay::stop() {
  if (&target != parent_glob) target.ad_stop();
  TMBAD_ASSERT(parent_glob == get_glob());
}

void global::replay::add_updatable_derivs(const intervals<Index> &I) {
  struct {
    Replay *p;
    void operator()(Index a, Index b) {
      bool all_updatable = true;
      for (size_t i = a; i <= b; i++)
        all_updatable = all_updatable && p[i].updatable();
      Index n = b - a + 1;
      if (!all_updatable) {
        global::AllocOp Z(n);
        Z(p + a, n);
      } else {
        global::ZeroOp Z(n);
        Z(p + a, n);
      }
    }
  } F = {derivs.data()};
  I.apply(F);
}

void global::replay::clear_deriv() {
  derivs.resize(values.size());
  std::fill(derivs.begin(), derivs.end(), Replay(0));
}

void global::replay::forward(bool inv_tags, bool dep_tags, Position start,
                             const std::vector<bool> &node_filter) {
  TMBAD_ASSERT(&target == get_glob());
  if (inv_tags) {
    for (size_t i = 0; i < orig.inv_index.size(); i++)
      value_inv(i).Independent();
  }
  ForwardArgs<Replay> args(orig.inputs, values);
  if (node_filter.size() > 0) {
    TMBAD_ASSERT(node_filter.size() == orig.opstack.size());
    orig.forward_loop(args, start.node, node_filter);
  } else {
    orig.forward_loop(args, start.node);
  }
  if (dep_tags) {
    for (size_t i = 0; i < orig.dep_index.size(); i++) value_dep(i).Dependent();
  }
}

void global::replay::reverse(bool dep_tags, bool inv_tags, Position start,
                             const std::vector<bool> &node_filter) {
  TMBAD_ASSERT(&target == get_glob());
  if (inv_tags) {
    for (size_t i = 0; i < orig.dep_index.size(); i++)
      deriv_dep(i).Independent();
  }

  if (orig.opstack.any.test(op_info::reverse_updating)) {
    intervals<Index> I = orig.updating_intervals();
    add_updatable_derivs(I);
  }
  ReverseArgs<Replay> args(orig.inputs, values, derivs);
  if (node_filter.size() > 0) {
    TMBAD_ASSERT(node_filter.size() == orig.opstack.size());
    orig.reverse_loop(args, start.node, node_filter);
  } else {
    orig.reverse_loop(args, start.node);
  }

  std::fill(derivs.begin(), derivs.begin() + start.ptr.second, Replay(0));
  if (dep_tags) {
    for (size_t i = 0; i < orig.inv_index.size(); i++) deriv_inv(i).Dependent();
  }
}

void global::replay::forward_sub() {
  ForwardArgs<Replay> args(orig.inputs, values);
  orig.forward_loop_subgraph(args);
}

void global::replay::reverse_sub() {
  if (orig.opstack.any.test(op_info::reverse_updating)) {
    intervals<Index> I = orig.updating_intervals_sub();
    add_updatable_derivs(I);
  }
  ReverseArgs<Replay> args(orig.inputs, values, derivs);
  orig.reverse_loop_subgraph(args);
}

void global::replay::clear_deriv_sub() { orig.clear_array_subgraph(derivs); }

void global::forward_replay(bool inv_tags, bool dep_tags) {
  global new_glob;
  global::replay replay(*this, new_glob);
  replay.start();
  replay.forward(inv_tags, dep_tags);
  replay.stop();
  *this = new_glob;
}

void global::subgraph_cache_ptr() const {
  if (subgraph_ptr.size() == opstack.size()) return;
  TMBAD_ASSERT(subgraph_ptr.size() < opstack.size());
  if (subgraph_ptr.size() == 0) subgraph_ptr.push_back(IndexPair(0, 0));
  for (size_t i = subgraph_ptr.size(); i < opstack.size(); i++) {
    IndexPair ptr = subgraph_ptr[i - 1];
    opstack[i - 1]->increment(ptr);
    subgraph_ptr.push_back(ptr);
  }
}

void global::set_subgraph(const std::vector<bool> &marks, bool append) {
  std::vector<Index> v2o = var2op();
  if (!append) subgraph_seq.resize(0);
  Index previous = (Index)-1;
  for (size_t i = 0; i < marks.size(); i++) {
    if (marks[i] && (v2o[i] != previous)) {
      subgraph_seq.push_back(v2o[i]);
      previous = v2o[i];
    }
  }
}

void global::mark_subgraph(std::vector<bool> &marks) {
  TMBAD_ASSERT(marks.size() == values.size());
  clear_array_subgraph(marks, true);
}

void global::unmark_subgraph(std::vector<bool> &marks) {
  TMBAD_ASSERT(marks.size() == values.size());
  clear_array_subgraph(marks, false);
}

void global::subgraph_trivial() {
  subgraph_cache_ptr();
  subgraph_seq.resize(0);
  for (size_t i = 0; i < opstack.size(); i++) subgraph_seq.push_back(i);
}

void global::clear_deriv_sub() { clear_array_subgraph(derivs); }

global global::extract_sub(std::vector<Index> &var_remap, global new_glob) {
  subgraph_cache_ptr();
  TMBAD_ASSERT(var_remap.size() == 0 || var_remap.size() == values.size());
  var_remap.resize(values.size(), 0);
  std::vector<bool> independent_variable = inv_marks();
  std::vector<bool> dependent_variable = dep_marks();
  ForwardArgs<Scalar> args(inputs, values, this);
  for (size_t j = 0; j < subgraph_seq.size(); j++) {
    Index i = subgraph_seq[j];
    args.ptr = subgraph_ptr[i];

    size_t nout = opstack[i]->output_size();
    for (size_t k = 0; k < nout; k++) {
      Index new_index = new_glob.values.size();
      Index old_index = args.output(k);
      var_remap[old_index] = new_index;
      new_glob.values.push_back(args.y(k));
      if (independent_variable[old_index]) {
        independent_variable[old_index] = false;
      }
      if (dependent_variable[old_index]) {
        dependent_variable[old_index] = false;
      }
    }

    size_t nin = opstack[i]->input_size();
    for (size_t k = 0; k < nin; k++) {
      new_glob.inputs.push_back(var_remap[args.input(k)]);
    }

    new_glob.opstack.push_back(opstack[i]->copy());
  }

  independent_variable.flip();
  dependent_variable.flip();

  for (size_t i = 0; i < inv_index.size(); i++) {
    Index old_var = inv_index[i];
    if (independent_variable[old_var])
      new_glob.inv_index.push_back(var_remap[old_var]);
  }
  for (size_t i = 0; i < dep_index.size(); i++) {
    Index old_var = dep_index[i];
    if (dependent_variable[old_var])
      new_glob.dep_index.push_back(var_remap[old_var]);
  }
  return new_glob;
}

void global::extract_sub_inplace(std::vector<bool> marks) {
  TMBAD_ASSERT(marks.size() == values.size());
  std::vector<Index> var_remap(values.size(), 0);
  std::vector<bool> independent_variable = inv_marks();
  std::vector<bool> dependent_variable = dep_marks();
  intervals<Index> marked_intervals;
  ForwardArgs<bool> args(inputs, marks, marked_intervals);
  size_t s = 0, s_input = 0;
  std::vector<bool> opstack_deallocate(opstack.size(), false);

  for (size_t i = 0; i < opstack.size(); i++) {
    op_info info = opstack[i]->info();

    size_t nout = opstack[i]->output_size();
    bool any_marked_output = info.test(op_info::elimination_protected);
    for (size_t j = 0; j < nout; j++) {
      any_marked_output |= args.y(j);
    }
    if (info.test(op_info::forward_updating)) {
      Dependencies dep;
      opstack[i]->dependencies_updating(args, dep);
      any_marked_output |= dep.any(args.values);
    }

    if (any_marked_output) {
      for (size_t k = 0; k < nout; k++) {
        Index new_index = s;
        Index old_index = args.output(k);
        var_remap[old_index] = new_index;
        values[new_index] = values[old_index];
        if (independent_variable[old_index]) {
          independent_variable[old_index] = false;
        }
        if (dependent_variable[old_index]) {
          dependent_variable[old_index] = false;
        }
        s++;
      }

      size_t nin = opstack[i]->input_size();
      for (size_t k = 0; k < nin; k++) {
        inputs[s_input] = var_remap[args.input(k)];
        s_input++;
      }
    }
    opstack[i]->increment(args.ptr);
    if (!any_marked_output) {
      opstack_deallocate[i] = true;
    }
  }

  independent_variable.flip();
  dependent_variable.flip();
  std::vector<Index> new_inv_index;
  for (size_t i = 0; i < inv_index.size(); i++) {
    Index old_var = inv_index[i];
    if (independent_variable[old_var])
      new_inv_index.push_back(var_remap[old_var]);
  }
  inv_index = new_inv_index;
  std::vector<Index> new_dep_index;
  for (size_t i = 0; i < dep_index.size(); i++) {
    Index old_var = dep_index[i];
    if (dependent_variable[old_var])
      new_dep_index.push_back(var_remap[old_var]);
  }
  dep_index = new_dep_index;

  inputs.resize(s_input);
  values.resize(s);
  size_t k = 0;
  for (size_t i = 0; i < opstack.size(); i++) {
    if (opstack_deallocate[i]) {
      opstack[i]->deallocate();
    } else {
      opstack[k] = opstack[i];
      k++;
    }
  }
  opstack.resize(k);

  if (opstack.any.test(op_info::dynamic)) this->forward();
}

global global::extract_sub() {
  std::vector<Index> var_remap;
  return extract_sub(var_remap);
}

std::vector<Index> global::var2op() {
  std::vector<Index> var2op(values.size());
  Args<> args(inputs);
  size_t j = 0;
  for (size_t i = 0; i < opstack.size(); i++) {
    opstack[i]->increment(args.ptr);
    for (; j < (size_t)args.ptr.second; j++) {
      var2op[j] = i;
    }
  }
  return var2op;
}

std::vector<bool> global::var2op(const std::vector<bool> &values,
                                 bool include_updating) {
  bool any_upd =
      include_updating && opstack.any.test(op_info::forward_updating);
  std::vector<bool> ans(opstack.size(), false);
  Args<> args(inputs);
  size_t j = 0;
  for (size_t i = 0; i < opstack.size(); i++) {
    if (any_upd && opstack[i]->info().test(op_info::forward_updating)) {
      Dependencies dep;
      opstack[i]->dependencies_updating(args, dep);
      ans[i] = ans[i] || dep.any(values);
    }

    opstack[i]->increment(args.ptr);
    for (; j < (size_t)args.ptr.second; j++) {
      ans[i] = ans[i] || values[j];
    }
  }
  return ans;
}

std::vector<Index> global::op2var(const std::vector<Index> &seq) {
  std::vector<bool> seq_mark = mark_space(opstack.size(), seq);
  std::vector<Index> ans;
  Args<> args(inputs);
  size_t j = 0;
  for (size_t i = 0; i < opstack.size(); i++) {
    opstack[i]->increment(args.ptr);
    for (; j < (size_t)args.ptr.second; j++) {
      if (seq_mark[i]) ans.push_back(j);
    }
  }
  return ans;
}

std::vector<bool> global::op2var(const std::vector<bool> &seq_mark) {
  std::vector<bool> ans(values.size());
  Args<> args(inputs);
  size_t j = 0;
  for (size_t i = 0; i < opstack.size(); i++) {
    opstack[i]->increment(args.ptr);
    for (; j < (size_t)args.ptr.second; j++) {
      if (seq_mark[i]) ans[j] = true;
    }
  }
  return ans;
}

std::vector<Index> global::op2idx(const std::vector<Index> &var_subset,
                                  Index NA) {
  std::vector<Index> v2o = var2op();
  std::vector<Index> op2idx(opstack.size(), NA);
  for (size_t i = var_subset.size(); i > 0;) {
    i--;
    op2idx[v2o[var_subset[i]]] = i;
  }
  return op2idx;
}

std::vector<bool> global::mark_space(size_t n, const std::vector<Index> ind) {
  std::vector<bool> mark(n, false);
  for (size_t i = 0; i < ind.size(); i++) {
    mark[ind[i]] = true;
  }
  return mark;
}

std::vector<bool> global::inv_marks() {
  return mark_space(values.size(), inv_index);
}

std::vector<bool> global::dep_marks() {
  return mark_space(values.size(), dep_index);
}

std::vector<bool> global::subgraph_marks() {
  return mark_space(opstack.size(), subgraph_seq);
}

global::append_edges::append_edges(size_t &i, size_t num_nodes,
                                   const std::vector<bool> &keep_var,
                                   std::vector<Index> &var2op,
                                   std::vector<IndexPair> &edges)
    : i(i),
      keep_var(keep_var),
      var2op(var2op),
      edges(edges),
      op_marks(num_nodes, false),
      pos(0) {}

void global::append_edges::operator()(Index dep_j) {
  if (keep_var[dep_j]) {
    size_t k = var2op[dep_j];
    if (i != k && !op_marks[k]) {
      IndexPair edge;

      edge.first = k;
      edge.second = i;
      edges.push_back(edge);
      op_marks[k] = true;
    }
  }
}

void global::append_edges::start_iteration() { pos = edges.size(); }

void global::append_edges::end_iteration() {
  size_t n = edges.size() - pos;
  for (size_t j = 0; j < n; j++) op_marks[edges[pos + j].first] = false;
}

graph global::build_graph(bool transpose, const std::vector<bool> &keep_var,
                          bool deriv) {
  TMBAD_ASSERT(keep_var.size() == values.size());

  std::vector<Index> var2op = this->var2op();

  bool any_updating = false;

  Args<> args(inputs);
  std::vector<IndexPair> edges;
  Dependencies dep;
  size_t i = 0;
  append_edges F(i, opstack.size(), keep_var, var2op, edges);
  for (; i < opstack.size(); i++) {
    op_info ifo = opstack[i]->info();
    bool skip_node = deriv && ifo.test(op_info::is_zero_deriv);
    any_updating |= ifo.test(op_info::forward_updating);
    if (!skip_node) {
      dep.clear();
      opstack[i]->dependencies(args, dep);
      F.start_iteration();
      dep.apply(F);
      F.end_iteration();
    }
    opstack[i]->increment(args.ptr);
  }
  if (any_updating) {
    size_t begin = edges.size();
    i = 0;
    args = Args<>(inputs);
    for (; i < opstack.size(); i++) {
      dep.clear();
      opstack[i]->dependencies_updating(args, dep);
      F.start_iteration();
      dep.apply(F);
      F.end_iteration();
      opstack[i]->increment(args.ptr);
    }
    for (size_t j = begin; j < edges.size(); j++)
      std::swap(edges[j].first, edges[j].second);
  }

  if (transpose) {
    for (size_t j = 0; j < edges.size(); j++)
      std::swap(edges[j].first, edges[j].second);
  }

  graph G(opstack.size(), edges);

  for (size_t i = 0; i < inv_index.size(); i++)
    G.inv2op.push_back(var2op[inv_index[i]]);
  for (size_t i = 0; i < dep_index.size(); i++)
    G.dep2op.push_back(var2op[dep_index[i]]);
  return G;
}

graph global::forward_graph(std::vector<bool> keep_var, bool deriv) {
  if (keep_var.size() == 0) {
    keep_var.resize(values.size(), true);
  }
  TMBAD_ASSERT(values.size() == keep_var.size());
  return build_graph(false, keep_var, deriv);
}

graph global::reverse_graph(std::vector<bool> keep_var, bool deriv) {
  if (keep_var.size() == 0) {
    keep_var.resize(values.size(), true);
  }
  TMBAD_ASSERT(values.size() == keep_var.size());
  return build_graph(true, keep_var, deriv);
}

bool global::identical(const global &other) const {
  if (inv_index != other.inv_index) return false;
  ;
  if (dep_index != other.dep_index) return false;
  ;
  if (opstack.size() != other.opstack.size()) return false;
  ;
  for (size_t i = 0; i < opstack.size(); i++) {
    if (opstack[i]->identifier() != other.opstack[i]->identifier())
      return false;
    ;
  }
  if (inputs != other.inputs) return false;
  ;
  if (values.size() != other.values.size()) return false;
  ;
  OperatorPure *constant = getOperator<ConstOp>();
  IndexPair ptr(0, 0);
  for (size_t i = 0; i < opstack.size(); i++) {
    if (opstack[i] == constant) {
      if (values[ptr.second] != other.values[ptr.second]) return false;
      ;
    }
    opstack[i]->increment(ptr);
  }

  return true;
}

hash_t global::hash() const {
  hash_t h = 37;

  hash(h, inv_index.size());
  ;
  for (size_t i = 0; i < inv_index.size(); i++) hash(h, inv_index[i]);
  ;
  ;
  hash(h, dep_index.size());
  ;
  for (size_t i = 0; i < dep_index.size(); i++) hash(h, dep_index[i]);
  ;
  ;
  hash(h, opstack.size());
  ;
  for (size_t i = 0; i < opstack.size(); i++) hash(h, opstack[i]);
  ;
  ;
  hash(h, inputs.size());
  ;
  for (size_t i = 0; i < inputs.size(); i++) hash(h, inputs[i]);
  ;
  ;
  hash(h, values.size());
  ;
  OperatorPure *constant = getOperator<ConstOp>();
  IndexPair ptr(0, 0);
  for (size_t i = 0; i < opstack.size(); i++) {
    if (opstack[i] == constant) {
      hash(h, values[ptr.second]);
      ;
    }
    opstack[i]->increment(ptr);
  }

  return h;
}

std::vector<hash_t> global::hash_sweep(hash_config cfg) const {
  std::vector<Index> opstack_id;
  if (cfg.deterministic) {
    std::vector<size_t> tmp(opstack.size());
    for (size_t i = 0; i < tmp.size(); i++)
      tmp[i] = (size_t)opstack[i]->identifier();
    opstack_id = radix::first_occurance<Index>(tmp);
    hash_t spread = (hash_t(1) << (sizeof(hash_t) * 4)) - 1;
    for (size_t i = 0; i < opstack_id.size(); i++)
      opstack_id[i] = (opstack_id[i] + 1) * spread;
  }

  std::vector<hash_t> hash_vec(values.size(), 37);
  Dependencies dep;
  OperatorPure *inv = getOperator<InvOp>();
  OperatorPure *constant = getOperator<ConstOp>();

  if (cfg.strong_inv) {
    bool have_inv_seed = (cfg.inv_seed.size() > 0);
    if (have_inv_seed) {
      TMBAD_ASSERT(cfg.inv_seed.size() == inv_index.size());
    }
    for (size_t i = 0; i < inv_index.size(); i++) {
      hash_vec[inv_index[i]] += (have_inv_seed ? cfg.inv_seed[i] + 1 : (i + 1));
    }
  }

  Args<> args(inputs);
  IndexPair &ptr = args.ptr;
  for (size_t i = 0; i < opstack.size(); i++) {
    if (opstack[i] == inv) {
      opstack[i]->increment(ptr);
      continue;
    }
    dep.clear();

    opstack[i]->dependencies(args, dep);

    hash_t h = 37;
    for (size_t j = 0; j < dep.size(); j++) {
      if (j == 0)
        h = hash_vec[dep[0]];
      else
        hash(h, hash_vec[dep[j]]);
      ;
    }

    if (!cfg.deterministic) {
      hash(h, opstack[i]->identifier());
      ;
    } else {
      hash(h, opstack_id[i]);
      ;
    }

    if (opstack[i] == constant && cfg.strong_const) {
      hash(h, values[ptr.second]);
      ;

      hash(h, values[ptr.second] > 0);
      ;
    }

    size_t noutput = opstack[i]->output_size();
    for (size_t j = 0; j < noutput; j++) {
      hash_vec[ptr.second + j] = h + j * cfg.strong_output;
    }

    opstack[i]->increment(ptr);
  }
  if (!cfg.reduce) return hash_vec;
  std::vector<hash_t> ans(dep_index.size());
  for (size_t j = 0; j < dep_index.size(); j++) {
    ans[j] = hash_vec[dep_index[j]];
  }
  return ans;
}

std::vector<hash_t> global::hash_sweep(bool weak) const {
  hash_config cfg;
  cfg.strong_inv = !weak;
  cfg.strong_const = true;
  cfg.strong_output = true;
  cfg.reduce = weak;
  cfg.deterministic = TMBAD_DETERMINISTIC_HASH;
  return hash_sweep(cfg);
}

void global::eliminate() {
  this->shrink_to_fit();

  std::vector<bool> marks;
  marks.resize(values.size(), false);

  for (size_t i = 0; i < inv_index.size(); i++) marks[inv_index[i]] = true;
  for (size_t i = 0; i < dep_index.size(); i++) marks[dep_index[i]] = true;

  reverse(marks);

  if (false) {
    set_subgraph(marks);

    *this = extract_sub();
  }
  this->extract_sub_inplace(marks);
  this->shrink_to_fit();
}

global::print_config::print_config() : prefix(""), mark("*"), depth(0) {}

void global::print(print_config cfg) {
  using std::endl;
  using std::left;
  using std::setw;
  IndexPair ptr(0, 0);
  std::vector<bool> sgm = subgraph_marks();
  bool have_subgraph = (subgraph_seq.size() > 0);
  int v = 0;
  print_config cfg2 = cfg;
  cfg2.depth--;
  cfg2.prefix = cfg.prefix + "##";
  Rcout << cfg.prefix;
  Rcout << setw(7) << "OpName:" << setw(7 + have_subgraph)
        << "Node:" << setw(13) << "Value:" << setw(13) << "Deriv:" << setw(13)
        << "Index:";
  Rcout << "    " << "Inputs:";
  Rcout << endl;
  for (size_t i = 0; i < opstack.size(); i++) {
    Rcout << cfg.prefix;
    Rcout << setw(7) << opstack[i]->op_name();
    if (have_subgraph) {
      if (sgm[i])
        Rcout << cfg.mark;
      else
        Rcout << " ";
    }
    Rcout << setw(7) << i;
    int numvar = opstack[i]->output_size();
    for (int j = 0; j < numvar + (numvar == 0); j++) {
      if (j > 0) Rcout << cfg.prefix;
      Rcout << setw((7 + 7) * (j > 0) + 13);
      if (numvar > 0)
        Rcout << values[v];
      else
        Rcout << "";
      Rcout << setw(13);
      if (numvar > 0) {
        if (derivs.size() == values.size())
          Rcout << derivs[v];
        else
          Rcout << "NA";
      } else {
        Rcout << "";
      }
      Rcout << setw(13);
      if (numvar > 0) {
        Rcout << v;
      } else {
        Rcout << "";
      }
      if (j == 0) {
        IndexPair ptr_old = ptr;
        opstack[i]->increment(ptr);
        int ninput = ptr.first - ptr_old.first;
        for (int k = 0; k < ninput; k++) {
          if (k == 0) Rcout << "   ";
          Rcout << " " << inputs[ptr_old.first + k];
        }
      }
      Rcout << endl;
      if (numvar > 0) {
        v++;
      }
    }
    if (cfg.depth > 0) opstack[i]->print(cfg2);
  }
}

void global::print() { this->print(print_config()); }

global::DynamicInputOutputOperator::DynamicInputOutputOperator(Index ninput,
                                                               Index noutput)
    : ninput_(ninput), noutput_(noutput) {}

Index global::DynamicInputOutputOperator::input_size() const {
  return this->ninput_;
}

Index global::DynamicInputOutputOperator::output_size() const {
  return this->noutput_;
}

const char *global::InvOp::op_name() { return "InvOp"; }

const char *global::DepOp::op_name() { return "DepOp"; }

void global::ConstOp::forward(ForwardArgs<Replay> &args) {
  args.y(0).addToTape();
}

const char *global::ConstOp::op_name() { return "ConstOp"; }

void global::ConstOp::forward(ForwardArgs<Writer> &args) {
  if (args.const_literals) {
    args.y(0) = args.y_const(0);
  }
}

global::DataOp::DataOp(Index n) { Base::noutput = n; }

const char *global::DataOp::op_name() { return "DataOp"; }

void global::DataOp::forward(ForwardArgs<Writer> &args) { TMBAD_ASSERT(false); }

global::AllocOp::AllocOp(Index n) { Base::noutput = n; }

void global::AllocOp::forward(ForwardArgs<Scalar> &args) {
  Scalar *y = args.y_ptr(0);
  std::fill(y, y + Base::noutput, Scalar(0));
}

void global::AllocOp::forward(ForwardArgs<Replay> &args) {
  Complete<AllocOp>(Base::noutput).forward_replay_copy(args);
  for (Index i = 0; i < Base::noutput; i++) args.y(i).setUpdatable(true);
}

const char *global::AllocOp::op_name() { return "AllocOp"; }

void global::AllocOp::forward(ForwardArgs<Writer> &args) {
  TMBAD_ASSERT(false);
}

void global::AllocOp::operator()(Replay *x, Index n) {
  Complete<AllocOp> Z(n);
  ad_segment y = Z(ad_segment());
  for (size_t i = 0; i < n; i++) {
    x[i] = y[i];
    x[i].setUpdatable(true);
  }
}

global::ZeroOp::ZeroOp(Index n) : n(n) {}

void global::ZeroOp::forward(ForwardArgs<Scalar> &args) {
  Scalar *x = args.x_ptr(0);
  std::fill(x, x + n, Scalar(0));
}

void global::ZeroOp::reverse(ReverseArgs<Scalar> &args) {
  Scalar *dx = args.dx_ptr(0);
  std::fill(dx, dx + n, Scalar(0));
}

const char *global::ZeroOp::op_name() { return "ZeroOp"; }

void global::ZeroOp::dependencies(Args<> &args, Dependencies &dep) const {}

void global::ZeroOp::dependencies_updating(Args<> &args,
                                           Dependencies &dep) const {
  dep.add_segment(args.input(0), n);
}

void global::ZeroOp::operator()(Replay *x, Index n) {
  TMBAD_ASSERT2(n > 0, "'ZeroOp' requires non-zero length");
  bool all_updatable = true;
  for (size_t i = 0; i < n; i++)
    all_updatable = all_updatable && x[i].updatable();
  TMBAD_ASSERT2(all_updatable, "'ZeroOp' requires updatable workspace");
  bool consecutive = true;
  for (size_t i = 1; i < n; i++)
    consecutive = consecutive && (x[i].index() - x[i - 1].index() == 1);
  TMBAD_ASSERT2(consecutive, "'ZeroOp' requires consecutive workspace");
  Complete<ZeroOp> Z(n);
  Z(std::vector<ad_plain>(1, x[0]));
}

global::NullOp::NullOp() {}

const char *global::NullOp::op_name() { return "NullOp"; }

global::NullOp2::NullOp2(Index ninput, Index noutput)
    : global::DynamicInputOutputOperator(ninput, noutput) {}

const char *global::NullOp2::op_name() { return "NullOp2"; }

global::RefOp::RefOp(global *glob, Index i) : glob(glob), i(i) {}

void global::RefOp::forward(ForwardArgs<Scalar> &args) {
  args.y(0) = glob->values[i];
}

void global::RefOp::forward(ForwardArgs<Replay> &args) {
  if (get_glob() == this->glob) {
    ad_plain tmp;
    tmp.index = i;
    args.y(0) = tmp;
  } else {
    global::OperatorPure *pOp =
        get_glob()->getOperator<RefOp>(this->glob, this->i);
    args.y(0) =
        get_glob()->add_to_stack<RefOp>(pOp, std::vector<ad_plain>(0))[0];
  }
}

void global::RefOp::reverse(ReverseArgs<Replay> &args) {
  if (get_glob() == this->glob) {
    Replay(args.dx(0)) += args.dy(0);
  }
}

void *global::RefOp::custom_identifier() { return &(glob->values[i]); }

const char *global::RefOp::op_name() { return "RefOp"; }

OperatorPure *global::Fuse(OperatorPure *Op1, OperatorPure *Op2) {
  if (Op1 == Op2)
    return Op1->self_fuse();
  else
    return Op1->other_fuse(Op2);
}

void global::set_fuse(bool flag) { fuse = flag; }

void global::add_to_opstack(OperatorPure *pOp) {
  if (fuse) {
    while (this->opstack.size() > 0) {
      OperatorPure *OpTry = this->Fuse(this->opstack.back(), pOp);
      if (OpTry == NULL) break;

      this->opstack.pop_back();
      pOp = OpTry;
    }
  }

  this->opstack.push_back(pOp);
}

bool global::ad_plain::initialized() const { return index != NA; }

bool global::ad_plain::on_some_tape() const { return initialized(); }

void global::ad_plain::addToTape() const { TMBAD_ASSERT(initialized()); }

global *global::ad_plain::glob() const {
  return (on_some_tape() ? get_glob() : NULL);
}

void global::ad_plain::override_by(const ad_plain &x) const {}

global::ad_plain::ad_plain() : index(NA) {}

global::ad_plain::ad_plain(Scalar x) {
  *this = get_glob()->add_to_stack<ConstOp>(x);
}

global::ad_plain::ad_plain(ad_aug x) {
  x.addToTape();
  x.setUpdatable(false);
  *this = x.taped_value;
}

Replay global::ad_plain::CopyOp::eval(Replay x0) { return x0.copy(); }

const char *global::ad_plain::CopyOp::op_name() { return "CopyOp"; }

ad_plain global::ad_plain::copy() const {
  ad_plain ans = get_glob()->add_to_stack<CopyOp>(*this);
  return ans;
}

Replay global::ad_plain::ValOp::eval(Replay x0) { return x0.copy0(); }

void global::ad_plain::ValOp::dependencies(Args<> &args,
                                           Dependencies &dep) const {}

const char *global::ad_plain::ValOp::op_name() { return "ValOp"; }

ad_plain global::ad_plain::copy0() const {
  ad_plain ans = get_glob()->add_to_stack<ValOp>(*this);
  return ans;
}

ad_plain global::ad_plain::operator+(const ad_plain &other) const {
  ad_plain ans;
  ans = get_glob()->add_to_stack<AddOp>(*this, other);
  return ans;
}

ad_plain global::ad_plain::operator-(const ad_plain &other) const {
  ad_plain ans;
  ans = get_glob()->add_to_stack<SubOp>(*this, other);
  return ans;
}

ad_plain global::ad_plain::operator*(const ad_plain &other) const {
  ad_plain ans = get_glob()->add_to_stack<MulOp>(*this, other);
  return ans;
}

ad_plain global::ad_plain::operator*(const Scalar &other) const {
  ad_plain ans =
      get_glob()->add_to_stack<MulOp_<true, false> >(*this, ad_plain(other));
  return ans;
}

ad_plain global::ad_plain::operator/(const ad_plain &other) const {
  ad_plain ans = get_glob()->add_to_stack<DivOp>(*this, other);
  return ans;
}

const char *global::ad_plain::NegOp::op_name() { return "NegOp"; }

ad_plain global::ad_plain::operator-() const {
  ad_plain ans = get_glob()->add_to_stack<NegOp>(*this);
  return ans;
}

ad_plain &global::ad_plain::operator+=(const ad_plain &other) {
  *this = *this + other;
  return *this;
}

ad_plain &global::ad_plain::operator-=(const ad_plain &other) {
  *this = *this - other;
  return *this;
}

ad_plain &global::ad_plain::operator*=(const ad_plain &other) {
  *this = *this * other;
  return *this;
}

ad_plain &global::ad_plain::operator/=(const ad_plain &other) {
  *this = *this / other;
  return *this;
}

void global::ad_plain::Dependent() {
  *this = get_glob()->add_to_stack<DepOp>(*this);
  get_glob()->dep_index.push_back(this->index);
}

void global::ad_plain::Independent() {
  Scalar val = (index == NA ? NAN : this->Value());
  *this = get_glob()->add_to_stack<InvOp>(val);
  get_glob()->inv_index.push_back(this->index);
}

Scalar &global::ad_plain::Value() { return get_glob()->values[index]; }

Scalar global::ad_plain::Value() const { return get_glob()->values[index]; }

Scalar global::ad_plain::Value(global *glob) const {
  return glob->values[index];
}

Scalar &global::ad_plain::Deriv() { return get_glob()->derivs[index]; }

void global::ad_start() {
  TMBAD_ASSERT2(!in_use, "Tape already in use");
  TMBAD_ASSERT(parent_glob == NULL);
  parent_glob = global_ptr[TMBAD_THREAD_NUM];
  global_ptr[TMBAD_THREAD_NUM] = this;
  in_use = true;
}

void global::ad_stop() {
  TMBAD_ASSERT2(in_use, "Tape not in use");
  global_ptr[TMBAD_THREAD_NUM] = parent_glob;
  parent_glob = NULL;
  in_use = false;
}

void global::Independent(std::vector<ad_plain> &x) {
  for (size_t i = 0; i < x.size(); i++) {
    x[i].Independent();
  }
}

global::ad_segment::ad_segment() : n(0), c(0) {}

global::ad_segment::ad_segment(ad_plain x, size_t n) : x(x), n(n), c(1) {}

global::ad_segment::ad_segment(ad_aug x) : x(ad_plain(x)), n(1), c(1) {}

global::ad_segment::ad_segment(Scalar x) : x(ad_plain(x)), n(1), c(1) {}

global::ad_segment::ad_segment(Index idx, size_t n) : n(n) { x.index = idx; }

global::ad_segment::ad_segment(ad_plain x, size_t r, size_t c)
    : x(x), n(r * c), c(c) {}

global::ad_segment::ad_segment(Replay *x, size_t n, bool zero_check)
    : n(n), c(1) {
  if (zero_check && all_zero(x, n)) return;
  if (all_constant(x, n)) {
    global *glob = get_glob();
    size_t m = glob->values.size();
    Complete<DataOp> D(n);
    D(ad_segment());
    for (size_t i = 0; i < n; i++) glob->values[m + i] = x[i].Value();
    this->x.index = m;
    return;
  }
  if (!is_contiguous(x, n)) {
    size_t before = get_glob()->values.size();
    this->x = x[0].copy();
    for (size_t i = 1; i < n; i++) x[i].copy();
    size_t after = get_glob()->values.size();
    TMBAD_ASSERT2(after - before == n,
                  "Each invocation of copy() should construct a new variable");
    return;
  }
  if (n > 0) this->x = x[0];
}

bool global::ad_segment::identicalZero() { return !x.initialized(); }

bool global::ad_segment::all_on_active_tape(Replay *x, size_t n) {
  global *cur_glob = get_glob();
  for (size_t i = 0; i < n; i++) {
    bool ok = x[i].on_some_tape() && (x[i].glob() == cur_glob);
    if (!ok) return false;
  }
  return true;
}

bool global::ad_segment::is_contiguous(Replay *x, size_t n) {
  if (!all_on_active_tape(x, n)) return false;
  for (size_t i = 1; i < n; i++) {
    if (x[i].index() != x[i - 1].index() + 1) return false;
  }
  return true;
}

bool global::ad_segment::all_zero(Replay *x, size_t n) {
  for (size_t i = 0; i < n; i++) {
    if (!x[i].identicalZero()) return false;
  }
  return true;
}

bool global::ad_segment::all_constant(Replay *x, size_t n) {
  for (size_t i = 0; i < n; i++) {
    if (!x[i].constant()) return false;
  }
  return true;
}

size_t global::ad_segment::size() const { return n; }

size_t global::ad_segment::rows() const { return n / c; }

size_t global::ad_segment::cols() const { return c; }

ad_plain global::ad_segment::operator[](size_t i) const {
  ad_plain ans;
  ans.index = x.index + i;
  return ans;
}

ad_plain global::ad_segment::offset() const { return x; }

Index global::ad_segment::index() const { return x.index; }

bool global::ad_aug::on_some_tape() const { return taped_value.initialized(); }

bool global::ad_aug::on_active_tape() const {
  return on_some_tape() && (this->glob() == get_glob());
}

bool global::ad_aug::ontape() const { return on_some_tape(); }

bool global::ad_aug::constant() const { return !taped_value.initialized(); }

Index global::ad_aug::index() const { return taped_value.index; }

global *global::ad_aug::glob() const {
  return (on_some_tape() ? data.glob : NULL);
}

Scalar global::ad_aug::Value() const {
  if (on_some_tape())
    return taped_value.Value(this->data.glob);
  else
    return data.value;
}

global::ad_aug::ad_aug() {}

global::ad_aug::ad_aug(Scalar x) { data.value = x; }

global::ad_aug::ad_aug(ad_plain x) : taped_value(x) { data.glob = get_glob(); }

void global::ad_aug::addToTape() const {
  if (on_some_tape()) {
    if (data.glob != get_glob()) {
      TMBAD_ASSERT2(in_context_stack(data.glob), "Variable not initialized?");
      global::OperatorPure *pOp =
          get_glob()->getOperator<RefOp>(data.glob, taped_value.index);
      this->taped_value =
          get_glob()->add_to_stack<RefOp>(pOp, std::vector<ad_plain>(0))[0];

      this->data.glob = get_glob();
    }
    return;
  }
  this->taped_value = ad_plain(data.value);
  this->data.glob = get_glob();
}

void global::ad_aug::override_by(const ad_plain &x) const {
  this->taped_value = x;
  this->data.glob = get_glob();
}

bool global::ad_aug::in_context_stack(global *glob) const {
  global *cur_glob = get_glob();
  while (cur_glob != NULL) {
    if (cur_glob == glob) return true;
    cur_glob = cur_glob->parent_glob;
  }
  return false;
}

ad_aug global::ad_aug::copy() const {
  if (on_active_tape()) {
    return taped_value.copy();
  } else {
    ad_aug cpy = *this;
    cpy.addToTape();
    return cpy;
  }
}

ad_aug global::ad_aug::copy0() const {
  ad_aug cpy = *this;
  if (!cpy.on_active_tape()) {
    cpy.addToTape();
  }
  return cpy.taped_value.copy0();
}

bool global::ad_aug::identicalZero() const {
  return constant() && data.value == Scalar(0);
}

bool global::ad_aug::identicalOne() const {
  return constant() && data.value == Scalar(1);
}

bool global::ad_aug::bothConstant(const ad_aug &other) const {
  return constant() && other.constant();
}

bool global::ad_aug::identical(const ad_aug &other) const {
  if (constant() && other.constant()) return (data.value == other.data.value);

  if (glob() == other.glob())
    return (taped_value.index == other.taped_value.index);
  return false;
}

void global::ad_aug::setUpdatable(bool flag) {
  if (on_some_tape()) {
    taped_value.index = flag ? (index() | updbit) : (index() & ~updbit);
  } else {
    TMBAD_ASSERT2(!flag, "An untaped constant cannot be made 'updatable'");
  }
}

bool global::ad_aug::updatable() const {
  return on_some_tape() && (index() & updbit);
}

ad_aug global::ad_aug::operator+(const ad_aug &other) const {
  if (bothConstant(other)) return Scalar(this->data.value + other.data.value);
  if (this->identicalZero()) return other;
  if (other.identicalZero()) return *this;
  return ad_plain(*this) + ad_plain(other);
}

ad_aug global::ad_aug::operator-(const ad_aug &other) const {
  if (bothConstant(other)) return Scalar(this->data.value - other.data.value);
  if (other.identicalZero()) return *this;
  if (this->identicalZero()) return -other;
  if (this->identical(other)) return Scalar(0);
  return ad_plain(*this) - ad_plain(other);
}

ad_aug global::ad_aug::operator-() const {
  if (this->constant()) return Scalar(-(this->data.value));
  return -ad_plain(*this);
}

ad_aug global::ad_aug::operator*(const ad_aug &other) const {
  if (bothConstant(other)) return Scalar(this->data.value * other.data.value);
  if (this->identicalZero()) return *this;
  if (other.identicalZero()) return other;
  if (this->identicalOne()) return other;
  if (other.identicalOne()) return *this;
  if (this->constant()) return ad_plain(other) * Scalar(this->data.value);
  if (other.constant()) return ad_plain(*this) * Scalar(other.data.value);
  return ad_plain(*this) * ad_plain(other);
}

ad_aug global::ad_aug::operator/(const ad_aug &other) const {
  if (bothConstant(other)) return Scalar(this->data.value / other.data.value);
  if (this->identicalZero()) return *this;
  if (other.identicalOne()) return *this;
  return ad_plain(*this) / ad_plain(other);
}

ad_aug &global::ad_aug::operator+=(const ad_aug &other) {
  *this = *this + other;
  return *this;
}

ad_aug &global::ad_aug::operator-=(const ad_aug &other) {
  *this = *this - other;
  return *this;
}

ad_aug &global::ad_aug::operator*=(const ad_aug &other) {
  *this = *this * other;
  return *this;
}

ad_aug &global::ad_aug::operator/=(const ad_aug &other) {
  *this = *this / other;
  return *this;
}

void global::ad_aug::Dependent() {
  this->setUpdatable(false);
  this->addToTape();
  taped_value.Dependent();
}

void global::ad_aug::Independent() {
  taped_value.Independent();
  taped_value.Value() = this->data.value;
  this->data.glob = get_glob();
}

Scalar &global::ad_aug::Value() {
  if (on_some_tape())

    return taped_value.Value();
  else
    return data.value;
}

Scalar &global::ad_aug::Deriv() { return taped_value.Deriv(); }

void global::Independent(std::vector<ad_aug> &x) {
  for (size_t i = 0; i < x.size(); i++) {
    x[i].Independent();
  }
}

std::ostream &operator<<(std::ostream &os, const global::ad_plain &x) {
  os << x.Value();
  return os;
}

std::ostream &operator<<(std::ostream &os, const global::ad_aug &x) {
  os << "{";
  if (x.on_some_tape()) {
    os << "value=" << x.data.glob->values[x.taped_value.index] << ", ";
    os << "index=" << x.taped_value.index << ", ";
    os << "tape=" << x.data.glob;
  } else {
    os << "const=" << x.data.value;
  }
  os << "}";
  return os;
}

ad_plain_index::ad_plain_index(const Index &i) { this->index = i; }

ad_plain_index::ad_plain_index(const ad_plain &x) : ad_plain(x) {}

ad_aug_index::ad_aug_index(const Index &i) : ad_aug(ad_plain_index(i)) {}

ad_aug_index::ad_aug_index(const ad_aug &x) : ad_aug(x) {}

ad_aug_index::ad_aug_index(const ad_plain &x) : ad_aug(x) {}

Scalar Value(Scalar x) { return x; }

ad_aug operator+(const double &x, const ad_aug &y) { return ad_aug(x) + y; }

ad_aug operator-(const double &x, const ad_aug &y) { return ad_aug(x) - y; }

ad_aug operator*(const double &x, const ad_aug &y) { return ad_aug(x) * y; }

ad_aug operator/(const double &x, const ad_aug &y) { return ad_aug(x) / y; }

bool operator<(const double &x, const ad_adapt &y) { return x < y.Value(); }

bool operator<=(const double &x, const ad_adapt &y) { return x <= y.Value(); }

bool operator>(const double &x, const ad_adapt &y) { return x > y.Value(); }

bool operator>=(const double &x, const ad_adapt &y) { return x >= y.Value(); }

bool operator==(const double &x, const ad_adapt &y) { return x == y.Value(); }

bool operator!=(const double &x, const ad_adapt &y) { return x != y.Value(); }

Writer floor(const Writer &x) {
  return "floor"
         "(" +
         x + ")";
}
const char *FloorOp::op_name() { return "FloorOp"; }
ad_plain floor(const ad_plain &x) {
  return get_glob()->add_to_stack<FloorOp>(x);
}
ad_aug floor(const ad_aug &x) {
  if (x.constant())
    return Scalar(floor(x.Value()));
  else
    return floor(ad_plain(x));
}

Writer ceil(const Writer &x) {
  return "ceil"
         "(" +
         x + ")";
}
const char *CeilOp::op_name() { return "CeilOp"; }
ad_plain ceil(const ad_plain &x) { return get_glob()->add_to_stack<CeilOp>(x); }
ad_aug ceil(const ad_aug &x) {
  if (x.constant())
    return Scalar(ceil(x.Value()));
  else
    return ceil(ad_plain(x));
}

Writer trunc(const Writer &x) {
  return "trunc"
         "(" +
         x + ")";
}
const char *TruncOp::op_name() { return "TruncOp"; }
ad_plain trunc(const ad_plain &x) {
  return get_glob()->add_to_stack<TruncOp>(x);
}
ad_aug trunc(const ad_aug &x) {
  if (x.constant())
    return Scalar(trunc(x.Value()));
  else
    return trunc(ad_plain(x));
}

Writer round(const Writer &x) {
  return "round"
         "(" +
         x + ")";
}
const char *RoundOp::op_name() { return "RoundOp"; }
ad_plain round(const ad_plain &x) {
  return get_glob()->add_to_stack<RoundOp>(x);
}
ad_aug round(const ad_aug &x) {
  if (x.constant())
    return Scalar(round(x.Value()));
  else
    return round(ad_plain(x));
}

double sign(const double &x) { return (x >= 0) - (x < 0); }

Writer sign(const Writer &x) {
  return "sign"
         "(" +
         x + ")";
}
const char *SignOp::op_name() { return "SignOp"; }
ad_plain sign(const ad_plain &x) { return get_glob()->add_to_stack<SignOp>(x); }
ad_aug sign(const ad_aug &x) {
  if (x.constant())
    return Scalar(sign(x.Value()));
  else
    return sign(ad_plain(x));
}

double ge0(const double &x) { return (x >= 0); }

double lt0(const double &x) { return (x < 0); }

Writer ge0(const Writer &x) {
  return "ge0"
         "(" +
         x + ")";
}
const char *Ge0Op::op_name() { return "Ge0Op"; }
ad_plain ge0(const ad_plain &x) { return get_glob()->add_to_stack<Ge0Op>(x); }
ad_aug ge0(const ad_aug &x) {
  if (x.constant())
    return Scalar(ge0(x.Value()));
  else
    return ge0(ad_plain(x));
}

Writer lt0(const Writer &x) {
  return "lt0"
         "(" +
         x + ")";
}
const char *Lt0Op::op_name() { return "Lt0Op"; }
ad_plain lt0(const ad_plain &x) { return get_glob()->add_to_stack<Lt0Op>(x); }
ad_aug lt0(const ad_aug &x) {
  if (x.constant())
    return Scalar(lt0(x.Value()));
  else
    return lt0(ad_plain(x));
}

Writer zeroDeriv(const Writer &x) {
  return "zeroDeriv"
         "(" +
         x + ")";
}
const char *ZderivOp::op_name() { return "ZderivOp"; }
ad_plain zeroDeriv(const ad_plain &x) {
  return get_glob()->add_to_stack<ZderivOp>(x);
}
ad_aug zeroDeriv(const ad_aug &x) {
  if (x.constant())
    return Scalar(zeroDeriv(x.Value()));
  else
    return zeroDeriv(ad_plain(x));
}

const char *SderivOp::op_name() { return "SderivOp"; }

ad_plain sparseDeriv(const ad_plain &x) {
  return get_glob()->add_to_stack<SderivOp>(x);
}

double sparseDeriv(const double &x) { return x; }

Writer fabs(const Writer &x) {
  return "fabs"
         "(" +
         x + ")";
}
void AbsOp::reverse(ReverseArgs<Scalar> &args) {
  typedef Scalar Type;
  if (args.dy(0) != Type(0)) args.dx(0) += args.dy(0) * (sign(args.x(0)));
}
const char *AbsOp::op_name() { return "AbsOp"; }
ad_plain fabs(const ad_plain &x) { return get_glob()->add_to_stack<AbsOp>(x); }
ad_aug fabs(const ad_aug &x) {
  if (x.constant())
    return Scalar(fabs(x.Value()));
  else
    return fabs(ad_plain(x));
}
ad_adapt fabs(const ad_adapt &x) { return ad_adapt(fabs(ad_aug(x))); }

Writer sin(const Writer &x) {
  return "sin"
         "(" +
         x + ")";
}
void SinOp::reverse(ReverseArgs<Scalar> &args) {
  typedef Scalar Type;
  if (args.dy(0) != Type(0)) args.dx(0) += args.dy(0) * (cos(args.x(0)));
}
const char *SinOp::op_name() { return "SinOp"; }
ad_plain sin(const ad_plain &x) { return get_glob()->add_to_stack<SinOp>(x); }
ad_aug sin(const ad_aug &x) {
  if (x.constant())
    return Scalar(sin(x.Value()));
  else
    return sin(ad_plain(x));
}
ad_adapt sin(const ad_adapt &x) { return ad_adapt(sin(ad_aug(x))); }

Writer cos(const Writer &x) {
  return "cos"
         "(" +
         x + ")";
}
void CosOp::reverse(ReverseArgs<Scalar> &args) {
  typedef Scalar Type;
  if (args.dy(0) != Type(0)) args.dx(0) += args.dy(0) * (-sin(args.x(0)));
}
const char *CosOp::op_name() { return "CosOp"; }
ad_plain cos(const ad_plain &x) { return get_glob()->add_to_stack<CosOp>(x); }
ad_aug cos(const ad_aug &x) {
  if (x.constant())
    return Scalar(cos(x.Value()));
  else
    return cos(ad_plain(x));
}
ad_adapt cos(const ad_adapt &x) { return ad_adapt(cos(ad_aug(x))); }

Writer exp(const Writer &x) {
  return "exp"
         "(" +
         x + ")";
}
void ExpOp::reverse(ReverseArgs<Scalar> &args) {
  typedef Scalar Type;
  if (args.dy(0) != Type(0)) args.dx(0) += args.dy(0) * (args.y(0));
}
const char *ExpOp::op_name() { return "ExpOp"; }
ad_plain exp(const ad_plain &x) { return get_glob()->add_to_stack<ExpOp>(x); }
ad_aug exp(const ad_aug &x) {
  if (x.constant())
    return Scalar(exp(x.Value()));
  else
    return exp(ad_plain(x));
}
ad_adapt exp(const ad_adapt &x) { return ad_adapt(exp(ad_aug(x))); }

Writer log(const Writer &x) {
  return "log"
         "(" +
         x + ")";
}
void LogOp::reverse(ReverseArgs<Scalar> &args) {
  typedef Scalar Type;
  if (args.dy(0) != Type(0)) args.dx(0) += args.dy(0) * (Type(1.) / args.x(0));
}
const char *LogOp::op_name() { return "LogOp"; }
ad_plain log(const ad_plain &x) { return get_glob()->add_to_stack<LogOp>(x); }
ad_aug log(const ad_aug &x) {
  if (x.constant())
    return Scalar(log(x.Value()));
  else
    return log(ad_plain(x));
}
ad_adapt log(const ad_adapt &x) { return ad_adapt(log(ad_aug(x))); }

Writer sqrt(const Writer &x) {
  return "sqrt"
         "(" +
         x + ")";
}
void SqrtOp::reverse(ReverseArgs<Scalar> &args) {
  typedef Scalar Type;
  if (args.dy(0) != Type(0)) args.dx(0) += args.dy(0) * (Type(0.5) / args.y(0));
}
const char *SqrtOp::op_name() { return "SqrtOp"; }
ad_plain sqrt(const ad_plain &x) { return get_glob()->add_to_stack<SqrtOp>(x); }
ad_aug sqrt(const ad_aug &x) {
  if (x.constant())
    return Scalar(sqrt(x.Value()));
  else
    return sqrt(ad_plain(x));
}
ad_adapt sqrt(const ad_adapt &x) { return ad_adapt(sqrt(ad_aug(x))); }

Writer tan(const Writer &x) {
  return "tan"
         "(" +
         x + ")";
}
void TanOp::reverse(ReverseArgs<Scalar> &args) {
  typedef Scalar Type;
  if (args.dy(0) != Type(0))
    args.dx(0) += args.dy(0) * (Type(1.) / (cos(args.x(0)) * cos(args.x(0))));
}
const char *TanOp::op_name() { return "TanOp"; }
ad_plain tan(const ad_plain &x) { return get_glob()->add_to_stack<TanOp>(x); }
ad_aug tan(const ad_aug &x) {
  if (x.constant())
    return Scalar(tan(x.Value()));
  else
    return tan(ad_plain(x));
}
ad_adapt tan(const ad_adapt &x) { return ad_adapt(tan(ad_aug(x))); }

Writer sinh(const Writer &x) {
  return "sinh"
         "(" +
         x + ")";
}
void SinhOp::reverse(ReverseArgs<Scalar> &args) {
  typedef Scalar Type;
  if (args.dy(0) != Type(0)) args.dx(0) += args.dy(0) * (cosh(args.x(0)));
}
const char *SinhOp::op_name() { return "SinhOp"; }
ad_plain sinh(const ad_plain &x) { return get_glob()->add_to_stack<SinhOp>(x); }
ad_aug sinh(const ad_aug &x) {
  if (x.constant())
    return Scalar(sinh(x.Value()));
  else
    return sinh(ad_plain(x));
}
ad_adapt sinh(const ad_adapt &x) { return ad_adapt(sinh(ad_aug(x))); }

Writer cosh(const Writer &x) {
  return "cosh"
         "(" +
         x + ")";
}
void CoshOp::reverse(ReverseArgs<Scalar> &args) {
  typedef Scalar Type;
  if (args.dy(0) != Type(0)) args.dx(0) += args.dy(0) * (sinh(args.x(0)));
}
const char *CoshOp::op_name() { return "CoshOp"; }
ad_plain cosh(const ad_plain &x) { return get_glob()->add_to_stack<CoshOp>(x); }
ad_aug cosh(const ad_aug &x) {
  if (x.constant())
    return Scalar(cosh(x.Value()));
  else
    return cosh(ad_plain(x));
}
ad_adapt cosh(const ad_adapt &x) { return ad_adapt(cosh(ad_aug(x))); }

Writer tanh(const Writer &x) {
  return "tanh"
         "(" +
         x + ")";
}
void TanhOp::reverse(ReverseArgs<Scalar> &args) {
  typedef Scalar Type;
  if (args.dy(0) != Type(0))
    args.dx(0) += args.dy(0) * (Type(1.) / (cosh(args.x(0)) * cosh(args.x(0))));
}
const char *TanhOp::op_name() { return "TanhOp"; }
ad_plain tanh(const ad_plain &x) { return get_glob()->add_to_stack<TanhOp>(x); }
ad_aug tanh(const ad_aug &x) {
  if (x.constant())
    return Scalar(tanh(x.Value()));
  else
    return tanh(ad_plain(x));
}
ad_adapt tanh(const ad_adapt &x) { return ad_adapt(tanh(ad_aug(x))); }

Writer expm1(const Writer &x) {
  return "expm1"
         "(" +
         x + ")";
}
void Expm1::reverse(ReverseArgs<Scalar> &args) {
  typedef Scalar Type;
  if (args.dy(0) != Type(0)) args.dx(0) += args.dy(0) * (args.y(0) + Type(1.));
}
const char *Expm1::op_name() { return "Expm1"; }
ad_plain expm1(const ad_plain &x) { return get_glob()->add_to_stack<Expm1>(x); }
ad_aug expm1(const ad_aug &x) {
  if (x.constant())
    return Scalar(expm1(x.Value()));
  else
    return expm1(ad_plain(x));
}
ad_adapt expm1(const ad_adapt &x) { return ad_adapt(expm1(ad_aug(x))); }

Writer log1p(const Writer &x) {
  return "log1p"
         "(" +
         x + ")";
}
void Log1p::reverse(ReverseArgs<Scalar> &args) {
  typedef Scalar Type;
  if (args.dy(0) != Type(0))
    args.dx(0) += args.dy(0) * (Type(1.) / (args.x(0) + Type(1.)));
}
const char *Log1p::op_name() { return "Log1p"; }
ad_plain log1p(const ad_plain &x) { return get_glob()->add_to_stack<Log1p>(x); }
ad_aug log1p(const ad_aug &x) {
  if (x.constant())
    return Scalar(log1p(x.Value()));
  else
    return log1p(ad_plain(x));
}
ad_adapt log1p(const ad_adapt &x) { return ad_adapt(log1p(ad_aug(x))); }

Writer asin(const Writer &x) {
  return "asin"
         "(" +
         x + ")";
}
void AsinOp::reverse(ReverseArgs<Scalar> &args) {
  typedef Scalar Type;
  if (args.dy(0) != Type(0))
    args.dx(0) +=
        args.dy(0) * (Type(1.) / sqrt(Type(1.) - args.x(0) * args.x(0)));
}
const char *AsinOp::op_name() { return "AsinOp"; }
ad_plain asin(const ad_plain &x) { return get_glob()->add_to_stack<AsinOp>(x); }
ad_aug asin(const ad_aug &x) {
  if (x.constant())
    return Scalar(asin(x.Value()));
  else
    return asin(ad_plain(x));
}
ad_adapt asin(const ad_adapt &x) { return ad_adapt(asin(ad_aug(x))); }

Writer acos(const Writer &x) {
  return "acos"
         "(" +
         x + ")";
}
void AcosOp::reverse(ReverseArgs<Scalar> &args) {
  typedef Scalar Type;
  if (args.dy(0) != Type(0))
    args.dx(0) +=
        args.dy(0) * (Type(-1.) / sqrt(Type(1.) - args.x(0) * args.x(0)));
}
const char *AcosOp::op_name() { return "AcosOp"; }
ad_plain acos(const ad_plain &x) { return get_glob()->add_to_stack<AcosOp>(x); }
ad_aug acos(const ad_aug &x) {
  if (x.constant())
    return Scalar(acos(x.Value()));
  else
    return acos(ad_plain(x));
}
ad_adapt acos(const ad_adapt &x) { return ad_adapt(acos(ad_aug(x))); }

Writer atan(const Writer &x) {
  return "atan"
         "(" +
         x + ")";
}
void AtanOp::reverse(ReverseArgs<Scalar> &args) {
  typedef Scalar Type;
  if (args.dy(0) != Type(0))
    args.dx(0) += args.dy(0) * (Type(1.) / (Type(1.) + args.x(0) * args.x(0)));
}
const char *AtanOp::op_name() { return "AtanOp"; }
ad_plain atan(const ad_plain &x) { return get_glob()->add_to_stack<AtanOp>(x); }
ad_aug atan(const ad_aug &x) {
  if (x.constant())
    return Scalar(atan(x.Value()));
  else
    return atan(ad_plain(x));
}
ad_adapt atan(const ad_adapt &x) { return ad_adapt(atan(ad_aug(x))); }

Writer asinh(const Writer &x) {
  return "asinh"
         "(" +
         x + ")";
}
void AsinhOp::reverse(ReverseArgs<Scalar> &args) {
  typedef Scalar Type;
  if (args.dy(0) != Type(0))
    args.dx(0) +=
        args.dy(0) * (Type(1.) / sqrt(args.x(0) * args.x(0) + Type(1.)));
}
const char *AsinhOp::op_name() { return "AsinhOp"; }
ad_plain asinh(const ad_plain &x) {
  return get_glob()->add_to_stack<AsinhOp>(x);
}
ad_aug asinh(const ad_aug &x) {
  if (x.constant())
    return Scalar(asinh(x.Value()));
  else
    return asinh(ad_plain(x));
}
ad_adapt asinh(const ad_adapt &x) { return ad_adapt(asinh(ad_aug(x))); }

Writer acosh(const Writer &x) {
  return "acosh"
         "(" +
         x + ")";
}
void AcoshOp::reverse(ReverseArgs<Scalar> &args) {
  typedef Scalar Type;
  if (args.dy(0) != Type(0))
    args.dx(0) +=
        args.dy(0) * (Type(1.) / sqrt(args.x(0) * args.x(0) - Type(1.)));
}
const char *AcoshOp::op_name() { return "AcoshOp"; }
ad_plain acosh(const ad_plain &x) {
  return get_glob()->add_to_stack<AcoshOp>(x);
}
ad_aug acosh(const ad_aug &x) {
  if (x.constant())
    return Scalar(acosh(x.Value()));
  else
    return acosh(ad_plain(x));
}
ad_adapt acosh(const ad_adapt &x) { return ad_adapt(acosh(ad_aug(x))); }

Writer atanh(const Writer &x) {
  return "atanh"
         "(" +
         x + ")";
}
void AtanhOp::reverse(ReverseArgs<Scalar> &args) {
  typedef Scalar Type;
  if (args.dy(0) != Type(0))
    args.dx(0) += args.dy(0) * (Type(1.) / (Type(1) - args.x(0) * args.x(0)));
}
const char *AtanhOp::op_name() { return "AtanhOp"; }
ad_plain atanh(const ad_plain &x) {
  return get_glob()->add_to_stack<AtanhOp>(x);
}
ad_aug atanh(const ad_aug &x) {
  if (x.constant())
    return Scalar(atanh(x.Value()));
  else
    return atanh(ad_plain(x));
}
ad_adapt atanh(const ad_adapt &x) { return ad_adapt(atanh(ad_aug(x))); }

Writer atan2(const Writer &x1, const Writer &x2) {
  return "atan2"
         "(" +
         x1 + "," + x2 + ")";
}
const char *Atan2::op_name() { return "Atan2"; }
ad_plain atan2(const ad_plain &x1, const ad_plain &x2) {
  return get_glob()->add_to_stack<Atan2>(x1, x2);
}
ad_aug atan2(const ad_aug &x1, const ad_aug &x2) {
  if (x1.constant() && x2.constant())
    return Scalar(atan2(x1.Value(), x2.Value()));
  else
    return atan2(ad_plain(x1), ad_plain(x2));
}
ad_adapt atan2(const ad_adapt &x1, const ad_adapt &x2) {
  return ad_adapt(atan2(ad_aug(x1), ad_aug(x2)));
}

Writer max(const Writer &x1, const Writer &x2) {
  return "max"
         "(" +
         x1 + "," + x2 + ")";
}
const char *MaxOp::op_name() { return "MaxOp"; }
ad_plain max(const ad_plain &x1, const ad_plain &x2) {
  return get_glob()->add_to_stack<MaxOp>(x1, x2);
}
ad_aug max(const ad_aug &x1, const ad_aug &x2) {
  if (x1.constant() && x2.constant())
    return Scalar(max(x1.Value(), x2.Value()));
  else
    return max(ad_plain(x1), ad_plain(x2));
}
ad_adapt max(const ad_adapt &x1, const ad_adapt &x2) {
  return ad_adapt(max(ad_aug(x1), ad_aug(x2)));
}

Writer min(const Writer &x1, const Writer &x2) {
  return "min"
         "(" +
         x1 + "," + x2 + ")";
}
const char *MinOp::op_name() { return "MinOp"; }
ad_plain min(const ad_plain &x1, const ad_plain &x2) {
  return get_glob()->add_to_stack<MinOp>(x1, x2);
}
ad_aug min(const ad_aug &x1, const ad_aug &x2) {
  if (x1.constant() && x2.constant())
    return Scalar(min(x1.Value(), x2.Value()));
  else
    return min(ad_plain(x1), ad_plain(x2));
}
ad_adapt min(const ad_adapt &x1, const ad_adapt &x2) {
  return ad_adapt(min(ad_aug(x1), ad_aug(x2)));
}

ad_aug asConstant(ad_aug x) { return x.Value(); }

Writer pow(const Writer &x1, const Writer &x2) {
  return "pow(" + x1 + "," + x2 + ")";
}

ad_aug pow(const ad_aug &x1, const ad_aug &x2) {
  if (x1.constant() && x2.constant())
    return Scalar(pow(x1.Value(), x2.Value()));
  else if (x2.constant()) {
    if (x2.Value() == 0.) return 1.;
    if (x2.Value() == 1.) return x1;
    return PowOp<1, 0>()(ad_plain(x1), ad_plain(x2));
  } else if (x1.constant()) {
    if (x1.Value() == 1.) return 1.;
    return PowOp<0, 1>()(ad_plain(x1), ad_plain(x2));
  } else
    return PowOp<1, 1>()(ad_plain(x1), ad_plain(x2));
}

ad_adapt F(const ad_adapt &x1, const ad_adapt &x2) {
  return ad_adapt(F(ad_aug(x1), ad_aug(x2)));
}
void CondExpEqOp::forward(ForwardArgs<Scalar> &args) {
  if (args.x(0) == args.x(1)) {
    args.y(0) = args.x(2);
  } else {
    args.y(0) = args.x(3);
  }
}
void CondExpEqOp::reverse(ReverseArgs<Scalar> &args) {
  if (args.x(0) == args.x(1)) {
    args.dx(2) += args.dy(0);
  } else {
    args.dx(3) += args.dy(0);
  }
}
void CondExpEqOp::forward(ForwardArgs<Replay> &args) {
  args.y(0) = CondExpEq(args.x(0), args.x(1), args.x(2), args.x(3));
}
void CondExpEqOp::reverse(ReverseArgs<Replay> &args) {
  Replay zero(0);
  args.dx(2) += CondExpEq(args.x(0), args.x(1), args.dy(0), zero);
  args.dx(3) += CondExpEq(args.x(0), args.x(1), zero, args.dy(0));
}
void CondExpEqOp::forward(ForwardArgs<Writer> &args) {
  Writer w;
  w << "if (" << args.x(0) << "==" << args.x(1) << ") ";
  args.y(0) = args.x(2);
  w << " else ";
  args.y(0) = args.x(3);
}
void CondExpEqOp::reverse(ReverseArgs<Writer> &args) {
  Writer w;
  w << "if (" << args.x(0) << "==" << args.x(1) << ") ";
  args.dx(2) += args.dy(0);
  w << " else ";
  args.dx(3) += args.dy(0);
}
const char *CondExpEqOp::op_name() {
  return "CExp"
         "Eq";
}
Scalar CondExpEq(const Scalar &x0, const Scalar &x1, const Scalar &x2,
                 const Scalar &x3) {
  if (x0 == x1)
    return x2;
  else
    return x3;
}
ad_plain CondExpEq(const ad_plain &x0, const ad_plain &x1, const ad_plain &x2,
                   const ad_plain &x3) {
  OperatorPure *pOp = get_glob()->getOperator<CondExpEqOp>();
  std::vector<ad_plain> x(4);
  x[0] = x0;
  x[1] = x1;
  x[2] = x2;
  x[3] = x3;
  std::vector<ad_plain> y = get_glob()->add_to_stack<CondExpEqOp>(pOp, x);
  return y[0];
}
ad_aug CondExpEq(const ad_aug &x0, const ad_aug &x1, const ad_aug &x2,
                 const ad_aug &x3) {
  if (x0.constant() && x1.constant()) {
    if (x0.Value() == x1.Value())
      return x2;
    else
      return x3;
  } else {
    return CondExpEq(ad_plain(x0), ad_plain(x1), ad_plain(x2), ad_plain(x3));
  }
}
void CondExpNeOp::forward(ForwardArgs<Scalar> &args) {
  if (args.x(0) != args.x(1)) {
    args.y(0) = args.x(2);
  } else {
    args.y(0) = args.x(3);
  }
}
void CondExpNeOp::reverse(ReverseArgs<Scalar> &args) {
  if (args.x(0) != args.x(1)) {
    args.dx(2) += args.dy(0);
  } else {
    args.dx(3) += args.dy(0);
  }
}
void CondExpNeOp::forward(ForwardArgs<Replay> &args) {
  args.y(0) = CondExpNe(args.x(0), args.x(1), args.x(2), args.x(3));
}
void CondExpNeOp::reverse(ReverseArgs<Replay> &args) {
  Replay zero(0);
  args.dx(2) += CondExpNe(args.x(0), args.x(1), args.dy(0), zero);
  args.dx(3) += CondExpNe(args.x(0), args.x(1), zero, args.dy(0));
}
void CondExpNeOp::forward(ForwardArgs<Writer> &args) {
  Writer w;
  w << "if (" << args.x(0) << "!=" << args.x(1) << ") ";
  args.y(0) = args.x(2);
  w << " else ";
  args.y(0) = args.x(3);
}
void CondExpNeOp::reverse(ReverseArgs<Writer> &args) {
  Writer w;
  w << "if (" << args.x(0) << "!=" << args.x(1) << ") ";
  args.dx(2) += args.dy(0);
  w << " else ";
  args.dx(3) += args.dy(0);
}
const char *CondExpNeOp::op_name() {
  return "CExp"
         "Ne";
}
Scalar CondExpNe(const Scalar &x0, const Scalar &x1, const Scalar &x2,
                 const Scalar &x3) {
  if (x0 != x1)
    return x2;
  else
    return x3;
}
ad_plain CondExpNe(const ad_plain &x0, const ad_plain &x1, const ad_plain &x2,
                   const ad_plain &x3) {
  OperatorPure *pOp = get_glob()->getOperator<CondExpNeOp>();
  std::vector<ad_plain> x(4);
  x[0] = x0;
  x[1] = x1;
  x[2] = x2;
  x[3] = x3;
  std::vector<ad_plain> y = get_glob()->add_to_stack<CondExpNeOp>(pOp, x);
  return y[0];
}
ad_aug CondExpNe(const ad_aug &x0, const ad_aug &x1, const ad_aug &x2,
                 const ad_aug &x3) {
  if (x0.constant() && x1.constant()) {
    if (x0.Value() != x1.Value())
      return x2;
    else
      return x3;
  } else {
    return CondExpNe(ad_plain(x0), ad_plain(x1), ad_plain(x2), ad_plain(x3));
  }
}
void CondExpGtOp::forward(ForwardArgs<Scalar> &args) {
  if (args.x(0) > args.x(1)) {
    args.y(0) = args.x(2);
  } else {
    args.y(0) = args.x(3);
  }
}
void CondExpGtOp::reverse(ReverseArgs<Scalar> &args) {
  if (args.x(0) > args.x(1)) {
    args.dx(2) += args.dy(0);
  } else {
    args.dx(3) += args.dy(0);
  }
}
void CondExpGtOp::forward(ForwardArgs<Replay> &args) {
  args.y(0) = CondExpGt(args.x(0), args.x(1), args.x(2), args.x(3));
}
void CondExpGtOp::reverse(ReverseArgs<Replay> &args) {
  Replay zero(0);
  args.dx(2) += CondExpGt(args.x(0), args.x(1), args.dy(0), zero);
  args.dx(3) += CondExpGt(args.x(0), args.x(1), zero, args.dy(0));
}
void CondExpGtOp::forward(ForwardArgs<Writer> &args) {
  Writer w;
  w << "if (" << args.x(0) << ">" << args.x(1) << ") ";
  args.y(0) = args.x(2);
  w << " else ";
  args.y(0) = args.x(3);
}
void CondExpGtOp::reverse(ReverseArgs<Writer> &args) {
  Writer w;
  w << "if (" << args.x(0) << ">" << args.x(1) << ") ";
  args.dx(2) += args.dy(0);
  w << " else ";
  args.dx(3) += args.dy(0);
}
const char *CondExpGtOp::op_name() {
  return "CExp"
         "Gt";
}
Scalar CondExpGt(const Scalar &x0, const Scalar &x1, const Scalar &x2,
                 const Scalar &x3) {
  if (x0 > x1)
    return x2;
  else
    return x3;
}
ad_plain CondExpGt(const ad_plain &x0, const ad_plain &x1, const ad_plain &x2,
                   const ad_plain &x3) {
  OperatorPure *pOp = get_glob()->getOperator<CondExpGtOp>();
  std::vector<ad_plain> x(4);
  x[0] = x0;
  x[1] = x1;
  x[2] = x2;
  x[3] = x3;
  std::vector<ad_plain> y = get_glob()->add_to_stack<CondExpGtOp>(pOp, x);
  return y[0];
}
ad_aug CondExpGt(const ad_aug &x0, const ad_aug &x1, const ad_aug &x2,
                 const ad_aug &x3) {
  if (x0.constant() && x1.constant()) {
    if (x0.Value() > x1.Value())
      return x2;
    else
      return x3;
  } else {
    return CondExpGt(ad_plain(x0), ad_plain(x1), ad_plain(x2), ad_plain(x3));
  }
}
void CondExpLtOp::forward(ForwardArgs<Scalar> &args) {
  if (args.x(0) < args.x(1)) {
    args.y(0) = args.x(2);
  } else {
    args.y(0) = args.x(3);
  }
}
void CondExpLtOp::reverse(ReverseArgs<Scalar> &args) {
  if (args.x(0) < args.x(1)) {
    args.dx(2) += args.dy(0);
  } else {
    args.dx(3) += args.dy(0);
  }
}
void CondExpLtOp::forward(ForwardArgs<Replay> &args) {
  args.y(0) = CondExpLt(args.x(0), args.x(1), args.x(2), args.x(3));
}
void CondExpLtOp::reverse(ReverseArgs<Replay> &args) {
  Replay zero(0);
  args.dx(2) += CondExpLt(args.x(0), args.x(1), args.dy(0), zero);
  args.dx(3) += CondExpLt(args.x(0), args.x(1), zero, args.dy(0));
}
void CondExpLtOp::forward(ForwardArgs<Writer> &args) {
  Writer w;
  w << "if (" << args.x(0) << "<" << args.x(1) << ") ";
  args.y(0) = args.x(2);
  w << " else ";
  args.y(0) = args.x(3);
}
void CondExpLtOp::reverse(ReverseArgs<Writer> &args) {
  Writer w;
  w << "if (" << args.x(0) << "<" << args.x(1) << ") ";
  args.dx(2) += args.dy(0);
  w << " else ";
  args.dx(3) += args.dy(0);
}
const char *CondExpLtOp::op_name() {
  return "CExp"
         "Lt";
}
Scalar CondExpLt(const Scalar &x0, const Scalar &x1, const Scalar &x2,
                 const Scalar &x3) {
  if (x0 < x1)
    return x2;
  else
    return x3;
}
ad_plain CondExpLt(const ad_plain &x0, const ad_plain &x1, const ad_plain &x2,
                   const ad_plain &x3) {
  OperatorPure *pOp = get_glob()->getOperator<CondExpLtOp>();
  std::vector<ad_plain> x(4);
  x[0] = x0;
  x[1] = x1;
  x[2] = x2;
  x[3] = x3;
  std::vector<ad_plain> y = get_glob()->add_to_stack<CondExpLtOp>(pOp, x);
  return y[0];
}
ad_aug CondExpLt(const ad_aug &x0, const ad_aug &x1, const ad_aug &x2,
                 const ad_aug &x3) {
  if (x0.constant() && x1.constant()) {
    if (x0.Value() < x1.Value())
      return x2;
    else
      return x3;
  } else {
    return CondExpLt(ad_plain(x0), ad_plain(x1), ad_plain(x2), ad_plain(x3));
  }
}
void CondExpGeOp::forward(ForwardArgs<Scalar> &args) {
  if (args.x(0) >= args.x(1)) {
    args.y(0) = args.x(2);
  } else {
    args.y(0) = args.x(3);
  }
}
void CondExpGeOp::reverse(ReverseArgs<Scalar> &args) {
  if (args.x(0) >= args.x(1)) {
    args.dx(2) += args.dy(0);
  } else {
    args.dx(3) += args.dy(0);
  }
}
void CondExpGeOp::forward(ForwardArgs<Replay> &args) {
  args.y(0) = CondExpGe(args.x(0), args.x(1), args.x(2), args.x(3));
}
void CondExpGeOp::reverse(ReverseArgs<Replay> &args) {
  Replay zero(0);
  args.dx(2) += CondExpGe(args.x(0), args.x(1), args.dy(0), zero);
  args.dx(3) += CondExpGe(args.x(0), args.x(1), zero, args.dy(0));
}
void CondExpGeOp::forward(ForwardArgs<Writer> &args) {
  Writer w;
  w << "if (" << args.x(0) << ">=" << args.x(1) << ") ";
  args.y(0) = args.x(2);
  w << " else ";
  args.y(0) = args.x(3);
}
void CondExpGeOp::reverse(ReverseArgs<Writer> &args) {
  Writer w;
  w << "if (" << args.x(0) << ">=" << args.x(1) << ") ";
  args.dx(2) += args.dy(0);
  w << " else ";
  args.dx(3) += args.dy(0);
}
const char *CondExpGeOp::op_name() {
  return "CExp"
         "Ge";
}
Scalar CondExpGe(const Scalar &x0, const Scalar &x1, const Scalar &x2,
                 const Scalar &x3) {
  if (x0 >= x1)
    return x2;
  else
    return x3;
}
ad_plain CondExpGe(const ad_plain &x0, const ad_plain &x1, const ad_plain &x2,
                   const ad_plain &x3) {
  OperatorPure *pOp = get_glob()->getOperator<CondExpGeOp>();
  std::vector<ad_plain> x(4);
  x[0] = x0;
  x[1] = x1;
  x[2] = x2;
  x[3] = x3;
  std::vector<ad_plain> y = get_glob()->add_to_stack<CondExpGeOp>(pOp, x);
  return y[0];
}
ad_aug CondExpGe(const ad_aug &x0, const ad_aug &x1, const ad_aug &x2,
                 const ad_aug &x3) {
  if (x0.constant() && x1.constant()) {
    if (x0.Value() >= x1.Value())
      return x2;
    else
      return x3;
  } else {
    return CondExpGe(ad_plain(x0), ad_plain(x1), ad_plain(x2), ad_plain(x3));
  }
}
void CondExpLeOp::forward(ForwardArgs<Scalar> &args) {
  if (args.x(0) <= args.x(1)) {
    args.y(0) = args.x(2);
  } else {
    args.y(0) = args.x(3);
  }
}
void CondExpLeOp::reverse(ReverseArgs<Scalar> &args) {
  if (args.x(0) <= args.x(1)) {
    args.dx(2) += args.dy(0);
  } else {
    args.dx(3) += args.dy(0);
  }
}
void CondExpLeOp::forward(ForwardArgs<Replay> &args) {
  args.y(0) = CondExpLe(args.x(0), args.x(1), args.x(2), args.x(3));
}
void CondExpLeOp::reverse(ReverseArgs<Replay> &args) {
  Replay zero(0);
  args.dx(2) += CondExpLe(args.x(0), args.x(1), args.dy(0), zero);
  args.dx(3) += CondExpLe(args.x(0), args.x(1), zero, args.dy(0));
}
void CondExpLeOp::forward(ForwardArgs<Writer> &args) {
  Writer w;
  w << "if (" << args.x(0) << "<=" << args.x(1) << ") ";
  args.y(0) = args.x(2);
  w << " else ";
  args.y(0) = args.x(3);
}
void CondExpLeOp::reverse(ReverseArgs<Writer> &args) {
  Writer w;
  w << "if (" << args.x(0) << "<=" << args.x(1) << ") ";
  args.dx(2) += args.dy(0);
  w << " else ";
  args.dx(3) += args.dy(0);
}
const char *CondExpLeOp::op_name() {
  return "CExp"
         "Le";
}
Scalar CondExpLe(const Scalar &x0, const Scalar &x1, const Scalar &x2,
                 const Scalar &x3) {
  if (x0 <= x1)
    return x2;
  else
    return x3;
}
ad_plain CondExpLe(const ad_plain &x0, const ad_plain &x1, const ad_plain &x2,
                   const ad_plain &x3) {
  OperatorPure *pOp = get_glob()->getOperator<CondExpLeOp>();
  std::vector<ad_plain> x(4);
  x[0] = x0;
  x[1] = x1;
  x[2] = x2;
  x[3] = x3;
  std::vector<ad_plain> y = get_glob()->add_to_stack<CondExpLeOp>(pOp, x);
  return y[0];
}
ad_aug CondExpLe(const ad_aug &x0, const ad_aug &x1, const ad_aug &x2,
                 const ad_aug &x3) {
  if (x0.constant() && x1.constant()) {
    if (x0.Value() <= x1.Value())
      return x2;
    else
      return x3;
  } else {
    return CondExpLe(ad_plain(x0), ad_plain(x1), ad_plain(x2), ad_plain(x3));
  }
}

Index SumOp::input_size() const { return n; }

Index SumOp::output_size() const { return 1; }

SumOp::SumOp(size_t n) : n(n) {}

const char *SumOp::op_name() { return "SumOp"; }

Index LogSpaceSumOp::input_size() const { return this->n; }

Index LogSpaceSumOp::output_size() const { return 1; }

LogSpaceSumOp::LogSpaceSumOp(size_t n) : n(n) {}

void LogSpaceSumOp::forward(ForwardArgs<Scalar> &args) {
  Scalar Max = -INFINITY;
  for (size_t i = 0; i < n; i++) {
    if (Max < args.x(i)) Max = args.x(i);
  }
  args.y(0) = 0;
  for (size_t i = 0; i < n; i++) {
    args.y(0) += exp(args.x(i) - Max);
  }
  args.y(0) = Max + log(args.y(0));
}

void LogSpaceSumOp::forward(ForwardArgs<Replay> &args) {
  std::vector<ad_plain> x(input_size());
  for (Index i = 0; i < input_size(); i++) x[i] = args.x(i);
  args.y(0) = logspace_sum(x);
}

const char *LogSpaceSumOp::op_name() { return "LSSumOp"; }

ad_plain logspace_sum(const std::vector<ad_plain> &x) {
  OperatorPure *pOp = get_glob()->getOperator<LogSpaceSumOp>(x.size());
  return get_glob()->add_to_stack<LogSpaceSumOp>(pOp, x)[0];
}

Index LogSpaceSumStrideOp::number_of_terms() const { return stride.size(); }

Index LogSpaceSumStrideOp::input_size() const { return number_of_terms(); }

Index LogSpaceSumStrideOp::output_size() const { return 1; }

LogSpaceSumStrideOp::LogSpaceSumStrideOp(std::vector<Index> stride, size_t n)
    : stride(stride), n(n) {}

void LogSpaceSumStrideOp::forward(ForwardArgs<Scalar> &args) {
  Scalar Max = -INFINITY;

  size_t m = stride.size();
  std::vector<Scalar *> wrk(m);
  Scalar **px = &(wrk[0]);
  for (size_t i = 0; i < m; i++) {
    px[i] = args.x_ptr(i);
  }

  for (size_t i = 0; i < n; i++) {
    Scalar s = rowsum(px, i);
    if (Max < s) Max = s;
  }

  args.y(0) = 0;
  for (size_t i = 0; i < n; i++) {
    Scalar s = rowsum(px, i);
    args.y(0) += exp(s - Max);
  }
  args.y(0) = Max + log(args.y(0));
}

void LogSpaceSumStrideOp::forward(ForwardArgs<Replay> &args) {
  std::vector<ad_plain> x(input_size());
  for (Index i = 0; i < input_size(); i++) x[i] = args.x(i);
  args.y(0) = logspace_sum_stride(x, stride, n);
}

void LogSpaceSumStrideOp::dependencies(Args<> &args, Dependencies &dep) const {
  for (size_t j = 0; j < (size_t)number_of_terms(); j++) {
    size_t K = n * stride[j];
    dep.add_segment(args.input(j), K);
  }
}

const char *LogSpaceSumStrideOp::op_name() { return "LSStride"; }

void LogSpaceSumStrideOp::forward(ForwardArgs<Writer> &args) {
  TMBAD_ASSERT(false);
}

void LogSpaceSumStrideOp::reverse(ReverseArgs<Writer> &args) {
  TMBAD_ASSERT(false);
}

ad_plain logspace_sum_stride(const std::vector<ad_plain> &x,
                             const std::vector<Index> &stride, size_t n) {
  TMBAD_ASSERT(x.size() == stride.size());
  OperatorPure *pOp = get_glob()->getOperator<LogSpaceSumStrideOp>(stride, n);
  return get_glob()->add_to_stack<LogSpaceSumStrideOp>(pOp, x)[0];
}
}  // namespace TMBad
// Autogenerated - do not edit by hand !
#include "graph2dot.hpp"
namespace TMBad {

void graph2dot(global glob, graph G, bool show_id, std::ostream &cout) {
  cout << "digraph graphname {\n";
  for (size_t i = 0; i < glob.opstack.size(); i++) {
    if (!show_id)
      cout << i << " [label=\"" << glob.opstack[i]->op_name() << "\"];\n";
    else
      cout << i << " [label=\"" << glob.opstack[i]->op_name() << " " << i
           << "\"];\n";
  }
  for (size_t node = 0; node < G.num_nodes(); node++) {
    for (size_t k = 0; k < G.num_neighbors(node); k++) {
      cout << node << " -> " << G.neighbors(node)[k] << ";\n";
    }
  }
  for (size_t i = 0; i < glob.subgraph_seq.size(); i++) {
    size_t node = glob.subgraph_seq[i];
    cout << node << " [style=\"filled\"];\n";
  }

  std::vector<Index> v2o = glob.var2op();

  cout << "{rank=same;";
  for (size_t i = 0; i < glob.inv_index.size(); i++) {
    cout << v2o[glob.inv_index[i]] << ";";
  }
  cout << "}\n";

  cout << "{rank=same;";
  for (size_t i = 0; i < glob.dep_index.size(); i++) {
    cout << v2o[glob.dep_index[i]] << ";";
  }
  cout << "}\n";

  cout << "}\n";
}

void graph2dot(global glob, bool show_id, std::ostream &cout) {
  graph G = glob.forward_graph();
  graph2dot(glob, G, show_id, cout);
}

void graph2dot(const char *filename, global glob, graph G, bool show_id) {
  std::ofstream myfile;
  myfile.open(filename);
  graph2dot(glob, G, show_id, myfile);
  myfile.close();
}

void graph2dot(const char *filename, global glob, bool show_id) {
  std::ofstream myfile;
  myfile.open(filename);
  graph2dot(glob, show_id, myfile);
  myfile.close();
}
}  // namespace TMBad
// Autogenerated - do not edit by hand !
#include "graph_transform.hpp"
namespace TMBad {

std::vector<size_t> which(const std::vector<bool> &x) {
  return which<size_t>(x);
}

size_t prod_int(const std::vector<size_t> &x) {
  size_t ans = 1;
  for (size_t i = 0; i < x.size(); i++) ans *= x[i];
  return ans;
}

std::vector<bool> reverse_boundary(global &glob,
                                   const std::vector<bool> &vars) {
  std::vector<bool> boundary(vars);
  std::vector<bool> node_filter = glob.var2op(vars);
  glob.reverse_sub(boundary, node_filter);

  for (size_t i = 0; i < vars.size(); i++) boundary[i] = boundary[i] ^ vars[i];
  return boundary;
}

std::vector<Index> get_accumulation_tree(global &glob, bool boundary) {
  std::vector<OperatorPure *> &opstack = glob.opstack;

  std::vector<bool> node_subset(opstack.size(), false);
  for (size_t i = 0; i < opstack.size(); i++) {
    node_subset[i] = opstack[i]->info().test(op_info::is_linear);
  }

  node_subset.flip();

  std::vector<bool> var_subset = glob.op2var(node_subset);

  glob.reverse(var_subset);

  var_subset.flip();

  if (boundary) var_subset = reverse_boundary(glob, var_subset);

  node_subset = glob.var2op(var_subset);

  return which<Index>(node_subset);
}

std::vector<Index> find_op_by_name(global &glob, const char *name) {
  std::vector<Index> ans;
  std::vector<OperatorPure *> &opstack = glob.opstack;
  for (size_t i = 0; i < opstack.size(); i++) {
    if (!strcmp(opstack[i]->op_name(), name)) {
      ans.push_back(i);
    }
  }
  return ans;
}

std::vector<Index> substitute(global &glob, const std::vector<Index> &seq,
                              bool inv_tags, bool dep_tags) {
  std::vector<OperatorPure *> &opstack = glob.opstack;
  std::vector<Index> seq2(seq);
  make_space_inplace(opstack, seq2);
  OperatorPure *invop = glob.getOperator<global::InvOp>();
  for (size_t i = 0; i < seq2.size(); i++) {
    OperatorPure *op = opstack[seq2[i]];
    if (inv_tags) TMBAD_ASSERT(op != invop);
    size_t nin = op->input_size();
    size_t nou = op->output_size();
    opstack[seq2[i] - 1] = glob.getOperator<global::NullOp2>(nin, 0);
    opstack[seq2[i]] = glob.getOperator<global::NullOp2>(0, nou);
    op->deallocate();
  }
  glob.opstack.any |= op_info(op_info::dynamic);
  std::vector<Index> new_inv = glob.op2var(seq2);
  if (!inv_tags) glob.inv_index.resize(0);
  if (!dep_tags) glob.dep_index.resize(0);
  glob.inv_index.insert(glob.inv_index.end(), new_inv.begin(), new_inv.end());
  return new_inv;
}

std::vector<Index> substitute(global &glob, const char *name, bool inv_tags,
                              bool dep_tags) {
  std::vector<Index> seq = find_op_by_name(glob, name);
  return substitute(glob, seq, inv_tags, dep_tags);
}

global accumulation_tree_split(global glob, bool sum_) {
  global glob_tree = glob;

  std::vector<Index> boundary = get_accumulation_tree(glob, true);

  substitute(glob_tree, boundary, false, true);
  glob_tree.eliminate();

  size_t n = glob_tree.inv_index.size();

  std::vector<Scalar> x0(n);
  for (size_t i = 0; i < n; i++) x0[i] = glob_tree.value_inv(i);
  glob_tree.forward();
  glob_tree.clear_deriv();
  glob_tree.deriv_dep(0) = 1;
  glob_tree.reverse();
  Scalar V = glob_tree.value_dep(0);
  std::vector<Scalar> J(n);
  for (size_t i = 0; i < n; i++) J[i] = glob_tree.deriv_inv(i);

  for (size_t i = 0; i < n; i++) V -= J[i] * x0[i];

  std::vector<Index> vars = glob.op2var(boundary);
  glob.dep_index.resize(0);
  glob.ad_start();
  std::vector<ad_aug_index> res(vars.begin(), vars.end());
  for (size_t i = 0; i < vars.size(); i++) {
    res[i] = res[i] * J[i];
    if (i == 0) res[i] += V;
    if (!sum_) res[i].Dependent();
  }
  if (sum_) {
    ad_aug sum_res = sum(res);
    sum_res.Dependent();
  }
  glob.ad_stop();
  glob.eliminate();
  return glob;
}

void aggregate(global &glob, int sign) {
  TMBAD_ASSERT((sign == 1) || (sign == -1));
  glob.ad_start();
  std::vector<ad_aug_index> x(glob.dep_index.begin(), glob.dep_index.end());
  ad_aug y = 0;
  for (size_t i = 0; i < x.size(); i++) y += x[i];
  if (sign < 0) y = -y;
  glob.dep_index.resize(0);
  y.Dependent();
  glob.ad_stop();
}

old_state::old_state(global &glob) : glob(glob) {
  dep_index = glob.dep_index;
  opstack_size = glob.opstack.size();
}

void old_state::restore() {
  glob.dep_index = dep_index;
  while (glob.opstack.size() > opstack_size) {
    Index input_size = glob.opstack.back()->input_size();
    Index output_size = glob.opstack.back()->output_size();
    glob.inputs.resize(glob.inputs.size() - input_size);
    glob.values.resize(glob.values.size() - output_size);
    glob.opstack.back()->deallocate();
    glob.opstack.pop_back();
  }
}

term_info::term_info(global &glob, bool do_init) : glob(glob) {
  if (do_init) initialize();
}

void term_info::initialize(std::vector<Index> inv_remap) {
  if (inv_remap.size() == 0) inv_remap.resize(glob.inv_index.size(), 0);
  inv_remap = radix::factor<Index>(inv_remap);
  std::vector<Index> remap = remap_identical_sub_expressions(glob, inv_remap);
  std::vector<Index> term_ids = subset(remap, glob.dep_index);
  id = radix::factor<Index>(term_ids);
  Index max_id = *std::max_element(id.begin(), id.end());
  count.resize(max_id + 1, 0);
  for (size_t i = 0; i < id.size(); i++) {
    count[id[i]]++;
  }
}

gk_config::gk_config()
    : debug(false), adaptive(false), nan2zero(true), ytol(1e-2), dx(1) {}

size_t multivariate_index::count() {
  size_t count = 1;
  for (size_t i = 0; i < bound.size(); i++)
    if (mask_[i]) count *= bound[i];
  return count;
}

multivariate_index::multivariate_index(size_t bound_, size_t dim, bool flag)
    : pointer(0) {
  bound.resize(dim, bound_);
  x.resize(dim, 0);
  mask_.resize(dim, flag);
}

multivariate_index::multivariate_index(std::vector<size_t> bound, bool flag)
    : pointer(0), bound(bound) {
  x.resize(bound.size(), 0);
  mask_.resize(bound.size(), flag);
}

void multivariate_index::flip() { mask_.flip(); }

multivariate_index &multivariate_index::operator++() {
  size_t N = 1;
  for (size_t i = 0; i < x.size(); i++) {
    if (mask_[i]) {
      if (x[i] < bound[i] - 1) {
        x[i]++;
        pointer += N;
        break;
      } else {
        x[i] = 0;
        pointer -= (bound[i] - 1) * N;
      }
    }
    N *= bound[i];
  }
  return *this;
}

multivariate_index::operator size_t() { return pointer; }

size_t multivariate_index::index(size_t i) { return x[i]; }

std::vector<size_t> multivariate_index::index() { return x; }

std::vector<bool>::reference multivariate_index::mask(size_t i) {
  return mask_[i];
}

void multivariate_index::set_mask(const std::vector<bool> &mask) {
  TMBAD_ASSERT(mask.size() == mask_.size());
  mask_ = mask;
}

size_t clique::clique_size() { return indices.size(); }

clique::clique() {}

void clique::subset_inplace(const std::vector<bool> &mask) {
  indices = subset(indices, mask);
  dim = subset(dim, mask);
}

void clique::logsum_init() { logsum.resize(prod_int(dim)); }

bool clique::empty() const { return (indices.size() == 0); }

bool clique::contains(Index i) {
  bool ans = false;
  for (size_t j = 0; j < indices.size(); j++) ans |= (i == indices[j]);
  return ans;
}

void clique::get_stride(const clique &super, Index ind,
                        std::vector<ad_plain> &offset, Index &stride) {
  stride = 1;
  for (size_t k = 0; (k < clique_size()) && (indices[k] < ind); k++) {
    stride *= dim[k];
  }

  multivariate_index mv(super.dim);
  size_t nx = mv.count();
  std::vector<bool> mask = lmatch(super.indices, this->indices);
  mask.flip();
  mv.set_mask(mask);
  std::vector<ad_plain> x(nx);
  size_t xa_count = mv.count();
  mv.flip();
  size_t xi_count = mv.count();
  mv.flip();
  TMBAD_ASSERT(x.size() == xa_count * xi_count);
  for (size_t i = 0; i < xa_count; i++, ++mv) {
    mv.flip();
    for (size_t j = 0; j < xi_count; j++, ++mv) {
      TMBAD_ASSERT(logsum[j].on_some_tape());
      x[mv] = logsum[j];
    }
    mv.flip();
  }

  mv = multivariate_index(super.dim);
  mask = lmatch(super.indices, std::vector<Index>(1, ind));
  mask.flip();
  mv.set_mask(mask);

  xa_count = mv.count();
  offset.resize(xa_count);
  for (size_t i = 0; i < xa_count; i++, ++mv) {
    offset[i] = x[mv];
  }
}

sr_grid::sr_grid() {}

sr_grid::sr_grid(Scalar a, Scalar b, size_t n) : x(n), w(n) {
  Scalar h = (b - a) / n;
  for (size_t i = 0; i < n; i++) {
    x[i] = a + h / 2 + i * h;
    w[i] = h;
  }
}

sr_grid::sr_grid(size_t n) {
  for (size_t i = 0; i < n; i++) {
    x[i] = i;
    w[i] = 1. / (double)n;
  }
}

size_t sr_grid::size() { return x.size(); }

ad_plain sr_grid::logw_offset() {
  if (logw.size() != w.size()) {
    logw.resize(w.size());
    for (size_t i = 0; i < w.size(); i++) logw[i] = log(w[i]);
    forceContiguous(logw);
  }
  return logw[0];
}

sequential_reduction::sequential_reduction(global &glob,
                                           std::vector<Index> random,
                                           std::vector<sr_grid> grid,
                                           std::vector<Index> random2grid,
                                           bool perm)
    : grid(grid),
      glob(glob),
      random(random),
      replay(glob, new_glob),
      tinfo(glob, false) {
  inv2grid.resize(glob.inv_index.size(), 0);
  for (size_t i = 0; i < random2grid.size(); i++) {
    inv2grid[random[i]] = random2grid[i];
  }

  mark.resize(glob.values.size(), false);
  for (size_t i = 0; i < random.size(); i++)
    mark[glob.inv_index[random[i]]] = true;
  glob.forward(mark);

  forward_graph = glob.forward_graph(mark);
  reverse_graph = glob.reverse_graph(mark);

  glob.subgraph_cache_ptr();

  var_remap.resize(glob.values.size());

  op2inv_idx = glob.op2idx(glob.inv_index, NA);
  op2dep_idx = glob.op2idx(glob.dep_index, NA);

  if (perm) reorder_random();

  terms_done.resize(glob.dep_index.size(), false);

  std::vector<Index> inv_remap(glob.inv_index.size());
  for (size_t i = 0; i < inv_remap.size(); i++) inv_remap[i] = -(i + 1);
  for (size_t i = 0; i < random.size(); i++)
    inv_remap[random[i]] = inv2grid[random[i]];
  inv_remap = radix::factor<Index>(inv_remap);
  tinfo.initialize(inv_remap);
}

void sequential_reduction::reorder_random() {
  std::vector<IndexPair> edges;
  std::vector<Index> &inv2op = forward_graph.inv2op;

  for (size_t i = 0; i < random.size(); i++) {
    std::vector<Index> subgraph(1, inv2op[random[i]]);
    forward_graph.search(subgraph);
    reverse_graph.search(subgraph);
    for (size_t l = 0; l < subgraph.size(); l++) {
      Index inv_other = op2inv_idx[subgraph[l]];
      if (inv_other != NA) {
        IndexPair edge(random[i], inv_other);
        edges.push_back(edge);
      }
    }
  }

  size_t num_nodes = glob.inv_index.size();
  graph G(num_nodes, edges);

  std::vector<bool> visited(num_nodes, false);
  std::vector<Index> subgraph;
  for (size_t i = 0; i < random.size(); i++) {
    if (visited[random[i]]) continue;
    std::vector<Index> sg(1, random[i]);
    G.search(sg, visited, false, false);
    subgraph.insert(subgraph.end(), sg.begin(), sg.end());
  }
  std::reverse(subgraph.begin(), subgraph.end());
  TMBAD_ASSERT(random.size() == subgraph.size());
  random = subgraph;
}

std::vector<size_t> sequential_reduction::get_grid_bounds(
    std::vector<Index> inv_index) {
  std::vector<size_t> ans(inv_index.size());
  for (size_t i = 0; i < inv_index.size(); i++) {
    ans[i] = grid[inv2grid[inv_index[i]]].size();
  }
  return ans;
}

std::vector<sr_grid *> sequential_reduction::get_grid(
    std::vector<Index> inv_index) {
  std::vector<sr_grid *> ans(inv_index.size());
  for (size_t i = 0; i < inv_index.size(); i++) {
    ans[i] = &(grid[inv2grid[inv_index[i]]]);
  }
  return ans;
}

std::vector<ad_aug> sequential_reduction::tabulate(std::vector<Index> inv_index,
                                                   Index dep_index) {
  size_t id = tinfo.id[dep_index];
  size_t count = tinfo.count[id];
  bool do_cache = (count >= 2);
  if (do_cache) {
    if (cache[id].size() > 0) {
      return cache[id];
    }
  }

  std::vector<sr_grid *> inv_grid = get_grid(inv_index);
  std::vector<size_t> grid_bounds = get_grid_bounds(inv_index);
  multivariate_index mv(grid_bounds);
  std::vector<ad_aug> ans(mv.count());
  for (size_t i = 0; i < ans.size(); i++, ++mv) {
    for (size_t j = 0; j < inv_index.size(); j++) {
      replay.value_inv(inv_index[j]) = inv_grid[j]->x[mv.index(j)];
    }
    replay.forward_sub();
    ans[i] = replay.value_dep(dep_index);
  }

  forceContiguous(ans);
  if (do_cache) {
    cache[id] = ans;
  }
  return ans;
}

void sequential_reduction::merge(Index i) {
  std::vector<Index> super;
  size_t c = 0;
  for (std::list<clique>::iterator it = cliques.begin(); it != cliques.end();
       ++it) {
    if ((*it).contains(i)) {
      super.insert(super.end(), (*it).indices.begin(), (*it).indices.end());
      c++;
    }
  }
  sort_unique_inplace(super);

  std::vector<std::vector<ad_plain> > offset_by_clique(c);
  std::vector<Index> stride_by_clique(c);
  clique C;
  C.indices = super;
  C.dim = get_grid_bounds(super);
  std::list<clique>::iterator it = cliques.begin();
  c = 0;
  while (it != cliques.end()) {
    if ((*it).contains(i)) {
      (*it).get_stride(C, i, offset_by_clique[c], stride_by_clique[c]);
      it = cliques.erase(it);
      c++;
    } else {
      ++it;
    }
  }

  std::vector<bool> mask = lmatch(super, std::vector<Index>(1, i));
  mask.flip();
  C.subset_inplace(mask);
  C.logsum_init();

  grid[inv2grid[i]].logw_offset();
  size_t v_begin = get_glob()->values.size();
  for (size_t j = 0; j < C.logsum.size(); j++) {
    std::vector<ad_plain> x;
    std::vector<Index> stride;
    for (size_t k = 0; k < offset_by_clique.size(); k++) {
      x.push_back(offset_by_clique[k][j]);
      stride.push_back(stride_by_clique[k]);
    }

    x.push_back(grid[inv2grid[i]].logw_offset());
    stride.push_back(1);
    C.logsum[j] = logspace_sum_stride(x, stride, grid[inv2grid[i]].size());
  }
  size_t v_end = get_glob()->values.size();
  TMBAD_ASSERT(v_end - v_begin == C.logsum.size());

  cliques.push_back(C);
}

void sequential_reduction::update(Index i) {
  const std::vector<Index> &inv2op = forward_graph.inv2op;

  Index start_node = inv2op[i];
  std::vector<Index> subgraph(1, start_node);
  forward_graph.search(subgraph);

  std::vector<Index> dep_clique;
  std::vector<Index> subgraph_terms;
  for (size_t k = 0; k < subgraph.size(); k++) {
    Index node = subgraph[k];
    Index dep_idx = op2dep_idx[node];
    if (dep_idx != NA && !terms_done[dep_idx]) {
      terms_done[dep_idx] = true;
      subgraph_terms.push_back(node);
      dep_clique.push_back(dep_idx);
    }
  }
  for (size_t k = 0; k < subgraph_terms.size(); k++) {
    subgraph.resize(0);
    subgraph.push_back(subgraph_terms[k]);

    reverse_graph.search(subgraph);

    std::vector<Index> inv_clique;
    for (size_t l = 0; l < subgraph.size(); l++) {
      Index tmp = op2inv_idx[subgraph[l]];
      if (tmp != NA) inv_clique.push_back(tmp);
    }

    glob.subgraph_seq = subgraph;

    clique C;
    C.indices = inv_clique;
    C.dim = get_grid_bounds(inv_clique);
    C.logsum = tabulate(inv_clique, dep_clique[k]);

    cliques.push_back(C);
  }

  merge(i);
}

void sequential_reduction::show_cliques() {
  Rcout << "Cliques: ";
  std::list<clique>::iterator it;
  for (it = cliques.begin(); it != cliques.end(); ++it) {
    Rcout << it->indices << " ";
  }
  Rcout << "\n";
}

void sequential_reduction::update_all() {
  for (size_t i = 0; i < random.size(); i++) update(random[i]);
}

ad_aug sequential_reduction::get_result() {
  ad_aug ans = 0;
  std::list<clique>::iterator it;
  for (it = cliques.begin(); it != cliques.end(); ++it) {
    TMBAD_ASSERT(it->clique_size() == 0);
    TMBAD_ASSERT(it->logsum.size() == 1);
    ans += it->logsum[0];
  }

  for (size_t i = 0; i < terms_done.size(); i++) {
    if (!terms_done[i]) ans += replay.value_dep(i);
  }
  return ans;
}

global sequential_reduction::marginal() {
  replay.start();
  replay.forward(true, false);
  update_all();
  ad_aug ans = get_result();
  ans.Dependent();
  replay.stop();
  return new_glob;
}

autopar::autopar(global &glob, size_t num_threads)
    : glob(glob),
      num_threads(num_threads),
      do_aggregate(false),
      keep_all_inv(false) {
  reverse_graph = glob.reverse_graph();
}

std::vector<size_t> autopar::max_tree_depth() {
  std::vector<Index> max_tree_depth(glob.opstack.size(), 0);
  Dependencies dep;
  Args<> args(glob.inputs);
  for (size_t i = 0; i < glob.opstack.size(); i++) {
    dep.resize(0);
    glob.opstack[i]->dependencies(args, dep);
    for (size_t j = 0; j < dep.size(); j++) {
      max_tree_depth[i] = std::max(max_tree_depth[i], max_tree_depth[dep[j]]);
    }

    max_tree_depth[i]++;

    glob.opstack[i]->increment(args.ptr);
  }
  std::vector<size_t> ans(glob.dep_index.size());
  for (size_t j = 0; j < glob.dep_index.size(); j++) {
    ans[j] = max_tree_depth[glob.dep_index[j]];
  }
  return ans;
}

void autopar::run() {
  std::vector<size_t> ord = order(max_tree_depth());
  std::reverse(ord.begin(), ord.end());
  std::vector<bool> visited(glob.opstack.size(), false);
  std::vector<Index> start;
  std::vector<Index> dWork(ord.size());
  for (size_t i = 0; i < ord.size(); i++) {
    start.resize(1);
    start[0] = reverse_graph.dep2op[ord[i]];
    reverse_graph.search(start, visited, false, false);
    dWork[i] = start.size();
    if (false) {
      for (size_t k = 0; k < start.size(); k++) {
        Rcout << glob.opstack[start[k]]->op_name() << " ";
      }
      Rcout << "\n";
    }
  }

  std::vector<size_t> thread_assign(ord.size(), 0);
  std::vector<size_t> work_by_thread(num_threads, 0);
  for (size_t i = 0; i < dWork.size(); i++) {
    if (i == 0) {
      thread_assign[i] = 0;
    } else {
      if (dWork[i] <= 1)
        thread_assign[i] = thread_assign[i - 1];
      else
        thread_assign[i] = which_min(work_by_thread);
    }
    work_by_thread[thread_assign[i]] += dWork[i];
  }

  node_split.resize(num_threads);
  for (size_t i = 0; i < ord.size(); i++) {
    node_split[thread_assign[i]].push_back(reverse_graph.dep2op[ord[i]]);
  }

  for (size_t i = 0; i < num_threads; i++) {
    if (keep_all_inv)
      node_split[i].insert(node_split[i].begin(), reverse_graph.inv2op.begin(),
                           reverse_graph.inv2op.end());
    reverse_graph.search(node_split[i]);
  }
}

void autopar::extract() {
  vglob.resize(num_threads);
  inv_idx.resize(num_threads);
  dep_idx.resize(num_threads);
  std::vector<Index> tmp;
  for (size_t i = 0; i < num_threads; i++) {
    glob.subgraph_seq = node_split[i];
    vglob[i] = glob.extract_sub(tmp);
    if (do_aggregate) aggregate(vglob[i]);
  }

  Index NA = -1;
  std::vector<Index> op2inv_idx = glob.op2idx(glob.inv_index, NA);
  std::vector<Index> op2dep_idx = glob.op2idx(glob.dep_index, NA);
  for (size_t i = 0; i < num_threads; i++) {
    std::vector<Index> &seq = node_split[i];
    for (size_t j = 0; j < seq.size(); j++) {
      if (op2inv_idx[seq[j]] != NA) inv_idx[i].push_back(op2inv_idx[seq[j]]);
      if (op2dep_idx[seq[j]] != NA) dep_idx[i].push_back(op2dep_idx[seq[j]]);
    }
    if (do_aggregate) {
      dep_idx[i].resize(1);
      dep_idx[i][0] = i;
    }
  }
}

size_t autopar::input_size() const { return glob.inv_index.size(); }

size_t autopar::output_size() const {
  return (do_aggregate ? num_threads : glob.dep_index.size());
}

Index ParalOp::input_size() const { return n; }

Index ParalOp::output_size() const { return m; }

ParalOp::ParalOp(const autopar &ap)
    : vglob(ap.vglob),
      inv_idx(ap.inv_idx),
      dep_idx(ap.dep_idx),
      n(ap.input_size()),
      m(ap.output_size()) {}

void ParalOp::forward(ForwardArgs<Scalar> &args) {
  size_t num_threads = vglob.size();

#ifdef _OPENMP
#pragma omp parallel for
#endif

  for (size_t i = 0; i < num_threads; i++) {
    for (size_t j = 0; j < inv_idx[i].size(); j++) {
      vglob[i].value_inv(j) = args.x(inv_idx[i][j]);
    }
    vglob[i].forward();
  }

  for (size_t i = 0; i < num_threads; i++) {
    for (size_t j = 0; j < dep_idx[i].size(); j++) {
      args.y(dep_idx[i][j]) = vglob[i].value_dep(j);
    }
  }
}

void ParalOp::reverse(ReverseArgs<Scalar> &args) {
  size_t num_threads = vglob.size();

#ifdef _OPENMP
#pragma omp parallel for
#endif

  for (size_t i = 0; i < num_threads; i++) {
    vglob[i].clear_deriv();
    for (size_t j = 0; j < dep_idx[i].size(); j++) {
      vglob[i].deriv_dep(j) = args.dy(dep_idx[i][j]);
    }
    vglob[i].reverse();
  }

  for (size_t i = 0; i < num_threads; i++) {
    for (size_t j = 0; j < inv_idx[i].size(); j++) {
      args.dx(inv_idx[i][j]) += vglob[i].deriv_inv(j);
    }
  }
}

const char *ParalOp::op_name() { return "ParalOp"; }

void ParalOp::print(global::print_config cfg) {
  size_t num_threads = vglob.size();
  for (size_t i = 0; i < num_threads; i++) {
    global::print_config cfg2 = cfg;
    std::stringstream ss;
    ss << i;
    std::string str = ss.str();
    cfg2.prefix = cfg2.prefix + str;
    vglob[i].print(cfg2);
  }
}

std::vector<Index> get_likely_expression_duplicates(
    const global &glob, std::vector<Index> inv_remap) {
  global::hash_config cfg;
  cfg.strong_inv = true;
  cfg.strong_const = true;
  cfg.strong_output = true;
  cfg.reduce = false;
  cfg.deterministic = TMBAD_DETERMINISTIC_HASH;
  cfg.inv_seed = inv_remap;
  std::vector<hash_t> h = glob.hash_sweep(cfg);
  return radix::first_occurance<Index>(h);
}

bool all_allow_remap(const global &glob) {
  Args<> args(glob.inputs);
  for (size_t i = 0; i < glob.opstack.size(); i++) {
    op_info info = glob.opstack[i]->info();
    if (!info.test(op_info::allow_remap)) {
      return false;
    }
    glob.opstack[i]->increment(args.ptr);
  }
  return true;
}

std::vector<Index> remap_identical_sub_expressions(
    global &glob, std::vector<Index> inv_remap) {
  std::vector<Index> remap = get_likely_expression_duplicates(glob, inv_remap);

  for (size_t i = 0; i < glob.inv_index.size(); i++) {
    bool accept = false;
    Index var_i = glob.inv_index[i];
    if (inv_remap.size() > 0) {
      Index j = inv_remap[i];
      Index var_j = glob.inv_index[j];
      accept = remap[var_i] == remap[var_j];
    }
    if (!accept) remap[var_i] = var_i;
  }

  std::vector<Index> v2o = glob.var2op();
  std::vector<Index> dep;
  global::OperatorPure *invop = glob.getOperator<global::InvOp>();
  Dependencies dep1;
  Dependencies dep2;
  size_t reject = 0;
  size_t total = 0;
  Args<> args(glob.inputs);

  for (size_t j = 0, i = 0, nout = 0; j < glob.opstack.size(); j++, i += nout) {
    nout = glob.opstack[j]->output_size();
    bool any_remap = false;
    for (size_t k = i; k < i + nout; k++) {
      if (remap[k] != k) {
        any_remap = true;
        break;
      }
    }
    if (any_remap) {
      bool ok = true;
      total += nout;

      global::OperatorPure *CurOp = glob.opstack[v2o[i]];
      global::OperatorPure *RemOp = glob.opstack[v2o[remap[i]]];
      ok &= (CurOp->identifier() == RemOp->identifier());

      ok &= (CurOp->input_size() == RemOp->input_size());
      ok &= (CurOp->output_size() == RemOp->output_size());

      op_info CurInfo = CurOp->info();

      if (ok && (nout > 1)) {
        for (size_t k = 1; k < nout; k++) {
          ok &= (remap[i + k] < i);

          ok &= (v2o[remap[i + k]] == v2o[remap[i]]);

          ok &= (remap[i + k] == remap[i] + k);
        }
      }

      if (CurOp == invop) {
        ok = false;
      }
      if (ok) {
        if (CurInfo.test(op_info::is_constant)) {
          if (glob.values[i] != glob.values[remap[i]]) {
            ok = false;
          }
        }
      }

      if (ok) {
        glob.subgraph_cache_ptr();

        args.ptr = glob.subgraph_ptr[v2o[i]];
        dep1.resize(0);
        glob.opstack[v2o[i]]->dependencies(args, dep1);

        args.ptr = glob.subgraph_ptr[v2o[remap[i]]];
        dep2.resize(0);
        glob.opstack[v2o[remap[i]]]->dependencies(args, dep2);

        ok = (dep1.size() == dep2.size());
        if (ok) {
          bool all_equal = true;
          for (size_t j = 0; j < dep1.size(); j++) {
            all_equal &= (remap[dep1[j]] == remap[dep2[j]]);
          }
          ok = all_equal;
        }
      }

      if (!ok) {
        reject += nout;
        for (size_t k = i; k < i + nout; k++) remap[k] = k;
      }
    }
  }

  for (size_t i = 0; i < remap.size(); i++) {
    TMBAD_ASSERT(remap[i] <= i);
    TMBAD_ASSERT(remap[remap[i]] == remap[i]);
  }

  if (true) {
    Args<> args(glob.inputs);
    intervals<Index> visited;
    for (size_t i = 0; i < glob.opstack.size(); i++) {
      op_info info = glob.opstack[i]->info();
      if (!info.test(op_info::allow_remap)) {
        Dependencies dep;
        glob.opstack[i]->dependencies(args, dep);
        for (size_t j = 0; j < dep.I.size(); j++) {
          visited.insert(dep.I[j].first, dep.I[j].second);
        }
      }
      glob.opstack[i]->increment(args.ptr);
    }

    forbid_remap<std::vector<Index> > fb(remap);
    visited.apply(fb);
  }
  if (reject > 0) {
    ((void)(total));
  }

  return remap;
}

void remap_identical_sub_expressions(global &glob) {
  std::vector<Index> inv_remap(0);
  std::vector<Index> remap = remap_identical_sub_expressions(glob, inv_remap);

  for (size_t i = 0; i < glob.inputs.size(); i++) {
    glob.inputs[i] = remap[glob.inputs[i]];
  }
}

std::vector<Position> inv_positions(global &glob) {
  IndexPair ptr(0, 0);
  std::vector<bool> independent_variable = glob.inv_marks();
  std::vector<Position> ans(glob.inv_index.size());
  size_t k = 0;
  for (size_t i = 0; i < glob.opstack.size(); i++) {
    Index nout = glob.opstack[i]->output_size();
    for (Index j = 0; j < nout; j++) {
      if (independent_variable[ptr.second + j]) {
        ans[k].node = i;
        ans[k].ptr = ptr;
        k++;
      }
    }
    glob.opstack[i]->increment(ptr);
  }
  return ans;
}

void reorder_graph(global &glob, std::vector<Index> inv_idx) {
  for (size_t i = 1; i < inv_idx.size(); i++) {
    TMBAD_ASSERT(inv_idx[i] > inv_idx[i - 1]);
  }

  std::vector<bool> marks(glob.values.size(), false);
  for (size_t i = 0; i < inv_idx.size(); i++)
    marks[glob.inv_index[inv_idx[i]]] = true;
  glob.forward_dense(marks);

  intervals<Index> I = glob.get_intervals(op_info::dynamic, true, true);
  if (I.x.size() > 0) {
    struct {
      std::vector<bool> &marks;
      bool invalid;
      void operator()(Index a, Index b) {
        size_t cnt = std::count(marks.begin() + a, marks.begin() + b + 1, true);
        bool ok = (cnt == 0) || (cnt == b - a + 1);
        if (!ok) {
          invalid = true;

          std::fill(marks.begin() + a, marks.begin() + b + 1, true);
        }
      }
    } F = {marks, false};

    I.apply(F);
    while (F.invalid) {
      glob.forward_dense(marks);

      F.invalid = false;
      I.apply(F);
    }
  }
  if (false) {
    int c = std::count(marks.begin(), marks.end(), true);
    Rcout << "marked proportion:" << (double)c / (double)marks.size() << "\n";
  }

  marks.flip();

  marks = glob.var2op(marks);

  std::vector<Index> seq1 = which<Index>(marks);
  marks.flip();
  std::vector<Index> seq2 = which<Index>(marks);
  seq1.insert(seq1.end(), seq2.begin(), seq2.end());
  glob.subgraph_seq = seq1;

  glob = glob.extract_sub();
}
}  // namespace TMBad
// Autogenerated - do not edit by hand !
#include "integrate.hpp"
namespace TMBad {

double value(double x) { return x; }

control::control(int subdivisions_, double reltol_, double abstol_)
    : subdivisions(subdivisions_), reltol(reltol_), abstol(abstol_) {}
}  // namespace TMBad
// Autogenerated - do not edit by hand !
#include "radix.hpp"
namespace TMBad {}
// Autogenerated - do not edit by hand !
#include "tmbad_allow_comparison.hpp"
namespace TMBad {

bool operator<(const ad_aug &x, const ad_aug &y) {
  return x.Value() < y.Value();
}
bool operator<(const Scalar &x, const ad_aug &y) { return x < y.Value(); }

bool operator<=(const ad_aug &x, const ad_aug &y) {
  return x.Value() <= y.Value();
}
bool operator<=(const Scalar &x, const ad_aug &y) { return x <= y.Value(); }

bool operator>(const ad_aug &x, const ad_aug &y) {
  return x.Value() > y.Value();
}
bool operator>(const Scalar &x, const ad_aug &y) { return x > y.Value(); }

bool operator>=(const ad_aug &x, const ad_aug &y) {
  return x.Value() >= y.Value();
}
bool operator>=(const Scalar &x, const ad_aug &y) { return x >= y.Value(); }

bool operator==(const ad_aug &x, const ad_aug &y) {
  return x.Value() == y.Value();
}
bool operator==(const Scalar &x, const ad_aug &y) { return x == y.Value(); }

bool operator!=(const ad_aug &x, const ad_aug &y) {
  return x.Value() != y.Value();
}
bool operator!=(const Scalar &x, const ad_aug &y) { return x != y.Value(); }
}  // namespace TMBad
// Autogenerated - do not edit by hand !
#include "vectorize.hpp"
namespace TMBad {

VSumOp::VSumOp(size_t n) : n(n) {}

void VSumOp::dependencies(Args<> &args, Dependencies &dep) const {
  dep.add_segment(args.input(0), n);
}

void VSumOp::forward(ForwardArgs<Writer> &args) { TMBAD_ASSERT(false); }

void VSumOp::reverse(ReverseArgs<Writer> &args) { TMBAD_ASSERT(false); }

const char *VSumOp::op_name() { return "VSumOp"; }

ad_aug sum(ad_segment x) {
  global::Complete<VSumOp> F(x.size());
  return F(x)[0];
}

Scalar *SegmentRef::value_ptr() { return (*glob_ptr).values.data() + offset; }

Scalar *SegmentRef::deriv_ptr() { return (*glob_ptr).derivs.data() + offset; }

SegmentRef::SegmentRef() {}

SegmentRef::SegmentRef(const Scalar *x) {
  SegmentRef *sx = (SegmentRef *)x;
  *this = *sx;
}

SegmentRef::SegmentRef(global *g, Index o, Index s)
    : glob_ptr(g), offset(o), size(s) {}

SegmentRef::SegmentRef(const ad_segment &x) {
  static const size_t K = ScalarPack<SegmentRef>::size;
  TMBAD_ASSERT(x.size() == K);
  Scalar buf[K];
  for (size_t i = 0; i < K; i++) buf[i] = x[i].Value();
  SegmentRef *sx = (SegmentRef *)buf;
  *this = *sx;
}

bool SegmentRef::isNull() { return (glob_ptr == NULL); }

void SegmentRef::resize(ad_segment &pack, Index n) {
  Index i = pack.index();
  SegmentRef *p = (SegmentRef *)(get_glob()->values.data() + i);
  p->size = n;
}

ad_segment pack(const ad_segment &x, bool up) {
  TMBAD_ASSERT2(x.index() < get_glob()->values.size(),
                "Packing invalid ad_segment");
  if (up) {
    global::Complete<PackOp<true> > F(x.size());
    return F(x);
  } else {
    global::Complete<PackOp<false> > F(x.size());
    return F(x);
  }
}

ad_segment unpack(const ad_segment &x) {
  Index n = SegmentRef(x).size;
  global::Complete<UnpkOp<false> > op(n);
  return op(x);
}

void unpack(const ad_segment &x, ad_segment &y) {
  Index n = SegmentRef(x).size;
  global::Complete<UnpkOp<true> > op(n);
  op(x, y);
}

Scalar *unpack(const std::vector<Scalar> &x, Index j) {
  Index K = ScalarPack<SegmentRef>::size;
  SegmentRef sr(&(x[j * K]));
  return sr.value_ptr();
}

std::vector<ad_aug> concat(const std::vector<ad_segment> &x) {
  std::vector<ad_aug> ans;
  for (size_t i = 0; i < x.size(); i++) {
    ad_segment xi = x[i];
    for (size_t j = 0; j < xi.size(); j++) {
      ans.push_back(xi[j]);
    }
  }
  return ans;
}
}  // namespace TMBad
