507c495c7dddaecb86d55bbac54dc2f79b463ef3
[barrelfish] / usr / tests / mt_waitset / main.c
1 #include <stdio.h>
2 #include <string.h>
3
4 #include <barrelfish/barrelfish.h>
5 #include <barrelfish/nameservice_client.h>
6 #include <if/mt_waitset_defs.h>
7 #include <if/mt_waitset_rpcclient_defs.h>
8 #include <barrelfish/deferred.h>
9 #include <barrelfish/inthandler.h>
10 #include <bench/bench.h>
11 #include <sys/time.h>
12 #include "../lib/barrelfish/include/threads_priv.h"
13 #include <barrelfish/debug.h>
14 #include <barrelfish/spawn_client.h>
15 #include <barrelfish/event_mutex.h>
16
17 const static char *service_name = "mt_waitset_service";
18 coreid_t my_core_id, num_cores;
19 struct thread *threads[256];
20
21 static int server_threads = 10;
22 static int client_threads = 1;
23 static int iteration_count = 1000;
24 static int limit;
25
26 static int client_counter = 0;
27 static int64_t server_calls[256];
28 static int64_t client_calls[256][256];
29
30 static void show_stats(void)
31 {
32     debug_printf("Stats: %zd %zd %zd %zd %zd %zd %zd %zd %zd %zd\n",
33         server_calls[0], server_calls[1], server_calls[2], server_calls[3],
34         server_calls[4], server_calls[5], server_calls[6], server_calls[7],
35         server_calls[8], server_calls[9]);
36 }
37
38 static void show_client_stats(void)
39 {
40     int i, j, s;
41     char text[256];
42
43     for (i = 0; i < num_cores; i++) {
44         s = sprintf(text, "Core %d:", i);
45         for (j = 0; j < 16; j++)
46             s += sprintf(text + s, "\t%zd", client_calls[i][j]);
47         s += sprintf(text + s, "\n");
48         debug_printf("%s", text);
49     }
50 }
51
52 static int client_thread(void * arg)
53 {
54     struct mt_waitset_rpc_client *rpc_client;
55     errval_t err;
56     rpc_client = arg;
57     int i, j, k, l;
58     uint64_t payload[512];
59     uint64_t result[512];
60     size_t result_size;
61     uint64_t o1;
62     uint32_t o2;
63     uint32_t i1 = my_core_id << 8 | thread_self()->id;
64     uint64_t mmm = ((uint64_t)my_core_id << 56) | ((uint64_t)thread_self()->id << 48);
65
66     debug_printf("Start\n");
67
68     for (k = 0; k < iteration_count; k++) {
69         uint64_t i2 = (rdtsc() & 0xffffffff) | mmm | (((uint64_t)k & 0xffffL) << 32);
70
71         j = ((i2 >> 5) & 511) + 1;
72
73         i2 &= 0xfffffffffffff000;
74
75         for (i = 0; i < j; i++)
76             payload[i] = i2 + i;
77         if (j > limit)
78             j = limit;
79         err = rpc_client->vtbl.rpc_method(rpc_client, i2, (uint8_t *)payload, 8 * j, i1, &o1, (uint8_t *)result, &result_size, &o2);
80
81         assert(err == SYS_ERR_OK);
82         l = 0;
83         for (i = 0; i < j; i++) {
84             if (result[i] == payload[i] + i)
85                 l++;
86         }
87         if (!(i2 + 1 == o1) || result_size != (8 * j) || l != j) {
88             debug_printf("%d: wrong %016lx != %016lx  %d %zd    %d %d\n", k, i2 + 1, o1, 8 * j, result_size, j, l);
89             for (i = 0; i < j; i++)
90                 debug_printf("\t%d: %016lx %016lx\n", i, payload[i], result[i]);
91         }
92         server_calls[o2]++;
93         if (err_is_fail(err)) {
94             DEBUG_ERR(err, "error sending message\n");
95         }
96     }
97
98     dispatcher_handle_t handle = disp_disable();
99
100     client_counter--;
101     debug_printf("Done, threads left:%d\n", client_counter);
102
103     if (client_counter == 0) {
104         disp_enable(handle);
105         // all threads have finished, we're done, inform the server
106         payload[0] = mmm;
107         err = rpc_client->vtbl.rpc_method(rpc_client, mmm, (uint8_t *)payload, 8, 65536, &o1, (uint8_t *)result, &result_size, &o2);
108         show_stats();
109     } else
110         disp_enable(handle);
111     return 0;
112 }
113
114 static void bind_cb(void *st, errval_t err, struct mt_waitset_binding *b)
115 {
116     struct mt_waitset_rpc_client *rpc_client;
117     int i = (long int)st;
118
119     rpc_client = malloc(sizeof(struct mt_waitset_rpc_client));
120     mt_waitset_rpc_client_init(rpc_client, b);
121
122     client_counter = client_threads;
123     for (i = 1; i < client_threads; i++)
124         thread_create(client_thread, rpc_client);
125     client_thread(rpc_client);
126 }
127
128 static void start_client(void)
129 {
130     char name[64];
131     errval_t err;
132     iref_t iref;
133
134     debug_printf("Start client\n");
135     sprintf(name, "%s%d", service_name, 0);
136     err = nameservice_blocking_lookup(service_name, &iref);
137     if (err_is_fail(err)) {
138         USER_PANIC_ERR(err, "nameservice_blocking_lookup failed");
139     }
140     err = mt_waitset_bind(iref, bind_cb,  (void *)0, get_default_waitset(), IDC_BIND_FLAGS_DEFAULT);
141     if (err_is_fail(err)) {
142         USER_PANIC_ERR(err, "bind failed");
143     }
144 }
145
146
147 // server
148
149 static void export_cb(void *st, errval_t err, iref_t iref)
150 {
151     if (err_is_fail(err)) {
152         USER_PANIC_ERR(err, "export failed");
153     }
154     err = nameservice_register(service_name, iref);
155     if (err_is_fail(err)) {
156             USER_PANIC_ERR(err, "nameservice_register failed");
157     }
158 }
159
160 static errval_t server_rpc_method_call(struct mt_waitset_binding *b, uint64_t i1, uint8_t *s, size_t ss, uint32_t i2, uint64_t *o1, uint8_t *r, size_t *rs, uint32_t *o2)
161 {
162     int i, j, k, me;
163     static int count = 0;
164     static uint64_t calls = 0;
165     uint64_t *response = (uint64_t *)r;
166
167     for (i = 0;; i++) {
168         if (thread_self() == threads[i]) {
169             server_calls[i]++;
170             me = i;
171             break;
172         }
173     }
174
175     if (i2 == 65536) {
176         count++;    // client has finished
177     } else
178         client_calls[i2 >> 8][i2 & 255]++;
179
180     j = ss / 8;
181     k = 0;
182     for (i = 0; i < j; i++) {
183         response[i] = ((uint64_t *)s)[i];
184         if (response[i] == i1 + i)
185             k++;
186         response[i] += i;
187     }
188     if (k != j && i2 != 65536)
189         debug_printf("%s: binding:%p %08x %08x  %d %d   %016lx:%d\n", __func__, b, i2, b->incoming_token, k, j, response[0], me);
190     if (count == num_cores) {
191         bool failed = false;
192
193         debug_printf("Final statistics\n");
194         show_stats();
195         show_client_stats();
196         for (i = 0; i < num_cores; i++) {
197             for (j = 0; j < client_threads; j++) {
198                 if (client_calls[i][j] != iteration_count) {
199                     failed = true;
200                     goto out;
201                 }
202             }
203         }
204 out:
205         if (failed)
206             debug_printf("Test FAILED\n");
207         else
208             debug_printf("Test PASSED\n");
209     }
210     calls++;
211     if ((calls % 10000) == 0) {
212         show_stats();
213     }
214
215     *o1 = i1 + 1;
216     *rs = 8 * j;
217     *o2 = me;
218
219     return SYS_ERR_OK;
220 }
221
222 static struct mt_waitset_rpc_rx_vtbl rpc_rx_vtbl = {
223     .rpc_method_call = server_rpc_method_call
224 };
225
226 static errval_t connect_cb(void *st, struct mt_waitset_binding *b)
227 {
228     b->rpc_rx_vtbl = rpc_rx_vtbl;
229     return SYS_ERR_OK;
230 }
231
232 static int run_server(void * arg)
233 {
234     int i = (uintptr_t)arg;
235     struct waitset *ws = get_default_waitset();
236     errval_t err;
237
238
239     debug_printf("Server dispatch loop %d\n", i);
240     threads[i] = thread_self();
241
242     for (;;) {
243         err = event_dispatch(ws);
244         if (err_is_fail(err)) {
245             DEBUG_ERR(err, "in event_dispatch");
246             break;
247         }
248     }
249     return SYS_ERR_OK;
250 }
251
252 static void start_server(void)
253 {
254     struct waitset *ws = get_default_waitset();
255     errval_t err;
256     int i;
257
258     debug_printf("Start server\n");
259
260     err = mt_waitset_export(NULL, export_cb, connect_cb, ws,
261                             IDC_EXPORT_FLAGS_DEFAULT);
262     if (err_is_fail(err)) {
263         USER_PANIC_ERR(err, "export failed");
264     }
265     for (i = 1; i < server_threads; i++) {
266         thread_create(run_server, (void *)(uintptr_t)i);
267     }
268 }
269
270 int main(int argc, char *argv[])
271 {
272     errval_t err;
273     char *my_name = strdup(argv[0]);
274
275     my_core_id = disp_get_core_id();
276
277     memset(server_calls, 0, sizeof(server_calls));
278     memset(client_calls, 0, sizeof(client_calls));
279
280     if (argc == 1) {
281         debug_printf("Usage: %s server_threads client_threads iteration_count\n", argv[0]);
282     } else if (argc == 5) {
283         char *xargv[] = {my_name, argv[2], argv[3], argv[4], NULL};
284
285         server_threads = atoi(argv[1]);
286         client_threads = atoi(argv[2]);
287         iteration_count = atoi(argv[3]);
288
289         err = spawn_program_on_all_cores(true, xargv[0], xargv, NULL,
290             SPAWN_FLAGS_DEFAULT, NULL, &num_cores);
291         debug_printf("spawn program on all cores (%d)\n", num_cores);
292         assert(err_is_ok(err));
293
294         start_server();
295
296         struct waitset *ws = get_default_waitset();
297
298         threads[0] = thread_self();
299         for (;;) {
300             err = event_dispatch(ws);
301             if (err_is_fail(err)) {
302                 DEBUG_ERR(err, "in event_dispatch");
303                 break;
304             }
305         }
306     } else {
307         client_threads = atoi(argv[1]);
308         iteration_count = atoi(argv[2]);
309         limit = atoi(argv[3]);
310
311         struct waitset *ws = get_default_waitset();
312         start_client();
313         debug_printf("Client process events\n");
314         for (;;) {
315             err = event_dispatch(ws);
316             if (err_is_fail(err)) {
317                 DEBUG_ERR(err, "in event_dispatch");
318                 break;
319             }
320         }
321     }
322     return EXIT_FAILURE;
323 }