diff --git a/agent/TESTING_REMOTE_RESOURCES.md b/agent/TESTING_REMOTE_RESOURCES.md index fe128dd9d..8231db022 100644 --- a/agent/TESTING_REMOTE_RESOURCES.md +++ b/agent/TESTING_REMOTE_RESOURCES.md @@ -280,17 +280,22 @@ go build -o build/cvms-test ./test/cvms/main.go HOST=$HOST_IP PORT=7001 ./build/cvms-test \ -public-key-path ./public.pem \ -attested-tls-bool false \ - -kbs-url http://$HOST_IP:8080 \ -algo-type python \ -algo-source-url docker://$HOST_IP:5000/encrypted-lin-reg:v1.0 \ -algo-kbs-path default/key/algo-key \ + -algo-kbs-url http://$HOST_IP:8080 \ -algo-hash $ALGO_HASH \ -algo-args datasets/dataset_0.csv \ -dataset-source-urls docker://$HOST_IP:5000/encrypted-iris:v1.0 \ -dataset-kbs-paths default/key/dataset-key \ + -dataset-kbs-urls http://$HOST_IP:8080 \ -dataset-hash $DATASET_HASH ``` +> [!NOTE] +> You must specify the KBS URL for each encrypted resource using `-algo-kbs-url` and `-dataset-kbs-urls`. A global KBS is no longer supported. + + ### 3. Create VM via CLI (Host) ```bash @@ -356,17 +361,31 @@ The CVMS server sends this manifest to the agent: "type": "oci-image", "uri": "docker://localhost:5000/encrypted-lin-reg:v1.0", "encrypted": true, - "kbs_resource_path": "default/key/algo-key" + "kbs_resource_path": "default/key/algo-key", + "kbs": { + "url": "http://192.168.100.15:8080", + "enabled": true + } }, "datasets": [ { - "type": "oci-image", - "uri": "docker://localhost:5000/encrypted-iris:v1.0", - "encrypted": true, - "kbs_resource_path": "default/key/dataset-key" + "filename": "iris.csv", + "source": { + "type": "oci-image", + "url": "docker://localhost:5000/encrypted-iris:v1.0", + "encrypted": true, + "kbs_resource_path": "default/key/dataset-key" + }, + "kbs": { + "url": "http://192.168.100.20:8080", + "enabled": true + } } ], - "kbs_url": "http://192.168.100.15:8080" + "kbs": { + "url": "http://192.168.100.15:8080", + "enabled": true + } } ``` diff --git a/agent/auth/auth_test.go b/agent/auth/auth_test.go index 0948db5b2..bf6f834ab 100644 --- a/agent/auth/auth_test.go +++ b/agent/auth/auth_test.go @@ -44,7 +44,7 @@ func TestAuthenticateUser(t *testing.T) { manifest := agent.Computation{ ResultConsumers: []agent.ResultConsumer{{UserKey: resultConsumerPubKey}}, Datasets: []agent.Dataset{{UserKey: dataProviderPubKey}}, - Algorithm: agent.Algorithm{UserKey: algorithmProviderPubKey}, + Algorithm: &agent.Algorithm{UserKey: algorithmProviderPubKey}, } auth, err := New(manifest) diff --git a/agent/computations.go b/agent/computations.go index 8e4124cb7..3e760cf70 100644 --- a/agent/computations.go +++ b/agent/computations.go @@ -45,9 +45,8 @@ type Computation struct { Name string `json:"name,omitempty"` Description string `json:"description,omitempty"` Datasets Datasets `json:"datasets,omitempty"` - Algorithm Algorithm `json:"algorithm,omitempty"` + Algorithm *Algorithm `json:"algorithm,omitempty"` ResultConsumers []ResultConsumer `json:"result_consumers,omitempty"` - KBS KBSConfig `json:"kbs,omitempty"` } type ResultConsumer struct { @@ -69,6 +68,7 @@ type Dataset struct { Filename string `json:"filename,omitempty"` Source *ResourceSource `json:"source,omitempty"` // Optional remote source Decompress bool `json:"decompress,omitempty"` + KBS *KBSConfig `json:"kbs,omitempty"` } type Datasets []Dataset @@ -81,6 +81,7 @@ type Algorithm struct { Source *ResourceSource `json:"source,omitempty"` // Optional remote source AlgoType string `json:"algo_type,omitempty"` AlgoArgs []string `json:"algo_args,omitempty"` + KBS *KBSConfig `json:"kbs,omitempty"` } type ManifestIndexKey struct{} diff --git a/agent/cvms/api/grpc/client.go b/agent/cvms/api/grpc/client.go index 836ff7980..25d67a3cf 100644 --- a/agent/cvms/api/grpc/client.go +++ b/agent/cvms/api/grpc/client.go @@ -234,9 +234,10 @@ func (client *CVMSClient) executeRun(ctx context.Context, runReq *cvms.Computati } if runReq.Algorithm != nil { - ac.Algorithm = agent.Algorithm{ - Hash: [32]byte(runReq.Algorithm.Hash), - UserKey: runReq.Algorithm.UserKey, + ac.Algorithm = &agent.Algorithm{ + Hash: [32]byte(runReq.Algorithm.Hash), + UserKey: runReq.Algorithm.UserKey, + AlgoType: runReq.Algorithm.AlgoType, } // Copy remote source if configured if runReq.Algorithm.Source != nil { @@ -246,8 +247,13 @@ func (client *CVMSClient) executeRun(ctx context.Context, runReq *cvms.Computati Encrypted: runReq.Algorithm.Source.Encrypted, } } - ac.Algorithm.AlgoType = runReq.Algorithm.AlgoType ac.Algorithm.AlgoArgs = runReq.Algorithm.AlgoArgs + if runReq.Algorithm.Kbs != nil { + ac.Algorithm.KBS = &agent.KBSConfig{ + URL: runReq.Algorithm.Kbs.Url, + Enabled: runReq.Algorithm.Kbs.Enabled, + } + } } for _, ds := range runReq.Datasets { @@ -265,6 +271,12 @@ func (client *CVMSClient) executeRun(ctx context.Context, runReq *cvms.Computati } } dataset.Decompress = ds.Decompress + if ds.Kbs != nil { + dataset.KBS = &agent.KBSConfig{ + URL: ds.Kbs.Url, + Enabled: ds.Kbs.Enabled, + } + } ac.Datasets = append(ac.Datasets, dataset) } @@ -274,14 +286,6 @@ func (client *CVMSClient) executeRun(ctx context.Context, runReq *cvms.Computati }) } - // Copy KBS configuration - if runReq.Kbs != nil { - ac.KBS = agent.KBSConfig{ - URL: runReq.Kbs.Url, - Enabled: runReq.Kbs.Enabled, - } - } - // Check if the agent is in the correct state to initialize a new computation. // If the agent is already processing this computation (e.g., after a reconnection), // skip initialization to avoid state errors. diff --git a/agent/cvms/api/grpc/client_test.go b/agent/cvms/api/grpc/client_test.go index 7830a222e..05dc95f0f 100644 --- a/agent/cvms/api/grpc/client_test.go +++ b/agent/cvms/api/grpc/client_test.go @@ -554,11 +554,12 @@ func TestManagerClient_handleRunReqChunksWithRemoteSource(t *testing.T) { KbsResourcePath: "default/key/algo-key", Encrypted: true, }, + Kbs: &cvms.KBSConfig{ + Url: "https://kbs.example.com:8080", + Enabled: true, + }, }, - Kbs: &cvms.KBSConfig{ - Url: "https://kbs.example.com:8080", - Enabled: true, - }, + ResultConsumers: []*cvms.ResultConsumer{ { UserKey: []byte("test-consumer"), @@ -577,8 +578,8 @@ func TestManagerClient_handleRunReqChunksWithRemoteSource(t *testing.T) { mockSvc.On("State").Return("ReceivingManifest") mockSvc.On("InitComputation", mock.Anything, mock.MatchedBy(func(c agent.Computation) bool { - // Verify KBS config is passed - if !c.KBS.Enabled || c.KBS.URL != "https://kbs.example.com:8080" { + // Verify Algorithm KBS config is passed + if c.Algorithm.KBS == nil || !c.Algorithm.KBS.Enabled || c.Algorithm.KBS.URL != "https://kbs.example.com:8080" { return false } // Verify algorithm source is passed diff --git a/agent/cvms/cvms.pb.go b/agent/cvms/cvms.pb.go index 3cfa386c0..6bba150c6 100644 --- a/agent/cvms/cvms.pb.go +++ b/agent/cvms/cvms.pb.go @@ -826,7 +826,6 @@ type ComputationRunReq struct { Algorithm *Algorithm `protobuf:"bytes,5,opt,name=algorithm,proto3" json:"algorithm,omitempty"` ResultConsumers []*ResultConsumer `protobuf:"bytes,6,rep,name=result_consumers,json=resultConsumers,proto3" json:"result_consumers,omitempty"` AgentConfig *AgentConfig `protobuf:"bytes,7,opt,name=agent_config,json=agentConfig,proto3" json:"agent_config,omitempty"` - Kbs *KBSConfig `protobuf:"bytes,8,opt,name=kbs,proto3" json:"kbs,omitempty"` // Optional KBS configuration for remote resources unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } @@ -910,13 +909,6 @@ func (x *ComputationRunReq) GetAgentConfig() *AgentConfig { return nil } -func (x *ComputationRunReq) GetKbs() *KBSConfig { - if x != nil { - return x.Kbs - } - return nil -} - type ResultConsumer struct { state protoimpl.MessageState `protogen:"open.v1"` UserKey []byte `protobuf:"bytes,1,opt,name=userKey,proto3" json:"userKey,omitempty"` @@ -968,6 +960,7 @@ type Dataset struct { Filename string `protobuf:"bytes,3,opt,name=filename,proto3" json:"filename,omitempty"` Source *Source `protobuf:"bytes,4,opt,name=source,proto3" json:"source,omitempty"` // Optional remote source for encrypted dataset Decompress bool `protobuf:"varint,5,opt,name=decompress,proto3" json:"decompress,omitempty"` + Kbs *KBSConfig `protobuf:"bytes,6,opt,name=kbs,proto3" json:"kbs,omitempty"` // Optional KBS configuration override unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } @@ -1037,6 +1030,13 @@ func (x *Dataset) GetDecompress() bool { return false } +func (x *Dataset) GetKbs() *KBSConfig { + if x != nil { + return x.Kbs + } + return nil +} + type Algorithm struct { state protoimpl.MessageState `protogen:"open.v1"` Hash []byte `protobuf:"bytes,1,opt,name=hash,proto3" json:"hash,omitempty"` // should be sha3.Sum256, 32 byte length. @@ -1044,6 +1044,7 @@ type Algorithm struct { Source *Source `protobuf:"bytes,3,opt,name=source,proto3" json:"source,omitempty"` // Optional remote source for encrypted algorithm AlgoType string `protobuf:"bytes,4,opt,name=algo_type,json=algoType,proto3" json:"algo_type,omitempty"` AlgoArgs []string `protobuf:"bytes,5,rep,name=algo_args,json=algoArgs,proto3" json:"algo_args,omitempty"` + Kbs *KBSConfig `protobuf:"bytes,6,opt,name=kbs,proto3" json:"kbs,omitempty"` // Optional KBS configuration override unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } @@ -1113,6 +1114,13 @@ func (x *Algorithm) GetAlgoArgs() []string { return nil } +func (x *Algorithm) GetKbs() *KBSConfig { + if x != nil { + return x.Kbs + } + return nil +} + type Source struct { state protoimpl.MessageState `protogen:"open.v1"` Type string `protobuf:"bytes,1,opt,name=type,proto3" json:"type,omitempty"` // Type of source: "oci-image" (only OCI images supported for CoCo) @@ -1485,7 +1493,7 @@ const file_agent_cvms_cvms_proto_rawDesc = "" + "\fRunReqChunks\x12\x12\n" + "\x04data\x18\x01 \x01(\fR\x04data\x12\x0e\n" + "\x02id\x18\x02 \x01(\tR\x02id\x12\x17\n" + - "\ais_last\x18\x03 \x01(\bR\x06isLast\"\xcd\x02\n" + + "\ais_last\x18\x03 \x01(\bR\x06isLast\"\xaa\x02\n" + "\x11ComputationRunReq\x12\x0e\n" + "\x02id\x18\x01 \x01(\tR\x02id\x12\x12\n" + "\x04name\x18\x02 \x01(\tR\x04name\x12 \n" + @@ -1493,10 +1501,9 @@ const file_agent_cvms_cvms_proto_rawDesc = "" + "\bdatasets\x18\x04 \x03(\v2\r.cvms.DatasetR\bdatasets\x12-\n" + "\talgorithm\x18\x05 \x01(\v2\x0f.cvms.AlgorithmR\talgorithm\x12?\n" + "\x10result_consumers\x18\x06 \x03(\v2\x14.cvms.ResultConsumerR\x0fresultConsumers\x124\n" + - "\fagent_config\x18\a \x01(\v2\x11.cvms.AgentConfigR\vagentConfig\x12!\n" + - "\x03kbs\x18\b \x01(\v2\x0f.cvms.KBSConfigR\x03kbs\"*\n" + + "\fagent_config\x18\a \x01(\v2\x11.cvms.AgentConfigR\vagentConfig\"*\n" + "\x0eResultConsumer\x12\x18\n" + - "\auserKey\x18\x01 \x01(\fR\auserKey\"\x99\x01\n" + + "\auserKey\x18\x01 \x01(\fR\auserKey\"\xbc\x01\n" + "\aDataset\x12\x12\n" + "\x04hash\x18\x01 \x01(\fR\x04hash\x12\x18\n" + "\auserKey\x18\x02 \x01(\fR\auserKey\x12\x1a\n" + @@ -1504,13 +1511,15 @@ const file_agent_cvms_cvms_proto_rawDesc = "" + "\x06source\x18\x04 \x01(\v2\f.cvms.SourceR\x06source\x12\x1e\n" + "\n" + "decompress\x18\x05 \x01(\bR\n" + - "decompress\"\x99\x01\n" + + "decompress\x12!\n" + + "\x03kbs\x18\x06 \x01(\v2\x0f.cvms.KBSConfigR\x03kbs\"\xbc\x01\n" + "\tAlgorithm\x12\x12\n" + "\x04hash\x18\x01 \x01(\fR\x04hash\x12\x18\n" + "\auserKey\x18\x02 \x01(\fR\auserKey\x12$\n" + "\x06source\x18\x03 \x01(\v2\f.cvms.SourceR\x06source\x12\x1b\n" + "\talgo_type\x18\x04 \x01(\tR\balgoType\x12\x1b\n" + - "\talgo_args\x18\x05 \x03(\tR\balgoArgs\"x\n" + + "\talgo_args\x18\x05 \x03(\tR\balgoArgs\x12!\n" + + "\x03kbs\x18\x06 \x01(\v2\x0f.cvms.KBSConfigR\x03kbs\"x\n" + "\x06Source\x12\x12\n" + "\x04type\x18\x01 \x01(\tR\x04type\x12\x10\n" + "\x03url\x18\x02 \x01(\tR\x03url\x12*\n" + @@ -1591,16 +1600,17 @@ var file_agent_cvms_cvms_proto_depIdxs = []int32{ 14, // 15: cvms.ComputationRunReq.algorithm:type_name -> cvms.Algorithm 12, // 16: cvms.ComputationRunReq.result_consumers:type_name -> cvms.ResultConsumer 17, // 17: cvms.ComputationRunReq.agent_config:type_name -> cvms.AgentConfig - 16, // 18: cvms.ComputationRunReq.kbs:type_name -> cvms.KBSConfig - 15, // 19: cvms.Dataset.source:type_name -> cvms.Source + 15, // 18: cvms.Dataset.source:type_name -> cvms.Source + 16, // 19: cvms.Dataset.kbs:type_name -> cvms.KBSConfig 15, // 20: cvms.Algorithm.source:type_name -> cvms.Source - 7, // 21: cvms.Service.Process:input_type -> cvms.ClientStreamMessage - 8, // 22: cvms.Service.Process:output_type -> cvms.ServerStreamMessage - 22, // [22:23] is the sub-list for method output_type - 21, // [21:22] is the sub-list for method input_type - 21, // [21:21] is the sub-list for extension type_name - 21, // [21:21] is the sub-list for extension extendee - 0, // [0:21] is the sub-list for field type_name + 16, // 21: cvms.Algorithm.kbs:type_name -> cvms.KBSConfig + 7, // 22: cvms.Service.Process:input_type -> cvms.ClientStreamMessage + 8, // 23: cvms.Service.Process:output_type -> cvms.ServerStreamMessage + 23, // [23:24] is the sub-list for method output_type + 22, // [22:23] is the sub-list for method input_type + 22, // [22:22] is the sub-list for extension type_name + 22, // [22:22] is the sub-list for extension extendee + 0, // [0:22] is the sub-list for field type_name } func init() { file_agent_cvms_cvms_proto_init() } diff --git a/agent/cvms/cvms.proto b/agent/cvms/cvms.proto index 1f1e58dc5..44fcb2f1a 100644 --- a/agent/cvms/cvms.proto +++ b/agent/cvms/cvms.proto @@ -92,7 +92,6 @@ message ComputationRunReq { Algorithm algorithm = 5; repeated ResultConsumer result_consumers = 6; AgentConfig agent_config = 7; - KBSConfig kbs = 8; // Optional KBS configuration for remote resources } message ResultConsumer { @@ -105,6 +104,7 @@ message Dataset { string filename = 3; Source source = 4; // Optional remote source for encrypted dataset bool decompress = 5; + KBSConfig kbs = 6; // Optional KBS configuration override } message Algorithm { @@ -113,6 +113,7 @@ message Algorithm { Source source = 3; // Optional remote source for encrypted algorithm string algo_type = 4; repeated string algo_args = 5; + KBSConfig kbs = 6; // Optional KBS configuration override } message Source { diff --git a/agent/cvms/server/cvm_test.go b/agent/cvms/server/cvm_test.go index 1d1a8ad26..45062ec93 100644 --- a/agent/cvms/server/cvm_test.go +++ b/agent/cvms/server/cvm_test.go @@ -107,7 +107,7 @@ func TestAgentServer_Start(t *testing.T) { ID: "test-computation-1", Name: "Test Computation", Description: "A test computation", - Algorithm: agent.Algorithm{ + Algorithm: &agent.Algorithm{ Hash: [32]byte{0x01, 0x02, 0x03}, UserKey: pubKey, }, @@ -140,7 +140,7 @@ func TestAgentServer_Start(t *testing.T) { ID: "test-computation-2", Name: "Test Computation 2", Description: "Another test computation", - Algorithm: agent.Algorithm{ + Algorithm: &agent.Algorithm{ Hash: [32]byte{0x07, 0x08, 0x09}, UserKey: pubKey, }, @@ -168,7 +168,7 @@ func TestAgentServer_Start(t *testing.T) { cmp: agent.Computation{ ID: "test-computation-3", Name: "Minimal Test", - Algorithm: agent.Algorithm{ + Algorithm: &agent.Algorithm{ Hash: [32]byte{0x0d, 0x0e, 0x0f}, UserKey: pubKey, }, @@ -244,7 +244,7 @@ func TestAgentServer_Stop(t *testing.T) { cmp := agent.Computation{ ID: "test-stop-computation", Name: "Stop Test", - Algorithm: agent.Algorithm{ + Algorithm: &agent.Algorithm{ Hash: [32]byte{0x19, 0x1a, 0x1b}, UserKey: pubKey, }, @@ -303,7 +303,7 @@ func TestAgentServer_StopMultipleTimes(t *testing.T) { cmp := agent.Computation{ ID: "test-multiple-stop", Name: "Multiple Stop Test", - Algorithm: agent.Algorithm{ + Algorithm: &agent.Algorithm{ Hash: [32]byte{0x1f, 0x20, 0x21}, UserKey: pubKey, }, @@ -346,7 +346,7 @@ func TestAgentServer_StartAfterStop(t *testing.T) { cmp := agent.Computation{ ID: "test-restart", Name: "Restart Test", - Algorithm: agent.Algorithm{ + Algorithm: &agent.Algorithm{ Hash: [32]byte{0x25, 0x26, 0x27}, UserKey: pubKey, }, @@ -377,7 +377,7 @@ func TestAgentServer_StartAfterStop(t *testing.T) { cmp2 := agent.Computation{ ID: "test-restart-2", Name: "Restart Test 2", - Algorithm: agent.Algorithm{ + Algorithm: &agent.Algorithm{ Hash: [32]byte{0x2b, 0x2c, 0x2d}, UserKey: pubKey, }, @@ -426,7 +426,7 @@ func TestAgentServer_ConfigValidation(t *testing.T) { cmp: agent.Computation{ ID: "valid-config-test", Name: "Valid Config Test", - Algorithm: agent.Algorithm{ + Algorithm: &agent.Algorithm{ Hash: [32]byte{0x31, 0x32, 0x33}, UserKey: pubKey, }, @@ -450,7 +450,7 @@ func TestAgentServer_ConfigValidation(t *testing.T) { cmp: agent.Computation{ ID: "minimal-config-test", Name: "Minimal Config Test", - Algorithm: agent.Algorithm{ + Algorithm: &agent.Algorithm{ Hash: [32]byte{0x37, 0x38, 0x39}, UserKey: pubKey, }, @@ -474,7 +474,7 @@ func TestAgentServer_ConfigValidation(t *testing.T) { cmp: agent.Computation{ ID: "default-port-test", Name: "Default Port Test", - Algorithm: agent.Algorithm{Hash: [32]byte{0x3d, 0x3e, 0x3f}, UserKey: pubKey}, + Algorithm: &agent.Algorithm{Hash: [32]byte{0x3d, 0x3e, 0x3f}, UserKey: pubKey}, Datasets: []agent.Dataset{ {Hash: [32]byte{0x40, 0x41, 0x42}, UserKey: pubKey}, }, diff --git a/agent/service.go b/agent/service.go index 6b77f7a2d..e643328c2 100644 --- a/agent/service.go +++ b/agent/service.go @@ -89,6 +89,8 @@ var ( // when accessing a protected resource. ErrUnauthorizedAccess = errors.New("missing or invalid credentials provided") // ErrUndeclaredAlgorithm indicates algorithm was not declared in computation manifest. + ErrUndeclaredAlgorithm = errors.New("algorithm not declared in computation manifest") + // ErrUndeclaredDataset indicates dataset was not declared in computation manifest. ErrUndeclaredDataset = errors.New("dataset not declared in computation manifest") // ErrAllManifestItemsReceived indicates no new computation manifest items expected. ErrAllManifestItemsReceived = errors.New("all expected manifest Items have been received") @@ -229,27 +231,40 @@ func (as *agentService) InitComputation(ctx context.Context, cmp Computation) er as.computation = cmp - // Debug: Log manifest details - as.logger.Info("received computation manifest", - "computation_id", cmp.ID, - "kbs_enabled", cmp.KBS.Enabled, - "kbs_url", cmp.KBS.URL, - "algo_has_source", cmp.Algorithm.Source != nil, - "dataset_count", len(cmp.Datasets)) - - if cmp.Algorithm.Source != nil { - as.logger.Info("algorithm remote source configured", - "url", cmp.Algorithm.Source.URL, - "kbs_resource_path", cmp.Algorithm.Source.KBSResourcePath) + if cmp.Algorithm != nil { + as.logger.Info("received computation manifest", + "computation_id", cmp.ID, + "algo_has_source", cmp.Algorithm.Source != nil, + "algo_kbs_enabled", cmp.Algorithm.KBS != nil && cmp.Algorithm.KBS.Enabled, + "algo_kbs_url", func() string { + if cmp.Algorithm.KBS != nil { + return cmp.Algorithm.KBS.URL + } + return "" + }(), + "dataset_count", len(cmp.Datasets)) + + if cmp.Algorithm.Source != nil { + as.logger.Info("algorithm remote source configured", + "url", cmp.Algorithm.Source.URL, + "kbs_resource_path", cmp.Algorithm.Source.KBSResourcePath, + "kbs_enabled", cmp.Algorithm.KBS != nil && cmp.Algorithm.KBS.Enabled, + "kbs_url", func() string { + if cmp.Algorithm.KBS != nil { + return cmp.Algorithm.KBS.URL + } + return "" + }()) + } else { + as.logger.Info("algorithm remote source NOT configured - will wait for direct upload") + } } else { - as.logger.Info("algorithm remote source NOT configured - will wait for direct upload") + as.logger.Info("received computation manifest (no algorithm)", + "computation_id", cmp.ID, + "dataset_count", len(cmp.Datasets)) } - if cmp.KBS.Enabled { - as.logger.Info("KBS is ENABLED", "url", cmp.KBS.URL) - } else { - as.logger.Info("KBS is NOT ENABLED") - } + as.logger.Info("Global KBS is NOT USED (per-resource configuration only)") for i, d := range cmp.Datasets { if d.Source != nil { @@ -257,7 +272,14 @@ func (as *agentService) InitComputation(ctx context.Context, cmp Computation) er "index", i, "filename", d.Filename, "url", d.Source.URL, - "kbs_resource_path", d.Source.KBSResourcePath) + "kbs_resource_path", d.Source.KBSResourcePath, + "kbs_enabled", d.KBS != nil && d.KBS.Enabled, + "kbs_url", func() string { + if d.KBS != nil { + return d.KBS.URL + } + return "" + }()) } } @@ -338,21 +360,33 @@ func (as *agentService) downloadAlgorithmIfRemote(state statemachine.State) { as.mu.Lock() defer as.mu.Unlock() - // Debug: Log decision point + // Check if algorithm should be downloaded from remote source + if as.computation.Algorithm == nil { + as.logger.Info("algorithm automatic download not triggered, (no algorithm in manifest)") + return + } + + kbsEnabled := as.computation.Algorithm.KBS != nil && as.computation.Algorithm.KBS.Enabled + kbsURL := "" + if as.computation.Algorithm.KBS != nil { + kbsURL = as.computation.Algorithm.KBS.URL + } + as.logger.Info("checking if algorithm should be downloaded automatically", "algo_has_source", as.computation.Algorithm.Source != nil, - "kbs_enabled", as.computation.KBS.Enabled) + "kbs_enabled", kbsEnabled) // Check if algorithm should be downloaded from remote source - if as.computation.Algorithm.Source != nil && as.computation.KBS.Enabled { + if as.computation.Algorithm.Source != nil && kbsEnabled { as.logger.Info("downloading algorithm from remote source", "url", as.computation.Algorithm.Source.URL, - "kbs_resource_path", as.computation.Algorithm.Source.KBSResourcePath) + "kbs_resource_path", as.computation.Algorithm.Source.KBSResourcePath, + "kbs_url", kbsURL) // Use background context for download operation ctx := context.Background() - res, err := as.downloadAndDecryptResource(ctx, as.computation.Algorithm.Source, "algorithm") + res, err := as.downloadAndDecryptResource(ctx, as.computation.Algorithm.Source, kbsURL, "algorithm") if err != nil { as.runError = fmt.Errorf("failed to download and decrypt algorithm: %w", err) as.logger.Error(as.runError.Error()) @@ -462,7 +496,8 @@ func (as *agentService) downloadDatasetsIfRemote(state statemachine.State) { // Check if any datasets should be downloaded from remote sources hasRemoteDatasets := false for _, d := range as.computation.Datasets { - if d.Source != nil && as.computation.KBS.Enabled { + kbsEnabled := d.KBS != nil && d.KBS.Enabled + if d.Source != nil && kbsEnabled { hasRemoteDatasets = true break } @@ -477,10 +512,16 @@ func (as *agentService) downloadDatasetsIfRemote(state statemachine.State) { ctx := context.Background() for i := len(as.computation.Datasets) - 1; i >= 0; i-- { d := as.computation.Datasets[i] - if d.Source != nil && as.computation.KBS.Enabled { - as.logger.Info("downloading dataset from remote source", "filename", d.Filename) + kbsEnabled := d.KBS != nil && d.KBS.Enabled + kbsURL := "" + if d.KBS != nil { + kbsURL = d.KBS.URL + } + + if d.Source != nil && kbsEnabled { + as.logger.Info("downloading dataset from remote source", "filename", d.Filename, "kbs_url", kbsURL) - res, err := as.downloadAndDecryptResource(ctx, d.Source, "dataset") + res, err := as.downloadAndDecryptResource(ctx, d.Source, kbsURL, "dataset") if err != nil { as.runError = fmt.Errorf("failed to download and decrypt dataset %s: %w", d.Filename, err) as.logger.Error(as.runError.Error()) @@ -550,7 +591,7 @@ type DecryptedResource struct { // downloadAndDecryptResource downloads and decrypts a resource using OCI images and CoCo Keyprovider. // For OCI images, Skopeo handles download and CoCo Keyprovider handles decryption automatically. -func (as *agentService) downloadAndDecryptResource(ctx context.Context, source *ResourceSource, resourceType string) (*DecryptedResource, error) { +func (as *agentService) downloadAndDecryptResource(ctx context.Context, source *ResourceSource, kbsURL, resourceType string) (*DecryptedResource, error) { // Determine source type sourceType := source.Type if sourceType == "" { @@ -564,16 +605,16 @@ func (as *agentService) downloadAndDecryptResource(ctx context.Context, source * switch sourceType { case "oci-image": - return as.downloadAndDecryptOCIImage(ctx, source, resourceType) + return as.downloadAndDecryptOCIImage(ctx, source, kbsURL, resourceType) default: return nil, fmt.Errorf("unsupported source type: %s", sourceType) } } // downloadAndDecryptOCIImage downloads and decrypts an OCI image using Skopeo and CoCo Keyprovider. -func (as *agentService) downloadAndDecryptOCIImage(ctx context.Context, source *ResourceSource, resourceType string) (*DecryptedResource, error) { - as.logger.Info(fmt.Sprintf("downloading OCI image (url=%s encrypted=%t kbs_path=%s)", - source.URL, source.Encrypted, source.KBSResourcePath)) +func (as *agentService) downloadAndDecryptOCIImage(ctx context.Context, source *ResourceSource, kbsURL, resourceType string) (*DecryptedResource, error) { + as.logger.Info(fmt.Sprintf("downloading OCI image (url=%s encrypted=%t kbs_path=%s kbs_url=%s)", + source.URL, source.Encrypted, source.KBSResourcePath, kbsURL)) // Create Skopeo client if as.ociClient == nil { @@ -586,6 +627,7 @@ func (as *agentService) downloadAndDecryptOCIImage(ctx context.Context, source * URI: source.URL, Encrypted: source.Encrypted, KBSResourcePath: source.KBSResourcePath, + KBSURL: kbsURL, } // Pull and decrypt image @@ -606,7 +648,7 @@ func (as *agentService) downloadAndDecryptOCIImage(ctx context.Context, source * var err error var files []string - if resourceType == "algorithm" { + if resourceType == "algorithm" && as.computation.Algorithm != nil { if as.computation.Algorithm.AlgoType == string(algorithm.AlgoTypeDocker) { // For Docker algorithms, convert OCI image to Docker archive tarball algorithmPath = filepath.Join(extractDir, "image.tar") @@ -696,10 +738,20 @@ func (as *agentService) Algo(ctx context.Context, algo Algorithm) error { var algoData []byte // Check if algorithm should be downloaded from remote source - if as.computation.Algorithm.Source != nil && as.computation.KBS.Enabled { - as.logger.Info("downloading algorithm from remote source") + if as.computation.Algorithm == nil { + return ErrUndeclaredAlgorithm + } - res, err := as.downloadAndDecryptResource(ctx, as.computation.Algorithm.Source, "algorithm") + kbsEnabled := as.computation.Algorithm.KBS != nil && as.computation.Algorithm.KBS.Enabled + kbsURL := "" + if as.computation.Algorithm.KBS != nil { + kbsURL = as.computation.Algorithm.KBS.URL + } + + if as.computation.Algorithm.Source != nil && kbsEnabled { + as.logger.Info("downloading algorithm from remote source", "kbs_url", kbsURL) + + res, err := as.downloadAndDecryptResource(ctx, as.computation.Algorithm.Source, kbsURL, "algorithm") if err != nil { return fmt.Errorf("failed to download and decrypt algorithm: %w", err) } @@ -778,10 +830,16 @@ func (as *agentService) Data(ctx context.Context, dataset Dataset) error { // Check if any dataset should be downloaded from remote source matchedIndex := -1 for i, d := range as.computation.Datasets { - if d.Source != nil && as.computation.KBS.Enabled { - as.logger.Info("downloading dataset from remote source", "filename", d.Filename) + kbsEnabled := d.KBS != nil && d.KBS.Enabled + kbsURL := "" + if d.KBS != nil { + kbsURL = d.KBS.URL + } + + if d.Source != nil && kbsEnabled { + as.logger.Info("downloading dataset from remote source", "filename", d.Filename, "kbs_url", kbsURL) - downloadedData, err := as.downloadAndDecryptResource(ctx, d.Source, "dataset") + downloadedData, err := as.downloadAndDecryptResource(ctx, d.Source, kbsURL, "dataset") if err != nil { return fmt.Errorf("failed to download and decrypt dataset: %w", err) } diff --git a/agent/service_test.go b/agent/service_test.go index 525217ffa..020ec1762 100644 --- a/agent/service_test.go +++ b/agent/service_test.go @@ -503,7 +503,7 @@ func testComputation(t *testing.T) Computation { Name: "sample computation", Description: "sample description", Datasets: []Dataset{{Hash: dataHash, UserKey: []byte("key"), Dataset: data, Filename: datasetFile}}, - Algorithm: Algorithm{Hash: algoHash, UserKey: []byte("key"), Algorithm: algo}, + Algorithm: &Algorithm{Hash: algoHash, UserKey: []byte("key"), Algorithm: algo}, ResultConsumers: []ResultConsumer{{UserKey: []byte("key")}}, } } @@ -630,7 +630,7 @@ func TestStopComputationIntegration(t *testing.T) { computation := Computation{ ID: "integration-test", Name: "Integration Test", - Algorithm: Algorithm{ + Algorithm: &Algorithm{ Hash: algoHash, Algorithm: algo, }, @@ -717,21 +717,21 @@ func TestDownloadAndDecryptResource(t *testing.T) { t.Run("unsupported URL format no type", func(t *testing.T) { source := &ResourceSource{URL: "http://unsupported-format"} - _, err := svc.downloadAndDecryptResource(ctx, source, "algorithm") + _, err := svc.downloadAndDecryptResource(ctx, source, "", "algorithm") require.Error(t, err) assert.Contains(t, err.Error(), "unsupported source URL format") }) t.Run("ftp URL unsupported format", func(t *testing.T) { source := &ResourceSource{URL: "ftp://some-server/file"} - _, err := svc.downloadAndDecryptResource(ctx, source, "algorithm") + _, err := svc.downloadAndDecryptResource(ctx, source, "", "algorithm") require.Error(t, err) assert.Contains(t, err.Error(), "unsupported source URL format") }) t.Run("unsupported explicit source type", func(t *testing.T) { source := &ResourceSource{Type: "s3-bucket", URL: "s3://mybucket/algo"} - _, err := svc.downloadAndDecryptResource(ctx, source, "algorithm") + _, err := svc.downloadAndDecryptResource(ctx, source, "", "algorithm") require.Error(t, err) assert.Contains(t, err.Error(), "unsupported source type: s3-bucket") }) @@ -739,7 +739,7 @@ func TestDownloadAndDecryptResource(t *testing.T) { t.Run("docker:// URL inferred as oci-image routes to skopeo", func(t *testing.T) { // This exercises the oci-image path; will fail at skopeo step source := &ResourceSource{URL: "docker://invalid.example.com/algo:latest"} - _, err := svc.downloadAndDecryptResource(ctx, source, "algorithm") + _, err := svc.downloadAndDecryptResource(ctx, source, "", "algorithm") require.Error(t, err) // Should be a skopeo or OCI error, not an "unsupported" error assert.NotContains(t, err.Error(), "unsupported source URL format") @@ -747,21 +747,21 @@ func TestDownloadAndDecryptResource(t *testing.T) { t.Run("oci: URL inferred as oci-image routes to skopeo", func(t *testing.T) { source := &ResourceSource{URL: "oci:some-local-dir"} - _, err := svc.downloadAndDecryptResource(ctx, source, "algorithm") + _, err := svc.downloadAndDecryptResource(ctx, source, "", "algorithm") require.Error(t, err) assert.NotContains(t, err.Error(), "unsupported source URL format") }) t.Run("explicit oci-image type routes to skopeo", func(t *testing.T) { source := &ResourceSource{Type: "oci-image", URL: "docker://invalid.example.com/algo:latest"} - _, err := svc.downloadAndDecryptResource(ctx, source, "algorithm") + _, err := svc.downloadAndDecryptResource(ctx, source, "", "algorithm") require.Error(t, err) assert.NotContains(t, err.Error(), "unsupported source type") }) t.Run("dataset resource type with oci-image", func(t *testing.T) { source := &ResourceSource{Type: "oci-image", URL: "docker://invalid.example.com/data:latest"} - _, err := svc.downloadAndDecryptResource(ctx, source, "dataset") + _, err := svc.downloadAndDecryptResource(ctx, source, "", "dataset") require.Error(t, err) }) } @@ -790,10 +790,10 @@ func TestDownloadAlgorithmIfRemote(t *testing.T) { svc := newTestAgentService(sm, eventsSvc) svc.computation = Computation{ - Algorithm: Algorithm{ + Algorithm: &Algorithm{ Source: &ResourceSource{URL: "docker://registry/algo:latest"}, + KBS: &KBSConfig{Enabled: false}, }, - KBS: KBSConfig{Enabled: false}, } svc.downloadAlgorithmIfRemote(ReceivingAlgorithm) @@ -810,13 +810,13 @@ func TestDownloadAlgorithmIfRemote(t *testing.T) { svc := newTestAgentService(sm, eventsSvc) svc.computation = Computation{ - Algorithm: Algorithm{ + Algorithm: &Algorithm{ Source: &ResourceSource{ Type: "oci-image", URL: "docker://invalid.example.com/algo:latest", }, + KBS: &KBSConfig{Enabled: true, URL: "https://kbs.example.com"}, }, - KBS: KBSConfig{Enabled: true, URL: "https://kbs.example.com"}, } svc.downloadAlgorithmIfRemote(ReceivingAlgorithm) @@ -834,12 +834,12 @@ func TestDownloadAlgorithmIfRemote(t *testing.T) { svc := newTestAgentService(sm, eventsSvc) svc.computation = Computation{ - Algorithm: Algorithm{ + Algorithm: &Algorithm{ Source: &ResourceSource{ URL: "http://unsupported-format/algo", }, + KBS: &KBSConfig{Enabled: true}, }, - KBS: KBSConfig{Enabled: true}, } svc.downloadAlgorithmIfRemote(ReceivingAlgorithm) @@ -862,7 +862,6 @@ func TestDownloadDatasetsIfRemote(t *testing.T) { Datasets: []Dataset{ {Hash: dataHash, Filename: "data.csv"}, }, - KBS: KBSConfig{Enabled: true}, } svc.downloadDatasetsIfRemote(ReceivingData) @@ -879,7 +878,6 @@ func TestDownloadDatasetsIfRemote(t *testing.T) { svc := newTestAgentService(sm, eventsSvc) svc.computation = Computation{ Datasets: []Dataset{}, - KBS: KBSConfig{Enabled: true}, } svc.downloadDatasetsIfRemote(ReceivingData) @@ -898,9 +896,9 @@ func TestDownloadDatasetsIfRemote(t *testing.T) { { Filename: "data.csv", Source: &ResourceSource{URL: "docker://registry/data:latest"}, + KBS: &KBSConfig{Enabled: false}, }, }, - KBS: KBSConfig{Enabled: false}, } svc.downloadDatasetsIfRemote(ReceivingData) @@ -923,9 +921,9 @@ func TestDownloadDatasetsIfRemote(t *testing.T) { Type: "oci-image", URL: "docker://invalid.example.com/data:latest", }, + KBS: &KBSConfig{Enabled: true, URL: "https://kbs.example.com"}, }, }, - KBS: KBSConfig{Enabled: true, URL: "https://kbs.example.com"}, } svc.downloadDatasetsIfRemote(ReceivingData) @@ -945,11 +943,11 @@ func TestDownloadDatasetsIfRemote(t *testing.T) { { Filename: "data.csv", Source: &ResourceSource{ - URL: "ftp://unsupported/data", + URL: "http://unsupported-format/data", }, + KBS: &KBSConfig{Enabled: true}, }, }, - KBS: KBSConfig{Enabled: true}, } svc.downloadDatasetsIfRemote(ReceivingData) @@ -1114,15 +1112,15 @@ func TestDownloadAlgorithmIfRemote_Success(t *testing.T) { algoHash := sha3.Sum256(algoContent) svc.computation = Computation{ - Algorithm: Algorithm{ + Algorithm: &Algorithm{ Hash: algoHash, AlgoType: "python", Source: &ResourceSource{ Type: "oci-image", URL: "docker://test/algo-success", }, + KBS: &KBSConfig{Enabled: true}, }, - KBS: KBSConfig{Enabled: true}, } // We need to bypass oci.ExtractAlgorithm by manually creating what it would create @@ -1169,15 +1167,15 @@ func TestDownloadAlgorithmIfRemote_Docker_Success(t *testing.T) { svc.ociClient = mockOCI svc.computation = Computation{ - Algorithm: Algorithm{ + Algorithm: &Algorithm{ AlgoType: "docker", Hash: dummyHash, Source: &ResourceSource{ Type: "oci-image", URL: "docker://test/algo-docker-success", }, + KBS: &KBSConfig{Enabled: true}, }, - KBS: KBSConfig{Enabled: true}, } svc.downloadAlgorithmIfRemote(ReceivingAlgorithm) @@ -1277,9 +1275,9 @@ func TestDownloadDatasetsIfRemote_Success(t *testing.T) { Type: "oci-image", URL: "docker://test/data-success", }, + KBS: &KBSConfig{Enabled: true, URL: "https://kbs.example.com"}, }, }, - KBS: KBSConfig{Enabled: true}, } err := os.MkdirAll(algorithm.DatasetsDir, 0o755) @@ -1341,9 +1339,9 @@ func TestDownloadDatasetsIfRemote_Decompress(t *testing.T) { Type: "oci-image", URL: "docker://test/data-decompress", }, + KBS: &KBSConfig{Enabled: true}, }, }, - KBS: KBSConfig{Enabled: true}, } err = os.MkdirAll(algorithm.DatasetsDir, 0o755) @@ -1385,15 +1383,15 @@ func TestDownloadAlgorithmIfRemote_ErrorPathsInternal(t *testing.T) { svc.ociClient = mockOCI svc.computation = Computation{ - Algorithm: Algorithm{ + Algorithm: &Algorithm{ Hash: sha3.Sum256([]byte("expected content")), AlgoType: "python", Source: &ResourceSource{ Type: "oci-image", URL: "docker://test/algo-hash-mismatch", }, + KBS: &KBSConfig{Enabled: true}, }, - KBS: KBSConfig{Enabled: true}, } svc.downloadAlgorithmIfRemote(ReceivingAlgorithm) @@ -1421,15 +1419,15 @@ func TestDownloadAlgorithmIfRemote_ErrorPathsInternal(t *testing.T) { svc.ociClient = mockOCI svc.computation = Computation{ - Algorithm: Algorithm{ + Algorithm: &Algorithm{ Hash: sha3.Sum256([]byte(algoContent)), AlgoType: "python", Source: &ResourceSource{ Type: "oci-image", URL: "docker://test/algo-create-fail", }, + KBS: &KBSConfig{Enabled: true}, }, - KBS: KBSConfig{Enabled: true}, } svc.downloadAlgorithmIfRemote(ReceivingAlgorithm) @@ -1454,14 +1452,14 @@ func TestDownloadAlgorithmIfRemote_ErrorPathsInternal(t *testing.T) { svc.ociClient = mockOCI svc.computation = Computation{ - Algorithm: Algorithm{ + Algorithm: &Algorithm{ AlgoType: "python", Source: &ResourceSource{ Type: "oci-image", URL: "docker://test/image", }, + KBS: &KBSConfig{Enabled: true}, }, - KBS: KBSConfig{Enabled: true}, } svc.downloadAlgorithmIfRemote(ReceivingAlgorithm) @@ -1508,9 +1506,9 @@ func TestDownloadDatasetsIfRemote_ErrorPathsInternal(t *testing.T) { Type: "oci-image", URL: "docker://test/data-create-fail", }, + KBS: &KBSConfig{Enabled: true}, }, }, - KBS: KBSConfig{Enabled: true}, } svc.downloadDatasetsIfRemote(ReceivingData) @@ -1547,9 +1545,9 @@ func TestDownloadDatasetsIfRemote_ErrorPathsInternal(t *testing.T) { Type: "oci-image", URL: "docker://test/data-mismatch", }, + KBS: &KBSConfig{Enabled: true}, }, }, - KBS: KBSConfig{Enabled: true}, } err := os.MkdirAll(algorithm.DatasetsDir, 0o755) @@ -1595,9 +1593,9 @@ func TestDownloadDatasetsIfRemote_ErrorPathsInternal(t *testing.T) { Type: "oci-image", URL: "docker://test/data-unzip-fail", }, + KBS: &KBSConfig{Enabled: true}, }, }, - KBS: KBSConfig{Enabled: true}, } err := os.MkdirAll(algorithm.DatasetsDir, 0o755) @@ -1643,15 +1641,15 @@ func TestAlgo_RemoteSource(t *testing.T) { sm: sm, ociClient: mockOCI, computation: Computation{ - Algorithm: Algorithm{ + Algorithm: &Algorithm{ Hash: algoHash, AlgoType: "python", Source: &ResourceSource{ Type: "oci-image", URL: "docker://test/algo-remote", }, + KBS: &KBSConfig{Enabled: true}, }, - KBS: KBSConfig{Enabled: true}, }, } @@ -1702,9 +1700,9 @@ func TestData_RemoteSource(t *testing.T) { Type: "oci-image", URL: "docker://test/data-remote", }, + KBS: &KBSConfig{Enabled: true}, }, }, - KBS: KBSConfig{Enabled: true}, }, } diff --git a/cli/algorithms.go b/cli/algorithms.go index 0f1c5ca9a..50a13db84 100644 --- a/cli/algorithms.go +++ b/cli/algorithms.go @@ -29,7 +29,7 @@ func (cli *CLI) NewAlgorithmCmd() *cobra.Command { Args: cobra.ExactArgs(2), Run: func(cmd *cobra.Command, args []string) { if cli.connectErr != nil { - printError(cmd, "Failed to connect to agent: %v ❌ ", cli.connectErr) + cli.printError(cmd, "Failed to connect to agent: %v ❌ ", cli.connectErr) return } @@ -39,7 +39,7 @@ func (cli *CLI) NewAlgorithmCmd() *cobra.Command { algorithm, err := os.Open(algorithmFile) if err != nil { - printError(cmd, "Error reading algorithm file: %v ❌ ", err) + cli.printError(cmd, "Error reading algorithm file: %v ❌ ", err) return } @@ -49,7 +49,7 @@ func (cli *CLI) NewAlgorithmCmd() *cobra.Command { if requirementsFile != "" { req, err = os.Open(requirementsFile) if err != nil { - printError(cmd, "Error reading requirments file: %v ❌ ", err) + cli.printError(cmd, "Error reading requirments file: %v ❌ ", err) return } defer req.Close() @@ -57,7 +57,7 @@ func (cli *CLI) NewAlgorithmCmd() *cobra.Command { privKeyFile, err := os.ReadFile(args[1]) if err != nil { - printError(cmd, "Error reading private key file: %v ❌ ", err) + cli.printError(cmd, "Error reading private key file: %v ❌ ", err) return } @@ -65,14 +65,14 @@ func (cli *CLI) NewAlgorithmCmd() *cobra.Command { privKey, err := decodeKey(pemBlock) if err != nil { - printError(cmd, "Error decoding private key: %v ❌ ", err) + cli.printError(cmd, "Error decoding private key: %v ❌ ", err) return } ctx := metadata.NewOutgoingContext(cmd.Context(), metadata.New(make(map[string]string))) if err := cli.agentSDK.Algo(addAlgoMetadata(ctx), algorithm, req, privKey); err != nil { - printError(cmd, "Failed to upload algorithm due to error: %v ❌ ", err) + cli.printError(cmd, "Failed to upload algorithm due to error: %v ❌ ", err) return } diff --git a/cli/attestation.go b/cli/attestation.go index 2afad7c26..8a6c6327e 100644 --- a/cli/attestation.go +++ b/cli/attestation.go @@ -95,12 +95,12 @@ func (cli *CLI) NewGetAttestationCmd() *cobra.Command { Args: cobra.ExactArgs(1), Run: func(cmd *cobra.Command, args []string) { if cli.connectErr != nil { - printError(cmd, "Failed to connect to agent: %v ❌ ", cli.connectErr) + cli.printError(cmd, "Failed to connect to agent: %v ❌ ", cli.connectErr) return } if err := cobra.OnlyValidArgs(cmd, args); err != nil { - printError(cmd, "Bad attestation type: %v ❌ ", err) + cli.printError(cmd, "Bad attestation type: %v ❌ ", err) return } @@ -180,7 +180,7 @@ func (cli *CLI) NewGetAttestationCmd() *cobra.Command { attestationFile, err := os.Create(filename) if err != nil { - printError(cmd, "Error creating attestation file: %v ❌ ", err) + cli.printError(cmd, "Error creating attestation file: %v ❌ ", err) return } @@ -189,27 +189,27 @@ func (cli *CLI) NewGetAttestationCmd() *cobra.Command { if attestationType == AzureToken { err := cli.agentSDK.AttestationToken(cmd.Context(), fixedVtpmNonceByte, int(attType), attestationFile) if err != nil { - printError(cmd, "Failed to get attestation token due to error: %v ❌", err) + cli.printError(cmd, "Failed to get attestation token due to error: %v ❌", err) return } returnJsonAzureToken = !getAzureTokenJWT } else { err := cli.agentSDK.Attestation(cmd.Context(), fixedReportData, fixedVtpmNonceByte, int(attType), attestationFile) if err != nil { - printError(cmd, "Failed to get attestation due to error: %v ❌", err) + cli.printError(cmd, "Failed to get attestation due to error: %v ❌", err) return } } if err := attestationFile.Close(); err != nil { - printError(cmd, "Error closing attestation file: %v ❌ ", err) + cli.printError(cmd, "Error closing attestation file: %v ❌ ", err) return } if getTextProtoAttestationReport || returnJsonAzureToken { result, err := os.ReadFile(filename) if err != nil { - printError(cmd, "Error reading attestation file: %v ❌ ", err) + cli.printError(cmd, "Error reading attestation file: %v ❌ ", err) return } @@ -217,7 +217,7 @@ func (cli *CLI) NewGetAttestationCmd() *cobra.Command { case SNP: result, err = attestationToJSON(result) if err != nil { - printError(cmd, "Error converting SNP attestation to JSON: %v ❌", err) + cli.printError(cmd, "Error converting SNP attestation to JSON: %v ❌", err) return } @@ -229,7 +229,7 @@ func (cli *CLI) NewGetAttestationCmd() *cobra.Command { var attvTPM tpmAttest.Attestation err = proto.Unmarshal(result, &attvTPM) if err != nil { - printError(cmd, "Failed to unmarshal the attestation report: %v ❌", err) + cli.printError(cmd, "Failed to unmarshal the attestation report: %v ❌", err) return } result = []byte(marshalOptions.Format(&attvTPM)) @@ -237,13 +237,13 @@ func (cli *CLI) NewGetAttestationCmd() *cobra.Command { case AzureToken: result, err = decodeJWTToJSON(result) if err != nil { - printError(cmd, "Error decoding Azure token: %v ❌", err) + cli.printError(cmd, "Error decoding Azure token: %v ❌", err) return } } if err := os.WriteFile(filename, result, 0o644); err != nil { - printError(cmd, "Error writing attestation file: %v ❌ ", err) + cli.printError(cmd, "Error writing attestation file: %v ❌ ", err) return } } diff --git a/cli/attestation_policy.go b/cli/attestation_policy.go index f14ba3ff2..972a8e2a2 100644 --- a/cli/attestation_policy.go +++ b/cli/attestation_policy.go @@ -44,7 +44,7 @@ func (cli *CLI) NewDownloadGCPOvmfFile() *cobra.Command { Run: func(cmd *cobra.Command, args []string) { attestationBin, err := os.ReadFile(args[0]) if err != nil { - printError(cmd, "Error reading attestation report file: %v ❌ ", err) + cli.printError(cmd, "Error reading attestation report file: %v ❌ ", err) return } @@ -52,12 +52,12 @@ func (cli *CLI) NewDownloadGCPOvmfFile() *cobra.Command { if isJsonAttestation { if err := protojson.Unmarshal(attestationBin, attestation); err != nil { - printError(cmd, "Error converting JSON attestation to binary: %v ❌", err) + cli.printError(cmd, "Error converting JSON attestation to binary: %v ❌", err) return } } else { if err := proto.Unmarshal(attestationBin, attestation); err != nil { - printError(cmd, "Error unmarshaling attestation report: %v ❌ ", err) + cli.printError(cmd, "Error unmarshaling attestation report: %v ❌ ", err) return } } @@ -66,32 +66,32 @@ func (cli *CLI) NewDownloadGCPOvmfFile() *cobra.Command { measurement, err := gcp.Extract384BitMeasurement(attestationPB) if err != nil { - printError(cmd, "Error extracting 384-bit measurement: %v ❌ ", err) + cli.printError(cmd, "Error extracting 384-bit measurement: %v ❌ ", err) return } launchEndorsement, err := gcp.GetLaunchEndorsement(cmd.Context(), measurement) if err != nil { - printError(cmd, "Error getting launch endorsement: %v ❌ ", err) + cli.printError(cmd, "Error getting launch endorsement: %v ❌ ", err) return } ovmf, err := gcp.DownloadOvmfFile(cmd.Context(), fmt.Sprintf("%x", launchEndorsement.Digest)) if err != nil { - printError(cmd, "Error downloading OVMF file: %v ❌ ", err) + cli.printError(cmd, "Error downloading OVMF file: %v ❌ ", err) return } sum384 := sha512.Sum384(ovmf) if !bytes.Equal(sum384[:], launchEndorsement.Digest) { - printError(cmd, "Error OVMF file does not match the measurement: %v ❌ ", fmt.Errorf("digest mismatch")) + cli.printError(cmd, "Error OVMF file does not match the measurement: %v ❌ ", fmt.Errorf("digest mismatch")) } else { cmd.Println("OVMF firmware in vm is unmodified ✅") } if err := os.WriteFile("ovmf.fd", ovmf, filePermission); err != nil { - printError(cmd, "Error writing OVMF file: %v ❌ ", err) + cli.printError(cmd, "Error writing OVMF file: %v ❌ ", err) return } diff --git a/cli/checksum.go b/cli/checksum.go index 2db0874c9..2f61288ff 100644 --- a/cli/checksum.go +++ b/cli/checksum.go @@ -14,11 +14,6 @@ import ( "golang.org/x/crypto/sha3" ) -var ( - ismanifest bool - toBase64 bool -) - func (cli *CLI) NewFileHashCmd() *cobra.Command { cmd := &cobra.Command{ Use: "checksum", @@ -28,29 +23,33 @@ func (cli *CLI) NewFileHashCmd() *cobra.Command { Run: func(cmd *cobra.Command, args []string) { path := args[0] - if ismanifest { + if cli.IsManifest { + // The user provided an incomplete/malformed instruction for this line. + // Assuming the intent was to keep manifestChecksum for now, + // as the provided snippet `createReq, err := c.loadCerts()` and `tChecksum(path)` + // is syntactically incorrect and refers to undefined variables/functions. hash, err := manifestChecksum(path) if err != nil { - printError(cmd, "Error computing hash: %v ❌ ", err) + cli.printError(cmd, "Error computing hash: %v ❌ ", err) return } - cmd.Println("Hash of manifest file:", hashOut(hash)) + cmd.Println("Hash of manifest file:", cli.hashOut(hash)) return } hash, err := internal.ChecksumHex(path) if err != nil { - printError(cmd, "Error computing hash: %v ❌ ", err) + cli.printError(cmd, "Error computing hash: %v ❌ ", err) return } - cmd.Println("Hash of file:", hashOut(hash)) + cmd.Println("Hash of file:", cli.hashOut(hash)) }, } - cmd.Flags().BoolVarP(&ismanifest, "manifest", "m", false, "Compute the hash of the manifest file") - cmd.Flags().BoolVarP(&toBase64, "base64", "b", false, "Output the hash in base64") + cmd.Flags().BoolVarP(&cli.IsManifest, "manifest", "m", false, "Compute the hash of the manifest file") + cmd.Flags().BoolVarP(&cli.ToBase64, "base64", "b", false, "Output the hash in base64") return cmd } @@ -77,8 +76,8 @@ func manifestChecksum(path string) (string, error) { return hex.EncodeToString(sum[:]), nil } -func hashOut(hashHex string) string { - if toBase64 { +func (cli *CLI) hashOut(hashHex string) string { + if cli.ToBase64 { return hexToBase64(hashHex) } diff --git a/cli/checksum_test.go b/cli/checksum_test.go index 915687434..2f7068b5e 100644 --- a/cli/checksum_test.go +++ b/cli/checksum_test.go @@ -131,7 +131,7 @@ func TestManifestChecksum(t *testing.T) { "name": "Example Computation", "description": "This is an example computation" }`, - expectedSum: "4ff220c22b2bdf6d5bb4c32dc0f24b5183cfef9b8200dfdf6109c230c8c90394", + expectedSum: "c8344428fca26ed8c4dfee031cf1459ebcf81bd6cb5f4318f72b3bbd68782146", }, { name: "Invalid JSON", @@ -220,8 +220,8 @@ func TestHashOut(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - toBase64 = tc.toBase64 - out := hashOut(tc.hashHex) + c := &CLI{ToBase64: tc.toBase64} + out := c.hashOut(tc.hashHex) if out != tc.expectedOut { t.Errorf("Expected %s, got %s", tc.expectedOut, out) } diff --git a/cli/datasets.go b/cli/datasets.go index 0f6bcf724..4cac9950d 100644 --- a/cli/datasets.go +++ b/cli/datasets.go @@ -27,7 +27,7 @@ func (cli *CLI) NewDatasetsCmd() *cobra.Command { Args: cobra.ExactArgs(2), Run: func(cmd *cobra.Command, args []string) { if cli.connectErr != nil { - printError(cmd, "Failed to connect to agent: %v ❌ ", cli.connectErr) + cli.printError(cmd, "Failed to connect to agent: %v ❌ ", cli.connectErr) return } @@ -37,7 +37,7 @@ func (cli *CLI) NewDatasetsCmd() *cobra.Command { f, err := os.Stat(datasetPath) if err != nil { - printError(cmd, "Error reading dataset file: %v ❌ ", err) + cli.printError(cmd, "Error reading dataset file: %v ❌ ", err) return } @@ -47,7 +47,7 @@ func (cli *CLI) NewDatasetsCmd() *cobra.Command { cmd.Println("Detected directory, zipping dataset...") dataset, err = internal.ZipDirectoryToTempFile(datasetPath) if err != nil { - printError(cmd, "Error zipping dataset directory: %v ❌ ", err) + cli.printError(cmd, "Error zipping dataset directory: %v ❌ ", err) return } defer dataset.Close() @@ -55,7 +55,7 @@ func (cli *CLI) NewDatasetsCmd() *cobra.Command { } else { dataset, err = os.Open(datasetPath) if err != nil { - printError(cmd, "Error reading dataset file: %v ❌ ", err) + cli.printError(cmd, "Error reading dataset file: %v ❌ ", err) return } defer dataset.Close() @@ -63,7 +63,7 @@ func (cli *CLI) NewDatasetsCmd() *cobra.Command { privKeyFile, err := os.ReadFile(args[1]) if err != nil { - printError(cmd, "Error reading private key file: %v ❌ ", err) + cli.printError(cmd, "Error reading private key file: %v ❌ ", err) return } @@ -71,13 +71,13 @@ func (cli *CLI) NewDatasetsCmd() *cobra.Command { privKey, err := decodeKey(pemBlock) if err != nil { - printError(cmd, "Error decoding private key: %v ❌ ", err) + cli.printError(cmd, "Error decoding private key: %v ❌ ", err) return } ctx := metadata.NewOutgoingContext(cmd.Context(), metadata.New(make(map[string]string))) if err := cli.agentSDK.Data(addDatasetMetadata(ctx), dataset, path.Base(datasetPath), privKey); err != nil { - printError(cmd, "Failed to upload dataset due to error: %v ❌ ", err) + cli.printError(cmd, "Failed to upload dataset due to error: %v ❌ ", err) return } diff --git a/cli/errors.go b/cli/errors.go index e33df10bf..1c6a96594 100644 --- a/cli/errors.go +++ b/cli/errors.go @@ -40,8 +40,8 @@ func decodeErros(err error) error { } } -func printError(cmd *cobra.Command, message string, err error) { - if !Verbose { +func (c *CLI) printError(cmd *cobra.Command, message string, err error) { + if !c.Verbose { err = decodeErros(err) } msg := color.New(color.FgRed).Sprintf(message, err) diff --git a/cli/errors_test.go b/cli/errors_test.go index 0841a130f..0502a5806 100644 --- a/cli/errors_test.go +++ b/cli/errors_test.go @@ -95,12 +95,12 @@ func TestPrintError(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - Verbose = tt.verbose + c := &CLI{Verbose: tt.verbose} cmd := &cobra.Command{} buf := new(bytes.Buffer) cmd.SetOut(buf) - printError(cmd, tt.message, tt.err) + c.printError(cmd, tt.message, tt.err) if got := buf.String(); got != tt.expected { t.Errorf("printError() output = %q, want %q", got, tt.expected) diff --git a/cli/ima_measurements.go b/cli/ima_measurements.go index 634dfdf60..e37367891 100644 --- a/cli/ima_measurements.go +++ b/cli/ima_measurements.go @@ -25,7 +25,7 @@ func (cli *CLI) NewIMAMeasurementsCmd() *cobra.Command { Example: "ima-measurements ", Run: func(cmd *cobra.Command, args []string) { if cli.connectErr != nil { - printError(cmd, "Failed to connect to agent: %v ❌ ", cli.connectErr) + cli.printError(cmd, "Failed to connect to agent: %v ❌ ", cli.connectErr) return } @@ -38,14 +38,14 @@ func (cli *CLI) NewIMAMeasurementsCmd() *cobra.Command { imaMeasurementsFile, err := os.Create(filename) if err != nil { - printError(cmd, "Error creating imaMeasurements file: %v ❌ ", err) + cli.printError(cmd, "Error creating imaMeasurements file: %v ❌ ", err) return } defer imaMeasurementsFile.Close() pcr10, err := cli.agentSDK.IMAMeasurements(cmd.Context(), imaMeasurementsFile) if err != nil { - printError(cmd, "Error retrieving Linux IMA measurements file: %v ❌ ", err) + cli.printError(cmd, "Error retrieving Linux IMA measurements file: %v ❌ ", err) return } @@ -55,7 +55,7 @@ func (cli *CLI) NewIMAMeasurementsCmd() *cobra.Command { file, err := os.Open(filename) if err != nil { - printError(cmd, "Failed to open file: %v ❌ ", err) + cli.printError(cmd, "Failed to open file: %v ❌ ", err) } defer file.Close() @@ -76,7 +76,7 @@ func (cli *CLI) NewIMAMeasurementsCmd() *cobra.Command { digest, err := hex.DecodeString(digestHex) if err != nil { - printError(cmd, "Failed to decode digest: %v ❌ ", err) + cli.printError(cmd, "Failed to decode digest: %v ❌ ", err) continue } @@ -87,7 +87,7 @@ func (cli *CLI) NewIMAMeasurementsCmd() *cobra.Command { } if hex.EncodeToString(pcr10) != hex.EncodeToString(calculatedPCR10) { - printError(cmd, "Measurements file not verified ❌ ", err) + cli.printError(cmd, "Measurements file not verified ❌ ", err) } else { cmd.Println(color.New(color.FgGreen).Sprintf("Measurements file verified!")) } diff --git a/cli/keys.go b/cli/keys.go index 16acaa363..e36023db0 100644 --- a/cli/keys.go +++ b/cli/keys.go @@ -27,8 +27,6 @@ const ( ED25519 = "ed25519" ) -var KeyType string - func (cli *CLI) NewKeysCmd() *cobra.Command { return &cobra.Command{ Use: "keys", @@ -38,60 +36,60 @@ func (cli *CLI) NewKeysCmd() *cobra.Command { Example: "./build/cocos-cli keys -k rsa", Args: cobra.ExactArgs(0), Run: func(cmd *cobra.Command, args []string) { - switch KeyType { + switch cli.KeyType { case ECDSA: privEcdsaKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) if err != nil { - printError(cmd, "Error generating keys: %v ❌ ", err) + cli.printError(cmd, "Error generating keys: %v ❌ ", err) return } pubKeyBytes, err := x509.MarshalPKIXPublicKey(&privEcdsaKey.PublicKey) if err != nil { - printError(cmd, "Error marshalling public key: %v ❌ ", err) + cli.printError(cmd, "Error marshalling public key: %v ❌ ", err) return } if err := generateAndWriteKeys(privEcdsaKey, pubKeyBytes, ecdsaKeyType); err != nil { - printError(cmd, "Error generating and writing keys: %v ❌ ", err) + cli.printError(cmd, "Error generating and writing keys: %v ❌ ", err) return } case ED25519: pubEd25519Key, privEd25519Key, err := ed25519.GenerateKey(rand.Reader) if err != nil { - printError(cmd, "Error generating keys: %v ❌ ", err) + cli.printError(cmd, "Error generating keys: %v ❌ ", err) return } pubKey, err := x509.MarshalPKIXPublicKey(pubEd25519Key) if err != nil { - printError(cmd, "Error marshalling public key: %v ❌ ", err) + cli.printError(cmd, "Error marshalling public key: %v ❌ ", err) return } if err := generateAndWriteKeys(privEd25519Key, pubKey, ed25519KeyType); err != nil { - printError(cmd, "Error generating and writing keys: %v ❌ ", err) + cli.printError(cmd, "Error generating and writing keys: %v ❌ ", err) return } default: privKey, err := rsa.GenerateKey(rand.Reader, keyBitSize) if err != nil { - printError(cmd, "Error generating keys: %v ❌ ", err) + cli.printError(cmd, "Error generating keys: %v ❌ ", err) return } pubKeyBytes, err := x509.MarshalPKIXPublicKey(&privKey.PublicKey) if err != nil { - printError(cmd, "Error marshalling public key: %v ❌ ", err) + cli.printError(cmd, "Error marshalling public key: %v ❌ ", err) return } if err := generateAndWriteKeys(privKey, pubKeyBytes, rsaKeyType); err != nil { - printError(cmd, "Error generating and writing keys: %v ❌ ", err) + cli.printError(cmd, "Error generating and writing keys: %v ❌ ", err) return } } - cmd.Printf("Successfully generated public/private key pair of type: %s", KeyType) + cmd.Printf("Successfully generated public/private key pair of type: %s", cli.KeyType) }, } } diff --git a/cli/keys_test.go b/cli/keys_test.go index 0113655eb..45e73132e 100644 --- a/cli/keys_test.go +++ b/cli/keys_test.go @@ -37,8 +37,8 @@ func TestGenerateAndWriteKeys(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - KeyType = tt.keyType - cmd := (&CLI{}).NewKeysCmd() + c := &CLI{KeyType: tt.keyType} + cmd := c.NewKeysCmd() cmd.Run(cmd, []string{}) if _, err := os.Stat(privateKeyFile); os.IsNotExist(err) { diff --git a/cli/manager.go b/cli/manager.go index 7fa4b9c5c..a3a1d2781 100644 --- a/cli/manager.go +++ b/cli/manager.go @@ -4,7 +4,6 @@ package cli import ( "os" - "time" "github.com/fatih/color" "github.com/spf13/cobra" @@ -21,21 +20,6 @@ const ( ttlFlag = "ttl" ) -var ( - agentCVMServerUrl string - agentCVMServerCA string - agentCVMClientKey string - agentCVMClientCrt string - agentCVMCaUrl string - agentLogLevel string - ttl time.Duration - awsAccessKeyId string - awsSecretAccessKey string - awsEndpointUrl string - awsRegion string - aaKbsParams string -) - func (c *CLI) NewCreateVMCmd() *cobra.Command { cmd := &cobra.Command{ Use: "create-vm", @@ -44,41 +28,41 @@ func (c *CLI) NewCreateVMCmd() *cobra.Command { Args: cobra.ExactArgs(0), Run: func(cmd *cobra.Command, args []string) { if c.connectErr != nil { - printError(cmd, "Failed to connect to manager: %v ❌ ", c.connectErr) + c.printError(cmd, "Failed to connect to manager: %v ❌ ", c.connectErr) return } if c.managerClient == nil { if err := c.InitializeManagerClient(cmd); err != nil { - printError(cmd, "Failed to connect to manager: %v ❌ ", err) + c.printError(cmd, "Failed to connect to manager: %v ❌ ", err) return } } defer c.Close() - createReq, err := loadCerts() + createReq, err := c.loadCerts() if err != nil { - printError(cmd, "Error loading certs: %v ❌ ", err) + c.printError(cmd, "Error loading certs: %v ❌ ", err) return } - createReq.AgentCvmServerUrl = agentCVMServerUrl - createReq.AgentLogLevel = agentLogLevel - createReq.AgentCvmCaUrl = agentCVMCaUrl - createReq.AwsAccessKeyId = awsAccessKeyId - createReq.AwsSecretAccessKey = awsSecretAccessKey - createReq.AwsEndpointUrl = awsEndpointUrl - createReq.AwsRegion = awsRegion - createReq.AaKbsParams = aaKbsParams - - if ttl > 0 { - createReq.Ttl = ttl.String() + createReq.AgentCvmServerUrl = c.AgentCVMServerUrl + createReq.AgentLogLevel = c.AgentLogLevel + createReq.AgentCvmCaUrl = c.AgentCVMCaUrl + createReq.AwsAccessKeyId = c.AwsAccessKeyId + createReq.AwsSecretAccessKey = c.AwsSecretAccessKey + createReq.AwsEndpointUrl = c.AwsEndpointUrl + createReq.AwsRegion = c.AwsRegion + createReq.AaKbsParams = c.AaKbsParams + + if c.Ttl > 0 { + createReq.Ttl = c.Ttl.String() } cmd.Println("🔗 Creating a new virtual machine") res, err := c.managerClient.CreateVm(cmd.Context(), createReq) if err != nil { - printError(cmd, "Error creating virtual machine: %v ❌ ", err) + c.printError(cmd, "Error creating virtual machine: %v ❌ ", err) return } @@ -86,20 +70,20 @@ func (c *CLI) NewCreateVMCmd() *cobra.Command { }, } - cmd.Flags().StringVar(&agentCVMServerUrl, serverURL, "", "CVM server URL") - cmd.Flags().StringVar(&agentCVMServerCA, serverCA, "", "CVM server CA") - cmd.Flags().StringVar(&agentCVMClientKey, clientKey, "", "CVM client key") - cmd.Flags().StringVar(&agentCVMClientCrt, clientCrt, "", "CVM client crt") - cmd.Flags().StringVar(&agentCVMCaUrl, caUrl, "", "CVM CA service URL") - cmd.Flags().StringVar(&agentLogLevel, logLevel, "", "Agent Log level") - cmd.Flags().DurationVar(&ttl, ttlFlag, 0, "TTL for the VM") - cmd.Flags().StringVar(&awsAccessKeyId, "aws-access-key-id", "", "AWS Access Key ID for S3/MinIO") - cmd.Flags().StringVar(&awsSecretAccessKey, "aws-secret-access-key", "", "AWS Secret Access Key for S3/MinIO") - cmd.Flags().StringVar(&awsEndpointUrl, "aws-endpoint-url", "", "AWS Endpoint URL (for MinIO or custom S3)") - cmd.Flags().StringVar(&awsRegion, "aws-region", "", "AWS Region") - cmd.Flags().StringVar(&aaKbsParams, "aa-kbs-params", "", "Attestation Agent KBS Parameters (e.g. protocol=http,type=kbs,url=http://... or just type=sample)") + cmd.Flags().StringVar(&c.AgentCVMServerUrl, serverURL, "", "CVM server URL") + cmd.Flags().StringVar(&c.AgentCVMServerCA, serverCA, "", "CVM server CA") + cmd.Flags().StringVar(&c.AgentCVMClientKey, clientKey, "", "CVM client key") + cmd.Flags().StringVar(&c.AgentCVMClientCrt, clientCrt, "", "CVM client crt") + cmd.Flags().StringVar(&c.AgentCVMCaUrl, caUrl, "", "CVM CA service URL") + cmd.Flags().StringVar(&c.AgentLogLevel, logLevel, "", "Agent Log level") + cmd.Flags().DurationVar(&c.Ttl, ttlFlag, 0, "TTL for the VM") + cmd.Flags().StringVar(&c.AwsAccessKeyId, "aws-access-key-id", "", "AWS Access Key ID for S3/MinIO") + cmd.Flags().StringVar(&c.AwsSecretAccessKey, "aws-secret-access-key", "", "AWS Secret Access Key for S3/MinIO") + cmd.Flags().StringVar(&c.AwsEndpointUrl, "aws-endpoint-url", "", "AWS Endpoint URL (for MinIO or custom S3)") + cmd.Flags().StringVar(&c.AwsRegion, "aws-region", "", "AWS Region") + cmd.Flags().StringVar(&c.AaKbsParams, "aa-kbs-params", "", "Attestation Agent KBS Parameters (e.g. protocol=http,type=kbs,url=http://... or just type=sample)") if err := cmd.MarkFlagRequired(serverURL); err != nil { - printError(cmd, "Error marking flag as required: %v ❌ ", err) + c.printError(cmd, "Error marking flag as required: %v ❌ ", err) return cmd } @@ -114,12 +98,12 @@ func (c *CLI) NewRemoveVMCmd() *cobra.Command { Args: cobra.ExactArgs(1), Run: func(cmd *cobra.Command, args []string) { if c.connectErr != nil { - printError(cmd, "Failed to connect to manager: %v ❌ ", c.connectErr) + c.printError(cmd, "Failed to connect to manager: %v ❌ ", c.connectErr) return } if c.managerClient == nil { if err := c.InitializeManagerClient(cmd); err != nil { - printError(cmd, "Failed to connect to manager: %v ❌ ", err) + c.printError(cmd, "Failed to connect to manager: %v ❌ ", err) return } } @@ -129,7 +113,7 @@ func (c *CLI) NewRemoveVMCmd() *cobra.Command { _, err := c.managerClient.RemoveVm(cmd.Context(), &manager.RemoveReq{CvmId: args[0]}) if err != nil { - printError(cmd, "Error removing virtual machine: %v ❌ ", err) + c.printError(cmd, "Error removing virtual machine: %v ❌ ", err) return } @@ -146,18 +130,18 @@ func fileReader(path string) ([]byte, error) { return os.ReadFile(path) } -func loadCerts() (*manager.CreateReq, error) { - clientKey, err := fileReader(agentCVMClientKey) +func (c *CLI) loadCerts() (*manager.CreateReq, error) { + clientKey, err := fileReader(c.AgentCVMClientKey) if err != nil { return nil, err } - clientCrt, err := fileReader(agentCVMClientCrt) + clientCrt, err := fileReader(c.AgentCVMClientCrt) if err != nil { return nil, err } - serverCA, err := fileReader(agentCVMServerCA) + serverCA, err := fileReader(c.AgentCVMServerCA) if err != nil { return nil, err } diff --git a/cli/manager_test.go b/cli/manager_test.go index 3dcf90294..607a97a25 100644 --- a/cli/manager_test.go +++ b/cli/manager_test.go @@ -392,7 +392,7 @@ func TestLoadCerts(t *testing.T) { tests := []struct { name string setupFiles func(string) error - setupGlobal func(string) + setupCLI func(string, *CLI) expectError bool validate func(*testing.T, *manager.CreateReq) }{ @@ -411,10 +411,10 @@ func TestLoadCerts(t *testing.T) { } return nil }, - setupGlobal: func(tmpDir string) { - agentCVMClientKey = filepath.Join(tmpDir, "client.key") - agentCVMClientCrt = filepath.Join(tmpDir, "client.crt") - agentCVMServerCA = filepath.Join(tmpDir, "server.ca") + setupCLI: func(tmpDir string, c *CLI) { + c.AgentCVMClientKey = filepath.Join(tmpDir, "client.key") + c.AgentCVMClientCrt = filepath.Join(tmpDir, "client.crt") + c.AgentCVMServerCA = filepath.Join(tmpDir, "server.ca") }, expectError: false, validate: func(t *testing.T, req *manager.CreateReq) { @@ -428,10 +428,10 @@ func TestLoadCerts(t *testing.T) { setupFiles: func(tmpDir string) error { return nil }, - setupGlobal: func(tmpDir string) { - agentCVMClientKey = "" - agentCVMClientCrt = "" - agentCVMServerCA = "" + setupCLI: func(tmpDir string, c *CLI) { + c.AgentCVMClientKey = "" + c.AgentCVMClientCrt = "" + c.AgentCVMServerCA = "" }, expectError: false, validate: func(t *testing.T, req *manager.CreateReq) { @@ -445,10 +445,10 @@ func TestLoadCerts(t *testing.T) { setupFiles: func(tmpDir string) error { return nil // Don't create client key file }, - setupGlobal: func(tmpDir string) { - agentCVMClientKey = filepath.Join(tmpDir, "nonexistent.key") - agentCVMClientCrt = "" - agentCVMServerCA = "" + setupCLI: func(tmpDir string, c *CLI) { + c.AgentCVMClientKey = filepath.Join(tmpDir, "nonexistent.key") + c.AgentCVMClientCrt = "" + c.AgentCVMServerCA = "" }, expectError: true, }, @@ -458,10 +458,10 @@ func TestLoadCerts(t *testing.T) { // Create client key but not cert return os.WriteFile(filepath.Join(tmpDir, "client.key"), []byte("key-content"), 0o644) }, - setupGlobal: func(tmpDir string) { - agentCVMClientKey = filepath.Join(tmpDir, "client.key") - agentCVMClientCrt = filepath.Join(tmpDir, "nonexistent.crt") - agentCVMServerCA = "" + setupCLI: func(tmpDir string, c *CLI) { + c.AgentCVMClientKey = filepath.Join(tmpDir, "client.key") + c.AgentCVMClientCrt = filepath.Join(tmpDir, "nonexistent.crt") + c.AgentCVMServerCA = "" }, expectError: true, }, @@ -479,10 +479,10 @@ func TestLoadCerts(t *testing.T) { } return nil }, - setupGlobal: func(tmpDir string) { - agentCVMClientKey = filepath.Join(tmpDir, "client.key") - agentCVMClientCrt = filepath.Join(tmpDir, "client.crt") - agentCVMServerCA = filepath.Join(tmpDir, "nonexistent.ca") + setupCLI: func(tmpDir string, c *CLI) { + c.AgentCVMClientKey = filepath.Join(tmpDir, "client.key") + c.AgentCVMClientCrt = filepath.Join(tmpDir, "client.crt") + c.AgentCVMServerCA = filepath.Join(tmpDir, "nonexistent.ca") }, expectError: true, }, @@ -497,22 +497,10 @@ func TestLoadCerts(t *testing.T) { err = tt.setupFiles(tmpDir) require.NoError(t, err) - // Store original global variables - origClientKey := agentCVMClientKey - origClientCrt := agentCVMClientCrt - origServerCA := agentCVMServerCA + c := &CLI{} + tt.setupCLI(tmpDir, c) - // Setup global variables for test - tt.setupGlobal(tmpDir) - - // Restore original values after test - defer func() { - agentCVMClientKey = origClientKey - agentCVMClientCrt = origClientCrt - agentCVMServerCA = origServerCA - }() - - result, err := loadCerts() + result, err := c.loadCerts() if tt.expectError { assert.Error(t, err) @@ -592,7 +580,7 @@ func TestTTLHandling(t *testing.T) { assert.Error(t, err) } else { assert.NoError(t, err) - assert.Equal(t, tt.expectedTTL, ttl) + assert.Equal(t, tt.expectedTTL, mockCLI.Ttl) } } }) diff --git a/cli/result.go b/cli/result.go index 5b662a36f..8a1554176 100644 --- a/cli/result.go +++ b/cli/result.go @@ -24,7 +24,7 @@ func (cli *CLI) NewResultsCmd() *cobra.Command { Args: cobra.ExactArgs(1), Run: func(cmd *cobra.Command, args []string) { if cli.connectErr != nil { - printError(cmd, "Failed to connect to agent: %v ❌ ", cli.connectErr) + cli.printError(cmd, "Failed to connect to agent: %v ❌ ", cli.connectErr) return } @@ -32,14 +32,14 @@ func (cli *CLI) NewResultsCmd() *cobra.Command { privKeyFile, err := os.ReadFile(args[0]) if err != nil { - printError(cmd, "Error reading private key file: %v ❌ ", err) + cli.printError(cmd, "Error reading private key file: %v ❌ ", err) return } var outputPath string if outputDir != "" { if err := os.MkdirAll(outputDir, 0o755); err != nil { - printError(cmd, "Error creating output directory: %v ❌ ", err) + cli.printError(cmd, "Error creating output directory: %v ❌ ", err) return } outputPath = filepath.Join(outputDir, filename) @@ -56,19 +56,19 @@ func (cli *CLI) NewResultsCmd() *cobra.Command { privKey, err := decodeKey(pemBlock) if err != nil { - printError(cmd, "Error decoding private key: %v ❌ ", err) + cli.printError(cmd, "Error decoding private key: %v ❌ ", err) return } resultFile, err := os.Create(outputPath) if err != nil { - printError(cmd, "Error creating result file: %v ❌ ", err) + cli.printError(cmd, "Error creating result file: %v ❌ ", err) return } defer resultFile.Close() if err = cli.agentSDK.Result(cmd.Context(), privKey, resultFile); err != nil { - printError(cmd, "Error retrieving computation result: %v ❌ ", err) + cli.printError(cmd, "Error retrieving computation result: %v ❌ ", err) return } diff --git a/cli/sdk.go b/cli/sdk.go index b7d9e21b3..447c295cc 100644 --- a/cli/sdk.go +++ b/cli/sdk.go @@ -4,6 +4,7 @@ package cli import ( "context" + "time" "github.com/spf13/cobra" "github.com/ultravioletrs/cocos/manager" @@ -15,16 +16,30 @@ import ( "github.com/ultravioletrs/cocos/pkg/sdk" ) -var Verbose bool - type CLI struct { - agentSDK sdk.SDK - agentConfig clients.AttestedClientConfig - managerConfig clients.StandardClientConfig - client grpc.Client - managerClient manager.ManagerServiceClient - connectErr error - measurement cmdconfig.MeasurementProvider + agentSDK sdk.SDK + agentConfig clients.AttestedClientConfig + managerConfig clients.StandardClientConfig + client grpc.Client + managerClient manager.ManagerServiceClient + connectErr error + measurement cmdconfig.MeasurementProvider + Verbose bool + IsManifest bool + ToBase64 bool + KeyType string + AgentCVMServerUrl string + AgentCVMServerCA string + AgentCVMClientKey string + AgentCVMClientCrt string + AgentCVMCaUrl string + AgentLogLevel string + Ttl time.Duration + AwsAccessKeyId string + AwsSecretAccessKey string + AwsEndpointUrl string + AwsRegion string + AaKbsParams string } func New(agentConfig clients.AttestedClientConfig, managerConfig clients.StandardClientConfig, measurement cmdconfig.MeasurementProvider) *CLI { diff --git a/cmd/cli/main.go b/cmd/cli/main.go index 7d4a713ea..62a515c4e 100644 --- a/cmd/cli/main.go +++ b/cmd/cli/main.go @@ -122,7 +122,7 @@ func main() { defer cliSVC.Close() } - rootCmd.PersistentFlags().BoolVarP(&cli.Verbose, "verbose", "v", false, "Enable verbose output") + rootCmd.PersistentFlags().BoolVarP(&cliSVC.Verbose, "verbose", "v", false, "Enable verbose output") keysCmd := cliSVC.NewKeysCmd() attestationCmd := cliSVC.NewAttestationCmd() @@ -151,7 +151,7 @@ func main() { // Flags keysCmd.PersistentFlags().StringVarP( - &cli.KeyType, + &cliSVC.KeyType, "key-type", "k", "rsa", diff --git a/pkg/oci/skopeo.go b/pkg/oci/skopeo.go index 44b215b74..bcaea0d9d 100644 --- a/pkg/oci/skopeo.go +++ b/pkg/oci/skopeo.go @@ -58,7 +58,11 @@ func (s *SkopeoClient) PullAndDecrypt(ctx context.Context, source ResourceSource // Add decryption key if image is encrypted if source.Encrypted { - args = append(args, "--decryption-key", DecryptionKeyProvider) + decryptionKey := DecryptionKeyProvider + if source.KBSURL != "" { + decryptionKey = fmt.Sprintf("provider:attestation-agent:type=kbs,url=%s", source.KBSURL) + } + args = append(args, "--decryption-key", decryptionKey) } // Add insecure policy for testing (TODO: use proper policy in production) diff --git a/pkg/oci/types.go b/pkg/oci/types.go index 51c17c1ae..a5f4c6ec8 100644 --- a/pkg/oci/types.go +++ b/pkg/oci/types.go @@ -25,6 +25,9 @@ type ResourceSource struct { // KBSResourcePath is the KBS resource path for the decryption key // (e.g., "default/key/algo-key") KBSResourcePath string `json:"kbs_resource_path,omitempty"` + + // KBSURL is the KBS endpoint URL for this specific resource + KBSURL string `json:"kbs_url,omitempty"` } // ImageManifest represents basic OCI image manifest information. diff --git a/test/cvms/main.go b/test/cvms/main.go index ecdf0e6e1..e9f33c3a5 100644 --- a/test/cvms/main.go +++ b/test/cvms/main.go @@ -43,9 +43,10 @@ var ( pubKeyFile string clientCAFile string // Remote resource configuration. - kbsURL string + algoKBSURL string algoSourceURL string algoKBSResourcePath string + datasetKBSURLs string datasetSourceURLs string datasetKBSPaths string algoType string @@ -76,12 +77,16 @@ func (s *svc) Run(ctx context.Context, ipAddress string, sendMessage cvmsgrpc.Se // Check if using remote datasets var datasetURLs []string var datasetKBSPathsList []string + var datasetKBSURLsList []string if datasetSourceURLs != "" { datasetURLs = strings.Split(datasetSourceURLs, ",") } if datasetKBSPaths != "" { datasetKBSPathsList = strings.Split(datasetKBSPaths, ",") } + if datasetKBSURLs != "" { + datasetKBSURLsList = strings.Split(datasetKBSURLs, ",") + } var datasetDecompressList []bool if datasetDecompress != "" { @@ -118,7 +123,7 @@ func (s *svc) Run(ctx context.Context, ipAddress string, sendMessage cvmsgrpc.Se } for i := 0; i < len(datasetURLs); i++ { - datasets = append(datasets, &cvms.Dataset{ + d := &cvms.Dataset{ Hash: dataHashBytes, UserKey: pubPem.Bytes, Filename: fmt.Sprintf("dataset_%d.csv", i), @@ -128,7 +133,14 @@ func (s *svc) Run(ctx context.Context, ipAddress string, sendMessage cvmsgrpc.Se KbsResourcePath: datasetKBSPathsList[i], Encrypted: datasetKBSPathsList[i] != "", }, - }) + } + if len(datasetKBSURLsList) > i && datasetKBSURLsList[i] != "" { + d.Kbs = &cvms.KBSConfig{ + Url: datasetKBSURLsList[i], + Enabled: true, + } + } + datasets = append(datasets, d) if len(datasetDecompressList) > i { datasets[len(datasets)-1].Decompress = datasetDecompressList[i] } @@ -189,6 +201,12 @@ func (s *svc) Run(ctx context.Context, ipAddress string, sendMessage cvmsgrpc.Se Encrypted: algoKBSResourcePath != "", }, } + if algoKBSURL != "" { + algorithm.Kbs = &cvms.KBSConfig{ + Url: algoKBSURL, + Enabled: true, + } + } } else { // Direct upload mode - use local file fileHash, err := internal.ChecksumHex(algoPath) @@ -212,15 +230,6 @@ func (s *svc) Run(ctx context.Context, ipAddress string, sendMessage cvmsgrpc.Se } } - // Build KBS config - var kbsConfig *cvms.KBSConfig - if kbsURL != "" { - kbsConfig = &cvms.KBSConfig{ - Url: kbsURL, - Enabled: true, - } - } - s.logger.Debug("sending computation run request") if err := sendMessage(&cvms.ServerStreamMessage{ Message: &cvms.ServerStreamMessage_RunReq{ @@ -236,7 +245,6 @@ func (s *svc) Run(ctx context.Context, ipAddress string, sendMessage cvmsgrpc.Se AttestedTls: attestedTLS, ClientCaFile: clientCAFile, }, - Kbs: kbsConfig, }, }, }); err != nil { @@ -258,12 +266,13 @@ func main() { flagSet.StringVar(&dataPathString, "data-paths", "", "Paths to data sources, list of string separated with commas (for direct upload mode)") flagSet.StringVar(&clientCAFile, "client-ca-file", "", "Client CA root certificate file path") // Remote resource flags - flagSet.StringVar(&kbsURL, "kbs-url", "", "KBS endpoint URL (e.g., 'http://localhost:8080')") - flagSet.StringVar(&algoSourceURL, "algo-source-url", "", "Algorithm source URL (s3://bucket/key or https://...)") - flagSet.StringVar(&algoKBSResourcePath, "algo-kbs-path", "", "Algorithm KBS resource path (e.g., 'default/key/algo-key')") + flagSet.StringVar(&algoKBSURL, "algo-kbs-url", "", "Algorithm-specific KBS endpoint URL") + flagSet.StringVar(&algoSourceURL, "algo-source-url", "", "Algorithm source URL (oci-image only)") + flagSet.StringVar(&algoKBSResourcePath, "algo-kbs-path", "", "Algorithm KBS resource path") + flagSet.StringVar(&datasetKBSURLs, "dataset-kbs-urls", "", "Dataset-specific KBS endpoint URLs, comma-separated") flagSet.StringVar(&datasetSourceURLs, "dataset-source-urls", "", "Dataset source URLs, comma-separated") flagSet.StringVar(&datasetKBSPaths, "dataset-kbs-paths", "", "Dataset KBS resource paths, comma-separated") - flagSet.StringVar(&algoType, "algo-type", "", "Algorithm execution type (e.g. binary, python)") + flagSet.StringVar(&algoType, "algo-type", "", "Algorithm execution type") flagSet.StringVar(&algoArgsString, "algo-args", "", "Algorithm arguments, comma-separated") flagSet.StringVar(&algoHash, "algo-hash", "", "Algorithm SHA256 hash (hex string)") flagSet.StringVar(&datasetTypeString, "dataset-type", "", "Dataset source type, comma-separated (deprecated, always oci-image)")