diff --git a/compiler/fory_compiler/generators/go.py b/compiler/fory_compiler/generators/go.py index b25e63c799..c4cc57e73c 100644 --- a/compiler/fory_compiler/generators/go.py +++ b/compiler/fory_compiler/generators/go.py @@ -21,6 +21,7 @@ from typing import Dict, List, Optional, Set, Tuple, Union as TypingUnion from fory_compiler.generators.base import BaseGenerator, GeneratedFile +from fory_compiler.generators.services.go import GoServiceGeneratorMixin from fory_compiler.frontend.utils import parse_idl_file from fory_compiler.ir.ast import ( Message, @@ -38,7 +39,7 @@ from fory_compiler.ir.types import PrimitiveKind -class GoGenerator(BaseGenerator): +class GoGenerator(GoServiceGeneratorMixin, BaseGenerator): """Generates Go structs with fory tags.""" language_name = "go" @@ -206,6 +207,10 @@ def generate(self) -> List[GeneratedFile]: # Generate a single Go file with all types files.append(self.generate_file()) + # Generate gRPC service stubs if requested + if self.options.grpc: + files.extend(self.generate_services()) + return files def get_package_name(self) -> str: diff --git a/compiler/fory_compiler/generators/services/base.py b/compiler/fory_compiler/generators/services/base.py new file mode 100644 index 0000000000..2bc2bcfb2d --- /dev/null +++ b/compiler/fory_compiler/generators/services/base.py @@ -0,0 +1,54 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"Shared utilities for gRPC service stub generators." + +from enum import Enum +from typing import List, Dict +from fory_compiler.ir.ast import RpcMethod + + +class StreamingMode(Enum): + UNARY = 1 + CLIENT_STREAMING = 2 + SERVER_STREAMING = 3 + BIDIRECTIONAL = 4 + + +def streaming_mode(method: RpcMethod) -> StreamingMode: + """Defines the type of rpc streaming patterns.""" + if not method.client_streaming and not method.server_streaming: + return StreamingMode.UNARY + elif method.client_streaming and not method.server_streaming: + return StreamingMode.CLIENT_STREAMING + elif not method.client_streaming and method.server_streaming: + return StreamingMode.SERVER_STREAMING + else: + return StreamingMode.BIDIRECTIONAL + + +class ImportTracker: + """Accumulates cross-package Go imports for generated service stubs.""" + + def __init__(self): + self._imports: Dict[str, str] = {} + + def add(self, alias: str, import_path: str) -> None: + self._imports[alias] = import_path + + def go_imports(self) -> List[str]: + return sorted(self._imports.values()) diff --git a/compiler/fory_compiler/generators/services/go.py b/compiler/fory_compiler/generators/services/go.py new file mode 100644 index 0000000000..281f6afab3 --- /dev/null +++ b/compiler/fory_compiler/generators/services/go.py @@ -0,0 +1,652 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Go gRPC service code generator.""" + +from typing import List +from fory_compiler.generators.services.base import ( + ImportTracker, + StreamingMode, + streaming_mode, +) +from fory_compiler.generators.base import GeneratedFile +from fory_compiler.ir.ast import Service, NamedType + + +class GoServiceGeneratorMixin: + """Generates Go gRPC service stubs.""" + + def generate_services(self) -> List[GeneratedFile]: + local_services = [ + s for s in self.schema.services if not self.is_imported_type(s) + ] + if not local_services: + return [] + return [self._generate_grpc_file(s) for s in local_services] + + def _generate_grpc_file(self, service: Service) -> GeneratedFile: + lines: List[str] = [] + tracker = ImportTracker() + + # License header + lines.append(self.get_license_header("//")) + lines.append("") + + # Package declaration + lines.append(f"package {self.get_package_name()}") + lines.append("") + + # Imports + # save the placeholder index + import_placeholder_index = len(lines) + + lines.extend(self._generate_client_interface(service, tracker)) + lines.extend(self._generate_client_struct(service)) + lines.extend(self._generate_new_client(service)) + lines.extend(self._generate_client_methods(service, tracker)) + lines.extend(self._generate_stream_types(service, tracker)) + lines.extend(self._generate_server_interface(service, tracker)) + lines.extend(self._generate_unimplemented_server(service, tracker)) + lines.extend(self._generate_server_stream_types(service, tracker)) + lines.extend(self._generate_service_desc(service, tracker)) + lines.extend(self._generate_register_server(service)) + + # insert the import block at the saved placeholder index + import_lines = self._build_import_block(tracker) + for i, line in enumerate(import_lines): + lines.insert(import_placeholder_index + i, line) + + return GeneratedFile( + path=f"{self.get_file_name()}_grpc.go", content="\n".join(lines) + ) + + def _build_import_block(self, tracker: ImportTracker) -> List[str]: + imports = [ + '"context"', + '"google.golang.org/grpc"', + '"google.golang.org/grpc/codes"', + '"google.golang.org/grpc/status"', + '"github.com/apache/fory/go/fory"', + ] + + for alias, path in tracker._imports.items(): + imports.append(f'{alias} "{path}"') + + sorted_imports = sorted(set(imports)) + + lines = ["import ("] + for imp in sorted_imports: + lines.append(f"\t{imp}") + lines.append(")") + lines.append("") + + return lines + + def _resolve_go_type(self, named_type: NamedType, tracker: ImportTracker) -> str: + type_ref = self.schema.resolve_type_name(named_type.name) + type_def = self.schema.get_type(type_ref) + if type_def is not None and self.is_imported_type(type_def): + info = self._import_info_for_type(type_def) + if info: + alias, import_path, _ = info + tracker.add(alias, import_path) + return f"*{alias}.{type_ref}" + return f"*{type_ref}" + + def _generate_client_interface( + self, service: Service, tracker: ImportTracker + ) -> List[str]: + lines: List[str] = [] + lines.append( + f"// {service.name}Client is the client API for {service.name} service." + ) + lines.append(f"type {service.name}Client interface {{") + for method in service.methods: + req_type = self._resolve_go_type(method.request_type, tracker) + res_type = self._resolve_go_type(method.response_type, tracker) + mode = streaming_mode(method) + if mode is StreamingMode.UNARY: + signature = f"(ctx context.Context, in {req_type}, opts ...grpc.CallOption) ({res_type}, error)" + lines.append(f"\t{self.to_pascal_case(method.name)}{signature}") + elif mode is StreamingMode.SERVER_STREAMING: + signature = f"(ctx context.Context, in {req_type}, opts ...grpc.CallOption) ({service.name}_{self.to_pascal_case(method.name)}Client, error)" + lines.append(f"\t{self.to_pascal_case(method.name)}{signature}") + else: + signature = f"(ctx context.Context, opts ...grpc.CallOption) ({service.name}_{self.to_pascal_case(method.name)}Client, error)" + lines.append(f"\t{self.to_pascal_case(method.name)}{signature}") + + lines.append("}") + lines.append("") + return lines + + def _generate_client_struct(self, service: Service) -> List[str]: + lines: List[str] = [] + lines.append(f"type {self.to_camel_case(service.name)}Client struct {{") + lines.append("\tcc grpc.ClientConnInterface") + lines.append("\tfory *fory.Fory") + lines.append("}") + lines.append("") + return lines + + def _generate_new_client(self, service: Service) -> List[str]: + lines: List[str] = [] + lines.append( + f"func New{service.name}Client(cc grpc.ClientConnInterface, f *fory.Fory) {service.name}Client {{" + ) + lines.append( + f"\treturn &{self.to_camel_case(service.name)}Client{{cc: cc, fory: f}}" + ) + lines.append("}") + lines.append("") + return lines + + def _generate_client_methods( + self, service: Service, tracker: ImportTracker + ) -> List[str]: + lines: List[str] = [] + tracker.add("forygrpc", "github.com/apache/fory/go/fory/grpc") + stream_index = 0 + for method in service.methods: + req_type = self._resolve_go_type(method.request_type, tracker) + res_type = self._resolve_go_type(method.response_type, tracker) + mode = streaming_mode(method) + if mode is StreamingMode.UNARY: + lines.append( + f"func (c *{self.to_camel_case(service.name)}Client) {self.to_pascal_case(method.name)}(ctx context.Context, in {req_type}, opts ...grpc.CallOption) ({res_type}, error) {{" + ) + lines.append(f"\tout := new({res_type[1:]})") + lines.append( + f'\terr := c.cc.Invoke(ctx, "{self.get_grpc_method_path(service, method)}", in, out, grpc.ForceCodecV2(forygrpc.CodecV2{{Fory: c.fory}}), opts...)' + ) + lines.append("\tif err != nil {") + lines.append("\t\treturn nil, err") + lines.append("\t}") + lines.append("\treturn out, nil") + lines.append("}") + lines.append("") + elif mode is StreamingMode.SERVER_STREAMING: + lines.append( + f"func (c *{self.to_camel_case(service.name)}Client) {self.to_pascal_case(method.name)}(ctx context.Context, in {req_type}, opts ...grpc.CallOption) ({service.name}_{self.to_pascal_case(method.name)}Client, error) {{" + ) + lines.append( + f'\tstream, err := c.cc.NewStream(ctx, &_{service.name}_serviceDesc.Streams[{stream_index}], "{self.get_grpc_method_path(service, method)}", grpc.ForceCodecV2(forygrpc.CodecV2{{Fory: c.fory}}), opts...)' + ) + lines.append("\tif err != nil {") + lines.append("\t\treturn nil, err") + lines.append("\t}") + lines.append( + f"\tx := &{self.to_camel_case(service.name)}{self.to_pascal_case(method.name)}Client{{stream}}" + ) + lines.append("\tif err := x.SendMsg(in); err != nil {") + lines.append("\t\treturn nil, err") + lines.append("\t}") + lines.append("\tif err := x.CloseSend(); err != nil {") + lines.append("\t\treturn nil, err") + lines.append("\t}") + lines.append("\treturn x, nil") + lines.append("}") + lines.append("") + stream_index += 1 + else: + lines.append( + f"func (c *{self.to_camel_case(service.name)}Client) {self.to_pascal_case(method.name)}(ctx context.Context, opts ...grpc.CallOption) ({service.name}_{self.to_pascal_case(method.name)}Client, error) {{" + ) + lines.append( + f'\tstream, err := c.cc.NewStream(ctx, &_{service.name}_serviceDesc.Streams[{stream_index}], "{self.get_grpc_method_path(service, method)}", grpc.ForceCodecV2(forygrpc.CodecV2{{Fory: c.fory}}), opts...)' + ) + lines.append("\tif err != nil {") + lines.append("\t\treturn nil, err") + lines.append("\t}") + lines.append( + f"\treturn &{self.to_camel_case(service.name)}{self.to_pascal_case(method.name)}Client{{stream}}, nil" + ) + lines.append("}") + lines.append("") + stream_index += 1 + return lines + + def _generate_stream_types( + self, service: Service, tracker: ImportTracker + ) -> List[str]: + lines: List[str] = [] + for method in service.methods: + req_type = self._resolve_go_type(method.request_type, tracker) + res_type = self._resolve_go_type(method.response_type, tracker) + mode = streaming_mode(method) + if mode is StreamingMode.UNARY: + continue + if mode is StreamingMode.CLIENT_STREAMING: + # type interface + lines.append( + f"type {service.name}_{self.to_pascal_case(method.name)}Client interface {{" + ) + lines.append(f"\tSend({req_type}) error") + lines.append(f"\tCloseAndRecv() ({res_type}, error)") + lines.append("\tgrpc.ClientStream") + lines.append("}") + lines.append("") + + # struct + lines.append( + f"type {self.to_camel_case(service.name)}{self.to_pascal_case(method.name)}Client struct {{" + ) + lines.append("\tgrpc.ClientStream") + lines.append("}") + lines.append("") + + # methods + lines.append( + f"func (x *{self.to_camel_case(service.name)}{self.to_pascal_case(method.name)}Client) Send(m {req_type}) error {{" + ) + lines.append("\treturn x.ClientStream.SendMsg(m)") + lines.append("}") + lines.append("") + lines.append( + f"func (x *{self.to_camel_case(service.name)}{self.to_pascal_case(method.name)}Client) CloseAndRecv() ({res_type}, error) {{" + ) + lines.append("\tif err := x.ClientStream.CloseSend(); err != nil {") + lines.append("\t\treturn nil, err") + lines.append("\t}") + lines.append(f"\tm := new({res_type[1:]})") + lines.append("\tif err := x.ClientStream.RecvMsg(m); err != nil {") + lines.append("\t\treturn nil, err") + lines.append("\t}") + lines.append("\treturn m, nil") + lines.append("}") + lines.append("") + elif mode is StreamingMode.SERVER_STREAMING: + # type interface + lines.append( + f"type {service.name}_{self.to_pascal_case(method.name)}Client interface {{" + ) + lines.append(f"\tRecv() ({res_type}, error)") + lines.append("\tgrpc.ClientStream") + lines.append("}") + lines.append("") + + # struct + lines.append( + f"type {self.to_camel_case(service.name)}{self.to_pascal_case(method.name)}Client struct {{" + ) + lines.append("\tgrpc.ClientStream") + lines.append("}") + lines.append("") + + # methods + lines.append( + f"func (x *{self.to_camel_case(service.name)}{self.to_pascal_case(method.name)}Client) Recv() ({res_type}, error) {{" + ) + lines.append(f"\tm := new({res_type[1:]})") + lines.append("\tif err := x.ClientStream.RecvMsg(m); err != nil {") + lines.append("\t\treturn nil, err") + lines.append("\t}") + lines.append("\treturn m, nil") + lines.append("}") + lines.append("") + else: + # interface + lines.append( + f"type {service.name}_{self.to_pascal_case(method.name)}Client interface {{" + ) + lines.append(f"\tSend({req_type}) error") + lines.append(f"\tRecv() ({res_type}, error)") + lines.append("\tgrpc.ClientStream") + lines.append("}") + lines.append("") + + # struct + lines.append( + f"type {self.to_camel_case(service.name)}{self.to_pascal_case(method.name)}Client struct {{" + ) + lines.append("\tgrpc.ClientStream") + lines.append("}") + lines.append("") + + # methods + lines.append( + f"func (x *{self.to_camel_case(service.name)}{self.to_pascal_case(method.name)}Client) Send(m {req_type}) error {{" + ) + lines.append("\treturn x.ClientStream.SendMsg(m)") + lines.append("}") + lines.append("") + lines.append( + f"func (x *{self.to_camel_case(service.name)}{self.to_pascal_case(method.name)}Client) Recv() ({res_type}, error) {{" + ) + lines.append(f"\tm := new({res_type[1:]})") + lines.append("\tif err := x.ClientStream.RecvMsg(m); err != nil {") + lines.append("\t\treturn nil, err") + lines.append("\t}") + lines.append("\treturn m, nil") + lines.append("}") + lines.append("") + return lines + + def _generate_server_interface( + self, service: Service, tracker: ImportTracker + ) -> List[str]: + lines: List[str] = [] + lines.append( + f"// {service.name}Server is the server API for {service.name} service." + ) + lines.append( + f"// All implementations must embed Unimplemented{service.name}Server" + ) + lines.append("// for forward compatibility.") + lines.append(f"type {service.name}Server interface {{") + for method in service.methods: + req_type = self._resolve_go_type(method.request_type, tracker) + res_type = self._resolve_go_type(method.response_type, tracker) + mode = streaming_mode(method) + if mode is StreamingMode.UNARY: + lines.append( + f"\t{self.to_pascal_case(method.name)}(context.Context, {req_type}) ({res_type}, error)" + ) + elif mode is StreamingMode.SERVER_STREAMING: + lines.append( + f"\t{self.to_pascal_case(method.name)}({req_type}, {service.name}_{self.to_pascal_case(method.name)}Server) error" + ) + else: + lines.append( + f"\t{self.to_pascal_case(method.name)}({service.name}_{self.to_pascal_case(method.name)}Server) error" + ) + lines.append(f"\tmustEmbedUnimplemented{service.name}Server()") + lines.append("}") + lines.append("") + return lines + + def _generate_unimplemented_server( + self, service: Service, tracker: ImportTracker + ) -> List[str]: + lines: List[str] = [] + lines.append( + f"// Unimplemented{service.name}Server must be embedded to have forward compatible implementation." + ) + lines.append(f"type Unimplemented{service.name}Server struct {{}}") + lines.append("") + lines.append( + f"func (Unimplemented{service.name}Server) mustEmbedUnimplemented{service.name}Server() {{}}" + ) + lines.append("") + for method in service.methods: + req_type = self._resolve_go_type(method.request_type, tracker) + res_type = self._resolve_go_type(method.response_type, tracker) + mode = streaming_mode(method) + if mode is StreamingMode.UNARY: + lines.append( + f"func (Unimplemented{service.name}Server) {self.to_pascal_case(method.name)}(context.Context, {req_type}) ({res_type}, error) {{" + ) + lines.append( + f'\treturn nil, status.Errorf(codes.Unimplemented, "method {self.to_pascal_case(method.name)} not implemented")' + ) + lines.append("}") + lines.append("") + elif mode is StreamingMode.SERVER_STREAMING: + lines.append( + f"func (Unimplemented{service.name}Server) {self.to_pascal_case(method.name)}({req_type}, {service.name}_{self.to_pascal_case(method.name)}Server) error {{" + ) + lines.append( + f'\treturn status.Errorf(codes.Unimplemented, "method {self.to_pascal_case(method.name)} not implemented")' + ) + lines.append("}") + lines.append("") + else: + lines.append( + f"func (Unimplemented{service.name}Server) {self.to_pascal_case(method.name)}({service.name}_{self.to_pascal_case(method.name)}Server) error {{" + ) + lines.append( + f'\treturn status.Errorf(codes.Unimplemented, "method {self.to_pascal_case(method.name)} not implemented")' + ) + lines.append("}") + lines.append("") + return lines + + def _generate_register_server(self, service: Service) -> List[str]: + lines: List[str] = [] + lines.append( + f"func Register{service.name}Server(s grpc.ServiceRegistrar, srv {service.name}Server) {{" + ) + lines.append(f"\ts.RegisterService(&_{service.name}_serviceDesc, srv)") + lines.append("}") + lines.append("") + return lines + + def _generate_server_stream_types( + self, service: Service, tracker: ImportTracker + ) -> List[str]: + lines: List[str] = [] + for method in service.methods: + req_type = self._resolve_go_type(method.request_type, tracker) + res_type = self._resolve_go_type(method.response_type, tracker) + mode = streaming_mode(method) + if mode is StreamingMode.UNARY: + continue + elif mode is StreamingMode.CLIENT_STREAMING: + # type interface + lines.append( + f"type {service.name}_{self.to_pascal_case(method.name)}Server interface {{" + ) + lines.append(f"\tRecv() ({req_type}, error)") + lines.append(f"\tSendAndClose({res_type}) error") + lines.append("\tgrpc.ServerStream") + lines.append("}") + lines.append("") + + # struct + lines.append( + f"type {self.to_camel_case(service.name)}{self.to_pascal_case(method.name)}Server struct {{" + ) + lines.append("\tgrpc.ServerStream") + lines.append("}") + lines.append("") + + # methods + lines.append( + f"func (x *{self.to_camel_case(service.name)}{self.to_pascal_case(method.name)}Server) Recv() ({req_type}, error) {{" + ) + lines.append(f"\tm := new({req_type[1:]})") + lines.append("\tif err := x.ServerStream.RecvMsg(m); err != nil {") + lines.append("\t\treturn nil, err") + lines.append("\t}") + lines.append("\treturn m, nil") + lines.append("}") + lines.append("") + lines.append( + f"func (x *{self.to_camel_case(service.name)}{self.to_pascal_case(method.name)}Server) SendAndClose(m {res_type}) error {{" + ) + lines.append("\treturn x.ServerStream.SendMsg(m)") + lines.append("}") + lines.append("") + elif mode is StreamingMode.SERVER_STREAMING: + # type interface + lines.append( + f"type {service.name}_{self.to_pascal_case(method.name)}Server interface {{" + ) + lines.append(f"\tSend({res_type}) error") + lines.append("\tgrpc.ServerStream") + lines.append("}") + lines.append("") + + # struct + lines.append( + f"type {self.to_camel_case(service.name)}{self.to_pascal_case(method.name)}Server struct {{" + ) + lines.append("\tgrpc.ServerStream") + lines.append("}") + lines.append("") + + # methods + lines.append( + f"func (x *{self.to_camel_case(service.name)}{self.to_pascal_case(method.name)}Server) Send(m {res_type}) error {{" + ) + lines.append("\treturn x.ServerStream.SendMsg(m)") + lines.append("}") + lines.append("") + else: + # type interface + lines.append( + f"type {service.name}_{self.to_pascal_case(method.name)}Server interface {{" + ) + lines.append(f"\tSend({res_type}) error") + lines.append(f"\tRecv() ({req_type}, error)") + lines.append("\tgrpc.ServerStream") + lines.append("}") + lines.append("") + + # struct + lines.append( + f"type {self.to_camel_case(service.name)}{self.to_pascal_case(method.name)}Server struct {{" + ) + lines.append("\tgrpc.ServerStream") + lines.append("}") + lines.append("") + + # methods + lines.append( + f"func (x *{self.to_camel_case(service.name)}{self.to_pascal_case(method.name)}Server) Send(m {res_type}) error {{" + ) + lines.append("\treturn x.ServerStream.SendMsg(m)") + lines.append("}") + lines.append("") + lines.append( + f"func (x *{self.to_camel_case(service.name)}{self.to_pascal_case(method.name)}Server) Recv() ({req_type}, error) {{" + ) + lines.append(f"\tm := new({req_type[1:]})") + lines.append("\tif err := x.ServerStream.RecvMsg(m); err != nil {") + lines.append("\t\treturn nil, err") + lines.append("\t}") + lines.append("\treturn m, nil") + lines.append("}") + lines.append("") + return lines + + def _generate_service_desc( + self, service: Service, tracker: ImportTracker + ) -> List[str]: + lines: List[str] = [] + for method in service.methods: + req_type = self._resolve_go_type(method.request_type, tracker) + mode = streaming_mode(method) + # handlers + if mode is StreamingMode.UNARY: + lines.append( + f"func _{service.name}_{self.to_pascal_case(method.name)}_Handler(srv interface{{}}, ctx context.Context, dec func(interface{{}}) error, interceptor grpc.UnaryServerInterceptor) (interface{{}}, error) {{" + ) + lines.append(f"\tin := new({req_type[1:]})") + lines.append("\tif err := dec(in); err != nil {") + lines.append("\t\treturn nil, err") + lines.append("\t}") + lines.append("\tif interceptor == nil {") + lines.append( + f"\t\treturn srv.({service.name}Server).{self.to_pascal_case(method.name)}(ctx, in)" + ) + lines.append("\t}") + lines.append("\tinfo := &grpc.UnaryServerInfo{") + lines.append("\t\tServer:\tsrv,") + lines.append( + f'\t\tFullMethod:\t"{self.get_grpc_method_path(service, method)}",' + ) + lines.append("\t}") + lines.append( + "\thandler := func(ctx context.Context, req interface{}) (interface{}, error) {" + ) + lines.append( + f"\t\treturn srv.({service.name}Server).{self.to_pascal_case(method.name)}(ctx, req.({req_type}))" + ) + lines.append("\t}") + lines.append("\treturn interceptor(ctx, in, info, handler)") + lines.append("}") + lines.append("") + elif mode is StreamingMode.SERVER_STREAMING: + lines.append( + f"func _{service.name}_{self.to_pascal_case(method.name)}_Handler(srv interface{{}}, stream grpc.ServerStream) error {{" + ) + lines.append(f"\tm := new({req_type[1:]})") + lines.append("\tif err := stream.RecvMsg(m); err != nil {") + lines.append("\t\treturn err") + lines.append("\t}") + lines.append( + f"\treturn srv.({service.name}Server).{self.to_pascal_case(method.name)}(m, &{self.to_camel_case(service.name)}{self.to_pascal_case(method.name)}Server{{stream}})" + ) + lines.append("}") + lines.append("") + else: + lines.append( + f"func _{service.name}_{self.to_pascal_case(method.name)}_Handler(srv interface{{}}, stream grpc.ServerStream) error {{" + ) + lines.append( + f"\treturn srv.({service.name}Server).{self.to_pascal_case(method.name)}(&{self.to_camel_case(service.name)}{self.to_pascal_case(method.name)}Server{{stream}})" + ) + lines.append("}") + lines.append("") + # var + lines.append(f"var _{service.name}_serviceDesc = grpc.ServiceDesc{{") + lines.append(f'\tServiceName: "{self.get_grpc_service_name(service)}",') + lines.append(f"\tHandlerType: (*{service.name}Server)(nil),") + # unary type service descriptors + lines.append("\tMethods: []grpc.MethodDesc{") + lines.extend(self._generate_unary_type_desc(service)) + lines.append("\t},") + # stream type service descriptors + lines.append("\tStreams: []grpc.StreamDesc{") + lines.extend(self._generate_stream_type_desc(service)) + lines.append("\t},") + lines.append(f'\tMetadata: "{self.get_file_name()}.fdl",') + lines.append("}") + lines.append("") + return lines + + def _generate_unary_type_desc(self, service: Service) -> List[str]: + lines: List[str] = [] + for method in service.methods: + mode = streaming_mode(method) + if mode is StreamingMode.UNARY: + lines.append("\t\t{") + lines.append( + f'\t\t\tMethodName:\t"{self.to_pascal_case(method.name)}",' + ) + lines.append( + f"\t\t\tHandler:\t_{service.name}_{self.to_pascal_case(method.name)}_Handler," + ) + lines.append("\t\t},") + return lines + + def _generate_stream_type_desc(self, service: Service) -> List[str]: + lines: List[str] = [] + for method in service.methods: + mode = streaming_mode(method) + if mode is StreamingMode.UNARY: + continue + else: + lines.append("\t\t{") + lines.append( + f'\t\t\tStreamName:\t"{self.to_pascal_case(method.name)}",' + ) + lines.append( + f"\t\t\tHandler:\t_{service.name}_{self.to_pascal_case(method.name)}_Handler," + ) + if ( + mode is StreamingMode.CLIENT_STREAMING + or mode is StreamingMode.BIDIRECTIONAL + ): + lines.append("\t\t\tClientStreams:\ttrue,") + if ( + mode is StreamingMode.SERVER_STREAMING + or mode is StreamingMode.BIDIRECTIONAL + ): + lines.append("\t\t\tServerStreams:\ttrue,") + lines.append("\t\t},") + return lines diff --git a/compiler/fory_compiler/tests/test_service_codegen.py b/compiler/fory_compiler/tests/test_service_codegen.py index 413fbf0131..a1fb686817 100644 --- a/compiler/fory_compiler/tests/test_service_codegen.py +++ b/compiler/fory_compiler/tests/test_service_codegen.py @@ -129,7 +129,7 @@ def test_service_definition_does_not_affect_message_codegen(): def test_generate_services_returns_empty_list_for_unsupported_generators(): schema = parse_fdl(_GREETER_WITH_SERVICE) for generator_cls in GENERATOR_CLASSES: - if generator_cls in (JavaGenerator, PythonGenerator): + if generator_cls in (JavaGenerator, PythonGenerator, GoGenerator): continue options = GeneratorOptions(output_dir=Path("/tmp")) generator = generator_cls(schema, options) @@ -218,6 +218,27 @@ def test_grpc_streaming_method_shapes(): assert "self.client = channel.stream_unary(" in python assert "self.bidi = channel.stream_stream(" in python + go = next(iter(generate_service_files(schema, GoGenerator).values())) + assert "ClientStreams:\ttrue" in go + assert "ServerStreams:\ttrue" in go + assert "grpc.ClientStream" in go + assert "grpc.ServerStream" in go + + +def test_go_grpc_service_codegen(): + schema = parse_fdl(_GREETER_WITH_SERVICE) + files = generate_service_files(schema, GoGenerator) + assert len(files) == 1 + content = next(iter(files.values())) + assert "func NewGreeterClient(" in content + assert "func RegisterGreeterServer(" in content + assert "type GreeterClient interface" in content + assert "type GreeterServer interface" in content + assert "type UnimplementedGreeterServer struct" in content + assert "forygrpc.CodecV2{Fory: c.fory}" in content + assert '"/demo.greeter.Greeter/SayHello"' in content + assert "mustEmbedUnimplementedGreeterServer()" in content + def test_java_outer_classname_service_references_nested_model_types(): schema = parse_fdl( diff --git a/go/fory/grpc/codec.go b/go/fory/grpc/codec.go new file mode 100644 index 0000000000..99b590a69a --- /dev/null +++ b/go/fory/grpc/codec.go @@ -0,0 +1,54 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package forygrpc + +import ( + "github.com/apache/fory/go/fory" + "google.golang.org/grpc/mem" +) + +// CodecV2 implements grpc/encoding.CodecV2, replacing the default protobuf +// codec with Fory binary serialization on both client and server sides. +// Pass a configured *fory.Fory instance with all message types registered. +type CodecV2 struct { + Fory *fory.Fory +} + +// Marshal serializes the message using Fory and wraps the resulting bytes +// in a single-buffer BufferSlice for gRPC transport. +func (c CodecV2) Marshal(v any) (mem.BufferSlice, error) { + b, err := c.Fory.Marshal(v) + if err != nil { + return nil, err + } + return mem.BufferSlice{mem.NewBuffer(&b, nil)}, nil +} + +// Unmarshal materializes the incoming buffer slice into a contiguous byte +// slice and deserializes it into v using Fory. +func (c CodecV2) Unmarshal(data mem.BufferSlice, v any) error { + buf := data.MaterializeToBuffer(mem.DefaultBufferPool()) + defer buf.Free() + return c.Fory.Unmarshal(buf.ReadOnlyData(), v) +} + +// Name returns the codec identifier registered with gRPC. Using "fory" +// ensures this codec does not conflict with the default "proto" codec. +func (CodecV2) Name() string { + return "fory" +} diff --git a/go/fory/grpc/codec_test.go b/go/fory/grpc/codec_test.go new file mode 100644 index 0000000000..65aa8519ae --- /dev/null +++ b/go/fory/grpc/codec_test.go @@ -0,0 +1,80 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package forygrpc + +import ( + "testing" + + "github.com/apache/fory/go/fory" + "google.golang.org/grpc/mem" +) + +// testMessage is a simple struct registered with Fory for use across all codec tests. +type testMessage struct { + Name string + Value int32 +} + +// newTestFory creates a Fory instance with testMessage registered under type ID 200. +func newTestFory(t *testing.T) *fory.Fory { + t.Helper() + f := fory.New(fory.WithXlang(false)) + if err := f.RegisterStruct(testMessage{}, 200); err != nil { + t.Fatalf("RegisterStruct: %v", err) + } + return f +} + +// TestRoundTrip verifies that a message marshaled by CodecV2 can be fully +// recovered by Unmarshal with all fields intact. +func TestRoundTrip(t *testing.T) { + codec := CodecV2{Fory: newTestFory(t)} + original := &testMessage{Name: "hello", Value: 42} + + buf, err := codec.Marshal(original) + if err != nil { + t.Fatalf("Marshal: %v", err) + } + + got := &testMessage{} + if err := codec.Unmarshal(buf, got); err != nil { + t.Fatalf("Unmarshal: %v", err) + } + + if got.Name != original.Name || got.Value != original.Value { + t.Errorf("round-trip mismatch: got %+v, want %+v", got, original) + } +} + +// TestUnmarshalError verifies that Unmarshal returns an error on corrupt input +// rather than silently producing a zero-value result. +func TestUnmarshalError(t *testing.T) { + codec := CodecV2{Fory: newTestFory(t)} + garbage := []byte{0xFF, 0xFE, 0x00, 0x01} + data := mem.BufferSlice{mem.NewBuffer(&garbage, nil)} + if err := codec.Unmarshal(data, &testMessage{}); err == nil { + t.Error("expected error unmarshaling corrupt data, got nil") + } +} + +// TestName verifies the codec identifier matches the value used in grpc.ForceCodecV2 calls. +func TestName(t *testing.T) { + if name := (CodecV2{}).Name(); name != "fory" { + t.Errorf("Name() = %q, want %q", name, "fory") + } +} diff --git a/go/fory/grpc/go.mod b/go/fory/grpc/go.mod new file mode 100644 index 0000000000..47696f94b1 --- /dev/null +++ b/go/fory/grpc/go.mod @@ -0,0 +1,29 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +module github.com/apache/fory/go/fory/grpc + +go 1.24.0 + +require ( + github.com/apache/fory/go/fory v0.0.0 + google.golang.org/grpc v1.68.0 +) + +require golang.org/x/sys v0.25.0 // indirect + +replace github.com/apache/fory/go/fory => ../ diff --git a/go/fory/grpc/go.sum b/go/fory/grpc/go.sum new file mode 100644 index 0000000000..d0471f5332 --- /dev/null +++ b/go/fory/grpc/go.sum @@ -0,0 +1,22 @@ +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= +github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +golang.org/x/net v0.29.0 h1:5ORfpBpCs4HzDYoodCDBbwHzdR5UrLBZ3sOnUJmFoHo= +golang.org/x/net v0.29.0/go.mod h1:gLkgy8jTGERgjzMic6DS9+SP0ajcu6Xu3Orq/SpETg0= +golang.org/x/sys v0.25.0 h1:r+8e+loiHxRqhXVl6ML1nO3l1+oFoWbnlu2Ehimmi34= +golang.org/x/sys v0.25.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/text v0.18.0 h1:XvMDiNzPAl0jr17s6W9lcaIhGUfUORdGCNsuLmPG224= +golang.org/x/text v0.18.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= +google.golang.org/genproto/googleapis/rpc v0.0.0-20240903143218-8af14fe29dc1 h1:pPJltXNxVzT4pK9yD8vR9X75DaWYYmLGMsEvBfFQZzQ= +google.golang.org/genproto/googleapis/rpc v0.0.0-20240903143218-8af14fe29dc1/go.mod h1:UqMtugtsSgubUsoxbuAoiCXvqvErP7Gf0so0mK9tHxU= +google.golang.org/grpc v1.68.0 h1:aHQeeJbo8zAkAa3pRzrVjZlbz6uSfeOXlJNQM0RAbz0= +google.golang.org/grpc v1.68.0/go.mod h1:fmSPC5AsjSBCK54MyHRx48kpOti1/jRfOlwEWywNjWA= +google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg= +google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=