diff --git a/close.go b/close.go index fcc68065..9b9594b2 100644 --- a/close.go +++ b/close.go @@ -99,11 +99,7 @@ func CloseStatus(err error) StatusCode { func (c *Conn) Close(code StatusCode, reason string) (err error) { defer errd.Wrap(&err, "failed to close WebSocket") - if c.casClosing() { - err = c.waitGoroutines() - if err != nil { - return err - } + if c.userClosed.Swap(true) && c.isClosed() { return net.ErrClosed } defer func() { @@ -112,6 +108,10 @@ func (c *Conn) Close(code StatusCode, reason string) (err error) { } }() + if c.casClosing() { + return c.waitGoroutines() + } + err = c.closeHandshake(code, reason) err2 := c.close() @@ -132,11 +132,7 @@ func (c *Conn) Close(code StatusCode, reason string) (err error) { func (c *Conn) CloseNow() (err error) { defer errd.Wrap(&err, "failed to immediately close WebSocket") - if c.casClosing() { - err = c.waitGoroutines() - if err != nil { - return err - } + if c.userClosed.Swap(true) && c.isClosed() { return net.ErrClosed } defer func() { @@ -145,6 +141,10 @@ func (c *Conn) CloseNow() (err error) { } }() + if c.casClosing() { + return c.waitGoroutines() + } + err = c.close() err2 := c.waitGoroutines() diff --git a/conn.go b/conn.go index 09234871..c4387c14 100644 --- a/conn.go +++ b/conn.go @@ -77,9 +77,10 @@ type Conn struct { closeReadCtx context.Context closeReadDone chan struct{} - closing atomic.Bool - closeMu sync.Mutex // Protects following. - closed chan struct{} + userClosed atomic.Bool // Set by Close/CloseNow on first user call. + closing atomic.Bool + closeMu sync.Mutex // Protects following. + closed chan struct{} pingCounter atomic.Int64 activePingsMu sync.Mutex