Skip to content

Commit c62c0dc

Browse files
authored
Merge pull request #193 from nhooyr/ensure-close
Ensure connection is closed at all error points
2 parents 43c4dc0 + 2e0dd1c commit c62c0dc

File tree

2 files changed

+42
-26
lines changed

2 files changed

+42
-26
lines changed

read.go

+12-14
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,9 @@ func (c *Conn) reader(ctx context.Context) (_ MessageType, _ io.Reader, err erro
304304
defer c.readMu.unlock()
305305

306306
if !c.msgReader.fin {
307-
return 0, nil, errors.New("previous message not read to completion")
307+
err = errors.New("previous message not read to completion")
308+
c.close(fmt.Errorf("failed to get reader: %w", err))
309+
return 0, nil, err
308310
}
309311

310312
h, err := c.readLoop(ctx)
@@ -361,21 +363,9 @@ func (mr *msgReader) setFrame(h header) {
361363
}
362364

363365
func (mr *msgReader) Read(p []byte) (n int, err error) {
364-
defer func() {
365-
if errors.Is(err, io.ErrUnexpectedEOF) && mr.fin && mr.flate {
366-
err = io.EOF
367-
}
368-
if errors.Is(err, io.EOF) {
369-
err = io.EOF
370-
mr.putFlateReader()
371-
return
372-
}
373-
errd.Wrap(&err, "failed to read")
374-
}()
375-
376366
err = mr.c.readMu.lock(mr.ctx)
377367
if err != nil {
378-
return 0, err
368+
return 0, fmt.Errorf("failed to read: %w", err)
379369
}
380370
defer mr.c.readMu.unlock()
381371

@@ -384,6 +374,14 @@ func (mr *msgReader) Read(p []byte) (n int, err error) {
384374
p = p[:n]
385375
mr.dict.write(p)
386376
}
377+
if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) && mr.fin && mr.flate {
378+
mr.putFlateReader()
379+
return n, io.EOF
380+
}
381+
if err != nil {
382+
err = fmt.Errorf("failed to read: %w", err)
383+
mr.c.close(err)
384+
}
387385
return n, err
388386
}
389387

write.go

+30-12
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ import (
1010
"errors"
1111
"fmt"
1212
"io"
13-
"sync"
1413
"time"
1514

1615
"github.com/klauspost/compress/flate"
@@ -71,7 +70,7 @@ type msgWriterState struct {
7170
c *Conn
7271

7372
mu *mu
74-
writeMu sync.Mutex
73+
writeMu *mu
7574

7675
ctx context.Context
7776
opcode opcode
@@ -83,8 +82,9 @@ type msgWriterState struct {
8382

8483
func newMsgWriterState(c *Conn) *msgWriterState {
8584
mw := &msgWriterState{
86-
c: c,
87-
mu: newMu(c),
85+
c: c,
86+
mu: newMu(c),
87+
writeMu: newMu(c),
8888
}
8989
return mw
9090
}
@@ -155,10 +155,18 @@ func (mw *msgWriterState) reset(ctx context.Context, typ MessageType) error {
155155

156156
// Write writes the given bytes to the WebSocket connection.
157157
func (mw *msgWriterState) Write(p []byte) (_ int, err error) {
158-
defer errd.Wrap(&err, "failed to write")
158+
err = mw.writeMu.lock(mw.ctx)
159+
if err != nil {
160+
return 0, fmt.Errorf("failed to write: %w", err)
161+
}
162+
defer mw.writeMu.unlock()
159163

160-
mw.writeMu.Lock()
161-
defer mw.writeMu.Unlock()
164+
defer func() {
165+
if err != nil {
166+
err = fmt.Errorf("failed to write: %w", err)
167+
mw.c.close(err)
168+
}
169+
}()
162170

163171
if mw.c.flate() {
164172
// Only enables flate if the length crosses the
@@ -193,8 +201,11 @@ func (mw *msgWriterState) write(p []byte) (int, error) {
193201
func (mw *msgWriterState) Close() (err error) {
194202
defer errd.Wrap(&err, "failed to close writer")
195203

196-
mw.writeMu.Lock()
197-
defer mw.writeMu.Unlock()
204+
err = mw.writeMu.lock(mw.ctx)
205+
if err != nil {
206+
return err
207+
}
208+
defer mw.writeMu.unlock()
198209

199210
_, err = mw.c.writeFrame(mw.ctx, true, mw.flate, mw.opcode, nil)
200211
if err != nil {
@@ -214,7 +225,7 @@ func (mw *msgWriterState) close() {
214225
putBufioWriter(mw.c.bw)
215226
}
216227

217-
mw.writeMu.Lock()
228+
mw.writeMu.forceLock()
218229
mw.dict.close()
219230
}
220231

@@ -230,8 +241,8 @@ func (c *Conn) writeControl(ctx context.Context, opcode opcode, p []byte) error
230241
}
231242

232243
// frame handles all writes to the connection.
233-
func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opcode, p []byte) (int, error) {
234-
err := c.writeFrameMu.lock(ctx)
244+
func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opcode, p []byte) (_ int, err error) {
245+
err = c.writeFrameMu.lock(ctx)
235246
if err != nil {
236247
return 0, err
237248
}
@@ -243,6 +254,13 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opco
243254
case c.writeTimeout <- ctx:
244255
}
245256

257+
defer func() {
258+
if err != nil {
259+
err = fmt.Errorf("failed to write frame: %w", err)
260+
c.close(err)
261+
}
262+
}()
263+
246264
c.writeHeader.fin = fin
247265
c.writeHeader.opcode = opcode
248266
c.writeHeader.payloadLength = int64(len(p))

0 commit comments

Comments
 (0)