Repository: tech-srl/code2seq Branch: master Commit: 8ca14173c323 Files: 71 Total size: 312.7 KB Directory structure: gitextract_t37_dsto/ ├── .gitignore ├── CITATION.cff ├── CSharpExtractor/ │ ├── .gitattributes │ ├── .gitignore │ ├── CSharpExtractor/ │ │ ├── .nuget/ │ │ │ └── packages.config │ │ ├── CSharpExtractor.sln │ │ └── Extractor/ │ │ ├── Extractor.cs │ │ ├── Extractor.csproj │ │ ├── PathFinder.cs │ │ ├── Program.cs │ │ ├── Properties/ │ │ │ └── launchSettings.json │ │ ├── Temp.cs │ │ ├── Tree/ │ │ │ └── Tree.cs │ │ ├── Utilities.cs │ │ └── Variable.cs │ └── extract.py ├── Input.java ├── JavaExtractor/ │ ├── JPredict/ │ │ ├── .classpath │ │ ├── .gitignore │ │ ├── src/ │ │ │ └── main/ │ │ │ └── java/ │ │ │ ├── JavaExtractor/ │ │ │ │ ├── App.java │ │ │ │ ├── Common/ │ │ │ │ │ ├── CommandLineValues.java │ │ │ │ │ ├── Common.java │ │ │ │ │ └── MethodContent.java │ │ │ │ ├── ExtractFeaturesTask.java │ │ │ │ ├── FeatureExtractor.java │ │ │ │ ├── FeaturesEntities/ │ │ │ │ │ ├── ProgramFeatures.java │ │ │ │ │ ├── ProgramRelation.java │ │ │ │ │ └── Property.java │ │ │ │ └── Visitors/ │ │ │ │ ├── FunctionVisitor.java │ │ │ │ └── LeavesCollectorVisitor.java │ │ │ └── Test.java │ │ └── target/ │ │ └── JavaExtractor-0.0.1-SNAPSHOT.jar │ └── extract.py ├── LICENSE ├── Python150kExtractor/ │ ├── README.md │ ├── extract.py │ └── preprocess.sh ├── README.md ├── __init__.py ├── baseline_tokenization/ │ ├── input_example.txt │ ├── javalang/ │ │ ├── __init__.py │ │ ├── ast.py │ │ ├── javadoc.py │ │ ├── parse.py │ │ ├── parser.py │ │ ├── test/ │ │ │ ├── __init__.py │ │ │ ├── source/ │ │ │ │ └── package-info/ │ │ │ │ ├── AnnotationJavadoc.java │ │ │ │ ├── AnnotationOnly.java │ │ │ │ ├── JavadocAnnotation.java │ │ │ │ ├── JavadocOnly.java │ │ │ │ └── NoAnnotationNoJavadoc.java │ │ │ ├── test_java_8_syntax.py │ │ │ ├── test_javadoc.py │ │ │ ├── test_package_declaration.py │ │ │ └── test_util.py │ │ ├── tokenizer.py │ │ ├── tree.py │ │ └── util.py │ └── subtokenize_nmt_baseline.py ├── code2seq.py ├── common.py ├── config.py ├── extractor.py ├── interactive_predict.py ├── model.py ├── preprocess.py ├── preprocess.sh ├── preprocess_csharp.sh ├── reader.py ├── train.sh └── train_python150k.sh ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ *.class *.lst .idea/* *.iml *.xml *.pyc ================================================ FILE: CITATION.cff ================================================ @inproceedings{ alon2018codeseq, title={code2seq: Generating Sequences from Structured Representations of Code}, author={Uri Alon and Shaked Brody and Omer Levy and Eran Yahav}, booktitle={International Conference on Learning Representations}, year={2019}, url={https://openreview.net/forum?id=H1gKYo09tX}, } ================================================ FILE: CSharpExtractor/.gitattributes ================================================ ############################################################################### # Set default behavior to automatically normalize line endings. ############################################################################### * text=auto ############################################################################### # Set default behavior for command prompt diff. # # This is need for earlier builds of msysgit that does not have it on by # default for csharp files. # Note: This is only used by command line ############################################################################### #*.cs diff=csharp ############################################################################### # Set the merge driver for project and solution files # # Merging from the command prompt will add diff markers to the files if there # are conflicts (Merging from VS is not affected by the settings below, in VS # the diff markers are never inserted). Diff markers may cause the following # file extensions to fail to load in VS. An alternative would be to treat # these files as binary and thus will always conflict and require user # intervention with every merge. To do so, just uncomment the entries below ############################################################################### #*.sln merge=binary #*.csproj merge=binary #*.vbproj merge=binary #*.vcxproj merge=binary #*.vcproj merge=binary #*.dbproj merge=binary #*.fsproj merge=binary #*.lsproj merge=binary #*.wixproj merge=binary #*.modelproj merge=binary #*.sqlproj merge=binary #*.wwaproj merge=binary ############################################################################### # behavior for image files # # image files are treated as binary by default. ############################################################################### #*.jpg binary #*.png binary #*.gif binary ############################################################################### # diff behavior for common document formats # # Convert binary document formats to text before diffing them. This feature # is only available from the command line. Turn it on by uncommenting the # entries below. ############################################################################### #*.doc diff=astextplain #*.DOC diff=astextplain #*.docx diff=astextplain #*.DOCX diff=astextplain #*.dot diff=astextplain #*.DOT diff=astextplain #*.pdf diff=astextplain #*.PDF diff=astextplain #*.rtf diff=astextplain #*.RTF diff=astextplain ================================================ FILE: CSharpExtractor/.gitignore ================================================ ## Ignore Visual Studio temporary files, build results, and ## files generated by popular Visual Studio add-ons. # User-specific files *.suo *.user *.userosscache *.sln.docstates # User-specific files (MonoDevelop/Xamarin Studio) *.userprefs # Build results [Dd]ebug/ [Dd]ebugPublic/ [Rr]elease/ [Rr]eleases/ x64/ x86/ bld/ [Bb]in/ [Oo]bj/ [Ll]og/ # Visual Studio 2015 cache/options directory .vs/ # Uncomment if you have tasks that create the project's static files in wwwroot #wwwroot/ # MSTest test Results [Tt]est[Rr]esult*/ [Bb]uild[Ll]og.* # NUNIT *.VisualState.xml TestResult.xml # Build Results of an ATL Project [Dd]ebugPS/ [Rr]eleasePS/ dlldata.c # DNX project.lock.json artifacts/ *_i.c *_p.c *_i.h *.ilk *.meta *.obj *.pch *.pdb *.pgc *.pgd *.rsp *.sbr *.tlb *.tli *.tlh *.tmp *.tmp_proj *.log *.vspscc *.vssscc .builds *.pidb *.svclog *.scc # Chutzpah Test files _Chutzpah* # Visual C++ cache files ipch/ *.aps *.ncb *.opendb *.opensdf *.sdf *.cachefile *.VC.db *.VC.VC.opendb # Visual Studio profiler *.psess *.vsp *.vspx *.sap # TFS 2012 Local Workspace $tf/ # Guidance Automation Toolkit *.gpState # ReSharper is a .NET coding add-in _ReSharper*/ *.[Rr]e[Ss]harper *.DotSettings.user # JustCode is a .NET coding add-in .JustCode # TeamCity is a build add-in _TeamCity* # DotCover is a Code Coverage Tool *.dotCover # NCrunch _NCrunch_* .*crunch*.local.xml nCrunchTemp_* # MightyMoose *.mm.* AutoTest.Net/ # Web workbench (sass) .sass-cache/ # Installshield output folder [Ee]xpress/ # DocProject is a documentation generator add-in DocProject/buildhelp/ DocProject/Help/*.HxT DocProject/Help/*.HxC DocProject/Help/*.hhc DocProject/Help/*.hhk DocProject/Help/*.hhp DocProject/Help/Html2 DocProject/Help/html # Click-Once directory publish/ # Publish Web Output *.[Pp]ublish.xml *.azurePubxml # TODO: Comment the next line if you want to checkin your web deploy settings # but database connection strings (with potential passwords) will be unencrypted *.pubxml *.publishproj # Microsoft Azure Web App publish settings. Comment the next line if you want to # checkin your Azure Web App publish settings, but sensitive information contained # in these scripts will be unencrypted PublishScripts/ # NuGet Packages *.nupkg # The packages folder can be ignored because of Package Restore **/packages/* # except build/, which is used as an MSBuild target. !**/packages/build/ # Uncomment if necessary however generally it will be regenerated when needed #!**/packages/repositories.config # NuGet v3's project.json files produces more ignoreable files *.nuget.props *.nuget.targets # Microsoft Azure Build Output csx/ *.build.csdef # Microsoft Azure Emulator ecf/ rcf/ # Windows Store app package directories and files AppPackages/ BundleArtifacts/ Package.StoreAssociation.xml _pkginfo.txt # Visual Studio cache files # files ending in .cache can be ignored *.[Cc]ache # but keep track of directories ending in .cache !*.[Cc]ache/ # Others ClientBin/ ~$* *~ *.dbmdl *.dbproj.schemaview *.pfx *.publishsettings node_modules/ orleans.codegen.cs # Since there are multiple workflows, uncomment next line to ignore bower_components # (https://github.com/github/gitignore/pull/1529#issuecomment-104372622) #bower_components/ # RIA/Silverlight projects Generated_Code/ # Backup & report files from converting an old project file # to a newer Visual Studio version. Backup files are not needed, # because we have git ;-) _UpgradeReport_Files/ Backup*/ UpgradeLog*.XML UpgradeLog*.htm # SQL Server files *.mdf *.ldf # Business Intelligence projects *.rdl.data *.bim.layout *.bim_*.settings # Microsoft Fakes FakesAssemblies/ # GhostDoc plugin setting file *.GhostDoc.xml # Node.js Tools for Visual Studio .ntvs_analysis.dat # Visual Studio 6 build log *.plg # Visual Studio 6 workspace options file *.opt # Visual Studio LightSwitch build output **/*.HTMLClient/GeneratedArtifacts **/*.DesktopClient/GeneratedArtifacts **/*.DesktopClient/ModelManifest.xml **/*.Server/GeneratedArtifacts **/*.Server/ModelManifest.xml _Pvt_Extensions # Paket dependency manager .paket/paket.exe paket-files/ # FAKE - F# Make .fake/ # JetBrains Rider .idea/ *.sln.iml # no data data/* backupdata/* ================================================ FILE: CSharpExtractor/CSharpExtractor/.nuget/packages.config ================================================  ================================================ FILE: CSharpExtractor/CSharpExtractor/CSharpExtractor.sln ================================================  Microsoft Visual Studio Solution File, Format Version 12.00 # Visual Studio 15 VisualStudioVersion = 15.0.28307.136 MinimumVisualStudioVersion = 10.0.40219.1 Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Extractor", "Extractor\Extractor.csproj", "{481EDE3F-0ED1-4CB9-814A-63A821022552}" EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU Debug|x64 = Debug|x64 Debug|x86 = Debug|x86 Release|Any CPU = Release|Any CPU Release|x64 = Release|x64 Release|x86 = Release|x86 Release20|Any CPU = Release20|Any CPU Release20|x64 = Release20|x64 Release20|x86 = Release20|x86 EndGlobalSection GlobalSection(ProjectConfigurationPlatforms) = postSolution {481EDE3F-0ED1-4CB9-814A-63A821022552}.Debug|Any CPU.ActiveCfg = Debug|Any CPU {481EDE3F-0ED1-4CB9-814A-63A821022552}.Debug|Any CPU.Build.0 = Debug|Any CPU {481EDE3F-0ED1-4CB9-814A-63A821022552}.Debug|x64.ActiveCfg = Debug|Any CPU {481EDE3F-0ED1-4CB9-814A-63A821022552}.Debug|x64.Build.0 = Debug|Any CPU {481EDE3F-0ED1-4CB9-814A-63A821022552}.Debug|x86.ActiveCfg = Debug|Any CPU {481EDE3F-0ED1-4CB9-814A-63A821022552}.Debug|x86.Build.0 = Debug|Any CPU {481EDE3F-0ED1-4CB9-814A-63A821022552}.Release|Any CPU.ActiveCfg = Release|Any CPU {481EDE3F-0ED1-4CB9-814A-63A821022552}.Release|Any CPU.Build.0 = Release|Any CPU {481EDE3F-0ED1-4CB9-814A-63A821022552}.Release|x64.ActiveCfg = Release|Any CPU {481EDE3F-0ED1-4CB9-814A-63A821022552}.Release|x64.Build.0 = Release|Any CPU {481EDE3F-0ED1-4CB9-814A-63A821022552}.Release|x86.ActiveCfg = Release|Any CPU {481EDE3F-0ED1-4CB9-814A-63A821022552}.Release|x86.Build.0 = Release|Any CPU {481EDE3F-0ED1-4CB9-814A-63A821022552}.Release20|Any CPU.ActiveCfg = Release|Any CPU {481EDE3F-0ED1-4CB9-814A-63A821022552}.Release20|Any CPU.Build.0 = Release|Any CPU {481EDE3F-0ED1-4CB9-814A-63A821022552}.Release20|x64.ActiveCfg = Release|Any CPU {481EDE3F-0ED1-4CB9-814A-63A821022552}.Release20|x64.Build.0 = Release|Any CPU {481EDE3F-0ED1-4CB9-814A-63A821022552}.Release20|x86.ActiveCfg = Release|Any CPU {481EDE3F-0ED1-4CB9-814A-63A821022552}.Release20|x86.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE EndGlobalSection GlobalSection(ExtensibilityGlobals) = postSolution SolutionGuid = {13A0DA89-D5D9-4E75-850E-70B9FBE88FF8} EndGlobalSection EndGlobal ================================================ FILE: CSharpExtractor/CSharpExtractor/Extractor/Extractor.cs ================================================ using Extractor.Semantics; using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp; using Microsoft.CodeAnalysis.CSharp.Syntax; using System; using System.Collections.Generic; using System.Linq; using System.Text; using System.Diagnostics; namespace Extractor { public class Extractor { public const string InternalDelimiter = "|"; public const string UpTreeChar = InternalDelimiter; public const string DownTreeChar = InternalDelimiter; public const string MethodNameConst = "METHOD_NAME"; public static SyntaxKind[] ParentTypeToAddChildId = new SyntaxKind[] { SyntaxKind.SimpleAssignmentExpression, SyntaxKind.ElementAccessExpression, SyntaxKind.SimpleMemberAccessExpression, SyntaxKind.InvocationExpression, SyntaxKind.BracketedArgumentList, SyntaxKind.ArgumentList}; private ICollection variables; public int LengthLimit { get; set; } public int WidthLimit { get; set; } public string Code { get; set; } public bool ShouldHash { get; set; } public int MaxContexts { get; set; } public Extractor(string code, Options opts) { LengthLimit = opts.MaxLength; WidthLimit = opts.MaxWidth; ShouldHash = !opts.NoHash; MaxContexts = opts.MaxContexts; Code = code; } StringBuilder builder = new StringBuilder(); private string PathNodesToString(PathFinder.Path path) { builder.Clear(); var nodeTypes = path.LeftSide; if (nodeTypes.Count() > 0) { builder.Append(nodeTypes.First().Kind()); if (ParentTypeToAddChildId.Contains(nodeTypes.First().Parent.Kind())) { builder.Append(GetTruncatedChildId(nodeTypes.First())); } foreach (var n in nodeTypes.Skip(1)) { builder.Append(UpTreeChar).Append(n.Kind()); if (ParentTypeToAddChildId.Contains(n.Parent.Kind())) { builder.Append(GetTruncatedChildId(n)); } } builder.Append(UpTreeChar); } builder.Append(path.Ancesstor.Kind()); nodeTypes = path.RightSide; if (nodeTypes.Count() > 0) { builder.Append(DownTreeChar); builder.Append(nodeTypes.First().Kind()); if (ParentTypeToAddChildId.Contains(nodeTypes.First().Parent.Kind())) { builder.Append(GetTruncatedChildId(nodeTypes.First())); } foreach (var n in nodeTypes.Skip(1)) { builder.Append(DownTreeChar).Append(n.Kind()); if (ParentTypeToAddChildId.Contains(n.Parent.Kind())) { builder.Append(GetTruncatedChildId(n)); } } } return builder.ToString(); } private int GetTruncatedChildId(SyntaxNode n) { var parent = n.Parent; int index = parent.ChildNodes().ToList().IndexOf(n); if (index > 3) { index = 3; } return index; } private string PathToString(PathFinder.Path path) { SyntaxNode ancesstor = path.Ancesstor; StringBuilder builder = new StringBuilder(); builder.Append(path.Left.Text).Append(UpTreeChar); builder.Append(this.PathNodesToString(path)); builder.Append(DownTreeChar).Append(path.Right.Text); return builder.ToString(); } internal IEnumerable GetInternalPaths(Tree tree) { var finder = new PathFinder(tree, LengthLimit, WidthLimit); var allPairs = Utilities.ReservoirSample(Utilities.WeakConcat(Utilities.Choose2(variables), variables.Select((arg) => new Tuple(arg, arg))), MaxContexts); //iterate over variable-variable pairs foreach (Tuple varPair in allPairs) { bool pathToSelf = varPair.Item1 == varPair.Item2; foreach (var rhs in varPair.Item2.Leaves) foreach (var lhs in varPair.Item1.Leaves) { if (lhs == rhs) continue; PathFinder.Path path = finder.FindPath(lhs, rhs, limited: true); if (path == null) continue; yield return path; } } } private string SplitNameUnlessEmpty(string original) { var subtokens = Utilities.SplitToSubtokens(original).Where(s => s.Length > 0); String name = String.Join(InternalDelimiter, subtokens); if (name.Length == 0) { name = Utilities.NormalizeName(original); } if (String.IsNullOrWhiteSpace(name)) { name = "SPACE"; } if (String.IsNullOrEmpty(name)) { name = "BLANK"; } if (original == Extractor.MethodNameConst) { name = original; } return name; } static readonly char[] removeFromComments = new char[] {' ', '/', '*', '{', '}'}; public List Extract() { var tree = new Tree(CSharpSyntaxTree.ParseText(Code).GetRoot()); IEnumerable methods = tree.GetRoot().DescendantNodesAndSelf().OfType().ToList(); List results = new List(); foreach(var method in methods) { String methodName = method.Identifier.ValueText; Tree methodTree = new Tree(method); var subtokensMethodName = Utilities.SplitToSubtokens(methodName); var tokenToVar = new Dictionary(); this.variables = Variable.CreateFromMethod(methodTree).ToArray(); foreach (var variable in variables) { foreach (SyntaxToken token in variable.Leaves) { tokenToVar[token] = variable; } } List contexts = new List(); foreach (PathFinder.Path path in GetInternalPaths(methodTree)) { String pathString = SplitNameUnlessEmpty(tokenToVar[path.Left].Name) + "," + MaybeHash(this.PathNodesToString(path)) + "," + SplitNameUnlessEmpty(tokenToVar[path.Right].Name); Debug.WriteLine(path.Left.FullSpan+" "+tokenToVar[path.Left].Name+ "," +this.PathNodesToString(path)+ "," + tokenToVar[path.Right].Name+" "+path.Right.FullSpan); contexts.Add(pathString); } var commentNodes = tree.GetRoot().DescendantTrivia().Where( node => node.IsKind(SyntaxKind.MultiLineCommentTrivia) || node.IsKind(SyntaxKind.SingleLineCommentTrivia) || node.IsKind(SyntaxKind.MultiLineDocumentationCommentTrivia)); foreach (SyntaxTrivia trivia in commentNodes) { string commentText = trivia.ToString().Trim(removeFromComments); string normalizedTrivia = SplitNameUnlessEmpty(commentText); var parts = normalizedTrivia.Split('|'); for (int i = 0; i < Math.Ceiling((double)parts.Length / (double)5); i++) { var batch = String.Join("|", parts.Skip(i * 5).Take(5)); contexts.Add(batch + "," + "COMMENT" + "," + batch); } } results.Add(String.Join("|", subtokensMethodName) + " " + String.Join(" ", contexts)); } return results; } private string MaybeHash(string v) { if (this.ShouldHash) { return v.GetHashCode().ToString(); } else { return v; } } } } ================================================ FILE: CSharpExtractor/CSharpExtractor/Extractor/Extractor.csproj ================================================ Exe netcoreapp2.2 Extractor.Program ================================================ FILE: CSharpExtractor/CSharpExtractor/Extractor/PathFinder.cs ================================================ using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp.Syntax; using System; using System.Collections.Generic; using System.Linq; namespace Extractor { internal class PathFinder { internal class Path { public SyntaxToken Left { get; } public List LeftSide { get; } public SyntaxNode Ancesstor { get; } public List RightSide { get; } public SyntaxToken Right { get; } public Path(SyntaxToken left, IEnumerable leftSide, SyntaxNode ancesstor, IEnumerable rightSide, SyntaxToken right) { this.Left = left; this.LeftSide = leftSide.ToList(); this.Ancesstor = ancesstor; this.RightSide = rightSide.ToList(); this.Right = right; } } public int Length { get; } public int Width { get; } Tree tree; public PathFinder(Tree tree, int length = 7, int width = 4) { if (length < 1 || width < 1) throw new ArgumentException("Width and Length params must be positive."); Length = length; Width = width; this.tree = tree; } private int GetDepth(SyntaxNode n) { int depth = 0; while(n.Parent != null) { n = n.Parent; depth++; } return depth; } public SyntaxNode FirstAncestor(SyntaxNode l, SyntaxNode r) { if (l.Equals(r)) return l; if (GetDepth(l) >= GetDepth(r)) { l = l.Parent; } else { r = r.Parent; } return FirstAncestor(l, r); } private IEnumerable CollectPathToParent(SyntaxNode start, SyntaxNode parent) { while (!start.Equals(parent)) { yield return start; start = start.Parent; } } internal Path FindPath(SyntaxToken l, SyntaxToken r, bool limited = true) { SyntaxNode p = FirstAncestor(l.Parent, r.Parent); // + 2 for the distance of the leafs themselves if (GetDepth(r.Parent) + GetDepth(l.Parent) - 2 * GetDepth(p) + 2 > Length) { return null; } var leftSide = CollectPathToParent(l.Parent, p); var rightSide = CollectPathToParent(r.Parent, p); rightSide = rightSide.Reverse(); List widthCheck = p.ChildNodes().ToList(); if (limited && leftSide.Count() != 0 && rightSide.Count() != 0) { int indexOfLeft = widthCheck.IndexOf(leftSide.Last()); int indexOfRight = widthCheck.IndexOf(rightSide.First()); if (Math.Abs(indexOfLeft - indexOfRight) >= Width) { return null; } } return new Path(l, leftSide, p, rightSide, r); } } } ================================================ FILE: CSharpExtractor/CSharpExtractor/Extractor/Program.cs ================================================ using CommandLine; using CommandLine.Text; using System; using System.Collections.Generic; using System.IO; using System.Linq; namespace Extractor { class Program { static List ExtractSingleFile(string filename, Options opts) { string data = File.ReadAllText(filename); var extractor = new Extractor(data, opts); List result = extractor.Extract(); return result; } static void Main(string[] args) { Options options = new Options(); Parser.Default.ParseArguments(args) .WithParsed(opt => options = opt) .WithNotParsed(errors => { Console.WriteLine(errors); return; }); string path = options.Path; string[] files; if (Directory.Exists(path)) { files = Directory.GetFiles(path, "*.cs", SearchOption.AllDirectories); } else { files = new string[] { path }; } IEnumerable results = null; results = files.AsParallel().WithDegreeOfParallelism(options.Threads).SelectMany(filename => ExtractSingleFile(filename, options)); using (StreamWriter sw = new StreamWriter(options.OFileName, append: true)) { foreach (var res in results) { sw.WriteLine(res); } } } } } ================================================ FILE: CSharpExtractor/CSharpExtractor/Extractor/Properties/launchSettings.json ================================================ { "profiles": { "Extractor": { "commandName": "Project", "commandLineArgs": "--path C:\\Users\\urial\\Source\\Repos\\CSharpExtractor\\CSharpExtractor\\Extractor\\bin\\ --no_hash" } } } ================================================ FILE: CSharpExtractor/CSharpExtractor/Extractor/Temp.cs ================================================ namespace Extractor { class Temp { class NestedClass { void fooBar() { a.b = c; } } } } ================================================ FILE: CSharpExtractor/CSharpExtractor/Extractor/Tree/Tree.cs ================================================ using System; using System.Collections.Generic; using System.IO; using System.Linq; using System.Text; using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp; using Microsoft.CodeAnalysis.CSharp.Syntax; namespace Extractor { public class Tree { public const string DummyClass = "IgnoreDummyClass"; public const string DummyMethodName = "IgnoreDummyMethod"; public const string DummyType = "IgnoreDummyType"; internal static readonly SyntaxKind[] literals = { SyntaxKind.NumericLiteralToken, SyntaxKind.StringLiteralToken, SyntaxKind.CharacterLiteralToken }; internal static readonly HashSet identifiers = new HashSet(new SyntaxKind[] { SyntaxKind.IdentifierToken }); //, SyntaxKind.VoidKeyword, SyntaxKind.StringKeyword }); internal static readonly HashSet keywords = new HashSet(new SyntaxKind[] { SyntaxKind.RefKeyword, SyntaxKind.OutKeyword, SyntaxKind.ConstKeyword }); internal static readonly HashSet declarations = new HashSet(new SyntaxKind[] { SyntaxKind.VariableDeclarator, SyntaxKind.Parameter, SyntaxKind.CatchDeclaration, SyntaxKind.ForEachStatement }); internal static readonly HashSet memberAccesses = new HashSet(new SyntaxKind[] { SyntaxKind.SimpleMemberAccessExpression, SyntaxKind.PointerMemberAccessExpression }); internal static readonly HashSet scopeEnders = new HashSet( new SyntaxKind[]{ SyntaxKind.Block, SyntaxKind.ForStatement, SyntaxKind.MethodDeclaration, SyntaxKind.ForEachStatement, SyntaxKind.CatchClause, SyntaxKind.SwitchSection, SyntaxKind.UsingStatement }); internal static readonly HashSet lambdaScopeStarters = new HashSet( new SyntaxKind[]{ SyntaxKind.AnonymousMethodExpression, SyntaxKind.SimpleLambdaExpression, SyntaxKind.ParenthesizedLambdaExpression }); public static bool IsScopeEnder(SyntaxNode node) { return Tree.scopeEnders.Contains(node.Kind()); } class TreeBuilderWalker : CSharpSyntaxWalker { Dictionary nodes; HashSet visitedNodes; List Desc; List Tokens; Dictionary tokens; internal TreeBuilderWalker(Dictionary nodes, Dictionary tokens) { visitedNodes = new HashSet(); this.nodes = nodes; this.tokens = tokens; } public override void Visit(SyntaxNode node) { visitedNodes.Add(node); base.Visit(node); visitedNodes.Remove(node); Desc = new List(); Tokens = new List(); foreach (var c in node.ChildNodes()) { if (!nodes.ContainsKey(c)) { continue; } Desc.AddRange(nodes[c].Descendents); Desc.Add(c); Tokens.AddRange(nodes[c].Leaves); } foreach (var token in node.ChildTokens()) { if (Leaf.IsLeafToken(token)) { tokens[token] = new Leaf(nodes, token); Tokens.Add(token); } } Node res = new Node(This: node, Ancestors: new HashSet(visitedNodes), Descendents: Desc.ToArray(), Leaves: Tokens.ToArray(), Kind: node.Kind()); nodes[node] = res; } } internal SyntaxNode GetRoot() { return tree; } SyntaxNode tree; internal Dictionary nodes = new Dictionary(); internal Dictionary leaves = new Dictionary(); public Tree(SyntaxNode syntaxTree) { this.tree = syntaxTree; /*if (this.tree.ChildNodes().ToList().Count() == 0) { this.tree = CSharpSyntaxTree.ParseText($"private {DummyType} {DummyMethodName}() {{ {code} }}"); }*/ new TreeBuilderWalker(nodes, leaves).Visit(this.tree); List commentNodes = tree.DescendantTrivia().Where( node => node.IsKind(SyntaxKind.MultiLineCommentTrivia) || node.IsKind(SyntaxKind.SingleLineCommentTrivia)).ToList(); } } public class Node { public Node(SyntaxNode This, HashSet Ancestors, SyntaxNode[] Descendents, SyntaxToken[] Leaves, SyntaxKind Kind) { this.This = This; this.Ancestors = Ancestors; this.Descendents = Descendents; this.AncestorsAndSelf = new HashSet(Ancestors); this.AncestorsAndSelf.Add(This); this.Leaves = Leaves; this.Depth = Depth; this.Kind = Kind; this.KindName = Kind.ToString(); } public SyntaxNode This { get; } public HashSet Ancestors { get; } public HashSet AncestorsAndSelf { get; } public SyntaxNode[] Descendents { get; } public SyntaxToken[] Leaves { get; } public SyntaxKind Kind { get; } public string KindName { get; } public int Depth { get; } public override bool Equals(object obj) { var item = obj as Node; if (item == null) { return false; } return this.This.Equals(item.This); } public override int GetHashCode() { return this.This.GetHashCode(); } } public class Leaf { internal static bool IsLeafToken(SyntaxToken token) { if (token.Text.Equals("var") && token.IsKind(SyntaxKind.IdentifierToken) && token.Parent.IsKind(SyntaxKind.IdentifierName) && token.Parent.Parent.IsKind(SyntaxKind.VariableDeclaration) && token.Parent.Parent.Parent.IsKind(SyntaxKind.LocalDeclarationStatement)) { return false; } if (token.ValueText == Tree.DummyMethodName || token.ValueText == Tree.DummyType) { return false; } return Tree.identifiers.Contains(token.Kind()) || Tree.literals.Contains(token.Kind()) || token.Parent.Kind() == SyntaxKind.PredefinedType; } public SyntaxToken token { get; } public SyntaxKind Kind { get; } public string KindName { get; } public string Text { get; set; } public bool IsConst { get; } public string VariableName { get; } public Leaf(Dictionary nodes, SyntaxToken token) { this.token = token; Kind = token.Kind(); KindName = Kind.ToString(); IsConst = !(Tree.identifiers.Contains(Kind) && Tree.declarations.Contains(token.Parent.Kind())); Text = token.ValueText; SyntaxNode node = token.Parent.Parent; SyntaxNode current = token.Parent; VariableName = Text; } } public class SyntaxViewer { private string ToDot(SyntaxTree tree) { List nodes = tree.GetRoot().DescendantNodesAndSelf().ToList(); SyntaxToken[] tokens = tree.GetRoot().DescendantTokens().ToArray(); string[] tokenStrings = tokens.Select((arg) => arg.Kind().ToString() + "-" + arg.ToString()).ToArray(); string[] nodeStrings = nodes.Select((arg) => arg.Kind().ToString()).ToArray(); Dictionary counts = new Dictionary(); Dictionary nodeNames = new Dictionary(); IEnumerable allItems = nodeStrings.Concat(tokenStrings); int i = 0; foreach (string name in allItems) { if (!counts.ContainsKey(name)) counts[name] = 0; counts[name] += 1; nodeNames[i] = name + counts[name].ToString(); i++; } StringBuilder builder = new StringBuilder(); builder.AppendLine("digraph G {"); // vertexes for (i = 0; i < allItems.Count(); i++) { builder.AppendFormat("\"{0}\" ;\n", nodeNames[i]); } builder.AppendLine(); // edges for (i = 1; i < nodes.Count(); i++) { builder.AppendFormat("\"{0}\"->\"{1}\" [];\n", nodeNames[nodes.IndexOf(nodes[i].Parent)], nodeNames[i]); } for (i = 0; i < tokens.Count(); i++) { builder.AppendFormat("\"{0}\"->\"{1}\" [];\n", nodeNames[nodes.IndexOf(tokens[i].Parent)], nodeNames[i + nodes.Count()]); } builder.AppendLine("}"); return builder.ToString(); } public SyntaxViewer(SyntaxTree tree, string path = "out.ong") { string dotData = ToDot(tree); File.WriteAllText("out.dot", dotData); } } } ================================================ FILE: CSharpExtractor/CSharpExtractor/Extractor/Utilities.cs ================================================ using CommandLine; using System; using System.Collections.Generic; using System.Linq; using System.Text; using System.Diagnostics; using System.Text.RegularExpressions; namespace Extractor { public class Options { [Option('t', "threads", Default = 1, HelpText = "How many threads to use <1>")] public int Threads { get; set; } [Option('p', "path", Default = "./data/", HelpText = "Where to find code files. <.>")] public string Path { get; set; } [Option('l', "max_length", Default = 9, HelpText = "Max path length")] public int MaxLength { get; set; } [Option('l', "max_width", Default = 2, HelpText = "Max path length")] public int MaxWidth { get; set; } [Option('o', "ofile_name", Default = "test.txt", HelpText = "Output file name")] public String OFileName { get; set; } [Option('h', "no_hash", Default = true, HelpText = "When enabled, prints the whole path strings (not hashed)")] public Boolean NoHash { get; set; } [Option('l', "max_contexts", Default = 30000, HelpText = "Max number of path contexts to sample. Affects only very large snippets")] public int MaxContexts { get; set; } } public static class Utilities { public static String[] NumbericLiteralsToKeep = new String[] { "0", "1", "2", "3", "4", "5", "10" }; public static IEnumerable> Choose2(IEnumerable enumerable) { int index = 0; foreach (var e in enumerable) { ++index; foreach (var t in enumerable.Skip(index)) yield return Tuple.Create(e, t); } } /// /// Sample uniform randomly numSamples from an enumerable, using reservoir sampling. /// See https://en.wikipedia.org/wiki/Reservoir_sampling /// /// /// /// /// public static IEnumerable ReservoirSample(this IEnumerable input, int numSamples) { var rng = new Random(); var sampledElements = new List(numSamples); int seenElementCount = 0; foreach (var element in input) { seenElementCount++; if (sampledElements.Count < numSamples) { sampledElements.Add(element); } else { int position = rng.Next(seenElementCount); if (position < numSamples) { sampledElements[position] = element; } } } Debug.Assert(sampledElements.Count <= numSamples); return sampledElements; } public static IEnumerable WeakConcat(IEnumerable enumerable1, IEnumerable enumerable2) { foreach (T t in enumerable1) yield return t; foreach (T t in enumerable2) yield return t; } public static IEnumerable SplitToSubtokens(String name) { return Regex.Split(name.Trim(), "(?<=[a-z])(?=[A-Z])|_|[0-9]|(?<=[A-Z])(?=[A-Z][a-z])|\\s+") .Where(s => s.Length > 0) .Select(s => NormalizeName(s)) .Where(s => s.Length > 0); } private static Regex Whitespaces = new Regex(@"\s"); private static Regex NonAlphabetic = new Regex("[^A-Za-z]"); public static String NormalizeName(string s) { String partiallyNormalized = s.ToLowerInvariant() .Replace("\\\\n", String.Empty) .Replace("[\"',]", String.Empty); partiallyNormalized = Whitespaces.Replace(partiallyNormalized, ""); partiallyNormalized = Encoding.ASCII.GetString( Encoding.Convert( Encoding.UTF8, Encoding.GetEncoding( Encoding.ASCII.EncodingName, new EncoderReplacementFallback(string.Empty), new DecoderExceptionFallback() ), Encoding.UTF8.GetBytes(partiallyNormalized) ) ); if (partiallyNormalized.Contains('\n')) { partiallyNormalized = partiallyNormalized.Replace('\n', 'N'); } if (partiallyNormalized.Contains('\r')) { partiallyNormalized = partiallyNormalized.Replace('\r', 'R'); } if (partiallyNormalized.Contains(',')) { partiallyNormalized = partiallyNormalized.Replace(',', 'C'); } String completelyNormalized = NonAlphabetic.Replace(partiallyNormalized, String.Empty); if (completelyNormalized.Length == 0) { if (Regex.IsMatch(partiallyNormalized, @"^\d+$")) { if (NumbericLiteralsToKeep.Contains(partiallyNormalized)) { return partiallyNormalized; } else { return "NUM"; } } return String.Empty; } return completelyNormalized; } } } ================================================ FILE: CSharpExtractor/CSharpExtractor/Extractor/Variable.cs ================================================ using System; using System.Collections.Generic; using System.Linq; using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp.Syntax; namespace Extractor { namespace Semantics { public class Variable { Tree tree; public string Name { get; } private HashSet leaves; public HashSet Leaves { get { return leaves; } } private Nullable constant; public bool Const { get { return constant.Value; } } private Variable(string name, SyntaxToken[] leaves, Tree tree) { this.tree = tree; this.Name = name; this.leaves = new HashSet(leaves); constant = true; foreach (var leaf in leaves) { if (!tree.leaves[leaf].IsConst) { constant = false; // If not constant the it is a decleration token break; } } } public override int GetHashCode() { return this.Name.GetHashCode(); } public bool IsLiteral() { return Tree.literals.Contains(tree.leaves[Leaves.First()].Kind); } internal static Boolean isMethodName(SyntaxToken token) { return token.Parent.IsKind(Microsoft.CodeAnalysis.CSharp.SyntaxKind.MethodDeclaration) && token.IsKind(Microsoft.CodeAnalysis.CSharp.SyntaxKind.IdentifierToken); } // Create a variable for each variable in scope from tokens while splitting identically named but differently scoped vars. internal static IEnumerable CreateFromMethod(Tree methodTree) { var root = methodTree.nodes[methodTree.GetRoot()]; var leaves = root.Leaves.ToArray(); Dictionary tokenToName = new Dictionary(); Dictionary> nameToTokens = new Dictionary>(); foreach (SyntaxToken token in root.Leaves) { string name = methodTree.leaves[token].VariableName; if (isMethodName(token)) { name = Extractor.MethodNameConst; } tokenToName[token] = name; if (!nameToTokens.ContainsKey(name)) nameToTokens[name] = new List(); nameToTokens[name].Add(token); } List results = new List(); foreach (SyntaxToken leaf in leaves) { string name = tokenToName[leaf]; SyntaxToken[] syntaxTokens = nameToTokens[name].ToArray(); var v = new Variable(name, syntaxTokens, methodTree); //check if exists var matches = results.Where(p => p.Name == name).ToList(); bool alreadyExists = (matches.Count != 0); if (!alreadyExists) { results.Add(v); } } return results; } } } } ================================================ FILE: CSharpExtractor/extract.py ================================================ #!/usr/bin/python import itertools import multiprocessing import os import sys import shutil import subprocess from threading import Timer import sys from argparse import ArgumentParser from subprocess import Popen, PIPE, STDOUT, call def get_immediate_subdirectories(a_dir): return [(os.path.join(a_dir, name)) for name in os.listdir(a_dir) if os.path.isdir(os.path.join(a_dir, name))] TMP_DIR = "" def ParallelExtractDir(args, dir): ExtractFeaturesForDir(args, dir, "") def ExtractFeaturesForDir(args, dir, prefix): command = ['dotnet', 'run', '--project', args.csproj, '--max_length', str(args.max_path_length), '--max_width', str(args.max_path_width), '--path', dir, '--threads', str(args.num_threads), '--ofile_name', str(args.ofile_name)] # print command # os.system(command) kill = lambda process: process.kill() sleeper = subprocess.Popen(command, stderr=subprocess.PIPE) timer = Timer(600000, kill, [sleeper]) try: timer.start() _, stderr = sleeper.communicate() finally: timer.cancel() if sleeper.poll() == 0: if len(stderr) > 0: print(sys.stderr, stderr) else: print(sys.stderr, 'dir: ' + str(dir) + ' was not completed in time') failed = True subdirs = get_immediate_subdirectories(dir) for subdir in subdirs: ExtractFeaturesForDir(args, subdir, prefix + dir.split('/')[-1] + '_') if failed: if os.path.exists(str(args.ofile_name)): os.remove(str(args.ofile_name)) def ExtractFeaturesForDirsList(args, dirs): global TMP_DIR TMP_DIR = "./tmp/feature_extractor%d/" % (os.getpid()) if os.path.exists(TMP_DIR): shutil.rmtree(TMP_DIR, ignore_errors=True) os.makedirs(TMP_DIR) try: p = multiprocessing.Pool(4) p.starmap(ParallelExtractDir, zip(itertools.repeat(args), dirs)) #for dir in dirs: # ExtractFeaturesForDir(args, dir, '') output_files = os.listdir(TMP_DIR) for f in output_files: os.system("cat %s/%s" % (TMP_DIR, f)) finally: shutil.rmtree(TMP_DIR, ignore_errors=True) if __name__ == '__main__': parser = ArgumentParser() parser.add_argument("-maxlen", "--max_path_length", dest="max_path_length", required=False, default=8) parser.add_argument("-maxwidth", "--max_path_width", dest="max_path_width", required=False, default=2) parser.add_argument("-threads", "--num_threads", dest="num_threads", required=False, default=64) parser.add_argument("--csproj", dest="csproj", required=True) parser.add_argument("-dir", "--dir", dest="dir", required=False) parser.add_argument("-ofile_name", "--ofile_name", dest="ofile_name", required=True) args = parser.parse_args() if args.dir is not None: subdirs = get_immediate_subdirectories(args.dir) to_extract = subdirs if len(subdirs) == 0: to_extract = [args.dir.rstrip('/')] ExtractFeaturesForDirsList(args, to_extract) ================================================ FILE: Input.java ================================================ public String getName() { return name; } ================================================ FILE: JavaExtractor/JPredict/.classpath ================================================ ================================================ FILE: JavaExtractor/JPredict/.gitignore ================================================ /target/ ================================================ FILE: JavaExtractor/JPredict/src/main/java/JavaExtractor/App.java ================================================ package JavaExtractor; import JavaExtractor.Common.CommandLineValues; import org.kohsuke.args4j.CmdLineException; import java.io.IOException; import java.nio.file.Files; import java.nio.file.Paths; import java.util.LinkedList; import java.util.List; import java.util.concurrent.ExecutionException; import java.util.concurrent.Executors; import java.util.concurrent.Future; import java.util.concurrent.ThreadPoolExecutor; public class App { private static CommandLineValues s_CommandLineValues; public static void main(String[] args) { try { s_CommandLineValues = new CommandLineValues(args); } catch (CmdLineException e) { e.printStackTrace(); return; } if (s_CommandLineValues.File != null) { ExtractFeaturesTask extractFeaturesTask = new ExtractFeaturesTask(s_CommandLineValues, s_CommandLineValues.File.toPath()); extractFeaturesTask.processFile(); } else if (s_CommandLineValues.Dir != null) { extractDir(); } } private static void extractDir() { ThreadPoolExecutor executor = (ThreadPoolExecutor) Executors.newFixedThreadPool(s_CommandLineValues.NumThreads); LinkedList tasks = new LinkedList<>(); try { Files.walk(Paths.get(s_CommandLineValues.Dir)).filter(Files::isRegularFile) .filter(p -> p.toString().toLowerCase().endsWith(".java")).forEach(f -> { ExtractFeaturesTask task = new ExtractFeaturesTask(s_CommandLineValues, f); tasks.add(task); }); } catch (IOException e) { e.printStackTrace(); return; } List> tasksResults = null; try { tasksResults = executor.invokeAll(tasks); } catch (InterruptedException e) { e.printStackTrace(); } finally { executor.shutdown(); } tasksResults.forEach(f -> { try { f.get(); } catch (InterruptedException | ExecutionException e) { e.printStackTrace(); } }); } } ================================================ FILE: JavaExtractor/JPredict/src/main/java/JavaExtractor/Common/CommandLineValues.java ================================================ package JavaExtractor.Common; import org.kohsuke.args4j.CmdLineException; import org.kohsuke.args4j.CmdLineParser; import org.kohsuke.args4j.Option; import java.io.File; /** * This class handles the programs arguments. */ public class CommandLineValues { @Option(name = "--file", required = false) public File File = null; @Option(name = "--dir", required = false, forbids = "--file") public String Dir = null; @Option(name = "--max_path_length", required = true) public int MaxPathLength; @Option(name = "--max_path_width", required = true) public int MaxPathWidth; @Option(name = "--num_threads", required = false) public int NumThreads = 64; @Option(name = "--min_code_len", required = false) public int MinCodeLength = 1; @Option(name = "--max_code_len", required = false) public int MaxCodeLength = -1; @Option(name = "--max_file_len", required = false) public int MaxFileLength = -1; @Option(name = "--pretty_print", required = false) public boolean PrettyPrint = false; @Option(name = "--max_child_id", required = false) public int MaxChildId = 3; @Option(name = "--json_output", required = false) public boolean JsonOutput = false; public CommandLineValues(String... args) throws CmdLineException { CmdLineParser parser = new CmdLineParser(this); try { parser.parseArgument(args); } catch (CmdLineException e) { System.err.println(e.getMessage()); parser.printUsage(System.err); throw e; } } public CommandLineValues() { } } ================================================ FILE: JavaExtractor/JPredict/src/main/java/JavaExtractor/Common/Common.java ================================================ package JavaExtractor.Common; import JavaExtractor.FeaturesEntities.Property; import com.github.javaparser.ast.Node; import com.github.javaparser.ast.UserDataKey; import java.util.ArrayList; import java.util.stream.Collectors; import java.util.stream.Stream; public final class Common { public static final UserDataKey PropertyKey = new UserDataKey() { }; public static final UserDataKey ChildId = new UserDataKey() { }; public static final String EmptyString = ""; public static final String MethodDeclaration = "MethodDeclaration"; public static final String NameExpr = "NameExpr"; public static final String BlankWord = "BLANK"; public static final int c_MaxLabelLength = 50; public static final String methodName = "METHOD_NAME"; public static final String internalSeparator = "|"; public static String normalizeName(String original, String defaultString) { original = original.toLowerCase().replaceAll("\\\\n", "") // escaped new // lines .replaceAll("//s+", "") // whitespaces .replaceAll("[\"',]", "") // quotes, apostrophies, commas .replaceAll("\\P{Print}", ""); // unicode weird characters String stripped = original.replaceAll("[^A-Za-z]", ""); if (stripped.length() == 0) { String carefulStripped = original.replaceAll(" ", "_"); if (carefulStripped.length() == 0) { return defaultString; } else { return carefulStripped; } } else { return stripped; } } public static boolean isMethod(Node node, String type) { Property parentProperty = node.getParentNode().getUserData(Common.PropertyKey); if (parentProperty == null) { return false; } String parentType = parentProperty.getType(); return Common.NameExpr.equals(type) && Common.MethodDeclaration.equals(parentType); } public static ArrayList splitToSubtokens(String str1) { String str2 = str1.replace("|", " "); String str3 = str2.trim(); return Stream.of(str3.split("(?<=[a-z])(?=[A-Z])|_|[0-9]|(?<=[A-Z])(?=[A-Z][a-z])|\\s+")) .filter(s -> s.length() > 0).map(s -> Common.normalizeName(s, Common.EmptyString)) .filter(s -> s.length() > 0).collect(Collectors.toCollection(ArrayList::new)); } } ================================================ FILE: JavaExtractor/JPredict/src/main/java/JavaExtractor/Common/MethodContent.java ================================================ package JavaExtractor.Common; import com.github.javaparser.ast.Node; import java.util.ArrayList; public class MethodContent { private final ArrayList leaves; private final String name; private final String content; public MethodContent(ArrayList leaves, String name, String content) { this.leaves = leaves; this.name = name; this.content = content; } public ArrayList getLeaves() { return leaves; } public String getName() { return name; } public String getContent() { return content; } } ================================================ FILE: JavaExtractor/JPredict/src/main/java/JavaExtractor/ExtractFeaturesTask.java ================================================ package JavaExtractor; import JavaExtractor.Common.CommandLineValues; import JavaExtractor.Common.Common; import JavaExtractor.FeaturesEntities.ProgramFeatures; import org.apache.commons.lang3.StringUtils; import java.io.IOException; import java.nio.charset.Charset; import java.nio.file.Files; import java.nio.file.Path; import java.util.ArrayList; import java.util.List; import java.util.concurrent.Callable; import com.google.gson.Gson; class ExtractFeaturesTask implements Callable { private final CommandLineValues commandLineValues; private final Path filePath; public ExtractFeaturesTask(CommandLineValues commandLineValues, Path path) { this.commandLineValues = commandLineValues; this.filePath = path; } @Override public Void call() { processFile(); return null; } public void processFile() { ArrayList features; try { features = extractSingleFile(); } catch (IOException e) { e.printStackTrace(); return; } if (features == null) { return; } String toPrint = featuresToString(features); if (toPrint.length() > 0) { System.out.println(toPrint); } } private ArrayList extractSingleFile() throws IOException { String code; if (commandLineValues.MaxFileLength > 0 && Files.lines(filePath, Charset.defaultCharset()).count() > commandLineValues.MaxFileLength) { return new ArrayList<>(); } try { code = new String(Files.readAllBytes(filePath)); } catch (IOException e) { e.printStackTrace(); code = Common.EmptyString; } FeatureExtractor featureExtractor = new FeatureExtractor(commandLineValues, this.filePath); return featureExtractor.extractFeatures(code); } public String featuresToString(ArrayList features) { if (features == null || features.isEmpty()) { return Common.EmptyString; } List methodsOutputs = new ArrayList<>(); for (ProgramFeatures singleMethodFeatures : features) { StringBuilder builder = new StringBuilder(); String toPrint; if (commandLineValues.JsonOutput) { toPrint = new Gson().toJson(singleMethodFeatures); } else { toPrint = singleMethodFeatures.toString(); } if (commandLineValues.PrettyPrint) { toPrint = toPrint.replace(" ", "\n\t"); } builder.append(toPrint); methodsOutputs.add(builder.toString()); } return StringUtils.join(methodsOutputs, "\n"); } } ================================================ FILE: JavaExtractor/JPredict/src/main/java/JavaExtractor/FeatureExtractor.java ================================================ package JavaExtractor; import JavaExtractor.Common.CommandLineValues; import JavaExtractor.Common.Common; import JavaExtractor.Common.MethodContent; import JavaExtractor.FeaturesEntities.ProgramFeatures; import JavaExtractor.FeaturesEntities.Property; import JavaExtractor.Visitors.FunctionVisitor; import com.github.javaparser.JavaParser; import com.github.javaparser.ParseProblemException; import com.github.javaparser.ast.CompilationUnit; import com.github.javaparser.ast.Node; import java.io.File; import java.nio.file.Path; import java.util.ArrayList; import java.util.HashSet; import java.util.Set; import java.util.StringJoiner; import java.util.stream.Collectors; import java.util.stream.Stream; @SuppressWarnings("StringEquality") class FeatureExtractor { private final static String upSymbol = "|"; private final static String downSymbol = "|"; private static final Set s_ParentTypeToAddChildId = Stream .of("AssignExpr", "ArrayAccessExpr", "FieldAccessExpr", "MethodCallExpr") .collect(Collectors.toCollection(HashSet::new)); private final CommandLineValues m_CommandLineValues; private final Path filePath; public FeatureExtractor(CommandLineValues commandLineValues, Path filePath) { this.m_CommandLineValues = commandLineValues; this.filePath = filePath; } private static ArrayList getTreeStack(Node node) { ArrayList upStack = new ArrayList<>(); Node current = node; while (current != null) { upStack.add(current); current = current.getParentNode(); } return upStack; } public ArrayList extractFeatures(String code) { CompilationUnit m_CompilationUnit = parseFileWithRetries(code); FunctionVisitor functionVisitor = new FunctionVisitor(m_CommandLineValues); functionVisitor.visit(m_CompilationUnit, null); ArrayList methods = functionVisitor.getMethodContents(); return generatePathFeatures(methods); } private CompilationUnit parseFileWithRetries(String code) { final String classPrefix = "public class Test {"; final String classSuffix = "}"; final String methodPrefix = "SomeUnknownReturnType f() {"; final String methodSuffix = "return noSuchReturnValue; }"; String content = code; CompilationUnit parsed; try { parsed = JavaParser.parse(content); } catch (ParseProblemException e1) { // Wrap with a class and method try { content = classPrefix + methodPrefix + code + methodSuffix + classSuffix; parsed = JavaParser.parse(content); } catch (ParseProblemException e2) { // Wrap with a class only content = classPrefix + code + classSuffix; parsed = JavaParser.parse(content); } } return parsed; } private ArrayList generatePathFeatures(ArrayList methods) { ArrayList methodsFeatures = new ArrayList<>(); for (MethodContent content : methods) { ProgramFeatures singleMethodFeatures = generatePathFeaturesForFunction(content); if (!singleMethodFeatures.isEmpty()) { methodsFeatures.add(singleMethodFeatures); } } return methodsFeatures; } private ProgramFeatures generatePathFeaturesForFunction(MethodContent methodContent) { ArrayList functionLeaves = methodContent.getLeaves(); ProgramFeatures programFeatures = new ProgramFeatures( methodContent.getName(), this.filePath, methodContent.getContent()); for (int i = 0; i < functionLeaves.size(); i++) { for (int j = i + 1; j < functionLeaves.size(); j++) { String separator = Common.EmptyString; String path = generatePath(functionLeaves.get(i), functionLeaves.get(j), separator); if (path != Common.EmptyString) { Property source = functionLeaves.get(i).getUserData(Common.PropertyKey); Property target = functionLeaves.get(j).getUserData(Common.PropertyKey); programFeatures.addFeature(source, path, target); } } } return programFeatures; } private String generatePath(Node source, Node target, String separator) { StringJoiner stringBuilder = new StringJoiner(separator); ArrayList sourceStack = getTreeStack(source); ArrayList targetStack = getTreeStack(target); int commonPrefix = 0; int currentSourceAncestorIndex = sourceStack.size() - 1; int currentTargetAncestorIndex = targetStack.size() - 1; while (currentSourceAncestorIndex >= 0 && currentTargetAncestorIndex >= 0 && sourceStack.get(currentSourceAncestorIndex) == targetStack.get(currentTargetAncestorIndex)) { commonPrefix++; currentSourceAncestorIndex--; currentTargetAncestorIndex--; } int pathLength = sourceStack.size() + targetStack.size() - 2 * commonPrefix; if (pathLength > m_CommandLineValues.MaxPathLength) { return Common.EmptyString; } if (currentSourceAncestorIndex >= 0 && currentTargetAncestorIndex >= 0) { int pathWidth = targetStack.get(currentTargetAncestorIndex).getUserData(Common.ChildId) - sourceStack.get(currentSourceAncestorIndex).getUserData(Common.ChildId); if (pathWidth > m_CommandLineValues.MaxPathWidth) { return Common.EmptyString; } } for (int i = 0; i < sourceStack.size() - commonPrefix; i++) { Node currentNode = sourceStack.get(i); String childId = Common.EmptyString; String parentRawType = currentNode.getParentNode().getUserData(Common.PropertyKey).getRawType(); if (i == 0 || s_ParentTypeToAddChildId.contains(parentRawType)) { childId = saturateChildId(currentNode.getUserData(Common.ChildId)) .toString(); } stringBuilder.add(String.format("%s%s%s", currentNode.getUserData(Common.PropertyKey).getType(true), childId, upSymbol)); } Node commonNode = sourceStack.get(sourceStack.size() - commonPrefix); String commonNodeChildId = Common.EmptyString; Property parentNodeProperty = commonNode.getParentNode().getUserData(Common.PropertyKey); String commonNodeParentRawType = Common.EmptyString; if (parentNodeProperty != null) { commonNodeParentRawType = parentNodeProperty.getRawType(); } if (s_ParentTypeToAddChildId.contains(commonNodeParentRawType)) { commonNodeChildId = saturateChildId(commonNode.getUserData(Common.ChildId)) .toString(); } stringBuilder.add(String.format("%s%s", commonNode.getUserData(Common.PropertyKey).getType(true), commonNodeChildId)); for (int i = targetStack.size() - commonPrefix - 1; i >= 0; i--) { Node currentNode = targetStack.get(i); String childId = Common.EmptyString; if (i == 0 || s_ParentTypeToAddChildId.contains(currentNode.getUserData(Common.PropertyKey).getRawType())) { childId = saturateChildId(currentNode.getUserData(Common.ChildId)) .toString(); } stringBuilder.add(String.format("%s%s%s", downSymbol, currentNode.getUserData(Common.PropertyKey).getType(true), childId)); } return stringBuilder.toString(); } private Integer saturateChildId(int childId) { return Math.min(childId, m_CommandLineValues.MaxChildId); } } ================================================ FILE: JavaExtractor/JPredict/src/main/java/JavaExtractor/FeaturesEntities/ProgramFeatures.java ================================================ package JavaExtractor.FeaturesEntities; import java.nio.file.Path; import java.util.ArrayList; import java.util.stream.Collectors; public class ProgramFeatures { String name; transient ArrayList features = new ArrayList<>(); String textContent; String filePath; public ProgramFeatures(String name, Path filePath, String textContent) { this.name = name; this.filePath = filePath.toAbsolutePath().toString(); this.textContent = textContent; } @SuppressWarnings("StringBufferReplaceableByString") @Override public String toString() { StringBuilder stringBuilder = new StringBuilder(); stringBuilder.append(name).append(" "); stringBuilder.append(features.stream().map(ProgramRelation::toString).collect(Collectors.joining(" "))); return stringBuilder.toString(); } public void addFeature(Property source, String path, Property target) { ProgramRelation newRelation = new ProgramRelation(source, target, path); features.add(newRelation); } public boolean isEmpty() { return features.isEmpty(); } } ================================================ FILE: JavaExtractor/JPredict/src/main/java/JavaExtractor/FeaturesEntities/ProgramRelation.java ================================================ package JavaExtractor.FeaturesEntities; public class ProgramRelation { Property source; Property target; String path; public ProgramRelation(Property sourceName, Property targetName, String path) { source = sourceName; target = targetName; this.path = path; } public String toString() { return String.format("%s,%s,%s", source.getName(), path, target.getName()); } } ================================================ FILE: JavaExtractor/JPredict/src/main/java/JavaExtractor/FeaturesEntities/Property.java ================================================ package JavaExtractor.FeaturesEntities; import JavaExtractor.Common.Common; import com.github.javaparser.ast.Node; import com.github.javaparser.ast.expr.AssignExpr; import com.github.javaparser.ast.expr.BinaryExpr; import com.github.javaparser.ast.expr.IntegerLiteralExpr; import com.github.javaparser.ast.expr.UnaryExpr; import com.github.javaparser.ast.type.ClassOrInterfaceType; import java.util.*; import java.util.stream.Collectors; import java.util.stream.Stream; public class Property { public static final HashSet NumericalKeepValues = Stream.of("0", "1", "32", "64") .collect(Collectors.toCollection(HashSet::new)); private static final Map shortTypes = Collections.unmodifiableMap(new HashMap() { /** * */ private static final long serialVersionUID = 1L; { put("ArrayAccessExpr", "ArAc"); put("ArrayBracketPair", "ArBr"); put("ArrayCreationExpr", "ArCr"); put("ArrayCreationLevel", "ArCrLvl"); put("ArrayInitializerExpr", "ArIn"); put("ArrayType", "ArTy"); put("AssertStmt", "Asrt"); put("AssignExpr:and", "AsAn"); put("AssignExpr:assign", "As"); put("AssignExpr:lShift", "AsLS"); put("AssignExpr:minus", "AsMi"); put("AssignExpr:or", "AsOr"); put("AssignExpr:plus", "AsP"); put("AssignExpr:rem", "AsRe"); put("AssignExpr:rSignedShift", "AsRSS"); put("AssignExpr:rUnsignedShift", "AsRUS"); put("AssignExpr:slash", "AsSl"); put("AssignExpr:star", "AsSt"); put("AssignExpr:xor", "AsX"); put("BinaryExpr:and", "And"); put("BinaryExpr:binAnd", "BinAnd"); put("BinaryExpr:binOr", "BinOr"); put("BinaryExpr:divide", "Div"); put("BinaryExpr:equals", "Eq"); put("BinaryExpr:greater", "Gt"); put("BinaryExpr:greaterEquals", "Geq"); put("BinaryExpr:less", "Ls"); put("BinaryExpr:lessEquals", "Leq"); put("BinaryExpr:lShift", "LS"); put("BinaryExpr:minus", "Minus"); put("BinaryExpr:notEquals", "Neq"); put("BinaryExpr:or", "Or"); put("BinaryExpr:plus", "Plus"); put("BinaryExpr:remainder", "Mod"); put("BinaryExpr:rSignedShift", "RSS"); put("BinaryExpr:rUnsignedShift", "RUS"); put("BinaryExpr:times", "Mul"); put("BinaryExpr:xor", "Xor"); put("BlockStmt", "Bk"); put("BooleanLiteralExpr", "BoolEx"); put("CastExpr", "Cast"); put("CatchClause", "Catch"); put("CharLiteralExpr", "CharEx"); put("ClassExpr", "ClsEx"); put("ClassOrInterfaceDeclaration", "ClsD"); put("ClassOrInterfaceType", "Cls"); put("ConditionalExpr", "Cond"); put("ConstructorDeclaration", "Ctor"); put("DoStmt", "Do"); put("DoubleLiteralExpr", "Dbl"); put("EmptyMemberDeclaration", "Emp"); put("EnclosedExpr", "Enc"); put("ExplicitConstructorInvocationStmt", "ExpCtor"); put("ExpressionStmt", "Ex"); put("FieldAccessExpr", "Fld"); put("FieldDeclaration", "FldDec"); put("ForeachStmt", "Foreach"); put("ForStmt", "For"); put("IfStmt", "If"); put("InitializerDeclaration", "Init"); put("InstanceOfExpr", "InstanceOf"); put("IntegerLiteralExpr", "IntEx"); put("IntegerLiteralMinValueExpr", "IntMinEx"); put("LabeledStmt", "Labeled"); put("LambdaExpr", "Lambda"); put("LongLiteralExpr", "LongEx"); put("MarkerAnnotationExpr", "MarkerExpr"); put("MemberValuePair", "Mvp"); put("MethodCallExpr", "Cal"); put("MethodDeclaration", "Mth"); put("MethodReferenceExpr", "MethRef"); put("NameExpr", "Nm"); put("NormalAnnotationExpr", "NormEx"); put("NullLiteralExpr", "Null"); put("ObjectCreationExpr", "ObjEx"); put("Parameter", "Prm"); put("PrimitiveType", "Prim"); put("QualifiedNameExpr", "Qua"); put("ReturnStmt", "Ret"); put("SingleMemberAnnotationExpr", "SMEx"); put("StringLiteralExpr", "StrEx"); put("SuperExpr", "SupEx"); put("SwitchEntryStmt", "SwiEnt"); put("SwitchStmt", "Switch"); put("SynchronizedStmt", "Sync"); put("ThisExpr", "This"); put("ThrowStmt", "Thro"); put("TryStmt", "Try"); put("TypeDeclarationStmt", "TypeDec"); put("TypeExpr", "Type"); put("TypeParameter", "TypePar"); put("UnaryExpr:inverse", "Inverse"); put("UnaryExpr:negative", "Neg"); put("UnaryExpr:not", "Not"); put("UnaryExpr:posDecrement", "PosDec"); put("UnaryExpr:posIncrement", "PosInc"); put("UnaryExpr:positive", "Pos"); put("UnaryExpr:preDecrement", "PreDec"); put("UnaryExpr:preIncrement", "PreInc"); put("UnionType", "Unio"); put("VariableDeclarationExpr", "VDE"); put("VariableDeclarator", "VD"); put("VariableDeclaratorId", "VDID"); put("VoidType", "Void"); put("WhileStmt", "While"); put("WildcardType", "Wild"); } }); private final String RawType; private String Type; private String SplitName; public Property(Node node, boolean isLeaf, boolean isGenericParent) { Class nodeClass = node.getClass(); RawType = Type = nodeClass.getSimpleName(); if (node instanceof ClassOrInterfaceType && ((ClassOrInterfaceType) node).isBoxedType()) { Type = "PrimitiveType"; } String operator = ""; if (node instanceof BinaryExpr) { operator = ((BinaryExpr) node).getOperator().toString(); } else if (node instanceof UnaryExpr) { operator = ((UnaryExpr) node).getOperator().toString(); } else if (node instanceof AssignExpr) { operator = ((AssignExpr) node).getOperator().toString(); } if (operator.length() > 0) { Type += ":" + operator; } String nameToSplit = node.toString(); if (isGenericParent) { nameToSplit = ((ClassOrInterfaceType) node).getName(); if (isLeaf) { // if it is a generic parent which counts as a leaf, then when // it is participating in a path // as a parent, it should be GenericClass and not a simple // ClassOrInterfaceType. Type = "GenericClass"; } } ArrayList splitNameParts = Common.splitToSubtokens(nameToSplit); SplitName = String.join(Common.internalSeparator, splitNameParts); String name = Common.normalizeName(node.toString(), Common.BlankWord); if (name.length() > Common.c_MaxLabelLength) { name = name.substring(0, Common.c_MaxLabelLength); } else if (node instanceof ClassOrInterfaceType && ((ClassOrInterfaceType) node).isBoxedType()) { name = ((ClassOrInterfaceType) node).toUnboxedType().toString(); } if (Common.isMethod(node, Type)) { name = SplitName = Common.methodName; } if (SplitName.length() == 0) { SplitName = name; if (node instanceof IntegerLiteralExpr && !NumericalKeepValues.contains(SplitName)) { // This is a numeric literal, but not in our white list SplitName = ""; } } } public String getRawType() { return RawType; } public String getType() { return Type; } public String getType(boolean shorten) { if (shorten) { return shortTypes.getOrDefault(Type, Type); } else { return Type; } } public String getName() { return SplitName; } } ================================================ FILE: JavaExtractor/JPredict/src/main/java/JavaExtractor/Visitors/FunctionVisitor.java ================================================ package JavaExtractor.Visitors; import JavaExtractor.Common.CommandLineValues; import JavaExtractor.Common.Common; import JavaExtractor.Common.MethodContent; import com.github.javaparser.ast.Node; import com.github.javaparser.ast.body.MethodDeclaration; import com.github.javaparser.ast.visitor.VoidVisitorAdapter; import java.util.ArrayList; import java.util.Arrays; @SuppressWarnings("StringEquality") public class FunctionVisitor extends VoidVisitorAdapter { private final ArrayList methods = new ArrayList<>(); private final CommandLineValues commandLineValues; public FunctionVisitor(CommandLineValues commandLineValues) { this.commandLineValues = commandLineValues; } @Override public void visit(MethodDeclaration node, Object arg) { visitMethod(node); super.visit(node, arg); } private void visitMethod(MethodDeclaration node) { LeavesCollectorVisitor leavesCollectorVisitor = new LeavesCollectorVisitor(); leavesCollectorVisitor.visitDepthFirst(node); ArrayList leaves = leavesCollectorVisitor.getLeaves(); String normalizedMethodName = Common.normalizeName(node.getName(), Common.BlankWord); ArrayList splitNameParts = Common.splitToSubtokens(node.getName()); String splitName = normalizedMethodName; if (splitNameParts.size() > 0) { splitName = String.join(Common.internalSeparator, splitNameParts); } node.setName(Common.methodName); if (node.getBody() != null) { long methodLength = getMethodLength(node.getBody().toString()); if (commandLineValues.MaxCodeLength <= 0 || (methodLength >= commandLineValues.MinCodeLength && methodLength <= commandLineValues.MaxCodeLength)) { methods.add(new MethodContent(leaves, splitName, node.toString())); } } } private long getMethodLength(String code) { String cleanCode = code.replaceAll("\r\n", "\n").replaceAll("\t", " "); if (cleanCode.startsWith("{\n")) cleanCode = cleanCode.substring(3).trim(); if (cleanCode.endsWith("\n}")) cleanCode = cleanCode.substring(0, cleanCode.length() - 2).trim(); if (cleanCode.length() == 0) { return 0; } return Arrays.stream(cleanCode.split("\n")) .filter(line -> (line.trim() != "{" && line.trim() != "}" && line.trim() != "")) .filter(line -> !line.trim().startsWith("/") && !line.trim().startsWith("*")).count(); } public ArrayList getMethodContents() { return methods; } } ================================================ FILE: JavaExtractor/JPredict/src/main/java/JavaExtractor/Visitors/LeavesCollectorVisitor.java ================================================ package JavaExtractor.Visitors; import JavaExtractor.Common.Common; import JavaExtractor.FeaturesEntities.Property; import com.github.javaparser.ast.Node; import com.github.javaparser.ast.comments.Comment; import com.github.javaparser.ast.expr.NullLiteralExpr; import com.github.javaparser.ast.stmt.Statement; import com.github.javaparser.ast.type.ClassOrInterfaceType; import com.github.javaparser.ast.visitor.TreeVisitor; import java.util.ArrayList; import java.util.List; public class LeavesCollectorVisitor extends TreeVisitor { private final ArrayList m_Leaves = new ArrayList<>(); @Override public void process(Node node) { if (node instanceof Comment) { return; } boolean isLeaf = false; boolean isGenericParent = isGenericParent(node); if (hasNoChildren(node) && isNotComment(node)) { if (!node.toString().isEmpty() && (!"null".equals(node.toString()) || (node instanceof NullLiteralExpr))) { m_Leaves.add(node); isLeaf = true; } } int childId = getChildId(node); node.setUserData(Common.ChildId, childId); Property property = new Property(node, isLeaf, isGenericParent); node.setUserData(Common.PropertyKey, property); } private boolean isGenericParent(Node node) { return (node instanceof ClassOrInterfaceType) && ((ClassOrInterfaceType) node).getTypeArguments() != null && ((ClassOrInterfaceType) node).getTypeArguments().size() > 0; } private boolean hasNoChildren(Node node) { return node.getChildrenNodes().size() == 0; } private boolean isNotComment(Node node) { return !(node instanceof Comment) && !(node instanceof Statement); } public ArrayList getLeaves() { return m_Leaves; } private int getChildId(Node node) { Node parent = node.getParentNode(); List parentsChildren = parent.getChildrenNodes(); int childId = 0; for (Node child : parentsChildren) { if (child.getRange().equals(node.getRange())) { return childId; } childId++; } return childId; } } ================================================ FILE: JavaExtractor/JPredict/src/main/java/Test.java ================================================ class Test { void fooBar() { System.out.println("http://github.com"); } } ================================================ FILE: JavaExtractor/extract.py ================================================ #!/usr/bin/python import itertools import multiprocessing import os import shutil import subprocess import sys from argparse import ArgumentParser from threading import Timer def get_immediate_subdirectories(a_dir): return [(os.path.join(a_dir, name)) for name in os.listdir(a_dir) if os.path.isdir(os.path.join(a_dir, name))] TMP_DIR = "" def ParallelExtractDir(args, dir): ExtractFeaturesForDir(args, dir, "") def ExtractFeaturesForDir(args, dir, prefix): command = ['java', '-Xmx100g', '-XX:MaxNewSize=60g', '-cp', args.jar, 'JavaExtractor.App', '--max_path_length', str(args.max_path_length), '--max_path_width', str(args.max_path_width), '--dir', dir, '--num_threads', str(args.num_threads)] # print command # os.system(command) kill = lambda process: process.kill() outputFileName = TMP_DIR + prefix + dir.split('/')[-1] failed = False with open(outputFileName, 'a') as outputFile: sleeper = subprocess.Popen(command, stdout=outputFile, stderr=subprocess.PIPE) timer = Timer(60 * 60, kill, [sleeper]) try: timer.start() stdout, stderr = sleeper.communicate() finally: timer.cancel() if sleeper.poll() == 0: if len(stderr) > 0: print(stderr, file=sys.stderr) else: print('dir: ' + str(dir) + ' was not completed in time', file=sys.stderr) failed = True subdirs = get_immediate_subdirectories(dir) for subdir in subdirs: ExtractFeaturesForDir(args, subdir, prefix + dir.split('/')[-1] + '_') if failed: if os.path.exists(outputFileName): os.remove(outputFileName) def ExtractFeaturesForDirsList(args, dirs): global TMP_DIR TMP_DIR = "./tmp/feature_extractor%d/" % (os.getpid()) if os.path.exists(TMP_DIR): shutil.rmtree(TMP_DIR, ignore_errors=True) os.makedirs(TMP_DIR) try: p = multiprocessing.Pool(6) p.starmap(ParallelExtractDir, zip(itertools.repeat(args), dirs)) # for dir in dirs: # ExtractFeaturesForDir(args, dir, '') output_files = os.listdir(TMP_DIR) for f in output_files: os.system("cat %s/%s" % (TMP_DIR, f)) finally: shutil.rmtree(TMP_DIR, ignore_errors=True) if __name__ == '__main__': parser = ArgumentParser() parser.add_argument("-maxlen", "--max_path_length", dest="max_path_length", required=False, default=8) parser.add_argument("-maxwidth", "--max_path_width", dest="max_path_width", required=False, default=2) parser.add_argument("-threads", "--num_threads", dest="num_threads", required=False, default=64) parser.add_argument("-j", "--jar", dest="jar", required=True) parser.add_argument("-dir", "--dir", dest="dir", required=False) parser.add_argument("-file", "--file", dest="file", required=False) args = parser.parse_args() if args.file is not None: command = 'java -cp ' + args.jar + ' JavaExtractor.App --max_path_length ' + \ str(args.max_path_length) + ' --max_path_width ' + str(args.max_path_width) + ' --file ' + args.file os.system(command) elif args.dir is not None: subdirs = get_immediate_subdirectories(args.dir) if len(subdirs) == 0: subdirs = [args.dir] ExtractFeaturesForDirsList(args, subdirs) ================================================ FILE: LICENSE ================================================ MIT License Copyright (c) 2019 Technion Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: Python150kExtractor/README.md ================================================ # Python150k dataset ## Steps to reproduce 1. Download parsed python dataset from [here](https://www.sri.inf.ethz.ch/py150 ), unarchive and place under `PYTHON150K_DIR`: ```bash # Replace with desired path. >>> PYTHON150K_DIR=/path/to/data/dir >>> mkdir -p $PYTHON150K_DIR >>> cd $PYTHON150K_DIR >>> wget http://files.srl.inf.ethz.ch/data/py150.tar.gz ... >>> tar -xzvf py150.tar.gz ... ``` 2. Extract samples to `DATA_DIR`: ```bash # Replace with desired path. >>> DATA_DIR=$(pwd)/data/default >>> SEED=239 >>> python extract.py \ --data_dir=$PYTHON150K_DIR \ --output_dir=$DATA_DIR \ --seed=$SEED ... ``` 3. Preprocess for training: ```bash >>> ./preprocess.sh $DATA_DIR ... ``` 4. Train: ```bash >>> cd .. >>> DESC=default >>> CUDA=0 >>> ./train_python150k.sh $DATA_DIR $DESC $CUDA $SEED ... ``` ## Test results (seed=239) ### Best scores **setup#2**: `batch_size=64` **setup#3**: `embedding_size=256,use_momentum=False` **setup#4**: `batch_size=32,embedding_size=256,embeddings_dropout_keep_prob=0.5,use_momentum=False` | params | Precision | Recall | F1 | ROUGE-2 | ROUGE-L | |---|---|---|---|---|---| | default | 0.37 | 0.27 | 0.31 | 0.06 | 0.38 | | setup#2 | 0.40 | 0.31 | 0.34 | 0.08 | 0.41 | | setup#3 | 0.36 | 0.31 | 0.33 | 0.09 | 0.38 | | setup#4 | 0.33 | 0.25 | 0.28 | 0.05 | 0.34 | ### Ablation studies | params | Precision | Recall | F1 | ROUGE-2 | ROUGE-L | |---|---|---|---|---|---| | default | 0.37 | 0.27 | 0.31 | 0.06 | 0.38 | | no ast nodes (5th epoch) | 0.27 | 0.16 | 0.20 | 0.02 | 0.28 | | no token split (4th epoch) | 0.60 | 0.09 | 0.15 | 0.00 | 0.60 | ================================================ FILE: Python150kExtractor/extract.py ================================================ import argparse import re import json import multiprocessing import itertools import tqdm import joblib import numpy as np from pathlib import Path from sklearn import model_selection as sklearn_model_selection METHOD_NAME, NUM = 'METHODNAME', 'NUM' parser = argparse.ArgumentParser() parser.add_argument('--data_dir', required=True, type=str) parser.add_argument('--valid_p', type=float, default=0.2) parser.add_argument('--max_path_length', type=int, default=8) parser.add_argument('--max_path_width', type=int, default=2) parser.add_argument('--use_method_name', type=bool, default=True) parser.add_argument('--use_nums', type=bool, default=True) parser.add_argument('--output_dir', required=True, type=str) parser.add_argument('--n_jobs', type=int, default=multiprocessing.cpu_count()) parser.add_argument('--seed', type=int, default=239) def __collect_asts(json_file): with open(json_file, 'r', encoding='utf-8') as f: for line in tqdm.tqdm(f): yield line def __terminals(ast, node_index, args): stack, paths = [], [] def dfs(v): stack.append(v) v_node = ast[v] if 'value' in v_node: if v == node_index: # Top-level func def node. if args.use_method_name: paths.append((stack.copy(), METHOD_NAME)) else: v_type = v_node['type'] if v_type.startswith('Name'): paths.append((stack.copy(), v_node['value'])) elif args.use_nums and v_type == 'Num': paths.append((stack.copy(), NUM)) else: pass if 'children' in v_node: for child in v_node['children']: dfs(child) stack.pop() dfs(node_index) return paths def __merge_terminals2_paths(v_path, u_path): s, n, m = 0, len(v_path), len(u_path) while s < min(n, m) and v_path[s] == u_path[s]: s += 1 prefix = list(reversed(v_path[s:])) lca = v_path[s - 1] suffix = u_path[s:] return prefix, lca, suffix def __raw_tree_paths(ast, node_index, args): tnodes = __terminals(ast, node_index, args) tree_paths = [] for (v_path, v_value), (u_path, u_value) in itertools.combinations( iterable=tnodes, r=2, ): prefix, lca, suffix = __merge_terminals2_paths(v_path, u_path) if (len(prefix) + 1 + len(suffix) <= args.max_path_length) \ and (abs(len(prefix) - len(suffix)) <= args.max_path_width): path = prefix + [lca] + suffix tree_path = v_value, path, u_value tree_paths.append(tree_path) return tree_paths def __delim_name(name): if name in {METHOD_NAME, NUM}: return name def camel_case_split(identifier): matches = re.finditer( '.+?(?:(?<=[a-z])(?=[A-Z])|(?<=[A-Z])(?=[A-Z][a-z])|$)', identifier, ) return [m.group(0) for m in matches] blocks = [] for underscore_block in name.split('_'): blocks.extend(camel_case_split(underscore_block)) return '|'.join(block.lower() for block in blocks) def __collect_sample(ast, fd_index, args): root = ast[fd_index] if root['type'] != 'FunctionDef': raise ValueError('Wrong node type.') target = root['value'] tree_paths = __raw_tree_paths(ast, fd_index, args) contexts = [] for tree_path in tree_paths: start, connector, finish = tree_path start, finish = __delim_name(start), __delim_name(finish) connector = '|'.join(ast[v]['type'] for v in connector) context = f'{start},{connector},{finish}' contexts.append(context) if len(contexts) == 0: return None target = __delim_name(target) context = ' '.join(contexts) return f'{target} {context}' def __collect_samples(ast, args): samples = [] for node_index, node in enumerate(ast): if node['type'] == 'FunctionDef': sample = __collect_sample(ast, node_index, args) if sample is not None: samples.append(sample) return samples def __collect_all_and_save(asts, args, output_file): parallel = joblib.Parallel(n_jobs=args.n_jobs) func = joblib.delayed(__collect_samples) samples = parallel(func(ast, args) for ast in tqdm.tqdm(asts)) samples = list(itertools.chain.from_iterable(samples)) with open(output_file, 'w') as f: for line_index, line in enumerate(samples): f.write(line + ('' if line_index == len(samples) - 1 else '\n')) def main(): args = parser.parse_args() np.random.seed(args.seed) data_dir = Path(args.data_dir) trains = list(__collect_asts(data_dir / 'python100k_train.json')) evals = list(__collect_asts(data_dir / 'python50k_eval.json')) train, valid = sklearn_model_selection.train_test_split( trains, test_size=args.valid_p, ) test = evals output_dir = Path(args.output_dir) output_dir.mkdir(exist_ok=True) for split_name, split in zip( ('train', 'valid', 'test'), (train, valid, test), ): output_file = output_dir / f'{split_name}_output_file.txt' __collect_all_and_save((json.loads(line) for line in split), args, output_file) if __name__ == '__main__': main() ================================================ FILE: Python150kExtractor/preprocess.sh ================================================ #!/usr/bin/env bash MAX_CONTEXTS=200 MAX_DATA_CONTEXTS=1000 SUBTOKEN_VOCAB_SIZE=186277 TARGET_VOCAB_SIZE=26347 data_dir=${1:-data} mkdir -p "${data_dir}" train_data_file=$data_dir/train_output_file.txt valid_data_file=$data_dir/valid_output_file.txt test_data_file=$data_dir/test_output_file.txt echo "Creating histograms from the training data..." target_histogram_file=$data_dir/histo.tgt.c2s source_subtoken_histogram=$data_dir/histo.ori.c2s node_histogram_file=$data_dir/histo.node.c2s cut <"${train_data_file}" -d' ' -f1 | tr '|' '\n' | awk '{n[$0]++} END {for (i in n) print i,n[i]}' >"${target_histogram_file}" cut <"${train_data_file}" -d' ' -f2- | tr ' ' '\n' | cut -d',' -f1,3 | tr ',|' '\n' | awk '{n[$0]++} END {for (i in n) print i,n[i]}' >"${source_subtoken_histogram}" cut <"${train_data_file}" -d' ' -f2- | tr ' ' '\n' | cut -d',' -f2 | tr '|' '\n' | awk '{n[$0]++} END {for (i in n) print i,n[i]}' >"${node_histogram_file}" echo "Preprocessing..." python ../preprocess.py \ --train_data "${train_data_file}" \ --val_data "${valid_data_file}" \ --test_data "${test_data_file}" \ --max_contexts ${MAX_CONTEXTS} \ --max_data_contexts ${MAX_DATA_CONTEXTS} \ --subtoken_vocab_size ${SUBTOKEN_VOCAB_SIZE} \ --target_vocab_size ${TARGET_VOCAB_SIZE} \ --target_histogram "${target_histogram_file}" \ --subtoken_histogram "${source_subtoken_histogram}" \ --node_histogram "${node_histogram_file}" \ --output_name "${data_dir}"/"$(basename "${data_dir}")" rm \ "${target_histogram_file}" \ "${source_subtoken_histogram}" \ "${node_histogram_file}" ================================================ FILE: README.md ================================================ # code2seq This is an official implementation of the model described in: [Uri Alon](http://urialon.cswp.cs.technion.ac.il), [Shaked Brody](http://www.cs.technion.ac.il/people/shakedbr/), [Omer Levy](https://levyomer.wordpress.com) and [Eran Yahav](http://www.cs.technion.ac.il/~yahave/), "code2seq: Generating Sequences from Structured Representations of Code" [[PDF]](https://openreview.net/pdf?id=H1gKYo09tX) Appeared in **ICLR'2019** (**poster** available [here](https://urialon.cswp.cs.technion.ac.il/wp-content/uploads/sites/83/2019/05/ICLR19_poster_code2seq.pdf)) An **online demo** is available at [https://code2seq.org](https://code2seq.org). This is a TensorFlow implementation of the network, with Java and C# extractors for preprocessing the input code. It can be easily extended to other languages, since the TensorFlow network is agnostic to the input programming language (see [Extending to other languages](#extending-to-other-languages). Contributions are welcome.
## See also: * **Structural Language Models for Code** (ICML'2020) is a new paper that learns to generate the missing code within a larger code snippet. This is similar to code completion, but is able to predict complex expressions rather than a single token at a time. See [PDF](https://arxiv.org/pdf/1910.00577.pdf), demo at [http://AnyCodeGen.org](http://AnyCodeGen.org). * **Adversarial Examples for Models of Code** is a new paper that shows how to slightly mutate the input code snippet of code2vec and GNNs models (thus, introducing adversarial examples), such that the model (code2vec or GNNs) will output a prediction of our choice. See [PDF](https://arxiv.org/pdf/1910.07517.pdf) (code: soon). * **Neural Reverse Engineering of Stripped Binaries** is a new paper that learns to predict procedure names in stripped binaries, thus use neural networks for reverse engineering. See [PDF](https://arxiv.org/pdf/1902.09122) (code: soon). * **code2vec** (POPL'2019) is our previous model. It can only generate a single label at a time (rather than a sequence as code2seq), but it is much faster to train (because of its simplicity). See [PDF](https://urialon.cswp.cs.technion.ac.il/wp-content/uploads/sites/83/2018/12/code2vec-popl19.pdf), demo at [https://code2vec.org](https://code2vec.org) and [code](https://github.com/tech-srl/code2vec/). Table of Contents ================= * [Requirements](#requirements) * [Quickstart](#quickstart) * [Configuration](#configuration) * [Releasing a trained mode](#releasing-a-trained-model) * [Extending to other languages](#extending-to-other-languages) * [Datasets](#datasets) * [Baselines](#baselines) * [Citation](#citation) ## Requirements * [python3](https://www.linuxbabe.com/ubuntu/install-python-3-6-ubuntu-16-04-16-10-17-04) * TensorFlow 1.12 ([install](https://www.tensorflow.org/install/install_linux)). To check TensorFlow version: > python3 -c 'import tensorflow as tf; print(tf.\_\_version\_\_)' - For a TensorFlow 2.1 implementation by [@Kolkir](https://github.com/Kolkir/), see: [https://github.com/Kolkir/code2seq](https://github.com/Kolkir/code2seq) * For [creating a new Java dataset](#creating-and-preprocessing-a-new-java-dataset) or [manually examining a trained model](#step-4-manual-examination-of-a-trained-model) (any operation that requires parsing of a new code example): [JDK](https://openjdk.java.net/install/) * For creating a C# dataset: [dotnet-core](https://dotnet.microsoft.com/download) version 2.2 or newer. * `pip install rouge` for computing rouge scores. ## Quickstart ### Step 0: Cloning this repository ``` git clone https://github.com/tech-srl/code2seq cd code2seq ``` ### Step 1: Creating a new dataset from Java sources To obtain a preprocessed dataset to train a network on, you can either download our preprocessed dataset, or create a new dataset from Java source files. #### Download our preprocessed dataset Java-large dataset (~16M examples, compressed: 11G, extracted 125GB) ``` mkdir data cd data wget https://s3.amazonaws.com/code2seq/datasets/java-large-preprocessed.tar.gz tar -xvzf java-large-preprocessed.tar.gz ``` This will create a `data/java-large/` sub-directory, containing the files that hold training, test and validation sets, and a dict file for various dataset properties. #### Creating and preprocessing a new Java dataset To create and preprocess a new dataset (for example, to compare code2seq to another model on another dataset): * Edit the file [preprocess.sh](preprocess.sh) using the instructions there, pointing it to the correct training, validation and test directories. * Run the preprocess.sh file: > bash preprocess.sh ### Step 2: Training a model You can either download an already trained model, or train a new model using a preprocessed dataset. #### Downloading a trained model (137 MB) We already trained a model for 52 epochs on the data that was preprocessed in the previous step. This model is the same model that was used in the paper and the same model that serves the demo at [code2seq.org](code2seq.org). ``` wget https://s3.amazonaws.com/code2seq/model/java-large/java-large-model.tar.gz tar -xvzf java-large-model.tar.gz ``` ##### Note: This trained model is in a "released" state, which means that we stripped it from its training parameters. #### Training a model from scratch To train a model from scratch: * Edit the file [train.sh](train.sh) to point it to the right preprocessed data. By default, it points to our "java-large" dataset that was preprocessed in the previous step. * Before training, you can edit the configuration hyper-parameters in the file [config.py](config.py), as explained in [Configuration](#configuration). * Run the [train.sh](train.sh) script: ``` bash train.sh ``` ### Step 3: Evaluating a trained model After `config.PATIENCE` iterations of no improvement on the validation set, training stops by itself. Suppose that iteration #52 is our chosen model, run: ``` python3 code2seq.py --load models/java-large-model/model_iter52.release --test data/java-large/java-large.test.c2s ``` While evaluating, a file named "log.txt" is written to the same dir as the saved models, with each test example name and the model's prediction. ### Step 4: Manual examination of a trained model To manually examine a trained model, run: ``` python3 code2seq.py --load models/java-large-model/model_iter52.release --predict ``` After the model loads, follow the instructions and edit the file `Input.java` and enter a Java method or code snippet, and examine the model's predictions and attention scores. #### Note: Due to TensorFlow's limitations, if using beam search (`config.BEAM_WIDTH > 0`), then `BEAM_WIDTH` hypotheses will be printed, but without attention weights. If not using beam search (`config.BEAM_WIDTH == 0`), then a single hypothesis will be printed *with the attention weights* in every decoding timestep. ## Configuration Changing hyper-parameters is possible by editing the file [config.py](config.py). Here are some of the parameters and their description: #### config.NUM_EPOCHS = 3000 The max number of epochs to train the model. #### config.SAVE_EVERY_EPOCHS = 1 The frequency, in epochs, of saving a model and evaluating on the validation set during training. #### config.PATIENCE = 10 Controlling early stopping: how many epochs of no improvement should training continue before stopping. #### config.BATCH_SIZE = 512 Batch size during training. #### config.TEST_BATCH_SIZE = 256 Batch size during evaluation. Affects only the evaluation speed and memory consumption, does not affect the results. #### config.SHUFFLE_BUFFER_SIZE = 10000 The buffer size that the reader uses for shuffling the training data. Controls the randomness of the data. Increasing this value might hurt training throughput. #### config.CSV_BUFFER_SIZE = 100 * 1024 * 1024 The buffer size (in bytes) of the CSV dataset reader. #### config.MAX_CONTEXTS = 200 The number of contexts to sample in each example during training (resampling a different subset of this size every training iteration). #### config.SUBTOKENS_VOCAB_MAX_SIZE = 190000 The max size of the subtoken vocabulary. #### config.TARGET_VOCAB_MAX_SIZE = 27000 The max size of the target words vocabulary. #### config.EMBEDDINGS_SIZE = 128 Embedding size for subtokens, AST nodes and target symbols. #### config.RNN_SIZE = 128 * 2 The total size of the two LSTMs that are used to embed the paths if `config.BIRNN` is `True`, or the size of the single LSTM if `config.BIRNN` is `False`. #### config.DECODER_SIZE = 320 Size of each LSTM layer in the decoder. #### config.NUM_DECODER_LAYERS = 1 Number of decoder LSTM layers. Can be increased to support long target sequences. #### config.MAX_PATH_LENGTH = 8 + 1 The max number of nodes in a path #### config.MAX_NAME_PARTS = 5 The max number of subtokens in an input token. If the token is longer, only the first subtokens will be read. #### config.MAX_TARGET_PARTS = 6 The max number of symbols in the target sequence. Set to 6 by default for method names, but can be increased for learning datasets with longer sequences. ### config.BIRNN = True If True, use a bidirectional LSTM to encode each path. If False, use a unidirectional LSTM only. #### config.RANDOM_CONTEXTS = True When True, sample `MAX_CONTEXT` from every example every training iteration. When False, take the first `MAX_CONTEXTS` only. #### config.BEAM_WIDTH = 0 Beam width in beam search. Inactive when 0. #### config.USE_MOMENTUM = True If `True`, use Momentum optimizer with nesterov. If `False`, use Adam (Adam converges in fewer epochs; Momentum leads to slightly better results). ## Releasing a trained model If you wish to keep a trained model for inference only (without the ability to continue training it) you can release the model using: ``` python3 code2seq.py --load models/java-large-model/model_iter52 --release ``` This will save a copy of the trained model with the '.release' suffix. A "released" model usually takes ~3x less disk space. ## Extending to other languages This project currently supports Java and C\# as the input languages. _**March 2020** - a code2seq extractor for **C++** based on LLVM was developed by [@Kolkir](https://github.com/Kolkir/) and is available here: [https://github.com/Kolkir/cppminer](https://github.com/Kolkir/cppminer)._ _**January 2020** - a code2seq extractor for Python (specifically targeting the Python150k dataset) was contributed by [@stasbel](https://github.com/stasbel). See: [https://github.com/tech-srl/code2seq/tree/master/Python150kExtractor](https://github.com/tech-srl/code2seq/tree/master/Python150kExtractor)._ _**January 2020** - an extractor for predicting TypeScript type annotations for JavaScript input using code2vec was developed by [@izosak](https://github.com/izosak) and Noa Cohen, and is available here: [https://github.com/tech-srl/id2vec](https://github.com/tech-srl/id2vec)._ ~~_**June 2019** - an extractor for **C** that is compatible with our model was developed by [CMU SEI team](https://github.com/cmu-sei/code2vec-c)._~~ - removed by CMU SEI team. _**June 2019** - a code2vec extractor for **Python, Java, C, C++** by JetBrains Research is available here: [PathMiner](https://github.com/JetBrains-Research/astminer)._ To extend code2seq to other languages other than Java and C#, a new extractor (similar to the [JavaExtractor](JavaExtractor)) should be implemented, and be called by [preprocess.sh](preprocess.sh). Basically, an extractor should be able to output for each directory containing source files: * A single text file, where each row is an example. * Each example is a space-delimited list of fields, where: 1. The first field is the target label, internally delimited by the "|" character (for example: `compare|ignore|case`) 2. Each of the following field are contexts, where each context has three components separated by commas (","). None of these components can include spaces nor commas. We refer to these three components as a token, a path, and another token, but in general other types of ternary contexts can be considered. Each "token" component is a token in the code, split to subtokens using the "|" character. Each path is a path between two tokens, split to path nodes (or other kinds of building blocks) using the "|" character. Example for a context: `my|key,StringExression|MethodCall|Name,get|value` Here `my|key` and `get|value` are tokens, and `StringExression|MethodCall|Name` is the syntactic path that connects them. ## Datasets ### Java To download the Java-small, Java-med and Java-large datasets used in the Code Summarization task as raw `*.java` files, use: * [Java-small](https://s3.amazonaws.com/code2seq/datasets/java-small.tar.gz) * [Java-med](https://s3.amazonaws.com/code2seq/datasets/java-med.tar.gz) * [Java-large](https://s3.amazonaws.com/code2seq/datasets/java-large.tar.gz) To download the preprocessed datasets, use: * [Java-small-preprocessed](https://s3.amazonaws.com/code2seq/datasets/java-small-preprocessed.tar.gz) * [Java-med-preprocessed](https://s3.amazonaws.com/code2seq/datasets/java-med-preprocessed.tar.gz) * [Java-large-preprocessed](https://s3.amazonaws.com/code2seq/datasets/java-large-preprocessed.tar.gz) ### C# The C# dataset used in the Code Captioning task can be downloaded from the [CodeNN](https://github.com/sriniiyer/codenn/) repository. ## Baselines ### Using the trained model For the NMT baselines (BiLSTM, Transformer) we used the implementation of [OpenNMT-py](http://opennmt.net/OpenNMT-py/). The trained BiLSTM model is available here: `https://code2seq.s3.amazonaws.com/lstm_baseline/model_acc_62.88_ppl_12.03_e16.pt` Test+validation sources and targets: ``` https://code2seq.s3.amazonaws.com/lstm_baseline/test_expected_actual.txt https://code2seq.s3.amazonaws.com/lstm_baseline/test_source.txt https://code2seq.s3.amazonaws.com/lstm_baseline/test_target.txt https://code2seq.s3.amazonaws.com/lstm_baseline/val_source.txt https://code2seq.s3.amazonaws.com/lstm_baseline/val_target.txt ``` The command line for "translating" a "source" file to a "target" is: `python3 translate.py -model model_acc_62.88_ppl_12.03_e16.pt -src test_source.txt -output translation_epoch16.txt -gpu 0` This results in a `translation_epoch16.txt` which we compare to `test_target.txt` to compute the score. The file `test_expected_actual.txt` is a line-by-line concatenation of the true reference ("expected") with the corresponding prediction (the "actual"). ### Creating data for the baseline We first modified the JavaExtractor (the same one as in this) to locate the methods to train on and print them to a file where each method is a single line. This modification is currently not checked in, but instead of extracting paths, it just prints `node.toString()` and replaces "\n" with space, where `node` is the object holding the AST node of type `MethodDeclaration`. Then, we tokenized (including sub-tokenization of identifiers, i.e., `"ArrayList"-> ["Array","List"])` each method body using `javalang`, using [this](baseline_tokenization/subtokenize_nmt_baseline.py) script (which can be run on [this](baseline_tokenization/input_example.txt) input example). So a program of: ``` void methodName(String fooBar) { System.out.println("hello world"); } ``` should be printed by the modified JavaExtractor as: ```method name|void (String fooBar){ System.out.println("hello world");}``` and the tokenization script would turn it into: ```void ( String foo Bar ) { System . out . println ( " hello world " ) ; }``` and the label to be predicted, i.e., "method name", into a separate file. OpenNMT-py can then be trained over these training source and target files. ## Citation [code2seq: Generating Sequences from Structured Representations of Code](https://arxiv.org/pdf/1808.01400) ``` @inproceedings{ alon2018codeseq, title={code2seq: Generating Sequences from Structured Representations of Code}, author={Uri Alon and Shaked Brody and Omer Levy and Eran Yahav}, booktitle={International Conference on Learning Representations}, year={2019}, url={https://openreview.net/forum?id=H1gKYo09tX}, } ``` ================================================ FILE: __init__.py ================================================ ================================================ FILE: baseline_tokenization/input_example.txt ================================================ requires landscape|boolean (){ return false; } get parent key|Object (){ return new ContactsUiKey(); } get parent key|Object (){ return new ContactsUiKey(); } get layout id|int (){ return R.layout.loose_screen; } get parent key|Object (){ return new EditContactKey(contactId); } to contact|Contact (){ return new Contact(id, name, email); } to string|String (){ return "Welcome!\nClick to continue."; } get parent key|Object (){ return new EditContactKey(contactId); } tear down services|void (@NonNull Services services){ } get layout id|int (){ return R.layout.landscape_screen; } ================================================ FILE: baseline_tokenization/javalang/__init__.py ================================================ from . import parser from . import parse from . import tokenizer from . import javadoc __version__ = "0.10.1" ================================================ FILE: baseline_tokenization/javalang/ast.py ================================================ import pickle import six class MetaNode(type): def __new__(mcs, name, bases, dict): attrs = list(dict['attrs']) dict['attrs'] = list() for base in bases: if hasattr(base, 'attrs'): dict['attrs'].extend(base.attrs) dict['attrs'].extend(attrs) return type.__new__(mcs, name, bases, dict) @six.add_metaclass(MetaNode) class Node(object): attrs = () def __init__(self, **kwargs): values = kwargs.copy() for attr_name in self.attrs: value = values.pop(attr_name, None) setattr(self, attr_name, value) if values: raise ValueError('Extraneous arguments') def __equals__(self, other): if type(other) is not type(self): return False for attr in self.attrs: if getattr(other, attr) != getattr(self, attr): return False return True def __repr__(self): return type(self).__name__ def __iter__(self): return walk_tree(self) def filter(self, pattern): for path, node in self: if ((isinstance(pattern, type) and isinstance(node, pattern)) or (node == pattern)): yield path, node @property def children(self): return [getattr(self, attr_name) for attr_name in self.attrs] def walk_tree(root): children = None if isinstance(root, Node): yield (), root children = root.children else: children = root for child in children: if isinstance(child, (Node, list, tuple)): for path, node in walk_tree(child): yield (root,) + path, node def dump(ast, file): pickle.dump(ast, file) def load(file): return pickle.load(file) ================================================ FILE: baseline_tokenization/javalang/javadoc.py ================================================ import re def join(s): return ' '.join(l.strip() for l in s.split('\n')) class DocBlock(object): def __init__(self): self.description = '' self.return_doc = None self.params = [] self.authors = [] self.deprecated = False # @exception and @throw are equivalent self.throws = {} self.exceptions = self.throws self.tags = {} def add_block(self, name, value): value = value.strip() if name == 'param': try: param, description = value.split(None, 1) except ValueError: param, description = value, '' self.params.append((param, join(description))) elif name in ('throws', 'exception'): try: ex, description = value.split(None, 1) except ValueError: ex, description = value, '' self.throws[ex] = join(description) elif name == 'return': self.return_doc = value elif name == 'author': self.authors.append(value) elif name == 'deprecated': self.deprecated = True self.tags.setdefault(name, []).append(value) blocks_re = re.compile('(^@)', re.MULTILINE) leading_space_re = re.compile(r'^\s*\*', re.MULTILINE) blocks_justify_re = re.compile(r'^\s*@', re.MULTILINE) def _sanitize(s): s = s.strip() if not (s[:3] == '/**' and s[-2:] == '*/'): raise ValueError('not a valid Javadoc comment') s = s.replace('\t', ' ') return s def _uncomment(s): # Remove /** and */ s = s[3:-2].strip() return leading_space_re.sub('', s) def _get_indent_level(s): return len(s) - len(s.lstrip()) def _left_justify(s): lines = s.rstrip().splitlines() if not lines: return '' indent_levels = [] for line in lines: if line.strip(): indent_levels.append(_get_indent_level(line)) indent_levels.sort() common_indent = indent_levels[0] if common_indent == 0: return s else: lines = [line[common_indent:] for line in lines] return '\n'.join(lines) def _force_blocks_left(s): return blocks_justify_re.sub('@', s) def parse(raw): sanitized = _sanitize(raw) uncommented = _uncomment(sanitized) justified = _left_justify(uncommented) justified_fixed = _force_blocks_left(justified) prepared = justified_fixed blocks = blocks_re.split(prepared) doc = DocBlock() if blocks[0] != '@': doc.description = blocks[0].strip() blocks = blocks[2::2] else: blocks = blocks[1::2] for block in blocks: try: tag, value = block.split(None, 1) except ValueError: tag, value = block, '' doc.add_block(tag, value) return doc ================================================ FILE: baseline_tokenization/javalang/parse.py ================================================ from .parser import Parser from .tokenizer import tokenize def parse_expression(exp): if not exp.endswith(';'): exp = exp + ';' tokens = tokenize(exp) parser = Parser(tokens) return parser.parse_expression() def parse_member_signature(sig): if not sig.endswith(';'): sig = sig + ';' tokens = tokenize(sig) parser = Parser(tokens) return parser.parse_member_declaration() def parse_constructor_signature(sig): # Add an empty body to the signature, replacing a ; if necessary if sig.endswith(';'): sig = sig[:-1] sig = sig + '{ }' tokens = tokenize(sig) parser = Parser(tokens) return parser.parse_member_declaration() def parse_type(s): tokens = tokenize(s) parser = Parser(tokens) return parser.parse_type() def parse_type_signature(sig): if sig.endswith(';'): sig = sig[:-1] sig = sig + '{ }' tokens = tokenize(sig) parser = Parser(tokens) return parser.parse_class_or_interface_declaration() def parse(s): tokens = tokenize(s) parser = Parser(tokens) return parser.parse() ================================================ FILE: baseline_tokenization/javalang/parser.py ================================================ import six from . import util from . import tree from .tokenizer import ( EndOfInput, Keyword, Modifier, BasicType, Identifier, Annotation, Literal, Operator, JavaToken, ) ENABLE_DEBUG_SUPPORT = False def parse_debug(method): global ENABLE_DEBUG_SUPPORT if ENABLE_DEBUG_SUPPORT: def _method(self): if not hasattr(self, 'recursion_depth'): self.recursion_depth = 0 if self.debug: depth = "%02d" % (self.recursion_depth,) token = six.text_type(self.tokens.look()) start_value = self.tokens.look().value name = method.__name__ sep = ("-" * self.recursion_depth) e_message = "" print("%s %s> %s(%s)" % (depth, sep, name, token)) self.recursion_depth += 1 try: r = method(self) except JavaSyntaxError as e: e_message = e.description raise except Exception as e: e_message = six.text_type(e) raise finally: token = six.text_type(self.tokens.last()) print("%s <%s %s(%s, %s) %s" % (depth, sep, name, start_value, token, e_message)) self.recursion_depth -= 1 else: self.recursion_depth += 1 try: r = method(self) finally: self.recursion_depth -= 1 return r return _method else: return method # ------------------------------------------------------------------------------ # ---- Parsing exception ---- class JavaParserBaseException(Exception): def __init__(self, message=''): super(JavaParserBaseException, self).__init__(message) class JavaSyntaxError(JavaParserBaseException): def __init__(self, description, at=None): super(JavaSyntaxError, self).__init__() self.description = description self.at = at class JavaParserError(JavaParserBaseException): pass # ------------------------------------------------------------------------------ # ---- Parser class ---- class Parser(object): operator_precedence = [ set(('||',)), set(('&&',)), set(('|',)), set(('^',)), set(('&',)), set(('==', '!=')), set(('<', '>', '>=', '<=', 'instanceof')), set(('<<', '>>', '>>>')), set(('+', '-')), set(('*', '/', '%')) ] def __init__(self, tokens): self.tokens = util.LookAheadListIterator(tokens) self.tokens.set_default(EndOfInput(None)) self.debug = False # ------------------------------------------------------------------------------ # ---- Debug control ---- def set_debug(self, debug=True): self.debug = debug # ------------------------------------------------------------------------------ # ---- Parsing entry point ---- def parse(self): return self.parse_compilation_unit() # ------------------------------------------------------------------------------ # ---- Helper methods ---- def illegal(self, description, at=None): if not at: at = self.tokens.look() raise JavaSyntaxError(description, at) def accept(self, *accepts): last = None if len(accepts) == 0: raise JavaParserError("Missing acceptable values") for accept in accepts: token = next(self.tokens) if isinstance(accept, six.string_types) and ( not token.value == accept): self.illegal("Expected '%s'" % (accept,)) elif isinstance(accept, type) and not isinstance(token, accept): self.illegal("Expected %s" % (accept.__name__,)) last = token return last.value def would_accept(self, *accepts): if len(accepts) == 0: raise JavaParserError("Missing acceptable values") for i, accept in enumerate(accepts): token = self.tokens.look(i) if isinstance(accept, six.string_types) and ( not token.value == accept): return False elif isinstance(accept, type) and not isinstance(token, accept): return False return True def try_accept(self, *accepts): if len(accepts) == 0: raise JavaParserError("Missing acceptable values") for i, accept in enumerate(accepts): token = self.tokens.look(i) if isinstance(accept, six.string_types) and ( not token.value == accept): return False elif isinstance(accept, type) and not isinstance(token, accept): return False for i in range(0, len(accepts)): next(self.tokens) return True def build_binary_operation(self, parts, start_level=0): if len(parts) == 1: return parts[0] operands = list() operators = list() i = 0 for level in range(start_level, len(self.operator_precedence)): for j in range(1, len(parts) - 1, 2): if parts[j] in self.operator_precedence[level]: operand = self.build_binary_operation(parts[i:j], level + 1) operator = parts[j] i = j + 1 operands.append(operand) operators.append(operator) if operands: break operand = self.build_binary_operation(parts[i:], level + 1) operands.append(operand) operation = operands[0] for operator, operandr in zip(operators, operands[1:]): operation = tree.BinaryOperation(operandl=operation) operation.operator = operator operation.operandr = operandr return operation def is_annotation(self, i=0): """ Returns true if the position is the start of an annotation application (as opposed to an annotation declaration) """ return (isinstance(self.tokens.look(i), Annotation) and not self.tokens.look(i + 1).value == 'interface') def is_annotation_declaration(self, i=0): """ Returns true if the position is the start of an annotation application (as opposed to an annotation declaration) """ return (isinstance(self.tokens.look(i), Annotation) and self.tokens.look(i + 1).value == 'interface') # ------------------------------------------------------------------------------ # ---- Parsing methods ---- # ------------------------------------------------------------------------------ # -- Identifiers -- @parse_debug def parse_identifier(self): return self.accept(Identifier) @parse_debug def parse_qualified_identifier(self): qualified_identifier = list() while True: identifier = self.parse_identifier() qualified_identifier.append(identifier) if not self.try_accept('.'): break return '.'.join(qualified_identifier) @parse_debug def parse_qualified_identifier_list(self): qualified_identifiers = list() while True: qualified_identifier = self.parse_qualified_identifier() qualified_identifiers.append(qualified_identifier) if not self.try_accept(','): break return qualified_identifiers # ------------------------------------------------------------------------------ # -- Top level units -- @parse_debug def parse_compilation_unit(self): package = None package_annotations = None javadoc = None import_declarations = list() type_declarations = list() self.tokens.push_marker() next_token = self.tokens.look() if next_token: javadoc = next_token.javadoc if self.is_annotation(): package_annotations = self.parse_annotations() if self.try_accept('package'): self.tokens.pop_marker(False) package_name = self.parse_qualified_identifier() package = tree.PackageDeclaration(annotations=package_annotations, name=package_name, documentation=javadoc) self.accept(';') else: self.tokens.pop_marker(True) package_annotations = None while self.would_accept('import'): import_declaration = self.parse_import_declaration() import_declarations.append(import_declaration) while not isinstance(self.tokens.look(), EndOfInput): try: type_declaration = self.parse_type_declaration() except StopIteration: self.illegal("Unexpected end of input") if type_declaration: type_declarations.append(type_declaration) return tree.CompilationUnit(package=package, imports=import_declarations, types=type_declarations) @parse_debug def parse_import_declaration(self): qualified_identifier = list() static = False import_all = False self.accept('import') if self.try_accept('static'): static = True while True: identifier = self.parse_identifier() qualified_identifier.append(identifier) if self.try_accept('.'): if self.try_accept('*'): self.accept(';') import_all = True break else: self.accept(';') break return tree.Import(path='.'.join(qualified_identifier), static=static, wildcard=import_all) @parse_debug def parse_type_declaration(self): if self.try_accept(';'): return None else: return self.parse_class_or_interface_declaration() @parse_debug def parse_class_or_interface_declaration(self): modifiers, annotations, javadoc = self.parse_modifiers() type_declaration = None token = self.tokens.look() if token.value == 'class': type_declaration = self.parse_normal_class_declaration() elif token.value == 'enum': type_declaration = self.parse_enum_declaration() elif token.value == 'interface': type_declaration = self.parse_normal_interface_declaration() elif self.is_annotation_declaration(): type_declaration = self.parse_annotation_type_declaration() else: self.illegal("Expected type declaration") type_declaration.modifiers = modifiers type_declaration.annotations = annotations type_declaration.documentation = javadoc return type_declaration @parse_debug def parse_normal_class_declaration(self): name = None type_params = None extends = None implements = None body = None self.accept('class') name = self.parse_identifier() if self.would_accept('<'): type_params = self.parse_type_parameters() if self.try_accept('extends'): extends = self.parse_type() if self.try_accept('implements'): implements = self.parse_type_list() body = self.parse_class_body() return tree.ClassDeclaration(name=name, type_parameters=type_params, extends=extends, implements=implements, body=body) @parse_debug def parse_enum_declaration(self): name = None implements = None body = None self.accept('enum') name = self.parse_identifier() if self.try_accept('implements'): implements = self.parse_type_list() body = self.parse_enum_body() return tree.EnumDeclaration(name=name, implements=implements, body=body) @parse_debug def parse_normal_interface_declaration(self): name = None type_parameters = None extends = None body = None self.accept('interface') name = self.parse_identifier() if self.would_accept('<'): type_parameters = self.parse_type_parameters() if self.try_accept('extends'): extends = self.parse_type_list() body = self.parse_interface_body() return tree.InterfaceDeclaration(name=name, type_parameters=type_parameters, extends=extends, body=body) @parse_debug def parse_annotation_type_declaration(self): name = None body = None self.accept('@', 'interface') name = self.parse_identifier() body = self.parse_annotation_type_body() return tree.AnnotationDeclaration(name=name, body=body) # ------------------------------------------------------------------------------ # -- Types -- @parse_debug def parse_type(self): java_type = None if isinstance(self.tokens.look(), BasicType): java_type = self.parse_basic_type() elif isinstance(self.tokens.look(), Identifier): java_type = self.parse_reference_type() else: self.illegal("Expected type") java_type.dimensions = self.parse_array_dimension() return java_type @parse_debug def parse_basic_type(self): return tree.BasicType(name=self.accept(BasicType)) @parse_debug def parse_reference_type(self): reference_type = tree.ReferenceType() tail = reference_type while True: tail.name = self.parse_identifier() if self.would_accept('<'): tail.arguments = self.parse_type_arguments() if self.try_accept('.'): tail.sub_type = tree.ReferenceType() tail = tail.sub_type else: break return reference_type @parse_debug def parse_type_arguments(self): type_arguments = list() self.accept('<') while True: type_argument = self.parse_type_argument() type_arguments.append(type_argument) if self.try_accept('>'): break self.accept(',') return type_arguments @parse_debug def parse_type_argument(self): pattern_type = None base_type = None if self.try_accept('?'): if self.tokens.look().value in ('extends', 'super'): pattern_type = self.tokens.next().value else: return tree.TypeArgument(pattern_type='?') if self.would_accept(BasicType): base_type = self.parse_basic_type() self.accept('[', ']') base_type.dimensions = [None] else: base_type = self.parse_reference_type() base_type.dimensions = [] base_type.dimensions += self.parse_array_dimension() return tree.TypeArgument(type=base_type, pattern_type=pattern_type) @parse_debug def parse_nonwildcard_type_arguments(self): self.accept('<') type_arguments = self.parse_type_list() self.accept('>') return [tree.TypeArgument(type=t) for t in type_arguments] @parse_debug def parse_type_list(self): types = list() while True: if self.would_accept(BasicType): base_type = self.parse_basic_type() self.accept('[', ']') base_type.dimensions = [None] else: base_type = self.parse_reference_type() base_type.dimensions = [] base_type.dimensions += self.parse_array_dimension() types.append(base_type) if not self.try_accept(','): break return types @parse_debug def parse_type_arguments_or_diamond(self): if self.try_accept('<', '>'): return list() else: return self.parse_type_arguments() @parse_debug def parse_nonwildcard_type_arguments_or_diamond(self): if self.try_accept('<', '>'): return list() else: return self.parse_nonwildcard_type_arguments() @parse_debug def parse_type_parameters(self): type_parameters = list() self.accept('<') while True: type_parameter = self.parse_type_parameter() type_parameters.append(type_parameter) if self.try_accept('>'): break else: self.accept(',') return type_parameters @parse_debug def parse_type_parameter(self): identifier = self.parse_identifier() extends = None if self.try_accept('extends'): extends = list() while True: reference_type = self.parse_reference_type() extends.append(reference_type) if not self.try_accept('&'): break return tree.TypeParameter(name=identifier, extends=extends) @parse_debug def parse_array_dimension(self): array_dimension = 0 while self.try_accept('[', ']'): array_dimension += 1 return [None] * array_dimension # ------------------------------------------------------------------------------ # -- Annotations and modifiers -- @parse_debug def parse_modifiers(self): annotations = list() modifiers = set() javadoc = None next_token = self.tokens.look() if next_token: javadoc = next_token.javadoc while True: if self.would_accept(Modifier): modifiers.add(self.accept(Modifier)) elif self.is_annotation(): annotation = self.parse_annotation() annotations.append(annotation) else: break return (modifiers, annotations, javadoc) @parse_debug def parse_annotations(self): annotations = list() while True: annotation = self.parse_annotation() annotations.append(annotation) if not self.is_annotation(): break return annotations @parse_debug def parse_annotation(self): qualified_identifier = None annotation_element = None self.accept('@') qualified_identifier = self.parse_qualified_identifier() if self.try_accept('('): if not self.would_accept(')'): annotation_element = self.parse_annotation_element() self.accept(')') return tree.Annotation(name=qualified_identifier, element=annotation_element) @parse_debug def parse_annotation_element(self): if self.would_accept(Identifier, '='): return self.parse_element_value_pairs() else: return self.parse_element_value() @parse_debug def parse_element_value_pairs(self): pairs = list() while True: pair = self.parse_element_value_pair() pairs.append(pair) if not self.try_accept(','): break return pairs @parse_debug def parse_element_value_pair(self): identifier = self.parse_identifier() self.accept('=') value = self.parse_element_value() return tree.ElementValuePair(name=identifier, value=value) @parse_debug def parse_element_value(self): if self.is_annotation(): return self.parse_annotation() elif self.would_accept('{'): return self.parse_element_value_array_initializer() else: return self.parse_expressionl() @parse_debug def parse_element_value_array_initializer(self): self.accept('{') if self.try_accept('}'): return list() element_values = self.parse_element_values() self.try_accept(',') self.accept('}') return tree.ElementArrayValue(values=element_values) @parse_debug def parse_element_values(self): element_values = list() while True: element_value = self.parse_element_value() element_values.append(element_value) if self.would_accept('}') or self.would_accept(',', '}'): break self.accept(',') return element_values # ------------------------------------------------------------------------------ # -- Class body -- @parse_debug def parse_class_body(self): declarations = list() self.accept('{') while not self.would_accept('}'): declaration = self.parse_class_body_declaration() if declaration: declarations.append(declaration) self.accept('}') return declarations @parse_debug def parse_class_body_declaration(self): token = self.tokens.look() if self.try_accept(';'): return None elif self.would_accept('static', '{'): self.accept('static') return self.parse_block() elif self.would_accept('{'): return self.parse_block() else: return self.parse_member_declaration() @parse_debug def parse_member_declaration(self): modifiers, annotations, javadoc = self.parse_modifiers() member = None token = self.tokens.look() if self.try_accept('void'): method_name = self.parse_identifier() member = self.parse_void_method_declarator_rest() member.name = method_name elif token.value == '<': member = self.parse_generic_method_or_constructor_declaration() elif token.value == 'class': member = self.parse_normal_class_declaration() elif token.value == 'enum': member = self.parse_enum_declaration() elif token.value == 'interface': member = self.parse_normal_interface_declaration() elif self.is_annotation_declaration(): member = self.parse_annotation_type_declaration() elif self.would_accept(Identifier, '('): constructor_name = self.parse_identifier() member = self.parse_constructor_declarator_rest() member.name = constructor_name else: member = self.parse_method_or_field_declaraction() member._position = token.position member.modifiers = modifiers member.annotations = annotations member.documentation = javadoc return member @parse_debug def parse_method_or_field_declaraction(self): member_type = self.parse_type() member_name = self.parse_identifier() member = self.parse_method_or_field_rest() if isinstance(member, tree.MethodDeclaration): member_type.dimensions += member.return_type.dimensions member.name = member_name member.return_type = member_type else: member.type = member_type member.declarators[0].name = member_name return member @parse_debug def parse_method_or_field_rest(self): if self.would_accept('('): return self.parse_method_declarator_rest() else: rest = self.parse_field_declarators_rest() self.accept(';') return rest @parse_debug def parse_field_declarators_rest(self): array_dimension, initializer = self.parse_variable_declarator_rest() declarators = [tree.VariableDeclarator(dimensions=array_dimension, initializer=initializer)] while self.try_accept(','): declarator = self.parse_variable_declarator() declarators.append(declarator) return tree.FieldDeclaration(declarators=declarators) @parse_debug def parse_method_declarator_rest(self): formal_parameters = self.parse_formal_parameters() additional_dimensions = self.parse_array_dimension() throws = None body = None if self.try_accept('throws'): throws = self.parse_qualified_identifier_list() if self.would_accept('{'): body = self.parse_block() else: self.accept(';') return tree.MethodDeclaration(parameters=formal_parameters, throws=throws, body=body, return_type=tree.Type(dimensions=additional_dimensions)) @parse_debug def parse_void_method_declarator_rest(self): formal_parameters = self.parse_formal_parameters() throws = None body = None if self.try_accept('throws'): throws = self.parse_qualified_identifier_list() if self.would_accept('{'): body = self.parse_block() else: self.accept(';') return tree.MethodDeclaration(parameters=formal_parameters, throws=throws, body=body) @parse_debug def parse_constructor_declarator_rest(self): formal_parameters = self.parse_formal_parameters() throws = None body = None if self.try_accept('throws'): throws = self.parse_qualified_identifier_list() body = self.parse_block() return tree.ConstructorDeclaration(parameters=formal_parameters, throws=throws, body=body) @parse_debug def parse_generic_method_or_constructor_declaration(self): type_parameters = self.parse_type_parameters() method = None if self.would_accept(Identifier, '('): constructor_name = self.parse_identifier() method = self.parse_constructor_declarator_rest() method.name = constructor_name elif self.try_accept('void'): method_name = self.parse_identifier() method = self.parse_void_method_declarator_rest() method.name = method_name else: method_return_type = self.parse_type() method_name = self.parse_identifier() method = self.parse_method_declarator_rest() method_return_type.dimensions += method.return_type.dimensions method.return_type = method_return_type method.name = method_name method.type_parameters = type_parameters return method # ------------------------------------------------------------------------------ # -- Interface body -- @parse_debug def parse_interface_body(self): declarations = list() self.accept('{') while not self.would_accept('}'): declaration = self.parse_interface_body_declaration() if declaration: declarations.append(declaration) self.accept('}') return declarations @parse_debug def parse_interface_body_declaration(self): if self.try_accept(';'): return None modifiers, annotations, javadoc = self.parse_modifiers() declaration = self.parse_interface_member_declaration() declaration.modifiers = modifiers declaration.annotations = annotations declaration.documentation = javadoc return declaration @parse_debug def parse_interface_member_declaration(self): declaration = None if self.would_accept('class'): declaration = self.parse_normal_class_declaration() elif self.would_accept('interface'): declaration = self.parse_normal_interface_declaration() elif self.would_accept('enum'): declaration = self.parse_enum_declaration() elif self.is_annotation_declaration(): declaration = self.parse_annotation_type_declaration() elif self.would_accept('<'): declaration = self.parse_interface_generic_method_declarator() elif self.try_accept('void'): method_name = self.parse_identifier() declaration = self.parse_void_interface_method_declarator_rest() declaration.name = method_name else: declaration = self.parse_interface_method_or_field_declaration() return declaration @parse_debug def parse_interface_method_or_field_declaration(self): java_type = self.parse_type() name = self.parse_identifier() member = self.parse_interface_method_or_field_rest() if isinstance(member, tree.MethodDeclaration): java_type.dimensions += member.return_type.dimensions member.name = name member.return_type = java_type else: member.declarators[0].name = name member.type = java_type return member @parse_debug def parse_interface_method_or_field_rest(self): rest = None if self.would_accept('('): rest = self.parse_interface_method_declarator_rest() else: rest = self.parse_constant_declarators_rest() self.accept(';') return rest @parse_debug def parse_constant_declarators_rest(self): array_dimension, initializer = self.parse_constant_declarator_rest() declarators = [tree.VariableDeclarator(dimensions=array_dimension, initializer=initializer)] while self.try_accept(','): declarator = self.parse_constant_declarator() declarators.append(declarator) return tree.ConstantDeclaration(declarators=declarators) @parse_debug def parse_constant_declarator_rest(self): array_dimension = self.parse_array_dimension() self.accept('=') initializer = self.parse_variable_initializer() return (array_dimension, initializer) @parse_debug def parse_constant_declarator(self): name = self.parse_identifier() additional_dimension, initializer = self.parse_constant_declarator_rest() return tree.VariableDeclarator(name=name, dimensions=additional_dimension, initializer=initializer) @parse_debug def parse_interface_method_declarator_rest(self): parameters = self.parse_formal_parameters() array_dimension = self.parse_array_dimension() throws = None body = None if self.try_accept('throws'): throws = self.parse_qualified_identifier_list() if self.would_accept('{'): body = self.parse_block() else: self.accept(';') return tree.MethodDeclaration(parameters=parameters, throws=throws, body=body, return_type=tree.Type(dimensions=array_dimension)) @parse_debug def parse_void_interface_method_declarator_rest(self): parameters = self.parse_formal_parameters() throws = None body = None if self.try_accept('throws'): throws = self.parse_qualified_identifier_list() if self.would_accept('{'): body = self.parse_block() else: self.accept(';') return tree.MethodDeclaration(parameters=parameters, throws=throws, body=body) @parse_debug def parse_interface_generic_method_declarator(self): type_parameters = self.parse_type_parameters() return_type = None method_name = None if not self.try_accept('void'): return_type = self.parse_type() method_name = self.parse_identifier() method = self.parse_interface_method_declarator_rest() method.name = method_name method.return_type = return_type method.type_parameters = type_parameters return method # ------------------------------------------------------------------------------ # -- Parameters and variables -- @parse_debug def parse_formal_parameters(self): formal_parameters = list() self.accept('(') if self.try_accept(')'): return formal_parameters while True: modifiers, annotations = self.parse_variable_modifiers() parameter_type = self.parse_type() varargs = False if self.try_accept('...'): varargs = True parameter_name = self.parse_identifier() parameter_type.dimensions += self.parse_array_dimension() parameter = tree.FormalParameter(modifiers=modifiers, annotations=annotations, type=parameter_type, name=parameter_name, varargs=varargs) formal_parameters.append(parameter) if varargs: # varargs parameter must be the last break if not self.try_accept(','): break self.accept(')') return formal_parameters @parse_debug def parse_variable_modifiers(self): modifiers = set() annotations = list() while True: if self.try_accept('final'): modifiers.add('final') elif self.is_annotation(): annotation = self.parse_annotation() annotations.append(annotation) else: break return modifiers, annotations @parse_debug def parse_variable_declators(self): declarators = list() while True: declarator = self.parse_variable_declator() declarators.append(declarator) if not self.try_accept(','): break return declarators @parse_debug def parse_variable_declarators(self): declarators = list() while True: declarator = self.parse_variable_declarator() declarators.append(declarator) if not self.try_accept(','): break return declarators @parse_debug def parse_variable_declarator(self): identifier = self.parse_identifier() array_dimension, initializer = self.parse_variable_declarator_rest() return tree.VariableDeclarator(name=identifier, dimensions=array_dimension, initializer=initializer) @parse_debug def parse_variable_declarator_rest(self): array_dimension = self.parse_array_dimension() initializer = None if self.try_accept('='): initializer = self.parse_variable_initializer() return (array_dimension, initializer) @parse_debug def parse_variable_initializer(self): if self.would_accept('{'): return self.parse_array_initializer() else: return self.parse_expression() @parse_debug def parse_array_initializer(self): array_initializer = tree.ArrayInitializer(initializers=list()) self.accept('{') if self.try_accept(','): self.accept('}') return array_initializer if self.try_accept('}'): return array_initializer while True: initializer = self.parse_variable_initializer() array_initializer.initializers.append(initializer) if not self.would_accept('}'): self.accept(',') if self.try_accept('}'): return array_initializer # ------------------------------------------------------------------------------ # -- Blocks and statements -- @parse_debug def parse_block(self): statements = list() self.accept('{') while not self.would_accept('}'): statement = self.parse_block_statement() statements.append(statement) self.accept('}') return statements @parse_debug def parse_block_statement(self): if self.would_accept(Identifier, ':'): # Labeled statement return self.parse_statement() if self.would_accept('synchronized'): return self.parse_statement() token = None found_annotations = False i = 0 # Look past annoatations and modifiers. If we find a modifier that is not # 'final' then the statement must be a class or interface declaration while True: token = self.tokens.look(i) if isinstance(token, Modifier): if not token.value == 'final': return self.parse_class_or_interface_declaration() elif self.is_annotation(i): found_annotations = True i += 2 while self.tokens.look(i).value == '.': i += 2 if self.tokens.look(i).value == '(': parens = 1 i += 1 while parens > 0: token = self.tokens.look(i) if token.value == '(': parens += 1 elif token.value == ')': parens -= 1 i += 1 continue else: break i += 1 if token.value in ('class', 'enum', 'interface', '@'): return self.parse_class_or_interface_declaration() if found_annotations or isinstance(token, BasicType): return self.parse_local_variable_declaration_statement() # At this point, if the block statement is a variable definition the next # token MUST be an identifier, so if it isn't we can conclude the block # statement is a normal statement if not isinstance(token, Identifier): return self.parse_statement() # We can't easily determine the statement type. Try parsing as a variable # declaration first and fall back to a statement try: with self.tokens: return self.parse_local_variable_declaration_statement() except JavaSyntaxError: return self.parse_statement() @parse_debug def parse_local_variable_declaration_statement(self): modifiers, annotations = self.parse_variable_modifiers() java_type = self.parse_type() declarators = self.parse_variable_declarators() self.accept(';') var = tree.LocalVariableDeclaration(modifiers=modifiers, annotations=annotations, type=java_type, declarators=declarators) return var @parse_debug def parse_statement(self): token = self.tokens.look() if self.would_accept('{'): block = self.parse_block() return tree.BlockStatement(statements=block) elif self.try_accept(';'): return tree.Statement() elif self.would_accept(Identifier, ':'): identifer = self.parse_identifier() self.accept(':') statement = self.parse_statement() statement.label = identifer return statement elif self.try_accept('if'): condition = self.parse_par_expression() then = self.parse_statement() else_statement = None if self.try_accept('else'): else_statement = self.parse_statement() return tree.IfStatement(condition=condition, then_statement=then, else_statement=else_statement) elif self.try_accept('assert'): condition = self.parse_expression() value = None if self.try_accept(':'): value = self.parse_expression() self.accept(';') return tree.AssertStatement(condition=condition, value=value) elif self.try_accept('switch'): switch_expression = self.parse_par_expression() self.accept('{') switch_block = self.parse_switch_block_statement_groups() self.accept('}') return tree.SwitchStatement(expression=switch_expression, cases=switch_block) elif self.try_accept('while'): condition = self.parse_par_expression() action = self.parse_statement() return tree.WhileStatement(condition=condition, body=action) elif self.try_accept('do'): action = self.parse_statement() self.accept('while') condition = self.parse_par_expression() self.accept(';') return tree.DoStatement(condition=condition, body=action) elif self.try_accept('for'): self.accept('(') for_control = self.parse_for_control() self.accept(')') for_statement = self.parse_statement() return tree.ForStatement(control=for_control, body=for_statement) elif self.try_accept('break'): label = None if self.would_accept(Identifier): label = self.parse_identifier() self.accept(';') return tree.BreakStatement(goto=label) elif self.try_accept('continue'): label = None if self.would_accept(Identifier): label = self.parse_identifier() self.accept(';') return tree.ContinueStatement(goto=label) elif self.try_accept('return'): value = None if not self.would_accept(';'): value = self.parse_expression() self.accept(';') return tree.ReturnStatement(expression=value) elif self.try_accept('throw'): value = self.parse_expression() self.accept(';') return tree.ThrowStatement(expression=value) elif self.try_accept('synchronized'): lock = self.parse_par_expression() block = self.parse_block() return tree.SynchronizedStatement(lock=lock, block=block) elif self.try_accept('try'): resource_specification = None block = None catches = None finally_block = None if self.would_accept('{'): block = self.parse_block() if self.would_accept('catch'): catches = self.parse_catches() if self.try_accept('finally'): finally_block = self.parse_block() if catches == None and finally_block == None: self.illegal("Expected catch/finally block") else: resource_specification = self.parse_resource_specification() block = self.parse_block() if self.would_accept('catch'): catches = self.parse_catches() if self.try_accept('finally'): finally_block = self.parse_block() return tree.TryStatement(resources=resource_specification, block=block, catches=catches, finally_block=finally_block) else: expression = self.parse_expression() self.accept(';') return tree.StatementExpression(expression=expression) # ------------------------------------------------------------------------------ # -- Try / catch -- @parse_debug def parse_catches(self): catches = list() while True: catch = self.parse_catch_clause() catches.append(catch) if not self.would_accept('catch'): break return catches @parse_debug def parse_catch_clause(self): self.accept('catch', '(') modifiers, annotations = self.parse_variable_modifiers() catch_parameter = tree.CatchClauseParameter(types=list()) while True: catch_type = self.parse_qualified_identifier() catch_parameter.types.append(catch_type) if not self.try_accept('|'): break catch_parameter.name = self.parse_identifier() self.accept(')') block = self.parse_block() return tree.CatchClause(parameter=catch_parameter, block=block) @parse_debug def parse_resource_specification(self): resources = list() self.accept('(') while True: resource = self.parse_resource() resources.append(resource) if not self.would_accept(')'): self.accept(';') if self.try_accept(')'): break return resources @parse_debug def parse_resource(self): modifiers, annotations = self.parse_variable_modifiers() reference_type = self.parse_reference_type() reference_type.dimensions = self.parse_array_dimension() name = self.parse_identifier() reference_type.dimensions += self.parse_array_dimension() self.accept('=') value = self.parse_expression() return tree.TryResource(modifiers=modifiers, annotations=annotations, type=reference_type, name=name, value=value) # ------------------------------------------------------------------------------ # -- Switch and for statements --- @parse_debug def parse_switch_block_statement_groups(self): statement_groups = list() while self.tokens.look().value in ('case', 'default'): statement_group = self.parse_switch_block_statement_group() statement_groups.append(statement_group) return statement_groups @parse_debug def parse_switch_block_statement_group(self): labels = list() statements = list() while True: case_type = self.tokens.next().value case_value = None if case_type == 'case': if self.would_accept(Identifier, ':'): case_value = self.parse_identifier() else: case_value = self.parse_expression() labels.append(case_value) elif not case_type == 'default': self.illegal("Expected switch case") self.accept(':') if self.tokens.look().value not in ('case', 'default'): break while self.tokens.look().value not in ('case', 'default', '}'): statement = self.parse_block_statement() statements.append(statement) return tree.SwitchStatementCase(case=labels, statements=statements) @parse_debug def parse_for_control(self): # Try for_var_control and fall back to normal three part for control try: with self.tokens: return self.parse_for_var_control() except JavaSyntaxError: pass init = None if not self.would_accept(';'): init = self.parse_for_init_or_update() self.accept(';') condition = None if not self.would_accept(';'): condition = self.parse_expression() self.accept(';') update = None if not self.would_accept(')'): update = self.parse_for_init_or_update() return tree.ForControl(init=init, condition=condition, update=update) @parse_debug def parse_for_var_control(self): modifiers, annotations = self.parse_variable_modifiers() var_type = self.parse_type() var_name = self.parse_identifier() var_type.dimensions += self.parse_array_dimension() var = tree.VariableDeclaration(modifiers=modifiers, annotations=annotations, type=var_type) rest = self.parse_for_var_control_rest() if isinstance(rest, tree.Expression): var.declarators = [tree.VariableDeclarator(name=var_name)] return tree.EnhancedForControl(var=var, iterable=rest) else: declarators, condition, update = rest declarators[0].name = var_name var.declarators = declarators return tree.ForControl(init=var, condition=condition, update=update) @parse_debug def parse_for_var_control_rest(self): if self.try_accept(':'): expression = self.parse_expression() return expression declarators = None if not self.would_accept(';'): declarators = self.parse_for_variable_declarator_rest() else: declarators = [tree.VariableDeclarator()] self.accept(';') condition = None if not self.would_accept(';'): condition = self.parse_expression() self.accept(';') update = None if not self.would_accept(')'): update = self.parse_for_init_or_update() return (declarators, condition, update) @parse_debug def parse_for_variable_declarator_rest(self): initializer = None if self.try_accept('='): initializer = self.parse_variable_initializer() declarators = [tree.VariableDeclarator(initializer=initializer)] while self.try_accept(','): declarator = self.parse_variable_declarator() declarators.append(declarator) return declarators @parse_debug def parse_for_init_or_update(self): expressions = list() while True: expression = self.parse_expression() expressions.append(expression) if not self.try_accept(','): break return expressions # ------------------------------------------------------------------------------ # -- Expressions -- @parse_debug def parse_expression(self): expressionl = self.parse_expressionl() assignment_type = None assignment_expression = None if self.tokens.look().value in Operator.ASSIGNMENT: assignment_type = self.tokens.next().value assignment_expression = self.parse_expression() return tree.Assignment(expressionl=expressionl, type=assignment_type, value=assignment_expression) else: return expressionl @parse_debug def parse_expressionl(self): expression_2 = self.parse_expression_2() true_expression = None false_expression = None if self.try_accept('?'): true_expression = self.parse_expression() self.accept(':') false_expression = self.parse_expressionl() return tree.TernaryExpression(condition=expression_2, if_true=true_expression, if_false=false_expression) if self.would_accept('->'): body = self.parse_lambda_method_body() return tree.LambdaExpression(parameters=[expression_2], body=body) if self.try_accept('::'): method_reference, type_arguments = self.parse_method_reference() return tree.MethodReference( expression=expression_2, method=method_reference, type_arguments=type_arguments) return expression_2 @parse_debug def parse_expression_2(self): expression_3 = self.parse_expression_3() token = self.tokens.look() if token.value in Operator.INFIX or token.value == 'instanceof': parts = self.parse_expression_2_rest() parts.insert(0, expression_3) return self.build_binary_operation(parts) return expression_3 @parse_debug def parse_expression_2_rest(self): parts = list() token = self.tokens.look() while token.value in Operator.INFIX or token.value == 'instanceof': if self.try_accept('instanceof'): comparison_type = self.parse_type() parts.extend(('instanceof', comparison_type)) else: operator = self.parse_infix_operator() expression = self.parse_expression_3() parts.extend((operator, expression)) token = self.tokens.look() return parts # ------------------------------------------------------------------------------ # -- Expression operators -- @parse_debug def parse_expression_3(self): prefix_operators = list() while self.tokens.look().value in Operator.PREFIX: prefix_operators.append(self.tokens.next().value) if self.would_accept('('): try: with self.tokens: lambda_exp = self.parse_lambda_expression() if lambda_exp: return lambda_exp except JavaSyntaxError: pass try: with self.tokens: self.accept('(') cast_target = self.parse_type() self.accept(')') expression = self.parse_expression_3() return tree.Cast(type=cast_target, expression=expression) except JavaSyntaxError: pass primary = self.parse_primary() primary.prefix_operators = prefix_operators primary.selectors = list() primary.postfix_operators = list() token = self.tokens.look() while token.value in '[.': selector = self.parse_selector() primary.selectors.append(selector) token = self.tokens.look() while token.value in Operator.POSTFIX: primary.postfix_operators.append(self.tokens.next().value) token = self.tokens.look() return primary @parse_debug def parse_method_reference(self): type_arguments = list() if self.would_accept('<'): type_arguments = self.parse_nonwildcard_type_arguments() if self.would_accept('new'): method_reference = tree.MemberReference(member=self.accept('new')) else: method_reference = self.parse_expression() return method_reference, type_arguments @parse_debug def parse_lambda_expression(self): lambda_expr = None parameters = None if self.would_accept('(', Identifier, ','): self.accept('(') parameters = [] while not self.would_accept(')'): parameters.append(tree.InferredFormalParameter( name=self.parse_identifier())) self.try_accept(',') self.accept(')') else: parameters = self.parse_formal_parameters() body = self.parse_lambda_method_body() return tree.LambdaExpression(parameters=parameters, body=body) @parse_debug def parse_lambda_method_body(self): if self.accept('->'): if self.would_accept('{'): return self.parse_block() else: return self.parse_expression() @parse_debug def parse_infix_operator(self): operator = self.accept(Operator) if not operator in Operator.INFIX: self.illegal("Expected infix operator") if operator == '>' and self.try_accept('>'): operator = '>>' if self.try_accept('>'): operator = '>>>' return operator # ------------------------------------------------------------------------------ # -- Primary expressions -- @parse_debug def parse_primary(self): token = self.tokens.look() if isinstance(token, Literal): return self.parse_literal() elif token.value == '(': return self.parse_par_expression() elif self.try_accept('this'): arguments = None if self.would_accept('('): arguments = self.parse_arguments() return tree.ExplicitConstructorInvocation(arguments=arguments) return tree.This() elif self.would_accept('super', '::'): self.accept('super') return token elif self.try_accept('super'): super_suffix = self.parse_super_suffix() return super_suffix elif self.try_accept('new'): return self.parse_creator() elif token.value == '<': type_arguments = self.parse_nonwildcard_type_arguments() if self.try_accept('this'): arguments = self.parse_arguments() return tree.ExplicitConstructorInvocation(type_arguments=type_arguments, arguments=arguments) else: invocation = self.parse_explicit_generic_invocation_suffix() invocation.type_arguments = type_arguments return invocation elif isinstance(token, Identifier): qualified_identifier = [self.parse_identifier()] while self.would_accept('.', Identifier): self.accept('.') identifier = self.parse_identifier() qualified_identifier.append(identifier) identifier_suffix = self.parse_identifier_suffix() if isinstance(identifier_suffix, (tree.MemberReference, tree.MethodInvocation)): # Take the last identifer as the member and leave the rest for the qualifier identifier_suffix.member = qualified_identifier.pop() elif isinstance(identifier_suffix, tree.ClassReference): identifier_suffix.type = tree.ReferenceType(name=qualified_identifier.pop()) identifier_suffix.qualifier = '.'.join(qualified_identifier) return identifier_suffix elif isinstance(token, BasicType): base_type = self.parse_basic_type() base_type.dimensions = self.parse_array_dimension() self.accept('.', 'class') return tree.ClassReference(type=base_type) elif self.try_accept('void'): self.accept('.', 'class') return tree.VoidClassReference() self.illegal("Expected expression") @parse_debug def parse_literal(self): literal = self.accept(Literal) return tree.Literal(value=literal) @parse_debug def parse_par_expression(self): self.accept('(') expression = self.parse_expression() self.accept(')') return expression @parse_debug def parse_arguments(self): expressions = list() self.accept('(') if self.try_accept(')'): return expressions while True: expression = self.parse_expression() expressions.append(expression) if not self.try_accept(','): break self.accept(')') return expressions @parse_debug def parse_super_suffix(self): identifier = None type_arguments = None arguments = None if self.try_accept('.'): if self.would_accept('<'): type_arguments = self.parse_nonwildcard_type_arguments() identifier = self.parse_identifier() if self.would_accept('('): arguments = self.parse_arguments() else: arguments = self.parse_arguments() if identifier and arguments is not None: return tree.SuperMethodInvocation(member=identifier, arguments=arguments, type_arguments=type_arguments) elif arguments is not None: return tree.SuperConstructorInvocation(arguments=arguments) else: return tree.SuperMemberReference(member=identifier) @parse_debug def parse_explicit_generic_invocation_suffix(self): identifier = None arguments = None if self.try_accept('super'): return self.parse_super_suffix() else: identifier = self.parse_identifier() arguments = self.parse_arguments() return tree.MethodInvocation(member=identifier, arguments=arguments) # ------------------------------------------------------------------------------ # -- Creators -- @parse_debug def parse_creator(self): constructor_type_arguments = None if self.would_accept(BasicType): created_name = self.parse_basic_type() rest = self.parse_array_creator_rest() rest.type = created_name return rest if self.would_accept('<'): constructor_type_arguments = self.parse_nonwildcard_type_arguments() created_name = self.parse_created_name() if self.would_accept('['): if constructor_type_arguments: self.illegal("Array creator not allowed with generic constructor type arguments") rest = self.parse_array_creator_rest() rest.type = created_name return rest else: arguments, body = self.parse_class_creator_rest() return tree.ClassCreator(constructor_type_arguments=constructor_type_arguments, type=created_name, arguments=arguments, body=body) @parse_debug def parse_created_name(self): created_name = tree.ReferenceType() tail = created_name while True: tail.name = self.parse_identifier() if self.would_accept('<'): tail.arguments = self.parse_type_arguments_or_diamond() if self.try_accept('.'): tail.sub_type = tree.ReferenceType() tail = tail.sub_type else: break return created_name @parse_debug def parse_class_creator_rest(self): arguments = self.parse_arguments() class_body = None if self.would_accept('{'): class_body = self.parse_class_body() return (arguments, class_body) @parse_debug def parse_array_creator_rest(self): if self.would_accept('[', ']'): array_dimension = self.parse_array_dimension() array_initializer = self.parse_array_initializer() return tree.ArrayCreator(dimensions=array_dimension, initializer=array_initializer) else: array_dimensions = list() while self.would_accept('[') and not self.would_accept('[', ']'): self.accept('[') expression = self.parse_expression() array_dimensions.append(expression) self.accept(']') array_dimensions += self.parse_array_dimension() return tree.ArrayCreator(dimensions=array_dimensions) @parse_debug def parse_identifier_suffix(self): if self.try_accept('[', ']'): array_dimension = [None] + self.parse_array_dimension() self.accept('.', 'class') return tree.ClassReference(type=tree.Type(dimensions=array_dimension)) elif self.would_accept('('): arguments = self.parse_arguments() return tree.MethodInvocation(arguments=arguments) elif self.try_accept('.', 'class'): return tree.ClassReference() elif self.try_accept('.', 'this'): return tree.This() elif self.would_accept('.', '<'): next(self.tokens) return self.parse_explicit_generic_invocation() elif self.try_accept('.', 'new'): type_arguments = None if self.would_accept('<'): type_arguments = self.parse_nonwildcard_type_arguments() inner_creator = self.parse_inner_creator() inner_creator.constructor_type_arguments = type_arguments return inner_creator elif self.would_accept('.', 'super', '('): self.accept('.', 'super') arguments = self.parse_arguments() return tree.SuperConstructorInvocation(arguments=arguments) else: return tree.MemberReference() @parse_debug def parse_explicit_generic_invocation(self): type_arguments = self.parse_nonwildcard_type_arguments() invocation = self.parse_explicit_generic_invocation_suffix() invocation.type_arguments = type_arguments return invocation @parse_debug def parse_inner_creator(self): identifier = self.parse_identifier() type_arguments = None if self.would_accept('<'): type_arguments = self.parse_nonwildcard_type_arguments_or_diamond() java_type = tree.ReferenceType(name=identifier, arguments=type_arguments) arguments, class_body = self.parse_class_creator_rest() return tree.InnerClassCreator(type=java_type, arguments=arguments, body=class_body) @parse_debug def parse_selector(self): if self.try_accept('['): expression = self.parse_expression() self.accept(']') return tree.ArraySelector(index=expression) elif self.try_accept('.'): token = self.tokens.look() if isinstance(token, Identifier): identifier = self.tokens.next().value arguments = None if self.would_accept('('): arguments = self.parse_arguments() return tree.MethodInvocation(member=identifier, arguments=arguments) else: return tree.MemberReference(member=identifier) elif self.would_accept('super', '::'): self.accept('super') return token elif self.would_accept('<'): return self.parse_explicit_generic_invocation() elif self.try_accept('this'): return tree.This() elif self.try_accept('super'): return self.parse_super_suffix() elif self.try_accept('new'): type_arguments = None if self.would_accept('<'): type_arguments = self.parse_nonwildcard_type_arguments() inner_creator = self.parse_inner_creator() inner_creator.constructor_type_arguments = type_arguments return inner_creator self.illegal("Expected selector") # ------------------------------------------------------------------------------ # -- Enum and annotation body -- @parse_debug def parse_enum_body(self): constants = list() body_declarations = list() self.accept('{') if not self.try_accept(','): while not (self.would_accept(';') or self.would_accept('}')): constant = self.parse_enum_constant() constants.append(constant) if not self.try_accept(','): break if self.try_accept(';'): while not self.would_accept('}'): declaration = self.parse_class_body_declaration() if declaration: body_declarations.append(declaration) self.accept('}') return tree.EnumBody(constants=constants, declarations=body_declarations) @parse_debug def parse_enum_constant(self): annotations = list() javadoc = None constant_name = None arguments = None body = None next_token = self.tokens.look() if next_token: javadoc = next_token.javadoc if self.would_accept(Annotation): annotations = self.parse_annotations() constant_name = self.parse_identifier() if self.would_accept('('): arguments = self.parse_arguments() if self.would_accept('{'): body = self.parse_class_body() return tree.EnumConstantDeclaration(annotations=annotations, name=constant_name, arguments=arguments, body=body, documentation=javadoc) @parse_debug def parse_annotation_type_body(self): declarations = None self.accept('{') declarations = self.parse_annotation_type_element_declarations() self.accept('}') return declarations @parse_debug def parse_annotation_type_element_declarations(self): declarations = list() while not self.would_accept('}'): declaration = self.parse_annotation_type_element_declaration() declarations.append(declaration) return declarations @parse_debug def parse_annotation_type_element_declaration(self): modifiers, annotations, javadoc = self.parse_modifiers() declaration = None if self.would_accept('class'): declaration = self.parse_normal_class_declaration() elif self.would_accept('interface'): declaration = self.parse_normal_interface_declaration() elif self.would_accept('enum'): declaration = self.parse_enum_declaration() elif self.is_annotation_declaration(): declaration = self.parse_annotation_type_declaration() else: attribute_type = self.parse_type() attribute_name = self.parse_identifier() declaration = self.parse_annotation_method_or_constant_rest() self.accept(';') if isinstance(declaration, tree.AnnotationMethod): declaration.name = attribute_name declaration.return_type = attribute_type else: declaration.declarators[0].name = attribute_name declaration.type = attribute_type declaration.modifiers = modifiers declaration.annotations = annotations declaration.documentation = javadoc return declaration @parse_debug def parse_annotation_method_or_constant_rest(self): if self.try_accept('('): self.accept(')') array_dimension = self.parse_array_dimension() default = None if self.try_accept('default'): default = self.parse_element_value() return tree.AnnotationMethod(dimensions=array_dimension, default=default) else: return self.parse_constant_declarators_rest() def parse(tokens, debug=False): parser = Parser(tokens) parser.set_debug(debug) return parser.parse() ================================================ FILE: baseline_tokenization/javalang/test/__init__.py ================================================ ================================================ FILE: baseline_tokenization/javalang/test/source/package-info/AnnotationJavadoc.java ================================================ @Package /** Test that includes java doc first but no annotation */ package org.javalang.test; ================================================ FILE: baseline_tokenization/javalang/test/source/package-info/AnnotationOnly.java ================================================ @Package package org.javalang.test; ================================================ FILE: baseline_tokenization/javalang/test/source/package-info/JavadocAnnotation.java ================================================ /** Test that includes java doc first but no annotation */ @Package package org.javalang.test; ================================================ FILE: baseline_tokenization/javalang/test/source/package-info/JavadocOnly.java ================================================ /** Test that includes java doc first but no annotation */ package org.javalang.test; ================================================ FILE: baseline_tokenization/javalang/test/source/package-info/NoAnnotationNoJavadoc.java ================================================ package org.javalang.test; ================================================ FILE: baseline_tokenization/javalang/test/test_java_8_syntax.py ================================================ import unittest from pkg_resources import resource_string from .. import parse, parser, tree def setup_java_class(content_to_add): """ returns an example java class with the given content_to_add contained within a method. """ template = """ public class Lambda { public static void main(String args[]) { %s } } """ return template % content_to_add def filter_type_in_method(clazz, the_type, method_name): """ yields the result of filtering the given class for the given type inside the given method identified by its name. """ for path, node in clazz.filter(the_type): for p in reversed(path): if isinstance(p, tree.MethodDeclaration): if p.name == method_name: yield path, node class LambdaSupportTest(unittest.TestCase): """ Contains tests for java 8 lambda syntax. """ def assert_contains_lambda_expression_in_m( self, clazz, method_name='main'): """ asserts that the given tree contains a method with the supplied method name containing a lambda expression. """ matches = list(filter_type_in_method( clazz, tree.LambdaExpression, method_name)) if not matches: self.fail('No matching lambda expression found.') return matches def test_lambda_support_no_parameters_no_body(self): """ tests support for lambda with no parameters and no body. """ self.assert_contains_lambda_expression_in_m( parse.parse(setup_java_class("() -> {};"))) def test_lambda_support_no_parameters_expression_body(self): """ tests support for lambda with no parameters and an expression body. """ test_classes = [ setup_java_class("() -> 3;"), setup_java_class("() -> null;"), setup_java_class("() -> { return 21; };"), setup_java_class("() -> { System.exit(1); };"), ] for test_class in test_classes: clazz = parse.parse(test_class) self.assert_contains_lambda_expression_in_m(clazz) def test_lambda_support_no_parameters_complex_expression(self): """ tests support for lambda with no parameters and a complex expression body. """ code = """ () -> { if (true) return 21; else { int result = 21; return result / 2; } };""" self.assert_contains_lambda_expression_in_m( parse.parse(setup_java_class(code))) def test_parameter_no_type_expression_body(self): """ tests support for lambda with parameters with inferred types. """ test_classes = [ setup_java_class("(bar) -> bar + 1;"), setup_java_class("bar -> bar + 1;"), setup_java_class("x -> x.length();"), setup_java_class("y -> { y.boom(); };"), ] for test_class in test_classes: clazz = parse.parse(test_class) self.assert_contains_lambda_expression_in_m(clazz) def test_parameter_with_type_expression_body(self): """ tests support for lambda with parameters with formal types. """ test_classes = [ setup_java_class("(int foo) -> { return foo + 2; };"), setup_java_class("(String s) -> s.length();"), setup_java_class("(int foo) -> foo + 1;"), setup_java_class("(Thread th) -> { th.start(); };"), setup_java_class("(String foo, String bar) -> " "foo + bar;"), ] for test_class in test_classes: clazz = parse.parse(test_class) self.assert_contains_lambda_expression_in_m(clazz) def test_parameters_with_no_type_expression_body(self): """ tests support for multiple lambda parameters that are specified without their types. """ self.assert_contains_lambda_expression_in_m( parse.parse(setup_java_class("(x, y) -> x + y;"))) def test_parameters_with_mixed_inferred_and_declared_types(self): """ this tests that lambda type specification mixing is considered invalid as per the specifications. """ with self.assertRaises(parser.JavaSyntaxError): parse.parse(setup_java_class("(x, int y) -> x+y;")) def test_parameters_inferred_types_with_modifiers(self): """ this tests that lambda inferred type parameters with modifiers are considered invalid as per the specifications. """ with self.assertRaises(parser.JavaSyntaxError): parse.parse(setup_java_class("(x, final y) -> x+y;")) def test_invalid_parameters_are_invalid(self): """ this tests that invalid lambda parameters are are considered invalid as per the specifications. """ with self.assertRaises(parser.JavaSyntaxError): parse.parse(setup_java_class("(a b c) -> {};")) def test_cast_works(self): """ this tests that a cast expression works as expected. """ parse.parse(setup_java_class("String x = (String) A.x() ;")) class MethodReferenceSyntaxTest(unittest.TestCase): """ Contains tests for java 8 method reference syntax. """ def assert_contains_method_reference_expression_in_m( self, clazz, method_name='main'): """ asserts that the given class contains a method with the supplied method name containing a method reference. """ matches = list(filter_type_in_method( clazz, tree.MethodReference, method_name)) if not matches: self.fail('No matching method reference found.') return matches def test_method_reference(self): """ tests that method references are supported. """ self.assert_contains_method_reference_expression_in_m( parse.parse(setup_java_class("String::length;"))) def test_method_reference_to_the_new_method(self): """ test support for method references to 'new'. """ self.assert_contains_method_reference_expression_in_m( parse.parse(setup_java_class("String::new;"))) def test_method_reference_to_the_new_method_with_explict_type(self): """ test support for method references to 'new' with an explicit type. """ self.assert_contains_method_reference_expression_in_m( parse.parse(setup_java_class("String:: new;"))) def test_method_reference_from_super(self): """ test support for method references from 'super'. """ self.assert_contains_method_reference_expression_in_m( parse.parse(setup_java_class("super::toString;"))) def test_method_reference_from_super_with_identifier(self): """ test support for method references from Identifier.super. """ self.assert_contains_method_reference_expression_in_m( parse.parse(setup_java_class("String.super::toString;"))) @unittest.expectedFailure def test_method_reference_explicit_type_arguments_for_generic_type(self): """ currently there is no support for method references for an explicit type. """ self.assert_contains_method_reference_expression_in_m( parse.parse(setup_java_class("List::size;"))) def test_method_reference_explicit_type_arguments(self): """ test support for method references with an explicit type. """ self.assert_contains_method_reference_expression_in_m( parse.parse(setup_java_class("Arrays:: sort;"))) @unittest.expectedFailure def test_method_reference_from_array_type(self): """ currently there is no support for method references from a primary type. """ self.assert_contains_method_reference_expression_in_m( parse.parse(setup_java_class("int[]::new;"))) class InterfaceSupportTest(unittest.TestCase): """ Contains tests for java 8 interface extensions. """ def test_interface_support_static_methods(self): parse.parse(""" interface Foo { void foo(); static Foo create() { return new Foo() { @Override void foo() { System.out.println("foo"); } }; } } """) def test_interface_support_default_methods(self): parse.parse(""" interface Foo { default void foo() { System.out.println("foo"); } } """) def main(): unittest.main() if __name__ == '__main__': main() ================================================ FILE: baseline_tokenization/javalang/test/test_javadoc.py ================================================ import unittest from .. import javadoc class TestJavadoc(unittest.TestCase): def test_empty_comment(self): javadoc.parse('/** */') javadoc.parse('/***/') javadoc.parse('/**\n *\n */') javadoc.parse('/**\n *\n *\n */') if __name__ == "__main__": unittest.main() ================================================ FILE: baseline_tokenization/javalang/test/test_package_declaration.py ================================================ import unittest from pkg_resources import resource_string from .. import parse # From my reading of the spec (http://docs.oracle.com/javase/specs/jls/se7/html/jls-7.html) the # allowed order is javadoc, optional annotation, package declaration class PackageInfo(unittest.TestCase): def testPackageDeclarationOnly(self): source_file = "source/package-info/NoAnnotationNoJavadoc.java" ast = self.get_ast(source_file) self.failUnless(ast.package.name == "org.javalang.test") self.failIf(ast.package.annotations) self.failIf(ast.package.documentation) def testAnnotationOnly(self): source_file = "source/package-info/AnnotationOnly.java" ast = self.get_ast(source_file) self.failUnless(ast.package.name == "org.javalang.test") self.failUnless(ast.package.annotations) self.failIf(ast.package.documentation) def testJavadocOnly(self): source_file = "source/package-info/JavadocOnly.java" ast = self.get_ast(source_file) self.failUnless(ast.package.name == "org.javalang.test") self.failIf(ast.package.annotations) self.failUnless(ast.package.documentation) def testAnnotationThenJavadoc(self): source_file = "source/package-info/AnnotationJavadoc.java" ast = self.get_ast(source_file) self.failUnless(ast.package.name == "org.javalang.test") self.failUnless(ast.package.annotations) self.failIf(ast.package.documentation) def testJavadocThenAnnotation(self): source_file = "source/package-info/JavadocAnnotation.java" ast = self.get_ast(source_file) self.failUnless(ast.package.name == "org.javalang.test") self.failUnless(ast.package.annotations) self.failUnless(ast.package.documentation) def get_ast(self, filename): source = resource_string(__name__, filename) ast = parse.parse(source) return ast def main(): unittest.main() if __name__ == '__main__': main() ================================================ FILE: baseline_tokenization/javalang/test/test_util.py ================================================ import unittest from ..util import LookAheadIterator class TestLookAheadIterator(unittest.TestCase): def test_usage(self): i = LookAheadIterator(list(range(0, 10000))) self.assertEqual(next(i), 0) self.assertEqual(next(i), 1) self.assertEqual(next(i), 2) self.assertEqual(i.last(), 2) self.assertEqual(i.look(), 3) self.assertEqual(i.last(), 3) self.assertEqual(i.look(1), 4) self.assertEqual(i.look(2), 5) self.assertEqual(i.look(3), 6) self.assertEqual(i.look(4), 7) self.assertEqual(i.last(), 7) i.push_marker() self.assertEqual(next(i), 3) self.assertEqual(next(i), 4) self.assertEqual(next(i), 5) i.pop_marker(True) # reset self.assertEqual(i.look(), 3) self.assertEqual(next(i), 3) i.push_marker() #1 self.assertEqual(next(i), 4) self.assertEqual(next(i), 5) i.push_marker() #2 self.assertEqual(next(i), 6) self.assertEqual(next(i), 7) i.push_marker() #3 self.assertEqual(next(i), 8) self.assertEqual(next(i), 9) i.pop_marker(False) #3 self.assertEqual(next(i), 10) i.pop_marker(True) #2 self.assertEqual(next(i), 6) self.assertEqual(next(i), 7) self.assertEqual(next(i), 8) i.pop_marker(False) #1 self.assertEqual(next(i), 9) try: with i: self.assertEqual(next(i), 10) self.assertEqual(next(i), 11) raise Exception() except: self.assertEqual(next(i), 10) self.assertEqual(next(i), 11) with i: self.assertEqual(next(i), 12) self.assertEqual(next(i), 13) self.assertEqual(next(i), 14) if __name__=="__main__": unittest.main() ================================================ FILE: baseline_tokenization/javalang/tokenizer.py ================================================ import re import unicodedata import six class LexerError(Exception): pass class JavaToken(object): def __init__(self, value, position=None, javadoc=None): self.value = value self.position = position self.javadoc = javadoc def __repr__(self): if self.position: return '%s "%s" line %d, position %d' % ( self.__class__.__name__, self.value, self.position[0], self.position[1] ) else: return '%s "%s"' % (self.__class__.__name__, self.value) def __str__(self): return repr(self) def __eq__(self, other): raise Exception("Direct comparison not allowed") class EndOfInput(JavaToken): pass class Keyword(JavaToken): VALUES = set(['abstract', 'assert', 'boolean', 'break', 'byte', 'case', 'catch', 'char', 'class', 'const', 'continue', 'default', 'do', 'double', 'else', 'enum', 'extends', 'final', 'finally', 'float', 'for', 'goto', 'if', 'implements', 'import', 'instanceof', 'int', 'interface', 'long', 'native', 'new', 'package', 'private', 'protected', 'public', 'return', 'short', 'static', 'strictfp', 'super', 'switch', 'synchronized', 'this', 'throw', 'throws', 'transient', 'try', 'void', 'volatile', 'while']) class Modifier(Keyword): VALUES = set(['abstract', 'default', 'final', 'native', 'private', 'protected', 'public', 'static', 'strictfp', 'synchronized', 'transient', 'volatile']) class BasicType(Keyword): VALUES = set(['boolean', 'byte', 'char', 'double', 'float', 'int', 'long', 'short']) class Literal(JavaToken): pass class Integer(Literal): pass class DecimalInteger(Literal): pass class OctalInteger(Integer): pass class BinaryInteger(Integer): pass class HexInteger(Integer): pass class FloatingPoint(Literal): pass class DecimalFloatingPoint(FloatingPoint): pass class HexFloatingPoint(FloatingPoint): pass class Boolean(Literal): VALUES = set(["true", "false"]) class Character(Literal): pass class String(Literal): pass class Null(Literal): pass class Separator(JavaToken): VALUES = set(['(', ')', '{', '}', '[', ']', ';', ',', '.']) class Operator(JavaToken): MAX_LEN = 4 VALUES = set(['>>>=', '>>=', '<<=', '%=', '^=', '|=', '&=', '/=', '*=', '-=', '+=', '<<', '--', '++', '||', '&&', '!=', '>=', '<=', '==', '%', '^', '|', '&', '/', '*', '-', '+', ':', '?', '~', '!', '<', '>', '=', '...', '->', '::']) # '>>>' and '>>' are excluded so that >> becomes two tokens and >>> becomes # three. This is done because we can not distinguish the operators >> and # >>> from the closing of multipel type parameter/argument lists when # lexing. The job of potentially recombining these symbols is left to the # parser INFIX = set(['||', '&&', '|', '^', '&', '==', '!=', '<', '>', '<=', '>=', '<<', '>>', '>>>', '+', '-', '*', '/', '%']) PREFIX = set(['++', '--', '!', '~', '+', '-']) POSTFIX = set(['++', '--']) ASSIGNMENT = set(['=', '+=', '-=', '*=', '/=', '&=', '|=', '^=', '%=', '<<=', '>>=', '>>>=']) LAMBDA = set(['->']) METHOD_REFERENCE = set(['::',]) def is_infix(self): return self.value in self.INFIX def is_prefix(self): return self.value in self.PREFIX def is_postfix(self): return self.value in self.POSTFIX def is_assignment(self): return self.value in self.ASSIGNMENT class Annotation(JavaToken): pass class Identifier(JavaToken): pass class JavaTokenizer(object): IDENT_START_CATEGORIES = set(['Lu', 'Ll', 'Lt', 'Lm', 'Lo', 'Nl', 'Pc', 'Sc']) IDENT_PART_CATEGORIES = set(['Lu', 'Ll', 'Lt', 'Lm', 'Lo', 'Mc', 'Mn', 'Nd', 'Nl', 'Pc', 'Sc']) def __init__(self, data): self.data = data self.current_line = 1 self.start_of_line = 0 self.operators = [set() for i in range(0, Operator.MAX_LEN)] for v in Operator.VALUES: self.operators[len(v) - 1].add(v) self.whitespace_consumer = re.compile(r'[^\s]') self.javadoc = None def reset(self): self.i = 0 self.j = 0 def consume_whitespace(self): match = self.whitespace_consumer.search(self.data, self.i + 1) if not match: self.i = self.length return i = match.start() start_of_line = self.data.rfind('\n', self.i, i) if start_of_line != -1: self.start_of_line = start_of_line self.current_line += self.data.count('\n', self.i, i) self.i = i def read_string(self): delim = self.data[self.i] state = 0 j = self.i + 1 length = self.length while True: if j >= length: self.error('Unterminated character/string literal') if state == 0: if self.data[j] == '\\': state = 1 elif self.data[j] == delim: break elif state == 1: if self.data[j] in 'btnfru"\'\\': state = 0 elif self.data[j] in '0123': state = 2 elif self.data[j] in '01234567': state = 3 else: self.error('Illegal escape character', self.data[j]) elif state == 2: # Possibly long octal if self.data[j] in '01234567': state = 3 elif self.data[j] == '\\': state = 1 elif self.data[j] == delim: break elif state == 3: state = 0 if self.data[j] == '\\': state = 1 elif self.data[j] == delim: break j += 1 self.j = j + 1 def try_operator(self): for l in range(min(self.length - self.i, Operator.MAX_LEN), 0, -1): if self.data[self.i:self.i + l] in self.operators[l - 1]: self.j = self.i + l return True return False def read_comment(self): if self.data[self.i + 1] == '/': i = self.data.find('\n', self.i + 2) if i == -1: self.i = self.length return i += 1 self.start_of_line = i self.current_line += 1 self.i = i else: i = self.data.find('*/', self.i + 2) if i == -1: self.i = self.length return i += 2 self.start_of_line = i self.current_line += self.data.count('\n', self.i, i) self.i = i def try_javadoc_comment(self): if self.i + 2 >= self.length or self.data[self.i + 2] != '*': return False j = self.data.find('*/', self.i + 2) if j == -1: self.j = self.length return False j += 2 self.start_of_line = j self.current_line += self.data.count('\n', self.i, j) self.j = j return True def read_decimal_float_or_integer(self): orig_i = self.i self.j = self.i self.read_decimal_integer() if self.data[self.j] not in '.eEfFdD': return DecimalInteger if self.data[self.j] == '.': self.i = self.j + 1 self.read_decimal_integer() if self.data[self.j] in 'eE': self.j = self.j + 1 if self.data[self.j] in '-+': self.j = self.j + 1 self.i = self.j self.read_decimal_integer() if self.data[self.j] in 'fFdD': self.j = self.j + 1 self.i = orig_i return DecimalFloatingPoint def read_hex_integer_or_float(self): orig_i = self.i self.j = self.i + 2 self.read_hex_integer() if self.data[self.j] not in '.pP': return HexInteger if self.data[self.j] == '.': self.j = self.j + 1 self.read_digits('0123456789abcdefABCDEF') if self.data[self.j] in 'pP': self.j = self.j + 1 else: self.error('Invalid hex float literal') if self.data[self.j] in '-+': self.j = self.j + 1 self.i = self.j self.read_decimal_integer() if self.data[self.j] in 'fFdD': self.j = self.j + 1 self.i = orig_i return HexFloatingPoint def read_digits(self, digits): tmp_i = 0 c = None while True: c = self.data[self.j + tmp_i] if c in digits: self.j += 1 + tmp_i tmp_i = 0 elif c == '_': tmp_i += 1 else: break if c in 'lL': self.j += 1 def read_decimal_integer(self): self.j = self.i self.read_digits('0123456789') def read_hex_integer(self): self.j = self.i + 2 self.read_digits('0123456789abcdefABCDEF') def read_bin_integer(self): self.j = self.i + 2 self.read_digits('01') def read_octal_integer(self): self.j = self.i + 1 self.read_digits('01234567') def read_integer_or_float(self, c, c_next): if c == '0' and c_next in 'xX': return self.read_hex_integer_or_float() elif c == '0' and c_next in 'bB': self.read_bin_integer() return BinaryInteger elif c == '0' and c_next in '01234567': self.read_octal_integer() return OctalInteger else: return self.read_decimal_float_or_integer() def try_separator(self): if self.data[self.i] in Separator.VALUES: self.j = self.i + 1 return True return False def decode_data(self): # Encodings to try in order codecs = ['utf_8', 'iso-8859-1'] # If data is already unicode don't try to redecode if isinstance(self.data, six.text_type): return self.data for codec in codecs: try: data = self.data.decode(codec) return data except UnicodeDecodeError: pass self.error('Could not decode input data') def is_java_identifier_start(self, c): return unicodedata.category(c) in self.IDENT_START_CATEGORIES def read_identifier(self): self.j = self.i + 1 while unicodedata.category(self.data[self.j]) in self.IDENT_PART_CATEGORIES: self.j += 1 ident = self.data[self.i:self.j] if ident in Keyword.VALUES: token_type = Keyword if ident in BasicType.VALUES: token_type = BasicType elif ident in Modifier.VALUES: token_type = Modifier elif ident in Boolean.VALUES: token_type = Boolean elif ident == 'null': token_type = Null else: token_type = Identifier return token_type def pre_tokenize(self): new_data = list() data = self.decode_data() i = 0 j = 0 length = len(data) NONE = 0 ELIGIBLE = 1 MARKER_FOUND = 2 state = NONE while j < length: if state == NONE: j = data.find('\\', j) if j == -1: j = length break state = ELIGIBLE elif state == ELIGIBLE: c = data[j] if c == 'u': state = MARKER_FOUND new_data.append(data[i:j - 1]) else: state = NONE elif state == MARKER_FOUND: c = data[j] if c != 'u': try: escape_code = int(data[j:j+4], 16) except ValueError: self.error('Invalid unicode escape', data[j:j+4]) new_data.append(six.unichr(escape_code)) i = j + 4 j = i state = NONE continue j = j + 1 new_data.append(data[i:]) self.data = ''.join(new_data) self.length = len(self.data) def tokenize(self): self.reset() # Convert unicode escapes self.pre_tokenize() while self.i < self.length: token_type = None c = self.data[self.i] c_next = None startswith = c if self.i + 1 < self.length: c_next = self.data[self.i + 1] startswith = c + c_next if c.isspace(): self.consume_whitespace() continue elif startswith in ("//", "/*"): if startswith == "/*" and self.try_javadoc_comment(): self.javadoc = self.data[self.i:self.j] self.i = self.j else: self.read_comment() continue elif startswith == '..' and self.try_operator(): # Ensure we don't mistake a '...' operator as a sequence of # three '.' separators. This is done as an optimization instead # of moving try_operator higher in the chain because operators # aren't as common and try_operator is expensive token_type = Operator elif c == '@': token_type = Annotation self.j = self.i + 1 elif c == '.' and c_next.isdigit(): token_type = self.read_decimal_float_or_integer() elif self.try_separator(): token_type = Separator elif c in ("'", '"'): token_type = String self.read_string() elif c in '0123456789': token_type = self.read_integer_or_float(c, c_next) elif self.is_java_identifier_start(c): token_type = self.read_identifier() elif self.try_operator(): token_type = Operator else: self.error('Could not process token', c) position = (self.current_line, self.i - self.start_of_line) token = token_type(self.data[self.i:self.j], position, self.javadoc) yield token if self.javadoc: self.javadoc = None self.i = self.j def error(self, message, char=None): # Provide additional information in the errors message line_start = self.data.rfind('\n', 0, self.i) + 1 line_end = self.data.find('\n', self.i) line = self.data[line_start:line_end].strip() line_number = self.current_line if not char: char = self.data[self.j] message = u'%s at "%s", line %s: %s' % (message, char, line_number, line) raise LexerError(message) def tokenize(code): tokenizer = JavaTokenizer(code) return tokenizer.tokenize() def reformat_tokens(tokens): indent = 0 closed_block = False ident_last = False output = list() for token in tokens: if closed_block: closed_block = False indent -= 4 output.append('\n') output.append(' ' * indent) output.append('}') if isinstance(token, (Literal, Keyword, Identifier)): output.append('\n') output.append(' ' * indent) if token.value == '{': indent += 4 output.append(' {\n') output.append(' ' * indent) elif token.value == '}': closed_block = True elif token.value == ',': output.append(', ') elif isinstance(token, (Literal, Keyword, Identifier)): if ident_last: # If the last token was a literla/keyword/identifer put a space in between output.append(' ') ident_last = True output.append(token.value) elif isinstance(token, Operator): output.append(' ' + token.value + ' ') elif token.value == ';': output.append(';\n') output.append(' ' * indent) else: output.append(token.value) ident_last = isinstance(token, (Literal, Keyword, Identifier)) if closed_block: output.append('\n}') output.append('\n') return ''.join(output) ================================================ FILE: baseline_tokenization/javalang/tree.py ================================================ from .ast import Node # ------------------------------------------------------------------------------ class CompilationUnit(Node): attrs = ("package", "imports", "types") class Import(Node): attrs = ("path", "static", "wildcard") class Documented(Node): attrs = ("documentation",) class Declaration(Node): attrs = ("modifiers", "annotations") class TypeDeclaration(Declaration, Documented): attrs = ("name", "body") @property def fields(self): return [decl for decl in self.body if isinstance(decl, FieldDeclaration)] @property def methods(self): return [decl for decl in self.body if isinstance(decl, MethodDeclaration)] @property def constructors(self): return [decl for decl in self.body if isinstance(decl, ConstructorDeclaration)] class PackageDeclaration(Declaration, Documented): attrs = ("name",) class ClassDeclaration(TypeDeclaration): attrs = ("type_parameters", "extends", "implements") class EnumDeclaration(TypeDeclaration): attrs = ("implements",) class InterfaceDeclaration(TypeDeclaration): attrs = ("type_parameters", "extends",) class AnnotationDeclaration(TypeDeclaration): attrs = () # ------------------------------------------------------------------------------ class Type(Node): attrs = ("name", "dimensions",) class BasicType(Type): attrs = () class ReferenceType(Type): attrs = ("arguments", "sub_type") class TypeArgument(Node): attrs = ("type", "pattern_type") # ------------------------------------------------------------------------------ class TypeParameter(Node): attrs = ("name", "extends") # ------------------------------------------------------------------------------ class Annotation(Node): attrs = ("name", "element") class ElementValuePair(Node): attrs = ("name", "value") class ElementArrayValue(Node): attrs = ("values",) # ------------------------------------------------------------------------------ class Member(Documented): attrs = () class MethodDeclaration(Member, Declaration): attrs = ("type_parameters", "return_type", "name", "parameters", "throws", "body") class FieldDeclaration(Member, Declaration): attrs = ("type", "declarators") class ConstructorDeclaration(Declaration, Documented): attrs = ("type_parameters", "name", "parameters", "throws", "body") # ------------------------------------------------------------------------------ class ConstantDeclaration(FieldDeclaration): attrs = () class ArrayInitializer(Node): attrs = ("initializers",) class VariableDeclaration(Declaration): attrs = ("type", "declarators") class LocalVariableDeclaration(VariableDeclaration): attrs = () class VariableDeclarator(Node): attrs = ("name", "dimensions", "initializer") class FormalParameter(Declaration): attrs = ("type", "name", "varargs") class InferredFormalParameter(Node): attrs = ('name',) # ------------------------------------------------------------------------------ class Statement(Node): attrs = ("label",) class IfStatement(Statement): attrs = ("condition", "then_statement", "else_statement") class WhileStatement(Statement): attrs = ("condition", "body") class DoStatement(Statement): attrs = ("condition", "body") class ForStatement(Statement): attrs = ("control", "body") class AssertStatement(Statement): attrs = ("condition", "value") class BreakStatement(Statement): attrs = ("goto",) class ContinueStatement(Statement): attrs = ("goto",) class ReturnStatement(Statement): attrs = ("expression",) class ThrowStatement(Statement): attrs = ("expression",) class SynchronizedStatement(Statement): attrs = ("lock", "block") class TryStatement(Statement): attrs = ("resources", "block", "catches", "finally_block") class SwitchStatement(Statement): attrs = ("expression", "cases") class BlockStatement(Statement): attrs = ("statements",) class StatementExpression(Statement): attrs = ("expression",) # ------------------------------------------------------------------------------ class TryResource(Declaration): attrs = ("type", "name", "value") class CatchClause(Statement): attrs = ("parameter", "block") class CatchClauseParameter(Declaration): attrs = ("types", "name") # ------------------------------------------------------------------------------ class SwitchStatementCase(Node): attrs = ("case", "statements") class ForControl(Node): attrs = ("init", "condition", "update") class EnhancedForControl(Node): attrs = ("var", "iterable") # ------------------------------------------------------------------------------ class Expression(Node): attrs = () class Assignment(Expression): attrs = ("expressionl", "value", "type") class TernaryExpression(Expression): attrs = ("condition", "if_true", "if_false") class BinaryOperation(Expression): attrs = ("operator", "operandl", "operandr") class Cast(Expression): attrs = ("type", "expression") class MethodReference(Expression): attrs = ("expression", "method", "type_arguments") class LambdaExpression(Expression): attrs = ('parameters', 'body') # ------------------------------------------------------------------------------ class Primary(Expression): attrs = ("prefix_operators", "postfix_operators", "qualifier", "selectors") class Literal(Primary): attrs = ("value",) class This(Primary): attrs = () class MemberReference(Primary): attrs = ("member",) class Invocation(Primary): attrs = ("type_arguments", "arguments") class ExplicitConstructorInvocation(Invocation): attrs = () class SuperConstructorInvocation(Invocation): attrs = () class MethodInvocation(Invocation): attrs = ("member",) class SuperMethodInvocation(Invocation): attrs = ("member",) class SuperMemberReference(Primary): attrs = ("member",) class ArraySelector(Expression): attrs = ("index",) class ClassReference(Primary): attrs = ("type",) class VoidClassReference(ClassReference): attrs = () # ------------------------------------------------------------------------------ class Creator(Primary): attrs = ("type",) class ArrayCreator(Creator): attrs = ("dimensions", "initializer") class ClassCreator(Creator): attrs = ("constructor_type_arguments", "arguments", "body") class InnerClassCreator(Creator): attrs = ("constructor_type_arguments", "arguments", "body") # ------------------------------------------------------------------------------ class EnumBody(Node): attrs = ("constants", "declarations") class EnumConstantDeclaration(Declaration, Documented): attrs = ("name", "arguments", "body") class AnnotationMethod(Declaration): attrs = ("name", "return_type", "dimensions", "default") ================================================ FILE: baseline_tokenization/javalang/util.py ================================================ class LookAheadIterator(object): def __init__(self, iterable): self.iterable = iter(iterable) self.look_ahead = list() self.markers = list() self.default = None self.value = None def __iter__(self): return self def set_default(self, value): self.default = value def next(self): return self.__next__() def __next__(self): if self.look_ahead: self.value = self.look_ahead.pop(0) else: self.value = next(self.iterable) if self.markers: self.markers[-1].append(self.value) return self.value def look(self, i=0): """ Look ahead of the iterable by some number of values with advancing past them. If the requested look ahead is past the end of the iterable then None is returned. """ length = len(self.look_ahead) if length <= i: try: self.look_ahead.extend([next(self.iterable) for _ in range(length, i + 1)]) except StopIteration: return self.default self.value = self.look_ahead[i] return self.value def last(self): return self.value def __enter__(self): self.push_marker() return self def __exit__(self, exc_type, exc_val, exc_tb): # Reset the iterator if there was an error if exc_type or exc_val or exc_tb: self.pop_marker(True) else: self.pop_marker(False) def push_marker(self): """ Push a marker on to the marker stack """ self.markers.append(list()) def pop_marker(self, reset): """ Pop a marker off of the marker stack. If reset is True then the iterator will be returned to the state it was in before the corresponding call to push_marker(). """ marker = self.markers.pop() if reset: # Make the values available to be read again marker.extend(self.look_ahead) self.look_ahead = marker elif self.markers: # Otherwise, reassign the values to the top marker self.markers[-1].extend(marker) else: # If there are not more markers in the stack then discard the values pass class LookAheadListIterator(object): def __init__(self, iterable): self.list = list(iterable) self.marker = 0 self.saved_markers = [] self.default = None self.value = None def __iter__(self): return self def set_default(self, value): self.default = value def next(self): return self.__next__() def __next__(self): try: self.value = self.list[self.marker] self.marker += 1 except IndexError: raise StopIteration() return self.value def look(self, i=0): """ Look ahead of the iterable by some number of values with advancing past them. If the requested look ahead is past the end of the iterable then None is returned. """ try: self.value = self.list[self.marker + i] except IndexError: return self.default return self.value def last(self): return self.value def __enter__(self): self.push_marker() return self def __exit__(self, exc_type, exc_val, exc_tb): # Reset the iterator if there was an error if exc_type or exc_val or exc_tb: self.pop_marker(True) else: self.pop_marker(False) def push_marker(self): """ Push a marker on to the marker stack """ self.saved_markers.append(self.marker) def pop_marker(self, reset): """ Pop a marker off of the marker stack. If reset is True then the iterator will be returned to the state it was in before the corresponding call to push_marker(). """ saved = self.saved_markers.pop() if reset: self.marker = saved elif self.saved_markers: self.saved_markers[-1] = saved ================================================ FILE: baseline_tokenization/subtokenize_nmt_baseline.py ================================================ #!/usr/bin/python import javalang import sys import re modifiers = ['public', 'private', 'protected', 'static'] RE_WORDS = re.compile(r''' # Find words in a string. Order matters! [A-Z]+(?=[A-Z][a-z]) | # All upper case before a capitalized word [A-Z]?[a-z]+ | # Capitalized words / all lower case [A-Z]+ | # All upper case \d+ | # Numbers .+ ''', re.VERBOSE) def split_subtokens(str): return [subtok for subtok in RE_WORDS.findall(str) if not subtok == '_'] def tokenizeFile(file_path): lines = 0 with open(file_path, 'r', encoding="utf-8") as file: with open(file_path + 'method_names.txt', 'w') as method_names_file: with open(file_path + 'method_subtokens_content.txt', 'w') as method_contents_file: for line in file: lines += 1 line = line.rstrip() parts = line.split('|', 1) method_name = parts[0] method_content = parts[1] try: tokens = list(javalang.tokenizer.tokenize(method_content)) except: print('ERROR in tokenizing: ' + method_content) #tokens = method_content.split(' ') if len(method_name) > 0 and len(tokens) > 0: method_names_file.write(method_name + '\n') method_contents_file.write(' '.join([' '.join(split_subtokens(i.value)) for i in tokens if not i.value in modifiers]) + '\n') else: print('ERROR in len of: ' + method_name + ', tokens: ' + str(tokens)) print(str(lines)) if __name__ == '__main__': file = sys.argv[1] tokenizeFile(file) ================================================ FILE: code2seq.py ================================================ from argparse import ArgumentParser import numpy as np import tensorflow as tf from config import Config from interactive_predict import InteractivePredictor from model import Model if __name__ == '__main__': parser = ArgumentParser() parser.add_argument("-d", "--data", dest="data_path", help="path to preprocessed dataset", required=False) parser.add_argument("-te", "--test", dest="test_path", help="path to test file", metavar="FILE", required=False) parser.add_argument("-s", "--save_prefix", dest="save_path_prefix", help="path to save file", metavar="FILE", required=False) parser.add_argument("-l", "--load", dest="load_path", help="path to saved file", metavar="FILE", required=False) parser.add_argument('--release', action='store_true', help='if specified and loading a trained model, release the loaded model for a smaller model ' 'size.') parser.add_argument('--predict', action='store_true') parser.add_argument('--debug', action='store_true') parser.add_argument('--seed', type=int, default=239) args = parser.parse_args() np.random.seed(args.seed) tf.set_random_seed(args.seed) if args.debug: config = Config.get_debug_config(args) else: config = Config.get_default_config(args) model = Model(config) print('Created model') if config.TRAIN_PATH: model.train() if config.TEST_PATH and not args.data_path: results, precision, recall, f1, rouge = model.evaluate() print('Accuracy: ' + str(results)) print('Precision: ' + str(precision) + ', recall: ' + str(recall) + ', F1: ' + str(f1)) print('Rouge: ', rouge) if args.predict: predictor = InteractivePredictor(config, model) predictor.predict() if args.release and args.load_path: model.evaluate(release=True) model.close_session() ================================================ FILE: common.py ================================================ import re import subprocess import sys class Common: internal_delimiter = '|' SOS = '' EOS = '' PAD = '' UNK = '' @staticmethod def normalize_word(word): stripped = re.sub(r'[^a-zA-Z]', '', word) if len(stripped) == 0: return word.lower() else: return stripped.lower() @staticmethod def load_histogram(path, max_size=None): histogram = {} with open(path, 'r') as file: for line in file.readlines(): parts = line.split(' ') if not len(parts) == 2: continue histogram[parts[0]] = int(parts[1]) sorted_histogram = [(k, histogram[k]) for k in sorted(histogram, key=histogram.get, reverse=True)] return dict(sorted_histogram[:max_size]) @staticmethod def load_vocab_from_dict(word_to_count, add_values=[], max_size=None): word_to_index, index_to_word = {}, {} current_index = 0 for value in add_values: word_to_index[value] = current_index index_to_word[current_index] = value current_index += 1 sorted_counts = [(k, word_to_count[k]) for k in sorted(word_to_count, key=word_to_count.get, reverse=True)] limited_sorted = dict(sorted_counts[:max_size]) for word, count in limited_sorted.items(): word_to_index[word] = current_index index_to_word[current_index] = word current_index += 1 return word_to_index, index_to_word, current_index @staticmethod def binary_to_string(binary_string): return binary_string.decode("utf-8") @staticmethod def binary_to_string_list(binary_string_list): return [Common.binary_to_string(w) for w in binary_string_list] @staticmethod def binary_to_string_matrix(binary_string_matrix): return [Common.binary_to_string_list(l) for l in binary_string_matrix] @staticmethod def binary_to_string_3d(binary_string_tensor): return [Common.binary_to_string_matrix(l) for l in binary_string_tensor] @staticmethod def legal_method_names_checker(name): return not name in [Common.UNK, Common.PAD, Common.EOS] @staticmethod def filter_impossible_names(top_words): result = list(filter(Common.legal_method_names_checker, top_words)) return result @staticmethod def unique(sequence): return list(set(sequence)) @staticmethod def parse_results(result, pc_info_dict, topk=5): prediction_results = {} results_counter = 0 for single_method in result: original_name, top_suggestions, top_scores, attention_per_context = list(single_method) current_method_prediction_results = PredictionResults(original_name) if attention_per_context is not None: word_attention_pairs = [(word, attention) for word, attention in zip(top_suggestions, attention_per_context) if Common.legal_method_names_checker(word)] for predicted_word, attention_timestep in word_attention_pairs: current_timestep_paths = [] for context, attention in [(key, attention_timestep[key]) for key in sorted(attention_timestep, key=attention_timestep.get, reverse=True)][ :topk]: if context in pc_info_dict: pc_info = pc_info_dict[context] current_timestep_paths.append((attention.item(), pc_info)) current_method_prediction_results.append_prediction(predicted_word, current_timestep_paths) else: for predicted_seq in top_suggestions: filtered_seq = [word for word in predicted_seq if Common.legal_method_names_checker(word)] current_method_prediction_results.append_prediction(filtered_seq, None) prediction_results[results_counter] = current_method_prediction_results results_counter += 1 return prediction_results @staticmethod def compute_bleu(ref_file_name, predicted_file_name): with open(predicted_file_name) as predicted_file: pipe = subprocess.Popen(["perl", "scripts/multi-bleu.perl", ref_file_name], stdin=predicted_file, stdout=sys.stdout, stderr=sys.stderr) class PredictionResults: def __init__(self, original_name): self.original_name = original_name self.predictions = list() def append_prediction(self, name, current_timestep_paths): self.predictions.append(SingleTimeStepPrediction(name, current_timestep_paths)) class SingleTimeStepPrediction: def __init__(self, prediction, attention_paths): self.prediction = prediction if attention_paths is not None: paths_with_scores = [] for attention_score, pc_info in attention_paths: path_context_dict = {'score': attention_score, 'path': pc_info.longPath, 'token1': pc_info.token1, 'token2': pc_info.token2} paths_with_scores.append(path_context_dict) self.attention_paths = paths_with_scores class PathContextInformation: def __init__(self, context): self.token1 = context['name1'] self.longPath = context['path'] self.shortPath = context['shortPath'] self.token2 = context['name2'] def __str__(self): return '%s,%s,%s' % (self.token1, self.shortPath, self.token2) ================================================ FILE: config.py ================================================ class Config: @staticmethod def get_default_config(args): config = Config(args) config.NUM_EPOCHS = 3000 config.SAVE_EVERY_EPOCHS = 1 config.PATIENCE = 10 config.BATCH_SIZE = 512 config.TEST_BATCH_SIZE = 256 config.READER_NUM_PARALLEL_BATCHES = 1 config.SHUFFLE_BUFFER_SIZE = 10000 config.CSV_BUFFER_SIZE = 100 * 1024 * 1024 # 100 MB config.MAX_CONTEXTS = 200 config.SUBTOKENS_VOCAB_MAX_SIZE = 190000 config.TARGET_VOCAB_MAX_SIZE = 27000 config.EMBEDDINGS_SIZE = 128 config.RNN_SIZE = 128 * 2 # Two LSTMs to embed paths, each of size 128 config.DECODER_SIZE = 320 config.NUM_DECODER_LAYERS = 1 config.MAX_PATH_LENGTH = 8 + 1 config.MAX_NAME_PARTS = 5 config.MAX_TARGET_PARTS = 6 config.EMBEDDINGS_DROPOUT_KEEP_PROB = 0.75 config.RNN_DROPOUT_KEEP_PROB = 0.5 config.BIRNN = True config.RANDOM_CONTEXTS = True config.BEAM_WIDTH = 0 config.USE_MOMENTUM = True return config def take_model_hyperparams_from(self, otherConfig): self.EMBEDDINGS_SIZE = otherConfig.EMBEDDINGS_SIZE self.RNN_SIZE = otherConfig.RNN_SIZE self.DECODER_SIZE = otherConfig.DECODER_SIZE self.NUM_DECODER_LAYERS = otherConfig.NUM_DECODER_LAYERS self.BIRNN = otherConfig.BIRNN if self.DATA_NUM_CONTEXTS <= 0: self.DATA_NUM_CONTEXTS = otherConfig.DATA_NUM_CONTEXTS def __init__(self, args): self.NUM_EPOCHS = 0 self.SAVE_EVERY_EPOCHS = 0 self.PATIENCE = 0 self.BATCH_SIZE = 0 self.TEST_BATCH_SIZE = 0 self.READER_NUM_PARALLEL_BATCHES = 0 self.SHUFFLE_BUFFER_SIZE = 0 self.CSV_BUFFER_SIZE = None self.TRAIN_PATH = args.data_path self.TEST_PATH = args.test_path if args.test_path is not None else '' self.DATA_NUM_CONTEXTS = 0 self.MAX_CONTEXTS = 0 self.SUBTOKENS_VOCAB_MAX_SIZE = 0 self.TARGET_VOCAB_MAX_SIZE = 0 self.EMBEDDINGS_SIZE = 0 self.RNN_SIZE = 0 self.DECODER_SIZE = 0 self.NUM_DECODER_LAYERS = 0 self.SAVE_PATH = args.save_path_prefix self.LOAD_PATH = args.load_path self.MAX_PATH_LENGTH = 0 self.MAX_NAME_PARTS = 0 self.MAX_TARGET_PARTS = 0 self.EMBEDDINGS_DROPOUT_KEEP_PROB = 0 self.RNN_DROPOUT_KEEP_PROB = 0 self.BIRNN = False self.RANDOM_CONTEXTS = True self.BEAM_WIDTH = 1 self.USE_MOMENTUM = True self.RELEASE = args.release @staticmethod def get_debug_config(args): config = Config(args) config.NUM_EPOCHS = 3000 config.SAVE_EVERY_EPOCHS = 100 config.PATIENCE = 200 config.BATCH_SIZE = 7 config.TEST_BATCH_SIZE = 7 config.READER_NUM_PARALLEL_BATCHES = 1 config.SHUFFLE_BUFFER_SIZE = 10 config.CSV_BUFFER_SIZE = None config.MAX_CONTEXTS = 5 config.SUBTOKENS_VOCAB_MAX_SIZE = 190000 config.TARGET_VOCAB_MAX_SIZE = 27000 config.EMBEDDINGS_SIZE = 19 config.RNN_SIZE = 10 config.DECODER_SIZE = 11 config.NUM_DECODER_LAYERS = 1 config.MAX_PATH_LENGTH = 8 + 1 config.MAX_NAME_PARTS = 5 config.MAX_TARGET_PARTS = 6 config.EMBEDDINGS_DROPOUT_KEEP_PROB = 1 config.RNN_DROPOUT_KEEP_PROB = 1 config.BIRNN = True config.RANDOM_CONTEXTS = True config.BEAM_WIDTH = 0 config.USE_MOMENTUM = False return config ================================================ FILE: extractor.py ================================================ import json import requests from common import PathContextInformation class Extractor: def __init__(self, config, extractor_api_url, max_path_length, max_path_width): self.config = config self.max_path_length = max_path_length self.max_path_width = max_path_width self.extractor_api_url = extractor_api_url self.bad_characters_table = str.maketrans('', '', '\t\r\n') @staticmethod def post_request(url, code_string): return requests.post(url, data=json.dumps({"code": code_string, "decompose": True}, separators=(',', ':'))) def extract_paths(self, code_string): response = self.post_request(self.extractor_api_url, code_string) response_array = json.loads(response.text) if 'errorType' in response_array: raise ValueError(response.text) if 'errorMessage' in response_array: raise TimeoutError(response.text) pc_info_dict = {} result = [] for single_method in response_array: method_name = single_method['target'] current_result_line_parts = [method_name] contexts = single_method['paths'] for context in contexts[:self.config.DATA_NUM_CONTEXTS]: pc_info = PathContextInformation(context) current_result_line_parts += [str(pc_info)] pc_info_dict[(pc_info.token1, pc_info.shortPath, pc_info.token2)] = pc_info space_padding = ' ' * (self.config.DATA_NUM_CONTEXTS - len(contexts)) result_line = ' '.join(current_result_line_parts) + space_padding result.append(result_line) return result, pc_info_dict ================================================ FILE: interactive_predict.py ================================================ from common import Common from extractor import Extractor SHOW_TOP_CONTEXTS = 10 MAX_PATH_LENGTH = 8 MAX_PATH_WIDTH = 2 EXTRACTION_API = 'https://po3g2dx2qa.execute-api.us-east-1.amazonaws.com/production/extractmethods' class InteractivePredictor: exit_keywords = ['exit', 'quit', 'q'] def __init__(self, config, model): model.predict([]) self.model = model self.config = config self.path_extractor = Extractor(config, EXTRACTION_API, self.config.MAX_PATH_LENGTH, max_path_width=2) @staticmethod def read_file(input_filename): with open(input_filename, 'r') as file: return file.readlines() def predict(self): input_filename = 'Input.java' print('Serving') while True: print('Modify the file: "' + input_filename + '" and press any key when ready, or "q" / "exit" to exit') user_input = input() if user_input.lower() in self.exit_keywords: print('Exiting...') return user_input = ' '.join(self.read_file(input_filename)) try: predict_lines, pc_info_dict = self.path_extractor.extract_paths(user_input) except ValueError: continue model_results = self.model.predict(predict_lines) prediction_results = Common.parse_results(model_results, pc_info_dict, topk=SHOW_TOP_CONTEXTS) for index, method_prediction in prediction_results.items(): print('Original name:\t' + method_prediction.original_name) if self.config.BEAM_WIDTH == 0: print('Predicted:\t%s' % [step.prediction for step in method_prediction.predictions]) for timestep, single_timestep_prediction in enumerate(method_prediction.predictions): print('Attention:') print('TIMESTEP: %d\t: %s' % (timestep, single_timestep_prediction.prediction)) for attention_obj in single_timestep_prediction.attention_paths: print('%f\tcontext: %s,%s,%s' % ( attention_obj['score'], attention_obj['token1'], attention_obj['path'], attention_obj['token2'])) else: print('Predicted:') for predicted_seq in method_prediction.predictions: print('\t%s' % predicted_seq.prediction) ================================================ FILE: model.py ================================================ import _pickle as pickle import os import time import numpy as np import shutil import tensorflow as tf import reader from common import Common from rouge import FilesRouge class Model: topk = 10 num_batches_to_log = 100 def __init__(self, config): self.config = config self.sess = tf.Session() self.eval_queue = None self.predict_queue = None self.eval_placeholder = None self.predict_placeholder = None self.eval_predicted_indices_op, self.eval_top_values_op, self.eval_true_target_strings_op, self.eval_topk_values = None, None, None, None self.predict_top_indices_op, self.predict_top_scores_op, self.predict_target_strings_op = None, None, None self.subtoken_to_index = None if config.LOAD_PATH: self.load_model(sess=None) else: with open('{}.dict.c2s'.format(config.TRAIN_PATH), 'rb') as file: subtoken_to_count = pickle.load(file) node_to_count = pickle.load(file) target_to_count = pickle.load(file) max_contexts = pickle.load(file) self.num_training_examples = pickle.load(file) print('Dictionaries loaded.') if self.config.DATA_NUM_CONTEXTS <= 0: self.config.DATA_NUM_CONTEXTS = max_contexts self.subtoken_to_index, self.index_to_subtoken, self.subtoken_vocab_size = \ Common.load_vocab_from_dict(subtoken_to_count, add_values=[Common.PAD, Common.UNK], max_size=config.SUBTOKENS_VOCAB_MAX_SIZE) print('Loaded subtoken vocab. size: %d' % self.subtoken_vocab_size) self.target_to_index, self.index_to_target, self.target_vocab_size = \ Common.load_vocab_from_dict(target_to_count, add_values=[Common.PAD, Common.UNK, Common.SOS], max_size=config.TARGET_VOCAB_MAX_SIZE) print('Loaded target word vocab. size: %d' % self.target_vocab_size) self.node_to_index, self.index_to_node, self.nodes_vocab_size = \ Common.load_vocab_from_dict(node_to_count, add_values=[Common.PAD, Common.UNK], max_size=None) print('Loaded nodes vocab. size: %d' % self.nodes_vocab_size) self.epochs_trained = 0 def close_session(self): self.sess.close() def train(self): print('Starting training') start_time = time.time() batch_num = 0 sum_loss = 0 best_f1 = 0 best_epoch = 0 best_f1_precision = 0 best_f1_recall = 0 epochs_no_improve = 0 self.queue_thread = reader.Reader(subtoken_to_index=self.subtoken_to_index, node_to_index=self.node_to_index, target_to_index=self.target_to_index, config=self.config) optimizer, train_loss = self.build_training_graph(self.queue_thread.get_output()) self.print_hyperparams() print('Number of trainable params:', np.sum([np.prod(v.get_shape().as_list()) for v in tf.trainable_variables()])) self.initialize_session_variables(self.sess) print('Initalized variables') if self.config.LOAD_PATH: self.load_model(self.sess) time.sleep(1) print('Started reader...') multi_batch_start_time = time.time() for iteration in range(1, (self.config.NUM_EPOCHS // self.config.SAVE_EVERY_EPOCHS) + 1): self.queue_thread.reset(self.sess) try: while True: batch_num += 1 _, batch_loss = self.sess.run([optimizer, train_loss]) sum_loss += batch_loss # print('SINGLE BATCH LOSS', batch_loss) if batch_num % self.num_batches_to_log == 0: self.trace(sum_loss, batch_num, multi_batch_start_time) sum_loss = 0 multi_batch_start_time = time.time() except tf.errors.OutOfRangeError: self.epochs_trained += self.config.SAVE_EVERY_EPOCHS print('Finished %d epochs' % self.config.SAVE_EVERY_EPOCHS) results, precision, recall, f1, rouge = self.evaluate() if self.config.BEAM_WIDTH == 0: print('Accuracy after %d epochs: %.5f' % (self.epochs_trained, results)) else: print('Accuracy after {} epochs: {}'.format(self.epochs_trained, results)) print('After %d epochs: Precision: %.5f, recall: %.5f, F1: %.5f' % ( self.epochs_trained, precision, recall, f1)) print('Rouge: ', rouge) if f1 > best_f1: best_f1 = f1 best_f1_precision = precision best_f1_recall = recall best_epoch = self.epochs_trained epochs_no_improve = 0 self.save_model(self.sess, self.config.SAVE_PATH) else: epochs_no_improve += self.config.SAVE_EVERY_EPOCHS if epochs_no_improve >= self.config.PATIENCE: print('Not improved for %d epochs, stopping training' % self.config.PATIENCE) print('Best scores - epoch %d: ' % best_epoch) print('Precision: %.5f, recall: %.5f, F1: %.5f' % (best_f1_precision, best_f1_recall, best_f1)) return if self.config.SAVE_PATH: self.save_model(self.sess, self.config.SAVE_PATH + '.final') print('Model saved in file: %s' % self.config.SAVE_PATH) elapsed = int(time.time() - start_time) print("Training time: %sh%sm%ss\n" % ((elapsed // 60 // 60), (elapsed // 60) % 60, elapsed % 60)) def trace(self, sum_loss, batch_num, multi_batch_start_time): multi_batch_elapsed = time.time() - multi_batch_start_time avg_loss = sum_loss / self.num_batches_to_log print('Average loss at batch %d: %f, \tthroughput: %d samples/sec' % (batch_num, avg_loss, self.config.BATCH_SIZE * self.num_batches_to_log / ( multi_batch_elapsed if multi_batch_elapsed > 0 else 1))) def evaluate(self, release=False): eval_start_time = time.time() if self.eval_queue is None: self.eval_queue = reader.Reader(subtoken_to_index=self.subtoken_to_index, node_to_index=self.node_to_index, target_to_index=self.target_to_index, config=self.config, is_evaluating=True) reader_output = self.eval_queue.get_output() self.eval_predicted_indices_op, self.eval_topk_values, _, _ = \ self.build_test_graph(reader_output) self.eval_true_target_strings_op = reader_output[reader.TARGET_STRING_KEY] self.saver = tf.train.Saver(max_to_keep=10) if self.config.LOAD_PATH and not self.config.TRAIN_PATH: self.initialize_session_variables(self.sess) self.load_model(self.sess) if release: release_name = self.config.LOAD_PATH + '.release' print('Releasing model, output model: %s' % release_name) self.saver.save(self.sess, release_name) shutil.copyfile(src=self.config.LOAD_PATH + '.dict', dst=release_name + '.dict') return None model_dirname = os.path.dirname(self.config.SAVE_PATH if self.config.SAVE_PATH else self.config.LOAD_PATH) ref_file_name = model_dirname + '/ref.txt' predicted_file_name = model_dirname + '/pred.txt' if not os.path.exists(model_dirname): os.makedirs(model_dirname) with open(model_dirname + '/log.txt', 'w') as output_file, open(ref_file_name, 'w') as ref_file, open( predicted_file_name, 'w') as pred_file: num_correct_predictions = 0 if self.config.BEAM_WIDTH == 0 \ else np.zeros([self.config.BEAM_WIDTH], dtype=np.int32) total_predictions = 0 total_prediction_batches = 0 true_positive, false_positive, false_negative = 0, 0, 0 self.eval_queue.reset(self.sess) start_time = time.time() try: while True: predicted_indices, true_target_strings, top_values = self.sess.run( [self.eval_predicted_indices_op, self.eval_true_target_strings_op, self.eval_topk_values], ) true_target_strings = Common.binary_to_string_list(true_target_strings) ref_file.write( '\n'.join( [name.replace(Common.internal_delimiter, ' ') for name in true_target_strings]) + '\n') if self.config.BEAM_WIDTH > 0: # predicted indices: (batch, time, beam_width) predicted_strings = [[[self.index_to_target[i] for i in timestep] for timestep in example] for example in predicted_indices] predicted_strings = [list(map(list, zip(*example))) for example in predicted_strings] # (batch, top-k, target_length) pred_file.write('\n'.join( [' '.join(Common.filter_impossible_names(words)) for words in predicted_strings[0]]) + '\n') else: predicted_strings = [[self.index_to_target[i] for i in example] for example in predicted_indices] pred_file.write('\n'.join( [' '.join(Common.filter_impossible_names(words)) for words in predicted_strings]) + '\n') num_correct_predictions = self.update_correct_predictions(num_correct_predictions, output_file, zip(true_target_strings, predicted_strings)) true_positive, false_positive, false_negative = self.update_per_subtoken_statistics( zip(true_target_strings, predicted_strings), true_positive, false_positive, false_negative) total_predictions += len(true_target_strings) total_prediction_batches += 1 if total_prediction_batches % self.num_batches_to_log == 0: elapsed = time.time() - start_time self.trace_evaluation(output_file, num_correct_predictions, total_predictions, elapsed) except tf.errors.OutOfRangeError: pass print('Done testing, epoch reached') output_file.write(str(num_correct_predictions / total_predictions) + '\n') # Common.compute_bleu(ref_file_name, predicted_file_name) elapsed = int(time.time() - eval_start_time) precision, recall, f1 = self.calculate_results(true_positive, false_positive, false_negative) try: files_rouge = FilesRouge() rouge = files_rouge.get_scores( hyp_path=predicted_file_name, ref_path=ref_file_name, avg=True, ignore_empty=True) except ValueError: rouge = 0 print("Evaluation time: %sh%sm%ss" % ((elapsed // 60 // 60), (elapsed // 60) % 60, elapsed % 60)) return num_correct_predictions / total_predictions, \ precision, recall, f1, rouge def update_correct_predictions(self, num_correct_predictions, output_file, results): for original_name, predicted in results: original_name_parts = original_name.split(Common.internal_delimiter) # list filtered_original = Common.filter_impossible_names(original_name_parts) # list predicted_first = predicted if self.config.BEAM_WIDTH > 0: predicted_first = predicted[0] filtered_predicted_first_parts = Common.filter_impossible_names(predicted_first) # list if self.config.BEAM_WIDTH == 0: output_file.write('Original: ' + Common.internal_delimiter.join(original_name_parts) + ' , predicted 1st: ' + Common.internal_delimiter.join(filtered_predicted_first_parts) + '\n') if filtered_original == filtered_predicted_first_parts or Common.unique(filtered_original) == Common.unique( filtered_predicted_first_parts) or ''.join(filtered_original) == ''.join(filtered_predicted_first_parts): num_correct_predictions += 1 else: filtered_predicted = [Common.internal_delimiter.join(Common.filter_impossible_names(p)) for p in predicted] true_ref = original_name output_file.write('Original: ' + ' '.join(original_name_parts) + '\n') for i, p in enumerate(filtered_predicted): output_file.write('\t@{}: {}'.format(i + 1, ' '.join(p.split(Common.internal_delimiter)))+ '\n') if true_ref in filtered_predicted: index_of_correct = filtered_predicted.index(true_ref) update = np.concatenate( [np.zeros(index_of_correct, dtype=np.int32), np.ones(self.config.BEAM_WIDTH - index_of_correct, dtype=np.int32)]) num_correct_predictions += update return num_correct_predictions def update_per_subtoken_statistics(self, results, true_positive, false_positive, false_negative): for original_name, predicted in results: if self.config.BEAM_WIDTH > 0: predicted = predicted[0] filtered_predicted_names = Common.filter_impossible_names(predicted) filtered_original_subtokens = Common.filter_impossible_names(original_name.split(Common.internal_delimiter)) if ''.join(filtered_original_subtokens) == ''.join(filtered_predicted_names): true_positive += len(filtered_original_subtokens) continue for subtok in filtered_predicted_names: if subtok in filtered_original_subtokens: true_positive += 1 else: false_positive += 1 for subtok in filtered_original_subtokens: if not subtok in filtered_predicted_names: false_negative += 1 return true_positive, false_positive, false_negative def print_hyperparams(self): print('Training batch size:\t\t\t', self.config.BATCH_SIZE) print('Dataset path:\t\t\t\t', self.config.TRAIN_PATH) print('Training file path:\t\t\t', self.config.TRAIN_PATH + '.train.c2s') print('Validation path:\t\t\t', self.config.TEST_PATH) print('Taking max contexts from each example:\t', self.config.MAX_CONTEXTS) print('Random path sampling:\t\t\t', self.config.RANDOM_CONTEXTS) print('Embedding size:\t\t\t\t', self.config.EMBEDDINGS_SIZE) if self.config.BIRNN: print('Using BiLSTMs, each of size:\t\t', self.config.RNN_SIZE // 2) else: print('Uni-directional LSTM of size:\t\t', self.config.RNN_SIZE) print('Decoder size:\t\t\t\t', self.config.DECODER_SIZE) print('Decoder layers:\t\t\t\t', self.config.NUM_DECODER_LAYERS) print('Max path lengths:\t\t\t', self.config.MAX_PATH_LENGTH) print('Max subtokens in a token:\t\t', self.config.MAX_NAME_PARTS) print('Max target length:\t\t\t', self.config.MAX_TARGET_PARTS) print('Embeddings dropout keep_prob:\t\t', self.config.EMBEDDINGS_DROPOUT_KEEP_PROB) print('LSTM dropout keep_prob:\t\t\t', self.config.RNN_DROPOUT_KEEP_PROB) print('============================================') @staticmethod def calculate_results(true_positive, false_positive, false_negative): if true_positive + false_positive > 0: precision = true_positive / (true_positive + false_positive) else: precision = 0 if true_positive + false_negative > 0: recall = true_positive / (true_positive + false_negative) else: recall = 0 if precision + recall > 0: f1 = 2 * precision * recall / (precision + recall) else: f1 = 0 return precision, recall, f1 @staticmethod def trace_evaluation(output_file, correct_predictions, total_predictions, elapsed): accuracy_message = str(correct_predictions / total_predictions) throughput_message = "Prediction throughput: %d" % int(total_predictions / (elapsed if elapsed > 0 else 1)) output_file.write(accuracy_message + '\n') output_file.write(throughput_message) # print(accuracy_message) print(throughput_message) def build_training_graph(self, input_tensors): target_index = input_tensors[reader.TARGET_INDEX_KEY] target_lengths = input_tensors[reader.TARGET_LENGTH_KEY] path_source_indices = input_tensors[reader.PATH_SOURCE_INDICES_KEY] node_indices = input_tensors[reader.NODE_INDICES_KEY] path_target_indices = input_tensors[reader.PATH_TARGET_INDICES_KEY] valid_context_mask = input_tensors[reader.VALID_CONTEXT_MASK_KEY] path_source_lengths = input_tensors[reader.PATH_SOURCE_LENGTHS_KEY] path_lengths = input_tensors[reader.PATH_LENGTHS_KEY] path_target_lengths = input_tensors[reader.PATH_TARGET_LENGTHS_KEY] with tf.variable_scope('model'): subtoken_vocab = tf.get_variable('SUBTOKENS_VOCAB', shape=(self.subtoken_vocab_size, self.config.EMBEDDINGS_SIZE), dtype=tf.float32, initializer=tf.contrib.layers.variance_scaling_initializer(factor=1.0, mode='FAN_OUT', uniform=True)) target_words_vocab = tf.get_variable('TARGET_WORDS_VOCAB', shape=(self.target_vocab_size, self.config.EMBEDDINGS_SIZE), dtype=tf.float32, initializer=tf.contrib.layers.variance_scaling_initializer(factor=1.0, mode='FAN_OUT', uniform=True)) nodes_vocab = tf.get_variable('NODES_VOCAB', shape=(self.nodes_vocab_size, self.config.EMBEDDINGS_SIZE), dtype=tf.float32, initializer=tf.contrib.layers.variance_scaling_initializer(factor=1.0, mode='FAN_OUT', uniform=True)) # (batch, max_contexts, decoder_size) batched_contexts = self.compute_contexts(subtoken_vocab=subtoken_vocab, nodes_vocab=nodes_vocab, source_input=path_source_indices, nodes_input=node_indices, target_input=path_target_indices, valid_mask=valid_context_mask, path_source_lengths=path_source_lengths, path_lengths=path_lengths, path_target_lengths=path_target_lengths) batch_size = tf.shape(target_index)[0] outputs, final_states = self.decode_outputs(target_words_vocab=target_words_vocab, target_input=target_index, batch_size=batch_size, batched_contexts=batched_contexts, valid_mask=valid_context_mask) step = tf.Variable(0, trainable=False) logits = outputs.rnn_output # (batch, max_output_length, dim * 2 + rnn_size) crossent = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=target_index, logits=logits) target_words_nonzero = tf.sequence_mask(target_lengths + 1, maxlen=self.config.MAX_TARGET_PARTS + 1, dtype=tf.float32) loss = tf.reduce_sum(crossent * target_words_nonzero) / tf.to_float(batch_size) if self.config.USE_MOMENTUM: learning_rate = tf.train.exponential_decay(0.01, step * self.config.BATCH_SIZE, self.num_training_examples, 0.95, staircase=True) optimizer = tf.train.MomentumOptimizer(learning_rate, 0.95, use_nesterov=True) train_op = optimizer.minimize(loss, global_step=step) else: params = tf.trainable_variables() gradients = tf.gradients(loss, params) clipped_gradients, _ = tf.clip_by_global_norm(gradients, clip_norm=5) optimizer = tf.train.AdamOptimizer() train_op = optimizer.apply_gradients(zip(clipped_gradients, params)) self.saver = tf.train.Saver(max_to_keep=10) return train_op, loss def decode_outputs(self, target_words_vocab, target_input, batch_size, batched_contexts, valid_mask, is_evaluating=False): num_contexts_per_example = tf.count_nonzero(valid_mask, axis=-1) start_fill = tf.fill([batch_size], self.target_to_index[Common.SOS]) # (batch, ) decoder_cell = tf.nn.rnn_cell.MultiRNNCell([ tf.nn.rnn_cell.LSTMCell(self.config.DECODER_SIZE) for _ in range(self.config.NUM_DECODER_LAYERS) ]) contexts_sum = tf.reduce_sum(batched_contexts * tf.expand_dims(valid_mask, -1), axis=1) # (batch_size, dim * 2 + rnn_size) contexts_average = tf.divide(contexts_sum, tf.to_float(tf.expand_dims(num_contexts_per_example, -1))) fake_encoder_state = tuple(tf.nn.rnn_cell.LSTMStateTuple(contexts_average, contexts_average) for _ in range(self.config.NUM_DECODER_LAYERS)) projection_layer = tf.layers.Dense(self.target_vocab_size, use_bias=False) if is_evaluating and self.config.BEAM_WIDTH > 0: batched_contexts = tf.contrib.seq2seq.tile_batch(batched_contexts, multiplier=self.config.BEAM_WIDTH) num_contexts_per_example = tf.contrib.seq2seq.tile_batch(num_contexts_per_example, multiplier=self.config.BEAM_WIDTH) attention_mechanism = tf.contrib.seq2seq.LuongAttention( num_units=self.config.DECODER_SIZE, memory=batched_contexts ) # TF doesn't support beam search with alignment history should_save_alignment_history = is_evaluating and self.config.BEAM_WIDTH == 0 decoder_cell = tf.contrib.seq2seq.AttentionWrapper(decoder_cell, attention_mechanism, attention_layer_size=self.config.DECODER_SIZE, alignment_history=should_save_alignment_history) if is_evaluating: if self.config.BEAM_WIDTH > 0: decoder_initial_state = decoder_cell.zero_state(dtype=tf.float32, batch_size=batch_size * self.config.BEAM_WIDTH) decoder_initial_state = decoder_initial_state.clone( cell_state=tf.contrib.seq2seq.tile_batch(fake_encoder_state, multiplier=self.config.BEAM_WIDTH)) decoder = tf.contrib.seq2seq.BeamSearchDecoder( cell=decoder_cell, embedding=target_words_vocab, start_tokens=start_fill, end_token=self.target_to_index[Common.PAD], initial_state=decoder_initial_state, beam_width=self.config.BEAM_WIDTH, output_layer=projection_layer, length_penalty_weight=0.0) else: helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(target_words_vocab, start_fill, 0) initial_state = decoder_cell.zero_state(batch_size, tf.float32).clone(cell_state=fake_encoder_state) decoder = tf.contrib.seq2seq.BasicDecoder(cell=decoder_cell, helper=helper, initial_state=initial_state, output_layer=projection_layer) else: decoder_cell = tf.nn.rnn_cell.DropoutWrapper(decoder_cell, output_keep_prob=self.config.RNN_DROPOUT_KEEP_PROB) target_words_embedding = tf.nn.embedding_lookup(target_words_vocab, tf.concat([tf.expand_dims(start_fill, -1), target_input], axis=-1)) # (batch, max_target_parts, dim * 2 + rnn_size) helper = tf.contrib.seq2seq.TrainingHelper(inputs=target_words_embedding, sequence_length=tf.ones([batch_size], dtype=tf.int32) * ( self.config.MAX_TARGET_PARTS + 1)) initial_state = decoder_cell.zero_state(batch_size, tf.float32).clone(cell_state=fake_encoder_state) decoder = tf.contrib.seq2seq.BasicDecoder(cell=decoder_cell, helper=helper, initial_state=initial_state, output_layer=projection_layer) outputs, final_states, final_sequence_lengths = tf.contrib.seq2seq.dynamic_decode(decoder, maximum_iterations=self.config.MAX_TARGET_PARTS + 1) return outputs, final_states def calculate_path_abstraction(self, path_embed, path_lengths, valid_contexts_mask, is_evaluating=False): return self.path_rnn_last_state(is_evaluating, path_embed, path_lengths, valid_contexts_mask) def path_rnn_last_state(self, is_evaluating, path_embed, path_lengths, valid_contexts_mask): # path_embed: (batch, max_contexts, max_path_length+1, dim) # path_length: (batch, max_contexts) # valid_contexts_mask: (batch, max_contexts) max_contexts = tf.shape(path_embed)[1] flat_paths = tf.reshape(path_embed, shape=[-1, self.config.MAX_PATH_LENGTH, self.config.EMBEDDINGS_SIZE]) # (batch * max_contexts, max_path_length+1, dim) flat_valid_contexts_mask = tf.reshape(valid_contexts_mask, [-1]) # (batch * max_contexts) lengths = tf.multiply(tf.reshape(path_lengths, [-1]), tf.cast(flat_valid_contexts_mask, tf.int32)) # (batch * max_contexts) if self.config.BIRNN: rnn_cell_fw = tf.nn.rnn_cell.LSTMCell(self.config.RNN_SIZE / 2) rnn_cell_bw = tf.nn.rnn_cell.LSTMCell(self.config.RNN_SIZE / 2) if not is_evaluating: rnn_cell_fw = tf.nn.rnn_cell.DropoutWrapper(rnn_cell_fw, output_keep_prob=self.config.RNN_DROPOUT_KEEP_PROB) rnn_cell_bw = tf.nn.rnn_cell.DropoutWrapper(rnn_cell_bw, output_keep_prob=self.config.RNN_DROPOUT_KEEP_PROB) _, (state_fw, state_bw) = tf.nn.bidirectional_dynamic_rnn( cell_fw=rnn_cell_fw, cell_bw=rnn_cell_bw, inputs=flat_paths, dtype=tf.float32, sequence_length=lengths) final_rnn_state = tf.concat([state_fw.h, state_bw.h], axis=-1) # (batch * max_contexts, rnn_size) else: rnn_cell = tf.nn.rnn_cell.LSTMCell(self.config.RNN_SIZE) if not is_evaluating: rnn_cell = tf.nn.rnn_cell.DropoutWrapper(rnn_cell, output_keep_prob=self.config.RNN_DROPOUT_KEEP_PROB) _, state = tf.nn.dynamic_rnn( cell=rnn_cell, inputs=flat_paths, dtype=tf.float32, sequence_length=lengths ) final_rnn_state = state.h # (batch * max_contexts, rnn_size) return tf.reshape(final_rnn_state, shape=[-1, max_contexts, self.config.RNN_SIZE]) # (batch, max_contexts, rnn_size) def compute_contexts(self, subtoken_vocab, nodes_vocab, source_input, nodes_input, target_input, valid_mask, path_source_lengths, path_lengths, path_target_lengths, is_evaluating=False): source_word_embed = tf.nn.embedding_lookup(params=subtoken_vocab, ids=source_input) # (batch, max_contexts, max_name_parts, dim) path_embed = tf.nn.embedding_lookup(params=nodes_vocab, ids=nodes_input) # (batch, max_contexts, max_path_length+1, dim) target_word_embed = tf.nn.embedding_lookup(params=subtoken_vocab, ids=target_input) # (batch, max_contexts, max_name_parts, dim) source_word_mask = tf.expand_dims( tf.sequence_mask(path_source_lengths, maxlen=self.config.MAX_NAME_PARTS, dtype=tf.float32), -1) # (batch, max_contexts, max_name_parts, 1) target_word_mask = tf.expand_dims( tf.sequence_mask(path_target_lengths, maxlen=self.config.MAX_NAME_PARTS, dtype=tf.float32), -1) # (batch, max_contexts, max_name_parts, 1) source_words_sum = tf.reduce_sum(source_word_embed * source_word_mask, axis=2) # (batch, max_contexts, dim) path_nodes_aggregation = self.calculate_path_abstraction(path_embed, path_lengths, valid_mask, is_evaluating) # (batch, max_contexts, rnn_size) target_words_sum = tf.reduce_sum(target_word_embed * target_word_mask, axis=2) # (batch, max_contexts, dim) context_embed = tf.concat([source_words_sum, path_nodes_aggregation, target_words_sum], axis=-1) # (batch, max_contexts, dim * 2 + rnn_size) if not is_evaluating: context_embed = tf.nn.dropout(context_embed, self.config.EMBEDDINGS_DROPOUT_KEEP_PROB) batched_embed = tf.layers.dense(inputs=context_embed, units=self.config.DECODER_SIZE, activation=tf.nn.tanh, trainable=not is_evaluating, use_bias=False) return batched_embed def build_test_graph(self, input_tensors): target_index = input_tensors[reader.TARGET_INDEX_KEY] path_source_indices = input_tensors[reader.PATH_SOURCE_INDICES_KEY] node_indices = input_tensors[reader.NODE_INDICES_KEY] path_target_indices = input_tensors[reader.PATH_TARGET_INDICES_KEY] valid_mask = input_tensors[reader.VALID_CONTEXT_MASK_KEY] path_source_lengths = input_tensors[reader.PATH_SOURCE_LENGTHS_KEY] path_lengths = input_tensors[reader.PATH_LENGTHS_KEY] path_target_lengths = input_tensors[reader.PATH_TARGET_LENGTHS_KEY] with tf.variable_scope('model', reuse=self.get_should_reuse_variables()): subtoken_vocab = tf.get_variable('SUBTOKENS_VOCAB', shape=(self.subtoken_vocab_size, self.config.EMBEDDINGS_SIZE), dtype=tf.float32, trainable=False) target_words_vocab = tf.get_variable('TARGET_WORDS_VOCAB', shape=(self.target_vocab_size, self.config.EMBEDDINGS_SIZE), dtype=tf.float32, trainable=False) nodes_vocab = tf.get_variable('NODES_VOCAB', shape=(self.nodes_vocab_size, self.config.EMBEDDINGS_SIZE), dtype=tf.float32, trainable=False) batched_contexts = self.compute_contexts(subtoken_vocab=subtoken_vocab, nodes_vocab=nodes_vocab, source_input=path_source_indices, nodes_input=node_indices, target_input=path_target_indices, valid_mask=valid_mask, path_source_lengths=path_source_lengths, path_lengths=path_lengths, path_target_lengths=path_target_lengths, is_evaluating=True) outputs, final_states = self.decode_outputs(target_words_vocab=target_words_vocab, target_input=target_index, batch_size=tf.shape(target_index)[0], batched_contexts=batched_contexts, valid_mask=valid_mask, is_evaluating=True) if self.config.BEAM_WIDTH > 0: predicted_indices = outputs.predicted_ids topk_values = outputs.beam_search_decoder_output.scores attention_weights = [tf.no_op()] else: predicted_indices = outputs.sample_id topk_values = tf.constant(1, shape=(1, 1), dtype=tf.float32) attention_weights = tf.squeeze(final_states.alignment_history.stack(), 1) return predicted_indices, topk_values, target_index, attention_weights def predict(self, predict_data_lines): if self.predict_queue is None: self.predict_queue = reader.Reader(subtoken_to_index=self.subtoken_to_index, node_to_index=self.node_to_index, target_to_index=self.target_to_index, config=self.config, is_evaluating=True) self.predict_placeholder = tf.placeholder(tf.string) reader_output = self.predict_queue.process_from_placeholder(self.predict_placeholder) reader_output = {key: tf.expand_dims(tensor, 0) for key, tensor in reader_output.items()} self.predict_top_indices_op, self.predict_top_scores_op, _, self.attention_weights_op = \ self.build_test_graph(reader_output) self.predict_source_string = reader_output[reader.PATH_SOURCE_STRINGS_KEY] self.predict_path_string = reader_output[reader.PATH_STRINGS_KEY] self.predict_path_target_string = reader_output[reader.PATH_TARGET_STRINGS_KEY] self.predict_target_strings_op = reader_output[reader.TARGET_STRING_KEY] self.initialize_session_variables(self.sess) self.saver = tf.train.Saver() self.load_model(self.sess) results = [] for line in predict_data_lines: predicted_indices, top_scores, true_target_strings, attention_weights, path_source_string, path_strings, path_target_string = self.sess.run( [self.predict_top_indices_op, self.predict_top_scores_op, self.predict_target_strings_op, self.attention_weights_op, self.predict_source_string, self.predict_path_string, self.predict_path_target_string], feed_dict={self.predict_placeholder: line}) top_scores = np.squeeze(top_scores, axis=0) path_source_string = path_source_string.reshape((-1)) path_strings = path_strings.reshape((-1)) path_target_string = path_target_string.reshape((-1)) predicted_indices = np.squeeze(predicted_indices, axis=0) true_target_strings = Common.binary_to_string(true_target_strings[0]) if self.config.BEAM_WIDTH > 0: predicted_strings = [[self.index_to_target[sugg] for sugg in timestep] for timestep in predicted_indices] # (target_length, top-k) predicted_strings = list(map(list, zip(*predicted_strings))) # (top-k, target_length) top_scores = [np.exp(np.sum(s)) for s in zip(*top_scores)] else: predicted_strings = [self.index_to_target[idx] for idx in predicted_indices] # (batch, target_length) attention_per_path = None if self.config.BEAM_WIDTH == 0: attention_per_path = self.get_attention_per_path(path_source_string, path_strings, path_target_string, attention_weights) results.append((true_target_strings, predicted_strings, top_scores, attention_per_path)) return results @staticmethod def get_attention_per_path(source_strings, path_strings, target_strings, attention_weights): # attention_weights: (time, contexts) results = [] for time_step in attention_weights: attention_per_context = {} for source, path, target, weight in zip(source_strings, path_strings, target_strings, time_step): string_triplet = ( Common.binary_to_string(source), Common.binary_to_string(path), Common.binary_to_string(target)) attention_per_context[string_triplet] = weight results.append(attention_per_context) return results def save_model(self, sess, path): save_target = path + '_iter%d' % self.epochs_trained dirname = os.path.dirname(save_target) if not os.path.exists(dirname): os.makedirs(dirname) self.saver.save(sess, save_target) dictionaries_path = save_target + '.dict' with open(dictionaries_path, 'wb') as file: pickle.dump(self.subtoken_to_index, file) pickle.dump(self.index_to_subtoken, file) pickle.dump(self.subtoken_vocab_size, file) pickle.dump(self.target_to_index, file) pickle.dump(self.index_to_target, file) pickle.dump(self.target_vocab_size, file) pickle.dump(self.node_to_index, file) pickle.dump(self.index_to_node, file) pickle.dump(self.nodes_vocab_size, file) pickle.dump(self.num_training_examples, file) pickle.dump(self.epochs_trained, file) pickle.dump(self.config, file) print('Saved after %d epochs in: %s' % (self.epochs_trained, save_target)) def load_model(self, sess): if not sess is None: self.saver.restore(sess, self.config.LOAD_PATH) print('Done loading model') with open(self.config.LOAD_PATH + '.dict', 'rb') as file: if self.subtoken_to_index is not None: return print('Loading dictionaries from: ' + self.config.LOAD_PATH) self.subtoken_to_index = pickle.load(file) self.index_to_subtoken = pickle.load(file) self.subtoken_vocab_size = pickle.load(file) self.target_to_index = pickle.load(file) self.index_to_target = pickle.load(file) self.target_vocab_size = pickle.load(file) self.node_to_index = pickle.load(file) self.index_to_node = pickle.load(file) self.nodes_vocab_size = pickle.load(file) self.num_training_examples = pickle.load(file) self.epochs_trained = pickle.load(file) saved_config = pickle.load(file) self.config.take_model_hyperparams_from(saved_config) print('Done loading dictionaries') @staticmethod def initialize_session_variables(sess): sess.run(tf.group(tf.global_variables_initializer(), tf.local_variables_initializer(), tf.tables_initializer())) def get_should_reuse_variables(self): if self.config.TRAIN_PATH: return True else: return None ================================================ FILE: preprocess.py ================================================ import pickle from argparse import ArgumentParser import numpy as np import common ''' This script preprocesses the data from MethodPaths. It truncates methods with too many contexts, and pads methods with less paths with spaces. ''' def save_dictionaries(dataset_name, subtoken_to_count, node_to_count, target_to_count, max_contexts, num_examples): save_dict_file_path = '{}.dict.c2s'.format(dataset_name) with open(save_dict_file_path, 'wb') as file: pickle.dump(subtoken_to_count, file) pickle.dump(node_to_count, file) pickle.dump(target_to_count, file) pickle.dump(max_contexts, file) pickle.dump(num_examples, file) print('Dictionaries saved to: {}'.format(save_dict_file_path)) def process_file(file_path, data_file_role, dataset_name, max_contexts, max_data_contexts): sum_total = 0 sum_sampled = 0 total = 0 max_unfiltered = 0 max_contexts_to_sample = max_data_contexts if data_file_role == 'train' else max_contexts output_path = '{}.{}.c2s'.format(dataset_name, data_file_role) with open(output_path, 'w') as outfile: with open(file_path, 'r') as file: for line in file: parts = line.rstrip('\n').split(' ') target_name = parts[0] contexts = parts[1:] if len(contexts) > max_unfiltered: max_unfiltered = len(contexts) sum_total += len(contexts) if len(contexts) > max_contexts_to_sample: contexts = np.random.choice(contexts, max_contexts_to_sample, replace=False) sum_sampled += len(contexts) csv_padding = " " * (max_data_contexts - len(contexts)) total += 1 outfile.write(target_name + ' ' + " ".join(contexts) + csv_padding + '\n') print('File: ' + file_path) print('Average total contexts: ' + str(float(sum_total) / total)) print('Average final (after sampling) contexts: ' + str(float(sum_sampled) / total)) print('Total examples: ' + str(total)) print('Max number of contexts per word: ' + str(max_unfiltered)) return total def context_full_found(context_parts, word_to_count, path_to_count): return context_parts[0] in word_to_count \ and context_parts[1] in path_to_count and context_parts[2] in word_to_count def context_partial_found(context_parts, word_to_count, path_to_count): return context_parts[0] in word_to_count \ or context_parts[1] in path_to_count or context_parts[2] in word_to_count if __name__ == '__main__': parser = ArgumentParser() parser.add_argument("-trd", "--train_data", dest="train_data_path", help="path to training data file", required=True) parser.add_argument("-ted", "--test_data", dest="test_data_path", help="path to test data file", required=True) parser.add_argument("-vd", "--val_data", dest="val_data_path", help="path to validation data file", required=True) parser.add_argument("-mc", "--max_contexts", dest="max_contexts", default=200, help="number of max contexts to keep in test+validation", required=False) parser.add_argument("-mdc", "--max_data_contexts", dest="max_data_contexts", default=1000, help="number of max contexts to keep in the dataset", required=False) parser.add_argument("-svs", "--subtoken_vocab_size", dest="subtoken_vocab_size", default=186277, help="Max number of source subtokens to keep in the vocabulary", required=False) parser.add_argument("-tvs", "--target_vocab_size", dest="target_vocab_size", default=26347, help="Max number of target words to keep in the vocabulary", required=False) parser.add_argument("-sh", "--subtoken_histogram", dest="subtoken_histogram", help="subtoken histogram file", metavar="FILE", required=True) parser.add_argument("-nh", "--node_histogram", dest="node_histogram", help="node_histogram file", metavar="FILE", required=True) parser.add_argument("-th", "--target_histogram", dest="target_histogram", help="target histogram file", metavar="FILE", required=True) parser.add_argument("-o", "--output_name", dest="output_name", help="output name - the base name for the created dataset", required=True, default='data') args = parser.parse_args() train_data_path = args.train_data_path test_data_path = args.test_data_path val_data_path = args.val_data_path subtoken_histogram_path = args.subtoken_histogram node_histogram_path = args.node_histogram subtoken_to_count = common.Common.load_histogram(subtoken_histogram_path, max_size=int(args.subtoken_vocab_size)) node_to_count = common.Common.load_histogram(node_histogram_path, max_size=None) target_to_count = common.Common.load_histogram(args.target_histogram, max_size=int(args.target_vocab_size)) print('subtoken vocab size: ', len(subtoken_to_count)) print('node vocab size: ', len(node_to_count)) print('target vocab size: ', len(target_to_count)) num_training_examples = 0 for data_file_path, data_role in zip([test_data_path, val_data_path, train_data_path], ['test', 'val', 'train']): num_examples = process_file(file_path=data_file_path, data_file_role=data_role, dataset_name=args.output_name, max_contexts=int(args.max_contexts), max_data_contexts=int(args.max_data_contexts)) if data_role == 'train': num_training_examples = num_examples save_dictionaries(dataset_name=args.output_name, subtoken_to_count=subtoken_to_count, node_to_count=node_to_count, target_to_count=target_to_count, max_contexts=int(args.max_data_contexts), num_examples=num_training_examples) ================================================ FILE: preprocess.sh ================================================ #!/usr/bin/env bash ########################################################### # Change the following values to preprocess a new dataset. # TRAIN_DIR, VAL_DIR and TEST_DIR should be paths to # directories containing sub-directories with .java files # DATASET_NAME is just a name for the currently extracted # dataset. # MAX_DATA_CONTEXTS is the number of contexts to keep in the dataset for each # method (by default 1000). At training time, these contexts # will be downsampled dynamically to MAX_CONTEXTS. # MAX_CONTEXTS - the number of actual contexts (by default 200) # that are taken into consideration (out of MAX_DATA_CONTEXTS) # every training iteration. To avoid randomness at test time, # for the test and validation sets only MAX_CONTEXTS contexts are kept # (while for training, MAX_DATA_CONTEXTS are kept and MAX_CONTEXTS are # selected dynamically during training). # SUBTOKEN_VOCAB_SIZE, TARGET_VOCAB_SIZE - # - the number of subtokens and target words to keep # in the vocabulary (the top occurring words and paths will be kept). # NUM_THREADS - the number of parallel threads to use. It is # recommended to use a multi-core machine for the preprocessing # step and set this value to the number of cores. # PYTHON - python3 interpreter alias. TRAIN_DIR=my_training_dir VAL_DIR=my_val_dir TEST_DIR=my_test_dir DATASET_NAME=my_dataset MAX_DATA_CONTEXTS=1000 MAX_CONTEXTS=200 SUBTOKEN_VOCAB_SIZE=186277 TARGET_VOCAB_SIZE=26347 NUM_THREADS=64 PYTHON=python3 ########################################################### TRAIN_DATA_FILE=${DATASET_NAME}.train.raw.txt VAL_DATA_FILE=${DATASET_NAME}.val.raw.txt TEST_DATA_FILE=${DATASET_NAME}.test.raw.txt EXTRACTOR_JAR=JavaExtractor/JPredict/target/JavaExtractor-0.0.1-SNAPSHOT.jar mkdir -p data mkdir -p data/${DATASET_NAME} echo "Extracting paths from validation set..." ${PYTHON} JavaExtractor/extract.py --dir ${VAL_DIR} --max_path_length 8 --max_path_width 2 --num_threads ${NUM_THREADS} --jar ${EXTRACTOR_JAR} > ${VAL_DATA_FILE} 2>> error_log.txt echo "Finished extracting paths from validation set" echo "Extracting paths from test set..." ${PYTHON} JavaExtractor/extract.py --dir ${TEST_DIR} --max_path_length 8 --max_path_width 2 --num_threads ${NUM_THREADS} --jar ${EXTRACTOR_JAR} > ${TEST_DATA_FILE} 2>> error_log.txt echo "Finished extracting paths from test set" echo "Extracting paths from training set..." ${PYTHON} JavaExtractor/extract.py --dir ${TRAIN_DIR} --max_path_length 8 --max_path_width 2 --num_threads ${NUM_THREADS} --jar ${EXTRACTOR_JAR} | shuf > ${TRAIN_DATA_FILE} 2>> error_log.txt echo "Finished extracting paths from training set" TARGET_HISTOGRAM_FILE=data/${DATASET_NAME}/${DATASET_NAME}.histo.tgt.c2s SOURCE_SUBTOKEN_HISTOGRAM=data/${DATASET_NAME}/${DATASET_NAME}.histo.ori.c2s NODE_HISTOGRAM_FILE=data/${DATASET_NAME}/${DATASET_NAME}.histo.node.c2s echo "Creating histograms from the training data" cat ${TRAIN_DATA_FILE} | cut -d' ' -f1 | tr '|' '\n' | awk '{n[$0]++} END {for (i in n) print i,n[i]}' > ${TARGET_HISTOGRAM_FILE} cat ${TRAIN_DATA_FILE} | cut -d' ' -f2- | tr ' ' '\n' | cut -d',' -f1,3 | tr ',|' '\n' | awk '{n[$0]++} END {for (i in n) print i,n[i]}' > ${SOURCE_SUBTOKEN_HISTOGRAM} cat ${TRAIN_DATA_FILE} | cut -d' ' -f2- | tr ' ' '\n' | cut -d',' -f2 | tr '|' '\n' | awk '{n[$0]++} END {for (i in n) print i,n[i]}' > ${NODE_HISTOGRAM_FILE} ${PYTHON} preprocess.py --train_data ${TRAIN_DATA_FILE} --test_data ${TEST_DATA_FILE} --val_data ${VAL_DATA_FILE} \ --max_contexts ${MAX_CONTEXTS} --max_data_contexts ${MAX_DATA_CONTEXTS} --subtoken_vocab_size ${SUBTOKEN_VOCAB_SIZE} \ --target_vocab_size ${TARGET_VOCAB_SIZE} --subtoken_histogram ${SOURCE_SUBTOKEN_HISTOGRAM} \ --node_histogram ${NODE_HISTOGRAM_FILE} --target_histogram ${TARGET_HISTOGRAM_FILE} --output_name data/${DATASET_NAME}/${DATASET_NAME} # If all went well, the raw data files can be deleted, because preprocess.py creates new files # with truncated and padded number of paths for each example. rm ${TRAIN_DATA_FILE} ${VAL_DATA_FILE} ${TEST_DATA_FILE} ${TARGET_HISTOGRAM_FILE} ${SOURCE_SUBTOKEN_HISTOGRAM} \ ${NODE_HISTOGRAM_FILE} ================================================ FILE: preprocess_csharp.sh ================================================ #!/usr/bin/env bash ########################################################### # Change the following values to preprocess a new dataset. # TRAIN_DIR, VAL_DIR and TEST_DIR should be paths to # directories containing sub-directories with .java files # DATASET_NAME is just a name for the currently extracted # dataset. # MAX_DATA_CONTEXTS is the number of contexts to keep in the dataset for each # method (by default 1000). At training time, these contexts # will be downsampled dynamically to MAX_CONTEXTS. # MAX_CONTEXTS - the number of actual contexts (by default 200) # that are taken into consideration (out of MAX_DATA_CONTEXTS) # every training iteration. To avoid randomness at test time, # for the test and validation sets only MAX_CONTEXTS contexts are kept # (while for training, MAX_DATA_CONTEXTS are kept and MAX_CONTEXTS are # selected dynamically during training). # SUBTOKEN_VOCAB_SIZE, TARGET_VOCAB_SIZE - # - the number of subtokens and target words to keep # in the vocabulary (the top occurring words and paths will be kept). # NUM_THREADS - the number of parallel threads to use. It is # recommended to use a multi-core machine for the preprocessing # step and set this value to the number of cores. # PYTHON - python3 interpreter alias. TRAIN_DIR=JavaExtractor/JPredict/src/main/java/JavaExtractor/Common VAL_DIR=JavaExtractor/JPredict/src/main/java/JavaExtractor/Common TEST_DIR=JavaExtractor/JPredict/src/main/java/JavaExtractor/Common DATASET_NAME=my_dataset MAX_DATA_CONTEXTS=1000 MAX_CONTEXTS=200 SUBTOKEN_VOCAB_SIZE=186277 TARGET_VOCAB_SIZE=26347 NUM_THREADS=64 PYTHON=python3 ########################################################### TRAIN_DATA_FILE=${DATASET_NAME}.train.raw.txt VAL_DATA_FILE=${DATASET_NAME}.val.raw.txt TEST_DATA_FILE=${DATASET_NAME}.test.raw.txt EXTRACTOR_JAR=CSharpExtractor/CSharpExtractor/Extractor/Extractor.csproj mkdir -p data mkdir -p data/${DATASET_NAME} echo "Extracting paths from validation set..." ${PYTHON} CSharpExtractor/extract.py --dir ${VAL_DIR} --max_path_length 8 --max_path_width 2 --num_threads ${NUM_THREADS} --csproj ${EXTRACTOR_JAR} --ofile_name ${VAL_DATA_FILE} 2>> error_log.txt echo "Finished extracting paths from validation set" echo "Extracting paths from test set..." ${PYTHON} CSharpExtractor/extract.py --dir ${TEST_DIR} --max_path_length 8 --max_path_width 2 --num_threads ${NUM_THREADS} --csproj ${EXTRACTOR_JAR} --ofile_name ${TEST_DATA_FILE} 2>> error_log.txt echo "Finished extracting paths from test set" echo "Extracting paths from training set..." ${PYTHON} CSharpExtractor/extract.py --dir ${TRAIN_DIR} --max_path_length 8 --max_path_width 2 --num_threads ${NUM_THREADS} --csproj ${EXTRACTOR_JAR} --ofile_name ${TRAIN_DATA_FILE}_unshuf 2>> error_log.txt echo "Finished extracting paths from training set" echo "Shuffling training data" cat ${TRAIN_DATA_FILE}_unshuf | shuf > ${TRAIN_DATA_FILE} rm ${TRAIN_DATA_FILE}_unshuf TARGET_HISTOGRAM_FILE=data/${DATASET_NAME}/${DATASET_NAME}.histo.tgt.c2s SOURCE_SUBTOKEN_HISTOGRAM=data/${DATASET_NAME}/${DATASET_NAME}.histo.ori.c2s NODE_HISTOGRAM_FILE=data/${DATASET_NAME}/${DATASET_NAME}.histo.node.c2s echo "Creating histograms from the training data" cat ${TRAIN_DATA_FILE} | cut -d' ' -f1 | tr '|' '\n' | awk '{n[$0]++} END {for (i in n) print i,n[i]}' > ${TARGET_HISTOGRAM_FILE} cat ${TRAIN_DATA_FILE} | cut -d' ' -f2- | tr ' ' '\n' | cut -d',' -f1,3 | tr ',|' '\n' | awk '{n[$0]++} END {for (i in n) print i,n[i]}' > ${SOURCE_SUBTOKEN_HISTOGRAM} cat ${TRAIN_DATA_FILE} | cut -d' ' -f2- | tr ' ' '\n' | cut -d',' -f2 | tr '|' '\n' | awk '{n[$0]++} END {for (i in n) print i,n[i]}' > ${NODE_HISTOGRAM_FILE} ${PYTHON} preprocess.py --train_data ${TRAIN_DATA_FILE} --test_data ${TEST_DATA_FILE} --val_data ${VAL_DATA_FILE} \ --max_contexts ${MAX_CONTEXTS} --max_data_contexts ${MAX_DATA_CONTEXTS} --subtoken_vocab_size ${SUBTOKEN_VOCAB_SIZE} \ --target_vocab_size ${TARGET_VOCAB_SIZE} --subtoken_histogram ${SOURCE_SUBTOKEN_HISTOGRAM} \ --node_histogram ${NODE_HISTOGRAM_FILE} --target_histogram ${TARGET_HISTOGRAM_FILE} --output_name data/${DATASET_NAME}/${DATASET_NAME} # If all went well, the raw data files can be deleted, because preprocess.py creates new files # with truncated and padded number of paths for each example. rm ${TRAIN_DATA_FILE} ${VAL_DATA_FILE} ${TEST_DATA_FILE} ${TARGET_HISTOGRAM_FILE} ${SOURCE_SUBTOKEN_HISTOGRAM} \ ${NODE_HISTOGRAM_FILE} ================================================ FILE: reader.py ================================================ import os import tensorflow as tf from common import Common TARGET_INDEX_KEY = 'TARGET_INDEX_KEY' TARGET_STRING_KEY = 'TARGET_STRING_KEY' TARGET_LENGTH_KEY = 'TARGET_LENGTH_KEY' PATH_SOURCE_INDICES_KEY = 'PATH_SOURCE_INDICES_KEY' NODE_INDICES_KEY = 'NODES_INDICES_KEY' PATH_TARGET_INDICES_KEY = 'PATH_TARGET_INDICES_KEY' VALID_CONTEXT_MASK_KEY = 'VALID_CONTEXT_MASK_KEY' PATH_SOURCE_LENGTHS_KEY = 'PATH_SOURCE_LENGTHS_KEY' PATH_LENGTHS_KEY = 'PATH_LENGTHS_KEY' PATH_TARGET_LENGTHS_KEY = 'PATH_TARGET_LENGTHS_KEY' PATH_SOURCE_STRINGS_KEY = 'PATH_SOURCE_STRINGS_KEY' PATH_STRINGS_KEY = 'PATH_STRINGS_KEY' PATH_TARGET_STRINGS_KEY = 'PATH_TARGET_STRINGS_KEY' class Reader: class_subtoken_table = None class_target_table = None class_node_table = None def __init__(self, subtoken_to_index, target_to_index, node_to_index, config, is_evaluating=False): self.config = config self.file_path = config.TEST_PATH if is_evaluating else (config.TRAIN_PATH + '.train.c2s') if self.file_path is not None and not os.path.exists(self.file_path): print( '%s cannot find file: %s' % ('Evaluation reader' if is_evaluating else 'Train reader', self.file_path)) self.batch_size = config.TEST_BATCH_SIZE if is_evaluating else config.BATCH_SIZE self.is_evaluating = is_evaluating self.context_pad = '{},{},{}'.format(Common.PAD, Common.PAD, Common.PAD) self.record_defaults = [[self.context_pad]] * (self.config.DATA_NUM_CONTEXTS + 1) self.subtoken_table = Reader.get_subtoken_table(subtoken_to_index) self.target_table = Reader.get_target_table(target_to_index) self.node_table = Reader.get_node_table(node_to_index) if self.file_path is not None: self.output_tensors = self.compute_output() @classmethod def get_subtoken_table(cls, subtoken_to_index): if cls.class_subtoken_table is None: cls.class_subtoken_table = cls.initialize_hash_map(subtoken_to_index, subtoken_to_index[Common.UNK]) return cls.class_subtoken_table @classmethod def get_target_table(cls, target_to_index): if cls.class_target_table is None: cls.class_target_table = cls.initialize_hash_map(target_to_index, target_to_index[Common.UNK]) return cls.class_target_table @classmethod def get_node_table(cls, node_to_index): if cls.class_node_table is None: cls.class_node_table = cls.initialize_hash_map(node_to_index, node_to_index[Common.UNK]) return cls.class_node_table @classmethod def initialize_hash_map(cls, word_to_index, default_value): return tf.contrib.lookup.HashTable( tf.contrib.lookup.KeyValueTensorInitializer(list(word_to_index.keys()), list(word_to_index.values()), key_dtype=tf.string, value_dtype=tf.int32), default_value) def process_from_placeholder(self, row): parts = tf.io.decode_csv(row, record_defaults=self.record_defaults, field_delim=' ', use_quote_delim=False) return self.process_dataset(*parts) def process_dataset(self, *row_parts): row_parts = list(row_parts) word = row_parts[0] # (, ) if not self.is_evaluating and self.config.RANDOM_CONTEXTS: all_contexts = tf.stack(row_parts[1:]) all_contexts_padded = tf.concat([all_contexts, [self.context_pad]], axis=-1) index_of_blank_context = tf.where(tf.equal(all_contexts_padded, self.context_pad)) num_contexts_per_example = tf.reduce_min(index_of_blank_context) # if there are less than self.max_contexts valid contexts, still sample self.max_contexts safe_limit = tf.cast(tf.maximum(num_contexts_per_example, self.config.MAX_CONTEXTS), tf.int32) rand_indices = tf.random_shuffle(tf.range(safe_limit))[:self.config.MAX_CONTEXTS] contexts = tf.gather(all_contexts, rand_indices) # (max_contexts,) else: contexts = row_parts[1:(self.config.MAX_CONTEXTS + 1)] # (max_contexts,) # contexts: (max_contexts, ) split_contexts = tf.string_split(contexts, delimiter=',', skip_empty=False) sparse_split_contexts = tf.sparse.SparseTensor(indices=split_contexts.indices, values=split_contexts.values, dense_shape=[self.config.MAX_CONTEXTS, 3]) dense_split_contexts = tf.reshape( tf.sparse.to_dense(sp_input=sparse_split_contexts, default_value=Common.PAD), shape=[self.config.MAX_CONTEXTS, 3]) # (batch, max_contexts, 3) split_target_labels = tf.string_split(tf.expand_dims(word, -1), delimiter='|') target_dense_shape = [1, tf.maximum(tf.to_int64(self.config.MAX_TARGET_PARTS), split_target_labels.dense_shape[1] + 1)] sparse_target_labels = tf.sparse.SparseTensor(indices=split_target_labels.indices, values=split_target_labels.values, dense_shape=target_dense_shape) dense_target_label = tf.reshape(tf.sparse.to_dense(sp_input=sparse_target_labels, default_value=Common.PAD), [-1]) index_of_blank = tf.where(tf.equal(dense_target_label, Common.PAD)) target_length = tf.reduce_min(index_of_blank) dense_target_label = dense_target_label[:self.config.MAX_TARGET_PARTS] clipped_target_lengths = tf.clip_by_value(target_length, clip_value_min=0, clip_value_max=self.config.MAX_TARGET_PARTS) target_word_labels = tf.concat([ self.target_table.lookup(dense_target_label), [0]], axis=-1) # (max_target_parts + 1) of int path_source_strings = tf.slice(dense_split_contexts, [0, 0], [self.config.MAX_CONTEXTS, 1]) # (max_contexts, 1) flat_source_strings = tf.reshape(path_source_strings, [-1]) # (max_contexts) split_source = tf.string_split(flat_source_strings, delimiter='|', skip_empty=False) # (max_contexts, max_name_parts) sparse_split_source = tf.sparse.SparseTensor(indices=split_source.indices, values=split_source.values, dense_shape=[self.config.MAX_CONTEXTS, tf.maximum(tf.to_int64(self.config.MAX_NAME_PARTS), split_source.dense_shape[1])]) dense_split_source = tf.sparse.to_dense(sp_input=sparse_split_source, default_value=Common.PAD) # (max_contexts, max_name_parts) dense_split_source = tf.slice(dense_split_source, [0, 0], [-1, self.config.MAX_NAME_PARTS]) path_source_indices = self.subtoken_table.lookup(dense_split_source) # (max_contexts, max_name_parts) path_source_lengths = tf.reduce_sum(tf.cast(tf.not_equal(dense_split_source, Common.PAD), tf.int32), -1) # (max_contexts) path_strings = tf.slice(dense_split_contexts, [0, 1], [self.config.MAX_CONTEXTS, 1]) flat_path_strings = tf.reshape(path_strings, [-1]) split_path = tf.string_split(flat_path_strings, delimiter='|', skip_empty=False) sparse_split_path = tf.sparse.SparseTensor(indices=split_path.indices, values=split_path.values, dense_shape=[self.config.MAX_CONTEXTS, self.config.MAX_PATH_LENGTH]) dense_split_path = tf.sparse.to_dense(sp_input=sparse_split_path, default_value=Common.PAD) # (batch, max_contexts, max_path_length) node_indices = self.node_table.lookup(dense_split_path) # (max_contexts, max_path_length) path_lengths = tf.reduce_sum(tf.cast(tf.not_equal(dense_split_path, Common.PAD), tf.int32), -1) # (max_contexts) path_target_strings = tf.slice(dense_split_contexts, [0, 2], [self.config.MAX_CONTEXTS, 1]) # (max_contexts, 1) flat_target_strings = tf.reshape(path_target_strings, [-1]) # (max_contexts) split_target = tf.string_split(flat_target_strings, delimiter='|', skip_empty=False) # (max_contexts, max_name_parts) sparse_split_target = tf.sparse.SparseTensor(indices=split_target.indices, values=split_target.values, dense_shape=[self.config.MAX_CONTEXTS, tf.maximum(tf.to_int64(self.config.MAX_NAME_PARTS), split_target.dense_shape[1])]) dense_split_target = tf.sparse.to_dense(sp_input=sparse_split_target, default_value=Common.PAD) # (max_contexts, max_name_parts) dense_split_target = tf.slice(dense_split_target, [0, 0], [-1, self.config.MAX_NAME_PARTS]) path_target_indices = self.subtoken_table.lookup(dense_split_target) # (max_contexts, max_name_parts) path_target_lengths = tf.reduce_sum(tf.cast(tf.not_equal(dense_split_target, Common.PAD), tf.int32), -1) # (max_contexts) valid_contexts_mask = tf.to_float(tf.not_equal( tf.reduce_max(path_source_indices, -1) + tf.reduce_max(node_indices, -1) + tf.reduce_max( path_target_indices, -1), 0)) return {TARGET_STRING_KEY: word, TARGET_INDEX_KEY: target_word_labels, TARGET_LENGTH_KEY: clipped_target_lengths, PATH_SOURCE_INDICES_KEY: path_source_indices, NODE_INDICES_KEY: node_indices, PATH_TARGET_INDICES_KEY: path_target_indices, VALID_CONTEXT_MASK_KEY: valid_contexts_mask, PATH_SOURCE_LENGTHS_KEY: path_source_lengths, PATH_LENGTHS_KEY: path_lengths, PATH_TARGET_LENGTHS_KEY: path_target_lengths, PATH_SOURCE_STRINGS_KEY: path_source_strings, PATH_STRINGS_KEY: path_strings, PATH_TARGET_STRINGS_KEY: path_target_strings } def reset(self, sess): sess.run(self.reset_op) def get_output(self): return self.output_tensors def compute_output(self): dataset = tf.data.experimental.CsvDataset(self.file_path, record_defaults=self.record_defaults, field_delim=' ', use_quote_delim=False, buffer_size=self.config.CSV_BUFFER_SIZE) if not self.is_evaluating: if self.config.SAVE_EVERY_EPOCHS > 1: dataset = dataset.repeat(self.config.SAVE_EVERY_EPOCHS) dataset = dataset.shuffle(self.config.SHUFFLE_BUFFER_SIZE, reshuffle_each_iteration=True) dataset = dataset.apply(tf.data.experimental.map_and_batch( map_func=self.process_dataset, batch_size=self.batch_size, num_parallel_batches=self.config.READER_NUM_PARALLEL_BATCHES)) dataset = dataset.prefetch(tf.contrib.data.AUTOTUNE) self.iterator = dataset.make_initializable_iterator() self.reset_op = self.iterator.initializer return self.iterator.get_next() if __name__ == '__main__': target_word_to_index = {Common.PAD: 0, Common.UNK: 1, Common.SOS: 2, 'a': 3, 'b': 4, 'c': 5, 'd': 6, 't': 7} subtoken_to_index = {Common.PAD: 0, Common.UNK: 1, 'a': 2, 'b': 3, 'c': 4, 'd': 5} node_to_index = {Common.PAD: 0, Common.UNK: 1, '1': 2, '2': 3, '3': 4, '4': 5} import numpy as np class Config: def __init__(self): self.SAVE_EVERY_EPOCHS = 1 self.TRAIN_PATH = self.TEST_PATH = 'test_input/test_input' self.BATCH_SIZE = 2 self.TEST_BATCH_SIZE = self.BATCH_SIZE self.READER_NUM_PARALLEL_BATCHES = 1 self.READING_BATCH_SIZE = 2 self.SHUFFLE_BUFFER_SIZE = 100 self.MAX_CONTEXTS = 4 self.DATA_NUM_CONTEXTS = 4 self.MAX_PATH_LENGTH = 3 self.MAX_NAME_PARTS = 2 self.MAX_TARGET_PARTS = 4 self.RANDOM_CONTEXTS = True self.CSV_BUFFER_SIZE = None config = Config() reader = Reader(subtoken_to_index, target_word_to_index, node_to_index, config, False) output = reader.get_output() target_index_op = output[TARGET_INDEX_KEY] target_string_op = output[TARGET_STRING_KEY] target_length_op = output[TARGET_LENGTH_KEY] path_source_indices_op = output[PATH_SOURCE_INDICES_KEY] node_indices_op = output[NODE_INDICES_KEY] path_target_indices_op = output[PATH_TARGET_INDICES_KEY] valid_context_mask_op = output[VALID_CONTEXT_MASK_KEY] path_source_lengths_op = output[PATH_SOURCE_LENGTHS_KEY] path_lengths_op = output[PATH_LENGTHS_KEY] path_target_lengths_op = output[PATH_TARGET_LENGTHS_KEY] path_source_strings_op = output[PATH_SOURCE_STRINGS_KEY] path_strings_op = output[PATH_STRINGS_KEY] path_target_strings_op = output[PATH_TARGET_STRINGS_KEY] sess = tf.InteractiveSession() tf.group(tf.global_variables_initializer(), tf.local_variables_initializer(), tf.tables_initializer()).run() reader.reset(sess) try: while True: target_indices, target_strings, target_lengths, path_source_indices, \ node_indices, path_target_indices, valid_context_mask, path_source_lengths, \ path_lengths, path_target_lengths, path_source_strings, path_strings, \ path_target_strings = sess.run( [target_index_op, target_string_op, target_length_op, path_source_indices_op, node_indices_op, path_target_indices_op, valid_context_mask_op, path_source_lengths_op, path_lengths_op, path_target_lengths_op, path_source_strings_op, path_strings_op, path_target_strings_op]) print('Target strings: ', Common.binary_to_string_list(target_strings)) print('Context strings: ', Common.binary_to_string_3d( np.concatenate([path_source_strings, path_strings, path_target_strings], -1))) print('Target indices: ', target_indices) print('Target lengths: ', target_lengths) print('Path source strings: ', Common.binary_to_string_3d(path_source_strings)) print('Path source indices: ', path_source_indices) print('Path source lengths: ', path_source_lengths) print('Path strings: ', Common.binary_to_string_3d(path_strings)) print('Node indices: ', node_indices) print('Path lengths: ', path_lengths) print('Path target strings: ', Common.binary_to_string_3d(path_target_strings)) print('Path target indices: ', path_target_indices) print('Path target lengths: ', path_target_lengths) print('Valid context mask: ', valid_context_mask) except tf.errors.OutOfRangeError: print('Done training, epoch reached') ================================================ FILE: train.sh ================================================ ########################################################### # Change the following values to train a new model. # type: the name of the new model, only affects the saved file name. # dataset: the name of the dataset, as was preprocessed using preprocess.sh # test_data: by default, points to the validation set, since this is the set that # will be evaluated after each training iteration. If you wish to test # on the final (held-out) test set, change 'val' to 'test'. type=java-large-model dataset_name=java-large data_dir=data/java-large data=${data_dir}/${dataset_name} test_data=${data_dir}/${dataset_name}.val.c2s model_dir=models/${type} mkdir -p ${model_dir} set -e python3 -u code2seq.py --data ${data} --test ${test_data} --save_prefix ${model_dir}/model ================================================ FILE: train_python150k.sh ================================================ #!/usr/bin/env bash data_dir=$1 data_name=$(basename "${data_dir}") data=${data_dir}/${data_name} test=${data_dir}/${data_name}.val.c2s run_name=$2 model_dir=models/python150k-${run_name} save_prefix=${model_dir}/model cuda=${3:-0} seed=${4:-239} mkdir -p "${model_dir}" set -e CUDA_VISIBLE_DEVICES=$cuda python -u code2seq.py \ --data="${data}" \ --test="${test}" \ --save_prefix="${save_prefix}" \ --seed="${seed}"