From 7e5abab3bcb914ff592b9962600cfe8986008ea2 Mon Sep 17 00:00:00 2001 From: jornh Date: Tue, 21 May 2024 20:36:43 +0000 Subject: [PATCH] initial enterprise --- internal/github/pr._test.go | 33 ++++++++++++++++++++++++ internal/github/pr.go | 51 +++++++++++++++++++++++++++---------- 2 files changed, 71 insertions(+), 13 deletions(-) create mode 100644 internal/github/pr._test.go diff --git a/internal/github/pr._test.go b/internal/github/pr._test.go new file mode 100644 index 0000000..a7b8da9 --- /dev/null +++ b/internal/github/pr._test.go @@ -0,0 +1,33 @@ +package github + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestParseGithubURL(t *testing.T) { + url := "https://github.com/foo/bar" + var enterprise bool + var baseURL string + var exists string + var err error + + enterprise, baseURL, exists, _, err = parseGitHubURL(url) + require.NoError(t, err) + require.Equal(t, false, enterprise) + require.Equal(t, "https://github.com", baseURL) + require.Equal(t, "foo", exists) + + url = "https://mygithub.co.uk/foo-BAR/bar" + // url = "https://otherdomain.example/baz/bar" + enterprise, baseURL, exists, _, err = parseGitHubURL(url) + require.NoError(t, err) + require.Equal(t, true, enterprise) + require.Equal(t, "https://mygithub.co.uk", baseURL) + require.Equal(t, "foo-BAR", exists) + // file = "bogus.go" + // exists, err = Exists(file) + // require.NoError(t, err) + // require.False(t, exists) +} diff --git a/internal/github/pr.go b/internal/github/pr.go index 1502317..3185570 100644 --- a/internal/github/pr.go +++ b/internal/github/pr.go @@ -21,18 +21,36 @@ func OpenPR( commitBranch string, repoCreds git.RepoCredentials, ) (string, error) { - owner, repo, err := parseGitHubURL(repoURL) + var baseURL string + isEnterprise, baseURL, owner, repo, err := parseGitHubURL(repoURL) if err != nil { return "", err } - githubClient := github.NewClient( - oauth2.NewClient( - ctx, - oauth2.StaticTokenSource( - &oauth2.Token{AccessToken: repoCreds.Password}, + var githubClient *github.Client + if isEnterprise { + githubClient, err = github.NewEnterpriseClient( + baseURL, + baseURL, + oauth2.NewClient( + ctx, + oauth2.StaticTokenSource( + &oauth2.Token{AccessToken: repoCreds.Password}, + ), ), - ), - ) + ) + if err != nil { + return "", err + } + } else { + githubClient = github.NewClient( + oauth2.NewClient( + ctx, + oauth2.StaticTokenSource( + &oauth2.Token{AccessToken: repoCreds.Password}, + ), + ), + ) + } pr, _, err := githubClient.PullRequests.Create( ctx, owner, @@ -57,11 +75,18 @@ func OpenPR( return *pr.HTMLURL, nil } -func parseGitHubURL(url string) (string, string, error) { - regex := regexp.MustCompile(`^https\://github\.com/([\w-]+)/([\w-]+).*`) +func parseGitHubURL(url string) (bool, string, string, string, error) { + // regex := regexp.MustCompile(`^https\://github\.[\w]/([\w-]+)/([\w-]+).*`) + regex := regexp.MustCompile(`^https\://([\w.-]+)/([\w-]+)/([\w-]+).*`) parts := regex.FindStringSubmatch(url) - if len(parts) != 3 { - return "", "", fmt.Errorf("error parsing github repository URL %q", url) + if len(parts) != 4 { + return false, "", "", "", fmt.Errorf("error parsing github repository URL %q", url) + } + isEnterprise := false + if parts[1] != "github.com" { + isEnterprise = true } - return parts[1], parts[2], nil + baseURL := "https://" + parts[1] + + return isEnterprise, baseURL, parts[2], parts[3], nil }