Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 24 additions & 34 deletions scapy/asn1/ber.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,11 @@
ASN1_Object,
_ASN1_ERROR,
)
from scapy.libs.codec import GenericCodec_metaclass, GenericCodecObject

from typing import (
Any,
AnyStr,
Dict,
Generic,
List,
Optional,
Tuple,
Expand Down Expand Up @@ -267,43 +266,36 @@ def BER_tagging_enc(s, implicit_tag=None, explicit_tag=None):
# [ BER classes ] #


class BERcodec_metaclass(type):
def __new__(cls,
name, # type: str
bases, # type: Tuple[type, ...]
dct # type: Dict[str, Any]
):
# type: (...) -> Type[BERcodec_Object[Any]]
c = cast('Type[BERcodec_Object[Any]]',
super(BERcodec_metaclass, cls).__new__(cls, name, bases, dct))
try:
c.tag.register(c.codec, c)
except Exception:
warning("Error registering %r for %r" % (c.tag, c.codec))
return c
class BERcodec_metaclass(GenericCodec_metaclass):
"""Metaclass for BER codec objects.

Inherits the tag registration logic from ``GenericCodec_metaclass`` and
adds a BER-specific warning when registration fails.
"""

@classmethod
def _handle_registration_error(cls, c, exc):
# type: (Type[Any], Exception) -> None
warning("Error registering %r for %r" % (c.tag, c.codec))


_K = TypeVar('_K')


class BERcodec_Object(Generic[_K], metaclass=BERcodec_metaclass):
class BERcodec_Object(GenericCodecObject[_K], metaclass=BERcodec_metaclass):
codec = ASN1_Codecs.BER
tag = ASN1_Class_UNIVERSAL.ANY

# Attributes consumed by GenericCodecObject.check_string and .dec
_decoding_error_class = BER_Decoding_Error
_generic_error_classes = (BER_Decoding_Error, ASN1_Error)
_decoding_error_object_class = ASN1_DECODING_ERROR

@classmethod
def asn1_object(cls, val):
# type: (_K) -> ASN1_Object[_K]
return cls.tag.asn1_object(val)

@classmethod
def check_string(cls, s):
# type: (bytes) -> None
if not s:
raise BER_Decoding_Error(
"%s: Got empty object while expecting tag %r" %
(cls.__name__, cls.tag), remaining=s
)

@classmethod
def check_type(cls, s):
# type: (bytes) -> bytes
Expand Down Expand Up @@ -369,6 +361,12 @@ def dec(cls,
safe=False, # type: bool
):
# type: (...) -> Tuple[Union[_ASN1_ERROR, ASN1_Object[_K]], bytes]
# BER overrides dec from GenericCodecObject to add special recovery for
# BER_BadTag_Decoding_Error: instead of wrapping the error in a
# DECODING_ERROR object, it recursively tries to decode from the
# remaining bytes and wraps the result in ASN1_BADTAG.
# Other BER/ASN1 errors are handled inline (same semantics as the
# generic dec, but with BER-specific exception types).
if not safe:
return cls.do_dec(s, context, safe)
try:
Expand All @@ -383,14 +381,6 @@ def dec(cls,
except ASN1_Error as e:
return ASN1_DECODING_ERROR(s, exc=e), b""

@classmethod
def safedec(cls,
s, # type: bytes
context=None, # type: Optional[Type[ASN1_Class]]
):
# type: (...) -> Tuple[Union[_ASN1_ERROR, ASN1_Object[_K]], bytes]
return cls.dec(s, context, safe=True)

@classmethod
def enc(cls, s, size_len=0):
# type: (_K, Optional[int]) -> bytes
Expand Down
134 changes: 22 additions & 112 deletions scapy/asn1fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@
Classes that implement ASN.1 data structures.
"""

import copy

from functools import reduce

from scapy.asn1.asn1 import (
Expand All @@ -30,7 +28,11 @@
BER_tagging_dec,
BER_tagging_enc,
)
from scapy.base_classes import BasePacket
from scapy.libs.codec import (
GenericCodecField,
GenericCodecField_element,
GenericCodecOptionalField,
)
from scapy.volatile import (
GeneralizedTime,
RandChoice,
Expand All @@ -48,7 +50,6 @@
AnyStr,
Callable,
Dict,
Generic,
List,
Optional,
Tuple,
Expand All @@ -67,7 +68,7 @@ class ASN1F_badsequence(Exception):
pass


class ASN1F_element(object):
class ASN1F_element(GenericCodecField_element):
pass


Expand All @@ -79,11 +80,10 @@ class ASN1F_element(object):
_A = TypeVar('_A') # ASN.1 object


class ASN1F_field(ASN1F_element, Generic[_I, _A]):
holds_packets = 0
islist = 0
class ASN1F_field(GenericCodecField[_I, _A], ASN1F_element):
ASN1_tag = ASN1_Class_UNIVERSAL.ANY
context = ASN1_Class_UNIVERSAL # type: Type[ASN1_Class]
_badsequence_error_class = ASN1F_badsequence

def __init__(self,
name, # type: str
Expand Down Expand Up @@ -115,18 +115,6 @@ def __init__(self,
self.network_tag = int(implicit_tag or explicit_tag or self.ASN1_tag)
self.owners = [] # type: List[Type[ASN1_Packet]]

def register_owner(self, cls):
# type: (Type[ASN1_Packet]) -> None
self.owners.append(cls)

def i2repr(self, pkt, x):
# type: (ASN1_Packet, _I) -> str
return repr(x)

def i2h(self, pkt, x):
# type: (ASN1_Packet, _I) -> Any
return x

def m2i(self, pkt, s):
# type: (ASN1_Packet, bytes) -> Tuple[_A, bytes]
"""
Expand Down Expand Up @@ -176,74 +164,10 @@ def i2m(self, pkt, x):
implicit_tag=self.implicit_tag,
explicit_tag=self.explicit_tag)

def any2i(self, pkt, x):
# type: (ASN1_Packet, Any) -> _I
return cast(_I, x)

def extract_packet(self,
cls, # type: Type[ASN1_Packet]
s, # type: bytes
_underlayer=None # type: Optional[ASN1_Packet]
):
# type: (...) -> Tuple[ASN1_Packet, bytes]
try:
c = cls(s, _underlayer=_underlayer)
except ASN1F_badsequence:
c = packet.Raw(s, _underlayer=_underlayer) # type: ignore
cpad = c.getlayer(packet.Raw)
s = b""
if cpad is not None:
s = cpad.load
if cpad.underlayer:
del cpad.underlayer.payload
return c, s

def build(self, pkt):
# type: (ASN1_Packet) -> bytes
return self.i2m(pkt, getattr(pkt, self.name))

def dissect(self, pkt, s):
# type: (ASN1_Packet, bytes) -> bytes
v, s = self.m2i(pkt, s)
self.set_val(pkt, v)
return s

def do_copy(self, x):
# type: (Any) -> Any
if isinstance(x, list):
x = x[:]
for i in range(len(x)):
if isinstance(x[i], BasePacket):
x[i] = x[i].copy()
return x
if hasattr(x, "copy"):
return x.copy()
return x

def set_val(self, pkt, val):
# type: (ASN1_Packet, Any) -> None
setattr(pkt, self.name, val)

def is_empty(self, pkt):
# type: (ASN1_Packet) -> bool
return getattr(pkt, self.name) is None

def get_fields_list(self):
# type: () -> List[ASN1F_field[Any, Any]]
return [self]

def __str__(self):
# type: () -> str
return repr(self)

def randval(self):
# type: () -> RandField[_I]
return cast(RandField[_I], RandInt())

def copy(self):
# type: () -> ASN1F_field[_I, _A]
return copy.copy(self)


############################
# Simple ASN1 Fields #
Expand Down Expand Up @@ -562,7 +486,7 @@ def is_empty(self,
pkt, # type: ASN1_Packet
):
# type: (...) -> bool
return ASN1F_field.is_empty(self, pkt)
return ASN1F_field.is_empty(self, pkt) # type: ignore

def m2i(self,
pkt, # type: ASN1_Packet
Expand Down Expand Up @@ -640,49 +564,35 @@ class ASN1F_TIME_TICKS(ASN1F_INTEGER):
# Complex ASN1 Fields #
#############################

class ASN1F_optional(ASN1F_element):
class ASN1F_optional(GenericCodecOptionalField, ASN1F_element):
"""
ASN.1 field that is optional.
"""
_optional_error_classes = (
ASN1_Error, ASN1F_badsequence, BER_Decoding_Error
)

def __init__(self, field):
# type: (ASN1F_field[Any, Any]) -> None
field.flexible_tag = False
self._field = field

def __getattr__(self, attr):
# type: (str) -> Optional[Any]
return getattr(self._field, attr)

def m2i(self, pkt, s):
# type: (ASN1_Packet, bytes) -> Tuple[Any, bytes]
try:
return self._field.m2i(pkt, s)
except (ASN1_Error, ASN1F_badsequence, BER_Decoding_Error):
except self._optional_error_classes:
# ASN1_Error may be raised by ASN1F_CHOICE
return None, s

def dissect(self, pkt, s):
# type: (ASN1_Packet, bytes) -> bytes
try:
return self._field.dissect(pkt, s)
except (ASN1_Error, ASN1F_badsequence, BER_Decoding_Error):
return cast(bytes, self._field.dissect(pkt, s))
except self._optional_error_classes:
self._field.set_val(pkt, None)
return s

def build(self, pkt):
# type: (ASN1_Packet) -> bytes
if self._field.is_empty(pkt):
return b""
return self._field.build(pkt)

def any2i(self, pkt, x):
# type: (ASN1_Packet, Any) -> Any
return self._field.any2i(pkt, x)

def i2repr(self, pkt, x):
# type: (ASN1_Packet, Any) -> str
return self._field.i2repr(pkt, x)


class ASN1F_omit(ASN1F_field[None, None]):
"""
Expand Down Expand Up @@ -842,7 +752,7 @@ def m2i(self, pkt, s):
cls = self.cls
if not hasattr(cls, "ASN1_root"):
# A normal Packet (!= ASN1)
return self.extract_packet(cls, s, _underlayer=pkt)
return self.extract_packet(cls, s, _underlayer=pkt) # type: ignore
diff_tag, s = BER_tagging_dec(s, hidden_tag=cls.ASN1_root.ASN1_tag, # noqa: E501
implicit_tag=self.implicit_tag,
explicit_tag=self.explicit_tag,
Expand All @@ -855,7 +765,7 @@ def m2i(self, pkt, s):
self.explicit_tag = diff_tag
if not s:
return None, s
return self.extract_packet(cls, s, _underlayer=pkt)
return self.extract_packet(cls, s, _underlayer=pkt) # type: ignore

def i2m(self,
pkt, # type: ASN1_Packet
Expand Down Expand Up @@ -887,7 +797,7 @@ def any2i(self,
# type: (...) -> 'ASN1_Packet'
if hasattr(x, "add_underlayer"):
x.add_underlayer(pkt) # type: ignore
return super(ASN1F_PACKET, self).any2i(pkt, x)
return super(ASN1F_PACKET, self).any2i(pkt, x) # type: ignore

def randval(self): # type: ignore
# type: () -> ASN1_Packet
Expand Down Expand Up @@ -972,7 +882,7 @@ def any2i(self, pkt, x):
value[self.mapping.index(i)] = "1"
x = "".join(value)
x = ASN1_BIT_STRING(x)
return super(ASN1F_FLAGS, self).any2i(pkt, x)
return super(ASN1F_FLAGS, self).any2i(pkt, x) # type: ignore

def get_flags(self, pkt):
# type: (ASN1_Packet) -> List[str]
Expand All @@ -998,7 +908,7 @@ def i2m(self, pkt, val):
# type: (ASN1_Packet, Any) -> bytes
if hasattr(val, "ASN1_root"):
val = ASN1_STRING(bytes(val))
return super(ASN1F_STRING_PacketField, self).i2m(pkt, val)
return super(ASN1F_STRING_PacketField, self).i2m(pkt, val) # type: ignore

def any2i(self, pkt, x):
# type: (ASN1_Packet, Any) -> Any
Expand Down
4 changes: 2 additions & 2 deletions scapy/asn1packet.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ def self_build(self):
# type: () -> bytes
if self.raw_packet_cache is not None:
return self.raw_packet_cache
return self.ASN1_root.build(self)
return self.ASN1_root.build(self) # type: ignore

def do_dissect(self, x):
# type: (bytes) -> bytes
return self.ASN1_root.dissect(self, x)
return self.ASN1_root.dissect(self, x) # type: ignore
Loading
Loading