Merge pull request #23 from iamqizhao/master

Revise codegen plugin to accommodate more cases and fix some bugs
This commit is contained in:
Qi Zhao
2015-02-02 17:05:46 -08:00

View File

@ -42,6 +42,7 @@
using namespace std;
// TODO(zhaoq): Support go_package option.
namespace grpc_go_generator {
bool NoStreaming(const google::protobuf::MethodDescriptor* method) {
@ -88,25 +89,42 @@ std::string BadToUnderscore(std::string str) {
return str;
}
const string GetFullName(const string& selfPkg,
const string& msgPkg,
const string& msgName) {
if (selfPkg == msgPkg) {
return msgName;
string GenerateFullGoPackage(const google::protobuf::FileDescriptor* file) {
// In opensouce environment, assume each directory has at most one package.
size_t pos = file->name().find_last_of('/');
if (pos != string::npos) {
return file->name().substr(0, pos);
}
return BadToUnderscore(msgPkg) + "." + msgName;
return "";
}
const string GetFullMessageQualifiedName(
const google::protobuf::Descriptor* desc,
set<string>& imports,
map<string, string>& import_alias) {
string pkg = GenerateFullGoPackage(desc->file());
if (imports.find(pkg) == imports.end()) {
// The message is in the same package as the services definition.
return desc->name();
}
if (import_alias.find(pkg) != import_alias.end()) {
// The message is in a package whose name is as same as the one consisting
// of the service definition. Use the alias to differentiate.
return import_alias[pkg] + "." + desc->name();
}
return BadToUnderscore(desc->file()->package()) + "." + desc->name();
}
void PrintClientMethodDef(google::protobuf::io::Printer* printer,
const google::protobuf::MethodDescriptor* method,
map<string, string>* vars) {
map<string, string>* vars,
set<string>& imports,
map<string, string>& import_alias) {
(*vars)["Method"] = method->name();
(*vars)["Request"] = GetFullName((*vars)["PackageName"],
method->input_type()->file()->package(),
method->input_type()->name());
(*vars)["Response"] = GetFullName((*vars)["PackageName"],
method->output_type()->file()->package(),
method->output_type()->name());
(*vars)["Request"] =
GetFullMessageQualifiedName(method->input_type(), imports, import_alias);
(*vars)["Response"] =
GetFullMessageQualifiedName(method->output_type(), imports, import_alias);
if (NoStreaming(method)) {
printer->Print(*vars,
"\t$Method$(ctx context.Context, in *$Request$, opts "
@ -130,14 +148,14 @@ void PrintClientMethodDef(google::protobuf::io::Printer* printer,
void PrintClientMethodImpl(google::protobuf::io::Printer* printer,
const google::protobuf::MethodDescriptor* method,
map<string, string>* vars) {
map<string, string>* vars,
set<string>& imports,
map<string, string>& import_alias) {
(*vars)["Method"] = method->name();
(*vars)["Request"] = GetFullName((*vars)["PackageName"],
method->input_type()->file()->package(),
method->input_type()->name());
(*vars)["Response"] = GetFullName((*vars)["PackageName"],
method->output_type()->file()->package(),
method->output_type()->name());
(*vars)["Request"] =
GetFullMessageQualifiedName(method->input_type(), imports, import_alias);
(*vars)["Response"] =
GetFullMessageQualifiedName(method->output_type(), imports, import_alias);
if (NoStreaming(method)) {
printer->Print(
*vars,
@ -279,12 +297,14 @@ void PrintClientMethodImpl(google::protobuf::io::Printer* printer,
void PrintClient(google::protobuf::io::Printer* printer,
const google::protobuf::ServiceDescriptor* service,
map<string, string>* vars) {
map<string, string>* vars,
set<string>& imports,
map<string, string>& import_alias) {
(*vars)["Service"] = service->name();
(*vars)["ServiceStruct"] = LowerCaseService(service->name());
printer->Print(*vars, "type $Service$Client interface {\n");
for (int i = 0; i < service->method_count(); ++i) {
PrintClientMethodDef(printer, service->method(i), vars);
PrintClientMethodDef(printer, service->method(i), vars, imports, import_alias);
}
printer->Print("}\n\n");
@ -298,20 +318,20 @@ void PrintClient(google::protobuf::io::Printer* printer,
"\treturn &$ServiceStruct$Client{cc}\n"
"}\n\n");
for (int i = 0; i < service->method_count(); ++i) {
PrintClientMethodImpl(printer, service->method(i), vars);
PrintClientMethodImpl(printer, service->method(i), vars, imports, import_alias);
}
}
void PrintServerMethodDef(google::protobuf::io::Printer* printer,
const google::protobuf::MethodDescriptor* method,
map<string, string>* vars) {
map<string, string>* vars,
set<string>& imports,
map<string, string>& import_alias) {
(*vars)["Method"] = method->name();
(*vars)["Request"] = GetFullName((*vars)["PackageName"],
method->input_type()->file()->package(),
method->input_type()->name());
(*vars)["Response"] = GetFullName((*vars)["PackageName"],
method->output_type()->file()->package(),
method->output_type()->name());
(*vars)["Request"] =
GetFullMessageQualifiedName(method->input_type(), imports, import_alias);
(*vars)["Response"] =
GetFullMessageQualifiedName(method->output_type(), imports, import_alias);
if (NoStreaming(method)) {
printer->Print(
*vars,
@ -328,14 +348,14 @@ void PrintServerMethodDef(google::protobuf::io::Printer* printer,
void PrintServerHandler(google::protobuf::io::Printer* printer,
const google::protobuf::MethodDescriptor* method,
map<string, string>* vars) {
map<string, string>* vars,
set<string>& imports,
map<string, string>& import_alias) {
(*vars)["Method"] = method->name();
(*vars)["Request"] = GetFullName((*vars)["PackageName"],
method->input_type()->file()->package(),
method->input_type()->name());
(*vars)["Response"] = GetFullName((*vars)["PackageName"],
method->output_type()->file()->package(),
method->output_type()->name());
(*vars)["Request"] =
GetFullMessageQualifiedName(method->input_type(), imports, import_alias);
(*vars)["Response"] =
GetFullMessageQualifiedName(method->output_type(), imports, import_alias);
if (NoStreaming(method)) {
printer->Print(
*vars,
@ -473,11 +493,13 @@ void PrintServerStreamingMethodDesc(
void PrintServer(google::protobuf::io::Printer* printer,
const google::protobuf::ServiceDescriptor* service,
map<string, string>* vars) {
map<string, string>* vars,
set<string>& imports,
map<string, string>& import_alias) {
(*vars)["Service"] = service->name();
printer->Print(*vars, "type $Service$Server interface {\n");
for (int i = 0; i < service->method_count(); ++i) {
PrintServerMethodDef(printer, service->method(i), vars);
PrintServerMethodDef(printer, service->method(i), vars, imports, import_alias);
}
printer->Print("}\n\n");
@ -487,7 +509,7 @@ void PrintServer(google::protobuf::io::Printer* printer,
"}\n\n");
for (int i = 0; i < service->method_count(); ++i) {
PrintServerHandler(printer, service->method(i), vars);
PrintServerHandler(printer, service->method(i), vars, imports, import_alias);
}
printer->Print(*vars,
@ -513,42 +535,53 @@ void PrintServer(google::protobuf::io::Printer* printer,
"}\n\n");
}
bool IsSelfImport(const google::protobuf::FileDescriptor* self,
const google::protobuf::FileDescriptor* import) {
if (GenerateFullGoPackage(self) == GenerateFullGoPackage(import)) {
return true;
}
return false;
}
void PrintMessageImports(
google::protobuf::io::Printer* printer,
const google::protobuf::FileDescriptor* file,
map<string, string>* vars) {
map<string, string>* vars,
set<string>* imports,
map<string, string>* import_alias) {
set<const google::protobuf::FileDescriptor*> descs;
set<string> importedPkgs;
for (int i = 0; i < file->service_count(); ++i) {
const google::protobuf::ServiceDescriptor* service = file->service(i);
for (int j = 0; j < service->method_count(); ++j) {
const google::protobuf::MethodDescriptor* method = service->method(i);
// Remove duplicated imports.
if (importedPkgs.find(
method->input_type()->file()->package()) == importedPkgs.end()) {
const google::protobuf::MethodDescriptor* method = service->method(j);
if (!IsSelfImport(file, method->input_type()->file())) {
descs.insert(method->input_type()->file());
importedPkgs.insert(method->input_type()->file()->package());
}
if (importedPkgs.find(
method->output_type()->file()->package()) == importedPkgs.end()) {
if (!IsSelfImport(file, method->output_type()->file())) {
descs.insert(method->output_type()->file());
importedPkgs.insert(method->output_type()->file()->package());
}
}
}
int idx = 0;
for (auto fd : descs) {
if (fd->package() == (*vars)["PackageName"]) {
continue;
string pkg = GenerateFullGoPackage(fd);
if (pkg != "") {
auto ret = imports->insert(pkg);
// Use ret.second to guarantee if a package spans multiple files, it only
// gets 1 alias.
if (ret.second && file->package() == fd->package()) {
// the same package name in different directories. Require an alias.
(*import_alias)[pkg] = "apb" + std::to_string(idx++);
}
}
string name = fd->name();
string import_path = "import \"";
if (name.find('/') == string::npos) {
// Assume all the proto in the same directory belong to the same package.
continue;
} else {
import_path += name.substr(0, name.find_last_of('/')) + "\"";
}
for (auto import : *imports) {
string import_path = "import ";
if (import_alias->find(import) != import_alias->end()) {
import_path += (*import_alias)[import] + " ";
}
import_path += "\"" + import + "\"";
printer->Print(import_path.c_str());
printer->Print("\n");
}
@ -560,7 +593,8 @@ string GetServices(const google::protobuf::FileDescriptor* file) {
google::protobuf::io::StringOutputStream output_stream(&output);
google::protobuf::io::Printer printer(&output_stream, '$');
map<string, string> vars;
map<string, string> import_alias;
set<string> imports;
string package_name = !file->options().go_package().empty()
? file->options().go_package()
: file->package();
@ -578,7 +612,7 @@ string GetServices(const google::protobuf::FileDescriptor* file) {
"\tproto \"github.com/golang/protobuf/proto\"\n"
")\n\n");
PrintMessageImports(&printer, file, &vars);
PrintMessageImports(&printer, file, &vars, &imports, &import_alias);
// $Package$ is used to fully qualify method names.
vars["Package"] = file->package();
@ -587,9 +621,9 @@ string GetServices(const google::protobuf::FileDescriptor* file) {
}
for (int i = 0; i < file->service_count(); ++i) {
PrintClient(&printer, file->service(0), &vars);
PrintClient(&printer, file->service(0), &vars, imports, import_alias);
printer.Print("\n");
PrintServer(&printer, file->service(0), &vars);
PrintServer(&printer, file->service(0), &vars, imports, import_alias);
printer.Print("\n");
}
return output;