1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package io.netty.channel.embedded;
17
18 import io.netty.channel.AbstractChannel;
19 import io.netty.channel.Channel;
20 import io.netty.channel.ChannelConfig;
21 import io.netty.channel.ChannelFuture;
22 import io.netty.channel.ChannelFutureListener;
23 import io.netty.channel.ChannelHandler;
24 import io.netty.channel.ChannelHandlerContext;
25 import io.netty.channel.ChannelInboundHandlerAdapter;
26 import io.netty.channel.ChannelInitializer;
27 import io.netty.channel.ChannelMetadata;
28 import io.netty.channel.ChannelOutboundBuffer;
29 import io.netty.channel.ChannelPipeline;
30 import io.netty.channel.ChannelPromise;
31 import io.netty.channel.DefaultChannelConfig;
32 import io.netty.channel.EventLoop;
33 import io.netty.util.ReferenceCountUtil;
34 import io.netty.util.internal.ObjectUtil;
35 import io.netty.util.internal.PlatformDependent;
36 import io.netty.util.internal.RecyclableArrayList;
37 import io.netty.util.internal.logging.InternalLogger;
38 import io.netty.util.internal.logging.InternalLoggerFactory;
39
40 import java.net.SocketAddress;
41 import java.nio.channels.ClosedChannelException;
42 import java.util.ArrayDeque;
43 import java.util.Queue;
44
45
46
47
48 public class EmbeddedChannel extends AbstractChannel {
49
50 private static final InternalLogger logger = InternalLoggerFactory.getInstance(EmbeddedChannel.class);
51
52 private static final ChannelMetadata METADATA_NO_DISCONNECT = new ChannelMetadata(false);
53 private static final ChannelMetadata METADATA_DISCONNECT = new ChannelMetadata(true);
54
55 private final EmbeddedEventLoop loop = new EmbeddedEventLoop();
56 private final ChannelFutureListener recordExceptionListener = new ChannelFutureListener() {
57 @Override
58 public void operationComplete(ChannelFuture future) throws Exception {
59 recordException(future);
60 }
61 };
62
63 private final ChannelMetadata metadata;
64 private final ChannelConfig config;
65 private final SocketAddress localAddress = new EmbeddedSocketAddress();
66 private final SocketAddress remoteAddress = new EmbeddedSocketAddress();
67
68 private Queue<Object> inboundMessages;
69 private Queue<Object> outboundMessages;
70 private Throwable lastException;
71 private int state;
72
73
74
75
76
77
78 public EmbeddedChannel(final ChannelHandler... handlers) {
79 this(false, handlers);
80 }
81
82
83
84
85
86
87
88
89
90 public EmbeddedChannel(boolean hasDisconnect, final ChannelHandler... handlers) {
91 super(null);
92 ObjectUtil.checkNotNull(handlers, "handlers");
93 metadata = hasDisconnect ? METADATA_DISCONNECT : METADATA_NO_DISCONNECT;
94 config = new DefaultChannelConfig(this);
95
96 ChannelPipeline p = pipeline();
97 p.addLast(new ChannelInitializer<Channel>() {
98 @Override
99 protected void initChannel(Channel ch) throws Exception {
100 ChannelPipeline pipeline = ch.pipeline();
101 for (ChannelHandler h: handlers) {
102 if (h == null) {
103 break;
104 }
105 pipeline.addLast(h);
106 }
107 }
108 });
109
110 ChannelFuture future = loop.register(this);
111 assert future.isDone();
112 p.addLast(new LastInboundHandler());
113 }
114
115 @Override
116 public ChannelMetadata metadata() {
117 return metadata;
118 }
119
120 @Override
121 public ChannelConfig config() {
122 return config;
123 }
124
125 @Override
126 public boolean isOpen() {
127 return state < 2;
128 }
129
130 @Override
131 public boolean isActive() {
132 return state == 1;
133 }
134
135
136
137
138 public Queue<Object> inboundMessages() {
139 if (inboundMessages == null) {
140 inboundMessages = new ArrayDeque<Object>();
141 }
142 return inboundMessages;
143 }
144
145
146
147
148 @Deprecated
149 public Queue<Object> lastInboundBuffer() {
150 return inboundMessages();
151 }
152
153
154
155
156 public Queue<Object> outboundMessages() {
157 if (outboundMessages == null) {
158 outboundMessages = new ArrayDeque<Object>();
159 }
160 return outboundMessages;
161 }
162
163
164
165
166 @Deprecated
167 public Queue<Object> lastOutboundBuffer() {
168 return outboundMessages();
169 }
170
171
172
173
174 public Object readInbound() {
175 return poll(inboundMessages);
176 }
177
178
179
180
181 public Object readOutbound() {
182 return poll(outboundMessages);
183 }
184
185
186
187
188
189
190
191
192 public boolean writeInbound(Object... msgs) {
193 ensureOpen();
194 if (msgs.length == 0) {
195 return isNotEmpty(inboundMessages);
196 }
197
198 ChannelPipeline p = pipeline();
199 for (Object m: msgs) {
200 p.fireChannelRead(m);
201 }
202 p.fireChannelReadComplete();
203 runPendingTasks();
204 checkException();
205 return isNotEmpty(inboundMessages);
206 }
207
208
209
210
211
212
213
214 public boolean writeOutbound(Object... msgs) {
215 ensureOpen();
216 if (msgs.length == 0) {
217 return isNotEmpty(outboundMessages);
218 }
219
220 RecyclableArrayList futures = RecyclableArrayList.newInstance(msgs.length);
221 try {
222 for (Object m: msgs) {
223 if (m == null) {
224 break;
225 }
226 futures.add(write(m));
227 }
228
229
230 runPendingTasks();
231 flush();
232
233 int size = futures.size();
234 for (int i = 0; i < size; i++) {
235 ChannelFuture future = (ChannelFuture) futures.get(i);
236 if (future.isDone()) {
237 recordException(future);
238 } else {
239
240 future.addListener(recordExceptionListener);
241 }
242 }
243
244 checkException();
245 return isNotEmpty(outboundMessages);
246 } finally {
247 futures.recycle();
248 }
249 }
250
251
252
253
254
255
256 public boolean finish() {
257 return finish(false);
258 }
259
260
261
262
263
264
265
266 public boolean finishAndReleaseAll() {
267 return finish(true);
268 }
269
270
271
272
273
274
275
276 private boolean finish(boolean releaseAll) {
277 close();
278 try {
279 checkException();
280 return isNotEmpty(inboundMessages) || isNotEmpty(outboundMessages);
281 } finally {
282 if (releaseAll) {
283 releaseAll(inboundMessages);
284 releaseAll(outboundMessages);
285 }
286 }
287 }
288
289
290
291
292
293 public boolean releaseInbound() {
294 return releaseAll(inboundMessages);
295 }
296
297
298
299
300
301 public boolean releaseOutbound() {
302 return releaseAll(outboundMessages);
303 }
304
305 private static boolean releaseAll(Queue<Object> queue) {
306 if (isNotEmpty(queue)) {
307 for (;;) {
308 Object msg = queue.poll();
309 if (msg == null) {
310 break;
311 }
312 ReferenceCountUtil.release(msg);
313 }
314 return true;
315 }
316 return false;
317 }
318
319 private void finishPendingTasks(boolean cancel) {
320 runPendingTasks();
321 if (cancel) {
322
323 loop.cancelScheduledTasks();
324 }
325 }
326
327 @Override
328 public final ChannelFuture close() {
329 return close(newPromise());
330 }
331
332 @Override
333 public final ChannelFuture disconnect() {
334 return disconnect(newPromise());
335 }
336
337 @Override
338 public final ChannelFuture close(ChannelPromise promise) {
339
340
341 runPendingTasks();
342 ChannelFuture future = super.close(promise);
343
344
345 finishPendingTasks(true);
346 return future;
347 }
348
349 @Override
350 public final ChannelFuture disconnect(ChannelPromise promise) {
351 ChannelFuture future = super.disconnect(promise);
352 finishPendingTasks(!metadata.hasDisconnect());
353 return future;
354 }
355
356 private static boolean isNotEmpty(Queue<Object> queue) {
357 return queue != null && !queue.isEmpty();
358 }
359
360 private static Object poll(Queue<Object> queue) {
361 return queue != null ? queue.poll() : null;
362 }
363
364
365
366
367
368 public void runPendingTasks() {
369 try {
370 loop.runTasks();
371 } catch (Exception e) {
372 recordException(e);
373 }
374
375 try {
376 loop.runScheduledTasks();
377 } catch (Exception e) {
378 recordException(e);
379 }
380 }
381
382
383
384
385
386
387 public long runScheduledPendingTasks() {
388 try {
389 return loop.runScheduledTasks();
390 } catch (Exception e) {
391 recordException(e);
392 return loop.nextScheduledTask();
393 }
394 }
395
396 private void recordException(ChannelFuture future) {
397 if (!future.isSuccess()) {
398 recordException(future.cause());
399 }
400 }
401
402 private void recordException(Throwable cause) {
403 if (lastException == null) {
404 lastException = cause;
405 } else {
406 logger.warn(
407 "More than one exception was raised. " +
408 "Will report only the first one and log others.", cause);
409 }
410 }
411
412
413
414
415 public void checkException() {
416 Throwable t = lastException;
417 if (t == null) {
418 return;
419 }
420
421 lastException = null;
422
423 PlatformDependent.throwException(t);
424 }
425
426
427
428
429 protected final void ensureOpen() {
430 if (!isOpen()) {
431 recordException(new ClosedChannelException());
432 checkException();
433 }
434 }
435
436 @Override
437 protected boolean isCompatible(EventLoop loop) {
438 return loop instanceof EmbeddedEventLoop;
439 }
440
441 @Override
442 protected SocketAddress localAddress0() {
443 return isActive()? localAddress : null;
444 }
445
446 @Override
447 protected SocketAddress remoteAddress0() {
448 return isActive()? remoteAddress : null;
449 }
450
451 @Override
452 protected void doRegister() throws Exception {
453 state = 1;
454 }
455
456 @Override
457 protected void doBind(SocketAddress localAddress) throws Exception {
458
459 }
460
461 @Override
462 protected void doDisconnect() throws Exception {
463 if (!metadata.hasDisconnect()) {
464 doClose();
465 }
466 }
467
468 @Override
469 protected void doClose() throws Exception {
470 state = 2;
471 }
472
473 @Override
474 protected void doBeginRead() throws Exception {
475
476 }
477
478 @Override
479 protected AbstractUnsafe newUnsafe() {
480 return new DefaultUnsafe();
481 }
482
483 @Override
484 protected void doWrite(ChannelOutboundBuffer in) throws Exception {
485 for (;;) {
486 Object msg = in.current();
487 if (msg == null) {
488 break;
489 }
490
491 ReferenceCountUtil.retain(msg);
492 outboundMessages().add(msg);
493 in.remove();
494 }
495 }
496
497 private class DefaultUnsafe extends AbstractUnsafe {
498 @Override
499 public void connect(SocketAddress remoteAddress, SocketAddress localAddress, ChannelPromise promise) {
500 safeSetSuccess(promise);
501 }
502 }
503
504 private final class LastInboundHandler extends ChannelInboundHandlerAdapter {
505 @Override
506 public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
507 inboundMessages().add(msg);
508 }
509
510 @Override
511 public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
512 recordException(cause);
513 }
514 }
515 }