Skip to content

Commit fe8b1bc

Browse files
committed
Alignerr project creation
1 parent 2fab8c9 commit fe8b1bc

19 files changed

+1832
-1
lines changed

libs/labelbox/pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ dependencies = [
1212
"tqdm>=4.66.2",
1313
"geojson>=3.1.0",
1414
"lbox-clients==1.1.2",
15+
"PyYAML>=6.0",
1516
]
1617
readme = "README.md"
1718
requires-python = ">=3.9,<3.14"

libs/labelbox/src/labelbox/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@
7878
from labelbox.schema.ontology_kind import OntologyKind
7979
from labelbox.schema.organization import Organization
8080
from labelbox.schema.project import Project
81+
from labelbox.alignerr.schema.project_rate import ProjectRateV2 as ProjectRate
8182
from labelbox.schema.project_model_config import ProjectModelConfig
8283
from labelbox.schema.project_overview import (
8384
ProjectOverview,
@@ -98,7 +99,6 @@
9899
ResponseOption,
99100
PromptResponseClassification,
100101
)
101-
from lbox.exceptions import *
102102
from labelbox.schema.taskstatus import TaskStatus
103103
from labelbox.schema.api_key import ApiKey
104104
from labelbox.schema.timeunit import TimeUnit
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .alignerr_project import AlignerrWorkspace
2+
3+
__all__ = ['AlignerrWorkspace']
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
from enum import Enum
2+
from typing import TYPE_CHECKING, Optional
3+
4+
import logging
5+
6+
from labelbox.alignerr.schema.project_rate import ProjectRateV2
7+
from labelbox.alignerr.schema.project_domain import ProjectDomain
8+
from labelbox.pagination import PaginatedCollection
9+
10+
logger = logging.getLogger(__name__)
11+
12+
13+
if TYPE_CHECKING:
14+
from labelbox import Client
15+
from labelbox.schema.project import Project
16+
from labelbox.alignerr.schema.project_domain import ProjectDomain
17+
18+
19+
class AlignerrRole(Enum):
20+
Labeler = "LABELER"
21+
Reviewer = "REVIEWER"
22+
Admin = "ADMIN"
23+
24+
25+
class AlignerrProject:
26+
def __init__(self, client: "Client", project: "Project", _internal: bool = False):
27+
if not _internal:
28+
raise RuntimeError(
29+
"AlignerrProject cannot be initialized directly. "
30+
"Use AlignerrProjectBuilder or AlignerrProjectFactory to create instances."
31+
)
32+
self.client = client
33+
self.project = project
34+
35+
@property
36+
def project(self) -> Optional["Project"]:
37+
return self._project
38+
39+
@project.setter
40+
def project(self, project: "Project"):
41+
self._project = project
42+
43+
44+
def domains(self) -> PaginatedCollection:
45+
"""Get all domains associated with this project.
46+
47+
Returns:
48+
PaginatedCollection of ProjectDomain instances
49+
"""
50+
return ProjectDomain.get_by_project_id(
51+
client=self.client,
52+
project_id=self.project.uid
53+
)
54+
55+
def add_domain(self, project_domain: ProjectDomain):
56+
return ProjectDomain.connect_project_to_domains(
57+
client=self.client,
58+
project_id=self.project.uid,
59+
domain_ids=[project_domain.uid]
60+
)
61+
62+
def get_project_rate(self) -> Optional["ProjectRateV2"]:
63+
return ProjectRateV2.get_by_project_id(
64+
client=self.client,
65+
project_id=self.project.uid
66+
)
67+
68+
def set_project_rate(self, project_rate_input):
69+
return ProjectRateV2.set_project_rate(
70+
client=self.client,
71+
project_id=self.project.uid,
72+
project_rate_input=project_rate_input
73+
)
74+
75+
76+
77+
78+
79+
class AlignerrWorkspace:
80+
def __init__(self, client: "Client"):
81+
self.client = client
82+
83+
def project_builder(self):
84+
from labelbox.alignerr.alignerr_project_builder import AlignerrProjectBuilder
85+
return AlignerrProjectBuilder(self.client)
86+
87+
def project_prototype(self):
88+
from labelbox.alignerr.alignerr_project_factory import AlignerrProjectFactory
89+
return AlignerrProjectFactory(self.client)
90+
91+
Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
import datetime
2+
from typing import TYPE_CHECKING, Optional
3+
import logging
4+
5+
from labelbox.alignerr.schema.project_rate import BillingMode
6+
from labelbox.alignerr.schema.project_rate import ProjectRateInput
7+
from labelbox.alignerr.schema.project_rate import ProjectRateV2
8+
from labelbox.alignerr.schema.project_domain import ProjectDomain
9+
from labelbox.schema.media_type import MediaType
10+
11+
logger = logging.getLogger(__name__)
12+
13+
14+
if TYPE_CHECKING:
15+
from labelbox import Client
16+
from labelbox.alignerr.alignerr_project import AlignerrProject, AlignerrRole
17+
18+
19+
class AlignerrProjectBuilder:
20+
def __init__(self, client: "Client"):
21+
self.client = client
22+
self._alignerr_rates: dict[str, ProjectRateInput] = {}
23+
self._customer_rate: ProjectRateInput = None
24+
self._domains: list[ProjectDomain] = []
25+
self.role_name_to_id = self._get_role_name_to_id()
26+
27+
def set_name(self, name: str):
28+
self.project_name = name
29+
return self
30+
31+
def set_media_type(self, media_type: "MediaType"):
32+
self.project_media_type = media_type
33+
return self
34+
35+
def set_alignerr_role_rate(
36+
self,
37+
*,
38+
role_name: "AlignerrRole",
39+
rate: float,
40+
billing_mode: BillingMode,
41+
effective_since: datetime.datetime,
42+
effective_until: Optional[datetime.datetime] = None,
43+
):
44+
if role_name.value not in self.role_name_to_id:
45+
raise ValueError(f"Role {role_name.value} not found")
46+
47+
role_id = self.role_name_to_id[role_name.value]
48+
role_name = role_name.value
49+
50+
# Convert datetime objects to ISO format strings
51+
effective_since_str = effective_since.isoformat() if isinstance(effective_since, datetime.datetime) else effective_since
52+
effective_until_str = effective_until.isoformat() if isinstance(effective_until, datetime.datetime) else effective_until
53+
54+
self._alignerr_rates[role_name] = ProjectRateInput(
55+
rateForId=role_id,
56+
isBillRate=False,
57+
billingMode=billing_mode,
58+
rate=rate,
59+
effectiveSince=effective_since_str,
60+
effectiveUntil=effective_until_str,
61+
)
62+
return self
63+
64+
def set_customer_rate(
65+
self,
66+
*,
67+
rate: float,
68+
billing_mode: BillingMode,
69+
effective_since: datetime.datetime,
70+
effective_until: Optional[datetime.datetime] = None,
71+
):
72+
# Convert datetime objects to ISO format strings
73+
effective_since_str = effective_since.isoformat() if isinstance(effective_since, datetime.datetime) else effective_since
74+
effective_until_str = effective_until.isoformat() if isinstance(effective_until, datetime.datetime) else effective_until
75+
76+
self._customer_rate = ProjectRateInput(
77+
rateForId="", # Empty string for customer rate
78+
isBillRate=True,
79+
billingMode=billing_mode,
80+
rate=rate,
81+
effectiveSince=effective_since_str,
82+
effectiveUntil=effective_until_str,
83+
)
84+
return self
85+
86+
def set_domains(self, domains: list[str]):
87+
for domain in domains:
88+
project_domain_page = ProjectDomain.search(self.client, search_by_name=domain)
89+
domain_result = project_domain_page.get_one()
90+
if domain_result is None:
91+
raise ValueError(f"Domain {domain} not found")
92+
self._domains.append(domain_result)
93+
return self
94+
95+
96+
def create(self, skip_validation: bool = False):
97+
if not skip_validation:
98+
self._validate()
99+
logger.info("Creating project")
100+
101+
project_data = {
102+
"name": self.project_name,
103+
"media_type": self.project_media_type,
104+
}
105+
labelbox_project = self.client.create_project(**project_data)
106+
107+
# Import here to avoid circular imports
108+
from labelbox.alignerr.alignerr_project import AlignerrProject
109+
alignerr_project = AlignerrProject(self.client, labelbox_project, _internal=True)
110+
111+
self._create_rates(alignerr_project)
112+
self._create_domains(alignerr_project)
113+
114+
return alignerr_project
115+
116+
def _create_rates(self, alignerr_project: "AlignerrProject"):
117+
for alignerr_role, project_rate in self._alignerr_rates.items():
118+
logger.info(f"Setting project rate for {alignerr_role}")
119+
alignerr_project.set_project_rate(project_rate)
120+
121+
def _create_domains(self, alignerr_project: "AlignerrProject"):
122+
if self._domains:
123+
logger.info(f"Setting domains: {[domain.name for domain in self._domains]}")
124+
domain_ids = [domain.uid for domain in self._domains]
125+
ProjectDomain.connect_project_to_domains(
126+
client=self.client,
127+
project_id=alignerr_project.project.uid,
128+
domain_ids=domain_ids
129+
)
130+
131+
def _validate_alignerr_rates(self):
132+
# Import here to avoid circular imports
133+
from labelbox.alignerr.alignerr_project import AlignerrRole
134+
135+
required_role_rates = set([AlignerrRole.Labeler.value, AlignerrRole.Reviewer.value])
136+
137+
for role_name in self._alignerr_rates.keys():
138+
required_role_rates.remove(role_name)
139+
if len(required_role_rates) > 0:
140+
raise ValueError(
141+
f"Required role rates are not set: {required_role_rates}"
142+
)
143+
144+
def _validate_customer_rate(self):
145+
if self._customer_rate is None:
146+
raise ValueError("Customer rate is not set")
147+
148+
def _validate(self):
149+
self._validate_alignerr_rates()
150+
self._validate_customer_rate()
151+
152+
def _get_role_name_to_id(self) -> dict[str, str]:
153+
roles = self.client.get_roles()
154+
return {role.name: role.uid for role in roles.values()}

0 commit comments

Comments
 (0)