sync.WaitGroup 的作用就是让主函数等待所有 goroutine 都执行完毕,再退出。

一个最简单的例子如下,如果没有 wg,那么 main 会在 goroutine 执行之前就退出,从而不会看到任何 output。

func main() {
    wg := sync.WaitGroup{}

    for i := 0; i < 3; i++ {
        wg.Add(1)
        go func(i int) {
            fmt.Println(i)
            wg.Done()
        }(i)
    }
    wg.Wait()
}

那么 WaitGroup 是如何实现的呢?

万变不离其宗,其底层还是基于 go runtime 提供的信号量机制,也就是 runtime_Semrelease()runtime_Semacquire(), 在之前的文章 Golang RWMutex 的实现netpoll 的实现 中都有它们的影子存在。

runtime_Semacquire(s *uint32) 
此函数会阻塞直到信号量*s的值大于0,原子减这个值。

runtime_Semrelease(s *uint32, lifo bool, skipframes int) 
此函数执行原子增信号量的值,然后通知被runtime_Semacquire阻塞的协程

说到底,就是用 信号量gopark 来控制 goroutine 是运行还是挂起,wg.Add() 对应信号量的增减,wg.Wait() 对应线程/协程的挂起。

WaitGroup 的源码位于 src/sync/waitgroup.go 文件中,一共才 100 多行,下面就逐一分析下。

WaitGroup

type WaitGroup struct {
	noCopy noCopy

	state1 [3]uint32
}

其中 noCopy 是干啥的呢? 搜索一番以后发现了这样的解释:

noCopy 用于 go vet 检查 sync.WaitGroup 类型变量是否采用了值传递的方式

如果采用了值传递,go vet 检查会抛出错误:call of foo copies lock value: sync.WaitGroup contains sync.noCopy, 因为如果采用值传递,那么 state1 就会被复制一份,而对应的信号量并不会跟着复制,所以值传递后复制出来的是一个不可用的 WaitGroup

state1 一共占用 3 * sizeof(uint32) = 12 字节,为什么要写成这样呢? 因为为了省字节而舍弃了可读性。 从结构体原本的实现能更容易看出这个结构体的意图。


+------------+-------------+--------------+
|  counter   |  waiter     |   sem        |
|            |             |              |
+------------+-------------+--------------+

分别代表了

  1. 第一个 4 字节的计数器用于 wg.Add() 和 wg.Done() 的计数
  2. 第二个 4 字节的计数器用于 wg.Wait() 调用者的计数
  3. 信号量

为什么要用 counter 和 waiter 两个计数器呢? 第一个很好理解,因为不止一个 goroutine 调用 Add 和 Done。

同样的道理,也可以有多个 goroutine 调用 wg.Wait() 等待,之所以没有理解,是因为我使用过的 wg 只有一个 wg.Wait() 调用者,没有考虑过多个的情景。

state 函数是用来获取 [3]uint32 中存的各个数据,

func (wg *WaitGroup) state() (statep *uint64, semap *uint32) {
    // 判断地址是否8字节对齐
    if uintptr(unsafe.Pointer(&wg.state1))%8 == 0 {
        return (*uint64)(unsafe.Pointer(&wg.state1)), &wg.state1[2]
    } else {
        return (*uint64)(unsafe.Pointer(&wg.state1[1])), &wg.state1[0]
    }
}

为什么其中需要对 8 求模呢? 因为 Golang 在 amd64 的机器上,64位原子操作需要满足64位对齐(也就是8字节对齐),32位编译器不能保证这点。

wg.Add()

有了前面的基础以后,看 Add 和 Wait 的实现就简单多,直接上关键部分的代码:

func (wg *WaitGroup) Add(delta int) {
    statep, semap := wg.state()
   
    // 将 delta 加到 counter 计数器上
    state := atomic.AddUint64(statep, uint64(delta)<<32)
    v := int32(state >> 32) // v 是调用 wg.Add() 的计数器
    w := uint32(state)      // w 是调用 wg.Wait() 的计数器
  
    if v < 0 {
        panic("sync: negative WaitGroup counter")
    }

    // 如果 w 不等于0,表明已经有 Wait 调用在等待,此时,再调Add会报错
    if w != 0 && delta > 0 && v == int32(delta) {
        panic("sync: WaitGroup misuse: Add called concurrently with Wait")
    }

    if v > 0 || w == 0 {
        return    // 正常情况下函数从这里返回
    }
   
    // 如果执行到这里,是最后一个 wg.Done() 执行
    if *statep != state {
        panic("sync: WaitGroup misuse: Add called concurrently with Wait")
    }

    // Reset waiters count to 0.
    *statep = 0

    for ; w != 0; w-- {

        // 根据 w 的值,将信号量做 release 操作
        // 如果 semap = 0,则阻塞的 wait 会监听到
        runtime_Semrelease(semap, false, 0)
    }
}

wg.Wait()

func (wg *WaitGroup) Wait() {
    statep, semap := wg.state()
   
    for {
        state := atomic.LoadUint64(statep)
        v := int32(state >> 32)
        w := uint32(state)
        if v == 0 {
            // Counter is 0, no need to wait.
            return
        }
        // Increment waiters count.
        // 将 w 计数器加 1, 由此可见 wg.Wait() 可以在 goroutine 中并发

        if atomic.CompareAndSwapUint64(statep, state, state+1) {
            // 当信号量 > 0,会执行 acquire 操作
            runtime_Semacquire(semap)
            
            if *statep != 0 {
                panic("sync: WaitGroup is reused before previous Wait has returned")
            }
            return
        }
    }
}

参考资料