1
- from dataclasses import dataclass , field , InitVar
2
- from typing import Optional , List , Dict
1
+ from dataclasses import dataclass , field
2
+ from typing import Optional , Any
3
3
4
4
5
5
@dataclass (frozen = True )
@@ -15,67 +15,54 @@ def __post_init__(self) -> None:
15
15
16
16
17
17
@dataclass (frozen = True )
18
- class ModelDescription :
18
+ class BaseModelDescription :
19
19
model : str
20
20
sources : ModelSource
21
21
model_file : str
22
- dim : Optional [int ]
22
+ description : str = ""
23
+ license : str = ""
24
+ size_in_GB : Optional [float ] = None
25
+ additional_files : list [str ] = field (default_factory = list )
23
26
24
- description : str
25
- license : str
26
- size_in_GB : Optional [float ]
27
- additional_files : List [str ] = field (default_factory = list )
28
- tasks : Dict [str , int ] = field (default_factory = dict )
27
+ def validate_info (self ) -> None :
28
+ if self .license == "" :
29
+ raise ValueError ("license is required in builtin model description" )
30
+
31
+ if self .description == "" :
32
+ raise ValueError ("description is required in builtin model description" )
33
+
34
+ if self .size_in_GB is None :
35
+ raise ValueError ("size_in_GB is required in builtin model description" )
36
+
37
+ def __post_init__ (self ) -> None :
38
+ self .validate_info ()
29
39
30
40
31
41
@dataclass (frozen = True )
32
- class MultimodalModelDescription (ModelDescription ):
33
- dim : int
42
+ class DenseModelDescription (BaseModelDescription ):
43
+ dim : Optional [int ] = None
44
+ tasks : Optional [dict [str , Any ]] = None
45
+
46
+ def __post_init__ (self ) -> None :
47
+ assert self .dim is not None , "dim is required for dense model description"
48
+ self .validate_info ()
34
49
35
50
36
51
@dataclass (frozen = True )
37
- class SparseModelDescription (ModelDescription ):
38
- _vocab_size : InitVar [Optional [int ]] = None
39
- _requires_idf : InitVar [Optional [bool ]] = None
40
-
41
- vocab_size : int = field (init = False )
42
- requires_idf : Optional [bool ] = field (init = False , default = None )
43
- dim : Optional [int ] = field (default = None , init = False )
44
-
45
- def __init__ (
46
- self ,
47
- * ,
48
- model : str ,
49
- sources : ModelSource ,
50
- model_file : str ,
51
- description : str ,
52
- license : str ,
53
- size_in_GB : Optional [float ],
54
- dim : Optional [int ] = None ,
55
- additional_files : Optional [List [str ]] = None ,
56
- tasks : Optional [Dict [str , int ]] = None ,
57
- vocab_size : int ,
58
- requires_idf : Optional [bool ] = None ,
59
- ):
60
- # Call the parent initializer with the fields it needs.
61
- object .__setattr__ (self , "model" , model )
62
- object .__setattr__ (self , "sources" , sources )
63
- object .__setattr__ (self , "model_file" , model_file )
64
- object .__setattr__ (self , "dim" , dim if dim else None )
65
- object .__setattr__ (self , "description" , description )
66
- object .__setattr__ (self , "license" , license )
67
- object .__setattr__ (self , "size_in_GB" , size_in_GB )
68
- object .__setattr__ (
69
- self , "additional_files" , additional_files if additional_files is not None else []
70
- )
71
- object .__setattr__ (self , "tasks" , tasks if tasks is not None else {})
72
- # Set new fields.
73
- object .__setattr__ (self , "vocab_size" , vocab_size )
74
- object .__setattr__ (self , "requires_idf" , requires_idf )
52
+ class SparseModelDescription (BaseModelDescription ):
53
+ requires_idf : Optional [bool ] = None
54
+ vocab_size : Optional [int ] = None
75
55
76
56
77
57
@dataclass (frozen = True )
78
- class CustomModelDescription (ModelDescription ):
79
- description : str = ""
80
- license : str = ""
81
- size_in_GB : Optional [float ] = None
58
+ class CustomDenseModelDescription (DenseModelDescription ):
59
+ def __post_init__ (self ) -> None :
60
+ if self .dim is None :
61
+ raise ValueError ("dim is required for custom dense model description" )
62
+ # disable self.validate_info
63
+
64
+
65
+ @dataclass (frozen = True )
66
+ class CustomSparseModelDescription (SparseModelDescription ):
67
+ def __post_init__ (self ) -> None :
68
+ pass # disable self.validate_info
0 commit comments