Skip to content

Commit ad33b1a

Browse files
committed
feat: Add enum and default value support in task processing
Signed-off-by: Anupam Kumar <[email protected]>
1 parent 3d26987 commit ad33b1a

File tree

1 file changed

+103
-24
lines changed

1 file changed

+103
-24
lines changed

nc_py_api/ex_app/providers/task_processing.py

+103-24
Original file line numberDiff line numberDiff line change
@@ -3,35 +3,104 @@
33
import contextlib
44
import dataclasses
55
import typing
6+
from enum import IntEnum
7+
8+
from pydantic import RootModel
9+
from pydantic.dataclasses import dataclass
610

711
from ..._exceptions import NextcloudException, NextcloudExceptionNotFound
8-
from ..._misc import clear_from_params_empty, require_capabilities
12+
from ..._misc import require_capabilities
913
from ..._session import AsyncNcSessionApp, NcSessionApp
1014

1115
_EP_SUFFIX: str = "ai_provider/task_processing"
1216

1317

14-
@dataclasses.dataclass
15-
class TaskProcessingProvider:
16-
"""TaskProcessing provider description."""
18+
class ShapeType(IntEnum):
19+
"""Enum for shape types."""
20+
21+
NUMBER = 0
22+
TEXT = 1
23+
IMAGE = 2
24+
AUDIO = 3
25+
VIDEO = 4
26+
FILE = 5
27+
ENUM = 6
28+
LISTOFNUMBERS = 10
29+
LISTOFTEXTS = 11
30+
LISTOFIMAGES = 12
31+
LISTOFAUDIOS = 13
32+
LISTOFVIDEOS = 14
33+
LISTOFFILES = 15
34+
35+
36+
@dataclass
37+
class ShapeEnumValue:
38+
"""Data object for input output shape enum slot value."""
39+
40+
name: str
41+
"""Name of the enum slot value which will be displayed in the UI"""
42+
value: str
43+
"""Value of the enum slot value"""
44+
45+
46+
@dataclass
47+
class ShapeDescriptor:
48+
"""Data object for input output shape entries."""
1749

18-
def __init__(self, raw_data: dict):
19-
self._raw_data = raw_data
50+
name: str
51+
"""Name of the shape entry"""
52+
description: str
53+
"""Description of the shape entry"""
54+
shape_type: ShapeType
55+
"""Type of the shape entry"""
2056

21-
@property
22-
def name(self) -> str:
23-
"""Unique ID for the provider."""
24-
return self._raw_data["name"]
2557

26-
@property
27-
def display_name(self) -> str:
28-
"""Providers display name."""
29-
return self._raw_data["display_name"]
58+
@dataclass
59+
class TaskType:
60+
"""TaskType description for the provider."""
3061

31-
@property
32-
def task_type(self) -> str:
33-
"""The TaskType provided by this provider."""
34-
return self._raw_data["task_type"]
62+
id: str
63+
"""The unique ID for the task type."""
64+
name: str
65+
"""The localized name of the task type."""
66+
description: str
67+
"""The localized description of the task type."""
68+
input_shape: list[ShapeDescriptor]
69+
"""The input shape of the task."""
70+
output_shape: list[ShapeDescriptor]
71+
"""The output shape of the task."""
72+
73+
74+
@dataclass
75+
class TaskProcessingProvider:
76+
"""TaskProcessing provider description."""
77+
78+
# pylint: disable=too-many-instance-attributes
79+
80+
id: str
81+
"""Unique ID for the provider."""
82+
name: str
83+
"""The localized name of this provider"""
84+
task_type: str
85+
"""The TaskType provided by this provider."""
86+
expected_runtime: int = dataclasses.field(default=0)
87+
"""Expected runtime of the task in seconds."""
88+
optional_input_shape: list[ShapeDescriptor] = dataclasses.field(default_factory=list)
89+
"""Optional input shape of the task."""
90+
optional_output_shape: list[ShapeDescriptor] = dataclasses.field(default_factory=list)
91+
"""Optional output shape of the task."""
92+
input_shape_enum_values: dict[str, list[ShapeEnumValue]] = dataclasses.field(default_factory=dict)
93+
"""The option dict for each input shape ENUM slot."""
94+
input_shape_defaults: dict[str, str | int | float] = dataclasses.field(default_factory=dict)
95+
"""The default values for input shape slots."""
96+
optional_input_shape_enum_values: dict[str, list[ShapeEnumValue]] = dataclasses.field(default_factory=dict)
97+
"""The option list for each optional input shape ENUM slot."""
98+
optional_input_shape_defaults: dict[str, str | int | float] = dataclasses.field(default_factory=dict)
99+
"""The default values for optional input shape slots."""
100+
output_shape_enum_values: dict[str, list[ShapeEnumValue]] = dataclasses.field(default_factory=dict)
101+
"""The option list for each output shape ENUM slot."""
102+
optional_output_shape_enum_values: dict[str, list[ShapeEnumValue]] = dataclasses.field(default_factory=dict)
103+
"""The option list for each optional output shape ENUM slot."""
35104

36105
def __repr__(self):
37106
return f"<{self.__class__.__name__} name={self.name}, type={self.task_type}>"
@@ -44,17 +113,22 @@ def __init__(self, session: NcSessionApp):
44113
self._session = session
45114

46115
def register(
47-
self, name: str, display_name: str, task_type: str, custom_task_type: dict[str, typing.Any] | None = None
116+
self,
117+
name: str,
118+
display_name: str,
119+
task_type: str,
120+
provider: TaskProcessingProvider,
121+
custom_task_type: TaskType | None = None,
48122
) -> None:
49123
"""Registers or edit the TaskProcessing provider."""
50124
require_capabilities("app_api", self._session.capabilities)
51125
params = {
52126
"name": name,
53127
"displayName": display_name,
54128
"taskType": task_type,
55-
"customTaskType": custom_task_type,
129+
"provider": RootModel(provider).model_dump(),
130+
**({"customTaskType": RootModel(custom_task_type).model_dump()} if custom_task_type else {}),
56131
}
57-
clear_from_params_empty(["customTaskType"], params)
58132
self._session.ocs("POST", f"{self._session.ae_url}/{_EP_SUFFIX}", json=params)
59133

60134
def unregister(self, name: str, not_fail=True) -> None:
@@ -123,17 +197,22 @@ def __init__(self, session: AsyncNcSessionApp):
123197
self._session = session
124198

125199
async def register(
126-
self, name: str, display_name: str, task_type: str, custom_task_type: dict[str, typing.Any] | None = None
200+
self,
201+
name: str,
202+
display_name: str,
203+
task_type: str,
204+
provider: TaskProcessingProvider,
205+
custom_task_type: TaskType | None = None,
127206
) -> None:
128207
"""Registers or edit the TaskProcessing provider."""
129208
require_capabilities("app_api", await self._session.capabilities)
130209
params = {
131210
"name": name,
132211
"displayName": display_name,
133212
"taskType": task_type,
134-
"customTaskType": custom_task_type,
213+
"provider": RootModel(provider).model_dump(),
214+
**({"customTaskType": RootModel(custom_task_type).model_dump()} if custom_task_type else {}),
135215
}
136-
clear_from_params_empty(["customTaskType"], params)
137216
await self._session.ocs("POST", f"{self._session.ae_url}/{_EP_SUFFIX}", json=params)
138217

139218
async def unregister(self, name: str, not_fail=True) -> None:

0 commit comments

Comments
 (0)