diff --git a/interface.go b/interface.go index aa6e195..c9e0fcc 100644 --- a/interface.go +++ b/interface.go @@ -25,6 +25,8 @@ type WriteAPI interface { // Create or update given item in DynamoDB. Must implemenmt DynamoRecord interface. // DynamoRecord.GetKeys will be called to get values for parition and sort keys. PutItem(ctx context.Context, pk, sk Attribute, item interface{}, opt ...PutOption) error + // Update specified fields on a DynamoDB record + UpdateItem(ctx context.Context, pk, sk Attribute, fields interface{}, opts ...UpdateOption) error DeleteItem(ctx context.Context, pk, sk string) error BatchDeleteItems(ctx context.Context, input []AttributeRecord) []AttributeRecord } diff --git a/tests/transact_items_test.go b/tests/transact_items_test.go index d8b85d3..53ff58a 100644 --- a/tests/transact_items_test.go +++ b/tests/transact_items_test.go @@ -129,3 +129,126 @@ func TestTransactItems(t *testing.T) { } } + +func TestTransactItemsWithUpdate(t *testing.T) { + table := prepareTable(t) + testCases := []struct { + title string + condition string + keys map[string]types.AttributeValue + // items to be added initially + initialItems []Terminal + operations []types.TransactWriteItem + // items expected to exist in table after transaction operation + expected []Terminal + expectedErr error + }{{ + title: "transaction with update operation", + condition: "pk = :pk", + keys: map[string]types.AttributeValue{ + ":pk": &types.AttributeValueMemberS{Value: "merchant3"}, + }, + initialItems: []Terminal{{ + Id: "1", + Pk: "merchant3", + Sk: "terminal1", + }}, + operations: []types.TransactWriteItem{ + table.WithUpdateItem("merchant3", "terminal1", map[string]interface{}{ + "Id": "updated_id", + }), + }, + expected: []Terminal{ + { + Id: "updated_id", + Pk: "merchant3", + Sk: "terminal1", + }, + }, + }, + { + title: "transaction with mixed operations including update", + condition: "pk = :pk", + keys: map[string]types.AttributeValue{ + ":pk": &types.AttributeValueMemberS{Value: "merchant4"}, + }, + initialItems: []Terminal{{ + Id: "1", + Pk: "merchant4", + Sk: "terminal1", + }}, + operations: []types.TransactWriteItem{ + table.WithUpdateItem("merchant4", "terminal1", map[string]interface{}{ + "Id": "updated_terminal1", + }), + table.WithPutItem("merchant4", "terminal2", Terminal{ + Id: "2", + Pk: "merchant4", + Sk: "terminal2", + }), + }, + expected: []Terminal{ + { + Id: "updated_terminal1", + Pk: "merchant4", + Sk: "terminal1", + }, + { + Id: "2", + Pk: "merchant4", + Sk: "terminal2", + }, + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.title, func(t *testing.T) { + t.Parallel() + ctx := context.TODO() + + // Create initial items + if len(tc.initialItems) > 0 { + items := make([]*dynago.TransactPutItemsInput, 0, len(tc.initialItems)) + for _, item := range tc.initialItems { + items = append(items, &dynago.TransactPutItemsInput{ + PartitionKeyValue: dynago.StringValue(item.Pk), + SortKeyValue: dynago.StringValue(item.Sk), + Item: item, + }) + } + err := table.TransactPutItems(ctx, items) + if err != nil { + t.Fatalf("transaction put items failed; got %s", err) + } + } + + // Perform operations + if len(tc.operations) > 0 { + err := table.TransactItems(ctx, tc.operations...) + if err != nil { + t.Fatalf("error occurred %s", err) + } + } + + var out []Terminal + _, err := table.Query(ctx, tc.condition, tc.keys, &out) + if tc.expectedErr != nil { + if err == nil { + t.Fatalf("expected query to fail with %s", tc.expectedErr) + } + if !strings.Contains(err.Error(), tc.expectedErr.Error()) { + t.Fatalf("expected query to fail with %s; got %s", tc.expectedErr, err) + } + return + } + if err != nil { + t.Fatalf("expected query to succeed; got %s", err) + } + if !reflect.DeepEqual(tc.expected, out) { + t.Errorf("expected query to return %v; got %v", tc.expected, out) + } + + }) + } +} diff --git a/tests/update_item_test.go b/tests/update_item_test.go new file mode 100644 index 0000000..2d0f85e --- /dev/null +++ b/tests/update_item_test.go @@ -0,0 +1,291 @@ +package tests + +import ( + "context" + "reflect" + "testing" + + "github.com/oolio-group/dynago" +) + +type UpdateRecord struct { + ID string `json:"id"` + Pk string `json:"pk"` + Sk string `json:"sk"` + Name string `json:"name"` + Age int `json:"age"` + Email string `json:"email"` + Version uint `json:"version"` +} + +func TestUpdateItem(t *testing.T) { + table := prepareTable(t) + ctx := context.Background() + + testCases := []struct { + title string + initialItem UpdateRecord + updateFields interface{} + opts []dynago.UpdateOption + expected UpdateRecord + expectError bool + }{ + { + title: "update single field", + initialItem: UpdateRecord{ + ID: "test1", + Pk: "user#1", + Sk: "profile", + Name: "John Doe", + Age: 30, + }, + updateFields: map[string]interface{}{ + "Name": "Jane Doe", + }, + expected: UpdateRecord{ + ID: "test1", + Pk: "user#1", + Sk: "profile", + Name: "Jane Doe", + Age: 30, + }, + }, + { + title: "update multiple fields", + initialItem: UpdateRecord{ + ID: "test2", + Pk: "user#2", + Sk: "profile", + Name: "Bob Smith", + Age: 25, + Email: "bob@example.com", + }, + updateFields: map[string]interface{}{ + "Name": "Robert Smith", + "Age": 26, + "Email": "robert@example.com", + }, + expected: UpdateRecord{ + ID: "test2", + Pk: "user#2", + Sk: "profile", + Name: "Robert Smith", + Age: 26, + Email: "robert@example.com", + }, + }, + { + title: "update with struct fields", + initialItem: UpdateRecord{ + ID: "test3", + Pk: "user#3", + Sk: "profile", + Name: "Alice Johnson", + Age: 28, + }, + updateFields: struct { + Name string `json:"name"` + Age int `json:"age"` + }{ + Name: "Alice Williams", + Age: 29, + }, + expected: UpdateRecord{ + ID: "test3", + Pk: "user#3", + Sk: "profile", + Name: "Alice Williams", + Age: 29, + }, + }, + { + title: "update with optimistic lock", + initialItem: UpdateRecord{ + ID: "test4", + Pk: "user#4", + Sk: "profile", + Name: "David Brown", + Age: 35, + Version: 1, + }, + updateFields: map[string]interface{}{ + "Name": "David Wilson", + }, + opts: []dynago.UpdateOption{ + dynago.WithOptimisticLockForUpdate("Version", 1), + }, + expected: UpdateRecord{ + ID: "test4", + Pk: "user#4", + Sk: "profile", + Name: "David Wilson", + Age: 35, + Version: 2, // Should be incremented + }, + }, + { + title: "update with conditional expression", + initialItem: UpdateRecord{ + ID: "test5", + Pk: "user#5", + Sk: "profile", + Name: "Emma Davis", + Age: 22, + }, + updateFields: map[string]interface{}{ + "Age": 23, + }, + opts: []dynago.UpdateOption{ + dynago.WithConditionalUpdate( + "attribute_exists(#name)", + map[string]dynago.Attribute{}, + map[string]string{ + "#name": "Name", // Use the correct field name that exists in the struct + }, + ), + }, + expected: UpdateRecord{ + ID: "test5", + Pk: "user#5", + Sk: "profile", + Name: "Emma Davis", + Age: 23, + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.title, func(t *testing.T) { + t.Parallel() + + // Create initial item + pk := dynago.StringValue(tc.initialItem.Pk) + sk := dynago.StringValue(tc.initialItem.Sk) + + err := table.PutItem(ctx, pk, sk, tc.initialItem) + if err != nil { + t.Fatalf("failed to create initial item: %s", err) + } + + // Update the item + err = table.UpdateItem(ctx, pk, sk, tc.updateFields, tc.opts...) + if tc.expectError { + if err == nil { + t.Fatalf("expected update to fail, but it succeeded") + } + return + } + if err != nil { + t.Fatalf("unexpected error during update: %s", err) + } + + // Retrieve and verify the updated item + var result UpdateRecord + err, found := table.GetItem(ctx, pk, sk, &result) + if err != nil { + t.Fatalf("failed to retrieve updated item: %s", err) + } + if !found { + t.Fatalf("item not found after update") + } + + if !reflect.DeepEqual(tc.expected, result) { + t.Errorf("expected updated item to be %+v; got %+v", tc.expected, result) + } + }) + } +} + +func TestUpdateItemCustomExpression(t *testing.T) { + table := prepareTable(t) + ctx := context.Background() + + // Test custom update expression (ADD operation) + initialItem := UpdateRecord{ + ID: "expr_test", + Pk: "user#expr", + Sk: "profile", + Name: "Counter User", + Age: 10, + } + + pk := dynago.StringValue(initialItem.Pk) + sk := dynago.StringValue(initialItem.Sk) + + // Create initial item + err := table.PutItem(ctx, pk, sk, initialItem) + if err != nil { + t.Fatalf("failed to create initial item: %s", err) + } + + // Update using ADD expression to increment age + err = table.UpdateItem(ctx, pk, sk, nil, dynago.WithUpdateExpression( + "ADD #age :increment", + map[string]dynago.Attribute{ + ":increment": dynago.NumberValue(5), + }, + map[string]string{ + "#age": "Age", // Use the actual struct field name + }, + )) + if err != nil { + t.Fatalf("failed to update with custom expression: %s", err) + } + + // Verify the result + var result UpdateRecord + err, found := table.GetItem(ctx, pk, sk, &result) + if err != nil { + t.Fatalf("failed to retrieve updated item: %s", err) + } + if !found { + t.Fatalf("item not found after update") + } + + expectedAge := 15 // 10 + 5 + if result.Age != expectedAge { + t.Errorf("expected age to be %d after ADD operation; got %d", expectedAge, result.Age) + } +} + +func TestUpdateItemErrors(t *testing.T) { + table := prepareTable(t) + ctx := context.Background() + + pk := dynago.StringValue("error#test") + sk := dynago.StringValue("profile") + + testCases := []struct { + title string + fields interface{} + description string + }{ + { + title: "nil fields", + fields: nil, + description: "should fail with nil fields", + }, + { + title: "empty map", + fields: map[string]interface{}{}, + description: "should fail with empty fields", + }, + { + title: "only primary keys", + fields: map[string]interface{}{ + "pk": "should_not_update", + "sk": "should_not_update", + }, + description: "should fail when only primary keys are provided", + }, + } + + for _, tc := range testCases { + t.Run(tc.title, func(t *testing.T) { + err := table.UpdateItem(ctx, pk, sk, tc.fields) + if err == nil { + t.Errorf("%s - expected error but got none", tc.description) + } + }) + } +} \ No newline at end of file diff --git a/transaction_items.go b/transaction_items.go index 4476282..04309e2 100644 --- a/transaction_items.go +++ b/transaction_items.go @@ -44,6 +44,30 @@ func (t *Client) WithPutItem(pk string, sk string, item interface{}) types.Trans } +func (t *Client) WithUpdateItem(pk string, sk string, fields interface{}) types.TransactWriteItem { + // Generate update expression from fields + updateExpr, attrValues, attrNames, err := t.generateUpdateExpression(fields) + if err != nil { + log.Printf("Failed to generate update expression: %s", err.Error()) + return types.TransactWriteItem{} + } + + if updateExpr == "" { + log.Println("No fields to update") + return types.TransactWriteItem{} + } + + return types.TransactWriteItem{ + Update: &types.Update{ + TableName: &t.TableName, + Key: t.NewKeys(StringValue(pk), StringValue(sk)), + UpdateExpression: &updateExpr, + ExpressionAttributeValues: attrValues, + ExpressionAttributeNames: attrNames, + }, + } +} + // TransactItems is a synchronous for writing or deletion operation performed in dynamodb grouped together func (t *Client) TransactItems(ctx context.Context, input ...types.TransactWriteItem) error { diff --git a/update_item.go b/update_item.go new file mode 100644 index 0000000..597a282 --- /dev/null +++ b/update_item.go @@ -0,0 +1,205 @@ +package dynago + +import ( + "context" + "fmt" + "log" + "strings" + + "github.com/aws/aws-sdk-go-v2/feature/dynamodb/attributevalue" + "github.com/aws/aws-sdk-go-v2/service/dynamodb" +) + +type UpdateOption func(*dynamodb.UpdateItemInput) error + +// WithConditionalUpdate enables conditional updates by setting a condition expression +func WithConditionalUpdate(conditionExpression string, values map[string]Attribute, names map[string]string) UpdateOption { + return func(input *dynamodb.UpdateItemInput) error { + input.ConditionExpression = &conditionExpression + if input.ExpressionAttributeValues == nil { + input.ExpressionAttributeValues = map[string]Attribute{} + } + for k, v := range values { + input.ExpressionAttributeValues[k] = v + } + if names != nil { + if input.ExpressionAttributeNames == nil { + input.ExpressionAttributeNames = map[string]string{} + } + for k, v := range names { + input.ExpressionAttributeNames[k] = v + } + } + return nil + } +} + +// WithOptimisticLockForUpdate enables concurrency control by using an optimistic lock for updates +// Similar to PutItem's WithOptimisticLock but for UpdateItem operations +func WithOptimisticLockForUpdate(key string, currentVersion uint) UpdateOption { + return func(input *dynamodb.UpdateItemInput) error { + // Check if version attribute doesn't exist or matches the old version + condition := "attribute_not_exists(#version) or #version = :oldVersion" + input.ConditionExpression = &condition + + if input.ExpressionAttributeNames == nil { + input.ExpressionAttributeNames = map[string]string{} + } + if input.ExpressionAttributeValues == nil { + input.ExpressionAttributeValues = map[string]Attribute{} + } + + input.ExpressionAttributeNames["#version"] = key + input.ExpressionAttributeValues[":oldVersion"] = NumberValue(int64(currentVersion)) + input.ExpressionAttributeValues[":newVersion"] = NumberValue(int64(currentVersion + 1)) + + // Add version increment to update expression + versionUpdate := "#version = :newVersion" + + if input.UpdateExpression == nil || *input.UpdateExpression == "" { + // If no existing expression, create a new SET expression + expr := fmt.Sprintf("SET %s", versionUpdate) + input.UpdateExpression = &expr + } else { + existingExpr := *input.UpdateExpression + if strings.Contains(strings.ToUpper(existingExpr), "SET") { + // Add to existing SET clause + newExpr := strings.Replace(existingExpr, "SET ", fmt.Sprintf("SET %s, ", versionUpdate), 1) + input.UpdateExpression = &newExpr + } else { + // Prepend SET clause to other operations + newExpr := fmt.Sprintf("SET %s %s", versionUpdate, existingExpr) + input.UpdateExpression = &newExpr + } + } + + return nil + } +} + +// WithUpdateExpression allows setting custom update expressions (e.g., "ADD balance :val") +func WithUpdateExpression(expression string, values map[string]Attribute, names map[string]string) UpdateOption { + return func(input *dynamodb.UpdateItemInput) error { + input.UpdateExpression = &expression + + if input.ExpressionAttributeValues == nil { + input.ExpressionAttributeValues = map[string]Attribute{} + } + for k, v := range values { + input.ExpressionAttributeValues[k] = v + } + + if names != nil { + if input.ExpressionAttributeNames == nil { + input.ExpressionAttributeNames = map[string]string{} + } + for k, v := range names { + input.ExpressionAttributeNames[k] = v + } + } + + return nil + } +} + +// UpdateItem updates specified fields on a DynamoDB record +// fields parameter should be a struct or map with fields to update +// If fields is nil, only custom expressions from options will be applied +func (t *Client) UpdateItem(ctx context.Context, pk, sk Attribute, fields interface{}, opts ...UpdateOption) error { + var updateExpr string + var attrValues map[string]Attribute + var attrNames map[string]string + var err error + + // Generate update expression from fields if provided + if fields != nil { + updateExpr, attrValues, attrNames, err = t.generateUpdateExpression(fields) + if err != nil { + return fmt.Errorf("failed to generate update expression: %w", err) + } + } else { + // Initialize empty maps if no fields provided + attrValues = make(map[string]Attribute) + attrNames = make(map[string]string) + } + + input := &dynamodb.UpdateItemInput{ + TableName: &t.TableName, + Key: t.NewKeys(pk, sk), + ExpressionAttributeValues: attrValues, + ExpressionAttributeNames: attrNames, + } + + // Set update expression if we have one from fields + if updateExpr != "" { + input.UpdateExpression = &updateExpr + } + + // Apply option functions + for _, opt := range opts { + if err := opt(input); err != nil { + return fmt.Errorf("failed to apply update option: %w", err) + } + } + + // Check if we have any update expression after applying options + if input.UpdateExpression == nil || *input.UpdateExpression == "" { + return fmt.Errorf("no update expression provided") + } + + _, err = t.client.UpdateItem(ctx, input) + if err != nil { + log.Printf("Failed to update item: %s", err.Error()) + return err + } + + return nil +} + +// generateUpdateExpression creates an update expression from a struct or map +func (t *Client) generateUpdateExpression(fields interface{}) (string, map[string]Attribute, map[string]string, error) { + if fields == nil { + return "", nil, nil, fmt.Errorf("fields cannot be nil") + } + + // Marshal the fields to get attribute values + av, err := attributevalue.MarshalMap(fields) + if err != nil { + return "", nil, nil, fmt.Errorf("failed to marshal fields: %w", err) + } + + if len(av) == 0 { + return "", nil, nil, fmt.Errorf("no fields to update") + } + + var setParts []string + attrValues := make(map[string]Attribute) + attrNames := make(map[string]string) + + // Filter out partition and sort keys from updates + pkName := t.Keys["pk"] + skName := t.Keys["sk"] + + for fieldName, attrValue := range av { + // Skip partition and sort keys + if fieldName == pkName || fieldName == skName { + continue + } + + // Create attribute name and value placeholders + nameKey := fmt.Sprintf("#%s", fieldName) + valueKey := fmt.Sprintf(":%s", fieldName) + + attrNames[nameKey] = fieldName + attrValues[valueKey] = attrValue + setParts = append(setParts, fmt.Sprintf("%s = %s", nameKey, valueKey)) + } + + if len(setParts) == 0 { + return "", nil, nil, fmt.Errorf("no valid fields to update (only primary keys provided)") + } + + updateExpr := fmt.Sprintf("SET %s", strings.Join(setParts, ", ")) + + return updateExpr, attrValues, attrNames, nil +} \ No newline at end of file