Thanks to visit codestin.com
Credit goes to github.com

Skip to content

feat: add a encrypt_secret helper function #279

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Jul 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 108 additions & 0 deletions _test_unstructured_client/unit/test_encryption.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
from cryptography import x509
from cryptography.hazmat.primitives import serialization, hashes
from cryptography.hazmat.primitives.asymmetric import padding, rsa
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from cryptography.hazmat.backends import default_backend
import os
import base64
from typing import Optional

import pytest

from unstructured_client import UnstructuredClient

@pytest.fixture
def rsa_key_pair():
private_key = rsa.generate_private_key(
public_exponent=65537,
key_size=2048,
backend=default_backend()
)
public_key = private_key.public_key()

private_key_pem = private_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.TraditionalOpenSSL,
encryption_algorithm=serialization.NoEncryption()
).decode('utf-8')

public_key_pem = public_key.public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo
).decode('utf-8')

return private_key_pem, public_key_pem

def test_encrypt_rsa(rsa_key_pair):
private_key_pem, public_key_pem = rsa_key_pair

client = UnstructuredClient()

plaintext = "This is a secret message."

secret_obj = client.users.encrypt_secret(public_key_pem, plaintext)

# A short payload should use direct RSA encryption
assert secret_obj["type"] == 'rsa'

decrypted_text = client.users.decrypt_secret(
private_key_pem,
secret_obj["encrypted_value"],
secret_obj["type"],
"",
"",
)
assert decrypted_text == plaintext


def test_encrypt_rsa_aes(rsa_key_pair):
private_key_pem, public_key_pem = rsa_key_pair

client = UnstructuredClient()

plaintext = "This is a secret message." * 100

secret_obj = client.users.encrypt_secret(public_key_pem, plaintext)

# A longer payload uses hybrid RSA-AES encryption
assert secret_obj["type"] == 'rsa_aes'

decrypted_text = client.users.decrypt_secret(
private_key_pem,
secret_obj["encrypted_value"],
secret_obj["type"],
secret_obj["encrypted_aes_key"],
secret_obj["aes_iv"],
)
assert decrypted_text == plaintext


rsa_key_size_bytes = 2048 // 8
max_payload_size = rsa_key_size_bytes - 66 # OAEP SHA256 overhead

@pytest.mark.parametrize(("plaintext", "secret_type"), [
("Short message", "rsa"),
("A" * (max_payload_size), "rsa"), # Just at the RSA limit
("A" * (max_payload_size + 1), "rsa_aes"), # Just over the RSA limit
("A" * 500, "rsa_aes"), # Well over the RSA limit
])
def test_encrypt_around_rsa_size_limit(rsa_key_pair, plaintext, secret_type):
"""
Test that payloads around the RSA size limit choose the correct algorithm.
"""
_, public_key_pem = rsa_key_pair

print(f"Testing plaintext of length {len(plaintext)} with expected type {secret_type}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
print(f"Testing plaintext of length {len(plaintext)} with expected type {secret_type}")

nit


# Load the public key
public_key = serialization.load_pem_public_key(
public_key_pem.encode('utf-8'),
backend=default_backend()
)

client = UnstructuredClient()

secret_obj = client.users.encrypt_secret(public_key_pem, plaintext)

assert secret_obj["type"] == secret_type
assert secret_obj["encrypted_value"] is not None
2 changes: 1 addition & 1 deletion gen.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ python:
clientServerStatusCodesAsErrors: true
defaultErrorName: SDKError
description: Python Client SDK for Unstructured API
enableCustomCodeRegions: false
enableCustomCodeRegions: true
enumFormat: enum
fixFlags:
responseRequiredSep2024: false
Expand Down
166 changes: 166 additions & 0 deletions src/unstructured_client/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,15 @@
from unstructured_client.models import errors, operations, shared
from unstructured_client.types import BaseModel, OptionalNullable, UNSET

# region imports
from cryptography import x509
from cryptography.hazmat.primitives import serialization, hashes
from cryptography.hazmat.primitives.asymmetric import padding, rsa
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from cryptography.hazmat.backends import default_backend
import os
import base64
# endregion imports

class Users(BaseSDK):
def retrieve(
Expand Down Expand Up @@ -458,3 +467,160 @@ async def store_secret_async(
http_res_text,
http_res,
)

# region sdk-class-body
def _encrypt_rsa_aes(
self,
public_key: rsa.RSAPublicKey,
plaintext: str,
) -> dict:
# Generate a random AES key
aes_key = os.urandom(32) # 256-bit AES key

# Generate a random IV
iv = os.urandom(16)

# Encrypt using AES-CFB
cipher = Cipher(
algorithms.AES(aes_key),
modes.CFB(iv),
)
encryptor = cipher.encryptor()
ciphertext = encryptor.update(plaintext.encode('utf-8')) + encryptor.finalize()

# Encrypt the AES key using the RSA public key
encrypted_key = public_key.encrypt(
aes_key,
padding.OAEP(
mgf=padding.MGF1(algorithm=hashes.SHA256()),
algorithm=hashes.SHA256(),
label=None
)
)

return {
'encrypted_aes_key': base64.b64encode(encrypted_key).decode('utf-8'),
'aes_iv': base64.b64encode(iv).decode('utf-8'),
'encrypted_value': base64.b64encode(ciphertext).decode('utf-8'),
'type': 'rsa_aes',
}

def _encrypt_rsa(
self,
public_key: rsa.RSAPublicKey,
plaintext: str,
) -> dict:
# Load public RSA key
ciphertext = public_key.encrypt(
plaintext.encode(),
padding.OAEP(
mgf=padding.MGF1(algorithm=hashes.SHA256()),
algorithm=hashes.SHA256(),
label=None
),
)
return {
'encrypted_value': base64.b64encode(ciphertext).decode('utf-8'),
'type': 'rsa',
'encrypted_aes_key': "",
'aes_iv': "",
}

def decrypt_secret(
self,
private_key_pem: str,
encrypted_value: str,
secret_type: str,
encrypted_aes_key: str,
aes_iv: str,
) -> str:
private_key = serialization.load_pem_private_key(
private_key_pem.encode('utf-8'),
password=None,
backend=default_backend()
)

if not isinstance(private_key, rsa.RSAPrivateKey):
raise TypeError("Private key must be a RSA private key for decryption.")

if secret_type == 'rsa':
ciphertext = base64.b64decode(encrypted_value)
plaintext = private_key.decrypt(
ciphertext,
padding.OAEP(
mgf=padding.MGF1(algorithm=hashes.SHA256()),
algorithm=hashes.SHA256(),
label=None
)
)
return plaintext.decode('utf-8')

# aes_rsa
encrypted_aes_key_decoded = base64.b64decode(encrypted_aes_key)
iv = base64.b64decode(aes_iv)
ciphertext = base64.b64decode(encrypted_value)

aes_key = private_key.decrypt(
encrypted_aes_key_decoded,
padding.OAEP(
mgf=padding.MGF1(algorithm=hashes.SHA256()),
algorithm=hashes.SHA256(),
label=None
)
)
cipher = Cipher(
algorithms.AES(aes_key),
modes.CFB(iv),
)
decryptor = cipher.decryptor()
plaintext = decryptor.update(ciphertext) + decryptor.finalize()
return plaintext.decode('utf-8')

def encrypt_secret(
self,
encryption_cert_or_key_pem: str,
plaintext: str,
encryption_type: Optional[str] = None,
) -> dict:
"""
Encrypts a plaintext string for securely sending to the Unstructured API.

Args:
encryption_cert_or_key_pem (str): A PEM-encoded RSA public key or certificate.
plaintext (str): The string to encrypt.
type (str, optional): Encryption type, either "rsa" or "rsa_aes".

Returns:
dict: A dictionary with encrypted AES key, iv, and ciphertext (all base64-encoded).
"""
# If a cert is provided, extract the public key
if "BEGIN CERTIFICATE" in encryption_cert_or_key_pem:
cert = x509.load_pem_x509_certificate(
encryption_cert_or_key_pem.encode('utf-8'),
)

public_key = cert.public_key() # type: ignore[assignment]
else:
public_key = serialization.load_pem_public_key(
encryption_cert_or_key_pem.encode('utf-8'),
backend=default_backend()
) # type: ignore[assignment]

if not isinstance(public_key, rsa.RSAPublicKey):
raise TypeError("Public key must be a RSA public key for encryption.")

# If the plaintext is short, use RSA directly
# Otherwise, use a RSA_AES envelope hybrid
# Use the length of the public key to determine the encryption type
key_size_bytes = public_key.key_size // 8
max_rsa_length = key_size_bytes - 66 # OAEP SHA256 overhead
print(max_rsa_length)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
print(max_rsa_length)

nit


if not encryption_type:
encryption_type = "rsa" if len(plaintext) <= max_rsa_length else "rsa_aes"

if encryption_type == "rsa":
return self._encrypt_rsa(public_key, plaintext)

return self._encrypt_rsa_aes(public_key, plaintext)
# endregion sdk-class-body