Rewrite ICD loader
[ocl-icd] / ocl_icd_loader.c
1 /**
2 Copyright (c) 2012, Brice Videau <brice.videau@imag.fr>
3 Copyright (c) 2012, Vincent Danjean <Vincent.Danjean@ens-lyon.org>
4 All rights reserved.
5       
6 Redistribution and use in source and binary forms, with or without
7 modification, are permitted provided that the following conditions are met:
8     
9 1. Redistributions of source code must retain the above copyright notice, this
10    list of conditions and the following disclaimer.
11 2. Redistributions in binary form must reproduce the above copyright notice,
12    this list of conditions and the following disclaimer in the documentation
13    and/or other materials provided with the distribution.
14         
15 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
16 ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
17 WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
18 DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
19 ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
20 (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
21 LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
22 ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
23 (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
24 SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
25 */
26
27 #include <dirent.h>
28 #include <stdio.h>
29 #include <stdlib.h>
30 #include <string.h>
31 #include <dlfcn.h>
32 #define CL_USE_DEPRECATED_OPENCL_1_1_APIS
33 #include <CL/opencl.h>
34
35 #pragma GCC visibility push(hidden)
36
37 #include "ocl_icd_loader.h"
38
39 #define DEBUG_OCL_ICD 1
40
41 #define D_WARN 1
42 #define D_LOG 2
43 #define D_ARGS 4
44 #define D_DUMP 8
45 #if defined(DEBUG_OCL_ICD)
46 static int debug_ocl_icd_mask=0;
47 #  define debug(mask, fmt, ...) do {\
48         if (debug_ocl_icd_mask & mask) { \
49                 fprintf(stderr, "ocl-icd: %s: " fmt "\n", __func__, ##__VA_ARGS__); \
50         } \
51    } while(0)
52 #  define RETURN(val) do { \
53         __typeof__(val) ret=(val); \
54         debug(D_ARGS, "return: %ld/0x%lx", (long)ret, (long)ret);       \
55         return ret; \
56    } while(0)
57 #else
58 #  define debug(...) (void)0
59 #  define RETURN(val) return (val)
60 #endif
61
62 typedef __typeof__(clGetExtensionFunctionAddress) *clGetExtensionFunctionAddress_fn;
63 typedef __typeof__(clGetPlatformInfo) *clGetPlatformInfo_fn;
64
65
66 struct vendor_icd {
67   cl_uint       num_platforms;
68   cl_uint       first_platform;
69   void *        dl_handle;
70   clGetExtensionFunctionAddress_fn ext_fn_ptr;
71 };
72
73 struct platform_icd {
74   char *         extension_suffix;
75   struct vendor_icd *vicd;
76   cl_platform_id pid;
77 };
78
79 struct vendor_icd *_icds=NULL;
80 struct platform_icd *_picds=NULL;
81 static cl_uint _num_icds = 0;
82 static cl_uint _num_picds = 0;
83
84 static cl_uint _initialized = 0;
85
86 static const char *_dir_path="/etc/OpenCL/vendors/";
87
88 static inline cl_uint _find_num_icds(DIR *dir) {
89   cl_uint num_icds = 0;
90   struct dirent *ent;
91   while( (ent=readdir(dir)) != NULL ){
92     if( strcmp(ent->d_name,".") == 0 || strcmp(ent->d_name,"..") == 0 )
93       continue;
94     cl_uint d_name_len = strlen(ent->d_name);
95     if( d_name_len>4 && strcmp(ent->d_name + d_name_len - 4, ".icd" ) != 0 )
96       continue;
97 //    printf("%s%s\n", _dir_path, ent->d_name);
98     num_icds++;
99   }
100   rewinddir(dir);
101   RETURN(num_icds);
102 }
103
104 static inline cl_uint _open_drivers(DIR *dir) {
105   cl_uint num_icds = 0;
106   struct dirent *ent;
107   while( (ent=readdir(dir)) != NULL ){
108     if( strcmp(ent->d_name,".") == 0 || strcmp(ent->d_name,"..") == 0 )
109       continue;
110     cl_uint d_name_len = strlen(ent->d_name);
111     if( d_name_len>4 && strcmp(ent->d_name + d_name_len - 4, ".icd" ) != 0 )
112       continue;
113     char * lib_path;
114     char * err;
115     unsigned int lib_path_length = strlen(_dir_path) + strlen(ent->d_name) + 1;
116     lib_path = malloc(lib_path_length*sizeof(char));
117     sprintf(lib_path,"%s%s", _dir_path, ent->d_name);
118     FILE *f = fopen(lib_path,"r");
119     free(lib_path);
120
121     fseek(f, 0, SEEK_END);
122     lib_path_length = ftell(f)+1;
123     fseek(f, 0, SEEK_SET);
124     if(lib_path_length == 1) {
125       fclose(f);
126       continue;
127     }
128     lib_path = malloc(lib_path_length*sizeof(char));
129     err = fgets(lib_path, lib_path_length, f);
130     fclose(f);
131     if( err == NULL ) {
132       free(lib_path);
133       continue;
134     }
135
136     lib_path_length = strlen(lib_path);
137     
138     if( lib_path[lib_path_length-1] == '\n' )
139       lib_path[lib_path_length-1] = '\0';
140
141     _icds[num_icds].dl_handle = dlopen(lib_path, RTLD_LAZY|RTLD_LOCAL);//|RTLD_DEEPBIND);
142     if(_icds[num_icds].dl_handle != NULL) {
143       debug(D_LOG, "Loading ICD[%i] -> '%s'", num_icds, lib_path);
144       num_icds++;
145     }
146     free(lib_path);
147   }
148   RETURN(num_icds);
149 }
150
151 static void* _get_function_addr(void* dlh, clGetExtensionFunctionAddress_fn fn, const char*name) {
152   void *addr1;
153   addr1=dlsym(dlh, name);
154   if (addr1 == NULL) {
155     debug(D_WARN, "Missing global symbol '%s' in ICD, should be skipped", name);
156   }
157   void* addr2=NULL;
158   if (fn) {
159     addr2=(*fn)(name);
160     if (addr2 == NULL) {
161       debug(D_WARN, "Missing function '%s' in ICD, should be skipped", name);
162     }
163 #if defined(DEBUG_OCL_ICD)
164     if (addr1 && addr2 && addr1!=addr2) {
165       debug(D_WARN, "Function and symbol '%s' have different addresses!", name);
166     }
167 #endif
168   }
169   if (!addr2) addr2=addr1;
170   RETURN(addr2);
171 }
172
173 static int _allocate_platforms(int req) {
174   static cl_uint allocated=0;
175   debug(D_LOG,"Requesting allocation for %d platforms",req);
176   if (allocated - _num_picds < req) {
177     if (allocated==0) {
178       _picds=(struct platform_icd*)malloc(req*sizeof(struct platform_icd));
179     } else {
180       req = req - (allocated - _num_picds);
181       _picds=(struct platform_icd*)realloc(_picds, (allocated+req)*sizeof(struct platform_icd));
182     }
183     allocated += req;
184   }
185   RETURN(allocated - _num_picds);
186 }
187
188 static inline void _find_and_check_platforms(cl_uint num_icds) {
189   cl_uint i;
190   _num_icds = 0;
191   for( i=0; i<num_icds; i++){
192     debug(D_LOG, "Checking ICD %i", i);
193     struct vendor_icd *picd = &_icds[_num_icds];
194     void* dlh = _icds[i].dl_handle;
195     picd->ext_fn_ptr = _get_function_addr(dlh, NULL, "clGetExtensionFunctionAddress");
196     clIcdGetPlatformIDsKHR_fn plt_fn_ptr = 
197       _get_function_addr(dlh, picd->ext_fn_ptr, "clIcdGetPlatformIDsKHR");
198     clGetPlatformInfo_fn plt_info_ptr = 
199       _get_function_addr(dlh, picd->ext_fn_ptr, "clGetPlatformInfo");
200     if( picd->ext_fn_ptr == NULL
201         || plt_fn_ptr == NULL
202         || plt_info_ptr == NULL) {
203       debug(D_WARN, "Missing symbols in ICD, skipping it");
204       continue;
205     }
206     cl_uint num_platforms=0;
207     cl_int error;
208     error = (*plt_fn_ptr)(0, NULL, &num_platforms);
209     if( error != CL_SUCCESS || num_platforms == 0) {
210       debug(D_LOG, "No platform in ICD, skipping it");
211       continue;
212     }
213     cl_platform_id *platforms = (cl_platform_id *) malloc( sizeof(cl_platform_id) * num_platforms);
214     error = (*plt_fn_ptr)(num_platforms, platforms, NULL);
215     if( error != CL_SUCCESS ){
216       free(platforms);
217       debug(D_WARN, "Error in loading ICD platforms, skipping ICD");
218       continue;
219     }
220     cl_uint num_valid_platforms=0;
221     cl_uint j;
222     debug(D_LOG, "Try to load %d plateforms", num_platforms);
223     if (_allocate_platforms(num_platforms) < num_platforms) {
224       free(platforms);
225       debug(D_WARN, "Not enought platform allocated. Skipping ICD");
226       continue;
227     }
228     for(j=0; j<num_platforms; j++) {
229       debug(D_LOG, "Checking platform %i", j);
230       size_t param_value_size_ret;
231       struct platform_icd *p=&_picds[_num_picds];
232       p->extension_suffix=NULL;
233       p->vicd=&_icds[i];
234       p->pid=platforms[j];
235       error = plt_info_ptr(p->pid, CL_PLATFORM_EXTENSIONS, 0, NULL, &param_value_size_ret);
236       if (error != CL_SUCCESS) {
237         debug(D_WARN, "Error while loading extensions in platform %i, skipping it",j);
238         continue;
239       }
240       char *param_value = (char *)malloc(sizeof(char)*param_value_size_ret);
241       error = plt_info_ptr(p->pid, CL_PLATFORM_EXTENSIONS, param_value_size_ret, param_value, NULL);
242       if (error != CL_SUCCESS){
243         free(param_value);
244         debug(D_WARN, "Error while loading extensions in platform %i, skipping it", j);
245         continue;
246       }
247       if( strstr(param_value, "cl_khr_icd") == NULL){
248         free(param_value);
249         debug(D_WARN, "Missing khr extension in platform %i, skipping it", j);
250         continue;
251       }
252       free(param_value);
253       error = plt_info_ptr(p->pid, CL_PLATFORM_ICD_SUFFIX_KHR, 0, NULL, &param_value_size_ret);
254       if (error != CL_SUCCESS) {
255         debug(D_WARN, "Error while loading suffix in platform %i, skipping it", j);
256         continue;
257       }
258       param_value = (char *)malloc(sizeof(char)*param_value_size_ret);
259       error = plt_info_ptr(p->pid, CL_PLATFORM_ICD_SUFFIX_KHR, param_value_size_ret, param_value, NULL);
260       if (error != CL_SUCCESS){
261         debug(D_WARN, "Error while loading suffix in platform %i, skipping it", j);
262         free(param_value);
263         continue;
264       }
265       p->extension_suffix = param_value;
266       debug(D_LOG, "Extension suffix: %s", param_value);
267       num_valid_platforms++;
268       _num_picds++;
269     }
270     if( num_valid_platforms != 0 ) {
271       if ( _num_icds != i ) {
272         picd->dl_handle = dlh;
273       }
274       _num_icds++;
275       picd->num_platforms = num_valid_platforms;
276       _icds[i].first_platform = _num_picds - num_valid_platforms;
277     } else {
278       dlclose(dlh);
279     }
280     free(platforms);
281   }
282 }
283
284 static void _initClIcd( void ) {
285   if( _initialized )
286     return;
287 #if defined(DEBUG_OCL_ICD)
288   char *debug=getenv("OCL_ICD_DEBUG");
289   if (debug) {
290     debug_ocl_icd_mask=atoi(debug);
291     if (debug_ocl_icd_mask==0)
292       debug_ocl_icd_mask=1;
293   }
294 #endif
295   cl_uint num_icds = 0;
296   DIR *dir;
297   dir = opendir(_dir_path);
298   if(dir == NULL) {
299     goto abort;
300   }
301
302   num_icds = _find_num_icds(dir);
303   if(num_icds == 0) {
304     goto abort;
305   }
306
307   _icds = (struct vendor_icd*)malloc(num_icds * sizeof(struct vendor_icd));
308   if (_icds == NULL) {
309     goto abort;
310   }
311   
312   num_icds = _open_drivers(dir);
313   if(num_icds == 0) {
314     goto abort;
315   }
316
317   _find_and_check_platforms(num_icds);
318   if(_num_icds == 0){
319     goto abort;
320   }
321
322   if (_num_icds < num_icds) {
323     _icds = (struct vendor_icd*)realloc(_icds, _num_icds * sizeof(struct vendor_icd));
324   }
325   debug(D_WARN, "%d valid vendor(s)!", _num_icds);
326   _initialized = 1;
327   return;
328 abort:
329   _num_icds = 0;
330   _initialized = 1;
331   if (_icds) {
332     free(_icds);
333     _icds = NULL;
334   }
335   return;
336 }
337
338 #pragma GCC visibility pop
339
340 CL_API_ENTRY void * CL_API_CALL clGetExtensionFunctionAddress(const char * func_name) CL_API_SUFFIX__VERSION_1_0 {
341   if( !_initialized )
342     _initClIcd();
343   if( func_name == NULL )
344     return NULL;
345   cl_uint suffix_length;
346   cl_uint i;
347   void * return_value=NULL;
348   struct func_desc const * fn=&function_description[0];
349   while (fn->name != NULL) {
350     if (strcmp(func_name, fn->name)==0)
351       return fn->addr;
352     fn++;
353   }
354   for(i=0; i<_num_picds; i++) {
355     suffix_length = strlen(_picds[i].extension_suffix);
356     if( suffix_length > strlen(func_name) )
357       continue;
358     if(strcmp(_picds[i].extension_suffix, &func_name[strlen(func_name)-suffix_length]) == 0)
359       return (*_picds[i].vicd->ext_fn_ptr)(func_name);
360   }
361   return return_value;
362 }
363 typeof(clGetExtensionFunctionAddress) clGetExtensionFunctionAddress_hid __attribute__ ((alias ("clGetExtensionFunctionAddress"), visibility("hidden")));
364
365 CL_API_ENTRY cl_int CL_API_CALL
366 clGetPlatformIDs(cl_uint          num_entries,
367                  cl_platform_id * platforms,
368                  cl_uint *        num_platforms) CL_API_SUFFIX__VERSION_1_0 {
369   if( !_initialized )
370     _initClIcd();
371   if( platforms == NULL && num_platforms == NULL )
372     return CL_INVALID_VALUE;
373   if( num_entries == 0 && platforms != NULL )
374     return CL_INVALID_VALUE;
375   if( _num_icds == 0)
376     return CL_PLATFORM_NOT_FOUND_KHR;
377
378   cl_uint i;
379   if( num_platforms != NULL ){
380     *num_platforms = _num_picds;
381   }
382   if( platforms != NULL ) {
383     cl_uint n_platforms = _num_picds < num_entries ? _num_picds : num_entries;
384     for( i=0; i<n_platforms; i++) {
385       *(platforms++) = _picds[i].pid;
386     }
387   }
388   return CL_SUCCESS;
389 }
390 typeof(clGetPlatformIDs) clGetPlatformIDs_hid __attribute__ ((alias ("clGetPlatformIDs"), visibility("hidden")));
391