hook-sqlalchemy.py 3.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. #-----------------------------------------------------------------------------
  2. # Copyright (c) 2005-2021, PyInstaller Development Team.
  3. #
  4. # Distributed under the terms of the GNU General Public License (version 2
  5. # or later) with exception for distributing the bootloader.
  6. #
  7. # The full license is in the file COPYING.txt, distributed with this software.
  8. #
  9. # SPDX-License-Identifier: (GPL-2.0-or-later WITH Bootloader-exception)
  10. #-----------------------------------------------------------------------------
  11. import re
  12. from PyInstaller.utils.hooks import (
  13. exec_statement, is_module_satisfies, logger)
  14. from PyInstaller.lib.modulegraph.modulegraph import SourceModule
  15. from PyInstaller.lib.modulegraph.util import guess_encoding
  16. # 'sqlalchemy.testing' causes bundling a lot of unnecessary modules.
  17. excludedimports = ['sqlalchemy.testing']
  18. # include most common database bindings
  19. # some database bindings are detected and include some
  20. # are not. We should explicitly include database backends.
  21. hiddenimports = ['pysqlite2', 'MySQLdb', 'psycopg2', 'sqlalchemy.ext.baked']
  22. if is_module_satisfies('sqlalchemy >= 1.4'):
  23. hiddenimports.append("sqlalchemy.sql.default_comparator")
  24. # In SQLAlchemy >= 0.6, the "sqlalchemy.dialects" package provides dialects.
  25. if is_module_satisfies('sqlalchemy >= 0.6'):
  26. dialects = exec_statement("import sqlalchemy.dialects;print(sqlalchemy.dialects.__all__)")
  27. dialects = eval(dialects.strip())
  28. for n in dialects:
  29. hiddenimports.append("sqlalchemy.dialects." + n)
  30. # In SQLAlchemy <= 0.5, the "sqlalchemy.databases" package provides dialects.
  31. else:
  32. databases = exec_statement("import sqlalchemy.databases; print(sqlalchemy.databases.__all__)")
  33. databases = eval(databases.strip())
  34. for n in databases:
  35. hiddenimports.append("sqlalchemy.databases." + n)
  36. def hook(hook_api):
  37. """
  38. SQLAlchemy 0.9 introduced the decorator 'util.dependencies'. This
  39. decorator does imports. eg:
  40. @util.dependencies("sqlalchemy.sql.schema")
  41. This hook scans for included SQLAlchemy modules and then scans those modules
  42. for any util.dependencies and marks those modules as hidden imports.
  43. """
  44. if not is_module_satisfies('sqlalchemy >= 0.9'):
  45. return
  46. # this parser is very simplistic but seems to catch all cases as of V1.1
  47. depend_regex = re.compile(r'@util.dependencies\([\'"](.*?)[\'"]\)')
  48. hidden_imports_set = set()
  49. known_imports = set()
  50. for node in hook_api.module_graph.iter_graph(start=hook_api.module):
  51. if isinstance(node, SourceModule) and \
  52. node.identifier.startswith('sqlalchemy.'):
  53. known_imports.add(node.identifier)
  54. # Determine the encoding of the source file.
  55. with open(node.filename, 'rb') as f:
  56. encoding = guess_encoding(f)
  57. # Use that to open the file.
  58. with open(node.filename, 'r', encoding=encoding) as f:
  59. for match in depend_regex.findall(f.read()):
  60. hidden_imports_set.add(match)
  61. hidden_imports_set -= known_imports
  62. if len(hidden_imports_set):
  63. logger.info(" Found %d sqlalchemy hidden imports",
  64. len(hidden_imports_set))
  65. hook_api.add_imports(*list(hidden_imports_set))