diff --git a/array_api_tests/test_searching_functions.py b/array_api_tests/test_searching_functions.py
index a12e9d52..a159c4a4 100644
--- a/array_api_tests/test_searching_functions.py
+++ b/array_api_tests/test_searching_functions.py
@@ -3,6 +3,7 @@
 import pytest
 from hypothesis import given, note
 from hypothesis import strategies as st
+from hypothesis.control import assume
 
 from . import _array_module as xp
 from . import dtype_helpers as dh
@@ -203,3 +204,102 @@ def test_searchsorted(data):
         expected=xp.__array_namespace_info__().default_dtypes()["indexing"],
     )
     # TODO: shapes and values testing
+
+
+@pytest.mark.unvectorized
+# TODO: Test with signed zeros and NaNs (and ignore them somehow)
+@given(
+    x=hh.arrays(
+        dtype=hh.real_dtypes,
+        shape=hh.shapes(min_dims=1, min_side=1),
+        elements={"allow_nan": False},
+    ),
+    data=st.data()
+)
+def test_top_k(x, data):
+
+    if dh.is_float_dtype(x.dtype):
+        assume(not xp.any(x == -0.0) and not xp.any(x == +0.0))
+
+    axis = data.draw(
+        st.integers(-x.ndim, x.ndim - 1), label='axis')
+    largest = data.draw(st.booleans(), label='largest')
+    if axis is None:
+        k = data.draw(st.integers(1, math.prod(x.shape)))
+    else:
+        k = data.draw(st.integers(1, x.shape[axis]))
+
+    kw = dict(
+        x=x,
+        k=k,
+        axis=axis,
+        largest=largest,
+    )
+
+    (out_values, out_indices) = xp.top_k(x, k, axis, largest=largest)
+    if axis is None:
+        x = xp.reshape(x, (-1,))
+        axis = 0
+
+    ph.assert_dtype("top_k", in_dtype=x.dtype, out_dtype=out_values.dtype)
+    ph.assert_dtype(
+        "top_k",
+        in_dtype=x.dtype,
+        out_dtype=out_indices.dtype,
+        expected=dh.default_int
+    )
+    axes, = sh.normalise_axis(axis, x.ndim)
+    for arr in [out_values, out_indices]:
+        ph.assert_shape(
+            "top_k",
+            out_shape=arr.shape,
+            expected=x.shape[:axes] + (k,) + x.shape[axes + 1:],
+            kw=kw
+        )
+
+    scalar_type = dh.get_scalar_type(x.dtype)
+
+    for indices in sh.axes_ndindex(x.shape, (axes,)):
+
+        # Test if the values indexed by out_indices corresponds to
+        # the correct top_k values.
+        elements = [scalar_type(x[idx]) for idx in indices]
+        size = len(elements)
+        correct_order = sorted(
+            range(size),
+            key=elements.__getitem__,
+            reverse=largest
+        )
+        correct_order = correct_order[:k]
+        test_order = [out_indices[idx] for idx in indices[:k]]
+        # Sort because top_k does not necessarily return the values in
+        # sorted order.
+        test_sorted_order = sorted(
+            test_order,
+            key=elements.__getitem__,
+            reverse=largest
+        )
+
+        for y_o, x_o in zip(correct_order, test_sorted_order):
+            y_idx = indices[y_o]
+            x_idx = indices[x_o]
+            ph.assert_0d_equals(
+                "top_k",
+                x_repr=f"x[{x_idx}]",
+                x_val=x[x_idx],
+                out_repr=f"x[{y_idx}]",
+                out_val=x[y_idx],
+                kw=kw,
+            )
+
+        # Test if the values indexed by out_indices corresponds to out_values.
+        for y_o, x_idx in zip(test_order, indices[:k]):
+            y_idx = indices[y_o]
+            ph.assert_0d_equals(
+                "top_k",
+                x_repr=f"out_values[{x_idx}]",
+                x_val=scalar_type(out_values[x_idx]),
+                out_repr=f"x[{y_idx}]",
+                out_val=x[y_idx],
+                kw=kw
+            )