Skip to content

Commit

Permalink
Merge pull request #210 from starquakee/evoxtorch-dev-fcc
Browse files Browse the repository at this point in the history
Change cec2022.py to avoid accuracy blowout in the extreme
  • Loading branch information
BillHuang2001 authored Feb 10, 2025
2 parents c2522d6 + a920c5d commit 8e3770d
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions src/evox/problems/numerical/cec2022.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,12 +127,12 @@ def cf_cal(self, x: torch.Tensor, fit: List[torch.Tensor], delta: List[int], bia
diff = x - shift[:, i * nx : (i + 1) * nx]
w = torch.sum(diff**2, dim=1)
w = torch.where(w != 0, (1 / torch.sqrt(w)) * torch.exp(-w / (2 * nx * d * d)), torch.inf)
w_sum += w
w_sum = w_sum + w
w_all.append(w * (f + b))
w_ret = torch.zeros(x.size(0), device=x.device)
w_sum = torch.where(w_sum == 0, 1e-9, w_sum)
for w in w_all:
w_ret += w / w_sum
w_ret = w_ret + w / w_sum
return w_ret

# cSpell:words Zakharov Rosenbrock Schaffer Rastrigin hgbat katsuura ackley schwefel happycat grie_rosen ellips escaffer griewank
Expand Down Expand Up @@ -217,7 +217,8 @@ def cec2022_f9(self, x: torch.Tensor) -> torch.Tensor:
self.sr_func_rate(x, 1.0, True, True, self.OShift[:, 2 * nx : 3 * nx], self.M[:, 2 * nx : 3 * nx])
)
* 10000
/ 1e30,
/ 1e10 / 1e10 / 1e10,
# if divide by 1e30 , cause NVRTC compilation error(https://github.com/pytorch/pytorch/issues/62962)
self.discus_func(self.sr_func_rate(x, 1.0, True, True, self.OShift[:, 3 * nx : 4 * nx], self.M[:, 3 * nx : 4 * nx]))
* 10000
/ 1e10,
Expand Down Expand Up @@ -310,7 +311,8 @@ def cec2022_f12(self, x: torch.Tensor) -> torch.Tensor:
self.sr_func_rate(x, 1.0, True, True, self.OShift[:, 3 * nx : 4 * nx], self.M[:, 3 * nx : 4 * nx])
)
* 10000
/ 1e30,
/ 1e10 / 1e10 / 1e10,
# if divide by 1e30 , cause NVRTC compilation error(https://github.com/pytorch/pytorch/issues/62962)
self.ellips_func(self.sr_func_rate(x, 1.0, True, True, self.OShift[:, 4 * nx : 5 * nx], self.M[:, 4 * nx : 5 * nx]))
* 10000
/ 1e10,
Expand Down

0 comments on commit 8e3770d

Please sign in to comment.