@@ -10,7 +10,6 @@ import (
10
10
"errors"
11
11
"fmt"
12
12
"io"
13
- "sync"
14
13
"time"
15
14
16
15
"github.com/klauspost/compress/flate"
@@ -71,7 +70,7 @@ type msgWriterState struct {
71
70
c * Conn
72
71
73
72
mu * mu
74
- writeMu sync. Mutex
73
+ writeMu * mu
75
74
76
75
ctx context.Context
77
76
opcode opcode
@@ -83,8 +82,9 @@ type msgWriterState struct {
83
82
84
83
func newMsgWriterState (c * Conn ) * msgWriterState {
85
84
mw := & msgWriterState {
86
- c : c ,
87
- mu : newMu (c ),
85
+ c : c ,
86
+ mu : newMu (c ),
87
+ writeMu : newMu (c ),
88
88
}
89
89
return mw
90
90
}
@@ -155,10 +155,18 @@ func (mw *msgWriterState) reset(ctx context.Context, typ MessageType) error {
155
155
156
156
// Write writes the given bytes to the WebSocket connection.
157
157
func (mw * msgWriterState ) Write (p []byte ) (_ int , err error ) {
158
- defer errd .Wrap (& err , "failed to write" )
158
+ err = mw .writeMu .lock (mw .ctx )
159
+ if err != nil {
160
+ return 0 , fmt .Errorf ("failed to write: %w" , err )
161
+ }
162
+ defer mw .writeMu .unlock ()
159
163
160
- mw .writeMu .Lock ()
161
- defer mw .writeMu .Unlock ()
164
+ defer func () {
165
+ if err != nil {
166
+ err = fmt .Errorf ("failed to write: %w" , err )
167
+ mw .c .close (err )
168
+ }
169
+ }()
162
170
163
171
if mw .c .flate () {
164
172
// Only enables flate if the length crosses the
@@ -193,8 +201,11 @@ func (mw *msgWriterState) write(p []byte) (int, error) {
193
201
func (mw * msgWriterState ) Close () (err error ) {
194
202
defer errd .Wrap (& err , "failed to close writer" )
195
203
196
- mw .writeMu .Lock ()
197
- defer mw .writeMu .Unlock ()
204
+ err = mw .writeMu .lock (mw .ctx )
205
+ if err != nil {
206
+ return err
207
+ }
208
+ defer mw .writeMu .unlock ()
198
209
199
210
_ , err = mw .c .writeFrame (mw .ctx , true , mw .flate , mw .opcode , nil )
200
211
if err != nil {
@@ -214,7 +225,7 @@ func (mw *msgWriterState) close() {
214
225
putBufioWriter (mw .c .bw )
215
226
}
216
227
217
- mw .writeMu .Lock ()
228
+ mw .writeMu .forceLock ()
218
229
mw .dict .close ()
219
230
}
220
231
@@ -230,8 +241,8 @@ func (c *Conn) writeControl(ctx context.Context, opcode opcode, p []byte) error
230
241
}
231
242
232
243
// frame handles all writes to the connection.
233
- func (c * Conn ) writeFrame (ctx context.Context , fin bool , flate bool , opcode opcode , p []byte ) (int , error ) {
234
- err : = c .writeFrameMu .lock (ctx )
244
+ func (c * Conn ) writeFrame (ctx context.Context , fin bool , flate bool , opcode opcode , p []byte ) (_ int , err error ) {
245
+ err = c .writeFrameMu .lock (ctx )
235
246
if err != nil {
236
247
return 0 , err
237
248
}
@@ -243,6 +254,13 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opco
243
254
case c .writeTimeout <- ctx :
244
255
}
245
256
257
+ defer func () {
258
+ if err != nil {
259
+ err = fmt .Errorf ("failed to write frame: %w" , err )
260
+ c .close (err )
261
+ }
262
+ }()
263
+
246
264
c .writeHeader .fin = fin
247
265
c .writeHeader .opcode = opcode
248
266
c .writeHeader .payloadLength = int64 (len (p ))
0 commit comments