Cleanup debug
[ocl-icd] / ocl_icd_loader.c
1 /**
2 Copyright (c) 2012, Brice Videau <brice.videau@imag.fr>
3 Copyright (c) 2012, Vincent Danjean <Vincent.Danjean@ens-lyon.org>
4 All rights reserved.
5       
6 Redistribution and use in source and binary forms, with or without
7 modification, are permitted provided that the following conditions are met:
8     
9 1. Redistributions of source code must retain the above copyright notice, this
10    list of conditions and the following disclaimer.
11 2. Redistributions in binary form must reproduce the above copyright notice,
12    this list of conditions and the following disclaimer in the documentation
13    and/or other materials provided with the distribution.
14         
15 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
16 ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
17 WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
18 DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
19 ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
20 (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
21 LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
22 ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
23 (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
24 SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
25 */
26
27 #include <dirent.h>
28 #include <stdio.h>
29 #include <stdlib.h>
30 #include <string.h>
31 #include <dlfcn.h>
32 #define CL_USE_DEPRECATED_OPENCL_1_1_APIS
33 #include <CL/opencl.h>
34
35 #pragma GCC visibility push(hidden)
36
37 #include "ocl_icd_loader.h"
38 #include "ocl_icd_loader_debug.h"
39
40 int debug_ocl_icd_mask=0;
41
42 typedef __typeof__(clGetExtensionFunctionAddress) *clGetExtensionFunctionAddress_fn;
43 typedef __typeof__(clGetPlatformInfo) *clGetPlatformInfo_fn;
44
45
46 struct vendor_icd {
47   cl_uint       num_platforms;
48   cl_uint       first_platform;
49   void *        dl_handle;
50   clGetExtensionFunctionAddress_fn ext_fn_ptr;
51 };
52
53 struct platform_icd {
54   char *         extension_suffix;
55   struct vendor_icd *vicd;
56   cl_platform_id pid;
57 };
58
59 struct vendor_icd *_icds=NULL;
60 struct platform_icd *_picds=NULL;
61 static cl_uint _num_icds = 0;
62 static cl_uint _num_picds = 0;
63
64 static cl_uint _initialized = 0;
65
66 static const char *_dir_path="/etc/OpenCL/vendors/";
67
68 static inline cl_uint _find_num_icds(DIR *dir) {
69   cl_uint num_icds = 0;
70   struct dirent *ent;
71   while( (ent=readdir(dir)) != NULL ){
72     if( strcmp(ent->d_name,".") == 0 || strcmp(ent->d_name,"..") == 0 )
73       continue;
74     cl_uint d_name_len = strlen(ent->d_name);
75     if( d_name_len>4 && strcmp(ent->d_name + d_name_len - 4, ".icd" ) != 0 )
76       continue;
77 //    printf("%s%s\n", _dir_path, ent->d_name);
78     num_icds++;
79   }
80   rewinddir(dir);
81   RETURN(num_icds);
82 }
83
84 static inline cl_uint _open_drivers(DIR *dir) {
85   cl_uint num_icds = 0;
86   struct dirent *ent;
87   while( (ent=readdir(dir)) != NULL ){
88     if( strcmp(ent->d_name,".") == 0 || strcmp(ent->d_name,"..") == 0 )
89       continue;
90     cl_uint d_name_len = strlen(ent->d_name);
91     if( d_name_len>4 && strcmp(ent->d_name + d_name_len - 4, ".icd" ) != 0 )
92       continue;
93     char * lib_path;
94     char * err;
95     unsigned int lib_path_length = strlen(_dir_path) + strlen(ent->d_name) + 1;
96     lib_path = malloc(lib_path_length*sizeof(char));
97     sprintf(lib_path,"%s%s", _dir_path, ent->d_name);
98     FILE *f = fopen(lib_path,"r");
99     free(lib_path);
100
101     fseek(f, 0, SEEK_END);
102     lib_path_length = ftell(f)+1;
103     fseek(f, 0, SEEK_SET);
104     if(lib_path_length == 1) {
105       fclose(f);
106       continue;
107     }
108     lib_path = malloc(lib_path_length*sizeof(char));
109     err = fgets(lib_path, lib_path_length, f);
110     fclose(f);
111     if( err == NULL ) {
112       free(lib_path);
113       continue;
114     }
115
116     lib_path_length = strlen(lib_path);
117     
118     if( lib_path[lib_path_length-1] == '\n' )
119       lib_path[lib_path_length-1] = '\0';
120
121     _icds[num_icds].dl_handle = dlopen(lib_path, RTLD_LAZY|RTLD_LOCAL);//|RTLD_DEEPBIND);
122     if(_icds[num_icds].dl_handle != NULL) {
123       debug(D_LOG, "Loading ICD[%i] -> '%s'", num_icds, lib_path);
124       num_icds++;
125     }
126     free(lib_path);
127   }
128   RETURN(num_icds);
129 }
130
131 static void* _get_function_addr(void* dlh, clGetExtensionFunctionAddress_fn fn, const char*name) {
132   void *addr1;
133   addr1=dlsym(dlh, name);
134   if (addr1 == NULL) {
135     debug(D_WARN, "Missing global symbol '%s' in ICD, should be skipped", name);
136   }
137   void* addr2=NULL;
138   if (fn) {
139     addr2=(*fn)(name);
140     if (addr2 == NULL) {
141       debug(D_WARN, "Missing function '%s' in ICD, should be skipped", name);
142     }
143 #if DEBUG_OCL_ICD
144     if (addr1 && addr2 && addr1!=addr2) {
145       debug(D_WARN, "Function and symbol '%s' have different addresses!", name);
146     }
147 #endif
148   }
149   if (!addr2) addr2=addr1;
150   RETURN(addr2);
151 }
152
153 static int _allocate_platforms(int req) {
154   static cl_uint allocated=0;
155   debug(D_LOG,"Requesting allocation for %d platforms",req);
156   if (allocated - _num_picds < req) {
157     if (allocated==0) {
158       _picds=(struct platform_icd*)malloc(req*sizeof(struct platform_icd));
159     } else {
160       req = req - (allocated - _num_picds);
161       _picds=(struct platform_icd*)realloc(_picds, (allocated+req)*sizeof(struct platform_icd));
162     }
163     allocated += req;
164   }
165   RETURN(allocated - _num_picds);
166 }
167
168 static inline void _find_and_check_platforms(cl_uint num_icds) {
169   cl_uint i;
170   _num_icds = 0;
171   for( i=0; i<num_icds; i++){
172     debug(D_LOG, "Checking ICD %i", i);
173     struct vendor_icd *picd = &_icds[_num_icds];
174     void* dlh = _icds[i].dl_handle;
175     picd->ext_fn_ptr = _get_function_addr(dlh, NULL, "clGetExtensionFunctionAddress");
176     clIcdGetPlatformIDsKHR_fn plt_fn_ptr = 
177       _get_function_addr(dlh, picd->ext_fn_ptr, "clIcdGetPlatformIDsKHR");
178     clGetPlatformInfo_fn plt_info_ptr = 
179       _get_function_addr(dlh, picd->ext_fn_ptr, "clGetPlatformInfo");
180     if( picd->ext_fn_ptr == NULL
181         || plt_fn_ptr == NULL
182         || plt_info_ptr == NULL) {
183       debug(D_WARN, "Missing symbols in ICD, skipping it");
184       continue;
185     }
186     cl_uint num_platforms=0;
187     cl_int error;
188     error = (*plt_fn_ptr)(0, NULL, &num_platforms);
189     if( error != CL_SUCCESS || num_platforms == 0) {
190       debug(D_LOG, "No platform in ICD, skipping it");
191       continue;
192     }
193     cl_platform_id *platforms = (cl_platform_id *) malloc( sizeof(cl_platform_id) * num_platforms);
194     error = (*plt_fn_ptr)(num_platforms, platforms, NULL);
195     if( error != CL_SUCCESS ){
196       free(platforms);
197       debug(D_WARN, "Error in loading ICD platforms, skipping ICD");
198       continue;
199     }
200     cl_uint num_valid_platforms=0;
201     cl_uint j;
202     debug(D_LOG, "Try to load %d plateforms", num_platforms);
203     if (_allocate_platforms(num_platforms) < num_platforms) {
204       free(platforms);
205       debug(D_WARN, "Not enought platform allocated. Skipping ICD");
206       continue;
207     }
208     for(j=0; j<num_platforms; j++) {
209       debug(D_LOG, "Checking platform %i", j);
210       size_t param_value_size_ret;
211       struct platform_icd *p=&_picds[_num_picds];
212       p->extension_suffix=NULL;
213       p->vicd=&_icds[i];
214       p->pid=platforms[j];
215 #if DEBUG_OCL_ICD
216       if (debug_ocl_icd_mask & D_DUMP) {
217         dump_platform(p->pid);
218       }
219 #endif
220       error = plt_info_ptr(p->pid, CL_PLATFORM_EXTENSIONS, 0, NULL, &param_value_size_ret);
221       if (error != CL_SUCCESS) {
222         debug(D_WARN, "Error while loading extensions in platform %i, skipping it",j);
223         continue;
224       }
225       char *param_value = (char *)malloc(sizeof(char)*param_value_size_ret);
226       error = plt_info_ptr(p->pid, CL_PLATFORM_EXTENSIONS, param_value_size_ret, param_value, NULL);
227       if (error != CL_SUCCESS){
228         free(param_value);
229         debug(D_WARN, "Error while loading extensions in platform %i, skipping it", j);
230         continue;
231       }
232       debug(D_DUMP, "Supported extensions: %s", param_value);
233       if( strstr(param_value, "cl_khr_icd") == NULL){
234         free(param_value);
235         debug(D_WARN, "Missing khr extension in platform %i, skipping it", j);
236         continue;
237       }
238       free(param_value);
239       error = plt_info_ptr(p->pid, CL_PLATFORM_ICD_SUFFIX_KHR, 0, NULL, &param_value_size_ret);
240       if (error != CL_SUCCESS) {
241         debug(D_WARN, "Error while loading suffix in platform %i, skipping it", j);
242         continue;
243       }
244       param_value = (char *)malloc(sizeof(char)*param_value_size_ret);
245       error = plt_info_ptr(p->pid, CL_PLATFORM_ICD_SUFFIX_KHR, param_value_size_ret, param_value, NULL);
246       if (error != CL_SUCCESS){
247         debug(D_WARN, "Error while loading suffix in platform %i, skipping it", j);
248         free(param_value);
249         continue;
250       }
251       p->extension_suffix = param_value;
252       debug(D_LOG, "Extension suffix: %s", param_value);
253       num_valid_platforms++;
254       _num_picds++;
255     }
256     if( num_valid_platforms != 0 ) {
257       if ( _num_icds != i ) {
258         picd->dl_handle = dlh;
259       }
260       _num_icds++;
261       picd->num_platforms = num_valid_platforms;
262       _icds[i].first_platform = _num_picds - num_valid_platforms;
263     } else {
264       dlclose(dlh);
265     }
266     free(platforms);
267   }
268 }
269
270 static void _initClIcd( void ) {
271   if( _initialized )
272     return;
273 #if DEBUG_OCL_ICD
274   char *debug=getenv("OCL_ICD_DEBUG");
275   if (debug) {
276     debug_ocl_icd_mask=atoi(debug);
277     if (debug_ocl_icd_mask==0)
278       debug_ocl_icd_mask=1;
279   }
280 #endif
281   cl_uint num_icds = 0;
282   DIR *dir;
283   dir = opendir(_dir_path);
284   if(dir == NULL) {
285     goto abort;
286   }
287
288   num_icds = _find_num_icds(dir);
289   if(num_icds == 0) {
290     goto abort;
291   }
292
293   _icds = (struct vendor_icd*)malloc(num_icds * sizeof(struct vendor_icd));
294   if (_icds == NULL) {
295     goto abort;
296   }
297   
298   num_icds = _open_drivers(dir);
299   if(num_icds == 0) {
300     goto abort;
301   }
302
303   _find_and_check_platforms(num_icds);
304   if(_num_icds == 0){
305     goto abort;
306   }
307
308   if (_num_icds < num_icds) {
309     _icds = (struct vendor_icd*)realloc(_icds, _num_icds * sizeof(struct vendor_icd));
310   }
311   debug(D_WARN, "%d valid vendor(s)!", _num_icds);
312   _initialized = 1;
313   return;
314 abort:
315   _num_icds = 0;
316   _initialized = 1;
317   if (_icds) {
318     free(_icds);
319     _icds = NULL;
320   }
321   return;
322 }
323
324 #pragma GCC visibility pop
325
326 CL_API_ENTRY void * CL_API_CALL clGetExtensionFunctionAddress(const char * func_name) CL_API_SUFFIX__VERSION_1_0 {
327   if( !_initialized )
328     _initClIcd();
329   if( func_name == NULL )
330     return NULL;
331   cl_uint suffix_length;
332   cl_uint i;
333   void * return_value=NULL;
334   struct func_desc const * fn=&function_description[0];
335   while (fn->name != NULL) {
336     if (strcmp(func_name, fn->name)==0)
337       return fn->addr;
338     fn++;
339   }
340   for(i=0; i<_num_picds; i++) {
341     suffix_length = strlen(_picds[i].extension_suffix);
342     if( suffix_length > strlen(func_name) )
343       continue;
344     if(strcmp(_picds[i].extension_suffix, &func_name[strlen(func_name)-suffix_length]) == 0)
345       return (*_picds[i].vicd->ext_fn_ptr)(func_name);
346   }
347   return return_value;
348 }
349 typeof(clGetExtensionFunctionAddress) clGetExtensionFunctionAddress_hid __attribute__ ((alias ("clGetExtensionFunctionAddress"), visibility("hidden")));
350
351 CL_API_ENTRY cl_int CL_API_CALL
352 clGetPlatformIDs(cl_uint          num_entries,
353                  cl_platform_id * platforms,
354                  cl_uint *        num_platforms) CL_API_SUFFIX__VERSION_1_0 {
355   if( !_initialized )
356     _initClIcd();
357   if( platforms == NULL && num_platforms == NULL )
358     return CL_INVALID_VALUE;
359   if( num_entries == 0 && platforms != NULL )
360     return CL_INVALID_VALUE;
361   if( _num_icds == 0)
362     return CL_PLATFORM_NOT_FOUND_KHR;
363
364   cl_uint i;
365   if( num_platforms != NULL ){
366     *num_platforms = _num_picds;
367   }
368   if( platforms != NULL ) {
369     cl_uint n_platforms = _num_picds < num_entries ? _num_picds : num_entries;
370     for( i=0; i<n_platforms; i++) {
371       *(platforms++) = _picds[i].pid;
372     }
373   }
374   return CL_SUCCESS;
375 }
376 typeof(clGetPlatformIDs) clGetPlatformIDs_hid __attribute__ ((alias ("clGetPlatformIDs"), visibility("hidden")));
377