Skip to content

Commit 583d25c

Browse files
James RobinsonKafonek
andauthored
Port to pydantic 2 (#192)
* Progress porting to pydantic 2 * isort * Tests over BaseRTU * Test over File.construct_url * Test over User.construct_auth_type. * Test over Space.construct_url() * Test project.construct_url * Notebook model unit tests. * isort * Remove unused import * Another round of dependency updates * Test suite passing without warnings * lint * Remove bump-pydantic, no longer needed * Better changelog * Merge changelog * Move to new changelog section * clean up DeltaCallback class (not Pydantic anymore) --------- Co-authored-by: Kafonek <matt.kafonek@noteable.io>
1 parent 48e0591 commit 583d25c

30 files changed

+750
-493
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
77
For pre-1.0 releases, see [0.0.35 Changelog](https://github.com/noteable-io/origami/blob/0.0.35/CHANGELOG.md)
88

99
## [Unreleased]
10+
### Changed
11+
- Upgraded pydantic to 2.4.2 up from 1.X.
1012

1113
### [1.1.5] - 2023-11-06
1214
### Fixed

origami/clients/api.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ async def user_info(self) -> User:
104104
endpoint = "/users/me"
105105
resp = await self.client.get(endpoint)
106106
resp.raise_for_status()
107-
user = User.parse_obj(resp.json())
107+
user = User.model_validate(resp.json())
108108
self.add_tags_and_contextvars(user_id=str(user.id))
109109
return user
110110

@@ -191,7 +191,7 @@ async def create_space(self, name: str, description: Optional[str] = None) -> Sp
191191
endpoint = "/spaces"
192192
resp = await self.client.post(endpoint, json={"name": name, "description": description})
193193
resp.raise_for_status()
194-
space = Space.parse_obj(resp.json())
194+
space = Space.model_validate(resp.json())
195195
self.add_tags_and_contextvars(space_id=str(space.id))
196196
return space
197197

@@ -200,7 +200,7 @@ async def get_space(self, space_id: uuid.UUID) -> Space:
200200
endpoint = f"/spaces/{space_id}"
201201
resp = await self.client.get(endpoint)
202202
resp.raise_for_status()
203-
space = Space.parse_obj(resp.json())
203+
space = Space.model_validate(resp.json())
204204
return space
205205

206206
async def delete_space(self, space_id: uuid.UUID) -> None:
@@ -216,7 +216,7 @@ async def list_space_projects(self, space_id: uuid.UUID) -> List[Project]:
216216
endpoint = f"/spaces/{space_id}/projects"
217217
resp = await self.client.get(endpoint)
218218
resp.raise_for_status()
219-
projects = [Project.parse_obj(project) for project in resp.json()]
219+
projects = [Project.model_validate(project) for project in resp.json()]
220220
return projects
221221

222222
async def share_space(
@@ -267,7 +267,7 @@ async def create_project(
267267
},
268268
)
269269
resp.raise_for_status()
270-
project = Project.parse_obj(resp.json())
270+
project = Project.model_validate(resp.json())
271271
self.add_tags_and_contextvars(project_id=str(project.id))
272272
return project
273273

@@ -276,15 +276,15 @@ async def get_project(self, project_id: uuid.UUID) -> Project:
276276
endpoint = f"/projects/{project_id}"
277277
resp = await self.client.get(endpoint)
278278
resp.raise_for_status()
279-
project = Project.parse_obj(resp.json())
279+
project = Project.model_validate(resp.json())
280280
return project
281281

282282
async def delete_project(self, project_id: uuid.UUID) -> Project:
283283
self.add_tags_and_contextvars(project_id=str(project_id))
284284
endpoint = f"/projects/{project_id}"
285285
resp = await self.client.delete(endpoint)
286286
resp.raise_for_status()
287-
project = Project.parse_obj(resp.json())
287+
project = Project.model_validate(resp.json())
288288
return project
289289

290290
async def share_project(
@@ -323,7 +323,7 @@ async def list_project_files(self, project_id: uuid.UUID) -> List[File]:
323323
endpoint = f"/projects/{project_id}/files"
324324
resp = await self.client.get(endpoint)
325325
resp.raise_for_status()
326-
files = [File.parse_obj(file) for file in resp.json()]
326+
files = [File.model_validate(file) for file in resp.json()]
327327
return files
328328

329329
# Files are flat files (like text, csv, etc) or Notebooks.
@@ -355,7 +355,7 @@ async def _multi_step_file_create(
355355
upload_url = js["presigned_upload_url_info"]["parts"][0]["upload_url"]
356356
upload_id = js["presigned_upload_url_info"]["upload_id"]
357357
upload_key = js["presigned_upload_url_info"]["key"]
358-
file = File.parse_obj(js)
358+
file = File.model_validate(js)
359359

360360
# (2) Upload to pre-signed url
361361
# TODO: remove this hack if/when we get containers in Skaffold to be able to translate
@@ -393,7 +393,7 @@ async def create_notebook(
393393
self.add_tags_and_contextvars(project_id=str(project_id))
394394
if notebook is None:
395395
notebook = Notebook()
396-
content = notebook.json().encode()
396+
content = notebook.model_dump_json().encode()
397397
file = await self._multi_step_file_create(project_id, path, "notebook", content)
398398
self.add_tags_and_contextvars(file_id=str(file.id))
399399
logger.info("Created new notebook", extra={"file_id": str(file.id)})
@@ -405,7 +405,7 @@ async def get_file(self, file_id: uuid.UUID) -> File:
405405
endpoint = f"/v1/files/{file_id}"
406406
resp = await self.client.get(endpoint)
407407
resp.raise_for_status()
408-
file = File.parse_obj(resp.json())
408+
file = File.model_validate(resp.json())
409409
return file
410410

411411
async def get_file_content(self, file_id: uuid.UUID) -> bytes:
@@ -433,15 +433,15 @@ async def get_file_versions(self, file_id: uuid.UUID) -> List[FileVersion]:
433433
endpoint = f"/files/{file_id}/versions"
434434
resp = await self.client.get(endpoint)
435435
resp.raise_for_status()
436-
versions = [FileVersion.parse_obj(version) for version in resp.json()]
436+
versions = [FileVersion.model_validate(version) for version in resp.json()]
437437
return versions
438438

439439
async def delete_file(self, file_id: uuid.UUID) -> File:
440440
self.add_tags_and_contextvars(file_id=str(file_id))
441441
endpoint = f"/v1/files/{file_id}"
442442
resp = await self.client.delete(endpoint)
443443
resp.raise_for_status()
444-
file = File.parse_obj(resp.json())
444+
file = File.model_validate(resp.json())
445445
return file
446446

447447
async def share_file(
@@ -497,7 +497,7 @@ async def launch_kernel(
497497
}
498498
resp = await self.client.post(endpoint, json=data)
499499
resp.raise_for_status()
500-
kernel_session = KernelSession.parse_obj(resp.json())
500+
kernel_session = KernelSession.model_validate(resp.json())
501501
self.add_tags_and_contextvars(kernel_session_id=str(kernel_session.id))
502502
logger.info(
503503
"Launched new kernel",
@@ -517,7 +517,7 @@ async def get_output_collection(
517517
endpoint = f"/outputs/collection/{output_collection_id}"
518518
resp = await self.client.get(endpoint)
519519
resp.raise_for_status()
520-
return KernelOutputCollection.parse_obj(resp.json())
520+
return KernelOutputCollection.model_validate(resp.json())
521521

522522
async def connect_realtime(self, file: Union[File, uuid.UUID, str]) -> "RTUClient": # noqa
523523
"""

origami/clients/rtu.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515

1616
import httpx
1717
import orjson
18-
from pydantic import BaseModel, parse_obj_as
1918
from sending.backends.websocket import WebsocketManager
2019
from websockets.client import WebSocketClientProtocol
2120

@@ -51,7 +50,7 @@
5150
KernelStatusUpdateResponse,
5251
)
5352
from origami.models.rtu.channels.system import AuthenticateReply, AuthenticateRequest
54-
from origami.models.rtu.discriminators import RTURequest, RTUResponse
53+
from origami.models.rtu.discriminators import RTURequest, RTUResponse, RTUResponseParser
5554
from origami.models.rtu.errors import InconsistentStateEvent
5655
from origami.notebook.builder import CellNotFound, NotebookBuilder
5756

@@ -87,7 +86,8 @@ async def inbound_message_hook(self, contents: str) -> RTUResponse:
8786
# to error or BaseRTUResponse)
8887
data: dict = orjson.loads(contents)
8988
data["channel_prefix"] = data.get("channel", "").split("/")[0]
90-
rtu_event = parse_obj_as(RTUResponse, data)
89+
90+
rtu_event = RTUResponseParser.validate_python(data)
9191

9292
# Debug Logging
9393
extra_dict = {
@@ -98,15 +98,18 @@ async def inbound_message_hook(self, contents: str) -> RTUResponse:
9898
if isinstance(rtu_event, NewDeltaEvent):
9999
extra_dict["delta_type"] = rtu_event.data.delta_type
100100
extra_dict["delta_action"] = rtu_event.data.delta_action
101-
logger.debug(f"Received: {data}\nParsed: {rtu_event.dict()}", extra=extra_dict)
101+
102+
if logging.DEBUG >= logging.root.level:
103+
logger.debug(f"Received: {data}\nParsed: {rtu_event.model_dump()}", extra=extra_dict)
104+
102105
return rtu_event
103106

104107
async def outbound_message_hook(self, contents: RTURequest) -> str:
105108
"""
106109
Hook applied to every message we send out over the websocket.
107110
- Anything calling .send() should pass in an RTU Request pydantic model
108111
"""
109-
return contents.json()
112+
return contents.model_dump_json()
110113

111114
def send(self, message: RTURequest) -> None:
112115
"""Override WebsocketManager-defined method for type hinting and logging."""
@@ -118,7 +121,9 @@ def send(self, message: RTURequest) -> None:
118121
if message.event == "new_delta_request":
119122
extra_dict["delta_type"] = message.data.delta.delta_type
120123
extra_dict["delta_action"] = message.data.delta.delta_action
124+
121125
logger.debug("Sending: RTU request", extra=extra_dict)
126+
122127
super().send(message) # the .outbound_message_hook handles serializing this to json
123128

124129
async def on_exception(self, exc: Exception):
@@ -143,11 +148,10 @@ class DeltaRejected(Exception):
143148

144149

145150
# Used in registering callback functions that get called right after squashing a Delta
146-
class DeltaCallback(BaseModel):
147-
# callback function should be async and expect one argument: a FileDelta
148-
# Doesn't matter what it returns. Pydantic doesn't validate Callable args/return.
149-
delta_class: Type[FileDelta]
150-
fn: Callable[[FileDelta], Awaitable[None]]
151+
class DeltaCallback:
152+
def __init__(self, delta_class: Type[FileDelta], fn: Callable[[FileDelta], Awaitable[None]]):
153+
self.delta_class = delta_class
154+
self.fn = fn
151155

152156

153157
class DeltaRequestCallbackManager:
@@ -455,7 +459,7 @@ async def load_seed_notebook(self):
455459
resp = await plain_http_client.get(file.presigned_download_url)
456460
resp.raise_for_status()
457461

458-
seed_notebook = Notebook.parse_obj(resp.json())
462+
seed_notebook = Notebook.model_validate(resp.json())
459463
self.builder = NotebookBuilder(seed_notebook=seed_notebook)
460464

461465
# See Sending backends.websocket for details but a quick refresher on hook timing:
@@ -494,7 +498,7 @@ async def auth_hook(self, *args, **kwargs):
494498
# we observe the auth reply. Instead use the unauth_ws directly and manually serialize
495499
ws: WebSocketClientProtocol = await self.manager.unauth_ws
496500
logger.info(f"Sending auth request with jwt {jwt[:5]}...{jwt[-5:]}")
497-
await ws.send(auth_request.json())
501+
await ws.send(auth_request.model_dump_json())
498502

499503
async def on_auth(self, msg: AuthenticateReply):
500504
# hook for Application code to override, consider catastrophic failure on auth failure

origami/models/api/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,4 @@ class ResourceBase(BaseModel):
99
id: uuid.UUID
1010
created_at: datetime
1111
updated_at: datetime
12-
deleted_at: Optional[datetime]
12+
deleted_at: Optional[datetime] = None

origami/models/api/datasources.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@ class DataSource(BaseModel):
1212
type_id: str # e.g. duckdb, postgresql
1313
sql_cell_handle: str # this goes in cell metadata for SQL cells
1414
# One of these three will be not None, and that tells you the scope of the datasource
15-
space_id: Optional[uuid.UUID]
16-
project_id: Optional[uuid.UUID]
17-
user_id: Optional[uuid.UUID]
15+
space_id: Optional[uuid.UUID] = None
16+
project_id: Optional[uuid.UUID] = None
17+
user_id: Optional[uuid.UUID] = None
1818
created_by_id: uuid.UUID
1919
created_at: datetime
2020
updated_at: datetime

origami/models/api/files.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import uuid
44
from typing import Literal, Optional
55

6-
from pydantic import validator
6+
from pydantic import model_validator
77

88
from origami.models.api.base import ResourceBase
99

@@ -22,10 +22,12 @@ class File(ResourceBase):
2222
presigned_download_url: Optional[str] = None
2323
url: Optional[str] = None
2424

25-
@validator("url", always=True)
26-
def construct_url(cls, v, values):
25+
@model_validator(mode="after")
26+
def construct_url(self):
2727
noteable_url = os.environ.get("PUBLIC_NOTEABLE_URL", "https://app.noteable.io")
28-
return f"{noteable_url}/f/{values['id']}/{values['path']}"
28+
self.url = f"{noteable_url}/f/{self.id}/{self.path}"
29+
30+
return self
2931

3032

3133
class FileVersion(ResourceBase):

origami/models/api/outputs.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,11 @@ class KernelOutputContent(BaseModel):
1414

1515
class KernelOutput(ResourceBase):
1616
type: str
17-
display_id: Optional[str]
17+
display_id: Optional[str] = None
1818
available_mimetypes: List[str]
1919
content_metadata: KernelOutputContent
20-
content: Optional[KernelOutputContent]
21-
content_for_llm: Optional[KernelOutputContent]
20+
content: Optional[KernelOutputContent] = None
21+
content_for_llm: Optional[KernelOutputContent] = None
2222
parent_collection_id: uuid.UUID
2323

2424

origami/models/api/projects.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,20 @@
22
import uuid
33
from typing import Optional
44

5-
from pydantic import validator
5+
from pydantic import model_validator
66

77
from origami.models.api.base import ResourceBase
88

99

1010
class Project(ResourceBase):
1111
name: str
12-
description: Optional[str]
12+
description: Optional[str] = None
1313
space_id: uuid.UUID
1414
url: Optional[str] = None
1515

16-
@validator("url", always=True)
17-
def construct_url(cls, v, values):
16+
@model_validator(mode="after")
17+
def construct_url(self):
1818
noteable_url = os.environ.get("PUBLIC_NOTEABLE_URL", "https://app.noteable.io")
19-
return f"{noteable_url}/p/{values['id']}"
19+
self.url = f"{noteable_url}/p/{self.id}"
20+
21+
return self

origami/models/api/spaces.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,19 @@
11
import os
22
from typing import Optional
33

4-
from pydantic import validator
4+
from pydantic import model_validator
55

66
from origami.models.api.base import ResourceBase
77

88

99
class Space(ResourceBase):
1010
name: str
11-
description: Optional[str]
11+
description: Optional[str] = None
1212
url: Optional[str] = None
1313

14-
@validator("url", always=True)
15-
def construct_url(cls, v, values):
14+
@model_validator(mode="after")
15+
def construct_url(self):
1616
noteable_url = os.environ.get("PUBLIC_NOTEABLE_URL", "https://app.noteable.io")
17-
return f"{noteable_url}/s/{values['id']}"
17+
self.url = f"{noteable_url}/s/{self.id}"
18+
19+
return self

origami/models/api/users.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import uuid
22
from typing import Optional
33

4-
from pydantic import validator
4+
from pydantic import model_validator
55

66
from origami.models.api.base import ResourceBase
77

@@ -10,14 +10,16 @@ class User(ResourceBase):
1010
"""The user fields sent to/from the server"""
1111

1212
handle: str
13-
email: Optional[str] # not returned if looking up user other than yourself
13+
email: Optional[str] = None # not returned if looking up user other than yourself
1414
first_name: str
1515
last_name: str
16-
origamist_default_project_id: Optional[uuid.UUID]
17-
principal_sub: Optional[str] # from /users/me only, represents auth type
18-
auth_type: Optional[str]
16+
origamist_default_project_id: Optional[uuid.UUID] = None
17+
principal_sub: Optional[str] = None # from /users/me only, represents auth type
18+
auth_type: Optional[str] = None
1919

20-
@validator("auth_type", always=True)
21-
def construct_auth_type(cls, v, values):
22-
if values.get("principal_sub"):
23-
return values["principal_sub"].split("|")[0]
20+
@model_validator(mode="after")
21+
def construct_auth_type(self):
22+
if self.principal_sub:
23+
self.auth_type = self.principal_sub.split("|")[0]
24+
25+
return self

0 commit comments

Comments
 (0)