从矩阵乘法代码看编译器的自动优化

在高性能计算中,矩阵乘法是一个非常基础且重要的操作。一个看似简单的实现,背后却可能蕴含着编译器大量的分析和优化工作。本文将以一个基础的矩阵乘法C语言代码为例,探讨编译器是如何通过SCEV (Scalar Evolution) 表达式来分析循环中的内存访问模式,并在此基础上执行循环展开 (Loop Unrolling) 等优化,从而大幅提升程序性能。

1. 矩阵乘法C代码·

首先,我们来看一个标准的 N × N 矩阵乘法实现 C = A * B。为了简化,我们使用一维数组来模拟二维矩阵,其中元素 M[i][j] 通过 M[i * N + j] 来访问。

1
2
3
4
5
6
7
8
9
10
11
12
void MatrixMul(unsigned int N, int *C, int *A, int *B) {
unsigned int i, j, k;

for (i = 0; i < N; i++) {
for (j = 0; j < N; j++) {
C[i * N + j] = 0;
for (k = 0; k < N; k++) {
C[i * N + j] += A[i * N + k] * B[k * N + j];
}
}
}
}

这段代码通过三层嵌套循环,精确地实现了矩阵乘法的数学定义。虽然功能正确,但其性能在不经过优化的情况下通常不理想。接下来,我们将探讨编译器如何“读懂”并优化这段代码。

2. 编译器如何理解循环:SCEV表达式分析·

为了进行有效的优化,编译器首先需要理解循环中变量和内存地址是如何变化的。LLVM等现代编译器使用一种称为SCEV (Scalar Evolution) 的技术来分析循环中的标量值的演变规律。

SCEV基本概念·

SCEV将循环中变量的变化规律表示为一个“Add Recurrence” (AddRec) 表达式,通常记为 {start, +, step}

  • start: 循环第一次迭代时变量的初始值。
  • +: 表示这是一个加法(或减法)递推。
  • step: 每次循环迭代时变量增加(或减少)的步长。

例如,一个从 0 到 N-1 的循环变量 i,其SCEV表达式就是 {0, +, 1}

SCEV的基本运算规则·

SCEV表达式支持基本的算术运算,例如:

  • 加法: {S1, +, T1} + {S2, +, T2} = {S1 + S2, +, T1 + T2}
  • 与常量相加: C + {S, +, T} = {C + S, +, T}
  • 与常量相乘: C * {S, +, T} = {C * S, +, C * T}

分析 MatrixMul 代码中的SCEV传播·

编译器会从外到内,逐层分析每个循环,并将上一层循环的变量在本层循环中视为不变量(Constant)。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
void MatrixMul(unsigned int N, int *C, int *A, int *B) {
unsigned int i, j, k;

// 最外层循环 (i)
for (i = 0; i < N; i++) {
// 在此循环中,i 的SCEV为 {0, +, 1}。
// N 是一个不变量。
// 因此,表达式 i * N 的SCEV为 {0 * N, +, 1 * N} => {0, +, N}。

// 次内层循环 (j)
for (j = 0; j < N; j++) {
// 进入此循环时,外层的 i 是一个固定值,所以 i * N 在此循环中是
// 一个循环不变量 (Loop Invariant),我们称之为 C0。
// j 的SCEV为 {0, +, 1}。
//
// 访问 C[i * N + j] 的地址偏移量 i * N + j 的SCEV为:
// C0 + {0, +, 1} => {C0, +, 1} 或 {i * N, +, 1}。
// 这表示地址偏移量是一个等差数列,每次迭代增加1。
C[i * N + j] = 0;

// 最内层循环 (k)
for (k = 0; k < N; k++) {
// 进入此循环时,i 和 j 都是固定值。因此:
// 1. C[i * N + j] 的地址是一个不变量,我们称之为 C1。
// 这使得累加操作可以直接在某个寄存器中进行。
// 2. A[i * N + k] 的地址偏移量 i * N + k 的SCEV:
// i * N 是不变量 C0,k 的SCEV是 {0, +, 1}。
// 所以,其SCEV为 C0 + {0, +, 1} => {i * N, +, 1}。
// 表示每次访问 A 的地址都线性增加1个元素大小。
// 3. B[k * N + j] 的地址偏移量 k * N + j 的SCEV:
// j 是不变量 C2,k 的SCEV是 {0, +, 1}。
// k * N 的SCEV是 {0, +, N}。
// 所以,其SCEV为 {0, +, N} + C2 => {j, +, N}。
// 表示每次访问 B 的地址都线性增加 N 个元素大小。
C[i * N + j] += A[i * N + k] * B[k * N + j];
}
}
}
}

通过SCEV分析,编译器得出结论:在最内层的核心计算循环中,内存访问模式是极其规律的。

  • A 的访问是步长为 1 的连续访问 (stride=1)。
  • B 的访问是步长为 N 的等距访问 (stride=N)。
  • C 的访问地址是不变的。

3. 循环展开 (Loop Unrolling)·

基本概念·

循环展开是一种通过复制循环体来减少循环总迭代次数的技术。这样做可以带来几个核心好处:

  1. 减少循环开销:每次迭代都需要执行循环变量的递增和条件判断。展开后,这些指令的执行频率降低了。
  2. 增加指令级并行 (ILP):现代CPU可以同时执行多条指令(超标量、流水线)。循环展开将多个独立的计算操作暴露给CPU,使其能够更好地并行处理,隐藏延迟。
  3. 为其他优化创造机会:展开后的代码更有利于指令调度和向量化(SIMD)等优化。

MatrixMul 内层循环的展开示例·

编译器会尝试展开最内层的 k 循环。假设展开因子 (unroll-count) 为 4:

原始循环:

1
2
3
for (k = 0; k < N; k++) {
C[i * N + j] += A[i * N + k] * B[k * N + j];
}

展开后的伪代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
// 处理循环次数不是4的倍数的剩余部分
int remainder = N % 4;
int k = 0;
for (; k < remainder; k++) {
C[i * N + j] += A[i * N + k] * B[k * N + j];
}

// 主循环,每次处理4个元素
for (; k < N; k += 4) {
C[i * N + j] += A[i * N + k] * B[k * N + j];
C[i * N + j] += A[i * N + (k + 1)] * B[(k + 1) * N + j];
C[i * N + j] += A[i * N + (k + 2)] * B[(k + 2) * N + j];
C[i * N + j] += A[i * N + (k + 3)] * B[(k + 3) * N + j];
}

在展开后的主循环中,一次迭代就完成了四次原始迭代的工作。这不仅将循环判断和跳转指令的数量减少到原来的1/4,还使得CPU可以并行加载 AB 的数据,并同时执行多个乘加运算,从而极大地提高了计算效率。

4. 汇编代码分析·

以下是 MatrixMul 函数在开启优化后由 GCC 编译生成的 RISC-V 完整汇编代码。我们将分段剖析它。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
MatrixMul:
beq a0,zero,.L42
addi sp,sp,-32
slli a6,a0,2 # a6 = 4 * N
sw s0,28(sp)
sw s1,24(sp)
sw s2,20(sp)
sw s3,16(sp)
sw s4,12(sp)
sw s5,8(sp)
mv t3,a2 # t3 = &A
mv t5,a1 # t5 = &C
add t1,a2,a6 # t1 = &A + 4 * N = &A[N]
li t6,0 # t6 = i = 0
.L3:
mv a7,a3 # a7 = &B
mv a1,t5 # a1 = &C
li t4,0 # t4 = j = 0
.L5:
sub a5,t1,t3 # a5 = 4 * N
addi t0,a5,-4 # t0 = 4 * N - 4
srli t2,t0,2 # t2 = N - 1
addi s0,t2,1 # s0 = N
sw zero,0(a1) # C[0] = 0
andi s1,s0,7 # s1 = N mod 8 (loop unroll reminder)
mv t0,a7 # t0 = &B
mv a4,t3 # t4 = &A
li a2,0
beq s1,zero,.L4 # reminder == 0
li s2,1
beq s1,s2,.L31 # reminder == 1
li s3,2
beq s1,s3,.L32 # reminder == 2
li s4,3
beq s1,s4,.L33 # reminder == 3
li s5,4
beq s1,s5,.L34 # reminder == 4
li a5,5
beq s1,a5,.L35 # reminder == 5
li t2,6
beq s1,t2,.L36 # reminder == 6
# reminder == 7
lw a2,0(t3) # process 7 times
lw s0,0(a7)
addi a4,t3,4
add t0,a7,a6
mul a2,a2,s0
sw a2,0(a1)
.L36:
lw s1,0(a4) # process 6 times
lw s2,0(t0)
addi a4,a4,4
add t0,t0,a6
mul s3,s1,s2
add a2,a2,s3
sw a2,0(a1)
.L35:
lw s4,0(a4) # process 5 times
lw s5,0(t0)
addi a4,a4,4
add t0,t0,a6
mul a5,s4,s5
add a2,a2,a5
sw a2,0(a1)
.L34:
lw s0,0(a4) # process 4 times
lw t2,0(t0)
addi a4,a4,4
add t0,t0,a6
mul s1,s0,t2
add a2,a2,s1
sw a2,0(a1)
.L33:
lw s2,0(a4) # process 3 times
lw s3,0(t0)
addi a4,a4,4
add t0,t0,a6
mul s4,s2,s3
add a2,a2,s4
sw a2,0(a1)
.L32:
lw s5,0(a4) # process 2 times
lw a5,0(t0)
addi a4,a4,4
add t0,t0,a6
mul s0,s5,a5
add a2,a2,s0
sw a2,0(a1)
.L31:
lw s1,0(a4) # process 1 times
lw t2,0(t0)
addi a4,a4,4
add t0,t0,a6
mul s2,s1,t2
add a2,a2,s2
sw a2,0(a1)
beq a4,t1,.L40 # if N <= 7, innermost loop end.
.L4:
# loop unroll body (unroll count = 8)
lw s3,0(a4) # s3 = *A_base
lw a5,0(t0) # a5 = *B_base
add s5,t0,a6 # s5 = t0 + 4 * N
add s4,s5,a6 # s4 = s4 + 4 * N
mul a5,s3,a5 # a5 = *A_base * *B_base
add s3,s4,a6 # ...
add s2,s3,a6
add s1,s2,a6
add s0,s1,a6
add t2,s0,a6
add t0,t2,a6
addi a4,a4,32
add a5,a2,a5
sw a5,0(a1)
lw s5,0(s5)
lw a2,-28(a4)
mul a2,a2,s5
add a5,a5,a2
sw a5,0(a1)
lw s5,-24(a4)
lw s4,0(s4)
mul a2,s5,s4
add a5,a5,a2
sw a5,0(a1)
lw s5,-20(a4)
lw s3,0(s3)
mul s4,s5,s3
add a5,a5,s4
sw a5,0(a1)
lw a2,-16(a4)
lw s2,0(s2)
mul s5,a2,s2
add s3,a5,s5
sw s3,0(a1)
lw s4,-12(a4)
lw s1,0(s1)
mul a5,s4,s1
add s2,s3,a5
sw s2,0(a1)
lw a2,-8(a4)
lw s0,0(s0)
mul s5,a2,s0
add s3,s2,s5
sw s3,0(a1)
lw s4,-4(a4)
lw t2,0(t2)
mul s1,s4,t2
add a2,s3,s1
sw a2,0(a1)
bne a4,t1,.L4 # if a4 != &A_base + 4 * N
.L40:
addi a4,t4,1 # j++
addi a1,a1,4 # a1 = &C_base + 4
addi a7,a7,4 # a7 = &B_base + 4
beq a0,a4,.L46 # if j == N, second loop end.
mv t4,a4
j .L5 # contine the second loop.
.L46:
add t3,t3,a6 # t3 = &A_base + 4 * N
add t1,t1,a6 # t1 = &A_base[N] + 4 * N
add t5,t5,a6 # t5 = &C_base + 4 * N
beq t6,t4,.L1 # if i == N, outer loop end.
addi t6,t6,1
j .L3 # contine the outer loop.
.L1:
lw s0,28(sp)
lw s1,24(sp)
lw s2,20(sp)
lw s3,16(sp)
lw s4,12(sp)
lw s5,8(sp)
addi sp,sp,32
jr ra
.L42:
ret
.size MatrixMul, .-MatrixMul
.ident "GCC: (ge53277d849a) 15.0.0 20250107 (experimental)"
.section .note.GNU-stack,"",@progbits

三层循环初始化及步长变化 : SCEV 的应用·

SCEV 的核心作用是把循环中复杂的变址计算(如 i * N)转变为简单的基地址加步长的递推关系。这一点在外两层循环的指针维护做的比较多。

外层 i 循环: SCEV 分析得出,在 i 循环中,AC 矩阵的行首地址 &A[i*N]&C[i*N] 都是以 N * sizeof(int) 为步长递增的,其 SCEV 表达式为 {base, +, N*4}

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
	  slli    a6,a0,2           # a6 = N * 4 (预计算步长)
mv t3,a2 # t3 = &A (A的基地址)
mv t5,a1 # t5 = &C (C的基地址)
li t6,0 # t6 = i = 0
# i-loop start
.L3:
# ... j-loop 和 k-loop 代码省略 ...

# i-loop 的结尾,准备下一次迭代
.L46:
add t3,t3,a6 # t3 = t3 + N*4 => 更新A的行指针
add t5,t5,a6 # t5 = t5 + N*4 => 更新C的行指针
addi t6,t6,1 # i++
# ...
j .L3 # 跳转到下一次 i 循环
  • t3t5 寄存器分别保存了 AC 当前处理行的基地址。

  • 在每次 i 循环结束时(标签 .L46 处),代码执行 add t3, t3, a6add t5, t5, a6。这里的 a6 就是预先计算好的步长 N*4

  • 这完美地印证了 SCEV 的分析:编译器放弃了在每次循环中重新计算 i * N 的笨办法,而是通过在前一次迭代的基地址上直接增加一个固定的步长,来获得当前行的地址。

次内层 j 循环: 同样的原理也应用在 j 循环中。对于 C[i*N+j],其地址在 j 循环内的 SCEV 是 {&C[i*N], +, 4}。对于 B 矩阵 &B[k*N+j] ,其地址偏移也以会随 j 的递增 4 字节为单位递增。

1
2
3
4
5
6
7
# j-loop 的结尾,准备下一次迭代
.L40:
addi a1,a1,4 # a1 = &C[i*N] + 4 => 更新C的元素指针
addi a7,a7,4 # a7 = &B[k*N] + 4 => 更新B的列指针
addi a4,t4,1 # j++
# ...
j .L5 # 跳转到下一次 j 循环

最内层循环 (k-loop)·

余数处理·

这部分是优化的核心。GCC 决定将最内层的 k-loop 展开 8 次。为了处理 N 不是 8 的倍数的情况,它采用了一种非常高效的余数处理策略。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
	sw  zero,0(a1)         		# C[i*N+j] = 0
andi s1,s0,7 # s1 = N % 8 (计算余数)
...
beq s1,zero,.L4 # 如果余数是0,直接跳到主展开循环

# ... 这里是一系列 beq 指令 ...
# 这构成了一个跳转表,根据余数的值(s1)直接跳到对应的处理代码
# 比如,如果余数是3,就跳到 .L33
.L33: # 处理余数=3的情况
lw s2,0(a4)
lw s3,0(t0)
addi a4,a4,4
add t0,t0,a6
mul s4,s2,s3
add a2,a2,s4
sw a2,0(a1)
.L32: # 处理余数=2的情况 (从 .L33 fall-through)
lw s5,0(a4)
lw a5,0(t0)
addi a4,a4,4
add t0,t0,a6
mul s0,s5,a5
add a2,a2,s0
sw a2,0(a1)
.L31: # 处理余数=1的情况 (从 .L32 fall-through)
lw s1,0(a4)
lw t2,0(t0)
addi a4,a4,4
add t0,t0,a6
mul s2,s1,t2
add a2,a2,s2
sw a2,0(a1)
# 余数处理完毕
beq a4,t1,.L40 # 如果 N <= 7,内层循环已完成,直接跳到末尾,结束最内循环

GCC 没有使用一个单独的循环来处理余数,而是生成了一段精巧的 “fall-through” 代码。如果 N % 8 = 3,程序会跳转到 .L33,执行一次乘加,然后顺序执行 .L32.L31 的代码,恰好完成3次迭代。这种方式避免了额外的循环判断开销。

展开的主体循环·

如果 N 大于 7,在处理完余数后,代码会进入 .L4 标签后的主循环。这里我们可以清晰地看到循环体被复制了8次。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
.L4:                                # 主循环体,一次处理8个元素
# ---- 迭代 1 ----
lw s3,0(a4)
lw a5,0(t0)
mul a5,s3,a5
# ...
# ---- 迭代 2 ----
lw s5,0(s5)
lw a2,-28(a4)
mul a2,a2,s5
# ...
# ---- 迭代 3 ----
lw s5,-24(a4)
lw s4,0(s4)
mul a2,s5,s4
# ...
# ---- (迭代 4, 5, 6, 7, 8 省略) ----

addi a4,a4,32 # A 的指针一次性前进 8 * 4 = 32 字节
bne a4,t1,.L4 # 循环条件判断,继续下一批8个元素

5. 反思·

本文在四个月前就想要写了,当时想写的内容远比现在呈现的多得多,为此付出的前期工作也远比文中提到的要多。

记得起初,我为这个主题编写了专门的测试 driver 程序,在玄铁的 Linux K230 开发板上,详细对比了 GCC 和 Clang 编译 matmul 后的性能与汇编代码差距。当时记录下了一些中间文件和心得,原本打算将这些内容完善成一个系列文章,完整地介绍从优化跑分到优化分析的一系列过程。

所以说,有什么想干的事,真的要尽早地去做,不要拖延,“72小时黄金时间”是对的。

在这里,我还是粗略地记录一下当时的一些笔记吧,也许后面会把这个“坑”填上。

  • 测试程序的输入矩阵不能过大。否则,程序的性能瓶颈将由访存效率决定,代码本身的优化将难以体现出优势。
  • 矩阵大小为4或8的倍数时,效率通常会更好。从多组控制矩阵大小的数据来看,存在这个趋势,猜测是因为 Cache Line 的原因。
  • 对于一个调度已经很优秀的基本块,减少其中一两条指令,对其性能优化影响不大
  • GCC 和 Clang 都没有将代码优化成如下“累加器”形式
1
2
3
4
5
int acc = 0;
for (k = 0; k < N; k++) {
acc += A[i * N + k] * B[k * N + j];
}
C[i * N + j] = acc;

这个优化的主要障碍是编译器的别名分析 (Alias Analysis) 不成功。编译器无法确保写入地址 C 和读取地址 AB 指向的内存区域一定不重叠。如果将 AB 矩阵和 C 矩阵的类型分开(例如 AB 使用 short 类型,C 使用 int 类型),那么 TBAA (Type-Based Alias Analysis) 将可以分析出它们一定不是别名,从而让编译器放心地进行此项优化。