From d86f20cc86d6df1eb47c66cb1136d5018ba77a36 Mon Sep 17 00:00:00 2001 From: wujideng <643604012@qq.com> Date: Wed, 25 Dec 2024 15:44:11 +0800 Subject: [PATCH] feat: support selection of encryption and decryption methods, add gm sm4 crypto type --- docker/conf/config_sdb.yaml | 1 + go.mod | 1 + go.sum | 4 + pkg/filter/crypto/filter.go | 19 ++-- pkg/misc/crypto.go | 188 +++++++++++++++++++++++++++++++++++- pkg/misc/crypto_test.go | 81 ++++++++++++++++ 6 files changed, 284 insertions(+), 10 deletions(-) diff --git a/docker/conf/config_sdb.yaml b/docker/conf/config_sdb.yaml index d1739b60..7e124ed7 100644 --- a/docker/conf/config_sdb.yaml +++ b/docker/conf/config_sdb.yaml @@ -71,3 +71,4 @@ app_config: - table: departments columns: [ "dept_name" ] aeskey: 123456789abcdefg + cryptoType: aesgcm diff --git a/go.mod b/go.mod index 99967897..fef3dc4e 100644 --- a/go.mod +++ b/go.mod @@ -32,6 +32,7 @@ require ( github.com/spf13/cobra v1.1.1 github.com/stretchr/testify v1.7.1 github.com/testcontainers/testcontainers-go v0.13.0 + github.com/tjfoc/gmsm v1.4.1 github.com/uber-go/atomic v1.4.0 github.com/valyala/fasthttp v1.34.0 go.etcd.io/etcd/api/v3 v3.5.0-alpha.0 diff --git a/go.sum b/go.sum index 31d6b3c4..67aa2d26 100644 --- a/go.sum +++ b/go.sum @@ -1237,6 +1237,8 @@ github.com/tikv/client-go/v2 v2.0.0-alpha.0.20210831090540-391fcd842dc8/go.mod h github.com/tikv/pd v1.1.0-beta.0.20210323121136-78679e5e209d/go.mod h1:Jw9KG11C/23Rr7DW4XWQ7H5xOgGZo6DFL1OKAF4+Igw= github.com/tikv/pd v1.1.0-beta.0.20210818112400-0c5667766690 h1:qGn7fDqj7IZ5dozy7QVkoj+0bama92ruVGHqoCBg7W4= github.com/tikv/pd v1.1.0-beta.0.20210818112400-0c5667766690/go.mod h1:rammPjeZgpvfrQRPkijcx8tlxF1XM5+m6kRXrkDzCAA= +github.com/tjfoc/gmsm v1.4.1 h1:aMe1GlZb+0bLjn+cKTPEvvn9oUEBlJitaZiiBwsbgho= +github.com/tjfoc/gmsm v1.4.1/go.mod h1:j4INPkHWMrhJb38G+J6W4Tw0AbuN8Thu3PbdVYhVcTE= github.com/tklauser/go-sysconf v0.3.4/go.mod h1:Cl2c8ZRWfHD5IrfHo9VN+FX9kCFjIOyVklgXycLB6ek= github.com/tklauser/numcpus v0.2.1/go.mod h1:9aU+wOc6WjUIZEwWMP62PL/41d65P+iks1gBkr4QyP8= github.com/tmc/grpc-websocket-proxy v0.0.0-20170815181823-89b8d40f7ca8/go.mod h1:ncp9v5uamzpCO7NfCPTXjqaC+bZgJeR0sMTm6dMHP7U= @@ -1412,6 +1414,7 @@ golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPh golang.org/x/crypto v0.0.0-20200728195943-123391ffb6de/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20200820211705-5c72a883971a/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20201002170205-7f63de1d35b0/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/crypto v0.0.0-20201012173705-84dcc777aaee/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20201112155050-0c6587e931a9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= golang.org/x/crypto v0.0.0-20210817164053-32db794688a5/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= @@ -1505,6 +1508,7 @@ golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81R golang.org/x/net v0.0.0-20200707034311-ab3426394381/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= golang.org/x/net v0.0.0-20201006153459-a7d1128ccaa0/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= +golang.org/x/net v0.0.0-20201010224723-4f7140c49acb/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20201031054903-ff519b6c9102/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20201110031124-69a78807bb2b/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= diff --git a/pkg/filter/crypto/filter.go b/pkg/filter/crypto/filter.go index 611a1649..a9e91159 100644 --- a/pkg/filter/crypto/filter.go +++ b/pkg/filter/crypto/filter.go @@ -67,9 +67,10 @@ type _filter struct { } type ColumnCrypto struct { - Table string - Columns []string - AesKey string + Table string + Columns []string + AesKey string + CryptoType misc.CryptoType } type columnIndex struct { @@ -304,7 +305,7 @@ func encryptInsertValues(columns []*columnIndex, config *ColumnCrypto, valueList if param, ok := arg.(*driver.ValueExpr); ok { value := param.GetBytes() if len(value) != 0 { - encoded, err := misc.AesEncryptGCM(value, []byte(config.AesKey), []byte(aesIV)) + encoded, err := misc.CryptoEncrypt(value, []byte(config.AesKey), []byte(aesIV), config.CryptoType) if err != nil { return errors.Wrapf(err, "Encryption of %s failed", column.Column) } @@ -326,7 +327,7 @@ func encryptUpdateValues(updateStmt *ast.UpdateStmt, config *ColumnCrypto) error if param, ok := arg.(*driver.ValueExpr); ok { value := param.GetBytes() if len(value) != 0 { - encoded, err := misc.AesEncryptGCM(value, []byte(config.AesKey), []byte(aesIV)) + encoded, err := misc.CryptoEncrypt(value, []byte(config.AesKey), []byte(aesIV), config.CryptoType) if err != nil { return errors.Wrapf(err, "Encryption of %s failed", column.Column) } @@ -345,14 +346,14 @@ func encryptBindVars(columns []*columnIndex, config *ColumnCrypto, args *map[str parameterID := fmt.Sprintf("v%d", column.Index+1) param := (*args)[parameterID] if arg, ok := param.(string); ok { - encoded, err := misc.AesEncryptGCM([]byte(arg), []byte(config.AesKey), []byte(aesIV)) + encoded, err := misc.CryptoEncrypt([]byte(arg), []byte(config.AesKey), []byte(aesIV), config.CryptoType) if err != nil { return errors.Errorf("Encryption of %s failed: %v", column.Column, err) } val := hex.EncodeToString(encoded) (*args)[parameterID] = val } else if arg, ok := param.([]byte); ok { - encoded, err := misc.AesEncryptGCM(arg, []byte(config.AesKey), []byte(aesIV)) + encoded, err := misc.CryptoEncrypt(arg, []byte(config.AesKey), []byte(aesIV), config.CryptoType) if err != nil { return errors.Errorf("Encryption of %s failed: %v", column.Column, err) } @@ -372,7 +373,7 @@ func decryptDecodedResult(decodedResult *mysql.Result, config *ColumnCrypto, col if protoValue != nil { if originalVal, ok := protoValue.Val.([]byte); ok { if n, err := hex.Decode(originalVal, originalVal); err == nil { - if decodedVal, err := misc.AesDecryptGCM(originalVal[:n], []byte(config.AesKey), []byte(aesIV)); err == nil { + if decodedVal, err := misc.CryptoDecrypt(originalVal[:n], []byte(config.AesKey), []byte(aesIV), config.CryptoType); err == nil { r.Values[column.Index].Val = decodedVal } } @@ -385,7 +386,7 @@ func decryptDecodedResult(decodedResult *mysql.Result, config *ColumnCrypto, col if protoValue != nil { if originalVal, ok := protoValue.Val.([]byte); ok { if n, err := hex.Decode(originalVal, originalVal); err == nil { - if decodedVal, err := misc.AesDecryptGCM(originalVal[:n], []byte(config.AesKey), []byte(aesIV)); err == nil { + if decodedVal, err := misc.CryptoDecrypt(originalVal[:n], []byte(config.AesKey), []byte(aesIV), config.CryptoType); err == nil { r.Values[column.Index].Val = decodedVal } } diff --git a/pkg/misc/crypto.go b/pkg/misc/crypto.go index f68a3bce..c721afb5 100644 --- a/pkg/misc/crypto.go +++ b/pkg/misc/crypto.go @@ -21,11 +21,112 @@ import ( "crypto/aes" "crypto/cipher" "crypto/rand" + "fmt" + "github.com/pkg/errors" + "github.com/tjfoc/gmsm/sm4" "io" +) - "github.com/pkg/errors" +type CryptoType int + +const ( + CryptoAESGCM CryptoType = iota + CryptoAESCBC + CryptoAESECB + CryptoAESCFB + CryptoSM4GCM + CryptoSM4ECB + CryptoSM4CBC + CryptoSM4CFB + CryptoSM4OFB ) +func (c *CryptoType) UnmarshalText(text []byte) error { + if c == nil { + return errors.New("can't unmarshal a nil *CryptoType") + } + if !c.unmarshalText(bytes.ToLower(text)) { + return fmt.Errorf("unrecognized protocol type: %q", text) + } + return nil +} + +func (c *CryptoType) unmarshalText(text []byte) bool { + switch string(text) { + case "aesgcm": + *c = CryptoAESGCM + case "aescbc": + *c = CryptoAESCBC + case "aesecb": + *c = CryptoAESECB + case "aescfb": + *c = CryptoAESCFB + case "sm4gcm": + *c = CryptoSM4GCM + case "sm4ecb": + *c = CryptoSM4ECB + case "sm4cbc": + *c = CryptoSM4CBC + case "sm4cfb": + *c = CryptoSM4CFB + case "sm4ofb": + *c = CryptoSM4OFB + default: + return false + } + return true +} + +func CryptoEncrypt(data []byte, key []byte, iv []byte, cryptoType CryptoType) ([]byte, error) { + switch cryptoType { + case CryptoAESGCM: + return AesEncryptGCM(data, key, iv) + case CryptoAESCBC: + return AesEncryptCBC(data, key, iv) + case CryptoAESECB: + return AesEncryptECB(data, key) + case CryptoAESCFB: + return AesEncryptCFB(data, key) + case CryptoSM4GCM: + return Sm4EncryptGCM(data, key, iv) + case CryptoSM4ECB: + return Sm4EncryptECB(data, key) + case CryptoSM4CBC: + return Sm4EncryptCBC(data, key, iv) + case CryptoSM4CFB: + return Sm4EncryptCFB(data, key, iv) + case CryptoSM4OFB: + return Sm4EncryptOFB(data, key, iv) + default: + return AesEncryptGCM(data, key, iv) + } +} + +func CryptoDecrypt(encrypted []byte, key []byte, iv []byte, cryptoType CryptoType) ([]byte, error) { + switch cryptoType { + case CryptoAESGCM: + return AesDecryptGCM(encrypted, key, iv) + case CryptoAESCBC: + return AesDecryptCBC(encrypted, key, iv) + case CryptoAESECB: + return AesDecryptECB(encrypted, key) + case CryptoAESCFB: + return AesDecryptCFB(encrypted, key) + case CryptoSM4GCM: + return Sm4DecryptGCM(encrypted, key, iv) + case CryptoSM4ECB: + return Sm4DecryptECB(encrypted, key) + case CryptoSM4CBC: + return Sm4DecryptCBC(encrypted, key, iv) + case CryptoSM4CFB: + return Sm4DecryptCFB(encrypted, key, iv) + case CryptoSM4OFB: + return Sm4DecryptOFB(encrypted, key, iv) + default: + return AesDecryptGCM(encrypted, key, iv) + } +} + func AesEncryptGCM(origData []byte, key []byte, iv []byte) (encrypted []byte, err error) { var block cipher.Block block, err = aes.NewCipher(key) @@ -178,3 +279,88 @@ func AesDecryptCFB(encrypted []byte, key []byte) (decrypted []byte, err error) { stream.XORKeyStream(encrypted, encrypted) return encrypted, err } + +func Sm4EncryptGCM(origData, key []byte, iv []byte) (encrypted []byte, err error) { + // Sm4GCM /** + // key: 对称加密密钥 + // IV: IV向量 + // in: + // A: 附加的可鉴别数据(ADD) + // mode: true - 加密; false - 解密验证 + // + // return: 密文C, 鉴别标签T, 错误 + encrypted, _, err = sm4.Sm4GCM(key, iv, origData, []byte{}, true) + if err != nil { + return nil, err + } + return encrypted, nil +} + +func Sm4DecryptGCM(encrypted, key []byte, iv []byte) (decrypted []byte, err error) { + decrypted, _, err = sm4.Sm4GCM(key, iv, encrypted, []byte{}, true) + if err != nil { + return nil, err + } + return decrypted, nil +} + +func Sm4EncryptECB(origData, key []byte) (encrypted []byte, err error) { + return sm4.Sm4Ecb(key, origData, true) +} + +func Sm4DecryptECB(encrypted, key []byte) (decrypted []byte, err error) { + return sm4.Sm4Ecb(key, encrypted, false) +} + +func Sm4EncryptCBC(origData, key, iv []byte) (encrypted []byte, err error) { + if err = sm4.SetIV(EnsureByteArrayLength16(iv)); err != nil { + return nil, err + } + return sm4.Sm4Cbc(key, origData, true) +} + +func Sm4DecryptCBC(encrypted, key, iv []byte) (decrypted []byte, err error) { + if err = sm4.SetIV(EnsureByteArrayLength16(iv)); err != nil { + return nil, err + } + return sm4.Sm4Cbc(key, encrypted, false) +} + +func Sm4EncryptCFB(origData, key, iv []byte) (encrypted []byte, err error) { + if err = sm4.SetIV(EnsureByteArrayLength16(iv)); err != nil { + return nil, err + } + return sm4.Sm4CFB(key, origData, true) +} + +func Sm4DecryptCFB(encrypted, key, iv []byte) (decrypted []byte, err error) { + if err = sm4.SetIV(EnsureByteArrayLength16(iv)); err != nil { + return nil, err + } + return sm4.Sm4CFB(key, encrypted, false) +} + +func Sm4EncryptOFB(origData, key, iv []byte) (encrypted []byte, err error) { + if err = sm4.SetIV(EnsureByteArrayLength16(iv)); err != nil { + return nil, err + } + return sm4.Sm4OFB(key, origData, true) +} + +func Sm4DecryptOFB(encrypted, key, iv []byte) (decrypted []byte, err error) { + if err = sm4.SetIV(EnsureByteArrayLength16(iv)); err != nil { + return nil, err + } + return sm4.Sm4OFB(key, encrypted, false) +} + +func EnsureByteArrayLength16(input []byte) []byte { + if len(input) == 16 { + return input + } + repeated := append(input, input...) + for len(repeated) < 16 { + repeated = append(repeated, input...) + } + return repeated[:16] +} diff --git a/pkg/misc/crypto_test.go b/pkg/misc/crypto_test.go index 39c87503..180910f2 100644 --- a/pkg/misc/crypto_test.go +++ b/pkg/misc/crypto_test.go @@ -86,3 +86,84 @@ func TestAesDecryptCFB(t *testing.T) { assert.Nil(t, err) assert.Equal(t, []byte("exampleplaintext"), decrypted) } + +func TestSm4EncryptGCM(t *testing.T) { + key, _ := hex.DecodeString("31323334353637383961626364656667") + plaintext := []byte("sunset4") + encrypted, err := Sm4EncryptGCM(plaintext, key, []byte("greatdbpack!")) + assert.Nil(t, err) + t.Logf("%x", encrypted) +} + +func TestSm4DecryptGCM(t *testing.T) { + key, _ := hex.DecodeString("31323334353637383961626364656667") + encrypted, _ := hex.DecodeString("4b3dd6cb3e0145") + decrypted, err := Sm4DecryptGCM(encrypted, key, []byte("greatdbpack!")) + assert.Nil(t, err) + t.Logf("%s", decrypted) + assert.Equal(t, []byte("sunset4"), decrypted) +} + +func TestSm4EncryptECB(t *testing.T) { + key, _ := hex.DecodeString("6368616e676520746869732070617373") + plaintext := []byte("exampleplaintext") + encrypted, err := Sm4EncryptECB(plaintext, key) + assert.Nil(t, err) + t.Logf("%x", encrypted) +} + +func TestSm4DecryptECB(t *testing.T) { + key, _ := hex.DecodeString("6368616e676520746869732070617373") + encrypted, _ := hex.DecodeString("1cadd74166afbe5f4bdaf6ebb49d4c46ce96714d2c0839338f995f4854c61b58") + decrypted, err := Sm4DecryptECB(encrypted, key) + assert.Nil(t, err) + assert.Equal(t, []byte("exampleplaintext"), decrypted) +} + +func TestSm4EncryptCBC(t *testing.T) { + key, _ := hex.DecodeString("6368616e676520746869732070617373") + plaintext := []byte("exampleplaintext") + encrypted, err := Sm4EncryptCBC(plaintext, key, []byte("impressivedbpack")) + assert.Nil(t, err) + t.Logf("%x", encrypted) +} + +func TestSm4DecryptCBC(t *testing.T) { + key, _ := hex.DecodeString("6368616e676520746869732070617373") + encrypted, _ := hex.DecodeString("2e88063cb32a13ce8fbfb60512c23d78d257734049682849d7c82a19f00e131a") + decrypted, err := Sm4DecryptCBC(encrypted, key, []byte("impressivedbpack")) + assert.Nil(t, err) + assert.Equal(t, []byte("exampleplaintext"), decrypted) +} + +func TestSm4EncryptCFB(t *testing.T) { + key, _ := hex.DecodeString("6368616e676520746869732070617373") + plaintext := []byte("exampleplaintext") + encrypted, err := Sm4EncryptCFB(plaintext, key, []byte("impressivedbpack")) + assert.Nil(t, err) + t.Logf("%x", encrypted) +} + +func TestSm4DecryptCFB(t *testing.T) { + key, _ := hex.DecodeString("6368616e676520746869732070617373") + encrypted, _ := hex.DecodeString("5ce63f4fac3744073aa91ac44bdc4ab44a19895a9fcb106947eae2cecfd99e62") + decrypted, err := Sm4DecryptCFB(encrypted, key, []byte("impressivedbpack")) + assert.Nil(t, err) + assert.Equal(t, []byte("exampleplaintext"), decrypted) +} + +func TestSm4EncryptOFB(t *testing.T) { + key, _ := hex.DecodeString("6368616e676520746869732070617373") + plaintext := []byte("exampleplaintext") + encrypted, err := Sm4EncryptOFB(plaintext, key, []byte("impressivedbpack")) + assert.Nil(t, err) + t.Logf("%x", encrypted) +} + +func TestSm4DecryptOFB(t *testing.T) { + key, _ := hex.DecodeString("6368616e676520746869732070617373") + encrypted, _ := hex.DecodeString("5ce63f4fac3744073aa91ac44bdc4ab4f83abab6ff8e4fd91da0740e339f9b2d") + decrypted, err := Sm4DecryptOFB(encrypted, key, []byte("impressivedbpack")) + assert.Nil(t, err) + assert.Equal(t, []byte("exampleplaintext"), decrypted) +}