diff --git a/payments/api.py b/payments/api.py index 8aded70e..f1dc154f 100644 --- a/payments/api.py +++ b/payments/api.py @@ -7,7 +7,7 @@ import reversion from django.conf import settings from django.contrib.auth import get_user_model -from django.core.exceptions import ObjectDoesNotExist, PermissionDenied, ValidationError +from django.core.exceptions import ObjectDoesNotExist, PermissionDenied from django.db import transaction from django.db.models import Q, QuerySet from django.urls import reverse @@ -176,92 +176,77 @@ def process_cybersource_payment_response( Returns: Order.state """ - if not PaymentGateway.validate_processor_response( settings.ECOMMERCE_DEFAULT_PAYMENT_GATEWAY, request ): - msg = "Could not validate response from the payment processor." - raise PermissionDenied(msg) + error_message = "Could not validate response from the payment processor." + raise PermissionDenied(error_message) processor_response = PaymentGateway.get_formatted_response( settings.ECOMMERCE_DEFAULT_PAYMENT_GATEWAY, request ) - - reason_code = processor_response.response_code + reason_code = ( + int(processor_response.response_code) + if ( + processor_response.response_code + and processor_response.response_code.isdigit() + ) + else None + ) transaction_id = processor_response.transaction_id - if reason_code and reason_code.isdigit(): - reason_code = int(reason_code) + + # Log transaction status + if reason_code is not None: message = ( - "Transaction was not successful. " - "Transaction ID:%s Reason Code:%d Message:%s" + f"Transaction ID: {transaction_id}, Reason Code: {reason_code}, " + f"Message: {processor_response.message}" ) if reason_code in CYBERSOURCE_ERROR_CODES: - # Log the errors as errors, so they make Sentry logs. - log.error(message, transaction_id, reason_code, processor_response.message) + log.error("Transaction was not successful. %s", message) elif reason_code not in CYBERSOURCE_ACCEPT_CODES: - # These may be declines or reviews - only log in debug mode. - log.debug(message, transaction_id, reason_code, processor_response.message) - - return_message = "" - - if processor_response.state == ProcessorResponse.STATE_DECLINED: - # Transaction declined for some reason - # This probably means the order needed to go through the process - # again so maybe tell the user to do a thing. - msg = f"Transaction declined: {processor_response.message}" - log.debug(msg) - order.decline() - return_message = order.state - elif processor_response.state == ProcessorResponse.STATE_ERROR: - # Error - something went wrong with the request - msg = f"Error happened submitting the transaction: {processor_response.message}" - log.debug(msg) - order.errored() - return_message = order.state - elif processor_response.state in [ - ProcessorResponse.STATE_CANCELLED, - ProcessorResponse.STATE_REVIEW, - ]: - # Transaction cancelled or reviewed - # Transaction could be cancelled for reasons that don't necessarily - # mean that the entire order is invalid, so we'll do nothing with - # the order here (other than set it to Cancelled). - # Transaction could be - msg = f"Transaction cancelled/reviewed: {processor_response.message}" - log.debug(msg) - order.cancel() - return_message = order.state - - elif ( - processor_response.state == ProcessorResponse.STATE_ACCEPTED - or reason_code == CYBERSOURCE_REASON_CODE_SUCCESS - ): - # It actually worked here - basket = Basket.objects.filter(user=order.purchaser).first() - try: - msg = f"Transaction accepted!: {processor_response.message}" - log.debug(msg) - fulfill_completed_order(order, request.POST, basket, source) - except ValidationError: - msg = ( - "Missing transaction id from transaction response: " - f"{processor_response.message}" - ) - log.debug(msg) - raise + log.debug("Transaction was not successful. %s", message) + + # Handle processor response states + state_handlers = { + ProcessorResponse.STATE_DECLINED: order.decline, + ProcessorResponse.STATE_ERROR: order.errored, + ProcessorResponse.STATE_CANCELLED: order.cancel, + ProcessorResponse.STATE_REVIEW: order.cancel, + ProcessorResponse.STATE_ACCEPTED: lambda: fulfill_completed_order( + order, + request.POST, + Basket.objects.filter(user=order.purchaser).first(), + source, + ), + } - return_message = order.state + if processor_response.state in state_handlers: + log.debug( + "Transaction %s: %s", + processor_response.state.lower(), + processor_response.message, + ) + state_handlers[processor_response.state]() + elif reason_code == CYBERSOURCE_REASON_CODE_SUCCESS: + log.debug("Transaction accepted!: %s", processor_response.message) + fulfill_completed_order( + order, + request.POST, + Basket.objects.filter(user=order.purchaser).first(), + source, + ) else: - msg = ( - f"Unknown state {processor_response.state} found: transaction ID" - f"{transaction_id}, reason code {reason_code}, response message" - f" {processor_response.message}" + log.error( + "Unknown state %s found: transaction ID %s, reason code %s, " + "response message %s", + processor_response.state, + transaction_id, + reason_code, + processor_response.message, ) - log.error(msg) order.cancel() - return_message = order.state - return return_message + return order.state def refund_order( @@ -282,161 +267,76 @@ def refund_order( tuple of (bool, str) : A boolean identifying if an order refund was successful, and the error message (if there is one) """ - refund_amount = kwargs.get("refund_amount") - refund_reason = kwargs.get("refund_reason", "") - message = "" - if reference_number is not None: - order = FulfilledOrder.objects.get(reference_number=reference_number) - elif order_id is not None: - order = FulfilledOrder.objects.get(pk=order_id) - else: - message = "Either order_id or reference_number is required to fetch the Order." - log.error(message) - return False, message + # Validate input + if not any([order_id, reference_number]): + log.error("Either order_id or reference_number is required to fetch the Order.") + return ( + False, + "Either order_id or reference_number is required to fetch the Order.", + ) + + # Fetch the order + try: + order = ( + FulfilledOrder.objects.get(reference_number=reference_number) + if reference_number + else FulfilledOrder.objects.get(pk=order_id) + ) + except FulfilledOrder.DoesNotExist: + log.exception( + "Order with %s %s not found.", + "reference_number" if reference_number else "order_id", + reference_number or order_id, + ) + raise + + # Validate order state if order.state != Order.STATE.FULFILLED: - message = f"Order with order_id {order.id} is not in fulfilled state." - log.error(message) - return False, message + log.error("Order with order_id %s is not in fulfilled state.", order.id) + return False, f"Order with order_id {order.id} is not in fulfilled state." + # Fetch the most recent transaction order_recent_transaction = order.transactions.first() - if not order_recent_transaction: - message = f"There is no associated transaction against order_id {order.id}" - log.error(message) - return False, message - - transaction_dict = order_recent_transaction.data + log.error("There is no associated transaction against order_id %s", order.id) + return False, f"There is no associated transaction against order_id {order.id}" - # Check for a PayPal payment - if there's one, we can't process it - if "paypal_token" in transaction_dict: + # Check for PayPal payment + if "paypal_token" in order_recent_transaction.data: msg = ( f"PayPal: Order {order.reference_number} contains a PayPal" "transaction. Please contact Finance to refund this order." ) raise PaypalRefundError(msg) - # The refund amount can be different then the payment amount, so we override - # that before PaymentGateway processing. - # e.g. While refunding order from Django Admin we can select custom amount. - if refund_amount: - transaction_dict["req_amount"] = refund_amount + # Prepare refund request + transaction_dict = order_recent_transaction.data + if "refund_amount" in kwargs: + transaction_dict["req_amount"] = kwargs["refund_amount"] + # Process refund refund_gateway_request = PaymentGateway.create_refund_request( settings.ECOMMERCE_DEFAULT_PAYMENT_GATEWAY, transaction_dict ) - response = PaymentGateway.start_refund( - settings.ECOMMERCE_DEFAULT_PAYMENT_GATEWAY, - refund_gateway_request, + settings.ECOMMERCE_DEFAULT_PAYMENT_GATEWAY, refund_gateway_request ) - if response.state in REFUND_SUCCESS_STATES: - # Record refund transaction with PaymentGateway's refund response - order.refund( - api_response_data=response.response_data, - amount=transaction_dict["req_amount"], - reason=refund_reason, - ) - else: + # Handle refund response + if response.state not in REFUND_SUCCESS_STATES: log.error( - "There was an error with the Refund API request %s", - response.message, + "There was an error with the Refund API request: %s", response.message ) - # PaymentGateway didn't raise an exception and instead gave a - # Response but the response status was not success so we manually - # rollback the transaction in this case. - msg = f"Payment gateway returned an error: {response.message}" - raise PaymentGatewayError(msg) - - return True, message - - -def check_and_process_pending_orders_for_resolution(refnos=None): - """ - Check pending orders for resolution. By default, this will pull all the - pending orders that are in the system. - - Args: - - refnos (list or None): check specific reference numbers - Returns: - - Tuple of counts: fulfilled count, cancelled count, error count - - """ - - gateway = PaymentGateway.get_gateway_class( - settings.ECOMMERCE_DEFAULT_PAYMENT_GATEWAY + error_message = f"Payment gateway returned an error: {response.message}" + raise PaymentGatewayError(error_message) + + # Record successful refund + order.refund( + api_response_data=response.response_data, + amount=transaction_dict["req_amount"], + reason=kwargs.get("refund_reason", ""), ) - - if refnos is not None: - pending_orders = PendingOrder.objects.filter( - state=PendingOrder.STATE.PENDING, reference_number__in=refnos - ).values_list("reference_number", flat=True) - else: - pending_orders = PendingOrder.objects.filter( - state=PendingOrder.STATE.PENDING - ).values_list("reference_number", flat=True) - - if len(pending_orders) == 0: - return (0, 0, 0) - - msg = f"Resolving {len(pending_orders)} orders" - log.info(msg) - - results = gateway.find_and_get_transactions(pending_orders) - - if len(results.keys()) == 0: - msg = "No orders found to resolve." - log.info(msg) - return (0, 0, 0) - - fulfilled_count = cancel_count = error_count = 0 - - for result in results: - payload = results[result] - if int(payload["reason_code"]) == CYBERSOURCE_REASON_CODE_SUCCESS: - try: - order = PendingOrder.objects.filter( - state=PendingOrder.STATE.PENDING, - reference_number=payload["req_reference_number"], - ).get() - - order.fulfill(payload) - fulfilled_count += 1 - - msg = f"Fulfilled order {order.reference_number}." - log.info(msg) - except Exception as e: - msg = ( - "Couldn't process pending order for fulfillment " - f"{payload['req_reference_number']}: {e!s}" - ) - log.exception(msg) - error_count += 1 - else: - try: - order = PendingOrder.objects.filter( - state=PendingOrder.STATE.PENDING, - reference_number=payload["req_reference_number"], - ).get() - - order.cancel() - order.transactions.create( - transaction_id=payload["transaction_id"], - amount=order.total_price_paid, - data=payload, - reason=f"Cancelled due to processor code {payload['reason_code']}", - ) - order.save() - cancel_count += 1 - - msg = f"Cancelled order {order.reference_number}." - log.info(msg) - except Exception: - msg = "Couldn't process pending order for cancellation %s" - log.exception(msg, payload["req_reference_number"]) - error_count += 1 - - return (fulfilled_count, cancel_count, error_count) + return True, "" def send_post_sale_webhook(system_id, order_id, source): @@ -450,27 +350,23 @@ def send_post_sale_webhook(system_id, order_id, source): system = IntegratedSystem.objects.get(pk=system_id) system_webhook_url = system.webhook_url - if system_webhook_url: - log.info( - ( - "send_post_sale_webhook: Calling webhook endpoint %s for order %s " - "with source %s" - ), - system_webhook_url, - order.reference_number, - source, - ) - else: + if not system.webhook_url: log.warning( - ( - "send_post_sale_webhook: No webhook URL set for system %s, skipping" - "for order %s" - ), + "send_post_sale_webhook: No webhook URL set for system %s, " + "skipping for order %s", system.slug, order.reference_number, ) return + log.info( + "send_post_sale_webhook: Calling webhook endpoint %s for order %s " + "with source %s", + system.webhook_url, + order.reference_number, + source, + ) + order_info = WebhookOrder( order=order, lines=[ @@ -504,13 +400,12 @@ def process_post_sale_webhooks(order_id, source): pk=order_id ) - systems = [ - product.system - for product in [ - line.product_version._object_version.object # noqa: SLF001 - for line in order.lines.all() - ] - ] + # Extract unique systems from the order lines + systems = { + line.product_version._object_version.object.system # noqa: SLF001 + for line in order.lines.all() + if line.product_version and line.product_version._object_version # noqa: SLF001 + } for system in systems: if not system.webhook_url: @@ -579,20 +474,183 @@ def get_auto_apply_discounts_for_basket(basket_id: int) -> QuerySet[Discount]: QuerySet: The auto-apply discounts that can be applied to the basket. """ basket = Basket.objects.get(pk=basket_id) - return ( - Discount.objects.filter( - Q(product__in=basket.get_products()) | Q(product__isnull=True) + products = basket.get_products() + + return Discount.objects.filter( + Q(product__in=products) | Q(product__isnull=True), + Q(integrated_system=basket.integrated_system) + | Q(integrated_system__isnull=True), + Q(assigned_users=basket.user) | Q(assigned_users__isnull=True), + automatic=True, + ) + + +def validate_discount_type(discount_type): + """ + Validate the discount type. + + Args: + discount_type (str): The discount type to validate. + + Raises: + ValueError: If the discount type is not valid. + """ + if discount_type not in ALL_DISCOUNT_TYPES: + error_message = f"Invalid discount type: {discount_type}." + raise ValueError(error_message) + + +def validate_payment_type(payment_type): + """ + Validate the payment type. + + Args: + payment_type (str): The payment type to validate. + + Raises: + ValueError: If the payment type is not valid. + """ + if payment_type not in ALL_PAYMENT_TYPES: + error_message = f"Payment type {payment_type} is not valid." + raise ValueError(error_message) + + +def validate_percent_off_amount(discount_type, amount): + """ + Validate the percent off amount. + + Args: + discount_type (str): discount type. + amount (int): discount amount. + + Raises: + ValueError: If the discount amount is not valid for the discount type. + """ + MAX_PERCENT_OFF_AMOUNT = 100 + if discount_type == DISCOUNT_TYPE_PERCENT_OFF and amount > MAX_PERCENT_OFF_AMOUNT: + error_message = ( + f"Discount amount {amount} not valid for discount type " + f"{DISCOUNT_TYPE_PERCENT_OFF}." ) - .filter( - Q(integrated_system=basket.integrated_system) - | Q(integrated_system__isnull=True) + raise ValueError(error_message) + + +def validate_prefix_for_batch(count, prefix): + """ + Validate the prefix for a batch of discount codes. + + Args: + count (int): The number of codes to create. + prefix (str): The prefix to append to the codes. + + Raises: + ValueError: If the prefix is not valid for a batch of codes. + ValueError: If the prefix is too long. + """ + MAX_PREFIX_LENGTH = 63 + if count > 1 and not prefix: + error_message = "You must specify a prefix to create a batch of codes." + raise ValueError(error_message) + if prefix and len(prefix) > MAX_PREFIX_LENGTH: + message = ( + f"Prefix {prefix} is {len(prefix)} - prefixes must be " + "63 characters or less." ) - .filter(Q(assigned_users=basket.user) | Q(assigned_users__isnull=True)) - .filter(automatic=True) - ) + raise ValueError(message) + + +def generate_codes(count, prefix=None, codes=None): + """ + Generate a list of discount codes. + + Args: + count (int): The number of codes to create. + prefix (str, optional): The prefix to append to the codes. Defaults to None. + codes (str, optional): The codes to create. Defaults to None. + + Returns: + list(str): The generated codes. + """ + if count > 1: + return [f"{prefix}{uuid.uuid4()}" for _ in range(count)] + return [codes] + + +def get_redemption_type(kwargs): + """ + Get the redemption type. + + Args: + kwargs (): The keyword arguments passed to the function. + one_time, once_per_user, and redemption_type are the valid arguments. + + """ + if kwargs.get("one_time"): + return REDEMPTION_TYPE_ONE_TIME + if kwargs.get("once_per_user"): + return REDEMPTION_TYPE_ONE_TIME_PER_USER + if ( + "redemption_type" in kwargs + and kwargs["redemption_type"] in ALL_REDEMPTION_TYPES + ): + return kwargs["redemption_type"] + return REDEMPTION_TYPE_UNLIMITED + + +def get_object_or_raise(model, identifier, missing_msg): + """ + Get an object from the model, or raise an error if it doesn't exist. + + Args: + model (Model): The model to get the object from. + identifier (): The identifier of the object. + missing_msg (str): The message to raise if the object doesn't exist. + + Raises: + ValueError: If the object doesn't exist. + + Returns: + Model: The object from the model. + """ + try: + if isinstance(identifier, int) or ( + isinstance(identifier, str) and identifier.isdigit() + ): + return model.objects.get(pk=identifier) + return model.objects.get(slug=identifier) + except ObjectDoesNotExist as err: + raise ValueError(missing_msg) from err -def generate_discount_code(**kwargs): # noqa: C901, PLR0912, PLR0915 +def get_users(users): + """ + Get a list of users from the user identifiers. + + Args: + users (str): The user identifiers. + + Raises: + ValueError: If the user doesn't exist. + + Returns: + User: The list of users. + """ + user_list = [] + for user_identifier in users: + try: + if isinstance(user_identifier, int) or ( + isinstance(user_identifier, str) and user_identifier.isdigit() + ): + user_list.append(User.objects.get(pk=user_identifier)) + else: + user_list.append(User.objects.get(email=user_identifier)) + except ObjectDoesNotExist: + error_message = f"User {user_identifier} does not exist." + raise ValueError(error_message) from None + return user_list + + +def generate_discount_code(**kwargs): """ Generate a discount code (or a batch of discount codes) as specified by the arguments passed. @@ -623,165 +681,79 @@ def generate_discount_code(**kwargs): # noqa: C901, PLR0912, PLR0915 code, type, amount, expiration_date """ - codes_to_generate = [] - discount_type = kwargs["discount_type"] - redemption_type = REDEMPTION_TYPE_UNLIMITED - payment_type = kwargs["payment_type"] - amount = Decimal(kwargs["amount"]) - bulk_discount_collection = None - if kwargs["discount_type"] not in ALL_DISCOUNT_TYPES: - raise ValueError(f"Invalid discount type: {kwargs['discount_type']}.") # noqa: EM102, TRY003 + validate_discount_type(kwargs["discount_type"]) + validate_payment_type(kwargs["payment_type"]) + validate_percent_off_amount(kwargs["discount_type"], Decimal(kwargs["amount"])) + validate_prefix_for_batch(kwargs.get("count", 1), kwargs.get("prefix", "")) - if payment_type not in ALL_PAYMENT_TYPES: - raise ValueError(f"Payment type {payment_type} is not valid.") # noqa: EM102, TRY003 + codes_to_generate = generate_codes( + kwargs.get("count", 1), kwargs.get("prefix", ""), kwargs.get("codes", "") + ) + redemption_type = get_redemption_type(kwargs) + expiration_date = ( + parse_supplied_date(kwargs["expires"]) if kwargs.get("expires") else None + ) + activation_date = ( + parse_supplied_date(kwargs["activates"]) if kwargs.get("activates") else None + ) - if kwargs["discount_type"] == DISCOUNT_TYPE_PERCENT_OFF and amount > 100: # noqa: PLR2004 - message = ( - f"Discount amount {amount} not valid for discount type " - f"{DISCOUNT_TYPE_PERCENT_OFF}." + integrated_system = ( + get_object_or_raise( + IntegratedSystem, + kwargs["integrated_system"], + f"Integrated system {kwargs['integrated_system']} does not exist.", ) - raise ValueError(message) - - if kwargs["count"] > 1 and "prefix" not in kwargs: - raise ValueError("You must specify a prefix to create a batch of codes.") # noqa: EM101, TRY003 - - if kwargs["count"] > 1: - prefix = kwargs["prefix"] - if prefix: - # upped the discount code limit to 100 characters - this used to be 13 (50 - 37 for the UUID) # noqa: E501 - if len(prefix) > 63: # noqa: PLR2004 - raise ValueError( # noqa: TRY003 - f"Prefix {prefix} is {len(prefix)} - prefixes must be 63 characters or less." # noqa: E501, EM102 - ) - bulk_discount_collection, _ = BulkDiscountCollection.objects.get_or_create( - prefix=prefix - ) - - for i in range(kwargs["count"]): # noqa: B007 - generated_uuid = uuid.uuid4() - code = f"{prefix}{generated_uuid}" - - codes_to_generate.append(code) - else: - codes_to_generate = kwargs["codes"] - - if kwargs.get("one_time"): - redemption_type = REDEMPTION_TYPE_ONE_TIME - - if kwargs.get("once_per_user"): - redemption_type = REDEMPTION_TYPE_ONE_TIME_PER_USER - - if ( - "redemption_type" in kwargs - and kwargs["redemption_type"] in ALL_REDEMPTION_TYPES - ): - redemption_type = kwargs["redemption_type"] + if kwargs.get("integrated_system") + else None + ) - if "expires" in kwargs and kwargs["expires"] is not None: - expiration_date = parse_supplied_date(kwargs["expires"]) - else: - expiration_date = None + product = ( + get_object_or_raise( + Product, kwargs["product"], f"Product {kwargs['product']} does not exist." + ) + if kwargs.get("product") + else None + ) - if "activates" in kwargs and kwargs["activates"] is not None: - activation_date = parse_supplied_date(kwargs["activates"]) - else: - activation_date = None + users = get_users(kwargs["users"]) if kwargs.get("users") else None - if "integrated_system" in kwargs and kwargs["integrated_system"] is not None: - # Try to get the integrated system via ID or slug. Raise an exception if it doesn't exist. # noqa: E501 - # check if integrated_system is an integer or a slug - integrated_system_missing_msg = ( - f"Integrated system {kwargs['integrated_system']} does not exist." + company = ( + get_object_or_raise( + Company, kwargs["company"], f"Company {kwargs['company']} does not exist." ) - if kwargs["integrated_system"].isdigit(): - try: - integrated_system = IntegratedSystem.objects.get( - pk=kwargs["integrated_system"] - ) - except IntegratedSystem.DoesNotExist: - raise ValueError(integrated_system_missing_msg) # noqa: B904 - else: - try: - integrated_system = IntegratedSystem.objects.get( - slug=kwargs["integrated_system"] - ) - except IntegratedSystem.DoesNotExist: - raise ValueError(integrated_system_missing_msg) # noqa: B904 - else: - integrated_system = None - - if "product" in kwargs and kwargs["product"] is not None: - # Try to get the product via ID or SKU. Raise an exception if it doesn't exist. - product_missing_msg = f"Product {kwargs['product']} does not exist." - if kwargs["product"].isdigit(): - try: - product = Product.objects.get(pk=kwargs["product"]) - except Product.DoesNotExist: - raise ValueError(product_missing_msg) # noqa: B904 - else: - try: - product = Product.objects.get(sku=kwargs["product"]) - except Product.DoesNotExist: - raise ValueError(product_missing_msg) # noqa: B904 - else: - product = None - - if "users" in kwargs and kwargs["users"] is not None: - # Try to get the users via ID or email. Raise an exception if it doesn't exist. - users = [] - user_missing_msg = "User %s does not exist." - for user_identifier in kwargs["users"]: - if user_identifier.isdigit(): - try: - users.append(User.objects.get(pk=user_identifier)) - except User.DoesNotExist: - raise ValueError(user_missing_msg % user_identifier) # noqa: B904 - else: - try: - user = User.objects.get(email=user_identifier) - users.append(user) - except User.DoesNotExist: - raise ValueError(user_missing_msg % user_identifier) # noqa: B904 - else: - users = None - - if "company" in kwargs and kwargs["company"] is not None: - try: - company = Company.objects.get(pk=kwargs["company"]) - except Company.DoesNotExist: - error_message = f"Company {kwargs['company']} does not exist." - raise ValueError(error_message) from None - else: - company = None + if kwargs.get("company") + else None + ) - if "transaction_number" in kwargs and kwargs["transaction_number"] is not None: - transaction_number = kwargs["transaction_number"] - else: - transaction_number = None + transaction_number = kwargs.get("transaction_number", "") generated_codes = [] - - for code_to_generate in codes_to_generate: + for code in codes_to_generate: with reversion.create_revision(): discount = Discount.objects.create( - discount_type=discount_type, + discount_type=kwargs["discount_type"], redemption_type=redemption_type, - payment_type=payment_type, + payment_type=kwargs["payment_type"], expiration_date=expiration_date, activation_date=activation_date, - discount_code=code_to_generate, - amount=amount, + discount_code=code, + amount=Decimal(kwargs["amount"]), is_bulk=True, integrated_system=integrated_system, product=product, - bulk_discount_collection=bulk_discount_collection, + bulk_discount_collection=( + BulkDiscountCollection.objects.get_or_create( + prefix=kwargs.get("prefix") + )[0] + if kwargs.get("prefix") + else None + ), company=company, transaction_number=transaction_number, ) - if users: - discount.assigned_users.set(users) - - generated_codes.append(discount) + if users: + discount.assigned_users.set(users) + generated_codes.append(discount) return generated_codes @@ -814,31 +786,21 @@ def update_discount_codes(**kwargs): # noqa: C901, PLR0912, PLR0915 * Number of discounts updated """ - discount_codes_to_update = kwargs["discount_codes"] if kwargs.get("discount_type"): - if kwargs["discount_type"] not in ALL_DISCOUNT_TYPES: - error_message = f"Discount type {kwargs['discount_type']} is not valid." - raise ValueError(error_message) - else: - discount_type = kwargs["discount_type"] + validate_discount_type(kwargs["discount_type"]) + discount_type = kwargs["discount_type"] + if kwargs.get("amount"): + validate_percent_off_amount(discount_type, Decimal(kwargs["amount"])) else: discount_type = None if kwargs.get("payment_type"): - if kwargs["payment_type"] not in ALL_PAYMENT_TYPES: - error_message = f"Payment type {kwargs['payment_type']} is not valid." - raise ValueError(error_message) - else: - payment_type = kwargs["payment_type"] + validate_payment_type(kwargs["payment_type"]) + payment_type = kwargs["payment_type"] else: payment_type = None - if kwargs.get("one_time"): - redemption_type = REDEMPTION_TYPE_ONE_TIME - elif kwargs.get("one_time_per_user"): - redemption_type = REDEMPTION_TYPE_ONE_TIME_PER_USER - else: - redemption_type = REDEMPTION_TYPE_UNLIMITED + redemption_type = get_redemption_type(kwargs) amount = Decimal(kwargs["amount"]) if kwargs.get("amount") else None @@ -852,72 +814,28 @@ def update_discount_codes(**kwargs): # noqa: C901, PLR0912, PLR0915 else: expiration_date = None - if kwargs.get("integrated_system"): - # Try to get the integrated system via ID or slug. - # Raise an exception if it doesn't exist. - integrated_system_missing_msg = ( - f"Integrated system {kwargs['integrated_system']} does not exist." + integrated_system = ( + get_object_or_raise( + IntegratedSystem, + kwargs["integrated_system"], + f"Integrated system {kwargs['integrated_system']} does not exist.", ) - if kwargs["integrated_system"].isdigit(): - try: - integrated_system = IntegratedSystem.objects.get( - pk=kwargs["integrated_system"] - ) - except IntegratedSystem.DoesNotExist: - raise ValueError(integrated_system_missing_msg) # noqa: B904 - else: - try: - integrated_system = IntegratedSystem.objects.get( - slug=kwargs["integrated_system"] - ) - except IntegratedSystem.DoesNotExist: - raise ValueError(integrated_system_missing_msg) # noqa: B904 - else: - integrated_system = None + if kwargs.get("integrated_system") + else None + ) - if kwargs.get("product"): - if kwargs.get("clear_products"): - error_message = "Cannot clear and set products at the same time." - raise ValueError(error_message) - # Try to get the product via ID or SKU. Raise an exception if it doesn't exist. - product_missing_msg = f"Product {kwargs['product']} does not exist." - if kwargs["product"].isdigit(): - try: - product = Product.objects.get(pk=kwargs["product"]) - except Product.DoesNotExist: - raise ValueError(product_missing_msg) # noqa: B904 - else: - try: - product = Product.objects.get(sku=kwargs["product"]) - except Product.DoesNotExist: - raise ValueError(product_missing_msg) # noqa: B904 - else: - product = None + product = ( + get_object_or_raise( + Product, kwargs["product"], f"Product {kwargs['product']} does not exist." + ) + if kwargs.get("product") + else None + ) - if kwargs.get("users"): - if kwargs.get("clear_users"): - error_message = "Cannot clear and set users at the same time." - raise ValueError(error_message) - # Try to get the users via ID or email. Raise an exception if it doesn't exist. - users = [] - user_missing_msg = "User %s does not exist." - for user_identifier in kwargs["users"]: - if user_identifier.isdigit(): - try: - users.append(User.objects.get(pk=user_identifier)) - except User.DoesNotExist: - raise ValueError(user_missing_msg % user_identifier) # noqa: B904 - else: - try: - user = User.objects.get(email=user_identifier) - users.append(user) - except User.DoesNotExist: - raise ValueError(user_missing_msg % user) # noqa: B904 - else: - users = None + users = get_users(kwargs["users"]) if kwargs.get("users") else None if kwargs.get("prefix"): - prefix = kwargs["prefix"] + prefix = kwargs.get("prefix") bulk_discount_collection = BulkDiscountCollection.objects.filter( prefix=prefix ).first() @@ -928,6 +846,7 @@ def update_discount_codes(**kwargs): # noqa: C901, PLR0912, PLR0915 raise ValueError(error_message) discounts_to_update = bulk_discount_collection.discounts.all() else: + discount_codes_to_update = kwargs.get("discount_codes", []) discounts_to_update = Discount.objects.filter( discount_code__in=discount_codes_to_update ) @@ -1018,60 +937,49 @@ def locate_customer_for_basket(request, basket, basket_item): def check_blocked_countries(basket, basket_item): """ - Check to see if the product is blocked for this customer. - - We should have the customer's location stored, so now perform the check - to see if the product is blocked or not. If it is, we raise an exception - to stop the process. - - Try this one first so we can kick the user out if they're blocked. + Check if the product is blocked for the customer based on their location. + Raises ProductBlockedError if the product is blocked. Args: - - basket (Basket): the current basket - - basket_item (Product): the item to add to the basket - Returns: - - None + - basket (Basket): The current basket. + - basket_item (Product): The item to add to the basket. Raises: - - ProductBlockedError: if the customer is blocked + - ProductBlockedError: If the customer is blocked from purchasing the product. """ + log.debug("Checking blockages for user: %s", basket.user) - log.debug("check_blocked_countries: checking for blockages for %s", basket.user) - - blocked_qset = BlockedCountry.objects.filter( - country_code=basket.user_blockable_country_code - ).filter(Q(product__isnull=True) | Q(product=basket_item)) - - if blocked_qset.exists(): - log.debug("check_blocked_countries: user is blocked") - errmsg = "Product %s blocked from purchase in country %s" - raise ProductBlockedError( - errmsg, basket_item, basket.user_blockable_country_code + if ( + BlockedCountry.objects.filter( + country_code=basket.user_blockable_country_code, ) + .filter(Q(product__isnull=True) | Q(product=basket_item)) + .exists() + ): + log.debug("User is blocked from purchasing the product.") + message = ( + f"Product {basket_item} blocked in country " + f"{basket.user_blockable_country_code}" + ) + raise ProductBlockedError(message) def check_taxable(basket): """ - Check to see if the product is taxable for this customer. - - We don't consider particular items taxable or not but we may want to - change that in the future. (Maybe one day we'll sell gift cards or - something!) So, this really applies to the basket - if there's an - applicable rate, then we tack it on to the basket. + Check if the basket is taxable based on the user's country code. + If taxable, apply the tax rate to the basket. Args: - basket (Basket): the current basket Returns: - None """ - log.debug("check_taxable: checking for tax for %s", basket.user) - taxable_qset = TaxRate.objects.filter( + taxrate = TaxRate.objects.filter( country_code=basket.user_blockable_country_code - ) + ).first() - if taxable_qset.exists(): - taxrate = taxable_qset.first() + if taxrate: basket.tax_rate = taxrate basket.save() log.debug("check_taxable: charging the tax for %s", taxrate) diff --git a/payments/api_test.py b/payments/api_test.py index 32c26c3a..b5f5e818 100644 --- a/payments/api_test.py +++ b/payments/api_test.py @@ -2,36 +2,57 @@ import random import uuid +from datetime import UTC, datetime +from decimal import Decimal import pytest import reversion from CyberSource.rest import ApiException from django.conf import settings +from django.core.exceptions import ObjectDoesNotExist +from django.http import HttpRequest from django.urls import reverse from factory import Faker, fuzzy from mitol.payment_gateway.api import PaymentGateway, ProcessorResponse from reversion.models import Version from payments.api import ( - check_and_process_pending_orders_for_resolution, + check_blocked_countries, + check_taxable, generate_checkout_payload, + generate_discount_code, get_auto_apply_discounts_for_basket, + get_redemption_type, + get_users, + locate_customer_for_basket, process_cybersource_payment_response, process_post_sale_webhooks, refund_order, + send_post_sale_webhook, send_pre_sale_webhook, + update_discount_codes, ) from payments.constants import ( PAYMENT_HOOK_ACTION_POST_SALE, PAYMENT_HOOK_ACTION_PRE_SALE, ) -from payments.exceptions import PaymentGatewayError, PaypalRefundError +from payments.dataclasses import CustomerLocationMetadata +from payments.exceptions import ( + PaymentGatewayError, + PaypalRefundError, + ProductBlockedError, +) from payments.factories import ( BasketFactory, BasketItemFactory, + BlockedCountryFactory, + BulkDiscountCollectionFactory, + CompanyFactory, DiscountFactory, LineFactory, OrderFactory, + RedeemedDiscountFactory, + TaxRateFactory, TransactionFactory, ) from payments.models import ( @@ -52,9 +73,15 @@ from system_meta.factories import IntegratedSystemFactory, ProductFactory from system_meta.models import IntegratedSystem from unified_ecommerce.constants import ( + ALL_DISCOUNT_TYPES, + ALL_PAYMENT_TYPES, DISCOUNT_TYPE_DOLLARS_OFF, + DISCOUNT_TYPE_PERCENT_OFF, POST_SALE_SOURCE_BACKOFFICE, POST_SALE_SOURCE_REDIRECT, + REDEMPTION_TYPE_ONE_TIME, + REDEMPTION_TYPE_ONE_TIME_PER_USER, + REDEMPTION_TYPE_UNLIMITED, TRANSACTION_TYPE_PAYMENT, TRANSACTION_TYPE_REFUND, ) @@ -155,10 +182,10 @@ def _payment_gateway_settings(): def test_cybersource_refund_no_order(): - """Test that refund_order throws FulfilledOrder.DoesNotExist exception when the order doesn't exist""" + """Test that refund_order throws FulfilledOrder.DoesNotExist exception when the order doesn"t exist""" with pytest.raises(FulfilledOrder.DoesNotExist): - refund_order(order_id=1) # Caling refund with random Id + refund_order(order_id=1) # Calling refund with random Id def create_basket(user, products): @@ -501,103 +528,6 @@ def test_process_cybersource_payment_decline_response( order.refresh_from_db() -@pytest.mark.parametrize("test_type", [None, "fail", "empty"]) -def test_check_and_process_pending_orders_for_resolution(mocker, test_type): - """ - Tests the pending order check. test_type can be: - - None - there's an order and it was found - - fail - there's an order but the payment failed (failed status in CyberSource) - - empty - order isn't pending - """ - order = OrderFactory.create(state=Order.STATE.PENDING) - - test_payload = { - "utf8": "", - "message": "Request was processed successfully.", - "decision": "100", - "auth_code": "888888", - "auth_time": "2023-02-09T20:06:51Z", - "signature": "", - "req_amount": "999", - "req_locale": "en-us", - "auth_amount": "999", - "reason_code": "100", - "req_currency": "USD", - "auth_avs_code": "X", - "auth_response": "100", - "req_card_type": "", - "request_token": "", - "card_type_name": "", - "req_access_key": "", - "req_item_0_sku": "60-2", - "req_profile_id": "2BA30484-75E7-4C99-A7D4-8BD7ADE4552D", - "transaction_id": "6759732112426719104003", - "req_card_number": "", - "req_consumer_id": "8c6976e5b5410415bde908bd4dee15dfb167a9c873fc4bb8a81f6f2ab448a918", - "req_item_0_code": "60", - "req_item_0_name": "course-v1:edX+E2E-101+course", - "signed_date_time": "2023-02-09T20:06:51Z", - "auth_avs_code_raw": "I1", - "auth_trans_ref_no": "123456789619999", - "bill_trans_ref_no": "123456789619999", - "req_bill_to_email": "testlearner@odl.local", - "req_payment_method": "card", - "signed_field_names": "", - "req_bill_to_surname": "LEARNER", - "req_item_0_quantity": 1, - "req_line_item_count": 1, - "req_bill_to_forename": "TEST", - "req_card_expiry_date": "02-2025", - "req_reference_number": f"{order.reference_number}", - "req_transaction_type": "sale", - "req_transaction_uuid": "", - "req_item_0_tax_amount": "0", - "req_item_0_unit_price": "999", - "req_customer_ip_address": "172.19.0.8", - "req_bill_to_address_city": "Tallahasseeeeeeee", - "req_bill_to_address_line1": "555 123 Place", - "req_bill_to_address_state": "FL", - "req_merchant_defined_data1": "1", - "req_bill_to_address_country": "US", - "req_bill_to_address_postal_code": "81992", - "req_override_custom_cancel_page": "https://rc.mitxonline.mit.edu/checkout/result/", - "req_override_custom_receipt_page": "https://rc.mitxonline.mit.edu/checkout/result/", - "req_card_type_selection_indicator": "001", - } - - retval = {} - - if test_type == "fail": - test_payload["reason_code"] = "999" - - if test_type == "empty": - order.state = Order.STATE.CANCELED - order.save() - order.refresh_from_db() - - if test_type is None or test_type == "fail": - retval = {f"{order.reference_number}": test_payload} - - mocked_gateway_func = mocker.patch( - "mitol.payment_gateway.api.CyberSourcePaymentGateway.find_and_get_transactions", - return_value=retval, - ) - - (fulfilled, cancelled, errored) = check_and_process_pending_orders_for_resolution() - - if test_type == "empty": - assert not mocked_gateway_func.called - assert (fulfilled, cancelled, errored) == (0, 0, 0) - elif test_type == "fail": - order.refresh_from_db() - assert order.state == Order.STATE.CANCELED - assert (fulfilled, cancelled, errored) == (0, 1, 0) - else: - order.refresh_from_db() - assert order.state == Order.STATE.FULFILLED - assert (fulfilled, cancelled, errored) == (1, 0, 0) - - @pytest.mark.parametrize( "source", [POST_SALE_SOURCE_BACKOFFICE, POST_SALE_SOURCE_REDIRECT] ) @@ -824,3 +754,799 @@ def test_get_auto_apply_discount_for_basket_no_auto_discount_exists(): discount = get_auto_apply_discounts_for_basket(basket_item.basket.id) assert discount.count() == 0 + + +@pytest.mark.parametrize("source", ["backoffice", "redirect"]) +def test_send_post_sale_webhook_success(mocker, source): + """Test sending the post-sale webhook successfully.""" + + # Mock Order + order = OrderFactory(reference_number="ORDER123") + + order_user = order.purchaser + + # Mock IntegratedSystem + system = IntegratedSystemFactory( + webhook_url="https://example.com/webhook", + slug="system_slug", + api_key="test_api_key", + ) + + # Mock dispatch_webhook.delay + mocked_task = mocker.patch("payments.tasks.dispatch_webhook.delay") + + # Mock logger + mock_logger = mocker.patch("payments.api.log") + + # Execute + send_post_sale_webhook(system.id, order.id, source) + + # Assert + mock_logger.info.assert_called_once_with( + "send_post_sale_webhook: Calling webhook endpoint %s for order %s with source %s", + "https://example.com/webhook", + "ORDER123", + source, + ) + mocked_task.assert_called_once_with( + "https://example.com/webhook", + { + "system_key": "test_api_key", + "type": "postsale", + "user": { + "id": order_user.id, + "global_id": order_user.global_id, + "username": order_user.username, + "email": order_user.email, + "first_name": order_user.first_name, + "last_name": order_user.last_name, + "name": "", + }, + "data": { + "reference_number": "ORDER123", + "total_price_paid": "10.00", + "state": order.state, + "lines": [], + "order": order.id, + }, + "system_slug": "system_slug", + }, + ) + + +@pytest.mark.parametrize("source", ["backoffice", "redirect"]) +def test_send_post_sale_webhook_order_not_found(source): + """Test sending the post-sale webhook when the order does not exist.""" + + # Mock Order to raise ObjectDoesNotExist + order = OrderFactory(reference_number="ORDER123") + + # Mock IntegratedSystem + system = IntegratedSystemFactory() + + # Execute and Assert + with pytest.raises(ObjectDoesNotExist): + send_post_sale_webhook( + system.id, order.id + 1, source + ) # Use a non-existent order ID + + +@pytest.mark.parametrize("source", ["backoffice", "redirect"]) +def test_send_post_sale_webhook_system_not_found(source): + """Test sending the post-sale webhook when the system does not exist.""" + + # Mock Order + order = OrderFactory(reference_number="ORDER123") + + # Execute and Assert + with pytest.raises(ObjectDoesNotExist): + send_post_sale_webhook(999, order.id, source) # Use a non-existent system ID + + +def test_generate_discount_code_single(): + """ + Test generating a single discount code + """ + # Setup + test_user = UserFactory() + company = CompanyFactory() + product = ProductFactory() + integrated_system = IntegratedSystemFactory() + + # Test + codes = generate_discount_code( + discount_type=DISCOUNT_TYPE_PERCENT_OFF, + payment_type="credit_card", + amount=Decimal("20.00"), + codes="ABC123", + count=1, + users=[test_user.id], + company=company.id, + product=product.id, + integrated_system=integrated_system.slug, + activates="2023-01-01", + expires="2023-12-31", + ) + + # Assert + assert len(codes) == 1 + code = codes[0] + assert code.discount_type == DISCOUNT_TYPE_PERCENT_OFF + assert code.payment_type == "credit_card" + assert code.amount == Decimal("20.00") + assert code.discount_code == "ABC123" + assert code.expiration_date == datetime(2023, 12, 31, 0, 0, tzinfo=UTC) + assert code.activation_date == datetime(2023, 1, 1, 0, 0, tzinfo=UTC) + assert code.company == company + assert code.product == product + assert code.integrated_system == integrated_system + + +def test_generate_discount_code_batch(): + """ + Test generating a batch of discount codes + """ + # Setup + prefix = "BATCH" + + # Test + codes = generate_discount_code( + discount_type=DISCOUNT_TYPE_PERCENT_OFF, + payment_type="credit_card", + amount=Decimal("10.00"), + count=5, + prefix=prefix, + ) + + # Assert + assert len(codes) == 5 + for code in codes: + assert code.discount_type == DISCOUNT_TYPE_PERCENT_OFF + assert code.payment_type == "credit_card" + assert code.amount == Decimal("10.00") + assert code.discount_code.startswith(prefix) + + +def test_generate_discount_code_invalid_discount_type(): + """ + Test generating a discount code with an invalid discount type + """ + with pytest.raises(ValueError, match="Invalid discount type") as excinfo: + generate_discount_code( + discount_type="invalid_type", + payment_type="credit_card", + amount=Decimal("10.00"), + ) + assert "Invalid discount type" in str(excinfo.value) + + +def test_generate_discount_code_invalid_payment_type(): + """ + Test generating a discount code with an invalid payment type + """ + with pytest.raises( + ValueError, match="Payment type invalid_type is not valid" + ) as excinfo: + generate_discount_code( + discount_type=DISCOUNT_TYPE_PERCENT_OFF, + payment_type="invalid_type", + amount=Decimal("10.00"), + ) + assert "Payment type invalid_type is not valid" in str(excinfo.value) + + +def test_generate_discount_code_invalid_percent_amount(): + """ + Test generating a discount code with an invalid percent amount + """ + with pytest.raises( + ValueError, + match="Discount amount 150.00 not valid for discount type percent-off", + ) as excinfo: + generate_discount_code( + discount_type=DISCOUNT_TYPE_PERCENT_OFF, + payment_type="credit_card", + amount=Decimal("150.00"), + ) + assert "Discount amount 150.00 not valid for discount type percent-off" in str( + excinfo.value + ) + + +def test_generate_discount_code_missing_prefix_for_batch(): + """ + Test generating a batch of discount codes without a prefix + """ + with pytest.raises(ValueError) as excinfo: # noqa: PT011 + generate_discount_code( + discount_type=DISCOUNT_TYPE_PERCENT_OFF, + payment_type="credit_card", + amount=Decimal("10.00"), + count=2, + ) + assert "You must specify a prefix to create a batch of codes" in str(excinfo.value) + + +def test_generate_discount_code_prefix_too_long(): + """ + Test generating a discount code with a prefix that is too long + """ + with pytest.raises(ValueError) as excinfo: # noqa: PT011 + generate_discount_code( + discount_type=DISCOUNT_TYPE_PERCENT_OFF, + payment_type="credit_card", + amount=Decimal("10.00"), + count=2, + prefix="a" * 64, + ) + assert ( + "Prefix aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa is 64 - prefixes must be 63 characters or less" + in str(excinfo.value) + ) + + +def test_update_discount_codes_with_valid_discount_type(): + """ + Test updating discount codes with a valid discount type + """ + discount = DiscountFactory() + updated_count = update_discount_codes( + discount_codes=[discount.discount_code], discount_type=ALL_DISCOUNT_TYPES[0] + ) + assert updated_count == 1 + discount.refresh_from_db() + assert discount.discount_type == ALL_DISCOUNT_TYPES[0] + + +def test_update_discount_codes_with_invalid_discount_type(): + """ + Test updating discount codes with an invalid discount type + """ + discount = DiscountFactory() + with pytest.raises(ValueError, match="Invalid discount type") as excinfo: + update_discount_codes( + discount_codes=[discount.discount_code], discount_type="INVALID_TYPE" + ) + assert "Invalid discount type: INVALID_TYPE." in str(excinfo.value) + + +def test_update_discount_codes_with_valid_payment_type(): + """ + Test updating discount codes with a valid payment type + """ + discount = DiscountFactory() + updated_count = update_discount_codes( + discount_codes=[discount.discount_code], payment_type=ALL_PAYMENT_TYPES[0] + ) + assert updated_count == 1 + discount.refresh_from_db() + assert discount.payment_type == ALL_PAYMENT_TYPES[0] + + +def test_update_discount_codes_with_invalid_payment_type(): + """ + Test updating discount codes with an invalid payment type + """ + discount = DiscountFactory() + with pytest.raises( + ValueError, match="Payment type INVALID_TYPE is not valid." + ) as excinfo: + update_discount_codes( + discount_codes=[discount.discount_code], payment_type="INVALID_TYPE" + ) + assert "Payment type INVALID_TYPE is not valid." in str(excinfo.value) + + +def test_update_discount_codes_with_amount(): + """ + Test updating discount codes with an amount + """ + discount = DiscountFactory() + updated_count = update_discount_codes( + discount_codes=[discount.discount_code], amount="20.00" + ) + assert updated_count == 1 + discount.refresh_from_db() + assert discount.amount == Decimal("20.00") + + +def test_update_discount_codes_with_activates_and_expires(): + """ + Test updating discount codes with activation and expiration dates + """ + discount = DiscountFactory() + activates = "2023-01-01" + expires = "2023-12-31" + updated_count = update_discount_codes( + discount_codes=[discount.discount_code], activates=activates, expires=expires + ) + assert updated_count == 1 + discount.refresh_from_db() + assert discount.activation_date == datetime(2023, 1, 1, 0, 0, tzinfo=UTC) + assert discount.expiration_date == datetime(2023, 12, 31, 0, 0, tzinfo=UTC) + + +def test_update_discount_codes_with_integrated_system(): + """ + Test updating discount codes with an integrated system + """ + discount = DiscountFactory() + integrated_system = IntegratedSystemFactory() + updated_count = update_discount_codes( + discount_codes=[discount.discount_code], + integrated_system=integrated_system.slug, + ) + assert updated_count == 1 + discount.refresh_from_db() + assert discount.integrated_system == integrated_system + + +def test_update_discount_codes_with_product(): + """ + Test updating discount codes with a product + """ + discount = DiscountFactory() + product = ProductFactory() + updated_count = update_discount_codes( + discount_codes=[discount.discount_code], product=product.id + ) + assert updated_count == 1 + discount.refresh_from_db() + assert discount.product == product + + +def test_update_discount_codes_with_users(): + """ + Test updating discount codes with users + """ + discount = DiscountFactory() + users = UserFactory.create_batch(3) + user_emails = [user.email for user in users] + updated_count = update_discount_codes( + discount_codes=[discount.discount_code], users=user_emails + ) + assert updated_count == 1 + discount.refresh_from_db() + assert set(discount.assigned_users.all()) == set(users) + + +def test_update_discount_codes_with_clear_users(): + """ + Test updating discount codes by clearing the assigned users + """ + discount = DiscountFactory() + users = UserFactory.create_batch(2) + discount.assigned_users.set(users) + updated_count = update_discount_codes( + discount_codes=[discount.discount_code], clear_users=True + ) + assert updated_count == 1 + discount.refresh_from_db() + assert discount.assigned_users.count() == 0 + + +def test_update_discount_codes_with_prefix(): + """ + Test updating discount codes with a prefix + """ + bulk_collection = BulkDiscountCollectionFactory() + discounts = [] + discounts.append( + DiscountFactory( + discount_code="ABC1", + bulk_discount_collection=bulk_collection, + amount="10.00", + ) + ) + discounts.append( + DiscountFactory( + discount_code="ABC2", + bulk_discount_collection=bulk_collection, + amount="10.00", + ) + ) + discounts.append( + DiscountFactory( + discount_code="ABC3", + bulk_discount_collection=bulk_collection, + amount="10.00", + ) + ) + updated_count = update_discount_codes(prefix=bulk_collection.prefix, amount="15.00") + assert updated_count == 3 + for discount in discounts: + discount.refresh_from_db() + assert discount.amount == Decimal("15.00") + + +def test_update_discount_codes_exclude_redeemed_discounts(): + """ + Test updating discount codes with a prefix + """ + discount = DiscountFactory(redemption_type=REDEMPTION_TYPE_ONE_TIME) + RedeemedDiscountFactory(discount=discount) + updated_count = update_discount_codes( + discount_codes=[discount.discount_code], amount="10.00" + ) + assert updated_count == 0 + discount.refresh_from_db() + assert discount.amount != Decimal("10.00") + + +def test_locate_customer_for_basket_sets_customer_location(mocker): + """ + Test that locate_customer_for_basket sets the customer location + """ + test_user = UserFactory() + product = ProductFactory() + basket = BasketFactory(user=test_user) + request = HttpRequest() + request.user = test_user + # Mock dependencies + mock_determine_user_location = mocker.patch( + "payments.api.determine_user_location", + side_effect=["BlockedCountryLocation", "TaxCountryLocation"], + ) + mock_get_flagged_countries = mocker.patch( + "payments.api.get_flagged_countries", + side_effect=[{"BlockedCountry"}, {"TaxCountry"}], + ) + mock_get_client_ip = mocker.patch( + "payments.api.get_client_ip", + return_value="127.0.0.1", + ) + mock_basket_save = mocker.patch.object(basket, "save") + mock_basket_set_customer_location = mocker.patch.object( + basket, "set_customer_location" + ) + + # Call the function + locate_customer_for_basket(request, basket, product) + + # Assertions + mock_get_client_ip.assert_called_once_with(request) + mock_get_flagged_countries.assert_any_call("tax") + mock_determine_user_location.assert_any_call( + request, + {"BlockedCountry"}, + ) + mock_determine_user_location.assert_any_call( + request, + {"TaxCountry"}, + ) + test = CustomerLocationMetadata( + location_block="BlockedCountryLocation", location_tax="TaxCountryLocation" + ) + mock_basket_set_customer_location.assert_any_call(test) + mock_basket_save.assert_called_once() + + +def test_locate_customer_for_basket_logs_debug_info(mocker, caplog): + """ + Test that locate_customer_for_basket logs debug information + """ + test_user = UserFactory() + product = ProductFactory() + basket = BasketFactory(user=test_user) + request = HttpRequest() + request.user = test_user + # Mock dependencies + mocker.patch("payments.api.determine_user_location", return_value="SomeLocation") + mocker.patch("payments.api.get_flagged_countries", return_value=set()) + mocker.patch("payments.api.get_client_ip", return_value="127.0.0.1") + mocker.patch.object(basket, "save") + mocker.patch.object(basket, "set_customer_location") + + # Call the function + with caplog.at_level("DEBUG"): + locate_customer_for_basket(request, basket, product) + + # Assert logs + assert "locate_customer_for_basket: running for" in caplog.text + assert str(request.user) in caplog.text + assert "127.0.0.1" in caplog.text + + +def test_check_blocked_countries_blocked_for_country(): + """ + Test that ProductBlockedError is raised when the product is blocked for the user's country. + """ + # Create test data + test_user = UserFactory() + test_user.profile.country_code = "US" + test_user.save() + product = ProductFactory() + basket = BasketFactory(user=test_user, user_blockable_country_code="US") + + # Block the product for the user's country + BlockedCountryFactory(country_code="US", product=product) + + # Call the function and expect an exception + with pytest.raises(ProductBlockedError) as exc_info: + check_blocked_countries(basket, product) + + # Verify the error message + assert str(exc_info.value) == f"Product {product} blocked in country US" + + +def test_check_blocked_countries_not_blocked_for_other_country(): + """ + Test that no exception is raised when the product is blocked for another country but not the user's country. + """ + # Create test data + test_user = UserFactory() # User is in Canada + test_user.profile.country_code = "CA" # User is in Canada + test_user.save() + product = ProductFactory() + basket = BasketFactory(user=test_user, integrated_system=product.system) + + # Block the product for a different country (US) + BlockedCountryFactory(country_code="US", product=product) + + # Call the function + check_blocked_countries(basket, product) + + # No exception means the test passes + + +def test_get_redemption_type_one_time(): + """ + Test that get_redemption_type returns the correct redemption type for one-time discounts + """ + kwargs = {"one_time": True} + assert get_redemption_type(kwargs) == REDEMPTION_TYPE_ONE_TIME + + +def test_get_redemption_type_once_per_user(): + """ + Test that get_redemption_type returns the correct redemption type for once-per-user discounts + """ + kwargs = {"once_per_user": True} + assert get_redemption_type(kwargs) == REDEMPTION_TYPE_ONE_TIME_PER_USER + + +def test_get_redemption_type_specific_redemption_type(): + """ + Test that get_redemption_type returns the correct redemption type when a specific redemption type is provided + """ + kwargs = {"redemption_type": REDEMPTION_TYPE_ONE_TIME} + assert get_redemption_type(kwargs) == REDEMPTION_TYPE_ONE_TIME + + kwargs = {"redemption_type": REDEMPTION_TYPE_ONE_TIME_PER_USER} + assert get_redemption_type(kwargs) == REDEMPTION_TYPE_ONE_TIME_PER_USER + + kwargs = {"redemption_type": REDEMPTION_TYPE_UNLIMITED} + assert get_redemption_type(kwargs) == REDEMPTION_TYPE_UNLIMITED + + +def test_get_redemption_type_invalid_redemption_type(): + """ + Test that get_redemption_type returns the default redemption type when an invalid redemption type is provided + """ + kwargs = {"redemption_type": "INVALID_TYPE"} + assert get_redemption_type(kwargs) == REDEMPTION_TYPE_UNLIMITED + + +def test_get_redemption_type_no_kwargs(): + """ + Test that get_redemption_type returns the default redemption type when no kwargs are provided + """ + kwargs = {} + assert get_redemption_type(kwargs) == REDEMPTION_TYPE_UNLIMITED + + +def test_get_redemption_type_multiple_kwargs(): + """ + Test that get_redemption_type returns the correct redemption type when multiple kwargs are provided + """ + kwargs = {"one_time": True, "once_per_user": True} + assert get_redemption_type(kwargs) == REDEMPTION_TYPE_ONE_TIME + + kwargs = {"once_per_user": True, "redemption_type": REDEMPTION_TYPE_ONE_TIME} + assert get_redemption_type(kwargs) == REDEMPTION_TYPE_ONE_TIME_PER_USER + + +def test_get_redemption_type_unknown_redemption_type(): + """ + Test that get_redemption_type returns the default redemption type when an unknown redemption type is provided + """ + kwargs = {"redemption_type": "UNKNOWN_TYPE"} + assert get_redemption_type(kwargs) == REDEMPTION_TYPE_UNLIMITED + + +def test_get_users_with_valid_ids(): + """ + Test that get_users returns the correct users when valid user IDs are provided + """ + # Create test users using UserFactory + user1 = UserFactory() + user2 = UserFactory() + + # Call the function with valid user IDs + result = get_users([user1.id, user2.id]) + + # Assert that the correct users are returned + assert result == [user1, user2] + + +def test_get_users_with_valid_emails(): + """ + Test that get_users returns the correct users when valid user emails are provided + """ + # Create test users using UserFactory + user1 = UserFactory() + user2 = UserFactory() + + # Call the function with valid user emails + result = get_users([user1.email, user2.email]) + + # Assert that the correct users are returned + assert result == [user1, user2] + + +def test_get_users_with_mixed_identifiers(): + """ + Test that get_users returns the correct users when a mix of IDs and emails are provided + """ + # Create test users using UserFactory + user1 = UserFactory() + user2 = UserFactory() + + # Call the function with a mix of IDs and emails + result = get_users([user1.id, user2.email]) + + # Assert that the correct users are returned + assert result == [user1, user2] + + +def test_get_users_with_invalid_id(): + """ + Test that get_users raises an error when an invalid user ID is provided + """ + # Create a test user + test_user = UserFactory() + + # Call the function with an invalid ID + with pytest.raises(ValueError) as exc_info: # noqa: PT011 + get_users([test_user.id + 1]) # Assuming this ID does not exist + + # Assert the correct error message + assert str(exc_info.value) == f"User {test_user.id + 1} does not exist." + + +def test_get_users_with_invalid_email(): + """ + Test that get_users raises an error when an invalid user email is provided + """ + # Create a test user + UserFactory() + + # Call the function with an invalid email + invalid_email = "nonexistent@example.com" + with pytest.raises(ValueError) as exc_info: # noqa: PT011 + get_users([invalid_email]) + + # Assert the correct error message + assert str(exc_info.value) == f"User {invalid_email} does not exist." + + +def test_get_users_with_empty_list(): + """ + Test that get_users returns an empty list when an empty list is provided + """ + # Call the function with an empty list + result = get_users([]) + + # Assert that the result is an empty list + assert result == [] + + +def test_get_users_with_string_ids(): + """ + Test that get_users returns the correct users when string representations of user IDs are provided + """ + # Create a test user + test_user = UserFactory() + + # Call the function with a string representation of the ID + result = get_users([str(test_user.id)]) + + # Assert that the correct user is returned + assert result == [test_user] + + +def test_check_taxable_with_taxable_country(): + """ + Test that check_taxable applies a tax rate when the user's country code has a TaxRate. + """ + # Create a TaxRate for a specific country code + country_code = "US" + taxrate = TaxRateFactory(country_code=country_code) + + # Create a Basket with a user whose country code matches the TaxRate + basket = BasketFactory(user_blockable_country_code=country_code) + + # Call the function + check_taxable(basket) + + # Refresh the basket instance from the database + basket.refresh_from_db() + + # Assert that the tax rate was applied to the basket + assert basket.tax_rate == taxrate + + +def test_check_taxable_with_non_taxable_country(): + """ + Test that check_taxable does not apply a tax rate when the user's country code does not have a TaxRate. + """ + # Create a Basket with a user whose country code does not have a TaxRate + country_code = "CA" + basket = BasketFactory(user_blockable_country_code=country_code) + + # Call the function + check_taxable(basket) + + # Refresh the basket instance from the database + basket.refresh_from_db() + + # Assert that no tax rate was applied to the basket + assert basket.tax_rate is None + + +def test_check_taxable_with_multiple_tax_rates(): + """ + Test that check_taxable applies the first matching tax rate to the basket when multiple TaxRate instances exist for the same country code. + """ + # Create multiple TaxRate instances for the same country code + country_code1 = "AB" + country_code2 = "UK" + taxrate1 = TaxRateFactory(country_code=country_code1) + TaxRateFactory(country_code=country_code2) + + # Create a Basket with a user whose country code matches the TaxRate + basket = BasketFactory(user_blockable_country_code=country_code1) + + # Call the function + check_taxable(basket) + + # Refresh the basket instance from the database + basket.refresh_from_db() + + # Assert that the first matching tax rate was applied to the basket + assert basket.tax_rate == taxrate1 + + +def test_check_taxable_with_no_tax_rates(): + """ + Test that check_taxable does not apply a tax rate when the user's country code has no TaxRate. + """ + # Create a Basket with a user whose country code has no TaxRate + country_code = "FR" + basket = BasketFactory(user_blockable_country_code=country_code) + + # Call the function + check_taxable(basket) + + # Refresh the basket instance from the database + basket.refresh_from_db() + + # Assert that no tax rate was applied to the basket + assert basket.tax_rate is None + + +def test_check_taxable_with_empty_country_code(): + """ + Test that check_taxable does not apply a tax rate when the user's country code is empty. + """ + # Create a Basket with an empty country code + basket = BasketFactory(user_blockable_country_code="") + + # Call the function + check_taxable(basket) + + # Refresh the basket instance from the database + basket.refresh_from_db() + + # Assert that no tax rate was applied to the basket + assert basket.tax_rate is None diff --git a/payments/factories.py b/payments/factories.py index aa733e2a..84e37205 100644 --- a/payments/factories.py +++ b/payments/factories.py @@ -40,6 +40,7 @@ class OrderFactory(DjangoModelFactory): total_price_paid = fuzzy.FuzzyDecimal(10.00, 10.00) purchaser = SubFactory(UserFactory) + reference_number = FAKE.unique.word() integrated_system = SubFactory(IntegratedSystemFactory) class Meta: @@ -115,3 +116,38 @@ class Meta: """Meta options for DiscountFactory""" model = models.Discount + + +class CompanyFactory(DjangoModelFactory): + """Factory for Company""" + + name = FAKE.unique.company() + + class Meta: + """Meta options for CompanyFactory""" + + model = models.Company + + +class BulkDiscountCollectionFactory(DjangoModelFactory): + """Factory for BulkDiscountCollection""" + + prefix = FAKE.unique.word() + + class Meta: + """Meta options for BulkDiscountCollectionFactory""" + + model = models.BulkDiscountCollection + + +class RedeemedDiscountFactory(DjangoModelFactory): + """Factory for RedeemedDiscount""" + + discount = SubFactory(DiscountFactory) + order = SubFactory(OrderFactory) + user = SubFactory(UserFactory) + + class Meta: + """Meta options for RedeemedDiscountFactory""" + + model = models.RedeemedDiscount diff --git a/payments/models_test.py b/payments/models_test.py index 6faa835d..2fa80ebd 100644 --- a/payments/models_test.py +++ b/payments/models_test.py @@ -7,6 +7,7 @@ import pytest import pytz import reversion +from django.http import HttpRequest from mitol.payment_gateway.payment_utils import quantize_decimal from reversion.models import Version @@ -14,6 +15,7 @@ from payments.factories import ( BasketFactory, BasketItemFactory, + DiscountFactory, LineFactory, OrderFactory, TaxRateFactory, @@ -527,3 +529,159 @@ def test_order_tax_calculation_precision_check(user): assert order.tax == quantize_decimal(tax_assessed) assert order.lines.first().tax == tax_assessed assert order.lines.first().total_price == taxed_price + + +def test_resolve_discount_version_current_version(): + """ + Test that the current version of a Discount instance is returned when no version is specified. + """ + # Create a Discount instance + discount = DiscountFactory() + + # Call the method with discount_version=None (current version) + result = models.Discount.resolve_discount_version(discount, discount_version=None) + + # Assert that the current version is returned + assert result == discount + + +def test_resolve_discount_version_no_versions(): + """ + Test that an error is raised when no versions of a Discount instance are found. + """ + # Create a Discount instance + discount = DiscountFactory() + + # Call the method with discount_version=None (current version) + result = models.Discount.resolve_discount_version(discount, discount_version=None) + + # Assert that the current version is returned + assert result == discount + + +def test_resolve_discount_version_invalid_version(): + """ + Test that an error is raised when an invalid version is specified. + """ + # Create a Discount instance + discount = DiscountFactory() + + # Create a version of the Discount instance + with reversion.create_revision(): + discount.amount = 50 + discount.save() + reversion.set_comment("Changed amount to 50") + + # Get the version + versions = Version.objects.get_for_object(discount) + versions.first() + + # Call the method with an invalid version + with pytest.raises(TypeError) as exc_info: + models.Discount.resolve_discount_version( + discount, discount_version="invalid_version" + ) + + # Assert the correct error message + assert str(exc_info.value) == "Invalid product version specified" + + +def test_establish_basket_new_basket(): + """ + Test that a new basket is created when a basket does not already exist for the user and integrated system. + """ + # Create a user and an integrated system + user = UserFactory() + integrated_system = IntegratedSystemFactory() + + # Simulate a request object with the user + request = HttpRequest() + request.user = user + + # Call the method + basket = models.Basket.establish_basket(request, integrated_system) + + # Assert that a new basket was created + assert basket.user == user + assert basket.integrated_system == integrated_system + assert models.Basket.objects.filter( + user=user, integrated_system=integrated_system + ).exists() + + +def test_establish_basket_existing_basket(): + """ + Test that an existing basket is returned when a basket already exists for the user and integrated system. + """ + # Create a user, an integrated system, and an existing basket + user = UserFactory() + integrated_system = IntegratedSystemFactory() + existing_basket = BasketFactory(user=user, integrated_system=integrated_system) + + # Simulate a request object with the user + request = HttpRequest() + request.user = user + + # Call the method + basket = models.Basket.establish_basket(request, integrated_system) + + # Assert that the existing basket was returned + assert basket == existing_basket + assert ( + models.Basket.objects.filter( + user=user, integrated_system=integrated_system + ).count() + == 1 + ) + + +def test_establish_basket_multiple_integrated_systems(): + """ + Test that a new basket is created for each integrated system when multiple integrated systems exist. + """ + # Create a user and two integrated systems + user = UserFactory() + integrated_system1 = IntegratedSystemFactory() + integrated_system2 = IntegratedSystemFactory() + + # Simulate a request object with the user + request = HttpRequest() + request.user = user + + # Call the method for the first integrated system + basket1 = models.Basket.establish_basket(request, integrated_system1) + + # Call the method for the second integrated system + basket2 = models.Basket.establish_basket(request, integrated_system2) + + # Assert that two different baskets were created + assert basket1 != basket2 + assert basket1.integrated_system == integrated_system1 + assert basket2.integrated_system == integrated_system2 + assert models.Basket.objects.filter(user=user).count() == 2 + + +def test_establish_basket_unique_constraint(): + """ + Test that a single basket is created when the method is called multiple times. + """ + # Create a user and an integrated system + user = UserFactory() + integrated_system = IntegratedSystemFactory() + + # Simulate a request object with the user + request = HttpRequest() + request.user = user + + # Call the method twice + basket1 = models.Basket.establish_basket(request, integrated_system) + basket2 = models.Basket.establish_basket(request, integrated_system) + + # Assert that the same basket was returned both times + assert basket1 == basket2 + assert ( + models.Basket.objects.filter( + user=user, integrated_system=integrated_system + ).count() + == 1 + ) diff --git a/payments/utils_test.py b/payments/utils_test.py index 58d389aa..b461c5fd 100644 --- a/payments/utils_test.py +++ b/payments/utils_test.py @@ -1,6 +1,8 @@ """Tests for utility functions in payments.""" import pytest +import pytz +from dateutil import parser from payments import models, utils from system_meta.factories import ProductFactory @@ -37,3 +39,75 @@ def test_product_price_with_discount(discount_type): assert utils.product_price_with_discount(discount, product) == 90 if discount_type == DISCOUNT_TYPE_FIXED_PRICE: assert utils.product_price_with_discount(discount, product) == 10 + + +def test_parse_supplied_date_with_timezone(): + """ + Test that the supplied date string is parsed correctly when it includes timezone information. + """ + # Test a date string with timezone information + datearg = "2023-10-15T12:30:00+05:00" + result = utils.parse_supplied_date(datearg) + + # Expected result: timezone should be converted to TIME_ZONE + expected = parser.parse("2023-10-15T07:30:00").replace(tzinfo=pytz.timezone("UTC")) + assert result == expected + + +def test_parse_supplied_date_without_timezone(): + """ + Test that the supplied date string is parsed correctly when it does not include timezone information. + """ + # Test a date string without timezone information + datearg = "2023-10-15T12:30:00" + result = utils.parse_supplied_date(datearg) + + # Expected result: timezone should be set to TIME_ZONE + expected = parser.parse("2023-10-15T12:30:00").replace(tzinfo=pytz.timezone("UTC")) + assert result == expected + + +def test_parse_supplied_date_with_invalid_date(): + """ + Test that an invalid date string raises a ValueError. + """ + # Test an invalid date string + datearg = "invalid-date" + with pytest.raises(ValueError): # noqa: PT011 + utils.parse_supplied_date(datearg) + + +def test_parse_supplied_date_with_empty_string(): + """ + Test that an empty date string raises a ValueError. + """ + # Test an empty date string + datearg = "" + with pytest.raises(ValueError): # noqa: PT011 + utils.parse_supplied_date(datearg) + + +def test_parse_supplied_date_with_only_date(): + """ + Test that the supplied date string is parsed correctly when it only includes the date (no time). + """ + # Test a date string with only date (no time) + datearg = "2023-10-15" + result = utils.parse_supplied_date(datearg) + + # Expected result: time should default to midnight, timezone set to TIME_ZONE + expected = parser.parse("2023-10-15T00:00:00").replace(tzinfo=pytz.timezone("UTC")) + assert result == expected + + +def test_parse_supplied_date_with_different_timezone(): + """ + Test that the supplied date string is parsed correctly when it includes a different timezone. + """ + # Test a date string with a different timezone + datearg = "2023-10-15T12:30:00-07:00" + result = utils.parse_supplied_date(datearg) + + # Expected result: timezone should be converted to TIME_ZONE + expected = parser.parse("2023-10-15T19:30:00").replace(tzinfo=pytz.timezone("UTC")) + assert result == expected