AcWing算法基础课

基础算法

快速排序

785. 快速排序

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
n = int(input())
nums = list(map(int, input().split()))
def quick_sort(l, r):
if l >= r:
return
i, j = l - 1, r + 1
x = nums[l + r >> 1]
while i < j:
while True:
i += 1
if nums[i] >= x:
break
while True:
j -= 1
if nums[j] <= x:
break
if i < j:
nums[i], nums[j] = nums[j], nums[i]
quick_sort(l, j)
quick_sort(j + 1, r)
quick_sort(0, n - 1)
print(*nums)

786. 第k个数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
n, k = map(int, input().split())
nums = list(map(int, input().split()))
def quick_select(l, r, k):
if l >= r:
return nums[l]
x = nums[(l + r) // 2]
i, j = l - 1, j + 1
while i < j:
while True:
i += 1
if nums[i] <= x:
break
while True:
j -= 1
if nums[j] >= x:
break
if i < j:
nums[i], nums[j] = nums[j], nums[i]
sl = j - l + 1
if k <= sl:
return quick_select(l, j, k)
else:
return quick_select(j + 1, r, k - sl)

print(quick_select(0, n - 1, k))

归并排序

787. 归并排序

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
n = int(input())
nums = [int(x) for x in input().split()]
def merge_sort(nums, l, r):
if l >= r:
return
mid = (l + r) // 2
merge_sort(nums, l, mid)
merge_sort(nums, mid + 1, r)
i, j = l, mid + 1
tmp = []
while i <= mid and j <= r:
if nums[i] <= nums[j]:
tmp.append(nums[i])
i += 1
else:
tmp.append(nums[j])
j += 1
tmp += nums[i: mid + 1]
tmp += nums[j: r + 1]
nums[l: r + 1] = tmp

merge_sort(nums, 0, n - 1)
print(' '.join(list(map(str, nums))))

788. 逆序对的数量

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
n = int(input())
nums = list(map(int, input().split()))
def merge_sort(l, r):
if l >= r:
return 0
mid = l + r >> 1
res = merge_sort(l, mid) + merge_sort(mid + 1, r)
i, j = l, mid + 1
tmp = []
while i <= mid and j <= r:
if nums[i] <= nums[j]:
tmp.append(nums[i])
i += 1
else:
tmp.append(nums[j])
j += 1
res += mid - i + 1
tmp += nums[i: mid + 1]
tmp += nums[j: r + 1]
nums[l: r + 1] = tmp
return res

print(merge_sort(0, n - 1))

二分

789. 数的范围

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
n, q = map(int, input().split())
nums = [int(x) for x in input().split()]
while q > 0:
q -= 1
x = int(input())
l, r = 0, n - 1
while l < r:
mid = (l + r) // 2
if nums[mid] >= x:
r = mid
else:
l = mid + 1
if nums[l] != x:
print('-1 -1')
continue
left = l
l, r = 0, n - 1
while l < r:
mid = (l + r + 1) // 2
if nums[mid] <= x:
l = mid
else:
r = mid - 1
print(f'{left} {l}')

790. 数的三次方根

1
2
3
4
5
6
7
8
9
n = float(input())
l, r = -100, 100
while abs(l - r) > 1e-8:
mid = (l + r) / 2
if mid ** 3 > n:
r = mid
else:
l = mid
print(f'{l:.6f}')

前缀和

795. 前缀和

1
2
3
4
5
6
7
8
n, m = map(int, input().split())
nums = [int(x) for x in input().split()]
sums = [0] * (n + 1)
for i in range(n):
sums[i + 1] = sums[i] + nums[i]
for _ in range(m):
l, r = map(int, input().split())
print(sums[r] - sums[l - 1])

796. 子矩阵的和

1
2
3
4
5
6
7
8
9
10
11
12
13
n, m, q = map(int, input().split())
nums = [[0] * (m + 1) for _ in range(n + 1)]
sums = [[0] * (m + 1) for _ in range(n + 1)]
for i in range(1, n + 1):
tmps = list(map(int, input().split()))
nums[i] = [0] + tmps[:]
for i in range(1, n + 1):
for j in range(1, m + 1):
sums[i][j] = sums[i - 1][j] + sums[i][j - 1] - sums[i - 1][j - 1] + nums[i][j]
for _ in range(q):
x1, y1, x2, y2 = map(int, input().split())
print(sums[x2][y2] - sums[x1 - 1][y2] - sums[x2][y1 - 1] + sums[x1 - 1][y1 - 1])

差分

797. 差分

1
2
3
4
5
6
7
8
9
10
11
12
13
14
n, m = map(int, input().split())
nums = [0] + list(map(int, input().split()))
diffs = [0] * (n + 2)
def insert(l, r, c):
diffs[l] += c
diffs[r + 1] -= c
for i in range(1, n + 1):
insert(i, i, nums[i])
for _ in range(m):
l, r, c = map(int, input().split())
insert(l, r, c)
for i in range(1, n + 1):
nums[i] = nums[i - 1] + diffs[i]
print(nums[i], end=' ')

798. 差分矩阵

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
n, m, q = map(int, input().split())
nums = [[0] * (m + 2) for _ in range(n + 2)]
diffs = [[0] * (m + 2) for _ in range(n + 2)]
def insert(x1, y1, x2, y2, c):
diffs[x1][y1] += c
diffs[x1][y2 + 1] -= c
diffs[x2 + 1][y1] -= c
diffs[x2 + 1][y2 + 1] += c
for i in range(1, n + 1):
tmps = list(map(int, input().split()))
nums[i] = [0] + tmps[:]
for j in range(1, m + 1):
insert(i, j, i, j, tmps[j - 1])
for _ in range(q):
x1, y1, x2, y2, c = map(int, input().split())
insert(x1, y1, x2, y2, c)
for i in range(1, n + 1):
for j in range(1, m + 1):
nums[i][j] = nums[i - 1][j] + nums[i][j - 1] - nums[i - 1][j - 1] + diffs[i][j]
print(nums[i][j], end=' ')
print()

双指针

799. 最长连续不重复子序列

1
2
3
4
5
6
7
8
9
10
11
n = int(input())
nums = list(map(int, input().split()))
dic = dict.fromkeys(nums, 0)
j = res = 0
for i in range(n):
dic[nums[i]] += 1
while dic[nums[i]] > 1:
dic[nums[j]] -= 1
j += 1
res = max(res, i - j + 1)
print(res)

800. 数组元素的目标和

1
2
3
4
5
6
7
8
9
10
n, m, x = map(int, input().split())
a = list(map(int, input().split()))
b = list(map(int, input().split()))
j = m - 1
for i in range(n):
while j and a[i] + b[j] > x:
j -= 1
if a[i] + b[j] == x:
print(i, j)
break

二进制

801. 二进制中1的个数

1
2
3
4
5
6
7
8
9
10
11
n = int(input())
nums = list(map(int, input().split()))
def lowbit(x):
return x & -x

for num in nums:
res = 0
while num:
num -= lowbit(num)
res += 1
print(res, end=' ')

离散化

802. 区间和

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
n, m = map(int, input().split())
# adds = [list(map(int, input().split())) for _ in range(n)]
# querys = [list(map(int, input().split())) for _ in range(m)]
# indexs = [add[0] for add in adds]
# for l, r in querys:
# indexs += [l, r]
adds, querys, indexs = [], [], []
for i in range(n):
x, c = map(int, input().split())
adds.append([x, c])
indexs.append(x)
for i in range(m):
l, r = map(int, input().split())
querys.append([l, r])
indexs.append(l)
indexs.append(r)
indexs.sort()
indexs = list(set(indexs))
n = len(indexs)
def find(x):
l, r = 0, n - 1
while l < r:
mid = l + r >> 1
if indexs[mid] >= x:
r = mid
else:
l = mid + 1
return l + 1

nums = [0] * (n + 1)
sums = [0] * (n + 1)
for x, c in adds:
nums[find(x)] += c
for i in range(1, n + 1):
sums[i] = sums[i - 1] + nums[i]
for l, r in querys:
print(sums[find(r)] - sums[find(l) - 1])

区间合并

803. 区间合并

1
2
3
4
5
6
7
8
9
10
11
12
13
n = int(input())
nums = [list(map(int, input().split())) for _ in range(n)]
nums.sort(key=lambda x: x[0])
st, ed = float('-inf'), float('-inf')
res = 0
for l, r in nums:
if ed < l:
res += 1
st = l
ed = r
else:
ed = max(ed, r)
print(res)

数据结构

单链表

826. 单链表

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
def insert_head(x):
global head, idx
e[idx] = x
ne[idx] = head
head = idx
idx += 1

def insert(k, x):
global idx
e[idx] = x
ne[idx] = ne[k]
ne[k] = idx
idx += 1

def remove(k):
ne[k] = ne[ne[k]]

N = 100010
e, ne = [0] * N, [0] * N
head, idx = -1, 0
n = int(input())
for _ in range(n):
ops = input().split()
if ops[0] == 'H':
insert_head(int(ops[1]))
elif ops[0] == 'I':
insert(int(ops[1]) - 1, int(ops[2]))
else:
k = int(ops[1])
if not k:
head = ne[head]
remove(k - 1)
i = head
res = []
while i != -1:
res.append(e[i])
i = ne[i]
print(' '.join(map(str, res)))

双链表

827. 双链表

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
N = 100010
e, l, r = [0] * N, [0] * N, [0] * N
r[0], l[1], idx = 1, 0, 2
def insert(k, x):
global idx
e[idx] = x
r[idx] = r[k]
l[idx] = k
l[r[k]] = idx
r[k] = idx
idx += 1
def remove(k):
r[l[k]] = r[k]
l[r[k]] = l[k]
n = int(input())
for _ in range(n):
ops = input().split()
if ops[0] == 'L':
insert(0, int(ops[1]))
elif ops[0] == 'R':
insert(l[1], int(ops[1]))
elif ops[0] == 'IL':
insert(l[int(ops[1]) + 1], int(ops[2]))
elif ops[0] == 'IR':
insert(int(ops[1]) + 1, int(ops[2]))
else:
remove(int(ops[1]) + 1)
i = r[0]
res = []
while i != 1:
res.append(e[i])
i = r[i]
print(' '.join(map(str, res)))

828. 模拟栈

1
2
3
4
5
6
7
8
9
10
11
12
n = int(input())
stack = []
for _ in range(n):
ops = input().split()
if ops[0] == 'push':
stack.append(ops[1])
elif ops[0] == 'pop':
stack.pop()
elif ops[0] == 'query':
print(stack[-1])
elif ops[0] == 'empty':
print('NO' if stack else 'YES')

3302. 表达式求值

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
dic = {'(': 0, '+': 1, '-': 1, '*': 2, '/': 2}
ops, nums = [], []
def new_eval():
b = nums.pop()
a = nums.pop()
o = ops.pop()
if o == '+':
nums.append(a + b)
elif o == '-':
nums.append(a - b)
elif o == '*':
nums.append(a * b)
elif o == '/':
nums.append(int(a / b))
a = input()
n = len(a)
i = 0
while i < n:
c = a[i]
if c.isdigit():
j, x = i, 0
while j < n and a[j].isdigit():
x = x * 10 + int(a[j])
j += 1
i = j - 1
nums.append(x)
elif c == '(':
ops.append(c)
elif c == ')':
while ops[-1] != '(':
new_eval()
ops.pop()
else:
while ops and dic[ops[-1]] >= dic[c]:
new_eval()
ops.append(c)
i += 1
while ops:
new_eval()
print(nums[-1])

队列

829. 模拟队列

1
2
3
4
5
6
7
8
9
10
11
12
13
import collections
n = int(input())
queue = collections.deque()
for _ in range(n):
ops = input().split()
if ops[0] == 'push':
queue.append(ops[1])
elif ops[0] == 'pop':
queue.popleft()
elif ops[0] == 'query':
print(queue[0])
elif ops[0] == 'empty':
print('NO' if queue else 'YES')

单调栈

830. 单调栈

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
n = int(input())
nums = list(map(int, input().split()))
stack, res = [], []
for num in nums:
if not stack:
stack.append(num)
res.append(-1)
continue
while stack and num <= stack[-1]:
stack.pop()
if not stack:
res.append(-1)
else:
res.append(stack[-1])
stack.append(num)
print(' '.join(map(str, res)))

单调队列

154. 滑动窗口

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
n, k = map(int, input().split())
nums = list(map(int, input().split()))
q = [0] * 1000010
hh, tt = 0, -1
res1, res2 = [], []
for i in range(n):
if hh <= tt and i - k + 1 > q[hh]:
hh += 1
while hh <= tt and nums[q[tt]] > nums[i]:
tt -= 1
tt += 1
q[tt] = i
if i >= k - 1:
res1.append(nums[q[hh]])
hh, tt = 0, -1
for i in range(n):
if hh <= tt and i - k + 1 > q[hh]:
hh += 1
while hh <= tt and nums[q[tt]] < nums[i]:
tt -= 1
tt += 1
q[tt] = i
if i >= k - 1:
res2.append(nums[q[hh]])
print(' '.join(map(str, res1)))
print(' '.join(map(str, res2)))

KMP

831. KMP字符串

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
n = int(input())
p = ' ' + input()
m = int(input())
s = ' ' + input()
ne = [0] * 1000010
j = 0
for i in range(2, n + 1):
while j and p[i] != p[j + 1]:
j = ne[j]
if p[i] == p[j + 1]:
j += 1
ne[i] = j
j = 0
res = []
for i in range(1, m + 1):
while j and s[i] != p[j + 1]:
j = ne[j]
if s[i] == p[j + 1]:
j += 1
if j == n:
res.append(i - j)
j = ne[j]
print(' '.join(map(str, res)))

Trie

835. Trie字符串统计

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
N = 10010
tries = [[0] * 26 for _ in range(N)]
cnt = [0] * N
idx = 1
def insert(string):
global idx
p = 0
for char in string:
t = ord(char) - 97
if not tries[p][t]:
tries[p][t] = idx
idx += 1
p = tries[p][t]
cnt[p] += 1
def query(string):
p = 0
for char in string:
t = ord(char) - 97
if not tries[p][t]:
return 0
p = tries[p][t]
return cnt[p]
n = int(input())
for _ in range(n):
op, string = input().split()
if op == 'I':
insert(string)
elif op == 'Q':
print(query(string))

143. 最大异或对

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
N = 100010
M = 31 * N
tries = [[0] * 2 for _ in range(M)]
n = int(input())
nums = list(map(int, input().split()))
idx, res = 0, 0
def insert(x):
global idx
p = 0
for i in range(32)[::-1]:
u = x >> i & 1
if not tries[p][u]:
idx += 1
tries[p][u] = idx
p = tries[p][u]
def query(x):
p, res = 0, 0
for i in range(32)[::-1]:
u = x >> i & 1
if tries[p][u^1]:
res = res * 2 + u^1
p = tries[p][u^1]
else:
res = res * 2 + u
p = tries[p][u]
return res
for num in nums:
insert(num)
t = query(num)
res = max(res, t^num)
print(res)

并查集

836. 合并集合

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
n, m = map(int, input().split())
p = [i for i in range(n + 1)]
def find(x):
if p[x] != x:
p[x] = find(p[x])
return p[x]
for _ in range(m):
op, a, b = input().split()
a, b = int(a), int(b)
if op == 'M':
p[find(a)] = find(b)
elif op == 'Q':
if find(a) == find(b):
print('Yes')
else:
print('No')

837. 连通块中点的数量

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
n, m = map(int, input().split())
p = [i for i in range(n + 1)]
size = [1] * (n + 1)
def find(x):
if p[x] != x:
p[x] = find(p[x])
return p[x]
for _ in range(m):
ops = input().split()
if ops[0] == 'C':
a, b = int(ops[1]), int(ops[2])
if find(a) == find(b):
continue
size[find(b)] += size[find(a)]
p[find(a)] = find(b)
elif ops[0] == 'Q1':
a, b = int(ops[1]), int(ops[2])
if find(a) == find(b):
print('Yes')
else:
print('No')
elif ops[0] == 'Q2':
print(size[find(int(ops[1]))])

240. 食物链

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
n, m = map(int, input().split())
p, d = [i for i in range(n + 1)], [0] * (n + 1)
res = 0
def find(x):
if p[x] != x:
t = find(p[x])
d[x] += d[p[x]]
p[x] = t
return p[x]
for _ in range(m):
op, x, y = map(int, input().split())
if x > n or y > n:
res += 1
continue
px, py = find(x), find(y)
diff = (d[x] - d[y]) % 3
if op == 1:
if px == py and diff:
res += 1
else:
p[px] = p[y]
d[px] = d[y] - d[x]
elif op == 2:
if px == py and diff != 1:
res += 1
else:
p[px] = p[y]
d[px] = d[y] - d[x] + 1
print(res)

838. 堆排序

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
n, m = map(int, input().split())
heap = [0] + list(map(int, input().split()))
def down(k):
t = k
if 2 * k <= n and heap[2 * k] < heap[t]:
t = 2 * k
if 2 * k + 1 <= n and heap[2 * k + 1] < heap[t]:
t = 2 * k + 1
if t != k:
heap[t], heap[k] = heap[k], heap[t]
down(t)
for i in range(int(n / 2), -1, -1):
down(i)
for _ in range(m):
print(heap[1], end=' ')
heap[1] = heap[n]
n -= 1
down(1)

839. 模拟堆

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
N = 100010
heap, ph, hp = [0] * N, [0] * N, [0] * N
size, idx = 0, 0
n = int(input())
def swap(a, b):
ph[hp[a]], ph[hp[b]] = b, a
hp[a], hp[b] = hp[b], hp[a]
heap[a], heap[b] = heap[b], heap[a]
def down(k):
t = k
if 2 * k <= size and heap[2 * k] < heap[t]:
t = 2 * k
if 2 * k + 1 <= size and heap[2 * k + 1] < heap[t]:
t = 2 * k + 1
if t != k:
swap(t, k)
down(t)
def up(k):
while k // 2 and heap[k // 2] > heap[k]:
swap(k // 2, k)
k //= 2
for _ in range(n):
ops = input().split()
if ops[0] == 'I':
size += 1
idx += 1
heap[size] = int(ops[1])
ph[idx] = size
hp[size] = idx
up(size)
elif ops[0] == 'PM':
print(heap[1])
elif ops[0] == 'DM':
swap(1, size)
size -= 1
down(1)
elif ops[0] == 'D':
k = ph[int(ops[1])]
swap(k, size)
size -= 1
down(k)
up(k)
elif ops[0] == 'C':
k, x = ph[int(ops[1])], int(ops[2])
heap[k] = x
down(k)
up(k)

哈希表

840. 模拟散列表

拉链法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
N = 100003
h, e, ne = [-1] * N, [0] * N, [0] * N
n = int(input())
idx = 0
def insert(x):
global idx
k = x % N
e[idx] = x
ne[idx] = h[k]
h[k] = idx
idx += 1
def find(x):
k = x % N
i = h[k]
while i != -1:
if e[i] == x:
return True
i = ne[i]
return False
for _ in range(n):
op, x = input().split()
if op == 'I':
insert(int(x))
elif op == 'Q':
if find(int(x)):
print('Yes')
else:
print('No')

开放寻址法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
N = 200003
null = 0x3f3f3f3f
h = [null] * N
n = int(input())
def find(x):
k = x % N
while h[k] != null and h[k] != x:
k += 1
if k == N:
k = 0
return k
for _ in range(n):
op, x = input().split()
k = find(int(x))
if op == 'I':
h[k] = int(x)
elif op == 'Q':
if h[k] == int(x):
print('Yes')
else:
print('No')

python 自带

1
2
3
4
5
6
7
8
9
from collections import defaultdict
dic = defaultdict(int)
n = int(input())
for _ in range(n):
op, x = input().split()
if op == 'I':
dic[x] += 1
elif op == 'Q':
print('Yes' if dic[x] else 'No')

841. 字符串哈希

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
n, m = map(int, input().split())
s = input()
Q, P = 1 << 64, 131
h, p = [0] * (n + 1), [1] * (n + 1)
def get(l, r):
return (h[r] - h[l - 1] * p[r - l + 1]) % Q
for i in range(1, n + 1):
h[i] = (h[i - 1] * P + ord(s[i - 1])) % Q
p[i] = (p[i - 1] * P) % Q
for _ in range(m):
l1, r1, l2, r2 = map(int, input().split())
if get(l1, r1) == get(l2, r2):
print('Yes')
else:
print('No')

搜索与图论

DFS

842. 排列数字

dfs 做法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
n = int(input())
path = [0] * n
st = [False] * (n + 1)
def dfs(x):
if x == n:
print(' '.join(map(str, path)))
return
for i in range(1, n + 1):
if not st[i]:
path[x] = i
st[i] = True
dfs(x + 1)
st[i] = False
dfs(0)

python permutation方法

1
2
3
4
5
import itertools
n = int(input())
nums = [i for i in range(1, n + 1)]
for res in itertools.permutations(nums, n):
print(' '.join(map(str, res)))

843. n-皇后问题

全排列

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
n = int(input())
g = [['.' for _ in range(n)] for _ in range(n)]
col, dg, udg = [0] * n, [0] * (2 * n), [0] * (2 * n)
def dfs(x):
if x == n:
for i in range(n):
print(''.join(map(str, g[i])))
print()
return
for y in range(n):
if not col[y] and not dg[x + y] and not udg[n - x + y]:
g[x][y] = 'Q'
col[y] = dg[x + y] = udg[n - x + y] = 1
dfs(x + 1)
g[x][y] = '.'
col[y] = dg[x + y] = udg[n - x + y] = 0
dfs(0)

原始暴力枚举

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
n = int(input())
g = [['.' for _ in range(n)] for _ in range(n)]
row, col, dg, udg = [0] * n, [0] * n, [0] * (2 * n), [0] * (2 * n)
def dfs(x, y, s):
if y == n:
y = 0
x += 1
if x == n:
if s == n:
for i in range(n):
print(''.join(map(str, g[i])))
print()
return
if not row[x] and not col[y] and not dg[x + y] and not udg[n - x + y]:
g[x][y] = 'Q'
row[x] = col[y] = dg[x + y] = udg[n - x + y] = 1
dfs(x, y + 1, s + 1)
g[x][y] = '.'
row[x] = col[y] = dg[x + y] = udg[n - x + y] = 0
dfs(x, y + 1, s)
dfs(0, 0, 0)

BFS

844. 走迷宫

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
from collections import deque
n, m = map(int, input().split())
g = [list(map(int, input().split())) for _ in range(n)]
path = [[-1] * m for _ in range(n)]
prev = [[0] * m for _ in range(n)]
q = deque()
q.append((0, 0))
path[0][0] = 0
while q:
a, b = q.popleft()
for l, r in ((0, 1), (1, 0), (0, -1), (-1, 0)):
x = a + l
y = b + r
if 0 <= x < n and 0 <= y < m and not g[x][y] and path[x][y] == -1:
q.append((x, y))
path[x][y] = path[a][b] + 1
prev[x][y] = (a, b)
print(path[-1][-1])
x, y = n - 1, m - 1
while x > 0 or y > 0:
x,y = prev[x][y]
print(x,y)

845. 八数码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
from collections import deque
start = ''.join(input().split())
queue = deque([start])
d = {start: 0}
target = '12345678x'
def swap(s, idx1, idx2):
l, r = (idx1, idx2) if idx1 < idx2 else(idx2, idx1)
return s[:l] + s[r] + s[l + 1: r] + s[l] + s[r + 1:]
def bfs():
while queue:
t = queue.popleft()
distance = d[t]
if t == target:
return distance
idx = t.find('x')
x, y = idx // 3, idx % 3
for l, r in ((0, 1), (1, 0), (0, -1), (-1, 0)):
a, b = x + l, y + r
if 0 <= a < 3 and 0 <= b < 3:
t = swap(t, a * 3 + b, idx)
if t not in d:
d[t] = distance + 1
queue.append(t)
t = swap(t, a * 3 + b, idx)
return -1
print(bfs())

树与图的深度优先遍历

846. 树的重心

用链表作为邻接表

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
n = int(input())
h, e, ne = [-1] * (n + 1), [0] * (2 * n), [0] * (2 * n)
state = [False] * (n + 1)
idx, ans = 0, n
def add(a, b):
global idx
idx += 1
e[idx] = b
ne[idx] = h[a]
h[a] = idx
def dfs(u):
global ans
state[u] = True
size, res = 1, 0
cur = h[u]
while cur != -1:
j = e[cur]
if not state[j]:
s = dfs(j)
res = max(res, s)
size += s
cur = ne[cur]
res = max(res, n - size)
ans = min(ans, res)
return size
for _ in range(n - 1):
a, b = map(int, input().split())
add(a, b)
add(b, a)
dfs(1)
print(ans)

使用python的list[list]

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
n = int(input())
adj_list = [[] for _ in range(n + 1)]
state = [False] * (n + 1)
ans = n
for _ in range(n - 1):
a, b = map(int, input().split())
adj_list[a].append(b)
adj_list[b].append(a)
def dfs(u):
global ans
state[u] = True
size, res = 1, 0
for j in adj_list[u]:
if not state[j]:
s = dfs(j)
res = max(res, s)
size += s
res = max(res, n - size)
ans = min(ans, res)
return size
dfs(1)
print(ans)

树与图的广度优先遍历

847. 图中点的层次

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
from collections import deque
n, m = map(int, input().split())
h, e, ne = [-1] * (n + 1), [0] * (m + 1), [0] * (m + 1)
dist = [-1] * (n + 1)
queue = deque([1])
dist[1] = 0
idx = 0
def add(a, b):
global idx
idx += 1
e[idx] = b
ne[idx] = h[a]
h[a] = idx
for _ in range(m):
a, b = map(int, input().split())
add(a, b)
def bfs():
while queue:
node = queue.popleft()
d = dist[node]
if node == n:
return d
cur = h[node]
while cur != -1:
j = e[cur]
if dist[j] == -1:
queue.append(j)
dist[j] = d + 1
cur = ne[cur]
return -1
print(bfs())

使用python的list[list]

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
from collections import deque
n, m = map(int, input().split())
adj_list = [[] for _ in range(n + 1)]
d = [0] * (n + 1)
queue = deque([1])
for _ in range(m):
a, b = map(int, input().split())
adj_list[a].append(b)
def bfs():
while queue:
cur = queue.popleft()
distance = d[cur]
if cur == n:
return distance
for j in adj_list[cur]:
if not d[j]:
d[j] = distance + 1
queue.append(j)
return -1
print(bfs())

拓扑排序

848. 有向图的拓扑序列

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
from collections import deque
n, m = map(int, input().split())
h, e, ne = [-1] * (n + 1), [0] * (m + 1), [0] * (m + 1)
d = [0] * (n + 1)
idx = 0
queue = deque()
def add(a, b):
global idx
idx += 1
e[idx] = b
ne[idx] = h[a]
h[a] = idx
for _ in range(m):
a, b = map(int, input().split())
add(a, b)
d[b] += 1
def topsort():
for i in range(1, n + 1):
if not d[i]:
queue.append(i)
res = []
while queue:
node = queue.popleft()
res.append(node)
t = h[node]
while t != -1:
j = e[t]
d[j] -= 1
if d[j] == 0:
queue.append(j)
t = ne[t]
if len(res) == n:
print(' '.join(map(str, res)))
else:
print('-1')
topsort()

使用python的list[list]

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
from collections import deque
n, m = map(int, input().split())
adj_list = [[] for _ in range(n + 1)]
in_degree = [0] * (n + 1)
queue = deque()
for _ in range(m):
a, b = map(int, input().split())
adj_list[a].append(b)
in_degree[b] += 1
for i in range(1, n + 1):
if not in_degree[i]:
queue.append(i)
res = []
while queue:
node = queue.popleft()
res.append(node)
for j in adj_list[node]:
in_degree[j] -= 1
if not in_degree[j]:
queue.append(j)
if len(res) == n:
print(' '.join(map(str, res)))
else:
print(-1)

Dijkstra

849. Dijkstra求最短路 I

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
n, m = map(int, input().split())
g = [[float('inf')] * (n + 1) for _ in range(n + 1)]
dist = [float('inf')] * (n + 1)
state = [False] * (n + 1)
dist[1] = 0
for _ in range(m):
a, b, c = map(int, input().split())
g[a][b] = min(g[a][b], c)
def dijkstra():
for _ in range(1, n + 1):
t = min((j for j in range(1, n + 1) if not state[j]), key=lambda j: dist[j])
state[t] = True
for j in range(1, n + 1):
dist[j] = min(dist[j], dist[t] + g[t][j])
print(dist[n] if dist[n] != float('inf') else -1)
dijkstra()

850. Dijkstra求最短路 II

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import heapq
n, m = map(int, input().split())
adj_list = [[] for _ in range(n + 1)]
dist = [float('inf')] * (n + 1)
dist[1] = 0
heap = []
heapq.heappush(heap, (0, 1))
s = set()
for _ in range(m):
x, y, z = map(int, input().split())
adj_list[x].append((y, z))
while heap:
d, node = heapq.heappop(heap)
if node in s:
continue
s.add(node)
for neighbor, weight in adj_list[node]:
if dist[neighbor] > dist[node] + weight:
dist[neighbor] = dist[node] + weight
heapq.heappush(heap, (dist[neighbor], neightbor))
print(dist[n] if dist[n] != float('inf') else -1)

bellman-ford

853. 有边数限制的最短路

1
2
3
4
5
6
7
8
9
10
11
12
n, m, k = map(int, input().split())
e = []
dist = [float('inf')] * (n + 1)
dist[1] = 0
for _ in range(m):
a, b, c = map(int, input().split())
e.append((a, b, c))
for _ in range(k):
backup = dist.copy()
for a, b, w in e:
dist[b] = min(dist[b], backup[a] + w)
print(dist[n] if dist[n] != float('inf') else 'impossible')

spfa

851. spfa求最短路

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
from collections import deque
n, m = map(int, input().split())
adj_list = [[] for _ in range(n + 1)]
dist = [float('inf')] * (n + 1)
state = [False] * (n + 1)
queue = deque([1])
dist[1] = 0
state[1] = True
for _ in range(m):
a, b, w = map(int, input().split())
adj_list[a].append((b, w))
while queue:
cur = queue.popleft()
state[cur] = False
for neighbor, weight in adj_list[cur]:
if dist[neighbor] > dist[cur] + weight:
dist[neighbor] = dist[cur] + weight
if not state[neighbor]:
queue.append(neighbor)
state[neighbor] = True
print(dist[n] if dist[n] != float('inf') else 'impossible')

852. spfa判断负环

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
from collections import deque
n, m = map(int, input().split())
adj_list = [[] for _ in range(n + 1)]
dist, cnt = [0] * (n + 1), [0] * (n + 1)
state = [True] * (n + 1)
queue = deque([i for i in range(1, n + 1)])
for _ in range(m):
a, b, c = map(int, input().split())
adj_list[a].append((b, c))
def spfa():
while queue:
cur = queue.popleft()
state[cur] = False
for neighbor, weight in adj_list[cur]:
if dist[neighbor] > dist[cur] + weight:
dist[neighbor] = dist[cur] + weight
cnt[neighbor] = cnt[cur] + 1
if cnt[neighbor] >= n:
return True
if not state[neighbor]:
queue.append(neighbor)
state[neighbor] = True
return False
print('Yes' if spfa() else 'No')

Floyd

854. Floyd求最短路

1
2
3
4
5
6
7
8
9
10
11
12
13
14
n, m, q = map(int, input().split())
g = [[float('inf')] * (n + 1) for _ in range(n + 1)]
for i in range(1, n + 1):
g[i][i] = 0
for _ in range(m):
a, b, c = map(int, input().split())
g[a][b] = min(g[a][b], c)
for k in range(1, n + 1):
for i in range(1, n + 1):
for j in range(1, n + 1):
g[i][j] = min(g[i][j], g[i][k] + g[k][j])
for _ in range(q):
a, b = map(int, input().split())
print(g[a][b] if g[a][b] != float('inf') else 'impossible')

Prim

858. Prim算法求最小生成树

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
n, m = map(int, input().split())
g = [[float('inf')] * (n + 1) for _ in range(n + 1)]
state = [False] * (n + 1)
dist = [float('inf')] * (n + 1)
for _ in range(m):
a, b, c = map(int, input().split())
g[a][b] = min(g[a][b], c)
g[b][a] = g[a][b]
def prim():
res = 0
for i in range(n):
t = min((j for j in range(1, n + 1) if not state[j]), key = lambda x: dist[x])
if i and dist[t] == float('inf'):
return 'impossible'
if i:
res += dist[t]
for j in range(1, n + 1):
dist[j] = min(dist[j], g[t][j])
state[t] = True
return res
print(prim())

Krustal

859. Kruskal算法求最小生成树

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
n, m = map(int, input().split())
e = []
p = [i for i in range(n + 1)]
res, cnt = 0, 0
def find(x):
if p[x] != x:
p[x] = find(p[x])
return p[x]
for _ in range(m):
a, b, c = map(int, input().split())
e.append((a, b, c))
e.sort(key=lambda x: x[2])
for a, b, c in e:
a, b = find(a), find(b)
if a != b:
p[a] = b
res += c
cnt += 1
print(res if cnt == n - 1 else 'impossible')

染色法判定二分图

860. 染色法判定二分图

dfs会爆栈

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
n, m = map(int, input().split())
adj_list = [[] for _ in range(n + 1)]
color = [0] * (n + 1)
for _ in range(m):
a, b = map(int, input().split())
adj_list[a].append(b)
adj_list[b].append(a)
def dfs(u, c):
color[u] = c
for neighbor in adj_list[u]:
if not color[neighbor]:
if not dfs(neighbor, c * -1):
return False
elif color[neighbor] == c:
return False
return True
for i in range(1, n + 1):
if not colort[i]:
if not dfs(i, 1):
print('No')
break
else:
print('Yes')

bfs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
from collections import deque
n, m = map(int, input().split())
adj_list = [[] for _ in range(n + 1)]
color = [0] * (n + 1)
for _ in range(m):
a, b = map(int, input().split())
adj_list[a].append(b)
adj_list[b].append(a)
def bfs(u):
queue = deque()
queue.append((u, 1))
while queue:
node, c = queue.popleft()
color[node] = c
for neighbor in adj_list[node]:
if not color[neighbor]:
queue.append((neighbor, c * -1))
elif color[neighbor] == c:
return False
return True
for i in range(1, n + 1):
if not color[i]:
if not bfs(i):
print('No')
break
else:
print('Yes')

匈牙利算法

861. 二分图的最大匹配

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
n1, n2, m = map(int, input().split())
n = max(n1, n2)
adj_list = [[] for _ in range(n + 1)]
match = [0] * (n + 1)
for _ in range(m):
a, b = map(int, input().split())
adj_list[a].append(b)
def find(u):
for neighbor in adj_list[u]:
if not state[neighbor]:
state[neighbor] = True
if not match[neighbor] or find(match[neighbor]):
match[neighbor] = u
return True
return False
res = 0
for i in range(1, n1 + 1):
state = [0] * (n + 1)
if find(i): res += 1
print(res)

数学知识

质数

866. 试除法判定质数

1
2
3
4
5
6
7
8
9
10
import math
n = int(input())
def prime(x):
if x < 2: return False
for i in range(2, int(math.sqrt(x) + 1)):
if x % i == 0: return False
return True
for _ in range(n):
x = int(input())
print('Yes' if prime(x) else 'No')

867. 分解质因数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import math
n = int(input())
def divid(x):
for i in range(2, int(math.sqrt(x) + 1)):
if x % i == 0:
s = 0
while x % i == 0:
x //= i
s += 1
print(i, s)
if x > 1: print(x, 1)
for _ in range(n):
x = int(input())
divid(x)
print()

868. 筛质数

线性筛法–每个合数只能被自己的最小质因数删除O(n)

1
2
3
4
5
6
7
8
9
10
11
12
13
n = int(input())
state = [True] * (n + 1)
res = []
for i in range(2, n + 1):
if state[i]:
res.append(i)
j = 0
while res[j] * i <= n:
state[res[j] * i] = False
if i % res[j] == 0:
break
j += 1
print(len(res))

埃氏筛法O(n lognlogn)

1
2
3
4
5
6
7
8
9
n = int(input())
state = [True] * (n + 1)
res = 0
for i in range(2, n + 1):
if state[i]:
res += 1
for j in range(2 * i, n + 1, i):
state[j] = False
print(res)

约数

869. 试除法求约数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
import math
n = int(input())
def divisor(x):
res = []
for i in range(1, int(math.sqrt(x) + 1)):
if x % i == 0:
res.append(i)
if i * i != x:
res.append(x // i)
res.sort()
print(' '.join(map(str, res)))
for _ in range(n):
a = int(input())
divisor(a)

870. 约数个数

$N = p^{\alpha1}{1} * p^{\alpha2}{2} * \cdots * p^{\alpha k}_{k}$

约数个数$res = (a_{1} + 1) (a_{2} + 1) \cdots (a_{k} + 1)$

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
import math
n = int(input())
dict = {}
def divisor(x):
for i in range(2, int(math.sqrt(x) + 1)):
while x % i == 0:
x //= i
dict[i] = dict.get(i, 0) + 1
if x > 1:
dict[x] = dict.get(x, 0) + 1
for _ in range(n):
x = int(input())
divisor(x)
res = 1
for v in dict.values():
res = res * (v + 1) % (1e9 + 7)
print(int(res))

871. 约数之和

约数之和$res = (p^{0}{1} + p^{1}{1} + \cdots p^{a_{1}}{1}) * (p^{0}{2} + p^{1}{2} + \cdots p^{a{2}}{2}) * \cdots *(p^{0}{k} + p^{1}{k} + \cdots p^{a{k}}_{k})$

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import math
n = int(input())
dict = {}
MOD = int(1e9 + 7)
def divisor(x):
for i in range(2, int(math.sqrt(x) + 1)):
while x % i == 0:
x //= i
dict[i] = dict.get(i, 0) + 1
if x > 1:
dict[x] = dict.get(x, 0) + 1
for _ in range(n):
x = int(input())
divisor(x)
res = 1
for p, a in dict.items():
t = 1
while a:
t = (t * p + 1) % MOD
a -= 1
res = res * t % MOD
print(res)

872. 最大公约数

辗转相除法

1
2
3
4
5
6
n = int(input())
def gcd(a, b):
return gcd(b, a % b) if b else a
for _ in range(n):
a, b = map(int, input().split())
print(gcd(a, b))

python自带

1
2
3
4
5
import math
n = int(input())
for _ in range(n):
a, b = map(int, input().split())
print(math.gcd(a, b))

欧拉函数

873. 欧拉函数

$\varphi(1) = 1$

当n不是质数:$\varphi(n) = n * \sum^{x}{i=1}(1 - \frac{1}{p{k}})$

当n是质数:$\varphi(n) = n - 1$

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import math
n = int(input())
def euler(x):
res = x
for i in range(2, int(math.sqrt(x) + 1)):
if x % i == 0:
res *= (1 - 1 / i)
while x % i == 0:
x //= i
if x > 1:
res *= (1 - 1 / x)
print(int(res))
for _ in range(n):
x = int(input())
euler(x)

874. 筛法求欧拉函数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
n = int(input())
state = [False] * (n + 1)
phi = [0] * (n + 1)
phi[1] = 1
primes = []
def euler(x):
for i in range(2, n + 1):
if not state[i]:
primes.append(i)
phi[i] = i - 1
j = 0
while primes[j] * i <= n:
state[primes[j] * i] = True
if i % primes[j] == 0:
phi[primes[j] * i] = phi[i] * primes[j]
break
phi[primes[j] * i] = phi[i] * (primes[j] - 1)
j += 1
print(sum(phi))
euler(n)

快速幂

875. 快速幂

费马小定理$a^{p - 1} \equiv 1 \ (mod \enspace p)$

1
2
3
4
5
6
7
8
9
n = int(input())
for _ in range(n):
a, k, p = map(int, input().split())
res = 1
while k:
if k & 1: res = res * a % p
k >>= 1
a = a * a % p
print(res)

876. 快速幂求逆元

当n为质数时,b的乘法逆元$x = b^{(n - 2)}$

当n不是质数时,使用拓展欧几里得求逆元 $a * x \equiv 1 \ (mod \ p)$

1
2
3
4
5
6
7
8
9
10
11
n = int(input())
def quick_mi(a, k, p):
res = 1
while k:
if k & 1: res = res * a % p
k >>= 1
a = a * a % p
return res
for _ in range(n):
a, p = map(int, input().split())
print(quick_mi(a, p - 2, p) if a % p else 'impossible')

扩展欧几里得算法

877. 扩展欧几里得算法

求解$ax + by = gcd(a, b)$

当b=0时 $ax+by=a$ 故而 $x=1, y=0$

当$b \neq 0$时$x = y \prime, \quad y = x \prime - \lfloor\frac{a}{b}\rfloor * y \prime$

1
2
3
4
5
6
7
8
9
10
n = int(input())
def exgcd(a, b):
if not b:
return 1, 0
y, x = exgcd(b, a % b)
y -= a // b * x
return x, y
for _ in range(n):
a, b = map(int, input().split())
print(*exgcd(a, b))

878. 线性同余方程

当$gcd(a,m) \mid b$有解,求出以一组解使得$a * x_{0} + m * y_{0} = gcd(a,m)$,

所以$x = x_{0} * \frac{b}{gcd(a,m)} % m$

1
2
3
4
5
6
7
8
9
10
11
n = int(input())
def exgcd(a, b):
if not b:
return a, 1, 0
d, y, x = exgcd(b, a % b)
y -= a // b * x
return d, x, y
for _ in range(n):
a, b, m = map(int, input().split())
d, x, _ = exgcd(a, m)
print('impossible' if b % d else x * b // d % m)

中国剩余定理

$M=m_{1} \cdot m_{2} \cdot \cdots \cdot m_{R}, \quad M_{i} = \frac{M}{m_{i}}$, $M^{-1}{i}$表示$M_i$模$m{i}$的逆,即$M_i \cdot m^{-1}{i} \equiv 1 \ (mod \ m{i})$

$x = a_1 \cdot M_1 \cdot M^{-1}{1} + a_2 \cdot M_2 \cdot M^{-1}{2} + \cdots + a_k \cdot M_k \cdot M^{-1}_{k}$

204. 表达整数的奇怪方式

注 $m_1,m_2 \cdots m_k$ 不两两互质

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
n = int(input())
def exgcd(a, b):
if not b:
return a, 1, 0
d, y, x = exgcd(b, a % b)
y -= a // b * x
return d, x, y
a1, m1 = map(int, input().split())
for _ in range(n - 1):
a2, m2 = map(int, input().split())
d, k1, _ = exgcd(a1, a2)
if (m2 - m1) % d:
print(-1)
break
k1 *= (m2 - m1) // d
k1 %= a2 // d
m1 += a1 * k1
a1 = a1 * a2 // d
else:
print(m1 % a1)

高斯消元

883. 高斯消元解线性方程组

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
n = int(input())
g = [list(map(float, input().split())) for _ in range(n)]
def gauss():
idx, zero = 0, 1e-6
for c in range(n):
t = max(range(c, n), key=lambda x: abs(g[x][c]))
if abs(g[t][c]) < zero:
continue
g[idx][c:], g[t][c:] = g[t][c:], g[idx][c:]
for i in range(n, c, -1):
g[idx][i] /= g[idx][c]
for i in range(idx + 1, n):
if abs(g[i][c]) > zero:
for j in range(n, c - 1, -1):
g[i][j] -= g[idx][j] * g[i][c]
idx += 1
if idx < n:
for i in range(idx, n):
if abs(g[i][n]) > zero:
print('No solution')
return
print('Infinite group solutions')
return
for i in range(n - 1, -1, -1):
for j in range(i + 1, n):
g[i][n] -= g[i][j] * g[j][n]
for i in range(n):
print(f'{g[i][n]:.2f}')
gauss()

884. 高斯消元解异或线性方程组

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
n = int(input())
g = [list(map(int, input().split())) for _ in range(n)]
def gauss():
idx = 0
for c in range(n):
t = idx
for i in range(idx, n):
if g[i][c]:
t = i
break
if not g[t][c]:
continue
g[t][c:], g[idx][c:] = g[idx][c:], g[t][c:]
for i in range(idx + 1, n):
if g[i][c]:
for j in range(c, n + 1):
g[i][j] ^= g[idx][j]
idx += 1
if idx < n:
for i in range(idx, n):
if g[i][n]:
print('No solution')
return
print('Multiple sets of solutions')
return
for i in range(n - 1, -1, -1):
for j in range(i + 1, n):
g[i][n] ^= g[i][j] & g[j][n]
for i in range(n):
print(g[i][n])
gauss()

求组合数

885. 求组合数 I

$C^{b}{a} = C^{b - 1}{a - 1} + C^{b}_{a - 1}$

1
2
3
4
5
6
7
8
9
n = int(input())
N, MOD = 2010, int(1e9+7)
g = [[1] + [0] * N for _ in range(N)]
for i in range(N):
for j in range(i + 1):
g[i][j] = (g[i - 1][j] + g[i - 1][j - 1]) % MOD
for _ in range(n):
a, b = map(int, input().split())
print(g[a][b])

886. 求组合数 II

注 $\frac{a}{b} \enspace mod \enspace p \neq \frac{a \enspace mod \enspace p}{b \enspace mod \enspace p}$

可以用逆元计算 $\frac{a}{b} \enspace mod \enspace p = a \times b^{-1} \enspace mod \enspace p$

$C^{b}_{a} = \frac{a!}{b! * (a - b)!} = a! * infact(b!) * infact((a - b)!)$

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
n = int(input())
N, MOD = 100010, int(1e9 + 7)
fact, infact = [1] * N, [1] * N
def qmi(a, k, p):
res = 1
while k:
if k & 1:
res = res * a % p
k >>= 1
a = a * a % p
return res
for i in range(1, N):
fact[i] = fact[i - 1] * i % MOD
infact[i] = infact[i - 1] * qmi(i, MOD - 2, MOD) % MOD
for _ in range(n):
a, b = map(int, input().split())
print(fact[a] * infact[b] * infact[a - b] % MOD)

887. 求组合数 III

卢卡斯定理 Lucas $O(logpNplogp)$

$C^{b}{a} \equiv C^{\frac{b}{p}}{\frac{a}{p}} C^{b \ mod \ p}_{a \ mod \ p} \ (mod \ p)$

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
n = int(input())
def qmi(a, k, p):
res = 1
while k:
if k & 1:
res = res * a % p
k >>= 1
a = a * a % p
return res
def C(a, b):
res = 1
i, j = 1, a
while i <= b:
res = res * j % p
res = res * qmi(i, p - 2, p) % p
i += 1
j -= 1
return res
def lucas(a, b, p):
if a < p and b < p:
return C(a, b)
else:
return C(a % p, b % p) * lucas(a // p, b // p, p) % p
for _ in range(n):
a, b, p = map(int, input().split())
print(lucas(a, b, p))

888. 求组合数 IV

1
2
3
import math
a, b = map(int, input().split())
print(math.factorial(a) // math.factorial(b) // math.factorial(a - b))

889. 满足条件的01序列

卡特兰数 $ans = C^{n}{2n} - C^{n - 1}{2n} = \frac{C^{n}_{2n}}{n + 1}$

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
n = int(input())
p = int(1e9 + 7)
def qmi(a, k, p):
res = 1
while k:
if k & 1:
res = res * a % p
k >>= 1
a = a * a % p
return res
res = 1
i, j = 1, 2 * n
while i <= n:
res = res * j % p
res = res * qmi(i, p - 2, p) % p
i += 1
j -= 1
print(res * qmi(n + 1, p - 2, p) % p)

使用公式+python硬解(很慢)

1
2
3
import math
n = int(input())
print(math.factorial(2 * n) // (math.factorial(n) ** 2 * (n + 1)) % int(1e9 + 7))

容斥原理

$$
\bigcup_{i=1}^{m} S_{i}=S_{1}+S_{2}+\cdots+S_{m}-(S_{1} \bigcap S_{2}+S_{1} \bigcap S_{3}+\ldots+S_{m-1} \bigcap S_{m})+(S_{1} \bigcap S_{2} \bigcap S_{3}+\ldots+S_{m-2} \bigcap S_{m-1} \bigcap S_{m})+\ldots+(-1)^{m-1}(\bigcap_{i=1}^{m} S)
$$

890. 能被整除的数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
n, m = map(int, input().split())
p = list(map(int, input().split()))
res = 0
for i in range(1, 1 << m):
t, s = 1, 0
for j in range(m):
if i >> j & 1:
if t * p[j] > n:
break
t *= p[j]
s += 1
else:
if s & 1: res += n // t
else: res -= n // t
print(res)

博弈论

891. Nim游戏

mex(S)为求出不属于集合S的最小非负整数

1
2
3
4
5
6
n = int(input())
nums = list(map(int, input().split()))
res = nums[0]
for i in range(1, n):
res ^= nums[i]
print('Yes' if res else 'No')

892. 台阶-Nim游戏

1
2
3
4
5
6
n = int(input())
nums = list(map(int, input().split()))
res = nums[0]
for i in range(2, n, 2):
res ^= nums[i]
print('Yes' if res else 'No')

893. 集合-Nim游戏

$SG(x)=mex({SG(y_{1}),SG(y_{2})····SG(y_{k})})$

$SG(G)=SG(G_{1})\oplus SG(G_{2}) \oplus \cdots \oplus SG(G_{m})$

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
k = int(input())
s = list(map(int, input().split()))
n = int(input())
nums = list(map(int, input().split()))
f = [-1] * 10010
def sg(x):
if f[x] != -1:
return f[x]
S = {sg(x - i) for i in s if x >= i}
i = 0
while i in S:
i += 1
f[x] = i
return f[x]
def nim(n, nums):
res = 0
for num in nums:
res ^= sg(num)
return res
print('Yes' if nim(n, nums) else 'No')

894. 拆分-Nim游戏

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
n = int(input())
nums = list(map(int, input().split()))
f = [-1] * 101
def sg(x):
if f[x] != -1:
return f[x]
s = set()
for i in range(x):
for j in range(i + 1):
s.add(sg(i) ^ sg(j))
i = 0
while i in s:
i += 1
f[x] = i
return f[x]
def nim(n, nums):
res = 0
for num in nums:
res ^= sg(num)
return res
print('Yes' if nim(n, nums) else 'No')

动态规划

背包问题

2. 01背包问题

二维dp

1
2
3
4
5
6
7
8
9
10
11
12
13
n, m = map(int, input().split())
v, w = [0] * (n + 1), [0] * (n + 1)
f = [[0] * (m + 1) for _ in range(n + 1)]
for i in range(1, n + 1):
a, b = map(int, input().split())
v[i] = a
w[i] = b
for i in range(1, n + 1):
for j in range(1, m + 1):
f[i][j] = f[i - 1][j]
if j >= v[i]:
f[i][j] = max(f[i][j], f[i - 1][j - v[i]] + w[i])
print(f[n][m])

一维dp

1
2
3
4
5
6
7
8
9
10
11
n, m = map(int, input().split())
v, w = [0] * (n + 1), [0] * (n + 1)
f = [0] * (m + 1)
for i in range(1, n + 1):
a, b = map(int, input().split())
v[i] = a
w[i] = b
for i in range(1, n + 1):
for j in range(m, v[i] - 1, -1):
f[j] = max(f[j], f[j - v[i]] + w[i])
print(f[m])

3. 完全背包问题

二维dp

1
2
3
4
5
6
7
8
9
10
11
12
13
n, m = map(int, input().split())
v, w = [0] * (n + 1), [0] * (n + 1)
f = [[0] * (m + 1) for _ in range(n + 1)]
for i in range(1, n + 1):
a, b = map(int, input().split())
v[i] = a
w[i] = b
for i in range(1, n + 1):
for j in range(1, m + 1):
f[i][j] = f[i - 1][j]
if j >= v[i]:
f[i][j] = max(f[i][j], f[i][j - v[i]] + w[i])
print(f[n][m])

一维dp

1
2
3
4
5
6
7
8
9
10
11
12
n, m = map(int, input().split())
v, w = [0] * (n + 1), [0] * (n + 1)
f = [0] * (m + 1)
for i in range(1, n + 1):
a, b = map(int, input().split())
v[i] = a
w[i] = b
for i in range(1, n + 1):
for j in range(1, m + 1):
if j >= v[i]:
f[j] = max(f[j], f[j - v[i]] + w[i])
print(f[m])

4. 多重背包问题 I

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
n, m = map(int, input().split())
v, w, s = [0] * (n + 1), [0] * (n + 1), [0] * (n + 1)
f = [[0] * (m + 1) for _ in range(n + 1)]
for i in range(1, n + 1):
a, b, c = map(int, input().split())
v[i] = a
w[i] = b
s[i] = c
for i in range(1, n + 1):
for j in range(1, m + 1):
f[i][j] = f[i - 1][j]
k = 0
while k <= s[i] and j >= k * v[i]:
f[i][j] = max(f[i][j], f[i - 1][j - k * v[i]] + k * w[i])
k += 1
print(f[n][m])

5. 多重背包问题 II

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
n, m = map(int, input().split())
N = 20010
v, w, f = [0] * (N + 1), [0] * (N + 1), [0] * (N + 1)
idx = 1
for _ in range(n):
a, b, c = map(int, input().split())
k = 1
while k < c:
v[idx] = a * k
w[idx] = b * k
c -= k
k *= 2
idx += 1
if c:
v[idx] = a * c
w[idx] = b * c
idx += 1
for i in range(1, idx):
for j in range(m, v[i] - 1, -1):
f[j] = max(f[j], f[j - v[i]] + w[i])
print(f[m])

9. 分组背包问题

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
n, m = map(int, input().split())
N = 101
v = [[0] * N for _ in range(N)]
w = [[0] * N for _ in range(N)]
s, f = [0] * N, [0] * N
for i in range(1, n + 1):
s[i] = int(input())
for j in range(1, s[i] + 1):
v[i][j], w[i][j] = map(int, input().split())
for i in range(1, n + 1):
for j in range(m, 0, -1):
for k in range(1, s[i] + 1):
if j >= v[i][k]:
f[j] = max(f[j], f[j - v[i][k]] + w[i][k])
print(f[m])

线性DP

898. 数字三角形

1
2
3
4
5
6
7
8
9
10
11
n = int(input())
INF = -1e9
a = [[INF] * (n + 1) for _ in range(n + 1)]
f = [[INF] * (n + 1) for _ in range(n + 1)]
for i in range(1, n + 1):
a[i] = [INF] + list(map(int, input().split()))
f[1][1] = a[1][1]
for i in range(2, n + 1):
for j in range(1, i + 1):
f[i][j] = max(f[i - 1][j - 1], f[i - 1][j]) + a[i][j]
print(max(f[n]))

895. 最长上升子序列

1
2
3
4
5
6
7
8
n = int(input())
a = [0] + list(map(int, input().split()))
f = [1] * (n + 1)
for i in range(1, n + 1):
for j in range(1, i):
if a[i] > a[j]:
f[i] = max(f[i], f[j] + 1)
print(max(f))

897. 最长公共子序列

1
2
3
4
5
6
7
8
9
n, m = map(int, input().split())
a, b = ' ' + input(), ' ' + input()
f = [[0] * (m + 1) for _ in range(n + 1)]
for i in range(1, n + 1):
for j in range(1, m + 1):
f[i][j] = max(f[i - 1][j], f[i][j - 1])
if a[i] == b[j]:
f[i][j] = max(f[i][j], f[i - 1][j - 1] + 1)
print(f[n][m])

902. 最短编辑距离

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
n = int(input())
a = ' ' + input()
m = int(input())
b = ' ' + input()
f = [[0] * (m + 1) for _ in range(n + 1)]
for i in range(1, n + 1):
f[i][0] = i
for i in range(1, m + 1):
f[0][i] = i
for i in range(1, n + 1):
for j in range(1, m + 1):
if a[i] == b[j]:
f[i][j] = min(f[i - 1][j] + 1, f[i][j - 1] + 1, f[i - 1][j - 1])
else:
f[i][j] = min(f[i - 1][j], f[i][j - 1], f[i - 1][j - 1]) + 1
print(f[n][m])

899. 编辑距离

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
n, m = map(int, input().split())
N = 11
a = [[0] * N for _ in range(n + 1)]
f = [[0] * N for _ in range(N)]
for i in range(n):
a[i] = ' ' + input()
def distance(a, b):
la, lb = len(a), len(b)
for i in range(1, la):
f[i][0] = i
for i in range(1, lb):
f[0][i] = i
for i in range(1, la):
for j in range(1, lb):
if a[i] == b[j]:
f[i][j] = min(f[i - 1][j] + 1, f[i][j - 1] + 1, f[i - 1][j - 1])
else:
f[i][j] = min(f[i - 1][j], f[i][j - 1], f[i - 1][j - 1]) + 1
return f[la - 1][lb - 1]
for _ in range(m):
b, limit = input().split()
b, limit = ' ' + b, int(limit)
res = 0
for i in range(n):
if distance(a[i], b) <= limit:
res += 1
print(res)

896. 最长上升子序列 II

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
n = int(input())
a = list(map(int, input().split()))
q = [0] * (n + 1)
res = 0
for i in range(n):
l, r = 0, res
while l < r:
mid = l + r + 1 >> 1
if q[mid] < a[i]:
l = mid
else:
r = mid - 1
q[r + 1] = a[i]
res = max(res, r + 1)
print(res)

区间DP

282. 石子合并

1
2
3
4
5
6
7
8
9
10
11
12
n = int(input())
s = [0] + list(map(int, input().split()))
f = [[0] * (n + 1) for _ in range(n + 1)]
for i in range(1, n + 1):
s[i] += s[i - 1]
for length in range(2, n + 1):
for i in range(1, n - length + 2):
l, r = i, i + length - 1
f[l][r] = float('inf')
for k in range(l, r):
f[l][r] = min(f[l][r], f[l][k] + f[k + 1][r] + s[r] - s[l - 1])
print(f[1][n])

计数类DP

900. 整数划分

1
2
3
4
5
6
7
8
n = int(input())
MOD = int(1e9 + 7)
f = [0] * (n + 1)
f[0] = 1
for i in range(1, n + 1):
for j in range(i, n + 1):
f[j] = (f[j] + f[j - i]) % MOD
print(f[n])

数位统计DP

338. 计数问题

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
def power10(x):
res = 1
while x:
res *= 10
x -= 1
return res
def count(n, x):
res = cnt = 0
m = n
while m:
cnt += 1
m //= 10
for i in range(1, cnt + 1):
r = power10(i - 1)
l = n // (r * 10)
if x:
res += l * r
else:
res += (l - 1) * r
d = n // r % 10
if d == x:
res += n % r + 1
elif d > x:
res += r
return res
while True:
a, b = map(int, input().split())
if not a and not b:
break
if a > b:
a, b = b, a
for i in range(10):
print(count(b, i) - count(a - 1, i), end=' ')
print()

状态压缩DP

291. 蒙德里安的梦想

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
def fun(n, m):
f = [[0] * (1 << 12) for _ in range(12)]
st = [False] * (1 << 12)
for i in range(1 << n):
cnt = 0
st[i] = True
for j in range(n):
if i >> j & 1:
if cnt & 1: st[i] = False
cnt = 0
else: cnt += 1
if cnt & 1: st[i] = False
f[0][0] = 1
for i in range(1, m + 1):
for j in range(1 << n):
for k in range(1 << n):
if not (j & k) and st[j | k]:
f[i][j] += f[i - 1][k]
return f[m][0]
while True:
a, b = map(int, input().split())
if not a and not b: break
print(fun(a, b))

91. 最短Hamilton路径

1
2
3
4
5
6
7
8
9
10
11
n = int(input())
g = [list(map(int, input().split())) for _ in range(n)]
f = [[float('inf')] * n for _ in range(1 << n)]
f[1][0] = 0
for i in range(1 << n):
for j in range(n):
if i >> j & 1:
for k in range(n):
if i >> k & 1:
f[i][j] = min(f[i][j], f[i - (1 << j)][k] + g[k][j])
print(f[i - (1 << n)][n - 1])

树形DP

285. 没有上司的舞会

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import sys
sys.setrecursionlimit(3000)
n = int(input())
f = [[0] * 2 for _ in range(n + 1)]
parent = [False] * (n + 1)
happy = [0] * (n + 1)
adj_list = [[] for _ in range(n + 1)]
for i in range(1, n + 1):
happy[i] = int(input())
for _ in range(n - 1):
a, b = map(int, input().split())
parent[a] = True
adj_list[b].append(a)
root = 1
while parent[root]:
root += 1
def dfs(u):
f[u][1] = happy[u]
for j in adj_list[u]:
dfs(j)
f[u][0] += max(f[j][1], f[j][0])
f[u][1] += f[j][0]
dfs(root)
print(max(f[root][0], f[root][1]))

记忆化搜索

901. 滑雪

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
n, m = map(int, input().split())
f = [[0] * m for _ in range(n)]
g = [list(map(int, input().split())) for _ in range(n)]
dircts = [(0, 1), (1, 0), (0, -1), (-1, 0)]
def dp(x, y):
if f[x][y]:
return f[x][y]
f[x][y] = 1
for l, r in dircts:
a, b = x + l, b + r
if 0 <= a < n and 0 <= b < m and g[a][b] < g[x][y]:
f[x][y] = max(f[x][y], dp(a, b) + 1)
return f[x][y]
res = 0
for i in range(n):
for j in range(m):
res = max(res, dp(i, j))
print(res)

贪心

区间问题

905. 区间选点

1
2
3
4
5
6
7
8
9
n = int(input())
g = [list(map(int, input().split())) for _ in range(n)]
g.sort(lambda x:x[1])
res, end = 0, float('-inf')
for a, b in g:
if a > end:
res += 1
end = b
print(res)

908. 最大不相交区间数量

1
2
3
4
5
6
7
8
9
n = int(input())
g = [list(map(int, input().split())) for _ in range(n)]
g.sort(lambda x: x[1])
res, end = 0, float('-inf')
for a, b in g:
if a > end:
res += 1
end = b
print(res)

906. 区间分组

1
2
3
4
5
6
7
8
9
10
import heapq
n = int(input())
g = [list(map(int, input().split())) for _ in range(n)]
g.sort()
res = []
for a, b in g:
if res and a > res[0]:
heapq.heappop(res)
heapq.heappush(res, b)
print(len(res))

907. 区间覆盖

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
s, t = map(int, input().split())
n = int(input())
g = [list(map(int, input().split())) for _ in range(n)]
g.sort()
idx = res = 0
flag = False
while idx < n:
r = float('-inf')
while idx < n and g[idx][0] <= s:
r = max(r, g[idx][1])
idx += 1
if r < s:
break
s = r
res += 1
if r >= t:
flag = True
break
print(res if flag else '-1')

Huffman树

148. 合并果子

1
2
3
4
5
6
7
8
9
10
import heapq
n = int(input())
nums = list(map(int, input().split()))
heapq.heapify(nums)
res = 0
while len(nums) > 1:
a, b = heapq.heappop(nums), heapq.heappop(nums)
res += a + b
heapq.heappush(nums, a + b)
print(res)

排序不等式

913. 排队打水

1
2
3
4
5
6
7
n = int(input())
nums = list(map(int, input().split()))
nums.sort()
res = 0
for i, num in enumerate(nums):
res += num * (n - i - 1)
print(res)

绝对值不等式

104. 货仓选址

1
2
3
4
5
6
7
n = int(input())
nums = list(map(int, input().split()))
nums.sort()
res = 0
for num in nums:
res += abs(num - nums[n // 2])
print(res)

推公式

125. 耍杂技的牛

1
2
3
4
5
6
7
8
n = int(input())
g = [list(map(int, input().split())) for _ in range(n)]
g.sort(lambda x: x[0] + x[1])
res, pre_sum = float('-inf'), 0
for w, s in g:
res = max(res, pre_sum - s)
pre_sum += w
print(res)

Python注意

容易爆栈

1
2
import sys 
sys.setrecursionlimit(100000)

​ python语言并不适合递归算法,因为其递归深度,语言自身就有限制,就算去除限制,其也会开辟大量空间

交换str两个字符的位置

1
2
3
def swap(s, idx1, idx2):
l, r = (idx1, idx2) if idx1 < idx2 else (idx2, idx1)
return s[:l] + s[r] + s[l + 1: r] + s[l] + s[r + 1:]

增强函数记忆力

1
2
3
import functools
#lru_cache,可以为函数自动增加记忆化的能力,在递归算法中非常实用
@functools.lru_cache()

科学计数法要用int

1
2
# 默认的科学计数法是小数表示
MOD = int(1e9 + 7)

取模%运算

c++中

1
2
3
4
cout<< 7 % 4 << endl;   // 3
cout<< -7 % 4 << endl; // -3
cout<< 7 % -4 << endl; // 3
cout<< -7 % -4 << endl; // -3

python中

1
2
3
4
print(7 % 4)   // 3
print(-7 % 4) // 1
print(7 % -4) // -1
print(-7 % -4) // -3

C 语言和 Python 在涉及有负数取余运算时,结果可能不同的本质原因是:C 语言中是向0取整,而 Python 是向负无穷取整

输入

1
2
3
from sys import stdin
input = lambda: stdin.readline().strip()
n, m = map(int, input().split())

常用函数

1
2
3
import math
math.factorial(x)
math.gcd(a, b)

二分

二分找左边界l=mid+1,找右边界r=mid-1,并且mid=l+r+1>>1