fb1a8bb6c008553861163cc98fb95c8cc81c7201
[barrelfish] / lib / tftp / client.c
1 /**
2  * \file
3  * \brief TFTP library
4  */
5
6 /*
7  * Copyright (c) 2015 ETH Zurich.
8  * All rights reserved.
9  *
10  * This file is distributed under the terms in the attached LICENSE file.
11  * If you do not find this file, copies can be found by writing to:
12  * ETH Zurich D-INFK, Universitaetsstrasse 6, CH-8092 Zurich. Attn: Systems Group.
13  */
14
15 #include <stdlib.h>
16 #include <stdio.h>
17
18 #include <barrelfish/barrelfish.h>
19 #include <barrelfish/waitset.h>
20 #include <barrelfish/nameservice_client.h>
21
22 #include <lwip/udp.h>
23 #include <lwip/init.h>
24
25 #include <tftp/tftp.h>
26
27 #include "tftp_internal.h"
28
29
30 // error definitions
31 #define TFTP_ERR_BUSY 1
32 #define TFTP_ERR_DISCONNECTED 1
33 #define TFTP_ERR_NOT_FOUND 1
34 #define TFTP_ERR_ACCESS_DENIED 1
35 #define TFTP_ERR_FILE_EXISTS 1
36
37
38
39 ///< the TFTP client
40 struct tftp_client
41 {
42     /* client state */
43     tftp_st_t state;
44
45     /* connection information */
46     struct ip_addr server_ip;
47     uint16_t server_port;
48     tftp_mode_t mode;
49
50     /* request information */
51     uint32_t block;
52     size_t bytes;
53     void *buf;
54     size_t buflen;
55
56     /* connection information */
57     struct udp_pcb *pcb;
58     struct pbuf *p;
59     void *ppayload;
60     struct udp_pcb *rcv_pcb;
61 };
62
63
64 struct tftp_client tftp_client;
65
66
67 static errval_t tftp_client_send_data(struct udp_pcb *pcb, uint32_t blockno, void *buf,
68                                       uint32_t length, struct ip_addr *addr, u16_t port,
69                                       struct pbuf *p)
70 {
71     p->len = TFTP_MAX_MSGSIZE;
72     p->tot_len = TFTP_MAX_MSGSIZE;
73     p->payload = tftp_client.ppayload;
74
75     size_t offset = set_opcode(p->payload, TFTP_OP_DATA);
76     offset += set_block_no(p->payload + offset, blockno);
77     if (length > TFTP_BLOCKSIZE) {
78         length = TFTP_BLOCKSIZE;
79     }
80
81     memcpy(p->payload + offset, buf, length);
82     p->len = (uint16_t)length + offset;
83     p->tot_len = (uint16_t)length + offset;
84
85     int r = udp_sendto(pcb, p, addr, port);
86     if (r != ERR_OK) {
87         USER_PANIC("send failed");
88     }
89
90     return SYS_ERR_OK;
91 }
92
93
94 /*
95  * ------------------------------------------------------------------------------
96  * Recv Handlers
97  * ------------------------------------------------------------------------------
98  */
99
100 static void tftp_client_handle_write(struct udp_pcb *pcb, struct pbuf *pbuf,
101                                      struct ip_addr *addr, u16_t port)
102 {
103     USER_PANIC("NYI");
104     tpft_op_t op = get_opcode(pbuf->payload);
105     uint32_t blockno;
106     switch(op) {
107         case TFTP_OP_ACK :
108             blockno = get_block_no(pbuf->payload, pbuf->len);
109             assert(pbuf->len == pbuf->tot_len);
110             if (blockno == TFTP_ERR_INVALID_BUFFER) {
111                 TFTP_DEBUG("failed to decode block number in data packet\n");
112                 break;
113             }
114
115             if (blockno == tftp_client.block) {
116                 if (tftp_client.state == TFTP_ST_LAST_DATA_SENT) {
117                     tftp_client.state = TFTP_ST_CLOSED;
118                     break;
119                 }
120
121                 uint32_t offset = TFTP_BLOCKSIZE * blockno;
122                 uint32_t length = TFTP_BLOCKSIZE;
123                 if (tftp_client.buflen - offset < TFTP_BLOCKSIZE) {
124                     length = tftp_client.buflen - offset;
125                     tftp_client.state = TFTP_ST_LAST_DATA_SENT;
126                 }
127
128                 tftp_client.block++;
129
130                 tftp_client_send_data(pcb, tftp_client.block, tftp_client.buf + offset, length,
131                                       addr, port, tftp_client.p);
132                 tftp_client.state = TFTP_ST_DATA_SENT;
133             } else  {
134                 TFTP_DEBUG("got double packet: %u\n", blockno);
135             }
136
137             break;
138         case TFTP_OP_ERROR :
139             TFTP_DEBUG("got a error packet\n");
140             break;
141         default:
142             tftp_client.state = TFTP_ST_ERROR;
143             break;
144     }
145
146     pbuf_free(pbuf);
147 }
148
149 static void tftp_client_handle_read(struct udp_pcb *pcb, struct pbuf *pbuf,
150                                     struct ip_addr *addr, u16_t port)
151 {
152     tpft_op_t op = get_opcode(pbuf->payload);
153     uint32_t blockno;
154     switch(op) {
155         case TFTP_OP_DATA :
156             blockno = get_block_no(pbuf->payload, pbuf->len);
157             assert(pbuf->len == pbuf->tot_len);
158             if (blockno == TFTP_ERR_INVALID_BUFFER) {
159                 TFTP_DEBUG("failed to decode block number in data packet\n");
160                 break;
161             }
162
163             if (blockno == tftp_client.block) {
164                 if (pbuf->len < 5) {
165                     TFTP_DEBUG("too small pbuf lenth\n");
166                 }
167
168                 void *buf = pbuf->payload + 4;
169                 size_t length = pbuf->len - 4;
170                 TFTP_DEBUG_PACKETS("received block %u of size %lu bytes\n", blockno, length);
171
172                 if (tftp_client.buflen < tftp_client.bytes + length) {
173                     TFTP_DEBUG("too less bufferspace available\n");
174                     length = tftp_client.buflen - tftp_client.bytes;
175                 }
176                 memcpy(tftp_client.buf + tftp_client.bytes, buf, length);
177
178                 int r = tftp_send_ack(pcb, blockno, addr, port, tftp_client.p,
179                                       tftp_client.ppayload);
180                 if (r != ERR_OK) {
181                     tftp_client.state = TFTP_ST_ERROR;
182                     break;
183                 }
184                 tftp_client.state = TFTP_ST_ACK_SENT;
185                 tftp_client.block++;
186                 tftp_client.bytes += length;
187                 if (length < TFTP_BLOCKSIZE) {
188                     TFTP_DEBUG("setting the last ack state\n");
189                     tftp_client.state = TFTP_ST_LAST_ACK_SENT;
190                 }
191             } else  {
192                 TFTP_DEBUG("got double packet: %u\n", blockno);
193                 int r = tftp_send_ack(pcb, blockno, addr, port, tftp_client.p,
194                                       tftp_client.ppayload);
195                 if (r != ERR_OK) {
196                     tftp_client.state = TFTP_ST_ERROR;
197                     break;
198                 }
199                 tftp_client.state = TFTP_ST_ACK_SENT;
200             }
201
202             break;
203         case TFTP_OP_ERROR :
204             TFTP_DEBUG("got a error packet\n");
205             get_error(pbuf->payload, pbuf->len);
206             tftp_client.state = TFTP_ST_ERROR;
207             break;
208         default:
209             tftp_client.state = TFTP_ST_ERROR;
210             TFTP_DEBUG("unexpected packet\n");
211             break;
212     }
213
214     pbuf_free(pbuf);
215 }
216
217
218 static void tftp_client_recv_handler(void *arg, struct udp_pcb *pcb, struct pbuf *pbuf,
219                              struct ip_addr *addr, u16_t port)
220 {
221     switch(tftp_client.state) {
222         case TFTP_ST_WRITE_REQ_SENT:
223         case TFTP_ST_DATA_SENT :
224         case TFTP_ST_LAST_DATA_SENT :
225             tftp_client_handle_write(pcb, pbuf, addr, port);
226             break;
227         case TFTP_ST_READ_REQ_SENT :
228         case TFTP_ST_ACK_SENT :
229             tftp_client_handle_read(pcb, pbuf, addr, port);
230             break;
231         default:
232             TFTP_DEBUG("unexpected state: %u\n", tftp_client.state);
233             break;
234     }
235 }
236
237 static void new_request(char *path, tpft_op_t opcode)
238 {
239     size_t path_length = strlen(path);
240     assert(strlen(path) + 14 < TFTP_MAX_MSGSIZE);
241
242     struct pbuf *p = tftp_client.p;
243     assert(p);
244
245     p->len = TFTP_MAX_MSGSIZE;
246     p->tot_len = TFTP_MAX_MSGSIZE;
247     p->payload = tftp_client.ppayload;
248
249     memset(p->payload, 0, path_length + 16);
250
251     size_t length = set_opcode(p->payload, opcode);
252
253     length += snprintf(p->payload + length, path_length + 1, "%s", path) + 1;
254     length += set_mode(p->payload + length, tftp_client.mode);
255
256     p->len = (uint16_t)length;
257     p->tot_len = (uint16_t)length;
258
259     TFTP_DEBUG("sending udp payload of %lu bytes\n", length);
260
261
262
263     int r = udp_send(tftp_client.pcb, p);
264     if (r != ERR_OK) {
265         TFTP_DEBUG("send failed\n");
266     }
267 }
268
269
270 errval_t tftp_client_write_file(char *name, void *buf, size_t buflen)
271 {
272     if (tftp_client.state < TFTP_ST_IDLE) {
273         TFTP_DEBUG("attempt to read file with no connection");
274         return TFTP_ERR_DISCONNECTED;
275     }
276
277     if (tftp_client.state > TFTP_ST_IDLE) {
278         return TFTP_ERR_BUSY;
279     }
280
281     tftp_client.buf = buf;
282     tftp_client.buflen = buflen;
283     tftp_client.block = 1;
284     tftp_client.state = TFTP_ST_WRITE_REQ_SENT;
285     tftp_client.bytes = 0;
286
287     return SYS_ERR_OK;
288 }
289
290 errval_t tftp_client_read_file(char *path, void *buf, size_t buflen, size_t *ret_size)
291 {
292     if (tftp_client.state < TFTP_ST_IDLE) {
293         TFTP_DEBUG("attempt to read file with no connection");
294         return TFTP_ERR_DISCONNECTED;
295     }
296
297     if (tftp_client.state > TFTP_ST_IDLE) {
298         return TFTP_ERR_BUSY;
299     }
300
301     tftp_client.buf = buf;
302     tftp_client.buflen = buflen;
303     tftp_client.block = 1;
304     tftp_client.state = TFTP_ST_READ_REQ_SENT;
305     tftp_client.bytes = 0;
306
307     assert(tftp_client.pcb);
308
309     TFTP_DEBUG("read request of file %s\n", path);
310
311     new_request(path, TFTP_OP_READ_REQ);
312
313     while(tftp_client.state > TFTP_ST_ERROR) {
314         event_dispatch(get_default_waitset());
315     }
316
317     TFTP_DEBUG("tftp read file done.\n");
318
319     if (ret_size) {
320         *ret_size = tftp_client.bytes;
321     }
322
323     if (tftp_client.state == TFTP_ST_ERROR) {
324         tftp_client.state = TFTP_ST_IDLE;
325         return -1;
326     }
327
328     tftp_client.state = TFTP_ST_IDLE;
329
330     return SYS_ERR_OK;
331 }
332
333
334
335 /**
336  * \brief attempts to initialize a new TFTP connection to a server
337  *
338  * \returns SYS_ERR_OK on success
339  *          TFTP_ERR_* on failure
340  */
341 errval_t tftp_client_connect(char *ip, uint16_t port)
342 {
343     switch(tftp_client.state) {
344         case TFTP_ST_INVALID :
345             lwip_init_auto();
346             tftp_client.pcb = udp_new();
347             TFTP_DEBUG("new connection from uninitialized state\n");
348             break;
349         case TFTP_ST_CLOSED :
350             TFTP_DEBUG("new connection from closed state\n");
351             tftp_client.pcb = udp_new();
352             break;
353         default:
354             TFTP_DEBUG("connection already established, cannot connect\n");
355             return TFTP_ERR_BUSY;
356     }
357
358     if (tftp_client.pcb == NULL) {
359         return LIB_ERR_MALLOC_FAIL;
360     }
361
362     tftp_client.server_port = port;
363
364     struct in_addr peer_ip_gen;
365     int ret = inet_aton(ip, &peer_ip_gen);
366     if (ret == 0) {
367         TFTP_DEBUG("Invalid IP addr: %s\n", ip);
368         return 1;
369     }
370     tftp_client.server_ip.addr = peer_ip_gen.s_addr;
371
372     TFTP_DEBUG("connecting to %s:%" PRIu16 "\n", ip, port);
373     tftp_client.rcv_pcb = udp_new();
374
375     int r = udp_bind(tftp_client.rcv_pcb, IP_ADDR_ANY, 0);
376     if (r != ERR_OK) {
377         USER_PANIC("UDP bind failed");
378     }
379
380     r = udp_connect(tftp_client.pcb, &tftp_client.server_ip, tftp_client.server_port);
381     if (r != ERR_OK) {
382         USER_PANIC("UDP connect failed");
383     }
384     tftp_client.pcb->local_port = tftp_client.rcv_pcb->local_port;
385
386     TFTP_DEBUG("registering recv handler\n");
387     udp_recv(tftp_client.pcb, tftp_client_recv_handler, NULL);
388     udp_recv(tftp_client.rcv_pcb, tftp_client_recv_handler, NULL);
389
390     tftp_client.state = TFTP_ST_IDLE;
391     tftp_client.mode = TFTP_MODE_OCTET;
392     tftp_client.p = pbuf_alloc(PBUF_TRANSPORT, TFTP_MAX_MSGSIZE, PBUF_POOL);
393     if (!tftp_client.p) {
394         USER_PANIC("no buffer");
395     }
396     tftp_client.ppayload = tftp_client.p->payload;
397     TFTP_DEBUG("all set up. connection idle\n");
398     return SYS_ERR_OK;
399 }
400
401 errval_t tftp_client_disconnect(void)
402 {
403     pbuf_free(tftp_client.p);
404     udp_remove(tftp_client.pcb);
405     udp_remove(tftp_client.rcv_pcb);
406     tftp_client.state = TFTP_ST_CLOSED;
407     return SYS_ERR_OK;
408 }
409
410