diff --git a/Cargo.lock b/Cargo.lock index 1270e673..1edb4ce9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2159,6 +2159,7 @@ dependencies = [ "bincode", "faster-hex", "futures", + "hex", "kaspa-addresses", "kaspa-bip32", "kaspa-consensus-client", @@ -2170,6 +2171,7 @@ dependencies = [ "kaspa-utils", "kaspa-wallet-core", "kaspa-wallet-keys", + "kaspa-wallet-pskt", "kaspa-wrpc-client", "paste", "pyo3", diff --git a/Cargo.toml b/Cargo.toml index 1d975e77..355474e4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,6 +16,7 @@ ahash = "0.8.12" bincode = "1.3.3" faster-hex = "0.9.0" futures = "0.3.31" +hex = "0.4.3" kaspa-addresses = { git = "https://github.com/kaspanet/rusty-kaspa.git", rev = "1a2f98a" } kaspa-bip32 = { git = "https://github.com/kaspanet/rusty-kaspa.git", rev = "1a2f98a" } kaspa-consensus-client = { git = "https://github.com/kaspanet/rusty-kaspa.git", rev = "1a2f98a" } @@ -27,6 +28,7 @@ kaspa-txscript = { git = "https://github.com/kaspanet/rusty-kaspa.git", rev = "1 kaspa-utils = { git = "https://github.com/kaspanet/rusty-kaspa.git", rev = "1a2f98a" } kaspa-wallet-core = { git = "https://github.com/kaspanet/rusty-kaspa.git", rev = "1a2f98a" } kaspa-wallet-keys = { git = "https://github.com/kaspanet/rusty-kaspa.git", rev = "1a2f98a" } +kaspa-wallet-pskt = { git = "https://github.com/kaspanet/rusty-kaspa.git", rev = "1a2f98a" } kaspa-wrpc-client = { git = "https://github.com/kaspanet/rusty-kaspa.git", rev = "1a2f98a" } paste = "1.0" pyo3 = { version = "0.27.1", features = ['multiple-pymethods'] } diff --git a/examples/transactions/pskt.py b/examples/transactions/pskt.py new file mode 100644 index 00000000..511a6b00 --- /dev/null +++ b/examples/transactions/pskt.py @@ -0,0 +1,142 @@ +import asyncio +from kaspa import ( + Hash, + Mnemonic, + Opcodes, + PSKT, + Resolver, + RpcClient, + ScriptBuilder, + TransactionInput, + TransactionOutpoint, + UtxoEntryReference, + XPrv, + address_from_script_public_key, + calculate_transaction_mass, + create_transaction, + sign_transaction, +) + + +def derive(seed, account_index): + xprv = XPrv(seed).derive_path(f"m/45'/111111'/{account_index}'") + xpub = xprv.to_xpub() + prv = xprv.derive_child(1).to_private_key() + pub = xpub.derive_child(1).to_public_key() + return prv, pub + + +async def main(): + ####################################################### + # Derive 3 accounts to use for Multisig PSKT demo + ####################################################### + seed = Mnemonic(( + 'predict cloud noise economy home stereo tag cancel adult pistol act remove ' + 'equip cricket man summer neutral black art miracle foam world clown say' + )).to_seed() + + prv1, pub1 = derive(seed, 0) + print(f'Account 1:\n - prv: {prv1.to_string()}\n - pub: {pub1.to_string()}\n') + + prv2, pub2 = derive(seed, 1) + print(f'Account 2:\n - prv: {prv2.to_string()}\n - pub: {pub2.to_string()}\n') + + prv3, pub3 = derive(seed, 2) + print(f'Account 3:\n - prv: {prv3.to_string()}\n - pub: {pub3.to_string()}\n') + + ####################################################### + # Create Multisig address + ####################################################### + redeem_script = ScriptBuilder()\ + .add_i64(2)\ + .add_data(pub1.to_x_only_public_key().to_string())\ + .add_data(pub2.to_x_only_public_key().to_string())\ + .add_data(pub3.to_x_only_public_key().to_string())\ + .add_i64(3)\ + .add_op(Opcodes.OpCheckMultiSig) + spk = redeem_script.create_pay_to_script_hash_script() + address = address_from_script_public_key(spk, "testnet") + + print(f"Multisig address: {address}") + + while True: + if input("Send funds to address (y to proceed): ") == "y": + break + + ####################################################### + # Get address's UTXOs + ####################################################### + client = RpcClient(resolver=Resolver(), network_id='testnet-10') + await client.connect(strategy='fallback') + utxos = await client.get_utxos_by_addresses(request={'addresses': [address]}) + utxos = utxos["entries"] + utxos = sorted(utxos, key=lambda x: x['utxoEntry']['amount'], reverse=True) + total = sum(item["utxoEntry"]["amount"] for item in utxos) + print(utxos) + # utxo = utxos["entries"][0] + + ####################################################### + # Placeholder TX for fee calculation + ####################################################### + # outputs = [ + # {"address": address, "amount": int(total)} + # ] + # tx = create_transaction(utxos, outputs, 0, None, 2) + # mass = calculate_transaction_mass("testnet-10", tx) + + ####################################################### + # Get feerates & create actual TX + ####################################################### + # fee_rates = await client.get_fee_estimate() + # fee_rate = int(fee_rates["estimate"]["priorityBucket"]["feerate"]) + + # outputs = [ + # {"address": address, "amount": int(total - (fee_rate * mass)), "scriptPublicKey": ""} + # ] + # tx = create_transaction(utxos, outputs, 0, None, 1) + # tx_signed = sign_transaction(tx, [prv1], True) + + ####################################################### + # Create PSKT + ####################################################### + pskt = PSKT() + pskt_serialized = pskt.serialize() + print(pskt_serialized) + + ####################################################### + # Create input + ####################################################### + input0 = TransactionInput.from_dict({ + 'previousOutpoint': { 'transactionId': 'c38eb7191a2e0df6089b05cf7df9c92dc559db618184b11cbb8c5ba30b024bce', 'index': 1 }, + 'signatureScript': '', + 'sequence': 0, + 'sigOpCount': 1, + 'utxo': { + 'utxo': { + 'address': 'kaspatest:prganzek6uhsn4rv29g6qkeh8rduae6n3ul0xk5fnzjtugqhfaxcx0ee2dn47', + 'outpoint': {'transactionId': 'c38eb7191a2e0df6089b05cf7df9c92dc559db618184b11cbb8c5ba30b024bce', 'index': 1}, + 'amount': 98699920028, + 'scriptPublicKey': '0000aa20d1d98b36d72f09d46c5151a05b3738dbcee7538f3ef35a8998a4be20174f4d8387', + 'blockDaaScore': 354263497, + 'isCoinbase': False + } + } + }) + # pskt = pskt.input(input) + + # previous_outpoint = TransactionOutpoint( + # transaction_id=Hash(utxo["outpoint"]['transactionId']), + # index=utxo["outpoint"]['index'] + # ) + # input_0 = TransactionInput( + # previous_outpoint=previous_outpoint, + # signature_script=b"", + # sequence=0, + # sig_op_count=2, + # utxo=None + # ) + # pskt.to_constructor().input(input_0) + # print(pskt.serialize()) + +if __name__ == "__main__": + asyncio.run(main()) \ No newline at end of file diff --git a/kaspa.pyi b/kaspa.pyi index eb95ebbe..7cedf093 100644 --- a/kaspa.pyi +++ b/kaspa.pyi @@ -762,6 +762,57 @@ class Outputs: """ ... +@typing.final +class PSKT: + r""" + Partially Signed Kaspa Transaction + """ + @property + def role(self) -> builtins.str: ... + @property + def payload(self) -> builtins.str: ... + def __new__(cls, payload: typing.Optional[typing.Any] = None) -> PSKT: ... + def serialize(self) -> builtins.str: ... + def creator(self) -> PSKT: + r""" + Change role to `CREATOR` + """ + def to_constructor(self) -> PSKT: + r""" + Change role to `CONSTRUCTOR` + """ + def to_updater(self) -> PSKT: + r""" + Change role to `UPDATER` + """ + def to_signer(self) -> PSKT: + r""" + Change role to `SIGNER` + """ + def to_combiner(self) -> PSKT: + r""" + Change role to `COMBINER` + """ + def to_finalizer(self) -> PSKT: + r""" + Change role to `FINALIZER` + """ + def to_extractor(self) -> PSKT: + r""" + Change role to `EXTRACTOR` + """ + def fallback_lock_time(self, lock_time: builtins.int) -> PSKT: ... + def inputs_modifiable(self) -> PSKT: ... + def outputs_modifiable(self) -> PSKT: ... + def no_more_inputs(self) -> PSKT: ... + def no_more_outputs(self) -> PSKT: ... + def input_and_redeem_script(self, input: TransactionInput, data: builtins.str) -> PSKT: ... + def input(self, input: TransactionInput) -> PSKT: ... + def output(self, output: TransactionOutput) -> PSKT: ... + def set_sequence(self, n: builtins.int, input_index: builtins.int) -> PSKT: ... + def calculate_id(self) -> Hash: ... + def calculate_mass(self, data: NetworkId) -> builtins.int: ... + @typing.final class PaymentOutput: r""" @@ -1403,6 +1454,76 @@ class PublicKeyGenerator: str: The generator info string. """ +@typing.final +class PyPsktConsensusClientError(builtins.Exception): + r""" + PSKT Consensus Client Error + """ + ... + +@typing.final +class PyPsktCreateNotAllowedError(builtins.Exception): + r""" + PSKT Creation Not Allowed Error + """ + ... + +@typing.final +class PyPsktCtorError(builtins.Exception): + r""" + PSKT Constructor Error + """ + ... + +@typing.final +class PyPsktCustomError(builtins.Exception): + r""" + Custom PSKT Error + """ + ... + +@typing.final +class PyPsktError(builtins.Exception): + r""" + PSKT Error + """ + ... + +@typing.final +class PyPsktExpectedStateError(builtins.Exception): + r""" + PSKT Expected State Error + """ + ... + +@typing.final +class PyPsktInvalidPayloadError(builtins.Exception): + r""" + PSKT Invalid Payload Error + """ + ... + +@typing.final +class PyPsktNotInitializedError(builtins.Exception): + r""" + PSKT Not Initialized Error + """ + ... + +@typing.final +class PyPsktStateError(builtins.Exception): + r""" + PSKT State Error + """ + ... + +@typing.final +class PyPsktTxNotFinalizedError(builtins.Exception): + r""" + PSKT Tx Not Finalized Error + """ + ... + @typing.final class Resolver: r""" @@ -2461,6 +2582,21 @@ class UtxoContext: Return a range of mature UTXO entries. """ +@typing.final +class UtxoEntries: + r""" + UTXO entries collection for flexible input handling. + + This type is not intended to be instantiated directly from Python. + It serves as a helper type that allows Rust functions to accept a list + of UTXO entries in multiple convenient forms. + + Accepts: + list[UtxoEntryReference]: A list of UtxoEntryReference objects. + list[dict]: A list of dicts with UtxoEntryReference-compatible keys. + """ + ... + @typing.final class UtxoEntries: r""" @@ -2505,21 +2641,6 @@ class UtxoEntries: """ def __eq__(self, other: UtxoEntries) -> builtins.bool: ... -@typing.final -class UtxoEntries: - r""" - UTXO entries collection for flexible input handling. - - This type is not intended to be instantiated directly from Python. - It serves as a helper type that allows Rust functions to accept a list - of UTXO entries in multiple convenient forms. - - Accepts: - list[UtxoEntryReference]: A list of UtxoEntryReference objects. - list[dict]: A list of dicts with UtxoEntryReference-compatible keys. - """ - ... - @typing.final class UtxoEntry: r""" diff --git a/src/lib.rs b/src/lib.rs index dcef9df5..b30e1f0a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -12,9 +12,16 @@ use pyo3_stub_gen::define_stub_info_gatherer; define_stub_info_gatherer!(stub_info); #[pymodule] -fn kaspa(m: &Bound<'_, PyModule>) -> PyResult<()> { +fn kaspa(py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> { + // Init logging bridge pyo3_log::init(); + // Create/register exceptions submodule + let exceptions = PyModule::new(py, "exceptions")?; + m.add_submodule(&exceptions)?; + + // Register classes & functions + m.add_class::()?; m.add_class::()?; @@ -157,5 +164,48 @@ fn kaspa(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; + exceptions.add( + "PsktCustomError", + py.get_type::(), + )?; + exceptions.add( + "PsktStateError", + py.get_type::(), + )?; + exceptions.add( + "PsktExpectedStateError", + py.get_type::(), + )?; + exceptions.add( + "PsktCtorError", + py.get_type::(), + )?; + exceptions.add( + "PsktInvalidPayloadError", + py.get_type::(), + )?; + exceptions.add( + "PsktTxNotFinalizedError", + py.get_type::(), + )?; + exceptions.add( + "PsktCreateNotAllowedError", + py.get_type::(), + )?; + exceptions.add( + "PsktNotInitializedError", + py.get_type::(), + )?; + exceptions.add( + "PsktConsensusClientError", + py.get_type::(), + )?; + exceptions.add( + "PsktError", + py.get_type::(), + )?; + + m.add_class::()?; + Ok(()) } diff --git a/src/macros.rs b/src/macros.rs index 87571440..925f841c 100644 --- a/src/macros.rs +++ b/src/macros.rs @@ -61,3 +61,37 @@ macro_rules! wrap_unit_enum_for_py { // } }; } + +// PyO3 provides create_exception! macro. However we cannot use it. +// Because we need to use proc macro #[gen_stub_pyclass] to include the defined +// exception in the stub file. When using create_exception!, we cannot apply +// #[gen_stub_pyclass]. +// When PyO3 is able to generate stub files (currently experimental) +// this could likely be removed in favor of that approach. +#[macro_export] +macro_rules! create_py_exception { + ($(#[$meta:meta])* $name:ident, $py_name:literal) => { + $(#[$meta])* + #[allow(dead_code)] + #[gen_stub_pyclass] + #[pyclass(name = $py_name, extends = PyException)] + pub struct $name { + message: String, + } + + // This is required, otherwise PyO3 cannot initialize the Exception on Python side + #[pymethods] + impl $name { + #[new] + pub fn new(message: String) -> Self { + Self { message } + } + } + + impl $name { + pub fn new_err(message: impl Into) -> PyErr { + PyErr::new::(message.into()) + } + } + }; +} diff --git a/src/wallet/mod.rs b/src/wallet/mod.rs index d6f842e3..e63d6e6b 100644 --- a/src/wallet/mod.rs +++ b/src/wallet/mod.rs @@ -1,3 +1,4 @@ pub mod bip32; pub mod core; pub mod keys; +pub mod pskt; diff --git a/src/wallet/pskt/error.rs b/src/wallet/pskt/error.rs new file mode 100644 index 00000000..d0cb1239 --- /dev/null +++ b/src/wallet/pskt/error.rs @@ -0,0 +1,95 @@ +use kaspa_wallet_pskt::wasm::error::Error as NativeError; +use pyo3::exceptions::PyException; +use pyo3::prelude::*; +use pyo3_stub_gen::derive::gen_stub_pyclass; + +// Custom Python Exceptions + +crate::create_py_exception!( + /// Custom PSKT Error + PyPsktCustomError, "PsktCustomError" +); + +crate::create_py_exception!( + /// PSKT State Error + PyPsktStateError, "PsktStateError" +); + +crate::create_py_exception!( + /// PSKT Expected State Error + PyPsktExpectedStateError, "PsktExpectedStateError" +); + +crate::create_py_exception!( + /// PSKT Constructor Error + PyPsktCtorError, "PsktCtorError" +); + +crate::create_py_exception!( + /// PSKT Invalid Payload Error + PyPsktInvalidPayloadError, "PsktInvalidPayloadError" +); + +crate::create_py_exception!( + /// PSKT Tx Not Finalized Error + PyPsktTxNotFinalizedError, "PsktTxNotFinalizedError" +); + +crate::create_py_exception!( + /// PSKT Creation Not Allowed Error + PyPsktCreateNotAllowedError, "PsktCreateNotAllowedError" +); + +crate::create_py_exception!( + /// PSKT Not Initialized Error + PyPsktNotInitializedError, "PsktNotInitializedError" +); + +crate::create_py_exception!( + /// PSKT Consensus Client Error + PyPsktConsensusClientError, "PsktConsensusClientError" +); + +crate::create_py_exception!( + /// PSKT Error + PyPsktError, "PsktError" +); + +// Internal error type +// Wraps natively defined WASM Error +// Returns corresponding custom Python exception to python +pub struct Error(NativeError); + +impl From for PyErr { + fn from(value: Error) -> Self { + match value.0 { + NativeError::Custom(msg) => PyPsktCustomError::new_err(msg), + NativeError::State(msg) => PyPsktStateError::new_err(msg), + NativeError::ExpectedState(msg) => PyPsktExpectedStateError::new_err(msg), + NativeError::Ctor(msg) => PyPsktCtorError::new_err(msg), + NativeError::InvalidPayload => { + PyPsktInvalidPayloadError::new_err(NativeError::InvalidPayload.to_string()) + } + NativeError::TxNotFinalized(inner) => { + PyPsktTxNotFinalizedError::new_err(inner.to_string()) + } + NativeError::CreateNotAllowed => { + PyPsktCreateNotAllowedError::new_err(NativeError::CreateNotAllowed.to_string()) + } + NativeError::NotInitialized => { + PyPsktNotInitializedError::new_err(NativeError::NotInitialized.to_string()) + } + NativeError::ConsensusClient(inner) => { + PyPsktConsensusClientError::new_err(inner.to_string()) + } + NativeError::Pskt(inner) => PyPsktError::new_err(inner.to_string()), + _ => PyException::new_err("Unhandled error type"), + } + } +} + +impl From for Error { + fn from(value: NativeError) -> Self { + Error(value) + } +} diff --git a/src/wallet/pskt/mod.rs b/src/wallet/pskt/mod.rs new file mode 100644 index 00000000..6e53d0ba --- /dev/null +++ b/src/wallet/pskt/mod.rs @@ -0,0 +1,361 @@ +pub mod error; + +use crate::consensus::client::input::PyTransactionInput; +use crate::consensus::client::output::PyTransactionOutput; +use crate::consensus::client::transaction::PyTransaction; +use crate::consensus::core::network::PyNetworkId; +use crate::consensus::core::tx::TransactionId; +use error::Error; +use kaspa_consensus_client::{Transaction, TransactionInput, TransactionOutput}; +use kaspa_consensus_core::network::NetworkType; +use kaspa_wallet_pskt::pskt::Input; +use kaspa_wallet_pskt::wasm::error::Error as WasmError; +use kaspa_wallet_pskt::{ + error::Error as NativeError, + pskt::{Inner, PSKT}, + role::*, + wasm::pskt::State, +}; +use pyo3::{exceptions::PyException, prelude::*}; +use pyo3_stub_gen::derive::*; +use std::sync::{Arc, Mutex, MutexGuard}; + +/// Partially Signed Kaspa Transaction +#[gen_stub_pyclass] +#[pyclass(name = "PSKT")] +#[derive(Clone)] +pub struct PyPSKT { + state: Arc>>, +} + +impl PyPSKT { + fn take(&self) -> State { + self.state.lock().unwrap().take().unwrap() + } + + fn replace(&self, state: State) -> PyResult { + self.state.lock().unwrap().replace(state); + Ok(self.clone()) + } + + fn state(&self) -> MutexGuard<'_, Option> { + self.state.lock().unwrap() + } +} + +#[gen_stub_pymethods] +#[pymethods] +impl PyPSKT { + #[new] + #[pyo3(signature = (payload=None))] + pub fn new(payload: Option>) -> PyResult { + let pskt = match payload { + None => PyPSKT::from(State::Creator(PSKT::::default())), + Some(p) => { + if let Ok(s) = p.extract::() { + let inner: State = serde_json::from_str(&s) + .map_err(|err| Error::from(WasmError::Ctor(err.to_string())))?; + PyPSKT::from(inner) + } else if let Ok(py_tx) = p.extract::() { + let tx: Transaction = py_tx.into(); + let inner: Inner = tx + .try_into() + .map_err(|err: NativeError| PyException::new_err(err.to_string()))?; + PyPSKT::from(State::NoOp(Some(inner))) + } else { + return Err(Error::from(WasmError::InvalidPayload).into()); + } + } + }; + + Ok(pskt) + } + + #[getter] + pub fn get_role(&self) -> String { + self.state().as_ref().unwrap().display().to_string() + } + + #[getter] + pub fn get_payload(&self) -> PyResult { + let state = self.state(); + serde_json::to_string(state.as_ref().unwrap()) + .map_err(|err| PyException::new_err(err.to_string())) + // workflow_wasm::serde::to_value(state.as_ref().unwrap()).unwrap() + } + + pub fn serialize(&self) -> String { + let state = self.state(); + serde_json::to_string(state.as_ref().unwrap()).unwrap() + } + + /// Change role to `CREATOR` + pub fn creator(&self) -> PyResult { + let state = match self.take() { + State::NoOp(inner) => match inner { + None => State::Creator(PSKT::default()), + Some(_) => Err(Error::from(WasmError::CreateNotAllowed))?, + }, + state => Err(Error::from(WasmError::state(state)))?, + }; + + self.replace(state) + } + + /// Change role to `CONSTRUCTOR` + pub fn to_constructor(&self) -> PyResult { + let state = match self.take() { + State::NoOp(inner) => { + State::Constructor(inner.ok_or(Error::from(WasmError::NotInitialized))?.into()) + } + State::Creator(pskt) => State::Constructor(pskt.constructor()), + state => Err(Error::from(WasmError::state(state)))?, + }; + + self.replace(state) + } + + /// Change role to `UPDATER` + pub fn to_updater(&self) -> PyResult { + let state = match self.take() { + State::NoOp(inner) => { + State::Updater(inner.ok_or(Error::from(WasmError::NotInitialized))?.into()) + } + State::Constructor(constructor) => State::Updater(constructor.updater()), + state => Err(Error::from(WasmError::state(state)))?, + }; + + self.replace(state) + } + + /// Change role to `SIGNER` + pub fn to_signer(&self) -> PyResult { + let state = match self.take() { + State::NoOp(inner) => { + State::Signer(inner.ok_or(Error::from(WasmError::NotInitialized))?.into()) + } + State::Constructor(pskt) => State::Signer(pskt.signer()), + State::Updater(pskt) => State::Signer(pskt.signer()), + State::Combiner(pskt) => State::Signer(pskt.signer()), + state => Err(Error::from(WasmError::state(state)))?, + }; + + self.replace(state) + } + + /// Change role to `COMBINER` + pub fn to_combiner(&self) -> PyResult { + let state = match self.take() { + State::NoOp(inner) => { + State::Combiner(inner.ok_or(Error::from(WasmError::NotInitialized))?.into()) + } + State::Constructor(pskt) => State::Combiner(pskt.combiner()), + State::Updater(pskt) => State::Combiner(pskt.combiner()), + State::Signer(pskt) => State::Combiner(pskt.combiner()), + state => Err(Error::from(WasmError::state(state)))?, + }; + + self.replace(state) + } + + /// Change role to `FINALIZER` + pub fn to_finalizer(&self) -> PyResult { + let state = match self.take() { + State::NoOp(inner) => { + State::Finalizer(inner.ok_or(Error::from(WasmError::NotInitialized))?.into()) + } + State::Combiner(pskt) => State::Finalizer(pskt.finalizer()), + state => Err(Error::from(WasmError::state(state)))?, + }; + + self.replace(state) + } + + /// Change role to `EXTRACTOR` + pub fn to_extractor(&self) -> PyResult { + let state = match self.take() { + State::NoOp(inner) => { + State::Extractor(inner.ok_or(Error::from(WasmError::NotInitialized))?.into()) + } + State::Finalizer(pskt) => State::Extractor( + pskt.extractor() + .map_err(WasmError::from) + .map_err(Error::from)?, + ), + state => Err(Error::from(WasmError::state(state)))?, + }; + + self.replace(state) + } + + pub fn fallback_lock_time(&self, lock_time: u64) -> PyResult { + let state = match self.take() { + State::Creator(pskt) => State::Creator(pskt.fallback_lock_time(lock_time)), + _ => Err(Error::from(WasmError::expected_state("Creator")))?, + }; + + self.replace(state) + } + + pub fn inputs_modifiable(&self) -> PyResult { + let state = match self.take() { + State::Creator(pskt) => State::Creator(pskt.inputs_modifiable()), + _ => Err(Error::from(WasmError::expected_state("Creator")))?, + }; + + self.replace(state) + } + + pub fn outputs_modifiable(&self) -> PyResult { + let state = match self.take() { + State::Creator(pskt) => State::Creator(pskt.outputs_modifiable()), + _ => Err(Error::from(WasmError::expected_state("Creator")))?, + }; + + self.replace(state) + } + + pub fn no_more_inputs(&self) -> PyResult { + let state = match self.take() { + State::Constructor(pskt) => State::Constructor(pskt.no_more_inputs()), + _ => Err(Error::from(WasmError::expected_state("Constructor")))?, + }; + + self.replace(state) + } + + pub fn no_more_outputs(&self) -> PyResult { + let state = match self.take() { + State::Constructor(pskt) => State::Constructor(pskt.no_more_outputs()), + _ => Err(Error::from(WasmError::expected_state("Constructor")))?, + }; + + self.replace(state) + } + + pub fn input_and_redeem_script( + &self, + input: PyTransactionInput, + data: String, + ) -> PyResult { + let input = TransactionInput::from(input); + let mut input: Input = input + .try_into() + .map_err(|err| Error::from(WasmError::from(err)))?; + input.redeem_script = Some(hex::decode(data).map_err(|e| { + Error::from(WasmError::custom(format!( + "Redeem script is not a hex string: {}", + e + ))) + })?); + let state = match self.take() { + State::Constructor(pskt) => State::Constructor(pskt.input(input)), + _ => Err(Error::from(WasmError::expected_state("Constructor")))?, + }; + + self.replace(state) + } + + pub fn input(&self, input: PyTransactionInput) -> PyResult { + let input = TransactionInput::from(input); + let state = match self.take() { + State::Constructor(pskt) => State::Constructor( + pskt.input( + input + .try_into() + .map_err(|err| Error::from(WasmError::from(err)))?, + ), + ), + _ => Err(Error::from(WasmError::expected_state("Constructor")))?, + }; + + self.replace(state) + } + + pub fn output(&self, output: PyTransactionOutput) -> PyResult { + let output = TransactionOutput::from(output); + let state = match self.take() { + State::Constructor(pskt) => State::Constructor( + pskt.output( + output + .try_into() + .map_err(|err| Error::from(WasmError::from(err)))?, + ), + ), + _ => Err(Error::from(WasmError::expected_state("Constructor")))?, + }; + + self.replace(state) + } + + pub fn set_sequence(&self, n: u64, input_index: usize) -> PyResult { + let state = match self.take() { + State::Updater(pskt) => State::Updater( + pskt.set_sequence(n, input_index) + .map_err(|err| Error::from(WasmError::from(err)))?, + ), + _ => Err(Error::from(WasmError::expected_state("Updater")))?, + }; + + self.replace(state) + } + + pub fn calculate_id(&self) -> PyResult { + let state = self.state(); + match state.as_ref().unwrap() { + State::Signer(pskt) => Ok(pskt.calculate_id().into()), + _ => Err(Error::from(WasmError::expected_state("Signer")))?, + } + } + + pub fn calculate_mass(&self, data: PyNetworkId) -> PyResult { + let network_type = data.get_network_type(); + + let cloned_pskt = self.clone(); + + let extractor = { + let finalizer = cloned_pskt.to_finalizer()?; + + let finalizer_state = finalizer.state().clone().unwrap(); + + match finalizer_state { + State::Finalizer(pskt) => { + for input in pskt.inputs.iter() { + if input.redeem_script.is_some() { + return Err(Error::from(WasmError::custom( + "Mass calculation is not supported for inputs with redeem scripts", + )) + .into()); + } + } + let pskt = pskt + .finalize_sync(|inner: &Inner| -> PyResult>> { + Ok(vec![vec![0u8, 65]; inner.inputs.len()]) + }) + .map_err(|e| { + Error::from(WasmError::custom(format!("Failed to finalize PSKT: {e}"))) + })?; + pskt.extractor() + .map_err(|err| Error::from(WasmError::TxNotFinalized(err)))? + } + _ => panic!("Finalizer state is not valid"), + } + }; + let tx = extractor + .extract_tx_unchecked(&NetworkType::from(network_type).into()) + .map_err(|e| { + Error::from(WasmError::custom(format!( + "Failed to extract transaction: {e}" + ))) + })?; + Ok(tx.tx.mass()) + } +} + +impl From for PyPSKT { + fn from(value: State) -> Self { + PyPSKT { + state: Arc::new(Mutex::new(Some(value))), + } + } +}