websocket 入门系列:三 netty实现简单聊天


一 序

接在上一篇《websocket入门系列:二Tomcat实现

书上第11章写了个demo。我基于此修改下实现。netty的版本是4.17final

二 server端

package com.netty.websocket;

import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import io.netty.bootstrap.ServerBootstrap;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelId;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioServerSocketChannel;
import io.netty.handler.codec.http.HttpObjectAggregator;
import io.netty.handler.codec.http.HttpServerCodec;
import io.netty.handler.stream.ChunkedWriteHandler;


public class WebSocketServer {
	
	 private static final Logger logger = LoggerFactory.getLogger(WebSocketServer.class);  
	   /** 
	     * 保存所有WebSocket连接 
	     */  
	    private Map<ChannelId, Channel> channelMap = new ConcurrentHashMap<ChannelId, Channel>();   

	    private static final int MAX_CONTENT_LENGTH = 65536;  
	    // ------------------------ member fields -----------------------  
	      
	    private String host; // 绑定的地址  
	    private int port; // 绑定的端口  	      
	      
	    public WebSocketServer(String host, int port) {  
	        this.host = host;  
	        this.port = port;  	 
	    }  
	
    public void run() throws Exception {
	EventLoopGroup bossGroup = new NioEventLoopGroup();
	EventLoopGroup workerGroup = new NioEventLoopGroup();
	try {
	    ServerBootstrap b = new ServerBootstrap();
	    b.group(bossGroup, workerGroup)
		    .channel(NioServerSocketChannel.class)
		    .childHandler(new ChannelInitializer<SocketChannel>() {

			@Override
			protected void initChannel(SocketChannel ch) throws Exception {
			    ChannelPipeline pipeline = ch.pipeline();
			    // 保存该Channel的引用  
                channelMap.put(ch.id(), ch);
                ch.closeFuture().addListener(new ChannelFutureListener() {                    
                    public void operationComplete(ChannelFuture future) throws Exception {
                        logger.info("channel close {}", future.channel());  
                        // Channel 关闭后不再引用该Channel
                        channelMap.remove(future.channel().id());  
                    }
                });  
			    pipeline.addLast("http-codec",  new HttpServerCodec());
			    pipeline.addLast("aggregator",  new HttpObjectAggregator(MAX_CONTENT_LENGTH));
			    pipeline.addLast("http-chunked", new ChunkedWriteHandler());
			    pipeline.addLast("handler",  new WebSocketServerHandler(channelMap));
			}
		    });

	    Channel ch = b.bind(port).sync().channel();
	    System.out.println("Web socket server started at port " + port
		    + '.');
	    System.out
		    .println("Open your browser and navigate to http://localhost:"
			    + port + '/');

	    ch.closeFuture().sync();
	} finally {
	    bossGroup.shutdownGracefully();
	    workerGroup.shutdownGracefully();
	}
    }


    public static void main(String[] args) throws Exception {
	int port = 8080;
	if (args.length > 0) {
	    try {
		port = Integer.parseInt(args[0]);
	    } catch (NumberFormatException e) {
		e.printStackTrace();
	    }
	}
	 new WebSocketServer("127.0.0.1",port).run();
    }
}

对应的

package com.netty.websocket;

import java.util.Map;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import io.netty.channel.Channel;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelId;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.HttpHeaderNames;
import io.netty.handler.codec.http.HttpHeaders;
import io.netty.handler.codec.http.HttpMethod;
import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame;
import io.netty.handler.codec.http.websocketx.PingWebSocketFrame;
import io.netty.handler.codec.http.websocketx.PongWebSocketFrame;
import io.netty.handler.codec.http.websocketx.TextWebSocketFrame;
import io.netty.handler.codec.http.websocketx.WebSocketFrame;
import io.netty.handler.codec.http.websocketx.WebSocketServerHandshaker;
import io.netty.handler.codec.http.websocketx.WebSocketServerHandshakerFactory;
import io.netty.util.AttributeKey;


public class WebSocketServerHandler extends SimpleChannelInboundHandler<Object> {
	private static final Logger logger = LoggerFactory.getLogger(WebSocketServerHandler.class);


    private static final String WEBSOCKET_UPGRADE = "websocket";  
    private static final String WEBSOCKET_CONNECTION = "Upgrade";  
    private final String WEBSOCKET_URI_ROOT ="ws://127.0.0.1:8080";  
    // handshaker attachment key  
    private static final AttributeKey<WebSocketServerHandshaker> ATTR_HANDSHAKER = AttributeKey.newInstance("ATTR_KEY_CHANNELID");
    /** 
     * 保存所有WebSocket连接 
     */  
    private Map<ChannelId, Channel> channelMap ;
    
    public WebSocketServerHandler(Map channelMap){
    	this.channelMap = channelMap;
    }
    
    @Override
    public void channelReadComplete(ChannelHandlerContext ctx) throws Exception {
	ctx.flush();
    }

    private void handleHttpRequest(ChannelHandlerContext ctx,
	    FullHttpRequest req) throws Exception {

	if (isWebSocketUpgrade(req)) { // 该请求是不是websocket upgrade请求   
        logger.info("upgrade to websocket protocol");  
          
        String subProtocols = req.headers().get(HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL);  
          
        WebSocketServerHandshakerFactory factory = new WebSocketServerHandshakerFactory(WEBSOCKET_URI_ROOT, subProtocols, false);  
        WebSocketServerHandshaker handshaker = factory.newHandshaker(req);  
          
        if (handshaker == null) {// 请求头不合法, 导致handshaker没创建成功  
            WebSocketServerHandshakerFactory.sendUnsupportedVersionResponse(ctx.channel());  
        } else {  
            // 响应该请求  
            handshaker.handshake(ctx.channel(), req);  
            // 把handshaker 绑定给Channel, 以便后面关闭连接用  
            ctx.channel().attr(ATTR_HANDSHAKER).set(handshaker);// attach handshaker to this channel  
        }  
        return;  
    }
	// TODO 忽略普通http请求  
    logger.info("ignoring normal http request");  
    }
    
   

    private void handleWebSocketFrame(ChannelHandlerContext ctx,
	    WebSocketFrame frame) {

    	  // text frame  
        if (frame instanceof TextWebSocketFrame) {  
            String text = ((TextWebSocketFrame) frame).text();  
            for (Channel ch : channelMap.values()) {  
            TextWebSocketFrame rspFrame = new TextWebSocketFrame(text);  
            logger.info("recieve TextWebSocketFrame from channel {}", ctx.channel());  
            // 发给其他所有channel  
//   
//                if (ctx.channel().equals(ch)) {   
//                    continue;   
//                }  
                ch.writeAndFlush(rspFrame);  
                logger.info("write text[{}] to channel {}", text, ch);  
            }  
            return;  
        }  
          
        // ping frame, 回复pong frame即可  
        if (frame instanceof PingWebSocketFrame) {  
            logger.info("recieve PingWebSocketFrame from channel {}", ctx.channel());  
            ctx.channel().writeAndFlush(new PongWebSocketFrame(frame.content().retain()));  
            return;  
        }  
          
        if (frame instanceof PongWebSocketFrame) {  
            logger.info("recieve PongWebSocketFrame from channel {}", ctx.channel());  
            return;  
        }  
        // close frame,   
        if (frame instanceof CloseWebSocketFrame) {  
            logger.info("recieve CloseWebSocketFrame from channel {}", ctx.channel());  
            WebSocketServerHandshaker handshaker = ctx.channel().attr(ATTR_HANDSHAKER).get();  
            if (handshaker == null) {  
                logger.error("channel {} have no HandShaker", ctx.channel());  
                return;  
            }  
            handshaker.close(ctx.channel(), (CloseWebSocketFrame) frame.retain());  
            return;  
        }  
        // 剩下的是binary frame, 忽略  
        logger.warn("unhandle binary frame from channel {}", ctx.channel());  
    }

  
    @Override
    public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause)
	    throws Exception {
	cause.printStackTrace();
	ctx.close();
    }

	@Override
	protected void channelRead0(ChannelHandlerContext ctx, Object msg) throws Exception {
		// TODO Auto-generated method stub
		logger.info("receive msg:"+msg.toString());
		// 传统的HTTP接入
		if (msg instanceof FullHttpRequest) {
		    handleHttpRequest(ctx, (FullHttpRequest) msg);
		}
		// WebSocket接入
		else if (msg instanceof WebSocketFrame) {
		    handleWebSocketFrame(ctx, (WebSocketFrame) msg);
		}
	}
	
	//三者与:1.GET? 2.Upgrade头 包含websocket字符串?  3.Connection头 包含 Upgrade字符串?  
    private boolean isWebSocketUpgrade(FullHttpRequest req) {  
        HttpHeaders headers = req.headers();  
        return req.method().equals(HttpMethod.GET)   
                && headers.get(HttpHeaderNames.UPGRADE).contains(WEBSOCKET_UPGRADE)  
                && headers.get(HttpHeaderNames.CONNECTION).contains(WEBSOCKET_CONNECTION);  
    }  
    
}

这里面是主要业务逻辑的实现。对应websocket的协议,识别出握手过程,对于传输的数据帧类型进行不同的业务处理。

里面有一些细节。注意一个异常:

io.netty.util.IllegalReferenceCountException: refCnt: 0

三 测试

这里的测试,修改下地址:

web页面起了Tomcat。端口80,websocket的端口避免冲突:改为8080

web页面代码不再重复发了,参见上一篇,

分别打开火狐,Chrome浏览器。client模拟发一句:

package com.websocket.client;

import java.io.IOException;
import java.net.URI;
import java.util.concurrent.CountDownLatch;

import javax.websocket.ContainerProvider;
import javax.websocket.DeploymentException;
import javax.websocket.Session;
import javax.websocket.WebSocketContainer;

public class Test {

	public static void main(String[] args) throws DeploymentException, IOException, InterruptedException {
        WebSocketContainer ws = ContainerProvider.getWebSocketContainer();
        String url = "ws://127.0.0.1:8080";
        MyClient client = new MyClient();
        Session session =  ws.connectToServer(client, URI.create(url)); 
        int turn = 0;
       
        	session.getBasicRemote().sendText("client send: " + turn);
            Thread.sleep(1000);
        
        new CountDownLatch(1).await();
       
   }
}

Java的client输出:

I was accpeted by her!
客户端收到消息: client send: 0
客户端收到消息: chrome
客户端收到消息: hi,ff

页面的输出截屏:


*****************************************************************************

这里只是参照书上例子休简单的实现群发消息的demo.

还得深入学习netty,不然除了异常很难去修复。

参考:

http://lixiaohui.iteye.com/blog/2328183


智能推荐

注意!

本站转载的文章为个人学习借鉴使用,本站对版权不负任何法律责任。如果侵犯了您的隐私权益,请联系我们删除。



 
© 2014-2019 ITdaan.com 粤ICP备14056181号  

赞助商广告