提出 #12180082


ソースコード 拡げる

#include <bits/stdc++.h>
using namespace std;

// 0-indexed
template <class T>
struct BIT {
  int treesize;
  vector<T> lst;
  // constructor
  BIT(int newn = 0) : treesize(newn), lst(newn + 1, 0) {}
  // a_place += num
  void add(int place, T num) {
    ++place;
    while (place <= treesize) {
      lst[place] += num;
      place += place & -place;
    }
  }
  // sum between [0,place)
  T sum(int place) {
    T res = 0;
    while (place > 0) {
      res += lst[place];
      place -= place & -place;
    }
    return res;
  }
  // sum [l,r)
  T sum(int left, int right) { return sum(right) - sum(left); }
};

struct LowestCommonAncestor {
  int n, root, h;
  vector<vector<int>> g, par;
  vector<int> dep;
  LowestCommonAncestor(int _n = 1, int r = 0) : n(n), root(r), g(n), dep(n) {
    h = 1;
    while ((1 << h) <= n) ++h;
    par.assign(h, vector<int>(n, -1));
  }
  LowestCommonAncestor(const vector<vector<int>> &_g, const int r = 0)
      : n(_g.size()), root(r), g(_g), dep(_g.size()) {
    h = 1;
    while ((1 << h) <= n) ++h;
    par.assign(h, vector<int>(n, -1));
    build();
  }
  void add(int a, int b) {
    g[a].push_back(b);
    g[b].push_back(a);
  }
  void dfs(int now, int bf, int d) {
    par[0][now] = bf;
    dep[now] = d;
    for (int &to : g[now])
      if (to != bf) dfs(to, now, d + 1);
  }
  void build() {
    dfs(root, -1, 0);
    for (int i = 0; i + 1 < h; ++i)
      for (int j = 0; j < n; ++j) {
        if (par[i][j] < 0)
          par[i + 1][j] = -1;
        else
          par[i + 1][j] = par[i][par[i][j]];
      }
  }
  int calc(int x, int y) {
    if (dep[x] > dep[y]) swap(x, y);
    for (int i = 0; i < h; ++i)
      if ((dep[y] - dep[x]) >> i & 1) y = par[i][y];
    if (x == y) return x;
    for (int i = h - 1; i >= 0; --i)
      if (par[i][x] != par[i][y]) {
        x = par[i][x];
        y = par[i][y];
      }
    return par[0][y];
  }
};

struct EulerTour {
  int n, root;
  vector<vector<int>> g;
  vector<int> par, dep, in, out, lst;
  EulerTour(int n = 1, int _root = 0)
      : root(_root), g(n), par(n), dep(n), in(n), out(n) {}
  EulerTour(const vector<vector<int>> &_g, const int _root = 0)
      : root(_root),
        g(_g),
        par(_g.size()),
        dep(_g.size()),
        in(_g.size()),
        out(_g.size()) {
    build();
  }
  void add(int a, int b) {
    g[a].push_back(b);
    g[b].push_back(a);
  }
  void build() {
    lst.clear();
    dfs(root, -1, 0);
  }
  void dfs(int now, int bf, int d) {
    dep[now] = d;
    par[now] = bf;
    in[now] = lst.size();
    lst.push_back(now);
    for (auto &to : g[now])
      if (to != bf) {
        dfs(to, now, d + 1);
        lst.push_back(now);
      }
    if (lst.back() != now) lst.push_back(now);
    out[now] = lst.size();
  }
  int chil(int x, int y) { return dep[x] < dep[y] ? y : x; }
  template <typename T, typename F>
  void update(int node, T x, const F &f) {
    f(in[node], x);
    f(out[node], -x);
  }
};

struct edge {
  int x, y, col;
  long long dis;
};

struct query {
  int id, x, y;
  long long d;
};

int n, q;

vector<vector<query>> queries;
vector<vector<int>> g, coledges;
vector<edge> edges;
vector<long long> res;
LowestCommonAncestor lca;
EulerTour et;
BIT<long long> cnt, sum;

void solve();

int main() {
  cin >> n >> q;
  queries.resize(n);
  g.resize(n);
  coledges.resize(n);
  edges.resize(n - 1);
  res.resize(q);
  for (int i = 0; i < n - 1; ++i) {
    cin >> edges[i].x >> edges[i].y >> edges[i].col >> edges[i].dis;
    --edges[i].x, --edges[i].y, --edges[i].col;
    g[edges[i].x].push_back(edges[i].y);
    g[edges[i].y].push_back(edges[i].x);
    coledges[edges[i].col].push_back(i);
  }
  lca = LowestCommonAncestor(g);
  et = EulerTour(g);
  cnt = BIT<long long>(et.lst.size());
  sum = BIT<long long>(et.lst.size());
  for (int i = 0; i < q; ++i) {
    int x, y, u, v;
    cin >> x >> y >> u >> v;
    queries[x - 1].push_back({i, --u, --v, y});
  }
  solve();
  for (int i = 0; i < q; ++i) cout << res[i] << endl;
  return 0;
}

void solve() {
  auto sumf = [](int l, long long r) { sum.add(l, r); };
  auto cntf = [](int l, long long r) { cnt.add(l, r); };
  // make cnt,sum
  for (auto e : edges) et.update(et.chil(e.x, e.y), e.dis, sumf);
  for (int i = 0; i < n; ++i) {
    // remove col i edge
    for (int id : coledges[i]) {
      edge &e = edges[id];
      int chil = et.chil(e.x, e.y);
      et.update(chil, 1, cntf);
      et.update(chil, -e.dis, sumf);
    }
    // calc distance
    for (auto qu : queries[i]) {
      int r = lca.calc(qu.x, qu.y);
      long long now = sum.sum(et.in[r] + 1, et.in[qu.x] + 1) +
                      sum.sum(et.in[r] + 1, et.in[qu.y] + 1);
      now += qu.d * (cnt.sum(et.in[r] + 1, et.in[qu.x] + 1) +
                     cnt.sum(et.in[r] + 1, et.in[qu.y] + 1));
      res[qu.id] = now;
    }
    // add col i edge
    for (int id : coledges[i]) {
      edge &e = edges[id];
      int chil = et.chil(e.x, e.y);
      et.update(chil, -1, cntf);
      et.update(chil, e.dis, sumf);
    }
  }
}

提出情報

提出日時
問題 F - Colorful Tree
ユーザ m_tsubasa
言語 C++14 (GCC 5.4.1)
得点 600
コード長 5215 Byte
結果 AC
実行時間 639 ms
メモリ 45032 KiB

ジャッジ結果

セット名 Sample All
得点 / 配点 0 / 0 600 / 600
結果
AC × 1
AC × 12
セット名 テストケース
Sample a01
All a01, b02, b03, b04, b05, b06, b07, b08, b09, b10, b11, b12
ケース名 結果 実行時間 メモリ
a01 AC 1 ms 256 KiB
b02 AC 1 ms 256 KiB
b03 AC 228 ms 4852 KiB
b04 AC 590 ms 45032 KiB
b05 AC 639 ms 44524 KiB
b06 AC 525 ms 42464 KiB
b07 AC 544 ms 43748 KiB
b08 AC 612 ms 44912 KiB
b09 AC 589 ms 44524 KiB
b10 AC 614 ms 42612 KiB
b11 AC 604 ms 41332 KiB
b12 AC 599 ms 41588 KiB