diff options
Diffstat (limited to 'main.go')
-rw-r--r-- | main.go | 135 |
1 files changed, 112 insertions, 23 deletions
@@ -62,10 +62,33 @@ func hammingDistance(a, b int64) int { return bits.OnesCount64(uint64(a) ^ uint64(b)) } +type productAggregator float64 + +func (pa *productAggregator) Step(v float64) { + *pa = productAggregator(float64(*pa) * v) +} + +func (pa *productAggregator) Done() float64 { + return float64(*pa) +} + +func newProductAggregator() *productAggregator { + pa := productAggregator(1) + return &pa +} + func init() { sql.Register("sqlite3_custom", &sqlite3.SQLiteDriver{ ConnectHook: func(conn *sqlite3.SQLiteConn) error { - return conn.RegisterFunc("hamming", hammingDistance, true /*pure*/) + if err := conn.RegisterFunc( + "hamming", hammingDistance, true /*pure*/); err != nil { + return err + } + if err := conn.RegisterAggregator( + "product", newProductAggregator, true /*pure*/); err != nil { + return err + } + return nil }, }) } @@ -956,17 +979,89 @@ func handleAPISimilar(w http.ResponseWriter, r *http.Request) { } // - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - +// This is the most miserable part of the whole program. -// NOTE: AND will mean MULTIPLY(IFNULL(ta.weight, 0)) per SHA1. -const searchCTE = `WITH +const searchCTE1 = `WITH matches(sha1, thumbw, thumbh, score) AS ( SELECT i.sha1, i.thumbw, i.thumbh, ta.weight AS score FROM tag_assignment AS ta JOIN image AS i ON i.sha1 = ta.sha1 - WHERE ta.tag = ? + WHERE ta.tag = %d + ) +` + +const searchCTEMulti = `WITH + positive(tag) AS (VALUES %s), + candidates(sha1) AS (%s), + matches(sha1, thumbw, thumbh, score) AS ( + SELECT i.sha1, i.thumbw, i.thumbh, + product(IFNULL(ta.weight, 0)) AS score + FROM image AS i, positive AS p + JOIN candidates AS c ON i.sha1 = c.sha1 + LEFT JOIN tag_assignment AS ta ON ta.sha1 = i.sha1 AND ta.tag = p.tag + GROUP BY i.sha1 ) ` +func parseQuery(query string) (string, error) { + positive, negative := []int64{}, []int64{} + for _, word := range strings.Split(query, " ") { + if word == "" { + continue + } + + space, tag, _ := strings.Cut(word, ":") + + negated := false + if strings.HasPrefix(space, "-") { + space = space[1:] + negated = true + } + + var tagID int64 + err := db.QueryRow(` + SELECT t.id FROM tag AS t + JOIN tag_space AS ts ON t.space = ts.id + WHERE ts.name = ? AND t.name = ?`, space, tag).Scan(&tagID) + if err != nil { + return "", err + } + + if negated { + negative = append(negative, tagID) + } else { + positive = append(positive, tagID) + } + } + + // Don't return most of the database, and simplify the following builder. + if len(positive) == 0 { + return "", errors.New("search is too wide") + } + + // Optimise single tag searches. + if len(positive) == 1 && len(negative) == 0 { + return fmt.Sprintf(searchCTE1, positive[0]), nil + } + + values := fmt.Sprintf(`(%d)`, positive[0]) + candidates := fmt.Sprintf( + `SELECT sha1 FROM tag_assignment WHERE tag = %d`, positive[0]) + for _, tagID := range positive[1:] { + values += fmt.Sprintf(`, (%d)`, tagID) + candidates += fmt.Sprintf(` INTERSECT + SELECT sha1 FROM tag_assignment WHERE tag = %d`, tagID) + } + for _, tagID := range negative { + candidates += fmt.Sprintf(` EXCEPT + SELECT sha1 FROM tag_assignment WHERE tag = %d`, tagID) + } + + return fmt.Sprintf(searchCTEMulti, values, candidates), nil +} + +// - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + type webTagMatch struct { SHA1 string `json:"sha1"` ThumbW int64 `json:"thumbW"` @@ -974,10 +1069,10 @@ type webTagMatch struct { Score float32 `json:"score"` } -func getTagMatches(tag int64) (matches []webTagMatch, err error) { - rows, err := db.Query(searchCTE+` +func getTagMatches(cte string) (matches []webTagMatch, err error) { + rows, err := db.Query(cte + ` SELECT sha1, IFNULL(thumbw, 0), IFNULL(thumbh, 0), score - FROM matches`, tag) + FROM matches`) if err != nil { return nil, err } @@ -1001,13 +1096,13 @@ type webTagSupertag struct { score float32 } -func getTagSupertags(tag int64) (result map[int64]*webTagSupertag, err error) { - rows, err := db.Query(searchCTE+` +func getTagSupertags(cte string) (result map[int64]*webTagSupertag, err error) { + rows, err := db.Query(cte + ` SELECT DISTINCT ta.tag, ts.name, t.name FROM tag_assignment AS ta JOIN matches AS m ON m.sha1 = ta.sha1 JOIN tag AS t ON ta.tag = t.id - JOIN tag_space AS ts ON ts.id = t.space`, tag) + JOIN tag_space AS ts ON ts.id = t.space`) if err != nil { return nil, err } @@ -1032,18 +1127,18 @@ type webTagRelated struct { Score float32 `json:"score"` } -func getTagRelated(tag int64, matches int) ( +func getTagRelated(cte string, matches int) ( result map[string][]webTagRelated, err error) { // Not sure if this level of efficiency is achievable directly in SQL. - supertags, err := getTagSupertags(tag) + supertags, err := getTagSupertags(cte) if err != nil { return nil, err } - rows, err := db.Query(searchCTE+` + rows, err := db.Query(cte + ` SELECT ta.tag, ta.weight FROM tag_assignment AS ta - JOIN matches AS m ON m.sha1 = ta.sha1`, tag) + JOIN matches AS m ON m.sha1 = ta.sha1`) if err != nil { return nil, err } @@ -1084,13 +1179,7 @@ func handleAPISearch(w http.ResponseWriter, r *http.Request) { Related map[string][]webTagRelated `json:"related"` } - space, tag, _ := strings.Cut(params.Query, ":") - - var tagID int64 - err := db.QueryRow(` - SELECT t.id FROM tag AS t - JOIN tag_space AS ts ON t.space = ts.id - WHERE ts.name = ? AND t.name = ?`, space, tag).Scan(&tagID) + cte, err := parseQuery(params.Query) if errors.Is(err, sql.ErrNoRows) { http.Error(w, err.Error(), http.StatusNotFound) return @@ -1099,11 +1188,11 @@ func handleAPISearch(w http.ResponseWriter, r *http.Request) { return } - if result.Matches, err = getTagMatches(tagID); err != nil { + if result.Matches, err = getTagMatches(cte); err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } - if result.Related, err = getTagRelated(tagID, + if result.Related, err = getTagRelated(cte, len(result.Matches)); err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return |