22
22
import datetime
23
23
import json
24
24
import logging
25
- from typing import TYPE_CHECKING , Any , Generator , Iterable
25
+ from typing import TYPE_CHECKING , Any , Generator , Iterable , overload
26
26
27
27
import pendulum
28
28
from dateutil import relativedelta
29
29
from sqlalchemy import TIMESTAMP , PickleType , and_ , event , false , nullsfirst , or_ , true , tuple_
30
30
from sqlalchemy .dialects import mssql , mysql
31
31
from sqlalchemy .exc import OperationalError
32
- from sqlalchemy .sql import ColumnElement
32
+ from sqlalchemy .sql import ColumnElement , Select
33
33
from sqlalchemy .sql .expression import ColumnOperators
34
34
from sqlalchemy .types import JSON , Text , TypeDecorator , TypeEngine , UnicodeText
35
35
@@ -515,11 +515,31 @@ def is_lock_not_available_error(error: OperationalError):
515
515
return False
516
516
517
517
518
+ @overload
518
519
def tuple_in_condition (
519
520
columns : tuple [ColumnElement , ...],
520
521
collection : Iterable [Any ],
521
522
) -> ColumnOperators :
522
- """Generates a tuple-in-collection operator to use in ``.filter()``.
523
+ ...
524
+
525
+
526
+ @overload
527
+ def tuple_in_condition (
528
+ columns : tuple [ColumnElement , ...],
529
+ collection : Select ,
530
+ * ,
531
+ session : Session ,
532
+ ) -> ColumnOperators :
533
+ ...
534
+
535
+
536
+ def tuple_in_condition (
537
+ columns : tuple [ColumnElement , ...],
538
+ collection : Iterable [Any ] | Select ,
539
+ * ,
540
+ session : Session | None = None ,
541
+ ) -> ColumnOperators :
542
+ """Generates a tuple-in-collection operator to use in ``.where()``.
523
543
524
544
For most SQL backends, this generates a simple ``([col, ...]) IN [condition]``
525
545
clause. This however does not work with MSSQL, where we need to expand to
@@ -529,25 +549,57 @@ def tuple_in_condition(
529
549
"""
530
550
if settings .engine .dialect .name != "mssql" :
531
551
return tuple_ (* columns ).in_ (collection )
532
- clauses = [and_ (* (c == v for c , v in zip (columns , values ))) for values in collection ]
552
+ if not isinstance (collection , Select ):
553
+ rows = collection
554
+ elif session is None :
555
+ raise TypeError ("session is required when passing in a subquery" )
556
+ else :
557
+ rows = session .execute (collection )
558
+ clauses = [and_ (* (c == v for c , v in zip (columns , values ))) for values in rows ]
533
559
if not clauses :
534
560
return false ()
535
561
return or_ (* clauses )
536
562
537
563
564
+ @overload
538
565
def tuple_not_in_condition (
539
566
columns : tuple [ColumnElement , ...],
540
567
collection : Iterable [Any ],
541
568
) -> ColumnOperators :
542
- """Generates a tuple-not-in-collection operator to use in ``.filter()``.
569
+ ...
570
+
571
+
572
+ @overload
573
+ def tuple_not_in_condition (
574
+ columns : tuple [ColumnElement , ...],
575
+ collection : Select ,
576
+ * ,
577
+ session : Session ,
578
+ ) -> ColumnOperators :
579
+ ...
580
+
581
+
582
+ def tuple_not_in_condition (
583
+ columns : tuple [ColumnElement , ...],
584
+ collection : Iterable [Any ] | Select ,
585
+ * ,
586
+ session : Session | None = None ,
587
+ ) -> ColumnOperators :
588
+ """Generates a tuple-not-in-collection operator to use in ``.where()``.
543
589
544
590
This is similar to ``tuple_in_condition`` except generating ``NOT IN``.
545
591
546
592
:meta private:
547
593
"""
548
594
if settings .engine .dialect .name != "mssql" :
549
595
return tuple_ (* columns ).not_in (collection )
550
- clauses = [or_ (* (c != v for c , v in zip (columns , values ))) for values in collection ]
596
+ if not isinstance (collection , Select ):
597
+ rows = collection
598
+ elif session is None :
599
+ raise TypeError ("session is required when passing in a subquery" )
600
+ else :
601
+ rows = session .execute (collection )
602
+ clauses = [or_ (* (c != v for c , v in zip (columns , values ))) for values in rows ]
551
603
if not clauses :
552
604
return true ()
553
605
return and_ (* clauses )
0 commit comments