aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPřemysl Eric Janouch <p@janouch.name>2024-01-22 19:29:51 +0100
committerPřemysl Eric Janouch <p@janouch.name>2024-01-22 19:52:35 +0100
commit083739fd4e227b8323a50d29055327ce9c5bfa2d (patch)
tree86e7fb22530e4dac108625650b8c4eba072c4b08
parent4f174972e3f3040f78f87f2b96a0e7bfc94fed6e (diff)
downloadgallery-083739fd4e227b8323a50d29055327ce9c5bfa2d.tar.gz
gallery-083739fd4e227b8323a50d29055327ce9c5bfa2d.tar.xz
gallery-083739fd4e227b8323a50d29055327ce9c5bfa2d.zip
gallery: implement AND/NOT for tag search
-rw-r--r--initialize.sql2
-rw-r--r--main.go135
-rw-r--r--public/gallery.js6
-rw-r--r--public/style.css2
4 files changed, 120 insertions, 25 deletions
diff --git a/initialize.sql b/initialize.sql
index 5a54a7f..292c50b 100644
--- a/initialize.sql
+++ b/initialize.sql
@@ -76,7 +76,7 @@ CREATE TABLE IF NOT EXISTS tag_space(
id INTEGER NOT NULL,
name TEXT NOT NULL,
description TEXT,
- CHECK (name NOT LIKE '%:%'),
+ CHECK (name NOT LIKE '%:%' AND name NOT LIKE '-%'),
PRIMARY KEY (id)
) STRICT;
diff --git a/main.go b/main.go
index 7c9aceb..644d2d4 100644
--- a/main.go
+++ b/main.go
@@ -62,10 +62,33 @@ func hammingDistance(a, b int64) int {
return bits.OnesCount64(uint64(a) ^ uint64(b))
}
+type productAggregator float64
+
+func (pa *productAggregator) Step(v float64) {
+ *pa = productAggregator(float64(*pa) * v)
+}
+
+func (pa *productAggregator) Done() float64 {
+ return float64(*pa)
+}
+
+func newProductAggregator() *productAggregator {
+ pa := productAggregator(1)
+ return &pa
+}
+
func init() {
sql.Register("sqlite3_custom", &sqlite3.SQLiteDriver{
ConnectHook: func(conn *sqlite3.SQLiteConn) error {
- return conn.RegisterFunc("hamming", hammingDistance, true /*pure*/)
+ if err := conn.RegisterFunc(
+ "hamming", hammingDistance, true /*pure*/); err != nil {
+ return err
+ }
+ if err := conn.RegisterAggregator(
+ "product", newProductAggregator, true /*pure*/); err != nil {
+ return err
+ }
+ return nil
},
})
}
@@ -956,17 +979,89 @@ func handleAPISimilar(w http.ResponseWriter, r *http.Request) {
}
// - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
+// This is the most miserable part of the whole program.
-// NOTE: AND will mean MULTIPLY(IFNULL(ta.weight, 0)) per SHA1.
-const searchCTE = `WITH
+const searchCTE1 = `WITH
matches(sha1, thumbw, thumbh, score) AS (
SELECT i.sha1, i.thumbw, i.thumbh, ta.weight AS score
FROM tag_assignment AS ta
JOIN image AS i ON i.sha1 = ta.sha1
- WHERE ta.tag = ?
+ WHERE ta.tag = %d
+ )
+`
+
+const searchCTEMulti = `WITH
+ positive(tag) AS (VALUES %s),
+ candidates(sha1) AS (%s),
+ matches(sha1, thumbw, thumbh, score) AS (
+ SELECT i.sha1, i.thumbw, i.thumbh,
+ product(IFNULL(ta.weight, 0)) AS score
+ FROM image AS i, positive AS p
+ JOIN candidates AS c ON i.sha1 = c.sha1
+ LEFT JOIN tag_assignment AS ta ON ta.sha1 = i.sha1 AND ta.tag = p.tag
+ GROUP BY i.sha1
)
`
+func parseQuery(query string) (string, error) {
+ positive, negative := []int64{}, []int64{}
+ for _, word := range strings.Split(query, " ") {
+ if word == "" {
+ continue
+ }
+
+ space, tag, _ := strings.Cut(word, ":")
+
+ negated := false
+ if strings.HasPrefix(space, "-") {
+ space = space[1:]
+ negated = true
+ }
+
+ var tagID int64
+ err := db.QueryRow(`
+ SELECT t.id FROM tag AS t
+ JOIN tag_space AS ts ON t.space = ts.id
+ WHERE ts.name = ? AND t.name = ?`, space, tag).Scan(&tagID)
+ if err != nil {
+ return "", err
+ }
+
+ if negated {
+ negative = append(negative, tagID)
+ } else {
+ positive = append(positive, tagID)
+ }
+ }
+
+ // Don't return most of the database, and simplify the following builder.
+ if len(positive) == 0 {
+ return "", errors.New("search is too wide")
+ }
+
+ // Optimise single tag searches.
+ if len(positive) == 1 && len(negative) == 0 {
+ return fmt.Sprintf(searchCTE1, positive[0]), nil
+ }
+
+ values := fmt.Sprintf(`(%d)`, positive[0])
+ candidates := fmt.Sprintf(
+ `SELECT sha1 FROM tag_assignment WHERE tag = %d`, positive[0])
+ for _, tagID := range positive[1:] {
+ values += fmt.Sprintf(`, (%d)`, tagID)
+ candidates += fmt.Sprintf(` INTERSECT
+ SELECT sha1 FROM tag_assignment WHERE tag = %d`, tagID)
+ }
+ for _, tagID := range negative {
+ candidates += fmt.Sprintf(` EXCEPT
+ SELECT sha1 FROM tag_assignment WHERE tag = %d`, tagID)
+ }
+
+ return fmt.Sprintf(searchCTEMulti, values, candidates), nil
+}
+
+// - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
+
type webTagMatch struct {
SHA1 string `json:"sha1"`
ThumbW int64 `json:"thumbW"`
@@ -974,10 +1069,10 @@ type webTagMatch struct {
Score float32 `json:"score"`
}
-func getTagMatches(tag int64) (matches []webTagMatch, err error) {
- rows, err := db.Query(searchCTE+`
+func getTagMatches(cte string) (matches []webTagMatch, err error) {
+ rows, err := db.Query(cte + `
SELECT sha1, IFNULL(thumbw, 0), IFNULL(thumbh, 0), score
- FROM matches`, tag)
+ FROM matches`)
if err != nil {
return nil, err
}
@@ -1001,13 +1096,13 @@ type webTagSupertag struct {
score float32
}
-func getTagSupertags(tag int64) (result map[int64]*webTagSupertag, err error) {
- rows, err := db.Query(searchCTE+`
+func getTagSupertags(cte string) (result map[int64]*webTagSupertag, err error) {
+ rows, err := db.Query(cte + `
SELECT DISTINCT ta.tag, ts.name, t.name
FROM tag_assignment AS ta
JOIN matches AS m ON m.sha1 = ta.sha1
JOIN tag AS t ON ta.tag = t.id
- JOIN tag_space AS ts ON ts.id = t.space`, tag)
+ JOIN tag_space AS ts ON ts.id = t.space`)
if err != nil {
return nil, err
}
@@ -1032,18 +1127,18 @@ type webTagRelated struct {
Score float32 `json:"score"`
}
-func getTagRelated(tag int64, matches int) (
+func getTagRelated(cte string, matches int) (
result map[string][]webTagRelated, err error) {
// Not sure if this level of efficiency is achievable directly in SQL.
- supertags, err := getTagSupertags(tag)
+ supertags, err := getTagSupertags(cte)
if err != nil {
return nil, err
}
- rows, err := db.Query(searchCTE+`
+ rows, err := db.Query(cte + `
SELECT ta.tag, ta.weight
FROM tag_assignment AS ta
- JOIN matches AS m ON m.sha1 = ta.sha1`, tag)
+ JOIN matches AS m ON m.sha1 = ta.sha1`)
if err != nil {
return nil, err
}
@@ -1084,13 +1179,7 @@ func handleAPISearch(w http.ResponseWriter, r *http.Request) {
Related map[string][]webTagRelated `json:"related"`
}
- space, tag, _ := strings.Cut(params.Query, ":")
-
- var tagID int64
- err := db.QueryRow(`
- SELECT t.id FROM tag AS t
- JOIN tag_space AS ts ON t.space = ts.id
- WHERE ts.name = ? AND t.name = ?`, space, tag).Scan(&tagID)
+ cte, err := parseQuery(params.Query)
if errors.Is(err, sql.ErrNoRows) {
http.Error(w, err.Error(), http.StatusNotFound)
return
@@ -1099,11 +1188,11 @@ func handleAPISearch(w http.ResponseWriter, r *http.Request) {
return
}
- if result.Matches, err = getTagMatches(tagID); err != nil {
+ if result.Matches, err = getTagMatches(cte); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
- if result.Related, err = getTagRelated(tagID,
+ if result.Related, err = getTagRelated(cte,
len(result.Matches)); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
diff --git a/public/gallery.js b/public/gallery.js
index 9d3b067..01439f7 100644
--- a/public/gallery.js
+++ b/public/gallery.js
@@ -646,7 +646,11 @@ let Search = {
m(Header),
m('.body', {}, [
m('.sidebar', [
- m('p', SearchModel.query),
+ m('input', {
+ value: SearchModel.query,
+ onchange: event => m.route.set(
+ `/search/:key`, {key: event.target.value}),
+ }),
m(SearchRelated),
]),
m(SearchView),
diff --git a/public/style.css b/public/style.css
index 1bdeb3f..7fd0079 100644
--- a/public/style.css
+++ b/public/style.css
@@ -27,6 +27,8 @@ a { color: inherit; }
.sidebar { padding: .25rem .5rem; background: var(--shade-color);
border-right: 1px solid #ccc; overflow: auto;
min-width: 10rem; max-width: 20rem; flex-shrink: 0; }
+.sidebar input { width: 100%; box-sizing: border-box; margin: .5rem 0;
+ font-size: inherit; }
.sidebar h2 { margin: 0.5em 0 0.25em 0; padding: 0; font-size: 1.2rem; }
.sidebar ul { margin: .5rem 0; padding: 0; }