13
13
# limitations under the License.
14
14
15
15
import asyncio
16
- from typing import Callable , Union , List , Dict , NamedTuple
17
- import queue
16
+ from typing import Callable , List , Dict , NamedTuple
18
17
19
- from google .api_core .exceptions import FailedPrecondition , GoogleAPICallError
18
+ from google .api_core .exceptions import GoogleAPICallError
20
19
from google .cloud .pubsub_v1 .subscriber .message import Message
21
20
from google .pubsub_v1 import PubsubMessage
22
21
23
- from google .cloud .pubsublite .internal .wait_ignore_cancelled import wait_ignore_cancelled
24
22
from google .cloud .pubsublite .internal .wire .permanent_failable import adapt_error
25
- from google .cloud .pubsublite .internal import fast_serialize
26
23
from google .cloud .pubsublite .types import FlowControlSettings
27
24
from google .cloud .pubsublite .cloudpubsub .internal .ack_set_tracker import AckSetTracker
25
+ from google .cloud .pubsublite .cloudpubsub .internal .wrapped_message import (
26
+ AckId ,
27
+ WrappedMessage ,
28
+ )
28
29
from google .cloud .pubsublite .cloudpubsub .message_transformer import MessageTransformer
29
30
from google .cloud .pubsublite .cloudpubsub .nack_handler import NackHandler
30
31
from google .cloud .pubsublite .cloudpubsub .internal .single_subscriber import (
36
37
SubscriberResetHandler ,
37
38
)
38
39
from google .cloud .pubsublite_v1 import FlowControlRequest , SequencedMessage
39
- from google .cloud .pubsub_v1 .subscriber ._protocol import requests
40
40
41
41
42
42
class _SizedMessage (NamedTuple ):
43
43
message : PubsubMessage
44
44
size_bytes : int
45
45
46
46
47
- class _AckId (NamedTuple ):
48
- generation : int
49
- offset : int
50
-
51
- def encode (self ) -> str :
52
- return fast_serialize .dump ([self .generation , self .offset ])
53
-
54
- @staticmethod
55
- def parse (payload : str ) -> "_AckId" : # pytype: disable=invalid-annotation
56
- loaded = fast_serialize .load (payload )
57
- return _AckId (generation = loaded [0 ], offset = loaded [1 ])
58
-
59
-
60
47
ResettableSubscriberFactory = Callable [[SubscriberResetHandler ], Subscriber ]
61
48
62
49
@@ -69,10 +56,10 @@ class SinglePartitionSingleSubscriber(
69
56
_nack_handler : NackHandler
70
57
_transformer : MessageTransformer
71
58
72
- _queue : queue .Queue
73
59
_ack_generation_id : int
74
- _messages_by_ack_id : Dict [str , _SizedMessage ]
75
- _looper_future : asyncio .Future
60
+ _messages_by_ack_id : Dict [AckId , _SizedMessage ]
61
+
62
+ _loop : asyncio .AbstractEventLoop
76
63
77
64
def __init__ (
78
65
self ,
@@ -89,7 +76,6 @@ def __init__(
89
76
self ._nack_handler = nack_handler
90
77
self ._transformer = transformer
91
78
92
- self ._queue = queue .Queue ()
93
79
self ._ack_generation_id = 0
94
80
self ._messages_by_ack_id = {}
95
81
@@ -104,19 +90,33 @@ def _wrap_message(self, message: SequencedMessage.meta.pb) -> Message:
104
90
rewrapped ._pb = message
105
91
cps_message = self ._transformer .transform (rewrapped )
106
92
offset = message .cursor .offset
107
- ack_id_str = _AckId (self ._ack_generation_id , offset ). encode ( )
93
+ ack_id = AckId (self ._ack_generation_id , offset )
108
94
self ._ack_set_tracker .track (offset )
109
- self ._messages_by_ack_id [ack_id_str ] = _SizedMessage (
95
+ self ._messages_by_ack_id [ack_id ] = _SizedMessage (
110
96
cps_message , message .size_bytes
111
97
)
112
- wrapped_message = Message (
113
- cps_message ._pb ,
114
- ack_id = ack_id_str ,
115
- delivery_attempt = 0 ,
116
- request_queue = self ._queue ,
98
+ wrapped_message = WrappedMessage (
99
+ pb = cps_message ._pb ,
100
+ ack_id = ack_id ,
101
+ ack_handler = lambda id , ack : self ._on_ack_threadsafe (id , ack ),
117
102
)
118
103
return wrapped_message
119
104
105
+ def _on_ack_threadsafe (self , ack_id : AckId , should_ack : bool ) -> None :
106
+ """A function called when a message is acked, may happen from any thread."""
107
+ if should_ack :
108
+ self ._loop .call_soon_threadsafe (lambda : self ._handle_ack (ack_id ))
109
+ return
110
+ try :
111
+ sized_message = self ._messages_by_ack_id [ack_id ]
112
+ # Call the threadsafe version on ack since the callback may be called from another thread.
113
+ self ._nack_handler .on_nack (
114
+ sized_message .message , lambda : self ._on_ack_threadsafe (ack_id , True )
115
+ )
116
+ except Exception as e :
117
+ e2 = adapt_error (e )
118
+ self ._loop .call_soon_threadsafe (lambda : self .fail (e2 ))
119
+
120
120
async def read (self ) -> List [Message ]:
121
121
try :
122
122
latest_batch = await self .await_unless_failed (self ._underlying .read ())
@@ -126,78 +126,23 @@ async def read(self) -> List[Message]:
126
126
self .fail (e )
127
127
raise e
128
128
129
- def _handle_ack (self , message : requests . AckRequest ):
129
+ def _handle_ack (self , ack_id : AckId ):
130
130
flow_control = FlowControlRequest ()
131
131
flow_control ._pb .allowed_messages = 1
132
- flow_control ._pb .allowed_bytes = self ._messages_by_ack_id [
133
- message .ack_id
134
- ].size_bytes
132
+ flow_control ._pb .allowed_bytes = self ._messages_by_ack_id [ack_id ].size_bytes
135
133
self ._underlying .allow_flow (flow_control )
136
- del self ._messages_by_ack_id [message . ack_id ]
134
+ del self ._messages_by_ack_id [ack_id ]
137
135
# Always refill flow control tokens, but do not commit offsets from outdated generations.
138
- ack_id = _AckId .parse (message .ack_id )
139
136
if ack_id .generation == self ._ack_generation_id :
140
137
try :
141
138
self ._ack_set_tracker .ack (ack_id .offset )
142
139
except GoogleAPICallError as e :
143
140
self .fail (e )
144
141
145
- def _handle_nack (self , message : requests .NackRequest ):
146
- sized_message = self ._messages_by_ack_id [message .ack_id ]
147
- try :
148
- # Put the ack request back into the queue since the callback may be called from another thread.
149
- self ._nack_handler .on_nack (
150
- sized_message .message ,
151
- lambda : self ._queue .put (
152
- requests .AckRequest (
153
- ack_id = message .ack_id ,
154
- byte_size = 0 , # Ignored
155
- time_to_ack = 0 , # Ignored
156
- ordering_key = "" , # Ignored
157
- )
158
- ),
159
- )
160
- except GoogleAPICallError as e :
161
- self .fail (e )
162
-
163
- async def _handle_queue_message (
164
- self ,
165
- message : Union [
166
- requests .AckRequest ,
167
- requests .DropRequest ,
168
- requests .ModAckRequest ,
169
- requests .NackRequest ,
170
- ],
171
- ):
172
- if isinstance (message , requests .DropRequest ) or isinstance (
173
- message , requests .ModAckRequest
174
- ):
175
- self .fail (
176
- FailedPrecondition (
177
- "Called internal method of google.cloud.pubsub_v1.subscriber.message.Message "
178
- f"Pub/Sub Lite does not support: { message } "
179
- )
180
- )
181
- elif isinstance (message , requests .AckRequest ):
182
- self ._handle_ack (message )
183
- else :
184
- self ._handle_nack (message )
185
-
186
- async def _looper (self ):
187
- while True :
188
- try :
189
- # This is not an asyncio.Queue, and therefore we cannot do `await self._queue.get()`.
190
- # A blocking wait would block the event loop, this needs to be a queue.Queue for
191
- # compatibility with the Cloud Pub/Sub Message's requirements.
192
- queue_message = self ._queue .get_nowait ()
193
- await self ._handle_queue_message (queue_message )
194
- except queue .Empty :
195
- await asyncio .sleep (0.1 )
196
-
197
142
async def __aenter__ (self ):
143
+ self ._loop = asyncio .get_event_loop ()
198
144
await self ._ack_set_tracker .__aenter__ ()
199
145
await self ._underlying .__aenter__ ()
200
- self ._looper_future = asyncio .ensure_future (self ._looper ())
201
146
self ._underlying .allow_flow (
202
147
FlowControlRequest (
203
148
allowed_messages = self ._flow_control_settings .messages_outstanding ,
@@ -207,7 +152,5 @@ async def __aenter__(self):
207
152
return self
208
153
209
154
async def __aexit__ (self , exc_type , exc_value , traceback ):
210
- self ._looper_future .cancel ()
211
- await wait_ignore_cancelled (self ._looper_future )
212
155
await self ._underlying .__aexit__ (exc_type , exc_value , traceback )
213
156
await self ._ack_set_tracker .__aexit__ (exc_type , exc_value , traceback )
0 commit comments