3
3
import contextlib
4
4
import dataclasses
5
5
import typing
6
+ from enum import IntEnum
7
+
8
+ from pydantic import RootModel
9
+ from pydantic .dataclasses import dataclass
6
10
7
11
from ..._exceptions import NextcloudException , NextcloudExceptionNotFound
8
- from ..._misc import clear_from_params_empty , require_capabilities
12
+ from ..._misc import require_capabilities
9
13
from ..._session import AsyncNcSessionApp , NcSessionApp
10
14
11
15
_EP_SUFFIX : str = "ai_provider/task_processing"
12
16
13
17
14
- @dataclasses .dataclass
15
- class TaskProcessingProvider :
16
- """TaskProcessing provider description."""
18
+ class ShapeType (IntEnum ):
19
+ """Enum for shape types."""
20
+
21
+ NUMBER = 0
22
+ TEXT = 1
23
+ IMAGE = 2
24
+ AUDIO = 3
25
+ VIDEO = 4
26
+ FILE = 5
27
+ ENUM = 6
28
+ LIST_OF_NUMBERS = 10
29
+ LIST_OF_TEXTS = 11
30
+ LIST_OF_IMAGES = 12
31
+ LIST_OF_AUDIOS = 13
32
+ LIST_OF_VIDEOS = 14
33
+ LIST_OF_FILES = 15
34
+
35
+
36
+ @dataclass
37
+ class ShapeEnumValue :
38
+ """Data object for input output shape enum slot value."""
39
+
40
+ name : str
41
+ """Name of the enum slot value which will be displayed in the UI"""
42
+ value : str
43
+ """Value of the enum slot value"""
44
+
17
45
18
- def __init__ (self , raw_data : dict ):
19
- self ._raw_data = raw_data
46
+ @dataclass
47
+ class ShapeDescriptor :
48
+ """Data object for input output shape entries."""
20
49
21
- @property
22
- def name (self ) -> str :
23
- """Unique ID for the provider."""
24
- return self ._raw_data ["name" ]
50
+ name : str
51
+ """Name of the shape entry"""
52
+ description : str
53
+ """Description of the shape entry"""
54
+ shape_type : ShapeType
55
+ """Type of the shape entry"""
25
56
26
- @property
27
- def display_name (self ) -> str :
28
- """Providers display name."""
29
- return self ._raw_data ["display_name" ]
30
57
31
- @property
32
- def task_type (self ) -> str :
33
- """The TaskType provided by this provider."""
34
- return self ._raw_data ["task_type" ]
58
+ @dataclass
59
+ class TaskType :
60
+ """TaskType description for the provider."""
61
+
62
+ id : str
63
+ """The unique ID for the task type."""
64
+ name : str
65
+ """The localized name of the task type."""
66
+ description : str
67
+ """The localized description of the task type."""
68
+ input_shape : list [ShapeDescriptor ]
69
+ """The input shape of the task."""
70
+ output_shape : list [ShapeDescriptor ]
71
+ """The output shape of the task."""
72
+
73
+
74
+ @dataclass
75
+ class TaskProcessingProvider :
76
+
77
+ id : str
78
+ """Unique ID for the provider."""
79
+ name : str
80
+ """The localized name of this provider"""
81
+ task_type : str
82
+ """The TaskType provided by this provider."""
83
+ expected_runtime : int = dataclasses .field (default = 0 )
84
+ """Expected runtime of the task in seconds."""
85
+ optional_input_shape : list [ShapeDescriptor ] = dataclasses .field (default_factory = list )
86
+ """Optional input shape of the task."""
87
+ optional_output_shape : list [ShapeDescriptor ] = dataclasses .field (default_factory = list )
88
+ """Optional output shape of the task."""
89
+ input_shape_enum_values : dict [str , list [ShapeEnumValue ]] = dataclasses .field (default_factory = dict )
90
+ """The option dict for each input shape ENUM slot."""
91
+ input_shape_defaults : dict [str , str | int | float ] = dataclasses .field (default_factory = dict )
92
+ """The default values for input shape slots."""
93
+ optional_input_shape_enum_values : dict [str , list [ShapeEnumValue ]] = dataclasses .field (default_factory = dict )
94
+ """The option list for each optional input shape ENUM slot."""
95
+ optional_input_shape_defaults : dict [str , str | int | float ] = dataclasses .field (default_factory = dict )
96
+ """The default values for optional input shape slots."""
97
+ output_shape_enum_values : dict [str , list [ShapeEnumValue ]] = dataclasses .field (default_factory = dict )
98
+ """The option list for each output shape ENUM slot."""
99
+ optional_output_shape_enum_values : dict [str , list [ShapeEnumValue ]] = dataclasses .field (default_factory = dict )
100
+ """The option list for each optional output shape ENUM slot."""
35
101
36
102
def __repr__ (self ):
37
103
return f"<{ self .__class__ .__name__ } name={ self .name } , type={ self .task_type } >"
@@ -44,17 +110,16 @@ def __init__(self, session: NcSessionApp):
44
110
self ._session = session
45
111
46
112
def register (
47
- self , name : str , display_name : str , task_type : str , custom_task_type : dict [str , typing .Any ] | None = None
113
+ self ,
114
+ provider : TaskProcessingProvider ,
115
+ custom_task_type : TaskType | None = None ,
48
116
) -> None :
49
117
"""Registers or edit the TaskProcessing provider."""
50
118
require_capabilities ("app_api" , self ._session .capabilities )
51
119
params = {
52
- "name" : name ,
53
- "displayName" : display_name ,
54
- "taskType" : task_type ,
55
- "customTaskType" : custom_task_type ,
120
+ "provider" : RootModel (provider ).model_dump (),
121
+ ** ({"customTaskType" : RootModel (custom_task_type ).model_dump ()} if custom_task_type else {}),
56
122
}
57
- clear_from_params_empty (["customTaskType" ], params )
58
123
self ._session .ocs ("POST" , f"{ self ._session .ae_url } /{ _EP_SUFFIX } " , json = params )
59
124
60
125
def unregister (self , name : str , not_fail = True ) -> None :
@@ -123,17 +188,16 @@ def __init__(self, session: AsyncNcSessionApp):
123
188
self ._session = session
124
189
125
190
async def register (
126
- self , name : str , display_name : str , task_type : str , custom_task_type : dict [str , typing .Any ] | None = None
191
+ self ,
192
+ provider : TaskProcessingProvider ,
193
+ custom_task_type : TaskType | None = None ,
127
194
) -> None :
128
195
"""Registers or edit the TaskProcessing provider."""
129
196
require_capabilities ("app_api" , await self ._session .capabilities )
130
197
params = {
131
- "name" : name ,
132
- "displayName" : display_name ,
133
- "taskType" : task_type ,
134
- "customTaskType" : custom_task_type ,
198
+ "provider" : RootModel (provider ).model_dump (),
199
+ ** ({"customTaskType" : RootModel (custom_task_type ).model_dump ()} if custom_task_type else {}),
135
200
}
136
- clear_from_params_empty (["customTaskType" ], params )
137
201
await self ._session .ocs ("POST" , f"{ self ._session .ae_url } /{ _EP_SUFFIX } " , json = params )
138
202
139
203
async def unregister (self , name : str , not_fail = True ) -> None :
0 commit comments