diff --git a/Makefile.inc b/Makefile.inc index c4efa3c..22580b7 100644 --- a/Makefile.inc +++ b/Makefile.inc @@ -116,6 +116,7 @@ MANPAGES = ares_cancel.3 \ ares_set_servers_ports.3 \ ares_set_servers_ports_csv.3 \ ares_set_socket_callback.3 \ + ares_set_socket_configure_callback.3 \ ares_set_sortlist.3 \ ares_strerror.3 \ ares_timeout.3 \ @@ -167,6 +168,7 @@ HTMLPAGES = ares_cancel.html \ ares_set_servers_ports.html \ ares_set_servers_ports_csv.html \ ares_set_socket_callback.html \ + ares_set_socket_configure_callback.html \ ares_set_sortlist.html \ ares_strerror.html \ ares_timeout.html \ @@ -218,6 +220,7 @@ PDFPAGES = ares_cancel.pdf \ ares_set_servers_ports.pdf \ ares_set_servers_ports_csv.pdf \ ares_set_socket_callback.pdf \ + ares_set_socket_configure_callback.pdf \ ares_set_sortlist.pdf \ ares_strerror.pdf \ ares_timeout.pdf \ diff --git a/ares.h b/ares.h index fcbcecf..cd94d4f 100644 --- a/ares.h +++ b/ares.h @@ -294,6 +294,10 @@ typedef int (*ares_sock_create_callback)(ares_socket_t socket_fd, int type, void *data); +typedef int (*ares_sock_config_callback)(ares_socket_t socket_fd, + int type, + void *data); + CARES_EXTERN int ares_library_init(int flags); CARES_EXTERN int ares_library_init_mem(int flags, @@ -344,6 +348,10 @@ CARES_EXTERN void ares_set_socket_callback(ares_channel channel, ares_sock_create_callback callback, void *user_data); +CARES_EXTERN void ares_set_socket_configure_callback(ares_channel channel, + ares_sock_config_callback callback, + void *user_data); + CARES_EXTERN int ares_set_sortlist(ares_channel channel, const char *sortstr); diff --git a/ares_init.c b/ares_init.c index 4607944..a6e4297 100644 --- a/ares_init.c +++ b/ares_init.c @@ -164,6 +164,8 @@ int ares_init_options(ares_channel *channelptr, struct ares_options *options, channel->sock_state_cb_data = NULL; channel->sock_create_cb = NULL; channel->sock_create_cb_data = NULL; + channel->sock_config_cb = NULL; + channel->sock_config_cb_data = NULL; channel->last_server = 0; channel->last_timeout_processed = (time_t)now.tv_sec; @@ -291,6 +293,8 @@ int ares_dup(ares_channel *dest, ares_channel src) /* Now clone the options that ares_save_options() doesn't support. */ (*dest)->sock_create_cb = src->sock_create_cb; (*dest)->sock_create_cb_data = src->sock_create_cb_data; + (*dest)->sock_config_cb = src->sock_config_cb; + (*dest)->sock_config_cb_data = src->sock_config_cb_data; strncpy((*dest)->local_dev_name, src->local_dev_name, sizeof(src->local_dev_name)); @@ -2085,6 +2089,14 @@ void ares_set_socket_callback(ares_channel channel, channel->sock_create_cb_data = data; } +void ares_set_socket_configure_callback(ares_channel channel, + ares_sock_config_callback cb, + void *data) +{ + channel->sock_config_cb = cb; + channel->sock_config_cb_data = data; +} + int ares_set_sortlist(ares_channel channel, const char *sortstr) { int nsort = 0; diff --git a/ares_private.h b/ares_private.h index 45f34ab..33a23e7 100644 --- a/ares_private.h +++ b/ares_private.h @@ -311,6 +311,9 @@ struct ares_channeldata { ares_sock_create_callback sock_create_cb; void *sock_create_cb_data; + + ares_sock_config_callback sock_config_cb; + void *sock_config_cb_data; }; /* Memory management functions */ diff --git a/ares_process.c b/ares_process.c index c3ac77b..0325f51 100644 --- a/ares_process.c +++ b/ares_process.c @@ -1031,6 +1031,17 @@ static int open_tcp_socket(ares_channel channel, struct server_state *server) } #endif + if (channel->sock_config_cb) + { + int err = channel->sock_config_cb(s, SOCK_STREAM, + channel->sock_config_cb_data); + if (err < 0) + { + sclose(s); + return err; + } + } + /* Connect to the server. */ if (connect(s, sa, salen) == -1) { @@ -1115,6 +1126,17 @@ static int open_udp_socket(ares_channel channel, struct server_state *server) return -1; } + if (channel->sock_config_cb) + { + int err = channel->sock_config_cb(s, SOCK_DGRAM, + channel->sock_config_cb_data); + if (err < 0) + { + sclose(s); + return err; + } + } + /* Connect to the server. */ if (connect(s, sa, salen) == -1) { diff --git a/ares_set_socket_callback.3 b/ares_set_socket_callback.3 index 68e608b..a92262a 100644 --- a/ares_set_socket_callback.3 +++ b/ares_set_socket_callback.3 @@ -20,7 +20,7 @@ connected to the remote server. The callback must return ARES_SUCCESS if things are fine, or use the standard ares error codes to signal errors back. Returned errors will abort the ares operation. .SH SEE ALSO -.BR ares_init_options (3) +.BR ares_init_options (3), ares_set_socket_configure_callback (3) .SH AVAILABILITY ares_set_socket_callback(3) was added in c-ares 1.6.0 .SH AUTHOR diff --git a/ares_set_socket_configure_callback.3 b/ares_set_socket_configure_callback.3 new file mode 100644 index 0000000..d3b2f93 --- /dev/null +++ b/ares_set_socket_configure_callback.3 @@ -0,0 +1,33 @@ +.\" +.TH ARES_SET_SOCKET_CONFIGURE_CALLBACK 3 "6 Feb 2016" +.SH NAME +ares_set_socket_configure_callback \- Set a socket configuration callback +.SH SYNOPSIS +.nf +.B #include +.PP +.B typedef int (*ares_sock_config_callback)(ares_socket_t \fIsocket_fd\fP, + int \fItype\fP, + void *\fIuserdata\fP) +.PP +.B void ares_set_socket_configure_callback(ares_channel \fIchannel\fP, + ares_sock_config_callback \fIcallback\fP, + void *\fIuserdata\fP) +.PP +.B cc file.c -lcares +.fi +.SH DESCRIPTION +.PP +This function sets a \fIcallback\fP in the given ares channel handle. This +callback function will be invoked after the socket has been created, but +before it has been connected to the remote server, which is an ideal time +to configure various socket options. The callback must return ARES_SUCCESS +if things are fine, or return -1 to signal an error. A returned error will +abort the ares operation. +.SH SEE ALSO +.BR ares_init_options (3), ares_set_socket_callback (3) +.SH AVAILABILITY +ares_set_socket_configure_callback(3) was added in c-ares 1.11.0 +.SH AUTHOR +Andrew Ayer + diff --git a/test/ares-test-mock.cc b/test/ares-test-mock.cc index 6b40fd7..fb0d702 100644 --- a/test/ares-test-mock.cc +++ b/test/ares-test-mock.cc @@ -153,6 +153,51 @@ TEST_P(MockChannelTest, SockFailCallback) { EXPECT_EQ(ARES_ECONNREFUSED, result.status_); } +static int sock_config_cb_count = 0; +static int SocketConfigureCallback(ares_socket_t fd, int type, void *data) { + int rc = *(int*)data; + if (verbose) std::cerr << "SocketConfigureCallback(" << fd << ") invoked" << std::endl; + sock_config_cb_count++; + return rc; +} + +TEST_P(MockChannelTest, SockConfigureCallback) { + DNSPacket rsp; + rsp.set_response().set_aa() + .add_question(new DNSQuestion("www.google.com", ns_t_a)) + .add_answer(new DNSARR("www.google.com", 100, {2, 3, 4, 5})); + EXPECT_CALL(server_, OnRequest("www.google.com", ns_t_a)) + .WillOnce(SetReply(&server_, &rsp)); + + // Get notified of new sockets + int rc = ARES_SUCCESS; + ares_set_socket_configure_callback(channel_, SocketConfigureCallback, &rc); + + HostResult result; + sock_config_cb_count = 0; + ares_gethostbyname(channel_, "www.google.com.", AF_INET, HostCallback, &result); + Process(); + EXPECT_EQ(1, sock_config_cb_count); + EXPECT_TRUE(result.done_); + std::stringstream ss; + ss << result.host_; + EXPECT_EQ("{'www.google.com' aliases=[] addrs=[2.3.4.5]}", ss.str()); +} + +TEST_P(MockChannelTest, SockConfigureFailCallback) { + // Notification of new sockets gives an error. + int rc = -1; + ares_set_socket_configure_callback(channel_, SocketConfigureCallback, &rc); + + HostResult result; + sock_config_cb_count = 0; + ares_gethostbyname(channel_, "www.google.com.", AF_INET, HostCallback, &result); + Process(); + EXPECT_LT(1, sock_config_cb_count); + EXPECT_TRUE(result.done_); + EXPECT_EQ(ARES_ECONNREFUSED, result.status_); +} + // TCP only to prevent retries TEST_P(MockTCPChannelTest, MalformedResponse) { std::vector one = {0x01};