递归与递推

递归实现指数型枚举

92. 递归实现指数型枚举

1
2
3
4
5
6
7
8
9
10
11
n = int(input())
st = []
def dfs(u):
if u == n + 1:
print(' '.join(map(str, st)))
return
dfs(u + 1)
st.append(u)
dfs(u + 1)
st.pop()
dfs(1)

递归实现排列型枚举

94. 递归实现排列型枚举

1
2
3
4
5
6
7
8
9
10
11
12
13
14
n = int(input())
st = [0] * (n + 1)
used = [False] * (n + 1)
def dfs(u):
if u == n + 1:
print(' '.join(map(str, st[1:])))
return
for i in range(1, n + 1):
if not used[i]:
st[u] = i
used[i] = True
dfs(u + 1)
used[i] = False
dfs(1)

递归实现组合型枚举

93. 递归实现组合型枚举

1
2
3
4
5
6
7
8
9
10
11
12
n, m = map(int, input().split())
st = [0] * (m + 1)
def dfs(u, start):
if n + u - start < m:
return
if u == m + 1:
print(' '.join(map(str, st[1:])))
return
for i in range(start, n + 1):
st[u] = i
dfs(u + 1, i + 1)
dfs(1, 1)

带分数

1209. 带分数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
n = int(input())
st = [0] * 10
res = 0
def check(a, c):
b = n * c - a * c
nums = set('123456789')
return set(str(a) + str(b) + str(c)) == nums
def dfs_c(u, a, c):
b = n * c - a * c
if len(str(a) + str(b) + str(c)) > 9:
return
global res
res += check(a, c)
for i in range(1, 10):
if not st[i]:
st[i] = True
dfs_c(u + 1, a, c * 10 + i)
st[i] = False
def dfs_a(u, a):
if a >= n:
return
dfs_c(u, a, 0)
for i in range(1, 10):
if not st[i]:
st[i] = True
dfs_a(u + 1, a * 10 + i)
st[i] = False
dfs_a(0, 0)
print(res)

简单斐波那契

717. 简单斐波那契

1
2
3
4
5
n = int(input())
nums = [0, 1] + [0] * n
for i in range(2, n + 1):
nums[i] = nums[i - 1] + nums[i - 2]
print(' '.join(map(str, nums[:n])))

费解的开关

95. 费解的开关

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
t = int(input())
dirs = [(0, 0), (1, 0), (0, 1), (-1, 0), (0, -1)]
def turn(a, b):
for i in range(5):
x, y = a + dirs[i][0], b + dirs[i][1]
if 0 <= x < 5 and 0 <= y < 5:
g[x][y] = '1' if g[x][y] == '0' else '0'
for idx in range(t):
g = [list(input()) for _ in range(5)]
if idx < t - 1:
input()
res = 10
tmp = [i[:] for i in g]
for x in range(32):
stp = 0
for i in range(5):
if x >> i & 1:
turn(0, i)
stp += 1
for i in range(4):
for j in range(5):
if g[i][j] == '0':
turn(i + 1, j)
stp +=1
if '0' not in g[4]:
res = min(res, stp)
g = [i[:] for i in tmp]
print(-1 if res > 6 else res)

翻硬币

1208. 翻硬币

1
2
3
4
5
6
7
a, b = list(input()), list(input())
res = 0
for i in range(len(a) - 1):
if a[i] != b[i]:
a[i+1] = 'o' if a[i+1] == '*' else '*'
res += 1
print(res)

飞行员兄弟

116. 飞行员兄弟

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
g = [list(input()) for _ in range(4)]
res = []
def turn(x, y):
backup[x][y] = '-' if backup[x][y] == '+' else '+'
def turn_all(x, y):
for i in range(4):
turn(x, i)
turn(i, y)
turn(x, y)
for x in range(1 << 16):
tmp = []
backup = [i[:] for i in g]
for i in range(4):
for j in range(4):
if x >> (4 * i + j) & 1:
turn_all(i, j)
tmp.append((i+1, j+1))
if not any('+' in i for i in backup):
if len(res) < len(tmp):
res = [i[:] for i in tmp]
print(len(res))
for i in res:
print(i[0], i[1])

二分与前缀和

数的范围

789. 数的范围

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
n, q = map(int, input().split())
nums = list(map(int, input().split()))
for _ in range(q):
k = int(input())
l, r = 0, n - 1
while l < r:
mid = l + r >> 1
if nums[mid] >= k:
r = mid
else:
l = mid + 1
if nums[l] == k:
print(l, end=' ')
r = n - 1
while l < r:
mid = l + r + 1 >> 1
if nums[mid] <= k:
l = mid
else:
r = mid - 1
if nums[l] == k:
print(l)
else:
print(-1, -1)

bisect

1
2
3
4
5
6
7
8
9
10
11
import bisect
n, q = map(int, input().split())
nums = list(map(int, input().split()))
for _ in range(q):
k = int(input())
l = bisect.bisect_left(nums, k)
r = bisect.bisect_right(nums, k)
if l != r:
print(l, r - 1)
else:
print(-1, -1)

数的三次方根

790. 数的三次方根

1
2
3
4
5
6
7
8
9
n = float(input())
l, r = -10000, 10000
while r - l > 1e-8:
mid = (l + r) / 2
if mid * mid * mid <= n:
l = mid
else:
r = mid
print(f'{l:.6f}')

机器人跳跃问题

730. 机器人跳跃问题

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
n = int(input())
nums = list(map(int, input().split()))
h = max(nums)
def check(e):
for num in nums:
e = 2 * e - num
if e > h:
return True
elif e < 0:
return False
return True
l, r = 0, h
while l < r:
mid = l + r >> 1
if check(mid):
r = mid
else:
l = mid + 1
print(l)

四平方和

1221. 四平方和

暴力

1
2
3
4
5
6
7
8
9
10
import math
n = int(input())
def fun():
for i in range(int((n/4) ** 0.5) + 1):
for j in range(i, int(((n - i*i) / 3) ** 0.5) + 1):
for k in range(j, int(((n - i*i - j*j) / 2) ** 0.5) + 1):
l = int(math.sqrt(n - i*i - j*j - k*k))
if i*i + j*j + k*k + l*l == n:
return i, j, k, l
print(*fun())

哈希表

1
2
3
4
5
6
7
8
9
10
11
12
13
n = int(input())
dic = {}
for c in range(int(n ** 0.5) + 1):
for d in range(int((n - c*c) ** 0.5) + 1):
t = c*c + d*d
if t not in dic:
dic[t] = (c, d)
for a in range(int(n ** 0.5) + 1):
for b in range(int((n - a*a) ** 0.5) + 1):
t = n - a*a - b*b
if t in dic:
print(a, b, dic[t][0], dic[t][1])
exit()

分巧克力

1227. 分巧克力

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
n, k = map(int, input().split())
chocolates = [list(map(int, input().split())) for _ in range(n)]
def check(x):
t = 0
for chocolate in chocolates:
t += (chocolate[0] // x) * (chocolate[1] // x)
if t >= k:
return True
return False
l, r = 1, 100000
while l < r:
mid = l + r + 1 >> 1
if check(mid):
l = mid
else:
r = mid - 1
print(l)

前缀和

795. 前缀和

1
2
3
4
5
6
7
8
n, m = map(int, input().split())
nums = list(map(int, input().split()))
sums = [0] + nums[:]
for i in range(1, n+1):
sums[i] += sums[i-1]
for _ in range(m):
l, r = map(int, input().split())
print(sums[r] - sums[l-1])

子矩阵的和

796. 子矩阵的和

1
2
3
4
5
6
7
8
9
10
11
from sys import stdin
input = lambda: stdin.readline().strip()
n, m, q = map(int, input().split())
nums = [[0] * (m+1)] + [[0] + list(map(int, input().split())) for _ in range(n)]
sums = [num[:] for num in nums]
for i in range(1, n+1):
for j in range(1, m+1):
sums[i][j] = sums[i-1][j] + sums[i][j-1] - sums[i-1][j-1] + nums[i][j]
for _ in range(q):
x1, y1, x2, y2 = map(int, input().split())
print(sums[x2][y2] - sums[x1-1][y2] - sums[x2][y1-1] + sums[x1-1][y1-1])

激光炸弹

99. 激光炸弹 - AcWing题

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
cnt, r = map(int, input().split())
values = [[0] * 5002 for _ in range(5002)]
r = min(r, 5001)
n = m = r
for _ in range(cnt):
x, y, w = map(int, input().split())
x += 1
y += 1
n, m = max(n, x), max(m, y)
values[x][y] += w
for i in range(1, n+1):
for j in range(1, m+1):
values[i][j] += values[i-1][j] + values[i][j-1] - values[i-1][j-1]
res = 0
for i in range(r, n+1):
for j in range(r, m+1):
t = values[i][j] - values[i-r][j] - values[i][j-r] + values[i-r][j-r]
res = max(res, t)
print(res)

K倍区间

1230. K倍区间

1
2
3
4
5
6
7
8
9
10
n, k = map(int, input().split())
sums = [0] + [int(input()) for _ in range(n)]
for i in range(1, n+1):
sums[i] += sums[i-1]
cnt = [0] * (n+1)
res = 0
for i in range(n+1):
res += cnt[sums[i] % k]
cnt[sums[i] % k] += 1
print(res)

数学与简单DP

买不到的数目

AcWing 1205. 买不到的数目

暴力

1
2
3
4
5
6
7
8
9
10
11
12
13
14
n, m = map(int, input().split())
def check(i, n, m):
if not i:
return True
if i >= n and check(i-n, n, m):
return True
if i >= m and check(i-m, n, m):
return True
return False
res = 0
for i in range(1, 1000):
if not check(i, n, m):
res = i
print(res)

优化

1
2
3
4
5
6
7
8
9
10
11
12
13
n, m = map(int, input().split())
dp = [False] * (1000001)
dp[0] = True
for i in range(1, 1000001):
if i >= m:
dp[i] |= dp[i-m]
if i >= n:
dp[i] |= dp[i-n]
res = 0
for i in range(1, 1000001):
if not dp[i]:
res = i
print(res)

公式

1
2
n, m = map(int, input().split())
print((n-1) * (m-1) - 1)

蚂蚁感冒

1211. 蚂蚁感冒

1
2
3
4
5
6
7
8
9
10
11
12
n = int(input())
nums = list(map(int, input().split()))
left = right = 0
for i in range(1, n):
if abs(nums[0]) > abs(nums[i]) and nums[i] > 0:
left += 1
elif abs(nums[0]) < abs(nums[i]) and nums[i] < 0:
right += 1
if nums[0] > 0 and not right or nums[0] < 0 and not left:
print(1)
else:
print(1 + left + right)

饮料换购

1216. 饮料换购

1
2
3
4
5
6
n = int(input())
res = n
while n > 2:
res += n // 3
n = n // 3 + n % 3
print(res)

背包问题

2. 01背包问题

1
2
3
4
5
6
7
8
9
n, v = map(int, input().split())
nums = [list(map(int, input().split())) for _ in range(n)]
dp = [[0] * (v+1) for _ in range(n+1)]
for i in range(1, n+1):
for j in range(1, v+1):
dp[i][j] = dp[i-1][j]
if j >= nums[i-1][0]:
dp[i][j] = max(dp[i][j], dp[i-1][j-nums[i-1][0]] + nums[i-1][1])
print(dp[n][v])

摘花生

1015. 摘花生

1
2
3
4
5
6
7
8
9
t = int(input())
for _ in range(t):
r, c = map(int, input().split())
g = [list(map(int, input().split())) for _ in range(r)]
dp = [[0] * (c+1) for _ in range(r+1)]
for i in range(1, r+1):
for j in range(1, c+1):
dp[i][j] = max(dp[i-1][j], dp[i][j-1]) + g[i-1][j-1]
print(dp[i][j])

最长上升子序列

895. 最长上升子序列

1
2
3
4
5
6
7
8
n = int(input())
nums = [0] + list(map(int, input().split()))
dp = [1] * (n+1)
for i in range(1, n+1):
for j in range(1, i):
if nums[i] > nums[j]:
dp[i] = max(dp[i], dp[j] + 1)
print(max(dp))

地宫取宝

1212. 地宫取宝

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
n, m, k = map(int, input().split())
MOD = 1000000007
g = [[0] * (m+1)] + [[0] + list(map(lambda x: int(x) + 1, input().split())) for _ in range(n)]
dp = [[[[0] * 14 for _ in range(k+1)] for _ in range(m+1)] for _ in range(n+1)]
dp[1][1][1][g[1][1]] = 1
dp[1][1][0][0] = 1
for i in range(1, n+1):
for j in range(1, m+1):
if i == 1 and j == 1:
continue
for u in range(k+1):
for v in range(14):
dp[i][j][u][v] = (dp[i][j][u][v] + dp[i-1][j][u][v] + dp[i][j-1][u][v]) % MOD
if u > 0 and v == g[i][j]:
for t in range(v):
dp[i][j][u][v] = (dp[i][j][u][v] + dp[i-1][j][u-1][t] + dp[i][j-1][u-1][t])
print(sum(dp[i][j][k]) % MOD)

波动数列

1214. 波动数列

1
2
3
4
5
6
7
8
n, s, a, b = map(int, input().split())
dp = [[0] * n for _ in range(n)]
dp[0][0] = 1
MOD = 100000007
for i in range(1, n):
for j in range(n):
dp[i][j] = (dp[i-1][(j-i*a) % n] + dp[i-1][(j+i*b) % n]) % MOD
print(dp[n-1][s%n])

枚举、模拟与排序

连号区间数

1210. 连号区间数

1
2
3
4
5
6
7
8
9
10
11
n = int(input())
nums = list(map(int, input().split()))
res = 0
for i in range(n):
minv, maxv = float('inf'), float('-inf')
for j in range(i, n):
minv = min(minv, nums[j])
maxv = max(maxv, nums[j])
if maxv - minv == j - i:
res += 1
print(res)

递增三元组

1236. 递增三元组

前缀和

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
n = int(input())
a = list(map(lambda x: int(x) + 1, input().split()))
b = list(map(lambda x: int(x) + 1, input().split()))
c = list(map(lambda x: int(x) + 1, input().split()))
cnta, cntc, suma, sumc = [0] * 100002, [0] * 100002, [0] * 100002, [0] * 100002
for i in range(n):
cnta[a[i]] += 1
for i in range(1, 100002):
suma[i] = suma[i-1] + cnta[i]
for i in range(n):
cntc[c[i]] += 1
for i in range(1, 100002):
sumc[i] = sumc[i-1] + cntc[i]
res = 0
for i in b:
res += suma[i-1] * (n - sumc[i])
print(res)

二分

1
2
3
4
5
6
7
8
9
10
11
12
13
import bisect
n = int(input())
a = list(map(int, input().split()))
b = list(map(int, input().split()))
c = list(map(int, input().split()))
a.sort()
c.sort()
res = 0
for i in b:
l = bisect.bisect_left(a, i)
r = bisect.bisect_right(c, i)
res += l * (n-r)
print(res)

特别数的和

1245. 特别数的和

1
2
3
4
5
6
n = int(input())
res = 0
for i in range(1, n+1):
if '0' in str(i) or '1' in str(i) or '2' in str(i) or '9' in str(i):
res += i
print(res)

错误票据

1204. 错误票据

1
2
3
4
5
6
7
8
9
10
11
12
n = int(input())
nums = []
for _ in range(n):
nums.extend(map(int, input().split()))
nums.sort()
a = b = 0
for i in range(1, len(nums)):
if nums[i] == nums[i-1] + 2:
a = nums[i] - 1
elif nums[i] == nums[i-1]:
b = nums[i]
print(a, b)

回文日期

466. 回文日期

1
2
3
4
5
6
7
8
9
10
11
12
date1, date2 = input(), input()
days = [0, 31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31]
rev = ['92200229']
for i in range(1, 13):
for j in range(1, days[i] + 1):
t = f'{i:02d}{j:02d}'
rev.append(t[::-1] + t)
res = 0
for i in rev:
if date1 <= i <= date2:
res += 1
print(res)

移动距离

1219. 移动距离

1
2
3
4
5
6
7
w, m, n = map(int, input().split())
m, n = m - 1, n - 1
x1, x2 = m // w, n // w
y1, y2 = m % w, n % w
if x1 % 2: y1 = w - 1 - y1
if x2 % 2: y2 = w - 1 - y2
print(abs(x1 - x2) + abs(y1 - y2))

日期问题

1229. 日期问题

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
from time import strptime
l, r, s = '1960-01-01', '2059-12-31', input().split('/')
res = []
def check(date):
if l <= date <= r:
try:
strptime(date, '%Y-%m-%d')
res.append(date)
except:
pass
for i in ['19', '20']:
check(i + s[0] + '-' + s[1] + '-' + s[2])
check(i + s[2] + '-' + s[1] + '-' + s[0])
check(i + s[2] + '-' + s[0] + '-' + s[1])
res.sort()
for i in set(res):
print(i)

航班时间

1231. 航班时间

1
2
3
4
5
6
7
8
9
10
11
12
13
t = int(input())
for _ in range(t):
res = 0
for _ in range(2):
a = input().split()
for i in range(len(a)):
if i == 2:
res += int(a[2][2]) * 24 * 3600
else:
b = list(map(int, a[i].split(':')))
res += (-1) ** (i+1) * (b[0] * 3600 + b[1] * 60 + b[2])
res //= 2
print(f'{res // 3600:02d}:{res % 3600 // 60:02d}:{res % 60:02d}')

外卖店优先级

1241. 外卖店优先级

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
n, m, t = map(int, input().split())
a, st, last = [0] * (n+1), [0] * (n+1), [0] * (n+1)
b = [list(map(int, input().split())) for _ in range(m)]
b.sort()
for i, j in b:
a[j] = max(0, a[j] - max(0, i - last[j] - 1))
if a[j] <= 3:
st[j] = 0
a[j], last[j] = a[j] + 2, i
if a[j] > 5:
st[j] = 1
for i in range(1, n+1):
a[i] = max(0, a[i] - t + last[i])
if a[i] <= 3:
st[i] = 0
print(sum(st))

归并排序

787. 归并排序

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
n = int(input())
nums = list(map(int, input().split()))
def merge_sort(l, r):
if l >= r:
return
mid = l + r >> 1
merge_sort(l, mid)
merge_sort(mid + 1, r)
i, j = l, mid + 1
tmp = []
while i <= mid and j <= r:
if nums[i] < nums[j]:
tmp.append(nums[i])
i += 1
else:
tmp.append(nums[j])
j += 1
tmp += nums[i: mid + 1]
tmp += nums[j: r + 1]
nums[l: r + 1] = tmp
merge_sort(0, n - 1)
print(*nums)

逆序对的数量

788. 逆序对的数量

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
n = int(input())
nums = list(map(int, input().split()))
def merge_sort(l, r):
if l >= r:
return 0
mid = l + r >> 1
res = merge_sort(l, mid) + merge_sort(mid + 1, r)
i, j = l, mid + 1
tmp = []
while i <= mid and j <= r:
if nums[i] <= nums[j]:
tmp.append(nums[i])
i += 1
else:
tmp.append(nums[j])
j += 1
res += mid - i + 1
tmp += nums[i: mid + 1]
tmp += nums[j: r + 1]
nums[l: r + 1] = tmp
return res
print(merge_sort(0, n - 1))

树状数组与线段树

动态求连续区间和

树状数组

1264. 动态求连续区间和

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
n, m = map(int, input().split())
nums = [0] + list(map(int, input().split()))
tree = [0] * (n+1)
def lowbit(x):
return x & -x
def add(x, v):
while x <= n:
tree[x] += v
x += lowbit(x)
def query(x):
res = 0
while x > 0:
res += tree[x]
x -= lowbit(x)
return res
for i in range(1, n+1):
add(i, nums[i])
for _ in range(m):
k, a, b = map(int, input().split())
if k:
add(a, b)
else:
print(query(b) - query(a-1))

线段树

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
n, m = map(int, input().split())
nums = [0] + list(map(int, input().split()))
tree = [[0, 0, 0] for _ in range(4 * n + 1)]
def pushup(u):
tree[u][2] = tree[u << 1][2] + tree[u << 1 | 1][2]
def build(u, l, r):
tree[u] = [l, r, 0]
if l == r:
tree[u][2] = nums[l]
else:
mid = l + r >> 1
build(u << 1, l, mid)
build(u << 1 | 1, mid + 1, r)
pushup(u)
def query(u, l, r):
if tree[u][0] >= l and tree[u][1] <= r:
return tree[u][2]
mid = tree[u][0] + tree[u][1] >> 1
res = 0
if l <= mid:
res += query(u << 1, l, r)
if r > mid:
res += query(u << 1 | 1, l, r)
return res
def modify(u, x, v):
if tree[u][0] == tree[u][1]:
tree[u][2] += v
else:
mid = tree[u][0] + tree[u][1] >> 1
if x <= mid:
modify(u << 1, x, v)
else:
modify(u << 1 | 1, x, v)
pushup(u)
build(1, 1, n)
for _ in range(m):
k, a, b = map(int, input().split())
if k:
modify(1, a, b)
else:
print(query(1, a, b))

数星星

1265. 数星星

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
n = int(input())
tree = [0] * 32002
res = [0] * n
def lowbit(x):
return x & -x
def add(x):
while x <= 32001:
tree[x] += 1
x += lowbit(x)
def query(x):
res = 0
while x > 0:
res += tree[x]
x -= lowbit(x)
return res
for _ in range(n):
x, y = map(int, input().split())
res[query(x+1)] += 1
add(x+1)
for i in res:
print(i)

数列区间最大值

1270. 数列区间最大值

dp(爆空间)

1
2
3
4
5
6
7
8
9
n, m = map(int, input().split())
nums = [0] + list(map(int, input().split()))
dp = [[0] * (n+1) for _ in range(n+1)]
for i in range(1, n + 1):
for j in range(i, n + 1):
dp[i][j] = max(dp[i][j - 1], nums[j])
for _ in range(m):
x, y = map(int, input().split())
print(dp[x][y])

线段树(TLE)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
n, m = map(int, input().split())
nums = [0] + list(map(int, input().split()))
tree = [[0, 0, 0] for _ in range(4 * n + 1)]
def pushup(u):
t = max(tree[u << 1][2], tree[u << 1 | 1][2])
tree[u][2] = t
def build(u, l, r):
tree[u] = [l, r, 0]
if l == r:
tree[u][2] = nums[l]
else:
mid = l + r >> 1
build(u << 1, l, mid)
build(u << 1 | 1, mid + 1, r)
pushup(u)
def query(u, l, r):
if tree[u][0] >= l and tree[u][1] <= r:
return tree[u][2]
mid = tree[u][0] + tree[u][1] >> 1
left = right = 0
if l <= mid:
left = query(u << 1, l, r, )
if r > mid:
right = query(u << 1 | 1, l, r)
return max(left, right)
build(1, 1, n)
for _ in range(m):
x, y = map(int, input().split())
print(query(1, x, y))

小朋友排队

1215. 小朋友排队

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
N = 1000001
n = int(input())
nums = [0] + list(map(lambda x: int(x) + 1, input().split()))
ct = [0] * (n + 1)
def lowbit(x):
return x & -x
def add(x, v):
while x <= N:
tree[x] += v
x += lowbit(x)
def query(x):
res = 0
while x > 0:
res += tree[x]
x -= lowbit(x)
return res
tree = [0] * (N + 1)
for i in range(1, n + 1):
ct[i] = query(N) - query(nums[i])
add(nums[i], 1)
tree = [0] * (N + 1)
for i in range(n, 0, -1):
ct[i] += query(nums[i] - 1)
add(nums[i], 1)
res = 0
for i in ct[1:]:
res += i * (i + 1) >> 1
print(res)

油漆面积

1228. 油漆面积

1
2
线段树
太难,跳过!

三体攻击

1232. 三体攻击

二分 + 三维差分 (难)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
from sys import stdin
input = lambda: stdin.readline().strip()
def get(i, j, k):
return (i * B + j) * C + k
def check(mid):
b = [0]*N
for i in range(1, mid + 1):
x1, x2, y1, y2, z1, z2, c = op[i]
b[get(x1, y1, z1)] -= c
b[get(x1, y1, z2 + 1)] += c
b[get(x1, y2 + 1, z1)] += c
b[get(x1, y2 + 1, z2 + 1)] -= c
b[get(x2 + 1, y1, z1)] += c
b[get(x2 + 1, y1, z2 + 1)] -= c
b[get(x2 + 1, y2 + 1, z1)] -= c
b[get(x2 + 1, y2 + 1, z2 + 1)] += c
# 对b求一下三维前缀和
arr = s[:]
for i in range(1, A + 1):
for j in range(1, B + 1):
for k in range(1, C + 1):
b[get(i, j, k)] += b[get(i - 1, j, k)] + b[get(i,j-1,k)] + b[get(i,j,k-1)] + b[get(i-1,j-1,k-1)] - \
b[get(i - 1, j - 1, k)] - b[get(i - 1, j, k - 1)] - b[get(i, j - 1, k - 1)]
arr[get(i, j, k)] += b[get(i,j,k)]
if arr[get(i, j, k)] < 0: return True
return False
A, B, C, m = map(int, input().split())
N = (A+1)*(B+1)*(C+1)
s, b, bp = [0]*N, [0]*N, [0]*N
arr = list(map(int, input().split()))
t = 0
for i in range(1, A+1):
for j in range(1, B+1):
for k in range(1, C+1):
s[get(i,j,k)] = arr[t] # 把地球军队的生命值给赋值
t += 1
op = [[]] + [list(map(int, input().split())) for _ in range(m)]
l, r = 1, m
while l < r:
mid = l + r >> 1
if check(mid): r = mid
else: l = mid + 1
print(l)

差分

797. 差分

1
2
3
4
5
6
7
8
9
10
11
12
13
14
n, m = map(int, input().split())
nums = [0] + list(map(int, input().split()))
diff = [0] * (n + 2)
def insert(l, r, c):
diff[l] += c
diff[r + 1] -= c
for i in range(1, n + 1):
insert(i, i, nums[i])
for _ in range(m):
l, r, c = map(int, input().split())
insert(l, r, c)
for i in range(1, n + 1):
nums[i] = nums[i - 1] + diff[i]
print(*nums[1:])

差分矩阵

798. 差分矩阵

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
n, m, q = map(int, input().split())
nums = [[0] * (m + 1)] + [[0] + list(map(int, input().split())) for _ in range(n)]
diff = [[0] * (m + 2) for _ in range(n + 2)]
def insert(x1, y1, x2, y2, c):
diff[x1][y1] += c
diff[x1][y2 + 1] -= c
diff[x2 + 1][y1] -= c
diff[x2 + 1][y2 + 1] += c
for _ in range(q):
x1, y1, x2, y2, c = map(int, input().split())
insert(x1, y1, x2, y2, c)
for i in range(1, n + 1):
for j in range(1, m + 1):
insert(i, j, i, j, nums[i][j])
nums[i][j] = nums[i - 1][j] + nums[i][j - 1] - nums[i - 1][j - 1] + diff[i][j]
print(nums[i][j], end=' ')
print()

螺旋折线

1237. 螺旋折线

1
2
3
4
5
6
7
8
9
x, y = map(int, input().split())
if abs(x) <= y:
print(2 * y * (2 * y - 1) + x + y)
elif abs(y) <= x:
print(2 * x * 2 * x + x - y)
elif abs(x) <= -y + 1:
print(2 * -y * (2 * -y + 1) - x - y)
else:
print((2 * -x - 1) * (2 * -x - 1) - x + y - 1)

双指针、BFS与图论

日志统计

1238. 日志统计

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
n, d, k = map(int, input().split())
nums = [list(map(int, input().split())) for _ in range(n)]
nums.sort()
cnt, st = [0] * 100001, [False] * 100001
j = 0
for i in range(n):
idx = nums[i][1]
cnt[idx] += 1
while nums[i][0] - nums[j][0] >= d:
cnt[nums[j][1]] -= 1
j += 1
if cnt[idx] >= k:
st[idx] = True
for i in range(100001):
if st[i]:
print(i)

献给阿尔吉侬的花束

1101. 献给阿尔吉侬的花束

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
from collections import deque
t = int(input())
dirs = [(0, 1), (1, 0), (0, -1), (-1, 0)]
def bfs(start, end, g):
q = deque()
q.append(start)
dist = [[-1] * c for _ in range(r)]
dist[start[0]][start[1]] = 0
while q:
a, b = q.popleft()
for i in range(4):
x, y = a + dirs[i][0], b + dirs[i][1]
if 0 <= x < r and 0 <= y < c and g[x][y] != '#' and dist[x][y] == -1:
dist[x][y] = dist[a][b] + 1
if x == end[0] and y == end[1]:
print(dist[x][y])
return
q.append((x, y))
print('oop!')
for _ in range(t):
r, c = map(int, input().split())
g = [input() for _ in range(r)]
for i in range(r):
for j in range(c):
if g[i][j] == 'S':
start = (i, j)
elif g[i][j] == 'E':
end = (i, j)
bfs(start, end, g)

红与黑

1113. 红与黑

dfs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
dirs = [(0, 1), (1, 0), (0, -1), (-1, 0)]
def dfs(a, b):
cnt = 1
st[a][b] = True
for i in range(4):
x, y = a + dirs[i][0], b + dirs[i][1]
if 0 <= x < n and 0 <= y < m and g[x][y] == '.' and not st[x][y]:
st[x][y] = True
cnt += dfs(x, y)
return cnt
while True:
m, n = map(int, input().split())
if not m:
break
g = [input() for _ in range(n)]
st = [[False] * m for _ in range(n)]
for i in range(n):
for j in range(m):
if g[i][j] == '@':
print(dfs(i, j))
break

bfs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
from collections import deque
dirs = [(0, 1), (1, 0), (0, -1), (-1, 0)]
def bfs(a, b):
q = deque()
q.append((a, b))
cnt = 1
st[a][b] = True
while q:
a, b = q.popleft()
for i in range(4):
x, y = a + dirs[i][0], b + dirs[i][1]
if 0 <= x < n and 0 <= y < m and g[x][y] == '.' and not st[x][y]:
st[x][y] = True
cnt += 1
q.append((x, y))
return cnt
while True:
m, n = map(int, input().split())
if not m:
break
g = [input() for _ in range(n)]
st = [[False] * m for _ in range(n)]
for i in range(n):
for j in range(m):
if g[i][j] == '@':
print(bfs(i, j))
break

交换瓶子

1224. 交换瓶子

贪心

1
2
3
4
5
6
7
8
9
10
11
12
13
n = int(input())
nums = [0] + list(map(int, input().split()))
st = [False] * (n + 1)
cnt = 0
for i in range(1, n + 1):
t = nums[i]
if not st[t]:
st[t] = True
cnt += 1
while not st[nums[t]]:
st[nums[t]] = True
t = nums[t]
print(n - cnt)

并查集

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
n = int(input())
nums = [0] + list(map(int, input().split()))
fa = [i for i in range(n + 1)]
s = [0] * (n + 1)
def find(x):
if fa[x] == x:
return x
fa[x] = find(fa[x])
return fa[x]
for i in range(1, n + 1):
fx, fy = find(nums[i]), find(nums[nums[i]])
fa[fx] = fy
for i in range(1, n + 1):
s[find(i)] += 1
res = 0
for i in range(1, n + 1):
if s[i]:
res += s[i] - 1
print(res)

完全二叉树的权值

1240. 完全二叉树的权值

前缀和

1
2
3
4
5
6
7
8
9
10
11
12
13
14
n = int(input())
nums = [0] + list(map(int, input().split()))
for i in range(1, n + 1):
nums[i] += nums[i - 1]
res, depth = float('-inf'), 1
i = j = 1
while i <= n:
t = nums[min(n, i * 2 - 1)] - nums[i - 1]
if t > res:
res = t
depth = j
i *= 2
j += 1
print(depth)

地牢大师

1096. 地牢大师

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
from collections import deque
dirs = [(1, 0 ,0), (-1, 0, 0), (0, 1, 0), (0, -1, 0), (0, 0, 1), (0, 0, -1)]
def bfs(start, end):
q = deque()
q.append(start)
dist = [[[-1] * c for _ in range(r)] for _ in range(l)]
dist[start[0]][start[1]][start[2]] = 0
while q:
a, b, d = q.popleft()
for i in range(6):
x, y, z = a + dirs[i][0], b + dirs[i][1], d + dirs[i][2]
if 0 <= x < l and 0 <= y < r and 0 <= z < c and g[x][y][z] != '#' and dist[x][y][z] == -1:
dist[x][y][z] = dist[a][b][d] + 1
if x == end[0] and y == end[1] and z == end[2]:
print(f'Escaped in {dist[x][y][z]} minute(s).')
return
q.append((x, y, z))
print('Trapped!')
while True:
l, r, c = map(int, input().split())
if not l:
break
g = [0] * l
for i in range(l):
g[i] = [input() for _ in range(r)]
input()
for i in range(l):
for j in range(r):
for k in range(c):
if g[i][j][k] == 'S':
start = (i, j, k)
elif g[i][j][k] == 'E':
end = (i, j, k)
bfs(start, end)

全球变暖

1233. 全球变暖

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
from collections import deque
n = int(input())
g = [input() for _ in range(n)]
st = [[False] * n for _ in range(n)]
dirs = [(0, 1), (1, 0), (0, -1), (-1, 0)]
def bfs(a, b):
q = deque()
q.append((a, b))
total = bound = 0
st[a][b] = True
while q:
a, b = q.popleft()
total += 1
is_bound = False
for i in range(4):
x, y = a + dirs[i][0], b + dirs[i][1]
if 0 <= x < n and 0 <= y < n:
if g[x][y] == '.':
is_bound = True
elif g[x][y] == '#' and not st[x][y]:
st[x][y] = True
q.append((x, y))
if is_bound:
bound += 1
if total == bound:
global res
res += 1
res = 0
for i in range(n):
for j in range(n):
if g[i][j] == '#' and not st[i][j]:
bfs(i, j)
print(res)

大臣的旅费

1207. 大臣的旅费

dfs(爆栈)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
import sys
sys.setrecursionlimit(100000)
n = int(input())
g = [[] for _ in range(n + 1)]
dist = [0] * (n + 1)
for _ in range(n - 1):
a, b, c = map(int, input().split())
g[a].append((b, c))
g[b].append((a, c))
def dfs(u, father, distance):
dist[u] = distance
for i, j in g[u]:
if i != father:
dfs(i, u, distance + j)
dfs(1, -1, 0)
t = dist.index(max(dist))
dfs(t, -1, 0)
t = max(dist)
print(10 * t + t * (t + 1) // 2)

bfs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
from collections import deque
n = int(input())
g = [[] for _ in range(n + 1)]
dist = [-1] * (n + 1)
for _ in range(n - 1):
a, b, c = map(int, input().split())
g[a].append((b, c))
g[b].append((a, c))
def bfs(u):
q = deque([u])
dist[u] = 0
while q:
u = q.popleft()
for i, j in g[u]:
if dist[i] == -1:
dist[i] = dist[u] + j
q.append(i)
bfs(1)
t = dist.index(max(dist))
dist = [-1] * (n + 1)
bfs(t)
t = max(dist)
print(t * 10 + t * (t + 1) // 2)

贪心

股票买卖 II

1055. 股票买卖 II

1
2
3
4
5
6
n = int(input())
nums = list(map(int, input().split()))
res = 0
for i in range(1, n):
res += max(0, nums[i] - nums[i - 1])
print(res)

货仓选址

104. 货仓选址

1
2
3
4
5
6
7
8
n = int(input())
nums = list(map(int, input().split()))
nums.sort()
a = nums[n // 2]
res = 0
for num in nums:
res += abs(a - num)
print(res)

糖果传递

122. 糖果传递

1
2
3
4
5
6
7
8
9
10
11
12
13
n = int(input())
nums = [0] + [int(input()) for _ in range(n)]
sums = [0] * (n + 2)
avg = sum(nums) // n
for i in range(n, 1, -1):
sums[i] = sums[i + 1] - nums[i] + avg
sums = sums[1: n + 1]
sums.sort()
res = 0
a = sums[n // 2]
for i in sums:
res += abs(i - a)
print(res)
1
2
3
4
5
6
7
8
9
10
11
12
n = int(input())
nums = [int(input()) for _ in range(n)]
avg = sum(nums) // n
a = [0] + [i - avg for i in nums]
for i in range(1, n + 1):
a[i] += a[i - 1]
a = sorted(a[1:])
t = a[n // 2]
res = 0
for i in a:
res += abs(t - i)
print(res)

雷达设备

112. 雷达设备

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
n, d = map(int, input().split())
g = []
for _ in range(n):
x, y = map(int, input().split())
if y > d:
print(-1)
exit(0)
else:
x1 = x - (d * d - y * y) ** 0.5
x2 = x + (d * d - y * y) ** 0.5
g.append((x1, x2))
g.sort(key=lambda x: x[1])
res, last = 0, float('-inf')
for i in range(n):
if g[i][0] > last:
res += 1
last = g[i][1]
print(res)

付账问题

1235. 付账问题

1
2
3
4
5
6
7
8
9
10
11
12
13
14
n, s = map(int, input().split())
nums = list(map(int, input().split()))
nums.sort()
res, avg = 0, s / n
for i, num in enumerate(nums):
cur = s / (n - i)
if num < cur:
res += (avg - num) ** 2
s -= num
else:
res += (cur - avg) ** 2 * (n - i)
break
res = (res / n) ** 0.5
print(f'{res:.4f}')

乘积最大

1239. 乘积最大

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
n, k = map(int, input().split())
nums = [int(input()) for _ in range(n)]
MOD = 1000000009
nums.sort()
res = sign = 1
l, r = 0, n - 1
if k % 2:
res = nums[r]
r, k = r - 1, k - 1
if res < 0:
sign = -1
while k:
x, y = nums[l] * nums[l + 1], nums[r] * nums[r - 1]
if x * sign > y * sign:
res *= x
l += 2
else:
res *= y
r -= 2
res = res % MOD if res > 0 else -(-res % MOD)
k -= 2
print(res)

后缀表达式

1247. 后缀表达式

1
2
3
4
5
6
7
8
9
10
n, m = map(int, input().split())
nums = list(map(int, input().split()))
if not m:
print(sum(nums))
else:
nums.sort()
res = nums[-1] - nums[0]
for num in nums[1: -1]:
res += abs(num)
print(res)

灵能传输

1248. 灵能传输

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
t = int(input())
for _ in range(t):
n = int(input())
nums = [0] + list(map(int, input().split()))
for i in range(1, n + 1):
nums[i] += nums[i - 1]
s0, sn = nums[0], nums[-1]
if s0 > sn:
s0, sn = sn, s0
nums.sort()
s0, sn = nums.index(s0), nums.index(sn)
l, r = 0, n
res = [0] * (n + 1)
st = [0] * (n + 1)
for i in range(s0, -1, -2):
res[l] = nums[i]
st[i] = True
l += 1
for i in range(sn, n + 1, 2):
res[r] = nums[i]
st[i] = True
r -= 1
for i in range(n + 1):
if not st[i]:
res[l] = nums[i]
l += 1
ans = 0
for i in range(1, n + 1):
ans = max(ans, abs(res[i] - res[i - 1]))
print(ans)

数论

等差数列

1246. 等差数列

1
2
3
4
5
6
7
8
9
10
11
12
n = int(input())
nums = list(map(int, input().split()))
nums.sort()
def gcd(a, b):
return gcd(b, a % b) if b else a
d = 0
for i in range(1, n):
d = gcd(nums[i] - nums[0], d)
if d:
print((nums[-1] - nums[0]) // d + 1)
else:
print(n)

X的因子链

1295. X的因子链

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
N = (1 << 20) + 10
primes, st, minp = [], [False] * N, [0] * N
def get_prime(n):
for i in range(2, n + 1):
if not st[i]:
primes.append(i)
minp[i] = i
j = 0
while i * primes[j] <= n and j < len(primes):
st[primes[j] * i] = True
minp[primes[j] * i] = primes[j]
if i % primes[j] == 0:
break
j += 1
get_prime(N - 1)
fact, sums = [0] * 30, [0] * N
while True:
try:
x = int(input())
k = total = 0
while x > 1:
p = minp[x]
fact[k], sums[k] = p, 0
while x % p == 0:
x //= p
sums[k] += 1
total += 1
k += 1
res = 1
for i in range(1, total + 1):
res *= i
for i in range(k):
for j in range(1, sums[i] + 1):
res //= j
print(total, res)
except:
break

聪明的燕姿

1296. 聪明的燕姿

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
N = 50000
primes, st = [], [False] * N
def get_prime(n):
for i in range(2, n + 1):
if not st[i]:
primes.append(i)
j = 0
while i * primes[j] <= n:
st[i * primes[j]] = True
if i % primes[j] == 0:
break
j += 1
def is_prime(x):
if x < N:
return not st[x]
i = 0
while primes[i] <= x / primes[i]:
if x % primes[i] == 0:
return False
i += 1
return True
def dfs(last, prod, s):
if s == 1:
res.append(prod)
return
if s - 1 > (0 if last < 0 else primes[last]) and is_prime(s - 1):
res.append(prod * (s - 1))
i = last + 1
while primes[i] <= s / primes[i]:
p = primes[i]
j, t = 1 + p, p
while t <= s:
if s % j == 0:
dfs(i, prod * t, s // j)
t *= p
j += t
i += 1
get_prime(N - 1)
while True:
try:
x = int(input())
res = []
dfs(-1, 1, x)
print(len(res))
if res:
res.sort()
print(*res)
except:
break

五指山

1299. 五指山

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
t = int(input())
def exgcd(a, b):
if not b:
return a, 1, 0
d, y, x = exgcd(b, a % b)
y -= a // b * x
return d, x, y
for _ in range(t):
n, d, x, y = map(int, input().split())
gcd, a, b = exgcd(n, d)
if (y - x) % gcd:
print("Impossible")
else:
b *= (y - x) // gcd
print(b % (n // gcd))

最大比例

1223. 最大比例

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
n = int(input())
nums = list(map(int, input().split()))
nums.sort()
def gcd(a, b):
return gcd(b, a % b) if b else a
def gcd_sub(a, b):
if a < b:
a, b = b, a
if b == 1:
return a
return gcd_sub(b, a // b)
a, b = [], []
for i in range(1, n):
if nums[i] != nums[i - 1]:
d = gcd(nums[i], nums[0])
a.append(nums[i] // d)
b.append(nums[0] // d)
up, down = a[0], b[0]
for i in range(1, len(a)):
up = gcd_sub(up, a[i])
down = gcd_sub(down, b[i])
print(f'{up}/{down}')

C 循环

1301. C 循环

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
def exgcd(a, b):
if b == 0:
return a, 1, 0
d, y, x = exgcd(b, a % b)
y -= a // b * x
return d, x, y
while True:
a, b, c, k = map(int, input().split())
if a == 0 and b == 0 and c == 0 and k == 0:
break
k = 1 << k
gcd, x, y = exgcd(c, k)
if (b - a) % gcd:
print("FOREVER")
else:
x *= (b - a) // gcd
print(x % (k // gcd))

正则问题

1225. 正则问题

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
s = input().strip()
i = 0
def dfs():
global i
res = 0
while i < len(s):
if s[i] == '(':
i += 1
res += dfs()
i += 1
elif s[i] == '|':
i += 1
res = max(res, dfs())
elif s[i] == ')':
break
else:
i += 1
res += 1
return res
print(dfs())

糖果

1243. 糖果

IDA*

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
def lowbit(x):
return x & -x
def h(st):
t = (1 << m) - 1 - st
res = 0
while t:
res += 1
i = lowbit(t)
t -= i
for row in col[log2[i]]:
t &= ~row
return res
def dfs(depth, st):
if depth == 0 or h(st) > depth:
return st == (1 << m) - 1
t = -1
i = (1 << m) - 1 - st
while i:
j = lowbit(i)
i -= j
if t == -1 or len(col[log2[j]]) < len(col[t]):
t = log2[j]
for row in col[t]:
if dfs(depth - 1, row | st):
return True
return False
n, m, k = map(int, input().split())
log2 = [0] * (1 << m + 1)
for i in range(m):
log2[1 << i] = i
col = [[] for _ in range(m + 1)]
for _ in range(n):
t = list(map(int, input().split()))
tt = 0
for ti in t:
tt |= 1 << (ti - 1)
for i in range(m):
if (tt >> i) & 1:
col[i].append(tt)
depth = 0
while depth <= m and not dfs(depth, 0):
depth += 1
print(-1 if depth > m else depth)

dp

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
n, m, k = map(int, input().split())
dp = [-1] * ((1 << m) + 1)
val = [0] * n
for i in range(n):
t = list(map(int, input().split()))
for j in t:
val[i] |= 1 << (j - 1)
dp[val[i]] = 1
for i in range(n):
for j in range(1 << m):
if dp[j] == -1:
continue
if dp[j | val[i]] == -1:
dp[j | val[i]] = dp[j] + dp[val[i]]
else:
dp[j | val[i]] = min(dp[j | val[i]], dp[j] + dp[val[i]])
print(dp[(1 << m) - 1])

复杂DP

鸣人的影分身

1050. 鸣人的影分身

1
2
3
4
5
6
7
8
9
10
11
t = int(input())
for _ in range(t):
m, n = map(int, input().split())
dp = [[0] * (n + 1) for _ in range(m + 1)]
dp[0][0] = 1
for i in range(m + 1):
for j in range(1, n + 1):
dp[i][j] = dp[i][j - 1]
if i >= j:
dp[i][j] += dp[i - j][j]
print(dp[m][n])

糖果

1047. 糖果

1
2
3
4
5
6
7
8
n, k = map(int, input().split())
dp = [[float('-inf')] * (k + 1) for _ in range(n + 1)]
dp[0][0] = 0
for i in range(1, n + 1):
x = int(input())
for j in range(k + 1):
dp[i][j] = max(dp[i - 1][j], dp[i - 1][(j - x) % k] + x)
print(dp[n][0])

密码脱落

1222. 密码脱落

1
2
3
4
5
6
7
8
9
10
11
12
13
s = input()
n = len(s)
dp = [[0] * (n + 1) for _ in range(n + 1)]
for i in range(n):
for l in range(n - i):
r = l + i
if i == 0:
dp[l][r] = 1
else:
if s[l] == s[r]:
dp[l][r] = dp[l + 1][r - 1] + 2
dp[l][r] = max(dp[l][r], dp[l + 1][r], dp[l][r - 1])
print(n - dp[0][n - 1])

生命之树

1220. 生命之树

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import sys
sys.setrecursionlimit(100000)
n = int(input())
w = [0] + list(map(int, input().split()))
adj_list = [[] for _ in range(n + 1)]
dp = [float('-inf')] * (n + 1)
for _ in range(n - 1):
a, b = map(int, input().split())
adj_list[a].append(b)
adj_list[b].append(a)
def dfs(u, fa):
dp[u] = w[u]
for v in adj_list[u]:
if v != fa:
dfs(v, u)
dp[u] += max(0, dp[v])
dfs(1, -1)
print(max(dp))

包子凑数

1226. 包子凑数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
n = int(input())
dp = [[0] * 10000 for _ in range(n + 1)]
w = [0] + [int(input()) for _ in range(n)]
def gcd(a, b):
return gcd(b, a % b) if b else a
d = w[1]
for i in range(2, n + 1):
d = gcd(d, w[i])
if d == 1:
dp[0][0] = 1
for i in range(1, n + 1):
for j in range(10000):
dp[i][j] = dp[i - 1][j]
if j >= w[i]:
dp[i][j] |= dp[i][j - w[i]]
res = 0
for i in range(10000):
if not dp[n][i]:
res += 1
print(res)
else:
print("INF")

括号配对

1070. 括号配对

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
s = input()
n = len(s)
dp = [[0] * (n + 1) for _ in range(n + 1)]
def check(l, r):
if s[l - 1] == '(' and s[r - 1] == ')':
return True
if s[l - 1] == '[' and s[r - 1] == ']':
return True
return False
for i in range(n, 0, -1):
for j in range(i + 1, n + 1):
if check(i, j):
dp[i][j] = dp[i + 1][j - 1] + 2
for k in range(i, j):
dp[i][j] = max(dp[i][j], dp[i][k] + dp[k + 1][j])
print(n - dp[1][n])

斐波那契前 n 项和

1303. 斐波那契前 n 项和

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
n, m = map(int, input().split())
f1 = [1, 1, 1]
a = [[0, 1, 0], [1, 1, 1], [0, 0, 1]]
def mul1(b, c):
tmp = [0] * 3
for i in range(3):
for j in range(3):
tmp[i] = (tmp[i] + b[j] * c[j][i]) % m
return tmp
def mul2(b, c):
tmp = [[0] * 3 for _ in range(3)]
for i in range(3):
for j in range(3):
for k in range(3):
tmp[i][j] = (tmp[i][j] + b[i][k] * c[k][j]) % m
return tmp
n -= 1
while n:
if n & 1:
f1 = mul1(f1, a)
a = mul2(a, a)
n >>= 1
print(f1[2])

垒骰子

1217. 垒骰子

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
n, m = map(int, input().split())
MOD = int(1e9) + 7
op = [3, 4, 5, 0, 1, 2]
f1 = [4, 4, 4, 4, 4, 4]
d = [[4] * 6 for _ in range(6)]
st = [[False] * 6 for _ in range(6)]
for _ in range(m):
a, b = map(lambda x: int(x) - 1, input().split())
st[a][b] = st[b][a] = True
for i in range(6):
for j in range(6):
if st[j][op[i]]:
d[j][i] = 0
def mul1(b, c):
tmp = [0] * 6
for i in range(6):
for j in range(6):
tmp[i] = (tmp[i] + b[j] * c[j][i]) % MOD
return tmp
def mul2(b, c):
tmp = [[0] * 6 for _ in range(6)]
for i in range(6):
for j in range(6):
for k in range(6):
tmp[i][j] = (tmp[i][j] + b[i][k] * c[k][j]) % MOD
return tmp
n -= 1
while n:
if n & 1:
f1 = mul1(f1, d)
d = mul2(d, d)
n >>= 1
print(sum(f1) % MOD)

疑难杂题

修改数组

1242. 修改数组

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import sys
sys.setrecursionlimit(100000)
n = int(input())
nums = list(map(int, input().split()))
p = [i for i in range(1100001)]
res = []
def find(x):
if x != p[x]:
p[x] = find(p[x])
return p[x]
for num in nums:
x = find(num)
res.append(x)
p[x] = x + 1
print(*res)

倍数问题

1234. 倍数问题

1
2
3
4
5
6
7
8
9
10
11
12
13
14
n, K = map(int, input().split())
nums = list(map(int, input().split()))
dp = [[float('-inf')] * K for _ in range(4)]
dp[0][0] = 0
a = [[] for _ in range(K)]
for num in nums:
a[num % K].append(num)
for i in range(K):
tmp = sorted(a[i], reverse = True)
for i in range(min(3, len(tmp))):
for j in range(3, 0, -1):
for k in range(K):
dp[j][k] = max(dp[j][k], dp[j - 1][(k - tmp[i]) % K] + tmp[i])
print(dp[3][0])

组合数问题

523. 组合数问题

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
t, k = map(int, input().split())
c = [[0] * 2001 for _ in range(2001)]
s = [[0] * 2001 for _ in range(2001)]
for i in range(2001):
for j in range(i + 1):
if j == 0:
c[i][j] = 1
else:
c[i][j] = (c[i - 1][j] + c[i - 1][j - 1]) % k
if c[i][j] == 0:
s[i][j] = 1
for i in range(2001):
for j in range(2001):
if i:
s[i][j] += s[i - 1][j]
if j:
s[i][j] += s[i][j - 1]
if i and j:
s[i][j] -= s[i - 1][j - 1]
for _ in range(t):
n, m = map(int, input().split())
print(s[n][m])

模拟散列表

840. 模拟散列表

1
2
3
4
5
6
7
8
9
10
11
n = int(input())
dicts = {}
for _ in range(n):
op, x = input().split()
if op == 'I':
dicts[x] = dicts.get(x, 0) + 1
elif op == 'Q':
if dicts.get(x):
print('Yes')
else:
print('No')