3
3
import contextlib
4
4
import dataclasses
5
5
import typing
6
+ from enum import IntEnum
7
+
8
+ from pydantic import RootModel
9
+ from pydantic .dataclasses import dataclass
6
10
7
11
from ..._exceptions import NextcloudException , NextcloudExceptionNotFound
8
- from ..._misc import clear_from_params_empty , require_capabilities
12
+ from ..._misc import require_capabilities
9
13
from ..._session import AsyncNcSessionApp , NcSessionApp
10
14
11
15
_EP_SUFFIX : str = "ai_provider/task_processing"
12
16
13
17
14
- @dataclasses .dataclass
15
- class TaskProcessingProvider :
16
- """TaskProcessing provider description."""
18
+ class ShapeType (IntEnum ):
19
+ """Enum for shape types."""
20
+
21
+ NUMBER = 0
22
+ TEXT = 1
23
+ IMAGE = 2
24
+ AUDIO = 3
25
+ VIDEO = 4
26
+ FILE = 5
27
+ ENUM = 6
28
+ LISTOFNUMBERS = 10
29
+ LISTOFTEXTS = 11
30
+ LISTOFIMAGES = 12
31
+ LISTOFAUDIOS = 13
32
+ LISTOFVIDEOS = 14
33
+ LISTOFFILES = 15
34
+
35
+
36
+ @dataclass
37
+ class ShapeEnumValue :
38
+ """Data object for input output shape enum slot value."""
39
+
40
+ name : str
41
+ """Name of the enum slot value which will be displayed in the UI"""
42
+ value : str
43
+ """Value of the enum slot value"""
44
+
45
+
46
+ @dataclass
47
+ class ShapeDescriptor :
48
+ """Data object for input output shape entries."""
17
49
18
- def __init__ (self , raw_data : dict ):
19
- self ._raw_data = raw_data
50
+ name : str
51
+ """Name of the shape entry"""
52
+ description : str
53
+ """Description of the shape entry"""
54
+ shape_type : ShapeType
55
+ """Type of the shape entry"""
20
56
21
- @property
22
- def name (self ) -> str :
23
- """Unique ID for the provider."""
24
- return self ._raw_data ["name" ]
25
57
26
- @property
27
- def display_name (self ) -> str :
28
- """Providers display name."""
29
- return self ._raw_data ["display_name" ]
58
+ @dataclass
59
+ class TaskType :
60
+ """TaskType description for the provider."""
30
61
31
- @property
32
- def task_type (self ) -> str :
33
- """The TaskType provided by this provider."""
34
- return self ._raw_data ["task_type" ]
62
+ id : str
63
+ """The unique ID for the task type."""
64
+ name : str
65
+ """The localized name of the task type."""
66
+ description : str
67
+ """The localized description of the task type."""
68
+ input_shape : list [ShapeDescriptor ]
69
+ """The input shape of the task."""
70
+ output_shape : list [ShapeDescriptor ]
71
+ """The output shape of the task."""
72
+
73
+
74
+ @dataclass
75
+ class TaskProcessingProvider :
76
+ """TaskProcessing provider description."""
77
+
78
+ # pylint: disable=too-many-instance-attributes
79
+
80
+ id : str
81
+ """Unique ID for the provider."""
82
+ name : str
83
+ """The localized name of this provider"""
84
+ task_type : str
85
+ """The TaskType provided by this provider."""
86
+ expected_runtime : int = dataclasses .field (default = 0 )
87
+ """Expected runtime of the task in seconds."""
88
+ optional_input_shape : list [ShapeDescriptor ] = dataclasses .field (default_factory = list )
89
+ """Optional input shape of the task."""
90
+ optional_output_shape : list [ShapeDescriptor ] = dataclasses .field (default_factory = list )
91
+ """Optional output shape of the task."""
92
+ input_shape_enum_values : dict [str , list [ShapeEnumValue ]] = dataclasses .field (default_factory = dict )
93
+ """The option dict for each input shape ENUM slot."""
94
+ input_shape_defaults : dict [str , str | int | float ] = dataclasses .field (default_factory = dict )
95
+ """The default values for input shape slots."""
96
+ optional_input_shape_enum_values : dict [str , list [ShapeEnumValue ]] = dataclasses .field (default_factory = dict )
97
+ """The option list for each optional input shape ENUM slot."""
98
+ optional_input_shape_defaults : dict [str , str | int | float ] = dataclasses .field (default_factory = dict )
99
+ """The default values for optional input shape slots."""
100
+ output_shape_enum_values : dict [str , list [ShapeEnumValue ]] = dataclasses .field (default_factory = dict )
101
+ """The option list for each output shape ENUM slot."""
102
+ optional_output_shape_enum_values : dict [str , list [ShapeEnumValue ]] = dataclasses .field (default_factory = dict )
103
+ """The option list for each optional output shape ENUM slot."""
35
104
36
105
def __repr__ (self ):
37
106
return f"<{ self .__class__ .__name__ } name={ self .name } , type={ self .task_type } >"
@@ -44,17 +113,22 @@ def __init__(self, session: NcSessionApp):
44
113
self ._session = session
45
114
46
115
def register (
47
- self , name : str , display_name : str , task_type : str , custom_task_type : dict [str , typing .Any ] | None = None
116
+ self ,
117
+ name : str ,
118
+ display_name : str ,
119
+ task_type : str ,
120
+ provider : TaskProcessingProvider ,
121
+ custom_task_type : TaskType | None = None ,
48
122
) -> None :
49
123
"""Registers or edit the TaskProcessing provider."""
50
124
require_capabilities ("app_api" , self ._session .capabilities )
51
125
params = {
52
126
"name" : name ,
53
127
"displayName" : display_name ,
54
128
"taskType" : task_type ,
55
- "customTaskType" : custom_task_type ,
129
+ "provider" : RootModel (provider ).model_dump (),
130
+ ** ({"customTaskType" : RootModel (custom_task_type ).model_dump ()} if custom_task_type else {}),
56
131
}
57
- clear_from_params_empty (["customTaskType" ], params )
58
132
self ._session .ocs ("POST" , f"{ self ._session .ae_url } /{ _EP_SUFFIX } " , json = params )
59
133
60
134
def unregister (self , name : str , not_fail = True ) -> None :
@@ -123,17 +197,22 @@ def __init__(self, session: AsyncNcSessionApp):
123
197
self ._session = session
124
198
125
199
async def register (
126
- self , name : str , display_name : str , task_type : str , custom_task_type : dict [str , typing .Any ] | None = None
200
+ self ,
201
+ name : str ,
202
+ display_name : str ,
203
+ task_type : str ,
204
+ provider : TaskProcessingProvider ,
205
+ custom_task_type : TaskType | None = None ,
127
206
) -> None :
128
207
"""Registers or edit the TaskProcessing provider."""
129
208
require_capabilities ("app_api" , await self ._session .capabilities )
130
209
params = {
131
210
"name" : name ,
132
211
"displayName" : display_name ,
133
212
"taskType" : task_type ,
134
- "customTaskType" : custom_task_type ,
213
+ "provider" : RootModel (provider ).model_dump (),
214
+ ** ({"customTaskType" : RootModel (custom_task_type ).model_dump ()} if custom_task_type else {}),
135
215
}
136
- clear_from_params_empty (["customTaskType" ], params )
137
216
await self ._session .ocs ("POST" , f"{ self ._session .ae_url } /{ _EP_SUFFIX } " , json = params )
138
217
139
218
async def unregister (self , name : str , not_fail = True ) -> None :
0 commit comments