diff --git a/docs/reference/configuration.md b/docs/reference/configuration.md index ae6c7ea..b38c538 100644 --- a/docs/reference/configuration.md +++ b/docs/reference/configuration.md @@ -13,6 +13,7 @@ spr uses the following Git configuration values: | `branchPrefix` | `--branch-prefix` | String used to prefix autogenerated names of pull request branches | | `spr/GITHUB_USERNAME/` | | `requireApproval` | | If true, `spr land` will refuse to land a pull request that is not accepted | false | | `requireTestPlan` | | If true, `spr diff` will refuse to process a commit without a test plan | true | +| `githubApiDomain` | | Domain where the Github API can be found. Override for Github Enterprise support. | true | api.github.com - The config keys are all in the `spr` section; for example, `spr.githubAuthToken`. diff --git a/spr/src/commands/init.rs b/spr/src/commands/init.rs index 98bec11..2d539f6 100644 --- a/spr/src/commands/init.rs +++ b/spr/src/commands/init.rs @@ -74,7 +74,27 @@ pub async fn init() -> Result<()> { return Err(Error::new("Cannot continue without an access token.")); } + let github_api_domain = dialoguer::Input::::new() + .with_prompt("Github API domain (override for Github Enterprise)") + .with_initial_text( + config + .get_string("spr.githubApiDomain") + .ok() + .unwrap_or_else(|| "api.github.com".to_string()), + ) + .interact_text()?; + + config.set_str("spr.githubApiDomain", &github_api_domain)?; + + let api_base_url; + if github_api_domain == "api.github.com" { + api_base_url = "https://api.github.com/v3/".into() + } else { + api_base_url = format!("https://{github_api_domain}/api/v3/"); + }; + let octocrab = octocrab::OctocrabBuilder::new() + .base_url(api_base_url)? .personal_token(pat.clone()) .build()?; let github_user = octocrab.current().user().await?; @@ -121,7 +141,7 @@ pub async fn init() -> Result<()> { let url = repo.find_remote(&remote)?.url().map(String::from); let regex = - lazy_regex::regex!(r#"github\.com[/:]([\w\-\.]+/[\w\-\.]+?)(.git)?$"#); + lazy_regex::regex!(r#"[^/]+[/:]([\w\-\.]+/[\w\-\.]+?)(.git)?$"#); let github_repo = config .get_string("spr.githubRepository") .ok() diff --git a/spr/src/commands/list.rs b/spr/src/commands/list.rs index 81670e0..09e04c0 100644 --- a/spr/src/commands/list.rs +++ b/spr/src/commands/list.rs @@ -32,7 +32,7 @@ pub async fn list( }; let request_body = SearchQuery::build_query(variables); let res = graphql_client - .post("https://api.github.com/graphql") + .post(config.api_base_url() + "graphql") .json(&request_body) .send() .await?; diff --git a/spr/src/config.rs b/spr/src/config.rs index da61598..c75dee7 100644 --- a/spr/src/config.rs +++ b/spr/src/config.rs @@ -18,6 +18,7 @@ pub struct Config { pub branch_prefix: String, pub require_approval: bool, pub require_test_plan: bool, + pub github_api_domain: String, } impl Config { @@ -29,6 +30,7 @@ impl Config { branch_prefix: String, require_approval: bool, require_test_plan: bool, + github_api_domain: String, ) -> Self { let master_ref = GitHubBranch::new_from_branch_name( &master_branch, @@ -43,15 +45,25 @@ impl Config { branch_prefix, require_approval, require_test_plan, + github_api_domain, } } pub fn pull_request_url(&self, number: u64) -> String { - format!( - "https://github.com/{owner}/{repo}/pull/{number}", - owner = &self.owner, - repo = &self.repo - ) + if self.github_api_domain == "api.github.com" { + format!( + "https://github.com/{owner}/{repo}/pull/{number}", + owner = &self.owner, + repo = &self.repo + ) + } else { + format!( + "https://{domain}/{owner}/{repo}/pull/{number}", + domain = &self.github_api_domain, + owner = &self.owner, + repo = &self.repo + ) + } } pub fn parse_pull_request_field(&self, text: &str) -> Option { @@ -66,7 +78,7 @@ impl Config { } let regex = lazy_regex::regex!( - r#"^\s*https?://github.com/([\w\-]+)/([\w\-]+)/pull/(\d+)([/?#].*)?\s*$"# + r#"^\s*https?://[^/]+/([\w\-]+)/([\w\-]+)/pull/(\d+)([/?#].*)?\s*$"# ); let m = regex.captures(text); if let Some(caps) = m { @@ -140,6 +152,17 @@ impl Config { self.master_ref.branch_name(), ) } + + pub fn api_base_url(&self) -> String { + if self.github_api_domain == "api.github.com" { + "https://api.github.com/".into() + } else { + format!( + "https://{domain}/api/", + domain = &self.github_api_domain, + ) + } + } } #[cfg(test)] @@ -156,6 +179,7 @@ mod tests { "spr/foo/".into(), false, true, + "api.github.com".into(), ) } @@ -169,6 +193,25 @@ mod tests { ); } + #[test] + fn test_pull_request_url_github_enterprise() { + let gh = crate::config::Config::new( + "acme".into(), + "codez".into(), + "origin".into(), + "master".into(), + "spr/foo/".into(), + false, + true, + "github.acme.com".into(), + ); + + assert_eq!( + &gh.pull_request_url(123), + "https://github.acme.com/acme/codez/pull/123" + ); + } + #[test] fn test_parse_pull_request_field_empty() { let gh = config_factory(); @@ -229,4 +272,45 @@ mod tests { Some(123) ); } + #[test] + fn test_parse_pull_request_field_url_github_enterprise() { + let gh = config_factory(); + + assert_eq!( + gh.parse_pull_request_field( + "https://github.acme.com/acme/codez/pull/123" + ), + Some(123) + ); + assert_eq!( + gh.parse_pull_request_field( + " https://github.acme.com/acme/codez/pull/123 " + ), + Some(123) + ); + assert_eq!( + gh.parse_pull_request_field( + "https://github.acme.com/acme/codez/pull/123/" + ), + Some(123) + ); + assert_eq!( + gh.parse_pull_request_field( + "https://github.acme.com/acme/codez/pull/123?x=a" + ), + Some(123) + ); + assert_eq!( + gh.parse_pull_request_field( + "https://github.acme.com/acme/codez/pull/123/foo" + ), + Some(123) + ); + assert_eq!( + gh.parse_pull_request_field( + "https://github.acme.com/acme/codez/pull/123#abc" + ), + Some(123) + ); + } } diff --git a/spr/src/github.rs b/spr/src/github.rs index d8849f5..8019ee5 100644 --- a/spr/src/github.rs +++ b/spr/src/github.rs @@ -164,7 +164,7 @@ impl GitHub { }; let request_body = PullRequestQuery::build_query(variables); let res = graphql_client - .post("https://api.github.com/graphql") + .post(config.api_base_url() + "graphql") .json(&request_body) .send() .await?; @@ -440,7 +440,7 @@ impl GitHub { let request_body = PullRequestMergeabilityQuery::build_query(variables); let res = self .graphql_client - .post("https://api.github.com/graphql") + .post(self.config.api_base_url() + "graphql") .json(&request_body) .send() .await?; diff --git a/spr/src/main.rs b/spr/src/main.rs index c1881f9..92e248b 100644 --- a/spr/src/main.rs +++ b/spr/src/main.rs @@ -135,6 +135,9 @@ pub async fn spr() -> Result<()> { .get_bool("spr.requireTestPlan") .ok() .unwrap_or(true); + let github_api_domain = git_config + .get_string("spr.githubApiDomain") + .unwrap_or_else(|_| "api.github.com".to_string()); let config = spr::config::Config::new( github_owner, @@ -144,6 +147,7 @@ pub async fn spr() -> Result<()> { branch_prefix, require_approval, require_test_plan, + github_api_domain, ); let git = spr::git::Git::new(repo); @@ -158,7 +162,9 @@ pub async fn spr() -> Result<()> { }?; octocrab::initialise( - octocrab::Octocrab::builder().personal_token(github_auth_token.clone()), + octocrab::Octocrab::builder() + .base_url(config.api_base_url() + "v3/")? + .personal_token(github_auth_token.clone()) )?; let mut headers = header::HeaderMap::new();