「BZOJ4518」[SDOI2016]征途

Description

Pine 开始了从SS地到TT地的征途。

SS地到TT地的路可以划分成nn段,相邻两段路的分界点设有休息站。
Pine 计划用mm天到达TT地。除第mm天外,每一天晚上 Pine 都必须在休息站过夜。所以,一段路必须在同一天中走完。
Pine 希望每一天走的路长度尽可能相近,所以他希望每一天走的路的长度的方差尽可能小。

帮助 Pine 求出最小方差是多少。

设方差是vv,可以证明,v×m2v \times m ^ 2是一个整数。为了避免精度误差,输出结果时输出v×m2v \times m ^ 2

Input

第一行两个数nnmm
第二行nn个数,表示nn段路的长度。

Output

一个数,最小方差乘以m2m ^ 2后的值。

Sample Input

1
2
5 2
1 2 5 8 6

Sample Output

1
36

HINT

对于30%30\%的数据,1n101 \leq n \leq 10
对于60%60\%的数据,1n1001 \leq n \leq 100
对于100%100\%的数据,1n30001 \leq n \leq 3000
保证从SSTT的总路程不超过3000030000

Solution

xˉ=i=1mxi\bar{x}=\sum_{i=1}^m x_i

ans=min{i=1m(xˉxi)m}m2ans=\min\{\frac{\sum_{i=1}^m(\bar{x}-x_i)}{m}\}\cdot m^2

min\min去掉,化简一下得到:

ans=mi=1mxi2(i=1mxi)2ans=m\sum_{i=1}^mx_i^2-(\sum_{i=1}^m x_i)^2

同时有(i=1mxi)=(i=1nxi)(\sum_{i=1}^m x_i)=(\sum_{i=1}^n x_i),前后两个xx含义不同。
fj,if_{j,i}表示前ii个数,划分成jj个集合的最小值。
转移用斜率优化一下。
观察(i=1nxi)2-(\sum_{i=1}^n x_i)^2是个定值,最后考虑就行了。
第一项中mm是个定值,所以我们维护i=1mxi2\sum_{i=1}^mx_i^2的最大值即时答案。
假设从fj1,kf_{j-1,k}转移过来。

fj,i=fj1,k+(sisk)2f_{j,i}=f_{j-1,k}+(s_i-s_k)^2

标准的斜率优化板子!
顺便滚存,注意初值f0,i=inf(1in)f_{0,i}=inf(1\leq i\leq n)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
#include<cstdio>
#include<algorithm>
#include<iostream>
#include<cstring>
#include<map>
#include<queue>
#include<bitset>
#define mk make_pair
#define fi first
#define nd second
#define pii pair<int,int>
#define pb push_back
#define sqr(x) ((x)*(x))
using namespace std;
typedef long long ll;
inline ll read() {ll x = 0; char ch = getchar(), w = 1;while(ch < '0' || ch > '9') {if(ch == '-') w = -1;
ch = getchar();}while(ch >= '0' && ch <= '9') {x = x * 10 + ch - '0';ch = getchar();}return x * w;}
void write(ll x) {if(x < 0) putchar('-'), x = -x;if(x > 9) write(x / 10);putchar(x % 10 + '0');}
inline void writeln(ll x) {write(x);puts("");}

const int N = 3100;
ll f[2][N];
int n, m;
ll a[N], s[N];
int q[N], h, t;
bool c;
double F(int x) {
return s[x] * s[x] + f[c ^ 1][x];
}
double slope(int x, int y) {
return (double) (F(y) - F(x)) / (s[y] - s[x]);
}
int main() {
n = read(), m = read();
for(int i = 1; i <= n; ++i) a[i] = read(), s[i] = s[i - 1] + a[i], f[0][i] = 1e16;
for(int j = 1; j <= m; ++j) {
q[h = t = 1] = 0;
c ^= 1;
memset(f[c],0,sizeof f[c]);
for(int i = 1; i <= n; ++i) {
while(h < t && slope(q[h], q[h + 1]) <= 2 * s[i]) ++h;
f[c][i] = f[c ^ 1][q[h]] + (s[i] - s[q[h]]) * (s[i] - s[q[h]]);
while(h < t && slope(q[t - 1], q[t]) >= slope(q[t-1], i)) --t;
q[++t] = i;
}
}
writeln(m * f[c][n] - s[n]*s[n]);
return 0;
}