Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions cmlutils/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@


class BaseWorkspaceInteractor(object):
# use a variable to store the apiv2 key for repeated use instead
_apiv2_key = None

def __init__(
self,
host: str,
Expand All @@ -27,6 +30,9 @@ def __init__(

@property
def apiv2_key(self) -> str:
if self._apiv2_key:
return self._apiv2_key

endpoint = Template(ApiV1Endpoints.API_KEY.value).substitute(
username=self.username
)
Expand All @@ -44,8 +50,8 @@ def apiv2_key(self) -> str:
ca_path=self.ca_path,
)
response_dict = response.json()
_apiv2_key = response_dict["apiKey"]
return _apiv2_key
self._apiv2_key = response_dict["apiKey"]
return self._apiv2_key

def remove_cdswctl_dir(self, file_path: str):
if os.path.exists(file_path):
Expand Down
25 changes: 14 additions & 11 deletions cmlutils/project_entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,12 @@ def project_import_cmd(project_name, verify):
uses_engine = True
project_metadata.pop("default_project_engine_type", None)

project_id = p.check_project_exist(project_metadata["name"])
# check if the project to be imported is a team's project
if "team_name" in project_metadata:
project_id = p.check_project_exist(project_metadata["name"], project_metadata["team_name"])
else:
project_id = p.check_project_exist(project_metadata["name"])

if project_id == None:
logging.info(
"Creating project %s to migrate files and metadata.", project_name
Expand All @@ -288,22 +293,20 @@ def project_import_cmd(project_name, verify):
if "team_name" in project_metadata:
username = project_metadata["team_name"]
creator_username, project_slug = p.get_creator_username()
pimport = ProjectImporter(
host=url,
username=username,
project_name=project_name,
api_key=apiv1_key,
top_level_dir=local_directory,
ca_path=ca_path,
project_slug=project_slug,
)

# reuse the ProjectImporter obj since it already generated the apiv2 key
# this fixed the bug of team projects import where cmlutil was trying to
# generate apiv2 key using the team's username
pimport = p
pimport.username = username
pimport.project_slug = project_slug

start_time = time.time()
if verify:
import_diff_file_list=pimport.transfer_project(log_filedir=log_filedir, verify=True)
else:
pimport.transfer_project(log_filedir=log_filedir)


if uses_engine:
proj_patch_metadata = {"default_project_engine_type": "legacy_engine"}
pimport.convert_project_to_engine_based(
Expand Down
27 changes: 22 additions & 5 deletions cmlutils/projects.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,9 +324,11 @@ def get_creator_username(self):
# Handle Pagination if exists
while next_page_exists:
# Note - projectName param makes LIKE query not the exact match
# If projectname has back slashes, it need to be escaped before passing to the API
escaped_project_name = self.project_name.replace("\\", "\\\\")
endpoint = Template(ApiV1Endpoints.PROJECTS_SUMMARY.value).substitute(
username=self.username,
projectName=self.project_name,
projectName=escaped_project_name,
limit=constants.MAX_API_PAGE_LENGTH,
offset=offset * constants.MAX_API_PAGE_LENGTH,
)
Expand Down Expand Up @@ -358,7 +360,10 @@ def get_creator_username(self):

if project_list:
for project in project_list:
if project["name"] == self.project_name:
# It is possible that project lists can contain other users' public projects, or team's projects
# so there could be projects that has the same name but belong to other users. To ensure that
# we identify the correct project, we need to compare the project owner's name too.
if project["name"] == self.project_name and project["owner"]["username"] == self.username:
if project["owner"]["type"] == constants.ORGANIZATION_TYPE:
return (
project["owner"]["username"],
Expand Down Expand Up @@ -927,9 +932,11 @@ def get_creator_username(self):
# Handle Pagination if exists
while next_page_exists:
# Note - projectName param makes LIKE query not the exact match
# If projectname has back slashes, it need to be escaped before passing to the API
escaped_project_name = self.project_name.replace("\\", "\\\\")
endpoint = Template(ApiV1Endpoints.PROJECTS_SUMMARY.value).substitute(
username=self.username,
projectName=self.project_name,
projectName=escaped_project_name,
limit=constants.MAX_API_PAGE_LENGTH,
offset=offset * constants.MAX_API_PAGE_LENGTH,
)
Expand Down Expand Up @@ -1243,7 +1250,7 @@ def get_all_runtimes_v2(self, page_token=""):
return result_list
return None

def check_project_exist(self, project_name: str) -> str:
def check_project_exist(self, project_name: str, team_name: str = None) -> str:
try:
search_option = {"name": project_name}
encoded_option = urllib.parse.quote(
Expand All @@ -1260,9 +1267,19 @@ def check_project_exist(self, project_name: str) -> str:
ca_path=self.ca_path,
)
project_list = response.json()["projects"]

# If the project is a team's project, then the owner of the project is the team
if team_name:
owner = team_name
else:
owner = self.username

# It is possible that project lists can contain other users' public projects, or team's projects
# so there could be projects that has the same name but belong to other users. To ensure that
# we identify the correct project, we need to compare the project owner's name too.
if project_list:
for project in project_list:
if project["name"] == project_name:
if project["name"] == project_name and project["owner"]["username"] == owner:
return project["id"]
return None
except KeyError as e:
Expand Down