TSP2 : MST+DP 시간초과 문제

  • minshogi
    minshogi

    안녕하세요? 읽어주셔서 감사합니다.
    종만북을 참고하여 MST Heuristic과 DP를 활용한 코드를 짜보았습니다.
    MST만 사용했을 땐 DP만 사용했을 때와 비슷한 시간으로 통과가 됐으나 MST와 DP를 함께 사용하니 시간초과가 뜨네요..
    책을 보고 거의 베끼다싶이 했는데 시간초과가 떠버리니 당황스럽네요.
    어떤 부분을 더 최적화 해야할지, 어떤 부분 때문에 시간초과가 나는건지 알고싶습니다.
    아래는 제 코드입니다.

    TSP2

    #include <cstdio>
    #include <iostream>
    #include <algorithm>
    #include <vector>
    #include <map>
    #include <limits>
    using namespace std;
    
    struct DisjointSet
    {
        vector<int> parent, rank;
        DisjointSet(int n) : parent(n), rank(n,1)
        {
            for(int i=0;i<n;++i) parent[i]=i;
        }
        int find(int u)
        {
            if(u == parent[u])  return u;
            return parent[u] = find(parent[u]);
        }
        bool merge(int u,int v)
        {
            u = find(u); v = find(v);
            if(u==v)    return false;
            if(rank[u]>rank[v]) swap(u,v);
            parent[u]=v;
            if(rank[u]==rank[v]) ++rank[v];
            return true;
        }
    };
    
    const double INF = numeric_limits<double>::max();
    const int MAX_N = 20;
    
    vector<pair<double,pair<int, int> > > edges;
    vector<vector<double> > dist;
    
    //DP[here][remaining][visited]
    map<int, double> DP[MAX_N][6];
    
    int N,END;
    double BEST;
    
    int left_bit(int visited){  return N-__builtin_popcount(visited);}
    
    void get_dist();
    void make_edge();
    double mstHeuristic(const int here, const int visited);
    bool prune(double currentLen, int here, const int visited);
    double DP_search(int here, int visited);
    void search(double currentLen, int here, int visited);
    double solve();
    
    int main()
    {
        int T;  scanf("%d",&T);
        while(T--)
        {
            printf("%.10lf\n",solve());
        }
    }
    
    void get_dist()
    {   
        dist.clear();   dist.resize(N,vector<double>(N));
    
        for(int i=0;i<N;++i)
            for(int j=0;j<N;++j)
                scanf("%lf",&dist[i][j]);
    }
    
    void make_edge()
    {
        edges.clear();
        for(int i=0;i<N;++i)
            for(int j=0;j<N;++j)
                edges.push_back(make_pair(dist[i][j],make_pair(i,j)));
    
        sort(edges.begin(), edges.end());
    }
    
    double mstHeuristic(const int here, const int visited)
    {
        DisjointSet DS(N);
        double taken=0;
        for(int i=0;i<edges.size();++i)
        {
            int a = edges[i].second.first, b = edges[i].second.second;
            if(a!=here && (visited&(1<<a))) continue;
            if(b!=here && (visited&(1<<b))) continue;
            if(DS.merge(a,b))
                taken+=edges[i].first;
        }
        return taken;
    }
    
    bool prune(double currentLen, int here, const int visited)
    {
        return (mstHeuristic(here,visited)+currentLen >= BEST);
    }
    
    double DP_search(int here, int visited)
    {
    
        if(visited == END) return 0;
    
        int remaining = left_bit(visited);
        double& ret = DP[here][remaining][visited];
        if(ret>0)   return ret;
    
        ret = INF;
        for(int there=0;there<N;++there)
        {
            if(dist[here][there]!=0 && !(visited&(1<<there)))
            {
                ret = min(ret, DP_search(there,visited|(1<<there)) + dist[here][there]);
            }
        }
        return ret;
    }
    
    void search(double currentLen, int here, int visited)
    {
        if(left_bit(visited)+5 >= N)
        {
            BEST = min(BEST, currentLen+DP_search(here, visited));
            return;
        }
    
        if(prune(currentLen, here, visited)) return;
    
        for(int there=0;there<N; ++there)
        {
            if(dist[here][there]!=0 && !(visited&(1<<there)))
            {
                search(currentLen+dist[here][there], there, visited|(1<<there));
            }
        }
    }
    
    double solve()
    {
        scanf("%d",&N);
        BEST=INF;
        END = (1<<N)-1;
        for(int i=0;i<MAX_N;++i)
            for(int j=0;j<6;++j)
                DP[i][j].clear();
    
    
        int visited=0;
    
        get_dist();
        make_edge();
    
        for(int i=0;i<N;++i) search(0,i,visited|(1<<i));
    
        return BEST;
    }
    

    7년 전
0개의 댓글이 있습니다.
  • 정회원 권한이 있어야 커멘트를 다실 수 있습니다. 정회원이 되시려면 온라인 저지에서 5문제 이상을 푸시고, 가입 후 7일 이상이 지나셔야 합니다. 현재 문제를 푸셨습니다.