1- from redis import StrictRedis
2- from typing import Union , Any , AnyStr , ByteString , Sequence
3- from .containers import Script , Model , Tensor
1+ from functools import wraps
2+ from typing import Union , AnyStr , ByteString , List , Sequence
43import warnings
54
6- try :
7- import numpy as np
8- except ImportError :
9- np = None
5+ from redis import StrictRedis
6+ import numpy as np
7+
8+ from . import utils
9+
1010
11- from .constants import Backend , Device , DType
12- from .utils import str_or_strsequence , to_string , list_to_dict
13- from . import convert
11+ def enable_debug (f ):
12+ @wraps (f )
13+ def wrapper (* args ):
14+ print (* args )
15+ return f (* args )
16+ return wrapper
1417
1518
19+ # TODO: typing to use AnyStr
20+
1621class Client (StrictRedis ):
1722 """
1823 RedisAI client that can call Redis with RedisAI specific commands
1924 """
20- def loadbackend (self , identifier : AnyStr , path : AnyStr ) -> AnyStr :
25+ def __init__ (self , debug = False , * args , ** kwargs ):
26+ super ().__init__ (* args , ** kwargs )
27+ if debug :
28+ self .execute_command = enable_debug (super ().execute_command )
29+
30+ def loadbackend (self , identifier : AnyStr , path : AnyStr ) -> str :
2131 """
2232 RedisAI by default won't load any backends. User can either explicitly
2333 load the backend by using this function or let RedisAI load the required
@@ -27,20 +37,36 @@ def loadbackend(self, identifier: AnyStr, path: AnyStr) -> AnyStr:
2737 :param path: Path to the shared object of the backend
2838 :return: byte string represents success or failure
2939 """
30- return self .execute_command ('AI.CONFIG LOADBACKEND' , identifier , path )
40+ return self .execute_command ('AI.CONFIG LOADBACKEND' , identifier , path ). decode ()
3141
3242 def modelset (self ,
3343 name : AnyStr ,
34- backend : Backend ,
35- device : Device ,
44+ backend : str ,
45+ device : str ,
3646 data : ByteString ,
3747 batch : int = None ,
3848 minbatch : int = None ,
3949 tag : str = None ,
40- inputs : Union [AnyStr , Sequence [AnyStr ]] = None ,
41- outputs : Union [AnyStr , Sequence [AnyStr ]] = None
42- ) -> AnyStr :
43- args = ['AI.MODELSET' , name , backend .value , device .value ]
50+ inputs : List [AnyStr ] = None ,
51+ outputs : List [AnyStr ] = None ) -> str :
52+ """
53+ Set the model on provided key.
54+ :param name: str, Key name
55+ :param backend: str, Backend name. Allowed backends are TF, TORCH, TFLITE, ONNX
56+ :param device: str, Device name. Allowed devices are CPU and GPU
57+ :param data: bytes, Model graph read as bytestring
58+ :param batch: int, Number of batches for doing autobatching
59+ :param minbatch: int, Minimum number of samples required in a batch for model
60+ execution
61+ :param tag: str, Any string that will be saved in RedisAI as tags for the model
62+ :param inputs: list, List of strings that represents the input nodes in the graph.
63+ Required only Tensorflow graphs
64+ :param outputs: list, List of strings that represents the output nodes in the graph
65+ Required only for Tensorflow graphs
66+
67+ :return:
68+ """
69+ args = ['AI.MODELSET' , name , backend , device ]
4470
4571 if batch is not None :
4672 args += ['BATCHSIZE' , batch ]
@@ -49,48 +75,47 @@ def modelset(self,
4975 if tag is not None :
5076 args += ['TAG' , tag ]
5177
52- if backend == Backend . tf :
78+ if backend . upper () == 'TF' :
5379 if not (all ((inputs , outputs ))):
5480 raise ValueError (
5581 'Require keyword arguments input and output for TF models' )
56- args += ['INPUTS' ] + str_or_strsequence (inputs )
57- args += ['OUTPUTS' ] + str_or_strsequence (outputs )
58- args += [data ]
59- return self .execute_command (* args )
60-
61- def modelget (self , name : AnyStr , meta_only = False ) -> Model :
62- argname = 'META' if meta_only else 'BLOB'
63- rv = self .execute_command ('AI.MODELGET' , name , argname )
64- rv = list_to_dict (rv )
65- return Model (
66- rv .get ('blob' ),
67- Device (rv ['device' ]),
68- Backend (rv ['backend' ]),
69- rv ['tag' ])
70-
71- def modeldel (self , name : AnyStr ) -> AnyStr :
72- return self .execute_command ('AI.MODELDEL' , name )
82+ args += ['INPUTS' ] + utils .listify (inputs )
83+ args += ['OUTPUTS' ] + utils .listify (outputs )
84+ args .append (data )
85+ return self .execute_command (* args ).decode ()
86+
87+ def modelget (self , name : AnyStr , meta_only = False ) -> dict :
88+ args = ['AI.MODELGET' , name , 'META' ]
89+ if not meta_only :
90+ args .append ('BLOB' )
91+ rv = self .execute_command (* args )
92+ return utils .list2dict (rv )
93+
94+ def modeldel (self , name : AnyStr ) -> str :
95+ return self .execute_command ('AI.MODELDEL' , name ).decode ()
7396
7497 def modelrun (self ,
7598 name : AnyStr ,
76- inputs : Union [AnyStr , Sequence [AnyStr ]],
77- outputs : Union [AnyStr , Sequence [AnyStr ]]
78- ) -> AnyStr :
79- args = ['AI.MODELRUN' , name ]
80- args += ['INPUTS' ] + str_or_strsequence (inputs )
81- args += ['OUTPUTS' ] + str_or_strsequence (outputs )
82- return self .execute_command (* args )
83-
84- def modelist (self ):
99+ inputs : List [AnyStr ],
100+ outputs : List [AnyStr ]
101+ ) -> str :
102+ out = self .execute_command (
103+ 'AI.MODELRUN' , name ,
104+ 'INPUTS' , * utils .listify (inputs ),
105+ 'OUTPUTS' , * utils .listify (outputs )
106+ )
107+ return out .decode ()
108+
109+ def modelscan (self ) -> list :
85110 warnings .warn ("Experimental: Model List API is experimental and might change "
86111 "in the future without any notice" , UserWarning )
87- return self .execute_command ("AI._MODELLIST" )
112+ return utils . un_bytize ( self .execute_command ("AI._MODELSCAN" ), lambda x : x . decode () )
88113
89114 def tensorset (self ,
90115 key : AnyStr ,
91116 tensor : Union [np .ndarray , list , tuple ],
92117 shape : Sequence [int ] = None ,
93- dtype : Union [ DType , type ] = None ) -> Any :
118+ dtype : str = None ) -> str :
94119 """
95120 Set the values of the tensor on the server using the provided Tensor object
96121 :param key: The name of the tensor
@@ -99,20 +124,20 @@ def tensorset(self,
99124 :param dtype: data type of the tensor. Required if `tensor` is list or tuple
100125 """
101126 if np and isinstance (tensor , np .ndarray ):
102- tensor = convert . from_numpy (tensor )
103- args = ['AI.TENSORSET' , key , tensor . dtype . value , * tensor . shape , tensor . argname , tensor . value ]
127+ dtype , shape , blob = utils . numpy2blob (tensor )
128+ args = ['AI.TENSORSET' , key , dtype , * shape , 'BLOB' , blob ]
104129 elif isinstance (tensor , (list , tuple )):
105130 if shape is None :
106131 shape = (len (tensor ),)
107- if not isinstance ( dtype , DType ):
108- dtype = DType . __members__ [ np . dtype ( dtype ). name ]
109- tensor = convert . from_sequence ( tensor , shape , dtype )
110- args = [ 'AI.TENSORSET' , key , tensor . dtype . value , * tensor . shape , tensor . argname , * tensor . value ]
111- return self .execute_command (* args )
132+ args = [ 'AI.TENSORSET' , key , dtype , * shape , 'VALUES' , * tensor ]
133+ else :
134+ raise TypeError ( f"`` tensor`` argument must be a numpy array or a list or a "
135+ f"tuple, but got { type ( tensor ) } " )
136+ return self .execute_command (* args ). decode ()
112137
113138 def tensorget (self ,
114- key : AnyStr , as_numpy : bool = True ,
115- meta_only : bool = False ) -> Union [Tensor , np .ndarray ]:
139+ key : str , as_numpy : bool = True ,
140+ meta_only : bool = False ) -> Union [dict , np .ndarray ]:
116141 """
117142 Retrieve the value of a tensor from the server. By default it returns the numpy array
118143 but it can be controlled using `as_type` argument and `meta_only` argument.
@@ -124,57 +149,63 @@ def tensorget(self,
124149 only the shape and the type
125150 :return: an instance of as_type
126151 """
152+ args = ['AI.TENSORGET' , key , 'META' ]
153+ if not meta_only :
154+ if as_numpy is True :
155+ args .append ('BLOB' )
156+ else :
157+ args .append ('VALUES' )
158+
159+ res = self .execute_command (* args )
160+ res = utils .list2dict (res )
127161 if meta_only :
128- argname = 'META'
162+ return res
129163 elif as_numpy is True :
130- argname = 'BLOB'
164+ return utils . blob2numpy ( res [ 'blob' ], res [ 'shape' ], res [ 'dtype' ])
131165 else :
132- argname = 'VALUES'
166+ target = float if res ['dtype' ] in ('FLOAT' , 'DOUBLE' ) else int
167+ utils .un_bytize (res ['values' ], target )
168+ return res
133169
134- res = self .execute_command ('AI.TENSORGET' , key , argname )
135- dtype , shape = to_string (res [0 ]), res [1 ]
136- if meta_only :
137- return convert .to_sequence ([], shape , dtype )
138- if as_numpy is True :
139- return convert .to_numpy (res [2 ], shape , dtype )
140- else :
141- return convert .to_sequence (res [2 ], shape , dtype )
142-
143- def scriptset (self , name : AnyStr , device : Device , script : AnyStr , tag : str = None ) -> AnyStr :
144- args = ['AI.SCRIPTSET' , name , device .value ]
170+ def scriptset (self , name : str , device : str , script : str , tag : str = None ) -> str :
171+ args = ['AI.SCRIPTSET' , name , device ]
145172 if tag :
146173 args += ['TAG' , tag ]
147- args += [ script ]
148- return self .execute_command (* args )
174+ args . append ( script )
175+ return self .execute_command (* args ). decode ()
149176
150- def scriptget (self , name : AnyStr ) -> Script :
151- ret = self .execute_command ('AI.SCRIPTGET' , name )
152- ret = list_to_dict (ret )
153- return Script (ret ['source' ], Device (ret ['device' ]), ret ['tag' ])
177+ def scriptget (self , name : AnyStr , meta_only = False ) -> dict :
178+ # TODO scripget test
179+ args = ['AI.SCRIPTGET' , name , 'META' ]
180+ if not meta_only :
181+ args .append ('SOURCE' )
182+ ret = self .execute_command (* args )
183+ return utils .list2dict (ret )
154184
155- def scriptdel (self , name ) :
156- return self .execute_command ('AI.SCRIPTDEL' , name )
185+ def scriptdel (self , name : str ) -> str :
186+ return self .execute_command ('AI.SCRIPTDEL' , name ). decode ()
157187
158188 def scriptrun (self ,
159189 name : AnyStr ,
160190 function : AnyStr ,
161191 inputs : Union [AnyStr , Sequence [AnyStr ]],
162192 outputs : Union [AnyStr , Sequence [AnyStr ]]
163193 ) -> AnyStr :
164- args = ['AI.SCRIPTRUN' , name , function , 'INPUTS' ]
165- args += str_or_strsequence (inputs )
166- args += ['OUTPUTS' ]
167- args += str_or_strsequence (outputs )
168- return self .execute_command (* args )
169-
170- def scriptlist (self ):
194+ out = self .execute_command (
195+ 'AI.SCRIPTRUN' , name , function ,
196+ 'INPUTS' , * utils .listify (inputs ),
197+ 'OUTPUTS' , * utils .listify (outputs )
198+ )
199+ return out .decode ()
200+
201+ def scriptscan (self ) -> list :
171202 warnings .warn ("Experimental: Script List API is experimental and might change "
172203 "in the future without any notice" , UserWarning )
173- return self .execute_command ("AI._SCRIPTLIST" )
204+ return utils . un_bytize ( self .execute_command ("AI._SCRIPTSCAN" ), lambda x : x . decode () )
174205
175206 def infoget (self , key : str ) -> dict :
176207 ret = self .execute_command ('AI.INFO' , key )
177- return list_to_dict (ret )
208+ return utils . list2dict (ret )
178209
179- def inforeset (self , key : str ) -> dict :
180- return self .execute_command ('AI.INFO' , key , 'RESETSTAT' )
210+ def inforeset (self , key : str ) -> str :
211+ return self .execute_command ('AI.INFO' , key , 'RESETSTAT' ). decode ()
0 commit comments