Skip to content

Commit 586895c

Browse files
committed
feat: Add REST API fallback for PR files when GraphQL fails (chaoss#2875)
Signed-off-by: Xiaoha <blairjade183@gmail.com>
1 parent 3e228c7 commit 586895c

File tree

4 files changed

+693
-199
lines changed

4 files changed

+693
-199
lines changed

augur/tasks/github/pull_requests/files_model/core.py

+308-60
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,315 @@
1+
import httpx
2+
from typing import List, Dict, Any
3+
from augur.tasks.github.util.github_graphql_data_access import (
4+
GithubGraphQlDataAccess,
5+
)
6+
from augur.application.db.models import (
7+
PullRequest,
8+
Repo,
9+
PullRequestFile,
10+
)
11+
from augur.application.db.session import DatabaseSession
12+
import logging
13+
from augur.tasks.github.util.github_random_key_auth import GithubRandomKeyAuth
14+
from augur.tasks.github.util.retry import (
15+
retry_on_exception,
16+
RateLimitError,
17+
AuthenticationError,
18+
)
19+
from augur.tasks.github.util.github_api_errors import GitHubAPIError
120
import sqlalchemy as s
2-
from augur.tasks.github.util.github_graphql_data_access import GithubGraphQlDataAccess, NotFoundException, InvalidDataException
3-
from augur.application.db.models import *
421
from augur.tasks.github.util.util import get_owner_repo
522
from augur.application.db.util import execute_session_query
623
from augur.application.db.lib import get_secondary_data_last_collected, get_updated_prs
724

25+
logger = logging.getLogger(__name__)
826

9-
def pull_request_files_model(repo_id,logger, augur_db, key_auth, full_collection=False):
10-
27+
28+
@retry_on_exception(
29+
retries=3,
30+
delay=1.0,
31+
backoff=2.0,
32+
exceptions=(GitHubAPIError, RateLimitError)
33+
)
34+
def get_pr_files_from_rest_api(
35+
owner: str,
36+
repo: str,
37+
pr_number: int,
38+
key_auth: GithubRandomKeyAuth
39+
) -> List[Dict[str, Any]]:
40+
"""
41+
Get PR files using REST API as a fallback
42+
43+
Args:
44+
owner: Repository owner
45+
repo: Repository name
46+
pr_number: Pull request number
47+
key_auth: GitHub authentication
48+
49+
Returns:
50+
List of files from the PR
51+
52+
Raises:
53+
GitHubAPIError: On API errors
54+
RateLimitError: When rate limit is exceeded
55+
AuthenticationError: On authentication failure
56+
"""
57+
url = (
58+
f"https://api.github.com/repos/{owner}/{repo}/pulls/{pr_number}/files"
59+
)
60+
headers = {"Accept": "application/vnd.github.v3+json"}
61+
62+
try:
63+
client = httpx.Client(auth=key_auth)
64+
response = client.get(url, headers=headers)
65+
client.close()
66+
67+
# Check rate limit
68+
remaining = int(response.headers.get("X-RateLimit-Remaining", 0))
69+
reset_time = int(response.headers.get("X-RateLimit-Reset", 0))
70+
if remaining == 0 or response.status_code == 403:
71+
raise RateLimitError(reset_time)
72+
73+
# Handle different status codes
74+
if response.status_code == 401:
75+
raise AuthenticationError("Invalid GitHub token")
76+
elif response.status_code == 404:
77+
logger.warning(f"PR {pr_number} not found")
78+
return []
79+
elif response.status_code == 422:
80+
logger.warning(f"PR {pr_number} diff not available")
81+
return []
82+
83+
response.raise_for_status()
84+
return response.json()
85+
86+
except httpx.HTTPError as e:
87+
raise GitHubAPIError(f"HTTP error occurred: {str(e)}")
88+
except RateLimitError:
89+
raise
90+
except Exception as e:
91+
raise GitHubAPIError(f"Error getting PR files: {str(e)}")
92+
93+
94+
@retry_on_exception(
95+
retries=3,
96+
delay=1.0,
97+
backoff=2.0,
98+
exceptions=(GitHubAPIError, RateLimitError)
99+
)
100+
def get_pr_files_from_graphql(
101+
owner: str,
102+
repo: str,
103+
pr_number: int,
104+
client: GithubGraphQlDataAccess
105+
) -> List[Dict[str, Any]]:
106+
"""
107+
Get PR files using GraphQL API
108+
109+
Args:
110+
owner: Repository owner
111+
repo: Repository name
112+
pr_number: Pull request number
113+
client: GraphQL client
114+
115+
Returns:
116+
List of files from the PR
117+
118+
Raises:
119+
GitHubAPIError: On API errors
120+
RateLimitError: When rate limit is exceeded
121+
"""
122+
query = """
123+
query($owner: String!, $repo: String!, $pr_number: Int!) {
124+
repository(owner: $owner, name: $repo) {
125+
pullRequest(number: $pr_number) {
126+
files(first: 100) {
127+
nodes {
128+
path
129+
additions
130+
deletions
131+
changeType
132+
}
133+
}
134+
}
135+
}
136+
}
137+
"""
138+
variables = {"owner": owner, "repo": repo, "pr_number": pr_number}
139+
140+
try:
141+
files = client.paginate_resource(
142+
query,
143+
variables,
144+
["repository", "pullRequest", "files"]
145+
)
146+
if not files:
147+
logger.warning(f"PR {pr_number} not found in GraphQL")
148+
return []
149+
return files
150+
except Exception as e:
151+
if "rate limit exceeded" in str(e).lower():
152+
raise RateLimitError(0) # GraphQL doesn't provide reset time
153+
raise GitHubAPIError(f"GraphQL error: {str(e)}")
154+
155+
156+
def collect_pull_request_files(repo_id: int, pr_number: int, key_auth: GithubRandomKeyAuth) -> List[Dict[str, Any]]:
157+
"""
158+
Collect files for a pull request using both GraphQL and REST APIs
159+
160+
Args:
161+
repo_id: Repository ID
162+
pr_number: Pull request number
163+
key_auth: GitHub authentication
164+
165+
Returns:
166+
List of files from the PR
167+
"""
168+
try:
169+
with DatabaseSession() as session:
170+
# Get PR and repo info
171+
pr = (
172+
session.query(PullRequest)
173+
.filter(
174+
PullRequest.repo_id == repo_id,
175+
PullRequest.pr_src_number == pr_number,
176+
)
177+
.first()
178+
)
179+
180+
if not pr:
181+
logger.warning(f"PR {pr_number} not found in database")
182+
return []
183+
184+
repo = session.query(Repo).filter(Repo.repo_id == repo_id).first()
185+
if not repo:
186+
logger.warning(f"Repo {repo_id} not found in database")
187+
return []
188+
189+
# Extract owner and repo from repo_git
190+
owner, repo_name = repo.repo_git.split("/")[-2:]
191+
repo_name = repo_name.replace(".git", "")
192+
193+
# Try GraphQL first
194+
graphql_client = GithubGraphQlDataAccess(key_auth, logger)
195+
196+
files = []
197+
try:
198+
files = get_pr_files_from_graphql(
199+
owner,
200+
repo_name,
201+
pr_number,
202+
graphql_client
203+
)
204+
if not files:
205+
logger.warning(
206+
"GraphQL returned no files for PR {}, "
207+
"trying REST API".format(pr_number)
208+
)
209+
files = get_pr_files_from_rest_api(
210+
owner,
211+
repo_name,
212+
pr_number,
213+
key_auth
214+
)
215+
except (GitHubAPIError, RateLimitError):
216+
logger.warning(
217+
"GraphQL query failed for PR {}, "
218+
"falling back to REST API".format(pr_number)
219+
)
220+
try:
221+
files = get_pr_files_from_rest_api(
222+
owner,
223+
repo_name,
224+
pr_number,
225+
key_auth
226+
)
227+
except (GitHubAPIError, RateLimitError):
228+
logger.error(
229+
"REST API also failed for PR {}".format(pr_number)
230+
)
231+
return []
232+
233+
if not files:
234+
logger.warning(f"No files found for PR {pr_number}")
235+
return []
236+
237+
# Process files from either API
238+
processed_files = []
239+
for file_data in files:
240+
if "node" in file_data: # GraphQL response
241+
if not file_data["node"].get("path"):
242+
logger.warning(
243+
"Skipping file with no path in PR {}".format(
244+
pr_number
245+
)
246+
)
247+
continue
248+
249+
processed_files.append({
250+
'pull_request_id': pr.pull_request_id,
251+
'repo_id': repo.repo_id,
252+
'pr_file_path': file_data["node"]["path"],
253+
'pr_file_additions': file_data["node"]["additions"],
254+
'pr_file_deletions': file_data["node"]["deletions"],
255+
'tool_source': file_data["node"]["changeType"].lower(),
256+
'tool_version': "1.0",
257+
'data_source': "GitHub GraphQL API",
258+
})
259+
else: # REST API response
260+
if not file_data.get("filename"):
261+
logger.warning(
262+
"Skipping file with no filename in PR {}".format(
263+
pr_number
264+
)
265+
)
266+
continue
267+
268+
status = "modified"
269+
if file_data.get("status"):
270+
status = file_data["status"].lower()
271+
272+
processed_files.append({
273+
'pull_request_id': pr.pull_request_id,
274+
'repo_id': repo.repo_id,
275+
'pr_file_path': file_data["filename"],
276+
'pr_file_additions': file_data.get("additions", 0),
277+
'pr_file_deletions': file_data.get("deletions", 0),
278+
'tool_source': status,
279+
'tool_version': "1.0",
280+
'data_source': "GitHub REST API",
281+
})
282+
283+
return processed_files
284+
285+
except Exception as e:
286+
logger.error(
287+
"Error processing PR {}: {}".format(pr_number, str(e)),
288+
exc_info=True
289+
)
290+
return []
291+
292+
293+
def pull_request_files_model(repo_id, logger, augur_db, key_auth, full_collection=False):
294+
"""
295+
Main function to collect PR files for a repository
296+
297+
Args:
298+
repo_id: Repository ID
299+
logger: Logger instance
300+
augur_db: Database session
301+
key_auth: GitHub authentication
302+
full_collection: Whether to collect all PRs or only updated ones
303+
"""
11304
if full_collection:
12305
# query existing PRs and the respective url we will append the commits url to
13306
pr_number_sql = s.sql.text("""
14307
SELECT DISTINCT pr_src_number as pr_src_number, pull_requests.pull_request_id
15-
FROM pull_requests--, pull_request_meta
308+
FROM pull_requests
16309
WHERE repo_id = :repo_id
17310
""").bindparams(repo_id=repo_id)
18-
pr_numbers = []
19-
#pd.read_sql(pr_number_sql, self.db, params={})
20311

21-
result = augur_db.execute_sql(pr_number_sql)#.fetchall()
312+
result = augur_db.execute_sql(pr_number_sql)
22313
pr_numbers = [dict(row) for row in result.mappings()]
23314

24315
else:
@@ -36,63 +327,20 @@ def pull_request_files_model(repo_id,logger, augur_db, key_auth, full_collection
36327
repo = execute_session_query(query, 'one')
37328
owner, name = get_owner_repo(repo.repo_git)
38329

39-
github_graphql_data_access = GithubGraphQlDataAccess(key_auth, logger)
40-
41-
pr_file_rows = []
42330
logger.info(f"Getting pull request files for repo: {repo.repo_git}")
331+
pr_file_rows = []
332+
43333
for index, pr_info in enumerate(pr_numbers):
44-
45334
logger.info(f'Querying files for pull request #{index + 1} of {len(pr_numbers)}')
46335

47-
query = """
48-
query($repo: String!, $owner: String!,$pr_number: Int!, $numRecords: Int!, $cursor: String) {
49-
repository(name: $repo, owner: $owner) {
50-
pullRequest(number: $pr_number) {
51-
files ( first: $numRecords, after: $cursor) {
52-
edges {
53-
node {
54-
additions
55-
deletions
56-
path
57-
}
58-
}
59-
totalCount
60-
pageInfo {
61-
hasNextPage
62-
endCursor
63-
}
64-
}
65-
}
66-
}
67-
}
68-
"""
336+
files = collect_pull_request_files(
337+
repo_id,
338+
pr_info['pr_src_number'],
339+
key_auth
340+
)
69341

70-
values = ["repository", "pullRequest", "files"]
71-
params = {
72-
'owner': owner,
73-
'repo': name,
74-
'pr_number': pr_info['pr_src_number'],
75-
}
76-
77-
try:
78-
for pr_file in github_graphql_data_access.paginate_resource(query, params, values):
79-
80-
if not pr_file or 'path' not in pr_file:
81-
continue
82-
83-
data = {
84-
'pull_request_id': pr_info['pull_request_id'],
85-
'pr_file_additions': pr_file['additions'] if 'additions' in pr_file else None,
86-
'pr_file_deletions': pr_file['deletions'] if 'deletions' in pr_file else None,
87-
'pr_file_path': pr_file['path'],
88-
'data_source': 'GitHub API',
89-
'repo_id': repo.repo_id,
90-
}
91-
92-
pr_file_rows.append(data)
93-
except (NotFoundException, InvalidDataException) as e:
94-
logger.warning(e)
95-
continue
342+
if files:
343+
pr_file_rows.extend(files)
96344

97345
if len(pr_file_rows) > 0:
98346
# Execute a bulk upsert with sqlalchemy

0 commit comments

Comments
 (0)