Skip to content

Commit 8eb6e16

Browse files
committed
bugfix
1 parent 0a9fee0 commit 8eb6e16

File tree

1 file changed

+14
-17
lines changed

1 file changed

+14
-17
lines changed

src/datatrove/pipeline/dedup/minhash.py

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -260,25 +260,22 @@ def run(self, data: DocumentsPipeline, rank: int = 0, world_size: int = 1):
260260

261261
logger.info("Sorting buckets...")
262262
for bi in range(self.config.num_buckets):
263-
# read one by one, sort and write back
264-
sigs = sorted(
265-
read_sigs(
266-
self.output_folder.open(f"bucket_{bi:03d}/{rank:05d}.minhash.sig", mode="rb"),
267-
-1,
268-
self.config,
269-
ensure_order=False,
270-
lines_to_buffer=-1, # load everything in one go
271-
)
263+
# read all records, sort and write back
264+
dtype = np.dtype(
265+
[
266+
(f"field{i + 1}", f"<{self.config.hash_config.struct_format}")
267+
for i in range(self.config.hashes_per_bucket)
268+
]
269+
+ [(f"field{self.config.hashes_per_bucket + 1}", "<I")]
272270
)
271+
with self.output_folder.open(f"bucket_{bi:03d}/{rank:05d}.minhash.sig", mode="rb") as fi:
272+
records = np.frombuffer(fi.read(), dtype=dtype)
273+
274+
indices = np.argsort(records, order=dtype.names)
275+
273276
with self.output_folder.open(f"bucket_{bi:03d}/{rank:05d}.minhash.sig", mode="wb") as fo:
274-
for sig in sigs:
275-
fo.write(
276-
struct.pack(
277-
f"<{self.config.hashes_per_bucket}{self.config.hash_config.struct_format}I",
278-
*sig.sig,
279-
sig.doc_id,
280-
)
281-
)
277+
for idx in indices:
278+
fo.write(records[idx].tobytes())
282279

283280

284281
class MinhashDedupBuckets(PipelineStep):

0 commit comments

Comments
 (0)