From e204b7b7566fd7fa423baef32977a8575d44a9e0 Mon Sep 17 00:00:00 2001 From: PalmDevs Date: Sun, 23 Jun 2024 17:00:17 +0700 Subject: [PATCH] feat(bots/discord): switch to `drizzle-orm` --- bots/discord/.gitignore | 3 +- bots/discord/package.json | 6 +- bots/discord/src/classes/Database.ts | 127 ------------------ bots/discord/src/context.ts | 18 ++- bots/discord/src/database/schemas.ts | 21 +++ .../interactionCreate/correct-response.ts | 10 +- .../src/events/discord/messageCreate/scan.ts | 13 +- .../messageReactionAdd/correct-response.ts | 6 +- bots/discord/src/utils/discord/messageScan.ts | 16 ++- 9 files changed, 68 insertions(+), 152 deletions(-) delete mode 100644 bots/discord/src/classes/Database.ts create mode 100644 bots/discord/src/database/schemas.ts diff --git a/bots/discord/.gitignore b/bots/discord/.gitignore index 0614777..7f0bd6c 100644 --- a/bots/discord/.gitignore +++ b/bots/discord/.gitignore @@ -178,4 +178,5 @@ dist config.ts # DB -*.db \ No newline at end of file +*.db +*.sqlite \ No newline at end of file diff --git a/bots/discord/package.json b/bots/discord/package.json index a9bf588..042cfaf 100644 --- a/bots/discord/package.json +++ b/bots/discord/package.json @@ -30,6 +30,10 @@ "@revanced/bot-api": "workspace:*", "@revanced/bot-shared": "workspace:*", "chalk": "^5.3.0", - "discord.js": "^14.15.3" + "discord.js": "^14.15.3", + "drizzle-orm": "^0.31.2" + }, + "devDependencies": { + "drizzle-kit": "^0.22.7" } } diff --git a/bots/discord/src/classes/Database.ts b/bots/discord/src/classes/Database.ts deleted file mode 100644 index 7da2bb0..0000000 --- a/bots/discord/src/classes/Database.ts +++ /dev/null @@ -1,127 +0,0 @@ -import { Database } from 'bun:sqlite' - -type BasicSQLBindings = string | number | null - -export class BasicDatabase> { - #db: Database - #table: string - - constructor(file: string, struct: string, tableName = 'data') { - const db = new Database(file, { - create: true, - readwrite: true, - }) - - this.#db = db - this.#table = tableName - - db.run(`CREATE TABLE IF NOT EXISTS ${tableName} (${struct});`) - } - - run(statement: string) { - this.#db.run(statement) - } - - prepare(statement: string) { - return this.#db.prepare(statement) - } - - query(statement: string) { - return this.#db.query(statement) - } - - insert(...values: BasicSQLBindings[]) { - this.run(`INSERT INTO ${this.#table} VALUES (${values.map(this.#encodeValue).join(', ')});`) - } - - update(what: Partial, where: string) { - const set = Object.entries(what) - .map(([key, value]) => `${key} = ${this.#encodeValue(value)}`) - .join(', ') - - this.run(`UPDATE ${this.#table} SET ${set} WHERE ${where};`) - } - - delete(where: string) { - this.run(`DELETE FROM ${this.#table} WHERE ${where};`) - } - - select(columns: string[] | string, where: string) { - const realColumns = Array.isArray(columns) ? columns.join(', ') : columns - return this.query(`SELECT ${realColumns} FROM ${this.#table} WHERE ${where};`).get() - } - - #encodeValue(value: unknown) { - if (typeof value === 'string') return `'${value.replaceAll("'", "\\'")}'` - if (typeof value === 'number') return value - if (typeof value === 'boolean') return value ? 1 : 0 - if (value === null) return 'NULL' - return null - } -} - -export class LabeledResponseDatabase { - #db: BasicDatabase - - constructor() { - this.#db = new BasicDatabase( - 'responses.db', - `reply TEXT PRIMARY KEY NOT NULL, - channel TEXT NOT NULL, - guild TEXT NOT NULL, - referenceMessage TEXT KEY NOT NULL, - label TEXT NOT NULL, - text TEXT NOT NULL, - correctedBy TEXT, - CHECK ( - typeof("text") = 'text' AND - length("text") > 0 AND - length("text") <= 280 - )`, - ) - } - - save({ reply, channel, guild, referenceMessage, label, text }: Omit) { - const actualText = text.slice(0, 280) - this.#db.insert(reply, channel, guild, referenceMessage, label, actualText, null) - } - - get(reply: string) { - return this.#db.select('*', `reply = ${reply}`) - } - - edit(reply: string, { label, correctedBy }: Pick) { - this.#db.update({ label, correctedBy }, `reply = ${reply}`) - } -} - -export type LabeledResponse = { - /** - * The label of the response - */ - label: string - /** - * The ID of the user who corrected the response - */ - correctedBy: string | null - /** - * The text content of the response - */ - text: string - /** - * The ID of the message that triggered the response - */ - referenceMessage: string - /** - * The ID of the channel where the response was sent - */ - channel: string - /** - * The ID of the guild where the response was sent - */ - guild: string - /** - * The ID of the reply - */ - reply: string -} diff --git a/bots/discord/src/context.ts b/bots/discord/src/context.ts index f3216d0..3c8bfc7 100644 --- a/bots/discord/src/context.ts +++ b/bots/discord/src/context.ts @@ -1,10 +1,14 @@ -import { loadCommands } from '$utils/discord/commands' +import { Database } from 'bun:sqlite' import { Client as APIClient } from '@revanced/bot-api' import { createLogger } from '@revanced/bot-shared' import { ActivityType, Client as DiscordClient, Partials } from 'discord.js' +import { drizzle } from 'drizzle-orm/bun-sqlite' + import config from '../config' -import { LabeledResponseDatabase } from './classes/Database' -import { pathJoinCurrentDir } from './utils/fs' +import * as schemas from './database/schemas' + +import { loadCommands } from '$utils/discord/commands' +import { pathJoinCurrentDir } from '$utils/fs' export { config } export const logger = createLogger({ @@ -23,9 +27,11 @@ export const api = { disconnectCount: 0, } -export const database = { - labeledResponses: new LabeledResponseDatabase(), -} as const +const db = new Database('db.sqlite') + +export const database = drizzle(db, { + schema: schemas, +}) export const discord = { client: new DiscordClient({ diff --git a/bots/discord/src/database/schemas.ts b/bots/discord/src/database/schemas.ts new file mode 100644 index 0000000..23abaa3 --- /dev/null +++ b/bots/discord/src/database/schemas.ts @@ -0,0 +1,21 @@ +import type { InferSelectModel } from 'drizzle-orm' +import { sqliteTable, text } from 'drizzle-orm/sqlite-core' + +export const responses = sqliteTable('responses', { + replyId: text('reply').primaryKey().notNull(), + channelId: text('channel').notNull(), + guildId: text('guild').notNull(), + referenceId: text('ref').notNull(), + label: text('label').notNull(), + content: text('text').notNull(), + correctedById: text('by'), +}) + +export const appliedPresets = sqliteTable('applied_presets', { + memberId: text('member').primaryKey().notNull(), + guildId: text('guild').notNull(), + presets: text('presets', { mode: 'json' }).$type().notNull().default([]), +}) + +export type Response = InferSelectModel +export type AppliedPreset = InferSelectModel diff --git a/bots/discord/src/events/discord/interactionCreate/correct-response.ts b/bots/discord/src/events/discord/interactionCreate/correct-response.ts index 9ded5dc..9c1c810 100644 --- a/bots/discord/src/events/discord/interactionCreate/correct-response.ts +++ b/bots/discord/src/events/discord/interactionCreate/correct-response.ts @@ -1,8 +1,10 @@ +import { responses } from '$/database/schemas' import { handleUserResponseCorrection } from '$/utils/discord/messageScan' import { createErrorEmbed, createStackTraceEmbed, createSuccessEmbed } from '$utils/discord/embeds' import { on } from '$utils/discord/events' import type { ButtonInteraction, StringSelectMenuInteraction, TextBasedChannel } from 'discord.js' +import { eq } from 'drizzle-orm' // No permission check required as it is already done when the user reacts to a bot response export default on('interactionCreate', async (context, interaction) => { @@ -19,7 +21,7 @@ export default on('interactionCreate', async (context, interaction) => { const [, key, action] = interaction.customId.split('_') as ['cr', string, 'select' | 'cancel' | 'delete'] if (!key || !action) return - const response = db.labeledResponses.get(key) + const response = await db.query.responses.findFirst({ where: eq(responses.replyId, key) }) // If the message isn't saved in my DB (unrelated message) if (!response) return void (await interaction.reply({ @@ -30,8 +32,8 @@ export default on('interactionCreate', async (context, interaction) => { try { // We're gonna pretend reactionChannel is a text-based channel, but it can be many more // But `messages` should always exist as a property - const reactionGuild = await interaction.client.guilds.fetch(response.guild) - const reactionChannel = (await reactionGuild.channels.fetch(response.channel)) as TextBasedChannel | null + const reactionGuild = await interaction.client.guilds.fetch(response.guildId) + const reactionChannel = (await reactionGuild.channels.fetch(response.channelId)) as TextBasedChannel | null const reactionMessage = await reactionChannel?.messages.fetch(key) if (!reactionMessage) { @@ -55,7 +57,7 @@ export default on('interactionCreate', async (context, interaction) => { const handleCorrection = (label: string) => handleUserResponseCorrection(context, response, reactionMessage, label, interaction.user) - if (response.correctedBy) + if (response.correctedById) return await editMessage( 'Response already corrected', 'Thank you for your feedback! Unfortunately, this response has already been corrected by someone else.', diff --git a/bots/discord/src/events/discord/messageCreate/scan.ts b/bots/discord/src/events/discord/messageCreate/scan.ts index f3195fe..4400172 100644 --- a/bots/discord/src/events/discord/messageCreate/scan.ts +++ b/bots/discord/src/events/discord/messageCreate/scan.ts @@ -1,4 +1,5 @@ import { MessageScanLabeledResponseReactions } from '$/constants' +import { responses } from '$/database/schemas' import { getResponseFromText, shouldScanMessage } from '$/utils/discord/messageScan' import { createMessageScanResponseEmbed } from '$utils/discord/embeds' import { on } from '$utils/discord/events' @@ -30,13 +31,13 @@ on('messageCreate', async (ctx, msg) => { }) if (label) - db.labeledResponses.save({ - reply: reply.id, - channel: reply.channel.id, - guild: reply.guild!.id, - referenceMessage: msg.id, + db.insert(responses).values({ + replyId: reply.id, + channelId: reply.channel.id, + guildId: reply.guild!.id, + referenceId: msg.id, label, - text: msg.content, + content: msg.content, }) if (label) { diff --git a/bots/discord/src/events/discord/messageReactionAdd/correct-response.ts b/bots/discord/src/events/discord/messageReactionAdd/correct-response.ts index d562c0c..b90a21a 100644 --- a/bots/discord/src/events/discord/messageReactionAdd/correct-response.ts +++ b/bots/discord/src/events/discord/messageReactionAdd/correct-response.ts @@ -10,8 +10,10 @@ import { StringSelectMenuOptionBuilder, } from 'discord.js' +import { responses } from '$/database/schemas' import { handleUserResponseCorrection } from '$/utils/discord/messageScan' import type { ConfigMessageScanResponseLabelConfig } from 'config.schema' +import { eq } from 'drizzle-orm' const PossibleReactions = Object.values(Reactions) as string[] @@ -57,8 +59,8 @@ on('messageReactionAdd', async (context, rct, user) => { } // Sanity check - const response = db.labeledResponses.get(rct.message.id) - if (!response || response.correctedBy) return + const response = await db.query.responses.findFirst({ where: eq(responses.replyId, rct.message.id) }) + if (!response || response.correctedById) return const handleCorrection = (label: string) => handleUserResponseCorrection(context, response, reactionMessage, label, user) diff --git a/bots/discord/src/utils/discord/messageScan.ts b/bots/discord/src/utils/discord/messageScan.ts index ce9c180..41f6f18 100644 --- a/bots/discord/src/utils/discord/messageScan.ts +++ b/bots/discord/src/utils/discord/messageScan.ts @@ -1,4 +1,4 @@ -import type { LabeledResponse } from '$/classes/Database' +import { type Response, responses } from '$/database/schemas' import type { Config, ConfigMessageScanResponse, @@ -6,6 +6,7 @@ import type { ConfigMessageScanResponseMessage, } from 'config.schema' import type { Message, PartialUser, User } from 'discord.js' +import { eq } from 'drizzle-orm' import { createMessageScanResponseEmbed } from './embeds' export const getResponseFromText = async ( @@ -138,7 +139,7 @@ export const shouldScanMessage = ( export const handleUserResponseCorrection = async ( { api, database: db, config: { messageScan: msConfig }, logger }: typeof import('$/context'), - response: LabeledResponse, + response: Response, reply: Message, label: string, user: User | PartialUser, @@ -151,14 +152,19 @@ export const handleUserResponseCorrection = async ( if (!correctLabelResponse.response) return void (await reply.delete()) if (response.label !== label) { - db.labeledResponses.edit(response.reply, { label, correctedBy: user.id }) + db.update(responses) + .set({ + label, + correctedById: user.id, + }) + .where(eq(responses.replyId, response.replyId)) await reply.edit({ embeds: [createMessageScanResponseEmbed(correctLabelResponse.response, 'nlp')], }) } - await api.client.trainMessage(response.text, label) - logger.debug(`User ${user.id} trained message ${response.reply} as ${label} (positive)`) + await api.client.trainMessage(response.content, label) + logger.debug(`User ${user.id} trained message ${response.replyId} as ${label} (positive)`) await reply.reactions.removeAll() }