Update the address of the Free Software Foundation.
[wine] / dlls / ntdll / tests / port.c
1 /* Unit test suite for Ntdll Port API functions
2  *
3  * Copyright 2006 James Hawkins
4  *
5  * This library is free software; you can redistribute it and/or
6  * modify it under the terms of the GNU Lesser General Public
7  * License as published by the Free Software Foundation; either
8  * version 2.1 of the License, or (at your option) any later version.
9  *
10  * This library is distributed in the hope that it will be useful,
11  * but WITHOUT ANY WARRANTY; without even the implied warranty of
12  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
13  * Lesser General Public License for more details.
14  *
15  * You should have received a copy of the GNU Lesser General Public
16  * License along with this library; if not, write to the Free Software
17  * Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301, USA
18  */
19
20 #include <stdio.h>
21 #include <stdarg.h>
22
23 #include "ntstatus.h"
24 #define WIN32_NO_STATUS
25 #include "windef.h"
26 #include "winbase.h"
27 #include "winuser.h"
28 #include "winreg.h"
29 #include "winnls.h"
30 #include "wine/test.h"
31 #include "wine/unicode.h"
32 #include "winternl.h"
33
34 #ifndef __WINE_WINTERNL_H
35
36 typedef struct _CLIENT_ID
37 {
38    HANDLE UniqueProcess;
39    HANDLE UniqueThread;
40 } CLIENT_ID, *PCLIENT_ID;
41
42 typedef struct _LPC_SECTION_WRITE
43 {
44   ULONG Length;
45   HANDLE SectionHandle;
46   ULONG SectionOffset;
47   ULONG ViewSize;
48   PVOID ViewBase;
49   PVOID TargetViewBase;
50 } LPC_SECTION_WRITE, *PLPC_SECTION_WRITE;
51
52 typedef struct _LPC_SECTION_READ
53 {
54   ULONG Length;
55   ULONG ViewSize;
56   PVOID ViewBase;
57 } LPC_SECTION_READ, *PLPC_SECTION_READ;
58
59 typedef struct _LPC_MESSAGE
60 {
61   USHORT DataSize;
62   USHORT MessageSize;
63   USHORT MessageType;
64   USHORT VirtualRangesOffset;
65   CLIENT_ID ClientId;
66   ULONG MessageId;
67   ULONG SectionSize;
68   UCHAR Data[ANYSIZE_ARRAY];
69 } LPC_MESSAGE, *PLPC_MESSAGE;
70
71 #endif
72
73 /* Types of LPC messages */
74 #define UNUSED_MSG_TYPE                 0
75 #define LPC_REQUEST                     1
76 #define LPC_REPLY                       2
77 #define LPC_DATAGRAM                    3
78 #define LPC_LOST_REPLY                  4
79 #define LPC_PORT_CLOSED                 5
80 #define LPC_CLIENT_DIED                 6
81 #define LPC_EXCEPTION                   7
82 #define LPC_DEBUG_EVENT                 8
83 #define LPC_ERROR_EVENT                 9
84 #define LPC_CONNECTION_REQUEST         10
85
86 static const WCHAR PORTNAME[] = {'\\','M','y','P','o','r','t',0};
87
88 #define REQUEST1    "Request1"
89 #define REQUEST2    "Request2"
90 #define REPLY       "Reply"
91
92 #define MAX_MESSAGE_LEN    30
93
94 UNICODE_STRING  port;
95 static char     selfname[MAX_PATH];
96 static int      myARGC;
97 static char**   myARGV;
98
99 /* Function pointers for ntdll calls */
100 static HMODULE hntdll = 0;
101 static NTSTATUS (WINAPI *pNtCompleteConnectPort)(HANDLE);
102 static NTSTATUS (WINAPI *pNtAcceptConnectPort)(PHANDLE,ULONG,PLPC_MESSAGE,ULONG,
103                                                ULONG,PLPC_SECTION_READ);
104 static NTSTATUS (WINAPI *pNtReplyPort)(HANDLE,PLPC_MESSAGE);
105 static NTSTATUS (WINAPI *pNtReplyWaitReceivePort)(PHANDLE,PULONG,PLPC_MESSAGE,
106                                                   PLPC_MESSAGE);
107 static NTSTATUS (WINAPI *pNtCreatePort)(PHANDLE,POBJECT_ATTRIBUTES,ULONG,ULONG,ULONG);
108 static NTSTATUS (WINAPI *pNtRequestWaitReplyPort)(HANDLE,PLPC_MESSAGE,PLPC_MESSAGE);
109 static NTSTATUS (WINAPI *pNtRequestPort)(HANDLE,PLPC_MESSAGE);
110 static NTSTATUS (WINAPI *pNtRegisterThreadTerminatePort)(HANDLE);
111 static NTSTATUS (WINAPI *pNtConnectPort)(PHANDLE,PUNICODE_STRING,
112                                          PSECURITY_QUALITY_OF_SERVICE,
113                                          PLPC_SECTION_WRITE,PLPC_SECTION_READ,
114                                          PVOID,PVOID,PULONG);
115 static NTSTATUS (WINAPI *pRtlInitUnicodeString)(PUNICODE_STRING,LPCWSTR);
116 static NTSTATUS (WINAPI *pNtWaitForSingleObject)(HANDLE,BOOLEAN,PLARGE_INTEGER);
117
118 static BOOL init_function_ptrs(void)
119 {
120     hntdll = LoadLibraryA("ntdll.dll");
121
122     if (hntdll)
123     {
124         pNtCompleteConnectPort = (void *)GetProcAddress(hntdll, "NtCompleteConnectPort");
125         pNtAcceptConnectPort = (void *)GetProcAddress(hntdll, "NtAcceptConnectPort");
126         pNtReplyPort = (void *)GetProcAddress(hntdll, "NtReplyPort");
127         pNtReplyWaitReceivePort = (void *)GetProcAddress(hntdll, "NtReplyWaitReceivePort");
128         pNtCreatePort = (void *)GetProcAddress(hntdll, "NtCreatePort");
129         pNtRequestWaitReplyPort = (void *)GetProcAddress(hntdll, "NtRequestWaitReplyPort");
130         pNtRequestPort = (void *)GetProcAddress(hntdll, "NtRequestPort");
131         pNtRegisterThreadTerminatePort = (void *)GetProcAddress(hntdll, "NtRegisterThreadTerminatePort");
132         pNtConnectPort = (void *)GetProcAddress(hntdll, "NtConnectPort");
133         pRtlInitUnicodeString = (void *)GetProcAddress(hntdll, "RtlInitUnicodeString");
134         pNtWaitForSingleObject = (void *)GetProcAddress(hntdll, "NtWaitForSingleObject");
135     }
136
137     if (!pNtCompleteConnectPort || !pNtAcceptConnectPort ||
138         !pNtReplyWaitReceivePort || !pNtCreatePort || !pNtRequestWaitReplyPort ||
139         !pNtRequestPort || !pNtRegisterThreadTerminatePort ||
140         !pNtConnectPort || !pRtlInitUnicodeString)
141     {
142         return FALSE;
143     }
144
145     return TRUE;
146 }
147
148 static void ProcessConnectionRequest(PLPC_MESSAGE LpcMessage, PHANDLE pAcceptPortHandle)
149 {
150     NTSTATUS status;
151
152     ok(LpcMessage->MessageType == LPC_CONNECTION_REQUEST,
153        "Expected LPC_CONNECTION_REQUEST, got %d\n", LpcMessage->MessageType);
154     ok(!*LpcMessage->Data, "Expected empty string!\n");
155
156     status = pNtAcceptConnectPort(pAcceptPortHandle, 0, LpcMessage, 1, 0, NULL);
157     ok(status == STATUS_SUCCESS, "Expected STATUS_SUCCESS, got %ld\n", status);
158     
159     status = pNtCompleteConnectPort(*pAcceptPortHandle);
160     ok(status == STATUS_SUCCESS, "Expected STATUS_SUCCESS, got %ld\n", status);
161 }
162
163 static void ProcessLpcRequest(HANDLE PortHandle, PLPC_MESSAGE LpcMessage)
164 {
165     NTSTATUS status;
166
167     ok(LpcMessage->MessageType == LPC_REQUEST,
168        "Expected LPC_REQUEST, got %d\n", LpcMessage->MessageType);
169     ok(!lstrcmp((LPSTR)LpcMessage->Data, REQUEST2),
170        "Expected %s, got %s\n", REQUEST2, LpcMessage->Data);
171
172     lstrcpy((LPSTR)LpcMessage->Data, REPLY);
173
174     status = pNtReplyPort(PortHandle, LpcMessage);
175     ok(status == STATUS_SUCCESS, "Expected STATUS_SUCCESS, got %ld\n", status);
176     ok(LpcMessage->MessageType == LPC_REQUEST,
177        "Expected LPC_REQUEST, got %d\n", LpcMessage->MessageType);
178     ok(!lstrcmp((LPSTR)LpcMessage->Data, REPLY),
179        "Expected %s, got %s\n", REPLY, LpcMessage->Data);
180 }
181
182 static DWORD WINAPI test_ports_client(LPVOID arg)
183 {
184     SECURITY_QUALITY_OF_SERVICE sqos;
185     LPC_MESSAGE *LpcMessage, *out;
186     HANDLE PortHandle;
187     ULONG len, size;
188     NTSTATUS status;
189
190     sqos.Length = sizeof(SECURITY_QUALITY_OF_SERVICE);
191     sqos.ImpersonationLevel = SecurityImpersonation;
192     sqos.ContextTrackingMode = SECURITY_STATIC_TRACKING;
193     sqos.EffectiveOnly = TRUE;
194
195     status = pNtConnectPort(&PortHandle, &port, &sqos, 0, 0, &len, NULL, NULL);
196     ok(status == STATUS_SUCCESS, "Expected STATUS_SUCCESS, got %ld\n", status);
197
198     status = pNtRegisterThreadTerminatePort(PortHandle);
199     ok(status == STATUS_SUCCESS, "Expected STATUS_SUCCESS, got %ld\n", status);
200
201     size = FIELD_OFFSET(LPC_MESSAGE, Data) + MAX_MESSAGE_LEN;
202     LpcMessage = HeapAlloc(GetProcessHeap(), 0, size);
203     out = HeapAlloc(GetProcessHeap(), 0, size);
204
205     memset(LpcMessage, 0, size);
206     LpcMessage->DataSize = lstrlen(REQUEST1) + 1;
207     LpcMessage->MessageSize = FIELD_OFFSET(LPC_MESSAGE, Data) + LpcMessage->DataSize;
208     lstrcpy((LPSTR)LpcMessage->Data, REQUEST1);
209
210     status = pNtRequestPort(PortHandle, LpcMessage);
211     ok(status == STATUS_SUCCESS, "Expected STATUS_SUCCESS, got %ld\n", status);
212     ok(LpcMessage->MessageType == 0, "Expected 0, got %d\n", LpcMessage->MessageType);
213     ok(!lstrcmp((LPSTR)LpcMessage->Data, REQUEST1),
214        "Expected %s, got %s\n", REQUEST1, LpcMessage->Data);
215
216     /* Fill in the message */
217     memset(LpcMessage, 0, size);
218     LpcMessage->DataSize = lstrlen(REQUEST2) + 1;
219     LpcMessage->MessageSize = FIELD_OFFSET(LPC_MESSAGE, Data) + LpcMessage->DataSize;
220     lstrcpy((LPSTR)LpcMessage->Data, REQUEST2);
221
222     /* Send the message and wait for the reply */
223     status = pNtRequestWaitReplyPort(PortHandle, LpcMessage, out);
224     ok(status == STATUS_SUCCESS, "Expected STATUS_SUCCESS, got %ld\n", status);
225     ok(!lstrcmp((LPSTR)out->Data, REPLY), "Expected %s, got %s\n", REPLY, out->Data);
226     ok(out->MessageType == LPC_REPLY, "Expected LPC_REPLY, got %d\n", out->MessageType);
227
228     return 0;
229 }
230
231 static void test_ports_server(void)
232 {
233     OBJECT_ATTRIBUTES obj;
234     HANDLE PortHandle;
235     HANDLE AcceptPortHandle;
236     PLPC_MESSAGE LpcMessage;
237     ULONG size;
238     NTSTATUS status;
239     BOOL done = FALSE;
240
241     pRtlInitUnicodeString(&port, PORTNAME);
242
243     memset(&obj, 0, sizeof(OBJECT_ATTRIBUTES));
244     obj.Length = sizeof(OBJECT_ATTRIBUTES);
245     obj.ObjectName = &port;
246
247     status = pNtCreatePort(&PortHandle, &obj, 100, 100, 0);
248     todo_wine
249     {
250         ok(status == STATUS_SUCCESS, "Expected STATUS_SUCCESS, got %ld\n", status);
251     }
252
253     size = FIELD_OFFSET(LPC_MESSAGE, Data) + MAX_MESSAGE_LEN;
254     LpcMessage = HeapAlloc(GetProcessHeap(), 0, size);
255     memset(LpcMessage, 0, size);
256
257     while (TRUE)
258     {
259         status = pNtReplyWaitReceivePort(PortHandle, NULL, NULL, LpcMessage);
260         todo_wine
261         {
262             ok(status == STATUS_SUCCESS, "Expected STATUS_SUCCESS, got %ld(%lx)\n", status, status);
263         }
264         /* STATUS_INVALID_HANDLE: win2k without admin rights will perform an
265          *                        endless loop here
266          */
267         if ((status == STATUS_NOT_IMPLEMENTED) ||
268             (status == STATUS_INVALID_HANDLE)) return;
269
270         switch (LpcMessage->MessageType)
271         {
272             case LPC_CONNECTION_REQUEST:
273                 ProcessConnectionRequest(LpcMessage, &AcceptPortHandle);
274                 break;
275
276             case LPC_REQUEST:
277                 ProcessLpcRequest(PortHandle, LpcMessage);
278                 done = TRUE;
279                 break;
280
281             case LPC_DATAGRAM:
282                 ok(!lstrcmp((LPSTR)LpcMessage->Data, REQUEST1),
283                    "Expected %s, got %s\n", REQUEST1, LpcMessage->Data);
284                 break;
285
286             case LPC_CLIENT_DIED:
287                 ok(done, "Expected LPC request to be completed!\n");
288                 return;
289
290             default:
291                 ok(FALSE, "Unexpected message: %d\n", LpcMessage->MessageType);
292                 break;
293         }
294     }
295 }
296
297 START_TEST(port)
298 {
299     HANDLE thread;
300     DWORD id;
301
302     if (!init_function_ptrs())
303         return;
304
305     myARGC = winetest_get_mainargs(&myARGV);
306     strcpy(selfname, myARGV[0]);
307
308     thread = CreateThread(NULL, 0, test_ports_client, NULL, 0, &id);
309     ok(thread != NULL, "Expected non-NULL thread handle!\n");
310
311     test_ports_server();
312     CloseHandle(thread);
313 }