Rename files and cleanup includes
[ocl-icd] / ocl_icd_loader.c
1 /**
2 Copyright (c) 2012, Brice Videau <brice.videau@imag.fr>
3 All rights reserved.
4       
5 Redistribution and use in source and binary forms, with or without
6 modification, are permitted provided that the following conditions are met:
7     
8 1. Redistributions of source code must retain the above copyright notice, this
9    list of conditions and the following disclaimer.
10 2. Redistributions in binary form must reproduce the above copyright notice,
11    this list of conditions and the following disclaimer in the documentation
12    and/or other materials provided with the distribution.
13         
14 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
15 ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
16 WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
17 DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
18 ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
19 (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
20 LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
21 ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
22 (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
23 SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
24 */
25
26 #include <dirent.h>
27 #include <stdio.h>
28 #include <stdlib.h>
29 #include <string.h>
30 #include <dlfcn.h>
31 #include <CL/opencl.h>
32 #include "ocl_icd_loader.h"
33
34 typedef CL_API_ENTRY void * (CL_API_CALL *clGetExtensionFunctionAddress_fn)(const char * /* func_name */) CL_API_SUFFIX__VERSION_1_0;
35
36 static cl_uint _initialized = 0;
37 static cl_uint _num_valid_vendors = 0;
38 static cl_uint *_vendors_num_platforms = NULL;
39 static cl_platform_id **_vendors_platforms = NULL;
40 static void **_vendor_dl_handles = NULL;
41 static clGetExtensionFunctionAddress_fn *_ext_fn_ptr = NULL;
42 static char** _vendors_extension_suffixes = NULL;
43 static const char *_dir_path="/etc/OpenCL/vendors/";
44
45 static inline cl_uint _find_num_vendors(DIR *dir) {
46   cl_uint num_vendors = 0;
47   struct dirent *ent;
48   while( (ent=readdir(dir)) != NULL ){
49     if( strcmp(ent->d_name,".") == 0 || strcmp(ent->d_name,"..") == 0 )
50       continue;
51     cl_uint d_name_len = strlen(ent->d_name);
52     if( strcmp(ent->d_name + d_name_len - 4, ".icd" ) != 0 )
53       continue;
54 //    printf("%s%s\n", _dir_path, ent->d_name);
55     num_vendors++;
56   }
57   rewinddir(dir);
58   return num_vendors;
59 }
60
61 static inline cl_uint _open_drivers(DIR *dir) {
62   cl_uint num_vendors = 0;
63   struct dirent *ent;
64   while( (ent=readdir(dir)) != NULL ){
65     if( strcmp(ent->d_name,".") == 0 || strcmp(ent->d_name,"..") == 0 )
66       continue;
67     cl_uint d_name_len = strlen(ent->d_name);
68     if( strcmp(ent->d_name + d_name_len - 4, ".icd" ) != 0 )
69       continue;
70     char * lib_path;
71     char * err;
72     unsigned int lib_path_length = strlen(_dir_path) + strlen(ent->d_name) + 1;
73     lib_path = malloc(lib_path_length*sizeof(char));
74     sprintf(lib_path,"%s%s", _dir_path, ent->d_name);
75     FILE *f = fopen(lib_path,"r");
76     free(lib_path);
77
78     fseek(f, 0, SEEK_END);
79     lib_path_length = ftell(f)+1;
80     fseek(f, 0, SEEK_SET);
81     if(lib_path_length == 1) {
82       fclose(f);
83       continue;
84     }
85     lib_path = malloc(lib_path_length*sizeof(char));
86     err = fgets(lib_path, lib_path_length, f);
87     fclose(f);
88     if( err == NULL ) {
89       free(lib_path);
90       continue;
91     }
92
93     lib_path_length = strlen(lib_path);
94     
95     if( lib_path[lib_path_length-1] == '\n' )
96       lib_path[lib_path_length-1] = '\0';
97
98     _vendor_dl_handles[num_vendors] = dlopen(lib_path, RTLD_LAZY|RTLD_LOCAL);//|RTLD_DEEPBIND);
99     free(lib_path);
100     if(_vendor_dl_handles[num_vendors] != NULL)      
101       num_vendors++;
102   }
103   return num_vendors;
104 }
105
106 static inline void _find_and_check_platforms(cl_uint num_vendors) {
107   cl_uint i;
108   _num_valid_vendors = 0;
109   for( i=0; i<num_vendors; i++){
110     cl_uint num_valid_platforms=0;
111     cl_uint num_platforms=0;
112     cl_platform_id *platforms;
113     cl_int error;
114     _ext_fn_ptr[_num_valid_vendors] = dlsym(_vendor_dl_handles[i], "clGetExtensionFunctionAddress");
115     clIcdGetPlatformIDsKHR_fn plt_fn_ptr;
116     if( _ext_fn_ptr[_num_valid_vendors] == NULL )
117       continue;
118     plt_fn_ptr = (*_ext_fn_ptr[_num_valid_vendors])("clIcdGetPlatformIDsKHR");
119     if( plt_fn_ptr == NULL )
120       continue;
121     error = (*plt_fn_ptr)(0, NULL, &num_platforms);
122     if( error != CL_SUCCESS || num_platforms == 0)
123       continue;
124     platforms = (cl_platform_id *) malloc( sizeof(cl_platform_id) * num_platforms);
125     error = (*plt_fn_ptr)(num_platforms, platforms, NULL);
126     if( error != CL_SUCCESS ){
127       free(platforms);
128       continue;
129     }
130     _vendors_platforms[_num_valid_vendors] = (cl_platform_id *)malloc(num_platforms * sizeof(cl_platform_id));
131     cl_uint j;
132     for(j=0; j<num_platforms; j++) {
133       size_t param_value_size_ret;
134       error = ((struct _cl_platform_id *)platforms[j])->dispatch->clGetPlatformInfo(platforms[j], CL_PLATFORM_EXTENSIONS, 0, NULL, &param_value_size_ret);
135       if (error != CL_SUCCESS)
136         continue;
137       char *param_value = (char *)malloc(sizeof(char)*param_value_size_ret);
138       error = ((struct _cl_platform_id *)platforms[j])->dispatch->clGetPlatformInfo(platforms[j], CL_PLATFORM_EXTENSIONS, param_value_size_ret, param_value, NULL);
139       if (error != CL_SUCCESS){
140         free(param_value);
141         continue;
142       }
143       if( strstr(param_value, "cl_khr_icd") == NULL){
144         free(param_value);
145         continue;
146       }
147       free(param_value);
148       error = ((struct _cl_platform_id *)platforms[j])->dispatch->clGetPlatformInfo(platforms[j], CL_PLATFORM_ICD_SUFFIX_KHR, 0, NULL, &param_value_size_ret);
149       if (error != CL_SUCCESS)
150         continue;
151       param_value = (char *)malloc(sizeof(char)*param_value_size_ret);
152       error = ((struct _cl_platform_id *)platforms[j])->dispatch->clGetPlatformInfo(platforms[j], CL_PLATFORM_ICD_SUFFIX_KHR, param_value_size_ret, param_value, NULL);
153       if (error != CL_SUCCESS){
154         free(param_value);
155         continue;
156       }
157       _vendors_extension_suffixes[_num_valid_vendors] = param_value;
158       _vendors_platforms[_num_valid_vendors][num_valid_platforms] = platforms[j];
159       num_valid_platforms++;
160     }
161     if( num_valid_platforms != 0 ) {
162       _vendors_num_platforms[_num_valid_vendors] = num_valid_platforms;
163       _num_valid_vendors++;
164     } else {
165       free(_vendors_platforms[_num_valid_vendors]);
166       dlclose(_vendor_dl_handles[i]);
167     }
168     free(platforms);
169   }
170 }
171
172 static void _initClIcd( void ) {
173   if( _initialized )
174     return;
175   cl_uint num_vendors = 0;
176   DIR *dir;
177   dir = opendir(_dir_path);
178   if(dir == NULL) {
179     _num_valid_vendors = 0;
180     _initialized = 1;
181     return;
182   }
183
184   num_vendors = _find_num_vendors(dir);
185 //  printf("%d vendor(s)!\n", num_vendors);
186   if(num_vendors == 0) {
187     _num_valid_vendors = 0;
188     _initialized = 1;
189     return;
190   }
191
192   _vendor_dl_handles = (void **)malloc(num_vendors * sizeof(void *));
193   num_vendors = _open_drivers(dir);
194 //  printf("%d vendor(s)!\n", num_vendors);
195   if(num_vendors == 0) {
196     free( _vendor_dl_handles );
197     _num_valid_vendors = 0;
198     _initialized = 1;
199     return;
200   }
201
202   _ext_fn_ptr = (clGetExtensionFunctionAddress_fn *)malloc(num_vendors * sizeof(clGetExtensionFunctionAddress_fn));
203   _vendors_extension_suffixes = (char **) malloc (sizeof(char *) * num_vendors);
204   _vendors_num_platforms = (cl_uint *)malloc(num_vendors * sizeof(cl_uint));
205   _vendors_platforms = (cl_platform_id **)malloc(num_vendors * sizeof(cl_platform_id *));
206   _find_and_check_platforms(num_vendors);
207   if(_num_valid_vendors == 0){
208     free( _vendor_dl_handles );
209     free( _ext_fn_ptr );
210     free( _vendors_extension_suffixes );
211     free( _vendors_platforms );
212     free( _vendors_num_platforms );
213   }
214 //  printf("%d valid vendor(s)!\n", _num_valid_vendors);
215   _initialized = 1;
216 }
217
218 CL_API_ENTRY void * CL_API_CALL clGetExtensionFunctionAddress(const char * func_name) CL_API_SUFFIX__VERSION_1_0 {
219   if( !_initialized )
220     _initClIcd();
221   if( func_name == NULL )
222     return NULL;
223   cl_uint suffix_length;
224   cl_uint i;
225   void * return_value=NULL;
226   for(i=0; i<_num_valid_vendors; i++) {
227     suffix_length = strlen(_vendors_extension_suffixes[i]);
228     if( suffix_length > strlen(func_name) )
229       continue;
230     if(strcmp(_vendors_extension_suffixes[i], &func_name[strlen(func_name)-suffix_length]) == 0)
231       return (*_ext_fn_ptr[i])(func_name);
232   }
233   return return_value;
234 }
235
236 CL_API_ENTRY cl_int CL_API_CALL
237 clGetPlatformIDs(cl_uint          num_entries,
238                  cl_platform_id * platforms,
239                  cl_uint *        num_platforms) CL_API_SUFFIX__VERSION_1_0 {
240   if( !_initialized )
241     _initClIcd();
242   if( platforms == NULL && num_platforms == NULL )
243     return CL_INVALID_VALUE;
244   if( num_entries == 0 && platforms != NULL )
245     return CL_INVALID_VALUE;
246   if( _num_valid_vendors == 0)
247     return CL_PLATFORM_NOT_FOUND_KHR;
248
249   cl_uint i;
250   cl_uint n_platforms=0;
251   for(i=0; i<_num_valid_vendors; i++) {
252     n_platforms += _vendors_num_platforms[i];
253   }
254   if( num_platforms != NULL ){
255     *num_platforms = n_platforms;
256   }
257   if( platforms != NULL ) {
258     n_platforms = n_platforms < num_entries ? n_platforms : num_entries;
259     for( i=0; i<_num_valid_vendors; i++) {
260       cl_uint j;
261       for(j=0; j<_vendors_num_platforms[i]; j++) {
262         *(platforms++) = _vendors_platforms[i][j];
263         n_platforms--;
264         if( n_platforms == 0 )
265           return CL_SUCCESS;
266       }
267     }
268   }
269   return CL_SUCCESS;
270 }
271
272