Skip to content

Commit e895bbf

Browse files
Jake VanderPlasThe precondition Authors
Jake VanderPlas
authored and
The precondition Authors
committed
[LSC] Ignore incorrect type annotations related to jax.numpy APIs
PiperOrigin-RevId: 568529477
1 parent e059707 commit e895bbf

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

precondition/distributed_shampoo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1583,7 +1583,7 @@ def updated_statistics_from_grad(
15831583
for axis in preconditioned_dims:
15841584
update = functools.partial(gram_weighted_update, precision=precision)
15851585
if frequent_directions:
1586-
if _should_compress(self._compression_rank, g.shape[axis]):
1586+
if _should_compress(self._compression_rank, g.shape[axis]): # pytype: disable=wrong-arg-types # jnp-type
15871587
update = frequent_directions_update
15881588
new_stat = update(to_float(stats[index]), g, axis, w1, w2)
15891589
new_stats.append(from_float(new_stat))

precondition/tearfree/reallocation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ def score_fn(
154154
score_dict[name] = jnp.mean(
155155
jnp.array([ops_dict[rule](ct) for ct in current_target])
156156
)
157-
return score_dict
157+
return score_dict # pytype: disable=bad-return-type # jnp-type
158158

159159

160160
def create_redist_dict(

0 commit comments

Comments
 (0)