bb7302bc8f0f5f045884536a56c34c8b708d6148
[barrelfish] / lib / net_sockets / net_sockets.c
1 #include <stdio.h>
2 #include <stdlib.h>
3 #include <string.h>
4 #include <barrelfish/barrelfish.h>
5 #include <barrelfish/nameservice_client.h>
6 #include <if/net_sockets_defs.h>
7 #include <net_sockets/net_sockets.h>
8
9 #include <barrelfish/waitset_chan.h>
10 #include <barrelfish/waitset.h>
11
12 #include <devif/queue_interface.h>
13 #include <devif/backends/descq.h>
14
15
16 static struct net_sockets_binding *binding;
17 static bool bound_done = false;
18
19 static struct capref buffer_frame;
20 struct descq* descq_queue;
21 static void *buffer_start;
22 static regionid_t regionid;
23 static uint64_t queue_id;
24
25 #define NO_OF_BUFFERS 128
26 #define BUFFER_SIZE 16384
27
28 void *buffers[NO_OF_BUFFERS];
29 uint64_t next_free, next_used;
30 struct net_socket *sockets = NULL;
31
32 /// Dequeue the element from the net_socket queue
33 static void dequeue(struct net_socket **queue,
34                             struct net_socket *element)
35 {
36     if (element->next == element) {
37         assert(element->prev == element);
38         assert(*queue == element);
39         *queue = NULL;
40     } else {
41         element->prev->next = element->next;
42         element->next->prev = element->prev;
43         if (*queue == element) {
44             *queue = element->next;
45         }
46     }
47     element->prev = element->next = NULL;
48 }
49
50 /// Enqueue the element on the net_socket queue
51 static void enqueue(struct net_socket **queue,
52                             struct net_socket *element)
53 {
54     if (*queue == NULL) {
55         *queue = element;
56         element->next = element->prev = element;
57     } else {
58         element->next = *queue;
59         element->prev = (*queue)->prev;
60         element->next->prev = element;
61         element->prev->next = element;
62     }
63 }
64
65 struct net_socket * net_udp_socket(void)
66 {
67     errval_t err;
68     struct net_socket *socket;
69     uint32_t descriptor;
70
71     err = binding->rpc_tx_vtbl.new_udp_socket(binding, &descriptor);
72     assert(err_is_ok(err));
73
74     socket = malloc(sizeof(struct net_socket));
75     assert(socket);
76
77     socket->descriptor = descriptor;
78     socket->received = NULL;
79     socket->connected = NULL;
80     socket->accepted = NULL;
81     socket->user_state = NULL;
82     enqueue(&sockets, socket);
83     return socket;
84 }
85
86 struct net_socket * net_tcp_socket(void)
87 {
88     errval_t err;
89     struct net_socket *socket;
90     uint32_t descriptor;
91
92     err = binding->rpc_tx_vtbl.new_tcp_socket(binding, &descriptor);
93     assert(err_is_ok(err));
94
95     socket = malloc(sizeof(struct net_socket));
96     assert(socket);
97
98     socket->descriptor = descriptor;
99     socket->received = NULL;
100     socket->sent = NULL;
101     socket->connected = NULL;
102     socket->accepted = NULL;
103     socket->user_state = NULL;
104     enqueue(&sockets, socket);
105     return socket;
106 }
107
108 static struct net_socket * get_socket(uint32_t descriptor)
109 {
110     struct net_socket *socket = sockets;
111     
112     while (socket) {
113         if (socket->descriptor == descriptor)
114             return socket;
115         socket = socket->next;
116         if (socket == sockets)
117             break;
118     }
119     debug_printf("%s: %d %p\n", __func__, descriptor, __builtin_return_address(0));
120     assert(0);
121     return NULL;
122 }
123
124 void net_set_user_state(struct net_socket *socket, void *user_state)
125 {
126     socket->user_state = user_state;
127 }
128
129 void net_close(struct net_socket *socket)
130 {
131     errval_t err, error;
132
133     debug_printf("%s(%d):\n", __func__, socket->descriptor);
134     err = binding->rpc_tx_vtbl.delete_socket(binding, socket->descriptor, &error);
135     assert(err_is_ok(err));
136     assert(err_is_ok(error));
137     dequeue(&sockets, socket);
138     free(socket);
139     debug_printf("%s: %ld:%p  %ld:%p\n", __func__, next_free, buffers[next_free], next_used, buffers[next_used]);
140 }
141
142 errval_t net_bind(struct net_socket *socket, struct in_addr ip_address, uint16_t port)
143 {
144     errval_t err, error;
145
146     err = binding->rpc_tx_vtbl.bind(binding, socket->descriptor, ip_address.s_addr, port, &error);
147     assert(err_is_ok(err));
148
149     return error;
150 }
151
152 errval_t net_listen(struct net_socket *socket, uint8_t backlog)
153 {
154     errval_t err, error;
155
156     err = binding->rpc_tx_vtbl.listen(binding, socket->descriptor, backlog, &error);
157     assert(err_is_ok(err));
158
159     return error;
160 }
161
162 void * net_alloc(size_t size)
163 {
164     void *buffer = buffers[next_free];
165     assert(buffer);
166     buffers[next_free] = NULL;
167     next_free = (next_free + 1) % NO_OF_BUFFERS;
168     // debug_printf("%s: %p:%zd  %ld:%p  %ld:%p  %p\n", __func__, buffer + sizeof(struct net_buffer), size, next_free, buffers[next_free], next_used, buffers[next_used], __builtin_return_address(0));
169     return buffer + sizeof(struct net_buffer);
170 }
171
172 void net_free(void *buffer)
173 {
174     assert(!buffers[next_used]);
175     buffers[next_used] = buffer - sizeof(struct net_buffer);
176     next_used = (next_used + 1) % NO_OF_BUFFERS;
177     // debug_printf("%s: %p  %ld:%p  %ld:%p  %p\n", __func__, buffer, next_free, buffers[next_free], next_used, buffers[next_used], __builtin_return_address(0));
178 }
179
180 errval_t net_send(struct net_socket *socket, void *data, size_t size)
181 {
182     errval_t err, error;
183
184     void *buffer = data - sizeof(struct net_buffer);
185     struct net_buffer *nb = buffer;
186     // debug_printf("%s(%d): %ld -> %p\n", __func__, socket->descriptor, size, buffer);
187
188     nb->size = size;
189     nb->descriptor = socket->descriptor;
190     nb->host_address.s_addr = INADDR_NONE;
191     nb->port = 0;
192     // debug_printf("%s: enqueue 2 %lx:%ld\n", __func__, buffer - buffer_start, sizeof(struct net_buffer) + size);
193     err = devq_enqueue((struct devq *)descq_queue, regionid, buffer - buffer_start, sizeof(struct net_buffer) + size,
194                        0, 0, 2);
195     assert(err_is_ok(err));
196     err = devq_notify((struct devq *)descq_queue);
197     assert(err_is_ok(err));
198
199     error = SYS_ERR_OK;
200     return error;
201 }
202
203 errval_t net_send_to(struct net_socket *socket, void *data, size_t size, struct in_addr ip_address, uint16_t port)
204 {
205     errval_t err, error;
206
207     void *buffer = data - sizeof(struct net_buffer);
208     struct net_buffer *nb = buffer;
209     // debug_printf("%s(%d): %ld -> %p\n", __func__, descriptor, size, buffer);
210
211     nb->size = size;
212     nb->descriptor = socket->descriptor;
213     nb->host_address = ip_address;
214     nb->port = port;
215     // debug_printf("%s: enqueue 2 %lx:%ld\n", __func__, buffer - buffer_start, sizeof(struct net_buffer) + size);
216     err = devq_enqueue((struct devq *)descq_queue, regionid, buffer - buffer_start, sizeof(struct net_buffer) + size,
217                        0, 0, 2);
218     assert(err_is_ok(err));
219     err = devq_notify((struct devq *)descq_queue);
220     assert(err_is_ok(err));
221
222     error = SYS_ERR_OK;
223     return error;
224 }
225
226 errval_t net_connect(struct net_socket *socket, struct in_addr ip_address, uint16_t port, net_connected_callback_t cb)
227 {
228     errval_t err, error;
229
230     socket->connected = cb;
231     err = binding->rpc_tx_vtbl.connect(binding, socket->descriptor, ip_address.s_addr, port, &error);
232     assert(err_is_ok(err));
233     assert(err_is_ok(error));
234
235     return error;
236 }
237
238 static void net_connected(struct net_sockets_binding *b, uint32_t descriptor, errval_t error)
239 {
240     struct net_socket *socket = get_socket(descriptor);
241     assert(socket->descriptor == descriptor);
242     assert(err_is_ok(error));
243
244     assert(socket->connected);
245     socket->connected(socket->user_state, socket);
246 }
247
248 void net_accept(struct net_socket *socket, net_accepted_callback_t cb)
249 {
250     socket->accepted = cb;
251 }
252
253 static void net_accepted(uint32_t descriptor, uint32_t accepted_descriptor, struct in_addr host_address, uint16_t port)
254 {
255     struct net_socket *socket = get_socket(descriptor);
256     assert(socket->descriptor == descriptor);
257
258     struct net_socket *accepted_socket = malloc(sizeof(struct net_socket));
259     assert(accepted_socket);
260
261     accepted_socket->descriptor = accepted_descriptor;
262     accepted_socket->received = NULL;
263     accepted_socket->sent = NULL;
264     accepted_socket->connected = NULL;
265     accepted_socket->accepted = NULL;
266     accepted_socket->user_state = NULL;
267     enqueue(&sockets, accepted_socket);
268
269     assert(socket->accepted);
270     socket->accepted(socket->user_state, accepted_socket, host_address, port);
271 }
272
273
274 void net_recv(struct net_socket *socket, net_received_callback_t cb)
275 {
276     socket->received = cb;
277 }
278
279 void net_set_sent(struct net_socket *socket, net_sent_callback_t cb)
280 {
281     socket->sent = cb;
282 }
283
284 static void bind_cb(void *st, errval_t err, struct net_sockets_binding *b)
285 {
286     binding = b;
287     net_sockets_rpc_client_init(binding);
288     bound_done = true;
289 }
290
291 static void alloc_mem(struct capref *frame, void** virt, size_t size)
292 {
293     errval_t r;
294     vregion_flags_t flags;
295
296     r = frame_alloc(frame, size, NULL);
297     assert(err_is_ok(r));
298
299     flags = VREGION_FLAGS_READ_WRITE;
300     r = vspace_map_one_frame_attr(virt, size, *frame, flags, NULL, NULL);
301     assert(err_is_ok(r));
302     memset(*virt, 0, size);
303 }
304
305 static errval_t q_notify(struct descq* q)
306 {
307     assert(descq_queue == q);
308     errval_t err = SYS_ERR_OK;
309     //errval_t err2 = SYS_ERR_OK;
310     regionid_t rid;
311     genoffset_t offset;
312     genoffset_t length;
313     genoffset_t valid_data;
314     genoffset_t valid_length;
315     uint64_t flags;
316     bool notify = 0;
317
318     // debug_printf("%s: \n", __func__);
319     for (;;) {
320         err = devq_dequeue((struct devq *)descq_queue, &rid, &offset, &length,
321                            &valid_data, &valid_length, &flags);
322         if (err_is_fail(err)) {
323             break;
324         } else {
325             // debug_printf("%s: dequeue %lx:%ld %ld\n", __func__, offset, length, flags);
326             void *buffer = buffer_start + offset;
327             struct net_buffer *nb = buffer;
328             // debug_printf("%s: dequeue %lx:%ld %ld  %p socket:%d asocket:%d\n", __func__, offset, length, flags, nb, nb->descriptor, nb->accepted_descriptor);
329             struct net_socket *socket = get_socket(nb->descriptor);
330             void *shb_data = buffer + sizeof(struct net_buffer);
331
332             if (flags == 1) { // receiving buffer
333                 // debug_printf("%s: enqueue 1> %lx:%d\n", __func__, offset, nb->size);
334                 if (nb->accepted_descriptor) { // accept
335                     net_accepted(nb->descriptor, nb->accepted_descriptor, nb->host_address, nb->port);
336
337                     err = devq_enqueue((struct devq *)descq_queue, rid, offset, length, 0, 0, 1);
338                     assert(err_is_ok(err));
339                     notify = 1;
340                 } else {  // receive
341                     if (socket->received) {
342                         // debug_printf("net_received(%d): %d\n", nb->descriptor, nb->size);
343                         socket->received(socket->user_state, socket, shb_data, nb->size, nb->host_address, nb->port);
344                     // debug_printf("%s: enqueue 1< %lx:%d\n", __func__, offset, 2048);
345                     }
346
347                     err = devq_enqueue((struct devq *)descq_queue, rid, offset, length, 0, 0, 1);
348                     assert(err_is_ok(err));
349                     notify = 1;
350                 }
351             } else if (flags == 2) { // transmitting buffer
352                 if (socket->sent) {
353                     // debug_printf("%s: dequeue %lx:%ld %p\n", __func__, offset, length, shb_data);
354                     socket->sent(socket->user_state, socket, shb_data, nb->size);
355                 }
356     // debug_printf("%s: %ld:%p  %ld:%p\n", __func__, next_free, buffers[next_free], next_used, buffers[next_used]);
357                 // assert(!buffers[next_used]);
358                 // buffers[next_used] = buffer_start + offset;
359                 // next_used = (next_used + 1) % NO_OF_BUFFERS;
360             }
361         }
362     }
363
364     if (notify) {
365         // debug_printf("notify>\n");
366         err = devq_notify((struct devq *)descq_queue);
367         assert(err_is_ok(err));
368         // debug_printf("notify<\n");
369     }
370
371     return SYS_ERR_OK;
372 }
373
374 errval_t net_sockets_init(void)
375 {
376     errval_t err;
377     iref_t iref;
378
379     memset(buffers, 0, sizeof(buffers));
380     next_free = 0;
381     next_used = 0;
382
383     alloc_mem(&buffer_frame, &buffer_start, 2 * BUFFER_SIZE * NO_OF_BUFFERS);
384  
385     struct descq_func_pointer f;
386     f.notify = q_notify;
387
388     debug_printf("Descriptor queue test started \n");
389     err = descq_create(&descq_queue, DESCQ_DEFAULT_SIZE, "net_sockets_queue",
390                        false, true, true, &queue_id, &f);
391     assert(err_is_ok(err));
392
393     err = nameservice_blocking_lookup("net_sockets", &iref);
394     assert(err_is_ok(err));
395     err = net_sockets_bind(iref, bind_cb, NULL, get_default_waitset(), IDC_BIND_FLAGS_DEFAULT);
396     assert(err_is_ok(err));
397
398     while (!bound_done) {
399         event_dispatch(get_default_waitset());
400     }
401     debug_printf("%s: initialized\n", __func__);
402     binding->rx_vtbl.connected = net_connected;
403     // binding->rx_vtbl.accepted = net_accepted;
404
405     err = binding->rpc_tx_vtbl.register_queue(binding, queue_id);
406     assert(err_is_ok(err));
407
408     err = devq_register((struct devq *)descq_queue, buffer_frame, &regionid);
409     assert(err_is_ok(err));
410
411     for (int i = 0; i < NO_OF_BUFFERS; i++) {
412         err = devq_enqueue((struct devq *)descq_queue, regionid, i * BUFFER_SIZE, BUFFER_SIZE,
413                            0, 0, 1);
414         if (!err_is_ok(err))
415             debug_printf("%s: %d:%d\n", __func__, i, NO_OF_BUFFERS);
416         assert(err_is_ok(err));
417         buffers[i] = i * BUFFER_SIZE + buffer_start + BUFFER_SIZE * NO_OF_BUFFERS;
418     }
419
420     err = devq_notify((struct devq *)descq_queue);
421     assert(err_is_ok(err));
422
423     return SYS_ERR_OK;
424 }