Repository: hubdotcom/marlon-tools Branch: master Commit: 9f18e0fd37bd Files: 7 Total size: 36.7 KB Directory structure: gitextract_f07xqyey/ ├── README.md └── tools/ ├── adjustsrt/ │ └── adjustsrt.py ├── dnsproxy/ │ ├── dnsparser.py │ └── dnsproxy.py ├── pingonline/ │ └── pingonline.py └── tcpmon/ ├── TCPMonitor.java └── TCPMonitorSelect.java ================================================ FILE CONTENTS ================================================ ================================================ FILE: README.md ================================================ # marlon-tools Automatically exported from code.google.com/p/marlon-tools **dnsproxy** A simple DNS proxy server, support wilcard hosts, IPv6, cache. Instructions: ``` 1. Edit /etc/hosts, add: 127.0.0.1 *.local -2404:6800:8005::62 *.blogspot.com 2. startup dnsproxy (here using Google DNS server as delegating server): $ sudo python dnsproxy.py -s 8.8.8.8 3. Then set system dns server as 127.0.0.1, you can verify it by dig: $ dig test.local The result should contain 127.0.0.1. ``` Usage: ``` dnsproxy.py [options] Options: -h, --help show this help message and exit -f , --hosts-file= specify hosts file, default /etc/hosts -H HOST, --host=HOST specify the address to listen on -p PORT, --port=PORT specify the port to listen on -s SERVER, --server=SERVER specify the delegating dns server -C, --no-cache disable dns cache ``` ================================================ FILE: tools/adjustsrt/adjustsrt.py ================================================ #!/usr/bin/env python # coding: utf-8 from optparse import OptionParser import re import sys import time def main(): parser = OptionParser(usage='%prog -d delay [src_srt] [dest_srt]', description=u'调整srt字幕文件的延时') parser.add_option('-d', '--delay', dest='delay', metavar='DELAY', help=u'延时时间(秒),负数表示提前') opts, args = parser.parse_args(sys.argv[1:]) if not opts.delay or len(args) > 2: parser.print_help() sys.exit(1) fin, fout = sys.stdin, sys.stdout try: if len(args) >= 1: fin = open(args[0], 'r') if len(args) == 2: fout = open(args[1], 'w') delay(fin, fout, float(opts.delay)) except IOError as e: print >> sys.stderr, u'找不到文件: %s' % e.filename sys.exit(2) finally: fin.close() fout.close() class Time(object): M_SEC = 1000 M_MIN = 60 * 1000 M_HOUR = 3600 * 1000 def __init__(self, hours=0, mins=0, secs=0, millis=0): self.millis = hours*self.M_HOUR + mins*self.M_MIN + secs*self.M_SEC + millis def __add__(self, secs): millis = int(secs * self.M_SEC) return Time(millis=self.millis + millis) def __sub__(self, secs): return self.__add__(-secs) def __str__(self): secs, millis = self.millis/1000, self.millis%1000 mins, secs = secs/60, secs%60 hours, mins = mins/60, mins%60 return '%02d:%02d:%02d,%03d' % (hours, mins, secs, millis) RE_TIME = re.compile(r'^(\d{1,2}):(\d{1,2}):(\d{1,2}),(\d{1,3})') def parse_time(time_str): m = RE_TIME.match(time_str) if m: return Time(int(m.group(1)), int(m.group(2)), int(m.group(3)), int(m.group(4))) return None RE_TIME_LINE = re.compile(r'^(\d{1,2}:\d{1,2}:\d{1,2},\d{1,3}) --> (\d{1,2}:\d{1,2}:\d{1,2},\d{1,3})(\s*)$') def delay(fin, fout, delay): for line in fin.readlines(): m = RE_TIME_LINE.match(line) if m: s_time = parse_time(m.group(1)) + delay e_time = parse_time(m.group(2)) + delay fout.write('%s --> %s%s' % (s_time, e_time, m.group(3))) else: fout.write(line) if __name__ == '__main__': main() ================================================ FILE: tools/dnsproxy/dnsparser.py ================================================ # coding: utf-8 from cStringIO import StringIO import struct DNS_FLAG_QR = 0x8000 DNS_FLAG_RD = 0x0100 DNS_FLAG_RA = 0x0080 DNS_TYPE_A = 1 DNS_TYPE_AAAA = 28 DNS_CLASS_IN = 1 ''' 一个简单的DNS解析器。 解析: msg = DNSMessage.parse(message data) 序列化: data = msg.serialize() author: marlonyao ''' class DNSMessageHeader(object): def __init__(self, id, flag, qd_count, an_count, ns_count, ar_count): self.id = id self.flag = flag self.qd_count = qd_count self.an_count = an_count self.ns_count = ns_count self.ar_count = ar_count @staticmethod def parse(message): id, flag = struct.unpack('!HH', message.read(4)) qd_count, an_count, ns_count, ar_count = struct.unpack('!HHHH', message.read(8)) return DNSMessageHeader(id, flag, qd_count, an_count, ns_count, ar_count) def serialize(self, message, memoize=None): message.write(struct.pack('!HHHHHH', self.id, self.flag, self.qd_count, self.an_count, self.ns_count, self.ar_count)) def __str__(self): return 'id: %s, flag: %s, qd_count: %s, an_count: %s, ns_count: %s, ar_count: %s' % ( self.id, self.flag, self.qd_count, self.an_count, self.ns_count, self.ar_count ) class DNSMessageQuestion(object): def __init__(self, qname, qtype, qclass): self.qname = qname self.qtype = qtype self.qclass = qclass @staticmethod def parse(message): qname = parse_domain_name(message) qtype, qclass = struct.unpack('!HH', message.read(4)) return DNSMessageQuestion(qname, qtype, qclass) def serialize(self, message, memoize): unparse_domain_name(self.qname, message, memoize) message.write(struct.pack('!HH', self.qtype, self.qclass)) def __str__(self): return 'qname: %s, qtype: %s, qclass: %s' % (self.qname, self.qtype, self.qclass) def __repr__(self): return str(self) class DNSMessageRecord(object): def __init__(self, name, type_, class_, ttl, rdata): self.name = name self.type_ = type_ self.class_ = class_ self.ttl = ttl self.rdata = rdata @staticmethod def parse(message): name = parse_domain_name(message) type_, class_, ttl, rd_len = struct.unpack('!HHIH', message.read(10)) return DNSMessageRecord(name, type_, class_, ttl, message.read(rd_len)) def serialize(self, message, memoize): unparse_domain_name(self.name, message, memoize) message.write(struct.pack('!HHIH%ss'%len(self.rdata), self.type_, self.class_, self.ttl, len(self.rdata), self.rdata, )) def __str__(self): return "name: %s, type: %s, class: %s, ttl: %s, rdata: %s" % ( self.name, self.type_, self.class_, self.ttl, self.rdata ) def __repr__(self): return str(self) class DNSMessage(object): def __init__(self, header, questions=None, answers=None, authorities=None, additionals=None): self.header = header self.questions = questions or [] self.answers = answers or [] self.authorities = authorities or [] self.additionals = additionals or [] @staticmethod def parse(data): message = StringIO(data) header = DNSMessageHeader.parse(message) questions = [] for i in range(0, header.qd_count): quest = DNSMessageQuestion.parse(message) questions.append(quest) answers = [] for i in range(0, header.an_count): answer = DNSMessageRecord.parse(message) answers.append(answer) authorities = [] for i in range(0, header.ns_count): authority = DNSMessageRecord.parse(message) authorities.append(authority) additionals = [] for i in range(0, header.ar_count): additional = DNSMessageRecord.parse(message) additionals.append(additional) return DNSMessage(header, questions, answers, authorities, additionals) def serialize(self): 'serialize to network bytes, not considering name compression' message = StringIO() memoize = {} self.header.serialize(message, memoize) for s in self.questions: s.serialize(message, memoize) for s in self.answers: s.serialize(message, memoize) for s in self.authorities: s.serialize(message, memoize) for s in self.additionals: s.serialize(message, memoize) return message.getvalue() def __str__(self): return 'header: %s\n' % self.header +\ 'questions: %s\n' % self.questions +\ 'answers: %s\n' % self.answers +\ 'authorities: %s\n' % self.authorities +\ 'additionals: %s\n' % self.additionals def __repr__(self): return str(self) def _parse_domain_labels(message): labels = [] len = ord(message.read(1)) while len > 0: if len >= 64: # domain name compression len = len & 0x3f offset = (len << 8) + ord(message.read(1)) mesg = StringIO(message.getvalue()) mesg.seek(offset) labels.extend(_parse_domain_labels(mesg)) return labels else: labels.append(message.read(len)) len = ord(message.read(1)) return labels def parse_domain_name(message): return '.'.join(_parse_domain_labels(message)) def unparse_domain_name(name, message, memoize): labels = name.split('.') for i, label in enumerate(labels): qname = '.'.join(labels[i:]) if qname in memoize: offset = (memoize[qname] & 0x3fff) + 0xc000 message.write(struct.pack('!H', offset)) break else: memoize[qname] = message.tell() #print 'add to memoize: %s, %s' % (qname, message.tell()) message.write(struct.pack('!B%ss' % len(label), len(label), label)) else: # write last ending zero message.write('\x00') ================================================ FILE: tools/dnsproxy/dnsproxy.py ================================================ #!/usr/bin/env python # coding: utf-8 from SocketServer import BaseRequestHandler, ThreadingUDPServer from cStringIO import StringIO import os import socket import struct import time ''' A simple DNS proxy server, support wilcard hosts, IPv6, cache. Usage: Edit /etc/hosts, add: 127.0.0.1 *.local 2404:6800:8005::62 *.blogspot.com startup dnsproxy(here use Google DNS server as delegating server): $ sudo python dnsproxy.py -s 8.8.8.8 Then set system dns server as 127.0.0.1, you can verify it by dig: $ dig test.local The result should contains 127.0.0.1. author: marlonyao ''' def main(): import optparse, sys parser = optparse.OptionParser() parser.add_option('-f', '--hosts-file', dest='hosts_file', metavar='', default='/etc/hosts', help='specify hosts file, default /etc/hosts') parser.add_option('-H', '--host', dest='host', default='127.0.0.1', help='specify the address to listen on') parser.add_option('-p', '--port', dest='port', default=53, type='int', help='specify the port to listen on') parser.add_option('-s', '--server', dest='dns_server', metavar='SERVER', help='specify the delegating dns server') parser.add_option('-C', '--no-cache', dest='disable_cache', default=False, action='store_true', help='disable dns cache') opts, args = parser.parse_args() if not opts.dns_server: parser.print_help() sys.exit(1) dnsserver = DNSProxyServer(opts.dns_server, disable_cache=opts.disable_cache, host=opts.host, port=opts.port, hosts_file=opts.hosts_file) dnsserver.serve_forever() class Struct(object): def __init__(self, **kwargs): for name, value in kwargs.items(): setattr(self, name, value) def parse_dns_message(data): message = StringIO(data) message.seek(4) # skip id, flag c_qd, c_an, c_ns, c_ar = struct.unpack('!4H', message.read(8)) # parse question question = parse_dns_question(message) for i in range(1, c_qd): # skip other question parse_dns_question(message) records = [] for i in range(c_an+c_ns+c_ar): records.append(parse_dns_record(message)) return Struct(question=question, records=records) def parse_dns_question(message): qname = parse_domain_name(message) qtype, qclass = struct.unpack('!HH', message.read(4)) end_offset = message.tell() return Struct(name=qname, type_=qtype, class_=qclass, end_offset=end_offset) def parse_dns_record(message): parse_domain_name(message) # skip name message.seek(4, os.SEEK_CUR) # skip type, class ttl_offset = message.tell() ttl = struct.unpack('!I', message.read(4))[0] rd_len = struct.unpack('!H', message.read(2))[0] message.seek(rd_len, os.SEEK_CUR) # skip rd_content return Struct(ttl_offset=ttl_offset, ttl=ttl) def _parse_domain_labels(message): labels = [] len = ord(message.read(1)) while len > 0: if len >= 64: # domain name compression len = len & 0x3f offset = (len << 8) + ord(message.read(1)) mesg = StringIO(message.getvalue()) mesg.seek(offset) labels.extend(_parse_domain_labels(mesg)) return labels else: labels.append(message.read(len)) len = ord(message.read(1)) return labels def parse_domain_name(message): return '.'.join(_parse_domain_labels(message)) def addr_p2n(addr): try: return socket.inet_pton(socket.AF_INET, addr) except: return socket.inet_pton(socket.AF_INET6, addr) DNS_TYPE_A = 1 DNS_TYPE_AAAA = 28 DNS_CLASS_IN = 1 class DNSProxyHandler(BaseRequestHandler): def handle(self): reqdata, sock = self.request req = parse_dns_message(reqdata) q = req.question if q.type_ in (DNS_TYPE_A, DNS_TYPE_AAAA) and (q.class_ == DNS_CLASS_IN): for packed_ip, host in self.server.host_lines: if q.name.endswith(host): # header, qd=1, an=1, ns=0, ar=0 rspdata = reqdata[:2] + '\x81\x80\x00\x01\x00\x01\x00\x00\x00\x00' rspdata += reqdata[12:q.end_offset] # answer rspdata += '\xc0\x0c' # pointer to domain name # type, 1 for ip4, 28 for ip6 if len(packed_ip) == 4: rspdata += '\x00\x01' # 1 for ip4 else: rspdata += '\x00\x1c' # 28 for ip6 # class: 1, ttl: 2000(0x000007d0) rspdata += '\x00\x01\x00\x00\x07\xd0' rspdata += '\x00' + chr(len(packed_ip)) # rd_len rspdata += packed_ip sock.sendto(rspdata, self.client_address) return # lookup cache if not self.server.disable_cache: cache = self.server.cache cache_key = (q.name, q.type_, q.class_) cache_entry = cache.get(cache_key) if cache_entry: rspdata = update_ttl(reqdata, cache_entry) if rspdata: sock.sendto(rspdata, self.client_address) return rspdata = self._get_response(reqdata) if not self.server.disable_cache: cache[cache_key] = Struct(rspdata=rspdata, cache_time=int(time.time())) sock.sendto(rspdata, self.client_address) def _get_response(self, data): sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) # socket for the remote DNS server sock.connect((self.server.dns_server, 53)) sock.sendall(data) sock.settimeout(60) rspdata = sock.recv(65535) sock.close() return rspdata def update_ttl(reqdata, cache_entry): rspdata, cache_time = cache_entry.rspdata, cache_entry.cache_time rspbytes = bytearray(rspdata) rspbytes[:2] = reqdata[:2] # update id current_time = int(time.time()) time_interval = current_time - cache_time rsp = parse_dns_message(rspdata) for record in rsp.records: if record.ttl <= time_interval: return None rspbytes[record.ttl_offset:record.ttl_offset+4] = struct.pack('!I', record.ttl-time_interval) return str(rspbytes) def load_hosts(hosts_file): 'load hosts config, only extract config line contains wildcard domain name' def wildcard_line(line): parts = line.strip().split()[:2] if len(parts) < 2: return False if not parts[1].startswith('*'): return False try: packed_ip = addr_p2n(parts[0]) return packed_ip, parts[1][1:] except: return None with open(hosts_file) as hosts_in: hostlines = [] for line in hosts_in: hostline = wildcard_line(line) if hostline: hostlines.append(hostline) return hostlines class DNSProxyServer(ThreadingUDPServer): def __init__(self, dns_server, disable_cache=False, host='127.0.0.1', port=53, hosts_file='/etc/hosts'): self.dns_server = dns_server self.hosts_file = hosts_file self.host_lines = load_hosts(hosts_file) self.disable_cache = disable_cache self.cache = {} ThreadingUDPServer.__init__(self, (host, port), DNSProxyHandler) if __name__ == '__main__': main() ================================================ FILE: tools/pingonline/pingonline.py ================================================ #!/usr/bin/env python # coding: utf-8 import cookielib import optparse import re import sys import urllib import urllib2 RE_TOKEN = re.compile(r'(.*?)
', re.DOTALL) RE_IPADDR=re.compile(r'\(([\w.:]*)\)') PING_URL = 'http://www.subnetonline.com/pages/network-tools/online-ping-ipv4.php' PING6_URL = 'http://www.subnetonline.com/pages/ipv6-network-tools/online-ipv6-ping.php' def ping(host, v6=False, count=4, ttl=255, size=32, only_ip=False): url = PING_URL if v6: url = PING6_URL cj = cookielib.LWPCookieJar() opener = urllib2.build_opener(urllib2.HTTPCookieProcessor(cj)) urllib2.install_opener(opener) resp = urllib2.urlopen(url) content = resp.read() #print content # get token m = RE_TOKEN.search(content) if m: token = m.group(1) else: print >> sys.stderr, 'error: cannot find token' return None # post data data = { 'host': host, 'token': token, 'count': str(count), 'ttl': str(ttl), 'size': str(size), } resp = urllib2.urlopen(url, data=urllib.urlencode(data)) content = resp.read() #print content m = RE_IPOUT.search(content) if m: content = m.group(1)[:-1] # remove the last '\n' else: print >> sys.stderr, 'error: cannot find output' return None if not only_ip: return content else: m = RE_IPADDR.search(content) if m: ip = m.group(1) return ip else: print >> sys.stderr, 'error: cannot find ip address' return None def main(): parser = optparse.OptionParser(usage=u'%prog [-hp6] [-c count] [-t ttl] [-s packetsize] destination') parser.add_option('-c', dest='count', type='int', help='Stop after sending count ECHO_REQUEST packets.', default=4) parser.add_option('-t', dest='ttl', type='int', help='Set the IP Time to Live.', default=255) parser.add_option('-s', dest='packetsize', type='int', help='Specifies the number of data bytes to be sent.', default=32) parser.add_option('-6', dest='ipv6', action='store_true', help='Specifies whether use IPv6', default=False) parser.add_option('-p', dest='only_ip', action='store_true', default=False, help='only output ipv6 address, can be used lookup ipv6 address according by host.') opts, args = parser.parse_args() if len(args) != 1: parser.print_help() sys.exit(1) res = ping(args[0], count=opts.count, ttl=opts.ttl, size=opts.packetsize, only_ip=opts.only_ip, v6=opts.ipv6) if res: print res elif res is None: sys.exit(2) else: print >> sys.stderr, '%s not reachable' % args[0] sys.exit(3) if __name__ == '__main__': main() ================================================ FILE: tools/tcpmon/TCPMonitor.java ================================================ import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; import java.net.InetSocketAddress; import java.net.ServerSocket; import java.net.Socket; import java.net.SocketException; import java.net.UnknownHostException; import java.util.concurrent.ArrayBlockingQueue; import java.util.concurrent.BlockingQueue; import java.util.concurrent.Executor; import java.util.concurrent.Executors; /** *

TCP监视器,它在本地监听端口,将请求转发到远端服务器,并将请求返回。

* *

java TCPMonitor -a :8080 localhost:8000
监听本地的8080端口,将请求转发给本机的8000端口,并将所有请求响应输出到标准错误输出。

* *

java TCPMonitor :8080 work:22
监听本地的8080端口,并将请求转发给work主机的22端口。

* * @author marlonyao */ class TCPMonitor { public static void main (String[] args) throws IOException { Options options = null; try { options = parseArgs(args); if (options.help) { usage(); System.exit(0); } } catch (Exception e) { usage(); System.exit(-1); } ServerSocket serverSocket = new ServerSocket(); serverSocket.setReuseAddress(true); serverSocket.bind(new InetSocketAddress(options.host, options.port)); System.out.println("server started on " + serverSocket.getLocalSocketAddress()); Executor executor = Executors.newFixedThreadPool(options.threadCount); while(true) { Socket sock = serverSocket.accept(); System.out.println("accept connection: " + sock.getRemoteSocketAddress()); executor.execute(new MonitorThread(sock, options.remoteHost, options.remotePort, options.dumpRequest, options.dumpResponse)); } } static class Options { boolean help; boolean dumpRequest; boolean dumpResponse; int threadCount = 10; String host = "localhost"; int port; String remoteHost; int remotePort; } private static void usage() { System.err.print( "java TCPMonitor [options] [host]:port remote_host:remote_port\n" + " -h, --help print this help\n" + " -r, --dump-request dump request to stderr\n" + " -s, --dump-response dump response to stderr\n" + " -a, --dump-all dump request and response to stderr\n" + " -n, --threads=N thread count, default is 10\n"); } private static Options parseArgs(String[] args) { Options options = new Options(); int i; // parse options for (i = 0; i < args.length; i++) { if (!args[i].startsWith("-")) { break; } if (args[i].equals("-h") || args[i].equals("--help")) { options.help = true; } else if (args[i].equals("-r") || args[i].equals("--dump-request")) { options.dumpRequest = true; } else if (args[i].equals("-s") || args[i].equals("--dump-response")) { options.dumpResponse = true; } else if (args[i].equals("-a") || args[i].equals("--dump-all")) { options.dumpRequest = true; options.dumpResponse = true; } else if (args[i].equals("-n") || args[i].startsWith("--threads=")) { if (args[i].equals("-n")) { options.threadCount = Integer.parseInt(args[++i]); } else { options.threadCount = Integer.parseInt(args[i].substring("--threads=".length())); } } else { throw new RuntimeException("unknown option '" + args[i] + "'"); } } // parse remainder String localPart = args[i++]; String[] bits = localPart.split(":"); if (bits[0].length() > 0) { options.host = bits[0]; } options.port = Integer.parseInt(bits[1]); String remotePart = args[i++]; bits = remotePart.split(":"); options.remoteHost = bits[0]; options.remotePort = Integer.parseInt(bits[1]); return options; } } class MonitorThread implements Runnable { static final int BUFLEN = 1024; static final byte[] EOF = new byte[0]; // flag the end of channel Socket lsock; String rhost; int rport; boolean dumpRequest; boolean dumpResponse; Socket rsock; BlockingQueue lrchannel; // channel between read lsock and write rsock BlockingQueue rlchannel; // channel between read rsock and write lsock volatile boolean shutdownRequested; Thread readLSockThread; Thread writeRSockThread; Thread readRSockThread; Thread writeLSockThread; public MonitorThread(Socket sock, String remoteHost, int remotePort, boolean dumpRequest, boolean dumpResponse) { this.lsock = sock; this.rhost = remoteHost; this.rport = remotePort; this.dumpRequest = dumpRequest; this.dumpResponse = dumpResponse; } public void run() { try { rsock = new Socket(rhost, rport); } catch (UnknownHostException e) { System.out.println("unknown host: " + rhost); return; } catch (IOException e) { System.out.println(e); return; } try { lrchannel = new ArrayBlockingQueue(10); rlchannel = new ArrayBlockingQueue(10); readLSockThread = new Thread(new ReadLSockThread()); readLSockThread.start(); writeRSockThread = new Thread(new WriteRSockThread()); writeRSockThread.start(); readRSockThread = new Thread(new ReadRSockThread()); readRSockThread.start(); writeLSockThread = new Thread(new WriteLSockThread()); writeLSockThread.start(); readLSockThread.join(); writeRSockThread.join(); readRSockThread.join(); writeLSockThread.join(); System.out.println("connection closed: " + lsock.getRemoteSocketAddress()); } catch (InterruptedException e) { Thread.currentThread().interrupt(); } } // close all socks and stop all threads private void shutdown() { try { if (shutdownRequested) return; shutdownRequested = true; if (!lsock.isClosed()) lsock.close(); if (!rsock.isClosed()) rsock.close(); readLSockThread.interrupt(); writeRSockThread.interrupt(); readRSockThread.interrupt(); writeLSockThread.interrupt(); } catch (IOException e) { e.printStackTrace(); // ignore this exception } } private void processException(Exception e) { e.printStackTrace(); shutdown(); } private void processSocketClosed(SocketException e) { // this is a normal case, ignore it // System.err.println("lsock should be closed: " + e); } class ReadLSockThread implements Runnable { public void run() { try { bareRun(); } catch (IOException e) { processException(e); } catch (InterruptedException e) { Thread.currentThread().interrupt(); } } private void bareRun() throws IOException, InterruptedException { InputStream in = null; try { in = lsock.getInputStream(); byte[] buf = new byte[BUFLEN]; int len; while ((len = in.read(buf)) != -1) { byte[] copy = new byte[len]; System.arraycopy(buf, 0, copy, 0, len); lrchannel.put(copy); if (dumpRequest) System.err.print(new String(copy)); } lrchannel.put(EOF); // flag the end of lchannel in.close(); } catch (SocketException e) { processSocketClosed(e); } finally { lrchannel.put(EOF); // flag the end of lchannel if (in != null) in.close(); } } } class WriteRSockThread implements Runnable { public void run() { try { bareRun(); } catch (IOException e) { processException(e); } catch (InterruptedException e) { Thread.currentThread().interrupt(); } } private void bareRun() throws IOException, InterruptedException { OutputStream out = null; try { out = rsock.getOutputStream(); while (true) { byte[] buf = lrchannel.take(); if (buf == EOF) break; out.write(buf); } } catch (SocketException e) { processSocketClosed(e); } finally { if (out != null) out.close(); } } } class ReadRSockThread implements Runnable { public void run() { try { bareRun(); } catch (IOException e) { processException(e); } catch (InterruptedException e) { Thread.currentThread().interrupt(); } } private void bareRun() throws IOException, InterruptedException { InputStream in = null; try { in = rsock.getInputStream(); byte[] buf = new byte[BUFLEN]; int len; while ((len = in.read(buf)) != -1) { byte[] copy = new byte[len]; System.arraycopy(buf, 0, copy, 0, len); rlchannel.put(copy); } } catch (SocketException e) { processSocketClosed(e); } finally { // flag the end of rlchannel rlchannel.put(EOF); // rsock finished, closed it in.close(); rsock.close(); } } } class WriteLSockThread implements Runnable { public void run() { try { bareRun(); } catch (IOException e) { processException(e); } catch (InterruptedException e) { Thread.currentThread().interrupt(); } } private void bareRun() throws IOException, InterruptedException { OutputStream out = null; try { out = lsock.getOutputStream(); while (true) { byte[] buf = rlchannel.take(); if (buf == EOF) break; out.write(buf); if (dumpResponse) System.err.print(new String(buf)); } } catch (SocketException e) { processSocketClosed(e); } finally { if (out != null) out.close(); lsock.close(); // write finished, close it. } } } } ================================================ FILE: tools/tcpmon/TCPMonitorSelect.java ================================================ import java.io.IOException; import java.net.InetSocketAddress; import java.net.ServerSocket; import java.nio.ByteBuffer; import java.nio.channels.SelectionKey; import java.nio.channels.Selector; import java.nio.channels.ServerSocketChannel; import java.nio.channels.SocketChannel; import java.util.Iterator; import java.util.Set; /** *

TCP监视器,它在本地监听端口,将请求转发到远端服务器,并将请求返回。和TCPMonitor功能一样, * 不同的是TCPMonitorSelect用的是非阻塞IO,所以它的性能理论上会高一些。

* *

java TCPMonitorSelect -a :8080 localhost:8000
监听本地的8080端口,将请求转发给本机的8000端口,并将所有请求响应输出到标准错误输出。

* *

java TCPMonitorSelect :8080 work:22
监听本地的8080端口,并将请求转发给work主机的22端口。

* * @author marlonyao */ public class TCPMonitorSelect { private static final int BUFLEN = 1024; public static void main(String[] args) throws IOException { Options options = null; try { options = parseArgs(args); if (options.help) { usage(); System.exit(0); } } catch (Exception e) { usage(); System.exit(-1); } Selector selector = Selector.open(); // initial server, start accept connections ServerSocketChannel serverChannel = ServerSocketChannel.open(); serverChannel.configureBlocking(false); ServerSocket serverSocket = serverChannel.socket(); serverChannel.socket().bind(new InetSocketAddress(options.host, options.port)); System.out.println("server started on " + serverSocket.getLocalSocketAddress()); SelectionKey serverKey = serverChannel.register(selector, SelectionKey.OP_ACCEPT); serverKey.attach(new ServerHandler(selector, serverChannel, options.remoteHost, options.remotePort, options.dumpRequest, options.dumpResponse)); while (true) { selector.select(); Set keys = selector.selectedKeys(); for (Iterator itor = keys.iterator(); itor.hasNext();) { SelectionKey key = itor.next(); Handler handler = (Handler) key.attachment(); handler.execute(key); } keys.clear(); } } static class Options { boolean help; boolean dumpRequest; boolean dumpResponse; String host = "localhost"; int port; String remoteHost; int remotePort; } private static void usage() { System.err.print( "java TCPMonitorSelect [options] [host]:port remote_host:remote_port\n" + " -h, --help print this help\n" + " -r, --dump-request dump request to stderr\n" + " -s, --dump-response dump response to stderr\n" + " -a, --dump-all dump request and response to stderr\n" ); } private static Options parseArgs(String[] args) { Options options = new Options(); int i; // parse options for (i = 0; i < args.length; i++) { if (!args[i].startsWith("-")) { break; } if (args[i].equals("-h") || args[i].equals("--help")) { options.help = true; } else if (args[i].equals("-r") || args[i].equals("--dump-request")) { options.dumpRequest = true; } else if (args[i].equals("-s") || args[i].equals("--dump-response")) { options.dumpResponse = true; } else if (args[i].equals("-a") || args[i].equals("--dump-all")) { options.dumpRequest = true; options.dumpResponse = true; } else { throw new RuntimeException("unknown option '" + args[i] + "'"); } } // parse remainder String localPart = args[i++]; String[] bits = localPart.split(":"); if (bits[0].length() > 0) { options.host = bits[0]; } options.port = Integer.parseInt(bits[1]); String remotePart = args[i++]; bits = remotePart.split(":"); options.remoteHost = bits[0]; options.remotePort = Integer.parseInt(bits[1]); return options; } interface Handler { void execute(SelectionKey key); } /* * process accept request. */ static class ServerHandler implements Handler { private ServerSocketChannel serverChannel; private Selector selector; private String remoteHost; private int remotePort; private boolean dumpRequest; private boolean dumpResponse; public ServerHandler(Selector selector, ServerSocketChannel serverChannel, String remoteHost, int remotePort, boolean dumpRequest, boolean dumpResponse) { this.selector = selector; this.serverChannel = serverChannel; this.remoteHost = remoteHost; this.remotePort = remotePort; this.dumpRequest = dumpRequest; this.dumpResponse = dumpResponse; } public void execute(SelectionKey key) { SocketChannel lsockChannel = null; try { lsockChannel = serverChannel.accept(); System.out.println("accept connection: " + lsockChannel.socket().getRemoteSocketAddress()); System.out.flush(); } catch (IOException e) { System.err.println("fail to accept connection"); e.printStackTrace(); return; } // start client handler ClientHandler handler = new ClientHandler(selector, lsockChannel, remoteHost, remotePort, dumpRequest, dumpResponse); // start connect to remote host handler.startConnect(); } } /* * process loop: read lsock -> write rsock -> read rsock -> write lsock */ static class ClientHandler implements Handler { private Selector selector; private String remoteHost; private int remotePort; private SocketChannel lsockChannel; private SocketChannel rsockChannel; private SelectionKey lsockKey; private SelectionKey rsockKey; private ByteBuffer lrBuffer; // buffer between read lsock and write rsock private ByteBuffer rlBuffer; // buffer between read rsock and write lsock private boolean dumpRequest; private boolean dumpResponse; public ClientHandler(Selector selector, SocketChannel lsockChannel, String remoteHost, int remotePort, boolean dumpRequest, boolean dumpResponse) { this.selector = selector; this.lsockChannel = lsockChannel; this.remoteHost = remoteHost; this.remotePort = remotePort; this.lrBuffer = ByteBuffer.allocate(BUFLEN); this.rlBuffer = ByteBuffer.allocate(BUFLEN); this.dumpRequest = dumpRequest; this.dumpResponse = dumpResponse; } public void startConnect() { try { // connect rsock key rsockChannel = SocketChannel.open(); rsockChannel.configureBlocking(false); rsockChannel.connect(new InetSocketAddress(remoteHost, remotePort)); rsockKey = rsockChannel.register(selector, SelectionKey.OP_CONNECT); rsockKey.attach(this); } catch (IOException e) { e.printStackTrace(); cancel(); } } public void execute(SelectionKey key) { if (!key.isValid()) return; try { if (key.isReadable()) { if (key == lsockKey) { readLSock(); } else { readRSock(); } } else if (key.isWritable()) { if (key == lsockKey) { writeLSock(); } else { writeRSock(); } } else if (key.isConnectable()) { connectRSock(); } } catch (IOException e) { e.printStackTrace(); cancel(); } } public void cancel() { if (lsockKey != null) { lsockKey.cancel(); try { lsockKey.channel().close(); } catch (IOException ioe) {} } if (rsockKey != null) { rsockKey.cancel(); try { rsockKey.channel().close(); } catch (IOException ioe) {} } } private void readLSock() throws IOException { int n = lsockChannel.read(lrBuffer); if (n == -1) { lsockKey.interestOps(0); rsockChannel.socket().shutdownOutput(); } else { if (dumpRequest) { System.err.print(new String(lrBuffer.array(), 0, n)); } lrBuffer.flip(); lsockKey.interestOps(0); rsockKey.interestOps(SelectionKey.OP_WRITE); } } private void writeRSock() throws IOException { /*int n = */rsockChannel.write(lrBuffer); if (lrBuffer.remaining() == 0) { lrBuffer.clear(); // write finished rsockKey.interestOps(SelectionKey.OP_READ); lsockKey.interestOps(SelectionKey.OP_READ); } } private void readRSock() throws IOException { int n = rsockChannel.read(rlBuffer); if (n == -1) { System.out.println("close connection: " + lsockChannel.socket().getRemoteSocketAddress()); System.out.flush(); rsockKey.cancel(); rsockChannel.close(); lsockKey.interestOps(0); lsockKey.cancel(); lsockChannel.close(); } else { rlBuffer.flip(); rsockKey.interestOps(0); lsockKey.interestOps(SelectionKey.OP_WRITE); } } private void writeLSock() throws IOException { int n = lsockChannel.write(rlBuffer); if (dumpResponse) { System.err.print(new String(rlBuffer.array(), rlBuffer.position()-n, n)); } if (rlBuffer.remaining() == 0) { rlBuffer.clear(); // write finished lsockKey.interestOps(SelectionKey.OP_READ); rsockKey.interestOps(SelectionKey.OP_READ); } } private void connectRSock() throws IOException { rsockChannel.finishConnect(); lsockChannel.configureBlocking(false); lsockKey = lsockChannel.register(selector, SelectionKey.OP_READ); lsockKey.attach(this); rsockKey.interestOps(SelectionKey.OP_READ); } } }