# -*- coding: utf-8 -*-
#
# This file is part of Python-ASN1. Python-ASN1 is free software that is
# made available under the MIT license. Consult the file "LICENSE" that is
# distributed together with this file for the exact licensing terms.
#
# Python-ASN1 is copyright (c) 2007-2016 by the Python-ASN1 authors. See the
# file "AUTHORS" for a complete overview.
"""
This module provides ASN.1 encoder and decoder.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import collections
import re
from builtins import bytes
from builtins import int
from builtins import range
from builtins import str
from enum import IntEnum
from numbers import Number
__version__ = "2.5.0"
[docs]class Numbers(IntEnum):
Boolean = 0x01
Integer = 0x02
BitString = 0x03
OctetString = 0x04
Null = 0x05
ObjectIdentifier = 0x06
Enumerated = 0x0a
UTF8String = 0x0c
Sequence = 0x10
Set = 0x11
PrintableString = 0x13
IA5String = 0x16
UTCTime = 0x17
UnicodeString = 0x1e
[docs]class Types(IntEnum):
Constructed = 0x20
Primitive = 0x00
[docs]class Classes(IntEnum):
Universal = 0x00
Application = 0x40
Context = 0x80
Private = 0xc0
Tag = collections.namedtuple('Tag', 'nr typ cls')
"""A named tuple to represent ASN.1 tags as returned by `Decoder.peek()` and
`Decoder.read()`."""
[docs]class Error(Exception):
"""ASN.11 encoding or decoding error."""
[docs]class Encoder(object):
"""ASN.1 encoder. Uses DER encoding.
"""
def __init__(self): # type: () -> None
"""Constructor."""
self.m_stack = None
[docs] def start(self): # type: () -> None
"""This method instructs the encoder to start encoding a new ASN.1
output. This method may be called at any time to reset the encoder,
and resets the current output (if any).
"""
self.m_stack = [[]]
[docs] def enter(self, nr, cls=None): # type: (int, int) -> None
"""This method starts the construction of a constructed type.
Args:
nr (int): The desired ASN.1 type. Use ``Numbers`` enumeration.
cls (int): This optional parameter specifies the class
of the constructed type. The default class to use is the
universal class. Use ``Classes`` enumeration.
Returns:
None
Raises:
`Error`
"""
if self.m_stack is None:
raise Error('Encoder not initialized. Call start() first.')
if cls is None:
cls = Classes.Universal
self._emit_tag(nr, Types.Constructed, cls)
self.m_stack.append([])
[docs] def leave(self): # type: () -> None
"""This method completes the construction of a constructed type and
writes the encoded representation to the output buffer.
"""
if self.m_stack is None:
raise Error('Encoder not initialized. Call start() first.')
if len(self.m_stack) == 1:
raise Error('Tag stack is empty.')
value = b''.join(self.m_stack[-1])
del self.m_stack[-1]
self._emit_length(len(value))
self._emit(value)
[docs] def write(self, value, nr=None, typ=None, cls=None): # type: (object, int, int, int) -> None
"""This method encodes one ASN.1 tag and writes it to the output buffer.
Note:
Normally, ``value`` will be the only parameter to this method.
In this case Python-ASN1 will autodetect the correct ASN.1 type from
the type of ``value``, and will output the encoded value based on this
type.
Args:
value (any): The value of the ASN.1 tag to write. Python-ASN1 will
try to autodetect the correct ASN.1 type from the type of
``value``.
nr (int): If the desired ASN.1 type cannot be autodetected or is
autodetected wrongly, the ``nr`` parameter can be provided to
specify the ASN.1 type to be used. Use ``Numbers`` enumeration.
typ (int): This optional parameter can be used to write constructed
types to the output by setting it to indicate the constructed
encoding type. In this case, ``value`` must already be valid ASN.1
encoded data as plain Python bytes. This is not normally how
constructed types should be encoded though, see `Encoder.enter()`
and `Encoder.leave()` for the recommended way of doing this.
Use ``Types`` enumeration.
cls (int): This parameter can be used to override the class of the
``value``. The default class is the universal class.
Use ``Classes`` enumeration.
Returns:
None
Raises:
`Error`
"""
if self.m_stack is None:
raise Error('Encoder not initialized. Call start() first.')
if typ is None:
typ = Types.Primitive
if cls is None:
cls = Classes.Universal
if cls != Classes.Universal and nr is None:
raise Error('Please specify a tag number (nr) when using classes Application, Context or Private')
if nr is None:
if isinstance(value, bool):
nr = Numbers.Boolean
elif isinstance(value, int):
nr = Numbers.Integer
elif isinstance(value, str):
nr = Numbers.PrintableString
elif isinstance(value, bytes):
nr = Numbers.OctetString
elif value is None:
nr = Numbers.Null
value = self._encode_value(cls, nr, value)
self._emit_tag(nr, typ, cls)
self._emit_length(len(value))
self._emit(value)
[docs] def output(self): # type: () -> bytes
"""This method returns the encoded ASN.1 data as plain Python ``bytes``.
This method can be called multiple times, also during encoding.
In the latter case the data that has been encoded so far is
returned.
Note:
It is an error to call this method if the encoder is still
constructing a constructed type, i.e. if `Encoder.enter()` has been
called more times that `Encoder.leave()`.
Returns:
bytes: The DER encoded ASN.1 data.
Raises:
`Error`
"""
if self.m_stack is None:
raise Error('Encoder not initialized. Call start() first.')
if len(self.m_stack) != 1:
raise Error('Stack is not empty.')
output = b''.join(self.m_stack[0])
return output
def _emit_tag(self, nr, typ, cls): # type: (int, int, int) -> None
"""Emit a tag."""
if nr < 31:
self._emit_tag_short(nr, typ, cls)
else:
self._emit_tag_long(nr, typ, cls)
def _emit_tag_short(self, nr, typ, cls): # type: (int, int, int) -> None
"""Emit a short (< 31 bytes) tag."""
assert nr < 31
self._emit(bytes([nr | typ | cls]))
def _emit_tag_long(self, nr, typ, cls): # type: (int, int, int) -> None
"""Emit a long (>= 31 bytes) tag."""
head = bytes([typ | cls | 0x1f])
self._emit(head)
values = [(nr & 0x7f)]
nr >>= 7
while nr:
values.append((nr & 0x7f) | 0x80)
nr >>= 7
values.reverse()
for val in values:
self._emit(bytes([val]))
def _emit_length(self, length): # type: (int) -> None
"""Emit length octects."""
if length < 128:
self._emit_length_short(length)
else:
self._emit_length_long(length)
def _emit_length_short(self, length): # type: (int) -> None
"""Emit the short length form (< 128 octets)."""
assert length < 128
self._emit(bytes([length]))
def _emit_length_long(self, length): # type: (int) -> None
"""Emit the long length form (>= 128 octets)."""
values = []
while length:
values.append(length & 0xff)
length >>= 8
values.reverse()
# really for correctness as this should not happen anytime soon
assert len(values) < 127
head = bytes([0x80 | len(values)])
self._emit(head)
for val in values:
self._emit(bytes([val]))
def _emit(self, s): # type: (bytes) -> None
"""Emit raw bytes."""
assert isinstance(s, bytes)
self.m_stack[-1].append(s)
def _encode_value(self, cls, nr, value): # type: (int, int, any) -> bytes
"""Encode a value."""
if cls != Classes.Universal:
return value
if nr in (Numbers.Integer, Numbers.Enumerated):
return self._encode_integer(value)
if nr in (Numbers.OctetString, Numbers.PrintableString,
Numbers.UTF8String, Numbers.IA5String,
Numbers.UnicodeString, Numbers.UTCTime):
return self._encode_octet_string(value)
if nr == Numbers.BitString:
return self._encode_bit_string(value)
if nr == Numbers.Boolean:
return self._encode_boolean(value)
if nr == Numbers.Null:
return self._encode_null()
if nr == Numbers.ObjectIdentifier:
return self._encode_object_identifier(value)
return value
@staticmethod
def _encode_boolean(value): # type: (bool) -> bytes
"""Encode a boolean."""
return value and bytes(b'\xff') or bytes(b'\x00')
@staticmethod
def _encode_integer(value): # type: (int) -> bytes
"""Encode an integer."""
if value < 0:
value = -value
negative = True
limit = 0x80
else:
negative = False
limit = 0x7f
values = []
while value > limit:
values.append(value & 0xff)
value >>= 8
values.append(value & 0xff)
if negative:
# create two's complement
for i in range(len(values)): # Invert bits
values[i] = 0xff - values[i]
for i in range(len(values)): # Add 1
values[i] += 1
if values[i] <= 0xff:
break
assert i != len(values) - 1
values[i] = 0x00
if negative and values[len(values) - 1] == 0x7f: # Two's complement corner case
values.append(0xff)
values.reverse()
return bytes(values)
@staticmethod
def _encode_octet_string(value): # type: (object) -> bytes
"""Encode an octetstring."""
# Use the primitive encoding
assert isinstance(value, str) or isinstance(value, bytes)
if isinstance(value, str):
return value.encode('utf-8')
else:
return value
@staticmethod
def _encode_bit_string(value): # type: (object) -> bytes
"""Encode a bitstring. Assumes no unused bytes."""
# Use the primitive encoding
assert isinstance(value, bytes)
return b'\x00' + value
@staticmethod
def _encode_null(): # type: () -> bytes
"""Encode a Null value."""
return bytes(b'')
_re_oid = re.compile(r'^[0-9]+(\.[0-9]+)+$')
def _encode_object_identifier(self, oid): # type: (str) -> bytes
"""Encode an object identifier."""
if not self._re_oid.match(oid):
raise Error('Illegal object identifier')
cmps = list(map(int, oid.split('.')))
if cmps[0] > 39 or cmps[1] > 39:
raise Error('Illegal object identifier')
cmps = [40 * cmps[0] + cmps[1]] + cmps[2:]
cmps.reverse()
result = []
for cmp_data in cmps:
result.append(cmp_data & 0x7f)
while cmp_data > 0x7f:
cmp_data >>= 7
result.append(0x80 | (cmp_data & 0x7f))
result.reverse()
return bytes(result)
[docs]class Decoder(object):
"""ASN.1 decoder. Understands BER (and DER which is a subset)."""
def __init__(self): # type: () -> None
"""Constructor."""
self.m_stack = None
self.m_tag = None
[docs] def start(self, data): # type: (bytes) -> None
"""This method instructs the decoder to start decoding the ASN.1 input
``data``, which must be a passed in as plain Python bytes.
This method may be called at any time to start a new decoding job.
If this method is called while currently decoding another input, that
decoding context is discarded.
Note:
It is not necessary to specify the encoding because the decoder
assumes the input is in BER or DER format.
Args:
data (bytes): ASN.1 input, in BER or DER format, to be decoded.
Returns:
None
Raises:
`Error`
"""
if not isinstance(data, bytes):
raise Error('Expecting bytes instance.')
self.m_stack = [[0, bytes(data)]]
self.m_tag = None
[docs] def peek(self): # type: () -> Tag
"""This method returns the current ASN.1 tag (i.e. the tag that a
subsequent `Decoder.read()` call would return) without updating the
decoding offset. In case no more data is available from the input,
this method returns ``None`` to signal end-of-file.
This method is useful if you don't know whether the next tag will be a
primitive or a constructed tag. Depending on the return value of `peek`,
you would decide to either issue a `Decoder.read()` in case of a primitive
type, or an `Decoder.enter()` in case of a constructed type.
Note:
Because this method does not advance the current offset in the input,
calling it multiple times in a row will return the same value for all
calls.
Returns:
`Tag`: The current ASN.1 tag.
Raises:
`Error`
"""
if self.m_stack is None:
raise Error('No input selected. Call start() first.')
if self._end_of_input():
return None
if self.m_tag is None:
self.m_tag = self._read_tag()
return self.m_tag
[docs] def read(self, tagnr=None): # type: (Number) -> (Tag, any)
"""This method decodes one ASN.1 tag from the input and returns it as a
``(tag, value)`` tuple. ``tag`` is a 3-tuple ``(nr, typ, cls)``,
while ``value`` is a Python object representing the ASN.1 value.
The offset in the input is increased so that the next `Decoder.read()`
call will return the next tag. In case no more data is available from
the input, this method returns ``None`` to signal end-of-file.
Returns:
`Tag`, value: The current ASN.1 tag and its value.
Raises:
`Error`
"""
if self.m_stack is None:
raise Error('No input selected. Call start() first.')
if self._end_of_input():
return None
tag = self.peek()
length = self._read_length()
if tagnr is None:
tagnr = tag.nr
value = self._read_value(tag.cls, tagnr, length)
self.m_tag = None
return tag, value
[docs] def eof(self): # type: () -> bool
"""Return True if we are at the end of input.
Returns:
bool: True if all input has been decoded, and False otherwise.
"""
return self._end_of_input()
[docs] def enter(self): # type: () -> None
"""This method enters the constructed type that is at the current
decoding offset.
Note:
It is an error to call `Decoder.enter()` if the to be decoded ASN.1 tag
is not of a constructed type.
Returns:
None
"""
if self.m_stack is None:
raise Error('No input selected. Call start() first.')
tag = self.peek()
if tag.typ != Types.Constructed:
raise Error('Cannot enter a non-constructed tag.')
length = self._read_length()
bytes_data = self._read_bytes(length)
self.m_stack.append([0, bytes_data])
self.m_tag = None
[docs] def leave(self): # type: () -> None
"""This method leaves the last constructed type that was
`Decoder.enter()`-ed.
Note:
It is an error to call `Decoder.leave()` if the current ASN.1 tag
is not of a constructed type.
Returns:
None
"""
if self.m_stack is None:
raise Error('No input selected. Call start() first.')
if len(self.m_stack) == 1:
raise Error('Tag stack is empty.')
del self.m_stack[-1]
self.m_tag = None
def _read_tag(self): # type: () -> Tag
"""Read a tag from the input."""
byte = self._read_byte()
cls = byte & 0xc0
typ = byte & 0x20
nr = byte & 0x1f
if nr == 0x1f: # Long form of tag encoding
nr = 0
while True:
byte = self._read_byte()
nr = (nr << 7) | (byte & 0x7f)
if not byte & 0x80:
break
return Tag(nr=nr, typ=typ, cls=cls)
def _read_length(self): # type: () -> int
"""Read a length from the input."""
byte = self._read_byte()
if byte & 0x80:
count = byte & 0x7f
if count == 0x7f:
raise Error('ASN1 syntax error')
bytes_data = self._read_bytes(count)
length = 0
for byte in bytes_data:
length = (length << 8) | int(byte)
try:
length = int(length)
except OverflowError:
pass
else:
length = byte
return length
def _read_value(self, cls, nr, length): # type: (int, int, int) -> any
"""Read a value from the input."""
bytes_data = self._read_bytes(length)
if cls != Classes.Universal:
value = bytes_data
elif nr == Numbers.Boolean:
value = self._decode_boolean(bytes_data)
elif nr in (Numbers.Integer, Numbers.Enumerated):
value = self._decode_integer(bytes_data)
elif nr == Numbers.OctetString:
value = self._decode_octet_string(bytes_data)
elif nr == Numbers.Null:
value = self._decode_null(bytes_data)
elif nr == Numbers.ObjectIdentifier:
value = self._decode_object_identifier(bytes_data)
elif nr in (Numbers.PrintableString, Numbers.IA5String, Numbers.UTF8String, Numbers.UTCTime):
value = self._decode_printable_string(bytes_data)
elif nr == Numbers.BitString:
value = self._decode_bitstring(bytes_data)
else:
value = bytes_data
return value
def _read_byte(self): # type: () -> int
"""Return the next input byte, or raise an error on end-of-input."""
index, input_data = self.m_stack[-1]
try:
byte = input_data[index]
except IndexError:
raise Error('Premature end of input.')
self.m_stack[-1][0] += 1
return byte
def _read_bytes(self, count): # type: (int) -> bytes
"""Return the next ``count`` bytes of input. Raise error on
end-of-input."""
index, input_data = self.m_stack[-1]
bytes_data = input_data[index:index + count]
if len(bytes_data) != count:
raise Error('Premature end of input.')
self.m_stack[-1][0] += count
return bytes_data
def _end_of_input(self): # type: () -> bool
"""Return True if we are at the end of input."""
index, input_data = self.m_stack[-1]
assert not index > len(input_data)
return index == len(input_data)
@staticmethod
def _decode_boolean(bytes_data): # type: (bytes) -> bool
"""Decode a boolean value."""
if len(bytes_data) != 1:
raise Error('ASN1 syntax error')
if bytes_data[0] == 0:
return False
return True
@staticmethod
def _decode_integer(bytes_data): # type: (bytes) -> int
"""Decode an integer value."""
values = [int(b) for b in bytes_data]
# check if the integer is normalized
if len(values) > 1 and (values[0] == 0xff and values[1] & 0x80 or values[0] == 0x00 and not (values[1] & 0x80)):
raise Error('ASN1 syntax error')
negative = values[0] & 0x80
if negative:
# make positive by taking two's complement
for i in range(len(values)):
values[i] = 0xff - values[i]
for i in range(len(values) - 1, -1, -1):
values[i] += 1
if values[i] <= 0xff:
break
assert i > 0
values[i] = 0x00
value = 0
for val in values:
value = (value << 8) | val
if negative:
value = -value
try:
value = int(value)
except OverflowError:
pass
return value
@staticmethod
def _decode_octet_string(bytes_data): # type: (bytes) -> bytes
"""Decode an octet string."""
return bytes_data
@staticmethod
def _decode_null(bytes_data): # type: (bytes) -> any
"""Decode a Null value."""
if len(bytes_data) != 0:
raise Error('ASN1 syntax error')
return None
@staticmethod
def _decode_object_identifier(bytes_data): # type: (bytes) -> str
"""Decode an object identifier."""
result = []
value = 0
for i in range(len(bytes_data)):
byte = int(bytes_data[i])
if value == 0 and byte == 0x80:
raise Error('ASN1 syntax error')
value = (value << 7) | (byte & 0x7f)
if not byte & 0x80:
result.append(value)
value = 0
if len(result) == 0 or result[0] > 1599:
raise Error('ASN1 syntax error')
result = [result[0] // 40, result[0] % 40] + result[1:]
result = list(map(str, result))
return str('.'.join(result))
@staticmethod
def _decode_printable_string(bytes_data): # type: (bytes) -> str
"""Decode a printable string."""
return bytes_data.decode('utf-8')
@staticmethod
def _decode_bitstring(bytes_data): # type: (bytes) -> str
"""Decode a bitstring."""
if len(bytes_data) == 0:
raise Error('ASN1 syntax error')
num_unused_bits = bytes_data[0]
if not (0 <= num_unused_bits <= 7):
raise Error('ASN1 syntax error')
if num_unused_bits == 0:
return bytes_data[1:]
# Shift off unused bits
remaining = bytearray(bytes_data[1:])
bitmask = (1 << num_unused_bits) - 1
removed_bits = 0
for i in range(len(remaining)):
byte = int(remaining[i])
remaining[i] = (byte >> num_unused_bits) | (removed_bits << num_unused_bits)
removed_bits = byte & bitmask
return bytes(remaining)