diff --git a/connection.c b/connection.c index e58031c..9f47c8e 100644 --- a/connection.c +++ b/connection.c @@ -13,6 +13,13 @@ #include "sockaddr.h" #include "tcpsocket.h" +struct flood_ctrl { + int burst; + int current; + int inc_per_sec; + time_t last_inc; +}; + struct connection { struct list_head list; @@ -23,6 +30,9 @@ struct connection { struct event_fd *dst_event; struct event_timeout *timeout; + struct flood_ctrl src_flood; + struct flood_ctrl dst_flood; + time_t login_time; }; @@ -37,7 +47,20 @@ static struct connection * create_connection() } memset(con, 0, sizeof(struct connection)); - con->login_time = time(NULL); + + time_t now = time(NULL); + con->login_time = now; + + con->src_flood.burst = 256; + con->src_flood.inc_per_sec = 1; + con->src_flood.last_inc = now; + + con->dst_flood.burst = 256; + con->dst_flood.inc_per_sec = 1; + con->dst_flood.last_inc = now; + + con->src_flood.current = con->src_flood.burst; + con->dst_flood.current = con->dst_flood.burst; return con; } @@ -61,32 +84,95 @@ static void destroy_connection(struct connection *con) free(con); } +static int forward_throttle_timeout(void *privdata); + static int forward_handler(int fd, void *privdata) { struct connection *con = (struct connection *)privdata; - char buf[256]; - int len = read(fd, buf, sizeof(buf)); - if (len <= 0) { + int src_fd = event_get_fd(con->src_event); + int dst_fd = event_get_fd(con->dst_event); + + struct event_fd *event; + struct flood_ctrl *flood; + int outfd; + + /* client -> device */ + if (src_fd == fd) { + flood = &con->dst_flood; + outfd = dst_fd; + event = con->src_event; + + /* device -> client */ + } else if (dst_fd == fd) { + flood = &con->src_flood; + outfd = src_fd; + event = con->dst_event; + + /* unknown fd */ + } else { list_del(&con->list); destroy_connection(con); return -1; } - int src_fd = event_get_fd(con->src_event); - int dst_fd = event_get_fd(con->dst_event); + /* increment quota */ + time_t now = time(NULL); + flood->current += (flood->inc_per_sec * (now - flood->last_inc)); + flood->last_inc = now; - /* client -> device */ - if (src_fd == fd && dst_fd != -1) - write(dst_fd, buf, len); + /* max burst size */ + if (flood->current > flood->burst) + flood->current = flood->burst; - /* device -> client */ - if (dst_fd == fd && src_fd != -1) - write(src_fd, buf, len); + /* read max. buffer size */ + char buf[256]; + int readsize = sizeof(buf); + /* only read current quota */ + if (readsize > flood->current) + readsize = flood->current; + + /* no quota left */ + if (readsize == 0) { + /* disable fd event */ + event_add_readfd(event, 0, NULL, NULL); + + /* setup timer to reenable fd-event */ + struct timeval tv = { .tv_sec = 1, .tv_usec = 0 }; + event_add_timeout(&tv, forward_throttle_timeout, con); + + return 0; + } + + int len = read(fd, buf, readsize); + if (len <= 0 && readsize > 0) { + list_del(&con->list); + destroy_connection(con); + return -1; + } + + /* not forwarding: discard the data */ + if (outfd == -1) + return 0; + + write(outfd, buf, len); + flood->current -= len; return 0; } +static int forward_throttle_timeout(void *privdata) +{ + struct connection *con = (struct connection *)privdata; + + /* HACK: enable both directions unconditionally */ + event_add_readfd(con->dst_event, 0, forward_handler, con); + event_add_readfd(con->src_event, 0, forward_handler, con); + + /* remove timeout again */ + return -1; +} + static int connect_handler(int fd, void *privdata) { struct connection *con = (struct connection *)privdata; @@ -144,7 +230,7 @@ int client_handler(int fd, void *privdata) // TODO: check destination - int dst_fd = tcp_connect(&con->dst_addr); + int dst_fd = tcp_connect_nonblock(&con->dst_addr); if (dst_fd < 0) { list_del(&con->list); destroy_connection(con); diff --git a/event.c b/event.c index 8e307e7..029c177 100644 --- a/event.c +++ b/event.c @@ -79,6 +79,7 @@ struct event_fd * event_add_fd( } memset(entry, 0, sizeof(struct event_fd)); + entry->flags |= EVENT_NEW; entry->fd = fd; /* put it on the list */ @@ -96,7 +97,6 @@ struct event_fd * event_add_fd( entry->write_priv = privdata; } - entry->flags |= EVENT_NEW; return entry; } @@ -208,32 +208,6 @@ int event_loop(void) } while (1) { - fd_set *readfds = NULL, *writefds = NULL; - struct event_fd *entry, *tmp; - - list_for_each_entry_safe(entry, tmp, &event_fd_list, list) { - entry->flags &= ~EVENT_NEW; - - if (entry->flags & EVENT_DELETE) { - list_del(&entry->list); - free(entry); - - } else if (entry->flags & FD_READ) { - if (readfds == NULL) { - readfds = &fdsets[0]; - FD_ZERO(readfds); - } - FD_SET(entry->fd, readfds); - - } else if (entry->flags & FD_WRITE) { - if (writefds == NULL) { - writefds = &fdsets[1]; - FD_ZERO(writefds); - } - FD_SET(entry->fd, writefds); - } - } - struct timeval timeout, *timeout_p = NULL; if (!list_empty(&event_timeout_list)) { struct timeval now; @@ -272,6 +246,32 @@ int event_loop(void) } } + fd_set *readfds = NULL, *writefds = NULL; + struct event_fd *entry, *tmp; + + list_for_each_entry_safe(entry, tmp, &event_fd_list, list) { + entry->flags &= ~EVENT_NEW; + + if (entry->flags & EVENT_DELETE) { + list_del(&entry->list); + free(entry); + + } else if (entry->flags & FD_READ) { + if (readfds == NULL) { + readfds = &fdsets[0]; + FD_ZERO(readfds); + } + FD_SET(entry->fd, readfds); + + } else if (entry->flags & FD_WRITE) { + if (writefds == NULL) { + writefds = &fdsets[1]; + FD_ZERO(writefds); + } + FD_SET(entry->fd, writefds); + } + } + int i = select(FD_SETSIZE, readfds, writefds, NULL, timeout_p); if (i <= 0) { /* On error, -1 is returned, and errno is set