diff --git a/dissect.go b/dissect.go index eda0851..9572686 100644 --- a/dissect.go +++ b/dissect.go @@ -245,6 +245,9 @@ func (dp *DissectedPacket) parseTLSServerName() (string, error) { case dp.TCP != nil: return ExtractTLSServerName(dp.TCP.Payload) case dp.UDP != nil: + if sni, err := ExtractQUICServerName(dp.UDP.Payload); err == nil { + return sni, err + } return ExtractTLSServerName(dp.UDP.Payload) default: return "", ErrDissectTransport diff --git a/dpidrop.go b/dpidrop.go index 3308a92..c1bd1fc 100644 --- a/dpidrop.go +++ b/dpidrop.go @@ -72,11 +72,6 @@ func (r *DPIDropTrafficForTLSSNI) Filter( return nil, false } - // short circuit for UDP packets - if packet.TransportProtocol() != layers.IPProtocolTCP { - return nil, false - } - // try to obtain the SNI sni, err := packet.parseTLSServerName() if err != nil { diff --git a/go.mod b/go.mod index f5fc7a7..b0266a6 100644 --- a/go.mod +++ b/go.mod @@ -14,6 +14,7 @@ require ( require ( github.com/google/btree v1.1.2 // indirect github.com/pkg/errors v0.9.1 // indirect + github.com/quic-go/quic-go v0.36.0 // indirect github.com/stretchr/testify v1.8.1 // indirect golang.org/x/net v0.10.0 // indirect golang.org/x/sys v0.8.0 // indirect diff --git a/go.sum b/go.sum index 1bebe0f..a6adfc1 100644 --- a/go.sum +++ b/go.sum @@ -78,6 +78,8 @@ github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINE github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= +github.com/quic-go/quic-go v0.36.0 h1:JIrO7p7Ug6hssFcARjWDiqS2RAKJHCiwPxBAA989rbI= +github.com/quic-go/quic-go v0.36.0/go.mod h1:zPetvwDlILVxt15n3hr3Gf/I3mDf7LpLKPhR4Ez0AZQ= github.com/rogpeppe/fastuuid v1.1.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ= github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAmXWZgo= github.com/smartystreets/assertions v1.0.0/go.mod h1:kHHU4qYBaI3q23Pp3VPrmWhuIUrLW/7eUrw0BU5VaoM= diff --git a/quiccrypto.go b/quiccrypto.go new file mode 100644 index 0000000..f0e1bbf --- /dev/null +++ b/quiccrypto.go @@ -0,0 +1,137 @@ +package netem + +import ( + "crypto" + "crypto/aes" + "crypto/cipher" + "encoding/binary" + + "golang.org/x/crypto/hkdf" +) + +// https://www.rfc-editor.org/rfc/rfc9001.html#protection-keys +// +// computeHP derives the header protection key from the initial secret. +func computeHP(secret []byte) (hp []byte) { + hp = hkdfExpandLabel(crypto.SHA256, secret, []byte{}, "quic hp", 16) + return +} + +// SPDX-License-Identifier: BSD-3-Clause +// This code is borrowed from https://github.com/marten-seemann/qtls-go1-15 +// https://github.com/marten-seemann/qtls-go1-15/blob/0d137e9e3594d8e9c864519eff97b323321e5e74/cipher_suites.go#L281 +type aead interface { + cipher.AEAD + + // explicitNonceLen returns the number of bytes of explicit nonce + // included in each record. This is eight for older AEADs and + // zero for modern ones. + explicitNonceLen() int +} + +// SPDX-License-Identifier: BSD-3-Clause +// This code is borrowed from https://github.com/marten-seemann/qtls-go1-15 +// https://github.com/marten-seemann/qtls-go1-15/blob/0d137e9e3594d8e9c864519eff97b323321e5e74/cipher_suites.go#L375 +func aeadAESGCMTLS13(key, nonceMask []byte) aead { + if len(nonceMask) != aeadNonceLength { + panic("tls: internal error: wrong nonce length") + } + aes, err := aes.NewCipher(key) + if err != nil { + panic(err) + } + aead, err := cipher.NewGCM(aes) + if err != nil { + panic(err) + } + ret := &xorNonceAEAD{aead: aead} + copy(ret.nonceMask[:], nonceMask) + return ret +} + +// SPDX-License-Identifier: MIT +// This code is borrowed from https://github.com/lucas-clemente/quic-go/ +// https://github.com/lucas-clemente/quic-go/blob/f3b098775e40f96486c0065204145ddc8675eb7c/internal/handshake/initial_aead.go#L60 +// https://www.rfc-editor.org/rfc/rfc9001.html#protection-keys +// +// computeInitialKeyAndIV derives the packet protection key and Initialization Vector (IV) from the initial secret. +func computeInitialKeyAndIV(secret []byte) (key, iv []byte) { + key = hkdfExpandLabel(crypto.SHA256, secret, []byte{}, "quic key", 16) + iv = hkdfExpandLabel(crypto.SHA256, secret, []byte{}, "quic iv", 12) + return +} + +// SPDX-License-Identifier: MIT +// This code is borrowed from https://github.com/lucas-clemente/quic-go/ +// https://github.com/lucas-clemente/quic-go/blob/f3b098775e40f96486c0065204145ddc8675eb7c/internal/handshake/initial_aead.go#L53 +// https://www.rfc-editor.org/rfc/rfc9001.html#name-initial-secrets +// +// computeSecrets computes the initial secrets based on the destination connection ID. +func computeSecrets(destConnID []byte) (clientSecret, serverSecret []byte) { + initialSalt := []byte{0x38, 0x76, 0x2c, 0xf7, 0xf5, 0x59, 0x34, 0xb3, 0x4d, 0x17, 0x9a, 0xe6, 0xa4, 0xc8, 0x0c, 0xad, 0xcc, 0xbb, 0x7f, 0x0a} + initialSecret := hkdf.Extract(crypto.SHA256.New, destConnID, initialSalt) + clientSecret = hkdfExpandLabel(crypto.SHA256, initialSecret, []byte{}, "client in", crypto.SHA256.Size()) + serverSecret = hkdfExpandLabel(crypto.SHA256, initialSecret, []byte{}, "server in", crypto.SHA256.Size()) + return +} + +// SPDX-License-Identifier: MIT +// This code is borrowed from https://github.com/lucas-clemente/quic-go/ +// https://github.com/lucas-clemente/quic-go/blob/master/internal/handshake/hkdf.go +// +// hkdfExpandLabel HKDF expands a label. +func hkdfExpandLabel(hash crypto.Hash, secret, context []byte, label string, length int) []byte { + b := make([]byte, 3, 3+6+len(label)+1+len(context)) + binary.BigEndian.PutUint16(b, uint16(length)) + b[2] = uint8(6 + len(label)) + b = append(b, []byte("tls13 ")...) + b = append(b, []byte(label)...) + b = b[:3+6+len(label)+1] + b[3+6+len(label)] = uint8(len(context)) + b = append(b, context...) + + out := make([]byte, length) + n, err := hkdf.Expand(hash.New, secret, b).Read(out) + if err != nil || n != length { + panic("quic: HKDF-Expand-Label invocation failed unexpectedly") + } + return out +} + +const aeadNonceLength = 12 + +// SPDX-License-Identifier: BSD-3-Clause +// This code is borrowed from https://github.com/marten-seemann/qtls-go1-15 +// https://github.com/marten-seemann/qtls-go1-15/blob/0d137e9e3594d8e9c864519eff97b323321e5e74/cipher_suites.go#L319 +// +// xorNonceAEAD wraps an AEAD by XORing in a fixed pattern to the nonce before each call. +type xorNonceAEAD struct { + nonceMask [aeadNonceLength]byte + aead cipher.AEAD +} + +func (f *xorNonceAEAD) NonceSize() int { return 8 } // 64-bit sequence number +func (f *xorNonceAEAD) Overhead() int { return f.aead.Overhead() } +func (f *xorNonceAEAD) explicitNonceLen() int { return 0 } + +func (f *xorNonceAEAD) Seal(out, nonce, plaintext, additionalData []byte) []byte { + for i, b := range nonce { + f.nonceMask[4+i] ^= b + } + result := f.aead.Seal(out, f.nonceMask[:], plaintext, additionalData) + for i, b := range nonce { + f.nonceMask[4+i] ^= b + } + return result +} + +func (f *xorNonceAEAD) Open(out, nonce, ciphertext, additionalData []byte) ([]byte, error) { + for i, b := range nonce { + f.nonceMask[4+i] ^= b + } + result, err := f.aead.Open(out, f.nonceMask[:], ciphertext, additionalData) + for i, b := range nonce { + f.nonceMask[4+i] ^= b + } + return result, err +} diff --git a/quicparse.go b/quicparse.go new file mode 100644 index 0000000..7d3c7bc --- /dev/null +++ b/quicparse.go @@ -0,0 +1,357 @@ +package netem + +import ( + "bytes" + "crypto/aes" + "encoding/binary" + "errors" + "fmt" + + "github.com/quic-go/quic-go/quicvarint" + "golang.org/x/crypto/cryptobyte" +) + +// ErrQUICParse is the error returned in case there is a QUIC parse error. +var ErrQUICParse = errors.New("quicparse: parse error") + +// newErrQUICParse returns a new [ErrQUICParse]. +func newErrQUICParse(message string) error { + return fmt.Errorf("%w: %s", ErrQUICParse, message) +} + +// QUICLongHeaderPacket is a decryptable long header packet +type QUICLongHeaderPacket interface { + Decrypt(raw []byte) error +} + +// QUICClientInitial is a data structure to store the header fields and (decrypted) payload of a +// parsed QUIC Client Initial packet. +// Specified in https://www.rfc-editor.org/rfc/rfc9000.html#name-initial-packet. +type QUICClientInitial struct { + // FirstByte is the partly encrypted first byte of the Initial packet. + // The lower 4 bits are protected by QUIC Header Protection. + // * Header Form (1), + // * Fixed Bit (1), + // * Long Packet Type (2), + // * Type-specific bits (4) + FirstByte byte + cursor *bytes.Reader + QUICLongHeaderPacket + + // QUICVersion is the QUIC version number. + QUICVersion uint32 + + // DestinationID is the variable length (up to 20 Byte) Destination Connection ID. + DestinationID []byte + + // SourceID is the variable length (up to 20 Byte) Source Connection ID. + SourceID []byte + + // Token is the QUIC token. + Token []byte + + // Length is the total length of packet number and payload bytes. + Length uint64 + + // PnOffset is the offset for the packet number which prefixes the packet payload. + PnOffset int + + // DecryptedPacketNumber is the decrypted packet number. + // The packet number is expected to be 0 for the Client Initial. + // Produced by QUICClientInitial.Decrypt + DecryptedPacketNumber []byte + + // Payload is the encrypted payload of the QUIC Client Initial. + // Produced by QUICClientInitial.Decrypt + Payload []byte + + // DecryptedPayload is the decrypted payload of the packet. + // Produced by QUICClientInitial.Decrypt + DecryptedPayload []byte +} + +// UnmarshalLongHeaderPacket unmarshals a raw QUIC long header packet +// Return values: +// 1. the parsed QUICClientInitial (on success) +// 2. the remaining data to be parsed [*bytes.Reader] +// 3. error (on failure) +func UnmarshalLongHeaderPacket(raw []byte) (QUICLongHeaderPacket, error) { + // read the packet header byte + cursor := bytes.NewReader(cryptobyte.String(raw)) + firstByte, err := cursor.ReadByte() + if err != nil { + return nil, newErrQUICParse("QUIC packet: cannot read first byte") + } + switch (firstByte & 0b1000_0000) >> 7 { + case 1: // allow long header format + + default: + return nil, newErrQUICParse("QUIC packet: unsupported header type") + } + + // the packet type is encoded in bits 6 and 5 (MSB 8 7 6 5 4 3 2 1 LSB) of the first byte + 1 + ptype := (firstByte & 0x30) >> 4 + switch ptype { + case 0: // Initial packet type + ci := &QUICClientInitial{ + FirstByte: firstByte, + cursor: cursor, + } + return ci, unmarshalInitial(raw, ci, cursor) + default: + return nil, newErrQUICParse("long header: unsupported packet type") + } +} + +// unmarshalInitial unmarshals a raw QUIC Client Initial packet +// Modifies the QUICClientInitial instance, and the cursor [*bytes.Reader]. +// Returns an error on failure. +func unmarshalInitial(raw []byte, ci *QUICClientInitial, cursor *bytes.Reader) error { + var err error + // QUIC version (4) + versionBytes := make([]byte, 4) + if _, err = cursor.Read(versionBytes); err != nil { + return newErrQUICParse("Initial header: cannot read version field") + } + ci.QUICVersion = binary.BigEndian.Uint32(versionBytes) + switch ci.QUICVersion { + case 0x1, 0xff00001d, 0xbabababa: + // all good + default: + return newErrQUICParse("Initial header: unsupported QUIC version") + } + // Destination Connection ID (1 + n) + var lendid uint8 + if lendid, err = cursor.ReadByte(); err != nil { + return newErrQUICParse("Initial header: cannot read length destination ID") + } + ci.DestinationID = make([]byte, int(lendid)) + if _, err = cursor.Read(ci.DestinationID); err != nil { + return newErrQUICParse("Initial header: cannot read destination ID") + } + // Source Connection ID (1 + n) + var lensid uint8 + if lensid, err = cursor.ReadByte(); err != nil { + return newErrQUICParse("Initial header: cannot read length source ID") + } + ci.SourceID = make([]byte, int(lensid)) + if _, err = cursor.Read(ci.SourceID); err != nil { + return newErrQUICParse("Initial header: cannot read source ID") + } + // Token length (n) + tokenlen, err := quicvarint.Read(cursor) + if err != nil { + return newErrQUICParse("Initial header: cannot read token length") + } + // Token (m) + ci.Token = make([]byte, tokenlen) + if _, err = cursor.Read(ci.Token); err != nil { + return newErrQUICParse("Initial header: cannot read token") + } + // Length of the payload + if ci.Length, err = quicvarint.Read(cursor); err != nil { + return newErrQUICParse("Initial header: cannot read payload length") + } + // ci.Length = append([]byte{lengthfirstbyte}, ci.Length...) + ci.PnOffset = int(cursor.Size()) - cursor.Len() + return nil +} + +// Decrypt decrypts the parsed Client Initial. +// Modifies the QUICClientInitial instance. +// Returns an error on failure. +func (ci *QUICClientInitial) Decrypt(raw []byte) error { + // the 16-byte ciphertext sample used for header protection starts at pnOffset + 4 + sampleOffset := ci.PnOffset + 4 + sample := raw[sampleOffset : sampleOffset+16] + + // the AES header protection key is derived from the destination ID and a version-specific salt + clientSecret, _ := computeSecrets(ci.DestinationID) + hp := computeHP(clientSecret) + block, err := aes.NewCipher(hp) + if err != nil { + return newErrQUICParse("decrypt Initial: error creating new AES cipher" + err.Error()) + } + mask := make([]byte, block.BlockSize()) + if len(sample) != len(mask) { + panic("invalid sample size") + } + // the mask used for header protection is obtained by encrypting the ciphertext sample + block.Encrypt(mask, sample) + + // remove header protection (applied to the second half of the first byte) + ci.FirstByte ^= mask[0] & 0xf + + // the packet number length is encoded in the two least significant bits of the first byte + 1 + pnlen := 1 << (ci.FirstByte & 0x03) + ci.DecryptedPacketNumber = make([]byte, pnlen) + if _, err = ci.cursor.Read(ci.DecryptedPacketNumber); err != nil { + return newErrQUICParse("decrypt Initial: cannot read packet number") + } + // remove header protection from the packet number field + for i, _ := range ci.DecryptedPacketNumber { + ci.DecryptedPacketNumber[i] ^= mask[i+1] + if ci.DecryptedPacketNumber[i] != 0 { + return newErrQUICParse("decrypt Initial: unexpected packet number (expect 0)") + } + } + // calculate the length of the payload + payloadLength := int(ci.Length) - pnlen + if payloadLength <= 0 { + return newErrQUICParse("decrypt Initial: no payload") + } + // parse the payload + ci.Payload = make([]byte, payloadLength) + if _, err = ci.cursor.Read(ci.Payload); err != nil { + return newErrQUICParse("decrypt Initial: cannot read payload") + } + // put together the decrypted header: first byte + rest (unprotected) + packet number + // which is needed for payload decryption + decryptedHeader := []byte{ci.FirstByte} + decryptedHeader = append(decryptedHeader, raw[1:ci.PnOffset]...) + decryptedHeader = append(decryptedHeader, ci.DecryptedPacketNumber...) + + // remove packet protection + // the decryption requires the initial client secret, and the decrypted header as associated data + ci.DecryptedPayload = decryptPayload(ci.Payload, clientSecret, decryptedHeader) + return nil +} + +// https://www.rfc-editor.org/rfc/rfc9001.html#name-packet-protection +// +// decryptPayload decrypts the payload of the packet by removing AEAD packet protection. +// AEAD decryption requires the initial client secret and associated data. +// Returns the decrypted payload. +func decryptPayload(payload, clientSecret []byte, ad []byte) []byte { + // derive AEAD packet protection key and initialization vectors from the intial client secret + key, iv := computeInitialKeyAndIV(clientSecret) + cipher := aeadAESGCMTLS13(key, iv) + + nonceBuf := make([]byte, cipher.NonceSize()) + binary.BigEndian.PutUint64(nonceBuf[len(nonceBuf)-8:], uint64(0)) + + // decrypt the payload + decrypted, err := cipher.Open(nil, nonceBuf, payload, ad) + if err != nil { + panic(err) + } + return decrypted +} + +// QUICFrame contains the content of a QUIC data frame. +// The payload of QUIC packets, after removing packet protection, consists of a sequence of complete frames. +type QUICFrame struct { + // Type is the QUIC frame type, as defined in RFC9000 + Type int + // Offset is the byte offset in the stream (stream-level sequence number) + Offset uint64 + // Length is the length of the data payload + Length uint64 + // Payload is the variable-length data payload + Payload []byte +} + +// nextFrame returns the next frame. +// Note that in a QUIC Client Initial there is usually only one frame (CRYPTO). +// It skips PADDING frames. +// +// Returns the next non-padding frame. +func nextFrame(cursor *bytes.Reader) (*QUICFrame, error) { + // read the first byte indicating the frame type + firstByte, err := cursor.ReadByte() + if err != nil { + return nil, newErrQUICParse("QUIC frame: cannot read first byte of frame") + } + for cursor.Len() > 0 { + switch firstByte { + // Skip PADDING frame + case 0x00: + var nextByte byte + for nextByte == 0 { + if nextByte, err = cursor.ReadByte(); err != nil { + return nil, newErrQUICParse("QUIC frame: cannot read first byte of frame") + } + } + continue + // CRYPTO frame https://www.rfc-editor.org/rfc/rfc9000.html#name-crypto-frames + case 0x06: + // create a new frame + crypto := &QUICFrame{ + Type: 0x06, + } + // the stream offset of the CRYPTO data + if crypto.Offset, err = quicvarint.Read(cursor); err != nil { + return nil, newErrQUICParse("CRYPTO frame: cannot read stream offset") + } + // the length of the data field in this CRYPTO frame + if crypto.Length, err = quicvarint.Read(cursor); err != nil { + return nil, newErrQUICParse("CRYPTO frame: cannot read data length") + } + // the cryptographic message data + crypto.Payload = make([]byte, crypto.Length) + if _, err = cursor.Read(crypto.Payload); err != nil { + return nil, newErrQUICParse("CRYPTO frame: cannot read data") + } + return crypto, nil + default: + break + } + } + return nil, newErrQUICParse("unsupported QUIC frame type") +} + +// ExtractQUICServerName takes in input bytes read from the network, attempts +// to determine whether this is a QUIC Client Initial message, +// and, if affirmative, attempts to extract the server name. +func ExtractQUICServerName(rawInput []byte) (string, error) { + if len(rawInput) <= 0 { + return "", newErrTLSParse("no data") + } + // unmarshal the packet + packet, err := UnmarshalLongHeaderPacket(rawInput) + if err != nil { + return "", err + } + // decrypt the initial packet + err = packet.Decrypt(rawInput) + if err != nil { + return "", err + } + ci, ok := packet.(*QUICClientInitial) + if !ok { + return "", newErrQUICParse("unexpected packet type") + } + // iterate through contained frames to find CRYPTO frame with SNI + frame, err := nextFrame(bytes.NewReader(ci.DecryptedPayload)) + for frame != nil { + if err != nil { + return "", err + } + switch frame.Type { + case 0x06: + // unmarshaling a decrypted QUIC CRYPTO frame inside a Client Initial + // packet is like unmarshaling a TLS Client Hello (TLS 1.3) + hx, err := UnmarshalTLSHandshakeMsg(frame.Payload) + if err != nil { + return "", err + } + if hx.ClientHello == nil { + return "", newErrTLSParse("no client hello") + } + exts, err := UnmarshalTLSExtensions(hx.ClientHello.Extensions) + if err != nil { + return "", err + } + snext, found := FindTLSServerNameExtension(exts) + if !found { + return "", newErrTLSParse("no server name extension") + } + ret, err := UnmarshalTLSServerNameExtension(snext.Data) + return ret, err + default: + frame, err = nextFrame(bytes.NewReader(ci.DecryptedPayload)) + continue + } + } + return "", newErrQUICParse("no CRYPTO frame") +} diff --git a/quicparse_test.go b/quicparse_test.go new file mode 100644 index 0000000..7b4a9f3 --- /dev/null +++ b/quicparse_test.go @@ -0,0 +1,141 @@ +package netem + +import ( + "bytes" + "testing" +) + +// QUICInitialBytes contains a QUIC Initial obtained +// from https://quic.xargs.org/#client-initial-packet. +var QUICInitialBytes = []byte{ + // [0:1] packet header byte + 0xcd, + + // [1:5] QUIC version + 0x00, 0x00, 0x00, 0x01, + + // [5:14] destination connection ID (first byte: length) + 0x08, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, + 0x07, + + // [14:20] source connection ID (first byte: length) + 0x05, 0x63, 0x5f, 0x63, 0x69, 0x64, + + // [20:21] token + 0x00, + + // [21:23] packet length + 0x41, 0x03, + + // [23:24] packet number + 0x98, + + // encrypted payload + 0x1c, 0x36, 0xa7, 0xed, 0x78, 0x71, 0x6b, 0xe9, 0x71, 0x1b, 0xa4, 0x98, 0xb7, 0xed, 0x86, 0x84, 0x43, 0xbb, 0x2e, 0x0c, 0x51, 0x4d, 0x4d, 0x84, 0x8e, 0xad, 0xcc, 0x7a, 0x00, 0xd2, 0x5c, 0xe9, 0xf9, 0xaf, 0xa4, 0x83, 0x97, 0x80, 0x88, 0xde, 0x83, 0x6b, 0xe6, 0x8c, 0x0b, 0x32, 0xa2, 0x45, 0x95, 0xd7, 0x81, 0x3e, 0xa5, 0x41, 0x4a, 0x91, 0x99, 0x32, 0x9a, 0x6d, 0x9f, 0x7f, 0x76, 0x0d, 0xd8, 0xbb, 0x24, 0x9b, 0xf3, 0xf5, 0x3d, 0x9a, 0x77, 0xfb, 0xb7, 0xb3, 0x95, 0xb8, 0xd6, 0x6d, 0x78, 0x79, 0xa5, 0x1f, 0xe5, 0x9e, 0xf9, 0x60, 0x1f, 0x79, 0x99, 0x8e, 0xb3, 0x56, 0x8e, 0x1f, 0xdc, 0x78, 0x9f, 0x64, 0x0a, 0xca, 0xb3, 0x85, 0x8a, 0x82, 0xef, 0x29, 0x30, 0xfa, 0x5c, 0xe1, 0x4b, 0x5b, 0x9e, 0xa0, 0xbd, 0xb2, 0x9f, 0x45, 0x72, 0xda, 0x85, 0xaa, 0x3d, 0xef, 0x39, 0xb7, 0xef, 0xaf, 0xff, 0xa0, 0x74, 0xb9, 0x26, 0x70, 0x70, 0xd5, 0x0b, 0x5d, 0x07, 0x84, 0x2e, 0x49, 0xbb, 0xa3, 0xbc, 0x78, 0x7f, 0xf2, 0x95, 0xd6, 0xae, 0x3b, 0x51, 0x43, 0x05, 0xf1, 0x02, 0xaf, 0xe5, 0xa0, 0x47, 0xb3, 0xfb, 0x4c, 0x99, 0xeb, 0x92, 0xa2, 0x74, 0xd2, 0x44, 0xd6, 0x04, 0x92, 0xc0, 0xe2, 0xe6, 0xe2, 0x12, 0xce, 0xf0, 0xf9, 0xe3, 0xf6, 0x2e, 0xfd, 0x09, 0x55, 0xe7, 0x1c, 0x76, 0x8a, 0xa6, 0xbb, 0x3c, 0xd8, 0x0b, 0xbb, 0x37, 0x55, 0xc8, 0xb7, 0xeb, 0xee, 0x32, 0x71, 0x2f, 0x40, 0xf2, 0x24, 0x51, 0x19, 0x48, 0x70, 0x21, 0xb4, 0xb8, 0x4e, 0x15, 0x65, 0xe3, 0xca, 0x31, 0x96, 0x7a, 0xc8, 0x60, 0x4d, 0x40, 0x32, 0x17, 0x0d, 0xec, 0x28, 0x0a, 0xee, 0xfa, 0x09, 0x5d, 0x08, + + // authentication tag + 0xb3, 0xb7, 0x24, 0x1e, 0xf6, 0x64, 0x6a, 0x6c, 0x86, 0xe5, 0xc6, 0x2c, 0xe0, 0x8b, 0xe0, 0x99, +} + +type test struct { + input []byte + expect QUICClientInitial + expectErr bool + name string +} + +var firstByteErr string = "quicparse: parse error: QUIC packet: cannot read first byte" + +func TestUnmarshalQUICInitial(t *testing.T) { + tests := []test{ + { + name: "with valid Client Initial", + input: QUICInitialBytes, + expect: QUICClientInitial{ + FirstByte: QUICInitialBytes[0], + QUICVersion: 1, + DestinationID: QUICInitialBytes[6:14], + SourceID: QUICInitialBytes[15:20], + Token: nil, + Length: 259, + PnOffset: 23, + Payload: nil, + DecryptedPacketNumber: nil, + DecryptedPayload: nil, + }, + expectErr: false, + }, + { + name: "with empty input", + input: []byte{}, + expect: QUICClientInitial{ + FirstByte: 0, + QUICVersion: 0, + DestinationID: nil, + SourceID: nil, + Token: nil, + Length: 0, + PnOffset: 0, + Payload: nil, + DecryptedPacketNumber: nil, + DecryptedPayload: nil, + }, + expectErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + packet, err := UnmarshalLongHeaderPacket(tt.input) + if (err != nil) != tt.expectErr { + t.Fatal("unexpected error", err, tt.expectErr) + } + if err != nil { + return + } + ci, ok := packet.(*QUICClientInitial) + if !ok { + t.Fatal("unexpected packet type, expected QUICClientInitial") + } + if ci.FirstByte != tt.expect.FirstByte { + t.Fatal("unexpected First Byte", ci.FirstByte) + } + if ci.QUICVersion != tt.expect.QUICVersion { + t.Fatal("unexpected QUIC version") + } + if !bytes.Equal(ci.DestinationID, tt.expect.DestinationID) { + t.Fatal("unexpected QUIC Destination Connection ID") + } + if !bytes.Equal(ci.SourceID, tt.expect.SourceID) { + t.Fatal("unexpected QUIC Source Connection ID") + } + if !bytes.Equal(ci.Token, tt.expect.Token) { + t.Fatal("unexpected QUIC Token") + } + if ci.Length != tt.expect.Length { + t.Fatalf("unexpected Length %b %b", ci.Length, tt.expect.Length) + } + if ci.PnOffset != tt.expect.PnOffset { + t.Fatal("unexpected Packet Number Offset") + } + if !bytes.Equal(ci.Payload, tt.expect.Payload) { + t.Fatal("unexpected encrypted payload", len(ci.Payload), len(tt.expect.Payload)) + } + if !bytes.Equal(ci.DecryptedPacketNumber, tt.expect.DecryptedPacketNumber) { + t.Fatal("unexpected decrypted packet number") + } + if !bytes.Equal(ci.DecryptedPayload, tt.expect.DecryptedPayload) { + t.Fatal("unexpected decrypted payload") + } + }) + } +} + +func TestExtractQUICServerNameExtractQUICServerName(t *testing.T) { + sni, err := ExtractQUICServerName(QUICInitialBytes) + if err != nil { + t.Fatal("unexpected error", err) + } + if sni != "example.ulfheim.net" { + t.Fatal("unexpected Server Name") + } +}