-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy patharg_utils.py
130 lines (97 loc) · 2.58 KB
/
arg_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
"""
This module defines some common functions for argument parsing.
Author: wangning([email protected])
Date : 2022/12/7 7:41 PM
"""
import argparse
import random
import numpy as np
import paddle
from paddlenlp.utils.log import logger
def set_seed(seed):
"""set seed for random, numpy and paddle
Args:
seed (int): seed value
"""
random.seed(seed)
np.random.seed(seed)
paddle.seed(seed)
def default_logdir():
"""generate default log dir
Returns:
path.Path: path of log dir
"""
from datetime import datetime
current_time = datetime.now().strftime("%b%d_%H-%M-%S")
return current_time
def str2bool(v):
"""convert string args to boolean
Args:
v (str): from args
Returns:
boolean: True or False
"""
if isinstance(v, bool):
return v
if v.lower() in ('yes', 'true', 't', 'y', '1'):
return True
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
return False
else:
raise argparse.ArgumentTypeError("Boolean value expected.")
def str2list(v):
"""convert string to list
Args:
v (str): from args
Returns:
list: sperated by ','
"""
if isinstance(v, list):
return v
elif isinstance(v, str):
vs = v.split(",")
return [v.strip() for v in vs]
else:
raise argparse.ArgumentTypeError("Str value seperated by ', ' expected.")
def str2intlist(v):
"""convert string to int list
Args:
v (string): from args
Returns:
list: int list
"""
if isinstance(v, list):
return v
elif isinstance(v, str):
vs = v.split(",")
if vs[-1] == "":
vs = vs[:-1]
return [int(v.strip()) for v in vs]
else:
raise argparse.ArgumentTypeError("Str value seperated by ', ' expected.")
def list2str(list_value):
"""convert list to string
Args:
v (list): int list
Returns:
str: s
"""
if isinstance(list_value, list):
res = ""
for x in list_value:
res = res + x + ","
return res[:-1]
else:
raise NotImplementedError
def print_config(args=None, key=""):
"""print the configuration of the experiment
Args:
args (argparse.Namespace): from args
Returns:
None
"""
logger.debug("=" * 60)
logger.debug('{:^40}'.format("{} Configuration Arguments".format(key)))
logger.debug('{:30}:{}'.format("paddle commit id", paddle.version.commit))
for k, v in vars(args).items():
logger.debug('{:30}:{}'.format(k, v))