初心者のプログラミング日記

プログラミング初心者の日記

プログラミングに関することを書いていきます。

区間DP

今回は区間DPをやっていきます。
理解するまでに時間がかかりましたが、理解できればなるほどと思いました。
例題は以下の問題でやっていきます。
http://judge.u-aizu.ac.jp/onlinejudge/description.jsp?id=1611&lang=jp
とりあえず、

4
1 2 3 4

与えられるデータはこれで行きます。
まず、基本となるコードを書いていきます。

N=int(input())
W=list(map(int,input().split()))

#DPテーブル      
dp = [[0]*(N) for _ in range(N)]
#探索する幅
for w in range(1,N+1):
    for i in range(N):
        j=w+i
        if j>=N:
            continue
         print(W[i:j+1])

この結果は以下のようになります

[1, 2]
[2, 3]
[3, 1]
[1, 2, 3]
[2, 3, 1]
[1, 2, 3, 4]

このように探索する幅をどんどん広げて行きます。
そして、この探索範囲で取り出せるブロック数の最大値をdpテーブルに書き込んでいきます。
求め方は2パターンあって、まず1パターン目をかいていきます。

N=int(input())
W=list(map(int,input().split()))

#DPテーブル      
dp = [[0]*(N) for _ in range(N)]
for w in range(1,N+1):
    for i in range(N):
        j=w+i
        if j>=N:
            continue
        #パターン1
        #区間を取り除けるか and 両端を取り除けるか
        if dp[i+1][j-1]==w-1 and abs(W[i]-W[j])<=1:
            dp[i][j]=w+1;

まず、dp[i+1][j-1]==w-1の処理から解説していきます。
探索範囲がdp[i][j]の時、両端を除いた区間はdp[i+1][j-1]になります。
f:id:nasubiFX:20201016181933p:plain
画像で表わすとこうなります。
そして、dpテーブルにはその区間で取り除くことができるブロック数の最大値がはいっています。
dp[i+1][j-1]を調べて、w-1となればその区間はすべて取り除くことができることになります。
次のabs(W[i]-W[j])<=1を解説していきます。
仮に両端を取り除いた区間がすべて取り除けた場合、残りは両端のみになります。
今回の問題の条件は隣り合う数字がの差が1以内ならたたき出せるので、absで両端の差を絶対値にして、差が1以内かを判断しています。
それらの条件がすべてそろった場合、dp[i][j]の区間はすべて取り除けることになります。
なので、dp[i][j]=w+1となります。
f:id:nasubiFX:20201016183305p:plain

2パターン目の処理を書いていきます

N=int(input())
W=list(map(int,input().split()))

#DPテーブル      
dp = [[0]*(N) for _ in range(N)]
for w in range(1,N+1):
    for i in range(N):
        j=w+i
        if j>=N:
            continue
        #パターン1
        #区間を取り除けるか and 両端を取り除けるか
        if dp[i+1][j-1]==w-1 and abs(W[i]-W[j])<=1:
            dp[i][j]=w+1;
            
        #パターン2
        #区間を分ける
        for mid in range(i,j):
            dp[i][j]=max(dp[i][j] , dp[i][mid]+dp[mid+1][j]);
            
        print(dp)
        
print(dp)
print(dp[0][-1])

この区間分けは探索区間における取り除くことができるブロック数の最大値を検索するものです。
探索方法はまず、探索区間を2ブロックにわける方法をfor文で全探索します。
そして、2ブロックに分けたそれぞれのブロックの取り除くことができるブロック数の最大値を足して、最終的にはその区間における取り除くことができるブロック数の最大値を求めることができます。
その結果をdpに書き込みます。
f:id:nasubiFX:20201016190553p:plain
幅が小さい順で探索しているので、これらのパターンの結果はすでにdpに書き込んであります。
これで、パターン1とパターン2で全探索ができました。
そして、最終的には以下の表ができます
f:id:nasubiFX:20201016192437p:plain
正解は全区間を探索した結果なので、dp[0][-1]になります。

わかりづらかったかもしれないですが、これが限界です。
参考記事
http://kutimoti.hatenablog.com/entry/2018/03/10/220819

おまけ

以下の問題を解いてみました
https://atcoder.jp/contests/tdpc/tasks/tdpc_iwi

S=str(input())
 
dp = [[0]*(len(S)) for _ in range(len(S))]

for w in range(2,len(S)+1):
    for i in range(len(S)):
        j=w+i
        if j>=len(S):
            continue
        if S[i:j+1]=="iwi":
            dp[i][j]=1
        elif dp[i+1][j-2]!=0 and S[i]+S[j-1:]=="iwi":
            dp[i][j]=dp[i+1][j-2]+1
        elif dp[i+2][j-1]!=0 and S[:i+2]+S[j-1]=="iwi":
            dp[i][j]=dp[i+2][j-1]+1
        
        for mid in range(i,j):
            dp[i][j]=max(dp[i][j],dp[i][mid]+dp[mid+1][j])
            
print(dp[0][-1])

最初はこのコードを試して見たんですが、これだと入力例2のように以下のようになる場合に対応できないんですよね。
f:id:nasubiFX:20201018151947p:plain
で、ちょと調べて改良してみたんですけど、WAで通らなかったです。
テストケースを調べてもでてこなくてこれ以上はわかりません

S=str(input())
 
dp = [[0]*(len(S)) for _ in range(len(S))]
 
for w in range(2,len(S)):
    for i in range(len(S)):
        j=w+i
        if j>=len(S):
            continue
        if S[i:j+1]=="iwi":
            dp[i][j]=1
            continue
        
        if S[i]=="i" and S[j]=="i":
                for k in range(i+1,j):
                    if S[k]=="w":
                        #おそらくここの処理が間違っている。全部消える時の処理ができていない
                        dp[i][j]=max(dp[i][j],dp[i+1][k-1]+dp[k+1][j-1]+1)
            
        for mid in range(i+1,j):
            dp[i][j]=max(dp[i][j],dp[i][mid]+dp[mid+1][j])

print(dp[0][-1])