Skip to content

Commit 16aaacd

Browse files
authored
Merge pull request #85 from opentensor/staging
2.1.5 Release
2 parents b498f12 + 91ed7bf commit 16aaacd

File tree

7 files changed

+72
-28
lines changed

7 files changed

+72
-28
lines changed

neurons/validators/validator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ def __init__(self):
194194
MockRewardModel(RewardModelType.nsfw.value),
195195
]
196196
self.penalty_functions = [
197-
TaskValidationPenaltyModel(max_penalty=0.6),
197+
TaskValidationPenaltyModel(max_penalty=0.75),
198198
ContentMatchPenaltyModel(max_penalty=0.2),
199199
KeywordMatchPenaltyModel(max_penalty=1),
200200
]

prompting/validators/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from . import event
2828
from . import dataset
2929

30-
__version__ = "2.1.4"
30+
__version__ = "2.1.5"
3131
version_split = __version__.split(".")
3232
__spec_version__ = (
3333
(1000 * int(version_split[0]))

prompting/validators/criteria.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -241,4 +241,4 @@ def evaluate(self, completions: List[str]) -> torch.FloatTensor:
241241
return penalties
242242

243243
def compose_text(self) -> str:
244-
return self.text.format(layout_type=self.layout_type)
244+
return self.text.format(layout_type=self.layout_type.value)

prompting/validators/forward.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ async def run_step(self, task: Task, k: int, timeout: float, exclude: list = [])
221221
return event
222222

223223

224-
async def forward(self):
224+
async def questions_and_answers_around_summary_flow(self):
225225
# Obtain a unique context from the dataset.
226226
data = next(self.dataset)["text"]
227227

@@ -272,3 +272,8 @@ async def forward(self):
272272
)
273273

274274
exclude += qa_event["uids"]
275+
276+
277+
async def forward(self):
278+
# Definition of flow to be executed at forward step
279+
await questions_and_answers_around_summary_flow(self)

prompting/validators/reward/blacklist.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,9 @@ def name(self) -> str:
4343

4444
def __init__(
4545
self,
46-
boundary: float = 6,
46+
boundary: float = 40,
4747
n_min: int = 5,
48-
n_max: int = 14,
48+
n_max: int = 10,
4949
word_limit: int = 2000,
5050
A: float = 1.3,
5151
preprocess: str = "[^(\\w|\\s)]",
@@ -213,7 +213,7 @@ def calculate_significance(self) -> dict:
213213
if len(decoded_ngram.split()) >= self.n_min:
214214
# calculate significance score for ngram
215215
significance_scores[decoded_ngram] = (
216-
self.A ** (len(decoded_ngram) - 1)
216+
self.A ** (len(decoded_ngram.split()) - 1)
217217
* ((count[0] + count[1]) / self.num_completion)
218218
* self.frequency_multiplier
219219
)
@@ -302,7 +302,7 @@ def reward(self, prompt: str, completion: str, name: str) -> BlacklistRewardEven
302302
and fuzz.partial_ratio(ngram, completion.lower())
303303
> self.partial_ratio_boundary
304304
):
305-
reward_event.reward = 1
305+
reward_event.reward = 0
306306
reward_event.matched_ngram = ngram
307307
reward_event.significance_score = score
308308
return reward_event

prompting/validators/reward/reward.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ def apply(
178178
bt.logging.warning(
179179
f"The tensor from {self.name} contains NaN values: {filled_rewards_normalized}"
180180
)
181-
filled_rewards_normalized.nan_to_num_(nan=0.0)
181+
filled_rewards_normalized = filled_rewards_normalized.nan_to_num_(nan=0.0)
182182

183183
# Return the filled rewards.
184184
return filled_rewards_normalized, reward_events

prompting/validators/tasks.py

Lines changed: 58 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@
2828
ContentMatchTypeEnum,
2929
SimpleResponseLayoutCriteria,
3030
MatchContentCriteria,
31+
MatchLayoutCriteria,
32+
LayoutMatchTypeEnum,
3133
)
3234

3335

@@ -122,24 +124,61 @@ def compose_prompt(self) -> str:
122124

123125

124126
def create_summarization_task(base_text: str) -> SummaryTask:
125-
possible_criterias = [
126-
MatchLengthCriteria(
127-
penalty=0.1,
128-
target_length=random.randint(50, 200),
129-
unit=TextLengthUnitEnum.WORDS,
130-
),
131-
MatchLengthCriteria(
132-
penalty=0.1,
133-
target_length=random.randint(4, 8),
134-
unit=TextLengthUnitEnum.SENTENCES,
135-
),
136-
]
127+
# scope 1: bullet points, scope 2: numbered list, scope 3: simple layout
128+
scope = random.randint(1, 3)
129+
130+
select_bullet_point_layout = scope == 1
131+
select_numbered_list_layout = scope == 2
132+
133+
# scope 1 or 2: define criteria set for bullet points or numbered list
134+
if select_bullet_point_layout or select_numbered_list_layout:
135+
if select_bullet_point_layout:
136+
layout_criteria = MatchLayoutCriteria(
137+
layout_type=LayoutMatchTypeEnum.UNORDERED_LIST,
138+
penalty=0.5,
139+
text="Your response should be ordered in format of bullet points.",
140+
)
141+
else:
142+
layout_criteria = MatchLayoutCriteria(
143+
layout_type=LayoutMatchTypeEnum.NUMBERED_LIST,
144+
penalty=0.5,
145+
)
146+
147+
possible_other_criterion = [
148+
MatchLengthCriteria(
149+
penalty=0.25,
150+
target_length=random.randint(100, 200),
151+
unit=TextLengthUnitEnum.WORDS,
152+
),
153+
MatchLengthCriteria(
154+
penalty=0.25,
155+
target_length=random.randint(8, 12),
156+
unit=TextLengthUnitEnum.SENTENCES,
157+
),
158+
]
159+
# scope 3: define criteria set for simple layout
160+
else:
161+
layout_criteria = SimpleResponseLayoutCriteria(penalty=0.5)
162+
163+
possible_other_criterion = [
164+
MatchLengthCriteria(
165+
penalty=0.25,
166+
target_length=random.randint(50, 200),
167+
unit=TextLengthUnitEnum.WORDS,
168+
),
169+
MatchLengthCriteria(
170+
penalty=0.25,
171+
target_length=random.randint(4, 8),
172+
unit=TextLengthUnitEnum.SENTENCES,
173+
),
174+
]
137175

138-
sampled_criterias = random.sample(possible_criterias, 1)
176+
random_sampled_criterion = random.sample(possible_other_criterion, 1)
177+
defined_criteria = [layout_criteria] + random_sampled_criterion
139178

140179
return SummaryTask(
141180
base_text=base_text,
142-
criteria=sampled_criterias,
181+
criteria=defined_criteria,
143182
task_type="summarization",
144183
task_name="augment",
145184
)
@@ -192,12 +231,12 @@ def create_qg_task(base_text: str, index: int) -> QuestionGenerationTask:
192231

193232
other_random_criteria = [
194233
MatchLengthCriteria(
195-
penalty=0.1,
234+
penalty=0.25,
196235
target_length=random.randint(10, 40),
197236
unit=TextLengthUnitEnum.WORDS,
198237
),
199238
MatchLengthCriteria(
200-
penalty=0.1,
239+
penalty=0.25,
201240
target_length=random.randint(40, 150),
202241
unit=TextLengthUnitEnum.CHARACTERS,
203242
),
@@ -221,16 +260,16 @@ def create_qa_task(base_text: str, index: int) -> QuestionAnswerTask:
221260
answer_should_not_include_criteria = MatchContentCriteria(
222261
words_array=["?"],
223262
n_words=1,
224-
penalty=0.2,
263+
penalty=0.25,
225264
contentMatchType=ContentMatchTypeEnum.INCLUDES,
226265
negate_match=True,
227266
text="Your response should not include any question marks",
228267
)
229268

230-
simple_response_layout_criteria = SimpleResponseLayoutCriteria(penalty=0.2)
269+
simple_response_layout_criteria = SimpleResponseLayoutCriteria(penalty=0.25)
231270

232271
words_criteria = MatchLengthCriteria(
233-
penalty=0.2,
272+
penalty=0.25,
234273
target_length=random.randint(50, 200),
235274
unit=TextLengthUnitEnum.WORDS,
236275
)

0 commit comments

Comments
 (0)