feat: enhance ECC and SM algorithms, add driver path configuration

This commit is contained in:
halo
2026-03-06 16:56:39 +08:00
parent ab36f72a86
commit 8bac7d8fb0
10 changed files with 197 additions and 117 deletions

View File

@@ -1,7 +1,13 @@
from django.conf import settings
from .device import Device
DEFAULT_DRIVER_PATH = "./lib/libpiico_ccmu.so"
def open_piico_device(driver_path) -> Device:
def open_piico_device(driver_path=None) -> Device:
if driver_path is None:
driver_path = settings.PIICO_DRIVER_PATH or DEFAULT_DRIVER_PATH
d = Device()
d.open(driver_path)
return d

View File

@@ -1,8 +1,4 @@
cipher_alg_id = {
"sm4_ebc": 0x00000401,
"sm4_cbc": 0x00000402,
"sm4_mac": 0x00000405,
}
from .const import CIPHER_ALG_ID, SGD_SM2_3
class ECCCipher:
@@ -12,45 +8,46 @@ class ECCCipher:
self.private_key = private_key
def encrypt(self, plain_text):
return self._session.ecc_encrypt(self.public_key, plain_text, 0x00020800)
return self._session.ecc_encrypt(self.public_key, plain_text, SGD_SM2_3)
def decrypt(self, cipher_text):
return self._session.ecc_decrypt(self.private_key, cipher_text, 0x00020800)
return self._session.ecc_decrypt(self.private_key, cipher_text, SGD_SM2_3)
class EBCCipher:
def __init__(self, session, key_val):
self._session = session
self._key = self.__get_key(key_val)
self._key = key_val if isinstance(key_val, bytes) else bytes(key_val, encoding='utf-8')
self._alg = "sm4_ebc"
self._iv = None
def __get_key(self, key_val):
key_val = self.__padding(key_val)
return self._session.import_key(key_val)
@staticmethod
def __padding(val):
# padding
val = bytes(val)
while len(val) == 0 or len(val) % 16 != 0:
val += b'\0'
return val
def __padding(data):
if not isinstance(data, bytes):
data = bytes(data, encoding='utf-8')
while len(data) == 0 or len(data) % 16 != 0:
data += b'\0'
return data
def encrypt(self, plain_text):
plain_text = self.__padding(plain_text)
cipher_text = self._session.encrypt(plain_text, self._key, cipher_alg_id[self._alg], self._iv)
cipher_text = self._session.encrypt(plain_text, self._key, CIPHER_ALG_ID[self._alg], self._iv)
return bytes(cipher_text)
def decrypt(self, cipher_text):
plain_text = self._session.decrypt(cipher_text, self._key, cipher_alg_id[self._alg], self._iv)
plain_text = self._session.decrypt(cipher_text, self._key, CIPHER_ALG_ID[self._alg], self._iv)
return bytes(plain_text)
def destroy(self):
self._session.destroy_cipher_key(self._key)
self._session.close()
def __del__(self):
try:
self.destroy()
except Exception:
pass
class CBCCipher(EBCCipher):

View File

@@ -0,0 +1,34 @@
# ECC algorithm IDs
SGD_SM2 = 0x00020200
SGD_SM2_3 = 0x00020800
# ECC key bits
ECC_KEY_BITS_256 = 0x100
ECC_KEY_BITS_64 = 0x40
# ECC point format
ECC_POINT_UNCOMPRESSED = 0x04
# Hash algorithm IDs
SGD_SM3 = 0x00000001
SGD_SHA1 = 0x00000002
SGD_SHA256 = 0x00000004
SGD_SHA512 = 0x00000008
HASH_ALG_ID = {
"sm3": SGD_SM3,
"sha1": SGD_SHA1,
"sha256": SGD_SHA256,
"sha512": SGD_SHA512,
}
# Cipher algorithm IDs
SGD_SM4_ECB = 0x00000401
SGD_SM4_CBC = 0x00000402
SGD_SM4_MAC = 0x00000405
CIPHER_ALG_ID = {
"sm4_ebc": SGD_SM4_ECB,
"sm4_cbc": SGD_SM4_CBC,
"sm4_mac": SGD_SM4_MAC,
}

View File

@@ -1,12 +1,11 @@
import os
import base64
from ctypes import *
from .cipher import *
from .const import SGD_SM2
from .digest import *
from .exception import PiicoError
from .session import Session
from .cipher import *
from .digest import *
from django.core.cache import cache
from redis_lock import Lock as RedisLock
class Device:
@@ -14,11 +13,12 @@ class Device:
__device = None
def open(self, driver_path="./libpiico_ccmu.so"):
if self.__device is not None:
return
# load driver
self.__load_driver(driver_path)
# open device
self.__open_device()
self.__reset_key_store()
def close(self):
if self.__device is None:
@@ -37,12 +37,21 @@ class Device:
def generate_ecc_key_pair(self):
session = self.new_session()
return session.generate_ecc_key_pair(alg_id=0x00020200)
return session.generate_ecc_key_pair(alg_id=SGD_SM2)
def generate_random(self, length=64):
session = self.new_session()
return session.generate_random(length)
def verify_sign(self, public_key, raw_data, sign_data):
session = self.new_session()
return session.verify_sign_ecc(
SGD_SM2,
base64.b64decode(public_key),
base64.b64decode(raw_data),
base64.b64decode(sign_data),
)
def new_sm2_ecc_cipher(self, public_key, private_key):
session = self.new_session()
return ECCCipher(session, public_key, private_key)
@@ -59,11 +68,11 @@ class Device:
session = self.new_session()
return Digest(session, mode)
def sm3_hmac(self, key, data):
session = self.new_session()
return session.sm3_hmac(key, data)
def __load_driver(self, path):
# check driver status
if self._driver is not None:
raise Exception("already load driver")
# load driver
self._driver = cdll.LoadLibrary(path)
def __open_device(self):
@@ -72,30 +81,3 @@ class Device:
if ret != 0:
raise PiicoError("open piico device failed", ret)
self.__device = device
def __reset_key_store(self):
redis_client = cache.client.get_client()
server_hostname = os.environ.get("SERVER_HOSTNAME")
RESET_LOCK_KEY = f"spiico:{server_hostname}:reset"
LOCK_EXPIRE_SECONDS = 300
if self._driver is None:
raise PiicoError("no driver loaded", 0)
if self.__device is None:
raise PiicoError("device not open", 0)
# ---- 分布式锁Redis-Lock 实现 Redlock ----
lock = RedisLock(
redis_client,
RESET_LOCK_KEY,
expire=LOCK_EXPIRE_SECONDS, # 锁自动过期
auto_renewal=False, # 不自动续租
)
# 尝试获取锁,拿不到直接返回
if not lock.acquire(blocking=False):
return
# ---- 真正执行 reset ----
ret = self._driver.SPII_ResetModule(self.__device)
if ret != 0:
raise PiicoError("reset device failed", ret)

View File

@@ -1,15 +1,10 @@
hash_alg_id = {
"sm3": 0x00000001,
"sha1": 0x00000002,
"sha256": 0x00000004,
"sha512": 0x00000008,
}
from .const import HASH_ALG_ID
class Digest:
def __init__(self, session, alg_name="sm3"):
if hash_alg_id[alg_name] is None:
if HASH_ALG_ID.get(alg_name) is None:
raise Exception("unsupported hash alg {}".format(alg_name))
self._alg_name = alg_name
@@ -17,7 +12,7 @@ class Digest:
self.__init_hash()
def __init_hash(self):
self._session.hash_init(hash_alg_id[self._alg_name])
self._session.hash_init(HASH_ALG_ID[self._alg_name])
def update(self, data):
self._session.hash_update(data)

View File

@@ -1,5 +1,7 @@
from ctypes import *
from .const import ECC_POINT_UNCOMPRESSED
ECCref_MAX_BITS = 512
ECCref_MAX_LEN = int((ECCref_MAX_BITS + 7) / 8)
@@ -11,19 +13,22 @@ class EncodeMixin:
class ECCrefPublicKey(Structure, EncodeMixin):
_fields_ = [
('bits', c_uint),
('x', c_ubyte * ECCref_MAX_LEN),
('y', c_ubyte * ECCref_MAX_LEN),
("bits", c_uint),
("x", c_ubyte * ECCref_MAX_LEN),
("y", c_ubyte * ECCref_MAX_LEN),
]
def encode(self):
return bytes([0x04]) + bytes(self.x[32:]) + bytes(self.y[32:])
return bytes([ECC_POINT_UNCOMPRESSED]) + bytes(self.x[32:]) + bytes(self.y[32:])
class ECCrefPrivateKey(Structure, EncodeMixin):
_fields_ = [
('bits', c_uint,),
('K', c_ubyte * ECCref_MAX_LEN),
(
"bits",
c_uint,
),
("K", c_ubyte * ECCref_MAX_LEN),
]
def encode(self):
@@ -41,9 +46,9 @@ class ECCCipherEncode(EncodeMixin):
def encode(self):
c1 = bytes(self.x[32:]) + bytes(self.y[32:])
c2 = bytes(self.C[:self.L])
c2 = bytes(self.C[: self.L])
c3 = bytes(self.M)
return bytes([0x04]) + c1 + c2 + c3
return bytes([ECC_POINT_UNCOMPRESSED]) + c1 + c2 + c3
def new_ecc_cipher_cla(length):
@@ -52,15 +57,19 @@ def new_ecc_cipher_cla(length):
if _cache.__contains__(cla_name):
return _cache[cla_name]
else:
cla = type(cla_name, (Structure, ECCCipherEncode), {
"_fields_": [
('x', c_ubyte * ECCref_MAX_LEN),
('y', c_ubyte * ECCref_MAX_LEN),
('M', c_ubyte * 32),
('L', c_uint),
('C', c_ubyte * length)
]
})
cla = type(
cla_name,
(Structure, ECCCipherEncode),
{
"_fields_": [
("x", c_ubyte * ECCref_MAX_LEN),
("y", c_ubyte * ECCref_MAX_LEN),
("M", c_ubyte * 32),
("L", c_uint),
("C", c_ubyte * length),
]
},
)
_cache[cla_name] = cla
return cla
@@ -69,3 +78,10 @@ class ECCKeyPair:
def __init__(self, public_key, private_key):
self.public_key = public_key
self.private_key = private_key
class ECCSignature(Structure, EncodeMixin):
_fields_ = [
("r", c_ubyte * ECCref_MAX_LEN),
("s", c_ubyte * ECCref_MAX_LEN),
]

View File

@@ -1,6 +1,9 @@
from ctypes import *
from .ecc import ECCrefPublicKey, ECCrefPrivateKey, ECCKeyPair
from Cryptodome.Util.asn1 import DerSequence
from .const import ECC_KEY_BITS_256
from .ecc import ECCrefPublicKey, ECCrefPrivateKey, ECCKeyPair, ECCSignature
from .exception import PiicoError
from .session_mixin import SM3Mixin, SM4Mixin, SM2Mixin
@@ -24,12 +27,47 @@ class Session(SM2Mixin, SM3Mixin, SM4Mixin):
def generate_ecc_key_pair(self, alg_id):
public_key = ECCrefPublicKey()
private_key = ECCrefPrivateKey()
ret = self._driver.SDF_GenerateKeyPair_ECC(self._session, c_int(alg_id), c_int(256), pointer(public_key),
pointer(private_key))
ret = self._driver.SDF_GenerateKeyPair_ECC(
self._session,
c_int(alg_id),
c_int(256),
pointer(public_key),
pointer(private_key),
)
if ret != 0:
raise PiicoError("generate ecc key pair failed", ret)
return ECCKeyPair(public_key.encode(), private_key.encode())
def verify_sign_ecc(self, alg_id, public_key, raw_data, sign_data):
pos = 0
k1 = bytes([0] * 32) + bytes(public_key[pos : pos + 32])
pos += 32
k2 = bytes([0] * 32) + bytes(public_key[pos : pos + 32])
pk = ECCrefPublicKey(
c_uint(ECC_KEY_BITS_256), (c_ubyte * len(k1))(*k1), (c_ubyte * len(k2))(*k2)
)
seq_der = DerSequence()
decoded_sign = seq_der.decode(sign_data)
if decoded_sign and len(decoded_sign) != 2:
raise PiicoError("verify_sign decoded_sign", -1)
r = bytes([0] * 32) + int(decoded_sign[0]).to_bytes(32, byteorder="big")
s = bytes([0] * 32) + int(decoded_sign[1]).to_bytes(32, byteorder="big")
signature = ECCSignature((c_ubyte * len(r))(*r), (c_ubyte * len(s))(*s))
plain_text = (c_ubyte * len(raw_data))(*raw_data)
ret = self._driver.SDF_ExternalVerify_ECC(
self._session,
c_int(alg_id),
pointer(pk),
plain_text,
c_int(len(plain_text)),
pointer(signature),
)
if ret != 0:
raise PiicoError("verify_sign", ret)
return True
def close(self):
ret = self._driver.SDF_CloseSession(self._session)
if ret != 0:

View File

@@ -1,4 +1,4 @@
from .const import ECC_KEY_BITS_64
from .ecc import *
from .exception import PiicoError
@@ -13,13 +13,12 @@ class BaseMixin:
class SM2Mixin(BaseMixin):
def ecc_encrypt(self, public_key, plain_text, alg_id):
pos = 1
pos = 0
k1 = bytes([0] * 32) + bytes(public_key[pos:pos + 32])
k1 = (c_ubyte * len(k1))(*k1)
pos += 32
k2 = bytes([0] * 32) + bytes(public_key[pos:pos + 32])
pk = ECCrefPublicKey(c_uint(0x40), (c_ubyte * len(k1))(*k1), (c_ubyte * len(k2))(*k2))
pk = ECCrefPublicKey(c_uint(ECC_KEY_BITS_64), (c_ubyte * len(k1))(*k1), (c_ubyte * len(k2))(*k2))
plain_text = (c_ubyte * len(plain_text))(*plain_text)
ecc_data = new_ecc_cipher_cla(len(plain_text))()
@@ -32,7 +31,7 @@ class SM2Mixin(BaseMixin):
def ecc_decrypt(self, private_key, cipher_text, alg_id):
k = bytes([0] * 32) + bytes(private_key[:32])
vk = ECCrefPrivateKey(c_uint(0x40), (c_ubyte * len(k))(*k))
vk = ECCrefPrivateKey(c_uint(ECC_KEY_BITS_64), (c_ubyte * len(k))(*k))
pos = 1
# c1
@@ -85,24 +84,24 @@ class SM3Mixin(BaseMixin):
raise PiicoError("hash final failed", ret)
return bytes(result_data[:result_length.value])
def sm3_hmac(self, key, data):
key_buf = (c_ubyte * len(key))(*key)
data_buf = (c_ubyte * len(data))(*data)
hash_buf = (c_ubyte * 32)()
hash_length = c_uint()
ret = self._driver.SPII_SM3Hmac(
self._session,
key_buf, c_uint(len(key)),
data_buf, c_uint(len(data)),
hash_buf, pointer(hash_length),
)
if ret != 0:
raise PiicoError("sm3 hmac failed", ret)
return bytes(hash_buf[:hash_length.value])
class SM4Mixin(BaseMixin):
def import_key(self, key_val):
# to c lang
key_val = (c_ubyte * len(key_val))(*key_val)
key = c_void_p()
ret = self._driver.SDF_ImportKey(self._session, key_val, c_int(len(key_val)), pointer(key))
if ret != 0:
raise PiicoError("import key failed", ret)
return key
def destroy_cipher_key(self, key):
ret = self._driver.SDF_DestroyKey(self._session, key)
if ret != 0:
raise Exception("destroy key failed")
def encrypt(self, plain_text, key, alg, iv=None):
return self.__do_cipher_action(plain_text, key, alg, iv, True)
@@ -110,20 +109,35 @@ class SM4Mixin(BaseMixin):
return self.__do_cipher_action(cipher_text, key, alg, iv, False)
def __do_cipher_action(self, text, key, alg, iv=None, encrypt=True):
text = (c_ubyte * len(text))(*text)
text_buf = (c_ubyte * len(text))(*text)
key_buf = (c_ubyte * len(key))(*key)
iv_buf = None
iv_len = 0
if iv is not None:
iv = (c_ubyte * len(iv))(*iv)
iv_buf = (c_ubyte * len(iv))(*iv)
iv_len = len(iv)
temp_data = (c_ubyte * len(text))()
temp_data_length = c_int()
if encrypt:
ret = self._driver.SDF_Encrypt(self._session, key, c_int(alg), iv, text, c_int(len(text)), temp_data,
pointer(temp_data_length))
ret = self._driver.SPII_EncryptEx(
self._session, text_buf, c_int(len(text)),
key_buf, c_int(len(key)),
iv_buf, c_int(iv_len),
c_int(alg),
temp_data, pointer(temp_data_length),
)
if ret != 0:
raise PiicoError("encrypt failed", ret)
else:
ret = self._driver.SDF_Decrypt(self._session, key, c_int(alg), iv, text, c_int(len(text)), temp_data,
pointer(temp_data_length))
ret = self._driver.SPII_DecryptEx(
self._session, text_buf, c_int(len(text)),
key_buf, c_int(len(key)),
iv_buf, c_int(iv_len),
c_int(alg),
temp_data, pointer(temp_data_length),
)
if ret != 0:
raise PiicoError("decrypt failed", ret)
return temp_data[:temp_data_length.value]

View File

@@ -210,9 +210,7 @@ class Crypto:
if not crypt_algo:
if settings.GMSSL_ENABLED:
if settings.PIICO_DEVICE_ENABLE:
piico_driver_path = settings.PIICO_DRIVER_PATH if settings.PIICO_DRIVER_PATH \
else "./lib/libpiico_ccmu.so"
device = piico.open_piico_device(piico_driver_path)
device = piico.open_piico_device()
self.cryptor_map["piico_gm"] = get_piico_gm_sm4_ecb_crypto(device)
crypt_algo = 'piico_gm'
else: