parent
4275bfc8c0
commit
593ff6308c
@ -0,0 +1,75 @@ |
||||
package client |
||||
|
||||
import ( |
||||
"bytes" |
||||
"encoding/json" |
||||
"fmt" |
||||
"net/http" |
||||
) |
||||
|
||||
type Prediction struct { |
||||
Prediction string `json:"prediction"` |
||||
} |
||||
|
||||
type Client struct { |
||||
baseURL string |
||||
client *http.Client |
||||
endpoint string |
||||
} |
||||
|
||||
func NewClient(baseURL string) *Client { |
||||
return &Client{ |
||||
baseURL: baseURL, |
||||
client: &http.Client{}, |
||||
endpoint: "/predict", |
||||
} |
||||
} |
||||
|
||||
type InputData struct { |
||||
Text string `json:"text"` |
||||
TopP float64 `json:"topP,omitempty"` |
||||
TopK int `json:"topK,omitempty"` |
||||
Temperature float64 `json:"temperature,omitempty"` |
||||
Tokens int `json:"tokens,omitempty"` |
||||
} |
||||
|
||||
func (c *Client) Predict(text string, opts ...InputOption) (string, error) { |
||||
input := NewInputData(opts...) |
||||
input.Text = text |
||||
|
||||
// encode input data to JSON format
|
||||
inputBytes, err := json.Marshal(input) |
||||
if err != nil { |
||||
return "", err |
||||
} |
||||
|
||||
// create HTTP request
|
||||
url := c.baseURL + c.endpoint |
||||
req, err := http.NewRequest("POST", url, bytes.NewBuffer(inputBytes)) |
||||
if err != nil { |
||||
return "", err |
||||
} |
||||
|
||||
// set request headers
|
||||
req.Header.Set("Content-Type", "application/json") |
||||
|
||||
// send request and get response
|
||||
resp, err := c.client.Do(req) |
||||
if err != nil { |
||||
return "", err |
||||
} |
||||
defer resp.Body.Close() |
||||
|
||||
if resp.StatusCode != http.StatusOK { |
||||
return "", fmt.Errorf("request failed with status %d", resp.StatusCode) |
||||
} |
||||
|
||||
// decode response body to Prediction struct
|
||||
var prediction Prediction |
||||
err = json.NewDecoder(resp.Body).Decode(&prediction) |
||||
if err != nil { |
||||
return "", err |
||||
} |
||||
|
||||
return prediction.Prediction, nil |
||||
} |
@ -0,0 +1,51 @@ |
||||
package client |
||||
|
||||
import "net/http" |
||||
|
||||
type ClientOption func(c *Client) |
||||
|
||||
func WithHTTPClient(httpClient *http.Client) ClientOption { |
||||
return func(c *Client) { |
||||
c.client = httpClient |
||||
} |
||||
} |
||||
|
||||
func WithEndpoint(endpoint string) ClientOption { |
||||
return func(c *Client) { |
||||
c.endpoint = endpoint |
||||
} |
||||
} |
||||
|
||||
type InputOption func(d *InputData) |
||||
|
||||
func NewInputData(opts ...InputOption) *InputData { |
||||
data := &InputData{} |
||||
for _, opt := range opts { |
||||
opt(data) |
||||
} |
||||
return data |
||||
} |
||||
|
||||
func WithTopP(topP float64) InputOption { |
||||
return func(d *InputData) { |
||||
d.TopP = topP |
||||
} |
||||
} |
||||
|
||||
func WithTopK(topK int) InputOption { |
||||
return func(d *InputData) { |
||||
d.TopK = topK |
||||
} |
||||
} |
||||
|
||||
func WithTemperature(temperature float64) InputOption { |
||||
return func(d *InputData) { |
||||
d.Temperature = temperature |
||||
} |
||||
} |
||||
|
||||
func WithTokens(tokens int) InputOption { |
||||
return func(d *InputData) { |
||||
d.Tokens = tokens |
||||
} |
||||
} |
Loading…
Reference in new issue