1
+
2
+ import io
3
+ import os
4
+ import tarfile
5
+ import time
6
+ from os import PathLike
7
+ from pathlib import Path
8
+ from typing import List , Optional
9
+
10
+ import numpy as np
11
+ from PIL import Image
12
+
13
+
14
+ class TarWriter :
15
+ '''
16
+ Write samples of AFM images, molecules and descriptors to tar files. Use as a context manager and add samples with
17
+ :meth:`add_sample`.
18
+
19
+ Each tar file has a maximum number of samples, and whenever that maximum is reached, a new tar file is created.
20
+ The generated tar files are named as ``{base_name}_{n}.tar`` and saved into the specified folder. The current tar file
21
+ handle is always available in the attribute :attr:`ft`, and is automatically closed when the context ends.
22
+
23
+ Arguments:
24
+ base_path: Path to directory where tar files are saved.
25
+ base_name: Base name for output tar files. The number of the tar file is appended to the name.
26
+ max_count: Maximum number of samples per tar file.
27
+ png_compress_level: Compression level 1-9 for saved png images. Larger value for smaller file size but slower
28
+ write speed.
29
+ '''
30
+
31
+ def __init__ (self , base_path : PathLike = './' , base_name : str = '' , max_count : int = 100 , png_compress_level = 4 ):
32
+ self .base_path = Path (base_path )
33
+ self .base_name = base_name
34
+ self .max_count = max_count
35
+ self .png_compress_level = png_compress_level
36
+
37
+ def __enter__ (self ):
38
+ self .sample_count = 0
39
+ self .total_count = 0
40
+ self .tar_count = 0
41
+ self .ft = self ._get_tar_file ()
42
+ return self
43
+
44
+ def __exit__ (self , exc_type , exc_value , exc_traceback ):
45
+ self .ft .close ()
46
+
47
+ def _get_tar_file (self ):
48
+ file_path = self .base_path / f'{ self .base_name } _{ self .tar_count } .tar'
49
+ if os .path .exists (file_path ):
50
+ raise RuntimeError (f'Tar file already exists at `{ file_path } `' )
51
+ return tarfile .open (file_path , 'w' , format = tarfile .GNU_FORMAT )
52
+
53
+ def add_sample (self , X : List [np .ndarray ], xyzs : np .ndarray , Y : Optional [np .ndarray ]= None , comment_str : str = '' ):
54
+ """
55
+ Add a sample to the current tar file.
56
+
57
+ Arguments:
58
+ X: AFM images. Each list item corresponds to an AFM tip and is an array of shape (nx, ny, nz).
59
+ xyzs: Atom coordinates and elements. Each row is one atom and is of the form [x, y, z, element].
60
+ Y: Image descriptors. Each list item is one descriptor and is an array of shape (nx, ny).
61
+ comment_str: Comment line (second line) to add to the xyz file.
62
+ """
63
+
64
+ if self .sample_count >= self .max_count :
65
+ self .tar_count += 1
66
+ self .sample_count = 0
67
+ self .ft .close ()
68
+ self .ft = self ._get_tar_file ()
69
+
70
+ # Write AFM images
71
+ for i , x in enumerate (X ):
72
+ for j in range (x .shape [- 1 ]):
73
+ xj = x [:, :, j ]
74
+ xj = ((xj - xj .min ()) / np .ptp (xj ) * (2 ** 8 - 1 )).astype (np .uint8 ) # Convert range to 0-255 integers
75
+ img_bytes = io .BytesIO ()
76
+ Image .fromarray (xj .T [::- 1 ], mode = 'L' ).save (img_bytes , 'png' , compress_level = self .png_compress_level )
77
+ img_bytes .seek (0 ) # Return stream to start so that addfile can read it correctly
78
+ self .ft .addfile (get_tarinfo (f'{ self .total_count } .{ j :02d} .{ i } .png' , img_bytes ), img_bytes )
79
+ img_bytes .close ()
80
+
81
+ # Write xyz file
82
+ xyz_bytes = io .BytesIO ()
83
+ xyz_bytes .write (bytearray (f'{ len (xyzs )} \n { comment_str } \n ' , 'utf-8' ))
84
+ for xyz in xyzs :
85
+ xyz_bytes .write (bytearray (f'{ int (xyz [- 1 ])} \t ' , 'utf-8' ))
86
+ for i in range (len (xyz )- 1 ):
87
+ xyz_bytes .write (bytearray (f'{ xyz [i ]:10.8f} \t ' , 'utf-8' ))
88
+ xyz_bytes .write (bytearray ('\n ' , 'utf-8' ))
89
+ xyz_bytes .seek (0 ) # Return stream to start so that addfile can read it correctly
90
+ self .ft .addfile (get_tarinfo (f'{ self .total_count } .xyz' , xyz_bytes ), xyz_bytes )
91
+ xyz_bytes .close ()
92
+
93
+ # Write image descriptors (if any)
94
+ if Y is not None :
95
+ for i , y in enumerate (Y ):
96
+ img_bytes = io .BytesIO ()
97
+ np .save (img_bytes , y .astype (np .float32 ))
98
+ img_bytes .seek (0 ) # Return stream to start so that addfile can read it correctly
99
+ self .ft .addfile (get_tarinfo (f'{ self .total_count } .desc_{ i } .npy' , img_bytes ), img_bytes )
100
+ img_bytes .close ()
101
+
102
+ self .sample_count += 1
103
+ self .total_count += 1
104
+
105
+ def get_tarinfo (fname : str , file_bytes : io .BytesIO ):
106
+ info = tarfile .TarInfo (fname )
107
+ info .size = file_bytes .getbuffer ().nbytes
108
+ info .mtime = time .time ()
109
+ return info
0 commit comments