From 9d051c5d4fabdf0a8df464df0a385d2393355377 Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Tue, 16 May 2023 19:32:53 +0200 Subject: [PATCH] feat: add image generation with ncnn-stablediffusion (#272) --- .dockerignore | 4 +- .gitignore | 3 +- Makefile | 39 +++- README.md | 217 +++++++++++++++++++- api/api.go | 8 +- api/api_test.go | 4 +- api/config.go | 13 +- api/openai.go | 189 ++++++++++++++++- api/prediction.go | 42 +++- examples/chatbot-ui/README.md | 4 +- go.mod | 3 +- go.sum | 4 + main.go | 8 +- pkg/model/initializers.go | 45 ++-- pkg/stablediffusion/generate.go | 23 +++ pkg/stablediffusion/generate_unsupported.go | 10 + pkg/stablediffusion/stablediffusion.go | 20 ++ 17 files changed, 580 insertions(+), 56 deletions(-) create mode 100644 pkg/stablediffusion/generate.go create mode 100644 pkg/stablediffusion/generate_unsupported.go create mode 100644 pkg/stablediffusion/stablediffusion.go diff --git a/.dockerignore b/.dockerignore index dff08bd..4147850 100644 --- a/.dockerignore +++ b/.dockerignore @@ -1,2 +1,4 @@ models -examples/chatbot-ui/models \ No newline at end of file +examples/chatbot-ui/models +examples/rwkv/models +examples/**/models diff --git a/.gitignore b/.gitignore index 98cdf70..12c461c 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,7 @@ # go-llama build artifacts go-llama -go-gpt4all-j +gpt4all +go-stable-diffusion go-gpt2 go-rwkv whisper.cpp diff --git a/Makefile b/Makefile index f296a83..dec3569 100644 --- a/Makefile +++ b/Makefile @@ -4,8 +4,8 @@ GOVET=$(GOCMD) vet BINARY_NAME=local-ai GOLLAMA_VERSION?=7f9ae4246088f0c08ed322acbae21d69cd2eb547 -GPT4ALL_REPO?=https://github.com/go-skynet/gpt4all -GPT4ALL_VERSION?=a330bfe26e9e35ca402e16df18973a3b162fb4db +GPT4ALL_REPO?=https://github.com/nomic-ai/gpt4all +GPT4ALL_VERSION?=a07237e54fcdfdb351913587052ac061a2a7bdff GOGPT2_VERSION?=92421a8cf61ed6e03babd9067af292b094cb1307 RWKV_REPO?=https://github.com/donomii/go-rwkv.cpp RWKV_VERSION?=07166da10cb2a9e8854395a4f210464dcea76e47 @@ -15,6 +15,11 @@ BLOOMZ_VERSION?=e9366e82abdfe70565644fbfae9651976714efd1 BUILD_TYPE?= CGO_LDFLAGS?= CUDA_LIBPATH?=/usr/local/cuda/lib64/ +STABLEDIFFUSION_VERSION?=c0748eca3642d58bcf9521108bcee46959c647dc + +GO_TAGS?= + +OPTIONAL_TARGETS?= GREEN := $(shell tput -Txterm setaf 2) YELLOW := $(shell tput -Txterm setaf 3) @@ -22,8 +27,8 @@ WHITE := $(shell tput -Txterm setaf 7) CYAN := $(shell tput -Txterm setaf 6) RESET := $(shell tput -Txterm sgr0) -C_INCLUDE_PATH=$(shell pwd)/go-llama:$(shell pwd)/gpt4all/gpt4all-bindings/golang/:$(shell pwd)/go-gpt2:$(shell pwd)/go-rwkv:$(shell pwd)/whisper.cpp:$(shell pwd)/go-bert:$(shell pwd)/bloomz -LIBRARY_PATH=$(shell pwd)/go-llama:$(shell pwd)/gpt4all/gpt4all-bindings/golang/:$(shell pwd)/go-gpt2:$(shell pwd)/go-rwkv:$(shell pwd)/whisper.cpp:$(shell pwd)/go-bert:$(shell pwd)/bloomz +C_INCLUDE_PATH=$(shell pwd)/go-llama:$(shell pwd)/go-stable-diffusion/:$(shell pwd)/gpt4all/gpt4all-bindings/golang/:$(shell pwd)/go-gpt2:$(shell pwd)/go-rwkv:$(shell pwd)/whisper.cpp:$(shell pwd)/go-bert:$(shell pwd)/bloomz +LIBRARY_PATH=$(shell pwd)/go-llama:$(shell pwd)/go-stable-diffusion/:$(shell pwd)/gpt4all/gpt4all-bindings/golang/:$(shell pwd)/go-gpt2:$(shell pwd)/go-rwkv:$(shell pwd)/whisper.cpp:$(shell pwd)/go-bert:$(shell pwd)/bloomz ifeq ($(BUILD_TYPE),openblas) CGO_LDFLAGS+=-lopenblas @@ -33,6 +38,11 @@ ifeq ($(BUILD_TYPE),cublas) CGO_LDFLAGS+=-lcublas -lcudart -L$(CUDA_LIBPATH) endif + +ifeq ($(GO_TAGS),stablediffusion) + OPTIONAL_TARGETS+=go-stable-diffusion/libstablediffusion.a +endif + .PHONY: all test build vendor all: help @@ -66,6 +76,14 @@ go-bert: @find ./go-bert -type f -name "*.cpp" -exec sed -i'' -e 's/ggml_/ggml_bert_/g' {} + @find ./go-bert -type f -name "*.h" -exec sed -i'' -e 's/ggml_/ggml_bert_/g' {} + +## stable diffusion +go-stable-diffusion: + git clone --recurse-submodules https://github.com/mudler/go-stable-diffusion go-stable-diffusion + cd go-stable-diffusion && git checkout -b build $(STABLEDIFFUSION_VERSION) && git submodule update --init --recursive --depth 1 + +go-stable-diffusion/libstablediffusion.a: + $(MAKE) -C go-stable-diffusion libstablediffusion.a + ## RWKV go-rwkv: git clone --recurse-submodules $(RWKV_REPO) go-rwkv @@ -133,14 +151,15 @@ go-llama/libbinding.a: go-llama replace: $(GOCMD) mod edit -replace github.com/go-skynet/go-llama.cpp=$(shell pwd)/go-llama - $(GOCMD) mod edit -replace github.com/nomic/gpt4all/gpt4all-bindings/golang=$(shell pwd)/gpt4all/gpt4all-bindings/golang + $(GOCMD) mod edit -replace github.com/nomic-ai/gpt4all/gpt4all-bindings/golang=$(shell pwd)/gpt4all/gpt4all-bindings/golang $(GOCMD) mod edit -replace github.com/go-skynet/go-gpt2.cpp=$(shell pwd)/go-gpt2 $(GOCMD) mod edit -replace github.com/donomii/go-rwkv.cpp=$(shell pwd)/go-rwkv $(GOCMD) mod edit -replace github.com/ggerganov/whisper.cpp=$(shell pwd)/whisper.cpp $(GOCMD) mod edit -replace github.com/go-skynet/go-bert.cpp=$(shell pwd)/go-bert $(GOCMD) mod edit -replace github.com/go-skynet/bloomz.cpp=$(shell pwd)/bloomz + $(GOCMD) mod edit -replace github.com/mudler/go-stable-diffusion=$(shell pwd)/go-stable-diffusion -prepare-sources: go-llama go-gpt2 gpt4all go-rwkv whisper.cpp go-bert bloomz replace +prepare-sources: go-llama go-gpt2 gpt4all go-rwkv whisper.cpp go-bert bloomz go-stable-diffusion replace $(GOCMD) mod download ## GENERIC @@ -150,19 +169,22 @@ rebuild: ## Rebuilds the project $(MAKE) -C go-gpt2 clean $(MAKE) -C go-rwkv clean $(MAKE) -C whisper.cpp clean + $(MAKE) -C go-stable-diffusion clean $(MAKE) -C go-bert clean $(MAKE) -C bloomz clean $(MAKE) build -prepare: prepare-sources gpt4all/gpt4all-bindings/golang/libgpt4all.a go-llama/libbinding.a go-bert/libgobert.a go-gpt2/libgpt2.a go-rwkv/librwkv.a whisper.cpp/libwhisper.a bloomz/libbloomz.a ## Prepares for building +prepare: prepare-sources gpt4all/gpt4all-bindings/golang/libgpt4all.a $(OPTIONAL_TARGETS) go-llama/libbinding.a go-bert/libgobert.a go-gpt2/libgpt2.a go-rwkv/librwkv.a whisper.cpp/libwhisper.a bloomz/libbloomz.a ## Prepares for building clean: ## Remove build related file rm -fr ./go-llama rm -rf ./gpt4all + rm -rf ./go-stable-diffusion rm -rf ./go-gpt2 rm -rf ./go-rwkv rm -rf ./go-bert rm -rf ./bloomz + rm -rf ./whisper.cpp rm -rf $(BINARY_NAME) ## Build: @@ -170,7 +192,8 @@ clean: ## Remove build related file build: prepare ## Build the project $(info ${GREEN}I local-ai build info:${RESET}) $(info ${GREEN}I BUILD_TYPE: ${YELLOW}$(BUILD_TYPE)${RESET}) - CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=${C_INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} $(GOCMD) build -x -o $(BINARY_NAME) ./ + $(info ${GREEN}I GO_TAGS: ${YELLOW}$(GO_TAGS)${RESET}) + CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=${C_INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} $(GOCMD) build -tags "$(GO_TAGS)" -x -o $(BINARY_NAME) ./ generic-build: ## Build the project using generic BUILD_TYPE="generic" $(MAKE) build diff --git a/README.md b/README.md index fc64e20..d9986e6 100644 --- a/README.md +++ b/README.md @@ -12,7 +12,7 @@ **LocalAI** is a drop-in replacement REST API compatible with OpenAI API specifications for local inferencing. It allows to run models locally or on-prem with consumer grade hardware, supporting multiple models families compatible with the `ggml` format. For a list of the supported model families, see [the model compatibility table below](https://github.com/go-skynet/LocalAI#model-compatibility-table). - OpenAI drop-in alternative REST API -- Supports multiple models +- Supports multiple models, Audio transcription, Text generation with GPTs, Image generation with stable diffusion (experimental) - Once loaded the first time, it keep models loaded in memory for faster inference - Support for prompt templates - Doesn't shell-out, but uses C++ bindings for a faster inference and better performance. @@ -23,6 +23,7 @@ LocalAI uses C++ bindings for optimizing speed. It is based on [llama.cpp](https See [examples on how to integrate LocalAI](https://github.com/go-skynet/LocalAI/tree/master/examples/). + ### How does it work?
@@ -33,6 +34,14 @@ See [examples on how to integrate LocalAI](https://github.com/go-skynet/LocalAI/ ## News +- 16-05-2023: 🔥🔥🔥 Experimental support for CUDA (https://github.com/go-skynet/LocalAI/pull/258) in the `llama.cpp` backend and Stable diffusion CPU image generation (https://github.com/go-skynet/LocalAI/pull/272) in `master`. + +Now LocalAI can generate images too: + +| mode=0 | mode=1 (winograd/sgemm) | +|------------------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------------------------| +| ![b6441997879](https://github.com/go-skynet/LocalAI/assets/2420543/d50af51c-51b7-4f39-b6c2-bf04c403894c) | ![winograd2](https://github.com/go-skynet/LocalAI/assets/2420543/1935a69a-ecce-4afc-a099-1ac28cb649b3) | + - 14-05-2023: __v1.11.1__ released! `rwkv` backend patch release - 13-05-2023: __v1.11.0__ released! 🔥 Updated `llama.cpp` bindings: This update includes a breaking change in the model files ( https://github.com/ggerganov/llama.cpp/pull/1405 ) - old models should still work with the `gpt4all-llama` backend. - 12-05-2023: __v1.10.0__ released! 🔥🔥 Updated `gpt4all` bindings. Added support for GPTNeox (experimental), RedPajama (experimental), Starcoder (experimental), Replit (experimental), MosaicML MPT. Also now `embeddings` endpoint supports tokens arrays. See the [langchain-chroma](https://github.com/go-skynet/LocalAI/tree/master/examples/langchain-chroma) example! Note - this update does NOT include https://github.com/ggerganov/llama.cpp/pull/1405 which makes models incompatible. @@ -106,7 +115,7 @@ Depending on the model you are attempting to run might need more RAM or CPU reso
-| Backend | Compatible models | Completion/Chat endpoint | Audio transcription | Embeddings support | Token stream support | Github | Bindings | +| Backend | Compatible models | Completion/Chat endpoint | Audio transcription/Image | Embeddings support | Token stream support | Github | Bindings | |-----------------|-----------------------|--------------------------|---------------------|-----------------------------------|----------------------|--------------------------------------------|-------------------------------------------| | llama | Vicuna, Alpaca, LLaMa | yes | no | yes (doesn't seem to be accurate) | yes | https://github.com/ggerganov/llama.cpp | https://github.com/go-skynet/go-llama.cpp | | gpt4all-llama | Vicuna, Alpaca, LLaMa | yes | no | no | yes | https://github.com/nomic-ai/gpt4all | https://github.com/go-skynet/gpt4all | @@ -122,8 +131,8 @@ Depending on the model you are attempting to run might need more RAM or CPU reso | bloomz | Bloom | yes | no | no | no | https://github.com/NouamaneTazi/bloomz.cpp | https://github.com/go-skynet/bloomz.cpp | | rwkv | RWKV | yes | no | no | yes | https://github.com/saharNooby/rwkv.cpp | https://github.com/donomii/go-rwkv.cpp | | bert-embeddings | bert | no | no | yes | no | https://github.com/skeskinen/bert.cpp | https://github.com/go-skynet/go-bert.cpp | -| whisper | whisper | no | yes | no | no | https://github.com/ggerganov/whisper.cpp | https://github.com/ggerganov/whisper.cpp | - +| whisper | whisper | no | Audio | no | no | https://github.com/ggerganov/whisper.cpp | https://github.com/ggerganov/whisper.cpp | +| stablediffusion | stablediffusion | no | Image | no | no | https://github.com/EdVince/Stable-Diffusion-NCNN | https://github.com/mudler/go-stable-diffusion |
## Usage @@ -148,7 +157,9 @@ cp your-model.bin models/ # vim .env # start with docker-compose -docker-compose up -d --build +docker-compose up -d --pull always +# or you can build the images with: +# docker-compose up -d --build # Now API is accessible at localhost:8080 curl http://localhost:8080/v1/models @@ -184,8 +195,9 @@ cp -rf prompt-templates/ggml-gpt4all-j.tmpl models/ # vim .env # start with docker-compose -docker-compose up -d --build - +docker-compose up -d --pull always +# or you can build the images with: +# docker-compose up -d --build # Now API is accessible at localhost:8080 curl http://localhost:8080/v1/models # {"object":"list","data":[{"id":"ggml-gpt4all-j","object":"model"}]} @@ -204,6 +216,8 @@ To build locally, run `make build` (see below). ### Other examples +![Screenshot from 2023-04-26 23-59-55](https://user-images.githubusercontent.com/2420543/234715439-98d12e03-d3ce-4f94-ab54-2b256808e05e.png) + To see other examples on how to integrate with other projects for instance for question answering or for using it with chatbot-ui, see: [examples](https://github.com/go-skynet/LocalAI/tree/master/examples/). @@ -294,6 +308,73 @@ Specifying a `config-file` via CLI allows to declare models in a single file as See also [chatbot-ui](https://github.com/go-skynet/LocalAI/tree/master/examples/chatbot-ui) as an example on how to use config files. +### Full config model file reference + +```yaml +name: gpt-3.5-turbo + +# Default model parameters +parameters: + # Relative to the models path + model: ggml-gpt4all-j + # temperature + temperature: 0.3 + # all the OpenAI request options here.. + top_k: + top_p: + max_tokens: + batch: + f16: true + ignore_eos: true + n_keep: 10 + seed: + mode: + step: + +# Default context size +context_size: 512 +# Default number of threads +threads: 10 +# Define a backend (optional). By default it will try to guess the backend the first time the model is interacted with. +backend: gptj # available: llama, stablelm, gpt2, gptj rwkv +# stopwords (if supported by the backend) +stopwords: +- "HUMAN:" +- "### Response:" +# string to trim space to +trimspace: +- string +# Strings to cut from the response +cutstrings: +- "string" +# define chat roles +roles: + user: "HUMAN:" + system: "GPT:" + assistant: "ASSISTANT:" +template: + # template file ".tmpl" with the prompt template to use by default on the endpoint call. Note there is no extension in the files + completion: completion + chat: ggml-gpt4all-j + edit: edit_template + +# Enable F16 if backend supports it +f16: true +# Enable debugging +debug: true +# Enable embeddings +embeddings: true +# Mirostat configuration (llama.cpp only) +mirostat_eta: 0.8 +mirostat_tau: 0.9 +mirostat: 1 + +# GPU Layers (only used when built with cublas) +gpu_layers: 22 + +# Directory used to store additional assets (used for stablediffusion) +asset_dir: "" +```
### Prompt templates @@ -351,6 +432,8 @@ local-ai --models-path [--address
] [--threads @@ -443,6 +526,48 @@ curl http://localhost:8080/v1/chat/completions -H "Content-Type: application/jso +### Build with Image generation support + +
+ +**Requirements**: OpenCV, Gomp + +Image generation is experimental and requires `GO_TAGS=stablediffusion` to be set during build: + +``` +make GO_TAGS=stablediffusion rebuild +``` + +
+ +### Accelleration + +#### OpenBLAS + +
+ +Requirements: OpenBLAS + +``` +make BUILD_TYPE=openblas build +``` + +
+ +#### CuBLAS + +
+ +Requirement: Nvidia CUDA toolkit + +Note: CuBLAS support is experimental, and has not been tested on real HW. please report any issues you find! + +``` +make BUILD_TYPE=cublas build +``` + +
+ ### Windows compatibility It should work, however you need to make sure you give enough resources to the container. See https://github.com/go-skynet/LocalAI/issues/2 @@ -615,6 +740,77 @@ curl http://localhost:8080/v1/audio/transcriptions -H "Content-Type: multipart/f +### Image generation + +LocalAI supports generating images with Stable diffusion, running on CPU. + +| mode=0 | mode=1 (winograd/sgemm) | +|------------------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------------------------| +| ![test](https://github.com/go-skynet/LocalAI/assets/2420543/7145bdee-4134-45bb-84d4-f11cb08a5638) | ![b643343452981](https://github.com/go-skynet/LocalAI/assets/2420543/abf14de1-4f50-4715-aaa4-411d703a942a) | +| ![b6441997879](https://github.com/go-skynet/LocalAI/assets/2420543/d50af51c-51b7-4f39-b6c2-bf04c403894c) | ![winograd2](https://github.com/go-skynet/LocalAI/assets/2420543/1935a69a-ecce-4afc-a099-1ac28cb649b3) | +| ![winograd](https://github.com/go-skynet/LocalAI/assets/2420543/1979a8c4-a70d-4602-95ed-642f382f6c6a) | ![winograd3](https://github.com/go-skynet/LocalAI/assets/2420543/e6d184d4-5002-408f-b564-163986e1bdfb) | + +
+ +To generate an image you can send a POST request to the `/v1/images/generations` endpoint with the instruction as the request body: + +```bash +# 512x512 is supported too +curl http://localhost:8080/v1/images/generations -H "Content-Type: application/json" -d '{ + "prompt": "A cute baby sea otter", + "size": "256x256" + }' +``` + +Available additional parameters: `mode`, `step`. + +Note: To set a negative prompt, you can split the prompt with `|`, for instance: `a cute baby sea otter|malformed`. + +```bash +curl http://localhost:8080/v1/images/generations -H "Content-Type: application/json" -d '{ + "prompt": "floating hair, portrait, ((loli)), ((one girl)), cute face, hidden hands, asymmetrical bangs, beautiful detailed eyes, eye shadow, hair ornament, ribbons, bowties, buttons, pleated skirt, (((masterpiece))), ((best quality)), colorful|((part of the head)), ((((mutated hands and fingers)))), deformed, blurry, bad anatomy, disfigured, poorly drawn face, mutation, mutated, extra limb, ugly, poorly drawn hands, missing limb, blurry, floating limbs, disconnected limbs, malformed hands, blur, out of focus, long neck, long body, Octane renderer, lowres, bad anatomy, bad hands, text", + "size": "256x256" + }' +``` + +#### Setup + +Note: In order to use the `images/generation` endpoint, you need to build LocalAI with `GO_TAGS=stablediffusion`. + +1. Create a model file `stablediffusion.yaml` in the models folder: + +```yaml +name: stablediffusion +backend: stablediffusion +asset_dir: stablediffusion_assets +``` +2. Create a `stablediffusion_assets` directory inside your `models` directory +3. Download the ncnn assets from https://github.com/EdVince/Stable-Diffusion-NCNN#out-of-box and place them in `stablediffusion_assets`. + +The models directory should look like the following: + +``` +models +├── stablediffusion_assets +│   ├── AutoencoderKL-256-256-fp16-opt.param +│   ├── AutoencoderKL-512-512-fp16-opt.param +│   ├── AutoencoderKL-base-fp16.param +│   ├── AutoencoderKL-encoder-512-512-fp16.bin +│   ├── AutoencoderKL-fp16.bin +│   ├── FrozenCLIPEmbedder-fp16.bin +│   ├── FrozenCLIPEmbedder-fp16.param +│   ├── log_sigmas.bin +│   ├── tmp-AutoencoderKL-encoder-256-256-fp16.param +│   ├── UNetModel-256-256-MHA-fp16-opt.param +│   ├── UNetModel-512-512-MHA-fp16-opt.param +│   ├── UNetModel-base-MHA-fp16.param +│   ├── UNetModel-MHA-fp16.bin +│   └── vocab.txt +└── stablediffusion.yaml +``` + +
+ ## Frequently asked questions Here are answers to some of the most common questions. @@ -716,10 +912,15 @@ MIT ## Acknowledgements +LocalAI couldn't have been built without the help of great software already available from the community. Thank you! + - [llama.cpp](https://github.com/ggerganov/llama.cpp) - https://github.com/tatsu-lab/stanford_alpaca - https://github.com/cornelk/llama-go for the initial ideas -- https://github.com/antimatter15/alpaca.cpp for the light model version (this is compatible and tested only with that checkpoint model!) +- https://github.com/antimatter15/alpaca.cpp +- https://github.com/EdVince/Stable-Diffusion-NCNN +- https://github.com/ggerganov/whisper.cpp +- https://github.com/saharNooby/rwkv.cpp ## Contributors diff --git a/api/api.go b/api/api.go index 59489f7..ecf56b0 100644 --- a/api/api.go +++ b/api/api.go @@ -12,7 +12,7 @@ import ( "github.com/rs/zerolog/log" ) -func App(configFile string, loader *model.ModelLoader, uploadLimitMB, threads, ctxSize int, f16 bool, debug, disableMessage bool) *fiber.App { +func App(configFile string, loader *model.ModelLoader, uploadLimitMB, threads, ctxSize int, f16 bool, debug, disableMessage bool, imageDir string) *fiber.App { zerolog.SetGlobalLevel(zerolog.InfoLevel) if debug { zerolog.SetGlobalLevel(zerolog.DebugLevel) @@ -87,6 +87,12 @@ func App(configFile string, loader *model.ModelLoader, uploadLimitMB, threads, c app.Post("/v1/audio/transcriptions", transcriptEndpoint(cm, debug, loader, threads, ctxSize, f16)) + app.Post("/v1/images/generations", imageEndpoint(cm, debug, loader, imageDir)) + + if imageDir != "" { + app.Static("/generated-images", imageDir) + } + app.Get("/v1/models", listModels(loader, cm)) app.Get("/models", listModels(loader, cm)) diff --git a/api/api_test.go b/api/api_test.go index a05096d..f2af038 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -25,7 +25,7 @@ var _ = Describe("API test", func() { Context("API query", func() { BeforeEach(func() { modelLoader = model.NewModelLoader(os.Getenv("MODELS_PATH")) - app = App("", modelLoader, 15, 1, 512, false, true, true) + app = App("", modelLoader, 15, 1, 512, false, true, true, "") go app.Listen("127.0.0.1:9090") defaultConfig := openai.DefaultConfig("") @@ -140,7 +140,7 @@ var _ = Describe("API test", func() { Context("Config file", func() { BeforeEach(func() { modelLoader = model.NewModelLoader(os.Getenv("MODELS_PATH")) - app = App(os.Getenv("CONFIG_FILE"), modelLoader, 5, 1, 512, false, true, true) + app = App(os.Getenv("CONFIG_FILE"), modelLoader, 5, 1, 512, false, true, true, "") go app.Listen("127.0.0.1:9090") defaultConfig := openai.DefaultConfig("") diff --git a/api/config.go b/api/config.go index 3791d49..7379978 100644 --- a/api/config.go +++ b/api/config.go @@ -32,6 +32,7 @@ type Config struct { MirostatTAU float64 `yaml:"mirostat_tau"` Mirostat int `yaml:"mirostat"` NGPULayers int `yaml:"gpu_layers"` + ImageGenerationAssets string `yaml:"asset_dir"` PromptStrings, InputStrings []string InputToken [][]int } @@ -211,12 +212,11 @@ func updateConfig(config *Config, input *OpenAIRequest) { } } } - -func readConfig(cm ConfigMerger, c *fiber.Ctx, loader *model.ModelLoader, debug bool, threads, ctx int, f16 bool) (*Config, *OpenAIRequest, error) { +func readInput(c *fiber.Ctx, loader *model.ModelLoader, randomModel bool) (string, *OpenAIRequest, error) { input := new(OpenAIRequest) // Get input data from the request body if err := c.BodyParser(input); err != nil { - return nil, nil, err + return "", nil, err } modelFile := input.Model @@ -234,14 +234,14 @@ func readConfig(cm ConfigMerger, c *fiber.Ctx, loader *model.ModelLoader, debug bearerExists := bearer != "" && loader.ExistsInModelPath(bearer) // If no model was specified, take the first available - if modelFile == "" && !bearerExists { + if modelFile == "" && !bearerExists && randomModel { models, _ := loader.ListModels() if len(models) > 0 { modelFile = models[0] log.Debug().Msgf("No model specified, using: %s", modelFile) } else { log.Debug().Msgf("No model specified, returning error") - return nil, nil, fmt.Errorf("no model specified") + return "", nil, fmt.Errorf("no model specified") } } @@ -250,7 +250,10 @@ func readConfig(cm ConfigMerger, c *fiber.Ctx, loader *model.ModelLoader, debug log.Debug().Msgf("Using model from bearer token: %s", bearer) modelFile = bearer } + return modelFile, input, nil +} +func readConfig(modelFile string, input *OpenAIRequest, cm ConfigMerger, loader *model.ModelLoader, debug bool, threads, ctx int, f16 bool) (*Config, *OpenAIRequest, error) { // Load a config file if present after the model name modelConfig := filepath.Join(loader.ModelPath, modelFile+".yaml") if _, err := os.Stat(modelConfig); err == nil { diff --git a/api/openai.go b/api/openai.go index 7b65135..19284f2 100644 --- a/api/openai.go +++ b/api/openai.go @@ -3,13 +3,16 @@ package api import ( "bufio" "bytes" + "encoding/base64" "encoding/json" "fmt" "io" + "io/ioutil" "net/http" "os" "path" "path/filepath" + "strconv" "strings" "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper" @@ -43,6 +46,10 @@ type Item struct { Embedding []float32 `json:"embedding"` Index int `json:"index"` Object string `json:"object,omitempty"` + + // Images + URL string `json:"url,omitempty"` + B64JSON string `json:"b64_json,omitempty"` } type OpenAIResponse struct { @@ -78,11 +85,13 @@ type OpenAIRequest struct { Model string `json:"model" yaml:"model"` // whisper - File string `json:"file" validate:"required"` + File string `json:"file" validate:"required"` + Language string `json:"language"` + //whisper/image ResponseFormat string `json:"response_format"` - Language string `json:"language"` - - // Prompt is read only by completion API calls + // image + Size string `json:"size"` + // Prompt is read only by completion/image API calls Prompt interface{} `json:"prompt" yaml:"prompt"` // Edit endpoint @@ -116,6 +125,10 @@ type OpenAIRequest struct { Mirostat int `json:"mirostat" yaml:"mirostat"` Seed int `json:"seed" yaml:"seed"` + + // Image (not supported by OpenAI) + Mode int `json:"mode"` + Step int `json:"step"` } func defaultRequest(modelFile string) OpenAIRequest { @@ -131,7 +144,13 @@ func defaultRequest(modelFile string) OpenAIRequest { // https://platform.openai.com/docs/api-reference/completions func completionEndpoint(cm ConfigMerger, debug bool, loader *model.ModelLoader, threads, ctx int, f16 bool) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { - config, input, err := readConfig(cm, c, loader, debug, threads, ctx, f16) + + model, input, err := readInput(c, loader, true) + if err != nil { + return fmt.Errorf("failed reading parameters from request:%w", err) + } + + config, input, err := readConfig(model, input, cm, loader, debug, threads, ctx, f16) if err != nil { return fmt.Errorf("failed reading parameters from request:%w", err) } @@ -182,7 +201,12 @@ func completionEndpoint(cm ConfigMerger, debug bool, loader *model.ModelLoader, // https://platform.openai.com/docs/api-reference/embeddings func embeddingsEndpoint(cm ConfigMerger, debug bool, loader *model.ModelLoader, threads, ctx int, f16 bool) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { - config, input, err := readConfig(cm, c, loader, debug, threads, ctx, f16) + model, input, err := readInput(c, loader, true) + if err != nil { + return fmt.Errorf("failed reading parameters from request:%w", err) + } + + config, input, err := readConfig(model, input, cm, loader, debug, threads, ctx, f16) if err != nil { return fmt.Errorf("failed reading parameters from request:%w", err) } @@ -249,7 +273,12 @@ func chatEndpoint(cm ConfigMerger, debug bool, loader *model.ModelLoader, thread close(responses) } return func(c *fiber.Ctx) error { - config, input, err := readConfig(cm, c, loader, debug, threads, ctx, f16) + model, input, err := readInput(c, loader, true) + if err != nil { + return fmt.Errorf("failed reading parameters from request:%w", err) + } + + config, input, err := readConfig(model, input, cm, loader, debug, threads, ctx, f16) if err != nil { return fmt.Errorf("failed reading parameters from request:%w", err) } @@ -349,7 +378,12 @@ func chatEndpoint(cm ConfigMerger, debug bool, loader *model.ModelLoader, thread func editEndpoint(cm ConfigMerger, debug bool, loader *model.ModelLoader, threads, ctx int, f16 bool) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { - config, input, err := readConfig(cm, c, loader, debug, threads, ctx, f16) + model, input, err := readInput(c, loader, true) + if err != nil { + return fmt.Errorf("failed reading parameters from request:%w", err) + } + + config, input, err := readConfig(model, input, cm, loader, debug, threads, ctx, f16) if err != nil { return fmt.Errorf("failed reading parameters from request:%w", err) } @@ -398,14 +432,151 @@ func editEndpoint(cm ConfigMerger, debug bool, loader *model.ModelLoader, thread } } +// https://platform.openai.com/docs/api-reference/images/create + +/* +* + + curl http://localhost:8080/v1/images/generations \ + -H "Content-Type: application/json" \ + -d '{ + "prompt": "A cute baby sea otter", + "n": 1, + "size": "512x512" + }' + +* +*/ +func imageEndpoint(cm ConfigMerger, debug bool, loader *model.ModelLoader, imageDir string) func(c *fiber.Ctx) error { + return func(c *fiber.Ctx) error { + m, input, err := readInput(c, loader, false) + if err != nil { + return fmt.Errorf("failed reading parameters from request:%w", err) + } + + if m == "" { + m = model.StableDiffusionBackend + } + log.Debug().Msgf("Loading model: %+v", m) + + config, input, err := readConfig(m, input, cm, loader, debug, 0, 0, false) + if err != nil { + return fmt.Errorf("failed reading parameters from request:%w", err) + } + + log.Debug().Msgf("Parameter Config: %+v", config) + + // XXX: Only stablediffusion is supported for now + if config.Backend == "" { + config.Backend = model.StableDiffusionBackend + } + + sizeParts := strings.Split(input.Size, "x") + if len(sizeParts) != 2 { + return fmt.Errorf("Invalid value for 'size'") + } + width, err := strconv.Atoi(sizeParts[0]) + if err != nil { + return fmt.Errorf("Invalid value for 'size'") + } + height, err := strconv.Atoi(sizeParts[1]) + if err != nil { + return fmt.Errorf("Invalid value for 'size'") + } + + b64JSON := false + if input.ResponseFormat == "b64_json" { + b64JSON = true + } + + var result []Item + for _, i := range config.PromptStrings { + prompts := strings.Split(i, "|") + positive_prompt := prompts[0] + negative_prompt := "" + if len(prompts) > 1 { + negative_prompt = prompts[1] + } + + mode := 0 + step := 15 + + if input.Mode != 0 { + mode = input.Mode + } + + if input.Step != 0 { + step = input.Step + } + + tempDir := "" + if !b64JSON { + tempDir = imageDir + } + // Create a temporary file + outputFile, err := ioutil.TempFile(tempDir, "b64") + if err != nil { + return err + } + outputFile.Close() + output := outputFile.Name() + ".png" + // Rename the temporary file + err = os.Rename(outputFile.Name(), output) + if err != nil { + return err + } + + baseURL := c.BaseURL() + + fn, err := ImageGeneration(height, width, mode, step, input.Seed, positive_prompt, negative_prompt, output, loader, *config) + if err != nil { + return err + } + if err := fn(); err != nil { + return err + } + + item := &Item{} + + if b64JSON { + defer os.RemoveAll(output) + data, err := os.ReadFile(output) + if err != nil { + return err + } + item.B64JSON = base64.StdEncoding.EncodeToString(data) + } else { + base := filepath.Base(output) + item.URL = baseURL + "/generated-images/" + base + } + + result = append(result, *item) + } + + resp := &OpenAIResponse{ + Data: result, + } + + jsonResult, _ := json.Marshal(resp) + log.Debug().Msgf("Response: %s", jsonResult) + + // Return the prediction in the response body + return c.JSON(resp) + } +} + // https://platform.openai.com/docs/api-reference/audio/create func transcriptEndpoint(cm ConfigMerger, debug bool, loader *model.ModelLoader, threads, ctx int, f16 bool) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { - config, input, err := readConfig(cm, c, loader, debug, threads, ctx, f16) + m, input, err := readInput(c, loader, false) if err != nil { return fmt.Errorf("failed reading parameters from request:%w", err) } + config, input, err := readConfig(m, input, cm, loader, debug, threads, ctx, f16) + if err != nil { + return fmt.Errorf("failed reading parameters from request:%w", err) + } // retrieve the file data from the request file, err := c.FormFile("file") if err != nil { diff --git a/api/prediction.go b/api/prediction.go index 7aa839b..c279e08 100644 --- a/api/prediction.go +++ b/api/prediction.go @@ -8,11 +8,12 @@ import ( "github.com/donomii/go-rwkv.cpp" model "github.com/go-skynet/LocalAI/pkg/model" + "github.com/go-skynet/LocalAI/pkg/stablediffusion" "github.com/go-skynet/bloomz.cpp" bert "github.com/go-skynet/go-bert.cpp" gpt2 "github.com/go-skynet/go-gpt2.cpp" llama "github.com/go-skynet/go-llama.cpp" - gpt4all "github.com/nomic/gpt4all/gpt4all-bindings/golang" + gpt4all "github.com/nomic-ai/gpt4all/gpt4all-bindings/golang" ) // mutex still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784 @@ -38,6 +39,45 @@ func defaultLLamaOpts(c Config) []llama.ModelOption { return llamaOpts } +func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negative_prompt, dst string, loader *model.ModelLoader, c Config) (func() error, error) { + if c.Backend != model.StableDiffusionBackend { + return nil, fmt.Errorf("endpoint only working with stablediffusion models") + } + inferenceModel, err := loader.BackendLoader(c.Backend, c.ImageGenerationAssets, []llama.ModelOption{}, uint32(c.Threads)) + if err != nil { + return nil, err + } + + var fn func() error + switch model := inferenceModel.(type) { + case *stablediffusion.StableDiffusion: + fn = func() error { + return model.GenerateImage(height, width, mode, step, seed, positive_prompt, negative_prompt, dst) + } + + default: + fn = func() error { + return fmt.Errorf("creation of images not supported by the backend") + } + } + + return func() error { + // This is still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784 + mutexMap.Lock() + l, ok := mutexes[c.Backend] + if !ok { + m := &sync.Mutex{} + mutexes[c.Backend] = m + l = m + } + mutexMap.Unlock() + l.Lock() + defer l.Unlock() + + return fn() + }, nil +} + func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c Config) (func() ([]float32, error), error) { if !c.Embeddings { return nil, fmt.Errorf("endpoint disabled for this model by API configuration") diff --git a/examples/chatbot-ui/README.md b/examples/chatbot-ui/README.md index 93459bc..7cf4bbb 100644 --- a/examples/chatbot-ui/README.md +++ b/examples/chatbot-ui/README.md @@ -19,7 +19,9 @@ cd LocalAI/examples/chatbot-ui wget https://gpt4all.io/models/ggml-gpt4all-j.bin -O models/ggml-gpt4all-j # start with docker-compose -docker-compose up -d --build +docker-compose up -d --pull always +# or you can build the images with: +# docker-compose up -d --build ``` ## Pointing chatbot-ui to a separately managed LocalAI service diff --git a/go.mod b/go.mod index 64b2a9d..8ff8084 100644 --- a/go.mod +++ b/go.mod @@ -49,7 +49,8 @@ require ( github.com/mattn/go-colorable v0.1.13 // indirect github.com/mattn/go-isatty v0.0.18 // indirect github.com/mattn/go-runewidth v0.0.14 // indirect - github.com/nomic/gpt4all/gpt4all-bindings/golang v0.0.0-00010101000000-000000000000 // indirect + github.com/mudler/go-stable-diffusion v0.0.0-20230516104333-2f32a16b5b24 // indirect + github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230516143155-79d6243fe1bc // indirect github.com/philhofer/fwd v1.1.2 // indirect github.com/rivo/uniseg v0.2.0 // indirect github.com/russross/blackfriday/v2 v2.1.0 // indirect diff --git a/go.sum b/go.sum index 0fdfd9b..b1a1a93 100644 --- a/go.sum +++ b/go.sum @@ -93,8 +93,12 @@ github.com/mattn/go-isatty v0.0.18 h1:DOKFKCQ7FNG2L1rbrmstDN4QVRdS89Nkh85u68Uwp9 github.com/mattn/go-isatty v0.0.18/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-runewidth v0.0.14 h1:+xnbZSEeDbOIg5/mE6JF0w6n9duR1l3/WmbinWVwUuU= github.com/mattn/go-runewidth v0.0.14/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= +github.com/mudler/go-stable-diffusion v0.0.0-20230516104333-2f32a16b5b24 h1:XfRD/bZom6u4zji7aB0urIVOsPe43KlkzSRrVhlzaOM= +github.com/mudler/go-stable-diffusion v0.0.0-20230516104333-2f32a16b5b24/go.mod h1:8ufRkpz/S/9ahkaxzZ5i4WMgO9w4InEhuRoT7vK5Rnw= github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e h1:fD57ERR4JtEqsWbfPhv4DMiApHyliiK5xCTNVSPiaAs= github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno= +github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230516143155-79d6243fe1bc h1:OPavP/SUsVWVYPhSUZKZeX8yDSQzf4G+BmUmwzrLTyI= +github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230516143155-79d6243fe1bc/go.mod h1:4T3CHXyrt+7FQHXaxULZfPjHbD8/99WuDDJa0YVZARI= github.com/onsi/ginkgo/v2 v2.9.4 h1:xR7vG4IXt5RWx6FfIjyAtsoMAtnc3C/rFXBBd2AjZwE= github.com/onsi/ginkgo/v2 v2.9.4/go.mod h1:gCQYp2Q+kSoIj7ykSVb9nskRSsR6PUj4AiLywzIhbKM= github.com/onsi/gomega v1.27.6 h1:ENqfyGeS5AX/rlXDd/ETokDz93u0YufY1Pgxuy/PvWE= diff --git a/main.go b/main.go index 275fb31..2490e19 100644 --- a/main.go +++ b/main.go @@ -56,6 +56,12 @@ func main() { EnvVars: []string{"ADDRESS"}, Value: ":8080", }, + &cli.StringFlag{ + Name: "image-dir", + DefaultText: "Image directory", + EnvVars: []string{"IMAGE_DIR"}, + Value: "", + }, &cli.IntFlag{ Name: "context-size", DefaultText: "Default context size of the model", @@ -87,7 +93,7 @@ It uses llama.cpp, ggml and gpt4all as backend with golang c bindings. Copyright: "go-skynet authors", Action: func(ctx *cli.Context) error { fmt.Printf("Starting LocalAI using %d threads, with models path: %s\n", ctx.Int("threads"), ctx.String("models-path")) - return api.App(ctx.String("config-file"), model.NewModelLoader(ctx.String("models-path")), ctx.Int("upload-limit"), ctx.Int("threads"), ctx.Int("context-size"), ctx.Bool("f16"), ctx.Bool("debug"), false).Listen(ctx.String("address")) + return api.App(ctx.String("config-file"), model.NewModelLoader(ctx.String("models-path")), ctx.Int("upload-limit"), ctx.Int("threads"), ctx.Int("context-size"), ctx.Bool("f16"), ctx.Bool("debug"), false, ctx.String("image-dir")).Listen(ctx.String("address")) }, } diff --git a/pkg/model/initializers.go b/pkg/model/initializers.go index ab5a9af..74c05f2 100644 --- a/pkg/model/initializers.go +++ b/pkg/model/initializers.go @@ -7,33 +7,35 @@ import ( rwkv "github.com/donomii/go-rwkv.cpp" whisper "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper" + "github.com/go-skynet/LocalAI/pkg/stablediffusion" bloomz "github.com/go-skynet/bloomz.cpp" bert "github.com/go-skynet/go-bert.cpp" gpt2 "github.com/go-skynet/go-gpt2.cpp" llama "github.com/go-skynet/go-llama.cpp" "github.com/hashicorp/go-multierror" - gpt4all "github.com/nomic/gpt4all/gpt4all-bindings/golang" + gpt4all "github.com/nomic-ai/gpt4all/gpt4all-bindings/golang" "github.com/rs/zerolog/log" ) const tokenizerSuffix = ".tokenizer.json" const ( - LlamaBackend = "llama" - BloomzBackend = "bloomz" - StarcoderBackend = "starcoder" - StableLMBackend = "stablelm" - DollyBackend = "dolly" - RedPajamaBackend = "redpajama" - GPTNeoXBackend = "gptneox" - ReplitBackend = "replit" - Gpt2Backend = "gpt2" - Gpt4AllLlamaBackend = "gpt4all-llama" - Gpt4AllMptBackend = "gpt4all-mpt" - Gpt4AllJBackend = "gpt4all-j" - BertEmbeddingsBackend = "bert-embeddings" - RwkvBackend = "rwkv" - WhisperBackend = "whisper" + LlamaBackend = "llama" + BloomzBackend = "bloomz" + StarcoderBackend = "starcoder" + StableLMBackend = "stablelm" + DollyBackend = "dolly" + RedPajamaBackend = "redpajama" + GPTNeoXBackend = "gptneox" + ReplitBackend = "replit" + Gpt2Backend = "gpt2" + Gpt4AllLlamaBackend = "gpt4all-llama" + Gpt4AllMptBackend = "gpt4all-mpt" + Gpt4AllJBackend = "gpt4all-j" + BertEmbeddingsBackend = "bert-embeddings" + RwkvBackend = "rwkv" + WhisperBackend = "whisper" + StableDiffusionBackend = "stablediffusion" ) var backends []string = []string{ @@ -48,8 +50,8 @@ var backends []string = []string{ StableLMBackend, DollyBackend, RedPajamaBackend, - GPTNeoXBackend, ReplitBackend, + GPTNeoXBackend, BertEmbeddingsBackend, StarcoderBackend, } @@ -89,6 +91,10 @@ var gpt2LM = func(modelFile string) (interface{}, error) { return gpt2.New(modelFile) } +var stableDiffusion = func(assetDir string) (interface{}, error) { + return stablediffusion.New(assetDir) +} + var whisperModel = func(modelFile string) (interface{}, error) { return whisper.New(modelFile) } @@ -107,6 +113,8 @@ func gpt4allLM(opts ...gpt4all.ModelOption) func(string) (interface{}, error) { func rwkvLM(tokenFile string, threads uint32) func(string) (interface{}, error) { return func(s string) (interface{}, error) { + log.Debug().Msgf("Loading RWKV", s, tokenFile) + model := rwkv.LoadFiles(s, tokenFile, threads) if model == nil { return nil, fmt.Errorf("could not load model") @@ -116,6 +124,7 @@ func rwkvLM(tokenFile string, threads uint32) func(string) (interface{}, error) } func (ml *ModelLoader) BackendLoader(backendString string, modelFile string, llamaOpts []llama.ModelOption, threads uint32) (model interface{}, err error) { + log.Debug().Msgf("Loading model %s from %s", backendString, modelFile) switch strings.ToLower(backendString) { case LlamaBackend: return ml.LoadModel(modelFile, llamaLM(llamaOpts...)) @@ -133,6 +142,8 @@ func (ml *ModelLoader) BackendLoader(backendString string, modelFile string, lla return ml.LoadModel(modelFile, gptNeoX) case ReplitBackend: return ml.LoadModel(modelFile, replit) + case StableDiffusionBackend: + return ml.LoadModel(modelFile, stableDiffusion) case StarcoderBackend: return ml.LoadModel(modelFile, starCoder) case Gpt4AllLlamaBackend: diff --git a/pkg/stablediffusion/generate.go b/pkg/stablediffusion/generate.go new file mode 100644 index 0000000..97989e9 --- /dev/null +++ b/pkg/stablediffusion/generate.go @@ -0,0 +1,23 @@ +//go:build stablediffusion +// +build stablediffusion + +package stablediffusion + +import ( + stableDiffusion "github.com/mudler/go-stable-diffusion" +) + +func GenerateImage(height, width, mode, step, seed int, positive_prompt, negative_prompt, dst, asset_dir string) error { + return stableDiffusion.GenerateImage( + height, + width, + mode, + step, + seed, + positive_prompt, + negative_prompt, + dst, + "", + asset_dir, + ) +} diff --git a/pkg/stablediffusion/generate_unsupported.go b/pkg/stablediffusion/generate_unsupported.go new file mode 100644 index 0000000..9563bae --- /dev/null +++ b/pkg/stablediffusion/generate_unsupported.go @@ -0,0 +1,10 @@ +//go:build !stablediffusion +// +build !stablediffusion + +package stablediffusion + +import "fmt" + +func GenerateImage(height, width, mode, step, seed int, positive_prompt, negative_prompt, dst, asset_dir string) error { + return fmt.Errorf("This version of LocalAI was built without the stablediffusion tag") +} diff --git a/pkg/stablediffusion/stablediffusion.go b/pkg/stablediffusion/stablediffusion.go new file mode 100644 index 0000000..e38db17 --- /dev/null +++ b/pkg/stablediffusion/stablediffusion.go @@ -0,0 +1,20 @@ +package stablediffusion + +import "os" + +type StableDiffusion struct { + assetDir string +} + +func New(assetDir string) (*StableDiffusion, error) { + if _, err := os.Stat(assetDir); err != nil { + return nil, err + } + return &StableDiffusion{ + assetDir: assetDir, + }, nil +} + +func (s *StableDiffusion) GenerateImage(height, width, mode, step, seed int, positive_prompt, negative_prompt, dst string) error { + return GenerateImage(height, width, mode, step, seed, positive_prompt, negative_prompt, dst, s.assetDir) +}