Skip to content

Commit 451bbdf

Browse files
committed
Make it compatible with shim
1 parent 3eab0cc commit 451bbdf

File tree

4 files changed

+46
-43
lines changed

4 files changed

+46
-43
lines changed

Diff for: Dockerfile

+30
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
FROM ghcr.io/tinfoilanalytics/nitro-attestation-shim:v0.2.2 AS shim
2+
3+
FROM ollama/ollama AS ollama
4+
5+
FROM golang:1.21 AS build
6+
7+
WORKDIR /app
8+
COPY main.go go.mod go.sum ./
9+
RUN CGO_ENABLED=0 GOOS=linux go build -o /contentmod main.go
10+
11+
FROM alpine:3
12+
13+
RUN apk add --no-cache iproute2 ca-certificates
14+
15+
# Copy binaries from previous stages
16+
COPY --from=shim /nitro-attestation-shim /nitro-attestation-shim
17+
COPY --from=ollama /bin/ollama /bin/ollama
18+
COPY --from=build /contentmod /contentmod
19+
20+
# Set environment variable for port
21+
ENV PORT=80
22+
23+
# Create script to initialize ollama and start the application
24+
RUN echo '#!/bin/sh\n\
25+
ollama serve &\n\
26+
sleep 5\n\
27+
ollama pull llama-guard3:1b\n\
28+
exec "$@"' > /start.sh && chmod +x /start.sh
29+
30+
ENTRYPOINT ["/start.sh", "/nitro-attestation-shim", "-e", "[email protected]", "-u", "80", "--", "/contentmod"]

Diff for: go.mod

+1-8
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,3 @@
1-
module content-moderator-api
1+
module contentmod
22

33
go 1.22.1
4-
5-
require github.com/replicate/replicate-go v0.26.0
6-
7-
require (
8-
github.com/vincent-petithory/dataurl v1.0.0 // indirect
9-
golang.org/x/sync v0.10.0 // indirect
10-
)

Diff for: go.sum

-14
This file was deleted.

Diff for: main.go

+15-21
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import (
1515
const (
1616
systemPrompt = `You are a content moderator. Analyze the following text and respond with 'safe' if the content is safe, or 'unsafe' followed by the category codes (e.g., 'unsafe\nS1,S2') if any violations are detected.`
1717
modelName = "llama-guard3:1b"
18+
ollamaURL = "http://localhost:11434"
1819
)
1920

2021
type analyzeRequest struct {
@@ -46,15 +47,7 @@ type ollamaResponse struct {
4647
Message Message `json:"message"`
4748
}
4849

49-
type server struct {
50-
ollamaURL string
51-
}
52-
53-
func newServer(ollamaURL string) *server {
54-
return &server{ollamaURL: ollamaURL}
55-
}
56-
57-
func (s *server) handleAnalyze(w http.ResponseWriter, r *http.Request) {
50+
func handleAnalyze(w http.ResponseWriter, r *http.Request) {
5851
if r.Method != http.MethodPost {
5952
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
6053
return
@@ -73,7 +66,7 @@ func (s *server) handleAnalyze(w http.ResponseWriter, r *http.Request) {
7366

7467
results := make([]analysisResult, 0, len(req.Messages))
7568
for _, message := range req.Messages {
76-
result, err := s.analyzeMessage(r.Context(), message)
69+
result, err := analyzeMessage(r.Context(), message)
7770
if err != nil {
7871
log.Printf("Error analyzing message '%s': %v", message, err)
7972
continue
@@ -87,7 +80,7 @@ func (s *server) handleAnalyze(w http.ResponseWriter, r *http.Request) {
8780
}
8881
}
8982

90-
func (s *server) analyzeMessage(ctx context.Context, message string) (analysisResult, error) {
83+
func analyzeMessage(ctx context.Context, message string) (analysisResult, error) {
9184
ollamaReq := ollamaRequest{
9285
Model: modelName,
9386
Messages: []Message{
@@ -101,7 +94,7 @@ func (s *server) analyzeMessage(ctx context.Context, message string) (analysisRe
10194
return analysisResult{}, fmt.Errorf("marshaling request: %w", err)
10295
}
10396

104-
req, err := http.NewRequestWithContext(ctx, http.MethodPost, s.ollamaURL+"/v1/chat/completions", bytes.NewReader(reqBody))
97+
req, err := http.NewRequestWithContext(ctx, http.MethodPost, ollamaURL+"/api/chat", bytes.NewReader(reqBody))
10598
if err != nil {
10699
return analysisResult{}, fmt.Errorf("creating request: %w", err)
107100
}
@@ -179,22 +172,23 @@ func corsMiddleware(next http.HandlerFunc) http.HandlerFunc {
179172
}
180173

181174
func main() {
182-
ollamaURL := os.Getenv("OLLAMA_URL")
183-
if ollamaURL == "" {
184-
ollamaURL = "http://localhost:11434"
185-
}
186-
187-
srv := newServer(ollamaURL)
188175
mux := http.NewServeMux()
189-
mux.HandleFunc("/api/analyze", corsMiddleware(srv.handleAnalyze))
176+
177+
// Health check endpoint
178+
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
179+
w.Write([]byte("Content moderation service is running"))
180+
})
181+
182+
// Analysis endpoint
183+
mux.HandleFunc("/api/analyze", corsMiddleware(handleAnalyze))
190184

191185
port := os.Getenv("PORT")
192186
if port == "" {
193-
port = "8080"
187+
port = "80"
194188
}
195189
addr := ":" + port
196190

197-
log.Printf("Server starting on %s, connecting to Ollama at %s", addr, ollamaURL)
191+
log.Printf("Server starting on %s", addr)
198192
if err := http.ListenAndServe(addr, mux); err != nil {
199193
log.Fatal(err)
200194
}

0 commit comments

Comments
 (0)