pki_test.go 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277
  1. package pki_test
  2. import (
  3. "crypto/x509"
  4. "encoding/pem"
  5. "fmt"
  6. "math/big"
  7. "math/rand"
  8. "testing"
  9. "time"
  10. "github.com/cad/ovpm/pki"
  11. )
  12. func TestNewCA(t *testing.T) {
  13. // Initialize:
  14. // Prepare:
  15. ca, err := pki.NewCA()
  16. if err != nil {
  17. t.Fatalf("can not create CA in test: %v", err)
  18. }
  19. // Test:
  20. // Is CertHolder empty?
  21. if ca.CertHolder == (pki.CertHolder{}) {
  22. t.Errorf("returned ca.CertHolder can't be empty: %+v", ca.CertHolder)
  23. }
  24. // Is CSR empty length?
  25. if len(ca.CSR) == 0 {
  26. t.Errorf("returned ca.CSR is a zero-length string")
  27. }
  28. var encodingtests = []struct {
  29. name string // name
  30. block string // pem block string
  31. typ string // expected pem block type
  32. }{
  33. {"ca.CSR", ca.CSR, pki.PEMCSRBlockType},
  34. {"ca.CertHolder.Cert", ca.CertHolder.Cert, pki.PEMCertificateBlockType},
  35. {"ca.CertHolder.Key", ca.CertHolder.Key, pki.PEMRSAPrivateKeyBlockType},
  36. }
  37. // Is PEM encoded properly?
  38. for _, tt := range encodingtests {
  39. if !isPEMEncodedProperly(t, tt.block, tt.typ) {
  40. t.Errorf("returned '%s' is not PEM encoded properly: %+v", tt.name, tt.block)
  41. }
  42. }
  43. }
  44. // TestNewCertHolders tests pki.NewServerCertHolder and pki.NewClientCertHolder functions.
  45. func TestNewCertHolders(t *testing.T) {
  46. // Initialize:
  47. ca, _ := pki.NewCA()
  48. // Prepare:
  49. sch, err := pki.NewServerCertHolder(ca)
  50. if err != nil {
  51. t.Fatalf("can not create server cert holder: %v", err)
  52. }
  53. cch, err := pki.NewClientCertHolder(ca, "test-user")
  54. if err != nil {
  55. t.Fatalf("can not create client cert holder: %v", err)
  56. }
  57. // Test:
  58. var certholdertests = []struct {
  59. name string
  60. certHolder *pki.CertHolder
  61. }{
  62. {"server", sch},
  63. {"client", cch},
  64. }
  65. for _, tt := range certholdertests {
  66. // Is CertHolder empty?
  67. if *tt.certHolder == (pki.CertHolder{}) {
  68. t.Errorf("returned '%s' cert holder can't be empty: %+v", tt.name, sch)
  69. }
  70. var encodingtests = []struct {
  71. name string // name
  72. block string // pem block string
  73. typ string // expected pem block type
  74. }{
  75. {tt.name + "CertHolder.Cert", tt.certHolder.Cert, pki.PEMCertificateBlockType},
  76. {tt.name + "CertHolder.Key", tt.certHolder.Key, pki.PEMRSAPrivateKeyBlockType},
  77. }
  78. // Is PEM encoded properly?
  79. for _, tt := range encodingtests {
  80. if !isPEMEncodedProperly(t, tt.block, tt.typ) {
  81. t.Errorf("returned '%s' is not PEM encoded properly: %+v", tt.name, tt.block)
  82. }
  83. }
  84. }
  85. }
  86. func TestNewCRL(t *testing.T) {
  87. // Initialize:
  88. max := 5
  89. n := randomBetween(1, max)
  90. ca, _ := pki.NewCA()
  91. // Prepare:
  92. var certHolders []*pki.CertHolder
  93. for i := 0; i < max; i++ {
  94. username := fmt.Sprintf("user-%d", i)
  95. ch, _ := pki.NewClientCertHolder(ca, username)
  96. certHolders = append(certHolders, ch)
  97. }
  98. // Test:
  99. // Create CRL that revokes first n certificates.
  100. var serials []*big.Int
  101. for i := 0; i < n; i++ {
  102. serials = append(serials, getSerial(t, certHolders[i].Cert))
  103. }
  104. crl, err := pki.NewCRL(ca, serials...)
  105. if err != nil {
  106. t.Fatalf("crl can not be created: %v", err)
  107. }
  108. // Is CRL empty?
  109. if len(crl) == 0 {
  110. t.Fatalf("CRL length expected to be NOT EMPTY %+v", crl)
  111. }
  112. // Is CRL PEM encoded properly?
  113. if !isPEMEncodedProperly(t, crl, pki.PEMx509CRLBlockType) {
  114. t.Fatalf("CRL is expected to be properly PEM encoded %+v", crl)
  115. }
  116. // Parse CRL and get revoked certList.
  117. block, _ := pem.Decode([]byte(crl))
  118. certList, err := x509.ParseCRL(block.Bytes)
  119. if err != nil {
  120. t.Fatalf("CRL's PEM block is expected to be parsed '%+v' but instead it CAN'T BE PARSED: %v", block, err)
  121. }
  122. rcl := certList.TBSCertList.RevokedCertificates
  123. // Is revoked cert list length is n, as correctly?
  124. if len(rcl) != n {
  125. t.Fatalf("revoked cert list lenth is expected to be %d but it is %d", n, len(rcl))
  126. }
  127. // Is revoked certificate list is correct?
  128. for _, serial := range serials {
  129. found := false
  130. for _, rc := range rcl {
  131. //t.Logf("%d == %d", rc.SerialNumber, serial)
  132. if rc.SerialNumber.Cmp(serial) == 0 {
  133. found = true
  134. break
  135. }
  136. }
  137. if !found {
  138. t.Errorf("revoked serial '%d' is expected to be found in the generated CRL but it is NOT FOUND instead", serial)
  139. }
  140. }
  141. }
  142. func TestReadCertFromPEM(t *testing.T) {
  143. // Initialize:
  144. ca, _ := pki.NewCA()
  145. // Prepare:
  146. // Test:
  147. crt, err := pki.ReadCertFromPEM(ca.Cert)
  148. if err != nil {
  149. t.Fatalf("can not get cert from pem %+v", ca)
  150. }
  151. // Is crt nil?
  152. if crt == nil {
  153. t.Fatalf("cert is expected to be 'not nil' but it's 'nil' instead")
  154. }
  155. }
  156. // isPEMEncodedProperly takes an PEM encoded string s and the expected block type typ (e.g. "RSA PRIVATE KEY") and returns whether it can be decodable.
  157. func isPEMEncodedProperly(t *testing.T, s string, typ string) bool {
  158. block, _ := pem.Decode([]byte(s))
  159. if block == nil {
  160. t.Logf("block is nil")
  161. return false
  162. }
  163. if len(block.Bytes) == 0 {
  164. t.Logf("block bytes length is zero")
  165. return false
  166. }
  167. if block.Type != typ {
  168. t.Logf("expected block type '%s' but got '%s'", typ, block.Type)
  169. return false
  170. }
  171. switch block.Type {
  172. case pki.PEMCertificateBlockType:
  173. crt, err := x509.ParseCertificate(block.Bytes)
  174. if err != nil {
  175. t.Logf("certificate parse failed %+v: %v", block, err)
  176. return false
  177. }
  178. if crt == nil {
  179. t.Logf("couldn't parse certificate %+v", block)
  180. return false
  181. }
  182. case pki.PEMRSAPrivateKeyBlockType:
  183. key, err := x509.ParsePKCS1PrivateKey(block.Bytes)
  184. if err != nil {
  185. t.Logf("private key parse failed %+v: %v", block, err)
  186. return false
  187. }
  188. if key == nil {
  189. t.Logf("couldn't parse private key %+v", block)
  190. return false
  191. }
  192. case pki.PEMCSRBlockType:
  193. csr, err := x509.ParseCertificateRequest(block.Bytes)
  194. if err != nil {
  195. t.Logf("CSR parse failed %+v: %v", block, err)
  196. return false
  197. }
  198. if csr == nil {
  199. t.Logf("couldn't parse CSR %+v", block)
  200. return false
  201. }
  202. case pki.PEMx509CRLBlockType:
  203. crl, err := x509.ParseCRL(block.Bytes)
  204. if err != nil {
  205. t.Logf("CRL parse failed %+v: %v", block, err)
  206. return false
  207. }
  208. if crl == nil {
  209. t.Logf("couldn't parse crl %+v", block)
  210. return false
  211. }
  212. }
  213. return true
  214. }
  215. // getSerial returns serial number of a pem encoded certificate
  216. func getSerial(t *testing.T, crt string) *big.Int {
  217. // PEM decode.
  218. block, _ := pem.Decode([]byte(crt))
  219. if block == nil {
  220. t.Fatalf("block is nil %+v", block)
  221. }
  222. // Parse certificate.
  223. cert, err := x509.ParseCertificate(block.Bytes)
  224. if err != nil {
  225. t.Fatalf("certificate can not be parsed from block %+v: %v", block, err)
  226. }
  227. return cert.SerialNumber
  228. }
  229. // randomBetween returns a random int between min and max
  230. func randomBetween(min, max int) int {
  231. rand.Seed(time.Now().Unix())
  232. return rand.Intn(max-min) + min
  233. }