feat(bots/discord): switch to drizzle-orm

This commit is contained in:
PalmDevs
2024-06-23 17:00:17 +07:00
parent 3bca6e5c31
commit e204b7b756
9 changed files with 68 additions and 152 deletions

View File

@@ -179,3 +179,4 @@ config.ts
# DB # DB
*.db *.db
*.sqlite

View File

@@ -30,6 +30,10 @@
"@revanced/bot-api": "workspace:*", "@revanced/bot-api": "workspace:*",
"@revanced/bot-shared": "workspace:*", "@revanced/bot-shared": "workspace:*",
"chalk": "^5.3.0", "chalk": "^5.3.0",
"discord.js": "^14.15.3" "discord.js": "^14.15.3",
"drizzle-orm": "^0.31.2"
},
"devDependencies": {
"drizzle-kit": "^0.22.7"
} }
} }

View File

@@ -1,127 +0,0 @@
import { Database } from 'bun:sqlite'
type BasicSQLBindings = string | number | null
export class BasicDatabase<T extends Record<string, BasicSQLBindings>> {
#db: Database
#table: string
constructor(file: string, struct: string, tableName = 'data') {
const db = new Database(file, {
create: true,
readwrite: true,
})
this.#db = db
this.#table = tableName
db.run(`CREATE TABLE IF NOT EXISTS ${tableName} (${struct});`)
}
run(statement: string) {
this.#db.run(statement)
}
prepare(statement: string) {
return this.#db.prepare<T, BasicSQLBindings[]>(statement)
}
query(statement: string) {
return this.#db.query<T, BasicSQLBindings[]>(statement)
}
insert(...values: BasicSQLBindings[]) {
this.run(`INSERT INTO ${this.#table} VALUES (${values.map(this.#encodeValue).join(', ')});`)
}
update(what: Partial<T>, where: string) {
const set = Object.entries(what)
.map(([key, value]) => `${key} = ${this.#encodeValue(value)}`)
.join(', ')
this.run(`UPDATE ${this.#table} SET ${set} WHERE ${where};`)
}
delete(where: string) {
this.run(`DELETE FROM ${this.#table} WHERE ${where};`)
}
select(columns: string[] | string, where: string) {
const realColumns = Array.isArray(columns) ? columns.join(', ') : columns
return this.query(`SELECT ${realColumns} FROM ${this.#table} WHERE ${where};`).get()
}
#encodeValue(value: unknown) {
if (typeof value === 'string') return `'${value.replaceAll("'", "\\'")}'`
if (typeof value === 'number') return value
if (typeof value === 'boolean') return value ? 1 : 0
if (value === null) return 'NULL'
return null
}
}
export class LabeledResponseDatabase {
#db: BasicDatabase<LabeledResponse>
constructor() {
this.#db = new BasicDatabase<LabeledResponse>(
'responses.db',
`reply TEXT PRIMARY KEY NOT NULL,
channel TEXT NOT NULL,
guild TEXT NOT NULL,
referenceMessage TEXT KEY NOT NULL,
label TEXT NOT NULL,
text TEXT NOT NULL,
correctedBy TEXT,
CHECK (
typeof("text") = 'text' AND
length("text") > 0 AND
length("text") <= 280
)`,
)
}
save({ reply, channel, guild, referenceMessage, label, text }: Omit<LabeledResponse, 'correctedBy'>) {
const actualText = text.slice(0, 280)
this.#db.insert(reply, channel, guild, referenceMessage, label, actualText, null)
}
get(reply: string) {
return this.#db.select('*', `reply = ${reply}`)
}
edit(reply: string, { label, correctedBy }: Pick<LabeledResponse, 'label' | 'correctedBy'>) {
this.#db.update({ label, correctedBy }, `reply = ${reply}`)
}
}
export type LabeledResponse = {
/**
* The label of the response
*/
label: string
/**
* The ID of the user who corrected the response
*/
correctedBy: string | null
/**
* The text content of the response
*/
text: string
/**
* The ID of the message that triggered the response
*/
referenceMessage: string
/**
* The ID of the channel where the response was sent
*/
channel: string
/**
* The ID of the guild where the response was sent
*/
guild: string
/**
* The ID of the reply
*/
reply: string
}

View File

@@ -1,10 +1,14 @@
import { loadCommands } from '$utils/discord/commands' import { Database } from 'bun:sqlite'
import { Client as APIClient } from '@revanced/bot-api' import { Client as APIClient } from '@revanced/bot-api'
import { createLogger } from '@revanced/bot-shared' import { createLogger } from '@revanced/bot-shared'
import { ActivityType, Client as DiscordClient, Partials } from 'discord.js' import { ActivityType, Client as DiscordClient, Partials } from 'discord.js'
import { drizzle } from 'drizzle-orm/bun-sqlite'
import config from '../config' import config from '../config'
import { LabeledResponseDatabase } from './classes/Database' import * as schemas from './database/schemas'
import { pathJoinCurrentDir } from './utils/fs'
import { loadCommands } from '$utils/discord/commands'
import { pathJoinCurrentDir } from '$utils/fs'
export { config } export { config }
export const logger = createLogger({ export const logger = createLogger({
@@ -23,9 +27,11 @@ export const api = {
disconnectCount: 0, disconnectCount: 0,
} }
export const database = { const db = new Database('db.sqlite')
labeledResponses: new LabeledResponseDatabase(),
} as const export const database = drizzle(db, {
schema: schemas,
})
export const discord = { export const discord = {
client: new DiscordClient({ client: new DiscordClient({

View File

@@ -0,0 +1,21 @@
import type { InferSelectModel } from 'drizzle-orm'
import { sqliteTable, text } from 'drizzle-orm/sqlite-core'
export const responses = sqliteTable('responses', {
replyId: text('reply').primaryKey().notNull(),
channelId: text('channel').notNull(),
guildId: text('guild').notNull(),
referenceId: text('ref').notNull(),
label: text('label').notNull(),
content: text('text').notNull(),
correctedById: text('by'),
})
export const appliedPresets = sqliteTable('applied_presets', {
memberId: text('member').primaryKey().notNull(),
guildId: text('guild').notNull(),
presets: text('presets', { mode: 'json' }).$type<string[]>().notNull().default([]),
})
export type Response = InferSelectModel<typeof responses>
export type AppliedPreset = InferSelectModel<typeof appliedPresets>

View File

@@ -1,8 +1,10 @@
import { responses } from '$/database/schemas'
import { handleUserResponseCorrection } from '$/utils/discord/messageScan' import { handleUserResponseCorrection } from '$/utils/discord/messageScan'
import { createErrorEmbed, createStackTraceEmbed, createSuccessEmbed } from '$utils/discord/embeds' import { createErrorEmbed, createStackTraceEmbed, createSuccessEmbed } from '$utils/discord/embeds'
import { on } from '$utils/discord/events' import { on } from '$utils/discord/events'
import type { ButtonInteraction, StringSelectMenuInteraction, TextBasedChannel } from 'discord.js' import type { ButtonInteraction, StringSelectMenuInteraction, TextBasedChannel } from 'discord.js'
import { eq } from 'drizzle-orm'
// No permission check required as it is already done when the user reacts to a bot response // No permission check required as it is already done when the user reacts to a bot response
export default on('interactionCreate', async (context, interaction) => { export default on('interactionCreate', async (context, interaction) => {
@@ -19,7 +21,7 @@ export default on('interactionCreate', async (context, interaction) => {
const [, key, action] = interaction.customId.split('_') as ['cr', string, 'select' | 'cancel' | 'delete'] const [, key, action] = interaction.customId.split('_') as ['cr', string, 'select' | 'cancel' | 'delete']
if (!key || !action) return if (!key || !action) return
const response = db.labeledResponses.get(key) const response = await db.query.responses.findFirst({ where: eq(responses.replyId, key) })
// If the message isn't saved in my DB (unrelated message) // If the message isn't saved in my DB (unrelated message)
if (!response) if (!response)
return void (await interaction.reply({ return void (await interaction.reply({
@@ -30,8 +32,8 @@ export default on('interactionCreate', async (context, interaction) => {
try { try {
// We're gonna pretend reactionChannel is a text-based channel, but it can be many more // We're gonna pretend reactionChannel is a text-based channel, but it can be many more
// But `messages` should always exist as a property // But `messages` should always exist as a property
const reactionGuild = await interaction.client.guilds.fetch(response.guild) const reactionGuild = await interaction.client.guilds.fetch(response.guildId)
const reactionChannel = (await reactionGuild.channels.fetch(response.channel)) as TextBasedChannel | null const reactionChannel = (await reactionGuild.channels.fetch(response.channelId)) as TextBasedChannel | null
const reactionMessage = await reactionChannel?.messages.fetch(key) const reactionMessage = await reactionChannel?.messages.fetch(key)
if (!reactionMessage) { if (!reactionMessage) {
@@ -55,7 +57,7 @@ export default on('interactionCreate', async (context, interaction) => {
const handleCorrection = (label: string) => const handleCorrection = (label: string) =>
handleUserResponseCorrection(context, response, reactionMessage, label, interaction.user) handleUserResponseCorrection(context, response, reactionMessage, label, interaction.user)
if (response.correctedBy) if (response.correctedById)
return await editMessage( return await editMessage(
'Response already corrected', 'Response already corrected',
'Thank you for your feedback! Unfortunately, this response has already been corrected by someone else.', 'Thank you for your feedback! Unfortunately, this response has already been corrected by someone else.',

View File

@@ -1,4 +1,5 @@
import { MessageScanLabeledResponseReactions } from '$/constants' import { MessageScanLabeledResponseReactions } from '$/constants'
import { responses } from '$/database/schemas'
import { getResponseFromText, shouldScanMessage } from '$/utils/discord/messageScan' import { getResponseFromText, shouldScanMessage } from '$/utils/discord/messageScan'
import { createMessageScanResponseEmbed } from '$utils/discord/embeds' import { createMessageScanResponseEmbed } from '$utils/discord/embeds'
import { on } from '$utils/discord/events' import { on } from '$utils/discord/events'
@@ -30,13 +31,13 @@ on('messageCreate', async (ctx, msg) => {
}) })
if (label) if (label)
db.labeledResponses.save({ db.insert(responses).values({
reply: reply.id, replyId: reply.id,
channel: reply.channel.id, channelId: reply.channel.id,
guild: reply.guild!.id, guildId: reply.guild!.id,
referenceMessage: msg.id, referenceId: msg.id,
label, label,
text: msg.content, content: msg.content,
}) })
if (label) { if (label) {

View File

@@ -10,8 +10,10 @@ import {
StringSelectMenuOptionBuilder, StringSelectMenuOptionBuilder,
} from 'discord.js' } from 'discord.js'
import { responses } from '$/database/schemas'
import { handleUserResponseCorrection } from '$/utils/discord/messageScan' import { handleUserResponseCorrection } from '$/utils/discord/messageScan'
import type { ConfigMessageScanResponseLabelConfig } from 'config.schema' import type { ConfigMessageScanResponseLabelConfig } from 'config.schema'
import { eq } from 'drizzle-orm'
const PossibleReactions = Object.values(Reactions) as string[] const PossibleReactions = Object.values(Reactions) as string[]
@@ -57,8 +59,8 @@ on('messageReactionAdd', async (context, rct, user) => {
} }
// Sanity check // Sanity check
const response = db.labeledResponses.get(rct.message.id) const response = await db.query.responses.findFirst({ where: eq(responses.replyId, rct.message.id) })
if (!response || response.correctedBy) return if (!response || response.correctedById) return
const handleCorrection = (label: string) => const handleCorrection = (label: string) =>
handleUserResponseCorrection(context, response, reactionMessage, label, user) handleUserResponseCorrection(context, response, reactionMessage, label, user)

View File

@@ -1,4 +1,4 @@
import type { LabeledResponse } from '$/classes/Database' import { type Response, responses } from '$/database/schemas'
import type { import type {
Config, Config,
ConfigMessageScanResponse, ConfigMessageScanResponse,
@@ -6,6 +6,7 @@ import type {
ConfigMessageScanResponseMessage, ConfigMessageScanResponseMessage,
} from 'config.schema' } from 'config.schema'
import type { Message, PartialUser, User } from 'discord.js' import type { Message, PartialUser, User } from 'discord.js'
import { eq } from 'drizzle-orm'
import { createMessageScanResponseEmbed } from './embeds' import { createMessageScanResponseEmbed } from './embeds'
export const getResponseFromText = async ( export const getResponseFromText = async (
@@ -138,7 +139,7 @@ export const shouldScanMessage = (
export const handleUserResponseCorrection = async ( export const handleUserResponseCorrection = async (
{ api, database: db, config: { messageScan: msConfig }, logger }: typeof import('$/context'), { api, database: db, config: { messageScan: msConfig }, logger }: typeof import('$/context'),
response: LabeledResponse, response: Response,
reply: Message, reply: Message,
label: string, label: string,
user: User | PartialUser, user: User | PartialUser,
@@ -151,14 +152,19 @@ export const handleUserResponseCorrection = async (
if (!correctLabelResponse.response) return void (await reply.delete()) if (!correctLabelResponse.response) return void (await reply.delete())
if (response.label !== label) { if (response.label !== label) {
db.labeledResponses.edit(response.reply, { label, correctedBy: user.id }) db.update(responses)
.set({
label,
correctedById: user.id,
})
.where(eq(responses.replyId, response.replyId))
await reply.edit({ await reply.edit({
embeds: [createMessageScanResponseEmbed(correctLabelResponse.response, 'nlp')], embeds: [createMessageScanResponseEmbed(correctLabelResponse.response, 'nlp')],
}) })
} }
await api.client.trainMessage(response.text, label) await api.client.trainMessage(response.content, label)
logger.debug(`User ${user.id} trained message ${response.reply} as ${label} (positive)`) logger.debug(`User ${user.id} trained message ${response.replyId} as ${label} (positive)`)
await reply.reactions.removeAll() await reply.reactions.removeAll()
} }