区間DP
今回は区間DPをやっていきます。
理解するまでに時間がかかりましたが、理解できればなるほどと思いました。
例題は以下の問題でやっていきます。
http://judge.u-aizu.ac.jp/onlinejudge/description.jsp?id=1611&lang=jp
とりあえず、
4 1 2 3 4
与えられるデータはこれで行きます。
まず、基本となるコードを書いていきます。
N=int(input()) W=list(map(int,input().split())) #DPテーブル dp = [[0]*(N) for _ in range(N)] #探索する幅 for w in range(1,N+1): for i in range(N): j=w+i if j>=N: continue print(W[i:j+1])
この結果は以下のようになります
[1, 2] [2, 3] [3, 1] [1, 2, 3] [2, 3, 1] [1, 2, 3, 4]
このように探索する幅をどんどん広げて行きます。
そして、この探索範囲で取り出せるブロック数の最大値をdpテーブルに書き込んでいきます。
求め方は2パターンあって、まず1パターン目をかいていきます。
N=int(input()) W=list(map(int,input().split())) #DPテーブル dp = [[0]*(N) for _ in range(N)] for w in range(1,N+1): for i in range(N): j=w+i if j>=N: continue #パターン1 #区間を取り除けるか and 両端を取り除けるか if dp[i+1][j-1]==w-1 and abs(W[i]-W[j])<=1: dp[i][j]=w+1;
まず、dp[i+1][j-1]==w-1の処理から解説していきます。
探索範囲がdp[i][j]の時、両端を除いた区間はdp[i+1][j-1]になります。
画像で表わすとこうなります。
そして、dpテーブルにはその区間で取り除くことができるブロック数の最大値がはいっています。
dp[i+1][j-1]を調べて、w-1となればその区間はすべて取り除くことができることになります。
次のabs(W[i]-W[j])<=1を解説していきます。
仮に両端を取り除いた区間がすべて取り除けた場合、残りは両端のみになります。
今回の問題の条件は隣り合う数字がの差が1以内ならたたき出せるので、absで両端の差を絶対値にして、差が1以内かを判断しています。
それらの条件がすべてそろった場合、dp[i][j]の区間はすべて取り除けることになります。
なので、dp[i][j]=w+1となります。
2パターン目の処理を書いていきます
N=int(input()) W=list(map(int,input().split())) #DPテーブル dp = [[0]*(N) for _ in range(N)] for w in range(1,N+1): for i in range(N): j=w+i if j>=N: continue #パターン1 #区間を取り除けるか and 両端を取り除けるか if dp[i+1][j-1]==w-1 and abs(W[i]-W[j])<=1: dp[i][j]=w+1; #パターン2 #区間を分ける for mid in range(i,j): dp[i][j]=max(dp[i][j] , dp[i][mid]+dp[mid+1][j]); print(dp) print(dp) print(dp[0][-1])
この区間分けは探索区間における取り除くことができるブロック数の最大値を検索するものです。
探索方法はまず、探索区間を2ブロックにわける方法をfor文で全探索します。
そして、2ブロックに分けたそれぞれのブロックの取り除くことができるブロック数の最大値を足して、最終的にはその区間における取り除くことができるブロック数の最大値を求めることができます。
その結果をdpに書き込みます。
幅が小さい順で探索しているので、これらのパターンの結果はすでにdpに書き込んであります。
これで、パターン1とパターン2で全探索ができました。
そして、最終的には以下の表ができます
正解は全区間を探索した結果なので、dp[0][-1]になります。
わかりづらかったかもしれないですが、これが限界です。
参考記事
http://kutimoti.hatenablog.com/entry/2018/03/10/220819
おまけ
以下の問題を解いてみました
https://atcoder.jp/contests/tdpc/tasks/tdpc_iwi
S=str(input()) dp = [[0]*(len(S)) for _ in range(len(S))] for w in range(2,len(S)+1): for i in range(len(S)): j=w+i if j>=len(S): continue if S[i:j+1]=="iwi": dp[i][j]=1 elif dp[i+1][j-2]!=0 and S[i]+S[j-1:]=="iwi": dp[i][j]=dp[i+1][j-2]+1 elif dp[i+2][j-1]!=0 and S[:i+2]+S[j-1]=="iwi": dp[i][j]=dp[i+2][j-1]+1 for mid in range(i,j): dp[i][j]=max(dp[i][j],dp[i][mid]+dp[mid+1][j]) print(dp[0][-1])
最初はこのコードを試して見たんですが、これだと入力例2のように以下のようになる場合に対応できないんですよね。
で、ちょと調べて改良してみたんですけど、WAで通らなかったです。
テストケースを調べてもでてこなくてこれ以上はわかりません
S=str(input()) dp = [[0]*(len(S)) for _ in range(len(S))] for w in range(2,len(S)): for i in range(len(S)): j=w+i if j>=len(S): continue if S[i:j+1]=="iwi": dp[i][j]=1 continue if S[i]=="i" and S[j]=="i": for k in range(i+1,j): if S[k]=="w": #おそらくここの処理が間違っている。全部消える時の処理ができていない dp[i][j]=max(dp[i][j],dp[i+1][k-1]+dp[k+1][j-1]+1) for mid in range(i+1,j): dp[i][j]=max(dp[i][j],dp[i][mid]+dp[mid+1][j]) print(dp[0][-1])