test-db.ts 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157
  1. import { DataSource, EntityManager, Repository } from 'typeorm';
  2. import { vi, beforeEach, afterEach } from 'vitest';
  3. /**
  4. * 创建模拟的数据源
  5. */
  6. export function createMockDataSource() {
  7. const manager = createMockEntityManager();
  8. const mockDataSource = {
  9. initialize: vi.fn().mockResolvedValue(undefined),
  10. destroy: vi.fn().mockResolvedValue(undefined),
  11. isInitialized: true,
  12. manager,
  13. getRepository: vi.fn().mockImplementation(() => createMockRepository()),
  14. createQueryBuilder: vi.fn().mockReturnValue(createMockQueryBuilder()),
  15. transaction: vi.fn().mockImplementation(async (callback) => {
  16. return callback(manager);
  17. }),
  18. synchronize: vi.fn().mockResolvedValue(undefined),
  19. dropDatabase: vi.fn().mockResolvedValue(undefined)
  20. };
  21. return mockDataSource;
  22. }
  23. /**
  24. * 创建模拟的实体管理器
  25. */
  26. export function createMockEntityManager(): EntityManager {
  27. return {
  28. find: vi.fn().mockResolvedValue([]),
  29. findOne: vi.fn().mockResolvedValue(null),
  30. save: vi.fn().mockImplementation((entity) => Promise.resolve(entity)),
  31. update: vi.fn().mockResolvedValue({ affected: 1 }),
  32. delete: vi.fn().mockResolvedValue({ affected: 1 }),
  33. createQueryBuilder: vi.fn().mockReturnValue(createMockQueryBuilder()),
  34. transaction: vi.fn().mockImplementation(async (callback) => {
  35. return callback(createMockEntityManager());
  36. }),
  37. getRepository: vi.fn().mockImplementation(() => createMockRepository())
  38. } as any;
  39. }
  40. /**
  41. * 创建模拟的Repository
  42. */
  43. export function createMockRepository<T extends object = any>(): Repository<T> {
  44. return {
  45. find: vi.fn().mockResolvedValue([]),
  46. findOne: vi.fn().mockResolvedValue(null),
  47. findOneBy: vi.fn().mockResolvedValue(null),
  48. findOneByOrFail: vi.fn().mockResolvedValue(null),
  49. findBy: vi.fn().mockResolvedValue([]),
  50. findAndCount: vi.fn().mockResolvedValue([[], 0]),
  51. findAndCountBy: vi.fn().mockResolvedValue([[], 0]),
  52. save: vi.fn().mockImplementation((entity) => Promise.resolve(entity)),
  53. update: vi.fn().mockResolvedValue({ affected: 1 }),
  54. delete: vi.fn().mockResolvedValue({ affected: 1 }),
  55. create: vi.fn().mockImplementation((entity) => ({ ...entity, id: Date.now() })),
  56. createQueryBuilder: vi.fn().mockReturnValue(createMockQueryBuilder()),
  57. count: vi.fn().mockResolvedValue(0),
  58. countBy: vi.fn().mockResolvedValue(0),
  59. exist: vi.fn().mockResolvedValue(false)
  60. } as any;
  61. }
  62. /**
  63. * 创建模拟的QueryBuilder
  64. */
  65. export function createMockQueryBuilder() {
  66. const mockQueryBuilder = {
  67. select: vi.fn().mockReturnThis(),
  68. from: vi.fn().mockReturnThis(),
  69. where: vi.fn().mockReturnThis(),
  70. andWhere: vi.fn().mockReturnThis(),
  71. orWhere: vi.fn().mockReturnThis(),
  72. leftJoin: vi.fn().mockReturnThis(),
  73. innerJoin: vi.fn().mockReturnThis(),
  74. orderBy: vi.fn().mockReturnThis(),
  75. groupBy: vi.fn().mockReturnThis(),
  76. having: vi.fn().mockReturnThis(),
  77. skip: vi.fn().mockReturnThis(),
  78. take: vi.fn().mockReturnThis(),
  79. getMany: vi.fn().mockResolvedValue([]),
  80. getOne: vi.fn().mockResolvedValue(null),
  81. getCount: vi.fn().mockResolvedValue(0),
  82. getRawMany: vi.fn().mockResolvedValue([]),
  83. getRawOne: vi.fn().mockResolvedValue(null),
  84. execute: vi.fn().mockResolvedValue(undefined),
  85. setParameter: vi.fn().mockReturnThis(),
  86. setParameters: vi.fn().mockReturnThis()
  87. };
  88. return mockQueryBuilder;
  89. }
  90. /**
  91. * 数据库测试工具类
  92. */
  93. export class TestDatabase {
  94. private static dataSource: DataSource | null = null;
  95. /**
  96. * 初始化测试数据库
  97. */
  98. static async initialize(): Promise<DataSource> {
  99. if (this.dataSource?.isInitialized) {
  100. return this.dataSource;
  101. }
  102. // 使用SQLite内存数据库进行测试
  103. this.dataSource = new DataSource({
  104. type: 'better-sqlite3',
  105. database: ':memory:',
  106. synchronize: true,
  107. logging: false,
  108. entities: [
  109. // 导入实际实体
  110. (await import('@/server/modules/users/user.entity')).UserEntity,
  111. (await import('@/server/modules/users/role.entity')).Role
  112. ]
  113. });
  114. await this.dataSource.initialize();
  115. return this.dataSource;
  116. }
  117. /**
  118. * 清理测试数据库
  119. */
  120. static async cleanup(): Promise<void> {
  121. if (this.dataSource?.isInitialized) {
  122. await this.dataSource.destroy();
  123. this.dataSource = null;
  124. }
  125. }
  126. /**
  127. * 获取当前数据源
  128. */
  129. static getDataSource(): DataSource | null {
  130. return this.dataSource;
  131. }
  132. }
  133. /**
  134. * 测试数据库生命周期钩子
  135. */
  136. export function setupDatabaseHooks() {
  137. beforeEach(async () => {
  138. await TestDatabase.initialize();
  139. });
  140. afterEach(async () => {
  141. await TestDatabase.cleanup();
  142. });
  143. }