diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index aaed5aa..95bba71 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -50,12 +50,37 @@ jobs: - name: Generate icons run: npm run gen-icons + - name: Import Apple Developer Certificate + env: + APPLE_CERTIFICATE: ${{ secrets.APPLE_CERTIFICATE }} + APPLE_CERTIFICATE_PASSWORD: ${{ secrets.APPLE_CERTIFICATE_PASSWORD }} + KEYCHAIN_PASSWORD: ${{ secrets.KEYCHAIN_PASSWORD }} + run: | + echo $APPLE_CERTIFICATE | base64 --decode > certificate.p12 + security create-keychain -p "$KEYCHAIN_PASSWORD" build.keychain + security default-keychain -s build.keychain + security unlock-keychain -p "$KEYCHAIN_PASSWORD" build.keychain + security set-keychain-settings -t 3600 -u build.keychain + security import certificate.p12 -k build.keychain -P "$APPLE_CERTIFICATE_PASSWORD" -T /usr/bin/codesign + security set-key-partition-list -S apple-tool:,apple:,codesign: -s -k "$KEYCHAIN_PASSWORD" build.keychain + security find-identity -v -p codesigning build.keychain + + - name: Verify Certificate + run: | + CERT_INFO=$(security find-identity -v -p codesigning build.keychain | grep "Apple Development") + CERT_ID=$(echo "$CERT_INFO" | awk -F'"' '{print $2}') + echo "CERT_ID=$CERT_ID" >> $GITHUB_ENV + echo "Certificate imported." + - name: Build & publish (Tauri) uses: tauri-apps/tauri-action@v0 env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} TAURI_SIGNING_PRIVATE_KEY: ${{ secrets.TAURI_SIGNING_PRIVATE_KEY }} TAURI_SIGNING_PRIVATE_KEY_PASSWORD: ${{ secrets.TAURI_SIGNING_PRIVATE_KEY_PASSWORD }} + APPLE_CERTIFICATE: ${{ secrets.APPLE_CERTIFICATE }} + APPLE_CERTIFICATE_PASSWORD: ${{ secrets.APPLE_CERTIFICATE_PASSWORD }} + APPLE_SIGNING_IDENTITY: ${{ env.CERT_ID }} with: tagName: v__VERSION__ releaseName: DTM v__VERSION__ diff --git a/README.md b/README.md index db6b016..f7b23df 100644 --- a/README.md +++ b/README.md @@ -1 +1,19 @@ -# DTM \ No newline at end of file +# DTM + +## Building + +To build the app on Mac, you will need to have [Node/NPM](https://nodejs.org/en/download), and [Rust](https://www.rust-lang.org/tools/install) installed, as well as the Xcode command line tools (`xcode-select --install`) + +```bash +npm install +npm run gen:icons + +# Build the app for current architecture +npm run build:mac + +# Build for Mac Universal +npm run build:universal + +# Run in dev mode +npm run dev +``` \ No newline at end of file diff --git a/package-lock.json b/package-lock.json index 195b204..2d1b128 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1,12 +1,12 @@ { "name": "dtm", - "version": "0.2.1", + "version": "0.3.2", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "dtm", - "version": "0.2.1", + "version": "0.3.2", "dependencies": { "@chakra-ui/react": "^3.30.0", "@emotion/react": "^11.14.0", diff --git a/package.json b/package.json index c52410b..0e63efb 100644 --- a/package.json +++ b/package.json @@ -2,7 +2,7 @@ "name": "dtm", "author": "kcjerrell", "private": true, - "version": "0.3.2", + "version": "0.3.3", "type": "module", "scripts": { "dev": "tauri dev", diff --git a/public/img_not_available.svg b/public/img_not_available.svg new file mode 100644 index 0000000..6e03d34 --- /dev/null +++ b/public/img_not_available.svg @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/src-tauri/.cargo/config.toml b/src-tauri/.cargo/config.toml new file mode 100644 index 0000000..d500958 --- /dev/null +++ b/src-tauri/.cargo/config.toml @@ -0,0 +1,2 @@ +[env] +RUST_TEST_THREADS = "1" \ No newline at end of file diff --git a/src-tauri/Cargo.lock b/src-tauri/Cargo.lock index a78be52..00ea6e3 100644 --- a/src-tauri/Cargo.lock +++ b/src-tauri/Cargo.lock @@ -1624,12 +1624,14 @@ name = "dtm" version = "0.3.2" dependencies = [ "anyhow", + "async-trait", "base64 0.22.1", "bytemuck", "byteorder", "bytes", "cc", "chrono", + "dashmap", "dtm_macros", "entity", "flatbuffers", @@ -1643,7 +1645,7 @@ dependencies = [ "log", "migration", "mime", - "moka", + "notify-debouncer-mini", "num_enum", "objc2 0.6.3", "objc2-app-kit 0.3.2", @@ -1674,10 +1676,12 @@ dependencies = [ "tauri-plugin-updater", "tauri-plugin-valtio", "tauri-plugin-window-state", + "tempfile", "tokio", "tracing", "tracing-subscriber", "unicode-normalization", + "walkdir", "web-image-meta", ] @@ -3582,26 +3586,6 @@ dependencies = [ "windows-sys 0.61.2", ] -[[package]] -name = "moka" -version = "0.12.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a3dec6bd31b08944e08b58fd99373893a6c17054d6f3ea5006cc894f4f4eee2a" -dependencies = [ - "async-lock 3.4.2", - "crossbeam-channel", - "crossbeam-epoch", - "crossbeam-utils", - "equivalent", - "event-listener 5.4.1", - "futures-util", - "parking_lot", - "portable-atomic", - "smallvec", - "tagptr", - "uuid", -] - [[package]] name = "moxcms" version = "0.7.11" @@ -3767,6 +3751,18 @@ dependencies = [ "walkdir", ] +[[package]] +name = "notify-debouncer-mini" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "17849edfaabd9a5fef1c606d99cfc615a8e99f7ac4366406d86c7942a3184cf2" +dependencies = [ + "log", + "notify", + "notify-types", + "tempfile", +] + [[package]] name = "notify-types" version = "2.0.0" @@ -4719,12 +4715,6 @@ dependencies = [ "windows-sys 0.61.2", ] -[[package]] -name = "portable-atomic" -version = "1.13.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f89776e4d69bb58bc6993e99ffa1d11f228b839984854c7daeb5d37f87cbe950" - [[package]] name = "potential_utf" version = "0.1.4" @@ -6548,12 +6538,6 @@ dependencies = [ "version-compare", ] -[[package]] -name = "tagptr" -version = "0.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7b2093cf4c8eb1e67749a6762251bc9cd836b6fc171623bd0a9d324d37af2417" - [[package]] name = "tao" version = "0.34.5" @@ -7144,9 +7128,9 @@ dependencies = [ [[package]] name = "tempfile" -version = "3.24.0" +version = "3.25.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "655da9c7eb6305c55742045d5a8d2037996d61d8de95806335c7c86ce0f82e9c" +checksum = "0136791f7c95b1f6dd99f9cc786b91bb81c3800b639b3478e561ddb7be95e5f1" dependencies = [ "fastrand 2.3.0", "getrandom 0.3.4", diff --git a/src-tauri/Cargo.toml b/src-tauri/Cargo.toml index 7b55df1..dfcc2ec 100644 --- a/src-tauri/Cargo.toml +++ b/src-tauri/Cargo.toml @@ -3,7 +3,7 @@ members = ["migration", "entity", ".", "fpzip-sys", "macros"] [package] name = "dtm" -version = "0.3.2" +version = "0.3.3" description = "A little app for reading Draw Things Metadata" authors = ["kcjerrell"] edition = "2021" @@ -58,7 +58,6 @@ futures = "0.3.28" tokio = { version = "1.48.0", features = ["full"] } mime = "0.3.17" once_cell = "1.21.3" -moka = { version = "0.12.11", features = ["future"] } bytes = "1.10.1" byteorder = "1.5.0" half = "2" @@ -79,6 +78,10 @@ sevenz-rust = "0.6.1" sha2 = "0.10.9" futures-util = "0.3.31" regex = "1.12.2" +walkdir = "2.5.0" +async-trait = "0.1.89" +notify-debouncer-mini = { version = "0.7.0", features = ["macos_fsevent"] } +dashmap = "6.1.0" # macOS-only [target."cfg(target_os = \"macos\")".dependencies] @@ -91,3 +94,6 @@ tauri-plugin-nspopover = { git = "https://github.com/freethinkel/tauri-nspopover [target.'cfg(not(any(target_os = "android", target_os = "ios")))'.dependencies] tauri-plugin-updater = { git = "https://github.com/tauri-apps/plugins-workspace", branch = "v2" } tauri-plugin-window-state = { git = "https://github.com/tauri-apps/plugins-workspace", branch = "v2" } + +[dev-dependencies] +tempfile = "3.25.0" diff --git a/src-tauri/capabilities/default.json b/src-tauri/capabilities/default.json index 6a2a45a..dc95e47 100644 --- a/src-tauri/capabilities/default.json +++ b/src-tauri/capabilities/default.json @@ -31,28 +31,46 @@ "fs:allow-home-read-recursive", "fs:allow-watch", "fs:allow-unwatch", + { + "identifier": "fs:allow-watch", + "allow": [ + "$HOME/**", + "/Volumes/**" + ] + }, { "identifier": "fs:allow-exists", "allow": [ - "$HOME/Library/Containers/com.liuliu.draw-things/Data/*" + "$HOME/**", + "/Volumes/**" ] }, { "identifier": "fs:allow-read-dir", "allow": [ - "$HOME/Library/Containers/com.liuliu.draw-things/Data/Documents/*" + "$HOME/**", + "/Volumes/**" ] }, { "identifier": "fs:allow-read-file", "allow": [ - "$HOME/Library/Containers/com.liuliu.draw-things/Data/Documents/*" + "$HOME/**", + "/Volumes/**" ] }, { "identifier": "fs:allow-read", "allow": [ - "$HOME/Library/Containers/com.liuliu.draw-things/Data/Documents/*" + "$HOME/**", + "/Volumes/**" + ] + }, + { + "identifier": "fs:allow-stat", + "allow": [ + "$HOME/**", + "/Volumes/**" ] }, "http:default", diff --git a/src-tauri/entity/Cargo.toml b/src-tauri/entity/Cargo.toml index ec7c2b7..703224a 100644 --- a/src-tauri/entity/Cargo.toml +++ b/src-tauri/entity/Cargo.toml @@ -10,6 +10,6 @@ path = "src/mod.rs" [dependencies] sea-orm = { version = "2.0.0-rc" } -serde = "1.0.228" +serde = { version = "1.0", features = ["derive"] } chrono = { version = "0.4", features = ["serde"] } num_enum = "0.7.5" diff --git a/src-tauri/entity/src/projects.rs b/src-tauri/entity/src/projects.rs index 7b64f1b..f02a49d 100644 --- a/src-tauri/entity/src/projects.rs +++ b/src-tauri/entity/src/projects.rs @@ -4,18 +4,25 @@ use sea_orm::entity::prelude::*; use serde::Serialize; #[sea_orm::model] -#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel, Serialize)] +#[derive(Clone, Debug, PartialEq, DeriveEntityModel, Serialize)] #[sea_orm(table_name = "projects")] pub struct Model { #[sea_orm(primary_key)] pub id: i64, pub fingerprint: String, - #[sea_orm(unique)] pub path: String, + pub watchfolder_id: i64, pub filesize: Option, pub modified: Option, - pub missing_on: Option, pub excluded: bool, + #[sea_orm( + belongs_to, + from = "watchfolder_id", + to = "id", + on_update = "NoAction", + on_delete = "Cascade" + )] + pub watchfolder: HasOne, #[sea_orm(has_many)] pub images: HasMany, } diff --git a/src-tauri/entity/src/watch_folders.rs b/src-tauri/entity/src/watch_folders.rs index 8a1b456..1183759 100644 --- a/src-tauri/entity/src/watch_folders.rs +++ b/src-tauri/entity/src/watch_folders.rs @@ -4,14 +4,18 @@ use sea_orm::entity::prelude::*; use serde::Serialize; #[sea_orm::model] -#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel, Serialize)] +#[derive(Clone, Debug, PartialEq, DeriveEntityModel, Serialize)] #[sea_orm(table_name = "watch_folders")] pub struct Model { #[sea_orm(primary_key)] pub id: i64, pub path: String, + pub bookmark: String, pub recursive: Option, - pub last_updated: Option, + pub is_missing: bool, + pub is_locked: bool, + #[sea_orm(has_many)] + pub projects: HasMany, } impl ActiveModelBehavior for ActiveModel {} diff --git a/src-tauri/macros/src/lib.rs b/src-tauri/macros/src/lib.rs index 13237da..628f2f8 100644 --- a/src-tauri/macros/src/lib.rs +++ b/src-tauri/macros/src/lib.rs @@ -1,9 +1,7 @@ use proc_macro::TokenStream; -use quote::quote; +use quote::{format_ident, quote}; use syn::{ - parse::{Parse, ParseStream}, - parse_macro_input, Expr, FnArg, GenericArgument, ItemFn, Pat, PathArguments, ReturnType, Token, - Type, + Expr, FnArg, GenericArgument, ItemFn, Pat, PathArguments, ReturnType, Token, Type, parse::{Parse, ParseStream}, parse_macro_input }; struct DtmArgs { @@ -180,3 +178,121 @@ pub fn dtm_command(args: TokenStream, input: TokenStream) -> TokenStream { TokenStream::from(expanded) } + +#[proc_macro_attribute] +pub fn dtp_command(_attr: TokenStream, item: TokenStream) -> TokenStream { + // This is now a marker macro when used inside #[dtp_commands] + // If used alone, it will still try to generate, but will fail if inside an impl. + // We'll keep the logic but allow it to be stripped by dtp_commands. + item +} + +#[proc_macro_attribute] +pub fn dtp_commands(_attr: TokenStream, item: TokenStream) -> TokenStream { + let mut input = parse_macro_input!(item as syn::ItemImpl); + let self_ty = &input.self_ty; + + let mut generated_commands = Vec::new(); + + for item in &mut input.items { + if let syn::ImplItem::Fn(method) = item { + let mut has_dtp_command = false; + let mut dtp_command_idx = None; + + for (i, attr) in method.attrs.iter().enumerate() { + if attr.path().is_ident("dtp_command") { + has_dtp_command = true; + dtp_command_idx = Some(i); + break; + } + } + + if has_dtp_command { + // Remove the dtp_command attribute from the method + if let Some(idx) = dtp_command_idx { + method.attrs.remove(idx); + } + + let vis = &method.vis; + let sig = &method.sig; + let fn_name = &sig.ident; + + // Ensure async + if sig.asyncness.is_none() { + return syn::Error::new_spanned( + sig.fn_token, + "dtp_command functions must be async", + ) + .to_compile_error() + .into(); + } + + // Extract args + let mut inputs = sig.inputs.iter(); + + // Ensure first arg is &self + let first = inputs.next(); + match first { + Some(FnArg::Receiver(_)) => {} + _ => { + return syn::Error::new_spanned( + sig, + "dtp_command requires &self as first parameter", + ) + .to_compile_error() + .into(); + } + } + + // Collect remaining args for wrapper + let mut wrapper_args = Vec::new(); + let mut forward_args = Vec::new(); + + for arg in inputs { + if let FnArg::Typed(pat_type) = arg { + wrapper_args.push(pat_type.clone()); + + // extract argument name for forwarding + if let Pat::Ident(pat_ident) = &*pat_type.pat { + forward_args.push(pat_ident.ident.clone()); + } + } + } + + let output = &sig.output; + let wrapper_name = format_ident!("dtp_{}", fn_name); + let command_name_str = wrapper_name.to_string(); + + generated_commands.push(quote! { + #[tauri::command] + #vis async fn #wrapper_name( + state: tauri::State<'_, #self_ty>, + #(#wrapper_args),* + ) #output { + log::debug!("DTPService command: {}", #command_name_str); + + let result = state.inner().#fn_name(#(#forward_args),*).await; + + if let Err(ref e) = result { + log::error!( + "DTPService command failed: {} ({})", + #command_name_str, + e + ); + } + + result + } + }); + } + } + } + + let expanded = quote! { + #input + + #(#generated_commands)* + }; + + expanded.into() +} \ No newline at end of file diff --git a/src-tauri/migration/src/m20220101_000001_create_table.rs b/src-tauri/migration/src/m20220101_000001_create_table.rs index a8a52fb..da23f1c 100644 --- a/src-tauri/migration/src/m20220101_000001_create_table.rs +++ b/src-tauri/migration/src/m20220101_000001_create_table.rs @@ -6,25 +6,60 @@ pub struct Migration; #[async_trait::async_trait] impl MigrationTrait for Migration { async fn up(&self, manager: &SchemaManager) -> Result<(), DbErr> { - // projects + // watchfolders manager .create_table( Table::create() - .table(Projects::Table) + .table(WatchFolders::Table) .if_not_exists() .col( - ColumnDef::new(Projects::Id) + ColumnDef::new(WatchFolders::Id) .integer() .not_null() .primary_key() .auto_increment(), ) + .col(ColumnDef::new(WatchFolders::Path).string().not_null()) .col( - ColumnDef::new(Projects::Path) + ColumnDef::new(WatchFolders::Bookmark) .string() + .unique_key() + .not_null(), + ) + .col( + ColumnDef::new(WatchFolders::Recursive) + .boolean() + .default(false), + ) + .col( + ColumnDef::new(WatchFolders::IsMissing) + .boolean() + .default(true), + ) + .col( + ColumnDef::new(WatchFolders::IsLocked) + .boolean() + .default(false), + ) + .to_owned(), + ) + .await?; + + // projects + manager + .create_table( + Table::create() + .table(Projects::Table) + .if_not_exists() + .col( + ColumnDef::new(Projects::Id) + .integer() .not_null() - .unique_key(), + .primary_key() + .auto_increment(), ) + .col(ColumnDef::new(Projects::Path).string().not_null()) + .col(ColumnDef::new(Projects::WatchfolderId).integer().not_null()) .col(ColumnDef::new(Projects::Filesize).big_integer().null()) .col(ColumnDef::new(Projects::Modified).big_integer().null()) .col( @@ -39,7 +74,20 @@ impl MigrationTrait for Migration { .not_null() .default(""), ) - .col(ColumnDef::new(Projects::MissingOn).big_integer().null()) + .foreign_key( + ForeignKey::create() + .name("fk_projects_watchfolder") + .from(Projects::Table, Projects::WatchfolderId) + .to(WatchFolders::Table, WatchFolders::Id) + .on_delete(ForeignKeyAction::Cascade), + ) + .index( + Index::create() + .name("idx_projects_path_watchfolder_id") + .col(Projects::Path) + .col(Projects::WatchfolderId) + .unique(), + ) .to_owned(), ) .await?; @@ -387,35 +435,6 @@ impl MigrationTrait for Migration { ) .await?; - // watchfolders - manager - .create_table( - Table::create() - .table(WatchFolders::Table) - .if_not_exists() - .col( - ColumnDef::new(WatchFolders::Id) - .integer() - .not_null() - .primary_key() - .auto_increment(), - ) - .col( - ColumnDef::new(WatchFolders::Path) - .string() - .not_null() - .unique_key(), - ) - .col( - ColumnDef::new(WatchFolders::Recursive) - .boolean() - .default(false), - ) - .col(ColumnDef::new(WatchFolders::LastUpdated).integer().null()) - .to_owned(), - ) - .await?; - manager .get_connection() .execute_unprepared( @@ -475,7 +494,7 @@ enum Projects { Modified, Excluded, Fingerprint, - MissingOn, + WatchfolderId, } #[derive(Iden)] @@ -565,5 +584,7 @@ enum WatchFolders { Id, Path, Recursive, - LastUpdated, + Bookmark, + IsMissing, + IsLocked, } diff --git a/src-tauri/src/bookmarks.rs b/src-tauri/src/bookmarks.rs index f408ba6..f121e6d 100644 --- a/src-tauri/src/bookmarks.rs +++ b/src-tauri/src/bookmarks.rs @@ -1,128 +1,33 @@ -use tauri::command; - #[cfg(target_os = "macos")] -mod ffi { - use std::os::raw::c_char; +mod bookmarks_mac; +#[cfg(target_os = "macos")] +pub use bookmarks_mac::*; - extern "C" { - pub fn open_dt_folder_picker(default_path: *const c_char) -> *mut c_char; - pub fn free_string_ptr(ptr: *mut c_char); - pub fn start_accessing_security_scoped_resource(bookmark: *const c_char) -> *mut c_char; - pub fn stop_all_security_scoped_resources(); - pub fn stop_accessing_security_scoped_resource(bookmark: *const c_char); - } -} +#[cfg(target_os = "linux")] +mod bookmarks_linux; +#[cfg(target_os = "linux")] +pub use bookmarks_linux::*; + +// Also support other non-macos platforms as linux-like (simple paths) +#[cfg(all(not(target_os = "macos"), not(target_os = "linux")))] +mod bookmarks_linux; +#[cfg(all(not(target_os = "macos"), not(target_os = "linux")))] +pub use bookmarks_linux::*; -#[derive(serde::Serialize)] + +#[derive(serde::Serialize, serde::Deserialize, Clone)] pub struct PickFolderResult { pub path: String, pub bookmark: String, } -#[command] -pub async fn pick_draw_things_folder( - default_path: Option, -) -> Result, String> { - #[cfg(target_os = "macos")] - { - use std::ffi::{CStr, CString}; - use std::ptr; - - // This function must run on the main thread for UI - // In Tauri v2, commands are async by default on a thread pool. - // NSOpenPanel should ideally be run on main thread. - // However, let's try calling it directly first. If it crashes/hangs, we'll need dispatch. - - let c_default_path = match default_path { - Some(path) => Some(CString::new(path).map_err(|e| e.to_string())?), - None => None, - }; - - let ptr_arg = match &c_default_path { - Some(c_str) => c_str.as_ptr(), - None => ptr::null(), - }; - - let ptr = unsafe { ffi::open_dt_folder_picker(ptr_arg) }; - - if ptr.is_null() { - return Ok(None); - } - - let c_str = unsafe { CStr::from_ptr(ptr) }; - let full_result = c_str.to_string_lossy().into_owned(); - - unsafe { ffi::free_string_ptr(ptr) }; - - // Parse "path|bookmark" - if let Some((path, bookmark)) = full_result.split_once('|') { - Ok(Some(PickFolderResult { - path: path.to_string(), - bookmark: bookmark.to_string(), - })) - } else { - Err("Failed to parse picker result".to_string()) - } - } - - #[cfg(not(target_os = "macos"))] - { - Err("Unsupported platform".to_string()) - } -} - -#[command] -pub async fn resolve_bookmark(bookmark: String) -> Result { - #[cfg(target_os = "macos")] - { - use std::ffi::{CStr, CString}; - - let c_bookmark = CString::new(bookmark).map_err(|e| e.to_string())?; - - let ptr = unsafe { ffi::start_accessing_security_scoped_resource(c_bookmark.as_ptr()) }; - - if ptr.is_null() { - return Err("Failed to resolve bookmark or start accessing resource".to_string()); - } - - let c_str = unsafe { CStr::from_ptr(ptr) }; - let result = c_str.to_string_lossy().into_owned(); - - unsafe { ffi::free_string_ptr(ptr) }; - - Ok(result) - } - - #[cfg(not(target_os = "macos"))] - { - Err("Unsupported platform".to_string()) - } -} - -#[command] -pub async fn stop_accessing_bookmark(bookmark: String) -> Result<(), String> { - #[cfg(target_os = "macos")] - { - use std::ffi::CString; - - let c_bookmark = CString::new(bookmark).map_err(|e| e.to_string())?; - - unsafe { - ffi::stop_accessing_security_scoped_resource(c_bookmark.as_ptr()); - }; - - Ok(()) - } - - #[cfg(not(target_os = "macos"))] - { - Err("Unsupported platform".to_string()) - } -} - -pub fn cleanup_bookmarks() { - #[cfg(target_os = "macos")] - unsafe { - ffi::stop_all_security_scoped_resources(); - } +#[derive(serde::Serialize, serde::Deserialize, Clone)] +#[serde(tag = "type", content = "data")] +pub enum ResolveResult { + CannotResolve, + Resolved(String), + StaleRefreshed { + new_bookmark: String, + resolved_path: String, + }, } diff --git a/src-tauri/src/bookmarks/bookmarks_linux.rs b/src-tauri/src/bookmarks/bookmarks_linux.rs new file mode 100644 index 0000000..ffd2cbe --- /dev/null +++ b/src-tauri/src/bookmarks/bookmarks_linux.rs @@ -0,0 +1,72 @@ +use std::str::FromStr; + +use crate::dtp_service::AppHandleWrapper; + +use super::{PickFolderResult, ResolveResult}; +use tauri::{command, Manager, State}; +use tauri_plugin_dialog::DialogExt; + +#[command] +pub async fn pick_folder_command( + app: State<'_, AppHandleWrapper>, + default_path: Option, + button_text: Option, +) -> Result, String> { + pick_folder(&app, default_path, button_text).await +} + +pub async fn pick_folder( + app: &AppHandleWrapper, + default_path: Option, + button_text: Option, +) -> Result, String> { + let app = app.app_handle.clone().unwrap(); + let folder_override = match default_path { + Some(path) => match path.starts_with("TESTPATH::") { + true => { + let path = path.strip_prefix("TESTPATH::").unwrap(); + Some(tauri_plugin_fs::FilePath::from_str(path).unwrap()) + } + false => None, + }, + None => None, + }; + + let folder: Option = match folder_override { + Some(path) => Some(path), + None => app.dialog().file().blocking_pick_folder(), + }; + + match folder { + Some(path) => { + let path_str = path.to_string(); + Ok(Some(PickFolderResult { + path: path_str.clone(), + bookmark: path_str, + })) + } + None => Ok(None), + } +} + +#[command] +pub async fn resolve_bookmark(bookmark: String) -> Result { + if bookmark.starts_with("TESTBOOKMARK::") { + return Ok(ResolveResult::Resolved( + bookmark.split("::").last().unwrap().to_string(), + )); + } + + // On Linux, the bookmark IS the path + Ok(ResolveResult::Resolved(bookmark)) +} + +#[command] +pub async fn stop_accessing_bookmark(_bookmark: String) -> Result<(), String> { + // No-op on Linux + Ok(()) +} + +pub fn cleanup_bookmarks() { + // No-op on Linux +} diff --git a/src-tauri/src/bookmarks/bookmarks_mac.rs b/src-tauri/src/bookmarks/bookmarks_mac.rs new file mode 100644 index 0000000..396db9e --- /dev/null +++ b/src-tauri/src/bookmarks/bookmarks_mac.rs @@ -0,0 +1,150 @@ +use crate::dtp_service::AppHandleWrapper; + +use super::{PickFolderResult, ResolveResult}; +use tauri::{command, State}; + +mod ffi { + use std::os::raw::c_char; + + extern "C" { + pub fn open_dt_folder_picker( + default_path: *const c_char, + button_text: *const c_char, + ) -> *mut c_char; + pub fn free_string_ptr(ptr: *mut c_char); + pub fn start_accessing_security_scoped_resource(bookmark: *const c_char) -> *mut c_char; + pub fn stop_all_security_scoped_resources(); + pub fn stop_accessing_security_scoped_resource(bookmark: *const c_char); + } +} + +#[derive(serde::Deserialize)] +struct FfiResolveResult { + status: String, + path: String, + new_bookmark: Option, +} + +#[command] +pub async fn pick_folder_command( + app: State<'_, AppHandleWrapper>, + default_path: Option, + button_text: Option, +) -> Result, String> { + pick_folder(&app, default_path, button_text).await +} +pub async fn pick_folder( + app: &AppHandleWrapper, + default_path: Option, + button_text: Option, +) -> Result, String> { + use std::ffi::{CStr, CString}; + + let target_path = match default_path { + Some(p) => p, + None => { + // Default to home directory + match app.get_home_dir() { + Ok(path) => path.to_string_lossy().into_owned(), + Err(_) => return Err("Failed to get home directory".to_string()), + } + } + }; + + let c_default_path = CString::new(target_path).map_err(|e| e.to_string())?; + + let display_button_text = button_text.unwrap_or_else(|| "Select folder".to_string()); + let c_button_text = CString::new(display_button_text).map_err(|e| e.to_string())?; + + let ptr = + unsafe { ffi::open_dt_folder_picker(c_default_path.as_ptr(), c_button_text.as_ptr()) }; + + if ptr.is_null() { + return Ok(None); + } + + let c_str = unsafe { CStr::from_ptr(ptr) }; + let json_result = c_str.to_string_lossy().into_owned(); + + unsafe { ffi::free_string_ptr(ptr) }; + + // Parse JSON result + let result: PickFolderResult = serde_json::from_str(&json_result) + .map_err(|e| format!("Failed to parse picker result: {}", e))?; + + Ok(Some(result)) +} + +#[command] +pub async fn resolve_bookmark(bookmark: String) -> Result { + use std::ffi::{CStr, CString}; + + if bookmark.starts_with("TESTBOOKMARK::") { + return Ok(ResolveResult::Resolved( + bookmark.split("::").last().unwrap().to_string(), + )); + } + + let c_bookmark = CString::new(bookmark).map_err(|e| e.to_string())?; + + let ptr = unsafe { ffi::start_accessing_security_scoped_resource(c_bookmark.as_ptr()) }; + + if ptr.is_null() { + return Ok(ResolveResult::CannotResolve); + } + + let c_str = unsafe { CStr::from_ptr(ptr) }; + let json_result = c_str.to_string_lossy().into_owned(); + + unsafe { ffi::free_string_ptr(ptr) }; + + // Parse JSON result from FFI + let ffi_result: FfiResolveResult = serde_json::from_str(&json_result) + .map_err(|e| format!("Failed to parse resolve result: {}", e))?; + + match ffi_result.status.as_str() { + "resolved" => { + log::debug!("Resolved bookmark: {}", ffi_result.path); + Ok(ResolveResult::Resolved(ffi_result.path)) + } + "stale_refreshed" => { + if let Some(new_bookmark) = ffi_result.new_bookmark { + log::debug!("Stale refreshed bookmark: {}", ffi_result.path); + Ok(ResolveResult::StaleRefreshed { + new_bookmark, + resolved_path: ffi_result.path, + }) + } else { + // Should not happen if status is stale_refreshed + log::debug!( + "Stale refreshed bookmark with no new bookmark: {}", + ffi_result.path + ); + Ok(ResolveResult::Resolved(ffi_result.path)) + } + } + _ => { + log::debug!("Cannot resolve bookmark: {}", ffi_result.path); + Ok(ResolveResult::CannotResolve) + } + } +} + +#[command] +pub async fn stop_accessing_bookmark(bookmark: String) -> Result<(), String> { + use std::ffi::CString; + + let c_bookmark = CString::new(bookmark).map_err(|e| e.to_string())?; + + unsafe { + ffi::stop_accessing_security_scoped_resource(c_bookmark.as_ptr()); + }; + + Ok(()) +} + +pub fn cleanup_bookmarks() { + unsafe { + ffi::stop_all_security_scoped_resources(); + } +} diff --git a/src-tauri/src/dtp_service/data.rs b/src-tauri/src/dtp_service/data.rs new file mode 100644 index 0000000..92b1ab7 --- /dev/null +++ b/src-tauri/src/dtp_service/data.rs @@ -0,0 +1,340 @@ +use crate::{ + bookmarks::{self, PickFolderResult}, + dtp_service::{events::DTPEvent, jobs::SyncJob, AppHandleWrapper, DTPService}, + projects_db::{ + dtos::{ + image::ListImagesResult, + model::ModelExtra, + project::ProjectExtra, + tensor::{TensorHistoryClip, TensorHistoryExtra, TensorSize}, + watch_folder::WatchFolderDTO, + }, + filters::ListImagesFilter, + folder_cache, + }, +}; +use dtm_macros::dtp_commands; + +#[dtp_commands] +impl DTPService { + #[dtp_command] + pub async fn list_projects( + &self, + watchfolder_id: Option, + ) -> Result, String> { + let db = self.get_db().await?; + Ok(db.list_projects(watchfolder_id).await?) + } + + #[dtp_command] + pub async fn update_project( + &self, + project_id: i64, + exclude: Option, + ) -> Result<(), String> { + let db = self.get_db().await?; + + if let Some(exclude_val) = exclude { + db.update_exclude(project_id, exclude_val).await?; + } + + let project = db.get_project(project_id).await?; + self.events + .emit(crate::dtp_service::events::DTPEvent::ProjectUpdated( + project, + )); + + Ok(()) + } + + #[dtp_command] + pub async fn list_images( + &self, + project_ids: Option>, + search: Option, + filters: Option>, + sort: Option, + direction: Option, + take: Option, + skip: Option, + count: Option, + show_video: Option, + show_image: Option, + ) -> Result { + let db = self.get_db().await?; + let opts = crate::projects_db::dtos::image::ListImagesOptions { + project_ids, + search, + filters, + sort, + direction, + take, + skip, + count, + show_video, + show_image, + }; + + Ok(db.list_images(opts).await?) + } + + #[dtp_command] + pub async fn find_image_from_preview_id( + &self, + project_id: i64, + preview_id: i64, + ) -> Result, String> { + let db = self.get_db().await?; + Ok(db.find_image_by_preview_id(project_id, preview_id).await?) + } + + #[dtp_command] + pub async fn get_clip(&self, image_id: i64) -> Result, String> { + let db = self.get_db().await?; + Ok(db.get_clip(image_id).await?) + } + + #[dtp_command] + pub async fn list_watch_folders(&self) -> Result, String> { + let db = self.get_db().await?; + Ok(db.list_watch_folders().await?) + } + + #[dtp_command] + pub async fn pick_watch_folder( + &self, + dt_folder: Option, + test_override: Option, + ) -> Result<(), String> { + let result = get_folder(&self.app_handle, dt_folder, test_override).await?; + self.internal_add_watch_folder(result.path, result.bookmark) + .await + } + + pub async fn add_watchfolder( + self: &Self, + path: String, + bookmark: String, + ) -> Result<(), String> { + self.internal_add_watch_folder(path, bookmark).await + } + + async fn internal_add_watch_folder( + &self, + path: String, + bookmark: String, + ) -> Result<(), String> { + let db = self.get_db().await?; + let folder = db.add_watch_folder(&path, &bookmark, false).await?; + + // Resolve the bookmark and update if needed + let resolved = folder_cache::resolve_bookmark(folder.id, &bookmark).await; + if let Ok(resolved) = resolved { + match resolved { + crate::bookmarks::ResolveResult::Resolved(updated_path) => { + if updated_path != path { + db.update_bookmark_path(folder.id, &bookmark, &updated_path) + .await?; + } + } + crate::bookmarks::ResolveResult::StaleRefreshed { + new_bookmark, + resolved_path, + } => { + db.update_bookmark_path(folder.id, &new_bookmark, &resolved_path) + .await?; + } + crate::bookmarks::ResolveResult::CannotResolve => { + // TODO: Mark as missing in DB? + } + } + } + + self.events + .emit(crate::dtp_service::events::DTPEvent::WatchFoldersChanged); + + let scheduler = self.scheduler.read().await; + let scheduler = scheduler.as_ref().unwrap(); + scheduler.add_job(SyncJob::new(false)); + Ok(()) + } + + #[dtp_command] + pub async fn remove_watch_folder(&self, id: i64) -> Result<(), String> { + let db = self.get_db().await?; + db.remove_watch_folders(vec![id]).await?; + + self.events + .emit(crate::dtp_service::events::DTPEvent::WatchFoldersChanged); + + // the projects will be removed automatically by the db + self.events.emit(DTPEvent::ProjectsChanged); + + Ok(()) + } + + #[dtp_command] + pub async fn update_watch_folder(&self, id: i64, recursive: bool) -> Result<(), String> { + let db = self.get_db().await?; + db.update_watch_folder(id, Some(recursive), None, None) + .await?; + + self.events + .emit(crate::dtp_service::events::DTPEvent::WatchFoldersChanged); + + Ok(()) + } + + #[dtp_command] + pub async fn list_models( + &self, + model_type: Option, + ) -> Result, String> { + let db = self.get_db().await?; + Ok(db.list_models(model_type).await?) + } + + #[dtp_command] + pub async fn get_history_full( + &self, + project_id: i64, + row_id: i64, + ) -> Result { + let project = self.get_project(project_id).await?; + Ok(project + .get_history_full(row_id) + .await + .map_err(|e| e.to_string())?) + } + + #[dtp_command] + pub async fn get_tensor_size( + &self, + project_id: i64, + tensor_id: String, + ) -> Result { + let project = self.get_project(project_id).await?; + Ok(project + .get_tensor_size(&tensor_id) + .await + .map_err(|e| e.to_string())?) + } + + #[dtp_command] + pub async fn decode_tensor( + &self, + project_id: i64, + node_id: Option, + tensor_id: String, + as_png: bool, + ) -> Result { + let project = self.get_project(project_id).await?; + let tensor = project + .get_tensor_raw(&tensor_id) + .await + .map_err(|e| e.to_string())?; + + let metadata = match node_id { + Some(node) => Some( + project + .get_history_full(node) + .await + .map_err(|e| e.to_string())? + .history, + ), + None => None, + }; + + let buffer = crate::projects_db::decode_tensor(tensor, as_png, metadata, None) + .map_err(|e| e.to_string())?; + Ok(tauri::ipc::Response::new(buffer)) + } + + #[dtp_command] + pub async fn find_predecessor( + &self, + project_id: i64, + row_id: i64, + lineage: i64, + logical_time: i64, + ) -> Result, String> { + let project = self.get_project(project_id).await?; + Ok(project + .find_predecessor_candidates(row_id, lineage, logical_time) + .await + .map_err(|e| e.to_string())?) + } + + // Helper method to get a DTProject instance + async fn get_project( + &self, + project_id: i64, + ) -> Result, String> { + let db = self.get_db().await?; + let project_ref = crate::projects_db::ProjectRef::Id(project_id); + Ok(db.get_dt_project(project_ref).await?) + } +} + +async fn get_dt_container(app_handle: &AppHandleWrapper) -> Result { + let path = app_handle + .get_home_dir() + .unwrap() + .join("Library/Containers/com.liuliu.draw-things/Data"); + Ok(path.to_string_lossy().to_string()) +} + +async fn get_dt_data_folder(app_handle: &AppHandleWrapper) -> Result { + let path = app_handle + .get_home_dir() + .unwrap() + .join("Library/Containers/com.liuliu.draw-things/Data/Documents"); + Ok(path.to_string_lossy().to_string()) +} + +async fn get_folder( + app_handle: &AppHandleWrapper, + dt_folder: Option, + test_override: Option, +) -> Result { + if let Some(test_override) = test_override { + return Ok(PickFolderResult { + path: test_override.clone(), + bookmark: test_override, + }); + } + + let result = match dt_folder { + Some(true) => { + let result = bookmarks::pick_folder( + app_handle, + Some(get_dt_container(app_handle).await?), + Some("Select Documents Folder".to_string()), + ) + .await?; + + match result { + Some(result) => { + if result.path != get_dt_data_folder(app_handle).await? { + return Err("Must select Documents folder".to_string()); + } + result + } + None => { + return Err("Failed to select a folder".to_string()); + } + } + } + _ => { + let result = bookmarks::pick_folder(app_handle, None, None).await?; + + match result { + Some(result) => result, + None => { + return Err("Failed to select a folder".to_string()); + } + } + } + }; + Ok(result) +} diff --git a/src-tauri/src/dtp_service/dtp_service.rs b/src-tauri/src/dtp_service/dtp_service.rs new file mode 100644 index 0000000..b6ba1a3 --- /dev/null +++ b/src-tauri/src/dtp_service/dtp_service.rs @@ -0,0 +1,253 @@ +use std::{ + fs, + sync::{ + atomic::{AtomicBool, Ordering}, + Arc, + }, +}; + +use dtm_macros::{dtm_command, dtp_commands}; +use tauri::{ipc::Channel, State}; +use tokio::sync::{OnceCell, RwLock}; + +use crate::{ + dtp_service::{ + AppHandleWrapper, events::{self, DTPEvent}, jobs::{FetchModels, Job, JobContext, SyncJob}, scheduler::Scheduler, watch::WatchService + }, + projects_db::{self, DtmProtocol, ProjectsDb, get_last_row}, +}; + +#[derive(Clone)] +pub struct DTPService { + pub app_handle: AppHandleWrapper, + pub events: events::DTPEventsService, + pdb: Arc>>, + pub scheduler: Arc>>, + pub watch: Arc>>, + dtm_protocol: Arc>, + pub auto_watch: Arc, +} + +#[dtp_commands] +impl DTPService { + pub fn new(app_handle: AppHandleWrapper) -> Self { + let pdb = Arc::new(RwLock::new(None)); + let events = events::DTPEventsService::new(); + let scheduler = Arc::new(RwLock::new(None)); + let watch = Arc::new(RwLock::new(None)); + let dtm_protocol = Arc::new(OnceCell::new()); + + Self { + app_handle, + pdb: pdb, + events, + scheduler, + watch, + dtm_protocol, + auto_watch: Arc::new(AtomicBool::new(false)), + } + } + + pub async fn connect( + &self, + channel: Channel, + auto_watch: bool, + db_path: String, + ) -> Result<(), String> { + self.auto_watch.store(auto_watch, Ordering::Relaxed); + let pdb = ProjectsDb::new(&db_path).await.unwrap(); + { + let mut guard = self.pdb.write().await; + *guard = Some(pdb.clone()); + } + // #FOLDER + self.events.set_channel(channel); + + let ctx = Arc::new(JobContext { + app_handle: self.app_handle.clone(), + pdb: pdb.clone(), + events: self.events.clone(), + dtp: self.clone(), + }); + + let scheduler = Scheduler::new(ctx.clone()); + { + let mut guard = self.scheduler.write().await; + *guard = Some(scheduler.clone()); + } + + let watch = WatchService::new(scheduler.clone()); + watch.watch_volumes().await.unwrap(); + { + let mut guard = self.watch.write().await; + *guard = Some(watch); + } + + self.events.emit(DTPEvent::DtpServiceReady); + + self.add_job(FetchModels {}); + self.add_job(SyncJob::new(true)); + + Ok(()) + } + + pub async fn get_db(&self) -> Result { + self.pdb + .read() + .await + .clone() + .ok_or_else(|| "DB not ready".to_string()) + } + + pub async fn dtm_protocol(&self) -> &DtmProtocol { + self.dtm_protocol + .get_or_init(|| async { DtmProtocol::new(self.get_db().await.unwrap()) }) + .await + } + + #[dtp_command] + pub async fn sync(&self) -> Result<(), String> { + let scheduler = self.scheduler.read().await; + let scheduler = scheduler.as_ref().unwrap(); + scheduler.add_job(SyncJob::new(false)); + + Ok(()) + } + + // test to compare checking rowid vs file metadata + pub async fn check_all(&self) -> Result<(), String> { + let start = std::time::Instant::now(); + let projects = self.list_projects(None).await.unwrap(); + let mut last_rows: Vec<(i64, i64)> = Vec::new(); + for project in projects { + let last_row = get_last_row(&project.full_path).await.unwrap(); + last_rows.push((project.id, last_row.0)); + } + + println!("Checked all projects: {:?}", last_rows); + println!("Checked all projects: {}", start.elapsed().as_millis()); + Ok(()) + } + pub async fn check_all_2(&self) -> Result<(), String> { + let start = std::time::Instant::now(); + let projects = self.list_projects(None).await.unwrap(); + let mut data: Vec<(i64, i64)> = Vec::new(); + for project in projects { + let base = fs::metadata(&project.full_path).map_or(0, |m| m.len() as i64); + let wal = + fs::metadata(format!("{}-wal", &project.full_path)).map_or(0, |m| m.len() as i64); + data.push((base, wal)); + } + + println!("Checked all projects: {:?}", data); + println!("Checked all projects: {}", start.elapsed().as_millis()); + Ok(()) + } + + pub async fn resume_watch(&self, path: &str, recursive: bool) { + let watch = self.watch.read().await; + let watch = watch.as_ref().unwrap(); + watch.watch_folder(path, recursive).await.unwrap(); + } + + pub async fn stop_watch(&self, path: &str) { + let watch = self.watch.read().await; + let watch = watch.as_ref().unwrap(); + watch.stop_watch_folder(path).await.unwrap(); + } + + pub fn add_job(&self, job: T) { + let dtp = self.clone(); + tokio::spawn(async move { + let scheduler = dtp.scheduler.read().await; + let scheduler = scheduler.as_ref().unwrap(); + scheduler.add_job(job); + }); + } + + pub async fn stop(&self) { + { + let watch = self.watch.read().await; + let watch = watch.as_ref().unwrap(); + watch.stop_all().await.unwrap(); + } + { + let mut guard = self.pdb.write().await; + *guard = None; + } + + { + let scheduler = self.scheduler.read().await.clone(); + scheduler.unwrap().stop().await; + } + { + let mut guard = self.scheduler.write().await; + *guard = None; + } + { + let mut guard = self.watch.write().await; + *guard = None; + } + } + + #[dtp_command] + pub async fn lock_folder(&self, watchfolder_id: i64) -> Result<(), String> { + let folder = self + .get_db() + .await + .unwrap() + .update_watch_folder(watchfolder_id, None, None, Some(true)) + .await?; + self.stop_watch(&folder.path).await; + projects_db::close_folder(&folder.path).await; + self.events.emit(DTPEvent::WatchFoldersChanged); + Ok(()) + } +} + +#[dtm_command] +pub async fn dtp_test(state: State<'_, AppHandleWrapper>) -> Result<(), String> { + println!( + "dtp test bla bla {}", + state.get_home_dir().unwrap().to_string_lossy() + ); + Ok(()) +} +// let scheduler = state.scheduler.read().await; +// let scheduler = scheduler.as_ref().unwrap(); +// scheduler.add_job(SyncJob); +// Ok("ok".to_string()) + +#[dtm_command] +pub async fn dtp_connect( + app_handle: State<'_, AppHandleWrapper>, + state: State<'_, DTPService>, + channel: Channel, + auto_watch: bool, +) -> Result<(), String> { + let db_path = get_db_path(&app_handle); + check_old_path(&app_handle); + let _ = state.connect(channel, auto_watch, db_path).await; + Ok(()) +} + +fn get_db_path(app_handle: &AppHandleWrapper) -> String { + let app_data_dir = app_handle.get_app_data_dir().unwrap(); + if !app_data_dir.exists() { + std::fs::create_dir_all(&app_data_dir).expect("Failed to create app data dir"); + } + let project_db_path = app_data_dir.join("projects4.db"); + format!("sqlite://{}?mode=rwc", project_db_path.to_str().unwrap()) +} + +fn check_old_path(app_handle: &AppHandleWrapper) { + let app_data_dir = app_handle.get_app_data_dir().unwrap(); + let old_path = app_data_dir.join("projects2.db"); + if old_path.exists() { + fs::remove_file(old_path).unwrap_or_default(); + } + let old_path = app_data_dir.join("projects3.db"); + if old_path.exists() { + fs::remove_file(old_path).unwrap_or_default(); + } +} diff --git a/src-tauri/src/dtp_service/events.rs b/src-tauri/src/dtp_service/events.rs new file mode 100644 index 0000000..2a76bfa --- /dev/null +++ b/src-tauri/src/dtp_service/events.rs @@ -0,0 +1,73 @@ +use std::sync::{Arc, Mutex}; + +use tauri::ipc::Channel; + +use crate::projects_db::dtos::project::ProjectExtra; + +#[derive(Clone)] +pub struct DTPEventsService { + sender: Arc>>>, +} + +impl DTPEventsService { + pub fn new() -> Self { + Self { + sender: Arc::new(Mutex::new(None)), + } + } + + pub fn set_channel(&self, sender: Channel) { + let mut guard = self.sender.lock().unwrap(); + *guard = Some(sender); + } + + pub fn emit(&self, event: DTPEvent) { + let sender = self.sender.clone(); + tauri::async_runtime::spawn(async move { + if let Some(tx) = &*sender.lock().unwrap() { + let _ = tx.send(event); + } + }); + } +} + +#[derive(serde::Serialize, Debug)] +#[serde(tag = "type", content = "data", rename_all = "snake_case")] +pub enum DTPEvent { + WatchFoldersChanged, + + ProjectAdded(ProjectExtra), + ProjectRemoved(i64), + ProjectUpdated(ProjectExtra), + // when many projects are changed, such as on delete cascade + ProjectsChanged, + + ModelsChanged, + + ImportStarted, + ImportProgress(ScanProgress), + ImportCompleted, + + SyncStarted, + SyncComplete, + + FolderSyncStarted(i64), + FolderSyncComplete(i64), + + DtpServiceReady, + + /// By default, tuple is (job id, msg) + TestEventStart(Option, Option), + /// By default, tuple is (job id, msg) + TestEventComplete(Option, Option), + /// By default, tuple is (job id, msg, error) + TestEventFailed(Option, Option, Option), +} + +#[derive(serde::Serialize, Debug)] +pub struct ScanProgress { + pub projects_found: u64, + pub projects_scanned: u64, + pub images_found: u64, + pub images_scanned: u64, +} diff --git a/src-tauri/src/dtp_service/helpers.rs b/src-tauri/src/dtp_service/helpers.rs new file mode 100644 index 0000000..49c6bf4 --- /dev/null +++ b/src-tauri/src/dtp_service/helpers.rs @@ -0,0 +1,192 @@ +use std::collections::HashMap; +use std::fs; +use std::path::PathBuf; +use std::time::{SystemTime, UNIX_EPOCH}; +use tauri::{AppHandle, Manager}; +use walkdir::WalkDir; + +use crate::projects_db::dtos::model::ModelType; +use crate::projects_db::dtos::project::ProjectExtra; +use crate::projects_db::folder_cache; + +#[derive(Debug, Clone)] +pub struct ProjectFile { + pub path: String, + pub filesize: u64, + pub modified: i64, + pub _watchfolder_id: i64, + pub has_base: bool, +} + +pub struct GetFolderFilesResult { + pub projects: HashMap, + pub model_info: Vec<(String, ModelType)>, +} + +pub async fn get_folder_files(watchfolder_path: &str, watchfolder_id: i64) -> GetFolderFilesResult { + let mut projects: HashMap = HashMap::new(); + let mut model_info: Vec<(String, ModelType)> = Vec::new(); + + // Walk the folder recursively + for entry in WalkDir::new(watchfolder_path) + .follow_links(false) + .into_iter() + .filter_map(Result::ok) + { + let path = entry.path(); + + if path.is_dir() { + continue; + } + + // Safe extension check + let ext = match path.extension().and_then(|s| s.to_str()) { + Some(e) => e, + None => continue, + }; + + match ext { + "sqlite3" | "sqlite3-wal" => { + let project_path = + get_project_path(path.to_string_lossy().to_string(), watchfolder_path); + let project_path = PathBuf::from(project_path).with_extension("sqlite3"); // normalize + + let key = path + .parent() + .map(|p| { + p.join( + path.with_extension("sqlite3") + .file_name() + .unwrap_or_default(), + ) + }) + .unwrap_or_else(|| path.to_path_buf()) + .to_string_lossy() + .to_string(); + + if let Ok(metadata) = fs::metadata(path) { + let project = projects.entry(key.clone()).or_insert_with(|| ProjectFile { + path: project_path.to_string_lossy().to_string(), + has_base: false, + filesize: 0, + modified: 0, + _watchfolder_id: watchfolder_id, + }); + + if ext == "sqlite3" { + project.has_base = true; + } + + project.filesize += metadata.len(); + if let Ok(modified) = metadata.modified() { + if let Some(epoch) = system_time_to_epoch_secs(modified) { + project.modified = project.modified.max(epoch); + } + } + } + } + "json" => { + if let Some(model_type) = path + .file_name() + .and_then(|s| s.to_str()) + .and_then(get_model_file_type) + { + model_info.push((path.to_string_lossy().to_string(), model_type)); + } + } + _ => {} + } + } + + projects.retain(|_, v| v.has_base); + + GetFolderFilesResult { + projects, + model_info, + } +} + +pub fn get_project_path(full_path: String, watchfolder_path: &str) -> String { + let path = PathBuf::from(full_path); + path.strip_prefix(watchfolder_path) + .expect("path should be in watchfolder") + .with_extension("sqlite3") + .to_string_lossy() + .to_string() +} + +pub fn get_full_project_path(project: &ProjectExtra) -> String { + let folder = folder_cache::get_folder(project.watchfolder_id).unwrap(); + let path = PathBuf::from(folder) + .join(project.path.to_string()) + .with_extension("sqlite3"); + path.to_string_lossy().to_string() +} + +pub fn get_model_file_type(filename: &str) -> Option { + match filename { + "custom.json" | "uncurated_models.json" | "models.json" => Some(ModelType::Model), + "custom_controlnet.json" | "controlnets.json" => Some(ModelType::Cnet), + "custom_lora.json" | "loras.json" => Some(ModelType::Lora), + _ => None, + } +} + +pub fn system_time_to_epoch_secs(time: SystemTime) -> Option { + time.duration_since(UNIX_EPOCH) + .ok() + .map(|d| d.as_secs() as i64) +} + +#[derive(Clone)] +pub struct AppHandleWrapper { + pub app_handle: Option, +} + +impl AppHandleWrapper { + pub fn new(app_handle: Option) -> Self { + Self { app_handle } + } + + fn get_test_path(&self, path: &str) -> PathBuf { + let base = std::env::current_dir().unwrap().join("test_data/temp"); + let result = match path { + "" => base, + _ => base.join(path), + }; + fs::create_dir_all(&result).unwrap(); + result + } + + pub fn get_home_dir(&self) -> tauri::Result { + if let Some(app_handle) = &self.app_handle { + app_handle.path().home_dir() + } else { + Ok(self.get_test_path("")) + } + } + + pub fn get_app_data_dir(&self) -> tauri::Result { + if let Some(app_handle) = &self.app_handle { + app_handle.path().app_data_dir() + } else { + Ok(self.get_test_path("app_data_dir")) + } + } +} + +impl From for AppHandleWrapper { + fn from(value: AppHandle) -> Self { + Self { + app_handle: Some(value.clone()), + } + } +} + +impl From<&AppHandle> for AppHandleWrapper { + fn from(value: &AppHandle) -> Self { + Self { + app_handle: Some(value.clone()), + } + } +} diff --git a/src-tauri/src/dtp_service/jobs/check_file.rs b/src-tauri/src/dtp_service/jobs/check_file.rs new file mode 100644 index 0000000..f2d69b5 --- /dev/null +++ b/src-tauri/src/dtp_service/jobs/check_file.rs @@ -0,0 +1,96 @@ +use std::{fs, sync::Arc}; + +use crate::dtp_service::{ + helpers::system_time_to_epoch_secs, + jobs::{AddProjectJob, Job, JobContext, JobResult, RemoveProjectJob, UpdateProjectJob}, +}; + +pub struct CheckFileJob { + pub project_path: String, +} + +impl CheckFileJob { + pub fn new(project_path: String) -> Self { + Self { project_path } + } +} + +#[async_trait::async_trait] +impl Job for CheckFileJob { + fn get_label(&self) -> String { + "Check file".to_string() + } + + async fn execute(self: &Self, ctx: &JobContext) -> Result { + let watchfolder = ctx + .pdb + .get_watch_folder_for_path(&self.project_path) + .await + .unwrap(); + if watchfolder.is_none() { + return Err("Watch folder not found".to_string()); + } + let watchfolder = watchfolder.unwrap(); + let project_path = self + .project_path + .strip_prefix(format!("{}/", watchfolder.path).as_str()) + .unwrap(); + + let entity = ctx + .pdb + .get_project_by_path(watchfolder.id, &project_path) + .await + .map_err(|e| e.to_string())?; + + if !fs::exists(&self.project_path).unwrap_or(false) { + println!("File does not exist: {}", self.project_path); + match entity { + Some(entity) => { + println!("Removing project: {}", entity.id); + let job = RemoveProjectJob { + project_id: entity.id, + }; + return Ok(JobResult::Subtasks(vec![Arc::new(job)])); + } + None => { + println!("File does not exist and no project found"); + return Ok(JobResult::None); + } + } + } + + let metadata = fs::metadata(&self.project_path).unwrap(); + let filesize = metadata.len() as i64; + let modified = system_time_to_epoch_secs(metadata.modified().unwrap()); + + match entity { + // if an entity was found, compare size and modified + Some(entity) => { + println!("Project found for path: {}", self.project_path); + if entity.filesize.unwrap_or(0) != filesize || entity.modified != modified { + let job = UpdateProjectJob { + project_id: entity.id, + filesize: filesize, + modified: modified.unwrap_or(0), + is_import: false, + }; + return Ok(JobResult::Subtasks(vec![Arc::new(job)])); + } + } + None => { + println!("No project found for path: {}", self.project_path); + let job = AddProjectJob { + path: project_path.to_string(), + watchfolder_id: watchfolder.id, + filesize, + modified: modified.unwrap_or(0), + is_import: false, + }; + println!("Adding project: {}", self.project_path); + return Ok(JobResult::Subtasks(vec![Arc::new(job)])); + } + } + + Ok(JobResult::None) + } +} diff --git a/src-tauri/src/dtp_service/jobs/check_folder.rs b/src-tauri/src/dtp_service/jobs/check_folder.rs new file mode 100644 index 0000000..0a15306 --- /dev/null +++ b/src-tauri/src/dtp_service/jobs/check_folder.rs @@ -0,0 +1,164 @@ +use std::{fs, sync::Arc}; + +use crate::{ + dtp_service::{ + events::DTPEvent, + jobs::{sync_folder::SyncFolderJob, CheckFileJob, Job, JobContext, JobResult}, + }, + projects_db::{dtos::watch_folder::WatchFolderDTO, folder_cache, ProjectsDb}, +}; + +#[derive(Debug)] +pub struct CheckFolderJob { + watchfolder: Option, + path: String, + /// reset is_locked for watchfolder + reset_lock: bool, + /// indicates that a SyncFolderJob should follow. overrides check_files if both are present + sync: bool, + /// if triggered by the watcher, it should follow with CheckFileJobs + check_files: Option>, +} + +impl CheckFolderJob { + pub fn new( + watchfolder: WatchFolderDTO, + reset_lock: bool, + sync: bool, + check_files: Option>, + ) -> Self { + Self { + path: watchfolder.path.clone(), + watchfolder: Some(watchfolder), + reset_lock, + sync, + check_files, + } + } + + pub fn new_from_path( + path: String, + reset_lock: bool, + sync: bool, + check_files: Option>, + ) -> Self { + Self { + watchfolder: None, + path, + reset_lock, + sync, + check_files, + } + } +} + +#[async_trait::async_trait] +impl Job for CheckFolderJob { + fn get_label(&self) -> String { + format!("CheckFolderJob for {}", self.path) + } + + async fn execute(self: &Self, ctx: &JobContext) -> Result { + ctx.dtp.stop_watch(&self.path).await; + + let mut locked_update: Option = None; + let mut missing_update: Option = None; + + let watchfolder = match &self.watchfolder { + Some(wf) => wf, + None => &ctx.pdb.get_watch_folder_by_path(&self.path).await?.unwrap(), + }; + + let resolved = resolve_folder(&watchfolder, &ctx.pdb) + .await + .unwrap_or(false); + + // check existence of folder + let is_missing = !resolved || !fs::exists(&watchfolder.path).unwrap_or(false); + + // if DTO.missing is different, update folder and all projects + if watchfolder.is_missing != is_missing { + missing_update = Some(is_missing); + } + + if watchfolder.is_locked && self.reset_lock { + locked_update = Some(false); + } + println!("locked_update: {:?}", locked_update); + println!("missing_update: {:?}", missing_update); + if locked_update.is_some() || missing_update.is_some() { + ctx.pdb + .update_watch_folder(watchfolder.id, None, missing_update, locked_update) + .await?; + ctx.events.emit(DTPEvent::ProjectsChanged); + } + + if is_missing { + return Ok(JobResult::None); + } + + if self.sync { + return Ok(JobResult::Subtasks(vec![Arc::new(SyncFolderJob::new( + &watchfolder, + ))])); + } + + if let Some(files) = &self.check_files { + let jobs: Vec> = files + .iter() + .map(|f| Arc::new(CheckFileJob::new(f.to_string())) as Arc) + .collect(); + return Ok(JobResult::Subtasks(jobs)); + } + + Ok(JobResult::None) + } + + async fn on_complete(&self, ctx: &JobContext) { + ctx.dtp.resume_watch(&self.path, true).await; + } + + async fn on_failed(&self, ctx: &JobContext, _error: String) { + ctx.dtp.resume_watch(&self.path, true).await; + } +} + +impl Into> for CheckFolderJob { + fn into(self) -> Arc { + Arc::new(self) + } +} + +async fn resolve_folder(folder: &WatchFolderDTO, db: &ProjectsDb) -> Result { + let cached = folder_cache::get_folder(folder.id); + if let Some(cached) = cached { + if cached == folder.path { + return Ok(true); + } + } + let resolved = folder_cache::resolve_bookmark(folder.id, &folder.bookmark).await; + if let Ok(resolved) = resolved { + match resolved { + crate::bookmarks::ResolveResult::Resolved(updated_path) => { + if updated_path != folder.path { + db.update_bookmark_path(folder.id, &folder.bookmark, &updated_path) + .await + .unwrap(); + } + } + crate::bookmarks::ResolveResult::StaleRefreshed { + new_bookmark, + resolved_path, + } => { + db.update_bookmark_path(folder.id, &new_bookmark, &resolved_path) + .await + .unwrap(); + } + crate::bookmarks::ResolveResult::CannotResolve => { + // TODO: Mark as missing in DB? + return Ok(false); + } + } + } + Ok(true) +} diff --git a/src-tauri/src/dtp_service/jobs/job.rs b/src-tauri/src/dtp_service/jobs/job.rs new file mode 100644 index 0000000..8a3197d --- /dev/null +++ b/src-tauri/src/dtp_service/jobs/job.rs @@ -0,0 +1,40 @@ +use std::sync::Arc; + +use crate::dtp_service::AppHandleWrapper; +use crate::{ + dtp_service::{ + events::{DTPEvent, DTPEventsService}, + DTPService, + }, + projects_db::ProjectsDb, +}; + +#[async_trait::async_trait] +pub trait Job +where + Self: Send + Sync, +{ + fn get_label(&self) -> String; + async fn execute(self: &Self, ctx: &JobContext) -> Result; + fn start_event(self: &Self) -> Option { + None + } + async fn on_complete(&self, _ctx: &JobContext) {} + async fn on_failed(&self, _ctx: &JobContext, _error: String) {} +} + +#[derive(Default)] +pub enum JobResult { + #[default] + None, + Event(DTPEvent), + Subtasks(Vec>), +} + +#[derive(Clone)] +pub struct JobContext { + pub app_handle: AppHandleWrapper, + pub pdb: ProjectsDb, + pub events: DTPEventsService, + pub dtp: DTPService, +} diff --git a/src-tauri/src/dtp_service/jobs/mod.rs b/src-tauri/src/dtp_service/jobs/mod.rs new file mode 100644 index 0000000..cb7110c --- /dev/null +++ b/src-tauri/src/dtp_service/jobs/mod.rs @@ -0,0 +1,14 @@ +mod job; +mod project_jobs; +mod sync; +mod sync_folder; +mod check_file; +mod check_folder; +mod sync_models; + +pub use job::{Job, JobContext, JobResult}; +pub use project_jobs::{AddProjectJob, RemoveProjectJob, UpdateProjectJob}; +pub use sync::SyncJob; +pub use check_file::CheckFileJob; +pub use check_folder::CheckFolderJob; +pub use sync_models::{FetchModels, SyncModelsJob}; diff --git a/src-tauri/src/dtp_service/jobs/project_jobs.rs b/src-tauri/src/dtp_service/jobs/project_jobs.rs new file mode 100644 index 0000000..1b0f2ad --- /dev/null +++ b/src-tauri/src/dtp_service/jobs/project_jobs.rs @@ -0,0 +1,158 @@ +use std::sync::Arc; + +use crate::dtp_service::{ + events::{DTPEvent, ScanProgress}, + jobs::{sync_folder::ProjectSync, Job, JobContext, JobResult}, +}; + +pub struct AddProjectJob { + pub path: String, + pub watchfolder_id: i64, + pub filesize: i64, + pub modified: i64, + pub is_import: bool, +} + +impl AddProjectJob { + pub fn new(project_sync: &ProjectSync, is_import: bool) -> Self { + let file = project_sync.file.as_ref().unwrap(); + Self { + path: file.path.to_string(), + watchfolder_id: project_sync.watchfolder_id, + filesize: file.filesize as i64, + modified: file.modified.into(), + is_import, + } + } +} + +#[async_trait::async_trait] +impl Job for AddProjectJob { + fn get_label(&self) -> String { + format!("AddProjectJob for {}", self.path) + } + + async fn execute(self: &Self, ctx: &JobContext) -> Result { + let result = ctx.pdb.add_project(self.watchfolder_id, &self.path).await; + + if self.is_import { + ctx.events.emit(DTPEvent::ImportProgress(ScanProgress { + projects_found: 1, + projects_scanned: 0, + images_found: 0, + images_scanned: 0, + })); + } + + match result { + Ok(added_project) => { + println!("Project added successfully"); + let id = added_project.id; + ctx.events.emit(DTPEvent::ProjectAdded(added_project)); + Ok(JobResult::Subtasks(vec![Arc::new(UpdateProjectJob { + project_id: id, + filesize: self.filesize, + modified: self.modified, + is_import: self.is_import, + })])) + } + Err(e) => Err(e.to_string()), + } + } +} + +pub struct RemoveProjectJob { + pub project_id: i64, +} + +impl RemoveProjectJob { + pub fn new(project_sync: &ProjectSync) -> Result { + if let Some(entity) = &project_sync.entity { + Ok(Self { + project_id: entity.id, + }) + } else { + Err("Project entity not found".to_string()) + } + } +} + +#[async_trait::async_trait] +impl Job for RemoveProjectJob { + fn get_label(&self) -> String { + format!("RemoveProjectJob for {}", self.project_id) + } + + async fn execute(self: &Self, ctx: &JobContext) -> Result { + let result = ctx + .pdb + .remove_project(self.project_id) + .await + .map_err(|e| e.to_string())?; + Ok(JobResult::Event(DTPEvent::ProjectRemoved(result.unwrap()))) + } +} + +pub struct UpdateProjectJob { + pub project_id: i64, + pub filesize: i64, + pub modified: i64, + pub is_import: bool, +} + +impl UpdateProjectJob { + pub fn new(project_sync: &ProjectSync, is_import: bool) -> Result { + if let Some(entity) = &project_sync.entity { + Ok(Self { + project_id: entity.id, + filesize: project_sync.file.as_ref().unwrap().filesize as i64, + modified: project_sync.file.as_ref().unwrap().modified, + is_import, + }) + } else { + Err("Project entity not found".to_string()) + } + } +} + +#[async_trait::async_trait] +impl Job for UpdateProjectJob { + fn get_label(&self) -> String { + format!("UpdateProjectJob for {}", self.project_id) + } + + async fn execute(self: &Self, ctx: &JobContext) -> Result { + let result: Result<(i64, u64), String> = ctx + .pdb + .scan_project(self.project_id, false) + .await + .map_err(|e| e.to_string()); + + match result { + Ok((_id, total)) => { + let project = ctx.pdb.get_project(_id).await.map_err(|e| e.to_string())?; + + let _ = ctx + .pdb + .update_project(project.id, Some(self.filesize), Some(self.modified)) + .await + .map_err(|e| e.to_string())?; + + if self.is_import { + ctx.events.emit(DTPEvent::ImportProgress(ScanProgress { + projects_found: 0, + projects_scanned: 1, + images_found: 0, + images_scanned: total, + })); + } + + Ok(JobResult::Event(DTPEvent::ProjectUpdated(project))) + } + Err(err) => { + log::error!("Error scanning project {}: {}", self.project_id, err); + Err(err.to_string()) + } + } + } +} diff --git a/src-tauri/src/dtp_service/jobs/sync.rs b/src-tauri/src/dtp_service/jobs/sync.rs new file mode 100644 index 0000000..7a1af81 --- /dev/null +++ b/src-tauri/src/dtp_service/jobs/sync.rs @@ -0,0 +1,50 @@ +use std::sync::Arc; + +use crate::dtp_service::events::DTPEvent; +use crate::dtp_service::jobs::CheckFolderJob; + +use super::job::{Job, JobContext, JobResult}; + +pub struct SyncJob { + reset_locks: bool, +} + +impl SyncJob { + pub fn new(reset_locks: bool) -> Self { + Self { reset_locks } + } +} + +#[async_trait::async_trait] +impl Job for SyncJob { + fn get_label(&self) -> String { + format!("SyncJob") + } + fn start_event(self: &Self) -> Option { + Some(DTPEvent::SyncStarted) + } + async fn on_complete(self: &Self, ctx: &JobContext) { + ctx.events.emit(DTPEvent::SyncComplete); + } + async fn execute(self: &Self, ctx: &JobContext) -> Result { + let folders = ctx + .pdb + .list_watch_folders() + .await + .map_err(|e| e.to_string())?; + + let subtasks = folders + .iter() + .map(|wf| { + Arc::new(CheckFolderJob::new( + wf.clone(), + self.reset_locks, + true, + None, + )) as Arc + }) + .collect(); + + Ok(JobResult::Subtasks(subtasks)) + } +} diff --git a/src-tauri/src/dtp_service/jobs/sync_folder.rs b/src-tauri/src/dtp_service/jobs/sync_folder.rs new file mode 100644 index 0000000..b1a56e1 --- /dev/null +++ b/src-tauri/src/dtp_service/jobs/sync_folder.rs @@ -0,0 +1,186 @@ +use std::sync::{ + atomic::{AtomicBool, Ordering}, + Arc, +}; + +use crate::{ + dtp_service::{ + events::DTPEvent, + helpers::{get_folder_files, get_full_project_path, ProjectFile}, + jobs::{ + AddProjectJob, Job, JobContext, JobResult, RemoveProjectJob, SyncModelsJob, + UpdateProjectJob, + }, + }, + projects_db::dtos::{project::ProjectExtra, watch_folder::WatchFolderDTO}, +}; + +pub struct SyncFolderJob { + pub watchfolder_id: i64, + pub watchfolder_path: String, + pub is_import: Arc, +} + +impl SyncFolderJob { + pub fn new(watchfolder: &WatchFolderDTO) -> Self { + Self { + watchfolder_id: watchfolder.id, + watchfolder_path: watchfolder.path.clone(), + is_import: Arc::new(AtomicBool::new(false)), + } + } +} + +#[async_trait::async_trait] +impl Job for SyncFolderJob { + fn get_label(&self) -> String { + format!( + "SyncFolderJob for {} ({})", + self.watchfolder_path, self.watchfolder_id + ) + } + fn start_event(self: &Self) -> Option { + Some(DTPEvent::FolderSyncStarted(self.watchfolder_id)) + } + async fn execute(self: &Self, ctx: &JobContext) -> Result { + let files = get_folder_files(&self.watchfolder_path, self.watchfolder_id).await; + let mut project_files = files.projects; + let mut sync_projects: Vec = Vec::new(); + let entities = ctx + .pdb + .list_projects(Some(self.watchfolder_id)) + .await + .unwrap(); + + // detect if this is a new folder import + let is_import = entities.is_empty() && !project_files.is_empty(); + if is_import { + ctx.events.emit(DTPEvent::ImportStarted); + self.is_import.store(true, Ordering::Relaxed); + } + + for entity in entities { + let full_path = get_full_project_path(&entity); + let file = project_files.remove(&full_path); + + let sync = ProjectSync::new( + Some(entity), + file, + self.watchfolder_id, + self.watchfolder_path.clone(), + ); + sync_projects.push(sync); + } + + for (_key, file) in project_files.drain() { + let sync = ProjectSync::new( + None, + Some(file), + self.watchfolder_id, + self.watchfolder_path.clone(), + ); + sync_projects.push(sync); + } + + let mut subtasks: Vec> = Vec::new(); + + for sync in sync_projects.iter_mut() { + sync.assign_sync_action(); + + match sync.action { + SyncAction::Add => { + subtasks.push(Arc::new(AddProjectJob::new( + &sync, + self.is_import.load(Ordering::Relaxed), + ))); + } + SyncAction::Remove => { + match RemoveProjectJob::new(&sync) { + Ok(job) => subtasks.push(Arc::new(job)), + Err(e) => log::error!("Failed to create RemoveProjectJob: {}", e), + }; + } + SyncAction::Update => { + subtasks.push(Arc::new( + UpdateProjectJob::new(&sync, self.is_import.load(Ordering::Relaxed)) + .unwrap(), + )); + } + _ => {} + }; + } + + if !files.model_info.is_empty() { + subtasks.push(Arc::new(SyncModelsJob::new( + files.model_info.into_iter().map(|m| m.into()).collect(), + ))); + } + + Ok(JobResult::Subtasks(subtasks)) + } + + async fn on_complete(self: &Self, ctx: &JobContext) { + if self.is_import.load(Ordering::Relaxed) { + ctx.events.emit(DTPEvent::ImportCompleted); + } + ctx.events + .emit(DTPEvent::FolderSyncComplete(self.watchfolder_id)); + } +} + +#[derive(Default, Debug, PartialEq, Eq, Clone)] +pub enum SyncAction { + #[default] + None = 0, + Add, + Remove, + Update, +} + +#[derive(Debug, Clone)] +pub struct ProjectSync { + pub entity: Option, + pub file: Option, + pub action: SyncAction, + pub watchfolder_id: i64, + pub watchfolder_path: String, +} + +impl ProjectSync { + pub fn new( + entity: Option, + file: Option, + watchfolder_id: i64, + watchfolder_path: String, + ) -> Self { + let sync = Self { + entity, + file, + action: SyncAction::None, + watchfolder_id, + watchfolder_path, + }; + sync + } + + fn assign_sync_action(&mut self) { + if self.entity.is_none() && self.file.is_some() { + self.action = SyncAction::Add; + return; + } + if self.entity.is_some() && self.file.is_none() { + self.action = SyncAction::Remove; + return; + } + if self.entity.is_none() && self.file.is_none() { + return; + } + if let (Some(entity), Some(file)) = (self.entity.as_ref(), self.file.as_ref()) { + if file.filesize != entity.filesize.unwrap_or(0) as u64 + || file.modified != entity.modified.unwrap_or(0) as i64 + { + self.action = SyncAction::Update; + } + } + } +} diff --git a/src-tauri/src/dtp_service/jobs/sync_models.rs b/src-tauri/src/dtp_service/jobs/sync_models.rs new file mode 100644 index 0000000..7fba43b --- /dev/null +++ b/src-tauri/src/dtp_service/jobs/sync_models.rs @@ -0,0 +1,94 @@ +use crate::dtp_service::{events::DTPEvent, jobs::{Job, JobContext, JobResult}}; +use entity::enums::ModelType; +use serde_json::Value; +use std::sync::Arc; +use tauri_plugin_http::reqwest; + +pub struct ModelInfoFile { + pub path: String, + pub model_type: ModelType, +} + +impl From<(String, ModelType)> for ModelInfoFile { + fn from(value: (String, ModelType)) -> Self { + Self { + path: value.0, + model_type: value.1, + } + } +} + +pub struct SyncModelsJob { + pub model_info: Vec, +} + +impl SyncModelsJob { + pub fn new(model_info: Vec) -> Self { + Self { model_info } + } +} + +#[async_trait::async_trait] +impl Job for SyncModelsJob { + fn get_label(&self) -> String { + format!("SyncModelsJob") + } + + async fn execute(self: &Self, ctx: &JobContext) -> Result { + let pdb = ctx.dtp.get_db().await.unwrap(); + for model_info in self.model_info.iter() { + pdb.scan_model_info(&model_info.path, model_info.model_type).await.unwrap(); + }; + Ok(JobResult::Event(DTPEvent::ModelsChanged)) + } +} + +pub struct FetchModels; + +#[async_trait::async_trait] +impl Job for FetchModels { + fn get_label(&self) -> String { + format!("FetchModels") + } + + async fn execute(self: &Self, ctx: &JobContext) -> Result { + let app_data_dir = ctx.app_handle.get_app_data_dir().map_err(|e| e.to_string())?; + std::fs::create_dir_all(&app_data_dir).map_err(|e| e.to_string())?; + + let url = "https://kcjerrell.github.io/dt-models/combined_models.json"; + let response = reqwest::get(url).await.map_err(|e| e.to_string())?; + let json: Value = response.json().await.map_err(|e| e.to_string())?; + + let mut model_files = Vec::new(); + + if let Some(obj) = json.as_object() { + for (key, value) in obj { + if key == "lastUpdate" { + continue; + } + + if let Some(arr) = value.as_array() { + let file_path = app_data_dir.join(format!("{}.json", key)); + let file_content = serde_json::to_string_pretty(arr).map_err(|e| e.to_string())?; + std::fs::write(&file_path, file_content).map_err(|e| e.to_string())?; + + let model_type = match key.as_str() { + "officialModels" | "communityModels" | "uncuratedModels" => ModelType::Model, + "officialLoras" | "communityLoras" => ModelType::Lora, + "officialCnets" | "communityCnets" => ModelType::Cnet, + _ => ModelType::None, + }; + + model_files.push(ModelInfoFile { + path: file_path.to_string_lossy().to_string(), + model_type, + }); + } + } + } + + Ok(JobResult::Subtasks(vec![Arc::new(SyncModelsJob::new( + model_files, + ))])) + } +} diff --git a/src-tauri/src/dtp_service/mod.rs b/src-tauri/src/dtp_service/mod.rs new file mode 100644 index 0000000..08fe5f3 --- /dev/null +++ b/src-tauri/src/dtp_service/mod.rs @@ -0,0 +1,19 @@ +pub mod events; +mod helpers; +mod scheduler; +mod watch; + +pub mod jobs; + +pub mod data; +pub use data::{ + dtp_decode_tensor, dtp_find_image_from_preview_id, dtp_find_predecessor, dtp_get_clip, + dtp_get_history_full, dtp_get_tensor_size, dtp_list_images, dtp_list_models, dtp_list_projects, + dtp_list_watch_folders, dtp_pick_watch_folder, dtp_remove_watch_folder, dtp_update_project, + dtp_update_watch_folder +}; + +pub mod dtp_service; +pub use dtp_service::{dtp_connect, DTPService, dtp_lock_folder}; + +pub use helpers::{AppHandleWrapper, GetFolderFilesResult, ProjectFile}; diff --git a/src-tauri/src/dtp_service/scheduler.rs b/src-tauri/src/dtp_service/scheduler.rs new file mode 100644 index 0000000..17f1b3d --- /dev/null +++ b/src-tauri/src/dtp_service/scheduler.rs @@ -0,0 +1,295 @@ +use std::{ + collections::HashMap, + sync::{ + atomic::{AtomicU64, Ordering}, + Arc, + }, +}; + +use tokio::sync::{mpsc, Mutex, Semaphore}; + +use crate::dtp_service::{ + events::DTPEvent, + jobs::{Job, JobContext, JobResult}, +}; + +type JobId = u64; + +#[derive(Clone, Debug, Default)] +pub enum JobStatus { + #[default] + Pending, + Active, + // Canceled, + WaitingForSubtasks(isize), + Complete, + Failed(String), +} + +#[derive(Clone, Debug, Default)] +struct JobState { + id: JobId, + parent_id: Option, + status: JobStatus, + jobs_failed: isize, + jobs_completed: isize, +} + +#[derive(Clone)] +struct JobEntry { + job: Arc, + state: JobState, +} + +#[derive(Clone)] +pub struct Scheduler { + tx: Arc>, + jobs: Arc>>, + next_id: Arc, + ctx: Arc, + worker_handle: Arc>>>, +} + +impl Scheduler { + pub fn new(ctx: Arc) -> Self { + let (tx, mut rx) = mpsc::channel::(10000); + + let semaphore = Arc::new(Semaphore::new(4)); + let scheduler = Scheduler { + tx: Arc::new(tx), + ctx, + jobs: Arc::new(Mutex::new(HashMap::new())), + next_id: Arc::new(AtomicU64::new(0)), + worker_handle: Arc::new(std::sync::Mutex::new(None)), + }; + + let handle = tokio::spawn({ + let semaphore = semaphore.clone(); + let scheduler = scheduler.clone(); + + async move { + while let Some(job_id) = rx.recv().await { + let permit = semaphore.clone().acquire_owned().await.unwrap(); + let scheduler = scheduler.clone(); + + tokio::spawn(async move { + scheduler.process(job_id).await; + drop(permit); // release worker slot + }); + } + } + }); + + *scheduler.worker_handle.lock().unwrap() = Some(handle); + + scheduler + } + + pub async fn stop(&self) { + if let Some(handle) = self.worker_handle.lock().unwrap().take() { + handle.abort(); + } + } + + async fn process(&self, job_id: JobId) { + // get the job, updating its status along the way + let job: Arc = { + let mut jobs = self.jobs.lock().await; + let Some(entry) = jobs.get_mut(&job_id) else { + log::warn!("[Scheduler] Job {} not found during process", job_id); + return; + }; + entry.state.status = JobStatus::Active; + entry.job.clone() + }; + + let label = job.get_label(); + log::debug!("[Scheduler] Starting job: {}", label); + + // emit start event + if let Some(event) = job.start_event() { + self.ctx.events.emit(event); + } + + // execute job + let result = job.execute(&self.ctx).await; + + let (next_status, event, subtasks) = self.handle_result(result).await; + + match &next_status { + JobStatus::WaitingForSubtasks(count) => self.shelve_job(job_id, count).await, + JobStatus::Complete => self.resolve_job(job_id, &self.ctx, Ok(())).await, + JobStatus::Failed(e) => self.resolve_job(job_id, &self.ctx, Err(e.clone())).await, + _ => {} + }; + + if let Some(subtasks) = subtasks { + for subtask in subtasks { + self.add_job_internal(subtask, Some(job_id)).await; + } + } + + if let Some(event) = event { + self.ctx.events.emit(event); + } + } + + async fn update_parent_job(&self, job_entry: &JobEntry, _ctx: &JobContext) -> Option { + if job_entry.state.parent_id.is_none() { + return None; + } + let parent_id = job_entry.state.parent_id.unwrap(); + + let (tasks_remaining, label) = { + let mut jobs = self.jobs.lock().await; + let Some(parent_job) = jobs.get_mut(&parent_id) else { + return None; + }; + let tasks_remaining = match job_entry.state.status { + JobStatus::Complete | JobStatus::Failed(_) => { + self.decrement_subtask_count(&mut parent_job.state) + } + _ => self.get_subtask_count(&parent_job.state), + }; + match job_entry.state.status { + JobStatus::Complete => parent_job.state.jobs_completed += 1, + JobStatus::Failed(_) => parent_job.state.jobs_failed += 1, + _ => {} + } + (tasks_remaining, parent_job.job.get_label()) + }; + + if tasks_remaining == 0 { + Some(parent_id) + } else { + None + } + } + + fn decrement_subtask_count(&self, state: &mut JobState) -> isize { + if let JobStatus::WaitingForSubtasks(tasks_remaining) = state.status { + state.status = JobStatus::WaitingForSubtasks(tasks_remaining - 1); + tasks_remaining - 1 + } else { + 0 + } + } + + fn get_subtask_count(&self, state: &JobState) -> isize { + if let JobStatus::WaitingForSubtasks(tasks_remaining) = state.status { + tasks_remaining + } else { + 0 + } + } + + async fn handle_result( + &self, + result: Result, + ) -> (JobStatus, Option, Option>>) { + let result = match result { + Ok(r) => r, + Err(e) => { + return (JobStatus::Failed(e.clone()), None, None); + } + }; + + let (status, event, subtasks) = match result { + JobResult::Event(event) => (JobStatus::Complete, Some(event), None), + JobResult::None => (JobStatus::Complete, None, None), + JobResult::Subtasks(subtasks) => ( + match subtasks.len() { + 0 => JobStatus::Complete, + _ => JobStatus::WaitingForSubtasks(subtasks.len() as isize), + }, + None, + Some(subtasks), + ), + }; + + (status, event, subtasks) + } + + /// Resolves a job, calling on_complete or on_failed, and updates its parent. + /// If a parent completes all subtasks, it always resolves as successful, + /// even if some subtasks failed. + async fn resolve_job(&self, job_id: JobId, ctx: &JobContext, result: Result<(), String>) { + let mut current_id = Some(job_id); + let mut current_result = result; + + while let Some(id) = current_id { + let mut entry = { + let mut jobs = self.jobs.lock().await; + let Some(entry) = jobs.remove(&id) else { + log::warn!("[Scheduler] Job {} not found during resolution", id); + break; + }; + entry + }; + + match ¤t_result { + Ok(_) => { + entry.state.status = JobStatus::Complete; + entry.job.on_complete(ctx).await; + if entry.state.jobs_failed + entry.state.jobs_completed > 0 { + log::debug!( + "[Scheduler] Finished job: {} and {} subtasks", + entry.job.get_label(), + entry.state.jobs_failed + entry.state.jobs_completed + ); + } else { + log::debug!("[Scheduler] Finished job: {}", entry.job.get_label(),); + } + } + Err(error) => { + entry.state.status = JobStatus::Failed(error.clone()); + log::warn!( + "[Scheduler] Failed job: {} ({}) {}", + entry.job.get_label(), + entry.state.id, + error + ); + entry.job.on_failed(ctx, error.clone()).await; + } + } + + current_id = self.update_parent_job(&entry, ctx).await; + + // Parent jobs always succeed when their subtasks finish, + // regardless of whether this specific subtask failed. + current_result = Ok(()); + } + } + + async fn shelve_job(&self, job_id: JobId, subtasks_remaining: &isize) { + let mut jobs = self.jobs.lock().await; + if let Some(entry) = jobs.get_mut(&job_id) { + entry.state.status = JobStatus::WaitingForSubtasks(*subtasks_remaining); + } else { + log::warn!("[Scheduler] Job {} not found during shelve", job_id); + } + } + + pub fn add_job(&self, job: T) { + let job = Arc::new(job); + let this = self.clone(); + tokio::spawn(async move { + this.add_job_internal(job, None).await; + }); + } + + async fn add_job_internal(&self, job: Arc, parent_id: Option) { + let id = self.next_id.fetch_add(1, Ordering::Relaxed); + let entry = JobEntry { + job, + state: JobState { + id, + parent_id, + status: JobStatus::Pending, + ..Default::default() + }, + }; + let _ = { self.jobs.lock().await.insert(id, entry) }; + let _ = self.tx.send(id).await; + } +} diff --git a/src-tauri/src/dtp_service/watch.rs b/src-tauri/src/dtp_service/watch.rs new file mode 100644 index 0000000..bb992e9 --- /dev/null +++ b/src-tauri/src/dtp_service/watch.rs @@ -0,0 +1,206 @@ +use dashmap::DashMap; +use notify_debouncer_mini::{ + new_debouncer, + notify::{RecommendedWatcher, RecursiveMode}, + DebouncedEvent, Debouncer, +}; +use std::{ + collections::HashSet, + path::Path, + sync::{mpsc::channel, Arc, OnceLock}, +}; +use tokio::time::Duration; +use tokio::{fs, sync::Mutex}; + +use crate::dtp_service::{ + jobs::{CheckFolderJob, SyncJob}, + scheduler::Scheduler, +}; + +pub struct WatchService { + watchers: DashMap, + volume_watcher: OnceLock, + scheduler: Arc, +} + +pub struct FolderWatcher { + watcher: Mutex>, + path: String, + recursive: bool, + task: tokio::task::JoinHandle<()>, +} + +impl FolderWatcher { + pub fn new(path: String, recursive: bool, scheduler: Arc) -> Self { + let (tx_std, rx_std) = std::sync::mpsc::channel::, _>>(); + + let watcher = new_debouncer(Duration::from_secs(2), tx_std).unwrap(); + let folder_path = path.clone(); + let task = tokio::task::spawn_blocking(move || { + for res in rx_std { + match res { + Ok(events) => { + let mut projects: HashSet = HashSet::new(); + for event in events { + match event.path.extension().and_then(|ext| ext.to_str()) { + Some("sqlite3") | Some("sqlite3-wal") => { + let project_path = event.path.with_extension("sqlite3"); + projects.insert(project_path.to_str().unwrap().to_string()); + } + _ => {} + } + } + + if !projects.is_empty() { + let job = CheckFolderJob::new_from_path( + folder_path.clone(), + false, + false, + Some(projects.into_iter().collect()), + ); + scheduler.add_job(job); + } + } + Err(e) => eprintln!("Watch error: {:?}", e), + } + } + }); + + Self { + watcher: Mutex::new(watcher), + path: path.to_string(), + recursive, + task, + } + } + + pub async fn start(&self) { + let exists = fs::try_exists(&self.path).await.unwrap_or(false); + if !exists { + return; + } + + let _ = self + .watcher + .lock() + .await + .watcher() + .watch(Path::new(&self.path), RecursiveMode::NonRecursive); + } + + pub async fn stop(&self) { + let _ = self + .watcher + .lock() + .await + .watcher() + .unwatch(Path::new(&self.path)); + } +} + +pub struct VolumeWatcher { + watcher: Mutex>, +} + +impl VolumeWatcher { + pub fn new(scheduler: Arc) -> Self { + let (tx_std, rx_std) = channel::, _>>(); + + let watcher = new_debouncer(Duration::from_secs(2), tx_std).unwrap(); + + let task = tokio::task::spawn_blocking(move || { + for res in rx_std { + match res { + Ok(events) => { + let mut volumes_changed = false; + for event in events { + if let Some(parent) = event.path.parent() { + if parent == Path::new("/Volumes") { + volumes_changed = true; + log::debug!("Volumes changed: {:?}", event.path); + } + } + } + + if volumes_changed { + let job = SyncJob::new(true); + scheduler.add_job(job); + } + } + Err(e) => eprintln!("Watch error: {:?}", e), + } + } + }); + + Self { + watcher: Mutex::new(watcher), + } + } + + pub async fn start(&self) { + self.watcher + .lock() + .await + .watcher() + .watch(Path::new("/Volumes"), RecursiveMode::NonRecursive) + .unwrap(); + } + + pub async fn stop(&self) { + self.watcher + .lock() + .await + .watcher() + .unwatch(Path::new("/Volumes")) + .unwrap(); + } +} + +impl WatchService { + pub fn new(scheduler: Scheduler) -> Self { + let scheduler = Arc::new(scheduler); + let watchers = DashMap::new(); + let volume_watcher = OnceLock::new(); + Self { + watchers, + volume_watcher, + scheduler, + } + } + + pub async fn watch_volumes(&self) -> Result<(), String> { + let volume_watcher = self + .volume_watcher + .get_or_init(|| VolumeWatcher::new(self.scheduler.clone())); + volume_watcher.start().await; + Ok(()) + } + + pub async fn stop_watch_volumes(&self) -> Result<(), String> { + let volume_watcher = self.volume_watcher.get().unwrap(); + volume_watcher.stop().await; + Ok(()) + } + + pub async fn watch_folder(&self, path: &str, recursive: bool) -> Result<(), String> { + let watcher = self.watchers.entry(path.to_string()).or_insert_with(|| { + FolderWatcher::new(path.to_string(), recursive, self.scheduler.clone()) + }); + watcher.start().await; + Ok(()) + } + + pub async fn stop_watch_folder(&self, path: &str) -> Result<(), String> { + let watcher = match self.watchers.get(path) { + Some(watcher) => watcher, + None => return Ok(()), + }; + watcher.stop().await; + Ok(()) + } + + #[allow(dead_code)] + pub async fn stop_all(&self) -> Result<(), String> { + Ok(()) + } +} diff --git a/src-tauri/src/lib.rs b/src-tauri/src/lib.rs index 4f76b75..374fe73 100644 --- a/src-tauri/src/lib.rs +++ b/src-tauri/src/lib.rs @@ -8,9 +8,11 @@ use tauri_plugin_window_state::StateFlags; mod clipboard; -mod bookmarks; +pub mod bookmarks; +pub mod dtp_service; mod ffmpeg; mod projects_db; +use dtp_service::dtp_connect; mod vid; use once_cell::sync::Lazy; @@ -19,7 +21,6 @@ use tokio::runtime::Runtime; pub static TOKIO_RT: Lazy = Lazy::new(|| Runtime::new().expect("Failed to create Tokio runtime")); - #[tauri::command] fn read_clipboard_types(pasteboard: Option) -> Result, String> { clipboard::read_clipboard_types(pasteboard) @@ -109,9 +110,6 @@ fn show_dev_window(app: tauri::AppHandle) -> Result<(), String> { #[cfg_attr(mobile, tauri::mobile_entry_point)] pub fn run() { - use projects_db::commands::*; - use projects_db::dtm_dtproject_protocol; - tauri::Builder::default() .plugin(tauri_plugin_shell::init()) .plugin(tauri_plugin_dialog::init()) @@ -154,32 +152,6 @@ pub fn run() { write_clipboard_binary, read_clipboard_strings, fetch_image_file, - // get_tensor_history, - // get_tensor, - // get_thumb_half, - projects_db_project_list, - projects_db_project_add, - projects_db_project_remove, - projects_db_project_scan, - projects_db_project_update_exclude, - projects_db_project_bulk_update_missing_on, - projects_db_image_list, - projects_db_get_clip, - projects_db_image_rebuild_fts, - projects_db_watch_folder_list, - projects_db_watch_folder_add, - projects_db_watch_folder_remove, - projects_db_watch_folder_update, - projects_db_scan_model_info, - projects_db_list_models, - dt_project_get_tensor_history, // #unused - dt_project_get_thumb_half, // #unused - dt_project_get_history_full, - dt_project_get_text_history, - dt_project_find_predecessor_candidates, - dt_project_get_tensor_raw, // #unused - dt_project_get_tensor_size, - dt_project_decode_tensor, vid::create_video_from_frames, vid::save_all_clip_frames, vid::check_pattern, @@ -187,25 +159,46 @@ pub fn run() { ffmpeg_download, ffmpeg_download, ffmpeg_call, - bookmarks::pick_draw_things_folder, + bookmarks::pick_folder_command, bookmarks::resolve_bookmark, - bookmarks::stop_accessing_bookmark + bookmarks::stop_accessing_bookmark, + dtp_connect, + dtp_service::data::dtp_pick_watch_folder, + dtp_service::data::dtp_decode_tensor, + dtp_service::data::dtp_find_image_from_preview_id, + dtp_service::data::dtp_find_predecessor, + dtp_service::data::dtp_get_clip, + dtp_service::data::dtp_get_history_full, + dtp_service::data::dtp_get_tensor_size, + dtp_service::data::dtp_list_images, + dtp_service::data::dtp_list_models, + dtp_service::data::dtp_list_projects, + dtp_service::data::dtp_list_watch_folders, + dtp_service::data::dtp_remove_watch_folder, + dtp_service::data::dtp_update_project, + dtp_service::data::dtp_update_watch_folder, + dtp_service::dtp_service::dtp_test, + dtp_service::dtp_service::dtp_sync, + dtp_service::dtp_service::dtp_lock_folder, ]) - .register_asynchronous_uri_scheme_protocol("dtm", |_ctx, request, responder| { - std::thread::spawn(move || { - TOKIO_RT.block_on(async move { - if request.uri().host().unwrap() == "dtproject" { - dtm_dtproject_protocol(request, responder).await; - } else { - responder.respond( - http::Response::builder() - .status(http::StatusCode::BAD_REQUEST) - .header(http::header::CONTENT_TYPE, mime::TEXT_PLAIN.essence_str()) - .body("failed to read file".as_bytes().to_vec()) - .unwrap(), - ); - } - }); + .register_asynchronous_uri_scheme_protocol("dtm", |ctx, request, responder| { + let app_handle = ctx.app_handle().clone(); + tauri::async_runtime::spawn(async move { + let dtp_service = app_handle.state::(); + let dtm_protocol = dtp_service.dtm_protocol().await; + if request.uri().host().unwrap() == "dtproject" { + dtm_protocol + .dtm_dtproject_protocol(request, responder) + .await; + } else { + responder.respond( + http::Response::builder() + .status(http::StatusCode::BAD_REQUEST) + .header(http::header::CONTENT_TYPE, mime::TEXT_PLAIN.essence_str()) + .body("failed to read file".as_bytes().to_vec()) + .unwrap(), + ); + } }); }) // .manage(AppState { @@ -227,6 +220,19 @@ pub fn run() { let _window = win_builder.build().unwrap(); + let app_handle_wrapper = dtp_service::AppHandleWrapper::new(Some(app.handle().clone())); + let dtp_service = dtp_service::DTPService::new(app_handle_wrapper.clone()); + + app.manage(dtp_service); + app.manage(app_handle_wrapper); + // tauri::async_runtime::spawn(async move { + // if let Err(e) = dtp_service.init().await { + // eprintln!("Failed to init DB: {}", e); + // } else { + // println!("DB initialized"); + // } + // }); + // let _panel_builder = // WebviewWindowBuilder::new(app, "panel", WebviewUrl::App(PathBuf::from("#mini"))) // .title("DT Metadata Mini") @@ -247,7 +253,7 @@ pub fn run() { .run(|_app_handle, event| match event { tauri::RunEvent::Exit => { bookmarks::cleanup_bookmarks(); - }, + } _ => {} }); } diff --git a/src-tauri/src/objc/FolderPicker.m b/src-tauri/src/objc/FolderPicker.m index 5329c58..be7baea 100644 --- a/src-tauri/src/objc/FolderPicker.m +++ b/src-tauri/src/objc/FolderPicker.m @@ -13,11 +13,12 @@ void ensure_bookmarks_initialized() { }); } -const char* open_dt_folder_picker(const char* default_path) { +const char* open_dt_folder_picker(const char* default_path, const char* button_text) { __block char* resultString = NULL; // Ensure we handle the C string safely NSString *defaultPathStr = default_path ? [NSString stringWithUTF8String:default_path] : nil; + NSString *buttonTextStr = button_text ? [NSString stringWithUTF8String:button_text] : nil; // NSOpenPanel must be run on the main thread dispatch_sync(dispatch_get_main_queue(), ^{ @@ -25,14 +26,13 @@ void ensure_bookmarks_initialized() { openPanel.canChooseDirectories = YES; openPanel.canChooseFiles = NO; openPanel.allowsMultipleSelection = NO; - openPanel.prompt = @"Select Documents folder"; + openPanel.prompt = buttonTextStr ?: @"Select folder"; if (defaultPathStr) { openPanel.directoryURL = [NSURL fileURLWithPath:defaultPathStr]; } else { NSURL *homeDir = [NSFileManager defaultManager].homeDirectoryForCurrentUser; - NSURL *suggestion = [homeDir URLByAppendingPathComponent:@"Library/Containers/com.liuliu.draw-things/Data/Documents"]; - openPanel.directoryURL = suggestion; + openPanel.directoryURL = homeDir; } if ([openPanel runModal] == NSModalResponseOK) { @@ -42,13 +42,29 @@ void ensure_bookmarks_initialized() { NSData *bookmarkData = [url bookmarkDataWithOptions:NSURLBookmarkCreationWithSecurityScope includingResourceValuesForKeys:nil relativeToURL:nil - error:&error]; + error:&error]; if (bookmarkData) { NSString *base64String = [bookmarkData base64EncodedStringWithOptions:0]; NSString *path = url.path; - NSString *result = [NSString stringWithFormat:@"%@|%@", path, base64String]; - resultString = strdup([result UTF8String]); + + // JSON format: {"path": "...", "bookmark": "..."} + // We need to escape backslashes and quotes in path if necessary (standard JSON rules) + // For simplicity in ObjC without a JSON lib, we can use NSJSONSerialization + + NSDictionary *dict = @{ + @"path": path, + @"bookmark": base64String + }; + + NSData *jsonData = [NSJSONSerialization dataWithJSONObject:dict options:0 error:&error]; + if (jsonData) { + NSString *jsonString = [[NSString alloc] initWithData:jsonData encoding:NSUTF8StringEncoding]; + resultString = strdup([jsonString UTF8String]); + } else { + NSLog(@"Failed to serialize JSON: %@", error); + } + } else { NSLog(@"Failed to create bookmark: %@", error); } @@ -74,11 +90,18 @@ void free_string_ptr(char* ptr) { if (!base64String) return NULL; // Check if we already have this bookmark active - // Note: In Swift we used the base64 string as the key. We do the same here. @synchronized(activeBookmarks) { NSURL *existingUrl = activeBookmarks[base64String]; if (existingUrl) { - return strdup([existingUrl.path UTF8String]); + NSDictionary *dict = @{ + @"status": @"resolved", + @"path": existingUrl.path + }; + NSData *jsonData = [NSJSONSerialization dataWithJSONObject:dict options:0 error:nil]; + if (jsonData) { + NSString *jsonString = [[NSString alloc] initWithData:jsonData encoding:NSUTF8StringEncoding]; + return strdup([jsonString UTF8String]); + } } } @@ -88,29 +111,56 @@ void free_string_ptr(char* ptr) { BOOL isStale = NO; NSError *error = nil; NSURL *url = [NSURL URLByResolvingBookmarkData:data - options:NSURLBookmarkResolutionWithSecurityScope + options:NSURLBookmarkResolutionWithSecurityScope | + NSURLBookmarkResolutionWithoutMounting relativeToURL:nil - bookmarkDataIsStale:&isStale - error:&error]; - - if (isStale) { - NSLog(@"Bookmark is stale"); - } + bookmarkDataIsStale:&isStale + error:&error]; if (url) { if ([url startAccessingSecurityScopedResource]) { + NSString *status = @"resolved"; + NSString *newBookmarkBase64 = nil; + + if (isStale) { + NSLog(@"Bookmark is stale, refreshing..."); + NSData *newBookmarkData = [url bookmarkDataWithOptions:NSURLBookmarkCreationWithSecurityScope + includingResourceValuesForKeys:nil + relativeToURL:nil + error:&error]; + if (newBookmarkData) { + newBookmarkBase64 = [newBookmarkData base64EncodedStringWithOptions:0]; + status = @"stale_refreshed"; + } else { + NSLog(@"Failed to refresh stale bookmark: %@", error); + } + } + @synchronized(activeBookmarks) { activeBookmarks[base64String] = url; } - return strdup([url.path UTF8String]); + + NSMutableDictionary *resultDict = [NSMutableDictionary dictionaryWithDictionary:@{ + @"status": status, + @"path": url.path + }]; + if (newBookmarkBase64) { + resultDict[@"new_bookmark"] = newBookmarkBase64; + } + + NSData *jsonData = [NSJSONSerialization dataWithJSONObject:resultDict options:0 error:&error]; + if (jsonData) { + NSString *jsonString = [[NSString alloc] initWithData:jsonData encoding:NSUTF8StringEncoding]; + return strdup([jsonString UTF8String]); + } } else { NSLog(@"Failed to start accessing security scoped resource"); - return NULL; } } else { NSLog(@"Error resolving bookmark: %@", error); - return NULL; } + + return NULL; } void stop_accessing_security_scoped_resource(const char* bookmark_base64) { diff --git a/src-tauri/src/projects_db/commands.rs b/src-tauri/src/projects_db/commands.rs index 745460b..7b4c858 100644 --- a/src-tauri/src/projects_db/commands.rs +++ b/src-tauri/src/projects_db/commands.rs @@ -57,10 +57,11 @@ fn update_tags(app_handle: &tauri::AppHandle, tag: &str, data: Value) { )] pub async fn projects_db_project_add( app_handle: tauri::AppHandle, + watch_folder_id: i64, path: String, ) -> Result { - let pdb = ProjectsDb::get_or_init(&app_handle).await?; - let project = pdb.add_project(&path).await?; + let pdb = ProjectsDb::get_or_init(&app_handle.clone().into()).await?; + let project = pdb.add_project(watch_folder_id, &path).await?; update_tags( &app_handle, "projects", @@ -72,15 +73,15 @@ pub async fn projects_db_project_add( } #[dtm_command( - ok = |ctx| format!("removed project {}", project_name(&ctx.path)), - err = |ctx| format!("error removing project {}: {}", project_name(&ctx.path), ctx.res) + ok = |ctx| format!("removed project {}", ctx.id), + err = |ctx| format!("error removing project {}: {}", ctx.id, ctx.res) )] pub async fn projects_db_project_remove( app_handle: tauri::AppHandle, - path: String, + id: i64, ) -> Result<(), String> { - let pdb = ProjectsDb::get_or_init(&app_handle).await?; - let result = pdb.remove_project(&path).await.map_err(|e| e.to_string())?; + let pdb = ProjectsDb::get_or_init(&app_handle.clone().into()).await?; + let result = pdb.remove_project(id).await.map_err(|e| e.to_string())?; match result { Some(id) => { @@ -101,8 +102,8 @@ pub async fn projects_db_project_remove( pub async fn projects_db_project_list( app_handle: tauri::AppHandle, ) -> Result, String> { - let pdb = ProjectsDb::get_or_init(&app_handle).await?; - let projects = pdb.list_projects().await.unwrap(); + let pdb = ProjectsDb::get_or_init(&app_handle.clone().into()).await?; + let projects = pdb.list_projects(None).await.unwrap(); Ok(projects) } @@ -112,7 +113,7 @@ pub async fn projects_db_project_update_exclude( id: i32, exclude: bool, ) -> Result<(), String> { - let pdb = ProjectsDb::get_or_init(&app_handle).await?; + let pdb = ProjectsDb::get_or_init(&app_handle.clone().into()).await?; pdb.update_exclude(id, exclude) .await .map_err(|e| e.to_string())?; @@ -124,11 +125,11 @@ pub async fn projects_db_project_update_exclude( #[dtm_command] pub async fn projects_db_project_bulk_update_missing_on( app_handle: tauri::AppHandle, - paths: Vec, - missing_on: Option, + watch_folder_id: i64, + is_missing: bool, ) -> Result<(), String> { - let pdb = ProjectsDb::get_or_init(&app_handle).await?; - pdb.bulk_update_missing_on(paths, missing_on) + let pdb = ProjectsDb::get_or_init(&app_handle.clone().into()).await?; + pdb.bulk_update_missing_on(watch_folder_id, is_missing) .await .map_err(|e| e.to_string())?; invalidate_tags(&app_handle, "projects", "bulk_update"); @@ -136,17 +137,17 @@ pub async fn projects_db_project_bulk_update_missing_on( } #[dtm_command( - ok = |ctx| format!("scanned project {}", project_name(&ctx.path)), - err = |ctx| format!("error scanning project {}: {}", project_name(&ctx.path), ctx.res) + ok = |ctx| format!("scanned project {}", ctx.id), + err = |ctx| format!("error scanning project {}: {}", ctx.id, ctx.res) )] pub async fn projects_db_project_scan( app: tauri::AppHandle, - path: String, + id: i64, full_scan: Option, - _filesize: Option, - _modified: Option, + filesize: Option, + modified: Option, ) -> Result { - let pdb = ProjectsDb::get_or_init(&app).await?; + let pdb = ProjectsDb::get_or_init(&app.clone().into()).await?; // let update = |images_scanned: i32, images_total: i32| { // app.emit( // "projects_db_scan_progress", @@ -162,20 +163,17 @@ pub async fn projects_db_project_scan( // .unwrap(); // }; let result: Result<(i64, u64), String> = pdb - .scan_project(&path, full_scan.unwrap_or(false)) + .scan_project(id, full_scan.unwrap_or(false)) .await .map_err(|e| e.to_string()); match result { Ok((_id, total)) => { - let project = pdb - .update_project(&path, _filesize, _modified) - .await - .map_err(|e| e.to_string())?; + let project = pdb.get_project(_id).await.map_err(|e| e.to_string())?; if total > 0 { let project = pdb - .get_project(project.id) + .update_project(project.id, filesize, modified) .await .map_err(|e| e.to_string())?; @@ -200,7 +198,7 @@ pub async fn projects_db_project_scan( Ok(total as i32) } Err(err) => { - log::error!("Error scanning project {}: {}", path, err); + log::error!("Error scanning project {}: {}", id, err); Err(err.to_string()) } } @@ -220,7 +218,7 @@ pub async fn projects_db_image_list( show_video: Option, show_image: Option, ) -> Result { - let projects_db = ProjectsDb::get_or_init(&app).await?; + let projects_db = ProjectsDb::get_or_init(&app.clone().into()).await?; let opts = ListImagesOptions { project_ids, search, @@ -237,18 +235,33 @@ pub async fn projects_db_image_list( Ok(projects_db.list_images(opts).await.unwrap()) } +#[dtm_command] +pub async fn projects_db_image_find_by_preview_id( + app: tauri::AppHandle, + project_id: i64, + preview_id: i64, +) -> Result, String> { + let projects_db = ProjectsDb::get_or_init(&app.clone().into()).await?; + let image = projects_db + .find_image_by_preview_id(project_id, preview_id) + .await + .map_err(|e| e.to_string())?; + + Ok(image) +} + #[dtm_command] pub async fn projects_db_get_clip( app_handle: tauri::AppHandle, image_id: i64, ) -> Result, String> { - let projects_db = ProjectsDb::get_or_init(&app_handle).await?; + let projects_db = ProjectsDb::get_or_init(&app_handle.clone().into()).await?; projects_db.get_clip(image_id).await } #[dtm_command] pub async fn projects_db_image_rebuild_fts(app: tauri::AppHandle) -> Result<(), String> { - let projects_db = ProjectsDb::get_or_init(&app).await?; + let projects_db = ProjectsDb::get_or_init(&app.clone().into()).await?; projects_db.rebuild_images_fts().await.unwrap(); Ok(()) } @@ -257,7 +270,7 @@ pub async fn projects_db_image_rebuild_fts(app: tauri::AppHandle) -> Result<(), pub async fn projects_db_watch_folder_list( app: tauri::AppHandle, ) -> Result, String> { - let projects_db = ProjectsDb::get_or_init(&app).await?; + let projects_db = ProjectsDb::get_or_init(&app.clone().into()).await?; Ok(projects_db.list_watch_folders().await.unwrap()) } @@ -268,11 +281,12 @@ pub async fn projects_db_watch_folder_list( pub async fn projects_db_watch_folder_add( app: tauri::AppHandle, path: String, + bookmark: String, recursive: bool, ) -> Result { - let projects_db = ProjectsDb::get_or_init(&app).await?; + let projects_db = ProjectsDb::get_or_init(&app.clone().into()).await?; let result = projects_db - .add_watch_folder(&path, recursive) + .add_watch_folder(&path, &bookmark, recursive) .await .unwrap(); @@ -286,7 +300,7 @@ pub async fn projects_db_watch_folder_remove( app: tauri::AppHandle, ids: Vec, ) -> Result<(), String> { - let projects_db = ProjectsDb::get_or_init(&app).await?; + let projects_db = ProjectsDb::get_or_init(&app.clone().into()).await?; projects_db.remove_watch_folders(ids).await.unwrap(); invalidate_tags(&app, "watchfolders", "remove"); Ok(()) @@ -299,7 +313,7 @@ pub async fn projects_db_watch_folder_update( recursive: Option, last_updated: Option, ) -> Result { - let projects_db = ProjectsDb::get_or_init(&app).await?; + let projects_db = ProjectsDb::get_or_init(&app.clone().into()).await?; let result = projects_db .update_watch_folder(id, recursive, last_updated) .await @@ -317,7 +331,7 @@ pub async fn projects_db_scan_model_info( file_path: String, model_type: entity::enums::ModelType, ) -> Result { - let projects_db = ProjectsDb::get_or_init(&app).await?; + let projects_db = ProjectsDb::get_or_init(&app.clone().into()).await?; let count = projects_db .scan_model_info(&file_path, model_type) .await @@ -335,7 +349,7 @@ pub async fn projects_db_list_models( app: tauri::AppHandle, model_type: Option, ) -> Result, String> { - let projects_db = ProjectsDb::get_or_init(&app).await?; + let projects_db = ProjectsDb::get_or_init(&app.clone().into()).await?; Ok(projects_db .list_models(model_type) .await @@ -344,11 +358,12 @@ pub async fn projects_db_list_models( #[dtm_command] pub async fn dt_project_get_tensor_history( - project_file: String, + app: tauri::AppHandle, + project_id: i64, index: u32, count: usize, ) -> Result, String> { - let project = DTProject::get(&project_file).await.unwrap(); + let project = get_project(app, project_id).await?; match project.get_histories(index as i64, count).await { Ok(history) => Ok(history), Err(_e) => Ok(Vec::new()), @@ -357,102 +372,120 @@ pub async fn dt_project_get_tensor_history( #[dtm_command] pub async fn dt_project_get_text_history( - project_file: String, + app: tauri::AppHandle, + project_id: i64, ) -> Result, String> { - let project = DTProject::get(&project_file).await.unwrap(); - Ok(project.get_text_history().await.unwrap()) + let project = get_project(app, project_id).await?; + Ok(project + .get_text_history() + .await + .map_err(|e| e.to_string())?) } #[dtm_command] pub async fn dt_project_get_thumb_half( - project_file: String, + app: tauri::AppHandle, + project_id: i64, thumb_id: i64, ) -> Result, String> { - let project = DTProject::get(&project_file).await.unwrap(); - Ok(project.get_thumb_half(thumb_id).await.unwrap()) + let project = get_project(app, project_id).await?; + Ok(project + .get_thumb_half(thumb_id) + .await + .map_err(|e| e.to_string())?) } #[dtm_command] pub async fn dt_project_get_history_full( - project_file: String, + app: tauri::AppHandle, + project_id: i64, row_id: i64, ) -> Result { - let project = DTProject::get(&project_file).await.unwrap(); - let history = project.get_history_full(row_id).await.unwrap(); + let project = get_project(app, project_id).await?; + let history = project + .get_history_full(row_id) + .await + .map_err(|e| e.to_string())?; Ok(history) } #[dtm_command] pub async fn dt_project_get_tensor_raw( app: tauri::AppHandle, - project_id: Option, - project_path: Option, + project_id: i64, tensor_id: String, ) -> Result { - let project = get_project(app, project_path, project_id).await.unwrap(); - let tensor = project.get_tensor_raw(&tensor_id).await.unwrap(); + let project = get_project(app, project_id).await?; + let tensor = project + .get_tensor_raw(&tensor_id) + .await + .map_err(|e| e.to_string())?; Ok(tensor) } #[dtm_command] pub async fn dt_project_get_tensor_size( app: tauri::AppHandle, - project_id: Option, - project_path: Option, + project_id: i64, tensor_id: String, ) -> Result { - let project = get_project(app, project_path, project_id).await.unwrap(); - let tensor = project.get_tensor_size(&tensor_id).await.unwrap(); + let project = get_project(app, project_id).await?; + let tensor = project + .get_tensor_size(&tensor_id) + .await + .map_err(|e| e.to_string())?; Ok(tensor) } #[dtm_command] pub async fn dt_project_decode_tensor( app: tauri::AppHandle, - project_id: Option, - project_file: Option, + project_id: i64, node_id: Option, tensor_id: String, as_png: bool, ) -> Result { - let project = get_project(app, project_file, project_id).await.unwrap(); - let tensor = project.get_tensor_raw(&tensor_id).await.unwrap(); + let project = get_project(app, project_id).await?; + let tensor = project + .get_tensor_raw(&tensor_id) + .await + .map_err(|e| e.to_string())?; let metadata = match node_id { - Some(node) => Some(project.get_history_full(node).await.unwrap().history), + Some(node) => Some( + project + .get_history_full(node) + .await + .map_err(|e| e.to_string())? + .history, + ), None => None, }; - let buffer = decode_tensor(tensor, as_png, metadata, None).unwrap(); + let buffer = decode_tensor(tensor, as_png, metadata, None).map_err(|e| e.to_string())?; Ok(tauri::ipc::Response::new(buffer)) } #[dtm_command] pub async fn dt_project_find_predecessor_candidates( - project_file: String, + app: tauri::AppHandle, + project_id: i64, row_id: i64, lineage: i64, logical_time: i64, ) -> Result, String> { - let project = DTProject::get(&project_file).await.unwrap(); + let project = get_project(app, project_id).await?; Ok(project .find_predecessor_candidates(row_id, lineage, logical_time) .await - .unwrap()) + .map_err(|e| e.to_string())?) } async fn get_project( app: tauri::AppHandle, - project_path: Option, - project_id: Option, + project_id: i64, ) -> Result, String> { - let project_ref = match project_id { - Some(pid) => ProjectRef::Id(pid), - None => match project_path { - Some(path) => ProjectRef::Path(path), - None => return Err("No project specified".to_string()), - }, - }; - let projects_db = ProjectsDb::get_or_init(&app).await?; + let project_ref = ProjectRef::Id(project_id); + let projects_db = ProjectsDb::get_or_init(&app.clone().into()).await?; let project = projects_db.get_dt_project(project_ref).await?; Ok(project) } diff --git a/src-tauri/src/projects_db/dt_project.rs b/src-tauri/src/projects_db/dt_project.rs index 5308502..6d75fe2 100644 --- a/src-tauri/src/projects_db/dt_project.rs +++ b/src-tauri/src/projects_db/dt_project.rs @@ -10,36 +10,88 @@ use crate::projects_db::{ tensor_history_tensor_data::TensorHistoryTensorData, TextHistory, }; -use moka::future::Cache; +use dashmap::DashMap; use once_cell::sync::Lazy; use serde::Serialize; -use sqlx::{query, query_as, sqlite::SqliteRow, Error, Row, SqlitePool}; -use std::sync::{ - atomic::{AtomicBool, Ordering}, - Arc, +use sqlx::{ + query, query_as, + sqlite::{SqliteConnection, SqliteRow}, + Connection, Error, Row, SqlitePool, +}; +use std::{ + future::Future, + pin::Pin, + sync::{ + atomic::{AtomicBool, AtomicU64, Ordering}, + Arc, + }, + time::Duration, }; use tokio::sync::OnceCell; -static PROJECT_CACHE: Lazy>> = Lazy::new(|| { - Cache::builder() - .max_capacity(16) - // caching database connections for 3 seconds, so images can be loaded in bulk - // from separate requests. Closing them early to avoid locks, in case project - // is renamed in DT - .time_to_idle(std::time::Duration::from_secs(3)) - .build() -}); +/// TTL for cached projects. After this duration of no access, the project is evicted. +const CACHE_TTL: Duration = Duration::from_secs(3); +/// Grace period after removing from cache before closing the pool, +/// allowing in-flight queries to complete. +const DRAIN_GRACE: Duration = Duration::from_millis(500); + +static PROJECT_CACHE: Lazy>>>> = + Lazy::new(DashMap::new); + +struct CachedProject { + project: Arc, + generation: AtomicU64, +} pub struct DTProject { - pool: SqlitePool, + pool: Arc, path: String, - has_tensor_history: AtomicBool, - has_text_history: AtomicBool, - has_moodboard: AtomicBool, - has_tensors: AtomicBool, + pub tables: Arc>, + pub text_history: Arc>, +} - has_thumbs: AtomicBool, - pub text_history: OnceCell, +pub async fn close_folder(folder_path: &str) { + let to_remove: Vec = PROJECT_CACHE + .iter() + .filter(|entry| entry.key().starts_with(folder_path)) + .map(|entry| entry.key().clone()) + .collect(); + + for key in to_remove { + if let Some((_, cell)) = PROJECT_CACHE.remove(&key) { + if let Some(cached) = cell.get() { + let pool = cached.project.pool.clone(); + tokio::spawn(async move { + tokio::time::sleep(DRAIN_GRACE).await; + pool.close().await; + }); + } + } + } +} + +fn schedule_eviction(path: String, generation: u64) { + tokio::spawn(async move { + tokio::time::sleep(CACHE_TTL).await; + + // Only evict if no one has accessed it since we were scheduled + let should_evict = PROJECT_CACHE + .get(&path) + .and_then(|cell| cell.get().map(|c| c.generation.load(Ordering::Relaxed) == generation)) + .unwrap_or(false); + + if should_evict { + if let Some((_, cell)) = PROJECT_CACHE.remove(&path) { + if let Some(cached) = cell.get() { + let pool = cached.project.pool.clone(); + tokio::spawn(async move { + tokio::time::sleep(DRAIN_GRACE).await; + pool.close().await; + }); + } + } + } + }); } #[derive(Debug, Serialize)] @@ -51,83 +103,110 @@ enum DTProjectTable { Thumbs, } +#[derive(Debug, Default, Clone)] +pub struct DTProjectTableStatus { + pub has_tensor_history: bool, + pub has_text_history: bool, + pub has_moodboard: bool, + pub has_tensors: bool, + pub has_thumbs: bool, +} + impl DTProject { pub async fn new(db_path: &str) -> Result { let connect_string = format!("sqlite:{}?mode=ro", db_path); let pool = SqlitePool::connect(&connect_string).await?; let dtp = Self { - pool, + pool: Arc::new(pool), path: db_path.to_string(), - has_tensor_history: AtomicBool::new(false), - has_text_history: AtomicBool::new(false), - has_moodboard: AtomicBool::new(false), - has_tensors: AtomicBool::new(false), - - has_thumbs: AtomicBool::new(false), - text_history: OnceCell::new(), + tables: Arc::new(OnceCell::new()), + text_history: Arc::new(OnceCell::new()), }; dtp.check_tables().await?; - Ok(dtp) } pub async fn get(path: &str) -> Result, Error> { - let arc = PROJECT_CACHE - .try_get_with(path.to_string(), async move { - let proj = DTProject::new(path).await?; - Ok::<_, Error>(Arc::new(proj)) + let cell = PROJECT_CACHE + .entry(path.to_string()) + .or_insert_with(|| Arc::new(OnceCell::new())) + .clone(); + + let result = cell + .get_or_try_init(|| async { + let project = Arc::new(DTProject::new(path).await?); + Ok::, Error>(Arc::new(CachedProject { + project, + generation: AtomicU64::new(0), + })) }) - .await - .map_err(|e| Error::Protocol(e.to_string()))?; + .await; - Ok(arc) - } - - pub async fn check_tables(&self) -> Result<(), Error> { - let tables: Vec<(String,)> = - sqlx::query_as::<_, (String,)>("SELECT name FROM sqlite_master WHERE type='table';") - .fetch_all(&self.pool) - .await?; - - for table in tables { - match table.0.as_str() { - "tensorhistorynode" => self.has_tensor_history.store(true, Ordering::Relaxed), - "tensormoodboarddata" => self.has_moodboard.store(true, Ordering::Relaxed), - "tensors" => self.has_tensors.store(true, Ordering::Relaxed), - "thumbnailhistorynode" => self.has_thumbs.store(true, Ordering::Relaxed), - "texthistorynode" => self.has_text_history.store(true, Ordering::Relaxed), - _ => {} + match result { + Ok(cached) => { + let gen = cached.generation.fetch_add(1, Ordering::Relaxed) + 1; + schedule_eviction(path.to_string(), gen); + Ok(cached.project.clone()) + } + Err(e) => { + // Remove the empty OnceCell so the next caller retries fresh + PROJECT_CACHE.remove(path); + Err(e) } } - - Ok(()) } - fn has_table(&self, table: &DTProjectTable) -> bool { - match table { - DTProjectTable::TensorHistory => self.has_tensor_history.load(Ordering::Relaxed), - DTProjectTable::TextHistory => self.has_text_history.load(Ordering::Relaxed), - DTProjectTable::Moodboard => self.has_moodboard.load(Ordering::Relaxed), - DTProjectTable::Tensors => self.has_tensors.load(Ordering::Relaxed), - DTProjectTable::Thumbs => self.has_thumbs.load(Ordering::Relaxed), - } + pub async fn check_tables(&self) -> Result<&DTProjectTableStatus, Error> { + let status = self + .tables + .get_or_try_init::(async || { + let tables: Vec<(String,)> = sqlx::query_as::<_, (String,)>( + "SELECT name FROM sqlite_master WHERE type='table';", + ) + .fetch_all(&*self.pool) + .await + .unwrap(); + + let mut status = DTProjectTableStatus::default(); + + for table in tables { + match table.0.as_str() { + "tensorhistorynode" => { + status.has_tensor_history = true; + } + "tensormoodboarddata" => status.has_moodboard = true, + "tensors" => status.has_tensors = true, + "thumbnailhistorynode" => status.has_thumbs = true, + "texthistorynode" => status.has_text_history = true, + _ => {} + } + } + Ok(status) + }) + .await + .unwrap(); + + Ok(status) } async fn check_table(&self, table: &DTProjectTable) -> Result { - if self.has_table(table) { - return Ok(true); - } + let status = self.check_tables().await?; + + let has_table = match table { + DTProjectTable::TensorHistory => status.has_tensor_history, + DTProjectTable::TextHistory => status.has_text_history, + DTProjectTable::Moodboard => status.has_moodboard, + DTProjectTable::Tensors => status.has_tensors, + DTProjectTable::Thumbs => status.has_thumbs, + }; - self.check_tables().await?; - match self.has_table(table) { - true => Ok(true), - false => Err(Error::from(std::io::Error::new( - std::io::ErrorKind::Other, - "Table not found", - ))), + if !has_table { + return Err(Error::Protocol("Table not found".to_string())); } + + Ok(has_table) } pub async fn get_fingerprint(&self) -> Result { @@ -143,7 +222,7 @@ impl DTProject { LIMIT 5 )", ) - .fetch_one(&self.pool) + .fetch_one(&*self.pool) .await?; let fingerprint: String = row.get(0); @@ -170,7 +249,7 @@ impl DTProject { query_as(&full_query_where("thn.rowid >= ?1 AND thn.rowid < ?2")) .bind(first_id) .bind(first_id + count as i64) - .fetch_all(&self.pool) + .fetch_all(&*self.pool) .await?; let grouper = TensorNodeGrouper::new(&result); @@ -231,7 +310,7 @@ impl DTProject { self.check_table(&DTProjectTable::Tensors).await?; let row = query("SELECT type, format, datatype, dim, data FROM tensors WHERE name = ?1") .bind(name) - .fetch_one(&self.pool) + .fetch_one(&*self.pool) .await?; let tensor_type: i64 = row.get(0); @@ -261,7 +340,7 @@ impl DTProject { self.check_table(&DTProjectTable::Tensors).await?; let row = query("SELECT datatype, dim FROM tensors WHERE name = ?1") .bind(name) - .fetch_one(&self.pool) + .fetch_one(&*self.pool) .await?; let datatype: i64 = row.get(0); @@ -311,7 +390,7 @@ impl DTProject { let result = query( "SELECT COUNT(*) AS total_count, MAX(rowid) AS last_rowid FROM tensorhistorynode;", ) - .fetch_one(&self.pool) + .fetch_one(&*self.pool) .await?; Ok(DTProjectInfo { @@ -325,7 +404,7 @@ impl DTProject { self.check_table(&DTProjectTable::Thumbs).await?; let result = query("SELECT p FROM thumbnailhistoryhalfnode WHERE __pk0 = ?1") .bind(thumb_id) - .fetch_one(&self.pool) + .fetch_one(&*self.pool) .await?; let thumbnail: Vec = result.get(0); Ok(thumbnail) @@ -335,7 +414,7 @@ impl DTProject { self.check_table(&DTProjectTable::Thumbs).await?; let result = query("SELECT p FROM thumbnailhistorynode WHERE __pk0 = ?1") .bind(thumb_id) - .fetch_one(&self.pool) + .fetch_one(&*self.pool) .await?; let thumbnail: Vec = result.get(0); Ok(thumbnail) @@ -346,7 +425,7 @@ impl DTProject { node_id: i64, ) -> Result, Error> { self.check_table(&DTProjectTable::TensorHistory).await?; - + let history = self.get_history_full(node_id).await?; let num_frames = history.history.num_frames; @@ -354,7 +433,7 @@ impl DTProject { .bind(node_id) .bind(node_id + num_frames as i64) .map(|row: SqliteRow| self.map_clip(row)) - .fetch_all(&self.pool) + .fetch_all(&*self.pool) .await?; Ok(items) @@ -368,7 +447,7 @@ impl DTProject { self.check_table(&DTProjectTable::TensorHistory).await?; let result: Vec = query_as(&full_query_where("thn.rowid == ?1")) .bind(row_id) - .fetch_all(&self.pool) + .fetch_all(&*self.pool) .await?; let mut item = TensorHistoryExtra::from((result, self.path.clone())); @@ -463,7 +542,7 @@ impl DTProject { .bind(lineage) .bind(logical_time) .map(|row: SqliteRow| row.get(0)) - .fetch_all(&self.pool) + .fetch_all(&*self.pool) .await?; Ok(shuffle_ids) @@ -483,7 +562,7 @@ impl DTProject { .bind(logical_time - 1) .bind(row_id) .map(|row: SqliteRow| self.map_full(row)) - .fetch_all(&self.pool) + .fetch_all(&*self.pool) .await?; let mut same_lineage: Option<&TensorHistoryExtra> = None; @@ -552,13 +631,23 @@ impl DTProject { let p: Vec = row.get(0); TextHistoryNode::try_from(p.as_slice()).unwrap() }) - .fetch_all(&self.pool) + .fetch_all(&*self.pool) .await?; Ok(items) } } +pub async fn get_last_row(path: &str) -> Result<(i64, i64), Error> { + let connect_string = format!("sqlite:{}?mode=ro", path); + let mut conn = SqliteConnection::connect(&connect_string).await?; + let row = query("SELECT max(rowid) FROM tensorhistorynode") + .fetch_one(&mut conn) + .await?; + let rowid: i64 = row.get(0); + Ok((rowid, rowid)) +} + fn import_query(has_moodboard: bool) -> String { let moodboard = match has_moodboard { true => { @@ -674,7 +763,6 @@ fn full_query_where(where_expr: &str) -> String { } pub enum ProjectRef { - Path(String), Id(i64), } diff --git a/src-tauri/src/projects_db/dtm_dtproject.rs b/src-tauri/src/projects_db/dtm_dtproject.rs index 51d6214..d9eeea9 100644 --- a/src-tauri/src/projects_db/dtm_dtproject.rs +++ b/src-tauri/src/projects_db/dtm_dtproject.rs @@ -1,12 +1,12 @@ +use dashmap::DashMap; use once_cell::sync::Lazy; -use sea_orm::DbErr; -use std::{collections::HashMap, sync::RwLock}; use tauri::{ http::{self, Response, StatusCode, Uri}, UriSchemeResponder, }; use crate::projects_db::{ + projects_db::MixedError, tensors::{decode_tensor, scribble_mask_to_png}, DTProject, ProjectsDb, }; @@ -21,8 +21,7 @@ const MISSING_SVG: &str = r##" // dtm://dtm_dtproject/thumbhalf/5/82988 // dtm://dtm_dtproject/{item type}/{project_id}/{item id} -static PROJECT_PATH_CACHE: Lazy>> = - Lazy::new(|| RwLock::new(HashMap::new())); +static PROJECT_PATH_CACHE: Lazy> = Lazy::new(DashMap::new); #[derive(Default)] struct DTPRequest { @@ -64,68 +63,98 @@ fn parse_request(uri: &Uri) -> Option { Some(req) } -pub async fn dtm_dtproject_protocol(request: http::Request, responder: UriSchemeResponder) { - let response = match handle_request(request).await { - Ok(r) => r, - Err(e) => { - log::error!("DTM Protocol Error: {}", e); - // Response::builder() - // .status(StatusCode::INTERNAL_SERVER_ERROR) - // .body(e.into_bytes()) - // .unwrap() - Response::builder() - .status(StatusCode::OK) - .header("Content-Type", "image/svg+xml") - .body(MISSING_SVG.as_bytes().to_vec()) - .unwrap() - } - - }; - - responder.respond(response); +pub struct DtmProtocol { + pdb: ProjectsDb, } -async fn handle_request(request: http::Request) -> Result>, String> { - let req = parse_request(request.uri()); +impl DtmProtocol { + pub fn new(pdb: ProjectsDb) -> Self { + Self { pdb } + } - if req.is_none() { - return Ok(Response::builder() - .status(StatusCode::BAD_REQUEST) - .body("Invalid path format".as_bytes().to_vec()) - .map_err(|e| e.to_string())?); + pub async fn dtm_dtproject_protocol( + &self, + request: http::Request, + responder: UriSchemeResponder, + ) { + let response = match self.handle_request(request).await { + Ok(r) => r, + Err(e) => { + log::error!("DTM Protocol Error: {}", e); + // Response::builder() + // .status(StatusCode::INTERNAL_SERVER_ERROR) + // .body(e.into_bytes()) + // .unwrap() + Response::builder() + .status(StatusCode::OK) + .header("Content-Type", "image/svg+xml") + .body(MISSING_SVG.as_bytes().to_vec()) + .unwrap() + } + }; + + responder.respond(response); } - let req = req.unwrap(); + async fn handle_request( + &self, + request: http::Request, + ) -> Result>, String> { + let req = parse_request(request.uri()); + + if req.is_none() { + return Ok(Response::builder() + .status(StatusCode::BAD_REQUEST) + .body("Invalid path format".as_bytes().to_vec()) + .map_err(|e| e.to_string())?); + } + + let req = req.unwrap(); + + let item_type = req.item_type; + let project_id: i64 = req.project_id; + + let project_path = self + .get_project_path(project_id) + .await + .map_err(|e| format!("Failed to get project path: {}", e))?; + + let item_id = req.item_id; + + let node = req.node; + let scale = req.scale; + let invert = req.invert; + let mask = req.mask; + match item_type.as_str() { + "thumb" => thumb(&project_path, &item_id, false).await, + "thumbhalf" => thumb(&project_path, &item_id, true).await, + "tensor" => tensor(&project_path, &item_id, node, scale, invert, mask).await, + _ => Ok(Response::builder() + .status(StatusCode::NOT_FOUND) + .body("Not Found".as_bytes().to_vec()) + .map_err(|e| e.to_string())?), + } + } - let item_type = req.item_type; - let project_id: i64 = req.project_id; + async fn get_project_path(&self, project_id: i64) -> Result { + if let Some(path) = PROJECT_PATH_CACHE.get(&project_id) { + return Ok(path.clone()); + } - let project_path = get_project_path(project_id) - .await - .map_err(|e| format!("Failed to get project path: {}", e))?; - - let item_id = req.item_id; - - let node = req.node; - let scale = req.scale; - let invert = req.invert; - let mask = req.mask; - - match item_type.as_str() { - "thumb" => thumb(&project_path, &item_id, false).await, - "thumbhalf" => thumb(&project_path, &item_id, true).await, - "tensor" => tensor(&project_path, &item_id, node, scale, invert, mask).await, - _ => Ok(Response::builder() - .status(StatusCode::NOT_FOUND) - .body("Not Found".as_bytes().to_vec()) - .map_err(|e| e.to_string())?), + let project = self.pdb.get_project(project_id).await?; + PROJECT_PATH_CACHE.insert(project_id, project.full_path.clone()); + Ok(project.full_path) } } -async fn thumb(path: &str, item_id: &str, half: bool) -> Result>, String> { +async fn thumb( + full_project_path: &str, + item_id: &str, + half: bool, +) -> Result>, String> { let id: i64 = item_id.parse().map_err(|_| "Invalid item ID".to_string())?; - let dtp = DTProject::get(path) + let dtp = DTProject::get(full_project_path) .await .map_err(|e| format!("Failed to open project: {}", e))?; @@ -148,14 +177,14 @@ async fn thumb(path: &str, item_id: &str, half: bool) -> Result } async fn tensor( - project_file: &str, + full_project_path: &str, name: &str, node: Option, scale: Option, invert: Option, _mask: Option, ) -> Result>, String> { - let dtp = DTProject::get(project_file) + let dtp = DTProject::get(full_project_path) .await .map_err(|e| format!("Failed to open project: {}", e))?; @@ -208,20 +237,6 @@ async fn tensor( } } -async fn get_project_path(project_id: i64) -> Result { - if let Some(path) = PROJECT_PATH_CACHE.read().unwrap().get(&project_id).cloned() { - return Ok(path); - } - - let pdb = ProjectsDb::get().map_err(|e| DbErr::Custom(e.to_string()))?; - let project = pdb.get_project(project_id).await?; - PROJECT_PATH_CACHE - .write() - .unwrap() - .insert(project_id, project.path.clone()); - Ok(project.path) -} - fn classify_type(s: &str) -> Option<&str> { s.rsplit_once('_').map(|(prefix, _)| prefix) } diff --git a/src-tauri/src/projects_db/dtos/image.rs b/src-tauri/src/projects_db/dtos/image.rs index e066fc4..850a4d0 100644 --- a/src-tauri/src/projects_db/dtos/image.rs +++ b/src-tauri/src/projects_db/dtos/image.rs @@ -49,6 +49,7 @@ pub struct ImageExtra { pub start_width: i32, pub start_height: i32, pub upscaler_scale_factor: Option, + pub is_ready: Option, } #[derive(Debug, Serialize)] diff --git a/src-tauri/src/projects_db/dtos/project.rs b/src-tauri/src/projects_db/dtos/project.rs index d71f790..e2ef44c 100644 --- a/src-tauri/src/projects_db/dtos/project.rs +++ b/src-tauri/src/projects_db/dtos/project.rs @@ -1,39 +1,80 @@ -use entity::projects; use sea_orm::FromQueryResult; use serde::Serialize; -#[derive(Debug, FromQueryResult, Serialize)] -pub struct ProjectExtra { +#[derive(Debug, FromQueryResult)] +pub struct ProjectRow { pub id: i64, pub fingerprint: String, pub path: String, + pub watchfolder_id: i64, pub image_count: Option, pub last_id: Option, pub filesize: Option, pub modified: Option, - pub missing_on: Option, pub excluded: bool, + pub is_missing: bool, + pub is_locked: bool, } -#[derive(Debug, Serialize, Clone)] -pub struct DTProjectInfo { - pub _path: String, - pub _history_count: i64, - pub history_max_id: i64, +#[derive(Debug, FromQueryResult, Serialize, Clone)] +pub struct ProjectExtra { + pub id: i64, + pub fingerprint: String, + pub path: String, + pub watchfolder_id: i64, + pub image_count: Option, + pub last_id: Option, + pub filesize: Option, + pub modified: Option, + pub excluded: bool, + pub name: String, + pub full_path: String, + pub is_missing: bool, + pub is_locked: bool, } -impl From for ProjectExtra { - fn from(m: projects::Model) -> Self { +impl From for ProjectExtra { + fn from(m: ProjectRow) -> Self { + let name = std::path::Path::new(&m.path) + .file_stem() + .and_then(|s| s.to_str()) + .unwrap_or("") + .to_string(); + + let wf_path = crate::projects_db::folder_cache::get_folder(m.watchfolder_id); + + let full_path = if let Some(ref wf) = wf_path { + std::path::Path::new(wf) + .join(&m.path) + .to_string_lossy() + .to_string() + } else { + m.path.clone() + }; + + // let is_missing = m.missing_on.is_some() || wf_path.is_none(); + Self { id: m.id, fingerprint: m.fingerprint, path: m.path, - image_count: None, - last_id: None, + watchfolder_id: m.watchfolder_id, + image_count: m.image_count, + last_id: m.last_id, filesize: m.filesize, modified: m.modified, - missing_on: m.missing_on, excluded: m.excluded, + name, + full_path, + is_missing: m.is_missing, + is_locked: m.is_locked, } } } + +#[derive(Debug, Serialize, Clone)] +pub struct DTProjectInfo { + pub _path: String, + pub _history_count: i64, + pub history_max_id: i64, +} diff --git a/src-tauri/src/projects_db/dtos/watch_folder.rs b/src-tauri/src/projects_db/dtos/watch_folder.rs index 228705f..a18b446 100644 --- a/src-tauri/src/projects_db/dtos/watch_folder.rs +++ b/src-tauri/src/projects_db/dtos/watch_folder.rs @@ -2,11 +2,14 @@ use entity::watch_folders; use serde::Serialize; #[derive(Debug, Serialize, Clone)] +#[serde(rename_all = "camelCase")] pub struct WatchFolderDTO { pub id: i64, pub path: String, pub recursive: Option, - pub last_updated: Option, + pub is_missing: bool, + pub is_locked: bool, + pub bookmark: String, } impl From for WatchFolderDTO { @@ -15,7 +18,9 @@ impl From for WatchFolderDTO { id: m.id, path: m.path, recursive: m.recursive, - last_updated: m.last_updated, + is_missing: m.is_missing, + is_locked: m.is_locked, + bookmark: m.bookmark, } } } diff --git a/src-tauri/src/projects_db/folder_cache.rs b/src-tauri/src/projects_db/folder_cache.rs new file mode 100644 index 0000000..2df00ec --- /dev/null +++ b/src-tauri/src/projects_db/folder_cache.rs @@ -0,0 +1,30 @@ +use crate::bookmarks; +use once_cell::sync::Lazy; +use std::collections::HashMap; +use std::path::PathBuf; +use std::sync::RwLock; + +pub static CACHE: Lazy>> = Lazy::new(|| RwLock::new(HashMap::new())); + +pub async fn resolve_bookmark(id: i64, bookmark: &str) -> Result { + let result = bookmarks::resolve_bookmark(bookmark.to_string()).await?; + + match &result { + bookmarks::ResolveResult::Resolved(path) => { + CACHE.write().unwrap().insert(id, PathBuf::from(path)); + } + bookmarks::ResolveResult::StaleRefreshed { resolved_path, .. } => { + CACHE.write().unwrap().insert(id, PathBuf::from(resolved_path)); + } + bookmarks::ResolveResult::CannotResolve => { + // Optionally remove from cache if it was there? + // For now just leave it as is or do nothing. + } + } + + Ok(result) +} + +pub fn get_folder(id: i64) -> Option { + CACHE.read().unwrap().get(&id).map(|p| p.to_str().unwrap().to_string()) +} diff --git a/src-tauri/src/projects_db/mod.rs b/src-tauri/src/projects_db/mod.rs index bf1260e..1edf923 100644 --- a/src-tauri/src/projects_db/mod.rs +++ b/src-tauri/src/projects_db/mod.rs @@ -1,15 +1,13 @@ mod dt_project; -pub use dt_project::DTProject; -mod projects_db; +pub use dt_project::{close_folder, get_last_row, DTProject, ProjectRef}; +pub mod projects_db; pub use projects_db::ProjectsDb; mod tensor_history; pub mod tensor_history_generated; -pub mod commands; - mod dtm_dtproject; -pub use dtm_dtproject::{dtm_dtproject_protocol, extract_jpeg_slice}; +pub use dtm_dtproject::{extract_jpeg_slice, DtmProtocol}; mod tensor_history_mod; @@ -23,9 +21,11 @@ pub use text_history::TextHistory; pub mod fbs; -mod filters; +pub mod filters; mod search; pub mod dtos; mod tensor_history_tensor_data; + +pub mod folder_cache; diff --git a/src-tauri/src/projects_db/projects_db/images.rs b/src-tauri/src/projects_db/projects_db/images.rs new file mode 100644 index 0000000..5bea45c --- /dev/null +++ b/src-tauri/src/projects_db/projects_db/images.rs @@ -0,0 +1,164 @@ +use crate::projects_db::{ + dtos::image::{ImageCount, ImageExtra, ListImagesOptions, ListImagesResult}, + dtos::tensor::TensorHistoryClip, + folder_cache, search, DTProject, +}; +use entity::{images, projects, watch_folders}; +use sea_orm::{ + ColumnTrait, EntityTrait, ExprTrait, JoinType, Order, PaginatorTrait, QueryFilter, QueryOrder, + QuerySelect, RelationTrait, +}; +use sea_query::Expr; + +use super::{MixedError, ProjectsDb}; + +impl ProjectsDb { + pub async fn get_image_count(&self) -> Result { + let count = images::Entity::find().count(&self.db).await?; + Ok(count as u32) + } + + pub async fn list_images( + &self, + opts: ListImagesOptions, + ) -> Result { + let direction = match opts.direction.as_deref() { + Some("asc") => Order::Asc, + _ => Order::Desc, + }; + + let mut query = images::Entity::find() + .join(JoinType::LeftJoin, images::Relation::Models.def()) + .join(JoinType::LeftJoin, images::Relation::Projects.def()) + .join(JoinType::LeftJoin, projects::Relation::WatchFolders.def()) + .column_as(entity::models::Column::Filename, "model_file") + .column_as( + Expr::col(watch_folders::Column::IsMissing) + .eq(false) + .and(Expr::col(watch_folders::Column::IsLocked).eq(false)), + "is_ready", + ) + .order_by(images::Column::WallClock, direction); + + if let Some(project_ids) = &opts.project_ids { + if !project_ids.is_empty() { + query = query.filter(images::Column::ProjectId.is_in(project_ids.clone())); + } + } + + if let Some(search_text) = &opts.search { + query = search::add_search(query, search_text); + } + + if let Some(filters) = opts.filters { + for f in filters { + query = f.target.apply(f.operator, &f.value, query); + } + } + + let show_image = opts.show_image.unwrap_or(true); + let show_video = opts.show_video.unwrap_or(true); + + if !show_image && !show_video { + return Ok(ListImagesResult { + counts: None, + images: Some(vec![]), + total: 0, + }); + } + + if show_image && !show_video { + query = query.filter(images::Column::NumFrames.is_null()); + } else if !show_image && show_video { + query = query.filter(images::Column::NumFrames.is_not_null()); + } + + if Some(true) == opts.count { + let project_counts = query + .select_only() + .column(images::Column::ProjectId) + .column_as(images::Column::Id.count(), "count") + .group_by(images::Column::ProjectId) + .into_model::() + .all(&self.db) + .await?; + + let mut total: u64 = 0; + let counts = project_counts + .into_iter() + .map(|p| { + total += p.count as u64; + ImageCount { + project_id: p.project_id, + count: p.count, + } + }) + .collect(); + + return Ok(ListImagesResult { + counts: Some(counts), + images: None, + total, + }); + } + + if let Some(skip) = opts.skip { + query = query.offset(skip as u64); + } + + if let Some(take) = opts.take { + query = query.limit(take as u64); + } + + let count = query.clone().count(&self.db).await?; + let result = query.into_model::().all(&self.db).await?; + + Ok(ListImagesResult { + images: Some(result), + total: count, + counts: None, + }) + } + + pub async fn find_image_by_preview_id( + &self, + project_id: i64, + preview_id: i64, + ) -> Result, MixedError> { + let image = images::Entity::find() + .filter(images::Column::ProjectId.eq(project_id)) + .filter(images::Column::PreviewId.eq(preview_id)) + .into_model::() + .one(&self.db) + .await?; + + Ok(image) + } + + pub async fn get_clip(&self, image_id: i64) -> Result, MixedError> { + let result: Option<(String, i64, i64)> = images::Entity::find_by_id(image_id) + .join(JoinType::InnerJoin, images::Relation::Projects.def()) + .select_only() + .column(projects::Column::Path) + .column(projects::Column::WatchfolderId) + .column(images::Column::NodeId) + .into_tuple() + .one(&self.db) + .await?; + + let (rel_path, watchfolder_id, node_id) = + result.ok_or_else(|| "Image or Project not found".to_string())?; + + let watch_folder_path = folder_cache::get_folder(watchfolder_id) + .ok_or_else(|| format!("Watch folder {watchfolder_id} not found in cache"))?; + + let full_path = std::path::Path::new(&watch_folder_path).join(rel_path); + let full_path_str = full_path + .to_str() + .ok_or_else(|| "Invalid path encoding".to_string())?; + + let dt_project = DTProject::get(full_path_str).await?; + let histories = dt_project.get_histories_from_clip(node_id).await?; + Ok(histories) + } +} diff --git a/src-tauri/src/projects_db/projects_db/import.rs b/src-tauri/src/projects_db/projects_db/import.rs new file mode 100644 index 0000000..84a2ac9 --- /dev/null +++ b/src-tauri/src/projects_db/projects_db/import.rs @@ -0,0 +1,284 @@ +use crate::projects_db::{ + dtos::image::ListImagesOptions, dtos::tensor::TensorHistoryImport, search::process_prompt, + DTProject, +}; +use entity::{ + enums::{ModelType, Sampler}, + images, +}; +use sea_orm::{sea_query::OnConflict, ConnectionTrait, EntityTrait, Set}; +use std::collections::HashMap; + +use super::models::ModelTypeAndFile; +use super::{MixedError, ProjectsDb}; + +const SCAN_BATCH_SIZE: u32 = 500; + +pub struct NodeModelWeight { + pub node_id: i64, + pub model_id: i64, + pub weight: f32, +} + +impl ProjectsDb { + pub async fn scan_project(&self, id: i64, full_scan: bool) -> Result<(i64, u64), MixedError> { + let project = self.get_project(id).await?; + + if project.excluded { + return Ok((project.id, 0)); + } + + let dt_project = DTProject::get(&project.full_path).await?; + let dt_project_info = dt_project.get_info().await?; + let end = dt_project_info.history_max_id; + + let start = match full_scan { + true => 0, + false => project.last_id.or(Some(-1)).unwrap(), + }; + + for batch_start in (start..end).step_by(SCAN_BATCH_SIZE as usize) { + let histories = dt_project + .get_histories(batch_start, SCAN_BATCH_SIZE as usize) + .await?; + + let histories_filtered: Vec = histories + .into_iter() + .filter(|h| full_scan || (h.index_in_a_clip == 0 && h.generated)) + .collect(); + + let preview_thumbs = HashMap::new(); + + let models_lookup = self.process_models(&histories_filtered).await?; + + let (images, batch_image_loras, batch_image_controls) = self.prepare_image_data( + project.id, + &histories_filtered, + &models_lookup, + preview_thumbs, + ); + + let inserted_images = if !images.is_empty() { + images::Entity::insert_many(images) + .on_conflict( + OnConflict::columns(vec![ + images::Column::NodeId, + images::Column::ProjectId, + ]) + .do_nothing() + .to_owned(), + ) + .exec_with_returning(&self.db) + .await? + } else { + vec![] + }; + + let mut node_id_to_image_id: HashMap = HashMap::new(); + for img in inserted_images { + node_id_to_image_id.insert(img.node_id, img.id); + } + + self.insert_related_data( + &node_id_to_image_id, + batch_image_loras, + batch_image_controls, + ) + .await?; + } + + let total = self + .list_images(ListImagesOptions { + project_ids: Some([project.id].to_vec()), + take: Some(0), + ..Default::default() + }) + .await?; + + self.rebuild_images_fts().await?; + + match total.images { + Some(_) => Ok((project.id, total.total)), + None => Err(MixedError::Other( + "Unexpected result: list_images returned no images".to_string(), + )), + } + } + + pub fn prepare_image_data( + &self, + project_id: i64, + histories: &[TensorHistoryImport], + models_lookup: &HashMap, + preview_thumbs: HashMap>, + ) -> ( + Vec, + Vec, + Vec, + ) { + let mut batch_image_loras: Vec = Vec::new(); + let mut batch_image_controls: Vec = Vec::new(); + + let images_models: Vec = histories + .iter() + .map(|h: &TensorHistoryImport| { + let preview_thumb = preview_thumbs.get(&h.preview_id).cloned(); + let mut image = images::ActiveModel { + project_id: Set(project_id), + node_id: Set(h.row_id), + preview_id: Set(h.preview_id), + thumbnail_half: Set(preview_thumb), + clip_id: Set(h.clip_id), + num_frames: Set(h.num_frames.map(|n| n as i16)), + prompt: Set(h.prompt.trim().to_string()), + negative_prompt: Set(h.negative_prompt.trim().to_string()), + prompt_search: Set(process_prompt(&h.prompt)), + negative_prompt_search: Set(process_prompt(&h.negative_prompt)), + refiner_start: Set(Some(h.refiner_start)), + start_width: Set(h.width as i16), + start_height: Set(h.height as i16), + seed: Set(h.seed as i64), + strength: Set(h.strength), + steps: Set(h.steps as i16), + guidance_scale: Set(h.guidance_scale), + shift: Set(h.shift), + hires_fix: Set(h.hires_fix), + tiled_decoding: Set(h.tiled_decoding), + tiled_diffusion: Set(h.tiled_diffusion), + tea_cache: Set(h.tea_cache), + cfg_zero_star: Set(h.cfg_zero_star), + upscaler_scale_factor: Set(h.upscaler.as_ref().map(|_| { + if h.upscaler_scale_factor == 2 { + 2 + } else { + 4 + } + })), + wall_clock: Set(h.wall_clock.unwrap_or_default().and_utc()), + has_mask: Set(h.has_mask), + has_depth: Set(h.has_depth), + has_pose: Set(h.has_pose), + has_color: Set(h.has_color), + has_custom: Set(h.has_custom), + has_scribble: Set(h.has_scribble), + has_shuffle: Set(h.has_shuffle), + sampler: Set(Sampler::try_from(h.sampler).unwrap_or(Sampler::EulerA)), + ..Default::default() + }; + + for lora in &h.loras { + if let Some(id) = models_lookup.get(&(lora.model.clone(), ModelType::Lora)) { + batch_image_loras.push(NodeModelWeight { + node_id: h.row_id, + model_id: *id, + weight: lora.weight, + }); + } + } + + for control in &h.controls { + if let Some(id) = models_lookup.get(&(control.model.clone(), ModelType::Cnet)) { + batch_image_controls.push(NodeModelWeight { + node_id: h.row_id, + model_id: *id, + weight: control.weight, + }); + } + } + + if let Some(model_id) = models_lookup.get(&(h.model.clone(), ModelType::Model)) { + image.model_id = Set(Some(*model_id)); + } + + if let Some(refiner) = &h.refiner_model { + if let Some(refiner_id) = + models_lookup.get(&(refiner.clone(), ModelType::Model)) + { + image.refiner_id = Set(Some(*refiner_id)); + } + } + + if let Some(upscaler) = &h.upscaler { + if let Some(upscaler_id) = + models_lookup.get(&(upscaler.clone(), ModelType::Upscaler)) + { + image.upscaler_id = Set(Some(*upscaler_id)); + } + } + + image + }) + .collect(); + + (images_models, batch_image_loras, batch_image_controls) + } + + pub async fn insert_related_data( + &self, + node_id_to_image_id: &HashMap, + batch_image_loras: Vec, + batch_image_controls: Vec, + ) -> Result<(), MixedError> { + let mut lora_models: Vec = Vec::new(); + for lora in batch_image_loras { + if let Some(image_id) = node_id_to_image_id.get(&lora.node_id) { + lora_models.push(entity::image_loras::ActiveModel { + image_id: Set(*image_id), + lora_id: Set(lora.model_id), + weight: Set(lora.weight), + ..Default::default() + }); + } + } + + if !lora_models.is_empty() { + entity::image_loras::Entity::insert_many(lora_models) + .on_conflict( + OnConflict::columns([ + entity::image_loras::Column::ImageId, + entity::image_loras::Column::LoraId, + ]) + .do_nothing() + .to_owned(), + ) + .exec(&self.db) + .await?; + } + + let mut control_models: Vec = Vec::new(); + for control in batch_image_controls { + if let Some(image_id) = node_id_to_image_id.get(&control.node_id) { + control_models.push(entity::image_controls::ActiveModel { + image_id: Set(*image_id), + control_id: Set(control.model_id), + weight: Set(control.weight), + ..Default::default() + }); + } + } + + if !control_models.is_empty() { + entity::image_controls::Entity::insert_many(control_models) + .on_conflict( + OnConflict::columns([ + entity::image_controls::Column::ImageId, + entity::image_controls::Column::ControlId, + ]) + .do_nothing() + .to_owned(), + ) + .exec(&self.db) + .await?; + } + + Ok(()) + } + + pub async fn rebuild_images_fts(&self) -> Result<(), MixedError> { + self.db + .execute_unprepared("INSERT INTO images_fts(images_fts) VALUES('rebuild')") + .await?; + + Ok(()) + } +} diff --git a/src-tauri/src/projects_db/projects_db/mixed_error.rs b/src-tauri/src/projects_db/projects_db/mixed_error.rs new file mode 100644 index 0000000..ebd07cf --- /dev/null +++ b/src-tauri/src/projects_db/projects_db/mixed_error.rs @@ -0,0 +1,62 @@ +use sea_orm::DbErr; + +#[derive(Debug)] +pub enum MixedError { + SeaOrm(DbErr), + Io(std::io::Error), + Other(String), + Sqlx(sqlx::Error), + Transaction(sea_orm::TransactionError), +} + +impl std::fmt::Display for MixedError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", mixed_error_to_string(&self)) + } +} + +impl From for MixedError { + fn from(e: std::string::String) -> Self { + MixedError::Other(e) + } +} + +impl From for MixedError { + fn from(e: std::io::Error) -> Self { + MixedError::Io(e) + } +} + +impl From for MixedError { + fn from(e: sqlx::Error) -> Self { + MixedError::Sqlx(e) + } +} + +impl From for MixedError { + fn from(e: DbErr) -> Self { + MixedError::SeaOrm(e) + } +} + +impl From> for MixedError { + fn from(e: sea_orm::TransactionError) -> Self { + MixedError::Transaction(e) + } +} + +fn mixed_error_to_string(error: &MixedError) -> String { + match error { + MixedError::Sqlx(e) => e.to_string(), + MixedError::SeaOrm(e) => e.to_string(), + MixedError::Io(e) => e.to_string(), + MixedError::Other(e) => e.to_string(), + MixedError::Transaction(e) => e.to_string(), + } +} + +impl From for String { + fn from(err: MixedError) -> String { + err.to_string() + } +} diff --git a/src-tauri/src/projects_db/projects_db/mod.rs b/src-tauri/src/projects_db/projects_db/mod.rs new file mode 100644 index 0000000..20b5033 --- /dev/null +++ b/src-tauri/src/projects_db/projects_db/mod.rs @@ -0,0 +1,24 @@ +use migration::{Migrator, MigratorTrait}; +use sea_orm::{Database, DatabaseConnection, DbErr}; + +#[derive(Clone, Debug)] +pub struct ProjectsDb { + pub db: DatabaseConnection, +} + +impl ProjectsDb { + pub async fn new(db_path: &str) -> Result { + let db = Database::connect(db_path).await?; + Migrator::up(&db, None).await?; + Ok(Self { db: db }) + } +} + +mod images; +mod import; +mod models; +mod projects; +mod watchfolders; + +mod mixed_error; +pub use mixed_error::MixedError; diff --git a/src-tauri/src/projects_db/projects_db/models.rs b/src-tauri/src/projects_db/projects_db/models.rs new file mode 100644 index 0000000..bd0dc9b --- /dev/null +++ b/src-tauri/src/projects_db/projects_db/models.rs @@ -0,0 +1,263 @@ +use std::collections::{HashMap, HashSet}; + +use crate::projects_db::dtos::{model::ModelExtra, tensor::TensorHistoryImport}; +use entity::{enums::ModelType, image_controls, image_loras, images, models}; +use sea_orm::{sea_query::OnConflict, ColumnTrait, EntityTrait, QueryFilter, QuerySelect, Set}; +use serde::Deserialize; + +use super::{MixedError, ProjectsDb}; + +#[derive(Deserialize)] +pub struct ModelInfoImport { + pub file: String, + pub name: String, + pub version: String, + #[serde(default = "default_true")] + pub is_new: bool, +} + +fn default_true() -> bool { + true +} + +pub type ModelTypeAndFile = (String, ModelType); + +impl ProjectsDb { + pub async fn process_models( + &self, + histories: &[TensorHistoryImport], + ) -> Result, MixedError> { + let models: Vec = HashSet::::from_iter( + histories + .iter() + .flat_map(get_all_models_from_tensor_history), + ) + .iter() + .map(|m| models::ActiveModel { + filename: Set(m.0.clone()), + model_type: Set(m.1), + ..Default::default() + }) + .collect(); + + let models = models::Entity::insert_many(models) + .on_conflict( + OnConflict::columns([models::Column::Filename, models::Column::ModelType]) + .update_column(models::Column::Filename) + .to_owned(), + ) + .exec_with_returning(&self.db) + .await?; + + let mut models_lookup: HashMap = HashMap::new(); + for model in models { + models_lookup.insert((model.filename.clone(), model.model_type), model.id); + } + Ok(models_lookup) + } + + pub async fn update_models( + &self, + mut models: HashMap, + model_type: ModelType, + ) -> Result { + if models.is_empty() { + return Ok(0); + } + + let existing_models = models::Entity::find() + .filter(models::Column::ModelType.eq(model_type)) + .all(&self.db) + .await?; + + for model in existing_models { + if let Some(import_model) = models.get_mut(&model.filename) { + if model.name.unwrap_or_default() == import_model.name + && model.version.unwrap_or_default() == import_model.version + { + import_model.is_new = false; + } + } + } + + let active_models: Vec = models + .into_values() + .filter_map(|m| match m.is_new { + true => Some(models::ActiveModel { + filename: Set(m.file.clone()), + name: Set(Some(m.name.clone())), + version: Set(Some(m.version.clone())), + model_type: Set(model_type), + ..Default::default() + }), + false => None, + }) + .collect(); + + let count = active_models.len(); + + models::Entity::insert_many(active_models) + .on_conflict( + OnConflict::columns([models::Column::Filename, models::Column::ModelType]) + .update_columns([models::Column::Name, models::Column::Version]) + .to_owned(), + ) + .exec(&self.db) + .await?; + + Ok(count) + } + + pub async fn scan_model_info( + &self, + path: &str, + model_type: ModelType, + ) -> Result { + let file = std::fs::File::open(path)?; + let reader = std::io::BufReader::new(file); + let models_list: Vec = + serde_json::from_reader(reader).map_err(|e| e.to_string())?; + let kvs = models_list.into_iter().map(|m| (m.file.clone(), m)); + + let models_map: HashMap = HashMap::from_iter(kvs); + + let count = self.update_models(models_map, model_type).await?; + + Ok(count) + } + + pub async fn list_models( + &self, + model_type: Option, + ) -> Result, MixedError> { + let mut models_query = models::Entity::find(); + + if let Some(t) = model_type { + models_query = models_query.filter(models::Column::ModelType.eq(t)); + } + + let models = models_query.all(&self.db).await?; + + if models.is_empty() { + return Ok(Vec::new()); + } + + let mut counts: HashMap = HashMap::new(); + + // Model + Refiner usage (images) + { + let rows = images::Entity::find() + .select_only() + .column(images::Column::ModelId) + .column_as(images::Column::Id.count(), "cnt") + .filter(images::Column::ModelId.is_not_null()) + .group_by(images::Column::ModelId) + .into_tuple::<(i64, i64)>() + .all(&self.db) + .await?; + + for (model_id, cnt) in rows { + *counts.entry(model_id).or_default() += cnt; + } + + let rows = images::Entity::find() + .select_only() + .column(images::Column::RefinerId) + .column_as(images::Column::Id.count(), "cnt") + .filter(images::Column::RefinerId.is_not_null()) + .group_by(images::Column::RefinerId) + .into_tuple::<(i64, i64)>() + .all(&self.db) + .await?; + + for (model_id, cnt) in rows { + *counts.entry(model_id).or_default() += cnt; + } + } + + // Lora usage + { + let rows = image_loras::Entity::find() + .select_only() + .column(image_loras::Column::LoraId) + .column_as(image_loras::Column::ImageId.count(), "cnt") + .group_by(image_loras::Column::LoraId) + .into_tuple::<(i64, i64)>() + .all(&self.db) + .await?; + + for (model_id, cnt) in rows { + *counts.entry(model_id).or_default() += cnt; + } + } + + // ControlNet usage + { + let rows = image_controls::Entity::find() + .select_only() + .column(image_controls::Column::ControlId) + .column_as(image_controls::Column::ImageId.count(), "cnt") + .group_by(image_controls::Column::ControlId) + .into_tuple::<(i64, i64)>() + .all(&self.db) + .await?; + + for (model_id, cnt) in rows { + *counts.entry(model_id).or_default() += cnt; + } + } + + // Upscaler usage + { + let rows = images::Entity::find() + .select_only() + .column(images::Column::UpscalerId) + .column_as(images::Column::Id.count(), "cnt") + .filter(images::Column::UpscalerId.is_not_null()) + .group_by(images::Column::UpscalerId) + .into_tuple::<(i64, i64)>() + .all(&self.db) + .await?; + + for (model_id, cnt) in rows { + *counts.entry(model_id).or_default() += cnt; + } + } + + let mut results = Vec::new(); + for model in models { + let count = counts.get(&model.id).copied().unwrap_or(0); + if count > 0 { + results.push(ModelExtra { + id: model.id, + model_type: model.model_type, + filename: model.filename, + name: model.name, + version: model.version, + count, + }); + } + } + + results.sort_by(|a, b| b.count.cmp(&a.count)); + Ok(results) + } +} + +pub fn get_all_models_from_tensor_history(h: &TensorHistoryImport) -> Vec { + let mut all_image_models: Vec = Vec::new(); + all_image_models.push((h.model.clone(), ModelType::Model)); + if let Some(refiner) = &h.refiner_model { + all_image_models.push((refiner.clone(), ModelType::Model)); + } + if let Some(upscaler) = &h.upscaler { + all_image_models.push((upscaler.clone(), ModelType::Upscaler)); + } + for lora in &h.loras { + all_image_models.push((lora.model.clone(), ModelType::Lora)); + } + for control in &h.controls { + all_image_models.push((control.model.clone(), ModelType::Cnet)); + } + all_image_models +} diff --git a/src-tauri/src/projects_db/projects_db/projects.rs b/src-tauri/src/projects_db/projects_db/projects.rs new file mode 100644 index 0000000..d8c185e --- /dev/null +++ b/src-tauri/src/projects_db/projects_db/projects.rs @@ -0,0 +1,237 @@ +use crate::projects_db::{ + dtos::project::{ProjectExtra, ProjectRow}, + folder_cache, DTProject, +}; +use entity::{ + images::{self, Entity as Images}, + projects::{self, ActiveModel, Entity as Projects}, + watch_folders, +}; +use sea_orm::{ + ActiveModelTrait, ColumnTrait, EntityTrait, ExprTrait, JoinType, QueryFilter, QuerySelect, + RelationTrait, Set, +}; +use sea_query::{Expr, OnConflict}; + +use super::{MixedError, ProjectsDb}; + +impl ProjectsDb { + pub async fn add_project( + &self, + watch_folder_id: i64, + relative_path: &str, + ) -> Result { + let watch_folder_path = folder_cache::get_folder(watch_folder_id) + .ok_or_else(|| "Watch folder not found in cache".to_string())?; + let full_path = std::path::Path::new(&watch_folder_path).join(relative_path); + let full_path_str = full_path + .to_str() + .ok_or_else(|| "Invalid path".to_string())?; + + let dt_project = DTProject::get(full_path_str).await?; + let fingerprint = dt_project.get_fingerprint().await?; + + let project = ActiveModel { + path: Set(relative_path.to_string()), + watchfolder_id: Set(watch_folder_id), + fingerprint: Set(fingerprint), + ..Default::default() + }; + + let project = Projects::insert(project) + .on_conflict( + OnConflict::columns([ + entity::projects::Column::Path, + entity::projects::Column::WatchfolderId, + ]) + .value(entity::projects::Column::Path, relative_path) + .to_owned(), + ) + .exec_with_returning(&self.db) + .await?; + + let project = self.get_project(project.id).await?; + + Ok(project) + } + + pub async fn remove_project(&self, id: i64) -> Result, MixedError> { + let _ = Projects::delete_by_id(id).exec(&self.db).await?; + + Ok(Some(id)) + } + + pub async fn get_project(&self, id: i64) -> Result { + let result = Projects::find_by_id(id) + .join(JoinType::LeftJoin, projects::Relation::Images.def()) + .column_as( + Expr::col((images::Entity, images::Column::ProjectId)).count(), + "image_count", + ) + .column_as( + Expr::col((images::Entity, images::Column::NodeId)).max(), + "last_id", + ) + .join(JoinType::LeftJoin, projects::Relation::WatchFolders.def()) + .column_as( + Expr::col((watch_folders::Entity, watch_folders::Column::Path)), + "watchfolder_path", + ) + .column_as( + Expr::col((watch_folders::Entity, watch_folders::Column::IsMissing)), + "is_missing", + ) + .column_as( + Expr::col((watch_folders::Entity, watch_folders::Column::IsLocked)), + "is_locked", + ) + .group_by(projects::Column::Id) + .into_model::() + .one(&self.db) + .await?; + + Ok(result.unwrap().into()) + } + + pub async fn get_project_by_path( + &self, + watchfolder_id: i64, + path: &str, + ) -> Result, MixedError> { + let project = project_query() + .filter(projects::Column::WatchfolderId.eq(watchfolder_id)) + .filter(projects::Column::Path.eq(path)) + .into_model::() + .one(&self.db) + .await?; + + Ok(project.map(|r| r.into())) + } + + pub async fn list_projects( + &self, + watchfolder_id: Option, + ) -> Result, MixedError> { + let mut query = Projects::find(); + + if let Some(watchfolder_id) = watchfolder_id { + query = query.filter(projects::Column::WatchfolderId.eq(watchfolder_id)); + } + + let query = query + .join(JoinType::LeftJoin, projects::Relation::Images.def()) + .column_as( + Expr::col((Images, images::Column::ProjectId)).count(), + "image_count", + ) + .column_as(Expr::col((Images, images::Column::Id)).max(), "last_id") + .join(JoinType::LeftJoin, projects::Relation::WatchFolders.def()) + .column_as( + Expr::col((watch_folders::Entity, watch_folders::Column::Path)), + "watchfolder_path", + ) + .column_as( + Expr::col((watch_folders::Entity, watch_folders::Column::IsMissing)), + "is_missing", + ) + .column_as( + Expr::col((watch_folders::Entity, watch_folders::Column::IsLocked)), + "is_locked", + ) + .group_by(projects::Column::Id) + .into_model::(); + + let results = query.all(&self.db).await?; + + Ok(results.into_iter().map(|r| r.into()).collect()) + } + + // extra calls + pub async fn update_project( + &self, + project_id: i64, + filesize: Option, + modified: Option, + ) -> Result { + let mut project = projects::ActiveModel { + id: Set(project_id), + ..Default::default() + }; + + if let Some(v) = filesize { + project.filesize = Set(Some(v)); + } + if let Some(v) = modified { + project.modified = Set(Some(v)); + } + + let result = project.update(&self.db).await?; + let updated = self.get_project(result.id).await?; + + Ok(updated) + } + + pub async fn update_exclude(&self, project_id: i64, exclude: bool) -> Result<(), MixedError> { + let project = Projects::find_by_id(project_id) + .one(&self.db) + .await? + .ok_or_else(|| MixedError::Other(format!("Project {project_id} not found")))?; + + let mut project: projects::ActiveModel = project.into(); + project.excluded = Set(exclude); + project.modified = Set(None); + project.filesize = Set(None); + project.update(&self.db).await?; + + if exclude { + log::debug!("Excluding project {}", project_id); + // Remove all images associated with this project + // Cascade delete will handle image_controls and image_loras + let result = images::Entity::delete_many() + .filter(images::Column::ProjectId.eq(project_id)) + .exec(&self.db) + .await?; + log::debug!("Deleted {} images", result.rows_affected); + } + + Ok(()) + } + + pub async fn get_dt_project( + &self, + project_ref: crate::projects_db::dt_project::ProjectRef, + ) -> Result, MixedError> { + let full_path = match project_ref { + crate::projects_db::dt_project::ProjectRef::Id(id) => { + let project = self.get_project(id).await?; + project.full_path + } + }; + + Ok(DTProject::get(&full_path).await?) + } +} + +fn project_query() -> sea_orm::Select { + projects::Entity::find() + .join(JoinType::LeftJoin, projects::Relation::Images.def()) + .column_as( + Expr::col((Images, images::Column::ProjectId)).count(), + "image_count", + ) + .column_as(Expr::col((Images, images::Column::Id)).max(), "last_id") + .join(JoinType::LeftJoin, projects::Relation::WatchFolders.def()) + .column_as( + Expr::col((watch_folders::Entity, watch_folders::Column::Path)), + "watchfolder_path", + ) + .column_as( + Expr::col((watch_folders::Entity, watch_folders::Column::IsMissing)), + "is_missing", + ) + .column_as( + Expr::col((watch_folders::Entity, watch_folders::Column::IsLocked)), + "is_locked", + ) + .group_by(projects::Column::Id) +} diff --git a/src-tauri/src/projects_db/projects_db/watchfolders.rs b/src-tauri/src/projects_db/projects_db/watchfolders.rs new file mode 100644 index 0000000..d03714b --- /dev/null +++ b/src-tauri/src/projects_db/projects_db/watchfolders.rs @@ -0,0 +1,123 @@ +use crate::projects_db::dtos::watch_folder::WatchFolderDTO; +use entity::watch_folders; +use sea_orm::{ + sea_query::Expr, ActiveModelBehavior, ActiveModelTrait, ColumnTrait, EntityTrait, QueryFilter, + QueryOrder, Set, +}; + +use super::{MixedError, ProjectsDb}; + +impl ProjectsDb { + pub async fn list_watch_folders(&self) -> Result, MixedError> { + let folders = watch_folders::Entity::find() + .order_by_asc(watch_folders::Column::Path) + .all(&self.db) + .await?; + + Ok(folders.into_iter().map(|f| f.into()).collect()) + } + + pub async fn add_watch_folder( + &self, + path: &str, + bookmark: &str, + recursive: bool, + ) -> Result { + let model = watch_folders::ActiveModel { + path: Set(path.to_string()), + bookmark: Set(bookmark.to_string()), + recursive: Set(Some(recursive)), + ..Default::default() + } + .insert(&self.db) + .await?; + + Ok(model.into()) + } + + pub async fn something(&self) -> Result<(), String> { + self.remove_watch_folders(vec![1]).await?; + Ok(()) + } + + pub async fn remove_watch_folders(&self, ids: Vec) -> Result<(), MixedError> { + if ids.is_empty() { + return Ok(()); + } + + watch_folders::Entity::delete_many() + .filter(watch_folders::Column::Id.is_in(ids)) + .exec(&self.db) + .await?; + + Ok(()) + } + + pub async fn update_watch_folder( + &self, + id: i64, + recursive: Option, + is_missing: Option, + is_locked: Option, + ) -> Result { + let mut model = watch_folders::ActiveModel::new(); + model.id = Set(id); + if let Some(r) = recursive { + model.recursive = Set(Some(r)); + } + + if let Some(is_missing) = is_missing { + model.is_missing = Set(is_missing); + } + + if let Some(is_locked) = is_locked { + model.is_locked = Set(is_locked); + } + + let model = model.update(&self.db).await?; + Ok(model.into()) + } + + pub async fn update_bookmark_path( + &self, + id: i64, + bookmark: &str, + path: &str, + ) -> Result { + let mut model: watch_folders::ActiveModel = watch_folders::Entity::find_by_id(id) + .one(&self.db) + .await? + .ok_or_else(|| MixedError::Other(format!("Watch folder {id} not found")))? + .into(); + + model.bookmark = Set(bookmark.to_string()); + model.path = Set(path.to_string()); + + let model = model.update(&self.db).await?; + Ok(model.into()) + } + + pub async fn get_watch_folder_for_path( + &self, + path: &str, + ) -> Result, MixedError> { + let folder = watch_folders::Entity::find() + .filter(Expr::cust_with_values("? LIKE path || '/%'", [path])) + .one(&self.db) + .await?; + + Ok(folder.map(|f| f.into())) + } + + pub async fn get_watch_folder_by_path( + &self, + path: &str, + ) -> Result, MixedError> { + let folder = watch_folders::Entity::find() + .filter(watch_folders::Column::Path.eq(path)) + .one(&self.db) + .await?; + + Ok(folder.map(|f| f.into())) + } +} diff --git a/src-tauri/src/projects_db/projects_db.rs b/src-tauri/src/projects_db/projects_dbx.rs similarity index 77% rename from src-tauri/src/projects_db/projects_db.rs rename to src-tauri/src/projects_db/projects_dbx.rs index 91fd19a..ffb7ea4 100644 --- a/src-tauri/src/projects_db/projects_db.rs +++ b/src-tauri/src/projects_db/projects_dbx.rs @@ -1,14 +1,14 @@ use entity::{ enums::{ModelType, Sampler}, images::{self}, - projects, + projects, watch_folders, }; use migration::{Migrator, MigratorTrait}; use sea_orm::{ sea_query::{Expr, OnConflict}, ActiveModelTrait, ColumnTrait, ConnectionTrait, Database, DatabaseConnection, DbErr, - EntityTrait, ExprTrait, JoinType, Order, PaginatorTrait, QueryFilter, QueryOrder, QuerySelect, - QueryTrait, RelationTrait, Set, + EntityTrait, ExprTrait, IntoActiveModel, JoinType, Order, PaginatorTrait, QueryFilter, + QueryOrder, QuerySelect, QueryTrait, RelationTrait, Set, }; use serde::Deserialize; use std::{ @@ -18,17 +18,21 @@ use std::{ use tauri::Manager; use tokio::sync::OnceCell; -use crate::projects_db::{ - dt_project::{self, ProjectRef}, - dtos::{ - image::{ImageCount, ImageExtra, ListImagesOptions, ListImagesResult, Paged}, - model::ModelExtra, - project::ProjectExtra, - tensor::{TensorHistoryClip, TensorHistoryImport}, - watch_folder::WatchFolderDTO, +use crate::{ + dtp_service::AppHandleWrapper, + projects_db::{ + dt_project::{self, ProjectRef}, + dtos::{ + image::{ImageCount, ImageExtra, ListImagesOptions, ListImagesResult}, + model::ModelExtra, + project::{ProjectExtra, ProjectRow}, + tensor::{TensorHistoryClip, TensorHistoryImport}, + watch_folder::WatchFolderDTO, + }, + folder_cache, + search::{self, process_prompt}, + DTProject, }, - search::{self, process_prompt}, - DTProject, }; static CELL: OnceCell = OnceCell::const_new(); @@ -39,30 +43,85 @@ pub struct ProjectsDb { pub db: DatabaseConnection, } -fn get_path(app_handle: &tauri::AppHandle) -> String { - let app_data_dir = app_handle.path().app_data_dir().unwrap(); +#[cfg(dev)] +const DB_NAME: &str = "projects4-dev.db"; +#[cfg(not(dev))] +const DB_NAME: &str = "projects4.db"; + +fn get_path(app_handle: &AppHandleWrapper) -> String { + let app_data_dir = app_handle.get_app_data_dir().unwrap(); if !app_data_dir.exists() { std::fs::create_dir_all(&app_data_dir).expect("Failed to create app data dir"); } - let project_db_path = app_data_dir.join("projects3.db"); + let project_db_path = app_data_dir.join(DB_NAME); format!("sqlite://{}?mode=rwc", project_db_path.to_str().unwrap()) } -fn check_old_path(app_handle: &tauri::AppHandle) { - let app_data_dir = app_handle.path().app_data_dir().unwrap(); +fn check_old_path(app_handle: &AppHandleWrapper) { + let app_data_dir = app_handle.get_app_data_dir().unwrap(); let old_path = app_data_dir.join("projects2.db"); if old_path.exists() { fs::remove_file(old_path).unwrap_or_default(); } + let old_path = app_data_dir.join("projects3.db"); + if old_path.exists() { + fs::remove_file(old_path).unwrap_or_default(); + } } impl ProjectsDb { - pub async fn get_or_init(app_handle: &tauri::AppHandle) -> Result<&'static ProjectsDb, String> { + pub async fn get_or_init(app_handle: &AppHandleWrapper) -> Result<&'static ProjectsDb, String> { + if CELL.initialized() { + return Ok(CELL.get().unwrap()); + } check_old_path(app_handle); + return ProjectsDb::get_or_init_path(app_handle, &get_path(app_handle)).await; + } + + pub async fn get_or_init_path( + _app_handle: &AppHandleWrapper, + db_path: &str, + ) -> Result<&'static ProjectsDb, String> { CELL.get_or_try_init(|| async { - ProjectsDb::new(&get_path(app_handle)) + println!("[ProjectsDB] opening db {}", db_path); + let db = ProjectsDb::new(&db_path) .await .map_err(|e| e.to_string()) + .unwrap(); + + let folders = entity::watch_folders::Entity::find() + .all(&db.db) + .await + .unwrap(); + + for folder in folders { + let resolved = folder_cache::resolve_bookmark(folder.id, &folder.bookmark).await; + if let Ok(resolved) = resolved { + match resolved { + crate::bookmarks::ResolveResult::Resolved(path) => { + if path != folder.path { + let mut update = folder.into_active_model(); + update.path = Set(path); + update.update(&db.db).await.unwrap(); + } + } + crate::bookmarks::ResolveResult::StaleRefreshed { + new_bookmark, + resolved_path, + } => { + let mut update = folder.into_active_model(); + update.path = Set(resolved_path); + update.bookmark = Set(new_bookmark); + update.update(&db.db).await.unwrap(); + } + crate::bookmarks::ResolveResult::CannotResolve => { + // TODO: Mark as missing in DB? + } + } + } + } + + Ok(db) }) .await } @@ -71,33 +130,49 @@ impl ProjectsDb { CELL.get().ok_or("Database not initialized".to_string()) } - async fn new(db_path: &str) -> Result { + pub async fn new(db_path: &str) -> Result { let db = Database::connect(db_path).await?; Migrator::up(&db, None).await?; Ok(Self { db: db }) } + // not used pub async fn get_image_count(&self) -> Result { let count = images::Entity::find().count(&self.db).await?; Ok(count as u32) } - pub async fn add_project(&self, path: &str) -> Result { - let dt_project = DTProject::get(path).await?; + // path must be relative to watch folder, which can be retrieved through folder_cache + pub async fn add_project( + &self, + watch_folder_id: i64, + relative_path: &str, + ) -> Result { + let watch_folder_path = folder_cache::get_folder(watch_folder_id) + .ok_or_else(|| "Watch folder not found in cache".to_string())?; + let full_path = std::path::Path::new(&watch_folder_path).join(relative_path); + let full_path_str = full_path + .to_str() + .ok_or_else(|| "Invalid path".to_string())?; + + let dt_project = DTProject::get(full_path_str).await?; let fingerprint = dt_project.get_fingerprint().await?; let project = projects::ActiveModel { - path: Set(path.to_string()), + path: Set(relative_path.to_string()), + watchfolder_id: Set(watch_folder_id), fingerprint: Set(fingerprint), ..Default::default() }; let project = entity::projects::Entity::insert(project) .on_conflict( - OnConflict::column(entity::projects::Column::Path) - // do a fake update so the row returns - .value(entity::projects::Column::Path, path) - .to_owned(), + OnConflict::columns([ + entity::projects::Column::Path, + entity::projects::Column::WatchfolderId, + ]) + .value(entity::projects::Column::Path, relative_path) + .to_owned(), ) .exec_with_returning(&self.db) .await?; @@ -107,11 +182,11 @@ impl ProjectsDb { Ok(project) } - pub async fn remove_project(&self, path: &str) -> Result, DbErr> { - let project = projects::Entity::find_by_path(path).one(&self.db).await?; + pub async fn remove_project(&self, id: i64) -> Result, DbErr> { + let project = projects::Entity::find_by_id(id).one(&self.db).await?; if project.is_none() { - log::debug!("remove project: No project found for path: {}", path); + log::debug!("remove project: No project found for id: {}", id); return Ok(None); } let project = project.unwrap(); @@ -121,7 +196,7 @@ impl ProjectsDb { .await?; if delete_result.rows_affected == 0 { - log::debug!("remove project: project couldn't be deleted: {}", path); + log::debug!("remove project: project couldn't be deleted: {}", id); } Ok(Some(project.id)) @@ -138,40 +213,44 @@ impl ProjectsDb { "image_count", ) .column_as(Expr::col((Images, images::Column::NodeId)).max(), "last_id") - .into_model::() + .group_by(projects::Column::Id) + .into_model::() .one(&self.db) .await?; - Ok(result.unwrap()) + Ok(result.unwrap().into()) } - pub async fn get_project_by_path(&self, path: &str) -> Result { - use images::Entity as Images; - use projects::Entity as Projects; - - let result = Projects::find_by_path(path) - .join(JoinType::LeftJoin, projects::Relation::Images.def()) - .column_as( - Expr::col((Images, images::Column::ProjectId)).count(), - "image_count", - ) - .column_as(Expr::col((Images, images::Column::NodeId)).max(), "last_id") - .into_model::() + pub async fn get_project_by_path( + &self, + watchfolder_id: i64, + path: &str, + ) -> Result, DbErr> { + let project = projects::Entity::find() + .filter(projects::Column::WatchfolderId.eq(watchfolder_id)) + .filter(projects::Column::Path.eq(path)) + .into_model::() .one(&self.db) .await?; - match result { - Some(result) => Ok(result), - None => Err(DbErr::RecordNotFound(format!("Project {path} not found"))), - } + Ok(project.map(|r| r.into())) } /// List all projects, newest first - pub async fn list_projects(&self) -> Result, DbErr> { + pub async fn list_projects( + &self, + watchfolder_id: Option, + ) -> Result, DbErr> { use images::Entity as Images; use projects::Entity as Projects; - let results = Projects::find() + let mut query = Projects::find(); + + if let Some(watchfolder_id) = watchfolder_id { + query = query.filter(projects::Column::WatchfolderId.eq(watchfolder_id)); + } + + let query = query .join(JoinType::LeftJoin, projects::Relation::Images.def()) .column_as( Expr::col((Images, images::Column::ProjectId)).count(), @@ -179,26 +258,24 @@ impl ProjectsDb { ) .column_as(Expr::col((Images, images::Column::Id)).max(), "last_id") .group_by(projects::Column::Id) - .into_model::() - .all(&self.db) - .await?; + .into_model::(); - Ok(results) + let results = query.all(&self.db).await?; + + Ok(results.into_iter().map(|r| r.into()).collect()) } pub async fn update_project( &self, - path: &str, + project_id: i64, filesize: Option, modified: Option, ) -> Result { // Fetch existing project - let mut project: projects::ActiveModel = projects::Entity::find() - .filter(projects::Column::Path.eq(path)) - .one(&self.db) - .await? - .ok_or(DbErr::RecordNotFound(format!("Project {path} not found")))? - .into(); + let mut project = projects::ActiveModel { + id: Set(project_id), + ..Default::default() + }; // Apply updates if let Some(v) = filesize { @@ -210,25 +287,26 @@ impl ProjectsDb { } // Save changes - let updated: ProjectExtra = project.update(&self.db).await?.into(); + let result = project.update(&self.db).await?; + + let updated = self.get_project(result.id).await?; Ok(updated) } - pub async fn scan_project( - &self, - path: &str, - full_scan: bool, - ) -> Result<(i64, u64), MixedError> { - let dt_project = DTProject::get(path).await?; - let dt_project_info = dt_project.get_info().await?; - let end = dt_project_info.history_max_id; - let project = self.get_project_by_path(path).await?; + // TODO from here down + // IMPORT + pub async fn scan_project(&self, id: i64, full_scan: bool) -> Result<(i64, u64), MixedError> { + let project = self.get_project(id).await?; if project.excluded { return Ok((project.id, 0)); } + let dt_project = DTProject::get(&project.full_path).await?; + let dt_project_info = dt_project.get_info().await?; + let end = dt_project_info.history_max_id; + let start = match full_scan { true => 0, false => project.last_id.or(Some(-1)).unwrap(), @@ -310,6 +388,7 @@ impl ProjectsDb { } } + // IMPORT async fn process_models( &self, histories: &[TensorHistoryImport], @@ -346,6 +425,7 @@ impl ProjectsDb { Ok(models_lookup) } + // IMPORT fn prepare_image_data( &self, project_id: i64, @@ -468,6 +548,7 @@ impl ProjectsDb { (images, batch_image_loras, batch_image_controls) } + // IMPORT async fn insert_related_data( &self, node_id_to_image_id: &HashMap, @@ -529,6 +610,7 @@ impl ProjectsDb { Ok(()) } + // IMAGES pub async fn list_images(&self, opts: ListImagesOptions) -> Result { // print!("ListImagesOptions: {:#?}\n", opts); @@ -667,6 +749,23 @@ impl ProjectsDb { }) } + // IMAGES + pub async fn find_image_by_preview_id( + &self, + project_id: i64, + preview_id: i64, + ) -> Result, DbErr> { + let image = entity::images::Entity::find() + .filter(images::Column::ProjectId.eq(project_id)) + .filter(images::Column::PreviewId.eq(preview_id)) + .into_model::() + .one(&self.db) + .await?; + + Ok(image) + } + + // WATCHFOLDERS pub async fn list_watch_folders(&self) -> Result, DbErr> { let folders = entity::watch_folders::Entity::find() .order_by_asc(entity::watch_folders::Column::Path) @@ -691,22 +790,52 @@ impl ProjectsDb { // // Ok(folder) // } + // WATCHFOLDERS pub async fn add_watch_folder( &self, path: &str, + bookmark: &str, recursive: bool, ) -> Result { let model = entity::watch_folders::ActiveModel { path: Set(path.to_string()), + bookmark: Set(bookmark.to_string()), recursive: Set(Some(recursive)), ..Default::default() } .insert(&self.db) .await?; + let resolved = folder_cache::resolve_bookmark(model.id, bookmark).await; + + if let Ok(resolved) = resolved { + match resolved { + crate::bookmarks::ResolveResult::Resolved(path) => { + if path != model.path { + let mut update = model.clone().into_active_model(); + update.path = Set(path); + update.update(&self.db).await?; + } + } + crate::bookmarks::ResolveResult::StaleRefreshed { + new_bookmark, + resolved_path, + } => { + let mut update = model.clone().into_active_model(); + update.path = Set(resolved_path); + update.bookmark = Set(new_bookmark); + update.update(&self.db).await?; + } + crate::bookmarks::ResolveResult::CannotResolve => { + // Handle case where it couldn't be resolved immediately? + } + } + } + Ok(model.into()) } + // WATCHFOLDERS pub async fn remove_watch_folders(&self, ids: Vec) -> Result<(), DbErr> { if ids.is_empty() { return Ok(()); @@ -720,6 +849,7 @@ impl ProjectsDb { Ok(()) } + // WATCHFOLDERS pub async fn update_watch_folder( &self, id: i64, @@ -745,6 +875,41 @@ impl ProjectsDb { Ok(model.into()) } + // WATCHFOLDERS + pub async fn update_bookmark_path( + &self, + id: i64, + bookmark: &str, + path: &str, + ) -> Result { + let mut model: entity::watch_folders::ActiveModel = + entity::watch_folders::Entity::find_by_id(id as i64) + .one(&self.db) + .await? + .unwrap() + .into(); + + model.bookmark = Set(bookmark.to_string()); + model.path = Set(path.to_string()); + + let model = model.update(&self.db).await?; + Ok(model.into()) + } + + // WATCHFOLDERS + pub async fn get_watch_folder_for_path( + &self, + path: &str, + ) -> Result, DbErr> { + let folder = watch_folders::Entity::find() + .filter(Expr::cust_with_values("? LIKE path || '/%'", [path])) + .one(&self.db) + .await?; + + Ok(folder.map(|f| f.into())) + } + + // PROJECTS pub async fn update_exclude(&self, project_id: i32, exclude: bool) -> Result<(), DbErr> { let project = projects::Entity::find_by_id(project_id) .one(&self.db) @@ -773,37 +938,28 @@ impl ProjectsDb { Ok(()) } + // PROJECTS pub async fn bulk_update_missing_on( &self, - paths: Vec, - missing_on: Option, + watch_folder_id: i64, + is_missing: bool, ) -> Result<(), DbErr> { - if paths.is_empty() { - return Ok(()); - } - - // Look up project IDs from paths - let projects = projects::Entity::find() - .filter(projects::Column::Path.is_in(paths)) - .select_only() - .column(projects::Column::Id) - .into_tuple::() - .all(&self.db) - .await?; - - if projects.is_empty() { - return Ok(()); - } + let missing_on = if is_missing { + Some(chrono::Utc::now().timestamp()) + } else { + None + }; projects::Entity::update_many() .col_expr(projects::Column::MissingOn, Expr::value(missing_on)) - .filter(projects::Column::Id.is_in(projects)) + .filter(projects::Column::WatchfolderId.eq(watch_folder_id)) .exec(&self.db) .await?; Ok(()) } + // IMPORT pub async fn rebuild_images_fts(&self) -> Result<(), sea_orm::DbErr> { self.db .execute_unprepared("INSERT INTO images_fts(images_fts) VALUES('rebuild')") @@ -812,38 +968,47 @@ impl ProjectsDb { Ok(()) } + // HELPERS pub async fn get_dt_project( &self, project_ref: ProjectRef, ) -> Result, String> { - let project_path = match project_ref { - ProjectRef::Path(path) => path, + let full_path = match project_ref { ProjectRef::Id(id) => { - let project = entity::projects::Entity::find_by_id(id as i32) - .one(&self.db) - .await - .map_err(|e| e.to_string())? - .unwrap(); - project.path + let project = self.get_project(id).await.map_err(|e| e.to_string())?; + project.full_path } }; - Ok(dt_project::DTProject::get(&project_path).await.unwrap()) + + Ok(dt_project::DTProject::get(&full_path) + .await + .map_err(|e| e.to_string())?) } + // IMAGES pub async fn get_clip(&self, image_id: i64) -> Result, String> { - let result: Option<(String, i64)> = images::Entity::find_by_id(image_id) + let result: Option<(String, i64, i64)> = images::Entity::find_by_id(image_id) .join(JoinType::InnerJoin, images::Relation::Projects.def()) .select_only() .column(entity::projects::Column::Path) + .column(entity::projects::Column::WatchfolderId) .column(images::Column::NodeId) .into_tuple() .one(&self.db) .await .map_err(|e| e.to_string())?; - let (project_path, node_id) = result.ok_or("Image or Project not found")?; + let (rel_path, watchfolder_id, node_id) = result.ok_or("Image or Project not found")?; - let dt_project = DTProject::get(&project_path) + let watch_folder_path = folder_cache::get_folder(watchfolder_id) + .ok_or_else(|| format!("Watch folder {watchfolder_id} not found in cache"))?; + + let full_path = std::path::Path::new(&watch_folder_path).join(rel_path); + let full_path_str = full_path + .to_str() + .ok_or_else(|| "Invalid path encoding".to_string())?; + + let dt_project = DTProject::get(full_path_str) .await .map_err(|e| e.to_string())?; dt_project @@ -852,6 +1017,7 @@ impl ProjectsDb { .map_err(|e| e.to_string()) } + // MODELS pub async fn update_models( &self, mut models: HashMap, @@ -910,6 +1076,7 @@ impl ProjectsDb { Ok(count) } + // MODELS pub async fn scan_model_info( &self, path: &str, @@ -928,6 +1095,7 @@ impl ProjectsDb { Ok(count) } + // MODELS pub async fn list_models( &self, model_type: Option, @@ -1048,7 +1216,6 @@ impl ProjectsDb { // 7. Sort by usage desc results.sort_by(|a, b| b.count.cmp(&a.count)); - Ok(results) } } diff --git a/src-tauri/src/vid.rs b/src-tauri/src/vid.rs index 015c20c..f2a1051 100644 --- a/src-tauri/src/vid.rs +++ b/src-tauri/src/vid.rs @@ -4,9 +4,10 @@ use serde::{Deserialize, Serialize}; use std::io::{BufRead, BufReader}; use std::process::{Command, Stdio}; use std::{fs, path::PathBuf}; -use tauri::{Emitter, Manager}; +use tauri::{Emitter, Manager, State}; -use crate::projects_db::{decode_tensor, DTProject, ProjectsDb}; +use crate::dtp_service::DTPService; +use crate::projects_db::{decode_tensor, DTProject}; #[derive(Debug, Deserialize)] #[serde(rename_all = "camelCase")] @@ -22,17 +23,13 @@ pub struct FramesExportOpts { #[tauri::command] pub async fn save_all_clip_frames( app: tauri::AppHandle, + dtp: State<'_, DTPService>, opts: FramesExportOpts, ) -> Result<(usize, String), String> { - let projects_db = ProjectsDb::get_or_init(&app).await?; + let projects_db = dtp.get_db().await.unwrap(); - let result: Option<(String, i64, i64)> = entity::images::Entity::find_by_id(opts.image_id) - .join( - JoinType::InnerJoin, - entity::images::Relation::Projects.def(), - ) + let result: Option<(i64, i64)> = entity::images::Entity::find_by_id(opts.image_id) .select_only() - .column(entity::projects::Column::Path) .column(entity::images::Column::NodeId) .column(entity::images::Column::ProjectId) .into_tuple() @@ -40,10 +37,12 @@ pub async fn save_all_clip_frames( .await .map_err(|e| e.to_string())?; - let (project_path, node_id, _project_db_id) = result.ok_or("Image or Project not found")?; + let (node_id, project_id) = result.ok_or("Image or Project not found")?; + + let project = projects_db.get_project(project_id).await.unwrap(); // 2. Fetch Clip Frames - let dt_project = DTProject::get(&project_path) + let dt_project = DTProject::get(&project.full_path) .await .map_err(|e| e.to_string())?; let frames = dt_project @@ -149,6 +148,7 @@ pub struct VideoExportOpts { #[tauri::command] pub async fn create_video_from_frames( app: tauri::AppHandle, + dtp: State<'_, DTPService>, opts: VideoExportOpts, ) -> Result { // ------------------------------------------------- @@ -183,6 +183,7 @@ pub async fn create_video_from_frames( // ------------------------------------------------- let (frame_count, _) = save_all_clip_frames( app.clone(), + dtp, FramesExportOpts { image_id: opts.image_id, output_dir: temp_dir.to_str().unwrap().to_string(), diff --git a/src-tauri/tauri.conf.json b/src-tauri/tauri.conf.json index 129b26b..b118369 100644 --- a/src-tauri/tauri.conf.json +++ b/src-tauri/tauri.conf.json @@ -1,7 +1,7 @@ { "$schema": "https://schema.tauri.app/config/2", "productName": "DTM", - "version": "0.3.2", + "version": "0.3.3", "identifier": "com.kcjer.dtm", "build": { "beforeDevCommand": "vite", @@ -73,9 +73,6 @@ "icons/icon.icns", "icons/icon.ico" ], - "createUpdaterArtifacts": true, - "macOS": { - "signingIdentity": "-" - } + "createUpdaterArtifacts": true } } diff --git a/src-tauri/test-setup.sh b/src-tauri/test-setup.sh new file mode 100644 index 0000000..9122f83 --- /dev/null +++ b/src-tauri/test-setup.sh @@ -0,0 +1,7 @@ +#!/bin/bash +set -e + +curl -L "https://github.com/kcjerrell/dtm/releases/download/test-data-v2/test-data-v2.zip" -o test-data.zip +unzip -o test-data.zip -d . +rm test-data.zip +mkdir -p test_data/temp \ No newline at end of file diff --git a/src-tauri/tests/common/mod.rs b/src-tauri/tests/common/mod.rs new file mode 100644 index 0000000..058f688 --- /dev/null +++ b/src-tauri/tests/common/mod.rs @@ -0,0 +1,194 @@ +use std::{ + env, fs, + sync::{Arc, RwLock}, +}; + +use dtm_lib::dtp_service::{ + events::DTPEvent, + jobs::{Job, JobContext, JobResult}, + AppHandleWrapper, DTPService, +}; +use serde_json::Value; +use tempfile::TempDir; + +use crate::common::projects::{WatchFolderHelper, Watchfolder}; + +pub mod projects; + +pub struct EventHelper { + received: Arc>>, +} + +impl EventHelper { + pub fn new() -> (Self, tauri::ipc::Channel) { + let received = Arc::new(RwLock::new(Vec::new())); + let received_clone = received.clone(); + let channel = tauri::ipc::Channel::new(move |event| { + match event { + tauri::ipc::InvokeResponseBody::Json(json_string) => { + let v: Value = serde_json::from_str(&json_string).unwrap(); + let event_type = v["type"].as_str().unwrap(); + println!("Received event: {}", event_type); + received_clone.write().unwrap().push(event_type.to_string()); + } + _ => { + println!("Received data event") + } + } + Ok(()) + }); + (EventHelper { received }, channel) + } + + pub fn count(self: &Self, event_type: &str) -> usize { + self.received + .read() + .unwrap() + .iter() + .filter(|e| *e == event_type) + .count() + } + + pub async fn assert_count(self: &Self, event_type: &str, count: usize) { + let mut max_checks = MAX_WAIT_MS / 100; + while self.count(event_type) < count && max_checks > 0 { + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + max_checks -= 1; + } + assert_eq!( + self.count(event_type), + count, + "Expected {} events of type {}", + count, + event_type + ); + } + + pub fn reset_counts(&self) { + self.received.write().unwrap().clear(); + } +} + +pub const MAX_WAIT_MS: u64 = 8000; +pub fn reset_db() { + let db_path = env::current_dir() + .unwrap() + .join("test_data") + .join("temp") + .join("app_data_dir") + .join("projects4-dev.db"); + + if db_path.exists() { + fs::remove_file(db_path).unwrap(); + } +} + +#[derive(Clone)] +pub struct TestJob { + pub id: u64, + pub delay: u64, + pub subtasks: Vec, + pub msg: Option, + pub should_fail: bool, +} +impl TestJob { + pub fn new(id: u64, delay: u64) -> Self { + Self { + id, + delay, + subtasks: Vec::new(), + msg: None, + should_fail: false, + } + } + + pub fn with_fail(mut self) -> Self { + self.should_fail = true; + self + } + + pub fn with_msg(mut self, msg: String) -> Self { + self.msg = Some(msg); + self + } + + pub fn with_subtasks(mut self, subtasks: Vec) -> Self { + self.subtasks = subtasks; + self + } + + pub fn with_subtask(mut self, subtask: TestJob) -> Self { + self.subtasks.push(subtask); + self + } +} +#[async_trait::async_trait] +impl Job for TestJob { + fn get_label(&self) -> String { + format!("TestJob {}", self.id) + } + + fn start_event(&self) -> Option { + Some(DTPEvent::TestEventStart(Some(self.id), None)) + } + + async fn on_complete(&self, ctx: &JobContext) { + ctx.events + .emit(DTPEvent::TestEventComplete(Some(self.id), None)); + } + + async fn on_failed(&self, ctx: &JobContext, error: String) { + ctx.events + .emit(DTPEvent::TestEventFailed(Some(self.id), None, Some(error))); + } + + async fn execute(&self, ctx: &JobContext) -> Result { + println!("Executing TestJob {}", self.id); + tokio::time::sleep(std::time::Duration::from_millis(self.delay)).await; + if self.should_fail { + return Err("TestJob failed".to_string()); + } + if self.subtasks.is_empty() { + Ok(JobResult::None) + } else { + let subtasks = self + .subtasks + .iter() + .map(|j| { + let j: Arc = Arc::new(j.clone()); + j + }) + .collect(); + Ok(JobResult::Subtasks(subtasks)) + } + } +} + +pub async fn test_fixture(auto_watch: bool) -> (DTPService, EventHelper, WatchFolderHelper, String) { + let temp_dir = TempDir::new_in("test_data/temp").unwrap(); + let temp_dir_path = temp_dir.path().to_str().unwrap().to_string(); + let wfh = WatchFolderHelper::get(Watchfolder::A, temp_dir); + // reset_db(); + let app_handle = AppHandleWrapper::new(None); + let dtps = DTPService::new(app_handle); + + let app_data_dir = format!("{}/app_data_dir", temp_dir_path); + let db_path = format!("{}/projects4.db", app_data_dir); + fs::create_dir_all(&app_data_dir).unwrap(); + + let (event_helper, channel) = EventHelper::new(); + let _ = dtps + .connect( + channel, + auto_watch, + format!( + "sqlite://{}/app_data_dir/projects4.db?mode=rwc", + temp_dir_path, + ) + .to_string(), + ) + .await + .unwrap(); + + (dtps, event_helper, wfh, db_path) +} diff --git a/src-tauri/tests/common/projects.rs b/src-tauri/tests/common/projects.rs new file mode 100644 index 0000000..bd02b93 --- /dev/null +++ b/src-tauri/tests/common/projects.rs @@ -0,0 +1,140 @@ +use std::{fs, path::PathBuf}; + +use tempfile::TempDir; +use tracing::warn; + +pub const PROJECTS_DIR: &str = "test_data/projects"; +pub const WATCHFOLDER_A: &str = "watchfolder_a"; + +pub enum Watchfolder { + A, +} + +fn get_watchfolder_path(watchfolder: Watchfolder) -> String { + match watchfolder { + Watchfolder::A => WATCHFOLDER_A.to_string(), + } +} + +pub struct TestProject { + pub filename: String, + pub variant: Option, + pub watchfolder: String, +} + +impl TestProject { + pub fn copy(&self) { + let src_path = self.get_src_path(); + let dest_path = self.get_dest_path(); + println!("Copying {} to {}", src_path, dest_path); + fs::copy(src_path, dest_path).unwrap(); + } + + pub fn remove(&self) { + let remove_path = self.get_dest_path(); + let remove_path = PathBuf::from(remove_path); + + if remove_path.exists() { + fs::remove_file(&remove_path).unwrap(); + } + + if remove_path.with_extension("sqlite3-wal").exists() { + fs::remove_file(&remove_path.with_extension("sqlite3-wal")).unwrap(); + } + + if remove_path.with_extension("sqlite3-shm").exists() { + fs::remove_file(&remove_path.with_extension("sqlite3-shm")).unwrap(); + } + } + + pub fn copy_variant(&self) { + if let Some(variant) = &self.variant { + let src_path = self.get_variant_src_path(); + let dest_path = self.get_dest_path(); + fs::copy(src_path, dest_path).unwrap(); + } else { + warn!("No variant for {}", self.filename); + } + } + + pub fn get_src_path(&self) -> String { + format!("{}/{}", PROJECTS_DIR, self.filename) + } + + pub fn get_variant_src_path(&self) -> String { + format!("{}/{}", PROJECTS_DIR, self.variant.as_ref().unwrap()) + } + + pub fn get_dest_path(&self) -> String { + format!("{}/{}", self.watchfolder, self.filename) + } +} + +pub struct WatchFolderHelper { + pub projects: Vec, + pub watchfolder_path: String, + pub bookmark: String, + pub temp_dir: TempDir, +} + +impl WatchFolderHelper { + pub fn get(watchfolder: Watchfolder, temp_dir: TempDir) -> Self { + let watchfolder_path = temp_dir + .path() + .join(get_watchfolder_path(watchfolder)) + .to_str() + .unwrap() + .to_string(); + + println!("Watchfolder path: {}", watchfolder_path); + let bookmark: String = format!("TESTBOOKMARK::{}", watchfolder_path); + + let projects = vec![ + TestProject { + filename: "test-project-a2.sqlite3".to_string(), + variant: None, + watchfolder: watchfolder_path.clone(), + }, + TestProject { + filename: "test-project-c-9.sqlite3".to_string(), + variant: Some("test-project-c-10.sqlite3".to_string()), + watchfolder: watchfolder_path.clone(), + }, + ]; + let wh = Self { + projects, + watchfolder_path, + bookmark, + temp_dir, + }; + wh.clear_all(); + wh + } + + pub fn get_count(&self) -> usize { + self.projects.len() + } + + pub fn copy_all(&self) { + for project in &self.projects { + project.copy(); + } + } + + pub fn clear_all(&self) { + let _ = fs::remove_dir_all(&self.watchfolder_path); + let _ = fs::create_dir_all(&self.watchfolder_path); + } + + pub fn remove_all(&self) { + for project in &self.projects { + project.remove(); + } + } + + pub fn copy_variants(&self) { + for project in &self.projects { + project.copy_variant(); + } + } +} diff --git a/src-tauri/tests/lib.rs b/src-tauri/tests/lib.rs index 31e1bb2..3e02cb9 100644 --- a/src-tauri/tests/lib.rs +++ b/src-tauri/tests/lib.rs @@ -1,7 +1,14 @@ +mod common; + #[cfg(test)] mod tests { - #[test] - fn it_works() { - assert_eq!(2 + 2, 4); + use dtm_lib::dtp_service::AppHandleWrapper; + use dtm_lib::dtp_service::DTPService; + + use crate::common::projects::WATCHFOLDER_A; + use crate::common::*; + + #[tokio::test] + async fn projects_test() { } } diff --git a/src-tauri/tests/scheduler.rs b/src-tauri/tests/scheduler.rs new file mode 100644 index 0000000..6c25124 --- /dev/null +++ b/src-tauri/tests/scheduler.rs @@ -0,0 +1,73 @@ +mod common; + +#[cfg(test)] +mod tests { + use dtm_lib::dtp_service::AppHandleWrapper; + use dtm_lib::dtp_service::DTPService; + + use crate::common::*; + + #[tokio::test] + async fn schedule_jobs() { + let app_handle = AppHandleWrapper::new(None); + let dtp = DTPService::new(app_handle); + + let (event_helper, channel) = EventHelper::new(); + let _ = dtp + .connect(channel, false, "sqlite::memory:".to_string()) + .await; + + // it can add and run jobs + dtp.add_job(TestJob::new(1, 100)); + event_helper.assert_count("test_event_start", 1).await; + event_helper.assert_count("test_event_complete", 1).await; + + // it can add and run concurrent jobs + // the assumes concurrent threads are 4 + event_helper.reset_counts(); + let start_time = std::time::Instant::now(); + dtp.add_job(TestJob::new(2, 500)); + dtp.add_job(TestJob::new(3, 500)); + dtp.add_job(TestJob::new(4, 500)); + dtp.add_job(TestJob::new(5, 500)); + event_helper.assert_count("test_event_start", 4).await; + event_helper.assert_count("test_event_complete", 0).await; + event_helper.assert_count("test_event_complete", 4).await; + let duration = start_time.elapsed(); + assert!(duration < std::time::Duration::from_millis(1000)); + + // it can add and run jobs with subtasks + event_helper.reset_counts(); + let start_time = std::time::Instant::now(); + dtp.add_job( + TestJob::new(6, 500) + .with_subtask(TestJob::new(7, 500).with_subtask(TestJob::new(8, 500))), + ); + event_helper.assert_count("test_event_start", 3).await; + event_helper.assert_count("test_event_complete", 3).await; + assert!(start_time.elapsed() >= std::time::Duration::from_millis(1500)); + + dtp.stop().await; + } + + #[tokio::test] + async fn schedule_jobs_with_failure() { + let app_handle = AppHandleWrapper::new(None); + let dtp_service = DTPService::new(app_handle); + + let (event_helper, channel) = EventHelper::new(); + let _ = dtp_service + .connect(channel, false, "sqlite::memory:".to_string()) + .await; + + let scheduler = { dtp_service.scheduler.read().await.clone().unwrap().clone() }; + + // it can add and run jobs with failure + event_helper.reset_counts(); + scheduler.add_job(TestJob::new(1, 500).with_fail()); + event_helper.assert_count("test_event_start", 1).await; + event_helper.assert_count("test_event_failed", 1).await; + + dtp_service.stop().await; + } +} diff --git a/src-tauri/tests/setup.rs b/src-tauri/tests/setup.rs new file mode 100644 index 0000000..cac1c07 --- /dev/null +++ b/src-tauri/tests/setup.rs @@ -0,0 +1,37 @@ +mod common; + +#[cfg(test)] +mod tests { + + use std::fs; + + use crate::common::*; + + #[tokio::test] + async fn sync_projects_no_watch() { + let (dtps, event_helper, wfh, db_path) = test_fixture(false).await; + + // add empty watch folder + dtps.add_watchfolder(wfh.watchfolder_path.clone(), wfh.bookmark.clone()) + .await + .unwrap(); + + event_helper.assert_count("folder_sync_complete", 1).await; + let projects = dtps.list_projects(None).await.unwrap(); + assert_eq!(projects.len(), 0); + event_helper.reset_counts(); + + // copy projects and sync + wfh.copy_all(); + let _ = dtps.sync().await; + + event_helper.assert_count("folder_sync_complete", 1).await; + let projects = dtps.list_projects(None).await.unwrap(); + assert_eq!(projects.len(), 2); + event_helper.reset_counts(); + + dtps.stop().await; + + fs::copy(db_path, "test_data/testdb.db").unwrap(); + } +} \ No newline at end of file diff --git a/src-tauri/tests/sync.rs b/src-tauri/tests/sync.rs new file mode 100644 index 0000000..a962837 --- /dev/null +++ b/src-tauri/tests/sync.rs @@ -0,0 +1,101 @@ +mod common; + +#[cfg(test)] +mod tests { + + use crate::common::*; + + #[tokio::test] + async fn sync_projects_no_watch() { + let (dtps, event_helper, wfh, _) = test_fixture(false).await; + + // add empty watch folder + dtps.add_watchfolder(wfh.watchfolder_path.clone(), wfh.bookmark.clone()) + .await + .unwrap(); + + event_helper.assert_count("folder_sync_complete", 1).await; + let projects = dtps.list_projects(None).await.unwrap(); + assert_eq!(projects.len(), 0); + event_helper.reset_counts(); + + // copy projects and sync + wfh.copy_all(); + let _ = dtps.sync().await; + + event_helper.assert_count("folder_sync_complete", 1).await; + event_helper.assert_count("project_added", 2).await; + event_helper.assert_count("project_updated", 2).await; + let projects = dtps.list_projects(None).await.unwrap(); + assert_eq!(projects.len(), 2); + event_helper.reset_counts(); + + // remove one project + wfh.projects[0].remove(); + let _ = dtps.sync().await; + + event_helper.assert_count("folder_sync_complete", 1).await; + event_helper.assert_count("project_removed", 1).await; + let projects = dtps.list_projects(None).await.unwrap(); + assert_eq!(projects.len(), 1); + event_helper.reset_counts(); + + // update one project + let current_image_count = projects[0].image_count.unwrap(); + wfh.projects[1].copy_variant(); + let _ = dtps.sync().await; + + event_helper.assert_count("folder_sync_complete", 1).await; + event_helper.assert_count("project_updated", 1).await; + let projects = dtps.list_projects(None).await.unwrap(); + assert_eq!(projects.len(), 1); + assert_eq!(projects[0].image_count.unwrap(), current_image_count + 1); + event_helper.reset_counts(); + + dtps.stop().await; + } + + #[tokio::test] + async fn sync_projects_with_watch() { + let (dtps, event_helper, wfh, _) = test_fixture(true).await; + + // add empty watch folder + dtps.add_watchfolder(wfh.watchfolder_path.clone(), wfh.bookmark.clone()) + .await + .unwrap(); + + event_helper.assert_count("folder_sync_complete", 1).await; + let projects = dtps.list_projects(None).await.unwrap(); + assert_eq!(projects.len(), 0); + event_helper.reset_counts(); + + // copy projects and sync + wfh.copy_all(); + + event_helper.assert_count("project_added", 2).await; + event_helper.assert_count("project_updated", 2).await; + let projects = dtps.list_projects(None).await.unwrap(); + assert_eq!(projects.len(), 2); + event_helper.reset_counts(); + + // remove one project + wfh.projects[0].remove(); + + event_helper.assert_count("project_removed", 1).await; + let projects = dtps.list_projects(None).await.unwrap(); + assert_eq!(projects.len(), 1); + event_helper.reset_counts(); + + // update one project + let current_image_count = projects[0].image_count.unwrap(); + wfh.projects[1].copy_variant(); + + event_helper.assert_count("project_updated", 1).await; + let projects = dtps.list_projects(None).await.unwrap(); + assert_eq!(projects.len(), 1); + assert_eq!(projects[0].image_count.unwrap(), current_image_count + 1); + event_helper.reset_counts(); + + dtps.stop().await; + } +} diff --git a/src/App.tsx b/src/App.tsx index affac0f..e7eea86 100644 --- a/src/App.tsx +++ b/src/App.tsx @@ -1,7 +1,7 @@ import { HStack, IconButton, Spacer, VStack } from "@chakra-ui/react" import { getCurrentWindow } from "@tauri-apps/api/window" import { AnimatePresence, LayoutGroup, motion } from "motion/react" -import { lazy, type PropsWithChildren, Suspense, useEffect, useRef } from "react" +import { type PropsWithChildren, Suspense, useEffect, useRef } from "react" import { ErrorBoundary } from "react-error-boundary" import { useSnapshot } from "valtio" import { CheckRoot, Sidebar, Tooltip } from "@/components" @@ -11,6 +11,7 @@ import { themeHelpers } from "@/theme/helpers" import { toggleColorMode, useColorMode } from "./components/ui/color-mode" import ErrorFallback from "./ErrorFallback" import AppStore from "./hooks/appState" +import { useMetadataDrop } from "./hooks/useDrop" import { Loading } from "./main" import "./menu" import UpgradeButton from "./metadata/toolbar/UpgradeButton" @@ -27,6 +28,8 @@ function App() { const isPreviewActive = useIsPreviewActive() const { colorMode } = useColorMode() + const { handlers: dropHandlers } = useMetadataDrop() + return ( - + {viewDescription.map((item) => ( = T | Readonly + +async function connect(channel: Channel) { + await invoke("dtp_connect", { channel, autoWatch: true }) +} + +async function lockFolder(watchfolderId: number) { + await invoke("dtp_lock_folder", { watchfolderId }) +} + +async function listProjects(watchfolderId?: number): Promise { + return await invoke("dtp_list_projects", { watchfolderId }) +} + +async function updateProject(projectId: number, exclude?: boolean): Promise { + return await invoke("dtp_update_project", { projectId, exclude }) +} + +async function listImages( + source: MaybeReadonly, + skip: number, + take: number, +): Promise { + const result: ListImagesResult = await invoke("dtp_list_images", { + ...source, + skip, + take, + }) + return result +} + +async function listImagesCount(source: MaybeReadonly) { + const opts = { ...source, projectIds: undefined, count: true } + const result: ListImagesResult = await invoke("dtp_list_images", opts) + return result +} + +async function findImageFromPreviewId( + projectId: number, + previewId: number, +): Promise { + return await invoke("dtp_find_image_from_preview_id", { projectId, previewId }) +} + +async function getClip(imageId: number): Promise { + return await invoke("dtp_get_clip", { imageId }) +} + +async function listWatchFolders(): Promise { + return await invoke("dtp_list_watch_folders") +} + +async function pickWatchFolder(dtFolder?: boolean): Promise { + let testOverride = undefined + if ((window as unknown as Record).__E2E_FILE_PATH__) { + testOverride = `TESTPATH::${(window as unknown as Record).__E2E_FILE_PATH__}` + ;(window as unknown as Record).__E2E_FILE_PATH__ = "" // Clear it after use + // In E2E tests, we bypass the native picker and return a predefined path. + } + return await invoke("dtp_pick_watch_folder", { dtFolder, testOverride }) +} + +async function removeWatchFolder(id: number): Promise { + return await invoke("dtp_remove_watch_folder", { id }) +} + +async function updateWatchFolder(id: number, recursive: boolean): Promise { + return await invoke("dtp_update_watch_folder", { id, recursive }) +} + +async function listModels(modelType?: ModelType): Promise { + return await invoke("dtp_list_models", { modelType }) +} + +async function getHistoryFull(projectId: number, rowId: number): Promise { + return await invoke("dtp_get_history_full", { projectId, rowId }) +} + +async function getTensorSize(projectId: number, tensorId: string): Promise { + return await invoke("dtp_get_tensor_size", { projectId, tensorId }) +} + +async function decodeTensor( + projectId: number, + tensorId: string, + asPng: boolean, + nodeId?: number, +): Promise> { + const opts = { + tensorId, + projectId, + asPng, + nodeId, + } + return new Uint8Array(await invoke("dtp_decode_tensor", opts)) +} + +async function findPredecessor( + projectId: number, + rowId: number, + lineage: number, + logicalTime: number, +): Promise { + return await invoke("dtp_find_predecessor", { + projectId, + rowId, + lineage, + logicalTime, + }) +} + +async function sync() { + await invoke("dtp_sync") +} + +const DTPService = { + connect, + listProjects, + updateProject, + listImages, + listImagesCount, + findImageFromPreviewId, + getClip, + listWatchFolders, + pickWatchFolder, + removeWatchFolder, + updateWatchFolder, + listModels, + getHistoryFull, + getTensorSize, + decodeTensor, + findPredecessor, + sync, + lockFolder +} + +export default DTPService diff --git a/src/commands/DtpServiceTypes.ts b/src/commands/DtpServiceTypes.ts new file mode 100644 index 0000000..5ef01af --- /dev/null +++ b/src/commands/DtpServiceTypes.ts @@ -0,0 +1,316 @@ +import type { DrawThingsConfig, DrawThingsConfigGrouped } from "@/types" + +export type ModelType = "None" | "Model" | "Lora" | "Cnet" | "Upscaler" + +export interface Model { + id: number + model_type: ModelType + filename: string + name?: string | null + version?: string | null + count: number +} + +export interface ProjectExtra { + id: number + fingerprint: string + path: string + watchfolder_id: number + image_count: number | null + last_id: number | null + filesize: number | null + modified: number | null + missing_on: number | null + excluded: boolean + name: string + full_path: string + is_missing: boolean + is_locked: boolean + is_ready: boolean +} + +export interface ImageExtra { + id: number + project_id: number + model_id: number | null + model_file: string | null + prompt: string + negative_prompt: string + num_frames: number | null + preview_id: number + node_id: number + has_depth: boolean + has_pose: boolean + has_color: boolean + has_custom: boolean + has_scribble: boolean + has_shuffle: boolean + start_width: number + start_height: number + upscaler_scale_factor: number | null + is_ready: boolean +} + +export interface ImageCount { + project_id: number + count: number +} + +export interface ListImagesResult { + counts: ImageCount[] | null + images: ImageExtra[] | null + total: number +} + +export type FilterOperator = + | "eq" + | "neq" + | "gt" + | "gte" + | "lt" + | "lte" + | "is" + | "isnot" + | "has" + | "hasall" + | "doesnothave" + +export type FilterTarget = + | "model" + | "lora" + | "control" + | "sampler" + | "content" + | "seed" + | "steps" + | "width" + | "height" + | "textguidance" + | "shift" + +export interface ListImagesFilter { + target: FilterTarget + operator: FilterOperator + value: string[] | number[] +} + +export interface ImagesSource { + projectIds?: number[] + search?: string + filters?: ListImagesFilter[] + sort?: string + direction?: "asc" | "desc" + count?: boolean + showVideo?: boolean + showImage?: boolean +} + +export interface WatchFolder { + id: number + path: string + recursive: boolean | null + last_updated: number | null + isMissing: boolean + isLocked: boolean + bookmark: string +} + +export interface TensorHistoryClip { + tensor_id: string + preview_id: number + clip_id: number + index_in_a_clip: number + row_id: number +} + +export interface TensorSize { + width: number + height: number + channels: number +} + +export interface Control { + file?: string + weight: number + guidance_start: number + guidance_end: number + no_prompt: boolean + global_average_pooling: boolean + down_sampling_rate: number + control_mode: string + target_blocks?: string[] + input_override: string +} + +export interface LoRA { + file?: string + weight: number + mode: string +} + +export interface TensorRaw { + tensor_type: number + data_type: number + format: number + width: number + height: number + channels: number + dim: ArrayBuffer + data: ArrayBuffer +} + +export interface DTImageFull { + id: number + prompt: string | null + negativePrompt: string | null + model: Model | null + project: ProjectExtra + config: DrawThingsConfig + groupedConfig: DrawThingsConfigGrouped + clipId: number + numFrames: number + node: XTensorHistoryNode + images: { + tensorId: string | null + previewId: number + maskId: string | null + depthMapId: string | null + scribbleId: string | null + poseId: string | null + colorPaletteId: string | null + customId: string | null + moodboardIds: string[] + } | null +} + +export interface XTensorHistoryNode { + lineage: number + logical_time: number + start_width: number + start_height: number + seed: number + steps: number + guidance_scale: number + strength: number + model: string | null + tensor_id: number + mask_id: number + wall_clock: string | null + text_edits: number + text_lineage: number + batch_size: number + sampler: number + hires_fix: boolean + hires_fix_start_width: number + hires_fix_start_height: number + hires_fix_strength: number + upscaler: string | null + scale_factor: number + depth_map_id: number + generated: boolean + image_guidance_scale: number + seed_mode: number + clip_skip: number + controls: Control[] | null + scribble_id: number + pose_id: number + loras: LoRA[] | null + color_palette_id: number + mask_blur: number + custom_id: number + face_restoration: string | null + clip_weight: number + negative_prompt_for_image_prior: boolean + image_prior_steps: number + data_stored: number + preview_id: number + content_offset_x: number + content_offset_y: number + scale_factor_by_120: number + refiner_model: string | null + original_image_height: number + original_image_width: number + crop_top: number + crop_left: number + target_image_height: number + target_image_width: number + aesthetic_score: number + negative_aesthetic_score: number + zero_negative_prompt: boolean + refiner_start: number + negative_original_image_height: number + negative_original_image_width: number + shuffle_data_stored: number + fps_id: number + motion_bucket_id: number + cond_aug: number + start_frame_cfg: number + num_frames: number + mask_blur_outset: number + sharpness: number + shift: number + stage_2_steps: number + stage_2_cfg: number + stage_2_shift: number + tiled_decoding: boolean + decoding_tile_width: number + decoding_tile_height: number + decoding_tile_overlap: number + stochastic_sampling_gamma: number + preserve_original_after_inpaint: boolean + tiled_diffusion: boolean + diffusion_tile_width: number + diffusion_tile_height: number + diffusion_tile_overlap: number + upscaler_scale_factor: number + script_session_id: number + t5_text_encoder: boolean + separate_clip_l: boolean + clip_l_text: string | null + separate_open_clip_g: boolean + open_clip_g_text: string | null + speed_up_with_guidance_embed: boolean + guidance_embed: number + resolution_dependent_shift: boolean + tea_cache_start: number + tea_cache_end: number + tea_cache_threshold: number + tea_cache: boolean + separate_t5: boolean + t5_text: string | null + tea_cache_max_skip_steps: number + text_prompt: string | null + negative_text_prompt: string | null + clip_id: number + index_in_a_clip: number + causal_inference_enabled: boolean + causal_inference: number + causal_inference_pad: number + cfg_zero_star: boolean + cfg_zero_init_steps: number + generation_time: number + reason: number +} + +export interface TensorHistoryExtra { + row_id: number + lineage: number + logical_time: number + tensor_id: string | null + mask_id: string | null + depth_map_id: string | null + scribble_id: string | null + pose_id: string | null + color_palette_id: string | null + custom_id: string | null + moodboard_ids: string[] + history: XTensorHistoryNode + project_path: string +} + +export type ScanProgress = { + projects_found: number + projects_scanned: number + images_found: number + images_scanned: number +} diff --git a/src/commands/bookmarks.ts b/src/commands/bookmarks.ts index c3b77b6..57672c5 100644 --- a/src/commands/bookmarks.ts +++ b/src/commands/bookmarks.ts @@ -6,33 +6,19 @@ export interface PickFolderResult { } /** - * Opens a native folder picker on macOS to select the Draw Things Documents folder. + * Opens a native folder picker on macOS. * Returns both the selected folder's path and a base64-encoded security-scoped bookmark. * * @param defaultPath Optional path to suggest in the picker. + * @param buttonText Optional text for the action button (default: "Select folder"). * @returns A PickFolderResult containing path and bookmark, or null if cancelled. */ -export async function pickDrawThingsFolder(defaultPath?: string): Promise { - return await invoke("pick_draw_things_folder", { defaultPath }); -} - -/** - * Resolves a security-scoped bookmark and starts accessing the resource. - * Returns the local file path to the resource. - * The internal cache ensures stopAccessing... is called only when the app exits. - * - * @param bookmark The base64-encoded bookmark string to resolve. - * @returns The resolved local path. - */ -export async function resolveBookmark(bookmark: string): Promise { - return await invoke("resolve_bookmark", { bookmark }); -} - -/** - * Manually stops accessing a security-scoped bookmark and removes it from the bookmark manager. - * - * @param bookmark The base64-encoded bookmark string to release. - */ -export async function stopAccessingBookmark(bookmark: string): Promise { - return await invoke("stop_accessing_bookmark", { bookmark }); +export async function pickFolder(defaultPath?: string, buttonText?: string): Promise { + let path = defaultPath + if ((window as unknown as Record).__E2E_FILE_PATH__) { + path = `TESTPATH::${(window as unknown as Record).__E2E_FILE_PATH__}`; + (window as unknown as Record).__E2E_FILE_PATH__ = ""; // Clear it after use + // In E2E tests, we bypass the native picker and return a predefined path. + } + return await invoke("pick_folder", { defaultPath: path, buttonText }); } diff --git a/src/commands/index.ts b/src/commands/index.ts index 62a0b3f..92168d2 100644 --- a/src/commands/index.ts +++ b/src/commands/index.ts @@ -1,3 +1,10 @@ -export * from './projects' -export * from './vid' -export * from './bookmarks' \ No newline at end of file +// export * from './projects' + +export * from "./bookmarks" +export * from "./vid" + +import DtpService from "./DtpService" + +export * from "./DtpServiceTypes" + +export { DtpService } diff --git a/src/commands/projects.ts b/src/commands/projects.ts deleted file mode 100644 index 1770bf8..0000000 --- a/src/commands/projects.ts +++ /dev/null @@ -1,395 +0,0 @@ -import { invoke } from "@tauri-apps/api/core" -import type { ProjectState } from "@/dtProjects/state/projects" -import type { - ImageExtra, - ListImagesResult, - ProjectExtra, - TensorHistoryClip, -} from "@/generated/types" -import type { DrawThingsConfig, DrawThingsConfigGrouped } from "@/types" -import type { ImagesSource as ListImagesOpts } from "../dtProjects/types" - -export type { ImageExtra, ListImagesResult, ProjectExtra } - -export type Control = { - file?: string - weight: number - guidance_start: number - guidance_end: number - no_prompt: boolean - global_average_pooling: boolean - down_sampling_rate: number - control_mode: string - target_blocks?: string[] - input_override: string -} - -export type LoRA = { - file?: string - weight: number - mode: string -} - -export type XTensorHistoryNode = { - lineage: number - logical_time: number - start_width: number - start_height: number - seed: number - steps: number - guidance_scale: number - strength: number - model?: string - tensor_id: number - mask_id: number - wall_clock?: string - text_edits: number - text_lineage: number - batch_size: number - sampler: number - hires_fix: boolean - hires_fix_start_width: number - hires_fix_start_height: number - hires_fix_strength: number - upscaler?: string - scale_factor: number - depth_map_id: number - generated: boolean - image_guidance_scale: number - seed_mode: number - clip_skip: number - controls?: Control[] - scribble_id: number - pose_id: number - loras?: LoRA[] - color_palette_id: number - mask_blur: number - custom_id: number - face_restoration?: string - clip_weight: number - negative_prompt_for_image_prior: boolean - image_prior_steps: number - data_stored: number - preview_id: number - content_offset_x: number - content_offset_y: number - scale_factor_by_120: number - refiner_model?: string - original_image_height: number - original_image_width: number - crop_top: number - crop_left: number - target_image_height: number - target_image_width: number - aesthetic_score: number - negative_aesthetic_score: number - zero_negative_prompt: boolean - refiner_start: number - negative_original_image_height: number - negative_original_image_width: number - shuffle_data_stored: number - fps_id: number - motion_bucket_id: number - cond_aug: number - start_frame_cfg: number - num_frames: number - mask_blur_outset: number - sharpness: number - shift: number - stage_2_steps: number - stage_2_cfg: number - stage_2_shift: number - tiled_decoding: boolean - decoding_tile_width: number - decoding_tile_height: number - decoding_tile_overlap: number - stochastic_sampling_gamma: number - preserve_original_after_inpaint: boolean - tiled_diffusion: boolean - diffusion_tile_width: number - diffusion_tile_height: number - diffusion_tile_overlap: number - upscaler_scale_factor: number - script_session_id: number - t5_text_encoder: boolean - separate_clip_l: boolean - clip_l_text?: string - separate_open_clip_g: boolean - open_clip_g_text?: string - speed_up_with_guidance_embed: boolean - guidance_embed: number - resolution_dependent_shift: boolean - tea_cache_start: number - tea_cache_end: number - tea_cache_threshold: number - tea_cache: boolean - separate_t5: boolean - t5_text?: string - tea_cache_max_skip_steps: number - text_prompt?: string - negative_text_prompt?: string - clip_id: number - index_in_a_clip: number - causal_inference_enabled: boolean - causal_inference: number - causal_inference_pad: number - cfg_zero_star: boolean - cfg_zero_init_steps: number - generation_time: number - reason: number -} - -export type TensorHistoryExtra = { - row_id: number - lineage: number - logical_time: number - tensor_id?: string - mask_id?: string - depth_map_id?: string - scribble_id?: string - pose_id?: string - color_palette_id?: string - custom_id?: string - moodboard_ids: string[] - history: XTensorHistoryNode - project_path: string -} - -export type DTImageFull = { - id: number - prompt?: string - negativePrompt?: string - model?: Model - project: ProjectState - config: DrawThingsConfig - groupedConfig: DrawThingsConfigGrouped - clipId: number - numFrames: number - node: XTensorHistoryNode - images?: { - tensorId?: string - previewId?: number - maskId?: string - depthMapId?: string - scribbleId?: string - poseId?: string - colorPaletteId?: string - customId?: string - moodboardIds?: string[] - } -} - -export type ScanProgress = { - projects_scanned: number - projects_total: number - project_path: string - images_scanned: number - images_total: number -} - -export type TensorRaw = { - tensor_type: number - data_type: number - format: number - width: number - height: number - channels: number - dim: ArrayBuffer - data: ArrayBuffer -} - -export type ListImagesOptions = { - projectIds?: number[] - nodeId?: number - sort?: string - direction?: string - model?: number[] - control?: number[] - lora?: number[] - search?: string - take?: number - skip?: number -} - -export type WatchFolder = { - id: number - path: string - recursive: boolean - last_updated?: number | null -} - -// -------------------- -// Command wrappers -// -------------------- - -export const pdb = { - addProject: async (path: string): Promise => { - try { - return await invoke("projects_db_project_add", { path }) - } - catch (e) { - if (e === "error communicating with database: Table not found") return undefined - console.error(e) - return undefined - } - }, - - removeProject: async (path: string): Promise => - invoke("projects_db_project_remove", { path }), - - listProjects: async (): Promise => invoke("projects_db_project_list"), - - scanProject: async ( - path: string, - fullScan = false, - filesize?: number, - modified?: number, - ): Promise => - invoke("projects_db_project_scan", { path, fullScan, filesize, modified }), - - updateExclude: async (id: number, exclude: boolean): Promise => - invoke("projects_db_project_update_exclude", { id, exclude }), - - updateMissingOn: async (paths: string[], missingOn: number | null): Promise => - invoke("projects_db_project_bulk_update_missing_on", { paths, missingOn }), - - listImages: async ( - source: MaybeReadonly, - skip: number, - take: number, - ): Promise => { - const result: ListImagesResult = await invoke("projects_db_image_list", { - ...source, - skip, - take, - }) - return result - }, - - getClip: async (imageId: number): Promise => - invoke("projects_db_get_clip", { imageId }), - - /** - * ignores projectIds, returns count of image matches in each project. - */ - listImagesCount: async (source: MaybeReadonly) => { - const opts = { ...source, projectIds: undefined, count: true } - const result: ListImagesResult = await invoke("projects_db_image_list", opts) - return result - }, - - rebuildIndex: async (): Promise => invoke("projects_db_image_rebuild_fts"), - - watchFolders: { - listAll: async (): Promise => invoke("projects_db_watch_folder_list"), - - add: async ( - path: string, - recursive: boolean, - ): Promise => - invoke("projects_db_watch_folder_add", { path, recursive }), - - remove: async (ids: number[] | number): Promise => - invoke("projects_db_watch_folder_remove", { ids: Array.isArray(ids) ? ids : [ids] }), - - update: async ( - id: number, - recursive?: boolean, - lastUpdated?: number, - ): Promise => - invoke("projects_db_watch_folder_update", { id, recursive, lastUpdated }), - }, - - scanModelInfo: async (filePath: string, modelType: ModelType): Promise => - invoke("projects_db_scan_model_info", { filePath, modelType }), - - listModels: async (modelType?: ModelType): Promise => - invoke("projects_db_list_models", { modelType }), -} - -export type ModelType = "Model" | "Lora" | "Cnet" | "Upscaler" - -export type Model = { - id: number - model_type: ModelType - filename: string - name?: string - version?: string - count?: number -} - -export type ModelInfo = { - file: string - name: string - version: string - model_type: ModelType -} - -export type TensorSize = { - width: number - height: number - channels: number -} - -export const dtProject = { - // #unused - getTensorHistory: async ( - project_file: string, - index: number, - count: number, - ): Promise[]> => - invoke("dt_project_get_tensor_history", { project_file, index, count }), - - // #unused - getThumbHalf: async (project_file: string, thumb_id: number): Promise => - invoke("dt_project_get_thumb_half", { project_file, thumb_id }), - - getHistoryFull: async (projectFile: string, rowId: number): Promise => - invoke("dt_project_get_history_full", { projectFile, rowId }), - - // #unused - getTensorRaw: async ( - projectFile: string, - projectId: number, - tensorId: string, - ): Promise => - invoke("dt_project_get_tensor_raw", { projectFile, projectId, tensorId }), - - getTensorSize: async (project: string | number, tensorId: string): Promise => { - const opts = { - tensorId, - projectId: typeof project === "string" ? undefined : project, - projectFile: typeof project === "string" ? project : undefined, - } - return invoke("dt_project_get_tensor_size", opts) - }, - - decodeTensor: async ( - project: string | number, - tensorId: string, - asPng: boolean, - nodeId?: number, - ): Promise> => { - const opts = { - tensorId, - projectId: typeof project === "string" ? undefined : project, - projectFile: typeof project === "string" ? project : undefined, - asPng, - nodeId, - } - return new Uint8Array(await invoke("dt_project_decode_tensor", opts)) - }, - - getPredecessorCandidates: async ( - projectFile: string, - rowId: number, - lineage: number, - logicalTime: number, - ): Promise => - invoke("dt_project_find_predecessor_candidates", { - projectFile, - rowId, - lineage, - logicalTime, - }), -} diff --git a/src/commands/urls.ts b/src/commands/urls.ts index bbb72b0..e10775d 100644 --- a/src/commands/urls.ts +++ b/src/commands/urls.ts @@ -1,4 +1,4 @@ -import type { ImageExtra } from "./projects" +import type { ImageExtra } from "./DtpServiceTypes" function thumb(image: ImageExtra): string function thumb(projectId: number, previewId: number): string diff --git a/src/components/DataItem.tsx b/src/components/DataItem.tsx index 7a4e78b..98f152d 100644 --- a/src/components/DataItem.tsx +++ b/src/components/DataItem.tsx @@ -540,37 +540,16 @@ const templates = { const { value, ...rest } = props if (!value || (value.guidance === 0 && value.steps === 0)) return null return null - return ( - - ) }, ImagePrior: (props: DataItemTemplateProps<"imagePrior">) => { const { value, ...rest } = props if (!value || value.steps === 0) return null return null - return ( - - ) }, AestheticScore: (props: DataItemTemplateProps<"aestheticScore">) => { const { value, ...rest } = props if (!value || (value.positive === 0 && value.negative === 0)) return null return null - return ( - - ) }, } diff --git a/src/components/FloatIndicator.tsx b/src/components/FloatIndicator.tsx deleted file mode 100644 index f120366..0000000 --- a/src/components/FloatIndicator.tsx +++ /dev/null @@ -1,189 +0,0 @@ -import { IconButton } from "@/components" -import { chakra, HStack } from "@chakra-ui/react" -import { AnimatePresence, motion } from "motion/react" -import { createContext, use, useEffect, useMemo, useState, type ComponentProps } from "react" -import { FiX } from "@/components/icons/icons" -import { useDTP } from "../dtProjects/state/context" - -const dur = 0.2 - -interface FloatIndicatorProps extends ChakraProps { - children: React.ReactNode -} - -function Root(props: FloatIndicatorProps) { - const { children, ...restProps } = props - - const [hasAltExtension, setHasAltExtension] = useState(false) - - return ( - - - - {children} - - - - ) -} - -const IndicatorContext = createContext(undefined) - -const IndicatorWrapper = chakra( - motion.div, - { - base: { - display: "flex", - gap: 0, - padding: 0, - zIndex: 0, - }, - }, - { - defaultProps: { - transition: { duration: dur, ease: "easeInOut" }, - initial: { opacity: 0 }, - animate: "normal", - exit: { opacity: 0 }, - variants: { - hovering: {}, - normal: { opacity: 1 }, - }, - className: "group", - layout: true, - whileHover: "hovering", - }, - }, -) - -const Base = chakra( - motion.div, - { - base: { - display: "grid", - gridTemplateColumns: "1fr auto", - gap: 0, - padding: 0, - flexDirection: "row", - borderLeftRadius: "md", - borderRightRadius: "md", - border: "1px solid #777777FF", - position: "relative", - boxShadow: "0px 0px 8px -2px #00000077, 0px 2px 5px -3px #00000077", - color: "fg.3", - bgColor: "bg.deep", - fontSize: "1rem", - alignItems: "stretch", - justifyContent: "stretch", - overflow: "clip", - }, - }, - { - defaultProps: { - variants: {}, - transition: { - duration: dur, - delay: dur * 0.8, - visibility: { duration: 0, delay: dur * 0.8 }, - }, - }, - }, -) - -interface ExtensionProps extends ChakraProps { - altExt?: boolean -} -const Extension = (props: ExtensionProps) => { - const { altExt, children, ...restProps } = props - const { hasAltExtension, setHasAltExtension } = use(IndicatorContext) - - const variants = useMemo(() => { - if (!hasAltExtension) { - return { - normal: () => {}, - hovering: () => {}, - } - } - const visible = () => ({ - opacity: 1, - }) - const hidden = () => ({ - opacity: 0, - }) - if (altExt) { - return { - normal: hidden, - hovering: visible, - } - } else { - return { - normal: visible, - hovering: hidden, - } - } - }, [hasAltExtension, altExt]) - - useEffect(() => { - if (altExt) setHasAltExtension(true) - - return () => { - setHasAltExtension(false) - } - }, [altExt, setHasAltExtension]) - - return ( - - {children} - - ) -} - -const ExtensionBase = chakra(motion.div, { - base: { - gridArea: "1/2", - // margin: "0.25rem", - // marginLeft: "0.5rem", - margin: 1, - overflow: "clip", - alignContent: "center", - justifyContent: "center", - }, -}) - -const Label = chakra( - motion.div, - { - base: { - bgColor: "bg.3", - color: "fg.2", - fontWeight: 500, - paddingX: 2, - outline: "1px solid #777777ff", - zIndex: 1, - cursor: "pointer", - borderLeftRadius: "md", - borderRightRadius: "xl", - boxShadow: "0px 0px 6px -2px #000000", - alignContent: "center", - height: "2rem", - }, - }, - { - defaultProps: { - // layout: true, - transition: { duration: dur }, - variants: { - hovering: {}, - normal: {}, - }, - }, - }, -) - -const FloatIndicator = { - Root, - Label, - Extension -} - -export default FloatIndicator diff --git a/src/components/PanelList.test.tsx b/src/components/PanelList.test.tsx deleted file mode 100644 index 52529d1..0000000 --- a/src/components/PanelList.test.tsx +++ /dev/null @@ -1,106 +0,0 @@ -import { render, screen, fireEvent } from "@testing-library/react" -import { describe, it, expect, vi } from "vitest" -import PanelList, { PanelListCommand } from "./PanelList" -import { proxy, useSnapshot } from "valtio" -import { PanelListItem } from '.' - -// Mock useSelectableGroup -vi.mock("@/hooks/useSelectableV", () => ({ - useSelectableGroup: (itemsSnap: any[], getItems: any) => ({ - SelectableGroup: ({ children }: any) =>
{children}
, - selectedItems: itemsSnap.filter((i: any) => i.selected), - }), -})) - -vi.mock("@chakra-ui/react", () => ({ - HStack: ({ children, ...props }: any) => ( -
- {children} -
- ), - VStack: ({ children, ...props }: any) => ( -
- {children} -
- ), -})) - -// Mock other components -vi.mock(".", () => ({ - IconButton: ({ children, onClick, disabled }: any) => ( - - ), - PaneListContainer: ({ children }: any) =>
{children}
, - PanelListItem: ({ children }: any) =>
{children}
, - PanelSectionHeader: ({ children }: any) => ( -
{children}
- ), - Tooltip: ({ children, tip }: any) =>
{children}
, -})) - -vi.mock("./common", () => ({ - PaneListScrollContainer: ({ children }: any) => ( -
{children}
- ), - PanelListScrollContent: ({ children }: any) =>
{children}
, -})) - -describe("PanelList", () => { - it("renders header correctly", () => { - const items = proxy([]) - render( []} itemsSnap={items} header="Test Header" />) - expect(screen.getByText("Test Header")).toBeInTheDocument() - }) - - it("renders commands and handles clicks", () => { - const items = proxy([{ id: 1, selected: true }]) - const onClickMock = vi.fn() - const command: PanelListCommand = { - id: "cmd1", - icon: () => CmdIcon, - onClick: onClickMock, - } - - render( - [{ id: 1, selected: true }]} - itemsSnap={items} - commands={[command]} - />, - ) - - const button = screen.getByTestId("icon-button") - expect(button).toBeInTheDocument() - expect(screen.getByText("CmdIcon")).toBeInTheDocument() - - fireEvent.click(button) - expect(onClickMock).toHaveBeenCalled() - }) - - it("disables command when selection requirement not met", () => { - const items = proxy(["something"]) // No selection - const itemsSnap = useSnapshot(items) - const command: PanelListCommand = { - id: "cmd1", - icon: () => CmdIcon, - onClick: vi.fn(), - requiresSelection: true, - } - - render( - []} itemsSnap={itemsSnap} commands={[command]}> - {items.map((item) => ({item}))} - , - ) - - const button = screen.getByTestId("icon-button") - expect(button).toBeDisabled() - - const item = screen.getByText("something") - expect(item).toBeInTheDocument() - item.click() - expect(button).not.toBeDisabled() - }) -}) diff --git a/src/components/PanelList2.tsx b/src/components/PanelList2.tsx new file mode 100644 index 0000000..78986b2 --- /dev/null +++ b/src/components/PanelList2.tsx @@ -0,0 +1,205 @@ +import { HStack, Spacer } from "@chakra-ui/react" +import { type ComponentType, useEffect, useRef } from "react" +import type { Snapshot } from "valtio" +import type { IconType } from "@/components/icons/icons" +import { PiInfo } from "@/components/icons/icons" +import type { Selectable } from "@/hooks/useSelectableV" +import { IconButton, PaneListContainer, PanelListItem, PanelSectionHeader, Tooltip } from "." +import { PaneListScrollContainer, PanelListScrollContent, PanelSection } from "./common" + +interface PanelListComponentProps extends ChakraProps { + emptyListText?: string | boolean + commands?: PanelListCommandItem[] + commandContext?: C + header?: string + headerInfo?: string + keyFn?: (item: T | Snapshot) => string | number + /** must be a valtio proxy (or a function that returns one) */ + itemsState: ValueOrGetter + onSelectionChanged?: (selected: T[]) => void + clearSelection?: unknown + selectionMode?: "multipleModifier" | "multipleToggle" | "single" + selectedItems?: Snapshot +} + +export type PanelListCommandItem = PanelListCommand | "spacer" + +export interface PanelListCommand { + id: string + ariaLabel?: string + icon?: IconType | ComponentType + getIcon?: (selected: Snapshot, context?: C) => IconType | ComponentType + requiresSelection?: boolean + requiresSingleSelection?: boolean + getEnabled?: (selected: Snapshot, context?: C) => boolean + /** if present, tipTitle and tipText will be ignored */ + tip?: React.ReactNode + tipTitle?: string + tipText?: string + getTip?: (selected: Snapshot, context?: C) => React.ReactNode + getTipTitle?: (selected: Snapshot, context?: C) => string + getTipText?: (selected: Snapshot, context?: C) => string + onClick: (selected: Snapshot, context?: C) => void +} + +function PanelList(props: PanelListComponentProps) { + const { + children, + emptyListText: emptyListTextProp, + commands, + commandContext, + header, + headerInfo, + keyFn, + itemsState: itemsProp, + onSelectionChanged, + clearSelection, + selectionMode = "multipleModifier", + selectedItems = [], + ...boxProps + } = props + + const scrollContainerRef = useRef(null) + + useEffect(() => { + const el = scrollContainerRef.current + if (!el) return + + const update = () => { + const canScrollTop = el.scrollTop > 0 + const canScrollBottom = el.scrollTop + el.clientHeight < el.scrollHeight + + el.dataset.top = canScrollTop ? "1" : "0" + el.dataset.bottom = canScrollBottom ? "1" : "0" + } + + update() + el.addEventListener("scroll", update) + + const ro = new ResizeObserver(update) + ro.observe(el) + + return () => { + el.removeEventListener("scroll", update) + ro.disconnect() + } + }, []) + + const areItemsSelected = selectedItems.length > 0 + + const emptyListText = + emptyListTextProp === false + ? null + : typeof emptyListTextProp === "string" + ? emptyListTextProp + : "(No items)" + + return ( + + {header && ( + + {header} + {headerInfo && ( + + + + )} + + )} + + + {children} + + + {!emptyListText && ( + + {emptyListText} + + )} + + + {commands?.map((command, i) => { + if (command === "spacer") return + + let enabled = true + if (command.requiresSelection && !areItemsSelected) enabled = false + if (command.requiresSingleSelection && selectedItems.length !== 1) + enabled = false + if (command.getEnabled) + enabled = command.getEnabled(selectedItems, commandContext) + + const Icon = command.getIcon + ? command.getIcon(selectedItems, commandContext) + : command.icon + const tip = command.getTip + ? command.getTip(selectedItems, commandContext) + : command.tip + const tipTitle = command.getTipTitle + ? command.getTipTitle(selectedItems, commandContext) + : command.tipTitle + const tipText = command.getTipText + ? command.getTipText(selectedItems, commandContext) + : command.tipText + + return ( + command.onClick(selectedItems, commandContext)} + disabled={!enabled} + tip={tip} + tipTitle={tipTitle} + tipText={tipText} + > + {Icon && } + + ) + })} + + + + ) +} + +export default PanelList diff --git a/src/components/Pose.tsx b/src/components/Pose.tsx index 0a6b609..d451c9d 100644 --- a/src/components/Pose.tsx +++ b/src/components/Pose.tsx @@ -1,30 +1,28 @@ -import { Box, Image } from "@chakra-ui/react" +import { Image } from "@chakra-ui/react" import { useEffect, useState } from "react" -import { dtProject } from "@/commands" +import { DtpService } from "@/commands" import { uint8ArrayToBase64 } from "@/utils/helpers" import { drawPose, pointsToPose, tensorToPoints } from "@/utils/pose" interface PoseImageComponentProps extends ChakraProps { - projectPath?: string projectId?: number tensorId?: string } function PoseImage(props: PoseImageComponentProps) { - const { projectPath, projectId, tensorId, ...restProps } = props + const { projectId, tensorId, ...restProps } = props const [src, setSrc] = useState(undefined) useEffect(() => { - const project = projectPath ?? projectId - if (project && tensorId) { - dtProject.decodeTensor(project, tensorId, false).then(async (data) => { + if (projectId && tensorId) { + DtpService.decodeTensor(projectId, tensorId, false).then(async (data) => { const points = tensorToPoints(data) const pose = pointsToPose(points, 256, 256) const image = await drawPose(pose, 4) if (image) setSrc(`data:image/png;base64,${await uint8ArrayToBase64(image)}`) }) } - }, [projectPath, tensorId, projectId]) + }, [tensorId, projectId]) return } diff --git a/src/components/common.tsx b/src/components/common.tsx index 6f76412..9a0a42c 100644 --- a/src/components/common.tsx +++ b/src/components/common.tsx @@ -89,6 +89,8 @@ export const CheckRoot = chakra( export const PaneListContainer = chakra("div", { base: { + position: "relative", + height: "auto", maxHeight: "100%", width: "100%", @@ -113,7 +115,8 @@ export const PaneListScrollContainer = chakra( "div", { base: { - height: "100%", + position: "relative", + // height: "100%", width: "100%", paddingY: "1px", gap: 0, @@ -133,13 +136,13 @@ export const PaneListScrollContainer = chakra( export const PanelListScrollContent = chakra("div", { base: { height: "auto", - bgColor: "bg.deep/90", + bgColor: "bg.deep/50", // minHeight: "100%", display: "flex", flexDirection: "column", justifyContent: "flex-start", alignItems: "stretch", - gap: 0.5, + gap: 0, }, }) diff --git a/src/components/icons/icons.tsx b/src/components/icons/icons.tsx index 22c7476..f151f66 100644 --- a/src/components/icons/icons.tsx +++ b/src/components/icons/icons.tsx @@ -16,9 +16,10 @@ export { GiNeedleDrill } from "react-icons/gi" export { GoGear } from "react-icons/go" export type { IconType } from "react-icons/lib" export { LuFolderTree, LuMoon, LuSun, LuX } from "react-icons/lu" -export { MdBlock, MdImageSearch } from "react-icons/md" +export { MdBlock, MdDoNotDisturbOn, MdImageSearch } from "react-icons/md" export { PiCoffee, + PiEject, PiFilmStrip, PiImage, PiImages, diff --git a/src/components/index.ts b/src/components/index.ts index 689dd85..57adf44 100644 --- a/src/components/index.ts +++ b/src/components/index.ts @@ -8,7 +8,6 @@ import * as Preview from "./preview" import SliderWithInput from "./SliderWithInput" import Sidebar from "./sidebar/Sidebar" import Tooltip from "./Tooltip" -import VirtualizedList from "./virtualizedList/VirtualizedList" export const { CheckRoot, @@ -26,7 +25,6 @@ export { Tooltip, SliderWithInput, IconButton, - VirtualizedList, Sidebar, Preview, MeasureGrid, diff --git a/src/components/preview/index.tsx b/src/components/preview/index.tsx index 371bc9b..a0f0e40 100644 --- a/src/components/preview/index.tsx +++ b/src/components/preview/index.tsx @@ -350,7 +350,7 @@ export function contain( } as DOMRect } -export function DotSpinner(props) { +export function DotSpinner(props: BoxProps) { const { style, ...rest } = props return ( @@ -378,7 +378,7 @@ const loopTrans = (delay: number) => }) as MotionProps["transition"] function Dot(props: Record) { - const { delay, cx, ix, iy, dur } = props + const { delay, cx } = props return ( { + onPointerDown: (e: React.PointerEvent) => { e.stopPropagation() controls.togglePlayPause() }, - onClick: (e) => e.stopPropagation(), + onClick: (e: React.MouseEvent) => e.stopPropagation(), } : {} diff --git a/src/components/video/context.ts b/src/components/video/context.ts index 6341eb7..a6a3016 100644 --- a/src/components/video/context.ts +++ b/src/components/video/context.ts @@ -1,7 +1,7 @@ import { createContext, useContext, useEffect, useRef } from "react" -import { pdb } from "@/commands" +import type { ImageExtra } from "@/commands" +import DTPService from '@/commands/DtpService' import urls from "@/commands/urls" -import type { ImageExtra } from "@/generated/types" import { useProxyRef } from "@/hooks/valtioHooks" import { everyNth } from "@/utils/helpers" import { useFrameAnimation } from "./hooks" @@ -91,7 +91,7 @@ export function useCreateVideoContext(opts: UseCreateVideoContextOpts) { useEffect(() => { if (!image) return - pdb.getClip(image.id).then(async (data) => { + DTPService.getClip(image.id).then(async (data) => { if (!image) return if (!imgRef.current) return diff --git a/src/components/virtualizedList/PVLIst.tsx b/src/components/virtualizedList/PVLIst.tsx deleted file mode 100644 index da91016..0000000 --- a/src/components/virtualizedList/PVLIst.tsx +++ /dev/null @@ -1,314 +0,0 @@ -import { Box, chakra, VStack } from "@chakra-ui/react" -import { useCallback, useEffect, useRef } from "react" -import { proxy, useSnapshot } from "valtio" -import { proxyMap } from "valtio/utils" -import { usePagedItemSource } from "./usePagedItemSource" - -export interface PVListProps extends ChakraProps { - itemComponent: PVListItemComponent - /** number of screens */ - overscan?: number - keyFn?: (item: T, index: number) => string | number - initialRenderCount?: number - itemProps?: P - totalCount: number - pageSize: number - getItems: (skip: number, take: number) => Promise -} - -export type PVListItemComponent = React.ComponentType> - -export interface PVListItemProps { - value: T | Readonly | null - index: number - itemProps: P - onSizeChanged?: (index: number, isBaseSize: boolean) => void -} - -type ProxyMap = ReturnType> -type StateProxy = { - preSpacerHeight: number - postSpacerHeight: number - minThreshold: number - maxThreshold: number - firstIndex: number - lastIndex: number - visibleHeight: number - expanded: ReturnType> -} -function PVList(props: PVListProps) { - const { - itemComponent, - keyFn, - initialRenderCount = 50, - overscan = 2, - itemProps, - totalCount, - pageSize, - getItems, - ...restProps - } = props - const Item = itemComponent - - const { renderItems, setRenderWindow } = usePagedItemSource({ getItems, pageSize, totalCount }) - - const stateRef = useRef>(null) - if (stateRef.current === null) { - stateRef.current = proxy({ - preSpacerHeight: 0, - minThreshold: 0, - firstIndex: 0, - lastIndex: initialRenderCount, - maxThreshold: 0, - postSpacerHeight: 0, - visibleHeight: 1, - expanded: proxyMap(), - }) - } - const state = stateRef.current as StateProxy - const snap = useSnapshot(state) - - const scrollContainerRef = useRef(null) - const scrollContentRef = useRef(null) - const topSpaceRef = useRef(null) - const bottomSpaceRef = useRef(null) - - const calcFirstAndLastIndex = useCallback(() => { - const container = scrollContainerRef.current - const content = scrollContentRef.current - if (!container || !content) return { first: -1, last: -1 } - - const scrollTop = container.scrollTop - const scrollBottom = scrollTop + container.clientHeight - const itemHeight = getItemHeight(content, state.expanded, state.firstIndex) - - let first = -1 - let last = -1 - - let t = 0 - for (let i = 0; i < totalCount; i++) { - const h = state.expanded.get(i) ?? itemHeight - t += h - if (first === -1 && t >= scrollTop) first = i - if (t >= scrollBottom) { - last = i - break - } - } - - return { first, last } - }, [totalCount, state]) - - const recalculate = useCallback(() => { - const scrollContent = scrollContentRef.current - const scrollContainer = scrollContainerRef.current - if (!scrollContent || !scrollContainer) return - - const { first, last } = calcFirstAndLastIndex() - if (first === -1 || last === -1) return - const visibleItemsCount = last - first + 1 - - const itemHeight = getItemHeight(scrollContent, state.expanded, state.firstIndex) - - state.firstIndex = Math.max(0, first - visibleItemsCount * overscan) - state.lastIndex = Math.min(totalCount, last + visibleItemsCount * overscan) - setRenderWindow(state.firstIndex, state.lastIndex) - - const [pre, mid, post] = calcRangeHeights( - state.firstIndex, - state.lastIndex, - totalCount, - itemHeight, - state.expanded, - ) - - state.preSpacerHeight = pre - state.postSpacerHeight = post - - state.minThreshold = (scrollContainer.scrollTop + pre) / 2 - const scrollBottom = scrollContainer.scrollTop + scrollContainer.clientHeight - state.maxThreshold = (scrollBottom + pre + mid) / 2 - scrollContainer.clientHeight - }, [totalCount, overscan, state, calcFirstAndLastIndex, setRenderWindow]) - - const handleScroll = useCallback( - (e: React.UIEvent) => { - if ( - e.currentTarget.scrollTop < state.minThreshold || - e.currentTarget.scrollTop > state.maxThreshold - ) { - recalculate() - } - }, - [recalculate, state], - ) - - useEffect(() => { - recalculate() - }, [recalculate]) - - useEffect(() => { - if (!scrollContainerRef.current) return - const ro = new ResizeObserver(() => { - if (!scrollContainerRef.current) return - state.visibleHeight = scrollContainerRef.current.clientHeight - recalculate() - }) - ro.observe(scrollContainerRef.current) - - return () => ro.disconnect() - }, [state, recalculate]) - - const handleSizeChanged = useCallback( - (index: number, baseSize: boolean) => { - setTimeout((res) => { - const actualIndex = index - snap.firstIndex + 1 - - if (baseSize) { - state.expanded.delete(index) - return - } - - const element = scrollContentRef.current?.children[actualIndex] as HTMLDivElement - const nextElement = scrollContainerRef.current?.firstElementChild?.children[ - actualIndex + 1 - ] as HTMLDivElement - state.expanded.set(index, nextElement?.offsetTop - element?.offsetTop) - }, 1000) - }, - [state, snap.firstIndex], - ) - - return ( - handleScroll(e)} - {...restProps} - > - - - {renderItems.map((item, i) => { - const index = i + snap.firstIndex - return ( - - ) - })} -
- {snap.postSpacerHeight} {totalCount} -
-
- - {/* {createPortal( - - Items: {items.length}
- Current Item: {snap.currentItem}
- Updates: {snap.updates}
- Rendered: {snap.firstIndex} to {snap.lastIndex}
- Height: {snap.visibleHeight} -
, - document.getElementById("root"), - )} */} -
- ) -} - -const Container = chakra("div", { - base: { - overflowY: "auto", - }, -}) - -const Content = chakra("div", { - base: { - width: "100%", - minHeight: "100%", - overflowY: "visible", - display: "grid", - gridTemplateColumns: "1fr", - justifyContent: "flex-start", - alignItems: "stretch", - gap: 0, - }, -}) - -export default PVList - -function getItemHeight( - container: HTMLDivElement | null, - expanded: Map, - first: number, -) { - if (!container) return 1 - - if (!container || container.childNodes.length <= 2) return 1 - if (container.childNodes.length === 3) - return (container.childNodes[1] as HTMLDivElement).clientHeight - - // instead of getting the item height, I need the distance to the next item - // and thanks to the spacer there will always be a next item so I don't have to worry about checking - let actual = -1 - for (let i = 1; i < container.childNodes.length - 1; i++) { - actual = first + i - 1 - if (expanded.has(actual)) continue - - const a = container.childNodes[i] as HTMLDivElement - const b = container.childNodes[i + 1] as HTMLDivElement - - return b.offsetTop - a.offsetTop - } - - return 1 -} - -function calcRangeHeights( - first: number, - last: number, - length: number, - baseHeight: number, - expanded: ProxyMap, -) { - const [pre, mid, post] = sumGroups(expanded, first, last) - - const preCount = first - const preHeight = (preCount - pre.count) * baseHeight + pre.sum - - const midCount = last - first + 1 - const midHeight = (midCount - mid.count) * baseHeight + mid.sum - - const postCount = length - last - 1 - const postHeight = (postCount - post.count) * baseHeight + post.sum - - return [preHeight, midHeight, postHeight] -} - -function sumGroups(map: Map, pre: number, post: number) { - const preGroup = { sum: 0, count: 0 } - const midGroup = { sum: 0, count: 0 } - const postGroup = { sum: 0, count: 0 } - - for (const [k, v] of map.entries()) { - if (k < pre) { - preGroup.sum += v - preGroup.count++ - } else if (k > post) { - postGroup.sum += v - postGroup.count++ - } else { - midGroup.sum += v - midGroup.count++ - } - } - - return [preGroup, midGroup, postGroup] -} diff --git a/src/components/virtualizedList/PagedItemSource.ts b/src/components/virtualizedList/PagedItemSource.ts index 280b809..3a3fa75 100644 --- a/src/components/virtualizedList/PagedItemSource.ts +++ b/src/components/virtualizedList/PagedItemSource.ts @@ -227,6 +227,7 @@ export class EmptyItemSource implements IItemSource { totalCount: 0, activeItemIndex: undefined, activeItem: undefined, + hasInitialLoad: true, }) } diff --git a/src/components/virtualizedList/VirtualizedList.tsx b/src/components/virtualizedList/VirtualizedList.tsx deleted file mode 100644 index ca34f5f..0000000 --- a/src/components/virtualizedList/VirtualizedList.tsx +++ /dev/null @@ -1,309 +0,0 @@ -import { Box, chakra, VStack } from "@chakra-ui/react" -import { type JSX, useCallback, useEffect, useRef } from "react" -import { createPortal } from "react-dom" -import { proxy, useSnapshot } from "valtio" -import { proxyMap } from "valtio/utils" - -export interface VirtualizedListProps extends ChakraProps { - items: T[] | Readonly - itemComponent: React.ComponentType> - /** number of screens */ - overscan?: number - keyFn?: (item: T, index: number) => string | number - initialRenderCount?: number - itemProps?: P -} - -export interface VirtualizedListItemProps { - items: T[] | Readonly - value: T | Readonly - index: number - itemProps: P - onSizeChanged?: (index: number, isBaseSize: boolean) => void -} - -type ProxyMap = ReturnType> -type StateProxy = { - currentItem: number - updates: number - preSpacerHeight: number - postSpacerHeight: number - minThreshold: number - maxThreshold: number - firstIndex: number - lastIndex: number - visibleHeight: number - expanded: ReturnType> -} -function VirtualizedList>(props: VirtualizedListProps) { - const { - itemComponent, - items, - keyFn, - initialRenderCount = 50, - overscan = 2, - itemProps, - ...restProps - } = props - const Item = itemComponent - - const stateRef = useRef(null) - if (stateRef.current === null) { - stateRef.current = proxy({ - currentItem: 0, - updates: 0, - preSpacerHeight: 0, - minThreshold: 0, - firstIndex: 0, - lastIndex: initialRenderCount, - maxThreshold: 0, - postSpacerHeight: 0, - visibleHeight: 1, - expanded: proxyMap(), - }) - } - const state = stateRef.current as StateProxy - const snap = useSnapshot(state) - - const scrollContainerRef = useRef(null) - const scrollContentRef = useRef(null) - const topSpaceRef = useRef(null) - const bottomSpaceRef = useRef(null) - - const calcFirstAndLastIndex = useCallback(() => { - const container = scrollContainerRef.current - const content = scrollContentRef.current - if (!container || !content) return { first: -1, last: -1 } - - const scrollTop = container.scrollTop - const scrollBottom = scrollTop + container.clientHeight - const itemHeight = getItemHeight(content, state.expanded, state.firstIndex) - - let first = -1 - let last = -1 - - let t = 0 - for (let i = 0; i < items.length; i++) { - const h = state.expanded.get(i) ?? itemHeight - t += h - if (first === -1 && t >= scrollTop) first = i - if (t >= scrollBottom) { - last = i - break - } - } - - return { first, last } - }, [items.length, state]) - - const recalculate = useCallback(() => { - const scrollContent = scrollContentRef.current - const scrollContainer = scrollContainerRef.current - if (!scrollContent || !scrollContainer) return - - const { first, last } = calcFirstAndLastIndex() - if (first === -1 || last === -1) return - const visibleItemsCount = last - first + 1 - - const itemHeight = getItemHeight(scrollContent, state.expanded, state.firstIndex) - - state.firstIndex = Math.max(0, first - visibleItemsCount * overscan) - state.lastIndex = Math.min(items.length, last + visibleItemsCount * overscan) - - const [pre, mid, post] = calcRangeHeights( - state.firstIndex, - state.lastIndex, - items.length, - itemHeight, - state.expanded, - ) - - state.preSpacerHeight = pre - state.postSpacerHeight = post - - state.minThreshold = (scrollContainer.scrollTop + pre) / 2 - const scrollBottom = scrollContainer.scrollTop + scrollContainer.clientHeight - state.maxThreshold = (scrollBottom + pre + mid) / 2 - scrollContainer.clientHeight - }, [items.length, overscan, state, calcFirstAndLastIndex]) - - const handleScroll = useCallback( - (e: React.UIEvent) => { - if ( - e.currentTarget.scrollTop < state.minThreshold || - e.currentTarget.scrollTop > state.maxThreshold - ) { - recalculate() - } - }, - [recalculate, state], - ) - - useEffect(() => { - recalculate() - }, [recalculate]) - - useEffect(() => { - if (!scrollContainerRef.current) return - const ro = new ResizeObserver(() => { - if (!scrollContainerRef.current) return - state.visibleHeight = scrollContainerRef.current.clientHeight - recalculate() - }) - ro.observe(scrollContainerRef.current) - - return () => ro.disconnect() - }, [state, recalculate]) - - const handleSizeChanged = useCallback( - (index: number, baseSize: boolean) => { - const actualIndex = index - snap.firstIndex + 1 - - if (baseSize) { - state.expanded.delete(index) - return - } - - const element = scrollContentRef.current?.children[actualIndex] as HTMLDivElement - const nextElement = scrollContainerRef.current?.firstElementChild?.children[ - actualIndex + 1 - ] as HTMLDivElement - state.expanded.set(index, nextElement?.offsetTop - element?.offsetTop) - }, - [state, snap.firstIndex], - ) - - return ( - handleScroll(e)} - {...restProps} - > - - - {items.slice(snap.firstIndex, snap.lastIndex).map((item, i) => { - const index = i + snap.firstIndex - return ( - - ) - })} -
- {snap.postSpacerHeight} {items.length} -
-
- - {/* {createPortal( - - Items: {items.length}
- Current Item: {snap.currentItem}
- Updates: {snap.updates}
- Rendered: {snap.firstIndex} to {snap.lastIndex}
- Height: {snap.visibleHeight} -
, - document.getElementById("root"), - )} */} -
- ) -} - -const Container = chakra("div", { - base: { - overflowY: "auto", - }, -}) - -const Content = chakra("div", { - base: { - width: "100%", - minHeight: "100%", - overflowY: "visible", - display: "grid", - gridTemplateColumns: "1fr", - justifyContent: "flex-start", - alignItems: "stretch", - gap: 0, - }, -}) - -export default VirtualizedList - -function getItemHeight( - container: HTMLDivElement | null, - expanded: Map, - first: number, -) { - if (!container) return 1 - - if (!container || container.childNodes.length <= 2) return 1 - if (container.childNodes.length === 3) - return (container.childNodes[1] as HTMLDivElement).clientHeight - - // instead of getting the item height, I need the distance to the next item - // and thanks to the spacer there will always be a next item so I don't have to worry about checking - let actual = -1 - for (let i = 1; i < container.childNodes.length - 1; i++) { - actual = first + i - 1 - if (expanded.has(actual)) continue - - const a = container.childNodes[i] as HTMLDivElement - const b = container.childNodes[i + 1] as HTMLDivElement - - return b.offsetTop - a.offsetTop - } - - return 1 -} - -function calcRangeHeights( - first: number, - last: number, - length: number, - baseHeight: number, - expanded: ProxyMap, -) { - const [pre, mid, post] = sumGroups(expanded, first, last) - - const preCount = first - const preHeight = (preCount - pre.count) * baseHeight + pre.sum - - const midCount = last - first + 1 - const midHeight = (midCount - mid.count) * baseHeight + mid.sum - - const postCount = length - last - 1 - const postHeight = (postCount - post.count) * baseHeight + post.sum - - return [preHeight, midHeight, postHeight] -} - -function sumGroups(map: Map, pre: number, post: number) { - const preGroup = { sum: 0, count: 0 } - const midGroup = { sum: 0, count: 0 } - const postGroup = { sum: 0, count: 0 } - - for (const [k, v] of map.entries()) { - if (k < pre) { - preGroup.sum += v - preGroup.count++ - } else if (k > post) { - postGroup.sum += v - postGroup.count++ - } else { - midGroup.sum += v - midGroup.count++ - } - } - - return [preGroup, midGroup, postGroup] -} diff --git a/src/dtProjects/DTProjects.tsx b/src/dtProjects/DTProjects.tsx index c1f2f50..0f8e0ee 100644 --- a/src/dtProjects/DTProjects.tsx +++ b/src/dtProjects/DTProjects.tsx @@ -30,7 +30,11 @@ function DTProjects(props: ChakraProps) { return ( - + { - const { open } = props +type ImportProgressProps = { + open: boolean + progress?: { + found: number + scanned: number + imageCount: number + } +} + +const ImportProgress = (props: ImportProgressProps) => { + const { open, progress } = props - const { projects } = useDTP() - const projectsSnap = projects.useSnap() + if (!open || !progress) return null - const found = projectsSnap.projects.length - const scanned = projectsSnap.projects.filter((p) => (p.filesize ?? 0) > 0).length - const imageCount = projectsSnap.projects.reduce((acc, p) => acc + (p.image_count ?? 0), 0) + const { found, scanned, imageCount } = progress return ( diff --git a/src/dtProjects/controlPane/ControlPane.tsx b/src/dtProjects/controlPane/ControlPane.tsx index 9f6fa05..61bc1b3 100644 --- a/src/dtProjects/controlPane/ControlPane.tsx +++ b/src/dtProjects/controlPane/ControlPane.tsx @@ -3,7 +3,7 @@ import { IconButton, Panel } from "@/components" import { GoGear, MdImageSearch, PiCoffee } from "@/components/icons/icons" import { useDTP } from "@/dtProjects/state/context" import Tabs from "@/metadata/infoPanel/tabs" -import ProjectsPanel from "./ProjectsPanel" +import ProjectsPanel from "./projectsPanel/ProjectsPanel" import SearchPanel from "./SearchPanel" const tabs = [ @@ -12,14 +12,12 @@ const tabs = [ value: "search", Icon: MdImageSearch, component: SearchPanel, - requiresProjects: true, }, { label: "Projects", value: "projects", Icon: PiCoffee, component: ProjectsPanel, - requiresProjects: true, }, ] @@ -47,27 +45,35 @@ function ControlPane(props: ControlPane) { onValueChange={(e) => { uiState.setSelectedTab(e.value as typeof uiSnap.selectedTab) }} + aria-label="Projects tabs" > - - + + ) } function TabList(props: ChakraProps) { - const { projects, uiState } = useDTP() - const snap = projects.useSnap() - - const hasProjects = snap.projects.length > 0 + const { uiState } = useDTP() return ( - - {tabs.map(({ value, Icon, label, requiresProjects }) => { - if (requiresProjects && !hasProjects) return null + + {tabs.map(({ value, Icon, label }) => { return ( - + {label} diff --git a/src/dtProjects/controlPane/ProjectsPanel.tsx b/src/dtProjects/controlPane/ProjectsPanel.tsx deleted file mode 100644 index 422f9f2..0000000 --- a/src/dtProjects/controlPane/ProjectsPanel.tsx +++ /dev/null @@ -1,176 +0,0 @@ -import { Box, FormatByte, HStack } from "@chakra-ui/react" -import { useEffect, useRef, useState } from "react" -import { useSnapshot } from "valtio" -import { computed } from "valtio-reactive" -import { PanelListItem } from "@/components" -import { FiRefreshCw, MdBlock } from "@/components/icons/icons" -import PanelList from "@/components/PanelList" -import { useSelectable } from "@/hooks/useSelectableV" -import TabContent from "@/metadata/infoPanel/TabContent" -import { useDTP } from "../state/context" -import type { ProjectState } from "../state/projects" -import { useProjectsCommands } from "./useProjectsCommands" - -interface ProjectsPanelComponentProps extends ChakraProps {} - -function ProjectsPanel(props: ProjectsPanelComponentProps) { - const { ...restProps } = props - const { projects, images } = useDTP() - const snap = projects.useSnap() - const { imageSource, projectImageCounts } = images.useSnap() - const [showExcluded, setShowExcluded] = useState(false) - const toggleRef = useRef(null) - - const groups = computed({ - activeProjects: () => projects.state.projects.filter((p) => !p.excluded), - excludedProjects: () => projects.state.projects.filter((p) => p.excluded), - allProjects: () => - projects.state.projects.toSorted( - (a, b) => (a.excluded ? 1 : -1) - (b.excluded ? 1 : -1), - ), - }) - - const activeProjectsSnap = useSnapshot(groups.activeProjects) - const excludedProjectsSnap = useSnapshot(groups.excludedProjects) - - const isFiltering = - !!imageSource?.filters?.length || - !!imageSource?.search || - imageSource?.showImage !== imageSource?.showVideo - const showEmpty = snap.showEmptyProjects || !isFiltering - - useEffect(() => { - if (showExcluded && toggleRef.current) { - setTimeout(() => { - toggleRef.current?.scrollIntoView({ behavior: "smooth", block: "start" }) - }, 100) - } - }, [showExcluded]) - - const toolbarCommands = useProjectsCommands() - - return ( - - p.path} - commands={toolbarCommands} - onSelectionChanged={(e) => { - projects.setSelectedProjects(e) - }} - > - {activeProjectsSnap.map((p) => { - if (!showEmpty && projectImageCounts?.[p.id] === undefined) return null - return ( - - ) - })} - {excludedProjectsSnap.length > 0 && ( - setShowExcluded(!showExcluded)} - cursor="pointer" - color="fg.3" - _hover={{ color: "fg.1" }} - > - - - - {showExcluded - ? "Hide excluded projects" - : `Show excluded projects (${excludedProjectsSnap.length})`} - - - - )} - {showExcluded && - excludedProjectsSnap.map((p) => )} - - - - {groups.activeProjects.length} projects - - - {groups.activeProjects.reduce((p, c) => p + (c.image_count ?? 0), 0)} images - - - p + (c.filesize ?? 0), 0)} - /> - - - - ) -} - -export default ProjectsPanel - -interface ProjectListItemProps extends ChakraProps { - project: ProjectState - altCount?: number -} -function ProjectListItem(props: ProjectListItemProps) { - const { project, altCount, ...restProps } = props - const { handlers, isSelected } = useSelectable(project) - - let count: number | string = project.image_count ?? 0 - let countStyle: string | undefined - - if (altCount !== count) { - count = altCount || "" - countStyle = "italic" - } // what if every image in the project matches a search? then it won't be italic - - const projectName = project.path.split("/").pop()?.slice(0, -8) - - return ( - { - // const text = await invoke("dt_project_get_text_history", { - // projectFile: project.path, - // }) - // console.log(text) - // navigator.clipboard.writeText(text) - // }} - > - - - {projectName} - {project.isMissing && " (missing)"} - - {project.isScanning ? ( - - - ) : ( - - {count} - - )} - - - ) -} diff --git a/src/dtProjects/controlPane/filters/ContentValueSelector.tsx b/src/dtProjects/controlPane/filters/ContentValueSelector.tsx index 9be9409..1d891b1 100644 --- a/src/dtProjects/controlPane/filters/ContentValueSelector.tsx +++ b/src/dtProjects/controlPane/filters/ContentValueSelector.tsx @@ -58,9 +58,9 @@ function ContentValueSelectorComponent(props: ValueSelectorProps) { const ContentValueSelector = ContentValueSelectorComponent as FilterValueSelector -ContentValueSelector.getValueLabel = (values) => { - if (!Array.isArray(values)) return [] - return values.map((v) => +ContentValueSelector.getValueLabel = (values: string | string[]) => { + const valueArray = Array.isArray(values) ? values : [values] + return valueArray.map((v) => v in contentValues ? contentValues[v as keyof typeof contentValues] : "unknown", ) } diff --git a/src/dtProjects/controlPane/filters/collections.tsx b/src/dtProjects/controlPane/filters/collections.tsx index d72ee17..be6eb8a 100644 --- a/src/dtProjects/controlPane/filters/collections.tsx +++ b/src/dtProjects/controlPane/filters/collections.tsx @@ -15,7 +15,9 @@ export function createValueLabelCollection(values: Record) { }) } -export type FilterValueSelector = (props: ValueSelectorProps) => JSX.Element +export type FilterValueSelector = ((props: ValueSelectorProps) => JSX.Element) & { + getValueLabel: (value: T) => string[] +} export function getValueSelector(target?: string) { if (!target) return filterTargets.none.ValueComponent diff --git a/src/dtProjects/controlPane/projectsPanel/ProjectFolderGroup.tsx b/src/dtProjects/controlPane/projectsPanel/ProjectFolderGroup.tsx new file mode 100644 index 0000000..09cfa91 --- /dev/null +++ b/src/dtProjects/controlPane/projectsPanel/ProjectFolderGroup.tsx @@ -0,0 +1,146 @@ +import { Box, Button, HStack, Spacer, VStack } from "@chakra-ui/react" +import { useEffect, useRef, useState } from "react" +import { MdBlock, MdDoNotDisturbOn } from "react-icons/md" +import { DtpService } from "@/commands" +import { IconButton, PanelListItem } from "@/components" +import { FiRefreshCw, PiEject } from "@/components/icons/icons" +import type { WatchFolderState } from "@/dtProjects/state/watchFolders" +import { ProjectState } from "@/dtProjects/state/projects" +import ProjectListItem from "./ProjectListItem" + +interface ProjectFolderGroupProps extends ChakraProps { + watchfolder: WatchFolderState + projects: readonly ProjectState[] + altCounts?: Record + showLabel: boolean + onSelectFolder: (watchfolder: WatchFolderState) => void +} + +function ProjectFolderGroup(props: ProjectFolderGroupProps) { + const { watchfolder, projects, altCounts, showLabel, children, onSelectFolder, ...restProps } = + props + + const [highlightGroup, setHighlightGroup] = useState(false) + const [showExcluded, setShowExcluded] = useState(false) + + const activeProjects = projects.filter((p) => !p.excluded) + const excludedProjects = projects.filter((p) => p.excluded) + + const label = getLabel(watchfolder) + + return ( + + {showLabel && ( + setHighlightGroup(true)} + onMouseLeave={() => setHighlightGroup(false)} + onClick={() => onSelectFolder(watchfolder)} + > + {/* */} + {label} + + {watchfolder.isMissing && } + {!watchfolder.isMissing && !watchfolder.isLocked && !watchfolder.isDtData && ( + { + e.stopPropagation() + await DtpService.lockFolder(watchfolder.id) + }} + > + + + )} + + )} + {watchfolder.isLocked ? ( + Safe to remove + ) : ( + <> + {activeProjects.map((p) => ( + + ))} + {excludedProjects.length > 0 && ( + setShowExcluded(!showExcluded)} + cursor="pointer" + color="fg.3" + _hover={{ color: "fg.1" }} + > + + + + {showExcluded + ? "Hide excluded projects" + : `Show excluded projects (${excludedProjects.length})`} + + + + )} + {showExcluded && + excludedProjects.map((p) => )} + + )} + {/* {activeProjectsSnap.map((p) => { + if (!showEmpty && projectImageCounts?.[p.id] === undefined) return null + return ( + + ) + })} + {excludedProjectsSnap.length > 0 && ( + setShowExcluded(!showExcluded)} + cursor="pointer" + color="fg.3" + _hover={{ color: "fg.1" }} + > + + + + {showExcluded + ? "Hide excluded projects" + : `Show excluded projects (${excludedProjectsSnap.length})`} + + + + )} + {showExcluded && + excludedProjectsSnap.map((p) => )} */} + + ) +} + +function getLabel(watchfolder: WatchFolderState) { + if (watchfolder.isDtData) return "Draw Things" + const parts = watchfolder.path.split("/") + return parts.slice(-2).join("/") +} + +export default ProjectFolderGroup diff --git a/src/dtProjects/controlPane/projectsPanel/ProjectListItem.tsx b/src/dtProjects/controlPane/projectsPanel/ProjectListItem.tsx new file mode 100644 index 0000000..a7f4dd1 --- /dev/null +++ b/src/dtProjects/controlPane/projectsPanel/ProjectListItem.tsx @@ -0,0 +1,49 @@ +import { Box, HStack } from "@chakra-ui/react" +import type { Snapshot } from "valtio" +import { PanelListItem } from "@/components" +import type { ProjectState } from "@/dtProjects/state/projects" + +export interface ProjectListItemProps extends ChakraProps { + project: Snapshot + altCount?: number +} +function ProjectListItem(props: ProjectListItemProps) { + const { project, altCount, ...restProps } = props + + let count: number | string = project.image_count ?? 0 + let countStyle: string | undefined + if (altCount !== count) { + count = altCount || "" + countStyle = "italic" + } + + const projectName = project.path.split("/").pop()?.slice(0, -8) + + return ( + project.onClick(e)} + {...restProps} + // {...handlers} + > + + + {projectName} + {project.isMissing && " (missing)"} + + + {count} + + + + ) +} + +export default ProjectListItem diff --git a/src/dtProjects/controlPane/projectsPanel/ProjectsPanel.tsx b/src/dtProjects/controlPane/projectsPanel/ProjectsPanel.tsx new file mode 100644 index 0000000..169907b --- /dev/null +++ b/src/dtProjects/controlPane/projectsPanel/ProjectsPanel.tsx @@ -0,0 +1,75 @@ +import { Box, FormatByte, HStack } from "@chakra-ui/react" +import PanelList from "@/components/PanelList2" +import TabContent from "@/metadata/infoPanel/TabContent" +import { useDTP } from "../../state/context" +import { useProjectsCommands } from "../useProjectsCommands" +import ProjectFolderGroup from "./ProjectFolderGroup" + +interface ProjectsPanelComponentProps extends ChakraProps {} + +function ProjectsPanel(props: ProjectsPanelComponentProps) { + const { ...restProps } = props + const { projects, images } = useDTP() + const snap = projects.useSnap() + const { imageSource, projectImageCounts } = images.useSnap() + + const showFolders = true + + const isFiltering = + !!imageSource?.filters?.length || + !!imageSource?.search || + imageSource?.showImage !== imageSource?.showVideo + const showEmpty = snap.showEmptyProjects || !isFiltering + + const toolbarCommands = useProjectsCommands() + + console.log("render") + + return ( + + p.path} + commands={toolbarCommands} + selectedItems={snap.selectedProjects} + onSelectionChanged={(e) => { + projects.setSelectedProjects(e) + }} + > + {showFolders && + snap.folders.map((folderGroup, i, arr) => ( + 1} + watchfolder={folderGroup.watchfolder} + altCounts={projectImageCounts} + projects={folderGroup.projects} + onSelectFolder={(wf) => projects.selectFolderProjects(wf)} + /> + ))} + + + + {snap.projects.length} projects + + {snap.projects.reduce((p, c) => p + (c.image_count ?? 0), 0)} images + + p + (c.filesize ?? 0), 0)} /> + + + + ) +} + +export default ProjectsPanel diff --git a/src/dtProjects/detailsOverlay/DetailsButtonBar.tsx b/src/dtProjects/detailsOverlay/DetailsButtonBar.tsx index 4c26cb4..610c757 100644 --- a/src/dtProjects/detailsOverlay/DetailsButtonBar.tsx +++ b/src/dtProjects/detailsOverlay/DetailsButtonBar.tsx @@ -6,12 +6,11 @@ import { type ComponentProps, useRef, useState } from "react" import { FiCopy, FiSave } from "react-icons/fi" import { PiListMagnifyingGlassBold } from "react-icons/pi" import type { Snapshot } from "valtio" -import { dtProject, pdb } from "@/commands" +import { DtpService, type ImageExtra } from "@/commands" import { IconButton } from "@/components" import FrameCountIndicator from "@/components/FrameCountIndicator" import VideoFrameIcon from "@/components/icons/VideoFramesIcon" import type { VideoContextType } from "@/components/video/context" -import type { ImageExtra } from "@/generated/types" import { sendToMetadata } from "@/metadata/state/interop" import { useDTP } from "../state/context" import type { ProjectState } from "../state/projects" @@ -54,24 +53,26 @@ function DetailsButtonBar(props: DetailsButtonBarProps) { } console.log("getting frame", frameIndex) - const clip = await pdb.getClip(item.id) + const clip = await DtpService.getClip(item.id) const frame = clip[frameIndex] if (!frame) return - return await dtProject.decodeTensor(item.project_id, frame.tensor_id, true, frame.row_id) + return await DtpService.decodeTensor(item.project_id, frame.tensor_id, true, frame.row_id) } const getImage = async (frameIndex?: number) => { if (isVideo) return getFrame(frameIndex) console.log("getting image") - if (!project?.path || !tensorId) return - return await dtProject.decodeTensor(project.path, tensorId, true, nodeId) + if (!item || !tensorId) return + return await DtpService.decodeTensor(item.project_id, tensorId, true, nodeId) } const disabled = !projectId || !tensorId || !show || lockButtons return ( e.stopPropagation()} initial={{ opacity: 0 }} @@ -82,6 +83,7 @@ function DetailsButtonBar(props: DetailsButtonBarProps) { > {subItem?.maskUrl && ( uiState.toggleSubItemMask()} @@ -91,6 +93,7 @@ function DetailsButtonBar(props: DetailsButtonBarProps) { )} { @@ -111,6 +114,7 @@ function DetailsButtonBar(props: DetailsButtonBarProps) { { @@ -135,6 +139,7 @@ function DetailsButtonBar(props: DetailsButtonBarProps) { { @@ -160,6 +165,7 @@ function DetailsButtonBar(props: DetailsButtonBarProps) { <> { if (!item) return @@ -177,6 +183,7 @@ function DetailsButtonBar(props: DetailsButtonBarProps) { /> )} { if (!item) return diff --git a/src/dtProjects/detailsOverlay/DetailsImages.tsx b/src/dtProjects/detailsOverlay/DetailsImages.tsx index c821453..bc8cf7a 100644 --- a/src/dtProjects/detailsOverlay/DetailsImages.tsx +++ b/src/dtProjects/detailsOverlay/DetailsImages.tsx @@ -1,6 +1,6 @@ import { Grid, HStack, Spinner } from "@chakra-ui/react" import type { Snapshot } from "valtio" -import type { DTImageFull } from "@/commands" +import type { DTImageFull, ImageExtra } from "@/commands" import urls from "@/commands/urls" import { VideoContext, type VideoContextType } from "@/components/video/context" import FpsButton from '@/components/video/FpsButton' @@ -8,7 +8,6 @@ import PlayPauseButton from "@/components/video/PlayPauseButton" import Seekbar from "@/components/video/Seekbar" import Video from "@/components/video/Video" import { VideoImage } from "@/components/video/VideoImage" -import type { ImageExtra } from "@/generated/types" import { useGetContext } from "@/hooks/useGetContext" import type { UIControllerState } from "../state/uiState" import { DetailsSpinnerRoot } from "./common" diff --git a/src/dtProjects/detailsOverlay/DetailsOverlay.tsx b/src/dtProjects/detailsOverlay/DetailsOverlay.tsx index c059860..f62de43 100644 --- a/src/dtProjects/detailsOverlay/DetailsOverlay.tsx +++ b/src/dtProjects/detailsOverlay/DetailsOverlay.tsx @@ -51,6 +51,8 @@ function DetailsOverlay(props: DetailsOverlayProps) { { // if (snap.subItem) uiState.hideSubItem() diff --git a/src/dtProjects/imagesList/ImagesList.tsx b/src/dtProjects/imagesList/ImagesList.tsx index c697112..19bc7e5 100644 --- a/src/dtProjects/imagesList/ImagesList.tsx +++ b/src/dtProjects/imagesList/ImagesList.tsx @@ -78,7 +78,7 @@ function GridItemAnim( if (!item) return const previewId = `${item?.project_id}/${item?.preview_id}` - const url = `dtm://dtproject/thumbhalf/${previewId}` + const url = item.is_ready ? `dtm://dtproject/thumbhalf/${previewId}` : "/img_not_available.svg" const isVideo = (item.num_frames ?? 0) > 0 const showVideo = isVideo && hoveredIndex === index diff --git a/src/dtProjects/jobs/models.ts b/src/dtProjects/jobs/models.ts index 64684c4..aee20a8 100644 --- a/src/dtProjects/jobs/models.ts +++ b/src/dtProjects/jobs/models.ts @@ -1,6 +1,6 @@ import { path } from "@tauri-apps/api" import { writeTextFile } from "@tauri-apps/plugin-fs" -import { pdb } from "@/commands" +import { DtpService } from "@/commands" import type { JobCallback } from "@/utils/container/queue" import type { DTPJobSpec } from "../state/types" import type { ListModelInfoFilesResult } from "../state/watchFolders" @@ -59,7 +59,7 @@ export function syncModelInfoJob( data: modelInfoFiles, execute: async (data) => { for (const { path, modelType } of data) { - await pdb.scanModelInfo(path, modelType) + await DtpService.scanModelInfo(path, modelType) } }, } diff --git a/src/dtProjects/settingsPanel/GrantAccess.tsx b/src/dtProjects/settingsPanel/GrantAccess.tsx index efe866f..9211b95 100644 --- a/src/dtProjects/settingsPanel/GrantAccess.tsx +++ b/src/dtProjects/settingsPanel/GrantAccess.tsx @@ -1,5 +1,5 @@ import { Text } from "@chakra-ui/react" -import { pickDrawThingsFolder } from "@/commands" +import { useState } from "react" import { PanelButton, PanelSection, PanelSectionHeader } from "@/components" import { useDTP } from "../state/context" @@ -7,13 +7,29 @@ interface GrantAccessProps extends ChakraProps {} function GrantAccess(props: GrantAccessProps) { const { ...restProps } = props - const { settings: storage, watchFolders } = useDTP() + const { watchFolders } = useDTP() + const snap = watchFolders.useSnap() + const [isLoading, setIsLoading] = useState(false) - const storageSnap = storage.useSnap() - const hasBookmark = !!storageSnap.permissions.bookmark + const handleGrantAccess = async () => { + setIsLoading(true) + try { + await watchFolders.pickDtFolder() + } catch (e) { + alert(`Couldn't add folder:\n\n${e}`) + console.error(e) + } finally { + setIsLoading(false) + } + } return ( - + Draw Things Access @@ -22,21 +38,7 @@ function GrantAccess(props: GrantAccessProps) { After clicking the button, a file picker will open. Select the Documents folder. Note: DTM does not modify your projects. - { - const bookmark = await pickDrawThingsFolder(watchFolders.containerPath) - if (!bookmark) return - if (bookmark.path !== watchFolders.defaultProjectPath) { - alert( - `Please select the correct folder: ${watchFolders.defaultProjectPath}`, - ) - return - } - storage.updateSetting("permissions", "bookmark", bookmark.bookmark) - watchFolders.addDefaultDataFolder() - }} - > + Select folder diff --git a/src/dtProjects/settingsPanel/SettingsPanel.tsx b/src/dtProjects/settingsPanel/SettingsPanel.tsx index 3319691..bacb764 100644 --- a/src/dtProjects/settingsPanel/SettingsPanel.tsx +++ b/src/dtProjects/settingsPanel/SettingsPanel.tsx @@ -1,17 +1,21 @@ import { Box, HStack, Text, VStack } from "@chakra-ui/react" import { openUrl } from "@tauri-apps/plugin-opener" import { useMemo } from "react" -import { IconButton, LinkButton, PanelListItem, PanelSection, PanelSectionHeader } from "@/components" +import { + IconButton, + LinkButton, + PanelListItem, + PanelSection, + PanelSectionHeader, +} from "@/components" import { FaMinus, FaPlus, FiList, FiX, LuFolderTree } from "@/components/icons/icons" import PanelList, { type PanelListCommand } from "@/components/PanelList" import { Slider } from "@/components/ui/slider" import { useSelectable } from "@/hooks/useSelectableV" -import { openAnd } from "@/utils/helpers" import { ContentPanelPopup, type ContentPanelPopupProps } from "../imagesList/ContentPanelPopup" import { useDTP } from "../state/context" import type { WatchFolderState, WatchFoldersController } from "../state/watchFolders" import GrantAccess from "./GrantAccess" -import ResetPermission from "./ResetPermission" function useCommands(watchFolders: WatchFoldersController): PanelListCommand[] { const commands = useMemo( @@ -42,12 +46,9 @@ function useCommands(watchFolders: WatchFoldersController): PanelListCommand - openAnd((f) => watchFolders.addWatchFolder(f), { - directory: true, - multiple: false, - title: `Select watch folder`, - }), + onClick: async () => { + await watchFolders.pickWatchFolder() + }, tip: "Add folder", }, ], @@ -191,7 +192,7 @@ export function SettingsPanel(props: Omit - + {/* */} diff --git a/src/dtProjects/state/context.tsx b/src/dtProjects/state/context.tsx index 538df27..6038a39 100644 --- a/src/dtProjects/state/context.tsx +++ b/src/dtProjects/state/context.tsx @@ -1,12 +1,12 @@ +import { Channel } from "@tauri-apps/api/core" +import DTPService from "@/commands/DtpService" import { UIController } from "@/dtProjects/state/uiState" import { JobQueue } from "@/utils/container/queue" import { Container } from "../../utils/container/container" -import { syncRemoteModelsJob } from "../jobs/models" import DetailsService from "./details" import ImagesController from "./images" import ModelsController from "./models" import ProjectsController from "./projects" -import ScannerService from "./scanner" import SearchController from "./search" import SettingsController from "./settings" import type { DTPContainer, DTPEvents, DTProjectsJobs, DTPServices } from "./types" @@ -59,14 +59,14 @@ export function useDTP() { function createContainer() { console.log("creating container") - return new Container(() => { + const channel = new Channel() + return new Container(channel, () => { const jobs = new JobQueue() const uiState = new UIController() const projects = new ProjectsController() const watchFolders = new WatchFoldersController() const models = new ModelsController() const images = new ImagesController() - const scanner = new ScannerService() const search = new SearchController() const details = new DetailsService(projects) const settings = new SettingsController() @@ -75,21 +75,22 @@ function createContainer() { images.buildImageSource({ text: text ?? "", filters: filters ?? [] }) } - Promise.all([ - watchFolders.assignPaths(), - projects.loadProjects(), - models.refreshModels(), - // watchFolders.loadWatchFolders(), - scanner.sync({}), - jobs.addJob(syncRemoteModelsJob()), - ]) + DTPService.connect(channel).then(async () => { + await Promise.all([ + watchFolders.assignPaths(), + projects.loadProjects(), + // models.refreshModels(), + // watchFolders.loadWatchFolders(), + // DTPService.sync(), + // jobs.addJob(syncRemoteModelsJob()), + ]) + }) const controllers = { projects, uiState, models, watchFolders, - scanner, search, images, details, diff --git a/src/dtProjects/state/details.ts b/src/dtProjects/state/details.ts index 4c4d409..8164b06 100644 --- a/src/dtProjects/state/details.ts +++ b/src/dtProjects/state/details.ts @@ -1,5 +1,5 @@ -import { type DTImageFull, dtProject } from "@/commands" -import type { ImageExtra } from '@/generated/types' +import type { DTImageFull, ImageExtra } from "@/commands" +import DTPService from "@/commands/DtpService" import { extractConfigFromTensorHistoryNode, groupConfigProperties } from "@/utils/config" import type ProjectsController from "./projects" import { DTPStateService } from "./types" @@ -20,7 +20,7 @@ class DetailsService extends DTPStateService { const project = this.projects.state.projects.find((p) => p.id === item.project_id) if (!project) return - const { history, ...extra } = await dtProject.getHistoryFull(project.path, item.node_id) + const { history, ...extra } = await DTPService.getHistoryFull(item.project_id, item.node_id) const rawConfig = extractConfigFromTensorHistoryNode(history) ?? {} const config = groupConfigProperties(rawConfig) @@ -59,8 +59,8 @@ class DetailsService extends DTPStateService { const history = await this.getDetails(item) if (!history) return - return await dtProject.getPredecessorCandidates( - project.path, + return await DTPService.findPredecessor( + item.project_id, item.node_id, history.node.lineage, history.node.logical_time, diff --git a/src/dtProjects/state/images.ts b/src/dtProjects/state/images.ts index 0250175..12678a9 100644 --- a/src/dtProjects/state/images.ts +++ b/src/dtProjects/state/images.ts @@ -1,11 +1,11 @@ import { proxy, subscribe, useSnapshot } from "valtio" -import { pdb } from "@/commands" +import type { FilterTarget, ImageExtra } from "@/commands" +import DTPService from "@/commands/DtpService" import { EmptyItemSource, type IItemSource, PagedItemSource, } from "@/components/virtualizedList/PagedItemSource" -import type { ImageExtra } from "@/generated/types" import type { ContainerEvent } from "@/utils/container/StateController" import type { ImagesSource } from "../types" import type { ProjectState, ProjectsControllerState } from "./projects" @@ -55,7 +55,6 @@ class ImagesController extends DTPStateController { this.watchProxy(async (get) => { const p = get(projectsService.state.projects) const changed = updateProjectsCache(p, this.projectsCache) - if (changed.length > 0) { await this.container.services.uiState.importLockPromise if (this.eventTimer) return @@ -78,8 +77,8 @@ class ImagesController extends DTPStateController { s.showImage = true s.showVideo = true } - const res = await pdb.listImages(s, skip, take) - return res.images + const res = await DTPService.listImages(s, skip, take) + return res.images ?? [] } const getCount = async () => { await this.refreshImageCounts() @@ -96,8 +95,6 @@ class ImagesController extends DTPStateController { itemSource.renderWindow = [0, 20] this.itemSource = itemSource this.state.searchId++ - - this.refreshImageCounts() }) } @@ -121,7 +118,7 @@ class ImagesController extends DTPStateController { async setSearchFilters(filters?: BackendFilter[]) { this.state.imageSource.filters = filters?.map((f) => ({ - target: f.target.toLowerCase(), + target: f.target.toLowerCase() as FilterTarget, operator: f.operator, value: f.value, })) @@ -149,12 +146,12 @@ class ImagesController extends DTPStateController { async refreshImageCounts() { const source = { ...this.state.imageSource } - console.log("refreshImageCounts", source) if (source.showImage === false && source.showVideo === false) { source.showImage = true source.showVideo = true } - const { total, counts } = await pdb.listImagesCount(source) + const { total, counts } = await DTPService.listImagesCount(source) + if (!counts) return const projectCounts = {} as Record for (const count of counts) { projectCounts[count.project_id] = count.count @@ -200,9 +197,10 @@ function updateProjectsCache( const visited: Record = { ...cache } for (const project of projects) { visited[project.id] = null - if (cache[project.id] !== project.image_count) { + const imageCount = project.is_missing || project.is_locked ? 0 : project.image_count + if (cache[project.id] !== imageCount) { projectsChanged.push(project.id) - cache[project.id] = project.image_count ?? 0 + cache[project.id] = imageCount ?? 0 } } diff --git a/src/dtProjects/state/models.ts b/src/dtProjects/state/models.ts index 6c88f1a..d60f525 100644 --- a/src/dtProjects/state/models.ts +++ b/src/dtProjects/state/models.ts @@ -1,5 +1,6 @@ import { proxy } from "valtio" -import { type Model, pdb } from "@/commands" +import type { Model } from "@/commands" +import DTPService from "@/commands/DtpService" import { getVersionLabel } from "@/utils/models" import type { ModelVersionInfo, VersionModel } from "../types" import { type DTPJob, DTPStateController } from "./types" @@ -22,17 +23,15 @@ class ModelsController extends DTPStateController { }) constructor() { - super("models", "models") - } + super("models") - protected override handleTags(_tags: string, _desc: Record) { - const job = getRefreshModelsJob() - this.container.getService("jobs").addJob(job) - return true + this.container.on("models_changed", async () => { + await this.refreshModels() + }) } async refreshModels() { - const dbModels = await pdb.listModels() + const dbModels = await DTPService.listModels() const versions = { "": { models: 0, controls: 0, loras: 0, label: "Unknown" }, diff --git a/src/dtProjects/state/projects.ts b/src/dtProjects/state/projects.ts index c368745..73be55a 100644 --- a/src/dtProjects/state/projects.ts +++ b/src/dtProjects/state/projects.ts @@ -1,10 +1,12 @@ import { proxy } from "valtio" -import { type ProjectExtra, pdb } from "@/commands" +import type { ProjectExtra } from "@/commands" +import DTPService from "@/commands/DtpService" import { makeSelectable, type Selectable } from "@/hooks/useSelectableV" import va from "@/utils/array" import type { ContainerEvent } from "@/utils/container/StateController" -import { arrayIfOnly } from "@/utils/helpers" +import { areEquivalent, arrayIfOnly, groupMap } from "@/utils/helpers" import { DTPStateController } from "./types" +import type { WatchFolderState } from "./watchFolders" export interface ProjectState extends Selectable { name: string @@ -17,10 +19,14 @@ export type ProjectsControllerState = { selectedProjects: ProjectState[] showEmptyProjects: boolean projectsCount: number + folders: { + watchfolder: WatchFolderState + projects: ProjectState[] + }[] } const projectSort = ( - a: Selectable<{ + a: { name: string id: number fingerprint: string @@ -31,8 +37,8 @@ const projectSort = ( modified: number | null missing_on: number | null excluded: boolean - }>, - b: Selectable<{ + }, + b: { name: string id: number fingerprint: string @@ -43,61 +49,52 @@ const projectSort = ( modified: number | null missing_on: number | null excluded: boolean - }>, + }, ): number => a.name.toLowerCase().localeCompare(b.name.toLowerCase()) + class ProjectsController extends DTPStateController { state = proxy({ projects: [], selectedProjects: [], showEmptyProjects: false, projectsCount: 0, + folders: [], }) hasLoaded = false constructor() { - super("projects", "projects") - } - - protected formatTags( - tags: string, - data?: { removed?: number; added?: ProjectExtra; updated?: ProjectExtra; desc?: string }, - ): string { - if (data?.desc) return `invalidate tag: ${tags} - ${data.desc}` - if (data?.removed) return `update tag - removed project - id ${data.removed}` - if (data?.added) - return `update tag - added project - ${data.added.path.split("/").pop()} id ${data.added.id}` - if (data?.updated) - return `update tag - updated project - ${data.updated.path.split("/").pop()} id ${data.updated.id}` - return `update tag: ${tags} ${String(data)}` - } + super("projects") - protected handleTags( - _tags: string, - data: { removed?: number; added?: ProjectExtra; updated?: ProjectExtra }, - ) { - if (data.updated) { - this.updateProject(data.updated.id, data.updated) - } else if (data.added) { - // check if project is already listed - if (this.state.projects.some((p) => p.id === data.added?.id)) { - this.updateProject(data.added.id, data.added) - return true - } + this.container.on("project_added", (project) => { this.state.projects.push( - makeSelectable({ ...data.added, name: data.added.path.split("/").pop() as string }), + makeSelectable({ ...project, name: project.path.split("/").pop() as string }), ) this.state.projects.sort(projectSort) this.state.projectsCount++ - } else if (data.removed) { - const project = this.state.projects.find((p) => p.id === data.removed) - if (project) { - va.remove(this.state.projects, project) + this.loadProjectsDebounced() + }) + + this.container.on("projects_changed", () => { + this.loadProjects() + }) + + this.container.on("project_removed", (projectId) => { + const projectState = this.state.projects.find((p) => p.id === projectId) + if (projectState) { + va.remove(this.state.projects, projectState) this.state.projectsCount-- } - } - this.loadProjectsDebounced() - return true + this.loadProjectsDebounced() + }) + + this.container.on("project_updated", (project) => { + const projectState = this.state.projects.find((p) => p.id === project.id) + if (projectState) { + Object.assign(projectState, project) + } + this.loadProjectsDebounced() + }) } updateProject(projectId: number, data: Partial) { @@ -117,23 +114,93 @@ class ProjectsController extends DTPStateController { } async loadProjects() { - const projects = await pdb.listProjects() - va.set( - this.state.projects, - projects - .map((p) => - makeSelectable( - { ...p, name: p.path.split("/").pop() as string }, - this.state.selectedProjects.some((sp) => sp.id === p.id), - ), - ) - .sort(projectSort), - ) - this.state.projectsCount = projects.length + const wfs = this.container.getService("watchFolders") + const watchfolders = await wfs.loadWatchFolders() + const dtpProjects = await (await DTPService.listProjects()).sort(projectSort) + + const folders = groupMap( + dtpProjects, + (p) => [ + p.watchfolder_id, + makeSelectable( + { + ...p, + name: p.path.split("/").pop() as string, + }, + false, + (item, currentValue, modifier) => this.selectItem(item, currentValue, modifier), + ), + ], + (folderId, folderProjects) => { + const folder = watchfolders.find((f) => f.id === folderId) + return { + watchfolder: folder, + projects: folderProjects, + } + }, + ).filter( + (f) => f.watchfolder !== undefined && f.projects.length > 0, + ) as ProjectsControllerState["folders"] + + const newProjects = folders.flatMap((f) => f.projects) + + va.set(this.state.folders, folders) + va.set(this.state.projects, newProjects) + + this.state.projectsCount = this.state.projects.length this.hasLoaded = true this.container.emit("projectsLoaded") } + private lastSelectedProject: ProjectState | null = null + selectItem(item: ProjectState, currentValue: boolean, modifier?: "shift" | "cmd" | null) { + // toggle item + if (modifier === "cmd") { + item.setSelected(!currentValue) + this.lastSelectedProject = item + } + // this is the tricky one + else if (modifier === "shift") { + const lastIndex = this.state.projects.findIndex( + (p) => p.id === this.lastSelectedProject?.id, + ) + const currentIndex = this.state.projects.findIndex((p) => p.id === item.id) + if (currentIndex === -1) return + + // if there is no lastselected index, just select/deselect the item + if (lastIndex === -1) { + item.toggleSelected() + } else { + const from = Math.min(lastIndex, currentIndex) + const to = Math.max(lastIndex, currentIndex) + this.state.projects.forEach((p, i) => { + if (i >= from && i <= to && !p.excluded) p.setSelected(true) + else p.setSelected(false) + }) + } + } + // change selected or deselect if only selected + else { + const areOthersSelected = this.state.projects.some( + (p) => p.id !== item.id && p.selected, + ) + if (areOthersSelected) { + // if others are selected, the current state of this item is irrelevant. + // the selection becomes this item + this.state.projects.forEach((p) => { + p.setSelected(false) + }) + item.setSelected(true) + } else { + // if no others are selected, we can just toggle this item + item.toggleSelected() + } + this.lastSelectedProject = item + } + const selectedProjects = this.state.projects.filter((p) => p.selected) + va.set(this.state.selectedProjects, selectedProjects) + } + private loadProjectsTimeout: NodeJS.Timeout | null = null async loadProjectsDebounced() { if (this.loadProjectsTimeout) { @@ -144,20 +211,6 @@ class ProjectsController extends DTPStateController { }, 2000) } - async removeProjects(projectFiles: string[]) { - for (const projectFile of projectFiles) { - await pdb.removeProject(projectFile) - } - await this.loadProjects() - } - - async addProjects(projectFiles: string[]) { - for (const pf of projectFiles) { - await pdb.addProject(pf) - } - await this.loadProjects() - } - /** * this function can be called with a project or an array of projects * state or snapshot @@ -168,14 +221,12 @@ class ProjectsController extends DTPStateController { for (const project of toUpdate) { const projectState = this.state.projects.find((p) => p.id === project.id) if (!projectState) continue - await pdb.updateExclude(project.id, exclude) + await DTPService.updateProject(project.id, exclude) projectState.excluded = exclude stateUpdate.push(projectState) projectState.setSelected(false) } this.setSelectedProjects([]) - const scanner = this.container.getService("scanner") - await scanner.syncProjects(stateUpdate.map((p) => p.path)) } getProject(projectId?: number | null) { @@ -194,12 +245,44 @@ class ProjectsController extends DTPStateController { } setSelectedProjects(projects: ProjectState[]) { + const projectIds = new Set(projects.map((p) => p.id)) for (const project of this.state.projects) { - project.setSelected(projects.some((p) => p.id === project.id)) + project.setSelected(projectIds.has(project.id)) } va.set(this.state.selectedProjects, projects) } + /// set selection to every project in the watchfolder + /// UNLESS every project in the folder and ONLY projects in the folder are selected + /// in which case we deselect all projects + /// this depends on sort being the same + selectFolderProjects(watchfolder: WatchFolderState) { + const selectedIds = this.state.selectedProjects.map((p) => p.id) + const folderGroups = this.state.folders.find( + (f) => f.watchfolder.id === watchfolder.id, + )?.projects + if (!folderGroups) return + const select = !areEquivalent( + selectedIds, + folderGroups.map((p) => p.id), + ) + + const selected: ProjectState[] = [] + + if (select) { + for (const project of this.state.projects) { + project.setSelected(project.watchfolder_id === watchfolder.id) + if (project.selected) selected.push(project) + } + } else { + this.state.projects.forEach((p) => { + p.setSelected(false) + }) + } + + va.set(this.state.selectedProjects, selected) + } + useProjectsSummary() { const snap = this.useSnap() return { @@ -211,7 +294,6 @@ class ProjectsController extends DTPStateController { toggleShowEmptyProjects() { this.state.showEmptyProjects = !this.state.showEmptyProjects - console.log("show empty", this.state.showEmptyProjects) } } diff --git a/src/dtProjects/state/scanner.ts b/src/dtProjects/state/scanner.ts deleted file mode 100644 index eee3a47..0000000 --- a/src/dtProjects/state/scanner.ts +++ /dev/null @@ -1,333 +0,0 @@ -import { exists, stat } from "@tauri-apps/plugin-fs" -import { type ProjectExtra, pdb } from "@/commands" -import type { JobCallback } from "@/utils/container/queue" -import { TMap } from "@/utils/TMap" -import { syncModelInfoJob } from "../jobs/models" -import { getRefreshModelsJob } from "./models" -import { - type DTPContainer, - type DTPJob, - DTPStateService, - type ProjectFilesChangedPayload, - type SyncScope, - type WatchFoldersChangedPayload, -} from "./types" -import type { ListModelInfoFilesResult, ProjectFileStats, WatchFolderState } from "./watchFolders" - -class ScannerService extends DTPStateService { - constructor() { - super("scanner") - this.container.on("watchFoldersChanged", (e) => this.onWatchFoldersChanged(e)) - this.container.on("projectFilesChanged", async (e) => this.onProjectFilesChanged(e)) - } - - async onWatchFoldersChanged(e: WatchFoldersChangedPayload) { - const syncFolders = [e.added, e.changed].flat() as WatchFolderState[] - if (syncFolders.length > 0) { - this.sync({ watchFolders: syncFolders }, () => { - console.log("sync finished?") - }) - } - if (e.removed.length > 0) { - this.sync({}, () => { - console.log("sync finished?") - }) - } - } - - async onProjectFilesChanged(e: ProjectFilesChangedPayload) { - this.syncProjects(e.files) - } - - sync(scope: SyncScope, callback?: JobCallback) { - console.log("starting sync job", scope) - const callbackWrapper = () => { - console.log("sync finished") - callback?.() - } - const job = createSyncJob(scope, callbackWrapper) - this.container.getService("jobs").addJob(job) - } - - async syncProjects(projectPaths: string[], callback?: JobCallback) { - const projectStats = await Promise.all(projectPaths.map((p) => getProjectStats(p))) - const projects = projectStats.filter((p) => !!p) as ProjectFileStats[] - this.sync({ projects }, callback) - } - - override dispose() { - super.dispose() - } -} - -export default ScannerService - -async function getProjectStats(projectPath: string) { - if (!projectPath.endsWith(".sqlite3")) return undefined - if (!(await exists(projectPath))) return undefined - - const stats = await stat(projectPath) - - let walStats: Pick>, "size" | "mtime"> = { - size: 0, - mtime: new Date(0), - } - if (await exists(`${projectPath}-wal`)) { - walStats = await stat(`${projectPath}-wal`) - } - - return { - path: projectPath, - size: stats.size + walStats.size, - modified: Math.max(stats.mtime?.getTime() || 0, walStats.mtime?.getTime() || 0), - } -} - -export type ProjectJobPayload = { - action: "add" | "update" | "remove" | "none" | "mark-missing" - project: string - size: number - mtime: number -} - -function getSyncScopeLabel(scope: SyncScope) { - if (scope.watchFolders) { - const folders = scope.watchFolders.map((f) => f.path.split("/").pop()) - return `Sync for folders: ${folders.join(", ")}` - } - if (scope.projects) { - const projects = scope.projects.map((p) => p.path.split("/").pop()) - return `Sync for projects: ${projects.join(", ")}` - } - return "Full sync" -} - -function createSyncJob(scope: SyncScope, callback?: JobCallback): DTPJob { - const label = getSyncScopeLabel(scope) - return { - type: "data-sync", - label, - data: scope, - execute: getExecuteSync(callback), - } -} - -type ProjectSyncObject = { - file?: ProjectFileStats - entity?: ProjectExtra - isMissing: boolean - action: "add" | "remove" | "update" | "none" | "mark-missing" -} - -function getSyncObject(opts: Partial): ProjectSyncObject { - return { - file: opts.file, - entity: opts.entity, - isMissing: opts.isMissing ?? false, - action: opts.action ?? "none", - } -} - -function getExecuteSync(callback?: JobCallback) { - async function executeSync(scope: SyncScope, container: DTPContainer) { - const wfs = container.services.watchFolders - const ps = container.services.projects - - const folderScoped = !!scope.watchFolders && scope.watchFolders.length > 0 - const projectScoped = !!scope.projects && scope.projects.length > 0 - - if (folderScoped && projectScoped) throw new Error("not supported at this time") - - const watchFolders = - (await (async () => { - if (folderScoped) return scope.watchFolders - if (projectScoped) return [] - await wfs.loadWatchFolders(true) - return wfs.state.folders - })()) ?? [] - - const modelFiles = [] as ListModelInfoFilesResult[] - const projectFiles = [] as ProjectFileStats[] - - for (const folder of watchFolders) { - const folderFiles = await wfs.listFiles(folder) - modelFiles.push(...folderFiles.models) - projectFiles.push(...folderFiles.projects) - } - if (projectScoped) { - projectFiles.push(...(scope.projects ?? [])) - } - - // gather ENTITIES - await ps.loadProjects() - const projectEntities = TMap.from(ps.state.projects, (p) => p.path) - - if (folderScoped && watchFolders?.length) { - projectEntities.retain((path) => watchFolders.some((f) => path.startsWith(f.path))) - } else if (projectScoped && scope.projects?.length) { - const scopedProjects = new Set(scope.projects.map((p) => p.path)) - projectEntities.retain((path) => scopedProjects.has(path)) - } - - const syncs = [] as ProjectSyncObject[] - - for (const projectFile of projectFiles) { - const project = getSyncObject({ - file: projectFile, - entity: projectEntities.take(projectFile.path), - }) - syncs.push(project) - } - - for (const projectEntity of projectEntities.values()) { - const project = getSyncObject({ - entity: projectEntity, - }) - // if a project is not covered by a watchfolder, we can stop searching for a file - const projectFolder = await wfs.getFolderForProject(projectEntity.path) - if (!projectFolder) { - syncs.push(project) - continue - } - - const projectStats = await getProjectStats(projectEntity.path) - if (projectStats) - project.file = { ...projectStats, watchFolderPath: projectFolder.path } - else project.isMissing = true - syncs.push(project) - } - - // create jobs from the entity/file pairs - const jobs = [] as DTPJob[] - - for (const project of syncs) { - // file with no entity, add new project - if (project.file && !project.entity) project.action = "add" - // entity with no file, remove or mark missing - else if (!project.file && project.entity) { - const folder = await wfs.getFolderForProject(project.entity.path) - if (folder?.isMissing) project.action = "mark-missing" - else project.action = "remove" - } - // update if sizes or modified times are different - else if (project.file && project.entity && !project.entity.excluded) { - if ( - project.file.size !== project.entity.filesize || - project.file.modified !== project.entity.modified - ) - project.action = "update" - } - } - - let jobsCreated = 0 - let jobsCompleted = 0 - - const jobCallback = () => { - jobsCompleted++ - if (jobsCompleted === jobsCreated) callback?.() - } - - // create jobs - for (const project of syncs) { - if (project.action === "none") continue - const projectPath = project.file?.path ?? project.entity?.path - if (!projectPath) continue - const job = getProjectJob(projectPath, project, jobCallback) - if (job) jobs.push(job) - } - - if (modelFiles.length > 0) { - jobs.push(syncModelInfoJob(modelFiles, jobCallback)) - } - - jobsCreated = jobs.length - - return { jobs } - } - - return executeSync -} - -function getProjectJob( - project: string, - data: ProjectSyncObject, - callback?: JobCallback, -): DTPJob | undefined { - switch (data.action) { - case "add": - if (!data.file) { - console.warn("can't create 'project-add' job without file stats") - return undefined - } - return { - type: "project-add", - data: [data.file], - merge: "first", - callback, - execute: async (data: ProjectFileStats[], container) => { - container.services.uiState.setImportLock(true) - const projects = [] as [ProjectFileStats, ProjectExtra][] - // there are two loops here because of the way the progress bar works - // the first loop creates the projects and gives the progress bar a total count - // the second loop scans each project and advances the progress bar - for (const p of data) { - try { - const project = await pdb.addProject(p.path) - if (project) projects.push([p, project]) - } catch (e) { - console.error(e) - } - } - for (const [p, project] of projects) { - try { - await pdb.scanProject(project.path, false, p.size, p.modified) - } catch (e) { - console.error(e) - } - } - container.services.uiState.setImportLock(false) - return { jobs: [getRefreshModelsJob()] } - }, - } - case "update": - if (!data.file) { - console.warn("can't create 'project-update' job without file stats") - return undefined - } - return { - type: "project-update", - data: { - project, - mtime: data.file?.modified, - size: data.file?.size, - action: "update", - }, - callback, - execute: async (data: ProjectJobPayload, _container) => { - await pdb.scanProject(project, false, data.size, data.mtime) - }, - } - case "remove": - return { - type: "project-remove", - data: project, - callback, - execute: async (_data: string, _container) => { - await pdb.removeProject(project) - }, - } - case "mark-missing": - return { - type: "project-mark-missing", - data: [project], - merge: "first", - callback, - execute: async (data: string[], _container) => { - await pdb.updateMissingOn(data, null) - console.log("missing", data) - }, - } - default: - return undefined - } -} diff --git a/src/dtProjects/state/search.ts b/src/dtProjects/state/search.ts index 96e9ccc..cba34fa 100644 --- a/src/dtProjects/state/search.ts +++ b/src/dtProjects/state/search.ts @@ -15,17 +15,7 @@ export type SearchControllerState = { filters: Filter[] } -export type FilterOperator = - | "eq" - | "neq" - | "gt" - | "gte" - | "lt" - | "lte" - | "is" - | "isnot" - | "has" - | "doesnothave" +import type { FilterOperator, FilterTarget } from "@/commands" export type Filter = { index: number @@ -36,7 +26,7 @@ export type Filter = { } export type BackendFilter = { - target: string + target: FilterTarget operator: FilterOperator value: T } @@ -96,7 +86,7 @@ class SearchController extends DTPStateController { const filterTarget = filterTargets[filter.target as keyof typeof filterTargets] const bFilter: BackendFilter = { - target: filter.target, + target: filter.target as FilterTarget, operator: filter.operator, value: arrayIfOnly( filterTarget.prepare ? filterTarget.prepare(filter.value) : filter.value, diff --git a/src/dtProjects/state/settings.ts b/src/dtProjects/state/settings.ts index 53ac9bb..6f88376 100644 --- a/src/dtProjects/state/settings.ts +++ b/src/dtProjects/state/settings.ts @@ -1,5 +1,4 @@ import { store } from "@tauri-store/valtio" -import { resolveBookmark, stopAccessingBookmark } from "@/commands" import { DTPStateController } from "./types" type SettingsControllerState = { @@ -10,9 +9,6 @@ type SettingsControllerState = { videoSource: "preview" | "tensor" videoFps: number } - permissions: { - bookmark: string | null - } models: { lastUpdated: string } @@ -26,9 +22,6 @@ const defaultState: SettingsControllerState = { videoSource: "preview", videoFps: 16, }, - permissions: { - bookmark: null, - }, models: { lastUpdated: new Date(0).toISOString(), }, @@ -46,7 +39,6 @@ class SettingsController extends DTPStateController { constructor() { super("settings") - console.log("does bookmark exist?", !!this.state.permissions.bookmark) } updateSetting< @@ -55,22 +47,6 @@ class SettingsController extends DTPStateController { >(group: G, key: K, value: SettingsControllerState[G][K]) { this.state[group][key] = value } - - async setBookmark(bookmark: string) { - await this.clearBookmark() - - this.state.permissions.bookmark = bookmark - await resolveBookmark(bookmark) - } - - async clearBookmark() { - const currentBookmark = this.state.permissions.bookmark - if (currentBookmark) { - await stopAccessingBookmark(currentBookmark) - } - - this.state.permissions.bookmark = null - } } export default SettingsController diff --git a/src/dtProjects/state/types.ts b/src/dtProjects/state/types.ts index fd3b1c8..0a84b0c 100644 --- a/src/dtProjects/state/types.ts +++ b/src/dtProjects/state/types.ts @@ -1,3 +1,5 @@ +import type { ProjectExtra } from "@/commands" +import type { ScanProgress } from "@/commands/DtpServiceTypes" import type { IContainer } from "@/utils/container/interfaces" import type { JobQueue, JobResult, JobSpec, JobUnion } from "@/utils/container/queue" import { Service } from "@/utils/container/Service" @@ -6,8 +8,6 @@ import type DetailsService from "./details" import type ImagesController from "./images" import type ModelsController from "./models" import type ProjectsController from "./projects" -import type ScannerService from "./scanner" -import type { ProjectJobPayload } from "./scanner" import type SearchController from "./search" import type SettingsController from "./settings" import type { UIController } from "./uiState" @@ -35,12 +35,8 @@ export type DTProjectsJobs = { data: ProjectFileStats[] result: never } - "project-update": { - data: ProjectJobPayload - result: never - } "project-remove": { - data: string + data: number result: never } "project-folder-scan": { @@ -83,6 +79,26 @@ export type DTProjectsContainer = IContainer export type DTPEvents = { watchFoldersChanged: (payload: WatchFoldersChangedPayload) => void projectFilesChanged: (payload: ProjectFilesChangedPayload) => void + + watch_folders_changed: () => void + project_added: (payload: ProjectExtra) => void + project_removed: (payload: number) => void + project_updated: (payload: ProjectExtra) => void + projects_changed: () => void + + import_started: () => void + import_progress: (payload: ScanProgress) => void + import_completed: () => void + + sync_started: () => void + sync_complete: () => void + + folder_sync_started: (payload: number) => void + folder_sync_complete: (payload: number) => void + + dtp_service_ready: () => void + projectsLoaded: (payload?: undefined) => void + imagesChanged: (payload?: undefined) => void } export interface WatchFoldersChangedPayload { @@ -109,7 +125,6 @@ export interface DTPServices { projects: ProjectsController models: ModelsController watchFolders: WatchFoldersController - scanner: ScannerService search: SearchController images: ImagesController details: DetailsService diff --git a/src/dtProjects/state/uiState.ts b/src/dtProjects/state/uiState.ts index 14d9441..ed70556 100644 --- a/src/dtProjects/state/uiState.ts +++ b/src/dtProjects/state/uiState.ts @@ -1,7 +1,8 @@ import { proxy, ref, useSnapshot } from "valtio" -import { type DTImageFull, dtProject, type TensorHistoryExtra } from "@/commands" +import type { DTImageFull, ImageExtra, TensorHistoryExtra } from "@/commands" +import DTPService from "@/commands/DtpService" +import type { ScanProgress } from "@/commands/DtpServiceTypes" import urls from "@/commands/urls" -import type { ImageExtra } from "@/generated/types" import { uint8ArrayToBase64 } from "@/utils/helpers" import { drawPose, pointsToPose, tensorToPoints } from "@/utils/pose" import type { ProjectState } from "./projects" @@ -37,6 +38,12 @@ export type UIControllerState = { isSettingsOpen: boolean isGridInert: boolean importLock: boolean + importLockCount: number + importProgress?: { + found: number + scanned: number + imageCount: number + } } type Handler = (payload: T) => void @@ -60,10 +67,15 @@ export class UIController extends DTPStateController { isSettingsOpen: false, isGridInert: false, importLock: false, + importLockCount: 0, }) constructor() { super("uiState") + + this.container.on("import_started", () => this.startImport()) + this.container.on("import_progress", (progress) => this.updateImport(progress)) + this.container.on("import_completed", () => this.endImport()) } onItemChanged: Handler<{ item: ImageExtra | null }>[] = [] @@ -100,17 +112,31 @@ export class UIController extends DTPStateController { get importLockPromise() { return this._importLockPromise } - /** show/hide the import lock */ - setImportLock(lock: boolean) { - this.state.importLock = lock - if (lock) { - this._importLockPromise = new Promise((resolve) => { - this._importLockResolver = resolve - }) - } else { - this._importLockResolver?.() + startImport() { + this.state.importLock = true + const { promise, resolve } = Promise.withResolvers() + this._importLockPromise = promise + this._importLockResolver = resolve + this.state.importLockCount++ + this.state.importProgress = { + found: 0, + scanned: 0, + imageCount: 0, } } + endImport() { + this.state.importLock = false + this.state.importProgress = undefined + this._importLockResolver?.() + this._importLockResolver = null + } + updateImport(progress: ScanProgress) { + const total = this.state.importProgress + if (!total) return + total.found += progress.projects_found + total.scanned += progress.projects_scanned + total.imageCount += progress.images_scanned + } async showDetailsOverlay(item: ImageExtra) { const detailsOverlay = this.state.detailsView @@ -179,7 +205,7 @@ export class UIController extends DTPStateController { } async showSubItemPose(projectId: number, tensorId: string) { - const poseData = await dtProject.decodeTensor(projectId, tensorId, false) + const poseData = await DTPService.decodeTensor(projectId, tensorId, false) const points = tensorToPoints(poseData) const pose = pointsToPose(points, 1024, 1024) const image = await drawPose(pose, 4) @@ -194,7 +220,7 @@ export class UIController extends DTPStateController { } async showSubItemImage(projectId: number, tensorId: string) { - const size = await dtProject.getTensorSize(projectId, tensorId) + const size = await DTPService.getTensorSize(projectId, tensorId) const loadImg = new Image() loadImg.onload = () => { const details = this.state.detailsView diff --git a/src/dtProjects/state/watchFolders.ts b/src/dtProjects/state/watchFolders.ts index a3796a3..db3f5a0 100644 --- a/src/dtProjects/state/watchFolders.ts +++ b/src/dtProjects/state/watchFolders.ts @@ -1,18 +1,10 @@ import { path } from "@tauri-apps/api" -import { - exists, - readDir, - stat, - type UnwatchFn, - type WatchEvent, - watch, -} from "@tauri-apps/plugin-fs" import { proxy } from "valtio" -import { pdb, type WatchFolder } from "@/commands" +import type { WatchFolder } from "@/commands" +import DTPService from "@/commands/DtpService" import { makeSelectable, type Selectable } from "@/hooks/useSelectableV" import va from "@/utils/array" -import { DebounceMap } from "@/utils/DebounceMap" -import { arrayIfOnly, compareItems } from "@/utils/helpers" +import { arrayIfOnly } from "@/utils/helpers" import { DTPStateController } from "./types" const modelInfoFilenames = { @@ -27,7 +19,10 @@ const modelInfoFilenames = { export type WatchFoldersControllerState = { folders: WatchFolderState[] - hasDefaultDataFolder: boolean + isDtFolderAdded: boolean + homePath: string | null + containerPath: string | null + defaultDataFolder: string | null } export type WatchFolderState = Selectable< @@ -35,6 +30,7 @@ export type WatchFolderState = Selectable< isMissing?: boolean selected?: boolean firstScan?: boolean + isDtData?: boolean } > @@ -48,6 +44,7 @@ export type ProjectFileStats = { size: number modified: number watchFolderPath?: string + watchFolderId?: number } export type ListFilesResult = { @@ -65,291 +62,88 @@ export type ListFilesResult = { export class WatchFoldersController extends DTPStateController { state = proxy({ folders: [] as WatchFolderState[], - hasDefaultDataFolder: false, + isDtFolderAdded: false, + homePath: null, + containerPath: null, + defaultDataFolder: null, }) async assignPaths() { - this._home = await path.homeDir() - this._containerPath = await path.join(this._home, "Library/Containers/com.liuliu.draw-things/Data") - this._defaultDataFolder = await path.join(this._containerPath, "Documents") - } - - _home: string = "" - _containerPath: string = "" - _defaultDataFolder: string = "" - - watchDisposers = new Map>() - watchCallbacks = new DebounceMap(1500) - - constructor() { - super("watchFolders", "watchfolders") - } + this.state.homePath = await path.homeDir() + this.state.containerPath = await path.join( + this.state.homePath, + "Library/Containers/com.liuliu.draw-things/Data", + ) + this.state.defaultDataFolder = await path.join(this.state.containerPath, "Documents") - override async handleTags(_tags: string, _desc: Record) { - await this.loadWatchFolders() - return true + this.state.isDtFolderAdded = this.state.folders.some( + (f) => f.path === this.state.defaultDataFolder, + ) } - async loadWatchFolders(supressEvent = false) { - const res = await pdb.watchFolders.listAll() - const folders = res.map((f) => makeSelectable(f as WatchFolderState)) - - for (const folder of folders) { - if (!this.state.hasDefaultDataFolder && folder.path === this._defaultDataFolder) { - this.state.hasDefaultDataFolder = true - } - folder.isMissing = !(await exists(folder.path)) - } - - const prevFolders = [...this.state.folders] - va.set(this.state.folders, folders) - - const diff = compareItems(prevFolders, folders, (f) => f.id, { ignoreFunctions: true }) - if (!diff.itemsChanged) return - - // why stop and start watching changed? - for (const folder of [...diff.removed, ...diff.changed]) { - this.stopWatch(folder.path) - } - - for (const folder of [...diff.added, ...diff.changed]) { - this.startWatch(folder) - } + constructor() { + super("watchFolders") - if (!supressEvent) this.container.emit("watchFoldersChanged", { ...diff }) - } + this.container.on("watch_folders_changed", async () => { + await this.loadWatchFolders() + await this.container.getService("projects").loadProjects() + }) - // it is not necessary to reload after adding - tags will invalidate - async addWatchFolder(folderPath: string, recursive = false) { - if (await exists(folderPath)) { - const isDtFolder = folderPath === this._defaultDataFolder - await pdb.watchFolders.add(folderPath, recursive || isDtFolder) - } else { - throw new Error("DNE") - } + this.assignPaths().then(() => {}) } - // it is not necessary to reload after removing - tags will invalidate - async removeWatchFolders(folder: WatchFolderState): Promise - async removeWatchFolders(folders: readonly WatchFolderState[]): Promise - async removeWatchFolders(arg: WatchFolderState | readonly WatchFolderState[]): Promise { - const folders = arrayIfOnly(arg) - await pdb.watchFolders.remove(folders.map((f) => f.id)) - if (folders.some((f) => f.path === this._defaultDataFolder)) - this.state.hasDefaultDataFolder = false + async loadWatchFolders() { + const res = await DTPService.listWatchFolders() + return this.setWatchfolders(res) } - async setRecursive(folder: WatchFolderState | readonly WatchFolderState[], value: boolean) { - // disallow changing recursive on default folder - const toUpdate = arrayIfOnly(folder).filter((f) => f.path !== this._defaultDataFolder) - for (const folder of toUpdate) { - const updFolder = await pdb.watchFolders.update(folder.id, value) + private setWatchfolders(folders: WatchFolder[]) { + const foldersState = folders.map((f) => this.createWatchFolderState(f)) + this.state.isDtFolderAdded = foldersState.some((folder) => folder.isDtData) - // TODO: is this necessary? I don't think so... - const idx = this.state.folders.findIndex((f) => f.id === folder.id) - if (idx !== -1) { - this.state.folders[idx].recursive = updFolder.recursive - } - } + va.set(this.state.folders, foldersState) + return this.state.folders } - async addDefaultDataFolder() { - await this.addWatchFolder(this._defaultDataFolder, true) + private createWatchFolderState(folder: WatchFolder): WatchFolderState { + return makeSelectable({ ...folder, isDtData: folder.path === this.state.defaultDataFolder }) } - async listFiles(folder: WatchFolderState): Promise { - const result: ListFilesResult = { - projects: [], - models: [], - isMissing: false, - } - - if (!exists(folder.path)) { - result.isMissing = true - return result - } - - const toCheck = [folder.path] - - async function readFolder(currentFolder: string) { - try { - const files = await readDir(currentFolder) - for (const file of files) { - const filePath = await path.join(currentFolder, file.name) - // add folders to list - if (file.isDirectory) { - toCheck.push(filePath) - } - // check project files - this also will check the -wal file - else if (file.name.endsWith(".sqlite3")) { - const fileStats = await stat(filePath) - if (!fileStats) continue - - const walPath = filePath + "-wal" - const walStats = (await exists(walPath)) ? await stat(walPath) : undefined - - const project: ProjectFileStats = { - path: filePath, - size: fileStats.size + (walStats?.size ?? 0), - modified: Math.max( - fileStats.mtime?.getTime() ?? 0, - walStats?.mtime?.getTime() ?? 0, - ), - watchFolderPath: currentFolder, - } - result.projects.push(project) - } - // check model files - else if (file.name.endsWith(".json") && file.name in modelInfoFilenames) { - result.models.push({ - path: filePath, - modelType: modelInfoFilenames[file.name], - }) - } - } - } catch (e) { - console.error(e) - } - return result - } - - while (toCheck.length > 0) { - const currentFolder = toCheck.shift() - if (!currentFolder) continue - await readFolder(currentFolder) - if (!folder.recursive) break + async pickDtFolder() { + try { + await DTPService.pickWatchFolder(true) + return true + } catch (e) { + throw e } - - return result } - // TODO: deprecate - async listProjects(folder: WatchFolderState): Promise { + async pickWatchFolder() { try { - if (!(await exists(folder.path))) { - folder.isMissing = true - return [] - } - folder.isMissing = false - const projects = await findFiles(folder.path, folder.recursive, (f) => - f.endsWith(".sqlite3"), - ) - - return projects + await DTPService.pickWatchFolder(false) + return true } catch (e) { console.error(e) - return [] + return false } } - async getFolderForProject(project: string): Promise { - const folders = [] as WatchFolderState[] - for (const folder of this.state.folders) { - const sep = await path.sep() - const folderWithSep = folder.path.endsWith(sep) ? folder.path : folder.path + sep - if (project.startsWith(folderWithSep)) { - const projectDir = await path.dirname(project) - if (projectDir === folder.path || folder.recursive) { - folders.push(folder) - } - } + async removeWatchFolders(folder: WatchFolderState): Promise + async removeWatchFolders(folders: readonly WatchFolderState[]): Promise + async removeWatchFolders(arg: WatchFolderState | readonly WatchFolderState[]): Promise { + const folders = arrayIfOnly(arg) + for (const folder of folders) { + await DTPService.removeWatchFolder(folder.id) } - - folders.sort((a, b) => b.path.length - a.path.length) - return folders[0] ?? undefined - } - - async startWatch(folder: WatchFolderState) { - if (this.watchDisposers.has(folder.path)) - throw new Error(`must stop watching folder first, ${folder.path}`) - - console.debug("starting watch", folder.path) - const unwatch = watch( - folder.path, - async (e) => { - if (!shouldReact(e)) return - const projectFiles = e.paths - .filter((p) => p.endsWith(".sqlite3") || p.endsWith(".sqlite3-wal")) - .map((p) => p.replace(/-wal$/g, "")) - if (projectFiles.length === 0) return - console.debug("watch event", JSON.stringify(e)) - const uniqueFiles = Array.from(new Set(projectFiles)) - - for (const file of uniqueFiles) { - this.watchCallbacks.set(file, () => { - this.container.emit("projectFilesChanged", { files: [file] }) - }) - } - }, - { delayMs: 1500, recursive: folder.recursive }, - ) - this.watchDisposers.set(folder.path, unwatch) - } - - async stopWatch(folder: string) { - if (!this.watchDisposers.has(folder)) return - console.debug("stopping watch", folder) - const unwatchPromise = this.watchDisposers.get(folder) - this.watchDisposers.delete(folder) - - const unwatch = await unwatchPromise - unwatch?.() } - get defaultProjectPath() { - return this._defaultDataFolder - } - - get containerPath() { - return this._containerPath - } - - override dispose() { - super.dispose() - - for (const folder of this.watchDisposers.keys()) { - this.stopWatch(folder) + async setRecursive(folder: WatchFolderState | readonly WatchFolderState[], recursive: boolean) { + // disallow changing recursive on default folder + const toUpdate = arrayIfOnly(folder).filter((f) => f.path !== this.state.defaultDataFolder) + for (const folder of toUpdate) { + await DTPService.updateWatchFolder(folder.id, recursive) } } } export default WatchFoldersController - -// TODO: remove -async function findFiles( - directory: string, - recursive: boolean, - filterFn: (file: string) => boolean, -) { - const files = [] as string[] - const dirFiles = await readDir(directory) - for (const file of dirFiles) { - if (file.isDirectory && recursive) { - files.push( - ...(await findFiles(await path.join(directory, file.name), recursive, filterFn)), - ) - } - - if (!file.isFile) continue - if (!filterFn(file.name)) continue - files.push(await path.join(directory, file.name)) - } - return files -} - -function shouldReact(event: WatchEvent) { - if (event.paths.every((p) => p.endsWith("shm"))) return false - - const type = event.type as object - - if ("access" in type) return false - if ("remove" in type) return true - if ("create" in type) return true - if ("modify" in type && type.modify && typeof type.modify === "object") { - // only react to changes in the file, not metadata changes - if ("kind" in type.modify && type.modify.kind === "metadata") return false - return true - } - - return true -} diff --git a/src/dtProjects/types.ts b/src/dtProjects/types.ts index e3a772d..0ef63c3 100644 --- a/src/dtProjects/types.ts +++ b/src/dtProjects/types.ts @@ -1,47 +1,10 @@ -import type { Model, XTensorHistoryNode } from "@/commands" +import type { Model, ScanProgress } from "@/commands" import type { BackendFilter } from "./state/search" -export type ScanProgress = { - projects_scanned: number - projects_total: number - project_final: number - project_path: string - images_scanned: number - images_total: number -} export type ScanProgressEvent = { payload: ScanProgress } -export type DTImage = { - image_id: number - project_id: number - model_id: number - model_file: string - prompt: string - negative_prompt: string - dt_id: number - row_id: number - wall_clock: string -} - -export interface TensorHistoryExtra { - rowId: number - lineage: number - logicalTime: number - - tensorId?: string | null - maskId?: string | null - depthMapId?: string | null - scribbleId?: string | null - poseId?: string | null - colorPaletteId?: string | null - customId?: string | null - - history: XTensorHistoryNode - projectPath: string -} - export type ImagesSource = { projectIds?: number[] search?: string diff --git a/src/generated/commands.ts b/src/generated/commands.ts deleted file mode 100644 index 5dbd018..0000000 --- a/src/generated/commands.ts +++ /dev/null @@ -1,142 +0,0 @@ -// This file was auto-generated by tauri-ts-generator -// Do not edit this file manually - -import { invoke } from "@tauri-apps/api/core"; -import type { ListImagesFilter, ListImagesResult, ModelExtra, ProjectExtra, TensorHistoryClip, TensorHistoryExtra, TensorHistoryImport, TensorRaw, TensorSize, TextHistoryNodeDTO, WatchFolderDTO } from "./types"; - -export async function readClipboardTypes(pasteboard: string | null): Promise { - return invoke("read_clipboard_types", { pasteboard }); -} - -export async function readClipboardStrings(types: string[], pasteboard: string | null): Promise> { - return invoke>("read_clipboard_strings", { types, pasteboard }); -} - -export async function readClipboardBinary(ty: string, pasteboard: string | null): Promise { - return invoke("read_clipboard_binary", { ty, pasteboard }); -} - -export async function writeClipboardBinary(ty: string, data: number[]): Promise { - return invoke("write_clipboard_binary", { ty, data }); -} - -export async function ffmpegCheck(): Promise { - return invoke("ffmpeg_check"); -} - -export async function ffmpegDownload(): Promise { - return invoke("ffmpeg_download"); -} - -export async function ffmpegCall(args: string[]): Promise { - return invoke("ffmpeg_call", { args }); -} - -export async function fetchImageFile(url: string): Promise { - return invoke("fetch_image_file", { url }); -} - -export async function showDevWindow(): Promise { - return invoke("show_dev_window"); -} - -export async function stemAll(): Promise { - return invoke("stem_all"); -} - -export async function createVideoFromFrames(imageId: number): Promise { - return invoke("create_video_from_frames", { imageId }); -} - -export async function projectsDbImageCount(): Promise { - return invoke("projects_db_image_count"); -} - -export async function projectsDbProjectAdd(path: string): Promise { - return invoke("projects_db_project_add", { path }); -} - -export async function projectsDbProjectRemove(path: string): Promise { - return invoke("projects_db_project_remove", { path }); -} - -export async function projectsDbProjectList(): Promise { - return invoke("projects_db_project_list"); -} - -export async function projectsDbProjectUpdateExclude(id: number, exclude: boolean): Promise { - return invoke("projects_db_project_update_exclude", { id, exclude }); -} - -export async function projectsDbProjectScan(path: string, fullScan: boolean | null, filesize: number | null, modified: number | null): Promise { - return invoke("projects_db_project_scan", { path, fullScan, filesize, modified }); -} - -export async function projectsDbImageList(projectIds: number[] | null, search: string | null, filters: ListImagesFilter[] | null, sort: string | null, direction: string | null, take: number | null, skip: number | null, count: boolean | null, showVideo: boolean | null, showImage: boolean | null): Promise { - return invoke("projects_db_image_list", { projectIds, search, filters, sort, direction, take, skip, count, showVideo, showImage }); -} - -export async function projectsDbGetClip(imageId: number): Promise { - return invoke("projects_db_get_clip", { imageId }); -} - -export async function projectsDbImageRebuildFts(): Promise { - return invoke("projects_db_image_rebuild_fts"); -} - -export async function projectsDbWatchFolderList(): Promise { - return invoke("projects_db_watch_folder_list"); -} - -export async function projectsDbWatchFolderAdd(path: string, itemType: entity::enums::ItemType, recursive: boolean): Promise { - return invoke("projects_db_watch_folder_add", { path, itemType, recursive }); -} - -export async function projectsDbWatchFolderRemove(ids: number[]): Promise { - return invoke("projects_db_watch_folder_remove", { ids }); -} - -export async function projectsDbWatchFolderUpdate(id: number, recursive: boolean | null, lastUpdated: number | null): Promise { - return invoke("projects_db_watch_folder_update", { id, recursive, lastUpdated }); -} - -export async function projectsDbScanModelInfo(filePath: string, modelType: entity::enums::ModelType): Promise { - return invoke("projects_db_scan_model_info", { filePath, modelType }); -} - -export async function projectsDbListModels(modelType: entity::enums::ModelType | null): Promise { - return invoke("projects_db_list_models", { modelType }); -} - -export async function dtProjectGetTensorHistory(projectFile: string, index: number, count: number): Promise { - return invoke("dt_project_get_tensor_history", { projectFile, index, count }); -} - -export async function dtProjectGetTextHistory(projectFile: string): Promise { - return invoke("dt_project_get_text_history", { projectFile }); -} - -export async function dtProjectGetThumbHalf(projectFile: string, thumbId: number): Promise { - return invoke("dt_project_get_thumb_half", { projectFile, thumbId }); -} - -export async function dtProjectGetHistoryFull(projectFile: string, rowId: number): Promise { - return invoke("dt_project_get_history_full", { projectFile, rowId }); -} - -export async function dtProjectGetTensorRaw(projectId: number | null, projectPath: string | null, tensorId: string): Promise { - return invoke("dt_project_get_tensor_raw", { projectId, projectPath, tensorId }); -} - -export async function dtProjectGetTensorSize(projectId: number | null, projectPath: string | null, tensorId: string): Promise { - return invoke("dt_project_get_tensor_size", { projectId, projectPath, tensorId }); -} - -export async function dtProjectDecodeTensor(projectId: number | null, projectFile: string | null, nodeId: number | null, tensorId: string, asPng: boolean): Promise { - return invoke("dt_project_decode_tensor", { projectId, projectFile, nodeId, tensorId, asPng }); -} - -export async function dtProjectFindPredecessorCandidates(projectFile: string, rowId: number, lineage: number, logicalTime: number): Promise { - return invoke("dt_project_find_predecessor_candidates", { projectFile, rowId, lineage, logicalTime }); -} - diff --git a/src/generated/types.ts b/src/generated/types.ts index e518c3e..36dc285 100644 --- a/src/generated/types.ts +++ b/src/generated/types.ts @@ -230,8 +230,8 @@ export interface WatchFolderDTO { id: number path: string recursive: boolean | null - item_type: string last_updated: number | null + is_missing: boolean } export interface TensorHistoryClip { @@ -290,6 +290,10 @@ export interface ProjectExtra { modified: number | null missing_on: number | null excluded: boolean + name: string + full_path: string + is_missing: boolean + watchfolder_id: number } export type TextType = "PositiveText" | "NegativeText" diff --git a/src/hooks/appState.ts b/src/hooks/appState.ts index a40bdf6..963fbe8 100644 --- a/src/hooks/appState.ts +++ b/src/hooks/appState.ts @@ -105,7 +105,7 @@ async function retryUpdate() { await checkForUpdate() } -async function setView(view: string) { +function setView(view: string) { appState.currentView = view if (!appState.viewRequests[view]) appState.viewRequests[view] = [] localStorage.setItem("currentView", view) diff --git a/src/metadata/useMetadataDrop.ts b/src/hooks/useDrop.ts similarity index 66% rename from src/metadata/useMetadataDrop.ts rename to src/hooks/useDrop.ts index ffd923e..9410074 100644 --- a/src/metadata/useMetadataDrop.ts +++ b/src/hooks/useDrop.ts @@ -1,17 +1,10 @@ import { getCurrentWindow } from "@tauri-apps/api/window" -import { useMemo, useRef } from "react" -import { proxy, useSnapshot } from "valtio" -import { loadImage2 } from "./state/imageLoaders" +import { useMemo } from "react" +import { handleDrop } from '@/metadata/state/interop' +import { useProxyRef } from './valtioHooks' export function useMetadataDrop() { - const stateRef = useRef<{ isDragging: boolean; dragCounter: number } | null>(null) - - if (stateRef.current === null) { - stateRef.current = proxy({ isDragging: true, dragCounter: 0 }) - } - const state = stateRef.current - - const snap = useSnapshot(stateRef.current) + const {state, snap} = useProxyRef(() => ({ isDragging: true, dragCounter: 0 })) const handlers = useMemo( () => ({ @@ -23,8 +16,7 @@ export function useMetadataDrop() { state.isDragging = false state.dragCounter = 0 getCurrentWindow().setFocus() - // loadFromPasteboard("drag") - loadImage2("drag") + handleDrop("drag") }, onDragEnter: (e: React.DragEvent) => { e.preventDefault() diff --git a/src/hooks/useSelectableV.tsx b/src/hooks/useSelectableV.tsx index 2875240..b66569f 100644 --- a/src/hooks/useSelectableV.tsx +++ b/src/hooks/useSelectableV.tsx @@ -1,194 +1,209 @@ import { - type Context, - createContext, - memo, - type PropsWithChildren, - useContext, - useMemo, - useRef, + type Context, + createContext, + memo, + type PropsWithChildren, + useContext, + useMemo, + useRef, } from "react" import { proxy, type Snapshot } from "valtio" type SelectableContextType = { - getItems: () => T[] - mode: "single" | "multipleToggle" | "multipleModifier" - onSelectionChanged?: (selectedItems: T[]) => void - keyFn: (item: T | Snapshot) => string | number - lastSelectedItem: React.RefObject - itemsSnap: Snapshot + getItems: () => T[] + mode: "single" | "multipleToggle" | "multipleModifier" + onSelectionChanged?: (selectedItems: T[]) => void + keyFn: (item: T | Snapshot) => string | number + lastSelectedItem: React.RefObject + itemsSnap: Snapshot } const SelectableContext = createContext(null) function clearAllSelected(state: SelectableContextType) { - state.getItems().forEach((it) => { - if (it.selected) it.setSelected(false) - }) + state.getItems().forEach((it) => { + if (it.selected) it.setSelected(false) + }) } function selectItem( - state: SelectableContextType, - // key: string | number, - item: T, - modifier?: "shift" | "cmd" | null, - value?: boolean, + state: SelectableContextType, + // key: string | number, + item: T, + modifier?: "shift" | "cmd" | null, + value?: boolean, ) { - const items = state.getItems() - const itemState = items?.find((it) => state.keyFn(it) === state.keyFn(item)) - if (!itemState) return - - // single select mode - if (state.mode === "single") { - const newValue = value ?? !itemState.selected - clearAllSelected(state) - if (newValue) itemState.setSelected(newValue) - } - // multiple toggle mode - else if (state.mode === "multipleToggle") { - const newValue = value ?? !itemState.selected - itemState.setSelected(newValue) - } - // multiple modifier mode - else if (state.mode === "multipleModifier") { - // cmd updates target item only, leaving other selections unchanged - if (modifier === "cmd") { - const newValue = value ?? !itemState.selected - itemState.setSelected(newValue) - if (newValue) state.lastSelectedItem.current = itemState - } else if (modifier === "shift" && state.lastSelectedItem.current) { - // this doesn't align with how shift works in Finder - // shift SELECTS all items between the last selected and target item - // and updates the last selected item to the target item - // we need to be careful about state vs snap here - const itemsSnap = state.itemsSnap - const lastItem = state.lastSelectedItem.current - const start = itemsSnap.findIndex(it => state.keyFn(it) === state.keyFn(lastItem)) - const end = itemsSnap.findIndex(it => state.keyFn(it) === state.keyFn(item)) - if (start === -1 || end === -1) return - for (let i = Math.min(start, end); i <= Math.max(start, end); i++) { - const itemSnap = itemsSnap[i] - const it = items.find(i => state.keyFn(i) === state.keyFn(itemSnap)) - if (it) it.setSelected(true) - } - state.lastSelectedItem.current = itemState - } else { - // if modifier not held, undefined selects - unless it's the only selected value - const areOthersSelected = items.some( - (it) => state.keyFn(it) !== state.keyFn(itemState) && it.selected, - ) - const newValue = value ?? (!itemState.selected || areOthersSelected) - clearAllSelected(state) - if (newValue) { - itemState.setSelected(newValue) - state.lastSelectedItem.current = itemState - } - else { - state.lastSelectedItem.current = null - } - } - } - - if (state.onSelectionChanged) { - const selectedItems = items.filter((it) => it.selected) - state.onSelectionChanged(selectedItems) - } + const items = state.getItems() + const itemState = items?.find((it) => state.keyFn(it) === state.keyFn(item)) + if (!itemState) return + + // single select mode + if (state.mode === "single") { + const newValue = value ?? !itemState.selected + clearAllSelected(state) + if (newValue) itemState.setSelected(newValue) + } + // multiple toggle mode + else if (state.mode === "multipleToggle") { + const newValue = value ?? !itemState.selected + itemState.setSelected(newValue) + } + // multiple modifier mode + else if (state.mode === "multipleModifier") { + // cmd updates target item only, leaving other selections unchanged + if (modifier === "cmd") { + const newValue = value ?? !itemState.selected + itemState.setSelected(newValue) + if (newValue) state.lastSelectedItem.current = itemState + } else if (modifier === "shift" && state.lastSelectedItem.current) { + // this doesn't align with how shift works in Finder + // shift SELECTS all items between the last selected and target item + // and updates the last selected item to the target item + // we need to be careful about state vs snap here + const itemsSnap = state.itemsSnap + const lastItem = state.lastSelectedItem.current + const start = itemsSnap.findIndex((it) => state.keyFn(it) === state.keyFn(lastItem)) + const end = itemsSnap.findIndex((it) => state.keyFn(it) === state.keyFn(item)) + if (start === -1 || end === -1) return + for (let i = Math.min(start, end); i <= Math.max(start, end); i++) { + const itemSnap = itemsSnap[i] + const it = items.find((i) => state.keyFn(i) === state.keyFn(itemSnap)) + if (it) it.setSelected(true) + } + state.lastSelectedItem.current = itemState + } else { + // if modifier not held, undefined selects - unless it's the only selected value + const areOthersSelected = items.some( + (it) => state.keyFn(it) !== state.keyFn(itemState) && it.selected, + ) + const newValue = value ?? (!itemState.selected || areOthersSelected) + clearAllSelected(state) + if (newValue) { + itemState.setSelected(newValue) + state.lastSelectedItem.current = itemState + } else { + state.lastSelectedItem.current = null + } + } + } + + if (state.onSelectionChanged) { + const selectedItems = items.filter((it) => it.selected) + state.onSelectionChanged(selectedItems) + } } type SelectableGroupOptions = { - onSelectionChanged?: (selectedItems: T[]) => void - mode?: "single" | "multipleToggle" | "multipleModifier" - /** This should be a stable reference for cv to memoize */ - keyFn?: (item: T | Snapshot) => string | number + onSelectionChanged?: (selectedItems: T[]) => void + mode?: "single" | "multipleToggle" | "multipleModifier" + /** This should be a stable reference for cv to memoize */ + keyFn?: (item: T | Snapshot) => string | number } const defaultSelectableGroupOptions = { - mode: "single", - keyFn: (item: unknown) => JSON.stringify(item), + mode: "single", + keyFn: (item: unknown) => JSON.stringify(item), } as const export function useSelectableGroup( - itemsSnap: Snapshot, - getItemsState: () => T[], - opts: SelectableGroupOptions = {}, + itemsSnap: Snapshot, + getItemsState: () => T[], + opts: SelectableGroupOptions = {}, ) { - const { mode, keyFn, onSelectionChanged } = { ...defaultSelectableGroupOptions, ...opts } + const { mode, keyFn, onSelectionChanged } = { ...defaultSelectableGroupOptions, ...opts } - const onSelectionChangedRef = useRef((_: T[]) => {}) - if (onSelectionChanged && onSelectionChanged !== onSelectionChangedRef.current) { - onSelectionChangedRef.current = onSelectionChanged - } + const onSelectionChangedRef = useRef((_: T[]) => {}) + if (onSelectionChanged && onSelectionChanged !== onSelectionChangedRef.current) { + onSelectionChangedRef.current = onSelectionChanged + } - const lastSelectedItem = useRef(null) + const lastSelectedItem = useRef(null) - const cv = { - getItems: getItemsState, - itemsSnap, - mode, - keyFn, - onSelectionChanged, - lastSelectedItem, - } + const cv = { + getItems: getItemsState, + itemsSnap, + mode, + keyFn, + onSelectionChanged, + lastSelectedItem, + } - const Context = SelectableContext as Context | null> - const SelectableGroup = memo((props: PropsWithChildren) => { - return {props.children} - }) + const Context = SelectableContext as Context | null> + const SelectableGroup = memo((props: PropsWithChildren) => { + return {props.children} + }) - const selectedItems = itemsSnap.filter((item) => item.selected) + const selectedItems = itemsSnap.filter((item) => item.selected) - const clearSelection = () => clearAllSelected(cv) + const clearSelection = () => clearAllSelected(cv) - return { SelectableGroup, selectedItems, clearSelection } + return { SelectableGroup, selectedItems, clearSelection } } export function useSelectable(item: T) { - const context = useContext(SelectableContext) - if (!context) throw new Error("useSelectable must be used within a SelectableGroup") - - const handlers = useMemo( - () => ({ - onClick(e: React.MouseEvent) { - const modifier = getModifier(e) - selectItem(context, item, modifier) - }, - }), - [context, item], - ) - - return { - isSelected: item.selected ?? false, - handlers, - } + const context = useContext(SelectableContext) + if (!context) throw new Error("useSelectable must be used within a SelectableGroup") + + const handlers = useMemo( + () => ({ + onClick(e: React.MouseEvent) { + const modifier = getModifier(e) + selectItem(context, item, modifier) + }, + }), + [context, item], + ) + + return { + isSelected: item.selected ?? false, + handlers, + } } export type Selectable = T & { - selected: boolean - setSelected: (value: boolean) => void - toggleSelected: () => void + selected: boolean + setSelected: (value: boolean, modifier?: "shift" | "cmd" | null) => void + toggleSelected: () => void + onClick: (e: React.MouseEvent) => void } -export function makeSelectable(item: T, initialValue = false): Selectable { - const p = proxy({ - ...item, - _selected: initialValue, - get selected() { - return p._selected - }, - setSelected(value: boolean) { - p._selected = value - }, - toggleSelected() { - p._selected = !p._selected - }, - }) - - return p +export function makeSelectable( + item: T, + initialValue = false, + handleClick?: ( + item: Selectable, + currentValue: boolean, + modifier?: "shift" | "cmd" | null, + ) => void, +): Selectable { + const onClick = handleClick ?? ((item: Selectable, value: boolean) => item.toggleSelected()) + + const p = proxy({ + ...item, + _selected: initialValue, + get selected() { + return p._selected + }, + setSelected(value: boolean) { + if (p._selected === value) return + p._selected = value + }, + toggleSelected() { + p.setSelected(!p._selected) + }, + onClick(e: React.MouseEvent) { + const modifier = e.metaKey ? "cmd" : e.shiftKey ? "shift" : undefined + onClick(p, p._selected, modifier) + }, + }) + + return p } export function makeSelectableList(items: T[]): Selectable[] { - return items.map((item) => makeSelectable(item)) + return items.map((item) => makeSelectable(item)) } function getModifier(e: React.MouseEvent) { - if (e.shiftKey) return "shift" - if (e.metaKey) return "cmd" - return null + if (e.shiftKey) return "shift" + if (e.metaKey) return "cmd" + return null } diff --git a/src/library/Library.tsx b/src/library/Library.tsx index b1ad0bd..46fd941 100644 --- a/src/library/Library.tsx +++ b/src/library/Library.tsx @@ -4,7 +4,7 @@ import { open } from "@tauri-apps/plugin-dialog" import { readDir, writeTextFile } from "@tauri-apps/plugin-fs" import { proxy, useSnapshot } from "valtio" import { getDrawThingsDataFromExif } from "@/metadata/helpers" -import { getExif } from "@/metadata/state/store" +import { getExif } from "@/metadata/state/metadataStore" import type { DrawThingsMetaData } from "@/types" const store = proxy({ diff --git a/src/main.tsx b/src/main.tsx index 667de65..89279ee 100644 --- a/src/main.tsx +++ b/src/main.tsx @@ -11,10 +11,11 @@ import { HotkeysProvider } from "react-hotkeys-hook" import { themeHelpers } from "./theme/helpers" import { system } from "./theme/theme" import { forwardConsoleAll } from "./utils/tauriLogger" -import App from './App' +import App from "./App" function bootstrap() { - forwardConsoleAll() + if (!import.meta.env.DEV) forwardConsoleAll() + window.toJSON = (object: unknown) => JSON.parse(JSON.stringify(object)) const hash = document.location?.hash?.slice(1) @@ -26,7 +27,9 @@ function bootstrap() { themeHelpers.applySize() if (import.meta.env.DEV) { - const _global = globalThis as unknown as { _devKeyPressHandler?: (e: KeyboardEvent) => void } + const _global = globalThis as unknown as { + _devKeyPressHandler?: (e: KeyboardEvent) => void + } if (_global._devKeyPressHandler) { window.removeEventListener("keypress", _global._devKeyPressHandler) } diff --git a/src/menu.ts b/src/menu.ts index 411f8ae..ee48a11 100644 --- a/src/menu.ts +++ b/src/menu.ts @@ -14,11 +14,11 @@ import { subscribe } from "valtio" import { toggleColorMode } from "./components/ui/color-mode" import { postMessage } from "./context/Messages" import AppStore from "./hooks/appState" +import { loadImage2 } from "./metadata/state/imageLoaders" +import { clearAll, clearCurrent, createImageItem } from "./metadata/state/metadataStore" import { themeHelpers } from "./theme/helpers" import { getLocalImage } from "./utils/clipboard" import { viewDescription } from "./views" -import { clearAll, clearCurrent, createImageItem, getMetadataStore } from './metadata/state/store' -import { loadImage2 } from './metadata/state/imageLoaders' const Separator = () => PredefinedMenuItem.new({ item: "Separator" }) @@ -106,14 +106,10 @@ const fileSubmenu = await Submenu.new({ if (imagePath == null) return const image = await getLocalImage(imagePath) if (image) - await createImageItem( - image, - await pathLib.extname(imagePath), - { - source: "open", - file: imagePath, - }, - ) + await createImageItem(image, await pathLib.extname(imagePath), { + source: "open", + file: imagePath, + }) }, }), await MenuItem.new({ diff --git a/src/metadata/Metadata.tsx b/src/metadata/Metadata.tsx index f480331..c1c882c 100644 --- a/src/metadata/Metadata.tsx +++ b/src/metadata/Metadata.tsx @@ -4,40 +4,38 @@ import CurrentImage from "./components/CurrentImage" import History from "./history/History" import InfoPanel from "./infoPanel/InfoPanel" import { loadImage2 } from "./state/imageLoaders" -import { selectImage } from "./state/store" +import { selectImage } from "./state/metadataStore" import Toolbar from "./toolbar/Toolbar" -import { useMetadataDrop } from "./useMetadataDrop" function Metadata(props: ChakraProps) { - const { ...restProps } = props - const { handlers } = useMetadataDrop() + const { ...restProps } = props - useEffect(() => { - const handler = () => loadImage2("general") - const escHandler = (e: KeyboardEvent) => { - if (e.key === "Escape") { - selectImage(null) - } - } - window.addEventListener("paste", handler) - window.addEventListener("keydown", escHandler, { capture: false }) + useEffect(() => { + const handler = () => loadImage2("general") + const escHandler = (e: KeyboardEvent) => { + if (e.key === "Escape") { + selectImage(null) + } + } + window.addEventListener("paste", handler) + window.addEventListener("keydown", escHandler, { capture: false }) - return () => { - window.removeEventListener("paste", handler) - window.removeEventListener("keydown", escHandler) - } - }, []) + return () => { + window.removeEventListener("paste", handler) + window.removeEventListener("keydown", escHandler) + } + }, []) - return ( - - - - - - - - - ) + return ( + + + + + + + + + ) } export default Metadata diff --git a/src/metadata/components/CurrentImage.tsx b/src/metadata/components/CurrentImage.tsx index 0396506..24bde57 100644 --- a/src/metadata/components/CurrentImage.tsx +++ b/src/metadata/components/CurrentImage.tsx @@ -1,70 +1,76 @@ import { Box, chakra, Flex } from "@chakra-ui/react" -import { AnimatePresence, motion, } from "motion/react" +import { AnimatePresence, motion } from "motion/react" import { useRef } from "react" import { useSnapshot } from "valtio" import { showPreview } from "@/components/preview" -import { getMetadataStore } from "../state/store" +import { getMetadataStore } from "../state/metadataStore" interface CurrentImageProps extends ChakraProps {} function CurrentImage(props: CurrentImageProps) { - const { ...restProps } = props + const { ...restProps } = props - const snap = useSnapshot(getMetadataStore()) - const { currentImage } = snap + const snap = useSnapshot(getMetadataStore()) + console.log("current image", snap.currentImage) + const { currentImage } = snap - const imgRef = useRef(null) + const imgRef = useRef(null) - return ( - - - {currentImage?.url ? ( - showPreview(e.currentTarget)} - initial={{ opacity: 0, zIndex: 1 }} - animate={{ opacity: 1 }} - exit={{ opacity: 0, zIndex: 0, transition: {duration: 0} }} - transition={{duration: 0}} - /> - ) : ( - - Drop image here - - )} - - - ) + return ( + + + {currentImage?.url ? ( + showPreview(e.currentTarget)} + initial={{ opacity: 0, zIndex: 1 }} + animate={{ opacity: 1 }} + exit={{ opacity: 0, zIndex: 0, transition: { duration: 0 } }} + transition={{ duration: 0 }} + /> + ) : ( + + Drop image here + + )} + + + ) } export default CurrentImage export const Img = motion.create( - chakra( - "img", - { - base: { - maxWidth: "100%", - maxHeight: "100%", - minWidth: 0, - minHeight: 0, - borderRadius: "sm", - boxShadow: "pane1", - }, - }, - { defaultProps: { draggable: false } }, - ), + chakra( + "img", + { + base: { + maxWidth: "100%", + maxHeight: "100%", + minWidth: 0, + minHeight: 0, + borderRadius: "sm", + boxShadow: "pane1", + }, + }, + { defaultProps: { draggable: false } }, + ), ) diff --git a/src/metadata/components/CurrentImageZoomPan.tsx b/src/metadata/components/CurrentImageZoomPan.tsx deleted file mode 100644 index 9633055..0000000 --- a/src/metadata/components/CurrentImageZoomPan.tsx +++ /dev/null @@ -1,111 +0,0 @@ -// @ts-nocheck - -import { Box, chakra, Flex } from "@chakra-ui/react" -import { getMetadataStore } from "../state/store" -import { useSnapshot } from "valtio" -import { motion, useMotionValue, useSpring } from 'motion/react' -import { useRef } from 'react' - -interface CurrentImageProps extends ChakraProps {} - -function CurrentImage(props: CurrentImageProps) { - const { ...restProps } = props - - const snap = useSnapshot(getMetadataStore()) - const { currentImage } = snap - - const zoomMv = useSpring(1, {bounce: 0, visualDuration: 0.2}) - const originMv = useMotionValue("0px 0px") - const offX = useSpring(0, { bounce: 0, visualDuration: 0.2 }) - const offY = useSpring(0, { bounce: 0, visualDuration: 0.2 }) - - const pinchXY = useRef([0, 0]) - const imgRef = useRef(null) - - useEffect(() => { - if (currentImage?.id) { - zoomMv.set(1) - offX.set(0) - offY.set(0) - } - - }, [currentImage?.id, offX, offY, zoomMv]) - - return ( - { - if (e.ctrlKey === true) { - e.preventDefault() - const box = imgRef.current.getBoundingClientRect() - const mx = e.clientX - box.x - const my = e.clientY - box.y - - const scale = zoomMv.get() - const newScale = Math.min(Math.max(scale * (1 - e.deltaY), 0.5), 20) // e.deltaY > 0 ? scale * 0.8 : scale * 1.2 - - const rx = mx / box.width - const ry = my / box.height - - const mx2 = box.width / scale * newScale * rx - const my2 = box.height / scale * newScale * ry - - const ox = offX.get() + mx - mx2 - const oy = offY.get() + my - my2 - - offX.set(ox) - offY.set(oy) - zoomMv.set(newScale) - } - else { - offX.set(offX.get() - e.deltaX * 10) - offY.set(offY.get() - e.deltaY * 10) - } - }} - {...restProps} - > - {currentImage?.url ? ( - showPreview(e.currentTarget)} - style={{ - scale: zoomMv, - transformOrigin: "0 0", - x: offX, - y: offY, - }} - /> - ) : ( - - Drop image here - - )} - - ) -} - -export default CurrentImage - -export const Img = motion.create(chakra( - "img",{ - base: { - maxWidth: "100%", - maxHeight: "100%", - minWidth: 0, - minHeight: 0, - borderRadius: "sm", - boxShadow: "pane1" - }, - }), -) \ No newline at end of file diff --git a/src/metadata/helpers.ts b/src/metadata/helpers.ts index 8bccb9d..f3855f4 100644 --- a/src/metadata/helpers.ts +++ b/src/metadata/helpers.ts @@ -1,5 +1,5 @@ import type { DrawThingsMetaData } from "@/types" -import type { ExifType } from './state/store' +import type { ExifType } from './state/metadataStore' export function hasDrawThingsData( exif?: unknown, diff --git a/src/metadata/history/History.tsx b/src/metadata/history/History.tsx index daba15e..b67eb92 100644 --- a/src/metadata/history/History.tsx +++ b/src/metadata/history/History.tsx @@ -3,139 +3,139 @@ import { motion, useMotionValue } from "motion/react" import { useCallback, useRef } from "react" import { useSnapshot } from "valtio" import type { ImageItem } from "../state/ImageItem" -import { getMetadataStore, selectImage } from "../state/store" +import { getMetadataStore, selectImage } from "../state/metadataStore" import HistoryItem from "./HistoryItem" interface HistoryProps extends Omit {} function History(props: HistoryProps) { - const { ...restProps } = props + const { ...restProps } = props - const snap = useSnapshot(getMetadataStore()) - const { images, currentImage } = snap + const snap = useSnapshot(getMetadataStore()) + const { images, currentImage } = snap - const pinned = images.filter((i) => i.pin != null) as ImageItem[] - const unpinned = images.filter((i) => i.pin == null) as ImageItem[] - const imageItems = [...pinned, ...unpinned] as ReadonlyState + const pinned = images.filter((i) => i.pin != null) as ImageItem[] + const unpinned = images.filter((i) => i.pin == null) as ImageItem[] + const imageItems = [...pinned, ...unpinned] as ReadonlyState - const scrollRef = useRef(null) + const scrollRef = useRef(null) - const scrollbarLeft = useMotionValue(0) - const scrollbarRight = useMotionValue(0) - const scrollbarBottom = useMotionValue(0) + const scrollbarLeft = useMotionValue(0) + const scrollbarRight = useMotionValue(0) + const scrollbarBottom = useMotionValue(0) - const updateScroll = useCallback(() => { - if (!scrollRef.current) return - const { scrollLeft, scrollWidth, clientWidth } = scrollRef.current - if (clientWidth >= scrollWidth) { - scrollbarLeft.set(0) - scrollbarRight.set(0) - scrollbarBottom.set(-5) - return - } - const leftP = scrollLeft / scrollWidth - const rightP = (scrollLeft + clientWidth) / scrollWidth - scrollbarLeft.set(leftP * clientWidth) - scrollbarRight.set((rightP - leftP) * clientWidth) - scrollbarBottom.set(0) - }, [scrollbarLeft, scrollbarRight, scrollbarBottom.set]) + const updateScroll = useCallback(() => { + if (!scrollRef.current) return + const { scrollLeft, scrollWidth, clientWidth } = scrollRef.current + if (clientWidth >= scrollWidth) { + scrollbarLeft.set(0) + scrollbarRight.set(0) + scrollbarBottom.set(-5) + return + } + const leftP = scrollLeft / scrollWidth + const rightP = (scrollLeft + clientWidth) / scrollWidth + scrollbarLeft.set(leftP * clientWidth) + scrollbarRight.set((rightP - leftP) * clientWidth) + scrollbarBottom.set(0) + }, [scrollbarLeft, scrollbarRight, scrollbarBottom.set]) - return ( - - - + return ( + + + - { - if (!elem) return - scrollRef.current = elem - const ro = new ResizeObserver(updateScroll) - ro.observe(elem) - return () => ro.disconnect() - }} - onScroll={updateScroll} - > - - {imageItems.map((image) => ( - selectImage(image)} - isPinned={image.pin != null} - /> - ))} - - - - ) + { + if (!elem) return + scrollRef.current = elem + const ro = new ResizeObserver(updateScroll) + ro.observe(elem) + return () => ro.disconnect() + }} + onScroll={updateScroll} + > + + {imageItems.map((image) => ( + selectImage(image)} + isPinned={image.pin != null} + /> + ))} + + + + ) } export default History const HistoryContainer = chakra( - "div", - { - base: { - position: "relative", - // marginBottom: "-4px", - flex: "0 0 auto", - transition: "transform 0.1s ease-in-out", - marginTop: "-1.5rem", - height: "4rem", - "&:hover > div.history-scrollbar": { - height: "3px", - }, - "& > div.history-scrollbar": { - height: "0px", - transition: "height 0.2s", - }, - }, - }, - { defaultProps: { className: "group" } }, + "div", + { + base: { + position: "relative", + // marginBottom: "-4px", + flex: "0 0 auto", + transition: "transform 0.1s ease-in-out", + marginTop: "-1.5rem", + height: "4rem", + "&:hover > div.history-scrollbar": { + height: "3px", + }, + "& > div.history-scrollbar": { + height: "0px", + transition: "height 0.2s", + }, + }, + }, + { defaultProps: { className: "group" } }, ) const HistoryScrollContainer = chakra("div", { - base: { - overflowX: "auto", - overflowY: "clip", - height: "100%", - "&::-webkit-scrollbar": { display: "none" }, - }, + base: { + overflowX: "auto", + overflowY: "clip", + height: "100%", + "&::-webkit-scrollbar": { display: "none" }, + }, }) const HistoryContent = chakra("div", { - base: { - display: "flex", - flexDirection: "row", - gap: "-1px", - overflow: "visible", - position: "relative", - transform: "translateY(1.5rem)", - transition: "transform 0.15s ease", - _groupHover: { transform: "translateY(1rem)" }, - }, + base: { + display: "flex", + flexDirection: "row", + gap: "-1px", + overflow: "visible", + position: "relative", + transform: "translateY(1.5rem)", + transition: "transform 0.15s ease", + _groupHover: { transform: "translateY(1rem)" }, + }, }) diff --git a/src/metadata/history/HistoryItem.tsx b/src/metadata/history/HistoryItem.tsx index a0a976c..1a9b8b8 100644 --- a/src/metadata/history/HistoryItem.tsx +++ b/src/metadata/history/HistoryItem.tsx @@ -1,148 +1,147 @@ import { type BoxProps, chakra } from "@chakra-ui/react" import { motion } from "motion/react" import { useEffect, useRef } from "react" -import type { getMetadataStore } from "../state/store" +import type { getMetadataStore } from "../state/metadataStore" interface HistoryItemProps extends BoxProps { - image: ReadonlyState["images"][number]> - isSelected: boolean - onSelect?: () => void - isPinned?: boolean + image: ReadonlyState["images"][number]> + isSelected: boolean + onSelect?: () => void + isPinned?: boolean } function HistoryItem(props: HistoryItemProps) { - const { image, isSelected, onSelect, isPinned, ...restProps } = props - const ref = useRef(null) + const { image, isSelected, onSelect, isPinned, ...restProps } = props + const ref = useRef(null) - useEffect(() => { - if (ref.current?.parentElement?.parentElement && isSelected) { - const item = ref.current - const scrollContainer = ref.current.parentElement.parentElement - const xMin = item.offsetLeft - const xMax = xMin + item.offsetWidth - const scrollMin = scrollContainer.scrollLeft - const scrollMax = scrollMin + scrollContainer.offsetWidth - if (xMin < scrollMin) { - scrollContainer.scrollTo({ left: xMin }) - } - else if (xMax > scrollMax) { - scrollContainer.scrollTo({ left: xMax - scrollContainer.offsetWidth }) - } - } - }, [isSelected]) + useEffect(() => { + if (ref.current?.parentElement?.parentElement && isSelected) { + const item = ref.current + const scrollContainer = ref.current.parentElement.parentElement + const xMin = item.offsetLeft + const xMax = xMin + item.offsetWidth + const scrollMin = scrollContainer.scrollLeft + const scrollMax = scrollMin + scrollContainer.offsetWidth + if (xMin < scrollMin) { + scrollContainer.scrollTo({ left: xMin }) + } else if (xMax > scrollMax) { + scrollContainer.scrollTo({ left: xMax - scrollContainer.offsetWidth }) + } + } + }, [isSelected]) - return ( - { - e.stopPropagation() - e.preventDefault() - onSelect?.() - }} - isPinned={isPinned} - isSelected={isSelected} - {...restProps} - > - - - - ) + return ( + { + e.stopPropagation() + e.preventDefault() + onSelect?.() + }} + isPinned={isPinned} + isSelected={isSelected} + {...restProps} + > + + + + ) } const HistoryItemBase = chakra("div", { - base: { - display: "flex", - flexDirection: "column", - height: "4rem", - width: "4rem", - flex: "0 0 4rem", - padding: "0px", - overflow: "hidden", - border: "1px solid", - borderColor: "gray.700/50", - marginInline: "-0.5px", - backgroundColor: "var(--chakra-colors-gray-700)", - marginTop: "0px", - transformOrigin: "top", - borderRadius: "0% 0% 0 0", - zIndex: 0, - transform: "scale(1) translateY(5px)", - transition: "all 0.2s ease", - _hover: { - borderRadius: "10% 10% 0 0", - zIndex: 2, - transform: "scale(1.2) translateY(-2px)", - }, - }, - variants: { - isSelected: { - true: { - marginTop: "-3px", - borderRadius: "10% 10% 0 0", - borderTop: 0, - zIndex: 1, - transform: "scale(1.1) translateY(2px)", - }, - }, - isPinned: { - true: { - marginTop: "-3px", - borderTop: 0, - }, - }, - }, + base: { + display: "flex", + flexDirection: "column", + height: "4rem", + width: "4rem", + flex: "0 0 4rem", + padding: "0px", + overflow: "hidden", + border: "1px solid", + borderColor: "gray.700/50", + marginInline: "-0.5px", + backgroundColor: "var(--chakra-colors-gray-700)", + marginTop: "0px", + transformOrigin: "top", + borderRadius: "0% 0% 0 0", + zIndex: 0, + transform: "scale(1) translateY(5px)", + transition: "all 0.2s ease", + _hover: { + borderRadius: "10% 10% 0 0", + zIndex: 2, + transform: "scale(1.2) translateY(-2px)", + }, + }, + variants: { + isSelected: { + true: { + marginTop: "-3px", + borderRadius: "10% 10% 0 0", + borderTop: 0, + zIndex: 1, + transform: "scale(1.1) translateY(2px)", + }, + }, + isPinned: { + true: { + marginTop: "-3px", + borderTop: 0, + }, + }, + }, }) const HistoryItemIndicator = chakra("div", { - base: { - width: "100%", - borderRadius: "10% 10% 0 0", - zIndex: 2, - height: 0, - flex: "0 0 auto", - transition: "all 0.2s ease", - }, - variants: { - isSelected: { - true: { - backgroundColor: "var(--chakra-colors-highlight)", - height: "3px", - }, - }, - isPinned: { - true: { - backgroundColor: "var(--chakra-colors-info)", - height: "3px", - }, - }, - }, - compoundVariants: [ - { - isPinned: true, - isSelected: true, - css: { - backgroundColor: "var(--chakra-colors-highlight)", - }, - }, - ], + base: { + width: "100%", + borderRadius: "10% 10% 0 0", + zIndex: 2, + height: 0, + flex: "0 0 auto", + transition: "all 0.2s ease", + }, + variants: { + isSelected: { + true: { + backgroundColor: "var(--chakra-colors-highlight)", + height: "3px", + }, + }, + isPinned: { + true: { + backgroundColor: "var(--chakra-colors-info)", + height: "3px", + }, + }, + }, + compoundVariants: [ + { + isPinned: true, + isSelected: true, + css: { + backgroundColor: "var(--chakra-colors-highlight)", + }, + }, + ], }) export default HistoryItem diff --git a/src/metadata/infoPanel/tabs.tsx b/src/metadata/infoPanel/tabs.tsx index 2afdd81..2bffe31 100644 --- a/src/metadata/infoPanel/tabs.tsx +++ b/src/metadata/infoPanel/tabs.tsx @@ -53,10 +53,9 @@ const Indicator = chakra(TabsIndicator, { const Content = chakra(TabsContent, { base: { bgColor: "bg.2", - // height: "100%", padding: 1, + paddingTop: "0.25rem !important", flex: "1 1 auto", - paddingTop: "0.5rem !important", }, }) diff --git a/src/metadata/state/ImageItem.ts b/src/metadata/state/ImageItem.ts index eb879a8..2a29c89 100644 --- a/src/metadata/state/ImageItem.ts +++ b/src/metadata/state/ImageItem.ts @@ -1,119 +1,119 @@ import type { DrawThingsMetaData, ImageSource } from "@/types" import ImageStore, { type ImageStoreEntry } from "@/utils/imageStore" import { getDrawThingsDataFromExif } from "../helpers" -import { type ExifType, getExif } from "./store" +import { type ExifType, getExif } from "./metadataStore" export type ImageItemConstructorOpts = { - id: string - pin?: number | null - loadedAt: number - source: ImageSource - type: string - exif?: ExifType | null - dtData?: DrawThingsMetaData | null - entry?: ImageStoreEntry + id: string + pin?: number | null + loadedAt: number + source: ImageSource + type: string + exif?: ExifType | null + dtData?: DrawThingsMetaData | null + entry?: ImageStoreEntry } export class ImageItem { - id: string - pin?: number | null - loadedAt: number - source: ImageSource - type: string - - private _exif?: ExifType | null - private _dtData?: DrawThingsMetaData | null - private _exifStatus?: "pending" | "done" - private _entry?: ImageStoreEntry - private _entryStatus?: "pending" | "done" | "error" - - constructor(opts: ImageItemConstructorOpts) { - if (!opts.id) throw new Error("ImageItem must have an id") - if (!opts.source) throw new Error("ImageItem must have a source") - if (!opts.type) throw new Error("ImageItem must have a type") - this.id = opts.id - this.source = opts.source - this.type = opts.type - this.pin = opts.pin - this.loadedAt = opts.loadedAt - - if (opts.exif) { - this._exif = opts.exif - this._dtData = opts.dtData - this._exifStatus = "done" - } - - if (opts.entry) { - this._entry = opts.entry - this._entryStatus = "done" - } - } - - get exif() { - if (!this._exif && !this._exifStatus) this.loadExif() - - return this._exif - } - - get dtData() { - // return undefined - if (!this._dtData && !this._exifStatus && !this.exif) this.loadExif() - - return this._dtData - } - - async loadExif() { - if (this._exifStatus) return - this._exifStatus = "pending" - - if (!this._entry) await this.loadEntry() - if (!this._entry?.url) return - - try { - const exif = await getExif(this._entry.url) - this._exif = exif - this._dtData = getDrawThingsDataFromExif(exif) ?? null - } catch (e) { - console.warn("couldn't load exif from ", this._entry.url, e) - } finally { - this._exifStatus = "done" - } - } - - get thumbUrl() { - if (!this._entry?.thumbUrl && !this._entryStatus) this.loadEntry() - return this._entry?.thumbUrl - } - - get url() { - if (!this._entry?.url && !this._entryStatus) this.loadEntry() - return this._entry?.url - } - - async loadEntry() { - if (this._entryStatus) return - this._entryStatus = "pending" - - for (let i = 0; i < 3; i++) { - const entry = await ImageStore.get(this.id) - if (entry) { - this._entry = entry - this._entryStatus = "done" - return - } - await new Promise((resolve) => setTimeout(resolve, 500)) - } - - this._entryStatus = "error" - } - - toJSON() { - return { - id: this.id, - source: this.source, - pin: this.pin, - loadedAt: this.loadedAt, - type: this.type, - } - } + id: string + pin?: number | null + loadedAt: number + source: ImageSource + type: string + + private _exif?: ExifType | null + private _dtData?: DrawThingsMetaData | null + private _exifStatus?: "pending" | "done" + private _entry?: ImageStoreEntry + private _entryStatus?: "pending" | "done" | "error" + + constructor(opts: ImageItemConstructorOpts) { + if (!opts.id) throw new Error("ImageItem must have an id") + if (!opts.source) throw new Error("ImageItem must have a source") + if (!opts.type) throw new Error("ImageItem must have a type") + this.id = opts.id + this.source = opts.source + this.type = opts.type + this.pin = opts.pin + this.loadedAt = opts.loadedAt + + if (opts.exif) { + this._exif = opts.exif + this._dtData = opts.dtData + this._exifStatus = "done" + } + + if (opts.entry) { + this._entry = opts.entry + this._entryStatus = "done" + } + } + + get exif() { + if (!this._exif && !this._exifStatus) this.loadExif() + + return this._exif + } + + get dtData() { + // return undefined + if (!this._dtData && !this._exifStatus && !this.exif) this.loadExif() + + return this._dtData + } + + async loadExif() { + if (this._exifStatus) return + this._exifStatus = "pending" + + if (!this._entry) await this.loadEntry() + if (!this._entry?.url) return + + try { + const exif = await getExif(this._entry.url) + this._exif = exif + this._dtData = getDrawThingsDataFromExif(exif) ?? null + } catch (e) { + console.warn("couldn't load exif from ", this._entry.url, e) + } finally { + this._exifStatus = "done" + } + } + + get thumbUrl() { + if (!this._entry?.thumbUrl && !this._entryStatus) this.loadEntry() + return this._entry?.thumbUrl + } + + get url() { + if (!this._entry?.url && !this._entryStatus) this.loadEntry() + return this._entry?.url + } + + async loadEntry() { + if (this._entryStatus) return + this._entryStatus = "pending" + + for (let i = 0; i < 3; i++) { + const entry = await ImageStore.get(this.id) + if (entry) { + this._entry = entry + this._entryStatus = "done" + return + } + await new Promise((resolve) => setTimeout(resolve, 500)) + } + + this._entryStatus = "error" + } + + toJSON() { + return { + id: this.id, + source: this.source, + pin: this.pin, + loadedAt: this.loadedAt, + type: this.type, + } + } } diff --git a/src/metadata/state/context.tsx b/src/metadata/state/context.tsx index 1eea8c5..3bcf801 100644 --- a/src/metadata/state/context.tsx +++ b/src/metadata/state/context.tsx @@ -1,33 +1,33 @@ -import { createContext, useContext, type PropsWithChildren } from "react" +import { createContext, type PropsWithChildren, useContext } from "react" import type { ImageSource } from "@/types" import type { ImageItem } from "./ImageItem" -import type { ExifType, ImageItemParam, getMetadataStore } from "./store" +import type { ExifType, getMetadataStore, ImageItemParam } from "./metadataStore" export type MetadataStoreContextType = { - state: ReturnType - selectImage(image?: ImageItemParam | null): void - pinImage(image: ImageItemParam, value: number | boolean | null): void - pinImage(useCurrent: true, value: number | boolean | null): void - clearAll(keepTabs: boolean): Promise - clearCurrent(): Promise - createImageItem( - imageData: Uint8Array, - type: string, - source: ImageSource, - ): Promise - getExif(imagePath: string): Promise - getExif(imageDataBuffer: ArrayBuffer): Promise - initialized: boolean + state: ReturnType + selectImage(image?: ImageItemParam | null): void + pinImage(image: ImageItemParam, value: number | boolean | null): void + pinImage(useCurrent: true, value: number | boolean | null): void + clearAll(keepTabs: boolean): Promise + clearCurrent(): Promise + createImageItem( + imageData: Uint8Array, + type: string, + source: ImageSource, + ): Promise + getExif(imagePath: string): Promise + getExif(imageDataBuffer: ArrayBuffer): Promise + initialized: boolean } const MetadataStoreContext = createContext>({ - initialized: false, + initialized: false, }) export function useMetadataStore() { - const context = useContext(MetadataStoreContext) - if (!context) throw new Error("useMetadataStore must be used within a MetadataStoreProvider") - return context + const context = useContext(MetadataStoreContext) + if (!context) throw new Error("useMetadataStore must be used within a MetadataStoreProvider") + return context } export function MetadataStoreProvider(props: PropsWithChildren) {} diff --git a/src/metadata/state/hooks.ts b/src/metadata/state/hooks.ts index 33a071c..99d85db 100644 --- a/src/metadata/state/hooks.ts +++ b/src/metadata/state/hooks.ts @@ -1,10 +1,8 @@ -import { useSnapshot } from 'valtio' -import type { ImageItem } from './ImageItem' -import { getMetadataStore } from './store' - - +import { useSnapshot } from "valtio" +import type { ImageItem } from "./ImageItem" +import { getMetadataStore } from "./metadataStore" export function useCurrentImage(): ReadonlyState | undefined { - const snap = useSnapshot(getMetadataStore()) - return snap.currentImage -} \ No newline at end of file + const snap = useSnapshot(getMetadataStore()) + return snap.currentImage +} diff --git a/src/metadata/state/imageLoaders.ts b/src/metadata/state/imageLoaders.ts index 73af9a7..d37c379 100644 --- a/src/metadata/state/imageLoaders.ts +++ b/src/metadata/state/imageLoaders.ts @@ -1,5 +1,6 @@ import * as pathlib from "@tauri-apps/api/path" import plist from "plist" +import { DtpService } from "@/commands" import { postMessage } from "@/context/Messages" import { fetchImage, @@ -12,7 +13,7 @@ import { settledValues } from "@/utils/helpers" import { drawPose } from "@/utils/pose" import { isOpenPose } from "@/utils/poseHelpers" import type { ImageItem } from "./ImageItem" -import { createImageItem } from "./store" +import { createImageItem } from "./metadataStore" const prioritizedTypes = [ "NSFilenamesPboardType", @@ -50,13 +51,15 @@ export async function loadImage2(pasteboard: "general" | "drag") { const data = await getType(type) if (!data) continue - + console.log("loadimage", data) if (isPose(type, data as string)) { return createImageFromPose(data as string) } if (typeof data === "string") { + console.log("trying to load from text") const images = await tryLoadText(data, type, source, checked) + console.log(images.length) if (images.length > 1) return true if (images.length === 1 && images[0].dtData) return true if (type === "NSFilenamesPboardType") { @@ -82,7 +85,19 @@ async function tryLoadText( source: "clipboard" | "drop", excludeMut: string[] = [], ): Promise { - const { files, urls } = parseText(text, type) + const { files, urls, dtpImage } = parseText(text, type) + + if (dtpImage) { + const dtpResult = await loadDtpImage(dtpImage) + if (dtpResult) { + const item = await createImageItem(dtpResult.image, "png", { + source, + projectFile: dtpResult.projectFile, + }) + if (item) return [item] + } + } + const items = [] as Parameters[] for (const file of files) { @@ -137,7 +152,7 @@ async function tryLoadText( if (!items.length) return [] - return settledValues(items.map((item) => createImageItem(...item))) + return await settledValues(items.map((item) => createImageItem(...item))) } export function parseText(value: string, type: string) { @@ -239,10 +254,21 @@ export function getLocalPath(path: string) { return null } -function extractPaths(text: string): { files: string[]; urls: string[] } { +function extractPaths(text: string): { + files: string[] + urls: string[] + dtpImage?: { projectId: number; imageId: number } +} { const files: string[] = [] const urls: string[] = [] + const dtpImageRegex = /^dtm:\/\/dtproject\/thumb(?:half)?\/(\d+)\/(\d+)/gm + const dtpMatch = dtpImageRegex.exec(text) + if (dtpMatch) { + const dtpImage = { projectId: Number(dtpMatch[1]), imageId: Number(dtpMatch[2]) } + return { files, urls, dtpImage } + } + // Regex for detecting quoted or unquoted chunks (handles spaces inside quotes) const chunkRegex = /'([^']+)'|"([^"]+)"|(\S+)/g @@ -288,3 +314,17 @@ async function createImageFromPose(text: string) { const image = await drawPose(JSON.parse(text)) if (image) await createImageItem(image, "png", { source: "clipboard" }) } + +async function loadDtpImage(dtpImage: { projectId: number; imageId: number }) { + const imageItem = await DtpService.findImageFromPreviewId(dtpImage.projectId, dtpImage.imageId) + if (!imageItem) return + const history = await DtpService.getHistoryFull(imageItem.project_id, imageItem.node_id) + if (!history || !history.tensor_id) return + const image = await DtpService.decodeTensor( + imageItem.project_id, + history.tensor_id, + true, + imageItem.node_id, + ) + return { image, projectFile: history.project_path } +} diff --git a/src/metadata/state/interop.ts b/src/metadata/state/interop.ts index 2468ad5..0e79a2b 100644 --- a/src/metadata/state/interop.ts +++ b/src/metadata/state/interop.ts @@ -1,32 +1,40 @@ import AppStore from "@/hooks/appState" import type { ImageSource } from "@/types" import type { ImageItem } from "./ImageItem" -import { createImageItem, getMetadataStore, selectImage } from './store' +import { loadImage2 } from "./imageLoaders" +import { createImageItem, getMetadataStore, selectImage } from "./metadataStore" export async function sendToMetadata( - imageData: Uint8Array, - type: string, - source: ImageSource, + imageData: Uint8Array, + type: string, + source: ImageSource, ) { - // check if the item already has been sent to the store - const state = getMetadataStore() - let imageItem = state.images.find((im) => - compareImageSource(im.source, source), - ) as Nullable - imageItem ??= await createImageItem(imageData, type, source) + // check if the item already has been sent to the store + const state = getMetadataStore() + let imageItem = state.images.find((im) => + compareImageSource(im.source, source), + ) as Nullable + imageItem ??= await createImageItem(imageData, type, source) - if (imageItem) { - selectImage(imageItem) - await AppStore.setView("metadata") - } + if (imageItem) { + selectImage(imageItem) + AppStore.setView("metadata") + } +} + +export function handleDrop(data: unknown) { + if (data === "drag") { + AppStore.setView("metadata") + loadImage2("drag") + } } function compareImageSource(a: ImageSource, b: ImageSource) { - if (a.source !== b.source) return false - if (a.file !== b.file) return false - if (a.url !== b.url) return false - if (a.projectFile !== b.projectFile) return false - if (a.tensorId !== b.tensorId) return false - if (a.nodeId !== b.nodeId) return false - return true + if (a.source !== b.source) return false + if (a.file !== b.file) return false + if (a.url !== b.url) return false + if (a.projectFile !== b.projectFile) return false + if (a.tensorId !== b.tensorId) return false + if (a.nodeId !== b.nodeId) return false + return true } diff --git a/src/metadata/state/store.ts b/src/metadata/state/metadataStore.ts similarity index 91% rename from src/metadata/state/store.ts rename to src/metadata/state/metadataStore.ts index 54fc4ab..9ee550d 100644 --- a/src/metadata/state/store.ts +++ b/src/metadata/state/metadataStore.ts @@ -9,20 +9,21 @@ import ImageStore from "@/utils/imageStore" import { getDrawThingsDataFromExif } from "../helpers" import { ImageItem, type ImageItemConstructorOpts } from "./ImageItem" +console.log("METADATA IMPORTED") + export function bind(instance: T): T { const props = Object.getOwnPropertyNames(Object.getPrototypeOf(instance)) for (const prop of props) { const method = instance[prop as keyof T] if (prop === "constructor" || typeof method !== "function") continue - ; (instance as Record)[prop] = (...args: unknown[]) => - method.apply(instance, args) + ;(instance as Record)[prop] = (...args: unknown[]) => + method.apply(instance, args) } return instance } - function initStore() { const storeInstance = store( getStoreName("metadata"), @@ -79,9 +80,10 @@ let metadataStore: ReturnType | undefined function getStore() { if (!metadataStore) { + console.debug("METADATA: creating store") metadataStore = initStore() } - return metadataStore! + return metadataStore } export function getMetadataStore() { @@ -102,8 +104,8 @@ async function cleanUp() { const clearHistory = AppStore.store.clearHistoryOnExit const clearPins = AppStore.store.clearPinsOnExit - const saveIds = getMetadataStore().images - .filter((im) => { + const saveIds = getMetadataStore() + .images.filter((im) => { if (im.pin != null && !clearPins) return true if (!clearHistory) return true return false @@ -158,8 +160,8 @@ export function pinImage( } function reconcilePins() { - const pins = getMetadataStore().images - .filter((im) => im.pin != null) + const pins = getMetadataStore() + .images.filter((im) => im.pin != null) .sort((a, b) => (a.pin ?? 0) - (b.pin ?? 0)) pins.forEach((im, i) => { @@ -168,7 +170,8 @@ function reconcilePins() { } export async function clearAll(keepTabs = false) { - if (keepTabs) getMetadataStore().images = getMetadataStore().images.filter((im) => im.pin != null) + if (keepTabs) + getMetadataStore().images = getMetadataStore().images.filter((im) => im.pin != null) else getMetadataStore().images = [] await syncImageStore() } @@ -190,12 +193,14 @@ export async function createImageItem( source: ImageSource, ) { console.trace("create image item") + const store = getMetadataStore() if (!imageData || !type || !source) return null if (imageData.length === 0) return null // save image to image store const entry = await ImageStore.save(imageData, type) + console.log("saved image", entry) if (!entry) return null const exif = await getExif(imageData.buffer) @@ -213,10 +218,12 @@ export async function createImageItem( } const imageItem = bind(proxy(new ImageItem(item))) - const itemIndex = getMetadataStore().images.push(imageItem) - 1 + console.log("image item", imageItem) + const itemIndex = store.images.push(imageItem) - 1 + console.log("item index", itemIndex) selectImage(itemIndex) - return getMetadataStore().images[itemIndex] + return store.images[itemIndex] } /** diff --git a/src/metadata/toolbar/Toolbar.tsx b/src/metadata/toolbar/Toolbar.tsx index a14701a..fa9a233 100644 --- a/src/metadata/toolbar/Toolbar.tsx +++ b/src/metadata/toolbar/Toolbar.tsx @@ -2,84 +2,93 @@ import { Box } from "@chakra-ui/react" import { AnimatePresence, LayoutGroup, motion } from "motion/react" import { useSnapshot } from "valtio" import { useMessages } from "@/context/Messages" -import { getMetadataStore } from "../state/store" +import { getMetadataStore } from "../state/metadataStore" import { toolbarCommands } from "./commands" import { ContentHeaderContainer, ToolbarButtonGroup, ToolbarContainer, ToolbarRoot } from "./parts" import ToolbarItem from "./ToolbarItem" function Toolbar(props: ChakraProps) { - const { ...restProps } = props + const { ...restProps } = props - const snap = useSnapshot(getMetadataStore()) + const snap = useSnapshot(getMetadataStore()) - const messageChannel = useMessages("toolbar") + const messageChannel = useMessages("toolbar") - // used when rendering command items - // let changedCount = 0 - // const prevState = useRef(toolbarCommands.map((item) => "hide")) + // used when rendering command items + // let changedCount = 0 + // const prevState = useRef(toolbarCommands.map((item) => "hide")) - const buttons = toolbarCommands.map((item) => { - if (item.separator) return () => null - const isVisible = item.check?.(snap) ?? true - const state = isVisible ? "show" : "hide" - // const isChanged = prevState.current[i] !== state - // prevState.current[i] = state - // const order = isChanged ? changedCount++ : undefined + const buttons = toolbarCommands.map((item) => { + if (item.separator) return () => null + const isVisible = item.check?.(snap) ?? true + const state = isVisible ? "show" : "hide" + // const isChanged = prevState.current[i] !== state + // prevState.current[i] = state + // const order = isChanged ? changedCount++ : undefined - let key = item.id - if (item.slotId && isVisible) key = item.slotId + let key = item.id + if (item.slotId && isVisible) key = item.slotId - return () => - }) + return () => + }) - return ( - - - - - - - {buttons.map((render) => render())} - - - - - {messageChannel.messages.map((message, i, msgs) => ( - 0 && msgs.length > 1 - ? { - content: '""', - display: "block", - height: "1px", - width: "70%", - bg: "fg.1/50", - marginX: "auto", - } - : undefined - } - overflow={"hidden"} - maxWidth={"100%"} - > - - - {message.message} - - - - ))} - - - - - ) + return ( + + + + + + + {buttons.map((render) => render())} + + + + + {messageChannel.messages.map((message, i, msgs) => ( + 0 && msgs.length > 1 + ? { + content: '""', + display: "block", + height: "1px", + width: "70%", + bg: "fg.1/50", + marginX: "auto", + } + : undefined + } + overflow={"hidden"} + maxWidth={"100%"} + > + + + {message.message} + + + + ))} + + + + + ) } export default Toolbar diff --git a/src/metadata/toolbar/ToolbarItem.tsx b/src/metadata/toolbar/ToolbarItem.tsx index b31cfa8..752810d 100644 --- a/src/metadata/toolbar/ToolbarItem.tsx +++ b/src/metadata/toolbar/ToolbarItem.tsx @@ -1,76 +1,82 @@ import { useSnapshot } from "valtio" import { MotionBox } from "@/components/common" -import { getMetadataStore } from "../state/store" +import { getMetadataStore } from "../state/metadataStore" import type { ToolbarCommand } from "./commands" import ToolbarButton from "./ToolbarButton" const separatorProps: ChakraProps["_before"] = { - content: '""', - width: "1px", - height: "1rem", - bgColor: "fg.2/20", - alignSelf: "center", - marginInline: "-1px", + content: '""', + width: "1px", + height: "1rem", + bgColor: "fg.2/20", + alignSelf: "center", + marginInline: "-1px", } interface ToolbarItemProps { - command: ToolbarCommand> - showSeparator?: boolean - state: "hide" | "show" + command: ToolbarCommand> + showSeparator?: boolean + state: "hide" | "show" } export function ToolbarItem(props: ToolbarItemProps) { - const { command, showSeparator, state } = props - const snap = useSnapshot(getMetadataStore()) as ReadonlyState> + const { command, showSeparator, state } = props + const snap = useSnapshot(getMetadataStore()) as ReadonlyState< + ReturnType + > - const tip = command.tip ?? command.getTip?.(snap) - const Icon = command.icon - const content = Icon ? : command.getIcon?.(snap) + const tip = command.tip ?? command.getTip?.(snap) + const Icon = command.icon + const content = Icon ? : command.getIcon?.(snap) - // const hDelay = 0.5 * (order ?? 0) - // const vDelay = 0.5 * (changedCount ?? 0) + // const hDelay = 0.5 * (order ?? 0) + // const vDelay = 0.5 * (changedCount ?? 0) - return ( - - command.action(getMetadataStore())}> - {content} - - - ) + return ( + + command.action(getMetadataStore())} + > + {content} + + + ) } export default ToolbarItem diff --git a/src/metadata/toolbar/commands.tsx b/src/metadata/toolbar/commands.tsx index ebd5a4d..8356b68 100644 --- a/src/metadata/toolbar/commands.tsx +++ b/src/metadata/toolbar/commands.tsx @@ -12,7 +12,7 @@ import { import { postMessage } from "@/context/Messages" import ImageStore from "@/utils/imageStore" import { loadImage2 } from "../state/imageLoaders" -import { clearAll, getMetadataStore, pinImage } from "../state/store" +import { clearAll, getMetadataStore, pinImage } from "../state/metadataStore" import PinnedIcon from "./PinnedIcon" let separatorId = 0 diff --git a/src/scratch/DTPTest.tsx b/src/scratch/DTPTest.tsx new file mode 100644 index 0000000..899dee7 --- /dev/null +++ b/src/scratch/DTPTest.tsx @@ -0,0 +1,69 @@ +import { Button, Grid, Text, VStack } from "@chakra-ui/react" +import { Channel, invoke } from "@tauri-apps/api/core" +import { useRef } from "react" +import { CheckRoot, Panel } from "@/components" +import type { ProjectExtra } from "@/generated/types" +import { useProxyRef } from "@/hooks/valtioHooks" + +function Empty() { + const channel = useRef(null) + + const { state, snap } = useProxyRef(() => ({ + events: [] as unknown[], + projects: [] as ProjectExtra[], + })) + + return ( + + + + + + {snap.events.map((event, index) => ( + {JSON.stringify(event)} + ))} + + + + {snap.projects.map((project, index) => ( + {project.name} + ))} + + + + + + ) +} + +export default Empty diff --git a/src/utils/config.ts b/src/utils/config.ts index dec29eb..2c1eb52 100644 --- a/src/utils/config.ts +++ b/src/utils/config.ts @@ -1,4 +1,4 @@ -import type { TensorHistoryNode } from "@/generated/types" +import type { XTensorHistoryNode as TensorHistoryNode } from "@/commands" import { type DrawThingsConfigGrouped, type DrawThingsMetaData, SeedModeLabels } from "@/types" export function extractConfigFromTensorHistoryNode( @@ -16,7 +16,7 @@ export function extractConfigFromTensorHistoryNode( clipLText: node.clip_l_text ?? "", clipSkip: node.clip_skip, clipWeight: node.clip_weight, - controls: node.controls ?? [], + controls: node.controls as DrawThingsMetaData["config"]["controls"]?? [], cropLeft: node.crop_left, cropTop: node.crop_top, decodingTileHeight: node.decoding_tile_height * 64, diff --git a/src/utils/container/container.ts b/src/utils/container/container.ts index ffe3427..07d3d5f 100644 --- a/src/utils/container/container.ts +++ b/src/utils/container/container.ts @@ -1,3 +1,4 @@ +import type { Channel } from "@tauri-apps/api/core" import { listen } from "@tauri-apps/api/event" import EventEmitter from "eventemitter3" import { type EventMap, type IContainer, type IStateService, isDisposable } from "./interfaces" @@ -8,7 +9,7 @@ type FutureServices = Record> = type TagHandler = (tag: string, data?: Record) => void type TagFormatter = (tag: string, data?: Record) => string -type TagService = { formatTags: TagFormatter, handleTags: TagHandler } +type TagService = { formatTags: TagFormatter; handleTags: TagHandler } export class Container< T extends { [K in keyof T]: IStateService> } = object, @@ -23,8 +24,9 @@ export class Container< private invalidateUnlistenPromise: Promise<() => void> private updateUnlistenPromise: Promise<() => void> private tagHandlers: Map = new Map() + private channel?: Channel<{ type: string; data: unknown }> - constructor(servicesInit: () => T) { + constructor(channel: Channel<{ type: string; data: unknown }>, servicesInit: () => T) { super() buildContainer>( @@ -49,6 +51,13 @@ export class Container< const { tag, data } = event.payload as { tag: string; data: Record } this.handleTags(tag, data) }) + + this.channel = channel + this.channel.onmessage = (event) => { + const eventType = event.type as EventEmitter.EventNames + const data = [event.data] as EventEmitter.EventArgs + this.emit(eventType, ...data) + } } getService(name: K): T[K] { diff --git a/src/utils/helpers.test.ts b/src/utils/helpers.test.ts index 549ea12..a2a966a 100644 --- a/src/utils/helpers.test.ts +++ b/src/utils/helpers.test.ts @@ -1,5 +1,5 @@ import { describe, expect, it } from "vitest" -import { compareItems, plural } from "./helpers" +import { compareItems, groupMap, plural } from "./helpers" describe("compareItems", () => { const keyFn = (item: { id: number }) => item.id @@ -118,3 +118,62 @@ describe("plural", () => { expect(plural(7, "child", "children")).toBe("children") }) }) + +describe("groupMap", () => { + it("should group items by key using default groupFn", () => { + const items = [ + { id: 1, category: "A" }, + { id: 2, category: "B" }, + { id: 3, category: "A" }, + ] + const result = groupMap(items, (item) => [item.category, item]) + expect(result).toEqual([ + { group: "A", items: [items[0], items[2]] }, + { group: "B", items: [items[1]] }, + ]) + }) + + it("should map values using itemFn", () => { + const items = [ + { id: 1, category: "A" }, + { id: 2, category: "B" }, + { id: 3, category: "A" }, + ] + const result = groupMap(items, (item) => [item.category, item.id]) + expect(result).toEqual([ + { group: "A", items: [1, 3] }, + { group: "B", items: [2] }, + ]) + }) + + it("should use custom groupFn", () => { + const items = [ + { id: 1, category: "A" }, + { id: 2, category: "B" }, + { id: 3, category: "A" }, + ] + const result = groupMap( + items, + (item) => [item.category, item.id], + (key, values) => ({ cat: key, ids: values, count: values.length }), + ) + expect(result).toEqual([ + { cat: "A", ids: [1, 3], count: 2 }, + { cat: "B", ids: [2], count: 1 }, + ]) + }) + + it("should handle empty arrays", () => { + const result = groupMap([] as unknown[], (item) => ["key", item]) + expect(result).toEqual([]) + }) + + it("should provide index and array to itemFn", () => { + const items = ["a", "b", "c"] + const result = groupMap(items, (item, index) => [index % 2 === 0 ? "even" : "odd", item]) + expect(result).toEqual([ + { group: "even", items: ["a", "c"] }, + { group: "odd", items: ["b"] }, + ]) + }) +}) diff --git a/src/utils/helpers.ts b/src/utils/helpers.ts index 9858ea8..152deeb 100644 --- a/src/utils/helpers.ts +++ b/src/utils/helpers.ts @@ -60,8 +60,8 @@ export function shuffle(array: T[]): T[] { // Pick a random index const i = Math.floor(Math.random() * m--) - // Swap element at m with element at i - ;[array[m], array[i]] = [array[i], array[m]] + // Swap element at m with element at i + ;[array[m], array[i]] = [array[i], array[m]] } return array @@ -107,7 +107,6 @@ export async function openAnd( callback: SingleOpenAndCallback | MultiOpenAndCallback, options: Parameters[0] = {}, ) { - const files = await pickFileForImport(options) if (!files || (Array.isArray(files) && files.length === 0)) return null @@ -188,7 +187,7 @@ export interface CompareOptions { * @param opts Options for comparison * @returns An object containing the added, removed, and changed items */ -export function compareItems>( +export function compareItems( a: T[], b: T[], keyFn: (item: T) => string | number, @@ -233,7 +232,6 @@ function shallowCompare>(a: T, b: T, opts: Com if (ignoreObjects && typeof valA === "object" && valA !== null) continue if (ignoreFunctions && typeof valA === "function") continue if (valA !== b[key]) { - console.log("diff", key, valA, b[key]) return false } } @@ -246,15 +244,57 @@ export function everyNth(arr: T[], n: number): T[] { export async function pickFileForImport(options?: Parameters[0]) { const e2eFilePath = (window as any).__E2E_FILE_PATH__ - console.debug("E2E file path:", e2eFilePath); + console.debug("E2E file path:", e2eFilePath) if (e2eFilePath) { - return e2eFilePath; + return e2eFilePath } - return await open(options); + return await open(options) } export function truncate(text: string, length: number) { if (text.length <= length) return text return `${text.slice(0, length)}...` } + +type GroupMapItemFn = (item: TIn, index: number, arr: TIn[]) => [TKey, TOut] +type GroupMapGroupFn = (key: TKey, items: TOut[]) => TGroup + +const defaultGroupFn = (key: unknown, items: unknown[]) => ({ group: key, items }) + +/** + * + * @param items Maps the item to a group key and (mapped) item value + * @param itemFn + */ +export function groupMap( + items: TIn[], + itemFn: GroupMapItemFn, +): { group: TKey; items: TOut[] }[] +export function groupMap( + items: TIn[], + itemFn: GroupMapItemFn, + groupFn: GroupMapGroupFn, +): TGroup[] +export function groupMap( + items: TIn[], + itemFn: GroupMapItemFn, + groupFn: GroupMapGroupFn = defaultGroupFn as unknown as GroupMapGroupFn< + TKey, + TOut, + TGroup + >, +): TGroup[] { + const map = new Map() + for (let i = 0; i < items.length; i++) { + const item = items[i] + const [key, value] = itemFn(item, i, items) + const existing = map.get(key) + if (existing) { + existing.push(value) + } else { + map.set(key, [value]) + } + } + return Array.from(map.entries()).map(([key, value]) => groupFn(key, value)) +} diff --git a/src/utils/imageStore.ts b/src/utils/imageStore.ts index 8af391c..c5d04f0 100644 --- a/src/utils/imageStore.ts +++ b/src/utils/imageStore.ts @@ -9,157 +9,172 @@ const nanoid = customAlphabet("0123456789abcdefghijklmnopqrstuvwxyz", 12) const appDataDir = await path.appDataDir() if (!(await fs.exists(appDataDir))) { - await fs.mkdir(appDataDir) + await fs.mkdir(appDataDir) } const imageFolder = await path.join(appDataDir, getStoreName("images")) if (!(await fs.exists(imageFolder))) { - await fs.mkdir(imageFolder) + await fs.mkdir(imageFolder) } type ImageStoreEntryBase = { - id: string - type: string + id: string + type: string } export type ImageStoreEntry = { - id: string - type: string - url: string - thumbUrl: string -} - -const imagesStore = createStore( - getStoreName("images"), - { images: {} as Record }, - { - autoStart: true, - syncStrategy: "debounce", - syncInterval: 1000, - saveOnChange: true, - hooks: { - beforeFrontendSync: (state) => { - console.log("fe sync") - return state - }, - }, - }, -) -window.addEventListener("unload", () => imagesStore.stop()) - -const store = imagesStore.state -const _validTypes = ["png", "tiff", "jpg", "webp"] + id: string + type: string + url: string + thumbUrl: string +} -async function saveImage(image: Uint8Array, type: string): Promise { - if (!type || !_validTypes.includes(type)) return - if (!image || image.length === 0) return +function initStore() { + const storeInstance = createStore( + getStoreName("images"), + { images: {} as Record }, + { + autoStart: true, + syncStrategy: "debounce", + syncInterval: 1000, + saveOnChange: true, + hooks: { + beforeFrontendSync: (state) => { + console.log("fe sync") + return state + }, + }, + }, + ) + window.addEventListener("unload", () => storeInstance.stop()) + return storeInstance +} - try { - const id = await getNewId() - const fname = await getFullPath(id, type) +let imagesStore: ReturnType | null = null - await fs.writeFile(fname, image, { - createNew: true, - }) +function getStore() { + if (!imagesStore) { + console.debug("IMAGES: creating store") + imagesStore = initStore() + } + return imagesStore +} - const entry = { id, type } - store.images[id] = entry +const _validTypes = ["png", "tiff", "jpg", "webp"] - const url = convertFileSrc(fname) - return { ...entry, url, thumbUrl: url } - } catch (e) { - console.error(e) - return - } +async function saveImage(image: Uint8Array, type: string): Promise { + if (!type || !_validTypes.includes(type)) return + if (!image || image.length === 0) return + + try { + const id = await getNewId() + const fname = await getFullPath(id, type) + + await fs.writeFile(fname, image, { + createNew: true, + }) + + const entry = { id, type } + getStore().state.images[id] = entry + + const url = convertFileSrc(fname) + return { ...entry, url, thumbUrl: url } + } catch (e) { + console.error(e) + return + } } async function getImage(id: string): Promise { - const entry = store.images[id] + const entry = getStore().state.images[id] - if (!entry) return + if (!entry) return - const url = convertFileSrc(await getFullPath(id, entry.type)) - // const thumbUrl = convertFileSrc(await getThumbPath(id)) - return { ...entry, url, thumbUrl: url } + const url = convertFileSrc(await getFullPath(id, entry.type)) + // const thumbUrl = convertFileSrc(await getThumbPath(id)) + return { ...entry, url, thumbUrl: url } } async function getFullPath(id: string, ext: string) { - return await path.join(imageFolder, `${id}.${ext}`) + return await path.join(imageFolder, `${id}.${ext}`) } async function getThumbPath(id: string) { - return await path.join(imageFolder, `${id}_thumb.png`) + return await path.join(imageFolder, `${id}_thumb.png`) } async function getNewId() { - let id: string + let id: string + const state = getStore().state - do { - id = nanoid() - } while (id in imagesStore) + do { + id = nanoid() + } while (id in state.images) - return id + return id } async function removeImage(id: string) { - const item = store.images[id] - if (!item) return - await removeFile(await getFullPath(id, item.type)) - await removeFile(await getThumbPath(id)) - delete store.images[id] + const state = getStore().state + const item = state.images[id] + if (!item) return + await removeFile(await getFullPath(id, item.type)) + await removeFile(await getThumbPath(id)) + delete state.images[id] } async function syncImages(keepIds: string[] = []) { - for (const id of Object.keys(store.images)) { - if (keepIds.includes(id)) continue - - await removeImage(id) - } - - // for (const file of await fs.readDir(imageFolder)) { - // if (file.name.startsWith(".") || file.isDirectory || file.isSymlink) continue - // console.log("looking at ", file.name) - // const filename = await path.basename(file.name, await path.extname(file.name)) - // const id = filename.split("_")[0] - - // if (!keepIds.includes(id)) { - // await removeFile(await path.join(imageFolder, file.name)) - // } - // } + const state = getStore().state + for (const id of Object.keys(state.images)) { + if (keepIds.includes(id)) continue + + await removeImage(id) + } + + // for (const file of await fs.readDir(imageFolder)) { + // if (file.name.startsWith(".") || file.isDirectory || file.isSymlink) continue + // console.log("looking at ", file.name) + // const filename = await path.basename(file.name, await path.extname(file.name)) + // const id = filename.split("_")[0] + + // if (!keepIds.includes(id)) { + // await removeFile(await path.join(imageFolder, file.name)) + // } + // } } async function copyImage(id: string) { - const entry = await getImage(id) - if (!entry) return - console.debug("copying image", entry.id) - const path = await getFullPath(id, entry.type) - const data = await fs.readFile(path) - await invoke("write_clipboard_binary", { ty: `public.${entry.type}`, data }) + const entry = await getImage(id) + if (!entry) return + console.debug("copying image", entry.id) + const path = await getFullPath(id, entry.type) + const data = await fs.readFile(path) + await invoke("write_clipboard_binary", { ty: `public.${entry.type}`, data }) } async function saveCopy(id: string, dest: string) { - const entry = await getImage(id) - if (!entry) return - const path = await getFullPath(id, entry.type) - await fs.copyFile(path, dest) + const entry = await getImage(id) + if (!entry) return + const path = await getFullPath(id, entry.type) + await fs.copyFile(path, dest) } const ImageStore = { - save: saveImage, - get: getImage, - remove: removeImage, - sync: syncImages, - copy: copyImage, - saveCopy: saveCopy, + save: saveImage, + get: getImage, + remove: removeImage, + sync: syncImages, + copy: copyImage, + saveCopy: saveCopy, } export default ImageStore async function removeFile(filePath: string) { - try { - if (await fs.exists(filePath)) { - await fs.remove(filePath) - } - } catch (e) { - console.error(e) - } + try { + if (await fs.exists(filePath)) { + await fs.remove(filePath) + } + } catch (e) { + console.error(e) + } } diff --git a/src/utils/reactDevtools.js b/src/utils/reactDevtools.js index ede4cf9..7e65619 100644 --- a/src/utils/reactDevtools.js +++ b/src/utils/reactDevtools.js @@ -1,5 +1,4 @@ /** biome-ignore-all lint: does not need linting*/ -throw new Error() if (window.location.hash === "#dev") { throw new Error() } diff --git a/src/views.ts b/src/views.ts index c9fced7..78f4946 100644 --- a/src/views.ts +++ b/src/views.ts @@ -29,7 +29,7 @@ export const views = { vid: lazy(() => import("./vid/Vid")), library: lazy(() => import("./library/Library")), projects: lazy(() => import("./dtProjects/DTProjects")), - // scratch: lazy(() => import("./scratch/Coffee")), + scratch: lazy(() => import("./scratch/DTPTest")), } // export const views = { diff --git a/test/package.json b/test/package.json index a5e774b..46745b2 100644 --- a/test/package.json +++ b/test/package.json @@ -5,7 +5,8 @@ "type": "module", "scripts": { "test": "wdio run wdio.conf.ts", - "test:dev": "REUSE_BUILD=true wdio run wdio.conf.ts" + "test:dev": "REUSE_BUILD=true wdio run wdio.conf.ts", + "dev": "../src-tauri/target/debug/dtm" }, "dependencies": { "@wdio/cli": "^9.19.0" diff --git a/test/pageobjects/App.ts b/test/pageobjects/App.ts index b7e8883..e4087d2 100644 --- a/test/pageobjects/App.ts +++ b/test/pageobjects/App.ts @@ -11,10 +11,10 @@ class App { async selectView(view: "projects" | "metadata") { if (view === "projects") { await this.projectsButton.click(); - await expect(this.projectsButton).toHaveAttribute("aria-selected", "true") + await expect(this.projectsButton).toHaveAttribute("aria-current", "page") } else if (view === "metadata") { await this.metadataButton.click(); - await expect(this.metadataButton).toHaveAttribute("aria-selected", "true") + await expect(this.metadataButton).toHaveAttribute("aria-current", "page") } } diff --git a/test/specs/example.e2e.ts b/test/specs/example.e2e.ts index 193242f..3bd090b 100644 --- a/test/specs/example.e2e.ts +++ b/test/specs/example.e2e.ts @@ -15,12 +15,12 @@ describe('Basic', () => { await new Promise(resolve => setTimeout(resolve, 2000)) await App.metadataButton.click(); - await expect(App.metadataButton).toHaveAttribute("aria-selected", "true") + await expect(App.metadataButton).toHaveAttribute("aria-current", "page") await expect($("div*=Drop image here")).toBeDisplayedInViewport() await App.projectsButton.click(); - await expect(App.projectsButton).toHaveAttribute("aria-selected", "true") + await expect(App.projectsButton).toHaveAttribute("aria-current", "page") await expect($("aria/Projects")).toBeDisplayedInViewport() }) diff --git a/test/specs/projects-reset.e2e.ts b/test/specs/projects-reset.e2e.ts index 534957b..7661a5c 100644 --- a/test/specs/projects-reset.e2e.ts +++ b/test/specs/projects-reset.e2e.ts @@ -15,8 +15,7 @@ beforeEach(async () => { describe('Projects', () => { it('can add a watchfolder', async () => { - await App.projectsButton.click(); - await expect(App.projectsButton).toHaveAttribute("aria-selected", "true") + await App.selectView("projects") const settingsHeader = $("p=Settings") await expect(settingsHeader).toBeDisplayedInViewport() diff --git a/test/specs/projects.e2e.ts b/test/specs/projects.e2e.ts index 6a05ba0..64020f3 100644 --- a/test/specs/projects.e2e.ts +++ b/test/specs/projects.e2e.ts @@ -20,8 +20,7 @@ afterEach(async () => { describe('Projects', () => { it('can select a project', async () => { - await App.projectsButton.click(); - await expect(App.projectsButton).toHaveAttribute("aria-selected", "true") + await App.selectView("projects") // verify projects are listed await expect(DTProjects.projectA).toBeDisplayedInViewport() @@ -72,7 +71,9 @@ describe('Projects', () => { // verify all images are shown again await expect(DTProjects.images).toBeElementsArrayOfSize(countBefore) }) +}) +describe("Projects files", () => { it("projects list stays in sync with file system", async () => { await App.selectView("projects") diff --git a/test/specs/projects2.ts b/test/specs/projects2.ts new file mode 100644 index 0000000..0520b14 --- /dev/null +++ b/test/specs/projects2.ts @@ -0,0 +1,13 @@ +import App from "../pageobjects/App" +import DTProjects from "../pageobjects/DTProjects" + +describe('Projects 2', () => { + it("can select an image", async () => { + await App.selectView("projects") + await browser.waitUntil(async () => + (await $('[data-testid="image-grid"]').getAttribute('aria-busy')) === 'false' + ) + await DTProjects.images[0].click(); + await expect(DTImageDetail.image).toBeDisplayedInViewport() + }) +}) \ No newline at end of file diff --git a/vite.config.ts b/vite.config.ts index de875c1..6449a7f 100644 --- a/vite.config.ts +++ b/vite.config.ts @@ -1,14 +1,15 @@ -import { defineConfig, ViteDevServer } from "vite"; +import "dotenv/config" +import { defineConfig } from "vite"; import react from "@vitejs/plugin-react"; import tsconfigPaths from "vite-tsconfig-paths" import { htmlInjectionPlugin } from "vite-plugin-html-injection"; -// import wasm from "vite-plugin-wasm"; -// import { visualizer } from 'rollup-plugin-visualizer' +import { visualizer } from 'rollup-plugin-visualizer' const host = process.env.TAURI_DEV_HOST; const isMock = process.env.MOCK_TAURI === "true"; const reactDevtools = process.env.REACT_DEVTOOLS === "true"; +const showVisualizer = process.env.SHOW_VIS === "true"; const hmr = true @@ -18,15 +19,6 @@ export default defineConfig(async () => ({ build: { target: "esnext", assetsInlineLimit: 0, - // cssCodeSplit: false, - // sourcemap: true, - // rollupOptions: { - // output: { - // manualChunks() { - // return 'app' - // } - // } - // } }, plugins: [ reactDevtools ? htmlInjectionPlugin({ @@ -49,8 +41,7 @@ export default defineConfig(async () => ({ } }), tsconfigPaths(), - // wasm(), - // visualizer({ open: true }), + showVisualizer ? visualizer({ open: true }) : null, ], resolve: { alias: { @@ -74,8 +65,7 @@ export default defineConfig(async () => ({ }, }, - // Vite options tailored for Tauri development and only applied in `tauri dev` or `tauri build` - // + // 1. prevent Vite from obscuring rust errors clearScreen: false, // 2. tauri expects a fixed port, fail if that port is not available