Skip to content

Commit 3216be1

Browse files
author
Manuel Riesen
committed
feat(tasks): add sort_list task
1 parent 29e9f34 commit 3216be1

File tree

1 file changed

+134
-0
lines changed

1 file changed

+134
-0
lines changed
+134
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
from pure_graph_of_thoughts.api.language_model import Prompt, Example
2+
from pure_graph_of_thoughts.api.operation import PromptOperation, OperationType, relative_complexity, ScoreExecOperation
3+
from pure_graph_of_thoughts.api.state import State
4+
from pure_graph_of_thoughts.api.task import Task, Evaluator
5+
6+
op_split = PromptOperation(
7+
name='split',
8+
n_outputs=2,
9+
n_inputs=1,
10+
type=OperationType.GENERATE,
11+
output_complexity=relative_complexity(1, 2),
12+
prompt=Prompt(
13+
instruction='Split the given list into two lists of equal size. '
14+
'Count the number of elements in the list before deciding where to split.'
15+
'Only output the lists in JSON format as the examples show.',
16+
examples=[
17+
Example(
18+
input={
19+
'list': [8, 2, 0, 1]
20+
},
21+
output={
22+
'lists': [
23+
[8, 2],
24+
[0, 1]
25+
]
26+
}
27+
)
28+
],
29+
),
30+
transform_before=lambda states: {
31+
'list': []
32+
} if len(states) == 0 else {
33+
'list': [states[0]['sum']]
34+
} if 'sum' in states[0] else states[0],
35+
transform_after=lambda state: [{'list': state_list} for state_list in state['lists']]
36+
)
37+
38+
op_merge = PromptOperation(
39+
name='merge',
40+
n_outputs=1,
41+
n_inputs=2,
42+
type=OperationType.AGGREGATE,
43+
output_complexity=relative_complexity(2),
44+
prompt=Prompt(
45+
instruction='Combine the given sorted lists to a single sorted list.'
46+
'Apply a merge sort to sort the final list.'
47+
'Only output the list in JSON format as the examples show.',
48+
examples=[
49+
Example(
50+
input={
51+
'lists': [
52+
[8, 2],
53+
[0, 1]
54+
]
55+
},
56+
output={
57+
'list': [8, 2, 0, 1]
58+
}
59+
)
60+
]
61+
),
62+
transform_before=lambda states: {
63+
'lists': [
64+
item for item in [
65+
state['list'] if 'list' in state
66+
else None for state in states
67+
] if item is not None
68+
]
69+
}
70+
)
71+
72+
73+
def score_op_sort(cumulative_score: float, previous_state: State, current_state: State) -> float:
74+
"""
75+
Determines the score of the sort operation.
76+
:param cumulative_score: cumulative score
77+
:param previous_state: previous state
78+
:param current_state: current state
79+
:return: score
80+
"""
81+
if cumulative_score < 0.0:
82+
return -1.0
83+
current_list = current_state['list'] if 'list' in current_state else None
84+
previous_list = (
85+
sorted(previous_state['list']) if 'list' in previous_state
86+
else sorted(
87+
list_item for state_list in previous_state['lists'] for list_item in state_list
88+
) if 'lists' in previous_state
89+
else None
90+
)
91+
if current_list is not None and previous_list is not None and current_list == previous_list:
92+
return 1.0
93+
return -1.0
94+
95+
96+
op_sort = PromptOperation(
97+
name='sort',
98+
type=OperationType.GENERATE,
99+
n_inputs=1,
100+
n_outputs=1,
101+
output_complexity=relative_complexity(1),
102+
prompt=Prompt(
103+
instruction='Sort the given list of single-digit integers in ascending order. Output the sorted list in JSON format.',
104+
examples=[
105+
Example(
106+
input={
107+
'list': [4, 9, 8, 4, 9, 1, 5, 6, 2, 8, 9, 9, 9, 2, 1, 5]
108+
},
109+
output={
110+
'list': [1, 1, 2, 2, 4, 4, 5, 5, 6, 8, 8, 9, 9, 9, 9, 9]
111+
}
112+
)
113+
]
114+
),
115+
transform_before=lambda states: {
116+
'list': []
117+
} if len(states) == 0 else states[0],
118+
score_operation=ScoreExecOperation(
119+
name='score',
120+
type=OperationType.SCORE,
121+
score=score_op_sort,
122+
n_inputs=1,
123+
n_outputs=1
124+
)
125+
)
126+
127+
sort_list_task = Task(
128+
operations=[op_sort, op_split, op_merge],
129+
evaluator=Evaluator(
130+
lambda initial_state, state: 'list' in initial_state
131+
and 'list' in state
132+
and sorted(initial_state['list']) == state['list']
133+
)
134+
)

0 commit comments

Comments
 (0)