aboutsummaryrefslogtreecommitdiff
path: root/deeptagger
diff options
context:
space:
mode:
authorPřemysl Eric Janouch <p@janouch.name>2024-01-07 23:26:05 +0100
committerPřemysl Eric Janouch <p@janouch.name>2024-01-18 18:31:09 +0100
commitb4f28814b7f5cf1d2375963db81f554d470aef83 (patch)
tree6598b3076c270ba5f91a20fc2a47d4ed4e0408be /deeptagger
parent054078908a1e4c7429ea0f5a3a0605addfccc46c (diff)
downloadgallery-b4f28814b7f5cf1d2375963db81f554d470aef83.tar.gz
gallery-b4f28814b7f5cf1d2375963db81f554d470aef83.tar.xz
gallery-b4f28814b7f5cf1d2375963db81f554d470aef83.zip
Add a deep tagger in C++
Diffstat (limited to 'deeptagger')
-rw-r--r--deeptagger/CMakeLists.txt20
-rw-r--r--deeptagger/FindONNXRuntime.cmake11
-rw-r--r--deeptagger/README.adoc25
-rwxr-xr-xdeeptagger/bench.sh38
-rw-r--r--deeptagger/deeptagger.cpp671
-rwxr-xr-xdeeptagger/download.sh161
6 files changed, 926 insertions, 0 deletions
diff --git a/deeptagger/CMakeLists.txt b/deeptagger/CMakeLists.txt
new file mode 100644
index 0000000..9c10bef
--- /dev/null
+++ b/deeptagger/CMakeLists.txt
@@ -0,0 +1,20 @@
+# Ubuntu 20.04 LTS
+cmake_minimum_required (VERSION 3.16)
+project (deeptagger VERSION 0.0.1 LANGUAGES CXX)
+
+# Hint: set ONNXRuntime_ROOT to a directory with a pre-built GitHub release.
+# (Useful for development, otherwise you may need to adjust the rpath.)
+set (CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}")
+
+find_package (ONNXRuntime REQUIRED)
+find_package (PkgConfig REQUIRED)
+pkg_check_modules (GM REQUIRED GraphicsMagick++)
+
+add_executable (deeptagger deeptagger.cpp)
+target_compile_features (deeptagger PRIVATE cxx_std_17)
+target_include_directories (deeptagger PRIVATE
+ ${GM_INCLUDE_DIRS} ${ONNXRuntime_INCLUDE_DIRS})
+target_link_directories (deeptagger PRIVATE
+ ${GM_LIBRARY_DIRS})
+target_link_libraries (deeptagger PRIVATE
+ ${GM_LIBRARIES} ${ONNXRuntime_LIBRARIES})
diff --git a/deeptagger/FindONNXRuntime.cmake b/deeptagger/FindONNXRuntime.cmake
new file mode 100644
index 0000000..902c27d
--- /dev/null
+++ b/deeptagger/FindONNXRuntime.cmake
@@ -0,0 +1,11 @@
+# Public Domain
+
+find_path (ONNXRuntime_INCLUDE_DIRS onnxruntime_c_api.h
+ PATH_SUFFIXES onnxruntime)
+find_library (ONNXRuntime_LIBRARIES NAMES onnxruntime)
+
+include (FindPackageHandleStandardArgs)
+FIND_PACKAGE_HANDLE_STANDARD_ARGS (ONNXRuntime DEFAULT_MSG
+ ONNXRuntime_INCLUDE_DIRS ONNXRuntime_LIBRARIES)
+
+mark_as_advanced (ONNXRuntime_LIBRARIES ONNXRuntime_INCLUDE_DIRS)
diff --git a/deeptagger/README.adoc b/deeptagger/README.adoc
new file mode 100644
index 0000000..8ea83cc
--- /dev/null
+++ b/deeptagger/README.adoc
@@ -0,0 +1,25 @@
+deeptagger
+==========
+
+This is an automatic image tagger/classifier written in C++,
+without using any Python, and primarily targets various anime models.
+
+Unfortunately, you will still need Python and some luck to prepare the models,
+achieved by running download.sh. You will need about 20 gigabytes of space.
+
+Very little effort is made to make this work on non-Unix systems.
+
+Getting this to work
+--------------------
+To build the evaluator, install a C++ compiler, CMake, and development packages
+of GraphicsMagick and ONNX Runtime.
+
+Prebuilt ONNX Runtime can be most conveniently downloaded from
+https://github.com/microsoft/onnxruntime/releases[GitHub releases].
+Remember to install CUDA packages, such as _nvidia-cudnn_ on Debian,
+if you plan on using the GPU-enabled options.
+
+ $ cmake -DONNXRuntime_ROOT=/path/to/onnxruntime -B build
+ $ cmake --build build
+ $ ./download.sh
+ $ build/deeptagger models/deepdanbooru-v3-20211112-sgd-e28.model image.jpg
diff --git a/deeptagger/bench.sh b/deeptagger/bench.sh
new file mode 100755
index 0000000..6b62791
--- /dev/null
+++ b/deeptagger/bench.sh
@@ -0,0 +1,38 @@
+#!/bin/sh -e
+if [ $# -lt 2 ] || ! [ -x "$1" ]
+then
+ echo "Usage: $0 DEEPTAGGER FILE..."
+ echo "Run this after using download.sh, from the same directory."
+ exit 1
+fi
+
+runner=$1
+shift
+log=bench.out
+: >$log
+
+run() {
+ opts=$1 batch=$2 model=$3
+ shift 3
+
+ for i in $(seq 1 3)
+ do
+ start=$(date +%s)
+ "$runner" $opts -b "$batch" -t 0.75 "$model" "$@" >/dev/null || :
+ end=$(date +%s)
+ printf '%s\t%s\t%s\t%s\t%s\n' \
+ "$name" "$model" "$opts" "$batch" "$((end - start))" | tee -a $log
+ done
+}
+
+for model in models/*.model
+do
+ name=$(sed -n 's/^name=//p' "$model")
+ run "" 1 "$model" "$@"
+ run "" 4 "$model" "$@"
+ run "" 16 "$model" "$@"
+
+ run --cpu 1 "$model" "$@"
+ run --cpu 4 "$model" "$@"
+ run --cpu 16 "$model" "$@"
+done
diff --git a/deeptagger/deeptagger.cpp b/deeptagger/deeptagger.cpp
new file mode 100644
index 0000000..27be965
--- /dev/null
+++ b/deeptagger/deeptagger.cpp
@@ -0,0 +1,671 @@
+#include <getopt.h>
+#include <Magick++.h>
+#include <onnxruntime_cxx_api.h>
+#ifdef __APPLE__
+#include <coreml_provider_factory.h>
+#endif
+
+#include <algorithm>
+#include <filesystem>
+#include <fstream>
+#include <iostream>
+#include <regex>
+#include <set>
+#include <stdexcept>
+#include <string>
+#include <tuple>
+
+#include <cstdio>
+#include <cstdint>
+#include <climits>
+
+static struct {
+ bool cpu = false;
+ int debug = 0;
+ long batch = 1;
+ float threshold = 0.1;
+
+ // Execution provider name → Key → Value
+ std::map<std::string, std::map<std::string, std::string>> options;
+} g;
+
+// --- Configuration -----------------------------------------------------------
+
+// Arguably, input normalization could be incorporated into models instead.
+struct Config {
+ std::string name;
+ enum class Shape {NHWC, NCHW} shape = Shape::NHWC;
+ enum class Channels {RGB, BGR} channels = Channels::RGB;
+ bool normalize = false;
+ enum class Pad {WHITE, EDGE, STRETCH} pad = Pad::WHITE;
+ int size = -1;
+ bool sigmoid = false;
+
+ std::vector<std::string> tags;
+};
+
+static void
+read_tags(const std::string &path, std::vector<std::string> &tags)
+{
+ std::ifstream f(path);
+ f.exceptions(std::ifstream::badbit);
+ if (!f)
+ throw std::runtime_error("cannot read tags");
+
+ std::string line;
+ while (std::getline(f, line)) {
+ if (!line.empty() && line.back() == '\r')
+ line.erase(line.size() - 1);
+ tags.push_back(line);
+ }
+}
+
+static void
+read_field(Config &config, std::string key, std::string value)
+{
+ if (key == "name") {
+ config.name = value;
+ } else if (key == "shape") {
+ if (value == "nhwc") config.shape = Config::Shape::NHWC;
+ else if (value == "nchw") config.shape = Config::Shape::NCHW;
+ else throw std::invalid_argument("bad value for: " + key);
+ } else if (key == "channels") {
+ if (value == "rgb") config.channels = Config::Channels::RGB;
+ else if (value == "bgr") config.channels = Config::Channels::BGR;
+ else throw std::invalid_argument("bad value for: " + key);
+ } else if (key == "normalize") {
+ if (value == "true") config.normalize = true;
+ else if (value == "false") config.normalize = false;
+ else throw std::invalid_argument("bad value for: " + key);
+ } else if (key == "pad") {
+ if (value == "white") config.pad = Config::Pad::WHITE;
+ else if (value == "edge") config.pad = Config::Pad::EDGE;
+ else if (value == "stretch") config.pad = Config::Pad::STRETCH;
+ else throw std::invalid_argument("bad value for: " + key);
+ } else if (key == "size") {
+ config.size = std::stoi(value);
+ } else if (key == "interpret") {
+ if (value == "false") config.sigmoid = false;
+ else if (value == "sigmoid") config.sigmoid = true;
+ else throw std::invalid_argument("bad value for: " + key);
+ } else {
+ throw std::invalid_argument("unsupported config key: " + key);
+ }
+}
+
+static void
+read_config(Config &config, const char *path)
+{
+ std::ifstream f(path);
+ f.exceptions(std::ifstream::badbit);
+ if (!f)
+ throw std::runtime_error("cannot read configuration");
+
+ std::regex re(R"(^\s*([^#=]+?)\s*=\s*([^#]*?)\s*(?:#|$))",
+ std::regex::optimize);
+ std::smatch m;
+
+ std::string line;
+ while (std::getline(f, line)) {
+ if (std::regex_match(line, m, re))
+ read_field(config, m[1].str(), m[2].str());
+ }
+
+ read_tags(
+ std::filesystem::path(path).replace_extension("tags"), config.tags);
+}
+
+// --- Data preparation --------------------------------------------------------
+
+static float *
+image_to_nhwc(float *data, Magick::Image &image, Config::Channels channels)
+{
+ unsigned int width = image.columns();
+ unsigned int height = image.rows();
+
+ auto pixels = image.getConstPixels(0, 0, width, height);
+ switch (channels) {
+ case Config::Channels::RGB:
+ for (unsigned int y = 0; y < height; y++) {
+ for (unsigned int x = 0; x < width; x++) {
+ auto pixel = *pixels++;
+ *data++ = ScaleQuantumToChar(pixel.red);
+ *data++ = ScaleQuantumToChar(pixel.green);
+ *data++ = ScaleQuantumToChar(pixel.blue);
+ }
+ }
+ break;
+ case Config::Channels::BGR:
+ for (unsigned int y = 0; y < height; y++) {
+ for (unsigned int x = 0; x < width; x++) {
+ auto pixel = *pixels++;
+ *data++ = ScaleQuantumToChar(pixel.blue);
+ *data++ = ScaleQuantumToChar(pixel.green);
+ *data++ = ScaleQuantumToChar(pixel.red);
+ }
+ }
+ }
+ return data;
+}
+
+static float *
+image_to_nchw(float *data, Magick::Image &image, Config::Channels channels)
+{
+ unsigned int width = image.columns();
+ unsigned int height = image.rows();
+
+ auto pixels = image.getConstPixels(0, 0, width, height), pp = pixels;
+ switch (channels) {
+ case Config::Channels::RGB:
+ for (unsigned int y = 0; y < height; y++)
+ for (unsigned int x = 0; x < width; x++)
+ *data++ = ScaleQuantumToChar((*pp++).red);
+ pp = pixels;
+ for (unsigned int y = 0; y < height; y++)
+ for (unsigned int x = 0; x < width; x++)
+ *data++ = ScaleQuantumToChar((*pp++).green);
+ pp = pixels;
+ for (unsigned int y = 0; y < height; y++)
+ for (unsigned int x = 0; x < width; x++)
+ *data++ = ScaleQuantumToChar((*pp++).blue);
+ break;
+ case Config::Channels::BGR:
+ for (unsigned int y = 0; y < height; y++)
+ for (unsigned int x = 0; x < width; x++)
+ *data++ = ScaleQuantumToChar((*pp++).blue);
+ pp = pixels;
+ for (unsigned int y = 0; y < height; y++)
+ for (unsigned int x = 0; x < width; x++)
+ *data++ = ScaleQuantumToChar((*pp++).green);
+ pp = pixels;
+ for (unsigned int y = 0; y < height; y++)
+ for (unsigned int x = 0; x < width; x++)
+ *data++ = ScaleQuantumToChar((*pp++).red);
+ }
+ return data;
+}
+
+static Magick::Image
+load(const std::string filename,
+ const Config &config, int64_t width, int64_t height)
+{
+ Magick::Image image;
+ try {
+ image.read(filename);
+ } catch (const Magick::Warning &warning) {
+ if (g.debug)
+ fprintf(stderr, "%s: %s\n", filename.c_str(), warning.what());
+ }
+
+ image.autoOrient();
+
+ Magick::Geometry adjusted(width, height);
+ switch (config.pad) {
+ case Config::Pad::EDGE:
+ case Config::Pad::WHITE:
+ adjusted.greater(true);
+ break;
+ case Config::Pad::STRETCH:
+ adjusted.aspect(false);
+ }
+
+ image.resize(adjusted, Magick::LanczosFilter);
+
+ // The GraphicsMagick API doesn't offer any good options.
+ if (config.pad == Config::Pad::EDGE) {
+ MagickLib::SetImageVirtualPixelMethod(
+ image.image(), MagickLib::EdgeVirtualPixelMethod);
+
+ auto x = (int64_t(image.columns()) - width) / 2;
+ auto y = (int64_t(image.rows()) - height) / 2;
+ auto source = image.getConstPixels(x, y, width, height);
+ std::vector<MagickLib::PixelPacket>
+ pixels(source, source + width * height);
+
+ Magick::Image edged(Magick::Geometry(width, height), "black");
+ edged.classType(Magick::DirectClass);
+ auto target = edged.setPixels(0, 0, width, height);
+ memcpy(target, pixels.data(), pixels.size() * sizeof pixels[0]);
+ edged.syncPixels();
+
+ image = edged;
+ }
+
+ // Center it in a square patch of white, removing any transparency.
+ // image.extent() could probably be used to do the same thing.
+ Magick::Image white(Magick::Geometry(width, height), "white");
+ auto x = (white.columns() - image.columns()) / 2;
+ auto y = (white.rows() - image.rows()) / 2;
+ white.composite(image, x, y, Magick::OverCompositeOp);
+ white.fileName(filename);
+
+ if (g.debug > 2)
+ white.display();
+
+ return white;
+}
+
+// --- Inference ---------------------------------------------------------------
+
+static void
+run(std::vector<Magick::Image> &images, const Config &config,
+ Ort::Session &session, std::vector<int64_t> shape)
+{
+ auto batch = shape[0] = images.size();
+
+ Ort::AllocatorWithDefaultOptions allocator;
+ auto tensor = Ort::Value::CreateTensor<float>(
+ allocator, shape.data(), shape.size());
+
+ auto input_len = tensor.GetTensorTypeAndShapeInfo().GetElementCount();
+ auto input_data = tensor.GetTensorMutableData<float>(), pi = input_data;
+ for (int64_t i = 0; i < batch; i++) {
+ switch (config.shape) {
+ case Config::Shape::NCHW:
+ pi = image_to_nchw(pi, images.at(i), config.channels);
+ break;
+ case Config::Shape::NHWC:
+ pi = image_to_nhwc(pi, images.at(i), config.channels);
+ }
+ }
+ if (config.normalize) {
+ pi = input_data;
+ for (size_t i = 0; i < input_len; i++)
+ *pi++ /= 255.0;
+ }
+
+ std::string input_name =
+ session.GetInputNameAllocated(0, allocator).get();
+ std::string output_name =
+ session.GetOutputNameAllocated(0, allocator).get();
+
+ std::vector<const char *> input_names = {input_name.c_str()};
+ std::vector<const char *> output_names = {output_name.c_str()};
+
+ auto outputs = session.Run(Ort::RunOptions{},
+ input_names.data(), &tensor, input_names.size(),
+ output_names.data(), output_names.size());
+ if (outputs.size() != 1 || !outputs[0].IsTensor()) {
+ fprintf(stderr, "Wrong output\n");
+ return;
+ }
+
+ auto output_len = outputs[0].GetTensorTypeAndShapeInfo().GetElementCount();
+ auto output_data = outputs.front().GetTensorData<float>(), po = output_data;
+ if (output_len != batch * config.tags.size()) {
+ fprintf(stderr, "Tags don't match the output\n");
+ return;
+ }
+
+ for (size_t i = 0; i < batch; i++) {
+ for (size_t t = 0; t < config.tags.size(); t++) {
+ float value = *po++;
+ if (config.sigmoid)
+ value = 1 / (1 + std::exp(-value));
+ if (value > g.threshold) {
+ printf("%s\t%.2f\t%s\n", images.at(i).fileName().c_str(),
+ value, config.tags.at(t).c_str());
+ }
+ }
+ }
+}
+
+// - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
+
+static void
+parse_options(const std::string &options)
+{
+ auto semicolon = options.find(";");
+ auto name = options.substr(0, semicolon);
+ auto sequence = options.substr(semicolon);
+
+ std::map<std::string, std::string> kv;
+ std::regex re(R"(;*([^;=]+)=([^;=]+))", std::regex::optimize);
+ std::sregex_iterator it(sequence.begin(), sequence.end(), re), end;
+ for (; it != end; ++it)
+ kv[it->str(1)] = it->str(2);
+ g.options.insert_or_assign(name, std::move(kv));
+}
+
+static std::tuple<std::vector<const char *>, std::vector<const char *>>
+unpack_options(const std::string &provider)
+{
+ std::vector<const char *> keys, values;
+ if (g.options.count(provider)) {
+ for (const auto &kv : g.options.at(provider)) {
+ keys.push_back(kv.first.c_str());
+ values.push_back(kv.second.c_str());
+ }
+ }
+ return {keys, values};
+}
+
+static void
+add_providers(Ort::SessionOptions &options)
+{
+ auto api = Ort::GetApi();
+ auto v_providers = Ort::GetAvailableProviders();
+ std::set<std::string> providers(v_providers.begin(), v_providers.end());
+
+ if (g.debug) {
+ printf("Providers:");
+ for (const auto &it : providers)
+ printf(" %s", it.c_str());
+ printf("\n");
+ }
+
+ // There is a string-based AppendExecutionProvider() method,
+ // but it cannot be used with all providers.
+ // TODO: Make it possible to disable providers.
+ // TODO: Providers will deserve some performance tuning.
+
+ if (g.cpu)
+ return;
+
+#ifdef __APPLE__
+ if (providers.count("CoreMLExecutionProvider")) {
+ try {
+ Ort::ThrowOnError(
+ OrtSessionOptionsAppendExecutionProvider_CoreML(options, 0));
+ } catch (const std::exception &e) {
+ fprintf(stderr, "CoreML unavailable: %s\n", e.what());
+ }
+ }
+#endif
+
+#if TENSORRT
+ // TensorRT should be the more performant execution provider, however:
+ // - it is difficult to set up (needs logging in to download),
+ // - with WD v1.4 ONNX models, one gets "Your ONNX model has been generated
+ // with INT64 weights, while TensorRT does not natively support INT64.
+ // Attempting to cast down to INT32." and that's not nice.
+ if (providers.count("TensorrtExecutionProvider")) {
+ OrtTensorRTProviderOptionsV2* tensorrt_options = nullptr;
+ Ort::ThrowOnError(api.CreateTensorRTProviderOptions(&tensorrt_options));
+ auto [keys, values] = unpack_options("TensorrtExecutionProvider");
+ if (!keys.empty()) {
+ Ort::ThrowOnError(api.UpdateTensorRTProviderOptions(
+ tensorrt_options, keys.data(), values.data(), keys.size()));
+ }
+
+ try {
+ options.AppendExecutionProvider_TensorRT_V2(*tensorrt_options);
+ } catch (const std::exception &e) {
+ fprintf(stderr, "TensorRT unavailable: %s\n", e.what());
+ }
+ api.ReleaseTensorRTProviderOptions(tensorrt_options);
+ }
+#endif
+
+ // See CUDA-ExecutionProvider.html for documentation.
+ if (providers.count("CUDAExecutionProvider")) {
+ OrtCUDAProviderOptionsV2* cuda_options = nullptr;
+ Ort::ThrowOnError(api.CreateCUDAProviderOptions(&cuda_options));
+ auto [keys, values] = unpack_options("CUDAExecutionProvider");
+ if (!keys.empty()) {
+ Ort::ThrowOnError(api.UpdateCUDAProviderOptions(
+ cuda_options, keys.data(), values.data(), keys.size()));
+ }
+
+ try {
+ options.AppendExecutionProvider_CUDA_V2(*cuda_options);
+ } catch (const std::exception &e) {
+ fprintf(stderr, "CUDA unavailable: %s\n", e.what());
+ }
+ api.ReleaseCUDAProviderOptions(cuda_options);
+ }
+
+ if (providers.count("ROCMExecutionProvider")) {
+ OrtROCMProviderOptions rocm_options = {};
+ auto [keys, values] = unpack_options("ROCMExecutionProvider");
+ if (!keys.empty()) {
+ Ort::ThrowOnError(api.UpdateROCMProviderOptions(
+ &rocm_options, keys.data(), values.data(), keys.size()));
+ }
+
+ try {
+ options.AppendExecutionProvider_ROCM(rocm_options);
+ } catch (const std::exception &e) {
+ fprintf(stderr, "ROCM unavailable: %s\n", e.what());
+ }
+ }
+
+ // The CPU provider is the default fallback, if everything else fails.
+}
+
+// - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
+
+static std::string
+print_shape(const Ort::ConstTensorTypeAndShapeInfo &info)
+{
+ std::vector<const char *> names(info.GetDimensionsCount());
+ info.GetSymbolicDimensions(names.data(), names.size());
+
+ auto shape = info.GetShape();
+ std::string result;
+ for (size_t i = 0; i < shape.size(); i++) {
+ if (shape[i] < 0)
+ result.append(names.at(i));
+ else
+ result.append(std::to_string(shape[i]));
+ result.append(" x ");
+ }
+ if (!result.empty())
+ result.erase(result.size() - 3);
+ return result;
+}
+
+static void
+print_shapes(const Ort::Session &session)
+{
+ Ort::AllocatorWithDefaultOptions allocator;
+ for (size_t i = 0; i < session.GetInputCount(); i++) {
+ std::string name = session.GetInputNameAllocated(i, allocator).get();
+ auto info = session.GetInputTypeInfo(i);
+ auto shape = print_shape(info.GetTensorTypeAndShapeInfo());
+ printf("Input: %s: %s\n", name.c_str(), shape.c_str());
+ }
+ for (size_t i = 0; i < session.GetOutputCount(); i++) {
+ std::string name = session.GetOutputNameAllocated(i, allocator).get();
+ auto info = session.GetOutputTypeInfo(i);
+ auto shape = print_shape(info.GetTensorTypeAndShapeInfo());
+ printf("Output: %s: %s\n", name.c_str(), shape.c_str());
+ }
+}
+
+static void
+infer(Ort::Env &env, const char *path, const std::vector<std::string> &images)
+{
+ Config config;
+ read_config(config, path);
+
+ Ort::SessionOptions session_options;
+ add_providers(session_options);
+
+ Ort::Session session = Ort::Session(env,
+ std::filesystem::path(path).replace_extension("onnx").c_str(),
+ session_options);
+
+ if (g.debug)
+ print_shapes(session);
+
+ if (session.GetInputCount() != 1 || session.GetOutputCount() != 1) {
+ fprintf(stderr, "Invalid input or output shape\n");
+ exit(EXIT_FAILURE);
+ }
+
+ auto input_info = session.GetInputTypeInfo(0);
+ auto shape = input_info.GetTensorTypeAndShapeInfo().GetShape();
+ if (shape.size() != 4) {
+ fprintf(stderr, "Incompatible input tensor format\n");
+ exit(EXIT_FAILURE);
+ }
+ if (shape.at(0) > 1) {
+ fprintf(stderr, "Fixed batching not supported\n");
+ exit(EXIT_FAILURE);
+ }
+ if (shape.at(0) >= 0 && g.batch > 1) {
+ fprintf(stderr, "Requested batching for a non-batching model\n");
+ exit(EXIT_FAILURE);
+ }
+
+ int64_t *height = {}, *width = {}, *channels = {};
+ switch (config.shape) {
+ case Config::Shape::NCHW:
+ channels = &shape[1];
+ height = &shape[2];
+ width = &shape[3];
+ break;
+ case Config::Shape::NHWC:
+ height = &shape[1];
+ width = &shape[2];
+ channels = &shape[3];
+ break;
+ }
+
+ // Variable dimensions don't combine well with batches.
+ if (*height < 0)
+ *height = config.size;
+ if (*width < 0)
+ *width = config.size;
+ if (*channels != 3 || *height < 1 || *width < 1) {
+ fprintf(stderr, "Incompatible input tensor format\n");
+ return;
+ }
+
+ // TODO: Image loading is heavily parallelizable. In theory.
+ std::vector<Magick::Image> batch;
+ for (const auto &filename : images) {
+ Magick::Image image;
+ try {
+ image = load(filename, config, *width, *height);
+ } catch (const std::exception &e) {
+ fprintf(stderr, "%s: %s\n", filename.c_str(), e.what());
+ continue;
+ }
+
+ if (*height != image.rows() || *width != image.columns()) {
+ fprintf(stderr, "%s: %s\n", filename.c_str(), "tensor mismatch");
+ continue;
+ }
+
+ batch.push_back(image);
+ if (batch.size() == g.batch) {
+ run(batch, config, session, shape);
+ batch.clear();
+ }
+ }
+ if (!batch.empty())
+ run(batch, config, session, shape);
+}
+
+int
+main(int argc, char *argv[])
+{
+ auto invocation_name = argv[0];
+ auto print_usage = [=] {
+ fprintf(stderr,
+ "Usage: %s [-b BATCH] [--cpu] [-d] [-o EP;KEY=VALUE...] "
+ "[-t THRESHOLD] MODEL { --pipe | [IMAGE...] }\n", invocation_name);
+ };
+
+ static option opts[] = {
+ {"batch", required_argument, 0, 'b'},
+ {"cpu", no_argument, 0, 'c'},
+ {"debug", no_argument, 0, 'd'},
+ {"help", no_argument, 0, 'h'},
+ {"options", required_argument, 0, 'o'},
+ {"pipe", no_argument, 0, 'p'},
+ {"threshold", required_argument, 0, 't'},
+ {nullptr, 0, 0, 0},
+ };
+
+ bool pipe = false;
+ while (1) {
+ int option_index = 0;
+ auto c = getopt_long(argc, const_cast<char *const *>(argv),
+ "b:cdho:pt:", opts, &option_index);
+ if (c == -1)
+ break;
+
+ char *end = nullptr;
+ switch (c) {
+ case 'b':
+ errno = 0, g.batch = strtol(optarg, &end, 10);
+ if (errno || *end || g.batch < 1 || g.batch > SHRT_MAX) {
+ fprintf(stderr, "Batch size must be a positive number\n");
+ exit(EXIT_FAILURE);
+ }
+ break;
+ case 'c':
+ g.cpu = true;
+ break;
+ case 'd':
+ g.debug++;
+ break;
+ case 'h':
+ print_usage();
+ return 0;
+ case 'o':
+ parse_options(optarg);
+ break;
+ case 'p':
+ pipe = true;
+ break;
+ case 't':
+ errno = 0, g.threshold = strtod(optarg, &end);
+ if (errno || *end || !std::isfinite(g.threshold) ||
+ g.threshold < 0 || g.threshold > 1) {
+ fprintf(stderr, "Threshold must be a number within 0..1\n");
+ exit(EXIT_FAILURE);
+ }
+ break;
+ default:
+ print_usage();
+ return 1;
+ }
+ }
+
+ argv += optind;
+ argc -= optind;
+
+ // TODO: There's actually no need to slurp all the lines up front.
+ std::vector<std::string> paths;
+ if (pipe) {
+ if (argc != 1) {
+ print_usage();
+ return 1;
+ }
+
+ std::string line;
+ while (std::getline(std::cin, line))
+ paths.push_back(line);
+ } else {
+ if (argc < 1) {
+ print_usage();
+ return 1;
+ }
+
+ paths.assign(argv + 1, argv + argc);
+ }
+
+ // XXX: GraphicsMagick initializes signal handlers here,
+ // one needs to use MagickLib::InitializeMagickEx()
+ // with MAGICK_OPT_NO_SIGNAL_HANDER to prevent that.
+ //
+ // ImageMagick conveniently has the opposite default.
+ //
+ // Once processing images in parallel, consider presetting
+ // OMP_NUM_THREADS=1 (GM) and/or MAGICK_THREAD_LIMIT=1 (IM).
+ Magick::InitializeMagick(nullptr);
+
+ OrtLoggingLevel logging = g.debug > 1
+ ? ORT_LOGGING_LEVEL_VERBOSE
+ : ORT_LOGGING_LEVEL_WARNING;
+
+ // Creating an environment before initializing providers in order to avoid:
+ // "Attempt to use DefaultLogger but none has been registered."
+ Ort::Env env(logging, invocation_name);
+ infer(env, argv[0], paths);
+ return 0;
+}
diff --git a/deeptagger/download.sh b/deeptagger/download.sh
new file mode 100755
index 0000000..29f651e
--- /dev/null
+++ b/deeptagger/download.sh
@@ -0,0 +1,161 @@
+#!/bin/sh -e
+# Requirements: Python ~ 3.11, curl, unzip, git-lfs, awk
+#
+# This script downloads a bunch of models into the models/ directory,
+# after any necessary transformations to run them using the deeptagger binary.
+#
+# Once it succeeds, feel free to remove everything but *.{model,tags,onnx}
+git lfs install
+mkdir -p models
+cd models
+
+# Create a virtual environment for model conversion.
+#
+# If any of the Python stuff fails,
+# retry from within a Conda environment with a different version of Python.
+export VIRTUAL_ENV=$(pwd)/venv
+export TF_ENABLE_ONEDNN_OPTS=0
+if ! [ -f "$VIRTUAL_ENV/ready" ]
+then
+ python3 -m venv "$VIRTUAL_ENV"
+ #"$VIRTUAL_ENV/bin/pip3" install tensorflow[and-cuda]
+ "$VIRTUAL_ENV/bin/pip3" install tf2onnx 'deepdanbooru[tensorflow]'
+ touch "$VIRTUAL_ENV/ready"
+fi
+
+status() {
+ echo "$(tput bold)-- $*$(tput sgr0)"
+}
+
+# Using the deepdanbooru package makes it possible to use other models
+# trained with the project.
+deepdanbooru() {
+ local name=$1 url=$2
+ status "$name"
+
+ local basename=$(basename "$url")
+ if ! [ -e "$basename" ]
+ then curl -LO "$url"
+ fi
+
+ local modelname=${basename%%.*}
+ if ! [ -d "$modelname" ]
+ then unzip -d "$modelname" "$basename"
+ fi
+
+ if ! [ -e "$modelname.tags" ]
+ then ln "$modelname/tags.txt" "$modelname.tags"
+ fi
+
+ if ! [ -d "$modelname.saved" ]
+ then "$VIRTUAL_ENV/bin/python3" - "$modelname" "$modelname.saved" <<-'END'
+ import sys
+ import deepdanbooru.project as ddp
+ model = ddp.load_model_from_project(
+ project_path=sys.argv[1], compile_model=False)
+ model.export(sys.argv[2])
+ END
+ fi
+
+ if ! [ -e "$modelname.onnx" ]
+ then "$VIRTUAL_ENV/bin/python3" -m tf2onnx.convert \
+ --saved-model "$modelname.saved" --output "$modelname.onnx"
+ fi
+
+ cat > "$modelname.model" <<-END
+ name=$name
+ shape=nhwc
+ channels=rgb
+ normalize=true
+ pad=edge
+ END
+}
+
+# ONNX preconversions don't have a symbolic first dimension, thus doing our own.
+wd14() {
+ local name=$1 repository=$2
+ status "$name"
+
+ local modelname=$(basename "$repository")
+ if ! [ -d "$modelname" ]
+ then git clone "https://huggingface.co/$repository"
+ fi
+
+ # Though link the original export as well.
+ if ! [ -e "$modelname.onnx" ]
+ then ln "$modelname/model.onnx" "$modelname.onnx"
+ fi
+
+ if ! [ -e "$modelname.tags" ]
+ then awk -F, 'NR > 1 { print $2 }' "$modelname/selected_tags.csv" \
+ > "$modelname.tags"
+ fi
+
+ cat > "$modelname.model" <<-END
+ name=$name
+ shape=nhwc
+ channels=bgr
+ normalize=false
+ pad=white
+ END
+
+ if ! [ -e "batch-$modelname.onnx" ]
+ then "$VIRTUAL_ENV/bin/python3" -m tf2onnx.convert \
+ --saved-model "$modelname" --output "batch-$modelname.onnx"
+ fi
+
+ if ! [ -e "batch-$modelname.tags" ]
+ then ln "$modelname.tags" "batch-$modelname.tags"
+ fi
+
+ if ! [ -e "batch-$modelname.model" ]
+ then ln "$modelname.model" "batch-$modelname.model"
+ fi
+}
+
+# These models are an undocumented mess, thus using ONNX preconversions.
+mldanbooru() {
+ local name=$1 basename=$2
+ status "$name"
+
+ if ! [ -d ml-danbooru-onnx ]
+ then git clone https://huggingface.co/deepghs/ml-danbooru-onnx
+ fi
+
+ local modelname=${basename%%.*}
+ if ! [ -e "$basename" ]
+ then ln "ml-danbooru-onnx/$basename"
+ fi
+
+ if ! [ -e "$modelname.tags" ]
+ then awk -F, 'NR > 1 { print $1 }' ml-danbooru-onnx/tags.csv \
+ > "$modelname.tags"
+ fi
+
+ cat > "$modelname.model" <<-END
+ name=$name
+ shape=nchw
+ channels=rgb
+ normalize=true
+ pad=stretch
+ size=640
+ interpret=sigmoid
+ END
+}
+
+status "Downloading models, beware that git-lfs doesn't indicate progress"
+
+deepdanbooru DeepDanbooru \
+ 'https://github.com/KichangKim/DeepDanbooru/releases/download/v3-20211112-sgd-e28/deepdanbooru-v3-20211112-sgd-e28.zip'
+
+#wd14 'WD v1.4 ViT v1' 'SmilingWolf/wd-v1-4-vit-tagger'
+wd14 'WD v1.4 ViT v2' 'SmilingWolf/wd-v1-4-vit-tagger-v2'
+#wd14 'WD v1.4 ConvNeXT v1' 'SmilingWolf/wd-v1-4-convnext-tagger'
+wd14 'WD v1.4 ConvNeXT v2' 'SmilingWolf/wd-v1-4-convnext-tagger-v2'
+wd14 'WD v1.4 ConvNeXTV2 v2' 'SmilingWolf/wd-v1-4-convnextv2-tagger-v2'
+wd14 'WD v1.4 SwinV2 v2' 'SmilingWolf/wd-v1-4-swinv2-tagger-v2'
+wd14 'WD v1.4 MOAT v2' 'SmilingWolf/wd-v1-4-moat-tagger-v2'
+
+# As suggested by author https://github.com/IrisRainbowNeko/ML-Danbooru-webui
+mldanbooru 'ML-Danbooru Caformer dec-5-97527' 'ml_caformer_m36_dec-5-97527.onnx'
+mldanbooru 'ML-Danbooru TResNet-D 6-30000' 'TResnet-D-FLq_ema_6-30000.onnx'