「BZOJ2734」[HNOI2012]集合选数

Description

集合论与图论》这门课程有一道作业题,要求同学们求出{1,2,3,4,5}\{1, 2, 3, 4, 5\}的所有满足以下条件的子集:若 xx 在该子集中,则 2x2x3x3x 不能在该子集中。同学们不喜欢这种具有枚举性 质的题目,于是把它变成了以下问题:对于任意一个正整数 n100000n\leq 100000,如何求出{1,2,,n}\{1, 2,\cdots , n\} 的满足上述约束条件的子集的个数(只需输出对 1,000,000,0011,000,000,001 取模的结果),现在这个问题就交给你了。

Input

只有一行,其中有一个正整数 nn

Output

仅包含一个正整数,表示{1,2,,n}\{1, 2,\cdots , n\}有多少个满足上述约束条件 的子集。

Sample Input

1
4

Sample Output

1
8

Solution

[T3T9T3m0T2T3×2T9×2T3m1×2T3mn×2nT]\begin{bmatrix}T & 3T&9T&\cdots&3^{m_0}T \\2T & 3\times2T &9\times 2T&\cdots &3^{m_1}\times2T\\ \cdots&\cdots&\dots&\cdots&3^{m_{n}}\times2^{n}T\end{bmatrix}

构造矩阵。那么枚举一次TT,就是在这个矩阵里选数,且选的位置不会上下左右相邻。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
#include<cstdio>
#include<algorithm>
#include<iostream>
#include<cstring>
#include<map>
#include<cmath>
#include<queue>
#include<bitset>
#define mk make_pair
#define fi first
#define nd second
#define pii pair<int,int>
#define pb push_back
#define sqr(x) ((x)*(x))
using namespace std;
typedef long long ll;
inline ll read() {ll x = 0; char ch = getchar(), w = 1;while(ch < '0' || ch > '9') {if(ch == '-') w = -1;
ch = getchar();}while(ch >= '0' && ch <= '9') {x = x * 10 + ch - '0';ch = getchar();}return x * w;}
void write(ll x) {if(x < 0) putchar('-'), x = -x;if(x > 9) write(x / 10);putchar(x % 10 + '0');}
inline void writeln(ll x) {write(x);puts("");}
/*
16.6096404744
10.4795163714
*/
int a[200][103];
int n, m[200], tot;
ll f[2][1 << 15];
const ll p = 1000000001;
void add(ll &x, ll v) {
x += v;
if(x >= p) x %= p;
}
bool vis[1100000];
int bin[200];
int main() {
tot = read();
ll ans = 1;
for(int T = 1; T <= tot; ++T) if(!vis[T]) {
int t = T;
for(n = 0;; ++n) {
a[n][0] = t;
m[n] = -1;
for(int j = t; j <= tot; j*=3) {
a[n][++m[n]] = j; //暴力算,不要用log
}
bin[n] = 1 << (m[n]+1);
t *= 2;
if(t > tot) break;
}
for(int i = 0; i <= n; ++i)
for(int j = 0; j <= m[i]; ++j) vis[a[i][j]] = 1;
t = 0;
memset(f[t],0,sizeof f[t]);
for(int i = 0; i < bin[0]; ++i) if(!(i & (i >>1)) && !(i & (i <<1)))f[t][i] = 1;
for(int i = 1; i <= n; ++i) {
t ^= 1;
memset(f[t],0,sizeof f[t]);
for(int j = 0; j < bin[i]; ++j)
for(int k = 0; k < bin[i - 1]; ++k)
if(!(j & k) && !(j & (j << 1)) && !(j & (j >> 1)))
add(f[t][j], f[t ^ 1][k]);
}
ll res = 0;
for(int i = 0; i < bin[n]; ++i) add(res, f[t][i]);
ans = ans * res % p;
}
writeln(ans);
return 0;
}