sync.WaitGroup 的作用就是让主函数等待所有 goroutine 都执行完毕,再退出。
一个最简单的例子如下,如果没有 wg,那么 main 会在 goroutine 执行之前就退出,从而不会看到任何 output。
func main() {
wg := sync.WaitGroup{}
for i := 0; i < 3; i++ {
wg.Add(1)
go func(i int) {
fmt.Println(i)
wg.Done()
}(i)
}
wg.Wait()
}
那么 WaitGroup 是如何实现的呢?
万变不离其宗,其底层还是基于 go runtime 提供的信号量机制,也就是 runtime_Semrelease()
和 runtime_Semacquire()
,
在之前的文章 Golang RWMutex 的实现 和 netpoll 的实现 中都有它们的影子存在。
runtime_Semacquire(s *uint32)
此函数会阻塞直到信号量*s的值大于0,原子减这个值。
runtime_Semrelease(s *uint32, lifo bool, skipframes int)
此函数执行原子增信号量的值,然后通知被runtime_Semacquire阻塞的协程
说到底,就是用 信号量 和 gopark 来控制 goroutine 是运行还是挂起,wg.Add()
对应信号量的增减,wg.Wait()
对应线程/协程的挂起。
WaitGroup 的源码位于 src/sync/waitgroup.go
文件中,一共才 100 多行,下面就逐一分析下。
WaitGroup
type WaitGroup struct {
noCopy noCopy
state1 [3]uint32
}
其中 noCopy 是干啥的呢? 搜索一番以后发现了这样的解释:
noCopy 用于 go vet 检查 sync.WaitGroup 类型变量是否采用了值传递的方式
如果采用了值传递,go vet 检查会抛出错误:call of foo copies lock value: sync.WaitGroup contains sync.noCopy, 因为如果采用值传递,那么 state1 就会被复制一份,而对应的信号量并不会跟着复制,所以值传递后复制出来的是一个不可用的 WaitGroup
state1 一共占用 3 * sizeof(uint32) = 12
字节,为什么要写成这样呢? 因为为了省字节而舍弃了可读性。 从结构体原本的实现能更容易看出这个结构体的意图。
+------------+-------------+--------------+
| counter | waiter | sem |
| | | |
+------------+-------------+--------------+
分别代表了
- 第一个 4 字节的计数器用于 wg.Add() 和 wg.Done() 的计数
- 第二个 4 字节的计数器用于 wg.Wait() 调用者的计数
- 信号量
为什么要用 counter 和 waiter 两个计数器呢? 第一个很好理解,因为不止一个 goroutine 调用 Add 和 Done。
同样的道理,也可以有多个 goroutine 调用 wg.Wait() 等待,之所以没有理解,是因为我使用过的 wg 只有一个 wg.Wait() 调用者,没有考虑过多个的情景。
state 函数是用来获取 [3]uint32
中存的各个数据,
func (wg *WaitGroup) state() (statep *uint64, semap *uint32) {
// 判断地址是否8字节对齐
if uintptr(unsafe.Pointer(&wg.state1))%8 == 0 {
return (*uint64)(unsafe.Pointer(&wg.state1)), &wg.state1[2]
} else {
return (*uint64)(unsafe.Pointer(&wg.state1[1])), &wg.state1[0]
}
}
为什么其中需要对 8 求模呢? 因为 Golang 在 amd64 的机器上,64位原子操作需要满足64位对齐(也就是8字节对齐),32位编译器不能保证这点。
wg.Add()
有了前面的基础以后,看 Add 和 Wait 的实现就简单多,直接上关键部分的代码:
func (wg *WaitGroup) Add(delta int) {
statep, semap := wg.state()
// 将 delta 加到 counter 计数器上
state := atomic.AddUint64(statep, uint64(delta)<<32)
v := int32(state >> 32) // v 是调用 wg.Add() 的计数器
w := uint32(state) // w 是调用 wg.Wait() 的计数器
if v < 0 {
panic("sync: negative WaitGroup counter")
}
// 如果 w 不等于0,表明已经有 Wait 调用在等待,此时,再调Add会报错
if w != 0 && delta > 0 && v == int32(delta) {
panic("sync: WaitGroup misuse: Add called concurrently with Wait")
}
if v > 0 || w == 0 {
return // 正常情况下函数从这里返回
}
// 如果执行到这里,是最后一个 wg.Done() 执行
if *statep != state {
panic("sync: WaitGroup misuse: Add called concurrently with Wait")
}
// Reset waiters count to 0.
*statep = 0
for ; w != 0; w-- {
// 根据 w 的值,将信号量做 release 操作
// 如果 semap = 0,则阻塞的 wait 会监听到
runtime_Semrelease(semap, false, 0)
}
}
wg.Wait()
func (wg *WaitGroup) Wait() {
statep, semap := wg.state()
for {
state := atomic.LoadUint64(statep)
v := int32(state >> 32)
w := uint32(state)
if v == 0 {
// Counter is 0, no need to wait.
return
}
// Increment waiters count.
// 将 w 计数器加 1, 由此可见 wg.Wait() 可以在 goroutine 中并发
if atomic.CompareAndSwapUint64(statep, state, state+1) {
// 当信号量 > 0,会执行 acquire 操作
runtime_Semacquire(semap)
if *statep != 0 {
panic("sync: WaitGroup is reused before previous Wait has returned")
}
return
}
}
}