|
1 | 1 | import logging |
| 2 | +import pprint |
| 3 | +from dataclasses import dataclass |
2 | 4 | from pathlib import Path |
3 | 5 |
|
| 6 | +from attr import asdict |
4 | 7 | from lsprotocol import types |
5 | | -from lsprotocol.types import Position, Range, TextEdit |
| 8 | +from lsprotocol.types import CreateFile, CreateFileOptions, DeleteFile, Position, Range, RenameFile, TextEdit |
6 | 9 | from pygls.workspace import TextDocument, Workspace |
7 | 10 |
|
8 | 11 | from codegen.sdk.codebase.io.file_io import FileIO |
|
11 | 14 | logger = logging.getLogger(__name__) |
12 | 15 |
|
13 | 16 |
|
| 17 | +@dataclass |
| 18 | +class File: |
| 19 | + doc: TextDocument |
| 20 | + path: Path |
| 21 | + change: TextEdit | None = None |
| 22 | + other_change: CreateFile | RenameFile | DeleteFile | None = None |
| 23 | + version: int = 0 |
| 24 | + |
| 25 | + @property |
| 26 | + def deleted(self) -> bool: |
| 27 | + return self.other_change is not None and self.other_change.kind == "delete" |
| 28 | + |
| 29 | + @property |
| 30 | + def created(self) -> bool: |
| 31 | + return self.other_change is not None and self.other_change.kind == "create" |
| 32 | + |
| 33 | + @property |
| 34 | + def identifier(self) -> types.OptionalVersionedTextDocumentIdentifier: |
| 35 | + return types.OptionalVersionedTextDocumentIdentifier(uri=self.path.as_uri(), version=self.version) |
| 36 | + |
| 37 | + |
14 | 38 | class LSPIO(IO): |
15 | 39 | base_io: FileIO |
16 | 40 | workspace: Workspace |
17 | | - changes: dict[str, TextEdit] = {} |
| 41 | + files: dict[Path, File] |
18 | 42 |
|
19 | 43 | def __init__(self, workspace: Workspace): |
20 | 44 | self.workspace = workspace |
21 | 45 | self.base_io = FileIO() |
| 46 | + self.files = {} |
22 | 47 |
|
23 | | - def _get_doc(self, path: Path) -> TextDocument | None: |
| 48 | + def _get_doc(self, path: Path) -> TextDocument: |
24 | 49 | uri = path.as_uri() |
25 | 50 | logger.info(f"Getting document for {uri}") |
26 | 51 | return self.workspace.get_text_document(uri) |
27 | 52 |
|
| 53 | + def _get_file(self, path: Path) -> File: |
| 54 | + if path not in self.files: |
| 55 | + doc = self._get_doc(path) |
| 56 | + self.files[path] = File(doc=doc, path=path, version=doc.version or 0) |
| 57 | + return self.files[path] |
| 58 | + |
| 59 | + def read_text(self, path: Path) -> str: |
| 60 | + file = self._get_file(path) |
| 61 | + if file.deleted: |
| 62 | + msg = f"File {path} has been deleted" |
| 63 | + raise FileNotFoundError(msg) |
| 64 | + if file.change: |
| 65 | + return file.change.new_text |
| 66 | + if file.created: |
| 67 | + return "" |
| 68 | + return file.doc.source |
| 69 | + |
28 | 70 | def read_bytes(self, path: Path) -> bytes: |
29 | | - if self.changes.get(path.as_uri()): |
30 | | - return self.changes[path.as_uri()].new_text.encode("utf-8") |
31 | | - if doc := self._get_doc(path): |
32 | | - return doc.source.encode("utf-8") |
33 | | - return self.base_io.read_bytes(path) |
| 71 | + file = self._get_file(path) |
| 72 | + if file.deleted: |
| 73 | + msg = f"File {path} has been deleted" |
| 74 | + raise FileNotFoundError(msg) |
| 75 | + if file.change: |
| 76 | + return file.change.new_text.encode("utf-8") |
| 77 | + if file.created: |
| 78 | + return b"" |
| 79 | + return file.doc.source.encode("utf-8") |
34 | 80 |
|
35 | 81 | def write_bytes(self, path: Path, content: bytes) -> None: |
36 | 82 | logger.info(f"Writing bytes to {path}") |
37 | 83 | start = Position(line=0, character=0) |
38 | | - if doc := self._get_doc(path): |
39 | | - end = Position(line=len(doc.source), character=len(doc.source)) |
| 84 | + file = self._get_file(path) |
| 85 | + if self.file_exists(path): |
| 86 | + lines = self.read_text(path).splitlines() |
| 87 | + if len(lines) == 0: |
| 88 | + end = Position(line=0, character=0) |
| 89 | + else: |
| 90 | + end = Position(line=len(lines) - 1, character=len(lines[-1])) |
| 91 | + file.change = TextEdit(range=Range(start=start, end=end), new_text=content.decode("utf-8")) |
40 | 92 | else: |
41 | | - end = Position(line=0, character=0) |
42 | | - self.changes[path.as_uri()] = TextEdit(range=Range(start=start, end=end), new_text=content.decode("utf-8")) |
| 93 | + file.other_change = CreateFile(uri=path.as_uri(), options=CreateFileOptions()) |
| 94 | + file.change = TextEdit(range=Range(start=start, end=start), new_text=content.decode("utf-8")) |
43 | 95 |
|
44 | 96 | def save_files(self, files: set[Path] | None = None) -> None: |
45 | | - self.base_io.save_files(files) |
| 97 | + logger.info(f"Saving files {files}") |
46 | 98 |
|
47 | 99 | def check_changes(self) -> None: |
48 | 100 | self.base_io.check_changes() |
49 | 101 |
|
50 | 102 | def delete_file(self, path: Path) -> None: |
| 103 | + file = self._get_file(path) |
| 104 | + file.other_change = DeleteFile(uri=path.as_uri()) |
51 | 105 | self.base_io.delete_file(path) |
52 | 106 |
|
53 | 107 | def file_exists(self, path: Path) -> bool: |
54 | | - if doc := self._get_doc(path): |
55 | | - try: |
56 | | - doc.source |
57 | | - except FileNotFoundError: |
58 | | - return False |
| 108 | + file = self._get_file(path) |
| 109 | + if file.deleted: |
| 110 | + return False |
| 111 | + if file.change: |
| 112 | + return True |
| 113 | + if file.created: |
| 114 | + return True |
| 115 | + try: |
| 116 | + file.doc.source |
59 | 117 | return True |
60 | | - return self.base_io.file_exists(path) |
| 118 | + except FileNotFoundError: |
| 119 | + return False |
61 | 120 |
|
62 | 121 | def untrack_file(self, path: Path) -> None: |
63 | 122 | self.base_io.untrack_file(path) |
64 | 123 |
|
65 | | - def get_document_changes(self) -> list[types.TextDocumentEdit]: |
66 | | - ret = [] |
67 | | - for uri, change in self.changes.items(): |
68 | | - id = types.OptionalVersionedTextDocumentIdentifier(uri=uri) |
69 | | - ret.append(types.TextDocumentEdit(text_document=id, edits=[change])) |
70 | | - self.changes = {} |
71 | | - return ret |
| 124 | + def get_workspace_edit(self) -> types.WorkspaceEdit: |
| 125 | + document_changes = [] |
| 126 | + for _, file in self.files.items(): |
| 127 | + id = file.identifier |
| 128 | + if file.other_change: |
| 129 | + document_changes.append(file.other_change) |
| 130 | + file.other_change = None |
| 131 | + if file.change: |
| 132 | + document_changes.append(types.TextDocumentEdit(text_document=id, edits=[file.change])) |
| 133 | + file.version += 1 |
| 134 | + file.change = None |
| 135 | + logger.info(f"Workspace edit: {pprint.pformat(list(map(asdict, document_changes)))}") |
| 136 | + return types.WorkspaceEdit(document_changes=document_changes) |
0 commit comments