Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 51 additions & 24 deletions certtostore_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ import (
"unsafe"

"github.com/google/deck"
"golang.org/x/crypto/cryptobyte"
"golang.org/x/crypto/cryptobyte/asn1"
"golang.org/x/crypto/cryptobyte"
"golang.org/x/sys/windows"
)

Expand All @@ -52,9 +52,11 @@ type WinCertStorage interface {
CertStorage

// Remove removes certificates issued by any of w.issuers from the user and/or system cert stores.
// If it is unable to remove any certificates, it returns an error.
Remove(removeSystem bool) error

// RemoveByCertInfo removes certificate(s) with the given subject and serial number from the user and/or system cert stores.
RemoveByCertInfo(certinfo *windows.CertInfo, removeSystem bool) error

// Link will associate the certificate installed in the system store to the user store.
Link() error

Expand Down Expand Up @@ -98,8 +100,10 @@ const (
certStoreLocalMachineID = 2 // CERT_SYSTEM_STORE_LOCAL_MACHINE_ID
infoIssuerFlag = 4 // CERT_INFO_ISSUER_FLAG
compareNameStrW = 8 // CERT_COMPARE_NAME_STR_A
compareSubjectCert = 11 // CERT_COMPARE_SUBJECT_CERT
compareShift = 16 // CERT_COMPARE_SHIFT
findIssuerStr = compareNameStrW<<compareShift | infoIssuerFlag // CERT_FIND_ISSUER_STR_W
findSubjectCert = compareSubjectCert << compareShift // CERT_FIND_SUBJECT_CERT
signatureKeyUsage = 0x80 // CERT_DIGITAL_SIGNATURE_KEY_USAGE

// Legacy CryptoAPI flags
Expand Down Expand Up @@ -264,7 +268,7 @@ func openProvider(provider string) (uintptr, error) {

// findCert wraps the CertFindCertificateInStore call. Note that any cert context passed
// into prev will be freed. If no certificate was found, nil will be returned.
func findCert(store windows.Handle, enc, findFlags, findType uint32, para *uint16, prev *windows.CertContext) (*windows.CertContext, error) {
func findCert[T any](store windows.Handle, enc, findFlags, findType uint32, para *T, prev *windows.CertContext) (*windows.CertContext, error) {
h, _, err := certFindCertificateInStore.Call(
uintptr(store),
uintptr(enc),
Expand Down Expand Up @@ -786,7 +790,6 @@ func (w *WinCertStore) linkLegacy() error {
}

// Remove removes certificates issued by any of w.issuers from the user and/or system cert stores.
// If it is unable to remove any certificates, it returns an error.
func (w *WinCertStore) Remove(removeSystem bool) error {
if w.isReadOnly() {
return fmt.Errorf("cannot remove certificates from a read-only store")
Expand All @@ -801,23 +804,35 @@ func (w *WinCertStore) Remove(removeSystem bool) error {

// remove removes a certificate issued by w.issuer from the user and/or system cert stores.
func (w *WinCertStore) remove(issuer string, removeSystem bool) error {
return w.removeCert(func(h windows.Handle) (*windows.CertContext, error) {
return findCert(
h,
encodingX509ASN|encodingPKCS7,
0,
findIssuerStr,
wide(issuer),
nil)
}, removeSystem)
}

type findCertFn func(h windows.Handle) (*windows.CertContext, error)

// remove removes a certificate found via findCertFn from the user and/or system cert stores.
func (w *WinCertStore) removeCert(findCertFn findCertFn, removeSystem bool) error {
h, err := w.storeHandle(certStoreCurrentUser, my)
if err != nil {
return err
}

userCertContext, err := findCert(
h,
encodingX509ASN|encodingPKCS7,
0,
findIssuerStr,
wide(issuer),
nil)
// Find the user cert.
userCertContext, err := findCertFn(h)
if err != nil {
return fmt.Errorf("remove: finding user certificate issued by %s failed: %v", issuer, err)
return fmt.Errorf("remove: finding user certificate failed: %v", err)
}

if userCertContext != nil {
if userCertContext == nil {
deck.Info("No user certificate found.")
} else {
if err := RemoveCertByContext(userCertContext); err != nil {
return fmt.Errorf("failed to remove user cert: %v", err)
}
Expand All @@ -835,28 +850,40 @@ func (w *WinCertStore) remove(issuer string, removeSystem bool) error {
return err
}

systemCertContext, err := findCert(
h2,
encodingX509ASN|encodingPKCS7,
0,
findIssuerStr,
wide(issuer),
nil)
// Find the system cert.
systemCertContext, err := findCertFn(
h2)
if err != nil {
return fmt.Errorf("remove: finding system certificate issued by %s failed: %v", issuer, err)
return fmt.Errorf("remove: finding system certificate failed: %v", err)
}

if systemCertContext != nil {
if systemCertContext == nil {
deck.Info("No system certificate found.")
} else {
if err := RemoveCertByContext(systemCertContext); err != nil {
return fmt.Errorf("failed to remove system cert: %v", err)
}
deck.Info("Cleaned up a system certificate.")
fmt.Fprintln(os.Stderr, "Cleaned up a system certificate.")
}

return nil
}

// RemoveByCertInfo removes certificate(s) with the given subject and serial number from the user and/or system cert stores.
func (w *WinCertStore) RemoveByCertInfo(certinfo *windows.CertInfo, removeSystem bool) error {
if w.isReadOnly() {
return fmt.Errorf("cannot remove certificates from a read-only store")
}
return w.removeCert(func(h windows.Handle) (*windows.CertContext, error) {
return findCert(
h,
encodingX509ASN|encodingPKCS7,
0,
findSubjectCert,
certinfo,
nil)
}, removeSystem)
}

// RemoveCertByContext wraps CertDeleteCertificateFromStore. If the call succeeds, nil is returned, otherwise
// the extended error is returned.
func RemoveCertByContext(certContext *windows.CertContext) error {
Expand Down
Loading