Skip to content

Commit 02d08da

Browse files
VirrageSapaszke
authored andcommitted
Add support for IPv6 in Data Channel TCP (pytorch#53)
1 parent 13a5090 commit 02d08da

File tree

1 file changed

+89
-49
lines changed

1 file changed

+89
-49
lines changed

torch/lib/THD/base/channels/DataChannelTCP.cpp

Lines changed: 89 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -196,67 +196,99 @@ DataChannelTCP::~DataChannelTCP()
196196

197197

198198
void DataChannelTCP::listen(std::uint16_t port = 0) {
199-
SYSCHECK(_socket = ::socket(PF_INET, SOCK_STREAM, 0))
199+
struct addrinfo hints, *res = NULL;
200+
201+
memset(&hints, 0x00, sizeof(hints));
202+
hints.ai_flags = AI_PASSIVE;
203+
hints.ai_family = AF_UNSPEC; // either IPv4 or IPv6
204+
hints.ai_socktype = SOCK_STREAM; // TCP
205+
206+
// `getaddrinfo` will sort addresses according to RFC 3484 and can be tweeked
207+
// by editing `/etc/gai.conf`. so there is no need to manual sorting
208+
// or protcol preference.
209+
int err = getaddrinfo(NULL, std::to_string(port).data(), &hints, &res);
210+
if (err != 0 || !res) {
211+
throw std::invalid_argument("cannot find host to listen on: " + std::string(gai_strerror(err)));
212+
}
200213

201-
struct sockaddr_in addr;
202-
socklen_t addr_len = sizeof(addr);
214+
std::shared_ptr<struct addrinfo> addresses(res, [](struct addrinfo* p) {
215+
::freeaddrinfo(p);
216+
});
203217

204-
memset(&addr, 0, addr_len);
205-
addr.sin_family = AF_INET;
206-
addr.sin_port = htons(port);
207-
addr.sin_addr.s_addr = INADDR_ANY;
218+
struct addrinfo *next_addr = addresses.get();
219+
while (true) {
220+
try {
221+
SYSCHECK(_socket = ::socket(next_addr->ai_family, next_addr->ai_socktype, next_addr->ai_protocol))
208222

209-
int optval = 1;
210-
SYSCHECK(::setsockopt(_socket, SOL_SOCKET, SO_REUSEADDR, &optval, sizeof(int)))
211-
SYSCHECK(::bind(_socket, reinterpret_cast<struct sockaddr*>(&addr), sizeof(addr)))
212-
SYSCHECK(::listen(_socket, LISTEN_QUEUE_SIZE))
213-
SYSCHECK(::getsockname(_socket, reinterpret_cast<struct sockaddr*>(&addr), &addr_len))
223+
int optval = 1;
224+
SYSCHECK(::setsockopt(_socket, SOL_SOCKET, SO_REUSEADDR, &optval, sizeof(int)))
225+
SYSCHECK(::bind(_socket, next_addr->ai_addr, next_addr->ai_addrlen))
226+
SYSCHECK(::listen(_socket, LISTEN_QUEUE_SIZE))
227+
break;
228+
} catch (const std::system_error& e) {
229+
::close(_socket);
230+
next_addr = next_addr->ai_next;
214231

232+
// we have tried all addresses but could not establish listening on any of them
233+
if (!next_addr) {
234+
throw e;
235+
}
236+
}
237+
}
238+
239+
// get listen port
240+
struct sockaddr_in addr;
241+
socklen_t addr_len = sizeof(addr);
242+
SYSCHECK(::getsockname(_socket, reinterpret_cast<struct sockaddr*>(&addr), &addr_len))
215243
_port = ntohs(addr.sin_port);
216244
}
217245

218246

219247
int DataChannelTCP::connect(const std::string& address, std::uint16_t port,
220248
int wait = true) const {
221-
struct sockaddr_in addr;
222-
socklen_t addr_len = sizeof(addr);
223-
224-
memset(&addr, 0, addr_len);
225-
addr.sin_family = AF_INET;
226-
addr.sin_port = htons(port);
227-
struct addrinfo *res;
228-
229-
// get address by host or IP
230-
int err = ::getaddrinfo(address.data(), NULL, NULL, &res);
231-
if (err == 0) {
232-
std::memcpy(
233-
&(addr.sin_addr),
234-
&(reinterpret_cast<struct sockaddr_in*>(res->ai_addr)->sin_addr),
235-
sizeof(struct in_addr)
236-
);
237-
::freeaddrinfo(res);
238-
} else {
239-
SYSCHECK(err = ::inet_pton(AF_INET, address.data(), &(addr.sin_addr)))
240-
if (err == 0)
241-
throw std::invalid_argument("invalid IP address");
249+
struct addrinfo hints, *res = NULL;
250+
251+
memset(&hints, 0x00, sizeof(hints));
252+
hints.ai_flags = AI_NUMERICSERV; // specifies that port (service) is numeric
253+
hints.ai_family = AF_UNSPEC; // either IPv4 or IPv6
254+
hints.ai_socktype = SOCK_STREAM; // TCP
255+
256+
// `getaddrinfo` will sort addresses according to RFC 3484 and can be tweeked
257+
// by editing `/etc/gai.conf`. so there is no need to manual sorting
258+
// or protcol preference.
259+
int err = ::getaddrinfo(address.data(), std::to_string(port).data(), &hints, &res);
260+
if (err != 0 || !res) {
261+
throw std::invalid_argument("host not found: " + std::string(gai_strerror(err)));
242262
}
243263

264+
std::shared_ptr<struct addrinfo> addresses(res, [](struct addrinfo* p) {
265+
::freeaddrinfo(p);
266+
});
267+
268+
struct addrinfo *next_addr = addresses.get();
244269
int socket;
245270
while (true) {
246271
try {
247-
/*
248-
* If connect() fails, the state of the socket is unspecified.
249-
* We should close the socket and create a new one before attempting to reconnect.
250-
*/
251-
SYSCHECK(socket = ::socket(AF_INET, SOCK_STREAM, 0))
252-
SYSCHECK(::connect(socket, reinterpret_cast<const struct sockaddr*>(&addr), addr_len))
272+
SYSCHECK(socket = ::socket(next_addr->ai_family, next_addr->ai_socktype, next_addr->ai_protocol))
273+
SYSCHECK(::connect(socket, next_addr->ai_addr, next_addr->ai_addrlen))
253274
break;
254275
} catch (const std::system_error& e) {
276+
// if `connect` fails, the state of the socket is unspecified.
277+
// we should close the socket and create a new one before attempting to reconnect.
255278
::close(socket);
256-
if (!wait || (errno != ECONNREFUSED))
257-
throw e;
258279

259-
std::this_thread::sleep_for(std::chrono::seconds(1));
280+
if (!wait || (errno != ECONNREFUSED)) {
281+
// we need to move to next address because this was not available
282+
// to connect or to create socket
283+
next_addr = next_addr->ai_next;
284+
285+
// we have tried all addresses but could not connect to any of them
286+
if (!next_addr) {
287+
throw e;
288+
}
289+
} else {
290+
std::this_thread::sleep_for(std::chrono::seconds(1));
291+
}
260292
}
261293
}
262294

@@ -278,16 +310,24 @@ std::tuple<int, std::string> DataChannelTCP::accept() const {
278310
throw std::system_error(ECONNABORTED, std::system_category());
279311
}
280312

281-
struct sockaddr_in addr;
313+
int socket;
314+
SYSCHECK(socket = ::accept(_socket, NULL, NULL))
315+
316+
struct sockaddr_storage addr;
282317
socklen_t addr_len = sizeof(addr);
283-
std::memset(&addr, 0, sizeof(addr));
318+
char address[INET6_ADDRSTRLEN + 1];
284319

285-
int socket;
286-
SYSCHECK(socket = ::accept(_socket, reinterpret_cast<struct sockaddr*>(&addr), &addr_len))
320+
SYSCHECK(::getpeername(socket, reinterpret_cast<struct sockaddr*>(&addr), &addr_len))
287321

288-
char address[INET_ADDRSTRLEN + 1];
289-
SYSCHECK(::inet_ntop(AF_INET, &(addr.sin_addr), address, INET_ADDRSTRLEN))
290-
address[INET_ADDRSTRLEN] = '\0';
322+
if (addr.ss_family == AF_INET) {
323+
struct sockaddr_in *s = reinterpret_cast<struct sockaddr_in*>(&addr);
324+
SYSCHECK(::inet_ntop(AF_INET, &(s->sin_addr), address, INET_ADDRSTRLEN))
325+
address[INET_ADDRSTRLEN] = '\0';
326+
} else {
327+
struct sockaddr_in6 *s = reinterpret_cast<struct sockaddr_in6*>(&addr);
328+
SYSCHECK(::inet_ntop(AF_INET6, &(s->sin6_addr), address, INET6_ADDRSTRLEN))
329+
address[INET6_ADDRSTRLEN] = '\0';
330+
}
291331

292332
return std::make_tuple(socket, std::string(address));
293333
}

0 commit comments

Comments
 (0)