threaded_extension.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. """An ISAPI extension base class implemented using a thread-pool."""
  2. # $Id$
  3. import sys
  4. import time
  5. from isapi import isapicon, ExtensionError
  6. import isapi.simple
  7. from win32file import GetQueuedCompletionStatus, CreateIoCompletionPort, \
  8. PostQueuedCompletionStatus, CloseHandle
  9. from win32security import SetThreadToken
  10. from win32event import INFINITE
  11. from pywintypes import OVERLAPPED
  12. import threading
  13. import traceback
  14. ISAPI_REQUEST = 1
  15. ISAPI_SHUTDOWN = 2
  16. class WorkerThread(threading.Thread):
  17. def __init__(self, extension, io_req_port):
  18. self.running = False
  19. self.io_req_port = io_req_port
  20. self.extension = extension
  21. threading.Thread.__init__(self)
  22. # We wait 15 seconds for a thread to terminate, but if it fails to,
  23. # we don't want the process to hang at exit waiting for it...
  24. self.setDaemon(True)
  25. def run(self):
  26. self.running = True
  27. while self.running:
  28. errCode, bytes, key, overlapped = \
  29. GetQueuedCompletionStatus(self.io_req_port, INFINITE)
  30. if key == ISAPI_SHUTDOWN and overlapped is None:
  31. break
  32. # Let the parent extension handle the command.
  33. dispatcher = self.extension.dispatch_map.get(key)
  34. if dispatcher is None:
  35. raise RuntimeError("Bad request '%s'" % (key,))
  36. dispatcher(errCode, bytes, key, overlapped)
  37. def call_handler(self, cblock):
  38. self.extension.Dispatch(cblock)
  39. # A generic thread-pool based extension, using IO Completion Ports.
  40. # Sub-classes can override one method to implement a simple extension, or
  41. # may leverage the CompletionPort to queue their own requests, and implement a
  42. # fully asynch extension.
  43. class ThreadPoolExtension(isapi.simple.SimpleExtension):
  44. "Base class for an ISAPI extension based around a thread-pool"
  45. max_workers = 20
  46. worker_shutdown_wait = 15000 # 15 seconds for workers to quit...
  47. def __init__(self):
  48. self.workers = []
  49. # extensible dispatch map, for sub-classes that need to post their
  50. # own requests to the completion port.
  51. # Each of these functions is called with the result of
  52. # GetQueuedCompletionStatus for our port.
  53. self.dispatch_map = {
  54. ISAPI_REQUEST: self.DispatchConnection,
  55. }
  56. def GetExtensionVersion(self, vi):
  57. isapi.simple.SimpleExtension.GetExtensionVersion(self, vi)
  58. # As per Q192800, the CompletionPort should be created with the number
  59. # of processors, even if the number of worker threads is much larger.
  60. # Passing 0 means the system picks the number.
  61. self.io_req_port = CreateIoCompletionPort(-1, None, 0, 0)
  62. # start up the workers
  63. self.workers = []
  64. for i in range(self.max_workers):
  65. worker = WorkerThread(self, self.io_req_port)
  66. worker.start()
  67. self.workers.append(worker)
  68. def HttpExtensionProc(self, control_block):
  69. overlapped = OVERLAPPED()
  70. overlapped.object = control_block
  71. PostQueuedCompletionStatus(self.io_req_port, 0, ISAPI_REQUEST, overlapped)
  72. return isapicon.HSE_STATUS_PENDING
  73. def TerminateExtension(self, status):
  74. for worker in self.workers:
  75. worker.running = False
  76. for worker in self.workers:
  77. PostQueuedCompletionStatus(self.io_req_port, 0, ISAPI_SHUTDOWN, None)
  78. # wait for them to terminate - pity we aren't using 'native' threads
  79. # as then we could do a smart wait - but now we need to poll....
  80. end_time = time.time() + self.worker_shutdown_wait/1000
  81. alive = self.workers
  82. while alive:
  83. if time.time() > end_time:
  84. # xxx - might be nice to log something here.
  85. break
  86. time.sleep(0.2)
  87. alive = [w for w in alive if w.is_alive()]
  88. self.dispatch_map = {} # break circles
  89. CloseHandle(self.io_req_port)
  90. # This is the one operation the base class supports - a simple
  91. # Connection request. We setup the thread-token, and dispatch to the
  92. # sub-class's 'Dispatch' method.
  93. def DispatchConnection(self, errCode, bytes, key, overlapped):
  94. control_block = overlapped.object
  95. # setup the correct user for this request
  96. hRequestToken = control_block.GetImpersonationToken()
  97. SetThreadToken(None, hRequestToken)
  98. try:
  99. try:
  100. self.Dispatch(control_block)
  101. except:
  102. self.HandleDispatchError(control_block)
  103. finally:
  104. # reset the security context
  105. SetThreadToken(None, None)
  106. def Dispatch(self, ecb):
  107. """Overridden by the sub-class to handle connection requests.
  108. This class creates a thread-pool using a Windows completion port,
  109. and dispatches requests via this port. Sub-classes can generally
  110. implement each connection request using blocking reads and writes, and
  111. the thread-pool will still provide decent response to the end user.
  112. The sub-class can set a max_workers attribute (default is 20). Note
  113. that this generally does *not* mean 20 threads will all be concurrently
  114. running, via the magic of Windows completion ports.
  115. There is no default implementation - sub-classes must implement this.
  116. """
  117. raise NotImplementedError("sub-classes should override Dispatch")
  118. def HandleDispatchError(self, ecb):
  119. """Handles errors in the Dispatch method.
  120. When a Dispatch method call fails, this method is called to handle
  121. the exception. The default implementation formats the traceback
  122. in the browser.
  123. """
  124. ecb.HttpStatusCode = isapicon.HSE_STATUS_ERROR
  125. #control_block.LogData = "we failed!"
  126. exc_typ, exc_val, exc_tb = sys.exc_info()
  127. limit = None
  128. try:
  129. try:
  130. import cgi
  131. ecb.SendResponseHeaders("200 OK", "Content-type: text/html\r\n\r\n",
  132. False)
  133. print(file=ecb)
  134. print("<H3>Traceback (most recent call last):</H3>", file=ecb)
  135. list = traceback.format_tb(exc_tb, limit) + \
  136. traceback.format_exception_only(exc_typ, exc_val)
  137. print("<PRE>%s<B>%s</B></PRE>" % (
  138. cgi.escape("".join(list[:-1])), cgi.escape(list[-1]),), file=ecb)
  139. except ExtensionError:
  140. # The client disconnected without reading the error body -
  141. # its probably not a real browser at the other end, ignore it.
  142. pass
  143. except:
  144. print("FAILED to render the error message!")
  145. traceback.print_exc()
  146. print("ORIGINAL extension error:")
  147. traceback.print_exception(exc_typ, exc_val, exc_tb)
  148. finally:
  149. # holding tracebacks in a local of a frame that may itself be
  150. # part of a traceback used to be evil and cause leaks!
  151. exc_tb = None
  152. ecb.DoneWithSession()