Calvin Lee: 3 database: batch msg inserts contrib/migrate-db: use explicit src/dest network contrib: wrap errors correctly 9 files changed, 160 insertions(+), 115 deletions(-)
Pushed, thanks for the fix!
It seems like this patch doesn't apply anymore -- can you rebase it? Also can we avoid %w unless we need it? %w makes the wrapped errors part of the API. IOW: if the wrapped errors change, it's an API break. (It doesn't really matter here since we're not a library.)
Copy & paste the following snippet into your terminal to import this patchset into git:
curl -s https://lists.sr.ht/~emersion/soju-dev/patches/42519/mbox | git am -3Learn more about email & git
This commit takes insert query compilation and transaction creation out of the critical loop for migrating message logs. I have tested with the sqlite backend, and a speedup of approximately 40x has been achieved for log migration. I would appreciate help testing the postgres change. --- This version removes the `StoreMessage` function from the database interface in favor of `StoreMessages`. contrib/migrate-logs/main.go | 14 ++++-- database/database.go | 2 +- database/postgres.go | 88 +++++++++++++++++++-------------- database/sqlite.go | 94 ++++++++++++++++++++++-------------- msgstore/db.go | 4 +- 5 files changed, 123 insertions(+), 79 deletions(-) diff --git a/contrib/migrate-logs/main.go b/contrib/migrate-logs/main.go index 42ad156..2b53aa8 100644 --- a/contrib/migrate-logs/main.go +++ b/contrib/migrate-logs/main.go @@ -15,6 +15,7 @@ import ( "git.sr.ht/~emersion/soju/database" "git.sr.ht/~emersion/soju/msgstore" "git.sr.ht/~emersion/soju/msgstore/znclog" + "gopkg.in/irc.v4" ) const usage = `usage: migrate-logs <source logs> <destination database> @@ -88,6 +89,7 @@ func migrateNetwork(ctx context.Context, db database.Database, user *database.Us return fmt.Errorf("unable to open entry: %s", entryPath) } sc := bufio.NewScanner(entry) + var msgs []*irc.Message for sc.Scan() { msg, _, err := znclog.UnmarshalLine(sc.Text(), user, network, target, ref, true) if err != nil { @@ -95,14 +97,18 @@ func migrateNetwork(ctx context.Context, db database.Database, user *database.Us } else if msg == nil { continue } - _, err = db.StoreMessage(ctx, network.ID, target, msg) - if err != nil { - return fmt.Errorf("unable to store message: %s: %s: %v", entryPath, sc.Text(), err) - } + msgs = append(msgs, msg) } if sc.Err() != nil { return fmt.Errorf("unable to parse entry: %s: %v", entryPath, sc.Err()) } + if len(msgs) == 0 { + continue
This leaks the file: we need to call entry.Close(). I've just dropped this if since optimizing for this case shouldn't really matter too much.
+ } + _, err = db.StoreMessages(ctx, network.ID, target, msgs) + if err != nil { + return fmt.Errorf("unable to store message: %s: %s: %v", entryPath, sc.Text(), err) + } entry.Close() } } diff --git a/database/database.go b/database/database.go index 15c04b9..44fa60e 100644 --- a/database/database.go +++ b/database/database.go @@ -60,7 +60,7 @@ type Database interface { DeleteWebPushSubscription(ctx context.Context, id int64) error GetMessageLastID(ctx context.Context, networkID int64, name string) (int64, error) - StoreMessage(ctx context.Context, networkID int64, name string, msg *irc.Message) (int64, error) + StoreMessages(ctx context.Context, networkID int64, name string, msgs []*irc.Message) ([]int64, error) ListMessageLastPerTarget(ctx context.Context, networkID int64, options *MessageOptions) ([]MessageTarget, error) ListMessages(ctx context.Context, networkID int64, name string, options *MessageOptions) ([]*irc.Message, error) } diff --git a/database/postgres.go b/database/postgres.go index e18cc05..c31a3d7 100644 --- a/database/postgres.go +++ b/database/postgres.go @@ -928,31 +928,18 @@ func (db *PostgresDB) GetMessageLastID(ctx context.Context, networkID int64, nam return msgID, nil } -func (db *PostgresDB) StoreMessage(ctx context.Context, networkID int64, name string, msg *irc.Message) (int64, error) { - ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout) - defer cancel() - - var t time.Time - if tag, ok := msg.Tags["time"]; ok { - var err error - t, err = time.Parse(xirc.ServerTimeLayout, tag) - if err != nil { - return 0, fmt.Errorf("failed to parse message time tag: %v", err) - } - } else { - t = time.Now() +func (db *PostgresDB) StoreMessages(ctx context.Context, networkID int64, name string, msgs []*irc.Message) ([]int64, error) { + if len(msgs) == 0 { + return nil, nil } - - var text sql.NullString - switch msg.Command { - case "PRIVMSG", "NOTICE": - if len(msg.Params) > 1 { - text.Valid = true - text.String = msg.Params[1] - } + ctx, cancel := context.WithTimeout(ctx, time.Duration(len(msgs))*sqliteQueryTimeout) + defer cancel() + tx, err := db.db.BeginTx(ctx, nil) + if err != nil { + return nil, err } - _, err := db.db.ExecContext(ctx, ` + _, err = tx.ExecContext(ctx, ` INSERT INTO "MessageTarget" (network, target) VALUES ($1, $2) ON CONFLICT DO NOTHING`, @@ -960,27 +947,58 @@ func (db *PostgresDB) StoreMessage(ctx context.Context, networkID int64, name st name, ) if err != nil { - return 0, err + tx.Rollback()
Instead of tx.Rollback() in many places, can use defer to avoid the repetition.
+ return nil, err } - var id int64 - err = db.db.QueryRowContext(ctx, ` + insertStmt, err := tx.PrepareContext(ctx, ` INSERT INTO "Message" (target, raw, time, sender, text) SELECT id, $1, $2, $3, $4 FROM "MessageTarget" as t WHERE network = $5 AND target = $6 - RETURNING id`, - msg.String(), - t, - msg.Name, - text, - networkID, - name, - ).Scan(&id) + RETURNING id`) if err != nil { - return 0, err + tx.Rollback() + return nil, err + } + ids := make([]int64, len(msgs)) + for idx, msg := range msgs { + var t time.Time + if tag, ok := msg.Tags["time"]; ok { + var err error + t, err = time.Parse(xirc.ServerTimeLayout, tag) + if err != nil { + tx.Rollback() + return nil, fmt.Errorf("failed to parse message time tag: %w", err) + } + } else { + t = time.Now() + } + + var text sql.NullString + switch msg.Command { + case "PRIVMSG", "NOTICE": + if len(msg.Params) > 1 { + text.Valid = true + text.String = msg.Params[1] + } + } + + err = insertStmt.QueryRowContext(ctx, + msg.String(), + t, + msg.Name, + text, + networkID, + name, + ).Scan(&ids[idx]) + if err != nil { + tx.Rollback() + return nil, err + } } - return id, nil + tx.Commit()
We need to handle errors here.
+ return ids, nil } func (db *PostgresDB) ListMessageLastPerTarget(ctx context.Context, networkID int64, options *MessageOptions) ([]MessageTarget, error) { diff --git a/database/sqlite.go b/database/sqlite.go index 5c07b66..7392b91 100644 --- a/database/sqlite.go +++ b/database/sqlite.go
(Comments above also apply to the SQLite variant.)
@@ -1191,31 +1191,18 @@ func (db *SqliteDB) GetMessageLastID(ctx context.Context, networkID int64, name return msgID, nil } -func (db *SqliteDB) StoreMessage(ctx context.Context, networkID int64, name string, msg *irc.Message) (int64, error) { - ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout) - defer cancel() - - var t time.Time - if tag, ok := msg.Tags["time"]; ok { - var err error - t, err = time.Parse(xirc.ServerTimeLayout, tag) - if err != nil { - return 0, fmt.Errorf("failed to parse message time tag: %v", err) - } - } else { - t = time.Now() +func (db *SqliteDB) StoreMessages(ctx context.Context, networkID int64, name string, msgs []*irc.Message) ([]int64, error) { + if len(msgs) == 0 { + return nil, nil } - - var text sql.NullString - switch msg.Command { - case "PRIVMSG", "NOTICE": - if len(msg.Params) > 1 { - text.Valid = true - text.String = msg.Params[1] - } + ctx, cancel := context.WithTimeout(ctx, time.Duration(len(msgs))*sqliteQueryTimeout) + defer cancel() + tx, err := db.db.BeginTx(ctx, nil) + if err != nil { + return nil, err } - res, err := db.db.ExecContext(ctx, ` + res, err := tx.ExecContext(ctx, ` INSERT INTO MessageTarget(network, target) VALUES (:network, :target) ON CONFLICT DO NOTHING`, @@ -1223,29 +1210,62 @@ func (db *SqliteDB) StoreMessage(ctx context.Context, networkID int64, name stri sql.Named("target", name), ) if err != nil { - return 0, err + tx.Rollback() + return nil, err } - res, err = db.db.ExecContext(ctx, ` + insertStmt, err := tx.PrepareContext(ctx, ` INSERT INTO Message(target, raw, time, sender, text) SELECT id, :raw, :time, :sender, :text FROM MessageTarget as t - WHERE network = :network AND target = :target`, - sql.Named("network", networkID), - sql.Named("target", name), - sql.Named("raw", msg.String()), - sql.Named("time", sqliteTime{t}), - sql.Named("sender", msg.Name), - sql.Named("text", text), - ) + WHERE network = :network AND target = :target`) if err != nil { - return 0, err + tx.Rollback() + return nil, err } - id, err := res.LastInsertId() - if err != nil { - return 0, err + ids := make([]int64, len(msgs)) + for idx, msg := range msgs { + var t time.Time + if tag, ok := msg.Tags["time"]; ok { + var err error + t, err = time.Parse(xirc.ServerTimeLayout, tag) + if err != nil { + return nil, fmt.Errorf("failed to parse message time tag: %w", err) + } + } else { + t = time.Now() + } + + var text sql.NullString + switch msg.Command { + case "PRIVMSG", "NOTICE": + if len(msg.Params) > 1 { + text.Valid = true + text.String = msg.Params[1] + } + } + + res, err = insertStmt.ExecContext(ctx, + sql.Named("network", networkID), + sql.Named("target", name), + sql.Named("raw", msg.String()), + sql.Named("time", sqliteTime{t}), + sql.Named("sender", msg.Name), + sql.Named("text", text), + ) + if err != nil { + tx.Rollback() + return nil, err + } + id, err := res.LastInsertId() + if err != nil { + tx.Rollback() + return nil, err + } + ids[idx] = id } - return id, nil + tx.Commit() + return ids, nil } func (db *SqliteDB) ListMessageLastPerTarget(ctx context.Context, networkID int64, options *MessageOptions) ([]MessageTarget, error) { diff --git a/msgstore/db.go b/msgstore/db.go index 289253a..46e0f03 100644 --- a/msgstore/db.go +++ b/msgstore/db.go @@ -81,11 +81,11 @@ func (ms *dbMessageStore) LoadLatestID(ctx context.Context, id string, options * } func (ms *dbMessageStore) Append(network *database.Network, entity string, msg *irc.Message) (string, error) { - id, err := ms.db.StoreMessage(context.TODO(), network.ID, entity, msg) + ids, err := ms.db.StoreMessages(context.TODO(), network.ID, entity, []*irc.Message{msg}) if err != nil { return "", err } - return formatDBMsgID(network.ID, entity, id), nil + return formatDBMsgID(network.ID, entity, ids[0]), nil } func (ms *dbMessageStore) ListTargets(ctx context.Context, network *database.Network, start, end time.Time, limit int, events bool) ([]ChatHistoryTarget, error) { -- 2.41.0
Thanks for the patch! I've pushed it with minor edits, see below.
This commit makes the source and destination network distinction explicit. This is necessary, as the source and destination network may not have the same ID in the database, and thus associations will be broken when migrated. --- contrib/migrate-db/main.go | 37 +++++++++++++++++++------------------ 1 file changed, 19 insertions(+), 18 deletions(-) diff --git a/contrib/migrate-db/main.go b/contrib/migrate-db/main.go index 823ca30..bb45c9d 100644 --- a/contrib/migrate-db/main.go +++ b/contrib/migrate-db/main.go @@ -71,19 +71,20 @@ func main() { log.Fatalf("unable to get source networks for user: #%d %s", user.ID, user.Username) } - for _, network := range networks { - log.Printf("Storing network: %s\n", network.Name) + for _, srcNetwork := range networks { + log.Printf("Storing network: %s\n", srcNetwork.Name) + destNetwork := srcNetwork - network.ID = 0 + destNetwork.ID = 0 - err := destinationdb.StoreNetwork(ctx, user.ID, &network) + err := destinationdb.StoreNetwork(ctx, user.ID, &destNetwork) if err != nil { - log.Fatalf("unable to store network: #%d %s", network.ID, network.Name) + log.Fatalf("unable to store network: #%d %s", srcNetwork.ID, srcNetwork.Name) } - channels, err := sourcedb.ListChannels(ctx, network.ID) + channels, err := sourcedb.ListChannels(ctx, srcNetwork.ID) if err != nil { - log.Fatalf("unable to get source channels for network: #%d %s", network.ID, network.Name) + log.Fatalf("unable to get source channels for network: #%d %s", srcNetwork.ID, srcNetwork.Name) } for _, channel := range channels { @@ -91,15 +92,15 @@ func main() { channel.ID = 0 - err := destinationdb.StoreChannel(ctx, network.ID, &channel) + err := destinationdb.StoreChannel(ctx, destNetwork.ID, &channel) if err != nil { log.Fatalf("unable to store channel: #%d %s", channel.ID, channel.Name) } } - deliveryReceipts, err := sourcedb.ListDeliveryReceipts(ctx, network.ID) + deliveryReceipts, err := sourcedb.ListDeliveryReceipts(ctx, srcNetwork.ID) if err != nil { - log.Fatalf("unable to get source delivery receipts for network: #%d %s", network.ID, network.Name) + log.Fatalf("unable to get source delivery receipts for network: #%d %s", srcNetwork.ID, srcNetwork.Name) } drcpts := make(map[string][]database.DeliveryReceipt) @@ -115,28 +116,28 @@ func main() { } for client, rcpts := range drcpts { - log.Printf("Storing delivery receipt for: %s.%s.%s", user.Username, network.Name, client) - err := destinationdb.StoreClientDeliveryReceipts(ctx, network.ID, client, rcpts) + log.Printf("Storing delivery receipt for: %s.%s.%s", user.Username, srcNetwork.Name, client) + err := destinationdb.StoreClientDeliveryReceipts(ctx, destNetwork.ID, client, rcpts) if err != nil { - log.Fatalf("unable to store delivery receipts for network and client: %s %s", network.Name, client) + log.Fatalf("unable to store delivery receipts for network and client: %s %s", srcNetwork.Name, client) } } // TODO: migrate read receipts as well - webPushSubscriptions, err := sourcedb.ListWebPushSubscriptions(ctx, user.ID, network.ID) + webPushSubscriptions, err := sourcedb.ListWebPushSubscriptions(ctx, user.ID, srcNetwork.ID) if err != nil { - log.Fatalf("unable to get source web push subscriptions for user and network: %s %s", user.Username, network.Name) + log.Fatalf("unable to get source web push subscriptions for user and network: %s %s", user.Username, srcNetwork.Name) } for _, sub := range webPushSubscriptions { - log.Printf("Storing web push subscription: %s.%s.%d", user.Username, network.Name, sub.ID) + log.Printf("Storing web push subscription: %s.%s.%d", user.Username, srcNetwork.Name, sub.ID) sub.ID = 0 - err := destinationdb.StoreWebPushSubscription(ctx, user.ID, network.ID, &sub) + err := destinationdb.StoreWebPushSubscription(ctx, user.ID, destNetwork.ID, &sub) if err != nil { - log.Fatalf("unable to store web push subscription for user and network: %s %s", user.Username, network.Name) + log.Fatalf("unable to store web push subscription for user and network: %s %s", user.Username, srcNetwork.Name) } } } -- 2.41.0
Pushed, thanks for the fix!
Several error messages in `migrate-logs` are not correctly wrapping the underlying error, and do not display the underlying error message. This commit should make debugging a log migration easier, as it wraps all underlying errors. --- contrib/migrate-db/main.go | 20 ++++++++++---------- contrib/migrate-logs/main.go | 14 +++++++------- msgstore/znclog/reader.go | 2 +- 3 files changed, 18 insertions(+), 18 deletions(-) diff --git a/contrib/migrate-db/main.go b/contrib/migrate-db/main.go index bb45c9d..567ca09 100644 --- a/contrib/migrate-db/main.go +++ b/contrib/migrate-db/main.go @@ -68,7 +68,7 @@ func main() { networks, err := sourcedb.ListNetworks(ctx, user.ID) if err != nil { - log.Fatalf("unable to get source networks for user: #%d %s", user.ID, user.Username) + log.Fatalf("unable to get source networks for user: #%d %s: %v", user.ID, user.Username, err) } for _, srcNetwork := range networks { @@ -79,12 +79,12 @@ func main() { err := destinationdb.StoreNetwork(ctx, user.ID, &destNetwork) if err != nil { - log.Fatalf("unable to store network: #%d %s", srcNetwork.ID, srcNetwork.Name) + log.Fatalf("unable to store network: #%d %s: %v", srcNetwork.ID, srcNetwork.Name, err) } channels, err := sourcedb.ListChannels(ctx, srcNetwork.ID) if err != nil { - log.Fatalf("unable to get source channels for network: #%d %s", srcNetwork.ID, srcNetwork.Name) + log.Fatalf("unable to get source channels for network: #%d %s: %v", srcNetwork.ID, srcNetwork.Name, err) } for _, channel := range channels { @@ -94,13 +94,13 @@ func main() { err := destinationdb.StoreChannel(ctx, destNetwork.ID, &channel) if err != nil { - log.Fatalf("unable to store channel: #%d %s", channel.ID, channel.Name) + log.Fatalf("unable to store channel: #%d %s: %v", channel.ID, channel.Name, err) } } deliveryReceipts, err := sourcedb.ListDeliveryReceipts(ctx, srcNetwork.ID) if err != nil { - log.Fatalf("unable to get source delivery receipts for network: #%d %s", srcNetwork.ID, srcNetwork.Name) + log.Fatalf("unable to get source delivery receipts for network: #%d %s: %v", srcNetwork.ID, srcNetwork.Name, err) } drcpts := make(map[string][]database.DeliveryReceipt) @@ -119,7 +119,7 @@ func main() { log.Printf("Storing delivery receipt for: %s.%s.%s", user.Username, srcNetwork.Name, client) err := destinationdb.StoreClientDeliveryReceipts(ctx, destNetwork.ID, client, rcpts) if err != nil { - log.Fatalf("unable to store delivery receipts for network and client: %s %s", srcNetwork.Name, client) + log.Fatalf("unable to store delivery receipts for network and client: %s %s: %v", srcNetwork.Name, client, err) } } @@ -127,7 +127,7 @@ func main() { webPushSubscriptions, err := sourcedb.ListWebPushSubscriptions(ctx, user.ID, srcNetwork.ID) if err != nil { - log.Fatalf("unable to get source web push subscriptions for user and network: %s %s", user.Username, srcNetwork.Name) + log.Fatalf("unable to get source web push subscriptions for user and network: %s %s: %v", user.Username, srcNetwork.Name, err) } for _, sub := range webPushSubscriptions { @@ -137,7 +137,7 @@ func main() { err := destinationdb.StoreWebPushSubscription(ctx, user.ID, destNetwork.ID, &sub) if err != nil { - log.Fatalf("unable to store web push subscription for user and network: %s %s", user.Username, srcNetwork.Name) + log.Fatalf("unable to store web push subscription for user and network: %s %s, %v", user.Username, srcNetwork.Name, err) } } } @@ -145,7 +145,7 @@ func main() { webPushConfigs, err := sourcedb.ListWebPushConfigs(ctx) if err != nil { - log.Fatal("unable to get source web push configs") + log.Fatalf("unable to get source web push configs: %v", err) } for _, config := range webPushConfigs { @@ -153,7 +153,7 @@ func main() { config.ID = 0 err := destinationdb.StoreWebPushConfig(ctx, &config) if err != nil { - log.Fatalf("unable to store web push config: #%d", config.ID) + log.Fatalf("unable to store web push config: #%d: %v", config.ID, err) } } } diff --git a/contrib/migrate-logs/main.go b/contrib/migrate-logs/main.go index 2b53aa8..db6446a 100644 --- a/contrib/migrate-logs/main.go +++ b/contrib/migrate-logs/main.go @@ -64,13 +64,13 @@ func migrateNetwork(ctx context.Context, db database.Database, user *database.Us targetPath := filepath.Join(rootPath, target) targetDir, err := os.Open(targetPath) if err != nil { - return fmt.Errorf("unable to open target folder: %s", targetPath) + return fmt.Errorf("unable to open target folder '%s': %w", targetPath, err) } entryNames, err := targetDir.Readdirnames(0) targetDir.Close() if err != nil { - return fmt.Errorf("unable to read target folder: %s", targetPath) + return fmt.Errorf("unable to read target folder '%s': %w", targetPath, err) } sort.Strings(entryNames) @@ -80,34 +80,34 @@ func migrateNetwork(ctx context.Context, db database.Database, user *database.Us var year, month, day int _, err := fmt.Sscanf(entryName, "%04d-%02d-%02d.log", &year, &month, &day) if err != nil { - return fmt.Errorf("invalid entry name: %s", entryName) + return fmt.Errorf("invalid entry name '%s': %w", entryName, err) } ref := time.Date(year, time.Month(month), day, 0, 0, 0, 0, time.UTC) entry, err := os.Open(entryPath) if err != nil { - return fmt.Errorf("unable to open entry: %s", entryPath) + return fmt.Errorf("unable to open entry '%s': %w", entryPath, err) } sc := bufio.NewScanner(entry) var msgs []*irc.Message for sc.Scan() { msg, _, err := znclog.UnmarshalLine(sc.Text(), user, network, target, ref, true) if err != nil { - return fmt.Errorf("unable to parse entry: %s: %s", entryPath, sc.Text()) + return fmt.Errorf("unable to parse entry '%s: %s': %w", entryPath, sc.Text(), err) } else if msg == nil { continue } msgs = append(msgs, msg) } if sc.Err() != nil { - return fmt.Errorf("unable to parse entry: %s: %v", entryPath, sc.Err()) + return fmt.Errorf("unable to parse entry: %s: %w", entryPath, sc.Err()) } if len(msgs) == 0 { continue } _, err = db.StoreMessages(ctx, network.ID, target, msgs) if err != nil { - return fmt.Errorf("unable to store message: %s: %s: %v", entryPath, sc.Text(), err) + return fmt.Errorf("unable to store message: %s: %s: %w", entryPath, sc.Text(), err) } entry.Close() } diff --git a/msgstore/znclog/reader.go b/msgstore/znclog/reader.go index 5d7fe1f..befddc5 100644 --- a/msgstore/znclog/reader.go +++ b/msgstore/znclog/reader.go @@ -17,7 +17,7 @@ func UnmarshalLine(line string, user *database.User, network *database.Network, var hour, minute, second int _, err := fmt.Sscanf(line, "[%02d:%02d:%02d] ", &hour, &minute, &second) if err != nil { - return nil, time.Time{}, fmt.Errorf("malformed timestamp prefix: %v", err) + return nil, time.Time{}, fmt.Errorf("malformed timestamp prefix: %w", err) } else if len(line) < timestampPrefixLen { return nil, time.Time{}, fmt.Errorf("malformed timestamp prefix: too short") } -- 2.41.0
It seems like this patch doesn't apply anymore -- can you rebase it? Also can we avoid %w unless we need it? %w makes the wrapped errors part of the API. IOW: if the wrapped errors change, it's an API break. (It doesn't really matter here since we're not a library.)