Skip to content

Commit 0a92e60

Browse files
committed
Add back the previous implementation
1 parent 41d39a0 commit 0a92e60

File tree

3 files changed

+503
-1
lines changed

3 files changed

+503
-1
lines changed

sdk/go/README.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,10 @@ cd sdk/simulator
113113

114114
### Running Tests
115115
```bash
116-
DSTACK_SIMULATOR_ENDPOINT=$(realpath ../simulator/dstack.sock) go test -v ./...
116+
DSTACK_SIMULATOR_ENDPOINT=$(realpath ../simulator/dstack.sock) go test -v ./dstack
117+
118+
# or for the old Tappd client
119+
DSTACK_SIMULATOR_ENDPOINT=$(realpath ../simulator/tappd.sock) go test -v ./tappd
117120
```
118121

119122
## License

sdk/go/tappd/client.go

Lines changed: 369 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,369 @@
1+
// Provides a Dstack SDK Tappd client and related utilities
2+
//
3+
// Author: Franco Barpp Gomes <franco@nethermind.io>
4+
package tappd
5+
6+
import (
7+
"bytes"
8+
"context"
9+
"crypto/sha512"
10+
"encoding/base64"
11+
"encoding/hex"
12+
"encoding/json"
13+
"fmt"
14+
"io"
15+
"log/slog"
16+
"net"
17+
"net/http"
18+
"os"
19+
"strings"
20+
)
21+
22+
// Represents the hash algorithm used in TDX quote generation.
23+
type QuoteHashAlgorithm string
24+
25+
const (
26+
SHA256 QuoteHashAlgorithm = "sha256"
27+
SHA384 QuoteHashAlgorithm = "sha384"
28+
SHA512 QuoteHashAlgorithm = "sha512"
29+
SHA3_256 QuoteHashAlgorithm = "sha3-256"
30+
SHA3_384 QuoteHashAlgorithm = "sha3-384"
31+
SHA3_512 QuoteHashAlgorithm = "sha3-512"
32+
KECCAK256 QuoteHashAlgorithm = "keccak256"
33+
KECCAK384 QuoteHashAlgorithm = "keccak384"
34+
KECCAK512 QuoteHashAlgorithm = "keccak512"
35+
RAW QuoteHashAlgorithm = "raw"
36+
)
37+
38+
// Represents the response from a key derivation request.
39+
type DeriveKeyResponse struct {
40+
Key string `json:"key"`
41+
CertificateChain []string `json:"certificate_chain"`
42+
}
43+
44+
// Decodes the key to bytes, optionally truncating to maxLength. If maxLength
45+
// < 0, the key is not truncated.
46+
func (d *DeriveKeyResponse) ToBytes(maxLength int) ([]byte, error) {
47+
content := d.Key
48+
49+
content = strings.Replace(content, "-----BEGIN PRIVATE KEY-----", "", 1)
50+
content = strings.Replace(content, "-----END PRIVATE KEY-----", "", 1)
51+
content = strings.Replace(content, "\n", "", -1)
52+
53+
binary, err := base64.StdEncoding.DecodeString(content)
54+
if err != nil {
55+
return nil, err
56+
}
57+
58+
if maxLength >= 0 && len(binary) > maxLength {
59+
return binary[:maxLength], nil
60+
}
61+
return binary, nil
62+
}
63+
64+
// Represents the response from a TDX quote request.
65+
type TdxQuoteResponse struct {
66+
Quote string `json:"quote"`
67+
EventLog string `json:"event_log"`
68+
}
69+
70+
// Represents an event log entry in the TCB info
71+
type EventLog struct {
72+
IMR int `json:"imr"`
73+
EventType int `json:"event_type"`
74+
Digest string `json:"digest"`
75+
Event string `json:"event"`
76+
EventPayload string `json:"event_payload"`
77+
}
78+
79+
// Represents the TCB information
80+
type TcbInfo struct {
81+
Mrtd string `json:"mrtd"`
82+
RootfsHash string `json:"rootfs_hash"`
83+
Rtmr0 string `json:"rtmr0"`
84+
Rtmr1 string `json:"rtmr1"`
85+
Rtmr2 string `json:"rtmr2"`
86+
Rtmr3 string `json:"rtmr3"`
87+
EventLog []EventLog `json:"event_log"`
88+
}
89+
90+
// Represents the response from an info request
91+
type TappdInfoResponse struct {
92+
AppID string `json:"app_id"`
93+
InstanceID string `json:"instance_id"`
94+
AppCert string `json:"app_cert"`
95+
TcbInfo TcbInfo `json:"tcb_info"`
96+
AppName string `json:"app_name"`
97+
PublicLogs bool `json:"public_logs"`
98+
PublicSysinfo bool `json:"public_sysinfo"`
99+
}
100+
101+
const INIT_MR = "000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"
102+
103+
// Replays the RTMR history to calculate final RTMR values
104+
func replayRTMR(history []string) (string, error) {
105+
if len(history) == 0 {
106+
return INIT_MR, nil
107+
}
108+
109+
mr := make([]byte, 48)
110+
111+
for _, content := range history {
112+
contentBytes, err := hex.DecodeString(content)
113+
if err != nil {
114+
return "", err
115+
}
116+
117+
if len(contentBytes) < 48 {
118+
padding := make([]byte, 48-len(contentBytes))
119+
contentBytes = append(contentBytes, padding...)
120+
}
121+
122+
h := sha512.New384()
123+
h.Write(append(mr, contentBytes...))
124+
mr = h.Sum(nil)
125+
}
126+
127+
return hex.EncodeToString(mr), nil
128+
}
129+
130+
// Replays the RTMR history to calculate final RTMR values
131+
func (r *TdxQuoteResponse) ReplayRTMRs() (map[int]string, error) {
132+
var eventLog []struct {
133+
IMR int `json:"imr"`
134+
Digest string `json:"digest"`
135+
}
136+
json.Unmarshal([]byte(r.EventLog), &eventLog)
137+
138+
rtmrs := make(map[int]string, 4)
139+
for idx := 0; idx < 4; idx++ {
140+
history := make([]string, 0)
141+
for _, event := range eventLog {
142+
if event.IMR == idx {
143+
history = append(history, event.Digest)
144+
}
145+
}
146+
147+
rtmr, err := replayRTMR(history)
148+
if err != nil {
149+
return nil, err
150+
}
151+
152+
rtmrs[idx] = rtmr
153+
}
154+
155+
return rtmrs, nil
156+
}
157+
158+
// Handles communication with the Tappd service.
159+
type TappdClient struct {
160+
endpoint string
161+
baseURL string
162+
httpClient *http.Client
163+
logger *slog.Logger
164+
}
165+
166+
// Functional option for configuring a TappdClient.
167+
type TappdClientOption func(*TappdClient)
168+
169+
// Sets the endpoint for the TappdClient.
170+
func WithEndpoint(endpoint string) TappdClientOption {
171+
return func(c *TappdClient) {
172+
c.endpoint = endpoint
173+
}
174+
}
175+
176+
// Sets the logger for the TappdClient
177+
func WithLogger(logger *slog.Logger) TappdClientOption {
178+
return func(c *TappdClient) {
179+
c.logger = logger
180+
}
181+
}
182+
183+
// Creates a new TappdClient instance based on the provided endpoint.
184+
// If the endpoint is empty, it will use the simulator endpoint if it is
185+
// set in the environment through DSTACK_SIMULATOR_ENDPOINT. Otherwise, it
186+
// will use the default endpoint at /var/run/tappd.sock.
187+
func NewTappdClient(opts ...TappdClientOption) *TappdClient {
188+
client := &TappdClient{
189+
endpoint: "",
190+
baseURL: "",
191+
httpClient: &http.Client{},
192+
logger: slog.Default(),
193+
}
194+
195+
for _, opt := range opts {
196+
opt(client)
197+
}
198+
199+
client.endpoint = client.getEndpoint()
200+
201+
if strings.HasPrefix(client.endpoint, "http://") || strings.HasPrefix(client.endpoint, "https://") {
202+
client.baseURL = client.endpoint
203+
} else {
204+
client.baseURL = "http://localhost"
205+
client.httpClient = &http.Client{
206+
Transport: &http.Transport{
207+
DialContext: func(_ context.Context, _, _ string) (net.Conn, error) {
208+
return net.Dial("unix", client.endpoint)
209+
},
210+
},
211+
}
212+
}
213+
214+
return client
215+
}
216+
217+
// Returns the appropriate endpoint based on environment and input. If the
218+
// endpoint is empty, it will use the simulator endpoint if it is set in the
219+
// environment through DSTACK_SIMULATOR_ENDPOINT. Otherwise, it will use the
220+
// default endpoint at /var/run/tappd.sock.
221+
func (c *TappdClient) getEndpoint() string {
222+
if c.endpoint != "" {
223+
return c.endpoint
224+
}
225+
if simEndpoint, exists := os.LookupEnv("DSTACK_SIMULATOR_ENDPOINT"); exists {
226+
c.logger.Info("using simulator endpoint", "endpoint", simEndpoint)
227+
return simEndpoint
228+
}
229+
return "/var/run/tappd.sock"
230+
}
231+
232+
// Sends an RPC request to the Tappd service.
233+
func (c *TappdClient) sendRPCRequest(ctx context.Context, path string, payload interface{}) ([]byte, error) {
234+
jsonData, err := json.Marshal(payload)
235+
if err != nil {
236+
return nil, err
237+
}
238+
239+
req, err := http.NewRequestWithContext(ctx, "POST", c.baseURL+path, bytes.NewBuffer(jsonData))
240+
if err != nil {
241+
return nil, err
242+
}
243+
244+
req.Header.Set("Content-Type", "application/json")
245+
resp, err := c.httpClient.Do(req)
246+
if err != nil {
247+
return nil, err
248+
}
249+
defer resp.Body.Close()
250+
251+
if resp.StatusCode != http.StatusOK {
252+
return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode)
253+
}
254+
255+
return io.ReadAll(resp.Body)
256+
}
257+
258+
// Derives a key from the Tappd service. This wraps
259+
// DeriveKeyWithSubjectAndAltNames using the path as the subject and an empty
260+
// altNames.
261+
func (c *TappdClient) DeriveKey(ctx context.Context, path string) (*DeriveKeyResponse, error) {
262+
return c.DeriveKeyWithSubjectAndAltNames(ctx, path, path, nil)
263+
}
264+
265+
// Derives a key from the Tappd service. This wraps
266+
// DeriveKeyWithSubjectAndAltNames using an empty altNames.
267+
func (c *TappdClient) DeriveKeyWithSubject(ctx context.Context, path string, subject string) (*DeriveKeyResponse, error) {
268+
return c.DeriveKeyWithSubjectAndAltNames(ctx, path, subject, nil)
269+
}
270+
271+
// Derives a key from the Tappd service, explicitly setting the subject and
272+
// altNames.
273+
func (c *TappdClient) DeriveKeyWithSubjectAndAltNames(ctx context.Context, path string, subject string, altNames []string) (*DeriveKeyResponse, error) {
274+
if subject == "" {
275+
subject = path
276+
}
277+
278+
payload := map[string]interface{}{
279+
"path": path,
280+
"subject": subject,
281+
}
282+
if len(altNames) > 0 {
283+
payload["alt_names"] = altNames
284+
}
285+
286+
data, err := c.sendRPCRequest(ctx, "/prpc/Tappd.DeriveKey", payload)
287+
if err != nil {
288+
return nil, err
289+
}
290+
291+
var response DeriveKeyResponse
292+
if err := json.Unmarshal(data, &response); err != nil {
293+
return nil, err
294+
}
295+
return &response, nil
296+
}
297+
298+
// Sends a TDX quote request to the Tappd service using SHA512 as the report
299+
// data hash algorithm.
300+
func (c *TappdClient) TdxQuote(ctx context.Context, reportData []byte) (*TdxQuoteResponse, error) {
301+
return c.TdxQuoteWithHashAlgorithm(ctx, reportData, SHA512)
302+
}
303+
304+
// Sends a TDX quote request to the Tappd service with a specific hash
305+
// report data hash algorithm. If the hash algorithm is RAW, the report data
306+
// must be at most 64 bytes - if it's below that, it will be left-padded with
307+
// zeros.
308+
func (c *TappdClient) TdxQuoteWithHashAlgorithm(ctx context.Context, reportData []byte, hashAlgorithm QuoteHashAlgorithm) (*TdxQuoteResponse, error) {
309+
if hashAlgorithm == RAW {
310+
if len(reportData) > 64 {
311+
return nil, fmt.Errorf("report data is too large, it should be at most 64 bytes when hashAlgorithm is RAW")
312+
}
313+
if len(reportData) < 64 {
314+
reportData = append(make([]byte, 64-len(reportData)), reportData...)
315+
}
316+
}
317+
318+
payload := map[string]interface{}{
319+
"report_data": hex.EncodeToString(reportData),
320+
"hash_algorithm": string(hashAlgorithm),
321+
}
322+
323+
data, err := c.sendRPCRequest(ctx, "/prpc/Tappd.TdxQuote", payload)
324+
if err != nil {
325+
return nil, err
326+
}
327+
328+
var response TdxQuoteResponse
329+
if err := json.Unmarshal(data, &response); err != nil {
330+
return nil, err
331+
}
332+
return &response, nil
333+
}
334+
335+
// Sends a request to get information about the Tappd instance
336+
func (c *TappdClient) Info(ctx context.Context) (*TappdInfoResponse, error) {
337+
data, err := c.sendRPCRequest(ctx, "/prpc/Tappd.Info", map[string]interface{}{})
338+
if err != nil {
339+
return nil, err
340+
}
341+
342+
var response struct {
343+
TcbInfo string `json:"tcb_info"`
344+
AppID string `json:"app_id"`
345+
InstanceID string `json:"instance_id"`
346+
AppCert string `json:"app_cert"`
347+
AppName string `json:"app_name"`
348+
PublicLogs bool `json:"public_logs"`
349+
PublicSysinfo bool `json:"public_sysinfo"`
350+
}
351+
if err := json.Unmarshal(data, &response); err != nil {
352+
return nil, err
353+
}
354+
355+
var tcbInfo TcbInfo
356+
if err := json.Unmarshal([]byte(response.TcbInfo), &tcbInfo); err != nil {
357+
return nil, err
358+
}
359+
360+
return &TappdInfoResponse{
361+
AppID: response.AppID,
362+
InstanceID: response.InstanceID,
363+
AppCert: response.AppCert,
364+
TcbInfo: tcbInfo,
365+
AppName: response.AppName,
366+
PublicLogs: response.PublicLogs,
367+
PublicSysinfo: response.PublicSysinfo,
368+
}, nil
369+
}

0 commit comments

Comments
 (0)