Skip to content

Commit

Permalink
add inline call
Browse files Browse the repository at this point in the history
  • Loading branch information
ChenghaoMou committed Mar 16, 2024
1 parent e51feba commit 3ffc964
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 52 deletions.
3 changes: 3 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,6 @@ test: run

spark_test: run
docker compose exec spark poetry run pytest -vvv -s --doctest-modules tests/test_minhash_spark.py

clean:
docker system prune -a
52 changes: 25 additions & 27 deletions tests/test_bloom_filter.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,33 @@
import io
import subprocess
from contextlib import redirect_stdout

import click

from text_dedup.bloom_filter import main as bf_main
from text_dedup.utils import BloomFilterArgs
from text_dedup.utils import IOArgs
from text_dedup.utils import MetaArgs

def test_bloom_filter():
result = subprocess.run(
[
"python",
"-m",
"text_dedup.bloom_filter",
"--path",
"allenai/c4",
"--name",
"xh",
"--split",
"train",
"--cache_dir",
".cache",
"--output",
".temp-output",
"--column",
"text",
"--batch_size",
"10000",
],
capture_output=True,
text=True,
)

def test_bloom_filter():
with redirect_stdout(io.StringIO()) as f:
ctx = click.Context(bf_main)
ctx.invoke(
bf_main,
io_args=IOArgs(
path="allenai/c4",
name="xh",
split="train",
cache_dir=".cache",
output=".temp-output",
),
meta_args=MetaArgs(column="text", batch_size=10000),
bloom_filter_args=BloomFilterArgs(),
)
s = f.getvalue()
# check the output
assert (
"69048" in result.stdout and "69048" in result.stdout
), f"Expected before and after are not present in the output: {result.stdout}"
assert "69048" in s and "69048" in s, f"Expected before and after are not present in the output: {s}"

# remove the output and input
subprocess.run(["rm", "-rf", ".cache"])
Expand Down
76 changes: 52 additions & 24 deletions text_dedup/utils/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,11 @@ def option_group(func):
)
@functools.wraps(func)
def wrapper(*args, **kwargs):
io_args = IOArgs(**{k: kwargs.pop(k) for k in list(kwargs.keys()) if k in IOArgs.__annotations__})
if "io_args" not in kwargs:
io_args = IOArgs(**{k: kwargs.pop(k) for k in list(kwargs.keys()) if k in IOArgs.__annotations__})
else:
io_args = kwargs.pop("io_args")
{kwargs.pop(k) for k in list(kwargs.keys()) if k in IOArgs.__annotations__}
return func(*args, **kwargs, io_args=io_args)

return wrapper
Expand All @@ -70,7 +74,11 @@ def option_group(func):
@optgroup.option("--batch_size", type=int, help="Batch size for deduplication", default=10_000)
@functools.wraps(func)
def wrapper(*args, **kwargs):
meta_args = MetaArgs(**{k: kwargs.pop(k) for k in list(kwargs.keys()) if k in MetaArgs.__annotations__})
if "meta_args" not in kwargs:
meta_args = MetaArgs(**{k: kwargs.pop(k) for k in list(kwargs.keys()) if k in MetaArgs.__annotations__})
else:
meta_args = kwargs.pop("meta_args")
{kwargs.pop(k) for k in list(kwargs.keys()) if k in MetaArgs.__annotations__}
return func(*args, **kwargs, meta_args=meta_args)

return wrapper
Expand Down Expand Up @@ -119,13 +127,17 @@ def option_group(func):
)
@functools.wraps(func)
def wrapper(*args, **kwargs):
minhash_args = MinHashArgs(
**{
k: (kwargs.pop(k) if k != "hash_bits" else int(kwargs.pop(k)))
for k in list(kwargs.keys())
if k in MinHashArgs.__annotations__
}
)
if "minhash_args" not in kwargs:
minhash_args = MinHashArgs(
**{
k: (kwargs.pop(k) if k != "hash_bits" else int(kwargs.pop(k)))
for k in list(kwargs.keys())
if k in MinHashArgs.__annotations__
}
)
else:
minhash_args = kwargs.pop("minhash_args")
{kwargs.pop(k) for k in list(kwargs.keys()) if k in MinHashArgs.__annotations__}
return func(*args, **kwargs, minhash_args=minhash_args)

return wrapper
Expand Down Expand Up @@ -154,13 +166,17 @@ def option_group(func):
)
@functools.wraps(func)
def wrapper(*args, **kwargs):
simhash_args = SimHashArgs(
**{
k: (kwargs.pop(k) if k != "f" else int(kwargs.pop(k)))
for k in list(kwargs.keys())
if k in SimHashArgs.__annotations__
}
)
if "simhash_args" not in kwargs:
simhash_args = SimHashArgs(
**{
k: (kwargs.pop(k) if k != "f" else int(kwargs.pop(k)))
for k in list(kwargs.keys())
if k in SimHashArgs.__annotations__
}
)
else:
simhash_args = kwargs.pop("simhash_args")
{kwargs.pop(k) for k in list(kwargs.keys()) if k in SimHashArgs.__annotations__}
return func(*args, **kwargs, simhash_args=simhash_args)

return wrapper
Expand Down Expand Up @@ -192,7 +208,11 @@ def option_group(func):
)
@functools.wraps(func)
def wrapper(*args, **kwargs):
sa_args = SAArgs(**{k: kwargs.pop(k) for k in list(kwargs.keys()) if k in SAArgs.__annotations__})
if "sa_args" not in kwargs:
sa_args = SAArgs(**{k: kwargs.pop(k) for k in list(kwargs.keys()) if k in SAArgs.__annotations__})
else:
sa_args = kwargs.pop("sa_args")
{kwargs.pop(k) for k in list(kwargs.keys()) if k in SAArgs.__annotations__}
return func(*args, **kwargs, sa_args=sa_args)

return wrapper
Expand All @@ -217,10 +237,14 @@ def option_group(func):
@optgroup.option("--initial_capacity", type=int, help="Initial capacity of BloomFilter", default=100)
@functools.wraps(func)
def wrapper(*args, **kwargs):
bloom_args = BloomFilterArgs(
**{k: kwargs.pop(k) for k in list(kwargs.keys()) if k in BloomFilterArgs.__annotations__}
)
return func(*args, **kwargs, bloom_filter_args=bloom_args)
if "bloom_filter_args" not in kwargs:
bloom_filter_args = BloomFilterArgs(
**{k: kwargs.pop(k) for k in list(kwargs.keys()) if k in BloomFilterArgs.__annotations__}
)
else:
bloom_filter_args = kwargs.pop("bloom_filter_args")
{kwargs.pop(k) for k in list(kwargs.keys()) if k in BloomFilterArgs.__annotations__}
return func(*args, **kwargs, bloom_filter_args=bloom_filter_args)

return wrapper

Expand All @@ -240,9 +264,13 @@ def option_group(func):
)
@functools.wraps(func)
def wrapper(*args, **kwargs):
exact_hash_args = ExactHashArgs(
**{k: kwargs.pop(k) for k in list(kwargs.keys()) if k in ExactHashArgs.__annotations__}
)
if "exact_hash_args" not in kwargs:
exact_hash_args = ExactHashArgs(
**{k: kwargs.pop(k) for k in list(kwargs.keys()) if k in ExactHashArgs.__annotations__}
)
else:
exact_hash_args = kwargs.pop("exact_hash_args")
{kwargs.pop(k) for k in list(kwargs.keys()) if k in ExactHashArgs.__annotations__}
return func(*args, **kwargs, exact_hash_args=exact_hash_args)

return wrapper

0 comments on commit 3ffc964

Please sign in to comment.