Pull out a main function, remove global state.

Makes it easier to see what inputs I need.
This commit is contained in:
Eli Ribble 2023-11-07 15:07:54 -07:00
parent 65c5517935
commit 0f450672b1
1 changed files with 21 additions and 17 deletions

View File

@ -17,9 +17,8 @@ def ecc_point_to_256_bit_key(point):
sha.update(int.to_bytes(point.y, 32, 'big')) sha.update(int.to_bytes(point.y, 32, 'big'))
return sha.digest() return sha.digest()
curve = registry.get_curve('brainpoolP256r1')
def encrypt_ECC(msg, pubKey): def encrypt_ECC(curve, msg, pubKey):
ciphertextPrivKey = secrets.randbelow(curve.field.n) ciphertextPrivKey = secrets.randbelow(curve.field.n)
sharedECCKey = ciphertextPrivKey * pubKey sharedECCKey = ciphertextPrivKey * pubKey
secretKey = ecc_point_to_256_bit_key(sharedECCKey) secretKey = ecc_point_to_256_bit_key(sharedECCKey)
@ -34,20 +33,25 @@ def decrypt_ECC(encryptedMsg, privKey):
plaintext = decrypt_AES_GCM(ciphertext, nonce, authTag, secretKey) plaintext = decrypt_AES_GCM(ciphertext, nonce, authTag, secretKey)
return plaintext return plaintext
msg = b'Text to be encrypted by ECC public key and ' \ def main():
curve = registry.get_curve('brainpoolP256r1')
msg = b'Text to be encrypted by ECC public key and ' \
b'decrypted by its corresponding ECC private key' b'decrypted by its corresponding ECC private key'
print("original msg:", msg) print("original msg:", msg)
privKey = secrets.randbelow(curve.field.n) privKey = secrets.randbelow(curve.field.n)
pubKey = privKey * curve.g pubKey = privKey * curve.g
encryptedMsg = encrypt_ECC(msg, pubKey) encryptedMsg = encrypt_ECC(curve, msg, pubKey)
encryptedMsgObj = { encryptedMsgObj = {
'ciphertext': binascii.hexlify(encryptedMsg[0]), 'ciphertext': binascii.hexlify(encryptedMsg[0]),
'nonce': binascii.hexlify(encryptedMsg[1]), 'nonce': binascii.hexlify(encryptedMsg[1]),
'authTag': binascii.hexlify(encryptedMsg[2]), 'authTag': binascii.hexlify(encryptedMsg[2]),
'ciphertextPubKey': hex(encryptedMsg[3].x) + hex(encryptedMsg[3].y % 2)[2:] 'ciphertextPubKey': hex(encryptedMsg[3].x) + hex(encryptedMsg[3].y % 2)[2:]
} }
print("encrypted msg:", encryptedMsgObj) print("encrypted msg:", encryptedMsgObj)
decryptedMsg = decrypt_ECC(encryptedMsg, privKey) decryptedMsg = decrypt_ECC(encryptedMsg, privKey)
print("decrypted msg:", decryptedMsg) print("decrypted msg:", decryptedMsg)
if __name__ == "__main__":
main()