diff --git a/api/tests/unit_tests/controllers/console/datasets/test_data_source.py b/api/tests/unit_tests/controllers/console/datasets/test_data_source.py index 3060062adf..4297e6d5ea 100644 --- a/api/tests/unit_tests/controllers/console/datasets/test_data_source.py +++ b/api/tests/unit_tests/controllers/console/datasets/test_data_source.py @@ -176,6 +176,35 @@ class TestDataSourceApi: with pytest.raises(ValueError): method(api, "b1", "disable") + def test_patch_binding_scoped_to_current_tenant(self, app, patch_tenant, mock_engine): + """Verify that the patch query includes tenant_id to prevent IDOR attacks.""" + + api = DataSourceApi() + method = unwrap(api.patch) + + binding = MagicMock(id="b1", disabled=True) + + with ( + app.test_request_context("/"), + patch("controllers.console.datasets.data_source.Session") as mock_session_class, + patch("controllers.console.datasets.data_source.db.session.add"), + patch("controllers.console.datasets.data_source.db.session.commit"), + ): + mock_session = MagicMock() + mock_session_class.return_value.__enter__.return_value = mock_session + mock_session.execute.return_value.scalar_one_or_none.return_value = binding + + method(api, "b1", "enable") + + # Inspect the SELECT statement passed to session.execute + call_args = mock_session.execute.call_args + stmt = call_args[0][0] + compiled = stmt.compile(compile_kwargs={"literal_binds": True}) + compiled_where = str(compiled) + + assert "tenant_id = 'tenant-1'" in compiled_where, ( + "The patch query must filter by tenant_id to prevent IDOR vulnerabilities" + ) class TestDataSourceNotionListApi: def test_get_credential_not_found(self, app, patch_tenant):