aboutsummaryrefslogtreecommitdiff
path: root/main.go
diff options
context:
space:
mode:
Diffstat (limited to 'main.go')
-rw-r--r--main.go135
1 files changed, 112 insertions, 23 deletions
diff --git a/main.go b/main.go
index 7c9aceb..644d2d4 100644
--- a/main.go
+++ b/main.go
@@ -62,10 +62,33 @@ func hammingDistance(a, b int64) int {
return bits.OnesCount64(uint64(a) ^ uint64(b))
}
+type productAggregator float64
+
+func (pa *productAggregator) Step(v float64) {
+ *pa = productAggregator(float64(*pa) * v)
+}
+
+func (pa *productAggregator) Done() float64 {
+ return float64(*pa)
+}
+
+func newProductAggregator() *productAggregator {
+ pa := productAggregator(1)
+ return &pa
+}
+
func init() {
sql.Register("sqlite3_custom", &sqlite3.SQLiteDriver{
ConnectHook: func(conn *sqlite3.SQLiteConn) error {
- return conn.RegisterFunc("hamming", hammingDistance, true /*pure*/)
+ if err := conn.RegisterFunc(
+ "hamming", hammingDistance, true /*pure*/); err != nil {
+ return err
+ }
+ if err := conn.RegisterAggregator(
+ "product", newProductAggregator, true /*pure*/); err != nil {
+ return err
+ }
+ return nil
},
})
}
@@ -956,17 +979,89 @@ func handleAPISimilar(w http.ResponseWriter, r *http.Request) {
}
// - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
+// This is the most miserable part of the whole program.
-// NOTE: AND will mean MULTIPLY(IFNULL(ta.weight, 0)) per SHA1.
-const searchCTE = `WITH
+const searchCTE1 = `WITH
matches(sha1, thumbw, thumbh, score) AS (
SELECT i.sha1, i.thumbw, i.thumbh, ta.weight AS score
FROM tag_assignment AS ta
JOIN image AS i ON i.sha1 = ta.sha1
- WHERE ta.tag = ?
+ WHERE ta.tag = %d
+ )
+`
+
+const searchCTEMulti = `WITH
+ positive(tag) AS (VALUES %s),
+ candidates(sha1) AS (%s),
+ matches(sha1, thumbw, thumbh, score) AS (
+ SELECT i.sha1, i.thumbw, i.thumbh,
+ product(IFNULL(ta.weight, 0)) AS score
+ FROM image AS i, positive AS p
+ JOIN candidates AS c ON i.sha1 = c.sha1
+ LEFT JOIN tag_assignment AS ta ON ta.sha1 = i.sha1 AND ta.tag = p.tag
+ GROUP BY i.sha1
)
`
+func parseQuery(query string) (string, error) {
+ positive, negative := []int64{}, []int64{}
+ for _, word := range strings.Split(query, " ") {
+ if word == "" {
+ continue
+ }
+
+ space, tag, _ := strings.Cut(word, ":")
+
+ negated := false
+ if strings.HasPrefix(space, "-") {
+ space = space[1:]
+ negated = true
+ }
+
+ var tagID int64
+ err := db.QueryRow(`
+ SELECT t.id FROM tag AS t
+ JOIN tag_space AS ts ON t.space = ts.id
+ WHERE ts.name = ? AND t.name = ?`, space, tag).Scan(&tagID)
+ if err != nil {
+ return "", err
+ }
+
+ if negated {
+ negative = append(negative, tagID)
+ } else {
+ positive = append(positive, tagID)
+ }
+ }
+
+ // Don't return most of the database, and simplify the following builder.
+ if len(positive) == 0 {
+ return "", errors.New("search is too wide")
+ }
+
+ // Optimise single tag searches.
+ if len(positive) == 1 && len(negative) == 0 {
+ return fmt.Sprintf(searchCTE1, positive[0]), nil
+ }
+
+ values := fmt.Sprintf(`(%d)`, positive[0])
+ candidates := fmt.Sprintf(
+ `SELECT sha1 FROM tag_assignment WHERE tag = %d`, positive[0])
+ for _, tagID := range positive[1:] {
+ values += fmt.Sprintf(`, (%d)`, tagID)
+ candidates += fmt.Sprintf(` INTERSECT
+ SELECT sha1 FROM tag_assignment WHERE tag = %d`, tagID)
+ }
+ for _, tagID := range negative {
+ candidates += fmt.Sprintf(` EXCEPT
+ SELECT sha1 FROM tag_assignment WHERE tag = %d`, tagID)
+ }
+
+ return fmt.Sprintf(searchCTEMulti, values, candidates), nil
+}
+
+// - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
+
type webTagMatch struct {
SHA1 string `json:"sha1"`
ThumbW int64 `json:"thumbW"`
@@ -974,10 +1069,10 @@ type webTagMatch struct {
Score float32 `json:"score"`
}
-func getTagMatches(tag int64) (matches []webTagMatch, err error) {
- rows, err := db.Query(searchCTE+`
+func getTagMatches(cte string) (matches []webTagMatch, err error) {
+ rows, err := db.Query(cte + `
SELECT sha1, IFNULL(thumbw, 0), IFNULL(thumbh, 0), score
- FROM matches`, tag)
+ FROM matches`)
if err != nil {
return nil, err
}
@@ -1001,13 +1096,13 @@ type webTagSupertag struct {
score float32
}
-func getTagSupertags(tag int64) (result map[int64]*webTagSupertag, err error) {
- rows, err := db.Query(searchCTE+`
+func getTagSupertags(cte string) (result map[int64]*webTagSupertag, err error) {
+ rows, err := db.Query(cte + `
SELECT DISTINCT ta.tag, ts.name, t.name
FROM tag_assignment AS ta
JOIN matches AS m ON m.sha1 = ta.sha1
JOIN tag AS t ON ta.tag = t.id
- JOIN tag_space AS ts ON ts.id = t.space`, tag)
+ JOIN tag_space AS ts ON ts.id = t.space`)
if err != nil {
return nil, err
}
@@ -1032,18 +1127,18 @@ type webTagRelated struct {
Score float32 `json:"score"`
}
-func getTagRelated(tag int64, matches int) (
+func getTagRelated(cte string, matches int) (
result map[string][]webTagRelated, err error) {
// Not sure if this level of efficiency is achievable directly in SQL.
- supertags, err := getTagSupertags(tag)
+ supertags, err := getTagSupertags(cte)
if err != nil {
return nil, err
}
- rows, err := db.Query(searchCTE+`
+ rows, err := db.Query(cte + `
SELECT ta.tag, ta.weight
FROM tag_assignment AS ta
- JOIN matches AS m ON m.sha1 = ta.sha1`, tag)
+ JOIN matches AS m ON m.sha1 = ta.sha1`)
if err != nil {
return nil, err
}
@@ -1084,13 +1179,7 @@ func handleAPISearch(w http.ResponseWriter, r *http.Request) {
Related map[string][]webTagRelated `json:"related"`
}
- space, tag, _ := strings.Cut(params.Query, ":")
-
- var tagID int64
- err := db.QueryRow(`
- SELECT t.id FROM tag AS t
- JOIN tag_space AS ts ON t.space = ts.id
- WHERE ts.name = ? AND t.name = ?`, space, tag).Scan(&tagID)
+ cte, err := parseQuery(params.Query)
if errors.Is(err, sql.ErrNoRows) {
http.Error(w, err.Error(), http.StatusNotFound)
return
@@ -1099,11 +1188,11 @@ func handleAPISearch(w http.ResponseWriter, r *http.Request) {
return
}
- if result.Matches, err = getTagMatches(tagID); err != nil {
+ if result.Matches, err = getTagMatches(cte); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
- if result.Related, err = getTagRelated(tagID,
+ if result.Related, err = getTagRelated(cte,
len(result.Matches)); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return