1+
2+ import io
3+ import os
4+ import tarfile
5+ import time
6+ from os import PathLike
7+ from pathlib import Path
8+ from typing import List , Optional
9+
10+ import numpy as np
11+ from PIL import Image
12+
13+
14+ class TarWriter :
15+ '''
16+ Write samples of AFM images, molecules and descriptors to tar files. Use as a context manager and add samples with
17+ :meth:`add_sample`.
18+
19+ Each tar file has a maximum number of samples, and whenever that maximum is reached, a new tar file is created.
20+ The generated tar files are named as ``{base_name}_{n}.tar`` and saved into the specified folder. The current tar file
21+ handle is always available in the attribute :attr:`ft`, and is automatically closed when the context ends.
22+
23+ Arguments:
24+ base_path: Path to directory where tar files are saved.
25+ base_name: Base name for output tar files. The number of the tar file is appended to the name.
26+ max_count: Maximum number of samples per tar file.
27+ png_compress_level: Compression level 1-9 for saved png images. Larger value for smaller file size but slower
28+ write speed.
29+ '''
30+
31+ def __init__ (self , base_path : PathLike = './' , base_name : str = '' , max_count : int = 100 , png_compress_level = 4 ):
32+ self .base_path = Path (base_path )
33+ self .base_name = base_name
34+ self .max_count = max_count
35+ self .png_compress_level = png_compress_level
36+
37+ def __enter__ (self ):
38+ self .sample_count = 0
39+ self .total_count = 0
40+ self .tar_count = 0
41+ self .ft = self ._get_tar_file ()
42+ return self
43+
44+ def __exit__ (self , exc_type , exc_value , exc_traceback ):
45+ self .ft .close ()
46+
47+ def _get_tar_file (self ):
48+ file_path = self .base_path / f'{ self .base_name } _{ self .tar_count } .tar'
49+ if os .path .exists (file_path ):
50+ raise RuntimeError (f'Tar file already exists at `{ file_path } `' )
51+ return tarfile .open (file_path , 'w' , format = tarfile .GNU_FORMAT )
52+
53+ def add_sample (self , X : List [np .ndarray ], xyzs : np .ndarray , Y : Optional [np .ndarray ]= None , comment_str : str = '' ):
54+ """
55+ Add a sample to the current tar file.
56+
57+ Arguments:
58+ X: AFM images. Each list item corresponds to an AFM tip and is an array of shape (nx, ny, nz).
59+ xyzs: Atom coordinates and elements. Each row is one atom and is of the form [x, y, z, element].
60+ Y: Image descriptors. Each list item is one descriptor and is an array of shape (nx, ny).
61+ comment_str: Comment line (second line) to add to the xyz file.
62+ """
63+
64+ if self .sample_count >= self .max_count :
65+ self .tar_count += 1
66+ self .sample_count = 0
67+ self .ft .close ()
68+ self .ft = self ._get_tar_file ()
69+
70+ # Write AFM images
71+ for i , x in enumerate (X ):
72+ for j in range (x .shape [- 1 ]):
73+ xj = x [:, :, j ]
74+ xj = ((xj - xj .min ()) / np .ptp (xj ) * (2 ** 8 - 1 )).astype (np .uint8 ) # Convert range to 0-255 integers
75+ img_bytes = io .BytesIO ()
76+ Image .fromarray (xj .T [::- 1 ], mode = 'L' ).save (img_bytes , 'png' , compress_level = self .png_compress_level )
77+ img_bytes .seek (0 ) # Return stream to start so that addfile can read it correctly
78+ self .ft .addfile (get_tarinfo (f'{ self .total_count } .{ j :02d} .{ i } .png' , img_bytes ), img_bytes )
79+ img_bytes .close ()
80+
81+ # Write xyz file
82+ xyz_bytes = io .BytesIO ()
83+ xyz_bytes .write (bytearray (f'{ len (xyzs )} \n { comment_str } \n ' , 'utf-8' ))
84+ for xyz in xyzs :
85+ xyz_bytes .write (bytearray (f'{ int (xyz [- 1 ])} \t ' , 'utf-8' ))
86+ for i in range (len (xyz )- 1 ):
87+ xyz_bytes .write (bytearray (f'{ xyz [i ]:10.8f} \t ' , 'utf-8' ))
88+ xyz_bytes .write (bytearray ('\n ' , 'utf-8' ))
89+ xyz_bytes .seek (0 ) # Return stream to start so that addfile can read it correctly
90+ self .ft .addfile (get_tarinfo (f'{ self .total_count } .xyz' , xyz_bytes ), xyz_bytes )
91+ xyz_bytes .close ()
92+
93+ # Write image descriptors (if any)
94+ if Y is not None :
95+ for i , y in enumerate (Y ):
96+ img_bytes = io .BytesIO ()
97+ np .save (img_bytes , y .astype (np .float32 ))
98+ img_bytes .seek (0 ) # Return stream to start so that addfile can read it correctly
99+ self .ft .addfile (get_tarinfo (f'{ self .total_count } .desc_{ i } .npy' , img_bytes ), img_bytes )
100+ img_bytes .close ()
101+
102+ self .sample_count += 1
103+ self .total_count += 1
104+
105+ def get_tarinfo (fname : str , file_bytes : io .BytesIO ):
106+ info = tarfile .TarInfo (fname )
107+ info .size = file_bytes .getbuffer ().nbytes
108+ info .mtime = time .time ()
109+ return info
0 commit comments