Add decryption when sshkey are selected

This commit is contained in:
2025-10-18 08:40:07 +02:00
parent e29f7856a9
commit 86c378fa61
2 changed files with 49 additions and 19 deletions

View File

@@ -1,6 +1,9 @@
# pylint: disable=line-too-long, C0114
from pathlib import Path from pathlib import Path
from os import chmod from os import chmod
from Crypto.PublicKey import RSA from Crypto.PublicKey import RSA
import yaml
from encryptor import Encryptor
class SshKey: class SshKey:
""" """
@@ -29,9 +32,10 @@ class Collection:
""" """
Object class of Collection type Object class of Collection type
""" """
def __init__(self, collection_name): def __init__(self, collection_name: str, password: str):
self.collection_name = collection_name self.collection_name = collection_name
self.collection_path = Path.home().joinpath(".sshkeymanager", self.collection_name) self.collection_path = Path.home().joinpath(".sshkeymanager", self.collection_name)
self.encryptor = Encryptor(password)
def generate_ssh_key(self, name: str, key_type: str): def generate_ssh_key(self, name: str, key_type: str):
""" """
@@ -48,13 +52,37 @@ class Collection:
key_file_path = self.collection_path.joinpath(my_ssh_key.get_name()) key_file_path = self.collection_path.joinpath(my_ssh_key.get_name())
## Info File ## Info File
with open(f"{key_file_path}.txt", "w+", encoding="utf-8") as info_file: with open(f"{key_file_path}.txt", "w+", encoding="utf-8") as info_file:
info_file.write(f"name: {my_ssh_key.get_name()}\nKey_type: {my_ssh_key.get_type()}") info_file.write(f"name: {my_ssh_key.get_name()}\nkey_type: {my_ssh_key.get_type()}")
## Private Key ## Private Key
with open(key_file_path, "wb") as private_file: with open(key_file_path, "wb") as private_file:
private_file.write(my_ssh_key.get_private()) encrypted_key = self.encryptor.encrypt(my_ssh_key.get_private())
private_file.write(encrypted_key)
chmod(key_file_path, 0o600) chmod(key_file_path, 0o600)
## Public Key ## Public Key
with open(f"{key_file_path}.pub", "wb") as public_file: with open(f"{key_file_path}.pub", "wb") as public_file:
public_file.write(my_ssh_key.get_public()) public_file.write(my_ssh_key.get_public())
def get_ssh_key(self, name: str) -> SshKey:
"""
Get ssh key and decrypt private key
"""
key_file_path = self.collection_path.joinpath(name)
# Info file
with open(f"{key_file_path}.txt", "r", encoding="utf-8") as info_file:
data = yaml.safe_load(info_file)
name = data["name"]
key_type = data["key_type"]
# Private Key
with open(key_file_path, "rb") as private_file:
encrypted_private_key = private_file.read()
private_key = self.encryptor.decrypt(encrypted_private_key)
# Public key
with open(f"{key_file_path}.pub", "rb") as public_file:
public_key = public_file.read()
return SshKey(name=name, key_type=key_type, private=private_key, public=public_key)

View File

@@ -1,10 +1,12 @@
import base64 import base64
import os
from Crypto.Cipher import AES from Crypto.Cipher import AES
from Crypto.Protocol.KDF import PBKDF2 from Crypto.Protocol.KDF import PBKDF2
from Crypto.Random import get_random_bytes from Crypto.Random import get_random_bytes
class Encryptor: class Encryptor:
"""
Class to encrypt/decrypt content
"""
def __init__(self, password: str): def __init__(self, password: str):
self.password = password.encode() self.password = password.encode()
self.salt_size = 16 self.salt_size = 16
@@ -13,31 +15,32 @@ class Encryptor:
self.iterations = 100_000 self.iterations = 100_000
def _derive_key(self, salt: bytes) -> bytes: def _derive_key(self, salt: bytes) -> bytes:
"""
Dérive une clé à partir du mot de passe et du sel.
"""
return PBKDF2(self.password, salt, dkLen=self.key_size, count=self.iterations) return PBKDF2(self.password, salt, dkLen=self.key_size, count=self.iterations)
def encrypt(self, plaintext: str) -> str: def encrypt(self, plaintext: str | bytes) -> str:
""" """
Encrypte une chaîne de texte en base64. Encrypte une chaîne de texte en base64.
""" """
if isinstance(plaintext, str):
plaintext_bytes = plaintext.encode()
else:
plaintext_bytes = plaintext
salt = get_random_bytes(self.salt_size) salt = get_random_bytes(self.salt_size)
key = self._derive_key(salt)
iv = get_random_bytes(self.iv_size) iv = get_random_bytes(self.iv_size)
key = self._derive_key(salt)
# Padding (PKCS7) # Padding (PKCS7)
pad_len = AES.block_size - (len(plaintext.encode()) % AES.block_size) pad_len = AES.block_size - (len(plaintext_bytes) % AES.block_size)
padded = plaintext + chr(pad_len) * pad_len padded = plaintext_bytes + bytes([pad_len] * pad_len)
cipher = AES.new(key, AES.MODE_CBC, iv) cipher = AES.new(key, AES.MODE_CBC, iv)
ciphertext = cipher.encrypt(padded.encode()) ciphertext = cipher.encrypt(padded)
# Encodage final : salt + iv + ciphertext encrypted_data = base64.b64encode(salt + iv + ciphertext)
encrypted_data = base64.b64encode(salt + iv + ciphertext).decode()
return encrypted_data return encrypted_data
def decrypt(self, encrypted_text: str) -> str: def decrypt(self, encrypted_text: str) -> bytes:
""" """
Décrypte une chaîne encodée en base64. Décrypte une chaîne encodée en base64.
""" """
@@ -50,8 +53,7 @@ class Encryptor:
cipher = AES.new(key, AES.MODE_CBC, iv) cipher = AES.new(key, AES.MODE_CBC, iv)
padded_plaintext = cipher.decrypt(ciphertext) padded_plaintext = cipher.decrypt(ciphertext)
# Retrait du padding # Remove padding
pad_len = padded_plaintext[-1] pad_len = padded_plaintext[-1]
plaintext = padded_plaintext[:-pad_len].decode() plaintext = padded_plaintext[:-pad_len]
return plaintext # retourne des bytes
return plaintext