diff --git a/src/httpfs_extension.cpp b/src/httpfs_extension.cpp index 79d9923..aabd6ca 100644 --- a/src/httpfs_extension.cpp +++ b/src/httpfs_extension.cpp @@ -91,6 +91,9 @@ static void LoadInternal(ExtensionLoader &loader) { Value(50)); config.AddExtensionOption("unsafe_disable_etag_checks", "Disable checks on ETag consistency", LogicalType::BOOLEAN, Value(false)); + + // S3 Access Grants config + config.AddExtensionOption("s3_access_grants_enabled", "Enable S3 Access grants", LogicalType::BOOLEAN, Value(false)); // HuggingFace options config.AddExtensionOption("hf_max_per_page", "Debug option to limit number of items returned in list requests", diff --git a/src/include/lru_cache.hpp b/src/include/lru_cache.hpp new file mode 100644 index 0000000..5c6c07d --- /dev/null +++ b/src/include/lru_cache.hpp @@ -0,0 +1,62 @@ +#pragma once +#include +#include +#include "duckdb/common/mutex.hpp" + +namespace duckdb { + template + class LRUCache + { + public: + typedef typename std::pair key_value_pair_t; + typedef typename std::list::iterator list_iterator_t; + LRUCache (size_t capacity): _capacity(capacity) {} + void Put(const key_t& key, const value_t& value) { + lock_guard parallel_lock(lock); + auto it = _cache_items_map.find(key); + _cache_items_list.push_front(key_value_pair_t(key, value)); + if (it != _cache_items_map.end()) { + _cache_items_list.erase(it->second); + _cache_items_map.erase(it); + } + _cache_items_map[key] = _cache_items_list.begin(); + + if (_cache_items_map.size() > _capacity) { + auto last = _cache_items_list.end(); + last--; + _cache_items_map.erase(last->first); + _cache_items_list.pop_back(); + } + } + bool Get(const key_t& key, value_t& value) { + lock_guard parallel_lock(lock); + auto it = _cache_items_map.find(key); + if (it == _cache_items_map.end()) { + return false; + } else { + _cache_items_list.splice(_cache_items_list.begin(), _cache_items_list, it->second); + value = it->second->second; + return true; + } + } + void Delete(const key_t& key) { + lock_guard parallel_lock(lock); + auto it = _cache_items_map.find(key); + if (it == _cache_items_map.end()) { + return; //another thread already cleaned up + } + _cache_items_list.erase(it->second); + _cache_items_map.erase(it); + } + size_t Size() const { + return _cache_items_map.size(); + } + + private: + std::list _cache_items_list; + std::unordered_map _cache_items_map; + size_t _capacity; + mutex lock; + }; + +} \ No newline at end of file diff --git a/src/include/s3fs.hpp b/src/include/s3fs.hpp index 525e0dd..fc524c4 100644 --- a/src/include/s3fs.hpp +++ b/src/include/s3fs.hpp @@ -31,6 +31,7 @@ struct S3AuthParams { bool use_ssl = true; bool s3_url_compatibility_mode = false; bool requester_pays = false; + bool s3_access_grants_enabled = false; string oauth2_bearer_token; // OAuth2 bearer token for GCS static S3AuthParams ReadFrom(optional_ptr opener, FileOpenerInfo &info); @@ -46,6 +47,7 @@ struct AWSEnvironmentCredentialsProvider { static constexpr const char *DUCKDB_USE_SSL_ENV_VAR = "DUCKDB_S3_USE_SSL"; static constexpr const char *DUCKDB_KMS_KEY_ID_ENV_VAR = "DUCKDB_S3_KMS_KEY_ID"; static constexpr const char *DUCKDB_REQUESTER_PAYS_ENV_VAR = "DUCKDB_S3_REQUESTER_PAYS"; + static constexpr const char *DUCKDB_S3_ACCESS_GRANTS_ENABLED_ENV_VAR = "DUCKDB_S3_ACCESS_GRANTS_ENABLED"; explicit AWSEnvironmentCredentialsProvider(DBConfig &config) : config(config) {}; diff --git a/src/s3fs.cpp b/src/s3fs.cpp index 2a92f22..fb2b999 100644 --- a/src/s3fs.cpp +++ b/src/s3fs.cpp @@ -19,6 +19,7 @@ #include "duckdb/storage/buffer_manager.hpp" #include "create_secret_functions.hpp" +#include "lru_cache.hpp" #include #include @@ -28,9 +29,22 @@ namespace duckdb { +struct TemporaryAWSCredential +{ + string access_key_id; + string secret_access_key; + string session_token; + timestamp_t expiration; +}; + +static LRUCache AccountIdCache(1024); +static LRUCache BucketOwnerAccountIdCache(2048); +static LRUCache AccessGrantsCache(4096); +static LRUCache AccessDeniedCache(4096); + static HTTPHeaders create_s3_header(string url, string query, string host, string service, string method, const S3AuthParams &auth_params, string date_now = "", string datetime_now = "", - string payload_hash = "", string content_type = "") { + string payload_hash = "", string content_type = "", string account_id = "") { HTTPHeaders res; res["Host"] = host; @@ -70,6 +84,10 @@ static HTTPHeaders create_s3_header(string url, string query, string host, strin res["x-amz-request-payer"] = "requester"; } + if(!account_id.empty()) { + res["x-amz-account-id"] = account_id; + } + string signed_headers = ""; hash_bytes canonical_request_hash; hash_str canonical_request_hash_str; @@ -79,7 +97,11 @@ static HTTPHeaders create_s3_header(string url, string query, string host, strin res["content-type"] = content_type; #endif } - signed_headers += "host;x-amz-content-sha256;x-amz-date"; + signed_headers += "host"; + if (!account_id.empty()) { + signed_headers += ";x-amz-account-id"; + } + signed_headers += ";x-amz-content-sha256;x-amz-date"; if (use_requester_pays) { signed_headers += ";x-amz-request-payer"; } @@ -93,7 +115,11 @@ static HTTPHeaders create_s3_header(string url, string query, string host, strin if (content_type.length() > 0) { canonical_request += "\ncontent-type:" + content_type; } - canonical_request += "\nhost:" + host + "\nx-amz-content-sha256:" + payload_hash + "\nx-amz-date:" + datetime_now; + canonical_request += "\nhost:" + host; + if (!account_id.empty()) { + canonical_request += "\nx-amz-account-id:" + account_id; + } + canonical_request += "\nx-amz-content-sha256:" + payload_hash + "\nx-amz-date:" + datetime_now; if (use_requester_pays) { canonical_request += "\nx-amz-request-payer:requester"; } @@ -129,6 +155,201 @@ static HTTPHeaders create_s3_header(string url, string query, string host, strin return res; } +optional_idx FindTagContents(const string &response, const string &tag, idx_t cur_pos, string &result) { + string open_tag = "<" + tag + ">"; + string close_tag = ""; + auto open_tag_pos = response.find(open_tag, cur_pos); + if (open_tag_pos == string::npos) { + // tag not found + return optional_idx(); + } + auto close_tag_pos = response.find(close_tag, open_tag_pos + open_tag.size()); + if (close_tag_pos == string::npos) { + throw InternalException("Failed to parse S3 result: found open tag for %s but did not find matching close tag", + tag); + } + result = response.substr(open_tag_pos + open_tag.size(), close_tag_pos - open_tag_pos - open_tag.size()); + return close_tag_pos + close_tag.size(); +} + +string get_current_account_id(HTTPParams &http_params, S3AuthParams &auth_params) { + string cached_account_id; + if (AccountIdCache.Get(auth_params.access_key_id, cached_account_id)) { + return cached_account_id; + } + string query = "Action=GetCallerIdentity&Version=2011-06-15"; + string host = "sts." + auth_params.region + ".amazonaws.com"; + auto full_url = "https://" + host + "?" + query; + auto headers = create_s3_header("/", query, host, "sts", "GET", auth_params); + std::stringstream response; + GetRequestInfo get_account("https://" + host, "/?" + query, headers, http_params, + [&](const HTTPResponse &response) { + if (static_cast(response.status) >= 400) { + throw Exception(ExceptionType::INVALID_INPUT, full_url); + } + return true; + }, + [&](const_data_ptr_t data, idx_t data_length) { + response << string(const_char_ptr_cast(data), data_length); + return true; + }); + auto result = http_params.http_util.Request(get_account); + if (result->HasRequestError()) { + throw IOException("%s error for HTTP GET to '%s'", result->GetRequestError(), full_url); + } + string account_id; + FindTagContents(response.str(), "Account", 0, account_id); + AccountIdCache.Put(auth_params.access_key_id, account_id); + return account_id; +} + +string get_account_id_for_s3_object(HTTPParams &http_params, S3AuthParams &auth_params, const string& url) { + auto parsed_url = S3FileSystem::S3UrlParse(url, auth_params); + string cached_account_id; + if (BucketOwnerAccountIdCache.Get(parsed_url.bucket, cached_account_id)) { + return cached_account_id; + } + string caller_account_id = get_current_account_id(http_params, auth_params); + string query = "s3prefix=" + StringUtil::URLEncode(url); + string access_grants_url = "/v20180820/accessgrantsinstance/prefix"; + string host = caller_account_id + ".s3-control." + auth_params.region + ".amazonaws.com"; + auto full_url = "https://" + host + access_grants_url + "?" + query; + auto headers = create_s3_header(access_grants_url, query, host, "s3", "GET", auth_params, "", "", "", "", caller_account_id); + std::stringstream response; + GetRequestInfo get_access_grant_urn(host, access_grants_url + "?" + query, headers, http_params, + [&](const HTTPResponse &response) { + if (static_cast(response.status) >= 400) { + throw Exception(ExceptionType::INVALID_INPUT, full_url); + } + return true; + }, + [&](const_data_ptr_t data, idx_t data_length) { + response << string(const_char_ptr_cast(data), data_length); + return true; + }); + auto result = http_params.http_util.Request(get_access_grant_urn); + if (result->HasRequestError()) { + throw IOException("%s error for HTTP GET to '%s'", result->GetRequestError(), full_url); + } + string access_grants_arn; + FindTagContents(response.str(), "AccessGrantsInstanceArn", 0, access_grants_arn); + vector parts = StringUtil::Split(access_grants_arn, ':'); + BucketOwnerAccountIdCache.Put(parsed_url.bucket, parts[4]); + return parts[4]; +} + +bool get_data_access(HTTPParams &http_params, S3AuthParams &auth_params, const string& operation, const string& url, string& access_key_id, string& secret_access_key, string& session_token) { + timestamp_t access_denied_timestamp; + string url_fixed_prefix = url; + if (StringUtil::StartsWith(url_fixed_prefix, "s3a://")) { + url_fixed_prefix = StringUtil::Replace(url_fixed_prefix, "s3a://", "s3://"); + } + + if (AccessDeniedCache.Get(url_fixed_prefix, access_denied_timestamp)) { + // do not try if we find a recent(5 min) access denied + if (access_denied_timestamp > Timestamp::GetCurrentTimestamp()) { + return false; + } + AccessDeniedCache.Delete(url_fixed_prefix); + } + auto account_id = get_account_id_for_s3_object(http_params, auth_params, url_fixed_prefix); + int current_pos = url_fixed_prefix.size() - 1; + TemporaryAWSCredential creds; + while (current_pos > 3) { + auto prefix = url_fixed_prefix.substr(0, current_pos) + "/*"; + auto found = AccessGrantsCache.Get(prefix, creds); + if (found) { + if (creds.expiration > Timestamp::GetCurrentTimestamp()) + { + access_key_id = creds.access_key_id; + secret_access_key = creds.secret_access_key; + session_token = creds.session_token; + return true; + } + AccessGrantsCache.Delete(prefix); + + } + current_pos = url_fixed_prefix.rfind("/", current_pos - 1); + } + current_pos = url_fixed_prefix.size() - 1; + while (current_pos > 3) { + auto prefix = url_fixed_prefix.substr(0, current_pos) + "*"; + auto found = AccessGrantsCache.Get(prefix, creds); + if (found) { + if (creds.expiration > Timestamp::GetCurrentTimestamp()) + { + access_key_id = creds.access_key_id; + secret_access_key = creds.secret_access_key; + session_token = creds.session_token; + return true; + } + AccessGrantsCache.Delete(prefix); + + } + current_pos--; + } + string query = "durationSeconds=3600&permission=" + operation+ "&privilege=Default&target=" + StringUtil::URLEncode(url_fixed_prefix) + "&targetType=Object"; + string access_grants_url = "/v20180820/accessgrantsinstance/dataaccess"; + string host = account_id + ".s3-control." + auth_params.region + ".amazonaws.com"; + auto full_url = "https://" + host + access_grants_url + "/" + query; + auto headers = create_s3_header(access_grants_url, query, host, "s3", "GET", auth_params, "", "", "", "", account_id); + std::stringstream response; + GetRequestInfo get_data_access_creds(host, access_grants_url + "?" + query, headers, http_params, + [&](const HTTPResponse &response) { + return true; + }, + [&](const_data_ptr_t data, idx_t data_length) { + response << string(const_char_ptr_cast(data), data_length); + return true; + }); + auto result = http_params.http_util.Request(get_data_access_creds); + // We ignore all errors here as we want to fallback to normal IAM creds + if (result->HasRequestError() || (int)result->status >= 400) { + // cache access denied for 5 min + if ((int)result->status == 403) { + AccessDeniedCache.Put(url_fixed_prefix, Timestamp::GetCurrentTimestamp() + 5 * 60 * 100000); + } + return false; + } + string response_str = response.str(); + string expiration; + string matched_grant_target; + FindTagContents(response_str, "AccessKeyId", 0, access_key_id); + FindTagContents(response_str, "SecretAccessKey", 0, secret_access_key); + FindTagContents(response_str, "SessionToken", 0, session_token); + FindTagContents(response_str, "Expiration", 0, expiration); + FindTagContents(response_str, "MatchedGrantTarget", 0, matched_grant_target); + timestamp_t expiration_ts; + bool has_offset; + string_t tz(nullptr, 0); + Timestamp::TryConvertTimestampTZ(expiration.c_str(), expiration.size(), expiration_ts, true, has_offset, tz); + expiration_ts -= 10 * 60 * 1000000; // 10 min buffer + AccessGrantsCache.Put(matched_grant_target, {.access_key_id = access_key_id, .secret_access_key = secret_access_key, .session_token = session_token, .expiration = expiration_ts}); + return true; +} + +void update_credentials_from_access_grants(HTTPParams &http_params, S3AuthParams &auth_params, const string& method, const string& url) { + if (!auth_params.s3_access_grants_enabled) { + return; + } + + if (auth_params.region.empty()) { + throw Exception(ExceptionType::INVALID_CONFIGURATION, "You must specify a region"); + } + + string operation = "WRITE"; + if (method == "GET" || method == "HEAD") { + operation = "READ"; + } + string access_key_id, secret_access_key, session_token; + if (get_data_access(http_params, auth_params, operation, url, access_key_id, secret_access_key, session_token)) { + auth_params.access_key_id = access_key_id; + auth_params.secret_access_key = secret_access_key; + auth_params.session_token = session_token; + } + +} + string S3FileSystem::UrlDecode(string input) { return StringUtil::URLDecode(input, true); } @@ -165,6 +386,7 @@ void AWSEnvironmentCredentialsProvider::SetAll() { this->SetExtensionOptionValue("s3_use_ssl", DUCKDB_USE_SSL_ENV_VAR); this->SetExtensionOptionValue("s3_kms_key_id", DUCKDB_KMS_KEY_ID_ENV_VAR); this->SetExtensionOptionValue("s3_requester_pays", DUCKDB_REQUESTER_PAYS_ENV_VAR); + this->SetExtensionOptionValue("s3_access_grants_enabled", DUCKDB_S3_ACCESS_GRANTS_ENABLED_ENV_VAR); } S3AuthParams AWSEnvironmentCredentialsProvider::CreateParams() { @@ -179,6 +401,7 @@ S3AuthParams AWSEnvironmentCredentialsProvider::CreateParams() { params.kms_key_id = DUCKDB_KMS_KEY_ID_ENV_VAR; params.use_ssl = DUCKDB_USE_SSL_ENV_VAR; params.requester_pays = DUCKDB_REQUESTER_PAYS_ENV_VAR; + params.s3_access_grants_enabled = DUCKDB_S3_ACCESS_GRANTS_ENABLED_ENV_VAR; return params; } @@ -228,6 +451,11 @@ S3AuthParams S3AuthParams::ReadFrom(optional_ptr opener, FileOpenerI result.endpoint = "s3.amazonaws.com"; } + Value value; + if (FileOpener::TryGetCurrentSetting(opener, "s3_access_grants_enabled", value)) { + result.s3_access_grants_enabled = value.GetValue(); + } + return result; } @@ -741,7 +969,8 @@ unique_ptr S3FileSystem::PostRequest(FileHandle &handle, string ur headers["Content-Type"] = "application/octet-stream"; } else { // Use existing S3 authentication - auto payload_hash = GetPayloadHash(buffer_in, buffer_in_len); + update_credentials_from_access_grants(handle.Cast().http_params, auth_params, "POST", url); + auto payload_hash = GetPayloadHash(buffer_in, buffer_in_len); headers = create_s3_header(parsed_s3_url.path, http_params, parsed_s3_url.host, "s3", "POST", auth_params, "", "", payload_hash, "application/octet-stream"); } @@ -764,6 +993,7 @@ unique_ptr S3FileSystem::PutRequest(FileHandle &handle, string url headers["Content-Type"] = content_type; } else { // Use existing S3 authentication + update_credentials_from_access_grants(handle.Cast().http_params, auth_params, "PUT", url); auto payload_hash = GetPayloadHash(buffer_in, buffer_in_len); headers = create_s3_header(parsed_s3_url.path, http_params, parsed_s3_url.host, "s3", "PUT", auth_params, "", "", payload_hash, content_type); @@ -784,6 +1014,7 @@ unique_ptr S3FileSystem::HeadRequest(FileHandle &handle, string s3 headers["Host"] = parsed_s3_url.host; } else { // Use existing S3 authentication + update_credentials_from_access_grants(handle.Cast().http_params, auth_params, "HEAD", s3_url); headers = create_s3_header(parsed_s3_url.path, "", parsed_s3_url.host, "s3", "HEAD", auth_params, "", "", "", ""); } @@ -803,6 +1034,7 @@ unique_ptr S3FileSystem::GetRequest(FileHandle &handle, string s3_ headers["Host"] = parsed_s3_url.host; } else { // Use existing S3 authentication + update_credentials_from_access_grants(handle.Cast().http_params, auth_params, "GET", s3_url); headers = create_s3_header(parsed_s3_url.path, "", parsed_s3_url.host, "s3", "GET", auth_params, "", "", "", ""); } @@ -823,6 +1055,7 @@ unique_ptr S3FileSystem::GetRangeRequest(FileHandle &handle, strin headers["Host"] = parsed_s3_url.host; } else { // Use existing S3 authentication + update_credentials_from_access_grants(handle.Cast().http_params, auth_params, "GET", s3_url); headers = create_s3_header(parsed_s3_url.path, "", parsed_s3_url.host, "s3", "GET", auth_params, "", "", "", ""); } @@ -842,6 +1075,7 @@ unique_ptr S3FileSystem::DeleteRequest(FileHandle &handle, string headers["Host"] = parsed_s3_url.host; } else { // Use existing S3 authentication + update_credentials_from_access_grants(handle.Cast().http_params, auth_params, "DELETE", s3_url); headers = create_s3_header(parsed_s3_url.path, "", parsed_s3_url.host, "s3", "DELETE", auth_params, "", "", "", ""); } @@ -1202,6 +1436,7 @@ HTTPException S3FileSystem::GetHTTPError(FileHandle &handle, const HTTPResponse return GetS3Error(s3_handle.auth_params, response, url); } + string AWSListObjectV2::Request(string &path, HTTPParams &http_params, S3AuthParams &s3_auth_params, string &continuation_token, optional_ptr state, bool use_delimiter) { auto parsed_url = S3FileSystem::S3UrlParse(path, s3_auth_params); @@ -1252,23 +1487,6 @@ string AWSListObjectV2::Request(string &path, HTTPParams &http_params, S3AuthPar return response.str(); } -optional_idx FindTagContents(const string &response, const string &tag, idx_t cur_pos, string &result) { - string open_tag = "<" + tag + ">"; - string close_tag = ""; - auto open_tag_pos = response.find(open_tag, cur_pos); - if (open_tag_pos == string::npos) { - // tag not found - return optional_idx(); - } - auto close_tag_pos = response.find(close_tag, open_tag_pos + open_tag.size()); - if (close_tag_pos == string::npos) { - throw InternalException("Failed to parse S3 result: found open tag for %s but did not find matching close tag", - tag); - } - result = response.substr(open_tag_pos + open_tag.size(), close_tag_pos - open_tag_pos - open_tag.size()); - return close_tag_pos + close_tag.size(); -} - void AWSListObjectV2::ParseFileList(string &aws_response, vector &result) { // Example S3 response: //