FAMILYTREE 문제 질문드립니다.

  • sven
    sven

    FAMILYTREE

    #include <vector>
    #include <algorithm>
    #include <iostream>
    #include <cstdio>
    #include <climits>
    using namespace std;
    
    typedef vector<int> VI;
    typedef vector <VI> VVI;
    #define PB push_back
    #define FOR(i,a,b) for(int i=(a); i<=(b); ++i)
    #define REP(i,n) for(int i=0; i<(n); ++i)
    #define ALL(X) (X).begin(),(X).end()
    #define FORE(it,X) for(__typeof((X).begin()) it=(X).begin(); it!=(X).end();++it)
    
    int N, Q;
    VI A; //parent
    VI D; //depth
    VVI B; //childs
    VI locInTrip, no2serial, serial2no;
    int nextSerial;
    
    struct RMQ {
      int n;
      VI rangeMin;
      RMQ(const VI &array) {
        n = array.size();
        int m = 1; while(m < n) m <<= 1;
        rangeMin.resize(m * 2);
        init(array, 0, n-1, 1);
      }
      int init(const VI &array, int l, int r, int node) {
        if(l == r) return rangeMin[node] = array[l];
        int mid = (l+r)/2;
        int lMin = init(array, l, mid, node * 2);
        int rMin = init(array, mid + 1, r, node * 2 + 1);
        return rangeMin[node] = min(lMin, rMin);
      }
      int query(int l, int r, int node, int nodeL, int nodeR) {
        if(r < nodeL or nodeR < l) return INT_MAX;
        if(l <= nodeL and nodeR <= r) return rangeMin[node];
        int mid = (nodeL + nodeR) / 2;
        return min(query(l, r, node*2, nodeL, mid), query(l, r, node*2+1, mid+1, nodeR));
      }
      //[l,r)
      int query(int l, int r) {
        return query(l, r, 1, 0, n-1);
      }
      int update(int idx, int newVal, int node, int nodeL, int nodeR) {
        if(idx < nodeL or nodeR < idx)
          return rangeMin[node];
        if(nodeL == nodeR) return rangeMin[node] = newVal;
        int mid = (nodeL + nodeR) / 2;
        return rangeMin[node] = min(
            update(idx, newVal, node * 2, nodeL, mid),
            update(idx, newVal, node * 2 + 1, mid + 1, nodeR));
      }
      int update(int idx, int newVal) {
        return update(idx, newVal, 1, 0, n-1);
      }
    };
    
    void preorder(int cur, int depth, VI &ret) {
      no2serial[cur] = nextSerial;
      serial2no[nextSerial] = cur;
      ++nextSerial;
      D[cur] = depth;
      locInTrip[cur] = ret.size();
      ret.PB(no2serial[cur]);
      FORE(i, B[cur]) {
        preorder(*i, depth+1, ret);
        ret.PB(no2serial[cur]);
      }
    }
    int solve() {
      cin >> N >> Q;
      A = VI(N); A[0] = -1; REP(i, N-1) scanf("%d", &A[i+1]);
      B = VVI(N, VI()); REP(i, N-1) B[A[i+1]].PB(i+1);
      D = VI(N); 
      VI ret;
      locInTrip = VI(N);
      serial2no = VI(N); no2serial = VI(N);
      nextSerial = 0;
      preorder(0, 0, ret);
      /*REP(i, N) P(i,B[i]);
      P(ret)*/
    
      RMQ R(ret); 
      REP(i, Q) {
        int a, b;
        cin >> a >> b;
        int ans = D[a] + D[b];
        int u = locInTrip[a], v = locInTrip[b];
        if(u > v) swap(u, v);
        int lca = serial2no[R.query(u, v)];
        ans -= 2*D[lca];
        printf("%d\n", ans);
      }
      return 0;
    }
    
    int main() {
      int T; cin >> T; REP(i, T)
        solve();
      return 0;
    }
    

    책의 코드를 구현하려고 했는데요, 시간 초과가 뜹니다.
    preorder 로 순회하되 자식을 순회하고 돌아오면 자신을 다시 넣어주는 형태로 배열을 만들면, 이 배열에서 어떤 두 원소의 least common ancestor 는 그 두 원소가 처음으로 등장한 시점 사이에서 가장 depth 가 낮은 원소이고, 이것을 RMQ 로 찾으려고 했습니다.
    시간 복잡도는 O(logN * Q) 인 것 같고, cin/cout 같이 시간이 오래 걸리는 연산들은 printf/scanf 등으로 바꿨는데도 시간 초과가 뜨는군요 ㅜㅜ 조언 부탁드립니다.


    10년 전
1개의 댓글이 있습니다.
  • JongMan
    JongMan

    지금 좀 들여다 봤는데.. cin과 stdio를 섞어 써서 느려지는 것 같아요. cin을 싹 다 scanf로 바꿔보세요.


    10년 전 link
  • 정회원 권한이 있어야 커멘트를 다실 수 있습니다. 정회원이 되시려면 온라인 저지에서 5문제 이상을 푸시고, 가입 후 7일 이상이 지나셔야 합니다. 현재 문제를 푸셨습니다.