diff --git a/certtostore_windows.go b/certtostore_windows.go index e6a7be8..f2e8da8 100644 --- a/certtostore_windows.go +++ b/certtostore_windows.go @@ -46,6 +46,236 @@ import ( "golang.org/x/sys/windows" ) +// winAPI abstracts the Windows CryptoAPI and CNG functions directly used by this package. +type winAPI interface { + CertAddCertificateContextToStore(hCertStore windows.Handle, pCertContext *windows.CertContext, dwAddDisposition uint32, ppStoreContext **windows.CertContext) error + CertCloseStore(hCertStore windows.Handle, flags uint32) error + CertCreateCertificateContext(dwCertEncodingType uint32, pbCertEncoded *byte, cbCertEncoded uint32) (*windows.CertContext, error) + CertDeleteCertificateFromStore(pCertContext *windows.CertContext) error + CertFindCertificateInStore(hCertStore windows.Handle, dwCertEncodingType, dwFindFlags, dwFindType uint32, pvFindPara unsafe.Pointer, pPrevCertContext *windows.CertContext) (*windows.CertContext, error) + CertFreeCertificateChain(pChainContext *windows.CertChainContext) + CertFreeCertificateContext(pCertContext *windows.CertContext) error + CertGetCertificateChain(hChainEngine uintptr, pCertContext *windows.CertContext, pTime *windows.Filetime, hAdditionalStore windows.Handle, pChainPara *windows.CertChainPara, dwFlags uint32, pvReserved unsafe.Pointer, ppChainContext **windows.CertChainContext) (bool, error) + CertGetIntendedKeyUsage(dwCertEncodingType uint32, pCertInfo *windows.CertInfo, pbKeyUsage *byte, pcbKeyUsage *uint32) error + CertOpenStore(storeProvider uintptr, msgAndCertEncodingType, cryptProv uintptr, flags uint32, para uintptr) (windows.Handle, error) + CryptAcquireCertificatePrivateKey(pCert *windows.CertContext, dwFlags uint32, pvReserved unsafe.Pointer, phCryptProvOrNCryptKey *uintptr, pdwKeySpec *uint32, pfCallerFreeProvOrNCryptKey *bool) (bool, error) + CryptFindCertificateKeyProvInfo(pCert *windows.CertContext, dwFlags uint32, pvReserved unsafe.Pointer) (bool, error) + NCryptCreatePersistedKey(hProvider uintptr, phKey *uintptr, pszAlgID *uint16, pszKeyName *uint16, dwLegacyKeySpec uint32, dwFlags uint32) (uintptr, error) + NCryptDecrypt(hKey uintptr, pbInput *byte, cbInput uint32, pPaddingInfo unsafe.Pointer, pbOutput *byte, cbOutput uint32, pcbResult *uint32, dwFlags uint32) (uintptr, error) + NCryptExportKey(hKey uintptr, hExportKey uintptr, pszBlobType *uint16, pParameterList unsafe.Pointer, pbOutput *byte, cbOutput uint32, pcbResult *uint32, dwFlags uint32) (uintptr, error) + NCryptFinalizeKey(hKey uintptr, dwFlags uint32) (uintptr, error) + NCryptFreeObject(hObject uintptr) (uintptr, error) + NCryptGetProperty(hObject uintptr, pszProperty *uint16, pbOutput *byte, cbOutput uint32, pcbResult *uint32, dwFlags uint32) (uintptr, error) + NCryptOpenKey(hProvider uintptr, phKey *uintptr, pszKeyName *uint16, dwLegacyKeySpec uint32, dwFlags uint32) (uintptr, error) + NCryptOpenStorageProvider(phProvider *uintptr, pszProviderName *uint16, dwFlags uint32) (uintptr, error) + NCryptSetProperty(hObject uintptr, pszProperty *uint16, pbInput *byte, cbInput uint32, dwFlags uint32) (uintptr, error) + NCryptSignHash(hKey uintptr, pPaddingInfo unsafe.Pointer, pbHashValue *byte, cbHashValue uint32, pbSignature *byte, cbSignature uint32, pcbResult *uint32, dwFlags uint32) (uintptr, error) +} + +type defaultWinAPI struct{} + +func (defaultWinAPI) CertAddCertificateContextToStore(hCertStore windows.Handle, pCertContext *windows.CertContext, dwAddDisposition uint32, ppStoreContext **windows.CertContext) error { + return windows.CertAddCertificateContextToStore(hCertStore, pCertContext, dwAddDisposition, ppStoreContext) +} + +func (defaultWinAPI) CertCloseStore(hCertStore windows.Handle, flags uint32) error { + return windows.CertCloseStore(hCertStore, flags) +} + +func (defaultWinAPI) CertCreateCertificateContext(dwCertEncodingType uint32, pbCertEncoded *byte, cbCertEncoded uint32) (*windows.CertContext, error) { + return windows.CertCreateCertificateContext(dwCertEncodingType, pbCertEncoded, cbCertEncoded) +} + +func (defaultWinAPI) CertDeleteCertificateFromStore(pCertContext *windows.CertContext) error { + r, _, err := certDeleteCertificateFromStore.Call(uintptr(unsafe.Pointer(pCertContext))) + if r != 1 { + return fmt.Errorf("certdeletecertificatefromstore failed with %X: %v", r, err) + } + return nil +} + +func (defaultWinAPI) CertFindCertificateInStore(hCertStore windows.Handle, dwCertEncodingType, dwFindFlags, dwFindType uint32, pvFindPara unsafe.Pointer, pPrevCertContext *windows.CertContext) (*windows.CertContext, error) { + h, _, err := certFindCertificateInStore.Call( + uintptr(hCertStore), + uintptr(dwCertEncodingType), + uintptr(dwFindFlags), + uintptr(dwFindType), + uintptr(pvFindPara), + uintptr(unsafe.Pointer(pPrevCertContext)), + ) + if h == 0 { + return nil, err + } + return (*windows.CertContext)(unsafe.Pointer(h)), nil +} + +func (defaultWinAPI) CertFreeCertificateChain(pChainContext *windows.CertChainContext) { + certFreeCertificateChain.Call(uintptr(unsafe.Pointer(pChainContext))) +} + +func (defaultWinAPI) CertFreeCertificateContext(pCertContext *windows.CertContext) error { + return windows.CertFreeCertificateContext(pCertContext) +} + +func (defaultWinAPI) CertGetCertificateChain(hChainEngine uintptr, pCertContext *windows.CertContext, pTime *windows.Filetime, hAdditionalStore windows.Handle, pChainPara *windows.CertChainPara, dwFlags uint32, pvReserved unsafe.Pointer, ppChainContext **windows.CertChainContext) (bool, error) { + r, _, err := certGetCertificateChain.Call( + hChainEngine, + uintptr(unsafe.Pointer(pCertContext)), + uintptr(unsafe.Pointer(pTime)), + uintptr(hAdditionalStore), + uintptr(unsafe.Pointer(pChainPara)), + uintptr(dwFlags), + uintptr(pvReserved), + uintptr(unsafe.Pointer(ppChainContext)), + ) + return r != 0, err +} + +func (defaultWinAPI) CertGetIntendedKeyUsage(dwCertEncodingType uint32, pCertInfo *windows.CertInfo, pbKeyUsage *byte, pcbKeyUsage *uint32) error { + r, _, err := certGetIntendedKeyUsage.Call( + uintptr(dwCertEncodingType), + uintptr(unsafe.Pointer(pCertInfo)), + uintptr(unsafe.Pointer(pbKeyUsage)), + uintptr(unsafe.Pointer(pcbKeyUsage)), + ) + if r == 0 { + return err + } + return nil +} + +func (defaultWinAPI) CertOpenStore(storeProvider uintptr, msgAndCertEncodingType, cryptProv uintptr, flags uint32, para uintptr) (windows.Handle, error) { + return windows.CertOpenStore(storeProvider, msgAndCertEncodingType, cryptProv, flags, para) +} + +func (defaultWinAPI) CryptAcquireCertificatePrivateKey(pCert *windows.CertContext, dwFlags uint32, pvReserved unsafe.Pointer, phCryptProvOrNCryptKey *uintptr, pdwKeySpec *uint32, pfCallerFreeProvOrNCryptKey *bool) (bool, error) { + r, _, err := cryptAcquireCertificatePrivateKey.Call( + uintptr(unsafe.Pointer(pCert)), + uintptr(dwFlags), + uintptr(pvReserved), + uintptr(unsafe.Pointer(phCryptProvOrNCryptKey)), + uintptr(unsafe.Pointer(pdwKeySpec)), + uintptr(unsafe.Pointer(pfCallerFreeProvOrNCryptKey)), + ) + return r != 0, err +} + +func (defaultWinAPI) CryptFindCertificateKeyProvInfo(pCert *windows.CertContext, dwFlags uint32, pvReserved unsafe.Pointer) (bool, error) { + r, _, err := cryptFindCertificateKeyProvInfo.Call( + uintptr(unsafe.Pointer(pCert)), + uintptr(dwFlags), + uintptr(pvReserved), + ) + return r != 0, err +} + +func (defaultWinAPI) NCryptCreatePersistedKey(hProvider uintptr, phKey *uintptr, pszAlgID *uint16, pszKeyName *uint16, dwLegacyKeySpec uint32, dwFlags uint32) (uintptr, error) { + r, _, err := nCryptCreatePersistedKey.Call( + hProvider, + uintptr(unsafe.Pointer(phKey)), + uintptr(unsafe.Pointer(pszAlgID)), + uintptr(unsafe.Pointer(pszKeyName)), + uintptr(dwLegacyKeySpec), + uintptr(dwFlags), + ) + return r, err +} + +func (defaultWinAPI) NCryptDecrypt(hKey uintptr, pbInput *byte, cbInput uint32, pPaddingInfo unsafe.Pointer, pbOutput *byte, cbOutput uint32, pcbResult *uint32, dwFlags uint32) (uintptr, error) { + r, _, err := nCryptDecrypt.Call( + hKey, + uintptr(unsafe.Pointer(pbInput)), + uintptr(cbInput), + uintptr(pPaddingInfo), + uintptr(unsafe.Pointer(pbOutput)), + uintptr(cbOutput), + uintptr(unsafe.Pointer(pcbResult)), + uintptr(dwFlags), + ) + return r, err +} + +func (defaultWinAPI) NCryptExportKey(hKey uintptr, hExportKey uintptr, pszBlobType *uint16, pParameterList unsafe.Pointer, pbOutput *byte, cbOutput uint32, pcbResult *uint32, dwFlags uint32) (uintptr, error) { + r, _, err := nCryptExportKey.Call( + hKey, + hExportKey, + uintptr(unsafe.Pointer(pszBlobType)), + uintptr(pParameterList), + uintptr(unsafe.Pointer(pbOutput)), + uintptr(cbOutput), + uintptr(unsafe.Pointer(pcbResult)), + uintptr(dwFlags), + ) + return r, err +} + +func (defaultWinAPI) NCryptFinalizeKey(hKey uintptr, dwFlags uint32) (uintptr, error) { + r, _, err := nCryptFinalizeKey.Call(hKey, uintptr(dwFlags)) + return r, err +} + +func (defaultWinAPI) NCryptFreeObject(hObject uintptr) (uintptr, error) { + r, _, err := nCryptFreeObject.Call(hObject) + return r, err +} + +func (defaultWinAPI) NCryptGetProperty(hObject uintptr, pszProperty *uint16, pbOutput *byte, cbOutput uint32, pcbResult *uint32, dwFlags uint32) (uintptr, error) { + r, _, err := nCryptGetProperty.Call( + hObject, + uintptr(unsafe.Pointer(pszProperty)), + uintptr(unsafe.Pointer(pbOutput)), + uintptr(cbOutput), + uintptr(unsafe.Pointer(pcbResult)), + uintptr(dwFlags), + ) + return r, err +} + +func (defaultWinAPI) NCryptOpenKey(hProvider uintptr, phKey *uintptr, pszKeyName *uint16, dwLegacyKeySpec uint32, dwFlags uint32) (uintptr, error) { + r, _, err := nCryptOpenKey.Call( + hProvider, + uintptr(unsafe.Pointer(phKey)), + uintptr(unsafe.Pointer(pszKeyName)), + uintptr(dwLegacyKeySpec), + uintptr(dwFlags), + ) + return r, err +} + +func (defaultWinAPI) NCryptOpenStorageProvider(phProvider *uintptr, pszProviderName *uint16, dwFlags uint32) (uintptr, error) { + r, _, err := nCryptOpenStorageProvider.Call( + uintptr(unsafe.Pointer(phProvider)), + uintptr(unsafe.Pointer(pszProviderName)), + uintptr(dwFlags), + ) + return r, err +} + +func (defaultWinAPI) NCryptSetProperty(hObject uintptr, pszProperty *uint16, pbInput *byte, cbInput uint32, dwFlags uint32) (uintptr, error) { + r, _, err := nCryptSetProperty.Call( + hObject, + uintptr(unsafe.Pointer(pszProperty)), + uintptr(unsafe.Pointer(pbInput)), + uintptr(cbInput), + uintptr(dwFlags), + ) + return r, err +} + +func (defaultWinAPI) NCryptSignHash(hKey uintptr, pPaddingInfo unsafe.Pointer, pbHashValue *byte, cbHashValue uint32, pbSignature *byte, cbSignature uint32, pcbResult *uint32, dwFlags uint32) (uintptr, error) { + r, _, err := nCryptSignHash.Call( + hKey, + uintptr(pPaddingInfo), + uintptr(unsafe.Pointer(pbHashValue)), + uintptr(cbHashValue), + uintptr(unsafe.Pointer(pbSignature)), + uintptr(cbSignature), + uintptr(unsafe.Pointer(pcbResult)), + uintptr(dwFlags), + ) + return r, err +} + // WinCertStorage provides windows-specific additions to the CertStorage interface. type WinCertStorage interface { CertStorage @@ -250,11 +480,11 @@ func wide(s string) *uint16 { return &w[0] } -func openProvider(provider string) (uintptr, error) { +func openProvider(api winAPI, provider string) (uintptr, error) { var hProv uintptr pname := wide(provider) // Open the provider, the last parameter is not used - r, _, err := nCryptOpenStorageProvider.Call(uintptr(unsafe.Pointer(&hProv)), uintptr(unsafe.Pointer(pname)), 0) + r, err := api.NCryptOpenStorageProvider(&hProv, pname, 0) if r == 0 { return hProv, nil } @@ -349,6 +579,7 @@ type WinCertStore struct { stores map[string]*storeHandle keyAccessFlags uintptr storeFlags uint32 + api winAPI mu sync.Mutex } @@ -428,8 +659,9 @@ func OpenWinCertStoreCurrentUser(provider, container string, issuers, intermedia // - Missing required options (provider, container, issuers) // - Insufficient privileges for machine store access func OpenWinCertStoreWithOptions(opts WinCertStoreOptions) (*WinCertStore, error) { + api := defaultWinAPI{} // Open a handle to the crypto provider we will use for private key operations - cngProv, err := openProvider(opts.Provider) + cngProv, err := openProvider(api, opts.Provider) if err != nil { return nil, fmt.Errorf("unable to open crypto provider %q: %v", opts.Provider, err) } @@ -442,6 +674,7 @@ func OpenWinCertStoreWithOptions(opts WinCertStoreOptions) (*WinCertStore, error container: opts.Container, stores: make(map[string]*storeHandle), storeFlags: opts.StoreFlags, + api: defaultWinAPI{}, } // Deep copy the issuer slices to prevent external modification