diff --git a/src/__tests__/bridges.test.ts b/src/__tests__/bridges.test.ts index dfd5d49..b09cde9 100644 --- a/src/__tests__/bridges.test.ts +++ b/src/__tests__/bridges.test.ts @@ -119,6 +119,28 @@ describe("EthersAdapterSigner", () => { }) }) + it("infers the root primaryType when omitted, not the first types key", async () => { + const adapter = createMockAdapter() + const signer = walletAdapterToEthersSigner(adapter, {}) + // `Person` (a dependency) is declared before `Mail` (the root). The primary + // type is the struct not referenced by any other, i.e. `Mail` - not the + // first key. + const types = { + Person: [{ name: "wallet", type: "address" }], + Mail: [ + { name: "from", type: "Person" }, + { name: "contents", type: "string" }, + ], + } + await signer.signTypedData({ name: "Test" }, types, { + from: { wallet: "0x1234567890abcdef1234567890abcdef12345678" }, + contents: "hi", + }) + expect(adapter.signTypedData).toHaveBeenCalledWith( + expect.objectContaining({ primaryType: "Mail" }), + ) + }) + it("connect returns new signer with different provider", () => { const adapter = createMockAdapter() const signer = walletAdapterToEthersSigner(adapter, { id: 1 }) diff --git a/src/__tests__/eip712.test.ts b/src/__tests__/eip712.test.ts new file mode 100644 index 0000000..a248aa2 --- /dev/null +++ b/src/__tests__/eip712.test.ts @@ -0,0 +1,22 @@ +import { describe, expect, it } from "vitest" +import { hashTypedData } from "../util/eip712.js" + +describe("hashTypedData", () => { + const domain = { name: "Test", version: "1", chainId: 1 } + + it("rejects an odd-length hex bytes value instead of silently truncating", () => { + const types = { Doc: [{ name: "data", type: "bytes" }] } + // 5 hex digits -> odd length. Previously the last nibble was dropped and a + // wrong hash was signed; it must now throw instead. + expect(() => + hashTypedData(domain, "Doc", { data: "0xabcde" }, types), + ).toThrow(/odd length/i) + }) + + it("still hashes a valid even-length bytes value", () => { + const types = { Doc: [{ name: "data", type: "bytes" }] } + const hash = hashTypedData(domain, "Doc", { data: "0xabcd" }, types) + expect(hash).toBeInstanceOf(Uint8Array) + expect(hash.length).toBe(32) + }) +}) diff --git a/src/bridges/ethers.ts b/src/bridges/ethers.ts index b186481..77c506c 100644 --- a/src/bridges/ethers.ts +++ b/src/bridges/ethers.ts @@ -108,8 +108,7 @@ export class EthersAdapterSigner { domain, types, message: value, - primaryType: - primaryType ?? Object.keys(types).find(t => t !== "EIP712Domain") ?? "", + primaryType: primaryType ?? inferPrimaryType(types), }) } @@ -117,3 +116,22 @@ export class EthersAdapterSigner { return new EthersAdapterSigner(this.adapter, provider) } } + +/** + * Infer the EIP-712 primary type the way ethers.js does: the struct that is not + * referenced as a field type by any other struct (the root of the type graph). + * The previous heuristic took the first key in `types`, which signs the wrong + * struct when the root is not declared first (e.g. dependencies listed above it). + */ +function inferPrimaryType(types: Record): string { + const named = Object.keys(types).filter(t => t !== "EIP712Domain") + const referenced = new Set() + for (const name of named) { + for (const field of types[name] ?? []) { + const base = String(field.type).replace(/(\[\d*\])+$/, "") + if (base in types) referenced.add(base) + } + } + const roots = named.filter(t => !referenced.has(t)) + return roots[0] ?? named[0] ?? "" +} diff --git a/src/util/eip712.ts b/src/util/eip712.ts index d1f75f2..6839326 100644 --- a/src/util/eip712.ts +++ b/src/util/eip712.ts @@ -187,6 +187,12 @@ function encodeValue( function hexToBytes(hex: string): Uint8Array { const clean = hex.startsWith("0x") ? hex.slice(2) : hex + // An odd-length hex string would otherwise be silently truncated (the last + // nibble dropped), producing a wrong EIP-712 hash for `bytes` values. Reject + // it so signing fails loudly instead of signing the wrong data. + if (clean.length % 2 !== 0) { + throw new Error(`Invalid hex value: odd length (${clean.length} digits)`) + } const bytes = new Uint8Array(clean.length / 2) for (let i = 0; i < bytes.length; i++) { bytes[i] = Number.parseInt(clean.slice(i * 2, i * 2 + 2), 16)