Skip to content
Snippets Groups Projects
Commit 5bb03903 authored by Julian Rother's avatar Julian Rother
Browse files

Fixed BER integer encoding and simplified decoding

parent 70be07d3
No related branches found
No related tags found
No related merge requests found
Pipeline #7065 passed
...@@ -41,14 +41,6 @@ def decode_ber(data): ...@@ -41,14 +41,6 @@ def decode_ber(data):
rest = data[index + length:] rest = data[index + length:]
return BERObject((ber_class, ber_constructed, ber_type), ber_content), rest return BERObject((ber_class, ber_constructed, ber_type), ber_content), rest
def decode_ber_integer(data):
if not data:
return 0
value = -1 if data[0] & 0x80 else 0
for octet in data:
value = value << 8 | octet
return value
def encode_ber(obj): def encode_ber(obj):
tag = (obj.tag[0] & 0b11) << 6 | (obj.tag[1] & 1) << 5 | (obj.tag[2] & 0b11111) tag = (obj.tag[0] & 0b11) << 6 | (obj.tag[1] & 1) << 5 | (obj.tag[2] & 0b11111)
length = len(obj.content) length = len(obj.content)
...@@ -60,11 +52,6 @@ def encode_ber(obj): ...@@ -60,11 +52,6 @@ def encode_ber(obj):
length = length >> 8 length = length >> 8
return bytes([tag, 0x80 | len(octets)]) + bytes(reversed(octets)) + obj.content return bytes([tag, 0x80 | len(octets)]) + bytes(reversed(octets)) + obj.content
def encode_ber_integer(value):
if value < 0 or value > 255:
raise NotImplementedError('Encoding of integers greater than 255 is not implemented')
return bytes([value])
class BERType(ABC): class BERType(ABC):
@classmethod @classmethod
@abstractmethod @abstractmethod
...@@ -103,13 +90,19 @@ class Integer(BERType): ...@@ -103,13 +90,19 @@ class Integer(BERType):
obj, rest = decode_ber(data) obj, rest = decode_ber(data)
if obj.tag != cls.BER_TAG: if obj.tag != cls.BER_TAG:
raise ValueError() raise ValueError()
return decode_ber_integer(obj.content), rest return int.from_bytes(obj.content, 'big', signed=True), rest
@classmethod @classmethod
def to_ber(cls, obj): def to_ber(cls, obj):
if not isinstance(obj, int): if not isinstance(obj, int):
raise TypeError() raise TypeError()
return encode_ber(BERObject(cls.BER_TAG, encode_ber_integer(obj))) if obj < 0:
res = obj.to_bytes((8 + (obj + 1).bit_length()) // 8, byteorder='big', signed=True)
else:
res = obj.to_bytes(max(1, (obj.bit_length() + 7) // 8), 'big', signed=False)
if res[0] & 0x80:
res = b'\x00' + res
return encode_ber(BERObject(cls.BER_TAG, res))
class Boolean(BERType): class Boolean(BERType):
BER_TAG = (0, False, 1) BER_TAG = (0, False, 1)
...@@ -117,9 +110,9 @@ class Boolean(BERType): ...@@ -117,9 +110,9 @@ class Boolean(BERType):
@classmethod @classmethod
def from_ber(cls, data): def from_ber(cls, data):
obj, rest = decode_ber(data) obj, rest = decode_ber(data)
if obj.tag != cls.BER_TAG: if obj.tag != cls.BER_TAG or len(obj.content) != 1:
raise ValueError() raise ValueError()
return bool(decode_ber_integer(obj.content)), rest return bool(obj.content[0]), rest
@classmethod @classmethod
def to_ber(cls, obj): def to_ber(cls, obj):
...@@ -269,23 +262,20 @@ def retag(cls, tag): ...@@ -269,23 +262,20 @@ def retag(cls, tag):
BER_TAG = tag BER_TAG = tag
return Overwritten return Overwritten
class Enum(BERType): class Enum(Integer):
BER_TAG = (0, False, 10) BER_TAG = (0, False, 10)
ENUM_TYPE = typing.ClassVar[enum.Enum] ENUM_TYPE = typing.ClassVar[enum.Enum]
@classmethod @classmethod
def from_ber(cls, data): def from_ber(cls, data):
obj, rest = decode_ber(data) value, rest = super().from_ber(data)
if obj.tag != cls.BER_TAG:
raise ValueError()
value = decode_ber_integer(obj.content)
return cls.ENUM_TYPE(value), rest return cls.ENUM_TYPE(value), rest
@classmethod @classmethod
def to_ber(cls, obj): def to_ber(cls, obj):
if not isinstance(obj, cls.ENUM_TYPE): if not isinstance(obj, cls.ENUM_TYPE):
raise TypeError() raise TypeError()
return encode_ber(BERObject(cls.BER_TAG, encode_ber_integer(obj.value))) return super().to_ber(obj.value)
def wrapenum(enumtype): def wrapenum(enumtype):
class WrappedEnum(Enum): class WrappedEnum(Enum):
......
import unittest
import enum
from ldapserver import asn1
class TestOctetString(unittest.TestCase):
def test_from_ber(self):
self.assertEqual(asn1.OctetString.from_ber(b'\x04\x00'), (b'', b''))
self.assertEqual(asn1.OctetString.from_ber(b'\x04\x03foo'), (b'foo', b''))
def test_to_ber(self):
self.assertEqual(asn1.OctetString.to_ber(b''), b'\x04\x00')
self.assertEqual(asn1.OctetString.to_ber(b'foo'), b'\x04\x03foo')
class TestInteger(unittest.TestCase):
def test_from_ber(self):
self.assertEqual(asn1.Integer.from_ber(b'\x02\x01\x00'), (0, b''))
self.assertEqual(asn1.Integer.from_ber(b'\x02\x01\x01'), (1, b''))
self.assertEqual(asn1.Integer.from_ber(b'\x02\x01\x7f'), (127, b''))
self.assertEqual(asn1.Integer.from_ber(b'\x02\x02\x00\x80'), (128, b''))
self.assertEqual(asn1.Integer.from_ber(b'\x02\x02\x01\x00'), (256, b''))
self.assertEqual(asn1.Integer.from_ber(b'\x02\x01\xff'), (-1, b''))
self.assertEqual(asn1.Integer.from_ber(b'\x02\x01\x80'), (-128, b''))
self.assertEqual(asn1.Integer.from_ber(b'\x02\x02\xff\x7f'), (-129, b''))
def test_to_ber(self):
self.assertEqual(asn1.Integer.to_ber(0), b'\x02\x01\x00')
self.assertEqual(asn1.Integer.to_ber(1), b'\x02\x01\x01')
self.assertEqual(asn1.Integer.to_ber(127), b'\x02\x01\x7f')
self.assertEqual(asn1.Integer.to_ber(128), b'\x02\x02\x00\x80')
self.assertEqual(asn1.Integer.to_ber(256), b'\x02\x02\x01\x00')
self.assertEqual(asn1.Integer.to_ber(-1), b'\x02\x01\xff')
self.assertEqual(asn1.Integer.to_ber(-128), b'\x02\x01\x80')
self.assertEqual(asn1.Integer.to_ber(-129), b'\x02\x02\xff\x7f')
class TestBoolean(unittest.TestCase):
def test_from_ber(self):
self.assertEqual(asn1.Boolean.from_ber(b'\x01\x01\xff'), (True, b''))
self.assertEqual(asn1.Boolean.from_ber(b'\x01\x01\x00'), (False, b''))
def test_to_ber(self):
self.assertEqual(asn1.Boolean.to_ber(True), b'\x01\x01\xff')
self.assertEqual(asn1.Boolean.to_ber(False), b'\x01\x01\x00')
class TestEnum(unittest.TestCase):
def test_from_ber(self):
class CustomEnum(enum.Enum):
NULL = 0
ONE = 1
self.assertEqual(asn1.wrapenum(CustomEnum).from_ber(b'\x0a\x01\x00'), (CustomEnum.NULL, b''))
self.assertEqual(asn1.wrapenum(CustomEnum).from_ber(b'\x0a\x01\x01'), (CustomEnum.ONE, b''))
def test_to_ber(self):
class CustomEnum(enum.Enum):
NULL = 0
ONE = 1
self.assertEqual(asn1.wrapenum(CustomEnum).to_ber(CustomEnum.NULL), b'\x0a\x01\x00')
self.assertEqual(asn1.wrapenum(CustomEnum).to_ber(CustomEnum.ONE), b'\x0a\x01\x01')
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment