mirror of
https://github.com/ReVanced/revanced-bots.git
synced 2026-01-18 00:33:59 +00:00
feat(bots/discord): switch to drizzle-orm
This commit is contained in:
3
bots/discord/.gitignore
vendored
3
bots/discord/.gitignore
vendored
@@ -178,4 +178,5 @@ dist
|
|||||||
config.ts
|
config.ts
|
||||||
|
|
||||||
# DB
|
# DB
|
||||||
*.db
|
*.db
|
||||||
|
*.sqlite
|
||||||
@@ -30,6 +30,10 @@
|
|||||||
"@revanced/bot-api": "workspace:*",
|
"@revanced/bot-api": "workspace:*",
|
||||||
"@revanced/bot-shared": "workspace:*",
|
"@revanced/bot-shared": "workspace:*",
|
||||||
"chalk": "^5.3.0",
|
"chalk": "^5.3.0",
|
||||||
"discord.js": "^14.15.3"
|
"discord.js": "^14.15.3",
|
||||||
|
"drizzle-orm": "^0.31.2"
|
||||||
|
},
|
||||||
|
"devDependencies": {
|
||||||
|
"drizzle-kit": "^0.22.7"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,127 +0,0 @@
|
|||||||
import { Database } from 'bun:sqlite'
|
|
||||||
|
|
||||||
type BasicSQLBindings = string | number | null
|
|
||||||
|
|
||||||
export class BasicDatabase<T extends Record<string, BasicSQLBindings>> {
|
|
||||||
#db: Database
|
|
||||||
#table: string
|
|
||||||
|
|
||||||
constructor(file: string, struct: string, tableName = 'data') {
|
|
||||||
const db = new Database(file, {
|
|
||||||
create: true,
|
|
||||||
readwrite: true,
|
|
||||||
})
|
|
||||||
|
|
||||||
this.#db = db
|
|
||||||
this.#table = tableName
|
|
||||||
|
|
||||||
db.run(`CREATE TABLE IF NOT EXISTS ${tableName} (${struct});`)
|
|
||||||
}
|
|
||||||
|
|
||||||
run(statement: string) {
|
|
||||||
this.#db.run(statement)
|
|
||||||
}
|
|
||||||
|
|
||||||
prepare(statement: string) {
|
|
||||||
return this.#db.prepare<T, BasicSQLBindings[]>(statement)
|
|
||||||
}
|
|
||||||
|
|
||||||
query(statement: string) {
|
|
||||||
return this.#db.query<T, BasicSQLBindings[]>(statement)
|
|
||||||
}
|
|
||||||
|
|
||||||
insert(...values: BasicSQLBindings[]) {
|
|
||||||
this.run(`INSERT INTO ${this.#table} VALUES (${values.map(this.#encodeValue).join(', ')});`)
|
|
||||||
}
|
|
||||||
|
|
||||||
update(what: Partial<T>, where: string) {
|
|
||||||
const set = Object.entries(what)
|
|
||||||
.map(([key, value]) => `${key} = ${this.#encodeValue(value)}`)
|
|
||||||
.join(', ')
|
|
||||||
|
|
||||||
this.run(`UPDATE ${this.#table} SET ${set} WHERE ${where};`)
|
|
||||||
}
|
|
||||||
|
|
||||||
delete(where: string) {
|
|
||||||
this.run(`DELETE FROM ${this.#table} WHERE ${where};`)
|
|
||||||
}
|
|
||||||
|
|
||||||
select(columns: string[] | string, where: string) {
|
|
||||||
const realColumns = Array.isArray(columns) ? columns.join(', ') : columns
|
|
||||||
return this.query(`SELECT ${realColumns} FROM ${this.#table} WHERE ${where};`).get()
|
|
||||||
}
|
|
||||||
|
|
||||||
#encodeValue(value: unknown) {
|
|
||||||
if (typeof value === 'string') return `'${value.replaceAll("'", "\\'")}'`
|
|
||||||
if (typeof value === 'number') return value
|
|
||||||
if (typeof value === 'boolean') return value ? 1 : 0
|
|
||||||
if (value === null) return 'NULL'
|
|
||||||
return null
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
export class LabeledResponseDatabase {
|
|
||||||
#db: BasicDatabase<LabeledResponse>
|
|
||||||
|
|
||||||
constructor() {
|
|
||||||
this.#db = new BasicDatabase<LabeledResponse>(
|
|
||||||
'responses.db',
|
|
||||||
`reply TEXT PRIMARY KEY NOT NULL,
|
|
||||||
channel TEXT NOT NULL,
|
|
||||||
guild TEXT NOT NULL,
|
|
||||||
referenceMessage TEXT KEY NOT NULL,
|
|
||||||
label TEXT NOT NULL,
|
|
||||||
text TEXT NOT NULL,
|
|
||||||
correctedBy TEXT,
|
|
||||||
CHECK (
|
|
||||||
typeof("text") = 'text' AND
|
|
||||||
length("text") > 0 AND
|
|
||||||
length("text") <= 280
|
|
||||||
)`,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
save({ reply, channel, guild, referenceMessage, label, text }: Omit<LabeledResponse, 'correctedBy'>) {
|
|
||||||
const actualText = text.slice(0, 280)
|
|
||||||
this.#db.insert(reply, channel, guild, referenceMessage, label, actualText, null)
|
|
||||||
}
|
|
||||||
|
|
||||||
get(reply: string) {
|
|
||||||
return this.#db.select('*', `reply = ${reply}`)
|
|
||||||
}
|
|
||||||
|
|
||||||
edit(reply: string, { label, correctedBy }: Pick<LabeledResponse, 'label' | 'correctedBy'>) {
|
|
||||||
this.#db.update({ label, correctedBy }, `reply = ${reply}`)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
export type LabeledResponse = {
|
|
||||||
/**
|
|
||||||
* The label of the response
|
|
||||||
*/
|
|
||||||
label: string
|
|
||||||
/**
|
|
||||||
* The ID of the user who corrected the response
|
|
||||||
*/
|
|
||||||
correctedBy: string | null
|
|
||||||
/**
|
|
||||||
* The text content of the response
|
|
||||||
*/
|
|
||||||
text: string
|
|
||||||
/**
|
|
||||||
* The ID of the message that triggered the response
|
|
||||||
*/
|
|
||||||
referenceMessage: string
|
|
||||||
/**
|
|
||||||
* The ID of the channel where the response was sent
|
|
||||||
*/
|
|
||||||
channel: string
|
|
||||||
/**
|
|
||||||
* The ID of the guild where the response was sent
|
|
||||||
*/
|
|
||||||
guild: string
|
|
||||||
/**
|
|
||||||
* The ID of the reply
|
|
||||||
*/
|
|
||||||
reply: string
|
|
||||||
}
|
|
||||||
@@ -1,10 +1,14 @@
|
|||||||
import { loadCommands } from '$utils/discord/commands'
|
import { Database } from 'bun:sqlite'
|
||||||
import { Client as APIClient } from '@revanced/bot-api'
|
import { Client as APIClient } from '@revanced/bot-api'
|
||||||
import { createLogger } from '@revanced/bot-shared'
|
import { createLogger } from '@revanced/bot-shared'
|
||||||
import { ActivityType, Client as DiscordClient, Partials } from 'discord.js'
|
import { ActivityType, Client as DiscordClient, Partials } from 'discord.js'
|
||||||
|
import { drizzle } from 'drizzle-orm/bun-sqlite'
|
||||||
|
|
||||||
import config from '../config'
|
import config from '../config'
|
||||||
import { LabeledResponseDatabase } from './classes/Database'
|
import * as schemas from './database/schemas'
|
||||||
import { pathJoinCurrentDir } from './utils/fs'
|
|
||||||
|
import { loadCommands } from '$utils/discord/commands'
|
||||||
|
import { pathJoinCurrentDir } from '$utils/fs'
|
||||||
|
|
||||||
export { config }
|
export { config }
|
||||||
export const logger = createLogger({
|
export const logger = createLogger({
|
||||||
@@ -23,9 +27,11 @@ export const api = {
|
|||||||
disconnectCount: 0,
|
disconnectCount: 0,
|
||||||
}
|
}
|
||||||
|
|
||||||
export const database = {
|
const db = new Database('db.sqlite')
|
||||||
labeledResponses: new LabeledResponseDatabase(),
|
|
||||||
} as const
|
export const database = drizzle(db, {
|
||||||
|
schema: schemas,
|
||||||
|
})
|
||||||
|
|
||||||
export const discord = {
|
export const discord = {
|
||||||
client: new DiscordClient({
|
client: new DiscordClient({
|
||||||
|
|||||||
21
bots/discord/src/database/schemas.ts
Normal file
21
bots/discord/src/database/schemas.ts
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
import type { InferSelectModel } from 'drizzle-orm'
|
||||||
|
import { sqliteTable, text } from 'drizzle-orm/sqlite-core'
|
||||||
|
|
||||||
|
export const responses = sqliteTable('responses', {
|
||||||
|
replyId: text('reply').primaryKey().notNull(),
|
||||||
|
channelId: text('channel').notNull(),
|
||||||
|
guildId: text('guild').notNull(),
|
||||||
|
referenceId: text('ref').notNull(),
|
||||||
|
label: text('label').notNull(),
|
||||||
|
content: text('text').notNull(),
|
||||||
|
correctedById: text('by'),
|
||||||
|
})
|
||||||
|
|
||||||
|
export const appliedPresets = sqliteTable('applied_presets', {
|
||||||
|
memberId: text('member').primaryKey().notNull(),
|
||||||
|
guildId: text('guild').notNull(),
|
||||||
|
presets: text('presets', { mode: 'json' }).$type<string[]>().notNull().default([]),
|
||||||
|
})
|
||||||
|
|
||||||
|
export type Response = InferSelectModel<typeof responses>
|
||||||
|
export type AppliedPreset = InferSelectModel<typeof appliedPresets>
|
||||||
@@ -1,8 +1,10 @@
|
|||||||
|
import { responses } from '$/database/schemas'
|
||||||
import { handleUserResponseCorrection } from '$/utils/discord/messageScan'
|
import { handleUserResponseCorrection } from '$/utils/discord/messageScan'
|
||||||
import { createErrorEmbed, createStackTraceEmbed, createSuccessEmbed } from '$utils/discord/embeds'
|
import { createErrorEmbed, createStackTraceEmbed, createSuccessEmbed } from '$utils/discord/embeds'
|
||||||
import { on } from '$utils/discord/events'
|
import { on } from '$utils/discord/events'
|
||||||
|
|
||||||
import type { ButtonInteraction, StringSelectMenuInteraction, TextBasedChannel } from 'discord.js'
|
import type { ButtonInteraction, StringSelectMenuInteraction, TextBasedChannel } from 'discord.js'
|
||||||
|
import { eq } from 'drizzle-orm'
|
||||||
|
|
||||||
// No permission check required as it is already done when the user reacts to a bot response
|
// No permission check required as it is already done when the user reacts to a bot response
|
||||||
export default on('interactionCreate', async (context, interaction) => {
|
export default on('interactionCreate', async (context, interaction) => {
|
||||||
@@ -19,7 +21,7 @@ export default on('interactionCreate', async (context, interaction) => {
|
|||||||
const [, key, action] = interaction.customId.split('_') as ['cr', string, 'select' | 'cancel' | 'delete']
|
const [, key, action] = interaction.customId.split('_') as ['cr', string, 'select' | 'cancel' | 'delete']
|
||||||
if (!key || !action) return
|
if (!key || !action) return
|
||||||
|
|
||||||
const response = db.labeledResponses.get(key)
|
const response = await db.query.responses.findFirst({ where: eq(responses.replyId, key) })
|
||||||
// If the message isn't saved in my DB (unrelated message)
|
// If the message isn't saved in my DB (unrelated message)
|
||||||
if (!response)
|
if (!response)
|
||||||
return void (await interaction.reply({
|
return void (await interaction.reply({
|
||||||
@@ -30,8 +32,8 @@ export default on('interactionCreate', async (context, interaction) => {
|
|||||||
try {
|
try {
|
||||||
// We're gonna pretend reactionChannel is a text-based channel, but it can be many more
|
// We're gonna pretend reactionChannel is a text-based channel, but it can be many more
|
||||||
// But `messages` should always exist as a property
|
// But `messages` should always exist as a property
|
||||||
const reactionGuild = await interaction.client.guilds.fetch(response.guild)
|
const reactionGuild = await interaction.client.guilds.fetch(response.guildId)
|
||||||
const reactionChannel = (await reactionGuild.channels.fetch(response.channel)) as TextBasedChannel | null
|
const reactionChannel = (await reactionGuild.channels.fetch(response.channelId)) as TextBasedChannel | null
|
||||||
const reactionMessage = await reactionChannel?.messages.fetch(key)
|
const reactionMessage = await reactionChannel?.messages.fetch(key)
|
||||||
|
|
||||||
if (!reactionMessage) {
|
if (!reactionMessage) {
|
||||||
@@ -55,7 +57,7 @@ export default on('interactionCreate', async (context, interaction) => {
|
|||||||
const handleCorrection = (label: string) =>
|
const handleCorrection = (label: string) =>
|
||||||
handleUserResponseCorrection(context, response, reactionMessage, label, interaction.user)
|
handleUserResponseCorrection(context, response, reactionMessage, label, interaction.user)
|
||||||
|
|
||||||
if (response.correctedBy)
|
if (response.correctedById)
|
||||||
return await editMessage(
|
return await editMessage(
|
||||||
'Response already corrected',
|
'Response already corrected',
|
||||||
'Thank you for your feedback! Unfortunately, this response has already been corrected by someone else.',
|
'Thank you for your feedback! Unfortunately, this response has already been corrected by someone else.',
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import { MessageScanLabeledResponseReactions } from '$/constants'
|
import { MessageScanLabeledResponseReactions } from '$/constants'
|
||||||
|
import { responses } from '$/database/schemas'
|
||||||
import { getResponseFromText, shouldScanMessage } from '$/utils/discord/messageScan'
|
import { getResponseFromText, shouldScanMessage } from '$/utils/discord/messageScan'
|
||||||
import { createMessageScanResponseEmbed } from '$utils/discord/embeds'
|
import { createMessageScanResponseEmbed } from '$utils/discord/embeds'
|
||||||
import { on } from '$utils/discord/events'
|
import { on } from '$utils/discord/events'
|
||||||
@@ -30,13 +31,13 @@ on('messageCreate', async (ctx, msg) => {
|
|||||||
})
|
})
|
||||||
|
|
||||||
if (label)
|
if (label)
|
||||||
db.labeledResponses.save({
|
db.insert(responses).values({
|
||||||
reply: reply.id,
|
replyId: reply.id,
|
||||||
channel: reply.channel.id,
|
channelId: reply.channel.id,
|
||||||
guild: reply.guild!.id,
|
guildId: reply.guild!.id,
|
||||||
referenceMessage: msg.id,
|
referenceId: msg.id,
|
||||||
label,
|
label,
|
||||||
text: msg.content,
|
content: msg.content,
|
||||||
})
|
})
|
||||||
|
|
||||||
if (label) {
|
if (label) {
|
||||||
|
|||||||
@@ -10,8 +10,10 @@ import {
|
|||||||
StringSelectMenuOptionBuilder,
|
StringSelectMenuOptionBuilder,
|
||||||
} from 'discord.js'
|
} from 'discord.js'
|
||||||
|
|
||||||
|
import { responses } from '$/database/schemas'
|
||||||
import { handleUserResponseCorrection } from '$/utils/discord/messageScan'
|
import { handleUserResponseCorrection } from '$/utils/discord/messageScan'
|
||||||
import type { ConfigMessageScanResponseLabelConfig } from 'config.schema'
|
import type { ConfigMessageScanResponseLabelConfig } from 'config.schema'
|
||||||
|
import { eq } from 'drizzle-orm'
|
||||||
|
|
||||||
const PossibleReactions = Object.values(Reactions) as string[]
|
const PossibleReactions = Object.values(Reactions) as string[]
|
||||||
|
|
||||||
@@ -57,8 +59,8 @@ on('messageReactionAdd', async (context, rct, user) => {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Sanity check
|
// Sanity check
|
||||||
const response = db.labeledResponses.get(rct.message.id)
|
const response = await db.query.responses.findFirst({ where: eq(responses.replyId, rct.message.id) })
|
||||||
if (!response || response.correctedBy) return
|
if (!response || response.correctedById) return
|
||||||
|
|
||||||
const handleCorrection = (label: string) =>
|
const handleCorrection = (label: string) =>
|
||||||
handleUserResponseCorrection(context, response, reactionMessage, label, user)
|
handleUserResponseCorrection(context, response, reactionMessage, label, user)
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
import type { LabeledResponse } from '$/classes/Database'
|
import { type Response, responses } from '$/database/schemas'
|
||||||
import type {
|
import type {
|
||||||
Config,
|
Config,
|
||||||
ConfigMessageScanResponse,
|
ConfigMessageScanResponse,
|
||||||
@@ -6,6 +6,7 @@ import type {
|
|||||||
ConfigMessageScanResponseMessage,
|
ConfigMessageScanResponseMessage,
|
||||||
} from 'config.schema'
|
} from 'config.schema'
|
||||||
import type { Message, PartialUser, User } from 'discord.js'
|
import type { Message, PartialUser, User } from 'discord.js'
|
||||||
|
import { eq } from 'drizzle-orm'
|
||||||
import { createMessageScanResponseEmbed } from './embeds'
|
import { createMessageScanResponseEmbed } from './embeds'
|
||||||
|
|
||||||
export const getResponseFromText = async (
|
export const getResponseFromText = async (
|
||||||
@@ -138,7 +139,7 @@ export const shouldScanMessage = (
|
|||||||
|
|
||||||
export const handleUserResponseCorrection = async (
|
export const handleUserResponseCorrection = async (
|
||||||
{ api, database: db, config: { messageScan: msConfig }, logger }: typeof import('$/context'),
|
{ api, database: db, config: { messageScan: msConfig }, logger }: typeof import('$/context'),
|
||||||
response: LabeledResponse,
|
response: Response,
|
||||||
reply: Message,
|
reply: Message,
|
||||||
label: string,
|
label: string,
|
||||||
user: User | PartialUser,
|
user: User | PartialUser,
|
||||||
@@ -151,14 +152,19 @@ export const handleUserResponseCorrection = async (
|
|||||||
if (!correctLabelResponse.response) return void (await reply.delete())
|
if (!correctLabelResponse.response) return void (await reply.delete())
|
||||||
|
|
||||||
if (response.label !== label) {
|
if (response.label !== label) {
|
||||||
db.labeledResponses.edit(response.reply, { label, correctedBy: user.id })
|
db.update(responses)
|
||||||
|
.set({
|
||||||
|
label,
|
||||||
|
correctedById: user.id,
|
||||||
|
})
|
||||||
|
.where(eq(responses.replyId, response.replyId))
|
||||||
await reply.edit({
|
await reply.edit({
|
||||||
embeds: [createMessageScanResponseEmbed(correctLabelResponse.response, 'nlp')],
|
embeds: [createMessageScanResponseEmbed(correctLabelResponse.response, 'nlp')],
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
await api.client.trainMessage(response.text, label)
|
await api.client.trainMessage(response.content, label)
|
||||||
logger.debug(`User ${user.id} trained message ${response.reply} as ${label} (positive)`)
|
logger.debug(`User ${user.id} trained message ${response.replyId} as ${label} (positive)`)
|
||||||
|
|
||||||
await reply.reactions.removeAll()
|
await reply.reactions.removeAll()
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user