tftpclient: converting to net sockets
[barrelfish] / lib / tftp / client.c
1 /**
2  * \file
3  * \brief TFTP library
4  */
5
6 /*
7  * Copyright (c) 2015 ETH Zurich.
8  * All rights reserved.
9  *
10  * This file is distributed under the terms in the attached LICENSE file.
11  * If you do not find this file, copies can be found by writing to:
12  * ETH Zurich D-INFK, Universitaetsstrasse 6, CH-8092 Zurich. Attn: Systems Group.
13  */
14
15 #include <stdlib.h>
16 #include <stdio.h>
17
18 #include <barrelfish/barrelfish.h>
19 #include <barrelfish/waitset.h>
20 #include <barrelfish/nameservice_client.h>
21
22 #include <net_sockets/net_sockets.h>
23 #include <arpa/inet.h>
24
25 #include <tftp/tftp.h>
26
27 #include "tftp_internal.h"
28
29
30 // error definitions
31 #define TFTP_ERR_BUSY 1
32 #define TFTP_ERR_DISCONNECTED 1
33 #define TFTP_ERR_NOT_FOUND 1
34 #define TFTP_ERR_ACCESS_DENIED 1
35 #define TFTP_ERR_FILE_EXISTS 1
36
37
38
39 ///< the TFTP client
40 struct tftp_client
41 {
42     /* client state */
43     tftp_st_t state;
44
45     /* connection information */
46     struct in_addr server_ip;
47     uint16_t server_port;
48     tftp_mode_t mode;
49
50     /* request information */
51     uint32_t block;
52     size_t bytes;
53     void *buf;
54     size_t buflen;
55
56     /* connection information */
57     struct net_socket *pcb;
58     void *ppayload;
59 };
60
61
62 struct tftp_client tftp_client;
63
64
65 static errval_t tftp_client_send_data(struct net_socket *socket, uint32_t blockno, void *buf,
66                                       uint32_t length, struct in_addr addr, uint16_t port)
67 {
68     void *payload = tftp_client.ppayload;
69     errval_t err;
70
71     size_t offset = set_opcode(payload, TFTP_OP_DATA);
72     offset += set_block_no(payload + offset, blockno);
73     if (length > TFTP_BLOCKSIZE) {
74         length = TFTP_BLOCKSIZE;
75     }
76
77     memcpy(payload + offset, buf, length);
78     err = net_send_to(socket, payload, length + offset, addr, port);
79     assert(err_is_ok(err));
80     return SYS_ERR_OK;
81 }
82
83
84 /*
85  * ------------------------------------------------------------------------------
86  * Recv Handlers
87  * ------------------------------------------------------------------------------
88  */
89
90 static void tftp_client_handle_write(struct net_socket *socket, void *data,
91     size_t size, struct in_addr ip_address, uint16_t port)
92 {
93     USER_PANIC("NYI");
94     tpft_op_t op = get_opcode(data);
95     uint32_t blockno;
96     switch(op) {
97         case TFTP_OP_ACK :
98             blockno = get_block_no(data, size);
99             if (blockno == TFTP_ERR_INVALID_BUFFER) {
100                 TFTP_DEBUG("failed to decode block number in data packet\n");
101                 break;
102             }
103
104             if (blockno == tftp_client.block) {
105                 if (tftp_client.state == TFTP_ST_LAST_DATA_SENT) {
106                     tftp_client.state = TFTP_ST_CLOSED;
107                     break;
108                 }
109
110                 uint32_t offset = TFTP_BLOCKSIZE * blockno;
111                 uint32_t length = TFTP_BLOCKSIZE;
112                 if (tftp_client.buflen - offset < TFTP_BLOCKSIZE) {
113                     length = tftp_client.buflen - offset;
114                     tftp_client.state = TFTP_ST_LAST_DATA_SENT;
115                 }
116
117                 tftp_client.block++;
118
119                 tftp_client_send_data(socket, tftp_client.block, tftp_client.buf + offset, length,
120                                       ip_address, port);
121                 tftp_client.state = TFTP_ST_DATA_SENT;
122             } else  {
123                 TFTP_DEBUG("got double packet: %u\n", blockno);
124             }
125
126             break;
127         case TFTP_OP_ERROR :
128             TFTP_DEBUG("got a error packet\n");
129             break;
130         default:
131             tftp_client.state = TFTP_ST_ERROR;
132             break;
133     }
134 }
135
136 static void tftp_client_handle_read(struct net_socket *socket, void *data,
137     size_t size, struct in_addr ip_address, uint16_t port)
138 {
139     tpft_op_t op = get_opcode(data);
140     uint32_t blockno;
141     switch(op) {
142         case TFTP_OP_DATA :
143             blockno = get_block_no(data, size);
144             if (blockno == TFTP_ERR_INVALID_BUFFER) {
145                 TFTP_DEBUG("failed to decode block number in data packet\n");
146                 break;
147             }
148
149             if (blockno == tftp_client.block) {
150                 if (size < 5) {
151                     TFTP_DEBUG("too small pbuf lenth\n");
152                 }
153
154                 void *buf = data + 4;
155                 size_t length = size - 4;
156                 TFTP_DEBUG_PACKETS("received block %u of size %lu bytes\n", blockno, length);
157
158                 if (tftp_client.buflen < tftp_client.bytes + length) {
159                     TFTP_DEBUG("too less bufferspace available\n");
160                     length = tftp_client.buflen - tftp_client.bytes;
161                 }
162                 memcpy(tftp_client.buf + tftp_client.bytes, buf, length);
163
164                 int r = tftp_send_ack(socket, blockno, ip_address, port,
165                                       tftp_client.ppayload);
166                 if (r != SYS_ERR_OK) {
167                     tftp_client.state = TFTP_ST_ERROR;
168                     break;
169                 }
170                 tftp_client.state = TFTP_ST_ACK_SENT;
171                 tftp_client.block++;
172                 tftp_client.bytes += length;
173                 if (length < TFTP_BLOCKSIZE) {
174                     TFTP_DEBUG("setting the last ack state\n");
175                     tftp_client.state = TFTP_ST_LAST_ACK_SENT;
176                 }
177             } else  {
178                 TFTP_DEBUG("got double packet: %u\n", blockno);
179                 int r = tftp_send_ack(socket, blockno, ip_address, port,
180                                       tftp_client.ppayload);
181                 if (r != SYS_ERR_OK) {
182                     tftp_client.state = TFTP_ST_ERROR;
183                     break;
184                 }
185                 tftp_client.state = TFTP_ST_ACK_SENT;
186             }
187
188             break;
189         case TFTP_OP_ERROR :
190             TFTP_DEBUG("got a error packet\n");
191             get_error(data, size);
192             tftp_client.state = TFTP_ST_ERROR;
193             break;
194         default:
195             tftp_client.state = TFTP_ST_ERROR;
196             TFTP_DEBUG("unexpected packet\n");
197             break;
198     }
199 }
200
201
202 static void tftp_client_recv_handler(void *user_state, struct net_socket *socket,
203     void *data, size_t size, struct in_addr ip_address, uint16_t port)
204 {
205     switch(tftp_client.state) {
206         case TFTP_ST_WRITE_REQ_SENT:
207         case TFTP_ST_DATA_SENT :
208         case TFTP_ST_LAST_DATA_SENT :
209             tftp_client_handle_write(socket, data, size, ip_address, port);
210             break;
211         case TFTP_ST_READ_REQ_SENT :
212         case TFTP_ST_ACK_SENT :
213             tftp_client_handle_read(socket, data, size, ip_address, port);
214             break;
215         default:
216             TFTP_DEBUG("unexpected state: %u\n", tftp_client.state);
217             break;
218     }
219 }
220
221 static void new_request(char *path, tpft_op_t opcode)
222 {
223     size_t path_length = strlen(path);
224     assert(strlen(path) + 14 < TFTP_MAX_MSGSIZE);
225
226     void *payload = tftp_client.ppayload;
227
228     memset(payload, 0, path_length + 16);
229
230     size_t length = set_opcode(payload, opcode);
231
232     length += snprintf(payload + length, path_length + 1, "%s", path) + 1;
233     length += set_mode(payload + length, tftp_client.mode);
234
235     TFTP_DEBUG("sending udp payload of %lu bytes\n", length);
236
237     errval_t err;
238     err = net_send_to(tftp_client.pcb, payload, length, tftp_client.server_ip, tftp_client.server_port);
239     if (err != SYS_ERR_OK) {
240         TFTP_DEBUG("send failed\n");
241     }
242 }
243
244
245 errval_t tftp_client_write_file(char *name, void *buf, size_t buflen)
246 {
247     if (tftp_client.state < TFTP_ST_IDLE) {
248         TFTP_DEBUG("attempt to read file with no connection");
249         return TFTP_ERR_DISCONNECTED;
250     }
251
252     if (tftp_client.state > TFTP_ST_IDLE) {
253         return TFTP_ERR_BUSY;
254     }
255
256     tftp_client.buf = buf;
257     tftp_client.buflen = buflen;
258     tftp_client.block = 1;
259     tftp_client.state = TFTP_ST_WRITE_REQ_SENT;
260     tftp_client.bytes = 0;
261
262     return SYS_ERR_OK;
263 }
264
265 errval_t tftp_client_read_file(char *path, void *buf, size_t buflen, size_t *ret_size)
266 {
267     if (tftp_client.state < TFTP_ST_IDLE) {
268         TFTP_DEBUG("attempt to read file with no connection");
269         return TFTP_ERR_DISCONNECTED;
270     }
271
272     if (tftp_client.state > TFTP_ST_IDLE) {
273         return TFTP_ERR_BUSY;
274     }
275
276     tftp_client.buf = buf;
277     tftp_client.buflen = buflen;
278     tftp_client.block = 1;
279     tftp_client.state = TFTP_ST_READ_REQ_SENT;
280     tftp_client.bytes = 0;
281
282     assert(tftp_client.pcb);
283
284     TFTP_DEBUG("read request of file %s\n", path);
285
286     new_request(path, TFTP_OP_READ_REQ);
287
288     while(tftp_client.state > TFTP_ST_ERROR) {
289         event_dispatch(get_default_waitset());
290     }
291
292     TFTP_DEBUG("tftp read file done.\n");
293
294     if (ret_size) {
295         *ret_size = tftp_client.bytes;
296     }
297
298     if (tftp_client.state == TFTP_ST_ERROR) {
299         tftp_client.state = TFTP_ST_IDLE;
300         return -1;
301     }
302
303     tftp_client.state = TFTP_ST_IDLE;
304
305     return SYS_ERR_OK;
306 }
307
308
309
310 /**
311  * \brief attempts to initialize a new TFTP connection to a server
312  *
313  * \returns SYS_ERR_OK on success
314  *          TFTP_ERR_* on failure
315  */
316 errval_t tftp_client_connect(char *ip, uint16_t port)
317 {
318     switch(tftp_client.state) {
319         case TFTP_ST_INVALID :
320             net_sockets_init();
321             tftp_client.pcb = net_udp_socket();
322             TFTP_DEBUG("new connection from uninitialized state\n");
323             break;
324         case TFTP_ST_CLOSED :
325             TFTP_DEBUG("new connection from closed state\n");
326             tftp_client.pcb = net_udp_socket();
327             break;
328         default:
329             TFTP_DEBUG("connection already established, cannot connect\n");
330             return TFTP_ERR_BUSY;
331     }
332
333     if (tftp_client.pcb == NULL) {
334         return LIB_ERR_MALLOC_FAIL;
335     }
336
337     tftp_client.server_port = port;
338
339     int ret = inet_aton(ip, &tftp_client.server_ip);
340     if (ret == 0) {
341         TFTP_DEBUG("Invalid IP addr: %s\n", ip);
342         return 1;
343     }
344
345     TFTP_DEBUG("connecting to %s:%" PRIu16 "\n", ip, port);
346
347     errval_t r;
348     r = net_bind(tftp_client.pcb, (struct in_addr){(INADDR_ANY)}, 0);
349     if (r != SYS_ERR_OK) {
350         USER_PANIC("UDP bind failed");
351     }
352     debug_printf("bound to %d\n", tftp_client.pcb->bound_port);
353
354     // r = net_connect(tftp_client.pcb, tftp_client.server_ip, tftp_client.server_port, NULL);
355     // if (r != SYS_ERR_OK) {
356     //     USER_PANIC("UDP connect failed");
357     // }
358
359     TFTP_DEBUG("registering recv handler\n");
360     net_recv(tftp_client.pcb, tftp_client_recv_handler);
361
362     tftp_client.state = TFTP_ST_IDLE;
363     tftp_client.mode = TFTP_MODE_OCTET;
364     tftp_client.ppayload = net_alloc(TFTP_MAX_MSGSIZE);
365     TFTP_DEBUG("all set up. connection idle\n");
366     return SYS_ERR_OK;
367 }
368
369 errval_t tftp_client_disconnect(void)
370 {
371     net_free(tftp_client.ppayload);
372     net_close(tftp_client.pcb);
373     tftp_client.state = TFTP_ST_CLOSED;
374     return SYS_ERR_OK;
375 }