tftpclient: converting to net sockets
[barrelfish] / lib / net_sockets / net_sockets.c
1 #include <stdio.h>
2 #include <stdlib.h>
3 #include <string.h>
4 #include <barrelfish/barrelfish.h>
5 #include <barrelfish/nameservice_client.h>
6 #include <if/net_sockets_defs.h>
7 #include <net_sockets/net_sockets.h>
8 #include <arpa/inet.h>
9
10 #include <barrelfish/waitset_chan.h>
11 #include <barrelfish/waitset.h>
12
13 #include <devif/queue_interface.h>
14 #include <devif/backends/descq.h>
15
16
17 static struct net_sockets_binding *binding;
18 static bool bound_done = false;
19
20 static struct capref buffer_frame;
21 struct descq* descq_queue;
22 static void *buffer_start;
23 static regionid_t regionid;
24 static uint64_t queue_id;
25
26 #define NO_OF_BUFFERS 128
27 #define BUFFER_SIZE 16384
28
29 void *buffers[NO_OF_BUFFERS];
30 uint64_t next_free, next_used;
31 struct net_socket *sockets = NULL;
32
33 /// Dequeue the element from the net_socket queue
34 static void dequeue(struct net_socket **queue,
35                             struct net_socket *element)
36 {
37     if (element->next == element) {
38         assert(element->prev == element);
39         assert(*queue == element);
40         *queue = NULL;
41     } else {
42         element->prev->next = element->next;
43         element->next->prev = element->prev;
44         if (*queue == element) {
45             *queue = element->next;
46         }
47     }
48     element->prev = element->next = NULL;
49 }
50
51 /// Enqueue the element on the net_socket queue
52 static void enqueue(struct net_socket **queue,
53                             struct net_socket *element)
54 {
55     if (*queue == NULL) {
56         *queue = element;
57         element->next = element->prev = element;
58     } else {
59         element->next = *queue;
60         element->prev = (*queue)->prev;
61         element->next->prev = element;
62         element->prev->next = element;
63     }
64 }
65
66 static struct net_socket * allocate_socket(uint32_t descriptor)
67 {
68     struct net_socket *socket;
69
70     socket = malloc(sizeof(struct net_socket));
71     assert(socket);
72
73     socket->descriptor = descriptor;
74     socket->received = NULL;
75     socket->connected = NULL;
76     socket->accepted = NULL;
77     socket->user_state = NULL;
78     socket->bound_address.s_addr = 0;
79     socket->bound_port = 0;
80     socket->connected_address.s_addr = 0;
81     socket->connected_port = 0;
82     enqueue(&sockets, socket);
83     return socket;
84 }
85
86 struct net_socket * net_udp_socket(void)
87 {
88     errval_t err;
89     struct net_socket *socket;
90     uint32_t descriptor;
91
92     err = binding->rpc_tx_vtbl.new_udp_socket(binding, &descriptor);
93     assert(err_is_ok(err));
94
95     socket = allocate_socket(descriptor);
96     return socket;
97 }
98
99 struct net_socket * net_tcp_socket(void)
100 {
101     errval_t err;
102     struct net_socket *socket;
103     uint32_t descriptor;
104
105     err = binding->rpc_tx_vtbl.new_tcp_socket(binding, &descriptor);
106     assert(err_is_ok(err));
107
108     socket = allocate_socket(descriptor);
109     return socket;
110 }
111
112 static struct net_socket * get_socket(uint32_t descriptor)
113 {
114     struct net_socket *socket = sockets;
115     
116     while (socket) {
117         if (socket->descriptor == descriptor)
118             return socket;
119         socket = socket->next;
120         if (socket == sockets)
121             break;
122     }
123     debug_printf("%s: socket not found %d %p\n", __func__, descriptor, __builtin_return_address(0));
124     assert(0);
125     return NULL;
126 }
127
128 void net_set_user_state(struct net_socket *socket, void *user_state)
129 {
130     socket->user_state = user_state;
131 }
132
133 void net_close(struct net_socket *socket)
134 {
135     errval_t err, error;
136
137     // debug_printf("%s(%d):\n", __func__, socket->descriptor);
138     err = binding->rpc_tx_vtbl.delete_socket(binding, socket->descriptor, &error);
139     assert(err_is_ok(err));
140     assert(err_is_ok(error));
141     dequeue(&sockets, socket);
142     free(socket);
143     // debug_printf("%s: %ld:%p  %ld:%p\n", __func__, next_free, buffers[next_free], next_used, buffers[next_used]);
144 }
145
146 errval_t net_bind(struct net_socket *socket, struct in_addr ip_address, uint16_t port)
147 {
148     errval_t err, error;
149     uint16_t bound_port;
150
151     err = binding->rpc_tx_vtbl.bind(binding, socket->descriptor, ip_address.s_addr, port, &error, &bound_port);
152     assert(err_is_ok(err));
153     socket->bound_address = ip_address;
154     socket->bound_port = bound_port;
155
156     return error;
157 }
158
159 errval_t net_listen(struct net_socket *socket, uint8_t backlog)
160 {
161     errval_t err, error;
162
163     err = binding->rpc_tx_vtbl.listen(binding, socket->descriptor, backlog, &error);
164     assert(err_is_ok(err));
165
166     return error;
167 }
168
169 void * net_alloc(size_t size)
170 {
171     void *buffer = buffers[next_free];
172     assert(buffer);
173     buffers[next_free] = NULL;
174     next_free = (next_free + 1) % NO_OF_BUFFERS;
175     // debug_printf("%s: %p:%zd  %ld:%p  %ld:%p  %p\n", __func__, buffer + sizeof(struct net_buffer), size, next_free, buffers[next_free], next_used, buffers[next_used], __builtin_return_address(0));
176     return buffer + sizeof(struct net_buffer);
177 }
178
179 void net_free(void *buffer)
180 {
181     assert(!buffers[next_used]);
182     buffers[next_used] = buffer - sizeof(struct net_buffer);
183     next_used = (next_used + 1) % NO_OF_BUFFERS;
184     // debug_printf("%s: %p  %ld:%p  %ld:%p  %p\n", __func__, buffer, next_free, buffers[next_free], next_used, buffers[next_used], __builtin_return_address(0));
185 }
186
187 errval_t net_send(struct net_socket *socket, void *data, size_t size)
188 {
189     errval_t err, error;
190
191     void *buffer = data - sizeof(struct net_buffer);
192     struct net_buffer *nb = buffer;
193     // debug_printf("%s(%d): %ld -> %p\n", __func__, socket->descriptor, size, buffer);
194
195     nb->size = size;
196     nb->descriptor = socket->descriptor;
197     nb->host_address.s_addr = INADDR_NONE;
198     nb->port = 0;
199     // debug_printf("%s: enqueue 2 %lx:%ld\n", __func__, buffer - buffer_start, sizeof(struct net_buffer) + size);
200     err = devq_enqueue((struct devq *)descq_queue, regionid, buffer - buffer_start, sizeof(struct net_buffer) + size,
201                        0, 0, 2);
202     assert(err_is_ok(err));
203     err = devq_notify((struct devq *)descq_queue);
204     assert(err_is_ok(err));
205
206     error = SYS_ERR_OK;
207     return error;
208 }
209
210 errval_t net_send_to(struct net_socket *socket, void *data, size_t size, struct in_addr ip_address, uint16_t port)
211 {
212     errval_t err, error;
213
214     void *buffer = data - sizeof(struct net_buffer);
215     struct net_buffer *nb = buffer;
216     // debug_printf("%s(%d): %ld -> %p\n", __func__, descriptor, size, buffer);
217
218     nb->size = size;
219     nb->descriptor = socket->descriptor;
220     nb->host_address = ip_address;
221     nb->port = port;
222     // debug_printf("%s: enqueue 2 %lx:%ld\n", __func__, buffer - buffer_start, sizeof(struct net_buffer) + size);
223     err = devq_enqueue((struct devq *)descq_queue, regionid, buffer - buffer_start, sizeof(struct net_buffer) + size,
224                        0, 0, 2);
225     assert(err_is_ok(err));
226     err = devq_notify((struct devq *)descq_queue);
227     assert(err_is_ok(err));
228
229     error = SYS_ERR_OK;
230     return error;
231 }
232
233 errval_t net_connect(struct net_socket *socket, struct in_addr ip_address, uint16_t port, net_connected_callback_t cb)
234 {
235     errval_t err, error;
236
237     socket->connected = cb;
238     err = binding->rpc_tx_vtbl.connect(binding, socket->descriptor, ip_address.s_addr, port, &error);
239     assert(err_is_ok(err));
240     assert(err_is_ok(error));
241
242     return error;
243 }
244
245 static void net_connected(struct net_sockets_binding *b, uint32_t descriptor, errval_t error, uint32_t connected_address, uint16_t connected_port)
246 {
247     struct net_socket *socket = get_socket(descriptor);
248     assert(socket->descriptor == descriptor);
249     assert(err_is_ok(error));
250
251     socket->connected_address.s_addr = connected_address;
252     socket->connected_port = connected_port;
253     assert(socket->connected);
254     socket->connected(socket->user_state, socket);
255 }
256
257 void net_accept(struct net_socket *socket, net_accepted_callback_t cb)
258 {
259     socket->accepted = cb;
260 }
261
262 static void net_accepted(uint32_t descriptor, uint32_t accepted_descriptor, struct in_addr host_address, uint16_t port)
263 {
264     struct net_socket *socket = get_socket(descriptor);
265     assert(socket->descriptor == descriptor);
266
267     struct net_socket *accepted_socket = allocate_socket(accepted_descriptor);
268     accepted_socket->connected_address = host_address;
269     accepted_socket->connected_port = port;
270     socket->accepted(socket->user_state, accepted_socket);
271 }
272
273
274 void net_recv(struct net_socket *socket, net_received_callback_t cb)
275 {
276     socket->received = cb;
277 }
278
279 void net_set_sent(struct net_socket *socket, net_sent_callback_t cb)
280 {
281     socket->sent = cb;
282 }
283
284 static void bind_cb(void *st, errval_t err, struct net_sockets_binding *b)
285 {
286     binding = b;
287     net_sockets_rpc_client_init(binding);
288     bound_done = true;
289 }
290
291 static void alloc_mem(struct capref *frame, void** virt, size_t size)
292 {
293     errval_t r;
294     vregion_flags_t flags;
295
296     r = frame_alloc(frame, size, NULL);
297     assert(err_is_ok(r));
298
299     flags = VREGION_FLAGS_READ_WRITE;
300     r = vspace_map_one_frame_attr(virt, size, *frame, flags, NULL, NULL);
301     assert(err_is_ok(r));
302     memset(*virt, 0, size);
303 }
304
305 static errval_t q_notify(struct descq* q)
306 {
307     assert(descq_queue == q);
308     errval_t err = SYS_ERR_OK;
309     //errval_t err2 = SYS_ERR_OK;
310     regionid_t rid;
311     genoffset_t offset;
312     genoffset_t length;
313     genoffset_t valid_data;
314     genoffset_t valid_length;
315     uint64_t flags;
316     bool notify = 0;
317
318     // debug_printf("%s: \n", __func__);
319     for (;;) {
320         err = devq_dequeue((struct devq *)descq_queue, &rid, &offset, &length,
321                            &valid_data, &valid_length, &flags);
322         if (err_is_fail(err)) {
323             break;
324         } else {
325             // debug_printf("%s: dequeue %lx:%ld %ld\n", __func__, offset, length, flags);
326             void *buffer = buffer_start + offset;
327             struct net_buffer *nb = buffer;
328             // debug_printf("%s: dequeue %lx:%ld %ld  %p socket:%d asocket:%d\n", __func__, offset, length, flags, nb, nb->descriptor, nb->accepted_descriptor);
329             struct net_socket *socket = get_socket(nb->descriptor);
330             void *shb_data = buffer + sizeof(struct net_buffer);
331
332             if (flags == 1) { // receiving buffer
333                 // debug_printf("%s: enqueue 1> %lx:%d\n", __func__, offset, nb->size);
334                 if (nb->accepted_descriptor) { // accept
335                     net_accepted(nb->descriptor, nb->accepted_descriptor, nb->host_address, nb->port);
336
337                     err = devq_enqueue((struct devq *)descq_queue, rid, offset, length, 0, 0, 1);
338                     assert(err_is_ok(err));
339                     notify = 1;
340                 } else {  // receive
341                     if (socket->received) {
342                         // debug_printf("net_received(%d): %d\n", nb->descriptor, nb->size);
343                         socket->received(socket->user_state, socket, shb_data, nb->size, nb->host_address, nb->port);
344                     // debug_printf("%s: enqueue 1< %lx:%d\n", __func__, offset, 2048);
345                     }
346
347                     err = devq_enqueue((struct devq *)descq_queue, rid, offset, length, 0, 0, 1);
348                     assert(err_is_ok(err));
349                     notify = 1;
350                 }
351             } else if (flags == 2) { // transmitting buffer
352                 if (socket->sent) {
353                     // debug_printf("%s: dequeue %lx:%ld %p\n", __func__, offset, length, shb_data);
354                     socket->sent(socket->user_state, socket, shb_data, nb->size);
355                 }
356     // debug_printf("%s: %ld:%p  %ld:%p\n", __func__, next_free, buffers[next_free], next_used, buffers[next_used]);
357                 // assert(!buffers[next_used]);
358                 // buffers[next_used] = buffer_start + offset;
359                 // next_used = (next_used + 1) % NO_OF_BUFFERS;
360             }
361         }
362     }
363
364     if (notify) {
365         // debug_printf("notify>\n");
366         err = devq_notify((struct devq *)descq_queue);
367         assert(err_is_ok(err));
368         // debug_printf("notify<\n");
369     }
370
371     return SYS_ERR_OK;
372 }
373
374 errval_t net_sockets_init(void)
375 {
376     errval_t err;
377     iref_t iref;
378
379     memset(buffers, 0, sizeof(buffers));
380     next_free = 0;
381     next_used = 0;
382
383     alloc_mem(&buffer_frame, &buffer_start, 2 * BUFFER_SIZE * NO_OF_BUFFERS);
384  
385     struct descq_func_pointer f;
386     f.notify = q_notify;
387
388     debug_printf("net socket client started \n");
389     err = descq_create(&descq_queue, DESCQ_DEFAULT_SIZE, "net_sockets_queue",
390                        false, true, true, &queue_id, &f);
391     assert(err_is_ok(err));
392
393     err = nameservice_blocking_lookup("net_sockets", &iref);
394     assert(err_is_ok(err));
395     err = net_sockets_bind(iref, bind_cb, NULL, get_default_waitset(), IDC_BIND_FLAGS_DEFAULT);
396     assert(err_is_ok(err));
397
398     while (!bound_done) {
399         event_dispatch(get_default_waitset());
400     }
401     debug_printf("%s: initialized\n", __func__);
402     binding->rx_vtbl.connected = net_connected;
403     // binding->rx_vtbl.accepted = net_accepted;
404
405     err = binding->rpc_tx_vtbl.register_queue(binding, queue_id);
406     assert(err_is_ok(err));
407
408     err = devq_register((struct devq *)descq_queue, buffer_frame, &regionid);
409     assert(err_is_ok(err));
410
411     for (int i = 0; i < NO_OF_BUFFERS; i++) {
412         err = devq_enqueue((struct devq *)descq_queue, regionid, i * BUFFER_SIZE, BUFFER_SIZE,
413                            0, 0, 1);
414         if (!err_is_ok(err))
415             debug_printf("%s: %d:%d\n", __func__, i, NO_OF_BUFFERS);
416         assert(err_is_ok(err));
417         buffers[i] = i * BUFFER_SIZE + buffer_start + BUFFER_SIZE * NO_OF_BUFFERS;
418     }
419
420     err = devq_notify((struct devq *)descq_queue);
421     assert(err_is_ok(err));
422
423     return SYS_ERR_OK;
424 }