diff --git a/pyproject.toml b/pyproject.toml index 42349eb..c2873b1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,6 +59,9 @@ select = [ "TC", # flake8-type-checking ] +[tool.ruff.lint.per-file-ignores] +"tests/**" = ["SIM117"] # Nested with statements are idiomatic in test mocking + [tool.ruff.lint.flake8-bugbear] extend-immutable-calls = ["pydantic.Field"] diff --git a/src/decibel/_base.py b/src/decibel/_base.py index 9685fa2..0441e79 100644 --- a/src/decibel/_base.py +++ b/src/decibel/_base.py @@ -18,7 +18,12 @@ from aptos_sdk.ed25519 import Signature as Ed25519Signature from aptos_sdk.transactions import FeePayerRawTransaction, SignedTransaction -from ._constants import DEFAULT_TXN_CONFIRM_TIMEOUT, DEFAULT_TXN_SUBMIT_TIMEOUT +from ._constants import ( + DEFAULT_TXN_CONFIRM_TIMEOUT, + DEFAULT_TXN_SUBMIT_TIMEOUT, + HTTP_LIMITS, + HTTP_TIMEOUT, +) from ._exceptions import TxnConfirmError, TxnSubmitError from ._fee_pay import ( PendingTransactionResponse, @@ -50,6 +55,14 @@ DEFAULT_GAS_ESTIMATE = 100 MAX_GAS_UNITS_LIMIT = 2_000_000 +_POLL_DELAYS = (0.2, 0.2, 0.5, 0.5, 1.0) + + +def _poll_delay(index: int) -> float: + if index < len(_POLL_DELAYS): + return _POLL_DELAYS[index] + return 1.0 + @dataclass class BaseSDKOptions: @@ -82,6 +95,7 @@ def __init__( self._chain_id = config.chain_id self._abi_registry = AbiRegistry(chain_id=config.chain_id) self._aptos = RestClient(config.fullnode_url) + self._http_client = httpx.AsyncClient(limits=HTTP_LIMITS, timeout=HTTP_TIMEOUT) opts = opts or BaseSDKOptions() self._skip_simulate = opts.skip_simulate @@ -124,6 +138,20 @@ def time_delta_ms(self) -> int: def time_delta_ms(self, value: int) -> None: self._time_delta_ms = value + async def close(self) -> None: + await self._http_client.aclose() + + async def __aenter__(self) -> BaseSDK: + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: object, + ) -> None: + await self.close() + def _get_abi(self, function_id: str) -> MoveFunction | None: return self._abi_registry.get_function(function_id) @@ -181,6 +209,7 @@ async def submit_tx( self._config, transaction, sender_authenticator, + client=self._http_client, txn_submit_timeout=txn_submit_timeout, ) @@ -277,8 +306,7 @@ async def _fetch_gas_price_estimation(self) -> int: url = f"{self._config.fullnode_url}/estimate_gas_price" headers = self._build_node_headers() - async with httpx.AsyncClient() as client: - response = await client.get(url, headers=headers) + response = await self._http_client.get(url, headers=headers, timeout=5.0) if not response.is_success: raise ValueError(f"Failed to fetch gas price: {response.status_code} - {response.text}") @@ -296,13 +324,12 @@ async def _simulate_transaction( bcs_bytes = self._serialize_for_simulation(transaction) - async with httpx.AsyncClient() as client: - response = await client.post( - url, - content=bcs_bytes, - headers=headers, - params={"estimate_max_gas_amount": "true", "estimate_gas_unit_price": "true"}, - ) + response = await self._http_client.post( + url, + content=bcs_bytes, + headers=headers, + params={"estimate_max_gas_amount": "true", "estimate_gas_unit_price": "true"}, + ) if not response.is_success: raise ValueError( @@ -327,10 +354,9 @@ async def _submit_direct( bcs_bytes = self._serialize_signed_transaction(transaction, sender_authenticator) - async with httpx.AsyncClient() as client: - response = await client.post( - url, content=bcs_bytes, headers=headers, timeout=txn_submit_timeout - ) + response = await self._http_client.post( + url, content=bcs_bytes, headers=headers, timeout=txn_submit_timeout + ) if not response.is_success: raise ValueError( @@ -353,7 +379,6 @@ async def _wait_for_transaction( self, tx_hash: str, txn_confirm_timeout: float | None = None, # Uses DEFAULT_TXN_CONFIRM_TIMEOUT if None - poll_interval_secs: float = 1.0, ) -> dict[str, Any]: if txn_confirm_timeout is None: txn_confirm_timeout = DEFAULT_TXN_CONFIRM_TIMEOUT @@ -361,32 +386,34 @@ async def _wait_for_transaction( headers = self._build_node_headers() start_time = time.time() - async with httpx.AsyncClient() as client: - while True: - try: - response = await client.get(url, headers=headers) - except httpx.ConnectTimeout: - pass - except httpx.ReadTimeout: - pass - except httpx.ConnectError: - pass - else: - if response.is_success: - data = cast("dict[str, Any]", response.json()) - tx_type = data.get("type") - if tx_type == "pending_transaction": - pass - elif data.get("success") is True: - return data - elif data.get("success") is False: - vm_status = data.get("vm_status", "Unknown error") - raise TxnConfirmError(tx_hash, f"failed: {vm_status}") - - if time.time() - start_time > txn_confirm_timeout: - raise TxnConfirmError(tx_hash, f"did not confirm within {txn_confirm_timeout}s") - - await self._async_sleep(poll_interval_secs) + poll_index = 0 + while True: + try: + response = await self._http_client.get(url, headers=headers, timeout=5.0) + except httpx.ConnectTimeout: + pass + except httpx.ReadTimeout: + pass + except httpx.ConnectError: + pass + else: + if response.is_success: + data = cast("dict[str, Any]", response.json()) + tx_type = data.get("type") + if tx_type == "pending_transaction": + pass + elif data.get("success") is True: + return data + elif data.get("success") is False: + vm_status = data.get("vm_status", "Unknown error") + raise TxnConfirmError(tx_hash, f"failed: {vm_status}") + + if time.time() - start_time > txn_confirm_timeout: + raise TxnConfirmError(tx_hash, f"did not confirm within {txn_confirm_timeout}s") + + delay = _poll_delay(poll_index) + poll_index += 1 + await self._async_sleep(delay) async def _async_sleep(self, seconds: float) -> None: import asyncio @@ -461,7 +488,10 @@ def __init__( self._node_api_key = opts.node_api_key self._gas_price_manager = opts.gas_price_manager self._time_delta_ms = opts.time_delta_ms - self._http_client = opts.http_client + self._http_client = opts.http_client or httpx.Client( + limits=HTTP_LIMITS, timeout=HTTP_TIMEOUT + ) + self._owns_http_client = opts.http_client is None if config.chain_id is None: logger.warning( @@ -493,6 +523,21 @@ def time_delta_ms(self) -> int: def time_delta_ms(self, value: int) -> None: self._time_delta_ms = value + def close(self) -> None: + if self._owns_http_client: + self._http_client.close() + + def __enter__(self) -> BaseSDKSync: + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: object, + ) -> None: + self.close() + def _get_abi(self, function_id: str) -> MoveFunction | None: return self._abi_registry.get_function(function_id) @@ -550,6 +595,7 @@ def submit_tx( self._config, transaction, sender_authenticator, + client=self._http_client, txn_submit_timeout=txn_submit_timeout, ) @@ -644,19 +690,11 @@ def _fetch_gas_price_estimation(self) -> int: url = f"{self._config.fullnode_url}/estimate_gas_price" headers = self._build_node_headers() - def make_request(client: httpx.Client) -> int: - response = client.get(url, headers=headers) - if not response.is_success: - raise ValueError( - f"Failed to fetch gas price: {response.status_code} - {response.text}" - ) - data = cast("dict[str, Any]", response.json()) - return int(data.get("gas_estimate", DEFAULT_GAS_ESTIMATE)) - - if self._http_client is not None: - return make_request(self._http_client) - with httpx.Client() as client: - return make_request(client) + response = self._http_client.get(url, headers=headers, timeout=5.0) + if not response.is_success: + raise ValueError(f"Failed to fetch gas price: {response.status_code} - {response.text}") + data = cast("dict[str, Any]", response.json()) + return int(data.get("gas_estimate", DEFAULT_GAS_ESTIMATE)) def _simulate_transaction( self, @@ -667,26 +705,20 @@ def _simulate_transaction( headers["Content-Type"] = "application/x.aptos.signed_transaction+bcs" bcs_bytes = self._serialize_for_simulation(transaction) - def make_request(client: httpx.Client) -> dict[str, Any]: - response = client.post( - url, - content=bcs_bytes, - headers=headers, - params={"estimate_max_gas_amount": "true", "estimate_gas_unit_price": "true"}, + response = self._http_client.post( + url, + content=bcs_bytes, + headers=headers, + params={"estimate_max_gas_amount": "true", "estimate_gas_unit_price": "true"}, + ) + if not response.is_success: + raise ValueError( + f"Transaction simulation failed: {response.status_code} - {response.text}" ) - if not response.is_success: - raise ValueError( - f"Transaction simulation failed: {response.status_code} - {response.text}" - ) - data: list[dict[str, Any]] | dict[str, Any] = response.json() - if isinstance(data, list) and len(data) > 0: - return data[0] - raise ValueError("Transaction simulation returned empty results") - - if self._http_client is not None: - return make_request(self._http_client) - with httpx.Client() as client: - return make_request(client) + data: list[dict[str, Any]] | dict[str, Any] = response.json() + if isinstance(data, list) and len(data) > 0: + return data[0] + raise ValueError("Transaction simulation returned empty results") def _submit_direct( self, @@ -699,35 +731,28 @@ def _submit_direct( headers["Content-Type"] = "application/x.aptos.signed_transaction+bcs" bcs_bytes = self._serialize_signed_transaction(transaction, sender_authenticator) - def make_request(client: httpx.Client) -> PendingTransactionResponse: - response = client.post( - url, content=bcs_bytes, headers=headers, timeout=txn_submit_timeout - ) - if not response.is_success: - raise ValueError( - f"Transaction submission failed: {response.status_code} - {response.text}" - ) - data = cast("dict[str, Any]", response.json()) - raw_txn = transaction.raw_transaction - return PendingTransactionResponse( - hash=str(data.get("hash", "")), - sender=str(raw_txn.sender), - sequence_number=str(raw_txn.sequence_number), - max_gas_amount=str(raw_txn.max_gas_amount), - gas_unit_price=str(raw_txn.gas_unit_price), - expiration_timestamp_secs=str(raw_txn.expiration_timestamps_secs), + response = self._http_client.post( + url, content=bcs_bytes, headers=headers, timeout=txn_submit_timeout + ) + if not response.is_success: + raise ValueError( + f"Transaction submission failed: {response.status_code} - {response.text}" ) - - if self._http_client is not None: - return make_request(self._http_client) - with httpx.Client() as client: - return make_request(client) + data = cast("dict[str, Any]", response.json()) + raw_txn = transaction.raw_transaction + return PendingTransactionResponse( + hash=str(data.get("hash", "")), + sender=str(raw_txn.sender), + sequence_number=str(raw_txn.sequence_number), + max_gas_amount=str(raw_txn.max_gas_amount), + gas_unit_price=str(raw_txn.gas_unit_price), + expiration_timestamp_secs=str(raw_txn.expiration_timestamps_secs), + ) def _wait_for_transaction( self, tx_hash: str, txn_confirm_timeout: float | None = None, # Uses DEFAULT_TXN_CONFIRM_TIMEOUT if None - poll_interval_secs: float = 1.0, ) -> dict[str, Any]: if txn_confirm_timeout is None: txn_confirm_timeout = DEFAULT_TXN_CONFIRM_TIMEOUT @@ -735,9 +760,17 @@ def _wait_for_transaction( headers = self._build_node_headers() start_time = time.time() - def poll_loop(client: httpx.Client) -> dict[str, Any]: - while True: - response = client.get(url, headers=headers) + poll_index = 0 + while True: + try: + response = self._http_client.get(url, headers=headers, timeout=5.0) + except httpx.ConnectTimeout: + pass + except httpx.ReadTimeout: + pass + except httpx.ConnectError: + pass + else: if response.is_success: data = cast("dict[str, Any]", response.json()) tx_type = data.get("type") @@ -748,14 +781,11 @@ def poll_loop(client: httpx.Client) -> dict[str, Any]: elif data.get("success") is False: vm_status = data.get("vm_status", "Unknown error") raise TxnConfirmError(tx_hash, f"failed: {vm_status}") - if time.time() - start_time > txn_confirm_timeout: - raise TxnConfirmError(tx_hash, f"did not confirm within {txn_confirm_timeout}s") - time.sleep(poll_interval_secs) - - if self._http_client is not None: - return poll_loop(self._http_client) - with httpx.Client() as client: - return poll_loop(client) + if time.time() - start_time > txn_confirm_timeout: + raise TxnConfirmError(tx_hash, f"did not confirm within {txn_confirm_timeout}s") + delay = _poll_delay(poll_index) + poll_index += 1 + time.sleep(delay) def _build_node_headers(self) -> dict[str, str]: headers: dict[str, str] = {} diff --git a/src/decibel/_constants.py b/src/decibel/_constants.py index ad7fa15..26d1693 100644 --- a/src/decibel/_constants.py +++ b/src/decibel/_constants.py @@ -3,6 +3,7 @@ from dataclasses import dataclass from enum import Enum +import httpx from aptos_sdk.account_address import AccountAddress __all__ = [ @@ -13,6 +14,8 @@ "DEFAULT_COMPAT_VERSION", "DEFAULT_TXN_CONFIRM_TIMEOUT", "DEFAULT_TXN_SUBMIT_TIMEOUT", + "HTTP_LIMITS", + "HTTP_TIMEOUT", "MAINNET_CONFIG", "NETNA_CONFIG", "TESTNET_CONFIG", @@ -32,6 +35,10 @@ # Default is 10 seconds (should be shorter than confirmation timeout) DEFAULT_TXN_SUBMIT_TIMEOUT = 10.0 +# Shared HTTP client connection pool limits +HTTP_LIMITS = httpx.Limits(max_connections=20, max_keepalive_connections=10) +HTTP_TIMEOUT = httpx.Timeout(10.0, connect=5.0) + class Network(str, Enum): MAINNET = "mainnet" diff --git a/src/decibel/_gas_price_manager.py b/src/decibel/_gas_price_manager.py index 0d6285a..144f26d 100644 --- a/src/decibel/_gas_price_manager.py +++ b/src/decibel/_gas_price_manager.py @@ -34,6 +34,8 @@ class GasPriceManagerOptions: node_api_key: str | None = None multiplier: float = 2.0 refresh_interval_seconds: float = 60.0 + http_client: httpx.AsyncClient | None = None + http_client_sync: httpx.Client | None = None def _build_auth_headers(api_key: str | None) -> dict[str, str]: @@ -56,6 +58,7 @@ def __init__( self._is_initialized = False self._refresh_interval_seconds = self._opts.refresh_interval_seconds self._multiplier = self._opts.multiplier + self._http_client = self._opts.http_client @property def gas_price(self) -> int | None: @@ -99,8 +102,11 @@ async def fetch_gas_price_estimation(self) -> int: url = f"{self._config.fullnode_url}/estimate_gas_price" headers = _build_auth_headers(self._opts.node_api_key) - async with httpx.AsyncClient() as client: - response = await client.get(url, headers=headers) + if self._http_client is not None: + response = await self._http_client.get(url, headers=headers, timeout=5.0) + else: + async with httpx.AsyncClient() as temp_client: + response = await temp_client.get(url, headers=headers, timeout=5.0) if not response.is_success: raise ValueError(f"Failed to fetch gas price: {response.status_code} - {response.text}") @@ -168,6 +174,7 @@ def __init__( self._is_initialized = False self._refresh_interval_seconds = self._opts.refresh_interval_seconds self._multiplier = self._opts.multiplier + self._http_client = self._opts.http_client_sync @property def gas_price(self) -> int | None: @@ -214,8 +221,11 @@ def fetch_gas_price_estimation(self) -> int: url = f"{self._config.fullnode_url}/estimate_gas_price" headers = _build_auth_headers(self._opts.node_api_key) - with httpx.Client() as client: - response = client.get(url, headers=headers) + if self._http_client is not None: + response = self._http_client.get(url, headers=headers, timeout=5.0) + else: + with httpx.Client() as temp_client: + response = temp_client.get(url, headers=headers, timeout=5.0) if not response.is_success: raise ValueError(f"Failed to fetch gas price: {response.status_code} - {response.text}") diff --git a/src/decibel/_order_status.py b/src/decibel/_order_status.py index 3b92ce3..f7bae4a 100644 --- a/src/decibel/_order_status.py +++ b/src/decibel/_order_status.py @@ -38,8 +38,16 @@ class OrderStatus(BaseModel): class OrderStatusClient: - def __init__(self, config: DecibelConfig) -> None: + def __init__( + self, + config: DecibelConfig, + *, + http_client: httpx.AsyncClient | None = None, + http_client_sync: httpx.Client | None = None, + ) -> None: self._config = config + self._http_client = http_client + self._http_client_sync = http_client_sync async def get_order_status( self, @@ -56,12 +64,14 @@ async def get_order_status( "account": user_address, } + effective_client = client or self._http_client + try: - if client is not None: - response = await client.get(url, params=params) + if effective_client is not None: + response = await effective_client.get(url, params=params, timeout=5.0) else: async with httpx.AsyncClient() as temp_client: - response = await temp_client.get(url, params=params) + response = await temp_client.get(url, params=params, timeout=5.0) if response.status_code == 404: return None @@ -89,12 +99,14 @@ def get_order_status_sync( "account": user_address, } + effective_client = client or self._http_client_sync + try: - if client is not None: - response = client.get(url, params=params) + if effective_client is not None: + response = effective_client.get(url, params=params, timeout=5.0) else: with httpx.Client() as temp_client: - response = temp_client.get(url, params=params) + response = temp_client.get(url, params=params, timeout=5.0) if response.status_code == 404: return None diff --git a/src/decibel/admin.py b/src/decibel/admin.py index d49de22..5e32399 100644 --- a/src/decibel/admin.py +++ b/src/decibel/admin.py @@ -2,7 +2,6 @@ from typing import TYPE_CHECKING, Any, cast -import httpx from aptos_sdk.account_address import AccountAddress from ._base import BaseSDK, BaseSDKSync @@ -850,19 +849,13 @@ def usdc_balance( ) -> int: addr_str = str(addr) if isinstance(addr, AccountAddress) else addr - def make_request(client: httpx.Client) -> int: - response = client.post( - f"{self._config.fullnode_url}/view", - json={ - "function": "0x1::primary_fungible_store::balance", - "type_arguments": ["0x1::fungible_asset::Metadata"], - "arguments": [addr_str, self._config.deployment.usdc], - }, - ) - data = cast("list[Any]", response.json()) - return int(data[0]) - - if self._http_client is not None: - return make_request(self._http_client) - with httpx.Client() as client: - return make_request(client) + response = self._http_client.post( + f"{self._config.fullnode_url}/view", + json={ + "function": "0x1::primary_fungible_store::balance", + "type_arguments": ["0x1::fungible_asset::Metadata"], + "arguments": [addr_str, self._config.deployment.usdc], + }, + ) + data = cast("list[Any]", response.json()) + return int(data[0]) diff --git a/src/decibel/read/__init__.py b/src/decibel/read/__init__.py index d9db5fa..359b152 100644 --- a/src/decibel/read/__init__.py +++ b/src/decibel/read/__init__.py @@ -2,8 +2,10 @@ from typing import TYPE_CHECKING +import httpx from aptos_sdk.async_client import RestClient +from .._constants import HTTP_LIMITS, HTTP_TIMEOUT from ._account_overview import ( AccountOverview, AccountOverviewReader, @@ -163,7 +165,14 @@ def __init__( ) -> None: aptos = RestClient(config.fullnode_url) ws = DecibelWsSubscription(config, api_key, on_ws_error) - deps = ReaderDeps(config=config, ws=ws, aptos=aptos, api_key=api_key) + self._http_client = httpx.AsyncClient(limits=HTTP_LIMITS, timeout=HTTP_TIMEOUT) + deps = ReaderDeps( + config=config, + ws=ws, + aptos=aptos, + api_key=api_key, + http_client=self._http_client, + ) self.ws = ws self.account_overview = AccountOverviewReader(deps) @@ -190,6 +199,21 @@ def __init__( self.vaults = VaultsReader(deps) self.trading_points = TradingPointsReader(deps) + async def close(self) -> None: + await self.ws.close() + await self._http_client.aclose() + + async def __aenter__(self) -> DecibelReadDex: + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: object, + ) -> None: + await self.close() + __all__ = [ "AccountOverview", diff --git a/src/decibel/read/_base.py b/src/decibel/read/_base.py index 40c5fac..6b3397d 100644 --- a/src/decibel/read/_base.py +++ b/src/decibel/read/_base.py @@ -1,6 +1,6 @@ from __future__ import annotations -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any, TypeVar from pydantic import BaseModel @@ -15,6 +15,7 @@ ) if TYPE_CHECKING: + import httpx from aptos_sdk.async_client import RestClient from .._constants import DecibelConfig @@ -34,6 +35,8 @@ class ReaderDeps: ws: DecibelWsSubscription aptos: RestClient api_key: str | None = None + http_client: httpx.AsyncClient | None = field(default=None, repr=False) + http_client_sync: httpx.Client | None = field(default=None, repr=False) class BaseReader: @@ -64,6 +67,7 @@ async def get_request( url=url, params=params, api_key=self._deps.api_key, + client=self._deps.http_client, ) async def post_request( @@ -78,6 +82,7 @@ async def post_request( url=url, body=body, api_key=self._deps.api_key, + client=self._deps.http_client, ) async def patch_request( @@ -92,6 +97,7 @@ async def patch_request( url=url, body=body, api_key=self._deps.api_key, + client=self._deps.http_client, ) def get_request_sync( @@ -106,6 +112,7 @@ def get_request_sync( url=url, params=params, api_key=self._deps.api_key, + client=self._deps.http_client_sync, ) def post_request_sync( @@ -120,6 +127,7 @@ def post_request_sync( url=url, body=body, api_key=self._deps.api_key, + client=self._deps.http_client_sync, ) def patch_request_sync( @@ -134,4 +142,5 @@ def patch_request_sync( url=url, body=body, api_key=self._deps.api_key, + client=self._deps.http_client_sync, ) diff --git a/src/decibel/write/__init__.py b/src/decibel/write/__init__.py index 894dc00..7f394fd 100644 --- a/src/decibel/write/__init__.py +++ b/src/decibel/write/__init__.py @@ -59,7 +59,7 @@ def __init__( opts: BaseSDKOptions | None = None, ) -> None: super().__init__(config, account, opts) - self._order_status_client = OrderStatusClient(config) + self._order_status_client = OrderStatusClient(config, http_client=self._http_client) @property def order_status_client(self) -> OrderStatusClient: @@ -1093,7 +1093,7 @@ def __init__( opts: BaseSDKOptionsSync | None = None, ) -> None: super().__init__(config, account, opts) - self._order_status_client = OrderStatusClient(config) + self._order_status_client = OrderStatusClient(config, http_client_sync=self._http_client) @property def order_status_client(self) -> OrderStatusClient: diff --git a/tests/conftest.py b/tests/conftest.py index 97941a6..5ca8b38 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,35 +1,109 @@ -from typing import TYPE_CHECKING, Any +from __future__ import annotations +from typing import Any +from unittest.mock import AsyncMock, MagicMock + +import httpx import pytest -if TYPE_CHECKING: - from decibel.abi import AbiRegistry +from decibel._constants import ( + CompatVersion, + DecibelConfig, + Deployment, + Network, +) + +TEST_PACKAGE = "0x" + "ab" * 32 +TEST_USDC = "0x" + "cd" * 32 +TEST_TESTC = "0x" + "ef" * 32 +TEST_PERP_ENGINE = "0x" + "12" * 32 +TEST_FULLNODE_URL = "https://test-node.example.com/v1" +TEST_TRADING_HTTP_URL = "https://test-trading.example.com" +TEST_TRADING_WS_URL = "wss://test-trading.example.com/ws" +TEST_GAS_STATION_URL = "https://test-gas.example.com" @pytest.fixture -def config() -> dict[str, Any]: - """Provide test configuration.""" - # Placeholder - will be implemented with actual Config class - return {} +def test_deployment() -> Deployment: + return Deployment( + package=TEST_PACKAGE, + usdc=TEST_USDC, + testc=TEST_TESTC, + perp_engine_global=TEST_PERP_ENGINE, + ) @pytest.fixture -def abi_registry() -> "AbiRegistry": - """Provide ABI registry for tests.""" - from decibel.abi import AbiRegistry +def test_config(test_deployment: Deployment) -> DecibelConfig: + return DecibelConfig( + network=Network.TESTNET, + fullnode_url=TEST_FULLNODE_URL, + trading_http_url=TEST_TRADING_HTTP_URL, + trading_ws_url=TEST_TRADING_WS_URL, + gas_station_url=TEST_GAS_STATION_URL, + gas_station_api_key="test-api-key", + deployment=test_deployment, + chain_id=2, + compat_version=CompatVersion.V0_4, + ) - return AbiRegistry() + +@pytest.fixture +def test_config_no_gas_key(test_deployment: Deployment) -> DecibelConfig: + return DecibelConfig( + network=Network.TESTNET, + fullnode_url=TEST_FULLNODE_URL, + trading_http_url=TEST_TRADING_HTTP_URL, + trading_ws_url=TEST_TRADING_WS_URL, + gas_station_url=TEST_GAS_STATION_URL, + gas_station_api_key=None, + deployment=test_deployment, + chain_id=2, + compat_version=CompatVersion.V0_4, + ) + + +@pytest.fixture +def mock_account() -> MagicMock: + account = MagicMock() + account.address.return_value = MagicMock() + account.address.return_value.__str__ = lambda self: "0x" + "aa" * 32 + account.private_key = MagicMock() + account.public_key.return_value = MagicMock() + return account + + +def make_httpx_response( + status_code: int = 200, + json_data: Any = None, + text: str = "", + reason_phrase: str = "OK", +) -> httpx.Response: + response = httpx.Response( + status_code=status_code, + json=json_data, + text=text if not json_data else "", + request=httpx.Request("GET", "https://test.example.com"), + ) + return response @pytest.fixture -def read_client() -> None: - """Provide read-only client for tests.""" - # Placeholder - will be implemented with DecibelRead - return None +def mock_async_client() -> AsyncMock: + client = AsyncMock(spec=httpx.AsyncClient) + client.aclose = AsyncMock() + return client @pytest.fixture -def write_client() -> None: - """Provide write client for tests.""" - # Placeholder - will be implemented with DecibelWrite - return None +def mock_sync_client() -> MagicMock: + client = MagicMock(spec=httpx.Client) + client.close = MagicMock() + return client + + +@pytest.fixture +def abi_registry() -> Any: + from decibel.abi import AbiRegistry + + return AbiRegistry() diff --git a/tests/read/__init__.py b/tests/read/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/read/test_base_reader.py b/tests/read/test_base_reader.py new file mode 100644 index 0000000..cf9f23f --- /dev/null +++ b/tests/read/test_base_reader.py @@ -0,0 +1,245 @@ +"""Tests for decibel.read._base module (ReaderDeps, BaseReader).""" + +from __future__ import annotations + +from unittest.mock import AsyncMock, MagicMock, patch + +import httpx +import pytest +from pydantic import BaseModel + +from decibel.read._base import BaseReader, ReaderDeps + +# --------------------------------------------------------------------------- +# Simple pydantic model for testing +# --------------------------------------------------------------------------- + + +class _SimpleModel(BaseModel): + value: int + + +_SIMPLE_RETURN = (_SimpleModel(value=42), 200, "OK") + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def reader_deps(test_config: object) -> ReaderDeps: + return ReaderDeps( + config=test_config, # type: ignore[arg-type] + ws=MagicMock(), + aptos=MagicMock(), + api_key="test-api-key", + http_client=AsyncMock(spec=httpx.AsyncClient), + http_client_sync=MagicMock(spec=httpx.Client), + ) + + +@pytest.fixture +def reader(reader_deps: ReaderDeps) -> BaseReader: + return BaseReader(reader_deps) + + +# --------------------------------------------------------------------------- +# ReaderDeps +# --------------------------------------------------------------------------- + + +class TestReaderDeps: + def test_required_fields(self, test_config: object) -> None: + ws = MagicMock() + aptos = MagicMock() + deps = ReaderDeps(config=test_config, ws=ws, aptos=aptos) # type: ignore[arg-type] + assert deps.config is test_config + assert deps.ws is ws + assert deps.aptos is aptos + + def test_optional_api_key_defaults_to_none(self, test_config: object) -> None: + deps = ReaderDeps(config=test_config, ws=MagicMock(), aptos=MagicMock()) # type: ignore[arg-type] + assert deps.api_key is None + + def test_optional_http_client_defaults_to_none(self, test_config: object) -> None: + deps = ReaderDeps(config=test_config, ws=MagicMock(), aptos=MagicMock()) # type: ignore[arg-type] + assert deps.http_client is None + + def test_optional_http_client_sync_defaults_to_none(self, test_config: object) -> None: + deps = ReaderDeps(config=test_config, ws=MagicMock(), aptos=MagicMock()) # type: ignore[arg-type] + assert deps.http_client_sync is None + + def test_stores_all_optional_fields(self, test_config: object) -> None: + async_client = AsyncMock(spec=httpx.AsyncClient) + sync_client = MagicMock(spec=httpx.Client) + deps = ReaderDeps( + config=test_config, # type: ignore[arg-type] + ws=MagicMock(), + aptos=MagicMock(), + api_key="key", + http_client=async_client, + http_client_sync=sync_client, + ) + assert deps.api_key == "key" + assert deps.http_client is async_client + assert deps.http_client_sync is sync_client + + +# --------------------------------------------------------------------------- +# BaseReader properties +# --------------------------------------------------------------------------- + + +class TestBaseReaderProperties: + def test_config_property(self, reader: BaseReader, reader_deps: ReaderDeps) -> None: + assert reader.config is reader_deps.config + + def test_ws_property(self, reader: BaseReader, reader_deps: ReaderDeps) -> None: + assert reader.ws is reader_deps.ws + + def test_aptos_property(self, reader: BaseReader, reader_deps: ReaderDeps) -> None: + assert reader.aptos is reader_deps.aptos + + +# --------------------------------------------------------------------------- +# BaseReader.get_request +# --------------------------------------------------------------------------- + + +class TestBaseReaderGetRequest: + async def test_get_request_delegates_to_utility(self, reader: BaseReader) -> None: + with patch("decibel.read._base.get_request", return_value=_SIMPLE_RETURN) as mock_get: + result = await reader.get_request(_SimpleModel, "https://example.com/api") + mock_get.assert_called_once_with( + model=_SimpleModel, + url="https://example.com/api", + params=None, + api_key=reader._deps.api_key, + client=reader._deps.http_client, + ) + assert result == _SIMPLE_RETURN + + async def test_get_request_passes_params(self, reader: BaseReader) -> None: + params = {"key": "value"} + with patch("decibel.read._base.get_request", return_value=_SIMPLE_RETURN) as mock_get: + await reader.get_request(_SimpleModel, "https://example.com/api", params=params) + mock_get.assert_called_once_with( + model=_SimpleModel, + url="https://example.com/api", + params=params, + api_key=reader._deps.api_key, + client=reader._deps.http_client, + ) + + +# --------------------------------------------------------------------------- +# BaseReader.post_request +# --------------------------------------------------------------------------- + + +class TestBaseReaderPostRequest: + async def test_post_request_delegates_to_utility(self, reader: BaseReader) -> None: + with patch("decibel.read._base.post_request", return_value=_SIMPLE_RETURN) as mock_post: + body = {"data": 1} + result = await reader.post_request(_SimpleModel, "https://example.com/api", body=body) + mock_post.assert_called_once_with( + model=_SimpleModel, + url="https://example.com/api", + body=body, + api_key=reader._deps.api_key, + client=reader._deps.http_client, + ) + assert result == _SIMPLE_RETURN + + async def test_post_request_no_body(self, reader: BaseReader) -> None: + with patch("decibel.read._base.post_request", return_value=_SIMPLE_RETURN) as mock_post: + await reader.post_request(_SimpleModel, "https://example.com/api") + mock_post.assert_called_once_with( + model=_SimpleModel, + url="https://example.com/api", + body=None, + api_key=reader._deps.api_key, + client=reader._deps.http_client, + ) + + +# --------------------------------------------------------------------------- +# BaseReader.patch_request +# --------------------------------------------------------------------------- + + +class TestBaseReaderPatchRequest: + async def test_patch_request_delegates_to_utility(self, reader: BaseReader) -> None: + with patch("decibel.read._base.patch_request", return_value=_SIMPLE_RETURN) as mock_patch: + body = {"update": True} + result = await reader.patch_request(_SimpleModel, "https://example.com/api", body=body) + mock_patch.assert_called_once_with( + model=_SimpleModel, + url="https://example.com/api", + body=body, + api_key=reader._deps.api_key, + client=reader._deps.http_client, + ) + assert result == _SIMPLE_RETURN + + +# --------------------------------------------------------------------------- +# BaseReader sync variants +# --------------------------------------------------------------------------- + + +class TestBaseReaderSyncVariants: + def test_get_request_sync_delegates_to_utility(self, reader: BaseReader) -> None: + with patch("decibel.read._base.get_request_sync", return_value=_SIMPLE_RETURN) as mock_get: + result = reader.get_request_sync(_SimpleModel, "https://example.com/api") + mock_get.assert_called_once_with( + model=_SimpleModel, + url="https://example.com/api", + params=None, + api_key=reader._deps.api_key, + client=reader._deps.http_client_sync, + ) + assert result == _SIMPLE_RETURN + + def test_get_request_sync_passes_params(self, reader: BaseReader) -> None: + params = {"filter": "all"} + with patch("decibel.read._base.get_request_sync", return_value=_SIMPLE_RETURN) as mock_get: + reader.get_request_sync(_SimpleModel, "https://example.com/api", params=params) + mock_get.assert_called_once_with( + model=_SimpleModel, + url="https://example.com/api", + params=params, + api_key=reader._deps.api_key, + client=reader._deps.http_client_sync, + ) + + def test_post_request_sync_delegates_to_utility(self, reader: BaseReader) -> None: + with patch( + "decibel.read._base.post_request_sync", return_value=_SIMPLE_RETURN + ) as mock_post: + body = {"x": 1} + result = reader.post_request_sync(_SimpleModel, "https://example.com/api", body=body) + mock_post.assert_called_once_with( + model=_SimpleModel, + url="https://example.com/api", + body=body, + api_key=reader._deps.api_key, + client=reader._deps.http_client_sync, + ) + assert result == _SIMPLE_RETURN + + def test_patch_request_sync_delegates_to_utility(self, reader: BaseReader) -> None: + with patch( + "decibel.read._base.patch_request_sync", return_value=_SIMPLE_RETURN + ) as mock_patch: + body = {"y": 2} + result = reader.patch_request_sync(_SimpleModel, "https://example.com/api", body=body) + mock_patch.assert_called_once_with( + model=_SimpleModel, + url="https://example.com/api", + body=body, + api_key=reader._deps.api_key, + client=reader._deps.http_client_sync, + ) + assert result == _SIMPLE_RETURN diff --git a/tests/read/test_readers.py b/tests/read/test_readers.py new file mode 100644 index 0000000..42863f1 --- /dev/null +++ b/tests/read/test_readers.py @@ -0,0 +1,1324 @@ +"""Tests for all reader modules in decibel.read.*""" + +from __future__ import annotations + +import json +from typing import Any +from unittest.mock import AsyncMock, MagicMock, patch + +import httpx +import pytest + +from decibel.read._base import ReaderDeps + +# --------------------------------------------------------------------------- +# Shared fixtures and helpers +# --------------------------------------------------------------------------- + + +@pytest.fixture +def reader_deps(test_config: object) -> ReaderDeps: + return ReaderDeps( + config=test_config, # type: ignore[arg-type] + ws=MagicMock(), + aptos=MagicMock(), + api_key="test-key", + http_client=AsyncMock(spec=httpx.AsyncClient), + http_client_sync=MagicMock(spec=httpx.Client), + ) + + +def _mock_get(return_value: Any) -> MagicMock: + """Helper to create a mock for get_request that returns the given value.""" + mock = AsyncMock(return_value=(return_value, 200, "OK")) + return mock + + +def _mock_get_root(items: list[Any], model_cls: type) -> Any: + """Return a RootModel wrapping items.""" + return model_cls(items) + + +# --------------------------------------------------------------------------- +# AccountOverviewReader +# --------------------------------------------------------------------------- + + +class TestAccountOverviewReader: + async def test_get_by_addr_basic(self, reader_deps: ReaderDeps) -> None: + from decibel.read._account_overview import AccountOverview, AccountOverviewReader + + overview = AccountOverview( + perp_equity_balance=100.0, + unrealized_pnl=5.0, + unrealized_funding_cost=0.5, + cross_margin_ratio=0.1, + maintenance_margin=0.05, + cross_account_leverage_ratio=None, + volume=None, + net_deposits=None, + all_time_return=None, + pnl_90d=None, + sharpe_ratio=None, + max_drawdown=None, + weekly_win_rate_12w=None, + average_cash_position=None, + average_leverage=None, + cross_account_position=10.0, + total_margin=500.0, + usdc_cross_withdrawable_balance=100.0, + usdc_isolated_withdrawable_balance=0.0, + realized_pnl=None, + liquidation_fees_paid=None, + liquidation_losses=None, + ) + reader = AccountOverviewReader(reader_deps) + + with patch.object(reader, "get_request", new_callable=AsyncMock) as mock_req: + mock_req.return_value = (overview, 200, "OK") + result = await reader.get_by_addr(sub_addr="0xuser") + + assert result is overview + mock_req.assert_called_once() + call_kwargs = mock_req.call_args.kwargs + assert "account" in call_kwargs["params"] + assert call_kwargs["params"]["account"] == "0xuser" + + async def test_get_by_addr_with_volume_window(self, reader_deps: ReaderDeps) -> None: + from decibel.read._account_overview import ( + AccountOverview, + AccountOverviewReader, + VolumeWindow, + ) + + overview = AccountOverview( + perp_equity_balance=100.0, + unrealized_pnl=5.0, + unrealized_funding_cost=0.5, + cross_margin_ratio=0.1, + maintenance_margin=0.05, + cross_account_leverage_ratio=None, + volume=None, + net_deposits=None, + all_time_return=None, + pnl_90d=None, + sharpe_ratio=None, + max_drawdown=None, + weekly_win_rate_12w=None, + average_cash_position=None, + average_leverage=None, + cross_account_position=10.0, + total_margin=500.0, + usdc_cross_withdrawable_balance=100.0, + usdc_isolated_withdrawable_balance=0.0, + realized_pnl=None, + liquidation_fees_paid=None, + liquidation_losses=None, + ) + reader = AccountOverviewReader(reader_deps) + + with patch.object(reader, "get_request", new_callable=AsyncMock) as mock_req: + mock_req.return_value = (overview, 200, "OK") + await reader.get_by_addr( + sub_addr="0xuser", + volume_window=VolumeWindow.SEVEN_DAYS, + include_performance=True, + ) + + call_kwargs = mock_req.call_args.kwargs + assert call_kwargs["params"]["volume_window"] == "7d" + assert call_kwargs["params"]["include_performance"] == "true" + + def test_subscribe_by_addr(self, reader_deps: ReaderDeps) -> None: + from decibel.read._account_overview import AccountOverviewReader, AccountOverviewWsMessage + + reader = AccountOverviewReader(reader_deps) + on_data = MagicMock() + + reader_deps.ws.subscribe.return_value = MagicMock() + reader.subscribe_by_addr("0xuser", on_data) + + reader_deps.ws.subscribe.assert_called_once() + call_args = reader_deps.ws.subscribe.call_args + assert "account_overview:0xuser" in call_args[0][0] + assert call_args[0][1] is AccountOverviewWsMessage + + +# --------------------------------------------------------------------------- +# CandlesticksReader +# --------------------------------------------------------------------------- + + +class TestCandlesticksReader: + async def test_get_by_name(self, reader_deps: ReaderDeps) -> None: + from decibel.read._candlesticks import ( + Candlestick, + CandlestickInterval, + CandlesticksReader, + _CandlesticksList, + ) + + candle = Candlestick(T=2000, c=100.0, h=105.0, i="1m", l=95.0, o=98.0, t=1000, v=500.0) + candles_list = _CandlesticksList([candle]) + reader = CandlesticksReader(reader_deps) + + with patch.object(reader, "get_request", new_callable=AsyncMock) as mock_req: + mock_req.return_value = (candles_list, 200, "OK") + result = await reader.get_by_name( + "BTC-PERP", + interval=CandlestickInterval.ONE_MINUTE, + start_time=1000, + end_time=2000, + ) + + assert len(result) == 1 + assert result[0].close == 100.0 + + def test_subscribe_by_name(self, reader_deps: ReaderDeps) -> None: + from decibel.read._candlesticks import ( + CandlestickInterval, + CandlesticksReader, + CandlestickWsMessage, + ) + + reader = CandlesticksReader(reader_deps) + on_data = MagicMock() + reader_deps.ws.subscribe.return_value = MagicMock() + + reader.subscribe_by_name("BTC-PERP", CandlestickInterval.ONE_MINUTE, on_data) + + reader_deps.ws.subscribe.assert_called_once() + call_args = reader_deps.ws.subscribe.call_args + assert "1m" in call_args[0][0] + assert call_args[0][1] is CandlestickWsMessage + + +# --------------------------------------------------------------------------- +# DelegationsReader +# --------------------------------------------------------------------------- + + +class TestDelegationsReader: + async def test_get_all(self, reader_deps: ReaderDeps) -> None: + from decibel.read._delegations import Delegation, DelegationsReader, _DelegationsList + + delegation = Delegation( + delegated_account="0xdeleg", + permission_type="full", + expiration_time_s=None, + ) + delegations_list = _DelegationsList([delegation]) + reader = DelegationsReader(reader_deps) + + with patch.object(reader, "get_request", new_callable=AsyncMock) as mock_req: + mock_req.return_value = (delegations_list, 200, "OK") + result = await reader.get_all(sub_addr="0xuser") + + assert len(result) == 1 + assert result[0].delegated_account == "0xdeleg" + call_kwargs = mock_req.call_args.kwargs + assert call_kwargs["params"]["subaccount"] == "0xuser" + + +# --------------------------------------------------------------------------- +# LeaderboardReader +# --------------------------------------------------------------------------- + + +class TestLeaderboardReader: + async def test_get_leaderboard_no_params(self, reader_deps: ReaderDeps) -> None: + from decibel.read._leaderboard import LeaderboardReader, LeaderboardResponse + + response = LeaderboardResponse(items=[], total_count=0) + reader = LeaderboardReader(reader_deps) + + with patch.object(reader, "get_request", new_callable=AsyncMock) as mock_req: + mock_req.return_value = (response, 200, "OK") + result = await reader.get_leaderboard() + + assert result is response + # params=None when no params provided + call_kwargs = mock_req.call_args.kwargs + assert call_kwargs["params"] is None + + async def test_get_leaderboard_with_all_params(self, reader_deps: ReaderDeps) -> None: + from decibel.read._leaderboard import LeaderboardReader, LeaderboardResponse + + response = LeaderboardResponse(items=[], total_count=0) + reader = LeaderboardReader(reader_deps) + + with patch.object(reader, "get_request", new_callable=AsyncMock) as mock_req: + mock_req.return_value = (response, 200, "OK") + await reader.get_leaderboard( + limit=10, + offset=5, + search_term="0xuser", + sort_key="volume", + sort_dir="DESC", + ) + + call_kwargs = mock_req.call_args.kwargs + params = call_kwargs["params"] + assert params["limit"] == "10" + assert params["offset"] == "5" + assert params["search_term"] == "0xuser" + assert params["sort_key"] == "volume" + assert params["sort_dir"] == "DESC" + + +# --------------------------------------------------------------------------- +# MarketContextsReader +# --------------------------------------------------------------------------- + + +class TestMarketContextsReader: + async def test_get_all(self, reader_deps: ReaderDeps) -> None: + from decibel.read._market_contexts import ( + MarketContext, + MarketContextsReader, + _MarketContextList, + ) + + ctx = MarketContext( + market="0xmarket", + volume_24h=1000.0, + open_interest=500.0, + previous_day_price=100.0, + price_change_pct_24h=1.5, + ) + ctx_list = _MarketContextList([ctx]) + reader = MarketContextsReader(reader_deps) + + with patch.object(reader, "get_request", new_callable=AsyncMock) as mock_req: + mock_req.return_value = (ctx_list, 200, "OK") + result = await reader.get_all() + + assert len(result) == 1 + assert result[0].market == "0xmarket" + + +# --------------------------------------------------------------------------- +# MarketDepthReader +# --------------------------------------------------------------------------- + + +class TestMarketDepthReader: + async def test_get_by_name(self, reader_deps: ReaderDeps) -> None: + from decibel.read._market_depth import MarketDepth, MarketDepthReader + + depth = MarketDepth(market="0xmarket", bids=[], asks=[], unix_ms=1000) + reader = MarketDepthReader(reader_deps) + + with patch.object(reader, "get_request", new_callable=AsyncMock) as mock_req: + mock_req.return_value = (depth, 200, "OK") + result = await reader.get_by_name("BTC-PERP") + + assert result is depth + + async def test_get_by_name_with_limit(self, reader_deps: ReaderDeps) -> None: + from decibel.read._market_depth import MarketDepth, MarketDepthReader + + depth = MarketDepth(market="0xmarket", bids=[], asks=[], unix_ms=1000) + reader = MarketDepthReader(reader_deps) + + with patch.object(reader, "get_request", new_callable=AsyncMock) as mock_req: + mock_req.return_value = (depth, 200, "OK") + await reader.get_by_name("BTC-PERP", limit=5) + + call_kwargs = mock_req.call_args.kwargs + assert call_kwargs["params"]["limit"] == "5" + + def test_subscribe_by_name(self, reader_deps: ReaderDeps) -> None: + from decibel.read._market_depth import MarketDepth, MarketDepthReader + + reader = MarketDepthReader(reader_deps) + on_data = MagicMock() + reader_deps.ws.subscribe.return_value = MagicMock() + + reader.subscribe_by_name("BTC-PERP", 1, on_data) + + reader_deps.ws.subscribe.assert_called_once() + call_args = reader_deps.ws.subscribe.call_args + assert ":1" in call_args[0][0] # aggregation_size in topic + assert call_args[0][1] is MarketDepth + + def test_reset_subscription_by_name(self, reader_deps: ReaderDeps) -> None: + from decibel.read._market_depth import MarketDepthReader + + reader = MarketDepthReader(reader_deps) + reader.reset_subscription_by_name("BTC-PERP", aggregation_size=5) + reader_deps.ws.reset.assert_called_once() + + def test_get_aggregation_sizes(self, reader_deps: ReaderDeps) -> None: + from decibel.read._market_depth import MarketDepthReader + + reader = MarketDepthReader(reader_deps) + sizes = reader.get_aggregation_sizes() + assert sizes == (1, 2, 5, 10, 100, 1000) + + +# --------------------------------------------------------------------------- +# MarketPricesReader +# --------------------------------------------------------------------------- + + +class TestMarketPricesReader: + async def test_get_all(self, reader_deps: ReaderDeps) -> None: + from decibel.read._market_prices import MarketPrice, MarketPricesReader, _MarketPriceList + + price = MarketPrice( + market="0xmarket", + mark_px=100.0, + mid_px=99.9, + oracle_px=100.1, + funding_rate_bps=0.01, + is_funding_positive=True, + open_interest=5000.0, + transaction_unix_ms=1000, + ) + prices_list = _MarketPriceList([price]) + reader = MarketPricesReader(reader_deps) + + with patch.object(reader, "get_request", new_callable=AsyncMock) as mock_req: + mock_req.return_value = (prices_list, 200, "OK") + result = await reader.get_all() + + assert len(result) == 1 + + async def test_get_by_name(self, reader_deps: ReaderDeps) -> None: + from decibel.read._market_prices import MarketPrice, MarketPricesReader, _MarketPriceList + + price = MarketPrice( + market="0xmarket", + mark_px=100.0, + mid_px=99.9, + oracle_px=100.1, + funding_rate_bps=0.01, + is_funding_positive=True, + open_interest=5000.0, + transaction_unix_ms=1000, + ) + prices_list = _MarketPriceList([price]) + reader = MarketPricesReader(reader_deps) + + with patch.object(reader, "get_request", new_callable=AsyncMock) as mock_req: + mock_req.return_value = (prices_list, 200, "OK") + result = await reader.get_by_name("BTC-PERP") + + assert len(result) == 1 + + def test_subscribe_by_name(self, reader_deps: ReaderDeps) -> None: + from decibel.read._market_prices import MarketPricesReader, MarketPriceWsMessage + + reader = MarketPricesReader(reader_deps) + on_data = MagicMock() + reader_deps.ws.subscribe.return_value = MagicMock() + + reader.subscribe_by_name("BTC-PERP", on_data) + + reader_deps.ws.subscribe.assert_called_once() + call_args = reader_deps.ws.subscribe.call_args + assert "market_price:" in call_args[0][0] + assert call_args[0][1] is MarketPriceWsMessage + + def test_subscribe_by_address(self, reader_deps: ReaderDeps) -> None: + from decibel.read._market_prices import MarketPricesReader + + reader = MarketPricesReader(reader_deps) + on_data = MagicMock() + reader_deps.ws.subscribe.return_value = MagicMock() + + reader.subscribe_by_address("0xmarket", on_data) + + reader_deps.ws.subscribe.assert_called_once() + call_args = reader_deps.ws.subscribe.call_args + assert call_args[0][0] == "market_price:0xmarket" + + def test_subscribe_all(self, reader_deps: ReaderDeps) -> None: + from decibel.read._market_prices import AllMarketPricesWsMessage, MarketPricesReader + + reader = MarketPricesReader(reader_deps) + on_data = MagicMock() + reader_deps.ws.subscribe.return_value = MagicMock() + + reader.subscribe_all(on_data) + + reader_deps.ws.subscribe.assert_called_once() + call_args = reader_deps.ws.subscribe.call_args + assert call_args[0][0] == "all_market_prices" + assert call_args[0][1] is AllMarketPricesWsMessage + + +# --------------------------------------------------------------------------- +# MarketTradesReader +# --------------------------------------------------------------------------- + + +class TestMarketTradesReader: + async def test_get_by_name(self, reader_deps: ReaderDeps) -> None: + from decibel.read._market_trades import ( + MarketTrade, + MarketTradesReader, + MarketTradesResponse, + ) + + trade = MarketTrade( + account="0xaccount", + market="0xmarket", + action="OpenLong", + size=1.0, + price=100.0, + is_profit=True, + realized_pnl_amount=5.0, + is_funding_positive=True, + realized_funding_amount=0.1, + is_rebate=False, + fee_amount=0.05, + transaction_unix_ms=1000, + transaction_version=1, + ) + response = MarketTradesResponse(items=[trade], total_count=1) + reader = MarketTradesReader(reader_deps) + + with patch.object(reader, "get_request", new_callable=AsyncMock) as mock_req: + mock_req.return_value = (response, 200, "OK") + result = await reader.get_by_name("BTC-PERP") + + assert len(result) == 1 + + async def test_get_by_name_with_limit(self, reader_deps: ReaderDeps) -> None: + from decibel.read._market_trades import MarketTradesReader, MarketTradesResponse + + response = MarketTradesResponse(items=[], total_count=0) + reader = MarketTradesReader(reader_deps) + + with patch.object(reader, "get_request", new_callable=AsyncMock) as mock_req: + mock_req.return_value = (response, 200, "OK") + await reader.get_by_name("BTC-PERP", limit=5) + + call_kwargs = mock_req.call_args.kwargs + assert call_kwargs["params"]["limit"] == "5" + + def test_subscribe_by_name(self, reader_deps: ReaderDeps) -> None: + from decibel.read._market_trades import MarketTradesReader, MarketTradeWsMessage + + reader = MarketTradesReader(reader_deps) + on_data = MagicMock() + reader_deps.ws.subscribe.return_value = MagicMock() + + reader.subscribe_by_name("BTC-PERP", on_data) + + reader_deps.ws.subscribe.assert_called_once() + call_args = reader_deps.ws.subscribe.call_args + assert "trades:" in call_args[0][0] + assert call_args[0][1] is MarketTradeWsMessage + + +# --------------------------------------------------------------------------- +# MarketsReader +# --------------------------------------------------------------------------- + + +class TestMarketsReader: + async def test_get_all_deduplicates(self, reader_deps: ReaderDeps) -> None: + from decibel.read._markets import MarketsReader, PerpMarket, _PerpMarketList + + market = PerpMarket( + market_addr="0xmarket", + market_name="BTC-PERP", + sz_decimals=2, + px_decimals=2, + max_leverage=10.0, + tick_size=0.1, + min_size=0.01, + lot_size=0.01, + max_open_interest=1000.0, + mode="Open", + ) + # Duplicate market + markets_list = _PerpMarketList([market, market]) + reader = MarketsReader(reader_deps) + + with patch.object(reader, "get_request", new_callable=AsyncMock) as mock_req: + mock_req.return_value = (markets_list, 200, "OK") + result = await reader.get_all() + + assert len(result) == 1 + assert result[0].market_addr == "0xmarket" + + async def test_get_by_name_success(self, reader_deps: ReaderDeps) -> None: + from decibel.read._markets import MarketsReader + + mock_resource = { + "__variant__": "V1", + "name": "BTC-PERP", + "sz_precision": {"decimals": 2, "multiplier": "100"}, + "min_size": "1", + "lot_size": "1", + "ticker_size": "1", + "max_leverage": 10.0, + "mode": {"__variant__": "Open"}, + } + reader_deps.aptos.account_resource = AsyncMock(return_value=mock_resource) + reader = MarketsReader(reader_deps) + + result = await reader.get_by_name("BTC-PERP") + + assert result is not None + assert result.name == "BTC-PERP" + + async def test_get_by_name_returns_none_on_error(self, reader_deps: ReaderDeps) -> None: + from decibel.read._markets import MarketsReader + + reader_deps.aptos.account_resource = AsyncMock(side_effect=Exception("resource not found")) + reader = MarketsReader(reader_deps) + + result = await reader.get_by_name("NONEXISTENT-PERP") + + assert result is None + + async def test_list_market_addresses(self, reader_deps: ReaderDeps) -> None: + from decibel.read._markets import MarketsReader + + addresses = ["0xaddr1", "0xaddr2"] + raw_bytes = json.dumps([addresses]).encode("utf-8") + reader_deps.aptos.view = AsyncMock(return_value=raw_bytes) + reader = MarketsReader(reader_deps) + + result = await reader.list_market_addresses() + + assert result == ["0xaddr1", "0xaddr2"] + + async def test_market_name_by_address(self, reader_deps: ReaderDeps) -> None: + from decibel.read._markets import MarketsReader + + raw_bytes = json.dumps(["BTC-PERP"]).encode("utf-8") + reader_deps.aptos.view = AsyncMock(return_value=raw_bytes) + reader = MarketsReader(reader_deps) + + result = await reader.market_name_by_address("0xmarket") + + assert result == "BTC-PERP" + + +# --------------------------------------------------------------------------- +# PortfolioChartReader +# --------------------------------------------------------------------------- + + +class TestPortfolioChartReader: + async def test_get_by_addr(self, reader_deps: ReaderDeps) -> None: + from decibel.read._portfolio_chart import ( + PortfolioChartItem, + PortfolioChartReader, + _PortfolioChartList, + ) + + item = PortfolioChartItem(timestamp=1000, data_points=50.0) + chart_list = _PortfolioChartList([item]) + reader = PortfolioChartReader(reader_deps) + + with patch.object(reader, "get_request", new_callable=AsyncMock) as mock_req: + mock_req.return_value = (chart_list, 200, "OK") + result = await reader.get_by_addr(sub_addr="0xuser", time_range="7d", data_type="pnl") + + assert len(result) == 1 + assert result[0].data_points == 50.0 + call_kwargs = mock_req.call_args.kwargs + params = call_kwargs["params"] + assert params["account"] == "0xuser" + assert params["range"] == "7d" + assert params["data_type"] == "pnl" + + +# --------------------------------------------------------------------------- +# TradingPointsReader +# --------------------------------------------------------------------------- + + +class TestTradingPointsReader: + async def test_get_by_owner(self, reader_deps: ReaderDeps) -> None: + from decibel.read._trading_points import OwnerTradingPoints, TradingPointsReader + + points = OwnerTradingPoints(owner="0xowner", total_points=100.0, breakdown=None) + reader = TradingPointsReader(reader_deps) + + with patch.object(reader, "get_request", new_callable=AsyncMock) as mock_req: + mock_req.return_value = (points, 200, "OK") + result = await reader.get_by_owner(owner_addr="0xowner") + + assert result is points + call_kwargs = mock_req.call_args.kwargs + assert call_kwargs["params"]["owner"] == "0xowner" + + +# --------------------------------------------------------------------------- +# UserActiveTwapsReader +# --------------------------------------------------------------------------- + + +class TestUserActiveTwapsReader: + async def test_get_by_addr(self, reader_deps: ReaderDeps) -> None: + from decibel.read._user_active_twaps import ( + UserActiveTwap, + UserActiveTwapsReader, + _UserActiveTwapsList, + ) + + twap = UserActiveTwap( + market="0xmarket", + is_buy=True, + order_id="0xorder", + client_order_id="0xclient", + is_reduce_only=False, + start_unix_ms=1000, + frequency_s=60, + duration_s=3600, + orig_size=1.0, + remaining_size=0.5, + status="Activated", + transaction_unix_ms=1000, + transaction_version=1, + ) + twaps_list = _UserActiveTwapsList([twap]) + reader = UserActiveTwapsReader(reader_deps) + + with patch.object(reader, "get_request", new_callable=AsyncMock) as mock_req: + mock_req.return_value = (twaps_list, 200, "OK") + result = await reader.get_by_addr(sub_addr="0xuser") + + assert len(result) == 1 + call_kwargs = mock_req.call_args.kwargs + assert call_kwargs["params"]["account"] == "0xuser" + + def test_subscribe_by_addr(self, reader_deps: ReaderDeps) -> None: + from decibel.read._user_active_twaps import UserActiveTwapsReader, UserActiveTwapsWsMessage + + reader = UserActiveTwapsReader(reader_deps) + on_data = MagicMock() + reader_deps.ws.subscribe.return_value = MagicMock() + + reader.subscribe_by_addr("0xuser", on_data) + + reader_deps.ws.subscribe.assert_called_once() + call_args = reader_deps.ws.subscribe.call_args + assert call_args[0][0] == "user_active_twaps:0xuser" + assert call_args[0][1] is UserActiveTwapsWsMessage + + +# --------------------------------------------------------------------------- +# UserBulkOrdersReader +# --------------------------------------------------------------------------- + + +class TestUserBulkOrdersReader: + async def test_get_by_addr_no_market(self, reader_deps: ReaderDeps) -> None: + from decibel.read._user_bulk_orders import ( + UserBulkOrder, + UserBulkOrdersReader, + _UserBulkOrdersList, + ) + + order = UserBulkOrder( + market="0xmarket", + sequence_number=1, + previous_seq_num=0, + bid_prices=[100.0], + bid_sizes=[1.0], + ask_prices=[101.0], + ask_sizes=[1.0], + cancelled_bid_prices=[], + cancelled_bid_sizes=[], + cancelled_ask_prices=[], + cancelled_ask_sizes=[], + ) + orders_list = _UserBulkOrdersList([order]) + reader = UserBulkOrdersReader(reader_deps) + + with patch.object(reader, "get_request", new_callable=AsyncMock) as mock_req: + mock_req.return_value = (orders_list, 200, "OK") + result = await reader.get_by_addr(sub_addr="0xuser") + + assert len(result) == 1 + call_kwargs = mock_req.call_args.kwargs + assert call_kwargs["params"]["account"] == "0xuser" + assert call_kwargs["params"]["market"] == "all" + + async def test_get_by_addr_with_market(self, reader_deps: ReaderDeps) -> None: + from decibel.read._user_bulk_orders import UserBulkOrdersReader, _UserBulkOrdersList + + orders_list = _UserBulkOrdersList([]) + reader = UserBulkOrdersReader(reader_deps) + + with patch.object(reader, "get_request", new_callable=AsyncMock) as mock_req: + mock_req.return_value = (orders_list, 200, "OK") + await reader.get_by_addr(sub_addr="0xuser", market="0xmarket") + + call_kwargs = mock_req.call_args.kwargs + assert call_kwargs["params"]["market"] == "0xmarket" + + def test_subscribe_by_addr(self, reader_deps: ReaderDeps) -> None: + from decibel.read._user_bulk_orders import UserBulkOrdersReader, UserBulkOrderWsMessage + + reader = UserBulkOrdersReader(reader_deps) + on_data = MagicMock() + reader_deps.ws.subscribe.return_value = MagicMock() + + reader.subscribe_by_addr("0xuser", on_data) + + reader_deps.ws.subscribe.assert_called_once() + call_args = reader_deps.ws.subscribe.call_args + assert call_args[0][0] == "bulk_orders:0xuser" + assert call_args[0][1] is UserBulkOrderWsMessage + + +# --------------------------------------------------------------------------- +# UserFundHistoryReader +# --------------------------------------------------------------------------- + + +class TestUserFundHistoryReader: + async def test_get_by_addr_defaults(self, reader_deps: ReaderDeps) -> None: + from decibel.read._user_fund_history import UserFundHistoryReader, UserFundHistoryResponse + + response = UserFundHistoryResponse(funds=[], total=0) + reader = UserFundHistoryReader(reader_deps) + + with patch.object(reader, "get_request", new_callable=AsyncMock) as mock_req: + mock_req.return_value = (response, 200, "OK") + result = await reader.get_by_addr(sub_addr="0xuser") + + assert result is response + call_kwargs = mock_req.call_args.kwargs + params = call_kwargs["params"] + assert params["account"] == "0xuser" + assert params["limit"] == "200" + assert params["offset"] == "0" + + async def test_get_by_addr_custom_pagination(self, reader_deps: ReaderDeps) -> None: + from decibel.read._user_fund_history import UserFundHistoryReader, UserFundHistoryResponse + + response = UserFundHistoryResponse(funds=[], total=0) + reader = UserFundHistoryReader(reader_deps) + + with patch.object(reader, "get_request", new_callable=AsyncMock) as mock_req: + mock_req.return_value = (response, 200, "OK") + await reader.get_by_addr(sub_addr="0xuser", limit=50, offset=100) + + call_kwargs = mock_req.call_args.kwargs + params = call_kwargs["params"] + assert params["limit"] == "50" + assert params["offset"] == "100" + + +# --------------------------------------------------------------------------- +# UserFundingHistoryReader +# --------------------------------------------------------------------------- + + +class TestUserFundingHistoryReader: + async def test_get_by_addr(self, reader_deps: ReaderDeps) -> None: + from decibel.read._user_funding_history import ( + UserFundingHistoryReader, + UserFundingHistoryResponse, + ) + + response = UserFundingHistoryResponse(items=[], total_count=0) + reader = UserFundingHistoryReader(reader_deps) + + with patch.object(reader, "get_request", new_callable=AsyncMock) as mock_req: + mock_req.return_value = (response, 200, "OK") + result = await reader.get_by_addr(sub_addr="0xuser", limit=20, offset=10) + + assert result is response + call_kwargs = mock_req.call_args.kwargs + params = call_kwargs["params"] + assert params["account"] == "0xuser" + assert params["limit"] == "20" + assert params["offset"] == "10" + + +# --------------------------------------------------------------------------- +# UserNotificationsReader +# --------------------------------------------------------------------------- + + +class TestUserNotificationsReader: + def test_subscribe_by_addr(self, reader_deps: ReaderDeps) -> None: + from decibel.read._user_notifications import ( + UserNotificationsReader, + UserNotificationWsMessage, + ) + + reader = UserNotificationsReader(reader_deps) + on_data = MagicMock() + reader_deps.ws.subscribe.return_value = MagicMock() + + reader.subscribe_by_addr("0xuser", on_data) + + reader_deps.ws.subscribe.assert_called_once() + call_args = reader_deps.ws.subscribe.call_args + assert call_args[0][0] == "notifications:0xuser" + assert call_args[0][1] is UserNotificationWsMessage + + +# --------------------------------------------------------------------------- +# UserOpenOrdersReader +# --------------------------------------------------------------------------- + + +class TestUserOpenOrdersReader: + async def test_get_by_addr_defaults(self, reader_deps: ReaderDeps) -> None: + from decibel.read._user_open_orders import UserOpenOrdersReader, UserOpenOrdersResponse + + response = UserOpenOrdersResponse(items=[], total_count=0) + reader = UserOpenOrdersReader(reader_deps) + + with patch.object(reader, "get_request", new_callable=AsyncMock) as mock_req: + mock_req.return_value = (response, 200, "OK") + result = await reader.get_by_addr(sub_addr="0xuser") + + assert result is response + call_kwargs = mock_req.call_args.kwargs + params = call_kwargs["params"] + assert params["user"] == "0xuser" + assert "limit" not in params + assert "offset" not in params + + async def test_get_by_addr_with_pagination(self, reader_deps: ReaderDeps) -> None: + from decibel.read._user_open_orders import UserOpenOrdersReader, UserOpenOrdersResponse + + response = UserOpenOrdersResponse(items=[], total_count=0) + reader = UserOpenOrdersReader(reader_deps) + + with patch.object(reader, "get_request", new_callable=AsyncMock) as mock_req: + mock_req.return_value = (response, 200, "OK") + await reader.get_by_addr(sub_addr="0xuser", limit=10, offset=5) + + call_kwargs = mock_req.call_args.kwargs + params = call_kwargs["params"] + assert params["limit"] == "10" + assert params["offset"] == "5" + + def test_subscribe_by_addr(self, reader_deps: ReaderDeps) -> None: + from decibel.read._user_open_orders import UserOpenOrdersReader, UserOpenOrdersWsMessage + + reader = UserOpenOrdersReader(reader_deps) + on_data = MagicMock() + reader_deps.ws.subscribe.return_value = MagicMock() + + reader.subscribe_by_addr("0xuser", on_data) + + reader_deps.ws.subscribe.assert_called_once() + call_args = reader_deps.ws.subscribe.call_args + assert call_args[0][0] == "account_open_orders:0xuser" + assert call_args[0][1] is UserOpenOrdersWsMessage + + +# --------------------------------------------------------------------------- +# UserOrderHistoryReader +# --------------------------------------------------------------------------- + + +class TestUserOrderHistoryReader: + async def test_get_by_addr(self, reader_deps: ReaderDeps) -> None: + from decibel.read._user_order_history import UserOrderHistoryReader, UserOrders + + response = UserOrders(items=[], total_count=0) + reader = UserOrderHistoryReader(reader_deps) + + with patch.object(reader, "get_request", new_callable=AsyncMock) as mock_req: + mock_req.return_value = (response, 200, "OK") + result = await reader.get_by_addr(sub_addr="0xuser") + + assert result is response + call_kwargs = mock_req.call_args.kwargs + params = call_kwargs["params"] + assert params["user"] == "0xuser" + + async def test_get_by_addr_with_pagination(self, reader_deps: ReaderDeps) -> None: + from decibel.read._user_order_history import UserOrderHistoryReader, UserOrders + + response = UserOrders(items=[], total_count=0) + reader = UserOrderHistoryReader(reader_deps) + + with patch.object(reader, "get_request", new_callable=AsyncMock) as mock_req: + mock_req.return_value = (response, 200, "OK") + await reader.get_by_addr(sub_addr="0xuser", limit=25, offset=50) + + call_kwargs = mock_req.call_args.kwargs + params = call_kwargs["params"] + assert params["limit"] == "25" + assert params["offset"] == "50" + + def test_subscribe_by_addr(self, reader_deps: ReaderDeps) -> None: + from decibel.read._user_order_history import UserOrderHistoryReader, UserOrdersWsMessage + + reader = UserOrderHistoryReader(reader_deps) + on_data = MagicMock() + reader_deps.ws.subscribe.return_value = MagicMock() + + reader.subscribe_by_addr("0xuser", on_data) + + reader_deps.ws.subscribe.assert_called_once() + call_args = reader_deps.ws.subscribe.call_args + assert call_args[0][0] == "order_updates:0xuser" + assert call_args[0][1] is UserOrdersWsMessage + + +# --------------------------------------------------------------------------- +# UserPositionsReader +# --------------------------------------------------------------------------- + + +class TestUserPositionsReader: + async def test_get_by_addr_defaults(self, reader_deps: ReaderDeps) -> None: + from decibel.read._user_positions import UserPositionsReader, _UserPositionsList + + positions_list = _UserPositionsList([]) + reader = UserPositionsReader(reader_deps) + + with patch.object(reader, "get_request", new_callable=AsyncMock) as mock_req: + mock_req.return_value = (positions_list, 200, "OK") + result = await reader.get_by_addr(sub_addr="0xuser") + + assert result == [] + call_kwargs = mock_req.call_args.kwargs + params = call_kwargs["params"] + assert params["account"] == "0xuser" + assert params["include_deleted"] == "false" + assert params["limit"] == "10" + + async def test_get_by_addr_with_market_addr(self, reader_deps: ReaderDeps) -> None: + from decibel.read._user_positions import UserPositionsReader, _UserPositionsList + + positions_list = _UserPositionsList([]) + reader = UserPositionsReader(reader_deps) + + with patch.object(reader, "get_request", new_callable=AsyncMock) as mock_req: + mock_req.return_value = (positions_list, 200, "OK") + await reader.get_by_addr( + sub_addr="0xuser", market_addr="0xmarket", include_deleted=True + ) + + call_kwargs = mock_req.call_args.kwargs + params = call_kwargs["params"] + assert params["market_address"] == "0xmarket" + assert params["include_deleted"] == "true" + + def test_subscribe_by_addr(self, reader_deps: ReaderDeps) -> None: + from decibel.read._user_positions import UserPositionsReader, UserPositionsWsMessage + + reader = UserPositionsReader(reader_deps) + on_data = MagicMock() + reader_deps.ws.subscribe.return_value = MagicMock() + + reader.subscribe_by_addr("0xuser", on_data) + + reader_deps.ws.subscribe.assert_called_once() + call_args = reader_deps.ws.subscribe.call_args + assert call_args[0][0] == "account_positions:0xuser" + assert call_args[0][1] is UserPositionsWsMessage + + +# --------------------------------------------------------------------------- +# UserSubaccountsReader +# --------------------------------------------------------------------------- + + +class TestUserSubaccountsReader: + async def test_get_by_addr(self, reader_deps: ReaderDeps) -> None: + from decibel.read._user_subaccounts import ( + UserSubaccount, + UserSubaccountsReader, + _UserSubaccountsList, + ) + + sub = UserSubaccount( + subaccount_address="0xsub", + primary_account_address="0xprimary", + is_primary=True, + is_active=True, + custom_label=None, + ) + subs_list = _UserSubaccountsList([sub]) + reader = UserSubaccountsReader(reader_deps) + + with patch.object(reader, "get_request", new_callable=AsyncMock) as mock_req: + mock_req.return_value = (subs_list, 200, "OK") + result = await reader.get_by_addr(owner_addr="0xowner") + + assert len(result) == 1 + call_kwargs = mock_req.call_args.kwargs + assert call_kwargs["params"]["owner"] == "0xowner" + + +# --------------------------------------------------------------------------- +# UserTradeHistoryReader +# --------------------------------------------------------------------------- + + +class TestUserTradeHistoryReader: + async def test_get_by_addr(self, reader_deps: ReaderDeps) -> None: + from decibel.read._user_trade_history import UserTradeHistoryReader, UserTradesResponse + + response = UserTradesResponse(items=[], total_count=0) + reader = UserTradeHistoryReader(reader_deps) + + with patch.object(reader, "get_request", new_callable=AsyncMock) as mock_req: + mock_req.return_value = (response, 200, "OK") + result = await reader.get_by_addr(sub_addr="0xuser", limit=5, offset=10) + + assert result is response + call_kwargs = mock_req.call_args.kwargs + params = call_kwargs["params"] + assert params["account"] == "0xuser" + assert params["limit"] == "5" + assert params["offset"] == "10" + + def test_subscribe_by_addr(self, reader_deps: ReaderDeps) -> None: + from decibel.read._user_trade_history import UserTradeHistoryReader, UserTradesWsMessage + + reader = UserTradeHistoryReader(reader_deps) + on_data = MagicMock() + reader_deps.ws.subscribe.return_value = MagicMock() + + reader.subscribe_by_addr("0xuser", on_data) + + reader_deps.ws.subscribe.assert_called_once() + call_args = reader_deps.ws.subscribe.call_args + assert call_args[0][0] == "user_trades:0xuser" + assert call_args[0][1] is UserTradesWsMessage + + +# --------------------------------------------------------------------------- +# UserTwapHistoryReader +# --------------------------------------------------------------------------- + + +class TestUserTwapHistoryReader: + async def test_get_by_addr_defaults(self, reader_deps: ReaderDeps) -> None: + from decibel.read._user_twap_history import UserTwapHistoryReader, UserTwapHistoryResponse + + response = UserTwapHistoryResponse(items=[], total_count=0) + reader = UserTwapHistoryReader(reader_deps) + + with patch.object(reader, "get_request", new_callable=AsyncMock) as mock_req: + mock_req.return_value = (response, 200, "OK") + result = await reader.get_by_addr(sub_addr="0xuser") + + assert result is response + call_kwargs = mock_req.call_args.kwargs + params = call_kwargs["params"] + assert params["user"] == "0xuser" + assert params["limit"] == "100" + assert params["offset"] == "0" + + async def test_get_by_addr_custom(self, reader_deps: ReaderDeps) -> None: + from decibel.read._user_twap_history import UserTwapHistoryReader, UserTwapHistoryResponse + + response = UserTwapHistoryResponse(items=[], total_count=0) + reader = UserTwapHistoryReader(reader_deps) + + with patch.object(reader, "get_request", new_callable=AsyncMock) as mock_req: + mock_req.return_value = (response, 200, "OK") + await reader.get_by_addr(sub_addr="0xuser", limit=25, offset=50) + + call_kwargs = mock_req.call_args.kwargs + params = call_kwargs["params"] + assert params["limit"] == "25" + assert params["offset"] == "50" + + +# --------------------------------------------------------------------------- +# VaultsReader +# --------------------------------------------------------------------------- + + +class TestVaultsReader: + async def test_get_vaults_no_params(self, reader_deps: ReaderDeps) -> None: + from decibel.read._vaults import VaultsReader, VaultsResponse + + response = VaultsResponse(items=[], total_count=0, total_value_locked=0.0, total_volume=0.0) + reader = VaultsReader(reader_deps) + + with patch.object(reader, "get_request", new_callable=AsyncMock) as mock_req: + mock_req.return_value = (response, 200, "OK") + result = await reader.get_vaults() + + assert result is response + call_kwargs = mock_req.call_args.kwargs + assert call_kwargs["params"] is None + + async def test_get_vaults_with_all_params(self, reader_deps: ReaderDeps) -> None: + from decibel.read._vaults import VaultsReader, VaultsResponse + + response = VaultsResponse(items=[], total_count=0, total_value_locked=0.0, total_volume=0.0) + reader = VaultsReader(reader_deps) + + with patch.object(reader, "get_request", new_callable=AsyncMock) as mock_req: + mock_req.return_value = (response, 200, "OK") + await reader.get_vaults( + vault_type="user", + limit=10, + offset=5, + address="0xvault", + search="my vault", + ) + + call_kwargs = mock_req.call_args.kwargs + params = call_kwargs["params"] + assert params["vault_type"] == "user" + assert params["limit"] == "10" + assert params["offset"] == "5" + assert params["vault_address"] == "0xvault" + assert params["search"] == "my vault" + + async def test_get_user_owned_vaults(self, reader_deps: ReaderDeps) -> None: + from decibel.read._vaults import UserOwnedVaultsResponse, VaultsReader + + response = UserOwnedVaultsResponse(items=[], total_count=0) + reader = VaultsReader(reader_deps) + + with patch.object(reader, "get_request", new_callable=AsyncMock) as mock_req: + mock_req.return_value = (response, 200, "OK") + result = await reader.get_user_owned_vaults(owner_addr="0xowner") + + assert result is response + call_kwargs = mock_req.call_args.kwargs + params = call_kwargs["params"] + assert params["account"] == "0xowner" + + async def test_get_user_owned_vaults_with_pagination(self, reader_deps: ReaderDeps) -> None: + from decibel.read._vaults import UserOwnedVaultsResponse, VaultsReader + + response = UserOwnedVaultsResponse(items=[], total_count=0) + reader = VaultsReader(reader_deps) + + with patch.object(reader, "get_request", new_callable=AsyncMock) as mock_req: + mock_req.return_value = (response, 200, "OK") + await reader.get_user_owned_vaults(owner_addr="0xowner", limit=5, offset=10) + + call_kwargs = mock_req.call_args.kwargs + params = call_kwargs["params"] + assert params["limit"] == "5" + assert params["offset"] == "10" + + async def test_get_user_performances_on_vaults(self, reader_deps: ReaderDeps) -> None: + from decibel.read._vaults import VaultsReader, _UserPerformancesOnVaultsList + + perfs_list = _UserPerformancesOnVaultsList([]) + reader = VaultsReader(reader_deps) + + with patch.object(reader, "get_request", new_callable=AsyncMock) as mock_req: + mock_req.return_value = (perfs_list, 200, "OK") + result = await reader.get_user_performances_on_vaults(owner_addr="0xowner") + + assert result == [] + call_kwargs = mock_req.call_args.kwargs + assert call_kwargs["params"]["account"] == "0xowner" + + async def test_get_vault_share_price_normal(self, reader_deps: ReaderDeps) -> None: + from decibel.read._vaults import VaultsReader + + nav_bytes = json.dumps([1000]).encode("utf-8") + shares_bytes = json.dumps([100]).encode("utf-8") + reader_deps.aptos.view = AsyncMock(side_effect=[nav_bytes, shares_bytes]) + reader = VaultsReader(reader_deps) + + result = await reader.get_vault_share_price(vault_address="0xvault") + + assert result == pytest.approx(10.0) + + async def test_get_vault_share_price_zero_shares_returns_1( + self, reader_deps: ReaderDeps + ) -> None: + from decibel.read._vaults import VaultsReader + + nav_bytes = json.dumps([0]).encode("utf-8") + shares_bytes = json.dumps([0]).encode("utf-8") + reader_deps.aptos.view = AsyncMock(side_effect=[nav_bytes, shares_bytes]) + reader = VaultsReader(reader_deps) + + result = await reader.get_vault_share_price(vault_address="0xvault") + + assert result == 1.0 + + async def test_get_vault_share_price_on_exception_returns_1( + self, reader_deps: ReaderDeps + ) -> None: + from decibel.read._vaults import VaultsReader + + reader_deps.aptos.view = AsyncMock(side_effect=Exception("aptos error")) + reader = VaultsReader(reader_deps) + + result = await reader.get_vault_share_price(vault_address="0xvault") + + assert result == 1.0 + + +# --------------------------------------------------------------------------- +# URL construction checks (spot-check a few readers) +# --------------------------------------------------------------------------- + + +class TestUrlConstruction: + async def test_account_overview_url(self, reader_deps: ReaderDeps) -> None: + from decibel.read._account_overview import AccountOverview, AccountOverviewReader + + overview = AccountOverview( + perp_equity_balance=0.0, + unrealized_pnl=0.0, + unrealized_funding_cost=0.0, + cross_margin_ratio=0.0, + maintenance_margin=0.0, + cross_account_leverage_ratio=None, + volume=None, + net_deposits=None, + all_time_return=None, + pnl_90d=None, + sharpe_ratio=None, + max_drawdown=None, + weekly_win_rate_12w=None, + average_cash_position=None, + average_leverage=None, + cross_account_position=0.0, + total_margin=0.0, + usdc_cross_withdrawable_balance=0.0, + usdc_isolated_withdrawable_balance=0.0, + realized_pnl=None, + liquidation_fees_paid=None, + liquidation_losses=None, + ) + reader = AccountOverviewReader(reader_deps) + with patch.object(reader, "get_request", new_callable=AsyncMock) as mock_req: + mock_req.return_value = (overview, 200, "OK") + await reader.get_by_addr(sub_addr="0xuser") + call_args = mock_req.call_args + # URL is the second positional arg or 'url' kwarg + url_arg = call_args.args[1] if len(call_args.args) > 1 else call_args.kwargs.get("url", "") + assert "/api/v1/account_overviews" in url_arg + + async def test_market_contexts_url(self, reader_deps: ReaderDeps) -> None: + from decibel.read._market_contexts import MarketContextsReader, _MarketContextList + + ctx_list = _MarketContextList([]) + reader = MarketContextsReader(reader_deps) + with patch.object(reader, "get_request", new_callable=AsyncMock) as mock_req: + mock_req.return_value = (ctx_list, 200, "OK") + await reader.get_all() + call_args = mock_req.call_args + url_arg = call_args.args[1] if len(call_args.args) > 1 else call_args.kwargs.get("url", "") + assert "/api/v1/asset_contexts" in url_arg + + async def test_trading_points_url(self, reader_deps: ReaderDeps) -> None: + from decibel.read._trading_points import OwnerTradingPoints, TradingPointsReader + + pts = OwnerTradingPoints(owner="0xowner", total_points=0.0, breakdown=None) + reader = TradingPointsReader(reader_deps) + with patch.object(reader, "get_request", new_callable=AsyncMock) as mock_req: + mock_req.return_value = (pts, 200, "OK") + await reader.get_by_owner(owner_addr="0xowner") + call_args = mock_req.call_args + url_arg = call_args.args[1] if len(call_args.args) > 1 else call_args.kwargs.get("url", "") + assert "/api/v1/points/trading/account" in url_arg diff --git a/tests/read/test_ws.py b/tests/read/test_ws.py new file mode 100644 index 0000000..66010c1 --- /dev/null +++ b/tests/read/test_ws.py @@ -0,0 +1,743 @@ +"""Tests for decibel.read._ws module (DecibelWsSubscription).""" + +from __future__ import annotations + +import asyncio +import json +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from pydantic import BaseModel + +from decibel.read._ws import DecibelWsSubscription + +# --------------------------------------------------------------------------- +# Helpers / fixtures +# --------------------------------------------------------------------------- + + +class _TestMessage(BaseModel): + value: int + + +@pytest.fixture +def ws_client(test_config: object) -> DecibelWsSubscription: + return DecibelWsSubscription(config=test_config, api_key="test-key") # type: ignore[arg-type] + + +@pytest.fixture +def ws_client_no_key(test_config: object) -> DecibelWsSubscription: + return DecibelWsSubscription(config=test_config) # type: ignore[arg-type] + + +# --------------------------------------------------------------------------- +# __init__ +# --------------------------------------------------------------------------- + + +class TestDecibelWsSubscriptionInit: + def test_stores_config(self, ws_client: DecibelWsSubscription, test_config: object) -> None: + assert ws_client._config is test_config + + def test_stores_api_key(self, ws_client: DecibelWsSubscription) -> None: + assert ws_client._api_key == "test-key" + + def test_api_key_defaults_to_none(self, ws_client_no_key: DecibelWsSubscription) -> None: + assert ws_client_no_key._api_key is None + + def test_on_error_defaults_to_none(self, ws_client: DecibelWsSubscription) -> None: + assert ws_client._on_error is None + + def test_on_error_stored(self, test_config: object) -> None: + on_error = MagicMock() + ws = DecibelWsSubscription(config=test_config, on_error=on_error) # type: ignore[arg-type] + assert ws._on_error is on_error + + def test_initial_state(self, ws_client: DecibelWsSubscription) -> None: + assert ws_client._ws is None + assert ws_client._subscriptions == {} + assert ws_client._reconnect_attempts == 0 + assert not ws_client._running + assert ws_client._receive_task is None + assert ws_client._close_timer_task is None + + +# --------------------------------------------------------------------------- +# _get_subscribe_message / _get_unsubscribe_message +# --------------------------------------------------------------------------- + + +class TestMessageHelpers: + def test_subscribe_message_format(self, ws_client: DecibelWsSubscription) -> None: + msg = ws_client._get_subscribe_message("market_price:0xabc") + data = json.loads(msg) + assert data["method"] == "subscribe" + assert data["topic"] == "market_price:0xabc" + + def test_unsubscribe_message_format(self, ws_client: DecibelWsSubscription) -> None: + msg = ws_client._get_unsubscribe_message("depth:0xabc:1") + data = json.loads(msg) + assert data["method"] == "unsubscribe" + assert data["topic"] == "depth:0xabc:1" + + def test_subscribe_message_is_valid_json(self, ws_client: DecibelWsSubscription) -> None: + msg = ws_client._get_subscribe_message("any:topic") + json.loads(msg) # Should not raise + + def test_unsubscribe_message_is_valid_json(self, ws_client: DecibelWsSubscription) -> None: + msg = ws_client._get_unsubscribe_message("any:topic") + json.loads(msg) # Should not raise + + +# --------------------------------------------------------------------------- +# _parse_message +# --------------------------------------------------------------------------- + + +class TestParseMessage: + def test_valid_message_with_topic(self, ws_client: DecibelWsSubscription) -> None: + raw = json.dumps({"topic": "market_price:0xabc", "value": 42}) + result = ws_client._parse_message(raw) + assert result is not None + topic, data = result + assert topic == "market_price:0xabc" + assert data["value"] == 42 + + def test_response_message_with_success_returns_none( + self, ws_client: DecibelWsSubscription + ) -> None: + raw = json.dumps({"topic": "market_price:0xabc", "success": True}) + result = ws_client._parse_message(raw) + assert result is None + + def test_invalid_json_raises_value_error(self, ws_client: DecibelWsSubscription) -> None: + with pytest.raises(ValueError, match="failed to parse JSON"): + ws_client._parse_message("not-valid-json{") + + def test_missing_topic_raises_value_error(self, ws_client: DecibelWsSubscription) -> None: + raw = json.dumps({"method": "subscribe"}) + with pytest.raises(ValueError, match="missing topic field"): + ws_client._parse_message(raw) + + def test_topic_not_string_raises_value_error(self, ws_client: DecibelWsSubscription) -> None: + raw = json.dumps({"topic": 123, "data": "something"}) + with pytest.raises(ValueError, match="missing topic field"): + ws_client._parse_message(raw) + + def test_strips_topic_from_data(self, ws_client: DecibelWsSubscription) -> None: + raw = json.dumps({"topic": "some:topic", "payload": "hello"}) + result = ws_client._parse_message(raw) + assert result is not None + _, data = result + assert "topic" not in data + assert data["payload"] == "hello" + + def test_bigint_reviver_applied(self, ws_client: DecibelWsSubscription) -> None: + raw = json.dumps({"topic": "some:topic", "nested": {"$bigint": "9999999999999999"}}) + result = ws_client._parse_message(raw) + assert result is not None + _, data = result + # The nested dict goes through bigint_reviver and should become an int + assert data["nested"] == 9999999999999999 + + +# --------------------------------------------------------------------------- +# ready_state +# --------------------------------------------------------------------------- + + +class TestReadyState: + def test_closed_when_no_ws(self, ws_client: DecibelWsSubscription) -> None: + assert ws_client.ready_state() == 3 + + def test_open_state(self, ws_client: DecibelWsSubscription) -> None: + mock_ws = MagicMock() + mock_ws.state.name = "OPEN" + ws_client._ws = mock_ws + assert ws_client.ready_state() == 1 + + def test_closing_state(self, ws_client: DecibelWsSubscription) -> None: + mock_ws = MagicMock() + mock_ws.state.name = "CLOSING" + ws_client._ws = mock_ws + assert ws_client.ready_state() == 2 + + def test_connecting_state(self, ws_client: DecibelWsSubscription) -> None: + mock_ws = MagicMock() + mock_ws.state.name = "CONNECTING" + ws_client._ws = mock_ws + assert ws_client.ready_state() == 0 + + +# --------------------------------------------------------------------------- +# subscribe +# --------------------------------------------------------------------------- + + +class TestSubscribe: + async def test_subscribe_adds_listener_and_opens( + self, ws_client: DecibelWsSubscription + ) -> None: + on_data = MagicMock() + + with patch.object(ws_client, "_open", new_callable=AsyncMock): + unsubscribe = ws_client.subscribe("test:topic", _TestMessage, on_data) + await asyncio.sleep(0) # Let tasks run + + assert "test:topic" in ws_client._subscriptions + assert len(ws_client._subscriptions["test:topic"]) == 1 + assert callable(unsubscribe) + + async def test_subscribe_sends_subscribe_when_ws_exists( + self, ws_client: DecibelWsSubscription + ) -> None: + mock_ws = AsyncMock() + mock_ws.send = AsyncMock() + ws_client._ws = mock_ws + + on_data = MagicMock() + ws_client.subscribe("new:topic", _TestMessage, on_data) + await asyncio.sleep(0) # Let tasks run + + mock_ws.send.assert_called() + + async def test_subscribe_returns_callable_unsubscribe( + self, ws_client: DecibelWsSubscription + ) -> None: + on_data = MagicMock() + with patch.object(ws_client, "_open", new_callable=AsyncMock): + unsubscribe = ws_client.subscribe("test:topic", _TestMessage, on_data) + assert callable(unsubscribe) + + async def test_subscribe_cancels_close_timer(self, ws_client: DecibelWsSubscription) -> None: + mock_task = MagicMock() + ws_client._close_timer_task = mock_task + + on_data = MagicMock() + with patch.object(ws_client, "_open", new_callable=AsyncMock): + ws_client.subscribe("test:topic", _TestMessage, on_data) + + mock_task.cancel.assert_called_once() + assert ws_client._close_timer_task is None + + async def test_subscribe_multiple_listeners_same_topic( + self, ws_client: DecibelWsSubscription + ) -> None: + on_data1 = MagicMock() + on_data2 = MagicMock() + + with patch.object(ws_client, "_open", new_callable=AsyncMock): + ws_client.subscribe("test:topic", _TestMessage, on_data1) + ws_client.subscribe("test:topic", _TestMessage, on_data2) + + assert len(ws_client._subscriptions["test:topic"]) == 2 + + +# --------------------------------------------------------------------------- +# _unsubscribe_listener / _unsubscribe_topic +# --------------------------------------------------------------------------- + + +class TestUnsubscribeListener: + async def test_unsubscribe_removes_listener(self, ws_client: DecibelWsSubscription) -> None: + on_data = MagicMock() + with patch.object(ws_client, "_open", new_callable=AsyncMock): + unsubscribe = ws_client.subscribe("test:topic", _TestMessage, on_data) + assert len(ws_client._subscriptions["test:topic"]) == 1 + + unsubscribe() + assert "test:topic" not in ws_client._subscriptions + + async def test_unsubscribe_nonexistent_topic_is_safe( + self, ws_client: DecibelWsSubscription + ) -> None: + # Should not raise + ws_client._unsubscribe_listener("nonexistent:topic", MagicMock()) + + async def test_unsubscribe_sends_unsub_message_when_ws_open( + self, ws_client: DecibelWsSubscription + ) -> None: + mock_ws = AsyncMock() + mock_ws.send = AsyncMock() + ws_client._ws = mock_ws + + on_data = MagicMock() + with patch.object(ws_client, "_open", new_callable=AsyncMock): + unsubscribe = ws_client.subscribe("test:topic", _TestMessage, on_data) + + # Clear mock calls from subscribe + mock_ws.send.reset_mock() + + unsubscribe() + await asyncio.sleep(0) + + # Unsubscribe message should be sent + mock_ws.send.assert_called() + + async def test_unsubscribe_topic_not_in_subs_is_safe( + self, ws_client: DecibelWsSubscription + ) -> None: + # Should not raise + ws_client._unsubscribe_topic("nonexistent:topic") + + +# --------------------------------------------------------------------------- +# _open +# --------------------------------------------------------------------------- + + +class TestOpen: + async def test_open_connects_and_sets_ws(self, ws_client: DecibelWsSubscription) -> None: + mock_ws = AsyncMock() + mock_ws.send = AsyncMock() + + with patch("decibel.read._ws.connect", return_value=mock_ws) as mock_connect: + mock_connect.__aenter__ = AsyncMock(return_value=mock_ws) + + # Make connect() awaitable returning mock_ws + async def fake_connect(*args, **kwargs): # noqa: ANN202 + return mock_ws + + with patch("decibel.read._ws.connect", side_effect=fake_connect): + await ws_client._open() + + assert ws_client._ws is mock_ws + assert ws_client._reconnect_attempts == 0 + assert ws_client._running + + async def test_open_noop_when_already_connected(self, ws_client: DecibelWsSubscription) -> None: + existing_ws = AsyncMock() + ws_client._ws = existing_ws + + with patch("decibel.read._ws.connect") as mock_connect: + await ws_client._open() + mock_connect.assert_not_called() + + async def test_open_subscribes_to_existing_topics( + self, ws_client: DecibelWsSubscription + ) -> None: + mock_ws = AsyncMock() + mock_ws.send = AsyncMock() + # Pre-add a subscription topic (simulating state before connection) + ws_client._subscriptions["existing:topic"] = set() + + async def fake_connect(*args, **kwargs): # noqa: ANN202 + return mock_ws + + with patch("decibel.read._ws.connect", side_effect=fake_connect): + await ws_client._open() + + mock_ws.send.assert_called() + sent_msg = json.loads(mock_ws.send.call_args_list[0][0][0]) + assert sent_msg["topic"] == "existing:topic" + + async def test_open_handles_connection_failure(self, ws_client: DecibelWsSubscription) -> None: + on_error = MagicMock() + ws_client._on_error = on_error + + async def failing_connect(*args, **kwargs): # noqa: ANN202 + raise ConnectionError("refused") + + with patch("decibel.read._ws.connect", side_effect=failing_connect): + with patch.object(ws_client, "_schedule_reconnect", new_callable=AsyncMock): + await ws_client._open() + + on_error.assert_called_once() + assert ws_client._ws is None + + +# --------------------------------------------------------------------------- +# _schedule_reconnect (exponential backoff) +# --------------------------------------------------------------------------- + + +class TestScheduleReconnect: + async def test_no_reconnect_when_no_subscriptions( + self, ws_client: DecibelWsSubscription + ) -> None: + with patch("decibel.read._ws.asyncio.sleep", new_callable=AsyncMock) as mock_sleep: + with patch.object(ws_client, "_open", new_callable=AsyncMock) as mock_open: + await ws_client._schedule_reconnect() + mock_sleep.assert_not_called() + mock_open.assert_not_called() + + async def test_reconnect_with_subscriptions_uses_backoff( + self, ws_client: DecibelWsSubscription + ) -> None: + ws_client._subscriptions["some:topic"] = set() + ws_client._reconnect_attempts = 0 + + with patch("decibel.read._ws.asyncio.sleep", new_callable=AsyncMock) as mock_sleep: + with patch.object(ws_client, "_open", new_callable=AsyncMock): + await ws_client._schedule_reconnect() + + # First attempt: delay = 1.5^0 = 1.0 + mock_sleep.assert_called_once_with(1.0) + + async def test_reconnect_increments_attempts(self, ws_client: DecibelWsSubscription) -> None: + ws_client._subscriptions["some:topic"] = set() + ws_client._reconnect_attempts = 2 + + with patch("decibel.read._ws.asyncio.sleep", new_callable=AsyncMock): + with patch.object(ws_client, "_open", new_callable=AsyncMock): + await ws_client._schedule_reconnect() + + assert ws_client._reconnect_attempts == 3 + + async def test_reconnect_delay_capped_at_60(self, ws_client: DecibelWsSubscription) -> None: + ws_client._subscriptions["some:topic"] = set() + ws_client._reconnect_attempts = 100 # Very high attempt count + + with patch("decibel.read._ws.asyncio.sleep", new_callable=AsyncMock) as mock_sleep: + with patch.object(ws_client, "_open", new_callable=AsyncMock): + await ws_client._schedule_reconnect() + + call_arg = mock_sleep.call_args[0][0] + assert call_arg == 60.0 + + async def test_exponential_backoff_formula(self, ws_client: DecibelWsSubscription) -> None: + ws_client._subscriptions["some:topic"] = set() + + for attempt in range(5): + ws_client._reconnect_attempts = attempt + expected = min(1.5**attempt, 60.0) + + with patch("decibel.read._ws.asyncio.sleep", new_callable=AsyncMock) as mock_sleep: + with patch.object(ws_client, "_open", new_callable=AsyncMock): + await ws_client._schedule_reconnect() + + actual = mock_sleep.call_args[0][0] + assert abs(actual - expected) < 0.0001, f"attempt={attempt}" + ws_client._reconnect_attempts = attempt # reset for next iteration + + +# --------------------------------------------------------------------------- +# close +# --------------------------------------------------------------------------- + + +class TestClose: + async def test_close_clears_subscriptions(self, ws_client: DecibelWsSubscription) -> None: + ws_client._subscriptions["topic:a"] = set() + ws_client._subscriptions["topic:b"] = set() + + await ws_client.close() + + assert ws_client._subscriptions == {} + + async def test_close_cancels_close_timer(self, ws_client: DecibelWsSubscription) -> None: + mock_task = MagicMock() + ws_client._close_timer_task = mock_task + + await ws_client.close() + + mock_task.cancel.assert_called_once() + assert ws_client._close_timer_task is None + + async def test_close_cancels_receive_task(self, ws_client: DecibelWsSubscription) -> None: + # Create a real asyncio task that sleeps forever so we can cancel it + async def _forever() -> None: + await asyncio.sleep(9999) + + real_task = asyncio.create_task(_forever()) + ws_client._receive_task = real_task + + await ws_client.close() + + assert real_task.cancelled() + assert ws_client._receive_task is None + + async def test_close_closes_ws(self, ws_client: DecibelWsSubscription) -> None: + mock_ws = AsyncMock() + mock_ws.close = AsyncMock() + ws_client._ws = mock_ws + + await ws_client.close() + + mock_ws.close.assert_called_once() + assert ws_client._ws is None + + async def test_close_with_no_ws_is_safe(self, ws_client: DecibelWsSubscription) -> None: + # Should not raise + await ws_client.close() + + +# --------------------------------------------------------------------------- +# reset +# --------------------------------------------------------------------------- + + +class TestReset: + async def test_reset_sends_unsub_then_sub(self, ws_client: DecibelWsSubscription) -> None: + mock_ws = AsyncMock() + mock_ws.send = AsyncMock() + ws_client._ws = mock_ws + ws_client._subscriptions["test:topic"] = {MagicMock()} + + ws_client.reset("test:topic") + await asyncio.sleep(0) + + # Both unsub and sub messages should have been sent + assert mock_ws.send.call_count == 2 + first_msg = json.loads(mock_ws.send.call_args_list[0][0][0]) + second_msg = json.loads(mock_ws.send.call_args_list[1][0][0]) + assert first_msg["method"] == "unsubscribe" + assert second_msg["method"] == "subscribe" + assert first_msg["topic"] == "test:topic" + assert second_msg["topic"] == "test:topic" + + def test_reset_noop_when_topic_not_subscribed(self, ws_client: DecibelWsSubscription) -> None: + mock_ws = AsyncMock() + ws_client._ws = mock_ws + + ws_client.reset("nonexistent:topic") + # No tasks created - no assert needed, just confirm no error + + def test_reset_noop_when_no_ws(self, ws_client: DecibelWsSubscription) -> None: + ws_client._subscriptions["test:topic"] = {MagicMock()} + ws_client._ws = None + + # Should not raise and should not do anything + ws_client.reset("test:topic") + + +# --------------------------------------------------------------------------- +# Listener invocation and error handling +# --------------------------------------------------------------------------- + + +class TestListenerInvocation: + async def test_subscribe_listener_parses_model(self, ws_client: DecibelWsSubscription) -> None: + received: list[_TestMessage] = [] + + def on_data(msg: _TestMessage) -> None: + received.append(msg) + + with patch.object(ws_client, "_open", new_callable=AsyncMock): + ws_client.subscribe("test:topic", _TestMessage, on_data) + + # Directly call the internal listener with raw data + listeners = list(ws_client._subscriptions["test:topic"]) + listener = listeners[0] + listener({"value": 99}) + assert len(received) == 1 + assert received[0].value == 99 + + async def test_subscribe_listener_raises_on_invalid_data( + self, ws_client: DecibelWsSubscription + ) -> None: + def on_data(msg: _TestMessage) -> None: + pass + + with patch.object(ws_client, "_open", new_callable=AsyncMock): + ws_client.subscribe("test:topic", _TestMessage, on_data) + + listeners = list(ws_client._subscriptions["test:topic"]) + listener = listeners[0] + + with pytest.raises(ValueError, match="Validation error"): + listener({"not_value": "bad"}) + + +# --------------------------------------------------------------------------- +# _receive_loop +# --------------------------------------------------------------------------- + + +class TestReceiveLoop: + async def test_receive_loop_exits_when_ws_is_none( + self, ws_client: DecibelWsSubscription + ) -> None: + ws_client._ws = None + # Should return immediately without error + await ws_client._receive_loop() + + async def test_receive_loop_dispatches_to_listener( + self, ws_client: DecibelWsSubscription + ) -> None: + """Test that _receive_loop dispatches valid messages to listeners.""" + received: list[dict] = [] + + async def fake_aiter(msg_list: list[str]): # noqa: ANN202 + for m in msg_list: + yield m + + messages = [json.dumps({"topic": "test:topic", "value": 99})] + mock_ws = MagicMock() + mock_ws.__aiter__ = MagicMock(return_value=fake_aiter(messages).__aiter__()) + + ws_client._ws = mock_ws + ws_client._subscriptions["test:topic"] = {lambda d: received.append(d)} # type: ignore[arg-type] + + with patch.object(ws_client, "_schedule_reconnect", new_callable=AsyncMock): + await ws_client._receive_loop() + + assert len(received) == 1 + + async def test_receive_loop_ignores_response_messages( + self, ws_client: DecibelWsSubscription + ) -> None: + """Subscribe/unsubscribe confirmation messages (success field) are ignored.""" + + async def fake_aiter(msg_list: list[str]): # noqa: ANN202 + for m in msg_list: + yield m + + messages = [json.dumps({"topic": "test:topic", "success": True})] + mock_ws = MagicMock() + mock_ws.__aiter__ = MagicMock(return_value=fake_aiter(messages).__aiter__()) + ws_client._ws = mock_ws + + received: list[dict] = [] + ws_client._subscriptions["test:topic"] = {lambda d: received.append(d)} # type: ignore[arg-type] + + with patch.object(ws_client, "_schedule_reconnect", new_callable=AsyncMock): + await ws_client._receive_loop() + + assert len(received) == 0 + + async def test_receive_loop_handles_listener_exception( + self, ws_client: DecibelWsSubscription + ) -> None: + """Listener exceptions are caught and logged without crashing the loop.""" + + async def fake_aiter(msg_list: list[str]): # noqa: ANN202 + for m in msg_list: + yield m + + messages = [json.dumps({"topic": "test:topic", "value": 1})] + mock_ws = MagicMock() + mock_ws.__aiter__ = MagicMock(return_value=fake_aiter(messages).__aiter__()) + ws_client._ws = mock_ws + + def bad_listener(d: dict) -> None: + raise RuntimeError("listener error") + + ws_client._subscriptions["test:topic"] = {bad_listener} + + # Should not raise + with patch.object(ws_client, "_schedule_reconnect", new_callable=AsyncMock): + await ws_client._receive_loop() + + async def test_receive_loop_handles_connection_closed( + self, ws_client: DecibelWsSubscription + ) -> None: + """ConnectionClosed is silently swallowed.""" + from websockets import ConnectionClosed as WsConnectionClosed + + async def fake_aiter(): # noqa: ANN202 + raise WsConnectionClosed(None, None) # type: ignore[arg-type] + yield # make it an async generator + + mock_ws = MagicMock() + mock_ws.__aiter__ = MagicMock(return_value=fake_aiter().__aiter__()) + ws_client._ws = mock_ws + + with patch.object(ws_client, "_schedule_reconnect", new_callable=AsyncMock): + await ws_client._receive_loop() # Should not raise + + async def test_receive_loop_calls_on_error_on_exception( + self, ws_client: DecibelWsSubscription + ) -> None: + on_error = MagicMock() + ws_client._on_error = on_error + + async def fake_aiter(): # noqa: ANN202 + raise RuntimeError("connection error") + yield + + mock_ws = MagicMock() + mock_ws.__aiter__ = MagicMock(return_value=fake_aiter().__aiter__()) + ws_client._ws = mock_ws + + with patch.object(ws_client, "_schedule_reconnect", new_callable=AsyncMock): + await ws_client._receive_loop() + + on_error.assert_called_once() + + async def test_receive_loop_handles_coroutine_listener( + self, ws_client: DecibelWsSubscription + ) -> None: + """Async listener coroutines are awaited.""" + received: list[dict] = [] + + async def async_listener(d: dict) -> None: + received.append(d) + + async def fake_aiter(msg_list: list[str]): # noqa: ANN202 + for m in msg_list: + yield m + + messages = [json.dumps({"topic": "test:topic", "value": 1})] + mock_ws = MagicMock() + mock_ws.__aiter__ = MagicMock(return_value=fake_aiter(messages).__aiter__()) + ws_client._ws = mock_ws + ws_client._subscriptions["test:topic"] = {async_listener} + + with patch.object(ws_client, "_schedule_reconnect", new_callable=AsyncMock): + await ws_client._receive_loop() + + assert len(received) == 1 + + async def test_receive_loop_reconnects_when_subscriptions_remain( + self, ws_client: DecibelWsSubscription + ) -> None: + """After loop ends, if subs remain, _schedule_reconnect is called.""" + + async def fake_aiter(): # noqa: ANN202 + return + yield # make it an async generator + + mock_ws = MagicMock() + mock_ws.__aiter__ = MagicMock(return_value=fake_aiter().__aiter__()) + ws_client._ws = mock_ws + ws_client._subscriptions["some:topic"] = set() + + with patch.object( + ws_client, "_schedule_reconnect", new_callable=AsyncMock + ) as mock_reconnect: + await ws_client._receive_loop() + + mock_reconnect.assert_called_once() + + +# --------------------------------------------------------------------------- +# _delayed_close +# --------------------------------------------------------------------------- + + +class TestDelayedClose: + async def test_delayed_close_closes_ws_when_no_subs( + self, ws_client: DecibelWsSubscription + ) -> None: + mock_ws = AsyncMock() + mock_ws.close = AsyncMock() + ws_client._ws = mock_ws + + with patch("decibel.read._ws.asyncio.sleep", new_callable=AsyncMock): + await ws_client._delayed_close() + + mock_ws.close.assert_called_once() + assert ws_client._ws is None + + async def test_delayed_close_skips_close_when_subs_remain( + self, ws_client: DecibelWsSubscription + ) -> None: + mock_ws = AsyncMock() + mock_ws.close = AsyncMock() + ws_client._ws = mock_ws + # Add a subscription so close is skipped + ws_client._subscriptions["active:topic"] = {MagicMock()} + + with patch("decibel.read._ws.asyncio.sleep", new_callable=AsyncMock): + await ws_client._delayed_close() + + mock_ws.close.assert_not_called() + + async def test_delayed_close_skips_when_ws_already_none( + self, ws_client: DecibelWsSubscription + ) -> None: + ws_client._ws = None + # Should not raise + with patch("decibel.read._ws.asyncio.sleep", new_callable=AsyncMock): + await ws_client._delayed_close() diff --git a/tests/test_admin.py b/tests/test_admin.py new file mode 100644 index 0000000..d109d92 --- /dev/null +++ b/tests/test_admin.py @@ -0,0 +1,1103 @@ +""" +Comprehensive unit tests for src/decibel/admin.py. + +Tests cover DecibelAdminDex (async) and DecibelAdminDexSync (sync) classes. + +Strategy: mock _send_tx at the instance level so no real HTTP calls or +blockchain interactions happen. The tests verify that: + 1. The correct Move function name is assembled from the package address. + 2. The correct arguments are passed to InputEntryFunctionData. + 3. usdc_balance works for both AccountAddress and str inputs. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from aptos_sdk.account_address import AccountAddress + +from decibel.admin import DecibelAdminDex, DecibelAdminDexSync + +if TYPE_CHECKING: + from decibel._transaction_builder import InputEntryFunctionData + +# --------------------------------------------------------------------------- +# Constants +# --------------------------------------------------------------------------- +TEST_PACKAGE = "0x" + "ab" * 32 +TEST_USDC = "0x" + "cd" * 32 +TEST_PERP_ENGINE = "0x" + "12" * 32 +TEST_ACCOUNT_ADDR = "0x" + "aa" * 32 +TEST_TX_HASH = "0xdeadbeef" +TEST_MARKET_NAME = "ETH-USD" +TEST_MARKET_ADDR = "0x" + "11" * 32 +TEST_VAULT_ADDR = "0x" + "cc" * 32 +TEST_COLLATERAL_ADDR = "0x" + "dd" * 32 +TEST_BACKSTOP_ADDR = "0x" + "ee" * 32 +TEST_DELEGATE_ADDR = "0x" + "ff" * 32 + + +def _make_tx_response(hash_val: str = TEST_TX_HASH) -> dict[str, Any]: + return {"hash": hash_val, "success": True, "events": []} + + +# --------------------------------------------------------------------------- +# Fixtures – async (DecibelAdminDex) +# --------------------------------------------------------------------------- + + +@pytest.fixture +def admin_dex(test_config, mock_account) -> DecibelAdminDex: + """Return a DecibelAdminDex instance with _send_tx mocked out.""" + with patch("decibel.admin.BaseSDK.__init__", return_value=None): + dex = DecibelAdminDex.__new__(DecibelAdminDex) + dex._config = test_config + dex._account = mock_account + dex._http_client = AsyncMock() + dex._skip_simulate = False + dex._no_fee_payer = False + dex._node_api_key = None + dex._gas_price_manager = None + dex._time_delta_ms = 0 + dex._chain_id = 2 + dex._abi_registry = MagicMock() + dex._send_tx = AsyncMock(return_value=_make_tx_response()) + dex._aptos = AsyncMock() + return dex + + +# --------------------------------------------------------------------------- +# Fixtures – sync (DecibelAdminDexSync) +# --------------------------------------------------------------------------- + + +@pytest.fixture +def admin_dex_sync(test_config, mock_account) -> DecibelAdminDexSync: + """Return a DecibelAdminDexSync instance with _send_tx mocked out.""" + with patch("decibel.admin.BaseSDKSync.__init__", return_value=None): + dex = DecibelAdminDexSync.__new__(DecibelAdminDexSync) + dex._config = test_config + dex._account = mock_account + dex._http_client = MagicMock() + dex._skip_simulate = False + dex._no_fee_payer = False + dex._node_api_key = None + dex._gas_price_manager = None + dex._time_delta_ms = 0 + dex._chain_id = 2 + dex._abi_registry = MagicMock() + dex._send_tx = MagicMock(return_value=_make_tx_response()) + dex._owns_http_client = False + return dex + + +# =========================================================================== +# Tests for DecibelAdminDex.__init__ +# =========================================================================== + + +class TestDecibelAdminDexInit: + def test_init_calls_super(self, test_config, mock_account) -> None: + with patch("decibel.admin.BaseSDK.__init__") as mock_super: + mock_super.return_value = None + dex = DecibelAdminDex.__new__(DecibelAdminDex) + DecibelAdminDex.__init__(dex, test_config, mock_account) + mock_super.assert_called_once_with(test_config, mock_account, None) + + +# =========================================================================== +# Tests for get_protocol_vault_address (async) +# =========================================================================== + + +class TestGetProtocolVaultAddress: + def test_returns_account_address(self, admin_dex: DecibelAdminDex) -> None: + result = admin_dex.get_protocol_vault_address() + assert isinstance(result, AccountAddress) + + def test_is_deterministic(self, admin_dex: DecibelAdminDex) -> None: + result1 = admin_dex.get_protocol_vault_address() + result2 = admin_dex.get_protocol_vault_address() + assert str(result1) == str(result2) + + +# =========================================================================== +# Tests for initialize (async) +# =========================================================================== + + +class TestInitialize: + async def test_sends_correct_function(self, admin_dex: DecibelAdminDex) -> None: + result = await admin_dex.initialize( + collateral_token_addr=TEST_COLLATERAL_ADDR, + backstop_liquidator_addr=TEST_BACKSTOP_ADDR, + ) + + admin_dex._send_tx.assert_awaited_once() + payload: InputEntryFunctionData = admin_dex._send_tx.call_args.args[0] + assert payload.function == f"{TEST_PACKAGE}::admin_apis::initialize" + assert payload.type_arguments == [] + assert payload.function_arguments == [TEST_COLLATERAL_ADDR, TEST_BACKSTOP_ADDR] + assert result == _make_tx_response() + + async def test_returns_tx_response(self, admin_dex: DecibelAdminDex) -> None: + custom_response = {"hash": "0xcafe", "success": True} + admin_dex._send_tx.return_value = custom_response + + result = await admin_dex.initialize( + collateral_token_addr=TEST_COLLATERAL_ADDR, + backstop_liquidator_addr=TEST_BACKSTOP_ADDR, + ) + assert result == custom_response + + +# =========================================================================== +# Tests for initialize_protocol_vault (async) +# =========================================================================== + + +class TestInitializeProtocolVault: + async def test_sends_create_and_fund_vault(self, admin_dex: DecibelAdminDex) -> None: + with patch( + "decibel.admin.BaseSDK.get_primary_subaccount_address", + return_value=TEST_ACCOUNT_ADDR, + ): + await admin_dex.initialize_protocol_vault( + collateral_token_addr=TEST_COLLATERAL_ADDR, + initial_funding=500_000, + ) + + payload: InputEntryFunctionData = admin_dex._send_tx.call_args.args[0] + assert payload.function == f"{TEST_PACKAGE}::vault_api::create_and_fund_vault" + + async def test_sets_correct_vault_name(self, admin_dex: DecibelAdminDex) -> None: + with patch( + "decibel.admin.BaseSDK.get_primary_subaccount_address", + return_value=TEST_ACCOUNT_ADDR, + ): + await admin_dex.initialize_protocol_vault( + collateral_token_addr=TEST_COLLATERAL_ADDR, + initial_funding=0, + ) + + payload: InputEntryFunctionData = admin_dex._send_tx.call_args.args[0] + args = payload.function_arguments + # args[0] = subaccount_addr, args[1] = collateral, args[2] = vault_name + assert args[2] == "Decibel Protocol Vault" + assert args[5] == "DPV" # vault_share_symbol + assert args[12] is True # accepts_contributions + assert args[13] is False # delegate_to_creator + + async def test_passes_initial_funding(self, admin_dex: DecibelAdminDex) -> None: + with patch( + "decibel.admin.BaseSDK.get_primary_subaccount_address", + return_value=TEST_ACCOUNT_ADDR, + ): + await admin_dex.initialize_protocol_vault( + collateral_token_addr=TEST_COLLATERAL_ADDR, + initial_funding=999_999, + ) + + payload: InputEntryFunctionData = admin_dex._send_tx.call_args.args[0] + assert payload.function_arguments[11] == 999_999 + + +# =========================================================================== +# Tests for delegate_protocol_vault_trading_to (async) +# =========================================================================== + + +class TestDelegateProtocolVaultTradingTo: + async def test_sends_correct_function(self, admin_dex: DecibelAdminDex) -> None: + await admin_dex.delegate_protocol_vault_trading_to( + vault_address=TEST_VAULT_ADDR, + account_to_delegate_to=TEST_DELEGATE_ADDR, + ) + + payload: InputEntryFunctionData = admin_dex._send_tx.call_args.args[0] + assert payload.function == f"{TEST_PACKAGE}::vault_admin_api::delegate_dex_actions_to" + assert payload.function_arguments == [TEST_VAULT_ADDR, TEST_DELEGATE_ADDR, None] + + +# =========================================================================== +# Tests for update_vault_use_global_redemption_slippage_adjustment (async) +# =========================================================================== + + +class TestUpdateVaultRedemptionSlippage: + async def test_sends_correct_function_true(self, admin_dex: DecibelAdminDex) -> None: + await admin_dex.update_vault_use_global_redemption_slippage_adjustment( + vault_address=TEST_VAULT_ADDR, + use_global_redemption_slippage_adjustment=True, + ) + + payload: InputEntryFunctionData = admin_dex._send_tx.call_args.args[0] + assert ( + payload.function == f"{TEST_PACKAGE}::vault_admin_api" + "::update_vault_use_global_redemption_slippage_adjustment" + ) + assert payload.function_arguments == [TEST_VAULT_ADDR, True] + + async def test_sends_correct_function_false(self, admin_dex: DecibelAdminDex) -> None: + await admin_dex.update_vault_use_global_redemption_slippage_adjustment( + vault_address=TEST_VAULT_ADDR, + use_global_redemption_slippage_adjustment=False, + ) + + payload: InputEntryFunctionData = admin_dex._send_tx.call_args.args[0] + assert payload.function_arguments == [TEST_VAULT_ADDR, False] + + +# =========================================================================== +# Tests for authorize_oracle_and_mark_update (async) +# =========================================================================== + + +class TestAuthorizeOracleAndMarkUpdate: + async def test_sends_correct_function(self, admin_dex: DecibelAdminDex) -> None: + oracle_updater = "0x" + "11" * 32 + await admin_dex.authorize_oracle_and_mark_update(oracle_updater) + + payload: InputEntryFunctionData = admin_dex._send_tx.call_args.args[0] + assert ( + payload.function == f"{TEST_PACKAGE}::admin_apis::add_oracle_and_mark_update_permission" + ) + assert payload.function_arguments == [oracle_updater] + + +# =========================================================================== +# Tests for add_access_control_admin (async) +# =========================================================================== + + +class TestAddAccessControlAdmin: + async def test_sends_correct_function(self, admin_dex: DecibelAdminDex) -> None: + delegated = "0x" + "11" * 32 + await admin_dex.add_access_control_admin(delegated) + + payload: InputEntryFunctionData = admin_dex._send_tx.call_args.args[0] + assert payload.function == f"{TEST_PACKAGE}::admin_apis::add_access_control_admin" + assert payload.function_arguments == [delegated] + + +# =========================================================================== +# Tests for add_market_list_admin (async) +# =========================================================================== + + +class TestAddMarketListAdmin: + async def test_sends_correct_function(self, admin_dex: DecibelAdminDex) -> None: + delegated = "0x" + "22" * 32 + await admin_dex.add_market_list_admin(delegated) + + payload: InputEntryFunctionData = admin_dex._send_tx.call_args.args[0] + assert payload.function == f"{TEST_PACKAGE}::admin_apis::add_market_list_admin" + assert payload.function_arguments == [delegated] + + +# =========================================================================== +# Tests for add_market_risk_governor (async) +# =========================================================================== + + +class TestAddMarketRiskGovernor: + async def test_sends_correct_function(self, admin_dex: DecibelAdminDex) -> None: + delegated = "0x" + "33" * 32 + await admin_dex.add_market_risk_governor(delegated) + + payload: InputEntryFunctionData = admin_dex._send_tx.call_args.args[0] + assert payload.function == f"{TEST_PACKAGE}::admin_apis::add_market_risk_governor" + assert payload.function_arguments == [delegated] + + +# =========================================================================== +# Tests for register_market_with_internal_oracle (async) +# =========================================================================== + + +class TestRegisterMarketWithInternalOracle: + async def test_sends_correct_function(self, admin_dex: DecibelAdminDex) -> None: + await admin_dex.register_market_with_internal_oracle( + name="BTC-USD", + sz_decimals=3, + min_size=1, + lot_size=1, + ticker_size=1, + max_open_interest=1_000_000, + max_leverage=20, + margin_call_fee_pct=500, + ) + + payload: InputEntryFunctionData = admin_dex._send_tx.call_args.args[0] + assert ( + payload.function == f"{TEST_PACKAGE}::admin_apis::register_market_with_internal_oracle" + ) + args = payload.function_arguments + assert args[0] == "BTC-USD" + assert args[1] == 3 + assert args[2] == 1 # min_size + assert args[6] == 20 # max_leverage + assert args[7] == 500 # margin_call_fee_pct + assert args[8] is True # taker_in_next_block default + assert args[9] == 1 # initial_oracle_price default + assert args[10] == 60 # max_staleness_secs default + + async def test_custom_defaults_overridden(self, admin_dex: DecibelAdminDex) -> None: + await admin_dex.register_market_with_internal_oracle( + name="ETH-USD", + sz_decimals=4, + min_size=10, + lot_size=5, + ticker_size=2, + max_open_interest=5_000_000, + max_leverage=50, + margin_call_fee_pct=200, + taker_in_next_block=False, + initial_oracle_price=3000, + max_staleness_secs=120, + ) + + payload: InputEntryFunctionData = admin_dex._send_tx.call_args.args[0] + args = payload.function_arguments + assert args[8] is False + assert args[9] == 3000 + assert args[10] == 120 + + +# =========================================================================== +# Tests for register_market_with_pyth_oracle (async) +# =========================================================================== + + +class TestRegisterMarketWithPythOracle: + async def test_sends_correct_function(self, admin_dex: DecibelAdminDex) -> None: + pyth_bytes = list(range(32)) + await admin_dex.register_market_with_pyth_oracle( + name="SOL-USD", + sz_decimals=2, + min_size=1, + lot_size=1, + ticker_size=1, + max_open_interest=1_000_000, + max_leverage=10, + margin_call_fee_pct=500, + pyth_identifier_bytes=pyth_bytes, + pyth_max_staleness_secs=30, + pyth_confidence_interval_threshold=100, + pyth_decimals=8, + ) + + payload: InputEntryFunctionData = admin_dex._send_tx.call_args.args[0] + assert payload.function == f"{TEST_PACKAGE}::admin_apis::register_market_with_pyth_oracle" + args = payload.function_arguments + assert args[0] == "SOL-USD" + assert args[8] is True # taker_in_next_block default + assert args[9] == pyth_bytes + assert args[10] == 30 + assert args[11] == 100 + assert args[12] == 8 + + +# =========================================================================== +# Tests for register_market_with_composite_oracle_primary_pyth (async) +# =========================================================================== + + +class TestRegisterMarketWithCompositePyth: + async def test_sends_correct_function(self, admin_dex: DecibelAdminDex) -> None: + pyth_bytes = list(range(32)) + await admin_dex.register_market_with_composite_oracle_primary_pyth( + name="AVAX-USD", + sz_decimals=2, + min_size=1, + lot_size=1, + ticker_size=1, + max_open_interest=500_000, + max_leverage=10, + margin_call_fee_pct=300, + pyth_identifier_bytes=pyth_bytes, + pyth_max_staleness_secs=30, + pyth_confidence_interval_threshold=200, + pyth_decimals=8, + internal_initial_price=25, + internal_max_staleness_secs=60, + oracles_deviation_bps=100, + consecutive_deviation_count=3, + ) + + payload: InputEntryFunctionData = admin_dex._send_tx.call_args.args[0] + assert ( + payload.function + == f"{TEST_PACKAGE}::admin_apis::register_market_with_composite_oracle_primary_pyth" + ) + args = payload.function_arguments + # 0=name, 1=sz_dec, 2=min_size, 3=lot, 4=ticker, 5=max_oi, 6=max_lev, 7=margin, + # 8=taker_in_next_block, 9=pyth_bytes, 10=pyth_staleness, 11=pyth_ci, 12=pyth_dec, + # 13=internal_initial_price, 14=internal_max_staleness, 15=oracles_dev_bps, 16=dev_count + assert args[0] == "AVAX-USD" + assert args[13] == 25 # internal_initial_price + assert args[14] == 60 # internal_max_staleness_secs + assert args[15] == 100 # oracles_deviation_bps + assert args[16] == 3 # consecutive_deviation_count + + +# =========================================================================== +# Tests for register_market_with_composite_oracle_primary_chainlink (async) +# =========================================================================== + + +class TestRegisterMarketWithCompositeChainlink: + async def test_sends_correct_function(self, admin_dex: DecibelAdminDex) -> None: + chainlink_bytes = list(range(32)) + await admin_dex.register_market_with_composite_oracle_primary_chainlink( + name="BNB-USD", + sz_decimals=2, + min_size=1, + lot_size=1, + ticker_size=1, + max_open_interest=300_000, + max_leverage=10, + margin_call_fee_pct=300, + rescale_decimals=8, + chainlink_feed_id_bytes=chainlink_bytes, + chainlink_max_staleness_secs=30, + internal_max_staleness_secs=60, + internal_initial_price=250, + oracles_deviation_bps=50, + consecutive_deviation_count=5, + ) + + payload: InputEntryFunctionData = admin_dex._send_tx.call_args.args[0] + assert ( + payload.function == f"{TEST_PACKAGE}::admin_apis" + "::register_market_with_composite_oracle_primary_chainlink" + ) + args = payload.function_arguments + assert args[0] == "BNB-USD" + assert args[9] == chainlink_bytes + assert args[11] == 8 # rescale_decimals + assert args[12] == 250 # internal_initial_price + + +# =========================================================================== +# Tests for update_internal_oracle_price (async) +# =========================================================================== + + +class TestUpdateInternalOraclePrice: + async def test_sends_correct_function(self, admin_dex: DecibelAdminDex) -> None: + with patch("decibel.admin.get_market_addr", return_value=TEST_MARKET_ADDR): + await admin_dex.update_internal_oracle_price( + market_name=TEST_MARKET_NAME, oracle_price=3000 + ) + + payload: InputEntryFunctionData = admin_dex._send_tx.call_args.args[0] + assert payload.function == f"{TEST_PACKAGE}::admin_apis::update_mark_for_internal_oracle" + assert payload.function_arguments == [TEST_MARKET_ADDR, 3000, [], [], True] + + async def test_resolves_market_addr(self, admin_dex: DecibelAdminDex) -> None: + with patch("decibel.admin.get_market_addr", return_value=TEST_MARKET_ADDR) as mock_get_addr: + await admin_dex.update_internal_oracle_price(market_name="SOL-USD", oracle_price=100) + + mock_get_addr.assert_called_once_with( + "SOL-USD", admin_dex._config.deployment.perp_engine_global + ) + + +# =========================================================================== +# Tests for update_pyth_oracle_price (async) +# =========================================================================== + + +class TestUpdatePythOraclePrice: + async def test_sends_correct_function(self, admin_dex: DecibelAdminDex) -> None: + vaa = [1, 2, 3, 4] + with patch("decibel.admin.get_market_addr", return_value=TEST_MARKET_ADDR): + await admin_dex.update_pyth_oracle_price(market_name=TEST_MARKET_NAME, vaa=vaa) + + payload: InputEntryFunctionData = admin_dex._send_tx.call_args.args[0] + assert payload.function == f"{TEST_PACKAGE}::admin_apis::update_mark_for_pyth_oracle" + assert payload.function_arguments == [TEST_MARKET_ADDR, vaa, [], [], True] + + +# =========================================================================== +# Tests for set_market_adl_trigger_threshold (async) +# =========================================================================== + + +class TestSetMarketAdlTriggerThreshold: + async def test_sends_correct_function(self, admin_dex: DecibelAdminDex) -> None: + with patch("decibel.admin.get_market_addr", return_value=TEST_MARKET_ADDR): + await admin_dex.set_market_adl_trigger_threshold( + market_name=TEST_MARKET_NAME, threshold=500 + ) + + payload: InputEntryFunctionData = admin_dex._send_tx.call_args.args[0] + assert payload.function == f"{TEST_PACKAGE}::admin_apis::set_market_adl_trigger_threshold" + assert payload.function_arguments == [TEST_MARKET_ADDR, 500] + + +# =========================================================================== +# Tests for update_price_to_pyth_only (async) +# =========================================================================== + + +class TestUpdatePriceToPythOnly: + async def test_sends_correct_function(self, admin_dex: DecibelAdminDex) -> None: + vaas = [[1, 2], [3, 4]] + await admin_dex.update_price_to_pyth_only(vaas=vaas) + + payload: InputEntryFunctionData = admin_dex._send_tx.call_args.args[0] + assert payload.function == f"{TEST_PACKAGE}::pyth::update_price_feeds_with_funder" + assert payload.function_arguments == [vaas] + + +# =========================================================================== +# Tests for update_price_to_chainlink_only (async) +# =========================================================================== + + +class TestUpdatePriceToChainlinkOnly: + async def test_sends_correct_function(self, admin_dex: DecibelAdminDex) -> None: + signed_report = [1, 2, 3, 4, 5] + await admin_dex.update_price_to_chainlink_only(signed_report=signed_report) + + payload: InputEntryFunctionData = admin_dex._send_tx.call_args.args[0] + assert payload.function == f"{TEST_PACKAGE}::chainlink_state::verify_and_store_single_price" + assert payload.function_arguments == [signed_report] + + +# =========================================================================== +# Tests for mint_usdc (async) +# =========================================================================== + + +class TestMintUsdc: + async def test_sends_correct_function_with_str_addr(self, admin_dex: DecibelAdminDex) -> None: + to_addr = "0x" + "11" * 32 + await admin_dex.mint_usdc(to_addr=to_addr, amount=1_000_000) + + payload: InputEntryFunctionData = admin_dex._send_tx.call_args.args[0] + assert payload.function == f"{TEST_PACKAGE}::usdc::mint" + assert payload.function_arguments == [to_addr, 1_000_000] + + async def test_sends_correct_function_with_account_address( + self, admin_dex: DecibelAdminDex + ) -> None: + addr = AccountAddress.from_str("0x" + "11" * 32) + await admin_dex.mint_usdc(to_addr=addr, amount=500_000) + + payload: InputEntryFunctionData = admin_dex._send_tx.call_args.args[0] + # AccountAddress should be converted to str + assert isinstance(payload.function_arguments[0], str) + assert payload.function_arguments[1] == 500_000 + + async def test_converts_account_address_to_str(self, admin_dex: DecibelAdminDex) -> None: + addr = AccountAddress.from_str("0x" + "22" * 32) + await admin_dex.mint_usdc(to_addr=addr, amount=100) + + payload: InputEntryFunctionData = admin_dex._send_tx.call_args.args[0] + assert payload.function_arguments[0] == str(addr) + + +# =========================================================================== +# Tests for set_public_minting (async) +# =========================================================================== + + +class TestSetPublicMinting: + async def test_allows_minting(self, admin_dex: DecibelAdminDex) -> None: + await admin_dex.set_public_minting(allow=True) + + payload: InputEntryFunctionData = admin_dex._send_tx.call_args.args[0] + assert payload.function == f"{TEST_PACKAGE}::usdc::set_public_minting" + assert payload.function_arguments == [True] + + async def test_disallows_minting(self, admin_dex: DecibelAdminDex) -> None: + await admin_dex.set_public_minting(allow=False) + + payload: InputEntryFunctionData = admin_dex._send_tx.call_args.args[0] + assert payload.function_arguments == [False] + + +# =========================================================================== +# Tests for usdc_balance (async) +# =========================================================================== + + +class TestUsdcBalance: + async def test_returns_balance_for_str_addr(self, admin_dex: DecibelAdminDex) -> None: + admin_dex._aptos.view = AsyncMock(return_value=["1000000"]) + result = await admin_dex.usdc_balance(addr="0x" + "11" * 32) + assert result == 1_000_000 + + async def test_returns_balance_for_account_address(self, admin_dex: DecibelAdminDex) -> None: + admin_dex._aptos.view = AsyncMock(return_value=["500000"]) + addr = AccountAddress.from_str("0x" + "11" * 32) + result = await admin_dex.usdc_balance(addr=addr) + assert result == 500_000 + + async def test_converts_account_address_to_str(self, admin_dex: DecibelAdminDex) -> None: + addr = AccountAddress.from_str("0x" + "33" * 32) + admin_dex._aptos.view = AsyncMock(return_value=["0"]) + await admin_dex.usdc_balance(addr=addr) + + call_args = admin_dex._aptos.view.call_args + # Third argument should be a list with addr string and usdc address + assert str(addr) in call_args.args[2] + + +# =========================================================================== +# Tests for DecibelAdminDexSync.__init__ +# =========================================================================== + + +class TestDecibelAdminDexSyncInit: + def test_init_calls_super(self, test_config, mock_account) -> None: + with patch("decibel.admin.BaseSDKSync.__init__") as mock_super: + mock_super.return_value = None + dex = DecibelAdminDexSync.__new__(DecibelAdminDexSync) + DecibelAdminDexSync.__init__(dex, test_config, mock_account) + mock_super.assert_called_once_with(test_config, mock_account, None) + + +# =========================================================================== +# Tests for DecibelAdminDexSync.get_protocol_vault_address +# =========================================================================== + + +class TestDecibelAdminDexSyncGetProtocolVaultAddress: + def test_returns_account_address(self, admin_dex_sync: DecibelAdminDexSync) -> None: + result = admin_dex_sync.get_protocol_vault_address() + assert isinstance(result, AccountAddress) + + def test_is_deterministic(self, admin_dex_sync: DecibelAdminDexSync) -> None: + result1 = admin_dex_sync.get_protocol_vault_address() + result2 = admin_dex_sync.get_protocol_vault_address() + assert str(result1) == str(result2) + + def test_matches_async_version( + self, admin_dex: DecibelAdminDex, admin_dex_sync: DecibelAdminDexSync + ) -> None: + async_result = admin_dex.get_protocol_vault_address() + sync_result = admin_dex_sync.get_protocol_vault_address() + assert str(async_result) == str(sync_result) + + +# =========================================================================== +# Tests for DecibelAdminDexSync.initialize +# =========================================================================== + + +class TestDecibelAdminDexSyncInitialize: + def test_sends_correct_function(self, admin_dex_sync: DecibelAdminDexSync) -> None: + result = admin_dex_sync.initialize( + collateral_token_addr=TEST_COLLATERAL_ADDR, + backstop_liquidator_addr=TEST_BACKSTOP_ADDR, + ) + + admin_dex_sync._send_tx.assert_called_once() + payload: InputEntryFunctionData = admin_dex_sync._send_tx.call_args.args[0] + assert payload.function == f"{TEST_PACKAGE}::admin_apis::initialize" + assert payload.function_arguments == [TEST_COLLATERAL_ADDR, TEST_BACKSTOP_ADDR] + assert result == _make_tx_response() + + +# =========================================================================== +# Tests for DecibelAdminDexSync.initialize_protocol_vault +# =========================================================================== + + +class TestDecibelAdminDexSyncInitializeProtocolVault: + def test_sends_create_and_fund_vault(self, admin_dex_sync: DecibelAdminDexSync) -> None: + with patch( + "decibel.admin.BaseSDKSync.get_primary_subaccount_address", + return_value=TEST_ACCOUNT_ADDR, + ): + admin_dex_sync.initialize_protocol_vault( + collateral_token_addr=TEST_COLLATERAL_ADDR, + initial_funding=1_000_000, + ) + + payload: InputEntryFunctionData = admin_dex_sync._send_tx.call_args.args[0] + assert payload.function == f"{TEST_PACKAGE}::vault_api::create_and_fund_vault" + assert payload.function_arguments[11] == 1_000_000 # initial_funding + + +# =========================================================================== +# Tests for DecibelAdminDexSync.delegate_protocol_vault_trading_to +# =========================================================================== + + +class TestDecibelAdminDexSyncDelegateProtocolVaultTradingTo: + def test_sends_correct_function(self, admin_dex_sync: DecibelAdminDexSync) -> None: + admin_dex_sync.delegate_protocol_vault_trading_to( + vault_address=TEST_VAULT_ADDR, + account_to_delegate_to=TEST_DELEGATE_ADDR, + ) + + payload: InputEntryFunctionData = admin_dex_sync._send_tx.call_args.args[0] + assert payload.function == f"{TEST_PACKAGE}::vault_admin_api::delegate_dex_actions_to" + assert payload.function_arguments == [TEST_VAULT_ADDR, TEST_DELEGATE_ADDR, None] + + +# =========================================================================== +# Tests for DecibelAdminDexSync.update_vault_use_global_redemption_slippage_adjustment +# =========================================================================== + + +class TestDecibelAdminDexSyncUpdateVaultRedemptionSlippage: + def test_sends_correct_function(self, admin_dex_sync: DecibelAdminDexSync) -> None: + admin_dex_sync.update_vault_use_global_redemption_slippage_adjustment( + vault_address=TEST_VAULT_ADDR, + use_global_redemption_slippage_adjustment=True, + ) + + payload: InputEntryFunctionData = admin_dex_sync._send_tx.call_args.args[0] + assert ( + payload.function == f"{TEST_PACKAGE}::vault_admin_api" + "::update_vault_use_global_redemption_slippage_adjustment" + ) + assert payload.function_arguments == [TEST_VAULT_ADDR, True] + + +# =========================================================================== +# Tests for DecibelAdminDexSync.authorize_oracle_and_mark_update +# =========================================================================== + + +class TestDecibelAdminDexSyncAuthorizeOracle: + def test_sends_correct_function(self, admin_dex_sync: DecibelAdminDexSync) -> None: + oracle_updater = "0x" + "11" * 32 + admin_dex_sync.authorize_oracle_and_mark_update(oracle_updater) + + payload: InputEntryFunctionData = admin_dex_sync._send_tx.call_args.args[0] + assert ( + payload.function == f"{TEST_PACKAGE}::admin_apis::add_oracle_and_mark_update_permission" + ) + assert payload.function_arguments == [oracle_updater] + + +# =========================================================================== +# Tests for DecibelAdminDexSync.add_access_control_admin +# =========================================================================== + + +class TestDecibelAdminDexSyncAddAccessControlAdmin: + def test_sends_correct_function(self, admin_dex_sync: DecibelAdminDexSync) -> None: + delegated = "0x" + "22" * 32 + admin_dex_sync.add_access_control_admin(delegated) + + payload: InputEntryFunctionData = admin_dex_sync._send_tx.call_args.args[0] + assert payload.function == f"{TEST_PACKAGE}::admin_apis::add_access_control_admin" + assert payload.function_arguments == [delegated] + + +# =========================================================================== +# Tests for DecibelAdminDexSync.add_market_list_admin +# =========================================================================== + + +class TestDecibelAdminDexSyncAddMarketListAdmin: + def test_sends_correct_function(self, admin_dex_sync: DecibelAdminDexSync) -> None: + delegated = "0x" + "33" * 32 + admin_dex_sync.add_market_list_admin(delegated) + + payload: InputEntryFunctionData = admin_dex_sync._send_tx.call_args.args[0] + assert payload.function == f"{TEST_PACKAGE}::admin_apis::add_market_list_admin" + + +# =========================================================================== +# Tests for DecibelAdminDexSync.add_market_risk_governor +# =========================================================================== + + +class TestDecibelAdminDexSyncAddMarketRiskGovernor: + def test_sends_correct_function(self, admin_dex_sync: DecibelAdminDexSync) -> None: + delegated = "0x" + "44" * 32 + admin_dex_sync.add_market_risk_governor(delegated) + + payload: InputEntryFunctionData = admin_dex_sync._send_tx.call_args.args[0] + assert payload.function == f"{TEST_PACKAGE}::admin_apis::add_market_risk_governor" + + +# =========================================================================== +# Tests for DecibelAdminDexSync.register_market_with_internal_oracle +# =========================================================================== + + +class TestDecibelAdminDexSyncRegisterMarketInternal: + def test_sends_correct_function(self, admin_dex_sync: DecibelAdminDexSync) -> None: + admin_dex_sync.register_market_with_internal_oracle( + name="BTC-USD", + sz_decimals=3, + min_size=1, + lot_size=1, + ticker_size=1, + max_open_interest=1_000_000, + max_leverage=20, + margin_call_fee_pct=500, + ) + + payload: InputEntryFunctionData = admin_dex_sync._send_tx.call_args.args[0] + assert ( + payload.function == f"{TEST_PACKAGE}::admin_apis::register_market_with_internal_oracle" + ) + assert payload.function_arguments[0] == "BTC-USD" + assert payload.function_arguments[8] is True # taker_in_next_block default + + +# =========================================================================== +# Tests for DecibelAdminDexSync.register_market_with_pyth_oracle +# =========================================================================== + + +class TestDecibelAdminDexSyncRegisterMarketPyth: + def test_sends_correct_function(self, admin_dex_sync: DecibelAdminDexSync) -> None: + pyth_bytes = list(range(32)) + admin_dex_sync.register_market_with_pyth_oracle( + name="ETH-USD", + sz_decimals=4, + min_size=1, + lot_size=1, + ticker_size=1, + max_open_interest=2_000_000, + max_leverage=15, + margin_call_fee_pct=400, + pyth_identifier_bytes=pyth_bytes, + pyth_max_staleness_secs=45, + pyth_confidence_interval_threshold=150, + pyth_decimals=8, + taker_in_next_block=False, + ) + + payload: InputEntryFunctionData = admin_dex_sync._send_tx.call_args.args[0] + assert payload.function == f"{TEST_PACKAGE}::admin_apis::register_market_with_pyth_oracle" + assert payload.function_arguments[8] is False # taker_in_next_block + + +# =========================================================================== +# Tests for DecibelAdminDexSync.register_market_with_composite_oracle_primary_pyth +# =========================================================================== + + +class TestDecibelAdminDexSyncRegisterMarketCompositePyth: + def test_sends_correct_function(self, admin_dex_sync: DecibelAdminDexSync) -> None: + pyth_bytes = list(range(32)) + admin_dex_sync.register_market_with_composite_oracle_primary_pyth( + name="ARB-USD", + sz_decimals=2, + min_size=1, + lot_size=1, + ticker_size=1, + max_open_interest=300_000, + max_leverage=10, + margin_call_fee_pct=250, + pyth_identifier_bytes=pyth_bytes, + pyth_max_staleness_secs=30, + pyth_confidence_interval_threshold=100, + pyth_decimals=8, + internal_initial_price=1, + internal_max_staleness_secs=60, + oracles_deviation_bps=50, + consecutive_deviation_count=2, + ) + + payload: InputEntryFunctionData = admin_dex_sync._send_tx.call_args.args[0] + assert ( + payload.function + == f"{TEST_PACKAGE}::admin_apis::register_market_with_composite_oracle_primary_pyth" + ) + + +# =========================================================================== +# Tests for DecibelAdminDexSync.register_market_with_composite_oracle_primary_chainlink +# =========================================================================== + + +class TestDecibelAdminDexSyncRegisterMarketCompositeChainlink: + def test_sends_correct_function(self, admin_dex_sync: DecibelAdminDexSync) -> None: + chainlink_bytes = list(range(32)) + admin_dex_sync.register_market_with_composite_oracle_primary_chainlink( + name="LINK-USD", + sz_decimals=2, + min_size=1, + lot_size=1, + ticker_size=1, + max_open_interest=200_000, + max_leverage=10, + margin_call_fee_pct=300, + rescale_decimals=8, + chainlink_feed_id_bytes=chainlink_bytes, + chainlink_max_staleness_secs=30, + internal_max_staleness_secs=60, + internal_initial_price=15, + oracles_deviation_bps=75, + consecutive_deviation_count=4, + ) + + payload: InputEntryFunctionData = admin_dex_sync._send_tx.call_args.args[0] + assert ( + payload.function == f"{TEST_PACKAGE}::admin_apis" + "::register_market_with_composite_oracle_primary_chainlink" + ) + assert payload.function_arguments[11] == 8 # rescale_decimals + + +# =========================================================================== +# Tests for DecibelAdminDexSync.update_internal_oracle_price +# =========================================================================== + + +class TestDecibelAdminDexSyncUpdateInternalOraclePrice: + def test_sends_correct_function(self, admin_dex_sync: DecibelAdminDexSync) -> None: + with patch("decibel.admin.get_market_addr", return_value=TEST_MARKET_ADDR): + admin_dex_sync.update_internal_oracle_price( + market_name=TEST_MARKET_NAME, oracle_price=2500 + ) + + payload: InputEntryFunctionData = admin_dex_sync._send_tx.call_args.args[0] + assert payload.function == f"{TEST_PACKAGE}::admin_apis::update_mark_for_internal_oracle" + assert payload.function_arguments == [TEST_MARKET_ADDR, 2500, [], [], True] + + +# =========================================================================== +# Tests for DecibelAdminDexSync.update_pyth_oracle_price +# =========================================================================== + + +class TestDecibelAdminDexSyncUpdatePythOraclePrice: + def test_sends_correct_function(self, admin_dex_sync: DecibelAdminDexSync) -> None: + vaa = [10, 20, 30] + with patch("decibel.admin.get_market_addr", return_value=TEST_MARKET_ADDR): + admin_dex_sync.update_pyth_oracle_price(market_name=TEST_MARKET_NAME, vaa=vaa) + + payload: InputEntryFunctionData = admin_dex_sync._send_tx.call_args.args[0] + assert payload.function == f"{TEST_PACKAGE}::admin_apis::update_mark_for_pyth_oracle" + assert payload.function_arguments == [TEST_MARKET_ADDR, vaa, [], [], True] + + +# =========================================================================== +# Tests for DecibelAdminDexSync.set_market_adl_trigger_threshold +# =========================================================================== + + +class TestDecibelAdminDexSyncSetMarketAdlTriggerThreshold: + def test_sends_correct_function(self, admin_dex_sync: DecibelAdminDexSync) -> None: + with patch("decibel.admin.get_market_addr", return_value=TEST_MARKET_ADDR): + admin_dex_sync.set_market_adl_trigger_threshold( + market_name=TEST_MARKET_NAME, threshold=750 + ) + + payload: InputEntryFunctionData = admin_dex_sync._send_tx.call_args.args[0] + assert payload.function == f"{TEST_PACKAGE}::admin_apis::set_market_adl_trigger_threshold" + assert payload.function_arguments == [TEST_MARKET_ADDR, 750] + + +# =========================================================================== +# Tests for DecibelAdminDexSync.update_price_to_pyth_only +# =========================================================================== + + +class TestDecibelAdminDexSyncUpdatePricePythOnly: + def test_sends_correct_function(self, admin_dex_sync: DecibelAdminDexSync) -> None: + vaas = [[1, 2], [3, 4]] + admin_dex_sync.update_price_to_pyth_only(vaas=vaas) + + payload: InputEntryFunctionData = admin_dex_sync._send_tx.call_args.args[0] + assert payload.function == f"{TEST_PACKAGE}::pyth::update_price_feeds_with_funder" + assert payload.function_arguments == [vaas] + + +# =========================================================================== +# Tests for DecibelAdminDexSync.update_price_to_chainlink_only +# =========================================================================== + + +class TestDecibelAdminDexSyncUpdatePriceChainlinkOnly: + def test_sends_correct_function(self, admin_dex_sync: DecibelAdminDexSync) -> None: + signed_report = [5, 10, 15] + admin_dex_sync.update_price_to_chainlink_only(signed_report=signed_report) + + payload: InputEntryFunctionData = admin_dex_sync._send_tx.call_args.args[0] + assert payload.function == f"{TEST_PACKAGE}::chainlink_state::verify_and_store_single_price" + assert payload.function_arguments == [signed_report] + + +# =========================================================================== +# Tests for DecibelAdminDexSync.mint_usdc +# =========================================================================== + + +class TestDecibelAdminDexSyncMintUsdc: + def test_sends_correct_function_with_str(self, admin_dex_sync: DecibelAdminDexSync) -> None: + to_addr = "0x" + "11" * 32 + admin_dex_sync.mint_usdc(to_addr=to_addr, amount=2_000_000) + + payload: InputEntryFunctionData = admin_dex_sync._send_tx.call_args.args[0] + assert payload.function == f"{TEST_PACKAGE}::usdc::mint" + assert payload.function_arguments == [to_addr, 2_000_000] + + def test_sends_correct_function_with_account_address( + self, admin_dex_sync: DecibelAdminDexSync + ) -> None: + addr = AccountAddress.from_str("0x" + "55" * 32) + admin_dex_sync.mint_usdc(to_addr=addr, amount=300_000) + + payload: InputEntryFunctionData = admin_dex_sync._send_tx.call_args.args[0] + assert payload.function_arguments[0] == str(addr) + assert payload.function_arguments[1] == 300_000 + + +# =========================================================================== +# Tests for DecibelAdminDexSync.set_public_minting +# =========================================================================== + + +class TestDecibelAdminDexSyncSetPublicMinting: + def test_allows_minting(self, admin_dex_sync: DecibelAdminDexSync) -> None: + admin_dex_sync.set_public_minting(allow=True) + + payload: InputEntryFunctionData = admin_dex_sync._send_tx.call_args.args[0] + assert payload.function == f"{TEST_PACKAGE}::usdc::set_public_minting" + assert payload.function_arguments == [True] + + def test_disallows_minting(self, admin_dex_sync: DecibelAdminDexSync) -> None: + admin_dex_sync.set_public_minting(allow=False) + + payload: InputEntryFunctionData = admin_dex_sync._send_tx.call_args.args[0] + assert payload.function_arguments == [False] + + +# =========================================================================== +# Tests for DecibelAdminDexSync.usdc_balance +# =========================================================================== + + +class TestDecibelAdminDexSyncUsdcBalance: + def test_returns_balance_for_str_addr(self, admin_dex_sync: DecibelAdminDexSync) -> None: + mock_response = MagicMock() + mock_response.json.return_value = ["2000000"] + admin_dex_sync._http_client.post.return_value = mock_response + + result = admin_dex_sync.usdc_balance(addr="0x" + "11" * 32) + assert result == 2_000_000 + + def test_returns_balance_for_account_address(self, admin_dex_sync: DecibelAdminDexSync) -> None: + addr = AccountAddress.from_str("0x" + "22" * 32) + mock_response = MagicMock() + mock_response.json.return_value = ["750000"] + admin_dex_sync._http_client.post.return_value = mock_response + + result = admin_dex_sync.usdc_balance(addr=addr) + assert result == 750_000 + + def test_uses_http_client_when_available(self, admin_dex_sync: DecibelAdminDexSync) -> None: + mock_response = MagicMock() + mock_response.json.return_value = ["100"] + admin_dex_sync._http_client.post.return_value = mock_response + + admin_dex_sync.usdc_balance(addr="0x" + "11" * 32) + admin_dex_sync._http_client.post.assert_called_once() diff --git a/tests/test_base.py b/tests/test_base.py new file mode 100644 index 0000000..22e53cf --- /dev/null +++ b/tests/test_base.py @@ -0,0 +1,1797 @@ +"""Unit tests for decibel._base module. + +Covers BaseSDK and BaseSDKSync: init, context managers, build_tx, +gas price fetching, simulation, signing, submit, wait for transaction, +and _send_tx full flow. +""" + +from __future__ import annotations + +from typing import Any +from unittest.mock import AsyncMock, MagicMock, patch + +import httpx +import pytest + +from decibel._base import ( + BaseSDK, + BaseSDKOptions, + BaseSDKOptionsSync, + BaseSDKSync, + _poll_delay, +) +from decibel._exceptions import TxnConfirmError, TxnSubmitError +from decibel._fee_pay import PendingTransactionResponse + +# --------------------------------------------------------------------------- +# Helper factories +# --------------------------------------------------------------------------- + + +def _make_pending_response(tx_hash: str = "0xdeadbeef") -> PendingTransactionResponse: + return PendingTransactionResponse( + hash=tx_hash, + sender="0x" + "aa" * 32, + sequence_number="1", + max_gas_amount="200000", + gas_unit_price="100", + expiration_timestamp_secs="9999999999", + ) + + +def _make_httpx_response( + status_code: int = 200, + json_data: Any = None, + text: str = "", +) -> httpx.Response: + if json_data is not None: + return httpx.Response( + status_code=status_code, + json=json_data, + request=httpx.Request("GET", "https://test.example.com"), + ) + return httpx.Response( + status_code=status_code, + text=text, + request=httpx.Request("GET", "https://test.example.com"), + ) + + +def _make_sdk(config: Any, account: Any, opts: Any = None) -> BaseSDK: + with patch("decibel._base.AbiRegistry"), patch("decibel._base.RestClient"): + sdk = BaseSDK(config=config, account=account, opts=opts) + return sdk + + +def _make_sdk_sync(config: Any, account: Any, opts: Any = None) -> BaseSDKSync: + with patch("decibel._base.AbiRegistry"): + sdk = BaseSDKSync(config=config, account=account, opts=opts) + return sdk + + +# --------------------------------------------------------------------------- +# _poll_delay helper +# --------------------------------------------------------------------------- + + +class TestPollDelay: + def test_first_delay(self) -> None: + assert _poll_delay(0) == pytest.approx(0.2) + + def test_second_delay(self) -> None: + assert _poll_delay(1) == pytest.approx(0.2) + + def test_third_delay(self) -> None: + assert _poll_delay(2) == pytest.approx(0.5) + + def test_fourth_delay(self) -> None: + assert _poll_delay(3) == pytest.approx(0.5) + + def test_fifth_delay(self) -> None: + assert _poll_delay(4) == pytest.approx(1.0) + + def test_beyond_table_returns_one(self) -> None: + assert _poll_delay(5) == pytest.approx(1.0) + assert _poll_delay(100) == pytest.approx(1.0) + + +# --------------------------------------------------------------------------- +# BaseSDK.__init__ +# --------------------------------------------------------------------------- + + +@pytest.mark.usefixtures("test_config") +class TestBaseSDKInit: + @patch("decibel._base.AbiRegistry") + @patch("decibel._base.RestClient") + def test_creates_http_client(self, mock_rest: Any, mock_abi: Any, test_config: Any) -> None: + account = MagicMock() + sdk = BaseSDK(config=test_config, account=account) + assert isinstance(sdk._http_client, httpx.AsyncClient) + assert sdk._config is test_config + assert sdk._account is account + + @patch("decibel._base.AbiRegistry") + @patch("decibel._base.RestClient") + def test_default_opts(self, mock_rest: Any, mock_abi: Any, test_config: Any) -> None: + account = MagicMock() + sdk = BaseSDK(config=test_config, account=account) + assert sdk._skip_simulate is False + assert sdk._no_fee_payer is False + assert sdk._node_api_key is None + assert sdk._gas_price_manager is None + assert sdk._time_delta_ms == 0 + + @patch("decibel._base.AbiRegistry") + @patch("decibel._base.RestClient") + def test_custom_opts(self, mock_rest: Any, mock_abi: Any, test_config: Any) -> None: + account = MagicMock() + opts = BaseSDKOptions( + skip_simulate=True, + no_fee_payer=True, + node_api_key="nodekey", + time_delta_ms=500, + ) + sdk = BaseSDK(config=test_config, account=account, opts=opts) + assert sdk._skip_simulate is True + assert sdk._no_fee_payer is True + assert sdk._node_api_key == "nodekey" + assert sdk._time_delta_ms == 500 + + @patch("decibel._base.AbiRegistry") + @patch("decibel._base.RestClient") + def test_none_chain_id_logs_warning( + self, mock_rest: Any, mock_abi: Any, test_config: Any + ) -> None: + from dataclasses import replace + + config_no_chain = replace(test_config, chain_id=None) + account = MagicMock() + import logging + + with patch.object(logging.getLogger("decibel._base"), "warning") as mock_warn: + BaseSDK(config=config_no_chain, account=account) + mock_warn.assert_called_once() + + @patch("decibel._base.AbiRegistry") + @patch("decibel._base.RestClient") + def test_creates_abi_registry_and_rest_client( + self, mock_rest: Any, mock_abi: Any, test_config: Any + ) -> None: + account = MagicMock() + BaseSDK(config=test_config, account=account) + mock_abi.assert_called_once_with(chain_id=test_config.chain_id) + mock_rest.assert_called_once_with(test_config.fullnode_url) + + @patch("decibel._base.AbiRegistry") + @patch("decibel._base.RestClient") + def test_properties(self, mock_rest: Any, mock_abi: Any, test_config: Any) -> None: + account = MagicMock() + sdk = BaseSDK(config=test_config, account=account) + assert sdk.config is test_config + assert sdk.account is account + assert sdk.skip_simulate is False + assert sdk.no_fee_payer is False + assert sdk.time_delta_ms == 0 + + @patch("decibel._base.AbiRegistry") + @patch("decibel._base.RestClient") + def test_time_delta_ms_setter(self, mock_rest: Any, mock_abi: Any, test_config: Any) -> None: + account = MagicMock() + sdk = BaseSDK(config=test_config, account=account) + sdk.time_delta_ms = 1000 + assert sdk.time_delta_ms == 1000 + + +# --------------------------------------------------------------------------- +# BaseSDK.close / context manager +# --------------------------------------------------------------------------- + + +class TestBaseSDKClose: + @pytest.mark.asyncio + async def test_close_calls_aclose(self, test_config: Any, mock_account: Any) -> None: + sdk = _make_sdk(test_config, mock_account) + sdk._http_client = AsyncMock() + await sdk.close() + sdk._http_client.aclose.assert_awaited_once() + + @pytest.mark.asyncio + async def test_aenter_returns_self(self, test_config: Any, mock_account: Any) -> None: + sdk = _make_sdk(test_config, mock_account) + result = await sdk.__aenter__() + assert result is sdk + + @pytest.mark.asyncio + async def test_aexit_calls_close(self, test_config: Any, mock_account: Any) -> None: + sdk = _make_sdk(test_config, mock_account) + sdk._http_client = AsyncMock() + await sdk.__aexit__(None, None, None) + sdk._http_client.aclose.assert_awaited_once() + + @pytest.mark.asyncio + async def test_async_context_manager(self, test_config: Any, mock_account: Any) -> None: + sdk = _make_sdk(test_config, mock_account) + sdk._http_client = AsyncMock() + async with sdk as ctx: + assert ctx is sdk + sdk._http_client.aclose.assert_awaited_once() + + +# --------------------------------------------------------------------------- +# BaseSDK._fetch_gas_price_estimation +# --------------------------------------------------------------------------- + + +class TestBaseSDKFetchGasPrice: + @pytest.mark.asyncio + async def test_success(self, test_config: Any, mock_account: Any) -> None: + sdk = _make_sdk(test_config, mock_account) + sdk._http_client = AsyncMock() + sdk._http_client.get = AsyncMock( + return_value=_make_httpx_response(200, json_data={"gas_estimate": 150}) + ) + + price = await sdk._fetch_gas_price_estimation() + assert price == 150 + + @pytest.mark.asyncio + async def test_uses_default_when_no_gas_estimate_key( + self, test_config: Any, mock_account: Any + ) -> None: + sdk = _make_sdk(test_config, mock_account) + sdk._http_client = AsyncMock() + sdk._http_client.get = AsyncMock( + return_value=_make_httpx_response(200, json_data={"other": "data"}) + ) + + price = await sdk._fetch_gas_price_estimation() + # DEFAULT_GAS_ESTIMATE = 100 + assert price == 100 + + @pytest.mark.asyncio + async def test_failure_raises_value_error(self, test_config: Any, mock_account: Any) -> None: + sdk = _make_sdk(test_config, mock_account) + sdk._http_client = AsyncMock() + sdk._http_client.get = AsyncMock( + return_value=_make_httpx_response(500, text="Server Error") + ) + + with pytest.raises(ValueError, match="Failed to fetch gas price"): + await sdk._fetch_gas_price_estimation() + + @pytest.mark.asyncio + async def test_uses_node_api_key_in_headers(self, test_config: Any, mock_account: Any) -> None: + opts = BaseSDKOptions(node_api_key="my-node-key") + sdk = _make_sdk(test_config, mock_account, opts) + sdk._http_client = AsyncMock() + sdk._http_client.get = AsyncMock( + return_value=_make_httpx_response(200, json_data={"gas_estimate": 200}) + ) + + await sdk._fetch_gas_price_estimation() + call_kwargs = sdk._http_client.get.call_args.kwargs + assert call_kwargs["headers"]["x-api-key"] == "my-node-key" + + +# --------------------------------------------------------------------------- +# BaseSDK._simulate_transaction +# --------------------------------------------------------------------------- + + +class TestBaseSDKSimulateTransaction: + @pytest.mark.asyncio + async def test_success_returns_first_item(self, test_config: Any, mock_account: Any) -> None: + sdk = _make_sdk(test_config, mock_account) + sim_data = [{"max_gas_amount": "50000", "gas_unit_price": "100", "success": True}] + sdk._http_client = AsyncMock() + sdk._http_client.post = AsyncMock( + return_value=_make_httpx_response(200, json_data=sim_data) + ) + + mock_txn = MagicMock() + mock_txn.fee_payer_address = None + # Need _serialize_for_simulation to return bytes + with patch.object(sdk, "_serialize_for_simulation", return_value=b"\x00" * 16): + result = await sdk._simulate_transaction(mock_txn) + + assert result["max_gas_amount"] == "50000" + assert result["gas_unit_price"] == "100" + + @pytest.mark.asyncio + async def test_failure_raises_value_error(self, test_config: Any, mock_account: Any) -> None: + sdk = _make_sdk(test_config, mock_account) + sdk._http_client = AsyncMock() + sdk._http_client.post = AsyncMock( + return_value=_make_httpx_response(400, text="Bad Request") + ) + + mock_txn = MagicMock() + with patch.object(sdk, "_serialize_for_simulation", return_value=b"\x00" * 16): + with pytest.raises(ValueError, match="Transaction simulation failed"): + await sdk._simulate_transaction(mock_txn) + + @pytest.mark.asyncio + async def test_empty_list_raises_value_error(self, test_config: Any, mock_account: Any) -> None: + sdk = _make_sdk(test_config, mock_account) + sdk._http_client = AsyncMock() + sdk._http_client.post = AsyncMock(return_value=_make_httpx_response(200, json_data=[])) + + mock_txn = MagicMock() + with patch.object(sdk, "_serialize_for_simulation", return_value=b"\x00" * 16): + with pytest.raises(ValueError, match="empty results"): + await sdk._simulate_transaction(mock_txn) + + +# --------------------------------------------------------------------------- +# BaseSDK._submit_direct +# --------------------------------------------------------------------------- + + +class TestBaseSDKSubmitDirect: + @pytest.mark.asyncio + async def test_success(self, test_config: Any, mock_account: Any) -> None: + sdk = _make_sdk(test_config, mock_account) + + mock_raw_txn = MagicMock() + mock_raw_txn.sender = "0x" + "aa" * 32 + mock_raw_txn.sequence_number = 1 + mock_raw_txn.max_gas_amount = 200000 + mock_raw_txn.gas_unit_price = 100 + mock_raw_txn.expiration_timestamps_secs = 9999999999 + + mock_txn = MagicMock() + mock_txn.raw_transaction = mock_raw_txn + mock_txn.fee_payer_address = None + + mock_auth = MagicMock() + sdk._http_client = AsyncMock() + sdk._http_client.post = AsyncMock( + return_value=_make_httpx_response(200, json_data={"hash": "0xabc123"}) + ) + + with patch.object(sdk, "_serialize_signed_transaction", return_value=b"\x00" * 16): + result = await sdk._submit_direct(mock_txn, mock_auth) + + assert result.hash == "0xabc123" + assert isinstance(result, PendingTransactionResponse) + + @pytest.mark.asyncio + async def test_failure_raises_value_error(self, test_config: Any, mock_account: Any) -> None: + sdk = _make_sdk(test_config, mock_account) + mock_txn = MagicMock() + mock_auth = MagicMock() + sdk._http_client = AsyncMock() + sdk._http_client.post = AsyncMock( + return_value=_make_httpx_response(400, text="Bad Request") + ) + + with patch.object(sdk, "_serialize_signed_transaction", return_value=b"\x00" * 16): + with pytest.raises(ValueError, match="Transaction submission failed"): + await sdk._submit_direct(mock_txn, mock_auth) + + +# --------------------------------------------------------------------------- +# BaseSDK._wait_for_transaction +# --------------------------------------------------------------------------- + + +class TestBaseSDKWaitForTransaction: + @pytest.mark.asyncio + async def test_success_on_first_poll(self, test_config: Any, mock_account: Any) -> None: + sdk = _make_sdk(test_config, mock_account) + success_data = {"type": "user_transaction", "success": True, "hash": "0xabc"} + sdk._http_client = AsyncMock() + sdk._http_client.get = AsyncMock( + return_value=_make_httpx_response(200, json_data=success_data) + ) + + with patch.object(sdk, "_async_sleep", new_callable=AsyncMock): + result = await sdk._wait_for_transaction("0xabc", txn_confirm_timeout=30.0) + + assert result["hash"] == "0xabc" + assert result["success"] is True + + @pytest.mark.asyncio + async def test_pending_then_success(self, test_config: Any, mock_account: Any) -> None: + sdk = _make_sdk(test_config, mock_account) + pending_data = {"type": "pending_transaction"} + success_data = {"type": "user_transaction", "success": True, "hash": "0xabc"} + + sdk._http_client = AsyncMock() + sdk._http_client.get = AsyncMock( + side_effect=[ + _make_httpx_response(200, json_data=pending_data), + _make_httpx_response(200, json_data=success_data), + ] + ) + + with patch.object(sdk, "_async_sleep", new_callable=AsyncMock): + result = await sdk._wait_for_transaction("0xabc", txn_confirm_timeout=30.0) + + assert result["success"] is True + + @pytest.mark.asyncio + async def test_failure_vm_status_raises_txn_confirm_error( + self, test_config: Any, mock_account: Any + ) -> None: + sdk = _make_sdk(test_config, mock_account) + failed_data = {"type": "user_transaction", "success": False, "vm_status": "Out of gas"} + sdk._http_client = AsyncMock() + sdk._http_client.get = AsyncMock( + return_value=_make_httpx_response(200, json_data=failed_data) + ) + + with patch.object(sdk, "_async_sleep", new_callable=AsyncMock): + with pytest.raises(TxnConfirmError, match="failed: Out of gas"): + await sdk._wait_for_transaction("0xabc", txn_confirm_timeout=30.0) + + @pytest.mark.asyncio + async def test_timeout_raises_txn_confirm_error( + self, test_config: Any, mock_account: Any + ) -> None: + sdk = _make_sdk(test_config, mock_account) + pending_data = {"type": "pending_transaction"} + sdk._http_client = AsyncMock() + sdk._http_client.get = AsyncMock( + return_value=_make_httpx_response(200, json_data=pending_data) + ) + + # Very short timeout so it fires immediately + with patch("time.time", side_effect=[0.0, 100.0]): + with patch.object(sdk, "_async_sleep", new_callable=AsyncMock): + with pytest.raises(TxnConfirmError, match="did not confirm"): + await sdk._wait_for_transaction("0xabc", txn_confirm_timeout=0.001) + + @pytest.mark.asyncio + async def test_connect_timeout_is_swallowed(self, test_config: Any, mock_account: Any) -> None: + sdk = _make_sdk(test_config, mock_account) + success_data = {"type": "user_transaction", "success": True, "hash": "0xabc"} + sdk._http_client = AsyncMock() + sdk._http_client.get = AsyncMock( + side_effect=[ + httpx.ConnectTimeout("timeout"), + _make_httpx_response(200, json_data=success_data), + ] + ) + + with patch.object(sdk, "_async_sleep", new_callable=AsyncMock): + result = await sdk._wait_for_transaction("0xabc", txn_confirm_timeout=30.0) + + assert result["success"] is True + + @pytest.mark.asyncio + async def test_read_timeout_is_swallowed(self, test_config: Any, mock_account: Any) -> None: + sdk = _make_sdk(test_config, mock_account) + success_data = {"type": "user_transaction", "success": True, "hash": "0xabc"} + sdk._http_client = AsyncMock() + sdk._http_client.get = AsyncMock( + side_effect=[ + httpx.ReadTimeout("read timeout"), + _make_httpx_response(200, json_data=success_data), + ] + ) + + with patch.object(sdk, "_async_sleep", new_callable=AsyncMock): + result = await sdk._wait_for_transaction("0xabc", txn_confirm_timeout=30.0) + + assert result["success"] is True + + @pytest.mark.asyncio + async def test_connect_error_is_swallowed(self, test_config: Any, mock_account: Any) -> None: + sdk = _make_sdk(test_config, mock_account) + success_data = {"type": "user_transaction", "success": True, "hash": "0xabc"} + sdk._http_client = AsyncMock() + sdk._http_client.get = AsyncMock( + side_effect=[ + httpx.ConnectError("connection refused"), + _make_httpx_response(200, json_data=success_data), + ] + ) + + with patch.object(sdk, "_async_sleep", new_callable=AsyncMock): + result = await sdk._wait_for_transaction("0xabc", txn_confirm_timeout=30.0) + + assert result["success"] is True + + @pytest.mark.asyncio + async def test_uses_default_timeout_when_none( + self, test_config: Any, mock_account: Any + ) -> None: + sdk = _make_sdk(test_config, mock_account) + success_data = {"type": "user_transaction", "success": True} + sdk._http_client = AsyncMock() + sdk._http_client.get = AsyncMock( + return_value=_make_httpx_response(200, json_data=success_data) + ) + + with patch.object(sdk, "_async_sleep", new_callable=AsyncMock): + result = await sdk._wait_for_transaction("0xabc", txn_confirm_timeout=None) + + assert result["success"] is True + + +# --------------------------------------------------------------------------- +# BaseSDK.build_tx +# --------------------------------------------------------------------------- + + +class TestBaseSDKBuildTx: + @pytest.mark.asyncio + async def test_build_tx_with_gas_price_manager_cached( + self, test_config: Any, mock_account: Any + ) -> None: + mock_manager = MagicMock() + mock_manager.get_gas_price.return_value = 150 + opts = BaseSDKOptions(gas_price_manager=mock_manager) + sdk = _make_sdk(test_config, mock_account, opts) + + mock_abi = MagicMock() + mock_abi.params = ["&signer", "u64"] + sdk._abi_registry = MagicMock() + sdk._abi_registry.get_function.return_value = mock_abi + + sender = MagicMock() + + mock_txn = MagicMock() + with patch( + "decibel._base.build_simple_transaction_sync", return_value=mock_txn + ) as mock_build: + with patch("decibel._base.generate_random_replay_protection_nonce", return_value=12345): + result = await sdk.build_tx( + MagicMock(function="0x1::m::f", function_arguments=[42], type_arguments=[]), + sender, + ) + + assert result is mock_txn + call_kwargs = mock_build.call_args.kwargs + assert call_kwargs["gas_unit_price"] == 150 + + @pytest.mark.asyncio + async def test_build_tx_with_gas_price_manager_uncached( + self, test_config: Any, mock_account: Any + ) -> None: + # gas_price_manager.fetch_and_set_gas_price is awaited in build_tx, + # so use a regular MagicMock whose get_gas_price returns None and + # whose fetch_and_set_gas_price is an AsyncMock coroutine. + mock_manager = MagicMock() + mock_manager.get_gas_price.return_value = None + mock_manager.fetch_and_set_gas_price = AsyncMock(return_value=200) + opts = BaseSDKOptions(gas_price_manager=mock_manager) + sdk = _make_sdk(test_config, mock_account, opts) + + mock_abi = MagicMock() + mock_abi.params = ["u64"] + sdk._abi_registry = MagicMock() + sdk._abi_registry.get_function.return_value = mock_abi + + sender = MagicMock() + + mock_txn = MagicMock() + with patch( + "decibel._base.build_simple_transaction_sync", return_value=mock_txn + ) as mock_build: + with patch("decibel._base.generate_random_replay_protection_nonce", return_value=999): + await sdk.build_tx( + MagicMock(function="0x1::m::f", function_arguments=[42], type_arguments=[]), + sender, + ) + + call_kwargs = mock_build.call_args.kwargs + assert call_kwargs["gas_unit_price"] == 200 + + @pytest.mark.asyncio + async def test_build_tx_without_gas_manager_fetches_price( + self, test_config: Any, mock_account: Any + ) -> None: + sdk = _make_sdk(test_config, mock_account) + + mock_abi = MagicMock() + mock_abi.params = ["u64"] + sdk._abi_registry = MagicMock() + sdk._abi_registry.get_function.return_value = mock_abi + + sender = MagicMock() + mock_txn = MagicMock() + + with patch.object( + sdk, "_fetch_gas_price_estimation", new_callable=AsyncMock, return_value=123 + ): + with patch( + "decibel._base.build_simple_transaction_sync", return_value=mock_txn + ) as mock_build: + with patch( + "decibel._base.generate_random_replay_protection_nonce", return_value=111 + ): + await sdk.build_tx( + MagicMock(function="0x1::m::f", function_arguments=[1], type_arguments=[]), + sender, + ) + + call_kwargs = mock_build.call_args.kwargs + assert call_kwargs["gas_unit_price"] == 123 + + @pytest.mark.asyncio + async def test_build_tx_explicit_gas_unit_price( + self, test_config: Any, mock_account: Any + ) -> None: + sdk = _make_sdk(test_config, mock_account) + + mock_abi = MagicMock() + mock_abi.params = ["u64"] + sdk._abi_registry = MagicMock() + sdk._abi_registry.get_function.return_value = mock_abi + + sender = MagicMock() + mock_txn = MagicMock() + + with patch( + "decibel._base.build_simple_transaction_sync", return_value=mock_txn + ) as mock_build: + with patch("decibel._base.generate_random_replay_protection_nonce", return_value=222): + await sdk.build_tx( + MagicMock(function="0x1::m::f", function_arguments=[1], type_arguments=[]), + sender, + gas_unit_price=500, + ) + + call_kwargs = mock_build.call_args.kwargs + assert call_kwargs["gas_unit_price"] == 500 + + @pytest.mark.asyncio + async def test_build_tx_missing_abi_raises(self, test_config: Any, mock_account: Any) -> None: + sdk = _make_sdk(test_config, mock_account) + sdk._abi_registry = MagicMock() + sdk._abi_registry.get_function.return_value = None # Missing ABI + + sender = MagicMock() + + with patch("decibel._base.generate_random_replay_protection_nonce", return_value=333): + with pytest.raises(ValueError, match="Cannot build transaction"): + await sdk.build_tx( + MagicMock(function="0x1::m::unknown", function_arguments=[], type_arguments=[]), + sender, + ) + + @pytest.mark.asyncio + async def test_build_tx_nonce_none_raises(self, test_config: Any, mock_account: Any) -> None: + sdk = _make_sdk(test_config, mock_account) + + with patch("decibel._base.generate_random_replay_protection_nonce", return_value=None): + with pytest.raises(ValueError, match="replay protection nonce"): + await sdk.build_tx( + MagicMock(function="0x1::m::f", function_arguments=[], type_arguments=[]), + MagicMock(), + ) + + +# --------------------------------------------------------------------------- +# BaseSDK.submit_tx +# --------------------------------------------------------------------------- + + +class TestBaseSDKSubmitTx: + @pytest.mark.asyncio + async def test_no_fee_payer_calls_submit_direct( + self, test_config: Any, mock_account: Any + ) -> None: + opts = BaseSDKOptions(no_fee_payer=True) + sdk = _make_sdk(test_config, mock_account, opts) + mock_txn = MagicMock() + mock_auth = MagicMock() + expected = _make_pending_response() + + with patch.object(sdk, "_submit_direct", new_callable=AsyncMock, return_value=expected): + result = await sdk.submit_tx(mock_txn, mock_auth) + + assert result is expected + + @pytest.mark.asyncio + async def test_with_fee_payer_calls_submit_fee_paid( + self, test_config: Any, mock_account: Any + ) -> None: + opts = BaseSDKOptions(no_fee_payer=False) + sdk = _make_sdk(test_config, mock_account, opts) + mock_txn = MagicMock() + mock_auth = MagicMock() + expected = _make_pending_response() + + with patch( + "decibel._base.submit_fee_paid_transaction", + new_callable=AsyncMock, + return_value=expected, + ): + result = await sdk.submit_tx(mock_txn, mock_auth) + + assert result is expected + + +# --------------------------------------------------------------------------- +# BaseSDK._sign_transaction +# --------------------------------------------------------------------------- + + +class TestBaseSDKSignTransaction: + def test_sign_without_fee_payer(self, test_config: Any, mock_account: Any) -> None: + sdk = _make_sdk(test_config, mock_account) + signer = MagicMock() + mock_raw_txn = MagicMock() + mock_auth = MagicMock() + mock_raw_txn.sign.return_value = mock_auth + + mock_txn = MagicMock() + mock_txn.raw_transaction = mock_raw_txn + mock_txn.fee_payer_address = None + + result = sdk._sign_transaction(signer, mock_txn) + mock_raw_txn.sign.assert_called_once_with(signer.private_key) + assert result is mock_auth + + def test_sign_with_fee_payer(self, test_config: Any, mock_account: Any) -> None: + sdk = _make_sdk(test_config, mock_account) + signer = MagicMock() + mock_raw_txn = MagicMock() + mock_auth = MagicMock() + + mock_txn = MagicMock() + mock_txn.raw_transaction = mock_raw_txn + mock_txn.fee_payer_address = MagicMock() # Not None + + with patch("decibel._base.FeePayerRawTransaction") as mock_fee_payer_cls: + mock_fee_payer_instance = MagicMock() + mock_fee_payer_instance.sign.return_value = mock_auth + mock_fee_payer_cls.return_value = mock_fee_payer_instance + + result = sdk._sign_transaction(signer, mock_txn) + + mock_fee_payer_instance.sign.assert_called_once_with(signer.private_key) + assert result is mock_auth + + +# --------------------------------------------------------------------------- +# BaseSDK._send_tx full flow +# --------------------------------------------------------------------------- + + +class TestBaseSDKSendTx: + @pytest.mark.asyncio + async def test_send_tx_skip_simulate(self, test_config: Any, mock_account: Any) -> None: + opts = BaseSDKOptions(skip_simulate=True) + sdk = _make_sdk(test_config, mock_account, opts) + + mock_txn = MagicMock() + mock_auth = MagicMock() + mock_pending = _make_pending_response("0xresult") + success_data = {"type": "user_transaction", "success": True} + + mock_account.address.return_value = MagicMock() + sdk._account = mock_account + + with patch.object(sdk, "build_tx", new_callable=AsyncMock, return_value=mock_txn): + with patch.object(sdk, "_sign_transaction", return_value=mock_auth): + with patch.object( + sdk, "submit_tx", new_callable=AsyncMock, return_value=mock_pending + ): + with patch.object( + sdk, + "_wait_for_transaction", + new_callable=AsyncMock, + return_value=success_data, + ): + result = await sdk._send_tx(MagicMock()) + + assert result == success_data + + @pytest.mark.asyncio + async def test_send_tx_with_simulate(self, test_config: Any, mock_account: Any) -> None: + opts = BaseSDKOptions(skip_simulate=False) + sdk = _make_sdk(test_config, mock_account, opts) + sdk._account = mock_account + + mock_txn1 = MagicMock() + mock_txn2 = MagicMock() + mock_auth = MagicMock() + mock_pending = _make_pending_response("0xresult2") + success_data = {"type": "user_transaction", "success": True} + sim_result = {"max_gas_amount": "100000", "gas_unit_price": "150"} + + build_tx_mock = AsyncMock(side_effect=[mock_txn1, mock_txn2]) + + with patch.object(sdk, "build_tx", build_tx_mock): + with patch.object( + sdk, "_simulate_transaction", new_callable=AsyncMock, return_value=sim_result + ): + with patch.object(sdk, "_sign_transaction", return_value=mock_auth): + with patch.object( + sdk, "submit_tx", new_callable=AsyncMock, return_value=mock_pending + ): + with patch.object( + sdk, + "_wait_for_transaction", + new_callable=AsyncMock, + return_value=success_data, + ): + result = await sdk._send_tx(MagicMock()) + + assert result == success_data + assert build_tx_mock.await_count == 2 # built twice (initial + post-simulate) + + @pytest.mark.asyncio + async def test_send_tx_simulate_missing_fields_raises( + self, test_config: Any, mock_account: Any + ) -> None: + opts = BaseSDKOptions(skip_simulate=False) + sdk = _make_sdk(test_config, mock_account, opts) + sdk._account = mock_account + + sim_result: dict[str, Any] = {} # Missing max_gas_amount and gas_unit_price + + with patch.object(sdk, "build_tx", new_callable=AsyncMock, return_value=MagicMock()): + with patch.object( + sdk, "_simulate_transaction", new_callable=AsyncMock, return_value=sim_result + ): + with pytest.raises(ValueError, match="Transaction simulation returned no results"): + await sdk._send_tx(MagicMock()) + + @pytest.mark.asyncio + async def test_send_tx_submit_connect_timeout_raises_txn_submit_error( + self, test_config: Any, mock_account: Any + ) -> None: + opts = BaseSDKOptions(skip_simulate=True) + sdk = _make_sdk(test_config, mock_account, opts) + sdk._account = mock_account + + with patch.object(sdk, "build_tx", new_callable=AsyncMock, return_value=MagicMock()): + with patch.object(sdk, "_sign_transaction", return_value=MagicMock()): + with patch.object( + sdk, + "submit_tx", + new_callable=AsyncMock, + side_effect=httpx.ConnectTimeout("timeout"), + ): + with pytest.raises(TxnSubmitError, match="connection timeout"): + await sdk._send_tx(MagicMock()) + + @pytest.mark.asyncio + async def test_send_tx_submit_connect_error_raises_txn_submit_error( + self, test_config: Any, mock_account: Any + ) -> None: + opts = BaseSDKOptions(skip_simulate=True) + sdk = _make_sdk(test_config, mock_account, opts) + sdk._account = mock_account + + with patch.object(sdk, "build_tx", new_callable=AsyncMock, return_value=MagicMock()): + with patch.object(sdk, "_sign_transaction", return_value=MagicMock()): + with patch.object( + sdk, + "submit_tx", + new_callable=AsyncMock, + side_effect=httpx.ConnectError("refused"), + ): + with pytest.raises(TxnSubmitError, match="connection error"): + await sdk._send_tx(MagicMock()) + + @pytest.mark.asyncio + async def test_send_tx_submit_generic_error_raises_txn_submit_error( + self, test_config: Any, mock_account: Any + ) -> None: + opts = BaseSDKOptions(skip_simulate=True) + sdk = _make_sdk(test_config, mock_account, opts) + sdk._account = mock_account + + with patch.object(sdk, "build_tx", new_callable=AsyncMock, return_value=MagicMock()): + with patch.object(sdk, "_sign_transaction", return_value=MagicMock()): + with patch.object( + sdk, "submit_tx", new_callable=AsyncMock, side_effect=RuntimeError("generic") + ): + with pytest.raises(TxnSubmitError): + await sdk._send_tx(MagicMock()) + + @pytest.mark.asyncio + async def test_send_tx_with_account_override(self, test_config: Any, mock_account: Any) -> None: + opts = BaseSDKOptions(skip_simulate=True) + sdk = _make_sdk(test_config, mock_account, opts) + + override_account = MagicMock() + override_account.address.return_value = MagicMock() + mock_pending = _make_pending_response() + success_data = {"success": True} + + with patch.object(sdk, "build_tx", new_callable=AsyncMock, return_value=MagicMock()): + with patch.object(sdk, "_sign_transaction", return_value=MagicMock()): + with patch.object( + sdk, "submit_tx", new_callable=AsyncMock, return_value=mock_pending + ): + with patch.object( + sdk, + "_wait_for_transaction", + new_callable=AsyncMock, + return_value=success_data, + ): + result = await sdk._send_tx(MagicMock(), account_override=override_account) + + assert result == success_data + + +# --------------------------------------------------------------------------- +# BaseSDKSync.__init__ +# --------------------------------------------------------------------------- + + +class TestBaseSDKSyncInit: + @patch("decibel._base.AbiRegistry") + def test_creates_default_http_client(self, mock_abi: Any, test_config: Any) -> None: + account = MagicMock() + sdk = BaseSDKSync(config=test_config, account=account) + assert isinstance(sdk._http_client, httpx.Client) + assert sdk._owns_http_client is True + + @patch("decibel._base.AbiRegistry") + def test_uses_provided_http_client(self, mock_abi: Any, test_config: Any) -> None: + account = MagicMock() + provided_client = MagicMock(spec=httpx.Client) + opts = BaseSDKOptionsSync(http_client=provided_client) + sdk = BaseSDKSync(config=test_config, account=account, opts=opts) + assert sdk._http_client is provided_client + assert sdk._owns_http_client is False + + @patch("decibel._base.AbiRegistry") + def test_default_opts(self, mock_abi: Any, test_config: Any) -> None: + account = MagicMock() + sdk = BaseSDKSync(config=test_config, account=account) + assert sdk._skip_simulate is False + assert sdk._no_fee_payer is False + assert sdk._node_api_key is None + assert sdk._gas_price_manager is None + + @patch("decibel._base.AbiRegistry") + def test_none_chain_id_logs_warning(self, mock_abi: Any, test_config: Any) -> None: + import logging + from dataclasses import replace + + config_no_chain = replace(test_config, chain_id=None) + account = MagicMock() + with patch.object(logging.getLogger("decibel._base"), "warning") as mock_warn: + BaseSDKSync(config=config_no_chain, account=account) + mock_warn.assert_called_once() + + +# --------------------------------------------------------------------------- +# BaseSDKSync.close / context manager +# --------------------------------------------------------------------------- + + +class TestBaseSDKSyncClose: + def test_close_closes_owned_client(self, test_config: Any, mock_account: Any) -> None: + sdk = _make_sdk_sync(test_config, mock_account) + sdk._http_client = MagicMock(spec=httpx.Client) + sdk._owns_http_client = True + + sdk.close() + sdk._http_client.close.assert_called_once() + + def test_close_does_not_close_provided_client( + self, test_config: Any, mock_account: Any + ) -> None: + provided_client = MagicMock(spec=httpx.Client) + opts = BaseSDKOptionsSync(http_client=provided_client) + sdk = _make_sdk_sync(test_config, mock_account, opts) + + sdk.close() + provided_client.close.assert_not_called() + + def test_enter_returns_self(self, test_config: Any, mock_account: Any) -> None: + sdk = _make_sdk_sync(test_config, mock_account) + result = sdk.__enter__() + assert result is sdk + + def test_exit_calls_close(self, test_config: Any, mock_account: Any) -> None: + sdk = _make_sdk_sync(test_config, mock_account) + sdk._http_client = MagicMock(spec=httpx.Client) + sdk.close = MagicMock() # type: ignore[method-assign] + sdk.__exit__(None, None, None) + sdk.close.assert_called_once() + + def test_context_manager(self, test_config: Any, mock_account: Any) -> None: + sdk = _make_sdk_sync(test_config, mock_account) + sdk._http_client = MagicMock(spec=httpx.Client) + with sdk as ctx: + assert ctx is sdk + sdk._http_client.close.assert_called_once() + + +# --------------------------------------------------------------------------- +# BaseSDKSync._fetch_gas_price_estimation +# --------------------------------------------------------------------------- + + +class TestBaseSDKSyncFetchGasPrice: + def test_success(self, test_config: Any, mock_account: Any) -> None: + sdk = _make_sdk_sync(test_config, mock_account) + sdk._http_client = MagicMock(spec=httpx.Client) + sdk._http_client.get = MagicMock( + return_value=_make_httpx_response(200, json_data={"gas_estimate": 250}) + ) + + price = sdk._fetch_gas_price_estimation() + assert price == 250 + + def test_uses_default_when_missing(self, test_config: Any, mock_account: Any) -> None: + sdk = _make_sdk_sync(test_config, mock_account) + sdk._http_client = MagicMock(spec=httpx.Client) + sdk._http_client.get = MagicMock(return_value=_make_httpx_response(200, json_data={})) + + price = sdk._fetch_gas_price_estimation() + assert price == 100 # DEFAULT_GAS_ESTIMATE + + def test_failure_raises_value_error(self, test_config: Any, mock_account: Any) -> None: + sdk = _make_sdk_sync(test_config, mock_account) + sdk._http_client = MagicMock(spec=httpx.Client) + sdk._http_client.get = MagicMock( + return_value=_make_httpx_response(503, text="Service Unavailable") + ) + + with pytest.raises(ValueError, match="Failed to fetch gas price"): + sdk._fetch_gas_price_estimation() + + +# --------------------------------------------------------------------------- +# BaseSDKSync._simulate_transaction +# --------------------------------------------------------------------------- + + +class TestBaseSDKSyncSimulateTransaction: + def test_success_returns_first_item(self, test_config: Any, mock_account: Any) -> None: + sdk = _make_sdk_sync(test_config, mock_account) + sim_data = [{"max_gas_amount": "75000", "gas_unit_price": "100"}] + sdk._http_client = MagicMock(spec=httpx.Client) + sdk._http_client.post = MagicMock( + return_value=_make_httpx_response(200, json_data=sim_data) + ) + + mock_txn = MagicMock() + with patch.object(sdk, "_serialize_for_simulation", return_value=b"\x00" * 8): + result = sdk._simulate_transaction(mock_txn) + + assert result["max_gas_amount"] == "75000" + + def test_failure_raises_value_error(self, test_config: Any, mock_account: Any) -> None: + sdk = _make_sdk_sync(test_config, mock_account) + sdk._http_client = MagicMock(spec=httpx.Client) + sdk._http_client.post = MagicMock(return_value=_make_httpx_response(400, text="Bad")) + + mock_txn = MagicMock() + with patch.object(sdk, "_serialize_for_simulation", return_value=b"\x00" * 8): + with pytest.raises(ValueError, match="simulation failed"): + sdk._simulate_transaction(mock_txn) + + def test_empty_list_raises_value_error(self, test_config: Any, mock_account: Any) -> None: + sdk = _make_sdk_sync(test_config, mock_account) + sdk._http_client = MagicMock(spec=httpx.Client) + sdk._http_client.post = MagicMock(return_value=_make_httpx_response(200, json_data=[])) + + mock_txn = MagicMock() + with patch.object(sdk, "_serialize_for_simulation", return_value=b"\x00" * 8): + with pytest.raises(ValueError, match="empty results"): + sdk._simulate_transaction(mock_txn) + + +# --------------------------------------------------------------------------- +# BaseSDKSync._submit_direct +# --------------------------------------------------------------------------- + + +class TestBaseSDKSyncSubmitDirect: + def test_success(self, test_config: Any, mock_account: Any) -> None: + sdk = _make_sdk_sync(test_config, mock_account) + + mock_raw = MagicMock() + mock_raw.sender = "0x" + "aa" * 32 + mock_raw.sequence_number = 1 + mock_raw.max_gas_amount = 200000 + mock_raw.gas_unit_price = 100 + mock_raw.expiration_timestamps_secs = 9999999999 + + mock_txn = MagicMock() + mock_txn.raw_transaction = mock_raw + mock_auth = MagicMock() + + sdk._http_client = MagicMock(spec=httpx.Client) + sdk._http_client.post = MagicMock( + return_value=_make_httpx_response(200, json_data={"hash": "0xsync123"}) + ) + + with patch.object(sdk, "_serialize_signed_transaction", return_value=b"\x00" * 8): + result = sdk._submit_direct(mock_txn, mock_auth) + + assert result.hash == "0xsync123" + + def test_failure_raises_value_error(self, test_config: Any, mock_account: Any) -> None: + sdk = _make_sdk_sync(test_config, mock_account) + sdk._http_client = MagicMock(spec=httpx.Client) + sdk._http_client.post = MagicMock( + return_value=_make_httpx_response(400, text="Bad Request") + ) + + mock_txn = MagicMock() + mock_auth = MagicMock() + + with patch.object(sdk, "_serialize_signed_transaction", return_value=b"\x00" * 8): + with pytest.raises(ValueError, match="submission failed"): + sdk._submit_direct(mock_txn, mock_auth) + + +# --------------------------------------------------------------------------- +# BaseSDKSync._wait_for_transaction +# --------------------------------------------------------------------------- + + +class TestBaseSDKSyncWaitForTransaction: + def test_success_on_first_poll(self, test_config: Any, mock_account: Any) -> None: + sdk = _make_sdk_sync(test_config, mock_account) + success_data = {"type": "user_transaction", "success": True, "hash": "0xsync"} + sdk._http_client = MagicMock(spec=httpx.Client) + sdk._http_client.get = MagicMock( + return_value=_make_httpx_response(200, json_data=success_data) + ) + + with patch("time.sleep"): + result = sdk._wait_for_transaction("0xsync", txn_confirm_timeout=30.0) + + assert result["success"] is True + + def test_pending_then_success(self, test_config: Any, mock_account: Any) -> None: + sdk = _make_sdk_sync(test_config, mock_account) + pending_data = {"type": "pending_transaction"} + success_data = {"type": "user_transaction", "success": True} + + sdk._http_client = MagicMock(spec=httpx.Client) + sdk._http_client.get = MagicMock( + side_effect=[ + _make_httpx_response(200, json_data=pending_data), + _make_httpx_response(200, json_data=success_data), + ] + ) + + with patch("time.sleep"): + result = sdk._wait_for_transaction("0xsync", txn_confirm_timeout=30.0) + + assert result["success"] is True + + def test_failure_vm_status_raises(self, test_config: Any, mock_account: Any) -> None: + sdk = _make_sdk_sync(test_config, mock_account) + failed_data = {"type": "user_transaction", "success": False, "vm_status": "Aborted"} + sdk._http_client = MagicMock(spec=httpx.Client) + sdk._http_client.get = MagicMock( + return_value=_make_httpx_response(200, json_data=failed_data) + ) + + with patch("time.sleep"), pytest.raises(TxnConfirmError, match="failed: Aborted"): + sdk._wait_for_transaction("0xsync", txn_confirm_timeout=30.0) + + def test_timeout_raises(self, test_config: Any, mock_account: Any) -> None: + sdk = _make_sdk_sync(test_config, mock_account) + pending_data = {"type": "pending_transaction"} + sdk._http_client = MagicMock(spec=httpx.Client) + sdk._http_client.get = MagicMock( + return_value=_make_httpx_response(200, json_data=pending_data) + ) + + with patch("time.time", side_effect=[0.0, 100.0]), patch("time.sleep"): + with pytest.raises(TxnConfirmError, match="did not confirm"): + sdk._wait_for_transaction("0xsync", txn_confirm_timeout=0.001) + + def test_uses_default_timeout_when_none(self, test_config: Any, mock_account: Any) -> None: + sdk = _make_sdk_sync(test_config, mock_account) + success_data = {"type": "user_transaction", "success": True} + sdk._http_client = MagicMock(spec=httpx.Client) + sdk._http_client.get = MagicMock( + return_value=_make_httpx_response(200, json_data=success_data) + ) + + with patch("time.sleep"): + result = sdk._wait_for_transaction("0xsync", txn_confirm_timeout=None) + + assert result["success"] is True + + +# --------------------------------------------------------------------------- +# BaseSDKSync._send_tx +# --------------------------------------------------------------------------- + + +class TestBaseSDKSyncSendTx: + def test_send_tx_skip_simulate(self, test_config: Any, mock_account: Any) -> None: + opts = BaseSDKOptionsSync(skip_simulate=True) + sdk = _make_sdk_sync(test_config, mock_account, opts) + sdk._account = mock_account + + mock_txn = MagicMock() + mock_auth = MagicMock() + mock_pending = _make_pending_response() + success_data = {"success": True} + + with patch.object(sdk, "build_tx", return_value=mock_txn): + with patch.object(sdk, "_sign_transaction", return_value=mock_auth): + with patch.object(sdk, "submit_tx", return_value=mock_pending): + with patch.object(sdk, "_wait_for_transaction", return_value=success_data): + result = sdk._send_tx(MagicMock()) + + assert result == success_data + + def test_send_tx_connect_timeout_raises_txn_submit_error( + self, test_config: Any, mock_account: Any + ) -> None: + opts = BaseSDKOptionsSync(skip_simulate=True) + sdk = _make_sdk_sync(test_config, mock_account, opts) + sdk._account = mock_account + + with patch.object(sdk, "build_tx", return_value=MagicMock()): + with patch.object(sdk, "_sign_transaction", return_value=MagicMock()): + with patch.object(sdk, "submit_tx", side_effect=httpx.ConnectTimeout("timeout")): + with pytest.raises(TxnSubmitError, match="connection timeout"): + sdk._send_tx(MagicMock()) + + def test_send_tx_generic_error_raises_txn_submit_error( + self, test_config: Any, mock_account: Any + ) -> None: + opts = BaseSDKOptionsSync(skip_simulate=True) + sdk = _make_sdk_sync(test_config, mock_account, opts) + sdk._account = mock_account + + with patch.object(sdk, "build_tx", return_value=MagicMock()): + with patch.object(sdk, "_sign_transaction", return_value=MagicMock()): + with patch.object(sdk, "submit_tx", side_effect=RuntimeError("boom")): + with pytest.raises(TxnSubmitError): + sdk._send_tx(MagicMock()) + + +# --------------------------------------------------------------------------- +# BaseSDKSync.submit_tx +# --------------------------------------------------------------------------- + + +class TestBaseSDKSyncSubmitTx: + def test_no_fee_payer_calls_submit_direct(self, test_config: Any, mock_account: Any) -> None: + opts = BaseSDKOptionsSync(no_fee_payer=True) + sdk = _make_sdk_sync(test_config, mock_account, opts) + expected = _make_pending_response() + + with patch.object(sdk, "_submit_direct", return_value=expected): + result = sdk.submit_tx(MagicMock(), MagicMock()) + + assert result is expected + + def test_with_fee_payer_calls_submit_fee_paid_sync( + self, test_config: Any, mock_account: Any + ) -> None: + opts = BaseSDKOptionsSync(no_fee_payer=False) + sdk = _make_sdk_sync(test_config, mock_account, opts) + expected = _make_pending_response() + + with patch("decibel._base.submit_fee_paid_transaction_sync", return_value=expected): + result = sdk.submit_tx(MagicMock(), MagicMock()) + + assert result is expected + + +# --------------------------------------------------------------------------- +# get_primary_subaccount_address (method) +# --------------------------------------------------------------------------- + + +class TestGetPrimarySubaccountAddress: + def test_delegates_to_util(self, test_config: Any, mock_account: Any) -> None: + sdk = _make_sdk(test_config, mock_account) + test_addr = "0x" + "aa" * 32 + with patch( + "decibel._base.get_primary_subaccount_addr", return_value="0xderived" + ) as mock_fn: + result = sdk.get_primary_subaccount_address(test_addr) + + assert result == "0xderived" + mock_fn.assert_called_once_with( + test_addr, test_config.compat_version, test_config.deployment.package + ) + + +# --------------------------------------------------------------------------- +# BaseSDK — additional properties and serialization coverage +# --------------------------------------------------------------------------- + + +class TestBaseSDKProperties: + @patch("decibel._base.AbiRegistry") + @patch("decibel._base.RestClient") + def test_aptos_property(self, mock_rest: Any, mock_abi: Any, test_config: Any) -> None: + account = MagicMock() + mock_rest_instance = MagicMock() + mock_rest.return_value = mock_rest_instance + sdk = BaseSDK(config=test_config, account=account) + assert sdk.aptos is mock_rest_instance + + def test_skip_simulate_property(self, test_config: Any, mock_account: Any) -> None: + opts = BaseSDKOptions(skip_simulate=True) + sdk = _make_sdk(test_config, mock_account, opts) + assert sdk.skip_simulate is True + + def test_no_fee_payer_property(self, test_config: Any, mock_account: Any) -> None: + opts = BaseSDKOptions(no_fee_payer=True) + sdk = _make_sdk(test_config, mock_account, opts) + assert sdk.no_fee_payer is True + + def test_config_property(self, test_config: Any, mock_account: Any) -> None: + sdk = _make_sdk(test_config, mock_account) + assert sdk.config is test_config + + def test_account_property(self, test_config: Any, mock_account: Any) -> None: + sdk = _make_sdk(test_config, mock_account) + assert sdk.account is mock_account + + +class TestBaseSDKSyncProperties: + def test_config_property(self, test_config: Any, mock_account: Any) -> None: + sdk = _make_sdk_sync(test_config, mock_account) + assert sdk.config is test_config + + def test_account_property(self, test_config: Any, mock_account: Any) -> None: + sdk = _make_sdk_sync(test_config, mock_account) + assert sdk.account is mock_account + + def test_skip_simulate_property(self, test_config: Any, mock_account: Any) -> None: + opts = BaseSDKOptionsSync(skip_simulate=True) + sdk = _make_sdk_sync(test_config, mock_account, opts) + assert sdk.skip_simulate is True + + def test_no_fee_payer_property(self, test_config: Any, mock_account: Any) -> None: + opts = BaseSDKOptionsSync(no_fee_payer=True) + sdk = _make_sdk_sync(test_config, mock_account, opts) + assert sdk.no_fee_payer is True + + def test_time_delta_ms_property(self, test_config: Any, mock_account: Any) -> None: + opts = BaseSDKOptionsSync(time_delta_ms=250) + sdk = _make_sdk_sync(test_config, mock_account, opts) + assert sdk.time_delta_ms == 250 + + def test_time_delta_ms_setter(self, test_config: Any, mock_account: Any) -> None: + sdk = _make_sdk_sync(test_config, mock_account) + sdk.time_delta_ms = 1500 + assert sdk.time_delta_ms == 1500 + + def test_get_primary_subaccount_address_delegates( + self, test_config: Any, mock_account: Any + ) -> None: + sdk = _make_sdk_sync(test_config, mock_account) + test_addr = "0x" + "aa" * 32 + with patch("decibel._base.get_primary_subaccount_addr", return_value="0xsync_derived"): + result = sdk.get_primary_subaccount_address(test_addr) + assert result == "0xsync_derived" + + +# --------------------------------------------------------------------------- +# BaseSDKSync.build_tx — gas manager paths +# --------------------------------------------------------------------------- + + +class TestBaseSDKSyncBuildTx: + def test_build_tx_with_gas_price_manager_cached( + self, test_config: Any, mock_account: Any + ) -> None: + mock_manager = MagicMock() + mock_manager.get_gas_price.return_value = 300 + opts = BaseSDKOptionsSync(gas_price_manager=mock_manager) + sdk = _make_sdk_sync(test_config, mock_account, opts) + + mock_abi = MagicMock() + mock_abi.params = ["u64"] + sdk._abi_registry = MagicMock() + sdk._abi_registry.get_function.return_value = mock_abi + + sender = MagicMock() + mock_txn = MagicMock() + with patch( + "decibel._base.build_simple_transaction_sync", return_value=mock_txn + ) as mock_build: + with patch("decibel._base.generate_random_replay_protection_nonce", return_value=111): + result = sdk.build_tx( + MagicMock(function="0x1::m::f", function_arguments=[1], type_arguments=[]), + sender, + ) + + assert result is mock_txn + assert mock_build.call_args.kwargs["gas_unit_price"] == 300 + + def test_build_tx_with_gas_price_manager_uncached( + self, test_config: Any, mock_account: Any + ) -> None: + mock_manager = MagicMock() + mock_manager.get_gas_price.return_value = None + mock_manager.fetch_and_set_gas_price.return_value = 400 + opts = BaseSDKOptionsSync(gas_price_manager=mock_manager) + sdk = _make_sdk_sync(test_config, mock_account, opts) + + mock_abi = MagicMock() + mock_abi.params = ["u64"] + sdk._abi_registry = MagicMock() + sdk._abi_registry.get_function.return_value = mock_abi + + sender = MagicMock() + mock_txn = MagicMock() + with patch( + "decibel._base.build_simple_transaction_sync", return_value=mock_txn + ) as mock_build: + with patch("decibel._base.generate_random_replay_protection_nonce", return_value=222): + sdk.build_tx( + MagicMock(function="0x1::m::f", function_arguments=[1], type_arguments=[]), + sender, + ) + + assert mock_build.call_args.kwargs["gas_unit_price"] == 400 + + def test_build_tx_missing_abi_raises(self, test_config: Any, mock_account: Any) -> None: + sdk = _make_sdk_sync(test_config, mock_account) + sdk._abi_registry = MagicMock() + sdk._abi_registry.get_function.return_value = None + + with patch("decibel._base.generate_random_replay_protection_nonce", return_value=333): + with pytest.raises(ValueError, match="Cannot build transaction"): + sdk.build_tx( + MagicMock(function="0x1::m::unknown", function_arguments=[], type_arguments=[]), + MagicMock(), + ) + + def test_build_tx_nonce_none_raises(self, test_config: Any, mock_account: Any) -> None: + sdk = _make_sdk_sync(test_config, mock_account) + with patch("decibel._base.generate_random_replay_protection_nonce", return_value=None): + with pytest.raises(ValueError, match="replay protection nonce"): + sdk.build_tx( + MagicMock(function="0x1::m::f", function_arguments=[], type_arguments=[]), + MagicMock(), + ) + + def test_build_tx_without_gas_manager_fetches_price( + self, test_config: Any, mock_account: Any + ) -> None: + sdk = _make_sdk_sync(test_config, mock_account) + + mock_abi = MagicMock() + mock_abi.params = ["u64"] + sdk._abi_registry = MagicMock() + sdk._abi_registry.get_function.return_value = mock_abi + + sender = MagicMock() + mock_txn = MagicMock() + + with patch.object(sdk, "_fetch_gas_price_estimation", return_value=500): + with patch( + "decibel._base.build_simple_transaction_sync", return_value=mock_txn + ) as mock_build: + with patch( + "decibel._base.generate_random_replay_protection_nonce", return_value=444 + ): + sdk.build_tx( + MagicMock(function="0x1::m::f", function_arguments=[1], type_arguments=[]), + sender, + ) + + assert mock_build.call_args.kwargs["gas_unit_price"] == 500 + + def test_build_tx_explicit_gas_unit_price(self, test_config: Any, mock_account: Any) -> None: + sdk = _make_sdk_sync(test_config, mock_account) + + mock_abi = MagicMock() + mock_abi.params = ["u64"] + sdk._abi_registry = MagicMock() + sdk._abi_registry.get_function.return_value = mock_abi + + mock_txn = MagicMock() + with patch( + "decibel._base.build_simple_transaction_sync", return_value=mock_txn + ) as mock_build: + with patch("decibel._base.generate_random_replay_protection_nonce", return_value=555): + sdk.build_tx( + MagicMock(function="0x1::m::f", function_arguments=[1], type_arguments=[]), + MagicMock(), + gas_unit_price=600, + ) + + assert mock_build.call_args.kwargs["gas_unit_price"] == 600 + + +# --------------------------------------------------------------------------- +# BaseSDKSync._send_tx — with simulate path +# --------------------------------------------------------------------------- + + +class TestBaseSDKSyncSendTxWithSimulate: + def test_send_tx_with_simulate(self, test_config: Any, mock_account: Any) -> None: + opts = BaseSDKOptionsSync(skip_simulate=False) + sdk = _make_sdk_sync(test_config, mock_account, opts) + sdk._account = mock_account + + mock_txn1 = MagicMock() + mock_txn2 = MagicMock() + mock_auth = MagicMock() + mock_pending = _make_pending_response() + success_data = {"success": True} + sim_result = {"max_gas_amount": "100000", "gas_unit_price": "150"} + + build_tx_mock = MagicMock(side_effect=[mock_txn1, mock_txn2]) + + with patch.object(sdk, "build_tx", build_tx_mock): + with patch.object(sdk, "_simulate_transaction", return_value=sim_result): + with patch.object(sdk, "_sign_transaction", return_value=mock_auth): + with patch.object(sdk, "submit_tx", return_value=mock_pending): + with patch.object(sdk, "_wait_for_transaction", return_value=success_data): + result = sdk._send_tx(MagicMock()) + + assert result == success_data + assert build_tx_mock.call_count == 2 + + def test_send_tx_simulate_missing_fields_raises( + self, test_config: Any, mock_account: Any + ) -> None: + opts = BaseSDKOptionsSync(skip_simulate=False) + sdk = _make_sdk_sync(test_config, mock_account, opts) + sdk._account = mock_account + + with patch.object(sdk, "build_tx", return_value=MagicMock()): + with patch.object(sdk, "_simulate_transaction", return_value={}): + with pytest.raises(ValueError, match="no results"): + sdk._send_tx(MagicMock()) + + def test_send_tx_connect_error_raises_txn_submit_error( + self, test_config: Any, mock_account: Any + ) -> None: + opts = BaseSDKOptionsSync(skip_simulate=True) + sdk = _make_sdk_sync(test_config, mock_account, opts) + sdk._account = mock_account + + with patch.object(sdk, "build_tx", return_value=MagicMock()): + with patch.object(sdk, "_sign_transaction", return_value=MagicMock()): + with patch.object(sdk, "submit_tx", side_effect=httpx.ConnectError("refused")): + with pytest.raises(TxnSubmitError, match="connection error"): + sdk._send_tx(MagicMock()) + + def test_send_tx_http_status_error_raises_txn_submit_error( + self, test_config: Any, mock_account: Any + ) -> None: + opts = BaseSDKOptionsSync(skip_simulate=True) + sdk = _make_sdk_sync(test_config, mock_account, opts) + sdk._account = mock_account + + mock_response = MagicMock() + mock_response.status_code = 429 + + with patch.object(sdk, "build_tx", return_value=MagicMock()): + with patch.object(sdk, "_sign_transaction", return_value=MagicMock()): + with patch.object( + sdk, + "submit_tx", + side_effect=httpx.HTTPStatusError( + "rate limited", request=MagicMock(), response=mock_response + ), + ): + with pytest.raises(TxnSubmitError, match="HTTP 429"): + sdk._send_tx(MagicMock()) + + +# --------------------------------------------------------------------------- +# BaseSDK._send_tx — HTTP status error path +# --------------------------------------------------------------------------- + + +class TestBaseSDKAsyncSleep: + @pytest.mark.asyncio + async def test_async_sleep_calls_asyncio_sleep( + self, test_config: Any, mock_account: Any + ) -> None: + sdk = _make_sdk(test_config, mock_account) + with patch("asyncio.sleep", new_callable=AsyncMock) as mock_sleep: + await sdk._async_sleep(0.1) + mock_sleep.assert_awaited_once_with(0.1) + + +class TestBaseSDKSyncSignTransaction: + def test_sign_without_fee_payer(self, test_config: Any, mock_account: Any) -> None: + sdk = _make_sdk_sync(test_config, mock_account) + signer = MagicMock() + mock_raw_txn = MagicMock() + mock_auth = MagicMock() + mock_raw_txn.sign.return_value = mock_auth + + mock_txn = MagicMock() + mock_txn.raw_transaction = mock_raw_txn + mock_txn.fee_payer_address = None + + result = sdk._sign_transaction(signer, mock_txn) + mock_raw_txn.sign.assert_called_once_with(signer.private_key) + assert result is mock_auth + + def test_sign_with_fee_payer(self, test_config: Any, mock_account: Any) -> None: + sdk = _make_sdk_sync(test_config, mock_account) + signer = MagicMock() + mock_raw_txn = MagicMock() + mock_auth = MagicMock() + + mock_txn = MagicMock() + mock_txn.raw_transaction = mock_raw_txn + mock_txn.fee_payer_address = MagicMock() # Not None + + with patch("decibel._base.FeePayerRawTransaction") as mock_fee_payer_cls: + mock_fee_payer_instance = MagicMock() + mock_fee_payer_instance.sign.return_value = mock_auth + mock_fee_payer_cls.return_value = mock_fee_payer_instance + + result = sdk._sign_transaction(signer, mock_txn) + + mock_fee_payer_instance.sign.assert_called_once_with(signer.private_key) + assert result is mock_auth + + +class TestBaseSDKSyncBuildNodeHeaders: + def test_no_api_key_returns_empty_dict(self, test_config: Any, mock_account: Any) -> None: + sdk = _make_sdk_sync(test_config, mock_account) + sdk._node_api_key = None + headers = sdk._build_node_headers() + assert headers == {} + + def test_with_api_key_includes_header(self, test_config: Any, mock_account: Any) -> None: + opts = BaseSDKOptionsSync(node_api_key="sync-node-key") + sdk = _make_sdk_sync(test_config, mock_account, opts) + headers = sdk._build_node_headers() + assert headers["x-api-key"] == "sync-node-key" + + +class TestBaseSDKSendTxHttpStatusError: + @pytest.mark.asyncio + async def test_send_tx_http_status_error_raises_txn_submit_error( + self, test_config: Any, mock_account: Any + ) -> None: + opts = BaseSDKOptions(skip_simulate=True) + sdk = _make_sdk(test_config, mock_account, opts) + sdk._account = mock_account + + mock_response = MagicMock() + mock_response.status_code = 429 + + with patch.object(sdk, "build_tx", new_callable=AsyncMock, return_value=MagicMock()): + with patch.object(sdk, "_sign_transaction", return_value=MagicMock()): + with patch.object( + sdk, + "submit_tx", + new_callable=AsyncMock, + side_effect=httpx.HTTPStatusError( + "rate limited", request=MagicMock(), response=mock_response + ), + ): + with pytest.raises(TxnSubmitError, match="HTTP 429"): + await sdk._send_tx(MagicMock()) + + +# --------------------------------------------------------------------------- +# Serialization methods using real cryptographic keys +# --------------------------------------------------------------------------- + + +def _build_real_transaction(with_fee_payer: bool = True) -> Any: + """Build a real SimpleTransaction using actual Aptos types.""" + from aptos_sdk.account import Account + + from decibel._transaction_builder import InputEntryFunctionData, build_simple_transaction_sync + from decibel.abi import AbiRegistry + + acct = Account.generate() + registry = AbiRegistry(chain_id=2) + func_id = "0xe7da2794b1d8af76532ed95f38bfdf1136abfd8ea3a240189971988a83101b7f::usdc::mint" + abi = registry.get_function(func_id) + assert abi is not None + + data = InputEntryFunctionData( + function=func_id, + function_arguments=[str(acct.address()), 1_000_000], + type_arguments=[], + ) + return acct, build_simple_transaction_sync( + sender=acct.address(), + data=data, + chain_id=2, + gas_unit_price=100, + abi=abi, + with_fee_payer=with_fee_payer, + replay_protection_nonce=99999, + ) + + +class TestBaseSDKSerializationMethods: + """Tests for _serialize_for_simulation and _serialize_signed_transaction using real keys.""" + + @patch("decibel._base.AbiRegistry") + @patch("decibel._base.RestClient") + def test_serialize_for_simulation_with_fee_payer( + self, mock_rest: Any, mock_abi: Any, test_config: Any + ) -> None: + from aptos_sdk.account import Account + + real_account = Account.generate() + sdk = BaseSDK(config=test_config, account=real_account) + _, txn = _build_real_transaction(with_fee_payer=True) + + result = sdk._serialize_for_simulation(txn) + assert isinstance(result, bytes) + assert len(result) > 0 + + @patch("decibel._base.AbiRegistry") + @patch("decibel._base.RestClient") + def test_serialize_for_simulation_without_fee_payer( + self, mock_rest: Any, mock_abi: Any, test_config: Any + ) -> None: + from aptos_sdk.account import Account + + real_account = Account.generate() + sdk = BaseSDK(config=test_config, account=real_account) + _, txn = _build_real_transaction(with_fee_payer=False) + + result = sdk._serialize_for_simulation(txn) + assert isinstance(result, bytes) + assert len(result) > 0 + + @patch("decibel._base.AbiRegistry") + @patch("decibel._base.RestClient") + def test_serialize_signed_transaction( + self, mock_rest: Any, mock_abi: Any, test_config: Any + ) -> None: + from aptos_sdk.account import Account + + real_account = Account.generate() + sdk = BaseSDK(config=test_config, account=real_account) + _, txn = _build_real_transaction(with_fee_payer=True) + + # Sign the transaction properly + sender_auth = sdk._sign_transaction(real_account, txn) + result = sdk._serialize_signed_transaction(txn, sender_auth) + assert isinstance(result, bytes) + assert len(result) > 0 + + +class TestBaseSDKSyncSerializationMethods: + """Tests for BaseSDKSync serialization methods using real keys.""" + + @patch("decibel._base.AbiRegistry") + def test_serialize_for_simulation_with_fee_payer(self, mock_abi: Any, test_config: Any) -> None: + from aptos_sdk.account import Account + + real_account = Account.generate() + sdk = BaseSDKSync(config=test_config, account=real_account) + _, txn = _build_real_transaction(with_fee_payer=True) + + result = sdk._serialize_for_simulation(txn) + assert isinstance(result, bytes) + assert len(result) > 0 + + @patch("decibel._base.AbiRegistry") + def test_serialize_for_simulation_without_fee_payer( + self, mock_abi: Any, test_config: Any + ) -> None: + from aptos_sdk.account import Account + + real_account = Account.generate() + sdk = BaseSDKSync(config=test_config, account=real_account) + _, txn = _build_real_transaction(with_fee_payer=False) + + result = sdk._serialize_for_simulation(txn) + assert isinstance(result, bytes) + assert len(result) > 0 + + @patch("decibel._base.AbiRegistry") + def test_serialize_signed_transaction(self, mock_abi: Any, test_config: Any) -> None: + from aptos_sdk.account import Account + + real_account = Account.generate() + sdk = BaseSDKSync(config=test_config, account=real_account) + _, txn = _build_real_transaction(with_fee_payer=True) + + sender_auth = sdk._sign_transaction(real_account, txn) + result = sdk._serialize_signed_transaction(txn, sender_auth) + assert isinstance(result, bytes) + assert len(result) > 0 diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py new file mode 100644 index 0000000..965b2e7 --- /dev/null +++ b/tests/test_exceptions.py @@ -0,0 +1,77 @@ +"""Tests for decibel._exceptions module.""" + +from __future__ import annotations + +import pytest + +from decibel._exceptions import TxnConfirmError, TxnSubmitError + + +class TestTxnConfirmError: + def test_init_stores_tx_hash(self) -> None: + err = TxnConfirmError("0xdeadbeef", "timed out") + assert err.tx_hash == "0xdeadbeef" + + def test_init_formats_message_with_tx_hash(self) -> None: + err = TxnConfirmError("0xabc123", "transaction reverted") + assert "0xabc123" in str(err) + assert "transaction reverted" in str(err) + + def test_is_exception(self) -> None: + err = TxnConfirmError("0x1", "some error") + assert isinstance(err, Exception) + + def test_can_be_raised_and_caught(self) -> None: + with pytest.raises(TxnConfirmError) as exc_info: + raise TxnConfirmError("0xhash", "confirmation failed") + assert exc_info.value.tx_hash == "0xhash" + + def test_message_format_contains_transaction_prefix(self) -> None: + err = TxnConfirmError("0xfeed", "dropped") + assert str(err) == "Transaction 0xfeed: dropped" + + def test_empty_message(self) -> None: + err = TxnConfirmError("0x0", "") + assert err.tx_hash == "0x0" + assert "0x0" in str(err) + + def test_long_tx_hash(self) -> None: + long_hash = "0x" + "a" * 64 + err = TxnConfirmError(long_hash, "msg") + assert err.tx_hash == long_hash + assert long_hash in str(err) + + +class TestTxnSubmitError: + def test_init_stores_original_exception(self) -> None: + original = ValueError("connection refused") + err = TxnSubmitError("submit failed", original_exception=original) + assert err.original_exception is original + + def test_init_with_no_original_exception(self) -> None: + err = TxnSubmitError("failed without cause") + assert err.original_exception is None + + def test_message_is_set(self) -> None: + err = TxnSubmitError("network timeout") + assert str(err) == "network timeout" + + def test_is_exception(self) -> None: + err = TxnSubmitError("some error") + assert isinstance(err, Exception) + + def test_can_be_raised_and_caught(self) -> None: + original = ConnectionError("host unreachable") + with pytest.raises(TxnSubmitError) as exc_info: + raise TxnSubmitError("submit failed", original_exception=original) + assert exc_info.value.original_exception is original + + def test_original_exception_defaults_to_none(self) -> None: + err = TxnSubmitError("error message") + assert err.original_exception is None + + def test_with_various_original_exception_types(self) -> None: + for exc_type in [ValueError, RuntimeError, TimeoutError, OSError]: + original = exc_type("inner error") + err = TxnSubmitError("outer message", original_exception=original) + assert isinstance(err.original_exception, exc_type) diff --git a/tests/test_fee_pay.py b/tests/test_fee_pay.py new file mode 100644 index 0000000..d0949dd --- /dev/null +++ b/tests/test_fee_pay.py @@ -0,0 +1,687 @@ +"""Unit tests for decibel._fee_pay module. + +Covers: submit_fee_paid_transaction, submit_fee_paid_transaction_sync, +_submit_via_gas_station_api, _submit_via_gas_station_api_sync, +_submit_via_legacy_fee_payer, _submit_via_legacy_fee_payer_sync, +_get_default_gas_station_url. +""" + +from __future__ import annotations + +from dataclasses import replace +from typing import Any +from unittest.mock import AsyncMock, MagicMock, patch + +import httpx +import pytest + +from decibel._constants import DecibelConfig, Network +from decibel._fee_pay import ( + PendingTransactionResponse, + _get_default_gas_station_url, + _submit_via_gas_station_api, + _submit_via_gas_station_api_sync, + _submit_via_legacy_fee_payer, + _submit_via_legacy_fee_payer_sync, + submit_fee_paid_transaction, + submit_fee_paid_transaction_sync, +) + +# --------------------------------------------------------------------------- +# Helpers / fixtures +# --------------------------------------------------------------------------- + + +def _make_httpx_response( + status_code: int = 200, + json_data: Any = None, + text: str = "", +) -> httpx.Response: + if json_data is not None: + return httpx.Response( + status_code=status_code, + json=json_data, + request=httpx.Request("POST", "https://test.example.com"), + ) + return httpx.Response( + status_code=status_code, + text=text, + request=httpx.Request("POST", "https://test.example.com"), + ) + + +def _make_mock_transaction() -> MagicMock: + """Build a SimpleTransaction-like mock with serialisable internals.""" + mock_transaction = MagicMock() + mock_raw = MagicMock() + mock_raw.sender = "0x" + "aa" * 32 + mock_raw.sequence_number = 1 + mock_raw.max_gas_amount = 200000 + mock_raw.gas_unit_price = 100 + mock_raw.expiration_timestamps_secs = 9999999999 + + # serialize writes bytes into a Serializer — just let it be a no-op mock + mock_raw.serialize = MagicMock() + + mock_transaction.raw_transaction = mock_raw + mock_transaction.fee_payer_address = None + return mock_transaction + + +def _make_mock_authenticator() -> MagicMock: + """Build an AccountAuthenticator-like mock.""" + mock_auth = MagicMock() + # serialize is called with a Serializer; make it a no-op + mock_auth.serialize = MagicMock() + return mock_auth + + +def _gas_api_config(test_config: DecibelConfig) -> DecibelConfig: + return replace(test_config, gas_station_api_key="gs-api-key") + + +def _legacy_only_config(test_config: DecibelConfig) -> DecibelConfig: + return replace( + test_config, + gas_station_api_key=None, + gas_station_url="https://legacy-gas.example.com", + ) + + +def _no_gas_config(test_config: DecibelConfig) -> DecibelConfig: + return replace(test_config, gas_station_api_key=None, gas_station_url=None) + + +# --------------------------------------------------------------------------- +# PendingTransactionResponse model +# --------------------------------------------------------------------------- + + +class TestPendingTransactionResponse: + def test_creation(self) -> None: + resp = PendingTransactionResponse( + hash="0xabc", + sender="0xsender", + sequence_number="1", + max_gas_amount="200000", + gas_unit_price="100", + expiration_timestamp_secs="9999999", + ) + assert resp.hash == "0xabc" + assert resp.sender == "0xsender" + + +# --------------------------------------------------------------------------- +# submit_fee_paid_transaction (async routing) +# --------------------------------------------------------------------------- + + +class TestSubmitFeePaidTransaction: + @pytest.mark.asyncio + async def test_routes_to_gas_station_api_when_api_key_present(self, test_config: Any) -> None: + config = _gas_api_config(test_config) + mock_txn = _make_mock_transaction() + mock_auth = _make_mock_authenticator() + expected = PendingTransactionResponse( + hash="0xhash1", + sender="0xsender", + sequence_number="1", + max_gas_amount="200000", + gas_unit_price="100", + expiration_timestamp_secs="9999", + ) + + with patch( + "decibel._fee_pay._submit_via_gas_station_api", + new_callable=AsyncMock, + return_value=expected, + ) as mock_fn: + result = await submit_fee_paid_transaction(config, mock_txn, mock_auth) + + assert result is expected + mock_fn.assert_awaited_once() + + @pytest.mark.asyncio + async def test_routes_to_legacy_when_only_gas_station_url(self, test_config: Any) -> None: + config = _legacy_only_config(test_config) + mock_txn = _make_mock_transaction() + mock_auth = _make_mock_authenticator() + expected = PendingTransactionResponse( + hash="0xhash2", + sender="0xsender", + sequence_number="1", + max_gas_amount="200000", + gas_unit_price="100", + expiration_timestamp_secs="9999", + ) + + with patch( + "decibel._fee_pay._submit_via_legacy_fee_payer", + new_callable=AsyncMock, + return_value=expected, + ) as mock_fn: + result = await submit_fee_paid_transaction(config, mock_txn, mock_auth) + + assert result is expected + mock_fn.assert_awaited_once() + + @pytest.mark.asyncio + async def test_raises_when_neither_key_nor_url(self, test_config: Any) -> None: + config = _no_gas_config(test_config) + with pytest.raises(ValueError, match="gas_station_api_key or gas_station_url"): + await submit_fee_paid_transaction( + config, _make_mock_transaction(), _make_mock_authenticator() + ) + + +# --------------------------------------------------------------------------- +# submit_fee_paid_transaction_sync (sync routing) +# --------------------------------------------------------------------------- + + +class TestSubmitFeePaidTransactionSync: + def test_routes_to_gas_station_api_when_api_key_present(self, test_config: Any) -> None: + config = _gas_api_config(test_config) + mock_txn = _make_mock_transaction() + mock_auth = _make_mock_authenticator() + expected = PendingTransactionResponse( + hash="0xhash3", + sender="0xsender", + sequence_number="1", + max_gas_amount="200000", + gas_unit_price="100", + expiration_timestamp_secs="9999", + ) + + with patch( + "decibel._fee_pay._submit_via_gas_station_api_sync", + return_value=expected, + ) as mock_fn: + result = submit_fee_paid_transaction_sync(config, mock_txn, mock_auth) + + assert result is expected + mock_fn.assert_called_once() + + def test_routes_to_legacy_when_only_gas_station_url(self, test_config: Any) -> None: + config = _legacy_only_config(test_config) + mock_txn = _make_mock_transaction() + mock_auth = _make_mock_authenticator() + expected = PendingTransactionResponse( + hash="0xhash4", + sender="0xsender", + sequence_number="1", + max_gas_amount="200000", + gas_unit_price="100", + expiration_timestamp_secs="9999", + ) + + with patch( + "decibel._fee_pay._submit_via_legacy_fee_payer_sync", + return_value=expected, + ) as mock_fn: + result = submit_fee_paid_transaction_sync(config, mock_txn, mock_auth) + + assert result is expected + mock_fn.assert_called_once() + + def test_raises_when_neither_key_nor_url(self, test_config: Any) -> None: + config = _no_gas_config(test_config) + with pytest.raises(ValueError, match="gas_station_api_key or gas_station_url"): + submit_fee_paid_transaction_sync( + config, _make_mock_transaction(), _make_mock_authenticator() + ) + + +# --------------------------------------------------------------------------- +# _submit_via_gas_station_api (async) +# --------------------------------------------------------------------------- + + +class TestSubmitViaGasStationApi: + @pytest.mark.asyncio + async def test_success_with_provided_client(self, test_config: Any) -> None: + config = _gas_api_config(test_config) + mock_txn = _make_mock_transaction() + mock_auth = _make_mock_authenticator() + + response_data = {"transactionHash": "0xgas_station_hash"} + mock_client = AsyncMock(spec=httpx.AsyncClient) + mock_client.post = AsyncMock( + return_value=_make_httpx_response(200, json_data=response_data) + ) + + result = await _submit_via_gas_station_api(config, mock_txn, mock_auth, client=mock_client) + + assert result.hash == "0xgas_station_hash" + mock_client.post.assert_awaited_once() + + @pytest.mark.asyncio + async def test_success_uses_hash_fallback(self, test_config: Any) -> None: + """When 'transactionHash' absent, falls back to 'hash'.""" + config = _gas_api_config(test_config) + mock_txn = _make_mock_transaction() + mock_auth = _make_mock_authenticator() + + response_data = {"hash": "0xfallback_hash"} + mock_client = AsyncMock(spec=httpx.AsyncClient) + mock_client.post = AsyncMock( + return_value=_make_httpx_response(200, json_data=response_data) + ) + + result = await _submit_via_gas_station_api(config, mock_txn, mock_auth, client=mock_client) + assert result.hash == "0xfallback_hash" + + @pytest.mark.asyncio + async def test_error_raises_value_error(self, test_config: Any) -> None: + config = _gas_api_config(test_config) + mock_txn = _make_mock_transaction() + mock_auth = _make_mock_authenticator() + + mock_client = AsyncMock(spec=httpx.AsyncClient) + mock_client.post = AsyncMock(return_value=_make_httpx_response(400, text="Bad Request")) + + with pytest.raises(ValueError, match="Gas station API error"): + await _submit_via_gas_station_api(config, mock_txn, mock_auth, client=mock_client) + + @pytest.mark.asyncio + async def test_without_client_creates_temp_client(self, test_config: Any) -> None: + config = _gas_api_config(test_config) + mock_txn = _make_mock_transaction() + mock_auth = _make_mock_authenticator() + + response_data = {"transactionHash": "0xtemp"} + mock_temp = AsyncMock(spec=httpx.AsyncClient) + mock_temp.post = AsyncMock(return_value=_make_httpx_response(200, json_data=response_data)) + mock_temp.__aenter__ = AsyncMock(return_value=mock_temp) + mock_temp.__aexit__ = AsyncMock(return_value=None) + + with patch("httpx.AsyncClient", return_value=mock_temp): + result = await _submit_via_gas_station_api(config, mock_txn, mock_auth, client=None) + + assert result.hash == "0xtemp" + + @pytest.mark.asyncio + async def test_sends_authorization_header(self, test_config: Any) -> None: + config = _gas_api_config(test_config) + mock_txn = _make_mock_transaction() + mock_auth = _make_mock_authenticator() + + mock_client = AsyncMock(spec=httpx.AsyncClient) + mock_client.post = AsyncMock( + return_value=_make_httpx_response(200, json_data={"transactionHash": "0xok"}) + ) + + await _submit_via_gas_station_api(config, mock_txn, mock_auth, client=mock_client) + + call_kwargs = mock_client.post.call_args.kwargs + assert call_kwargs["headers"]["Authorization"] == f"Bearer {config.gas_station_api_key}" + + @pytest.mark.asyncio + async def test_with_fee_payer_address_serialised(self, test_config: Any) -> None: + config = _gas_api_config(test_config) + mock_txn = _make_mock_transaction() + # Simulate a fee_payer_address being set + mock_fee_payer_addr = MagicMock() + mock_fee_payer_addr.serialize = MagicMock() + mock_txn.fee_payer_address = mock_fee_payer_addr + mock_auth = _make_mock_authenticator() + + mock_client = AsyncMock(spec=httpx.AsyncClient) + mock_client.post = AsyncMock( + return_value=_make_httpx_response(200, json_data={"transactionHash": "0xfp"}) + ) + + result = await _submit_via_gas_station_api(config, mock_txn, mock_auth, client=mock_client) + assert result.hash == "0xfp" + mock_fee_payer_addr.serialize.assert_called() + + @pytest.mark.asyncio + async def test_response_has_correct_sender_fields(self, test_config: Any) -> None: + config = _gas_api_config(test_config) + mock_txn = _make_mock_transaction() + mock_auth = _make_mock_authenticator() + + mock_client = AsyncMock(spec=httpx.AsyncClient) + mock_client.post = AsyncMock( + return_value=_make_httpx_response(200, json_data={"transactionHash": "0xh"}) + ) + + result = await _submit_via_gas_station_api(config, mock_txn, mock_auth, client=mock_client) + + assert result.sequence_number == str(mock_txn.raw_transaction.sequence_number) + assert result.max_gas_amount == str(mock_txn.raw_transaction.max_gas_amount) + assert result.gas_unit_price == str(mock_txn.raw_transaction.gas_unit_price) + + +# --------------------------------------------------------------------------- +# _submit_via_gas_station_api_sync +# --------------------------------------------------------------------------- + + +class TestSubmitViaGasStationApiSync: + def test_success_with_provided_client(self, test_config: Any) -> None: + config = _gas_api_config(test_config) + mock_txn = _make_mock_transaction() + mock_auth = _make_mock_authenticator() + + response_data = {"transactionHash": "0xsync_gs_hash"} + mock_client = MagicMock(spec=httpx.Client) + mock_client.post = MagicMock( + return_value=_make_httpx_response(200, json_data=response_data) + ) + + result = _submit_via_gas_station_api_sync(config, mock_txn, mock_auth, client=mock_client) + assert result.hash == "0xsync_gs_hash" + + def test_success_hash_fallback(self, test_config: Any) -> None: + config = _gas_api_config(test_config) + mock_txn = _make_mock_transaction() + mock_auth = _make_mock_authenticator() + + mock_client = MagicMock(spec=httpx.Client) + mock_client.post = MagicMock( + return_value=_make_httpx_response(200, json_data={"hash": "0xsync_fallback"}) + ) + + result = _submit_via_gas_station_api_sync(config, mock_txn, mock_auth, client=mock_client) + assert result.hash == "0xsync_fallback" + + def test_error_raises_value_error(self, test_config: Any) -> None: + config = _gas_api_config(test_config) + mock_txn = _make_mock_transaction() + mock_auth = _make_mock_authenticator() + + mock_client = MagicMock(spec=httpx.Client) + mock_client.post = MagicMock( + return_value=_make_httpx_response(500, text="Internal Server Error") + ) + + with pytest.raises(ValueError, match="Gas station API error"): + _submit_via_gas_station_api_sync(config, mock_txn, mock_auth, client=mock_client) + + def test_without_client_creates_temp_client(self, test_config: Any) -> None: + config = _gas_api_config(test_config) + mock_txn = _make_mock_transaction() + mock_auth = _make_mock_authenticator() + + response_data = {"transactionHash": "0xsync_temp"} + mock_temp = MagicMock(spec=httpx.Client) + mock_temp.post = MagicMock(return_value=_make_httpx_response(200, json_data=response_data)) + mock_temp.__enter__ = MagicMock(return_value=mock_temp) + mock_temp.__exit__ = MagicMock(return_value=None) + + with patch("httpx.Client", return_value=mock_temp): + result = _submit_via_gas_station_api_sync(config, mock_txn, mock_auth, client=None) + + assert result.hash == "0xsync_temp" + + def test_sends_authorization_header(self, test_config: Any) -> None: + config = _gas_api_config(test_config) + mock_txn = _make_mock_transaction() + mock_auth = _make_mock_authenticator() + + mock_client = MagicMock(spec=httpx.Client) + mock_client.post = MagicMock( + return_value=_make_httpx_response(200, json_data={"transactionHash": "0xok"}) + ) + + _submit_via_gas_station_api_sync(config, mock_txn, mock_auth, client=mock_client) + + call_kwargs = mock_client.post.call_args.kwargs + assert call_kwargs["headers"]["Authorization"] == f"Bearer {config.gas_station_api_key}" + + def test_with_fee_payer_address_serialised(self, test_config: Any) -> None: + """Covers the fee_payer_address.serialize branch in _submit_via_gas_station_api_sync.""" + config = _gas_api_config(test_config) + mock_txn = _make_mock_transaction() + # Set a fee_payer_address that is not None + mock_fee_payer_addr = MagicMock() + mock_fee_payer_addr.serialize = MagicMock() + mock_txn.fee_payer_address = mock_fee_payer_addr + mock_auth = _make_mock_authenticator() + + mock_client = MagicMock(spec=httpx.Client) + mock_client.post = MagicMock( + return_value=_make_httpx_response(200, json_data={"transactionHash": "0xfp_sync"}) + ) + + result = _submit_via_gas_station_api_sync(config, mock_txn, mock_auth, client=mock_client) + assert result.hash == "0xfp_sync" + mock_fee_payer_addr.serialize.assert_called() + + +# --------------------------------------------------------------------------- +# _submit_via_legacy_fee_payer (async) +# --------------------------------------------------------------------------- + + +class TestSubmitViaLegacyFeePayer: + @pytest.mark.asyncio + async def test_success_with_provided_client(self, test_config: Any) -> None: + config = _legacy_only_config(test_config) + mock_txn = _make_mock_transaction() + mock_auth = _make_mock_authenticator() + + response_data = { + "hash": "0xlegacy_hash", + "sender": "0xsender", + "sequence_number": "1", + "max_gas_amount": "200000", + "gas_unit_price": "100", + "expiration_timestamp_secs": "9999", + } + mock_client = AsyncMock(spec=httpx.AsyncClient) + mock_client.post = AsyncMock( + return_value=_make_httpx_response(200, json_data=response_data) + ) + + result = await _submit_via_legacy_fee_payer(config, mock_txn, mock_auth, client=mock_client) + + assert result.hash == "0xlegacy_hash" + assert result.sender == "0xsender" + + @pytest.mark.asyncio + async def test_error_raises_value_error(self, test_config: Any) -> None: + config = _legacy_only_config(test_config) + mock_txn = _make_mock_transaction() + mock_auth = _make_mock_authenticator() + + mock_client = AsyncMock(spec=httpx.AsyncClient) + mock_client.post = AsyncMock(return_value=_make_httpx_response(400, text="Bad request")) + + with pytest.raises(ValueError, match="Fee payer error"): + await _submit_via_legacy_fee_payer(config, mock_txn, mock_auth, client=mock_client) + + @pytest.mark.asyncio + async def test_without_client_creates_temp_client(self, test_config: Any) -> None: + config = _legacy_only_config(test_config) + mock_txn = _make_mock_transaction() + mock_auth = _make_mock_authenticator() + + response_data = { + "hash": "0xtemp_legacy", + "sender": "0xsender", + "sequence_number": "1", + "max_gas_amount": "200000", + "gas_unit_price": "100", + "expiration_timestamp_secs": "9999", + } + mock_temp = AsyncMock(spec=httpx.AsyncClient) + mock_temp.post = AsyncMock(return_value=_make_httpx_response(200, json_data=response_data)) + mock_temp.__aenter__ = AsyncMock(return_value=mock_temp) + mock_temp.__aexit__ = AsyncMock(return_value=None) + + with patch("httpx.AsyncClient", return_value=mock_temp): + result = await _submit_via_legacy_fee_payer(config, mock_txn, mock_auth, client=None) + + assert result.hash == "0xtemp_legacy" + + @pytest.mark.asyncio + async def test_posts_to_correct_url(self, test_config: Any) -> None: + config = _legacy_only_config(test_config) + mock_txn = _make_mock_transaction() + mock_auth = _make_mock_authenticator() + + response_data = { + "hash": "0xurl_check", + "sender": "", + "sequence_number": "", + "max_gas_amount": "", + "gas_unit_price": "", + "expiration_timestamp_secs": "", + } + mock_client = AsyncMock(spec=httpx.AsyncClient) + mock_client.post = AsyncMock( + return_value=_make_httpx_response(200, json_data=response_data) + ) + + await _submit_via_legacy_fee_payer(config, mock_txn, mock_auth, client=mock_client) + + call_args = mock_client.post.call_args + assert call_args.args[0] == f"{config.gas_station_url}/transactions" + + +# --------------------------------------------------------------------------- +# _submit_via_legacy_fee_payer_sync +# --------------------------------------------------------------------------- + + +class TestSubmitViaLegacyFeePayerSync: + def test_success_with_provided_client(self, test_config: Any) -> None: + config = _legacy_only_config(test_config) + mock_txn = _make_mock_transaction() + mock_auth = _make_mock_authenticator() + + response_data = { + "hash": "0xsync_legacy", + "sender": "0xsender", + "sequence_number": "1", + "max_gas_amount": "200000", + "gas_unit_price": "100", + "expiration_timestamp_secs": "9999", + } + mock_client = MagicMock(spec=httpx.Client) + mock_client.post = MagicMock( + return_value=_make_httpx_response(200, json_data=response_data) + ) + + result = _submit_via_legacy_fee_payer_sync(config, mock_txn, mock_auth, client=mock_client) + assert result.hash == "0xsync_legacy" + + def test_error_raises_value_error(self, test_config: Any) -> None: + config = _legacy_only_config(test_config) + mock_txn = _make_mock_transaction() + mock_auth = _make_mock_authenticator() + + mock_client = MagicMock(spec=httpx.Client) + mock_client.post = MagicMock(return_value=_make_httpx_response(500, text="Server Error")) + + with pytest.raises(ValueError, match="Fee payer error"): + _submit_via_legacy_fee_payer_sync(config, mock_txn, mock_auth, client=mock_client) + + def test_without_client_creates_temp_client(self, test_config: Any) -> None: + config = _legacy_only_config(test_config) + mock_txn = _make_mock_transaction() + mock_auth = _make_mock_authenticator() + + response_data = { + "hash": "0xsync_temp_legacy", + "sender": "", + "sequence_number": "", + "max_gas_amount": "", + "gas_unit_price": "", + "expiration_timestamp_secs": "", + } + mock_temp = MagicMock(spec=httpx.Client) + mock_temp.post = MagicMock(return_value=_make_httpx_response(200, json_data=response_data)) + mock_temp.__enter__ = MagicMock(return_value=mock_temp) + mock_temp.__exit__ = MagicMock(return_value=None) + + with patch("httpx.Client", return_value=mock_temp): + result = _submit_via_legacy_fee_payer_sync(config, mock_txn, mock_auth, client=None) + + assert result.hash == "0xsync_temp_legacy" + + def test_posts_to_correct_url(self, test_config: Any) -> None: + config = _legacy_only_config(test_config) + mock_txn = _make_mock_transaction() + mock_auth = _make_mock_authenticator() + + response_data = { + "hash": "0x", + "sender": "", + "sequence_number": "", + "max_gas_amount": "", + "gas_unit_price": "", + "expiration_timestamp_secs": "", + } + mock_client = MagicMock(spec=httpx.Client) + mock_client.post = MagicMock( + return_value=_make_httpx_response(200, json_data=response_data) + ) + + _submit_via_legacy_fee_payer_sync(config, mock_txn, mock_auth, client=mock_client) + + call_args = mock_client.post.call_args + assert call_args.args[0] == f"{config.gas_station_url}/transactions" + + +# --------------------------------------------------------------------------- +# _get_default_gas_station_url +# --------------------------------------------------------------------------- + + +class TestGetDefaultGasStationUrl: + def test_testnet_returns_testnet_url(self, test_config: Any) -> None: + config = replace(test_config, network=Network.TESTNET) + url = _get_default_gas_station_url(config) + assert "testnet" in url + assert "aptoslabs" in url + + def test_chain_id_208_returns_netna_url(self, test_config: Any) -> None: + config = replace(test_config, network=Network.CUSTOM, chain_id=208) + url = _get_default_gas_station_url(config) + assert "netna" in url + + def test_custom_network_with_gas_station_url_returns_it(self, test_config: Any) -> None: + custom_url = "https://my-custom-gas-station.example.com/v1" + config = replace( + test_config, + network=Network.CUSTOM, + chain_id=999, + gas_station_url=custom_url, + ) + url = _get_default_gas_station_url(config) + assert url == custom_url + + def test_custom_network_without_gas_station_url_raises(self, test_config: Any) -> None: + config = replace( + test_config, + network=Network.CUSTOM, + chain_id=999, + gas_station_url=None, + ) + with pytest.raises(ValueError, match="gas_station_url must be provided"): + _get_default_gas_station_url(config) + + def test_mainnet_without_gas_station_url_raises(self, test_config: Any) -> None: + config = replace( + test_config, + network=Network.MAINNET, + chain_id=1, + gas_station_url=None, + ) + # MAINNET is not explicitly handled, falls through to gas_station_url check + with pytest.raises(ValueError, match="gas_station_url must be provided"): + _get_default_gas_station_url(config) + + def test_mainnet_with_gas_station_url_returns_it(self, test_config: Any) -> None: + mainnet_gs_url = "https://api.mainnet.aptoslabs.com/gs/v1" + config = replace( + test_config, + network=Network.MAINNET, + chain_id=1, + gas_station_url=mainnet_gs_url, + ) + url = _get_default_gas_station_url(config) + assert url == mainnet_gs_url diff --git a/tests/test_gas_price_manager.py b/tests/test_gas_price_manager.py new file mode 100644 index 0000000..1b36f8c --- /dev/null +++ b/tests/test_gas_price_manager.py @@ -0,0 +1,676 @@ +"""Tests for decibel._gas_price_manager module.""" + +from __future__ import annotations + +import asyncio +import threading +import time +from unittest.mock import AsyncMock, MagicMock, patch + +import httpx +import pytest + +from decibel._gas_price_manager import ( + GasPriceInfo, + GasPriceManager, + GasPriceManagerOptions, + GasPriceManagerSync, +) + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_mock_response(gas_estimate: int = 100, is_success: bool = True) -> MagicMock: + mock_response = MagicMock() + mock_response.is_success = is_success + mock_response.json.return_value = {"gas_estimate": gas_estimate} + mock_response.status_code = 200 if is_success else 500 + mock_response.text = "" if is_success else "Internal Server Error" + return mock_response + + +# --------------------------------------------------------------------------- +# GasPriceManagerOptions +# --------------------------------------------------------------------------- + + +class TestGasPriceManagerOptions: + def test_default_values(self) -> None: + opts = GasPriceManagerOptions() + assert opts.node_api_key is None + assert opts.multiplier == 2.0 + assert opts.refresh_interval_seconds == 60.0 + assert opts.http_client is None + assert opts.http_client_sync is None + + def test_custom_values(self) -> None: + async_client = AsyncMock(spec=httpx.AsyncClient) + sync_client = MagicMock(spec=httpx.Client) + opts = GasPriceManagerOptions( + node_api_key="my-key", + multiplier=3.0, + refresh_interval_seconds=30.0, + http_client=async_client, + http_client_sync=sync_client, + ) + assert opts.node_api_key == "my-key" + assert opts.multiplier == 3.0 + assert opts.refresh_interval_seconds == 30.0 + assert opts.http_client is async_client + assert opts.http_client_sync is sync_client + + def test_http_client_fields_accept_none(self) -> None: + opts = GasPriceManagerOptions(http_client=None, http_client_sync=None) + assert opts.http_client is None + assert opts.http_client_sync is None + + +# --------------------------------------------------------------------------- +# GasPriceInfo +# --------------------------------------------------------------------------- + + +class TestGasPriceInfo: + def test_stores_fields(self) -> None: + info = GasPriceInfo(gas_estimate=200, timestamp=12345.0) + assert info.gas_estimate == 200 + assert info.timestamp == 12345.0 + + +# --------------------------------------------------------------------------- +# GasPriceManager (async) +# --------------------------------------------------------------------------- + + +class TestGasPriceManagerInit: + def test_default_state(self, test_config: object) -> None: + mgr = GasPriceManager(test_config) # type: ignore[arg-type] + assert mgr._gas_price is None + assert mgr._refresh_task is None + assert mgr._pending_refresh_task is None + assert not mgr._is_initialized + assert mgr._multiplier == 2.0 + assert mgr._refresh_interval_seconds == 60.0 + assert mgr._http_client is None + + def test_stores_http_client_from_opts(self, test_config: object) -> None: + client = AsyncMock(spec=httpx.AsyncClient) + opts = GasPriceManagerOptions(http_client=client) + mgr = GasPriceManager(test_config, opts=opts) # type: ignore[arg-type] + assert mgr._http_client is client + + def test_get_gas_price_returns_none_when_not_set(self, test_config: object) -> None: + mgr = GasPriceManager(test_config) # type: ignore[arg-type] + assert mgr.get_gas_price() is None + assert mgr.gas_price is None + + def test_is_initialized_false_by_default(self, test_config: object) -> None: + mgr = GasPriceManager(test_config) # type: ignore[arg-type] + assert not mgr.is_initialized + + +class TestGasPriceManagerFetchGasPriceEstimation: + async def test_with_shared_client(self, test_config: object) -> None: + mock_response = _make_mock_response(gas_estimate=100) + client = AsyncMock(spec=httpx.AsyncClient) + client.get = AsyncMock(return_value=mock_response) + + opts = GasPriceManagerOptions(http_client=client, multiplier=1.0) + mgr = GasPriceManager(test_config, opts=opts) # type: ignore[arg-type] + + result = await mgr.fetch_gas_price_estimation() + + assert result == 100 + client.get.assert_called_once() + call_kwargs = client.get.call_args + assert call_kwargs.kwargs["timeout"] == 5.0 + + async def test_without_client_creates_temp(self, test_config: object) -> None: + mock_response = _make_mock_response(gas_estimate=50) + mock_temp_client = AsyncMock() + mock_temp_client.get = AsyncMock(return_value=mock_response) + mock_temp_client.__aenter__ = AsyncMock(return_value=mock_temp_client) + mock_temp_client.__aexit__ = AsyncMock(return_value=None) + + opts = GasPriceManagerOptions(multiplier=1.0) + mgr = GasPriceManager(test_config, opts=opts) # type: ignore[arg-type] + + with patch("decibel._gas_price_manager.httpx.AsyncClient", return_value=mock_temp_client): + result = await mgr.fetch_gas_price_estimation() + + assert result == 50 + + async def test_applies_multiplier(self, test_config: object) -> None: + mock_response = _make_mock_response(gas_estimate=100) + client = AsyncMock(spec=httpx.AsyncClient) + client.get = AsyncMock(return_value=mock_response) + + opts = GasPriceManagerOptions(http_client=client, multiplier=2.0) + mgr = GasPriceManager(test_config, opts=opts) # type: ignore[arg-type] + + result = await mgr.fetch_gas_price_estimation() + assert result == 200 + + async def test_error_response_raises(self, test_config: object) -> None: + mock_response = _make_mock_response(is_success=False) + client = AsyncMock(spec=httpx.AsyncClient) + client.get = AsyncMock(return_value=mock_response) + + opts = GasPriceManagerOptions(http_client=client) + mgr = GasPriceManager(test_config, opts=opts) # type: ignore[arg-type] + + with pytest.raises(ValueError, match="Failed to fetch gas price"): + await mgr.fetch_gas_price_estimation() + + async def test_includes_auth_header_when_api_key_set(self, test_config: object) -> None: + mock_response = _make_mock_response(gas_estimate=100) + client = AsyncMock(spec=httpx.AsyncClient) + client.get = AsyncMock(return_value=mock_response) + + opts = GasPriceManagerOptions(http_client=client, node_api_key="secret", multiplier=1.0) + mgr = GasPriceManager(test_config, opts=opts) # type: ignore[arg-type] + + await mgr.fetch_gas_price_estimation() + + call_kwargs = client.get.call_args.kwargs + assert call_kwargs["headers"] == {"x-api-key": "secret"} + + async def test_no_auth_header_when_no_api_key(self, test_config: object) -> None: + mock_response = _make_mock_response(gas_estimate=100) + client = AsyncMock(spec=httpx.AsyncClient) + client.get = AsyncMock(return_value=mock_response) + + opts = GasPriceManagerOptions(http_client=client, node_api_key=None, multiplier=1.0) + mgr = GasPriceManager(test_config, opts=opts) # type: ignore[arg-type] + + await mgr.fetch_gas_price_estimation() + + call_kwargs = client.get.call_args.kwargs + assert call_kwargs["headers"] == {} + + +class TestGasPriceManagerFetchAndSet: + async def test_sets_gas_price_on_success(self, test_config: object) -> None: + mock_response = _make_mock_response(gas_estimate=100) + client = AsyncMock(spec=httpx.AsyncClient) + client.get = AsyncMock(return_value=mock_response) + + opts = GasPriceManagerOptions(http_client=client, multiplier=1.0) + mgr = GasPriceManager(test_config, opts=opts) # type: ignore[arg-type] + + result = await mgr.fetch_and_set_gas_price() + + assert result == 100 + assert mgr._gas_price is not None + assert mgr._gas_price.gas_estimate == 100 + + async def test_raises_on_zero_estimate(self, test_config: object) -> None: + mock_response = _make_mock_response(gas_estimate=0) + client = AsyncMock(spec=httpx.AsyncClient) + client.get = AsyncMock(return_value=mock_response) + + opts = GasPriceManagerOptions(http_client=client, multiplier=1.0) + mgr = GasPriceManager(test_config, opts=opts) # type: ignore[arg-type] + + with pytest.raises(ValueError, match="no gas estimate"): + await mgr.fetch_and_set_gas_price() + + async def test_raises_and_logs_on_error(self, test_config: object) -> None: + client = AsyncMock(spec=httpx.AsyncClient) + client.get = AsyncMock(side_effect=ConnectionError("network failure")) + + opts = GasPriceManagerOptions(http_client=client) + mgr = GasPriceManager(test_config, opts=opts) # type: ignore[arg-type] + + with pytest.raises(ConnectionError): + await mgr.fetch_and_set_gas_price() + + +class TestGasPriceManagerGetGasPrice: + async def test_returns_none_when_not_set(self, test_config: object) -> None: + mgr = GasPriceManager(test_config) # type: ignore[arg-type] + assert mgr.get_gas_price() is None + + async def test_returns_gas_estimate_when_set(self, test_config: object) -> None: + mgr = GasPriceManager(test_config) # type: ignore[arg-type] + mgr._gas_price = GasPriceInfo(gas_estimate=500, timestamp=time.time()) + assert mgr.get_gas_price() == 500 + + +class TestGasPriceManagerInitialize: + async def test_initialize_calls_fetch_and_creates_task(self, test_config: object) -> None: + mock_response = _make_mock_response(gas_estimate=100) + client = AsyncMock(spec=httpx.AsyncClient) + client.get = AsyncMock(return_value=mock_response) + + opts = GasPriceManagerOptions(http_client=client, multiplier=1.0) + mgr = GasPriceManager(test_config, opts=opts) # type: ignore[arg-type] + + await mgr.initialize() + + assert mgr._is_initialized + assert mgr._refresh_task is not None + mgr._refresh_task.cancel() + + async def test_already_initialized_is_noop(self, test_config: object) -> None: + mock_response = _make_mock_response(gas_estimate=100) + client = AsyncMock(spec=httpx.AsyncClient) + client.get = AsyncMock(return_value=mock_response) + + opts = GasPriceManagerOptions(http_client=client, multiplier=1.0) + mgr = GasPriceManager(test_config, opts=opts) # type: ignore[arg-type] + + await mgr.initialize() + first_task = mgr._refresh_task + call_count = client.get.call_count + + await mgr.initialize() + + assert mgr._refresh_task is first_task + assert client.get.call_count == call_count + mgr._refresh_task.cancel() + + async def test_initialize_logs_on_failure(self, test_config: object) -> None: + client = AsyncMock(spec=httpx.AsyncClient) + client.get = AsyncMock(side_effect=RuntimeError("boom")) + + opts = GasPriceManagerOptions(http_client=client) + mgr = GasPriceManager(test_config, opts=opts) # type: ignore[arg-type] + + # Should not raise; logs instead + await mgr.initialize() + assert not mgr._is_initialized + + +class TestGasPriceManagerDestroy: + async def test_destroy_cancels_task_and_clears_state(self, test_config: object) -> None: + mock_response = _make_mock_response(gas_estimate=100) + client = AsyncMock(spec=httpx.AsyncClient) + client.get = AsyncMock(return_value=mock_response) + + opts = GasPriceManagerOptions(http_client=client, multiplier=1.0) + mgr = GasPriceManager(test_config, opts=opts) # type: ignore[arg-type] + + await mgr.initialize() + assert mgr._is_initialized + + await mgr.destroy() + + assert not mgr._is_initialized + assert mgr._gas_price is None + assert mgr._refresh_task is None + + async def test_destroy_with_no_task_is_safe(self, test_config: object) -> None: + mgr = GasPriceManager(test_config) # type: ignore[arg-type] + await mgr.destroy() # Should not raise + + +class TestGasPriceManagerRefresh: + async def test_refresh_creates_pending_task(self, test_config: object) -> None: + mock_response = _make_mock_response(gas_estimate=100) + client = AsyncMock(spec=httpx.AsyncClient) + client.get = AsyncMock(return_value=mock_response) + + opts = GasPriceManagerOptions(http_client=client, multiplier=1.0) + mgr = GasPriceManager(test_config, opts=opts) # type: ignore[arg-type] + + mgr.refresh() + assert mgr._pending_refresh_task is not None + # Let it complete + await asyncio.sleep(0) + + async def test_refresh_noop_when_task_already_pending(self, test_config: object) -> None: + mock_response = _make_mock_response(gas_estimate=100) + client = AsyncMock(spec=httpx.AsyncClient) + client.get = AsyncMock(return_value=mock_response) + + opts = GasPriceManagerOptions(http_client=client, multiplier=1.0) + mgr = GasPriceManager(test_config, opts=opts) # type: ignore[arg-type] + + mgr.refresh() + first_task = mgr._pending_refresh_task + mgr.refresh() + + # Should be the same task (not done yet) + assert mgr._pending_refresh_task is first_task + await asyncio.sleep(0) + + +class TestGasPriceManagerRefreshLoop: + async def test_refresh_loop_calls_fetch_periodically(self, test_config: object) -> None: + call_count = 0 + + async def fake_fetch_and_set() -> int: + nonlocal call_count + call_count += 1 + return 100 + + mgr = GasPriceManager(test_config) # type: ignore[arg-type] + mgr._refresh_interval_seconds = 0.01 + mgr.fetch_and_set_gas_price = fake_fetch_and_set # type: ignore[method-assign] + + task = asyncio.create_task(mgr._refresh_loop()) + await asyncio.sleep(0.05) + task.cancel() + with pytest.raises(asyncio.CancelledError): + await task + + assert call_count >= 1 + + async def test_refresh_loop_continues_on_exception(self, test_config: object) -> None: + call_count = 0 + + async def flaky_fetch() -> int: + nonlocal call_count + call_count += 1 + raise RuntimeError("transient error") + + mgr = GasPriceManager(test_config) # type: ignore[arg-type] + mgr._refresh_interval_seconds = 0.01 + mgr.fetch_and_set_gas_price = flaky_fetch # type: ignore[method-assign] + + task = asyncio.create_task(mgr._refresh_loop()) + await asyncio.sleep(0.05) + task.cancel() + with pytest.raises(asyncio.CancelledError): + await task + + assert call_count >= 1 + + +class TestGasPriceManagerContextManager: + async def test_aenter_calls_initialize(self, test_config: object) -> None: + mock_response = _make_mock_response(gas_estimate=100) + client = AsyncMock(spec=httpx.AsyncClient) + client.get = AsyncMock(return_value=mock_response) + + opts = GasPriceManagerOptions(http_client=client, multiplier=1.0) + mgr = GasPriceManager(test_config, opts=opts) # type: ignore[arg-type] + + async with mgr as ctx: + assert ctx is mgr + assert mgr._is_initialized + + async def test_aexit_calls_destroy(self, test_config: object) -> None: + mock_response = _make_mock_response(gas_estimate=100) + client = AsyncMock(spec=httpx.AsyncClient) + client.get = AsyncMock(return_value=mock_response) + + opts = GasPriceManagerOptions(http_client=client, multiplier=1.0) + mgr = GasPriceManager(test_config, opts=opts) # type: ignore[arg-type] + + async with mgr: + pass + + assert not mgr._is_initialized + assert mgr._gas_price is None + + +# --------------------------------------------------------------------------- +# GasPriceManagerSync (threaded) +# --------------------------------------------------------------------------- + + +class TestGasPriceManagerSyncInit: + def test_default_state(self, test_config: object) -> None: + mgr = GasPriceManagerSync(test_config) # type: ignore[arg-type] + assert mgr._gas_price is None + assert mgr._refresh_thread is None + assert not mgr._is_initialized + assert mgr._multiplier == 2.0 + assert mgr._http_client is None + + def test_stores_http_client_sync_from_opts(self, test_config: object) -> None: + client = MagicMock(spec=httpx.Client) + opts = GasPriceManagerOptions(http_client_sync=client) + mgr = GasPriceManagerSync(test_config, opts=opts) # type: ignore[arg-type] + assert mgr._http_client is client + + +class TestGasPriceManagerSyncFetchGasPriceEstimation: + def test_with_shared_client(self, test_config: object) -> None: + mock_response = _make_mock_response(gas_estimate=100) + client = MagicMock(spec=httpx.Client) + client.get = MagicMock(return_value=mock_response) + + opts = GasPriceManagerOptions(http_client_sync=client, multiplier=1.0) + mgr = GasPriceManagerSync(test_config, opts=opts) # type: ignore[arg-type] + + result = mgr.fetch_gas_price_estimation() + + assert result == 100 + client.get.assert_called_once() + call_kwargs = client.get.call_args.kwargs + assert call_kwargs["timeout"] == 5.0 + + def test_without_client_creates_temp(self, test_config: object) -> None: + mock_response = _make_mock_response(gas_estimate=75) + mock_temp_client = MagicMock() + mock_temp_client.get = MagicMock(return_value=mock_response) + mock_temp_client.__enter__ = MagicMock(return_value=mock_temp_client) + mock_temp_client.__exit__ = MagicMock(return_value=None) + + opts = GasPriceManagerOptions(multiplier=1.0) + mgr = GasPriceManagerSync(test_config, opts=opts) # type: ignore[arg-type] + + with patch("decibel._gas_price_manager.httpx.Client", return_value=mock_temp_client): + result = mgr.fetch_gas_price_estimation() + + assert result == 75 + + def test_applies_multiplier(self, test_config: object) -> None: + mock_response = _make_mock_response(gas_estimate=100) + client = MagicMock(spec=httpx.Client) + client.get = MagicMock(return_value=mock_response) + + opts = GasPriceManagerOptions(http_client_sync=client, multiplier=3.0) + mgr = GasPriceManagerSync(test_config, opts=opts) # type: ignore[arg-type] + + result = mgr.fetch_gas_price_estimation() + assert result == 300 + + def test_error_response_raises(self, test_config: object) -> None: + mock_response = _make_mock_response(is_success=False) + client = MagicMock(spec=httpx.Client) + client.get = MagicMock(return_value=mock_response) + + opts = GasPriceManagerOptions(http_client_sync=client) + mgr = GasPriceManagerSync(test_config, opts=opts) # type: ignore[arg-type] + + with pytest.raises(ValueError, match="Failed to fetch gas price"): + mgr.fetch_gas_price_estimation() + + +class TestGasPriceManagerSyncFetchAndSet: + def test_sets_gas_price_on_success(self, test_config: object) -> None: + mock_response = _make_mock_response(gas_estimate=150) + client = MagicMock(spec=httpx.Client) + client.get = MagicMock(return_value=mock_response) + + opts = GasPriceManagerOptions(http_client_sync=client, multiplier=1.0) + mgr = GasPriceManagerSync(test_config, opts=opts) # type: ignore[arg-type] + + result = mgr.fetch_and_set_gas_price() + + assert result == 150 + assert mgr._gas_price is not None + assert mgr._gas_price.gas_estimate == 150 + + def test_raises_on_zero_estimate(self, test_config: object) -> None: + mock_response = _make_mock_response(gas_estimate=0) + client = MagicMock(spec=httpx.Client) + client.get = MagicMock(return_value=mock_response) + + opts = GasPriceManagerOptions(http_client_sync=client, multiplier=1.0) + mgr = GasPriceManagerSync(test_config, opts=opts) # type: ignore[arg-type] + + with pytest.raises(ValueError, match="no gas estimate"): + mgr.fetch_and_set_gas_price() + + def test_raises_on_network_error(self, test_config: object) -> None: + client = MagicMock(spec=httpx.Client) + client.get = MagicMock(side_effect=ConnectionError("refused")) + + opts = GasPriceManagerOptions(http_client_sync=client) + mgr = GasPriceManagerSync(test_config, opts=opts) # type: ignore[arg-type] + + with pytest.raises(ConnectionError): + mgr.fetch_and_set_gas_price() + + +class TestGasPriceManagerSyncGetGasPrice: + def test_returns_none_when_not_set(self, test_config: object) -> None: + mgr = GasPriceManagerSync(test_config) # type: ignore[arg-type] + assert mgr.get_gas_price() is None + + def test_returns_gas_estimate_when_set(self, test_config: object) -> None: + mgr = GasPriceManagerSync(test_config) # type: ignore[arg-type] + mgr._gas_price = GasPriceInfo(gas_estimate=999, timestamp=time.time()) + assert mgr.get_gas_price() == 999 + + +class TestGasPriceManagerSyncInitialize: + def test_initialize_starts_daemon_thread(self, test_config: object) -> None: + mock_response = _make_mock_response(gas_estimate=100) + client = MagicMock(spec=httpx.Client) + client.get = MagicMock(return_value=mock_response) + + opts = GasPriceManagerOptions(http_client_sync=client, multiplier=1.0) + mgr = GasPriceManagerSync(test_config, opts=opts) # type: ignore[arg-type] + + mgr.initialize() + + assert mgr._is_initialized + assert mgr._refresh_thread is not None + assert mgr._refresh_thread.daemon + + mgr.destroy() + + def test_already_initialized_is_noop(self, test_config: object) -> None: + mock_response = _make_mock_response(gas_estimate=100) + client = MagicMock(spec=httpx.Client) + client.get = MagicMock(return_value=mock_response) + + opts = GasPriceManagerOptions(http_client_sync=client, multiplier=1.0) + mgr = GasPriceManagerSync(test_config, opts=opts) # type: ignore[arg-type] + + mgr.initialize() + first_thread = mgr._refresh_thread + call_count = client.get.call_count + + mgr.initialize() + + assert mgr._refresh_thread is first_thread + assert client.get.call_count == call_count + + mgr.destroy() + + def test_initialize_logs_on_failure(self, test_config: object) -> None: + client = MagicMock(spec=httpx.Client) + client.get = MagicMock(side_effect=RuntimeError("fail")) + + opts = GasPriceManagerOptions(http_client_sync=client) + mgr = GasPriceManagerSync(test_config, opts=opts) # type: ignore[arg-type] + + # Should not raise; logs instead + mgr.initialize() + assert not mgr._is_initialized + + +class TestGasPriceManagerSyncDestroy: + def test_destroy_sets_stop_event_and_joins_thread(self, test_config: object) -> None: + mock_response = _make_mock_response(gas_estimate=100) + client = MagicMock(spec=httpx.Client) + client.get = MagicMock(return_value=mock_response) + + opts = GasPriceManagerOptions(http_client_sync=client, multiplier=1.0) + mgr = GasPriceManagerSync(test_config, opts=opts) # type: ignore[arg-type] + + mgr.initialize() + assert mgr._is_initialized + + mgr.destroy() + + assert not mgr._is_initialized + assert mgr._gas_price is None + assert mgr._refresh_thread is None + + def test_destroy_with_no_thread_is_safe(self, test_config: object) -> None: + mgr = GasPriceManagerSync(test_config) # type: ignore[arg-type] + mgr.destroy() # Should not raise + + +class TestGasPriceManagerSyncRefresh: + def test_refresh_calls_fetch_and_set(self, test_config: object) -> None: + mock_response = _make_mock_response(gas_estimate=100) + client = MagicMock(spec=httpx.Client) + client.get = MagicMock(return_value=mock_response) + + opts = GasPriceManagerOptions(http_client_sync=client, multiplier=1.0) + mgr = GasPriceManagerSync(test_config, opts=opts) # type: ignore[arg-type] + + mgr.refresh() + assert mgr._gas_price is not None + + def test_refresh_logs_on_exception(self, test_config: object) -> None: + client = MagicMock(spec=httpx.Client) + client.get = MagicMock(side_effect=RuntimeError("boom")) + + opts = GasPriceManagerOptions(http_client_sync=client) + mgr = GasPriceManagerSync(test_config, opts=opts) # type: ignore[arg-type] + + # Should not raise + mgr.refresh() + + +class TestGasPriceManagerSyncRefreshLoop: + def test_refresh_loop_stops_on_event(self, test_config: object) -> None: + call_count = 0 + + def fake_fetch_and_set() -> int: + nonlocal call_count + call_count += 1 + return 100 + + mgr = GasPriceManagerSync(test_config) # type: ignore[arg-type] + mgr._refresh_interval_seconds = 0.01 + mgr.fetch_and_set_gas_price = fake_fetch_and_set # type: ignore[method-assign] + + thread = threading.Thread(target=mgr._refresh_loop, daemon=True) + thread.start() + time.sleep(0.05) + mgr._stop_event.set() + thread.join(timeout=1.0) + + assert call_count >= 1 + assert not thread.is_alive() + + +class TestGasPriceManagerSyncContextManager: + def test_enter_calls_initialize(self, test_config: object) -> None: + mock_response = _make_mock_response(gas_estimate=100) + client = MagicMock(spec=httpx.Client) + client.get = MagicMock(return_value=mock_response) + + opts = GasPriceManagerOptions(http_client_sync=client, multiplier=1.0) + mgr = GasPriceManagerSync(test_config, opts=opts) # type: ignore[arg-type] + + with mgr as ctx: + assert ctx is mgr + assert mgr._is_initialized + + def test_exit_calls_destroy(self, test_config: object) -> None: + mock_response = _make_mock_response(gas_estimate=100) + client = MagicMock(spec=httpx.Client) + client.get = MagicMock(return_value=mock_response) + + opts = GasPriceManagerOptions(http_client_sync=client, multiplier=1.0) + mgr = GasPriceManagerSync(test_config, opts=opts) # type: ignore[arg-type] + + with mgr: + pass + + assert not mgr._is_initialized + assert mgr._gas_price is None diff --git a/tests/test_order_status.py b/tests/test_order_status.py new file mode 100644 index 0000000..df9a7bb --- /dev/null +++ b/tests/test_order_status.py @@ -0,0 +1,335 @@ +"""Tests for decibel._order_status module.""" + +from __future__ import annotations + +from unittest.mock import AsyncMock, MagicMock, patch + +import httpx + +from decibel._order_status import OrderStatus, OrderStatusClient + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +ORDER_STATUS_DATA = { + "parent": "0xparent", + "market": "0xmarket", + "order_id": "0xorder", + "status": "Filled", + "orig_size": 1.0, + "remaining_size": 0.0, + "size_delta": 1.0, + "price": 100.0, + "is_buy": True, + "details": "ok", + "transaction_version": 1, + "unix_ms": 1000, +} + + +def _make_async_response( + status_code: int = 200, + json_data: dict | None = None, + is_success: bool = True, +) -> AsyncMock: + resp = AsyncMock(spec=httpx.Response) + resp.status_code = status_code + resp.is_success = is_success + resp.json.return_value = json_data or ORDER_STATUS_DATA + resp.text = "" + resp.reason_phrase = "OK" if is_success else "Error" + return resp + + +def _make_sync_response( + status_code: int = 200, + json_data: dict | None = None, + is_success: bool = True, +) -> MagicMock: + resp = MagicMock(spec=httpx.Response) + resp.status_code = status_code + resp.is_success = is_success + resp.json.return_value = json_data or ORDER_STATUS_DATA + resp.text = "" + resp.reason_phrase = "OK" if is_success else "Error" + return resp + + +# --------------------------------------------------------------------------- +# OrderStatusClient.__init__ +# --------------------------------------------------------------------------- + + +class TestOrderStatusClientInit: + def test_stores_config(self, test_config: object) -> None: + client = OrderStatusClient(test_config) # type: ignore[arg-type] + assert client._config is test_config + + def test_optional_clients_default_to_none(self, test_config: object) -> None: + client = OrderStatusClient(test_config) # type: ignore[arg-type] + assert client._http_client is None + assert client._http_client_sync is None + + def test_stores_provided_clients(self, test_config: object) -> None: + async_client = AsyncMock(spec=httpx.AsyncClient) + sync_client = MagicMock(spec=httpx.Client) + client = OrderStatusClient( + test_config, # type: ignore[arg-type] + http_client=async_client, + http_client_sync=sync_client, + ) + assert client._http_client is async_client + assert client._http_client_sync is sync_client + + +# --------------------------------------------------------------------------- +# get_order_status (async) +# --------------------------------------------------------------------------- + + +class TestGetOrderStatus: + async def test_success_returns_order_status(self, test_config: object) -> None: + async_client = AsyncMock(spec=httpx.AsyncClient) + async_client.get = AsyncMock(return_value=_make_async_response()) + + os_client = OrderStatusClient(test_config, http_client=async_client) # type: ignore[arg-type] + result = await os_client.get_order_status("0xorder", "0xmarket", "0xuser") + + assert isinstance(result, OrderStatus) + assert result.order_id == "0xorder" + + async def test_404_returns_none(self, test_config: object) -> None: + async_client = AsyncMock(spec=httpx.AsyncClient) + resp = _make_async_response(status_code=404, is_success=False) + resp.status_code = 404 + async_client.get = AsyncMock(return_value=resp) + + os_client = OrderStatusClient(test_config, http_client=async_client) # type: ignore[arg-type] + result = await os_client.get_order_status("0xorder", "0xmarket", "0xuser") + + assert result is None + + async def test_server_error_logs_and_returns_none(self, test_config: object) -> None: + async_client = AsyncMock(spec=httpx.AsyncClient) + resp = _make_async_response(status_code=500, is_success=False) + resp.status_code = 500 + resp.reason_phrase = "Internal Server Error" + async_client.get = AsyncMock(return_value=resp) + + os_client = OrderStatusClient(test_config, http_client=async_client) # type: ignore[arg-type] + result = await os_client.get_order_status("0xorder", "0xmarket", "0xuser") + + assert result is None + + async def test_uses_shared_client_when_available(self, test_config: object) -> None: + async_client = AsyncMock(spec=httpx.AsyncClient) + async_client.get = AsyncMock(return_value=_make_async_response()) + + os_client = OrderStatusClient(test_config, http_client=async_client) # type: ignore[arg-type] + await os_client.get_order_status("0xorder", "0xmarket", "0xuser") + + async_client.get.assert_called_once() + + async def test_uses_explicit_client_parameter(self, test_config: object) -> None: + stored_client = AsyncMock(spec=httpx.AsyncClient) + explicit_client = AsyncMock(spec=httpx.AsyncClient) + explicit_client.get = AsyncMock(return_value=_make_async_response()) + + os_client = OrderStatusClient(test_config, http_client=stored_client) # type: ignore[arg-type] + await os_client.get_order_status("0xorder", "0xmarket", "0xuser", client=explicit_client) + + explicit_client.get.assert_called_once() + stored_client.get.assert_not_called() + + async def test_creates_temp_client_when_no_client(self, test_config: object) -> None: + mock_temp = AsyncMock() + mock_temp.get = AsyncMock(return_value=_make_async_response()) + mock_temp.__aenter__ = AsyncMock(return_value=mock_temp) + mock_temp.__aexit__ = AsyncMock(return_value=None) + + os_client = OrderStatusClient(test_config) # type: ignore[arg-type] + + with patch("decibel._order_status.httpx.AsyncClient", return_value=mock_temp): + result = await os_client.get_order_status("0xorder", "0xmarket", "0xuser") + + assert isinstance(result, OrderStatus) + + async def test_exception_logs_and_returns_none(self, test_config: object) -> None: + async_client = AsyncMock(spec=httpx.AsyncClient) + async_client.get = AsyncMock(side_effect=ConnectionError("connection refused")) + + os_client = OrderStatusClient(test_config, http_client=async_client) # type: ignore[arg-type] + result = await os_client.get_order_status("0xorder", "0xmarket", "0xuser") + + assert result is None + + +# --------------------------------------------------------------------------- +# get_order_status_sync +# --------------------------------------------------------------------------- + + +class TestGetOrderStatusSync: + def test_success_returns_order_status(self, test_config: object) -> None: + sync_client = MagicMock(spec=httpx.Client) + sync_client.get = MagicMock(return_value=_make_sync_response()) + + os_client = OrderStatusClient(test_config, http_client_sync=sync_client) # type: ignore[arg-type] + result = os_client.get_order_status_sync("0xorder", "0xmarket", "0xuser") + + assert isinstance(result, OrderStatus) + + def test_404_returns_none(self, test_config: object) -> None: + sync_client = MagicMock(spec=httpx.Client) + resp = _make_sync_response(status_code=404, is_success=False) + resp.status_code = 404 + sync_client.get = MagicMock(return_value=resp) + + os_client = OrderStatusClient(test_config, http_client_sync=sync_client) # type: ignore[arg-type] + result = os_client.get_order_status_sync("0xorder", "0xmarket", "0xuser") + + assert result is None + + def test_server_error_logs_and_returns_none(self, test_config: object) -> None: + sync_client = MagicMock(spec=httpx.Client) + resp = _make_sync_response(status_code=500, is_success=False) + resp.status_code = 500 + resp.reason_phrase = "Internal Server Error" + sync_client.get = MagicMock(return_value=resp) + + os_client = OrderStatusClient(test_config, http_client_sync=sync_client) # type: ignore[arg-type] + result = os_client.get_order_status_sync("0xorder", "0xmarket", "0xuser") + + assert result is None + + def test_uses_shared_sync_client(self, test_config: object) -> None: + sync_client = MagicMock(spec=httpx.Client) + sync_client.get = MagicMock(return_value=_make_sync_response()) + + os_client = OrderStatusClient(test_config, http_client_sync=sync_client) # type: ignore[arg-type] + os_client.get_order_status_sync("0xorder", "0xmarket", "0xuser") + + sync_client.get.assert_called_once() + + def test_uses_explicit_client_parameter(self, test_config: object) -> None: + stored_client = MagicMock(spec=httpx.Client) + explicit_client = MagicMock(spec=httpx.Client) + explicit_client.get = MagicMock(return_value=_make_sync_response()) + + os_client = OrderStatusClient(test_config, http_client_sync=stored_client) # type: ignore[arg-type] + os_client.get_order_status_sync("0xorder", "0xmarket", "0xuser", client=explicit_client) + + explicit_client.get.assert_called_once() + stored_client.get.assert_not_called() + + def test_creates_temp_client_when_no_client(self, test_config: object) -> None: + mock_temp = MagicMock() + mock_temp.get = MagicMock(return_value=_make_sync_response()) + mock_temp.__enter__ = MagicMock(return_value=mock_temp) + mock_temp.__exit__ = MagicMock(return_value=None) + + os_client = OrderStatusClient(test_config) # type: ignore[arg-type] + + with patch("decibel._order_status.httpx.Client", return_value=mock_temp): + result = os_client.get_order_status_sync("0xorder", "0xmarket", "0xuser") + + assert isinstance(result, OrderStatus) + + def test_exception_logs_and_returns_none(self, test_config: object) -> None: + sync_client = MagicMock(spec=httpx.Client) + sync_client.get = MagicMock(side_effect=ConnectionError("refused")) + + os_client = OrderStatusClient(test_config, http_client_sync=sync_client) # type: ignore[arg-type] + result = os_client.get_order_status_sync("0xorder", "0xmarket", "0xuser") + + assert result is None + + +# --------------------------------------------------------------------------- +# parse_order_status_type +# --------------------------------------------------------------------------- + + +class TestParseOrderStatusType: + def test_acknowledged(self) -> None: + assert OrderStatusClient.parse_order_status_type("Acknowledged") == "Acknowledged" + + def test_acknowledged_case_insensitive(self) -> None: + assert OrderStatusClient.parse_order_status_type("ACKNOWLEDGED") == "Acknowledged" + + def test_filled(self) -> None: + assert OrderStatusClient.parse_order_status_type("Filled") == "Filled" + + def test_filled_case_insensitive(self) -> None: + assert OrderStatusClient.parse_order_status_type("filled") == "Filled" + + def test_cancelled(self) -> None: + assert OrderStatusClient.parse_order_status_type("Cancelled") == "Cancelled" + + def test_rejected(self) -> None: + assert OrderStatusClient.parse_order_status_type("Rejected") == "Rejected" + + def test_unknown_for_unrecognized_string(self) -> None: + assert OrderStatusClient.parse_order_status_type("SomethingElse") == "Unknown" + + def test_none_returns_unknown(self) -> None: + assert OrderStatusClient.parse_order_status_type(None) == "Unknown" + + def test_empty_string_returns_unknown(self) -> None: + assert OrderStatusClient.parse_order_status_type("") == "Unknown" + + def test_partial_match_cancelled(self) -> None: + assert OrderStatusClient.parse_order_status_type("order_cancelled_by_user") == "Cancelled" + + +# --------------------------------------------------------------------------- +# is_success_status / is_failure_status / is_final_status +# --------------------------------------------------------------------------- + + +class TestStatusHelpers: + def test_is_success_status_true_for_filled(self) -> None: + assert OrderStatusClient.is_success_status("Filled") is True + + def test_is_success_status_false_for_acknowledged(self) -> None: + assert OrderStatusClient.is_success_status("Acknowledged") is False + + def test_is_success_status_false_for_cancelled(self) -> None: + assert OrderStatusClient.is_success_status("Cancelled") is False + + def test_is_success_status_false_for_none(self) -> None: + assert OrderStatusClient.is_success_status(None) is False + + def test_is_failure_status_true_for_cancelled(self) -> None: + assert OrderStatusClient.is_failure_status("Cancelled") is True + + def test_is_failure_status_true_for_rejected(self) -> None: + assert OrderStatusClient.is_failure_status("Rejected") is True + + def test_is_failure_status_false_for_filled(self) -> None: + assert OrderStatusClient.is_failure_status("Filled") is False + + def test_is_failure_status_false_for_acknowledged(self) -> None: + assert OrderStatusClient.is_failure_status("Acknowledged") is False + + def test_is_failure_status_false_for_none(self) -> None: + assert OrderStatusClient.is_failure_status(None) is False + + def test_is_final_status_true_for_filled(self) -> None: + assert OrderStatusClient.is_final_status("Filled") is True + + def test_is_final_status_true_for_cancelled(self) -> None: + assert OrderStatusClient.is_final_status("Cancelled") is True + + def test_is_final_status_true_for_rejected(self) -> None: + assert OrderStatusClient.is_final_status("Rejected") is True + + def test_is_final_status_false_for_acknowledged(self) -> None: + assert OrderStatusClient.is_final_status("Acknowledged") is False + + def test_is_final_status_false_for_unknown(self) -> None: + assert OrderStatusClient.is_final_status("Unknown") is False + + def test_is_final_status_false_for_none(self) -> None: + assert OrderStatusClient.is_final_status(None) is False diff --git a/tests/test_pagination.py b/tests/test_pagination.py new file mode 100644 index 0000000..4ca5ef5 --- /dev/null +++ b/tests/test_pagination.py @@ -0,0 +1,126 @@ +"""Tests for decibel._pagination module.""" + +from __future__ import annotations + +import pytest +from pydantic import BaseModel, ValidationError + +from decibel._pagination import ( + PaginatedResponse, + construct_known_query_params, +) + + +class _Item(BaseModel): + name: str + value: int + + +class TestPaginatedResponse: + def test_basic_model_validation(self) -> None: + data = {"items": [{"name": "foo", "value": 1}], "total_count": 1} + response = PaginatedResponse[_Item].model_validate(data) + assert response.total_count == 1 + assert len(response.items) == 1 + assert response.items[0].name == "foo" + + def test_empty_items_list(self) -> None: + data = {"items": [], "total_count": 0} + response = PaginatedResponse[_Item].model_validate(data) + assert response.total_count == 0 + assert response.items == [] + + def test_multiple_items(self) -> None: + data = { + "items": [ + {"name": "a", "value": 1}, + {"name": "b", "value": 2}, + {"name": "c", "value": 3}, + ], + "total_count": 3, + } + response = PaginatedResponse[_Item].model_validate(data) + assert len(response.items) == 3 + assert response.items[1].name == "b" + + def test_total_count_can_differ_from_items_length(self) -> None: + data = {"items": [{"name": "a", "value": 1}], "total_count": 100} + response = PaginatedResponse[_Item].model_validate(data) + assert response.total_count == 100 + assert len(response.items) == 1 + + def test_missing_items_raises(self) -> None: + with pytest.raises(ValidationError): + PaginatedResponse[_Item].model_validate({"total_count": 0}) + + def test_missing_total_count_raises(self) -> None: + with pytest.raises(ValidationError): + PaginatedResponse[_Item].model_validate({"items": []}) + + +class TestConstructKnownQueryParams: + def test_full_params(self) -> None: + result = construct_known_query_params( + { + "limit": 10, + "offset": 5, + "sort_key": "volume", + "sort_dir": "ASC", + "search_term": "btc", + } + ) + assert result == { + "limit": "10", + "offset": "5", + "sort_key": "volume", + "sort_dir": "ASC", + "search_term": "btc", + } + + def test_partial_params_limit_only(self) -> None: + result = construct_known_query_params({"limit": 20}) + assert result == {"limit": "20"} + assert "offset" not in result + + def test_partial_params_sort_only(self) -> None: + result = construct_known_query_params({"sort_key": "realized_pnl", "sort_dir": "DESC"}) + assert result == {"sort_key": "realized_pnl", "sort_dir": "DESC"} + + def test_empty_dict_returns_empty(self) -> None: + result = construct_known_query_params({}) + assert result == {} + + def test_none_values_are_skipped(self) -> None: + result = construct_known_query_params({"limit": 10, "sort_dir": None}) # type: ignore[typeddict-item] + assert "sort_dir" not in result + assert result["limit"] == "10" + + def test_empty_string_search_term_is_skipped(self) -> None: + result = construct_known_query_params({"search_term": " "}) + assert "search_term" not in result + + def test_empty_string_exactly_is_skipped(self) -> None: + result = construct_known_query_params({"search_term": ""}) + assert "search_term" not in result + + def test_non_empty_search_term_is_included(self) -> None: + result = construct_known_query_params({"search_term": "eth"}) + assert result["search_term"] == "eth" + + def test_integer_values_are_stringified(self) -> None: + result = construct_known_query_params({"limit": 100, "offset": 0}) + assert result["limit"] == "100" + assert result["offset"] == "0" + + def test_sort_dir_asc(self) -> None: + result = construct_known_query_params({"sort_dir": "ASC"}) + assert result["sort_dir"] == "ASC" + + def test_sort_dir_desc(self) -> None: + result = construct_known_query_params({"sort_dir": "DESC"}) + assert result["sort_dir"] == "DESC" + + def test_all_values_returned_as_strings(self) -> None: + result = construct_known_query_params({"limit": 5, "offset": 0, "search_term": "market"}) + for v in result.values(): + assert isinstance(v, str) diff --git a/tests/test_transaction_builder.py b/tests/test_transaction_builder.py new file mode 100644 index 0000000..ae33d7c --- /dev/null +++ b/tests/test_transaction_builder.py @@ -0,0 +1,814 @@ +"""Unit tests for decibel._transaction_builder module. + +Covers: generate_expire_timestamp, build_simple_transaction_sync, +_build_entry_function, _find_first_non_signer_arg, _parse_type_tag, +_encode_argument, _encode_vector_bytes, _encode_option_bytes, +_encode_function_arguments, TransactionExtraConfigV1 (serialize/deserialize), +TransactionPayloadOrderless (serialize). +""" + +from __future__ import annotations + +import time +from typing import Any +from unittest.mock import MagicMock + +import pytest +from aptos_sdk.account_address import AccountAddress +from aptos_sdk.bcs import Deserializer, Serializer +from aptos_sdk.type_tag import StructTag, TypeTag + +from decibel._transaction_builder import ( + DEADBEEF_SEQUENCE_NUMBER, + InputEntryFunctionData, + SimpleTransaction, + TransactionExtraConfigV1, + TransactionPayloadOrderless, + _build_entry_function, + _encode_argument, + _encode_function_arguments, + _encode_option_bytes, + _encode_vector_bytes, + _find_first_non_signer_arg, + _parse_type_tag, + build_simple_transaction_sync, + generate_expire_timestamp, +) + +# --------------------------------------------------------------------------- +# Helper – create a minimal MoveFunction mock +# --------------------------------------------------------------------------- + + +def _make_abi(params: list[str]) -> Any: + """Create a minimal MoveFunction-like object with given params.""" + mock = MagicMock() + mock.params = params + return mock + + +# --------------------------------------------------------------------------- +# generate_expire_timestamp +# --------------------------------------------------------------------------- + + +class TestGenerateExpireTimestamp: + def test_returns_int(self) -> None: + result = generate_expire_timestamp() + assert isinstance(result, int) + + def test_in_the_future(self) -> None: + now = int(time.time()) + result = generate_expire_timestamp() + assert result > now + + def test_approximately_now_plus_20(self) -> None: + before = int(time.time()) + result = generate_expire_timestamp() + after = int(time.time()) + # Should be now + 20 seconds (default expiry) + assert before + 19 <= result <= after + 22 + + def test_time_delta_ms_shifts_result(self) -> None: + base = generate_expire_timestamp(time_delta_ms=0) + shifted = generate_expire_timestamp(time_delta_ms=5000) + # 5000ms = 5s shift + assert shifted >= base + 4 + + def test_negative_time_delta_ms_shifts_backward(self) -> None: + base = generate_expire_timestamp(time_delta_ms=0) + shifted = generate_expire_timestamp(time_delta_ms=-10000) + assert shifted <= base - 9 + + def test_custom_default_txn_expiry_sec(self) -> None: + before = int(time.time()) + result = generate_expire_timestamp(default_txn_expiry_sec=60) + after = int(time.time()) + assert before + 59 <= result <= after + 62 + + +# --------------------------------------------------------------------------- +# build_simple_transaction_sync +# --------------------------------------------------------------------------- + + +class TestBuildSimpleTransactionSync: + def _make_abi_for_build(self) -> Any: + """ABI with one non-signer u64 param.""" + return _make_abi(["&signer", "u64"]) + + def test_returns_simple_transaction(self) -> None: + abi = self._make_abi_for_build() + sender = AccountAddress.from_str("0x" + "aa" * 32) + data = InputEntryFunctionData( + function=f"{'0x' + 'ab' * 32}::module::func", + function_arguments=[42], + type_arguments=[], + ) + result = build_simple_transaction_sync( + sender=sender, + data=data, + chain_id=2, + gas_unit_price=100, + abi=abi, + with_fee_payer=False, + replay_protection_nonce=12345, + ) + assert isinstance(result, SimpleTransaction) + + def test_fee_payer_address_set_when_with_fee_payer(self) -> None: + abi = self._make_abi_for_build() + sender = AccountAddress.from_str("0x" + "aa" * 32) + data = InputEntryFunctionData( + function=f"{'0x' + 'ab' * 32}::module::func", + function_arguments=[42], + type_arguments=[], + ) + result = build_simple_transaction_sync( + sender=sender, + data=data, + chain_id=2, + gas_unit_price=100, + abi=abi, + with_fee_payer=True, + replay_protection_nonce=99, + ) + assert result.fee_payer_address is not None + assert str(result.fee_payer_address) == str(AccountAddress.from_str("0x0")) + + def test_fee_payer_address_none_when_no_fee_payer(self) -> None: + abi = self._make_abi_for_build() + sender = AccountAddress.from_str("0x" + "aa" * 32) + data = InputEntryFunctionData( + function=f"{'0x' + 'ab' * 32}::module::func", + function_arguments=[42], + type_arguments=[], + ) + result = build_simple_transaction_sync( + sender=sender, + data=data, + chain_id=2, + gas_unit_price=100, + abi=abi, + with_fee_payer=False, + replay_protection_nonce=99, + ) + assert result.fee_payer_address is None + + def test_uses_deadbeef_sequence_number(self) -> None: + abi = self._make_abi_for_build() + sender = AccountAddress.from_str("0x" + "aa" * 32) + data = InputEntryFunctionData( + function=f"{'0x' + 'ab' * 32}::module::func", + function_arguments=[42], + type_arguments=[], + ) + result = build_simple_transaction_sync( + sender=sender, + data=data, + chain_id=2, + gas_unit_price=100, + abi=abi, + with_fee_payer=False, + replay_protection_nonce=1, + ) + assert result.raw_transaction.sequence_number == DEADBEEF_SEQUENCE_NUMBER + + def test_correct_gas_unit_price(self) -> None: + abi = self._make_abi_for_build() + sender = AccountAddress.from_str("0x" + "aa" * 32) + data = InputEntryFunctionData( + function=f"{'0x' + 'ab' * 32}::module::func", + function_arguments=[42], + type_arguments=[], + ) + result = build_simple_transaction_sync( + sender=sender, + data=data, + chain_id=2, + gas_unit_price=777, + abi=abi, + with_fee_payer=False, + replay_protection_nonce=1, + ) + assert result.raw_transaction.gas_unit_price == 777 + + def test_correct_max_gas_amount(self) -> None: + abi = self._make_abi_for_build() + sender = AccountAddress.from_str("0x" + "aa" * 32) + data = InputEntryFunctionData( + function=f"{'0x' + 'ab' * 32}::module::func", + function_arguments=[42], + type_arguments=[], + ) + result = build_simple_transaction_sync( + sender=sender, + data=data, + chain_id=2, + gas_unit_price=100, + abi=abi, + with_fee_payer=False, + replay_protection_nonce=1, + max_gas_amount=500_000, + ) + assert result.raw_transaction.max_gas_amount == 500_000 + + def test_sender_as_string(self) -> None: + abi = self._make_abi_for_build() + sender_str = "0x" + "aa" * 32 + data = InputEntryFunctionData( + function=f"{'0x' + 'ab' * 32}::module::func", + function_arguments=[42], + type_arguments=[], + ) + result = build_simple_transaction_sync( + sender=sender_str, + data=data, + chain_id=2, + gas_unit_price=100, + abi=abi, + with_fee_payer=False, + replay_protection_nonce=1, + ) + assert isinstance(result, SimpleTransaction) + + def test_correct_chain_id(self) -> None: + abi = self._make_abi_for_build() + sender = AccountAddress.from_str("0x" + "aa" * 32) + data = InputEntryFunctionData( + function=f"{'0x' + 'ab' * 32}::module::func", + function_arguments=[42], + type_arguments=[], + ) + result = build_simple_transaction_sync( + sender=sender, + data=data, + chain_id=42, + gas_unit_price=100, + abi=abi, + with_fee_payer=False, + replay_protection_nonce=1, + ) + assert result.raw_transaction.chain_id == 42 + + +# --------------------------------------------------------------------------- +# _build_entry_function +# --------------------------------------------------------------------------- + + +class TestBuildEntryFunction: + def test_parses_valid_function_id(self) -> None: + pkg = "0x" + "ab" * 32 + abi = _make_abi(["&signer", "u64"]) + data = InputEntryFunctionData( + function=f"{pkg}::my_module::my_func", + function_arguments=[999], + type_arguments=[], + ) + entry_fn = _build_entry_function(data, abi) + assert entry_fn.function == "my_func" + assert str(entry_fn.module.address) == pkg + + def test_invalid_function_id_raises(self) -> None: + abi = _make_abi(["u64"]) + data = InputEntryFunctionData( + function="invalid_format", + function_arguments=[1], + type_arguments=[], + ) + with pytest.raises(ValueError, match="Invalid function format"): + _build_entry_function(data, abi) + + def test_too_many_parts_raises(self) -> None: + abi = _make_abi(["u64"]) + data = InputEntryFunctionData( + function="0x1::a::b::c", + function_arguments=[1], + type_arguments=[], + ) + with pytest.raises(ValueError, match="Invalid function format"): + _build_entry_function(data, abi) + + def test_type_arguments_parsed(self) -> None: + pkg = "0x" + "ab" * 32 + abi = _make_abi([]) # No signer, no params + data = InputEntryFunctionData( + function=f"{pkg}::module::func", + function_arguments=[], + type_arguments=["u64", "bool"], + ) + entry_fn = _build_entry_function(data, abi) + assert len(entry_fn.ty_args) == 2 + + def test_skips_signer_params_when_encoding(self) -> None: + pkg = "0x" + "ab" * 32 + abi = _make_abi(["&signer", "&signer", "u64"]) + data = InputEntryFunctionData( + function=f"{pkg}::module::func", + function_arguments=[100], # only 1 non-signer arg + type_arguments=[], + ) + entry_fn = _build_entry_function(data, abi) + assert len(entry_fn.args) == 1 + + def test_none_type_arguments_defaults_to_empty(self) -> None: + pkg = "0x" + "ab" * 32 + abi = _make_abi(["u64"]) + data = InputEntryFunctionData( + function=f"{pkg}::module::func", + function_arguments=[1], + type_arguments=None, + ) + entry_fn = _build_entry_function(data, abi) + assert entry_fn.ty_args == [] + + +# --------------------------------------------------------------------------- +# _find_first_non_signer_arg +# --------------------------------------------------------------------------- + + +class TestFindFirstNonSignerArg: + def test_no_params(self) -> None: + assert _find_first_non_signer_arg([]) == 0 + + def test_all_signers(self) -> None: + params = ["&signer", "signer", "&signer"] + assert _find_first_non_signer_arg(params) == len(params) + + def test_first_non_signer_at_zero(self) -> None: + params = ["u64", "bool"] + assert _find_first_non_signer_arg(params) == 0 + + def test_first_non_signer_after_signer(self) -> None: + params = ["&signer", "u64", "bool"] + assert _find_first_non_signer_arg(params) == 1 + + def test_multiple_signers_then_non_signer(self) -> None: + params = ["signer", "&signer", "address"] + assert _find_first_non_signer_arg(params) == 2 + + def test_reference_signer(self) -> None: + params = ["& signer", "u128"] + # "& signer" stripped becomes "signer" + assert _find_first_non_signer_arg(params) == 1 + + +# --------------------------------------------------------------------------- +# _parse_type_tag +# --------------------------------------------------------------------------- + + +class TestParseTypeTag: + def test_bool(self) -> None: + tag = _parse_type_tag("bool") + assert tag.value == TypeTag.BOOL + + def test_u8(self) -> None: + tag = _parse_type_tag("u8") + assert tag.value == TypeTag.U8 + + def test_u16(self) -> None: + tag = _parse_type_tag("u16") + assert tag.value == TypeTag.U16 + + def test_u32(self) -> None: + tag = _parse_type_tag("u32") + assert tag.value == TypeTag.U32 + + def test_u64(self) -> None: + tag = _parse_type_tag("u64") + assert tag.value == TypeTag.U64 + + def test_u128(self) -> None: + tag = _parse_type_tag("u128") + assert tag.value == TypeTag.U128 + + def test_u256(self) -> None: + tag = _parse_type_tag("u256") + assert tag.value == TypeTag.U256 + + def test_address(self) -> None: + tag = _parse_type_tag("address") + assert tag.value == TypeTag.ACCOUNT_ADDRESS + + def test_signer(self) -> None: + tag = _parse_type_tag("signer") + assert tag.value == TypeTag.SIGNER + + def test_vector_u8(self) -> None: + tag = _parse_type_tag("vector") + assert isinstance(tag.value, tuple) + assert tag.value[0] == TypeTag.VECTOR + + def test_vector_u64(self) -> None: + tag = _parse_type_tag("vector") + assert isinstance(tag.value, tuple) + assert tag.value[0] == TypeTag.VECTOR + + def test_struct_type(self) -> None: + tag = _parse_type_tag("0x1::string::String") + assert isinstance(tag.value, StructTag) + + def test_whitespace_stripped(self) -> None: + tag = _parse_type_tag(" u64 ") + assert tag.value == TypeTag.U64 + + +# --------------------------------------------------------------------------- +# _encode_argument +# --------------------------------------------------------------------------- + + +class TestEncodeArgument: + def _decode_bool(self, data: bytes) -> bool: + d = Deserializer(data) + return d.bool() + + def _decode_u8(self, data: bytes) -> int: + d = Deserializer(data) + return d.u8() + + def _decode_u16(self, data: bytes) -> int: + d = Deserializer(data) + return d.u16() + + def _decode_u32(self, data: bytes) -> int: + d = Deserializer(data) + return d.u32() + + def _decode_u64(self, data: bytes) -> int: + d = Deserializer(data) + return d.u64() + + def _decode_u128(self, data: bytes) -> int: + d = Deserializer(data) + return d.u128() + + def _decode_u256(self, data: bytes) -> int: + d = Deserializer(data) + return d.u256() + + def test_bool_true(self) -> None: + data = _encode_argument(True, "bool") + assert self._decode_bool(data) is True + + def test_bool_false(self) -> None: + data = _encode_argument(False, "bool") + assert self._decode_bool(data) is False + + def test_u8(self) -> None: + data = _encode_argument(255, "u8") + assert self._decode_u8(data) == 255 + + def test_u16(self) -> None: + data = _encode_argument(65535, "u16") + assert self._decode_u16(data) == 65535 + + def test_u32(self) -> None: + data = _encode_argument(4294967295, "u32") + assert self._decode_u32(data) == 4294967295 + + def test_u64(self) -> None: + data = _encode_argument(1234567890, "u64") + assert self._decode_u64(data) == 1234567890 + + def test_u128(self) -> None: + val = 2**64 + 42 + data = _encode_argument(val, "u128") + assert self._decode_u128(data) == val + + def test_u256(self) -> None: + val = 2**128 + 99 + data = _encode_argument(val, "u256") + assert self._decode_u256(data) == val + + def test_address_from_string(self) -> None: + addr_str = "0x" + "aa" * 32 + data = _encode_argument(addr_str, "address") + d = Deserializer(data) + decoded = AccountAddress.deserialize(d) + assert str(decoded) == addr_str + + def test_address_from_account_address(self) -> None: + addr = AccountAddress.from_str("0x" + "bb" * 32) + data = _encode_argument(addr, "address") + d = Deserializer(data) + decoded = AccountAddress.deserialize(d) + assert str(decoded) == str(addr) + + def test_vector_u8_from_bytes(self) -> None: + payload = b"\x01\x02\x03" + data = _encode_argument(payload, "vector") + d = Deserializer(data) + assert d.to_bytes() == payload + + def test_vector_u8_from_hex_string(self) -> None: + data = _encode_argument("0x010203", "vector") + d = Deserializer(data) + assert d.to_bytes() == bytes([1, 2, 3]) + + def test_vector_u8_from_hex_string_no_prefix(self) -> None: + data = _encode_argument("010203", "vector") + d = Deserializer(data) + assert d.to_bytes() == bytes([1, 2, 3]) + + def test_vector_u8_from_list(self) -> None: + data = _encode_argument([10, 20, 30], "vector") + d = Deserializer(data) + assert d.to_bytes() == bytes([10, 20, 30]) + + def test_string_encoding(self) -> None: + data = _encode_argument("hello world", "0x1::string::String") + d = Deserializer(data) + assert d.str() == "hello world" + + def test_option_some_u64(self) -> None: + data = _encode_argument(42, "0x1::option::Option") + d = Deserializer(data) + flag = d.u8() + assert flag == 1 # Some + val = d.u64() + assert val == 42 + + def test_option_none(self) -> None: + data = _encode_argument(None, "0x1::option::Option") + d = Deserializer(data) + flag = d.u8() + assert flag == 0 # None + + def test_object_type_from_string(self) -> None: + addr_str = "0x" + "cc" * 32 + data = _encode_argument(addr_str, "0x1::object::Object<0x1::coin::Coin>") + d = Deserializer(data) + decoded = AccountAddress.deserialize(d) + assert str(decoded) == addr_str + + def test_object_type_suffix(self) -> None: + addr_str = "0x" + "dd" * 32 + data = _encode_argument(addr_str, "some::module::Object") + d = Deserializer(data) + decoded = AccountAddress.deserialize(d) + assert str(decoded) == addr_str + + def test_reference_param_type_stripped(self) -> None: + # "&u64" should be treated same as "u64" + data = _encode_argument(99, "&u64") + d = Deserializer(data) + assert d.u64() == 99 + + def test_unknown_type_raises(self) -> None: + with pytest.raises(ValueError, match="Cannot encode argument"): + _encode_argument("something", "totally::unknown::Type") + + def test_non_u8_vector_via_encode_argument(self) -> None: + """_encode_argument dispatches to _encode_vector_bytes for non-u8 vectors.""" + data = _encode_argument([10, 20], "vector") + d = Deserializer(data) + length = d.uleb128() + assert length == 2 + + +# --------------------------------------------------------------------------- +# _encode_vector_bytes +# --------------------------------------------------------------------------- + + +class TestEncodeVectorBytes: + def test_vector_of_u64(self) -> None: + data = _encode_vector_bytes([1, 2, 3], "vector") + d = Deserializer(data) + length = d.uleb128() + assert length == 3 + # Each u64 is 8 bytes (fixed) + # The deserializer for fixed_bytes needs to know the size + # Just check the length prefix is correct + assert len(data) > 0 + + def test_vector_of_u8(self) -> None: + data = _encode_vector_bytes([10, 20, 30], "vector") + assert len(data) > 0 + + def test_empty_vector(self) -> None: + data = _encode_vector_bytes([], "vector") + d = Deserializer(data) + length = d.uleb128() + assert length == 0 + + def test_vector_of_bool(self) -> None: + data = _encode_vector_bytes([True, False, True], "vector") + d = Deserializer(data) + length = d.uleb128() + assert length == 3 + + +# --------------------------------------------------------------------------- +# _encode_option_bytes +# --------------------------------------------------------------------------- + + +class TestEncodeOptionBytes: + def test_none_produces_zero_byte(self) -> None: + data = _encode_option_bytes(None, "0x1::option::Option") + d = Deserializer(data) + assert d.u8() == 0 + + def test_some_u64_produces_one_then_value(self) -> None: + data = _encode_option_bytes(42, "0x1::option::Option") + d = Deserializer(data) + assert d.u8() == 1 + assert d.u64() == 42 + + def test_some_bool(self) -> None: + data = _encode_option_bytes(True, "0x1::option::Option") + d = Deserializer(data) + assert d.u8() == 1 + assert d.bool() is True + + def test_some_u128(self) -> None: + val = 2**64 + 1 + data = _encode_option_bytes(val, "0x1::option::Option") + d = Deserializer(data) + assert d.u8() == 1 + assert d.u128() == val + + +# --------------------------------------------------------------------------- +# _encode_function_arguments +# --------------------------------------------------------------------------- + + +class TestEncodeFunctionArguments: + def test_count_mismatch_raises(self) -> None: + with pytest.raises(ValueError, match="Argument count mismatch"): + _encode_function_arguments([1, 2, 3], ["u64", "bool"]) + + def test_empty_args_empty_params(self) -> None: + result = _encode_function_arguments([], []) + assert result == [] + + def test_single_arg(self) -> None: + result = _encode_function_arguments([100], ["u64"]) + assert len(result) == 1 + d = Deserializer(result[0]) + assert d.u64() == 100 + + def test_multiple_args(self) -> None: + result = _encode_function_arguments([True, 255], ["bool", "u8"]) + assert len(result) == 2 + d0 = Deserializer(result[0]) + assert d0.bool() is True + d1 = Deserializer(result[1]) + assert d1.u8() == 255 + + +# --------------------------------------------------------------------------- +# TransactionExtraConfigV1 serialize / deserialize +# --------------------------------------------------------------------------- + + +class TestTransactionExtraConfigV1: + def _roundtrip( + self, + multisig: AccountAddress | None = None, + nonce: int | None = None, + ) -> TransactionExtraConfigV1: + config = TransactionExtraConfigV1( + multisig_address=multisig, + replay_protection_nonce=nonce, + ) + serializer = Serializer() + config.serialize(serializer) + deserializer = Deserializer(serializer.output()) + return TransactionExtraConfigV1.deserialize(deserializer) + + def test_no_multisig_no_nonce(self) -> None: + result = self._roundtrip() + assert result.multisig_address is None + assert result.replay_protection_nonce is None + + def test_with_nonce_only(self) -> None: + result = self._roundtrip(nonce=98765) + assert result.multisig_address is None + assert result.replay_protection_nonce == 98765 + + def test_with_multisig_only(self) -> None: + addr = AccountAddress.from_str("0x" + "ee" * 32) + result = self._roundtrip(multisig=addr) + assert result.multisig_address is not None + assert str(result.multisig_address) == str(addr) + assert result.replay_protection_nonce is None + + def test_with_both(self) -> None: + addr = AccountAddress.from_str("0x" + "ff" * 32) + result = self._roundtrip(multisig=addr, nonce=55555) + assert str(result.multisig_address) == str(addr) # type: ignore[arg-type] + assert result.replay_protection_nonce == 55555 + + def test_invalid_variant_raises(self) -> None: + serializer = Serializer() + serializer.uleb128(99) # Wrong variant + deserializer = Deserializer(serializer.output()) + with pytest.raises(ValueError, match="Unknown TransactionExtraConfig variant"): + TransactionExtraConfigV1.deserialize(deserializer) + + def test_serialize_writes_variant_zero(self) -> None: + config = TransactionExtraConfigV1() + serializer = Serializer() + config.serialize(serializer) + data = serializer.output() + # First byte should be variant 0 (uleb128 of 0 = 0x00) + assert data[0] == 0 + + +# --------------------------------------------------------------------------- +# TransactionPayloadOrderless serialize +# --------------------------------------------------------------------------- + + +class TestTransactionPayloadOrderless: + def test_serialize_starts_with_variant_4(self) -> None: + from aptos_sdk.transactions import EntryFunction, ModuleId + + from decibel._transaction_builder import ( + TransactionExecutableEntryFunction, + TransactionInnerPayloadV1, + ) + + # Create a real EntryFunction + module_id = ModuleId(AccountAddress.from_str("0x1"), "m") + entry_fn = EntryFunction(module=module_id, function="f", ty_args=[], args=[]) + executable = TransactionExecutableEntryFunction(entry_fn) + + extra_config = TransactionExtraConfigV1(replay_protection_nonce=1) + inner = TransactionInnerPayloadV1(executable, extra_config) + payload = TransactionPayloadOrderless(inner) + + serializer = Serializer() + payload.serialize(serializer) + data = serializer.output() + + # First uleb128 byte = variant 4 = 0x04 + assert data[0] == 4 + + def test_serialize_is_non_empty(self) -> None: + from aptos_sdk.transactions import EntryFunction, ModuleId + + from decibel._transaction_builder import ( + TransactionExecutableEntryFunction, + TransactionInnerPayloadV1, + ) + + module_id = ModuleId(AccountAddress.from_str("0x1"), "m") + entry_fn = EntryFunction(module=module_id, function="f", ty_args=[], args=[]) + executable = TransactionExecutableEntryFunction(entry_fn) + extra_config = TransactionExtraConfigV1() + inner = TransactionInnerPayloadV1(executable, extra_config) + payload = TransactionPayloadOrderless(inner) + + serializer = Serializer() + payload.serialize(serializer) + assert len(serializer.output()) > 0 + + +# --------------------------------------------------------------------------- +# InputEntryFunctionData dataclass +# --------------------------------------------------------------------------- + + +class TestInputEntryFunctionData: + def test_default_function_arguments(self) -> None: + data = InputEntryFunctionData(function="0x1::m::f") + assert data.function_arguments == [] + + def test_default_type_arguments(self) -> None: + data = InputEntryFunctionData(function="0x1::m::f") + assert data.type_arguments is None + + def test_custom_values(self) -> None: + data = InputEntryFunctionData( + function="0x1::m::f", + function_arguments=[1, 2, 3], + type_arguments=["u64"], + ) + assert data.function_arguments == [1, 2, 3] + assert data.type_arguments == ["u64"] + + +# --------------------------------------------------------------------------- +# SimpleTransaction dataclass +# --------------------------------------------------------------------------- + + +class TestSimpleTransaction: + def test_default_fee_payer_address(self) -> None: + mock_raw = MagicMock() + txn = SimpleTransaction(raw_transaction=mock_raw) + assert txn.fee_payer_address is None + + def test_with_fee_payer_address(self) -> None: + mock_raw = MagicMock() + addr = AccountAddress.from_str("0x" + "aa" * 32) + txn = SimpleTransaction(raw_transaction=mock_raw, fee_payer_address=addr) + assert txn.fee_payer_address is addr diff --git a/tests/test_utils_extended.py b/tests/test_utils_extended.py new file mode 100644 index 0000000..8dab4b2 --- /dev/null +++ b/tests/test_utils_extended.py @@ -0,0 +1,831 @@ +"""Extended unit tests for decibel._utils module. + +Covers: FetchError, bigint_reviver, prettify_validation_error, +_base_request_async, _base_request_sync, _process_response, +address derivation helpers, extract_vault_address_from_create_tx, +generate_random_replay_protection_nonce. +""" + +from __future__ import annotations + +import json +from typing import Any +from unittest.mock import AsyncMock, MagicMock, patch + +import httpx +import pytest +from pydantic import BaseModel, ValidationError + +from decibel._utils import ( + FetchError, + _base_request_async, + _base_request_sync, + _process_response, + bigint_reviver, + extract_vault_address_from_create_tx, + generate_random_replay_protection_nonce, + get_market_addr, + get_primary_subaccount_addr, + get_request, + get_request_sync, + get_trading_competition_subaccount_addr, + get_vault_share_address, + patch_request, + patch_request_sync, + post_request, + post_request_sync, + prettify_validation_error, +) + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +class _SimpleModel(BaseModel): + value: int + name: str + + +def _make_response( + status_code: int = 200, + json_data: Any = None, + text: str = "", +) -> httpx.Response: + if json_data is not None: + return httpx.Response( + status_code=status_code, + json=json_data, + request=httpx.Request("GET", "https://test.example.com"), + ) + return httpx.Response( + status_code=status_code, + text=text, + request=httpx.Request("GET", "https://test.example.com"), + ) + + +# --------------------------------------------------------------------------- +# FetchError +# --------------------------------------------------------------------------- + + +class TestFetchError: + def test_json_response_with_status_and_message(self) -> None: + body = json.dumps({"status": "NOT_FOUND", "message": "Resource missing"}) + err = FetchError(body, 404, "Not Found") + + assert err.status == 404 + assert err.status_text == "NOT_FOUND" + assert err.response_message == "Resource missing" + assert "404" in str(err) + assert "NOT_FOUND" in str(err) + assert "Resource missing" in str(err) + + def test_json_response_missing_status_and_message_fields(self) -> None: + body = json.dumps({"foo": "bar"}) + err = FetchError(body, 500, "Server Error") + + # Falls back to HTTP status_text and raw body + assert err.status == 500 + assert err.status_text == "Server Error" + assert err.response_message == body + + def test_json_response_non_string_status(self) -> None: + body = json.dumps({"status": 42, "message": "hello"}) + err = FetchError(body, 400, "Bad Request") + + # status is int, not str → should fall back + assert err.status_text == "Bad Request" + # message is str → should be used + assert err.response_message == "hello" + + def test_non_json_response(self) -> None: + body = "plain text error" + err = FetchError(body, 503, "Service Unavailable") + + assert err.status == 503 + assert err.status_text == "Service Unavailable" + assert err.response_message == "plain text error" + + def test_empty_response(self) -> None: + err = FetchError("", 422, "Unprocessable Entity") + assert err.status == 422 + assert err.status_text == "Unprocessable Entity" + + def test_is_exception(self) -> None: + err = FetchError("{}", 400, "Bad Request") + assert isinstance(err, Exception) + + def test_message_format_no_status_text(self) -> None: + # When status_text is empty string, parenthetical should be omitted + body = json.dumps({"status": "", "message": "oops"}) + err = FetchError(body, 400, "fallback") + # status is empty str → falsy → should show fallback + assert "400" in str(err) + + +# --------------------------------------------------------------------------- +# bigint_reviver +# --------------------------------------------------------------------------- + + +class TestBigintReviver: + def test_converts_bigint_string_to_int(self) -> None: + result = bigint_reviver({"$bigint": "123456789012345678"}) + assert result == 123456789012345678 + assert isinstance(result, int) + + def test_zero_bigint(self) -> None: + assert bigint_reviver({"$bigint": "0"}) == 0 + + def test_passes_through_normal_dict(self) -> None: + d = {"a": 1, "b": "hello"} + assert bigint_reviver(d) is d + + def test_bigint_non_string_value_passes_through(self) -> None: + # If $bigint is not a str, should pass through unchanged + d = {"$bigint": 999} + result = bigint_reviver(d) + assert result is d + + def test_empty_dict(self) -> None: + d: dict[str, Any] = {} + result = bigint_reviver(d) + assert result is d + + def test_dict_with_other_keys(self) -> None: + d = {"name": "Alice", "age": 30} + assert bigint_reviver(d) is d + + def test_json_loads_with_hook(self) -> None: + payload = '{"amount": {"$bigint": "99999999999999"}}' + result = json.loads(payload, object_hook=bigint_reviver) + assert result["amount"] == 99999999999999 + + +# --------------------------------------------------------------------------- +# prettify_validation_error +# --------------------------------------------------------------------------- + + +class TestPrettifyValidationError: + def _make_validation_error(self) -> ValidationError: + try: + _SimpleModel.model_validate({"value": "not_an_int", "name": 123}) + except ValidationError as e: + return e + pytest.fail("Expected ValidationError not raised") + + def test_returns_string(self) -> None: + err = self._make_validation_error() + result = prettify_validation_error(err) + assert isinstance(result, str) + + def test_starts_with_validation_error(self) -> None: + err = self._make_validation_error() + result = prettify_validation_error(err) + assert result.startswith("Validation error:") + + def test_contains_field_location(self) -> None: + err = self._make_validation_error() + result = prettify_validation_error(err) + # Should contain field name from error location + assert "value" in result or "name" in result + + def test_missing_required_field(self) -> None: + try: + _SimpleModel.model_validate({}) + except ValidationError as e: + result = prettify_validation_error(e) + assert "Validation error:" in result + assert "value" in result or "name" in result + + def test_root_location_fallback(self) -> None: + """When loc is empty, should show 'root'.""" + mock_error = MagicMock() + mock_error.errors.return_value = [{"loc": (), "msg": "some error"}] + result = prettify_validation_error(mock_error) + assert "root" in result + assert "some error" in result + + +# --------------------------------------------------------------------------- +# _process_response +# --------------------------------------------------------------------------- + + +class TestProcessResponse: + def test_success_with_valid_json(self) -> None: + response = _make_response(200, json_data={"value": 42, "name": "test"}) + data, status, status_text = _process_response(_SimpleModel, response) + assert data.value == 42 + assert data.name == "test" + assert status == 200 + + def test_non_success_raises_fetch_error(self) -> None: + response = _make_response(404, text="Not found") + with pytest.raises(FetchError) as exc_info: + _process_response(_SimpleModel, response) + assert exc_info.value.status == 404 + + def test_server_error_raises_fetch_error(self) -> None: + response = _make_response(500, text="Internal Server Error") + with pytest.raises(FetchError): + _process_response(_SimpleModel, response) + + def test_validation_error_raises_value_error(self) -> None: + # Response is valid JSON but doesn't match model + response = _make_response(200, json_data={"unexpected": "data"}) + with pytest.raises(ValueError, match="Validation error"): + _process_response(_SimpleModel, response) + + def test_bigint_reviver_is_applied(self) -> None: + class BigintModel(BaseModel): + amount: int + + response = _make_response(200, text='{"amount": {"$bigint": "987"}}') + data, _, _ = _process_response(BigintModel, response) + assert data.amount == 987 + + def test_returns_status_code_and_reason(self) -> None: + response = _make_response(201, json_data={"value": 1, "name": "a"}) + _, status, _ = _process_response(_SimpleModel, response) + assert status == 201 + + +# --------------------------------------------------------------------------- +# _base_request_async +# --------------------------------------------------------------------------- + + +class TestBaseRequestAsync: + @pytest.mark.asyncio + async def test_get_with_provided_client(self) -> None: + mock_client = AsyncMock(spec=httpx.AsyncClient) + mock_client.request = AsyncMock( + return_value=_make_response(200, json_data={"value": 1, "name": "x"}) + ) + + data, status, _ = await _base_request_async( + _SimpleModel, + "https://example.com/api", + "GET", + client=mock_client, + ) + + mock_client.request.assert_awaited_once() + call_kwargs = mock_client.request.call_args + assert call_kwargs.kwargs["method"] == "GET" + assert data.value == 1 + assert status == 200 + + @pytest.mark.asyncio + async def test_get_passes_params(self) -> None: + mock_client = AsyncMock(spec=httpx.AsyncClient) + mock_client.request = AsyncMock( + return_value=_make_response(200, json_data={"value": 2, "name": "y"}) + ) + + await _base_request_async( + _SimpleModel, + "https://example.com/api", + "GET", + params={"foo": "bar"}, + client=mock_client, + ) + + call_kwargs = mock_client.request.call_args.kwargs + assert call_kwargs["params"] == {"foo": "bar"} + + @pytest.mark.asyncio + async def test_post_adds_content_type_header(self) -> None: + mock_client = AsyncMock(spec=httpx.AsyncClient) + mock_client.request = AsyncMock( + return_value=_make_response(200, json_data={"value": 3, "name": "z"}) + ) + + await _base_request_async( + _SimpleModel, + "https://example.com/api", + "POST", + body={"key": "val"}, + client=mock_client, + ) + + call_kwargs = mock_client.request.call_args.kwargs + assert call_kwargs["headers"]["Content-Type"] == "application/json" + + @pytest.mark.asyncio + async def test_patch_adds_content_type_header(self) -> None: + mock_client = AsyncMock(spec=httpx.AsyncClient) + mock_client.request = AsyncMock( + return_value=_make_response(200, json_data={"value": 4, "name": "w"}) + ) + + await _base_request_async( + _SimpleModel, + "https://example.com/api", + "PATCH", + body={"key": "val"}, + client=mock_client, + ) + + call_kwargs = mock_client.request.call_args.kwargs + assert call_kwargs["headers"]["Content-Type"] == "application/json" + + @pytest.mark.asyncio + async def test_api_key_adds_authorization_header(self) -> None: + mock_client = AsyncMock(spec=httpx.AsyncClient) + mock_client.request = AsyncMock( + return_value=_make_response(200, json_data={"value": 5, "name": "v"}) + ) + + await _base_request_async( + _SimpleModel, + "https://example.com/api", + "GET", + api_key="secret-key", + client=mock_client, + ) + + call_kwargs = mock_client.request.call_args.kwargs + assert call_kwargs["headers"]["Authorization"] == "Bearer secret-key" + + @pytest.mark.asyncio + async def test_error_response_raises_fetch_error(self) -> None: + mock_client = AsyncMock(spec=httpx.AsyncClient) + mock_client.request = AsyncMock(return_value=_make_response(401, text="Unauthorized")) + + with pytest.raises(FetchError) as exc_info: + await _base_request_async( + _SimpleModel, + "https://example.com/api", + "GET", + client=mock_client, + ) + assert exc_info.value.status == 401 + + @pytest.mark.asyncio + async def test_without_client_creates_temp_client(self) -> None: + mock_response = _make_response(200, json_data={"value": 6, "name": "u"}) + mock_temp = AsyncMock(spec=httpx.AsyncClient) + mock_temp.request = AsyncMock(return_value=mock_response) + mock_temp.__aenter__ = AsyncMock(return_value=mock_temp) + mock_temp.__aexit__ = AsyncMock(return_value=None) + + with patch("httpx.AsyncClient", return_value=mock_temp): + data, _, _ = await _base_request_async( + _SimpleModel, + "https://example.com/api", + "GET", + ) + assert data.value == 6 + + @pytest.mark.asyncio + async def test_get_does_not_send_body(self) -> None: + mock_client = AsyncMock(spec=httpx.AsyncClient) + mock_client.request = AsyncMock( + return_value=_make_response(200, json_data={"value": 7, "name": "t"}) + ) + + await _base_request_async( + _SimpleModel, + "https://example.com/api", + "GET", + body={"should_be_ignored": True}, + client=mock_client, + ) + + call_kwargs = mock_client.request.call_args.kwargs + assert call_kwargs["json"] is None + + +# --------------------------------------------------------------------------- +# _base_request_sync +# --------------------------------------------------------------------------- + + +class TestBaseRequestSync: + def test_get_with_provided_client(self) -> None: + mock_client = MagicMock(spec=httpx.Client) + mock_client.request = MagicMock( + return_value=_make_response(200, json_data={"value": 10, "name": "sync"}) + ) + + data, status, _ = _base_request_sync( + _SimpleModel, + "https://example.com/api", + "GET", + client=mock_client, + ) + + mock_client.request.assert_called_once() + assert data.value == 10 + assert status == 200 + + def test_post_adds_content_type_header(self) -> None: + mock_client = MagicMock(spec=httpx.Client) + mock_client.request = MagicMock( + return_value=_make_response(200, json_data={"value": 11, "name": "s"}) + ) + + _base_request_sync( + _SimpleModel, + "https://example.com/api", + "POST", + body={"k": "v"}, + client=mock_client, + ) + + call_kwargs = mock_client.request.call_args.kwargs + assert call_kwargs["headers"]["Content-Type"] == "application/json" + + def test_patch_adds_content_type_header(self) -> None: + mock_client = MagicMock(spec=httpx.Client) + mock_client.request = MagicMock( + return_value=_make_response(200, json_data={"value": 12, "name": "r"}) + ) + + _base_request_sync( + _SimpleModel, + "https://example.com/api", + "PATCH", + body={"k": "v"}, + client=mock_client, + ) + + call_kwargs = mock_client.request.call_args.kwargs + assert call_kwargs["headers"]["Content-Type"] == "application/json" + + def test_api_key_adds_authorization_header(self) -> None: + mock_client = MagicMock(spec=httpx.Client) + mock_client.request = MagicMock( + return_value=_make_response(200, json_data={"value": 13, "name": "q"}) + ) + + _base_request_sync( + _SimpleModel, + "https://example.com/api", + "GET", + api_key="my-key", + client=mock_client, + ) + + call_kwargs = mock_client.request.call_args.kwargs + assert call_kwargs["headers"]["Authorization"] == "Bearer my-key" + + def test_error_response_raises_fetch_error(self) -> None: + mock_client = MagicMock(spec=httpx.Client) + mock_client.request = MagicMock(return_value=_make_response(403, text="Forbidden")) + + with pytest.raises(FetchError) as exc_info: + _base_request_sync( + _SimpleModel, + "https://example.com/api", + "GET", + client=mock_client, + ) + assert exc_info.value.status == 403 + + def test_without_client_creates_temp_client(self) -> None: + mock_response = _make_response(200, json_data={"value": 14, "name": "p"}) + mock_temp = MagicMock(spec=httpx.Client) + mock_temp.request = MagicMock(return_value=mock_response) + mock_temp.__enter__ = MagicMock(return_value=mock_temp) + mock_temp.__exit__ = MagicMock(return_value=None) + + with patch("httpx.Client", return_value=mock_temp): + data, _, _ = _base_request_sync( + _SimpleModel, + "https://example.com/api", + "GET", + ) + assert data.value == 14 + + def test_params_passed_correctly(self) -> None: + mock_client = MagicMock(spec=httpx.Client) + mock_client.request = MagicMock( + return_value=_make_response(200, json_data={"value": 15, "name": "o"}) + ) + + _base_request_sync( + _SimpleModel, + "https://example.com/api", + "GET", + params={"limit": "10"}, + client=mock_client, + ) + + call_kwargs = mock_client.request.call_args.kwargs + assert call_kwargs["params"] == {"limit": "10"} + + +# --------------------------------------------------------------------------- +# Address derivation helpers (pure functions using real aptos_sdk) +# --------------------------------------------------------------------------- + + +TEST_PACKAGE = "0x" + "ab" * 32 +TEST_PERP_ENGINE = "0x" + "12" * 32 +TEST_ADDR = "0x" + "aa" * 32 + + +class TestGetMarketAddr: + def test_returns_string(self) -> None: + result = get_market_addr("BTC-PERP", TEST_PERP_ENGINE) + assert isinstance(result, str) + assert result.startswith("0x") + + def test_different_market_names_give_different_addresses(self) -> None: + addr1 = get_market_addr("BTC-PERP", TEST_PERP_ENGINE) + addr2 = get_market_addr("ETH-PERP", TEST_PERP_ENGINE) + assert addr1 != addr2 + + def test_same_inputs_give_same_address(self) -> None: + addr1 = get_market_addr("BTC-PERP", TEST_PERP_ENGINE) + addr2 = get_market_addr("BTC-PERP", TEST_PERP_ENGINE) + assert addr1 == addr2 + + def test_different_perp_engines_give_different_addresses(self) -> None: + engine2 = "0x" + "34" * 32 + addr1 = get_market_addr("BTC-PERP", TEST_PERP_ENGINE) + addr2 = get_market_addr("BTC-PERP", engine2) + assert addr1 != addr2 + + +class TestGetPrimarySubaccountAddr: + def test_returns_string(self) -> None: + from decibel._constants import CompatVersion + + result = get_primary_subaccount_addr(TEST_ADDR, CompatVersion.V0_4, TEST_PACKAGE) + assert isinstance(result, str) + assert result.startswith("0x") + + def test_deterministic(self) -> None: + from decibel._constants import CompatVersion + + addr1 = get_primary_subaccount_addr(TEST_ADDR, CompatVersion.V0_4, TEST_PACKAGE) + addr2 = get_primary_subaccount_addr(TEST_ADDR, CompatVersion.V0_4, TEST_PACKAGE) + assert addr1 == addr2 + + def test_different_owners_give_different_addresses(self) -> None: + from decibel._constants import CompatVersion + + addr2 = "0x" + "bb" * 32 + result1 = get_primary_subaccount_addr(TEST_ADDR, CompatVersion.V0_4, TEST_PACKAGE) + result2 = get_primary_subaccount_addr(addr2, CompatVersion.V0_4, TEST_PACKAGE) + assert result1 != result2 + + def test_accepts_account_address_object(self) -> None: + from aptos_sdk.account_address import AccountAddress + + from decibel._constants import CompatVersion + + addr_obj = AccountAddress.from_str(TEST_ADDR) + result = get_primary_subaccount_addr(addr_obj, CompatVersion.V0_4, TEST_PACKAGE) + assert isinstance(result, str) + + +class TestGetTradingCompetitionSubaccountAddr: + def test_returns_string(self) -> None: + result = get_trading_competition_subaccount_addr(TEST_ADDR) + assert isinstance(result, str) + assert result.startswith("0x") + + def test_deterministic(self) -> None: + r1 = get_trading_competition_subaccount_addr(TEST_ADDR) + r2 = get_trading_competition_subaccount_addr(TEST_ADDR) + assert r1 == r2 + + def test_different_accounts_give_different_addresses(self) -> None: + addr2 = "0x" + "cc" * 32 + r1 = get_trading_competition_subaccount_addr(TEST_ADDR) + r2 = get_trading_competition_subaccount_addr(addr2) + assert r1 != r2 + + def test_accepts_account_address_object(self) -> None: + from aptos_sdk.account_address import AccountAddress + + addr_obj = AccountAddress.from_str(TEST_ADDR) + result = get_trading_competition_subaccount_addr(addr_obj) + assert isinstance(result, str) + + +class TestGetVaultShareAddress: + def test_returns_string(self) -> None: + result = get_vault_share_address(TEST_ADDR) + assert isinstance(result, str) + assert result.startswith("0x") + + def test_deterministic(self) -> None: + r1 = get_vault_share_address(TEST_ADDR) + r2 = get_vault_share_address(TEST_ADDR) + assert r1 == r2 + + def test_different_vault_addresses_give_different_shares(self) -> None: + vault2 = "0x" + "dd" * 32 + r1 = get_vault_share_address(TEST_ADDR) + r2 = get_vault_share_address(vault2) + assert r1 != r2 + + +# --------------------------------------------------------------------------- +# extract_vault_address_from_create_tx +# --------------------------------------------------------------------------- + + +class TestExtractVaultAddressFromCreateTx: + def _make_tx(self, vault_val: Any) -> dict[str, Any]: + return { + "events": [ + { + "type": "0xdeadbeef::vault::VaultCreatedEvent", + "data": {"vault": vault_val}, + } + ] + } + + def test_vault_as_string(self) -> None: + tx = self._make_tx("0xabcdef") + result = extract_vault_address_from_create_tx(tx) + assert result == "0xabcdef" + + def test_vault_as_dict_with_inner(self) -> None: + tx = self._make_tx({"inner": "0x123456"}) + result = extract_vault_address_from_create_tx(tx) + assert result == "0x123456" + + def test_no_vault_created_event_raises(self) -> None: + tx: dict[str, Any] = { + "events": [ + { + "type": "0xdeadbeef::other::OtherEvent", + "data": {"foo": "bar"}, + } + ] + } + with pytest.raises(ValueError, match="Unable to extract vault address"): + extract_vault_address_from_create_tx(tx) + + def test_empty_events_raises(self) -> None: + tx: dict[str, Any] = {"events": []} + with pytest.raises(ValueError, match="Unable to extract vault address"): + extract_vault_address_from_create_tx(tx) + + def test_no_events_key_raises(self) -> None: + tx: dict[str, Any] = {} + with pytest.raises(ValueError, match="Unable to extract vault address"): + extract_vault_address_from_create_tx(tx) + + def test_vault_dict_without_inner_raises(self) -> None: + # Dict vault but no "inner" key → should raise + tx = self._make_tx({"other_key": "0xdeadbeef"}) + with pytest.raises(ValueError, match="Unable to extract vault address"): + extract_vault_address_from_create_tx(tx) + + def test_vault_none_skipped_raises(self) -> None: + tx: dict[str, Any] = { + "events": [ + { + "type": "0xdeadbeef::vault::VaultCreatedEvent", + "data": {"vault": None}, + } + ] + } + with pytest.raises(ValueError, match="Unable to extract vault address"): + extract_vault_address_from_create_tx(tx) + + def test_stops_at_first_vault_event(self) -> None: + tx: dict[str, Any] = { + "events": [ + { + "type": "0xdeadbeef::vault::VaultCreatedEvent", + "data": {"vault": "0xfirst"}, + }, + { + "type": "0xdeadbeef::vault::VaultCreatedEvent", + "data": {"vault": "0xsecond"}, + }, + ] + } + result = extract_vault_address_from_create_tx(tx) + assert result == "0xfirst" + + +# --------------------------------------------------------------------------- +# generate_random_replay_protection_nonce +# --------------------------------------------------------------------------- + + +# --------------------------------------------------------------------------- +# Public request wrapper functions (get_request, post_request, etc.) +# --------------------------------------------------------------------------- + + +class TestPublicRequestWrappers: + """Cover the thin wrapper functions that delegate to _base_request_async/sync.""" + + @pytest.mark.asyncio + async def test_get_request_delegates(self) -> None: + mock_client = AsyncMock(spec=httpx.AsyncClient) + mock_client.request = AsyncMock( + return_value=_make_response(200, json_data={"value": 1, "name": "a"}) + ) + data, status, _ = await get_request( + _SimpleModel, "https://example.com/api", client=mock_client + ) + assert data.value == 1 + assert status == 200 + + @pytest.mark.asyncio + async def test_post_request_delegates(self) -> None: + mock_client = AsyncMock(spec=httpx.AsyncClient) + mock_client.request = AsyncMock( + return_value=_make_response(200, json_data={"value": 2, "name": "b"}) + ) + data, _, _ = await post_request( + _SimpleModel, "https://example.com/api", body={"x": 1}, client=mock_client + ) + assert data.value == 2 + + @pytest.mark.asyncio + async def test_patch_request_delegates(self) -> None: + mock_client = AsyncMock(spec=httpx.AsyncClient) + mock_client.request = AsyncMock( + return_value=_make_response(200, json_data={"value": 3, "name": "c"}) + ) + data, _, _ = await patch_request( + _SimpleModel, "https://example.com/api", body={"y": 2}, client=mock_client + ) + assert data.value == 3 + + def test_get_request_sync_delegates(self) -> None: + mock_client = MagicMock(spec=httpx.Client) + mock_client.request = MagicMock( + return_value=_make_response(200, json_data={"value": 4, "name": "d"}) + ) + data, status, _ = get_request_sync( + _SimpleModel, "https://example.com/api", client=mock_client + ) + assert data.value == 4 + assert status == 200 + + def test_post_request_sync_delegates(self) -> None: + mock_client = MagicMock(spec=httpx.Client) + mock_client.request = MagicMock( + return_value=_make_response(200, json_data={"value": 5, "name": "e"}) + ) + data, _, _ = post_request_sync( + _SimpleModel, "https://example.com/api", body={"z": 3}, client=mock_client + ) + assert data.value == 5 + + def test_patch_request_sync_delegates(self) -> None: + mock_client = MagicMock(spec=httpx.Client) + mock_client.request = MagicMock( + return_value=_make_response(200, json_data={"value": 6, "name": "f"}) + ) + data, _, _ = patch_request_sync( + _SimpleModel, "https://example.com/api", body={"w": 4}, client=mock_client + ) + assert data.value == 6 + + +class TestGenerateRandomReplayProtectionNonce: + def test_returns_int_or_none(self) -> None: + for _ in range(20): + result = generate_random_replay_protection_nonce() + assert result is None or isinstance(result, int) + + def test_non_none_result_is_positive(self) -> None: + for _ in range(20): + result = generate_random_replay_protection_nonce() + if result is not None: + assert result > 0 + + def test_returns_none_when_buf_contains_zero(self) -> None: + # When first buf element is 0, should return None + with patch("secrets.randbits", side_effect=[0, 12345]): + result = generate_random_replay_protection_nonce() + assert result is None + + def test_returns_none_when_second_buf_is_zero(self) -> None: + with patch("secrets.randbits", side_effect=[12345, 0]): + result = generate_random_replay_protection_nonce() + assert result is None + + def test_returns_combined_value_when_both_nonzero(self) -> None: + buf0 = 0xDEAD + buf1 = 0xBEEF + with patch("secrets.randbits", side_effect=[buf0, buf1]): + result = generate_random_replay_protection_nonce() + expected = (buf0 << 32) | buf1 + assert result == expected + + def test_typical_run_mostly_returns_int(self) -> None: + """Statistically, with 32-bit values, zero is extremely rare.""" + results = [generate_random_replay_protection_nonce() for _ in range(50)] + non_none = [r for r in results if r is not None] + # At least 40 out of 50 should be non-None (probability of zero ~1 in 4B) + assert len(non_none) >= 40 diff --git a/tests/write/__init__.py b/tests/write/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/write/test_write_dex.py b/tests/write/test_write_dex.py new file mode 100644 index 0000000..a967b90 --- /dev/null +++ b/tests/write/test_write_dex.py @@ -0,0 +1,1803 @@ +""" +Comprehensive unit tests for src/decibel/write/__init__.py. + +Tests cover DecibelWriteDex (async) and DecibelWriteDexSync (sync) classes, +the _round_to_tick_size helper, and all public methods. + +Strategy: mock _send_tx / _send_tx at the instance level so no real HTTP +calls or blockchain interactions happen. The tests verify that: + 1. The correct Move function name is assembled from the package address. + 2. The correct arguments are passed to InputEntryFunctionData. + 3. The return value is constructed correctly from the tx response. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from decibel._order_types import ( + PlaceBulkOrdersFailure, + PlaceBulkOrdersSuccess, + PlaceOrderFailure, + PlaceOrderSuccess, +) +from decibel._subaccount_types import RenameSubaccountArgs +from decibel.write import ( + DecibelWriteDex, + DecibelWriteDexSync, + TimeInForce, + _round_to_tick_size, # type: ignore[attr-defined] +) + +if TYPE_CHECKING: + from decibel._transaction_builder import InputEntryFunctionData + +# --------------------------------------------------------------------------- +# Constants shared across tests (mirror conftest.py) +# --------------------------------------------------------------------------- +TEST_PACKAGE = "0x" + "ab" * 32 +TEST_USDC = "0x" + "cd" * 32 +TEST_PERP_ENGINE = "0x" + "12" * 32 +TEST_ACCOUNT_ADDR = "0x" + "aa" * 32 +TEST_SUBACCOUNT_ADDR = "0x" + "bb" * 32 +TEST_MARKET_NAME = "BTC-USD" +TEST_VAULT_ADDR = "0x" + "cc" * 32 +TEST_TX_HASH = "0xdeadbeef" + + +# --------------------------------------------------------------------------- +# Helpers – build a minimal fake tx response +# --------------------------------------------------------------------------- + + +def _make_tx_response( + order_id: str = "12345", + user_addr: str = TEST_ACCOUNT_ADDR, + event_type: str = "0x1::market_types::OrderEvent", +) -> dict[str, Any]: + return { + "hash": TEST_TX_HASH, + "success": True, + "events": [ + { + "type": event_type, + "data": { + "user": user_addr, + "order_id": order_id, + }, + } + ], + } + + +def _make_twap_tx_response( + order_id: str = "99999", + user_addr: str = TEST_ACCOUNT_ADDR, +) -> dict[str, Any]: + return { + "hash": TEST_TX_HASH, + "success": True, + "events": [ + { + "type": "0x1::async_matching_engine::TwapEvent", + "data": { + "account": user_addr, + "order_id": {"order_id": order_id}, + }, + } + ], + } + + +# --------------------------------------------------------------------------- +# Fixtures – async (DecibelWriteDex) +# --------------------------------------------------------------------------- + + +@pytest.fixture +def write_dex(test_config, mock_account) -> DecibelWriteDex: + """Return a DecibelWriteDex instance with _send_tx mocked out.""" + with patch("decibel.write.BaseSDK.__init__", return_value=None): + dex = DecibelWriteDex.__new__(DecibelWriteDex) + dex._config = test_config + dex._account = mock_account + dex._http_client = AsyncMock() + dex._skip_simulate = False + dex._no_fee_payer = False + dex._node_api_key = None + dex._gas_price_manager = None + dex._time_delta_ms = 0 + dex._chain_id = 2 + dex._abi_registry = MagicMock() + dex._order_status_client = MagicMock() + dex._send_tx = AsyncMock(return_value=_make_tx_response()) + return dex + + +# --------------------------------------------------------------------------- +# Fixtures – sync (DecibelWriteDexSync) +# --------------------------------------------------------------------------- + + +@pytest.fixture +def write_dex_sync(test_config, mock_account) -> DecibelWriteDexSync: + """Return a DecibelWriteDexSync instance with _send_tx mocked out.""" + with patch("decibel.write.BaseSDKSync.__init__", return_value=None): + dex = DecibelWriteDexSync.__new__(DecibelWriteDexSync) + dex._config = test_config + dex._account = mock_account + dex._http_client = MagicMock() + dex._skip_simulate = False + dex._no_fee_payer = False + dex._node_api_key = None + dex._gas_price_manager = None + dex._time_delta_ms = 0 + dex._chain_id = 2 + dex._abi_registry = MagicMock() + dex._order_status_client = MagicMock() + dex._send_tx = MagicMock(return_value=_make_tx_response()) + return dex + + +# =========================================================================== +# Tests for _round_to_tick_size helper +# =========================================================================== + + +class TestRoundToTickSize: + def test_normal_rounding(self) -> None: + result = _round_to_tick_size(105.3, 10) + assert result == 110 + + def test_rounds_down(self) -> None: + result = _round_to_tick_size(104.9, 10) + assert result == 100 + + def test_exact_multiple(self) -> None: + result = _round_to_tick_size(100.0, 10) + assert result == 100 + + def test_zero_value_returns_zero(self) -> None: + result = _round_to_tick_size(0, 10) + assert result == 0.0 + + def test_zero_tick_size_returns_zero(self) -> None: + result = _round_to_tick_size(100, 0) + assert result == 0.0 + + def test_both_zero_returns_zero(self) -> None: + result = _round_to_tick_size(0, 0) + assert result == 0.0 + + def test_float_inputs(self) -> None: + # 1.05 / 0.1 = 10.5 — Python banker's rounding rounds this to 10 (even) + result = _round_to_tick_size(1.05, 0.1) + assert result == pytest.approx(1.0, rel=1e-6) + + def test_float_inputs_rounds_up(self) -> None: + # 1.17 / 0.1 = 11.7 — rounds to 12 → 1.2 + result = _round_to_tick_size(1.17, 0.1) + assert result == pytest.approx(1.2, rel=1e-6) + + def test_small_tick_size(self) -> None: + result = _round_to_tick_size(123.456, 1) + assert result == 123 + + +# =========================================================================== +# Tests for DecibelWriteDex.__init__ +# =========================================================================== + + +class TestDecibelWriteDexInit: + def test_init_creates_order_status_client(self, test_config, mock_account) -> None: + with ( + patch("decibel.write.BaseSDK.__init__", return_value=None), + patch("decibel.write.OrderStatusClient") as mock_osc, + ): + dex = DecibelWriteDex.__new__(DecibelWriteDex) + # Manually set the attribute that BaseSDK.__init__ would set + dex._http_client = AsyncMock() + dex._config = test_config + dex._account = mock_account + + # Call the actual __init__ partially by calling OrderStatusClient directly + # to verify the wiring. We test the __init__ logic here. + mock_osc.return_value = MagicMock() + + # Now test that order_status_client property returns the internal client + dex._order_status_client = mock_osc.return_value + assert dex.order_status_client is mock_osc.return_value + + +# =========================================================================== +# Tests for _extract_order_id_from_transaction +# =========================================================================== + + +class TestExtractOrderId: + def test_extracts_string_order_id_from_order_event(self, write_dex: DecibelWriteDex) -> None: + tx = _make_tx_response(order_id="42", user_addr=TEST_ACCOUNT_ADDR) + result = write_dex._extract_order_id_from_transaction(tx) + assert result == "42" + + def test_extracts_dict_order_id_from_twap_event(self, write_dex: DecibelWriteDex) -> None: + tx = _make_twap_tx_response(order_id="99", user_addr=TEST_ACCOUNT_ADDR) + result = write_dex._extract_order_id_from_transaction(tx) + assert result == "99" + + def test_returns_none_when_no_events(self, write_dex: DecibelWriteDex) -> None: + tx: dict[str, Any] = {"hash": TEST_TX_HASH, "success": True} + result = write_dex._extract_order_id_from_transaction(tx) + assert result is None + + def test_returns_none_when_events_is_none(self, write_dex: DecibelWriteDex) -> None: + tx: dict[str, Any] = {"hash": TEST_TX_HASH, "events": None} + result = write_dex._extract_order_id_from_transaction(tx) + assert result is None + + def test_returns_none_when_event_type_does_not_match(self, write_dex: DecibelWriteDex) -> None: + tx: dict[str, Any] = { + "hash": TEST_TX_HASH, + "events": [{"type": "0x1::some::OtherEvent", "data": {"user": TEST_ACCOUNT_ADDR}}], + } + result = write_dex._extract_order_id_from_transaction(tx) + assert result is None + + def test_returns_none_when_user_address_does_not_match( + self, write_dex: DecibelWriteDex + ) -> None: + tx = _make_tx_response(order_id="1", user_addr="0x" + "ff" * 32) + result = write_dex._extract_order_id_from_transaction(tx) + assert result is None + + def test_uses_subaccount_addr_when_provided(self, write_dex: DecibelWriteDex) -> None: + tx = _make_tx_response(order_id="777", user_addr=TEST_SUBACCOUNT_ADDR) + result = write_dex._extract_order_id_from_transaction( + tx, subaccount_addr=TEST_SUBACCOUNT_ADDR + ) + assert result == "777" + + def test_returns_none_when_event_data_is_none(self, write_dex: DecibelWriteDex) -> None: + tx: dict[str, Any] = { + "hash": TEST_TX_HASH, + "events": [{"type": "0x1::market_types::OrderEvent", "data": None}], + } + result = write_dex._extract_order_id_from_transaction(tx) + assert result is None + + def test_returns_none_when_order_id_missing_from_nested_dict( + self, write_dex: DecibelWriteDex + ) -> None: + tx: dict[str, Any] = { + "hash": TEST_TX_HASH, + "events": [ + { + "type": "0x1::market_types::OrderEvent", + "data": { + "user": TEST_ACCOUNT_ADDR, + "order_id": {}, # dict with no "order_id" key + }, + } + ], + } + result = write_dex._extract_order_id_from_transaction(tx) + assert result is None + + def test_handles_exception_gracefully(self, write_dex: DecibelWriteDex) -> None: + # Pass an object that causes an error during iteration + result = write_dex._extract_order_id_from_transaction({"events": "not-a-list"}) # type: ignore[arg-type] + assert result is None + + def test_twap_event_with_account_field(self, write_dex: DecibelWriteDex) -> None: + tx: dict[str, Any] = { + "hash": TEST_TX_HASH, + "events": [ + { + "type": "0x1::async_matching_engine::TwapEvent", + "data": { + "account": TEST_ACCOUNT_ADDR, + "order_id": "54321", + }, + } + ], + } + result = write_dex._extract_order_id_from_transaction(tx) + assert result == "54321" + + +# =========================================================================== +# Tests for send_subaccount_tx / with_subaccount +# =========================================================================== + + +class TestSendSubaccountTx: + async def test_uses_primary_subaccount_when_none_provided( + self, write_dex: DecibelWriteDex + ) -> None: + called_with: list[str] = [] + + async def fake_tx(addr: str) -> dict[str, Any]: + called_with.append(addr) + return {"hash": TEST_TX_HASH} + + with patch("decibel.write.get_primary_subaccount_addr", return_value=TEST_SUBACCOUNT_ADDR): + result = await write_dex.send_subaccount_tx(fake_tx) + + assert called_with == [TEST_SUBACCOUNT_ADDR] + assert result == {"hash": TEST_TX_HASH} + + async def test_uses_provided_subaccount_addr(self, write_dex: DecibelWriteDex) -> None: + called_with: list[str] = [] + + async def fake_tx(addr: str) -> dict[str, Any]: + called_with.append(addr) + return {"hash": TEST_TX_HASH} + + result = await write_dex.send_subaccount_tx(fake_tx, subaccount_addr=TEST_SUBACCOUNT_ADDR) + assert called_with == [TEST_SUBACCOUNT_ADDR] + assert result == {"hash": TEST_TX_HASH} + + async def test_with_subaccount_uses_primary_when_none(self, write_dex: DecibelWriteDex) -> None: + called_with: list[str] = [] + + async def fn(addr: str) -> str: + called_with.append(addr) + return "result" + + with patch("decibel.write.get_primary_subaccount_addr", return_value=TEST_SUBACCOUNT_ADDR): + result = await write_dex.with_subaccount(fn) + + assert called_with == [TEST_SUBACCOUNT_ADDR] + assert result == "result" + + async def test_with_subaccount_uses_provided_addr(self, write_dex: DecibelWriteDex) -> None: + called_with: list[str] = [] + + async def fn(addr: str) -> str: + called_with.append(addr) + return "result" + + result = await write_dex.with_subaccount(fn, subaccount_addr=TEST_SUBACCOUNT_ADDR) + assert called_with == [TEST_SUBACCOUNT_ADDR] + assert result == "result" + + +# =========================================================================== +# Tests for rename_subaccount (async) +# =========================================================================== + + +class TestRenameSubaccount: + async def test_posts_to_correct_url(self, write_dex: DecibelWriteDex) -> None: + mock_result = (MagicMock(), 200, "OK") + args = RenameSubaccountArgs(subaccountAddress=TEST_SUBACCOUNT_ADDR, newName="My Account") + + with patch("decibel.write.post_request", return_value=mock_result) as mock_post: + result = await write_dex.rename_subaccount(args) + + expected_url = ( + f"{write_dex._config.trading_http_url}/api/v1/subaccounts/{TEST_SUBACCOUNT_ADDR}" + ) + mock_post.assert_awaited_once() + call_kwargs = mock_post.call_args + assert call_kwargs.args[1] == expected_url + assert call_kwargs.kwargs["body"] == {"name": "My Account"} + assert result == mock_result + + +# =========================================================================== +# Tests for create_subaccount (async) +# =========================================================================== + + +class TestCreateSubaccount: + async def test_sends_correct_function(self, write_dex: DecibelWriteDex) -> None: + result = await write_dex.create_subaccount() + + write_dex._send_tx.assert_awaited_once() + payload: InputEntryFunctionData = write_dex._send_tx.call_args.args[0] + assert payload.function == f"{TEST_PACKAGE}::dex_accounts_entry::create_new_subaccount" + assert payload.type_arguments == [] + assert payload.function_arguments == [] + assert result == _make_tx_response() + + +# =========================================================================== +# Tests for deposit (async) +# =========================================================================== + + +class TestDeposit: + async def test_deposit_uses_primary_subaccount_by_default( + self, write_dex: DecibelWriteDex + ) -> None: + with patch("decibel.write.get_primary_subaccount_addr", return_value=TEST_SUBACCOUNT_ADDR): + await write_dex.deposit(amount=1_000_000) + + write_dex._send_tx.assert_awaited_once() + payload: InputEntryFunctionData = write_dex._send_tx.call_args.args[0] + assert payload.function == f"{TEST_PACKAGE}::dex_accounts_entry::deposit_to_subaccount_at" + assert payload.function_arguments == [TEST_SUBACCOUNT_ADDR, TEST_USDC, 1_000_000] + + async def test_deposit_uses_explicit_subaccount(self, write_dex: DecibelWriteDex) -> None: + await write_dex.deposit(amount=500, subaccount_addr=TEST_SUBACCOUNT_ADDR) + + write_dex._send_tx.assert_awaited_once() + payload: InputEntryFunctionData = write_dex._send_tx.call_args.args[0] + assert payload.function_arguments[0] == TEST_SUBACCOUNT_ADDR + assert payload.function_arguments[2] == 500 + + async def test_deposit_passes_timeouts(self, write_dex: DecibelWriteDex) -> None: + with patch("decibel.write.get_primary_subaccount_addr", return_value=TEST_SUBACCOUNT_ADDR): + await write_dex.deposit( + amount=100, + txn_submit_timeout=5.0, + txn_confirm_timeout=15.0, + ) + + call_kwargs = write_dex._send_tx.call_args.kwargs + assert call_kwargs["txn_submit_timeout"] == 5.0 + assert call_kwargs["txn_confirm_timeout"] == 15.0 + + +# =========================================================================== +# Tests for withdraw (async) +# =========================================================================== + + +class TestWithdraw: + async def test_withdraw_uses_correct_function(self, write_dex: DecibelWriteDex) -> None: + with patch("decibel.write.get_primary_subaccount_addr", return_value=TEST_SUBACCOUNT_ADDR): + await write_dex.withdraw(amount=200) + + payload: InputEntryFunctionData = write_dex._send_tx.call_args.args[0] + assert payload.function == f"{TEST_PACKAGE}::dex_accounts_entry::withdraw_from_subaccount" + assert payload.function_arguments == [TEST_SUBACCOUNT_ADDR, TEST_USDC, 200] + + async def test_withdraw_passes_timeouts(self, write_dex: DecibelWriteDex) -> None: + await write_dex.withdraw( + amount=100, + subaccount_addr=TEST_SUBACCOUNT_ADDR, + txn_submit_timeout=3.0, + txn_confirm_timeout=10.0, + ) + call_kwargs = write_dex._send_tx.call_args.kwargs + assert call_kwargs["txn_submit_timeout"] == 3.0 + assert call_kwargs["txn_confirm_timeout"] == 10.0 + + +# =========================================================================== +# Tests for configure_user_settings_for_market (async) +# =========================================================================== + + +class TestConfigureUserSettingsForMarket: + async def test_sends_correct_payload(self, write_dex: DecibelWriteDex) -> None: + market_addr = "0x" + "11" * 32 + await write_dex.configure_user_settings_for_market( + market_addr=market_addr, + subaccount_addr=TEST_SUBACCOUNT_ADDR, + is_cross=True, + user_leverage=10, + ) + payload: InputEntryFunctionData = write_dex._send_tx.call_args.args[0] + assert ( + payload.function + == f"{TEST_PACKAGE}::dex_accounts_entry::configure_user_settings_for_market" + ) + assert payload.function_arguments == [ + TEST_SUBACCOUNT_ADDR, + market_addr, + True, + 10, + ] + + +# =========================================================================== +# Tests for place_order (async) +# =========================================================================== + + +class TestPlaceOrder: + async def test_place_order_success(self, write_dex: DecibelWriteDex) -> None: + write_dex._send_tx.return_value = _make_tx_response(order_id="123") + with ( + patch("decibel.write.get_market_addr", return_value="0x" + "11" * 32), + patch("decibel.write.get_primary_subaccount_addr", return_value=TEST_SUBACCOUNT_ADDR), + ): + result = await write_dex.place_order( + market_name=TEST_MARKET_NAME, + price=50000, + size=1, + is_buy=True, + time_in_force=TimeInForce.GoodTillCanceled, + is_reduce_only=False, + ) + + assert isinstance(result, PlaceOrderSuccess) + assert result.success is True + assert result.order_id == "123" + assert result.transaction_hash == TEST_TX_HASH + + async def test_place_order_with_tick_size_rounds_price( + self, write_dex: DecibelWriteDex + ) -> None: + # 50004 / 10 = 5000.4 which rounds to 5000 + # 50006 / 10 = 5000.6 which rounds to 5001 → 50010 + with ( + patch("decibel.write.get_market_addr", return_value="0x" + "11" * 32), + patch("decibel.write.get_primary_subaccount_addr", return_value=TEST_SUBACCOUNT_ADDR), + ): + await write_dex.place_order( + market_name=TEST_MARKET_NAME, + price=50006, + size=1, + is_buy=True, + time_in_force=TimeInForce.GoodTillCanceled, + is_reduce_only=False, + tick_size=10, + ) + + payload: InputEntryFunctionData = write_dex._send_tx.call_args.args[0] + final_price = payload.function_arguments[2] + assert final_price == 50010 # round(50006/10)*10 = round(5000.6)*10 = 5001*10 + + async def test_place_order_with_stop_price(self, write_dex: DecibelWriteDex) -> None: + with ( + patch("decibel.write.get_market_addr", return_value="0x" + "11" * 32), + patch("decibel.write.get_primary_subaccount_addr", return_value=TEST_SUBACCOUNT_ADDR), + ): + await write_dex.place_order( + market_name=TEST_MARKET_NAME, + price=50000, + size=1, + is_buy=False, + time_in_force=TimeInForce.GoodTillCanceled, + is_reduce_only=True, + stop_price=49000, + tick_size=10, + ) + + payload: InputEntryFunctionData = write_dex._send_tx.call_args.args[0] + # stop_price should be rounded: round(49000/10)*10 = 49000 + assert payload.function_arguments[8] == 49000 + + async def test_place_order_exception_returns_failure(self, write_dex: DecibelWriteDex) -> None: + write_dex._send_tx.side_effect = RuntimeError("Network error") + with ( + patch("decibel.write.get_market_addr", return_value="0x" + "11" * 32), + patch("decibel.write.get_primary_subaccount_addr", return_value=TEST_SUBACCOUNT_ADDR), + ): + result = await write_dex.place_order( + market_name=TEST_MARKET_NAME, + price=50000, + size=1, + is_buy=True, + time_in_force=TimeInForce.GoodTillCanceled, + is_reduce_only=False, + ) + + assert isinstance(result, PlaceOrderFailure) + assert result.success is False + assert "RuntimeError" in result.error + + async def test_place_order_function_name(self, write_dex: DecibelWriteDex) -> None: + with ( + patch("decibel.write.get_market_addr", return_value="0x" + "11" * 32), + patch("decibel.write.get_primary_subaccount_addr", return_value=TEST_SUBACCOUNT_ADDR), + ): + await write_dex.place_order( + market_name=TEST_MARKET_NAME, + price=100, + size=0.5, + is_buy=True, + time_in_force=TimeInForce.PostOnly, + is_reduce_only=False, + ) + + payload: InputEntryFunctionData = write_dex._send_tx.call_args.args[0] + assert payload.function == f"{TEST_PACKAGE}::dex_accounts_entry::place_order_to_subaccount" + + async def test_place_order_with_tp_sl_prices(self, write_dex: DecibelWriteDex) -> None: + write_dex._send_tx.return_value = _make_tx_response() + with ( + patch("decibel.write.get_market_addr", return_value="0x" + "11" * 32), + patch("decibel.write.get_primary_subaccount_addr", return_value=TEST_SUBACCOUNT_ADDR), + ): + await write_dex.place_order( + market_name=TEST_MARKET_NAME, + price=50000, + size=1, + is_buy=True, + time_in_force=TimeInForce.GoodTillCanceled, + is_reduce_only=False, + tp_trigger_price=55000, + tp_limit_price=56000, + sl_trigger_price=45000, + sl_limit_price=44000, + tick_size=100, + ) + + payload: InputEntryFunctionData = write_dex._send_tx.call_args.args[0] + # Arguments: addr, market, price, size, is_buy, tif, reduce_only, + # client_order_id, stop_price, tp_trigger, tp_limit, sl_trigger, sl_limit, + # builder_addr, builder_fee + assert payload.function_arguments[9] == 55000 # tp_trigger rounded + assert payload.function_arguments[10] == 56000 # tp_limit rounded + assert payload.function_arguments[11] == 45000 # sl_trigger rounded + assert payload.function_arguments[12] == 44000 # sl_limit rounded + + +# =========================================================================== +# Tests for trigger_matching (async) +# =========================================================================== + + +class TestTriggerMatching: + async def test_sends_correct_payload(self, write_dex: DecibelWriteDex) -> None: + market_addr = "0x" + "11" * 32 + result = await write_dex.trigger_matching(market_addr=market_addr, max_work_unit=100) + + payload: InputEntryFunctionData = write_dex._send_tx.call_args.args[0] + assert ( + payload.function == f"{TEST_PACKAGE}::public_apis::process_perp_market_pending_requests" + ) + assert payload.function_arguments == [market_addr, 100] + assert result == {"success": True, "transactionHash": TEST_TX_HASH} + + +# =========================================================================== +# Tests for place_twap_order (async) +# =========================================================================== + + +class TestPlaceTwapOrder: + async def test_place_twap_order_success(self, write_dex: DecibelWriteDex) -> None: + write_dex._send_tx.return_value = _make_twap_tx_response(order_id="88") + + with ( + patch("decibel.write.get_market_addr", return_value="0x" + "11" * 32), + patch("decibel.write.get_primary_subaccount_addr", return_value=TEST_SUBACCOUNT_ADDR), + ): + result = await write_dex.place_twap_order( + market_name=TEST_MARKET_NAME, + size=1, + is_buy=True, + is_reduce_only=False, + twap_frequency_seconds=60, + twap_duration_seconds=3600, + ) + + assert isinstance(result, PlaceOrderSuccess) + assert result.success is True + assert result.order_id == "88" + + async def test_place_twap_order_function_name(self, write_dex: DecibelWriteDex) -> None: + with ( + patch("decibel.write.get_market_addr", return_value="0x" + "11" * 32), + patch("decibel.write.get_primary_subaccount_addr", return_value=TEST_SUBACCOUNT_ADDR), + ): + await write_dex.place_twap_order( + market_name=TEST_MARKET_NAME, + size=0.5, + is_buy=False, + is_reduce_only=True, + twap_frequency_seconds=30, + twap_duration_seconds=1800, + ) + + payload: InputEntryFunctionData = write_dex._send_tx.call_args.args[0] + assert ( + payload.function + == f"{TEST_PACKAGE}::dex_accounts_entry::place_twap_order_to_subaccount_v2" + ) + + async def test_place_twap_order_arguments(self, write_dex: DecibelWriteDex) -> None: + market_addr = "0x" + "11" * 32 + builder_addr = "0x" + "ee" * 32 + + with ( + patch("decibel.write.get_market_addr", return_value=market_addr), + patch("decibel.write.get_primary_subaccount_addr", return_value=TEST_SUBACCOUNT_ADDR), + ): + await write_dex.place_twap_order( + market_name=TEST_MARKET_NAME, + size=2, + is_buy=True, + is_reduce_only=False, + twap_frequency_seconds=60, + twap_duration_seconds=3600, + client_order_id="my-order", + builder_address=builder_addr, + builder_fees=0.001, + ) + + payload: InputEntryFunctionData = write_dex._send_tx.call_args.args[0] + args = payload.function_arguments + # addr, market_addr, size, is_buy, is_reduce_only, client_order_id, + # twap_frequency_seconds, twap_duration_seconds, builder_address, builder_fees + assert args[0] == TEST_SUBACCOUNT_ADDR + assert args[1] == market_addr + assert args[2] == 2 + assert args[3] is True + assert args[4] is False + assert args[5] == "my-order" + assert args[6] == 60 + assert args[7] == 3600 + assert args[8] == builder_addr + assert args[9] == 0.001 + + +# =========================================================================== +# Tests for cancel_order (async) +# =========================================================================== + + +class TestCancelOrder: + async def test_cancel_order_by_market_name(self, write_dex: DecibelWriteDex) -> None: + market_addr = "0x" + "11" * 32 + with ( + patch("decibel.write.get_market_addr", return_value=market_addr), + patch("decibel.write.get_primary_subaccount_addr", return_value=TEST_SUBACCOUNT_ADDR), + ): + await write_dex.cancel_order(order_id=99, market_name=TEST_MARKET_NAME) + + payload: InputEntryFunctionData = write_dex._send_tx.call_args.args[0] + assert payload.function == f"{TEST_PACKAGE}::dex_accounts_entry::cancel_order_to_subaccount" + assert payload.function_arguments == [TEST_SUBACCOUNT_ADDR, 99, market_addr] + + async def test_cancel_order_by_market_addr(self, write_dex: DecibelWriteDex) -> None: + market_addr = "0x" + "11" * 32 + with patch("decibel.write.get_primary_subaccount_addr", return_value=TEST_SUBACCOUNT_ADDR): + await write_dex.cancel_order(order_id="55", market_addr=market_addr) + + payload: InputEntryFunctionData = write_dex._send_tx.call_args.args[0] + assert payload.function_arguments == [TEST_SUBACCOUNT_ADDR, 55, market_addr] + + async def test_cancel_order_raises_when_no_market(self, write_dex: DecibelWriteDex) -> None: + with pytest.raises(ValueError, match="Either market_name or market_addr must be provided"): + await write_dex.cancel_order(order_id=1) + + async def test_cancel_order_converts_str_order_id_to_int( + self, write_dex: DecibelWriteDex + ) -> None: + market_addr = "0x" + "11" * 32 + with patch("decibel.write.get_primary_subaccount_addr", return_value=TEST_SUBACCOUNT_ADDR): + await write_dex.cancel_order(order_id="123", market_addr=market_addr) + + payload: InputEntryFunctionData = write_dex._send_tx.call_args.args[0] + assert payload.function_arguments[1] == 123 # integer, not string + + +# =========================================================================== +# Tests for place_bulk_orders (async) +# =========================================================================== + + +class TestPlaceBulkOrders: + async def test_place_bulk_orders_success(self, write_dex: DecibelWriteDex) -> None: + market_addr = "0x" + "11" * 32 + with ( + patch("decibel.write.get_market_addr", return_value=market_addr), + patch("decibel.write.get_primary_subaccount_addr", return_value=TEST_SUBACCOUNT_ADDR), + ): + result = await write_dex.place_bulk_orders( + market_name=TEST_MARKET_NAME, + sequence_number=1, + bid_prices=[100, 99], + bid_sizes=[10, 20], + ask_prices=[101, 102], + ask_sizes=[10, 20], + ) + + assert isinstance(result, PlaceBulkOrdersSuccess) + assert result.success is True + assert result.transaction_hash == TEST_TX_HASH + + async def test_place_bulk_orders_function_name(self, write_dex: DecibelWriteDex) -> None: + market_addr = "0x" + "11" * 32 + with ( + patch("decibel.write.get_market_addr", return_value=market_addr), + patch("decibel.write.get_primary_subaccount_addr", return_value=TEST_SUBACCOUNT_ADDR), + ): + await write_dex.place_bulk_orders( + market_name=TEST_MARKET_NAME, + sequence_number=1, + bid_prices=[100], + bid_sizes=[10], + ask_prices=[101], + ask_sizes=[10], + ) + + payload: InputEntryFunctionData = write_dex._send_tx.call_args.args[0] + assert ( + payload.function + == f"{TEST_PACKAGE}::dex_accounts_entry::place_bulk_orders_to_subaccount" + ) + + async def test_place_bulk_orders_arguments(self, write_dex: DecibelWriteDex) -> None: + market_addr = "0x" + "11" * 32 + builder_addr = "0x" + "ee" * 32 + with ( + patch("decibel.write.get_market_addr", return_value=market_addr), + patch("decibel.write.get_primary_subaccount_addr", return_value=TEST_SUBACCOUNT_ADDR), + ): + await write_dex.place_bulk_orders( + market_name=TEST_MARKET_NAME, + sequence_number=5, + bid_prices=[100, 99], + bid_sizes=[10, 20], + ask_prices=[101, 102], + ask_sizes=[10, 20], + builder_addr=builder_addr, + builder_fee=50, + ) + + payload: InputEntryFunctionData = write_dex._send_tx.call_args.args[0] + args = payload.function_arguments + assert args[0] == TEST_SUBACCOUNT_ADDR + assert args[1] == market_addr + assert args[2] == 5 + assert args[3] == [100, 99] + assert args[4] == [10, 20] + assert args[5] == [101, 102] + assert args[6] == [10, 20] + assert args[7] == builder_addr + assert args[8] == 50 + + async def test_place_bulk_orders_failure_on_exception(self, write_dex: DecibelWriteDex) -> None: + write_dex._send_tx.side_effect = RuntimeError("Connection failed") + with ( + patch("decibel.write.get_market_addr", return_value="0x" + "11" * 32), + patch("decibel.write.get_primary_subaccount_addr", return_value=TEST_SUBACCOUNT_ADDR), + ): + result = await write_dex.place_bulk_orders( + market_name=TEST_MARKET_NAME, + sequence_number=1, + bid_prices=[100], + bid_sizes=[10], + ask_prices=[101], + ask_sizes=[10], + ) + + assert isinstance(result, PlaceBulkOrdersFailure) + assert "Connection failed" in result.error + + +# =========================================================================== +# Tests for cancel_bulk_order (async) +# =========================================================================== + + +class TestCancelBulkOrder: + async def test_sends_correct_function(self, write_dex: DecibelWriteDex) -> None: + market_addr = "0x" + "11" * 32 + with ( + patch("decibel.write.get_market_addr", return_value=market_addr), + patch("decibel.write.get_primary_subaccount_addr", return_value=TEST_SUBACCOUNT_ADDR), + ): + await write_dex.cancel_bulk_order(market_name=TEST_MARKET_NAME) + + payload: InputEntryFunctionData = write_dex._send_tx.call_args.args[0] + assert ( + payload.function + == f"{TEST_PACKAGE}::dex_accounts_entry::cancel_bulk_order_to_subaccount" + ) + assert payload.function_arguments == [TEST_SUBACCOUNT_ADDR, market_addr] + + +# =========================================================================== +# Tests for cancel_client_order (async) +# =========================================================================== + + +class TestCancelClientOrder: + async def test_sends_correct_function(self, write_dex: DecibelWriteDex) -> None: + market_addr = "0x" + "11" * 32 + with ( + patch("decibel.write.get_market_addr", return_value=market_addr), + patch("decibel.write.get_primary_subaccount_addr", return_value=TEST_SUBACCOUNT_ADDR), + ): + await write_dex.cancel_client_order( + client_order_id="my-order-id", + market_name=TEST_MARKET_NAME, + ) + + payload: InputEntryFunctionData = write_dex._send_tx.call_args.args[0] + assert ( + payload.function + == f"{TEST_PACKAGE}::dex_accounts_entry::cancel_client_order_to_subaccount" + ) + assert payload.function_arguments == [ + TEST_SUBACCOUNT_ADDR, + "my-order-id", + market_addr, + ] + + +# =========================================================================== +# Tests for delegate_trading_to_for_subaccount (async) +# =========================================================================== + + +class TestDelegateTradingToForSubaccount: + async def test_sends_correct_function(self, write_dex: DecibelWriteDex) -> None: + delegate_addr = "0x" + "dd" * 32 + await write_dex.delegate_trading_to_for_subaccount( + subaccount_addr=TEST_SUBACCOUNT_ADDR, + account_to_delegate_to=delegate_addr, + ) + + payload: InputEntryFunctionData = write_dex._send_tx.call_args.args[0] + assert ( + payload.function + == f"{TEST_PACKAGE}::dex_accounts_entry::delegate_trading_to_for_subaccount" + ) + assert payload.function_arguments == [ + TEST_SUBACCOUNT_ADDR, + delegate_addr, + None, + ] + + async def test_passes_expiration(self, write_dex: DecibelWriteDex) -> None: + delegate_addr = "0x" + "dd" * 32 + await write_dex.delegate_trading_to_for_subaccount( + subaccount_addr=TEST_SUBACCOUNT_ADDR, + account_to_delegate_to=delegate_addr, + expiration_timestamp_secs=9999999, + ) + + payload: InputEntryFunctionData = write_dex._send_tx.call_args.args[0] + assert payload.function_arguments[2] == 9999999 + + +# =========================================================================== +# Tests for revoke_delegation (async) +# =========================================================================== + + +class TestRevokeDelegation: + async def test_sends_correct_function(self, write_dex: DecibelWriteDex) -> None: + revoking_addr = "0x" + "dd" * 32 + with patch("decibel.write.get_primary_subaccount_addr", return_value=TEST_SUBACCOUNT_ADDR): + await write_dex.revoke_delegation(account_to_revoke=revoking_addr) + + payload: InputEntryFunctionData = write_dex._send_tx.call_args.args[0] + assert payload.function == f"{TEST_PACKAGE}::dex_accounts_entry::revoke_delegation" + assert payload.function_arguments == [TEST_SUBACCOUNT_ADDR, revoking_addr] + + +# =========================================================================== +# Tests for place_tp_sl_order_for_position (async) +# =========================================================================== + + +class TestPlaceTpSlOrderForPosition: + async def test_sends_correct_function_with_tick_size(self, write_dex: DecibelWriteDex) -> None: + market_addr = "0x" + "11" * 32 + with patch("decibel.write.get_primary_subaccount_addr", return_value=TEST_SUBACCOUNT_ADDR): + await write_dex.place_tp_sl_order_for_position( + market_addr=market_addr, + tp_trigger_price=55000, + tp_limit_price=56000, + tp_size=1, + sl_trigger_price=45000, + sl_limit_price=44000, + sl_size=1, + tick_size=100, + ) + + payload: InputEntryFunctionData = write_dex._send_tx.call_args.args[0] + assert ( + payload.function + == f"{TEST_PACKAGE}::dex_accounts_entry::place_tp_sl_order_for_position" + ) + # addr, market_addr, tp_trigger, tp_limit, tp_size, + # sl_trigger, sl_limit, sl_size, None, None + args = payload.function_arguments + assert args[0] == TEST_SUBACCOUNT_ADDR + assert args[1] == market_addr + assert args[2] == 55000 # tp_trigger rounded to nearest 100 + assert args[8] is None # trailing None + assert args[9] is None # trailing None + + async def test_passes_none_for_trailing_args(self, write_dex: DecibelWriteDex) -> None: + market_addr = "0x" + "11" * 32 + with patch("decibel.write.get_primary_subaccount_addr", return_value=TEST_SUBACCOUNT_ADDR): + await write_dex.place_tp_sl_order_for_position( + market_addr=market_addr, + ) + + payload: InputEntryFunctionData = write_dex._send_tx.call_args.args[0] + args = payload.function_arguments + # Trailing args should be None, None + assert args[-2] is None + assert args[-1] is None + + +# =========================================================================== +# Tests for update_tp_order_for_position (async) +# =========================================================================== + + +class TestUpdateTpOrderForPosition: + async def test_sends_correct_function(self, write_dex: DecibelWriteDex) -> None: + market_addr = "0x" + "11" * 32 + with patch("decibel.write.get_primary_subaccount_addr", return_value=TEST_SUBACCOUNT_ADDR): + await write_dex.update_tp_order_for_position( + market_addr=market_addr, + prev_order_id="42", + tp_trigger_price=55000.0, + ) + + payload: InputEntryFunctionData = write_dex._send_tx.call_args.args[0] + assert ( + payload.function == f"{TEST_PACKAGE}::dex_accounts_entry::update_tp_order_for_position" + ) + args = payload.function_arguments + assert args[0] == TEST_SUBACCOUNT_ADDR + assert args[1] == 42 # int conversion + assert args[2] == market_addr + assert args[3] == 55000.0 + + +# =========================================================================== +# Tests for update_sl_order_for_position (async) +# =========================================================================== + + +class TestUpdateSlOrderForPosition: + async def test_sends_correct_function(self, write_dex: DecibelWriteDex) -> None: + market_addr = "0x" + "11" * 32 + with patch("decibel.write.get_primary_subaccount_addr", return_value=TEST_SUBACCOUNT_ADDR): + await write_dex.update_sl_order_for_position( + market_addr=market_addr, + prev_order_id=99, + sl_trigger_price=44000.0, + ) + + payload: InputEntryFunctionData = write_dex._send_tx.call_args.args[0] + assert ( + payload.function == f"{TEST_PACKAGE}::dex_accounts_entry::update_sl_order_for_position" + ) + args = payload.function_arguments + assert args[1] == 99 + assert args[3] == 44000.0 + + +# =========================================================================== +# Tests for cancel_tp_sl_order_for_position (async) +# =========================================================================== + + +class TestCancelTpSlOrderForPosition: + async def test_sends_correct_function(self, write_dex: DecibelWriteDex) -> None: + market_addr = "0x" + "11" * 32 + with patch("decibel.write.get_primary_subaccount_addr", return_value=TEST_SUBACCOUNT_ADDR): + await write_dex.cancel_tp_sl_order_for_position(market_addr=market_addr, order_id="77") + + payload: InputEntryFunctionData = write_dex._send_tx.call_args.args[0] + assert ( + payload.function + == f"{TEST_PACKAGE}::dex_accounts_entry::cancel_tp_sl_order_for_position" + ) + assert payload.function_arguments == [TEST_SUBACCOUNT_ADDR, market_addr, 77] + + +# =========================================================================== +# Tests for cancel_twap_order (async) +# =========================================================================== + + +class TestCancelTwapOrder: + async def test_sends_correct_function(self, write_dex: DecibelWriteDex) -> None: + market_addr = "0x" + "11" * 32 + with patch("decibel.write.get_primary_subaccount_addr", return_value=TEST_SUBACCOUNT_ADDR): + await write_dex.cancel_twap_order(market_addr=market_addr, order_id=33) + + payload: InputEntryFunctionData = write_dex._send_tx.call_args.args[0] + assert ( + payload.function + == f"{TEST_PACKAGE}::dex_accounts_entry::cancel_twap_orders_to_subaccount" + ) + assert payload.function_arguments == [TEST_SUBACCOUNT_ADDR, market_addr, 33] + + +# =========================================================================== +# Tests for deactivate_subaccount (async) +# =========================================================================== + + +class TestDeactivateSubaccount: + async def test_sends_correct_function(self, write_dex: DecibelWriteDex) -> None: + await write_dex.deactivate_subaccount(subaccount_addr=TEST_SUBACCOUNT_ADDR) + + payload: InputEntryFunctionData = write_dex._send_tx.call_args.args[0] + assert payload.function == f"{TEST_PACKAGE}::dex_accounts_entry::deactivate_subaccount" + assert payload.function_arguments == [TEST_SUBACCOUNT_ADDR, True] + + async def test_passes_revoke_all_delegations_false(self, write_dex: DecibelWriteDex) -> None: + await write_dex.deactivate_subaccount( + subaccount_addr=TEST_SUBACCOUNT_ADDR, revoke_all_delegations=False + ) + + payload: InputEntryFunctionData = write_dex._send_tx.call_args.args[0] + assert payload.function_arguments[1] is False + + +# =========================================================================== +# Tests for activate_vault (async) +# =========================================================================== + + +class TestActivateVault: + async def test_sends_correct_function(self, write_dex: DecibelWriteDex) -> None: + await write_dex.activate_vault(vault_address=TEST_VAULT_ADDR) + + payload: InputEntryFunctionData = write_dex._send_tx.call_args.args[0] + assert payload.function == f"{TEST_PACKAGE}::vault_api::activate_vault" + assert payload.function_arguments == [TEST_VAULT_ADDR] + + +# =========================================================================== +# Tests for deposit_to_vault (async) +# =========================================================================== + + +class TestDepositToVault: + async def test_sends_correct_function(self, write_dex: DecibelWriteDex) -> None: + await write_dex.deposit_to_vault( + vault_address=TEST_VAULT_ADDR, + amount=500.0, + subaccount_addr=TEST_SUBACCOUNT_ADDR, + ) + + payload: InputEntryFunctionData = write_dex._send_tx.call_args.args[0] + assert payload.function == f"{TEST_PACKAGE}::dex_accounts_entry::contribute_to_vault" + assert payload.function_arguments == [ + TEST_SUBACCOUNT_ADDR, + TEST_VAULT_ADDR, + TEST_USDC, + 500.0, + ] + + +# =========================================================================== +# Tests for withdraw_from_vault (async) +# =========================================================================== + + +class TestWithdrawFromVault: + async def test_sends_correct_function(self, write_dex: DecibelWriteDex) -> None: + with patch("decibel.write.get_primary_subaccount_addr", return_value=TEST_SUBACCOUNT_ADDR): + await write_dex.withdraw_from_vault(vault_address=TEST_VAULT_ADDR, shares=10.0) + + payload: InputEntryFunctionData = write_dex._send_tx.call_args.args[0] + assert payload.function == f"{TEST_PACKAGE}::dex_accounts_entry::redeem_from_vault" + assert payload.function_arguments == [TEST_SUBACCOUNT_ADDR, TEST_VAULT_ADDR, 10.0] + + +# =========================================================================== +# Tests for delegate_vault_actions (async) +# =========================================================================== + + +class TestDelegateVaultActions: + async def test_sends_correct_function(self, write_dex: DecibelWriteDex) -> None: + delegate_addr = "0x" + "dd" * 32 + await write_dex.delegate_vault_actions( + vault_address=TEST_VAULT_ADDR, + account_to_delegate_to=delegate_addr, + ) + + payload: InputEntryFunctionData = write_dex._send_tx.call_args.args[0] + assert payload.function == f"{TEST_PACKAGE}::vault_admin_api::delegate_dex_actions_to" + assert payload.function_arguments == [TEST_VAULT_ADDR, delegate_addr, None] + + +# =========================================================================== +# Tests for approve_max_builder_fee (async) +# =========================================================================== + + +class TestApproveMaxBuilderFee: + async def test_sends_correct_function(self, write_dex: DecibelWriteDex) -> None: + builder_addr = "0x" + "ee" * 32 + with patch("decibel.write.get_primary_subaccount_addr", return_value=TEST_SUBACCOUNT_ADDR): + await write_dex.approve_max_builder_fee(builder_addr=builder_addr, max_fee=1000) + + payload: InputEntryFunctionData = write_dex._send_tx.call_args.args[0] + assert ( + payload.function + == f"{TEST_PACKAGE}::dex_accounts_entry::approve_max_builder_fee_for_subaccount" + ) + assert payload.function_arguments == [TEST_SUBACCOUNT_ADDR, builder_addr, 1000] + + +# =========================================================================== +# Tests for revoke_max_builder_fee (async) +# =========================================================================== + + +class TestRevokeMaxBuilderFee: + async def test_sends_correct_function(self, write_dex: DecibelWriteDex) -> None: + builder_addr = "0x" + "ee" * 32 + with patch("decibel.write.get_primary_subaccount_addr", return_value=TEST_SUBACCOUNT_ADDR): + await write_dex.revoke_max_builder_fee(builder_addr=builder_addr) + + payload: InputEntryFunctionData = write_dex._send_tx.call_args.args[0] + assert ( + payload.function + == f"{TEST_PACKAGE}::dex_accounts_entry::revoke_max_builder_fee_for_subaccount" + ) + assert payload.function_arguments == [TEST_SUBACCOUNT_ADDR, builder_addr] + + +# =========================================================================== +# Tests for create_vault (async) +# =========================================================================== + + +class TestCreateVault: + async def test_sends_correct_function_with_explicit_subaccount( + self, write_dex: DecibelWriteDex + ) -> None: + args = { + "vault_name": "My Vault", + "vault_description": "A test vault", + "vault_social_links": [], + "vault_share_symbol": "MVT", + "fee_bps": 100, + "fee_interval_s": 86400, + "contribution_lockup_duration_s": 604800, + "initial_funding": 1000, + "accepts_contributions": True, + "delegate_to_creator": False, + } + with patch("decibel.write.get_primary_subaccount_addr", return_value=TEST_SUBACCOUNT_ADDR): + await write_dex.create_vault( + args, + subaccount_addr=TEST_SUBACCOUNT_ADDR, # type: ignore[arg-type] + ) + + payload: InputEntryFunctionData = write_dex._send_tx.call_args.args[0] + assert payload.function == f"{TEST_PACKAGE}::vault_api::create_and_fund_vault" + + +# =========================================================================== +# Tests for DecibelWriteDexSync (representative subset) +# =========================================================================== + + +class TestDecibelWriteDexSyncInit: + def test_init_creates_order_status_client(self, test_config, mock_account) -> None: + with ( + patch("decibel.write.BaseSDKSync.__init__", return_value=None), + patch("decibel.write.OrderStatusClient") as mock_osc, + ): + dex = DecibelWriteDexSync.__new__(DecibelWriteDexSync) + dex._http_client = MagicMock() + dex._config = test_config + dex._account = mock_account + dex._order_status_client = mock_osc.return_value + assert dex.order_status_client is mock_osc.return_value + + +class TestDecibelWriteDexSyncExtractOrderId: + def test_extracts_string_order_id(self, write_dex_sync: DecibelWriteDexSync) -> None: + tx = _make_tx_response(order_id="42", user_addr=TEST_ACCOUNT_ADDR) + result = write_dex_sync._extract_order_id_from_transaction(tx) + assert result == "42" + + def test_returns_none_when_no_events(self, write_dex_sync: DecibelWriteDexSync) -> None: + tx: dict[str, Any] = {"hash": TEST_TX_HASH} + result = write_dex_sync._extract_order_id_from_transaction(tx) + assert result is None + + def test_handles_twap_event(self, write_dex_sync: DecibelWriteDexSync) -> None: + tx = _make_twap_tx_response(order_id="77") + result = write_dex_sync._extract_order_id_from_transaction(tx) + assert result == "77" + + def test_handles_exception_gracefully(self, write_dex_sync: DecibelWriteDexSync) -> None: + result = write_dex_sync._extract_order_id_from_transaction({"events": "bad"}) # type: ignore[arg-type] + assert result is None + + +class TestDecibelWriteDexSyncSendSubaccountTx: + def test_uses_primary_subaccount_when_none(self, write_dex_sync: DecibelWriteDexSync) -> None: + called_with: list[str] = [] + + def fake_tx(addr: str) -> dict[str, Any]: + called_with.append(addr) + return {"hash": TEST_TX_HASH} + + with patch("decibel.write.get_primary_subaccount_addr", return_value=TEST_SUBACCOUNT_ADDR): + result = write_dex_sync.send_subaccount_tx(fake_tx) + + assert called_with == [TEST_SUBACCOUNT_ADDR] + assert result == {"hash": TEST_TX_HASH} + + def test_uses_provided_subaccount_addr(self, write_dex_sync: DecibelWriteDexSync) -> None: + called_with: list[str] = [] + + def fake_tx(addr: str) -> dict[str, Any]: + called_with.append(addr) + return {"hash": TEST_TX_HASH} + + write_dex_sync.send_subaccount_tx(fake_tx, subaccount_addr=TEST_SUBACCOUNT_ADDR) + assert called_with == [TEST_SUBACCOUNT_ADDR] + + +class TestDecibelWriteDexSyncRenameSubaccount: + def test_posts_to_correct_url(self, write_dex_sync: DecibelWriteDexSync) -> None: + mock_result = (MagicMock(), 200, "OK") + args = RenameSubaccountArgs(subaccountAddress=TEST_SUBACCOUNT_ADDR, newName="New Name") + + with patch("decibel.write.post_request_sync", return_value=mock_result) as mock_post: + write_dex_sync.rename_subaccount(args) + + expected_url = ( + f"{write_dex_sync._config.trading_http_url}/api/v1/subaccounts/{TEST_SUBACCOUNT_ADDR}" + ) + mock_post.assert_called_once() + call_kwargs = mock_post.call_args + assert call_kwargs.args[1] == expected_url + assert call_kwargs.kwargs["body"] == {"name": "New Name"} + + +class TestDecibelWriteDexSyncCreateSubaccount: + def test_sends_correct_function(self, write_dex_sync: DecibelWriteDexSync) -> None: + result = write_dex_sync.create_subaccount() + + write_dex_sync._send_tx.assert_called_once() + payload: InputEntryFunctionData = write_dex_sync._send_tx.call_args.args[0] + assert payload.function == f"{TEST_PACKAGE}::dex_accounts_entry::create_new_subaccount" + assert result == _make_tx_response() + + +class TestDecibelWriteDexSyncDeposit: + def test_deposit_uses_primary_subaccount_by_default( + self, write_dex_sync: DecibelWriteDexSync + ) -> None: + with patch("decibel.write.get_primary_subaccount_addr", return_value=TEST_SUBACCOUNT_ADDR): + write_dex_sync.deposit(amount=100) + + payload: InputEntryFunctionData = write_dex_sync._send_tx.call_args.args[0] + assert payload.function == f"{TEST_PACKAGE}::dex_accounts_entry::deposit_to_subaccount_at" + assert payload.function_arguments == [TEST_SUBACCOUNT_ADDR, TEST_USDC, 100] + + def test_deposit_uses_explicit_subaccount(self, write_dex_sync: DecibelWriteDexSync) -> None: + write_dex_sync.deposit(amount=250, subaccount_addr=TEST_SUBACCOUNT_ADDR) + payload: InputEntryFunctionData = write_dex_sync._send_tx.call_args.args[0] + assert payload.function_arguments[0] == TEST_SUBACCOUNT_ADDR + assert payload.function_arguments[2] == 250 + + +class TestDecibelWriteDexSyncWithdraw: + def test_withdraw_uses_correct_function(self, write_dex_sync: DecibelWriteDexSync) -> None: + with patch("decibel.write.get_primary_subaccount_addr", return_value=TEST_SUBACCOUNT_ADDR): + write_dex_sync.withdraw(amount=300) + + payload: InputEntryFunctionData = write_dex_sync._send_tx.call_args.args[0] + assert payload.function == f"{TEST_PACKAGE}::dex_accounts_entry::withdraw_from_subaccount" + assert payload.function_arguments == [TEST_SUBACCOUNT_ADDR, TEST_USDC, 300] + + +class TestDecibelWriteDexSyncPlaceOrder: + def test_place_order_success(self, write_dex_sync: DecibelWriteDexSync) -> None: + write_dex_sync._send_tx.return_value = _make_tx_response(order_id="456") + with ( + patch("decibel.write.get_market_addr", return_value="0x" + "11" * 32), + patch("decibel.write.get_primary_subaccount_addr", return_value=TEST_SUBACCOUNT_ADDR), + ): + result = write_dex_sync.place_order( + market_name=TEST_MARKET_NAME, + price=50000, + size=1, + is_buy=True, + time_in_force=TimeInForce.GoodTillCanceled, + is_reduce_only=False, + ) + + assert isinstance(result, PlaceOrderSuccess) + assert result.order_id == "456" + assert result.transaction_hash == TEST_TX_HASH + + def test_place_order_failure_on_exception(self, write_dex_sync: DecibelWriteDexSync) -> None: + write_dex_sync._send_tx.side_effect = RuntimeError("Sync error") + with ( + patch("decibel.write.get_market_addr", return_value="0x" + "11" * 32), + patch("decibel.write.get_primary_subaccount_addr", return_value=TEST_SUBACCOUNT_ADDR), + ): + result = write_dex_sync.place_order( + market_name=TEST_MARKET_NAME, + price=100, + size=0.5, + is_buy=False, + time_in_force=TimeInForce.ImmediateOrCancel, + is_reduce_only=True, + ) + + assert isinstance(result, PlaceOrderFailure) + assert "RuntimeError" in result.error + + def test_place_order_with_tick_size(self, write_dex_sync: DecibelWriteDexSync) -> None: + with ( + patch("decibel.write.get_market_addr", return_value="0x" + "11" * 32), + patch("decibel.write.get_primary_subaccount_addr", return_value=TEST_SUBACCOUNT_ADDR), + ): + write_dex_sync.place_order( + market_name=TEST_MARKET_NAME, + price=50007, + size=1, + is_buy=True, + time_in_force=TimeInForce.GoodTillCanceled, + is_reduce_only=False, + tick_size=10, + ) + + payload: InputEntryFunctionData = write_dex_sync._send_tx.call_args.args[0] + assert payload.function_arguments[2] == 50010 # rounded to nearest 10 + + +class TestDecibelWriteDexSyncCancelOrder: + def test_cancel_order_raises_without_market(self, write_dex_sync: DecibelWriteDexSync) -> None: + with pytest.raises(ValueError, match="Either market_name or market_addr must be provided"): + write_dex_sync.cancel_order(order_id=1) + + def test_cancel_order_by_market_name(self, write_dex_sync: DecibelWriteDexSync) -> None: + market_addr = "0x" + "11" * 32 + with ( + patch("decibel.write.get_market_addr", return_value=market_addr), + patch("decibel.write.get_primary_subaccount_addr", return_value=TEST_SUBACCOUNT_ADDR), + ): + write_dex_sync.cancel_order(order_id=10, market_name=TEST_MARKET_NAME) + + payload: InputEntryFunctionData = write_dex_sync._send_tx.call_args.args[0] + assert payload.function == f"{TEST_PACKAGE}::dex_accounts_entry::cancel_order_to_subaccount" + + def test_cancel_order_by_market_addr(self, write_dex_sync: DecibelWriteDexSync) -> None: + market_addr = "0x" + "11" * 32 + with patch("decibel.write.get_primary_subaccount_addr", return_value=TEST_SUBACCOUNT_ADDR): + write_dex_sync.cancel_order(order_id="20", market_addr=market_addr) + + payload: InputEntryFunctionData = write_dex_sync._send_tx.call_args.args[0] + assert payload.function_arguments[1] == 20 # int conversion + + +class TestDecibelWriteDexSyncPlaceBulkOrders: + def test_success(self, write_dex_sync: DecibelWriteDexSync) -> None: + market_addr = "0x" + "11" * 32 + with ( + patch("decibel.write.get_market_addr", return_value=market_addr), + patch("decibel.write.get_primary_subaccount_addr", return_value=TEST_SUBACCOUNT_ADDR), + ): + result = write_dex_sync.place_bulk_orders( + market_name=TEST_MARKET_NAME, + sequence_number=1, + bid_prices=[100], + bid_sizes=[10], + ask_prices=[101], + ask_sizes=[10], + ) + + assert isinstance(result, PlaceBulkOrdersSuccess) + + def test_failure_on_exception(self, write_dex_sync: DecibelWriteDexSync) -> None: + write_dex_sync._send_tx.side_effect = RuntimeError("Failure") + with ( + patch("decibel.write.get_market_addr", return_value="0x" + "11" * 32), + patch("decibel.write.get_primary_subaccount_addr", return_value=TEST_SUBACCOUNT_ADDR), + ): + result = write_dex_sync.place_bulk_orders( + market_name=TEST_MARKET_NAME, + sequence_number=1, + bid_prices=[100], + bid_sizes=[10], + ask_prices=[101], + ask_sizes=[10], + ) + + assert isinstance(result, PlaceBulkOrdersFailure) + + +class TestDecibelWriteDexSyncTwapOrder: + def test_place_twap_success(self, write_dex_sync: DecibelWriteDexSync) -> None: + write_dex_sync._send_tx.return_value = _make_twap_tx_response(order_id="11") + with ( + patch("decibel.write.get_market_addr", return_value="0x" + "11" * 32), + patch("decibel.write.get_primary_subaccount_addr", return_value=TEST_SUBACCOUNT_ADDR), + ): + result = write_dex_sync.place_twap_order( + market_name=TEST_MARKET_NAME, + size=1, + is_buy=True, + is_reduce_only=False, + twap_frequency_seconds=60, + twap_duration_seconds=3600, + ) + + assert isinstance(result, PlaceOrderSuccess) + assert result.order_id == "11" + + def test_cancel_twap_order(self, write_dex_sync: DecibelWriteDexSync) -> None: + market_addr = "0x" + "11" * 32 + with patch("decibel.write.get_primary_subaccount_addr", return_value=TEST_SUBACCOUNT_ADDR): + write_dex_sync.cancel_twap_order(order_id="55", market_addr=market_addr) + + payload: InputEntryFunctionData = write_dex_sync._send_tx.call_args.args[0] + assert ( + payload.function + == f"{TEST_PACKAGE}::dex_accounts_entry::cancel_twap_orders_to_subaccount" + ) + assert payload.function_arguments[2] == 55 # int conversion + + +class TestDecibelWriteDexSyncVaultOperations: + def test_activate_vault(self, write_dex_sync: DecibelWriteDexSync) -> None: + write_dex_sync.activate_vault(vault_address=TEST_VAULT_ADDR) + payload: InputEntryFunctionData = write_dex_sync._send_tx.call_args.args[0] + assert payload.function == f"{TEST_PACKAGE}::vault_api::activate_vault" + assert payload.function_arguments == [TEST_VAULT_ADDR] + + def test_deposit_to_vault(self, write_dex_sync: DecibelWriteDexSync) -> None: + write_dex_sync.deposit_to_vault( + vault_address=TEST_VAULT_ADDR, + amount=100.0, + subaccount_addr=TEST_SUBACCOUNT_ADDR, + ) + payload: InputEntryFunctionData = write_dex_sync._send_tx.call_args.args[0] + assert payload.function == f"{TEST_PACKAGE}::dex_accounts_entry::contribute_to_vault" + assert payload.function_arguments == [ + TEST_SUBACCOUNT_ADDR, + TEST_VAULT_ADDR, + TEST_USDC, + 100.0, + ] + + def test_withdraw_from_vault(self, write_dex_sync: DecibelWriteDexSync) -> None: + with patch("decibel.write.get_primary_subaccount_addr", return_value=TEST_SUBACCOUNT_ADDR): + write_dex_sync.withdraw_from_vault(vault_address=TEST_VAULT_ADDR, shares=5.0) + + payload: InputEntryFunctionData = write_dex_sync._send_tx.call_args.args[0] + assert payload.function == f"{TEST_PACKAGE}::dex_accounts_entry::redeem_from_vault" + assert payload.function_arguments == [TEST_SUBACCOUNT_ADDR, TEST_VAULT_ADDR, 5.0] + + def test_delegate_vault_actions(self, write_dex_sync: DecibelWriteDexSync) -> None: + delegate_addr = "0x" + "dd" * 32 + write_dex_sync.delegate_vault_actions( + vault_address=TEST_VAULT_ADDR, + account_to_delegate_to=delegate_addr, + ) + payload: InputEntryFunctionData = write_dex_sync._send_tx.call_args.args[0] + assert payload.function == f"{TEST_PACKAGE}::vault_admin_api::delegate_dex_actions_to" + assert payload.function_arguments == [TEST_VAULT_ADDR, delegate_addr, None] + + +class TestDecibelWriteDexSyncApproveRevokeBuilderFee: + def test_approve_max_builder_fee(self, write_dex_sync: DecibelWriteDexSync) -> None: + builder_addr = "0x" + "ee" * 32 + with patch("decibel.write.get_primary_subaccount_addr", return_value=TEST_SUBACCOUNT_ADDR): + write_dex_sync.approve_max_builder_fee(builder_addr=builder_addr, max_fee=500) + + payload: InputEntryFunctionData = write_dex_sync._send_tx.call_args.args[0] + assert ( + payload.function + == f"{TEST_PACKAGE}::dex_accounts_entry::approve_max_builder_fee_for_subaccount" + ) + assert payload.function_arguments == [TEST_SUBACCOUNT_ADDR, builder_addr, 500] + + def test_revoke_max_builder_fee(self, write_dex_sync: DecibelWriteDexSync) -> None: + builder_addr = "0x" + "ee" * 32 + with patch("decibel.write.get_primary_subaccount_addr", return_value=TEST_SUBACCOUNT_ADDR): + write_dex_sync.revoke_max_builder_fee(builder_addr=builder_addr) + + payload: InputEntryFunctionData = write_dex_sync._send_tx.call_args.args[0] + assert ( + payload.function + == f"{TEST_PACKAGE}::dex_accounts_entry::revoke_max_builder_fee_for_subaccount" + ) + assert payload.function_arguments == [TEST_SUBACCOUNT_ADDR, builder_addr] + + +class TestDecibelWriteDexSyncDelegation: + def test_delegate_trading_to_for_subaccount(self, write_dex_sync: DecibelWriteDexSync) -> None: + delegate_addr = "0x" + "dd" * 32 + write_dex_sync.delegate_trading_to_for_subaccount( + subaccount_addr=TEST_SUBACCOUNT_ADDR, + account_to_delegate_to=delegate_addr, + expiration_timestamp_secs=12345, + ) + payload: InputEntryFunctionData = write_dex_sync._send_tx.call_args.args[0] + assert ( + payload.function + == f"{TEST_PACKAGE}::dex_accounts_entry::delegate_trading_to_for_subaccount" + ) + assert payload.function_arguments == [ + TEST_SUBACCOUNT_ADDR, + delegate_addr, + 12345, + ] + + def test_revoke_delegation(self, write_dex_sync: DecibelWriteDexSync) -> None: + revoking_addr = "0x" + "ff" * 32 + with patch("decibel.write.get_primary_subaccount_addr", return_value=TEST_SUBACCOUNT_ADDR): + write_dex_sync.revoke_delegation(account_to_revoke=revoking_addr) + + payload: InputEntryFunctionData = write_dex_sync._send_tx.call_args.args[0] + assert payload.function == f"{TEST_PACKAGE}::dex_accounts_entry::revoke_delegation" + assert payload.function_arguments == [TEST_SUBACCOUNT_ADDR, revoking_addr] + + +class TestDecibelWriteDexSyncTpSlOrders: + def test_place_tp_sl_order_for_position(self, write_dex_sync: DecibelWriteDexSync) -> None: + market_addr = "0x" + "11" * 32 + with patch("decibel.write.get_primary_subaccount_addr", return_value=TEST_SUBACCOUNT_ADDR): + write_dex_sync.place_tp_sl_order_for_position( + market_addr=market_addr, + tp_trigger_price=55000, + sl_trigger_price=45000, + tick_size=100, + ) + + payload: InputEntryFunctionData = write_dex_sync._send_tx.call_args.args[0] + assert ( + payload.function + == f"{TEST_PACKAGE}::dex_accounts_entry::place_tp_sl_order_for_position" + ) + # Check trailing Nones + assert payload.function_arguments[-2] is None + assert payload.function_arguments[-1] is None + + def test_update_tp_order_for_position(self, write_dex_sync: DecibelWriteDexSync) -> None: + market_addr = "0x" + "11" * 32 + with patch("decibel.write.get_primary_subaccount_addr", return_value=TEST_SUBACCOUNT_ADDR): + write_dex_sync.update_tp_order_for_position( + market_addr=market_addr, + prev_order_id="88", + tp_trigger_price=60000.0, + ) + + payload: InputEntryFunctionData = write_dex_sync._send_tx.call_args.args[0] + assert ( + payload.function == f"{TEST_PACKAGE}::dex_accounts_entry::update_tp_order_for_position" + ) + assert payload.function_arguments[1] == 88 # int conversion + + def test_update_sl_order_for_position(self, write_dex_sync: DecibelWriteDexSync) -> None: + market_addr = "0x" + "11" * 32 + with patch("decibel.write.get_primary_subaccount_addr", return_value=TEST_SUBACCOUNT_ADDR): + write_dex_sync.update_sl_order_for_position( + market_addr=market_addr, + prev_order_id=77, + sl_trigger_price=40000.0, + ) + + payload: InputEntryFunctionData = write_dex_sync._send_tx.call_args.args[0] + assert ( + payload.function == f"{TEST_PACKAGE}::dex_accounts_entry::update_sl_order_for_position" + ) + + def test_cancel_tp_sl_order_for_position(self, write_dex_sync: DecibelWriteDexSync) -> None: + market_addr = "0x" + "11" * 32 + with patch("decibel.write.get_primary_subaccount_addr", return_value=TEST_SUBACCOUNT_ADDR): + write_dex_sync.cancel_tp_sl_order_for_position(market_addr=market_addr, order_id="33") + + payload: InputEntryFunctionData = write_dex_sync._send_tx.call_args.args[0] + assert payload.function_arguments == [TEST_SUBACCOUNT_ADDR, market_addr, 33] + + +class TestDecibelWriteDexSyncDeactivateSubaccount: + def test_sends_correct_function(self, write_dex_sync: DecibelWriteDexSync) -> None: + write_dex_sync.deactivate_subaccount(subaccount_addr=TEST_SUBACCOUNT_ADDR) + payload: InputEntryFunctionData = write_dex_sync._send_tx.call_args.args[0] + assert payload.function == f"{TEST_PACKAGE}::dex_accounts_entry::deactivate_subaccount" + assert payload.function_arguments == [TEST_SUBACCOUNT_ADDR, True] + + def test_revoke_all_delegations_false(self, write_dex_sync: DecibelWriteDexSync) -> None: + write_dex_sync.deactivate_subaccount( + subaccount_addr=TEST_SUBACCOUNT_ADDR, revoke_all_delegations=False + ) + payload: InputEntryFunctionData = write_dex_sync._send_tx.call_args.args[0] + assert payload.function_arguments[1] is False + + +class TestDecibelWriteDexSyncTriggerMatching: + def test_sends_correct_function(self, write_dex_sync: DecibelWriteDexSync) -> None: + market_addr = "0x" + "11" * 32 + result = write_dex_sync.trigger_matching(market_addr=market_addr, max_work_unit=50) + + payload: InputEntryFunctionData = write_dex_sync._send_tx.call_args.args[0] + assert ( + payload.function == f"{TEST_PACKAGE}::public_apis::process_perp_market_pending_requests" + ) + assert payload.function_arguments == [market_addr, 50] + assert result == {"success": True, "transactionHash": TEST_TX_HASH} + + +class TestDecibelWriteDexSyncCancelBulkOrder: + def test_sends_correct_function(self, write_dex_sync: DecibelWriteDexSync) -> None: + market_addr = "0x" + "11" * 32 + with ( + patch("decibel.write.get_market_addr", return_value=market_addr), + patch("decibel.write.get_primary_subaccount_addr", return_value=TEST_SUBACCOUNT_ADDR), + ): + write_dex_sync.cancel_bulk_order(market_name=TEST_MARKET_NAME) + + payload: InputEntryFunctionData = write_dex_sync._send_tx.call_args.args[0] + assert ( + payload.function + == f"{TEST_PACKAGE}::dex_accounts_entry::cancel_bulk_order_to_subaccount" + ) + + def test_cancel_client_order(self, write_dex_sync: DecibelWriteDexSync) -> None: + market_addr = "0x" + "11" * 32 + with ( + patch("decibel.write.get_market_addr", return_value=market_addr), + patch("decibel.write.get_primary_subaccount_addr", return_value=TEST_SUBACCOUNT_ADDR), + ): + write_dex_sync.cancel_client_order( + client_order_id="cid-1", market_name=TEST_MARKET_NAME + ) + + payload: InputEntryFunctionData = write_dex_sync._send_tx.call_args.args[0] + assert ( + payload.function + == f"{TEST_PACKAGE}::dex_accounts_entry::cancel_client_order_to_subaccount" + ) + assert payload.function_arguments == [TEST_SUBACCOUNT_ADDR, "cid-1", market_addr] + + +class TestDecibelWriteDexSyncConfigureUserSettings: + def test_sends_correct_payload(self, write_dex_sync: DecibelWriteDexSync) -> None: + market_addr = "0x" + "11" * 32 + write_dex_sync.configure_user_settings_for_market( + market_addr=market_addr, + subaccount_addr=TEST_SUBACCOUNT_ADDR, + is_cross=False, + user_leverage=5, + ) + payload: InputEntryFunctionData = write_dex_sync._send_tx.call_args.args[0] + assert ( + payload.function + == f"{TEST_PACKAGE}::dex_accounts_entry::configure_user_settings_for_market" + ) + assert payload.function_arguments == [ + TEST_SUBACCOUNT_ADDR, + market_addr, + False, + 5, + ] + + +class TestDecibelWriteDexSyncWithSubaccount: + def test_with_subaccount_uses_primary_when_none( + self, write_dex_sync: DecibelWriteDexSync + ) -> None: + called_with: list[str] = [] + + def fn(addr: str) -> str: + called_with.append(addr) + return "ok" + + with patch("decibel.write.get_primary_subaccount_addr", return_value=TEST_SUBACCOUNT_ADDR): + result = write_dex_sync.with_subaccount(fn) + + assert called_with == [TEST_SUBACCOUNT_ADDR] + assert result == "ok" + + def test_with_subaccount_uses_provided_addr(self, write_dex_sync: DecibelWriteDexSync) -> None: + called_with: list[str] = [] + + def fn(addr: str) -> str: + called_with.append(addr) + return "done" + + result = write_dex_sync.with_subaccount(fn, subaccount_addr=TEST_SUBACCOUNT_ADDR) + assert called_with == [TEST_SUBACCOUNT_ADDR] + assert result == "done"