Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/python-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.5, 3.6, 3.7, 3.8]
python-version: [3.6, 3.7, 3.8, 3.9]

steps:
- uses: actions/checkout@v2
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
cryptography==3.2
cryptography==3.3.2
yapf==0.21.0
61 changes: 25 additions & 36 deletions sshpubkeys/keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,11 @@
MalformedDataError, MissingMandatoryOptionValueError, TooLongKeyError, TooShortKeyError, UnknownOptionNameError
)
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.asymmetric import ec
from cryptography.hazmat.primitives.asymmetric.dsa import DSAParameterNumbers, DSAPublicNumbers
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicNumbers
from cryptography.hazmat.primitives.asymmetric.utils import Prehashed
from cryptography.hazmat.primitives.serialization import Encoding, PublicFormat
from cryptography.hazmat.primitives.asymmetric import ec
from cryptography.hazmat.primitives import hashes
from urllib.parse import urlparse

import base64
Expand All @@ -40,6 +39,7 @@
class _ECVerifyingKey:
"""ecdsa.key.VerifyingKey reimplementation
"""

def __init__(self, pubkey, default_hashfunc):
self.pubkey = pubkey
self.default_hashfunc = default_hashfunc
Expand All @@ -52,20 +52,17 @@ def curve(self):
def __repr__(self):
pub_key = self.to_string("compressed")
self.to_string("raw")
return "VerifyingKey({0!r}, {1!r}, {2})".format(
pub_key, self.curve.name, self.default_hashfunc.name
)
return f"VerifyingKey({pub_key!r}, {self.curve.name!r}, {self.default_hashfunc.name})"

def to_string(self, encoding="raw"):
"""Pub key as bytes string"""
if encoding == "raw":
return self.pubkey.public_numbers().encode_point()[1:]
elif encoding == "uncompressed":
if encoding == "uncompressed":
return self.pubkey.public_numbers().encode_point()
elif encoding == "compressed":
if encoding == "compressed":
return self.pubkey.public_bytes(Encoding.X962, PublicFormat.CompressedPoint)
else:
raise ValueError(encoding)
raise ValueError(encoding)

def to_pem(self, point_encoding="uncompressed"):
"""Pub key as PEM"""
Expand All @@ -83,10 +80,6 @@ def verify(self, signature, data):
"""Verify signature of provided data"""
return self.pubkey.verify(signature, data, ec.ECDSA(self.default_hashfunc))

def verify_digest(self, signature, digest):
"""Verify signature over prehashed digest"""
return self.pubkey.verify(signature, data, ec.ECDSA(Prehashed(digest)))


class AuthorizedKeysFile: # pylint:disable=too-few-public-methods
"""Represents a full authorized_keys file.
Expand Down Expand Up @@ -191,7 +184,7 @@ def __init__(self, keydata=None, **kwargs):
pass

def __str__(self):
return "Key type: %s, bits: %s, options: %s" % (self.key_type.decode(), self.bits, self.options)
return f"Key type: {self.key_type.decode()}, bits: {self.bits}, options: {self.options}"

def reset(self):
"""Reset all data fields."""
Expand Down Expand Up @@ -235,15 +228,15 @@ def _unpack_by_int(self, data, current_position):
try:
requested_data_length = struct.unpack('>I', data[current_position:current_position + self.INT_LEN])[0]
except struct.error as ex:
raise MalformedDataError("Unable to unpack %s bytes from the data" % self.INT_LEN) from ex
raise MalformedDataError(f"Unable to unpack {self.INT_LEN} bytes from the data") from ex

# Move pointer to the beginning of the data field
current_position += self.INT_LEN
remaining_data_length = len(data[current_position:])

if remaining_data_length < requested_data_length:
raise MalformedDataError(
"Requested %s bytes, but only %s bytes available." % (requested_data_length, remaining_data_length)
f"Requested {requested_data_length} bytes, but only {remaining_data_length} bytes available."
)

next_data = data[current_position:current_position + requested_data_length]
Expand Down Expand Up @@ -326,15 +319,15 @@ def parse_add_single_option(opt):
opt_name = opt
opt_value = True
if " " in opt_name or not self.OPTION_NAME_RE.match(opt_name):
raise InvalidOptionNameError("%s is not a valid option name." % opt_name)
raise InvalidOptionNameError(f"{opt_name} is not a valid option name.")
if self.strict_mode:
for valid_opt_name, value_required in self.OPTIONS_SPEC:
if opt_name.lower() == valid_opt_name:
if value_required and opt_value is True:
raise MissingMandatoryOptionValueError("%s is missing a mandatory value." % opt_name)
raise MissingMandatoryOptionValueError(f"{opt_name} is missing a mandatory value.")
break
else:
raise UnknownOptionNameError("%s is an unrecognized option name." % opt_name)
raise UnknownOptionNameError(f"{opt_name} is an unrecognized option name.")
if opt_name not in parsed_options:
parsed_options[opt_name] = []
parsed_options[opt_name].append(opt_value)
Expand Down Expand Up @@ -377,11 +370,11 @@ def _process_ssh_rsa(self, data):
max_length = self.RSA_MAX_LENGTH_LOOSE
if self.bits < min_length:
raise TooShortKeyError(
"%s key data can not be shorter than %s bits (was %s)" % (self.key_type.decode(), min_length, self.bits)
f"{self.key_type.decode()} key data can not be shorter than {min_length} bits (was {self.bits})"
)
if self.bits > max_length:
raise TooLongKeyError(
"%s key data can not be longer than %s bits (was %s)" % (self.key_type.decode(), max_length, self.bits)
f"{self.key_type.decode()} key data can not be longer than {max_length} bits (was {self.bits})"
)
return current_position

Expand All @@ -396,20 +389,18 @@ def _process_ssh_dss(self, data):
q_bits = self._bits_in_number(data_fields["q"])
p_bits = self._bits_in_number(data_fields["p"])
if q_bits != self.DSA_N_LENGTH:
raise InvalidKeyError("Incorrect DSA key parameters: bits(p)=%s, q=%s" % (self.bits, q_bits))
raise InvalidKeyError(f"Incorrect DSA key parameters: bits(p)={self.bits}, q={q_bits}")
if self.strict_mode:
min_length = self.DSA_MIN_LENGTH_STRICT
max_length = self.DSA_MAX_LENGTH_STRICT
else:
min_length = self.DSA_MIN_LENGTH_LOOSE
max_length = self.DSA_MAX_LENGTH_LOOSE
if p_bits < min_length:
raise TooShortKeyError(
"%s key can not be shorter than %s bits (was %s)" % (self.key_type.decode(), min_length, p_bits)
)
raise TooShortKeyError(f"{self.key_type.decode()} key can not be shorter than {min_length} bits (was {p_bits})")
if p_bits > max_length:
raise TooLongKeyError(
"%s key data can not be longer than %s bits (was %s)" % (self.key_type.decode(), max_length, p_bits)
f"{self.key_type.decode()} key data can not be longer than {max_length} bits (was {p_bits})"
)

dsa_parameters = DSAParameterNumbers(data_fields["p"], data_fields["q"], data_fields["g"])
Expand All @@ -422,14 +413,12 @@ def _process_ecdsa_sha(self, data):
"""Parses ecdsa-sha public keys."""
current_position, curve_information = self._unpack_by_int(data, 0)
if curve_information not in self.ECDSA_CURVE_DATA:
raise NotImplementedError("Invalid curve type: %s" % curve_information)
raise NotImplementedError(f"Invalid curve type: {curve_information}")
curve, hash_algorithm = self.ECDSA_CURVE_DATA[curve_information]

current_position, key_data = self._unpack_by_int(data, current_position)
try:
ecdsa_pubkey = ec.EllipticCurvePublicKey.from_encoded_point(
curve, key_data
)
ecdsa_pubkey = ec.EllipticCurvePublicKey.from_encoded_point(curve, key_data)
except ValueError as ex:
raise InvalidKeyError("Invalid ecdsa key") from ex
self.bits = curve.key_size
Expand All @@ -452,7 +441,7 @@ def _process_ed25516(self, data):

self.bits = verifying_key_length
if self.bits != 256:
raise InvalidKeyLengthError("ed25519 keys must be 256 bits (was %s bits)" % self.bits)
raise InvalidKeyLengthError(f"ed25519 keys must be 256 bits (was {self.bits} bits)")
return current_position

def _validate_application_string(self, application):
Expand All @@ -463,7 +452,7 @@ def _validate_application_string(self, application):
try:
parsed_url = urlparse(application)
except ValueError as err:
raise InvalidKeyError("Application string: %s" % err) from err
raise InvalidKeyError(f"Application string: {err}") from err
if parsed_url.scheme != b"ssh":
raise InvalidKeyError('Application string must begin with "ssh:"')

Expand Down Expand Up @@ -494,7 +483,7 @@ def _process_key(self, data):
return self._process_sk_ecdsa_sha(data)
if self.key_type.strip().startswith(b"sk-ssh-ed25519"):
return self._process_sk_ed25519(data)
raise NotImplementedError("Invalid key type: %s" % self.key_type.decode())
raise NotImplementedError(f"Invalid key type: {self.key_type.decode()}")

def parse(self, keydata=None):
"""Validates SSH public key.
Expand Down Expand Up @@ -528,15 +517,15 @@ def parse(self, keydata=None):
# Check key type
current_position, unpacked_key_type = self._unpack_by_int(self._decoded_key, 0)
if key_type is not None and key_type != unpacked_key_type.decode():
raise InvalidTypeError("Keytype mismatch: %s != %s" % (key_type, unpacked_key_type.decode()))
raise InvalidTypeError(f"Keytype mismatch: {key_type} != {unpacked_key_type.decode()}")

self.key_type = unpacked_key_type

key_data_length = self._process_key(self._decoded_key[current_position:])
current_position = current_position + key_data_length

if current_position != len(self._decoded_key):
raise MalformedDataError("Leftover data: %s bytes" % (len(self._decoded_key) - current_position))
raise MalformedDataError(f"Leftover data: {len(self._decoded_key) - current_position} bytes")

if self.disallow_options and self.options:
raise InvalidOptionsError("Options are disallowed.")
24 changes: 12 additions & 12 deletions tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,17 +105,17 @@ def ch(option, parsed_option):
return lambda self: self.check_valid_option(option, parsed_option)

for i, items in enumerate(options):
prefix_tmp = "%s_%s" % (items[0], i)
setattr(TestOptions, "test_%s" % prefix_tmp, ch(items[1], items[2]))
prefix_tmp = f"{items[0]}_{i}"
setattr(TestOptions, f"test_{prefix_tmp}", ch(items[1], items[2]))


def loop_invalid_options(options):
def ch(option, expected_error):
return lambda self: self.check_invalid_option(option, expected_error)

for i, items in enumerate(options):
prefix_tmp = "%s_%s" % (items[0], i)
setattr(TestOptions, "test_%s" % prefix_tmp, ch(items[1], items[2]))
prefix_tmp = f"{items[0]}_{i}"
setattr(TestOptions, f"test_{prefix_tmp}", ch(items[1], items[2]))


def loop_valid(keyset, prefix):
Expand All @@ -126,7 +126,7 @@ def ch(pubkey, bits, fingerprint_md5, fingerprint_sha256, options, comment, **kw

for items in keyset:
modes = items.pop()
prefix_tmp = "%s_%s" % (prefix, items.pop())
prefix_tmp = f"{prefix}_{items.pop()}"
for mode in modes:
if mode == "strict":
kwargs = {"strict": True}
Expand All @@ -138,7 +138,7 @@ def ch(pubkey, bits, fingerprint_md5, fingerprint_sha256, options, comment, **kw
else:
pubkey, bits, fingerprint_md5, fingerprint_sha256, options, comment = items
setattr(
TestKeys, "test_%s_mode_%s" % (prefix_tmp, mode),
TestKeys, f"test_{prefix_tmp}_mode_{mode}",
ch(pubkey, bits, fingerprint_md5, fingerprint_sha256, options, comment, **kwargs)
)

Expand All @@ -151,32 +151,32 @@ def ch(pubkey, expected_error, **kwargs):

for items in keyset:
modes = items.pop()
prefix_tmp = "%s_%s" % (prefix, items.pop())
prefix_tmp = f"{prefix}_{items.pop()}"
for mode in modes:
if mode == "strict":
kwargs = {"strict": True}
else:
kwargs = {"strict": False}
pubkey, expected_error = items
setattr(TestKeys, "test_%s_mode_%s" % (prefix_tmp, mode), ch(pubkey, expected_error, **kwargs))
setattr(TestKeys, f"test_{prefix_tmp}_mode_{mode}", ch(pubkey, expected_error, **kwargs))


def loop_authorized_keys(keyset):
def ch(file_str, valid_keys_count):
return lambda self: self.check_valid_file(file_str, valid_keys_count)

for i, items in enumerate(keyset):
prefix_tmp = "%s_%s" % (items[0], i)
setattr(TestAuthorizedKeys, "test_%s" % prefix_tmp, ch(items[1], items[2]))
prefix_tmp = f"{items[0]}_{i}"
setattr(TestAuthorizedKeys, f"test_{prefix_tmp}", ch(items[1], items[2]))


def loop_invalid_authorized_keys(keyset):
def ch(file_str, expected_error, **kwargs):
return lambda self: self.check_invalid_file(file_str, expected_error, **kwargs)

for i, items in enumerate(keyset):
prefix_tmp = "%s_%s" % (items[0], i)
setattr(TestAuthorizedKeys, "test_invalid_%s" % prefix_tmp, ch(items[1], items[2]))
prefix_tmp = f"{items[0]}_{i}"
setattr(TestAuthorizedKeys, f"test_invalid_{prefix_tmp}", ch(items[1], items[2]))


loop_valid(list_of_valid_keys, "valid_key")
Expand Down