Skip to content

Commit 449d3a4

Browse files
committed
make default github provider configureable w/r to api endpoints (allow connecting to other github instances, e.g. ghec)
1 parent f9fa41d commit 449d3a4

File tree

4 files changed

+82
-9
lines changed

4 files changed

+82
-9
lines changed

README.md

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -404,6 +404,21 @@ Add custom AI providers:
404404
- `copilot` - GitHub Copilot (default)
405405
- `github_models` - GitHub Marketplace models (disabled by default)
406406

407+
## Github Enterprise
408+
409+
If your employer provides access to Copilot via a Github Enterprise instance ("GHEC") you can provide the respective URLs with the following config keys:
410+
411+
```lua
412+
{
413+
-- github instance main address w/o protocol prefix, default: "github.com" (without "https://"). E.g. a github-enterprise address might look like this: "mycorp.ghe.com"
414+
github_instance_url = 'mycorp.ghe.com',
415+
-- github instance api address w/o protocol prefix, default: "api.github.com" (without "https://"). E.g.: "api.mycorp.ghe.com"
416+
github_instance_api_url = 'api.mycorp.ghe.com',
417+
}
418+
```
419+
420+
(These keys are used in the default Copilot "provider", this is an alternative to defining a full custom provider)
421+
407422
# API Reference
408423

409424
## Core

lua/CopilotChat/config.lua

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@
4949
---@field functions table<string, CopilotChat.config.functions.Function>?
5050
---@field prompts table<string, CopilotChat.config.prompts.Prompt|string>?
5151
---@field mappings CopilotChat.config.mappings?
52+
---@field github_instance_url string?
53+
---@field github_instance_api_url string?
5254
return {
5355

5456
-- Shared config starts here (can be passed to functions at runtime and configured via setup function)
@@ -102,6 +104,9 @@ return {
102104

103105
chat_autocomplete = true, -- Enable chat autocompletion (when disabled, requires manual `mappings.complete` trigger)
104106

107+
github_instance_url = 'github.com', -- github instance main address w/o protocol prefix (without "https://"). E.g. a github-enterprise address might look like this: "mycorp.ghe.com"
108+
github_instance_api_url = 'api.github.com', -- github instance api address w/o protocol prefix (without "https://"). E.g.: "api.mycorp.ghe.com"
109+
105110
log_path = vim.fn.stdpath('state') .. '/CopilotChat.log', -- Default path to log file
106111
history_path = vim.fn.stdpath('data') .. '/copilotchat_history', -- Default path to stored history
107112

lua/CopilotChat/config/providers.lua

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,22 @@ local constants = require('CopilotChat.constants')
22
local notify = require('CopilotChat.notify')
33
local utils = require('CopilotChat.utils')
44
local plenary_utils = require('plenary.async.util')
5+
local log = require('plenary.log')
56

67
local EDITOR_VERSION = 'Neovim/' .. vim.version().major .. '.' .. vim.version().minor .. '.' .. vim.version().patch
78

9+
---@class CopilotChat
10+
---@field config CopilotChat.config.Config
11+
---@field chat CopilotChat.ui.chat.Chat
12+
local MC = setmetatable({}, {
13+
__index = function(t, key)
14+
if key == 'config' then
15+
return require('CopilotChat.config')
16+
end
17+
return rawget(t, key)
18+
end,
19+
})
20+
821
local token_cache = nil
922
local unsaved_token_cache = {}
1023
local function load_tokens()
@@ -50,7 +63,7 @@ end
5063
---@return string
5164
local function github_device_flow(tag, client_id, scope)
5265
local function request_device_code()
53-
local res = utils.curl_post('https://github.com/login/device/code', {
66+
local res = utils.curl_post('https://' .. MC.config.github_instance_url .. '/login/device/code', {
5467
body = {
5568
client_id = client_id,
5669
scope = scope,
@@ -66,7 +79,7 @@ local function github_device_flow(tag, client_id, scope)
6679
while true do
6780
plenary_utils.sleep(interval * 1000)
6881

69-
local res = utils.curl_post('https://github.com/login/oauth/access_token', {
82+
local res = utils.curl_post('https://' .. MC.config.github_instance_url .. '/login/oauth/access_token', {
7083
body = {
7184
client_id = client_id,
7285
device_code = device_code,
@@ -146,7 +159,7 @@ local function get_github_copilot_token(tag)
146159
local parsed_data = utils.json_decode(file_data)
147160
if parsed_data then
148161
for key, value in pairs(parsed_data) do
149-
if string.find(key, 'github.com') and value and value.oauth_token then
162+
if string.find(key, MC.config.github_instance_url) and value and value.oauth_token then
150163
return set_token(tag, value.oauth_token, false)
151164
end
152165
end
@@ -173,7 +186,7 @@ local function get_github_models_token(tag)
173186

174187
-- loading token from gh cli if available
175188
if vim.fn.executable('gh') == 0 then
176-
local result = utils.system({ 'gh', 'auth', 'token', '-h', 'github.com' })
189+
local result = utils.system({ 'gh', 'auth', 'token', '-h', MC.config.github_instance_url })
177190
if result and result.code == 0 and result.stdout then
178191
local gh_token = vim.trim(result.stdout)
179192
if gh_token ~= '' and not gh_token:find('no oauth token') then
@@ -214,10 +227,12 @@ M.copilot = {
214227
endpoints_api = '',
215228

216229
get_headers = function()
217-
local response, err = utils.curl_get('https://api.github.com/copilot_internal/v2/token', {
230+
local url = 'https://' .. MC.config.github_instance_api_url .. '/copilot_internal/v2/token'
231+
log.debug('get headers - get ' .. url)
232+
local response, err = utils.curl_get(url, {
218233
json_response = true,
219234
headers = {
220-
['Authorization'] = 'Token ' .. get_github_copilot_token('github_copilot'),
235+
['Authorization'] = 'Token ' .. get_github_copilot_token(MC.config.github_instance_api_url),
221236
},
222237
})
223238

@@ -249,10 +264,10 @@ M.copilot = {
249264
end,
250265

251266
get_info = function(headers)
252-
local response, err = utils.curl_get('https://api.github.com/copilot_internal/user', {
267+
local response, err = utils.curl_get('https://' .. MC.config.github_instance_url .. '/copilot_internal/user', {
253268
json_response = true,
254269
headers = {
255-
['Authorization'] = 'Token ' .. get_github_copilot_token('github_copilot'),
270+
['Authorization'] = 'Token ' .. get_github_copilot_token(MC.config.github_instance_url),
256271
},
257272
})
258273

@@ -299,7 +314,7 @@ M.copilot = {
299314
end,
300315

301316
get_models = function(headers)
302-
local response, err = utils.curl_get('https://api.githubcopilot.com/models', {
317+
log.info('getting models .. headers: ' .. utils.to_string(headers))
303318
local response, err = utils.curl_get(M.endpoints_api .. '/models', {
304319
json_response = true,
305320
headers = headers,

lua/CopilotChat/utils.lua

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -450,6 +450,44 @@ M.curl_post = async.wrap(function(url, opts, callback)
450450
curl.post(url, args)
451451
end, 3)
452452

453+
function M.to_string(tbl)
454+
-- credit: http://lua-users.org/wiki/TableSerialization (universal tostring)
455+
local function table_print(tt, indent, done)
456+
done = done or {}
457+
indent = indent or 0
458+
if type(tt) == 'table' then
459+
local sb = {}
460+
for key, value in pairs(tt) do
461+
table.insert(sb, string.rep(' ', indent)) -- indent it
462+
if type(value) == 'table' and not done[value] then
463+
done[value] = true
464+
table.insert(sb, key .. ' = {\n')
465+
table.insert(sb, table_print(value, indent + 2, done))
466+
table.insert(sb, string.rep(' ', indent)) -- indent it
467+
table.insert(sb, '}\n')
468+
elseif 'number' == type(key) then
469+
table.insert(sb, string.format('"%s"\n', tostring(value)))
470+
else
471+
table.insert(sb, string.format('%s = "%s"\n', tostring(key), tostring(value)))
472+
end
473+
end
474+
return table.concat(sb)
475+
else
476+
return tt .. '\n'
477+
end
478+
end
479+
480+
if 'nil' == type(tbl) then
481+
return tostring(nil)
482+
elseif 'table' == type(tbl) then
483+
return table_print(tbl)
484+
elseif 'string' == type(tbl) then
485+
return tbl
486+
else
487+
return tostring(tbl)
488+
end
489+
end
490+
453491
local function filter_files(files, max_count)
454492
local filetype = require('plenary.filetype')
455493

0 commit comments

Comments
 (0)