diff --git a/src/collection.py b/src/collection.py index 130ae74..1b1781f 100644 --- a/src/collection.py +++ b/src/collection.py @@ -1,6 +1,9 @@ +# pylint: disable=line-too-long, C0114 from pathlib import Path from os import chmod from Crypto.PublicKey import RSA +import yaml +from encryptor import Encryptor class SshKey: """ @@ -29,9 +32,10 @@ class Collection: """ Object class of Collection type """ - def __init__(self, collection_name): + def __init__(self, collection_name: str, password: str): self.collection_name = collection_name self.collection_path = Path.home().joinpath(".sshkeymanager", self.collection_name) + self.encryptor = Encryptor(password) def generate_ssh_key(self, name: str, key_type: str): """ @@ -48,13 +52,37 @@ class Collection: key_file_path = self.collection_path.joinpath(my_ssh_key.get_name()) ## Info File with open(f"{key_file_path}.txt", "w+", encoding="utf-8") as info_file: - info_file.write(f"name: {my_ssh_key.get_name()}\nKey_type: {my_ssh_key.get_type()}") + info_file.write(f"name: {my_ssh_key.get_name()}\nkey_type: {my_ssh_key.get_type()}") ## Private Key with open(key_file_path, "wb") as private_file: - private_file.write(my_ssh_key.get_private()) + encrypted_key = self.encryptor.encrypt(my_ssh_key.get_private()) + private_file.write(encrypted_key) chmod(key_file_path, 0o600) ## Public Key with open(f"{key_file_path}.pub", "wb") as public_file: public_file.write(my_ssh_key.get_public()) + + def get_ssh_key(self, name: str) -> SshKey: + """ + Get ssh key and decrypt private key + """ + key_file_path = self.collection_path.joinpath(name) + + # Info file + with open(f"{key_file_path}.txt", "r", encoding="utf-8") as info_file: + data = yaml.safe_load(info_file) + name = data["name"] + key_type = data["key_type"] + + # Private Key + with open(key_file_path, "rb") as private_file: + encrypted_private_key = private_file.read() + private_key = self.encryptor.decrypt(encrypted_private_key) + + # Public key + with open(f"{key_file_path}.pub", "rb") as public_file: + public_key = public_file.read() + + return SshKey(name=name, key_type=key_type, private=private_key, public=public_key) diff --git a/src/encryptor.py b/src/encryptor.py index eb27ca5..afed1d6 100644 --- a/src/encryptor.py +++ b/src/encryptor.py @@ -1,10 +1,12 @@ import base64 -import os from Crypto.Cipher import AES from Crypto.Protocol.KDF import PBKDF2 from Crypto.Random import get_random_bytes class Encryptor: + """ + Class to encrypt/decrypt content + """ def __init__(self, password: str): self.password = password.encode() self.salt_size = 16 @@ -13,31 +15,32 @@ class Encryptor: self.iterations = 100_000 def _derive_key(self, salt: bytes) -> bytes: - """ - Dérive une clé à partir du mot de passe et du sel. - """ return PBKDF2(self.password, salt, dkLen=self.key_size, count=self.iterations) - def encrypt(self, plaintext: str) -> str: + def encrypt(self, plaintext: str | bytes) -> str: """ Encrypte une chaîne de texte en base64. """ + if isinstance(plaintext, str): + plaintext_bytes = plaintext.encode() + else: + plaintext_bytes = plaintext + salt = get_random_bytes(self.salt_size) - key = self._derive_key(salt) iv = get_random_bytes(self.iv_size) + key = self._derive_key(salt) # Padding (PKCS7) - pad_len = AES.block_size - (len(plaintext.encode()) % AES.block_size) - padded = plaintext + chr(pad_len) * pad_len + pad_len = AES.block_size - (len(plaintext_bytes) % AES.block_size) + padded = plaintext_bytes + bytes([pad_len] * pad_len) cipher = AES.new(key, AES.MODE_CBC, iv) - ciphertext = cipher.encrypt(padded.encode()) + ciphertext = cipher.encrypt(padded) - # Encodage final : salt + iv + ciphertext - encrypted_data = base64.b64encode(salt + iv + ciphertext).decode() + encrypted_data = base64.b64encode(salt + iv + ciphertext) return encrypted_data - def decrypt(self, encrypted_text: str) -> str: + def decrypt(self, encrypted_text: str) -> bytes: """ Décrypte une chaîne encodée en base64. """ @@ -50,8 +53,7 @@ class Encryptor: cipher = AES.new(key, AES.MODE_CBC, iv) padded_plaintext = cipher.decrypt(ciphertext) - # Retrait du padding + # Remove padding pad_len = padded_plaintext[-1] - plaintext = padded_plaintext[:-pad_len].decode() - - return plaintext + plaintext = padded_plaintext[:-pad_len] + return plaintext # retourne des bytes