@@ -42,7 +42,7 @@ class Connection(metaclass=ConnectionMeta):
|
42 | 42 | """
|
43 | 43 |
|
44 | 44 | __slots__ = ('_protocol', '_transport', '_loop',
|
45 |
| -'_top_xact', '_aborted', |
| 45 | +'_top_xact', '_aborted', '_middlewares' |
46 | 46 | '_pool_release_ctr', '_stmt_cache', '_stmts_to_close',
|
47 | 47 | '_listeners', '_server_version', '_server_caps',
|
48 | 48 | '_intro_query', '_reset_query', '_proxy',
|
@@ -53,7 +53,8 @@ class Connection(metaclass=ConnectionMeta):
|
53 | 53 | def __init__(self, protocol, transport, loop,
|
54 | 54 | addr: (str, int) or str,
|
55 | 55 | config: connect_utils._ClientConfiguration,
|
56 |
| -params: connect_utils._ConnectionParameters): |
| 56 | +params: connect_utils._ConnectionParameters, |
| 57 | +middlewares=None): |
57 | 58 | self._protocol = protocol
|
58 | 59 | self._transport = transport
|
59 | 60 | self._loop = loop
|
@@ -92,7 +93,7 @@ def __init__(self, protocol, transport, loop,
|
92 | 93 |
|
93 | 94 | self._reset_query = None
|
94 | 95 | self._proxy = None
|
95 |
| - |
| 96 | +self._middlewares = _middlewares |
96 | 97 | # Used to serialize operations that might involve anonymous
|
97 | 98 | # statements. Specifically, we want to make the following
|
98 | 99 | # operation atomic:
|
@@ -1410,8 +1411,12 @@ async def reload_schema_state(self):
|
1410 | 1411 |
|
1411 | 1412 | async def _execute(self, query, args, limit, timeout, return_status=False):
|
1412 | 1413 | with self._stmt_exclusive_section:
|
1413 |
| -result, _ = await self.__execute( |
1414 |
| -query, args, limit, timeout, return_status=return_status) |
| 1414 | +wrapped = self.__execute |
| 1415 | +if self._middlewares: |
| 1416 | +for m in reversed(self._middlewares): |
| 1417 | +wrapped = await m(self, wrapped) |
| 1418 | + |
| 1419 | +result, _ = await wrapped(query, args, limit, timeout, return_status=return_status) |
1415 | 1420 | return result
|
1416 | 1421 |
|
1417 | 1422 | async def __execute(self, query, args, limit, timeout,
|
@@ -1502,6 +1507,7 @@ async def connect(dsn=None, *,
|
1502 | 1507 | max_cacheable_statement_size=1024 * 15,
|
1503 | 1508 | command_timeout=None,
|
1504 | 1509 | ssl=None,
|
| 1510 | +middlewares=None, |
1505 | 1511 | connection_class=Connection,
|
1506 | 1512 | server_settings=None):
|
1507 | 1513 | r"""A coroutine to establish a connection to a PostgreSQL server.
|
@@ -1618,6 +1624,10 @@ async def connect(dsn=None, *,
|
1618 | 1624 | PostgreSQL documentation for
|
1619 | 1625 | a `list of supported options <server settings>`_.
|
1620 | 1626 |
|
| 1627 | +:param middlewares: |
| 1628 | +An optional list of middleware functions. Refer to documentation |
| 1629 | +on create_pool. |
| 1630 | +
|
1621 | 1631 | :param Connection connection_class:
|
1622 | 1632 | Class of the returned connection object. Must be a subclass of
|
1623 | 1633 | :class:`~asyncpg.connection.Connection`.
|
@@ -1683,6 +1693,7 @@ async def connect(dsn=None, *,
|
1683 | 1693 | ssl=ssl, database=database,
|
1684 | 1694 | server_settings=server_settings,
|
1685 | 1695 | command_timeout=command_timeout,
|
| 1696 | +middlewares=middlewares, |
1686 | 1697 | statement_cache_size=statement_cache_size,
|
1687 | 1698 | max_cached_statement_lifetime=max_cached_statement_lifetime,
|
1688 | 1699 | max_cacheable_statement_size=max_cacheable_statement_size)
|
|
0 commit comments