Skip to content

Commit 526ba32

Browse files
committed
Merge branch 'hotfix/24.08.2' into develop
2 parents b3e7355 + beaafec commit 526ba32

File tree

1 file changed

+160
-0
lines changed

1 file changed

+160
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
import logging
2+
import json
3+
4+
from django.core.management.base import BaseCommand
5+
from django.db import transaction
6+
from google.cloud.storage.client import Client
7+
from google.oauth2.service_account import Credentials
8+
9+
from osf.models import AbstractNode
10+
from osf.utils.migrations import disable_auto_now_fields
11+
from addons.osfstorage.models import Region
12+
13+
logger = logging.getLogger(__name__)
14+
15+
def _get_file_block_map(node):
16+
file_block_map = {}
17+
file_input_qids = node.registration_schema.schema_blocks.filter(
18+
block_type='file-input'
19+
).values_list('registration_response_key', flat=True)
20+
for schema_response in node.schema_responses.all():
21+
for block in schema_response.response_blocks.filter(schema_key__in=file_input_qids):
22+
for file_response in block.response:
23+
if file_block_map.get(file_response['file_id'], False):
24+
file_block_map[file_response['file_id']].append(block)
25+
else:
26+
file_block_map[file_response['file_id']] = [block]
27+
return file_block_map
28+
29+
def _update_blocks(file_block_map, original_id, cloned_id):
30+
for block in file_block_map[original_id]:
31+
logger.info(f'Updating block {block._id} file info')
32+
response = []
33+
for file_response in block.response:
34+
if original_id == file_response['file_id']:
35+
for key in file_response['file_urls'].keys():
36+
file_response['file_urls'][key] = file_response['file_urls'][key].replace(original_id, cloned_id)
37+
response.append(file_response)
38+
block.response = response
39+
block.save()
40+
41+
def _update_schema_meta(node):
42+
logger.info('Updating legacy schema information...')
43+
node.registration_responses = node.schema_responses.latest('-created').all_responses
44+
node.registered_meta[node.registration_schema._id] = node.expand_registration_responses()
45+
node.save()
46+
logger.info('Updated legacy schema information.')
47+
48+
def _copy_and_clone_versions(original_file, cloned_file, src_bucket, dest_bucket, dest_bucket_name, dest_region):
49+
for v in original_file.versions.order_by('identifier').all():
50+
blob_hash = v.location['object']
51+
logger.info(f'Preparing to move version {blob_hash}')
52+
# Copy each version to dest_bucket
53+
src_blob = src_bucket.get_blob(blob_hash)
54+
src_bucket.copy_blob(src_blob, dest_bucket)
55+
logger.info(f'Blob {blob_hash} copied to destination, cloning version object.')
56+
# Clone each version, update location
57+
cloned_v = v.clone()
58+
cloned_v.location['bucket'] = dest_bucket_name
59+
# Set FKs
60+
cloned_v.creator = v.creator
61+
cloned_v.region = dest_region
62+
# Save before M2M's can be set
63+
cloned_v.save()
64+
cloned_file.add_version(cloned_v)
65+
# Retain original timestamps
66+
cloned_v.created = v.created
67+
cloned_v.modified = v.modified
68+
cloned_v.save()
69+
logger.info(f'Version {blob_hash} cloned.')
70+
71+
def _clone_file(file_obj):
72+
# Clone each file, so that the originals will be purged from src_region
73+
cloned_f = file_obj.clone()
74+
# Set (G)FKs
75+
cloned_f.target = file_obj.target
76+
cloned_f.parent = file_obj.parent
77+
cloned_f.checkout = file_obj.checkout
78+
cloned_f.copied_from = file_obj.copied_from
79+
# Save before M2M's can be set, assigning both id and _id
80+
cloned_f.save()
81+
# Repoint Guids
82+
assert cloned_f.id, f'Cloned file ID not assigned for {file_obj._id}'
83+
file_obj.guids.update(object_id=cloned_f.id)
84+
# Retain original timestamps
85+
cloned_f.created = file_obj.created
86+
cloned_f.modified = file_obj.modified
87+
cloned_f.save()
88+
return cloned_f
89+
90+
def change_node_region(node, dest_region, gcs_creds):
91+
creds = Credentials.from_service_account_info(gcs_creds)
92+
client = Client(credentials=creds)
93+
osfstorage_addon = node.get_addon('osfstorage')
94+
src_region = osfstorage_addon.region
95+
if src_region.id == dest_region.id:
96+
logger.warning(f'Source and destination regions match: {src_region._id}. Exiting.')
97+
return
98+
src_bucket_name = src_region.waterbutler_settings['storage']['bucket']
99+
dest_bucket_name = dest_region.waterbutler_settings['storage']['bucket']
100+
src_bucket = client.get_bucket(src_bucket_name)
101+
dest_bucket = client.get_bucket(dest_bucket_name)
102+
response_blocks_by_file_id = {}
103+
with transaction.atomic():
104+
with disable_auto_now_fields():
105+
if node.type == 'osf.registration':
106+
response_blocks_by_file_id = _get_file_block_map(node)
107+
for f in node.files.all():
108+
logger.info(f'Prepraring to move file {f._id}')
109+
cloned_f = _clone_file(f)
110+
if f._id in response_blocks_by_file_id:
111+
logger.info(f'Prepraring to update ResponseBlocks for file {f._id}')
112+
_update_blocks(response_blocks_by_file_id, f._id, cloned_f._id)
113+
logger.info(f'File {f._id} cloned, copying versions...')
114+
_copy_and_clone_versions(f, cloned_f, src_bucket, dest_bucket, dest_bucket_name, dest_region)
115+
# Trash original file
116+
f.delete()
117+
logger.info('All files complete.')
118+
if response_blocks_by_file_id:
119+
_update_schema_meta(node)
120+
osfstorage_addon.region = dest_region
121+
osfstorage_addon.save()
122+
logger.info('Region updated. Exiting.')
123+
124+
class Command(BaseCommand):
125+
126+
def add_arguments(self, parser):
127+
super().add_arguments(parser)
128+
parser.add_argument(
129+
'-n',
130+
'--node',
131+
type=str,
132+
action='store',
133+
dest='node',
134+
help='Node._id to migrate.',
135+
)
136+
parser.add_argument(
137+
'-r',
138+
'--region',
139+
type=str,
140+
action='store',
141+
dest='region',
142+
help='Region._id to migrate files to.',
143+
)
144+
parser.add_argument(
145+
'-c',
146+
'--credentials',
147+
type=str,
148+
action='store',
149+
dest='gcs_creds',
150+
help='GCS Credentials to use. JSON string.',
151+
)
152+
153+
def handle(self, *args, **options):
154+
node = AbstractNode.load(options.get('node', None))
155+
region = Region.load(options.get('region', None))
156+
gcs_creds = json.loads(options.get('gcs_creds', '{}'))
157+
assert node, 'Node not found'
158+
assert region, 'Region not found'
159+
assert gcs_creds, 'Credentials required'
160+
change_node_region(node, region, gcs_creds)

0 commit comments

Comments
 (0)