mirror of
https://github.com/jumpserver/jumpserver.git
synced 2026-07-02 23:23:21 +00:00
361 lines
14 KiB
Python
361 lines
14 KiB
Python
# -*- coding: utf-8 -*-
|
||
#
|
||
|
||
import base64
|
||
import os
|
||
import tempfile
|
||
|
||
from django.conf import settings
|
||
from django.core.exceptions import PermissionDenied
|
||
from django.utils.translation import gettext_lazy as _
|
||
|
||
from users.models import User
|
||
from common.utils import get_logger
|
||
from ..base import JMSBaseAuthBackend
|
||
from .sdk import ukey_sdk_config
|
||
from .exceptions import (
|
||
UKeyAuthError,
|
||
UKeyUserNotFoundError,
|
||
UkeySNMismatchError,
|
||
UKeyCertNormalizationError,
|
||
UKeyCertChainError,
|
||
UKeyCertCNMismatchError,
|
||
UKeySignatureError,
|
||
UKeyCertExpiredError,
|
||
UKeyCertUnsupportedAlgorithmError,
|
||
)
|
||
from .utils import is_sm2_pem
|
||
from authentication.errors.const import reason_user_inactive, reason_choices
|
||
|
||
|
||
__all__ = ['UKeyBackend']
|
||
|
||
logger = get_logger(__name__)
|
||
|
||
|
||
class UKeyBackend(JMSBaseAuthBackend):
|
||
backend = settings.AUTH_BACKEND_UKEY
|
||
|
||
@staticmethod
|
||
def is_enabled():
|
||
return settings.AUTH_UKEY
|
||
|
||
# ── 主入口 ────────────────────────────────────────────────────────────────
|
||
|
||
def authenticate(self, request, username, cert, signature, challenge, ukey_sn=None):
|
||
try:
|
||
user = self._check_user_and_ukey_sn(username, ukey_sn)
|
||
cert_pem = self._load_cert_pem(cert)
|
||
if self._is_sm2_cert(cert_pem):
|
||
user = self._authenticate_sm2(cert_pem, username, signature, challenge, user)
|
||
else:
|
||
user = self._authenticate_other(cert_pem, username, signature, challenge, user)
|
||
if self.user_can_authenticate(user):
|
||
return user
|
||
else:
|
||
error = reason_choices[reason_user_inactive]
|
||
raise PermissionDenied(error)
|
||
except Exception as e:
|
||
if request:
|
||
request.error_message = str(e)
|
||
raise PermissionDenied(str(e))
|
||
|
||
# ── Part 1: 用户与 UKey SN 预校验 ────────────────────────────────────────
|
||
|
||
def _check_user_and_ukey_sn(self, username, ukey_sn):
|
||
"""查找用户并校验 ukey_sn 绑定关系,返回 User 实例。"""
|
||
ukey_sn = (ukey_sn or '').strip()
|
||
user = User.objects.filter(username=username).first()
|
||
if user is None:
|
||
logger.error('UKeyBackend: user %r not found', username)
|
||
raise UKeyUserNotFoundError()
|
||
user_ukey_sn = (user.ukey_sn or '').strip()
|
||
if not user_ukey_sn or not ukey_sn or ukey_sn != user_ukey_sn:
|
||
logger.error('UKeyBackend: ukey_sn mismatch for user %r', username)
|
||
raise UkeySNMismatchError()
|
||
return user
|
||
|
||
# ── Part 2: SM2 证书校验流程 ──────────────────────────────────────────────
|
||
|
||
def _authenticate_sm2(self, cert_pem, username, signature, challenge, user):
|
||
"""SM2 证书校验:加载 → 链校验 → 有效期 → CN 比对 → 签名验证。"""
|
||
sm2_cert = self._load_sm2_cert(cert_pem)
|
||
self._verify_sm2_cert_chain(sm2_cert)
|
||
self._verify_sm2_cert_validity(sm2_cert)
|
||
self._verify_cert_cn(sm2_cert.get_subject().get('commonName'), username)
|
||
self._verify_sm2_signature(sm2_cert.get_subject_public_key(), signature, challenge)
|
||
return user
|
||
|
||
@staticmethod
|
||
def _load_sm2_cert(cert_pem):
|
||
"""将 PEM 字符串写入临时文件,加载为 Sm2Certificate 对象后立即删除临时文件。"""
|
||
from common.utils.gmssl_python import Sm2Certificate
|
||
|
||
fd, cert_file = tempfile.mkstemp(suffix='.crt')
|
||
try:
|
||
os.close(fd)
|
||
with open(cert_file, 'w', encoding='utf-8') as f:
|
||
f.write(cert_pem)
|
||
sm2_cert = Sm2Certificate()
|
||
sm2_cert.import_pem(cert_file)
|
||
except Exception as e:
|
||
logger.error('UKeyBackend: failed to load SM2 cert: %s', e)
|
||
raise UKeyCertNormalizationError()
|
||
finally:
|
||
if os.path.exists(cert_file):
|
||
os.unlink(cert_file)
|
||
return sm2_cert
|
||
|
||
@staticmethod
|
||
def _verify_sm2_cert_validity(sm2_cert):
|
||
"""校验 SM2 证书有效期(not_before / not_after)。"""
|
||
try:
|
||
validity = sm2_cert.get_validity()
|
||
except Exception as e:
|
||
logger.error('UKeyBackend: failed to get SM2 cert validity: %s', e)
|
||
raise UKeyCertExpiredError()
|
||
UKeyBackend._check_validity_period(validity.not_before, validity.not_after, 'SM2')
|
||
|
||
@staticmethod
|
||
def _verify_sm2_cert_chain(sm2_cert):
|
||
"""调用 Sm2Certificate.verify_by_ca_certificate 验证 SM2 证书链。"""
|
||
from common.utils.gmssl_python import Sm2Certificate, SM2_DEFAULT_ID
|
||
|
||
ca_cert_content = ukey_sdk_config.ca_cert_content
|
||
if not ca_cert_content:
|
||
raise UKeyCertChainError()
|
||
|
||
fd, ca_cert_file = tempfile.mkstemp(suffix='.crt')
|
||
try:
|
||
os.close(fd)
|
||
with open(ca_cert_file, 'w', encoding='utf-8') as f:
|
||
f.write(ca_cert_content)
|
||
ca_cert = Sm2Certificate()
|
||
ca_cert.import_pem(ca_cert_file)
|
||
ok = sm2_cert.verify_by_ca_certificate(ca_cert, SM2_DEFAULT_ID)
|
||
except UKeyAuthError:
|
||
raise
|
||
except Exception as e:
|
||
logger.error('UKeyBackend: SM2 cert chain verification error: %s', e)
|
||
raise UKeyCertChainError()
|
||
finally:
|
||
if os.path.exists(ca_cert_file):
|
||
os.unlink(ca_cert_file)
|
||
|
||
if not ok:
|
||
logger.error('UKeyBackend: SM2 cert chain verification failed')
|
||
raise UKeyCertChainError()
|
||
|
||
@staticmethod
|
||
def _verify_sm2_signature(sm2_key, signature, challenge):
|
||
"""使用 gmssl_python 的 Sm2Signature 做 SM2withSM3 验签。"""
|
||
from common.utils.gmssl_python import Sm2Signature, DO_VERIFY, SM2_DEFAULT_ID
|
||
|
||
sig_bytes = UKeyBackend._decode_signature(signature)
|
||
signed_data = UKeyBackend._challenge_as_bytes(challenge)
|
||
try:
|
||
verifier = Sm2Signature(sm2_key, SM2_DEFAULT_ID, DO_VERIFY)
|
||
verifier.update(signed_data)
|
||
ok = bool(verifier.verify(sig_bytes))
|
||
except Exception as e:
|
||
logger.error('UKeyBackend: SM2 signature verification error: %s', e)
|
||
raise UKeySignatureError()
|
||
if not ok:
|
||
logger.error('UKeyBackend: SM2 signature mismatch')
|
||
raise UKeySignatureError()
|
||
|
||
# ── Part 3: RSA / 其他证书校验流程 ───────────────────────────────────────
|
||
|
||
def _authenticate_other(self, cert_pem, username, signature, challenge, user):
|
||
"""RSA 证书校验:加载 → 链校验 → 有效期 → CN 比对 → 签名验证。"""
|
||
cert, pub_key = self._load_rsa_cert(cert_pem)
|
||
self._verify_rsa_cert_chain(cert)
|
||
self._verify_rsa_cert_validity(cert)
|
||
self._verify_cert_cn(self._extract_rsa_cert_cn(cert), username)
|
||
self._verify_rsa_signature(pub_key, signature, challenge)
|
||
return user
|
||
|
||
@staticmethod
|
||
def _load_rsa_cert(cert_pem):
|
||
"""加载 RSA PEM 证书,校验公钥算法类型,返回 (cert, pub_key)。"""
|
||
from cryptography import x509
|
||
from cryptography.hazmat.primitives.asymmetric import ec, rsa
|
||
|
||
try:
|
||
cert = x509.load_pem_x509_certificate(cert_pem.encode())
|
||
except Exception as e:
|
||
logger.error('UKeyBackend: failed to load certificate: %s', e)
|
||
raise UKeyCertNormalizationError()
|
||
|
||
pub_key = cert.public_key()
|
||
if isinstance(pub_key, ec.EllipticCurvePublicKey):
|
||
logger.error('UKeyBackend: ECDSA certificate verification is not supported')
|
||
raise UKeyCertUnsupportedAlgorithmError()
|
||
if not isinstance(pub_key, rsa.RSAPublicKey):
|
||
logger.error('UKeyBackend: unsupported key type: %s', type(pub_key).__name__)
|
||
raise UKeyCertUnsupportedAlgorithmError()
|
||
return cert, pub_key
|
||
|
||
@staticmethod
|
||
def _verify_rsa_cert_validity(cert):
|
||
"""校验 RSA 证书有效期(not_valid_before_utc / not_valid_after_utc)。"""
|
||
UKeyBackend._check_validity_period(
|
||
cert.not_valid_before_utc, cert.not_valid_after_utc, 'RSA'
|
||
)
|
||
|
||
@staticmethod
|
||
def _verify_rsa_cert_chain(cert):
|
||
"""使用 CA 根证书验证 RSA 证书链。"""
|
||
from cryptography import x509
|
||
from cryptography.exceptions import InvalidSignature
|
||
from cryptography.hazmat.primitives.asymmetric import padding
|
||
|
||
ca_cert_content = ukey_sdk_config.ca_cert_content
|
||
if not ca_cert_content:
|
||
logger.error('UKeyBackend: AUTH_UKEY_CA_CERT_CONTENT not configured')
|
||
raise UKeyCertChainError()
|
||
try:
|
||
ca_cert = x509.load_pem_x509_certificate(ca_cert_content.encode())
|
||
ca_cert.public_key().verify(
|
||
cert.signature,
|
||
cert.tbs_certificate_bytes,
|
||
padding.PKCS1v15(),
|
||
cert.signature_hash_algorithm,
|
||
)
|
||
except InvalidSignature:
|
||
logger.error('UKeyBackend: RSA cert chain verification failed')
|
||
raise UKeyCertChainError()
|
||
except UKeyAuthError:
|
||
raise
|
||
except Exception as e:
|
||
logger.error('UKeyBackend: RSA cert chain verification error: %s', e)
|
||
raise UKeyCertChainError()
|
||
|
||
@staticmethod
|
||
def _extract_rsa_cert_cn(cert):
|
||
"""从 RSA 证书 subject 中提取 CN,失败时返回 None。"""
|
||
from cryptography import x509
|
||
|
||
try:
|
||
return cert.subject.get_attributes_for_oid(x509.NameOID.COMMON_NAME)[0].value
|
||
except Exception:
|
||
return None
|
||
|
||
@staticmethod
|
||
def _verify_rsa_signature(pub_key, signature, challenge):
|
||
"""使用 RSA PKCS1v15 + SHA256 验证签名。"""
|
||
from cryptography.exceptions import InvalidSignature
|
||
from cryptography.hazmat.primitives import hashes
|
||
from cryptography.hazmat.primitives.asymmetric import padding
|
||
|
||
sig_bytes = UKeyBackend._decode_signature(signature)
|
||
signed_data = UKeyBackend._challenge_as_bytes(challenge)
|
||
try:
|
||
pub_key.verify(sig_bytes, signed_data, padding.PKCS1v15(), hashes.SHA256())
|
||
except InvalidSignature:
|
||
logger.error('UKeyBackend: RSA signature mismatch')
|
||
raise UKeySignatureError()
|
||
except UKeyAuthError:
|
||
raise
|
||
except Exception as e:
|
||
logger.error('UKeyBackend: RSA signature verification error: %s', e)
|
||
raise UKeySignatureError()
|
||
|
||
# ── 公共工具方法 ──────────────────────────────────────────────────────────
|
||
|
||
@staticmethod
|
||
def _check_validity_period(not_before, not_after, label=''):
|
||
"""校验证书有效期(SM2 和 RSA 共用)。
|
||
|
||
not_before / not_after 可为 naive(本地时间)或 aware(带时区)datetime,
|
||
now 与之保持相同类型以确保可比较。
|
||
"""
|
||
import datetime
|
||
|
||
if not_before.tzinfo is not None:
|
||
now = datetime.datetime.now(datetime.timezone.utc)
|
||
else:
|
||
now = datetime.datetime.now()
|
||
|
||
if now < not_before:
|
||
logger.error(
|
||
'UKeyBackend: %s certificate not yet valid, valid from %s', label, not_before
|
||
)
|
||
raise UKeyCertExpiredError()
|
||
if now > not_after:
|
||
logger.error(
|
||
'UKeyBackend: %s certificate has expired at %s', label, not_after
|
||
)
|
||
raise UKeyCertExpiredError()
|
||
|
||
@staticmethod
|
||
def _verify_cert_cn(cert_cn, username):
|
||
"""校验证书 CN 与 username 是否匹配(SM2 和 RSA 流程共用)。"""
|
||
if cert_cn != username:
|
||
logger.error(
|
||
'UKeyBackend: cert CN %r does not match username %r', cert_cn, username
|
||
)
|
||
raise UKeyCertCNMismatchError()
|
||
|
||
@staticmethod
|
||
def _challenge_as_bytes(challenge):
|
||
"""将 challenge 统一转为 bytes(SM2 和 RSA 签名验证共用)。"""
|
||
return challenge if isinstance(challenge, bytes) else challenge.encode('utf-8')
|
||
|
||
@staticmethod
|
||
def _load_cert_pem(cert_data):
|
||
"""将原始证书数据转为 PEM 字符串,格式不合法时抛出 CertNormalizationError。"""
|
||
try:
|
||
return UKeyBackend._normalize_cert_to_pem(cert_data)
|
||
except Exception as e:
|
||
logger.error('UKeyBackend: cert normalization failed: %s', e)
|
||
raise UKeyCertNormalizationError()
|
||
|
||
@staticmethod
|
||
def _is_sm2_cert(cert_pem):
|
||
"""通过 OID 字节序列判断证书是否使用 SM2 算法。"""
|
||
return is_sm2_pem(cert_pem)
|
||
|
||
@staticmethod
|
||
def _normalize_cert_to_pem(cert_data):
|
||
"""
|
||
将证书统一转换为标准 PEM 格式。
|
||
支持:已含头尾的 PEM、裸 base64 字符串、DER bytes。
|
||
"""
|
||
if isinstance(cert_data, bytes):
|
||
if cert_data.lstrip().startswith(b'-----BEGIN'):
|
||
return cert_data.decode('utf-8')
|
||
b64 = base64.b64encode(cert_data).decode('ascii')
|
||
else:
|
||
cert_data = cert_data.strip()
|
||
if cert_data.startswith('-----BEGIN'):
|
||
return cert_data
|
||
b64 = ''.join(cert_data.split())
|
||
base64.b64decode(b64, validate=True) # 验证是合法 base64
|
||
|
||
lines = [b64[i:i + 64] for i in range(0, len(b64), 64)]
|
||
return (
|
||
'-----BEGIN CERTIFICATE-----\n'
|
||
+ '\n'.join(lines)
|
||
+ '\n-----END CERTIFICATE-----\n'
|
||
)
|
||
|
||
@staticmethod
|
||
def _decode_signature(signature):
|
||
"""
|
||
将签名值转为 bytes。
|
||
依次尝试:已是 bytes → 十六进制字符串 → base64 字符串。
|
||
"""
|
||
if isinstance(signature, bytes):
|
||
return signature
|
||
sig = signature.strip()
|
||
try:
|
||
return bytes.fromhex(sig)
|
||
except ValueError:
|
||
pass
|
||
try:
|
||
return base64.b64decode(sig)
|
||
except Exception:
|
||
pass
|
||
raise ValueError('Cannot decode signature: unknown format')
|