Repository: ZZMarquis/gm Branch: master Commit: 950ad0926746 Files: 25 Total size: 147.4 KB Directory structure: gitextract_meawgwxy/ ├── .github/ │ └── FUNDING.yml ├── .gitignore ├── LICENSE ├── README.md ├── cryptobyte/ │ ├── README │ ├── asn1/ │ │ └── asn1.go │ ├── asn1.go │ ├── builder.go │ └── string.go ├── go.mod ├── sm2/ │ ├── cert/ │ │ ├── gmx509.go │ │ └── gmx509_test.go │ ├── keyexchange.go │ ├── keyexchange_test.go │ ├── sm2.go │ ├── sm2_loop_test.go │ └── sm2_test.go ├── sm3/ │ ├── sm3.go │ └── sm3_test.go ├── sm4/ │ ├── sm4.go │ └── sm4_test.go └── util/ ├── bigint.go ├── bigint_test.go ├── ec.go └── padding.go ================================================ FILE CONTENTS ================================================ ================================================ FILE: .github/FUNDING.yml ================================================ # These are supported funding model platforms github: # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2] patreon: # Replace with a single Patreon username open_collective: # Replace with a single Open Collective username ko_fi: # Replace with a single Ko-fi username tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry liberapay: # Replace with a single Liberapay username issuehunt: # Replace with a single IssueHunt username otechie: # Replace with a single Otechie username custom: https://weibo.com/5978668016/KDzUy5mOq ================================================ FILE: .gitignore ================================================ # Binaries for programs and plugins *.exe *.exe~ *.dll *.so *.dylib # Test binary, build with `go test -c` *.test # Output of the go coverage tool, specifically when used with LiteIDE *.out ================================================ FILE: LICENSE ================================================ Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright [yyyy] [name of copyright owner] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================ FILE: README.md ================================================ # gm ================================================ FILE: cryptobyte/README ================================================ 这个包下的代码全部copy自SDK中的golang_org/x/crypto/cryptobyte包,这个包不对应用开放,所以只好全部copy过来 ================================================ FILE: cryptobyte/asn1/asn1.go ================================================ // Copyright 2017 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. // Package asn1 contains supporting types for parsing and building ASN.1 // messages with the cryptobyte package. package asn1 // Tag represents an ASN.1 identifier octet, consisting of a tag number // (indicating a type) and class (such as context-specific or constructed). // // Methods in the cryptobyte package only support the low-tag-number form, i.e. // a single identifier octet with bits 7-8 encoding the class and bits 1-6 // encoding the tag number. type Tag uint8 const ( classConstructed = 0x20 classContextSpecific = 0x80 ) // Constructed returns t with the constructed class bit set. func (t Tag) Constructed() Tag { return t | classConstructed } // ContextSpecific returns t with the context-specific class bit set. func (t Tag) ContextSpecific() Tag { return t | classContextSpecific } // The following is a list of standard tag and class combinations. const ( BOOLEAN = Tag(1) INTEGER = Tag(2) BIT_STRING = Tag(3) OCTET_STRING = Tag(4) NULL = Tag(5) OBJECT_IDENTIFIER = Tag(6) ENUM = Tag(10) UTF8String = Tag(12) SEQUENCE = Tag(16 | classConstructed) SET = Tag(17 | classConstructed) PrintableString = Tag(19) T61String = Tag(20) IA5String = Tag(22) UTCTime = Tag(23) GeneralizedTime = Tag(24) GeneralString = Tag(27) ) ================================================ FILE: cryptobyte/asn1.go ================================================ // Copyright 2017 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. package cryptobyte import ( encoding_asn1 "encoding/asn1" "fmt" "math/big" "reflect" "time" "github.com/ZZMarquis/gm/cryptobyte/asn1" ) // This file contains ASN.1-related methods for String and Builder. // Builder // AddASN1Int64 appends a DER-encoded ASN.1 INTEGER. func (b *Builder) AddASN1Int64(v int64) { b.addASN1Signed(asn1.INTEGER, v) } // AddASN1Enum appends a DER-encoded ASN.1 ENUMERATION. func (b *Builder) AddASN1Enum(v int64) { b.addASN1Signed(asn1.ENUM, v) } func (b *Builder) addASN1Signed(tag asn1.Tag, v int64) { b.AddASN1(tag, func(c *Builder) { length := 1 for i := v; i >= 0x80 || i < -0x80; i >>= 8 { length++ } for ; length > 0; length-- { i := v >> uint((length-1)*8) & 0xff c.AddUint8(uint8(i)) } }) } // AddASN1Uint64 appends a DER-encoded ASN.1 INTEGER. func (b *Builder) AddASN1Uint64(v uint64) { b.AddASN1(asn1.INTEGER, func(c *Builder) { length := 1 for i := v; i >= 0x80; i >>= 8 { length++ } for ; length > 0; length-- { i := v >> uint((length-1)*8) & 0xff c.AddUint8(uint8(i)) } }) } // AddASN1BigInt appends a DER-encoded ASN.1 INTEGER. func (b *Builder) AddASN1BigInt(n *big.Int) { if b.err != nil { return } b.AddASN1(asn1.INTEGER, func(c *Builder) { if n.Sign() < 0 { // A negative number has to be converted to two's-complement form. So we // invert and subtract 1. If the most-significant-bit isn't set then // we'll need to pad the beginning with 0xff in order to keep the number // negative. nMinus1 := new(big.Int).Neg(n) nMinus1.Sub(nMinus1, bigOne) bytes := nMinus1.Bytes() for i := range bytes { bytes[i] ^= 0xff } if bytes[0]&0x80 == 0 { c.add(0xff) } c.add(bytes...) } else if n.Sign() == 0 { c.add(0) } else { bytes := n.Bytes() if bytes[0]&0x80 != 0 { c.add(0) } c.add(bytes...) } }) } // AddASN1OctetString appends a DER-encoded ASN.1 OCTET STRING. func (b *Builder) AddASN1OctetString(bytes []byte) { b.AddASN1(asn1.OCTET_STRING, func(c *Builder) { c.AddBytes(bytes) }) } const generalizedTimeFormatStr = "20060102150405Z0700" // AddASN1GeneralizedTime appends a DER-encoded ASN.1 GENERALIZEDTIME. func (b *Builder) AddASN1GeneralizedTime(t time.Time) { if t.Year() < 0 || t.Year() > 9999 { b.err = fmt.Errorf("cryptobyte: cannot represent %v as a GeneralizedTime", t) return } b.AddASN1(asn1.GeneralizedTime, func(c *Builder) { c.AddBytes([]byte(t.Format(generalizedTimeFormatStr))) }) } // AddASN1BitString appends a DER-encoded ASN.1 BIT STRING. This does not // support BIT STRINGs that are not a whole number of bytes. func (b *Builder) AddASN1BitString(data []byte) { b.AddASN1(asn1.BIT_STRING, func(b *Builder) { b.AddUint8(0) b.AddBytes(data) }) } func (b *Builder) addBase128Int(n int64) { var length int if n == 0 { length = 1 } else { for i := n; i > 0; i >>= 7 { length++ } } for i := length - 1; i >= 0; i-- { o := byte(n >> uint(i*7)) o &= 0x7f if i != 0 { o |= 0x80 } b.add(o) } } func isValidOID(oid encoding_asn1.ObjectIdentifier) bool { if len(oid) < 2 { return false } if oid[0] > 2 || (oid[0] <= 1 && oid[1] >= 40) { return false } for _, v := range oid { if v < 0 { return false } } return true } func (b *Builder) AddASN1ObjectIdentifier(oid encoding_asn1.ObjectIdentifier) { b.AddASN1(asn1.OBJECT_IDENTIFIER, func(b *Builder) { if !isValidOID(oid) { b.err = fmt.Errorf("cryptobyte: invalid OID: %v", oid) return } b.addBase128Int(int64(oid[0])*40 + int64(oid[1])) for _, v := range oid[2:] { b.addBase128Int(int64(v)) } }) } func (b *Builder) AddASN1Boolean(v bool) { b.AddASN1(asn1.BOOLEAN, func(b *Builder) { if v { b.AddUint8(0xff) } else { b.AddUint8(0) } }) } func (b *Builder) AddASN1NULL() { b.add(uint8(asn1.NULL), 0) } // MarshalASN1 calls encoding_asn1.Marshal on its input and appends the result if // successful or records an error if one occurred. func (b *Builder) MarshalASN1(v interface{}) { // NOTE(martinkr): This is somewhat of a hack to allow propagation of // encoding_asn1.Marshal errors into Builder.err. N.B. if you call MarshalASN1 with a // value embedded into a struct, its tag information is lost. if b.err != nil { return } bytes, err := encoding_asn1.Marshal(v) if err != nil { b.err = err return } b.AddBytes(bytes) } // AddASN1 appends an ASN.1 object. The object is prefixed with the given tag. // Tags greater than 30 are not supported and result in an error (i.e. // low-tag-number form only). The child builder passed to the // BuilderContinuation can be used to build the content of the ASN.1 object. func (b *Builder) AddASN1(tag asn1.Tag, f BuilderContinuation) { if b.err != nil { return } // Identifiers with the low five bits set indicate high-tag-number format // (two or more octets), which we don't support. if tag&0x1f == 0x1f { b.err = fmt.Errorf("cryptobyte: high-tag number identifier octects not supported: 0x%x", tag) return } b.AddUint8(uint8(tag)) b.addLengthPrefixed(1, true, f) } // String func (s *String) ReadASN1Boolean(out *bool) bool { var bytes String if !s.ReadASN1(&bytes, asn1.INTEGER) || len(bytes) != 1 { return false } switch bytes[0] { case 0: *out = false case 0xff: *out = true default: return false } return true } var bigIntType = reflect.TypeOf((*big.Int)(nil)).Elem() // ReadASN1Integer decodes an ASN.1 INTEGER into out and advances. If out does // not point to an integer or to a big.Int, it panics. It returns true on // success and false on error. func (s *String) ReadASN1Integer(out interface{}) bool { if reflect.TypeOf(out).Kind() != reflect.Ptr { panic("out is not a pointer") } switch reflect.ValueOf(out).Elem().Kind() { case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: var i int64 if !s.readASN1Int64(&i) || reflect.ValueOf(out).Elem().OverflowInt(i) { return false } reflect.ValueOf(out).Elem().SetInt(i) return true case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: var u uint64 if !s.readASN1Uint64(&u) || reflect.ValueOf(out).Elem().OverflowUint(u) { return false } reflect.ValueOf(out).Elem().SetUint(u) return true case reflect.Struct: if reflect.TypeOf(out).Elem() == bigIntType { return s.readASN1BigInt(out.(*big.Int)) } } panic("out does not point to an integer type") } func checkASN1Integer(bytes []byte) bool { if len(bytes) == 0 { // An INTEGER is encoded with at least one octet. return false } if len(bytes) == 1 { return true } if bytes[0] == 0 && bytes[1]&0x80 == 0 || bytes[0] == 0xff && bytes[1]&0x80 == 0x80 { // Value is not minimally encoded. return false } return true } var bigOne = big.NewInt(1) func (s *String) readASN1BigInt(out *big.Int) bool { var bytes String if !s.ReadASN1(&bytes, asn1.INTEGER) || !checkASN1Integer(bytes) { return false } if bytes[0]&0x80 == 0x80 { // Negative number. neg := make([]byte, len(bytes)) for i, b := range bytes { neg[i] = ^b } out.SetBytes(neg) out.Add(out, bigOne) out.Neg(out) } else { out.SetBytes(bytes) } return true } func (s *String) readASN1Int64(out *int64) bool { var bytes String if !s.ReadASN1(&bytes, asn1.INTEGER) || !checkASN1Integer(bytes) || !asn1Signed(out, bytes) { return false } return true } func asn1Signed(out *int64, n []byte) bool { length := len(n) if length > 8 { return false } for i := 0; i < length; i++ { *out <<= 8 *out |= int64(n[i]) } // Shift up and down in order to sign extend the result. *out <<= 64 - uint8(length)*8 *out >>= 64 - uint8(length)*8 return true } func (s *String) readASN1Uint64(out *uint64) bool { var bytes String if !s.ReadASN1(&bytes, asn1.INTEGER) || !checkASN1Integer(bytes) || !asn1Unsigned(out, bytes) { return false } return true } func asn1Unsigned(out *uint64, n []byte) bool { length := len(n) if length > 9 || length == 9 && n[0] != 0 { // Too large for uint64. return false } if n[0]&0x80 != 0 { // Negative number. return false } for i := 0; i < length; i++ { *out <<= 8 *out |= uint64(n[i]) } return true } // ReadASN1Enum decodes an ASN.1 ENUMERATION into out and advances. It returns // true on success and false on error. func (s *String) ReadASN1Enum(out *int) bool { var bytes String var i int64 if !s.ReadASN1(&bytes, asn1.ENUM) || !checkASN1Integer(bytes) || !asn1Signed(&i, bytes) { return false } if int64(int(i)) != i { return false } *out = int(i) return true } func (s *String) readBase128Int(out *int) bool { ret := 0 for i := 0; len(*s) > 0; i++ { if i == 4 { return false } ret <<= 7 b := s.read(1)[0] ret |= int(b & 0x7f) if b&0x80 == 0 { *out = ret return true } } return false // truncated } // ReadASN1ObjectIdentifier decodes an ASN.1 OBJECT IDENTIFIER into out and // advances. It returns true on success and false on error. func (s *String) ReadASN1ObjectIdentifier(out *encoding_asn1.ObjectIdentifier) bool { var bytes String if !s.ReadASN1(&bytes, asn1.OBJECT_IDENTIFIER) || len(bytes) == 0 { return false } // In the worst case, we get two elements from the first byte (which is // encoded differently) and then every varint is a single byte long. components := make([]int, len(bytes)+1) // The first varint is 40*value1 + value2: // According to this packing, value1 can take the values 0, 1 and 2 only. // When value1 = 0 or value1 = 1, then value2 is <= 39. When value1 = 2, // then there are no restrictions on value2. var v int if !bytes.readBase128Int(&v) { return false } if v < 80 { components[0] = v / 40 components[1] = v % 40 } else { components[0] = 2 components[1] = v - 80 } i := 2 for ; len(bytes) > 0; i++ { if !bytes.readBase128Int(&v) { return false } components[i] = v } *out = components[:i] return true } // ReadASN1GeneralizedTime decodes an ASN.1 GENERALIZEDTIME into out and // advances. It returns true on success and false on error. func (s *String) ReadASN1GeneralizedTime(out *time.Time) bool { var bytes String if !s.ReadASN1(&bytes, asn1.GeneralizedTime) { return false } t := string(bytes) res, err := time.Parse(generalizedTimeFormatStr, t) if err != nil { return false } if serialized := res.Format(generalizedTimeFormatStr); serialized != t { return false } *out = res return true } // ReadASN1BitString decodes an ASN.1 BIT STRING into out and advances. It // returns true on success and false on error. func (s *String) ReadASN1BitString(out *encoding_asn1.BitString) bool { var bytes String if !s.ReadASN1(&bytes, asn1.BIT_STRING) || len(bytes) == 0 { return false } paddingBits := uint8(bytes[0]) bytes = bytes[1:] if paddingBits > 7 || len(bytes) == 0 && paddingBits != 0 || len(bytes) > 0 && bytes[len(bytes)-1]&(1< 4 || len(*s) < int(2+lenLen) { return false } lenBytes := String((*s)[2 : 2+lenLen]) if !lenBytes.readUnsigned(&len32, int(lenLen)) { return false } // ITU-T X.690 section 10.1 (DER length forms) requires encoding the length // with the minimum number of octets. if len32 < 128 { // Length should have used short-form encoding. return false } if len32>>((lenLen-1)*8) == 0 { // Leading octet is 0. Length should have been at least one byte shorter. return false } headerLen = 2 + uint32(lenLen) if headerLen+len32 < len32 { // Overflow. return false } length = headerLen + len32 } if uint32(int(length)) != length || !s.ReadBytes((*[]byte)(out), int(length)) { return false } if skipHeader && !out.Skip(int(headerLen)) { panic("cryptobyte: internal error") } return true } ================================================ FILE: cryptobyte/builder.go ================================================ // Copyright 2017 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. package cryptobyte import ( "errors" "fmt" ) // A Builder builds byte strings from fixed-length and length-prefixed values. // Builders either allocate space as needed, or are ‘fixed’, which means that // they write into a given buffer and produce an error if it's exhausted. // // The zero value is a usable Builder that allocates space as needed. // // Simple values are marshaled and appended to a Builder using methods on the // Builder. Length-prefixed values are marshaled by providing a // BuilderContinuation, which is a function that writes the inner contents of // the value to a given Builder. See the documentation for BuilderContinuation // for details. type Builder struct { err error result []byte fixedSize bool child *Builder offset int pendingLenLen int pendingIsASN1 bool inContinuation *bool } // NewBuilder creates a Builder that appends its output to the given buffer. // Like append(), the slice will be reallocated if its capacity is exceeded. // Use Bytes to get the final buffer. func NewBuilder(buffer []byte) *Builder { return &Builder{ result: buffer, } } // NewFixedBuilder creates a Builder that appends its output into the given // buffer. This builder does not reallocate the output buffer. Writes that // would exceed the buffer's capacity are treated as an error. func NewFixedBuilder(buffer []byte) *Builder { return &Builder{ result: buffer, fixedSize: true, } } // Bytes returns the bytes written by the builder or an error if one has // occurred during during building. func (b *Builder) Bytes() ([]byte, error) { if b.err != nil { return nil, b.err } return b.result[b.offset:], nil } // BytesOrPanic returns the bytes written by the builder or panics if an error // has occurred during building. func (b *Builder) BytesOrPanic() []byte { if b.err != nil { panic(b.err) } return b.result[b.offset:] } // AddUint8 appends an 8-bit value to the byte string. func (b *Builder) AddUint8(v uint8) { b.add(byte(v)) } // AddUint16 appends a big-endian, 16-bit value to the byte string. func (b *Builder) AddUint16(v uint16) { b.add(byte(v>>8), byte(v)) } // AddUint24 appends a big-endian, 24-bit value to the byte string. The highest // byte of the 32-bit input value is silently truncated. func (b *Builder) AddUint24(v uint32) { b.add(byte(v>>16), byte(v>>8), byte(v)) } // AddUint32 appends a big-endian, 32-bit value to the byte string. func (b *Builder) AddUint32(v uint32) { b.add(byte(v>>24), byte(v>>16), byte(v>>8), byte(v)) } // AddBytes appends a sequence of bytes to the byte string. func (b *Builder) AddBytes(v []byte) { b.add(v...) } // BuilderContinuation is continuation-passing interface for building // length-prefixed byte sequences. Builder methods for length-prefixed // sequences (AddUint8LengthPrefixed etc) will invoke the BuilderContinuation // supplied to them. The child builder passed to the continuation can be used // to build the content of the length-prefixed sequence. For example: // // parent := cryptobyte.NewBuilder() // parent.AddUint8LengthPrefixed(func (child *Builder) { // child.AddUint8(42) // child.AddUint8LengthPrefixed(func (grandchild *Builder) { // grandchild.AddUint8(5) // }) // }) // // It is an error to write more bytes to the child than allowed by the reserved // length prefix. After the continuation returns, the child must be considered // invalid, i.e. users must not store any copies or references of the child // that outlive the continuation. // // If the continuation panics with a value of type BuildError then the inner // error will be returned as the error from Bytes. If the child panics // otherwise then Bytes will repanic with the same value. type BuilderContinuation func(child *Builder) // BuildError wraps an error. If a BuilderContinuation panics with this value, // the panic will be recovered and the inner error will be returned from // Builder.Bytes. type BuildError struct { Err error } // AddUint8LengthPrefixed adds a 8-bit length-prefixed byte sequence. func (b *Builder) AddUint8LengthPrefixed(f BuilderContinuation) { b.addLengthPrefixed(1, false, f) } // AddUint16LengthPrefixed adds a big-endian, 16-bit length-prefixed byte sequence. func (b *Builder) AddUint16LengthPrefixed(f BuilderContinuation) { b.addLengthPrefixed(2, false, f) } // AddUint24LengthPrefixed adds a big-endian, 24-bit length-prefixed byte sequence. func (b *Builder) AddUint24LengthPrefixed(f BuilderContinuation) { b.addLengthPrefixed(3, false, f) } // AddUint32LengthPrefixed adds a big-endian, 32-bit length-prefixed byte sequence. func (b *Builder) AddUint32LengthPrefixed(f BuilderContinuation) { b.addLengthPrefixed(4, false, f) } func (b *Builder) callContinuation(f BuilderContinuation, arg *Builder) { if !*b.inContinuation { *b.inContinuation = true defer func() { *b.inContinuation = false r := recover() if r == nil { return } if buildError, ok := r.(BuildError); ok { b.err = buildError.Err } else { panic(r) } }() } f(arg) } func (b *Builder) addLengthPrefixed(lenLen int, isASN1 bool, f BuilderContinuation) { // Subsequent writes can be ignored if the builder has encountered an error. if b.err != nil { return } offset := len(b.result) b.add(make([]byte, lenLen)...) if b.inContinuation == nil { b.inContinuation = new(bool) } b.child = &Builder{ result: b.result, fixedSize: b.fixedSize, offset: offset, pendingLenLen: lenLen, pendingIsASN1: isASN1, inContinuation: b.inContinuation, } b.callContinuation(f, b.child) b.flushChild() if b.child != nil { panic("cryptobyte: internal error") } } func (b *Builder) flushChild() { if b.child == nil { return } b.child.flushChild() child := b.child b.child = nil if child.err != nil { b.err = child.err return } length := len(child.result) - child.pendingLenLen - child.offset if length < 0 { panic("cryptobyte: internal error") // result unexpectedly shrunk } if child.pendingIsASN1 { // For ASN.1, we reserved a single byte for the length. If that turned out // to be incorrect, we have to move the contents along in order to make // space. if child.pendingLenLen != 1 { panic("cryptobyte: internal error") } var lenLen, lenByte uint8 if int64(length) > 0xfffffffe { b.err = errors.New("pending ASN.1 child too long") return } else if length > 0xffffff { lenLen = 5 lenByte = 0x80 | 4 } else if length > 0xffff { lenLen = 4 lenByte = 0x80 | 3 } else if length > 0xff { lenLen = 3 lenByte = 0x80 | 2 } else if length > 0x7f { lenLen = 2 lenByte = 0x80 | 1 } else { lenLen = 1 lenByte = uint8(length) length = 0 } // Insert the initial length byte, make space for successive length bytes, // and adjust the offset. child.result[child.offset] = lenByte extraBytes := int(lenLen - 1) if extraBytes != 0 { child.add(make([]byte, extraBytes)...) childStart := child.offset + child.pendingLenLen copy(child.result[childStart+extraBytes:], child.result[childStart:]) } child.offset++ child.pendingLenLen = extraBytes } l := length for i := child.pendingLenLen - 1; i >= 0; i-- { child.result[child.offset+i] = uint8(l) l >>= 8 } if l != 0 { b.err = fmt.Errorf("cryptobyte: pending child length %d exceeds %d-byte length prefix", length, child.pendingLenLen) return } if !b.fixedSize { b.result = child.result // In case child reallocated result. } } func (b *Builder) add(bytes ...byte) { if b.err != nil { return } if b.child != nil { panic("attempted write while child is pending") } if len(b.result)+len(bytes) < len(bytes) { b.err = errors.New("cryptobyte: length overflow") } if b.fixedSize && len(b.result)+len(bytes) > cap(b.result) { b.err = errors.New("cryptobyte: Builder is exceeding its fixed-size buffer") return } b.result = append(b.result, bytes...) } // A MarshalingValue marshals itself into a Builder. type MarshalingValue interface { // Marshal is called by Builder.AddValue. It receives a pointer to a builder // to marshal itself into. It may return an error that occurred during // marshaling, such as unset or invalid values. Marshal(b *Builder) error } // AddValue calls Marshal on v, passing a pointer to the builder to append to. // If Marshal returns an error, it is set on the Builder so that subsequent // appends don't have an effect. func (b *Builder) AddValue(v MarshalingValue) { err := v.Marshal(b) if err != nil { b.err = err } } ================================================ FILE: cryptobyte/string.go ================================================ // Copyright 2017 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. // Package cryptobyte contains types that help with parsing and constructing // length-prefixed, binary messages, including ASN.1 DER. (The asn1 subpackage // contains useful ASN.1 constants.) // // The String type is for parsing. It wraps a []byte slice and provides helper // functions for consuming structures, value by value. // // The Builder type is for constructing messages. It providers helper functions // for appending values and also for appending length-prefixed submessages – // without having to worry about calculating the length prefix ahead of time. // // See the documentation and examples for the Builder and String types to get // started. package cryptobyte // String represents a string of bytes. It provides methods for parsing // fixed-length and length-prefixed values from it. type String []byte // read advances a String by n bytes and returns them. If less than n bytes // remain, it returns nil. func (s *String) read(n int) []byte { if len(*s) < n { return nil } v := (*s)[:n] *s = (*s)[n:] return v } // Skip advances the String by n byte and reports whether it was successful. func (s *String) Skip(n int) bool { return s.read(n) != nil } // ReadUint8 decodes an 8-bit value into out and advances over it. It // returns true on success and false on error. func (s *String) ReadUint8(out *uint8) bool { v := s.read(1) if v == nil { return false } *out = uint8(v[0]) return true } // ReadUint16 decodes a big-endian, 16-bit value into out and advances over it. // It returns true on success and false on error. func (s *String) ReadUint16(out *uint16) bool { v := s.read(2) if v == nil { return false } *out = uint16(v[0])<<8 | uint16(v[1]) return true } // ReadUint24 decodes a big-endian, 24-bit value into out and advances over it. // It returns true on success and false on error. func (s *String) ReadUint24(out *uint32) bool { v := s.read(3) if v == nil { return false } *out = uint32(v[0])<<16 | uint32(v[1])<<8 | uint32(v[2]) return true } // ReadUint32 decodes a big-endian, 32-bit value into out and advances over it. // It returns true on success and false on error. func (s *String) ReadUint32(out *uint32) bool { v := s.read(4) if v == nil { return false } *out = uint32(v[0])<<24 | uint32(v[1])<<16 | uint32(v[2])<<8 | uint32(v[3]) return true } func (s *String) readUnsigned(out *uint32, length int) bool { v := s.read(length) if v == nil { return false } var result uint32 for i := 0; i < length; i++ { result <<= 8 result |= uint32(v[i]) } *out = result return true } func (s *String) readLengthPrefixed(lenLen int, outChild *String) bool { lenBytes := s.read(lenLen) if lenBytes == nil { return false } var length uint32 for _, b := range lenBytes { length = length << 8 length = length | uint32(b) } if int(length) < 0 { // This currently cannot overflow because we read uint24 at most, but check // anyway in case that changes in the future. return false } v := s.read(int(length)) if v == nil { return false } *outChild = v return true } // ReadUint8LengthPrefixed reads the content of an 8-bit length-prefixed value // into out and advances over it. It returns true on success and false on // error. func (s *String) ReadUint8LengthPrefixed(out *String) bool { return s.readLengthPrefixed(1, out) } // ReadUint16LengthPrefixed reads the content of a big-endian, 16-bit // length-prefixed value into out and advances over it. It returns true on // success and false on error. func (s *String) ReadUint16LengthPrefixed(out *String) bool { return s.readLengthPrefixed(2, out) } // ReadUint24LengthPrefixed reads the content of a big-endian, 24-bit // length-prefixed value into out and advances over it. It returns true on // success and false on error. func (s *String) ReadUint24LengthPrefixed(out *String) bool { return s.readLengthPrefixed(3, out) } // ReadBytes reads n bytes into out and advances over them. It returns true on // success and false and error. func (s *String) ReadBytes(out *[]byte, n int) bool { v := s.read(n) if v == nil { return false } *out = v return true } // CopyBytes copies len(out) bytes into out and advances over them. It returns // true on success and false on error. func (s *String) CopyBytes(out []byte) bool { n := len(out) v := s.read(n) if v == nil { return false } return copy(out, v) == n } // Empty reports whether the string does not contain any bytes. func (s String) Empty() bool { return len(s) == 0 } ================================================ FILE: go.mod ================================================ module github.com/ZZMarquis/gm go 1.12 ================================================ FILE: sm2/cert/gmx509.go ================================================ package cert import ( "bytes" "crypto/elliptic" "crypto/x509" "crypto/x509/pkix" "encoding/asn1" "errors" "fmt" "math/big" "net" "net/url" "strconv" "strings" "time" "unicode/utf8" "github.com/ZZMarquis/gm/cryptobyte" cryptobyte_asn1 "github.com/ZZMarquis/gm/cryptobyte/asn1" "github.com/ZZMarquis/gm/sm2" ) var ( oidSM2P256V1 = asn1.ObjectIdentifier{1, 2, 156, 10197, 1, 301} oidSignatureSM3WithSM2 = asn1.ObjectIdentifier{1, 2, 156, 10197, 1, 501} oidPublicKeyECDSA = asn1.ObjectIdentifier{1, 2, 840, 10045, 2, 1} oidExtensionRequest = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 9, 14} oidExtensionSubjectKeyId = []int{2, 5, 29, 14} oidExtensionKeyUsage = []int{2, 5, 29, 15} oidExtensionExtendedKeyUsage = []int{2, 5, 29, 37} oidExtensionAuthorityKeyId = []int{2, 5, 29, 35} oidExtensionBasicConstraints = []int{2, 5, 29, 19} oidExtensionSubjectAltName = []int{2, 5, 29, 17} oidExtensionCertificatePolicies = []int{2, 5, 29, 32} oidExtensionNameConstraints = []int{2, 5, 29, 30} oidExtensionCRLDistributionPoints = []int{2, 5, 29, 31} oidExtensionAuthorityInfoAccess = []int{1, 3, 6, 1, 5, 5, 7, 1, 1} oidAuthorityInfoAccessOcsp = asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 48, 1} oidAuthorityInfoAccessIssuers = asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 48, 2} ) const ( nameTypeEmail = 1 nameTypeDNS = 2 nameTypeURI = 6 nameTypeIP = 7 ) type publicKeyInfo struct { Raw asn1.RawContent Algorithm pkix.AlgorithmIdentifier PublicKey asn1.BitString } type tbsCertificateRequest struct { Raw asn1.RawContent Version int Subject asn1.RawValue PublicKey publicKeyInfo RawAttributes []asn1.RawValue `asn1:"tag:0"` } type certificateRequest struct { Raw asn1.RawContent TBSCSR tbsCertificateRequest SignatureAlgorithm pkix.AlgorithmIdentifier SignatureValue asn1.BitString } func CreateCertificateRequest(template *x509.CertificateRequest, pub *sm2.PublicKey, pri *sm2.PrivateKey, userId []byte) (csr []byte, err error) { var publicKeyBytes []byte var publicKeyAlgorithm pkix.AlgorithmIdentifier publicKeyBytes, publicKeyAlgorithm, err = marshalPublicKey(pub) if err != nil { return nil, err } var extensions []pkix.Extension if (len(template.DNSNames) > 0 || len(template.EmailAddresses) > 0 || len(template.IPAddresses) > 0 || len(template.URIs) > 0) && !oidInExtensions(oidExtensionSubjectAltName, template.ExtraExtensions) { sanBytes, err := marshalSANs(template.DNSNames, template.EmailAddresses, template.IPAddresses, template.URIs) if err != nil { return nil, err } extensions = append(extensions, pkix.Extension{ Id: oidExtensionSubjectAltName, Value: sanBytes, }) } extensions = append(extensions, template.ExtraExtensions...) var attributes []pkix.AttributeTypeAndValueSET attributes = append(attributes, template.Attributes...) if len(extensions) > 0 { // specifiedExtensions contains all the extensions that we // found specified via template.Attributes. specifiedExtensions := make(map[string]bool) for _, atvSet := range template.Attributes { if !atvSet.Type.Equal(oidExtensionRequest) { continue } for _, atvs := range atvSet.Value { for _, atv := range atvs { specifiedExtensions[atv.Type.String()] = true } } } atvs := make([]pkix.AttributeTypeAndValue, 0, len(extensions)) for _, e := range extensions { if specifiedExtensions[e.Id.String()] { // Attributes already contained a value for // this extension and it takes priority. continue } atvs = append(atvs, pkix.AttributeTypeAndValue{ // There is no place for the critical flag in a CSR. Type: e.Id, Value: e.Value, }) } // Append the extensions to an existing attribute if possible. appended := false for _, atvSet := range attributes { if !atvSet.Type.Equal(oidExtensionRequest) || len(atvSet.Value) == 0 { continue } atvSet.Value[0] = append(atvSet.Value[0], atvs...) appended = true break } // Otherwise, add a new attribute for the extensions. if !appended { attributes = append(attributes, pkix.AttributeTypeAndValueSET{ Type: oidExtensionRequest, Value: [][]pkix.AttributeTypeAndValue{ atvs, }, }) } } asn1Subject := template.RawSubject if len(asn1Subject) == 0 { asn1Subject, err = asn1.Marshal(template.Subject.ToRDNSequence()) if err != nil { return } } rawAttributes, err := newRawAttributes(attributes) if err != nil { return } tbsCSR := tbsCertificateRequest{ Version: 0, // PKCS #10, RFC 2986 Subject: asn1.RawValue{FullBytes: asn1Subject}, PublicKey: publicKeyInfo{ Algorithm: publicKeyAlgorithm, PublicKey: asn1.BitString{ Bytes: publicKeyBytes, BitLength: len(publicKeyBytes) * 8, }, }, RawAttributes: rawAttributes, } tbsCSRContents, err := asn1.Marshal(tbsCSR) if err != nil { return } tbsCSR.Raw = tbsCSRContents var signature []byte signature, err = sm2.Sign(pri, userId, tbsCSRContents) if err != nil { return } var sigAlgo pkix.AlgorithmIdentifier sigAlgo.Algorithm = oidSignatureSM3WithSM2 return asn1.Marshal(certificateRequest{ TBSCSR: tbsCSR, SignatureAlgorithm: sigAlgo, SignatureValue: asn1.BitString{ Bytes: signature, BitLength: len(signature) * 8, }, }) } // marshalSANs marshals a list of addresses into a the contents of an X.509 // SubjectAlternativeName extension. func marshalSANs(dnsNames, emailAddresses []string, ipAddresses []net.IP, uris []*url.URL) (derBytes []byte, err error) { var rawValues []asn1.RawValue for _, name := range dnsNames { rawValues = append(rawValues, asn1.RawValue{Tag: nameTypeDNS, Class: 2, Bytes: []byte(name)}) } for _, email := range emailAddresses { rawValues = append(rawValues, asn1.RawValue{Tag: nameTypeEmail, Class: 2, Bytes: []byte(email)}) } for _, rawIP := range ipAddresses { // If possible, we always want to encode IPv4 addresses in 4 bytes. ip := rawIP.To4() if ip == nil { ip = rawIP } rawValues = append(rawValues, asn1.RawValue{Tag: nameTypeIP, Class: 2, Bytes: ip}) } for _, uri := range uris { rawValues = append(rawValues, asn1.RawValue{Tag: nameTypeURI, Class: 2, Bytes: []byte(uri.String())}) } return asn1.Marshal(rawValues) } // oidNotInExtensions returns whether an extension with the given oid exists in // extensions. func oidInExtensions(oid asn1.ObjectIdentifier, extensions []pkix.Extension) bool { for _, e := range extensions { if e.Id.Equal(oid) { return true } } return false } func marshalPublicKey(pub *sm2.PublicKey) (publicKeyBytes []byte, publicKeyAlgorithm pkix.AlgorithmIdentifier, err error) { publicKeyBytes = pub.GetUnCompressBytes() publicKeyAlgorithm.Algorithm = oidPublicKeyECDSA var paramBytes []byte paramBytes, err = asn1.Marshal(oidSM2P256V1) publicKeyAlgorithm.Parameters.FullBytes = paramBytes err = nil return } // newRawAttributes converts AttributeTypeAndValueSETs from a template // CertificateRequest's Attributes into tbsCertificateRequest RawAttributes. func newRawAttributes(attributes []pkix.AttributeTypeAndValueSET) ([]asn1.RawValue, error) { var rawAttributes []asn1.RawValue b, err := asn1.Marshal(attributes) if err != nil { return nil, err } rest, err := asn1.Unmarshal(b, &rawAttributes) if err != nil { return nil, err } if len(rest) != 0 { return nil, errors.New("x509: failed to unmarshal raw CSR Attributes") } return rawAttributes, nil } // ParseCertificateRequest parses a single certificate request from the // given ASN.1 DER data. func ParseCertificateRequest(asn1Data []byte) (*x509.CertificateRequest, error) { var csr certificateRequest rest, err := asn1.Unmarshal(asn1Data, &csr) if err != nil { return nil, err } else if len(rest) != 0 { return nil, asn1.SyntaxError{Msg: "trailing data"} } return parseCertificateRequest(&csr) } func parseCertificateRequest(in *certificateRequest) (*x509.CertificateRequest, error) { out := &x509.CertificateRequest{ Raw: in.Raw, RawTBSCertificateRequest: in.TBSCSR.Raw, RawSubjectPublicKeyInfo: in.TBSCSR.PublicKey.Raw, RawSubject: in.TBSCSR.Subject.FullBytes, Signature: in.SignatureValue.RightAlign(), SignatureAlgorithm: 0, //与x509.go里的实现不一样,因为这里都是固定了使用SM3WithSM2 PublicKeyAlgorithm: 0, //与x509.go里的实现不一样,因为这里都是固定了使用EC公钥 Version: in.TBSCSR.Version, Attributes: parseRawAttributes(in.TBSCSR.RawAttributes), } if !oidSignatureSM3WithSM2.Equal(in.SignatureAlgorithm.Algorithm) { return nil, errors.New("x509: illegal signature algorithm OID") } if !oidPublicKeyECDSA.Equal(in.TBSCSR.PublicKey.Algorithm.Algorithm) { return nil, errors.New("x509: illegal publick key algorithm OID") } var err error out.PublicKey, err = parsePublicKey(&in.TBSCSR.PublicKey) if err != nil { return nil, err } var subject pkix.RDNSequence if rest, err := asn1.Unmarshal(in.TBSCSR.Subject.FullBytes, &subject); err != nil { return nil, err } else if len(rest) != 0 { return nil, errors.New("x509: trailing data after X.509 Subject") } out.Subject.FillFromRDNSequence(&subject) if out.Extensions, err = parseCSRExtensions(in.TBSCSR.RawAttributes); err != nil { return nil, err } for _, extension := range out.Extensions { if extension.Id.Equal(oidExtensionSubjectAltName) { out.DNSNames, out.EmailAddresses, out.IPAddresses, out.URIs, err = parseSANExtension(extension.Value) if err != nil { return nil, err } } } return out, nil } func parsePublicKey(keyData *publicKeyInfo) (interface{}, error) { paramsData := keyData.Algorithm.Parameters.FullBytes namedCurveOID := new(asn1.ObjectIdentifier) rest, err := asn1.Unmarshal(paramsData, namedCurveOID) if err != nil { return nil, err } if len(rest) != 0 { return nil, errors.New("x509: trailing data after SM2 parameters") } if !oidSM2P256V1.Equal(*namedCurveOID) { return nil, errors.New("x509: CurveOID is not the OID of SM2P256V1") } curve := sm2.GetSm2P256V1() x, y := elliptic.Unmarshal(curve, keyData.PublicKey.RightAlign()) if x == nil || y == nil { return nil, errors.New("x509: Unmarshal PublicKey failed") } pub := &sm2.PublicKey{ Curve: curve, X: x, Y: y, } return pub, nil } // parseRawAttributes Unmarshals RawAttributes intos AttributeTypeAndValueSETs. func parseRawAttributes(rawAttributes []asn1.RawValue) []pkix.AttributeTypeAndValueSET { var attributes []pkix.AttributeTypeAndValueSET for _, rawAttr := range rawAttributes { var attr pkix.AttributeTypeAndValueSET rest, err := asn1.Unmarshal(rawAttr.FullBytes, &attr) // Ignore attributes that don't parse into pkix.AttributeTypeAndValueSET // (i.e.: challengePassword or unstructuredName). if err == nil && len(rest) == 0 { attributes = append(attributes, attr) } } return attributes } // parseCSRExtensions parses the attributes from a CSR and extracts any // requested extensions. func parseCSRExtensions(rawAttributes []asn1.RawValue) ([]pkix.Extension, error) { // pkcs10Attribute reflects the Attribute structure from section 4.1 of // https://tools.ietf.org/html/rfc2986. type pkcs10Attribute struct { Id asn1.ObjectIdentifier Values []asn1.RawValue `asn1:"set"` } var ret []pkix.Extension for _, rawAttr := range rawAttributes { var attr pkcs10Attribute if rest, err := asn1.Unmarshal(rawAttr.FullBytes, &attr); err != nil || len(rest) != 0 || len(attr.Values) == 0 { // Ignore attributes that don't parse. continue } if !attr.Id.Equal(oidExtensionRequest) { continue } var extensions []pkix.Extension if _, err := asn1.Unmarshal(attr.Values[0].FullBytes, &extensions); err != nil { return nil, err } ret = append(ret, extensions...) } return ret, nil } func forEachSAN(extension []byte, callback func(tag int, data []byte) error) error { // RFC 5280, 4.2.1.6 // SubjectAltName ::= GeneralNames // // GeneralNames ::= SEQUENCE SIZE (1..MAX) OF GeneralName // // GeneralName ::= CHOICE { // otherName [0] OtherName, // rfc822Name [1] IA5String, // dNSName [2] IA5String, // x400Address [3] ORAddress, // directoryName [4] Name, // ediPartyName [5] EDIPartyName, // uniformResourceIdentifier [6] IA5String, // iPAddress [7] OCTET STRING, // registeredID [8] OBJECT IDENTIFIER } var seq asn1.RawValue rest, err := asn1.Unmarshal(extension, &seq) if err != nil { return err } else if len(rest) != 0 { return errors.New("x509: trailing data after X.509 extension") } if !seq.IsCompound || seq.Tag != 16 || seq.Class != 0 { return asn1.StructuralError{Msg: "bad SAN sequence"} } rest = seq.Bytes for len(rest) > 0 { var v asn1.RawValue rest, err = asn1.Unmarshal(rest, &v) if err != nil { return err } if err := callback(v.Tag, v.Bytes); err != nil { return err } } return nil } // domainToReverseLabels converts a textual domain name like foo.example.com to // the list of labels in reverse order, e.g. ["com", "example", "foo"]. func domainToReverseLabels(domain string) (reverseLabels []string, ok bool) { for len(domain) > 0 { if i := strings.LastIndexByte(domain, '.'); i == -1 { reverseLabels = append(reverseLabels, domain) domain = "" } else { reverseLabels = append(reverseLabels, domain[i+1:]) domain = domain[:i] } } if len(reverseLabels) > 0 && len(reverseLabels[0]) == 0 { // An empty label at the end indicates an absolute value. return nil, false } for _, label := range reverseLabels { if len(label) == 0 { // Empty labels are otherwise invalid. return nil, false } for _, c := range label { if c < 33 || c > 126 { // Invalid character. return nil, false } } } return reverseLabels, true } func parseSANExtension(value []byte) (dnsNames, emailAddresses []string, ipAddresses []net.IP, uris []*url.URL, err error) { err = forEachSAN(value, func(tag int, data []byte) error { switch tag { case nameTypeEmail: emailAddresses = append(emailAddresses, string(data)) case nameTypeDNS: dnsNames = append(dnsNames, string(data)) case nameTypeURI: uri, err := url.Parse(string(data)) if err != nil { return fmt.Errorf("x509: cannot parse URI %q: %s", string(data), err) } if len(uri.Host) > 0 { if _, ok := domainToReverseLabels(uri.Host); !ok { return fmt.Errorf("x509: cannot parse URI %q: invalid domain", string(data)) } } uris = append(uris, uri) case nameTypeIP: switch len(data) { case net.IPv4len, net.IPv6len: ipAddresses = append(ipAddresses, data) default: return errors.New("x509: cannot parse IP address of length " + strconv.Itoa(len(data))) } } return nil }) return } func VerifyDERCSRSign(asn1Data []byte, userId []byte) (bool, error) { csr, err := ParseCertificateRequest(asn1Data) if err != nil { return false, err } return VerifyCSRSign(csr, userId), nil } func VerifyCSRSign(csr *x509.CertificateRequest, userId []byte) bool { pub := csr.PublicKey.(*sm2.PublicKey) return sm2.Verify(pub, userId, csr.RawTBSCertificateRequest, csr.Signature) } func FillCertificateTemplateByCSR(template *x509.Certificate, csr *x509.CertificateRequest) { template.Subject = csr.Subject template.PublicKeyAlgorithm = csr.PublicKeyAlgorithm template.PublicKey = csr.PublicKey template.Extensions = csr.Extensions template.ExtraExtensions = csr.ExtraExtensions template.DNSNames = csr.DNSNames template.EmailAddresses = csr.EmailAddresses template.IPAddresses = csr.IPAddresses template.URIs = csr.URIs } func subjectBytes(cert *x509.Certificate) ([]byte, error) { if len(cert.RawSubject) > 0 { return cert.RawSubject, nil } return asn1.Marshal(cert.Subject.ToRDNSequence()) } func reverseBitsInAByte(in byte) byte { b1 := in>>4 | in<<4 b2 := b1>>2&0x33 | b1<<2&0xcc b3 := b2>>1&0x55 | b2<<1&0xaa return b3 } // asn1BitLength returns the bit-length of bitString by considering the // most-significant bit in a byte to be the "first" bit. This convention // matches ASN.1, but differs from almost everything else. func asn1BitLength(bitString []byte) int { bitLen := len(bitString) * 8 for i := range bitString { b := bitString[len(bitString)-i-1] for bit := uint(0); bit < 8; bit++ { if (b>>bit)&1 == 1 { return bitLen } bitLen-- } } return 0 } // RFC 5280, 4.2.1.12 Extended Key Usage // // anyExtendedKeyUsage OBJECT IDENTIFIER ::= { id-ce-extKeyUsage 0 } // // id-kp OBJECT IDENTIFIER ::= { id-pkix 3 } // // id-kp-serverAuth OBJECT IDENTIFIER ::= { id-kp 1 } // id-kp-clientAuth OBJECT IDENTIFIER ::= { id-kp 2 } // id-kp-codeSigning OBJECT IDENTIFIER ::= { id-kp 3 } // id-kp-emailProtection OBJECT IDENTIFIER ::= { id-kp 4 } // id-kp-timeStamping OBJECT IDENTIFIER ::= { id-kp 8 } // id-kp-OCSPSigning OBJECT IDENTIFIER ::= { id-kp 9 } var ( oidExtKeyUsageAny = asn1.ObjectIdentifier{2, 5, 29, 37, 0} oidExtKeyUsageServerAuth = asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 3, 1} oidExtKeyUsageClientAuth = asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 3, 2} oidExtKeyUsageCodeSigning = asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 3, 3} oidExtKeyUsageEmailProtection = asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 3, 4} oidExtKeyUsageIPSECEndSystem = asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 3, 5} oidExtKeyUsageIPSECTunnel = asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 3, 6} oidExtKeyUsageIPSECUser = asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 3, 7} oidExtKeyUsageTimeStamping = asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 3, 8} oidExtKeyUsageOCSPSigning = asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 3, 9} oidExtKeyUsageMicrosoftServerGatedCrypto = asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 311, 10, 3, 3} oidExtKeyUsageNetscapeServerGatedCrypto = asn1.ObjectIdentifier{2, 16, 840, 1, 113730, 4, 1} oidExtKeyUsageMicrosoftCommercialCodeSigning = asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 311, 2, 1, 22} oidExtKeyUsageMicrosoftKernelCodeSigning = asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 311, 61, 1, 1} ) // extKeyUsageOIDs contains the mapping between an ExtKeyUsage and its OID. var extKeyUsageOIDs = []struct { extKeyUsage x509.ExtKeyUsage oid asn1.ObjectIdentifier }{ {x509.ExtKeyUsageAny, oidExtKeyUsageAny}, {x509.ExtKeyUsageServerAuth, oidExtKeyUsageServerAuth}, {x509.ExtKeyUsageClientAuth, oidExtKeyUsageClientAuth}, {x509.ExtKeyUsageCodeSigning, oidExtKeyUsageCodeSigning}, {x509.ExtKeyUsageEmailProtection, oidExtKeyUsageEmailProtection}, {x509.ExtKeyUsageIPSECEndSystem, oidExtKeyUsageIPSECEndSystem}, {x509.ExtKeyUsageIPSECTunnel, oidExtKeyUsageIPSECTunnel}, {x509.ExtKeyUsageIPSECUser, oidExtKeyUsageIPSECUser}, {x509.ExtKeyUsageTimeStamping, oidExtKeyUsageTimeStamping}, {x509.ExtKeyUsageOCSPSigning, oidExtKeyUsageOCSPSigning}, {x509.ExtKeyUsageMicrosoftServerGatedCrypto, oidExtKeyUsageMicrosoftServerGatedCrypto}, {x509.ExtKeyUsageNetscapeServerGatedCrypto, oidExtKeyUsageNetscapeServerGatedCrypto}, {x509.ExtKeyUsageMicrosoftCommercialCodeSigning, oidExtKeyUsageMicrosoftCommercialCodeSigning}, {x509.ExtKeyUsageMicrosoftKernelCodeSigning, oidExtKeyUsageMicrosoftKernelCodeSigning}, } func oidFromExtKeyUsage(eku x509.ExtKeyUsage) (oid asn1.ObjectIdentifier, ok bool) { for _, pair := range extKeyUsageOIDs { if eku == pair.extKeyUsage { return pair.oid, true } } return } type basicConstraints struct { IsCA bool `asn1:"optional"` MaxPathLen int `asn1:"optional,default:-1"` } // RFC 5280, 4.2.1.1 type authKeyId struct { Id []byte `asn1:"optional,tag:0"` } // RFC 5280, 4.2.2.1 type authorityInfoAccess struct { Method asn1.ObjectIdentifier Location asn1.RawValue } // RFC 5280 4.2.1.4 type policyInformation struct { Policy asn1.ObjectIdentifier // policyQualifiers omitted } type distributionPointName struct { FullName []asn1.RawValue `asn1:"optional,tag:0"` RelativeName pkix.RDNSequence `asn1:"optional,tag:1"` } // RFC 5280, 4.2.1.14 type distributionPoint struct { DistributionPoint distributionPointName `asn1:"optional,tag:0"` Reason asn1.BitString `asn1:"optional,tag:1"` CRLIssuer asn1.RawValue `asn1:"optional,tag:2"` } func isIA5String(s string) error { for _, r := range s { if r >= utf8.RuneSelf { return fmt.Errorf("x509: %q cannot be encoded as an IA5String", s) } } return nil } func buildExtensions(template *x509.Certificate, subjectIsEmpty bool, authorityKeyId []byte) (ret []pkix.Extension, err error) { ret = make([]pkix.Extension, 10 /* maximum number of elements. */) n := 0 if template.KeyUsage != 0 && !oidInExtensions(oidExtensionKeyUsage, template.ExtraExtensions) { ret[n].Id = oidExtensionKeyUsage ret[n].Critical = true var a [2]byte a[0] = reverseBitsInAByte(byte(template.KeyUsage)) a[1] = reverseBitsInAByte(byte(template.KeyUsage >> 8)) l := 1 if a[1] != 0 { l = 2 } bitString := a[:l] ret[n].Value, err = asn1.Marshal(asn1.BitString{Bytes: bitString, BitLength: asn1BitLength(bitString)}) if err != nil { return } n++ } if (len(template.ExtKeyUsage) > 0 || len(template.UnknownExtKeyUsage) > 0) && !oidInExtensions(oidExtensionExtendedKeyUsage, template.ExtraExtensions) { ret[n].Id = oidExtensionExtendedKeyUsage var oids []asn1.ObjectIdentifier for _, u := range template.ExtKeyUsage { if oid, ok := oidFromExtKeyUsage(u); ok { oids = append(oids, oid) } else { panic("internal error") } } oids = append(oids, template.UnknownExtKeyUsage...) ret[n].Value, err = asn1.Marshal(oids) if err != nil { return } n++ } if template.BasicConstraintsValid && !oidInExtensions(oidExtensionBasicConstraints, template.ExtraExtensions) { // Leaving MaxPathLen as zero indicates that no maximum path // length is desired, unless MaxPathLenZero is set. A value of // -1 causes encoding/asn1 to omit the value as desired. maxPathLen := template.MaxPathLen if maxPathLen == 0 && !template.MaxPathLenZero { maxPathLen = -1 } ret[n].Id = oidExtensionBasicConstraints ret[n].Value, err = asn1.Marshal(basicConstraints{template.IsCA, maxPathLen}) ret[n].Critical = true if err != nil { return } n++ } if len(template.SubjectKeyId) > 0 && !oidInExtensions(oidExtensionSubjectKeyId, template.ExtraExtensions) { ret[n].Id = oidExtensionSubjectKeyId ret[n].Value, err = asn1.Marshal(template.SubjectKeyId) if err != nil { return } n++ } if len(authorityKeyId) > 0 && !oidInExtensions(oidExtensionAuthorityKeyId, template.ExtraExtensions) { ret[n].Id = oidExtensionAuthorityKeyId ret[n].Value, err = asn1.Marshal(authKeyId{authorityKeyId}) if err != nil { return } n++ } if (len(template.OCSPServer) > 0 || len(template.IssuingCertificateURL) > 0) && !oidInExtensions(oidExtensionAuthorityInfoAccess, template.ExtraExtensions) { ret[n].Id = oidExtensionAuthorityInfoAccess var aiaValues []authorityInfoAccess for _, name := range template.OCSPServer { aiaValues = append(aiaValues, authorityInfoAccess{ Method: oidAuthorityInfoAccessOcsp, Location: asn1.RawValue{Tag: 6, Class: 2, Bytes: []byte(name)}, }) } for _, name := range template.IssuingCertificateURL { aiaValues = append(aiaValues, authorityInfoAccess{ Method: oidAuthorityInfoAccessIssuers, Location: asn1.RawValue{Tag: 6, Class: 2, Bytes: []byte(name)}, }) } ret[n].Value, err = asn1.Marshal(aiaValues) if err != nil { return } n++ } if (len(template.DNSNames) > 0 || len(template.EmailAddresses) > 0 || len(template.IPAddresses) > 0 || len(template.URIs) > 0) && !oidInExtensions(oidExtensionSubjectAltName, template.ExtraExtensions) { ret[n].Id = oidExtensionSubjectAltName // https://tools.ietf.org/html/rfc5280#section-4.2.1.6 // “If the subject field contains an empty sequence ... then // subjectAltName extension ... is marked as critical” ret[n].Critical = subjectIsEmpty ret[n].Value, err = marshalSANs(template.DNSNames, template.EmailAddresses, template.IPAddresses, template.URIs) if err != nil { return } n++ } if len(template.PolicyIdentifiers) > 0 && !oidInExtensions(oidExtensionCertificatePolicies, template.ExtraExtensions) { ret[n].Id = oidExtensionCertificatePolicies policies := make([]policyInformation, len(template.PolicyIdentifiers)) for i, policy := range template.PolicyIdentifiers { policies[i].Policy = policy } ret[n].Value, err = asn1.Marshal(policies) if err != nil { return } n++ } if (len(template.PermittedDNSDomains) > 0 || len(template.ExcludedDNSDomains) > 0 || len(template.PermittedIPRanges) > 0 || len(template.ExcludedIPRanges) > 0 || len(template.PermittedEmailAddresses) > 0 || len(template.ExcludedEmailAddresses) > 0 || len(template.PermittedURIDomains) > 0 || len(template.ExcludedURIDomains) > 0) && !oidInExtensions(oidExtensionNameConstraints, template.ExtraExtensions) { ret[n].Id = oidExtensionNameConstraints ret[n].Critical = template.PermittedDNSDomainsCritical ipAndMask := func(ipNet *net.IPNet) []byte { maskedIP := ipNet.IP.Mask(ipNet.Mask) ipAndMask := make([]byte, 0, len(maskedIP)+len(ipNet.Mask)) ipAndMask = append(ipAndMask, maskedIP...) ipAndMask = append(ipAndMask, ipNet.Mask...) return ipAndMask } serialiseConstraints := func(dns []string, ips []*net.IPNet, emails []string, uriDomains []string) (der []byte, err error) { var b cryptobyte.Builder for _, name := range dns { if err = isIA5String(name); err != nil { return nil, err } b.AddASN1(cryptobyte_asn1.SEQUENCE, func(b *cryptobyte.Builder) { b.AddASN1(cryptobyte_asn1.Tag(2).ContextSpecific(), func(b *cryptobyte.Builder) { b.AddBytes([]byte(name)) }) }) } for _, ipNet := range ips { b.AddASN1(cryptobyte_asn1.SEQUENCE, func(b *cryptobyte.Builder) { b.AddASN1(cryptobyte_asn1.Tag(7).ContextSpecific(), func(b *cryptobyte.Builder) { b.AddBytes(ipAndMask(ipNet)) }) }) } for _, email := range emails { if err = isIA5String(email); err != nil { return nil, err } b.AddASN1(cryptobyte_asn1.SEQUENCE, func(b *cryptobyte.Builder) { b.AddASN1(cryptobyte_asn1.Tag(1).ContextSpecific(), func(b *cryptobyte.Builder) { b.AddBytes([]byte(email)) }) }) } for _, uriDomain := range uriDomains { if err = isIA5String(uriDomain); err != nil { return nil, err } b.AddASN1(cryptobyte_asn1.SEQUENCE, func(b *cryptobyte.Builder) { b.AddASN1(cryptobyte_asn1.Tag(6).ContextSpecific(), func(b *cryptobyte.Builder) { b.AddBytes([]byte(uriDomain)) }) }) } return b.Bytes() } permitted, err := serialiseConstraints(template.PermittedDNSDomains, template.PermittedIPRanges, template.PermittedEmailAddresses, template.PermittedURIDomains) if err != nil { return nil, err } excluded, err := serialiseConstraints(template.ExcludedDNSDomains, template.ExcludedIPRanges, template.ExcludedEmailAddresses, template.ExcludedURIDomains) if err != nil { return nil, err } var b cryptobyte.Builder b.AddASN1(cryptobyte_asn1.SEQUENCE, func(b *cryptobyte.Builder) { if len(permitted) > 0 { b.AddASN1(cryptobyte_asn1.Tag(0).ContextSpecific().Constructed(), func(b *cryptobyte.Builder) { b.AddBytes(permitted) }) } if len(excluded) > 0 { b.AddASN1(cryptobyte_asn1.Tag(1).ContextSpecific().Constructed(), func(b *cryptobyte.Builder) { b.AddBytes(excluded) }) } }) ret[n].Value, err = b.Bytes() if err != nil { return nil, err } n++ } if len(template.CRLDistributionPoints) > 0 && !oidInExtensions(oidExtensionCRLDistributionPoints, template.ExtraExtensions) { ret[n].Id = oidExtensionCRLDistributionPoints var crlDp []distributionPoint for _, name := range template.CRLDistributionPoints { dp := distributionPoint{ DistributionPoint: distributionPointName{ FullName: []asn1.RawValue{ {Tag: 6, Class: 2, Bytes: []byte(name)}, }, }, } crlDp = append(crlDp, dp) } ret[n].Value, err = asn1.Marshal(crlDp) if err != nil { return } n++ } // Adding another extension here? Remember to update the maximum number // of elements in the make() at the top of the function. return append(ret[:n], template.ExtraExtensions...), nil } type validity struct { NotBefore, NotAfter time.Time } type certificate struct { Raw asn1.RawContent TBSCertificate TBSCertificate SignatureAlgorithm pkix.AlgorithmIdentifier SignatureValue asn1.BitString } type tbsCertificate struct { Raw asn1.RawContent Version int `asn1:"optional,explicit,default:0,tag:0"` SerialNumber *big.Int SignatureAlgorithm pkix.AlgorithmIdentifier Issuer asn1.RawValue Validity validity Subject asn1.RawValue PublicKey publicKeyInfo UniqueId asn1.BitString `asn1:"optional,tag:1"` SubjectUniqueId asn1.BitString `asn1:"optional,tag:2"` Extensions []pkix.Extension `asn1:"optional,explicit,tag:3"` } type TBSCertificate tbsCertificate // emptyASN1Subject is the ASN.1 DER encoding of an empty Subject, which is // just an empty SEQUENCE. var emptyASN1Subject = []byte{0x30, 0} // 为什么要将构建CertificateInfo和签发证书分开呢? // 是因为实际应用中的CA密钥大多数都是放在加密卡/加密机中的,签名由加密卡/加密机来完成 func CreateCertificateInfo(template, parent *x509.Certificate, csr *x509.CertificateRequest) (*TBSCertificate, error) { if template.SerialNumber == nil { return nil, errors.New("x509: no SerialNumber given") } asn1Issuer, err := subjectBytes(parent) if err != nil { return nil, err } asn1Subject, err := subjectBytes(template) if err != nil { return nil, err } authorityKeyId := template.AuthorityKeyId if !bytes.Equal(asn1Issuer, asn1Subject) && len(parent.SubjectKeyId) > 0 { authorityKeyId = parent.SubjectKeyId } extensions, err := buildExtensions(template, bytes.Equal(asn1Subject, emptyASN1Subject), authorityKeyId) if err != nil { return nil, err } var sigAlgo pkix.AlgorithmIdentifier sigAlgo.Algorithm = oidSignatureSM3WithSM2 var subjectPubKeyInfo publicKeyInfo rest, err := asn1.Unmarshal(csr.RawSubjectPublicKeyInfo, &subjectPubKeyInfo) if err != nil { return nil, err } else if len(rest) != 0 { return nil, asn1.SyntaxError{Msg: "trailing data"} } c := TBSCertificate{ Version: 2, SerialNumber: template.SerialNumber, SignatureAlgorithm: sigAlgo, Issuer: asn1.RawValue{FullBytes: asn1Issuer}, Validity: validity{template.NotBefore.UTC(), template.NotAfter.UTC()}, Subject: asn1.RawValue{FullBytes: asn1Subject}, PublicKey: subjectPubKeyInfo, Extensions: extensions, } tbsCertContents, err := asn1.Marshal(c) if err != nil { return nil, err } c.Raw = tbsCertContents return &c, nil } func IssueCertificateBySoftCAKey(cinfo *TBSCertificate, caPri *sm2.PrivateKey, userId []byte) ([]byte, error) { signature, err := sm2.Sign(caPri, userId, cinfo.Raw) if err != nil { return nil, err } return CreateCertificate(cinfo, signature) } func CreateCertificate(cinfo *TBSCertificate, signature []byte) ([]byte, error) { var sigAlgo pkix.AlgorithmIdentifier sigAlgo.Algorithm = oidSignatureSM3WithSM2 return asn1.Marshal(certificate{ nil, *cinfo, sigAlgo, asn1.BitString{Bytes: signature, BitLength: len(signature) * 8}, }) } // ParseCertificate parses a single certificate from the given ASN.1 DER data. func ParseCertificate(asn1Data []byte) (*x509.Certificate, error) { var cert certificate rest, err := asn1.Unmarshal(asn1Data, &cert) if err != nil { return nil, err } if len(rest) > 0 { return nil, asn1.SyntaxError{Msg: "trailing data"} } return parseCertificate(&cert) } func parseCertificate(in *certificate) (*x509.Certificate, error) { out := new(x509.Certificate) out.Raw = in.Raw out.RawTBSCertificate = in.TBSCertificate.Raw out.RawSubjectPublicKeyInfo = in.TBSCertificate.PublicKey.Raw out.RawSubject = in.TBSCertificate.Subject.FullBytes out.RawIssuer = in.TBSCertificate.Issuer.FullBytes out.Signature = in.SignatureValue.RightAlign() out.SignatureAlgorithm = 0 out.PublicKeyAlgorithm = 0 if !oidSignatureSM3WithSM2.Equal(in.SignatureAlgorithm.Algorithm) { return nil, errors.New("x509: illegal signature algorithm OID") } if !oidPublicKeyECDSA.Equal(in.TBSCertificate.PublicKey.Algorithm.Algorithm) { return nil, errors.New("x509: illegal publick key algorithm OID") } var err error out.PublicKey, err = parsePublicKey(&in.TBSCertificate.PublicKey) if err != nil { return nil, err } out.Version = in.TBSCertificate.Version + 1 out.SerialNumber = in.TBSCertificate.SerialNumber var issuer, subject pkix.RDNSequence if rest, err := asn1.Unmarshal(in.TBSCertificate.Subject.FullBytes, &subject); err != nil { return nil, err } else if len(rest) != 0 { return nil, errors.New("x509: trailing data after X.509 subject") } if rest, err := asn1.Unmarshal(in.TBSCertificate.Issuer.FullBytes, &issuer); err != nil { return nil, err } else if len(rest) != 0 { return nil, errors.New("x509: trailing data after X.509 subject") } out.Issuer.FillFromRDNSequence(&issuer) out.Subject.FillFromRDNSequence(&subject) out.NotBefore = in.TBSCertificate.Validity.NotBefore out.NotAfter = in.TBSCertificate.Validity.NotAfter for _, e := range in.TBSCertificate.Extensions { out.Extensions = append(out.Extensions, e) unhandled := false if len(e.Id) == 4 && e.Id[0] == 2 && e.Id[1] == 5 && e.Id[2] == 29 { switch e.Id[3] { case 15: // RFC 5280, 4.2.1.3 var usageBits asn1.BitString if rest, err := asn1.Unmarshal(e.Value, &usageBits); err != nil { return nil, err } else if len(rest) != 0 { return nil, errors.New("x509: trailing data after X.509 KeyUsage") } var usage int for i := 0; i < 9; i++ { if usageBits.At(i) != 0 { usage |= 1 << uint(i) } } out.KeyUsage = x509.KeyUsage(usage) case 19: // RFC 5280, 4.2.1.9 var constraints basicConstraints if rest, err := asn1.Unmarshal(e.Value, &constraints); err != nil { return nil, err } else if len(rest) != 0 { return nil, errors.New("x509: trailing data after X.509 BasicConstraints") } out.BasicConstraintsValid = true out.IsCA = constraints.IsCA out.MaxPathLen = constraints.MaxPathLen out.MaxPathLenZero = out.MaxPathLen == 0 // TODO: map out.MaxPathLen to 0 if it has the -1 default value? (Issue 19285) case 17: out.DNSNames, out.EmailAddresses, out.IPAddresses, out.URIs, err = parseSANExtension(e.Value) if err != nil { return nil, err } if len(out.DNSNames) == 0 && len(out.EmailAddresses) == 0 && len(out.IPAddresses) == 0 && len(out.URIs) == 0 { // If we didn't parse anything then we do the critical check, below. unhandled = true } case 30: unhandled, err = parseNameConstraintsExtension(out, e) if err != nil { return nil, err } case 31: // RFC 5280, 4.2.1.13 // CRLDistributionPoints ::= SEQUENCE SIZE (1..MAX) OF DistributionPoint // // DistributionPoint ::= SEQUENCE { // distributionPoint [0] DistributionPointName OPTIONAL, // reasons [1] ReasonFlags OPTIONAL, // cRLIssuer [2] GeneralNames OPTIONAL } // // DistributionPointName ::= CHOICE { // fullName [0] GeneralNames, // nameRelativeToCRLIssuer [1] RelativeDistinguishedName } var cdp []distributionPoint if rest, err := asn1.Unmarshal(e.Value, &cdp); err != nil { return nil, err } else if len(rest) != 0 { return nil, errors.New("x509: trailing data after X.509 CRL distribution point") } for _, dp := range cdp { // Per RFC 5280, 4.2.1.13, one of distributionPoint or cRLIssuer may be empty. if len(dp.DistributionPoint.FullName) == 0 { continue } for _, fullName := range dp.DistributionPoint.FullName { if fullName.Tag == 6 { out.CRLDistributionPoints = append(out.CRLDistributionPoints, string(fullName.Bytes)) } } } case 35: // RFC 5280, 4.2.1.1 var a authKeyId if rest, err := asn1.Unmarshal(e.Value, &a); err != nil { return nil, err } else if len(rest) != 0 { return nil, errors.New("x509: trailing data after X.509 authority key-id") } out.AuthorityKeyId = a.Id case 37: // RFC 5280, 4.2.1.12. Extended Key Usage // id-ce-extKeyUsage OBJECT IDENTIFIER ::= { id-ce 37 } // // ExtKeyUsageSyntax ::= SEQUENCE SIZE (1..MAX) OF KeyPurposeId // // KeyPurposeId ::= OBJECT IDENTIFIER var keyUsage []asn1.ObjectIdentifier if rest, err := asn1.Unmarshal(e.Value, &keyUsage); err != nil { return nil, err } else if len(rest) != 0 { return nil, errors.New("x509: trailing data after X.509 ExtendedKeyUsage") } for _, u := range keyUsage { if extKeyUsage, ok := extKeyUsageFromOID(u); ok { out.ExtKeyUsage = append(out.ExtKeyUsage, extKeyUsage) } else { out.UnknownExtKeyUsage = append(out.UnknownExtKeyUsage, u) } } case 14: // RFC 5280, 4.2.1.2 var keyid []byte if rest, err := asn1.Unmarshal(e.Value, &keyid); err != nil { return nil, err } else if len(rest) != 0 { return nil, errors.New("x509: trailing data after X.509 key-id") } out.SubjectKeyId = keyid case 32: // RFC 5280 4.2.1.4: Certificate Policies var policies []policyInformation if rest, err := asn1.Unmarshal(e.Value, &policies); err != nil { return nil, err } else if len(rest) != 0 { return nil, errors.New("x509: trailing data after X.509 certificate policies") } out.PolicyIdentifiers = make([]asn1.ObjectIdentifier, len(policies)) for i, policy := range policies { out.PolicyIdentifiers[i] = policy.Policy } default: // Unknown extensions are recorded if critical. unhandled = true } } else if e.Id.Equal(oidExtensionAuthorityInfoAccess) { // RFC 5280 4.2.2.1: Authority Information Access var aia []authorityInfoAccess if rest, err := asn1.Unmarshal(e.Value, &aia); err != nil { return nil, err } else if len(rest) != 0 { return nil, errors.New("x509: trailing data after X.509 authority information") } for _, v := range aia { // GeneralName: uniformResourceIdentifier [6] IA5String if v.Location.Tag != 6 { continue } if v.Method.Equal(oidAuthorityInfoAccessOcsp) { out.OCSPServer = append(out.OCSPServer, string(v.Location.Bytes)) } else if v.Method.Equal(oidAuthorityInfoAccessIssuers) { out.IssuingCertificateURL = append(out.IssuingCertificateURL, string(v.Location.Bytes)) } } } else { // Unknown extensions are recorded if critical. unhandled = true } if e.Critical && unhandled { out.UnhandledCriticalExtensions = append(out.UnhandledCriticalExtensions, e.Id) } } return out, nil } func parseNameConstraintsExtension(out *x509.Certificate, e pkix.Extension) (unhandled bool, err error) { // RFC 5280, 4.2.1.10 // NameConstraints ::= SEQUENCE { // permittedSubtrees [0] GeneralSubtrees OPTIONAL, // excludedSubtrees [1] GeneralSubtrees OPTIONAL } // // GeneralSubtrees ::= SEQUENCE SIZE (1..MAX) OF GeneralSubtree // // GeneralSubtree ::= SEQUENCE { // base GeneralName, // minimum [0] BaseDistance DEFAULT 0, // maximum [1] BaseDistance OPTIONAL } // // BaseDistance ::= INTEGER (0..MAX) outer := cryptobyte.String(e.Value) var toplevel, permitted, excluded cryptobyte.String var havePermitted, haveExcluded bool if !outer.ReadASN1(&toplevel, cryptobyte_asn1.SEQUENCE) || !outer.Empty() || !toplevel.ReadOptionalASN1(&permitted, &havePermitted, cryptobyte_asn1.Tag(0).ContextSpecific().Constructed()) || !toplevel.ReadOptionalASN1(&excluded, &haveExcluded, cryptobyte_asn1.Tag(1).ContextSpecific().Constructed()) || !toplevel.Empty() { return false, errors.New("x509: invalid NameConstraints extension") } if !havePermitted && !haveExcluded || len(permitted) == 0 && len(excluded) == 0 { // https://tools.ietf.org/html/rfc5280#section-4.2.1.10: // “either the permittedSubtrees field // or the excludedSubtrees MUST be // present” return false, errors.New("x509: empty name constraints extension") } getValues := func(subtrees cryptobyte.String) (dnsNames []string, ips []*net.IPNet, emails, uriDomains []string, err error) { for !subtrees.Empty() { var seq, value cryptobyte.String var tag cryptobyte_asn1.Tag if !subtrees.ReadASN1(&seq, cryptobyte_asn1.SEQUENCE) || !seq.ReadAnyASN1(&value, &tag) { return nil, nil, nil, nil, fmt.Errorf("x509: invalid NameConstraints extension") } var ( dnsTag = cryptobyte_asn1.Tag(2).ContextSpecific() emailTag = cryptobyte_asn1.Tag(1).ContextSpecific() ipTag = cryptobyte_asn1.Tag(7).ContextSpecific() uriTag = cryptobyte_asn1.Tag(6).ContextSpecific() ) switch tag { case dnsTag: domain := string(value) if err := isIA5String(domain); err != nil { return nil, nil, nil, nil, errors.New("x509: invalid constraint value: " + err.Error()) } trimmedDomain := domain if len(trimmedDomain) > 0 && trimmedDomain[0] == '.' { // constraints can have a leading // period to exclude the domain // itself, but that's not valid in a // normal domain name. trimmedDomain = trimmedDomain[1:] } if _, ok := domainToReverseLabels(trimmedDomain); !ok { return nil, nil, nil, nil, fmt.Errorf("x509: failed to parse dnsName constraint %q", domain) } dnsNames = append(dnsNames, domain) case ipTag: l := len(value) var ip, mask []byte switch l { case 8: ip = value[:4] mask = value[4:] case 32: ip = value[:16] mask = value[16:] default: return nil, nil, nil, nil, fmt.Errorf("x509: IP constraint contained value of length %d", l) } if !isValidIPMask(mask) { return nil, nil, nil, nil, fmt.Errorf("x509: IP constraint contained invalid mask %x", mask) } ips = append(ips, &net.IPNet{IP: net.IP(ip), Mask: net.IPMask(mask)}) case emailTag: constraint := string(value) if err := isIA5String(constraint); err != nil { return nil, nil, nil, nil, errors.New("x509: invalid constraint value: " + err.Error()) } // If the constraint contains an @ then // it specifies an exact mailbox name. if strings.Contains(constraint, "@") { if _, ok := parseRFC2821Mailbox(constraint); !ok { return nil, nil, nil, nil, fmt.Errorf("x509: failed to parse rfc822Name constraint %q", constraint) } } else { // Otherwise it's a domain name. domain := constraint if len(domain) > 0 && domain[0] == '.' { domain = domain[1:] } if _, ok := domainToReverseLabels(domain); !ok { return nil, nil, nil, nil, fmt.Errorf("x509: failed to parse rfc822Name constraint %q", constraint) } } emails = append(emails, constraint) case uriTag: domain := string(value) if err := isIA5String(domain); err != nil { return nil, nil, nil, nil, errors.New("x509: invalid constraint value: " + err.Error()) } if net.ParseIP(domain) != nil { return nil, nil, nil, nil, fmt.Errorf("x509: failed to parse URI constraint %q: cannot be IP address", domain) } trimmedDomain := domain if len(trimmedDomain) > 0 && trimmedDomain[0] == '.' { // constraints can have a leading // period to exclude the domain itself, // but that's not valid in a normal // domain name. trimmedDomain = trimmedDomain[1:] } if _, ok := domainToReverseLabels(trimmedDomain); !ok { return nil, nil, nil, nil, fmt.Errorf("x509: failed to parse URI constraint %q", domain) } uriDomains = append(uriDomains, domain) default: unhandled = true } } return dnsNames, ips, emails, uriDomains, nil } if out.PermittedDNSDomains, out.PermittedIPRanges, out.PermittedEmailAddresses, out.PermittedURIDomains, err = getValues(permitted); err != nil { return false, err } if out.ExcludedDNSDomains, out.ExcludedIPRanges, out.ExcludedEmailAddresses, out.ExcludedURIDomains, err = getValues(excluded); err != nil { return false, err } out.PermittedDNSDomainsCritical = e.Critical return unhandled, nil } // isValidIPMask returns true iff mask consists of zero or more 1 bits, followed by zero bits. func isValidIPMask(mask []byte) bool { seenZero := false for _, b := range mask { if seenZero { if b != 0 { return false } continue } switch b { case 0x00, 0x80, 0xc0, 0xe0, 0xf0, 0xf8, 0xfc, 0xfe: seenZero = true case 0xff: default: return false } } return true } // rfc2821Mailbox represents a “mailbox” (which is an email address to most // people) by breaking it into the “local” (i.e. before the '@') and “domain” // parts. type rfc2821Mailbox struct { local, domain string } // parseRFC2821Mailbox parses an email address into local and domain parts, // based on the ABNF for a “Mailbox” from RFC 2821. According to // https://tools.ietf.org/html/rfc5280#section-4.2.1.6 that's correct for an // rfc822Name from a certificate: “The format of an rfc822Name is a "Mailbox" // as defined in https://tools.ietf.org/html/rfc2821#section-4.1.2”. func parseRFC2821Mailbox(in string) (mailbox rfc2821Mailbox, ok bool) { if len(in) == 0 { return mailbox, false } localPartBytes := make([]byte, 0, len(in)/2) if in[0] == '"' { // Quoted-string = DQUOTE *qcontent DQUOTE // non-whitespace-control = %d1-8 / %d11 / %d12 / %d14-31 / %d127 // qcontent = qtext / quoted-pair // qtext = non-whitespace-control / // %d33 / %d35-91 / %d93-126 // quoted-pair = ("\" text) / obs-qp // text = %d1-9 / %d11 / %d12 / %d14-127 / obs-text // // (Names beginning with “obs-” are the obsolete syntax from // https://tools.ietf.org/html/rfc2822#section-4. Since it has // been 16 years, we no longer accept that.) in = in[1:] QuotedString: for { if len(in) == 0 { return mailbox, false } c := in[0] in = in[1:] switch { case c == '"': break QuotedString case c == '\\': // quoted-pair if len(in) == 0 { return mailbox, false } if in[0] == 11 || in[0] == 12 || (1 <= in[0] && in[0] <= 9) || (14 <= in[0] && in[0] <= 127) { localPartBytes = append(localPartBytes, in[0]) in = in[1:] } else { return mailbox, false } case c == 11 || c == 12 || // Space (char 32) is not allowed based on the // BNF, but RFC 3696 gives an example that // assumes that it is. Several “verified” // errata continue to argue about this point. // We choose to accept it. c == 32 || c == 33 || c == 127 || (1 <= c && c <= 8) || (14 <= c && c <= 31) || (35 <= c && c <= 91) || (93 <= c && c <= 126): // qtext localPartBytes = append(localPartBytes, c) default: return mailbox, false } } } else { // Atom ("." Atom)* NextChar: for len(in) > 0 { // atext from https://tools.ietf.org/html/rfc2822#section-3.2.4 c := in[0] switch { case c == '\\': // Examples given in RFC 3696 suggest that // escaped characters can appear outside of a // quoted string. Several “verified” errata // continue to argue the point. We choose to // accept it. in = in[1:] if len(in) == 0 { return mailbox, false } fallthrough case ('0' <= c && c <= '9') || ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '!' || c == '#' || c == '$' || c == '%' || c == '&' || c == '\'' || c == '*' || c == '+' || c == '-' || c == '/' || c == '=' || c == '?' || c == '^' || c == '_' || c == '`' || c == '{' || c == '|' || c == '}' || c == '~' || c == '.': localPartBytes = append(localPartBytes, in[0]) in = in[1:] default: break NextChar } } if len(localPartBytes) == 0 { return mailbox, false } // https://tools.ietf.org/html/rfc3696#section-3 // “period (".") may also appear, but may not be used to start // or end the local part, nor may two or more consecutive // periods appear.” twoDots := []byte{'.', '.'} if localPartBytes[0] == '.' || localPartBytes[len(localPartBytes)-1] == '.' || bytes.Contains(localPartBytes, twoDots) { return mailbox, false } } if len(in) == 0 || in[0] != '@' { return mailbox, false } in = in[1:] // The RFC species a format for domains, but that's known to be // violated in practice so we accept that anything after an '@' is the // domain part. if _, ok := domainToReverseLabels(in); !ok { return mailbox, false } mailbox.local = string(localPartBytes) mailbox.domain = in return mailbox, true } func extKeyUsageFromOID(oid asn1.ObjectIdentifier) (eku x509.ExtKeyUsage, ok bool) { for _, pair := range extKeyUsageOIDs { if oid.Equal(pair.oid) { return pair.extKeyUsage, true } } return } ================================================ FILE: sm2/cert/gmx509_test.go ================================================ package cert import ( "bytes" "crypto/rand" "crypto/x509" "crypto/x509/pkix" "encoding/asn1" "fmt" "io/ioutil" "math/big" "testing" "time" "github.com/ZZMarquis/gm/sm2" ) func TestX500Name(t *testing.T) { name := new(pkix.Name) name.CommonName = "ID=Mock Root CA" name.Country = []string{"CN"} name.Province = []string{"Beijing"} name.Locality = []string{"Beijing"} name.Organization = []string{"org.zz"} name.OrganizationalUnit = []string{"org.zz"} fmt.Println(name.String()) } func TestCreateCertificateRequest(t *testing.T) { pri, pub, err := sm2.GenerateKey(rand.Reader) if err != nil { t.Fatal(err) } sanContents, err := marshalSANs([]string{"foo.example.com"}, nil, nil, nil) if err != nil { t.Fatal(err) } template := x509.CertificateRequest{ Subject: pkix.Name{ CommonName: "test.example.com", Organization: []string{"Σ Acme Co"}, }, DNSNames: []string{"test.example.com"}, // An explicit extension should override the DNSNames from the // template. ExtraExtensions: []pkix.Extension{ { Id: oidExtensionSubjectAltName, Value: sanContents, }, }, } derBytes, err := CreateCertificateRequest(&template, pub, pri, nil) if err != nil { t.Fatal(err) } ioutil.WriteFile("sample.csr", derBytes, 0644) csr, err := ParseCertificateRequest(derBytes) if err != nil { t.Fatal(err) } csrPub := csr.PublicKey.(*sm2.PublicKey) if !bytes.Equal(pub.GetUnCompressBytes(), csrPub.GetUnCompressBytes()) { t.Fatal("public key not equals") } b, err := VerifyDERCSRSign(derBytes, nil) if err != nil { t.Fatal(err) } if !b { t.Fatal("Verify CSR sign not pass") } } func TestCreateCertificate(t *testing.T) { pri, pub, err := sm2.GenerateKey(rand.Reader) if err != nil { t.Fatal(err) } sanContents, err := marshalSANs([]string{"foo.example.com"}, nil, nil, nil) if err != nil { t.Fatal(err) } template := x509.CertificateRequest{ Subject: pkix.Name{ CommonName: "test.example.com", Organization: []string{"Σ Acme Co"}, }, DNSNames: []string{"test.example.com"}, // An explicit extension should override the DNSNames from the // template. ExtraExtensions: []pkix.Extension{ { Id: oidExtensionSubjectAltName, Value: sanContents, }, }, } derBytes, err := CreateCertificateRequest(&template, pub, pri, nil) if err != nil { t.Fatal(err) } csr, err := ParseCertificateRequest(derBytes) if err != nil { t.Fatal(err) } testExtKeyUsage := []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth} testUnknownExtKeyUsage := []asn1.ObjectIdentifier{[]int{1, 2, 3}, []int{2, 59, 1}} cerTemplate := x509.Certificate{ // SerialNumber is negative to ensure that negative // values are parsed. This is due to the prevalence of // buggy code that produces certificates with negative // serial numbers. SerialNumber: big.NewInt(-1), NotBefore: time.Now(), NotAfter: time.Unix(time.Now().Unix()+100000000, 0), SubjectKeyId: []byte{1, 2, 3, 4}, KeyUsage: x509.KeyUsageCertSign, ExtKeyUsage: testExtKeyUsage, UnknownExtKeyUsage: testUnknownExtKeyUsage, BasicConstraintsValid: true, IsCA: true, OCSPServer: []string{"http://ocsp.example.com"}, IssuingCertificateURL: []string{"http://crt.example.com/ca1.crt"}, PolicyIdentifiers: []asn1.ObjectIdentifier{[]int{1, 2, 3}}, CRLDistributionPoints: []string{"http://crl1.example.com/ca1.crl", "http://crl2.example.com/ca1.crl"}, } FillCertificateTemplateByCSR(&cerTemplate, csr) cinfo, err := CreateCertificateInfo(&cerTemplate, &cerTemplate, csr) if err != nil { t.Fatal(err) } sign, err := sm2.Sign(pri, nil, cinfo.Raw) if err != nil { t.Fatal(err) } cer, err := CreateCertificate(cinfo, sign) if err != nil { t.Fatal(err) } ioutil.WriteFile("sample.cer", cer, 0644) certificate, err := ParseCertificate(cer) if err != nil { t.Fatal(err) } fmt.Println(certificate.DNSNames) } ================================================ FILE: sm2/keyexchange.go ================================================ package sm2 import ( "bytes" "encoding/binary" "errors" "github.com/ZZMarquis/gm/sm3" "github.com/ZZMarquis/gm/util" "hash" "math/big" ) type ExchangeResult struct { Key []byte S1 []byte S2 []byte } func reduce(x *big.Int, w int) *big.Int { intOne := new(big.Int).SetInt64(1) result := util.Lsh(intOne, uint(w)) result = util.Sub(result, intOne) result = util.And(x, result) result = util.SetBit(result, w, 1) return result } func calculateU(w int, selfStaticPriv *PrivateKey, selfEphemeralPriv *PrivateKey, selfEphemeralPub *PublicKey, otherStaticPub *PublicKey, otherEphemeralPub *PublicKey) (x *big.Int, y *big.Int) { x1 := reduce(selfEphemeralPub.X, w) x2 := reduce(otherEphemeralPub.X, w) tA := util.Mul(x1, selfEphemeralPriv.D) tA = util.Add(selfStaticPriv.D, tA) k1 := util.Mul(sm2H, tA) k1 = util.Mod(k1, selfStaticPriv.Curve.N) k2 := util.Mul(k1, x2) k2 = util.Mod(k2, selfStaticPriv.Curve.N) p1x, p1y := selfStaticPriv.Curve.ScalarMult(otherStaticPub.X, otherStaticPub.Y, k1.Bytes()) p2x, p2y := selfStaticPriv.Curve.ScalarMult(otherEphemeralPub.X, otherEphemeralPub.Y, k2.Bytes()) x, y = selfStaticPriv.Curve.Add(p1x, p1y, p2x, p2y) return } func kdfForExch(digest hash.Hash, ux, uy *big.Int, za, zb []byte, keyBits int) []byte { bufSize := 4 if bufSize < digest.BlockSize() { bufSize = digest.BlockSize() } buf := make([]byte, bufSize) rv := make([]byte, (keyBits+7)/8) rvLen := len(rv) uxBytes := ux.Bytes() uyBytes := uy.Bytes() off := 0 ct := uint32(0) for off < rvLen { digest.Reset() digest.Write(uxBytes) digest.Write(uyBytes) digest.Write(za) digest.Write(zb) ct++ binary.BigEndian.PutUint32(buf, ct) digest.Write(buf[:4]) tmp := digest.Sum(nil) copy(buf[:bufSize], tmp[:bufSize]) copyLen := rvLen - off copy(rv[off:off+copyLen], buf[:copyLen]) off += copyLen } return rv } func calculateInnerHash(digest hash.Hash, ux *big.Int, za, zb []byte, p1x, p1y *big.Int, p2x, p2y *big.Int) []byte { digest.Reset() digest.Write(ux.Bytes()) digest.Write(za) digest.Write(zb) digest.Write(p1x.Bytes()) digest.Write(p1y.Bytes()) digest.Write(p2x.Bytes()) digest.Write(p2y.Bytes()) return digest.Sum(nil) } func s1(digest hash.Hash, uy *big.Int, innerHash []byte) []byte { digest.Reset() digest.Write([]byte{0x02}) digest.Write(uy.Bytes()) digest.Write(innerHash) return digest.Sum(nil) } func s2(digest hash.Hash, uy *big.Int, innerHash []byte) []byte { digest.Reset() digest.Write([]byte{0x03}) digest.Write(uy.Bytes()) digest.Write(innerHash) return digest.Sum(nil) } func CalculateKeyWithConfirmation(initiator bool, keyBits int, confirmationTag []byte, selfStaticPriv *PrivateKey, selfEphemeralPriv *PrivateKey, selfId []byte, otherStaticPub *PublicKey, otherEphemeralPub *PublicKey, otherId []byte) (*ExchangeResult, error) { if selfId == nil { selfId = make([]byte, 0) } if otherId == nil { otherId = make([]byte, 0) } if initiator && confirmationTag == nil { return nil, errors.New("if initiating, confirmationTag must be set") } selfStaticPub := CalculatePubKey(selfStaticPriv) digest := sm3.New() za := getZ(digest, &selfStaticPriv.Curve, selfStaticPub.X, selfStaticPub.Y, selfId) zb := getZ(digest, &selfStaticPriv.Curve, otherStaticPub.X, otherStaticPub.Y, otherId) w := selfStaticPriv.Curve.BitSize/2 - 1 selfEphemeralPub := CalculatePubKey(selfEphemeralPriv) ux, uy := calculateU(w, selfStaticPriv, selfEphemeralPriv, selfEphemeralPub, otherStaticPub, otherEphemeralPub) if initiator { rv := kdfForExch(digest, ux, uy, za, zb, keyBits) innerHash := calculateInnerHash(digest, ux, za, zb, selfEphemeralPub.X, selfEphemeralPub.Y, otherEphemeralPub.X, otherEphemeralPub.Y) s1 := s1(digest, uy, innerHash) if !bytes.Equal(s1, confirmationTag) { return nil, errors.New("confirmation tag mismatch") } s2 := s2(digest, uy, innerHash) return &ExchangeResult{Key: rv, S2: s2}, nil } else { rv := kdfForExch(digest, ux, uy, zb, za, keyBits) innerHash := calculateInnerHash(digest, ux, zb, za, otherEphemeralPub.X, otherEphemeralPub.Y, selfEphemeralPub.X, selfEphemeralPub.Y) s1 := s1(digest, uy, innerHash) s2 := s2(digest, uy, innerHash) return &ExchangeResult{Key: rv, S1: s1, S2: s2}, nil } } func ResponderConfirm(responderS2 []byte, initiatorS2 []byte) bool { return bytes.Equal(responderS2, initiatorS2) } ================================================ FILE: sm2/keyexchange_test.go ================================================ package sm2 import ( "crypto/rand" "testing" ) const ( KeyBits = 128 ) var ( initiatorId = []byte("ABCDEFG1234") responderId = []byte("1234567ABCD") ) func TestSM2KeyExchange(t *testing.T) { initiatorStaticPriv, initiatorStaticPub, _ := GenerateKey(rand.Reader) initiatorEphemeralPriv, initiatorEphemeralPub, _ := GenerateKey(rand.Reader) responderStaticPriv, responderStaticPub, _ := GenerateKey(rand.Reader) responderEphemeralPriv, responderEphemeralPub, _ := GenerateKey(rand.Reader) responderResult, err := CalculateKeyWithConfirmation(false, KeyBits, nil, responderStaticPriv, responderEphemeralPriv, responderId, initiatorStaticPub, initiatorEphemeralPub, initiatorId) if err != nil { t.Error(err.Error()) return } initiatorResult, err := CalculateKeyWithConfirmation(true, KeyBits, responderResult.S1, initiatorStaticPriv, initiatorEphemeralPriv, initiatorId, responderStaticPub, responderEphemeralPub, responderId) if err != nil { t.Error(err.Error()) return } if !ResponderConfirm(responderResult.S2, initiatorResult.S2) { t.Error("responder confirm s2 failed") return } } ================================================ FILE: sm2/sm2.go ================================================ package sm2 import ( "bytes" "crypto/elliptic" "crypto/rand" "encoding/asn1" "encoding/binary" "errors" "fmt" "hash" "io" "math/big" "github.com/ZZMarquis/gm/sm3" "github.com/ZZMarquis/gm/util" ) const ( BitSize = 256 KeyBytes = (BitSize + 7) / 8 UnCompress = 0x04 ) type Sm2CipherTextType int32 const ( // 旧标准的密文顺序 C1C2C3 Sm2CipherTextType = 1 // [GM/T 0009-2012]标准规定的顺序 C1C3C2 Sm2CipherTextType = 2 ) var ( sm2H = new(big.Int).SetInt64(1) sm2SignDefaultUserId = []byte{ 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38} ) var sm2P256V1 P256V1Curve type P256V1Curve struct { *elliptic.CurveParams A *big.Int } type PublicKey struct { X, Y *big.Int Curve P256V1Curve } type PrivateKey struct { D *big.Int Curve P256V1Curve } type sm2Signature struct { R, S *big.Int } type sm2CipherC1C3C2 struct { X, Y *big.Int C3 []byte C2 []byte } type sm2CipherC1C2C3 struct { X, Y *big.Int C2 []byte C3 []byte } func init() { initSm2P256V1() } func initSm2P256V1() { sm2P, _ := new(big.Int).SetString("FFFFFFFEFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF00000000FFFFFFFFFFFFFFFF", 16) sm2A, _ := new(big.Int).SetString("FFFFFFFEFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF00000000FFFFFFFFFFFFFFFC", 16) sm2B, _ := new(big.Int).SetString("28E9FA9E9D9F5E344D5A9E4BCF6509A7F39789F515AB8F92DDBCBD414D940E93", 16) sm2N, _ := new(big.Int).SetString("FFFFFFFEFFFFFFFFFFFFFFFFFFFFFFFF7203DF6B21C6052B53BBF40939D54123", 16) sm2Gx, _ := new(big.Int).SetString("32C4AE2C1F1981195F9904466A39C9948FE30BBFF2660BE1715A4589334C74C7", 16) sm2Gy, _ := new(big.Int).SetString("BC3736A2F4F6779C59BDCEE36B692153D0A9877CC62A474002DF32E52139F0A0", 16) sm2P256V1.CurveParams = &elliptic.CurveParams{Name: "SM2-P-256-V1"} sm2P256V1.P = sm2P sm2P256V1.A = sm2A sm2P256V1.B = sm2B sm2P256V1.N = sm2N sm2P256V1.Gx = sm2Gx sm2P256V1.Gy = sm2Gy sm2P256V1.BitSize = BitSize } func GetSm2P256V1() P256V1Curve { return sm2P256V1 } func GenerateKey(rand io.Reader) (*PrivateKey, *PublicKey, error) { priv, x, y, err := elliptic.GenerateKey(sm2P256V1, rand) if err != nil { return nil, nil, err } privateKey := new(PrivateKey) privateKey.Curve = sm2P256V1 privateKey.D = new(big.Int).SetBytes(priv) publicKey := new(PublicKey) publicKey.Curve = sm2P256V1 publicKey.X = x publicKey.Y = y return privateKey, publicKey, nil } func RawBytesToPublicKey(bytes []byte) (*PublicKey, error) { if len(bytes) != KeyBytes*2 { return nil, errors.New(fmt.Sprintf("Public key raw bytes length must be %d", KeyBytes*2)) } publicKey := new(PublicKey) publicKey.Curve = sm2P256V1 publicKey.X = new(big.Int).SetBytes(bytes[:KeyBytes]) publicKey.Y = new(big.Int).SetBytes(bytes[KeyBytes:]) return publicKey, nil } func RawBytesToPrivateKey(bytes []byte) (*PrivateKey, error) { if len(bytes) != KeyBytes { return nil, errors.New(fmt.Sprintf("Private key raw bytes length must be %d", KeyBytes)) } privateKey := new(PrivateKey) privateKey.Curve = sm2P256V1 privateKey.D = new(big.Int).SetBytes(bytes) return privateKey, nil } func (pub *PublicKey) GetUnCompressBytes() []byte { xBytes := bigIntTo32Bytes(pub.X) yBytes := bigIntTo32Bytes(pub.Y) xl := len(xBytes) yl := len(yBytes) raw := make([]byte, 1+KeyBytes*2) raw[0] = UnCompress if xl > KeyBytes { copy(raw[1:1+KeyBytes], xBytes[xl-KeyBytes:]) } else if xl < KeyBytes { copy(raw[1+(KeyBytes-xl):1+KeyBytes], xBytes) } else { copy(raw[1:1+KeyBytes], xBytes) } if yl > KeyBytes { copy(raw[1+KeyBytes:], yBytes[yl-KeyBytes:]) } else if yl < KeyBytes { copy(raw[1+KeyBytes+(KeyBytes-yl):], yBytes) } else { copy(raw[1+KeyBytes:], yBytes) } return raw } func (pub *PublicKey) GetRawBytes() []byte { raw := pub.GetUnCompressBytes() return raw[1:] } func (pri *PrivateKey) GetRawBytes() []byte { dBytes := bigIntTo32Bytes(pri.D) dl := len(dBytes) if dl > KeyBytes { raw := make([]byte, KeyBytes) copy(raw, dBytes[dl-KeyBytes:]) return raw } else if dl < KeyBytes { raw := make([]byte, KeyBytes) copy(raw[KeyBytes-dl:], dBytes) return raw } else { return dBytes } } func CalculatePubKey(priv *PrivateKey) *PublicKey { pub := new(PublicKey) pub.Curve = priv.Curve pub.X, pub.Y = priv.Curve.ScalarBaseMult(priv.D.Bytes()) return pub } func nextK(rnd io.Reader, max *big.Int) (*big.Int, error) { intOne := new(big.Int).SetInt64(1) var k *big.Int var err error for { k, err = rand.Int(rnd, max) if err != nil { return nil, err } if k.Cmp(intOne) >= 0 { return k, err } } } func xor(data []byte, kdfOut []byte, dRemaining int) { for i := 0; i != dRemaining; i++ { data[i] ^= kdfOut[i] } } // 表示SM2 Key的大数比较小时,直接通过Bytes()函数得到的字节数组可能不够32字节,这个时候要补齐成32字节 func bigIntTo32Bytes(bn *big.Int) []byte { byteArr := bn.Bytes() byteArrLen := len(byteArr) if byteArrLen == KeyBytes { return byteArr } byteArr = append(make([]byte, KeyBytes-byteArrLen), byteArr...) return byteArr } func kdf(digest hash.Hash, c1x *big.Int, c1y *big.Int, encData []byte) { bufSize := 4 if bufSize < digest.Size() { bufSize = digest.Size() } buf := make([]byte, bufSize) encDataLen := len(encData) c1xBytes := bigIntTo32Bytes(c1x) c1yBytes := bigIntTo32Bytes(c1y) off := 0 ct := uint32(0) for off < encDataLen { digest.Reset() digest.Write(c1xBytes) digest.Write(c1yBytes) ct++ binary.BigEndian.PutUint32(buf, ct) digest.Write(buf[:4]) tmp := digest.Sum(nil) copy(buf[:bufSize], tmp[:bufSize]) xorLen := encDataLen - off if xorLen > digest.Size() { xorLen = digest.Size() } xor(encData[off:], buf, xorLen) off += xorLen } } func notEncrypted(encData []byte, in []byte) bool { encDataLen := len(encData) for i := 0; i != encDataLen; i++ { if encData[i] != in[i] { return false } } return true } func Encrypt(pub *PublicKey, in []byte, cipherTextType Sm2CipherTextType) ([]byte, error) { c2 := make([]byte, len(in)) copy(c2, in) var c1 []byte digest := sm3.New() var kPBx, kPBy *big.Int for { k, err := nextK(rand.Reader, pub.Curve.N) if err != nil { return nil, err } kBytes := k.Bytes() c1x, c1y := pub.Curve.ScalarBaseMult(kBytes) c1 = elliptic.Marshal(pub.Curve, c1x, c1y) kPBx, kPBy = pub.Curve.ScalarMult(pub.X, pub.Y, kBytes) kdf(digest, kPBx, kPBy, c2) if !notEncrypted(c2, in) { break } } digest.Reset() digest.Write(bigIntTo32Bytes(kPBx)) digest.Write(in) digest.Write(bigIntTo32Bytes(kPBy)) c3 := digest.Sum(nil) c1Len := len(c1) c2Len := len(c2) c3Len := len(c3) result := make([]byte, c1Len+c2Len+c3Len) if cipherTextType == C1C2C3 { copy(result[:c1Len], c1) copy(result[c1Len:c1Len+c2Len], c2) copy(result[c1Len+c2Len:], c3) } else if cipherTextType == C1C3C2 { copy(result[:c1Len], c1) copy(result[c1Len:c1Len+c3Len], c3) copy(result[c1Len+c3Len:], c2) } else { return nil, errors.New("unknown cipherTextType:" + string(cipherTextType)) } return result, nil } func Decrypt(priv *PrivateKey, in []byte, cipherTextType Sm2CipherTextType) ([]byte, error) { c1Len := ((priv.Curve.BitSize+7)>>3)*2 + 1 c1 := make([]byte, c1Len) copy(c1, in[:c1Len]) c1x, c1y := elliptic.Unmarshal(priv.Curve, c1) sx, sy := priv.Curve.ScalarMult(c1x, c1y, sm2H.Bytes()) if util.IsEcPointInfinity(sx, sy) { return nil, errors.New("[h]C1 at infinity") } c1x, c1y = priv.Curve.ScalarMult(c1x, c1y, priv.D.Bytes()) digest := sm3.New() c3Len := digest.Size() c2Len := len(in) - c1Len - c3Len c2 := make([]byte, c2Len) c3 := make([]byte, c3Len) if cipherTextType == C1C2C3 { copy(c2, in[c1Len:c1Len+c2Len]) copy(c3, in[c1Len+c2Len:]) } else if cipherTextType == C1C3C2 { copy(c3, in[c1Len:c1Len+c3Len]) copy(c2, in[c1Len+c3Len:]) } else { return nil, errors.New("unknown cipherTextType:" + string(cipherTextType)) } kdf(digest, c1x, c1y, c2) digest.Reset() digest.Write(bigIntTo32Bytes(c1x)) digest.Write(c2) digest.Write(bigIntTo32Bytes(c1y)) newC3 := digest.Sum(nil) if !bytes.Equal(newC3, c3) { return nil, errors.New("invalid cipher text") } return c2, nil } func MarshalCipher(in []byte, cipherTextType Sm2CipherTextType) ([]byte, error) { byteLen := (sm2P256V1.Params().BitSize + 7) >> 3 c1x := make([]byte, byteLen) c1y := make([]byte, byteLen) c2Len := len(in) - (1 + byteLen*2) - sm3.DigestLength c2 := make([]byte, c2Len) c3 := make([]byte, sm3.DigestLength) pos := 1 copy(c1x, in[pos:pos+byteLen]) pos += byteLen copy(c1y, in[pos:pos+byteLen]) pos += byteLen nc1x := new(big.Int).SetBytes(c1x) nc1y := new(big.Int).SetBytes(c1y) if cipherTextType == C1C2C3 { copy(c2, in[pos:pos+c2Len]) pos += c2Len copy(c3, in[pos:pos+sm3.DigestLength]) result, err := asn1.Marshal(sm2CipherC1C2C3{nc1x, nc1y, c2, c3}) if err != nil { return nil, err } return result, nil } else if cipherTextType == C1C3C2 { copy(c3, in[pos:pos+sm3.DigestLength]) pos += sm3.DigestLength copy(c2, in[pos:pos+c2Len]) result, err := asn1.Marshal(sm2CipherC1C3C2{nc1x, nc1y, c3, c2}) if err != nil { return nil, err } return result, nil } else { return nil, errors.New("unknown cipherTextType:" + string(cipherTextType)) } } func UnmarshalCipher(in []byte, cipherTextType Sm2CipherTextType) (out []byte, err error) { if cipherTextType == C1C2C3 { cipher := new(sm2CipherC1C2C3) _, err = asn1.Unmarshal(in, cipher) if err != nil { return nil, err } c1xBytes := bigIntTo32Bytes(cipher.X) c1yBytes := bigIntTo32Bytes(cipher.Y) c1xLen := len(c1xBytes) c1yLen := len(c1yBytes) c2Len := len(cipher.C2) c3Len := len(cipher.C3) result := make([]byte, 1+c1xLen+c1yLen+c2Len+c3Len) pos := 0 result[pos] = UnCompress pos += 1 copy(result[pos:pos+c1xLen], c1xBytes) pos += c1xLen copy(result[pos:pos+c1yLen], c1yBytes) pos += c1yLen copy(result[pos:pos+c2Len], cipher.C2) pos += c2Len copy(result[pos:pos+c3Len], cipher.C3) return result, nil } else if cipherTextType == C1C3C2 { cipher := new(sm2CipherC1C3C2) _, err = asn1.Unmarshal(in, cipher) if err != nil { return nil, err } c1xBytes := bigIntTo32Bytes(cipher.X) c1yBytes := bigIntTo32Bytes(cipher.Y) c1xLen := len(c1xBytes) c1yLen := len(c1yBytes) c2Len := len(cipher.C2) c3Len := len(cipher.C3) result := make([]byte, 1+c1xLen+c1yLen+c2Len+c3Len) pos := 0 result[pos] = UnCompress pos += 1 copy(result[pos:pos+c1xLen], c1xBytes) pos += c1xLen copy(result[pos:pos+c1yLen], c1yBytes) pos += c1yLen copy(result[pos:pos+c3Len], cipher.C3) pos += c3Len copy(result[pos:pos+c2Len], cipher.C2) return result, nil } else { return nil, errors.New("unknown cipherTextType:" + string(cipherTextType)) } } func getZ(digest hash.Hash, curve *P256V1Curve, pubX *big.Int, pubY *big.Int, userId []byte) []byte { digest.Reset() userIdLen := uint16(len(userId) * 8) var userIdLenBytes [2]byte binary.BigEndian.PutUint16(userIdLenBytes[:], userIdLen) digest.Write(userIdLenBytes[:]) if userId != nil && len(userId) > 0 { digest.Write(userId) } digest.Write(bigIntTo32Bytes(curve.A)) digest.Write(bigIntTo32Bytes(curve.B)) digest.Write(bigIntTo32Bytes(curve.Gx)) digest.Write(bigIntTo32Bytes(curve.Gy)) digest.Write(bigIntTo32Bytes(pubX)) digest.Write(bigIntTo32Bytes(pubY)) return digest.Sum(nil) } func calculateE(digest hash.Hash, curve *P256V1Curve, pubX *big.Int, pubY *big.Int, userId []byte, src []byte) *big.Int { z := getZ(digest, curve, pubX, pubY, userId) digest.Reset() digest.Write(z) digest.Write(src) eHash := digest.Sum(nil) return new(big.Int).SetBytes(eHash) } func MarshalSign(r, s *big.Int) ([]byte, error) { result, err := asn1.Marshal(sm2Signature{r, s}) if err != nil { return nil, err } return result, nil } func UnmarshalSign(sign []byte) (r, s *big.Int, err error) { sm2Sign := new(sm2Signature) _, err = asn1.Unmarshal(sign, sm2Sign) if err != nil { return nil, nil, err } return sm2Sign.R, sm2Sign.S, nil } func SignToRS(priv *PrivateKey, userId []byte, in []byte) (r, s *big.Int, err error) { digest := sm3.New() pubX, pubY := priv.Curve.ScalarBaseMult(priv.D.Bytes()) if userId == nil { userId = sm2SignDefaultUserId } e := calculateE(digest, &priv.Curve, pubX, pubY, userId, in) intZero := new(big.Int).SetInt64(0) intOne := new(big.Int).SetInt64(1) for { var k *big.Int var err error for { k, err = nextK(rand.Reader, priv.Curve.N) if err != nil { return nil, nil, err } px, _ := priv.Curve.ScalarBaseMult(k.Bytes()) r = util.Add(e, px) r = util.Mod(r, priv.Curve.N) rk := new(big.Int).Set(r) rk = rk.Add(rk, k) if r.Cmp(intZero) != 0 && rk.Cmp(priv.Curve.N) != 0 { break } } dPlus1ModN := util.Add(priv.D, intOne) dPlus1ModN = util.ModInverse(dPlus1ModN, priv.Curve.N) s = util.Mul(r, priv.D) s = util.Sub(k, s) s = util.Mod(s, priv.Curve.N) s = util.Mul(dPlus1ModN, s) s = util.Mod(s, priv.Curve.N) if s.Cmp(intZero) != 0 { break } } return r, s, nil } // 签名结果为DER编码的字节数组 func Sign(priv *PrivateKey, userId []byte, in []byte) ([]byte, error) { r, s, err := SignToRS(priv, userId, in) if err != nil { return nil, err } return MarshalSign(r, s) } func VerifyByRS(pub *PublicKey, userId []byte, src []byte, r, s *big.Int) bool { intOne := new(big.Int).SetInt64(1) if r.Cmp(intOne) == -1 || r.Cmp(pub.Curve.N) >= 0 { return false } if s.Cmp(intOne) == -1 || s.Cmp(pub.Curve.N) >= 0 { return false } digest := sm3.New() if userId == nil { userId = sm2SignDefaultUserId } e := calculateE(digest, &pub.Curve, pub.X, pub.Y, userId, src) intZero := new(big.Int).SetInt64(0) t := util.Add(r, s) t = util.Mod(t, pub.Curve.N) if t.Cmp(intZero) == 0 { return false } sgx, sgy := pub.Curve.ScalarBaseMult(s.Bytes()) tpx, tpy := pub.Curve.ScalarMult(pub.X, pub.Y, t.Bytes()) x, y := pub.Curve.Add(sgx, sgy, tpx, tpy) if util.IsEcPointInfinity(x, y) { return false } expectedR := util.Add(e, x) expectedR = util.Mod(expectedR, pub.Curve.N) return expectedR.Cmp(r) == 0 } // 输入签名须为DER编码的字节数组 func Verify(pub *PublicKey, userId []byte, src []byte, sign []byte) bool { r, s, err := UnmarshalSign(sign) if err != nil { return false } return VerifyByRS(pub, userId, src, r, s) } ================================================ FILE: sm2/sm2_loop_test.go ================================================ package sm2 import ( "crypto/rand" "encoding/hex" "fmt" "math/big" "testing" ) const loopCount = 10 var loopTestSignData = []testSm2SignData{ { d: "5DD701828C424B84C5D56770ECF7C4FE882E654CAC53C7CC89A66B1709068B9D", x: "FF6712D3A7FC0D1B9E01FF471A87EA87525E47C7775039D19304E554DEFE0913", y: "F632025F692776D4C13470ECA36AC85D560E794E1BCCF53D82C015988E0EB956", in: "0102030405060708010203040506070801020304050607080102030405060708010203040506070801020304050607080102030405060708010203040506070830450220213C6CD6EBD6A4D5C2D0AB38E29D441836D1457A8118D34864C247D727831962022100D9248480342AC8513CCDF0F89A2250DC8F6EB4F2471E144E9A812E0AF497F801", }, } func TestSignVerifyLoop(t *testing.T) { priv := new(PrivateKey) priv.Curve = GetSm2P256V1() dBytes, _ := hex.DecodeString(loopTestSignData[0].d) priv.D = new(big.Int).SetBytes(dBytes) pub := new(PublicKey) pub.Curve = GetSm2P256V1() xBytes, _ := hex.DecodeString(loopTestSignData[0].x) yBytes, _ := hex.DecodeString(loopTestSignData[0].y) pub.X = new(big.Int).SetBytes(xBytes) pub.Y = new(big.Int).SetBytes(yBytes) for i := 0; i < loopCount; i++ { inBytes, _ := hex.DecodeString(loopTestSignData[0].in) sign, err := Sign(priv, nil, inBytes) if err != nil { t.Error(err.Error()) break } result := Verify(pub, nil, inBytes, sign) if !result { t.Error("verify failed") break } fmt.Printf("%d pass\n", i) } } func TestSignVerifyLoop2(t *testing.T) { for i := 0; i < loopCount; i++ { priv, pub, err := GenerateKey(rand.Reader) if err != nil { t.Error(err.Error()) break } inBytes, _ := hex.DecodeString(loopTestSignData[0].in) sign, err := Sign(priv, nil, inBytes) if err != nil { t.Error(err.Error()) break } result := Verify(pub, nil, inBytes, sign) if !result { t.Error("verify failed") break } fmt.Printf("%d pass\n", i) } } func TestSignVerifyLoop3(t *testing.T) { priv, pub, err := GenerateKey(rand.Reader) if err != nil { t.Error(err.Error()) return } for i := 0; i < loopCount; i++ { inBytes, _ := hex.DecodeString(loopTestSignData[0].in) sign, err := Sign(priv, nil, inBytes) if err != nil { t.Error(err.Error()) break } result := Verify(pub, nil, inBytes, sign) if !result { t.Error("verify failed") break } fmt.Printf("%d pass\n", i) } } ================================================ FILE: sm2/sm2_test.go ================================================ package sm2 import ( "bytes" "crypto/rand" "encoding/hex" "fmt" "math/big" "testing" ) func TestGetSm2P256V1(t *testing.T) { curve := GetSm2P256V1() fmt.Printf("P:%s\n", curve.Params().P.Text(16)) fmt.Printf("B:%s\n", curve.Params().B.Text(16)) fmt.Printf("N:%s\n", curve.Params().N.Text(16)) fmt.Printf("Gx:%s\n", curve.Params().Gx.Text(16)) fmt.Printf("Gy:%s\n", curve.Params().Gy.Text(16)) } func TestGenerateKey(t *testing.T) { priv, pub, err := GenerateKey(rand.Reader) if err != nil { t.Error(err.Error()) return } fmt.Printf("priv:%s\n", priv.D.Text(16)) fmt.Printf("x:%s\n", pub.X.Text(16)) fmt.Printf("y:%s\n", pub.Y.Text(16)) curve := GetSm2P256V1() if !curve.IsOnCurve(pub.X, pub.Y) { t.Error("x,y is not on Curve") return } fmt.Println("x,y is on sm2 Curve") } func TestEncryptDecrypt_C1C2C3(t *testing.T) { src := []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10} priv, pub, err := GenerateKey(rand.Reader) if err != nil { t.Error(err.Error()) return } fmt.Printf("d:%s\n", hex.EncodeToString(priv.D.Bytes())) fmt.Printf("x:%s\n", hex.EncodeToString(pub.X.Bytes())) fmt.Printf("y:%s\n", hex.EncodeToString(pub.Y.Bytes())) cipherText, err := Encrypt(pub, src, C1C2C3) if err != nil { t.Error(err.Error()) return } fmt.Printf("cipher text:%s\n", hex.EncodeToString(cipherText)) plainText, err := Decrypt(priv, cipherText, C1C2C3) if err != nil { t.Error(err.Error()) return } fmt.Printf("plain text:%s\n", hex.EncodeToString(plainText)) if !bytes.Equal(plainText, src) { t.Error("decrypt result not equal expected") return } } func TestEncryptDecrypt_C1C3C2(t *testing.T) { src := []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10} priv, pub, err := GenerateKey(rand.Reader) if err != nil { t.Error(err.Error()) return } fmt.Printf("d:%s\n", hex.EncodeToString(priv.D.Bytes())) fmt.Printf("x:%s\n", hex.EncodeToString(pub.X.Bytes())) fmt.Printf("y:%s\n", hex.EncodeToString(pub.Y.Bytes())) cipherText, err := Encrypt(pub, src, C1C3C2) if err != nil { t.Error(err.Error()) return } fmt.Printf("cipher text:%s\n", hex.EncodeToString(cipherText)) plainText, err := Decrypt(priv, cipherText, C1C3C2) if err != nil { t.Error(err.Error()) return } fmt.Printf("plain text:%s\n", hex.EncodeToString(plainText)) if !bytes.Equal(plainText, src) { t.Error("decrypt result not equal expected") return } } func TestCipherDerEncode_C1C2C3(t *testing.T) { src := []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10} priv, pub, err := GenerateKey(rand.Reader) if err != nil { t.Error(err.Error()) return } cipherText, err := Encrypt(pub, src, C1C2C3) if err != nil { t.Error(err.Error()) return } fmt.Printf("before DER encode, cipher text:%s\n", hex.EncodeToString(cipherText)) derCipher, err := MarshalCipher(cipherText, C1C2C3) if err != nil { t.Error(err.Error()) return } //err = ioutil.WriteFile("derCipher.dat", derCipher, 0644) //if err != nil { // t.Error(err.Error()) // return //} cipherText, err = UnmarshalCipher(derCipher, C1C2C3) if err != nil { t.Error(err.Error()) return } fmt.Printf("after DER decode, cipher text:%s\n", hex.EncodeToString(cipherText)) plainText, err := Decrypt(priv, cipherText, C1C2C3) if err != nil { t.Error(err.Error()) return } fmt.Printf("plain text:%s\n", hex.EncodeToString(plainText)) if !bytes.Equal(plainText, src) { t.Error("decrypt result not equal expected") return } } func TestCipherDerEncode_C1C3C2(t *testing.T) { src := []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10} priv, pub, err := GenerateKey(rand.Reader) if err != nil { t.Error(err.Error()) return } cipherText, err := Encrypt(pub, src, C1C3C2) if err != nil { t.Error(err.Error()) return } fmt.Printf("before DER encode, cipher text:%s\n", hex.EncodeToString(cipherText)) derCipher, err := MarshalCipher(cipherText, C1C3C2) if err != nil { t.Error(err.Error()) return } //err = ioutil.WriteFile("derCipher.dat", derCipher, 0644) //if err != nil { // t.Error(err.Error()) // return //} cipherText, err = UnmarshalCipher(derCipher, C1C3C2) if err != nil { t.Error(err.Error()) return } fmt.Printf("after DER decode, cipher text:%s\n", hex.EncodeToString(cipherText)) plainText, err := Decrypt(priv, cipherText, C1C3C2) if err != nil { t.Error(err.Error()) return } fmt.Printf("plain text:%s\n", hex.EncodeToString(plainText)) if !bytes.Equal(plainText, src) { t.Error("decrypt result not equal expected") return } } type testSm2SignData struct { d string x string y string in string sign string } var testSignData = []testSm2SignData{ { d: "5DD701828C424B84C5D56770ECF7C4FE882E654CAC53C7CC89A66B1709068B9D", x: "FF6712D3A7FC0D1B9E01FF471A87EA87525E47C7775039D19304E554DEFE0913", y: "F632025F692776D4C13470ECA36AC85D560E794E1BCCF53D82C015988E0EB956", in: "0102030405060708010203040506070801020304050607080102030405060708", sign: "30450220213C6CD6EBD6A4D5C2D0AB38E29D441836D1457A8118D34864C247D727831962022100D9248480342AC8513CCDF0F89A2250DC8F6EB4F2471E144E9A812E0AF497F801", }, } func TestSign(t *testing.T) { for _, data := range testSignData { priv := new(PrivateKey) priv.Curve = GetSm2P256V1() dBytes, _ := hex.DecodeString(data.d) priv.D = new(big.Int).SetBytes(dBytes) inBytes, _ := hex.DecodeString(data.in) sign, err := Sign(priv, nil, inBytes) if err != nil { t.Error(err.Error()) return } fmt.Printf("sign:%s\n", hex.EncodeToString(sign)) pub := new(PublicKey) pub.Curve = GetSm2P256V1() xBytes, _ := hex.DecodeString(data.x) yBytes, _ := hex.DecodeString(data.y) pub.X = new(big.Int).SetBytes(xBytes) pub.Y = new(big.Int).SetBytes(yBytes) result := Verify(pub, nil, inBytes, sign) if !result { t.Error("verify failed") return } } } func TestVerify(t *testing.T) { for _, data := range testSignData { pub := new(PublicKey) pub.Curve = GetSm2P256V1() xBytes, _ := hex.DecodeString(data.x) yBytes, _ := hex.DecodeString(data.y) pub.X = new(big.Int).SetBytes(xBytes) pub.Y = new(big.Int).SetBytes(yBytes) inBytes, _ := hex.DecodeString(data.in) sign, _ := hex.DecodeString(data.sign) result := Verify(pub, nil, inBytes, sign) if !result { t.Error("verify failed") return } } } ================================================ FILE: sm3/sm3.go ================================================ package sm3 import ( "encoding/binary" "fmt" "hash" "math/bits" ) const ( DigestLength = 32 BlockSize = 16 ) var gT = []uint32{ 0x79CC4519, 0xF3988A32, 0xE7311465, 0xCE6228CB, 0x9CC45197, 0x3988A32F, 0x7311465E, 0xE6228CBC, 0xCC451979, 0x988A32F3, 0x311465E7, 0x6228CBCE, 0xC451979C, 0x88A32F39, 0x11465E73, 0x228CBCE6, 0x9D8A7A87, 0x3B14F50F, 0x7629EA1E, 0xEC53D43C, 0xD8A7A879, 0xB14F50F3, 0x629EA1E7, 0xC53D43CE, 0x8A7A879D, 0x14F50F3B, 0x29EA1E76, 0x53D43CEC, 0xA7A879D8, 0x4F50F3B1, 0x9EA1E762, 0x3D43CEC5, 0x7A879D8A, 0xF50F3B14, 0xEA1E7629, 0xD43CEC53, 0xA879D8A7, 0x50F3B14F, 0xA1E7629E, 0x43CEC53D, 0x879D8A7A, 0x0F3B14F5, 0x1E7629EA, 0x3CEC53D4, 0x79D8A7A8, 0xF3B14F50, 0xE7629EA1, 0xCEC53D43, 0x9D8A7A87, 0x3B14F50F, 0x7629EA1E, 0xEC53D43C, 0xD8A7A879, 0xB14F50F3, 0x629EA1E7, 0xC53D43CE, 0x8A7A879D, 0x14F50F3B, 0x29EA1E76, 0x53D43CEC, 0xA7A879D8, 0x4F50F3B1, 0x9EA1E762, 0x3D43CEC5} type sm3Digest struct { v [DigestLength / 4]uint32 inWords [BlockSize]uint32 xOff int32 w [68]uint32 xBuf [4]byte xBufOff int32 byteCount int64 } func New() hash.Hash { digest := new(sm3Digest) digest.Reset() return digest } func (digest *sm3Digest) Sum(b []byte) []byte { d1 := digest h := d1.checkSum() return append(b, h[:]...) } // Size returns the number of bytes Sum will return. func (digest *sm3Digest) Size() int { return DigestLength } // BlockSize returns the hash's underlying block size. // The Write method must be able to accept any amount // of data, but it may operate more efficiently if all writes // are a multiple of the block size. func (digest *sm3Digest) BlockSize() int { return BlockSize } func (digest *sm3Digest) Reset() { digest.byteCount = 0 digest.xBufOff = 0 for i := 0; i < len(digest.xBuf); i++ { digest.xBuf[i] = 0 } for i := 0; i < len(digest.inWords); i++ { digest.inWords[i] = 0 } for i := 0; i < len(digest.w); i++ { digest.w[i] = 0 } digest.v[0] = 0x7380166F digest.v[1] = 0x4914B2B9 digest.v[2] = 0x172442D7 digest.v[3] = 0xDA8A0600 digest.v[4] = 0xA96F30BC digest.v[5] = 0x163138AA digest.v[6] = 0xE38DEE4D digest.v[7] = 0xB0FB0E4E digest.xOff = 0 } func (digest *sm3Digest) Write(p []byte) (n int, err error) { _ = p[0] inLen := len(p) i := 0 if digest.xBufOff != 0 { for i < inLen { digest.xBuf[digest.xBufOff] = p[i] digest.xBufOff++ i++ if digest.xBufOff == 4 { digest.processWord(digest.xBuf[:], 0) digest.xBufOff = 0 break } } } limit := ((inLen - i) & ^3) + i for ; i < limit; i += 4 { digest.processWord(p, int32(i)) } for i < inLen { digest.xBuf[digest.xBufOff] = p[i] digest.xBufOff++ i++ } digest.byteCount += int64(inLen) n = inLen return } func (digest *sm3Digest) finish() { bitLength := digest.byteCount << 3 digest.Write([]byte{128}) for digest.xBufOff != 0 { digest.Write([]byte{0}) } digest.processLength(bitLength) digest.processBlock() } func (digest *sm3Digest) checkSum() [DigestLength]byte { digest.finish() vlen := len(digest.v) var out [DigestLength]byte for i := 0; i < vlen; i++ { binary.BigEndian.PutUint32(out[i*4:(i+1)*4], digest.v[i]) } return out } func (digest *sm3Digest) processBlock() { for j := 0; j < 16; j++ { digest.w[j] = digest.inWords[j] } for j := 16; j < 68; j++ { wj3 := digest.w[j-3] r15 := (wj3 << 15) | (wj3 >> (32 - 15)) wj13 := digest.w[j-13] r7 := (wj13 << 7) | (wj13 >> (32 - 7)) digest.w[j] = p1(digest.w[j-16]^digest.w[j-9]^r15) ^ r7 ^ digest.w[j-6] } A := digest.v[0] B := digest.v[1] C := digest.v[2] D := digest.v[3] E := digest.v[4] F := digest.v[5] G := digest.v[6] H := digest.v[7] for j := 0; j < 16; j++ { a12 := (A << 12) | (A >> (32 - 12)) s1 := a12 + E + gT[j] SS1 := (s1 << 7) | (s1 >> (32 - 7)) SS2 := SS1 ^ a12 Wj := digest.w[j] W1j := Wj ^ digest.w[j+4] TT1 := ff0(A, B, C) + D + SS2 + W1j TT2 := gg0(E, F, G) + H + SS1 + Wj D = C C = (B << 9) | (B >> (32 - 9)) B = A A = TT1 H = G G = (F << 19) | (F >> (32 - 19)) F = E E = p0(TT2) } for j := 16; j < 64; j++ { a12 := (A << 12) | (A >> (32 - 12)) s1 := a12 + E + gT[j] SS1 := (s1 << 7) | (s1 >> (32 - 7)) SS2 := SS1 ^ a12 Wj := digest.w[j] W1j := Wj ^ digest.w[j+4] TT1 := ff1(A, B, C) + D + SS2 + W1j TT2 := gg1(E, F, G) + H + SS1 + Wj D = C C = (B << 9) | (B >> (32 - 9)) B = A A = TT1 H = G G = (F << 19) | (F >> (32 - 19)) F = E E = p0(TT2) } digest.v[0] ^= A digest.v[1] ^= B digest.v[2] ^= C digest.v[3] ^= D digest.v[4] ^= E digest.v[5] ^= F digest.v[6] ^= G digest.v[7] ^= H digest.xOff = 0 } func (digest *sm3Digest) processWord(in []byte, inOff int32) { n := binary.BigEndian.Uint32(in[inOff : inOff+4]) digest.inWords[digest.xOff] = n digest.xOff++ if digest.xOff >= 16 { digest.processBlock() } } func (digest *sm3Digest) processLength(bitLength int64) { if digest.xOff > (BlockSize - 2) { digest.inWords[digest.xOff] = 0 digest.xOff++ digest.processBlock() } for ; digest.xOff < (BlockSize - 2); digest.xOff++ { digest.inWords[digest.xOff] = 0 } digest.inWords[digest.xOff] = uint32(bitLength >> 32) digest.xOff++ digest.inWords[digest.xOff] = uint32(bitLength) digest.xOff++ } func p0(x uint32) uint32 { r9 := bits.RotateLeft32(x, 9) r17 := bits.RotateLeft32(x, 17) return x ^ r9 ^ r17 } func p1(x uint32) uint32 { r15 := bits.RotateLeft32(x, 15) r23 := bits.RotateLeft32(x, 23) return x ^ r15 ^ r23 } func ff0(x uint32, y uint32, z uint32) uint32 { return x ^ y ^ z } func ff1(x uint32, y uint32, z uint32) uint32 { return (x & y) | (x & z) | (y & z) } func gg0(x uint32, y uint32, z uint32) uint32 { return x ^ y ^ z } func gg1(x uint32, y uint32, z uint32) uint32 { return (x & y) | ((^x) & z) } func Sum(data []byte) [DigestLength]byte { var d sm3Digest d.Reset() d.Write(data) return d.checkSum() } func PrintT() { var T [64]uint32 fmt.Print("{") for j := 0; j < 16; j++ { T[j] = 0x79CC4519 Tj := (T[j] << uint32(j)) | (T[j] >> (32 - uint32(j))) fmt.Printf("0x%08X, ", Tj) } for j := 16; j < 64; j++ { n := j % 32 T[j] = 0x7A879D8A Tj := (T[j] << uint32(n)) | (T[j] >> (32 - uint32(n))) if j == 63 { fmt.Printf("0x%08X}\n", Tj) } else { fmt.Printf("0x%08X, ", Tj) } } } ================================================ FILE: sm3/sm3_test.go ================================================ package sm3 import ( "bytes" "encoding/hex" "fmt" "testing" ) var testData = map[string]string{ "abc": "66c7f0f462eeedd9d1f2d46bdc10e4e24167c4875cf2f7a2297da02b8f4ba8e0", "abcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcd": "debe9ff92275b8a138604889c18e5a4d6fdb70e5387e5765293dcba39c0c5732"} func TestPrintT(t *testing.T) { PrintT() } func TestSum(t *testing.T) { for src, expected := range testData { testSum(t, src, expected) } } func TestSm3Digest_Sum(t *testing.T) { for src, expected := range testData { testSm3DigestSum(t, src, expected) } } func testSum(t *testing.T, src string, expected string) { hash := Sum([]byte(src)) hashHex := hex.EncodeToString(hash[:]) if hashHex != expected { t.Errorf("result:%s , not equal expected\n", hashHex) return } } func testSm3DigestSum(t *testing.T, src string, expected string) { d := New() d.Write([]byte(src)) hash := d.Sum(nil) hashHex := hex.EncodeToString(hash[:]) if hashHex != expected { t.Errorf("result:%s , not equal expected\n", hashHex) return } } func TestSm3Digest_Write(t *testing.T) { src1 := []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10} src2 := []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10} src3 := []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10} d := New() d.Write(src1) d.Write(src2) d.Write(src3) digest1 := d.Sum(nil) fmt.Printf("1 : %s\n", hex.EncodeToString(digest1)) d.Reset() d.Write(src1) d.Write(src2) d.Write(src3) digest2 := d.Sum(nil) fmt.Printf("2 : %s\n", hex.EncodeToString(digest2)) if !bytes.Equal(digest1, digest2) { t.Error("") return } } ================================================ FILE: sm4/sm4.go ================================================ package sm4 import ( "crypto/cipher" "encoding/binary" "errors" "math/bits" "strconv" ) const ( BlockSize = 16 KeySize = 16 ) var sBox = [256]byte{ 0xd6, 0x90, 0xe9, 0xfe, 0xcc, 0xe1, 0x3d, 0xb7, 0x16, 0xb6, 0x14, 0xc2, 0x28, 0xfb, 0x2c, 0x05, 0x2b, 0x67, 0x9a, 0x76, 0x2a, 0xbe, 0x04, 0xc3, 0xaa, 0x44, 0x13, 0x26, 0x49, 0x86, 0x06, 0x99, 0x9c, 0x42, 0x50, 0xf4, 0x91, 0xef, 0x98, 0x7a, 0x33, 0x54, 0x0b, 0x43, 0xed, 0xcf, 0xac, 0x62, 0xe4, 0xb3, 0x1c, 0xa9, 0xc9, 0x08, 0xe8, 0x95, 0x80, 0xdf, 0x94, 0xfa, 0x75, 0x8f, 0x3f, 0xa6, 0x47, 0x07, 0xa7, 0xfc, 0xf3, 0x73, 0x17, 0xba, 0x83, 0x59, 0x3c, 0x19, 0xe6, 0x85, 0x4f, 0xa8, 0x68, 0x6b, 0x81, 0xb2, 0x71, 0x64, 0xda, 0x8b, 0xf8, 0xeb, 0x0f, 0x4b, 0x70, 0x56, 0x9d, 0x35, 0x1e, 0x24, 0x0e, 0x5e, 0x63, 0x58, 0xd1, 0xa2, 0x25, 0x22, 0x7c, 0x3b, 0x01, 0x21, 0x78, 0x87, 0xd4, 0x00, 0x46, 0x57, 0x9f, 0xd3, 0x27, 0x52, 0x4c, 0x36, 0x02, 0xe7, 0xa0, 0xc4, 0xc8, 0x9e, 0xea, 0xbf, 0x8a, 0xd2, 0x40, 0xc7, 0x38, 0xb5, 0xa3, 0xf7, 0xf2, 0xce, 0xf9, 0x61, 0x15, 0xa1, 0xe0, 0xae, 0x5d, 0xa4, 0x9b, 0x34, 0x1a, 0x55, 0xad, 0x93, 0x32, 0x30, 0xf5, 0x8c, 0xb1, 0xe3, 0x1d, 0xf6, 0xe2, 0x2e, 0x82, 0x66, 0xca, 0x60, 0xc0, 0x29, 0x23, 0xab, 0x0d, 0x53, 0x4e, 0x6f, 0xd5, 0xdb, 0x37, 0x45, 0xde, 0xfd, 0x8e, 0x2f, 0x03, 0xff, 0x6a, 0x72, 0x6d, 0x6c, 0x5b, 0x51, 0x8d, 0x1b, 0xaf, 0x92, 0xbb, 0xdd, 0xbc, 0x7f, 0x11, 0xd9, 0x5c, 0x41, 0x1f, 0x10, 0x5a, 0xd8, 0x0a, 0xc1, 0x31, 0x88, 0xa5, 0xcd, 0x7b, 0xbd, 0x2d, 0x74, 0xd0, 0x12, 0xb8, 0xe5, 0xb4, 0xb0, 0x89, 0x69, 0x97, 0x4a, 0x0c, 0x96, 0x77, 0x7e, 0x65, 0xb9, 0xf1, 0x09, 0xc5, 0x6e, 0xc6, 0x84, 0x18, 0xf0, 0x7d, 0xec, 0x3a, 0xdc, 0x4d, 0x20, 0x79, 0xee, 0x5f, 0x3e, 0xd7, 0xcb, 0x39, 0x48, } var cK = [32]uint32{ 0x00070e15, 0x1c232a31, 0x383f464d, 0x545b6269, 0x70777e85, 0x8c939aa1, 0xa8afb6bd, 0xc4cbd2d9, 0xe0e7eef5, 0xfc030a11, 0x181f262d, 0x343b4249, 0x50575e65, 0x6c737a81, 0x888f969d, 0xa4abb2b9, 0xc0c7ced5, 0xdce3eaf1, 0xf8ff060d, 0x141b2229, 0x30373e45, 0x4c535a61, 0x686f767d, 0x848b9299, 0xa0a7aeb5, 0xbcc3cad1, 0xd8dfe6ed, 0xf4fb0209, 0x10171e25, 0x2c333a41, 0x484f565d, 0x646b7279, } var fK = [4]uint32{ 0xa3b1bac6, 0x56aa3350, 0x677d9197, 0xb27022dc, } type KeySizeError int func (k KeySizeError) Error() string { return "sm4: invalid key size " + strconv.Itoa(int(k)) } type sm4Cipher struct { enc []uint32 dec []uint32 } func NewCipher(key []byte) (cipher.Block, error) { n := len(key) if n != KeySize { return nil, KeySizeError(n) } c := new(sm4Cipher) c.enc = expandKey(key, true) c.dec = expandKey(key, false) return c, nil } func (c *sm4Cipher) BlockSize() int { return BlockSize } func (c *sm4Cipher) Encrypt(dst, src []byte) { if len(src) < BlockSize { panic("sm4: input not full block") } if len(dst) < BlockSize { panic("sm4: output not full block") } processBlock(c.enc, src, dst) } func (c *sm4Cipher) Decrypt(dst, src []byte) { if len(src) < BlockSize { panic("sm4: input not full block") } if len(dst) < BlockSize { panic("sm4: output not full block") } processBlock(c.dec, src, dst) } func expandKey(key []byte, forEnc bool) []uint32 { var mK [4]uint32 mK[0] = binary.BigEndian.Uint32(key[0:4]) mK[1] = binary.BigEndian.Uint32(key[4:8]) mK[2] = binary.BigEndian.Uint32(key[8:12]) mK[3] = binary.BigEndian.Uint32(key[12:16]) var x [5]uint32 x[0] = mK[0] ^ fK[0] x[1] = mK[1] ^ fK[1] x[2] = mK[2] ^ fK[2] x[3] = mK[3] ^ fK[3] var rk [32]uint32 if forEnc { for i := 0; i < 32; i++ { x[(i+4)%5] = encRound(x[i%5], x[(i+1)%5], x[(i+2)%5], x[(i+3)%5], x[(i+4)%5], rk[:], i) } } else { for i := 0; i < 32; i++ { x[(i+4)%5] = decRound(x[i%5], x[(i+1)%5], x[(i+2)%5], x[(i+3)%5], x[(i+4)%5], rk[:], i) } } return rk[:] } func tau(a uint32) uint32 { var aArr [4]byte var bArr [4]byte binary.BigEndian.PutUint32(aArr[:], a) bArr[0] = sBox[aArr[0]] bArr[1] = sBox[aArr[1]] bArr[2] = sBox[aArr[2]] bArr[3] = sBox[aArr[3]] return binary.BigEndian.Uint32(bArr[:]) } func lAp(b uint32) uint32 { return b ^ bits.RotateLeft32(b, 13) ^ bits.RotateLeft32(b, 23) } func tAp(z uint32) uint32 { return lAp(tau(z)) } func encRound(x0 uint32, x1 uint32, x2 uint32, x3 uint32, x4 uint32, rk []uint32, i int) uint32 { x4 = x0 ^ tAp(x1^x2^x3^cK[i]) rk[i] = x4 return x4 } func decRound(x0 uint32, x1 uint32, x2 uint32, x3 uint32, x4 uint32, rk []uint32, i int) uint32 { x4 = x0 ^ tAp(x1^x2^x3^cK[i]) rk[31-i] = x4 return x4 } func processBlock(rk []uint32, in []byte, out []byte) { var x [BlockSize / 4]uint32 x[0] = binary.BigEndian.Uint32(in[0:4]) x[1] = binary.BigEndian.Uint32(in[4:8]) x[2] = binary.BigEndian.Uint32(in[8:12]) x[3] = binary.BigEndian.Uint32(in[12:16]) for i := 0; i < 32; i += 4 { x[0] = f0(x[:], rk[i]) x[1] = f1(x[:], rk[i+1]) x[2] = f2(x[:], rk[i+2]) x[3] = f3(x[:], rk[i+3]) } r(x[:]) binary.BigEndian.PutUint32(out[0:4], x[0]) binary.BigEndian.PutUint32(out[4:8], x[1]) binary.BigEndian.PutUint32(out[8:12], x[2]) binary.BigEndian.PutUint32(out[12:16], x[3]) } func l(b uint32) uint32 { return b ^ bits.RotateLeft32(b, 2) ^ bits.RotateLeft32(b, 10) ^ bits.RotateLeft32(b, 18) ^ bits.RotateLeft32(b, 24) } func t(z uint32) uint32 { return l(tau(z)) } func r(a []uint32) { a[0] = a[0] ^ a[3] a[3] = a[0] ^ a[3] a[0] = a[0] ^ a[3] a[1] = a[1] ^ a[2] a[2] = a[1] ^ a[2] a[1] = a[1] ^ a[2] } func f0(x []uint32, rk uint32) uint32 { return x[0] ^ t(x[1]^x[2]^x[3]^rk) } func f1(x []uint32, rk uint32) uint32 { return x[1] ^ t(x[2]^x[3]^x[0]^rk) } func f2(x []uint32, rk uint32) uint32 { return x[2] ^ t(x[3]^x[0]^x[1]^rk) } func f3(x []uint32, rk uint32) uint32 { return x[3] ^ t(x[0]^x[1]^x[2]^rk) } // 输入的plainText长度必须是BlockSize(16)的整数倍,也就是调用该方法前调用方需先加好padding, // 可调用util.PKCS5Padding()方法进行加padding操作 func ECBEncrypt(key, plainText []byte) (cipherText []byte, err error) { plainTextLen := len(plainText) if plainTextLen%BlockSize != 0 { return nil, errors.New("input not full blocks") } c, err := NewCipher(key) if err != nil { return nil, err } cipherText = make([]byte, plainTextLen) for i := 0; i < plainTextLen; i += BlockSize { c.Encrypt(cipherText[i:i+BlockSize], plainText[i:i+BlockSize]) } return cipherText, nil } // 输出的plainText是加padding的明文,调用方需要自己去padding, // 可调用util.PKCS5UnPadding()方法进行去padding操作 func ECBDecrypt(key, cipherText []byte) (plainText []byte, err error) { cipherTextLen := len(cipherText) if cipherTextLen%BlockSize != 0 { return nil, errors.New("input not full blocks") } c, err := NewCipher(key) if err != nil { return nil, err } plainText = make([]byte, cipherTextLen) for i := 0; i < cipherTextLen; i += BlockSize { c.Decrypt(plainText[i:i+BlockSize], cipherText[i:i+BlockSize]) } return plainText, nil } // 输入的plainText长度必须是BlockSize(16)的整数倍,也就是调用该方法前调用方需先加好padding, // 可调用util.PKCS5Padding()方法进行加padding操作 func CBCEncrypt(key, iv, plainText []byte) (cipherText []byte, err error) { plainTextLen := len(plainText) if plainTextLen%BlockSize != 0 { return nil, errors.New("input not full blocks") } c, err := NewCipher(key) if err != nil { return nil, err } encrypter := cipher.NewCBCEncrypter(c, iv) cipherText = make([]byte, plainTextLen) encrypter.CryptBlocks(cipherText, plainText) return cipherText, nil } // 输出的plainText是加padding的明文,调用方需要自己去padding, // 可调用util.PKCS5UnPadding()方法进行去padding操作 func CBCDecrypt(key, iv, cipherText []byte) (plainText []byte, err error) { cipherTextLen := len(cipherText) if cipherTextLen%BlockSize != 0 { return nil, errors.New("input not full blocks") } c, err := NewCipher(key) if err != nil { return nil, err } decrypter := cipher.NewCBCDecrypter(c, iv) plainText = make([]byte, len(cipherText)) decrypter.CryptBlocks(plainText, cipherText) return plainText, nil } ================================================ FILE: sm4/sm4_test.go ================================================ package sm4 import ( "bytes" "encoding/hex" "fmt" "github.com/ZZMarquis/gm/util" "testing" ) type sm4CbcTestData struct { key []byte iv []byte in []byte out []byte } var cbcTestData = []sm4CbcTestData{ { key: []byte{0x7b, 0xea, 0x0a, 0xa5, 0x45, 0x8e, 0xd1, 0xa3, 0x7d, 0xb1, 0x65, 0x2e, 0xfb, 0xc5, 0x95, 0x05}, iv: []byte{0x70, 0xb6, 0xe0, 0x8d, 0x46, 0xee, 0x82, 0x24, 0x45, 0x60, 0x0b, 0x25, 0xc4, 0x71, 0xfa, 0xba}, in: []byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08}, out: []byte{0xca, 0x55, 0xc5, 0x15, 0x0b, 0xf7, 0xf4, 0x6f, 0xc9, 0x89, 0x2a, 0xce, 0x49, 0x78, 0x93, 0x03}, }, { key: []byte{0x7b, 0xea, 0x0a, 0xa5, 0x45, 0x8e, 0xd1, 0xa3, 0x7d, 0xb1, 0x65, 0x2e, 0xfb, 0xc5, 0x95, 0x05}, iv: []byte{0x70, 0xb6, 0xe0, 0x8d, 0x46, 0xee, 0x82, 0x24, 0x45, 0x60, 0x0b, 0x25, 0xc4, 0x71, 0xfa, 0xba}, in: []byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08}, out: []byte{0x95, 0xe1, 0xec, 0x3b, 0x56, 0x4a, 0x46, 0x71, 0xe7, 0xd6, 0xb1, 0x10, 0xe9, 0x09, 0x0b, 0x1b, 0xb7, 0xb5, 0x9e, 0x8d, 0x74, 0x47, 0x1e, 0x70, 0x86, 0x04, 0x6b, 0xe8, 0x78, 0x00, 0x45, 0x32}, }, } func TestSm4_CBC_Encrypt(t *testing.T) { for _, data := range cbcTestData { fmt.Printf("Key:%s\n", hex.EncodeToString(data.key)) fmt.Printf("IV:%s\n", hex.EncodeToString(data.iv)) cipherText, err := CBCEncrypt(data.key, data.iv, util.PKCS5Padding(data.in, BlockSize)) if err != nil { t.Error(err.Error()) return } fmt.Printf("encrypt cipherText:%s\n", hex.EncodeToString(cipherText)) if !bytes.Equal(cipherText, data.out) { t.Error("encrypt cipherText not equal expected") return } plainTextWithPadding, err := CBCDecrypt(data.key, data.iv, cipherText) if err != nil { t.Error(err.Error()) return } fmt.Printf("decrypt cipherText:%s\n", hex.EncodeToString(plainTextWithPadding)) plainText := util.PKCS5UnPadding(plainTextWithPadding) if !bytes.Equal(plainText, data.in) { t.Error("decrypt cipherText not equal expected") return } } } type sm4EcbTestData struct { key []byte in []byte } var ecbTestData = []sm4EcbTestData{ { key: []byte("1234567890123456"), in: []byte("ssssssss"), }, { key: []byte("1234567890123456"), in: []byte("ssssssssssssssss"), }, { key: []byte("1234567890123456"), in: []byte("ssssssssssssssssssssssss"), }, } func TestSm4_ECB_Encrypt_PKCS5Padding(t *testing.T) { for _, testData := range ecbTestData { plainTextWithPadding := util.PKCS5Padding(testData.in, BlockSize) cipherText, err := ECBEncrypt(testData.key, plainTextWithPadding) if err != nil { t.Error(err.Error()) return } fmt.Printf("%x\n", cipherText) plainTextWithPadding, err = ECBDecrypt(testData.key, cipherText) if err != nil { t.Error(err.Error()) return } plainText := util.PKCS5UnPadding(plainTextWithPadding) fmt.Println(string(plainText)) if !bytes.Equal(testData.in, plainText) { t.Error("decrypt result not equal expected") return } } } func TestSm4_ECB_Encrypt_ZeroPadding(t *testing.T) { for _, testData := range ecbTestData { plainTextWithPadding := util.ZeroPadding(testData.in, BlockSize) paddingLen := len(plainTextWithPadding) - len(testData.in) cipherText, err := ECBEncrypt(testData.key, plainTextWithPadding) if err != nil { t.Error(err.Error()) return } fmt.Printf("%x\n", cipherText) plainTextWithPadding, err = ECBDecrypt(testData.key, cipherText) if err != nil { t.Error(err.Error()) return } plainText := util.UnZeroPadding(plainTextWithPadding, paddingLen) fmt.Println(string(plainText)) if !bytes.Equal(testData.in, plainText) { t.Error("decrypt result not equal expected") return } } } ================================================ FILE: util/bigint.go ================================================ package util import "math/big" func Add(x, y *big.Int) *big.Int { var z big.Int z.Add(x, y) return &z } func Sub(x, y *big.Int) *big.Int { var z big.Int z.Sub(x, y) return &z } func Mod(x, y *big.Int) *big.Int { var z big.Int z.Mod(x, y) return &z } func ModInverse(x, y *big.Int) *big.Int { var z big.Int z.ModInverse(x, y) return &z } func Mul(x, y *big.Int) *big.Int { var z big.Int z.Mul(x, y) return &z } func Lsh(x *big.Int, n uint) *big.Int { var z big.Int z.Lsh(x, n) return &z } func SetBit(x *big.Int, i int, b uint) *big.Int { var z big.Int z.SetBit(x, i, b) return &z } func And(x, y *big.Int) *big.Int { var z big.Int z.And(x, y) return &z } ================================================ FILE: util/bigint_test.go ================================================ package util import ( "fmt" "math/big" "testing" ) func TestAdd(t *testing.T) { a := new(big.Int).SetInt64(1) b := new(big.Int).SetInt64(1) a.Add(a, b) fmt.Printf("a:%s\n", a.Text(10)) fmt.Printf("b:%s\n", b.Text(10)) } func TestAdd2(t *testing.T) { a := new(big.Int).SetInt64(1) b := new(big.Int).SetInt64(1) z := Add(a, b) fmt.Printf("a:%s\n", a.Text(10)) fmt.Printf("b:%s\n", b.Text(10)) fmt.Printf("z:%s\n", z.Text(10)) } ================================================ FILE: util/ec.go ================================================ package util import "math/big" func IsEcPointInfinity(x, y *big.Int) bool { if x.Sign() == 0 && y.Sign() == 0 { return true } return false } func ZForAffine(x, y *big.Int) *big.Int { z := new(big.Int) if x.Sign() != 0 || y.Sign() != 0 { z.SetInt64(1) } return z } ================================================ FILE: util/padding.go ================================================ package util import "bytes" func PKCS5Padding(src []byte, blockSize int) []byte { padding := blockSize - len(src)%blockSize padtext := bytes.Repeat([]byte{byte(padding)}, padding) return append(src, padtext...) } func PKCS5UnPadding(src []byte) []byte { length := len(src) unpadding := int(src[length-1]) return src[:(length - unpadding)] } func ZeroPadding(src []byte, blockSize int) []byte { padding := blockSize - len(src)%blockSize padtext := bytes.Repeat([]byte{0}, padding) return append(src, padtext...) } // UnZeroPadding // 由于原文最后一个或若干个字节就有可能为0,所以大多情况下不能简单粗暴地后面有几个0就去掉几个0,除非可以确定最后一个字节肯定不为0. // 所以需要用户自己去指定具体要去掉末尾几个字节,具体要看用户的自己的数据协议怎么设计。 func UnZeroPadding(src []byte, paddingLen int) []byte { return src[:len(src)-paddingLen] }