Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 15 additions & 5 deletions scapy/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,8 @@ class Field(Generic[I, M], metaclass=Field_metaclass):
islist = 0
ismutable = False
holds_packets = 0
_is_conditional = False
_may_end = False

def __init__(self, name, default, fmt="H"):
# type: (str, Any, str) -> None
Expand Down Expand Up @@ -198,7 +200,7 @@ def i2count(self, pkt, x):
def h2i(self, pkt, x):
# type: (Optional[Packet], Any) -> I
"""Convert human value to internal value"""
return cast(I, x)
return x # type: ignore

def i2h(self, pkt, x):
# type: (Optional[Packet], I) -> Any
Expand All @@ -208,16 +210,16 @@ def i2h(self, pkt, x):
def m2i(self, pkt, x):
# type: (Optional[Packet], M) -> I
"""Convert machine value to internal value"""
return cast(I, x)
return x # type: ignore

def i2m(self, pkt, x):
# type: (Optional[Packet], Optional[I]) -> M
"""Convert internal value to machine value"""
if x is None:
return cast(M, 0)
return 0 # type: ignore
elif isinstance(x, str):
return cast(M, bytes_encode(x))
return cast(M, x)
return bytes_encode(x) # type: ignore
return x # type: ignore

def any2i(self, pkt, x):
# type: (Optional[Packet], Any) -> Optional[I]
Expand Down Expand Up @@ -257,6 +259,10 @@ def getfield(self, pkt, s):
first the raw packet string after having removed the extracted field,
second the extracted field itself in internal representation.
"""
# Use unpack_from for plain bytes (avoids temporary slice allocation).
# Fall back to unpack+slice for subclasses that override __getitem__.
if type(s) is bytes:
return s[self.sz:], self.m2i(pkt, self.struct.unpack_from(s)[0])
return s[self.sz:], self.m2i(pkt, self.struct.unpack(s[:self.sz])[0])

def do_copy(self, x):
Expand Down Expand Up @@ -311,6 +317,8 @@ class _FieldContainer(object):
A field that acts as a container for another field
"""
__slots__ = ["fld"]
_is_conditional = False
_may_end = False

def __getattr__(self, attr):
# type: (str) -> Any
Expand Down Expand Up @@ -349,6 +357,7 @@ class MayEnd(_FieldContainer):
to an empty value, else the behavior will be unexpected.
"""
__slots__ = ["fld"]
_may_end = True

def __init__(self, fld):
# type: (Any) -> None
Expand Down Expand Up @@ -380,6 +389,7 @@ def any2i(self, pkt, val):

class ConditionalField(_FieldContainer):
__slots__ = ["fld", "cond"]
_is_conditional = True

def __init__(self,
fld, # type: AnyField
Expand Down
71 changes: 51 additions & 20 deletions scapy/packet.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
Field,
FlagsField,
FlagValue,
MayEnd,
MultiEnumField,
MultipleTypeField,
PadField,
Expand Down Expand Up @@ -155,7 +154,7 @@ def __init__(self,
**fields # type: Any
):
# type: (...) -> None
self.time = time.time() # type: Union[EDecimal, float]
self.time = 0.0 if _internal else time.time() # type: Union[EDecimal, float]
self.sent_time = None # type: Union[EDecimal, float, None]
self.name = (self.__class__.__name__
if self._name is None else
Expand Down Expand Up @@ -352,11 +351,12 @@ def do_init_cached_fields(self, for_dissect_only=False):
cls_name = self.__class__

# Build the fields information
if Packet.class_default_fields.get(cls_name, None) is None:
default_fields = Packet.class_default_fields.get(cls_name)
if default_fields is None:
self.prepare_cached_fields(self.fields_desc)
default_fields = Packet.class_default_fields.get(cls_name)

# Use fields information from cache
default_fields = Packet.class_default_fields.get(cls_name, None)
if default_fields:
self.default_fields = default_fields
self.fieldtype = Packet.class_fieldtype[cls_name]
Expand Down Expand Up @@ -516,13 +516,18 @@ def getfieldval(self, attr):
# type: (str) -> Any
if self.deprecated_fields and attr in self.deprecated_fields:
attr = self._resolve_alias(attr)
if attr in self.fields:
try:
return self.fields[attr]
if attr in self.overloaded_fields:
except KeyError:
pass
try:
return self.overloaded_fields[attr]
if attr in self.default_fields:
except KeyError:
pass
try:
return self.default_fields[attr]
return self.payload.getfieldval(attr)
except KeyError:
return self.payload.getfieldval(attr)

def getfield_and_val(self, attr):
# type: (str) -> Tuple[AnyField, Any]
Expand Down Expand Up @@ -725,17 +730,24 @@ def copy_fields_dict(self, fields):
def _raw_packet_cache_field_value(self, fld, val, copy=False):
# type: (AnyField, Any, bool) -> Optional[Any]
"""Get a value representative of a mutable field to detect changes"""
_cpy = lambda x: fld.do_copy(x) if copy else x # type: Callable[[Any], Any]
if fld.holds_packets:
# avoid copying whole packets (perf: #GH3894)
if fld.islist:
if copy:
return [
(fld.do_copy(x.fields), x.payload.raw_packet_cache)
for x in val
]
return [
(_cpy(x.fields), x.payload.raw_packet_cache) for x in val
(x.fields, x.payload.raw_packet_cache) for x in val
]
else:
return (_cpy(val.fields), val.payload.raw_packet_cache)
if copy:
return (fld.do_copy(val.fields),
val.payload.raw_packet_cache)
return (val.fields, val.payload.raw_packet_cache)
elif fld.islist or fld.ismutable:
return _cpy(val)
return fld.do_copy(val) if copy else val
return None

def clear_cache(self):
Expand Down Expand Up @@ -1081,7 +1093,7 @@ def do_dissect(self, s):
for f in self.fields_desc:
s, fval = f.getfield(self, s)
# Skip unused ConditionalField
if isinstance(f, ConditionalField) and fval is None:
if f._is_conditional and fval is None:
continue
# We need to track fields with mutable values to discard
# .raw_packet_cache when needed.
Expand All @@ -1090,9 +1102,9 @@ def do_dissect(self, s):
self._raw_packet_cache_field_value(f, fval, copy=True)
self.fields[f.name] = fval
# Nothing left to dissect
if not s and (isinstance(f, MayEnd) or
(fval is not None and isinstance(f, ConditionalField) and
isinstance(f.fld, MayEnd))):
if not s and (f._may_end or
(fval is not None and f._is_conditional and
f.fld._may_end)): # type: ignore
break
self.raw_packet_cache = _raw[:-len(s)] if s else _raw
self.explicit = 1
Expand Down Expand Up @@ -1162,10 +1174,29 @@ def guess_payload_class(self, payload):
for t in self.aliastypes:
for fval, cls in t.payload_guess:
try:
if all(v == self.getfieldval(k)
for k, v in fval.items()):
fields = self.fields
overloaded = self.overloaded_fields
default = self.default_fields
deprecated_fields = self.deprecated_fields
matched = True
for k, v in fval.items():
if deprecated_fields:
fv = self.getfieldval(k)
else:
# Inline getfieldval for speed when there are no
# deprecated field aliases to resolve.
if k in fields:
fv = fields[k]
elif k in overloaded:
fv = overloaded[k]
else:
fv = default[k]
if v != fv:
matched = False
break
if matched:
return cls # type: ignore
except AttributeError:
except (AttributeError, KeyError):
pass
return self.default_payload_class(payload)

Expand Down Expand Up @@ -1858,7 +1889,7 @@ def __new__(cls, *args, **kargs):
if singl is None:
cls.__singl__ = singl = Packet.__new__(cls)
Packet.__init__(singl)
return cast(NoPayload, singl)
return singl # type: ignore

def __init__(self, *args, **kargs):
# type: (*Any, **Any) -> None
Expand Down
Loading