diff --git a/ext/openssl/ossl_pkey.c b/ext/openssl/ossl_pkey.c index e88074ddf..481bd8a8e 100644 --- a/ext/openssl/ossl_pkey.c +++ b/ext/openssl/ossl_pkey.c @@ -83,33 +83,41 @@ ossl_pkey_wrap(EVP_PKEY *pkey) # include static EVP_PKEY * -ossl_pkey_read(BIO *bio, const char *input_type, int selection, VALUE pass) +ossl_pkey_read(BIO *bio, const char *input_type, int selection, VALUE pass, + int *retryable) { void *ppass = (void *)pass; OSSL_DECODER_CTX *dctx; EVP_PKEY *pkey = NULL; int pos = 0, pos2; + *retryable = 0; dctx = OSSL_DECODER_CTX_new_for_pkey(&pkey, input_type, NULL, NULL, selection, NULL, NULL); if (!dctx) goto out; - if (OSSL_DECODER_CTX_set_pem_password_cb(dctx, ossl_pem_passwd_cb, + if (selection == EVP_PKEY_KEYPAIR && + OSSL_DECODER_CTX_set_pem_password_cb(dctx, ossl_pem_passwd_cb, ppass) != 1) goto out; while (1) { if (OSSL_DECODER_from_bio(dctx, bio) == 1) - goto out; - if (BIO_eof(bio)) break; + if (ERR_GET_REASON(ERR_peek_error()) != ERR_R_UNSUPPORTED) + break; + if (BIO_eof(bio) == 1) { + *retryable = 1; + break; + } pos2 = BIO_tell(bio); - if (pos2 < 0 || pos2 <= pos) + if (pos2 < 0 || pos2 <= pos) { + *retryable = 1; break; + } ossl_clear_error(); pos = pos2; } out: - OSSL_BIO_reset(bio); OSSL_DECODER_CTX_free(dctx); return pkey; } @@ -117,7 +125,6 @@ ossl_pkey_read(BIO *bio, const char *input_type, int selection, VALUE pass) EVP_PKEY * ossl_pkey_read_generic(BIO *bio, VALUE pass) { - EVP_PKEY *pkey = NULL; /* First check DER, then check PEM. */ const char *input_types[] = {"DER", "PEM"}; int input_type_num = (int)(sizeof(input_types) / sizeof(char *)); @@ -166,18 +173,22 @@ ossl_pkey_read_generic(BIO *bio, VALUE pass) EVP_PKEY_PUBLIC_KEY }; int selection_num = (int)(sizeof(selections) / sizeof(int)); - int i, j; - for (i = 0; i < input_type_num; i++) { - for (j = 0; j < selection_num; j++) { - pkey = ossl_pkey_read(bio, input_types[i], selections[j], pass); - if (pkey) { - goto out; + for (int i = 0; i < input_type_num; i++) { + for (int j = 0; j < selection_num; j++) { + if (i || j) { + ossl_clear_error(); + BIO_reset(bio); } + + int retryable; + EVP_PKEY *pkey = ossl_pkey_read(bio, input_types[i], selections[j], + pass, &retryable); + if (pkey || !retryable) + return pkey; } } - out: - return pkey; + return NULL; } #else EVP_PKEY * diff --git a/test/openssl/test_pkey.rb b/test/openssl/test_pkey.rb index 71f5da81d..f177cbc47 100644 --- a/test/openssl/test_pkey.rb +++ b/test/openssl/test_pkey.rb @@ -60,6 +60,133 @@ def test_s_generate_key assert_not_equal nil, pkey.private_key end + def test_s_read_pem_unknown_block + # A PEM-encoded certificate and a PEM-encoded private key are combined. + # Check that OSSL_STORE doesn't stop after the first PEM block. + pkey = Fixtures.pkey("rsa-1") + subject = OpenSSL::X509::Name.new([["CN", "test"]]) + cert = issue_cert(subject, pkey, 1, [], nil, nil) + + input = cert.to_pem + pkey.private_to_pem + decoded = OpenSSL::PKey.read(input) + assert_equal(pkey.private_to_der, decoded.private_to_der) + end + + def test_s_read_der_then_pem + # Ensure DER decode -> DER encode round-trip. If the input is valid as both + # DER and PEM (with garbage data before and after the block), DER is + # prioritized. + inner_rsa = Fixtures.pkey("rsa-1") + inner_pem = inner_rsa.private_to_pem + + asn1 = OpenSSL::ASN1::Sequence([ + OpenSSL::ASN1::Sequence([ + OpenSSL::ASN1::ObjectId("rsaEncryption"), + OpenSSL::ASN1::Null(nil) + ]), + OpenSSL::ASN1::BitString( + OpenSSL::ASN1::Sequence([ + OpenSSL::ASN1::Integer(OpenSSL::BN.new("\n" + inner_pem, 2)), + OpenSSL::ASN1::Integer(65537) + ]).to_der + ) + ]) + assert_include(asn1.to_der, inner_pem) + + pkey = OpenSSL::PKey.read(asn1.to_der) + assert_equal(asn1.to_der, pkey.to_der) + assert_not_predicate(pkey, :private?) + end + + def test_s_read_passphrase + pkey0 = Fixtures.pkey("rsa-1") + encrypted_pem = pkey0.private_to_pem("AES-256-CBC", "correct_passphrase") + assert_match(/\A-----BEGIN ENCRYPTED PRIVATE KEY-----/, encrypted_pem) + + # Correct passphrase passed as the second argument + pkey1 = OpenSSL::PKey.read(encrypted_pem, "correct_passphrase") + assert_equal(pkey0.private_to_der, pkey1.private_to_der) + + # Correct passphrase returned by the block. The block gets false + called = 0 + flag = nil + pkey2 = OpenSSL::PKey.read(encrypted_pem) { |f| + called += 1 + flag = f + "correct_passphrase" + } + assert_equal(pkey0.private_to_der, pkey2.private_to_der) + assert_equal(1, called) + assert_false(flag) + + # Incorrect passphrase passed. The block is not called + called = 0 + assert_raise(OpenSSL::PKey::PKeyError) { + OpenSSL::PKey.read(encrypted_pem, "incorrect_passphrase") { + called += 1 + } + } + assert_equal(0, called) + + # Incorrect passphrase returned by the block. The block is called only once + called = 0 + assert_raise(OpenSSL::PKey::PKeyError) { + OpenSSL::PKey.read(encrypted_pem) { + called += 1 + "incorrect_passphrase" + } + } + assert_equal(1, called) + + # Incorrect passphrase returned by the block. The input contains two PEM + # blocks. + called = 0 + assert_raise(OpenSSL::PKey::PKeyError) { + OpenSSL::PKey.read(encrypted_pem + encrypted_pem) { + called += 1 + "incorrect_passphrase" + } + } + assert_equal(1, called) + end + + def test_s_read_passphrase_tty + pkey0 = Fixtures.pkey("rsa-1") + encrypted_pem = pkey0.private_to_pem("AES-256-CBC", "correct_passphrase") + + # Correct passphrase passed to OpenSSL's prompt + script = <<~"end;" + require "openssl" + Process.setsid + OpenSSL::PKey.read(#{encrypted_pem.dump}) + puts "ok" + end; + assert_in_out_err([*$:.map { |l| "-I#{l}" }, "-e#{script}"], + "correct_passphrase\n") { |stdout, stderr| + assert_equal(["Enter PEM pass phrase:"], stderr) + assert_equal(["ok"], stdout) + } + + # Incorrect passphrase passed to OpenSSL's prompt + script = <<~"end;" + require "openssl" + Process.setsid + begin + OpenSSL::PKey.read(#{encrypted_pem.dump}) + rescue OpenSSL::PKey::PKeyError + puts "ok" + else + puts "expected OpenSSL::PKey::PKeyError" + end + end; + stdin = "incorrect_passphrase\n" * 5 + assert_in_out_err([*$:.map { |l| "-I#{l}" }, "-e#{script}"], + stdin) { |stdout, stderr| + assert_equal(1, stderr.count("Enter PEM pass phrase:")) + assert_equal(["ok"], stdout) + } + end + def test_hmac_sign_verify pkey = OpenSSL::PKey.generate_key("HMAC", { "key" => "abcd" })