diff --git a/README.md b/README.md index 8820ce9..8ba9e42 100644 --- a/README.md +++ b/README.md @@ -54,6 +54,9 @@ providers: #cdn: false # Manage Page Rules (URLFWD) records # pagerules: true + # Optional. Define Cloudflare plan type for the zones. Default: free, + # options: free, enterprise + #plan_type: free # Optional. Default: 4. Number of times to retry if a 429 response # is received. #retry_count: 4 diff --git a/octodns_cloudflare/__init__.py b/octodns_cloudflare/__init__.py index 770bf9b..c2a2265 100644 --- a/octodns_cloudflare/__init__.py +++ b/octodns_cloudflare/__init__.py @@ -16,6 +16,8 @@ from octodns.provider.base import BaseProvider from octodns.record import Create, Record, Update +from octodns_cloudflare.record import CloudflareZoneRecord + try: # pragma: no cover from octodns.record.https import HttpsValue from octodns.record.svcb import SvcbValue @@ -89,6 +91,7 @@ def __init__( account_id=None, cdn=False, pagerules=True, + plan_type=None, retry_count=4, retry_period=300, auth_error_retry_count=0, @@ -100,11 +103,12 @@ def __init__( ): self.log = getLogger(f'CloudflareProvider[{id}]') self.log.debug( - '__init__: id=%s, email=%s, token=***, account_id=%s, cdn=%s', + '__init__: id=%s, email=%s, token=***, account_id=%s, cdn=%s, plan=%s', id, email, account_id, cdn, + plan_type, ) super().__init__(id, *args, **kwargs) @@ -123,6 +127,7 @@ def __init__( self.account_id = account_id self.cdn = cdn self.pagerules = pagerules + self.plan_type = plan_type self.retry_count = retry_count self.retry_period = retry_period self.auth_error_retry_count = auth_error_retry_count @@ -214,7 +219,15 @@ def zones(self): else: page = None - self._zones = IdnaDict({f'{z["name"]}.': z['id'] for z in zones}) + self._zones = IdnaDict( + { + f'{z["name"]}.': { + 'id': z['id'], + 'plan': z.get('plan', {}).get('legacy_id', None), + } + for z in zones + } + ) return self._zones @@ -468,7 +481,7 @@ def _data_for_SSHFP(self, _type, records): def zone_records(self, zone): if zone.name not in self._zone_records: - zone_id = self.zones.get(zone.name, False) + zone_id = self.zones.get(zone.name, {}).get('id', False) if not zone_id: return [] @@ -1016,7 +1029,11 @@ def _gen_key(self, data): def _apply_Create(self, change): new = change.new - zone_id = self.zones[new.zone.name] + if new._type == 'CF_ZONE': + self._update_plan(new.zone.name, new.value['plan']) + return + + zone_id = self.zones[new.zone.name]['id'] if new._type == 'URLFWD': path = f'/zones/{zone_id}/pagerules' else: @@ -1026,7 +1043,12 @@ def _apply_Create(self, change): def _apply_Update(self, change): zone = change.new.zone - zone_id = self.zones[zone.name] + if change.new._type == 'CF_ZONE': + self._update_plan(zone.name, change.new.value['plan']) + return + + zone_id = self.zones[zone.name]['id'] + hostname = zone.hostname_from_fqdn(change.new.fqdn[:-1]) _type = change.new._type @@ -1156,6 +1178,12 @@ def _apply_Update(self, change): def _apply_Delete(self, change): existing = change.existing + if existing._type == 'CF_ZONE': + # Cloudflare plan record deletion is interpreted as a change to the free plan + self._update_plan( + existing.zone.name, CloudflareZoneRecord.FREE_PLAN + ) + return existing_name = existing.fqdn[:-1] # Make sure to map ALIAS to CNAME when looking for the target to delete existing_type = 'CNAME' if existing._type == 'ALIAS' else existing._type @@ -1166,7 +1194,9 @@ def _apply_Delete(self, change): parsed_uri = urlsplit(uri) record_name = parsed_uri.netloc record_type = 'URLFWD' - zone_id = self.zones.get(existing.zone.name, False) + zone_id = self.zones.get(existing.zone.name, {}).get( + 'id', False + ) if ( existing_name == record_name and existing_type == record_type @@ -1184,6 +1214,32 @@ def _apply_Delete(self, change): ) self._try_request('DELETE', path) + def _supported_plans(self, zone_name): + zone_id = self.zones[zone_name]['id'] + path = f'/zones/{zone_id}/available_plans' + resp = self._try_request('GET', path) + try: + result = resp['result'] + if isinstance(result, list): + return [plan['legacy_id'] for plan in result] + except KeyError: + pass + msg = f'{self.id}: unable to determine supported plans, do you have an Enterprise account?' + raise SupportsException(msg) + + def _update_plan(self, zone_name, plan): + if self.zones[zone_name]['plan'] == plan: + return + if plan in self._supported_plans(zone_name): + zone_id = self.zones[zone_name]['id'] + data = {'plan': {'legacy_id': plan}} + resp = self._try_request('PATCH', f'/zones/{zone_id}', data=data) + # Update the cached plan information + self.zones[zone_name]['plan'] = resp['result']['plan']['legacy_id'] + else: + msg = f'{self.id}: {plan} is not supported for {zone_name}' + raise SupportsException(msg) + def _apply(self, plan): desired = plan.desired changes = plan.changes @@ -1197,9 +1253,12 @@ def _apply(self, plan): data = {'name': name[:-1], 'jump_start': False} if self.account_id is not None: data['account'] = {'id': self.account_id} + if self.plan_type is not None: + data['plan'] = {'legacy_id': self.plan_type} resp = self._try_request('POST', '/zones', data=data) - zone_id = resp['result']['id'] - self.zones[name] = zone_id + self.zones[name] = {'id': resp['result']['id']} + if self.plan_type is not None: + self.zones[name]['plan'] = resp['result']['plan']['legacy_id'] self._zone_records[name] = {} # Force the operation order to be Delete() -> Create() -> Update() @@ -1220,6 +1279,26 @@ def _extra_changes(self, existing, desired, changes): existing_records = {r: r for r in existing.records} changed_records = {c.record for c in changes} + # Check if plan needs to be updated + desired_plan = self.plan_type + if desired_plan is not None: + zone_name = desired.name + if zone_name in self.zones: + current_plan = self.zones[zone_name]['plan'] + if current_plan != desired_plan: + # Add a fake record with custom type to trigger the plan update + record = Record.new( + desired, + '_plan_update', + { + 'type': 'CF_ZONE', # Custom type for Cloudflare zone updates + 'ttl': 300, + 'value': {'plan': self.plan_type}, + }, + ) + extra_changes.append(Update(record, record)) + + # Check for other changes (proxied status, auto-ttl, comments, tags) for desired_record in desired.records: existing_record = existing_records.get(desired_record, None) if not existing_record: # Will be created diff --git a/octodns_cloudflare/record.py b/octodns_cloudflare/record.py new file mode 100644 index 0000000..6b6c76e --- /dev/null +++ b/octodns_cloudflare/record.py @@ -0,0 +1,47 @@ +from octodns.record import Record + + +class CloudflareZoneRecord(Record): + """ + Custom record type for Cloudflare zone. + + Supports updating the zone plan. + """ + + _type = 'CF_ZONE' + _value_type = dict + + FREE_PLAN = 'free' + + def __init__(self, zone, name, data, *args, **kwargs): + super().__init__(zone, name, data, *args, **kwargs) + self.value = data['value'] + + @classmethod + def validate(cls, _name, fqdn, data): + value = data['value'] + if not isinstance(value, dict): + return [f'CF_ZONE value must be a dict, not {type(value)}'] + + if 'plan' not in value: + return ['CF_ZONE value must include "plan" key'] + + if not all(isinstance(v, str) for v in value.values()): + return ['CF_ZONE values must be strings'] + + return [] + + def _equality_tuple(self): + return (self.zone.name, self._type, self.name, self.value['plan']) + + def changes(self, other, target): + if not isinstance(other, CloudflareZoneRecord): + return True + return other.value != self.value + + def __repr__(self): + return f'CloudflareZoneRecord<{self.value}>' + + +# Register the custom record type +Record.register_type(CloudflareZoneRecord) diff --git a/tests/test_octodns_provider_cloudflare.py b/tests/test_octodns_provider_cloudflare.py index 5a90358..e6d270a 100644 --- a/tests/test_octodns_provider_cloudflare.py +++ b/tests/test_octodns_provider_cloudflare.py @@ -1005,7 +1005,7 @@ def test_pagerules(self): # Set things up to preexist/mock as necessary zone = Zone('unit.tests.', []) # Stuff a fake zone id in place - provider._zones = {zone.name: '42'} + provider._zones = {zone.name: {'id': '42'}} provider._request = Mock() side_effect = [ { @@ -2891,3 +2891,266 @@ def test_process_desired_zone(self): msg = str(ctx.exception) self.assertTrue('subber.unit.tests.' in msg) self.assertTrue('coresponding NS record' in msg) + + def test_plan_handling(self): + provider = CloudflareProvider( + 'test', 'email', 'token', 'account_id', plan_type='enterprise' + ) + provider._try_request = Mock() + + # Test 1: Creating new zone with plan_type + provider._try_request.side_effect = [ + { + # GET /zones response (empty) + 'result': [], + 'result_info': {'count': 0, 'per_page': 50}, + }, + { + # POST /zones response + 'result': {'id': '42', 'plan': {'legacy_id': 'enterprise'}}, + 'result_info': {'count': 1, 'per_page': 50}, + }, + ] + + zone = Zone('unit.tests.', []) + plan = Plan(zone, zone, [], True) + provider._apply(plan) + + provider._try_request.assert_has_calls( + [ + call( + 'GET', + '/zones', + params={ + 'page': 1, + 'per_page': 50, + 'account.id': 'account_id', + }, + ), + call( + 'POST', + '/zones', + data={ + 'name': 'unit.tests', + 'jump_start': False, + 'account': {'id': 'account_id'}, + 'plan': {'legacy_id': 'enterprise'}, + }, + ), + ] + ) + + self.assertEqual( + provider.zones['unit.tests.'], {'id': '42', 'plan': 'enterprise'} + ) + + # Reset for next test + provider._try_request.reset_mock() + + # Test 2: No plan update for new zone (plan is set during creation) + provider._try_request.side_effect = [ + { + # GET /zones response (empty) + 'result': [], + 'result_info': {'count': 0, 'per_page': 50}, + } + ] + + existing = Zone('unit.tests.', []) + desired = Zone('unit.tests.', []) + changes = [] + + extra = provider._extra_changes(existing, desired, changes) + self.assertEqual(0, len(extra)) # No extra changes for new zone + + # Test 3: Plan update via extra changes and apply + provider._zones = {'unit.tests.': {'id': '42', 'plan': 'pro'}} + + extra = provider._extra_changes(existing, desired, changes) + self.assertEqual(1, len(extra)) + self.assertIsInstance(extra[0], Update) + self.assertEqual('CF_ZONE', extra[0].new._type) + self.assertEqual({'plan': 'enterprise'}, extra[0].new.value) + + # Test 4: Plan update fails when available plans can't be determined + provider._try_request.reset_mock() + provider._try_request.side_effect = [ + { + # GET /zones/42/available_plans returns no plans + 'result': [] + } + ] + + with self.assertRaises(SupportsException) as ctx: + provider.apply(Plan(existing, desired, changes + extra, True)) + self.assertEqual( + 'test: enterprise is not supported for unit.tests.', + str(ctx.exception), + ) + + # Test 5: Plan update fails when desired plan isn't available + provider._try_request.reset_mock() + provider._try_request.side_effect = [ + { + # GET /zones/42/available_plans returns only pro plan + 'result': [{'legacy_id': 'pro'}] + } + ] + + with self.assertRaises(SupportsException) as ctx: + provider.apply(Plan(existing, desired, changes + extra, True)) + self.assertEqual( + 'test: enterprise is not supported for unit.tests.', + str(ctx.exception), + ) + + # Test 6: Successful plan update + provider._try_request.reset_mock() + provider._try_request.side_effect = [ + { + # GET /zones/42/available_plans + 'result': [{'legacy_id': 'pro'}, {'legacy_id': 'enterprise'}] + }, + { + # PATCH /zones/42 (plan update) + 'result': {'plan': {'legacy_id': 'enterprise'}} + }, + ] + + provider.apply(Plan(existing, desired, changes + extra, True)) + + provider._try_request.assert_has_calls( + [ + call('GET', '/zones/42/available_plans'), + call( + 'PATCH', + '/zones/42', + data={'plan': {'legacy_id': 'enterprise'}}, + ), + ] + ) + + # Test 7: No plan update when zone doesn't exist + provider._zones = {} + provider._try_request.reset_mock() + extra = provider._extra_changes(existing, desired, changes) + self.assertEqual(0, len(extra)) + + # Test 8: No plan update when plan_type is None + provider = CloudflareProvider( + 'test', 'email', 'token', 'account_id', plan_type=None + ) + provider._zones = {'unit.tests.': {'id': '42', 'plan': 'pro'}} + provider._try_request = Mock() + + extra = provider._extra_changes(existing, desired, changes) + self.assertEqual( + 0, len(extra) + ) # No extra changes when plan_type is None + provider._try_request.assert_not_called() # No API calls should be made + + # Test 9: No plan update when current plan matches desired plan + provider = CloudflareProvider( + 'test', 'email', 'token', 'account_id', plan_type='enterprise' + ) + provider._try_request = Mock() + provider._zones = {'unit.tests.': {'id': '42', 'plan': 'enterprise'}} + provider._update_plan('unit.tests.', 'enterprise') # Should do nothing + provider._try_request.assert_not_called() + + # Test 10: Regular record update (non-CF_ZONE) + provider = CloudflareProvider( + 'test', 'email', 'token', 'account_id', plan_type='enterprise' + ) + provider._try_request = Mock() + provider._zones = {'unit.tests.': {'id': '42', 'plan': 'pro'}} + + record = Record.new( + existing, 'test', {'type': 'A', 'ttl': 300, 'value': '1.2.3.4'} + ) + + # Mock the zone_records method + provider.zone_records = Mock( + return_value=[ + { + 'id': 'record-id', + 'type': 'A', + 'name': 'test.unit.tests', + 'content': '1.2.3.4', + 'ttl': 300, + 'proxied': False, + 'zone_id': '42', + } + ] + ) + + provider._apply_Update(Update(record, record)) + provider._try_request.assert_called_once_with( + 'PUT', + '/zones/42/dns_records/record-id', + data={ + 'content': '1.2.3.4', + 'name': 'test.unit.tests', + 'type': 'A', + 'ttl': 300, + 'proxied': False, + }, + ) + + # Test 11: Plan update with empty response + provider._try_request = Mock() + provider._try_request.side_effect = [{'result': None}] + provider._zones = {'unit.tests.': {'id': '42', 'plan': 'pro'}} + + with self.assertRaises(SupportsException) as ctx: + provider._update_plan('unit.tests.', 'enterprise') + self.assertEqual( + 'test: unable to determine supported plans, do you have an Enterprise account?', + str(ctx.exception), + ) + + # Test 12: Plan update with malformed response + provider._try_request.side_effect = [ + {'result': [{}]} # Missing legacy_id + ] + with self.assertRaises(SupportsException) as ctx: + provider._update_plan('unit.tests.', 'enterprise') + self.assertEqual( + 'test: unable to determine supported plans, do you have an Enterprise account?', + str(ctx.exception), + ) + + # Test 13: Create CF_ZONE record + provider = CloudflareProvider('test', 'email', 'token') + provider._update_plan = Mock() + + zone = Zone('unit.tests.', []) + new = Record.new( + zone, + '_plan_update', + {'type': 'CF_ZONE', 'ttl': 300, 'value': {'plan': 'pro'}}, + ) + + change = Create(new) + provider._apply_Create(change) + provider._update_plan.assert_called_once_with('unit.tests.', 'pro') + + # Test 14: Delete CF_ZONE record + provider = CloudflareProvider('test', 'email', 'token') + provider._update_plan = Mock() + + existing = Record.new( + zone, + '_plan_update', + { + 'type': 'CF_ZONE', + 'ttl': 300, + 'value': { + 'plan': 'pro' # Current plan doesn't matter for deletion + }, + }, + ) + + change = Delete(existing) + provider._apply_Delete(change) + provider._update_plan.assert_called_once_with('unit.tests.', 'free') diff --git a/tests/test_octodns_record_cloudflare.py b/tests/test_octodns_record_cloudflare.py new file mode 100644 index 0000000..74f3b98 --- /dev/null +++ b/tests/test_octodns_record_cloudflare.py @@ -0,0 +1,172 @@ +from unittest import TestCase + +from octodns.record import Record, ValidationError +from octodns.zone import Zone + +from octodns_cloudflare.record import CloudflareZoneRecord + + +class TestCloudflareZoneRecord(TestCase): + def test_cloudflare_zone_record(self): + # Test valid plan record creation + zone = Zone('unit.tests.', []) + + record = Record.new( + zone, + '_plan_update', + {'type': 'CF_ZONE', 'ttl': 300, 'value': {'plan': 'enterprise'}}, + ) + + self.assertIsInstance(record, CloudflareZoneRecord) + self.assertEqual('CF_ZONE', record._type) + self.assertEqual({'plan': 'enterprise'}, record.value) + + # Test validation errors + with self.assertRaises(ValidationError) as ctx: + Record.new( + zone, + '_plan_update', + { + 'type': 'CF_ZONE', + 'ttl': 300, + 'value': 'invalid', # Should be dict + }, + ) + self.assertTrue('_plan_update.unit.tests.' in str(ctx.exception)) + self.assertTrue('must be a dict' in str(ctx.exception)) + + with self.assertRaises(ValidationError) as ctx: + Record.new( + zone, + '_plan_update', + {'type': 'CF_ZONE', 'ttl': 300, 'value': {}}, # Missing 'plan' + ) + self.assertTrue('_plan_update.unit.tests.' in str(ctx.exception)) + self.assertTrue('must include "plan" key' in str(ctx.exception)) + + with self.assertRaises(ValidationError) as ctx: + Record.new( + zone, + '_plan_update', + { + 'type': 'CF_ZONE', + 'ttl': 300, + 'value': {'plan': 123}, # Invalid value type + }, + ) + self.assertTrue('_plan_update.unit.tests.' in str(ctx.exception)) + self.assertTrue('must be strings' in str(ctx.exception)) + + # Test equality comparison + record1 = Record.new( + zone, + '_plan_update', + {'type': 'CF_ZONE', 'ttl': 300, 'value': {'plan': 'enterprise'}}, + ) + + record2 = Record.new( + zone, + '_plan_update', + {'type': 'CF_ZONE', 'ttl': 300, 'value': {'plan': 'enterprise'}}, + ) + + record3 = Record.new( + zone, + '_plan_update', + { + 'type': 'CF_ZONE', + 'ttl': 300, + 'value': {'plan': 'business'}, # Different target plan + }, + ) + + self.assertEqual(record1, record2) + self.assertNotEqual(record1, record3) + + # Test comparison with non-CloudflareZoneRecord + other_record = Record.new( + zone, '_plan_update', {'type': 'A', 'ttl': 300, 'value': '1.2.3.4'} + ) + self.assertTrue(record1.changes(other_record, None)) + + def test_cloudflare_zone_record_repr(self): + zone = Zone('unit.tests.', []) + record = Record.new( + zone, + '_plan_update', + {'type': 'CF_ZONE', 'ttl': 300, 'value': {'plan': 'enterprise'}}, + ) + self.assertEqual( + "CloudflareZoneRecord<{'plan': 'enterprise'}>", repr(record) + ) + + def test_cloudflare_zone_record_changes(self): + zone = Zone('unit.tests.', []) + + # Test changes with non-CloudflareZoneRecord + record = Record.new( + zone, + '_plan_update', + {'type': 'CF_ZONE', 'ttl': 300, 'value': {'plan': 'enterprise'}}, + ) + other = Record.new( + zone, '_plan_update', {'type': 'A', 'ttl': 300, 'value': '1.2.3.4'} + ) + self.assertTrue(record.changes(other, None)) + + # Test changes with same type but different values + other = Record.new( + zone, + '_plan_update', + {'type': 'CF_ZONE', 'ttl': 300, 'value': {'plan': 'pro'}}, + ) + self.assertTrue(record.changes(other, None)) + + # Test no changes with identical records + other = Record.new( + zone, + '_plan_update', + {'type': 'CF_ZONE', 'ttl': 300, 'value': {'plan': 'enterprise'}}, + ) + self.assertFalse(record.changes(other, None)) + + def test_cloudflare_zone_record_equality(self): + zone = Zone('unit.tests.', []) + record1 = Record.new( + zone, + '_plan_update', + {'type': 'CF_ZONE', 'ttl': 300, 'value': {'plan': 'enterprise'}}, + ) + + # Same record + record2 = Record.new( + zone, + '_plan_update', + {'type': 'CF_ZONE', 'ttl': 300, 'value': {'plan': 'enterprise'}}, + ) + self.assertEqual(record1, record2) + + # Different plan + record3 = Record.new( + zone, + '_plan_update', + {'type': 'CF_ZONE', 'ttl': 300, 'value': {'plan': 'pro'}}, + ) + self.assertNotEqual(record1, record3) + + # Different name + record4 = Record.new( + zone, + 'other_update', + {'type': 'CF_ZONE', 'ttl': 300, 'value': {'plan': 'enterprise'}}, + ) + self.assertNotEqual(record1, record4) + + # Different zone + other_zone = Zone('other.tests.', []) + record5 = Record.new( + other_zone, + '_plan_update', + {'type': 'CF_ZONE', 'ttl': 300, 'value': {'plan': 'enterprise'}}, + ) + self.assertNotEqual(record1, record5)