diff --git a/src/viser/_gui_api.py b/src/viser/_gui_api.py index a1b6ae4b3..f3c68a4ae 100644 --- a/src/viser/_gui_api.py +++ b/src/viser/_gui_api.py @@ -129,27 +129,34 @@ def _handle_gui_updates( if handle is None: return - handle_state = handle._impl + prop_name = message.prop_name + prop_value = message.prop_value + del message - # Do some type casting. This is necessary when we expect floats but the - # Javascript side gives us integers. - if handle_state.typ is tuple: - assert len(message.value) == len(handle_state.value) - value = tuple( - type(handle_state.value[i])(message.value[i]) - for i in range(len(message.value)) - ) - else: - value = handle_state.typ(message.value) + handle_state = handle._impl + assert hasattr(handle_state, prop_name) + current_value = getattr(handle_state, prop_name) + + has_changed = current_value != prop_value + + if prop_name == "value": + # Do some type casting. This is necessary when we expect floats but the + # Javascript side gives us integers. + if handle_state.typ is tuple: + assert len(prop_value) == len(handle_state.value) + prop_value = tuple( + type(handle_state.value[i])(prop_value[i]) + for i in range(len(prop_value)) + ) + else: + prop_value = handle_state.typ(prop_value) # Only call update when value has actually changed. - if not handle_state.is_button and value == handle_state.value: + if not handle_state.is_button and not has_changed: return # Update state. - with self._get_api()._atomic_lock: - handle_state.value = value - handle_state.update_timestamp = time.time() + setattr(handle_state, prop_name, prop_value) # Trigger callbacks. for cb in handle_state.update_cb: @@ -165,8 +172,9 @@ def _handle_gui_updates( assert False cb(GuiEvent(client, client_id, handle)) + if handle_state.sync_cb is not None: - handle_state.sync_cb(client_id, value) + handle_state.sync_cb(client_id, prop_name, prop_value) def _get_container_id(self) -> str: """Get container ID associated with the current thread.""" @@ -196,6 +204,7 @@ def add_gui_folder( label: str, order: Optional[float] = None, expand_by_default: bool = True, + visible: bool = True, ) -> GuiFolderHandle: """Add a folder, and return a handle that can be used to populate it. @@ -204,6 +213,7 @@ def add_gui_folder( order: Optional ordering, smallest values will be displayed first. expand_by_default: Open the folder by default. Set to False to collapse it by default. + visible: Whether the component is visible. Returns: A handle that can be used as a context to populate the folder. @@ -217,6 +227,7 @@ def add_gui_folder( label=label, container_id=self._get_container_id(), expand_by_default=expand_by_default, + visible=visible, ) ) return GuiFolderHandle( @@ -258,24 +269,37 @@ def add_gui_modal( def add_gui_tab_group( self, order: Optional[float] = None, + visible: bool = True, ) -> GuiTabGroupHandle: """Add a tab group. Args: order: Optional ordering, smallest values will be displayed first. + visible: Whether the component is visible. Returns: A handle that can be used as a context to populate the tab group. """ tab_group_id = _make_unique_id() order = _apply_default_order(order) + + self._get_api()._queue( + _messages.GuiAddTabGroupMessage( + order=order, + id=tab_group_id, + container_id=self._get_container_id(), + tab_labels=(), + visible=visible, + tab_icons_base64=(), + tab_container_ids=(), + ) + ) return GuiTabGroupHandle( _tab_group_id=tab_group_id, _labels=[], _icons_base64=[], _tabs=[], _gui_api=self, - _container_id=self._get_container_id(), _order=order, ) @@ -284,6 +308,7 @@ def add_gui_markdown( content: str, image_root: Optional[Path] = None, order: Optional[float] = None, + visible: bool = True, ) -> GuiMarkdownHandle: """Add markdown to the GUI. @@ -291,6 +316,7 @@ def add_gui_markdown( content: Markdown content to display. image_root: Optional root directory to resolve relative image paths. order: Optional ordering, smallest values will be displayed first. + visible: Whether the component is visible. Returns: A handle that can be used to interact with the GUI element. @@ -298,7 +324,7 @@ def add_gui_markdown( handle = GuiMarkdownHandle( _gui_api=self, _id=_make_unique_id(), - _visible=True, + _visible=visible, _container_id=self._get_container_id(), _order=_apply_default_order(order), _image_root=image_root, @@ -357,19 +383,19 @@ def add_gui_button( order = _apply_default_order(order) return GuiButtonHandle( self._create_gui_input( - initial_value=False, + value=False, message=_messages.GuiAddButtonMessage( order=order, id=id, label=label, container_id=self._get_container_id(), hint=hint, - initial_value=False, + value=False, color=color, icon_base64=None if icon is None else base64_from_icon(icon), + disabled=disabled, + visible=visible, ), - disabled=disabled, - visible=visible, is_button=True, )._impl ) @@ -425,23 +451,23 @@ def add_gui_button_group( Returns: A handle that can be used to interact with the GUI element. """ - initial_value = options[0] + value = options[0] id = _make_unique_id() order = _apply_default_order(order) return GuiButtonGroupHandle( self._create_gui_input( - initial_value, + value, message=_messages.GuiAddButtonGroupMessage( order=order, id=id, label=label, container_id=self._get_container_id(), hint=hint, - initial_value=initial_value, + value=value, options=tuple(options), + disabled=disabled, + visible=visible, ), - disabled=disabled, - visible=visible, )._impl, ) @@ -467,21 +493,22 @@ def add_gui_checkbox( Returns: A handle that can be used to interact with the GUI element. """ - assert isinstance(initial_value, bool) + value = initial_value + assert isinstance(value, bool) id = _make_unique_id() order = _apply_default_order(order) return self._create_gui_input( - initial_value, + value, message=_messages.GuiAddCheckboxMessage( order=order, id=id, label=label, container_id=self._get_container_id(), hint=hint, - initial_value=initial_value, + value=value, + disabled=disabled, + visible=visible, ), - disabled=disabled, - visible=visible, ) def add_gui_text( @@ -506,21 +533,22 @@ def add_gui_text( Returns: A handle that can be used to interact with the GUI element. """ - assert isinstance(initial_value, str) + value = initial_value + assert isinstance(value, str) id = _make_unique_id() order = _apply_default_order(order) return self._create_gui_input( - initial_value, + value, message=_messages.GuiAddTextMessage( order=order, id=id, label=label, container_id=self._get_container_id(), hint=hint, - initial_value=initial_value, + value=value, + disabled=disabled, + visible=visible, ), - disabled=disabled, - visible=visible, ) def add_gui_number( @@ -552,8 +580,9 @@ def add_gui_number( Returns: A handle that can be used to interact with the GUI element. """ + value = initial_value - assert isinstance(initial_value, (int, float)) + assert isinstance(value, (int, float)) if step is None: # It's ok that `step` is always a float, even if the value is an integer, @@ -561,7 +590,7 @@ def add_gui_number( step = float( # type: ignore onp.min( [ - _compute_step(initial_value), + _compute_step(value), _compute_step(min), _compute_step(max), ] @@ -573,21 +602,21 @@ def add_gui_number( id = _make_unique_id() order = _apply_default_order(order) return self._create_gui_input( - initial_value=initial_value, + value, message=_messages.GuiAddNumberMessage( order=order, id=id, label=label, container_id=self._get_container_id(), hint=hint, - initial_value=initial_value, + value=value, min=min, max=max, precision=_compute_precision_digits(step), step=step, + disabled=disabled, + visible=visible, ), - disabled=disabled, - visible=visible, is_button=False, ) @@ -619,7 +648,8 @@ def add_gui_vector2( Returns: A handle that can be used to interact with the GUI element. """ - initial_value = cast_vector(initial_value, 2) + value = initial_value + value = cast_vector(value, 2) min = cast_vector(min, 2) if min is not None else None max = cast_vector(max, 2) if max is not None else None id = _make_unique_id() @@ -627,7 +657,7 @@ def add_gui_vector2( if step is None: possible_steps: List[float] = [] - possible_steps.extend([_compute_step(x) for x in initial_value]) + possible_steps.extend([_compute_step(x) for x in value]) if min is not None: possible_steps.extend([_compute_step(x) for x in min]) if max is not None: @@ -635,21 +665,21 @@ def add_gui_vector2( step = float(onp.min(possible_steps)) return self._create_gui_input( - initial_value, + value, message=_messages.GuiAddVector2Message( order=order, id=id, label=label, container_id=self._get_container_id(), hint=hint, - initial_value=initial_value, + value=value, min=min, max=max, step=step, precision=_compute_precision_digits(step), + disabled=disabled, + visible=visible, ), - disabled=disabled, - visible=visible, ) def add_gui_vector3( @@ -680,7 +710,8 @@ def add_gui_vector3( Returns: A handle that can be used to interact with the GUI element. """ - initial_value = cast_vector(initial_value, 2) + value = initial_value + value = cast_vector(value, 2) min = cast_vector(min, 3) if min is not None else None max = cast_vector(max, 3) if max is not None else None id = _make_unique_id() @@ -688,7 +719,7 @@ def add_gui_vector3( if step is None: possible_steps: List[float] = [] - possible_steps.extend([_compute_step(x) for x in initial_value]) + possible_steps.extend([_compute_step(x) for x in value]) if min is not None: possible_steps.extend([_compute_step(x) for x in min]) if max is not None: @@ -696,21 +727,21 @@ def add_gui_vector3( step = float(onp.min(possible_steps)) return self._create_gui_input( - initial_value, + value, message=_messages.GuiAddVector3Message( order=order, id=id, label=label, container_id=self._get_container_id(), hint=hint, - initial_value=initial_value, + value=value, min=min, max=max, step=step, precision=_compute_precision_digits(step), + disabled=disabled, + visible=visible, ), - disabled=disabled, - visible=visible, ) # See add_gui_dropdown for notes on overloads. @@ -764,24 +795,25 @@ def add_gui_dropdown( Returns: A handle that can be used to interact with the GUI element. """ - if initial_value is None: - initial_value = options[0] + value = initial_value + if value is None: + value = options[0] id = _make_unique_id() order = _apply_default_order(order) return GuiDropdownHandle( self._create_gui_input( - initial_value, + value, message=_messages.GuiAddDropdownMessage( order=order, id=id, label=label, container_id=self._get_container_id(), hint=hint, - initial_value=initial_value, + value=value, options=tuple(options), + disabled=disabled, + visible=visible, ), - disabled=disabled, - visible=visible, )._impl, _impl_options=tuple(options), ) @@ -818,29 +850,30 @@ def add_gui_slider( Returns: A handle that can be used to interact with the GUI element. """ + value: IntOrFloat = initial_value assert max >= min if step > max - min: step = max - min - assert max >= initial_value >= min + assert max >= value >= min # GUI callbacks cast incoming values to match the type of the initial value. If # the min, max, or step is a float, we should cast to a float. # # This should also match what the IntOrFloat TypeVar resolves to. - if type(initial_value) is int and ( + if type(value) is int and ( type(min) is float or type(max) is float or type(step) is float ): - initial_value = float(initial_value) # type: ignore + value = float(value) # type: ignore # TODO: as of 6/5/2023, this assert will break something in nerfstudio. (at # least LERF) # - # assert type(min) == type(max) == type(step) == type(initial_value) + # assert type(min) == type(max) == type(step) == type(value) id = _make_unique_id() order = _apply_default_order(order) return self._create_gui_input( - initial_value=initial_value, + value, message=_messages.GuiAddSliderMessage( order=order, id=id, @@ -850,8 +883,10 @@ def add_gui_slider( min=min, max=max, step=step, - initial_value=initial_value, + value=value, precision=_compute_precision_digits(step), + visible=visible, + disabled=disabled, marks=tuple( {"value": float(x[0]), "label": x[1]} if isinstance(x, tuple) @@ -861,8 +896,6 @@ def add_gui_slider( if marks is not None else None, ), - disabled=disabled, - visible=visible, is_button=False, ) @@ -922,7 +955,7 @@ def add_gui_multi_slider( id = _make_unique_id() order = _apply_default_order(order) return self._create_gui_input( - initial_value=initial_value, + value=initial_value, message=_messages.GuiAddMultiSliderMessage( order=order, id=id, @@ -933,7 +966,9 @@ def add_gui_multi_slider( min_range=min_range, max=max, step=step, - initial_value=initial_value, + value=initial_value, + visible=visible, + disabled=disabled, fixed_endpoints=fixed_endpoints, precision=_compute_precision_digits(step), marks=tuple( @@ -945,8 +980,6 @@ def add_gui_multi_slider( if marks is not None else None, ), - disabled=disabled, - visible=visible, is_button=False, ) @@ -973,20 +1006,21 @@ def add_gui_rgb( A handle that can be used to interact with the GUI element. """ + value = initial_value id = _make_unique_id() order = _apply_default_order(order) return self._create_gui_input( - initial_value, + value, message=_messages.GuiAddRgbMessage( order=order, id=id, label=label, container_id=self._get_container_id(), hint=hint, - initial_value=initial_value, + value=value, + disabled=disabled, + visible=visible, ), - disabled=disabled, - visible=visible, ) def add_gui_rgba( @@ -1011,28 +1045,27 @@ def add_gui_rgba( Returns: A handle that can be used to interact with the GUI element. """ + value = initial_value id = _make_unique_id() order = _apply_default_order(order) return self._create_gui_input( - initial_value, + value, message=_messages.GuiAddRgbaMessage( order=order, id=id, label=label, container_id=self._get_container_id(), hint=hint, - initial_value=initial_value, + value=value, + disabled=disabled, + visible=visible, ), - disabled=disabled, - visible=visible, ) def _create_gui_input( self, - initial_value: T, + value: T, message: _messages._GuiAddInputBase, - disabled: bool, - visible: bool, is_button: bool = False, ) -> GuiInputHandle[T]: """Private helper for adding a simple GUI element.""" @@ -1043,19 +1076,20 @@ def _create_gui_input( # Construct handle. handle_state = _GuiHandleState( label=message.label, - typ=type(initial_value), + message_type=type(message), + typ=type(value), gui_api=self, - value=initial_value, + value=value, + initial_value=value, update_timestamp=time.time(), container_id=self._get_container_id(), update_cb=[], is_button=is_button, sync_cb=None, - disabled=False, - visible=True, + disabled=message.disabled, + visible=message.visible, id=message.id, order=message.order, - initial_value=initial_value, hint=message.hint, ) @@ -1063,8 +1097,12 @@ def _create_gui_input( # This will be a no-op for client handles. if not is_button: - def sync_other_clients(client_id: ClientId, value: Any) -> None: - message = _messages.GuiSetValueMessage(id=handle_state.id, value=value) + def sync_other_clients( + client_id: ClientId, prop_name: str, prop_value: Any + ) -> None: + message = _messages.GuiUpdateMessage( + handle_state.id, prop_name, prop_value + ) message.excluded_self_client = client_id self._get_api()._queue(message) @@ -1072,10 +1110,4 @@ def sync_other_clients(client_id: ClientId, value: Any) -> None: handle = GuiInputHandle(handle_state) - # Set the disabled/visible fields. These will queue messages under-the-hood. - if disabled: - handle.disabled = disabled - if not visible: - handle.visible = visible - return handle diff --git a/src/viser/_gui_handles.py b/src/viser/_gui_handles.py index 110430795..1e22ea94c 100644 --- a/src/viser/_gui_handles.py +++ b/src/viser/_gui_handles.py @@ -10,6 +10,7 @@ from pathlib import Path from typing import ( TYPE_CHECKING, + Any, Callable, Dict, Generic, @@ -29,14 +30,10 @@ from ._icons_enum import IconName from ._message_api import _encode_image_base64 from ._messages import ( - GuiAddDropdownMessage, - GuiAddMarkdownMessage, - GuiAddTabGroupMessage, GuiCloseModalMessage, GuiRemoveMessage, - GuiSetDisabledMessage, - GuiSetValueMessage, - GuiSetVisibleMessage, + GuiUpdateMessage, + Message, ) from .infra import ClientId @@ -84,7 +81,7 @@ class _GuiHandleState(Generic[T]): is_button: bool """Indicates a button element, which requires special handling.""" - sync_cb: Optional[Callable[[ClientId, T], None]] + sync_cb: Optional[Callable[[ClientId, str, Any], None]] """Callback for synchronizing inputs across clients.""" disabled: bool @@ -95,6 +92,8 @@ class _GuiHandleState(Generic[T]): initial_value: T hint: Optional[str] + message_type: Type[Message] + @dataclasses.dataclass class _GuiInputHandle(Generic[T]): @@ -137,7 +136,7 @@ def value(self, value: T | onp.ndarray) -> None: # Send to client, except for buttons. if not self._impl.is_button: self._impl.gui_api._get_api()._queue( - GuiSetValueMessage(self._impl.id, value) # type: ignore + GuiUpdateMessage(self._impl.id, "value", value) ) # Set internal state. We automatically convert numpy arrays to the expected @@ -176,7 +175,7 @@ def disabled(self, disabled: bool) -> None: return self._impl.gui_api._get_api()._queue( - GuiSetDisabledMessage(self._impl.id, disabled=disabled) + GuiUpdateMessage(self._impl.id, "disabled", disabled) ) self._impl.disabled = disabled @@ -192,7 +191,7 @@ def visible(self, visible: bool) -> None: return self._impl.gui_api._get_api()._queue( - GuiSetVisibleMessage(self._impl.id, visible=visible) + GuiUpdateMessage(self._impl.id, "visible", visible) ) self._impl.visible = visible @@ -312,15 +311,7 @@ def options(self, options: Iterable[StringType]) -> None: self._impl.initial_value = self._impl_options[0] self._impl.gui_api._get_api()._queue( - GuiAddDropdownMessage( - order=self._impl.order, - id=self._impl.id, - label=self._impl.label, - container_id=self._impl.container_id, - hint=self._impl.hint, - initial_value=self._impl.initial_value, - options=self._impl_options, - ) + GuiUpdateMessage(self._impl.id, "options", self._impl_options) ) if self.value not in self._impl_options: @@ -334,7 +325,6 @@ class GuiTabGroupHandle: _icons_base64: List[Optional[str]] _tabs: List[GuiTabHandle] _gui_api: GuiApi - _container_id: str # Parent. _order: float @property @@ -364,15 +354,20 @@ def remove(self) -> None: self._gui_api._get_api()._queue(GuiRemoveMessage(self._tab_group_id)) def _sync_with_client(self) -> None: - """Send a message that syncs tab state with the client.""" + """Send messages for syncing tab state with the client.""" + self._gui_api._get_api()._queue( + GuiUpdateMessage(self._tab_group_id, "tab_labels", tuple(self._labels)) + ) + self._gui_api._get_api()._queue( + GuiUpdateMessage( + self._tab_group_id, "tab_icons_base64", tuple(self._icons_base64) + ) + ) self._gui_api._get_api()._queue( - GuiAddTabGroupMessage( - order=self.order, - id=self._tab_group_id, - container_id=self._container_id, - tab_labels=tuple(self._labels), - tab_icons_base64=tuple(self._icons_base64), - tab_container_ids=tuple(tab._id for tab in self._tabs), + GuiUpdateMessage( + self._tab_group_id, + "tab_container_ids", + tuple(tab._id for tab in self._tabs), ) ) @@ -561,11 +556,10 @@ def content(self) -> str: def content(self, content: str) -> None: self._content = content self._gui_api._get_api()._queue( - GuiAddMarkdownMessage( - order=self._order, - id=self._id, - markdown=_parse_markdown(content, self._image_root), - container_id=self._container_id, + GuiUpdateMessage( + self._id, + "markdown", + _parse_markdown(content, self._image_root), ) ) @@ -585,7 +579,7 @@ def visible(self, visible: bool) -> None: if visible == self.visible: return - self._gui_api._get_api()._queue(GuiSetVisibleMessage(self._id, visible=visible)) + self._gui_api._get_api()._queue(GuiUpdateMessage(self._id, "visible", visible)) self._visible = visible def __post_init__(self) -> None: diff --git a/src/viser/_messages.py b/src/viser/_messages.py index 26d567b68..5b9566a87 100644 --- a/src/viser/_messages.py +++ b/src/viser/_messages.py @@ -4,7 +4,7 @@ from __future__ import annotations import dataclasses -from typing import Any, Optional, Tuple, Union +from typing import Any, Callable, ClassVar, Optional, Tuple, Type, TypeVar, Union import numpy as onp import numpy.typing as onpt @@ -16,6 +16,8 @@ class Message(infra.Message): + _tags: ClassVar[Tuple[str, ...]] = tuple() + @override def redundancy_key(self) -> str: """Returns a unique key for this message, used for detecting redundant @@ -39,6 +41,19 @@ def redundancy_key(self) -> str: return "_".join(parts) +T = TypeVar("T", bound=Type[Message]) + + +def tag_class(tag: str) -> Callable[[T], T]: + """Decorator for tagging a class with a `type` field.""" + + def wrapper(cls: T) -> T: + cls._tags = (cls._tags or ()) + (tag,) + return cls + + return wrapper + + @dataclasses.dataclass class ViewerCameraMessage(Message): """Message for a posed viewer camera. @@ -348,6 +363,7 @@ class ResetSceneMessage(Message): """Reset scene.""" +@tag_class("GuiAddComponentMessage") @dataclasses.dataclass class GuiAddFolderMessage(Message): order: float @@ -355,16 +371,20 @@ class GuiAddFolderMessage(Message): label: str container_id: str expand_by_default: bool + visible: bool +@tag_class("GuiAddComponentMessage") @dataclasses.dataclass class GuiAddMarkdownMessage(Message): order: float id: str markdown: str container_id: str + visible: bool +@tag_class("GuiAddComponentMessage") @dataclasses.dataclass class GuiAddTabGroupMessage(Message): order: float @@ -373,6 +393,7 @@ class GuiAddTabGroupMessage(Message): tab_labels: Tuple[str, ...] tab_icons_base64: Tuple[Union[str, None], ...] tab_container_ids: Tuple[str, ...] + visible: bool @dataclasses.dataclass @@ -384,7 +405,9 @@ class _GuiAddInputBase(Message): label: str container_id: str hint: Optional[str] - initial_value: Any + value: Any + visible: bool + disabled: bool @dataclasses.dataclass @@ -399,11 +422,12 @@ class GuiCloseModalMessage(Message): id: str +@tag_class("GuiAddComponentMessage") @dataclasses.dataclass class GuiAddButtonMessage(_GuiAddInputBase): - # All GUI elements currently need an `initial_value` field. + # All GUI elements currently need an `value` field. # This makes our job on the frontend easier. - initial_value: bool + value: bool color: Optional[ Literal[ "dark", @@ -425,84 +449,94 @@ class GuiAddButtonMessage(_GuiAddInputBase): icon_base64: Optional[str] +@tag_class("GuiAddComponentMessage") @dataclasses.dataclass class GuiAddSliderMessage(_GuiAddInputBase): min: float max: float step: Optional[float] - initial_value: float + value: float precision: int marks: Optional[Tuple[GuiSliderMark, ...]] = None +@tag_class("GuiAddComponentMessage") @dataclasses.dataclass class GuiAddMultiSliderMessage(_GuiAddInputBase): min: float max: float step: Optional[float] min_range: Optional[float] - initial_value: Tuple[float, ...] precision: int fixed_endpoints: bool = False marks: Optional[Tuple[GuiSliderMark, ...]] = None +@tag_class("GuiAddComponentMessage") @dataclasses.dataclass class GuiAddNumberMessage(_GuiAddInputBase): - initial_value: float + value: float precision: int step: float min: Optional[float] max: Optional[float] +@tag_class("GuiAddComponentMessage") @dataclasses.dataclass class GuiAddRgbMessage(_GuiAddInputBase): - initial_value: Tuple[int, int, int] + value: Tuple[int, int, int] +@tag_class("GuiAddComponentMessage") @dataclasses.dataclass class GuiAddRgbaMessage(_GuiAddInputBase): - initial_value: Tuple[int, int, int, int] + value: Tuple[int, int, int, int] +@tag_class("GuiAddComponentMessage") @dataclasses.dataclass class GuiAddCheckboxMessage(_GuiAddInputBase): - initial_value: bool + value: bool +@tag_class("GuiAddComponentMessage") @dataclasses.dataclass class GuiAddVector2Message(_GuiAddInputBase): - initial_value: Tuple[float, float] + value: Tuple[float, float] min: Optional[Tuple[float, float]] max: Optional[Tuple[float, float]] step: float precision: int +@tag_class("GuiAddComponentMessage") @dataclasses.dataclass class GuiAddVector3Message(_GuiAddInputBase): - initial_value: Tuple[float, float, float] + value: Tuple[float, float, float] min: Optional[Tuple[float, float, float]] max: Optional[Tuple[float, float, float]] step: float precision: int +@tag_class("GuiAddComponentMessage") @dataclasses.dataclass class GuiAddTextMessage(_GuiAddInputBase): - initial_value: str + value: str +@tag_class("GuiAddComponentMessage") @dataclasses.dataclass class GuiAddDropdownMessage(_GuiAddInputBase): - initial_value: str + value: str options: Tuple[str, ...] +@tag_class("GuiAddComponentMessage") @dataclasses.dataclass class GuiAddButtonGroupMessage(_GuiAddInputBase): - initial_value: str + value: str options: Tuple[str, ...] @@ -515,34 +549,15 @@ class GuiRemoveMessage(Message): @dataclasses.dataclass class GuiUpdateMessage(Message): - """Sent client->server when a GUI input is changed.""" + """Sent client<->server when any property of a GUI component is changed.""" id: str - value: Any - - -@dataclasses.dataclass -class GuiSetVisibleMessage(Message): - """Sent client->server when a GUI input is changed.""" + prop_name: str + prop_value: Any - id: str - visible: bool - - -@dataclasses.dataclass -class GuiSetDisabledMessage(Message): - """Sent client->server when a GUI input is changed.""" - - id: str - disabled: bool - - -@dataclasses.dataclass -class GuiSetValueMessage(Message): - """Sent server->client to set the value of a particular input.""" - - id: str - value: Any + @override + def redundancy_key(self) -> str: + return type(self).__name__ + "-" + self.id + "-" + self.prop_name @dataclasses.dataclass diff --git a/src/viser/client/src/ControlPanel/Generated.tsx b/src/viser/client/src/ControlPanel/Generated.tsx index fa2c592bb..5bb2bb02c 100644 --- a/src/viser/client/src/ControlPanel/Generated.tsx +++ b/src/viser/client/src/ControlPanel/Generated.tsx @@ -1,51 +1,60 @@ -import { - GuiAddFolderMessage, - GuiAddTabGroupMessage, -} from "../WebsocketMessages"; -import { ViewerContext, ViewerContextContents } from "../App"; +import { ViewerContext } from "../App"; import { makeThrottledMessageSender } from "../WebsocketFunctions"; -import { computeRelativeLuminance } from "./GuiState"; -import { - Collapse, - Image, - Paper, - Tabs, - TabsValue, - useMantineTheme, -} from "@mantine/core"; +import { GuiComponentContext } from "./GuiComponentContext"; -import { - Box, - Button, - Checkbox, - ColorInput, - Flex, - NumberInput, - Select, - Slider, - Text, - TextInput, - Tooltip, -} from "@mantine/core"; -import { MultiSlider } from "./MultiSlider"; +import { Box } from "@mantine/core"; import React from "react"; -import Markdown from "../Markdown"; -import { ErrorBoundary } from "react-error-boundary"; -import { useDisclosure } from "@mantine/hooks"; -import { IconChevronDown, IconChevronUp } from "@tabler/icons-react"; +import ButtonComponent from "../components/Button"; +import SliderComponent from "../components/Slider"; +import NumberInputComponent from "../components/NumberInput"; +import TextInputComponent from "../components/TextInput"; +import CheckboxComponent from "../components/Checkbox"; +import Vector2Component from "../components/Vector2"; +import Vector3Component from "../components/Vector3"; +import DropdownComponent from "../components/Dropdown"; +import RgbComponent from "../components/Rgb"; +import RgbaComponent from "../components/Rgba"; +import ButtonGroupComponent from "../components/ButtonGroup"; +import MarkdownComponent from "../components/Markdown"; +import TabGroupComponent from "../components/TabGroup"; +import FolderComponent from "../components/Folder"; +import MultiSliderComponent from "../components/MultiSlider"; /** Root of generated inputs. */ export default function GeneratedGuiContainer({ - // We need to take viewer as input in drei's elements, where contexts break. containerId, - viewer, - folderDepth, }: { containerId: string; - viewer?: ViewerContextContents; - folderDepth?: number; }) { - if (viewer === undefined) viewer = React.useContext(ViewerContext)!; + const viewer = React.useContext(ViewerContext)!; + const updateGuiProps = viewer.useGui((state) => state.updateGuiProps); + const messageSender = makeThrottledMessageSender(viewer.websocketRef, 50); + + function setValue(id: string, value: any) { + updateGuiProps(id, "value", value); + messageSender({ + type: "GuiUpdateMessage", + id: id, + prop_name: "value", + prop_value: value, + }); + } + return ( + + + + ); +} + +function GuiContainer({ containerId }: { containerId: string }) { + const viewer = React.useContext(ViewerContext)!; const guiIdSet = viewer.useGui((state) => state.guiIdSetFromContainerId[containerId]) ?? {}; @@ -55,829 +64,61 @@ export default function GeneratedGuiContainer({ const guiOrderFromId = viewer!.useGui((state) => state.guiOrderFromId); if (guiIdSet === undefined) return null; - const guiIdOrderPairArray = guiIdArray.map((id) => ({ + let guiIdOrderPairArray = guiIdArray.map((id) => ({ id: id, order: guiOrderFromId[id], })); + guiIdOrderPairArray = guiIdOrderPairArray.sort((a, b) => a.order - b.order); const out = ( - - {guiIdOrderPairArray - .sort((a, b) => a.order - b.order) - .map((pair, index) => ( - - ))} + + {guiIdOrderPairArray.map((pair) => ( + + ))} ); return out; } /** A single generated GUI element. */ -function GeneratedInput({ - id, - viewer, - folderDepth, - last, -}: { - id: string; - viewer?: ViewerContextContents; - folderDepth: number; - last: boolean; -}) { - // Handle GUI input types. - if (viewer === undefined) viewer = React.useContext(ViewerContext)!; - const conf = viewer.useGui((state) => state.guiConfigFromId[id]); - - // Handle nested containers. - if (conf.type == "GuiAddFolderMessage") - return ( - - - - ); - if (conf.type == "GuiAddTabGroupMessage") - return ; - if (conf.type == "GuiAddMarkdownMessage") { - let { visible } = - viewer.useGui((state) => state.guiAttributeFromId[conf.id]) || {}; - visible = visible ?? true; - if (!visible) return <>; - return ( - - Markdown Failed to Render} - > - {conf.markdown} - - - ); - } - - const messageSender = makeThrottledMessageSender(viewer.websocketRef, 50); - function updateValue(value: any) { - setGuiValue(conf.id, value); - messageSender({ type: "GuiUpdateMessage", id: conf.id, value: value }); - } - - const setGuiValue = viewer.useGui((state) => state.setGuiValue); - const value = - viewer.useGui((state) => state.guiValueFromId[conf.id]) ?? - conf.initial_value; - const theme = useMantineTheme(); - - let { visible, disabled } = - viewer.useGui((state) => state.guiAttributeFromId[conf.id]) || {}; - - visible = visible ?? true; - disabled = disabled ?? false; - - if (!visible) return <>; - - let inputColor = - computeRelativeLuminance(theme.fn.primaryColor()) > 50.0 - ? theme.colors.gray[9] - : theme.white; - - let labeled = true; - let input = null; - let containerProps = {}; +function GeneratedInput(props: { guiId: string }) { + const viewer = React.useContext(ViewerContext)!; + const conf = viewer.useGui((state) => state.guiConfigFromId[props.guiId]); switch (conf.type) { + case "GuiAddFolderMessage": + return ; + case "GuiAddTabGroupMessage": + return ; + case "GuiAddMarkdownMessage": + return ; case "GuiAddButtonMessage": - labeled = false; - if (conf.color !== null) { - inputColor = - computeRelativeLuminance( - theme.colors[conf.color][theme.fn.primaryShade()], - ) > 50.0 - ? theme.colors.gray[9] - : theme.white; - } - - input = ( - - ); - break; + return ; case "GuiAddSliderMessage": - input = ( - - ({ - thumb: { - background: theme.fn.primaryColor(), - borderRadius: "0.1rem", - height: "0.75rem", - width: "0.625rem", - }, - trackContainer: { - zIndex: 3, - position: "relative", - }, - markLabel: { - transform: "translate(-50%, 0.03rem)", - fontSize: "0.6rem", - textAlign: "center", - }, - marksContainer: { - left: "0.2rem", - right: "0.2rem", - }, - markWrapper: { - position: "absolute", - top: `0.03rem`, - ...(conf.marks === null - ? /* Shift the mark labels so they don't spill too far out the left/right when we only have min and max marks. */ - { - ":first-child": { - "div:nth-child(2)": { - transform: "translate(-0.2rem, 0.03rem)", - }, - }, - ":last-child": { - "div:nth-child(2)": { - transform: "translate(-90%, 0.03rem)", - }, - }, - } - : {}), - }, - mark: { - border: "0px solid transparent", - background: - theme.colorScheme === "dark" - ? theme.colors.dark[4] - : theme.colors.gray[2], - width: "0.42rem", - height: "0.42rem", - transform: `translateX(-50%)`, - }, - markFilled: { - background: disabled - ? theme.colorScheme === "dark" - ? theme.colors.dark[3] - : theme.colors.gray[4] - : theme.fn.primaryColor(), - }, - })} - pt="0.2em" - showLabelOnHover={false} - min={conf.min} - max={conf.max} - step={conf.step ?? undefined} - precision={conf.precision} - value={value} - onChange={updateValue} - marks={ - conf.marks === null - ? [ - { - value: conf.min, - label: `${parseInt(conf.min.toFixed(6))}`, - }, - { - value: conf.max, - label: `${parseInt(conf.max.toFixed(6))}`, - }, - ] - : conf.marks - } - disabled={disabled} - /> - { - // Ignore empty values. - newValue !== "" && updateValue(newValue); - }} - size="xs" - min={conf.min} - max={conf.max} - hideControls - step={conf.step ?? undefined} - precision={conf.precision} - sx={{ width: "3rem" }} - styles={{ - input: { - padding: "0.375em", - letterSpacing: "-0.5px", - minHeight: "1.875em", - height: "1.875em", - }, - }} - ml="xs" - /> - - ); - break; + return ; case "GuiAddMultiSliderMessage": - input = ( - ({ - thumb: { - background: theme.fn.primaryColor(), - borderRadius: "0.1rem", - height: "0.75rem", - width: "0.625rem", - }, - trackContainer: { - zIndex: 3, - position: "relative", - }, - markLabel: { - transform: "translate(-50%, 0.03rem)", - fontSize: "0.6rem", - textAlign: "center", - }, - marksContainer: { - left: "0.2rem", - right: "0.2rem", - }, - markWrapper: { - position: "absolute", - top: `0.03rem`, - ...(conf.marks === null - ? /* Shift the mark labels so they don't spill too far out the left/right when we only have min and max marks. */ - { - ":first-child": { - "div:nth-child(2)": { - transform: "translate(-0.2rem, 0.03rem)", - }, - }, - ":last-child": { - "div:nth-child(2)": { - transform: "translate(-90%, 0.03rem)", - }, - }, - } - : {}), - }, - mark: { - border: "0px solid transparent", - background: - theme.colorScheme === "dark" - ? theme.colors.dark[4] - : theme.colors.gray[2], - width: "0.42rem", - height: "0.42rem", - transform: `translateX(-50%)`, - }, - markFilled: { - background: disabled - ? theme.colorScheme === "dark" - ? theme.colors.dark[3] - : theme.colors.gray[4] - : theme.fn.primaryColor(), - }, - })} - pt="0.2em" - showLabelOnHover={false} - min={conf.min} - max={conf.max} - step={conf.step ?? undefined} - precision={conf.precision} - value={value} - onChange={updateValue} - marks={ - conf.marks === null - ? [ - { - value: conf.min, - label: `${parseInt(conf.min.toFixed(6))}`, - }, - { - value: conf.max, - label: `${parseInt(conf.max.toFixed(6))}`, - }, - ] - : conf.marks - } - disabled={disabled} - fixedEndpoints={conf.fixed_endpoints} - minRange={conf.min_range || undefined} - /> - ); - - if (conf.marks?.some((x) => x.label) || conf.marks === null) - containerProps = { ...containerProps, mb: "xs" }; - break; + return ; case "GuiAddNumberMessage": - input = ( - { - // Ignore empty values. - newValue !== "" && updateValue(newValue); - }} - styles={{ - input: { - minHeight: "1.625rem", - height: "1.625rem", - }, - }} - disabled={disabled} - stepHoldDelay={500} - stepHoldInterval={(t) => Math.max(1000 / t ** 2, 25)} - /> - ); - break; + return ; case "GuiAddTextMessage": - input = ( - { - updateValue(value.target.value); - }} - styles={{ - input: { - minHeight: "1.625rem", - height: "1.625rem", - padding: "0 0.5em", - }, - }} - disabled={disabled} - /> - ); - break; + return ; case "GuiAddCheckboxMessage": - input = ( - { - updateValue(value.target.checked); - }} - disabled={disabled} - styles={{ - icon: { - color: inputColor + " !important", - }, - }} - /> - ); - break; + return ; case "GuiAddVector2Message": - input = ( - - ); - break; + return ; case "GuiAddVector3Message": - input = ( - - ); - break; + return ; case "GuiAddDropdownMessage": - input = ( - setValue(id, value)} + disabled={disabled} + searchable + maxDropdownHeight={400} + size="xs" + styles={{ + input: { + padding: "0.5em", + letterSpacing: "-0.5px", + minHeight: "1.625rem", + height: "1.625rem", + }, + }} + // zIndex of dropdown should be >modal zIndex. + // On edge cases: it seems like existing dropdowns are always closed when a new modal is opened. + zIndex={1000} + withinPortal + /> + + ); +} diff --git a/src/viser/client/src/components/Folder.tsx b/src/viser/client/src/components/Folder.tsx new file mode 100644 index 000000000..81fecd1b3 --- /dev/null +++ b/src/viser/client/src/components/Folder.tsx @@ -0,0 +1,78 @@ +import * as React from "react"; +import { useDisclosure } from "@mantine/hooks"; +import { GuiAddFolderMessage } from "../WebsocketMessages"; +import { IconChevronDown, IconChevronUp } from "@tabler/icons-react"; +import { Box, Collapse, Paper } from "@mantine/core"; +import { GuiComponentContext } from "../ControlPanel/GuiComponentContext"; +import { ViewerContext } from "../App"; + +export default function FolderComponent({ + id, + label, + visible, + expand_by_default, +}: GuiAddFolderMessage) { + const viewer = React.useContext(ViewerContext)!; + const [opened, { toggle }] = useDisclosure(expand_by_default); + const guiIdSet = viewer.useGui((state) => state.guiIdSetFromContainerId[id]); + const guiContext = React.useContext(GuiComponentContext)!; + const isEmpty = guiIdSet === undefined || Object.keys(guiIdSet).length === 0; + + const ToggleIcon = opened ? IconChevronUp : IconChevronDown; + if (!visible) return <>; + return ( + + + {label} + + + + + + + + + + + + ); +} diff --git a/src/viser/client/src/components/Markdown.tsx b/src/viser/client/src/components/Markdown.tsx new file mode 100644 index 000000000..752937c16 --- /dev/null +++ b/src/viser/client/src/components/Markdown.tsx @@ -0,0 +1,20 @@ +import { Box, Text } from "@mantine/core"; +import Markdown from "../Markdown"; +import { ErrorBoundary } from "react-error-boundary"; +import { GuiAddMarkdownMessage } from "../WebsocketMessages"; + +export default function MarkdownComponent({ + visible, + markdown, +}: GuiAddMarkdownMessage) { + if (!visible) return <>; + return ( + + Markdown Failed to Render} + > + {markdown} + + + ); +} diff --git a/src/viser/client/src/components/MultiSlider.tsx b/src/viser/client/src/components/MultiSlider.tsx new file mode 100644 index 000000000..726a2d337 --- /dev/null +++ b/src/viser/client/src/components/MultiSlider.tsx @@ -0,0 +1,117 @@ +import React from "react"; +import { GuiAddMultiSliderMessage } from "../WebsocketMessages"; +import { Box } from "@mantine/core"; +import { GuiComponentContext } from "../ControlPanel/GuiComponentContext"; +import { ViserInputComponent } from "./common"; +import { MultiSlider } from "./MultiSliderPrimitive/MultiSlider"; + +export default function MultiSliderComponent({ + id, + label, + hint, + visible, + disabled, + value, + ...otherProps +}: GuiAddMultiSliderMessage) { + const { setValue } = React.useContext(GuiComponentContext)!; + if (!visible) return <>; + const updateValue = (value: number[]) => setValue(id, value); + const { min, max, precision, step, marks, fixed_endpoints, min_range } = + otherProps; + const input = ( + + ({ + thumb: { + background: theme.fn.primaryColor(), + borderRadius: "0.1rem", + height: "0.75rem", + width: "0.625rem", + }, + trackContainer: { + zIndex: 3, + position: "relative", + }, + markLabel: { + transform: "translate(-50%, 0.03rem)", + fontSize: "0.6rem", + textAlign: "center", + }, + marksContainer: { + left: "0.2rem", + right: "0.2rem", + }, + markWrapper: { + position: "absolute", + top: `0.03rem`, + ...(marks === null + ? /* Shift the mark labels so they don't spill too far out the left/right when we only have min and max marks. */ + { + ":first-child": { + "div:nth-child(2)": { + transform: "translate(-0.2rem, 0.03rem)", + }, + }, + ":last-child": { + "div:nth-child(2)": { + transform: "translate(-90%, 0.03rem)", + }, + }, + } + : {}), + }, + mark: { + border: "0px solid transparent", + background: + theme.colorScheme === "dark" + ? theme.colors.dark[4] + : theme.colors.gray[2], + width: "0.42rem", + height: "0.42rem", + transform: `translateX(-50%)`, + }, + markFilled: { + background: disabled + ? theme.colorScheme === "dark" + ? theme.colors.dark[3] + : theme.colors.gray[4] + : theme.fn.primaryColor(), + }, + })} + pt="0.2em" + showLabelOnHover={false} + min={min} + max={max} + step={step ?? undefined} + precision={precision} + value={value} + onChange={updateValue} + marks={ + marks === null + ? [ + { + value: min, + label: `${parseInt(min.toFixed(6))}`, + }, + { + value: max, + label: `${parseInt(max.toFixed(6))}`, + }, + ] + : marks + } + disabled={disabled} + fixedEndpoints={fixed_endpoints} + minRange={min_range || undefined} + /> + + ); + + return ( + {input} + ); +} diff --git a/src/viser/client/src/ControlPanel/MultiSlider.styles.tsx b/src/viser/client/src/components/MultiSliderPrimitive/MultiSlider.styles.tsx similarity index 100% rename from src/viser/client/src/ControlPanel/MultiSlider.styles.tsx rename to src/viser/client/src/components/MultiSliderPrimitive/MultiSlider.styles.tsx diff --git a/src/viser/client/src/ControlPanel/MultiSlider.tsx b/src/viser/client/src/components/MultiSliderPrimitive/MultiSlider.tsx similarity index 100% rename from src/viser/client/src/ControlPanel/MultiSlider.tsx rename to src/viser/client/src/components/MultiSliderPrimitive/MultiSlider.tsx diff --git a/src/viser/client/src/components/NumberInput.tsx b/src/viser/client/src/components/NumberInput.tsx new file mode 100644 index 000000000..8063dade1 --- /dev/null +++ b/src/viser/client/src/components/NumberInput.tsx @@ -0,0 +1,45 @@ +import * as React from "react"; +import { GuiComponentContext } from "../ControlPanel/GuiComponentContext"; +import { GuiAddNumberMessage } from "../WebsocketMessages"; +import { ViserInputComponent } from "./common"; +import { NumberInput } from "@mantine/core"; + +export default function NumberInputComponent({ + visible, + id, + label, + hint, + value, + disabled, + ...otherProps +}: GuiAddNumberMessage) { + const { setValue } = React.useContext(GuiComponentContext)!; + const { precision, min, max, step } = otherProps; + if (!visible) return <>; + return ( + + { + // Ignore empty values. + newValue !== "" && setValue(id, newValue); + }} + styles={{ + input: { + minHeight: "1.625rem", + height: "1.625rem", + }, + }} + disabled={disabled} + stepHoldDelay={500} + stepHoldInterval={(t) => Math.max(1000 / t ** 2, 25)} + /> + + ); +} diff --git a/src/viser/client/src/components/Rgb.tsx b/src/viser/client/src/components/Rgb.tsx new file mode 100644 index 000000000..fdc6a67c9 --- /dev/null +++ b/src/viser/client/src/components/Rgb.tsx @@ -0,0 +1,37 @@ +import * as React from "react"; +import { ColorInput } from "@mantine/core"; +import { GuiComponentContext } from "../ControlPanel/GuiComponentContext"; +import { rgbToHex, hexToRgb } from "./utils"; +import { ViserInputComponent } from "./common"; +import { GuiAddRgbMessage } from "../WebsocketMessages"; + +export default function RgbComponent({ + id, + label, + hint, + value, + disabled, + visible, +}: GuiAddRgbMessage) { + const { setValue } = React.useContext(GuiComponentContext)!; + if (!visible) return <>; + return ( + + setValue(id, hexToRgb(v))} + format="hex" + // zIndex of dropdown should be >modal zIndex. + // On edge cases: it seems like existing dropdowns are always closed when a new modal is opened. + dropdownZIndex={1000} + withinPortal + styles={{ + input: { height: "1.625rem", minHeight: "1.625rem" }, + icon: { transform: "scale(0.8)" }, + }} + /> + + ); +} diff --git a/src/viser/client/src/components/Rgba.tsx b/src/viser/client/src/components/Rgba.tsx new file mode 100644 index 000000000..b96491b90 --- /dev/null +++ b/src/viser/client/src/components/Rgba.tsx @@ -0,0 +1,36 @@ +import * as React from "react"; +import { ColorInput } from "@mantine/core"; +import { GuiComponentContext } from "../ControlPanel/GuiComponentContext"; +import { rgbaToHex, hexToRgba } from "./utils"; +import { ViserInputComponent } from "./common"; +import { GuiAddRgbaMessage } from "../WebsocketMessages"; + +export default function RgbaComponent({ + id, + label, + hint, + value, + disabled, + visible, +}: GuiAddRgbaMessage) { + const { setValue } = React.useContext(GuiComponentContext)!; + if (!visible) return <>; + return ( + + setValue(id, hexToRgba(v))} + format="hexa" + // zIndex of dropdown should be >modal zIndex. + // On edge cases: it seems like existing dropdowns are always closed when a new modal is opened. + dropdownZIndex={1000} + withinPortal + styles={{ + input: { height: "1.625rem", minHeight: "1.625rem" }, + }} + /> + + ); +} diff --git a/src/viser/client/src/components/Slider.tsx b/src/viser/client/src/components/Slider.tsx new file mode 100644 index 000000000..4e7ca0a81 --- /dev/null +++ b/src/viser/client/src/components/Slider.tsx @@ -0,0 +1,138 @@ +import React from "react"; +import { GuiAddSliderMessage } from "../WebsocketMessages"; +import { Slider, Flex, NumberInput } from "@mantine/core"; +import { GuiComponentContext } from "../ControlPanel/GuiComponentContext"; +import { ViserInputComponent } from "./common"; + +export default function SliderComponent({ + id, + label, + hint, + visible, + disabled, + value, + ...otherProps +}: GuiAddSliderMessage) { + const { setValue } = React.useContext(GuiComponentContext)!; + if (!visible) return <>; + const updateValue = (value: number) => setValue(id, value); + const { min, max, precision, step, marks } = otherProps; + const input = ( + + ({ + thumb: { + background: theme.fn.primaryColor(), + borderRadius: "0.1rem", + height: "0.75rem", + width: "0.625rem", + }, + trackContainer: { + zIndex: 3, + position: "relative", + }, + markLabel: { + transform: "translate(-50%, 0.03rem)", + fontSize: "0.6rem", + textAlign: "center", + }, + marksContainer: { + left: "0.2rem", + right: "0.2rem", + }, + markWrapper: { + position: "absolute", + top: `0.03rem`, + ...(marks === null + ? /* Shift the mark labels so they don't spill too far out the left/right when we only have min and max marks. */ + { + ":first-child": { + "div:nth-child(2)": { + transform: "translate(-0.2rem, 0.03rem)", + }, + }, + ":last-child": { + "div:nth-child(2)": { + transform: "translate(-90%, 0.03rem)", + }, + }, + } + : {}), + }, + mark: { + border: "0px solid transparent", + background: + theme.colorScheme === "dark" + ? theme.colors.dark[4] + : theme.colors.gray[2], + width: "0.42rem", + height: "0.42rem", + transform: `translateX(-50%)`, + }, + markFilled: { + background: disabled + ? theme.colorScheme === "dark" + ? theme.colors.dark[3] + : theme.colors.gray[4] + : theme.fn.primaryColor(), + }, + })} + pt="0.2em" + pb="0.4em" + showLabelOnHover={false} + min={min} + max={max} + step={step ?? undefined} + precision={precision} + value={value} + onChange={updateValue} + marks={ + marks === null + ? [ + { + value: min, + label: `${parseInt(min.toFixed(6))}`, + }, + { + value: max, + label: `${parseInt(max.toFixed(6))}`, + }, + ] + : marks + } + disabled={disabled} + /> + { + // Ignore empty values. + newValue !== "" && updateValue(newValue); + }} + size="xs" + min={min} + max={max} + hideControls + step={step ?? undefined} + precision={precision} + sx={{ width: "3rem" }} + styles={{ + input: { + padding: "0.375em", + letterSpacing: "-0.5px", + minHeight: "1.875em", + height: "1.875em", + }, + }} + ml="xs" + /> + + ); + + return ( + {input} + ); +} diff --git a/src/viser/client/src/components/TabGroup.tsx b/src/viser/client/src/components/TabGroup.tsx new file mode 100644 index 000000000..cd9aeeb57 --- /dev/null +++ b/src/viser/client/src/components/TabGroup.tsx @@ -0,0 +1,55 @@ +import * as React from "react"; +import { GuiAddTabGroupMessage } from "../WebsocketMessages"; +import { Tabs, TabsValue } from "@mantine/core"; +import { Image } from "@mantine/core"; +import { GuiComponentContext } from "../ControlPanel/GuiComponentContext"; + +export default function TabGroupComponent({ + tab_labels, + tab_icons_base64, + tab_container_ids, + visible, +}: GuiAddTabGroupMessage) { + const [tabState, setTabState] = React.useState("0"); + const icons = tab_icons_base64; + const { GuiContainer } = React.useContext(GuiComponentContext)!; + if (!visible) return <>; + return ( + + + {tab_labels.map((label, index) => ( + ({ + filter: + theme.colorScheme == "dark" ? "invert(1)" : undefined, + })} + src={"data:image/svg+xml;base64," + icons[index]} + /> + ) + } + > + {label} + + ))} + + {tab_container_ids.map((containerId, index) => ( + + + + ))} + + ); +} diff --git a/src/viser/client/src/components/TextInput.tsx b/src/viser/client/src/components/TextInput.tsx new file mode 100644 index 000000000..1d4002b02 --- /dev/null +++ b/src/viser/client/src/components/TextInput.tsx @@ -0,0 +1,30 @@ +import * as React from "react"; +import { TextInput } from "@mantine/core"; +import { ViserInputComponent } from "./common"; +import { GuiAddTextMessage } from "../WebsocketMessages"; +import { GuiComponentContext } from "../ControlPanel/GuiComponentContext"; + +export default function TextInputComponent(props: GuiAddTextMessage) { + const { id, hint, label, value, disabled, visible } = props; + const { setValue } = React.useContext(GuiComponentContext)!; + if (!visible) return <>; + return ( + + { + setValue(id, value.target.value); + }} + styles={{ + input: { + minHeight: "1.625rem", + height: "1.625rem", + padding: "0 0.5em", + }, + }} + disabled={disabled} + /> + + ); +} diff --git a/src/viser/client/src/components/Vector2.tsx b/src/viser/client/src/components/Vector2.tsx new file mode 100644 index 000000000..089d0dc4d --- /dev/null +++ b/src/viser/client/src/components/Vector2.tsx @@ -0,0 +1,33 @@ +import * as React from "react"; +import { GuiComponentContext } from "../ControlPanel/GuiComponentContext"; +import { GuiAddVector2Message } from "../WebsocketMessages"; +import { VectorInput, ViserInputComponent } from "./common"; + +export default function Vector2Component({ + id, + hint, + label, + visible, + disabled, + value, + ...otherProps +}: GuiAddVector2Message) { + const { min, max, step, precision } = otherProps; + const { setValue } = React.useContext(GuiComponentContext)!; + if (!visible) return <>; + return ( + + setValue(id, value)} + min={min} + max={max} + step={step} + precision={precision} + disabled={disabled} + /> + + ); +} diff --git a/src/viser/client/src/components/Vector3.tsx b/src/viser/client/src/components/Vector3.tsx new file mode 100644 index 000000000..4b20219f8 --- /dev/null +++ b/src/viser/client/src/components/Vector3.tsx @@ -0,0 +1,33 @@ +import * as React from "react"; +import { GuiComponentContext } from "../ControlPanel/GuiComponentContext"; +import { GuiAddVector3Message } from "../WebsocketMessages"; +import { VectorInput, ViserInputComponent } from "./common"; + +export default function Vector3Component({ + id, + hint, + label, + visible, + disabled, + value, + ...otherProps +}: GuiAddVector3Message) { + const { min, max, step, precision } = otherProps; + const { setValue } = React.useContext(GuiComponentContext)!; + if (!visible) return <>; + return ( + + setValue(id, value)} + min={min} + max={max} + step={step} + precision={precision} + disabled={disabled} + /> + + ); +} diff --git a/src/viser/client/src/components/common.tsx b/src/viser/client/src/components/common.tsx new file mode 100644 index 000000000..0e2f41fbd --- /dev/null +++ b/src/viser/client/src/components/common.tsx @@ -0,0 +1,149 @@ +import * as React from "react"; +import { Box, Flex, Text, NumberInput, Tooltip } from "@mantine/core"; +import { GuiComponentContext } from "../ControlPanel/GuiComponentContext"; + +export function ViserInputComponent({ + id, + label, + hint, + children, +}: { + id: string; + children: React.ReactNode; + label?: string; + hint?: string | null; +}) { + const { folderDepth } = React.useContext(GuiComponentContext)!; + if (hint !== undefined && hint !== null) { + children = // We need to add for inputs that we can't assign refs to. + ( + + {children} + + ); + } + + if (label !== undefined) + children = ( + + ); + + return ( + + {children} + + ); +} + +/** GUI input with a label horizontally placed to the left of it. */ +function LabeledInput(props: { + id: string; + label: string; + input: React.ReactNode; + folderDepth: number; +}) { + return ( + + + + + + + {props.input} + + ); +} + +export function VectorInput( + props: + | { + id: string; + n: 2; + value: [number, number]; + min: [number, number] | null; + max: [number, number] | null; + step: number; + precision: number; + onChange: (value: number[]) => void; + disabled: boolean; + } + | { + id: string; + n: 3; + value: [number, number, number]; + min: [number, number, number] | null; + max: [number, number, number] | null; + step: number; + precision: number; + onChange: (value: number[]) => void; + disabled: boolean; + }, +) { + return ( + + {[...Array(props.n).keys()].map((i) => ( + { + const updated = [...props.value]; + updated[i] = v === "" ? 0.0 : v; + props.onChange(updated); + }} + size="xs" + styles={{ + root: { flexGrow: 1, width: 0 }, + input: { + paddingLeft: "0.5em", + paddingRight: "1.75em", + textAlign: "right", + minHeight: "1.875em", + height: "1.875em", + }, + rightSection: { width: "1.2em" }, + control: { + width: "1.1em", + }, + }} + precision={props.precision} + step={props.step} + min={props.min === null ? undefined : props.min[i]} + max={props.max === null ? undefined : props.max[i]} + stepHoldDelay={500} + stepHoldInterval={(t) => Math.max(1000 / t ** 2, 25)} + disabled={props.disabled} + /> + ))} + + ); +} diff --git a/src/viser/client/src/components/utils.tsx b/src/viser/client/src/components/utils.tsx new file mode 100644 index 000000000..49271026b --- /dev/null +++ b/src/viser/client/src/components/utils.tsx @@ -0,0 +1,37 @@ +// Color conversion helpers. + +export function rgbToHex([r, g, b]: [number, number, number]): string { + const hexR = r.toString(16).padStart(2, "0"); + const hexG = g.toString(16).padStart(2, "0"); + const hexB = b.toString(16).padStart(2, "0"); + return `#${hexR}${hexG}${hexB}`; +} + +export function hexToRgb(hexColor: string): [number, number, number] { + const hex = hexColor.slice(1); // Remove the # in #ffffff. + const r = parseInt(hex.substring(0, 2), 16); + const g = parseInt(hex.substring(2, 4), 16); + const b = parseInt(hex.substring(4, 6), 16); + return [r, g, b]; +} +export function rgbaToHex([r, g, b, a]: [ + number, + number, + number, + number, +]): string { + const hexR = r.toString(16).padStart(2, "0"); + const hexG = g.toString(16).padStart(2, "0"); + const hexB = b.toString(16).padStart(2, "0"); + const hexA = a.toString(16).padStart(2, "0"); + return `#${hexR}${hexG}${hexB}${hexA}`; +} + +export function hexToRgba(hexColor: string): [number, number, number, number] { + const hex = hexColor.slice(1); // Remove the # in #ffffff. + const r = parseInt(hex.substring(0, 2), 16); + const g = parseInt(hex.substring(2, 4), 16); + const b = parseInt(hex.substring(4, 6), 16); + const a = parseInt(hex.substring(6, 8), 16); + return [r, g, b, a]; +} diff --git a/src/viser/infra/_messages.py b/src/viser/infra/_messages.py index ce9d2af20..e4c68a535 100644 --- a/src/viser/infra/_messages.py +++ b/src/viser/infra/_messages.py @@ -18,6 +18,32 @@ ClientId = Any +def _prepare_for_deserialization(value: Any, annotation: Type) -> Any: + # If annotated as a float but we got an integer, cast to float. These + # are both `number` in Javascript. + if annotation is float: + return float(value) + elif annotation is int: + return int(value) + elif get_origin(annotation) is tuple: + out = [] + args = get_args(annotation) + if len(args) >= 2 and args[1] == ...: + args = (args[0],) * len(value) + elif len(value) != len(args): + warnings.warn(f"[viser] {value} does not match annotation {annotation}") + return value + + for i, v in enumerate(value): + out.append( + # Hack to be OK with wrong type annotations. + # https://github.com/nerfstudio-project/nerfstudio/pull/1805 + _prepare_for_deserialization(v, args[i]) if i < len(args) else v + ) + return tuple(out) + return value + + def _prepare_for_serialization(value: Any, annotation: Type) -> Any: """Prepare any special types for serialization.""" @@ -38,19 +64,19 @@ def _prepare_for_serialization(value: Any, annotation: Type) -> Any: out = [] args = get_args(annotation) - if len(args) >= 1: - if len(args) >= 2 and args[1] == ...: - args = (args[0],) * len(value) - elif len(value) != len(args): - warnings.warn(f"[viser] {value} does not match annotation {annotation}") - - for i, v in enumerate(value): - out.append( - # Hack to be OK with wrong type annotations. - # https://github.com/nerfstudio-project/nerfstudio/pull/1805 - _prepare_for_serialization(v, args[i]) if i < len(args) else v - ) - return tuple(out) + if len(args) >= 2 and args[1] == ...: + args = (args[0],) * len(value) + elif len(value) != len(args): + warnings.warn(f"[viser] {value} does not match annotation {annotation}") + return value + + for i, v in enumerate(value): + out.append( + # Hack to be OK with wrong type annotations. + # https://github.com/nerfstudio-project/nerfstudio/pull/1805 + _prepare_for_serialization(v, args[i]) if i < len(args) else v + ) + return tuple(out) # For arrays, we serialize underlying data directly. The client is responsible for # reading using the correct dtype. @@ -77,13 +103,25 @@ class Message(abc.ABC): def as_serializable_dict(self) -> Dict[str, Any]: """Convert a Python Message object into bytes.""" - hints = get_type_hints_cached(type(self)) + message_type = type(self) + hints = get_type_hints_cached(message_type) out = { k: _prepare_for_serialization(v, hints[k]) for k, v in vars(self).items() } - out["type"] = type(self).__name__ + out["type"] = message_type.__name__ return out + @classmethod + def _from_serializable_dict(cls, mapping: Dict[str, Any]) -> Dict[str, Any]: + """Convert a dict message back into a Python Message object.""" + + hints = get_type_hints_cached(cls) + + mapping = { + k: _prepare_for_deserialization(v, hints[k]) for k, v in mapping.items() + } + return mapping + @classmethod def deserialize(cls, message: bytes) -> Message: """Convert bytes into a Python Message object.""" @@ -95,23 +133,8 @@ def deserialize(cls, message: bytes) -> Message: k: tuple(v) if isinstance(v, list) else v for k, v in mapping.items() } message_type = cls._subclass_from_type_string()[cast(str, mapping.pop("type"))] - - # If annotated as a float but we got an integer, cast to float. These - # are both `number` in Javascript. - def coerce_floats(value: Any, annotation: Type[Any]) -> Any: - if annotation is float: - return float(value) - elif get_origin(annotation) is tuple: - return tuple( - coerce_floats(value[i], typ) - for i, typ in enumerate(get_args(annotation)) - ) - else: - return value - - type_hints = get_type_hints(message_type) - mapping = {k: coerce_floats(v, type_hints[k]) for k, v in mapping.items()} - return message_type(**mapping) # type: ignore + message_kwargs = message_type._from_serializable_dict(mapping) + return message_type(**message_kwargs) @classmethod @functools.lru_cache(maxsize=100) diff --git a/src/viser/infra/_typescript_interface_gen.py b/src/viser/infra/_typescript_interface_gen.py index e08a9229b..3204d9a33 100644 --- a/src/viser/infra/_typescript_interface_gen.py +++ b/src/viser/infra/_typescript_interface_gen.py @@ -1,4 +1,5 @@ import dataclasses +from collections import defaultdict from typing import Any, ClassVar, Type, Union, cast, get_type_hints import numpy as onp @@ -81,6 +82,7 @@ def generate_typescript_interfaces(message_cls: Type[Message]) -> str: """Generate TypeScript definitions for all subclasses of a base message class.""" out_lines = [] message_types = message_cls.get_subclasses() + tag_map = defaultdict(list) # Generate interfaces for each specific message. for cls in message_types: @@ -93,6 +95,15 @@ def generate_typescript_interfaces(message_cls: Type[Message]) -> str: out_lines.append(" * (automatically generated)") out_lines.append(" */") + for tag in getattr(cls, "_tags", []): + tag_map[tag].append(cls.__name__) + + get_ts_type = getattr(cls, "_get_ts_type", None) + if get_ts_type is not None: + assert callable(get_ts_type) + out_lines.append(get_ts_type()) + continue + out_lines.append(f"export interface {cls.__name__} " + "{") out_lines.append(f' type: "{cls.__name__}";') field_names = set([f.name for f in dataclasses.fields(cls)]) # type: ignore @@ -113,6 +124,13 @@ def generate_typescript_interfaces(message_cls: Type[Message]) -> str: out_lines.append(f" | {cls.__name__}") out_lines[-1] = out_lines[-1] + ";" + # Generate union type over all tags. + for tag, cls_names in tag_map.items(): + out_lines.append(f"export type {tag} = ") + for cls_name in cls_names: + out_lines.append(f" | {cls_name}") + out_lines[-1] = out_lines[-1] + ";" + interfaces = "\n".join(out_lines) + "\n" # Add header and return.