Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 64 additions & 5 deletions src/sortedcontainers/sortedset.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

"""

from collections.abc import MutableSet, Sequence, Set
from collections.abc import Iterable, MutableSet, Sequence, Set
from itertools import chain
from operator import eq, ge, gt, le, lt, ne
from textwrap import dedent
Expand Down Expand Up @@ -470,7 +470,21 @@ def difference(self, *iterables):
diff = self._set.difference(*iterables)
return self._fromset(diff, key=self._key)

__sub__ = difference
def __sub__(self, other):
"""Return the difference of two sets as a new sorted set.

``ss.__sub__(other)`` <==> ``ss - other``

Return :data:`NotImplemented` when `other` is not iterable so that
Python can fall back to `other`'s reflected operator instead of raising.

:param other: `other` iterable
:return: new sorted set

"""
if isinstance(other, Iterable):
return self.difference(other)
return NotImplemented

def difference_update(self, *iterables):
"""Remove all values of `iterables` from this sorted set.
Expand Down Expand Up @@ -524,7 +538,22 @@ def intersection(self, *iterables):
intersect = self._set.intersection(*iterables)
return self._fromset(intersect, key=self._key)

__and__ = intersection
def __and__(self, other):
"""Return the intersection of two sets as a new sorted set.

``ss.__and__(other)`` <==> ``ss & other``

Return :data:`NotImplemented` when `other` is not iterable so that
Python can fall back to `other`'s reflected operator instead of raising.

:param other: `other` iterable
:return: new sorted set

"""
if isinstance(other, Iterable):
return self.intersection(other)
return NotImplemented

__rand__ = __and__

def intersection_update(self, *iterables):
Expand Down Expand Up @@ -575,7 +604,22 @@ def symmetric_difference(self, other):
diff = self._set.symmetric_difference(other)
return self._fromset(diff, key=self._key)

__xor__ = symmetric_difference
def __xor__(self, other):
"""Return the symmetric difference with `other` as a new sorted set.

``ss.__xor__(other)`` <==> ``ss ^ other``

Return :data:`NotImplemented` when `other` is not iterable so that
Python can fall back to `other`'s reflected operator instead of raising.

:param other: `other` iterable
:return: new sorted set

"""
if isinstance(other, Iterable):
return self.symmetric_difference(other)
return NotImplemented

__rxor__ = __xor__

def symmetric_difference_update(self, other):
Expand Down Expand Up @@ -623,7 +667,22 @@ def union(self, *iterables):
"""
return self.__class__(chain(iter(self), *iterables), key=self._key)

__or__ = union
def __or__(self, other):
"""Return new sorted set with values from itself and `other`.

``ss.__or__(other)`` <==> ``ss | other``

Return :data:`NotImplemented` when `other` is not iterable so that
Python can fall back to `other`'s reflected operator instead of raising.

:param other: `other` iterable
:return: new sorted set

"""
if isinstance(other, Iterable):
return self.union(other)
return NotImplemented

__ror__ = __or__

def update(self, *iterables):
Expand Down
33 changes: 33 additions & 0 deletions tests/test_coverage_sortedset.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,3 +535,36 @@ def test_pickle():
beta = pickle.loads(data)
assert alpha == beta
assert alpha._key == beta._key


def test_set_op_returns_notimplemented():
# Set-algebra operators must return NotImplemented for an unsupported
# (non-iterable) operand instead of raising, so the other operand's
# reflected operator can run. See issue #219.
temp = SortedSet(range(10))

assert temp.__and__(5) is NotImplemented
assert temp.__or__(5) is NotImplemented
assert temp.__sub__(5) is NotImplemented
assert temp.__xor__(5) is NotImplemented

# Iterable operands keep working as before.
assert temp & [1, 2, 20] == SortedSet([1, 2])
assert temp - [0, 1] == SortedSet(range(2, 10))
assert temp ^ [0, 10] == SortedSet(list(range(1, 10)) + [10])
assert temp | [10] == SortedSet(range(11))

# A non-iterable operand that defines a reflected operator now wins.
class OnlyReflected:
def __rand__(self, other):
return 'reflected'

assert (temp & OnlyReflected()) == 'reflected'

# A plain non-iterable with no reflected operator still raises TypeError.
try:
temp & 5
except TypeError:
pass
else:
assert False, 'expected TypeError'