ba61361368cfc225f1777d2e2da5da699d2ce9ef
[barrelfish] / usr / tests / mt_waitset / main.c
1 #include <stdio.h>
2 #include <string.h>
3
4 #include <barrelfish/barrelfish.h>
5 #include <barrelfish/nameservice_client.h>
6 #include <if/mt_waitset_defs.h>
7 #include <if/mt_waitset_rpcclient_defs.h>
8 #include <barrelfish/deferred.h>
9 #include <barrelfish/inthandler.h>
10 #include <bench/bench.h>
11 #include <sys/time.h>
12 #include "../lib/barrelfish/include/threads_priv.h"
13 #include <barrelfish/debug.h>
14 #include <barrelfish/spawn_client.h>
15 #include <barrelfish/event_mutex.h>
16
17 const static char *service_name = "mt_waitset_service";
18 coreid_t my_core_id, num_cores;
19 struct thread *threads[256];
20
21 static int server_threads = 10;
22 static int client_threads = 1;
23 static int iteration_count = 1000;
24 static int limit;
25
26 static int client_counter = 0;
27 static int64_t server_calls[256];
28 static int64_t client_calls[256][256];
29
30 static void show_stats(void)
31 {
32     debug_printf("Stats: %zd %zd %zd %zd %zd %zd %zd %zd %zd %zd\n",
33         server_calls[0], server_calls[1], server_calls[2], server_calls[3],
34         server_calls[4], server_calls[5], server_calls[6], server_calls[7],
35         server_calls[8], server_calls[9]);
36 }
37
38 static void show_client_stats(void)
39 {
40     int i, j, s;
41     char text[256];
42
43     for (i = 0; i < num_cores; i++) {
44         s = sprintf(text, "Core %d:", i);
45         for (j = 0; j < 16; j++)
46             s += sprintf(text + s, "\t%zd", client_calls[i][j]);
47         s += sprintf(text + s, "\n");
48         debug_printf("%s", text);
49     }
50 }
51
52 static int client_thread(void * arg)
53 {
54     struct mt_waitset_rpc_client *rpc_client;
55     errval_t err;
56     rpc_client = arg;
57     int i, j, k, l;
58     uint64_t payload[256];
59     uint64_t result[256];
60     size_t result_size;
61     uint64_t o1;
62     uint32_t o2;
63     uint32_t i1 = my_core_id << 8 | thread_self()->id;
64     uint64_t mmm = ((uint64_t)my_core_id << 56) | ((uint64_t)thread_self()->id << 48);
65
66     debug_printf("Start\n");
67
68     for (k = 0; k < iteration_count; k++) {
69         uint64_t i2 = (rdtsc() & 0xffffffff) | mmm | (((uint64_t)k & 0xffffL) << 32);
70
71         j = ((i2 >> 5) & 127) + 1;
72
73         i2 &= 0xfffffffffffff000;
74
75         for (i = 0; i < j; i++)
76             payload[i] = i2 + i;
77         if (j > limit)
78             j = limit;
79         err = rpc_client->vtbl.rpc_method(rpc_client, i2, (uint8_t *)payload, 8 * j, i1, &o1, (uint8_t *)result, &result_size, &o2);
80
81         assert(err == SYS_ERR_OK);
82         l = 0;
83         for (i = 0; i < j; i++) {
84             if (result[i] == payload[i] + i)
85                 l++;
86         }
87         if (!(i2 + 1 == o1) || result_size != (8 * j) || l != j) {
88             debug_printf("%d: wrong %016lx != %016lx  %d %zd    %d %d\n", k, i2 + 1, o1, 8 * j, result_size, j, l);
89             for (i = 0; i < j; i++)
90                 debug_printf("\t%d: %016lx %016lx\n", i, payload[i], result[i]);
91         }
92         server_calls[o2]++;
93         if (err_is_fail(err)) {
94             DEBUG_ERR(err, "error sending message\n");
95         }
96     }
97
98     client_counter--;
99     debug_printf("Done, threads left:%d\n", client_counter);
100
101     if (client_counter == 0) {
102         // all threads have finished, we're done, inform the server
103         payload[0] = mmm;
104         err = rpc_client->vtbl.rpc_method(rpc_client, mmm, (uint8_t *)payload, 8, 65536, &o1, (uint8_t *)result, &result_size, &o2);
105         show_stats();
106     }
107     return 0;
108 }
109
110 static void bind_cb(void *st, errval_t err, struct mt_waitset_binding *b)
111 {
112     struct mt_waitset_rpc_client *rpc_client;
113     int i = (long int)st;
114
115     rpc_client = malloc(sizeof(struct mt_waitset_rpc_client));
116     mt_waitset_rpc_client_init(rpc_client, b);
117
118     client_counter = client_threads;
119     for (i = 1; i < client_threads; i++)
120         thread_create(client_thread, rpc_client);
121     client_thread(rpc_client);
122 }
123
124 static void start_client(void)
125 {
126     char name[64];
127     errval_t err;
128     iref_t iref;
129
130     debug_printf("Start client\n");
131     sprintf(name, "%s%d", service_name, 0);
132     err = nameservice_blocking_lookup(service_name, &iref);
133     if (err_is_fail(err)) {
134         USER_PANIC_ERR(err, "nameservice_blocking_lookup failed");
135     }
136     err = mt_waitset_bind(iref, bind_cb,  (void *)0, get_default_waitset(), IDC_BIND_FLAGS_DEFAULT);
137     if (err_is_fail(err)) {
138         USER_PANIC_ERR(err, "bind failed");
139     }
140 }
141
142
143 // server
144
145 static void export_cb(void *st, errval_t err, iref_t iref)
146 {
147     if (err_is_fail(err)) {
148         USER_PANIC_ERR(err, "export failed");
149     }
150     err = nameservice_register(service_name, iref);
151     if (err_is_fail(err)) {
152             USER_PANIC_ERR(err, "nameservice_register failed");
153     }
154 }
155
156 static errval_t server_rpc_method_call(struct mt_waitset_binding *b, uint64_t i1, uint8_t *s, size_t ss, uint32_t i2, uint64_t *o1, uint8_t *r, size_t *rs, uint32_t *o2)
157 {
158     int i, j, k, me;
159     static int count = 0;
160     static uint64_t calls = 0;
161     uint64_t *response = (uint64_t *)r;
162
163     for (i = 0;; i++) {
164         if (thread_self() == threads[i]) {
165             server_calls[i]++;
166             me = i;
167             break;
168         }
169     }
170
171     if (i2 == 65536) {
172         count++;    // client has finished
173     } else
174         client_calls[i2 >> 8][i2 & 255]++;
175
176     j = ss / 8;
177     k = 0;
178     for (i = 0; i < j; i++) {
179         response[i] = ((uint64_t *)s)[i];
180         if (response[i] == i1 + i)
181             k++;
182         response[i] += i;
183     }
184     if (k != j && i2 != 65536)
185         debug_printf("server_zrob_call: binding:%p %08x %08x  %d %d   %016lx:%d\n", b, i2, b->incoming_token, k, j, response[0], me);
186     if (count == num_cores) {
187         bool failed = false;
188
189         debug_printf("Final statistics\n");
190         show_stats();
191         show_client_stats();
192         for (i = 0; i < num_cores; i++) {
193             for (j = 0; j < client_threads; j++) {
194                 if (client_calls[i][j] != iteration_count) {
195                     failed = true;
196                     goto out;
197                 }
198             }
199         }
200 out:
201         if (failed)
202             debug_printf("Test FAILED\n");
203         else
204             debug_printf("Test PASSED\n");
205     }
206     calls++;
207     if ((calls % iteration_count) == 0) {
208         show_stats();
209     }
210
211     *o1 = i1 + 1;
212     *rs = 8 * j;
213     *o2 = me;
214
215     return SYS_ERR_OK;
216 }
217
218 static struct mt_waitset_rpc_rx_vtbl rpc_rx_vtbl = {
219     .rpc_method_call = server_rpc_method_call
220 };
221
222 static errval_t connect_cb(void *st, struct mt_waitset_binding *b)
223 {
224     b->rpc_rx_vtbl = rpc_rx_vtbl;
225     return SYS_ERR_OK;
226 }
227
228 static int run_server(void * arg)
229 {
230     int i = (uintptr_t)arg;
231     struct waitset *ws = get_default_waitset();
232     errval_t err;
233
234
235     debug_printf("Server dispatch loop %d\n", i);
236     threads[i] = thread_self();
237
238     for (;;) {
239         err = event_dispatch(ws);
240         if (err_is_fail(err)) {
241             DEBUG_ERR(err, "in event_dispatch");
242             break;
243         }
244     }
245     return SYS_ERR_OK;
246 }
247
248 static void start_server(void)
249 {
250     struct waitset *ws = get_default_waitset();
251     errval_t err;
252     int i;
253
254     debug_printf("Start server\n");
255
256     err = mt_waitset_export(NULL, export_cb, connect_cb, ws,
257                             IDC_EXPORT_FLAGS_DEFAULT);
258     if (err_is_fail(err)) {
259         USER_PANIC_ERR(err, "export failed");
260     }
261     for (i = 1; i < server_threads; i++) {
262         thread_create(run_server, (void *)(uintptr_t)i);
263     }
264 }
265
266 int main(int argc, char *argv[])
267 {
268     errval_t err;
269     char *my_name = strdup(argv[0]);
270
271     my_core_id = disp_get_core_id();
272
273     memset(server_calls, 0, sizeof(server_calls));
274     memset(client_calls, 0, sizeof(client_calls));
275
276     debug_printf("Got %d args\n", argc);
277
278     if (argc == 1) {
279         debug_printf("Usage: %s server_threads client_threads iteration_count\n", argv[0]);
280     } else if (argc == 5) {
281         char *xargv[] = {my_name, argv[2], argv[3], argv[4], NULL};
282
283         server_threads = atoi(argv[1]);
284         client_threads = atoi(argv[2]);
285         iteration_count = atoi(argv[3]);
286
287         err = spawn_program_on_all_cores(true, xargv[0], xargv, NULL,
288             SPAWN_FLAGS_DEFAULT, NULL, &num_cores);
289         debug_printf("spawn program on all cores (%d)\n", num_cores);
290         assert(err_is_ok(err));
291
292         start_server();
293
294         struct waitset *ws = get_default_waitset();
295
296         threads[0] = thread_self();
297         for (;;) {
298             err = event_dispatch(ws);
299             if (err_is_fail(err)) {
300                 DEBUG_ERR(err, "in event_dispatch");
301                 break;
302             }
303         }
304     } else {
305         client_threads = atoi(argv[1]);
306         iteration_count = atoi(argv[2]);
307         limit = atoi(argv[3]);
308
309         start_client();
310
311         struct waitset *ws = get_default_waitset();
312
313         for (;;) {
314             err = event_dispatch(ws);
315             if (err_is_fail(err)) {
316                 DEBUG_ERR(err, "in event_dispatch");
317                 break;
318             }
319         }
320     }
321     return EXIT_FAILURE;
322 }