
#include "benchmarker.hpp"
#include "curve_export.hpp"
#include "geometrycentral/surface/mesh_graph_algorithms.h"
#include "geometrycentral/surface/flip_geodesics.h"

void computeSkeleton(std::unique_ptr<NeckModel> &nm, int r_hops){
  // Compute Genus
  if (nm->mesh->genus() > 0){
    polyscope::removeAllStructures();
    nm.release();
    nm = NULL;
    exit(7);
  }
  // srand(time(NULL));
  int64_t last_timing = 0;
  std::ofstream file("timing", std::ios::app);

  double total_area = 0.0;
  for (size_t i = 0; i < nm->mesh->nFaces(); i++)
  {
    total_area += nm->geometry->faceAreas[i];
  }

  std::vector<glm::vec3> output_ve;
  std::vector<std::array<size_t, 2>> output_ed;
  int base_count = 0;

  auto start = std::chrono::steady_clock::now();

  ///////////// Pick a Good Root Point
  // // Pick a random source vertex: X
  int ridx = std::rand() % nm->mesh->nVertices();
  // int ridx = 1040;
  Vertex Xc = nm->mesh->vertex(ridx);
  auto sssp_Xc = nm->sssp_report_furthest(Xc);
  vPair Yc = sssp_Xc.second;
  auto sssp_Yc = nm->sssp_report_furthest(Yc.second);
  vPair Zc = sssp_Yc.second;

  nm->_source = Zc.second;
  auto candidates = findLeaders(nm,true,r_hops);

  // Find the longest distance from source
  vPair spair = candidates.back();
  candidates.pop_back();
  float maxv = 0.0;
  vPair tpair;
  auto todel = candidates.begin();
  // Get maximum and remove it from the candidate list.
  for (auto it = candidates.begin(); it < candidates.end(); it++){
    auto x = *it;
    if (x.first > maxv){
      maxv = x.first;
      tpair = x;
      todel = it;
    } 
  }
  file << "V: " << nm->mesh->nVertices() << " E: " << nm->mesh->nEdges() << " F: " << nm->mesh->nFaces() << " c: " << candidates.size()+1 << std::endl;
  candidates.erase(todel);

  // Get Candidates Timing
  last_timing = since(start).count();
  file << "Candidates: " << last_timing << std::endl;

  //////////// Build the skeleton

  // We have the sssp of the diameter. We're going to perform the following loop:
  // Add each node in the path to a set (the skeleton)
  // For each candidate, compute the shortest path from it to the skeleton
  // Add all of those nodes into the skeleton, repeat until candidates are exhausted.

  std::unordered_set<Vertex> skeleton;
  std::vector<std::vector<Halfedge>> skeleton_paths;
  auto sssp = nm->st_dijkstras(spair.second, tpair.second);
  auto spine = nm->get_he_path(sssp, spair.second, tpair.second);
  skeleton_paths.push_back(spine);

  Vertex t_spine = spine.at(0).twin().vertex();

  skeleton.insert(t_spine);
  for (auto he : spine){
    Vertex v = he.vertex();

    skeleton.insert(v);
  }

  for (auto candidate : candidates){
    Vertex s = candidate.second;
    auto stres = nm->stgroup_dijkstras(s, skeleton);
    sssp_t st_sssp = stres.first;
    Vertex t = stres.second;

    auto bone = nm->get_he_path(st_sssp, s,t);
    if (!bone.empty()){
      skeleton_paths.push_back(bone);
      Vertex t_bone = bone.at(0).twin().vertex();
      skeleton.insert(t_bone);
      for (auto he : bone){
        Vertex v = he.vertex();
        skeleton.insert(v);
      }
    }

  }

  file << "Skeleton: " << since(start).count() - last_timing << std::endl;
  last_timing = since(start).count();


  auto out_cycles = nm->get_cycles_from_skeleton(skeleton_paths);
  file << "Cycles: " << since(start).count() - last_timing << std::endl;
  last_timing = since(start).count();


    for (auto cycles : out_cycles){
      std::vector<float> cycle_lens;
      for (auto he_cycle : cycles)
      {
        float a = 0.0;
        for (auto he : he_cycle)
        {
          a += nm->geometry->edgeLengths[he.edge()];
        }
        cycle_lens.push_back(a);
      }

      // For each cycle, add each face into the queue
      // Run BFS, preventing the crossing of a cycle.

      FaceData<bool> visited(*(nm->mesh)); // Visited Set of faces
      std::queue<Face> bfs_q;
      std::vector<double> segment_lengths = std::vector<double>(cycles.size() + 1); // The table of constrained areas
      size_t ind = 0;
      for (auto he_cycle : cycles)
      {
        std::unordered_set<Face> banned_faces; // Banned faces from enquing in one iteration
        double area_sum = 0.0;
        for (auto he : he_cycle)
        {
          // enqueue all the faces induced by one side of the cycle
          if (visited[he.face()] == false)
          {
            bfs_q.push(he.face());
            visited[he.face()] = true;
          }
          // ban all the faces induced by the other side (this should prevent all crossings automatically)
          banned_faces.insert(he.twin().face());
        }

        while (!bfs_q.empty())
        {
          Face f = bfs_q.front();
          bfs_q.pop();
          area_sum += nm->geometry->faceAreas[f];
          for (Face g : f.adjacentFaces())
          {
            if (visited[g] == false && banned_faces.find(g) == banned_faces.end())
            {
              bfs_q.push(g);
              visited[g] = true;
            }
          }
        }
        segment_lengths[ind] = area_sum;
        ind++;
      }

      // Do last segment:

      if (!cycles.empty()) {
        ind = cycles.size()-1;
        std::unordered_set<Face> banned_faces;
        double area_sum = 0.0;
        auto he_cycle = cycles[ind];
        for (auto he : he_cycle)
        {
          auto hetw = he.twin();
          // enqueue all the faces induced by one side of the cycle
          if (visited[hetw.face()] == false)
          {
            bfs_q.push(hetw.face());
            visited[hetw.face()] = true;
          }
          // ban all the faces induced by the other side (this should prevent all crossings automatically)
          banned_faces.insert(hetw.twin().face());
        }
        while (!bfs_q.empty())
        {
          Face f = bfs_q.front();
          bfs_q.pop();
          area_sum += nm->geometry->faceAreas[f];
          for (Face g : f.adjacentFaces())
          {
            if (visited[g] == false && banned_faces.find(g) == banned_faces.end())
            {
              bfs_q.push(g);
              visited[g] = true;
            }
          }
        }
        segment_lengths[ind] = area_sum;
      }
  
      std::vector<double> segment_prefix_sums = std::vector<double>(segment_lengths.size());
      std::vector<double> tightness = std::vector<double>(cycles.size());
      segment_prefix_sums[0] = segment_lengths[0];

      for (size_t i = 1; i < segment_lengths.size(); i++)
      {
        segment_prefix_sums[i] = segment_prefix_sums[i - 1] + segment_lengths[i];
      }

      for (size_t i = 0; i < tightness.size(); i++)
      {
        tightness[i] = std::min(segment_prefix_sums[i], total_area - segment_prefix_sums[i]) / (cycle_lens[i] * cycle_lens[i]);
      }

      std::vector<bool> local_max_cycle(cycles.size());
      if (cycles.size() > 7)
      {
        for (size_t i = 3; i < cycles.size() - 3; i++)
        {
          if (tightness[i] <= .163) {
            continue;
          }
          if (tightness[i] > tightness[i + 1] && tightness[i] > tightness[i - 1] && tightness[i] > tightness[i + 2] && tightness[i] > tightness[i - 2])
          {
            if (tightness[i] > tightness[i + 3] && tightness[i] > tightness[i - 3])
              local_max_cycle[i] = true;
          }
        }
      }

      // Cycle path display:
      for (size_t i = 0; i < cycles.size(); i++)
        {
            if (!local_max_cycle[i]) {
              continue;
            }

            for (size_t j = 0; j < cycles[i].size(); j++)
            {
              Halfedge he = cycles[i][j];
              Vector3 vertdat = nm->geometry->vertexPositions[he.tailVertex()];
              output_ve.push_back({vertdat.x, vertdat.y, vertdat.z});
              output_ed.push_back({base_count + j, base_count + ((j + 1) % (cycles[i].size()))});
            }
            base_count += cycles[i].size();

        }
    }
    file << "Tightness: " << since(start).count() - last_timing << std::endl;
    nm->skeleton_cycles_output = out_cycles;
    auto curve2 = polyscope::registerCurveNetwork("cyclecurve", output_ve, output_ed);
    curve2->setColor({1.0, 0.0, 0.0});
    curve2->setPosition(glm::vec3{0.,0.,0.});
    // nm->salient_cycles_output = cycles;
    file << "Done: " << since(start).count() << std::endl;
    std::cout << "Done: " << since(start).count() << std::endl;
    export_curve_network_obj(output_ve, output_ed, "curvenet.obj");
}

std::vector<vPair> findLeaders(std::unique_ptr<NeckModel> &nm, bool trim, int r)
{
  std::cout << "rhops: " << r << std::endl;
  auto path = nm->sssp(nm->_source);
  auto prev = path.first;
  auto dists = path.second;

  // Find all local candidates
  std::vector<std::pair<float, Vertex>> candidates;

  for (Vertex v : nm->mesh->vertices())
  {
    bool is_candidate = true;
    for (Vertex u : v.adjacentVertices())
    {
      if (dists[v] < dists[u])
      {
        is_candidate = false;
        break;
      }
    }
    if (is_candidate)
    {
      candidates.push_back({dists[v], v});
    }
  }
  std::cout << "# of Candidates/Verts " << candidates.size() << "/" << nm->mesh->nVertices() << std::endl;

  // Sort candidates by distance
  struct
  {
    bool operator()(std::pair<float, Vertex> a, std::pair<float, Vertex> b) const { return a.first > b.first; }
  } pair_ge;

  std::sort(candidates.begin(), candidates.end(), pair_ge);

  // Remove All Candidates within r-hops
  if (trim){
    typedef std::pair<int, Vertex> VHop;
    for (size_t i = 0; i < candidates.size(); i++)
    {
      std::queue<VHop> q;
      std::unordered_set<Vertex> visited;
      Vertex s = candidates[i].second;
      q.push({0, s});
      visited.insert(s);

      while (!q.empty())
      {
        auto hop = q.front().first;
        auto curr = q.front().second;
        q.pop();
        if (hop + 1 == r)
        {
          continue;
        }
        auto it = candidates.begin();
        for (; it < candidates.end(); it++)
        {
          if ((*it).second == curr && curr != s)
          {
            candidates.erase(it);
            break;
          }
        }

        for (Vertex v : curr.adjacentVertices())
        {
          int rp = hop + 1;
          if (visited.find(v) == visited.end())
          {
            visited.insert(v);
            q.push({rp, v});
          }
        }
      }
    }
  }

  std::cout << "# of Candidates Post Trim " << candidates.size() << std::endl;
  // Mark vertices on structure
  std::vector<std::array<double, 3>>
      vcolors(nm->mesh->nVertices(), {0.0, 0.0, 0.0});

  for (size_t i = 0; i < candidates.size(); i++)
  {
    vcolors[candidates[i].second.getIndex()] = {1.0, 0.0, 0.0};
  }

  std::vector<glm::vec3> pcloud;
  for (auto x : candidates)
  {
    Vector3 vertdat = nm->geometry->vertexPositions[x.second.getIndex()];
    pcloud.push_back({vertdat.x, vertdat.y, vertdat.z});
  }
  Vector3 srcdat = nm->geometry->vertexPositions[nm->_source.getIndex()];
  pcloud.push_back({srcdat.x, srcdat.y, srcdat.z});

  auto gen_pcloud = polyscope::registerPointCloud("critpts", pcloud);
  
  auto curve = polyscope::getCurveNetwork("curve");
  curve->addNodeColorQuantity("leaders", vcolors);
  
  candidates.push_back({0.0, nm->_source});

  // gen_pcloud->centerBoundingBox();
  // curve->centerBoundingBox();
  gen_pcloud->setPosition(glm::vec3{0.,0.,0.});
  curve->setPosition(glm::vec3{0.,0.,0.});
  return candidates;
}

