Skip to content

Commit 07c97ea

Browse files
Kludexxrmx
andauthored
Add type hints to psycopg (#3067)
* Add type hints to Psycopg * fix tests * fix * Add psycopg.Connection to nitpick * Add py.typed * add psycopg to nitpick again * add psycopg to nitpick again * move py.typed to the right folder --------- Co-authored-by: Riccardo Magliocchetti <[email protected]>
1 parent 52871b8 commit 07c97ea

File tree

5 files changed

+45
-34
lines changed

5 files changed

+45
-34
lines changed

docs/nitpick-exceptions.ini

+2
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ py-class=
4141
callable
4242
Consumer
4343
confluent_kafka.Message
44+
psycopg.Connection
45+
psycopg.AsyncConnection
4446
ObjectProxy
4547

4648
any=

instrumentation/opentelemetry-instrumentation-psycopg/src/opentelemetry/instrumentation/psycopg/__init__.py

+39-30
Original file line numberDiff line numberDiff line change
@@ -137,27 +137,28 @@
137137
---
138138
"""
139139

140+
from __future__ import annotations
141+
140142
import logging
141-
import typing
142-
from typing import Collection
143+
from typing import Any, Callable, Collection, TypeVar
143144

144145
import psycopg # pylint: disable=import-self
145-
from psycopg import (
146-
AsyncCursor as pg_async_cursor, # pylint: disable=import-self,no-name-in-module
147-
)
148-
from psycopg import (
149-
Cursor as pg_cursor, # pylint: disable=no-name-in-module,import-self
150-
)
151146
from psycopg.sql import Composed # pylint: disable=no-name-in-module
152147

153148
from opentelemetry.instrumentation import dbapi
154149
from opentelemetry.instrumentation.instrumentor import BaseInstrumentor
155150
from opentelemetry.instrumentation.psycopg.package import _instruments
156151
from opentelemetry.instrumentation.psycopg.version import __version__
152+
from opentelemetry.trace import TracerProvider
157153

158154
_logger = logging.getLogger(__name__)
159155
_OTEL_CURSOR_FACTORY_KEY = "_otel_orig_cursor_factory"
160156

157+
ConnectionT = TypeVar(
158+
"ConnectionT", psycopg.Connection, psycopg.AsyncConnection
159+
)
160+
CursorT = TypeVar("CursorT", psycopg.Cursor, psycopg.AsyncCursor)
161+
161162

162163
class PsycopgInstrumentor(BaseInstrumentor):
163164
_CONNECTION_ATTRIBUTES = {
@@ -172,7 +173,7 @@ class PsycopgInstrumentor(BaseInstrumentor):
172173
def instrumentation_dependencies(self) -> Collection[str]:
173174
return _instruments
174175

175-
def _instrument(self, **kwargs):
176+
def _instrument(self, **kwargs: Any):
176177
"""Integrate with PostgreSQL Psycopg library.
177178
Psycopg: http://initd.org/psycopg/
178179
"""
@@ -223,7 +224,7 @@ def _instrument(self, **kwargs):
223224
enable_attribute_commenter=enable_attribute_commenter,
224225
)
225226

226-
def _uninstrument(self, **kwargs):
227+
def _uninstrument(self, **kwargs: Any):
227228
""" "Disable Psycopg instrumentation"""
228229
dbapi.unwrap_connect(psycopg, "connect") # pylint: disable=no-member
229230
dbapi.unwrap_connect(
@@ -237,7 +238,9 @@ def _uninstrument(self, **kwargs):
237238

238239
# TODO(owais): check if core dbapi can do this for all dbapi implementations e.g, pymysql and mysql
239240
@staticmethod
240-
def instrument_connection(connection, tracer_provider=None):
241+
def instrument_connection(
242+
connection: ConnectionT, tracer_provider: TracerProvider | None = None
243+
) -> ConnectionT:
241244
"""Enable instrumentation in a psycopg connection.
242245
243246
Args:
@@ -269,7 +272,7 @@ def instrument_connection(connection, tracer_provider=None):
269272

270273
# TODO(owais): check if core dbapi can do this for all dbapi implementations e.g, pymysql and mysql
271274
@staticmethod
272-
def uninstrument_connection(connection):
275+
def uninstrument_connection(connection: ConnectionT) -> ConnectionT:
273276
connection.cursor_factory = getattr(
274277
connection, _OTEL_CURSOR_FACTORY_KEY, None
275278
)
@@ -281,9 +284,9 @@ def uninstrument_connection(connection):
281284
class DatabaseApiIntegration(dbapi.DatabaseApiIntegration):
282285
def wrapped_connection(
283286
self,
284-
connect_method: typing.Callable[..., typing.Any],
285-
args: typing.Tuple[typing.Any, typing.Any],
286-
kwargs: typing.Dict[typing.Any, typing.Any],
287+
connect_method: Callable[..., Any],
288+
args: tuple[Any, Any],
289+
kwargs: dict[Any, Any],
287290
):
288291
"""Add object proxy to connection object."""
289292
base_cursor_factory = kwargs.pop("cursor_factory", None)
@@ -299,9 +302,9 @@ def wrapped_connection(
299302
class DatabaseApiAsyncIntegration(dbapi.DatabaseApiIntegration):
300303
async def wrapped_connection(
301304
self,
302-
connect_method: typing.Callable[..., typing.Any],
303-
args: typing.Tuple[typing.Any, typing.Any],
304-
kwargs: typing.Dict[typing.Any, typing.Any],
305+
connect_method: Callable[..., Any],
306+
args: tuple[Any, Any],
307+
kwargs: dict[Any, Any],
305308
):
306309
"""Add object proxy to connection object."""
307310
base_cursor_factory = kwargs.pop("cursor_factory", None)
@@ -317,7 +320,7 @@ async def wrapped_connection(
317320

318321

319322
class CursorTracer(dbapi.CursorTracer):
320-
def get_operation_name(self, cursor, args):
323+
def get_operation_name(self, cursor: CursorT, args: list[Any]) -> str:
321324
if not args:
322325
return ""
323326

@@ -332,7 +335,7 @@ def get_operation_name(self, cursor, args):
332335

333336
return ""
334337

335-
def get_statement(self, cursor, args):
338+
def get_statement(self, cursor: CursorT, args: list[Any]) -> str:
336339
if not args:
337340
return ""
338341

@@ -342,7 +345,11 @@ def get_statement(self, cursor, args):
342345
return statement
343346

344347

345-
def _new_cursor_factory(db_api=None, base_factory=None, tracer_provider=None):
348+
def _new_cursor_factory(
349+
db_api: DatabaseApiIntegration | None = None,
350+
base_factory: type[psycopg.Cursor] | None = None,
351+
tracer_provider: TracerProvider | None = None,
352+
):
346353
if not db_api:
347354
db_api = DatabaseApiIntegration(
348355
__name__,
@@ -352,21 +359,21 @@ def _new_cursor_factory(db_api=None, base_factory=None, tracer_provider=None):
352359
tracer_provider=tracer_provider,
353360
)
354361

355-
base_factory = base_factory or pg_cursor
362+
base_factory = base_factory or psycopg.Cursor
356363
_cursor_tracer = CursorTracer(db_api)
357364

358365
class TracedCursorFactory(base_factory):
359-
def execute(self, *args, **kwargs):
366+
def execute(self, *args: Any, **kwargs: Any):
360367
return _cursor_tracer.traced_execution(
361368
self, super().execute, *args, **kwargs
362369
)
363370

364-
def executemany(self, *args, **kwargs):
371+
def executemany(self, *args: Any, **kwargs: Any):
365372
return _cursor_tracer.traced_execution(
366373
self, super().executemany, *args, **kwargs
367374
)
368375

369-
def callproc(self, *args, **kwargs):
376+
def callproc(self, *args: Any, **kwargs: Any):
370377
return _cursor_tracer.traced_execution(
371378
self, super().callproc, *args, **kwargs
372379
)
@@ -375,7 +382,9 @@ def callproc(self, *args, **kwargs):
375382

376383

377384
def _new_cursor_async_factory(
378-
db_api=None, base_factory=None, tracer_provider=None
385+
db_api: DatabaseApiAsyncIntegration | None = None,
386+
base_factory: type[psycopg.AsyncCursor] | None = None,
387+
tracer_provider: TracerProvider | None = None,
379388
):
380389
if not db_api:
381390
db_api = DatabaseApiAsyncIntegration(
@@ -385,21 +394,21 @@ def _new_cursor_async_factory(
385394
version=__version__,
386395
tracer_provider=tracer_provider,
387396
)
388-
base_factory = base_factory or pg_async_cursor
397+
base_factory = base_factory or psycopg.AsyncCursor
389398
_cursor_tracer = CursorTracer(db_api)
390399

391400
class TracedCursorAsyncFactory(base_factory):
392-
async def execute(self, *args, **kwargs):
401+
async def execute(self, *args: Any, **kwargs: Any):
393402
return await _cursor_tracer.traced_execution(
394403
self, super().execute, *args, **kwargs
395404
)
396405

397-
async def executemany(self, *args, **kwargs):
406+
async def executemany(self, *args: Any, **kwargs: Any):
398407
return await _cursor_tracer.traced_execution(
399408
self, super().executemany, *args, **kwargs
400409
)
401410

402-
async def callproc(self, *args, **kwargs):
411+
async def callproc(self, *args: Any, **kwargs: Any):
403412
return await _cursor_tracer.traced_execution(
404413
self, super().callproc, *args, **kwargs
405414
)

instrumentation/opentelemetry-instrumentation-psycopg/src/opentelemetry/instrumentation/psycopg/package.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,6 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
from __future__ import annotations
1415

15-
16-
_instruments = ("psycopg >= 3.1.0",)
16+
_instruments: tuple[str, ...] = ("psycopg >= 3.1.0",)

instrumentation/opentelemetry-instrumentation-psycopg/src/opentelemetry/instrumentation/psycopg/py.typed

Whitespace-only changes.

instrumentation/opentelemetry-instrumentation-psycopg/tests/test_psycopg_integration.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -132,10 +132,10 @@ class PostgresqlIntegrationTestMixin:
132132
def setUp(self):
133133
super().setUp()
134134
self.cursor_mock = mock.patch(
135-
"opentelemetry.instrumentation.psycopg.pg_cursor", MockCursor
135+
"opentelemetry.instrumentation.psycopg.psycopg.Cursor", MockCursor
136136
)
137137
self.cursor_async_mock = mock.patch(
138-
"opentelemetry.instrumentation.psycopg.pg_async_cursor",
138+
"opentelemetry.instrumentation.psycopg.psycopg.AsyncCursor",
139139
MockAsyncCursor,
140140
)
141141
self.connection_mock = mock.patch("psycopg.connect", MockConnection)

0 commit comments

Comments
 (0)