Compare commits
13 Commits
renovate/g
...
master
Author | SHA1 | Date |
---|---|---|
gregandev | 77800c1636 | 1 year ago |
ci-robbot [bot] | 5ee186b8e5 | 1 year ago |
Ettore Di Giacinto | 94817b557c | 1 year ago |
Ettore Di Giacinto | 26e1496075 | 1 year ago |
Ettore Di Giacinto | 92fca8ae74 | 1 year ago |
Stepan | 7fa5b8401d | 1 year ago |
Ettore Di Giacinto | 0eac0402e1 | 1 year ago |
Ettore Di Giacinto | c71c729bc2 | 1 year ago |
Ettore Di Giacinto | e459f114cd | 1 year ago |
Ettore Di Giacinto | 982a7e86a8 | 1 year ago |
Ettore Di Giacinto | 94916749c5 | 1 year ago |
Ettore Di Giacinto | 1d2ae46ddc | 1 year ago |
Ettore Di Giacinto | 47cc95fc9f | 1 year ago |
@ -0,0 +1,42 @@ |
||||
package backend |
||||
|
||||
import ( |
||||
"context" |
||||
"fmt" |
||||
|
||||
config "github.com/go-skynet/LocalAI/api/config" |
||||
|
||||
"github.com/go-skynet/LocalAI/api/options" |
||||
"github.com/go-skynet/LocalAI/pkg/grpc/proto" |
||||
"github.com/go-skynet/LocalAI/pkg/grpc/whisper/api" |
||||
model "github.com/go-skynet/LocalAI/pkg/model" |
||||
) |
||||
|
||||
func ModelTranscription(audio, language string, loader *model.ModelLoader, c config.Config, o *options.Option) (*api.Result, error) { |
||||
opts := []model.Option{ |
||||
model.WithBackendString(model.WhisperBackend), |
||||
model.WithModelFile(c.Model), |
||||
model.WithContext(o.Context), |
||||
model.WithThreads(uint32(c.Threads)), |
||||
model.WithAssetDir(o.AssetsDestination), |
||||
} |
||||
|
||||
for k, v := range o.ExternalGRPCBackends { |
||||
opts = append(opts, model.WithExternalBackend(k, v)) |
||||
} |
||||
|
||||
whisperModel, err := o.Loader.BackendLoader(opts...) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
if whisperModel == nil { |
||||
return nil, fmt.Errorf("could not load whisper model") |
||||
} |
||||
|
||||
return whisperModel.AudioTranscription(context.Background(), &proto.TranscriptRequest{ |
||||
Dst: audio, |
||||
Language: language, |
||||
Threads: uint32(c.Threads), |
||||
}) |
||||
} |
@ -0,0 +1,72 @@ |
||||
package backend |
||||
|
||||
import ( |
||||
"context" |
||||
"fmt" |
||||
"os" |
||||
"path/filepath" |
||||
|
||||
"github.com/go-skynet/LocalAI/api/options" |
||||
"github.com/go-skynet/LocalAI/pkg/grpc/proto" |
||||
model "github.com/go-skynet/LocalAI/pkg/model" |
||||
"github.com/go-skynet/LocalAI/pkg/utils" |
||||
) |
||||
|
||||
func generateUniqueFileName(dir, baseName, ext string) string { |
||||
counter := 1 |
||||
fileName := baseName + ext |
||||
|
||||
for { |
||||
filePath := filepath.Join(dir, fileName) |
||||
_, err := os.Stat(filePath) |
||||
if os.IsNotExist(err) { |
||||
return fileName |
||||
} |
||||
|
||||
counter++ |
||||
fileName = fmt.Sprintf("%s_%d%s", baseName, counter, ext) |
||||
} |
||||
} |
||||
|
||||
func ModelTTS(text, modelFile string, loader *model.ModelLoader, o *options.Option) (string, *proto.Result, error) { |
||||
opts := []model.Option{ |
||||
model.WithBackendString(model.PiperBackend), |
||||
model.WithModelFile(modelFile), |
||||
model.WithContext(o.Context), |
||||
model.WithAssetDir(o.AssetsDestination), |
||||
} |
||||
|
||||
for k, v := range o.ExternalGRPCBackends { |
||||
opts = append(opts, model.WithExternalBackend(k, v)) |
||||
} |
||||
|
||||
piperModel, err := o.Loader.BackendLoader(opts...) |
||||
if err != nil { |
||||
return "", nil, err |
||||
} |
||||
|
||||
if piperModel == nil { |
||||
return "", nil, fmt.Errorf("could not load piper model") |
||||
} |
||||
|
||||
if err := os.MkdirAll(o.AudioDir, 0755); err != nil { |
||||
return "", nil, fmt.Errorf("failed creating audio directory: %s", err) |
||||
} |
||||
|
||||
fileName := generateUniqueFileName(o.AudioDir, "piper", ".wav") |
||||
filePath := filepath.Join(o.AudioDir, fileName) |
||||
|
||||
modelPath := filepath.Join(o.Loader.ModelPath, modelFile) |
||||
|
||||
if err := utils.VerifyPath(modelPath, o.Loader.ModelPath); err != nil { |
||||
return "", nil, err |
||||
} |
||||
|
||||
res, err := piperModel.TTS(context.Background(), &proto.TTSRequest{ |
||||
Text: text, |
||||
Model: modelPath, |
||||
Dst: filePath, |
||||
}) |
||||
|
||||
return filePath, res, err |
||||
} |
@ -0,0 +1,49 @@ |
||||
# -*- coding: utf-8 -*- |
||||
# Generated by the protocol buffer compiler. DO NOT EDIT! |
||||
# source: backend.proto |
||||
"""Generated protocol buffer code.""" |
||||
from google.protobuf import descriptor as _descriptor |
||||
from google.protobuf import descriptor_pool as _descriptor_pool |
||||
from google.protobuf import symbol_database as _symbol_database |
||||
from google.protobuf.internal import builder as _builder |
||||
# @@protoc_insertion_point(imports) |
||||
|
||||
_sym_db = _symbol_database.Default() |
||||
|
||||
|
||||
|
||||
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\rbackend.proto\x12\x07\x62\x61\x63kend\"\x0f\n\rHealthMessage\"\xa4\x05\n\x0ePredictOptions\x12\x0e\n\x06Prompt\x18\x01 \x01(\t\x12\x0c\n\x04Seed\x18\x02 \x01(\x05\x12\x0f\n\x07Threads\x18\x03 \x01(\x05\x12\x0e\n\x06Tokens\x18\x04 \x01(\x05\x12\x0c\n\x04TopK\x18\x05 \x01(\x05\x12\x0e\n\x06Repeat\x18\x06 \x01(\x05\x12\r\n\x05\x42\x61tch\x18\x07 \x01(\x05\x12\r\n\x05NKeep\x18\x08 \x01(\x05\x12\x13\n\x0bTemperature\x18\t \x01(\x02\x12\x0f\n\x07Penalty\x18\n \x01(\x02\x12\r\n\x05\x46\x31\x36KV\x18\x0b \x01(\x08\x12\x11\n\tDebugMode\x18\x0c \x01(\x08\x12\x13\n\x0bStopPrompts\x18\r \x03(\t\x12\x11\n\tIgnoreEOS\x18\x0e \x01(\x08\x12\x19\n\x11TailFreeSamplingZ\x18\x0f \x01(\x02\x12\x10\n\x08TypicalP\x18\x10 \x01(\x02\x12\x18\n\x10\x46requencyPenalty\x18\x11 \x01(\x02\x12\x17\n\x0fPresencePenalty\x18\x12 \x01(\x02\x12\x10\n\x08Mirostat\x18\x13 \x01(\x05\x12\x13\n\x0bMirostatETA\x18\x14 \x01(\x02\x12\x13\n\x0bMirostatTAU\x18\x15 \x01(\x02\x12\x12\n\nPenalizeNL\x18\x16 \x01(\x08\x12\x11\n\tLogitBias\x18\x17 \x01(\t\x12\r\n\x05MLock\x18\x19 \x01(\x08\x12\x0c\n\x04MMap\x18\x1a \x01(\x08\x12\x16\n\x0ePromptCacheAll\x18\x1b \x01(\x08\x12\x15\n\rPromptCacheRO\x18\x1c \x01(\x08\x12\x0f\n\x07Grammar\x18\x1d \x01(\t\x12\x0f\n\x07MainGPU\x18\x1e \x01(\t\x12\x13\n\x0bTensorSplit\x18\x1f \x01(\t\x12\x0c\n\x04TopP\x18 \x01(\x02\x12\x17\n\x0fPromptCachePath\x18! \x01(\t\x12\r\n\x05\x44\x65\x62ug\x18\" \x01(\x08\x12\x17\n\x0f\x45mbeddingTokens\x18# \x03(\x05\x12\x12\n\nEmbeddings\x18$ \x01(\t\"\x18\n\x05Reply\x12\x0f\n\x07message\x18\x01 \x01(\t\"\xac\x02\n\x0cModelOptions\x12\r\n\x05Model\x18\x01 \x01(\t\x12\x13\n\x0b\x43ontextSize\x18\x02 \x01(\x05\x12\x0c\n\x04Seed\x18\x03 \x01(\x05\x12\x0e\n\x06NBatch\x18\x04 \x01(\x05\x12\x11\n\tF16Memory\x18\x05 \x01(\x08\x12\r\n\x05MLock\x18\x06 \x01(\x08\x12\x0c\n\x04MMap\x18\x07 \x01(\x08\x12\x11\n\tVocabOnly\x18\x08 \x01(\x08\x12\x0f\n\x07LowVRAM\x18\t \x01(\x08\x12\x12\n\nEmbeddings\x18\n \x01(\x08\x12\x0c\n\x04NUMA\x18\x0b \x01(\x08\x12\x12\n\nNGPULayers\x18\x0c \x01(\x05\x12\x0f\n\x07MainGPU\x18\r \x01(\t\x12\x13\n\x0bTensorSplit\x18\x0e \x01(\t\x12\x0f\n\x07Threads\x18\x0f \x01(\x05\x12\x19\n\x11LibrarySearchPath\x18\x10 \x01(\t\"*\n\x06Result\x12\x0f\n\x07message\x18\x01 \x01(\t\x12\x0f\n\x07success\x18\x02 \x01(\x08\"%\n\x0f\x45mbeddingResult\x12\x12\n\nembeddings\x18\x01 \x03(\x02\"C\n\x11TranscriptRequest\x12\x0b\n\x03\x64st\x18\x02 \x01(\t\x12\x10\n\x08language\x18\x03 \x01(\t\x12\x0f\n\x07threads\x18\x04 \x01(\r\"N\n\x10TranscriptResult\x12,\n\x08segments\x18\x01 \x03(\x0b\x32\x1a.backend.TranscriptSegment\x12\x0c\n\x04text\x18\x02 \x01(\t\"Y\n\x11TranscriptSegment\x12\n\n\x02id\x18\x01 \x01(\x05\x12\r\n\x05start\x18\x02 \x01(\x03\x12\x0b\n\x03\x65nd\x18\x03 \x01(\x03\x12\x0c\n\x04text\x18\x04 \x01(\t\x12\x0e\n\x06tokens\x18\x05 \x03(\x05\"\x9e\x01\n\x14GenerateImageRequest\x12\x0e\n\x06height\x18\x01 \x01(\x05\x12\r\n\x05width\x18\x02 \x01(\x05\x12\x0c\n\x04mode\x18\x03 \x01(\x05\x12\x0c\n\x04step\x18\x04 \x01(\x05\x12\x0c\n\x04seed\x18\x05 \x01(\x05\x12\x17\n\x0fpositive_prompt\x18\x06 \x01(\t\x12\x17\n\x0fnegative_prompt\x18\x07 \x01(\t\x12\x0b\n\x03\x64st\x18\x08 \x01(\t\"6\n\nTTSRequest\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\r\n\x05model\x18\x02 \x01(\t\x12\x0b\n\x03\x64st\x18\x03 \x01(\t2\xeb\x03\n\x07\x42\x61\x63kend\x12\x32\n\x06Health\x12\x16.backend.HealthMessage\x1a\x0e.backend.Reply\"\x00\x12\x34\n\x07Predict\x12\x17.backend.PredictOptions\x1a\x0e.backend.Reply\"\x00\x12\x35\n\tLoadModel\x12\x15.backend.ModelOptions\x1a\x0f.backend.Result\"\x00\x12<\n\rPredictStream\x12\x17.backend.PredictOptions\x1a\x0e.backend.Reply\"\x00\x30\x01\x12@\n\tEmbedding\x12\x17.backend.PredictOptions\x1a\x18.backend.EmbeddingResult\"\x00\x12\x41\n\rGenerateImage\x12\x1d.backend.GenerateImageRequest\x1a\x0f.backend.Result\"\x00\x12M\n\x12\x41udioTranscription\x12\x1a.backend.TranscriptRequest\x1a\x19.backend.TranscriptResult\"\x00\x12-\n\x03TTS\x12\x13.backend.TTSRequest\x1a\x0f.backend.Result\"\x00\x42Z\n\x19io.skynet.localai.backendB\x0eLocalAIBackendP\x01Z+github.com/go-skynet/LocalAI/pkg/grpc/protob\x06proto3') |
||||
|
||||
_globals = globals() |
||||
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) |
||||
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'backend_pb2', _globals) |
||||
if _descriptor._USE_C_DESCRIPTORS == False: |
||||
|
||||
DESCRIPTOR._options = None |
||||
DESCRIPTOR._serialized_options = b'\n\031io.skynet.localai.backendB\016LocalAIBackendP\001Z+github.com/go-skynet/LocalAI/pkg/grpc/proto' |
||||
_globals['_HEALTHMESSAGE']._serialized_start=26 |
||||
_globals['_HEALTHMESSAGE']._serialized_end=41 |
||||
_globals['_PREDICTOPTIONS']._serialized_start=44 |
||||
_globals['_PREDICTOPTIONS']._serialized_end=720 |
||||
_globals['_REPLY']._serialized_start=722 |
||||
_globals['_REPLY']._serialized_end=746 |
||||
_globals['_MODELOPTIONS']._serialized_start=749 |
||||
_globals['_MODELOPTIONS']._serialized_end=1049 |
||||
_globals['_RESULT']._serialized_start=1051 |
||||
_globals['_RESULT']._serialized_end=1093 |
||||
_globals['_EMBEDDINGRESULT']._serialized_start=1095 |
||||
_globals['_EMBEDDINGRESULT']._serialized_end=1132 |
||||
_globals['_TRANSCRIPTREQUEST']._serialized_start=1134 |
||||
_globals['_TRANSCRIPTREQUEST']._serialized_end=1201 |
||||
_globals['_TRANSCRIPTRESULT']._serialized_start=1203 |
||||
_globals['_TRANSCRIPTRESULT']._serialized_end=1281 |
||||
_globals['_TRANSCRIPTSEGMENT']._serialized_start=1283 |
||||
_globals['_TRANSCRIPTSEGMENT']._serialized_end=1372 |
||||
_globals['_GENERATEIMAGEREQUEST']._serialized_start=1375 |
||||
_globals['_GENERATEIMAGEREQUEST']._serialized_end=1533 |
||||
_globals['_TTSREQUEST']._serialized_start=1535 |
||||
_globals['_TTSREQUEST']._serialized_end=1589 |
||||
_globals['_BACKEND']._serialized_start=1592 |
||||
_globals['_BACKEND']._serialized_end=2083 |
||||
# @@protoc_insertion_point(module_scope) |
@ -0,0 +1,297 @@ |
||||
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! |
||||
"""Client and server classes corresponding to protobuf-defined services.""" |
||||
import grpc |
||||
|
||||
import backend_pb2 as backend__pb2 |
||||
|
||||
|
||||
class BackendStub(object): |
||||
"""Missing associated documentation comment in .proto file.""" |
||||
|
||||
def __init__(self, channel): |
||||
"""Constructor. |
||||
|
||||
Args: |
||||
channel: A grpc.Channel. |
||||
""" |
||||
self.Health = channel.unary_unary( |
||||
'/backend.Backend/Health', |
||||
request_serializer=backend__pb2.HealthMessage.SerializeToString, |
||||
response_deserializer=backend__pb2.Reply.FromString, |
||||
) |
||||
self.Predict = channel.unary_unary( |
||||
'/backend.Backend/Predict', |
||||
request_serializer=backend__pb2.PredictOptions.SerializeToString, |
||||
response_deserializer=backend__pb2.Reply.FromString, |
||||
) |
||||
self.LoadModel = channel.unary_unary( |
||||
'/backend.Backend/LoadModel', |
||||
request_serializer=backend__pb2.ModelOptions.SerializeToString, |
||||
response_deserializer=backend__pb2.Result.FromString, |
||||
) |
||||
self.PredictStream = channel.unary_stream( |
||||
'/backend.Backend/PredictStream', |
||||
request_serializer=backend__pb2.PredictOptions.SerializeToString, |
||||
response_deserializer=backend__pb2.Reply.FromString, |
||||
) |
||||
self.Embedding = channel.unary_unary( |
||||
'/backend.Backend/Embedding', |
||||
request_serializer=backend__pb2.PredictOptions.SerializeToString, |
||||
response_deserializer=backend__pb2.EmbeddingResult.FromString, |
||||
) |
||||
self.GenerateImage = channel.unary_unary( |
||||
'/backend.Backend/GenerateImage', |
||||
request_serializer=backend__pb2.GenerateImageRequest.SerializeToString, |
||||
response_deserializer=backend__pb2.Result.FromString, |
||||
) |
||||
self.AudioTranscription = channel.unary_unary( |
||||
'/backend.Backend/AudioTranscription', |
||||
request_serializer=backend__pb2.TranscriptRequest.SerializeToString, |
||||
response_deserializer=backend__pb2.TranscriptResult.FromString, |
||||
) |
||||
self.TTS = channel.unary_unary( |
||||
'/backend.Backend/TTS', |
||||
request_serializer=backend__pb2.TTSRequest.SerializeToString, |
||||
response_deserializer=backend__pb2.Result.FromString, |
||||
) |
||||
|
||||
|
||||
class BackendServicer(object): |
||||
"""Missing associated documentation comment in .proto file.""" |
||||
|
||||
def Health(self, request, context): |
||||
"""Missing associated documentation comment in .proto file.""" |
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED) |
||||
context.set_details('Method not implemented!') |
||||
raise NotImplementedError('Method not implemented!') |
||||
|
||||
def Predict(self, request, context): |
||||
"""Missing associated documentation comment in .proto file.""" |
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED) |
||||
context.set_details('Method not implemented!') |
||||
raise NotImplementedError('Method not implemented!') |
||||
|
||||
def LoadModel(self, request, context): |
||||
"""Missing associated documentation comment in .proto file.""" |
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED) |
||||
context.set_details('Method not implemented!') |
||||
raise NotImplementedError('Method not implemented!') |
||||
|
||||
def PredictStream(self, request, context): |
||||
"""Missing associated documentation comment in .proto file.""" |
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED) |
||||
context.set_details('Method not implemented!') |
||||
raise NotImplementedError('Method not implemented!') |
||||
|
||||
def Embedding(self, request, context): |
||||
"""Missing associated documentation comment in .proto file.""" |
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED) |
||||
context.set_details('Method not implemented!') |
||||
raise NotImplementedError('Method not implemented!') |
||||
|
||||
def GenerateImage(self, request, context): |
||||
"""Missing associated documentation comment in .proto file.""" |
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED) |
||||
context.set_details('Method not implemented!') |
||||
raise NotImplementedError('Method not implemented!') |
||||
|
||||
def AudioTranscription(self, request, context): |
||||
"""Missing associated documentation comment in .proto file.""" |
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED) |
||||
context.set_details('Method not implemented!') |
||||
raise NotImplementedError('Method not implemented!') |
||||
|
||||
def TTS(self, request, context): |
||||
"""Missing associated documentation comment in .proto file.""" |
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED) |
||||
context.set_details('Method not implemented!') |
||||
raise NotImplementedError('Method not implemented!') |
||||
|
||||
|
||||
def add_BackendServicer_to_server(servicer, server): |
||||
rpc_method_handlers = { |
||||
'Health': grpc.unary_unary_rpc_method_handler( |
||||
servicer.Health, |
||||
request_deserializer=backend__pb2.HealthMessage.FromString, |
||||
response_serializer=backend__pb2.Reply.SerializeToString, |
||||
), |
||||
'Predict': grpc.unary_unary_rpc_method_handler( |
||||
servicer.Predict, |
||||
request_deserializer=backend__pb2.PredictOptions.FromString, |
||||
response_serializer=backend__pb2.Reply.SerializeToString, |
||||
), |
||||
'LoadModel': grpc.unary_unary_rpc_method_handler( |
||||
servicer.LoadModel, |
||||
request_deserializer=backend__pb2.ModelOptions.FromString, |
||||
response_serializer=backend__pb2.Result.SerializeToString, |
||||
), |
||||
'PredictStream': grpc.unary_stream_rpc_method_handler( |
||||
servicer.PredictStream, |
||||
request_deserializer=backend__pb2.PredictOptions.FromString, |
||||
response_serializer=backend__pb2.Reply.SerializeToString, |
||||
), |
||||
'Embedding': grpc.unary_unary_rpc_method_handler( |
||||
servicer.Embedding, |
||||
request_deserializer=backend__pb2.PredictOptions.FromString, |
||||
response_serializer=backend__pb2.EmbeddingResult.SerializeToString, |
||||
), |
||||
'GenerateImage': grpc.unary_unary_rpc_method_handler( |
||||
servicer.GenerateImage, |
||||
request_deserializer=backend__pb2.GenerateImageRequest.FromString, |
||||
response_serializer=backend__pb2.Result.SerializeToString, |
||||
), |
||||
'AudioTranscription': grpc.unary_unary_rpc_method_handler( |
||||
servicer.AudioTranscription, |
||||
request_deserializer=backend__pb2.TranscriptRequest.FromString, |
||||
response_serializer=backend__pb2.TranscriptResult.SerializeToString, |
||||
), |
||||
'TTS': grpc.unary_unary_rpc_method_handler( |
||||
servicer.TTS, |
||||
request_deserializer=backend__pb2.TTSRequest.FromString, |
||||
response_serializer=backend__pb2.Result.SerializeToString, |
||||
), |
||||
} |
||||
generic_handler = grpc.method_handlers_generic_handler( |
||||
'backend.Backend', rpc_method_handlers) |
||||
server.add_generic_rpc_handlers((generic_handler,)) |
||||
|
||||
|
||||
# This class is part of an EXPERIMENTAL API. |
||||
class Backend(object): |
||||
"""Missing associated documentation comment in .proto file.""" |
||||
|
||||
@staticmethod |
||||
def Health(request, |
||||
target, |
||||
options=(), |
||||
channel_credentials=None, |
||||
call_credentials=None, |
||||
insecure=False, |
||||
compression=None, |
||||
wait_for_ready=None, |
||||
timeout=None, |
||||
metadata=None): |
||||
return grpc.experimental.unary_unary(request, target, '/backend.Backend/Health', |
||||
backend__pb2.HealthMessage.SerializeToString, |
||||
backend__pb2.Reply.FromString, |
||||
options, channel_credentials, |
||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata) |
||||
|
||||
@staticmethod |
||||
def Predict(request, |
||||
target, |
||||
options=(), |
||||
channel_credentials=None, |
||||
call_credentials=None, |
||||
insecure=False, |
||||
compression=None, |
||||
wait_for_ready=None, |
||||
timeout=None, |
||||
metadata=None): |
||||
return grpc.experimental.unary_unary(request, target, '/backend.Backend/Predict', |
||||
backend__pb2.PredictOptions.SerializeToString, |
||||
backend__pb2.Reply.FromString, |
||||
options, channel_credentials, |
||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata) |
||||
|
||||
@staticmethod |
||||
def LoadModel(request, |
||||
target, |
||||
options=(), |
||||
channel_credentials=None, |
||||
call_credentials=None, |
||||
insecure=False, |
||||
compression=None, |
||||
wait_for_ready=None, |
||||
timeout=None, |
||||
metadata=None): |
||||
return grpc.experimental.unary_unary(request, target, '/backend.Backend/LoadModel', |
||||
backend__pb2.ModelOptions.SerializeToString, |
||||
backend__pb2.Result.FromString, |
||||
options, channel_credentials, |
||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata) |
||||
|
||||
@staticmethod |
||||
def PredictStream(request, |
||||
target, |
||||
options=(), |
||||
channel_credentials=None, |
||||
call_credentials=None, |
||||
insecure=False, |
||||
compression=None, |
||||
wait_for_ready=None, |
||||
timeout=None, |
||||
metadata=None): |
||||
return grpc.experimental.unary_stream(request, target, '/backend.Backend/PredictStream', |
||||
backend__pb2.PredictOptions.SerializeToString, |
||||
backend__pb2.Reply.FromString, |
||||
options, channel_credentials, |
||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata) |
||||
|
||||
@staticmethod |
||||
def Embedding(request, |
||||
target, |
||||
options=(), |
||||
channel_credentials=None, |
||||
call_credentials=None, |
||||
insecure=False, |
||||
compression=None, |
||||
wait_for_ready=None, |
||||
timeout=None, |
||||
metadata=None): |
||||
return grpc.experimental.unary_unary(request, target, '/backend.Backend/Embedding', |
||||
backend__pb2.PredictOptions.SerializeToString, |
||||
backend__pb2.EmbeddingResult.FromString, |
||||
options, channel_credentials, |
||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata) |
||||
|
||||
@staticmethod |
||||
def GenerateImage(request, |
||||
target, |
||||
options=(), |
||||
channel_credentials=None, |
||||
call_credentials=None, |
||||
insecure=False, |
||||
compression=None, |
||||
wait_for_ready=None, |
||||
timeout=None, |
||||
metadata=None): |
||||
return grpc.experimental.unary_unary(request, target, '/backend.Backend/GenerateImage', |
||||
backend__pb2.GenerateImageRequest.SerializeToString, |
||||
backend__pb2.Result.FromString, |
||||
options, channel_credentials, |
||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata) |
||||
|
||||
@staticmethod |
||||
def AudioTranscription(request, |
||||
target, |
||||
options=(), |
||||
channel_credentials=None, |
||||
call_credentials=None, |
||||
insecure=False, |
||||
compression=None, |
||||
wait_for_ready=None, |
||||
timeout=None, |
||||
metadata=None): |
||||
return grpc.experimental.unary_unary(request, target, '/backend.Backend/AudioTranscription', |
||||
backend__pb2.TranscriptRequest.SerializeToString, |
||||
backend__pb2.TranscriptResult.FromString, |
||||
options, channel_credentials, |
||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata) |
||||
|
||||
@staticmethod |
||||
def TTS(request, |
||||
target, |
||||
options=(), |
||||
channel_credentials=None, |
||||
call_credentials=None, |
||||
insecure=False, |
||||
compression=None, |
||||
wait_for_ready=None, |
||||
timeout=None, |
||||
metadata=None): |
||||
return grpc.experimental.unary_unary(request, target, '/backend.Backend/TTS', |
||||
backend__pb2.TTSRequest.SerializeToString, |
||||
backend__pb2.Result.FromString, |
||||
options, channel_credentials, |
||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata) |
@ -0,0 +1,67 @@ |
||||
#!/usr/bin/env python3 |
||||
import grpc |
||||
from concurrent import futures |
||||
import time |
||||
import backend_pb2 |
||||
import backend_pb2_grpc |
||||
import argparse |
||||
import signal |
||||
import sys |
||||
import os |
||||
from sentence_transformers import SentenceTransformer |
||||
|
||||
_ONE_DAY_IN_SECONDS = 60 * 60 * 24 |
||||
|
||||
# Implement the BackendServicer class with the service methods |
||||
class BackendServicer(backend_pb2_grpc.BackendServicer): |
||||
def Health(self, request, context): |
||||
return backend_pb2.Reply(message="OK") |
||||
def LoadModel(self, request, context): |
||||
model_name = request.Model |
||||
model_name = os.path.basename(model_name) |
||||
try: |
||||
self.model = SentenceTransformer(model_name) |
||||
except Exception as err: |
||||
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}") |
||||
# Implement your logic here for the LoadModel service |
||||
# Replace this with your desired response |
||||
return backend_pb2.Result(message="Model loaded successfully", success=True) |
||||
def Embedding(self, request, context): |
||||
# Implement your logic here for the Embedding service |
||||
# Replace this with your desired response |
||||
print("Calculated embeddings for: " + request.Embeddings, file=sys.stderr) |
||||
sentence_embeddings = self.model.encode(request.Embeddings) |
||||
return backend_pb2.EmbeddingResult(embeddings=sentence_embeddings) |
||||
|
||||
|
||||
def serve(address): |
||||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) |
||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) |
||||
server.add_insecure_port(address) |
||||
server.start() |
||||
print("Server started. Listening on: " + address, file=sys.stderr) |
||||
|
||||
# Define the signal handler function |
||||
def signal_handler(sig, frame): |
||||
print("Received termination signal. Shutting down...") |
||||
server.stop(0) |
||||
sys.exit(0) |
||||
|
||||
# Set the signal handlers for SIGINT and SIGTERM |
||||
signal.signal(signal.SIGINT, signal_handler) |
||||
signal.signal(signal.SIGTERM, signal_handler) |
||||
|
||||
try: |
||||
while True: |
||||
time.sleep(_ONE_DAY_IN_SECONDS) |
||||
except KeyboardInterrupt: |
||||
server.stop(0) |
||||
|
||||
if __name__ == "__main__": |
||||
parser = argparse.ArgumentParser(description="Run the gRPC server.") |
||||
parser.add_argument( |
||||
"--addr", default="localhost:50051", help="The address to bind the server to." |
||||
) |
||||
args = parser.parse_args() |
||||
|
||||
serve(args.addr) |
@ -0,0 +1,4 @@ |
||||
sentence_transformers |
||||
grpcio |
||||
google |
||||
protobuf |
@ -0,0 +1,37 @@ |
||||
package utils |
||||
|
||||
import ( |
||||
"time" |
||||
|
||||
"github.com/rs/zerolog/log" |
||||
) |
||||
|
||||
var lastProgress time.Time = time.Now() |
||||
var startTime time.Time = time.Now() |
||||
|
||||
func ResetDownloadTimers() { |
||||
lastProgress = time.Now() |
||||
startTime = time.Now() |
||||
} |
||||
|
||||
func DisplayDownloadFunction(fileName string, current string, total string, percentage float64) { |
||||
currentTime := time.Now() |
||||
|
||||
if currentTime.Sub(lastProgress) >= 5*time.Second { |
||||
|
||||
lastProgress = currentTime |
||||
|
||||
// calculate ETA based on percentage and elapsed time
|
||||
var eta time.Duration |
||||
if percentage > 0 { |
||||
elapsed := currentTime.Sub(startTime) |
||||
eta = time.Duration(float64(elapsed)*(100/percentage) - float64(elapsed)) |
||||
} |
||||
|
||||
if total != "" { |
||||
log.Debug().Msgf("Downloading %s: %s/%s (%.2f%%) ETA: %s", fileName, current, total, percentage, eta) |
||||
} else { |
||||
log.Debug().Msgf("Downloading: %s", current) |
||||
} |
||||
} |
||||
} |
@ -0,0 +1,5 @@ |
||||
name: code-search-ada-code-001 |
||||
backend: huggingface |
||||
embeddings: true |
||||
parameters: |
||||
model: all-MiniLM-L6-v2 |
Loading…
Reference in new issue