Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,14 @@ class DataForwarderSerializer(Serializer):
(DataForwarderProviderSlug.SPLUNK, "Splunk"),
]
)
config = serializers.DictField(child=serializers.CharField(allow_blank=False), default=dict)
config = serializers.DictField(child=serializers.CharField(allow_blank=True), default=dict)
project_ids = serializers.ListField(
child=serializers.IntegerField(), allow_empty=True, required=True
child=serializers.IntegerField(), allow_empty=True, required=False, default=list
)

def validate_config(self, config) -> SQSConfig | SegmentConfig | SplunkConfig:
# Filter out empty string values (cleared optional fields)
config = {k: v for k, v in config.items() if v != ""}
provider = self.initial_data.get("provider")

if provider == DataForwarderProviderSlug.SQS:
Expand Down Expand Up @@ -210,12 +212,16 @@ def create(self, validated_data: Mapping[str, Any]) -> DataForwarder:

# Enroll specified projects
if project_ids:
for project_id in project_ids:
DataForwarderProject.objects.create(
data_forwarder=data_forwarder,
project_id=project_id,
is_enabled=False,
)
DataForwarderProject.objects.bulk_create(
[
DataForwarderProject(
data_forwarder=data_forwarder,
project_id=project_id,
is_enabled=True,
)
for project_id in project_ids
]
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Update without project_ids deletes all enrolled projects

The update method always processes project_ids from validated_data, but this field now has a default value of an empty list when not provided in the request. When updating a data forwarder's main config without explicitly providing project_ids, the serializer will use the default empty list, causing the code to delete all enrolled project associations. The method should only modify project_ids if they were actually present in the request data.

Fix in Cursor Fix in Web

return data_forwarder

def update(self, instance: DataForwarder, validated_data: Mapping[str, Any]) -> DataForwarder:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -346,14 +346,17 @@ def test_create_missing_config(self) -> None:
response = self.get_error_response(self.organization.slug, status_code=400, **payload)
assert "config" in str(response.data).lower()

def test_create_missing_project_ids(self) -> None:
def test_create_without_project_ids(self) -> None:
payload = {
"provider": DataForwarderProviderSlug.SEGMENT,
"config": {"write_key": "test_key"},
}

response = self.get_error_response(self.organization.slug, status_code=400, **payload)
assert "project_ids" in str(response.data).lower()
response = self.get_success_response(self.organization.slug, status_code=201, **payload)
assert response.data["provider"] == DataForwarderProviderSlug.SEGMENT

data_forwarder = DataForwarder.objects.get(id=response.data["id"])
assert data_forwarder.projects.count() == 0

def test_create_sqs_fifo_queue_validation(self) -> None:
payload = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,17 +69,6 @@ def test_required_fields(self) -> None:
assert not serializer.is_valid()
assert "provider" in serializer.errors

# Missing project_ids
serializer = DataForwarderSerializer(
data={
"organization_id": self.organization.id,
"provider": DataForwarderProviderSlug.SEGMENT,
"config": {"write_key": "test_key"},
}
)
assert not serializer.is_valid()
assert "project_ids" in serializer.errors

def test_provider_choice_validation(self) -> None:
# Valid providers
provider_configs = {
Expand Down Expand Up @@ -230,8 +219,8 @@ def test_sqs_config_validation_empty_credentials(self) -> None:
)
assert not serializer.is_valid()
assert "config" in serializer.errors
config_errors = serializer.errors["config"]
assert "access_key" in config_errors or "secret_key" in config_errors
config_errors_str = str(serializer.errors["config"])
assert "access_key" in config_errors_str and "secret_key" in config_errors_str

def test_sqs_config_validation_fifo_queue_without_message_group_id(self) -> None:
config: dict[str, str] = {
Expand Down Expand Up @@ -419,8 +408,8 @@ def test_splunk_config_validation_empty_strings(self) -> None:
)
assert not serializer.is_valid()
assert "config" in serializer.errors
config_errors = serializer.errors["config"]
assert "index" in config_errors or "source" in config_errors
config_errors_str = str(serializer.errors["config"])
assert "index" in config_errors_str and "source" in config_errors_str

def test_splunk_config_validation_invalid_token_format(self) -> None:
config: dict[str, str] = {
Expand Down
Loading