/*
* Copyright (c) 2004-2007 Hypertriton, Inc.
* All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions
* are met:
* 1. Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* 2. Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
* ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE FOR
* ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
* USE OF THIS SOFTWARE EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
#include
#ifdef AG_NETWORK
#include "core.h"
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include "net_command.h"
#include "net_client.h"
#include
#include
#if defined(HAVE_GETPWUID) && defined(HAVE_GETUID)
#include
#endif
enum {
RDBUF_INIT = 4096,
RDBUF_GROW = 1024,
RDBUF_MAX = 256 * (1024 * 1024),
REQBUF_MAX = 4096
};
int ncWarnOnReconnect = 1; /* Warn on reconnection */
int ncReconnectAttempts = 4; /* Connection retries */
int ncReconnectIval = 2; /* Interval between retries (secs) */
void
NC_InitSubsystem(Uint flags)
{
}
int
NC_Write(NC_Session *client, const char *fmt, ...)
{
char req[REQBUF_MAX];
ssize_t wrote, len;
va_list ap;
va_start(ap, fmt);
Vsnprintf(req, sizeof(req), fmt, ap);
va_end(ap);
len = strlen(req);
wrote = write(client->sock, req, len);
if (wrote == -1) {
AG_SetError("Write error: %s", strerror(errno));
return (-1);
} else if (wrote != len) {
AG_SetError("Short write");
return (-1);
}
return (0);
}
static int
ParseItemCount(const char *buf, unsigned int *ip)
{
char numbuf[13];
char *endp;
long lval;
Strlcpy(numbuf, buf, sizeof(numbuf));
errno = 0;
lval = strtol(numbuf, &endp, 10);
if (numbuf[0] == '\0' || endp[0] != '\0') {
*ip = 0;
return (0);
}
if ((errno == ERANGE &&
(lval == AG_LONG_MAX || lval == AG_LONG_MIN)) ||
(lval > AG_INT_MAX || lval < AG_INT_MIN)) {
AG_SetError("Item count out of range");
return (-1);
}
*ip = lval;
return (0);
}
/* Parse an error message that occured during negotiation. */
static char *
GetServerError(NC_Session *client)
{
if (((client->read.buf[0] == '!') || (client->read.buf[0] == '1')) &&
client->read.buf[1] == ' ') {
return (&client->read.buf[2]);
}
return (client->read.buf);
}
/*
* Send a query to the server and expect an immediate response consisting
* in an array of binary items.
*/
NC_Result *
NC_Query(NC_Session *client, const char *fmt, ...)
{
char req[REQBUF_MAX];
char *bufp;
NC_Result *res;
int i;
unsigned int count;
va_list ap;
char *s;
size_t totsz = 0, pos;
if (client == NULL) {
AG_SetError("Not connected to server.");
return (NULL);
}
va_start(ap, fmt);
Vsnprintf(req, sizeof(req), fmt, ap);
Strlcat(req, "\n", sizeof(req));
va_end(ap);
sendreq:
/* Issue the server request. */
if (NC_Write(client, "%s\n", req) == -1) {
return (NULL);
}
if (NC_Read(client, 3) < 1) {
if (NC_Reconnect(client) == 0) {
goto sendreq;
}
return (NULL);
}
switch (client->read.buf[0]) {
case '0':
break;
case '!':
case '1':
AG_SetError("Server error: `%s'", GetServerError(client));
return (NULL);
default:
AG_SetError("Illegal server response: `%s'", client->read.buf);
return (NULL);
}
/* Parse the list/item size specification. */
res = Malloc(sizeof(NC_Result));
bufp = &client->read.buf[2];
for (i = 0; (s = AG_Strsep(&bufp, ":")) != NULL; i++) {
if (s[0] == '\0') {
break;
}
if (ParseItemCount(s, &count) == -1) {
if (i == 0) {
Free(res);
} else {
for (i = 0; i < res->argc; i++) {
Free(res->argv[i]);
}
if (res->argv != NULL) {
Free(res->argv);
Free(res->argv_len);
}
}
return (NULL);
}
if (i == 0) {
res->argc = count;
if (count == 0) {
res->argv = NULL;
res->argv_len = NULL;
return (res);
} else {
res->argv = Malloc(count*sizeof(char *));
res->argv_len = Malloc(count*sizeof(size_t));
}
} else {
res->argv[i-1] = Malloc(count);
res->argv_len[i-1] = count;
totsz += count;
}
}
/* Read the items. */
if (NC_Write(client, "1\n") == -1)
goto fail;
if (NC_ReadBinary(client, totsz) < totsz) {
goto fail;
}
for (i = 0, pos = 0; i < res->argc; i++) {
memcpy(res->argv[i], &client->read.buf[pos], res->argv_len[i]);
pos += res->argv_len[i];
}
if (NC_Write(client, "0\n") == -1) {
goto fail;
}
return (res);
fail:
NC_FreeResult(res);
return (NULL);
}
/*
* Send a query to the server and expect a (possibly slowly delivered)
* stream of data in response.
*/
NC_Result *
NC_QueryBinary(NC_Session *client, const char *fmt, ...)
{
char sizbuf[16];
char req[REQBUF_MAX];
NC_Result *res;
size_t binread = 0, binsize;
char *dst;
ssize_t rv, i;
va_list ap;
if (client == NULL) {
AG_SetError("Not connected to server.");
return (NULL);
}
/* Issue the request. */
va_start(ap, fmt);
Vsnprintf(req, sizeof(req), fmt, ap);
va_end(ap);
sendreq:
if (NC_Write(client, "%s\n\n", req) == -1) {
return (NULL);
}
if ((rv = NC_Read(client, 15)) < 1) {
if (NC_Reconnect(client) == 0) {
goto sendreq;
}
return (NULL);
}
if (rv < 15) {
AG_SetError("Malformed binary size packet.");
return (NULL);
}
if (client->read.buf[0] != '0') {
AG_SetError("Bad request: %s", client->read.buf);
return (NULL);
}
/* Parse the size specification. */
for (i = 2; i < rv; i++) {
char *bufp = &client->read.buf[i];
char *bsp = &sizbuf[i-2];
if (*bufp == '\n') {
*bsp = '\0';
break;
}
*bsp = *bufp;
}
binsize = atoi(sizbuf);
/* Allocate the response structure. */
res = Malloc(sizeof(NC_Result));
res->argv = Malloc(sizeof(char *));
res->argv_len = Malloc(sizeof(size_t));
res->argc = 1;
res->argv[0] = dst = Malloc(binsize);
res->argv_len[0] = binsize;
/* Read the binary data. */
while (binread < binsize) {
readbin:
rv = read(client->sock, dst, binsize);
if (rv == -1) {
if (errno == EINTR) {
goto readbin;
}
AG_SetError("Read error: %s", strerror(errno));
goto fail;
} else if (rv == 0) {
break;
}
binread += rv;
dst += rv;
}
if (binread < binsize) {
AG_SetError("Incomplete binary response.");
goto fail;
}
printf("downloaded %lu bytes\n", (unsigned long)binread);
return (res);
fail:
Free(res->argv[0]);
Free(res->argv_len);
Free(res->argv);
Free(res);
return (NULL);
}
void
NC_FreeResult(NC_Result *res)
{
int i;
if (res->argv != NULL) {
for (i = 0; i < res->argc; i++) {
Free(res->argv[i]);
}
Free(res->argv);
}
if (res->argv_len != NULL) {
Free(res->argv_len);
}
Free(res);
}
/* Destroy the current server connection and establish a new one. */
int
NC_Reconnect(NC_Session *client)
{
char host_save[NC_HOSTNAME_MAX];
char port_save[NC_PORTNUM_MAX];
char user_save[NC_USERNAME_MAX];
char pass_save[NC_PASSWORD_MAX];
int try, retries;
Strlcpy(host_save, client->host, sizeof(host_save));
Strlcpy(port_save, client->port, sizeof(port_save));
Strlcpy(user_save, client->user, sizeof(user_save));
Strlcpy(pass_save, client->pass, sizeof(pass_save));
if (ncWarnOnReconnect)
fprintf(stderr, "%s: reconnecting...\n", AG_GetError());
NC_Disconnect(client);
for (try = 0, retries = ncReconnectAttempts;
try < retries;
try++) {
if (NC_Connect(client, host_save, port_save, user_save,
pass_save) == 0) {
break;
}
sleep(ncReconnectIval);
}
if (try == retries) {
AG_SetError("Could not reconnect to server.");
return (-1);
}
return (0);
}
long
NC_Read(NC_Session *client, size_t nbytes)
{
ssize_t rv;
size_t i;
client->read.len = 0;
for (;;) {
if (client->read.len+nbytes > client->read.maxlen) { /* Grow */
client->read.maxlen += nbytes+RDBUF_GROW;
if ((RDBUF_MAX > 0) &&
(client->read.maxlen > RDBUF_MAX)) {
AG_SetError("Illegal server response");
return (-1);
}
client->read.buf = Realloc(client->read.buf,
client->read.maxlen);
}
rv = read(client->sock, client->read.buf+client->read.len,
nbytes);
if (rv == -1) {
AG_SetError("Read error: %s", strerror(errno));
return (-1);
} else if (rv == 0) {
AG_SetError("EOF from server");
return (-1);
}
/* XXX add a timeout; server aborts may cause infinite loop. */
for (i = client->read.len; i < client->read.len+rv; i++) {
if (client->read.buf[i] == '\n') {
client->read.buf[client->read.len+rv-1] = '\0';
return (long)(client->read.len+rv);
}
}
client->read.len += (size_t)rv;
}
AG_SetError("Illegal server response");
return (-1);
}
long
NC_ReadBinary(NC_Session *client, size_t nbytes)
{
ssize_t rv;
client->read.len = 0;
for (;;) {
if (client->read.len+nbytes > client->read.maxlen) { /* Grow */
client->read.maxlen += nbytes+RDBUF_GROW;
if ((RDBUF_MAX > 0) &&
(client->read.maxlen > RDBUF_MAX)) {
AG_SetError("Illegal server response");
return (-1);
}
client->read.buf = Realloc(client->read.buf,
client->read.maxlen);
}
rv = read(client->sock, client->read.buf+client->read.len,
nbytes);
if (rv == -1) {
AG_SetError("Read error: %s", strerror(errno));
return (-1);
} else if (rv == 0) {
AG_SetError("EOF from server");
return (-1);
}
client->read.len += (size_t)rv;
if (client->read.len >= nbytes)
return ((long)client->read.len);
}
AG_SetError("Illegal server response");
return (-1);
}
static int
Authenticate(NC_Session *client, const char *user, const char *pass)
{
if (NC_Write(client, "password\n") == -1 ||
NC_Read(client, 32) < 1 ||
strcmp(client->read.buf, "ok-send-auth") != 0) {
AG_SetError("Authentication protocol error: `%s'",
client->read.buf);
return (-1);
}
if (NC_Write(client, "%s:%s\n", user, pass) == -1 ||
NC_Read(client, 32) < 1 ||
strcmp(client->read.buf, "ok") != 0) {
AG_SetError("Authentication failed");
return (-1);
}
return (0);
}
/* Negotiate the protocol version. */
static int
ProtoNegotiate(NC_Session *client)
{
if (NC_Read(client, 32) < 1) {
AG_SetError("Server did not respond");
return (-1);
}
Strlcpy(client->server_proto, client->read.buf,
sizeof(client->server_proto));
if (NC_Write(client, "%s\n", client->client_proto) == -1 ||
NC_Read(client, 32) < 1) {
AG_SetError("Server protocol error");
return (-1);
}
if (strncmp(client->read.buf, "auth:", strlen("auth:")) != 0) {
AG_SetError("Server version mismatch");
return (-1);
}
return (0);
}
/* Establish a connection to the server and authenticate. */
int
NC_Connect(NC_Session *client, const char *host, const char *port,
const char *user, const char *pass)
{
char fbuf[1024];
const char *cause = NULL;
struct addrinfo hints, *res, *res0;
int s, rv;
/* Look in ~/.rc for the login information. */
if (host == NULL || port == NULL || user == NULL || pass == NULL) {
char file[AG_PATHNAME_MAX];
char *s, *fbufp;
FILE *f;
#if defined(HAVE_GETPWUID) && defined(HAVE_GETUID)
{
struct passwd *pwd;
if ((pwd = getpwuid(getuid())) == NULL) {
AG_SetError("Who are you?");
return (-1);
}
Strlcpy(file, pwd->pw_dir, sizeof(file));
Strlcat(file, "/.", sizeof(file));
}
#else
Strlcpy(file, "./", sizeof(file));
#endif
Strlcat(file, client->name, sizeof(file));
Strlcat(file, "rc", sizeof(file));
if ((f = fopen(file, "r")) == NULL) {
AG_SetError("%s: %s", file, strerror(errno));
return (-1);
}
fread(fbuf, sizeof(fbuf), 1, f);
fclose(f);
fbufp = fbuf;
while ((s = AG_Strsep(&fbufp, "\n")) != NULL) {
const char *lv, *rv;
if (s[0] == '#' )
continue;
if ((lv = AG_Strsep(&s, ":=")) == NULL ||
(rv = AG_Strsep(&s, ":=")) == NULL)
continue;
if (Strcasecmp(lv, "host") == 0) host = rv;
if (Strcasecmp(lv, "port") == 0) port = rv;
if (Strcasecmp(lv, "user") == 0) user = rv;
if (Strcasecmp(lv, "pass") == 0) pass = rv;
}
}
/* Connect to the server. */
memset(&hints, 0, sizeof(hints));
hints.ai_family = PF_UNSPEC;
hints.ai_socktype = SOCK_STREAM;
if ((rv = getaddrinfo(host, port, &hints, &res0)) != 0) {
AG_SetError("%s:%s: %s", host, port, gai_strerror(rv));
return (-1);
}
for (s = -1, res = res0;
res != NULL;
res = res->ai_next) {
s = socket(res->ai_family, res->ai_socktype, res->ai_protocol);
if (s < 0) {
cause = "socket";
continue;
}
if (connect(s, res->ai_addr, res->ai_addrlen) < 0) {
cause = "connect";
close(s);
s = -1;
continue;
}
break;
}
if (s == -1) {
AG_SetError("%s: %s", cause, strerror(errno));
goto fail_resolution;
}
client->sock = s;
Strlcpy(client->host, host, sizeof(client->host));
Strlcpy(client->port, port, sizeof(client->port));
Strlcpy(client->user, user, sizeof(client->user));
Strlcpy(client->pass, pass, sizeof(client->pass));
/* Negotiate the protocol version and authenticate. */
if (ProtoNegotiate(client) == -1 ||
Authenticate(client, user, pass) == -1) {
AG_SetError("Server error: %s", GetServerError(client));
goto fail_close;
}
freeaddrinfo(res0);
return (0);
fail_close:
NC_Disconnect(client);
fail_resolution:
freeaddrinfo(res0);
return (-1);
}
void
NC_Disconnect(NC_Session *client)
{
if (client->sock != -1) {
close(client->sock);
client->sock = -1;
}
}
void
NC_Init(NC_Session *client, const char *name, const char *ver)
{
client->name = name;
client->host[0] = '\0';
client->port[0] = '\0';
client->user[0] = '\0';
client->pass[0] = '\0';
client->sock = -1;
client->read.buf = Malloc(RDBUF_INIT);
client->read.maxlen = RDBUF_INIT;
client->read.len = 0;
client->server_proto[0] = '\0';
Strlcpy(client->client_proto, name, sizeof(client->client_proto));
Strlcat(client->client_proto, " ", sizeof(client->client_proto));
Strlcat(client->client_proto, ver, sizeof(client->client_proto));
}
void
NC_Destroy(NC_Session *client)
{
NC_Disconnect(client);
Free(client->read.buf);
}
#endif /* AG_NETWORK */