Skip to content

Commit b2f297f

Browse files
committed
add overloads for transform
fix lint fix lint
1 parent 3be6b77 commit b2f297f

File tree

4 files changed

+116
-5
lines changed

4 files changed

+116
-5
lines changed

graphistry/Plottable.py

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, Tuple, Union, Protocol
1+
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, Tuple, Union, Protocol, overload
22
from typing_extensions import Literal, runtime_checkable
33
import pandas as pd
44

@@ -580,6 +580,32 @@ def infer_labels(self) -> 'Plottable':
580580
...
581581

582582

583+
@overload
584+
def transform(self, df: pd.DataFrame,
585+
y: Optional[pd.DataFrame] = None,
586+
kind: str = 'nodes',
587+
min_dist: Union[str, float, int] = 'auto',
588+
n_neighbors: int = 7,
589+
merge_policy: bool = False,
590+
sample: Optional[int] = None,
591+
return_graph: Literal[True] = True,
592+
scaled: bool = True,
593+
verbose: bool = False) -> 'Plottable':
594+
...
595+
596+
@overload
597+
def transform(self, df: pd.DataFrame,
598+
y: Optional[pd.DataFrame] = None,
599+
kind: str = 'nodes',
600+
min_dist: Union[str, float, int] = 'auto',
601+
n_neighbors: int = 7,
602+
merge_policy: bool = False,
603+
sample: Optional[int] = None,
604+
return_graph: Literal[False] = False,
605+
scaled: bool = True,
606+
verbose: bool = False) -> Tuple[pd.DataFrame, pd.DataFrame]:
607+
...
608+
583609
def transform(self, df: pd.DataFrame,
584610
y: Optional[pd.DataFrame] = None,
585611
kind: str = 'nodes',
@@ -593,6 +619,34 @@ def transform(self, df: pd.DataFrame,
593619
...
594620

595621

622+
@overload
623+
def transform_umap(self, df: pd.DataFrame,
624+
y: Optional[pd.DataFrame] = None,
625+
kind: GraphEntityKind = 'nodes',
626+
min_dist: Union[str, float, int] = 'auto',
627+
n_neighbors: int = 7,
628+
merge_policy: bool = False,
629+
sample: Optional[int] = None,
630+
return_graph: Literal[True] = True,
631+
fit_umap_embedding: bool = True,
632+
umap_transform_kwargs: Dict[str, Any] = {}
633+
) -> 'Plottable':
634+
...
635+
636+
@overload
637+
def transform_umap(self, df: pd.DataFrame,
638+
y: Optional[pd.DataFrame] = None,
639+
kind: GraphEntityKind = 'nodes',
640+
min_dist: Union[str, float, int] = 'auto',
641+
n_neighbors: int = 7,
642+
merge_policy: bool = False,
643+
sample: Optional[int] = None,
644+
return_graph: Literal[False] = False,
645+
fit_umap_embedding: bool = True,
646+
umap_transform_kwargs: Dict[str, Any] = {}
647+
) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
648+
...
649+
596650
def transform_umap(self, df: pd.DataFrame,
597651
y: Optional[pd.DataFrame] = None,
598652
kind: GraphEntityKind = 'nodes',

graphistry/compute/cluster.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -392,9 +392,9 @@ def _transform_dbscan(
392392

393393
emb = None
394394
if umap and cols is None:
395-
emb, X, y = res.transform_umap(df, ydf, kind=kind, return_graph=False) # type: ignore
395+
emb, X, y = res.transform_umap(df, ydf, kind=kind, return_graph=False)
396396
else:
397-
X, y = res.transform(df, ydf, kind=kind, return_graph=False) # type: ignore
397+
X, y = res.transform(df, ydf, kind=kind, return_graph=False)
398398
XX = X
399399
if target:
400400
XX = y

graphistry/feature_utils.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@
1616
Optional,
1717
Tuple,
1818
TYPE_CHECKING,
19+
overload,
1920
)
21+
from typing_extensions import Literal
2022

2123
from graphistry.compute.ComputeMixin import ComputeMixin
2224
from graphistry.config import config as graphistry_config
@@ -2358,6 +2360,32 @@ def _transform(self, encoder: str, df: pd.DataFrame, ydf: Optional[pd.DataFrame]
23582360
)
23592361
raise ValueError(f"Encoder {encoder} not initialized. Call featurize() first.")
23602362

2363+
@overload
2364+
def transform(self, df: pd.DataFrame,
2365+
y: Optional[pd.DataFrame] = None,
2366+
kind: str = 'nodes',
2367+
min_dist: Union[str, float, int] = 'auto',
2368+
n_neighbors: int = 7,
2369+
merge_policy: bool = False,
2370+
sample: Optional[int] = None,
2371+
return_graph: Literal[True] = True,
2372+
scaled: bool = True,
2373+
verbose: bool = False) -> 'Plottable':
2374+
...
2375+
2376+
@overload
2377+
def transform(self, df: pd.DataFrame,
2378+
y: Optional[pd.DataFrame] = None,
2379+
kind: str = 'nodes',
2380+
min_dist: Union[str, float, int] = 'auto',
2381+
n_neighbors: int = 7,
2382+
merge_policy: bool = False,
2383+
sample: Optional[int] = None,
2384+
return_graph: Literal[False] = False,
2385+
scaled: bool = True,
2386+
verbose: bool = False) -> Tuple[pd.DataFrame, pd.DataFrame]:
2387+
...
2388+
23612389
def transform(self, df: pd.DataFrame,
23622390
y: Optional[pd.DataFrame] = None,
23632391
kind: str = 'nodes',

graphistry/umap_utils.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import copy
22
from time import time
3-
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union, cast
3+
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union, cast, overload
4+
from typing_extensions import Literal
45
from inspect import getmodule
56
import warnings
67

@@ -384,6 +385,34 @@ def _umap_fit_transform(
384385
return emb
385386

386387

388+
@overload
389+
def transform_umap(self, df: pd.DataFrame,
390+
y: Optional[pd.DataFrame] = None,
391+
kind: GraphEntityKind = 'nodes',
392+
min_dist: Union[str, float, int] = 'auto',
393+
n_neighbors: int = 7,
394+
merge_policy: bool = False,
395+
sample: Optional[int] = None,
396+
return_graph: Literal[True] = True,
397+
fit_umap_embedding: bool = True,
398+
umap_transform_kwargs: Dict[str, Any] = {}
399+
) -> 'Plottable':
400+
...
401+
402+
@overload
403+
def transform_umap(self, df: pd.DataFrame,
404+
y: Optional[pd.DataFrame] = None,
405+
kind: GraphEntityKind = 'nodes',
406+
min_dist: Union[str, float, int] = 'auto',
407+
n_neighbors: int = 7,
408+
merge_policy: bool = False,
409+
sample: Optional[int] = None,
410+
return_graph: Literal[False] = False,
411+
fit_umap_embedding: bool = True,
412+
umap_transform_kwargs: Dict[str, Any] = {}
413+
) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
414+
...
415+
387416
def transform_umap(self, df: pd.DataFrame,
388417
y: Optional[pd.DataFrame] = None,
389418
kind: GraphEntityKind = 'nodes',
@@ -394,7 +423,7 @@ def transform_umap(self, df: pd.DataFrame,
394423
return_graph: bool = True,
395424
fit_umap_embedding: bool = True,
396425
umap_transform_kwargs: Dict[str, Any] = {}
397-
) -> Union[Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame], Plottable]:
426+
) -> Union[Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame], 'Plottable']:
398427
"""Transforms data into UMAP embedding
399428
400429
Args:

0 commit comments

Comments
 (0)