parent
4275bfc8c0
commit
593ff6308c
@ -0,0 +1,75 @@ |
|||||||
|
package client |
||||||
|
|
||||||
|
import ( |
||||||
|
"bytes" |
||||||
|
"encoding/json" |
||||||
|
"fmt" |
||||||
|
"net/http" |
||||||
|
) |
||||||
|
|
||||||
|
type Prediction struct { |
||||||
|
Prediction string `json:"prediction"` |
||||||
|
} |
||||||
|
|
||||||
|
type Client struct { |
||||||
|
baseURL string |
||||||
|
client *http.Client |
||||||
|
endpoint string |
||||||
|
} |
||||||
|
|
||||||
|
func NewClient(baseURL string) *Client { |
||||||
|
return &Client{ |
||||||
|
baseURL: baseURL, |
||||||
|
client: &http.Client{}, |
||||||
|
endpoint: "/predict", |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
type InputData struct { |
||||||
|
Text string `json:"text"` |
||||||
|
TopP float64 `json:"topP,omitempty"` |
||||||
|
TopK int `json:"topK,omitempty"` |
||||||
|
Temperature float64 `json:"temperature,omitempty"` |
||||||
|
Tokens int `json:"tokens,omitempty"` |
||||||
|
} |
||||||
|
|
||||||
|
func (c *Client) Predict(text string, opts ...InputOption) (string, error) { |
||||||
|
input := NewInputData(opts...) |
||||||
|
input.Text = text |
||||||
|
|
||||||
|
// encode input data to JSON format
|
||||||
|
inputBytes, err := json.Marshal(input) |
||||||
|
if err != nil { |
||||||
|
return "", err |
||||||
|
} |
||||||
|
|
||||||
|
// create HTTP request
|
||||||
|
url := c.baseURL + c.endpoint |
||||||
|
req, err := http.NewRequest("POST", url, bytes.NewBuffer(inputBytes)) |
||||||
|
if err != nil { |
||||||
|
return "", err |
||||||
|
} |
||||||
|
|
||||||
|
// set request headers
|
||||||
|
req.Header.Set("Content-Type", "application/json") |
||||||
|
|
||||||
|
// send request and get response
|
||||||
|
resp, err := c.client.Do(req) |
||||||
|
if err != nil { |
||||||
|
return "", err |
||||||
|
} |
||||||
|
defer resp.Body.Close() |
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK { |
||||||
|
return "", fmt.Errorf("request failed with status %d", resp.StatusCode) |
||||||
|
} |
||||||
|
|
||||||
|
// decode response body to Prediction struct
|
||||||
|
var prediction Prediction |
||||||
|
err = json.NewDecoder(resp.Body).Decode(&prediction) |
||||||
|
if err != nil { |
||||||
|
return "", err |
||||||
|
} |
||||||
|
|
||||||
|
return prediction.Prediction, nil |
||||||
|
} |
@ -0,0 +1,51 @@ |
|||||||
|
package client |
||||||
|
|
||||||
|
import "net/http" |
||||||
|
|
||||||
|
type ClientOption func(c *Client) |
||||||
|
|
||||||
|
func WithHTTPClient(httpClient *http.Client) ClientOption { |
||||||
|
return func(c *Client) { |
||||||
|
c.client = httpClient |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func WithEndpoint(endpoint string) ClientOption { |
||||||
|
return func(c *Client) { |
||||||
|
c.endpoint = endpoint |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
type InputOption func(d *InputData) |
||||||
|
|
||||||
|
func NewInputData(opts ...InputOption) *InputData { |
||||||
|
data := &InputData{} |
||||||
|
for _, opt := range opts { |
||||||
|
opt(data) |
||||||
|
} |
||||||
|
return data |
||||||
|
} |
||||||
|
|
||||||
|
func WithTopP(topP float64) InputOption { |
||||||
|
return func(d *InputData) { |
||||||
|
d.TopP = topP |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func WithTopK(topK int) InputOption { |
||||||
|
return func(d *InputData) { |
||||||
|
d.TopK = topK |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func WithTemperature(temperature float64) InputOption { |
||||||
|
return func(d *InputData) { |
||||||
|
d.Temperature = temperature |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func WithTokens(tokens int) InputOption { |
||||||
|
return func(d *InputData) { |
||||||
|
d.Tokens = tokens |
||||||
|
} |
||||||
|
} |
Loading…
Reference in new issue