diff --git a/src/project_x_py/client/trading.py b/src/project_x_py/client/trading.py index 9761784..4a724e4 100644 --- a/src/project_x_py/client/trading.py +++ b/src/project_x_py/client/trading.py @@ -66,6 +66,7 @@ async def main(): import datetime import logging +from dataclasses import fields from datetime import timedelta from typing import Any @@ -77,6 +78,14 @@ async def main(): logger = logging.getLogger(__name__) +_POSITION_FIELDS = frozenset(field.name for field in fields(Position)) + + +def _position_from_response(data: dict[str, Any]) -> Position: + return Position( + **{field: data[field] for field in _POSITION_FIELDS if field in data} + ) + class TradingMixin: """Mixin class providing trading functionality.""" @@ -182,7 +191,7 @@ async def search_open_positions( else: return [] - return [Position(**pos) for pos in positions_data] + return [_position_from_response(pos) for pos in positions_data] async def search_trades( self, diff --git a/src/project_x_py/models.py b/src/project_x_py/models.py index 05c9897..ec17cb7 100644 --- a/src/project_x_py/models.py +++ b/src/project_x_py/models.py @@ -366,6 +366,7 @@ class Position: 0=UNDEFINED, 1=LONG, 2=SHORT size (int): Position size (number of contracts, always positive) averagePrice (float): Average entry price of the position + contractDisplayName (Optional[str]): Human-readable contract display name Note: This model contains only the fields returned by ProjectX API. @@ -385,11 +386,12 @@ class Position: type: int size: int averagePrice: float + contractDisplayName: str | None = None # Allow dict-like access for compatibility in tests/utilities - def __getitem__(self, key: str) -> Union[int, str, float]: + def __getitem__(self, key: str) -> Union[int, str, float, None]: value = getattr(self, key) - if isinstance(value, int | str | float): + if value is None or isinstance(value, int | str | float): return value else: raise TypeError( diff --git a/src/project_x_py/types/api_responses.py b/src/project_x_py/types/api_responses.py index 0c4d5ec..66422cf 100644 --- a/src/project_x_py/types/api_responses.py +++ b/src/project_x_py/types/api_responses.py @@ -133,6 +133,7 @@ class PositionResponse(TypedDict): type: int # 0=UNDEFINED, 1=LONG, 2=SHORT size: int averagePrice: float + contractDisplayName: NotRequired[str] class TradeResponse(TypedDict): diff --git a/tests/client/test_trading_legacy.py b/tests/client/test_trading_legacy.py index a9815c6..d9f3fc4 100644 --- a/tests/client/test_trading_legacy.py +++ b/tests/client/test_trading_legacy.py @@ -114,6 +114,45 @@ async def test_search_open_positions_success(self, trading_client): assert positions[1].type == 2 # SHORT trading_client._ensure_authenticated.assert_called_once() + @pytest.mark.asyncio + async def test_search_open_positions_preserves_display_name_and_ignores_unknown_fields( + self, trading_client + ): + """Test position search keeps known display fields and ignores unknown ones.""" + trading_client.account_info = Account( + id=12345, + name="Test Account", + balance=10000.0, + canTrade=True, + isVisible=True, + simulated=False, + ) + + mock_response = { + "success": True, + "positions": [ + { + "id": "pos1", + "accountId": 12345, + "contractId": "CON.F.US.MNQ.Z25", + "contractDisplayName": "MNQZ25", + "unknownGatewayField": "ignored", + "creationTimestamp": datetime.datetime.now(pytz.UTC).isoformat(), + "size": 2, + "averagePrice": 21342.25, + "type": 1, + }, + ], + } + trading_client._make_request.return_value = mock_response + + positions = await trading_client.search_open_positions() + + assert len(positions) == 1 + assert positions[0].contractId == "CON.F.US.MNQ.Z25" + assert positions[0].contractDisplayName == "MNQZ25" + assert positions[0].size == 2 + @pytest.mark.asyncio async def test_search_open_positions_with_account_id(self, trading_client): """Test position search with specific account ID.""" diff --git a/tests/types/test_api_responses.py b/tests/types/test_api_responses.py index 12a6231..1f920bb 100644 --- a/tests/types/test_api_responses.py +++ b/tests/types/test_api_responses.py @@ -115,6 +115,7 @@ def test_position_response_structure(self): assert hints["size"] is int assert "averagePrice" in hints assert hints["averagePrice"] is float + assert "contractDisplayName" in hints def test_trade_response_structure(self): """Test TradeResponse has correct fields.""" @@ -267,6 +268,7 @@ def test_real_world_response_creation(self): "id": 67890, "accountId": 12345, "contractId": "CON.F.US.MNQ.U25", + "contractDisplayName": "MNQU25", "creationTimestamp": "2024-01-01T10:00:00Z", "type": 1, # LONG "size": 5, @@ -275,6 +277,7 @@ def test_real_world_response_creation(self): assert position["type"] == 1 assert position["size"] == 5 + assert position["contractDisplayName"] == "MNQU25" def test_market_data_responses(self): """Test market data response structures.""" diff --git a/tests/types/test_models.py b/tests/types/test_models.py index 8457fd3..707cb50 100644 --- a/tests/types/test_models.py +++ b/tests/types/test_models.py @@ -130,6 +130,12 @@ def test_basic_properties_and_indexing(self): assert p.direction == "LONG" assert p["averagePrice"] == pytest.approx(2050.0) assert p.symbol == "MGC" + assert p.contractDisplayName is None + + def test_contract_display_name(self): + p = self.make_position(contractDisplayName="MGCM25") + assert p.contractDisplayName == "MGCM25" + assert p["contractDisplayName"] == "MGCM25" def test_short_position_helpers(self): p = self.make_position(type=2, size=3)