diff --git a/ldapserver/asn1.py b/ldapserver/asn1.py index 57ae0e790cf1d45e61332d3359d8bb38546ef0a1..fdbda7b91cd2c7706b1f203a7fcd73ea2df6fc52 100644 --- a/ldapserver/asn1.py +++ b/ldapserver/asn1.py @@ -41,14 +41,6 @@ def decode_ber(data): rest = data[index + length:] return BERObject((ber_class, ber_constructed, ber_type), ber_content), rest -def decode_ber_integer(data): - if not data: - return 0 - value = -1 if data[0] & 0x80 else 0 - for octet in data: - value = value << 8 | octet - return value - def encode_ber(obj): tag = (obj.tag[0] & 0b11) << 6 | (obj.tag[1] & 1) << 5 | (obj.tag[2] & 0b11111) length = len(obj.content) @@ -60,11 +52,6 @@ def encode_ber(obj): length = length >> 8 return bytes([tag, 0x80 | len(octets)]) + bytes(reversed(octets)) + obj.content -def encode_ber_integer(value): - if value < 0 or value > 255: - raise NotImplementedError('Encoding of integers greater than 255 is not implemented') - return bytes([value]) - class BERType(ABC): @classmethod @abstractmethod @@ -103,13 +90,19 @@ class Integer(BERType): obj, rest = decode_ber(data) if obj.tag != cls.BER_TAG: raise ValueError() - return decode_ber_integer(obj.content), rest + return int.from_bytes(obj.content, 'big', signed=True), rest @classmethod def to_ber(cls, obj): if not isinstance(obj, int): raise TypeError() - return encode_ber(BERObject(cls.BER_TAG, encode_ber_integer(obj))) + if obj < 0: + res = obj.to_bytes((8 + (obj + 1).bit_length()) // 8, byteorder='big', signed=True) + else: + res = obj.to_bytes(max(1, (obj.bit_length() + 7) // 8), 'big', signed=False) + if res[0] & 0x80: + res = b'\x00' + res + return encode_ber(BERObject(cls.BER_TAG, res)) class Boolean(BERType): BER_TAG = (0, False, 1) @@ -117,9 +110,9 @@ class Boolean(BERType): @classmethod def from_ber(cls, data): obj, rest = decode_ber(data) - if obj.tag != cls.BER_TAG: + if obj.tag != cls.BER_TAG or len(obj.content) != 1: raise ValueError() - return bool(decode_ber_integer(obj.content)), rest + return bool(obj.content[0]), rest @classmethod def to_ber(cls, obj): @@ -269,23 +262,20 @@ def retag(cls, tag): BER_TAG = tag return Overwritten -class Enum(BERType): +class Enum(Integer): BER_TAG = (0, False, 10) ENUM_TYPE = typing.ClassVar[enum.Enum] @classmethod def from_ber(cls, data): - obj, rest = decode_ber(data) - if obj.tag != cls.BER_TAG: - raise ValueError() - value = decode_ber_integer(obj.content) + value, rest = super().from_ber(data) return cls.ENUM_TYPE(value), rest @classmethod def to_ber(cls, obj): if not isinstance(obj, cls.ENUM_TYPE): raise TypeError() - return encode_ber(BERObject(cls.BER_TAG, encode_ber_integer(obj.value))) + return super().to_ber(obj.value) def wrapenum(enumtype): class WrappedEnum(Enum): diff --git a/tests/test_asn1.py b/tests/test_asn1.py new file mode 100644 index 0000000000000000000000000000000000000000..021929ea8c7df491c73bae18fa4f6ac90c22490f --- /dev/null +++ b/tests/test_asn1.py @@ -0,0 +1,58 @@ +import unittest +import enum + +from ldapserver import asn1 + +class TestOctetString(unittest.TestCase): + def test_from_ber(self): + self.assertEqual(asn1.OctetString.from_ber(b'\x04\x00'), (b'', b'')) + self.assertEqual(asn1.OctetString.from_ber(b'\x04\x03foo'), (b'foo', b'')) + + def test_to_ber(self): + self.assertEqual(asn1.OctetString.to_ber(b''), b'\x04\x00') + self.assertEqual(asn1.OctetString.to_ber(b'foo'), b'\x04\x03foo') + +class TestInteger(unittest.TestCase): + def test_from_ber(self): + self.assertEqual(asn1.Integer.from_ber(b'\x02\x01\x00'), (0, b'')) + self.assertEqual(asn1.Integer.from_ber(b'\x02\x01\x01'), (1, b'')) + self.assertEqual(asn1.Integer.from_ber(b'\x02\x01\x7f'), (127, b'')) + self.assertEqual(asn1.Integer.from_ber(b'\x02\x02\x00\x80'), (128, b'')) + self.assertEqual(asn1.Integer.from_ber(b'\x02\x02\x01\x00'), (256, b'')) + self.assertEqual(asn1.Integer.from_ber(b'\x02\x01\xff'), (-1, b'')) + self.assertEqual(asn1.Integer.from_ber(b'\x02\x01\x80'), (-128, b'')) + self.assertEqual(asn1.Integer.from_ber(b'\x02\x02\xff\x7f'), (-129, b'')) + + def test_to_ber(self): + self.assertEqual(asn1.Integer.to_ber(0), b'\x02\x01\x00') + self.assertEqual(asn1.Integer.to_ber(1), b'\x02\x01\x01') + self.assertEqual(asn1.Integer.to_ber(127), b'\x02\x01\x7f') + self.assertEqual(asn1.Integer.to_ber(128), b'\x02\x02\x00\x80') + self.assertEqual(asn1.Integer.to_ber(256), b'\x02\x02\x01\x00') + self.assertEqual(asn1.Integer.to_ber(-1), b'\x02\x01\xff') + self.assertEqual(asn1.Integer.to_ber(-128), b'\x02\x01\x80') + self.assertEqual(asn1.Integer.to_ber(-129), b'\x02\x02\xff\x7f') + +class TestBoolean(unittest.TestCase): + def test_from_ber(self): + self.assertEqual(asn1.Boolean.from_ber(b'\x01\x01\xff'), (True, b'')) + self.assertEqual(asn1.Boolean.from_ber(b'\x01\x01\x00'), (False, b'')) + + def test_to_ber(self): + self.assertEqual(asn1.Boolean.to_ber(True), b'\x01\x01\xff') + self.assertEqual(asn1.Boolean.to_ber(False), b'\x01\x01\x00') + +class TestEnum(unittest.TestCase): + def test_from_ber(self): + class CustomEnum(enum.Enum): + NULL = 0 + ONE = 1 + self.assertEqual(asn1.wrapenum(CustomEnum).from_ber(b'\x0a\x01\x00'), (CustomEnum.NULL, b'')) + self.assertEqual(asn1.wrapenum(CustomEnum).from_ber(b'\x0a\x01\x01'), (CustomEnum.ONE, b'')) + + def test_to_ber(self): + class CustomEnum(enum.Enum): + NULL = 0 + ONE = 1 + self.assertEqual(asn1.wrapenum(CustomEnum).to_ber(CustomEnum.NULL), b'\x0a\x01\x00') + self.assertEqual(asn1.wrapenum(CustomEnum).to_ber(CustomEnum.ONE), b'\x0a\x01\x01')