diff --git a/rpc/compiler/go_generator.cc b/rpc/compiler/go_generator.cc index 21484f3b..bbc5a562 100644 --- a/rpc/compiler/go_generator.cc +++ b/rpc/compiler/go_generator.cc @@ -42,6 +42,7 @@ using namespace std; +// TODO(zhaoq): Support go_package option. namespace grpc_go_generator { bool NoStreaming(const google::protobuf::MethodDescriptor* method) { @@ -88,25 +89,42 @@ std::string BadToUnderscore(std::string str) { return str; } -const string GetFullName(const string& selfPkg, - const string& msgPkg, - const string& msgName) { - if (selfPkg == msgPkg) { - return msgName; +string GenerateFullGoPackage(const google::protobuf::FileDescriptor* file) { + // In opensouce environment, assume each directory has at most one package. + size_t pos = file->name().find_last_of('/'); + if (pos != string::npos) { + return file->name().substr(0, pos); } - return BadToUnderscore(msgPkg) + "." + msgName; + return ""; +} + +const string GetFullMessageQualifiedName( + const google::protobuf::Descriptor* desc, + set& imports, + map& import_alias) { + string pkg = GenerateFullGoPackage(desc->file()); + if (imports.find(pkg) == imports.end()) { + // The message is in the same package as the services definition. + return desc->name(); + } + if (import_alias.find(pkg) != import_alias.end()) { + // The message is in a package whose name is as same as the one consisting + // of the service definition. Use the alias to differentiate. + return import_alias[pkg] + "." + desc->name(); + } + return BadToUnderscore(desc->file()->package()) + "." + desc->name(); } void PrintClientMethodDef(google::protobuf::io::Printer* printer, const google::protobuf::MethodDescriptor* method, - map* vars) { + map* vars, + set& imports, + map& import_alias) { (*vars)["Method"] = method->name(); - (*vars)["Request"] = GetFullName((*vars)["PackageName"], - method->input_type()->file()->package(), - method->input_type()->name()); - (*vars)["Response"] = GetFullName((*vars)["PackageName"], - method->output_type()->file()->package(), - method->output_type()->name()); + (*vars)["Request"] = + GetFullMessageQualifiedName(method->input_type(), imports, import_alias); + (*vars)["Response"] = + GetFullMessageQualifiedName(method->output_type(), imports, import_alias); if (NoStreaming(method)) { printer->Print(*vars, "\t$Method$(ctx context.Context, in *$Request$, opts " @@ -130,14 +148,14 @@ void PrintClientMethodDef(google::protobuf::io::Printer* printer, void PrintClientMethodImpl(google::protobuf::io::Printer* printer, const google::protobuf::MethodDescriptor* method, - map* vars) { + map* vars, + set& imports, + map& import_alias) { (*vars)["Method"] = method->name(); - (*vars)["Request"] = GetFullName((*vars)["PackageName"], - method->input_type()->file()->package(), - method->input_type()->name()); - (*vars)["Response"] = GetFullName((*vars)["PackageName"], - method->output_type()->file()->package(), - method->output_type()->name()); + (*vars)["Request"] = + GetFullMessageQualifiedName(method->input_type(), imports, import_alias); + (*vars)["Response"] = + GetFullMessageQualifiedName(method->output_type(), imports, import_alias); if (NoStreaming(method)) { printer->Print( *vars, @@ -279,12 +297,14 @@ void PrintClientMethodImpl(google::protobuf::io::Printer* printer, void PrintClient(google::protobuf::io::Printer* printer, const google::protobuf::ServiceDescriptor* service, - map* vars) { + map* vars, + set& imports, + map& import_alias) { (*vars)["Service"] = service->name(); (*vars)["ServiceStruct"] = LowerCaseService(service->name()); printer->Print(*vars, "type $Service$Client interface {\n"); for (int i = 0; i < service->method_count(); ++i) { - PrintClientMethodDef(printer, service->method(i), vars); + PrintClientMethodDef(printer, service->method(i), vars, imports, import_alias); } printer->Print("}\n\n"); @@ -298,20 +318,20 @@ void PrintClient(google::protobuf::io::Printer* printer, "\treturn &$ServiceStruct$Client{cc}\n" "}\n\n"); for (int i = 0; i < service->method_count(); ++i) { - PrintClientMethodImpl(printer, service->method(i), vars); + PrintClientMethodImpl(printer, service->method(i), vars, imports, import_alias); } } void PrintServerMethodDef(google::protobuf::io::Printer* printer, const google::protobuf::MethodDescriptor* method, - map* vars) { + map* vars, + set& imports, + map& import_alias) { (*vars)["Method"] = method->name(); - (*vars)["Request"] = GetFullName((*vars)["PackageName"], - method->input_type()->file()->package(), - method->input_type()->name()); - (*vars)["Response"] = GetFullName((*vars)["PackageName"], - method->output_type()->file()->package(), - method->output_type()->name()); + (*vars)["Request"] = + GetFullMessageQualifiedName(method->input_type(), imports, import_alias); + (*vars)["Response"] = + GetFullMessageQualifiedName(method->output_type(), imports, import_alias); if (NoStreaming(method)) { printer->Print( *vars, @@ -328,14 +348,14 @@ void PrintServerMethodDef(google::protobuf::io::Printer* printer, void PrintServerHandler(google::protobuf::io::Printer* printer, const google::protobuf::MethodDescriptor* method, - map* vars) { + map* vars, + set& imports, + map& import_alias) { (*vars)["Method"] = method->name(); - (*vars)["Request"] = GetFullName((*vars)["PackageName"], - method->input_type()->file()->package(), - method->input_type()->name()); - (*vars)["Response"] = GetFullName((*vars)["PackageName"], - method->output_type()->file()->package(), - method->output_type()->name()); + (*vars)["Request"] = + GetFullMessageQualifiedName(method->input_type(), imports, import_alias); + (*vars)["Response"] = + GetFullMessageQualifiedName(method->output_type(), imports, import_alias); if (NoStreaming(method)) { printer->Print( *vars, @@ -473,11 +493,13 @@ void PrintServerStreamingMethodDesc( void PrintServer(google::protobuf::io::Printer* printer, const google::protobuf::ServiceDescriptor* service, - map* vars) { + map* vars, + set& imports, + map& import_alias) { (*vars)["Service"] = service->name(); printer->Print(*vars, "type $Service$Server interface {\n"); for (int i = 0; i < service->method_count(); ++i) { - PrintServerMethodDef(printer, service->method(i), vars); + PrintServerMethodDef(printer, service->method(i), vars, imports, import_alias); } printer->Print("}\n\n"); @@ -487,7 +509,7 @@ void PrintServer(google::protobuf::io::Printer* printer, "}\n\n"); for (int i = 0; i < service->method_count(); ++i) { - PrintServerHandler(printer, service->method(i), vars); + PrintServerHandler(printer, service->method(i), vars, imports, import_alias); } printer->Print(*vars, @@ -513,42 +535,53 @@ void PrintServer(google::protobuf::io::Printer* printer, "}\n\n"); } +bool IsSelfImport(const google::protobuf::FileDescriptor* self, + const google::protobuf::FileDescriptor* import) { + if (GenerateFullGoPackage(self) == GenerateFullGoPackage(import)) { + return true; + } + return false; +} + void PrintMessageImports( google::protobuf::io::Printer* printer, const google::protobuf::FileDescriptor* file, - map* vars) { + map* vars, + set* imports, + map* import_alias) { set descs; - set importedPkgs; for (int i = 0; i < file->service_count(); ++i) { const google::protobuf::ServiceDescriptor* service = file->service(i); for (int j = 0; j < service->method_count(); ++j) { - const google::protobuf::MethodDescriptor* method = service->method(i); - // Remove duplicated imports. - if (importedPkgs.find( - method->input_type()->file()->package()) == importedPkgs.end()) { + const google::protobuf::MethodDescriptor* method = service->method(j); + if (!IsSelfImport(file, method->input_type()->file())) { descs.insert(method->input_type()->file()); - importedPkgs.insert(method->input_type()->file()->package()); } - if (importedPkgs.find( - method->output_type()->file()->package()) == importedPkgs.end()) { + if (!IsSelfImport(file, method->output_type()->file())) { descs.insert(method->output_type()->file()); - importedPkgs.insert(method->output_type()->file()->package()); } } } + int idx = 0; for (auto fd : descs) { - if (fd->package() == (*vars)["PackageName"]) { - continue; + string pkg = GenerateFullGoPackage(fd); + if (pkg != "") { + auto ret = imports->insert(pkg); + // Use ret.second to guarantee if a package spans multiple files, it only + // gets 1 alias. + if (ret.second && file->package() == fd->package()) { + // the same package name in different directories. Require an alias. + (*import_alias)[pkg] = "apb" + std::to_string(idx++); + } } - string name = fd->name(); - string import_path = "import \""; - if (name.find('/') == string::npos) { - // Assume all the proto in the same directory belong to the same package. - continue; - } else { - import_path += name.substr(0, name.find_last_of('/')) + "\""; + } + for (auto import : *imports) { + string import_path = "import "; + if (import_alias->find(import) != import_alias->end()) { + import_path += (*import_alias)[import] + " "; } + import_path += "\"" + import + "\""; printer->Print(import_path.c_str()); printer->Print("\n"); } @@ -560,7 +593,8 @@ string GetServices(const google::protobuf::FileDescriptor* file) { google::protobuf::io::StringOutputStream output_stream(&output); google::protobuf::io::Printer printer(&output_stream, '$'); map vars; - + map import_alias; + set imports; string package_name = !file->options().go_package().empty() ? file->options().go_package() : file->package(); @@ -578,7 +612,7 @@ string GetServices(const google::protobuf::FileDescriptor* file) { "\tproto \"github.com/golang/protobuf/proto\"\n" ")\n\n"); - PrintMessageImports(&printer, file, &vars); + PrintMessageImports(&printer, file, &vars, &imports, &import_alias); // $Package$ is used to fully qualify method names. vars["Package"] = file->package(); @@ -587,9 +621,9 @@ string GetServices(const google::protobuf::FileDescriptor* file) { } for (int i = 0; i < file->service_count(); ++i) { - PrintClient(&printer, file->service(0), &vars); + PrintClient(&printer, file->service(0), &vars, imports, import_alias); printer.Print("\n"); - PrintServer(&printer, file->service(0), &vars); + PrintServer(&printer, file->service(0), &vars, imports, import_alias); printer.Print("\n"); } return output;