diff --git a/certtostore_windows.go b/certtostore_windows.go index b6c82eb..23c9182 100644 --- a/certtostore_windows.go +++ b/certtostore_windows.go @@ -25,6 +25,7 @@ import ( "crypto/rsa" "crypto/x509" "encoding/binary" + "encoding/hex" "errors" "fmt" "io" @@ -243,6 +244,12 @@ type pssPaddingInfo struct { cbSalt uint32 } +// cryptHashBlob is the CRYPT_HASH_BLOB struct in wincrypt.h. +type cryptHashBlob struct { + cbData uint32 + pbData *byte +} + // wide returns a pointer to a a uint16 representing the equivalent // to a Windows LPCWSTR. func wide(s string) *uint16 { @@ -1825,3 +1832,59 @@ func (w *WinCertStore) CertByCommonName(commonName string) (*x509.Certificate, } return nil, nil, nil, cryptENotFound } + +// CertBySHA1Hash searches for a certificate by its SHA1 hash in the store. +// The hash must be provided as a hex-encoded string, containing only hexadecimal +// characters. +// The returned *windows.CertContext must be freed by the caller using +// FreeCertContext to avoid resource leaks. +func (w *WinCertStore) CertBySHA1Hash(hash string) (*x509.Certificate, + *windows.CertContext, [][]*x509.Certificate, error) { + storeHandle, err := w.storeHandle(w.storeDomain(), my) + if err != nil { + return nil, nil, nil, fmt.Errorf("failed to open certificate store: %v", err) + } + + // Convert hex string to binary data + hashBytes, err := hex.DecodeString(hash) + if err != nil { + return nil, nil, nil, fmt.Errorf("invalid hex string: %v", err) + } + + // Create CRYPT_HASH_BLOB structure + hashBlob := cryptHashBlob{ + cbData: uint32(len(hashBytes)), + pbData: &hashBytes[0], + } + + // Find the certificate by its SHA1 hash, there can be only one so the `prev` context is NULL. + certContext, err := findCert( + storeHandle, + encodingX509ASN|encodingPKCS7, + 0, + windows.CERT_FIND_SHA1_HASH, + (*uint16)(unsafe.Pointer(&hashBlob)), + nil, + ) + if err != nil { + return nil, nil, nil, fmt.Errorf("could not find certificate by SHA1 hash %q: %w", + hash, err) + } + if certContext == nil { + return nil, nil, nil, cryptENotFound + } + + cert, err := certContextToX509(certContext) + if err != nil { + FreeCertContext(certContext) // Free context to avoid memory leak + return nil, nil, nil, err + } + + if err := w.resolveChains(certContext); err != nil { + FreeCertContext(certContext) // Free context to avoid memory leak + return nil, nil, nil, err + } + + // Found a valid certificate, return it. + return cert, certContext, w.certChains, nil +} diff --git a/certtostore_windows_test.go b/certtostore_windows_test.go index a8afe77..7e62226 100644 --- a/certtostore_windows_test.go +++ b/certtostore_windows_test.go @@ -17,8 +17,10 @@ package certtostore import ( "crypto/rand" "crypto/rsa" + "crypto/sha1" "crypto/x509" "crypto/x509/pkix" + "encoding/hex" "errors" "fmt" "math/big" @@ -676,3 +678,121 @@ func TestCertByCommonName(t *testing.T) { t.Errorf("chains[0][0] is not the leaf; got %v, want leaf %v", chains[0][0].Subject, found.Subject) } } + +func TestCertBySHA1Hash(t *testing.T) { + // Open a valid, writable current-user store. + opts := WinCertStoreOptions{ + Provider: ProviderMSSoftware, + Container: "TestContainerForSHA1Lookup", + Issuers: []string{"CN=Test CA"}, + IntermediateIssuers: []string{"CN=Intermediate CA"}, + LegacyKey: false, + CurrentUser: true, + StoreFlags: 0, + } + store, err := OpenWinCertStoreWithOptions(opts) + if err != nil { + t.Fatalf("failed to open store: %v", err) + } + defer store.Close() + + // Create a self-signed cert with a unique CN. + cn := fmt.Sprintf("__certtostore_%d__", time.Now().UnixNano()) + key, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatalf("rsa.GenerateKey: %v", err) + } + template := x509.Certificate{ + SerialNumber: big.NewInt(time.Now().UnixNano()), + Subject: pkix.Name{ + CommonName: cn, + }, + NotBefore: time.Now().Add(-1 * time.Minute), + NotAfter: time.Now().Add(5 * time.Minute), + KeyUsage: x509.KeyUsageDigitalSignature, + BasicConstraintsValid: true, + } + der, err := x509.CreateCertificate(rand.Reader, &template, &template, &key.PublicKey, key) + if err != nil { + t.Fatalf("x509.CreateCertificate: %v", err) + } + cert, err := x509.ParseCertificate(der) + if err != nil { + t.Fatalf("x509.ParseCertificate: %v", err) + } + + // Insert the cert directly into the current-user MY store to avoid private key association. + // Local constants for CertOpenStore. + const ( + certStoreProvSystem = 10 // CERT_STORE_PROV_SYSTEM + certSystemStoreCurrentUser = 1 << 16 + x509ASN = 1 // X509_ASN_ENCODING + pkcs7ASN = 65536 // PKCS_7_ASN_ENCODING + ) + myW, err := windows.UTF16PtrFromString("MY") + if err != nil { + t.Fatalf("UTF16PtrFromString: %v", err) + } + h, err := windows.CertOpenStore( + certStoreProvSystem, + 0, + 0, + certSystemStoreCurrentUser, + uintptr(unsafe.Pointer(myW)), + ) + if err != nil { + t.Fatalf("CertOpenStore: %v", err) + } + defer windows.CertCloseStore(h, 0) + + ctx, err := windows.CertCreateCertificateContext( + x509ASN|pkcs7ASN, + &cert.Raw[0], + uint32(len(cert.Raw)), + ) + if err != nil { + t.Fatalf("CertCreateCertificateContext: %v", err) + } + defer windows.CertFreeCertificateContext(ctx) + + if err := windows.CertAddCertificateContextToStore(h, ctx, windows.CERT_STORE_ADD_ALWAYS, nil); err != nil { + t.Fatalf("CertAddCertificateContextToStore: %v", err) + } + + // Create a SHA-1 hash of the created cert + hasher := sha1.New() + hasher.Write(cert.Raw) + sha1target := hasher.Sum(nil) + + // Query by SHA-1 hash which is not expected in cert store + _, _, _, err = store.CertBySHA1Hash("1234567890abcdef1234567890abcdef12345678") // random hash + if !errors.Is(err, cryptENotFound) { + t.Fatalf("expected cryptENotFound error, got %v", err) + } + + // Query by legitimate SHA-1 hash. + found, foundCtx, _, err := store.CertBySHA1Hash(hex.EncodeToString(sha1target)) + if err != nil { + t.Fatalf("CertBySHA1Hash returned error: %v", err) + } + + if found == nil { + t.Fatal("expected a certificate, got nil") + } + + if foundCtx == nil { + t.Fatal("expected a cert context, got nil") + } + + // Ensure cleanup: RemoveCertByContext frees foundCtx. + defer func() { + if delErr := RemoveCertByContext(foundCtx); delErr != nil { + t.Fatalf("RemoveCertByContext: %v", delErr) + } + }() + + // Validate result. + if found.Subject.CommonName != cn { + t.Fatalf("unexpected CommonName: got %q, want %q", found.Subject.CommonName, cn) + } +}