Skip to content

Commit 08573c0

Browse files
authored
fix: Zero division in inverse estimator functions (#65)
#### Motivation Fixes this warning: ``` packages/text_generation_server/utils/memory_characterizer.py:71: RuntimeWarning: invalid value encountered in scalar divide Shard 0: return (np.sqrt(c0**2 + 4*c1*(mem/batch)) - c0)/(2*c1) ``` #### Modifications When the memory characterizer doesn't find a linear or quadratic behavior, the coefficients are set to zero resulting in division by zero errors in the inverse functions. In this commit this situation is detected and the max float is returned to be consistent with the semantics of the memory estimator Signed-off-by: Max de Bayser <[email protected]>
1 parent cc6e911 commit 08573c0

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

server/text_generation_server/utils/memory_characterizer.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,17 +72,21 @@ def inverse_quadratic_prefill(self, batch, mem):
7272
return (np.sqrt(c0**2 + 4*c1*(mem/batch)) - c0)/(2*c1)
7373

7474
def inverse_prefill(self,batch, mem):
75-
linear = self.inverse_linear_prefill(batch,mem)
76-
quad = self.inverse_quadratic_prefill(batch, mem)
75+
linear = self.inverse_linear_prefill(batch,mem) if self.linear_fit_params[0] != 0.0 else sys.float_info.max
76+
quad = self.inverse_quadratic_prefill(batch, mem) if self.quadratic_fit_params[1] != 0.0 else sys.float_info.max
7777
return min(linear, quad)
7878

7979
def nt_memory_usage(self, batch_size, input_len, output_len):
8080
return batch_size * self.next_token_params[0] * input_len + batch_size * self.next_token_params[1] * output_len
8181

8282
def inverse_next_token_output(self, batch, in_seq, mem):
83+
if self.next_token_params[1] == 0.0:
84+
return sys.float_info.max
8385
return (mem - self.next_token_params[0]*batch*in_seq)/(batch*self.next_token_params[1])
8486

8587
def inverse_next_token_input(self, batch, out_seq, mem):
88+
if self.next_token_params[0] == 0.0:
89+
return sys.float_info.max
8690
return (mem - self.next_token_params[1]*batch*out_seq)/(batch*self.next_token_params[0])
8791

8892
def max_input_len_for_prefill(self, batch_size, max_input_len):

0 commit comments

Comments
 (0)