Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 13 additions & 16 deletions openml/tasks/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,32 +492,29 @@ def _create_task_from_xml(xml: str) -> OpenMLTask:
"data_set_id": inputs["source_data"]["oml:data_set"]["oml:data_set_id"],
"evaluation_measure": evaluation_measures,
}
# TODO: add OpenMLClusteringTask?
if task_type in (
TaskType.SUPERVISED_CLASSIFICATION,
TaskType.SUPERVISED_REGRESSION,
TaskType.LEARNING_CURVE,
TaskType.CLUSTERING,
):
# Convert some more parameters
for parameter in inputs["estimation_procedure"]["oml:estimation_procedure"][
"oml:parameter"
]:
est_proc = inputs["estimation_procedure"]["oml:estimation_procedure"]
parameters = est_proc.get("oml:parameter", [])
if isinstance(parameters, dict):
parameters = [parameters]
for parameter in parameters:
name = parameter["@name"]
text = parameter.get("#text", "")
estimation_parameters[name] = text

common_kwargs["estimation_procedure_type"] = inputs["estimation_procedure"][
"oml:estimation_procedure"
]["oml:type"]
common_kwargs["estimation_procedure_id"] = int(
inputs["estimation_procedure"]["oml:estimation_procedure"]["oml:id"]
)

common_kwargs["estimation_procedure_type"] = est_proc.get("oml:type")
est_proc_id = est_proc.get("oml:id")
common_kwargs["estimation_procedure_id"] = int(est_proc_id) if est_proc_id else None
common_kwargs["estimation_parameters"] = estimation_parameters
common_kwargs["target_name"] = inputs["source_data"]["oml:data_set"]["oml:target_feature"]
common_kwargs["data_splits_url"] = inputs["estimation_procedure"][
"oml:estimation_procedure"
]["oml:data_splits_url"]
common_kwargs["target_name"] = (
inputs["source_data"]["oml:data_set"].get("oml:target_feature") or None
)
common_kwargs["data_splits_url"] = est_proc.get("oml:data_splits_url")

cls = {
TaskType.SUPERVISED_CLASSIFICATION: OpenMLClassificationTask,
Expand Down
24 changes: 24 additions & 0 deletions tests/test_tasks/test_clustering_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,30 @@ def test_download_task(self):
assert task.task_type_id == TaskType.CLUSTERING
assert task.dataset_id == 36

@pytest.mark.production()
def test_estimation_procedure_extraction(self):
# task 126033 has complete estimation procedure data
self.use_production_server()
task = openml.tasks.get_task(126033, download_data=False)

assert task.task_type_id == TaskType.CLUSTERING
assert task.estimation_procedure_id == 17

est_proc = task.estimation_procedure
assert est_proc["type"] == "testontrainingdata"
assert est_proc["parameters"] is not None
assert "number_repeats" in est_proc["parameters"]
assert est_proc["data_splits_url"] is not None

@pytest.mark.production()
def test_estimation_procedure_empty_fields(self):
# task 146714 has empty estimation procedure fields in XML
self.use_production_server()
task = openml.tasks.get_task(self.task_id, download_data=False)

assert task.task_type_id == TaskType.CLUSTERING
assert task.estimation_procedure_id == 17

def test_upload_task(self):
compatible_datasets = self._get_compatible_rand_dataset()
for i in range(100):
Expand Down