Skip to content

Commit af7fc8a

Browse files
committed
Alignerr project creation
1 parent 2fab8c9 commit af7fc8a

19 files changed

+1927
-1
lines changed

libs/labelbox/pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ dependencies = [
1212
"tqdm>=4.66.2",
1313
"geojson>=3.1.0",
1414
"lbox-clients==1.1.2",
15+
"PyYAML>=6.0",
1516
]
1617
readme = "README.md"
1718
requires-python = ">=3.9,<3.14"

libs/labelbox/src/labelbox/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@
7878
from labelbox.schema.ontology_kind import OntologyKind
7979
from labelbox.schema.organization import Organization
8080
from labelbox.schema.project import Project
81+
from labelbox.alignerr.schema.project_rate import ProjectRateV2 as ProjectRate
8182
from labelbox.schema.project_model_config import ProjectModelConfig
8283
from labelbox.schema.project_overview import (
8384
ProjectOverview,
@@ -98,7 +99,6 @@
9899
ResponseOption,
99100
PromptResponseClassification,
100101
)
101-
from lbox.exceptions import *
102102
from labelbox.schema.taskstatus import TaskStatus
103103
from labelbox.schema.api_key import ApiKey
104104
from labelbox.schema.timeunit import TimeUnit
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .alignerr_project import AlignerrWorkspace
2+
3+
__all__ = ["AlignerrWorkspace"]
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
from enum import Enum
2+
from typing import TYPE_CHECKING, Optional
3+
4+
import logging
5+
6+
from labelbox.alignerr.schema.project_rate import ProjectRateV2
7+
from labelbox.alignerr.schema.project_domain import ProjectDomain
8+
from labelbox.pagination import PaginatedCollection
9+
10+
logger = logging.getLogger(__name__)
11+
12+
13+
if TYPE_CHECKING:
14+
from labelbox import Client
15+
from labelbox.schema.project import Project
16+
from labelbox.alignerr.schema.project_domain import ProjectDomain
17+
18+
19+
class AlignerrRole(Enum):
20+
Labeler = "LABELER"
21+
Reviewer = "REVIEWER"
22+
Admin = "ADMIN"
23+
24+
25+
class AlignerrProject:
26+
def __init__(
27+
self, client: "Client", project: "Project", _internal: bool = False
28+
):
29+
if not _internal:
30+
raise RuntimeError(
31+
"AlignerrProject cannot be initialized directly. "
32+
"Use AlignerrProjectBuilder or AlignerrProjectFactory to create instances."
33+
)
34+
self.client = client
35+
self.project = project
36+
37+
@property
38+
def project(self) -> Optional["Project"]:
39+
return self._project
40+
41+
@project.setter
42+
def project(self, project: "Project"):
43+
self._project = project
44+
45+
def domains(self) -> PaginatedCollection:
46+
"""Get all domains associated with this project.
47+
48+
Returns:
49+
PaginatedCollection of ProjectDomain instances
50+
"""
51+
return ProjectDomain.get_by_project_id(
52+
client=self.client, project_id=self.project.uid
53+
)
54+
55+
def add_domain(self, project_domain: ProjectDomain):
56+
return ProjectDomain.connect_project_to_domains(
57+
client=self.client,
58+
project_id=self.project.uid,
59+
domain_ids=[project_domain.uid],
60+
)
61+
62+
def get_project_rate(self) -> Optional["ProjectRateV2"]:
63+
return ProjectRateV2.get_by_project_id(
64+
client=self.client, project_id=self.project.uid
65+
)
66+
67+
def set_project_rate(self, project_rate_input):
68+
return ProjectRateV2.set_project_rate(
69+
client=self.client,
70+
project_id=self.project.uid,
71+
project_rate_input=project_rate_input,
72+
)
73+
74+
75+
class AlignerrWorkspace:
76+
def __init__(self, client: "Client"):
77+
self.client = client
78+
79+
def project_builder(self):
80+
from labelbox.alignerr.alignerr_project_builder import (
81+
AlignerrProjectBuilder,
82+
)
83+
84+
return AlignerrProjectBuilder(self.client)
85+
86+
def project_prototype(self):
87+
from labelbox.alignerr.alignerr_project_factory import (
88+
AlignerrProjectFactory,
89+
)
90+
91+
return AlignerrProjectFactory(self.client)
Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
import datetime
2+
from typing import TYPE_CHECKING, Optional
3+
import logging
4+
5+
from labelbox.alignerr.schema.project_rate import BillingMode
6+
from labelbox.alignerr.schema.project_rate import ProjectRateInput
7+
from labelbox.alignerr.schema.project_rate import ProjectRateV2
8+
from labelbox.alignerr.schema.project_domain import ProjectDomain
9+
from labelbox.schema.media_type import MediaType
10+
11+
logger = logging.getLogger(__name__)
12+
13+
14+
if TYPE_CHECKING:
15+
from labelbox import Client
16+
from labelbox.alignerr.alignerr_project import AlignerrProject, AlignerrRole
17+
18+
19+
class AlignerrProjectBuilder:
20+
def __init__(self, client: "Client"):
21+
self.client = client
22+
self._alignerr_rates: dict[str, ProjectRateInput] = {}
23+
self._customer_rate: ProjectRateInput = None
24+
self._domains: list[ProjectDomain] = []
25+
self.role_name_to_id = self._get_role_name_to_id()
26+
27+
def set_name(self, name: str):
28+
self.project_name = name
29+
return self
30+
31+
def set_media_type(self, media_type: "MediaType"):
32+
self.project_media_type = media_type
33+
return self
34+
35+
def set_alignerr_role_rate(
36+
self,
37+
*,
38+
role_name: "AlignerrRole",
39+
rate: float,
40+
billing_mode: BillingMode,
41+
effective_since: datetime.datetime,
42+
effective_until: Optional[datetime.datetime] = None,
43+
):
44+
if role_name.value not in self.role_name_to_id:
45+
raise ValueError(f"Role {role_name.value} not found")
46+
47+
role_id = self.role_name_to_id[role_name.value]
48+
role_name = role_name.value
49+
50+
# Convert datetime objects to ISO format strings
51+
effective_since_str = (
52+
effective_since.isoformat()
53+
if isinstance(effective_since, datetime.datetime)
54+
else effective_since
55+
)
56+
effective_until_str = (
57+
effective_until.isoformat()
58+
if isinstance(effective_until, datetime.datetime)
59+
else effective_until
60+
)
61+
62+
self._alignerr_rates[role_name] = ProjectRateInput(
63+
rateForId=role_id,
64+
isBillRate=False,
65+
billingMode=billing_mode,
66+
rate=rate,
67+
effectiveSince=effective_since_str,
68+
effectiveUntil=effective_until_str,
69+
)
70+
return self
71+
72+
def set_customer_rate(
73+
self,
74+
*,
75+
rate: float,
76+
billing_mode: BillingMode,
77+
effective_since: datetime.datetime,
78+
effective_until: Optional[datetime.datetime] = None,
79+
):
80+
# Convert datetime objects to ISO format strings
81+
effective_since_str = (
82+
effective_since.isoformat()
83+
if isinstance(effective_since, datetime.datetime)
84+
else effective_since
85+
)
86+
effective_until_str = (
87+
effective_until.isoformat()
88+
if isinstance(effective_until, datetime.datetime)
89+
else effective_until
90+
)
91+
92+
self._customer_rate = ProjectRateInput(
93+
rateForId="", # Empty string for customer rate
94+
isBillRate=True,
95+
billingMode=billing_mode,
96+
rate=rate,
97+
effectiveSince=effective_since_str,
98+
effectiveUntil=effective_until_str,
99+
)
100+
return self
101+
102+
def set_domains(self, domains: list[str]):
103+
for domain in domains:
104+
project_domain_page = ProjectDomain.search(
105+
self.client, search_by_name=domain
106+
)
107+
domain_result = project_domain_page.get_one()
108+
if domain_result is None:
109+
raise ValueError(f"Domain {domain} not found")
110+
self._domains.append(domain_result)
111+
return self
112+
113+
def create(self, skip_validation: bool = False):
114+
if not skip_validation:
115+
self._validate()
116+
logger.info("Creating project")
117+
118+
project_data = {
119+
"name": self.project_name,
120+
"media_type": self.project_media_type,
121+
}
122+
labelbox_project = self.client.create_project(**project_data)
123+
124+
# Import here to avoid circular imports
125+
from labelbox.alignerr.alignerr_project import AlignerrProject
126+
127+
alignerr_project = AlignerrProject(
128+
self.client, labelbox_project, _internal=True
129+
)
130+
131+
self._create_rates(alignerr_project)
132+
self._create_domains(alignerr_project)
133+
134+
return alignerr_project
135+
136+
def _create_rates(self, alignerr_project: "AlignerrProject"):
137+
for alignerr_role, project_rate in self._alignerr_rates.items():
138+
logger.info(f"Setting project rate for {alignerr_role}")
139+
alignerr_project.set_project_rate(project_rate)
140+
141+
def _create_domains(self, alignerr_project: "AlignerrProject"):
142+
if self._domains:
143+
logger.info(
144+
f"Setting domains: {[domain.name for domain in self._domains]}"
145+
)
146+
domain_ids = [domain.uid for domain in self._domains]
147+
ProjectDomain.connect_project_to_domains(
148+
client=self.client,
149+
project_id=alignerr_project.project.uid,
150+
domain_ids=domain_ids,
151+
)
152+
153+
def _validate_alignerr_rates(self):
154+
# Import here to avoid circular imports
155+
from labelbox.alignerr.alignerr_project import AlignerrRole
156+
157+
required_role_rates = set(
158+
[AlignerrRole.Labeler.value, AlignerrRole.Reviewer.value]
159+
)
160+
161+
for role_name in self._alignerr_rates.keys():
162+
required_role_rates.remove(role_name)
163+
if len(required_role_rates) > 0:
164+
raise ValueError(
165+
f"Required role rates are not set: {required_role_rates}"
166+
)
167+
168+
def _validate_customer_rate(self):
169+
if self._customer_rate is None:
170+
raise ValueError("Customer rate is not set")
171+
172+
def _validate(self):
173+
self._validate_alignerr_rates()
174+
self._validate_customer_rate()
175+
176+
def _get_role_name_to_id(self) -> dict[str, str]:
177+
roles = self.client.get_roles()
178+
return {role.name: role.uid for role in roles.values()}

0 commit comments

Comments
 (0)