blob: 4e84b3c2eff8ac23d4dfd49afefac67bc4c0d8fc [file] [log] [blame]
Daniel Borkmann604326b2018-10-13 02:45:58 +02001/* SPDX-License-Identifier: GPL-2.0 */
2/* Copyright (c) 2017 - 2018 Covalent IO, Inc. http://covalent.io */
3
4#ifndef _LINUX_SKMSG_H
5#define _LINUX_SKMSG_H
6
7#include <linux/bpf.h>
8#include <linux/filter.h>
9#include <linux/scatterlist.h>
10#include <linux/skbuff.h>
11
12#include <net/sock.h>
13#include <net/tcp.h>
14#include <net/strparser.h>
15
16#define MAX_MSG_FRAGS MAX_SKB_FRAGS
17
18enum __sk_action {
19 __SK_DROP = 0,
20 __SK_PASS,
21 __SK_REDIRECT,
22 __SK_NONE,
23};
24
25struct sk_msg_sg {
26 u32 start;
27 u32 curr;
28 u32 end;
29 u32 size;
30 u32 copybreak;
31 bool copy[MAX_MSG_FRAGS];
32 struct scatterlist data[MAX_MSG_FRAGS];
33};
34
35struct sk_msg {
36 struct sk_msg_sg sg;
37 void *data;
38 void *data_end;
39 u32 apply_bytes;
40 u32 cork_bytes;
41 u32 flags;
42 struct sk_buff *skb;
43 struct sock *sk_redir;
44 struct sock *sk;
45 struct list_head list;
46};
47
48struct sk_psock_progs {
49 struct bpf_prog *msg_parser;
50 struct bpf_prog *skb_parser;
51 struct bpf_prog *skb_verdict;
52};
53
54enum sk_psock_state_bits {
55 SK_PSOCK_TX_ENABLED,
56};
57
58struct sk_psock_link {
59 struct list_head list;
60 struct bpf_map *map;
61 void *link_raw;
62};
63
64struct sk_psock_parser {
65 struct strparser strp;
66 bool enabled;
67 void (*saved_data_ready)(struct sock *sk);
68};
69
70struct sk_psock_work_state {
71 struct sk_buff *skb;
72 u32 len;
73 u32 off;
74};
75
76struct sk_psock {
77 struct sock *sk;
78 struct sock *sk_redir;
79 u32 apply_bytes;
80 u32 cork_bytes;
81 u32 eval;
82 struct sk_msg *cork;
83 struct sk_psock_progs progs;
84 struct sk_psock_parser parser;
85 struct sk_buff_head ingress_skb;
86 struct list_head ingress_msg;
87 unsigned long state;
88 struct list_head link;
89 spinlock_t link_lock;
90 refcount_t refcnt;
91 void (*saved_unhash)(struct sock *sk);
92 void (*saved_close)(struct sock *sk, long timeout);
93 void (*saved_write_space)(struct sock *sk);
94 struct proto *sk_proto;
95 struct sk_psock_work_state work_state;
96 struct work_struct work;
97 union {
98 struct rcu_head rcu;
99 struct work_struct gc;
100 };
101};
102
103int sk_msg_alloc(struct sock *sk, struct sk_msg *msg, int len,
104 int elem_first_coalesce);
Daniel Borkmannd829e9c2018-10-13 02:45:59 +0200105int sk_msg_clone(struct sock *sk, struct sk_msg *dst, struct sk_msg *src,
106 u32 off, u32 len);
Daniel Borkmann604326b2018-10-13 02:45:58 +0200107void sk_msg_trim(struct sock *sk, struct sk_msg *msg, int len);
108int sk_msg_free(struct sock *sk, struct sk_msg *msg);
109int sk_msg_free_nocharge(struct sock *sk, struct sk_msg *msg);
110void sk_msg_free_partial(struct sock *sk, struct sk_msg *msg, u32 bytes);
111void sk_msg_free_partial_nocharge(struct sock *sk, struct sk_msg *msg,
112 u32 bytes);
113
114void sk_msg_return(struct sock *sk, struct sk_msg *msg, int bytes);
115
116int sk_msg_zerocopy_from_iter(struct sock *sk, struct iov_iter *from,
117 struct sk_msg *msg, u32 bytes);
118int sk_msg_memcopy_from_iter(struct sock *sk, struct iov_iter *from,
119 struct sk_msg *msg, u32 bytes);
120
121static inline void sk_msg_check_to_free(struct sk_msg *msg, u32 i, u32 bytes)
122{
123 WARN_ON(i == msg->sg.end && bytes);
124}
125
126static inline void sk_msg_apply_bytes(struct sk_psock *psock, u32 bytes)
127{
128 if (psock->apply_bytes) {
129 if (psock->apply_bytes < bytes)
130 psock->apply_bytes = 0;
131 else
132 psock->apply_bytes -= bytes;
133 }
134}
135
136#define sk_msg_iter_var_prev(var) \
137 do { \
138 if (var == 0) \
139 var = MAX_MSG_FRAGS - 1; \
140 else \
141 var--; \
142 } while (0)
143
144#define sk_msg_iter_var_next(var) \
145 do { \
146 var++; \
147 if (var == MAX_MSG_FRAGS) \
148 var = 0; \
149 } while (0)
150
151#define sk_msg_iter_prev(msg, which) \
152 sk_msg_iter_var_prev(msg->sg.which)
153
154#define sk_msg_iter_next(msg, which) \
155 sk_msg_iter_var_next(msg->sg.which)
156
157static inline void sk_msg_clear_meta(struct sk_msg *msg)
158{
159 memset(&msg->sg, 0, offsetofend(struct sk_msg_sg, copy));
160}
161
162static inline void sk_msg_init(struct sk_msg *msg)
163{
164 memset(msg, 0, sizeof(*msg));
165 sg_init_marker(msg->sg.data, ARRAY_SIZE(msg->sg.data));
166}
167
168static inline void sk_msg_xfer(struct sk_msg *dst, struct sk_msg *src,
169 int which, u32 size)
170{
171 dst->sg.data[which] = src->sg.data[which];
172 dst->sg.data[which].length = size;
173 src->sg.data[which].length -= size;
174 src->sg.data[which].offset += size;
175}
176
177static inline u32 sk_msg_elem_used(const struct sk_msg *msg)
178{
179 return msg->sg.end >= msg->sg.start ?
180 msg->sg.end - msg->sg.start :
181 msg->sg.end + (MAX_MSG_FRAGS - msg->sg.start);
182}
183
184static inline bool sk_msg_full(const struct sk_msg *msg)
185{
186 return (msg->sg.end == msg->sg.start) && msg->sg.size;
187}
188
189static inline struct scatterlist *sk_msg_elem(struct sk_msg *msg, int which)
190{
191 return &msg->sg.data[which];
192}
193
194static inline struct page *sk_msg_page(struct sk_msg *msg, int which)
195{
196 return sg_page(sk_msg_elem(msg, which));
197}
198
199static inline bool sk_msg_to_ingress(const struct sk_msg *msg)
200{
201 return msg->flags & BPF_F_INGRESS;
202}
203
204static inline void sk_msg_compute_data_pointers(struct sk_msg *msg)
205{
206 struct scatterlist *sge = sk_msg_elem(msg, msg->sg.start);
207
208 if (msg->sg.copy[msg->sg.start]) {
209 msg->data = NULL;
210 msg->data_end = NULL;
211 } else {
212 msg->data = sg_virt(sge);
213 msg->data_end = msg->data + sge->length;
214 }
215}
216
217static inline void sk_msg_page_add(struct sk_msg *msg, struct page *page,
218 u32 len, u32 offset)
219{
220 struct scatterlist *sge;
221
222 get_page(page);
223 sge = sk_msg_elem(msg, msg->sg.end);
224 sg_set_page(sge, page, len, offset);
225 sg_unmark_end(sge);
226
227 msg->sg.copy[msg->sg.end] = true;
228 msg->sg.size += len;
229 sk_msg_iter_next(msg, end);
230}
231
232static inline struct sk_psock *sk_psock(const struct sock *sk)
233{
234 return rcu_dereference_sk_user_data(sk);
235}
236
237static inline bool sk_has_psock(struct sock *sk)
238{
239 return sk_psock(sk) != NULL && sk->sk_prot->recvmsg == tcp_bpf_recvmsg;
240}
241
242static inline void sk_psock_queue_msg(struct sk_psock *psock,
243 struct sk_msg *msg)
244{
245 list_add_tail(&msg->list, &psock->ingress_msg);
246}
247
248static inline void sk_psock_report_error(struct sk_psock *psock, int err)
249{
250 struct sock *sk = psock->sk;
251
252 sk->sk_err = err;
253 sk->sk_error_report(sk);
254}
255
256struct sk_psock *sk_psock_init(struct sock *sk, int node);
257
258int sk_psock_init_strp(struct sock *sk, struct sk_psock *psock);
259void sk_psock_start_strp(struct sock *sk, struct sk_psock *psock);
260void sk_psock_stop_strp(struct sock *sk, struct sk_psock *psock);
261
262int sk_psock_msg_verdict(struct sock *sk, struct sk_psock *psock,
263 struct sk_msg *msg);
264
265static inline struct sk_psock_link *sk_psock_init_link(void)
266{
267 return kzalloc(sizeof(struct sk_psock_link),
268 GFP_ATOMIC | __GFP_NOWARN);
269}
270
271static inline void sk_psock_free_link(struct sk_psock_link *link)
272{
273 kfree(link);
274}
275
276struct sk_psock_link *sk_psock_link_pop(struct sk_psock *psock);
277#if defined(CONFIG_BPF_STREAM_PARSER)
278void sk_psock_unlink(struct sock *sk, struct sk_psock_link *link);
279#else
280static inline void sk_psock_unlink(struct sock *sk,
281 struct sk_psock_link *link)
282{
283}
284#endif
285
286void __sk_psock_purge_ingress_msg(struct sk_psock *psock);
287
288static inline void sk_psock_cork_free(struct sk_psock *psock)
289{
290 if (psock->cork) {
291 sk_msg_free(psock->sk, psock->cork);
292 kfree(psock->cork);
293 psock->cork = NULL;
294 }
295}
296
297static inline void sk_psock_update_proto(struct sock *sk,
298 struct sk_psock *psock,
299 struct proto *ops)
300{
301 psock->saved_unhash = sk->sk_prot->unhash;
302 psock->saved_close = sk->sk_prot->close;
303 psock->saved_write_space = sk->sk_write_space;
304
305 psock->sk_proto = sk->sk_prot;
306 sk->sk_prot = ops;
307}
308
309static inline void sk_psock_restore_proto(struct sock *sk,
310 struct sk_psock *psock)
311{
312 if (psock->sk_proto) {
313 sk->sk_prot = psock->sk_proto;
314 psock->sk_proto = NULL;
315 }
316}
317
318static inline void sk_psock_set_state(struct sk_psock *psock,
319 enum sk_psock_state_bits bit)
320{
321 set_bit(bit, &psock->state);
322}
323
324static inline void sk_psock_clear_state(struct sk_psock *psock,
325 enum sk_psock_state_bits bit)
326{
327 clear_bit(bit, &psock->state);
328}
329
330static inline bool sk_psock_test_state(const struct sk_psock *psock,
331 enum sk_psock_state_bits bit)
332{
333 return test_bit(bit, &psock->state);
334}
335
336static inline struct sk_psock *sk_psock_get(struct sock *sk)
337{
338 struct sk_psock *psock;
339
340 rcu_read_lock();
341 psock = sk_psock(sk);
342 if (psock && !refcount_inc_not_zero(&psock->refcnt))
343 psock = NULL;
344 rcu_read_unlock();
345 return psock;
346}
347
348void sk_psock_stop(struct sock *sk, struct sk_psock *psock);
349void sk_psock_destroy(struct rcu_head *rcu);
350void sk_psock_drop(struct sock *sk, struct sk_psock *psock);
351
352static inline void sk_psock_put(struct sock *sk, struct sk_psock *psock)
353{
354 if (refcount_dec_and_test(&psock->refcnt))
355 sk_psock_drop(sk, psock);
356}
357
358static inline void psock_set_prog(struct bpf_prog **pprog,
359 struct bpf_prog *prog)
360{
361 prog = xchg(pprog, prog);
362 if (prog)
363 bpf_prog_put(prog);
364}
365
366static inline void psock_progs_drop(struct sk_psock_progs *progs)
367{
368 psock_set_prog(&progs->msg_parser, NULL);
369 psock_set_prog(&progs->skb_parser, NULL);
370 psock_set_prog(&progs->skb_verdict, NULL);
371}
372
373#endif /* _LINUX_SKMSG_H */