Skip to content

Commit 13714a7

Browse files
committed
feat: filter restricted runs on APIs
1 parent 9a5b5aa commit 13714a7

30 files changed

+464
-64
lines changed

course_discovery/apps/api/serializers.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from course_discovery.apps.api.fields import (
3232
HtmlField, ImageField, SlugRelatedFieldWithReadSerializer, SlugRelatedTranslatableField, StdImageSerializerField
3333
)
34-
from course_discovery.apps.api.utils import StudioAPI
34+
from course_discovery.apps.api.utils import StudioAPI, get_excluded_restriction_types
3535
from course_discovery.apps.catalogs.models import Catalog
3636
from course_discovery.apps.core.api_client.lms import LMSAPIClient
3737
from course_discovery.apps.core.utils import update_instance
@@ -1638,8 +1638,10 @@ class CourseWithRecommendationsSerializer(FlexFieldsSerializerMixin, TimestampMo
16381638
recommendations = serializers.SerializerMethodField()
16391639

16401640
def get_recommendations(self, course):
1641+
excluded_restriction_types = get_excluded_restriction_types(self.context['request'])
1642+
recommended_courses = course.recommendations(excluded_restriction_types=excluded_restriction_types)
16411643
return CourseRecommendationSerializer(
1642-
course.recommendations(),
1644+
recommended_courses,
16431645
many=True,
16441646
context={
16451647
'request': self.context.get('request'),
@@ -1996,7 +1998,7 @@ def get_organization_logo_override_url(self, obj):
19961998
return None
19971999

19982000
@classmethod
1999-
def prefetch_queryset(cls, partner, queryset=None):
2001+
def prefetch_queryset(cls, partner, queryset=None, course_runs=None):
20002002
# Explicitly check if the queryset is None before selecting related
20012003
queryset = queryset if queryset is not None else Program.objects.filter(partner=partner)
20022004

@@ -2020,7 +2022,7 @@ def prefetch_queryset(cls, partner, queryset=None):
20202022
'degree__rankings',
20212023
'degree__quick_facts',
20222024
'labels',
2023-
Prefetch('courses', queryset=MinimalProgramCourseSerializer.prefetch_queryset()),
2025+
Prefetch('courses', queryset=MinimalProgramCourseSerializer.prefetch_queryset(course_runs=course_runs)),
20242026
Prefetch('authoring_organizations', queryset=OrganizationSerializer.prefetch_queryset(partner)),
20252027
)
20262028

@@ -2165,8 +2167,8 @@ class MinimalExtendedProgramSerializer(MinimalProgramSerializer):
21652167
expected_learning_items = serializers.SlugRelatedField(many=True, read_only=True, slug_field='value')
21662168

21672169
@classmethod
2168-
def prefetch_queryset(cls, partner, queryset=None):
2169-
queryset = super().prefetch_queryset(partner=partner, queryset=queryset)
2170+
def prefetch_queryset(cls, partner, queryset=None, course_runs=None):
2171+
queryset = super().prefetch_queryset(partner=partner, queryset=queryset, course_runs=course_runs)
21702172

21712173
return queryset.prefetch_related(
21722174
'expected_learning_items',
@@ -2209,7 +2211,7 @@ class ProgramSerializer(MinimalProgramSerializer):
22092211
product_source = SourceSerializer(required=False, read_only=True)
22102212

22112213
@classmethod
2212-
def prefetch_queryset(cls, partner, queryset=None):
2214+
def prefetch_queryset(cls, partner, queryset=None, course_runs=None):
22132215
"""
22142216
Prefetch the related objects that will be serialized with a `Program`.
22152217
@@ -2255,7 +2257,7 @@ def prefetch_queryset(cls, partner, queryset=None):
22552257
'instructor_ordering',
22562258
# We need the full Course prefetch here to get CourseRun information that methods on the Program
22572259
# model iterate across (e.g. language). These fields aren't prefetched by the minimal Course serializer.
2258-
Prefetch('courses', queryset=CourseSerializer.prefetch_queryset(partner=partner)),
2260+
Prefetch('courses', queryset=CourseSerializer.prefetch_queryset(partner=partner, course_runs=course_runs)),
22592261
Prefetch('authoring_organizations', queryset=OrganizationSerializer.prefetch_queryset(partner)),
22602262
Prefetch('credit_backing_organizations', queryset=OrganizationSerializer.prefetch_queryset(partner)),
22612263
Prefetch('corporate_endorsements', queryset=CorporateEndorsementSerializer.prefetch_queryset()),
@@ -2302,11 +2304,13 @@ class PathwaySerializer(BaseModelSerializer):
23022304
course_run_statuses = serializers.ReadOnlyField()
23032305

23042306
@classmethod
2305-
def prefetch_queryset(cls, partner):
2307+
def prefetch_queryset(cls, partner, course_runs=None):
23062308
queryset = Pathway.objects.filter(partner=partner)
23072309

23082310
return queryset.prefetch_related(
2309-
Prefetch('programs', queryset=MinimalProgramSerializer.prefetch_queryset(partner=partner)),
2311+
Prefetch('programs', queryset=MinimalProgramSerializer.prefetch_queryset(
2312+
partner=partner, course_runs=course_runs
2313+
)),
23102314
)
23112315

23122316
class Meta:

course_discovery/apps/api/tests/test_serializers.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2347,7 +2347,9 @@ def test_detail_fields_in_response(self, is_post_request):
23472347
'staff': MinimalPersonSerializer(course_run.staff, many=True,
23482348
context={'request': request}).data,
23492349
'content_language': course_run.language.code if course_run.language else None,
2350-
2350+
'restriction_type': (
2351+
course_run.restricted_run.restriction_type if hasattr(course_run, 'restricted_run') else None
2352+
)
23512353
}],
23522354
'uuid': str(course.uuid),
23532355
'subjects': [subject.name for subject in course.subjects.all()],
@@ -2418,6 +2420,9 @@ def get_expected_data(cls, course, course_run, course_skill, seat):
24182420
'estimated_hours': get_course_run_estimated_hours(course_run),
24192421
'first_enrollable_paid_seat_price': course_run.first_enrollable_paid_seat_price or 0.0,
24202422
'is_enrollable': course_run.is_enrollable,
2423+
'restriction_type': (
2424+
course_run.restricted_run.restriction_type if hasattr(course_run, 'restricted_run') else None
2425+
)
24212426
}],
24222427
'uuid': str(course.uuid),
24232428
'subjects': [subject.name for subject in course.subjects.all()],
@@ -2549,6 +2554,9 @@ def get_expected_data(cls, course_run, course_skill, request):
25492554
'first_enrollable_paid_seat_sku': course_run.first_enrollable_paid_seat_sku(),
25502555
'first_enrollable_paid_seat_price': course_run.first_enrollable_paid_seat_price,
25512556
'is_enrollable': course_run.is_enrollable,
2557+
'restriction_type': (
2558+
course_run.restricted_run.restriction_type if hasattr(course_run, 'restricted_run') else None
2559+
)
25522560
}
25532561

25542562

@@ -2751,7 +2759,8 @@ def get_expected_data(cls, learner_pathway, request):
27512759
'visible_via_association': True,
27522760
'steps': LearnerPathwayStepSerializer(
27532761
learner_pathway.steps.all(),
2754-
many=True
2762+
many=True,
2763+
context={'request': request}
27552764
).data,
27562765
'created': serialize_datetime(learner_pathway.created),
27572766
}

course_discovery/apps/api/utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
from course_discovery.apps.core.api_client.lms import LMSAPIClient
1515
from course_discovery.apps.core.utils import serialize_datetime
16+
from course_discovery.apps.course_metadata.choices import CourseRunRestrictionType
1617
from course_discovery.apps.course_metadata.models import CourseRun
1718

1819
logger = logging.getLogger(__name__)
@@ -199,6 +200,11 @@ def increment_character(character):
199200
return chr(ord(character) + 1) if character != 'z' else 'a'
200201

201202

203+
def get_excluded_restriction_types(request):
204+
include_restricted = request.query_params.get('include_restricted', '').split(',')
205+
return list(set(CourseRunRestrictionType.values) - set(include_restricted))
206+
207+
202208
class StudioAPI:
203209
"""
204210
A convenience class for talking to the Studio API - designed to allow subclassing by the publisher django app,

course_discovery/apps/api/v1/tests/test_views/test_catalogs.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from course_discovery.apps.course_metadata.choices import CourseRunStatus
2222
from course_discovery.apps.course_metadata.models import Course, CourseType
2323
from course_discovery.apps.course_metadata.tests.factories import (
24-
CourseRunFactory, SeatFactory, SeatTypeFactory, SubjectFactory
24+
CourseRunFactory, RestrictedCourseRunFactory, SeatFactory, SeatTypeFactory, SubjectFactory
2525
)
2626
from course_discovery.conftest import get_course_run_states
2727

@@ -335,6 +335,33 @@ def test_courses(self, state):
335335
assert response.status_code == 200
336336
assert response.data['results'] == []
337337

338+
@ddt.data([True, 2], [False, 1])
339+
@ddt.unpack
340+
def test_courses_with_restricted_runs(self, include_restriction_param, expected_result_count):
341+
url = reverse('api:v1:catalog-courses', kwargs={'id': self.catalog.id})
342+
Course.objects.all().delete()
343+
344+
now = datetime.datetime.now(pytz.UTC)
345+
future = now + datetime.timedelta(days=30)
346+
course_run = CourseRunFactory.create(
347+
course__title='ABC Test Course With Archived', end=future, enrollment_end=future
348+
)
349+
restricted_course_run = CourseRunFactory.create(
350+
course=course_run.course,
351+
course__title='ABC Test Course With Archived', end=future, enrollment_end=future,
352+
status=CourseRunStatus.Published
353+
)
354+
RestrictedCourseRunFactory(course_run=restricted_course_run, restriction_type='custom-b2b-enterprise')
355+
SeatFactory.create(course_run=course_run)
356+
SeatFactory.create(course_run=restricted_course_run)
357+
358+
if include_restriction_param:
359+
url += '?include_restricted=custom-b2b-enterprise'
360+
361+
response = self.client.get(url)
362+
assert response.status_code == 200
363+
assert len(response.data['results'][0]['course_runs']) == expected_result_count
364+
338365
def test_courses_with_include_archived(self):
339366
"""
340367
Verify the endpoint returns the list of available and archived courses if include archived

course_discovery/apps/api/v1/tests/test_views/test_course_runs.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import pytz
1010
import responses
1111
from django.contrib.auth.models import Group
12+
from django.core.management import call_command
1213
from django.db.models.functions import Lower
1314
from django.db.models.signals import pre_save
1415
from django.test import override_settings
@@ -1211,6 +1212,43 @@ def test_list_sorted_by_course_start_date(self):
12111212
self.serialize_course_run(CourseRun.objects.all().order_by('start'), many=True)
12121213
)
12131214

1215+
@ddt.data(True, False)
1216+
def test_list_include_restricted(self, include_restriction_param):
1217+
restricted_run = CourseRunFactory(course__partner=self.partner)
1218+
RestrictedCourseRunFactory(course_run=restricted_run, restriction_type='custom-b2c')
1219+
url = reverse('api:v1:course_run-list')
1220+
if include_restriction_param:
1221+
url += '?include_restricted=custom-b2c'
1222+
1223+
with self.assertNumQueries(14, threshold=3):
1224+
response = self.client.get(url)
1225+
1226+
assert response.status_code == 200
1227+
retrieved_keys = [r['key'] for r in response.data['results']]
1228+
if include_restriction_param:
1229+
assert restricted_run.key in retrieved_keys
1230+
else:
1231+
assert restricted_run.key not in retrieved_keys
1232+
1233+
@ddt.data([True, 4], [False, 3])
1234+
@ddt.unpack
1235+
def test_list_query_include_restricted(self, include_restriction_param, expected_result_count):
1236+
CourseRunFactory.create_batch(3, title='Some cool title', course__partner=self.partner)
1237+
CourseRunFactory(title='non-cool title')
1238+
restricted_run = CourseRunFactory(title='Some cool title', course__partner=self.partner)
1239+
RestrictedCourseRunFactory(course_run=restricted_run, restriction_type='custom-b2c')
1240+
query = 'title:Some cool title'
1241+
url = '{root}?q={query}'.format(root=reverse('api:v1:course_run-list'), query=query)
1242+
if include_restriction_param:
1243+
url += '&include_restricted=custom-b2c,custom-b2b-enterprise'
1244+
1245+
call_command('search_index', '--rebuild', '-f')
1246+
1247+
with self.assertNumQueries(30, threshold=3):
1248+
response = self.client.get(url)
1249+
1250+
assert len(response.data['results']) == expected_result_count
1251+
12141252
def test_list_query(self):
12151253
""" Verify the endpoint returns a filtered list of courses """
12161254
course_runs = CourseRunFactory.create_batch(3, title='Some random title', course__partner=self.partner)

course_discovery/apps/api/v1/tests/test_views/test_courses.py

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
from course_discovery.apps.course_metadata.tests.factories import (
3636
CourseEditorFactory, CourseEntitlementFactory, CourseFactory, CourseLocationRestrictionFactory, CourseRunFactory,
3737
CourseTypeFactory, GeoLocationFactory, LevelTypeFactory, OrganizationFactory, ProductValueFactory, ProgramFactory,
38-
SeatFactory, SeatTypeFactory, SourceFactory, SubjectFactory
38+
RestrictedCourseRunFactory, SeatFactory, SeatTypeFactory, SourceFactory, SubjectFactory
3939
)
4040
from course_discovery.apps.course_metadata.toggles import IS_SUBDIRECTORY_SLUG_FORMAT_ENABLED
4141
from course_discovery.apps.course_metadata.utils import data_modified_timestamp_update, ensure_draft_world
@@ -278,6 +278,68 @@ def test_course_runs_are_ordered(self):
278278
self.assertListEqual(response.data['course_run_keys'], expected_keys)
279279
self.assertListEqual([run['key'] for run in response.data['course_runs']], expected_keys)
280280

281+
@ddt.data(True, False)
282+
def test_course_runs_restriction(self, include_restriction_param):
283+
run_restricted = CourseRunFactory(
284+
course=self.course,
285+
start=datetime.datetime(2033, 1, 1, tzinfo=pytz.UTC),
286+
status=CourseRunStatus.Published
287+
)
288+
run_not_restricted = CourseRunFactory(
289+
course=self.course,
290+
start=datetime.datetime(2033, 1, 1, tzinfo=pytz.UTC),
291+
status=CourseRunStatus.Unpublished
292+
)
293+
RestrictedCourseRunFactory(course_run=run_restricted, restriction_type='custom-b2c')
294+
SeatFactory(course_run=run_restricted)
295+
SeatFactory(course_run=run_not_restricted)
296+
297+
url = reverse('api:v1:course-detail', kwargs={'key': self.course.key})
298+
if include_restriction_param:
299+
url += '?include_restricted=custom-b2c'
300+
with self.assertNumQueries(36, threshold=3):
301+
response = self.client.get(url)
302+
assert response.status_code == 200
303+
304+
if not include_restriction_param:
305+
self.assertEqual(response.data['course_run_keys'], [run_not_restricted.key])
306+
self.assertEqual(response.data['course_run_statuses'], [run_not_restricted.status])
307+
self.assertEqual(len(response.data['course_runs']), 1)
308+
self.assertEqual(response.data['advertised_course_run_uuid'], None)
309+
else:
310+
self.assertEqual(set(response.data['course_run_keys']), {run_not_restricted.key, run_restricted.key})
311+
self.assertEqual(
312+
set(response.data['course_run_statuses']),
313+
{run_not_restricted.status, run_restricted.status}
314+
)
315+
self.assertEqual(len(response.data['course_runs']), 2)
316+
self.assertEqual(response.data['advertised_course_run_uuid'], run_restricted.uuid)
317+
318+
def test_course_runs_restriction_param(self):
319+
run_restricted = CourseRunFactory(
320+
course=self.course,
321+
start=datetime.datetime(2033, 1, 1, tzinfo=pytz.UTC),
322+
status=CourseRunStatus.Published
323+
)
324+
run_not_restricted = CourseRunFactory(
325+
course=self.course,
326+
start=datetime.datetime(2033, 1, 1, tzinfo=pytz.UTC),
327+
status=CourseRunStatus.Unpublished
328+
)
329+
RestrictedCourseRunFactory(course_run=run_restricted, restriction_type='custom-b2c')
330+
SeatFactory(course_run=run_restricted)
331+
332+
url = reverse('api:v1:course-detail', kwargs={'key': self.course.key})
333+
url += '?include_restricted=custom-b2c'
334+
with self.assertNumQueries(36, threshold=3):
335+
response = self.client.get(url)
336+
assert response.status_code == 200
337+
338+
self.assertEqual(set(response.data['course_run_keys']), {run_not_restricted.key, run_restricted.key})
339+
self.assertEqual(set(response.data['course_run_statuses']), {run_not_restricted.status, run_restricted.status})
340+
self.assertEqual(len(response.data['course_runs']), 2)
341+
self.assertEqual(response.data['advertised_course_run_uuid'], run_restricted.uuid)
342+
281343
def test_list(self):
282344
""" Verify the endpoint returns a list of all courses. """
283345
url = reverse('api:v1:course-list')

course_discovery/apps/api/v1/tests/test_views/test_programs.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,13 @@
1313
from course_discovery.apps.api.v1.views.programs import ProgramViewSet
1414
from course_discovery.apps.core.tests.factories import USER_PASSWORD, UserFactory
1515
from course_discovery.apps.core.tests.helpers import make_image_file
16-
from course_discovery.apps.course_metadata.choices import ProgramStatus
16+
from course_discovery.apps.course_metadata.choices import CourseRunStatus, ProgramStatus
1717
from course_discovery.apps.course_metadata.models import CourseType, Program, ProgramType
1818
from course_discovery.apps.course_metadata.tests.factories import (
1919
CorporateEndorsementFactory, CourseFactory, CourseRunFactory, CurriculumCourseMembershipFactory, CurriculumFactory,
2020
CurriculumProgramMembershipFactory, DegreeAdditionalMetadataFactory, DegreeFactory, EndorsementFactory,
2121
ExpectedLearningItemFactory, JobOutlookItemFactory, OrganizationFactory, PersonFactory, ProgramFactory,
22-
ProgramTypeFactory, VideoFactory
22+
ProgramTypeFactory, RestrictedCourseRunFactory, VideoFactory
2323
)
2424

2525

@@ -48,13 +48,16 @@ def setup(self, client, django_assert_num_queries, partner):
4848
self.partner = partner
4949
self.request = request
5050

51-
def create_program(self, courses=None, program_type=None):
51+
def create_program(self, courses=None, program_type=None, include_restricted_run=False):
5252
organizations = [OrganizationFactory(partner=self.partner)]
5353
person = PersonFactory()
5454

5555
if courses is None:
5656
courses = [CourseFactory(partner=self.partner)]
57-
CourseRunFactory(course=courses[0], staff=[person])
57+
course_run = CourseRunFactory(course=courses[0], staff=[person])
58+
59+
if include_restricted_run:
60+
RestrictedCourseRunFactory(course_run=course_run, restriction_type='custom-b2c')
5861

5962
if program_type is None:
6063
program_type = ProgramTypeFactory()
@@ -216,6 +219,21 @@ def test_list(self):
216219

217220
self.assert_list_results(self.list_path, expected, 26)
218221

222+
@pytest.mark.parametrize("include_restriction_param", [True, False])
223+
def test_list_restricted_runs(self, include_restriction_param):
224+
self.create_program(include_restricted_run=True)
225+
query_param_string = "?include_restricted=custom-b2c" if include_restriction_param else ""
226+
resp = self.client.get(self.list_path + query_param_string)
227+
228+
if include_restriction_param:
229+
assert resp.data['results'][0]['courses'][0]['course_runs']
230+
assert resp.data['results'][0]['courses'][0]['course_run_statuses']
231+
assert resp.data['results'][0]['course_run_statuses'] == [CourseRunStatus.Published]
232+
else:
233+
assert not resp.data['results'][0]['courses'][0]['course_runs']
234+
assert not resp.data['results'][0]['courses'][0]['course_run_statuses']
235+
assert resp.data['results'][0]['course_run_statuses'] == []
236+
219237
def test_extended_query_param_fields(self):
220238
""" Verify that the `extended` query param will result in an extended amount of fields returned. """
221239
for _ in range(3):

0 commit comments

Comments
 (0)