diff --git a/packages/core/database/src/__tests__/collection-factory.test.ts b/packages/core/database/src/__tests__/collection-factory.test.ts new file mode 100644 index 0000000000..e5022b2944 --- /dev/null +++ b/packages/core/database/src/__tests__/collection-factory.test.ts @@ -0,0 +1,40 @@ +import Database from '../database'; +import { mockDatabase } from './index'; +import { Collection } from '../collection'; + +describe('collection factory', function () { + let db: Database; + + beforeEach(async () => { + db = mockDatabase(); + + await db.clean({ drop: true }); + }); + + afterEach(async () => { + await db.close(); + }); + + it('should register new collection type', async () => { + class ChildCollection extends Collection { + static type = 'child'; + } + + db.collectionFactory.registerCollectionType(ChildCollection, (options) => { + return options.child == true; + }); + + const collection = db.collectionFactory.createCollection({ + name: 'child', + child: true, + }); + + expect(collection).toBeInstanceOf(ChildCollection); + + const collection2 = db.collectionFactory.createCollection({ + name: 'collection', + }); + + expect(collection2).toBeInstanceOf(Collection); + }); +}); diff --git a/packages/core/database/src/collection-factory.ts b/packages/core/database/src/collection-factory.ts new file mode 100644 index 0000000000..56a31d9423 --- /dev/null +++ b/packages/core/database/src/collection-factory.ts @@ -0,0 +1,29 @@ +import { Collection, CollectionOptions } from './collection'; +import Database from './database'; + +export class CollectionFactory { + private collectionTypes: Array<{ + ctor: typeof Collection; + condition: (options: CollectionOptions) => boolean; + }> = []; + + constructor(private database: Database) {} + + registerCollectionType(collectionClass: typeof Collection, condition: (options: CollectionOptions) => boolean) { + this.collectionTypes.push({ ctor: collectionClass, condition }); + } + + createCollection(options: CollectionOptions): T { + let klass = Collection; + for (const { ctor, condition } of this.collectionTypes) { + if (condition(options)) { + klass = ctor; + break; + } + } + + return new klass(options, { + database: this.database, + }) as T; + } +} diff --git a/packages/core/database/src/database.ts b/packages/core/database/src/database.ts index 00bf9c6202..613f5386a2 100644 --- a/packages/core/database/src/database.ts +++ b/packages/core/database/src/database.ts @@ -72,6 +72,7 @@ import { import { patchSequelizeQueryInterface, snakeCase } from './utils'; import { BaseValueParser, registerFieldValueParsers } from './value-parsers'; import { ViewCollection } from './view-collection'; +import { CollectionFactory } from './collection-factory'; export type MergeOptions = merge.Options; @@ -177,23 +178,17 @@ export class Database extends EventEmitter implements AsyncEmitter { tableNameCollectionMap = new Map(); context: any = {}; queryInterface: QueryInterface; - - _instanceId: string; - utils = new DatabaseUtils(this); referenceMap = new ReferencesMap(); inheritanceMap = new InheritanceMap(); - importedFrom = new Map>(); - modelHook: ModelHook; version: DatabaseVersion; - delayCollectionExtend = new Map(); - logger: Logger; - collectionGroupManager = new CollectionGroupManager(this); + + collectionFactory: CollectionFactory = new CollectionFactory(this); declare emitAsync: (event: string | symbol, ...args: any[]) => Promise; constructor(options: DatabaseOptions) { @@ -307,12 +302,30 @@ export class Database extends EventEmitter implements AsyncEmitter { this.initListener(); patchSequelizeQueryInterface(this); + + this.registerCollectionType(); } + _instanceId: string; + get instanceId() { return this._instanceId; } + registerCollectionType() { + this.collectionFactory.registerCollectionType(InheritedCollection, (options) => { + return options.inherits && lodash.castArray(options.inherits).length > 0; + }); + + this.collectionFactory.registerCollectionType(ViewCollection, (options) => { + return options.viewName || options.view; + }); + + this.collectionFactory.registerCollectionType(SqlCollection, (options) => { + return options.sql; + }); + } + setContext(context: any) { this.context = context; } @@ -447,10 +460,6 @@ export class Database extends EventEmitter implements AsyncEmitter { return dialect.includes(this.sequelize.getDialect()); } - escapeId(identifier: string) { - return this.inDialect('mysql') ? `\`${identifier}\`` : `"${identifier}"`; - } - /** * Add collection to database * @param options @@ -466,29 +475,7 @@ export class Database extends EventEmitter implements AsyncEmitter { this.emit('beforeDefineCollection', options); - const hasValidInheritsOptions = (() => { - return options.inherits && lodash.castArray(options.inherits).length > 0; - })(); - - const hasViewOptions = options.viewName || options.view; - - const collectionKlass = (() => { - if (hasValidInheritsOptions) { - return InheritedCollection; - } - - if (hasViewOptions) { - return ViewCollection; - } - - if (options.sql) { - return SqlCollection; - } - - return Collection; - })(); - - const collection = new collectionKlass(options, { database: this }); + const collection = this.collectionFactory.createCollection(options); this.collections.set(collection.name, collection);