test-db.ts 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. import { DataSource, EntityManager, Repository } from 'typeorm';
  2. import { vi, beforeEach, afterEach } from 'vitest';
  3. /**
  4. * 创建模拟的数据源
  5. */
  6. export function createMockDataSource() {
  7. const mockDataSource = {
  8. initialize: vi.fn().mockResolvedValue(undefined),
  9. destroy: vi.fn().mockResolvedValue(undefined),
  10. isInitialized: true,
  11. manager: createMockEntityManager(),
  12. getRepository: vi.fn().mockImplementation(() => createMockRepository()),
  13. createQueryBuilder: vi.fn().mockReturnValue(createMockQueryBuilder()),
  14. transaction: vi.fn().mockImplementation(async (callback) => {
  15. return callback(mockDataSource.manager);
  16. }),
  17. synchronize: vi.fn().mockResolvedValue(undefined),
  18. dropDatabase: vi.fn().mockResolvedValue(undefined)
  19. };
  20. return mockDataSource;
  21. }
  22. /**
  23. * 创建模拟的实体管理器
  24. */
  25. export function createMockEntityManager(): EntityManager {
  26. return {
  27. find: vi.fn().mockResolvedValue([]),
  28. findOne: vi.fn().mockResolvedValue(null),
  29. save: vi.fn().mockImplementation((entity) => Promise.resolve(entity)),
  30. update: vi.fn().mockResolvedValue({ affected: 1 }),
  31. delete: vi.fn().mockResolvedValue({ affected: 1 }),
  32. createQueryBuilder: vi.fn().mockReturnValue(createMockQueryBuilder()),
  33. transaction: vi.fn().mockImplementation(async (callback) => {
  34. return callback(mockDataSource.manager);
  35. }),
  36. getRepository: vi.fn().mockImplementation(() => createMockRepository())
  37. } as any;
  38. }
  39. /**
  40. * 创建模拟的Repository
  41. */
  42. export function createMockRepository<T extends object = any>(): Repository<T> {
  43. return {
  44. find: vi.fn().mockResolvedValue([]),
  45. findOne: vi.fn().mockResolvedValue(null),
  46. findOneBy: vi.fn().mockResolvedValue(null),
  47. findOneByOrFail: vi.fn().mockResolvedValue(null),
  48. findBy: vi.fn().mockResolvedValue([]),
  49. findAndCount: vi.fn().mockResolvedValue([[], 0]),
  50. findAndCountBy: vi.fn().mockResolvedValue([[], 0]),
  51. save: vi.fn().mockImplementation((entity) => Promise.resolve(entity)),
  52. update: vi.fn().mockResolvedValue({ affected: 1 }),
  53. delete: vi.fn().mockResolvedValue({ affected: 1 }),
  54. create: vi.fn().mockImplementation((entity) => ({ ...entity, id: Date.now() })),
  55. createQueryBuilder: vi.fn().mockReturnValue(createMockQueryBuilder()),
  56. count: vi.fn().mockResolvedValue(0),
  57. countBy: vi.fn().mockResolvedValue(0),
  58. exist: vi.fn().mockResolvedValue(false)
  59. } as any;
  60. }
  61. /**
  62. * 创建模拟的QueryBuilder
  63. */
  64. export function createMockQueryBuilder() {
  65. const mockQueryBuilder = {
  66. select: vi.fn().mockReturnThis(),
  67. from: vi.fn().mockReturnThis(),
  68. where: vi.fn().mockReturnThis(),
  69. andWhere: vi.fn().mockReturnThis(),
  70. orWhere: vi.fn().mockReturnThis(),
  71. leftJoin: vi.fn().mockReturnThis(),
  72. innerJoin: vi.fn().mockReturnThis(),
  73. orderBy: vi.fn().mockReturnThis(),
  74. groupBy: vi.fn().mockReturnThis(),
  75. having: vi.fn().mockReturnThis(),
  76. skip: vi.fn().mockReturnThis(),
  77. take: vi.fn().mockReturnThis(),
  78. getMany: vi.fn().mockResolvedValue([]),
  79. getOne: vi.fn().mockResolvedValue(null),
  80. getCount: vi.fn().mockResolvedValue(0),
  81. getRawMany: vi.fn().mockResolvedValue([]),
  82. getRawOne: vi.fn().mockResolvedValue(null),
  83. execute: vi.fn().mockResolvedValue(undefined),
  84. setParameter: vi.fn().mockReturnThis(),
  85. setParameters: vi.fn().mockReturnThis()
  86. };
  87. return mockQueryBuilder;
  88. }
  89. /**
  90. * 数据库测试工具类
  91. */
  92. export class TestDatabase {
  93. private static dataSource: DataSource | null = null;
  94. /**
  95. * 初始化测试数据库
  96. */
  97. static async initialize(): Promise<DataSource> {
  98. if (this.dataSource?.isInitialized) {
  99. return this.dataSource;
  100. }
  101. // 使用SQLite内存数据库进行测试
  102. this.dataSource = new DataSource({
  103. type: 'better-sqlite3',
  104. database: ':memory:',
  105. synchronize: true,
  106. logging: false,
  107. entities: [] // 需要根据实际项目配置实体
  108. });
  109. await this.dataSource.initialize();
  110. return this.dataSource;
  111. }
  112. /**
  113. * 清理测试数据库
  114. */
  115. static async cleanup(): Promise<void> {
  116. if (this.dataSource?.isInitialized) {
  117. await this.dataSource.destroy();
  118. this.dataSource = null;
  119. }
  120. }
  121. /**
  122. * 获取当前数据源
  123. */
  124. static getDataSource(): DataSource | null {
  125. return this.dataSource;
  126. }
  127. }
  128. /**
  129. * 测试数据库生命周期钩子
  130. */
  131. export function setupDatabaseHooks() {
  132. beforeEach(async () => {
  133. await TestDatabase.initialize();
  134. });
  135. afterEach(async () => {
  136. await TestDatabase.cleanup();
  137. });
  138. }