跳转至

区间 DP

区间 DP(Interval DP)

区间 DP 是一类以区间为状态的动态规划问题。核心思想是:将一个大区间的最优解,通过枚举分割点,由若干小区间的最优解合并而来。

  • 状态:dp[i][j] 表示区间 [i, j] 上的最优值
  • 转移:枚举分割点 k,将 [i, j] 拆成 [i, k][k+1, j]
  • 顺序:按区间长度从小到大递推(先算小区间,再合并成大区间)

通用模板:

# 枚举区间长度
for length in range(2, n + 1):

    # 枚举左端点
    for i in range(0, n - length + 1):

        # 右端点
        j = i + length - 1             

        # 枚举分割点
        for k in range(i, j):

            # 更新状态
            dp[i][j] = optimal(dp[i][j], dp[i][k]  dp[k+1][j]  cost)

典型问题:

问题 状态含义 合并方式
石子合并 区间合并的最小代价 dp[i][k] + dp[k+1][j] + sum(i,j)
矩阵链乘 最少乘法次数 dp[i][k] + dp[k+1][j] + p[i]*p[k+1]*p[j+1]
戳气球 最大硬币数 dp[i][k] + dp[k][j] + nums[i]*nums[k]*nums[j]
最长回文子序列 最长回文长度 端点匹配 / 不匹配分类讨论

经典问题(石子合并

问题描述

n 堆石子排成一行,每堆石子有一个重量 a[i]。每次可以合并相邻两堆,代价为这两堆石子的重量之和。经过 n-1 次合并后,所有石子合为一堆。求最小总代价

示例:石子重量为 [1, 3, 5, 2]

合并 [1,3] → 代价 4,得到 [4, 5, 2]
合并 [5,2] → 代价 7,得到 [4, 7]
合并 [4,7] → 代价 11,得到 [11]
总代价 = 4 + 7 + 11 = 22

思路分析

最后一次合并一定是将某个 [i, k][k+1, j] 合为整个区间 [i, j]。因此可以枚举最后一次合并的分割点 k,递归地求解子区间。

graph TD
    A["[1, 3, 5, 2]<br/>dp[0][3]"] --> B["[1, 3, 5] + [2]<br/>dp[0][2] + dp[3][3]"]
    A --> C["[1, 3] + [5, 2]<br/>dp[0][1] + dp[2][3]"]
    A --> D["[1] + [3, 5, 2]<br/>dp[0][0] + dp[1][3]"]

    C --> E["[1]+[3]<br/>dp[0][0]+dp[1][1]"]
    C --> F["[5]+[2]<br/>dp[2][2]+dp[3][3]"]

    classDef blue fill:#4aa3df,color:#fff,stroke:#4aa3df;
    classDef green fill:#7db67d,color:#fff,stroke:#7db67d;
    classDef orange fill:#f28c52,color:#fff,stroke:#f28c52;

    class A blue;
    class B,C,D green;
    class E,F orange;

状态定义

dp[i][j] = 合并区间 [i, j] 内所有石子的最小代价

  • 基础状态:dp[i][i] = 0(单堆石子无需合并)

状态转移方程

\[ dp[i][j] = \min_{i \le k < j} \{ dp[i][k] + dp[k+1][j] \} + \text{sum}(i, j) \]

其中 \(\text{sum}(i, j) = \sum_{t=i}^{j} a[t]\) 是区间 [i, j] 的石子总重量,可用前缀和 \(O(1)\) 求出。

遍历顺序

计算 dp[i][j] 时需要用到所有比 [i, j] 更短的子区间。因此必须按区间长度从小到大递推:先算长度 1(基础),再算长度 2、3、...、n。

代码实现

def stone_merge(stones: list[int]) -> int:
    n = len(stones)

    prefix = [0] * (n + 1)
    for i in range(n):
        prefix[i + 1] = prefix[i] + stones[i]

    dp = [[0] * n for _ in range(n)]

    for length in range(2, n + 1):
        for i in range(n - length + 1):
            j = i + length - 1
            dp[i][j] = float('inf')
            for k in range(i, j):
                dp[i][j] = min(dp[i][j], dp[i][k] + dp[k + 1][j] + prefix[j + 1] - prefix[i])

    return dp[0][n - 1]
int stone_merge(vector<int>& stones) {
    int n = stones.size();
    vector<int> prefix(n + 1, 0);
    for (int i = 0; i < n; i++)
        prefix[i + 1] = prefix[i] + stones[i];

    vector<vector<int>> dp(n, vector<int>(n, 0));

    for (int len = 2; len <= n; len++) {
        for (int i = 0; i + len - 1 < n; i++) {
            int j = i + len - 1;
            dp[i][j] = INT_MAX;
            for (int k = i; k < j; k++) {
                dp[i][j] = min(dp[i][j], dp[i][k] + dp[k + 1][j] + prefix[j + 1] - prefix[i]);
            }
        }
    }

    return dp[0][n - 1];
}
fn stone_merge(stones: &[i32]) -> i32 {
    let n = stones.len();
    let mut prefix = vec![0; n + 1];
    for i in 0..n {
        prefix[i + 1] = prefix[i] + stones[i];
    }

    let mut dp = vec![vec![0; n]; n];

    for len in 2..=n {
        for i in 0..=n - len {
            let j = i + len - 1;
            dp[i][j] = i32::MAX;
            for k in i..j {
                dp[i][j] = dp[i][j].min(dp[i][k] + dp[k + 1][j] + prefix[j + 1] - prefix[i]);
            }
        }
    }

    dp[0][n - 1]
}
  • 时间复杂度:\(O(n^3)\),三层循环(长度、左端点、分割点)
  • 空间复杂度:\(O(n^2)\)

四边形不等式优化

核心思想

朴素区间 DP 的瓶颈在于枚举分割点 k,范围为 [i, j-1]。如果代价函数满足四边形不等式,则最优分割点 opt[i][j] 具有单调性:

\[ opt[i][j-1] \le opt[i][j] \le opt[i+1][j] \]

利用这一性质,可以将分割点 k 的搜索范围从 [i, j-1] 缩小到 [opt[i][j-1], opt[i+1][j]],总时间复杂度降为 \(O(n^2)\)

石子合并问题的代价函数 \(w(i, j) = \text{sum}(i, j)\) 满足四边形不等式,因此可以使用此优化。

def stone_merge_optimized(stones: list[int]) -> int:
    n = len(stones)
    prefix = [0] * (n + 1)
    for i in range(n):
        prefix[i + 1] = prefix[i] + stones[i]

    dp = [[0] * n for _ in range(n)]
    opt = [[0] * n for _ in range(n)]  # 最优分割点

    for i in range(n):
        opt[i][i] = i

    for length in range(2, n + 1):
        for i in range(n - length + 1):
            j = i + length - 1
            dp[i][j] = float('inf')
            lo, hi = opt[i][j - 1], opt[i + 1][j] if i + 1 < n else j - 1
            for k in range(lo, min(hi, j - 1) + 1):
                cost = dp[i][k] + dp[k + 1][j] + prefix[j + 1] - prefix[i]
                if cost < dp[i][j]:
                    dp[i][j] = cost
                    opt[i][j] = k

    return dp[0][n - 1]
int stone_merge_optimized(vector<int>& stones) {
    int n = stones.size();
    vector<int> prefix(n + 1, 0);
    for (int i = 0; i < n; i++)
        prefix[i + 1] = prefix[i] + stones[i];

    vector<vector<int>> dp(n, vector<int>(n, 0));
    vector<vector<int>> opt(n, vector<int>(n, 0));

    for (int i = 0; i < n; i++)
        opt[i][i] = i;

    for (int len = 2; len <= n; len++) {
        for (int i = 0; i + len - 1 < n; i++) {
            int j = i + len - 1;
            dp[i][j] = INT_MAX;
            int lo = opt[i][j - 1];
            int hi = (i + 1 < n) ? opt[i + 1][j] : j - 1;
            for (int k = lo; k <= min(hi, j - 1); k++) {
                int cost = dp[i][k] + dp[k + 1][j] + prefix[j + 1] - prefix[i];
                if (cost < dp[i][j]) {
                    dp[i][j] = cost;
                    opt[i][j] = k;
                }
            }
        }
    }

    return dp[0][n - 1];
}
fn stone_merge_optimized(stones: &[i32]) -> i32 {
    let n = stones.len();
    let mut prefix = vec![0; n + 1];
    for i in 0..n {
        prefix[i + 1] = prefix[i] + stones[i];
    }

    let mut dp = vec![vec![0; n]; n];
    let mut opt = vec![vec![0usize; n]; n];

    for i in 0..n {
        opt[i][i] = i;
    }

    for len in 2..=n {
        for i in 0..=n - len {
            let j = i + len - 1;
            dp[i][j] = i32::MAX;
            let lo = opt[i][j - 1];
            let hi = if i + 1 < n { opt[i + 1][j] } else { j - 1 };
            for k in lo..=hi.min(j - 1) {
                let cost = dp[i][k] + dp[k + 1][j] + prefix[j + 1] - prefix[i];
                if cost < dp[i][j] {
                    dp[i][j] = cost;
                    opt[i][j] = k;
                }
            }
        }
    }

    dp[0][n - 1]
}
  • 时间复杂度:\(O(n^2)\)
  • 空间复杂度:\(O(n^2)\)

经典问题(最长回文子序列

问题描述

给定字符串 s,求其最长回文子序列的长度。子序列可以不连续。

例如:s = "bbbab" → 最长回文子序列为 "bbbb",长度为 4。

状态定义

dp[i][j] = 字符串 s[i..j] 中最长回文子序列的长度

  • 基础状态:dp[i][i] = 1(单个字符是回文)

状态转移方程

\[ dp[i][j] = \begin{cases} dp[i+1][j-1] + 2 & \text{if } s[i] = s[j] \\ \max(dp[i+1][j],\ dp[i][j-1]) & \text{if } s[i] \ne s[j] \end{cases} \]

为什么是区间 DP

状态 dp[i][j] 定义在区间 [i, j] 上,且转移依赖更小的区间 [i+1, j-1][i+1, j][i, j-1]。需要按区间长度从小到大计算。

def longest_palindrome_subseq(s: str) -> int:
    n = len(s)
    dp = [[0] * n for _ in range(n)]

    for i in range(n):
        dp[i][i] = 1

    for length in range(2, n + 1):
        for i in range(n - length + 1):
            j = i + length - 1
            if s[i] == s[j]:
                dp[i][j] = dp[i + 1][j - 1] + 2
            else:
                dp[i][j] = max(dp[i + 1][j], dp[i][j - 1])

    return dp[0][n - 1]
int longest_palindrome_subseq(string& s) {
    int n = s.size();
    vector<vector<int>> dp(n, vector<int>(n, 0));

    for (int i = 0; i < n; i++)
        dp[i][i] = 1;

    for (int len = 2; len <= n; len++) {
        for (int i = 0; i + len - 1 < n; i++) {
            int j = i + len - 1;
            if (s[i] == s[j])
                dp[i][j] = dp[i + 1][j - 1] + 2;
            else
                dp[i][j] = max(dp[i + 1][j], dp[i][j - 1]);
        }
    }

    return dp[0][n - 1];
}
fn longest_palindrome_subseq(s: &str) -> i32 {
    let s: Vec<char> = s.chars().collect();
    let n = s.len();
    let mut dp = vec![vec![0; n]; n];

    for i in 0..n {
        dp[i][i] = 1;
    }

    for len in 2..=n {
        for i in 0..=n - len {
            let j = i + len - 1;
            if s[i] == s[j] {
                dp[i][j] = dp[i + 1][j - 1] + 2;
            } else {
                dp[i][j] = dp[i + 1][j].max(dp[i][j - 1]);
            }
        }
    }

    dp[0][n - 1]
}
  • 时间复杂度:\(O(n^2)\)
  • 空间复杂度:\(O(n^2)\)

经典问题(戳气球

问题描述

n 个气球,每个气球上标有数字 nums[i]。每次戳破一个气球 i,获得 nums[i-1] * nums[i] * nums[i+1] 枚硬币(越界视为 1)。求戳破所有气球能获得的最大硬币数

思路分析

逆向思维:最后一个被戳破的气球

直接考虑"先戳哪个"会导致边界不断变化,难以建立子问题。

逆向思考:枚举区间 (i, j)最后一个被戳破的气球 k。 此时 k 被戳破时,左右邻居一定是 ij(因为区间内其他气球都已经被戳了),代价为 nums[i] * nums[k] * nums[j]

状态定义

在原数组两端添加虚拟气球 1nums = [1] + nums + [1]

dp[i][j] = 戳破开区间 (i, j) 内所有气球能获得的最大硬币数

  • 基础状态:dp[i][i+1] = 0(开区间内没有气球)

状态转移方程

枚举 (i, j) 中最后一个被戳破的气球 k

\[ dp[i][j] = \max_{i < k < j} \{ dp[i][k] + dp[k][j] + nums[i] \cdot nums[k] \cdot nums[j] \} \]
def max_coins(nums: list[int]) -> int:
    nums = [1] + nums + [1]
    n = len(nums)
    dp = [[0] * n for _ in range(n)]

    for length in range(3, n + 1):  # 开区间长度至少为 3 才有气球
        for i in range(n - length + 1):
            j = i + length - 1
            for k in range(i + 1, j):
                dp[i][j] = max(dp[i][j], dp[i][k] + dp[k][j] + nums[i] * nums[k] * nums[j])

    return dp[0][n - 1]
int max_coins(vector<int>& nums) {
    nums.insert(nums.begin(), 1);
    nums.push_back(1);
    int n = nums.size();
    vector<vector<int>> dp(n, vector<int>(n, 0));

    for (int len = 3; len <= n; len++) {
        for (int i = 0; i + len - 1 < n; i++) {
            int j = i + len - 1;
            for (int k = i + 1; k < j; k++) {
                dp[i][j] = max(dp[i][j], dp[i][k] + dp[k][j] + nums[i] * nums[k] * nums[j]);
            }
        }
    }

    return dp[0][n - 1];
}
fn max_coins(nums: &[i32]) -> i32 {
    let mut a = vec![1];
    a.extend_from_slice(nums);
    a.push(1);
    let n = a.len();
    let mut dp = vec![vec![0; n]; n];

    for len in 3..=n {
        for i in 0..=n - len {
            let j = i + len - 1;
            for k in (i + 1)..j {
                dp[i][j] = dp[i][j].max(dp[i][k] + dp[k][j] + a[i] * a[k] * a[j]);
            }
        }
    }

    dp[0][n - 1]
}
  • 时间复杂度:\(O(n^3)\)
  • 空间复杂度:\(O(n^2)\)

复杂度对比总结

问题 时间复杂度 空间复杂度 分割点含义
石子合并(朴素) \(O(n^3)\) \(O(n^2)\) 最后一次合并的位置
石子合并(四边形不等式) \(O(n^2)\) \(O(n^2)\) 利用最优分割点单调性
最长回文子序列 \(O(n^2)\) \(O(n^2)\) 端点匹配 / 不匹配
戳气球 \(O(n^3)\) \(O(n^2)\) 最后一个被戳破的气球