Websocket follow-up, image creation with Gen AI
Follow-up to previous post

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
// The function to trigger the generation of an image
const { mutate } = useMutation({
mutationFn: () => {
setLoading(true);
return sendPost("/generate-image", {
text: post,
s3Key: `image/${id}`,
bedrock: bedrock,
});
},
});
// The function for polling the /get-image endpoint. When status 200 is returned polling stops and image is displayed
const { data } = useQuery({
queryKey: ["image"],
refetchInterval: 3000,
refetchIntervalInBackground: true,
enabled: refetch,
queryFn: async () => {
const imageData = await getImage("/get-image", { s3Key: `image/${id}` });
if (imageData.status === 200) {
setRefetch(false);
setImage(imageData.data);
setLoading(false);
}
return imageData.data;
},
});
- Remove both the mutation and the useQuery function
- Uninstall React-Query
- Create a context in react for websocket requests and use the build in lib for websockets.
- Use context for requests.
1
2
3
4
5
6
7
function App() {
return (
<WebSocketProvider>
<MainContent />
</WebSocketProvider>
);
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import React, { createContext, useContext, useEffect, useState } from "react";
const API_URL = process.env.REACT_APP_API_URL || "";
// The type of what is returned by this context to the component
interface WebSocketContextType {
socket: WebSocket | null;
sendMessage: (message: object) => void;
isLoading: boolean;
url: string | null;
}
interface WebSocketMessage {
url: string;
}
const WebSocketContext = createContext<WebSocketContextType | undefined>(
undefined
);
export const WebSocketProvider: React.FC<{ children: React.ReactNode }> = ({
children,
}) => {
const [socket, setSocket] = useState<WebSocket | null>(null);
const [isLoading, setIsLoading] = useState<boolean>(false);
const [url, setUrl] = useState<string | null>(null);
useEffect(() => {
const ws = new WebSocket(API_URL);
setSocket(ws);
ws.onopen = () => {
console.log("Connected to WebSocket");
};
ws.onmessage = (event) => {
try {
const data: WebSocketMessage = JSON.parse(event.data);
setUrl(data.url);
} catch (error) {
console.error("Error parsing WebSocket message: ", error);
} finally {
setIsLoading(false);
ws.close();
}
};
ws.onerror = (error) => {
console.error("WebSocket Error: ", error);
};
ws.onclose = () => {
console.log("WebSocket connection closed");
};
return () => {
ws.close();
};
}, []);
const sendMessage = (message: object) => {
if (socket && socket.readyState === WebSocket.OPEN) {
setIsLoading(true);
socket.send(JSON.stringify(message));
}
};
return (
<WebSocketContext.Provider value={{ socket, sendMessage, isLoading, url }}>
{children}
</WebSocketContext.Provider>
);
};
export const useWebSocket = (): WebSocketContextType => {
const context = useContext(WebSocketContext);
if (!context) {
throw new Error("useWebSocket must be used within a WebSocketProvider");
}
return context;
};
- I initialize a few state variables:
- socket: This holds the websocket instance.
- isLoading: Holds the information if I have sent a message and are waiting for a reply. (So I can tell the user that it is loading)
- url: Stores the url from the websocket reply.
- UseEffect:
- Runs once when the component is called.
- Establishes a WebSocket connection.
- sendMessage function:
- Sends the prompt via websocket to api gateway.
- Sets loading to true which means I can give the user feedback that we are waiting for a reply.
- onmessage:
- If a message is received I parse the response so I can use the url to show the image in the component.
- Finally sets loading to false to indicate that the image is generated.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import React from "react";
import { Button, Flex, Layout, Spin, Image, Switch } from "antd";
import TextArea from "antd/es/input/TextArea";
import { Typography } from "antd";
import { useWebSocket } from "../../contexts/websocket";
const { Paragraph, Text } = Typography;
const { Header, Content } = Layout;
const headerStyle: React.CSSProperties = {
textAlign: "center",
color: "#fff",
height: 64,
paddingInline: 48,
lineHeight: "64px",
backgroundColor: "#00415a",
};
const contentStyle: React.CSSProperties = {
textAlign: "center",
minHeight: 120,
lineHeight: "120px",
color: "#fff",
backgroundColor: "#00719c",
};
const layoutStyle = {
borderRadius: 8,
overflow: "hidden",
width: "calc(50% - 8px)",
maxWidth: "calc(50% - 8px)",
marginTop: "10vh",
height: "100%",
};
const textAreaStyle = {
width: "80%",
height: "80%",
};
const buttonStyle = {
width: "80%",
height: "80%",
marginBottom: 16,
backgroundColor: "#009bd6",
};
const paragraphStyle = {
margin: 10,
};
const textStyle = {
color: "white",
};
export default function MainContent() {
const id = "testing12asswwxc222";
const { sendMessage, isLoading, url } = useWebSocket();
const [bedrock, setBedrock] = React.useState(false);
const [text, setText] = React.useState("");
const handleBedrockChange = (checked: boolean) => {
setBedrock(checked);
};
const handleSendMessage = () => {
const message = {
action: "generate",
text: text,
s3Key: id,
bedrock,
};
sendMessage(message);
setText("");
};
return (
<Spin spinning={isLoading}>
<Flex justify="center">
<Layout style={layoutStyle}>
<Header style={headerStyle}>Epic Post To Image POC</Header>
<Content style={contentStyle}>
<Paragraph style={paragraphStyle}>
<Text strong style={textStyle}>
Write a post as you would for a social media platform. Click
generate to create an image for your post.
</Text>
.
</Paragraph>
<TextArea
style={textAreaStyle}
rows={4}
placeholder="Write here"
onChange={(e) => setText(e.target.value)}
/>
<Paragraph style={paragraphStyle}>
<Text strong style={textStyle}>
Use Bedrock? (Toggle to use bedrock)
</Text>
<Switch onChange={handleBedrockChange} />
</Paragraph>
{url && (
<Image width={"80%"} style={{ marginTop: "10px" }} src={url} />
)}
<Button
type="primary"
style={buttonStyle}
onClick={handleSendMessage}
>
Generate
</Button>
</Content>
</Layout>
</Flex>
</Spin>
);
}
- The openApi definition of api gateway has been removed.
- Websocket api definitions has been added:
1
2
3
4
5
6
WebSocketApi:
Type: AWS::ApiGatewayV2::Api
Properties:
Name: WebSocketApi
ProtocolType: WEBSOCKET
RouteSelectionExpression: "$request.body.action"
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
ConnectRoute:
Type: AWS::ApiGatewayV2::Route
Properties:
ApiId: !Ref WebSocketApi
RouteKey: $connect
AuthorizationType: NONE
OperationName: ConnectRoute
Target: !Join
- "/"
- - "integrations"
- !Ref ConnectIntegration
ConnectIntegration:
Type: AWS::ApiGatewayV2::Integration
Properties:
ApiId: !Ref WebSocketApi
Description: Connect Integration
IntegrationType: AWS_PROXY
IntegrationUri:
Fn::Sub: arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31/functions/${ConnectFunction.Arn}/invocations
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
DisconnectRoute:
Type: AWS::ApiGatewayV2::Route
Properties:
ApiId: !Ref WebSocketApi
RouteKey: $disconnect
AuthorizationType: NONE
OperationName: DisconnectRoute
Target: !Join
- "/"
- - "integrations"
- !Ref DisconnectIntegration
DisconnectIntegration:
Type: AWS::ApiGatewayV2::Integration
Properties:
ApiId: !Ref WebSocketApi
Description: Disconnect Integration
IntegrationType: AWS_PROXY
IntegrationUri:
Fn::Sub: arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31/functions/${DisconnectFunction.Arn}/invocations
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
GenerateRoute:
Type: AWS::ApiGatewayV2::Route
Properties:
ApiId: !Ref WebSocketApi
RouteKey: generate
AuthorizationType: NONE
OperationName: GenerateRoute
Target: !Join
- "/"
- - "integrations"
- !Ref GenerateIntegration
GenerateIntegration:
Type: AWS::ApiGatewayV2::Integration
Properties:
ApiId: !Ref WebSocketApi
Description: forward Integration
IntegrationType: AWS_PROXY
IntegrationUri:
Fn::Sub: arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31/functions/${EventBridgeProxyFunction.Arn}/invocations
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
EventBridgeProxyFunction:
Type: AWS::Serverless::Function
Properties:
Handler: bootstrap
Runtime: provided.al2023
CodeUri: eventbridge-proxy/
Environment:
Variables:
EVENT_BUS_NAME: !Ref PostToImageEventBus
Policies:
- Statement:
- Effect: Allow
Action:
- events:PutEvents
Resource: !Sub "arn:aws:events:${AWS::Region}:${AWS::AccountId}:event-bus/${PostToImageEventBus}"
EventBridgeProxyPermission:
Type: AWS::Lambda::Permission
DependsOn:
- WebSocketApi
Properties:
Action: lambda:InvokeFunction
FunctionName: !Ref EventBridgeProxyFunction
Principal: apigateway.amazonaws.com
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
ConnectFunction:
Type: AWS::Serverless::Function
Properties:
Handler: bootstrap
Runtime: provided.al2023
CodeUri: connect/
ConnectPermission:
Type: AWS::Lambda::Permission
DependsOn:
- WebSocketApi
Properties:
Action: lambda:InvokeFunction
FunctionName: !Ref ConnectFunction
Principal: apigateway.amazonaws.com
DisconnectFunction:
Type: AWS::Serverless::Function
Properties:
Handler: bootstrap
Runtime: provided.al2023
CodeUri: disconnect/
DisconnectPermission:
Type: AWS::Lambda::Permission
DependsOn:
- WebSocketApi
Properties:
Action: lambda:InvokeFunction
FunctionName: !Ref DisconnectFunction
Principal: apigateway.amazonaws.com
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
package main
import (
"context"
"fmt"
"github.com/aws/aws-lambda-go/events"
"github.com/aws/aws-lambda-go/lambda"
)
func handler(ctx context.Context, request events.APIGatewayWebsocketProxyRequest) (events.APIGatewayProxyResponse, error) {
fmt.Println("Connect Event:", request)
return events.APIGatewayProxyResponse{StatusCode: 200, Body: "Connected"}, nil
}
func main() {
lambda.Start(handler)
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
package main
import (
"context"
"fmt"
"github.com/aws/aws-lambda-go/events"
"github.com/aws/aws-lambda-go/lambda"
)
func handler(ctx context.Context, request events.APIGatewayWebsocketProxyRequest) (events.APIGatewayProxyResponse, error) {
fmt.Println("Disconnect Event:", request)
return events.APIGatewayProxyResponse{StatusCode: 200, Body: "Disconnected"}, nil
}
func main() {
lambda.Start(handler)
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
package main
import (
"context"
"encoding/json"
"fmt"
"log"
"os"
"github.com/aws/aws-lambda-go/events"
"github.com/aws/aws-lambda-go/lambda"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/eventbridge"
)
type Payload struct {
Text string `json:"text"`
S3Key string `json:"s3Key"`
Bedrock bool `json:"bedrock"`
}
type EventDetail struct {
ConnectionId string `json:"connectionId"`
Text string `json:"text"`
S3Key string `json:"s3Key"`
Bedrock bool `json:"bedrock"`
}
var svc *eventbridge.EventBridge
func init() {
sesh := session.Must(session.NewSession())
svc = eventbridge.New(sesh)
}
func handler(ctx context.Context, request events.APIGatewayWebsocketProxyRequest) (events.APIGatewayProxyResponse, error) {
eventBusName := os.Getenv("EVENT_BUS_NAME")
if eventBusName == "" {
log.Println("Error: EVENT_BUS_NAME environment variable is not set")
return events.APIGatewayProxyResponse{StatusCode: 400}, fmt.Errorf("missing EVENT_BUS_NAME")
}
connectionId := request.RequestContext.ConnectionID
var payload Payload
err := json.Unmarshal([]byte(request.Body), &payload)
if err != nil {
log.Printf("Error parsing JSON: %s\n", err)
return events.APIGatewayProxyResponse{StatusCode: 400}, nil
}
eventDetail := EventDetail{
ConnectionId: connectionId,
Text: payload.Text,
S3Key: payload.S3Key,
Bedrock: payload.Bedrock,
}
eventDetailJSON, err := json.Marshal(eventDetail)
if err != nil {
log.Printf("Error marshaling event detail: %s\n", err)
return events.APIGatewayProxyResponse{StatusCode: 500}, nil
}
input := &eventbridge.PutEventsInput{
Entries: []*eventbridge.PutEventsRequestEntry{
{
Detail: aws.String(string(eventDetailJSON)),
DetailType: aws.String("PreparePrompt"),
Source: aws.String("api-gateway"),
EventBusName: aws.String(eventBusName),
},
},
}
result, err := svc.PutEvents(input)
if err != nil {
log.Printf("Error putting event: %s\n", err)
return events.APIGatewayProxyResponse{StatusCode: 500}, nil
}
log.Printf("PutEvents result: %v\n", result)
return events.APIGatewayProxyResponse{StatusCode: 200}, nil
}
func main() {
lambda.Start(handler)
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
url := fmt.Sprintf("https://%s.s3.%s.amazonaws.com/%s", bucketName, region, objectKey)
apiClient := apigatewaymanagementapi.New(sesh, aws.NewConfig().WithEndpoint(websocketEndpoint))
postToConnectionInput := &apigatewaymanagementapi.PostToConnectionInput{
ConnectionId: aws.String(event.ConnectionId),
Data: []byte(fmt.Sprintf(`{"url": "%s"}`, url)),
}
_, err = apiClient.PostToConnection(postToConnectionInput)
if err != nil {
log.Println("Error sending response to WebSocket client:", err)
return err
}
log.Println("Successfully sent response to WebSocket client:", event.ConnectionId)