2222import datetime
2323import json
2424import logging
25- from typing import TYPE_CHECKING , Any , Generator , Iterable
25+ from typing import TYPE_CHECKING , Any , Generator , Iterable , overload
2626
2727import pendulum
2828from dateutil import relativedelta
2929from sqlalchemy import TIMESTAMP , PickleType , and_ , event , false , nullsfirst , or_ , true , tuple_
3030from sqlalchemy .dialects import mssql , mysql
3131from sqlalchemy .exc import OperationalError
32- from sqlalchemy .sql import ColumnElement
32+ from sqlalchemy .sql import ColumnElement , Select
3333from sqlalchemy .sql .expression import ColumnOperators
3434from sqlalchemy .types import JSON , Text , TypeDecorator , TypeEngine , UnicodeText
3535
@@ -515,11 +515,31 @@ def is_lock_not_available_error(error: OperationalError):
515515 return False
516516
517517
518+ @overload
518519def tuple_in_condition (
519520 columns : tuple [ColumnElement , ...],
520521 collection : Iterable [Any ],
521522) -> ColumnOperators :
522- """Generates a tuple-in-collection operator to use in ``.filter()``.
523+ ...
524+
525+
526+ @overload
527+ def tuple_in_condition (
528+ columns : tuple [ColumnElement , ...],
529+ collection : Select ,
530+ * ,
531+ session : Session ,
532+ ) -> ColumnOperators :
533+ ...
534+
535+
536+ def tuple_in_condition (
537+ columns : tuple [ColumnElement , ...],
538+ collection : Iterable [Any ] | Select ,
539+ * ,
540+ session : Session | None = None ,
541+ ) -> ColumnOperators :
542+ """Generates a tuple-in-collection operator to use in ``.where()``.
523543
524544 For most SQL backends, this generates a simple ``([col, ...]) IN [condition]``
525545 clause. This however does not work with MSSQL, where we need to expand to
@@ -529,25 +549,57 @@ def tuple_in_condition(
529549 """
530550 if settings .engine .dialect .name != "mssql" :
531551 return tuple_ (* columns ).in_ (collection )
532- clauses = [and_ (* (c == v for c , v in zip (columns , values ))) for values in collection ]
552+ if not isinstance (collection , Select ):
553+ rows = collection
554+ elif session is None :
555+ raise TypeError ("session is required when passing in a subquery" )
556+ else :
557+ rows = session .execute (collection )
558+ clauses = [and_ (* (c == v for c , v in zip (columns , values ))) for values in rows ]
533559 if not clauses :
534560 return false ()
535561 return or_ (* clauses )
536562
537563
564+ @overload
538565def tuple_not_in_condition (
539566 columns : tuple [ColumnElement , ...],
540567 collection : Iterable [Any ],
541568) -> ColumnOperators :
542- """Generates a tuple-not-in-collection operator to use in ``.filter()``.
569+ ...
570+
571+
572+ @overload
573+ def tuple_not_in_condition (
574+ columns : tuple [ColumnElement , ...],
575+ collection : Select ,
576+ * ,
577+ session : Session ,
578+ ) -> ColumnOperators :
579+ ...
580+
581+
582+ def tuple_not_in_condition (
583+ columns : tuple [ColumnElement , ...],
584+ collection : Iterable [Any ] | Select ,
585+ * ,
586+ session : Session | None = None ,
587+ ) -> ColumnOperators :
588+ """Generates a tuple-not-in-collection operator to use in ``.where()``.
543589
544590 This is similar to ``tuple_in_condition`` except generating ``NOT IN``.
545591
546592 :meta private:
547593 """
548594 if settings .engine .dialect .name != "mssql" :
549595 return tuple_ (* columns ).not_in (collection )
550- clauses = [or_ (* (c != v for c , v in zip (columns , values ))) for values in collection ]
596+ if not isinstance (collection , Select ):
597+ rows = collection
598+ elif session is None :
599+ raise TypeError ("session is required when passing in a subquery" )
600+ else :
601+ rows = session .execute (collection )
602+ clauses = [or_ (* (c != v for c , v in zip (columns , values ))) for values in rows ]
551603 if not clauses :
552604 return true ()
553605 return and_ (* clauses )
0 commit comments